Commit 07079085

Stainless Bot <107565488+stainless-bot@users.noreply.github.com>
2024-04-01 17:47:23
chore(client): validate that max_retries is not None (#1286)
1 parent 0470d1b
Changed files (2)
src/openai/_base_client.py
@@ -361,6 +361,11 @@ class BaseClient(Generic[_HttpxClientT, _DefaultStreamT]):
         self._strict_response_validation = _strict_response_validation
         self._idempotency_header = None
 
+        if max_retries is None:  # pyright: ignore[reportUnnecessaryComparison]
+            raise TypeError(
+                "max_retries cannot be None. If you want to disable retries, pass `0`; if you want unlimited retries, pass `math.inf` or a very high number; if you want the default behavior, pass `openai.DEFAULT_MAX_RETRIES`"
+            )
+
     def _enforce_trailing_slash(self, url: URL) -> URL:
         if url.raw_path.endswith(b"/"):
             return url
tests/test_client.py
@@ -646,6 +646,10 @@ class TestOpenAI:
 
         assert isinstance(exc.value.__cause__, ValidationError)
 
+    def test_client_max_retries_validation(self) -> None:
+        with pytest.raises(TypeError, match=r"max_retries cannot be None"):
+            OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None))
+
     @pytest.mark.respx(base_url=base_url)
     def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
         class Model(BaseModel):
@@ -1368,6 +1372,12 @@ class TestAsyncOpenAI:
 
         assert isinstance(exc.value.__cause__, ValidationError)
 
+    async def test_client_max_retries_validation(self) -> None:
+        with pytest.raises(TypeError, match=r"max_retries cannot be None"):
+            AsyncOpenAI(
+                base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None)
+            )
+
     @pytest.mark.respx(base_url=base_url)
     @pytest.mark.asyncio
     async def test_default_stream_cls(self, respx_mock: MockRouter) -> None: