main
   1from __future__ import annotations
   2
   3import asyncio
   4from types import TracebackType
   5from typing import TYPE_CHECKING, Any, Generic, TypeVar, Callable, Iterable, Iterator, cast
   6from typing_extensions import Awaitable, AsyncIterable, AsyncIterator, assert_never
   7
   8import httpx
   9
  10from ..._utils import is_dict, is_list, consume_sync_iterator, consume_async_iterator
  11from ..._compat import model_dump
  12from ..._models import construct_type
  13from ..._streaming import Stream, AsyncStream
  14from ...types.beta import AssistantStreamEvent
  15from ...types.beta.threads import (
  16    Run,
  17    Text,
  18    Message,
  19    ImageFile,
  20    TextDelta,
  21    MessageDelta,
  22    MessageContent,
  23    MessageContentDelta,
  24)
  25from ...types.beta.threads.runs import RunStep, ToolCall, RunStepDelta, ToolCallDelta
  26
  27
  28class AssistantEventHandler:
  29    text_deltas: Iterable[str]
  30    """Iterator over just the text deltas in the stream.
  31
  32    This corresponds to the `thread.message.delta` event
  33    in the API.
  34
  35    ```py
  36    for text in stream.text_deltas:
  37        print(text, end="", flush=True)
  38    print()
  39    ```
  40    """
  41
  42    def __init__(self) -> None:
  43        self._current_event: AssistantStreamEvent | None = None
  44        self._current_message_content_index: int | None = None
  45        self._current_message_content: MessageContent | None = None
  46        self._current_tool_call_index: int | None = None
  47        self._current_tool_call: ToolCall | None = None
  48        self.__current_run_step_id: str | None = None
  49        self.__current_run: Run | None = None
  50        self.__run_step_snapshots: dict[str, RunStep] = {}
  51        self.__message_snapshots: dict[str, Message] = {}
  52        self.__current_message_snapshot: Message | None = None
  53
  54        self.text_deltas = self.__text_deltas__()
  55        self._iterator = self.__stream__()
  56        self.__stream: Stream[AssistantStreamEvent] | None = None
  57
  58    def _init(self, stream: Stream[AssistantStreamEvent]) -> None:
  59        if self.__stream:
  60            raise RuntimeError(
  61                "A single event handler cannot be shared between multiple streams; You will need to construct a new event handler instance"
  62            )
  63
  64        self.__stream = stream
  65
  66    def __next__(self) -> AssistantStreamEvent:
  67        return self._iterator.__next__()
  68
  69    def __iter__(self) -> Iterator[AssistantStreamEvent]:
  70        for item in self._iterator:
  71            yield item
  72
  73    @property
  74    def current_event(self) -> AssistantStreamEvent | None:
  75        return self._current_event
  76
  77    @property
  78    def current_run(self) -> Run | None:
  79        return self.__current_run
  80
  81    @property
  82    def current_run_step_snapshot(self) -> RunStep | None:
  83        if not self.__current_run_step_id:
  84            return None
  85
  86        return self.__run_step_snapshots[self.__current_run_step_id]
  87
  88    @property
  89    def current_message_snapshot(self) -> Message | None:
  90        return self.__current_message_snapshot
  91
  92    def close(self) -> None:
  93        """
  94        Close the response and release the connection.
  95
  96        Automatically called when the context manager exits.
  97        """
  98        if self.__stream:
  99            self.__stream.close()
 100
 101    def until_done(self) -> None:
 102        """Waits until the stream has been consumed"""
 103        consume_sync_iterator(self)
 104
 105    def get_final_run(self) -> Run:
 106        """Wait for the stream to finish and returns the completed Run object"""
 107        self.until_done()
 108
 109        if not self.__current_run:
 110            raise RuntimeError("No final run object found")
 111
 112        return self.__current_run
 113
 114    def get_final_run_steps(self) -> list[RunStep]:
 115        """Wait for the stream to finish and returns the steps taken in this run"""
 116        self.until_done()
 117
 118        if not self.__run_step_snapshots:
 119            raise RuntimeError("No run steps found")
 120
 121        return [step for step in self.__run_step_snapshots.values()]
 122
 123    def get_final_messages(self) -> list[Message]:
 124        """Wait for the stream to finish and returns the messages emitted in this run"""
 125        self.until_done()
 126
 127        if not self.__message_snapshots:
 128            raise RuntimeError("No messages found")
 129
 130        return [message for message in self.__message_snapshots.values()]
 131
 132    def __text_deltas__(self) -> Iterator[str]:
 133        for event in self:
 134            if event.event != "thread.message.delta":
 135                continue
 136
 137            for content_delta in event.data.delta.content or []:
 138                if content_delta.type == "text" and content_delta.text and content_delta.text.value:
 139                    yield content_delta.text.value
 140
 141    # event handlers
 142
 143    def on_end(self) -> None:
 144        """Fires when the stream has finished.
 145
 146        This happens if the stream is read to completion
 147        or if an exception occurs during iteration.
 148        """
 149
 150    def on_event(self, event: AssistantStreamEvent) -> None:
 151        """Callback that is fired for every Server-Sent-Event"""
 152
 153    def on_run_step_created(self, run_step: RunStep) -> None:
 154        """Callback that is fired when a run step is created"""
 155
 156    def on_run_step_delta(self, delta: RunStepDelta, snapshot: RunStep) -> None:
 157        """Callback that is fired whenever a run step delta is returned from the API
 158
 159        The first argument is just the delta as sent by the API and the second argument
 160        is the accumulated snapshot of the run step. For example, a tool calls event may
 161        look like this:
 162
 163        # delta
 164        tool_calls=[
 165            RunStepDeltaToolCallsCodeInterpreter(
 166                index=0,
 167                type='code_interpreter',
 168                id=None,
 169                code_interpreter=CodeInterpreter(input=' sympy', outputs=None)
 170            )
 171        ]
 172        # snapshot
 173        tool_calls=[
 174            CodeToolCall(
 175                id='call_wKayJlcYV12NiadiZuJXxcfx',
 176                code_interpreter=CodeInterpreter(input='from sympy', outputs=[]),
 177                type='code_interpreter',
 178                index=0
 179            )
 180        ],
 181        """
 182
 183    def on_run_step_done(self, run_step: RunStep) -> None:
 184        """Callback that is fired when a run step is completed"""
 185
 186    def on_tool_call_created(self, tool_call: ToolCall) -> None:
 187        """Callback that is fired when a tool call is created"""
 188
 189    def on_tool_call_delta(self, delta: ToolCallDelta, snapshot: ToolCall) -> None:
 190        """Callback that is fired when a tool call delta is encountered"""
 191
 192    def on_tool_call_done(self, tool_call: ToolCall) -> None:
 193        """Callback that is fired when a tool call delta is encountered"""
 194
 195    def on_exception(self, exception: Exception) -> None:
 196        """Fired whenever an exception happens during streaming"""
 197
 198    def on_timeout(self) -> None:
 199        """Fires if the request times out"""
 200
 201    def on_message_created(self, message: Message) -> None:
 202        """Callback that is fired when a message is created"""
 203
 204    def on_message_delta(self, delta: MessageDelta, snapshot: Message) -> None:
 205        """Callback that is fired whenever a message delta is returned from the API
 206
 207        The first argument is just the delta as sent by the API and the second argument
 208        is the accumulated snapshot of the message. For example, a text content event may
 209        look like this:
 210
 211        # delta
 212        MessageDeltaText(
 213            index=0,
 214            type='text',
 215            text=Text(
 216                value=' Jane'
 217            ),
 218        )
 219        # snapshot
 220        MessageContentText(
 221            index=0,
 222            type='text',
 223            text=Text(
 224                value='Certainly, Jane'
 225            ),
 226        )
 227        """
 228
 229    def on_message_done(self, message: Message) -> None:
 230        """Callback that is fired when a message is completed"""
 231
 232    def on_text_created(self, text: Text) -> None:
 233        """Callback that is fired when a text content block is created"""
 234
 235    def on_text_delta(self, delta: TextDelta, snapshot: Text) -> None:
 236        """Callback that is fired whenever a text content delta is returned
 237        by the API.
 238
 239        The first argument is just the delta as sent by the API and the second argument
 240        is the accumulated snapshot of the text. For example:
 241
 242        on_text_delta(TextDelta(value="The"), Text(value="The")),
 243        on_text_delta(TextDelta(value=" solution"), Text(value="The solution")),
 244        on_text_delta(TextDelta(value=" to"), Text(value="The solution to")),
 245        on_text_delta(TextDelta(value=" the"), Text(value="The solution to the")),
 246        on_text_delta(TextDelta(value=" equation"), Text(value="The solution to the equation")),
 247        """
 248
 249    def on_text_done(self, text: Text) -> None:
 250        """Callback that is fired when a text content block is finished"""
 251
 252    def on_image_file_done(self, image_file: ImageFile) -> None:
 253        """Callback that is fired when an image file block is finished"""
 254
 255    def _emit_sse_event(self, event: AssistantStreamEvent) -> None:
 256        self._current_event = event
 257        self.on_event(event)
 258
 259        self.__current_message_snapshot, new_content = accumulate_event(
 260            event=event,
 261            current_message_snapshot=self.__current_message_snapshot,
 262        )
 263        if self.__current_message_snapshot is not None:
 264            self.__message_snapshots[self.__current_message_snapshot.id] = self.__current_message_snapshot
 265
 266        accumulate_run_step(
 267            event=event,
 268            run_step_snapshots=self.__run_step_snapshots,
 269        )
 270
 271        for content_delta in new_content:
 272            assert self.__current_message_snapshot is not None
 273
 274            block = self.__current_message_snapshot.content[content_delta.index]
 275            if block.type == "text":
 276                self.on_text_created(block.text)
 277
 278        if (
 279            event.event == "thread.run.completed"
 280            or event.event == "thread.run.cancelled"
 281            or event.event == "thread.run.expired"
 282            or event.event == "thread.run.failed"
 283            or event.event == "thread.run.requires_action"
 284            or event.event == "thread.run.incomplete"
 285        ):
 286            self.__current_run = event.data
 287            if self._current_tool_call:
 288                self.on_tool_call_done(self._current_tool_call)
 289        elif (
 290            event.event == "thread.run.created"
 291            or event.event == "thread.run.in_progress"
 292            or event.event == "thread.run.cancelling"
 293            or event.event == "thread.run.queued"
 294        ):
 295            self.__current_run = event.data
 296        elif event.event == "thread.message.created":
 297            self.on_message_created(event.data)
 298        elif event.event == "thread.message.delta":
 299            snapshot = self.__current_message_snapshot
 300            assert snapshot is not None
 301
 302            message_delta = event.data.delta
 303            if message_delta.content is not None:
 304                for content_delta in message_delta.content:
 305                    if content_delta.type == "text" and content_delta.text:
 306                        snapshot_content = snapshot.content[content_delta.index]
 307                        assert snapshot_content.type == "text"
 308                        self.on_text_delta(content_delta.text, snapshot_content.text)
 309
 310                    # If the delta is for a new message content:
 311                    # - emit on_text_done/on_image_file_done for the previous message content
 312                    # - emit on_text_created/on_image_created for the new message content
 313                    if content_delta.index != self._current_message_content_index:
 314                        if self._current_message_content is not None:
 315                            if self._current_message_content.type == "text":
 316                                self.on_text_done(self._current_message_content.text)
 317                            elif self._current_message_content.type == "image_file":
 318                                self.on_image_file_done(self._current_message_content.image_file)
 319
 320                        self._current_message_content_index = content_delta.index
 321                        self._current_message_content = snapshot.content[content_delta.index]
 322
 323                    # Update the current_message_content (delta event is correctly emitted already)
 324                    self._current_message_content = snapshot.content[content_delta.index]
 325
 326            self.on_message_delta(event.data.delta, snapshot)
 327        elif event.event == "thread.message.completed" or event.event == "thread.message.incomplete":
 328            self.__current_message_snapshot = event.data
 329            self.__message_snapshots[event.data.id] = event.data
 330
 331            if self._current_message_content_index is not None:
 332                content = event.data.content[self._current_message_content_index]
 333                if content.type == "text":
 334                    self.on_text_done(content.text)
 335                elif content.type == "image_file":
 336                    self.on_image_file_done(content.image_file)
 337
 338            self.on_message_done(event.data)
 339        elif event.event == "thread.run.step.created":
 340            self.__current_run_step_id = event.data.id
 341            self.on_run_step_created(event.data)
 342        elif event.event == "thread.run.step.in_progress":
 343            self.__current_run_step_id = event.data.id
 344        elif event.event == "thread.run.step.delta":
 345            step_snapshot = self.__run_step_snapshots[event.data.id]
 346
 347            run_step_delta = event.data.delta
 348            if (
 349                run_step_delta.step_details
 350                and run_step_delta.step_details.type == "tool_calls"
 351                and run_step_delta.step_details.tool_calls is not None
 352            ):
 353                assert step_snapshot.step_details.type == "tool_calls"
 354                for tool_call_delta in run_step_delta.step_details.tool_calls:
 355                    if tool_call_delta.index == self._current_tool_call_index:
 356                        self.on_tool_call_delta(
 357                            tool_call_delta,
 358                            step_snapshot.step_details.tool_calls[tool_call_delta.index],
 359                        )
 360
 361                    # If the delta is for a new tool call:
 362                    # - emit on_tool_call_done for the previous tool_call
 363                    # - emit on_tool_call_created for the new tool_call
 364                    if tool_call_delta.index != self._current_tool_call_index:
 365                        if self._current_tool_call is not None:
 366                            self.on_tool_call_done(self._current_tool_call)
 367
 368                        self._current_tool_call_index = tool_call_delta.index
 369                        self._current_tool_call = step_snapshot.step_details.tool_calls[tool_call_delta.index]
 370                        self.on_tool_call_created(self._current_tool_call)
 371
 372                    # Update the current_tool_call (delta event is correctly emitted already)
 373                    self._current_tool_call = step_snapshot.step_details.tool_calls[tool_call_delta.index]
 374
 375            self.on_run_step_delta(
 376                event.data.delta,
 377                step_snapshot,
 378            )
 379        elif (
 380            event.event == "thread.run.step.completed"
 381            or event.event == "thread.run.step.cancelled"
 382            or event.event == "thread.run.step.expired"
 383            or event.event == "thread.run.step.failed"
 384        ):
 385            if self._current_tool_call:
 386                self.on_tool_call_done(self._current_tool_call)
 387
 388            self.on_run_step_done(event.data)
 389            self.__current_run_step_id = None
 390        elif event.event == "thread.created" or event.event == "thread.message.in_progress" or event.event == "error":
 391            # currently no special handling
 392            ...
 393        else:
 394            # we only want to error at build-time
 395            if TYPE_CHECKING:  # type: ignore[unreachable]
 396                assert_never(event)
 397
 398        self._current_event = None
 399
 400    def __stream__(self) -> Iterator[AssistantStreamEvent]:
 401        stream = self.__stream
 402        if not stream:
 403            raise RuntimeError("Stream has not been started yet")
 404
 405        try:
 406            for event in stream:
 407                self._emit_sse_event(event)
 408
 409                yield event
 410        except (httpx.TimeoutException, asyncio.TimeoutError) as exc:
 411            self.on_timeout()
 412            self.on_exception(exc)
 413            raise
 414        except Exception as exc:
 415            self.on_exception(exc)
 416            raise
 417        finally:
 418            self.on_end()
 419
 420
 421AssistantEventHandlerT = TypeVar("AssistantEventHandlerT", bound=AssistantEventHandler)
 422
 423
 424class AssistantStreamManager(Generic[AssistantEventHandlerT]):
 425    """Wrapper over AssistantStreamEventHandler that is returned by `.stream()`
 426    so that a context manager can be used.
 427
 428    ```py
 429    with client.threads.create_and_run_stream(...) as stream:
 430        for event in stream:
 431            ...
 432    ```
 433    """
 434
 435    def __init__(
 436        self,
 437        api_request: Callable[[], Stream[AssistantStreamEvent]],
 438        *,
 439        event_handler: AssistantEventHandlerT,
 440    ) -> None:
 441        self.__stream: Stream[AssistantStreamEvent] | None = None
 442        self.__event_handler = event_handler
 443        self.__api_request = api_request
 444
 445    def __enter__(self) -> AssistantEventHandlerT:
 446        self.__stream = self.__api_request()
 447        self.__event_handler._init(self.__stream)
 448        return self.__event_handler
 449
 450    def __exit__(
 451        self,
 452        exc_type: type[BaseException] | None,
 453        exc: BaseException | None,
 454        exc_tb: TracebackType | None,
 455    ) -> None:
 456        if self.__stream is not None:
 457            self.__stream.close()
 458
 459
 460class AsyncAssistantEventHandler:
 461    text_deltas: AsyncIterable[str]
 462    """Iterator over just the text deltas in the stream.
 463
 464    This corresponds to the `thread.message.delta` event
 465    in the API.
 466
 467    ```py
 468    async for text in stream.text_deltas:
 469        print(text, end="", flush=True)
 470    print()
 471    ```
 472    """
 473
 474    def __init__(self) -> None:
 475        self._current_event: AssistantStreamEvent | None = None
 476        self._current_message_content_index: int | None = None
 477        self._current_message_content: MessageContent | None = None
 478        self._current_tool_call_index: int | None = None
 479        self._current_tool_call: ToolCall | None = None
 480        self.__current_run_step_id: str | None = None
 481        self.__current_run: Run | None = None
 482        self.__run_step_snapshots: dict[str, RunStep] = {}
 483        self.__message_snapshots: dict[str, Message] = {}
 484        self.__current_message_snapshot: Message | None = None
 485
 486        self.text_deltas = self.__text_deltas__()
 487        self._iterator = self.__stream__()
 488        self.__stream: AsyncStream[AssistantStreamEvent] | None = None
 489
 490    def _init(self, stream: AsyncStream[AssistantStreamEvent]) -> None:
 491        if self.__stream:
 492            raise RuntimeError(
 493                "A single event handler cannot be shared between multiple streams; You will need to construct a new event handler instance"
 494            )
 495
 496        self.__stream = stream
 497
 498    async def __anext__(self) -> AssistantStreamEvent:
 499        return await self._iterator.__anext__()
 500
 501    async def __aiter__(self) -> AsyncIterator[AssistantStreamEvent]:
 502        async for item in self._iterator:
 503            yield item
 504
 505    async def close(self) -> None:
 506        """
 507        Close the response and release the connection.
 508
 509        Automatically called when the context manager exits.
 510        """
 511        if self.__stream:
 512            await self.__stream.close()
 513
 514    @property
 515    def current_event(self) -> AssistantStreamEvent | None:
 516        return self._current_event
 517
 518    @property
 519    def current_run(self) -> Run | None:
 520        return self.__current_run
 521
 522    @property
 523    def current_run_step_snapshot(self) -> RunStep | None:
 524        if not self.__current_run_step_id:
 525            return None
 526
 527        return self.__run_step_snapshots[self.__current_run_step_id]
 528
 529    @property
 530    def current_message_snapshot(self) -> Message | None:
 531        return self.__current_message_snapshot
 532
 533    async def until_done(self) -> None:
 534        """Waits until the stream has been consumed"""
 535        await consume_async_iterator(self)
 536
 537    async def get_final_run(self) -> Run:
 538        """Wait for the stream to finish and returns the completed Run object"""
 539        await self.until_done()
 540
 541        if not self.__current_run:
 542            raise RuntimeError("No final run object found")
 543
 544        return self.__current_run
 545
 546    async def get_final_run_steps(self) -> list[RunStep]:
 547        """Wait for the stream to finish and returns the steps taken in this run"""
 548        await self.until_done()
 549
 550        if not self.__run_step_snapshots:
 551            raise RuntimeError("No run steps found")
 552
 553        return [step for step in self.__run_step_snapshots.values()]
 554
 555    async def get_final_messages(self) -> list[Message]:
 556        """Wait for the stream to finish and returns the messages emitted in this run"""
 557        await self.until_done()
 558
 559        if not self.__message_snapshots:
 560            raise RuntimeError("No messages found")
 561
 562        return [message for message in self.__message_snapshots.values()]
 563
 564    async def __text_deltas__(self) -> AsyncIterator[str]:
 565        async for event in self:
 566            if event.event != "thread.message.delta":
 567                continue
 568
 569            for content_delta in event.data.delta.content or []:
 570                if content_delta.type == "text" and content_delta.text and content_delta.text.value:
 571                    yield content_delta.text.value
 572
 573    # event handlers
 574
 575    async def on_end(self) -> None:
 576        """Fires when the stream has finished.
 577
 578        This happens if the stream is read to completion
 579        or if an exception occurs during iteration.
 580        """
 581
 582    async def on_event(self, event: AssistantStreamEvent) -> None:
 583        """Callback that is fired for every Server-Sent-Event"""
 584
 585    async def on_run_step_created(self, run_step: RunStep) -> None:
 586        """Callback that is fired when a run step is created"""
 587
 588    async def on_run_step_delta(self, delta: RunStepDelta, snapshot: RunStep) -> None:
 589        """Callback that is fired whenever a run step delta is returned from the API
 590
 591        The first argument is just the delta as sent by the API and the second argument
 592        is the accumulated snapshot of the run step. For example, a tool calls event may
 593        look like this:
 594
 595        # delta
 596        tool_calls=[
 597            RunStepDeltaToolCallsCodeInterpreter(
 598                index=0,
 599                type='code_interpreter',
 600                id=None,
 601                code_interpreter=CodeInterpreter(input=' sympy', outputs=None)
 602            )
 603        ]
 604        # snapshot
 605        tool_calls=[
 606            CodeToolCall(
 607                id='call_wKayJlcYV12NiadiZuJXxcfx',
 608                code_interpreter=CodeInterpreter(input='from sympy', outputs=[]),
 609                type='code_interpreter',
 610                index=0
 611            )
 612        ],
 613        """
 614
 615    async def on_run_step_done(self, run_step: RunStep) -> None:
 616        """Callback that is fired when a run step is completed"""
 617
 618    async def on_tool_call_created(self, tool_call: ToolCall) -> None:
 619        """Callback that is fired when a tool call is created"""
 620
 621    async def on_tool_call_delta(self, delta: ToolCallDelta, snapshot: ToolCall) -> None:
 622        """Callback that is fired when a tool call delta is encountered"""
 623
 624    async def on_tool_call_done(self, tool_call: ToolCall) -> None:
 625        """Callback that is fired when a tool call delta is encountered"""
 626
 627    async def on_exception(self, exception: Exception) -> None:
 628        """Fired whenever an exception happens during streaming"""
 629
 630    async def on_timeout(self) -> None:
 631        """Fires if the request times out"""
 632
 633    async def on_message_created(self, message: Message) -> None:
 634        """Callback that is fired when a message is created"""
 635
 636    async def on_message_delta(self, delta: MessageDelta, snapshot: Message) -> None:
 637        """Callback that is fired whenever a message delta is returned from the API
 638
 639        The first argument is just the delta as sent by the API and the second argument
 640        is the accumulated snapshot of the message. For example, a text content event may
 641        look like this:
 642
 643        # delta
 644        MessageDeltaText(
 645            index=0,
 646            type='text',
 647            text=Text(
 648                value=' Jane'
 649            ),
 650        )
 651        # snapshot
 652        MessageContentText(
 653            index=0,
 654            type='text',
 655            text=Text(
 656                value='Certainly, Jane'
 657            ),
 658        )
 659        """
 660
 661    async def on_message_done(self, message: Message) -> None:
 662        """Callback that is fired when a message is completed"""
 663
 664    async def on_text_created(self, text: Text) -> None:
 665        """Callback that is fired when a text content block is created"""
 666
 667    async def on_text_delta(self, delta: TextDelta, snapshot: Text) -> None:
 668        """Callback that is fired whenever a text content delta is returned
 669        by the API.
 670
 671        The first argument is just the delta as sent by the API and the second argument
 672        is the accumulated snapshot of the text. For example:
 673
 674        on_text_delta(TextDelta(value="The"), Text(value="The")),
 675        on_text_delta(TextDelta(value=" solution"), Text(value="The solution")),
 676        on_text_delta(TextDelta(value=" to"), Text(value="The solution to")),
 677        on_text_delta(TextDelta(value=" the"), Text(value="The solution to the")),
 678        on_text_delta(TextDelta(value=" equation"), Text(value="The solution to the equivalent")),
 679        """
 680
 681    async def on_text_done(self, text: Text) -> None:
 682        """Callback that is fired when a text content block is finished"""
 683
 684    async def on_image_file_done(self, image_file: ImageFile) -> None:
 685        """Callback that is fired when an image file block is finished"""
 686
 687    async def _emit_sse_event(self, event: AssistantStreamEvent) -> None:
 688        self._current_event = event
 689        await self.on_event(event)
 690
 691        self.__current_message_snapshot, new_content = accumulate_event(
 692            event=event,
 693            current_message_snapshot=self.__current_message_snapshot,
 694        )
 695        if self.__current_message_snapshot is not None:
 696            self.__message_snapshots[self.__current_message_snapshot.id] = self.__current_message_snapshot
 697
 698        accumulate_run_step(
 699            event=event,
 700            run_step_snapshots=self.__run_step_snapshots,
 701        )
 702
 703        for content_delta in new_content:
 704            assert self.__current_message_snapshot is not None
 705
 706            block = self.__current_message_snapshot.content[content_delta.index]
 707            if block.type == "text":
 708                await self.on_text_created(block.text)
 709
 710        if (
 711            event.event == "thread.run.completed"
 712            or event.event == "thread.run.cancelled"
 713            or event.event == "thread.run.expired"
 714            or event.event == "thread.run.failed"
 715            or event.event == "thread.run.requires_action"
 716            or event.event == "thread.run.incomplete"
 717        ):
 718            self.__current_run = event.data
 719            if self._current_tool_call:
 720                await self.on_tool_call_done(self._current_tool_call)
 721        elif (
 722            event.event == "thread.run.created"
 723            or event.event == "thread.run.in_progress"
 724            or event.event == "thread.run.cancelling"
 725            or event.event == "thread.run.queued"
 726        ):
 727            self.__current_run = event.data
 728        elif event.event == "thread.message.created":
 729            await self.on_message_created(event.data)
 730        elif event.event == "thread.message.delta":
 731            snapshot = self.__current_message_snapshot
 732            assert snapshot is not None
 733
 734            message_delta = event.data.delta
 735            if message_delta.content is not None:
 736                for content_delta in message_delta.content:
 737                    if content_delta.type == "text" and content_delta.text:
 738                        snapshot_content = snapshot.content[content_delta.index]
 739                        assert snapshot_content.type == "text"
 740                        await self.on_text_delta(content_delta.text, snapshot_content.text)
 741
 742                    # If the delta is for a new message content:
 743                    # - emit on_text_done/on_image_file_done for the previous message content
 744                    # - emit on_text_created/on_image_created for the new message content
 745                    if content_delta.index != self._current_message_content_index:
 746                        if self._current_message_content is not None:
 747                            if self._current_message_content.type == "text":
 748                                await self.on_text_done(self._current_message_content.text)
 749                            elif self._current_message_content.type == "image_file":
 750                                await self.on_image_file_done(self._current_message_content.image_file)
 751
 752                        self._current_message_content_index = content_delta.index
 753                        self._current_message_content = snapshot.content[content_delta.index]
 754
 755                    # Update the current_message_content (delta event is correctly emitted already)
 756                    self._current_message_content = snapshot.content[content_delta.index]
 757
 758            await self.on_message_delta(event.data.delta, snapshot)
 759        elif event.event == "thread.message.completed" or event.event == "thread.message.incomplete":
 760            self.__current_message_snapshot = event.data
 761            self.__message_snapshots[event.data.id] = event.data
 762
 763            if self._current_message_content_index is not None:
 764                content = event.data.content[self._current_message_content_index]
 765                if content.type == "text":
 766                    await self.on_text_done(content.text)
 767                elif content.type == "image_file":
 768                    await self.on_image_file_done(content.image_file)
 769
 770            await self.on_message_done(event.data)
 771        elif event.event == "thread.run.step.created":
 772            self.__current_run_step_id = event.data.id
 773            await self.on_run_step_created(event.data)
 774        elif event.event == "thread.run.step.in_progress":
 775            self.__current_run_step_id = event.data.id
 776        elif event.event == "thread.run.step.delta":
 777            step_snapshot = self.__run_step_snapshots[event.data.id]
 778
 779            run_step_delta = event.data.delta
 780            if (
 781                run_step_delta.step_details
 782                and run_step_delta.step_details.type == "tool_calls"
 783                and run_step_delta.step_details.tool_calls is not None
 784            ):
 785                assert step_snapshot.step_details.type == "tool_calls"
 786                for tool_call_delta in run_step_delta.step_details.tool_calls:
 787                    if tool_call_delta.index == self._current_tool_call_index:
 788                        await self.on_tool_call_delta(
 789                            tool_call_delta,
 790                            step_snapshot.step_details.tool_calls[tool_call_delta.index],
 791                        )
 792
 793                    # If the delta is for a new tool call:
 794                    # - emit on_tool_call_done for the previous tool_call
 795                    # - emit on_tool_call_created for the new tool_call
 796                    if tool_call_delta.index != self._current_tool_call_index:
 797                        if self._current_tool_call is not None:
 798                            await self.on_tool_call_done(self._current_tool_call)
 799
 800                        self._current_tool_call_index = tool_call_delta.index
 801                        self._current_tool_call = step_snapshot.step_details.tool_calls[tool_call_delta.index]
 802                        await self.on_tool_call_created(self._current_tool_call)
 803
 804                    # Update the current_tool_call (delta event is correctly emitted already)
 805                    self._current_tool_call = step_snapshot.step_details.tool_calls[tool_call_delta.index]
 806
 807            await self.on_run_step_delta(
 808                event.data.delta,
 809                step_snapshot,
 810            )
 811        elif (
 812            event.event == "thread.run.step.completed"
 813            or event.event == "thread.run.step.cancelled"
 814            or event.event == "thread.run.step.expired"
 815            or event.event == "thread.run.step.failed"
 816        ):
 817            if self._current_tool_call:
 818                await self.on_tool_call_done(self._current_tool_call)
 819
 820            await self.on_run_step_done(event.data)
 821            self.__current_run_step_id = None
 822        elif event.event == "thread.created" or event.event == "thread.message.in_progress" or event.event == "error":
 823            # currently no special handling
 824            ...
 825        else:
 826            # we only want to error at build-time
 827            if TYPE_CHECKING:  # type: ignore[unreachable]
 828                assert_never(event)
 829
 830        self._current_event = None
 831
 832    async def __stream__(self) -> AsyncIterator[AssistantStreamEvent]:
 833        stream = self.__stream
 834        if not stream:
 835            raise RuntimeError("Stream has not been started yet")
 836
 837        try:
 838            async for event in stream:
 839                await self._emit_sse_event(event)
 840
 841                yield event
 842        except (httpx.TimeoutException, asyncio.TimeoutError) as exc:
 843            await self.on_timeout()
 844            await self.on_exception(exc)
 845            raise
 846        except Exception as exc:
 847            await self.on_exception(exc)
 848            raise
 849        finally:
 850            await self.on_end()
 851
 852
 853AsyncAssistantEventHandlerT = TypeVar("AsyncAssistantEventHandlerT", bound=AsyncAssistantEventHandler)
 854
 855
 856class AsyncAssistantStreamManager(Generic[AsyncAssistantEventHandlerT]):
 857    """Wrapper over AsyncAssistantStreamEventHandler that is returned by `.stream()`
 858    so that an async context manager can be used without `await`ing the
 859    original client call.
 860
 861    ```py
 862    async with client.threads.create_and_run_stream(...) as stream:
 863        async for event in stream:
 864            ...
 865    ```
 866    """
 867
 868    def __init__(
 869        self,
 870        api_request: Awaitable[AsyncStream[AssistantStreamEvent]],
 871        *,
 872        event_handler: AsyncAssistantEventHandlerT,
 873    ) -> None:
 874        self.__stream: AsyncStream[AssistantStreamEvent] | None = None
 875        self.__event_handler = event_handler
 876        self.__api_request = api_request
 877
 878    async def __aenter__(self) -> AsyncAssistantEventHandlerT:
 879        self.__stream = await self.__api_request
 880        self.__event_handler._init(self.__stream)
 881        return self.__event_handler
 882
 883    async def __aexit__(
 884        self,
 885        exc_type: type[BaseException] | None,
 886        exc: BaseException | None,
 887        exc_tb: TracebackType | None,
 888    ) -> None:
 889        if self.__stream is not None:
 890            await self.__stream.close()
 891
 892
 893def accumulate_run_step(
 894    *,
 895    event: AssistantStreamEvent,
 896    run_step_snapshots: dict[str, RunStep],
 897) -> None:
 898    if event.event == "thread.run.step.created":
 899        run_step_snapshots[event.data.id] = event.data
 900        return
 901
 902    if event.event == "thread.run.step.delta":
 903        data = event.data
 904        snapshot = run_step_snapshots[data.id]
 905
 906        if data.delta:
 907            merged = accumulate_delta(
 908                cast(
 909                    "dict[object, object]",
 910                    model_dump(snapshot, exclude_unset=True, warnings=False),
 911                ),
 912                cast(
 913                    "dict[object, object]",
 914                    model_dump(data.delta, exclude_unset=True, warnings=False),
 915                ),
 916            )
 917            run_step_snapshots[snapshot.id] = cast(RunStep, construct_type(type_=RunStep, value=merged))
 918
 919    return None
 920
 921
 922def accumulate_event(
 923    *,
 924    event: AssistantStreamEvent,
 925    current_message_snapshot: Message | None,
 926) -> tuple[Message | None, list[MessageContentDelta]]:
 927    """Returns a tuple of message snapshot and newly created text message deltas"""
 928    if event.event == "thread.message.created":
 929        return event.data, []
 930
 931    new_content: list[MessageContentDelta] = []
 932
 933    if event.event != "thread.message.delta":
 934        return current_message_snapshot, []
 935
 936    if not current_message_snapshot:
 937        raise RuntimeError("Encountered a message delta with no previous snapshot")
 938
 939    data = event.data
 940    if data.delta.content:
 941        for content_delta in data.delta.content:
 942            try:
 943                block = current_message_snapshot.content[content_delta.index]
 944            except IndexError:
 945                current_message_snapshot.content.insert(
 946                    content_delta.index,
 947                    cast(
 948                        MessageContent,
 949                        construct_type(
 950                            # mypy doesn't allow Content for some reason
 951                            type_=cast(Any, MessageContent),
 952                            value=model_dump(content_delta, exclude_unset=True, warnings=False),
 953                        ),
 954                    ),
 955                )
 956                new_content.append(content_delta)
 957            else:
 958                merged = accumulate_delta(
 959                    cast(
 960                        "dict[object, object]",
 961                        model_dump(block, exclude_unset=True, warnings=False),
 962                    ),
 963                    cast(
 964                        "dict[object, object]",
 965                        model_dump(content_delta, exclude_unset=True, warnings=False),
 966                    ),
 967                )
 968                current_message_snapshot.content[content_delta.index] = cast(
 969                    MessageContent,
 970                    construct_type(
 971                        # mypy doesn't allow Content for some reason
 972                        type_=cast(Any, MessageContent),
 973                        value=merged,
 974                    ),
 975                )
 976
 977    return current_message_snapshot, new_content
 978
 979
 980def accumulate_delta(acc: dict[object, object], delta: dict[object, object]) -> dict[object, object]:
 981    for key, delta_value in delta.items():
 982        if key not in acc:
 983            acc[key] = delta_value
 984            continue
 985
 986        acc_value = acc[key]
 987        if acc_value is None:
 988            acc[key] = delta_value
 989            continue
 990
 991        # the `index` property is used in arrays of objects so it should
 992        # not be accumulated like other values e.g.
 993        # [{'foo': 'bar', 'index': 0}]
 994        #
 995        # the same applies to `type` properties as they're used for
 996        # discriminated unions
 997        if key == "index" or key == "type":
 998            acc[key] = delta_value
 999            continue
