Commit 0f0aec55
Changed files (3)
src
openai
_utils
tests
src/openai/_utils/_transform.py
@@ -51,6 +51,7 @@ class PropertyInfo:
alias: str | None
format: PropertyFormat | None
format_template: str | None
+ discriminator: str | None
def __init__(
self,
@@ -58,14 +59,16 @@ class PropertyInfo:
alias: str | None = None,
format: PropertyFormat | None = None,
format_template: str | None = None,
+ discriminator: str | None = None,
) -> None:
self.alias = alias
self.format = format
self.format_template = format_template
+ self.discriminator = discriminator
@override
def __repr__(self) -> str:
- return f"{self.__class__.__name__}(alias='{self.alias}', format={self.format}, format_template='{self.format_template}')"
+ return f"{self.__class__.__name__}(alias='{self.alias}', format={self.format}, format_template='{self.format_template}', discriminator='{self.discriminator}')"
def maybe_transform(
src/openai/_models.py
@@ -10,6 +10,7 @@ from typing_extensions import (
Protocol,
Required,
TypedDict,
+ TypeGuard,
final,
override,
runtime_checkable,
@@ -31,6 +32,7 @@ from ._types import (
HttpxRequestFiles,
)
from ._utils import (
+ PropertyInfo,
is_list,
is_given,
is_mapping,
@@ -39,6 +41,7 @@ from ._utils import (
strip_not_given,
extract_type_arg,
is_annotated_type,
+ strip_annotated_type,
)
from ._compat import (
PYDANTIC_V2,
@@ -55,6 +58,9 @@ from ._compat import (
)
from ._constants import RAW_RESPONSE_HEADER
+if TYPE_CHECKING:
+ from pydantic_core.core_schema import ModelField, ModelFieldsSchema
+
__all__ = ["BaseModel", "GenericModel"]
_T = TypeVar("_T")
@@ -268,7 +274,6 @@ def _construct_field(value: object, field: FieldInfo, key: str) -> object:
def is_basemodel(type_: type) -> bool:
"""Returns whether or not the given type is either a `BaseModel` or a union of `BaseModel`"""
- origin = get_origin(type_) or type_
if is_union(type_):
for variant in get_args(type_):
if is_basemodel(variant):
@@ -276,6 +281,11 @@ def is_basemodel(type_: type) -> bool:
return False
+ return is_basemodel_type(type_)
+
+
+def is_basemodel_type(type_: type) -> TypeGuard[type[BaseModel] | type[GenericModel]]:
+ origin = get_origin(type_) or type_
return issubclass(origin, BaseModel) or issubclass(origin, GenericModel)
@@ -286,7 +296,10 @@ def construct_type(*, value: object, type_: type) -> object:
"""
# unwrap `Annotated[T, ...]` -> `T`
if is_annotated_type(type_):
+ meta = get_args(type_)[1:]
type_ = extract_type_arg(type_, 0)
+ else:
+ meta = tuple()
# we need to use the origin class for any types that are subscripted generics
# e.g. Dict[str, object]
@@ -299,6 +312,28 @@ def construct_type(*, value: object, type_: type) -> object:
except Exception:
pass
+ # if the type is a discriminated union then we want to construct the right variant
+ # in the union, even if the data doesn't match exactly, otherwise we'd break code
+ # that relies on the constructed class types, e.g.
+ #
+ # class FooType:
+ # kind: Literal['foo']
+ # value: str
+ #
+ # class BarType:
+ # kind: Literal['bar']
+ # value: int
+ #
+ # without this block, if the data we get is something like `{'kind': 'bar', 'value': 'foo'}` then
+ # we'd end up constructing `FooType` when it should be `BarType`.
+ discriminator = _build_discriminated_union_meta(union=type_, meta_annotations=meta)
+ if discriminator and is_mapping(value):
+ variant_value = value.get(discriminator.field_alias_from or discriminator.field_name)
+ if variant_value and isinstance(variant_value, str):
+ variant_type = discriminator.mapping.get(variant_value)
+ if variant_type:
+ return construct_type(type_=variant_type, value=value)
+
# if the data is not valid, use the first variant that doesn't fail while deserializing
for variant in args:
try:
@@ -356,6 +391,129 @@ def construct_type(*, value: object, type_: type) -> object:
return value
+@runtime_checkable
+class CachedDiscriminatorType(Protocol):
+ __discriminator__: DiscriminatorDetails
+
+
+class DiscriminatorDetails:
+ field_name: str
+ """The name of the discriminator field in the variant class, e.g.
+
+ ```py
+ class Foo(BaseModel):
+ type: Literal['foo']
+ ```
+
+ Will result in field_name='type'
+ """
+
+ field_alias_from: str | None
+ """The name of the discriminator field in the API response, e.g.
+
+ ```py
+ class Foo(BaseModel):
+ type: Literal['foo'] = Field(alias='type_from_api')
+ ```
+
+ Will result in field_alias_from='type_from_api'
+ """
+
+ mapping: dict[str, type]
+ """Mapping of discriminator value to variant type, e.g.
+
+ {'foo': FooVariant, 'bar': BarVariant}
+ """
+
+ def __init__(
+ self,
+ *,
+ mapping: dict[str, type],
+ discriminator_field: str,
+ discriminator_alias: str | None,
+ ) -> None:
+ self.mapping = mapping
+ self.field_name = discriminator_field
+ self.field_alias_from = discriminator_alias
+
+
+def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, ...]) -> DiscriminatorDetails | None:
+ if isinstance(union, CachedDiscriminatorType):
+ return union.__discriminator__
+
+ discriminator_field_name: str | None = None
+
+ for annotation in meta_annotations:
+ if isinstance(annotation, PropertyInfo) and annotation.discriminator is not None:
+ discriminator_field_name = annotation.discriminator
+ break
+
+ if not discriminator_field_name:
+ return None
+
+ mapping: dict[str, type] = {}
+ discriminator_alias: str | None = None
+
+ for variant in get_args(union):
+ variant = strip_annotated_type(variant)
+ if is_basemodel_type(variant):
+ if PYDANTIC_V2:
+ field = _extract_field_schema_pv2(variant, discriminator_field_name)
+ if not field:
+ continue
+
+ # Note: if one variant defines an alias then they all should
+ discriminator_alias = field.get("serialization_alias")
+
+ field_schema = field["schema"]
+
+ if field_schema["type"] == "literal":
+ for entry in field_schema["expected"]:
+ if isinstance(entry, str):
+ mapping[entry] = variant
+ else:
+ field_info = cast("dict[str, FieldInfo]", variant.__fields__).get(discriminator_field_name) # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
+ if not field_info:
+ continue
+
+ # Note: if one variant defines an alias then they all should
+ discriminator_alias = field_info.alias
+
+ if field_info.annotation and is_literal_type(field_info.annotation):
+ for entry in get_args(field_info.annotation):
+ if isinstance(entry, str):
+ mapping[entry] = variant
+
+ if not mapping:
+ return None
+
+ details = DiscriminatorDetails(
+ mapping=mapping,
+ discriminator_field=discriminator_field_name,
+ discriminator_alias=discriminator_alias,
+ )
+ cast(CachedDiscriminatorType, union).__discriminator__ = details
+ return details
+
+
+def _extract_field_schema_pv2(model: type[BaseModel], field_name: str) -> ModelField | None:
+ schema = model.__pydantic_core_schema__
+ if schema["type"] != "model":
+ return None
+
+ fields_schema = schema["schema"]
+ if fields_schema["type"] != "model-fields":
+ return None
+
+ fields_schema = cast("ModelFieldsSchema", fields_schema)
+
+ field = fields_schema["fields"].get(field_name)
+ if not field:
+ return None
+
+ return cast("ModelField", field) # pyright: ignore[reportUnnecessaryCast]
+
+
def validate_type(*, type_: type[_T], value: object) -> _T:
"""Strict validation that the given value matches the expected type"""
if inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel):
tests/test_models.py
@@ -7,6 +7,7 @@ import pytest
import pydantic
from pydantic import Field
+from openai._utils import PropertyInfo
from openai._compat import PYDANTIC_V2, parse_obj, model_dump, model_json
from openai._models import BaseModel, construct_type
@@ -583,3 +584,182 @@ def test_annotated_types() -> None:
)
assert isinstance(m, Model)
assert m.value == "foo"
+
+
+def test_discriminated_unions_invalid_data() -> None:
+ class A(BaseModel):
+ type: Literal["a"]
+
+ data: str
+
+ class B(BaseModel):
+ type: Literal["b"]
+
+ data: int
+
+ m = construct_type(
+ value={"type": "b", "data": "foo"},
+ type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="type")]),
+ )
+ assert isinstance(m, B)
+ assert m.type == "b"
+ assert m.data == "foo" # type: ignore[comparison-overlap]
+
+ m = construct_type(
+ value={"type": "a", "data": 100},
+ type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="type")]),
+ )
+ assert isinstance(m, A)
+ assert m.type == "a"
+ if PYDANTIC_V2:
+ assert m.data == 100 # type: ignore[comparison-overlap]
+ else:
+ # pydantic v1 automatically converts inputs to strings
+ # if the expected type is a str
+ assert m.data == "100"
+
+
+def test_discriminated_unions_unknown_variant() -> None:
+ class A(BaseModel):
+ type: Literal["a"]
+
+ data: str
+
+ class B(BaseModel):
+ type: Literal["b"]
+
+ data: int
+
+ m = construct_type(
+ value={"type": "c", "data": None, "new_thing": "bar"},
+ type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="type")]),
+ )
+
+ # just chooses the first variant
+ assert isinstance(m, A)
+ assert m.type == "c" # type: ignore[comparison-overlap]
+ assert m.data == None # type: ignore[unreachable]
+ assert m.new_thing == "bar"
+
+
+def test_discriminated_unions_invalid_data_nested_unions() -> None:
+ class A(BaseModel):
+ type: Literal["a"]
+
+ data: str
+
+ class B(BaseModel):
+ type: Literal["b"]
+
+ data: int
+
+ class C(BaseModel):
+ type: Literal["c"]
+
+ data: bool
+
+ m = construct_type(
+ value={"type": "b", "data": "foo"},
+ type_=cast(Any, Annotated[Union[Union[A, B], C], PropertyInfo(discriminator="type")]),
+ )
+ assert isinstance(m, B)
+ assert m.type == "b"
+ assert m.data == "foo" # type: ignore[comparison-overlap]
+
+ m = construct_type(
+ value={"type": "c", "data": "foo"},
+ type_=cast(Any, Annotated[Union[Union[A, B], C], PropertyInfo(discriminator="type")]),
+ )
+ assert isinstance(m, C)
+ assert m.type == "c"
+ assert m.data == "foo" # type: ignore[comparison-overlap]
+
+
+def test_discriminated_unions_with_aliases_invalid_data() -> None:
+ class A(BaseModel):
+ foo_type: Literal["a"] = Field(alias="type")
+
+ data: str
+
+ class B(BaseModel):
+ foo_type: Literal["b"] = Field(alias="type")
+
+ data: int
+
+ m = construct_type(
+ value={"type": "b", "data": "foo"},
+ type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="foo_type")]),
+ )
+ assert isinstance(m, B)
+ assert m.foo_type == "b"
+ assert m.data == "foo" # type: ignore[comparison-overlap]
+
+ m = construct_type(
+ value={"type": "a", "data": 100},
+ type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="foo_type")]),
+ )
+ assert isinstance(m, A)
+ assert m.foo_type == "a"
+ if PYDANTIC_V2:
+ assert m.data == 100 # type: ignore[comparison-overlap]
+ else:
+ # pydantic v1 automatically converts inputs to strings
+ # if the expected type is a str
+ assert m.data == "100"
+
+
+def test_discriminated_unions_overlapping_discriminators_invalid_data() -> None:
+ class A(BaseModel):
+ type: Literal["a"]
+
+ data: bool
+
+ class B(BaseModel):
+ type: Literal["a"]
+
+ data: int
+
+ m = construct_type(
+ value={"type": "a", "data": "foo"},
+ type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="type")]),
+ )
+ assert isinstance(m, B)
+ assert m.type == "a"
+ assert m.data == "foo" # type: ignore[comparison-overlap]
+
+
+def test_discriminated_unions_invalid_data_uses_cache() -> None:
+ class A(BaseModel):
+ type: Literal["a"]
+
+ data: str
+
+ class B(BaseModel):
+ type: Literal["b"]
+
+ data: int
+
+ UnionType = cast(Any, Union[A, B])
+
+ assert not hasattr(UnionType, "__discriminator__")
+
+ m = construct_type(
+ value={"type": "b", "data": "foo"}, type_=cast(Any, Annotated[UnionType, PropertyInfo(discriminator="type")])
+ )
+ assert isinstance(m, B)
+ assert m.type == "b"
+ assert m.data == "foo" # type: ignore[comparison-overlap]
+
+ discriminator = UnionType.__discriminator__
+ assert discriminator is not None
+
+ m = construct_type(
+ value={"type": "b", "data": "foo"}, type_=cast(Any, Annotated[UnionType, PropertyInfo(discriminator="type")])
+ )
+ assert isinstance(m, B)
+ assert m.type == "b"
+ assert m.data == "foo" # type: ignore[comparison-overlap]
+
+ # if the discriminator details object stays the same between invocations then
+ # we hit the cache
+ assert UnionType.__discriminator__ is discriminator