Commit 6634f525

stainless-app[bot] <142633134+stainless-app[bot]@users.noreply.github.com>
2024-10-21 15:23:51
chore(tests): add more retry tests (#1806)
1 parent a366120
Changed files (1)
tests/test_client.py
@@ -10,6 +10,7 @@ import inspect
 import tracemalloc
 from typing import Any, Union, cast
 from unittest import mock
+from typing_extensions import Literal
 
 import httpx
 import pytest
@@ -764,7 +765,14 @@ class TestOpenAI:
     @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
     @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
     @pytest.mark.respx(base_url=base_url)
-    def test_retries_taken(self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter) -> None:
+    @pytest.mark.parametrize("failure_mode", ["status", "exception"])
+    def test_retries_taken(
+        self,
+        client: OpenAI,
+        failures_before_success: int,
+        failure_mode: Literal["status", "exception"],
+        respx_mock: MockRouter,
+    ) -> None:
         client = client.with_options(max_retries=4)
 
         nb_retries = 0
@@ -773,6 +781,8 @@ class TestOpenAI:
             nonlocal nb_retries
             if nb_retries < failures_before_success:
                 nb_retries += 1
+                if failure_mode == "exception":
+                    raise RuntimeError("oops")
                 return httpx.Response(500)
             return httpx.Response(200)
 
@@ -1623,8 +1633,13 @@ class TestAsyncOpenAI:
     @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
     @pytest.mark.respx(base_url=base_url)
     @pytest.mark.asyncio
+    @pytest.mark.parametrize("failure_mode", ["status", "exception"])
     async def test_retries_taken(
-        self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
+        self,
+        async_client: AsyncOpenAI,
+        failures_before_success: int,
+        failure_mode: Literal["status", "exception"],
+        respx_mock: MockRouter,
     ) -> None:
         client = async_client.with_options(max_retries=4)
 
@@ -1634,6 +1649,8 @@ class TestAsyncOpenAI:
             nonlocal nb_retries
             if nb_retries < failures_before_success:
                 nb_retries += 1
+                if failure_mode == "exception":
+                    raise RuntimeError("oops")
                 return httpx.Response(500)
             return httpx.Response(200)