Commit 61f3346b

Stainless Bot <107565488+stainless-bot@users.noreply.github.com>
2024-01-29 20:05:32
feat(client): support parsing custom response types (#1111)
1 parent 63de8ef
src/openai/_utils/_typing.py
@@ -45,7 +45,13 @@ def extract_type_arg(typ: type, index: int) -> type:
         raise RuntimeError(f"Expected type {typ} to have a type argument at index {index} but it did not") from err
 
 
-def extract_type_var_from_base(typ: type, *, generic_bases: tuple[type, ...], index: int) -> type:
+def extract_type_var_from_base(
+    typ: type,
+    *,
+    generic_bases: tuple[type, ...],
+    index: int,
+    failure_message: str | None = None,
+) -> type:
     """Given a type like `Foo[T]`, returns the generic type variable `T`.
 
     This also handles the case where a concrete subclass is given, e.g.
@@ -104,4 +110,4 @@ def extract_type_var_from_base(typ: type, *, generic_bases: tuple[type, ...], in
 
         return extracted
 
-    raise RuntimeError(f"Could not resolve inner type variable at index {index} for {typ}")
+    raise RuntimeError(failure_message or f"Could not resolve inner type variable at index {index} for {typ}")
src/openai/__init__.py
@@ -9,6 +9,7 @@ from . import types
 from ._types import NoneType, Transport, ProxiesTypes
 from ._utils import file_from_path
 from ._client import Client, OpenAI, Stream, Timeout, Transport, AsyncClient, AsyncOpenAI, AsyncStream, RequestOptions
+from ._models import BaseModel
 from ._version import __title__, __version__
 from ._response import APIResponse as APIResponse, AsyncAPIResponse as AsyncAPIResponse
 from ._exceptions import (
@@ -59,6 +60,7 @@ __all__ = [
     "OpenAI",
     "AsyncOpenAI",
     "file_from_path",
+    "BaseModel",
 ]
 
 from .lib import azure as _azure
src/openai/_legacy_response.py
@@ -5,25 +5,28 @@ import inspect
 import logging
 import datetime
 import functools
-from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, Iterator, AsyncIterator, cast
-from typing_extensions import Awaitable, ParamSpec, get_args, override, deprecated, get_origin
+from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, Iterator, AsyncIterator, cast, overload
+from typing_extensions import Awaitable, ParamSpec, override, deprecated, get_origin
 
 import anyio
 import httpx
+import pydantic
 
 from ._types import NoneType
 from ._utils import is_given
 from ._models import BaseModel, is_basemodel
 from ._constants import RAW_RESPONSE_HEADER
+from ._streaming import Stream, AsyncStream, is_stream_class_type, extract_stream_chunk_type
 from ._exceptions import APIResponseValidationError
 
 if TYPE_CHECKING:
     from ._models import FinalRequestOptions
-    from ._base_client import Stream, BaseClient, AsyncStream
+    from ._base_client import BaseClient
 
 
 P = ParamSpec("P")
 R = TypeVar("R")
+_T = TypeVar("_T")
 
 log: logging.Logger = logging.getLogger(__name__)
 
@@ -43,7 +46,7 @@ class LegacyAPIResponse(Generic[R]):
 
     _cast_to: type[R]
     _client: BaseClient[Any, Any]
-    _parsed: R | None
+    _parsed_by_type: dict[type[Any], Any]
     _stream: bool
     _stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None
     _options: FinalRequestOptions
@@ -62,27 +65,60 @@ class LegacyAPIResponse(Generic[R]):
     ) -> None:
         self._cast_to = cast_to
         self._client = client
-        self._parsed = None
+        self._parsed_by_type = {}
         self._stream = stream
         self._stream_cls = stream_cls
         self._options = options
         self.http_response = raw
 
+    @overload
+    def parse(self, *, to: type[_T]) -> _T:
+        ...
+
+    @overload
     def parse(self) -> R:
+        ...
+
+    def parse(self, *, to: type[_T] | None = None) -> R | _T:
         """Returns the rich python representation of this response's data.
 
+        NOTE: For the async client: this will become a coroutine in the next major version.
+
         For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`.
 
-        NOTE: For the async client: this will become a coroutine in the next major version.
+        You can customise the type that the response is parsed into through
+        the `to` argument, e.g.
+
+        ```py
+        from openai import BaseModel
+
+
+        class MyModel(BaseModel):
+            foo: str
+
+
+        obj = response.parse(to=MyModel)
+        print(obj.foo)
+        ```
+
+        We support parsing:
+          - `BaseModel`
+          - `dict`
+          - `list`
+          - `Union`
+          - `str`
+          - `httpx.Response`
         """
-        if self._parsed is not None:
-            return self._parsed
+        cache_key = to if to is not None else self._cast_to
+        cached = self._parsed_by_type.get(cache_key)
+        if cached is not None:
+            return cached  # type: ignore[no-any-return]
 
-        parsed = self._parse()
+        parsed = self._parse(to=to)
         if is_given(self._options.post_parser):
             parsed = self._options.post_parser(parsed)
 
-        self._parsed = parsed
+        self._parsed_by_type[cache_key] = parsed
         return parsed
 
     @property
@@ -135,13 +171,29 @@ class LegacyAPIResponse(Generic[R]):
         """The time taken for the complete request/response cycle to complete."""
         return self.http_response.elapsed
 
-    def _parse(self) -> R:
+    def _parse(self, *, to: type[_T] | None = None) -> R | _T:
         if self._stream:
+            if to:
+                if not is_stream_class_type(to):
+                    raise TypeError(f"Expected custom parse type to be a subclass of {Stream} or {AsyncStream}")
+
+                return cast(
+                    _T,
+                    to(
+                        cast_to=extract_stream_chunk_type(
+                            to,
+                            failure_message="Expected custom stream type to be passed with a type argument, e.g. Stream[ChunkType]",
+                        ),
+                        response=self.http_response,
+                        client=cast(Any, self._client),
+                    ),
+                )
+
             if self._stream_cls:
                 return cast(
                     R,
                     self._stream_cls(
-                        cast_to=_extract_stream_chunk_type(self._stream_cls),
+                        cast_to=extract_stream_chunk_type(self._stream_cls),
                         response=self.http_response,
                         client=cast(Any, self._client),
                     ),
@@ -160,7 +212,7 @@ class LegacyAPIResponse(Generic[R]):
                 ),
             )
 
-        cast_to = self._cast_to
+        cast_to = to if to is not None else self._cast_to
         if cast_to is NoneType:
             return cast(R, None)
 
@@ -186,14 +238,9 @@ class LegacyAPIResponse(Generic[R]):
                 raise ValueError(f"Subclasses of httpx.Response cannot be passed to `cast_to`")
             return cast(R, response)
 
-        # The check here is necessary as we are subverting the the type system
-        # with casts as the relationship between TypeVars and Types are very strict
-        # which means we must return *exactly* what was input or transform it in a
-        # way that retains the TypeVar state. As we cannot do that in this function
-        # then we have to resort to using `cast`. At the time of writing, we know this
-        # to be safe as we have handled all the types that could be bound to the
-        # `ResponseT` TypeVar, however if that TypeVar is ever updated in the future, then
-        # this function would become unsafe but a type checker would not report an error.
+        if inspect.isclass(origin) and not issubclass(origin, BaseModel) and issubclass(origin, pydantic.BaseModel):
+            raise TypeError("Pydantic models must subclass our base model type, e.g. `from openai import BaseModel`")
+
         if (
             cast_to is not object
             and not origin is list
@@ -202,12 +249,12 @@ class LegacyAPIResponse(Generic[R]):
             and not issubclass(origin, BaseModel)
         ):
             raise RuntimeError(
-                f"Invalid state, expected {cast_to} to be a subclass type of {BaseModel}, {dict}, {list} or {Union}."
+                f"Unsupported type, expected {cast_to} to be a subclass of {BaseModel}, {dict}, {list}, {Union}, {NoneType}, {str} or {httpx.Response}."
             )
 
         # split is required to handle cases where additional information is included
         # in the response, e.g. application/json; charset=utf-8
-        content_type, *_ = response.headers.get("content-type").split(";")
+        content_type, *_ = response.headers.get("content-type", "*").split(";")
         if content_type != "application/json":
             if is_basemodel(cast_to):
                 try:
@@ -253,15 +300,6 @@ class MissingStreamClassError(TypeError):
         )
 
 
-def _extract_stream_chunk_type(stream_cls: type) -> type:
-    args = get_args(stream_cls)
-    if not args:
-        raise TypeError(
-            f"Expected stream_cls to have been given a generic type argument, e.g. Stream[Foo] but received {stream_cls}",
-        )
-    return cast(type, args[0])
-
-
 def to_raw_response_wrapper(func: Callable[P, R]) -> Callable[P, LegacyAPIResponse[R]]:
     """Higher order function that takes one of our bound API methods and wraps it
     to support returning the raw `APIResponse` object directly.
src/openai/_response.py
@@ -16,25 +16,29 @@ from typing import (
     Iterator,
     AsyncIterator,
     cast,
+    overload,
 )
 from typing_extensions import Awaitable, ParamSpec, override, get_origin
 
 import anyio
 import httpx
+import pydantic
 
 from ._types import NoneType
 from ._utils import is_given, extract_type_var_from_base
 from ._models import BaseModel, is_basemodel
 from ._constants import RAW_RESPONSE_HEADER, OVERRIDE_CAST_TO_HEADER
+from ._streaming import Stream, AsyncStream, is_stream_class_type, extract_stream_chunk_type
 from ._exceptions import OpenAIError, APIResponseValidationError
 
 if TYPE_CHECKING:
     from ._models import FinalRequestOptions
-    from ._base_client import Stream, BaseClient, AsyncStream
+    from ._base_client import BaseClient
 
 
 P = ParamSpec("P")
 R = TypeVar("R")
+_T = TypeVar("_T")
 _APIResponseT = TypeVar("_APIResponseT", bound="APIResponse[Any]")
 _AsyncAPIResponseT = TypeVar("_AsyncAPIResponseT", bound="AsyncAPIResponse[Any]")
 
@@ -44,7 +48,7 @@ log: logging.Logger = logging.getLogger(__name__)
 class BaseAPIResponse(Generic[R]):
     _cast_to: type[R]
     _client: BaseClient[Any, Any]
-    _parsed: R | None
+    _parsed_by_type: dict[type[Any], Any]
     _is_sse_stream: bool
     _stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None
     _options: FinalRequestOptions
@@ -63,7 +67,7 @@ class BaseAPIResponse(Generic[R]):
     ) -> None:
         self._cast_to = cast_to
         self._client = client
-        self._parsed = None
+        self._parsed_by_type = {}
         self._is_sse_stream = stream
         self._stream_cls = stream_cls
         self._options = options
@@ -116,8 +120,24 @@ class BaseAPIResponse(Generic[R]):
             f"<{self.__class__.__name__} [{self.status_code} {self.http_response.reason_phrase}] type={self._cast_to}>"
         )
 
-    def _parse(self) -> R:
+    def _parse(self, *, to: type[_T] | None = None) -> R | _T:
         if self._is_sse_stream:
+            if to:
+                if not is_stream_class_type(to):
+                    raise TypeError(f"Expected custom parse type to be a subclass of {Stream} or {AsyncStream}")
+
+                return cast(
+                    _T,
+                    to(
+                        cast_to=extract_stream_chunk_type(
+                            to,
+                            failure_message="Expected custom stream type to be passed with a type argument, e.g. Stream[ChunkType]",
+                        ),
+                        response=self.http_response,
+                        client=cast(Any, self._client),
+                    ),
+                )
+
             if self._stream_cls:
                 return cast(
                     R,
@@ -141,7 +161,7 @@ class BaseAPIResponse(Generic[R]):
                 ),
             )
 
-        cast_to = self._cast_to
+        cast_to = to if to is not None else self._cast_to
         if cast_to is NoneType:
             return cast(R, None)
 
@@ -171,14 +191,9 @@ class BaseAPIResponse(Generic[R]):
                 raise ValueError(f"Subclasses of httpx.Response cannot be passed to `cast_to`")
             return cast(R, response)
 
-        # The check here is necessary as we are subverting the the type system
-        # with casts as the relationship between TypeVars and Types are very strict
-        # which means we must return *exactly* what was input or transform it in a
-        # way that retains the TypeVar state. As we cannot do that in this function
-        # then we have to resort to using `cast`. At the time of writing, we know this
-        # to be safe as we have handled all the types that could be bound to the
-        # `ResponseT` TypeVar, however if that TypeVar is ever updated in the future, then
-        # this function would become unsafe but a type checker would not report an error.
+        if inspect.isclass(origin) and not issubclass(origin, BaseModel) and issubclass(origin, pydantic.BaseModel):
+            raise TypeError("Pydantic models must subclass our base model type, e.g. `from openai import BaseModel`")
+
         if (
             cast_to is not object
             and not origin is list
@@ -187,12 +202,12 @@ class BaseAPIResponse(Generic[R]):
             and not issubclass(origin, BaseModel)
         ):
             raise RuntimeError(
-                f"Invalid state, expected {cast_to} to be a subclass type of {BaseModel}, {dict}, {list} or {Union}."
+                f"Unsupported type, expected {cast_to} to be a subclass of {BaseModel}, {dict}, {list}, {Union}, {NoneType}, {str} or {httpx.Response}."
             )
 
         # split is required to handle cases where additional information is included
         # in the response, e.g. application/json; charset=utf-8
-        content_type, *_ = response.headers.get("content-type").split(";")
+        content_type, *_ = response.headers.get("content-type", "*").split(";")
         if content_type != "application/json":
             if is_basemodel(cast_to):
                 try:
@@ -228,22 +243,55 @@ class BaseAPIResponse(Generic[R]):
 
 
 class APIResponse(BaseAPIResponse[R]):
+    @overload
+    def parse(self, *, to: type[_T]) -> _T:
+        ...
+
+    @overload
     def parse(self) -> R:
+        ...
+
+    def parse(self, *, to: type[_T] | None = None) -> R | _T:
         """Returns the rich python representation of this response's data.
 
         For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`.
+
+        You can customise the type that the response is parsed into through
+        the `to` argument, e.g.
+
+        ```py
+        from openai import BaseModel
+
+
+        class MyModel(BaseModel):
+            foo: str
+
+
+        obj = response.parse(to=MyModel)
+        print(obj.foo)
+        ```
+
+        We support parsing:
+          - `BaseModel`
+          - `dict`
+          - `list`
+          - `Union`
+          - `str`
+          - `httpx.Response`
         """
-        if self._parsed is not None:
-            return self._parsed
+        cache_key = to if to is not None else self._cast_to
+        cached = self._parsed_by_type.get(cache_key)
+        if cached is not None:
+            return cached  # type: ignore[no-any-return]
 
         if not self._is_sse_stream:
             self.read()
 
-        parsed = self._parse()
+        parsed = self._parse(to=to)
         if is_given(self._options.post_parser):
             parsed = self._options.post_parser(parsed)
 
-        self._parsed = parsed
+        self._parsed_by_type[cache_key] = parsed
         return parsed
 
     def read(self) -> bytes:
@@ -297,22 +345,55 @@ class APIResponse(BaseAPIResponse[R]):
 
 
 class AsyncAPIResponse(BaseAPIResponse[R]):
+    @overload
+    async def parse(self, *, to: type[_T]) -> _T:
+        ...
+
+    @overload
     async def parse(self) -> R:
+        ...
+
+    async def parse(self, *, to: type[_T] | None = None) -> R | _T:
         """Returns the rich python representation of this response's data.
 
         For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`.
+
+        You can customise the type that the response is parsed into through
+        the `to` argument, e.g.
+
+        ```py
+        from openai import BaseModel
+
+
+        class MyModel(BaseModel):
+            foo: str
+
+
+        obj = response.parse(to=MyModel)
+        print(obj.foo)
+        ```
+
+        We support parsing:
+          - `BaseModel`
+          - `dict`
+          - `list`
+          - `Union`
+          - `str`
+          - `httpx.Response`
         """
-        if self._parsed is not None:
-            return self._parsed
+        cache_key = to if to is not None else self._cast_to
+        cached = self._parsed_by_type.get(cache_key)
+        if cached is not None:
+            return cached  # type: ignore[no-any-return]
 
         if not self._is_sse_stream:
             await self.read()
 
-        parsed = self._parse()
+        parsed = self._parse(to=to)
         if is_given(self._options.post_parser):
             parsed = self._options.post_parser(parsed)
 
-        self._parsed = parsed
+        self._parsed_by_type[cache_key] = parsed
         return parsed
 
     async def read(self) -> bytes:
@@ -708,26 +789,6 @@ def async_to_custom_raw_response_wrapper(
     return wrapped
 
 
-def extract_stream_chunk_type(stream_cls: type) -> type:
-    """Given a type like `Stream[T]`, returns the generic type variable `T`.
-
-    This also handles the case where a concrete subclass is given, e.g.
-    ```py
-    class MyStream(Stream[bytes]):
-        ...
-
-    extract_stream_chunk_type(MyStream) -> bytes
-    ```
-    """
-    from ._base_client import Stream, AsyncStream
-
-    return extract_type_var_from_base(
-        stream_cls,
-        index=0,
-        generic_bases=cast("tuple[type, ...]", (Stream, AsyncStream)),
-    )
-
-
 def extract_response_type(typ: type[BaseAPIResponse[Any]]) -> type:
     """Given a type like `APIResponse[T]`, returns the generic type variable `T`.
 
src/openai/_streaming.py
@@ -2,13 +2,14 @@
 from __future__ import annotations
 
 import json
+import inspect
 from types import TracebackType
 from typing import TYPE_CHECKING, Any, Generic, TypeVar, Iterator, AsyncIterator, cast
-from typing_extensions import Self, override
+from typing_extensions import Self, TypeGuard, override, get_origin
 
 import httpx
 
-from ._utils import is_mapping
+from ._utils import is_mapping, extract_type_var_from_base
 from ._exceptions import APIError
 
 if TYPE_CHECKING:
@@ -281,3 +282,34 @@ class SSEDecoder:
             pass  # Field is ignored.
 
         return None
+
+
+def is_stream_class_type(typ: type) -> TypeGuard[type[Stream[object]] | type[AsyncStream[object]]]:
+    """TypeGuard for determining whether or not the given type is a subclass of `Stream` / `AsyncStream`"""
+    origin = get_origin(typ) or typ
+    return inspect.isclass(origin) and issubclass(origin, (Stream, AsyncStream))
+
+
+def extract_stream_chunk_type(
+    stream_cls: type,
+    *,
+    failure_message: str | None = None,
+) -> type:
+    """Given a type like `Stream[T]`, returns the generic type variable `T`.
+
+    This also handles the case where a concrete subclass is given, e.g.
+    ```py
+    class MyStream(Stream[bytes]):
+        ...
+
+    extract_stream_chunk_type(MyStream) -> bytes
+    ```
+    """
+    from ._base_client import Stream, AsyncStream
+
+    return extract_type_var_from_base(
+        stream_cls,
+        index=0,
+        generic_bases=cast("tuple[type, ...]", (Stream, AsyncStream)),
+        failure_message=failure_message,
+    )
tests/test_legacy_response.py
@@ -0,0 +1,65 @@
+import json
+
+import httpx
+import pytest
+import pydantic
+
+from openai import OpenAI, BaseModel
+from openai._streaming import Stream
+from openai._base_client import FinalRequestOptions
+from openai._legacy_response import LegacyAPIResponse
+
+
+class PydanticModel(pydantic.BaseModel):
+    ...
+
+
+def test_response_parse_mismatched_basemodel(client: OpenAI) -> None:
+    response = LegacyAPIResponse(
+        raw=httpx.Response(200, content=b"foo"),
+        client=client,
+        stream=False,
+        stream_cls=None,
+        cast_to=str,
+        options=FinalRequestOptions.construct(method="get", url="/foo"),
+    )
+
+    with pytest.raises(
+        TypeError,
+        match="Pydantic models must subclass our base model type, e.g. `from openai import BaseModel`",
+    ):
+        response.parse(to=PydanticModel)
+
+
+def test_response_parse_custom_stream(client: OpenAI) -> None:
+    response = LegacyAPIResponse(
+        raw=httpx.Response(200, content=b"foo"),
+        client=client,
+        stream=True,
+        stream_cls=None,
+        cast_to=str,
+        options=FinalRequestOptions.construct(method="get", url="/foo"),
+    )
+
+    stream = response.parse(to=Stream[int])
+    assert stream._cast_to == int
+
+
+class CustomModel(BaseModel):
+    foo: str
+    bar: int
+
+
+def test_response_parse_custom_model(client: OpenAI) -> None:
+    response = LegacyAPIResponse(
+        raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})),
+        client=client,
+        stream=False,
+        stream_cls=None,
+        cast_to=str,
+        options=FinalRequestOptions.construct(method="get", url="/foo"),
+    )
+
+    obj = response.parse(to=CustomModel)
+    assert obj.foo == "hello!"
+    assert obj.bar == 2
tests/test_response.py
@@ -1,8 +1,11 @@
+import json
 from typing import List
 
 import httpx
 import pytest
+import pydantic
 
+from openai import OpenAI, BaseModel, AsyncOpenAI
 from openai._response import (
     APIResponse,
     BaseAPIResponse,
@@ -11,6 +14,8 @@ from openai._response import (
     AsyncBinaryAPIResponse,
     extract_response_type,
 )
+from openai._streaming import Stream
+from openai._base_client import FinalRequestOptions
 
 
 class ConcreteBaseAPIResponse(APIResponse[bytes]):
@@ -48,3 +53,107 @@ def test_extract_response_type_concrete_subclasses() -> None:
 def test_extract_response_type_binary_response() -> None:
     assert extract_response_type(BinaryAPIResponse) == bytes
     assert extract_response_type(AsyncBinaryAPIResponse) == bytes
+
+
+class PydanticModel(pydantic.BaseModel):
+    ...
+
+
+def test_response_parse_mismatched_basemodel(client: OpenAI) -> None:
+    response = APIResponse(
+        raw=httpx.Response(200, content=b"foo"),
+        client=client,
+        stream=False,
+        stream_cls=None,
+        cast_to=str,
+        options=FinalRequestOptions.construct(method="get", url="/foo"),
+    )
+
+    with pytest.raises(
+        TypeError,
+        match="Pydantic models must subclass our base model type, e.g. `from openai import BaseModel`",
+    ):
+        response.parse(to=PydanticModel)
+
+
+@pytest.mark.asyncio
+async def test_async_response_parse_mismatched_basemodel(async_client: AsyncOpenAI) -> None:
+    response = AsyncAPIResponse(
+        raw=httpx.Response(200, content=b"foo"),
+        client=async_client,
+        stream=False,
+        stream_cls=None,
+        cast_to=str,
+        options=FinalRequestOptions.construct(method="get", url="/foo"),
+    )
+
+    with pytest.raises(
+        TypeError,
+        match="Pydantic models must subclass our base model type, e.g. `from openai import BaseModel`",
+    ):
+        await response.parse(to=PydanticModel)
+
+
+def test_response_parse_custom_stream(client: OpenAI) -> None:
+    response = APIResponse(
+        raw=httpx.Response(200, content=b"foo"),
+        client=client,
+        stream=True,
+        stream_cls=None,
+        cast_to=str,
+        options=FinalRequestOptions.construct(method="get", url="/foo"),
+    )
+
+    stream = response.parse(to=Stream[int])
+    assert stream._cast_to == int
+
+
+@pytest.mark.asyncio
+async def test_async_response_parse_custom_stream(async_client: AsyncOpenAI) -> None:
+    response = AsyncAPIResponse(
+        raw=httpx.Response(200, content=b"foo"),
+        client=async_client,
+        stream=True,
+        stream_cls=None,
+        cast_to=str,
+        options=FinalRequestOptions.construct(method="get", url="/foo"),
+    )
+
+    stream = await response.parse(to=Stream[int])
+    assert stream._cast_to == int
+
+
+class CustomModel(BaseModel):
+    foo: str
+    bar: int
+
+
+def test_response_parse_custom_model(client: OpenAI) -> None:
+    response = APIResponse(
+        raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})),
+        client=client,
+        stream=False,
+        stream_cls=None,
+        cast_to=str,
+        options=FinalRequestOptions.construct(method="get", url="/foo"),
+    )
+
+    obj = response.parse(to=CustomModel)
+    assert obj.foo == "hello!"
+    assert obj.bar == 2
+
+
+@pytest.mark.asyncio
+async def test_async_response_parse_custom_model(async_client: AsyncOpenAI) -> None:
+    response = AsyncAPIResponse(
+        raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})),
+        client=async_client,
+        stream=False,
+        stream_cls=None,
+        cast_to=str,
+        options=FinalRequestOptions.construct(method="get", url="/foo"),
+    )
+
+    obj = await response.parse(to=CustomModel)
+    assert obj.foo == "hello!"
+    assert obj.bar == 2