Commit 080587bd

stainless-app[bot] <142633134+stainless-app[bot]@users.noreply.github.com>
2025-11-10 21:35:58
fix: compat with Python 3.14
1 parent 3f52ac8
Changed files (2)
src
openai
tests
src/openai/_models.py
@@ -2,6 +2,7 @@ from __future__ import annotations
 
 import os
 import inspect
+import weakref
 from typing import TYPE_CHECKING, Any, Type, Tuple, Union, Generic, TypeVar, Callable, Optional, cast
 from datetime import date, datetime
 from typing_extensions import (
@@ -598,6 +599,9 @@ class CachedDiscriminatorType(Protocol):
     __discriminator__: DiscriminatorDetails
 
 
+DISCRIMINATOR_CACHE: weakref.WeakKeyDictionary[type, DiscriminatorDetails] = weakref.WeakKeyDictionary()
+
+
 class DiscriminatorDetails:
     field_name: str
     """The name of the discriminator field in the variant class, e.g.
@@ -640,8 +644,9 @@ class DiscriminatorDetails:
 
 
 def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, ...]) -> DiscriminatorDetails | None:
-    if isinstance(union, CachedDiscriminatorType):
-        return union.__discriminator__
+    cached = DISCRIMINATOR_CACHE.get(union)
+    if cached is not None:
+        return cached
 
     discriminator_field_name: str | None = None
 
@@ -694,7 +699,7 @@ def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any,
         discriminator_field=discriminator_field_name,
         discriminator_alias=discriminator_alias,
     )
-    cast(CachedDiscriminatorType, union).__discriminator__ = details
+    DISCRIMINATOR_CACHE.setdefault(union, details)
     return details
 
 
tests/test_models.py
@@ -9,7 +9,7 @@ from pydantic import Field
 
 from openai._utils import PropertyInfo
 from openai._compat import PYDANTIC_V1, parse_obj, model_dump, model_json
-from openai._models import BaseModel, construct_type
+from openai._models import DISCRIMINATOR_CACHE, BaseModel, construct_type
 
 
 class BasicModel(BaseModel):
@@ -809,7 +809,7 @@ def test_discriminated_unions_invalid_data_uses_cache() -> None:
 
     UnionType = cast(Any, Union[A, B])
 
-    assert not hasattr(UnionType, "__discriminator__")
+    assert not DISCRIMINATOR_CACHE.get(UnionType)
 
     m = construct_type(
         value={"type": "b", "data": "foo"}, type_=cast(Any, Annotated[UnionType, PropertyInfo(discriminator="type")])
@@ -818,7 +818,7 @@ def test_discriminated_unions_invalid_data_uses_cache() -> None:
     assert m.type == "b"
     assert m.data == "foo"  # type: ignore[comparison-overlap]
 
-    discriminator = UnionType.__discriminator__
+    discriminator = DISCRIMINATOR_CACHE.get(UnionType)
     assert discriminator is not None
 
     m = construct_type(
@@ -830,7 +830,7 @@ def test_discriminated_unions_invalid_data_uses_cache() -> None:
 
     # if the discriminator details object stays the same between invocations then
     # we hit the cache
-    assert UnionType.__discriminator__ is discriminator
+    assert DISCRIMINATOR_CACHE.get(UnionType) is discriminator
 
 
 @pytest.mark.skipif(PYDANTIC_V1, reason="TypeAliasType is not supported in Pydantic v1")