main
1from __future__ import annotations
2
3from typing import Iterator, AsyncIterator
4
5import httpx
6import pytest
7
8from openai import OpenAI, AsyncOpenAI
9from openai._streaming import Stream, AsyncStream, ServerSentEvent
10
11
12@pytest.mark.asyncio
13@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
14async def test_basic(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
15 def body() -> Iterator[bytes]:
16 yield b"event: completion\n"
17 yield b'data: {"foo":true}\n'
18 yield b"\n"
19
20 iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client)
21
22 sse = await iter_next(iterator)
23 assert sse.event == "completion"
24 assert sse.json() == {"foo": True}
25
26 await assert_empty_iter(iterator)
27
28
29@pytest.mark.asyncio
30@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
31async def test_data_missing_event(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
32 def body() -> Iterator[bytes]:
33 yield b'data: {"foo":true}\n'
34 yield b"\n"
35
36 iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client)
37
38 sse = await iter_next(iterator)
39 assert sse.event is None
40 assert sse.json() == {"foo": True}
41
42 await assert_empty_iter(iterator)
43
44
45@pytest.mark.asyncio
46@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
47async def test_event_missing_data(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
48 def body() -> Iterator[bytes]:
49 yield b"event: ping\n"
50 yield b"\n"
51
52 iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client)
53
54 sse = await iter_next(iterator)
55 assert sse.event == "ping"
56 assert sse.data == ""
57
58 await assert_empty_iter(iterator)
59
60
61@pytest.mark.asyncio
62@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
63async def test_multiple_events(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
64 def body() -> Iterator[bytes]:
65 yield b"event: ping\n"
66 yield b"\n"
67 yield b"event: completion\n"
68 yield b"\n"
69
70 iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client)
71
72 sse = await iter_next(iterator)
73 assert sse.event == "ping"
74 assert sse.data == ""
75
76 sse = await iter_next(iterator)
77 assert sse.event == "completion"
78 assert sse.data == ""
79
80 await assert_empty_iter(iterator)
81
82
83@pytest.mark.asyncio
84@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
85async def test_multiple_events_with_data(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
86 def body() -> Iterator[bytes]:
87 yield b"event: ping\n"
88 yield b'data: {"foo":true}\n'
89 yield b"\n"
90 yield b"event: completion\n"
91 yield b'data: {"bar":false}\n'
92 yield b"\n"
93
94 iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client)
95
96 sse = await iter_next(iterator)
97 assert sse.event == "ping"
98 assert sse.json() == {"foo": True}
99
100 sse = await iter_next(iterator)
101 assert sse.event == "completion"
102 assert sse.json() == {"bar": False}
103
104 await assert_empty_iter(iterator)
105
106
107@pytest.mark.asyncio
108@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
109async def test_multiple_data_lines_with_empty_line(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
110 def body() -> Iterator[bytes]:
111 yield b"event: ping\n"
112 yield b"data: {\n"
113 yield b'data: "foo":\n'
114 yield b"data: \n"
115 yield b"data:\n"
116 yield b"data: true}\n"
117 yield b"\n\n"
118
119 iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client)
120
121 sse = await iter_next(iterator)
122 assert sse.event == "ping"
123 assert sse.json() == {"foo": True}
124 assert sse.data == '{\n"foo":\n\n\ntrue}'
125
126 await assert_empty_iter(iterator)
127
128
129@pytest.mark.asyncio
130@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
131async def test_data_json_escaped_double_new_line(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
132 def body() -> Iterator[bytes]:
133 yield b"event: ping\n"
134 yield b'data: {"foo": "my long\\n\\ncontent"}'
135 yield b"\n\n"
136
137 iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client)
138
139 sse = await iter_next(iterator)
140 assert sse.event == "ping"
141 assert sse.json() == {"foo": "my long\n\ncontent"}
142
143 await assert_empty_iter(iterator)
144
145
146@pytest.mark.asyncio
147@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
148async def test_multiple_data_lines(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
149 def body() -> Iterator[bytes]:
150 yield b"event: ping\n"
151 yield b"data: {\n"
152 yield b'data: "foo":\n'
153 yield b"data: true}\n"
154 yield b"\n\n"
155
156 iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client)
157
158 sse = await iter_next(iterator)
159 assert sse.event == "ping"
160 assert sse.json() == {"foo": True}
161
162 await assert_empty_iter(iterator)
163
164
165@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
166async def test_special_new_line_character(
167 sync: bool,
168 client: OpenAI,
169 async_client: AsyncOpenAI,
170) -> None:
171 def body() -> Iterator[bytes]:
172 yield b'data: {"content":" culpa"}\n'
173 yield b"\n"
174 yield b'data: {"content":" \xe2\x80\xa8"}\n'
175 yield b"\n"
176 yield b'data: {"content":"foo"}\n'
177 yield b"\n"
178
179 iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client)
180
181 sse = await iter_next(iterator)
182 assert sse.event is None
183 assert sse.json() == {"content": " culpa"}
184
185 sse = await iter_next(iterator)
186 assert sse.event is None
187 assert sse.json() == {"content": "
"}
188
189 sse = await iter_next(iterator)
190 assert sse.event is None
191 assert sse.json() == {"content": "foo"}
192
193 await assert_empty_iter(iterator)
194
195
196@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
197async def test_multi_byte_character_multiple_chunks(
198 sync: bool,
199 client: OpenAI,
200 async_client: AsyncOpenAI,
201) -> None:
202 def body() -> Iterator[bytes]:
203 yield b'data: {"content":"'
204 # bytes taken from the string 'известни' and arbitrarily split
205 # so that some multi-byte characters span multiple chunks
206 yield b"\xd0"
207 yield b"\xb8\xd0\xb7\xd0"
208 yield b"\xb2\xd0\xb5\xd1\x81\xd1\x82\xd0\xbd\xd0\xb8"
209 yield b'"}\n'
210 yield b"\n"
211
212 iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client)
213
214 sse = await iter_next(iterator)
215 assert sse.event is None
216 assert sse.json() == {"content": "известни"}
217
218
219async def to_aiter(iter: Iterator[bytes]) -> AsyncIterator[bytes]:
220 for chunk in iter:
221 yield chunk
222
223
224async def iter_next(iter: Iterator[ServerSentEvent] | AsyncIterator[ServerSentEvent]) -> ServerSentEvent:
225 if isinstance(iter, AsyncIterator):
226 return await iter.__anext__()
227
228 return next(iter)
229
230
231async def assert_empty_iter(iter: Iterator[ServerSentEvent] | AsyncIterator[ServerSentEvent]) -> None:
232 with pytest.raises((StopAsyncIteration, RuntimeError)):
233 await iter_next(iter)
234
235
236def make_event_iterator(
237 content: Iterator[bytes],
238 *,
239 sync: bool,
240 client: OpenAI,
241 async_client: AsyncOpenAI,
242) -> Iterator[ServerSentEvent] | AsyncIterator[ServerSentEvent]:
243 if sync:
244 return Stream(cast_to=object, client=client, response=httpx.Response(200, content=content))._iter_events()
245
246 return AsyncStream(
247 cast_to=object, client=async_client, response=httpx.Response(200, content=to_aiter(content))
248 )._iter_events()