Commit 7de6c5ce

stainless-app[bot] <142633134+stainless-app[bot]@users.noreply.github.com>
2025-04-23 06:53:21
chore(internal): refactor retries to not use recursion
1 parent 1e7dea2
Changed files (1)
src
src/openai/_base_client.py
@@ -439,8 +439,7 @@ class BaseClient(Generic[_HttpxClientT, _DefaultStreamT]):
         headers = httpx.Headers(headers_dict)
 
         idempotency_header = self._idempotency_header
-        if idempotency_header and options.method.lower() != "get" and idempotency_header not in headers:
-            options.idempotency_key = options.idempotency_key or self._idempotency_key()
+        if idempotency_header and options.idempotency_key and idempotency_header not in headers:
             headers[idempotency_header] = options.idempotency_key
 
         # Don't set these headers if they were already set or removed by the caller. We check
@@ -905,7 +904,6 @@ class SyncAPIClient(BaseClient[httpx.Client, Stream[Any]]):
         self,
         cast_to: Type[ResponseT],
         options: FinalRequestOptions,
-        remaining_retries: Optional[int] = None,
         *,
         stream: Literal[True],
         stream_cls: Type[_StreamT],
@@ -916,7 +914,6 @@ class SyncAPIClient(BaseClient[httpx.Client, Stream[Any]]):
         self,
         cast_to: Type[ResponseT],
         options: FinalRequestOptions,
-        remaining_retries: Optional[int] = None,
         *,
         stream: Literal[False] = False,
     ) -> ResponseT: ...
@@ -926,7 +923,6 @@ class SyncAPIClient(BaseClient[httpx.Client, Stream[Any]]):
         self,
         cast_to: Type[ResponseT],
         options: FinalRequestOptions,
-        remaining_retries: Optional[int] = None,
         *,
         stream: bool = False,
         stream_cls: Type[_StreamT] | None = None,
@@ -936,126 +932,110 @@ class SyncAPIClient(BaseClient[httpx.Client, Stream[Any]]):
         self,
         cast_to: Type[ResponseT],
         options: FinalRequestOptions,
-        remaining_retries: Optional[int] = None,
         *,
         stream: bool = False,
         stream_cls: type[_StreamT] | None = None,
     ) -> ResponseT | _StreamT:
-        if remaining_retries is not None:
-            retries_taken = options.get_max_retries(self.max_retries) - remaining_retries
-        else:
-            retries_taken = 0
-
-        return self._request(
-            cast_to=cast_to,
-            options=options,
-            stream=stream,
-            stream_cls=stream_cls,
-            retries_taken=retries_taken,
-        )
+        cast_to = self._maybe_override_cast_to(cast_to, options)
 
-    def _request(
-        self,
-        *,
-        cast_to: Type[ResponseT],
-        options: FinalRequestOptions,
-        retries_taken: int,
-        stream: bool,
-        stream_cls: type[_StreamT] | None,
-    ) -> ResponseT | _StreamT:
         # create a copy of the options we were given so that if the
         # options are mutated later & we then retry, the retries are
         # given the original options
         input_options = model_copy(options)
-
-        cast_to = self._maybe_override_cast_to(cast_to, options)
-        options = self._prepare_options(options)
-
-        remaining_retries = options.get_max_retries(self.max_retries) - retries_taken
-        request = self._build_request(options, retries_taken=retries_taken)
-        self._prepare_request(request)
-
-        if options.idempotency_key:
+        if input_options.idempotency_key is None and input_options.method.lower() != "get":
             # ensure the idempotency key is reused between requests
-            input_options.idempotency_key = options.idempotency_key
+            input_options.idempotency_key = self._idempotency_key()
 
-        kwargs: HttpxSendArgs = {}
-        if self.custom_auth is not None:
-            kwargs["auth"] = self.custom_auth
+        response: httpx.Response | None = None
+        max_retries = input_options.get_max_retries(self.max_retries)
 
