Commit 6e159bf0
Changed files (3)
openai
openai/api_resources/image.py
@@ -2,7 +2,7 @@
from typing import Any, List
import openai
-from openai import api_requestor, util
+from openai import api_requestor, error, util
from openai.api_resources.abstract import APIResource
@@ -10,8 +10,11 @@ class Image(APIResource):
OBJECT_NAME = "images"
@classmethod
- def _get_url(cls, action):
- return cls.class_url() + f"/{action}"
+ def _get_url(cls, action, azure_action, api_type, api_version):
+ if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD) and azure_action is not None:
+ return f"/{cls.azure_api_prefix}{cls.class_url()}/{action}:{azure_action}?api-version={api_version}"
+ else:
+ return f"{cls.class_url()}/{action}"
@classmethod
def create(
@@ -31,12 +34,20 @@ class Image(APIResource):
organization=organization,
)
- _, api_version = cls._get_api_type_and_version(api_type, api_version)
+ api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
response, _, api_key = requestor.request(
- "post", cls._get_url("generations"), params
+ "post", cls._get_url("generations", azure_action="submit", api_type=api_type, api_version=api_version), params
)
+ if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD):
+ requestor.api_base = "" # operation_location is a full url
+ response, _, api_key = requestor._poll(
+ "get", response.operation_location,
+ until=lambda response: response.data['status'] in [ 'succeeded' ],
+ failed=lambda response: response.data['status'] in [ 'failed' ]
+ )
+
return util.convert_to_openai_object(
response, api_key, api_version, organization
)
@@ -60,12 +71,20 @@ class Image(APIResource):
organization=organization,
)
- _, api_version = cls._get_api_type_and_version(api_type, api_version)
+ api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
response, _, api_key = await requestor.arequest(
- "post", cls._get_url("generations"), params
+ "post", cls._get_url("generations", azure_action="submit", api_type=api_type, api_version=api_version), params
)
+ if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD):
+ requestor.api_base = "" # operation_location is a full url
+ response, _, api_key = await requestor._apoll(
+ "get", response.operation_location,
+ until=lambda response: response.data['status'] in [ 'succeeded' ],
+ failed=lambda response: response.data['status'] in [ 'failed' ]
+ )
+
return util.convert_to_openai_object(
response, api_key, api_version, organization
)
@@ -88,9 +107,9 @@ class Image(APIResource):
api_version=api_version,
organization=organization,
)
- _, api_version = cls._get_api_type_and_version(api_type, api_version)
+ api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
- url = cls._get_url("variations")
+ url = cls._get_url("variations", azure_action=None, api_type=api_type, api_version=api_version)
files: List[Any] = []
for key, value in params.items():
@@ -109,6 +128,9 @@ class Image(APIResource):
organization=None,
**params,
):
+ if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD):
+ raise error.InvalidAPIType("Variations are not supported by the Azure OpenAI API yet.")
+
requestor, url, files = cls._prepare_create_variation(
image,
api_key,
@@ -136,6 +158,9 @@ class Image(APIResource):
organization=None,
**params,
):
+ if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD):
+ raise error.InvalidAPIType("Variations are not supported by the Azure OpenAI API yet.")
+
requestor, url, files = cls._prepare_create_variation(
image,
api_key,
@@ -171,9 +196,9 @@ class Image(APIResource):
api_version=api_version,
organization=organization,
)
- _, api_version = cls._get_api_type_and_version(api_type, api_version)
+ api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
- url = cls._get_url("edits")
+ url = cls._get_url("edits", azure_action=None, api_type=api_type, api_version=api_version)
files: List[Any] = []
for key, value in params.items():
@@ -195,6 +220,9 @@ class Image(APIResource):
organization=None,
**params,
):
+ if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD):
+ raise error.InvalidAPIType("Edits are not supported by the Azure OpenAI API yet.")
+
requestor, url, files = cls._prepare_create_edit(
image,
mask,
@@ -224,6 +252,9 @@ class Image(APIResource):
organization=None,
**params,
):
+ if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD):
+ raise error.InvalidAPIType("Edits are not supported by the Azure OpenAI API yet.")
+
requestor, url, files = cls._prepare_create_edit(
image,
mask,
openai/api_requestor.py
@@ -1,5 +1,6 @@
import asyncio
import json
+import time
import platform
import sys
import threading
@@ -10,6 +11,7 @@ from json import JSONDecodeError
from typing import (
AsyncGenerator,
AsyncIterator,
+ Callable,
Dict,
Iterator,
Optional,
@@ -151,6 +153,70 @@ class APIRequestor:
str += " (%s)" % (info["url"],)
return str
+ def _check_polling_response(self, response: OpenAIResponse, predicate: Callable[[OpenAIResponse], bool]):
+ if not predicate(response):
+ return
+ error_data = response.data['error']
+ message = error_data.get('message', 'Operation failed')
+ code = error_data.get('code')
+ raise error.OpenAIError(message=message, code=code)
+
+ def _poll(
+ self,
+ method,
+ url,
+ until,
+ failed,
+ params = None,
+ headers = None,
+ interval = None,
+ delay = None
+ ) -> Tuple[Iterator[OpenAIResponse], bool, str]:
+ if delay:
+ time.sleep(delay)
+
+ response, b, api_key = self.request(method, url, params, headers)
+ self._check_polling_response(response, failed)
+ start_time = time.time()
+ while not until(response):
+ if time.time() - start_time > TIMEOUT_SECS:
+ raise error.Timeout("Operation polling timed out.")
+
+ time.sleep(interval or response.retry_after or 10)
+ response, b, api_key = self.request(method, url, params, headers)
+ self._check_polling_response(response, failed)
+
+ response.data = response.data['result']
+ return response, b, api_key
+
+ async def _apoll(
+ self,
+ method,
+ url,
+ until,
+ failed,
+ params = None,
+ headers = None,
+ interval = None,
+ delay = None
+ ) -> Tuple[Iterator[OpenAIResponse], bool, str]:
+ if delay:
+ await asyncio.sleep(delay)
+
+ response, b, api_key = await self.arequest(method, url, params, headers)
+ self._check_polling_response(response, failed)
+ start_time = time.time()
+ while not until(response):
+ if time.time() - start_time > TIMEOUT_SECS:
+ raise error.Timeout("Operation polling timed out.")
+
+ await asyncio.sleep(interval or response.retry_after or 10)
+ response, b, api_key = await self.arequest(method, url, params, headers)
+ self._check_polling_response(response, failed)
+
+ response.data = response.data['result']
+ return response, b, api_key
+
@overload
def request(
self,
openai/openai_response.py
@@ -10,6 +10,17 @@ class OpenAIResponse:
def request_id(self) -> Optional[str]:
return self._headers.get("request-id")
+ @property
+ def retry_after(self) -> Optional[int]:
+ try:
+ return int(self._headers.get("retry-after"))
+ except TypeError:
+ return None
+
+ @property
+ def operation_location(self) -> Optional[str]:
+ return self._headers.get("operation-location")
+
@property
def organization(self) -> Optional[str]:
return self._headers.get("OpenAI-Organization")