main
1from __future__ import annotations
2
3import inspect
4from types import TracebackType
5from typing import TYPE_CHECKING, Any, Generic, Callable, Iterable, Awaitable, AsyncIterator, cast
6from typing_extensions import Self, Iterator, assert_never
7
8from jiter import from_json
9
10from ._types import ParsedChoiceSnapshot, ParsedChatCompletionSnapshot, ParsedChatCompletionMessageSnapshot
11from ._events import (
12 ChunkEvent,
13 ContentDoneEvent,
14 RefusalDoneEvent,
15 ContentDeltaEvent,
16 RefusalDeltaEvent,
17 LogprobsContentDoneEvent,
18 LogprobsRefusalDoneEvent,
19 ChatCompletionStreamEvent,
20 LogprobsContentDeltaEvent,
21 LogprobsRefusalDeltaEvent,
22 FunctionToolCallArgumentsDoneEvent,
23 FunctionToolCallArgumentsDeltaEvent,
24)
25from .._deltas import accumulate_delta
26from ...._types import Omit, IncEx, omit
27from ...._utils import is_given, consume_sync_iterator, consume_async_iterator
28from ...._compat import model_dump
29from ...._models import build, construct_type
30from ..._parsing import (
31 ResponseFormatT,
32 has_parseable_input,
33 maybe_parse_content,
34 parse_chat_completion,
35 get_input_tool_by_name,
36 solve_response_format_t,
37 parse_function_tool_arguments,
38)
39from ...._streaming import Stream, AsyncStream
40from ....types.chat import ChatCompletionChunk, ParsedChatCompletion, ChatCompletionToolUnionParam
41from ...._exceptions import LengthFinishReasonError, ContentFilterFinishReasonError
42from ....types.chat.chat_completion import ChoiceLogprobs
43from ....types.chat.chat_completion_chunk import Choice as ChoiceChunk
44from ....types.chat.completion_create_params import ResponseFormat as ResponseFormatParam
45
46
47class ChatCompletionStream(Generic[ResponseFormatT]):
48 """Wrapper over the Chat Completions streaming API that adds helpful
49 events such as `content.done`, supports automatically parsing
50 responses & tool calls and accumulates a `ChatCompletion` object
51 from each individual chunk.
52
53 https://platform.openai.com/docs/api-reference/streaming
54 """
55
56 def __init__(
57 self,
58 *,
59 raw_stream: Stream[ChatCompletionChunk],
60 response_format: type[ResponseFormatT] | ResponseFormatParam | Omit,
61 input_tools: Iterable[ChatCompletionToolUnionParam] | Omit,
62 ) -> None:
63 self._raw_stream = raw_stream
64 self._response = raw_stream.response
65 self._iterator = self.__stream__()
66 self._state = ChatCompletionStreamState(response_format=response_format, input_tools=input_tools)
67
68 def __next__(self) -> ChatCompletionStreamEvent[ResponseFormatT]:
69 return self._iterator.__next__()
70
71 def __iter__(self) -> Iterator[ChatCompletionStreamEvent[ResponseFormatT]]:
72 for item in self._iterator:
73 yield item
74
75 def __enter__(self) -> Self:
76 return self
77
78 def __exit__(
79 self,
80 exc_type: type[BaseException] | None,
81 exc: BaseException | None,
82 exc_tb: TracebackType | None,
83 ) -> None:
84 self.close()
85
86 def close(self) -> None:
87 """
88 Close the response and release the connection.
89
90 Automatically called if the response body is read to completion.
91 """
92 self._response.close()
93
94 def get_final_completion(self) -> ParsedChatCompletion[ResponseFormatT]:
95 """Waits until the stream has been read to completion and returns
96 the accumulated `ParsedChatCompletion` object.
97
98 If you passed a class type to `.stream()`, the `completion.choices[0].message.parsed`
99 property will be the content deserialised into that class, if there was any content returned
100 by the API.
101 """
102 self.until_done()
103 return self._state.get_final_completion()
104
105 def until_done(self) -> Self:
106 """Blocks until the stream has been consumed."""
107 consume_sync_iterator(self)
108 return self
109
110 @property
111 def current_completion_snapshot(self) -> ParsedChatCompletionSnapshot:
112 return self._state.current_completion_snapshot
113
114 def __stream__(self) -> Iterator[ChatCompletionStreamEvent[ResponseFormatT]]:
115 for sse_event in self._raw_stream:
116 if not _is_valid_chat_completion_chunk_weak(sse_event):
117 continue
118 events_to_fire = self._state.handle_chunk(sse_event)
119 for event in events_to_fire:
120 yield event
121
122
123class ChatCompletionStreamManager(Generic[ResponseFormatT]):
124 """Context manager over a `ChatCompletionStream` that is returned by `.stream()`.
125
126 This context manager ensures the response cannot be leaked if you don't read
127 the stream to completion.
128
129 Usage:
130 ```py
131 with client.chat.completions.stream(...) as stream:
132 for event in stream:
133 ...
134 ```
135 """
136
137 def __init__(
138 self,
139 api_request: Callable[[], Stream[ChatCompletionChunk]],
140 *,
141 response_format: type[ResponseFormatT] | ResponseFormatParam | Omit,
142 input_tools: Iterable[ChatCompletionToolUnionParam] | Omit,
143 ) -> None:
144 self.__stream: ChatCompletionStream[ResponseFormatT] | None = None
145 self.__api_request = api_request
146 self.__response_format = response_format
147 self.__input_tools = input_tools
148
149 def __enter__(self) -> ChatCompletionStream[ResponseFormatT]:
150 raw_stream = self.__api_request()
151
152 self.__stream = ChatCompletionStream(
153 raw_stream=raw_stream,
154 response_format=self.__response_format,
155 input_tools=self.__input_tools,
156 )
157
158 return self.__stream
159
160 def __exit__(
161 self,
162 exc_type: type[BaseException] | None,
163 exc: BaseException | None,
164 exc_tb: TracebackType | None,
165 ) -> None:
166 if self.__stream is not None:
167 self.__stream.close()
168
169
170class AsyncChatCompletionStream(Generic[ResponseFormatT]):
171 """Wrapper over the Chat Completions streaming API that adds helpful
172 events such as `content.done`, supports automatically parsing
173 responses & tool calls and accumulates a `ChatCompletion` object
174 from each individual chunk.
175
176 https://platform.openai.com/docs/api-reference/streaming
177 """
178
179 def __init__(
180 self,
181 *,
182 raw_stream: AsyncStream[ChatCompletionChunk],
183 response_format: type[ResponseFormatT] | ResponseFormatParam | Omit,
184 input_tools: Iterable[ChatCompletionToolUnionParam] | Omit,
185 ) -> None:
186 self._raw_stream = raw_stream
187 self._response = raw_stream.response
188 self._iterator = self.__stream__()
189 self._state = ChatCompletionStreamState(response_format=response_format, input_tools=input_tools)
190
191 async def __anext__(self) -> ChatCompletionStreamEvent[ResponseFormatT]:
192 return await self._iterator.__anext__()
193
194 async def __aiter__(self) -> AsyncIterator[ChatCompletionStreamEvent[ResponseFormatT]]:
195 async for item in self._iterator:
196 yield item
197
198 async def __aenter__(self) -> Self:
199 return self
200
201 async def __aexit__(
202 self,
203 exc_type: type[BaseException] | None,
204 exc: BaseException | None,
205 exc_tb: TracebackType | None,
206 ) -> None:
207 await self.close()
208
209 async def close(self) -> None:
210 """
211 Close the response and release the connection.
212
213 Automatically called if the response body is read to completion.
214 """
215 await self._response.aclose()
216
217 async def get_final_completion(self) -> ParsedChatCompletion[ResponseFormatT]:
218 """Waits until the stream has been read to completion and returns
219 the accumulated `ParsedChatCompletion` object.
220
221 If you passed a class type to `.stream()`, the `completion.choices[0].message.parsed`
222 property will be the content deserialised into that class, if there was any content returned
223 by the API.
224 """
225 await self.until_done()
226 return self._state.get_final_completion()
227
228 async def until_done(self) -> Self:
229 """Blocks until the stream has been consumed."""
230 await consume_async_iterator(self)
231 return self
232
233 @property
234 def current_completion_snapshot(self) -> ParsedChatCompletionSnapshot:
235 return self._state.current_completion_snapshot
236
237 async def __stream__(self) -> AsyncIterator[ChatCompletionStreamEvent[ResponseFormatT]]:
238 async for sse_event in self._raw_stream:
239 if not _is_valid_chat_completion_chunk_weak(sse_event):
240 continue
241 events_to_fire = self._state.handle_chunk(sse_event)
242 for event in events_to_fire:
243 yield event
244
245
246class AsyncChatCompletionStreamManager(Generic[ResponseFormatT]):
247 """Context manager over a `AsyncChatCompletionStream` that is returned by `.stream()`.
248
249 This context manager ensures the response cannot be leaked if you don't read
250 the stream to completion.
251
252 Usage:
253 ```py
254 async with client.chat.completions.stream(...) as stream:
255 for event in stream:
256 ...
257 ```
258 """
259
260 def __init__(
261 self,
262 api_request: Awaitable[AsyncStream[ChatCompletionChunk]],
263 *,
264 response_format: type[ResponseFormatT] | ResponseFormatParam | Omit,
265 input_tools: Iterable[ChatCompletionToolUnionParam] | Omit,
266 ) -> None:
267 self.__stream: AsyncChatCompletionStream[ResponseFormatT] | None = None
268 self.__api_request = api_request
269 self.__response_format = response_format
270 self.__input_tools = input_tools
271
272 async def __aenter__(self) -> AsyncChatCompletionStream[ResponseFormatT]:
273 raw_stream = await self.__api_request
274
275 self.__stream = AsyncChatCompletionStream(
276 raw_stream=raw_stream,
277 response_format=self.__response_format,
278 input_tools=self.__input_tools,
279 )
280
281 return self.__stream
282
283 async def __aexit__(
284 self,
285 exc_type: type[BaseException] | None,
286 exc: BaseException | None,
287 exc_tb: TracebackType | None,
288 ) -> None:
289 if self.__stream is not None:
290 await self.__stream.close()
291
292
293class ChatCompletionStreamState(Generic[ResponseFormatT]):
294 """Helper class for manually accumulating `ChatCompletionChunk`s into a final `ChatCompletion` object.
295
296 This is useful in cases where you can't always use the `.stream()` method, e.g.
297
298 ```py
299 from openai.lib.streaming.chat import ChatCompletionStreamState
300
301 state = ChatCompletionStreamState()
302
303 stream = client.chat.completions.create(..., stream=True)
304 for chunk in response:
305 state.handle_chunk(chunk)
306
307 # can also access the accumulated `ChatCompletion` mid-stream
308 state.current_completion_snapshot
309
310 print(state.get_final_completion())
311 ```
312 """
313
314 def __init__(
315 self,
316 *,
317 input_tools: Iterable[ChatCompletionToolUnionParam] | Omit = omit,
318 response_format: type[ResponseFormatT] | ResponseFormatParam | Omit = omit,
319 ) -> None:
320 self.__current_completion_snapshot: ParsedChatCompletionSnapshot | None = None
321 self.__choice_event_states: list[ChoiceEventState] = []
322
323 self._input_tools = [tool for tool in input_tools] if is_given(input_tools) else []
324 self._response_format = response_format
325 self._rich_response_format: type | Omit = response_format if inspect.isclass(response_format) else omit
326
327 def get_final_completion(self) -> ParsedChatCompletion[ResponseFormatT]:
328 """Parse the final completion object.
329
330 Note this does not provide any guarantees that the stream has actually finished, you must
331 only call this method when the stream is finished.
332 """
333 return parse_chat_completion(
334 chat_completion=self.current_completion_snapshot,
335 response_format=self._rich_response_format,
336 input_tools=self._input_tools,
337 )
338
339 @property
340 def current_completion_snapshot(self) -> ParsedChatCompletionSnapshot:
341 assert self.__current_completion_snapshot is not None
342 return self.__current_completion_snapshot
343
344 def handle_chunk(self, chunk: ChatCompletionChunk) -> Iterable[ChatCompletionStreamEvent[ResponseFormatT]]:
345 """Accumulate a new chunk into the snapshot and returns an iterable of events to yield."""
346 self.__current_completion_snapshot = self._accumulate_chunk(chunk)
347
348 return self._build_events(
349 chunk=chunk,
350 completion_snapshot=self.__current_completion_snapshot,
351 )
352
353 def _get_choice_state(self, choice: ChoiceChunk) -> ChoiceEventState:
354 try:
355 return self.__choice_event_states[choice.index]
356 except IndexError:
357 choice_state = ChoiceEventState(input_tools=self._input_tools)
358 self.__choice_event_states.append(choice_state)
359 return choice_state
360
361 def _accumulate_chunk(self, chunk: ChatCompletionChunk) -> ParsedChatCompletionSnapshot:
362 completion_snapshot = self.__current_completion_snapshot
363
364 if completion_snapshot is None:
365 return _convert_initial_chunk_into_snapshot(chunk)
366
367 for choice in chunk.choices:
368 try:
369 choice_snapshot = completion_snapshot.choices[choice.index]
370 previous_tool_calls = choice_snapshot.message.tool_calls or []
371
372 choice_snapshot.message = cast(
373 ParsedChatCompletionMessageSnapshot,
374 construct_type(
375 type_=ParsedChatCompletionMessageSnapshot,
376 value=accumulate_delta(
377 cast(
378 "dict[object, object]",
379 model_dump(
380 choice_snapshot.message,
381 # we don't want to serialise / deserialise our custom properties
382 # as they won't appear in the delta and we don't want to have to
383 # continuosly reparse the content
384 exclude=cast(
385 # cast required as mypy isn't smart enough to infer `True` here to `Literal[True]`
386 IncEx,
387 {
388 "parsed": True,
389 "tool_calls": {
390 idx: {"function": {"parsed_arguments": True}}
391 for idx, _ in enumerate(choice_snapshot.message.tool_calls or [])
392 },
393 },
394 ),
395 ),
396 ),
397 cast("dict[object, object]", choice.delta.to_dict()),
398 ),
399 ),
400 )
401
402 # ensure tools that have already been parsed are added back into the newly
403 # constructed message snapshot
404 for tool_index, prev_tool in enumerate(previous_tool_calls):
405 new_tool = (choice_snapshot.message.tool_calls or [])[tool_index]
406
407 if prev_tool.type == "function":
408 assert new_tool.type == "function"
409 new_tool.function.parsed_arguments = prev_tool.function.parsed_arguments
410 elif TYPE_CHECKING: # type: ignore[unreachable]
411 assert_never(prev_tool)
412 except IndexError:
413 choice_snapshot = cast(
414 ParsedChoiceSnapshot,
415 construct_type(
416 type_=ParsedChoiceSnapshot,
417 value={
418 **choice.model_dump(exclude_unset=True, exclude={"delta"}),
419 "message": choice.delta.to_dict(),
420 },
421 ),
422 )
423 completion_snapshot.choices.append(choice_snapshot)
424
425 if choice.finish_reason:
426 choice_snapshot.finish_reason = choice.finish_reason
427
428 if has_parseable_input(response_format=self._response_format, input_tools=self._input_tools):
429 if choice.finish_reason == "length":
430 # at the time of writing, `.usage` will always be `None` but
431 # we include it here in case that is changed in the future
432 raise LengthFinishReasonError(completion=completion_snapshot)
433
434 if choice.finish_reason == "content_filter":
435 raise ContentFilterFinishReasonError()
436
437 if (
438 choice_snapshot.message.content
439 and not choice_snapshot.message.refusal
440 and is_given(self._rich_response_format)
441 # partial parsing fails on white-space
442 and choice_snapshot.message.content.lstrip()
443 ):
444 choice_snapshot.message.parsed = from_json(
445 bytes(choice_snapshot.message.content, "utf-8"),
446 partial_mode=True,
447 )
448
449 for tool_call_chunk in choice.delta.tool_calls or []:
450 tool_call_snapshot = (choice_snapshot.message.tool_calls or [])[tool_call_chunk.index]
451
452 if tool_call_snapshot.type == "function":
453 input_tool = get_input_tool_by_name(
454 input_tools=self._input_tools, name=tool_call_snapshot.function.name
455 )
456
457 if (
458 input_tool
459 and input_tool.get("function", {}).get("strict")
460 and tool_call_snapshot.function.arguments
461 ):
462 tool_call_snapshot.function.parsed_arguments = from_json(
463 bytes(tool_call_snapshot.function.arguments, "utf-8"),
464 partial_mode=True,
465 )
466 elif TYPE_CHECKING: # type: ignore[unreachable]
467 assert_never(tool_call_snapshot)
468
469 if choice.logprobs is not None:
470 if choice_snapshot.logprobs is None:
471 choice_snapshot.logprobs = build(
472 ChoiceLogprobs,
473 content=choice.logprobs.content,
474 refusal=choice.logprobs.refusal,
475 )
476 else:
477 if choice.logprobs.content:
478 if choice_snapshot.logprobs.content is None:
479 choice_snapshot.logprobs.content = []
480
481 choice_snapshot.logprobs.content.extend(choice.logprobs.content)
482
483 if choice.logprobs.refusal:
484 if choice_snapshot.logprobs.refusal is None:
485 choice_snapshot.logprobs.refusal = []
486
487 choice_snapshot.logprobs.refusal.extend(choice.logprobs.refusal)
488
489 completion_snapshot.usage = chunk.usage
490 completion_snapshot.system_fingerprint = chunk.system_fingerprint
491
492 return completion_snapshot
493
494 def _build_events(
495 self,
496 *,
497 chunk: ChatCompletionChunk,
498 completion_snapshot: ParsedChatCompletionSnapshot,
499 ) -> list[ChatCompletionStreamEvent[ResponseFormatT]]:
500 events_to_fire: list[ChatCompletionStreamEvent[ResponseFormatT]] = []
501
502 events_to_fire.append(
503 build(ChunkEvent, type="chunk", chunk=chunk, snapshot=completion_snapshot),
504 )
505
506 for choice in chunk.choices:
507 choice_state = self._get_choice_state(choice)
508 choice_snapshot = completion_snapshot.choices[choice.index]
509
510 if choice.delta.content is not None and choice_snapshot.message.content is not None:
511 events_to_fire.append(
512 build(
513 ContentDeltaEvent,
514 type="content.delta",
515 delta=choice.delta.content,
516 snapshot=choice_snapshot.message.content,
517 parsed=choice_snapshot.message.parsed,
518 )
519 )
520
521 if choice.delta.refusal is not None and choice_snapshot.message.refusal is not None:
522 events_to_fire.append(
523 build(
524 RefusalDeltaEvent,
525 type="refusal.delta",
526 delta=choice.delta.refusal,
527 snapshot=choice_snapshot.message.refusal,
528 )
529 )
530
531 if choice.delta.tool_calls:
532 tool_calls = choice_snapshot.message.tool_calls
533 assert tool_calls is not None
534
535 for tool_call_delta in choice.delta.tool_calls:
536 tool_call = tool_calls[tool_call_delta.index]
537
538 if tool_call.type == "function":
539 assert tool_call_delta.function is not None
540 events_to_fire.append(
541 build(
542 FunctionToolCallArgumentsDeltaEvent,
543 type="tool_calls.function.arguments.delta",
544 name=tool_call.function.name,
545 index=tool_call_delta.index,
546 arguments=tool_call.function.arguments,
547 parsed_arguments=tool_call.function.parsed_arguments,
548 arguments_delta=tool_call_delta.function.arguments or "",
549 )
550 )
551 elif TYPE_CHECKING: # type: ignore[unreachable]
552 assert_never(tool_call)
553
554 if choice.logprobs is not None and choice_snapshot.logprobs is not None:
555 if choice.logprobs.content and choice_snapshot.logprobs.content:
556 events_to_fire.append(
557 build(
558 LogprobsContentDeltaEvent,
559 type="logprobs.content.delta",
560 content=choice.logprobs.content,
561 snapshot=choice_snapshot.logprobs.content,
562 ),
563 )
564
565 if choice.logprobs.refusal and choice_snapshot.logprobs.refusal:
566 events_to_fire.append(
567 build(
568 LogprobsRefusalDeltaEvent,
569 type="logprobs.refusal.delta",
570 refusal=choice.logprobs.refusal,
571 snapshot=choice_snapshot.logprobs.refusal,
572 ),
573 )
574
575 events_to_fire.extend(
576 choice_state.get_done_events(
577 choice_chunk=choice,
578 choice_snapshot=choice_snapshot,
579 response_format=self._response_format,
580 )
581 )
582
583 return events_to_fire
584
585
586class ChoiceEventState:
587 def __init__(self, *, input_tools: list[ChatCompletionToolUnionParam]) -> None:
588 self._input_tools = input_tools
589
590 self._content_done = False
591 self._refusal_done = False
592 self._logprobs_content_done = False
593 self._logprobs_refusal_done = False
594 self._done_tool_calls: set[int] = set()
595 self.__current_tool_call_index: int | None = None
596
597 def get_done_events(
598 self,
599 *,
600 choice_chunk: ChoiceChunk,
601 choice_snapshot: ParsedChoiceSnapshot,
602 response_format: type[ResponseFormatT] | ResponseFormatParam | Omit,
603 ) -> list[ChatCompletionStreamEvent[ResponseFormatT]]:
604 events_to_fire: list[ChatCompletionStreamEvent[ResponseFormatT]] = []
605
606 if choice_snapshot.finish_reason:
607 events_to_fire.extend(
608 self._content_done_events(choice_snapshot=choice_snapshot, response_format=response_format)
609 )
610
611 if (
612 self.__current_tool_call_index is not None
613 and self.__current_tool_call_index not in self._done_tool_calls
614 ):
615 self._add_tool_done_event(
616 events_to_fire=events_to_fire,
617 choice_snapshot=choice_snapshot,
618 tool_index=self.__current_tool_call_index,
619 )
620
621 for tool_call in choice_chunk.delta.tool_calls or []:
622 if self.__current_tool_call_index != tool_call.index:
623 events_to_fire.extend(
624 self._content_done_events(choice_snapshot=choice_snapshot, response_format=response_format)
625 )
626
627 if self.__current_tool_call_index is not None:
628 self._add_tool_done_event(
629 events_to_fire=events_to_fire,
630 choice_snapshot=choice_snapshot,
631 tool_index=self.__current_tool_call_index,
632 )
633
634 self.__current_tool_call_index = tool_call.index
635
636 return events_to_fire
637
638 def _content_done_events(
639 self,
640 *,
641 choice_snapshot: ParsedChoiceSnapshot,
642 response_format: type[ResponseFormatT] | ResponseFormatParam | Omit,
643 ) -> list[ChatCompletionStreamEvent[ResponseFormatT]]:
644 events_to_fire: list[ChatCompletionStreamEvent[ResponseFormatT]] = []
645
646 if choice_snapshot.message.content and not self._content_done:
647 self._content_done = True
648
649 parsed = maybe_parse_content(
650 response_format=response_format,
651 message=choice_snapshot.message,
652 )
653
654 # update the parsed content to now use the richer `response_format`
655 # as opposed to the raw JSON-parsed object as the content is now
656 # complete and can be fully validated.
657 choice_snapshot.message.parsed = parsed
658
659 events_to_fire.append(
660 build(
661 # we do this dance so that when the `ContentDoneEvent` instance
662 # is printed at runtime the class name will include the solved
663 # type variable, e.g. `ContentDoneEvent[MyModelType]`
664 cast( # pyright: ignore[reportUnnecessaryCast]
665 "type[ContentDoneEvent[ResponseFormatT]]",
666 cast(Any, ContentDoneEvent)[solve_response_format_t(response_format)],
667 ),
668 type="content.done",
669 content=choice_snapshot.message.content,
670 parsed=parsed,
671 ),
672 )
673
674 if choice_snapshot.message.refusal is not None and not self._refusal_done:
675 self._refusal_done = True
676 events_to_fire.append(
677 build(RefusalDoneEvent, type="refusal.done", refusal=choice_snapshot.message.refusal),
678 )
679
680 if (
681 choice_snapshot.logprobs is not None
682 and choice_snapshot.logprobs.content is not None
683 and not self._logprobs_content_done
684 ):
685 self._logprobs_content_done = True
686 events_to_fire.append(
687 build(LogprobsContentDoneEvent, type="logprobs.content.done", content=choice_snapshot.logprobs.content),
688 )
689
690 if (
691 choice_snapshot.logprobs is not None
692 and choice_snapshot.logprobs.refusal is not None
693 and not self._logprobs_refusal_done
694 ):
695 self._logprobs_refusal_done = True
696 events_to_fire.append(
697 build(LogprobsRefusalDoneEvent, type="logprobs.refusal.done", refusal=choice_snapshot.logprobs.refusal),
698 )
699
700 return events_to_fire
701
702 def _add_tool_done_event(
703 self,
704 *,
705 events_to_fire: list[ChatCompletionStreamEvent[ResponseFormatT]],
706 choice_snapshot: ParsedChoiceSnapshot,
707 tool_index: int,
708 ) -> None:
709 if tool_index in self._done_tool_calls:
710 return
711
712 self._done_tool_calls.add(tool_index)
713
714 assert choice_snapshot.message.tool_calls is not None
715 tool_call_snapshot = choice_snapshot.message.tool_calls[tool_index]
716
717 if tool_call_snapshot.type == "function":
718 parsed_arguments = parse_function_tool_arguments(
719 input_tools=self._input_tools, function=tool_call_snapshot.function
720 )
721
722 # update the parsed content to potentially use a richer type
723 # as opposed to the raw JSON-parsed object as the content is now
724 # complete and can be fully validated.
725 tool_call_snapshot.function.parsed_arguments = parsed_arguments
726
727 events_to_fire.append(
728 build(
729 FunctionToolCallArgumentsDoneEvent,
730 type="tool_calls.function.arguments.done",
731 index=tool_index,
732 name=tool_call_snapshot.function.name,
733 arguments=tool_call_snapshot.function.arguments,
734 parsed_arguments=parsed_arguments,
735 )
736 )
737 elif TYPE_CHECKING: # type: ignore[unreachable]
738 assert_never(tool_call_snapshot)
739
740
741def _convert_initial_chunk_into_snapshot(chunk: ChatCompletionChunk) -> ParsedChatCompletionSnapshot:
742 data = chunk.to_dict()
743 choices = cast("list[object]", data["choices"])
744
745 for choice in chunk.choices:
746 choices[choice.index] = {
747 **choice.model_dump(exclude_unset=True, exclude={"delta"}),
748 "message": choice.delta.to_dict(),
749 }
750
751 return cast(
752 ParsedChatCompletionSnapshot,
753 construct_type(
754 type_=ParsedChatCompletionSnapshot,
755 value={
756 "system_fingerprint": None,
757 **data,
758 "object": "chat.completion",
759 },
760 ),
761 )
762
763
764def _is_valid_chat_completion_chunk_weak(sse_event: ChatCompletionChunk) -> bool:
765 # Although the _raw_stream is always supposed to contain only objects adhering to ChatCompletionChunk schema,
766 # this is broken by the Azure OpenAI in case of Asynchronous Filter enabled.
767 # An easy filter is to check for the "object" property:
768 # - should be "chat.completion.chunk" for a ChatCompletionChunk;
769 # - is an empty string for Asynchronous Filter events.
770 return sse_event.object == "chat.completion.chunk" # type: ignore # pylance reports this as a useless check