main
1from __future__ import annotations
2
3import asyncio
4from types import TracebackType
5from typing import TYPE_CHECKING, Any, Generic, TypeVar, Callable, Iterable, Iterator, cast
6from typing_extensions import Awaitable, AsyncIterable, AsyncIterator, assert_never
7
8import httpx
9
10from ..._utils import is_dict, is_list, consume_sync_iterator, consume_async_iterator
11from ..._compat import model_dump
12from ..._models import construct_type
13from ..._streaming import Stream, AsyncStream
14from ...types.beta import AssistantStreamEvent
15from ...types.beta.threads import (
16 Run,
17 Text,
18 Message,
19 ImageFile,
20 TextDelta,
21 MessageDelta,
22 MessageContent,
23 MessageContentDelta,
24)
25from ...types.beta.threads.runs import RunStep, ToolCall, RunStepDelta, ToolCallDelta
26
27
28class AssistantEventHandler:
29 text_deltas: Iterable[str]
30 """Iterator over just the text deltas in the stream.
31
32 This corresponds to the `thread.message.delta` event
33 in the API.
34
35 ```py
36 for text in stream.text_deltas:
37 print(text, end="", flush=True)
38 print()
39 ```
40 """
41
42 def __init__(self) -> None:
43 self._current_event: AssistantStreamEvent | None = None
44 self._current_message_content_index: int | None = None
45 self._current_message_content: MessageContent | None = None
46 self._current_tool_call_index: int | None = None
47 self._current_tool_call: ToolCall | None = None
48 self.__current_run_step_id: str | None = None
49 self.__current_run: Run | None = None
50 self.__run_step_snapshots: dict[str, RunStep] = {}
51 self.__message_snapshots: dict[str, Message] = {}
52 self.__current_message_snapshot: Message | None = None
53
54 self.text_deltas = self.__text_deltas__()
55 self._iterator = self.__stream__()
56 self.__stream: Stream[AssistantStreamEvent] | None = None
57
58 def _init(self, stream: Stream[AssistantStreamEvent]) -> None:
59 if self.__stream:
60 raise RuntimeError(
61 "A single event handler cannot be shared between multiple streams; You will need to construct a new event handler instance"
62 )
63
64 self.__stream = stream
65
66 def __next__(self) -> AssistantStreamEvent:
67 return self._iterator.__next__()
68
69 def __iter__(self) -> Iterator[AssistantStreamEvent]:
70 for item in self._iterator:
71 yield item
72
73 @property
74 def current_event(self) -> AssistantStreamEvent | None:
75 return self._current_event
76
77 @property
78 def current_run(self) -> Run | None:
79 return self.__current_run
80
81 @property
82 def current_run_step_snapshot(self) -> RunStep | None:
83 if not self.__current_run_step_id:
84 return None
85
86 return self.__run_step_snapshots[self.__current_run_step_id]
87
88 @property
89 def current_message_snapshot(self) -> Message | None:
90 return self.__current_message_snapshot
91
92 def close(self) -> None:
93 """
94 Close the response and release the connection.
95
96 Automatically called when the context manager exits.
97 """
98 if self.__stream:
99 self.__stream.close()
100
101 def until_done(self) -> None:
102 """Waits until the stream has been consumed"""
103 consume_sync_iterator(self)
104
105 def get_final_run(self) -> Run:
106 """Wait for the stream to finish and returns the completed Run object"""
107 self.until_done()
108
109 if not self.__current_run:
110 raise RuntimeError("No final run object found")
111
112 return self.__current_run
113
114 def get_final_run_steps(self) -> list[RunStep]:
115 """Wait for the stream to finish and returns the steps taken in this run"""
116 self.until_done()
117
118 if not self.__run_step_snapshots:
119 raise RuntimeError("No run steps found")
120
121 return [step for step in self.__run_step_snapshots.values()]
122
123 def get_final_messages(self) -> list[Message]:
124 """Wait for the stream to finish and returns the messages emitted in this run"""
125 self.until_done()
126
127 if not self.__message_snapshots:
128 raise RuntimeError("No messages found")
129
130 return [message for message in self.__message_snapshots.values()]
131
132 def __text_deltas__(self) -> Iterator[str]:
133 for event in self:
134 if event.event != "thread.message.delta":
135 continue
136
137 for content_delta in event.data.delta.content or []:
138 if content_delta.type == "text" and content_delta.text and content_delta.text.value:
139 yield content_delta.text.value
140
141 # event handlers
142
143 def on_end(self) -> None:
144 """Fires when the stream has finished.
145
146 This happens if the stream is read to completion
147 or if an exception occurs during iteration.
148 """
149
150 def on_event(self, event: AssistantStreamEvent) -> None:
151 """Callback that is fired for every Server-Sent-Event"""
152
153 def on_run_step_created(self, run_step: RunStep) -> None:
154 """Callback that is fired when a run step is created"""
155
156 def on_run_step_delta(self, delta: RunStepDelta, snapshot: RunStep) -> None:
157 """Callback that is fired whenever a run step delta is returned from the API
158
159 The first argument is just the delta as sent by the API and the second argument
160 is the accumulated snapshot of the run step. For example, a tool calls event may
161 look like this:
162
163 # delta
164 tool_calls=[
165 RunStepDeltaToolCallsCodeInterpreter(
166 index=0,
167 type='code_interpreter',
168 id=None,
169 code_interpreter=CodeInterpreter(input=' sympy', outputs=None)
170 )
171 ]
172 # snapshot
173 tool_calls=[
174 CodeToolCall(
175 id='call_wKayJlcYV12NiadiZuJXxcfx',
176 code_interpreter=CodeInterpreter(input='from sympy', outputs=[]),
177 type='code_interpreter',
178 index=0
179 )
180 ],
181 """
182
183 def on_run_step_done(self, run_step: RunStep) -> None:
184 """Callback that is fired when a run step is completed"""
185
186 def on_tool_call_created(self, tool_call: ToolCall) -> None:
187 """Callback that is fired when a tool call is created"""
188
189 def on_tool_call_delta(self, delta: ToolCallDelta, snapshot: ToolCall) -> None:
190 """Callback that is fired when a tool call delta is encountered"""
191
192 def on_tool_call_done(self, tool_call: ToolCall) -> None:
193 """Callback that is fired when a tool call delta is encountered"""
194
195 def on_exception(self, exception: Exception) -> None:
196 """Fired whenever an exception happens during streaming"""
197
198 def on_timeout(self) -> None:
199 """Fires if the request times out"""
200
201 def on_message_created(self, message: Message) -> None:
202 """Callback that is fired when a message is created"""
203
204 def on_message_delta(self, delta: MessageDelta, snapshot: Message) -> None:
205 """Callback that is fired whenever a message delta is returned from the API
206
207 The first argument is just the delta as sent by the API and the second argument
208 is the accumulated snapshot of the message. For example, a text content event may
209 look like this:
210
211 # delta
212 MessageDeltaText(
213 index=0,
214 type='text',
215 text=Text(
216 value=' Jane'
217 ),
218 )
219 # snapshot
220 MessageContentText(
221 index=0,
222 type='text',
223 text=Text(
224 value='Certainly, Jane'
225 ),
226 )
227 """
228
229 def on_message_done(self, message: Message) -> None:
230 """Callback that is fired when a message is completed"""
231
232 def on_text_created(self, text: Text) -> None:
233 """Callback that is fired when a text content block is created"""
234
235 def on_text_delta(self, delta: TextDelta, snapshot: Text) -> None:
236 """Callback that is fired whenever a text content delta is returned
237 by the API.
238
239 The first argument is just the delta as sent by the API and the second argument
240 is the accumulated snapshot of the text. For example:
241
242 on_text_delta(TextDelta(value="The"), Text(value="The")),
243 on_text_delta(TextDelta(value=" solution"), Text(value="The solution")),
244 on_text_delta(TextDelta(value=" to"), Text(value="The solution to")),
245 on_text_delta(TextDelta(value=" the"), Text(value="The solution to the")),
246 on_text_delta(TextDelta(value=" equation"), Text(value="The solution to the equation")),
247 """
248
249 def on_text_done(self, text: Text) -> None:
250 """Callback that is fired when a text content block is finished"""
251
252 def on_image_file_done(self, image_file: ImageFile) -> None:
253 """Callback that is fired when an image file block is finished"""
254
255 def _emit_sse_event(self, event: AssistantStreamEvent) -> None:
256 self._current_event = event
257 self.on_event(event)
258
259 self.__current_message_snapshot, new_content = accumulate_event(
260 event=event,
261 current_message_snapshot=self.__current_message_snapshot,
262 )
263 if self.__current_message_snapshot is not None:
264 self.__message_snapshots[self.__current_message_snapshot.id] = self.__current_message_snapshot
265
266 accumulate_run_step(
267 event=event,
268 run_step_snapshots=self.__run_step_snapshots,
269 )
270
271 for content_delta in new_content:
272 assert self.__current_message_snapshot is not None
273
274 block = self.__current_message_snapshot.content[content_delta.index]
275 if block.type == "text":
276 self.on_text_created(block.text)
277
278 if (
279 event.event == "thread.run.completed"
280 or event.event == "thread.run.cancelled"
281 or event.event == "thread.run.expired"
282 or event.event == "thread.run.failed"
283 or event.event == "thread.run.requires_action"
284 or event.event == "thread.run.incomplete"
285 ):
286 self.__current_run = event.data
287 if self._current_tool_call:
288 self.on_tool_call_done(self._current_tool_call)
289 elif (
290 event.event == "thread.run.created"
291 or event.event == "thread.run.in_progress"
292 or event.event == "thread.run.cancelling"
293 or event.event == "thread.run.queued"
294 ):
295 self.__current_run = event.data
296 elif event.event == "thread.message.created":
297 self.on_message_created(event.data)
298 elif event.event == "thread.message.delta":
299 snapshot = self.__current_message_snapshot
300 assert snapshot is not None
301
302 message_delta = event.data.delta
303 if message_delta.content is not None:
304 for content_delta in message_delta.content:
305 if content_delta.type == "text" and content_delta.text:
306 snapshot_content = snapshot.content[content_delta.index]
307 assert snapshot_content.type == "text"
308 self.on_text_delta(content_delta.text, snapshot_content.text)
309
310 # If the delta is for a new message content:
311 # - emit on_text_done/on_image_file_done for the previous message content
312 # - emit on_text_created/on_image_created for the new message content
313 if content_delta.index != self._current_message_content_index:
314 if self._current_message_content is not None:
315 if self._current_message_content.type == "text":
316 self.on_text_done(self._current_message_content.text)
317 elif self._current_message_content.type == "image_file":
318 self.on_image_file_done(self._current_message_content.image_file)
319
320 self._current_message_content_index = content_delta.index
321 self._current_message_content = snapshot.content[content_delta.index]
322
323 # Update the current_message_content (delta event is correctly emitted already)
324 self._current_message_content = snapshot.content[content_delta.index]
325
326 self.on_message_delta(event.data.delta, snapshot)
327 elif event.event == "thread.message.completed" or event.event == "thread.message.incomplete":
328 self.__current_message_snapshot = event.data
329 self.__message_snapshots[event.data.id] = event.data
330
331 if self._current_message_content_index is not None:
332 content = event.data.content[self._current_message_content_index]
333 if content.type == "text":
334 self.on_text_done(content.text)
335 elif content.type == "image_file":
336 self.on_image_file_done(content.image_file)
337
338 self.on_message_done(event.data)
339 elif event.event == "thread.run.step.created":
340 self.__current_run_step_id = event.data.id
341 self.on_run_step_created(event.data)
342 elif event.event == "thread.run.step.in_progress":
343 self.__current_run_step_id = event.data.id
344 elif event.event == "thread.run.step.delta":
345 step_snapshot = self.__run_step_snapshots[event.data.id]
346
347 run_step_delta = event.data.delta
348 if (
349 run_step_delta.step_details
350 and run_step_delta.step_details.type == "tool_calls"
351 and run_step_delta.step_details.tool_calls is not None
352 ):
353 assert step_snapshot.step_details.type == "tool_calls"
354 for tool_call_delta in run_step_delta.step_details.tool_calls:
355 if tool_call_delta.index == self._current_tool_call_index:
356 self.on_tool_call_delta(
357 tool_call_delta,
358 step_snapshot.step_details.tool_calls[tool_call_delta.index],
359 )
360
361 # If the delta is for a new tool call:
362 # - emit on_tool_call_done for the previous tool_call
363 # - emit on_tool_call_created for the new tool_call
364 if tool_call_delta.index != self._current_tool_call_index:
365 if self._current_tool_call is not None:
366 self.on_tool_call_done(self._current_tool_call)
367
368 self._current_tool_call_index = tool_call_delta.index
369 self._current_tool_call = step_snapshot.step_details.tool_calls[tool_call_delta.index]
370 self.on_tool_call_created(self._current_tool_call)
371
372 # Update the current_tool_call (delta event is correctly emitted already)
373 self._current_tool_call = step_snapshot.step_details.tool_calls[tool_call_delta.index]
374
375 self.on_run_step_delta(
376 event.data.delta,
377 step_snapshot,
378 )
379 elif (
380 event.event == "thread.run.step.completed"
381 or event.event == "thread.run.step.cancelled"
382 or event.event == "thread.run.step.expired"
383 or event.event == "thread.run.step.failed"
384 ):
385 if self._current_tool_call:
386 self.on_tool_call_done(self._current_tool_call)
387
388 self.on_run_step_done(event.data)
389 self.__current_run_step_id = None
390 elif event.event == "thread.created" or event.event == "thread.message.in_progress" or event.event == "error":
391 # currently no special handling
392 ...
393 else:
394 # we only want to error at build-time
395 if TYPE_CHECKING: # type: ignore[unreachable]
396 assert_never(event)
397
398 self._current_event = None
399
400 def __stream__(self) -> Iterator[AssistantStreamEvent]:
401 stream = self.__stream
402 if not stream:
403 raise RuntimeError("Stream has not been started yet")
404
405 try:
406 for event in stream:
407 self._emit_sse_event(event)
408
409 yield event
410 except (httpx.TimeoutException, asyncio.TimeoutError) as exc:
411 self.on_timeout()
412 self.on_exception(exc)
413 raise
414 except Exception as exc:
415 self.on_exception(exc)
416 raise
417 finally:
418 self.on_end()
419
420
421AssistantEventHandlerT = TypeVar("AssistantEventHandlerT", bound=AssistantEventHandler)
422
423
424class AssistantStreamManager(Generic[AssistantEventHandlerT]):
425 """Wrapper over AssistantStreamEventHandler that is returned by `.stream()`
426 so that a context manager can be used.
427
428 ```py
429 with client.threads.create_and_run_stream(...) as stream:
430 for event in stream:
431 ...
432 ```
433 """
434
435 def __init__(
436 self,
437 api_request: Callable[[], Stream[AssistantStreamEvent]],
438 *,
439 event_handler: AssistantEventHandlerT,
440 ) -> None:
441 self.__stream: Stream[AssistantStreamEvent] | None = None
442 self.__event_handler = event_handler
443 self.__api_request = api_request
444
445 def __enter__(self) -> AssistantEventHandlerT:
446 self.__stream = self.__api_request()
447 self.__event_handler._init(self.__stream)
448 return self.__event_handler
449
450 def __exit__(
451 self,
452 exc_type: type[BaseException] | None,
453 exc: BaseException | None,
454 exc_tb: TracebackType | None,
455 ) -> None:
456 if self.__stream is not None:
457 self.__stream.close()
458
459
460class AsyncAssistantEventHandler:
461 text_deltas: AsyncIterable[str]
462 """Iterator over just the text deltas in the stream.
463
464 This corresponds to the `thread.message.delta` event
465 in the API.
466
467 ```py
468 async for text in stream.text_deltas:
469 print(text, end="", flush=True)
470 print()
471 ```
472 """
473
474 def __init__(self) -> None:
475 self._current_event: AssistantStreamEvent | None = None
476 self._current_message_content_index: int | None = None
477 self._current_message_content: MessageContent | None = None
478 self._current_tool_call_index: int | None = None
479 self._current_tool_call: ToolCall | None = None
480 self.__current_run_step_id: str | None = None
481 self.__current_run: Run | None = None
482 self.__run_step_snapshots: dict[str, RunStep] = {}
483 self.__message_snapshots: dict[str, Message] = {}
484 self.__current_message_snapshot: Message | None = None
485
486 self.text_deltas = self.__text_deltas__()
487 self._iterator = self.__stream__()
488 self.__stream: AsyncStream[AssistantStreamEvent] | None = None
489
490 def _init(self, stream: AsyncStream[AssistantStreamEvent]) -> None:
491 if self.__stream:
492 raise RuntimeError(
493 "A single event handler cannot be shared between multiple streams; You will need to construct a new event handler instance"
494 )
495
496 self.__stream = stream
497
498 async def __anext__(self) -> AssistantStreamEvent:
499 return await self._iterator.__anext__()
500
501 async def __aiter__(self) -> AsyncIterator[AssistantStreamEvent]:
502 async for item in self._iterator:
503 yield item
504
505 async def close(self) -> None:
506 """
507 Close the response and release the connection.
508
509 Automatically called when the context manager exits.
510 """
511 if self.__stream:
512 await self.__stream.close()
513
514 @property
515 def current_event(self) -> AssistantStreamEvent | None:
516 return self._current_event
517
518 @property
519 def current_run(self) -> Run | None:
520 return self.__current_run
521
522 @property
523 def current_run_step_snapshot(self) -> RunStep | None:
524 if not self.__current_run_step_id:
525 return None
526
527 return self.__run_step_snapshots[self.__current_run_step_id]
528
529 @property
530 def current_message_snapshot(self) -> Message | None:
531 return self.__current_message_snapshot
532
533 async def until_done(self) -> None:
534 """Waits until the stream has been consumed"""
535 await consume_async_iterator(self)
536
537 async def get_final_run(self) -> Run:
538 """Wait for the stream to finish and returns the completed Run object"""
539 await self.until_done()
540
541 if not self.__current_run:
542 raise RuntimeError("No final run object found")
543
544 return self.__current_run
545
546 async def get_final_run_steps(self) -> list[RunStep]:
547 """Wait for the stream to finish and returns the steps taken in this run"""
548 await self.until_done()
549
550 if not self.__run_step_snapshots:
551 raise RuntimeError("No run steps found")
552
553 return [step for step in self.__run_step_snapshots.values()]
554
555 async def get_final_messages(self) -> list[Message]:
556 """Wait for the stream to finish and returns the messages emitted in this run"""
557 await self.until_done()
558
559 if not self.__message_snapshots:
560 raise RuntimeError("No messages found")
561
562 return [message for message in self.__message_snapshots.values()]
563
564 async def __text_deltas__(self) -> AsyncIterator[str]:
565 async for event in self:
566 if event.event != "thread.message.delta":
567 continue
568
569 for content_delta in event.data.delta.content or []:
570 if content_delta.type == "text" and content_delta.text and content_delta.text.value:
571 yield content_delta.text.value
572
573 # event handlers
574
575 async def on_end(self) -> None:
576 """Fires when the stream has finished.
577
578 This happens if the stream is read to completion
579 or if an exception occurs during iteration.
580 """
581
582 async def on_event(self, event: AssistantStreamEvent) -> None:
583 """Callback that is fired for every Server-Sent-Event"""
584
585 async def on_run_step_created(self, run_step: RunStep) -> None:
586 """Callback that is fired when a run step is created"""
587
588 async def on_run_step_delta(self, delta: RunStepDelta, snapshot: RunStep) -> None:
589 """Callback that is fired whenever a run step delta is returned from the API
590
591 The first argument is just the delta as sent by the API and the second argument
592 is the accumulated snapshot of the run step. For example, a tool calls event may
593 look like this:
594
595 # delta
596 tool_calls=[
597 RunStepDeltaToolCallsCodeInterpreter(
598 index=0,
599 type='code_interpreter',
600 id=None,
601 code_interpreter=CodeInterpreter(input=' sympy', outputs=None)
602 )
603 ]
604 # snapshot
605 tool_calls=[
606 CodeToolCall(
607 id='call_wKayJlcYV12NiadiZuJXxcfx',
608 code_interpreter=CodeInterpreter(input='from sympy', outputs=[]),
609 type='code_interpreter',
610 index=0
611 )
612 ],
613 """
614
615 async def on_run_step_done(self, run_step: RunStep) -> None:
616 """Callback that is fired when a run step is completed"""
617
618 async def on_tool_call_created(self, tool_call: ToolCall) -> None:
619 """Callback that is fired when a tool call is created"""
620
621 async def on_tool_call_delta(self, delta: ToolCallDelta, snapshot: ToolCall) -> None:
622 """Callback that is fired when a tool call delta is encountered"""
623
624 async def on_tool_call_done(self, tool_call: ToolCall) -> None:
625 """Callback that is fired when a tool call delta is encountered"""
626
627 async def on_exception(self, exception: Exception) -> None:
628 """Fired whenever an exception happens during streaming"""
629
630 async def on_timeout(self) -> None:
631 """Fires if the request times out"""
632
633 async def on_message_created(self, message: Message) -> None:
634 """Callback that is fired when a message is created"""
635
636 async def on_message_delta(self, delta: MessageDelta, snapshot: Message) -> None:
637 """Callback that is fired whenever a message delta is returned from the API
638
639 The first argument is just the delta as sent by the API and the second argument
640 is the accumulated snapshot of the message. For example, a text content event may
641 look like this:
642
643 # delta
644 MessageDeltaText(
645 index=0,
646 type='text',
647 text=Text(
648 value=' Jane'
649 ),
650 )
651 # snapshot
652 MessageContentText(
653 index=0,
654 type='text',
655 text=Text(
656 value='Certainly, Jane'
657 ),
658 )
659 """
660
661 async def on_message_done(self, message: Message) -> None:
662 """Callback that is fired when a message is completed"""
663
664 async def on_text_created(self, text: Text) -> None:
665 """Callback that is fired when a text content block is created"""
666
667 async def on_text_delta(self, delta: TextDelta, snapshot: Text) -> None:
668 """Callback that is fired whenever a text content delta is returned
669 by the API.
670
671 The first argument is just the delta as sent by the API and the second argument
672 is the accumulated snapshot of the text. For example:
673
674 on_text_delta(TextDelta(value="The"), Text(value="The")),
675 on_text_delta(TextDelta(value=" solution"), Text(value="The solution")),
676 on_text_delta(TextDelta(value=" to"), Text(value="The solution to")),
677 on_text_delta(TextDelta(value=" the"), Text(value="The solution to the")),
678 on_text_delta(TextDelta(value=" equation"), Text(value="The solution to the equivalent")),
679 """
680
681 async def on_text_done(self, text: Text) -> None:
682 """Callback that is fired when a text content block is finished"""
683
684 async def on_image_file_done(self, image_file: ImageFile) -> None:
685 """Callback that is fired when an image file block is finished"""
686
687 async def _emit_sse_event(self, event: AssistantStreamEvent) -> None:
688 self._current_event = event
689 await self.on_event(event)
690
691 self.__current_message_snapshot, new_content = accumulate_event(
692 event=event,
693 current_message_snapshot=self.__current_message_snapshot,
694 )
695 if self.__current_message_snapshot is not None:
696 self.__message_snapshots[self.__current_message_snapshot.id] = self.__current_message_snapshot
697
698 accumulate_run_step(
699 event=event,
700 run_step_snapshots=self.__run_step_snapshots,
701 )
702
703 for content_delta in new_content:
704 assert self.__current_message_snapshot is not None
705
706 block = self.__current_message_snapshot.content[content_delta.index]
707 if block.type == "text":
708 await self.on_text_created(block.text)
709
710 if (
711 event.event == "thread.run.completed"
712 or event.event == "thread.run.cancelled"
713 or event.event == "thread.run.expired"
714 or event.event == "thread.run.failed"
715 or event.event == "thread.run.requires_action"
716 or event.event == "thread.run.incomplete"
717 ):
718 self.__current_run = event.data
719 if self._current_tool_call:
720 await self.on_tool_call_done(self._current_tool_call)
721 elif (
722 event.event == "thread.run.created"
723 or event.event == "thread.run.in_progress"
724 or event.event == "thread.run.cancelling"
725 or event.event == "thread.run.queued"
726 ):
727 self.__current_run = event.data
728 elif event.event == "thread.message.created":
729 await self.on_message_created(event.data)
730 elif event.event == "thread.message.delta":
731 snapshot = self.__current_message_snapshot
732 assert snapshot is not None
733
734 message_delta = event.data.delta
735 if message_delta.content is not None:
736 for content_delta in message_delta.content:
737 if content_delta.type == "text" and content_delta.text:
738 snapshot_content = snapshot.content[content_delta.index]
739 assert snapshot_content.type == "text"
740 await self.on_text_delta(content_delta.text, snapshot_content.text)
741
742 # If the delta is for a new message content:
743 # - emit on_text_done/on_image_file_done for the previous message content
744 # - emit on_text_created/on_image_created for the new message content
745 if content_delta.index != self._current_message_content_index:
746 if self._current_message_content is not None:
747 if self._current_message_content.type == "text":
748 await self.on_text_done(self._current_message_content.text)
749 elif self._current_message_content.type == "image_file":
750 await self.on_image_file_done(self._current_message_content.image_file)
751
752 self._current_message_content_index = content_delta.index
753 self._current_message_content = snapshot.content[content_delta.index]
754
755 # Update the current_message_content (delta event is correctly emitted already)
756 self._current_message_content = snapshot.content[content_delta.index]
757
758 await self.on_message_delta(event.data.delta, snapshot)
759 elif event.event == "thread.message.completed" or event.event == "thread.message.incomplete":
760 self.__current_message_snapshot = event.data
761 self.__message_snapshots[event.data.id] = event.data
762
763 if self._current_message_content_index is not None:
764 content = event.data.content[self._current_message_content_index]
765 if content.type == "text":
766 await self.on_text_done(content.text)
767 elif content.type == "image_file":
768 await self.on_image_file_done(content.image_file)
769
770 await self.on_message_done(event.data)
771 elif event.event == "thread.run.step.created":
772 self.__current_run_step_id = event.data.id
773 await self.on_run_step_created(event.data)
774 elif event.event == "thread.run.step.in_progress":
775 self.__current_run_step_id = event.data.id
776 elif event.event == "thread.run.step.delta":
777 step_snapshot = self.__run_step_snapshots[event.data.id]
778
779 run_step_delta = event.data.delta
780 if (
781 run_step_delta.step_details
782 and run_step_delta.step_details.type == "tool_calls"
783 and run_step_delta.step_details.tool_calls is not None
784 ):
785 assert step_snapshot.step_details.type == "tool_calls"
786 for tool_call_delta in run_step_delta.step_details.tool_calls:
787 if tool_call_delta.index == self._current_tool_call_index:
788 await self.on_tool_call_delta(
789 tool_call_delta,
790 step_snapshot.step_details.tool_calls[tool_call_delta.index],
791 )
792
793 # If the delta is for a new tool call:
794 # - emit on_tool_call_done for the previous tool_call
795 # - emit on_tool_call_created for the new tool_call
796 if tool_call_delta.index != self._current_tool_call_index:
797 if self._current_tool_call is not None:
798 await self.on_tool_call_done(self._current_tool_call)
799
800 self._current_tool_call_index = tool_call_delta.index
801 self._current_tool_call = step_snapshot.step_details.tool_calls[tool_call_delta.index]
802 await self.on_tool_call_created(self._current_tool_call)
803
804 # Update the current_tool_call (delta event is correctly emitted already)
805 self._current_tool_call = step_snapshot.step_details.tool_calls[tool_call_delta.index]
806
807 await self.on_run_step_delta(
808 event.data.delta,
809 step_snapshot,
810 )
811 elif (
812 event.event == "thread.run.step.completed"
813 or event.event == "thread.run.step.cancelled"
814 or event.event == "thread.run.step.expired"
815 or event.event == "thread.run.step.failed"
816 ):
817 if self._current_tool_call:
818 await self.on_tool_call_done(self._current_tool_call)
819
820 await self.on_run_step_done(event.data)
821 self.__current_run_step_id = None
822 elif event.event == "thread.created" or event.event == "thread.message.in_progress" or event.event == "error":
823 # currently no special handling
824 ...
825 else:
826 # we only want to error at build-time
827 if TYPE_CHECKING: # type: ignore[unreachable]
828 assert_never(event)
829
830 self._current_event = None
831
832 async def __stream__(self) -> AsyncIterator[AssistantStreamEvent]:
833 stream = self.__stream
834 if not stream:
835 raise RuntimeError("Stream has not been started yet")
836
837 try:
838 async for event in stream:
839 await self._emit_sse_event(event)
840
841 yield event
842 except (httpx.TimeoutException, asyncio.TimeoutError) as exc:
843 await self.on_timeout()
844 await self.on_exception(exc)
845 raise
846 except Exception as exc:
847 await self.on_exception(exc)
848 raise
849 finally:
850 await self.on_end()
851
852
853AsyncAssistantEventHandlerT = TypeVar("AsyncAssistantEventHandlerT", bound=AsyncAssistantEventHandler)
854
855
856class AsyncAssistantStreamManager(Generic[AsyncAssistantEventHandlerT]):
857 """Wrapper over AsyncAssistantStreamEventHandler that is returned by `.stream()`
858 so that an async context manager can be used without `await`ing the
859 original client call.
860
861 ```py
862 async with client.threads.create_and_run_stream(...) as stream:
863 async for event in stream:
864 ...
865 ```
866 """
867
868 def __init__(
869 self,
870 api_request: Awaitable[AsyncStream[AssistantStreamEvent]],
871 *,
872 event_handler: AsyncAssistantEventHandlerT,
873 ) -> None:
874 self.__stream: AsyncStream[AssistantStreamEvent] | None = None
875 self.__event_handler = event_handler
876 self.__api_request = api_request
877
878 async def __aenter__(self) -> AsyncAssistantEventHandlerT:
879 self.__stream = await self.__api_request
880 self.__event_handler._init(self.__stream)
881 return self.__event_handler
882
883 async def __aexit__(
884 self,
885 exc_type: type[BaseException] | None,
886 exc: BaseException | None,
887 exc_tb: TracebackType | None,
888 ) -> None:
889 if self.__stream is not None:
890 await self.__stream.close()
891
892
893def accumulate_run_step(
894 *,
895 event: AssistantStreamEvent,
896 run_step_snapshots: dict[str, RunStep],
897) -> None:
898 if event.event == "thread.run.step.created":
899 run_step_snapshots[event.data.id] = event.data
900 return
901
902 if event.event == "thread.run.step.delta":
903 data = event.data
904 snapshot = run_step_snapshots[data.id]
905
906 if data.delta:
907 merged = accumulate_delta(
908 cast(
909 "dict[object, object]",
910 model_dump(snapshot, exclude_unset=True, warnings=False),
911 ),
912 cast(
913 "dict[object, object]",
914 model_dump(data.delta, exclude_unset=True, warnings=False),
915 ),
916 )
917 run_step_snapshots[snapshot.id] = cast(RunStep, construct_type(type_=RunStep, value=merged))
918
919 return None
920
921
922def accumulate_event(
923 *,
924 event: AssistantStreamEvent,
925 current_message_snapshot: Message | None,
926) -> tuple[Message | None, list[MessageContentDelta]]:
927 """Returns a tuple of message snapshot and newly created text message deltas"""
928 if event.event == "thread.message.created":
929 return event.data, []
930
931 new_content: list[MessageContentDelta] = []
932
933 if event.event != "thread.message.delta":
934 return current_message_snapshot, []
935
936 if not current_message_snapshot:
937 raise RuntimeError("Encountered a message delta with no previous snapshot")
938
939 data = event.data
940 if data.delta.content:
941 for content_delta in data.delta.content:
942 try:
943 block = current_message_snapshot.content[content_delta.index]
944 except IndexError:
945 current_message_snapshot.content.insert(
946 content_delta.index,
947 cast(
948 MessageContent,
949 construct_type(
950 # mypy doesn't allow Content for some reason
951 type_=cast(Any, MessageContent),
952 value=model_dump(content_delta, exclude_unset=True, warnings=False),
953 ),
954 ),
955 )
956 new_content.append(content_delta)
957 else:
958 merged = accumulate_delta(
959 cast(
960 "dict[object, object]",
961 model_dump(block, exclude_unset=True, warnings=False),
962 ),
963 cast(
964 "dict[object, object]",
965 model_dump(content_delta, exclude_unset=True, warnings=False),
966 ),
967 )
968 current_message_snapshot.content[content_delta.index] = cast(
969 MessageContent,
970 construct_type(
971 # mypy doesn't allow Content for some reason
972 type_=cast(Any, MessageContent),
973 value=merged,
974 ),
975 )
976
977 return current_message_snapshot, new_content
978
979
980def accumulate_delta(acc: dict[object, object], delta: dict[object, object]) -> dict[object, object]:
981 for key, delta_value in delta.items():
982 if key not in acc:
983 acc[key] = delta_value
984 continue
985
986 acc_value = acc[key]
987 if acc_value is None:
988 acc[key] = delta_value
989 continue
990
991 # the `index` property is used in arrays of objects so it should
992 # not be accumulated like other values e.g.
993 # [{'foo': 'bar', 'index': 0}]
994 #
995 # the same applies to `type` properties as they're used for
996 # discriminated unions
997 if key == "index" or key == "type":
998 acc[key] = delta_value
999 continue
1000
1001 if isinstance(acc_value, str) and isinstance(delta_value, str):
1002 acc_value += delta_value
1003 elif isinstance(acc_value, (int, float)) and isinstance(delta_value, (int, float)):
1004 acc_value += delta_value
1005 elif is_dict(acc_value) and is_dict(delta_value):
1006 acc_value = accumulate_delta(acc_value, delta_value)
1007 elif is_list(acc_value) and is_list(delta_value):
1008 # for lists of non-dictionary items we'll only ever get new entries
1009 # in the array, existing entries will never be changed
1010 if all(isinstance(x, (str, int, float)) for x in acc_value):
1011 acc_value.extend(delta_value)
1012 continue
1013
1014 for delta_entry in delta_value:
1015 if not is_dict(delta_entry):
1016 raise TypeError(f"Unexpected list delta entry is not a dictionary: {delta_entry}")
1017
1018 try:
1019 index = delta_entry["index"]
1020 except KeyError as exc:
1021 raise RuntimeError(f"Expected list delta entry to have an `index` key; {delta_entry}") from exc
1022
1023 if not isinstance(index, int):
1024 raise TypeError(f"Unexpected, list delta entry `index` value is not an integer; {index}")
1025
1026 try:
1027 acc_entry = acc_value[index]
1028 except IndexError:
1029 acc_value.insert(index, delta_entry)
1030 else:
1031 if not is_dict(acc_entry):
1032 raise TypeError("not handled yet")
1033
1034 acc_value[index] = accumulate_delta(acc_entry, delta_entry)
1035
1036 acc[key] = acc_value
1037
1038 return acc