main
1# Note: initially copied from https://github.com/florimondmanca/httpx-sse/blob/master/src/httpx_sse/_decoders.py
2from __future__ import annotations
3
4import json
5import inspect
6from types import TracebackType
7from typing import TYPE_CHECKING, Any, Generic, TypeVar, Iterator, AsyncIterator, cast
8from typing_extensions import Self, Protocol, TypeGuard, override, get_origin, runtime_checkable
9
10import httpx
11
12from ._utils import is_mapping, extract_type_var_from_base
13from ._exceptions import APIError
14
15if TYPE_CHECKING:
16 from ._client import OpenAI, AsyncOpenAI
17
18
19_T = TypeVar("_T")
20
21
22class Stream(Generic[_T]):
23 """Provides the core interface to iterate over a synchronous stream response."""
24
25 response: httpx.Response
26
27 _decoder: SSEBytesDecoder
28
29 def __init__(
30 self,
31 *,
32 cast_to: type[_T],
33 response: httpx.Response,
34 client: OpenAI,
35 ) -> None:
36 self.response = response
37 self._cast_to = cast_to
38 self._client = client
39 self._decoder = client._make_sse_decoder()
40 self._iterator = self.__stream__()
41
42 def __next__(self) -> _T:
43 return self._iterator.__next__()
44
45 def __iter__(self) -> Iterator[_T]:
46 for item in self._iterator:
47 yield item
48
49 def _iter_events(self) -> Iterator[ServerSentEvent]:
50 yield from self._decoder.iter_bytes(self.response.iter_bytes())
51
52 def __stream__(self) -> Iterator[_T]:
53 cast_to = cast(Any, self._cast_to)
54 response = self.response
55 process_data = self._client._process_response_data
56 iterator = self._iter_events()
57
58 try:
59 for sse in iterator:
60 if sse.data.startswith("[DONE]"):
61 break
62
63 # we have to special case the Assistants `thread.` events since we won't have an "event" key in the data
64 if sse.event and sse.event.startswith("thread."):
65 data = sse.json()
66
67 if sse.event == "error" and is_mapping(data) and data.get("error"):
68 message = None
69 error = data.get("error")
70 if is_mapping(error):
71 message = error.get("message")
72 if not message or not isinstance(message, str):
73 message = "An error occurred during streaming"
74
75 raise APIError(
76 message=message,
77 request=self.response.request,
78 body=data["error"],
79 )
80
81 yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response)
82 else:
83 data = sse.json()
84 if is_mapping(data) and data.get("error"):
85 message = None
86 error = data.get("error")
87 if is_mapping(error):
88 message = error.get("message")
89 if not message or not isinstance(message, str):
90 message = "An error occurred during streaming"
91
92 raise APIError(
93 message=message,
94 request=self.response.request,
95 body=data["error"],
96 )
97
98 yield process_data(data=data, cast_to=cast_to, response=response)
99
100 finally:
101 # Ensure the response is closed even if the consumer doesn't read all data
102 response.close()
103
104 def __enter__(self) -> Self:
105 return self
106
107 def __exit__(
108 self,
109 exc_type: type[BaseException] | None,
110 exc: BaseException | None,
111 exc_tb: TracebackType | None,
112 ) -> None:
113 self.close()
114
115 def close(self) -> None:
116 """
117 Close the response and release the connection.
118
119 Automatically called if the response body is read to completion.
120 """
121 self.response.close()
122
123
124class AsyncStream(Generic[_T]):
125 """Provides the core interface to iterate over an asynchronous stream response."""
126
127 response: httpx.Response
128
129 _decoder: SSEDecoder | SSEBytesDecoder
130
131 def __init__(
132 self,
133 *,
134 cast_to: type[_T],
135 response: httpx.Response,
136 client: AsyncOpenAI,
137 ) -> None:
138 self.response = response
139 self._cast_to = cast_to
140 self._client = client
141 self._decoder = client._make_sse_decoder()
142 self._iterator = self.__stream__()
143
144 async def __anext__(self) -> _T:
145 return await self._iterator.__anext__()
146
147 async def __aiter__(self) -> AsyncIterator[_T]:
148 async for item in self._iterator:
149 yield item
150
151 async def _iter_events(self) -> AsyncIterator[ServerSentEvent]:
152 async for sse in self._decoder.aiter_bytes(self.response.aiter_bytes()):
153 yield sse
154
155 async def __stream__(self) -> AsyncIterator[_T]:
156 cast_to = cast(Any, self._cast_to)
157 response = self.response
158 process_data = self._client._process_response_data
159 iterator = self._iter_events()
160
161 try:
162 async for sse in iterator:
163 if sse.data.startswith("[DONE]"):
164 break
165
166 # we have to special case the Assistants `thread.` events since we won't have an "event" key in the data
167 if sse.event and sse.event.startswith("thread."):
168 data = sse.json()
169
170 if sse.event == "error" and is_mapping(data) and data.get("error"):
171 message = None
172 error = data.get("error")
173 if is_mapping(error):
174 message = error.get("message")
175 if not message or not isinstance(message, str):
176 message = "An error occurred during streaming"
177
178 raise APIError(
179 message=message,
180 request=self.response.request,
181 body=data["error"],
182 )
183
184 yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response)
185 else:
186 data = sse.json()
187 if is_mapping(data) and data.get("error"):
188 message = None
189 error = data.get("error")
190 if is_mapping(error):
191 message = error.get("message")
192 if not message or not isinstance(message, str):
193 message = "An error occurred during streaming"
194
195 raise APIError(
196 message=message,
197 request=self.response.request,
198 body=data["error"],
199 )
200
201 yield process_data(data=data, cast_to=cast_to, response=response)
202
203 finally:
204 # Ensure the response is closed even if the consumer doesn't read all data
205 await response.aclose()
206
207 async def __aenter__(self) -> Self:
208 return self
209
210 async def __aexit__(
211 self,
212 exc_type: type[BaseException] | None,
213 exc: BaseException | None,
214 exc_tb: TracebackType | None,
215 ) -> None:
216 await self.close()
217
218 async def close(self) -> None:
219 """
220 Close the response and release the connection.
221
222 Automatically called if the response body is read to completion.
223 """
224 await self.response.aclose()
225
226
227class ServerSentEvent:
228 def __init__(
229 self,
230 *,
231 event: str | None = None,
232 data: str | None = None,
233 id: str | None = None,
234 retry: int | None = None,
235 ) -> None:
236 if data is None:
237 data = ""
238
239 self._id = id
240 self._data = data
241 self._event = event or None
242 self._retry = retry
243
244 @property
245 def event(self) -> str | None:
246 return self._event
247
248 @property
249 def id(self) -> str | None:
250 return self._id
251
252 @property
253 def retry(self) -> int | None:
254 return self._retry
255
256 @property
257 def data(self) -> str:
258 return self._data
259
260 def json(self) -> Any:
261 return json.loads(self.data)
262
263 @override
264 def __repr__(self) -> str:
265 return f"ServerSentEvent(event={self.event}, data={self.data}, id={self.id}, retry={self.retry})"
266
267
268class SSEDecoder:
269 _data: list[str]
270 _event: str | None
271 _retry: int | None
272 _last_event_id: str | None
273
274 def __init__(self) -> None:
275 self._event = None
276 self._data = []
277 self._last_event_id = None
278 self._retry = None
279
280 def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]:
281 """Given an iterator that yields raw binary data, iterate over it & yield every event encountered"""
282 for chunk in self._iter_chunks(iterator):
283 # Split before decoding so splitlines() only uses \r and \n
284 for raw_line in chunk.splitlines():
285 line = raw_line.decode("utf-8")
286 sse = self.decode(line)
287 if sse:
288 yield sse
289
290 def _iter_chunks(self, iterator: Iterator[bytes]) -> Iterator[bytes]:
291 """Given an iterator that yields raw binary data, iterate over it and yield individual SSE chunks"""
292 data = b""
293 for chunk in iterator:
294 for line in chunk.splitlines(keepends=True):
295 data += line
296 if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")):
297 yield data
298 data = b""
299 if data:
300 yield data
301
302 async def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]:
303 """Given an iterator that yields raw binary data, iterate over it & yield every event encountered"""
304 async for chunk in self._aiter_chunks(iterator):
305 # Split before decoding so splitlines() only uses \r and \n
306 for raw_line in chunk.splitlines():
307 line = raw_line.decode("utf-8")
308 sse = self.decode(line)
309 if sse:
310 yield sse
311
312 async def _aiter_chunks(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[bytes]:
313 """Given an iterator that yields raw binary data, iterate over it and yield individual SSE chunks"""
314 data = b""
315 async for chunk in iterator:
316 for line in chunk.splitlines(keepends=True):
317 data += line
318 if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")):
319 yield data
320 data = b""
321 if data:
322 yield data
323
324 def decode(self, line: str) -> ServerSentEvent | None:
325 # See: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation # noqa: E501
326
327 if not line:
328 if not self._event and not self._data and not self._last_event_id and self._retry is None:
329 return None
330
331 sse = ServerSentEvent(
332 event=self._event,
333 data="\n".join(self._data),
334 id=self._last_event_id,
335 retry=self._retry,
336 )
337
338 # NOTE: as per the SSE spec, do not reset last_event_id.
339 self._event = None
340 self._data = []
341 self._retry = None
342
343 return sse
344
345 if line.startswith(":"):
346 return None
347
348 fieldname, _, value = line.partition(":")
349
350 if value.startswith(" "):
351 value = value[1:]
352
353 if fieldname == "event":
354 self._event = value
355 elif fieldname == "data":
356 self._data.append(value)
357 elif fieldname == "id":
358 if "\0" in value:
359 pass
360 else:
361 self._last_event_id = value
362 elif fieldname == "retry":
363 try:
364 self._retry = int(value)
365 except (TypeError, ValueError):
366 pass
367 else:
368 pass # Field is ignored.
369
370 return None
371
372
373@runtime_checkable
374class SSEBytesDecoder(Protocol):
375 def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]:
376 """Given an iterator that yields raw binary data, iterate over it & yield every event encountered"""
377 ...
378
379 def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]:
380 """Given an async iterator that yields raw binary data, iterate over it & yield every event encountered"""
381 ...
382
383
384def is_stream_class_type(typ: type) -> TypeGuard[type[Stream[object]] | type[AsyncStream[object]]]:
385 """TypeGuard for determining whether or not the given type is a subclass of `Stream` / `AsyncStream`"""
386 origin = get_origin(typ) or typ
387 return inspect.isclass(origin) and issubclass(origin, (Stream, AsyncStream))
388
389
390def extract_stream_chunk_type(
391 stream_cls: type,
392 *,
393 failure_message: str | None = None,
394) -> type:
395 """Given a type like `Stream[T]`, returns the generic type variable `T`.
396
397 This also handles the case where a concrete subclass is given, e.g.
398 ```py
399 class MyStream(Stream[bytes]):
400 ...
401
402 extract_stream_chunk_type(MyStream) -> bytes
403 ```
404 """
405 from ._base_client import Stream, AsyncStream
406
407 return extract_type_var_from_base(
408 stream_cls,
409 index=0,
410 generic_bases=cast("tuple[type, ...]", (Stream, AsyncStream)),
411 failure_message=failure_message,
412 )