Commit 6b01ba6a

Robert Craigie <robert@craigie.dev>
2024-07-10 02:03:49
fix(azure): refresh auth token during retries (#1533)
1 parent 41f682b
Changed files (1)
tests
tests/lib/test_azure.py
@@ -1,7 +1,9 @@
-from typing import Union
-from typing_extensions import Literal
+from typing import Union, cast
+from typing_extensions import Literal, Protocol
 
+import httpx
 import pytest
+from respx import MockRouter
 
 from openai._models import FinalRequestOptions
 from openai.lib.azure import AzureOpenAI, AsyncAzureOpenAI
@@ -22,6 +24,10 @@ async_client = AsyncAzureOpenAI(
 )
 
 
+class MockRequestCall(Protocol):
+    request: httpx.Request
+
+
 @pytest.mark.parametrize("client", [sync_client, async_client])
 def test_implicit_deployment_path(client: Client) -> None:
     req = client._build_request(
@@ -64,3 +70,81 @@ def test_client_copying_override_options(client: Client) -> None:
         api_version="2022-05-01",
     )
     assert copied._custom_query == {"api-version": "2022-05-01"}
+
+
+@pytest.mark.respx()
+def test_client_token_provider_refresh_sync(respx_mock: MockRouter) -> None:
+    respx_mock.post(
+        "https://example-resource.azure.openai.com/openai/deployments/gpt-4/chat/completions?api-version=2024-02-01"
+    ).mock(
+        side_effect=[
+            httpx.Response(500, json={"error": "server error"}),
+            httpx.Response(200, json={"foo": "bar"}),
+        ]
+    )
+
+    counter = 0
+
+    def token_provider() -> str:
+        nonlocal counter
+
+        counter += 1
+
+        if counter == 1:
+            return "first"
+
+        return "second"
+
+    client = AzureOpenAI(
+        api_version="2024-02-01",
+        azure_ad_token_provider=token_provider,
+        azure_endpoint="https://example-resource.azure.openai.com",
+    )
+    client.chat.completions.create(messages=[], model="gpt-4")
+
+    calls = cast("list[MockRequestCall]", respx_mock.calls)
+
+    assert len(calls) == 2
+
+    assert calls[0].request.headers.get("Authorization") == "Bearer first"
+    assert calls[1].request.headers.get("Authorization") == "Bearer second"
+
+
+@pytest.mark.asyncio
+@pytest.mark.respx()
+async def test_client_token_provider_refresh_async(respx_mock: MockRouter) -> None:
+    respx_mock.post(
+        "https://example-resource.azure.openai.com/openai/deployments/gpt-4/chat/completions?api-version=2024-02-01"
+    ).mock(
+        side_effect=[
+            httpx.Response(500, json={"error": "server error"}),
+            httpx.Response(200, json={"foo": "bar"}),
+        ]
+    )
+
+    counter = 0
+
+    def token_provider() -> str:
+        nonlocal counter
+
+        counter += 1
+
+        if counter == 1:
+            return "first"
+
+        return "second"
+
+    client = AsyncAzureOpenAI(
+        api_version="2024-02-01",
+        azure_ad_token_provider=token_provider,
+        azure_endpoint="https://example-resource.azure.openai.com",
+    )
+
+    await client.chat.completions.create(messages=[], model="gpt-4")
+
+    calls = cast("list[MockRequestCall]", respx_mock.calls)
+
+    assert len(calls) == 2
+
+    assert calls[0].request.headers.get("Authorization") == "Bearer first"
+    assert calls[1].request.headers.get("Authorization") == "Bearer second"