main
1from __future__ import annotations
2
3import inspect
4from types import TracebackType
5from typing import Any, List, Generic, Iterable, Awaitable, cast
6from typing_extensions import Self, Callable, Iterator, AsyncIterator
7
8from ._types import ParsedResponseSnapshot
9from ._events import (
10 ResponseStreamEvent,
11 ResponseTextDoneEvent,
12 ResponseCompletedEvent,
13 ResponseTextDeltaEvent,
14 ResponseFunctionCallArgumentsDeltaEvent,
15)
16from ...._types import Omit, omit
17from ...._utils import is_given, consume_sync_iterator, consume_async_iterator
18from ...._models import build, construct_type_unchecked
19from ...._streaming import Stream, AsyncStream
20from ....types.responses import ParsedResponse, ResponseStreamEvent as RawResponseStreamEvent
21from ..._parsing._responses import TextFormatT, parse_text, parse_response
22from ....types.responses.tool_param import ToolParam
23from ....types.responses.parsed_response import (
24 ParsedContent,
25 ParsedResponseOutputMessage,
26 ParsedResponseFunctionToolCall,
27)
28
29
30class ResponseStream(Generic[TextFormatT]):
31 def __init__(
32 self,
33 *,
34 raw_stream: Stream[RawResponseStreamEvent],
35 text_format: type[TextFormatT] | Omit,
36 input_tools: Iterable[ToolParam] | Omit,
37 starting_after: int | None,
38 ) -> None:
39 self._raw_stream = raw_stream
40 self._response = raw_stream.response
41 self._iterator = self.__stream__()
42 self._state = ResponseStreamState(text_format=text_format, input_tools=input_tools)
43 self._starting_after = starting_after
44
45 def __next__(self) -> ResponseStreamEvent[TextFormatT]:
46 return self._iterator.__next__()
47
48 def __iter__(self) -> Iterator[ResponseStreamEvent[TextFormatT]]:
49 for item in self._iterator:
50 yield item
51
52 def __enter__(self) -> Self:
53 return self
54
55 def __stream__(self) -> Iterator[ResponseStreamEvent[TextFormatT]]:
56 for sse_event in self._raw_stream:
57 events_to_fire = self._state.handle_event(sse_event)
58 for event in events_to_fire:
59 if self._starting_after is None or event.sequence_number > self._starting_after:
60 yield event
61
62 def __exit__(
63 self,
64 exc_type: type[BaseException] | None,
65 exc: BaseException | None,
66 exc_tb: TracebackType | None,
67 ) -> None:
68 self.close()
69
70 def close(self) -> None:
71 """
72 Close the response and release the connection.
73
74 Automatically called if the response body is read to completion.
75 """
76 self._response.close()
77
78 def get_final_response(self) -> ParsedResponse[TextFormatT]:
79 """Waits until the stream has been read to completion and returns
80 the accumulated `ParsedResponse` object.
81 """
82 self.until_done()
83 response = self._state._completed_response
84 if not response:
85 raise RuntimeError("Didn't receive a `response.completed` event.")
86
87 return response
88
89 def until_done(self) -> Self:
90 """Blocks until the stream has been consumed."""
91 consume_sync_iterator(self)
92 return self
93
94
95class ResponseStreamManager(Generic[TextFormatT]):
96 def __init__(
97 self,
98 api_request: Callable[[], Stream[RawResponseStreamEvent]],
99 *,
100 text_format: type[TextFormatT] | Omit,
101 input_tools: Iterable[ToolParam] | Omit,
102 starting_after: int | None,
103 ) -> None:
104 self.__stream: ResponseStream[TextFormatT] | None = None
105 self.__api_request = api_request
106 self.__text_format = text_format
107 self.__input_tools = input_tools
108 self.__starting_after = starting_after
109
110 def __enter__(self) -> ResponseStream[TextFormatT]:
111 raw_stream = self.__api_request()
112
113 self.__stream = ResponseStream(
114 raw_stream=raw_stream,
115 text_format=self.__text_format,
116 input_tools=self.__input_tools,
117 starting_after=self.__starting_after,
118 )
119
120 return self.__stream
121
122 def __exit__(
123 self,
124 exc_type: type[BaseException] | None,
125 exc: BaseException | None,
126 exc_tb: TracebackType | None,
127 ) -> None:
128 if self.__stream is not None:
129 self.__stream.close()
130
131
132class AsyncResponseStream(Generic[TextFormatT]):
133 def __init__(
134 self,
135 *,
136 raw_stream: AsyncStream[RawResponseStreamEvent],
137 text_format: type[TextFormatT] | Omit,
138 input_tools: Iterable[ToolParam] | Omit,
139 starting_after: int | None,
140 ) -> None:
141 self._raw_stream = raw_stream
142 self._response = raw_stream.response
143 self._iterator = self.__stream__()
144 self._state = ResponseStreamState(text_format=text_format, input_tools=input_tools)
145 self._starting_after = starting_after
146
147 async def __anext__(self) -> ResponseStreamEvent[TextFormatT]:
148 return await self._iterator.__anext__()
149
150 async def __aiter__(self) -> AsyncIterator[ResponseStreamEvent[TextFormatT]]:
151 async for item in self._iterator:
152 yield item
153
154 async def __stream__(self) -> AsyncIterator[ResponseStreamEvent[TextFormatT]]:
155 async for sse_event in self._raw_stream:
156 events_to_fire = self._state.handle_event(sse_event)
157 for event in events_to_fire:
158 if self._starting_after is None or event.sequence_number > self._starting_after:
159 yield event
160
161 async def __aenter__(self) -> Self:
162 return self
163
164 async def __aexit__(
165 self,
166 exc_type: type[BaseException] | None,
167 exc: BaseException | None,
168 exc_tb: TracebackType | None,
169 ) -> None:
170 await self.close()
171
172 async def close(self) -> None:
173 """
174 Close the response and release the connection.
175
176 Automatically called if the response body is read to completion.
177 """
178 await self._response.aclose()
179
180 async def get_final_response(self) -> ParsedResponse[TextFormatT]:
181 """Waits until the stream has been read to completion and returns
182 the accumulated `ParsedResponse` object.
183 """
184 await self.until_done()
185 response = self._state._completed_response
186 if not response:
187 raise RuntimeError("Didn't receive a `response.completed` event.")
188
189 return response
190
191 async def until_done(self) -> Self:
192 """Blocks until the stream has been consumed."""
193 await consume_async_iterator(self)
194 return self
195
196
197class AsyncResponseStreamManager(Generic[TextFormatT]):
198 def __init__(
199 self,
200 api_request: Awaitable[AsyncStream[RawResponseStreamEvent]],
201 *,
202 text_format: type[TextFormatT] | Omit,
203 input_tools: Iterable[ToolParam] | Omit,
204 starting_after: int | None,
205 ) -> None:
206 self.__stream: AsyncResponseStream[TextFormatT] | None = None
207 self.__api_request = api_request
208 self.__text_format = text_format
209 self.__input_tools = input_tools
210 self.__starting_after = starting_after
211
212 async def __aenter__(self) -> AsyncResponseStream[TextFormatT]:
213 raw_stream = await self.__api_request
214
215 self.__stream = AsyncResponseStream(
216 raw_stream=raw_stream,
217 text_format=self.__text_format,
218 input_tools=self.__input_tools,
219 starting_after=self.__starting_after,
220 )
221
222 return self.__stream
223
224 async def __aexit__(
225 self,
226 exc_type: type[BaseException] | None,
227 exc: BaseException | None,
228 exc_tb: TracebackType | None,
229 ) -> None:
230 if self.__stream is not None:
231 await self.__stream.close()
232
233
234class ResponseStreamState(Generic[TextFormatT]):
235 def __init__(
236 self,
237 *,
238 input_tools: Iterable[ToolParam] | Omit,
239 text_format: type[TextFormatT] | Omit,
240 ) -> None:
241 self.__current_snapshot: ParsedResponseSnapshot | None = None
242 self._completed_response: ParsedResponse[TextFormatT] | None = None
243 self._input_tools = [tool for tool in input_tools] if is_given(input_tools) else []
244 self._text_format = text_format
245 self._rich_text_format: type | Omit = text_format if inspect.isclass(text_format) else omit
246
247 def handle_event(self, event: RawResponseStreamEvent) -> List[ResponseStreamEvent[TextFormatT]]:
248 self.__current_snapshot = snapshot = self.accumulate_event(event)
249
250 events: List[ResponseStreamEvent[TextFormatT]] = []
251
252 if event.type == "response.output_text.delta":
253 output = snapshot.output[event.output_index]
254 assert output.type == "message"
255
256 content = output.content[event.content_index]
257 assert content.type == "output_text"
258
259 events.append(
260 build(
261 ResponseTextDeltaEvent,
262 content_index=event.content_index,
263 delta=event.delta,
264 item_id=event.item_id,
265 output_index=event.output_index,
266 sequence_number=event.sequence_number,
267 logprobs=event.logprobs,
268 type="response.output_text.delta",
269 snapshot=content.text,
270 )
271 )
272 elif event.type == "response.output_text.done":
273 output = snapshot.output[event.output_index]
274 assert output.type == "message"
275
276 content = output.content[event.content_index]
277 assert content.type == "output_text"
278
279 events.append(
280 build(
281 ResponseTextDoneEvent[TextFormatT],
282 content_index=event.content_index,
283 item_id=event.item_id,
284 output_index=event.output_index,
285 sequence_number=event.sequence_number,
286 logprobs=event.logprobs,
287 type="response.output_text.done",
288 text=event.text,
289 parsed=parse_text(event.text, text_format=self._text_format),
290 )
291 )
292 elif event.type == "response.function_call_arguments.delta":
293 output = snapshot.output[event.output_index]
294 assert output.type == "function_call"
295
296 events.append(
297 build(
298 ResponseFunctionCallArgumentsDeltaEvent,
299 delta=event.delta,
300 item_id=event.item_id,
301 output_index=event.output_index,
302 sequence_number=event.sequence_number,
303 type="response.function_call_arguments.delta",
304 snapshot=output.arguments,
305 )
306 )
307
308 elif event.type == "response.completed":
309 response = self._completed_response
310 assert response is not None
311
312 events.append(
313 build(
314 ResponseCompletedEvent,
315 sequence_number=event.sequence_number,
316 type="response.completed",
317 response=response,
318 )
319 )
320 else:
321 events.append(event)
322
323 return events
324
325 def accumulate_event(self, event: RawResponseStreamEvent) -> ParsedResponseSnapshot:
326 snapshot = self.__current_snapshot
327 if snapshot is None:
328 return self._create_initial_response(event)
329
330 if event.type == "response.output_item.added":
331 if event.item.type == "function_call":
332 snapshot.output.append(
333 construct_type_unchecked(
334 type_=cast(Any, ParsedResponseFunctionToolCall), value=event.item.to_dict()
335 )
336 )
337 elif event.item.type == "message":
338 snapshot.output.append(
339 construct_type_unchecked(type_=cast(Any, ParsedResponseOutputMessage), value=event.item.to_dict())
340 )
341 else:
342 snapshot.output.append(event.item)
343 elif event.type == "response.content_part.added":
344 output = snapshot.output[event.output_index]
345 if output.type == "message":
346 output.content.append(
347 construct_type_unchecked(type_=cast(Any, ParsedContent), value=event.part.to_dict())
348 )
349 elif event.type == "response.output_text.delta":
350 output = snapshot.output[event.output_index]
351 if output.type == "message":
352 content = output.content[event.content_index]
353 assert content.type == "output_text"
354 content.text += event.delta
355 elif event.type == "response.function_call_arguments.delta":
356 output = snapshot.output[event.output_index]
357 if output.type == "function_call":
358 output.arguments += event.delta
359 elif event.type == "response.completed":
360 self._completed_response = parse_response(
361 text_format=self._text_format,
362 response=event.response,
363 input_tools=self._input_tools,
364 )
365
366 return snapshot
367
368 def _create_initial_response(self, event: RawResponseStreamEvent) -> ParsedResponseSnapshot:
369 if event.type != "response.created":
370 raise RuntimeError(f"Expected to have received `response.created` before `{event.type}`")
371
372 return construct_type_unchecked(type_=ParsedResponseSnapshot, value=event.response.to_dict())