Commit 778e28e5

Robert Craigie <robert@craigie.dev>
2024-11-30 05:05:14
feat(client): make ChatCompletionStreamState public (#1898)
1 parent 6974a98
Changed files (3)
src
openai
lib
tests
src/openai/lib/streaming/chat/__init__.py
@@ -21,6 +21,7 @@ from ._events import (
 from ._completions import (
     ChatCompletionStream as ChatCompletionStream,
     AsyncChatCompletionStream as AsyncChatCompletionStream,
+    ChatCompletionStreamState as ChatCompletionStreamState,
     ChatCompletionStreamManager as ChatCompletionStreamManager,
     AsyncChatCompletionStreamManager as AsyncChatCompletionStreamManager,
 )
src/openai/lib/streaming/chat/_completions.py
@@ -287,11 +287,31 @@ class AsyncChatCompletionStreamManager(Generic[ResponseFormatT]):
 
 
 class ChatCompletionStreamState(Generic[ResponseFormatT]):
+    """Helper class for manually accumulating `ChatCompletionChunk`s into a final `ChatCompletion` object.
+
+    This is useful in cases where you can't always use the `.stream()` method, e.g.
+
+    ```py
+    from openai.lib.streaming.chat import ChatCompletionStreamState
+
+    state = ChatCompletionStreamState()
+
+    stream = client.chat.completions.create(..., stream=True)
+    for chunk in response:
+        state.handle_chunk(chunk)
+
+        # can also access the accumulated `ChatCompletion` mid-stream
+        state.current_completion_snapshot
+
+    print(state.get_final_completion())
+    ```
+    """
+
     def __init__(
         self,
         *,
-        input_tools: Iterable[ChatCompletionToolParam] | NotGiven,
-        response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
+        input_tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
+        response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven = NOT_GIVEN,
     ) -> None:
         self.__current_completion_snapshot: ParsedChatCompletionSnapshot | None = None
         self.__choice_event_states: list[ChoiceEventState] = []
@@ -301,6 +321,11 @@ class ChatCompletionStreamState(Generic[ResponseFormatT]):
         self._rich_response_format: type | NotGiven = response_format if inspect.isclass(response_format) else NOT_GIVEN
 
     def get_final_completion(self) -> ParsedChatCompletion[ResponseFormatT]:
+        """Parse the final completion object.
+
+        Note this does not provide any guarantees that the stream has actually finished, you must
+        only call this method when the stream is finished.
+        """
         return parse_chat_completion(
             chat_completion=self.current_completion_snapshot,
             response_format=self._rich_response_format,
@@ -312,8 +337,8 @@ class ChatCompletionStreamState(Generic[ResponseFormatT]):
         assert self.__current_completion_snapshot is not None
         return self.__current_completion_snapshot
 
-    def handle_chunk(self, chunk: ChatCompletionChunk) -> list[ChatCompletionStreamEvent[ResponseFormatT]]:
-        """Accumulate a new chunk into the snapshot and returns a list of events to yield."""
+    def handle_chunk(self, chunk: ChatCompletionChunk) -> Iterable[ChatCompletionStreamEvent[ResponseFormatT]]:
+        """Accumulate a new chunk into the snapshot and returns an iterable of events to yield."""
         self.__current_completion_snapshot = self._accumulate_chunk(chunk)
 
         return self._build_events(
tests/lib/chat/test_completions_streaming.py
@@ -13,12 +13,14 @@ from inline_snapshot import external, snapshot, outsource
 
 import openai
 from openai import OpenAI, AsyncOpenAI
-from openai._utils import assert_signatures_in_sync
+from openai._utils import consume_sync_iterator, assert_signatures_in_sync
 from openai._compat import model_copy
+from openai.types.chat import ChatCompletionChunk
 from openai.lib.streaming.chat import (
     ContentDoneEvent,
     ChatCompletionStream,
     ChatCompletionStreamEvent,
+    ChatCompletionStreamState,
     ChatCompletionStreamManager,
     ParsedChatCompletionSnapshot,
 )
@@ -997,6 +999,55 @@ FunctionToolCallArgumentsDoneEvent(
     )
 
 
+@pytest.mark.respx(base_url=base_url)
+def test_chat_completion_state_helper(client: OpenAI, respx_mock: MockRouter, monkeypatch: pytest.MonkeyPatch) -> None:
+    state = ChatCompletionStreamState()
+
+    def streamer(client: OpenAI) -> Iterator[ChatCompletionChunk]:
+        stream = client.chat.completions.create(
+            model="gpt-4o-2024-08-06",
+            messages=[
+                {
+                    "role": "user",
+                    "content": "What's the weather like in SF?",
+                },
+            ],
+            stream=True,
+        )
+        for chunk in stream:
+            state.handle_chunk(chunk)
+            yield chunk
+
+    _make_raw_stream_snapshot_request(
+        streamer,
+        content_snapshot=snapshot(external("e2aad469b71d*.bin")),
+        mock_client=client,
+        respx_mock=respx_mock,
+    )
+
+    assert print_obj(state.get_final_completion().choices, monkeypatch) == snapshot(
+        """\
+[
+    ParsedChoice[NoneType](
+        finish_reason='stop',
+        index=0,
+        logprobs=None,
+        message=ParsedChatCompletionMessage[NoneType](
+            audio=None,
+            content="I'm unable to provide real-time weather updates. To get the current weather in San Francisco, I 
+recommend checking a reliable weather website or a weather app.",
+            function_call=None,
+            parsed=None,
+            refusal=None,
+            role='assistant',
+            tool_calls=[]
+        )
+    )
+]
+"""
+    )
+
+
 @pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
 def test_stream_method_in_sync(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
     checking_client: OpenAI | AsyncOpenAI = client if sync else async_client
@@ -1075,3 +1126,44 @@ def _make_stream_snapshot_request(
         client.close()
 
     return listener
+
+
+def _make_raw_stream_snapshot_request(
+    func: Callable[[OpenAI], Iterator[ChatCompletionChunk]],
+    *,
+    content_snapshot: Any,
+    respx_mock: MockRouter,
+    mock_client: OpenAI,
+) -> None:
+    live = os.environ.get("OPENAI_LIVE") == "1"
+    if live:
+
+        def _on_response(response: httpx.Response) -> None:
+            # update the content snapshot
+            assert outsource(response.read()) == content_snapshot
+
+        respx_mock.stop()
+
+        client = OpenAI(
+            http_client=httpx.Client(
+                event_hooks={
+                    "response": [_on_response],
+                }
+            )
+        )
+    else:
+        respx_mock.post("/chat/completions").mock(
+            return_value=httpx.Response(
+                200,
+                content=content_snapshot._old_value._load_value(),
+                headers={"content-type": "text/event-stream"},
+            )
+        )
+
+        client = mock_client
+
+    stream = func(client)
+    consume_sync_iterator(stream)
+
+    if live:
+        client.close()