Commit 7e1e4572

Stainless Bot <107565488+stainless-bot@users.noreply.github.com>
2024-04-03 21:36:41
fix(client): correct logic for line decoding in streaming (#1293)
1 parent f0bdef0
Changed files (2)
src/openai/_streaming.py
@@ -24,7 +24,7 @@ class Stream(Generic[_T]):
 
     response: httpx.Response
 
-    _decoder: SSEDecoder | SSEBytesDecoder
+    _decoder: SSEBytesDecoder
 
     def __init__(
         self,
@@ -47,10 +47,7 @@ class Stream(Generic[_T]):
             yield item
 
     def _iter_events(self) -> Iterator[ServerSentEvent]:
-        if isinstance(self._decoder, SSEBytesDecoder):
-            yield from self._decoder.iter_bytes(self.response.iter_bytes())
-        else:
-            yield from self._decoder.iter(self.response.iter_lines())
+        yield from self._decoder.iter_bytes(self.response.iter_bytes())
 
     def __stream__(self) -> Iterator[_T]:
         cast_to = cast(Any, self._cast_to)
@@ -151,12 +148,8 @@ class AsyncStream(Generic[_T]):
             yield item
 
     async def _iter_events(self) -> AsyncIterator[ServerSentEvent]:
-        if isinstance(self._decoder, SSEBytesDecoder):
-            async for sse in self._decoder.aiter_bytes(self.response.aiter_bytes()):
-                yield sse
-        else:
-            async for sse in self._decoder.aiter(self.response.aiter_lines()):
-                yield sse
+        async for sse in self._decoder.aiter_bytes(self.response.aiter_bytes()):
+            yield sse
 
     async def __stream__(self) -> AsyncIterator[_T]:
         cast_to = cast(Any, self._cast_to)
@@ -282,21 +275,49 @@ class SSEDecoder:
         self._last_event_id = None
         self._retry = None
 
-    def iter(self, iterator: Iterator[str]) -> Iterator[ServerSentEvent]:
-        """Given an iterator that yields lines, iterate over it & yield every event encountered"""
-        for line in iterator:
-            line = line.rstrip("\n")
-            sse = self.decode(line)
-            if sse is not None:
-                yield sse
-
-    async def aiter(self, iterator: AsyncIterator[str]) -> AsyncIterator[ServerSentEvent]:
-        """Given an async iterator that yields lines, iterate over it & yield every event encountered"""
-        async for line in iterator:
-            line = line.rstrip("\n")
-            sse = self.decode(line)
-            if sse is not None:
-                yield sse
+    def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]:
+        """Given an iterator that yields raw binary data, iterate over it & yield every event encountered"""
+        for chunk in self._iter_chunks(iterator):
+            # Split before decoding so splitlines() only uses \r and \n
+            for raw_line in chunk.splitlines():
+                line = raw_line.decode("utf-8")
+                sse = self.decode(line)
+                if sse:
+                    yield sse
+
+    def _iter_chunks(self, iterator: Iterator[bytes]) -> Iterator[bytes]:
+        """Given an iterator that yields raw binary data, iterate over it and yield individual SSE chunks"""
+        data = b""
+        for chunk in iterator:
+            for line in chunk.splitlines(keepends=True):
+                data += line
+                if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")):
+                    yield data
+                    data = b""
+        if data:
+            yield data
+
+    async def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]:
+        """Given an iterator that yields raw binary data, iterate over it & yield every event encountered"""
+        async for chunk in self._aiter_chunks(iterator):
+            # Split before decoding so splitlines() only uses \r and \n
+            for raw_line in chunk.splitlines():
+                line = raw_line.decode("utf-8")
+                sse = self.decode(line)
+                if sse:
+                    yield sse
+
+    async def _aiter_chunks(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[bytes]:
+        """Given an iterator that yields raw binary data, iterate over it and yield individual SSE chunks"""
+        data = b""
+        async for chunk in iterator:
+            for line in chunk.splitlines(keepends=True):
+                data += line
+                if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")):
+                    yield data
+                    data = b""
+        if data:
+            yield data
 
     def decode(self, line: str) -> ServerSentEvent | None:
         # See: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation  # noqa: E501
tests/test_streaming.py
@@ -1,104 +1,248 @@
+from __future__ import annotations
+
 from typing import Iterator, AsyncIterator
 
+import httpx
 import pytest
 
-from openai._streaming import SSEDecoder
+from openai import OpenAI, AsyncOpenAI
+from openai._streaming import Stream, AsyncStream, ServerSentEvent
 
 
 @pytest.mark.asyncio
-async def test_basic_async() -> None:
-    async def body() -> AsyncIterator[str]:
-        yield "event: completion"
-        yield 'data: {"foo":true}'
-        yield ""
-
-    async for sse in SSEDecoder().aiter(body()):
-        assert sse.event == "completion"
-        assert sse.json() == {"foo": True}
+@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
+async def test_basic(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
+    def body() -> Iterator[bytes]:
+        yield b"event: completion\n"
+        yield b'data: {"foo":true}\n'
+        yield b"\n"
 
+    iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client)
 
-def test_basic() -> None:
-    def body() -> Iterator[str]:
-        yield "event: completion"
-        yield 'data: {"foo":true}'
-        yield ""
-
-    it = SSEDecoder().iter(body())
-    sse = next(it)
+    sse = await iter_next(iterator)
     assert sse.event == "completion"
     assert sse.json() == {"foo": True}
 
-    with pytest.raises(StopIteration):
-        next(it)
+    await assert_empty_iter(iterator)
 
 
-def test_data_missing_event() -> None:
-    def body() -> Iterator[str]:
-        yield 'data: {"foo":true}'
-        yield ""
+@pytest.mark.asyncio
+@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
+async def test_data_missing_event(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
+    def body() -> Iterator[bytes]:
+        yield b'data: {"foo":true}\n'
+        yield b"\n"
 
-    it = SSEDecoder().iter(body())
-    sse = next(it)
+    iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client)
+
+    sse = await iter_next(iterator)
     assert sse.event is None
     assert sse.json() == {"foo": True}
 
-    with pytest.raises(StopIteration):
-        next(it)
+    await assert_empty_iter(iterator)
+
 
+@pytest.mark.asyncio
+@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
+async def test_event_missing_data(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
+    def body() -> Iterator[bytes]:
+        yield b"event: ping\n"
+        yield b"\n"
 
-def test_event_missing_data() -> None:
-    def body() -> Iterator[str]:
-        yield "event: ping"
-        yield ""
+    iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client)
 
-    it = SSEDecoder().iter(body())
-    sse = next(it)
+    sse = await iter_next(iterator)
     assert sse.event == "ping"
     assert sse.data == ""
 
-    with pytest.raises(StopIteration):
-        next(it)
+    await assert_empty_iter(iterator)
 
 
-def test_multiple_events() -> None:
-    def body() -> Iterator[str]:
-        yield "event: ping"
-        yield ""
-        yield "event: completion"
-        yield ""
+@pytest.mark.asyncio
+@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
+async def test_multiple_events(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
+    def body() -> Iterator[bytes]:
+        yield b"event: ping\n"
+        yield b"\n"
+        yield b"event: completion\n"
+        yield b"\n"
 
-    it = SSEDecoder().iter(body())
+    iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client)
 
-    sse = next(it)
+    sse = await iter_next(iterator)
     assert sse.event == "ping"
     assert sse.data == ""
 
-    sse = next(it)
+    sse = await iter_next(iterator)
     assert sse.event == "completion"
     assert sse.data == ""
 
-    with pytest.raises(StopIteration):
-        next(it)
-
-
-def test_multiple_events_with_data() -> None:
-    def body() -> Iterator[str]:
-        yield "event: ping"
-        yield 'data: {"foo":true}'
-        yield ""
-        yield "event: completion"
-        yield 'data: {"bar":false}'
-        yield ""
+    await assert_empty_iter(iterator)
 
-    it = SSEDecoder().iter(body())
 
-    sse = next(it)
+@pytest.mark.asyncio
+@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
+async def test_multiple_events_with_data(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
+    def body() -> Iterator[bytes]:
+        yield b"event: ping\n"
+        yield b'data: {"foo":true}\n'
+        yield b"\n"
+        yield b"event: completion\n"
+        yield b'data: {"bar":false}\n'
+        yield b"\n"
+
+    iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client)
+
+    sse = await iter_next(iterator)
     assert sse.event == "ping"
     assert sse.json() == {"foo": True}
 
-    sse = next(it)
+    sse = await iter_next(iterator)
     assert sse.event == "completion"
     assert sse.json() == {"bar": False}
 
-    with pytest.raises(StopIteration):
-        next(it)
+    await assert_empty_iter(iterator)
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
+async def test_multiple_data_lines_with_empty_line(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
+    def body() -> Iterator[bytes]:
+        yield b"event: ping\n"
+        yield b"data: {\n"
+        yield b'data: "foo":\n'
+        yield b"data: \n"
+        yield b"data:\n"
+        yield b"data: true}\n"
+        yield b"\n\n"
+
+    iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client)
+
+    sse = await iter_next(iterator)
+    assert sse.event == "ping"
+    assert sse.json() == {"foo": True}
+    assert sse.data == '{\n"foo":\n\n\ntrue}'
+
+    await assert_empty_iter(iterator)
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
+async def test_data_json_escaped_double_new_line(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
+    def body() -> Iterator[bytes]:
+        yield b"event: ping\n"
+        yield b'data: {"foo": "my long\\n\\ncontent"}'
+        yield b"\n\n"
+
+    iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client)
+
+    sse = await iter_next(iterator)
+    assert sse.event == "ping"
+    assert sse.json() == {"foo": "my long\n\ncontent"}
+
+    await assert_empty_iter(iterator)
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
+async def test_multiple_data_lines(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
+    def body() -> Iterator[bytes]:
+        yield b"event: ping\n"
+        yield b"data: {\n"
+        yield b'data: "foo":\n'
+        yield b"data: true}\n"
+        yield b"\n\n"
+
+    iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client)
+
+    sse = await iter_next(iterator)
+    assert sse.event == "ping"
+    assert sse.json() == {"foo": True}
+
+    await assert_empty_iter(iterator)
+
+
+@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
+async def test_special_new_line_character(
+    sync: bool,
+    client: OpenAI,
+    async_client: AsyncOpenAI,
+) -> None:
+    def body() -> Iterator[bytes]:
+        yield b'data: {"content":" culpa"}\n'
+        yield b"\n"
+        yield b'data: {"content":" \xe2\x80\xa8"}\n'
+        yield b"\n"
+        yield b'data: {"content":"foo"}\n'
+        yield b"\n"
+
+    iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client)
+
+    sse = await iter_next(iterator)
+    assert sse.event is None
+    assert sse.json() == {"content": " culpa"}
+
+    sse = await iter_next(iterator)
+    assert sse.event is None
+    assert sse.json() == {"content": " 
"}
+
+    sse = await iter_next(iterator)
+    assert sse.event is None
+    assert sse.json() == {"content": "foo"}
+
+    await assert_empty_iter(iterator)
+
+
+@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
+async def test_multi_byte_character_multiple_chunks(
+    sync: bool,
+    client: OpenAI,
+    async_client: AsyncOpenAI,
+) -> None:
+    def body() -> Iterator[bytes]:
+        yield b'data: {"content":"'
+        # bytes taken from the string 'известни' and arbitrarily split
+        # so that some multi-byte characters span multiple chunks
+        yield b"\xd0"
+        yield b"\xb8\xd0\xb7\xd0"
+        yield b"\xb2\xd0\xb5\xd1\x81\xd1\x82\xd0\xbd\xd0\xb8"
+        yield b'"}\n'
+        yield b"\n"
+
+    iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client)
+
+    sse = await iter_next(iterator)
+    assert sse.event is None
+    assert sse.json() == {"content": "известни"}
+
+
+async def to_aiter(iter: Iterator[bytes]) -> AsyncIterator[bytes]:
+    for chunk in iter:
+        yield chunk
+
+
+async def iter_next(iter: Iterator[ServerSentEvent] | AsyncIterator[ServerSentEvent]) -> ServerSentEvent:
+    if isinstance(iter, AsyncIterator):
+        return await iter.__anext__()
+
+    return next(iter)
+
+
+async def assert_empty_iter(iter: Iterator[ServerSentEvent] | AsyncIterator[ServerSentEvent]) -> None:
+    with pytest.raises((StopAsyncIteration, RuntimeError)):
+        await iter_next(iter)
+
+
+def make_event_iterator(
+    content: Iterator[bytes],
+    *,
+    sync: bool,
+    client: OpenAI,
+    async_client: AsyncOpenAI,
+) -> Iterator[ServerSentEvent] | AsyncIterator[ServerSentEvent]:
+    if sync:
+        return Stream(cast_to=object, client=client, response=httpx.Response(200, content=content))._iter_events()
+
+    return AsyncStream(
+        cast_to=object, client=async_client, response=httpx.Response(200, content=to_aiter(content))
+    )._iter_events()