Commit 7aad3405

Stainless Bot <107565488+stainless-bot@users.noreply.github.com>
2023-11-30 08:02:45
fix(client): ensure retried requests are closed (#902)
1 parent 2e7e897
Changed files (3)
src/openai/_base_client.py
@@ -72,6 +72,7 @@ from ._constants import (
     DEFAULT_TIMEOUT,
     DEFAULT_MAX_RETRIES,
     RAW_RESPONSE_HEADER,
+    STREAMED_RAW_RESPONSE_HEADER,
 )
 from ._streaming import Stream, AsyncStream
 from ._exceptions import (
@@ -363,14 +364,21 @@ class BaseClient(Generic[_HttpxClientT, _DefaultStreamT]):
         self,
         response: httpx.Response,
     ) -> APIStatusError:
-        err_text = response.text.strip()
-        body = err_text
+        if response.is_closed and not response.is_stream_consumed:
+            # We can't read the response body as it has been closed
+            # before it was read. This can happen if an event hook
+            # raises a status error.
+            body = None
+            err_msg = f"Error code: {response.status_code}"
+        else:
+            err_text = response.text.strip()
+            body = err_text
 
-        try:
-            body = json.loads(err_text)
-            err_msg = f"Error code: {response.status_code} - {body}"
-        except Exception:
-            err_msg = err_text or f"Error code: {response.status_code}"
+            try:
+                body = json.loads(err_text)
+                err_msg = f"Error code: {response.status_code} - {body}"
+            except Exception:
+                err_msg = err_text or f"Error code: {response.status_code}"
 
         return self._make_status_error(err_msg, body=body, response=response)
 
@@ -534,6 +542,12 @@ class BaseClient(Generic[_HttpxClientT, _DefaultStreamT]):
         except pydantic.ValidationError as err:
             raise APIResponseValidationError(response=response, body=data) from err
 
+    def _should_stream_response_body(self, *, request: httpx.Request) -> bool:
+        if request.headers.get(STREAMED_RAW_RESPONSE_HEADER) == "true":
+            return True
+
+        return False
+
     @property
     def qs(self) -> Querystring:
         return Querystring()
@@ -606,7 +620,7 @@ class BaseClient(Generic[_HttpxClientT, _DefaultStreamT]):
             if response_headers is not None:
                 retry_header = response_headers.get("retry-after")
                 try:
-                    retry_after = int(retry_header)
+                    retry_after = float(retry_header)
                 except Exception:
                     retry_date_tuple = email.utils.parsedate_tz(retry_header)
                     if retry_date_tuple is None:
@@ -862,14 +876,21 @@ class SyncAPIClient(BaseClient[httpx.Client, Stream[Any]]):
         request = self._build_request(options)
         self._prepare_request(request)
 
+        response = None
+
         try:
-            response = self._client.send(request, auth=self.custom_auth, stream=stream)
+            response = self._client.send(
+                request,
+                auth=self.custom_auth,
+                stream=stream or self._should_stream_response_body(request=request),
+            )
             log.debug(
                 'HTTP Request: %s %s "%i %s"', request.method, request.url, response.status_code, response.reason_phrase
             )
             response.raise_for_status()
         except httpx.HTTPStatusError as err:  # thrown on 4xx and 5xx status code
             if retries > 0 and self._should_retry(err.response):
+                err.response.close()
                 return self._retry_request(
                     options,
                     cast_to,
@@ -881,9 +902,14 @@ class SyncAPIClient(BaseClient[httpx.Client, Stream[Any]]):
 
             # If the response is streamed then we need to explicitly read the response
             # to completion before attempting to access the response text.
-            err.response.read()
+            if not err.response.is_closed:
+                err.response.read()
+
             raise self._make_status_error_from_response(err.response) from None
         except httpx.TimeoutException as err:
+            if response is not None:
+                response.close()
+
             if retries > 0:
                 return self._retry_request(
                     options,
@@ -891,9 +917,14 @@ class SyncAPIClient(BaseClient[httpx.Client, Stream[Any]]):
                     retries,
                     stream=stream,
                     stream_cls=stream_cls,
+                    response_headers=response.headers if response is not None else None,
                 )
+
             raise APITimeoutError(request=request) from err
         except Exception as err:
+            if response is not None:
+                response.close()
+
             if retries > 0:
                 return self._retry_request(
                     options,
@@ -901,7 +932,9 @@ class SyncAPIClient(BaseClient[httpx.Client, Stream[Any]]):
                     retries,
                     stream=stream,
                     stream_cls=stream_cls,
+                    response_headers=response.headers if response is not None else None,
                 )
+
             raise APIConnectionError(request=request) from err
 
         return self._process_response(
@@ -917,7 +950,7 @@ class SyncAPIClient(BaseClient[httpx.Client, Stream[Any]]):
         options: FinalRequestOptions,
         cast_to: Type[ResponseT],
         remaining_retries: int,
-        response_headers: Optional[httpx.Headers] = None,
+        response_headers: httpx.Headers | None,
         *,
         stream: bool,
         stream_cls: type[_StreamT] | None,
@@ -1303,14 +1336,21 @@ class AsyncAPIClient(BaseClient[httpx.AsyncClient, AsyncStream[Any]]):
         request = self._build_request(options)
         await self._prepare_request(request)
 
+        response = None
+
         try:
-            response = await self._client.send(request, auth=self.custom_auth, stream=stream)
+            response = await self._client.send(
+                request,
+                auth=self.custom_auth,
+                stream=stream or self._should_stream_response_body(request=request),
+            )
             log.debug(
                 'HTTP Request: %s %s "%i %s"', request.method, request.url, response.status_code, response.reason_phrase
             )
             response.raise_for_status()
         except httpx.HTTPStatusError as err:  # thrown on 4xx and 5xx status code
             if retries > 0 and self._should_retry(err.response):
+                await err.response.aclose()
                 return await self._retry_request(
                     options,
                     cast_to,
@@ -1322,19 +1362,39 @@ class AsyncAPIClient(BaseClient[httpx.AsyncClient, AsyncStream[Any]]):
 
             # If the response is streamed then we need to explicitly read the response
             # to completion before attempting to access the response text.
-            await err.response.aread()
+            if not err.response.is_closed:
+                await err.response.aread()
+
             raise self._make_status_error_from_response(err.response) from None
-        except httpx.ConnectTimeout as err:
-            if retries > 0:
-                return await self._retry_request(options, cast_to, retries, stream=stream, stream_cls=stream_cls)
-            raise APITimeoutError(request=request) from err
         except httpx.TimeoutException as err:
+            if response is not None:
+                await response.aclose()
+
             if retries > 0:
-                return await self._retry_request(options, cast_to, retries, stream=stream, stream_cls=stream_cls)
+                return await self._retry_request(
+                    options,
+                    cast_to,
+                    retries,
+                    stream=stream,
+                    stream_cls=stream_cls,
+                    response_headers=response.headers if response is not None else None,
+                )
+
             raise APITimeoutError(request=request) from err
         except Exception as err:
+            if response is not None:
+                await response.aclose()
+
             if retries > 0:
-                return await self._retry_request(options, cast_to, retries, stream=stream, stream_cls=stream_cls)
+                return await self._retry_request(
+                    options,
+                    cast_to,
+                    retries,
+                    stream=stream,
+                    stream_cls=stream_cls,
+                    response_headers=response.headers if response is not None else None,
+                )
+
             raise APIConnectionError(request=request) from err
 
         return self._process_response(
@@ -1350,7 +1410,7 @@ class AsyncAPIClient(BaseClient[httpx.AsyncClient, AsyncStream[Any]]):
         options: FinalRequestOptions,
         cast_to: Type[ResponseT],
         remaining_retries: int,
-        response_headers: Optional[httpx.Headers] = None,
+        response_headers: httpx.Headers | None,
         *,
         stream: bool,
         stream_cls: type[_AsyncStreamT] | None,
src/openai/_constants.py
@@ -3,6 +3,7 @@
 import httpx
 
 RAW_RESPONSE_HEADER = "X-Stainless-Raw-Response"
+STREAMED_RAW_RESPONSE_HEADER = "X-Stainless-Streamed-Raw-Response"
 
 # default timeout is 10 minutes
 DEFAULT_TIMEOUT = httpx.Timeout(timeout=600.0, connect=5.0)
tests/test_client.py
@@ -18,7 +18,12 @@ from openai import OpenAI, AsyncOpenAI, APIResponseValidationError
 from openai._client import OpenAI, AsyncOpenAI
 from openai._models import BaseModel, FinalRequestOptions
 from openai._streaming import Stream, AsyncStream
-from openai._exceptions import APIResponseValidationError
+from openai._exceptions import (
+    APIStatusError,
+    APITimeoutError,
+    APIConnectionError,
+    APIResponseValidationError,
+)
 from openai._base_client import (
     DEFAULT_TIMEOUT,
     HTTPX_DEFAULT_TIMEOUT,
@@ -38,6 +43,24 @@ def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]:
     return dict(url.params)
 
 
+_original_response_init = cast(Any, httpx.Response.__init__)  # type: ignore
+
+
+def _low_retry_response_init(*args: Any, **kwargs: Any) -> Any:
+    headers = cast("list[tuple[bytes, bytes]]", kwargs["headers"])
+    headers.append((b"retry-after", b"0.1"))
+
+    return _original_response_init(*args, **kwargs)
+
+
+def _get_open_connections(client: OpenAI | AsyncOpenAI) -> int:
+    transport = client._client._transport
+    assert isinstance(transport, httpx.HTTPTransport) or isinstance(transport, httpx.AsyncHTTPTransport)
+
+    pool = transport._pool
+    return len(pool._requests)
+
+
 class TestOpenAI:
     client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
 
@@ -592,6 +615,104 @@ class TestOpenAI:
         calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
         assert calculated == pytest.approx(timeout, 0.5 * 0.875)  # pyright: ignore[reportUnknownMemberType]
 
+    @mock.patch("httpx.Response.__init__", _low_retry_response_init)
+    def test_retrying_timeout_errors_doesnt_leak(self) -> None:
+        def raise_for_status(response: httpx.Response) -> None:
+            raise httpx.TimeoutException("Test timeout error", request=response.request)
+
+        with mock.patch("httpx.Response.raise_for_status", raise_for_status):
+            with pytest.raises(APITimeoutError):
+                self.client.post(
+                    "/chat/completions",
+                    body=dict(
+                        messages=[
+                            {
+                                "role": "user",
+                                "content": "Say this is a test",
+                            }
+                        ],
+                        model="gpt-3.5-turbo",
+                    ),
+                    cast_to=httpx.Response,
+                    options={"headers": {"X-Stainless-Streamed-Raw-Response": "true"}},
+                )
+
+        assert _get_open_connections(self.client) == 0
+
+    @mock.patch("httpx.Response.__init__", _low_retry_response_init)
+    def test_retrying_runtime_errors_doesnt_leak(self) -> None:
+        def raise_for_status(_response: httpx.Response) -> None:
+            raise RuntimeError("Test error")
+
+        with mock.patch("httpx.Response.raise_for_status", raise_for_status):
+            with pytest.raises(APIConnectionError):
+                self.client.post(
+                    "/chat/completions",
+                    body=dict(
+                        messages=[
+                            {
+                                "role": "user",
+                                "content": "Say this is a test",
+                            }
+                        ],
+                        model="gpt-3.5-turbo",
+                    ),
+                    cast_to=httpx.Response,
+                    options={"headers": {"X-Stainless-Streamed-Raw-Response": "true"}},
+                )
+
+        assert _get_open_connections(self.client) == 0
+
+    @mock.patch("httpx.Response.__init__", _low_retry_response_init)
+    def test_retrying_status_errors_doesnt_leak(self) -> None:
+        def raise_for_status(response: httpx.Response) -> None:
+            response.status_code = 500
+            raise httpx.HTTPStatusError("Test 500 error", response=response, request=response.request)
+
+        with mock.patch("httpx.Response.raise_for_status", raise_for_status):
+            with pytest.raises(APIStatusError):
+                self.client.post(
+                    "/chat/completions",
+                    body=dict(
+                        messages=[
+                            {
+                                "role": "user",
+                                "content": "Say this is a test",
+                            }
+                        ],
+                        model="gpt-3.5-turbo",
+                    ),
+                    cast_to=httpx.Response,
+                    options={"headers": {"X-Stainless-Streamed-Raw-Response": "true"}},
+                )
+
+        assert _get_open_connections(self.client) == 0
+
+    @pytest.mark.respx(base_url=base_url)
+    def test_status_error_within_httpx(self, respx_mock: MockRouter) -> None:
+        respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
+
+        def on_response(response: httpx.Response) -> None:
+            raise httpx.HTTPStatusError(
+                "Simulating an error inside httpx",
+                response=response,
+                request=response.request,
+            )
+
+        client = OpenAI(
+            base_url=base_url,
+            api_key=api_key,
+            _strict_response_validation=True,
+            http_client=httpx.Client(
+                event_hooks={
+                    "response": [on_response],
+                }
+            ),
+            max_retries=0,
+        )
+        with pytest.raises(APIStatusError):
+            client.post("/foo", cast_to=httpx.Response)
+
 
 class TestAsyncOpenAI:
     client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
@@ -1162,3 +1283,102 @@ class TestAsyncOpenAI:
         options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
         calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
         assert calculated == pytest.approx(timeout, 0.5 * 0.875)  # pyright: ignore[reportUnknownMemberType]
+
+    @mock.patch("httpx.Response.__init__", _low_retry_response_init)
+    async def test_retrying_timeout_errors_doesnt_leak(self) -> None:
+        def raise_for_status(response: httpx.Response) -> None:
+            raise httpx.TimeoutException("Test timeout error", request=response.request)
+
+        with mock.patch("httpx.Response.raise_for_status", raise_for_status):
+            with pytest.raises(APITimeoutError):
+                await self.client.post(
+                    "/chat/completions",
+                    body=dict(
+                        messages=[
+                            {
+                                "role": "user",
+                                "content": "Say this is a test",
+                            }
+                        ],
+                        model="gpt-3.5-turbo",
+                    ),
+                    cast_to=httpx.Response,
+                    options={"headers": {"X-Stainless-Streamed-Raw-Response": "true"}},
+                )
+
+        assert _get_open_connections(self.client) == 0
+
+    @mock.patch("httpx.Response.__init__", _low_retry_response_init)
+    async def test_retrying_runtime_errors_doesnt_leak(self) -> None:
+        def raise_for_status(_response: httpx.Response) -> None:
+            raise RuntimeError("Test error")
+
+        with mock.patch("httpx.Response.raise_for_status", raise_for_status):
+            with pytest.raises(APIConnectionError):
+                await self.client.post(
+                    "/chat/completions",
+                    body=dict(
+                        messages=[
+                            {
+                                "role": "user",
+                                "content": "Say this is a test",
+                            }
+                        ],
+                        model="gpt-3.5-turbo",
+                    ),
+                    cast_to=httpx.Response,
+                    options={"headers": {"X-Stainless-Streamed-Raw-Response": "true"}},
+                )
+
+        assert _get_open_connections(self.client) == 0
+
+    @mock.patch("httpx.Response.__init__", _low_retry_response_init)
+    async def test_retrying_status_errors_doesnt_leak(self) -> None:
+        def raise_for_status(response: httpx.Response) -> None:
+            response.status_code = 500
+            raise httpx.HTTPStatusError("Test 500 error", response=response, request=response.request)
+
+        with mock.patch("httpx.Response.raise_for_status", raise_for_status):
+            with pytest.raises(APIStatusError):
+                await self.client.post(
+                    "/chat/completions",
+                    body=dict(
+                        messages=[
+                            {
+                                "role": "user",
+                                "content": "Say this is a test",
+                            }
+                        ],
+                        model="gpt-3.5-turbo",
+                    ),
+                    cast_to=httpx.Response,
+                    options={"headers": {"X-Stainless-Streamed-Raw-Response": "true"}},
+                )
+
+        assert _get_open_connections(self.client) == 0
+
+    @pytest.mark.respx(base_url=base_url)
+    @pytest.mark.asyncio
+    async def test_status_error_within_httpx(self, respx_mock: MockRouter) -> None:
+        respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
+
+        def on_response(response: httpx.Response) -> None:
+            raise httpx.HTTPStatusError(
+                "Simulating an error inside httpx",
+                response=response,
+                request=response.request,
+            )
+
+        client = AsyncOpenAI(
+            base_url=base_url,
+            api_key=api_key,
+            _strict_response_validation=True,
+            http_client=httpx.AsyncClient(
+                event_hooks={
+                    "response": [on_response],
+                }
+            ),
+            max_retries=0,
+        )
+        with pytest.raises(APIStatusError):
+            await client.post("/foo", cast_to=httpx.Response)