-        log.debug("Sending HTTP Request: %s %s", request.method, request.url)
+        retries_taken = 0
+        for retries_taken in range(max_retries + 1):
+            options = model_copy(input_options)
+            options = self._prepare_options(options)
 
-        try:
-            response = self._client.send(
-                request,
-                stream=stream or self._should_stream_response_body(request=request),
-                **kwargs,
-            )
-        except httpx.TimeoutException as err:
-            log.debug("Encountered httpx.TimeoutException", exc_info=True)
+            remaining_retries = max_retries - retries_taken
+            request = self._build_request(options, retries_taken=retries_taken)
+            self._prepare_request(request)
 
-            if remaining_retries > 0:
-                return self._retry_request(
-                    input_options,
-                    cast_to,
-                    retries_taken=retries_taken,
-                    stream=stream,
-                    stream_cls=stream_cls,
-                    response_headers=None,
-                )
+            kwargs: HttpxSendArgs = {}
+            if self.custom_auth is not None:
+                kwargs["auth"] = self.custom_auth
 
-            log.debug("Raising timeout error")
-            raise APITimeoutError(request=request) from err
-        except Exception as err:
-            log.debug("Encountered Exception", exc_info=True)
+            log.debug("Sending HTTP Request: %s %s", request.method, request.url)
 
