Commit 297796b7

Stainless Bot <107565488+stainless-bot@users.noreply.github.com>
2024-02-29 20:32:23
chore(internal): minor core client restructuring (#1199)
1 parent e41abf7
Changed files (2)
src/openai/_base_client.py
@@ -79,7 +79,7 @@ from ._constants import (
     RAW_RESPONSE_HEADER,
     OVERRIDE_CAST_TO_HEADER,
 )
-from ._streaming import Stream, AsyncStream
+from ._streaming import Stream, SSEDecoder, AsyncStream, SSEBytesDecoder
 from ._exceptions import (
     APIStatusError,
     APITimeoutError,
@@ -431,6 +431,9 @@ class BaseClient(Generic[_HttpxClientT, _DefaultStreamT]):
 
         return merge_url
 
+    def _make_sse_decoder(self) -> SSEDecoder | SSEBytesDecoder:
+        return SSEDecoder()
+
     def _build_request(
         self,
         options: FinalRequestOptions,
src/openai/_streaming.py
@@ -5,7 +5,7 @@ import json
 import inspect
 from types import TracebackType
 from typing import TYPE_CHECKING, Any, Generic, TypeVar, Iterator, AsyncIterator, cast
-from typing_extensions import Self, TypeGuard, override, get_origin
+from typing_extensions import Self, Protocol, TypeGuard, override, get_origin, runtime_checkable
 
 import httpx
 
@@ -24,6 +24,8 @@ class Stream(Generic[_T]):
 
     response: httpx.Response
 
+    _decoder: SSEDecoder | SSEBytesDecoder
+
     def __init__(
         self,
         *,
@@ -34,7 +36,7 @@ class Stream(Generic[_T]):
         self.response = response
         self._cast_to = cast_to
         self._client = client
-        self._decoder = SSEDecoder()
+        self._decoder = client._make_sse_decoder()
         self._iterator = self.__stream__()
 
     def __next__(self) -> _T:
@@ -45,7 +47,10 @@ class Stream(Generic[_T]):
             yield item
 
     def _iter_events(self) -> Iterator[ServerSentEvent]:
-        yield from self._decoder.iter(self.response.iter_lines())
+        if isinstance(self._decoder, SSEBytesDecoder):
+            yield from self._decoder.iter_bytes(self.response.iter_bytes())
+        else:
+            yield from self._decoder.iter(self.response.iter_lines())
 
     def __stream__(self) -> Iterator[_T]:
         cast_to = cast(Any, self._cast_to)
@@ -97,6 +102,8 @@ class AsyncStream(Generic[_T]):
 
     response: httpx.Response
 
+    _decoder: SSEDecoder | SSEBytesDecoder
+
     def __init__(
         self,
         *,
@@ -107,7 +114,7 @@ class AsyncStream(Generic[_T]):
         self.response = response
         self._cast_to = cast_to
         self._client = client
-        self._decoder = SSEDecoder()
+        self._decoder = client._make_sse_decoder()
         self._iterator = self.__stream__()
 
     async def __anext__(self) -> _T:
@@ -118,8 +125,12 @@ class AsyncStream(Generic[_T]):
             yield item
 
     async def _iter_events(self) -> AsyncIterator[ServerSentEvent]:
-        async for sse in self._decoder.aiter(self.response.aiter_lines()):
-            yield sse
+        if isinstance(self._decoder, SSEBytesDecoder):
+            async for sse in self._decoder.aiter_bytes(self.response.aiter_bytes()):
+                yield sse
+        else:
+            async for sse in self._decoder.aiter(self.response.aiter_lines()):
+                yield sse
 
     async def __stream__(self) -> AsyncIterator[_T]:
         cast_to = cast(Any, self._cast_to)
@@ -284,6 +295,17 @@ class SSEDecoder:
         return None
 
 
+@runtime_checkable
+class SSEBytesDecoder(Protocol):
+    def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]:
+        """Given an iterator that yields raw binary data, iterate over it & yield every event encountered"""
+        ...
+
+    def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]:
+        """Given an async iterator that yields raw binary data, iterate over it & yield every event encountered"""
+        ...
+
+
 def is_stream_class_type(typ: type) -> TypeGuard[type[Stream[object]] | type[AsyncStream[object]]]:
     """TypeGuard for determining whether or not the given type is a subclass of `Stream` / `AsyncStream`"""
     origin = get_origin(typ) or typ