main
  1import json
  2from typing import Any, List, Union, cast
  3from typing_extensions import Annotated
  4
  5import httpx
  6import pytest
  7import pydantic
  8
  9from openai import OpenAI, BaseModel, AsyncOpenAI
 10from openai._response import (
 11    APIResponse,
 12    BaseAPIResponse,
 13    AsyncAPIResponse,
 14    BinaryAPIResponse,
 15    AsyncBinaryAPIResponse,
 16    extract_response_type,
 17)
 18from openai._streaming import Stream
 19from openai._base_client import FinalRequestOptions
 20
 21from .utils import rich_print_str
 22
 23
 24class ConcreteBaseAPIResponse(APIResponse[bytes]): ...
 25
 26
 27class ConcreteAPIResponse(APIResponse[List[str]]): ...
 28
 29
 30class ConcreteAsyncAPIResponse(APIResponse[httpx.Response]): ...
 31
 32
 33def test_extract_response_type_direct_classes() -> None:
 34    assert extract_response_type(BaseAPIResponse[str]) == str
 35    assert extract_response_type(APIResponse[str]) == str
 36    assert extract_response_type(AsyncAPIResponse[str]) == str
 37
 38
 39def test_extract_response_type_direct_class_missing_type_arg() -> None:
 40    with pytest.raises(
 41        RuntimeError,
 42        match="Expected type <class 'openai._response.AsyncAPIResponse'> to have a type argument at index 0 but it did not",
 43    ):
 44        extract_response_type(AsyncAPIResponse)
 45
 46
 47def test_extract_response_type_concrete_subclasses() -> None:
 48    assert extract_response_type(ConcreteBaseAPIResponse) == bytes
 49    assert extract_response_type(ConcreteAPIResponse) == List[str]
 50    assert extract_response_type(ConcreteAsyncAPIResponse) == httpx.Response
 51
 52
 53def test_extract_response_type_binary_response() -> None:
 54    assert extract_response_type(BinaryAPIResponse) == bytes
 55    assert extract_response_type(AsyncBinaryAPIResponse) == bytes
 56
 57
 58class PydanticModel(pydantic.BaseModel): ...
 59
 60
 61def test_response_parse_mismatched_basemodel(client: OpenAI) -> None:
 62    response = APIResponse(
 63        raw=httpx.Response(200, content=b"foo"),
 64        client=client,
 65        stream=False,
 66        stream_cls=None,
 67        cast_to=str,
 68        options=FinalRequestOptions.construct(method="get", url="/foo"),
 69    )
 70
 71    with pytest.raises(
 72        TypeError,
 73        match="Pydantic models must subclass our base model type, e.g. `from openai import BaseModel`",
 74    ):
 75        response.parse(to=PydanticModel)
 76
 77
 78@pytest.mark.asyncio
 79async def test_async_response_parse_mismatched_basemodel(async_client: AsyncOpenAI) -> None:
 80    response = AsyncAPIResponse(
 81        raw=httpx.Response(200, content=b"foo"),
 82        client=async_client,
 83        stream=False,
 84        stream_cls=None,
 85        cast_to=str,
 86        options=FinalRequestOptions.construct(method="get", url="/foo"),
 87    )
 88
 89    with pytest.raises(
 90        TypeError,
 91        match="Pydantic models must subclass our base model type, e.g. `from openai import BaseModel`",
 92    ):
 93        await response.parse(to=PydanticModel)
 94
 95
 96def test_response_parse_custom_stream(client: OpenAI) -> None:
 97    response = APIResponse(
 98        raw=httpx.Response(200, content=b"foo"),
 99        client=client,
100        stream=True,
101        stream_cls=None,
102        cast_to=str,
103        options=FinalRequestOptions.construct(method="get", url="/foo"),
104    )
105
106    stream = response.parse(to=Stream[int])
107    assert stream._cast_to == int
108
109
110@pytest.mark.asyncio
111async def test_async_response_parse_custom_stream(async_client: AsyncOpenAI) -> None:
112    response = AsyncAPIResponse(
113        raw=httpx.Response(200, content=b"foo"),
114        client=async_client,
115        stream=True,
116        stream_cls=None,
117        cast_to=str,
118        options=FinalRequestOptions.construct(method="get", url="/foo"),
119    )
120
121    stream = await response.parse(to=Stream[int])
122    assert stream._cast_to == int
123
124
125class CustomModel(BaseModel):
126    foo: str
127    bar: int
128
129
130def test_response_parse_custom_model(client: OpenAI) -> None:
131    response = APIResponse(
132        raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})),
133        client=client,
134        stream=False,
135        stream_cls=None,
136        cast_to=str,
137        options=FinalRequestOptions.construct(method="get", url="/foo"),
138    )
139
140    obj = response.parse(to=CustomModel)
141    assert obj.foo == "hello!"
142    assert obj.bar == 2
143
144
145@pytest.mark.asyncio
146async def test_async_response_parse_custom_model(async_client: AsyncOpenAI) -> None:
147    response = AsyncAPIResponse(
148        raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})),
149        client=async_client,
150        stream=False,
151        stream_cls=None,
152        cast_to=str,
153        options=FinalRequestOptions.construct(method="get", url="/foo"),
154    )
155
156    obj = await response.parse(to=CustomModel)
157    assert obj.foo == "hello!"
158    assert obj.bar == 2
159
160
161def test_response_basemodel_request_id(client: OpenAI) -> None:
162    response = APIResponse(
163        raw=httpx.Response(
164            200,
165            headers={"x-request-id": "my-req-id"},
166            content=json.dumps({"foo": "hello!", "bar": 2}),
167        ),
168        client=client,
169        stream=False,
170        stream_cls=None,
171        cast_to=str,
172        options=FinalRequestOptions.construct(method="get", url="/foo"),
173    )
174
175    obj = response.parse(to=CustomModel)
176    assert obj._request_id == "my-req-id"
177    assert obj.foo == "hello!"
178    assert obj.bar == 2
179    assert obj.to_dict() == {"foo": "hello!", "bar": 2}
180    assert "_request_id" not in rich_print_str(obj)
181    assert "__exclude_fields__" not in rich_print_str(obj)
182
183
184@pytest.mark.asyncio
185async def test_async_response_basemodel_request_id(client: OpenAI) -> None:
186    response = AsyncAPIResponse(
187        raw=httpx.Response(
188            200,
189            headers={"x-request-id": "my-req-id"},
190            content=json.dumps({"foo": "hello!", "bar": 2}),
191        ),
192        client=client,
193        stream=False,
194        stream_cls=None,
195        cast_to=str,
196        options=FinalRequestOptions.construct(method="get", url="/foo"),
197    )
198
199    obj = await response.parse(to=CustomModel)
200    assert obj._request_id == "my-req-id"
201    assert obj.foo == "hello!"
202    assert obj.bar == 2
203    assert obj.to_dict() == {"foo": "hello!", "bar": 2}
204
205
206def test_response_parse_annotated_type(client: OpenAI) -> None:
207    response = APIResponse(
208        raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})),
209        client=client,
210        stream=False,
211        stream_cls=None,
212        cast_to=str,
213        options=FinalRequestOptions.construct(method="get", url="/foo"),
214    )
215
216    obj = response.parse(
217        to=cast("type[CustomModel]", Annotated[CustomModel, "random metadata"]),
218    )
219    assert obj.foo == "hello!"
220    assert obj.bar == 2
221
222
223async def test_async_response_parse_annotated_type(async_client: AsyncOpenAI) -> None:
224    response = AsyncAPIResponse(
225        raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})),
226        client=async_client,
227        stream=False,
228        stream_cls=None,
229        cast_to=str,
230        options=FinalRequestOptions.construct(method="get", url="/foo"),
231    )
232
233    obj = await response.parse(
234        to=cast("type[CustomModel]", Annotated[CustomModel, "random metadata"]),
235    )
236    assert obj.foo == "hello!"
237    assert obj.bar == 2
238
239
240@pytest.mark.parametrize(
241    "content, expected",
242    [
243        ("false", False),
244        ("true", True),
245        ("False", False),
246        ("True", True),
247        ("TrUe", True),
248        ("FalSe", False),
249    ],
250)
251def test_response_parse_bool(client: OpenAI, content: str, expected: bool) -> None:
252    response = APIResponse(
253        raw=httpx.Response(200, content=content),
254        client=client,
255        stream=False,
256        stream_cls=None,
257        cast_to=str,
258        options=FinalRequestOptions.construct(method="get", url="/foo"),
259    )
260
261    result = response.parse(to=bool)
262    assert result is expected
263
264
265@pytest.mark.parametrize(
266    "content, expected",
267    [
268        ("false", False),
269        ("true", True),
270        ("False", False),
271        ("True", True),
272        ("TrUe", True),
273        ("FalSe", False),
274    ],
275)
276async def test_async_response_parse_bool(client: AsyncOpenAI, content: str, expected: bool) -> None:
277    response = AsyncAPIResponse(
278        raw=httpx.Response(200, content=content),
279        client=client,
280        stream=False,
281        stream_cls=None,
282        cast_to=str,
283        options=FinalRequestOptions.construct(method="get", url="/foo"),
284    )
285
286    result = await response.parse(to=bool)
287    assert result is expected
288
289
290class OtherModel(BaseModel):
291    a: str
292
293
294@pytest.mark.parametrize("client", [False], indirect=True)  # loose validation
295def test_response_parse_expect_model_union_non_json_content(client: OpenAI) -> None:
296    response = APIResponse(
297        raw=httpx.Response(200, content=b"foo", headers={"Content-Type": "application/text"}),
298        client=client,
299        stream=False,
300        stream_cls=None,
301        cast_to=str,
302        options=FinalRequestOptions.construct(method="get", url="/foo"),
303    )
304
305    obj = response.parse(to=cast(Any, Union[CustomModel, OtherModel]))
306    assert isinstance(obj, str)
307    assert obj == "foo"
308
309
310@pytest.mark.asyncio
311@pytest.mark.parametrize("async_client", [False], indirect=True)  # loose validation
312async def test_async_response_parse_expect_model_union_non_json_content(async_client: AsyncOpenAI) -> None:
313    response = AsyncAPIResponse(
314        raw=httpx.Response(200, content=b"foo", headers={"Content-Type": "application/text"}),
315        client=async_client,
316        stream=False,
317        stream_cls=None,
318        cast_to=str,
319        options=FinalRequestOptions.construct(method="get", url="/foo"),
320    )
321
322    obj = await response.parse(to=cast(Any, Union[CustomModel, OtherModel]))
323    assert isinstance(obj, str)
324    assert obj == "foo"