Commit 49cab6f2
Changed files (7)
src/openai/_legacy_response.py
@@ -13,7 +13,7 @@ import httpx
import pydantic
from ._types import NoneType
-from ._utils import is_given
+from ._utils import is_given, extract_type_arg, is_annotated_type
from ._models import BaseModel, is_basemodel
from ._constants import RAW_RESPONSE_HEADER
from ._streaming import Stream, AsyncStream, is_stream_class_type, extract_stream_chunk_type
@@ -174,6 +174,10 @@ class LegacyAPIResponse(Generic[R]):
return self.http_response.elapsed
def _parse(self, *, to: type[_T] | None = None) -> R | _T:
+ # unwrap `Annotated[T, ...]` -> `T`
+ if to and is_annotated_type(to):
+ to = extract_type_arg(to, 0)
+
if self._stream:
if to:
if not is_stream_class_type(to):
@@ -215,6 +219,11 @@ class LegacyAPIResponse(Generic[R]):
)
cast_to = to if to is not None else self._cast_to
+
+ # unwrap `Annotated[T, ...]` -> `T`
+ if is_annotated_type(cast_to):
+ cast_to = extract_type_arg(cast_to, 0)
+
if cast_to is NoneType:
return cast(R, None)
src/openai/_models.py
@@ -30,7 +30,16 @@ from ._types import (
AnyMapping,
HttpxRequestFiles,
)
-from ._utils import is_list, is_given, is_mapping, parse_date, parse_datetime, strip_not_given
+from ._utils import (
+ is_list,
+ is_given,
+ is_mapping,
+ parse_date,
+ parse_datetime,
+ strip_not_given,
+ extract_type_arg,
+ is_annotated_type,
+)
from ._compat import (
PYDANTIC_V2,
ConfigDict,
@@ -275,6 +284,9 @@ def construct_type(*, value: object, type_: type) -> object:
If the given value does not match the expected type then it is returned as-is.
"""
+ # unwrap `Annotated[T, ...]` -> `T`
+ if is_annotated_type(type_):
+ type_ = extract_type_arg(type_, 0)
# we need to use the origin class for any types that are subscripted generics
# e.g. Dict[str, object]
src/openai/_response.py
@@ -25,7 +25,7 @@ import httpx
import pydantic
from ._types import NoneType
-from ._utils import is_given, extract_type_var_from_base
+from ._utils import is_given, extract_type_arg, is_annotated_type, extract_type_var_from_base
from ._models import BaseModel, is_basemodel
from ._constants import RAW_RESPONSE_HEADER, OVERRIDE_CAST_TO_HEADER
from ._streaming import Stream, AsyncStream, is_stream_class_type, extract_stream_chunk_type
@@ -121,6 +121,10 @@ class BaseAPIResponse(Generic[R]):
)
def _parse(self, *, to: type[_T] | None = None) -> R | _T:
+ # unwrap `Annotated[T, ...]` -> `T`
+ if to and is_annotated_type(to):
+ to = extract_type_arg(to, 0)
+
if self._is_sse_stream:
if to:
if not is_stream_class_type(to):
@@ -162,6 +166,11 @@ class BaseAPIResponse(Generic[R]):
)
cast_to = to if to is not None else self._cast_to
+
+ # unwrap `Annotated[T, ...]` -> `T`
+ if is_annotated_type(cast_to):
+ cast_to = extract_type_arg(cast_to, 0)
+
if cast_to is NoneType:
return cast(R, None)
tests/test_legacy_response.py
@@ -1,4 +1,6 @@
import json
+from typing import cast
+from typing_extensions import Annotated
import httpx
import pytest
@@ -63,3 +65,20 @@ def test_response_parse_custom_model(client: OpenAI) -> None:
obj = response.parse(to=CustomModel)
assert obj.foo == "hello!"
assert obj.bar == 2
+
+
+def test_response_parse_annotated_type(client: OpenAI) -> None:
+ response = LegacyAPIResponse(
+ raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})),
+ client=client,
+ stream=False,
+ stream_cls=None,
+ cast_to=str,
+ options=FinalRequestOptions.construct(method="get", url="/foo"),
+ )
+
+ obj = response.parse(
+ to=cast("type[CustomModel]", Annotated[CustomModel, "random metadata"]),
+ )
+ assert obj.foo == "hello!"
+ assert obj.bar == 2
tests/test_models.py
@@ -1,14 +1,14 @@
import json
from typing import Any, Dict, List, Union, Optional, cast
from datetime import datetime, timezone
-from typing_extensions import Literal
+from typing_extensions import Literal, Annotated
import pytest
import pydantic
from pydantic import Field
from openai._compat import PYDANTIC_V2, parse_obj, model_dump, model_json
-from openai._models import BaseModel
+from openai._models import BaseModel, construct_type
class BasicModel(BaseModel):
@@ -571,3 +571,15 @@ def test_type_compat() -> None:
foo: Optional[str] = None
takes_pydantic(OurModel())
+
+
+def test_annotated_types() -> None:
+ class Model(BaseModel):
+ value: str
+
+ m = construct_type(
+ value={"value": "foo"},
+ type_=cast(Any, Annotated[Model, "random metadata"]),
+ )
+ assert isinstance(m, Model)
+ assert m.value == "foo"
tests/test_response.py
@@ -1,5 +1,6 @@
import json
-from typing import List
+from typing import List, cast
+from typing_extensions import Annotated
import httpx
import pytest
@@ -157,3 +158,37 @@ async def test_async_response_parse_custom_model(async_client: AsyncOpenAI) -> N
obj = await response.parse(to=CustomModel)
assert obj.foo == "hello!"
assert obj.bar == 2
+
+
+def test_response_parse_annotated_type(client: OpenAI) -> None:
+ response = APIResponse(
+ raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})),
+ client=client,
+ stream=False,
+ stream_cls=None,
+ cast_to=str,
+ options=FinalRequestOptions.construct(method="get", url="/foo"),
+ )
+
+ obj = response.parse(
+ to=cast("type[CustomModel]", Annotated[CustomModel, "random metadata"]),
+ )
+ assert obj.foo == "hello!"
+ assert obj.bar == 2
+
+
+async def test_async_response_parse_annotated_type(async_client: AsyncOpenAI) -> None:
+ response = AsyncAPIResponse(
+ raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})),
+ client=async_client,
+ stream=False,
+ stream_cls=None,
+ cast_to=str,
+ options=FinalRequestOptions.construct(method="get", url="/foo"),
+ )
+
+ obj = await response.parse(
+ to=cast("type[CustomModel]", Annotated[CustomModel, "random metadata"]),
+ )
+ assert obj.foo == "hello!"
+ assert obj.bar == 2
tests/utils.py
@@ -14,6 +14,8 @@ from openai._utils import (
is_list,
is_list_type,
is_union_type,
+ extract_type_arg,
+ is_annotated_type,
)
from openai._compat import PYDANTIC_V2, field_outer_type, get_model_fields
from openai._models import BaseModel
@@ -49,6 +51,10 @@ def assert_matches_type(
path: list[str],
allow_none: bool = False,
) -> None:
+ # unwrap `Annotated[T, ...]` -> `T`
+ if is_annotated_type(type_):
+ type_ = extract_type_arg(type_, 0)
+
if allow_none and value is None:
return