Commit 052e611a

Stainless Bot <107565488+stainless-bot@users.noreply.github.com>
2023-12-20 00:23:44
chore(internal): minor utils restructuring (#992)
1 parent 3d6ca3e
src/openai/_utils/__init__.py
@@ -9,13 +9,11 @@ from ._utils import is_tuple_t as is_tuple_t
 from ._utils import parse_date as parse_date
 from ._utils import is_sequence as is_sequence
 from ._utils import coerce_float as coerce_float
-from ._utils import is_list_type as is_list_type
 from ._utils import is_mapping_t as is_mapping_t
 from ._utils import removeprefix as removeprefix
 from ._utils import removesuffix as removesuffix
 from ._utils import extract_files as extract_files
 from ._utils import is_sequence_t as is_sequence_t
-from ._utils import is_union_type as is_union_type
 from ._utils import required_args as required_args
 from ._utils import coerce_boolean as coerce_boolean
 from ._utils import coerce_integer as coerce_integer
@@ -23,15 +21,20 @@ from ._utils import file_from_path as file_from_path
 from ._utils import parse_datetime as parse_datetime
 from ._utils import strip_not_given as strip_not_given
 from ._utils import deepcopy_minimal as deepcopy_minimal
-from ._utils import extract_type_arg as extract_type_arg
-from ._utils import is_required_type as is_required_type
 from ._utils import get_async_library as get_async_library
-from ._utils import is_annotated_type as is_annotated_type
 from ._utils import maybe_coerce_float as maybe_coerce_float
 from ._utils import get_required_header as get_required_header
 from ._utils import maybe_coerce_boolean as maybe_coerce_boolean
 from ._utils import maybe_coerce_integer as maybe_coerce_integer
-from ._utils import strip_annotated_type as strip_annotated_type
+from ._typing import is_list_type as is_list_type
+from ._typing import is_union_type as is_union_type
+from ._typing import extract_type_arg as extract_type_arg
+from ._typing import is_required_type as is_required_type
+from ._typing import is_annotated_type as is_annotated_type
+from ._typing import strip_annotated_type as strip_annotated_type
+from ._typing import extract_type_var_from_base as extract_type_var_from_base
+from ._streams import consume_sync_iterator as consume_sync_iterator
+from ._streams import consume_async_iterator as consume_async_iterator
 from ._transform import PropertyInfo as PropertyInfo
 from ._transform import transform as transform
 from ._transform import maybe_transform as maybe_transform
src/openai/_utils/_streams.py
@@ -0,0 +1,12 @@
+from typing import Any
+from typing_extensions import Iterator, AsyncIterator
+
+
+def consume_sync_iterator(iterator: Iterator[Any]) -> None:
+    for _ in iterator:
+        ...
+
+
+async def consume_async_iterator(iterator: AsyncIterator[Any]) -> None:
+    async for _ in iterator:
+        ...
src/openai/_utils/_transform.py
@@ -6,9 +6,8 @@ from typing_extensions import Literal, get_args, override, get_type_hints
 
 import pydantic
 
-from ._utils import (
-    is_list,
-    is_mapping,
+from ._utils import is_list, is_mapping
+from ._typing import (
     is_list_type,
     is_union_type,
     extract_type_arg,
src/openai/_utils/_typing.py
@@ -0,0 +1,80 @@
+from __future__ import annotations
+
+from typing import Any, cast
+from typing_extensions import Required, Annotated, get_args, get_origin
+
+from .._types import InheritsGeneric
+from .._compat import is_union as _is_union
+
+
+def is_annotated_type(typ: type) -> bool:
+    return get_origin(typ) == Annotated
+
+
+def is_list_type(typ: type) -> bool:
+    return (get_origin(typ) or typ) == list
+
+
+def is_union_type(typ: type) -> bool:
+    return _is_union(get_origin(typ))
+
+
+def is_required_type(typ: type) -> bool:
+    return get_origin(typ) == Required
+
+
+# Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]]
+def strip_annotated_type(typ: type) -> type:
+    if is_required_type(typ) or is_annotated_type(typ):
+        return strip_annotated_type(cast(type, get_args(typ)[0]))
+
+    return typ
+
+
+def extract_type_arg(typ: type, index: int) -> type:
+    args = get_args(typ)
+    try:
+        return cast(type, args[index])
+    except IndexError as err:
+        raise RuntimeError(f"Expected type {typ} to have a type argument at index {index} but it did not") from err
+
+
+def extract_type_var_from_base(typ: type, *, generic_bases: tuple[type, ...], index: int) -> type:
+    """Given a type like `Foo[T]`, returns the generic type variable `T`.
+
+    This also handles the case where a concrete subclass is given, e.g.
+    ```py
+    class MyResponse(Foo[bytes]):
+        ...
+
+    extract_type_var(MyResponse, bases=(Foo,), index=0) -> bytes
+    ```
+    """
+    cls = cast(object, get_origin(typ) or typ)
+    if cls in generic_bases:
+        # we're given the class directly
+        return extract_type_arg(typ, index)
+
+    # if a subclass is given
+    # ---
+    # this is needed as __orig_bases__ is not present in the typeshed stubs
+    # because it is intended to be for internal use only, however there does
+    # not seem to be a way to resolve generic TypeVars for inherited subclasses
+    # without using it.
+    if isinstance(cls, InheritsGeneric):
+        target_base_class: Any | None = None
+        for base in cls.__orig_bases__:
+            if base.__origin__ in generic_bases:
+                target_base_class = base
+                break
+
+        if target_base_class is None:
+            raise RuntimeError(
+                "Could not find the generic base class;\n"
+                "This should never happen;\n"
+                f"Does {cls} inherit from one of {generic_bases} ?"
+            )
+
+        return extract_type_arg(target_base_class, index)
+
+    raise RuntimeError(f"Could not resolve inner type variable at index {index} for {typ}")
src/openai/_utils/_utils.py
@@ -16,12 +16,11 @@ from typing import (
     overload,
 )
 from pathlib import Path
-from typing_extensions import Required, Annotated, TypeGuard, get_args, get_origin
+from typing_extensions import TypeGuard
 
 import sniffio
 
 from .._types import Headers, NotGiven, FileTypes, NotGivenOr, HeadersLike
-from .._compat import is_union as _is_union
 from .._compat import parse_date as parse_date
 from .._compat import parse_datetime as parse_datetime
 
@@ -166,38 +165,6 @@ def is_list(obj: object) -> TypeGuard[list[object]]:
     return isinstance(obj, list)
 
 
-def is_annotated_type(typ: type) -> bool:
-    return get_origin(typ) == Annotated
-
-
-def is_list_type(typ: type) -> bool:
-    return (get_origin(typ) or typ) == list
-
-
-def is_union_type(typ: type) -> bool:
-    return _is_union(get_origin(typ))
-
-
-def is_required_type(typ: type) -> bool:
-    return get_origin(typ) == Required
-
-
-# Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]]
-def strip_annotated_type(typ: type) -> type:
-    if is_required_type(typ) or is_annotated_type(typ):
-        return strip_annotated_type(cast(type, get_args(typ)[0]))
-
-    return typ
-
-
-def extract_type_arg(typ: type, index: int) -> type:
-    args = get_args(typ)
-    try:
-        return cast(type, args[index])
-    except IndexError as err:
-        raise RuntimeError(f"Expected type {typ} to have a type argument at index {index} but it did not") from err
-
-
 def deepcopy_minimal(item: _T) -> _T:
     """Minimal reimplementation of copy.deepcopy() that will only copy certain object types:
 
src/openai/_response.py
@@ -5,12 +5,12 @@ import logging
 import datetime
 import functools
 from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, cast
-from typing_extensions import Awaitable, ParamSpec, get_args, override, get_origin
+from typing_extensions import Awaitable, ParamSpec, override, get_origin
 
 import httpx
 
 from ._types import NoneType, UnknownResponse, BinaryResponseContent
-from ._utils import is_given
+from ._utils import is_given, extract_type_var_from_base
 from ._models import BaseModel, is_basemodel
 from ._constants import RAW_RESPONSE_HEADER
 from ._exceptions import APIResponseValidationError
@@ -221,12 +221,13 @@ class MissingStreamClassError(TypeError):
 
 
 def _extract_stream_chunk_type(stream_cls: type) -> type:
-    args = get_args(stream_cls)
-    if not args:
-        raise TypeError(
-            f"Expected stream_cls to have been given a generic type argument, e.g. Stream[Foo] but received {stream_cls}",
-        )
-    return cast(type, args[0])
+    from ._base_client import Stream, AsyncStream
+
+    return extract_type_var_from_base(
+        stream_cls,
+        index=0,
+        generic_bases=cast("tuple[type, ...]", (Stream, AsyncStream)),
+    )
 
 
 def to_raw_response_wrapper(func: Callable[P, R]) -> Callable[P, APIResponse[R]]:
src/openai/_streaming.py
@@ -2,12 +2,12 @@
 from __future__ import annotations
 
 import json
-from typing import TYPE_CHECKING, Any, Generic, Iterator, AsyncIterator
-from typing_extensions import override
+from types import TracebackType
+from typing import TYPE_CHECKING, Any, Generic, TypeVar, Iterator, AsyncIterator, cast
+from typing_extensions import Self, override
 
 import httpx
 
-from ._types import ResponseT
 from ._utils import is_mapping
 from ._exceptions import APIError
 
@@ -15,7 +15,10 @@ if TYPE_CHECKING:
     from ._client import OpenAI, AsyncOpenAI
 
 
-class Stream(Generic[ResponseT]):
+_T = TypeVar("_T")
+
+
+class Stream(Generic[_T]):
     """Provides the core interface to iterate over a synchronous stream response."""
 
     response: httpx.Response
@@ -23,7 +26,7 @@ class Stream(Generic[ResponseT]):
     def __init__(
         self,
         *,
-        cast_to: type[ResponseT],
+        cast_to: type[_T],
         response: httpx.Response,
         client: OpenAI,
     ) -> None:
@@ -33,18 +36,18 @@ class Stream(Generic[ResponseT]):
         self._decoder = SSEDecoder()
         self._iterator = self.__stream__()
 
-    def __next__(self) -> ResponseT:
+    def __next__(self) -> _T:
         return self._iterator.__next__()
 
-    def __iter__(self) -> Iterator[ResponseT]:
+    def __iter__(self) -> Iterator[_T]:
         for item in self._iterator:
             yield item
 
     def _iter_events(self) -> Iterator[ServerSentEvent]:
         yield from self._decoder.iter(self.response.iter_lines())
 
-    def __stream__(self) -> Iterator[ResponseT]:
-        cast_to = self._cast_to
+    def __stream__(self) -> Iterator[_T]:
+        cast_to = cast(Any, self._cast_to)
         response = self.response
         process_data = self._client._process_response_data
         iterator = self._iter_events()
@@ -68,8 +71,27 @@ class Stream(Generic[ResponseT]):
         for _sse in iterator:
             ...
 
+    def __enter__(self) -> Self:
+        return self
+
+    def __exit__(
+        self,
+        exc_type: type[BaseException] | None,
+        exc: BaseException | None,
+        exc_tb: TracebackType | None,
+    ) -> None:
+        self.close()
+
+    def close(self) -> None:
+        """
+        Close the response and release the connection.
+
+        Automatically called if the response body is read to completion.
+        """
+        self.response.close()
 
-class AsyncStream(Generic[ResponseT]):
+
+class AsyncStream(Generic[_T]):
     """Provides the core interface to iterate over an asynchronous stream response."""
 
     response: httpx.Response
@@ -77,7 +99,7 @@ class AsyncStream(Generic[ResponseT]):
     def __init__(
         self,
         *,
-        cast_to: type[ResponseT],
+        cast_to: type[_T],
         response: httpx.Response,
         client: AsyncOpenAI,
     ) -> None:
@@ -87,10 +109,10 @@ class AsyncStream(Generic[ResponseT]):
         self._decoder = SSEDecoder()
         self._iterator = self.__stream__()
 
-    async def __anext__(self) -> ResponseT:
+    async def __anext__(self) -> _T:
         return await self._iterator.__anext__()
 
-    async def __aiter__(self) -> AsyncIterator[ResponseT]:
+    async def __aiter__(self) -> AsyncIterator[_T]:
         async for item in self._iterator:
             yield item
 
@@ -98,8 +120,8 @@ class AsyncStream(Generic[ResponseT]):
         async for sse in self._decoder.aiter(self.response.aiter_lines()):
             yield sse
 
-    async def __stream__(self) -> AsyncIterator[ResponseT]:
-        cast_to = self._cast_to
+    async def __stream__(self) -> AsyncIterator[_T]:
+        cast_to = cast(Any, self._cast_to)
         response = self.response
         process_data = self._client._process_response_data
         iterator = self._iter_events()
@@ -123,6 +145,25 @@ class AsyncStream(Generic[ResponseT]):
         async for _sse in iterator:
             ...
 
+    async def __aenter__(self) -> Self:
+        return self
+
+    async def __aexit__(
+        self,
+        exc_type: type[BaseException] | None,
+        exc: BaseException | None,
+        exc_tb: TracebackType | None,
+    ) -> None:
+        await self.close()
+
+    async def close(self) -> None:
+        """
+        Close the response and release the connection.
+
+        Automatically called if the response body is read to completion.
+        """
+        await self.response.aclose()
+
 
 class ServerSentEvent:
     def __init__(
src/openai/_types.py
@@ -353,3 +353,17 @@ StrBytesIntFloat = Union[str, bytes, int, float]
 IncEx: TypeAlias = "set[int] | set[str] | dict[int, Any] | dict[str, Any] | None"
 
 PostParser = Callable[[Any], Any]
+
+
+@runtime_checkable
+class InheritsGeneric(Protocol):
+    """Represents a type that has inherited from `Generic`
+    The `__orig_bases__` property can be used to determine the resolved
+    type variable for a given base class.
+    """
+
+    __orig_bases__: tuple[_GenericAlias]
+
+
+class _GenericAlias(Protocol):
+    __origin__: type[object]