Commit 07079085
Changed files (2)
src
openai
tests
src/openai/_base_client.py
@@ -361,6 +361,11 @@ class BaseClient(Generic[_HttpxClientT, _DefaultStreamT]):
self._strict_response_validation = _strict_response_validation
self._idempotency_header = None
+ if max_retries is None: # pyright: ignore[reportUnnecessaryComparison]
+ raise TypeError(
+ "max_retries cannot be None. If you want to disable retries, pass `0`; if you want unlimited retries, pass `math.inf` or a very high number; if you want the default behavior, pass `openai.DEFAULT_MAX_RETRIES`"
+ )
+
def _enforce_trailing_slash(self, url: URL) -> URL:
if url.raw_path.endswith(b"/"):
return url
tests/test_client.py
@@ -646,6 +646,10 @@ class TestOpenAI:
assert isinstance(exc.value.__cause__, ValidationError)
+ def test_client_max_retries_validation(self) -> None:
+ with pytest.raises(TypeError, match=r"max_retries cannot be None"):
+ OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None))
+
@pytest.mark.respx(base_url=base_url)
def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
class Model(BaseModel):
@@ -1368,6 +1372,12 @@ class TestAsyncOpenAI:
assert isinstance(exc.value.__cause__, ValidationError)
+ async def test_client_max_retries_validation(self) -> None:
+ with pytest.raises(TypeError, match=r"max_retries cannot be None"):
+ AsyncOpenAI(
+ base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None)
+ )
+
@pytest.mark.respx(base_url=base_url)
@pytest.mark.asyncio
async def test_default_stream_cls(self, respx_mock: MockRouter) -> None: