Commit 052e611a
Changed files (8)
src/openai/_utils/__init__.py
@@ -9,13 +9,11 @@ from ._utils import is_tuple_t as is_tuple_t
from ._utils import parse_date as parse_date
from ._utils import is_sequence as is_sequence
from ._utils import coerce_float as coerce_float
-from ._utils import is_list_type as is_list_type
from ._utils import is_mapping_t as is_mapping_t
from ._utils import removeprefix as removeprefix
from ._utils import removesuffix as removesuffix
from ._utils import extract_files as extract_files
from ._utils import is_sequence_t as is_sequence_t
-from ._utils import is_union_type as is_union_type
from ._utils import required_args as required_args
from ._utils import coerce_boolean as coerce_boolean
from ._utils import coerce_integer as coerce_integer
@@ -23,15 +21,20 @@ from ._utils import file_from_path as file_from_path
from ._utils import parse_datetime as parse_datetime
from ._utils import strip_not_given as strip_not_given
from ._utils import deepcopy_minimal as deepcopy_minimal
-from ._utils import extract_type_arg as extract_type_arg
-from ._utils import is_required_type as is_required_type
from ._utils import get_async_library as get_async_library
-from ._utils import is_annotated_type as is_annotated_type
from ._utils import maybe_coerce_float as maybe_coerce_float
from ._utils import get_required_header as get_required_header
from ._utils import maybe_coerce_boolean as maybe_coerce_boolean
from ._utils import maybe_coerce_integer as maybe_coerce_integer
-from ._utils import strip_annotated_type as strip_annotated_type
+from ._typing import is_list_type as is_list_type
+from ._typing import is_union_type as is_union_type
+from ._typing import extract_type_arg as extract_type_arg
+from ._typing import is_required_type as is_required_type
+from ._typing import is_annotated_type as is_annotated_type
+from ._typing import strip_annotated_type as strip_annotated_type
+from ._typing import extract_type_var_from_base as extract_type_var_from_base
+from ._streams import consume_sync_iterator as consume_sync_iterator
+from ._streams import consume_async_iterator as consume_async_iterator
from ._transform import PropertyInfo as PropertyInfo
from ._transform import transform as transform
from ._transform import maybe_transform as maybe_transform
src/openai/_utils/_streams.py
@@ -0,0 +1,12 @@
+from typing import Any
+from typing_extensions import Iterator, AsyncIterator
+
+
+def consume_sync_iterator(iterator: Iterator[Any]) -> None:
+ for _ in iterator:
+ ...
+
+
+async def consume_async_iterator(iterator: AsyncIterator[Any]) -> None:
+ async for _ in iterator:
+ ...
src/openai/_utils/_transform.py
@@ -6,9 +6,8 @@ from typing_extensions import Literal, get_args, override, get_type_hints
import pydantic
-from ._utils import (
- is_list,
- is_mapping,
+from ._utils import is_list, is_mapping
+from ._typing import (
is_list_type,
is_union_type,
extract_type_arg,
src/openai/_utils/_typing.py
@@ -0,0 +1,80 @@
+from __future__ import annotations
+
+from typing import Any, cast
+from typing_extensions import Required, Annotated, get_args, get_origin
+
+from .._types import InheritsGeneric
+from .._compat import is_union as _is_union
+
+
+def is_annotated_type(typ: type) -> bool:
+ return get_origin(typ) == Annotated
+
+
+def is_list_type(typ: type) -> bool:
+ return (get_origin(typ) or typ) == list
+
+
+def is_union_type(typ: type) -> bool:
+ return _is_union(get_origin(typ))
+
+
+def is_required_type(typ: type) -> bool:
+ return get_origin(typ) == Required
+
+
+# Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]]
+def strip_annotated_type(typ: type) -> type:
+ if is_required_type(typ) or is_annotated_type(typ):
+ return strip_annotated_type(cast(type, get_args(typ)[0]))
+
+ return typ
+
+
+def extract_type_arg(typ: type, index: int) -> type:
+ args = get_args(typ)
+ try:
+ return cast(type, args[index])
+ except IndexError as err:
+ raise RuntimeError(f"Expected type {typ} to have a type argument at index {index} but it did not") from err
+
+
+def extract_type_var_from_base(typ: type, *, generic_bases: tuple[type, ...], index: int) -> type:
+ """Given a type like `Foo[T]`, returns the generic type variable `T`.
+
+ This also handles the case where a concrete subclass is given, e.g.
+ ```py
+ class MyResponse(Foo[bytes]):
+ ...
+
+ extract_type_var(MyResponse, bases=(Foo,), index=0) -> bytes
+ ```
+ """
+ cls = cast(object, get_origin(typ) or typ)
+ if cls in generic_bases:
+ # we're given the class directly
+ return extract_type_arg(typ, index)
+
+ # if a subclass is given
+ # ---
+ # this is needed as __orig_bases__ is not present in the typeshed stubs
+ # because it is intended to be for internal use only, however there does
+ # not seem to be a way to resolve generic TypeVars for inherited subclasses
+ # without using it.
+ if isinstance(cls, InheritsGeneric):
+ target_base_class: Any | None = None
+ for base in cls.__orig_bases__:
+ if base.__origin__ in generic_bases:
+ target_base_class = base
+ break
+
+ if target_base_class is None:
+ raise RuntimeError(
+ "Could not find the generic base class;\n"
+ "This should never happen;\n"
+ f"Does {cls} inherit from one of {generic_bases} ?"
+ )
+
+ return extract_type_arg(target_base_class, index)
+
+ raise RuntimeError(f"Could not resolve inner type variable at index {index} for {typ}")
src/openai/_utils/_utils.py
@@ -16,12 +16,11 @@ from typing import (
overload,
)
from pathlib import Path
-from typing_extensions import Required, Annotated, TypeGuard, get_args, get_origin
+from typing_extensions import TypeGuard
import sniffio
from .._types import Headers, NotGiven, FileTypes, NotGivenOr, HeadersLike
-from .._compat import is_union as _is_union
from .._compat import parse_date as parse_date
from .._compat import parse_datetime as parse_datetime
@@ -166,38 +165,6 @@ def is_list(obj: object) -> TypeGuard[list[object]]:
return isinstance(obj, list)
-def is_annotated_type(typ: type) -> bool:
- return get_origin(typ) == Annotated
-
-
-def is_list_type(typ: type) -> bool:
- return (get_origin(typ) or typ) == list
-
-
-def is_union_type(typ: type) -> bool:
- return _is_union(get_origin(typ))
-
-
-def is_required_type(typ: type) -> bool:
- return get_origin(typ) == Required
-
-
-# Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]]
-def strip_annotated_type(typ: type) -> type:
- if is_required_type(typ) or is_annotated_type(typ):
- return strip_annotated_type(cast(type, get_args(typ)[0]))
-
- return typ
-
-
-def extract_type_arg(typ: type, index: int) -> type:
- args = get_args(typ)
- try:
- return cast(type, args[index])
- except IndexError as err:
- raise RuntimeError(f"Expected type {typ} to have a type argument at index {index} but it did not") from err
-
-
def deepcopy_minimal(item: _T) -> _T:
"""Minimal reimplementation of copy.deepcopy() that will only copy certain object types:
src/openai/_response.py
@@ -5,12 +5,12 @@ import logging
import datetime
import functools
from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, cast
-from typing_extensions import Awaitable, ParamSpec, get_args, override, get_origin
+from typing_extensions import Awaitable, ParamSpec, override, get_origin
import httpx
from ._types import NoneType, UnknownResponse, BinaryResponseContent
-from ._utils import is_given
+from ._utils import is_given, extract_type_var_from_base
from ._models import BaseModel, is_basemodel
from ._constants import RAW_RESPONSE_HEADER
from ._exceptions import APIResponseValidationError
@@ -221,12 +221,13 @@ class MissingStreamClassError(TypeError):
def _extract_stream_chunk_type(stream_cls: type) -> type:
- args = get_args(stream_cls)
- if not args:
- raise TypeError(
- f"Expected stream_cls to have been given a generic type argument, e.g. Stream[Foo] but received {stream_cls}",
- )
- return cast(type, args[0])
+ from ._base_client import Stream, AsyncStream
+
+ return extract_type_var_from_base(
+ stream_cls,
+ index=0,
+ generic_bases=cast("tuple[type, ...]", (Stream, AsyncStream)),
+ )
def to_raw_response_wrapper(func: Callable[P, R]) -> Callable[P, APIResponse[R]]:
src/openai/_streaming.py
@@ -2,12 +2,12 @@
from __future__ import annotations
import json
-from typing import TYPE_CHECKING, Any, Generic, Iterator, AsyncIterator
-from typing_extensions import override
+from types import TracebackType
+from typing import TYPE_CHECKING, Any, Generic, TypeVar, Iterator, AsyncIterator, cast
+from typing_extensions import Self, override
import httpx
-from ._types import ResponseT
from ._utils import is_mapping
from ._exceptions import APIError
@@ -15,7 +15,10 @@ if TYPE_CHECKING:
from ._client import OpenAI, AsyncOpenAI
-class Stream(Generic[ResponseT]):
+_T = TypeVar("_T")
+
+
+class Stream(Generic[_T]):
"""Provides the core interface to iterate over a synchronous stream response."""
response: httpx.Response
@@ -23,7 +26,7 @@ class Stream(Generic[ResponseT]):
def __init__(
self,
*,
- cast_to: type[ResponseT],
+ cast_to: type[_T],
response: httpx.Response,
client: OpenAI,
) -> None:
@@ -33,18 +36,18 @@ class Stream(Generic[ResponseT]):
self._decoder = SSEDecoder()
self._iterator = self.__stream__()
- def __next__(self) -> ResponseT:
+ def __next__(self) -> _T:
return self._iterator.__next__()
- def __iter__(self) -> Iterator[ResponseT]:
+ def __iter__(self) -> Iterator[_T]:
for item in self._iterator:
yield item
def _iter_events(self) -> Iterator[ServerSentEvent]:
yield from self._decoder.iter(self.response.iter_lines())
- def __stream__(self) -> Iterator[ResponseT]:
- cast_to = self._cast_to
+ def __stream__(self) -> Iterator[_T]:
+ cast_to = cast(Any, self._cast_to)
response = self.response
process_data = self._client._process_response_data
iterator = self._iter_events()
@@ -68,8 +71,27 @@ class Stream(Generic[ResponseT]):
for _sse in iterator:
...
+ def __enter__(self) -> Self:
+ return self
+
+ def __exit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc: BaseException | None,
+ exc_tb: TracebackType | None,
+ ) -> None:
+ self.close()
+
+ def close(self) -> None:
+ """
+ Close the response and release the connection.
+
+ Automatically called if the response body is read to completion.
+ """
+ self.response.close()
-class AsyncStream(Generic[ResponseT]):
+
+class AsyncStream(Generic[_T]):
"""Provides the core interface to iterate over an asynchronous stream response."""
response: httpx.Response
@@ -77,7 +99,7 @@ class AsyncStream(Generic[ResponseT]):
def __init__(
self,
*,
- cast_to: type[ResponseT],
+ cast_to: type[_T],
response: httpx.Response,
client: AsyncOpenAI,
) -> None:
@@ -87,10 +109,10 @@ class AsyncStream(Generic[ResponseT]):
self._decoder = SSEDecoder()
self._iterator = self.__stream__()
- async def __anext__(self) -> ResponseT:
+ async def __anext__(self) -> _T:
return await self._iterator.__anext__()
- async def __aiter__(self) -> AsyncIterator[ResponseT]:
+ async def __aiter__(self) -> AsyncIterator[_T]:
async for item in self._iterator:
yield item
@@ -98,8 +120,8 @@ class AsyncStream(Generic[ResponseT]):
async for sse in self._decoder.aiter(self.response.aiter_lines()):
yield sse
- async def __stream__(self) -> AsyncIterator[ResponseT]:
- cast_to = self._cast_to
+ async def __stream__(self) -> AsyncIterator[_T]:
+ cast_to = cast(Any, self._cast_to)
response = self.response
process_data = self._client._process_response_data
iterator = self._iter_events()
@@ -123,6 +145,25 @@ class AsyncStream(Generic[ResponseT]):
async for _sse in iterator:
...
+ async def __aenter__(self) -> Self:
+ return self
+
+ async def __aexit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc: BaseException | None,
+ exc_tb: TracebackType | None,
+ ) -> None:
+ await self.close()
+
+ async def close(self) -> None:
+ """
+ Close the response and release the connection.
+
+ Automatically called if the response body is read to completion.
+ """
+ await self.response.aclose()
+
class ServerSentEvent:
def __init__(
src/openai/_types.py
@@ -353,3 +353,17 @@ StrBytesIntFloat = Union[str, bytes, int, float]
IncEx: TypeAlias = "set[int] | set[str] | dict[int, Any] | dict[str, Any] | None"
PostParser = Callable[[Any], Any]
+
+
+@runtime_checkable
+class InheritsGeneric(Protocol):
+ """Represents a type that has inherited from `Generic`
+ The `__orig_bases__` property can be used to determine the resolved
+ type variable for a given base class.
+ """
+
+ __orig_bases__: tuple[_GenericAlias]
+
+
+class _GenericAlias(Protocol):
+ __origin__: type[object]