-            if remaining_retries > 0:
-                return self._retry_request(
-                    input_options,
-                    cast_to,
-                    retries_taken=retries_taken,
-                    stream=stream,
-                    stream_cls=stream_cls,
-                    response_headers=None,
+            response = None
+            try:
+                response = self._client.send(
+                    request,
+                    stream=stream or self._should_stream_response_body(request=request),
+                    **kwargs,
                 )
+            except httpx.TimeoutException as err:
+                log.debug("Encountered httpx.TimeoutException", exc_info=True)
+
+                if remaining_retries > 0:
+                    self._sleep_for_retry(
+                        retries_taken=retries_taken,
+                        max_retries=max_retries,
+                        options=input_options,
+                        response=None,
+                    )
+                    continue
+
+                log.debug("Raising timeout error")
+                raise APITimeoutError(request=request) from err
+            except Exception as err:
+                log.debug("Encountered Exception", exc_info=True)
+
+                if remaining_retries > 0:
+                    self._sleep_for_retry(
+                        retries_taken=retries_taken,
+                        max_retries=max_retries,
+                        options=input_options,
+                        response=None,
+                    )
+                    continue
+
+                log.debug("Raising connection error")
+                raise APIConnectionError(request=request) from err
+
+            log.debug(
+                'HTTP Response: %s %s "%i %s" %s',
+                request.method,
+                request.url,
+                response.status_code,
+                response.reason_phrase,
+                response.headers,
+            )
+            log.debug("request_id: %s", response.headers.get("x-request-id"))
 
-            log.debug("Raising connection error")
-            raise APIConnectionError(request=request) from err
-
-        log.debug(
-            'HTTP Response: %s %s "%i %s" %s',
-            request.method,
-            request.url,
-            response.status_code,
-            response.reason_phrase,
-            response.headers,
-        )
-        log.debug("request_id: %s", response.headers.get("x-request-id"))
+            try:
+                response.raise_for_status()
+            except httpx.HTTPStatusError as err:  # thrown on 4xx and 5xx status code
+                log.debug("Encountered httpx.HTTPStatusError", exc_info=True)
+
+                if remaining_retries > 0 and self._should_retry(err.response):
+                    err.response.close()
+                    self._sleep_for_retry(
+                        retries_taken=retries_taken,
+                        max_retries=max_retries,
+                        options=input_options,
+                        response=response,
+                    )
+                    continue
 
-        try:
-            response.raise_for_status()
-        except httpx.HTTPStatusError as err:  # thrown on 4xx and 5xx status code
-            log.debug("Encountered httpx.HTTPStatusError", exc_info=True)
-
-            if remaining_retries > 0 and self._should_retry(err.response):
-                err.response.close()
-                return self._retry_request(
-                    input_options,
-                    cast_to,
-                    retries_taken=retries_taken,
-                    response_headers=err.response.headers,
-                    stream=stream,
-                    stream_cls=stream_cls,
-                )
+                # If the response is streamed then we need to explicitly read the response
+                # to completion before attempting to access the response text.
+                if not err.response.is_closed:
+                    err.response.read()
 
-            # If the response is streamed then we need to explicitly read the response
-            # to completion before attempting to access the response text.
-            if not err.response.is_closed:
-                err.response.read()
+                log.debug("Re-raising status error")
+                raise self._make_status_error_from_response(err.response) from None
 
-            log.debug("Re-raising status error")
-            raise self._make_status_error_from_response(err.response) from None
+            break
 
+        assert response is not None, "could not resolve response (should never happen)"
         return self._process_response(
             cast_to=cast_to,
             options=options,
@@ -1065,37 +1045,20 @@ class SyncAPIClient(BaseClient[httpx.Client, Stream[Any]]):
             retries_taken=retries_taken,
         )
 
-    def _retry_request(
-        self,
-        options: FinalRequestOptions,
-        cast_to: Type[ResponseT],
-        *,
-        retries_taken: int,
-        response_headers: httpx.Headers | None,
-        stream: bool,
-        stream_cls: type[_StreamT] | None,
-    ) -> ResponseT | _StreamT:
-        remaining_retries = options.get_max_retries(self.max_retries) - retries_taken
+    def _sleep_for_retry(
+        self, *, retries_taken: int, max_retries: int, options: FinalRequestOptions, response: httpx.Response | None
+    ) -> None:
+        remaining_retries = max_retries - retries_taken
         if remaining_retries == 1:
             log.debug("1 retry left")
         else:
             log.debug("%i retries left", remaining_retries)
 
-        timeout = self._calculate_retry_timeout(remaining_retries, options, response_headers)
+        timeout = self._calculate_retry_timeout(remaining_retries, options, response.headers if response else None)
         log.info("Retrying request to %s in %f seconds", options.url, timeout)
 
-        # In a synchronous context we are blocking the entire thread. Up to the library user to run the client in a
-        # different thread if necessary.
         time.sleep(timeout)
 
-        return self._request(
-            options=options,
-            cast_to=cast_to,
-            retries_taken=retries_taken + 1,
-            stream=stream,
-            stream_cls=stream_cls,
-        )
-
     def _process_response(
         self,
         *,
@@ -1453,7 +1416,6 @@ class AsyncAPIClient(BaseClient[httpx.AsyncClient, AsyncStream[Any]]):
         options: FinalRequestOptions,
         *,
         stream: Literal[False] = False,
-        remaining_retries: Optional[int] = None,
     ) -> ResponseT: ...
 
     @overload
@@ -1464,7 +1426,6 @@ class AsyncAPIClient(BaseClient[httpx.AsyncClient, AsyncStream[Any]]):
         *,
         stream: Literal[True],
         stream_cls: type[_AsyncStreamT],
-        remaining_retries: Optional[int] = None,
     ) -> _AsyncStreamT: ...
 
     @overload
@@ -1475,7 +1436,6 @@ class AsyncAPIClient(BaseClient[httpx.AsyncClient, AsyncStream[Any]]):
         *,
         stream: bool,
         stream_cls: type[_AsyncStreamT] | None = None,
