Commit 0abf6413
Changed files (30)
openai
api_resources
abstract
openai/api_resources/abstract/api_resource.py
@@ -20,6 +20,13 @@ class APIResource(OpenAIObject):
instance.refresh(request_id=request_id, request_timeout=request_timeout)
return instance
+ @classmethod
+ def aretrieve(
+ cls, id, api_key=None, request_id=None, request_timeout=None, **params
+ ):
+ instance = cls(id, api_key, **params)
+ return instance.arefresh(request_id=request_id, request_timeout=request_timeout)
+
def refresh(self, request_id=None, request_timeout=None):
self.refresh_from(
self.request(
@@ -31,6 +38,17 @@ class APIResource(OpenAIObject):
)
return self
+ async def arefresh(self, request_id=None, request_timeout=None):
+ self.refresh_from(
+ await self.arequest(
+ "get",
+ self.instance_url(operation="refresh"),
+ request_id=request_id,
+ request_timeout=request_timeout,
+ )
+ )
+ return self
+
@classmethod
def class_url(cls):
if cls == APIResource:
@@ -116,6 +134,31 @@ class APIResource(OpenAIObject):
response, api_key, api_version, organization
)
+ @classmethod
+ async def _astatic_request(
+ cls,
+ method_,
+ url_,
+ api_key=None,
+ api_base=None,
+ api_type=None,
+ request_id=None,
+ api_version=None,
+ organization=None,
+ **params,
+ ):
+ requestor = api_requestor.APIRequestor(
+ api_key,
+ api_version=api_version,
+ organization=organization,
+ api_base=api_base,
+ api_type=api_type,
+ )
+ response, _, api_key = await requestor.arequest(
+ method_, url_, params, request_id=request_id
+ )
+ return response
+
@classmethod
def _get_api_type_and_version(
cls, api_type: Optional[str] = None, api_version: Optional[str] = None
openai/api_resources/abstract/createable_api_resource.py
@@ -7,15 +7,13 @@ class CreateableAPIResource(APIResource):
plain_old_data = False
@classmethod
- def create(
+ def __prepare_create_requestor(
cls,
api_key=None,
api_base=None,
api_type=None,
- request_id=None,
api_version=None,
organization=None,
- **params,
):
requestor = api_requestor.APIRequestor(
api_key,
@@ -35,6 +33,26 @@ class CreateableAPIResource(APIResource):
url = cls.class_url()
else:
raise error.InvalidAPIType("Unsupported API type %s" % api_type)
+ return requestor, url
+
+ @classmethod
+ def create(
+ cls,
+ api_key=None,
+ api_base=None,
+ api_type=None,
+ request_id=None,
+ api_version=None,
+ organization=None,
+ **params,
+ ):
+ requestor, url = cls.__prepare_create_requestor(
+ api_key,
+ api_base,
+ api_type,
+ api_version,
+ organization,
+ )
response, _, api_key = requestor.request(
"post", url, params, request_id=request_id
@@ -47,3 +65,34 @@ class CreateableAPIResource(APIResource):
organization,
plain_old_data=cls.plain_old_data,
)
+
+ @classmethod
+ async def acreate(
+ cls,
+ api_key=None,
+ api_base=None,
+ api_type=None,
+ request_id=None,
+ api_version=None,
+ organization=None,
+ **params,
+ ):
+ requestor, url = cls.__prepare_create_requestor(
+ api_key,
+ api_base,
+ api_type,
+ api_version,
+ organization,
+ )
+
+ response, _, api_key = await requestor.arequest(
+ "post", url, params, request_id=request_id
+ )
+
+ return util.convert_to_openai_object(
+ response,
+ api_key,
+ api_version,
+ organization,
+ plain_old_data=cls.plain_old_data,
+ )
openai/api_resources/abstract/deletable_api_resource.py
@@ -1,4 +1,5 @@
from urllib.parse import quote_plus
+from typing import Awaitable
from openai import error
from openai.api_resources.abstract.api_resource import APIResource
@@ -7,7 +8,7 @@ from openai.util import ApiType
class DeletableAPIResource(APIResource):
@classmethod
- def delete(cls, sid, api_type=None, api_version=None, **params):
+ def __prepare_delete(cls, sid, api_type=None, api_version=None):
if isinstance(cls, APIResource):
raise ValueError(".delete may only be called as a class method now.")
@@ -28,7 +29,20 @@ class DeletableAPIResource(APIResource):
url = "%s/%s" % (base, extn)
else:
raise error.InvalidAPIType("Unsupported API type %s" % api_type)
+ return url
+
+ @classmethod
+ def delete(cls, sid, api_type=None, api_version=None, **params):
+ url = cls.__prepare_delete(sid, api_type, api_version)
return cls._static_request(
"delete", url, api_type=api_type, api_version=api_version, **params
)
+
+ @classmethod
+ def adelete(cls, sid, api_type=None, api_version=None, **params) -> Awaitable:
+ url = cls.__prepare_delete(sid, api_type, api_version)
+
+ return cls._astatic_request(
+ "delete", url, api_type=api_type, api_version=api_version, **params
+ )
openai/api_resources/abstract/engine_api_resource.py
@@ -61,12 +61,11 @@ class EngineAPIResource(APIResource):
raise error.InvalidAPIType("Unsupported API type %s" % api_type)
@classmethod
- def create(
+ def __prepare_create_request(
cls,
api_key=None,
api_base=None,
api_type=None,
- request_id=None,
api_version=None,
organization=None,
**params,
@@ -112,6 +111,45 @@ class EngineAPIResource(APIResource):
organization=organization,
)
url = cls.class_url(engine, api_type, api_version)
+ return (
+ deployment_id,
+ engine,
+ timeout,
+ stream,
+ headers,
+ request_timeout,
+ typed_api_type,
+ requestor,
+ url,
+ params,
+ )
+
+ @classmethod
+ def create(
+ cls,
+ api_key=None,
+ api_base=None,
+ api_type=None,
+ request_id=None,
+ api_version=None,
+ organization=None,
+ **params,
+ ):
+ (
+ deployment_id,
+ engine,
+ timeout,
+ stream,
+ headers,
+ request_timeout,
+ typed_api_type,
+ requestor,
+ url,
+ params,
+ ) = cls.__prepare_create_request(
+ api_key, api_base, api_type, api_version, organization, **params
+ )
+
response, _, api_key = requestor.request(
"post",
url,
@@ -151,6 +189,70 @@ class EngineAPIResource(APIResource):
return obj
+ @classmethod
+ async def acreate(
+ cls,
+ api_key=None,
+ api_base=None,
+ api_type=None,
+ request_id=None,
+ api_version=None,
+ organization=None,
+ **params,
+ ):
+ (
+ deployment_id,
+ engine,
+ timeout,
+ stream,
+ headers,
+ request_timeout,
+ typed_api_type,
+ requestor,
+ url,
+ params,
+ ) = cls.__prepare_create_request(
+ api_key, api_base, api_type, api_version, organization, **params
+ )
+ response, _, api_key = await requestor.arequest(
+ "post",
+ url,
+ params=params,
+ headers=headers,
+ stream=stream,
+ request_id=request_id,
+ request_timeout=request_timeout,
+ )
+
+ if stream:
+ # must be an iterator
+ assert not isinstance(response, OpenAIResponse)
+ return (
+ util.convert_to_openai_object(
+ line,
+ api_key,
+ api_version,
+ organization,
+ engine=engine,
+ plain_old_data=cls.plain_old_data,
+ )
+ for line in response
+ )
+ else:
+ obj = util.convert_to_openai_object(
+ response,
+ api_key,
+ api_version,
+ organization,
+ engine=engine,
+ plain_old_data=cls.plain_old_data,
+ )
+
+ if timeout is not None:
+ await obj.await_(timeout=timeout or None)
+
+ return obj
+
def instance_url(self):
id = self.get("id")
@@ -206,3 +308,18 @@ class EngineAPIResource(APIResource):
break
self.refresh()
return self
+
+ async def await_(self, timeout=None):
+ """Async version of `EngineApiResource.wait`"""
+ start = time.time()
+ while self.status != "complete":
+ self.timeout = (
+ min(timeout + start - time.time(), MAX_TIMEOUT)
+ if timeout is not None
+ else MAX_TIMEOUT
+ )
+ if self.timeout < 0:
+ del self.timeout
+ break
+ await self.arefresh()
+ return self
openai/api_resources/abstract/listable_api_resource.py
@@ -9,15 +9,13 @@ class ListableAPIResource(APIResource):
return cls.list(*args, **params).auto_paging_iter()
@classmethod
- def list(
+ def __prepare_list_requestor(
cls,
api_key=None,
- request_id=None,
api_version=None,
organization=None,
api_base=None,
api_type=None,
- **params,
):
requestor = api_requestor.APIRequestor(
api_key,
@@ -38,6 +36,26 @@ class ListableAPIResource(APIResource):
url = cls.class_url()
else:
raise error.InvalidAPIType("Unsupported API type %s" % api_type)
+ return requestor, url
+
+ @classmethod
+ def list(
+ cls,
+ api_key=None,
+ request_id=None,
+ api_version=None,
+ organization=None,
+ api_base=None,
+ api_type=None,
+ **params,
+ ):
+ requestor, url = cls.__prepare_list_requestor(
+ api_key,
+ api_version,
+ organization,
+ api_base,
+ api_type,
+ )
response, _, api_key = requestor.request(
"get", url, params, request_id=request_id
@@ -47,3 +65,31 @@ class ListableAPIResource(APIResource):
)
openai_object._retrieve_params = params
return openai_object
+
+ @classmethod
+ async def alist(
+ cls,
+ api_key=None,
+ request_id=None,
+ api_version=None,
+ organization=None,
+ api_base=None,
+ api_type=None,
+ **params,
+ ):
+ requestor, url = cls.__prepare_list_requestor(
+ api_key,
+ api_version,
+ organization,
+ api_base,
+ api_type,
+ )
+
+ response, _, api_key = await requestor.arequest(
+ "get", url, params, request_id=request_id
+ )
+ openai_object = util.convert_to_openai_object(
+ response, api_key, api_version, organization
+ )
+ openai_object._retrieve_params = params
+ return openai_object
openai/api_resources/abstract/nested_resource_class_methods.py
@@ -3,8 +3,12 @@ from urllib.parse import quote_plus
from openai import api_requestor, util
-def nested_resource_class_methods(
- resource, path=None, operations=None, resource_plural=None
+def _nested_resource_class_methods(
+ resource,
+ path=None,
+ operations=None,
+ resource_plural=None,
+ async_=False,
):
if resource_plural is None:
resource_plural = "%ss" % resource
@@ -43,8 +47,34 @@ def nested_resource_class_methods(
response, api_key, api_version, organization
)
+ async def anested_resource_request(
+ cls,
+ method,
+ url,
+ api_key=None,
+ request_id=None,
+ api_version=None,
+ organization=None,
+ **params,
+ ):
+ requestor = api_requestor.APIRequestor(
+ api_key, api_version=api_version, organization=organization
+ )
+ response, _, api_key = await requestor.arequest(
+ method, url, params, request_id=request_id
+ )
+ return util.convert_to_openai_object(
+ response, api_key, api_version, organization
+ )
+
resource_request_method = "%ss_request" % resource
- setattr(cls, resource_request_method, classmethod(nested_resource_request))
+ setattr(
+ cls,
+ resource_request_method,
+ classmethod(
+ anested_resource_request if async_ else nested_resource_request
+ ),
+ )
for operation in operations:
if operation == "create":
@@ -100,3 +130,25 @@ def nested_resource_class_methods(
return cls
return wrapper
+
+
+def nested_resource_class_methods(
+ resource,
+ path=None,
+ operations=None,
+ resource_plural=None,
+):
+ return _nested_resource_class_methods(
+ resource, path, operations, resource_plural, async_=False
+ )
+
+
+def anested_resource_class_methods(
+ resource,
+ path=None,
+ operations=None,
+ resource_plural=None,
+):
+ return _nested_resource_class_methods(
+ resource, path, operations, resource_plural, async_=True
+ )
openai/api_resources/abstract/updateable_api_resource.py
@@ -1,4 +1,5 @@
from urllib.parse import quote_plus
+from typing import Awaitable
from openai.api_resources.abstract.api_resource import APIResource
@@ -8,3 +9,8 @@ class UpdateableAPIResource(APIResource):
def modify(cls, sid, **params):
url = "%s/%s" % (cls.class_url(), quote_plus(sid))
return cls._static_request("post", url, **params)
+
+ @classmethod
+ def amodify(cls, sid, **params) -> Awaitable:
+ url = "%s/%s" % (cls.class_url(), quote_plus(sid))
+ return cls._astatic_request("patch", url, **params)
openai/api_resources/answer.py
@@ -10,3 +10,8 @@ class Answer(OpenAIObject):
def create(cls, **params):
instance = cls()
return instance.request("post", cls.get_url(), params)
+
+ @classmethod
+ def acreate(cls, **params):
+ instance = cls()
+ return instance.arequest("post", cls.get_url(), params)
openai/api_resources/classification.py
@@ -10,3 +10,8 @@ class Classification(OpenAIObject):
def create(cls, **params):
instance = cls()
return instance.request("post", cls.get_url(), params)
+
+ @classmethod
+ def acreate(cls, **params):
+ instance = cls()
+ return instance.arequest("post", cls.get_url(), params)
openai/api_resources/completion.py
@@ -28,3 +28,23 @@ class Completion(EngineAPIResource):
raise
util.log_info("Waiting for model to warm up", error=e)
+
+ @classmethod
+ async def acreate(cls, *args, **kwargs):
+ """
+ Creates a new completion for the provided prompt and parameters.
+
+ See https://beta.openai.com/docs/api-reference/completions/create for a list
+ of valid parameters.
+ """
+ start = time.time()
+ timeout = kwargs.pop("timeout", None)
+
+ while True:
+ try:
+ return await super().acreate(*args, **kwargs)
+ except TryAgain as e:
+ if timeout is not None and time.time() > start + timeout:
+ raise
+
+ util.log_info("Waiting for model to warm up", error=e)
openai/api_resources/customer.py
@@ -10,3 +10,8 @@ class Customer(OpenAIObject):
def create(cls, customer, endpoint, **params):
instance = cls()
return instance.request("post", cls.get_url(customer, endpoint), params)
+
+ @classmethod
+ def acreate(cls, customer, endpoint, **params):
+ instance = cls()
+ return instance.arequest("post", cls.get_url(customer, endpoint), params)
openai/api_resources/deployment.py
@@ -11,10 +11,7 @@ class Deployment(CreateableAPIResource, ListableAPIResource, DeletableAPIResourc
OBJECT_NAME = "deployments"
@classmethod
- def create(cls, *args, **kwargs):
- """
- Creates a new deployment for the provided prompt and parameters.
- """
+ def _check_create(cls, *args, **kwargs):
typed_api_type, _ = cls._get_api_type_and_version(
kwargs.get("api_type", None), None
)
@@ -45,10 +42,24 @@ class Deployment(CreateableAPIResource, ListableAPIResource, DeletableAPIResourc
param="scale_settings",
)
+ @classmethod
+ def create(cls, *args, **kwargs):
+ """
+ Creates a new deployment for the provided prompt and parameters.
+ """
+ cls._check_create(*args, **kwargs)
return super().create(*args, **kwargs)
@classmethod
- def list(cls, *args, **kwargs):
+ def acreate(cls, *args, **kwargs):
+ """
+ Creates a new deployment for the provided prompt and parameters.
+ """
+ cls._check_create(*args, **kwargs)
+ return super().acreate(*args, **kwargs)
+
+ @classmethod
+ def _check_list(cls, *args, **kwargs):
typed_api_type, _ = cls._get_api_type_and_version(
kwargs.get("api_type", None), None
)
@@ -57,10 +68,18 @@ class Deployment(CreateableAPIResource, ListableAPIResource, DeletableAPIResourc
"Deployment operations are only available for the Azure API type."
)
+ @classmethod
+ def list(cls, *args, **kwargs):
+ cls._check_list(*args, **kwargs)
return super().list(*args, **kwargs)
@classmethod
- def delete(cls, *args, **kwargs):
+ def alist(cls, *args, **kwargs):
+ cls._check_list(*args, **kwargs)
+ return super().alist(*args, **kwargs)
+
+ @classmethod
+ def _check_delete(cls, *args, **kwargs):
typed_api_type, _ = cls._get_api_type_and_version(
kwargs.get("api_type", None), None
)
@@ -69,10 +88,18 @@ class Deployment(CreateableAPIResource, ListableAPIResource, DeletableAPIResourc
"Deployment operations are only available for the Azure API type."
)
+ @classmethod
+ def delete(cls, *args, **kwargs):
+ cls._check_delete(*args, **kwargs)
return super().delete(*args, **kwargs)
@classmethod
- def retrieve(cls, *args, **kwargs):
+ def adelete(cls, *args, **kwargs):
+ cls._check_delete(*args, **kwargs)
+ return super().adelete(*args, **kwargs)
+
+ @classmethod
+ def _check_retrieve(cls, *args, **kwargs):
typed_api_type, _ = cls._get_api_type_and_version(
kwargs.get("api_type", None), None
)
@@ -81,4 +108,12 @@ class Deployment(CreateableAPIResource, ListableAPIResource, DeletableAPIResourc
"Deployment operations are only available for the Azure API type."
)
+ @classmethod
+ def retrieve(cls, *args, **kwargs):
+ cls._check_retrieve(*args, **kwargs)
return super().retrieve(*args, **kwargs)
+
+ @classmethod
+ def aretrieve(cls, *args, **kwargs):
+ cls._check_retrieve(*args, **kwargs)
+ return super().aretrieve(*args, **kwargs)
openai/api_resources/edit.py
@@ -31,3 +31,27 @@ class Edit(EngineAPIResource):
raise
util.log_info("Waiting for model to warm up", error=e)
+
+ @classmethod
+ async def acreate(cls, *args, **kwargs):
+ """
+ Creates a new edit for the provided input, instruction, and parameters.
+ """
+ start = time.time()
+ timeout = kwargs.pop("timeout", None)
+
+ api_type = kwargs.pop("api_type", None)
+ typed_api_type = cls._get_api_type_and_version(api_type=api_type)[0]
+ if typed_api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD):
+ raise error.InvalidAPIType(
+ "This operation is not supported by the Azure OpenAI API yet."
+ )
+
+ while True:
+ try:
+ return await super().acreate(*args, **kwargs)
+ except TryAgain as e:
+ if timeout is not None and time.time() > start + timeout:
+ raise
+
+ util.log_info("Waiting for model to warm up", error=e)
openai/api_resources/embedding.py
@@ -50,3 +50,42 @@ class Embedding(EngineAPIResource):
raise
util.log_info("Waiting for model to warm up", error=e)
+
+ @classmethod
+ async def acreate(cls, *args, **kwargs):
+ """
+ Creates a new embedding for the provided input and parameters.
+
+ See https://beta.openai.com/docs/api-reference/embeddings for a list
+ of valid parameters.
+ """
+ start = time.time()
+ timeout = kwargs.pop("timeout", None)
+
+ user_provided_encoding_format = kwargs.get("encoding_format", None)
+
+ # If encoding format was not explicitly specified, we opaquely use base64 for performance
+ if not user_provided_encoding_format:
+ kwargs["encoding_format"] = "base64"
+
+ while True:
+ try:
+ response = await super().acreate(*args, **kwargs)
+
+ # If a user specifies base64, we'll just return the encoded string.
+ # This is only for the default case.
+ if not user_provided_encoding_format:
+ for data in response.data:
+
+ # If an engine isn't using this optimization, don't do anything
+ if type(data["embedding"]) == str:
+ data["embedding"] = np.frombuffer(
+ base64.b64decode(data["embedding"]), dtype="float32"
+ ).tolist()
+
+ return response
+ except TryAgain as e:
+ if timeout is not None and time.time() > start + timeout:
+ raise
+
+ util.log_info("Waiting for model to warm up", error=e)
openai/api_resources/engine.py
@@ -27,6 +27,23 @@ class Engine(ListableAPIResource, UpdateableAPIResource):
util.log_info("Waiting for model to warm up", error=e)
+ async def agenerate(self, timeout=None, **params):
+ start = time.time()
+ while True:
+ try:
+ return await self.arequest(
+ "post",
+ self.instance_url() + "/generate",
+ params,
+ stream=params.get("stream"),
+ plain_old_data=True,
+ )
+ except TryAgain as e:
+ if timeout is not None and time.time() > start + timeout:
+ raise
+
+ util.log_info("Waiting for model to warm up", error=e)
+
def search(self, **params):
if self.typed_api_type in (ApiType.AZURE, ApiType.AZURE_AD):
return self.request("post", self.instance_url("search"), params)
@@ -35,6 +52,14 @@ class Engine(ListableAPIResource, UpdateableAPIResource):
else:
raise InvalidAPIType("Unsupported API type %s" % self.api_type)
+ def asearch(self, **params):
+ if self.typed_api_type in (ApiType.AZURE, ApiType.AZURE_AD):
+ return self.arequest("post", self.instance_url("search"), params)
+ elif self.typed_api_type == ApiType.OPEN_AI:
+ return self.arequest("post", self.instance_url() + "/search", params)
+ else:
+ raise InvalidAPIType("Unsupported API type %s" % self.api_type)
+
def embeddings(self, **params):
warnings.warn(
"Engine.embeddings is deprecated, use Embedding.create", DeprecationWarning
openai/api_resources/file.py
@@ -12,7 +12,7 @@ class File(ListableAPIResource, DeletableAPIResource):
OBJECT_NAME = "files"
@classmethod
- def create(
+ def __prepare_file_create(
cls,
file,
purpose,
@@ -56,13 +56,69 @@ class File(ListableAPIResource, DeletableAPIResource):
)
else:
files.append(("file", ("file", file, "application/octet-stream")))
+
+ return requestor, url, files
+
+ @classmethod
+ def create(
+ cls,
+ file,
+ purpose,
+ model=None,
+ api_key=None,
+ api_base=None,
+ api_type=None,
+ api_version=None,
+ organization=None,
+ user_provided_filename=None,
+ ):
+ requestor, url, files = cls.__prepare_file_create(
+ file,
+ purpose,
+ model,
+ api_key,
+ api_base,
+ api_type,
+ api_version,
+ organization,
+ user_provided_filename,
+ )
response, _, api_key = requestor.request("post", url, files=files)
return util.convert_to_openai_object(
response, api_key, api_version, organization
)
@classmethod
- def download(
+ async def acreate(
+ cls,
+ file,
+ purpose,
+ model=None,
+ api_key=None,
+ api_base=None,
+ api_type=None,
+ api_version=None,
+ organization=None,
+ user_provided_filename=None,
+ ):
+ requestor, url, files = cls.__prepare_file_create(
+ file,
+ purpose,
+ model,
+ api_key,
+ api_base,
+ api_type,
+ api_version,
+ organization,
+ user_provided_filename,
+ )
+ response, _, api_key = await requestor.arequest("post", url, files=files)
+ return util.convert_to_openai_object(
+ response, api_key, api_version, organization
+ )
+
+ @classmethod
+ def __prepare_file_download(
cls,
id,
api_key=None,
@@ -84,17 +140,33 @@ class File(ListableAPIResource, DeletableAPIResource):
if typed_api_type in (ApiType.AZURE, ApiType.AZURE_AD):
base = cls.class_url()
- url = "/%s%s/%s/content?api-version=%s" % (
+ url = "/%s%s/%s?api-version=%s" % (
cls.azure_api_prefix,
base,
id,
api_version,
)
elif typed_api_type == ApiType.OPEN_AI:
- url = f"{cls.class_url()}/{id}/content"
+ url = "%s/%s" % (cls.class_url(), id)
else:
raise error.InvalidAPIType("Unsupported API type %s" % api_type)
+ return requestor, url
+
+ @classmethod
+ def download(
+ cls,
+ id,
+ api_key=None,
+ api_base=None,
+ api_type=None,
+ api_version=None,
+ organization=None,
+ ):
+ requestor, url = cls.__prepare_file_download(
+ id, api_key, api_base, api_type, api_version, organization
+ )
+
result = requestor.request_raw("get", url)
if not 200 <= result.status_code < 300:
raise requestor.handle_error_response(
@@ -107,25 +179,32 @@ class File(ListableAPIResource, DeletableAPIResource):
return result.content
@classmethod
- def find_matching_files(
+ async def adownload(
cls,
- name,
- bytes,
- purpose,
+ id,
api_key=None,
api_base=None,
api_type=None,
api_version=None,
organization=None,
):
- """Find already uploaded files with the same name, size, and purpose."""
- all_files = cls.list(
- api_key=api_key,
- api_base=api_base or openai.api_base,
- api_type=api_type,
- api_version=api_version,
- organization=organization,
- ).get("data", [])
+ requestor, url = cls.__prepare_file_download(
+ id, api_key, api_base, api_type, api_version, organization
+ )
+
+ result = await requestor.arequest_raw("get", url)
+ if not 200 <= result.status < 300:
+ raise requestor.handle_error_response(
+ result.content,
+ result.status,
+ json.loads(cast(bytes, result.content)),
+ result.headers,
+ stream_error=False,
+ )
+ return result.content
+
+ @classmethod
+ def __find_matching_files(cls, name, all_files, purpose):
matching_files = []
basename = os.path.basename(name)
for f in all_files:
@@ -140,3 +219,49 @@ class File(ListableAPIResource, DeletableAPIResource):
continue
matching_files.append(f)
return matching_files
+
+ @classmethod
+ def find_matching_files(
+ cls,
+ name,
+ bytes,
+ purpose,
+ api_key=None,
+ api_base=None,
+ api_type=None,
+ api_version=None,
+ organization=None,
+ ):
+ """Find already uploaded files with the same name, size, and purpose."""
+ all_files = cls.list(
+ api_key=api_key,
+ api_base=api_base or openai.api_base,
+ api_type=api_type,
+ api_version=api_version,
+ organization=organization,
+ ).get("data", [])
+ return cls.__find_matching_files(name, all_files, purpose)
+
+ @classmethod
+ async def afind_matching_files(
+ cls,
+ name,
+ bytes,
+ purpose,
+ api_key=None,
+ api_base=None,
+ api_type=None,
+ api_version=None,
+ organization=None,
+ ):
+ """Find already uploaded files with the same name, size, and purpose."""
+ all_files = (
+ await cls.alist(
+ api_key=api_key,
+ api_base=api_base or openai.api_base,
+ api_type=api_type,
+ api_version=api_version,
+ organization=organization,
+ )
+ ).get("data", [])
+ return cls.__find_matching_files(name, all_files, purpose)
openai/api_resources/fine_tune.py
@@ -16,7 +16,7 @@ class FineTune(ListableAPIResource, CreateableAPIResource, DeletableAPIResource)
OBJECT_NAME = "fine-tunes"
@classmethod
- def cancel(
+ def _prepare_cancel(
cls,
id,
api_key=None,
@@ -44,10 +44,50 @@ class FineTune(ListableAPIResource, CreateableAPIResource, DeletableAPIResource)
raise error.InvalidAPIType("Unsupported API type %s" % api_type)
instance = cls(id, api_key, **params)
+ return instance, url
+
+ @classmethod
+ def cancel(
+ cls,
+ id,
+ api_key=None,
+ api_type=None,
+ request_id=None,
+ api_version=None,
+ **params,
+ ):
+ instance, url = cls._prepare_cancel(
+ id,
+ api_key,
+ api_type,
+ request_id,
+ api_version,
+ **params,
+ )
return instance.request("post", url, request_id=request_id)
@classmethod
- def stream_events(
+ def acancel(
+ cls,
+ id,
+ api_key=None,
+ api_type=None,
+ request_id=None,
+ api_version=None,
+ **params,
+ ):
+ instance, url = cls._prepare_cancel(
+ id,
+ api_key,
+ api_type,
+ request_id,
+ api_version,
+ **params,
+ )
+ return instance.arequest("post", url, request_id=request_id)
+
+ @classmethod
+ def _prepare_stream_events(
cls,
id,
api_key=None,
@@ -85,7 +125,32 @@ class FineTune(ListableAPIResource, CreateableAPIResource, DeletableAPIResource)
else:
raise error.InvalidAPIType("Unsupported API type %s" % api_type)
- response, _, api_key = requestor.request(
+ return requestor, url
+
+ @classmethod
+ async def stream_events(
+ cls,
+ id,
+ api_key=None,
+ api_base=None,
+ api_type=None,
+ request_id=None,
+ api_version=None,
+ organization=None,
+ **params,
+ ):
+ requestor, url = cls._prepare_stream_events(
+ id,
+ api_key,
+ api_base,
+ api_type,
+ request_id,
+ api_version,
+ organization,
+ **params,
+ )
+
+ response, _, api_key = await requestor.arequest(
"get", url, params, stream=True, request_id=request_id
)
openai/api_resources/image.py
@@ -22,7 +22,12 @@ class Image(APIResource):
return instance.request("post", cls._get_url("generations"), params)
@classmethod
- def create_variation(
+ def acreate(cls, **params):
+ instance = cls()
+ return instance.arequest("post", cls._get_url("generations"), params)
+
+ @classmethod
+ def _prepare_create_variation(
cls,
image,
api_key=None,
@@ -47,6 +52,28 @@ class Image(APIResource):
for key, value in params.items():
files.append((key, (None, value)))
files.append(("image", ("image", image, "application/octet-stream")))
+ return requestor, url, files
+
+ @classmethod
+ def create_variation(
+ cls,
+ image,
+ api_key=None,
+ api_base=None,
+ api_type=None,
+ api_version=None,
+ organization=None,
+ **params,
+ ):
+ requestor, url, files = cls._prepare_create_variation(
+ image,
+ api_key,
+ api_base,
+ api_type,
+ api_version,
+ organization,
+ **params,
+ )
response, _, api_key = requestor.request("post", url, files=files)
@@ -55,7 +82,34 @@ class Image(APIResource):
)
@classmethod
- def create_edit(
+ async def acreate_variation(
+ cls,
+ image,
+ api_key=None,
+ api_base=None,
+ api_type=None,
+ api_version=None,
+ organization=None,
+ **params,
+ ):
+ requestor, url, files = cls._prepare_create_variation(
+ image,
+ api_key,
+ api_base,
+ api_type,
+ api_version,
+ organization,
+ **params,
+ )
+
+ response, _, api_key = await requestor.arequest("post", url, files=files)
+
+ return util.convert_to_openai_object(
+ response, api_key, api_version, organization
+ )
+
+ @classmethod
+ def _prepare_create_edit(
cls,
image,
mask,
@@ -82,9 +136,62 @@ class Image(APIResource):
files.append((key, (None, value)))
files.append(("image", ("image", image, "application/octet-stream")))
files.append(("mask", ("mask", mask, "application/octet-stream")))
+ return requestor, url, files
+
+ @classmethod
+ def create_edit(
+ cls,
+ image,
+ mask,
+ api_key=None,
+ api_base=None,
+ api_type=None,
+ api_version=None,
+ organization=None,
+ **params,
+ ):
+ requestor, url, files = cls._prepare_create_edit(
+ image,
+ mask,
+ api_key,
+ api_base,
+ api_type,
+ api_version,
+ organization,
+ **params,
+ )
response, _, api_key = requestor.request("post", url, files=files)
return util.convert_to_openai_object(
response, api_key, api_version, organization
)
+
+ @classmethod
+ async def acreate_edit(
+ cls,
+ image,
+ mask,
+ api_key=None,
+ api_base=None,
+ api_type=None,
+ api_version=None,
+ organization=None,
+ **params,
+ ):
+ requestor, url, files = cls._prepare_create_edit(
+ image,
+ mask,
+ api_key,
+ api_base,
+ api_type,
+ api_version,
+ organization,
+ **params,
+ )
+
+ response, _, api_key = await requestor.arequest("post", url, files=files)
+
+ return util.convert_to_openai_object(
+ response, api_key, api_version, organization
+ )
openai/api_resources/moderation.py
@@ -11,7 +11,7 @@ class Moderation(OpenAIObject):
return "/moderations"
@classmethod
- def create(cls, input: Union[str, List[str]], model: Optional[str] = None, api_key: Optional[str] = None):
+ def _prepare_create(cls, input, model, api_key):
if model is not None and model not in cls.VALID_MODEL_NAMES:
raise ValueError(
f"The parameter model should be chosen from {cls.VALID_MODEL_NAMES} "
@@ -22,4 +22,24 @@ class Moderation(OpenAIObject):
params = {"input": input}
if model is not None:
params["model"] = model
+ return instance, params
+
+ @classmethod
+ def create(
+ cls,
+ input: Union[str, List[str]],
+ model: Optional[str] = None,
+ api_key: Optional[str] = None,
+ ):
+ instance, params = cls._prepare_create(input, model, api_key)
return instance.request("post", cls.get_url(), params)
+
+ @classmethod
+ def acreate(
+ cls,
+ input: Union[str, List[str]],
+ model: Optional[str] = None,
+ api_key: Optional[str] = None,
+ ):
+ instance, params = cls._prepare_create(input, model, api_key)
+ return instance.arequest("post", cls.get_url(), params)
openai/api_resources/search.py
@@ -28,3 +28,24 @@ class Search(EngineAPIResource):
raise
util.log_info("Waiting for model to warm up", error=e)
+
+ @classmethod
+ async def acreate(cls, *args, **kwargs):
+ """
+ Creates a new search for the provided input and parameters.
+
+ See https://beta.openai.com/docs/api-reference/search for a list
+ of valid parameters.
+ """
+
+ start = time.time()
+ timeout = kwargs.pop("timeout", None)
+
+ while True:
+ try:
+ return await super().acreate(*args, **kwargs)
+ except TryAgain as e:
+ if timeout is not None and time.time() > start + timeout:
+ raise
+
+ util.log_info("Waiting for model to warm up", error=e)
openai/tests/asyncio/__init__.py
openai/tests/asyncio/test_endpoints.py
@@ -0,0 +1,65 @@
+import io
+import json
+
+import pytest
+
+import openai
+from openai import error
+
+
+pytestmark = [pytest.mark.asyncio]
+
+
+# FILE TESTS
+async def test_file_upload():
+ result = await openai.File.acreate(
+ file=io.StringIO(json.dumps({"text": "test file data"})),
+ purpose="search",
+ )
+ assert result.purpose == "search"
+ assert "id" in result
+
+ result = await openai.File.aretrieve(id=result.id)
+ assert result.status == "uploaded"
+
+
+# COMPLETION TESTS
+async def test_completions():
+ result = await openai.Completion.acreate(
+ prompt="This was a test", n=5, engine="ada"
+ )
+ assert len(result.choices) == 5
+
+
+async def test_completions_multiple_prompts():
+ result = await openai.Completion.acreate(
+ prompt=["This was a test", "This was another test"], n=5, engine="ada"
+ )
+ assert len(result.choices) == 10
+
+
+async def test_completions_model():
+ result = await openai.Completion.acreate(prompt="This was a test", n=5, model="ada")
+ assert len(result.choices) == 5
+ assert result.model.startswith("ada")
+
+
+async def test_timeout_raises_error():
+ # A query that should take awhile to return
+ with pytest.raises(error.Timeout):
+ await openai.Completion.acreate(
+ prompt="test" * 1000,
+ n=10,
+ model="ada",
+ max_tokens=100,
+ request_timeout=0.01,
+ )
+
+
+async def test_timeout_does_not_error():
+ # A query that should be fast
+ await openai.Completion.acreate(
+ prompt="test",
+ model="ada",
+ request_timeout=10,
+ )
openai/tests/__init__.py
openai/tests/test_long_examples_validator.py
@@ -20,29 +20,28 @@ def test_long_examples_validator() -> None:
# the order of these matters
unprepared_training_data = [
{"prompt": long_prompt, "completion": long_completion}, # 1 of 2 duplicates
- {"prompt": short_prompt, "completion": short_completion},
+ {"prompt": short_prompt, "completion": short_completion},
{"prompt": long_prompt, "completion": long_completion}, # 2 of 2 duplicates
-
]
with NamedTemporaryFile(suffix="jsonl", mode="w") as training_data:
for prompt_completion_row in unprepared_training_data:
training_data.write(json.dumps(prompt_completion_row) + "\n")
training_data.flush()
-
+
prepared_data_cmd_output = subprocess.run(
- [f"openai tools fine_tunes.prepare_data -f {training_data.name}"],
- stdout=subprocess.PIPE,
- text=True,
+ [f"openai tools fine_tunes.prepare_data -f {training_data.name}"],
+ stdout=subprocess.PIPE,
+ text=True,
input="y\ny\ny\ny\ny", # apply all recommendations, one at a time
stderr=subprocess.PIPE,
encoding="utf-8",
- shell=True
+ shell=True,
)
# validate data was prepared successfully
- assert prepared_data_cmd_output.stderr == ""
+ assert prepared_data_cmd_output.stderr == ""
# validate get_long_indexes() applied during optional_fn() call in long_examples_validator()
assert "indices of the long examples has changed" in prepared_data_cmd_output.stdout
-
- return prepared_data_cmd_output.stdout
\ No newline at end of file
+
+ return prepared_data_cmd_output.stdout
openai/__init__.py
@@ -3,7 +3,8 @@
# Originally forked from the MIT-licensed Stripe Python bindings.
import os
-from typing import Optional
+from contextvars import ContextVar
+from typing import Optional, TYPE_CHECKING
from openai.api_resources import (
Answer,
@@ -24,6 +25,9 @@ from openai.api_resources import (
)
from openai.error import APIError, InvalidRequestError, OpenAIError
+if TYPE_CHECKING:
+ from aiohttp import ClientSession
+
api_key = os.environ.get("OPENAI_API_KEY")
# Path of a file with an API key, whose contents can change. Supercedes
# `api_key` if set. The main use case is volume-mounted Kubernetes secrets,
@@ -44,6 +48,11 @@ ca_bundle_path = None # No longer used, feature was removed
debug = False
log = None # Set to either 'debug' or 'info', controls console logging
+aiosession: ContextVar[Optional["ClientSession"]] = ContextVar(
+ "aiohttp-session", default=None
+) # Acts as a global aiohttp ClientSession that reuses connections.
+# This is user-supplied; otherwise, a session is remade for each request.
+
__all__ = [
"APIError",
"Answer",
openai/api_requestor.py
@@ -1,12 +1,14 @@
+import asyncio
import json
import platform
import sys
import threading
import warnings
from json import JSONDecodeError
-from typing import Dict, Iterator, Optional, Tuple, Union, overload
+from typing import AsyncGenerator, Dict, Iterator, Optional, Tuple, Union, overload
from urllib.parse import urlencode, urlsplit, urlunsplit
+import aiohttp
import requests
if sys.version_info >= (3, 8):
@@ -49,6 +51,20 @@ def _requests_proxies_arg(proxy) -> Optional[Dict[str, str]]:
)
+def _aiohttp_proxies_arg(proxy) -> Optional[str]:
+ """Returns a value suitable for the 'proxies' argument to 'aiohttp.ClientSession.request."""
+ if proxy is None:
+ return None
+ elif isinstance(proxy, str):
+ return proxy
+ elif isinstance(proxy, dict):
+ return proxy["https"] if "https" in proxy else proxy["http"]
+ else:
+ raise ValueError(
+ "'openai.proxy' must be specified as either a string URL or a dict with string URL under the https and/or http keys."
+ )
+
+
def _make_session() -> requests.Session:
if not openai.verify_ssl_certs:
warnings.warn("verify_ssl_certs is ignored; openai always verifies.")
@@ -63,18 +79,32 @@ def _make_session() -> requests.Session:
return s
+def parse_stream_helper(line):
+ if line:
+ if line == b"data: [DONE]":
+ # return here will cause GeneratorExit exception in urllib3
+ # and it will close http connection with TCP Reset
+ return None
+ if hasattr(line, "decode"):
+ line = line.decode("utf-8")
+ if line.startswith("data: "):
+ line = line[len("data: ") :]
+ return line
+ return None
+
+
def parse_stream(rbody):
for line in rbody:
- if line:
- if line == b"data: [DONE]":
- # return here will cause GeneratorExit exception in urllib3
- # and it will close http connection with TCP Reset
- continue
- if hasattr(line, "decode"):
- line = line.decode("utf-8")
- if line.startswith("data: "):
- line = line[len("data: ") :]
- yield line
+ _line = parse_stream_helper(line)
+ if _line is not None:
+ yield _line
+
+
+async def parse_stream_async(rbody: aiohttp.StreamReader):
+ async for line in rbody:
+ _line = parse_stream_helper(line)
+ if _line is not None:
+ yield _line
class APIRequestor:
@@ -186,6 +216,86 @@ class APIRequestor:
resp, got_stream = self._interpret_response(result, stream)
return resp, got_stream, self.api_key
+ @overload
+ async def arequest(
+ self,
+ method,
+ url,
+ params,
+ headers,
+ files,
+ stream: Literal[True],
+ request_id: Optional[str] = ...,
+ request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
+ ) -> Tuple[AsyncGenerator[OpenAIResponse, None], bool, str]:
+ pass
+
+ @overload
+ async def arequest(
+ self,
+ method,
+ url,
+ params=...,
+ headers=...,
+ files=...,
+ *,
+ stream: Literal[True],
+ request_id: Optional[str] = ...,
+ request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
+ ) -> Tuple[AsyncGenerator[OpenAIResponse, None], bool, str]:
+ pass
+
+ @overload
+ async def arequest(
+ self,
+ method,
+ url,
+ params=...,
+ headers=...,
+ files=...,
+ stream: Literal[False] = ...,
+ request_id: Optional[str] = ...,
+ request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
+ ) -> Tuple[OpenAIResponse, bool, str]:
+ pass
+
+ @overload
+ async def arequest(
+ self,
+ method,
+ url,
+ params=...,
+ headers=...,
+ files=...,
+ stream: bool = ...,
+ request_id: Optional[str] = ...,
+ request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
+ ) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool, str]:
+ pass
+
+ async def arequest(
+ self,
+ method,
+ url,
+ params=None,
+ headers=None,
+ files=None,
+ stream: bool = False,
+ request_id: Optional[str] = None,
+ request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
+ ) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool, str]:
+ result = await self.arequest_raw(
+ method.lower(),
+ url,
+ params=params,
+ supplied_headers=headers,
+ files=files,
+ request_id=request_id,
+ request_timeout=request_timeout,
+ )
+ resp, got_stream = await self._interpret_async_response(result, stream)
+ return resp, got_stream, self.api_key
+
def handle_error_response(self, rbody, rcode, resp, rheaders, stream_error=False):
try:
error_data = resp["error"]
@@ -315,18 +425,15 @@ class APIRequestor:
return headers
- def request_raw(
+ def _prepare_request_raw(
self,
- method,
url,
- *,
- params=None,
- supplied_headers: Dict[str, str] = None,
- files=None,
- stream: bool = False,
- request_id: Optional[str] = None,
- request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
- ) -> requests.Response:
+ supplied_headers,
+ method,
+ params,
+ files,
+ request_id: Optional[str],
+ ) -> Tuple[str, Dict[str, str], Optional[bytes]]:
abs_url = "%s%s" % (self.api_base, url)
headers = self._validate_headers(supplied_headers)
@@ -355,6 +462,24 @@ class APIRequestor:
util.log_info("Request to OpenAI API", method=method, path=abs_url)
util.log_debug("Post details", data=data, api_version=self.api_version)
+ return abs_url, headers, data
+
+ def request_raw(
+ self,
+ method,
+ url,
+ *,
+ params=None,
+ supplied_headers: Optional[Dict[str, str]] = None,
+ files=None,
+ stream: bool = False,
+ request_id: Optional[str] = None,
+ request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
+ ) -> requests.Response:
+ abs_url, headers, data = self._prepare_request_raw(
+ url, supplied_headers, method, params, files, request_id
+ )
+
if not hasattr(_thread_context, "session"):
_thread_context.session = _make_session()
try:
@@ -385,6 +510,71 @@ class APIRequestor:
)
return result
+ async def arequest_raw(
+ self,
+ method,
+ url,
+ *,
+ params=None,
+ supplied_headers: Optional[Dict[str, str]] = None,
+ files=None,
+ request_id: Optional[str] = None,
+ request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
+ ) -> aiohttp.ClientResponse:
+ abs_url, headers, data = self._prepare_request_raw(
+ url, supplied_headers, method, params, files, request_id
+ )
+
+ if isinstance(request_timeout, tuple):
+ timeout = aiohttp.ClientTimeout(
+ connect=request_timeout[0],
+ total=request_timeout[1],
+ )
+ else:
+ timeout = aiohttp.ClientTimeout(
+ total=request_timeout if request_timeout else TIMEOUT_SECS
+ )
+ user_set_session = openai.aiosession.get()
+
+ if files:
+ # TODO: Use `aiohttp.MultipartWriter` to create the multipart form data here.
+ # For now we use the private `requests` method that is known to have worked so far.
+ data, content_type = requests.models.RequestEncodingMixin._encode_files( # type: ignore
+ files, data
+ )
+ headers["Content-Type"] = content_type
+ request_kwargs = {
+ "method": method,
+ "url": abs_url,
+ "headers": headers,
+ "data": data,
+ "proxy": _aiohttp_proxies_arg(openai.proxy),
+ "timeout": timeout,
+ }
+ try:
+ if user_set_session:
+ result = await user_set_session.request(**request_kwargs)
+ else:
+ async with aiohttp.ClientSession() as session:
+ result = await session.request(**request_kwargs)
+ util.log_info(
+ "OpenAI API response",
+ path=abs_url,
+ response_code=result.status,
+ processing_ms=result.headers.get("OpenAI-Processing-Ms"),
+ request_id=result.headers.get("X-Request-Id"),
+ )
+ # Don't read the whole stream for debug logging unless necessary.
+ if openai.log == "debug":
+ util.log_debug(
+ "API response body", body=result.content, headers=result.headers
+ )
+ return result
+ except (aiohttp.ServerTimeoutError, asyncio.TimeoutError) as e:
+ raise error.Timeout("Request timed out") from e
+ except aiohttp.ClientError as e:
+ raise error.APIConnectionError("Error communicating with OpenAI") from e
+
def _interpret_response(
self, result: requests.Response, stream: bool
) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool]:
@@ -404,6 +594,29 @@ class APIRequestor:
False,
)
+ async def _interpret_async_response(
+ self, result: aiohttp.ClientResponse, stream: bool
+ ) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool]:
+ """Returns the response(s) and a bool indicating whether it is a stream."""
+ if stream and "text/event-stream" in result.headers.get("Content-Type", ""):
+ return (
+ self._interpret_response_line(
+ line, result.status, result.headers, stream=True
+ )
+ async for line in parse_stream_async(result.content)
+ ), True
+ else:
+ try:
+ await result.read()
+ except aiohttp.ClientError as e:
+ util.log_warn(e, body=result.content)
+ return (
+ self._interpret_response_line(
+ await result.read(), result.status, result.headers, stream=False
+ ),
+ False,
+ )
+
def _interpret_response_line(
self, rbody, rcode, rheaders, stream: bool
) -> OpenAIResponse:
openai/embeddings_utils.py
@@ -23,6 +23,19 @@ def get_embedding(text: str, engine="text-similarity-davinci-001") -> List[float
return openai.Embedding.create(input=[text], engine=engine)["data"][0]["embedding"]
+@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
+async def aget_embedding(
+ text: str, engine="text-similarity-davinci-001"
+) -> List[float]:
+
+ # replace newlines, which can negatively affect performance.
+ text = text.replace("\n", " ")
+
+ return (await openai.Embedding.acreate(input=[text], engine=engine))["data"][0][
+ "embedding"
+ ]
+
+
@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
def get_embeddings(
list_of_text: List[str], engine="text-similarity-babbage-001"
@@ -37,6 +50,20 @@ def get_embeddings(
return [d["embedding"] for d in data]
+@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
+async def aget_embeddings(
+ list_of_text: List[str], engine="text-similarity-babbage-001"
+) -> List[List[float]]:
+ assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
+
+ # replace newlines, which can negatively affect performance.
+ list_of_text = [text.replace("\n", " ") for text in list_of_text]
+
+ data = (await openai.Embedding.acreate(input=list_of_text, engine=engine)).data
+ data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input.
+ return [d["embedding"] for d in data]
+
+
def cosine_similarity(a, b):
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
openai/openai_object.py
@@ -207,6 +207,57 @@ class OpenAIObject(dict):
plain_old_data=plain_old_data,
)
+ async def arequest(
+ self,
+ method,
+ url,
+ params=None,
+ headers=None,
+ stream=False,
+ plain_old_data=False,
+ request_id: Optional[str] = None,
+ request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
+ ):
+ if params is None:
+ params = self._retrieve_params
+ requestor = api_requestor.APIRequestor(
+ key=self.api_key,
+ api_base=self.api_base_override or self.api_base(),
+ api_type=self.api_type,
+ api_version=self.api_version,
+ organization=self.organization,
+ )
+ response, stream, api_key = await requestor.arequest(
+ method,
+ url,
+ params=params,
+ stream=stream,
+ headers=headers,
+ request_id=request_id,
+ request_timeout=request_timeout,
+ )
+
+ if stream:
+ assert not isinstance(response, OpenAIResponse) # must be an iterator
+ return (
+ util.convert_to_openai_object(
+ line,
+ api_key,
+ self.api_version,
+ self.organization,
+ plain_old_data=plain_old_data,
+ )
+ for line in response
+ )
+ else:
+ return util.convert_to_openai_object(
+ response,
+ api_key,
+ self.api_version,
+ self.organization,
+ plain_old_data=plain_old_data,
+ )
+
def __repr__(self):
ident_parts = [type(self).__name__]
README.md
@@ -211,6 +211,32 @@ image_resp = openai.Image.create(prompt="two dogs playing chess, oil painting",
```
+## Async API
+
+Async support is available in the API by prepending `a` to a network-bound method:
+
+```python
+import openai
+openai.api_key = "sk-..." # supply your API key however you choose
+
+async def create_completion():
+ completion_resp = await openai.Completion.acreate(prompt="This is a test", engine="davinci")
+
+```
+
+To make async requests more efficient, you can pass in your own
+``aiohttp.ClientSession``, but you must manually close the client session at the end
+of your program/event loop:
+
+```python
+import openai
+from aiohttp import ClientSession
+
+openai.aiosession.set(ClientSession())
+# At the end of your program, close the http session
+await openai.aiosession.get().close()
+```
+
See the [usage guide](https://beta.openai.com/docs/guides/images) for more details.
## Requirements
setup.py
@@ -26,9 +26,10 @@ setup(
"openpyxl>=3.0.7", # Needed for CLI fine-tuning data preparation tool xlsx format
"numpy",
'typing_extensions;python_version<"3.8"', # Needed for type hints for mypy
+ "aiohttp", # Needed for async support
],
extras_require={
- "dev": ["black~=21.6b0", "pytest==6.*"],
+ "dev": ["black~=21.6b0", "pytest==6.*", "pytest-asyncio", "pytest-mock"],
"wandb": ["wandb"],
"embeddings": [
"scikit-learn>=1.0.2", # Needed for embedding utils, versions >= 1.1 require python 3.8