Commit 49cab6f2

Stainless Bot <107565488+stainless-bot@users.noreply.github.com>
2024-03-08 22:23:35
chore(internal): support parsing Annotated types (#1222)
1 parent 6147fbe
src/openai/_legacy_response.py
@@ -13,7 +13,7 @@ import httpx
 import pydantic
 
 from ._types import NoneType
-from ._utils import is_given
+from ._utils import is_given, extract_type_arg, is_annotated_type
 from ._models import BaseModel, is_basemodel
 from ._constants import RAW_RESPONSE_HEADER
 from ._streaming import Stream, AsyncStream, is_stream_class_type, extract_stream_chunk_type
@@ -174,6 +174,10 @@ class LegacyAPIResponse(Generic[R]):
         return self.http_response.elapsed
 
     def _parse(self, *, to: type[_T] | None = None) -> R | _T:
+        # unwrap `Annotated[T, ...]` -> `T`
+        if to and is_annotated_type(to):
+            to = extract_type_arg(to, 0)
+
         if self._stream:
             if to:
                 if not is_stream_class_type(to):
@@ -215,6 +219,11 @@ class LegacyAPIResponse(Generic[R]):
             )
 
         cast_to = to if to is not None else self._cast_to
+
+        # unwrap `Annotated[T, ...]` -> `T`
+        if is_annotated_type(cast_to):
+            cast_to = extract_type_arg(cast_to, 0)
+
         if cast_to is NoneType:
             return cast(R, None)
 
src/openai/_models.py
@@ -30,7 +30,16 @@ from ._types import (
     AnyMapping,
     HttpxRequestFiles,
 )
