main
1# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
2
3from __future__ import annotations
4
5import gc
6import os
7import sys
8import json
9import asyncio
10import inspect
11import tracemalloc
12from typing import Any, Union, Protocol, cast
13from unittest import mock
14from typing_extensions import Literal
15
16import httpx
17import pytest
18from respx import MockRouter
19from pydantic import ValidationError
20
21from openai import OpenAI, AsyncOpenAI, APIResponseValidationError
22from openai._types import Omit
23from openai._utils import asyncify
24from openai._models import BaseModel, FinalRequestOptions
25from openai._streaming import Stream, AsyncStream
26from openai._exceptions import OpenAIError, APIStatusError, APITimeoutError, APIResponseValidationError
27from openai._base_client import (
28 DEFAULT_TIMEOUT,
29 HTTPX_DEFAULT_TIMEOUT,
30 BaseClient,
31 OtherPlatform,
32 DefaultHttpxClient,
33 DefaultAsyncHttpxClient,
34 get_platform,
35 make_request_options,
36)
37
38from .utils import update_env
39
40base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
41api_key = "My API Key"
42
43
44class MockRequestCall(Protocol):
45 request: httpx.Request
46
47
48def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]:
49 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
50 url = httpx.URL(request.url)
51 return dict(url.params)
52
53
54def _low_retry_timeout(*_args: Any, **_kwargs: Any) -> float:
55 return 0.1
56
57
58def _get_open_connections(client: OpenAI | AsyncOpenAI) -> int:
59 transport = client._client._transport
60 assert isinstance(transport, httpx.HTTPTransport) or isinstance(transport, httpx.AsyncHTTPTransport)
61
62 pool = transport._pool
63 return len(pool._requests)
64
65
66class TestOpenAI:
67 @pytest.mark.respx(base_url=base_url)
68 def test_raw_response(self, respx_mock: MockRouter, client: OpenAI) -> None:
69 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
70
71 response = client.post("/foo", cast_to=httpx.Response)
72 assert response.status_code == 200
73 assert isinstance(response, httpx.Response)
74 assert response.json() == {"foo": "bar"}
75
76 @pytest.mark.respx(base_url=base_url)
77 def test_raw_response_for_binary(self, respx_mock: MockRouter, client: OpenAI) -> None:
78 respx_mock.post("/foo").mock(
79 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
80 )
81
82 response = client.post("/foo", cast_to=httpx.Response)
83 assert response.status_code == 200
84 assert isinstance(response, httpx.Response)
85 assert response.json() == {"foo": "bar"}
86
87 def test_copy(self, client: OpenAI) -> None:
88 copied = client.copy()
89 assert id(copied) != id(client)
90
91 copied = client.copy(api_key="another My API Key")
92 assert copied.api_key == "another My API Key"
93 assert client.api_key == "My API Key"
94
95 def test_copy_default_options(self, client: OpenAI) -> None:
96 # options that have a default are overridden correctly
97 copied = client.copy(max_retries=7)
98 assert copied.max_retries == 7
99 assert client.max_retries == 2
100
101 copied2 = copied.copy(max_retries=6)
102 assert copied2.max_retries == 6
103 assert copied.max_retries == 7
104
105 # timeout
106 assert isinstance(client.timeout, httpx.Timeout)
107 copied = client.copy(timeout=None)
108 assert copied.timeout is None
109 assert isinstance(client.timeout, httpx.Timeout)
110
111 def test_copy_default_headers(self) -> None:
112 client = OpenAI(
113 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
114 )
115 assert client.default_headers["X-Foo"] == "bar"
116
117 # does not override the already given value when not specified
118 copied = client.copy()
119 assert copied.default_headers["X-Foo"] == "bar"
120
121 # merges already given headers
122 copied = client.copy(default_headers={"X-Bar": "stainless"})
123 assert copied.default_headers["X-Foo"] == "bar"
124 assert copied.default_headers["X-Bar"] == "stainless"
125
126 # uses new values for any already given headers
127 copied = client.copy(default_headers={"X-Foo": "stainless"})
128 assert copied.default_headers["X-Foo"] == "stainless"
129
130 # set_default_headers
131
132 # completely overrides already set values
133 copied = client.copy(set_default_headers={})
134 assert copied.default_headers.get("X-Foo") is None
135
136 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
137 assert copied.default_headers["X-Bar"] == "Robert"
138
139 with pytest.raises(
140 ValueError,
141 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
142 ):
143 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
144 client.close()
145
146 def test_copy_default_query(self) -> None:
147 client = OpenAI(
148 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
149 )
150 assert _get_params(client)["foo"] == "bar"
151
152 # does not override the already given value when not specified
153 copied = client.copy()
154 assert _get_params(copied)["foo"] == "bar"
155
156 # merges already given params
157 copied = client.copy(default_query={"bar": "stainless"})
158 params = _get_params(copied)
159 assert params["foo"] == "bar"
160 assert params["bar"] == "stainless"
161
162 # uses new values for any already given headers
163 copied = client.copy(default_query={"foo": "stainless"})
164 assert _get_params(copied)["foo"] == "stainless"
165
166 # set_default_query
167
168 # completely overrides already set values
169 copied = client.copy(set_default_query={})
170 assert _get_params(copied) == {}
171
172 copied = client.copy(set_default_query={"bar": "Robert"})
173 assert _get_params(copied)["bar"] == "Robert"
174
175 with pytest.raises(
176 ValueError,
177 # TODO: update
178 match="`default_query` and `set_default_query` arguments are mutually exclusive",
179 ):
180 client.copy(set_default_query={}, default_query={"foo": "Bar"})
181
182 client.close()
183
184 def test_copy_signature(self, client: OpenAI) -> None:
185 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
186 init_signature = inspect.signature(
187 # mypy doesn't like that we access the `__init__` property.
188 client.__init__, # type: ignore[misc]
189 )
190 copy_signature = inspect.signature(client.copy)
191 exclude_params = {"transport", "proxies", "_strict_response_validation"}
192
193 for name in init_signature.parameters.keys():
194 if name in exclude_params:
195 continue
196
197 copy_param = copy_signature.parameters.get(name)
198 assert copy_param is not None, f"copy() signature is missing the {name} param"
199
200 @pytest.mark.skipif(sys.version_info >= (3, 10), reason="fails because of a memory leak that started from 3.12")
201 def test_copy_build_request(self, client: OpenAI) -> None:
202 options = FinalRequestOptions(method="get", url="/foo")
203
204 def build_request(options: FinalRequestOptions) -> None:
205 client_copy = client.copy()
206 client_copy._build_request(options)
207
208 # ensure that the machinery is warmed up before tracing starts.
209 build_request(options)
210 gc.collect()
211
212 tracemalloc.start(1000)
213
214 snapshot_before = tracemalloc.take_snapshot()
215
216 ITERATIONS = 10
217 for _ in range(ITERATIONS):
218 build_request(options)
219
220 gc.collect()
221 snapshot_after = tracemalloc.take_snapshot()
222
223 tracemalloc.stop()
224
225 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
226 if diff.count == 0:
227 # Avoid false positives by considering only leaks (i.e. allocations that persist).
228 return
229
230 if diff.count % ITERATIONS != 0:
231 # Avoid false positives by considering only leaks that appear per iteration.
232 return
233
234 for frame in diff.traceback:
235 if any(
236 frame.filename.endswith(fragment)
237 for fragment in [
238 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
239 #
240 # removing the decorator fixes the leak for reasons we don't understand.
241 "openai/_legacy_response.py",
242 "openai/_response.py",
243 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
244 "openai/_compat.py",
245 # Standard library leaks we don't care about.
246 "/logging/__init__.py",
247 ]
248 ):
249 return
250
251 leaks.append(diff)
252
253 leaks: list[tracemalloc.StatisticDiff] = []
254 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
255 add_leak(leaks, diff)
256 if leaks:
257 for leak in leaks:
258 print("MEMORY LEAK:", leak)
259 for frame in leak.traceback:
260 print(frame)
261 raise AssertionError()
262
263 def test_request_timeout(self, client: OpenAI) -> None:
264 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
265 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
266 assert timeout == DEFAULT_TIMEOUT
267
268 request = client._build_request(FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0)))
269 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
270 assert timeout == httpx.Timeout(100.0)
271
272 def test_client_timeout_option(self) -> None:
273 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0))
274
275 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
276 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
277 assert timeout == httpx.Timeout(0)
278
279 client.close()
280
281 def test_http_client_timeout_option(self) -> None:
282 # custom timeout given to the httpx client should be used
283 with httpx.Client(timeout=None) as http_client:
284 client = OpenAI(
285 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
286 )
287
288 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
289 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
290 assert timeout == httpx.Timeout(None)
291
292 client.close()
293
294 # no timeout given to the httpx client should not use the httpx default
295 with httpx.Client() as http_client:
296 client = OpenAI(
297 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
298 )
299
300 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
301 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
302 assert timeout == DEFAULT_TIMEOUT
303
304 client.close()
305
306 # explicitly passing the default timeout currently results in it being ignored
307 with httpx.Client(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
308 client = OpenAI(
309 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
310 )
311
312 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
313 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
314 assert timeout == DEFAULT_TIMEOUT # our default
315
316 client.close()
317
318 async def test_invalid_http_client(self) -> None:
319 with pytest.raises(TypeError, match="Invalid `http_client` arg"):
320 async with httpx.AsyncClient() as http_client:
321 OpenAI(
322 base_url=base_url,
323 api_key=api_key,
324 _strict_response_validation=True,
325 http_client=cast(Any, http_client),
326 )
327
328 def test_default_headers_option(self) -> None:
329 test_client = OpenAI(
330 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
331 )
332 request = test_client._build_request(FinalRequestOptions(method="get", url="/foo"))
333 assert request.headers.get("x-foo") == "bar"
334 assert request.headers.get("x-stainless-lang") == "python"
335
336 test_client2 = OpenAI(
337 base_url=base_url,
338 api_key=api_key,
339 _strict_response_validation=True,
340 default_headers={
341 "X-Foo": "stainless",
342 "X-Stainless-Lang": "my-overriding-header",
343 },
344 )
345 request = test_client2._build_request(FinalRequestOptions(method="get", url="/foo"))
346 assert request.headers.get("x-foo") == "stainless"
347 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
348
349 test_client.close()
350 test_client2.close()
351
352 def test_validate_headers(self) -> None:
353 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
354 options = client._prepare_options(FinalRequestOptions(method="get", url="/foo"))
355 request = client._build_request(options)
356
357 assert request.headers.get("Authorization") == f"Bearer {api_key}"
358
359 with pytest.raises(OpenAIError):
360 with update_env(**{"OPENAI_API_KEY": Omit()}):
361 client2 = OpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
362 _ = client2
363
364 def test_default_query_option(self) -> None:
365 client = OpenAI(
366 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
367 )
368 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
369 url = httpx.URL(request.url)
370 assert dict(url.params) == {"query_param": "bar"}
371
372 request = client._build_request(
373 FinalRequestOptions(
374 method="get",
375 url="/foo",
376 params={"foo": "baz", "query_param": "overridden"},
377 )
378 )
379 url = httpx.URL(request.url)
380 assert dict(url.params) == {"foo": "baz", "query_param": "overridden"}
381
382 client.close()
383
384 def test_request_extra_json(self, client: OpenAI) -> None:
385 request = client._build_request(
386 FinalRequestOptions(
387 method="post",
388 url="/foo",
389 json_data={"foo": "bar"},
390 extra_json={"baz": False},
391 ),
392 )
393 data = json.loads(request.content.decode("utf-8"))
394 assert data == {"foo": "bar", "baz": False}
395
396 request = client._build_request(
397 FinalRequestOptions(
398 method="post",
399 url="/foo",
400 extra_json={"baz": False},
401 ),
402 )
403 data = json.loads(request.content.decode("utf-8"))
404 assert data == {"baz": False}
405
406 # `extra_json` takes priority over `json_data` when keys clash
407 request = client._build_request(
408 FinalRequestOptions(
409 method="post",
410 url="/foo",
411 json_data={"foo": "bar", "baz": True},
412 extra_json={"baz": None},
413 ),
414 )
415 data = json.loads(request.content.decode("utf-8"))
416 assert data == {"foo": "bar", "baz": None}
417
418 def test_request_extra_headers(self, client: OpenAI) -> None:
419 request = client._build_request(
420 FinalRequestOptions(
421 method="post",
422 url="/foo",
423 **make_request_options(extra_headers={"X-Foo": "Foo"}),
424 ),
425 )
426 assert request.headers.get("X-Foo") == "Foo"
427
428 # `extra_headers` takes priority over `default_headers` when keys clash
429 request = client.with_options(default_headers={"X-Bar": "true"})._build_request(
430 FinalRequestOptions(
431 method="post",
432 url="/foo",
433 **make_request_options(
434 extra_headers={"X-Bar": "false"},
435 ),
436 ),
437 )
438 assert request.headers.get("X-Bar") == "false"
439
440 def test_request_extra_query(self, client: OpenAI) -> None:
441 request = client._build_request(
442 FinalRequestOptions(
443 method="post",
444 url="/foo",
445 **make_request_options(
446 extra_query={"my_query_param": "Foo"},
447 ),
448 ),
449 )
450 params = dict(request.url.params)
451 assert params == {"my_query_param": "Foo"}
452
453 # if both `query` and `extra_query` are given, they are merged
454 request = client._build_request(
455 FinalRequestOptions(
456 method="post",
457 url="/foo",
458 **make_request_options(
459 query={"bar": "1"},
460 extra_query={"foo": "2"},
461 ),
462 ),
463 )
464 params = dict(request.url.params)
465 assert params == {"bar": "1", "foo": "2"}
466
467 # `extra_query` takes priority over `query` when keys clash
468 request = client._build_request(
469 FinalRequestOptions(
470 method="post",
471 url="/foo",
472 **make_request_options(
473 query={"foo": "1"},
474 extra_query={"foo": "2"},
475 ),
476 ),
477 )
478 params = dict(request.url.params)
479 assert params == {"foo": "2"}
480
481 def test_multipart_repeating_array(self, client: OpenAI) -> None:
482 request = client._build_request(
483 FinalRequestOptions.construct(
484 method="post",
485 url="/foo",
486 headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
487 json_data={"array": ["foo", "bar"]},
488 files=[("foo.txt", b"hello world")],
489 )
490 )
491
492 assert request.read().split(b"\r\n") == [
493 b"--6b7ba517decee4a450543ea6ae821c82",
494 b'Content-Disposition: form-data; name="array[]"',
495 b"",
496 b"foo",
497 b"--6b7ba517decee4a450543ea6ae821c82",
498 b'Content-Disposition: form-data; name="array[]"',
499 b"",
500 b"bar",
501 b"--6b7ba517decee4a450543ea6ae821c82",
502 b'Content-Disposition: form-data; name="foo.txt"; filename="upload"',
503 b"Content-Type: application/octet-stream",
504 b"",
505 b"hello world",
506 b"--6b7ba517decee4a450543ea6ae821c82--",
507 b"",
508 ]
509
510 @pytest.mark.respx(base_url=base_url)
511 def test_basic_union_response(self, respx_mock: MockRouter, client: OpenAI) -> None:
512 class Model1(BaseModel):
513 name: str
514
515 class Model2(BaseModel):
516 foo: str
517
518 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
519
520 response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
521 assert isinstance(response, Model2)
522 assert response.foo == "bar"
523
524 @pytest.mark.respx(base_url=base_url)
525 def test_union_response_different_types(self, respx_mock: MockRouter, client: OpenAI) -> None:
526 """Union of objects with the same field name using a different type"""
527
528 class Model1(BaseModel):
529 foo: int
530
531 class Model2(BaseModel):
532 foo: str
533
534 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
535
536 response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
537 assert isinstance(response, Model2)
538 assert response.foo == "bar"
539
540 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
541
542 response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
543 assert isinstance(response, Model1)
544 assert response.foo == 1
545
546 @pytest.mark.respx(base_url=base_url)
547 def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter, client: OpenAI) -> None:
548 """
549 Response that sets Content-Type to something other than application/json but returns json data
550 """
551
552 class Model(BaseModel):
553 foo: int
554
555 respx_mock.get("/foo").mock(
556 return_value=httpx.Response(
557 200,
558 content=json.dumps({"foo": 2}),
559 headers={"Content-Type": "application/text"},
560 )
561 )
562
563 response = client.get("/foo", cast_to=Model)
564 assert isinstance(response, Model)
565 assert response.foo == 2
566
567 def test_base_url_setter(self) -> None:
568 client = OpenAI(base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True)
569 assert client.base_url == "https://example.com/from_init/"
570
571 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
572
573 assert client.base_url == "https://example.com/from_setter/"
574
575 client.close()
576
577 def test_base_url_env(self) -> None:
578 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
579 client = OpenAI(api_key=api_key, _strict_response_validation=True)
580 assert client.base_url == "http://localhost:5000/from/env/"
581
582 @pytest.mark.parametrize(
583 "client",
584 [
585 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
586 OpenAI(
587 base_url="http://localhost:5000/custom/path/",
588 api_key=api_key,
589 _strict_response_validation=True,
590 http_client=httpx.Client(),
591 ),
592 ],
593 ids=["standard", "custom http client"],
594 )
595 def test_base_url_trailing_slash(self, client: OpenAI) -> None:
596 request = client._build_request(
597 FinalRequestOptions(
598 method="post",
599 url="/foo",
600 json_data={"foo": "bar"},
601 ),
602 )
603 assert request.url == "http://localhost:5000/custom/path/foo"
604 client.close()
605
606 @pytest.mark.parametrize(
607 "client",
608 [
609 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
610 OpenAI(
611 base_url="http://localhost:5000/custom/path/",
612 api_key=api_key,
613 _strict_response_validation=True,
614 http_client=httpx.Client(),
615 ),
616 ],
617 ids=["standard", "custom http client"],
618 )
619 def test_base_url_no_trailing_slash(self, client: OpenAI) -> None:
620 request = client._build_request(
621 FinalRequestOptions(
622 method="post",
623 url="/foo",
624 json_data={"foo": "bar"},
625 ),
626 )
627 assert request.url == "http://localhost:5000/custom/path/foo"
628 client.close()
629
630 @pytest.mark.parametrize(
631 "client",
632 [
633 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
634 OpenAI(
635 base_url="http://localhost:5000/custom/path/",
636 api_key=api_key,
637 _strict_response_validation=True,
638 http_client=httpx.Client(),
639 ),
640 ],
641 ids=["standard", "custom http client"],
642 )
643 def test_absolute_request_url(self, client: OpenAI) -> None:
644 request = client._build_request(
645 FinalRequestOptions(
646 method="post",
647 url="https://myapi.com/foo",
648 json_data={"foo": "bar"},
649 ),
650 )
651 assert request.url == "https://myapi.com/foo"
652 client.close()
653
654 def test_copied_client_does_not_close_http(self) -> None:
655 test_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
656 assert not test_client.is_closed()
657
658 copied = test_client.copy()
659 assert copied is not test_client
660
661 del copied
662
663 assert not test_client.is_closed()
664
665 def test_client_context_manager(self) -> None:
666 test_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
667 with test_client as c2:
668 assert c2 is test_client
669 assert not c2.is_closed()
670 assert not test_client.is_closed()
671 assert test_client.is_closed()
672
673 @pytest.mark.respx(base_url=base_url)
674 def test_client_response_validation_error(self, respx_mock: MockRouter, client: OpenAI) -> None:
675 class Model(BaseModel):
676 foo: str
677
678 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
679
680 with pytest.raises(APIResponseValidationError) as exc:
681 client.get("/foo", cast_to=Model)
682
683 assert isinstance(exc.value.__cause__, ValidationError)
684
685 def test_client_max_retries_validation(self) -> None:
686 with pytest.raises(TypeError, match=r"max_retries cannot be None"):
687 OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None))
688
689 @pytest.mark.respx(base_url=base_url)
690 def test_default_stream_cls(self, respx_mock: MockRouter, client: OpenAI) -> None:
691 class Model(BaseModel):
692 name: str
693
694 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
695
696 stream = client.post("/foo", cast_to=Model, stream=True, stream_cls=Stream[Model])
697 assert isinstance(stream, Stream)
698 stream.response.close()
699
700 @pytest.mark.respx(base_url=base_url)
701 def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
702 class Model(BaseModel):
703 name: str
704
705 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
706
707 strict_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
708
709 with pytest.raises(APIResponseValidationError):
710 strict_client.get("/foo", cast_to=Model)
711
712 non_strict_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
713
714 response = non_strict_client.get("/foo", cast_to=Model)
715 assert isinstance(response, str) # type: ignore[unreachable]
716
717 strict_client.close()
718 non_strict_client.close()
719
720 @pytest.mark.parametrize(
721 "remaining_retries,retry_after,timeout",
722 [
723 [3, "20", 20],
724 [3, "0", 0.5],
725 [3, "-10", 0.5],
726 [3, "60", 60],
727 [3, "61", 0.5],
728 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
729 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
730 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
731 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
732 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
733 [3, "99999999999999999999999999999999999", 0.5],
734 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
735 [3, "", 0.5],
736 [2, "", 0.5 * 2.0],
737 [1, "", 0.5 * 4.0],
738 [-1100, "", 8], # test large number potentially overflowing
739 ],
740 )
741 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
742 def test_parse_retry_after_header(
743 self, remaining_retries: int, retry_after: str, timeout: float, client: OpenAI
744 ) -> None:
745 headers = httpx.Headers({"retry-after": retry_after})
746 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
747 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
748 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
749
750 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
751 @pytest.mark.respx(base_url=base_url)
752 def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter, client: OpenAI) -> None:
753 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
754
755 with pytest.raises(APITimeoutError):
756 client.chat.completions.with_streaming_response.create(
757 messages=[
758 {
759 "content": "string",
760 "role": "developer",
761 }
762 ],
763 model="gpt-4o",
764 ).__enter__()
765
766 assert _get_open_connections(client) == 0
767
768 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
769 @pytest.mark.respx(base_url=base_url)
770 def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter, client: OpenAI) -> None:
771 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
772
773 with pytest.raises(APIStatusError):
774 client.chat.completions.with_streaming_response.create(
775 messages=[
776 {
777 "content": "string",
778 "role": "developer",
779 }
780 ],
781 model="gpt-4o",
782 ).__enter__()
783 assert _get_open_connections(client) == 0
784
785 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
786 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
787 @pytest.mark.respx(base_url=base_url)
788 @pytest.mark.parametrize("failure_mode", ["status", "exception"])
789 def test_retries_taken(
790 self,
791 client: OpenAI,
792 failures_before_success: int,
793 failure_mode: Literal["status", "exception"],
794 respx_mock: MockRouter,
795 ) -> None:
796 client = client.with_options(max_retries=4)
797
798 nb_retries = 0
799
800 def retry_handler(_request: httpx.Request) -> httpx.Response:
801 nonlocal nb_retries
802 if nb_retries < failures_before_success:
803 nb_retries += 1
804 if failure_mode == "exception":
805 raise RuntimeError("oops")
806 return httpx.Response(500)
807 return httpx.Response(200)
808
809 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
810
811 response = client.chat.completions.with_raw_response.create(
812 messages=[
813 {
814 "content": "string",
815 "role": "developer",
816 }
817 ],
818 model="gpt-4o",
819 )
820
821 assert response.retries_taken == failures_before_success
822 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
823
824 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
825 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
826 @pytest.mark.respx(base_url=base_url)
827 def test_omit_retry_count_header(
828 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
829 ) -> None:
830 client = client.with_options(max_retries=4)
831
832 nb_retries = 0
833
834 def retry_handler(_request: httpx.Request) -> httpx.Response:
835 nonlocal nb_retries
836 if nb_retries < failures_before_success:
837 nb_retries += 1
838 return httpx.Response(500)
839 return httpx.Response(200)
840
841 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
842
843 response = client.chat.completions.with_raw_response.create(
844 messages=[
845 {
846 "content": "string",
847 "role": "developer",
848 }
849 ],
850 model="gpt-4o",
851 extra_headers={"x-stainless-retry-count": Omit()},
852 )
853
854 assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0
855
856 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
857 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
858 @pytest.mark.respx(base_url=base_url)
859 def test_overwrite_retry_count_header(
860 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
861 ) -> None:
862 client = client.with_options(max_retries=4)
863
864 nb_retries = 0
865
866 def retry_handler(_request: httpx.Request) -> httpx.Response:
867 nonlocal nb_retries
868 if nb_retries < failures_before_success:
869 nb_retries += 1
870 return httpx.Response(500)
871 return httpx.Response(200)
872
873 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
874
875 response = client.chat.completions.with_raw_response.create(
876 messages=[
877 {
878 "content": "string",
879 "role": "developer",
880 }
881 ],
882 model="gpt-4o",
883 extra_headers={"x-stainless-retry-count": "42"},
884 )
885
886 assert response.http_request.headers.get("x-stainless-retry-count") == "42"
887
888 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
889 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
890 @pytest.mark.respx(base_url=base_url)
891 def test_retries_taken_new_response_class(
892 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
893 ) -> None:
894 client = client.with_options(max_retries=4)
895
896 nb_retries = 0
897
898 def retry_handler(_request: httpx.Request) -> httpx.Response:
899 nonlocal nb_retries
900 if nb_retries < failures_before_success:
901 nb_retries += 1
902 return httpx.Response(500)
903 return httpx.Response(200)
904
905 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
906
907 with client.chat.completions.with_streaming_response.create(
908 messages=[
909 {
910 "content": "string",
911 "role": "developer",
912 }
913 ],
914 model="gpt-4o",
915 ) as response:
916 assert response.retries_taken == failures_before_success
917 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
918
919 def test_proxy_environment_variables(self, monkeypatch: pytest.MonkeyPatch) -> None:
920 # Test that the proxy environment variables are set correctly
921 monkeypatch.setenv("HTTPS_PROXY", "https://example.org")
922
923 client = DefaultHttpxClient()
924
925 mounts = tuple(client._mounts.items())
926 assert len(mounts) == 1
927 assert mounts[0][0].pattern == "https://"
928
929 @pytest.mark.filterwarnings("ignore:.*deprecated.*:DeprecationWarning")
930 def test_default_client_creation(self) -> None:
931 # Ensure that the client can be initialized without any exceptions
932 DefaultHttpxClient(
933 verify=True,
934 cert=None,
935 trust_env=True,
936 http1=True,
937 http2=False,
938 limits=httpx.Limits(max_connections=100, max_keepalive_connections=20),
939 )
940
941 @pytest.mark.respx(base_url=base_url)
942 def test_follow_redirects(self, respx_mock: MockRouter, client: OpenAI) -> None:
943 # Test that the default follow_redirects=True allows following redirects
944 respx_mock.post("/redirect").mock(
945 return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
946 )
947 respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"}))
948
949 response = client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response)
950 assert response.status_code == 200
951 assert response.json() == {"status": "ok"}
952
953 @pytest.mark.respx(base_url=base_url)
954 def test_follow_redirects_disabled(self, respx_mock: MockRouter, client: OpenAI) -> None:
955 # Test that follow_redirects=False prevents following redirects
956 respx_mock.post("/redirect").mock(
957 return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
958 )
959
960 with pytest.raises(APIStatusError) as exc_info:
961 client.post("/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response)
962
963 assert exc_info.value.response.status_code == 302
964 assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected"
965
966 def test_api_key_before_after_refresh_provider(self) -> None:
967 client = OpenAI(base_url=base_url, api_key=lambda: "test_bearer_token")
968
969 assert client.api_key == ""
970 assert "Authorization" not in client.auth_headers
971
972 client._refresh_api_key()
973
974 assert client.api_key == "test_bearer_token"
975 assert client.auth_headers.get("Authorization") == "Bearer test_bearer_token"
976
977 def test_api_key_before_after_refresh_str(self) -> None:
978 client = OpenAI(base_url=base_url, api_key="test_api_key")
979
980 assert client.auth_headers.get("Authorization") == "Bearer test_api_key"
981 client._refresh_api_key()
982
983 assert client.auth_headers.get("Authorization") == "Bearer test_api_key"
984
985 @pytest.mark.respx()
986 def test_api_key_refresh_on_retry(self, respx_mock: MockRouter) -> None:
987 respx_mock.post(base_url + "/chat/completions").mock(
988 side_effect=[
989 httpx.Response(500, json={"error": "server error"}),
990 httpx.Response(200, json={"foo": "bar"}),
991 ]
992 )
993
994 counter = 0
995
996 def token_provider() -> str:
997 nonlocal counter
998
999 counter += 1
1000
1001 if counter == 1:
1002 return "first"
1003
1004 return "second"
1005
1006 client = OpenAI(base_url=base_url, api_key=token_provider)
1007 client.chat.completions.create(messages=[], model="gpt-4")
1008
1009 calls = cast("list[MockRequestCall]", respx_mock.calls)
1010 assert len(calls) == 2
1011
1012 assert calls[0].request.headers.get("Authorization") == "Bearer first"
1013 assert calls[1].request.headers.get("Authorization") == "Bearer second"
1014
1015 def test_copy_auth(self) -> None:
1016 client = OpenAI(base_url=base_url, api_key=lambda: "test_bearer_token_1").copy(
1017 api_key=lambda: "test_bearer_token_2"
1018 )
1019 client._refresh_api_key()
1020 assert client.auth_headers == {"Authorization": "Bearer test_bearer_token_2"}
1021
1022
1023class TestAsyncOpenAI:
1024 @pytest.mark.respx(base_url=base_url)
1025 async def test_raw_response(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
1026 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1027
1028 response = await async_client.post("/foo", cast_to=httpx.Response)
1029 assert response.status_code == 200
1030 assert isinstance(response, httpx.Response)
1031 assert response.json() == {"foo": "bar"}
1032
1033 @pytest.mark.respx(base_url=base_url)
1034 async def test_raw_response_for_binary(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
1035 respx_mock.post("/foo").mock(
1036 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
1037 )
1038
1039 response = await async_client.post("/foo", cast_to=httpx.Response)
1040 assert response.status_code == 200
1041 assert isinstance(response, httpx.Response)
1042 assert response.json() == {"foo": "bar"}
1043
1044 def test_copy(self, async_client: AsyncOpenAI) -> None:
1045 copied = async_client.copy()
1046 assert id(copied) != id(async_client)
1047
1048 copied = async_client.copy(api_key="another My API Key")
1049 assert copied.api_key == "another My API Key"
1050 assert async_client.api_key == "My API Key"
1051
1052 def test_copy_default_options(self, async_client: AsyncOpenAI) -> None:
1053 # options that have a default are overridden correctly
1054 copied = async_client.copy(max_retries=7)
1055 assert copied.max_retries == 7
1056 assert async_client.max_retries == 2
1057
1058 copied2 = copied.copy(max_retries=6)
1059 assert copied2.max_retries == 6
1060 assert copied.max_retries == 7
1061
1062 # timeout
1063 assert isinstance(async_client.timeout, httpx.Timeout)
1064 copied = async_client.copy(timeout=None)
1065 assert copied.timeout is None
1066 assert isinstance(async_client.timeout, httpx.Timeout)
1067
1068 async def test_copy_default_headers(self) -> None:
1069 client = AsyncOpenAI(
1070 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
1071 )
1072 assert client.default_headers["X-Foo"] == "bar"
1073
1074 # does not override the already given value when not specified
1075 copied = client.copy()
1076 assert copied.default_headers["X-Foo"] == "bar"
1077
1078 # merges already given headers
1079 copied = client.copy(default_headers={"X-Bar": "stainless"})
1080 assert copied.default_headers["X-Foo"] == "bar"
1081 assert copied.default_headers["X-Bar"] == "stainless"
1082
1083 # uses new values for any already given headers
1084 copied = client.copy(default_headers={"X-Foo": "stainless"})
1085 assert copied.default_headers["X-Foo"] == "stainless"
1086
1087 # set_default_headers
1088
1089 # completely overrides already set values
1090 copied = client.copy(set_default_headers={})
1091 assert copied.default_headers.get("X-Foo") is None
1092
1093 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
1094 assert copied.default_headers["X-Bar"] == "Robert"
1095
1096 with pytest.raises(
1097 ValueError,
1098 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
1099 ):
1100 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
1101 await client.close()
1102
1103 async def test_copy_default_query(self) -> None:
1104 client = AsyncOpenAI(
1105 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
1106 )
1107 assert _get_params(client)["foo"] == "bar"
1108
1109 # does not override the already given value when not specified
1110 copied = client.copy()
1111 assert _get_params(copied)["foo"] == "bar"
1112
1113 # merges already given params
1114 copied = client.copy(default_query={"bar": "stainless"})
1115 params = _get_params(copied)
1116 assert params["foo"] == "bar"
1117 assert params["bar"] == "stainless"
1118
1119 # uses new values for any already given headers
1120 copied = client.copy(default_query={"foo": "stainless"})
1121 assert _get_params(copied)["foo"] == "stainless"
1122
1123 # set_default_query
1124
1125 # completely overrides already set values
1126 copied = client.copy(set_default_query={})
1127 assert _get_params(copied) == {}
1128
1129 copied = client.copy(set_default_query={"bar": "Robert"})
1130 assert _get_params(copied)["bar"] == "Robert"
1131
1132 with pytest.raises(
1133 ValueError,
1134 # TODO: update
1135 match="`default_query` and `set_default_query` arguments are mutually exclusive",
1136 ):
1137 client.copy(set_default_query={}, default_query={"foo": "Bar"})
1138
1139 await client.close()
1140
1141 def test_copy_signature(self, async_client: AsyncOpenAI) -> None:
1142 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
1143 init_signature = inspect.signature(
1144 # mypy doesn't like that we access the `__init__` property.
1145 async_client.__init__, # type: ignore[misc]
1146 )
1147 copy_signature = inspect.signature(async_client.copy)
1148 exclude_params = {"transport", "proxies", "_strict_response_validation"}
1149
1150 for name in init_signature.parameters.keys():
1151 if name in exclude_params:
1152 continue
1153
1154 copy_param = copy_signature.parameters.get(name)
1155 assert copy_param is not None, f"copy() signature is missing the {name} param"
1156
1157 @pytest.mark.skipif(sys.version_info >= (3, 10), reason="fails because of a memory leak that started from 3.12")
1158 def test_copy_build_request(self, async_client: AsyncOpenAI) -> None:
1159 options = FinalRequestOptions(method="get", url="/foo")
1160
1161 def build_request(options: FinalRequestOptions) -> None:
1162 client_copy = async_client.copy()
1163 client_copy._build_request(options)
1164
1165 # ensure that the machinery is warmed up before tracing starts.
1166 build_request(options)
1167 gc.collect()
1168
1169 tracemalloc.start(1000)
1170
1171 snapshot_before = tracemalloc.take_snapshot()
1172
1173 ITERATIONS = 10
1174 for _ in range(ITERATIONS):
1175 build_request(options)
1176
1177 gc.collect()
1178 snapshot_after = tracemalloc.take_snapshot()
1179
1180 tracemalloc.stop()
1181
1182 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
1183 if diff.count == 0:
1184 # Avoid false positives by considering only leaks (i.e. allocations that persist).
1185 return
1186
1187 if diff.count % ITERATIONS != 0:
1188 # Avoid false positives by considering only leaks that appear per iteration.
1189 return
1190
1191 for frame in diff.traceback:
1192 if any(
1193 frame.filename.endswith(fragment)
1194 for fragment in [
1195 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
1196 #
1197 # removing the decorator fixes the leak for reasons we don't understand.
1198 "openai/_legacy_response.py",
1199 "openai/_response.py",
1200 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
1201 "openai/_compat.py",
1202 # Standard library leaks we don't care about.
1203 "/logging/__init__.py",
1204 ]
1205 ):
1206 return
1207
1208 leaks.append(diff)
1209
1210 leaks: list[tracemalloc.StatisticDiff] = []
1211 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
1212 add_leak(leaks, diff)
1213 if leaks:
1214 for leak in leaks:
1215 print("MEMORY LEAK:", leak)
1216 for frame in leak.traceback:
1217 print(frame)
1218 raise AssertionError()
1219
1220 async def test_request_timeout(self, async_client: AsyncOpenAI) -> None:
1221 request = async_client._build_request(FinalRequestOptions(method="get", url="/foo"))
1222 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1223 assert timeout == DEFAULT_TIMEOUT
1224
1225 request = async_client._build_request(
1226 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
1227 )
1228 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1229 assert timeout == httpx.Timeout(100.0)
1230
1231 async def test_client_timeout_option(self) -> None:
1232 client = AsyncOpenAI(
1233 base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0)
1234 )
1235
1236 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1237 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1238 assert timeout == httpx.Timeout(0)
1239
1240 await client.close()
1241
1242 async def test_http_client_timeout_option(self) -> None:
1243 # custom timeout given to the httpx client should be used
1244 async with httpx.AsyncClient(timeout=None) as http_client:
1245 client = AsyncOpenAI(
1246 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1247 )
1248
1249 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1250 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1251 assert timeout == httpx.Timeout(None)
1252
1253 await client.close()
1254
1255 # no timeout given to the httpx client should not use the httpx default
1256 async with httpx.AsyncClient() as http_client:
1257 client = AsyncOpenAI(
1258 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1259 )
1260
1261 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1262 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1263 assert timeout == DEFAULT_TIMEOUT
1264
1265 await client.close()
1266
1267 # explicitly passing the default timeout currently results in it being ignored
1268 async with httpx.AsyncClient(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
1269 client = AsyncOpenAI(
1270 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1271 )
1272
1273 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1274 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1275 assert timeout == DEFAULT_TIMEOUT # our default
1276
1277 await client.close()
1278
1279 def test_invalid_http_client(self) -> None:
1280 with pytest.raises(TypeError, match="Invalid `http_client` arg"):
1281 with httpx.Client() as http_client:
1282 AsyncOpenAI(
1283 base_url=base_url,
1284 api_key=api_key,
1285 _strict_response_validation=True,
1286 http_client=cast(Any, http_client),
1287 )
1288
1289 async def test_default_headers_option(self) -> None:
1290 test_client = AsyncOpenAI(
1291 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
1292 )
1293 request = test_client._build_request(FinalRequestOptions(method="get", url="/foo"))
1294 assert request.headers.get("x-foo") == "bar"
1295 assert request.headers.get("x-stainless-lang") == "python"
1296
1297 test_client2 = AsyncOpenAI(
1298 base_url=base_url,
1299 api_key=api_key,
1300 _strict_response_validation=True,
1301 default_headers={
1302 "X-Foo": "stainless",
1303 "X-Stainless-Lang": "my-overriding-header",
1304 },
1305 )
1306 request = test_client2._build_request(FinalRequestOptions(method="get", url="/foo"))
1307 assert request.headers.get("x-foo") == "stainless"
1308 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
1309
1310 await test_client.close()
1311 await test_client2.close()
1312
1313 async def test_validate_headers(self) -> None:
1314 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1315 options = await client._prepare_options(FinalRequestOptions(method="get", url="/foo"))
1316 request = client._build_request(options)
1317 assert request.headers.get("Authorization") == f"Bearer {api_key}"
1318
1319 with pytest.raises(OpenAIError):
1320 with update_env(**{"OPENAI_API_KEY": Omit()}):
1321 client2 = AsyncOpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
1322 _ = client2
1323
1324 async def test_default_query_option(self) -> None:
1325 client = AsyncOpenAI(
1326 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
1327 )
1328 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1329 url = httpx.URL(request.url)
1330 assert dict(url.params) == {"query_param": "bar"}
1331
1332 request = client._build_request(
1333 FinalRequestOptions(
1334 method="get",
1335 url="/foo",
1336 params={"foo": "baz", "query_param": "overridden"},
1337 )
1338 )
1339 url = httpx.URL(request.url)
1340 assert dict(url.params) == {"foo": "baz", "query_param": "overridden"}
1341
1342 await client.close()
1343
1344 def test_request_extra_json(self, client: OpenAI) -> None:
1345 request = client._build_request(
1346 FinalRequestOptions(
1347 method="post",
1348 url="/foo",
1349 json_data={"foo": "bar"},
1350 extra_json={"baz": False},
1351 ),
1352 )
1353 data = json.loads(request.content.decode("utf-8"))
1354 assert data == {"foo": "bar", "baz": False}
1355
1356 request = client._build_request(
1357 FinalRequestOptions(
1358 method="post",
1359 url="/foo",
1360 extra_json={"baz": False},
1361 ),
1362 )
1363 data = json.loads(request.content.decode("utf-8"))
1364 assert data == {"baz": False}
1365
1366 # `extra_json` takes priority over `json_data` when keys clash
1367 request = client._build_request(
1368 FinalRequestOptions(
1369 method="post",
1370 url="/foo",
1371 json_data={"foo": "bar", "baz": True},
1372 extra_json={"baz": None},
1373 ),
1374 )
1375 data = json.loads(request.content.decode("utf-8"))
1376 assert data == {"foo": "bar", "baz": None}
1377
1378 def test_request_extra_headers(self, client: OpenAI) -> None:
1379 request = client._build_request(
1380 FinalRequestOptions(
1381 method="post",
1382 url="/foo",
1383 **make_request_options(extra_headers={"X-Foo": "Foo"}),
1384 ),
1385 )
1386 assert request.headers.get("X-Foo") == "Foo"
1387
1388 # `extra_headers` takes priority over `default_headers` when keys clash
1389 request = client.with_options(default_headers={"X-Bar": "true"})._build_request(
1390 FinalRequestOptions(
1391 method="post",
1392 url="/foo",
1393 **make_request_options(
1394 extra_headers={"X-Bar": "false"},
1395 ),
1396 ),
1397 )
1398 assert request.headers.get("X-Bar") == "false"
1399
1400 def test_request_extra_query(self, client: OpenAI) -> None:
1401 request = client._build_request(
1402 FinalRequestOptions(
1403 method="post",
1404 url="/foo",
1405 **make_request_options(
1406 extra_query={"my_query_param": "Foo"},
1407 ),
1408 ),
1409 )
1410 params = dict(request.url.params)
1411 assert params == {"my_query_param": "Foo"}
1412
1413 # if both `query` and `extra_query` are given, they are merged
1414 request = client._build_request(
1415 FinalRequestOptions(
1416 method="post",
1417 url="/foo",
1418 **make_request_options(
1419 query={"bar": "1"},
1420 extra_query={"foo": "2"},
1421 ),
1422 ),
1423 )
1424 params = dict(request.url.params)
1425 assert params == {"bar": "1", "foo": "2"}
1426
1427 # `extra_query` takes priority over `query` when keys clash
1428 request = client._build_request(
1429 FinalRequestOptions(
1430 method="post",
1431 url="/foo",
1432 **make_request_options(
1433 query={"foo": "1"},
1434 extra_query={"foo": "2"},
1435 ),
1436 ),
1437 )
1438 params = dict(request.url.params)
1439 assert params == {"foo": "2"}
1440
1441 def test_multipart_repeating_array(self, async_client: AsyncOpenAI) -> None:
1442 request = async_client._build_request(
1443 FinalRequestOptions.construct(
1444 method="post",
1445 url="/foo",
1446 headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
1447 json_data={"array": ["foo", "bar"]},
1448 files=[("foo.txt", b"hello world")],
1449 )
1450 )
1451
1452 assert request.read().split(b"\r\n") == [
1453 b"--6b7ba517decee4a450543ea6ae821c82",
1454 b'Content-Disposition: form-data; name="array[]"',
1455 b"",
1456 b"foo",
1457 b"--6b7ba517decee4a450543ea6ae821c82",
1458 b'Content-Disposition: form-data; name="array[]"',
1459 b"",
1460 b"bar",
1461 b"--6b7ba517decee4a450543ea6ae821c82",
1462 b'Content-Disposition: form-data; name="foo.txt"; filename="upload"',
1463 b"Content-Type: application/octet-stream",
1464 b"",
1465 b"hello world",
1466 b"--6b7ba517decee4a450543ea6ae821c82--",
1467 b"",
1468 ]
1469
1470 @pytest.mark.respx(base_url=base_url)
1471 async def test_basic_union_response(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
1472 class Model1(BaseModel):
1473 name: str
1474
1475 class Model2(BaseModel):
1476 foo: str
1477
1478 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1479
1480 response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1481 assert isinstance(response, Model2)
1482 assert response.foo == "bar"
1483
1484 @pytest.mark.respx(base_url=base_url)
1485 async def test_union_response_different_types(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
1486 """Union of objects with the same field name using a different type"""
1487
1488 class Model1(BaseModel):
1489 foo: int
1490
1491 class Model2(BaseModel):
1492 foo: str
1493
1494 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1495
1496 response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1497 assert isinstance(response, Model2)
1498 assert response.foo == "bar"
1499
1500 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
1501
1502 response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1503 assert isinstance(response, Model1)
1504 assert response.foo == 1
1505
1506 @pytest.mark.respx(base_url=base_url)
1507 async def test_non_application_json_content_type_for_json_data(
1508 self, respx_mock: MockRouter, async_client: AsyncOpenAI
1509 ) -> None:
1510 """
1511 Response that sets Content-Type to something other than application/json but returns json data
1512 """
1513
1514 class Model(BaseModel):
1515 foo: int
1516
1517 respx_mock.get("/foo").mock(
1518 return_value=httpx.Response(
1519 200,
1520 content=json.dumps({"foo": 2}),
1521 headers={"Content-Type": "application/text"},
1522 )
1523 )
1524
1525 response = await async_client.get("/foo", cast_to=Model)
1526 assert isinstance(response, Model)
1527 assert response.foo == 2
1528
1529 async def test_base_url_setter(self) -> None:
1530 client = AsyncOpenAI(
1531 base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True
1532 )
1533 assert client.base_url == "https://example.com/from_init/"
1534
1535 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
1536
1537 assert client.base_url == "https://example.com/from_setter/"
1538
1539 await client.close()
1540
1541 async def test_base_url_env(self) -> None:
1542 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
1543 client = AsyncOpenAI(api_key=api_key, _strict_response_validation=True)
1544 assert client.base_url == "http://localhost:5000/from/env/"
1545
1546 @pytest.mark.parametrize(
1547 "client",
1548 [
1549 AsyncOpenAI(
1550 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1551 ),
1552 AsyncOpenAI(
1553 base_url="http://localhost:5000/custom/path/",
1554 api_key=api_key,
1555 _strict_response_validation=True,
1556 http_client=httpx.AsyncClient(),
1557 ),
1558 ],
1559 ids=["standard", "custom http client"],
1560 )
1561 async def test_base_url_trailing_slash(self, client: AsyncOpenAI) -> None:
1562 request = client._build_request(
1563 FinalRequestOptions(
1564 method="post",
1565 url="/foo",
1566 json_data={"foo": "bar"},
1567 ),
1568 )
1569 assert request.url == "http://localhost:5000/custom/path/foo"
1570 await client.close()
1571
1572 @pytest.mark.parametrize(
1573 "client",
1574 [
1575 AsyncOpenAI(
1576 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1577 ),
1578 AsyncOpenAI(
1579 base_url="http://localhost:5000/custom/path/",
1580 api_key=api_key,
1581 _strict_response_validation=True,
1582 http_client=httpx.AsyncClient(),
1583 ),
1584 ],
1585 ids=["standard", "custom http client"],
1586 )
1587 async def test_base_url_no_trailing_slash(self, client: AsyncOpenAI) -> None:
1588 request = client._build_request(
1589 FinalRequestOptions(
1590 method="post",
1591 url="/foo",
1592 json_data={"foo": "bar"},
1593 ),
1594 )
1595 assert request.url == "http://localhost:5000/custom/path/foo"
1596 await client.close()
1597
1598 @pytest.mark.parametrize(
1599 "client",
1600 [
1601 AsyncOpenAI(
1602 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1603 ),
1604 AsyncOpenAI(
1605 base_url="http://localhost:5000/custom/path/",
1606 api_key=api_key,
1607 _strict_response_validation=True,
1608 http_client=httpx.AsyncClient(),
1609 ),
1610 ],
1611 ids=["standard", "custom http client"],
1612 )
1613 async def test_absolute_request_url(self, client: AsyncOpenAI) -> None:
1614 request = client._build_request(
1615 FinalRequestOptions(
1616 method="post",
1617 url="https://myapi.com/foo",
1618 json_data={"foo": "bar"},
1619 ),
1620 )
1621 assert request.url == "https://myapi.com/foo"
1622 await client.close()
1623
1624 async def test_copied_client_does_not_close_http(self) -> None:
1625 test_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1626 assert not test_client.is_closed()
1627
1628 copied = test_client.copy()
1629 assert copied is not test_client
1630
1631 del copied
1632
1633 await asyncio.sleep(0.2)
1634 assert not test_client.is_closed()
1635
1636 async def test_client_context_manager(self) -> None:
1637 test_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1638 async with test_client as c2:
1639 assert c2 is test_client
1640 assert not c2.is_closed()
1641 assert not test_client.is_closed()
1642 assert test_client.is_closed()
1643
1644 @pytest.mark.respx(base_url=base_url)
1645 async def test_client_response_validation_error(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
1646 class Model(BaseModel):
1647 foo: str
1648
1649 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
1650
1651 with pytest.raises(APIResponseValidationError) as exc:
1652 await async_client.get("/foo", cast_to=Model)
1653
1654 assert isinstance(exc.value.__cause__, ValidationError)
1655
1656 async def test_client_max_retries_validation(self) -> None:
1657 with pytest.raises(TypeError, match=r"max_retries cannot be None"):
1658 AsyncOpenAI(
1659 base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None)
1660 )
1661
1662 @pytest.mark.respx(base_url=base_url)
1663 async def test_default_stream_cls(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
1664 class Model(BaseModel):
1665 name: str
1666
1667 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1668
1669 stream = await async_client.post("/foo", cast_to=Model, stream=True, stream_cls=AsyncStream[Model])
1670 assert isinstance(stream, AsyncStream)
1671 await stream.response.aclose()
1672
1673 @pytest.mark.respx(base_url=base_url)
1674 async def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
1675 class Model(BaseModel):
1676 name: str
1677
1678 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
1679
1680 strict_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1681
1682 with pytest.raises(APIResponseValidationError):
1683 await strict_client.get("/foo", cast_to=Model)
1684
1685 non_strict_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
1686
1687 response = await non_strict_client.get("/foo", cast_to=Model)
1688 assert isinstance(response, str) # type: ignore[unreachable]
1689
1690 await strict_client.close()
1691 await non_strict_client.close()
1692
1693 @pytest.mark.parametrize(
1694 "remaining_retries,retry_after,timeout",
1695 [
1696 [3, "20", 20],
1697 [3, "0", 0.5],
1698 [3, "-10", 0.5],
1699 [3, "60", 60],
1700 [3, "61", 0.5],
1701 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
1702 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
1703 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
1704 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
1705 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
1706 [3, "99999999999999999999999999999999999", 0.5],
1707 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
1708 [3, "", 0.5],
1709 [2, "", 0.5 * 2.0],
1710 [1, "", 0.5 * 4.0],
1711 [-1100, "", 8], # test large number potentially overflowing
1712 ],
1713 )
1714 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
1715 async def test_parse_retry_after_header(
1716 self, remaining_retries: int, retry_after: str, timeout: float, async_client: AsyncOpenAI
1717 ) -> None:
1718 headers = httpx.Headers({"retry-after": retry_after})
1719 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
1720 calculated = async_client._calculate_retry_timeout(remaining_retries, options, headers)
1721 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
1722
1723 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1724 @pytest.mark.respx(base_url=base_url)
1725 async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
1726 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
1727
1728 with pytest.raises(APITimeoutError):
1729 await async_client.chat.completions.with_streaming_response.create(
1730 messages=[
1731 {
1732 "content": "string",
1733 "role": "developer",
1734 }
1735 ],
1736 model="gpt-4o",
1737 ).__aenter__()
1738
1739 assert _get_open_connections(async_client) == 0
1740
1741 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1742 @pytest.mark.respx(base_url=base_url)
1743 async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
1744 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
1745
1746 with pytest.raises(APIStatusError):
1747 await async_client.chat.completions.with_streaming_response.create(
1748 messages=[
1749 {
1750 "content": "string",
1751 "role": "developer",
1752 }
1753 ],
1754 model="gpt-4o",
1755 ).__aenter__()
1756 assert _get_open_connections(async_client) == 0
1757
1758 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1759 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1760 @pytest.mark.respx(base_url=base_url)
1761 @pytest.mark.parametrize("failure_mode", ["status", "exception"])
1762 async def test_retries_taken(
1763 self,
1764 async_client: AsyncOpenAI,
1765 failures_before_success: int,
1766 failure_mode: Literal["status", "exception"],
1767 respx_mock: MockRouter,
1768 ) -> None:
1769 client = async_client.with_options(max_retries=4)
1770
1771 nb_retries = 0
1772
1773 def retry_handler(_request: httpx.Request) -> httpx.Response:
1774 nonlocal nb_retries
1775 if nb_retries < failures_before_success:
1776 nb_retries += 1
1777 if failure_mode == "exception":
1778 raise RuntimeError("oops")
1779 return httpx.Response(500)
1780 return httpx.Response(200)
1781
1782 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1783
1784 response = await client.chat.completions.with_raw_response.create(
1785 messages=[
1786 {
1787 "content": "string",
1788 "role": "developer",
1789 }
1790 ],
1791 model="gpt-4o",
1792 )
1793
1794 assert response.retries_taken == failures_before_success
1795 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
1796
1797 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1798 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1799 @pytest.mark.respx(base_url=base_url)
1800 async def test_omit_retry_count_header(
1801 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
1802 ) -> None:
1803 client = async_client.with_options(max_retries=4)
1804
1805 nb_retries = 0
1806
1807 def retry_handler(_request: httpx.Request) -> httpx.Response:
1808 nonlocal nb_retries
1809 if nb_retries < failures_before_success:
1810 nb_retries += 1
1811 return httpx.Response(500)
1812 return httpx.Response(200)
1813
1814 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1815
1816 response = await client.chat.completions.with_raw_response.create(
1817 messages=[
1818 {
1819 "content": "string",
1820 "role": "developer",
1821 }
1822 ],
1823 model="gpt-4o",
1824 extra_headers={"x-stainless-retry-count": Omit()},
1825 )
1826
1827 assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0
1828
1829 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1830 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1831 @pytest.mark.respx(base_url=base_url)
1832 async def test_overwrite_retry_count_header(
1833 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
1834 ) -> None:
1835 client = async_client.with_options(max_retries=4)
1836
1837 nb_retries = 0
1838
1839 def retry_handler(_request: httpx.Request) -> httpx.Response:
1840 nonlocal nb_retries
1841 if nb_retries < failures_before_success:
1842 nb_retries += 1
1843 return httpx.Response(500)
1844 return httpx.Response(200)
1845
1846 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1847
1848 response = await client.chat.completions.with_raw_response.create(
1849 messages=[
1850 {
1851 "content": "string",
1852 "role": "developer",
1853 }
1854 ],
1855 model="gpt-4o",
1856 extra_headers={"x-stainless-retry-count": "42"},
1857 )
1858
1859 assert response.http_request.headers.get("x-stainless-retry-count") == "42"
1860
1861 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1862 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1863 @pytest.mark.respx(base_url=base_url)
1864 async def test_retries_taken_new_response_class(
1865 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
1866 ) -> None:
1867 client = async_client.with_options(max_retries=4)
1868
1869 nb_retries = 0
1870
1871 def retry_handler(_request: httpx.Request) -> httpx.Response:
1872 nonlocal nb_retries
1873 if nb_retries < failures_before_success:
1874 nb_retries += 1
1875 return httpx.Response(500)
1876 return httpx.Response(200)
1877
1878 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1879
1880 async with client.chat.completions.with_streaming_response.create(
1881 messages=[
1882 {
1883 "content": "string",
1884 "role": "developer",
1885 }
1886 ],
1887 model="gpt-4o",
1888 ) as response:
1889 assert response.retries_taken == failures_before_success
1890 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
1891
1892 async def test_get_platform(self) -> None:
1893 platform = await asyncify(get_platform)()
1894 assert isinstance(platform, (str, OtherPlatform))
1895
1896 async def test_proxy_environment_variables(self, monkeypatch: pytest.MonkeyPatch) -> None:
1897 # Test that the proxy environment variables are set correctly
1898 monkeypatch.setenv("HTTPS_PROXY", "https://example.org")
1899
1900 client = DefaultAsyncHttpxClient()
1901
1902 mounts = tuple(client._mounts.items())
1903 assert len(mounts) == 1
1904 assert mounts[0][0].pattern == "https://"
1905
1906 @pytest.mark.filterwarnings("ignore:.*deprecated.*:DeprecationWarning")
1907 async def test_default_client_creation(self) -> None:
1908 # Ensure that the client can be initialized without any exceptions
1909 DefaultAsyncHttpxClient(
1910 verify=True,
1911 cert=None,
1912 trust_env=True,
1913 http1=True,
1914 http2=False,
1915 limits=httpx.Limits(max_connections=100, max_keepalive_connections=20),
1916 )
1917
1918 @pytest.mark.respx(base_url=base_url)
1919 async def test_follow_redirects(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
1920 # Test that the default follow_redirects=True allows following redirects
1921 respx_mock.post("/redirect").mock(
1922 return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
1923 )
1924 respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"}))
1925
1926 response = await async_client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response)
1927 assert response.status_code == 200
1928 assert response.json() == {"status": "ok"}
1929
1930 @pytest.mark.respx(base_url=base_url)
1931 async def test_follow_redirects_disabled(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
1932 # Test that follow_redirects=False prevents following redirects
1933 respx_mock.post("/redirect").mock(
1934 return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
1935 )
1936
1937 with pytest.raises(APIStatusError) as exc_info:
1938 await async_client.post(
1939 "/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response
1940 )
1941
1942 assert exc_info.value.response.status_code == 302
1943 assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected"
1944
1945 @pytest.mark.asyncio
1946 async def test_api_key_before_after_refresh_provider(self) -> None:
1947 async def mock_api_key_provider():
1948 return "test_bearer_token"
1949
1950 client = AsyncOpenAI(base_url=base_url, api_key=mock_api_key_provider)
1951
1952 assert client.api_key == ""
1953 assert "Authorization" not in client.auth_headers
1954
1955 await client._refresh_api_key()
1956
1957 assert client.api_key == "test_bearer_token"
1958 assert client.auth_headers.get("Authorization") == "Bearer test_bearer_token"
1959
1960 @pytest.mark.asyncio
1961 async def test_api_key_before_after_refresh_str(self) -> None:
1962 client = AsyncOpenAI(base_url=base_url, api_key="test_api_key")
1963
1964 assert client.auth_headers.get("Authorization") == "Bearer test_api_key"
1965 await client._refresh_api_key()
1966
1967 assert client.auth_headers.get("Authorization") == "Bearer test_api_key"
1968
1969 @pytest.mark.asyncio
1970 @pytest.mark.respx()
1971 async def test_bearer_token_refresh_async(self, respx_mock: MockRouter) -> None:
1972 respx_mock.post(base_url + "/chat/completions").mock(
1973 side_effect=[
1974 httpx.Response(500, json={"error": "server error"}),
1975 httpx.Response(200, json={"foo": "bar"}),
1976 ]
1977 )
1978
1979 counter = 0
1980
1981 async def token_provider() -> str:
1982 nonlocal counter
1983
1984 counter += 1
1985
1986 if counter == 1:
1987 return "first"
1988
1989 return "second"
1990
1991 client = AsyncOpenAI(base_url=base_url, api_key=token_provider)
1992 await client.chat.completions.create(messages=[], model="gpt-4")
1993
1994 calls = cast("list[MockRequestCall]", respx_mock.calls)
1995 assert len(calls) == 2
1996
1997 assert calls[0].request.headers.get("Authorization") == "Bearer first"
1998 assert calls[1].request.headers.get("Authorization") == "Bearer second"
1999
2000 @pytest.mark.asyncio
2001 async def test_copy_auth(self) -> None:
2002 async def token_provider_1() -> str:
2003 return "test_bearer_token_1"
2004
2005 async def token_provider_2() -> str:
2006 return "test_bearer_token_2"
2007
2008 client = AsyncOpenAI(base_url=base_url, api_key=token_provider_1).copy(api_key=token_provider_2)
2009 await client._refresh_api_key()
2010 assert client.auth_headers == {"Authorization": "Bearer test_bearer_token_2"}