main
1import json
2from typing import TYPE_CHECKING, Any, Dict, List, Union, Optional, cast
3from datetime import datetime, timezone
4from typing_extensions import Literal, Annotated, TypeAliasType
5
6import pytest
7import pydantic
8from pydantic import Field
9
10from openai._utils import PropertyInfo
11from openai._compat import PYDANTIC_V1, parse_obj, model_dump, model_json
12from openai._models import DISCRIMINATOR_CACHE, BaseModel, construct_type
13
14
15class BasicModel(BaseModel):
16 foo: str
17
18
19@pytest.mark.parametrize("value", ["hello", 1], ids=["correct type", "mismatched"])
20def test_basic(value: object) -> None:
21 m = BasicModel.construct(foo=value)
22 assert m.foo == value
23
24
25def test_directly_nested_model() -> None:
26 class NestedModel(BaseModel):
27 nested: BasicModel
28
29 m = NestedModel.construct(nested={"foo": "Foo!"})
30 assert m.nested.foo == "Foo!"
31
32 # mismatched types
33 m = NestedModel.construct(nested="hello!")
34 assert cast(Any, m.nested) == "hello!"
35
36
37def test_optional_nested_model() -> None:
38 class NestedModel(BaseModel):
39 nested: Optional[BasicModel]
40
41 m1 = NestedModel.construct(nested=None)
42 assert m1.nested is None
43
44 m2 = NestedModel.construct(nested={"foo": "bar"})
45 assert m2.nested is not None
46 assert m2.nested.foo == "bar"
47
48 # mismatched types
49 m3 = NestedModel.construct(nested={"foo"})
50 assert isinstance(cast(Any, m3.nested), set)
51 assert cast(Any, m3.nested) == {"foo"}
52
53
54def test_list_nested_model() -> None:
55 class NestedModel(BaseModel):
56 nested: List[BasicModel]
57
58 m = NestedModel.construct(nested=[{"foo": "bar"}, {"foo": "2"}])
59 assert m.nested is not None
60 assert isinstance(m.nested, list)
61 assert len(m.nested) == 2
62 assert m.nested[0].foo == "bar"
63 assert m.nested[1].foo == "2"
64
65 # mismatched types
66 m = NestedModel.construct(nested=True)
67 assert cast(Any, m.nested) is True
68
69 m = NestedModel.construct(nested=[False])
70 assert cast(Any, m.nested) == [False]
71
72
73def test_optional_list_nested_model() -> None:
74 class NestedModel(BaseModel):
75 nested: Optional[List[BasicModel]]
76
77 m1 = NestedModel.construct(nested=[{"foo": "bar"}, {"foo": "2"}])
78 assert m1.nested is not None
79 assert isinstance(m1.nested, list)
80 assert len(m1.nested) == 2
81 assert m1.nested[0].foo == "bar"
82 assert m1.nested[1].foo == "2"
83
84 m2 = NestedModel.construct(nested=None)
85 assert m2.nested is None
86
87 # mismatched types
88 m3 = NestedModel.construct(nested={1})
89 assert cast(Any, m3.nested) == {1}
90
91 m4 = NestedModel.construct(nested=[False])
92 assert cast(Any, m4.nested) == [False]
93
94
95def test_list_optional_items_nested_model() -> None:
96 class NestedModel(BaseModel):
97 nested: List[Optional[BasicModel]]
98
99 m = NestedModel.construct(nested=[None, {"foo": "bar"}])
100 assert m.nested is not None
101 assert isinstance(m.nested, list)
102 assert len(m.nested) == 2
103 assert m.nested[0] is None
104 assert m.nested[1] is not None
105 assert m.nested[1].foo == "bar"
106
107 # mismatched types
108 m3 = NestedModel.construct(nested="foo")
109 assert cast(Any, m3.nested) == "foo"
110
111 m4 = NestedModel.construct(nested=[False])
112 assert cast(Any, m4.nested) == [False]
113
114
115def test_list_mismatched_type() -> None:
116 class NestedModel(BaseModel):
117 nested: List[str]
118
119 m = NestedModel.construct(nested=False)
120 assert cast(Any, m.nested) is False
121
122
123def test_raw_dictionary() -> None:
124 class NestedModel(BaseModel):
125 nested: Dict[str, str]
126
127 m = NestedModel.construct(nested={"hello": "world"})
128 assert m.nested == {"hello": "world"}
129
130 # mismatched types
131 m = NestedModel.construct(nested=False)
132 assert cast(Any, m.nested) is False
133
134
135def test_nested_dictionary_model() -> None:
136 class NestedModel(BaseModel):
137 nested: Dict[str, BasicModel]
138
139 m = NestedModel.construct(nested={"hello": {"foo": "bar"}})
140 assert isinstance(m.nested, dict)
141 assert m.nested["hello"].foo == "bar"
142
143 # mismatched types
144 m = NestedModel.construct(nested={"hello": False})
145 assert cast(Any, m.nested["hello"]) is False
146
147
148def test_unknown_fields() -> None:
149 m1 = BasicModel.construct(foo="foo", unknown=1)
150 assert m1.foo == "foo"
151 assert cast(Any, m1).unknown == 1
152
153 m2 = BasicModel.construct(foo="foo", unknown={"foo_bar": True})
154 assert m2.foo == "foo"
155 assert cast(Any, m2).unknown == {"foo_bar": True}
156
157 assert model_dump(m2) == {"foo": "foo", "unknown": {"foo_bar": True}}
158
159
160def test_strict_validation_unknown_fields() -> None:
161 class Model(BaseModel):
162 foo: str
163
164 model = parse_obj(Model, dict(foo="hello!", user="Robert"))
165 assert model.foo == "hello!"
166 assert cast(Any, model).user == "Robert"
167
168 assert model_dump(model) == {"foo": "hello!", "user": "Robert"}
169
170
171def test_aliases() -> None:
172 class Model(BaseModel):
173 my_field: int = Field(alias="myField")
174
175 m = Model.construct(myField=1)
176 assert m.my_field == 1
177
178 # mismatched types
179 m = Model.construct(myField={"hello": False})
180 assert cast(Any, m.my_field) == {"hello": False}
181
182
183def test_repr() -> None:
184 model = BasicModel(foo="bar")
185 assert str(model) == "BasicModel(foo='bar')"
186 assert repr(model) == "BasicModel(foo='bar')"
187
188
189def test_repr_nested_model() -> None:
190 class Child(BaseModel):
191 name: str
192 age: int
193
194 class Parent(BaseModel):
195 name: str
196 child: Child
197
198 model = Parent(name="Robert", child=Child(name="Foo", age=5))
199 assert str(model) == "Parent(name='Robert', child=Child(name='Foo', age=5))"
200 assert repr(model) == "Parent(name='Robert', child=Child(name='Foo', age=5))"
201
202
203def test_optional_list() -> None:
204 class Submodel(BaseModel):
205 name: str
206
207 class Model(BaseModel):
208 items: Optional[List[Submodel]]
209
210 m = Model.construct(items=None)
211 assert m.items is None
212
213 m = Model.construct(items=[])
214 assert m.items == []
215
216 m = Model.construct(items=[{"name": "Robert"}])
217 assert m.items is not None
218 assert len(m.items) == 1
219 assert m.items[0].name == "Robert"
220
221
222def test_nested_union_of_models() -> None:
223 class Submodel1(BaseModel):
224 bar: bool
225
226 class Submodel2(BaseModel):
227 thing: str
228
229 class Model(BaseModel):
230 foo: Union[Submodel1, Submodel2]
231
232 m = Model.construct(foo={"thing": "hello"})
233 assert isinstance(m.foo, Submodel2)
234 assert m.foo.thing == "hello"
235
236
237def test_nested_union_of_mixed_types() -> None:
238 class Submodel1(BaseModel):
239 bar: bool
240
241 class Model(BaseModel):
242 foo: Union[Submodel1, Literal[True], Literal["CARD_HOLDER"]]
243
244 m = Model.construct(foo=True)
245 assert m.foo is True
246
247 m = Model.construct(foo="CARD_HOLDER")
248 assert m.foo == "CARD_HOLDER"
249
250 m = Model.construct(foo={"bar": False})
251 assert isinstance(m.foo, Submodel1)
252 assert m.foo.bar is False
253
254
255def test_nested_union_multiple_variants() -> None:
256 class Submodel1(BaseModel):
257 bar: bool
258
259 class Submodel2(BaseModel):
260 thing: str
261
262 class Submodel3(BaseModel):
263 foo: int
264
265 class Model(BaseModel):
266 foo: Union[Submodel1, Submodel2, None, Submodel3]
267
268 m = Model.construct(foo={"thing": "hello"})
269 assert isinstance(m.foo, Submodel2)
270 assert m.foo.thing == "hello"
271
272 m = Model.construct(foo=None)
273 assert m.foo is None
274
275 m = Model.construct()
276 assert m.foo is None
277
278 m = Model.construct(foo={"foo": "1"})
279 assert isinstance(m.foo, Submodel3)
280 assert m.foo.foo == 1
281
282
283def test_nested_union_invalid_data() -> None:
284 class Submodel1(BaseModel):
285 level: int
286
287 class Submodel2(BaseModel):
288 name: str
289
290 class Model(BaseModel):
291 foo: Union[Submodel1, Submodel2]
292
293 m = Model.construct(foo=True)
294 assert cast(bool, m.foo) is True
295
296 m = Model.construct(foo={"name": 3})
297 if PYDANTIC_V1:
298 assert isinstance(m.foo, Submodel2)
299 assert m.foo.name == "3"
300 else:
301 assert isinstance(m.foo, Submodel1)
302 assert m.foo.name == 3 # type: ignore
303
304
305def test_list_of_unions() -> None:
306 class Submodel1(BaseModel):
307 level: int
308
309 class Submodel2(BaseModel):
310 name: str
311
312 class Model(BaseModel):
313 items: List[Union[Submodel1, Submodel2]]
314
315 m = Model.construct(items=[{"level": 1}, {"name": "Robert"}])
316 assert len(m.items) == 2
317 assert isinstance(m.items[0], Submodel1)
318 assert m.items[0].level == 1
319 assert isinstance(m.items[1], Submodel2)
320 assert m.items[1].name == "Robert"
321
322 m = Model.construct(items=[{"level": -1}, 156])
323 assert len(m.items) == 2
324 assert isinstance(m.items[0], Submodel1)
325 assert m.items[0].level == -1
326 assert cast(Any, m.items[1]) == 156
327
328
329def test_union_of_lists() -> None:
330 class SubModel1(BaseModel):
331 level: int
332
333 class SubModel2(BaseModel):
334 name: str
335
336 class Model(BaseModel):
337 items: Union[List[SubModel1], List[SubModel2]]
338
339 # with one valid entry
340 m = Model.construct(items=[{"name": "Robert"}])
341 assert len(m.items) == 1
342 assert isinstance(m.items[0], SubModel2)
343 assert m.items[0].name == "Robert"
344
345 # with two entries pointing to different types
346 m = Model.construct(items=[{"level": 1}, {"name": "Robert"}])
347 assert len(m.items) == 2
348 assert isinstance(m.items[0], SubModel1)
349 assert m.items[0].level == 1
350 assert isinstance(m.items[1], SubModel1)
351 assert cast(Any, m.items[1]).name == "Robert"
352
353 # with two entries pointing to *completely* different types
354 m = Model.construct(items=[{"level": -1}, 156])
355 assert len(m.items) == 2
356 assert isinstance(m.items[0], SubModel1)
357 assert m.items[0].level == -1
358 assert cast(Any, m.items[1]) == 156
359
360
361def test_dict_of_union() -> None:
362 class SubModel1(BaseModel):
363 name: str
364
365 class SubModel2(BaseModel):
366 foo: str
367
368 class Model(BaseModel):
369 data: Dict[str, Union[SubModel1, SubModel2]]
370
371 m = Model.construct(data={"hello": {"name": "there"}, "foo": {"foo": "bar"}})
372 assert len(list(m.data.keys())) == 2
373 assert isinstance(m.data["hello"], SubModel1)
374 assert m.data["hello"].name == "there"
375 assert isinstance(m.data["foo"], SubModel2)
376 assert m.data["foo"].foo == "bar"
377
378 # TODO: test mismatched type
379
380
381def test_double_nested_union() -> None:
382 class SubModel1(BaseModel):
383 name: str
384
385 class SubModel2(BaseModel):
386 bar: str
387
388 class Model(BaseModel):
389 data: Dict[str, List[Union[SubModel1, SubModel2]]]
390
391 m = Model.construct(data={"foo": [{"bar": "baz"}, {"name": "Robert"}]})
392 assert len(m.data["foo"]) == 2
393
394 entry1 = m.data["foo"][0]
395 assert isinstance(entry1, SubModel2)
396 assert entry1.bar == "baz"
397
398 entry2 = m.data["foo"][1]
399 assert isinstance(entry2, SubModel1)
400 assert entry2.name == "Robert"
401
402 # TODO: test mismatched type
403
404
405def test_union_of_dict() -> None:
406 class SubModel1(BaseModel):
407 name: str
408
409 class SubModel2(BaseModel):
410 foo: str
411
412 class Model(BaseModel):
413 data: Union[Dict[str, SubModel1], Dict[str, SubModel2]]
414
415 m = Model.construct(data={"hello": {"name": "there"}, "foo": {"foo": "bar"}})
416 assert len(list(m.data.keys())) == 2
417 assert isinstance(m.data["hello"], SubModel1)
418 assert m.data["hello"].name == "there"
419 assert isinstance(m.data["foo"], SubModel1)
420 assert cast(Any, m.data["foo"]).foo == "bar"
421
422
423def test_iso8601_datetime() -> None:
424 class Model(BaseModel):
425 created_at: datetime
426
427 expected = datetime(2019, 12, 27, 18, 11, 19, 117000, tzinfo=timezone.utc)
428
429 if PYDANTIC_V1:
430 expected_json = '{"created_at": "2019-12-27T18:11:19.117000+00:00"}'
431 else:
432 expected_json = '{"created_at":"2019-12-27T18:11:19.117000Z"}'
433
434 model = Model.construct(created_at="2019-12-27T18:11:19.117Z")
435 assert model.created_at == expected
436 assert model_json(model) == expected_json
437
438 model = parse_obj(Model, dict(created_at="2019-12-27T18:11:19.117Z"))
439 assert model.created_at == expected
440 assert model_json(model) == expected_json
441
442
443def test_does_not_coerce_int() -> None:
444 class Model(BaseModel):
445 bar: int
446
447 assert Model.construct(bar=1).bar == 1
448 assert Model.construct(bar=10.9).bar == 10.9
449 assert Model.construct(bar="19").bar == "19" # type: ignore[comparison-overlap]
450 assert Model.construct(bar=False).bar is False
451
452
453def test_int_to_float_safe_conversion() -> None:
454 class Model(BaseModel):
455 float_field: float
456
457 m = Model.construct(float_field=10)
458 assert m.float_field == 10.0
459 assert isinstance(m.float_field, float)
460
461 m = Model.construct(float_field=10.12)
462 assert m.float_field == 10.12
463 assert isinstance(m.float_field, float)
464
465 # number too big
466 m = Model.construct(float_field=2**53 + 1)
467 assert m.float_field == 2**53 + 1
468 assert isinstance(m.float_field, int)
469
470
471def test_deprecated_alias() -> None:
472 class Model(BaseModel):
473 resource_id: str = Field(alias="model_id")
474
475 @property
476 def model_id(self) -> str:
477 return self.resource_id
478
479 m = Model.construct(model_id="id")
480 assert m.model_id == "id"
481 assert m.resource_id == "id"
482 assert m.resource_id is m.model_id
483
484 m = parse_obj(Model, {"model_id": "id"})
485 assert m.model_id == "id"
486 assert m.resource_id == "id"
487 assert m.resource_id is m.model_id
488
489
490def test_omitted_fields() -> None:
491 class Model(BaseModel):
492 resource_id: Optional[str] = None
493
494 m = Model.construct()
495 assert m.resource_id is None
496 assert "resource_id" not in m.model_fields_set
497
498 m = Model.construct(resource_id=None)
499 assert m.resource_id is None
500 assert "resource_id" in m.model_fields_set
501
502 m = Model.construct(resource_id="foo")
503 assert m.resource_id == "foo"
504 assert "resource_id" in m.model_fields_set
505
506
507def test_to_dict() -> None:
508 class Model(BaseModel):
509 foo: Optional[str] = Field(alias="FOO", default=None)
510
511 m = Model(FOO="hello")
512 assert m.to_dict() == {"FOO": "hello"}
513 assert m.to_dict(use_api_names=False) == {"foo": "hello"}
514
515 m2 = Model()
516 assert m2.to_dict() == {}
517 assert m2.to_dict(exclude_unset=False) == {"FOO": None}
518 assert m2.to_dict(exclude_unset=False, exclude_none=True) == {}
519 assert m2.to_dict(exclude_unset=False, exclude_defaults=True) == {}
520
521 m3 = Model(FOO=None)
522 assert m3.to_dict() == {"FOO": None}
523 assert m3.to_dict(exclude_none=True) == {}
524 assert m3.to_dict(exclude_defaults=True) == {}
525
526 class Model2(BaseModel):
527 created_at: datetime
528
529 time_str = "2024-03-21T11:39:01.275859"
530 m4 = Model2.construct(created_at=time_str)
531 assert m4.to_dict(mode="python") == {"created_at": datetime.fromisoformat(time_str)}
532 assert m4.to_dict(mode="json") == {"created_at": time_str}
533
534 if PYDANTIC_V1:
535 with pytest.raises(ValueError, match="warnings is only supported in Pydantic v2"):
536 m.to_dict(warnings=False)
537
538
539def test_forwards_compat_model_dump_method() -> None:
540 class Model(BaseModel):
541 foo: Optional[str] = Field(alias="FOO", default=None)
542
543 m = Model(FOO="hello")
544 assert m.model_dump() == {"foo": "hello"}
545 assert m.model_dump(include={"bar"}) == {}
546 assert m.model_dump(exclude={"foo"}) == {}
547 assert m.model_dump(by_alias=True) == {"FOO": "hello"}
548
549 m2 = Model()
550 assert m2.model_dump() == {"foo": None}
551 assert m2.model_dump(exclude_unset=True) == {}
552 assert m2.model_dump(exclude_none=True) == {}
553 assert m2.model_dump(exclude_defaults=True) == {}
554
555 m3 = Model(FOO=None)
556 assert m3.model_dump() == {"foo": None}
557 assert m3.model_dump(exclude_none=True) == {}
558
559 if PYDANTIC_V1:
560 with pytest.raises(ValueError, match="round_trip is only supported in Pydantic v2"):
561 m.model_dump(round_trip=True)
562
563 with pytest.raises(ValueError, match="warnings is only supported in Pydantic v2"):
564 m.model_dump(warnings=False)
565
566
567def test_compat_method_no_error_for_warnings() -> None:
568 class Model(BaseModel):
569 foo: Optional[str]
570
571 m = Model(foo="hello")
572 assert isinstance(model_dump(m, warnings=False), dict)
573
574
575def test_to_json() -> None:
576 class Model(BaseModel):
577 foo: Optional[str] = Field(alias="FOO", default=None)
578
579 m = Model(FOO="hello")
580 assert json.loads(m.to_json()) == {"FOO": "hello"}
581 assert json.loads(m.to_json(use_api_names=False)) == {"foo": "hello"}
582
583 if PYDANTIC_V1:
584 assert m.to_json(indent=None) == '{"FOO": "hello"}'
585 else:
586 assert m.to_json(indent=None) == '{"FOO":"hello"}'
587
588 m2 = Model()
589 assert json.loads(m2.to_json()) == {}
590 assert json.loads(m2.to_json(exclude_unset=False)) == {"FOO": None}
591 assert json.loads(m2.to_json(exclude_unset=False, exclude_none=True)) == {}
592 assert json.loads(m2.to_json(exclude_unset=False, exclude_defaults=True)) == {}
593
594 m3 = Model(FOO=None)
595 assert json.loads(m3.to_json()) == {"FOO": None}
596 assert json.loads(m3.to_json(exclude_none=True)) == {}
597
598 if PYDANTIC_V1:
599 with pytest.raises(ValueError, match="warnings is only supported in Pydantic v2"):
600 m.to_json(warnings=False)
601
602
603def test_forwards_compat_model_dump_json_method() -> None:
604 class Model(BaseModel):
605 foo: Optional[str] = Field(alias="FOO", default=None)
606
607 m = Model(FOO="hello")
608 assert json.loads(m.model_dump_json()) == {"foo": "hello"}
609 assert json.loads(m.model_dump_json(include={"bar"})) == {}
610 assert json.loads(m.model_dump_json(include={"foo"})) == {"foo": "hello"}
611 assert json.loads(m.model_dump_json(by_alias=True)) == {"FOO": "hello"}
612
613 assert m.model_dump_json(indent=2) == '{\n "foo": "hello"\n}'
614
615 m2 = Model()
616 assert json.loads(m2.model_dump_json()) == {"foo": None}
617 assert json.loads(m2.model_dump_json(exclude_unset=True)) == {}
618 assert json.loads(m2.model_dump_json(exclude_none=True)) == {}
619 assert json.loads(m2.model_dump_json(exclude_defaults=True)) == {}
620
621 m3 = Model(FOO=None)
622 assert json.loads(m3.model_dump_json()) == {"foo": None}
623 assert json.loads(m3.model_dump_json(exclude_none=True)) == {}
624
625 if PYDANTIC_V1:
626 with pytest.raises(ValueError, match="round_trip is only supported in Pydantic v2"):
627 m.model_dump_json(round_trip=True)
628
629 with pytest.raises(ValueError, match="warnings is only supported in Pydantic v2"):
630 m.model_dump_json(warnings=False)
631
632
633def test_type_compat() -> None:
634 # our model type can be assigned to Pydantic's model type
635
636 def takes_pydantic(model: pydantic.BaseModel) -> None: # noqa: ARG001
637 ...
638
639 class OurModel(BaseModel):
640 foo: Optional[str] = None
641
642 takes_pydantic(OurModel())
643
644
645def test_annotated_types() -> None:
646 class Model(BaseModel):
647 value: str
648
649 m = construct_type(
650 value={"value": "foo"},
651 type_=cast(Any, Annotated[Model, "random metadata"]),
652 )
653 assert isinstance(m, Model)
654 assert m.value == "foo"
655
656
657def test_discriminated_unions_invalid_data() -> None:
658 class A(BaseModel):
659 type: Literal["a"]
660
661 data: str
662
663 class B(BaseModel):
664 type: Literal["b"]
665
666 data: int
667
668 m = construct_type(
669 value={"type": "b", "data": "foo"},
670 type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="type")]),
671 )
672 assert isinstance(m, B)
673 assert m.type == "b"
674 assert m.data == "foo" # type: ignore[comparison-overlap]
675
676 m = construct_type(
677 value={"type": "a", "data": 100},
678 type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="type")]),
679 )
680 assert isinstance(m, A)
681 assert m.type == "a"
682 if PYDANTIC_V1:
683 # pydantic v1 automatically converts inputs to strings
684 # if the expected type is a str
685 assert m.data == "100"
686 else:
687 assert m.data == 100 # type: ignore[comparison-overlap]
688
689
690def test_discriminated_unions_unknown_variant() -> None:
691 class A(BaseModel):
692 type: Literal["a"]
693
694 data: str
695
696 class B(BaseModel):
697 type: Literal["b"]
698
699 data: int
700
701 m = construct_type(
702 value={"type": "c", "data": None, "new_thing": "bar"},
703 type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="type")]),
704 )
705
706 # just chooses the first variant
707 assert isinstance(m, A)
708 assert m.type == "c" # type: ignore[comparison-overlap]
709 assert m.data == None # type: ignore[unreachable]
710 assert m.new_thing == "bar"
711
712
713def test_discriminated_unions_invalid_data_nested_unions() -> None:
714 class A(BaseModel):
715 type: Literal["a"]
716
717 data: str
718
719 class B(BaseModel):
720 type: Literal["b"]
721
722 data: int
723
724 class C(BaseModel):
725 type: Literal["c"]
726
727 data: bool
728
729 m = construct_type(
730 value={"type": "b", "data": "foo"},
731 type_=cast(Any, Annotated[Union[Union[A, B], C], PropertyInfo(discriminator="type")]),
732 )
733 assert isinstance(m, B)
734 assert m.type == "b"
735 assert m.data == "foo" # type: ignore[comparison-overlap]
736
737 m = construct_type(
738 value={"type": "c", "data": "foo"},
739 type_=cast(Any, Annotated[Union[Union[A, B], C], PropertyInfo(discriminator="type")]),
740 )
741 assert isinstance(m, C)
742 assert m.type == "c"
743 assert m.data == "foo" # type: ignore[comparison-overlap]
744
745
746def test_discriminated_unions_with_aliases_invalid_data() -> None:
747 class A(BaseModel):
748 foo_type: Literal["a"] = Field(alias="type")
749
750 data: str
751
752 class B(BaseModel):
753 foo_type: Literal["b"] = Field(alias="type")
754
755 data: int
756
757 m = construct_type(
758 value={"type": "b", "data": "foo"},
759 type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="foo_type")]),
760 )
761 assert isinstance(m, B)
762 assert m.foo_type == "b"
763 assert m.data == "foo" # type: ignore[comparison-overlap]
764
765 m = construct_type(
766 value={"type": "a", "data": 100},
767 type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="foo_type")]),
768 )
769 assert isinstance(m, A)
770 assert m.foo_type == "a"
771 if PYDANTIC_V1:
772 # pydantic v1 automatically converts inputs to strings
773 # if the expected type is a str
774 assert m.data == "100"
775 else:
776 assert m.data == 100 # type: ignore[comparison-overlap]
777
778
779def test_discriminated_unions_overlapping_discriminators_invalid_data() -> None:
780 class A(BaseModel):
781 type: Literal["a"]
782
783 data: bool
784
785 class B(BaseModel):
786 type: Literal["a"]
787
788 data: int
789
790 m = construct_type(
791 value={"type": "a", "data": "foo"},
792 type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="type")]),
793 )
794 assert isinstance(m, B)
795 assert m.type == "a"
796 assert m.data == "foo" # type: ignore[comparison-overlap]
797
798
799def test_discriminated_unions_invalid_data_uses_cache() -> None:
800 class A(BaseModel):
801 type: Literal["a"]
802
803 data: str
804
805 class B(BaseModel):
806 type: Literal["b"]
807
808 data: int
809
810 UnionType = cast(Any, Union[A, B])
811
812 assert not DISCRIMINATOR_CACHE.get(UnionType)
813
814 m = construct_type(
815 value={"type": "b", "data": "foo"}, type_=cast(Any, Annotated[UnionType, PropertyInfo(discriminator="type")])
816 )
817 assert isinstance(m, B)
818 assert m.type == "b"
819 assert m.data == "foo" # type: ignore[comparison-overlap]
820
821 discriminator = DISCRIMINATOR_CACHE.get(UnionType)
822 assert discriminator is not None
823
824 m = construct_type(
825 value={"type": "b", "data": "foo"}, type_=cast(Any, Annotated[UnionType, PropertyInfo(discriminator="type")])
826 )
827 assert isinstance(m, B)
828 assert m.type == "b"
829 assert m.data == "foo" # type: ignore[comparison-overlap]
830
831 # if the discriminator details object stays the same between invocations then
832 # we hit the cache
833 assert DISCRIMINATOR_CACHE.get(UnionType) is discriminator
834
835
836@pytest.mark.skipif(PYDANTIC_V1, reason="TypeAliasType is not supported in Pydantic v1")
837def test_type_alias_type() -> None:
838 Alias = TypeAliasType("Alias", str) # pyright: ignore
839
840 class Model(BaseModel):
841 alias: Alias
842 union: Union[int, Alias]
843
844 m = construct_type(value={"alias": "foo", "union": "bar"}, type_=Model)
845 assert isinstance(m, Model)
846 assert isinstance(m.alias, str)
847 assert m.alias == "foo"
848 assert isinstance(m.union, str)
849 assert m.union == "bar"
850
851
852@pytest.mark.skipif(PYDANTIC_V1, reason="TypeAliasType is not supported in Pydantic v1")
853def test_field_named_cls() -> None:
854 class Model(BaseModel):
855 cls: str
856
857 m = construct_type(value={"cls": "foo"}, type_=Model)
858 assert isinstance(m, Model)
859 assert isinstance(m.cls, str)
860
861
862def test_discriminated_union_case() -> None:
863 class A(BaseModel):
864 type: Literal["a"]
865
866 data: bool
867
868 class B(BaseModel):
869 type: Literal["b"]
870
871 data: List[Union[A, object]]
872
873 class ModelA(BaseModel):
874 type: Literal["modelA"]
875
876 data: int
877
878 class ModelB(BaseModel):
879 type: Literal["modelB"]
880
881 required: str
882
883 data: Union[A, B]
884
885 # when constructing ModelA | ModelB, value data doesn't match ModelB exactly - missing `required`
886 m = construct_type(
887 value={"type": "modelB", "data": {"type": "a", "data": True}},
888 type_=cast(Any, Annotated[Union[ModelA, ModelB], PropertyInfo(discriminator="type")]),
889 )
890
891 assert isinstance(m, ModelB)
892
893
894def test_nested_discriminated_union() -> None:
895 class InnerType1(BaseModel):
896 type: Literal["type_1"]
897
898 class InnerModel(BaseModel):
899 inner_value: str
900
901 class InnerType2(BaseModel):
902 type: Literal["type_2"]
903 some_inner_model: InnerModel
904
905 class Type1(BaseModel):
906 base_type: Literal["base_type_1"]
907 value: Annotated[
908 Union[
909 InnerType1,
910 InnerType2,
911 ],
912 PropertyInfo(discriminator="type"),
913 ]
914
915 class Type2(BaseModel):
916 base_type: Literal["base_type_2"]
917
918 T = Annotated[
919 Union[
920 Type1,
921 Type2,
922 ],
923 PropertyInfo(discriminator="base_type"),
924 ]
925
926 model = construct_type(
927 type_=T,
928 value={
929 "base_type": "base_type_1",
930 "value": {
931 "type": "type_2",
932 },
933 },
934 )
935 assert isinstance(model, Type1)
936 assert isinstance(model.value, InnerType2)
937
938
939@pytest.mark.skipif(PYDANTIC_V1, reason="this is only supported in pydantic v2 for now")
940def test_extra_properties() -> None:
941 class Item(BaseModel):
942 prop: int
943
944 class Model(BaseModel):
945 __pydantic_extra__: Dict[str, Item] = Field(init=False) # pyright: ignore[reportIncompatibleVariableOverride]
946
947 other: str
948
949 if TYPE_CHECKING:
950
951 def __getattr__(self, attr: str) -> Item: ...
952
953 model = construct_type(
954 type_=Model,
955 value={
956 "a": {"prop": 1},
957 "other": "foo",
958 },
959 )
960 assert isinstance(model, Model)
961 assert model.a.prop == 1
962 assert isinstance(model.a, Item)
963 assert model.other == "foo"