-        remaining_retries: Optional[int] = None,
     ) -> ResponseT | _AsyncStreamT: ...
 
     async def request(
@@ -1485,120 +1445,112 @@ class AsyncAPIClient(BaseClient[httpx.AsyncClient, AsyncStream[Any]]):
         *,
         stream: bool = False,
         stream_cls: type[_AsyncStreamT] | None = None,
-        remaining_retries: Optional[int] = None,
-    ) -> ResponseT | _AsyncStreamT:
-        if remaining_retries is not None:
-            retries_taken = options.get_max_retries(self.max_retries) - remaining_retries
-        else:
-            retries_taken = 0
-
-        return await self._request(
-            cast_to=cast_to,
-            options=options,
-            stream=stream,
-            stream_cls=stream_cls,
-            retries_taken=retries_taken,
-        )
-
-    async def _request(
-        self,
-        cast_to: Type[ResponseT],
-        options: FinalRequestOptions,
-        *,
-        stream: bool,
-        stream_cls: type[_AsyncStreamT] | None,
-        retries_taken: int,
     ) -> ResponseT | _AsyncStreamT:
         if self._platform is None:
             # `get_platform` can make blocking IO calls so we
             # execute it earlier while we are in an async context
             self._platform = await asyncify(get_platform)()
 
+        cast_to = self._maybe_override_cast_to(cast_to, options)
+
         # create a copy of the options we were given so that if the
         # options are mutated later & we then retry, the retries are
         # given the original options
         input_options = model_copy(options)
-
-        cast_to = self._maybe_override_cast_to(cast_to, options)
-        options = await self._prepare_options(options)
-
-        remaining_retries = options.get_max_retries(self.max_retries) - retries_taken
-        request = self._build_request(options, retries_taken=retries_taken)
-        await self._prepare_request(request)
-
-        if options.idempotency_key:
+        if input_options.idempotency_key is None and input_options.method.lower() != "get":
             # ensure the idempotency key is reused between requests
-            input_options.idempotency_key = options.idempotency_key
+            input_options.idempotency_key = self._idempotency_key()
 
-        kwargs: HttpxSendArgs = {}
-        if self.custom_auth is not None:
-            kwargs["auth"] = self.custom_auth
+        response: httpx.Response | None = None
+        max_retries = input_options.get_max_retries(self.max_retries)
 
-        try:
-            response = await self._client.send(
-                request,
-                stream=stream or self._should_stream_response_body(request=request),
-                **kwargs,
-            )
-        except httpx.TimeoutException as err:
-            log.debug("Encountered httpx.TimeoutException", exc_info=True)
+        retries_taken = 0
+        for retries_taken in range(max_retries + 1):
+            options = model_copy(input_options)
+            options = await self._prepare_options(options)
 
-            if remaining_retries > 0:
-                return await self._retry_request(
-                    input_options,
-                    cast_to,
-                    retries_taken=retries_taken,
-                    stream=stream,
-                    stream_cls=stream_cls,
-                    response_headers=None,
-                )
+            remaining_retries = max_retries - retries_taken
+            request = self._build_request(options, retries_taken=retries_taken)
+            await self._prepare_request(request)
 
-            log.debug("Raising timeout error")
-            raise APITimeoutError(request=request) from err
-        except Exception as err:
-            log.debug("Encountered Exception", exc_info=True)
+            kwargs: HttpxSendArgs = {}
+            if self.custom_auth is not None:
+                kwargs["auth"] = self.custom_auth
 
-            if remaining_retries > 0:
-                return await self._retry_request(
-                    input_options,
-                    cast_to,
-                    retries_taken=retries_taken,
-                    stream=stream,
-                    stream_cls=stream_cls,
-                    response_headers=None,
-                )
+            log.debug("Sending HTTP Request: %s %s", request.method, request.url)
 
-            log.debug("Raising connection error")
-            raise APIConnectionError(request=request) from err
+            response = None
+            try:
+                response = await self._client.send(
+                    request,
+                    stream=stream or self._should_stream_response_body(request=request),
+                    **kwargs,
+                )
+            except httpx.TimeoutException as err:
+                log.debug("Encountered httpx.TimeoutException", exc_info=True)
+
+                if remaining_retries > 0:
+                    await self._sleep_for_retry(
+                        retries_taken=retries_taken,
+                        max_retries=max_retries,
+                        options=input_options,
+                        response=None,
+                    )
+                    continue
+
+                log.debug("Raising timeout error")
+                raise APITimeoutError(request=request) from err
+            except Exception as err:
+                log.debug("Encountered Exception", exc_info=True)
+
+                if remaining_retries > 0:
+                    await self._sleep_for_retry(
+                        retries_taken=retries_taken,
+                        max_retries=max_retries,
+                        options=input_options,
+                        response=None,
+                    )
+                    continue
+
+                log.debug("Raising connection error")
+                raise APIConnectionError(request=request) from err
+
+            log.debug(
+                'HTTP Response: %s %s "%i %s" %s',
+                request.method,
+                request.url,
+                response.status_code,
+                response.reason_phrase,
+                response.headers,
+            )
+            log.debug("request_id: %s", response.headers.get("x-request-id"))
 
