Commit 080587bd
Changed files (2)
src
openai
tests
src/openai/_models.py
@@ -2,6 +2,7 @@ from __future__ import annotations
import os
import inspect
+import weakref
from typing import TYPE_CHECKING, Any, Type, Tuple, Union, Generic, TypeVar, Callable, Optional, cast
from datetime import date, datetime
from typing_extensions import (
@@ -598,6 +599,9 @@ class CachedDiscriminatorType(Protocol):
__discriminator__: DiscriminatorDetails
+DISCRIMINATOR_CACHE: weakref.WeakKeyDictionary[type, DiscriminatorDetails] = weakref.WeakKeyDictionary()
+
+
class DiscriminatorDetails:
field_name: str
"""The name of the discriminator field in the variant class, e.g.
@@ -640,8 +644,9 @@ class DiscriminatorDetails:
def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, ...]) -> DiscriminatorDetails | None:
- if isinstance(union, CachedDiscriminatorType):
- return union.__discriminator__
+ cached = DISCRIMINATOR_CACHE.get(union)
+ if cached is not None:
+ return cached
discriminator_field_name: str | None = None
@@ -694,7 +699,7 @@ def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any,
discriminator_field=discriminator_field_name,
discriminator_alias=discriminator_alias,
)
- cast(CachedDiscriminatorType, union).__discriminator__ = details
+ DISCRIMINATOR_CACHE.setdefault(union, details)
return details
tests/test_models.py
@@ -9,7 +9,7 @@ from pydantic import Field
from openai._utils import PropertyInfo
from openai._compat import PYDANTIC_V1, parse_obj, model_dump, model_json
-from openai._models import BaseModel, construct_type
+from openai._models import DISCRIMINATOR_CACHE, BaseModel, construct_type
class BasicModel(BaseModel):
@@ -809,7 +809,7 @@ def test_discriminated_unions_invalid_data_uses_cache() -> None:
UnionType = cast(Any, Union[A, B])
- assert not hasattr(UnionType, "__discriminator__")
+ assert not DISCRIMINATOR_CACHE.get(UnionType)
m = construct_type(
value={"type": "b", "data": "foo"}, type_=cast(Any, Annotated[UnionType, PropertyInfo(discriminator="type")])
@@ -818,7 +818,7 @@ def test_discriminated_unions_invalid_data_uses_cache() -> None:
assert m.type == "b"
assert m.data == "foo" # type: ignore[comparison-overlap]
- discriminator = UnionType.__discriminator__
+ discriminator = DISCRIMINATOR_CACHE.get(UnionType)
assert discriminator is not None
m = construct_type(
@@ -830,7 +830,7 @@ def test_discriminated_unions_invalid_data_uses_cache() -> None:
# if the discriminator details object stays the same between invocations then
# we hit the cache
- assert UnionType.__discriminator__ is discriminator
+ assert DISCRIMINATOR_CACHE.get(UnionType) is discriminator
@pytest.mark.skipif(PYDANTIC_V1, reason="TypeAliasType is not supported in Pydantic v1")