Commit bed71312

stainless-app[bot] <142633134+stainless-app[bot]@users.noreply.github.com>
2024-12-12 20:44:38
chore(internal): add support for TypeAliasType (#1942)
1 parent f32d466
src/openai/_utils/__init__.py
@@ -40,6 +40,7 @@ from ._typing import (
     is_iterable_type as is_iterable_type,
     is_required_type as is_required_type,
     is_annotated_type as is_annotated_type,
+    is_type_alias_type as is_type_alias_type,
     strip_annotated_type as strip_annotated_type,
     extract_type_var_from_base as extract_type_var_from_base,
 )
src/openai/_utils/_typing.py
@@ -1,8 +1,17 @@
 from __future__ import annotations
 
+import sys
+import typing
+import typing_extensions
 from typing import Any, TypeVar, Iterable, cast
 from collections import abc as _c_abc
-from typing_extensions import Required, Annotated, get_args, get_origin
+from typing_extensions import (
+    TypeIs,
+    Required,
+    Annotated,
+    get_args,
+    get_origin,
+)
 
 from .._types import InheritsGeneric
 from .._compat import is_union as _is_union
@@ -36,6 +45,26 @@ def is_typevar(typ: type) -> bool:
     return type(typ) == TypeVar  # type: ignore
 
 
+_TYPE_ALIAS_TYPES: tuple[type[typing_extensions.TypeAliasType], ...] = (typing_extensions.TypeAliasType,)
+if sys.version_info >= (3, 12):
+    _TYPE_ALIAS_TYPES = (*_TYPE_ALIAS_TYPES, typing.TypeAliasType)
+
+
+def is_type_alias_type(tp: Any, /) -> TypeIs[typing_extensions.TypeAliasType]:
+    """Return whether the provided argument is an instance of `TypeAliasType`.
+
+    ```python
+    type Int = int
+    is_type_alias_type(Int)
+    # > True
+    Str = TypeAliasType("Str", str)
+    is_type_alias_type(Str)
+    # > True
+    ```
+    """
+    return isinstance(tp, _TYPE_ALIAS_TYPES)
+
+
 # Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]]
 def strip_annotated_type(typ: type) -> type:
     if is_required_type(typ) or is_annotated_type(typ):
src/openai/_legacy_response.py
@@ -24,7 +24,7 @@ import httpx
 import pydantic
 
 from ._types import NoneType
-from ._utils import is_given, extract_type_arg, is_annotated_type
+from ._utils import is_given, extract_type_arg, is_annotated_type, is_type_alias_type
 from ._models import BaseModel, is_basemodel, add_request_id
 from ._constants import RAW_RESPONSE_HEADER
 from ._streaming import Stream, AsyncStream, is_stream_class_type, extract_stream_chunk_type
@@ -195,9 +195,15 @@ class LegacyAPIResponse(Generic[R]):
         return self.http_response.elapsed
 
     def _parse(self, *, to: type[_T] | None = None) -> R | _T:
+        cast_to = to if to is not None else self._cast_to
+
+        # unwrap `TypeAlias('Name', T)` -> `T`
+        if is_type_alias_type(cast_to):
+            cast_to = cast_to.__value__  # type: ignore[unreachable]
+
         # unwrap `Annotated[T, ...]` -> `T`
-        if to and is_annotated_type(to):
-            to = extract_type_arg(to, 0)
+        if cast_to and is_annotated_type(cast_to):
+            cast_to = extract_type_arg(cast_to, 0)
 
         if self._stream:
             if to:
@@ -233,18 +239,12 @@ class LegacyAPIResponse(Generic[R]):
             return cast(
                 R,
                 stream_cls(
-                    cast_to=self._cast_to,
+                    cast_to=cast_to,
                     response=self.http_response,
                     client=cast(Any, self._client),
                 ),
             )
 
-        cast_to = to if to is not None else self._cast_to
-
-        # unwrap `Annotated[T, ...]` -> `T`
-        if is_annotated_type(cast_to):
-            cast_to = extract_type_arg(cast_to, 0)
-
         if cast_to is NoneType:
             return cast(R, None)
 
src/openai/_models.py
@@ -47,6 +47,7 @@ from ._utils import (
     strip_not_given,
     extract_type_arg,
     is_annotated_type,
+    is_type_alias_type,
     strip_annotated_type,
 )
 from ._compat import (
@@ -453,6 +454,8 @@ def construct_type(*, value: object, type_: object) -> object:
     # we allow `object` as the input type because otherwise, passing things like
     # `Literal['value']` will be reported as a type error by type checkers
     type_ = cast("type[object]", type_)
+    if is_type_alias_type(type_):
+        type_ = type_.__value__  # type: ignore[unreachable]
 
     # unwrap `Annotated[T, ...]` -> `T`
     if is_annotated_type(type_):
src/openai/_response.py
@@ -25,7 +25,7 @@ import httpx
 import pydantic
 
 from ._types import NoneType
-from ._utils import is_given, extract_type_arg, is_annotated_type, extract_type_var_from_base
+from ._utils import is_given, extract_type_arg, is_annotated_type, is_type_alias_type, extract_type_var_from_base
 from ._models import BaseModel, is_basemodel, add_request_id
 from ._constants import RAW_RESPONSE_HEADER, OVERRIDE_CAST_TO_HEADER
 from ._streaming import Stream, AsyncStream, is_stream_class_type, extract_stream_chunk_type
@@ -126,9 +126,15 @@ class BaseAPIResponse(Generic[R]):
         )
 
     def _parse(self, *, to: type[_T] | None = None) -> R | _T:
+        cast_to = to if to is not None else self._cast_to
+
+        # unwrap `TypeAlias('Name', T)` -> `T`
+        if is_type_alias_type(cast_to):
+            cast_to = cast_to.__value__  # type: ignore[unreachable]
+
         # unwrap `Annotated[T, ...]` -> `T`
-        if to and is_annotated_type(to):
-            to = extract_type_arg(to, 0)
+        if cast_to and is_annotated_type(cast_to):
+            cast_to = extract_type_arg(cast_to, 0)
 
         if self._is_sse_stream:
             if to:
@@ -164,18 +170,12 @@ class BaseAPIResponse(Generic[R]):
             return cast(
                 R,
                 stream_cls(
-                    cast_to=self._cast_to,
+                    cast_to=cast_to,
                     response=self.http_response,
                     client=cast(Any, self._client),
                 ),
             )
 
-        cast_to = to if to is not None else self._cast_to
-
-        # unwrap `Annotated[T, ...]` -> `T`
-        if is_annotated_type(cast_to):
-            cast_to = extract_type_arg(cast_to, 0)
-
         if cast_to is NoneType:
             return cast(R, None)
 
tests/test_models.py
@@ -1,7 +1,7 @@
 import json
 from typing import Any, Dict, List, Union, Optional, cast
 from datetime import datetime, timezone
-from typing_extensions import Literal, Annotated
+from typing_extensions import Literal, Annotated, TypeAliasType
 
 import pytest
 import pydantic
@@ -828,3 +828,19 @@ def test_discriminated_unions_invalid_data_uses_cache() -> None:
     # if the discriminator details object stays the same between invocations then
     # we hit the cache
     assert UnionType.__discriminator__ is discriminator
+
+
+@pytest.mark.skipif(not PYDANTIC_V2, reason="TypeAliasType is not supported in Pydantic v1")
+def test_type_alias_type() -> None:
+    Alias = TypeAliasType("Alias", str)
+
+    class Model(BaseModel):
+        alias: Alias
+        union: Union[int, Alias]
+
+    m = construct_type(value={"alias": "foo", "union": "bar"}, type_=Model)
+    assert isinstance(m, Model)
+    assert isinstance(m.alias, str)
+    assert m.alias == "foo"
+    assert isinstance(m.union, str)
+    assert m.union == "bar"
tests/utils.py
@@ -19,6 +19,7 @@ from openai._utils import (
     is_union_type,
     extract_type_arg,
     is_annotated_type,
+    is_type_alias_type,
 )
 from openai._compat import PYDANTIC_V2, field_outer_type, get_model_fields
 from openai._models import BaseModel
@@ -58,6 +59,9 @@ def assert_matches_type(
     path: list[str],
     allow_none: bool = False,
 ) -> None:
+    if is_type_alias_type(type_):
+        type_ = type_.__value__
+
     # unwrap `Annotated[T, ...]` -> `T`
     if is_annotated_type(type_):
         type_ = extract_type_arg(type_, 0)