Commit 6b01ba6a
Changed files (1)
tests
tests/lib/test_azure.py
@@ -1,7 +1,9 @@
-from typing import Union
-from typing_extensions import Literal
+from typing import Union, cast
+from typing_extensions import Literal, Protocol
+import httpx
import pytest
+from respx import MockRouter
from openai._models import FinalRequestOptions
from openai.lib.azure import AzureOpenAI, AsyncAzureOpenAI
@@ -22,6 +24,10 @@ async_client = AsyncAzureOpenAI(
)
+class MockRequestCall(Protocol):
+ request: httpx.Request
+
+
@pytest.mark.parametrize("client", [sync_client, async_client])
def test_implicit_deployment_path(client: Client) -> None:
req = client._build_request(
@@ -64,3 +70,81 @@ def test_client_copying_override_options(client: Client) -> None:
api_version="2022-05-01",
)
assert copied._custom_query == {"api-version": "2022-05-01"}
+
+
+@pytest.mark.respx()
+def test_client_token_provider_refresh_sync(respx_mock: MockRouter) -> None:
+ respx_mock.post(
+ "https://example-resource.azure.openai.com/openai/deployments/gpt-4/chat/completions?api-version=2024-02-01"
+ ).mock(
+ side_effect=[
+ httpx.Response(500, json={"error": "server error"}),
+ httpx.Response(200, json={"foo": "bar"}),
+ ]
+ )
+
+ counter = 0
+
+ def token_provider() -> str:
+ nonlocal counter
+
+ counter += 1
+
+ if counter == 1:
+ return "first"
+
+ return "second"
+
+ client = AzureOpenAI(
+ api_version="2024-02-01",
+ azure_ad_token_provider=token_provider,
+ azure_endpoint="https://example-resource.azure.openai.com",
+ )
+ client.chat.completions.create(messages=[], model="gpt-4")
+
+ calls = cast("list[MockRequestCall]", respx_mock.calls)
+
+ assert len(calls) == 2
+
+ assert calls[0].request.headers.get("Authorization") == "Bearer first"
+ assert calls[1].request.headers.get("Authorization") == "Bearer second"
+
+
+@pytest.mark.asyncio
+@pytest.mark.respx()
+async def test_client_token_provider_refresh_async(respx_mock: MockRouter) -> None:
+ respx_mock.post(
+ "https://example-resource.azure.openai.com/openai/deployments/gpt-4/chat/completions?api-version=2024-02-01"
+ ).mock(
+ side_effect=[
+ httpx.Response(500, json={"error": "server error"}),
+ httpx.Response(200, json={"foo": "bar"}),
+ ]
+ )
+
+ counter = 0
+
+ def token_provider() -> str:
+ nonlocal counter
+
+ counter += 1
+
+ if counter == 1:
+ return "first"
+
+ return "second"
+
+ client = AsyncAzureOpenAI(
+ api_version="2024-02-01",
+ azure_ad_token_provider=token_provider,
+ azure_endpoint="https://example-resource.azure.openai.com",
+ )
+
+ await client.chat.completions.create(messages=[], model="gpt-4")
+
+ calls = cast("list[MockRequestCall]", respx_mock.calls)
+
+ assert len(calls) == 2
+
+ assert calls[0].request.headers.get("Authorization") == "Bearer first"
+ assert calls[1].request.headers.get("Authorization") == "Bearer second"