-from ._utils import is_list, is_given, is_mapping, parse_date, parse_datetime, strip_not_given
+from ._utils import (
+    is_list,
+    is_given,
+    is_mapping,
+    parse_date,
+    parse_datetime,
+    strip_not_given,
+    extract_type_arg,
+    is_annotated_type,
+)
 from ._compat import (
     PYDANTIC_V2,
     ConfigDict,
@@ -275,6 +284,9 @@ def construct_type(*, value: object, type_: type) -> object:
 
     If the given value does not match the expected type then it is returned as-is.
     """
+    # unwrap `Annotated[T, ...]` -> `T`
+    if is_annotated_type(type_):
+        type_ = extract_type_arg(type_, 0)
 
     # we need to use the origin class for any types that are subscripted generics
     # e.g. Dict[str, object]
src/openai/_response.py
@@ -25,7 +25,7 @@ import httpx
 import pydantic
 
 from ._types import NoneType
-from ._utils import is_given, extract_type_var_from_base
+from ._utils import is_given, extract_type_arg, is_annotated_type, extract_type_var_from_base
 from ._models import BaseModel, is_basemodel
 from ._constants import RAW_RESPONSE_HEADER, OVERRIDE_CAST_TO_HEADER
 from ._streaming import Stream, AsyncStream, is_stream_class_type, extract_stream_chunk_type
@@ -121,6 +121,10 @@ class BaseAPIResponse(Generic[R]):
         )
 
     def _parse(self, *, to: type[_T] | None = None) -> R | _T:
+        # unwrap `Annotated[T, ...]` -> `T`
+        if to and is_annotated_type(to):
+            to = extract_type_arg(to, 0)
+
         if self._is_sse_stream:
             if to:
                 if not is_stream_class_type(to):
@@ -162,6 +166,11 @@ class BaseAPIResponse(Generic[R]):
             )
 
         cast_to = to if to is not None else self._cast_to
+
+        # unwrap `Annotated[T, ...]` -> `T`
+        if is_annotated_type(cast_to):
+            cast_to = extract_type_arg(cast_to, 0)
+
         if cast_to is NoneType:
             return cast(R, None)
 
tests/test_legacy_response.py
@@ -1,4 +1,6 @@
 import json
+from typing import cast
+from typing_extensions import Annotated
 
 import httpx
 import pytest
@@ -63,3 +65,20 @@ def test_response_parse_custom_model(client: OpenAI) -> None:
     obj = response.parse(to=CustomModel)
     assert obj.foo == "hello!"
     assert obj.bar == 2
+
+
+def test_response_parse_annotated_type(client: OpenAI) -> None:
+    response = LegacyAPIResponse(
+        raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})),
+        client=client,
+        stream=False,
+        stream_cls=None,
+        cast_to=str,
+        options=FinalRequestOptions.construct(method="get", url="/foo"),
+    )
+
+    obj = response.parse(
+        to=cast("type[CustomModel]", Annotated[CustomModel, "random metadata"]),
+    )
+    assert obj.foo == "hello!"
+    assert obj.bar == 2
tests/test_models.py
@@ -1,14 +1,14 @@
 import json
 from typing import Any, Dict, List, Union, Optional, cast
 from datetime import datetime, timezone
-from typing_extensions import Literal
+from typing_extensions import Literal, Annotated
 
 import pytest
 import pydantic
 from pydantic import Field
 
 from openai._compat import PYDANTIC_V2, parse_obj, model_dump, model_json
-from openai._models import BaseModel
+from openai._models import BaseModel, construct_type
 
 
 class BasicModel(BaseModel):
@@ -571,3 +571,15 @@ def test_type_compat() -> None:
         foo: Optional[str] = None
 
     takes_pydantic(OurModel())
+
+
+def test_annotated_types() -> None:
+    class Model(BaseModel):
+        value: str
+
+    m = construct_type(
+        value={"value": "foo"},
+        type_=cast(Any, Annotated[Model, "random metadata"]),
+    )
+    assert isinstance(m, Model)
+    assert m.value == "foo"
tests/test_response.py
@@ -1,5 +1,6 @@
 import json
-from typing import List
+from typing import List, cast
+from typing_extensions import Annotated
 
 import httpx
 import pytest
@@ -157,3 +158,37 @@ async def test_async_response_parse_custom_model(async_client: AsyncOpenAI) -> N
     obj = await response.parse(to=CustomModel)
     assert obj.foo == "hello!"
     assert obj.bar == 2
+
+
+def test_response_parse_annotated_type(client: OpenAI) -> None:
+    response = APIResponse(
+        raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})),
+        client=client,
+        stream=False,
+        stream_cls=None,
+        cast_to=str,
+        options=FinalRequestOptions.construct(method="get", url="/foo"),
+    )
+
+    obj = response.parse(
+        to=cast("type[CustomModel]", Annotated[CustomModel, "random metadata"]),
+    )
+    assert obj.foo == "hello!"
+    assert obj.bar == 2
+
+
+async def test_async_response_parse_annotated_type(async_client: AsyncOpenAI) -> None:
+    response = AsyncAPIResponse(
+        raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})),
+        client=async_client,
+        stream=False,
+        stream_cls=None,
+        cast_to=str,
+        options=FinalRequestOptions.construct(method="get", url="/foo"),
+    )
+
+    obj = await response.parse(
+        to=cast("type[CustomModel]", Annotated[CustomModel, "random metadata"]),
+    )
+    assert obj.foo == "hello!"
+    assert obj.bar == 2
tests/utils.py
@@ -14,6 +14,8 @@ from openai._utils import (
     is_list,
     is_list_type,
     is_union_type,
+    extract_type_arg,
+    is_annotated_type,
 )
 from openai._compat import PYDANTIC_V2, field_outer_type, get_model_fields
 from openai._models import BaseModel
@@ -49,6 +51,10 @@ def assert_matches_type(
     path: list[str],
     allow_none: bool = False,
 ) -> None:
+    # unwrap `Annotated[T, ...]` -> `T`
+    if is_annotated_type(type_):
+        type_ = extract_type_arg(type_, 0)
+
     if allow_none and value is None:
         return