Commit 6e159bf0

Gerardo Lecaros <gerardo.lecaros.e@gmail.com>
2023-06-07 07:03:49
Support for Azure Dall-e (#439)
* This PR updates #337 with updates for it to work with the latest API preview --------- Co-authored-by: Christian Mürtz <t-cmurtz@microsoft.com>
1 parent 77e5cf9
Changed files (3)
openai/api_resources/image.py
@@ -2,7 +2,7 @@
 from typing import Any, List
 
 import openai
-from openai import api_requestor, util
+from openai import api_requestor, error, util
 from openai.api_resources.abstract import APIResource
 
 
@@ -10,8 +10,11 @@ class Image(APIResource):
     OBJECT_NAME = "images"
 
     @classmethod
-    def _get_url(cls, action):
-        return cls.class_url() + f"/{action}"
+    def _get_url(cls, action, azure_action, api_type, api_version):
+        if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD) and azure_action is not None:
+            return f"/{cls.azure_api_prefix}{cls.class_url()}/{action}:{azure_action}?api-version={api_version}"
+        else:
+            return f"{cls.class_url()}/{action}"
 
     @classmethod
     def create(
@@ -31,12 +34,20 @@ class Image(APIResource):
             organization=organization,
         )
 
-        _, api_version = cls._get_api_type_and_version(api_type, api_version)
+        api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
 
         response, _, api_key = requestor.request(
-            "post", cls._get_url("generations"), params
+            "post", cls._get_url("generations", azure_action="submit", api_type=api_type, api_version=api_version), params
         )
 
+        if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD):
+            requestor.api_base = "" # operation_location is a full url
+            response, _, api_key = requestor._poll(
+                "get", response.operation_location,
+                until=lambda response: response.data['status'] in [ 'succeeded' ],
+                failed=lambda response: response.data['status'] in [ 'failed' ]
+            )
+
         return util.convert_to_openai_object(
             response, api_key, api_version, organization
         )
@@ -60,12 +71,20 @@ class Image(APIResource):
             organization=organization,
         )
 
-        _, api_version = cls._get_api_type_and_version(api_type, api_version)
+        api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
 
         response, _, api_key = await requestor.arequest(
-            "post", cls._get_url("generations"), params
+            "post", cls._get_url("generations", azure_action="submit", api_type=api_type, api_version=api_version), params
         )
 
+        if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD):
+            requestor.api_base = "" # operation_location is a full url
+            response, _, api_key = await requestor._apoll(
+                "get", response.operation_location,
+                until=lambda response: response.data['status'] in [ 'succeeded' ],
+                failed=lambda response: response.data['status'] in [ 'failed' ]
+            )
+
         return util.convert_to_openai_object(
             response, api_key, api_version, organization
         )
@@ -88,9 +107,9 @@ class Image(APIResource):
             api_version=api_version,
             organization=organization,
         )
-        _, api_version = cls._get_api_type_and_version(api_type, api_version)
+        api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
 
-        url = cls._get_url("variations")
+        url = cls._get_url("variations", azure_action=None, api_type=api_type, api_version=api_version)
 
         files: List[Any] = []
         for key, value in params.items():
@@ -109,6 +128,9 @@ class Image(APIResource):
         organization=None,
         **params,
     ):
+        if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD):
+            raise error.InvalidAPIType("Variations are not supported by the Azure OpenAI API yet.")
+
         requestor, url, files = cls._prepare_create_variation(
             image,
             api_key,
@@ -136,6 +158,9 @@ class Image(APIResource):
         organization=None,
         **params,
     ):
+        if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD):
+            raise error.InvalidAPIType("Variations are not supported by the Azure OpenAI API yet.")
+
         requestor, url, files = cls._prepare_create_variation(
             image,
             api_key,
@@ -171,9 +196,9 @@ class Image(APIResource):
             api_version=api_version,
             organization=organization,
         )
