main
  1import json
  2from typing import TYPE_CHECKING, Any, Dict, List, Union, Optional, cast
  3from datetime import datetime, timezone
  4from typing_extensions import Literal, Annotated, TypeAliasType
  5
  6import pytest
  7import pydantic
  8from pydantic import Field
  9
 10from openai._utils import PropertyInfo
 11from openai._compat import PYDANTIC_V1, parse_obj, model_dump, model_json
 12from openai._models import DISCRIMINATOR_CACHE, BaseModel, construct_type
 13
 14
 15class BasicModel(BaseModel):
 16    foo: str
 17
 18
 19@pytest.mark.parametrize("value", ["hello", 1], ids=["correct type", "mismatched"])
 20def test_basic(value: object) -> None:
 21    m = BasicModel.construct(foo=value)
 22    assert m.foo == value
 23
 24
 25def test_directly_nested_model() -> None:
 26    class NestedModel(BaseModel):
 27        nested: BasicModel
 28
 29    m = NestedModel.construct(nested={"foo": "Foo!"})
 30    assert m.nested.foo == "Foo!"
 31
 32    # mismatched types
 33    m = NestedModel.construct(nested="hello!")
 34    assert cast(Any, m.nested) == "hello!"
 35
 36
 37def test_optional_nested_model() -> None:
 38    class NestedModel(BaseModel):
 39        nested: Optional[BasicModel]
 40
 41    m1 = NestedModel.construct(nested=None)
 42    assert m1.nested is None
 43
 44    m2 = NestedModel.construct(nested={"foo": "bar"})
 45    assert m2.nested is not None
 46    assert m2.nested.foo == "bar"
 47
 48    # mismatched types
 49    m3 = NestedModel.construct(nested={"foo"})
 50    assert isinstance(cast(Any, m3.nested), set)
 51    assert cast(Any, m3.nested) == {"foo"}
 52
 53
 54def test_list_nested_model() -> None:
 55    class NestedModel(BaseModel):
 56        nested: List[BasicModel]
 57
 58    m = NestedModel.construct(nested=[{"foo": "bar"}, {"foo": "2"}])
 59    assert m.nested is not None
 60    assert isinstance(m.nested, list)
 61    assert len(m.nested) == 2
 62    assert m.nested[0].foo == "bar"
 63    assert m.nested[1].foo == "2"
 64
 65    # mismatched types
 66    m = NestedModel.construct(nested=True)
 67    assert cast(Any, m.nested) is True
 68
 69    m = NestedModel.construct(nested=[False])
 70    assert cast(Any, m.nested) == [False]
 71
 72
 73def test_optional_list_nested_model() -> None:
 74    class NestedModel(BaseModel):
 75        nested: Optional[List[BasicModel]]
 76
 77    m1 = NestedModel.construct(nested=[{"foo": "bar"}, {"foo": "2"}])
 78    assert m1.nested is not None
 79    assert isinstance(m1.nested, list)
 80    assert len(m1.nested) == 2
 81    assert m1.nested[0].foo == "bar"
 82    assert m1.nested[1].foo == "2"
 83
 84    m2 = NestedModel.construct(nested=None)
 85    assert m2.nested is None
 86
 87    # mismatched types
 88    m3 = NestedModel.construct(nested={1})
 89    assert cast(Any, m3.nested) == {1}
 90
 91    m4 = NestedModel.construct(nested=[False])
 92    assert cast(Any, m4.nested) == [False]
 93
 94
 95def test_list_optional_items_nested_model() -> None:
 96    class NestedModel(BaseModel):
 97        nested: List[Optional[BasicModel]]
 98
 99    m = NestedModel.construct(nested=[None, {"foo": "bar"}])
