main
  1# Note: initially copied from https://github.com/florimondmanca/httpx-sse/blob/master/src/httpx_sse/_decoders.py
  2from __future__ import annotations
  3
  4import json
  5import inspect
  6from types import TracebackType
  7from typing import TYPE_CHECKING, Any, Generic, TypeVar, Iterator, AsyncIterator, cast
  8from typing_extensions import Self, Protocol, TypeGuard, override, get_origin, runtime_checkable
  9
 10import httpx
 11
 12from ._utils import is_mapping, extract_type_var_from_base
 13from ._exceptions import APIError
 14
 15if TYPE_CHECKING:
 16    from ._client import OpenAI, AsyncOpenAI
 17
 18
 19_T = TypeVar("_T")
 20
 21
 22class Stream(Generic[_T]):
 23    """Provides the core interface to iterate over a synchronous stream response."""
 24
 25    response: httpx.Response
 26
 27    _decoder: SSEBytesDecoder
 28
 29    def __init__(
 30        self,
 31        *,
 32        cast_to: type[_T],
 33        response: httpx.Response,
 34        client: OpenAI,
 35    ) -> None:
 36        self.response = response
 37        self._cast_to = cast_to
 38        self._client = client
 39        self._decoder = client._make_sse_decoder()
 40        self._iterator = self.__stream__()
 41
 42    def __next__(self) -> _T:
 43        return self._iterator.__next__()
 44
 45    def __iter__(self) -> Iterator[_T]:
 46        for item in self._iterator:
 47            yield item
 48
 49    def _iter_events(self) -> Iterator[ServerSentEvent]:
 50        yield from self._decoder.iter_bytes(self.response.iter_bytes())
 51
 52    def __stream__(self) -> Iterator[_T]:
 53        cast_to = cast(Any, self._cast_to)
 54        response = self.response
 55        process_data = self._client._process_response_data
 56        iterator = self._iter_events()
 57
 58        try:
 59            for sse in iterator:
 60                if sse.data.startswith("[DONE]"):
 61                    break
 62
 63                # we have to special case the Assistants `thread.` events since we won't have an "event" key in the data
 64                if sse.event and sse.event.startswith("thread."):
 65                    data = sse.json()
 66
 67                    if sse.event == "error" and is_mapping(data) and data.get("error"):
 68                        message = None
 69                        error = data.get("error")
 70                        if is_mapping(error):
 71                            message = error.get("message")
 72                        if not message or not isinstance(message, str):
 73                            message = "An error occurred during streaming"
 74
 75                        raise APIError(
 76                            message=message,
 77                            request=self.response.request,
 78                            body=data["error"],
 79                        )
 80
 81                    yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response)
 82                else:
 83                    data = sse.json()
 84                    if is_mapping(data) and data.get("error"):
 85                        message = None
 86                        error = data.get("error")
 87                        if is_mapping(error):
 88                            message = error.get("message")
 89                        if not message or not isinstance(message, str):
 90                            message = "An error occurred during streaming"
 91
 92                        raise APIError(
 93                            message=message,
 94                            request=self.response.request,
 95                            body=data["error"],
 96                        )
 97
 98                    yield process_data(data=data, cast_to=cast_to, response=response)
 99
