Commit 5449e208

Stainless Bot <dev+git@stainlessapi.com>
2024-09-25 20:45:37
feat(client): allow overriding retry count header (#1745)
1 parent 102fce4
Changed files (2)
src/openai/_base_client.py
@@ -413,8 +413,10 @@ class BaseClient(Generic[_HttpxClientT, _DefaultStreamT]):
         if idempotency_header and options.method.lower() != "get" and idempotency_header not in headers:
             headers[idempotency_header] = options.idempotency_key or self._idempotency_key()
 
-        if retries_taken > 0:
-            headers.setdefault("x-stainless-retry-count", str(retries_taken))
+        # Don't set the retry count header if it was already set or removed by the caller. We check
+        # `custom_headers`, which can contain `Omit()`, instead of `headers` to account for the removal case.
+        if "x-stainless-retry-count" not in (header.lower() for header in custom_headers):
+            headers["x-stainless-retry-count"] = str(retries_taken)
 
         return headers
 
tests/test_client.py
@@ -788,10 +788,71 @@ class TestOpenAI:
         )
 
         assert response.retries_taken == failures_before_success
-        if failures_before_success == 0:
-            assert "x-stainless-retry-count" not in response.http_request.headers
-        else:
-            assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
+        assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
+
+    @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
+    @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
+    @pytest.mark.respx(base_url=base_url)
+    def test_omit_retry_count_header(
+        self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
+    ) -> None:
+        client = client.with_options(max_retries=4)
+
+        nb_retries = 0
+
+        def retry_handler(_request: httpx.Request) -> httpx.Response:
+            nonlocal nb_retries
+            if nb_retries < failures_before_success:
+                nb_retries += 1
+                return httpx.Response(500)
+            return httpx.Response(200)
+
+        respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
+
+        response = client.chat.completions.with_raw_response.create(
+            messages=[
+                {
+                    "content": "string",
+                    "role": "system",
+                }
+            ],
+            model="gpt-4o",
+            extra_headers={"x-stainless-retry-count": Omit()},
+        )
+
+        assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0
+
+    @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
+    @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
+    @pytest.mark.respx(base_url=base_url)
+    def test_overwrite_retry_count_header(
+        self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
+    ) -> None:
+        client = client.with_options(max_retries=4)
+
+        nb_retries = 0
+
+        def retry_handler(_request: httpx.Request) -> httpx.Response:
+            nonlocal nb_retries
+            if nb_retries < failures_before_success:
+                nb_retries += 1
+                return httpx.Response(500)
+            return httpx.Response(200)
+
+        respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
+
+        response = client.chat.completions.with_raw_response.create(
+            messages=[
+                {
+                    "content": "string",
+                    "role": "system",
+                }
+            ],
+            model="gpt-4o",
+            extra_headers={"x-stainless-retry-count": "42"},
+        )
+
+        assert response.http_request.headers.get("x-stainless-retry-count") == "42"
 
     @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
     @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@@ -822,10 +883,7 @@ class TestOpenAI:
             model="gpt-4o",
         ) as response:
             assert response.retries_taken == failures_before_success
-            if failures_before_success == 0:
-                assert "x-stainless-retry-count" not in response.http_request.headers
-            else:
-                assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
+            assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
 
 
 class TestAsyncOpenAI:
@@ -1590,10 +1648,73 @@ class TestAsyncOpenAI:
         )
 
         assert response.retries_taken == failures_before_success
-        if failures_before_success == 0:
-            assert "x-stainless-retry-count" not in response.http_request.headers
-        else:
-            assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
+        assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
+
+    @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
+    @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
+    @pytest.mark.respx(base_url=base_url)
+    @pytest.mark.asyncio
+    async def test_omit_retry_count_header(
+        self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
+    ) -> None:
+        client = async_client.with_options(max_retries=4)
+
+        nb_retries = 0
+
+        def retry_handler(_request: httpx.Request) -> httpx.Response:
+            nonlocal nb_retries
+            if nb_retries < failures_before_success:
+                nb_retries += 1
+                return httpx.Response(500)
+            return httpx.Response(200)
+
+        respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
+
+        response = await client.chat.completions.with_raw_response.create(
+            messages=[
+                {
+                    "content": "string",
+                    "role": "system",
+                }
+            ],
+            model="gpt-4o",
+            extra_headers={"x-stainless-retry-count": Omit()},
+        )
+
+        assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0
+
+    @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
+    @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
+    @pytest.mark.respx(base_url=base_url)
+    @pytest.mark.asyncio
+    async def test_overwrite_retry_count_header(
+        self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
+    ) -> None:
+        client = async_client.with_options(max_retries=4)
+
+        nb_retries = 0
+
+        def retry_handler(_request: httpx.Request) -> httpx.Response:
+            nonlocal nb_retries
+            if nb_retries < failures_before_success:
+                nb_retries += 1
+                return httpx.Response(500)
+            return httpx.Response(200)
+
+        respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
+
+        response = await client.chat.completions.with_raw_response.create(
+            messages=[
+                {
+                    "content": "string",
+                    "role": "system",
+                }
+            ],
+            model="gpt-4o",
+            extra_headers={"x-stainless-retry-count": "42"},
+        )
+
+        assert response.http_request.headers.get("x-stainless-retry-count") == "42"
 
     @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
     @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@@ -1625,7 +1746,4 @@ class TestAsyncOpenAI:
             model="gpt-4o",
         ) as response:
             assert response.retries_taken == failures_before_success
-            if failures_before_success == 0:
-                assert "x-stainless-retry-count" not in response.http_request.headers
-            else:
-                assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
+            assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success