-        log.debug(
-            'HTTP Request: %s %s "%i %s"', request.method, request.url, response.status_code, response.reason_phrase
-        )
+            try:
+                response.raise_for_status()
+            except httpx.HTTPStatusError as err:  # thrown on 4xx and 5xx status code
+                log.debug("Encountered httpx.HTTPStatusError", exc_info=True)
+
+                if remaining_retries > 0 and self._should_retry(err.response):
+                    await err.response.aclose()
+                    await self._sleep_for_retry(
+                        retries_taken=retries_taken,
+                        max_retries=max_retries,
+                        options=input_options,
+                        response=response,
+                    )
+                    continue
 
-        try:
-            response.raise_for_status()
-        except httpx.HTTPStatusError as err:  # thrown on 4xx and 5xx status code
-            log.debug("Encountered httpx.HTTPStatusError", exc_info=True)
-
-            if remaining_retries > 0 and self._should_retry(err.response):
-                await err.response.aclose()
-                return await self._retry_request(
-                    input_options,
-                    cast_to,
-                    retries_taken=retries_taken,
-                    response_headers=err.response.headers,
-                    stream=stream,
-                    stream_cls=stream_cls,
-                )
+                # If the response is streamed then we need to explicitly read the response
+                # to completion before attempting to access the response text.
+                if not err.response.is_closed:
+                    await err.response.aread()
 
-            # If the response is streamed then we need to explicitly read the response
-            # to completion before attempting to access the response text.
-            if not err.response.is_closed:
-                await err.response.aread()
+                log.debug("Re-raising status error")
+                raise self._make_status_error_from_response(err.response) from None
 
-            log.debug("Re-raising status error")
-            raise self._make_status_error_from_response(err.response) from None
+            break
 
+        assert response is not None, "could not resolve response (should never happen)"
         return await self._process_response(
             cast_to=cast_to,
             options=options,
@@ -1608,35 +1560,20 @@ class AsyncAPIClient(BaseClient[httpx.AsyncClient, AsyncStream[Any]]):
             retries_taken=retries_taken,
         )
 
-    async def _retry_request(
-        self,
-        options: FinalRequestOptions,
-        cast_to: Type[ResponseT],
-        *,
-        retries_taken: int,
-        response_headers: httpx.Headers | None,
-        stream: bool,
-        stream_cls: type[_AsyncStreamT] | None,
-    ) -> ResponseT | _AsyncStreamT:
-        remaining_retries = options.get_max_retries(self.max_retries) - retries_taken
+    async def _sleep_for_retry(
+        self, *, retries_taken: int, max_retries: int, options: FinalRequestOptions, response: httpx.Response | None
+    ) -> None:
+        remaining_retries = max_retries - retries_taken
         if remaining_retries == 1:
             log.debug("1 retry left")
         else:
             log.debug("%i retries left", remaining_retries)
 
-        timeout = self._calculate_retry_timeout(remaining_retries, options, response_headers)
+        timeout = self._calculate_retry_timeout(remaining_retries, options, response.headers if response else None)
         log.info("Retrying request to %s in %f seconds", options.url, timeout)
 
         await anyio.sleep(timeout)
 
-        return await self._request(
-            options=options,
-            cast_to=cast_to,
-            retries_taken=retries_taken + 1,
-            stream=stream,
-            stream_cls=stream_cls,
-        )
-
     async def _process_response(
         self,
         *,