main
  1from __future__ import annotations
  2
  3from typing import Iterator, AsyncIterator
  4
  5import httpx
  6import pytest
  7
  8from openai import OpenAI, AsyncOpenAI
  9from openai._streaming import Stream, AsyncStream, ServerSentEvent
 10
 11
 12@pytest.mark.asyncio
 13@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
 14async def test_basic(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
 15    def body() -> Iterator[bytes]:
 16        yield b"event: completion\n"
 17        yield b'data: {"foo":true}\n'
 18        yield b"\n"
 19
 20    iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client)
 21
 22    sse = await iter_next(iterator)
 23    assert sse.event == "completion"
 24    assert sse.json() == {"foo": True}
 25
 26    await assert_empty_iter(iterator)
 27
 28
 29@pytest.mark.asyncio
 30@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
 31async def test_data_missing_event(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
 32    def body() -> Iterator[bytes]:
 33        yield b'data: {"foo":true}\n'
 34        yield b"\n"
 35
 36    iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client)
 37
 38    sse = await iter_next(iterator)
 39    assert sse.event is None
 40    assert sse.json() == {"foo": True}
 41
 42    await assert_empty_iter(iterator)
 43
 44
 45@pytest.mark.asyncio
 46@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
 47async def test_event_missing_data(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
 48    def body() -> Iterator[bytes]:
 49        yield b"event: ping\n"
 50        yield b"\n"
 51
 52    iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client)
 53
 54    sse = await iter_next(iterator)
 55    assert sse.event == "ping"
 56    assert sse.data == ""
 57
 58    await assert_empty_iter(iterator)
 59
 60
 61@pytest.mark.asyncio
 62@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
 63async def test_multiple_events(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
 64    def body() -> Iterator[bytes]:
 65        yield b"event: ping\n"
 66        yield b"\n"
 67        yield b"event: completion\n"
 68        yield b"\n"
 69
 70    iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client)
 71
 72    sse = await iter_next(iterator)
 73    assert sse.event == "ping"
 74    assert sse.data == ""
 75
 76    sse = await iter_next(iterator)
 77    assert sse.event == "completion"
 78    assert sse.data == ""
 79
 80    await assert_empty_iter(iterator)
 81
 82
 83@pytest.mark.asyncio
 84@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
 85async def test_multiple_events_with_data(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
 86    def body() -> Iterator[bytes]:
 87        yield b"event: ping\n"
 88        yield b'data: {"foo":true}\n'
 89        yield b"\n"
 90        yield b"event: completion\n"
 91        yield b'data: {"bar":false}\n'
 92        yield b"\n"
 93
 94    iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client)
 95
 96    sse = await iter_next(iterator)
 97    assert sse.event == "ping"
 98    assert sse.json() == {"foo": True}
 99
100    sse = await iter_next(iterator)
101    assert sse.event == "completion"
102    assert sse.json() == {"bar": False}
103
104    await assert_empty_iter(iterator)
105
106
107@pytest.mark.asyncio
108@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
109async def test_multiple_data_lines_with_empty_line(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
110    def body() -> Iterator[bytes]:
111        yield b"event: ping\n"
112        yield b"data: {\n"
113        yield b'data: "foo":\n'
114        yield b"data: \n"
115        yield b"data:\n"
116        yield b"data: true}\n"
117        yield b"\n\n"
118
119    iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client)
120
121    sse = await iter_next(iterator)
122    assert sse.event == "ping"
123    assert sse.json() == {"foo": True}
124    assert sse.data == '{\n"foo":\n\n\ntrue}'
125
126    await assert_empty_iter(iterator)
127
128
129@pytest.mark.asyncio
130@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
131async def test_data_json_escaped_double_new_line(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
132    def body() -> Iterator[bytes]:
133        yield b"event: ping\n"
134        yield b'data: {"foo": "my long\\n\\ncontent"}'
135        yield b"\n\n"
136
137    iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client)
138
139    sse = await iter_next(iterator)
140    assert sse.event == "ping"
141    assert sse.json() == {"foo": "my long\n\ncontent"}
142
143    await assert_empty_iter(iterator)
144
145
146@pytest.mark.asyncio
147@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
148async def test_multiple_data_lines(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
149    def body() -> Iterator[bytes]:
150        yield b"event: ping\n"
151        yield b"data: {\n"
152        yield b'data: "foo":\n'
153        yield b"data: true}\n"
154        yield b"\n\n"
155
156    iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client)
157
158    sse = await iter_next(iterator)
159    assert sse.event == "ping"
160    assert sse.json() == {"foo": True}
161
162    await assert_empty_iter(iterator)
163
164
165@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
166async def test_special_new_line_character(
167    sync: bool,
168    client: OpenAI,
169    async_client: AsyncOpenAI,
170) -> None:
171    def body() -> Iterator[bytes]:
172        yield b'data: {"content":" culpa"}\n'
173        yield b"\n"
174        yield b'data: {"content":" \xe2\x80\xa8"}\n'
175        yield b"\n"
176        yield b'data: {"content":"foo"}\n'
177        yield b"\n"
178
179    iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client)
180
181    sse = await iter_next(iterator)
182    assert sse.event is None
183    assert sse.json() == {"content": " culpa"}
184
185    sse = await iter_next(iterator)
186    assert sse.event is None
187    assert sse.json() == {"content": ""}
188
189    sse = await iter_next(iterator)
190    assert sse.event is None
191    assert sse.json() == {"content": "foo"}
192
193    await assert_empty_iter(iterator)
194
195
196@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
197async def test_multi_byte_character_multiple_chunks(
198    sync: bool,
199    client: OpenAI,
200    async_client: AsyncOpenAI,
201) -> None:
202    def body() -> Iterator[bytes]:
203        yield b'data: {"content":"'
204        # bytes taken from the string 'известни' and arbitrarily split
205        # so that some multi-byte characters span multiple chunks
206        yield b"\xd0"
207        yield b"\xb8\xd0\xb7\xd0"
208        yield b"\xb2\xd0\xb5\xd1\x81\xd1\x82\xd0\xbd\xd0\xb8"
209        yield b'"}\n'
210        yield b"\n"
211
212    iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client)
213
214    sse = await iter_next(iterator)
215    assert sse.event is None
216    assert sse.json() == {"content": "известни"}
217
218
219async def to_aiter(iter: Iterator[bytes]) -> AsyncIterator[bytes]:
220    for chunk in iter:
221        yield chunk
222
223
224async def iter_next(iter: Iterator[ServerSentEvent] | AsyncIterator[ServerSentEvent]) -> ServerSentEvent:
225    if isinstance(iter, AsyncIterator):
226        return await iter.__anext__()
227
228    return next(iter)
229
230
231async def assert_empty_iter(iter: Iterator[ServerSentEvent] | AsyncIterator[ServerSentEvent]) -> None:
232    with pytest.raises((StopAsyncIteration, RuntimeError)):
233        await iter_next(iter)
234
235
236def make_event_iterator(
237    content: Iterator[bytes],
238    *,
239    sync: bool,
240    client: OpenAI,
241    async_client: AsyncOpenAI,
242) -> Iterator[ServerSentEvent] | AsyncIterator[ServerSentEvent]:
243    if sync:
244        return Stream(cast_to=object, client=client, response=httpx.Response(200, content=content))._iter_events()
245
246    return AsyncStream(
247        cast_to=object, client=async_client, response=httpx.Response(200, content=to_aiter(content))
248    )._iter_events()