Commit 25d16be1
Changed files (5)
src
openai
tests
src/openai/lib/azure.py
@@ -94,7 +94,7 @@ class AzureOpenAI(BaseAzureClient[httpx.Client, Stream[Any]], OpenAI):
azure_endpoint: str,
azure_deployment: str | None = None,
api_version: str | None = None,
- api_key: str | None = None,
+ api_key: str | Callable[[], str] | None = None,
azure_ad_token: str | None = None,
azure_ad_token_provider: AzureADTokenProvider | None = None,
organization: str | None = None,
@@ -114,7 +114,7 @@ class AzureOpenAI(BaseAzureClient[httpx.Client, Stream[Any]], OpenAI):
*,
azure_deployment: str | None = None,
api_version: str | None = None,
- api_key: str | None = None,
+ api_key: str | Callable[[], str] | None = None,
azure_ad_token: str | None = None,
azure_ad_token_provider: AzureADTokenProvider | None = None,
organization: str | None = None,
@@ -134,7 +134,7 @@ class AzureOpenAI(BaseAzureClient[httpx.Client, Stream[Any]], OpenAI):
*,
base_url: str,
api_version: str | None = None,
- api_key: str | None = None,
+ api_key: str | Callable[[], str] | None = None,
azure_ad_token: str | None = None,
azure_ad_token_provider: AzureADTokenProvider | None = None,
organization: str | None = None,
@@ -154,7 +154,7 @@ class AzureOpenAI(BaseAzureClient[httpx.Client, Stream[Any]], OpenAI):
api_version: str | None = None,
azure_endpoint: str | None = None,
azure_deployment: str | None = None,
- api_key: str | None = None,
+ api_key: str | Callable[[], str] | None = None,
azure_ad_token: str | None = None,
azure_ad_token_provider: AzureADTokenProvider | None = None,
organization: str | None = None,
@@ -258,7 +258,7 @@ class AzureOpenAI(BaseAzureClient[httpx.Client, Stream[Any]], OpenAI):
def copy(
self,
*,
- api_key: str | None = None,
+ api_key: str | Callable[[], str] | None = None,
organization: str | None = None,
project: str | None = None,
webhook_secret: str | None = None,
@@ -345,7 +345,7 @@ class AzureOpenAI(BaseAzureClient[httpx.Client, Stream[Any]], OpenAI):
"api-version": self._api_version,
"deployment": self._azure_deployment or model,
}
- if self.api_key != "<missing API key>":
+ if self.api_key and self.api_key != "<missing API key>":
auth_headers = {"api-key": self.api_key}
else:
token = self._get_azure_ad_token()
@@ -372,7 +372,7 @@ class AsyncAzureOpenAI(BaseAzureClient[httpx.AsyncClient, AsyncStream[Any]], Asy
azure_endpoint: str,
azure_deployment: str | None = None,
api_version: str | None = None,
- api_key: str | None = None,
+ api_key: str | Callable[[], Awaitable[str]] | None = None,
azure_ad_token: str | None = None,
azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
organization: str | None = None,
@@ -393,7 +393,7 @@ class AsyncAzureOpenAI(BaseAzureClient[httpx.AsyncClient, AsyncStream[Any]], Asy
*,
azure_deployment: str | None = None,
api_version: str | None = None,
- api_key: str | None = None,
+ api_key: str | Callable[[], Awaitable[str]] | None = None,
azure_ad_token: str | None = None,
azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
organization: str | None = None,
@@ -414,7 +414,7 @@ class AsyncAzureOpenAI(BaseAzureClient[httpx.AsyncClient, AsyncStream[Any]], Asy
*,
base_url: str,
api_version: str | None = None,
- api_key: str | None = None,
+ api_key: str | Callable[[], Awaitable[str]] | None = None,
azure_ad_token: str | None = None,
azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
organization: str | None = None,
@@ -435,7 +435,7 @@ class AsyncAzureOpenAI(BaseAzureClient[httpx.AsyncClient, AsyncStream[Any]], Asy
azure_endpoint: str | None = None,
azure_deployment: str | None = None,
api_version: str | None = None,
- api_key: str | None = None,
+ api_key: str | Callable[[], Awaitable[str]] | None = None,
azure_ad_token: str | None = None,
azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
organization: str | None = None,
@@ -539,7 +539,7 @@ class AsyncAzureOpenAI(BaseAzureClient[httpx.AsyncClient, AsyncStream[Any]], Asy
def copy(
self,
*,
- api_key: str | None = None,
+ api_key: str | Callable[[], Awaitable[str]] | None = None,
organization: str | None = None,
project: str | None = None,
webhook_secret: str | None = None,
@@ -628,7 +628,7 @@ class AsyncAzureOpenAI(BaseAzureClient[httpx.AsyncClient, AsyncStream[Any]], Asy
"api-version": self._api_version,
"deployment": self._azure_deployment or model,
}
- if self.api_key != "<missing API key>":
+ if self.api_key and self.api_key != "<missing API key>":
auth_headers = {"api-key": self.api_key}
else:
token = await self._get_azure_ad_token()
src/openai/resources/beta/realtime/realtime.py
@@ -358,6 +358,7 @@ class AsyncRealtimeConnectionManager:
raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc
extra_query = self.__extra_query
+ await self.__client._refresh_api_key()
auth_headers = self.__client.auth_headers
if is_async_azure_client(self.__client):
url, auth_headers = await self.__client._configure_realtime(self.__model, extra_query)
@@ -540,6 +541,7 @@ class RealtimeConnectionManager:
raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc
extra_query = self.__extra_query
+ self.__client._refresh_api_key()
auth_headers = self.__client.auth_headers
if is_azure_client(self.__client):
url, auth_headers = self.__client._configure_realtime(self.__model, extra_query)
src/openai/resources/realtime/realtime.py
@@ -326,6 +326,7 @@ class AsyncRealtimeConnectionManager:
raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc
extra_query = self.__extra_query
+ await self.__client._refresh_api_key()
auth_headers = self.__client.auth_headers
if is_async_azure_client(self.__client):
url, auth_headers = await self.__client._configure_realtime(self.__model, extra_query)
@@ -507,6 +508,7 @@ class RealtimeConnectionManager:
raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc
extra_query = self.__extra_query
+ self.__client._refresh_api_key()
auth_headers = self.__client.auth_headers
if is_azure_client(self.__client):
url, auth_headers = self.__client._configure_realtime(self.__model, extra_query)
src/openai/_client.py
@@ -3,7 +3,7 @@
from __future__ import annotations
import os
-from typing import TYPE_CHECKING, Any, Union, Mapping
+from typing import TYPE_CHECKING, Any, Union, Mapping, Callable, Awaitable
from typing_extensions import Self, override
import httpx
@@ -25,6 +25,7 @@ from ._utils import (
get_async_library,
)
from ._compat import cached_property
+from ._models import FinalRequestOptions
from ._version import __version__
from ._streaming import Stream as Stream, AsyncStream as AsyncStream
from ._exceptions import OpenAIError, APIStatusError
@@ -96,7 +97,7 @@ class OpenAI(SyncAPIClient):
def __init__(
self,
*,
- api_key: str | None = None,
+ api_key: str | None | Callable[[], str] = None,
organization: str | None = None,
project: str | None = None,
webhook_secret: str | None = None,
@@ -134,7 +135,12 @@ class OpenAI(SyncAPIClient):
raise OpenAIError(
"The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable"
)
- self.api_key = api_key
+ if callable(api_key):
+ self.api_key = ""
+ self._api_key_provider: Callable[[], str] | None = api_key
+ else:
+ self.api_key = api_key
+ self._api_key_provider = None
if organization is None:
organization = os.environ.get("OPENAI_ORG_ID")
@@ -295,6 +301,15 @@ class OpenAI(SyncAPIClient):
def qs(self) -> Querystring:
return Querystring(array_format="brackets")
+ def _refresh_api_key(self) -> None:
+ if self._api_key_provider:
+ self.api_key = self._api_key_provider()
+
+ @override
+ def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
+ self._refresh_api_key()
+ return super()._prepare_options(options)
+
@property
@override
def auth_headers(self) -> dict[str, str]:
@@ -318,7 +333,7 @@ class OpenAI(SyncAPIClient):
def copy(
self,
*,
- api_key: str | None = None,
+ api_key: str | Callable[[], str] | None = None,
organization: str | None = None,
project: str | None = None,
webhook_secret: str | None = None,
@@ -356,7 +371,7 @@ class OpenAI(SyncAPIClient):
http_client = http_client or self._client
return self.__class__(
- api_key=api_key or self.api_key,
+ api_key=api_key or self._api_key_provider or self.api_key,
organization=organization or self.organization,
project=project or self.project,
webhook_secret=webhook_secret or self.webhook_secret,
@@ -427,7 +442,7 @@ class AsyncOpenAI(AsyncAPIClient):
def __init__(
self,
*,
- api_key: str | None = None,
+ api_key: str | Callable[[], Awaitable[str]] | None = None,
organization: str | None = None,
project: str | None = None,
webhook_secret: str | None = None,
@@ -465,7 +480,12 @@ class AsyncOpenAI(AsyncAPIClient):
raise OpenAIError(
"The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable"
)
- self.api_key = api_key
+ if callable(api_key):
+ self.api_key = ""
+ self._api_key_provider: Callable[[], Awaitable[str]] | None = api_key
+ else:
+ self.api_key = api_key
+ self._api_key_provider = None
if organization is None:
organization = os.environ.get("OPENAI_ORG_ID")
@@ -626,6 +646,15 @@ class AsyncOpenAI(AsyncAPIClient):
def qs(self) -> Querystring:
return Querystring(array_format="brackets")
+ async def _refresh_api_key(self) -> None:
+ if self._api_key_provider:
+ self.api_key = await self._api_key_provider()
+
+ @override
+ async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
+ await self._refresh_api_key()
+ return await super()._prepare_options(options)
+
@property
@override
def auth_headers(self) -> dict[str, str]:
@@ -649,7 +678,7 @@ class AsyncOpenAI(AsyncAPIClient):
def copy(
self,
*,
- api_key: str | None = None,
+ api_key: str | Callable[[], Awaitable[str]] | None = None,
organization: str | None = None,
project: str | None = None,
webhook_secret: str | None = None,
@@ -687,7 +716,7 @@ class AsyncOpenAI(AsyncAPIClient):
http_client = http_client or self._client
return self.__class__(
- api_key=api_key or self.api_key,
+ api_key=api_key or self._api_key_provider or self.api_key,
organization=organization or self.organization,
project=project or self.project,
webhook_secret=webhook_secret or self.webhook_secret,
tests/test_client.py
@@ -11,7 +11,7 @@ import asyncio
import inspect
import subprocess
import tracemalloc
-from typing import Any, Union, cast
+from typing import Any, Union, Protocol, cast
from textwrap import dedent
from unittest import mock
from typing_extensions import Literal
@@ -41,6 +41,10 @@ base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
api_key = "My API Key"
+class MockRequestCall(Protocol):
+ request: httpx.Request
+
+
def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]:
request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
url = httpx.URL(request.url)
@@ -337,7 +341,9 @@ class TestOpenAI:
def test_validate_headers(self) -> None:
client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
- request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
+ options = client._prepare_options(FinalRequestOptions(method="get", url="/foo"))
+ request = client._build_request(options)
+
assert request.headers.get("Authorization") == f"Bearer {api_key}"
with pytest.raises(OpenAIError):
@@ -939,6 +945,62 @@ class TestOpenAI:
assert exc_info.value.response.status_code == 302
assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected"
+ def test_api_key_before_after_refresh_provider(self) -> None:
+ client = OpenAI(base_url=base_url, api_key=lambda: "test_bearer_token")
+
+ assert client.api_key == ""
+ assert "Authorization" not in client.auth_headers
+
+ client._refresh_api_key()
+
+ assert client.api_key == "test_bearer_token"
+ assert client.auth_headers.get("Authorization") == "Bearer test_bearer_token"
+
+ def test_api_key_before_after_refresh_str(self) -> None:
+ client = OpenAI(base_url=base_url, api_key="test_api_key")
+
+ assert client.auth_headers.get("Authorization") == "Bearer test_api_key"
+ client._refresh_api_key()
+
+ assert client.auth_headers.get("Authorization") == "Bearer test_api_key"
+
+ @pytest.mark.respx()
+ def test_api_key_refresh_on_retry(self, respx_mock: MockRouter) -> None:
+ respx_mock.post(base_url + "/chat/completions").mock(
+ side_effect=[
+ httpx.Response(500, json={"error": "server error"}),
+ httpx.Response(200, json={"foo": "bar"}),
+ ]
+ )
+
+ counter = 0
+
+ def token_provider() -> str:
+ nonlocal counter
+
+ counter += 1
+
+ if counter == 1:
+ return "first"
+
+ return "second"
+
+ client = OpenAI(base_url=base_url, api_key=token_provider)
+ client.chat.completions.create(messages=[], model="gpt-4")
+
+ calls = cast("list[MockRequestCall]", respx_mock.calls)
+ assert len(calls) == 2
+
+ assert calls[0].request.headers.get("Authorization") == "Bearer first"
+ assert calls[1].request.headers.get("Authorization") == "Bearer second"
+
+ def test_copy_auth(self) -> None:
+ client = OpenAI(base_url=base_url, api_key=lambda: "test_bearer_token_1").copy(
+ api_key=lambda: "test_bearer_token_2"
+ )
+ client._refresh_api_key()
+ assert client.auth_headers == {"Authorization": "Bearer test_bearer_token_2"}
+
class TestAsyncOpenAI:
client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
@@ -1220,9 +1282,10 @@ class TestAsyncOpenAI:
assert request.headers.get("x-foo") == "stainless"
assert request.headers.get("x-stainless-lang") == "my-overriding-header"
- def test_validate_headers(self) -> None:
+ async def test_validate_headers(self) -> None:
client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
- request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
+ options = await client._prepare_options(FinalRequestOptions(method="get", url="/foo"))
+ request = client._build_request(options)
assert request.headers.get("Authorization") == f"Bearer {api_key}"
with pytest.raises(OpenAIError):
@@ -1887,3 +1950,70 @@ class TestAsyncOpenAI:
assert exc_info.value.response.status_code == 302
assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected"
+
+ @pytest.mark.asyncio
+ async def test_api_key_before_after_refresh_provider(self) -> None:
+ async def mock_api_key_provider():
+ return "test_bearer_token"
+
+ client = AsyncOpenAI(base_url=base_url, api_key=mock_api_key_provider)
+
+ assert client.api_key == ""
+ assert "Authorization" not in client.auth_headers
+
+ await client._refresh_api_key()
+
+ assert client.api_key == "test_bearer_token"
+ assert client.auth_headers.get("Authorization") == "Bearer test_bearer_token"
+
+ @pytest.mark.asyncio
+ async def test_api_key_before_after_refresh_str(self) -> None:
+ client = AsyncOpenAI(base_url=base_url, api_key="test_api_key")
+
+ assert client.auth_headers.get("Authorization") == "Bearer test_api_key"
+ await client._refresh_api_key()
+
+ assert client.auth_headers.get("Authorization") == "Bearer test_api_key"
+
+ @pytest.mark.asyncio
+ @pytest.mark.respx()
+ async def test_bearer_token_refresh_async(self, respx_mock: MockRouter) -> None:
+ respx_mock.post(base_url + "/chat/completions").mock(
+ side_effect=[
+ httpx.Response(500, json={"error": "server error"}),
+ httpx.Response(200, json={"foo": "bar"}),
+ ]
+ )
+
+ counter = 0
+
+ async def token_provider() -> str:
+ nonlocal counter
+
+ counter += 1
+
+ if counter == 1:
+ return "first"
+
+ return "second"
+
+ client = AsyncOpenAI(base_url=base_url, api_key=token_provider)
+ await client.chat.completions.create(messages=[], model="gpt-4")
+
+ calls = cast("list[MockRequestCall]", respx_mock.calls)
+ assert len(calls) == 2
+
+ assert calls[0].request.headers.get("Authorization") == "Bearer first"
+ assert calls[1].request.headers.get("Authorization") == "Bearer second"
+
+ @pytest.mark.asyncio
+ async def test_copy_auth(self) -> None:
+ async def token_provider_1() -> str:
+ return "test_bearer_token_1"
+
+ async def token_provider_2() -> str:
+ return "test_bearer_token_2"
+
+ client = AsyncOpenAI(base_url=base_url, api_key=token_provider_1).copy(api_key=token_provider_2)
+ await client._refresh_api_key()
+ assert client.auth_headers == {"Authorization": "Bearer test_bearer_token_2"}