-        _, api_version = cls._get_api_type_and_version(api_type, api_version)
+        api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
 
-        url = cls._get_url("edits")
+        url = cls._get_url("edits", azure_action=None, api_type=api_type, api_version=api_version)
 
         files: List[Any] = []
         for key, value in params.items():
@@ -195,6 +220,9 @@ class Image(APIResource):
         organization=None,
         **params,
     ):
+        if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD):
+            raise error.InvalidAPIType("Edits are not supported by the Azure OpenAI API yet.")
+
         requestor, url, files = cls._prepare_create_edit(
             image,
             mask,
@@ -224,6 +252,9 @@ class Image(APIResource):
         organization=None,
         **params,
     ):
+        if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD):
+            raise error.InvalidAPIType("Edits are not supported by the Azure OpenAI API yet.")
+
         requestor, url, files = cls._prepare_create_edit(
             image,
             mask,
openai/api_requestor.py
@@ -1,5 +1,6 @@
 import asyncio
 import json
+import time
 import platform
 import sys
 import threading
@@ -10,6 +11,7 @@ from json import JSONDecodeError
 from typing import (
     AsyncGenerator,
     AsyncIterator,
+    Callable,
     Dict,
     Iterator,
     Optional,
@@ -151,6 +153,70 @@ class APIRequestor:
             str += " (%s)" % (info["url"],)
         return str
 
+    def _check_polling_response(self, response: OpenAIResponse, predicate: Callable[[OpenAIResponse], bool]):
+        if not predicate(response):
+            return
+        error_data = response.data['error']
+        message = error_data.get('message', 'Operation failed')
+        code = error_data.get('code')
+        raise error.OpenAIError(message=message, code=code)
+
+    def _poll(
+        self,
+        method,
+        url,
+        until,
+        failed,
+        params = None,
+        headers = None,
+        interval = None,
+        delay = None
+    ) -> Tuple[Iterator[OpenAIResponse], bool, str]:
+        if delay:
+            time.sleep(delay)
+
+        response, b, api_key = self.request(method, url, params, headers)
+        self._check_polling_response(response, failed)
+        start_time = time.time()
+        while not until(response):
+            if time.time() - start_time > TIMEOUT_SECS:
+                raise error.Timeout("Operation polling timed out.")
+
+            time.sleep(interval or response.retry_after or 10)
+            response, b, api_key = self.request(method, url, params, headers)
+            self._check_polling_response(response, failed)
+
+        response.data = response.data['result']
+        return response, b, api_key
+
+    async def _apoll(
+        self,
+        method,
+        url,
+        until,
+        failed,
+        params = None,
+        headers = None,
+        interval = None,
+        delay = None
+    ) -> Tuple[Iterator[OpenAIResponse], bool, str]:
+        if delay:
+            await asyncio.sleep(delay)
+
+        response, b, api_key = await self.arequest(method, url, params, headers)
+        self._check_polling_response(response, failed)
+        start_time = time.time()
+        while not until(response):
+            if time.time() - start_time > TIMEOUT_SECS:
+                raise error.Timeout("Operation polling timed out.")
+
+            await asyncio.sleep(interval or response.retry_after or 10)
+            response, b, api_key = await self.arequest(method, url, params, headers)
+            self._check_polling_response(response, failed)
+
+        response.data = response.data['result']
+        return response, b, api_key
+
     @overload
     def request(
         self,
openai/openai_response.py
@@ -10,6 +10,17 @@ class OpenAIResponse:
     def request_id(self) -> Optional[str]:
         return self._headers.get("request-id")
 
+    @property
+    def retry_after(self) -> Optional[int]:
+        try:
+            return int(self._headers.get("retry-after"))
+        except TypeError:
+            return None
+
+    @property
+    def operation_location(self) -> Optional[str]:
+        return self._headers.get("operation-location")
+
     @property
     def organization(self) -> Optional[str]:
         return self._headers.get("OpenAI-Organization")