Commit 778e28e5
Changed files (3)
src
openai
lib
streaming
tests
lib
src/openai/lib/streaming/chat/__init__.py
@@ -21,6 +21,7 @@ from ._events import (
from ._completions import (
ChatCompletionStream as ChatCompletionStream,
AsyncChatCompletionStream as AsyncChatCompletionStream,
+ ChatCompletionStreamState as ChatCompletionStreamState,
ChatCompletionStreamManager as ChatCompletionStreamManager,
AsyncChatCompletionStreamManager as AsyncChatCompletionStreamManager,
)
src/openai/lib/streaming/chat/_completions.py
@@ -287,11 +287,31 @@ class AsyncChatCompletionStreamManager(Generic[ResponseFormatT]):
class ChatCompletionStreamState(Generic[ResponseFormatT]):
+ """Helper class for manually accumulating `ChatCompletionChunk`s into a final `ChatCompletion` object.
+
+ This is useful in cases where you can't always use the `.stream()` method, e.g.
+
+ ```py
+ from openai.lib.streaming.chat import ChatCompletionStreamState
+
+ state = ChatCompletionStreamState()
+
+ stream = client.chat.completions.create(..., stream=True)
+ for chunk in response:
+ state.handle_chunk(chunk)
+
+ # can also access the accumulated `ChatCompletion` mid-stream
+ state.current_completion_snapshot
+
+ print(state.get_final_completion())
+ ```
+ """
+
def __init__(
self,
*,
- input_tools: Iterable[ChatCompletionToolParam] | NotGiven,
- response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
+ input_tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
+ response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven = NOT_GIVEN,
) -> None:
self.__current_completion_snapshot: ParsedChatCompletionSnapshot | None = None
self.__choice_event_states: list[ChoiceEventState] = []
@@ -301,6 +321,11 @@ class ChatCompletionStreamState(Generic[ResponseFormatT]):
self._rich_response_format: type | NotGiven = response_format if inspect.isclass(response_format) else NOT_GIVEN
def get_final_completion(self) -> ParsedChatCompletion[ResponseFormatT]:
+ """Parse the final completion object.
+
+ Note this does not provide any guarantees that the stream has actually finished, you must
+ only call this method when the stream is finished.
+ """
return parse_chat_completion(
chat_completion=self.current_completion_snapshot,
response_format=self._rich_response_format,
@@ -312,8 +337,8 @@ class ChatCompletionStreamState(Generic[ResponseFormatT]):
assert self.__current_completion_snapshot is not None
return self.__current_completion_snapshot
- def handle_chunk(self, chunk: ChatCompletionChunk) -> list[ChatCompletionStreamEvent[ResponseFormatT]]:
- """Accumulate a new chunk into the snapshot and returns a list of events to yield."""
+ def handle_chunk(self, chunk: ChatCompletionChunk) -> Iterable[ChatCompletionStreamEvent[ResponseFormatT]]:
+ """Accumulate a new chunk into the snapshot and returns an iterable of events to yield."""
self.__current_completion_snapshot = self._accumulate_chunk(chunk)
return self._build_events(
tests/lib/chat/test_completions_streaming.py
@@ -13,12 +13,14 @@ from inline_snapshot import external, snapshot, outsource
import openai
from openai import OpenAI, AsyncOpenAI
-from openai._utils import assert_signatures_in_sync
+from openai._utils import consume_sync_iterator, assert_signatures_in_sync
from openai._compat import model_copy
+from openai.types.chat import ChatCompletionChunk
from openai.lib.streaming.chat import (
ContentDoneEvent,
ChatCompletionStream,
ChatCompletionStreamEvent,
+ ChatCompletionStreamState,
ChatCompletionStreamManager,
ParsedChatCompletionSnapshot,
)
@@ -997,6 +999,55 @@ FunctionToolCallArgumentsDoneEvent(
)
+@pytest.mark.respx(base_url=base_url)
+def test_chat_completion_state_helper(client: OpenAI, respx_mock: MockRouter, monkeypatch: pytest.MonkeyPatch) -> None:
+ state = ChatCompletionStreamState()
+
+ def streamer(client: OpenAI) -> Iterator[ChatCompletionChunk]:
+ stream = client.chat.completions.create(
+ model="gpt-4o-2024-08-06",
+ messages=[
+ {
+ "role": "user",
+ "content": "What's the weather like in SF?",
+ },
+ ],
+ stream=True,
+ )
+ for chunk in stream:
+ state.handle_chunk(chunk)
+ yield chunk
+
+ _make_raw_stream_snapshot_request(
+ streamer,
+ content_snapshot=snapshot(external("e2aad469b71d*.bin")),
+ mock_client=client,
+ respx_mock=respx_mock,
+ )
+
+ assert print_obj(state.get_final_completion().choices, monkeypatch) == snapshot(
+ """\
+[
+ ParsedChoice[NoneType](
+ finish_reason='stop',
+ index=0,
+ logprobs=None,
+ message=ParsedChatCompletionMessage[NoneType](
+ audio=None,
+ content="I'm unable to provide real-time weather updates. To get the current weather in San Francisco, I
+recommend checking a reliable weather website or a weather app.",
+ function_call=None,
+ parsed=None,
+ refusal=None,
+ role='assistant',
+ tool_calls=[]
+ )
+ )
+]
+"""
+ )
+
+
@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
def test_stream_method_in_sync(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
checking_client: OpenAI | AsyncOpenAI = client if sync else async_client
@@ -1075,3 +1126,44 @@ def _make_stream_snapshot_request(
client.close()
return listener
+
+
+def _make_raw_stream_snapshot_request(
+ func: Callable[[OpenAI], Iterator[ChatCompletionChunk]],
+ *,
+ content_snapshot: Any,
+ respx_mock: MockRouter,
+ mock_client: OpenAI,
+) -> None:
+ live = os.environ.get("OPENAI_LIVE") == "1"
+ if live:
+
+ def _on_response(response: httpx.Response) -> None:
+ # update the content snapshot
+ assert outsource(response.read()) == content_snapshot
+
+ respx_mock.stop()
+
+ client = OpenAI(
+ http_client=httpx.Client(
+ event_hooks={
+ "response": [_on_response],
+ }
+ )
+ )
+ else:
+ respx_mock.post("/chat/completions").mock(
+ return_value=httpx.Response(
+ 200,
+ content=content_snapshot._old_value._load_value(),
+ headers={"content-type": "text/event-stream"},
+ )
+ )
+
+ client = mock_client
+
+ stream = func(client)
+ consume_sync_iterator(stream)
+
+ if live:
+ client.close()