Commit 98d8b2ac

stainless-app[bot] <142633134+stainless-app[bot]@users.noreply.github.com>
2024-08-05 21:20:50
feat(client): add `retries_taken` to raw response class (#1601)
1 parent 5c7030e
src/openai/_base_client.py
@@ -1051,6 +1051,7 @@ class SyncAPIClient(BaseClient[httpx.Client, Stream[Any]]):
             response=response,
             stream=stream,
             stream_cls=stream_cls,
+            retries_taken=options.get_max_retries(self.max_retries) - retries,
         )
 
     def _retry_request(
@@ -1092,6 +1093,7 @@ class SyncAPIClient(BaseClient[httpx.Client, Stream[Any]]):
         response: httpx.Response,
         stream: bool,
         stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None,
+        retries_taken: int = 0,
     ) -> ResponseT:
         if response.request.headers.get(RAW_RESPONSE_HEADER) == "true":
             return cast(
@@ -1103,6 +1105,7 @@ class SyncAPIClient(BaseClient[httpx.Client, Stream[Any]]):
                     stream=stream,
                     stream_cls=stream_cls,
                     options=options,
+                    retries_taken=retries_taken,
                 ),
             )
 
@@ -1122,6 +1125,7 @@ class SyncAPIClient(BaseClient[httpx.Client, Stream[Any]]):
                     stream=stream,
                     stream_cls=stream_cls,
                     options=options,
+                    retries_taken=retries_taken,
                 ),
             )
 
@@ -1135,6 +1139,7 @@ class SyncAPIClient(BaseClient[httpx.Client, Stream[Any]]):
             stream=stream,
             stream_cls=stream_cls,
             options=options,
+            retries_taken=retries_taken,
         )
         if bool(response.request.headers.get(RAW_RESPONSE_HEADER)):
             return cast(ResponseT, api_response)
@@ -1625,6 +1630,7 @@ class AsyncAPIClient(BaseClient[httpx.AsyncClient, AsyncStream[Any]]):
             response=response,
             stream=stream,
             stream_cls=stream_cls,
+            retries_taken=options.get_max_retries(self.max_retries) - retries,
         )
 
     async def _retry_request(
@@ -1664,6 +1670,7 @@ class AsyncAPIClient(BaseClient[httpx.AsyncClient, AsyncStream[Any]]):
         response: httpx.Response,
         stream: bool,
         stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None,
+        retries_taken: int = 0,
     ) -> ResponseT:
         if response.request.headers.get(RAW_RESPONSE_HEADER) == "true":
             return cast(
@@ -1675,6 +1682,7 @@ class AsyncAPIClient(BaseClient[httpx.AsyncClient, AsyncStream[Any]]):
                     stream=stream,
                     stream_cls=stream_cls,
                     options=options,
+                    retries_taken=retries_taken,
                 ),
             )
 
@@ -1694,6 +1702,7 @@ class AsyncAPIClient(BaseClient[httpx.AsyncClient, AsyncStream[Any]]):
                     stream=stream,
                     stream_cls=stream_cls,
                     options=options,
+                    retries_taken=retries_taken,
                 ),
             )
 
@@ -1707,6 +1716,7 @@ class AsyncAPIClient(BaseClient[httpx.AsyncClient, AsyncStream[Any]]):
             stream=stream,
             stream_cls=stream_cls,
             options=options,
+            retries_taken=retries_taken,
         )
         if bool(response.request.headers.get(RAW_RESPONSE_HEADER)):
             return cast(ResponseT, api_response)
src/openai/_legacy_response.py
@@ -5,7 +5,18 @@ import inspect
 import logging
 import datetime
 import functools
-from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, Iterator, AsyncIterator, cast, overload
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Union,
+    Generic,
+    TypeVar,
+    Callable,
+    Iterator,
+    AsyncIterator,
+    cast,
+    overload,
+)
 from typing_extensions import Awaitable, ParamSpec, override, deprecated, get_origin
 
 import anyio
@@ -53,6 +64,9 @@ class LegacyAPIResponse(Generic[R]):
 
     http_response: httpx.Response
 
+    retries_taken: int
+    """The number of retries made. If no retries happened this will be `0`"""
+
     def __init__(
         self,
         *,
@@ -62,6 +76,7 @@ class LegacyAPIResponse(Generic[R]):
         stream: bool,
         stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None,
         options: FinalRequestOptions,
+        retries_taken: int = 0,
     ) -> None:
         self._cast_to = cast_to
         self._client = client
