Commit 297796b7
Changed files (2)
src
src/openai/_base_client.py
@@ -79,7 +79,7 @@ from ._constants import (
RAW_RESPONSE_HEADER,
OVERRIDE_CAST_TO_HEADER,
)
-from ._streaming import Stream, AsyncStream
+from ._streaming import Stream, SSEDecoder, AsyncStream, SSEBytesDecoder
from ._exceptions import (
APIStatusError,
APITimeoutError,
@@ -431,6 +431,9 @@ class BaseClient(Generic[_HttpxClientT, _DefaultStreamT]):
return merge_url
+ def _make_sse_decoder(self) -> SSEDecoder | SSEBytesDecoder:
+ return SSEDecoder()
+
def _build_request(
self,
options: FinalRequestOptions,
src/openai/_streaming.py
@@ -5,7 +5,7 @@ import json
import inspect
from types import TracebackType
from typing import TYPE_CHECKING, Any, Generic, TypeVar, Iterator, AsyncIterator, cast
-from typing_extensions import Self, TypeGuard, override, get_origin
+from typing_extensions import Self, Protocol, TypeGuard, override, get_origin, runtime_checkable
import httpx
@@ -24,6 +24,8 @@ class Stream(Generic[_T]):
response: httpx.Response
+ _decoder: SSEDecoder | SSEBytesDecoder
+
def __init__(
self,
*,
@@ -34,7 +36,7 @@ class Stream(Generic[_T]):
self.response = response
self._cast_to = cast_to
self._client = client
- self._decoder = SSEDecoder()
+ self._decoder = client._make_sse_decoder()
self._iterator = self.__stream__()
def __next__(self) -> _T:
@@ -45,7 +47,10 @@ class Stream(Generic[_T]):
yield item
def _iter_events(self) -> Iterator[ServerSentEvent]:
- yield from self._decoder.iter(self.response.iter_lines())
+ if isinstance(self._decoder, SSEBytesDecoder):
+ yield from self._decoder.iter_bytes(self.response.iter_bytes())
+ else:
+ yield from self._decoder.iter(self.response.iter_lines())
def __stream__(self) -> Iterator[_T]:
cast_to = cast(Any, self._cast_to)
@@ -97,6 +102,8 @@ class AsyncStream(Generic[_T]):
response: httpx.Response
+ _decoder: SSEDecoder | SSEBytesDecoder
+
def __init__(
self,
*,
@@ -107,7 +114,7 @@ class AsyncStream(Generic[_T]):
self.response = response
self._cast_to = cast_to
self._client = client
- self._decoder = SSEDecoder()
+ self._decoder = client._make_sse_decoder()
self._iterator = self.__stream__()
async def __anext__(self) -> _T:
@@ -118,8 +125,12 @@ class AsyncStream(Generic[_T]):
yield item
async def _iter_events(self) -> AsyncIterator[ServerSentEvent]:
- async for sse in self._decoder.aiter(self.response.aiter_lines()):
- yield sse
+ if isinstance(self._decoder, SSEBytesDecoder):
+ async for sse in self._decoder.aiter_bytes(self.response.aiter_bytes()):
+ yield sse
+ else:
+ async for sse in self._decoder.aiter(self.response.aiter_lines()):
+ yield sse
async def __stream__(self) -> AsyncIterator[_T]:
cast_to = cast(Any, self._cast_to)
@@ -284,6 +295,17 @@ class SSEDecoder:
return None
+@runtime_checkable
+class SSEBytesDecoder(Protocol):
+ def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]:
+ """Given an iterator that yields raw binary data, iterate over it & yield every event encountered"""
+ ...
+
+ def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]:
+ """Given an async iterator that yields raw binary data, iterate over it & yield every event encountered"""
+ ...
+
+
def is_stream_class_type(typ: type) -> TypeGuard[type[Stream[object]] | type[AsyncStream[object]]]:
"""TypeGuard for determining whether or not the given type is a subclass of `Stream` / `AsyncStream`"""
origin = get_origin(typ) or typ