100    assert m.nested is not None
101    assert isinstance(m.nested, list)
102    assert len(m.nested) == 2
103    assert m.nested[0] is None
104    assert m.nested[1] is not None
105    assert m.nested[1].foo == "bar"
106
107    # mismatched types
108    m3 = NestedModel.construct(nested="foo")
109    assert cast(Any, m3.nested) == "foo"
110
111    m4 = NestedModel.construct(nested=[False])
112    assert cast(Any, m4.nested) == [False]
113
114
115def test_list_mismatched_type() -> None:
116    class NestedModel(BaseModel):
117        nested: List[str]
118
119    m = NestedModel.construct(nested=False)
120    assert cast(Any, m.nested) is False
121
122
123def test_raw_dictionary() -> None:
124    class NestedModel(BaseModel):
125        nested: Dict[str, str]
126
127    m = NestedModel.construct(nested={"hello": "world"})
128    assert m.nested == {"hello": "world"}
129
130    # mismatched types
131    m = NestedModel.construct(nested=False)
132    assert cast(Any, m.nested) is False
133
134
135def test_nested_dictionary_model() -> None:
136    class NestedModel(BaseModel):
137        nested: Dict[str, BasicModel]
138
139    m = NestedModel.construct(nested={"hello": {"foo": "bar"}})
140    assert isinstance(m.nested, dict)
141    assert m.nested["hello"].foo == "bar"
142
143    # mismatched types
144    m = NestedModel.construct(nested={"hello": False})
145    assert cast(Any, m.nested["hello"]) is False
146
147
148def test_unknown_fields() -> None:
149    m1 = BasicModel.construct(foo="foo", unknown=1)
150    assert m1.foo == "foo"
151    assert cast(Any, m1).unknown == 1
152
153    m2 = BasicModel.construct(foo="foo", unknown={"foo_bar": True})
154    assert m2.foo == "foo"
155    assert cast(Any, m2).unknown == {"foo_bar": True}
156
157    assert model_dump(m2) == {"foo": "foo", "unknown": {"foo_bar": True}}
158
159
160def test_strict_validation_unknown_fields() -> None:
161    class Model(BaseModel):
162        foo: str
163
164    model = parse_obj(Model, dict(foo="hello!", user="Robert"))
165    assert model.foo == "hello!"
166    assert cast(Any, model).user == "Robert"
167
168    assert model_dump(model) == {"foo": "hello!", "user": "Robert"}
169
170
171def test_aliases() -> None:
172    class Model(BaseModel):
173        my_field: int = Field(alias="myField")
174
175    m = Model.construct(myField=1)
176    assert m.my_field == 1
177
178    # mismatched types
179    m = Model.construct(myField={"hello": False})
180    assert cast(Any, m.my_field) == {"hello": False}
181
182
183def test_repr() -> None:
184    model = BasicModel(foo="bar")
185    assert str(model) == "BasicModel(foo='bar')"
186    assert repr(model) == "BasicModel(foo='bar')"
187
188
189def test_repr_nested_model() -> None:
190    class Child(BaseModel):
191        name: str
192        age: int
193
194    class Parent(BaseModel):
195        name: str
196        child: Child
197
198    model = Parent(name="Robert", child=Child(name="Foo", age=5))
199    assert str(model) == "Parent(name='Robert', child=Child(name='Foo', age=5))"
200    assert repr(model) == "Parent(name='Robert', child=Child(name='Foo', age=5))"
201
202
203def test_optional_list() -> None:
204    class Submodel(BaseModel):
205        name: str
206
207    class Model(BaseModel):
208        items: Optional[List[Submodel]]
209
210    m = Model.construct(items=None)
211    assert m.items is None
212
213    m = Model.construct(items=[])
214    assert m.items == []
215
216    m = Model.construct(items=[{"name": "Robert"}])
217    assert m.items is not None
218    assert len(m.items) == 1
219    assert m.items[0].name == "Robert"
220
221
222def test_nested_union_of_models() -> None:
223    class Submodel1(BaseModel):
224        bar: bool
225
226    class Submodel2(BaseModel):
227        thing: str
228
229    class Model(BaseModel):
230        foo: Union[Submodel1, Submodel2]
231
232    m = Model.construct(foo={"thing": "hello"})
233    assert isinstance(m.foo, Submodel2)
234    assert m.foo.thing == "hello"
235
236
237def test_nested_union_of_mixed_types() -> None:
238    class Submodel1(BaseModel):
239        bar: bool
240
241    class Model(BaseModel):
242        foo: Union[Submodel1, Literal[True], Literal["CARD_HOLDER"]]
243
244    m = Model.construct(foo=True)
245    assert m.foo is True
246
247    m = Model.construct(foo="CARD_HOLDER")
248    assert m.foo == "CARD_HOLDER"
249
250    m = Model.construct(foo={"bar": False})
251    assert isinstance(m.foo, Submodel1)
252    assert m.foo.bar is False
253
254
255def test_nested_union_multiple_variants() -> None:
256    class Submodel1(BaseModel):
257        bar: bool
258
259    class Submodel2(BaseModel):
260        thing: str
261
262    class Submodel3(BaseModel):
263        foo: int
264
265    class Model(BaseModel):
266        foo: Union[Submodel1, Submodel2, None, Submodel3]
267
268    m = Model.construct(foo={"thing": "hello"})
269    assert isinstance(m.foo, Submodel2)
270    assert m.foo.thing == "hello"
271
272    m = Model.construct(foo=None)
273    assert m.foo is None
274
275    m = Model.construct()
276    assert m.foo is None
277
278    m = Model.construct(foo={"foo": "1"})
279    assert isinstance(m.foo, Submodel3)
280    assert m.foo.foo == 1
281
282
283def test_nested_union_invalid_data() -> None:
284    class Submodel1(BaseModel):
285        level: int
286
287    class Submodel2(BaseModel):
288        name: str
289
290    class Model(BaseModel):
291        foo: Union[Submodel1, Submodel2]
292
293    m = Model.construct(foo=True)
294    assert cast(bool, m.foo) is True
295
296    m = Model.construct(foo={"name": 3})
297    if PYDANTIC_V1:
298        assert isinstance(m.foo, Submodel2)
299        assert m.foo.name == "3"
300    else:
301        assert isinstance(m.foo, Submodel1)
302        assert m.foo.name == 3  # type: ignore
303
304
305def test_list_of_unions() -> None:
306    class Submodel1(BaseModel):
307        level: int
308
309    class Submodel2(BaseModel):
310        name: str
311
312    class Model(BaseModel):
313        items: List[Union[Submodel1, Submodel2]]
314
315    m = Model.construct(items=[{"level": 1}, {"name": "Robert"}])
316    assert len(m.items) == 2
317    assert isinstance(m.items[0], Submodel1)
318    assert m.items[0].level == 1
319    assert isinstance(m.items[1], Submodel2)
320    assert m.items[1].name == "Robert"
321
322    m = Model.construct(items=[{"level": -1}, 156])
323    assert len(m.items) == 2
324    assert isinstance(m.items[0], Submodel1)
325    assert m.items[0].level == -1
326    assert cast(Any, m.items[1]) == 156
327
328
329def test_union_of_lists() -> None:
330    class SubModel1(BaseModel):
331        level: int
332
333    class SubModel2(BaseModel):
334        name: str
335
336    class Model(BaseModel):
337        items: Union[List[SubModel1], List[SubModel2]]
338
339    # with one valid entry
340    m = Model.construct(items=[{"name": "Robert"}])
341    assert len(m.items) == 1
342    assert isinstance(m.items[0], SubModel2)
343    assert m.items[0].name == "Robert"
344
345    # with two entries pointing to different types
346    m = Model.construct(items=[{"level": 1}, {"name": "Robert"}])
347    assert len(m.items) == 2
348    assert isinstance(m.items[0], SubModel1)
349    assert m.items[0].level == 1
350    assert isinstance(m.items[1], SubModel1)
351    assert cast(Any, m.items[1]).name == "Robert"
352
353    # with two entries pointing to *completely* different types
354    m = Model.construct(items=[{"level": -1}, 156])
355    assert len(m.items) == 2
356    assert isinstance(m.items[0], SubModel1)
357    assert m.items[0].level == -1
358    assert cast(Any, m.items[1]) == 156
359
360
361def test_dict_of_union() -> None:
362    class SubModel1(BaseModel):
363        name: str
364
365    class SubModel2(BaseModel):
366        foo: str
367
368    class Model(BaseModel):
369        data: Dict[str, Union[SubModel1, SubModel2]]
370
371    m = Model.construct(data={"hello": {"name": "there"}, "foo": {"foo": "bar"}})
372    assert len(list(m.data.keys())) == 2
373    assert isinstance(m.data["hello"], SubModel1)
374    assert m.data["hello"].name == "there"
375    assert isinstance(m.data["foo"], SubModel2)
376    assert m.data["foo"].foo == "bar"
377
378    # TODO: test mismatched type
379
380
381def test_double_nested_union() -> None:
382    class SubModel1(BaseModel):
383        name: str
384
385    class SubModel2(BaseModel):
386        bar: str
387
388    class Model(BaseModel):
389        data: Dict[str, List[Union[SubModel1, SubModel2]]]
390
391    m = Model.construct(data={"foo": [{"bar": "baz"}, {"name": "Robert"}]})
392    assert len(m.data["foo"]) == 2
393
394    entry1 = m.data["foo"][0]
395    assert isinstance(entry1, SubModel2)
396    assert entry1.bar == "baz"
397
398    entry2 = m.data["foo"][1]
399    assert isinstance(entry2, SubModel1)
400    assert entry2.name == "Robert"
401
402    # TODO: test mismatched type
403
404
405def test_union_of_dict() -> None:
406    class SubModel1(BaseModel):
407        name: str
408
409    class SubModel2(BaseModel):
410        foo: str
411
412    class Model(BaseModel):
413        data: Union[Dict[str, SubModel1], Dict[str, SubModel2]]
414
415    m = Model.construct(data={"hello": {"name": "there"}, "foo": {"foo": "bar"}})
416    assert len(list(m.data.keys())) == 2
417    assert isinstance(m.data["hello"], SubModel1)
418    assert m.data["hello"].name == "there"
419    assert isinstance(m.data["foo"], SubModel1)
420    assert cast(Any, m.data["foo"]).foo == "bar"
421
422
423def test_iso8601_datetime() -> None:
424    class Model(BaseModel):
425        created_at: datetime
426
427    expected = datetime(2019, 12, 27, 18, 11, 19, 117000, tzinfo=timezone.utc)
428
429    if PYDANTIC_V1:
430        expected_json = '{"created_at": "2019-12-27T18:11:19.117000+00:00"}'
431    else:
432        expected_json = '{"created_at":"2019-12-27T18:11:19.117000Z"}'
433
434    model = Model.construct(created_at="2019-12-27T18:11:19.117Z")
435    assert model.created_at == expected
436    assert model_json(model) == expected_json
437
438    model = parse_obj(Model, dict(created_at="2019-12-27T18:11:19.117Z"))
439    assert model.created_at == expected
440    assert model_json(model) == expected_json
441
442
443def test_does_not_coerce_int() -> None:
444    class Model(BaseModel):
445        bar: int
446
447    assert Model.construct(bar=1).bar == 1
448    assert Model.construct(bar=10.9).bar == 10.9
449    assert Model.construct(bar="19").bar == "19"  # type: ignore[comparison-overlap]
450    assert Model.construct(bar=False).bar is False
451
452
453def test_int_to_float_safe_conversion() -> None:
454    class Model(BaseModel):
455        float_field: float
456
457    m = Model.construct(float_field=10)
458    assert m.float_field == 10.0
459    assert isinstance(m.float_field, float)
460
461    m = Model.construct(float_field=10.12)
462    assert m.float_field == 10.12
463    assert isinstance(m.float_field, float)
464
465    # number too big
466    m = Model.construct(float_field=2**53 + 1)
467    assert m.float_field == 2**53 + 1
468    assert isinstance(m.float_field, int)
469
470
471def test_deprecated_alias() -> None:
472    class Model(BaseModel):
473        resource_id: str = Field(alias="model_id")
474
475        @property
476        def model_id(self) -> str:
477            return self.resource_id
478
479    m = Model.construct(model_id="id")
480    assert m.model_id == "id"
481    assert m.resource_id == "id"
482    assert m.resource_id is m.model_id
483
484    m = parse_obj(Model, {"model_id": "id"})
485    assert m.model_id == "id"
486    assert m.resource_id == "id"
487    assert m.resource_id is m.model_id
488
489
490def test_omitted_fields() -> None:
491    class Model(BaseModel):
492        resource_id: Optional[str] = None
493
494    m = Model.construct()
495    assert m.resource_id is None
496    assert "resource_id" not in m.model_fields_set
497
498    m = Model.construct(resource_id=None)
499    assert m.resource_id is None
500    assert "resource_id" in m.model_fields_set
501
502    m = Model.construct(resource_id="foo")
503    assert m.resource_id == "foo"
504    assert "resource_id" in m.model_fields_set
505
506
507def test_to_dict() -> None:
508    class Model(BaseModel):
509        foo: Optional[str] = Field(alias="FOO", default=None)
510
511    m = Model(FOO="hello")
512    assert m.to_dict() == {"FOO": "hello"}
513    assert m.to_dict(use_api_names=False) == {"foo": "hello"}
514
515    m2 = Model()
516    assert m2.to_dict() == {}
517    assert m2.to_dict(exclude_unset=False) == {"FOO": None}
518    assert m2.to_dict(exclude_unset=False, exclude_none=True) == {}
519    assert m2.to_dict(exclude_unset=False, exclude_defaults=True) == {}
520
521    m3 = Model(FOO=None)
522    assert m3.to_dict() == {"FOO": None}
523    assert m3.to_dict(exclude_none=True) == {}
524    assert m3.to_dict(exclude_defaults=True) == {}
525
526    class Model2(BaseModel):
527        created_at: datetime
528
529    time_str = "2024-03-21T11:39:01.275859"
530    m4 = Model2.construct(created_at=time_str)
531    assert m4.to_dict(mode="python") == {"created_at": datetime.fromisoformat(time_str)}
532    assert m4.to_dict(mode="json") == {"created_at": time_str}
533
534    if PYDANTIC_V1:
535        with pytest.raises(ValueError, match="warnings is only supported in Pydantic v2"):
536            m.to_dict(warnings=False)
537
538
539def test_forwards_compat_model_dump_method() -> None:
540    class Model(BaseModel):
541        foo: Optional[str] = Field(alias="FOO", default=None)
542
543    m = Model(FOO="hello")
544    assert m.model_dump() == {"foo": "hello"}
545    assert m.model_dump(include={"bar"}) == {}
546    assert m.model_dump(exclude={"foo"}) == {}
547    assert m.model_dump(by_alias=True) == {"FOO": "hello"}
548
549    m2 = Model()
550    assert m2.model_dump() == {"foo": None}
551    assert m2.model_dump(exclude_unset=True) == {}
552    assert m2.model_dump(exclude_none=True) == {}
553    assert m2.model_dump(exclude_defaults=True) == {}
554
555    m3 = Model(FOO=None)
556    assert m3.model_dump() == {"foo": None}
557    assert m3.model_dump(exclude_none=True) == {}
558
559    if PYDANTIC_V1:
560        with pytest.raises(ValueError, match="round_trip is only supported in Pydantic v2"):
561            m.model_dump(round_trip=True)
562
563        with pytest.raises(ValueError, match="warnings is only supported in Pydantic v2"):
564            m.model_dump(warnings=False)
565
566
567def test_compat_method_no_error_for_warnings() -> None:
568    class Model(BaseModel):
569        foo: Optional[str]
570
571    m = Model(foo="hello")
572    assert isinstance(model_dump(m, warnings=False), dict)
573
574
575def test_to_json() -> None:
576    class Model(BaseModel):
577        foo: Optional[str] = Field(alias="FOO", default=None)
578
579    m = Model(FOO="hello")
580    assert json.loads(m.to_json()) == {"FOO": "hello"}
581    assert json.loads(m.to_json(use_api_names=False)) == {"foo": "hello"}
582
583    if PYDANTIC_V1:
584        assert m.to_json(indent=None) == '{"FOO": "hello"}'
585    else:
586        assert m.to_json(indent=None) == '{"FOO":"hello"}'
587
588    m2 = Model()
589    assert json.loads(m2.to_json()) == {}
590    assert json.loads(m2.to_json(exclude_unset=False)) == {"FOO": None}
591    assert json.loads(m2.to_json(exclude_unset=False, exclude_none=True)) == {}
592    assert json.loads(m2.to_json(exclude_unset=False, exclude_defaults=True)) == {}
593
594    m3 = Model(FOO=None)
595    assert json.loads(m3.to_json()) == {"FOO": None}
596    assert json.loads(m3.to_json(exclude_none=True)) == {}
597
598    if PYDANTIC_V1:
599        with pytest.raises(ValueError, match="warnings is only supported in Pydantic v2"):
600            m.to_json(warnings=False)
601
602
603def test_forwards_compat_model_dump_json_method() -> None:
604    class Model(BaseModel):
605        foo: Optional[str] = Field(alias="FOO", default=None)
606
607    m = Model(FOO="hello")
608    assert json.loads(m.model_dump_json()) == {"foo": "hello"}
609    assert json.loads(m.model_dump_json(include={"bar"})) == {}
610    assert json.loads(m.model_dump_json(include={"foo"})) == {"foo": "hello"}
611    assert json.loads(m.model_dump_json(by_alias=True)) == {"FOO": "hello"}
612
613    assert m.model_dump_json(indent=2) == '{\n  "foo": "hello"\n}'
614
615    m2 = Model()
616    assert json.loads(m2.model_dump_json()) == {"foo": None}
617    assert json.loads(m2.model_dump_json(exclude_unset=True)) == {}
618    assert json.loads(m2.model_dump_json(exclude_none=True)) == {}
619    assert json.loads(m2.model_dump_json(exclude_defaults=True)) == {}
620
621    m3 = Model(FOO=None)
622    assert json.loads(m3.model_dump_json()) == {"foo": None}
623    assert json.loads(m3.model_dump_json(exclude_none=True)) == {}
624
625    if PYDANTIC_V1:
626        with pytest.raises(ValueError, match="round_trip is only supported in Pydantic v2"):
627            m.model_dump_json(round_trip=True)
628
629        with pytest.raises(ValueError, match="warnings is only supported in Pydantic v2"):
630            m.model_dump_json(warnings=False)
631
632
633def test_type_compat() -> None:
634    # our model type can be assigned to Pydantic's model type
635
636    def takes_pydantic(model: pydantic.BaseModel) -> None:  # noqa: ARG001
637        ...
638
639    class OurModel(BaseModel):
640        foo: Optional[str] = None
641
642    takes_pydantic(OurModel())
643
644
645def test_annotated_types() -> None:
646    class Model(BaseModel):
647        value: str
648
649    m = construct_type(
650        value={"value": "foo"},
651        type_=cast(Any, Annotated[Model, "random metadata"]),
652    )
653    assert isinstance(m, Model)
654    assert m.value == "foo"
655
656
657def test_discriminated_unions_invalid_data() -> None:
658    class A(BaseModel):
659        type: Literal["a"]
660
661        data: str
662
663    class B(BaseModel):
664        type: Literal["b"]
665
666        data: int
667
668    m = construct_type(
669        value={"type": "b", "data": "foo"},
670        type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="type")]),
671    )
672    assert isinstance(m, B)
673    assert m.type == "b"
674    assert m.data == "foo"  # type: ignore[comparison-overlap]
675
676    m = construct_type(
677        value={"type": "a", "data": 100},
678        type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="type")]),
679    )
680    assert isinstance(m, A)
681    assert m.type == "a"
682    if PYDANTIC_V1:
683        # pydantic v1 automatically converts inputs to strings
684        # if the expected type is a str
685        assert m.data == "100"
686    else:
687        assert m.data == 100  # type: ignore[comparison-overlap]
688
689
690def test_discriminated_unions_unknown_variant() -> None:
691    class A(BaseModel):
692        type: Literal["a"]
693
694        data: str
695
696    class B(BaseModel):
697        type: Literal["b"]
698
699        data: int
700
701    m = construct_type(
702        value={"type": "c", "data": None, "new_thing": "bar"},
703        type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="type")]),
704    )
705
706    # just chooses the first variant
707    assert isinstance(m, A)
708    assert m.type == "c"  # type: ignore[comparison-overlap]
709    assert m.data == None  # type: ignore[unreachable]
710    assert m.new_thing == "bar"
711
712
713def test_discriminated_unions_invalid_data_nested_unions() -> None:
714    class A(BaseModel):
715        type: Literal["a"]
716
717        data: str
718
719    class B(BaseModel):
720        type: Literal["b"]
721
722        data: int
723
724    class C(BaseModel):
725        type: Literal["c"]
726
727        data: bool
728
729    m = construct_type(
730        value={"type": "b", "data": "foo"},
731        type_=cast(Any, Annotated[Union[Union[A, B], C], PropertyInfo(discriminator="type")]),
732    )
733    assert isinstance(m, B)
734    assert m.type == "b"
735    assert m.data == "foo"  # type: ignore[comparison-overlap]
736
737    m = construct_type(
738        value={"type": "c", "data": "foo"},
739        type_=cast(Any, Annotated[Union[Union[A, B], C], PropertyInfo(discriminator="type")]),
740    )
741    assert isinstance(m, C)
742    assert m.type == "c"
743    assert m.data == "foo"  # type: ignore[comparison-overlap]
744
745
746def test_discriminated_unions_with_aliases_invalid_data() -> None:
747    class A(BaseModel):
748        foo_type: Literal["a"] = Field(alias="type")
749
750        data: str
751
752    class B(BaseModel):
753        foo_type: Literal["b"] = Field(alias="type")
754
755        data: int
756
757    m = construct_type(
758        value={"type": "b", "data": "foo"},
759        type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="foo_type")]),
760    )
761    assert isinstance(m, B)
762    assert m.foo_type == "b"
763    assert m.data == "foo"  # type: ignore[comparison-overlap]
764
765    m = construct_type(
766        value={"type": "a", "data": 100},
767        type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="foo_type")]),
768    )
769    assert isinstance(m, A)
770    assert m.foo_type == "a"
771    if PYDANTIC_V1:
772        # pydantic v1 automatically converts inputs to strings
773        # if the expected type is a str
774        assert m.data == "100"
775    else:
776        assert m.data == 100  # type: ignore[comparison-overlap]
777
778
779def test_discriminated_unions_overlapping_discriminators_invalid_data() -> None:
780    class A(BaseModel):
781        type: Literal["a"]
782
783        data: bool
784
785    class B(BaseModel):
786        type: Literal["a"]
787
788        data: int
789
790    m = construct_type(
791        value={"type": "a", "data": "foo"},
792        type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="type")]),
793    )
794    assert isinstance(m, B)
795    assert m.type == "a"
796    assert m.data == "foo"  # type: ignore[comparison-overlap]
797
798
799def test_discriminated_unions_invalid_data_uses_cache() -> None:
800    class A(BaseModel):
801        type: Literal["a"]
802
803        data: str
804
805    class B(BaseModel):
806        type: Literal["b"]
807
808        data: int
809
810    UnionType = cast(Any, Union[A, B])
811
812    assert not DISCRIMINATOR_CACHE.get(UnionType)
813
814    m = construct_type(
815        value={"type": "b", "data": "foo"}, type_=cast(Any, Annotated[UnionType, PropertyInfo(discriminator="type")])
816    )
817    assert isinstance(m, B)
818    assert m.type == "b"
819    assert m.data == "foo"  # type: ignore[comparison-overlap]
820
821    discriminator = DISCRIMINATOR_CACHE.get(UnionType)
822    assert discriminator is not None
823
824    m = construct_type(
825        value={"type": "b", "data": "foo"}, type_=cast(Any, Annotated[UnionType, PropertyInfo(discriminator="type")])
826    )
827    assert isinstance(m, B)
828    assert m.type == "b"
829    assert m.data == "foo"  # type: ignore[comparison-overlap]
830
831    # if the discriminator details object stays the same between invocations then
832    # we hit the cache
833    assert DISCRIMINATOR_CACHE.get(UnionType) is discriminator
834
835
836@pytest.mark.skipif(PYDANTIC_V1, reason="TypeAliasType is not supported in Pydantic v1")
837def test_type_alias_type() -> None:
838    Alias = TypeAliasType("Alias", str)  # pyright: ignore
839
840    class Model(BaseModel):
841        alias: Alias
842        union: Union[int, Alias]
843
844    m = construct_type(value={"alias": "foo", "union": "bar"}, type_=Model)
845    assert isinstance(m, Model)
846    assert isinstance(m.alias, str)
847    assert m.alias == "foo"
848    assert isinstance(m.union, str)
849    assert m.union == "bar"
850
851
852@pytest.mark.skipif(PYDANTIC_V1, reason="TypeAliasType is not supported in Pydantic v1")
853def test_field_named_cls() -> None:
854    class Model(BaseModel):
855        cls: str
856
857    m = construct_type(value={"cls": "foo"}, type_=Model)
858    assert isinstance(m, Model)
859    assert isinstance(m.cls, str)
860
861
862def test_discriminated_union_case() -> None:
863    class A(BaseModel):
864        type: Literal["a"]
865
866        data: bool
867
868    class B(BaseModel):
869        type: Literal["b"]
870
871        data: List[Union[A, object]]
872
873    class ModelA(BaseModel):
874        type: Literal["modelA"]
875
876        data: int
877
878    class ModelB(BaseModel):
879        type: Literal["modelB"]
880
881        required: str
882
883        data: Union[A, B]
884
885    # when constructing ModelA | ModelB, value data doesn't match ModelB exactly - missing `required`
886    m = construct_type(
887        value={"type": "modelB", "data": {"type": "a", "data": True}},
888        type_=cast(Any, Annotated[Union[ModelA, ModelB], PropertyInfo(discriminator="type")]),
889    )
890
891    assert isinstance(m, ModelB)
892
893
894def test_nested_discriminated_union() -> None:
895    class InnerType1(BaseModel):
896        type: Literal["type_1"]
897
898    class InnerModel(BaseModel):
899        inner_value: str
900
901    class InnerType2(BaseModel):
902        type: Literal["type_2"]
903        some_inner_model: InnerModel
904
905    class Type1(BaseModel):
906        base_type: Literal["base_type_1"]
907        value: Annotated[
908            Union[
909                InnerType1,
910                InnerType2,
911            ],
912            PropertyInfo(discriminator="type"),
913        ]
914
915    class Type2(BaseModel):
916        base_type: Literal["base_type_2"]
917
918    T = Annotated[
919        Union[
920            Type1,
921            Type2,
922        ],
923        PropertyInfo(discriminator="base_type"),
924    ]
925
926    model = construct_type(
927        type_=T,
928        value={
929            "base_type": "base_type_1",
930            "value": {
931                "type": "type_2",
932            },
933        },
934    )
935    assert isinstance(model, Type1)
936    assert isinstance(model.value, InnerType2)
937
938
939@pytest.mark.skipif(PYDANTIC_V1, reason="this is only supported in pydantic v2 for now")
940def test_extra_properties() -> None:
941    class Item(BaseModel):
942        prop: int
943
944    class Model(BaseModel):
945        __pydantic_extra__: Dict[str, Item] = Field(init=False)  # pyright: ignore[reportIncompatibleVariableOverride]
946
947        other: str
948
949        if TYPE_CHECKING:
950
951            def __getattr__(self, attr: str) -> Item: ...
952
953    model = construct_type(
954        type_=Model,
955        value={
956            "a": {"prop": 1},
957            "other": "foo",
958        },
959    )
960    assert isinstance(model, Model)
961    assert model.a.prop == 1
962    assert isinstance(model.a, Item)
963    assert model.other == "foo"