Commit 3765cc2c

Stainless Bot <dev@stainlessapi.com>
2024-09-20 00:11:56
feat(client): send retry count header
1 parent 6172976
Changed files (3)
src/openai/lib/azure.py
@@ -53,13 +53,15 @@ class BaseAzureClient(BaseClient[_HttpxClientT, _DefaultStreamT]):
     def _build_request(
         self,
         options: FinalRequestOptions,
+        *,
+        retries_taken: int = 0,
     ) -> httpx.Request:
         if options.url in _deployments_endpoints and is_mapping(options.json_data):
             model = options.json_data.get("model")
             if model is not None and not "/deployments" in str(self.base_url):
                 options.url = f"/deployments/{model}{options.url}"
 
-        return super()._build_request(options)
+        return super()._build_request(options, retries_taken=retries_taken)
 
 
 class AzureOpenAI(BaseAzureClient[httpx.Client, Stream[Any]], OpenAI):
src/openai/_base_client.py
@@ -401,14 +401,7 @@ class BaseClient(Generic[_HttpxClientT, _DefaultStreamT]):
     ) -> _exceptions.APIStatusError:
         raise NotImplementedError()
 
-    def _remaining_retries(
-        self,
-        remaining_retries: Optional[int],
-        options: FinalRequestOptions,
-    ) -> int:
-        return remaining_retries if remaining_retries is not None else options.get_max_retries(self.max_retries)
-
-    def _build_headers(self, options: FinalRequestOptions) -> httpx.Headers:
+    def _build_headers(self, options: FinalRequestOptions, *, retries_taken: int = 0) -> httpx.Headers:
         custom_headers = options.headers or {}
         headers_dict = _merge_mappings(self.default_headers, custom_headers)
         self._validate_headers(headers_dict, custom_headers)
@@ -420,6 +413,9 @@ class BaseClient(Generic[_HttpxClientT, _DefaultStreamT]):
         if idempotency_header and options.method.lower() != "get" and idempotency_header not in headers:
             headers[idempotency_header] = options.idempotency_key or self._idempotency_key()
 
+        if retries_taken > 0:
+            headers.setdefault("x-stainless-retry-count", str(retries_taken))
+
         return headers
 
     def _prepare_url(self, url: str) -> URL:
@@ -441,6 +437,8 @@ class BaseClient(Generic[_HttpxClientT, _DefaultStreamT]):
     def _build_request(
         self,
         options: FinalRequestOptions,
+        *,
+        retries_taken: int = 0,
     ) -> httpx.Request:
         if log.isEnabledFor(logging.DEBUG):
             log.debug("Request options: %s", model_dump(options, exclude_unset=True))
@@ -456,7 +454,7 @@ class BaseClient(Generic[_HttpxClientT, _DefaultStreamT]):
             else:
                 raise RuntimeError(f"Unexpected JSON data type, {type(json_data)}, cannot merge with `extra_body`")
 
-        headers = self._build_headers(options)
+        headers = self._build_headers(options, retries_taken=retries_taken)
         params = _merge_mappings(self.default_query, options.params)
         content_type = headers.get("Content-Type")
         files = options.files
@@ -939,12 +937,17 @@ class SyncAPIClient(BaseClient[httpx.Client, Stream[Any]]):
         stream: bool = False,
         stream_cls: type[_StreamT] | None = None,
     ) -> ResponseT | _StreamT:
+        if remaining_retries is not None:
+            retries_taken = options.get_max_retries(self.max_retries) - remaining_retries
+        else:
+            retries_taken = 0
+
         return self._request(
             cast_to=cast_to,
             options=options,
             stream=stream,
             stream_cls=stream_cls,
-            remaining_retries=remaining_retries,
+            retries_taken=retries_taken,
         )
 
     def _request(
@@ -952,7 +955,7 @@ class SyncAPIClient(BaseClient[httpx.Client, Stream[Any]]):
         *,
         cast_to: Type[ResponseT],
         options: FinalRequestOptions,
-        remaining_retries: int | None,
+        retries_taken: int,
         stream: bool,
         stream_cls: type[_StreamT] | None,
     ) -> ResponseT | _StreamT:
@@ -964,8 +967,8 @@ class SyncAPIClient(BaseClient[httpx.Client, Stream[Any]]):
         cast_to = self._maybe_override_cast_to(cast_to, options)
         options = self._prepare_options(options)
 
-        retries = self._remaining_retries(remaining_retries, options)
-        request = self._build_request(options)
+        remaining_retries = options.get_max_retries(self.max_retries) - retries_taken
+        request = self._build_request(options, retries_taken=retries_taken)
         self._prepare_request(request)
 
         kwargs: HttpxSendArgs = {}
@@ -983,11 +986,11 @@ class SyncAPIClient(BaseClient[httpx.Client, Stream[Any]]):
         except httpx.TimeoutException as err:
             log.debug("Encountered httpx.TimeoutException", exc_info=True)
 
