Commit dae0ec80

Stainless Bot <107565488+stainless-bot@users.noreply.github.com>
2024-03-05 18:35:14
chore(client): improve error message for invalid http_client argument (#1216)
1 parent 3208335
Changed files (2)
src/openai/_base_client.py
@@ -780,6 +780,11 @@ class SyncAPIClient(BaseClient[httpx.Client, Stream[Any]]):
             else:
                 timeout = DEFAULT_TIMEOUT
 
+        if http_client is not None and not isinstance(http_client, httpx.Client):  # pyright: ignore[reportUnnecessaryIsInstance]
+            raise TypeError(
+                f"Invalid `http_client` argument; Expected an instance of `httpx.Client` but got {type(http_client)}"
+            )
+
         super().__init__(
             version=version,
             limits=limits,
@@ -1322,6 +1327,11 @@ class AsyncAPIClient(BaseClient[httpx.AsyncClient, AsyncStream[Any]]):
             else:
                 timeout = DEFAULT_TIMEOUT
 
+        if http_client is not None and not isinstance(http_client, httpx.AsyncClient):  # pyright: ignore[reportUnnecessaryIsInstance]
+            raise TypeError(
+                f"Invalid `http_client` argument; Expected an instance of `httpx.AsyncClient` but got {type(http_client)}"
+            )
+
         super().__init__(
             version=version,
             base_url=base_url,
tests/test_client.py
@@ -292,6 +292,16 @@ class TestOpenAI:
             timeout = httpx.Timeout(**request.extensions["timeout"])  # type: ignore
             assert timeout == DEFAULT_TIMEOUT  # our default
 
+    async def test_invalid_http_client(self) -> None:
+        with pytest.raises(TypeError, match="Invalid `http_client` arg"):
+            async with httpx.AsyncClient() as http_client:
+                OpenAI(
+                    base_url=base_url,
+                    api_key=api_key,
+                    _strict_response_validation=True,
+                    http_client=cast(Any, http_client),
+                )
+
     def test_default_headers_option(self) -> None:
         client = OpenAI(
             base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
@@ -994,6 +1004,16 @@ class TestAsyncOpenAI:
             timeout = httpx.Timeout(**request.extensions["timeout"])  # type: ignore
             assert timeout == DEFAULT_TIMEOUT  # our default
 
+    def test_invalid_http_client(self) -> None:
+        with pytest.raises(TypeError, match="Invalid `http_client` arg"):
+            with httpx.Client() as http_client:
+                AsyncOpenAI(
+                    base_url=base_url,
+                    api_key=api_key,
+                    _strict_response_validation=True,
+                    http_client=cast(Any, http_client),
+                )
+
     def test_default_headers_option(self) -> None:
         client = AsyncOpenAI(
             base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}