Commit 3f6ef060

Krista Pratico <krpratic@microsoft.com>
2023-09-22 05:49:55
[azure] enable audio/whisper support (#613)
* enable azure for audio * reorder overloads * add additional tests * add helper function to utils * simplify - azure users will just need to pass model and deployment_id
1 parent e389823
Changed files (1)
openai
api_resources
openai/api_resources/audio.py
@@ -9,7 +9,9 @@ class Audio(APIResource):
     OBJECT_NAME = "audio"
 
     @classmethod
-    def _get_url(cls, action):
+    def _get_url(cls, action, deployment_id=None, api_type=None, api_version=None):
+        if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD):
+            return f"/{cls.azure_api_prefix}/deployments/{deployment_id}/audio/{action}?api-version={api_version}"
         return cls.class_url() + f"/{action}"
 
     @classmethod
@@ -50,6 +52,8 @@ class Audio(APIResource):
         api_type=None,
         api_version=None,
         organization=None,
+        *,
+        deployment_id=None,
         **params,
     ):
         requestor, files, data = cls._prepare_request(
@@ -63,7 +67,8 @@ class Audio(APIResource):
             organization=organization,
             **params,
         )
-        url = cls._get_url("transcriptions")
+        api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
+        url = cls._get_url("transcriptions", deployment_id=deployment_id, api_type=api_type, api_version=api_version)
         response, _, api_key = requestor.request("post", url, files=files, params=data)
         return util.convert_to_openai_object(
             response, api_key, api_version, organization
@@ -79,6 +84,8 @@ class Audio(APIResource):
         api_type=None,
         api_version=None,
         organization=None,
+        *,
+        deployment_id=None,
         **params,
     ):
         requestor, files, data = cls._prepare_request(
@@ -92,7 +99,8 @@ class Audio(APIResource):
             organization=organization,
             **params,
         )
-        url = cls._get_url("translations")
+        api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
+        url = cls._get_url("translations", deployment_id=deployment_id, api_type=api_type, api_version=api_version)
         response, _, api_key = requestor.request("post", url, files=files, params=data)
         return util.convert_to_openai_object(
             response, api_key, api_version, organization
@@ -109,6 +117,8 @@ class Audio(APIResource):
         api_type=None,
         api_version=None,
         organization=None,
+        *,
+        deployment_id=None,
         **params,
     ):
         requestor, files, data = cls._prepare_request(
@@ -122,7 +132,8 @@ class Audio(APIResource):
             organization=organization,
             **params,
         )
-        url = cls._get_url("transcriptions")
+        api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
+        url = cls._get_url("transcriptions", deployment_id=deployment_id, api_type=api_type, api_version=api_version)
         response, _, api_key = requestor.request("post", url, files=files, params=data)
         return util.convert_to_openai_object(
             response, api_key, api_version, organization
@@ -139,6 +150,8 @@ class Audio(APIResource):
         api_type=None,
         api_version=None,
         organization=None,
+        *,
+        deployment_id=None,
         **params,
     ):
         requestor, files, data = cls._prepare_request(
@@ -152,7 +165,8 @@ class Audio(APIResource):
             organization=organization,
             **params,
         )
-        url = cls._get_url("translations")
+        api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
+        url = cls._get_url("translations", deployment_id=deployment_id, api_type=api_type, api_version=api_version)
         response, _, api_key = requestor.request("post", url, files=files, params=data)
         return util.convert_to_openai_object(
             response, api_key, api_version, organization
@@ -168,6 +182,8 @@ class Audio(APIResource):
         api_type=None,
         api_version=None,
         organization=None,
+        *,
+        deployment_id=None,
         **params,
     ):
         requestor, files, data = cls._prepare_request(
@@ -181,7 +197,8 @@ class Audio(APIResource):
             organization=organization,
             **params,
         )
-        url = cls._get_url("transcriptions")
+        api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
+        url = cls._get_url("transcriptions", deployment_id=deployment_id, api_type=api_type, api_version=api_version)
         response, _, api_key = await requestor.arequest(
             "post", url, files=files, params=data
         )
@@ -199,6 +216,8 @@ class Audio(APIResource):
         api_type=None,
         api_version=None,
         organization=None,
+        *,
+        deployment_id=None,
         **params,
     ):
         requestor, files, data = cls._prepare_request(
@@ -212,7 +231,8 @@ class Audio(APIResource):
             organization=organization,
             **params,
         )
-        url = cls._get_url("translations")
+        api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
+        url = cls._get_url("translations", deployment_id=deployment_id, api_type=api_type, api_version=api_version)
         response, _, api_key = await requestor.arequest(
             "post", url, files=files, params=data
         )
@@ -231,6 +251,8 @@ class Audio(APIResource):
         api_type=None,
         api_version=None,
         organization=None,
+        *,
+        deployment_id=None,
         **params,
     ):
         requestor, files, data = cls._prepare_request(
@@ -244,7 +266,8 @@ class Audio(APIResource):
             organization=organization,
             **params,
         )
-        url = cls._get_url("transcriptions")
+        api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
+        url = cls._get_url("transcriptions", deployment_id=deployment_id, api_type=api_type, api_version=api_version)
         response, _, api_key = await requestor.arequest(
             "post", url, files=files, params=data
         )
@@ -263,6 +286,8 @@ class Audio(APIResource):
         api_type=None,
         api_version=None,
         organization=None,
+        *,
+        deployment_id=None,
         **params,
     ):
         requestor, files, data = cls._prepare_request(
@@ -276,7 +301,8 @@ class Audio(APIResource):
             organization=organization,
             **params,
         )
-        url = cls._get_url("translations")
+        api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
+        url = cls._get_url("translations", deployment_id=deployment_id, api_type=api_type, api_version=api_version)
         response, _, api_key = await requestor.arequest(
             "post", url, files=files, params=data
         )