-            if retries > 0:
+            if remaining_retries > 0:
                 return self._retry_request(
                     input_options,
                     cast_to,
-                    retries,
+                    retries_taken=retries_taken,
                     stream=stream,
                     stream_cls=stream_cls,
                     response_headers=None,
@@ -998,11 +1001,11 @@ class SyncAPIClient(BaseClient[httpx.Client, Stream[Any]]):
         except Exception as err:
             log.debug("Encountered Exception", exc_info=True)
 
-            if retries > 0:
+            if remaining_retries > 0:
                 return self._retry_request(
                     input_options,
                     cast_to,
-                    retries,
+                    retries_taken=retries_taken,
                     stream=stream,
                     stream_cls=stream_cls,
                     response_headers=None,
@@ -1026,13 +1029,13 @@ class SyncAPIClient(BaseClient[httpx.Client, Stream[Any]]):
         except httpx.HTTPStatusError as err:  # thrown on 4xx and 5xx status code
             log.debug("Encountered httpx.HTTPStatusError", exc_info=True)
 
-            if retries > 0 and self._should_retry(err.response):
+            if remaining_retries > 0 and self._should_retry(err.response):
                 err.response.close()
                 return self._retry_request(
                     input_options,
                     cast_to,
-                    retries,
-                    err.response.headers,
+                    retries_taken=retries_taken,
+                    response_headers=err.response.headers,
                     stream=stream,
                     stream_cls=stream_cls,
                 )
@@ -1051,26 +1054,26 @@ class SyncAPIClient(BaseClient[httpx.Client, Stream[Any]]):
             response=response,
             stream=stream,
             stream_cls=stream_cls,
-            retries_taken=options.get_max_retries(self.max_retries) - retries,
+            retries_taken=retries_taken,
         )
 
     def _retry_request(
         self,
         options: FinalRequestOptions,
         cast_to: Type[ResponseT],
-        remaining_retries: int,
-        response_headers: httpx.Headers | None,
         *,
+        retries_taken: int,
+        response_headers: httpx.Headers | None,
         stream: bool,
         stream_cls: type[_StreamT] | None,
     ) -> ResponseT | _StreamT:
-        remaining = remaining_retries - 1
-        if remaining == 1:
+        remaining_retries = options.get_max_retries(self.max_retries) - retries_taken
+        if remaining_retries == 1:
             log.debug("1 retry left")
         else:
-            log.debug("%i retries left", remaining)
+            log.debug("%i retries left", remaining_retries)
 
-        timeout = self._calculate_retry_timeout(remaining, options, response_headers)
+        timeout = self._calculate_retry_timeout(remaining_retries, options, response_headers)
         log.info("Retrying request to %s in %f seconds", options.url, timeout)
 
         # In a synchronous context we are blocking the entire thread. Up to the library user to run the client in a
@@ -1080,7 +1083,7 @@ class SyncAPIClient(BaseClient[httpx.Client, Stream[Any]]):
         return self._request(
             options=options,
             cast_to=cast_to,
-            remaining_retries=remaining,
+            retries_taken=retries_taken + 1,
             stream=stream,
             stream_cls=stream_cls,
         )
@@ -1512,12 +1515,17 @@ class AsyncAPIClient(BaseClient[httpx.AsyncClient, AsyncStream[Any]]):
         stream_cls: type[_AsyncStreamT] | None = None,
         remaining_retries: Optional[int] = None,
     ) -> ResponseT | _AsyncStreamT:
+        if remaining_retries is not None:
+            retries_taken = options.get_max_retries(self.max_retries) - remaining_retries
+        else:
+            retries_taken = 0
+
         return await self._request(
             cast_to=cast_to,
             options=options,
             stream=stream,
             stream_cls=stream_cls,
-            remaining_retries=remaining_retries,
+            retries_taken=retries_taken,
         )
 
     async def _request(
@@ -1527,7 +1535,7 @@ class AsyncAPIClient(BaseClient[httpx.AsyncClient, AsyncStream[Any]]):
         *,
         stream: bool,
         stream_cls: type[_AsyncStreamT] | None,
-        remaining_retries: int | None,
+        retries_taken: int,
     ) -> ResponseT | _AsyncStreamT:
         if self._platform is None:
             # `get_platform` can make blocking IO calls so we
@@ -1542,8 +1550,8 @@ class AsyncAPIClient(BaseClient[httpx.AsyncClient, AsyncStream[Any]]):
         cast_to = self._maybe_override_cast_to(cast_to, options)
         options = await self._prepare_options(options)
 
-        retries = self._remaining_retries(remaining_retries, options)
-        request = self._build_request(options)
+        remaining_retries = options.get_max_retries(self.max_retries) - retries_taken
+        request = self._build_request(options, retries_taken=retries_taken)
         await self._prepare_request(request)
 
         kwargs: HttpxSendArgs = {}
