Commit 0f0aec55

Stainless Bot <107565488+stainless-bot@users.noreply.github.com>
2024-03-11 21:28:46
chore(internal): improve deserialisation of discriminated unions (#1227)
1 parent b3e37c8
Changed files (3)
src/openai/_utils/_transform.py
@@ -51,6 +51,7 @@ class PropertyInfo:
     alias: str | None
     format: PropertyFormat | None
     format_template: str | None
+    discriminator: str | None
 
     def __init__(
         self,
@@ -58,14 +59,16 @@ class PropertyInfo:
         alias: str | None = None,
         format: PropertyFormat | None = None,
         format_template: str | None = None,
+        discriminator: str | None = None,
     ) -> None:
         self.alias = alias
         self.format = format
         self.format_template = format_template
+        self.discriminator = discriminator
 
     @override
     def __repr__(self) -> str:
-        return f"{self.__class__.__name__}(alias='{self.alias}', format={self.format}, format_template='{self.format_template}')"
+        return f"{self.__class__.__name__}(alias='{self.alias}', format={self.format}, format_template='{self.format_template}', discriminator='{self.discriminator}')"
 
 
 def maybe_transform(
src/openai/_models.py
@@ -10,6 +10,7 @@ from typing_extensions import (
     Protocol,
     Required,
     TypedDict,
+    TypeGuard,
     final,
     override,
     runtime_checkable,
@@ -31,6 +32,7 @@ from ._types import (
     HttpxRequestFiles,
 )
 from ._utils import (
+    PropertyInfo,
     is_list,
     is_given,
     is_mapping,
@@ -39,6 +41,7 @@ from ._utils import (
     strip_not_given,
     extract_type_arg,
     is_annotated_type,
+    strip_annotated_type,
 )
 from ._compat import (
     PYDANTIC_V2,
@@ -55,6 +58,9 @@ from ._compat import (
 )
 from ._constants import RAW_RESPONSE_HEADER
 
+if TYPE_CHECKING:
+    from pydantic_core.core_schema import ModelField, ModelFieldsSchema
+
 __all__ = ["BaseModel", "GenericModel"]
 
 _T = TypeVar("_T")
@@ -268,7 +274,6 @@ def _construct_field(value: object, field: FieldInfo, key: str) -> object:
 
 def is_basemodel(type_: type) -> bool:
     """Returns whether or not the given type is either a `BaseModel` or a union of `BaseModel`"""
-    origin = get_origin(type_) or type_
     if is_union(type_):
         for variant in get_args(type_):
             if is_basemodel(variant):
@@ -276,6 +281,11 @@ def is_basemodel(type_: type) -> bool:
 
         return False
 
+    return is_basemodel_type(type_)
+
+
+def is_basemodel_type(type_: type) -> TypeGuard[type[BaseModel] | type[GenericModel]]:
+    origin = get_origin(type_) or type_
     return issubclass(origin, BaseModel) or issubclass(origin, GenericModel)
 
 
@@ -286,7 +296,10 @@ def construct_type(*, value: object, type_: type) -> object:
     """
     # unwrap `Annotated[T, ...]` -> `T`
     if is_annotated_type(type_):
+        meta = get_args(type_)[1:]
         type_ = extract_type_arg(type_, 0)
+    else:
+        meta = tuple()
 
     # we need to use the origin class for any types that are subscripted generics
     # e.g. Dict[str, object]
@@ -299,6 +312,28 @@ def construct_type(*, value: object, type_: type) -> object:
         except Exception:
             pass
 
+        # if the type is a discriminated union then we want to construct the right variant
+        # in the union, even if the data doesn't match exactly, otherwise we'd break code
+        # that relies on the constructed class types, e.g.
+        #
+        # class FooType:
+        #   kind: Literal['foo']
+        #   value: str
+        #
+        # class BarType:
+        #   kind: Literal['bar']
+        #   value: int
+        #
+        # without this block, if the data we get is something like `{'kind': 'bar', 'value': 'foo'}` then
+        # we'd end up constructing `FooType` when it should be `BarType`.
+        discriminator = _build_discriminated_union_meta(union=type_, meta_annotations=meta)
+        if discriminator and is_mapping(value):
+            variant_value = value.get(discriminator.field_alias_from or discriminator.field_name)
+            if variant_value and isinstance(variant_value, str):
+                variant_type = discriminator.mapping.get(variant_value)
+                if variant_type:
+                    return construct_type(type_=variant_type, value=value)
+
         # if the data is not valid, use the first variant that doesn't fail while deserializing
         for variant in args:
             try:
@@ -356,6 +391,129 @@ def construct_type(*, value: object, type_: type) -> object:
     return value
 
 
+@runtime_checkable
+class CachedDiscriminatorType(Protocol):
+    __discriminator__: DiscriminatorDetails
+
+
+class DiscriminatorDetails:
+    field_name: str
+    """The name of the discriminator field in the variant class, e.g.
+
+    ```py
+    class Foo(BaseModel):
+        type: Literal['foo']
+    ```
+
+    Will result in field_name='type'
+    """
+
+    field_alias_from: str | None
+    """The name of the discriminator field in the API response, e.g.
+
+    ```py
+    class Foo(BaseModel):
+        type: Literal['foo'] = Field(alias='type_from_api')
+    ```
+
+    Will result in field_alias_from='type_from_api'
+    """
+
+    mapping: dict[str, type]
+    """Mapping of discriminator value to variant type, e.g.
+
+    {'foo': FooVariant, 'bar': BarVariant}
+    """
+
+    def __init__(
+        self,
+        *,
+        mapping: dict[str, type],
+        discriminator_field: str,
+        discriminator_alias: str | None,
+    ) -> None:
+        self.mapping = mapping
+        self.field_name = discriminator_field
+        self.field_alias_from = discriminator_alias
+
+
+def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, ...]) -> DiscriminatorDetails | None:
+    if isinstance(union, CachedDiscriminatorType):
+        return union.__discriminator__
+
+    discriminator_field_name: str | None = None
+
+    for annotation in meta_annotations:
+        if isinstance(annotation, PropertyInfo) and annotation.discriminator is not None:
+            discriminator_field_name = annotation.discriminator
+            break
+
+    if not discriminator_field_name:
+        return None
+
+    mapping: dict[str, type] = {}
+    discriminator_alias: str | None = None
+
+    for variant in get_args(union):
+        variant = strip_annotated_type(variant)
+        if is_basemodel_type(variant):
+            if PYDANTIC_V2:
+                field = _extract_field_schema_pv2(variant, discriminator_field_name)
+                if not field:
+                    continue
+
+                # Note: if one variant defines an alias then they all should
+                discriminator_alias = field.get("serialization_alias")
+
+                field_schema = field["schema"]
+
+                if field_schema["type"] == "literal":
+                    for entry in field_schema["expected"]:
+                        if isinstance(entry, str):
+                            mapping[entry] = variant
+            else:
+                field_info = cast("dict[str, FieldInfo]", variant.__fields__).get(discriminator_field_name)  # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
+                if not field_info:
+                    continue
+
+                # Note: if one variant defines an alias then they all should
+                discriminator_alias = field_info.alias
+
+                if field_info.annotation and is_literal_type(field_info.annotation):
+                    for entry in get_args(field_info.annotation):
+                        if isinstance(entry, str):
+                            mapping[entry] = variant
+
+    if not mapping:
+        return None
+
+    details = DiscriminatorDetails(
+        mapping=mapping,
+        discriminator_field=discriminator_field_name,
+        discriminator_alias=discriminator_alias,
+    )
+    cast(CachedDiscriminatorType, union).__discriminator__ = details
+    return details
+
+
+def _extract_field_schema_pv2(model: type[BaseModel], field_name: str) -> ModelField | None:
+    schema = model.__pydantic_core_schema__
+    if schema["type"] != "model":
+        return None
+
+    fields_schema = schema["schema"]
+    if fields_schema["type"] != "model-fields":
+        return None
+
+    fields_schema = cast("ModelFieldsSchema", fields_schema)
+
+    field = fields_schema["fields"].get(field_name)
+    if not field:
+        return None
+
+    return cast("ModelField", field)  # pyright: ignore[reportUnnecessaryCast]
+
+
 def validate_type(*, type_: type[_T], value: object) -> _T:
     """Strict validation that the given value matches the expected type"""
     if inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel):
tests/test_models.py
@@ -7,6 +7,7 @@ import pytest
 import pydantic
 from pydantic import Field
 
+from openai._utils import PropertyInfo
 from openai._compat import PYDANTIC_V2, parse_obj, model_dump, model_json
 from openai._models import BaseModel, construct_type
 
@@ -583,3 +584,182 @@ def test_annotated_types() -> None:
     )
     assert isinstance(m, Model)
     assert m.value == "foo"
+
+
+def test_discriminated_unions_invalid_data() -> None:
+    class A(BaseModel):
+        type: Literal["a"]
+
+        data: str
+
+    class B(BaseModel):
+        type: Literal["b"]
+
+        data: int
+
+    m = construct_type(
+        value={"type": "b", "data": "foo"},
+        type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="type")]),
+    )
+    assert isinstance(m, B)
+    assert m.type == "b"
+    assert m.data == "foo"  # type: ignore[comparison-overlap]
+
+    m = construct_type(
+        value={"type": "a", "data": 100},
+        type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="type")]),
+    )
+    assert isinstance(m, A)
+    assert m.type == "a"
+    if PYDANTIC_V2:
+        assert m.data == 100  # type: ignore[comparison-overlap]
+    else:
+        # pydantic v1 automatically converts inputs to strings
+        # if the expected type is a str
+        assert m.data == "100"
+
+
+def test_discriminated_unions_unknown_variant() -> None:
+    class A(BaseModel):
+        type: Literal["a"]
+
+        data: str
+
+    class B(BaseModel):
+        type: Literal["b"]
+
+        data: int
+
+    m = construct_type(
+        value={"type": "c", "data": None, "new_thing": "bar"},
+        type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="type")]),
+    )
+
+    # just chooses the first variant
+    assert isinstance(m, A)
+    assert m.type == "c"  # type: ignore[comparison-overlap]
+    assert m.data == None  # type: ignore[unreachable]
+    assert m.new_thing == "bar"
+
+
+def test_discriminated_unions_invalid_data_nested_unions() -> None:
+    class A(BaseModel):
+        type: Literal["a"]
+
+        data: str
+
+    class B(BaseModel):
+        type: Literal["b"]
+
+        data: int
+
+    class C(BaseModel):
+        type: Literal["c"]
+
+        data: bool
+
+    m = construct_type(
+        value={"type": "b", "data": "foo"},
+        type_=cast(Any, Annotated[Union[Union[A, B], C], PropertyInfo(discriminator="type")]),
+    )
+    assert isinstance(m, B)
+    assert m.type == "b"
+    assert m.data == "foo"  # type: ignore[comparison-overlap]
+
+    m = construct_type(
+        value={"type": "c", "data": "foo"},
+        type_=cast(Any, Annotated[Union[Union[A, B], C], PropertyInfo(discriminator="type")]),
+    )
+    assert isinstance(m, C)
+    assert m.type == "c"
+    assert m.data == "foo"  # type: ignore[comparison-overlap]
+
+
+def test_discriminated_unions_with_aliases_invalid_data() -> None:
+    class A(BaseModel):
+        foo_type: Literal["a"] = Field(alias="type")
+
+        data: str
+
+    class B(BaseModel):
+        foo_type: Literal["b"] = Field(alias="type")
+
+        data: int
+
+    m = construct_type(
+        value={"type": "b", "data": "foo"},
+        type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="foo_type")]),
+    )
+    assert isinstance(m, B)
+    assert m.foo_type == "b"
+    assert m.data == "foo"  # type: ignore[comparison-overlap]
+
+    m = construct_type(
+        value={"type": "a", "data": 100},
+        type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="foo_type")]),
+    )
+    assert isinstance(m, A)
+    assert m.foo_type == "a"
+    if PYDANTIC_V2:
+        assert m.data == 100  # type: ignore[comparison-overlap]
+    else:
+        # pydantic v1 automatically converts inputs to strings
+        # if the expected type is a str
+        assert m.data == "100"
+
+
+def test_discriminated_unions_overlapping_discriminators_invalid_data() -> None:
+    class A(BaseModel):
+        type: Literal["a"]
+
+        data: bool
+
+    class B(BaseModel):
+        type: Literal["a"]
+
+        data: int
+
+    m = construct_type(
+        value={"type": "a", "data": "foo"},
+        type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="type")]),
+    )
+    assert isinstance(m, B)
+    assert m.type == "a"
+    assert m.data == "foo"  # type: ignore[comparison-overlap]
+
+
+def test_discriminated_unions_invalid_data_uses_cache() -> None:
+    class A(BaseModel):
+        type: Literal["a"]
+
+        data: str
+
+    class B(BaseModel):
+        type: Literal["b"]
+
+        data: int
+
+    UnionType = cast(Any, Union[A, B])
+
+    assert not hasattr(UnionType, "__discriminator__")
+
+    m = construct_type(
+        value={"type": "b", "data": "foo"}, type_=cast(Any, Annotated[UnionType, PropertyInfo(discriminator="type")])
+    )
+    assert isinstance(m, B)
+    assert m.type == "b"
+    assert m.data == "foo"  # type: ignore[comparison-overlap]
+
+    discriminator = UnionType.__discriminator__
+    assert discriminator is not None
+
+    m = construct_type(
+        value={"type": "b", "data": "foo"}, type_=cast(Any, Annotated[UnionType, PropertyInfo(discriminator="type")])
+    )
+    assert isinstance(m, B)
+    assert m.type == "b"
+    assert m.data == "foo"  # type: ignore[comparison-overlap]
+
+    # if the discriminator details object stays the same between invocations then
+    # we hit the cache
+    assert UnionType.__discriminator__ is discriminator