Commit 5481c2ec
Changed files (4)
src
openai
src/openai/_utils/__init__.py
@@ -25,6 +25,7 @@ from ._utils import (
coerce_integer as coerce_integer,
file_from_path as file_from_path,
parse_datetime as parse_datetime,
+ is_azure_client as is_azure_client,
strip_not_given as strip_not_given,
deepcopy_minimal as deepcopy_minimal,
get_async_library as get_async_library,
@@ -32,6 +33,7 @@ from ._utils import (
get_required_header as get_required_header,
maybe_coerce_boolean as maybe_coerce_boolean,
maybe_coerce_integer as maybe_coerce_integer,
+ is_async_azure_client as is_async_azure_client,
)
from ._typing import (
is_list_type as is_list_type,
src/openai/_utils/_utils.py
@@ -5,6 +5,7 @@ import re
import inspect
import functools
from typing import (
+ TYPE_CHECKING,
Any,
Tuple,
Mapping,
@@ -30,6 +31,9 @@ _MappingT = TypeVar("_MappingT", bound=Mapping[str, object])
_SequenceT = TypeVar("_SequenceT", bound=Sequence[object])
CallableT = TypeVar("CallableT", bound=Callable[..., Any])
+if TYPE_CHECKING:
+ from ..lib.azure import AzureOpenAI, AsyncAzureOpenAI
+
def flatten(t: Iterable[Iterable[_T]]) -> list[_T]:
return [item for sublist in t for item in sublist]
@@ -412,3 +416,15 @@ def json_safe(data: object) -> object:
return data.isoformat()
return data
+
+
+def is_azure_client(client: object) -> TypeGuard[AzureOpenAI]:
+ from ..lib.azure import AzureOpenAI
+
+ return isinstance(client, AzureOpenAI)
+
+
+def is_async_azure_client(client: object) -> TypeGuard[AsyncAzureOpenAI]:
+ from ..lib.azure import AsyncAzureOpenAI
+
+ return isinstance(client, AsyncAzureOpenAI)
src/openai/lib/azure.py
@@ -7,7 +7,7 @@ from typing_extensions import Self, override
import httpx
-from .._types import NOT_GIVEN, Omit, Timeout, NotGiven
+from .._types import NOT_GIVEN, Omit, Query, Timeout, NotGiven
from .._utils import is_given, is_mapping
from .._client import OpenAI, AsyncOpenAI
from .._compat import model_copy
@@ -307,6 +307,21 @@ class AzureOpenAI(BaseAzureClient[httpx.Client, Stream[Any]], OpenAI):
return options
+ def _configure_realtime(self, model: str, extra_query: Query) -> tuple[Query, dict[str, str]]:
+ auth_headers = {}
+ query = {
+ **extra_query,
+ "api-version": self._api_version,
+ "deployment": model,
+ }
+ if self.api_key != "<missing API key>":
+ auth_headers = {"api-key": self.api_key}
+ else:
+ token = self._get_azure_ad_token()
+ if token:
+ auth_headers = {"Authorization": f"Bearer {token}"}
+ return query, auth_headers
+
class AsyncAzureOpenAI(BaseAzureClient[httpx.AsyncClient, AsyncStream[Any]], AsyncOpenAI):
@overload
@@ -555,3 +570,18 @@ class AsyncAzureOpenAI(BaseAzureClient[httpx.AsyncClient, AsyncStream[Any]], Asy
raise ValueError("Unable to handle auth")
return options
+
+ async def _configure_realtime(self, model: str, extra_query: Query) -> tuple[Query, dict[str, str]]:
+ auth_headers = {}
+ query = {
+ **extra_query,
+ "api-version": self._api_version,
+ "deployment": model,
+ }
+ if self.api_key != "<missing API key>":
+ auth_headers = {"api-key": self.api_key}
+ else:
+ token = await self._get_azure_ad_token()
+ if token:
+ auth_headers = {"Authorization": f"Bearer {token}"}
+ return query, auth_headers
src/openai/resources/beta/realtime/realtime.py
@@ -21,9 +21,11 @@ from .sessions import (
)
from ...._types import NOT_GIVEN, Query, Headers, NotGiven
from ...._utils import (
+ is_azure_client,
maybe_transform,
strip_not_given,
async_maybe_transform,
+ is_async_azure_client,
)
from ...._compat import cached_property
from ...._models import construct_type_unchecked
@@ -319,11 +321,16 @@ class AsyncRealtimeConnectionManager:
except ImportError as exc:
raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc
+ extra_query = self.__extra_query
+ auth_headers = self.__client.auth_headers
+ if is_async_azure_client(self.__client):
+ extra_query, auth_headers = await self.__client._configure_realtime(self.__model, extra_query)
+
url = self._prepare_url().copy_with(
params={
**self.__client.base_url.params,
"model": self.__model,
- **self.__extra_query,
+ **extra_query,
},
)
log.debug("Connecting to %s", url)
@@ -336,7 +343,7 @@ class AsyncRealtimeConnectionManager:
user_agent_header=self.__client.user_agent,
additional_headers=_merge_mappings(
{
- **self.__client.auth_headers,
+ **auth_headers,
"OpenAI-Beta": "realtime=v1",
},
self.__extra_headers,
@@ -496,11 +503,16 @@ class RealtimeConnectionManager:
except ImportError as exc:
raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc
+ extra_query = self.__extra_query
+ auth_headers = self.__client.auth_headers
+ if is_azure_client(self.__client):
+ extra_query, auth_headers = self.__client._configure_realtime(self.__model, extra_query)
+
url = self._prepare_url().copy_with(
params={
**self.__client.base_url.params,
"model": self.__model,
- **self.__extra_query,
+ **extra_query,
},
)
log.debug("Connecting to %s", url)
@@ -513,7 +525,7 @@ class RealtimeConnectionManager:
user_agent_header=self.__client.user_agent,
additional_headers=_merge_mappings(
{
- **self.__client.auth_headers,
+ **auth_headers,
"OpenAI-Beta": "realtime=v1",
},
self.__extra_headers,