@@ -1559,11 +1567,11 @@ class AsyncAPIClient(BaseClient[httpx.AsyncClient, AsyncStream[Any]]):
         except httpx.TimeoutException as err:
             log.debug("Encountered httpx.TimeoutException", exc_info=True)
 
-            if retries > 0:
+            if remaining_retries > 0:
                 return await self._retry_request(
                     input_options,
                     cast_to,
-                    retries,
+                    retries_taken=retries_taken,
                     stream=stream,
                     stream_cls=stream_cls,
                     response_headers=None,
@@ -1574,11 +1582,11 @@ class AsyncAPIClient(BaseClient[httpx.AsyncClient, AsyncStream[Any]]):
         except Exception as err:
             log.debug("Encountered Exception", exc_info=True)
 
-            if retries > 0:
+            if retries_taken > 0:
                 return await self._retry_request(
                     input_options,
                     cast_to,
-                    retries,
+                    retries_taken=retries_taken,
                     stream=stream,
                     stream_cls=stream_cls,
                     response_headers=None,
@@ -1596,13 +1604,13 @@ class AsyncAPIClient(BaseClient[httpx.AsyncClient, AsyncStream[Any]]):
         except httpx.HTTPStatusError as err:  # thrown on 4xx and 5xx status code
             log.debug("Encountered httpx.HTTPStatusError", exc_info=True)
 
-            if retries > 0 and self._should_retry(err.response):
+            if remaining_retries > 0 and self._should_retry(err.response):
                 await err.response.aclose()
                 return await self._retry_request(
                     input_options,
                     cast_to,
-                    retries,
-                    err.response.headers,
+                    retries_taken=retries_taken,
+                    response_headers=err.response.headers,
                     stream=stream,
                     stream_cls=stream_cls,
                 )
@@ -1621,26 +1629,26 @@ class AsyncAPIClient(BaseClient[httpx.AsyncClient, AsyncStream[Any]]):
             response=response,
             stream=stream,
             stream_cls=stream_cls,
-            retries_taken=options.get_max_retries(self.max_retries) - retries,
+            retries_taken=retries_taken,
         )
 
     async def _retry_request(
         self,
         options: FinalRequestOptions,
         cast_to: Type[ResponseT],
-        remaining_retries: int,
-        response_headers: httpx.Headers | None,
         *,
+        retries_taken: int,
+        response_headers: httpx.Headers | None,
         stream: bool,
         stream_cls: type[_AsyncStreamT] | None,
     ) -> ResponseT | _AsyncStreamT:
-        remaining = remaining_retries - 1
-        if remaining == 1:
+        remaining_retries = options.get_max_retries(self.max_retries) - retries_taken
+        if remaining_retries == 1:
             log.debug("1 retry left")
         else:
-            log.debug("%i retries left", remaining)
+            log.debug("%i retries left", remaining_retries)
 
-        timeout = self._calculate_retry_timeout(remaining, options, response_headers)
+        timeout = self._calculate_retry_timeout(remaining_retries, options, response_headers)
         log.info("Retrying request to %s in %f seconds", options.url, timeout)
 
         await anyio.sleep(timeout)
@@ -1648,7 +1656,7 @@ class AsyncAPIClient(BaseClient[httpx.AsyncClient, AsyncStream[Any]]):
         return await self._request(
             options=options,
             cast_to=cast_to,
-            remaining_retries=remaining,
+            retries_taken=retries_taken + 1,
             stream=stream,
             stream_cls=stream_cls,
         )
tests/test_client.py
@@ -788,6 +788,10 @@ class TestOpenAI:
         )
 
         assert response.retries_taken == failures_before_success
+        if failures_before_success == 0:
+            assert "x-stainless-retry-count" not in response.http_request.headers
+        else:
+            assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
 
     @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
     @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@@ -818,6 +822,10 @@ class TestOpenAI:
             model="gpt-4o",
         ) as response:
             assert response.retries_taken == failures_before_success
+            if failures_before_success == 0:
+                assert "x-stainless-retry-count" not in response.http_request.headers
+            else:
+                assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
 
 
 class TestAsyncOpenAI:
@@ -1582,6 +1590,10 @@ class TestAsyncOpenAI:
         )
 
         assert response.retries_taken == failures_before_success
+        if failures_before_success == 0:
+            assert "x-stainless-retry-count" not in response.http_request.headers
+        else:
+            assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
 
     @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
     @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@@ -1613,3 +1625,7 @@ class TestAsyncOpenAI:
             model="gpt-4o",
         ) as response:
             assert response.retries_taken == failures_before_success
+            if failures_before_success == 0:
+                assert "x-stainless-retry-count" not in response.http_request.headers
+            else:
+                assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success