@@ -70,6 +85,7 @@ class LegacyAPIResponse(Generic[R]):
         self._stream_cls = stream_cls
         self._options = options
         self.http_response = raw
+        self.retries_taken = retries_taken
 
     @property
     def request_id(self) -> str | None:
src/openai/_response.py
@@ -55,6 +55,9 @@ class BaseAPIResponse(Generic[R]):
 
     http_response: httpx.Response
 
+    retries_taken: int
+    """The number of retries made. If no retries happened this will be `0`"""
+
     def __init__(
         self,
         *,
@@ -64,6 +67,7 @@ class BaseAPIResponse(Generic[R]):
         stream: bool,
         stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None,
         options: FinalRequestOptions,
+        retries_taken: int = 0,
     ) -> None:
         self._cast_to = cast_to
         self._client = client
@@ -72,6 +76,7 @@ class BaseAPIResponse(Generic[R]):
         self._stream_cls = stream_cls
         self._options = options
         self.http_response = raw
+        self.retries_taken = retries_taken
 
     @property
     def headers(self) -> httpx.Headers:
tests/test_client.py
@@ -758,6 +758,65 @@ class TestOpenAI:
 
         assert _get_open_connections(self.client) == 0
 
+    @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
+    @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
+    @pytest.mark.respx(base_url=base_url)
+    def test_retries_taken(self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter) -> None:
+        client = client.with_options(max_retries=4)
+
+        nb_retries = 0
+
+        def retry_handler(_request: httpx.Request) -> httpx.Response:
+            nonlocal nb_retries
+            if nb_retries < failures_before_success:
+                nb_retries += 1
+                return httpx.Response(500)
+            return httpx.Response(200)
+
+        respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
+
+        response = client.chat.completions.with_raw_response.create(
+            messages=[
+                {
+                    "content": "content",
+                    "role": "system",
+                }
+            ],
+            model="gpt-4-turbo",
+        )
+
+        assert response.retries_taken == failures_before_success
+
+    @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
+    @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
+    @pytest.mark.respx(base_url=base_url)
+    def test_retries_taken_new_response_class(
+        self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
+    ) -> None:
+        client = client.with_options(max_retries=4)
+
+        nb_retries = 0
+
+        def retry_handler(_request: httpx.Request) -> httpx.Response:
+            nonlocal nb_retries
+            if nb_retries < failures_before_success:
+                nb_retries += 1
+                return httpx.Response(500)
+            return httpx.Response(200)
+
+        respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
+
+        with client.chat.completions.with_streaming_response.create(
+            messages=[
+                {
+                    "content": "content",
+                    "role": "system",
+                }
+            ],
+            model="gpt-4-turbo",
+        ) as response:
+            assert response.retries_taken == failures_before_success
+
 
 class TestAsyncOpenAI:
     client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
@@ -1488,3 +1547,66 @@ class TestAsyncOpenAI:
             )
 
         assert _get_open_connections(self.client) == 0
+
+    @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
+    @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
+    @pytest.mark.respx(base_url=base_url)
+    @pytest.mark.asyncio
+    async def test_retries_taken(
+        self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
+    ) -> None:
+        client = async_client.with_options(max_retries=4)
+
+        nb_retries = 0
+
+        def retry_handler(_request: httpx.Request) -> httpx.Response:
+            nonlocal nb_retries
+            if nb_retries < failures_before_success:
+                nb_retries += 1
+                return httpx.Response(500)
+            return httpx.Response(200)
+
+        respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
+
+        response = await client.chat.completions.with_raw_response.create(
+            messages=[
+                {
+                    "content": "content",
+                    "role": "system",
+                }
+            ],
+            model="gpt-4-turbo",
+        )
+
+        assert response.retries_taken == failures_before_success
+
+    @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
+    @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
+    @pytest.mark.respx(base_url=base_url)
+    @pytest.mark.asyncio
+    async def test_retries_taken_new_response_class(
+        self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
+    ) -> None:
+        client = async_client.with_options(max_retries=4)
+
+        nb_retries = 0
+
+        def retry_handler(_request: httpx.Request) -> httpx.Response:
+            nonlocal nb_retries
+            if nb_retries < failures_before_success:
+                nb_retries += 1
+                return httpx.Response(500)
+            return httpx.Response(200)
+
+        respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
+
+        async with client.chat.completions.with_streaming_response.create(
+            messages=[
+                {
+                    "content": "content",
+                    "role": "system",
+                }
+            ],
+            model="gpt-4-turbo",
+        ) as response:
+            assert response.retries_taken == failures_before_success