100        finally:
101            # Ensure the response is closed even if the consumer doesn't read all data
102            response.close()
103
104    def __enter__(self) -> Self:
105        return self
106
107    def __exit__(
108        self,
109        exc_type: type[BaseException] | None,
110        exc: BaseException | None,
111        exc_tb: TracebackType | None,
112    ) -> None:
113        self.close()
114
115    def close(self) -> None:
116        """
117        Close the response and release the connection.
118
119        Automatically called if the response body is read to completion.
120        """
121        self.response.close()
122
123
124class AsyncStream(Generic[_T]):
125    """Provides the core interface to iterate over an asynchronous stream response."""
126
127    response: httpx.Response
128
129    _decoder: SSEDecoder | SSEBytesDecoder
130
131    def __init__(
132        self,
133        *,
134        cast_to: type[_T],
135        response: httpx.Response,
136        client: AsyncOpenAI,
137    ) -> None:
138        self.response = response
139        self._cast_to = cast_to
140        self._client = client
141        self._decoder = client._make_sse_decoder()
142        self._iterator = self.__stream__()
143
144    async def __anext__(self) -> _T:
145        return await self._iterator.__anext__()
146
147    async def __aiter__(self) -> AsyncIterator[_T]:
148        async for item in self._iterator:
149            yield item
150
151    async def _iter_events(self) -> AsyncIterator[ServerSentEvent]:
152        async for sse in self._decoder.aiter_bytes(self.response.aiter_bytes()):
153            yield sse
154
155    async def __stream__(self) -> AsyncIterator[_T]:
156        cast_to = cast(Any, self._cast_to)
157        response = self.response
158        process_data = self._client._process_response_data
159        iterator = self._iter_events()
160
161        try:
162            async for sse in iterator:
163                if sse.data.startswith("[DONE]"):
164                    break
165
166                # we have to special case the Assistants `thread.` events since we won't have an "event" key in the data
167                if sse.event and sse.event.startswith("thread."):
168                    data = sse.json()
169
170                    if sse.event == "error" and is_mapping(data) and data.get("error"):
171                        message = None
172                        error = data.get("error")
173                        if is_mapping(error):
174                            message = error.get("message")
175                        if not message or not isinstance(message, str):
176                            message = "An error occurred during streaming"
177
178                        raise APIError(
179                            message=message,
180                            request=self.response.request,
181                            body=data["error"],
182                        )
183
184                    yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response)
185                else:
186                    data = sse.json()
187                    if is_mapping(data) and data.get("error"):
188                        message = None
189                        error = data.get("error")
190                        if is_mapping(error):
191                            message = error.get("message")
192                        if not message or not isinstance(message, str):
193                            message = "An error occurred during streaming"
194
195                        raise APIError(
196                            message=message,
197                            request=self.response.request,
198                            body=data["error"],
199                        )
200
201                    yield process_data(data=data, cast_to=cast_to, response=response)
202
203        finally:
204            # Ensure the response is closed even if the consumer doesn't read all data
205            await response.aclose()
206
207    async def __aenter__(self) -> Self:
208        return self
209
210    async def __aexit__(
211        self,
212        exc_type: type[BaseException] | None,
213        exc: BaseException | None,
214        exc_tb: TracebackType | None,
215    ) -> None:
216        await self.close()
217
218    async def close(self) -> None:
219        """
220        Close the response and release the connection.
221
222        Automatically called if the response body is read to completion.
223        """
224        await self.response.aclose()
225
226
227class ServerSentEvent:
228    def __init__(
229        self,
230        *,
231        event: str | None = None,
232        data: str | None = None,
233        id: str | None = None,
234        retry: int | None = None,
235    ) -> None:
236        if data is None:
237            data = ""
238
239        self._id = id
240        self._data = data
241        self._event = event or None
242        self._retry = retry
243
244    @property
245    def event(self) -> str | None:
246        return self._event
247
248    @property
249    def id(self) -> str | None:
250        return self._id
251
252    @property
253    def retry(self) -> int | None:
254        return self._retry
255
256    @property
257    def data(self) -> str:
258        return self._data
259
260    def json(self) -> Any:
261        return json.loads(self.data)
262
263    @override
264    def __repr__(self) -> str:
265        return f"ServerSentEvent(event={self.event}, data={self.data}, id={self.id}, retry={self.retry})"
266
267
268class SSEDecoder:
269    _data: list[str]
270    _event: str | None
271    _retry: int | None
272    _last_event_id: str | None
273
274    def __init__(self) -> None:
275        self._event = None
276        self._data = []
277        self._last_event_id = None
278        self._retry = None
279
280    def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]:
281        """Given an iterator that yields raw binary data, iterate over it & yield every event encountered"""
282        for chunk in self._iter_chunks(iterator):
283            # Split before decoding so splitlines() only uses \r and \n
284            for raw_line in chunk.splitlines():
285                line = raw_line.decode("utf-8")
286                sse = self.decode(line)
287                if sse:
288                    yield sse
289
290    def _iter_chunks(self, iterator: Iterator[bytes]) -> Iterator[bytes]:
291        """Given an iterator that yields raw binary data, iterate over it and yield individual SSE chunks"""
292        data = b""
293        for chunk in iterator:
294            for line in chunk.splitlines(keepends=True):
295                data += line
296                if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")):
297                    yield data
298                    data = b""
299        if data:
300            yield data
301
302    async def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]:
303        """Given an iterator that yields raw binary data, iterate over it & yield every event encountered"""
304        async for chunk in self._aiter_chunks(iterator):
305            # Split before decoding so splitlines() only uses \r and \n
306            for raw_line in chunk.splitlines():
307                line = raw_line.decode("utf-8")
308                sse = self.decode(line)
309                if sse:
310                    yield sse
311
312    async def _aiter_chunks(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[bytes]:
313        """Given an iterator that yields raw binary data, iterate over it and yield individual SSE chunks"""
314        data = b""
315        async for chunk in iterator:
316            for line in chunk.splitlines(keepends=True):
317                data += line
318                if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")):
319                    yield data
320                    data = b""
321        if data:
322            yield data
323
324    def decode(self, line: str) -> ServerSentEvent | None:
325        # See: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation  # noqa: E501
326
327        if not line:
328            if not self._event and not self._data and not self._last_event_id and self._retry is None:
329                return None
330
331            sse = ServerSentEvent(
332                event=self._event,
333                data="\n".join(self._data),
334                id=self._last_event_id,
335                retry=self._retry,
336            )
337
338            # NOTE: as per the SSE spec, do not reset last_event_id.
339            self._event = None
340            self._data = []
341            self._retry = None
342
343            return sse
344
345        if line.startswith(":"):
346            return None
347
348        fieldname, _, value = line.partition(":")
349
350        if value.startswith(" "):
351            value = value[1:]
352
353        if fieldname == "event":
354            self._event = value
355        elif fieldname == "data":
356            self._data.append(value)
357        elif fieldname == "id":
358            if "\0" in value:
359                pass
360            else:
361                self._last_event_id = value
362        elif fieldname == "retry":
363            try:
364                self._retry = int(value)
365            except (TypeError, ValueError):
366                pass
367        else:
368            pass  # Field is ignored.
369
370        return None
371
372
373@runtime_checkable
374class SSEBytesDecoder(Protocol):
375    def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]:
376        """Given an iterator that yields raw binary data, iterate over it & yield every event encountered"""
377        ...
378
379    def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]:
380        """Given an async iterator that yields raw binary data, iterate over it & yield every event encountered"""
381        ...
382
383
384def is_stream_class_type(typ: type) -> TypeGuard[type[Stream[object]] | type[AsyncStream[object]]]:
385    """TypeGuard for determining whether or not the given type is a subclass of `Stream` / `AsyncStream`"""
386    origin = get_origin(typ) or typ
387    return inspect.isclass(origin) and issubclass(origin, (Stream, AsyncStream))
388
389
390def extract_stream_chunk_type(
391    stream_cls: type,
392    *,
393    failure_message: str | None = None,
394) -> type:
395    """Given a type like `Stream[T]`, returns the generic type variable `T`.
396
397    This also handles the case where a concrete subclass is given, e.g.
398    ```py
399    class MyStream(Stream[bytes]):
400        ...
401
402    extract_stream_chunk_type(MyStream) -> bytes
403    ```
404    """
405    from ._base_client import Stream, AsyncStream
406
407    return extract_type_var_from_base(
408        stream_cls,
409        index=0,
410        generic_bases=cast("tuple[type, ...]", (Stream, AsyncStream)),
411        failure_message=failure_message,
412    )