main
1import json
2from typing import Any, List, Union, cast
3from typing_extensions import Annotated
4
5import httpx
6import pytest
7import pydantic
8
9from openai import OpenAI, BaseModel, AsyncOpenAI
10from openai._response import (
11 APIResponse,
12 BaseAPIResponse,
13 AsyncAPIResponse,
14 BinaryAPIResponse,
15 AsyncBinaryAPIResponse,
16 extract_response_type,
17)
18from openai._streaming import Stream
19from openai._base_client import FinalRequestOptions
20
21from .utils import rich_print_str
22
23
24class ConcreteBaseAPIResponse(APIResponse[bytes]): ...
25
26
27class ConcreteAPIResponse(APIResponse[List[str]]): ...
28
29
30class ConcreteAsyncAPIResponse(APIResponse[httpx.Response]): ...
31
32
33def test_extract_response_type_direct_classes() -> None:
34 assert extract_response_type(BaseAPIResponse[str]) == str
35 assert extract_response_type(APIResponse[str]) == str
36 assert extract_response_type(AsyncAPIResponse[str]) == str
37
38
39def test_extract_response_type_direct_class_missing_type_arg() -> None:
40 with pytest.raises(
41 RuntimeError,
42 match="Expected type <class 'openai._response.AsyncAPIResponse'> to have a type argument at index 0 but it did not",
43 ):
44 extract_response_type(AsyncAPIResponse)
45
46
47def test_extract_response_type_concrete_subclasses() -> None:
48 assert extract_response_type(ConcreteBaseAPIResponse) == bytes
49 assert extract_response_type(ConcreteAPIResponse) == List[str]
50 assert extract_response_type(ConcreteAsyncAPIResponse) == httpx.Response
51
52
53def test_extract_response_type_binary_response() -> None:
54 assert extract_response_type(BinaryAPIResponse) == bytes
55 assert extract_response_type(AsyncBinaryAPIResponse) == bytes
56
57
58class PydanticModel(pydantic.BaseModel): ...
59
60
61def test_response_parse_mismatched_basemodel(client: OpenAI) -> None:
62 response = APIResponse(
63 raw=httpx.Response(200, content=b"foo"),
64 client=client,
65 stream=False,
66 stream_cls=None,
67 cast_to=str,
68 options=FinalRequestOptions.construct(method="get", url="/foo"),
69 )
70
71 with pytest.raises(
72 TypeError,
73 match="Pydantic models must subclass our base model type, e.g. `from openai import BaseModel`",
74 ):
75 response.parse(to=PydanticModel)
76
77
78@pytest.mark.asyncio
79async def test_async_response_parse_mismatched_basemodel(async_client: AsyncOpenAI) -> None:
80 response = AsyncAPIResponse(
81 raw=httpx.Response(200, content=b"foo"),
82 client=async_client,
83 stream=False,
84 stream_cls=None,
85 cast_to=str,
86 options=FinalRequestOptions.construct(method="get", url="/foo"),
87 )
88
89 with pytest.raises(
90 TypeError,
91 match="Pydantic models must subclass our base model type, e.g. `from openai import BaseModel`",
92 ):
93 await response.parse(to=PydanticModel)
94
95
96def test_response_parse_custom_stream(client: OpenAI) -> None:
97 response = APIResponse(
98 raw=httpx.Response(200, content=b"foo"),
99 client=client,
100 stream=True,
101 stream_cls=None,
102 cast_to=str,
103 options=FinalRequestOptions.construct(method="get", url="/foo"),
104 )
105
106 stream = response.parse(to=Stream[int])
107 assert stream._cast_to == int
108
109
110@pytest.mark.asyncio
111async def test_async_response_parse_custom_stream(async_client: AsyncOpenAI) -> None:
112 response = AsyncAPIResponse(
113 raw=httpx.Response(200, content=b"foo"),
114 client=async_client,
115 stream=True,
116 stream_cls=None,
117 cast_to=str,
118 options=FinalRequestOptions.construct(method="get", url="/foo"),
119 )
120
121 stream = await response.parse(to=Stream[int])
122 assert stream._cast_to == int
123
124
125class CustomModel(BaseModel):
126 foo: str
127 bar: int
128
129
130def test_response_parse_custom_model(client: OpenAI) -> None:
131 response = APIResponse(
132 raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})),
133 client=client,
134 stream=False,
135 stream_cls=None,
136 cast_to=str,
137 options=FinalRequestOptions.construct(method="get", url="/foo"),
138 )
139
140 obj = response.parse(to=CustomModel)
141 assert obj.foo == "hello!"
142 assert obj.bar == 2
143
144
145@pytest.mark.asyncio
146async def test_async_response_parse_custom_model(async_client: AsyncOpenAI) -> None:
147 response = AsyncAPIResponse(
148 raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})),
149 client=async_client,
150 stream=False,
151 stream_cls=None,
152 cast_to=str,
153 options=FinalRequestOptions.construct(method="get", url="/foo"),
154 )
155
156 obj = await response.parse(to=CustomModel)
157 assert obj.foo == "hello!"
158 assert obj.bar == 2
159
160
161def test_response_basemodel_request_id(client: OpenAI) -> None:
162 response = APIResponse(
163 raw=httpx.Response(
164 200,
165 headers={"x-request-id": "my-req-id"},
166 content=json.dumps({"foo": "hello!", "bar": 2}),
167 ),
168 client=client,
169 stream=False,
170 stream_cls=None,
171 cast_to=str,
172 options=FinalRequestOptions.construct(method="get", url="/foo"),
173 )
174
175 obj = response.parse(to=CustomModel)
176 assert obj._request_id == "my-req-id"
177 assert obj.foo == "hello!"
178 assert obj.bar == 2
179 assert obj.to_dict() == {"foo": "hello!", "bar": 2}
180 assert "_request_id" not in rich_print_str(obj)
181 assert "__exclude_fields__" not in rich_print_str(obj)
182
183
184@pytest.mark.asyncio
185async def test_async_response_basemodel_request_id(client: OpenAI) -> None:
186 response = AsyncAPIResponse(
187 raw=httpx.Response(
188 200,
189 headers={"x-request-id": "my-req-id"},
190 content=json.dumps({"foo": "hello!", "bar": 2}),
191 ),
192 client=client,
193 stream=False,
194 stream_cls=None,
195 cast_to=str,
196 options=FinalRequestOptions.construct(method="get", url="/foo"),
197 )
198
199 obj = await response.parse(to=CustomModel)
200 assert obj._request_id == "my-req-id"
201 assert obj.foo == "hello!"
202 assert obj.bar == 2
203 assert obj.to_dict() == {"foo": "hello!", "bar": 2}
204
205
206def test_response_parse_annotated_type(client: OpenAI) -> None:
207 response = APIResponse(
208 raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})),
209 client=client,
210 stream=False,
211 stream_cls=None,
212 cast_to=str,
213 options=FinalRequestOptions.construct(method="get", url="/foo"),
214 )
215
216 obj = response.parse(
217 to=cast("type[CustomModel]", Annotated[CustomModel, "random metadata"]),
218 )
219 assert obj.foo == "hello!"
220 assert obj.bar == 2
221
222
223async def test_async_response_parse_annotated_type(async_client: AsyncOpenAI) -> None:
224 response = AsyncAPIResponse(
225 raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})),
226 client=async_client,
227 stream=False,
228 stream_cls=None,
229 cast_to=str,
230 options=FinalRequestOptions.construct(method="get", url="/foo"),
231 )
232
233 obj = await response.parse(
234 to=cast("type[CustomModel]", Annotated[CustomModel, "random metadata"]),
235 )
236 assert obj.foo == "hello!"
237 assert obj.bar == 2
238
239
240@pytest.mark.parametrize(
241 "content, expected",
242 [
243 ("false", False),
244 ("true", True),
245 ("False", False),
246 ("True", True),
247 ("TrUe", True),
248 ("FalSe", False),
249 ],
250)
251def test_response_parse_bool(client: OpenAI, content: str, expected: bool) -> None:
252 response = APIResponse(
253 raw=httpx.Response(200, content=content),
254 client=client,
255 stream=False,
256 stream_cls=None,
257 cast_to=str,
258 options=FinalRequestOptions.construct(method="get", url="/foo"),
259 )
260
261 result = response.parse(to=bool)
262 assert result is expected
263
264
265@pytest.mark.parametrize(
266 "content, expected",
267 [
268 ("false", False),
269 ("true", True),
270 ("False", False),
271 ("True", True),
272 ("TrUe", True),
273 ("FalSe", False),
274 ],
275)
276async def test_async_response_parse_bool(client: AsyncOpenAI, content: str, expected: bool) -> None:
277 response = AsyncAPIResponse(
278 raw=httpx.Response(200, content=content),
279 client=client,
280 stream=False,
281 stream_cls=None,
282 cast_to=str,
283 options=FinalRequestOptions.construct(method="get", url="/foo"),
284 )
285
286 result = await response.parse(to=bool)
287 assert result is expected
288
289
290class OtherModel(BaseModel):
291 a: str
292
293
294@pytest.mark.parametrize("client", [False], indirect=True) # loose validation
295def test_response_parse_expect_model_union_non_json_content(client: OpenAI) -> None:
296 response = APIResponse(
297 raw=httpx.Response(200, content=b"foo", headers={"Content-Type": "application/text"}),
298 client=client,
299 stream=False,
300 stream_cls=None,
301 cast_to=str,
302 options=FinalRequestOptions.construct(method="get", url="/foo"),
303 )
304
305 obj = response.parse(to=cast(Any, Union[CustomModel, OtherModel]))
306 assert isinstance(obj, str)
307 assert obj == "foo"
308
309
310@pytest.mark.asyncio
311@pytest.mark.parametrize("async_client", [False], indirect=True) # loose validation
312async def test_async_response_parse_expect_model_union_non_json_content(async_client: AsyncOpenAI) -> None:
313 response = AsyncAPIResponse(
314 raw=httpx.Response(200, content=b"foo", headers={"Content-Type": "application/text"}),
315 client=async_client,
316 stream=False,
317 stream_cls=None,
318 cast_to=str,
319 options=FinalRequestOptions.construct(method="get", url="/foo"),
320 )
321
322 obj = await response.parse(to=cast(Any, Union[CustomModel, OtherModel]))
323 assert isinstance(obj, str)
324 assert obj == "foo"