1000
1001        if isinstance(acc_value, str) and isinstance(delta_value, str):
1002            acc_value += delta_value
1003        elif isinstance(acc_value, (int, float)) and isinstance(delta_value, (int, float)):
1004            acc_value += delta_value
1005        elif is_dict(acc_value) and is_dict(delta_value):
1006            acc_value = accumulate_delta(acc_value, delta_value)
1007        elif is_list(acc_value) and is_list(delta_value):
1008            # for lists of non-dictionary items we'll only ever get new entries
1009            # in the array, existing entries will never be changed
1010            if all(isinstance(x, (str, int, float)) for x in acc_value):
1011                acc_value.extend(delta_value)
1012                continue
1013
1014            for delta_entry in delta_value:
1015                if not is_dict(delta_entry):
1016                    raise TypeError(f"Unexpected list delta entry is not a dictionary: {delta_entry}")
1017
1018                try:
1019                    index = delta_entry["index"]
1020                except KeyError as exc:
1021                    raise RuntimeError(f"Expected list delta entry to have an `index` key; {delta_entry}") from exc
1022
1023                if not isinstance(index, int):
1024                    raise TypeError(f"Unexpected, list delta entry `index` value is not an integer; {index}")
1025
1026                try:
1027                    acc_entry = acc_value[index]
1028                except IndexError:
1029                    acc_value.insert(index, delta_entry)
1030                else:
1031                    if not is_dict(acc_entry):
1032                        raise TypeError("not handled yet")
1033
1034                    acc_value[index] = accumulate_delta(acc_entry, delta_entry)
1035
1036        acc[key] = acc_value
1037
1038    return acc