Commit 5481c2ec

Krista Pratico <krpratic@microsoft.com>
2024-12-20 01:28:52
feat(azure): support for the Realtime API (#1963)
1 parent cacd374
Changed files (4)
src
openai
_utils
lib
resources
beta
realtime
src/openai/_utils/__init__.py
@@ -25,6 +25,7 @@ from ._utils import (
     coerce_integer as coerce_integer,
     file_from_path as file_from_path,
     parse_datetime as parse_datetime,
+    is_azure_client as is_azure_client,
     strip_not_given as strip_not_given,
     deepcopy_minimal as deepcopy_minimal,
     get_async_library as get_async_library,
@@ -32,6 +33,7 @@ from ._utils import (
     get_required_header as get_required_header,
     maybe_coerce_boolean as maybe_coerce_boolean,
     maybe_coerce_integer as maybe_coerce_integer,
+    is_async_azure_client as is_async_azure_client,
 )
 from ._typing import (
     is_list_type as is_list_type,
src/openai/_utils/_utils.py
@@ -5,6 +5,7 @@ import re
 import inspect
 import functools
 from typing import (
+    TYPE_CHECKING,
     Any,
     Tuple,
     Mapping,
@@ -30,6 +31,9 @@ _MappingT = TypeVar("_MappingT", bound=Mapping[str, object])
 _SequenceT = TypeVar("_SequenceT", bound=Sequence[object])
 CallableT = TypeVar("CallableT", bound=Callable[..., Any])
 
+if TYPE_CHECKING:
+    from ..lib.azure import AzureOpenAI, AsyncAzureOpenAI
+
 
 def flatten(t: Iterable[Iterable[_T]]) -> list[_T]:
     return [item for sublist in t for item in sublist]
@@ -412,3 +416,15 @@ def json_safe(data: object) -> object:
         return data.isoformat()
 
     return data
+
+
+def is_azure_client(client: object) -> TypeGuard[AzureOpenAI]:
+    from ..lib.azure import AzureOpenAI
+
+    return isinstance(client, AzureOpenAI)
+
+
+def is_async_azure_client(client: object) -> TypeGuard[AsyncAzureOpenAI]:
+    from ..lib.azure import AsyncAzureOpenAI
+
+    return isinstance(client, AsyncAzureOpenAI)
src/openai/lib/azure.py
@@ -7,7 +7,7 @@ from typing_extensions import Self, override
 
 import httpx
 
-from .._types import NOT_GIVEN, Omit, Timeout, NotGiven
+from .._types import NOT_GIVEN, Omit, Query, Timeout, NotGiven
 from .._utils import is_given, is_mapping
 from .._client import OpenAI, AsyncOpenAI
 from .._compat import model_copy
@@ -307,6 +307,21 @@ class AzureOpenAI(BaseAzureClient[httpx.Client, Stream[Any]], OpenAI):
 
         return options
 
+    def _configure_realtime(self, model: str, extra_query: Query) -> tuple[Query, dict[str, str]]:
+        auth_headers = {}
+        query = {
+            **extra_query,
+            "api-version": self._api_version,
+            "deployment": model,
+        }
+        if self.api_key != "<missing API key>":
+            auth_headers = {"api-key": self.api_key}
+        else:
+            token = self._get_azure_ad_token()
+            if token:
+                auth_headers = {"Authorization": f"Bearer {token}"}
+        return query, auth_headers
+
 
 class AsyncAzureOpenAI(BaseAzureClient[httpx.AsyncClient, AsyncStream[Any]], AsyncOpenAI):
     @overload
@@ -555,3 +570,18 @@ class AsyncAzureOpenAI(BaseAzureClient[httpx.AsyncClient, AsyncStream[Any]], Asy
             raise ValueError("Unable to handle auth")
 
         return options
+
+    async def _configure_realtime(self, model: str, extra_query: Query) -> tuple[Query, dict[str, str]]:
+        auth_headers = {}
+        query = {
+            **extra_query,
+            "api-version": self._api_version,
+            "deployment": model,
+        }
+        if self.api_key != "<missing API key>":
+            auth_headers = {"api-key": self.api_key}
+        else:
+            token = await self._get_azure_ad_token()
+            if token:
+                auth_headers = {"Authorization": f"Bearer {token}"}
+        return query, auth_headers
src/openai/resources/beta/realtime/realtime.py
@@ -21,9 +21,11 @@ from .sessions import (
 )
 from ...._types import NOT_GIVEN, Query, Headers, NotGiven
 from ...._utils import (
+    is_azure_client,
     maybe_transform,
     strip_not_given,
     async_maybe_transform,
+    is_async_azure_client,
 )
 from ...._compat import cached_property
 from ...._models import construct_type_unchecked
@@ -319,11 +321,16 @@ class AsyncRealtimeConnectionManager:
         except ImportError as exc:
             raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc
 
+        extra_query = self.__extra_query
+        auth_headers = self.__client.auth_headers
+        if is_async_azure_client(self.__client):
+            extra_query, auth_headers = await self.__client._configure_realtime(self.__model, extra_query)
+
         url = self._prepare_url().copy_with(
             params={
                 **self.__client.base_url.params,
                 "model": self.__model,
-                **self.__extra_query,
+                **extra_query,
             },
         )
         log.debug("Connecting to %s", url)
@@ -336,7 +343,7 @@ class AsyncRealtimeConnectionManager:
                 user_agent_header=self.__client.user_agent,
                 additional_headers=_merge_mappings(
                     {
-                        **self.__client.auth_headers,
+                        **auth_headers,
                         "OpenAI-Beta": "realtime=v1",
                     },
                     self.__extra_headers,
@@ -496,11 +503,16 @@ class RealtimeConnectionManager:
         except ImportError as exc:
             raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc
 
+        extra_query = self.__extra_query
+        auth_headers = self.__client.auth_headers
+        if is_azure_client(self.__client):
+            extra_query, auth_headers = self.__client._configure_realtime(self.__model, extra_query)
+
         url = self._prepare_url().copy_with(
             params={
                 **self.__client.base_url.params,
                 "model": self.__model,
-                **self.__extra_query,
+                **extra_query,
             },
         )
         log.debug("Connecting to %s", url)
@@ -513,7 +525,7 @@ class RealtimeConnectionManager:
                 user_agent_header=self.__client.user_agent,
                 additional_headers=_merge_mappings(
                     {
-                        **self.__client.auth_headers,
+                        **auth_headers,
                         "OpenAI-Beta": "realtime=v1",
                     },
                     self.__extra_headers,