Commit ba4f7a97

Stainless Bot <107565488+stainless-bot@users.noreply.github.com>
2023-12-13 07:17:24
refactor: simplify internal error handling (#968)
1 parent a7ebc26
Changed files (2)
src/openai/_base_client.py
@@ -873,40 +873,25 @@ class SyncAPIClient(BaseClient[httpx.Client, Stream[Any]]):
         request = self._build_request(options)
         self._prepare_request(request)
 
-        response = None
-
         try:
             response = self._client.send(
                 request,
                 auth=self.custom_auth,
                 stream=stream or self._should_stream_response_body(request=request),
             )
-            log.debug(
-                'HTTP Request: %s %s "%i %s"', request.method, request.url, response.status_code, response.reason_phrase
-            )
-            response.raise_for_status()
-        except httpx.HTTPStatusError as err:  # thrown on 4xx and 5xx status code
-            if retries > 0 and self._should_retry(err.response):
-                err.response.close()
+        except httpx.TimeoutException as err:
+            if retries > 0:
                 return self._retry_request(
                     options,
                     cast_to,
                     retries,
-                    err.response.headers,
                     stream=stream,
                     stream_cls=stream_cls,
+                    response_headers=None,
                 )
 
-            # If the response is streamed then we need to explicitly read the response
-            # to completion before attempting to access the response text.
-            if not err.response.is_closed:
-                err.response.read()
-
-            raise self._make_status_error_from_response(err.response) from None
-        except httpx.TimeoutException as err:
-            if response is not None:
-                response.close()
-
+            raise APITimeoutError(request=request) from err
+        except Exception as err:
             if retries > 0:
                 return self._retry_request(
                     options,
@@ -914,25 +899,35 @@ class SyncAPIClient(BaseClient[httpx.Client, Stream[Any]]):
                     retries,
                     stream=stream,
                     stream_cls=stream_cls,
-                    response_headers=response.headers if response is not None else None,
+                    response_headers=None,
                 )
 
-            raise APITimeoutError(request=request) from err
-        except Exception as err:
-            if response is not None:
-                response.close()
+            raise APIConnectionError(request=request) from err
 
-            if retries > 0:
+        log.debug(
+            'HTTP Request: %s %s "%i %s"', request.method, request.url, response.status_code, response.reason_phrase
+        )
+
+        try:
+            response.raise_for_status()
+        except httpx.HTTPStatusError as err:  # thrown on 4xx and 5xx status code
+            if retries > 0 and self._should_retry(err.response):
+                err.response.close()
                 return self._retry_request(
                     options,
                     cast_to,
                     retries,
+                    err.response.headers,
                     stream=stream,
                     stream_cls=stream_cls,
-                    response_headers=response.headers if response is not None else None,
                 )
 
-            raise APIConnectionError(request=request) from err
+            # If the response is streamed then we need to explicitly read the response
+            # to completion before attempting to access the response text.
+            if not err.response.is_closed:
+                err.response.read()
+
+            raise self._make_status_error_from_response(err.response) from None
 
         return self._process_response(
             cast_to=cast_to,
@@ -1340,40 +1335,25 @@ class AsyncAPIClient(BaseClient[httpx.AsyncClient, AsyncStream[Any]]):
         request = self._build_request(options)
         await self._prepare_request(request)
 
-        response = None
-
         try:
             response = await self._client.send(
                 request,
                 auth=self.custom_auth,
                 stream=stream or self._should_stream_response_body(request=request),
             )
-            log.debug(
-                'HTTP Request: %s %s "%i %s"', request.method, request.url, response.status_code, response.reason_phrase
-            )
-            response.raise_for_status()
-        except httpx.HTTPStatusError as err:  # thrown on 4xx and 5xx status code
-            if retries > 0 and self._should_retry(err.response):
-                await err.response.aclose()
+        except httpx.TimeoutException as err:
+            if retries > 0:
                 return await self._retry_request(
                     options,
                     cast_to,
                     retries,
-                    err.response.headers,
                     stream=stream,
                     stream_cls=stream_cls,
+                    response_headers=None,
                 )
 
-            # If the response is streamed then we need to explicitly read the response
-            # to completion before attempting to access the response text.
-            if not err.response.is_closed:
-                await err.response.aread()
-
-            raise self._make_status_error_from_response(err.response) from None
-        except httpx.TimeoutException as err:
-            if response is not None:
-                await response.aclose()
-
+            raise APITimeoutError(request=request) from err
+        except Exception as err:
             if retries > 0:
                 return await self._retry_request(
                     options,
@@ -1381,25 +1361,35 @@ class AsyncAPIClient(BaseClient[httpx.AsyncClient, AsyncStream[Any]]):
                     retries,
                     stream=stream,
                     stream_cls=stream_cls,
-                    response_headers=response.headers if response is not None else None,
+                    response_headers=None,
                 )
 
-            raise APITimeoutError(request=request) from err
-        except Exception as err:
-            if response is not None:
-                await response.aclose()
+            raise APIConnectionError(request=request) from err
 
-            if retries > 0:
+        log.debug(
+            'HTTP Request: %s %s "%i %s"', request.method, request.url, response.status_code, response.reason_phrase
+        )
+
+        try:
+            response.raise_for_status()
+        except httpx.HTTPStatusError as err:  # thrown on 4xx and 5xx status code
+            if retries > 0 and self._should_retry(err.response):
+                await err.response.aclose()
                 return await self._retry_request(
                     options,
                     cast_to,
                     retries,
+                    err.response.headers,
                     stream=stream,
                     stream_cls=stream_cls,
-                    response_headers=response.headers if response is not None else None,
                 )
 
-            raise APIConnectionError(request=request) from err
+            # If the response is streamed then we need to explicitly read the response
+            # to completion before attempting to access the response text.
+            if not err.response.is_closed:
+                await err.response.aread()
+
+            raise self._make_status_error_from_response(err.response) from None
 
         return self._process_response(
             cast_to=cast_to,
tests/test_client.py
@@ -24,7 +24,6 @@ from openai._exceptions import (
     OpenAIError,
     APIStatusError,
     APITimeoutError,
-    APIConnectionError,
     APIResponseValidationError,
 )
 from openai._base_client import (
@@ -46,14 +45,8 @@ def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]:
     return dict(url.params)
 
 
-_original_response_init = cast(Any, httpx.Response.__init__)  # type: ignore
-
-
-def _low_retry_response_init(*args: Any, **kwargs: Any) -> Any:
-    headers = cast("list[tuple[bytes, bytes]]", kwargs["headers"])
-    headers.append((b"retry-after", b"0.1"))
-
-    return _original_response_init(*args, **kwargs)
+def _low_retry_timeout(*_args: Any, **_kwargs: Any) -> float:
+    return 0.1
 
 
 def _get_open_connections(client: OpenAI | AsyncOpenAI) -> int:
@@ -678,103 +671,51 @@ class TestOpenAI:
         calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
         assert calculated == pytest.approx(timeout, 0.5 * 0.875)  # pyright: ignore[reportUnknownMemberType]
 
-    @mock.patch("httpx.Response.__init__", _low_retry_response_init)
-    def test_retrying_timeout_errors_doesnt_leak(self) -> None:
-        def raise_for_status(response: httpx.Response) -> None:
-            raise httpx.TimeoutException("Test timeout error", request=response.request)
-
-        with mock.patch("httpx.Response.raise_for_status", raise_for_status):
-            with pytest.raises(APITimeoutError):
-                self.client.post(
-                    "/chat/completions",
-                    body=dict(
-                        messages=[
-                            {
-                                "role": "user",
-                                "content": "Say this is a test",
-                            }
-                        ],
-                        model="gpt-3.5-turbo",
-                    ),
-                    cast_to=httpx.Response,
-                    options={"headers": {"X-Stainless-Streamed-Raw-Response": "true"}},
-                )
-
-        assert _get_open_connections(self.client) == 0
-
-    @mock.patch("httpx.Response.__init__", _low_retry_response_init)
-    def test_retrying_runtime_errors_doesnt_leak(self) -> None:
-        def raise_for_status(_response: httpx.Response) -> None:
-            raise RuntimeError("Test error")
-
-        with mock.patch("httpx.Response.raise_for_status", raise_for_status):
-            with pytest.raises(APIConnectionError):
-                self.client.post(
-                    "/chat/completions",
-                    body=dict(
-                        messages=[
-                            {
-                                "role": "user",
-                                "content": "Say this is a test",
-                            }
-                        ],
-                        model="gpt-3.5-turbo",
-                    ),
-                    cast_to=httpx.Response,
-                    options={"headers": {"X-Stainless-Streamed-Raw-Response": "true"}},
-                )
-
-        assert _get_open_connections(self.client) == 0
-
-    @mock.patch("httpx.Response.__init__", _low_retry_response_init)
-    def test_retrying_status_errors_doesnt_leak(self) -> None:
-        def raise_for_status(response: httpx.Response) -> None:
-            response.status_code = 500
-            raise httpx.HTTPStatusError("Test 500 error", response=response, request=response.request)
-
-        with mock.patch("httpx.Response.raise_for_status", raise_for_status):
-            with pytest.raises(APIStatusError):
-                self.client.post(
-                    "/chat/completions",
-                    body=dict(
-                        messages=[
-                            {
-                                "role": "user",
-                                "content": "Say this is a test",
-                            }
-                        ],
-                        model="gpt-3.5-turbo",
-                    ),
-                    cast_to=httpx.Response,
-                    options={"headers": {"X-Stainless-Streamed-Raw-Response": "true"}},
-                )
+    @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
+    @pytest.mark.respx(base_url=base_url)
+    def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
+        respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
+
+        with pytest.raises(APITimeoutError):
+            self.client.post(
+                "/chat/completions",
+                body=dict(
+                    messages=[
+                        {
+                            "role": "user",
+                            "content": "Say this is a test",
+                        }
+                    ],
+                    model="gpt-3.5-turbo",
+                ),
+                cast_to=httpx.Response,
+                options={"headers": {"X-Stainless-Streamed-Raw-Response": "true"}},
+            )
 
         assert _get_open_connections(self.client) == 0
 
+    @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
     @pytest.mark.respx(base_url=base_url)
-    def test_status_error_within_httpx(self, respx_mock: MockRouter) -> None:
-        respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
+    def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
+        respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
 
-        def on_response(response: httpx.Response) -> None:
-            raise httpx.HTTPStatusError(
-                "Simulating an error inside httpx",
-                response=response,
-                request=response.request,
+        with pytest.raises(APIStatusError):
+            self.client.post(
+                "/chat/completions",
+                body=dict(
+                    messages=[
+                        {
+                            "role": "user",
+                            "content": "Say this is a test",
+                        }
+                    ],
+                    model="gpt-3.5-turbo",
+                ),
+                cast_to=httpx.Response,
+                options={"headers": {"X-Stainless-Streamed-Raw-Response": "true"}},
             )
 
-        client = OpenAI(
-            base_url=base_url,
-            api_key=api_key,
-            _strict_response_validation=True,
-            http_client=httpx.Client(
-                event_hooks={
-                    "response": [on_response],
-                }
-            ),
-            max_retries=0,
-        )
-        with pytest.raises(APIStatusError):
-            client.post("/foo", cast_to=httpx.Response)
+        assert _get_open_connections(self.client) == 0
 
 
 class TestAsyncOpenAI:
@@ -1408,101 +1349,48 @@ class TestAsyncOpenAI:
         calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
         assert calculated == pytest.approx(timeout, 0.5 * 0.875)  # pyright: ignore[reportUnknownMemberType]
 
-    @mock.patch("httpx.Response.__init__", _low_retry_response_init)
-    async def test_retrying_timeout_errors_doesnt_leak(self) -> None:
-        def raise_for_status(response: httpx.Response) -> None:
-            raise httpx.TimeoutException("Test timeout error", request=response.request)
-
-        with mock.patch("httpx.Response.raise_for_status", raise_for_status):
-            with pytest.raises(APITimeoutError):
-                await self.client.post(
-                    "/chat/completions",
-                    body=dict(
-                        messages=[
-                            {
-                                "role": "user",
-                                "content": "Say this is a test",
-                            }
-                        ],
-                        model="gpt-3.5-turbo",
-                    ),
-                    cast_to=httpx.Response,
-                    options={"headers": {"X-Stainless-Streamed-Raw-Response": "true"}},
-                )
-
-        assert _get_open_connections(self.client) == 0
-
-    @mock.patch("httpx.Response.__init__", _low_retry_response_init)
-    async def test_retrying_runtime_errors_doesnt_leak(self) -> None:
-        def raise_for_status(_response: httpx.Response) -> None:
-            raise RuntimeError("Test error")
-
-        with mock.patch("httpx.Response.raise_for_status", raise_for_status):
-            with pytest.raises(APIConnectionError):
-                await self.client.post(
-                    "/chat/completions",
-                    body=dict(
-                        messages=[
-                            {
-                                "role": "user",
-                                "content": "Say this is a test",
-                            }
-                        ],
-                        model="gpt-3.5-turbo",
-                    ),
-                    cast_to=httpx.Response,
-                    options={"headers": {"X-Stainless-Streamed-Raw-Response": "true"}},
-                )
-
-        assert _get_open_connections(self.client) == 0
-
-    @mock.patch("httpx.Response.__init__", _low_retry_response_init)
-    async def test_retrying_status_errors_doesnt_leak(self) -> None:
-        def raise_for_status(response: httpx.Response) -> None:
-            response.status_code = 500
-            raise httpx.HTTPStatusError("Test 500 error", response=response, request=response.request)
-
-        with mock.patch("httpx.Response.raise_for_status", raise_for_status):
-            with pytest.raises(APIStatusError):
-                await self.client.post(
-                    "/chat/completions",
-                    body=dict(
-                        messages=[
-                            {
-                                "role": "user",
-                                "content": "Say this is a test",
-                            }
-                        ],
-                        model="gpt-3.5-turbo",
-                    ),
-                    cast_to=httpx.Response,
-                    options={"headers": {"X-Stainless-Streamed-Raw-Response": "true"}},
-                )
+    @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
+    @pytest.mark.respx(base_url=base_url)
+    async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
+        respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
+
+        with pytest.raises(APITimeoutError):
+            await self.client.post(
+                "/chat/completions",
+                body=dict(
+                    messages=[
+                        {
+                            "role": "user",
+                            "content": "Say this is a test",
+                        }
+                    ],
+                    model="gpt-3.5-turbo",
+                ),
+                cast_to=httpx.Response,
+                options={"headers": {"X-Stainless-Streamed-Raw-Response": "true"}},
+            )
 
         assert _get_open_connections(self.client) == 0
 
+    @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
     @pytest.mark.respx(base_url=base_url)
-    @pytest.mark.asyncio
-    async def test_status_error_within_httpx(self, respx_mock: MockRouter) -> None:
-        respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
+    async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
+        respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
 
-        def on_response(response: httpx.Response) -> None:
-            raise httpx.HTTPStatusError(
-                "Simulating an error inside httpx",
-                response=response,
-                request=response.request,
+        with pytest.raises(APIStatusError):
+            await self.client.post(
+                "/chat/completions",
+                body=dict(
+                    messages=[
+                        {
+                            "role": "user",
+                            "content": "Say this is a test",
+                        }
+                    ],
+                    model="gpt-3.5-turbo",
+                ),
+                cast_to=httpx.Response,
+                options={"headers": {"X-Stainless-Streamed-Raw-Response": "true"}},
             )
 
-        client = AsyncOpenAI(
-            base_url=base_url,
-            api_key=api_key,
-            _strict_response_validation=True,
-            http_client=httpx.AsyncClient(
-                event_hooks={
-                    "response": [on_response],
-                }
-            ),
-            max_retries=0,
-        )
-        with pytest.raises(APIStatusError):
-            await client.post("/foo", cast_to=httpx.Response)
+        assert _get_open_connections(self.client) == 0