Commit c93af95b
Changed files (14)
openai/api_resources/abstract/api_resource.py
@@ -1,11 +1,14 @@
from urllib.parse import quote_plus
from openai import api_requestor, error, util
+import openai
from openai.openai_object import OpenAIObject
+from openai.util import ApiType
class APIResource(OpenAIObject):
api_prefix = ""
+ azure_api_prefix = 'openai/deployments'
@classmethod
def retrieve(cls, id, api_key=None, request_id=None, **params):
@@ -32,7 +35,7 @@ class APIResource(OpenAIObject):
return "/%s/%ss" % (cls.api_prefix, base)
return "/%ss" % (base)
- def instance_url(self):
+ def instance_url(self, operation=None):
id = self.get("id")
if not isinstance(id, str):
@@ -42,10 +45,26 @@ class APIResource(OpenAIObject):
" `unicode`)" % (type(self).__name__, id, type(id)),
"id",
)
+ api_version = self.api_version or openai.api_version
- base = self.class_url()
- extn = quote_plus(id)
- return "%s/%s" % (base, extn)
+ if self.typed_api_type == ApiType.AZURE:
+ if not api_version:
+ raise error.InvalidRequestError("An API version is required for the Azure API type.")
+ if not operation:
+ raise error.InvalidRequestError(
+ "The request needs an operation (eg: 'search') for the Azure OpenAI API type."
+ )
+ extn = quote_plus(id)
+ return "/%s/%s/%s?api-version=%s" % (self.azure_api_prefix, extn, operation, api_version)
+
+ elif self.typed_api_type == ApiType.OPEN_AI:
+ base = self.class_url()
+ extn = quote_plus(id)
+ return "%s/%s" % (base, extn)
+
+ else:
+ raise error.InvalidAPIType('Unsupported API type %s' % self.api_type)
+
# The `method_` and `url_` arguments are suffixed with an underscore to
# avoid conflicting with actual request parameters in `params`.
openai/api_resources/abstract/engine_api_resource.py
@@ -1,10 +1,13 @@
+from pydoc import apropos
import time
from typing import Optional
from urllib.parse import quote_plus
+import openai
from openai import api_requestor, error, util
from openai.api_resources.abstract.api_resource import APIResource
from openai.openai_response import OpenAIResponse
+from openai.util import ApiType
MAX_TIMEOUT = 20
@@ -12,26 +15,46 @@ MAX_TIMEOUT = 20
class EngineAPIResource(APIResource):
engine_required = True
plain_old_data = False
+ azure_api_prefix = 'openai/deployments'
def __init__(self, engine: Optional[str] = None, **kwargs):
super().__init__(engine=engine, **kwargs)
@classmethod
- def class_url(cls, engine: Optional[str] = None):
+ def class_url(cls, engine: Optional[str] = None, api_type : Optional[str] = None, api_version: Optional[str] = None):
# Namespaces are separated in object names with periods (.) and in URLs
# with forward slashes (/), so replace the former with the latter.
base = cls.OBJECT_NAME.replace(".", "/") # type: ignore
- if engine is None:
- return "/%ss" % (base)
+ typed_api_type = ApiType.from_str(api_type) if api_type else ApiType.from_str(openai.api_type)
+ api_version = api_version or openai.api_version
+
+ if typed_api_type == ApiType.AZURE:
+ if not api_version:
+ raise error.InvalidRequestError("An API version is required for the Azure API type.")
+ if engine is None:
+ raise error.InvalidRequestError(
+ "You must provide the deployment name in the 'engine' parameter to access the Azure OpenAI service"
+ )
+ extn = quote_plus(engine)
+ return "/%s/%s/%ss?api-version=%s" % (cls.azure_api_prefix, extn, base, api_version)
+
+ elif typed_api_type == ApiType.OPEN_AI:
+ if engine is None:
+ return "/%ss" % (base)
+
+ extn = quote_plus(engine)
+ return "/engines/%s/%ss" % (extn, base)
+
+ else:
+ raise error.InvalidAPIType('Unsupported API type %s' % api_type)
- extn = quote_plus(engine)
- return "/engines/%s/%ss" % (extn, base)
@classmethod
def create(
cls,
api_key=None,
api_base=None,
+ api_type=None,
request_id=None,
api_version=None,
organization=None,
@@ -58,10 +81,11 @@ class EngineAPIResource(APIResource):
requestor = api_requestor.APIRequestor(
api_key,
api_base=api_base,
+ api_type=api_type,
api_version=api_version,
organization=organization,
)
- url = cls.class_url(engine)
+ url = cls.class_url(engine, api_type, api_version)
response, _, api_key = requestor.request(
"post", url, params, stream=stream, request_id=request_id
)
@@ -103,14 +127,28 @@ class EngineAPIResource(APIResource):
"id",
)
- base = self.class_url(self.engine)
- extn = quote_plus(id)
- url = "%s/%s" % (base, extn)
+ params_connector = '?'
+ if self.typed_api_type == ApiType.AZURE:
+ api_version = self.api_version or openai.api_version
+ if not api_version:
+ raise error.InvalidRequestError("An API version is required for the Azure API type.")
+ extn = quote_plus(id)
+ base = self.OBJECT_NAME.replace(".", "/")
+ url = "/%s/%s/%ss/%s?api-version=%s" % (self.azure_api_prefix, self.engine, base, extn, api_version)
+ params_connector = '&'
+
+ elif self.typed_api_type == ApiType.OPEN_AI:
+ base = self.class_url(self.engine, self.api_type, self.api_version)
+ extn = quote_plus(id)
+ url = "%s/%s" % (base, extn)
+
+ else:
+ raise error.InvalidAPIType('Unsupported API type %s' % self.api_type)
timeout = self.get("timeout")
if timeout is not None:
timeout = quote_plus(str(timeout))
- url += "?timeout={}".format(timeout)
+ url += params_connector + "timeout={}".format(timeout)
return url
def wait(self, timeout=None):
openai/api_resources/engine.py
@@ -3,7 +3,8 @@ import warnings
from openai import util
from openai.api_resources.abstract import ListableAPIResource, UpdateableAPIResource
-from openai.error import TryAgain
+from openai.error import InvalidAPIType, TryAgain
+from openai.util import ApiType
class Engine(ListableAPIResource, UpdateableAPIResource):
@@ -27,7 +28,12 @@ class Engine(ListableAPIResource, UpdateableAPIResource):
util.log_info("Waiting for model to warm up", error=e)
def search(self, **params):
- return self.request("post", self.instance_url() + "/search", params)
+ if self.typed_api_type == ApiType.AZURE:
+ return self.request("post", self.instance_url("search"), params)
+ elif self.typed_api_type == ApiType.OPEN_AI:
+ return self.request("post", self.instance_url() + "/search", params)
+ else:
+ raise InvalidAPIType('Unsupported API type %s' % self.api_type)
def embeddings(self, **params):
warnings.warn(
openai/tests/test_api_requestor.py
@@ -1,11 +1,12 @@
import json
-
+import pytest
import requests
from pytest_mock import MockerFixture
from openai import Model
+from openai.api_requestor import APIRequestor
-
+@pytest.mark.requestor
def test_requestor_sets_request_id(mocker: MockerFixture) -> None:
# Fake out 'requests' and confirm that the X-Request-Id header is set.
@@ -25,3 +26,25 @@ def test_requestor_sets_request_id(mocker: MockerFixture) -> None:
Model.retrieve("xxx", request_id=fake_request_id) # arbitrary API resource
got_request_id = got_headers.get("X-Request-Id")
assert got_request_id == fake_request_id
+
+@pytest.mark.requestor
+def test_requestor_open_ai_headers() -> None:
+ api_requestor = APIRequestor(key="test_key", api_type="open_ai")
+ headers = {"Test_Header": "Unit_Test_Header"}
+ headers = api_requestor.request_headers(method="get", extra=headers, request_id="test_id")
+ print(headers)
+ assert "Test_Header"in headers
+ assert headers["Test_Header"] == "Unit_Test_Header"
+ assert "Authorization"in headers
+ assert headers["Authorization"] == "Bearer test_key"
+
+@pytest.mark.requestor
+def test_requestor_azure_headers() -> None:
+ api_requestor = APIRequestor(key="test_key", api_type="azure")
+ headers = {"Test_Header": "Unit_Test_Header"}
+ headers = api_requestor.request_headers(method="get", extra=headers, request_id="test_id")
+ print(headers)
+ assert "Test_Header"in headers
+ assert headers["Test_Header"] == "Unit_Test_Header"
+ assert "api-key"in headers
+ assert headers["api-key"] == "test_key"
openai/tests/test_url_composition.py
@@ -0,0 +1,124 @@
+from sys import api_version
+import pytest
+
+from openai import Completion
+from openai import Engine
+from openai.util import ApiType
+
+@pytest.mark.url
+def test_completions_url_composition_azure() -> None:
+ url = Completion.class_url("test_engine", "azure", '2021-11-01-preview')
+ assert url == '/openai/deployments/test_engine/completions?api-version=2021-11-01-preview'
+
+@pytest.mark.url
+def test_completions_url_composition_default() -> None:
+ url = Completion.class_url("test_engine")
+ assert url == '/engines/test_engine/completions'
+
+@pytest.mark.url
+def test_completions_url_composition_open_ai() -> None:
+ url = Completion.class_url("test_engine", "open_ai")
+ assert url == '/engines/test_engine/completions'
+
+@pytest.mark.url
+def test_completions_url_composition_invalid_type() -> None:
+ with pytest.raises(Exception):
+ url = Completion.class_url("test_engine", "invalid")
+
+@pytest.mark.url
+def test_completions_url_composition_instance_url_azure() -> None:
+ completion = Completion(id="test_id", engine="test_engine", api_type="azure", api_version='2021-11-01-preview')
+ url = completion.instance_url()
+ assert url == "/openai/deployments/test_engine/completions/test_id?api-version=2021-11-01-preview"
+
+@pytest.mark.url
+def test_completions_url_composition_instance_url_azure_no_version() -> None:
+ completion = Completion(id="test_id", engine="test_engine", api_type="azure", api_version=None)
+ with pytest.raises(Exception):
+ completion.instance_url()
+
+@pytest.mark.url
+def test_completions_url_composition_instance_url_default() -> None:
+ completion = Completion(id="test_id", engine="test_engine")
+ url = completion.instance_url()
+ assert url == "/engines/test_engine/completions/test_id"
+
+@pytest.mark.url
+def test_completions_url_composition_instance_url_open_ai() -> None:
+ completion = Completion(id="test_id", engine="test_engine", api_type="open_ai", api_version='2021-11-01-preview')
+ url = completion.instance_url()
+ assert url == "/engines/test_engine/completions/test_id"
+
+@pytest.mark.url
+def test_completions_url_composition_instance_url_invalid() -> None:
+ completion = Completion(id="test_id", engine="test_engine", api_type="invalid")
+ with pytest.raises(Exception):
+ url = completion.instance_url()
+
+@pytest.mark.url
+def test_completions_url_composition_instance_url_timeout_azure() -> None:
+ completion = Completion(id="test_id", engine="test_engine", api_type="azure", api_version='2021-11-01-preview')
+ completion["timeout"] = 12
+ url = completion.instance_url()
+ assert url == "/openai/deployments/test_engine/completions/test_id?api-version=2021-11-01-preview&timeout=12"
+
+@pytest.mark.url
+def test_completions_url_composition_instance_url_timeout_openai() -> None:
+ completion = Completion(id="test_id", engine="test_engine", api_type="open_ai" )
+ completion["timeout"] = 12
+ url = completion.instance_url()
+ assert url == "/engines/test_engine/completions/test_id?timeout=12"
+
+@pytest.mark.url
+def test_engine_search_url_composition_azure() -> None:
+ engine = Engine(id="test_id", api_type="azure", api_version='2021-11-01-preview')
+ assert engine.api_type == "azure"
+ assert engine.typed_api_type == ApiType.AZURE
+ url = engine.instance_url("test_operation")
+ assert url == '/openai/deployments/test_id/test_operation?api-version=2021-11-01-preview'
+
+@pytest.mark.url
+def test_engine_search_url_composition_azure_no_version() -> None:
+ engine = Engine(id="test_id", api_type="azure", api_version=None)
+ assert engine.api_type == "azure"
+ assert engine.typed_api_type == ApiType.AZURE
+ with pytest.raises(Exception):
+ engine.instance_url("test_operation")
+
+@pytest.mark.url
+def test_engine_search_url_composition_azure_no_operation() -> None:
+ engine = Engine(id="test_id", api_type="azure", api_version='2021-11-01-preview')
+ assert engine.api_type == "azure"
+ assert engine.typed_api_type == ApiType.AZURE
+ with pytest.raises(Exception):
+ engine.instance_url()
+
+@pytest.mark.url
+def test_engine_search_url_composition_default() -> None:
+ engine = Engine(id="test_id")
+ assert engine.api_type == None
+ assert engine.typed_api_type == ApiType.OPEN_AI
+ url = engine.instance_url()
+ assert url == '/engines/test_id'
+
+@pytest.mark.url
+def test_engine_search_url_composition_open_ai() -> None:
+ engine = Engine(id="test_id", api_type="open_ai")
+ assert engine.api_type == "open_ai"
+ assert engine.typed_api_type == ApiType.OPEN_AI
+ url = engine.instance_url()
+ assert url == '/engines/test_id'
+
+@pytest.mark.url
+def test_engine_search_url_composition_invalid_type() -> None:
+ engine = Engine(id="test_id", api_type="invalid")
+ assert engine.api_type == "invalid"
+ with pytest.raises(Exception):
+ assert engine.typed_api_type == ApiType.OPEN_AI
+
+@pytest.mark.url
+def test_engine_search_url_composition_invalid_search() -> None:
+ engine = Engine(id="test_id", api_type="invalid")
+ assert engine.api_type == "invalid"
+ with pytest.raises(Exception):
+ engine.search()
openai/__init__.py
@@ -27,7 +27,8 @@ api_key_path: Optional[str] = os.environ.get("OPENAI_API_KEY_PATH")
organization = os.environ.get("OPENAI_ORGANIZATION")
api_base = os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1")
-api_version = None
+api_type = os.environ.get("OPENAI_API_TYPE", "open_ai")
+api_version = '2021-11-01-preview' if api_type == "azure" else None
verify_ssl_certs = True # No effect. Certificates are always verified.
proxy = None
app_info = None
@@ -52,6 +53,7 @@ __all__ = [
"Search",
"api_base",
"api_key",
+ "api_type",
"api_key_path",
"api_version",
"app_info",
openai/api_requestor.py
@@ -1,3 +1,4 @@
+from email import header
import json
import platform
import threading
@@ -11,6 +12,7 @@ import requests
import openai
from openai import error, util, version
from openai.openai_response import OpenAIResponse
+from openai.util import ApiType
TIMEOUT_SECS = 600
MAX_CONNECTION_RETRIES = 2
@@ -69,9 +71,10 @@ def parse_stream(rbody):
class APIRequestor:
- def __init__(self, key=None, api_base=None, api_version=None, organization=None):
+ def __init__(self, key=None, api_base=None, api_type=None, api_version=None, organization=None):
self.api_base = api_base or openai.api_base
self.api_key = key or util.default_api_key()
+ self.api_type = ApiType.from_str(api_type) if api_type else ApiType.from_str(openai.api_type)
self.api_version = api_version or openai.api_version
self.organization = organization or openai.organization
@@ -192,13 +195,14 @@ class APIRequestor:
headers = {
"X-OpenAI-Client-User-Agent": json.dumps(ua),
"User-Agent": user_agent,
- "Authorization": "Bearer %s" % (self.api_key,),
}
+ headers.update(util.api_key_to_header(self.api_type, self.api_key))
+
if self.organization:
headers["OpenAI-Organization"] = self.organization
- if self.api_version is not None:
+ if self.api_version is not None and self.api_type == ApiType.OPEN_AI:
headers["OpenAI-Version"] = self.api_version
if request_id is not None:
headers["X-Request-Id"] = request_id
openai/error.py
@@ -146,6 +146,9 @@ class RateLimitError(OpenAIError):
class ServiceUnavailableError(OpenAIError):
pass
+class InvalidAPIType(OpenAIError):
+ pass
+
class SignatureVerificationError(OpenAIError):
def __init__(self, message, sig_header, http_body=None):
openai/openai_object.py
@@ -2,8 +2,10 @@ import json
from copy import deepcopy
from typing import Optional
+import openai
from openai import api_requestor, util
from openai.openai_response import OpenAIResponse
+from openai.util import ApiType
class OpenAIObject(dict):
@@ -14,6 +16,7 @@ class OpenAIObject(dict):
id=None,
api_key=None,
api_version=None,
+ api_type=None,
organization=None,
response_ms: Optional[int] = None,
api_base=None,
@@ -30,6 +33,7 @@ class OpenAIObject(dict):
object.__setattr__(self, "api_key", api_key)
object.__setattr__(self, "api_version", api_version)
+ object.__setattr__(self, "api_type", api_type)
object.__setattr__(self, "organization", organization)
object.__setattr__(self, "api_base_override", api_base)
object.__setattr__(self, "engine", engine)
@@ -90,6 +94,7 @@ class OpenAIObject(dict):
self.get("id", None),
self.api_key,
self.api_version,
+ self.api_type,
self.organization,
),
dict(self), # state
@@ -128,11 +133,13 @@ class OpenAIObject(dict):
values,
api_key=None,
api_version=None,
+ api_type=None,
organization=None,
response_ms: Optional[int] = None,
):
self.api_key = api_key or getattr(values, "api_key", None)
self.api_version = api_version or getattr(values, "api_version", None)
+ self.api_type = api_type or getattr(values, "api_type", None)
self.organization = organization or getattr(values, "organization", None)
self._response_ms = response_ms or getattr(values, "_response_ms", None)
@@ -164,6 +171,7 @@ class OpenAIObject(dict):
requestor = api_requestor.APIRequestor(
key=self.api_key,
api_base=self.api_base_override or self.api_base(),
+ api_type=self.api_type,
api_version=self.api_version,
organization=self.organization,
)
@@ -233,6 +241,10 @@ class OpenAIObject(dict):
def openai_id(self):
return self.id
+ @property
+ def typed_api_type(self):
+ return ApiType.from_str(self.api_type) if self.api_type else ApiType.from_str(openai.api_type)
+
# This class overrides __setitem__ to throw exceptions on inputs that it
# doesn't like. This can cause problems when we try to copy an object
# wholesale because some data that's returned from the API may not be valid
@@ -243,6 +255,7 @@ class OpenAIObject(dict):
self.get("id"),
self.api_key,
api_version=self.api_version,
+ api_type=self.api_type,
organization=self.organization,
)
openai/openai_response.py
@@ -17,4 +17,4 @@ class OpenAIResponse:
@property
def response_ms(self) -> Optional[int]:
h = self._headers.get("Openai-Processing-Ms")
- return None if h is None else int(h)
+ return None if h is None else round(float(h))
openai/util.py
@@ -3,6 +3,7 @@ import os
import re
import sys
from typing import Optional
+from enum import Enum
import openai
@@ -17,6 +18,21 @@ __all__ = [
"logfmt",
]
+api_key_to_header = lambda api, key: {"Authorization": f"Bearer {key}"} if api == ApiType.OPEN_AI else {"api-key": f"{key}"}
+
+class ApiType(Enum):
+ AZURE = 1
+ OPEN_AI = 2
+
+ @staticmethod
+ def from_str(label):
+ if label.lower() == 'azure':
+ return ApiType.AZURE
+ elif label.lower() in ('open_ai', 'openai'):
+ return ApiType.OPEN_AI
+ else:
+ raise openai.error.InvalidAPIType("The API type provided in invalid. Please select one of the supported API types: 'azure', 'open_ai'")
+
def _console_log_level():
if openai.log in ["debug", "info"]:
openai/version.py
@@ -1,1 +1,1 @@
-VERSION = "0.11.6"
+VERSION = "0.12.0"
.gitignore
@@ -4,4 +4,6 @@
/public/dist
__pycache__
build
-.ipynb_checkpoints
+*.egg
+.vscode/settings.json
+.ipynb_checkpoints
\ No newline at end of file
pytest.ini
@@ -0,0 +1,4 @@
+[pytest]
+markers =
+ url: mark a test as part of the url composition tests.
+ requestor: mark test as part of the api_requestor tests.