Commit 25d16be1

Johan Stenberg (MSFT) <johan.stenberg@microsoft.com>
2025-09-03 23:52:05
feat(client): support callable api_key (#2588)
Co-authored-by: Krista Pratico <krpratic@microsoft.com>
1 parent 8672413
Changed files (5)
src
openai
lib
resources
beta
realtime
realtime
tests
src/openai/lib/azure.py
@@ -94,7 +94,7 @@ class AzureOpenAI(BaseAzureClient[httpx.Client, Stream[Any]], OpenAI):
         azure_endpoint: str,
         azure_deployment: str | None = None,
         api_version: str | None = None,
-        api_key: str | None = None,
+        api_key: str | Callable[[], str] | None = None,
         azure_ad_token: str | None = None,
         azure_ad_token_provider: AzureADTokenProvider | None = None,
         organization: str | None = None,
@@ -114,7 +114,7 @@ class AzureOpenAI(BaseAzureClient[httpx.Client, Stream[Any]], OpenAI):
         *,
         azure_deployment: str | None = None,
         api_version: str | None = None,
-        api_key: str | None = None,
+        api_key: str | Callable[[], str] | None = None,
         azure_ad_token: str | None = None,
         azure_ad_token_provider: AzureADTokenProvider | None = None,
         organization: str | None = None,
@@ -134,7 +134,7 @@ class AzureOpenAI(BaseAzureClient[httpx.Client, Stream[Any]], OpenAI):
         *,
         base_url: str,
         api_version: str | None = None,
-        api_key: str | None = None,
+        api_key: str | Callable[[], str] | None = None,
         azure_ad_token: str | None = None,
         azure_ad_token_provider: AzureADTokenProvider | None = None,
         organization: str | None = None,
@@ -154,7 +154,7 @@ class AzureOpenAI(BaseAzureClient[httpx.Client, Stream[Any]], OpenAI):
         api_version: str | None = None,
         azure_endpoint: str | None = None,
         azure_deployment: str | None = None,
-        api_key: str | None = None,
+        api_key: str | Callable[[], str] | None = None,
         azure_ad_token: str | None = None,
         azure_ad_token_provider: AzureADTokenProvider | None = None,
         organization: str | None = None,
@@ -258,7 +258,7 @@ class AzureOpenAI(BaseAzureClient[httpx.Client, Stream[Any]], OpenAI):
     def copy(
         self,
         *,
-        api_key: str | None = None,
+        api_key: str | Callable[[], str] | None = None,
         organization: str | None = None,
         project: str | None = None,
         webhook_secret: str | None = None,
@@ -345,7 +345,7 @@ class AzureOpenAI(BaseAzureClient[httpx.Client, Stream[Any]], OpenAI):
             "api-version": self._api_version,
             "deployment": self._azure_deployment or model,
         }
-        if self.api_key != "<missing API key>":
+        if self.api_key and self.api_key != "<missing API key>":
             auth_headers = {"api-key": self.api_key}
         else:
             token = self._get_azure_ad_token()
@@ -372,7 +372,7 @@ class AsyncAzureOpenAI(BaseAzureClient[httpx.AsyncClient, AsyncStream[Any]], Asy
         azure_endpoint: str,
         azure_deployment: str | None = None,
         api_version: str | None = None,
-        api_key: str | None = None,
+        api_key: str | Callable[[], Awaitable[str]] | None = None,
         azure_ad_token: str | None = None,
         azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
         organization: str | None = None,
@@ -393,7 +393,7 @@ class AsyncAzureOpenAI(BaseAzureClient[httpx.AsyncClient, AsyncStream[Any]], Asy
         *,
         azure_deployment: str | None = None,
         api_version: str | None = None,
-        api_key: str | None = None,
+        api_key: str | Callable[[], Awaitable[str]] | None = None,
         azure_ad_token: str | None = None,
         azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
         organization: str | None = None,
@@ -414,7 +414,7 @@ class AsyncAzureOpenAI(BaseAzureClient[httpx.AsyncClient, AsyncStream[Any]], Asy
         *,
         base_url: str,
         api_version: str | None = None,
-        api_key: str | None = None,
+        api_key: str | Callable[[], Awaitable[str]] | None = None,
         azure_ad_token: str | None = None,
         azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
         organization: str | None = None,
@@ -435,7 +435,7 @@ class AsyncAzureOpenAI(BaseAzureClient[httpx.AsyncClient, AsyncStream[Any]], Asy
         azure_endpoint: str | None = None,
         azure_deployment: str | None = None,
         api_version: str | None = None,
-        api_key: str | None = None,
+        api_key: str | Callable[[], Awaitable[str]] | None = None,
         azure_ad_token: str | None = None,
         azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
         organization: str | None = None,
@@ -539,7 +539,7 @@ class AsyncAzureOpenAI(BaseAzureClient[httpx.AsyncClient, AsyncStream[Any]], Asy
     def copy(
         self,
         *,
-        api_key: str | None = None,
+        api_key: str | Callable[[], Awaitable[str]] | None = None,
         organization: str | None = None,
         project: str | None = None,
         webhook_secret: str | None = None,
@@ -628,7 +628,7 @@ class AsyncAzureOpenAI(BaseAzureClient[httpx.AsyncClient, AsyncStream[Any]], Asy
             "api-version": self._api_version,
             "deployment": self._azure_deployment or model,
         }
-        if self.api_key != "<missing API key>":
+        if self.api_key and self.api_key != "<missing API key>":
             auth_headers = {"api-key": self.api_key}
         else:
             token = await self._get_azure_ad_token()
src/openai/resources/beta/realtime/realtime.py
@@ -358,6 +358,7 @@ class AsyncRealtimeConnectionManager:
             raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc
 
         extra_query = self.__extra_query
+        await self.__client._refresh_api_key()
         auth_headers = self.__client.auth_headers
         if is_async_azure_client(self.__client):
             url, auth_headers = await self.__client._configure_realtime(self.__model, extra_query)
@@ -540,6 +541,7 @@ class RealtimeConnectionManager:
             raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc
 
         extra_query = self.__extra_query
+        self.__client._refresh_api_key()
         auth_headers = self.__client.auth_headers
         if is_azure_client(self.__client):
             url, auth_headers = self.__client._configure_realtime(self.__model, extra_query)
src/openai/resources/realtime/realtime.py
@@ -326,6 +326,7 @@ class AsyncRealtimeConnectionManager:
             raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc
 
         extra_query = self.__extra_query
+        await self.__client._refresh_api_key()
         auth_headers = self.__client.auth_headers
         if is_async_azure_client(self.__client):
             url, auth_headers = await self.__client._configure_realtime(self.__model, extra_query)
@@ -507,6 +508,7 @@ class RealtimeConnectionManager:
             raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc
 
         extra_query = self.__extra_query
+        self.__client._refresh_api_key()
         auth_headers = self.__client.auth_headers
         if is_azure_client(self.__client):
             url, auth_headers = self.__client._configure_realtime(self.__model, extra_query)
src/openai/_client.py
@@ -3,7 +3,7 @@
 from __future__ import annotations
 
 import os
-from typing import TYPE_CHECKING, Any, Union, Mapping
+from typing import TYPE_CHECKING, Any, Union, Mapping, Callable, Awaitable
 from typing_extensions import Self, override
 
 import httpx
@@ -25,6 +25,7 @@ from ._utils import (
     get_async_library,
 )
 from ._compat import cached_property
+from ._models import FinalRequestOptions
 from ._version import __version__
 from ._streaming import Stream as Stream, AsyncStream as AsyncStream
 from ._exceptions import OpenAIError, APIStatusError
@@ -96,7 +97,7 @@ class OpenAI(SyncAPIClient):
     def __init__(
         self,
         *,
-        api_key: str | None = None,
+        api_key: str | None | Callable[[], str] = None,
         organization: str | None = None,
         project: str | None = None,
         webhook_secret: str | None = None,
@@ -134,7 +135,12 @@ class OpenAI(SyncAPIClient):
             raise OpenAIError(
                 "The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable"
             )
-        self.api_key = api_key
+        if callable(api_key):
+            self.api_key = ""
+            self._api_key_provider: Callable[[], str] | None = api_key
+        else:
+            self.api_key = api_key
+            self._api_key_provider = None
 
         if organization is None:
             organization = os.environ.get("OPENAI_ORG_ID")
@@ -295,6 +301,15 @@ class OpenAI(SyncAPIClient):
     def qs(self) -> Querystring:
         return Querystring(array_format="brackets")
 
+    def _refresh_api_key(self) -> None:
+        if self._api_key_provider:
+            self.api_key = self._api_key_provider()
+
+    @override
+    def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
+        self._refresh_api_key()
+        return super()._prepare_options(options)
+
     @property
     @override
     def auth_headers(self) -> dict[str, str]:
@@ -318,7 +333,7 @@ class OpenAI(SyncAPIClient):
     def copy(
         self,
         *,
-        api_key: str | None = None,
+        api_key: str | Callable[[], str] | None = None,
         organization: str | None = None,
         project: str | None = None,
         webhook_secret: str | None = None,
@@ -356,7 +371,7 @@ class OpenAI(SyncAPIClient):
 
         http_client = http_client or self._client
         return self.__class__(
-            api_key=api_key or self.api_key,
+            api_key=api_key or self._api_key_provider or self.api_key,
             organization=organization or self.organization,
             project=project or self.project,
             webhook_secret=webhook_secret or self.webhook_secret,
@@ -427,7 +442,7 @@ class AsyncOpenAI(AsyncAPIClient):
     def __init__(
         self,
         *,
-        api_key: str | None = None,
+        api_key: str | Callable[[], Awaitable[str]] | None = None,
         organization: str | None = None,
         project: str | None = None,
         webhook_secret: str | None = None,
@@ -465,7 +480,12 @@ class AsyncOpenAI(AsyncAPIClient):
             raise OpenAIError(
                 "The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable"
             )
-        self.api_key = api_key
+        if callable(api_key):
+            self.api_key = ""
+            self._api_key_provider: Callable[[], Awaitable[str]] | None = api_key
+        else:
+            self.api_key = api_key
+            self._api_key_provider = None
 
         if organization is None:
             organization = os.environ.get("OPENAI_ORG_ID")
@@ -626,6 +646,15 @@ class AsyncOpenAI(AsyncAPIClient):
     def qs(self) -> Querystring:
         return Querystring(array_format="brackets")
 
+    async def _refresh_api_key(self) -> None:
+        if self._api_key_provider:
+            self.api_key = await self._api_key_provider()
+
+    @override
+    async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
+        await self._refresh_api_key()
+        return await super()._prepare_options(options)
+
     @property
     @override
     def auth_headers(self) -> dict[str, str]:
@@ -649,7 +678,7 @@ class AsyncOpenAI(AsyncAPIClient):
     def copy(
         self,
         *,
-        api_key: str | None = None,
+        api_key: str | Callable[[], Awaitable[str]] | None = None,
         organization: str | None = None,
         project: str | None = None,
         webhook_secret: str | None = None,
@@ -687,7 +716,7 @@ class AsyncOpenAI(AsyncAPIClient):
 
         http_client = http_client or self._client
         return self.__class__(
-            api_key=api_key or self.api_key,
+            api_key=api_key or self._api_key_provider or self.api_key,
             organization=organization or self.organization,
             project=project or self.project,
             webhook_secret=webhook_secret or self.webhook_secret,
tests/test_client.py
@@ -11,7 +11,7 @@ import asyncio
 import inspect
 import subprocess
 import tracemalloc
-from typing import Any, Union, cast
+from typing import Any, Union, Protocol, cast
 from textwrap import dedent
 from unittest import mock
 from typing_extensions import Literal
@@ -41,6 +41,10 @@ base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
 api_key = "My API Key"
 
 
+class MockRequestCall(Protocol):
+    request: httpx.Request
+
+
 def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]:
     request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
     url = httpx.URL(request.url)
@@ -337,7 +341,9 @@ class TestOpenAI:
 
     def test_validate_headers(self) -> None:
         client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
-        request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
+        options = client._prepare_options(FinalRequestOptions(method="get", url="/foo"))
+        request = client._build_request(options)
+
         assert request.headers.get("Authorization") == f"Bearer {api_key}"
 
         with pytest.raises(OpenAIError):
@@ -939,6 +945,62 @@ class TestOpenAI:
         assert exc_info.value.response.status_code == 302
         assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected"
 
+    def test_api_key_before_after_refresh_provider(self) -> None:
+        client = OpenAI(base_url=base_url, api_key=lambda: "test_bearer_token")
+
+        assert client.api_key == ""
+        assert "Authorization" not in client.auth_headers
+
+        client._refresh_api_key()
+
+        assert client.api_key == "test_bearer_token"
+        assert client.auth_headers.get("Authorization") == "Bearer test_bearer_token"
+
+    def test_api_key_before_after_refresh_str(self) -> None:
+        client = OpenAI(base_url=base_url, api_key="test_api_key")
+
+        assert client.auth_headers.get("Authorization") == "Bearer test_api_key"
+        client._refresh_api_key()
+
+        assert client.auth_headers.get("Authorization") == "Bearer test_api_key"
+
+    @pytest.mark.respx()
+    def test_api_key_refresh_on_retry(self, respx_mock: MockRouter) -> None:
+        respx_mock.post(base_url + "/chat/completions").mock(
+            side_effect=[
+                httpx.Response(500, json={"error": "server error"}),
+                httpx.Response(200, json={"foo": "bar"}),
+            ]
+        )
+
+        counter = 0
+
+        def token_provider() -> str:
+            nonlocal counter
+
+            counter += 1
+
+            if counter == 1:
+                return "first"
+
+            return "second"
+
+        client = OpenAI(base_url=base_url, api_key=token_provider)
+        client.chat.completions.create(messages=[], model="gpt-4")
+
+        calls = cast("list[MockRequestCall]", respx_mock.calls)
+        assert len(calls) == 2
+
+        assert calls[0].request.headers.get("Authorization") == "Bearer first"
+        assert calls[1].request.headers.get("Authorization") == "Bearer second"
+
+    def test_copy_auth(self) -> None:
+        client = OpenAI(base_url=base_url, api_key=lambda: "test_bearer_token_1").copy(
+            api_key=lambda: "test_bearer_token_2"
+        )
+        client._refresh_api_key()
+        assert client.auth_headers == {"Authorization": "Bearer test_bearer_token_2"}
+
 
 class TestAsyncOpenAI:
     client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
@@ -1220,9 +1282,10 @@ class TestAsyncOpenAI:
         assert request.headers.get("x-foo") == "stainless"
         assert request.headers.get("x-stainless-lang") == "my-overriding-header"
 
-    def test_validate_headers(self) -> None:
+    async def test_validate_headers(self) -> None:
         client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
-        request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
+        options = await client._prepare_options(FinalRequestOptions(method="get", url="/foo"))
+        request = client._build_request(options)
         assert request.headers.get("Authorization") == f"Bearer {api_key}"
 
         with pytest.raises(OpenAIError):
@@ -1887,3 +1950,70 @@ class TestAsyncOpenAI:
 
         assert exc_info.value.response.status_code == 302
         assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected"
+
+    @pytest.mark.asyncio
+    async def test_api_key_before_after_refresh_provider(self) -> None:
+        async def mock_api_key_provider():
+            return "test_bearer_token"
+
+        client = AsyncOpenAI(base_url=base_url, api_key=mock_api_key_provider)
+
+        assert client.api_key == ""
+        assert "Authorization" not in client.auth_headers
+
+        await client._refresh_api_key()
+
+        assert client.api_key == "test_bearer_token"
+        assert client.auth_headers.get("Authorization") == "Bearer test_bearer_token"
+
+    @pytest.mark.asyncio
+    async def test_api_key_before_after_refresh_str(self) -> None:
+        client = AsyncOpenAI(base_url=base_url, api_key="test_api_key")
+
+        assert client.auth_headers.get("Authorization") == "Bearer test_api_key"
+        await client._refresh_api_key()
+
+        assert client.auth_headers.get("Authorization") == "Bearer test_api_key"
+
+    @pytest.mark.asyncio
+    @pytest.mark.respx()
+    async def test_bearer_token_refresh_async(self, respx_mock: MockRouter) -> None:
+        respx_mock.post(base_url + "/chat/completions").mock(
+            side_effect=[
+                httpx.Response(500, json={"error": "server error"}),
+                httpx.Response(200, json={"foo": "bar"}),
+            ]
+        )
+
+        counter = 0
+
+        async def token_provider() -> str:
+            nonlocal counter
+
+            counter += 1
+
+            if counter == 1:
+                return "first"
+
+            return "second"
+
+        client = AsyncOpenAI(base_url=base_url, api_key=token_provider)
+        await client.chat.completions.create(messages=[], model="gpt-4")
+
+        calls = cast("list[MockRequestCall]", respx_mock.calls)
+        assert len(calls) == 2
+
+        assert calls[0].request.headers.get("Authorization") == "Bearer first"
+        assert calls[1].request.headers.get("Authorization") == "Bearer second"
+
+    @pytest.mark.asyncio
+    async def test_copy_auth(self) -> None:
+        async def token_provider_1() -> str:
+            return "test_bearer_token_1"
+
+        async def token_provider_2() -> str:
+            return "test_bearer_token_2"
+
+        client = AsyncOpenAI(base_url=base_url, api_key=token_provider_1).copy(api_key=token_provider_2)
+        await client._refresh_api_key()
+        assert client.auth_headers == {"Authorization": "Bearer test_bearer_token_2"}