Commit 7e1e4572
Changed files (2)
src
openai
tests
src/openai/_streaming.py
@@ -24,7 +24,7 @@ class Stream(Generic[_T]):
response: httpx.Response
- _decoder: SSEDecoder | SSEBytesDecoder
+ _decoder: SSEBytesDecoder
def __init__(
self,
@@ -47,10 +47,7 @@ class Stream(Generic[_T]):
yield item
def _iter_events(self) -> Iterator[ServerSentEvent]:
- if isinstance(self._decoder, SSEBytesDecoder):
- yield from self._decoder.iter_bytes(self.response.iter_bytes())
- else:
- yield from self._decoder.iter(self.response.iter_lines())
+ yield from self._decoder.iter_bytes(self.response.iter_bytes())
def __stream__(self) -> Iterator[_T]:
cast_to = cast(Any, self._cast_to)
@@ -151,12 +148,8 @@ class AsyncStream(Generic[_T]):
yield item
async def _iter_events(self) -> AsyncIterator[ServerSentEvent]:
- if isinstance(self._decoder, SSEBytesDecoder):
- async for sse in self._decoder.aiter_bytes(self.response.aiter_bytes()):
- yield sse
- else:
- async for sse in self._decoder.aiter(self.response.aiter_lines()):
- yield sse
+ async for sse in self._decoder.aiter_bytes(self.response.aiter_bytes()):
+ yield sse
async def __stream__(self) -> AsyncIterator[_T]:
cast_to = cast(Any, self._cast_to)
@@ -282,21 +275,49 @@ class SSEDecoder:
self._last_event_id = None
self._retry = None
- def iter(self, iterator: Iterator[str]) -> Iterator[ServerSentEvent]:
- """Given an iterator that yields lines, iterate over it & yield every event encountered"""
- for line in iterator:
- line = line.rstrip("\n")
- sse = self.decode(line)
- if sse is not None:
- yield sse
-
- async def aiter(self, iterator: AsyncIterator[str]) -> AsyncIterator[ServerSentEvent]:
- """Given an async iterator that yields lines, iterate over it & yield every event encountered"""
- async for line in iterator:
- line = line.rstrip("\n")
- sse = self.decode(line)
- if sse is not None:
- yield sse
+ def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]:
+ """Given an iterator that yields raw binary data, iterate over it & yield every event encountered"""
+ for chunk in self._iter_chunks(iterator):
+ # Split before decoding so splitlines() only uses \r and \n
+ for raw_line in chunk.splitlines():
+ line = raw_line.decode("utf-8")
+ sse = self.decode(line)
+ if sse:
+ yield sse
+
+ def _iter_chunks(self, iterator: Iterator[bytes]) -> Iterator[bytes]:
+ """Given an iterator that yields raw binary data, iterate over it and yield individual SSE chunks"""
+ data = b""
+ for chunk in iterator:
+ for line in chunk.splitlines(keepends=True):
+ data += line
+ if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")):
+ yield data
+ data = b""
+ if data:
+ yield data
+
+ async def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]:
+ """Given an iterator that yields raw binary data, iterate over it & yield every event encountered"""
+ async for chunk in self._aiter_chunks(iterator):
+ # Split before decoding so splitlines() only uses \r and \n
+ for raw_line in chunk.splitlines():
+ line = raw_line.decode("utf-8")
+ sse = self.decode(line)
+ if sse:
+ yield sse
+
+ async def _aiter_chunks(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[bytes]:
+ """Given an iterator that yields raw binary data, iterate over it and yield individual SSE chunks"""
+ data = b""
+ async for chunk in iterator:
+ for line in chunk.splitlines(keepends=True):
+ data += line
+ if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")):
+ yield data
+ data = b""
+ if data:
+ yield data
def decode(self, line: str) -> ServerSentEvent | None:
# See: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation # noqa: E501
tests/test_streaming.py
@@ -1,104 +1,248 @@
+from __future__ import annotations
+
from typing import Iterator, AsyncIterator
+import httpx
import pytest
-from openai._streaming import SSEDecoder
+from openai import OpenAI, AsyncOpenAI
+from openai._streaming import Stream, AsyncStream, ServerSentEvent
@pytest.mark.asyncio
-async def test_basic_async() -> None:
- async def body() -> AsyncIterator[str]:
- yield "event: completion"
- yield 'data: {"foo":true}'
- yield ""
-
- async for sse in SSEDecoder().aiter(body()):
- assert sse.event == "completion"
- assert sse.json() == {"foo": True}
+@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
+async def test_basic(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
+ def body() -> Iterator[bytes]:
+ yield b"event: completion\n"
+ yield b'data: {"foo":true}\n'
+ yield b"\n"
+ iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client)
-def test_basic() -> None:
- def body() -> Iterator[str]:
- yield "event: completion"
- yield 'data: {"foo":true}'
- yield ""
-
- it = SSEDecoder().iter(body())
- sse = next(it)
+ sse = await iter_next(iterator)
assert sse.event == "completion"
assert sse.json() == {"foo": True}
- with pytest.raises(StopIteration):
- next(it)
+ await assert_empty_iter(iterator)
-def test_data_missing_event() -> None:
- def body() -> Iterator[str]:
- yield 'data: {"foo":true}'
- yield ""
+@pytest.mark.asyncio
+@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
+async def test_data_missing_event(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
+ def body() -> Iterator[bytes]:
+ yield b'data: {"foo":true}\n'
+ yield b"\n"
- it = SSEDecoder().iter(body())
- sse = next(it)
+ iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client)
+
+ sse = await iter_next(iterator)
assert sse.event is None
assert sse.json() == {"foo": True}
- with pytest.raises(StopIteration):
- next(it)
+ await assert_empty_iter(iterator)
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
+async def test_event_missing_data(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
+ def body() -> Iterator[bytes]:
+ yield b"event: ping\n"
+ yield b"\n"
-def test_event_missing_data() -> None:
- def body() -> Iterator[str]:
- yield "event: ping"
- yield ""
+ iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client)
- it = SSEDecoder().iter(body())
- sse = next(it)
+ sse = await iter_next(iterator)
assert sse.event == "ping"
assert sse.data == ""
- with pytest.raises(StopIteration):
- next(it)
+ await assert_empty_iter(iterator)
-def test_multiple_events() -> None:
- def body() -> Iterator[str]:
- yield "event: ping"
- yield ""
- yield "event: completion"
- yield ""
+@pytest.mark.asyncio
+@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
+async def test_multiple_events(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
+ def body() -> Iterator[bytes]:
+ yield b"event: ping\n"
+ yield b"\n"
+ yield b"event: completion\n"
+ yield b"\n"
- it = SSEDecoder().iter(body())
+ iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client)
- sse = next(it)
+ sse = await iter_next(iterator)
assert sse.event == "ping"
assert sse.data == ""
- sse = next(it)
+ sse = await iter_next(iterator)
assert sse.event == "completion"
assert sse.data == ""
- with pytest.raises(StopIteration):
- next(it)
-
-
-def test_multiple_events_with_data() -> None:
- def body() -> Iterator[str]:
- yield "event: ping"
- yield 'data: {"foo":true}'
- yield ""
- yield "event: completion"
- yield 'data: {"bar":false}'
- yield ""
+ await assert_empty_iter(iterator)
- it = SSEDecoder().iter(body())
- sse = next(it)
+@pytest.mark.asyncio
+@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
+async def test_multiple_events_with_data(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
+ def body() -> Iterator[bytes]:
+ yield b"event: ping\n"
+ yield b'data: {"foo":true}\n'
+ yield b"\n"
+ yield b"event: completion\n"
+ yield b'data: {"bar":false}\n'
+ yield b"\n"
+
+ iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client)
+
+ sse = await iter_next(iterator)
assert sse.event == "ping"
assert sse.json() == {"foo": True}
- sse = next(it)
+ sse = await iter_next(iterator)
assert sse.event == "completion"
assert sse.json() == {"bar": False}
- with pytest.raises(StopIteration):
- next(it)
+ await assert_empty_iter(iterator)
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
+async def test_multiple_data_lines_with_empty_line(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
+ def body() -> Iterator[bytes]:
+ yield b"event: ping\n"
+ yield b"data: {\n"
+ yield b'data: "foo":\n'
+ yield b"data: \n"
+ yield b"data:\n"
+ yield b"data: true}\n"
+ yield b"\n\n"
+
+ iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client)
+
+ sse = await iter_next(iterator)
+ assert sse.event == "ping"
+ assert sse.json() == {"foo": True}
+ assert sse.data == '{\n"foo":\n\n\ntrue}'
+
+ await assert_empty_iter(iterator)
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
+async def test_data_json_escaped_double_new_line(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
+ def body() -> Iterator[bytes]:
+ yield b"event: ping\n"
+ yield b'data: {"foo": "my long\\n\\ncontent"}'
+ yield b"\n\n"
+
+ iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client)
+
+ sse = await iter_next(iterator)
+ assert sse.event == "ping"
+ assert sse.json() == {"foo": "my long\n\ncontent"}
+
+ await assert_empty_iter(iterator)
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
+async def test_multiple_data_lines(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
+ def body() -> Iterator[bytes]:
+ yield b"event: ping\n"
+ yield b"data: {\n"
+ yield b'data: "foo":\n'
+ yield b"data: true}\n"
+ yield b"\n\n"
+
+ iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client)
+
+ sse = await iter_next(iterator)
+ assert sse.event == "ping"
+ assert sse.json() == {"foo": True}
+
+ await assert_empty_iter(iterator)
+
+
+@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
+async def test_special_new_line_character(
+ sync: bool,
+ client: OpenAI,
+ async_client: AsyncOpenAI,
+) -> None:
+ def body() -> Iterator[bytes]:
+ yield b'data: {"content":" culpa"}\n'
+ yield b"\n"
+ yield b'data: {"content":" \xe2\x80\xa8"}\n'
+ yield b"\n"
+ yield b'data: {"content":"foo"}\n'
+ yield b"\n"
+
+ iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client)
+
+ sse = await iter_next(iterator)
+ assert sse.event is None
+ assert sse.json() == {"content": " culpa"}
+
+ sse = await iter_next(iterator)
+ assert sse.event is None
+ assert sse.json() == {"content": "
"}
+
+ sse = await iter_next(iterator)
+ assert sse.event is None
+ assert sse.json() == {"content": "foo"}
+
+ await assert_empty_iter(iterator)
+
+
+@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
+async def test_multi_byte_character_multiple_chunks(
+ sync: bool,
+ client: OpenAI,
+ async_client: AsyncOpenAI,
+) -> None:
+ def body() -> Iterator[bytes]:
+ yield b'data: {"content":"'
+ # bytes taken from the string 'известни' and arbitrarily split
+ # so that some multi-byte characters span multiple chunks
+ yield b"\xd0"
+ yield b"\xb8\xd0\xb7\xd0"
+ yield b"\xb2\xd0\xb5\xd1\x81\xd1\x82\xd0\xbd\xd0\xb8"
+ yield b'"}\n'
+ yield b"\n"
+
+ iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client)
+
+ sse = await iter_next(iterator)
+ assert sse.event is None
+ assert sse.json() == {"content": "известни"}
+
+
+async def to_aiter(iter: Iterator[bytes]) -> AsyncIterator[bytes]:
+ for chunk in iter:
+ yield chunk
+
+
+async def iter_next(iter: Iterator[ServerSentEvent] | AsyncIterator[ServerSentEvent]) -> ServerSentEvent:
+ if isinstance(iter, AsyncIterator):
+ return await iter.__anext__()
+
+ return next(iter)
+
+
+async def assert_empty_iter(iter: Iterator[ServerSentEvent] | AsyncIterator[ServerSentEvent]) -> None:
+ with pytest.raises((StopAsyncIteration, RuntimeError)):
+ await iter_next(iter)
+
+
+def make_event_iterator(
+ content: Iterator[bytes],
+ *,
+ sync: bool,
+ client: OpenAI,
+ async_client: AsyncOpenAI,
+) -> Iterator[ServerSentEvent] | AsyncIterator[ServerSentEvent]:
+ if sync:
+ return Stream(cast_to=object, client=client, response=httpx.Response(200, content=content))._iter_events()
+
+ return AsyncStream(
+ cast_to=object, client=async_client, response=httpx.Response(200, content=to_aiter(content))
+ )._iter_events()