Commit 61f3346b
Changed files (7)
src/openai/_utils/_typing.py
@@ -45,7 +45,13 @@ def extract_type_arg(typ: type, index: int) -> type:
raise RuntimeError(f"Expected type {typ} to have a type argument at index {index} but it did not") from err
-def extract_type_var_from_base(typ: type, *, generic_bases: tuple[type, ...], index: int) -> type:
+def extract_type_var_from_base(
+ typ: type,
+ *,
+ generic_bases: tuple[type, ...],
+ index: int,
+ failure_message: str | None = None,
+) -> type:
"""Given a type like `Foo[T]`, returns the generic type variable `T`.
This also handles the case where a concrete subclass is given, e.g.
@@ -104,4 +110,4 @@ def extract_type_var_from_base(typ: type, *, generic_bases: tuple[type, ...], in
return extracted
- raise RuntimeError(f"Could not resolve inner type variable at index {index} for {typ}")
+ raise RuntimeError(failure_message or f"Could not resolve inner type variable at index {index} for {typ}")
src/openai/__init__.py
@@ -9,6 +9,7 @@ from . import types
from ._types import NoneType, Transport, ProxiesTypes
from ._utils import file_from_path
from ._client import Client, OpenAI, Stream, Timeout, Transport, AsyncClient, AsyncOpenAI, AsyncStream, RequestOptions
+from ._models import BaseModel
from ._version import __title__, __version__
from ._response import APIResponse as APIResponse, AsyncAPIResponse as AsyncAPIResponse
from ._exceptions import (
@@ -59,6 +60,7 @@ __all__ = [
"OpenAI",
"AsyncOpenAI",
"file_from_path",
+ "BaseModel",
]
from .lib import azure as _azure
src/openai/_legacy_response.py
@@ -5,25 +5,28 @@ import inspect
import logging
import datetime
import functools
-from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, Iterator, AsyncIterator, cast
-from typing_extensions import Awaitable, ParamSpec, get_args, override, deprecated, get_origin
+from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, Iterator, AsyncIterator, cast, overload
+from typing_extensions import Awaitable, ParamSpec, override, deprecated, get_origin
import anyio
import httpx
+import pydantic
from ._types import NoneType
from ._utils import is_given
from ._models import BaseModel, is_basemodel
from ._constants import RAW_RESPONSE_HEADER
+from ._streaming import Stream, AsyncStream, is_stream_class_type, extract_stream_chunk_type
from ._exceptions import APIResponseValidationError
if TYPE_CHECKING:
from ._models import FinalRequestOptions
- from ._base_client import Stream, BaseClient, AsyncStream
+ from ._base_client import BaseClient
P = ParamSpec("P")
R = TypeVar("R")
+_T = TypeVar("_T")
log: logging.Logger = logging.getLogger(__name__)
@@ -43,7 +46,7 @@ class LegacyAPIResponse(Generic[R]):
_cast_to: type[R]
_client: BaseClient[Any, Any]
- _parsed: R | None
+ _parsed_by_type: dict[type[Any], Any]
_stream: bool
_stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None
_options: FinalRequestOptions
@@ -62,27 +65,60 @@ class LegacyAPIResponse(Generic[R]):
) -> None:
self._cast_to = cast_to
self._client = client
- self._parsed = None
+ self._parsed_by_type = {}
self._stream = stream
self._stream_cls = stream_cls
self._options = options
self.http_response = raw
+ @overload
+ def parse(self, *, to: type[_T]) -> _T:
+ ...
+
+ @overload
def parse(self) -> R:
+ ...
+
+ def parse(self, *, to: type[_T] | None = None) -> R | _T:
"""Returns the rich python representation of this response's data.
+ NOTE: For the async client: this will become a coroutine in the next major version.
+
For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`.
- NOTE: For the async client: this will become a coroutine in the next major version.
+ You can customise the type that the response is parsed into through
+ the `to` argument, e.g.
+
+ ```py
+ from openai import BaseModel
+
+
+ class MyModel(BaseModel):
+ foo: str
+
+
+ obj = response.parse(to=MyModel)
+ print(obj.foo)
+ ```
+
+ We support parsing:
+ - `BaseModel`
+ - `dict`
+ - `list`
+ - `Union`
+ - `str`
+ - `httpx.Response`
"""
- if self._parsed is not None:
- return self._parsed
+ cache_key = to if to is not None else self._cast_to
+ cached = self._parsed_by_type.get(cache_key)
+ if cached is not None:
+ return cached # type: ignore[no-any-return]
- parsed = self._parse()
+ parsed = self._parse(to=to)
if is_given(self._options.post_parser):
parsed = self._options.post_parser(parsed)
- self._parsed = parsed
+ self._parsed_by_type[cache_key] = parsed
return parsed
@property
@@ -135,13 +171,29 @@ class LegacyAPIResponse(Generic[R]):
"""The time taken for the complete request/response cycle to complete."""
return self.http_response.elapsed
- def _parse(self) -> R:
+ def _parse(self, *, to: type[_T] | None = None) -> R | _T:
if self._stream:
+ if to:
+ if not is_stream_class_type(to):
+ raise TypeError(f"Expected custom parse type to be a subclass of {Stream} or {AsyncStream}")
+
+ return cast(
+ _T,
+ to(
+ cast_to=extract_stream_chunk_type(
+ to,
+ failure_message="Expected custom stream type to be passed with a type argument, e.g. Stream[ChunkType]",
+ ),
+ response=self.http_response,
+ client=cast(Any, self._client),
+ ),
+ )
+
if self._stream_cls:
return cast(
R,
self._stream_cls(
- cast_to=_extract_stream_chunk_type(self._stream_cls),
+ cast_to=extract_stream_chunk_type(self._stream_cls),
response=self.http_response,
client=cast(Any, self._client),
),
@@ -160,7 +212,7 @@ class LegacyAPIResponse(Generic[R]):
),
)
- cast_to = self._cast_to
+ cast_to = to if to is not None else self._cast_to
if cast_to is NoneType:
return cast(R, None)
@@ -186,14 +238,9 @@ class LegacyAPIResponse(Generic[R]):
raise ValueError(f"Subclasses of httpx.Response cannot be passed to `cast_to`")
return cast(R, response)
- # The check here is necessary as we are subverting the the type system
- # with casts as the relationship between TypeVars and Types are very strict
- # which means we must return *exactly* what was input or transform it in a
- # way that retains the TypeVar state. As we cannot do that in this function
- # then we have to resort to using `cast`. At the time of writing, we know this
- # to be safe as we have handled all the types that could be bound to the
- # `ResponseT` TypeVar, however if that TypeVar is ever updated in the future, then
- # this function would become unsafe but a type checker would not report an error.
+ if inspect.isclass(origin) and not issubclass(origin, BaseModel) and issubclass(origin, pydantic.BaseModel):
+ raise TypeError("Pydantic models must subclass our base model type, e.g. `from openai import BaseModel`")
+
if (
cast_to is not object
and not origin is list
@@ -202,12 +249,12 @@ class LegacyAPIResponse(Generic[R]):
and not issubclass(origin, BaseModel)
):
raise RuntimeError(
- f"Invalid state, expected {cast_to} to be a subclass type of {BaseModel}, {dict}, {list} or {Union}."
+ f"Unsupported type, expected {cast_to} to be a subclass of {BaseModel}, {dict}, {list}, {Union}, {NoneType}, {str} or {httpx.Response}."
)
# split is required to handle cases where additional information is included
# in the response, e.g. application/json; charset=utf-8
- content_type, *_ = response.headers.get("content-type").split(";")
+ content_type, *_ = response.headers.get("content-type", "*").split(";")
if content_type != "application/json":
if is_basemodel(cast_to):
try:
@@ -253,15 +300,6 @@ class MissingStreamClassError(TypeError):
)
-def _extract_stream_chunk_type(stream_cls: type) -> type:
- args = get_args(stream_cls)
- if not args:
- raise TypeError(
- f"Expected stream_cls to have been given a generic type argument, e.g. Stream[Foo] but received {stream_cls}",
- )
- return cast(type, args[0])
-
-
def to_raw_response_wrapper(func: Callable[P, R]) -> Callable[P, LegacyAPIResponse[R]]:
"""Higher order function that takes one of our bound API methods and wraps it
to support returning the raw `APIResponse` object directly.
src/openai/_response.py
@@ -16,25 +16,29 @@ from typing import (
Iterator,
AsyncIterator,
cast,
+ overload,
)
from typing_extensions import Awaitable, ParamSpec, override, get_origin
import anyio
import httpx
+import pydantic
from ._types import NoneType
from ._utils import is_given, extract_type_var_from_base
from ._models import BaseModel, is_basemodel
from ._constants import RAW_RESPONSE_HEADER, OVERRIDE_CAST_TO_HEADER
+from ._streaming import Stream, AsyncStream, is_stream_class_type, extract_stream_chunk_type
from ._exceptions import OpenAIError, APIResponseValidationError
if TYPE_CHECKING:
from ._models import FinalRequestOptions
- from ._base_client import Stream, BaseClient, AsyncStream
+ from ._base_client import BaseClient
P = ParamSpec("P")
R = TypeVar("R")
+_T = TypeVar("_T")
_APIResponseT = TypeVar("_APIResponseT", bound="APIResponse[Any]")
_AsyncAPIResponseT = TypeVar("_AsyncAPIResponseT", bound="AsyncAPIResponse[Any]")
@@ -44,7 +48,7 @@ log: logging.Logger = logging.getLogger(__name__)
class BaseAPIResponse(Generic[R]):
_cast_to: type[R]
_client: BaseClient[Any, Any]
- _parsed: R | None
+ _parsed_by_type: dict[type[Any], Any]
_is_sse_stream: bool
_stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None
_options: FinalRequestOptions
@@ -63,7 +67,7 @@ class BaseAPIResponse(Generic[R]):
) -> None:
self._cast_to = cast_to
self._client = client
- self._parsed = None
+ self._parsed_by_type = {}
self._is_sse_stream = stream
self._stream_cls = stream_cls
self._options = options
@@ -116,8 +120,24 @@ class BaseAPIResponse(Generic[R]):
f"<{self.__class__.__name__} [{self.status_code} {self.http_response.reason_phrase}] type={self._cast_to}>"
)
- def _parse(self) -> R:
+ def _parse(self, *, to: type[_T] | None = None) -> R | _T:
if self._is_sse_stream:
+ if to:
+ if not is_stream_class_type(to):
+ raise TypeError(f"Expected custom parse type to be a subclass of {Stream} or {AsyncStream}")
+
+ return cast(
+ _T,
+ to(
+ cast_to=extract_stream_chunk_type(
+ to,
+ failure_message="Expected custom stream type to be passed with a type argument, e.g. Stream[ChunkType]",
+ ),
+ response=self.http_response,
+ client=cast(Any, self._client),
+ ),
+ )
+
if self._stream_cls:
return cast(
R,
@@ -141,7 +161,7 @@ class BaseAPIResponse(Generic[R]):
),
)
- cast_to = self._cast_to
+ cast_to = to if to is not None else self._cast_to
if cast_to is NoneType:
return cast(R, None)
@@ -171,14 +191,9 @@ class BaseAPIResponse(Generic[R]):
raise ValueError(f"Subclasses of httpx.Response cannot be passed to `cast_to`")
return cast(R, response)
- # The check here is necessary as we are subverting the the type system
- # with casts as the relationship between TypeVars and Types are very strict
- # which means we must return *exactly* what was input or transform it in a
- # way that retains the TypeVar state. As we cannot do that in this function
- # then we have to resort to using `cast`. At the time of writing, we know this
- # to be safe as we have handled all the types that could be bound to the
- # `ResponseT` TypeVar, however if that TypeVar is ever updated in the future, then
- # this function would become unsafe but a type checker would not report an error.
+ if inspect.isclass(origin) and not issubclass(origin, BaseModel) and issubclass(origin, pydantic.BaseModel):
+ raise TypeError("Pydantic models must subclass our base model type, e.g. `from openai import BaseModel`")
+
if (
cast_to is not object
and not origin is list
@@ -187,12 +202,12 @@ class BaseAPIResponse(Generic[R]):
and not issubclass(origin, BaseModel)
):
raise RuntimeError(
- f"Invalid state, expected {cast_to} to be a subclass type of {BaseModel}, {dict}, {list} or {Union}."
+ f"Unsupported type, expected {cast_to} to be a subclass of {BaseModel}, {dict}, {list}, {Union}, {NoneType}, {str} or {httpx.Response}."
)
# split is required to handle cases where additional information is included
# in the response, e.g. application/json; charset=utf-8
- content_type, *_ = response.headers.get("content-type").split(";")
+ content_type, *_ = response.headers.get("content-type", "*").split(";")
if content_type != "application/json":
if is_basemodel(cast_to):
try:
@@ -228,22 +243,55 @@ class BaseAPIResponse(Generic[R]):
class APIResponse(BaseAPIResponse[R]):
+ @overload
+ def parse(self, *, to: type[_T]) -> _T:
+ ...
+
+ @overload
def parse(self) -> R:
+ ...
+
+ def parse(self, *, to: type[_T] | None = None) -> R | _T:
"""Returns the rich python representation of this response's data.
For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`.
+
+ You can customise the type that the response is parsed into through
+ the `to` argument, e.g.
+
+ ```py
+ from openai import BaseModel
+
+
+ class MyModel(BaseModel):
+ foo: str
+
+
+ obj = response.parse(to=MyModel)
+ print(obj.foo)
+ ```
+
+ We support parsing:
+ - `BaseModel`
+ - `dict`
+ - `list`
+ - `Union`
+ - `str`
+ - `httpx.Response`
"""
- if self._parsed is not None:
- return self._parsed
+ cache_key = to if to is not None else self._cast_to
+ cached = self._parsed_by_type.get(cache_key)
+ if cached is not None:
+ return cached # type: ignore[no-any-return]
if not self._is_sse_stream:
self.read()
- parsed = self._parse()
+ parsed = self._parse(to=to)
if is_given(self._options.post_parser):
parsed = self._options.post_parser(parsed)
- self._parsed = parsed
+ self._parsed_by_type[cache_key] = parsed
return parsed
def read(self) -> bytes:
@@ -297,22 +345,55 @@ class APIResponse(BaseAPIResponse[R]):
class AsyncAPIResponse(BaseAPIResponse[R]):
+ @overload
+ async def parse(self, *, to: type[_T]) -> _T:
+ ...
+
+ @overload
async def parse(self) -> R:
+ ...
+
+ async def parse(self, *, to: type[_T] | None = None) -> R | _T:
"""Returns the rich python representation of this response's data.
For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`.
+
+ You can customise the type that the response is parsed into through
+ the `to` argument, e.g.
+
+ ```py
+ from openai import BaseModel
+
+
+ class MyModel(BaseModel):
+ foo: str
+
+
+ obj = response.parse(to=MyModel)
+ print(obj.foo)
+ ```
+
+ We support parsing:
+ - `BaseModel`
+ - `dict`
+ - `list`
+ - `Union`
+ - `str`
+ - `httpx.Response`
"""
- if self._parsed is not None:
- return self._parsed
+ cache_key = to if to is not None else self._cast_to
+ cached = self._parsed_by_type.get(cache_key)
+ if cached is not None:
+ return cached # type: ignore[no-any-return]
if not self._is_sse_stream:
await self.read()
- parsed = self._parse()
+ parsed = self._parse(to=to)
if is_given(self._options.post_parser):
parsed = self._options.post_parser(parsed)
- self._parsed = parsed
+ self._parsed_by_type[cache_key] = parsed
return parsed
async def read(self) -> bytes:
@@ -708,26 +789,6 @@ def async_to_custom_raw_response_wrapper(
return wrapped
-def extract_stream_chunk_type(stream_cls: type) -> type:
- """Given a type like `Stream[T]`, returns the generic type variable `T`.
-
- This also handles the case where a concrete subclass is given, e.g.
- ```py
- class MyStream(Stream[bytes]):
- ...
-
- extract_stream_chunk_type(MyStream) -> bytes
- ```
- """
- from ._base_client import Stream, AsyncStream
-
- return extract_type_var_from_base(
- stream_cls,
- index=0,
- generic_bases=cast("tuple[type, ...]", (Stream, AsyncStream)),
- )
-
-
def extract_response_type(typ: type[BaseAPIResponse[Any]]) -> type:
"""Given a type like `APIResponse[T]`, returns the generic type variable `T`.
src/openai/_streaming.py
@@ -2,13 +2,14 @@
from __future__ import annotations
import json
+import inspect
from types import TracebackType
from typing import TYPE_CHECKING, Any, Generic, TypeVar, Iterator, AsyncIterator, cast
-from typing_extensions import Self, override
+from typing_extensions import Self, TypeGuard, override, get_origin
import httpx
-from ._utils import is_mapping
+from ._utils import is_mapping, extract_type_var_from_base
from ._exceptions import APIError
if TYPE_CHECKING:
@@ -281,3 +282,34 @@ class SSEDecoder:
pass # Field is ignored.
return None
+
+
+def is_stream_class_type(typ: type) -> TypeGuard[type[Stream[object]] | type[AsyncStream[object]]]:
+ """TypeGuard for determining whether or not the given type is a subclass of `Stream` / `AsyncStream`"""
+ origin = get_origin(typ) or typ
+ return inspect.isclass(origin) and issubclass(origin, (Stream, AsyncStream))
+
+
+def extract_stream_chunk_type(
+ stream_cls: type,
+ *,
+ failure_message: str | None = None,
+) -> type:
+ """Given a type like `Stream[T]`, returns the generic type variable `T`.
+
+ This also handles the case where a concrete subclass is given, e.g.
+ ```py
+ class MyStream(Stream[bytes]):
+ ...
+
+ extract_stream_chunk_type(MyStream) -> bytes
+ ```
+ """
+ from ._base_client import Stream, AsyncStream
+
+ return extract_type_var_from_base(
+ stream_cls,
+ index=0,
+ generic_bases=cast("tuple[type, ...]", (Stream, AsyncStream)),
+ failure_message=failure_message,
+ )
tests/test_legacy_response.py
@@ -0,0 +1,65 @@
+import json
+
+import httpx
+import pytest
+import pydantic
+
+from openai import OpenAI, BaseModel
+from openai._streaming import Stream
+from openai._base_client import FinalRequestOptions
+from openai._legacy_response import LegacyAPIResponse
+
+
+class PydanticModel(pydantic.BaseModel):
+ ...
+
+
+def test_response_parse_mismatched_basemodel(client: OpenAI) -> None:
+ response = LegacyAPIResponse(
+ raw=httpx.Response(200, content=b"foo"),
+ client=client,
+ stream=False,
+ stream_cls=None,
+ cast_to=str,
+ options=FinalRequestOptions.construct(method="get", url="/foo"),
+ )
+
+ with pytest.raises(
+ TypeError,
+ match="Pydantic models must subclass our base model type, e.g. `from openai import BaseModel`",
+ ):
+ response.parse(to=PydanticModel)
+
+
+def test_response_parse_custom_stream(client: OpenAI) -> None:
+ response = LegacyAPIResponse(
+ raw=httpx.Response(200, content=b"foo"),
+ client=client,
+ stream=True,
+ stream_cls=None,
+ cast_to=str,
+ options=FinalRequestOptions.construct(method="get", url="/foo"),
+ )
+
+ stream = response.parse(to=Stream[int])
+ assert stream._cast_to == int
+
+
+class CustomModel(BaseModel):
+ foo: str
+ bar: int
+
+
+def test_response_parse_custom_model(client: OpenAI) -> None:
+ response = LegacyAPIResponse(
+ raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})),
+ client=client,
+ stream=False,
+ stream_cls=None,
+ cast_to=str,
+ options=FinalRequestOptions.construct(method="get", url="/foo"),
+ )
+
+ obj = response.parse(to=CustomModel)
+ assert obj.foo == "hello!"
+ assert obj.bar == 2
tests/test_response.py
@@ -1,8 +1,11 @@
+import json
from typing import List
import httpx
import pytest
+import pydantic
+from openai import OpenAI, BaseModel, AsyncOpenAI
from openai._response import (
APIResponse,
BaseAPIResponse,
@@ -11,6 +14,8 @@ from openai._response import (
AsyncBinaryAPIResponse,
extract_response_type,
)
+from openai._streaming import Stream
+from openai._base_client import FinalRequestOptions
class ConcreteBaseAPIResponse(APIResponse[bytes]):
@@ -48,3 +53,107 @@ def test_extract_response_type_concrete_subclasses() -> None:
def test_extract_response_type_binary_response() -> None:
assert extract_response_type(BinaryAPIResponse) == bytes
assert extract_response_type(AsyncBinaryAPIResponse) == bytes
+
+
+class PydanticModel(pydantic.BaseModel):
+ ...
+
+
+def test_response_parse_mismatched_basemodel(client: OpenAI) -> None:
+ response = APIResponse(
+ raw=httpx.Response(200, content=b"foo"),
+ client=client,
+ stream=False,
+ stream_cls=None,
+ cast_to=str,
+ options=FinalRequestOptions.construct(method="get", url="/foo"),
+ )
+
+ with pytest.raises(
+ TypeError,
+ match="Pydantic models must subclass our base model type, e.g. `from openai import BaseModel`",
+ ):
+ response.parse(to=PydanticModel)
+
+
+@pytest.mark.asyncio
+async def test_async_response_parse_mismatched_basemodel(async_client: AsyncOpenAI) -> None:
+ response = AsyncAPIResponse(
+ raw=httpx.Response(200, content=b"foo"),
+ client=async_client,
+ stream=False,
+ stream_cls=None,
+ cast_to=str,
+ options=FinalRequestOptions.construct(method="get", url="/foo"),
+ )
+
+ with pytest.raises(
+ TypeError,
+ match="Pydantic models must subclass our base model type, e.g. `from openai import BaseModel`",
+ ):
+ await response.parse(to=PydanticModel)
+
+
+def test_response_parse_custom_stream(client: OpenAI) -> None:
+ response = APIResponse(
+ raw=httpx.Response(200, content=b"foo"),
+ client=client,
+ stream=True,
+ stream_cls=None,
+ cast_to=str,
+ options=FinalRequestOptions.construct(method="get", url="/foo"),
+ )
+
+ stream = response.parse(to=Stream[int])
+ assert stream._cast_to == int
+
+
+@pytest.mark.asyncio
+async def test_async_response_parse_custom_stream(async_client: AsyncOpenAI) -> None:
+ response = AsyncAPIResponse(
+ raw=httpx.Response(200, content=b"foo"),
+ client=async_client,
+ stream=True,
+ stream_cls=None,
+ cast_to=str,
+ options=FinalRequestOptions.construct(method="get", url="/foo"),
+ )
+
+ stream = await response.parse(to=Stream[int])
+ assert stream._cast_to == int
+
+
+class CustomModel(BaseModel):
+ foo: str
+ bar: int
+
+
+def test_response_parse_custom_model(client: OpenAI) -> None:
+ response = APIResponse(
+ raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})),
+ client=client,
+ stream=False,
+ stream_cls=None,
+ cast_to=str,
+ options=FinalRequestOptions.construct(method="get", url="/foo"),
+ )
+
+ obj = response.parse(to=CustomModel)
+ assert obj.foo == "hello!"
+ assert obj.bar == 2
+
+
+@pytest.mark.asyncio
+async def test_async_response_parse_custom_model(async_client: AsyncOpenAI) -> None:
+ response = AsyncAPIResponse(
+ raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})),
+ client=async_client,
+ stream=False,
+ stream_cls=None,
+ cast_to=str,
+ options=FinalRequestOptions.construct(method="get", url="/foo"),
+ )
+
+ obj = await response.parse(to=CustomModel)
+ assert obj.foo == "hello!"
+ assert obj.bar == 2