Commit c93af95b

Sorin Suciu <sorin.suciu@microsoft.com>
2022-01-22 08:46:27
Add an option to use Azure endpoints for the /completions & /search operations. (#45) tag: v0.12.0
* Add an option to use Azure endpoints for the /completions operation. * Add the azure endpoints option for the /search operation + small fixes. * errata * Adressed CR comments
1 parent f4be8f2
openai/api_resources/abstract/api_resource.py
@@ -1,11 +1,14 @@
 from urllib.parse import quote_plus
 
 from openai import api_requestor, error, util
+import openai
 from openai.openai_object import OpenAIObject
+from openai.util import ApiType
 
 
 class APIResource(OpenAIObject):
     api_prefix = ""
+    azure_api_prefix = 'openai/deployments'
 
     @classmethod
     def retrieve(cls, id, api_key=None, request_id=None, **params):
@@ -32,7 +35,7 @@ class APIResource(OpenAIObject):
             return "/%s/%ss" % (cls.api_prefix, base)
         return "/%ss" % (base)
 
-    def instance_url(self):
+    def instance_url(self, operation=None):
         id = self.get("id")
 
         if not isinstance(id, str):
@@ -42,10 +45,26 @@ class APIResource(OpenAIObject):
                 " `unicode`)" % (type(self).__name__, id, type(id)),
                 "id",
             )
+        api_version = self.api_version or openai.api_version
 
-        base = self.class_url()
-        extn = quote_plus(id)
-        return "%s/%s" % (base, extn)
+        if self.typed_api_type == ApiType.AZURE:
+            if not api_version:
+                raise error.InvalidRequestError("An API version is required for the Azure API type.")
+            if not operation:
+                raise error.InvalidRequestError(
+                    "The request needs an operation (eg: 'search') for the Azure OpenAI API type."
+                )
+            extn = quote_plus(id)
+            return "/%s/%s/%s?api-version=%s" % (self.azure_api_prefix, extn, operation, api_version)
+
+        elif self.typed_api_type == ApiType.OPEN_AI:
+            base = self.class_url()
+            extn = quote_plus(id)
+            return "%s/%s" % (base, extn)
+
+        else:
+            raise error.InvalidAPIType('Unsupported API type %s' % self.api_type)
+    
 
     # The `method_` and `url_` arguments are suffixed with an underscore to
     # avoid conflicting with actual request parameters in `params`.
openai/api_resources/abstract/engine_api_resource.py
@@ -1,10 +1,13 @@
+from pydoc import apropos
 import time
 from typing import Optional
 from urllib.parse import quote_plus
 
+import openai
 from openai import api_requestor, error, util
 from openai.api_resources.abstract.api_resource import APIResource
 from openai.openai_response import OpenAIResponse
+from openai.util import ApiType
 
 MAX_TIMEOUT = 20
 
@@ -12,26 +15,46 @@ MAX_TIMEOUT = 20
 class EngineAPIResource(APIResource):
     engine_required = True
     plain_old_data = False
+    azure_api_prefix = 'openai/deployments'
 
     def __init__(self, engine: Optional[str] = None, **kwargs):
         super().__init__(engine=engine, **kwargs)
 
     @classmethod
-    def class_url(cls, engine: Optional[str] = None):
+    def class_url(cls, engine: Optional[str] = None, api_type : Optional[str] = None, api_version: Optional[str] = None):
         # Namespaces are separated in object names with periods (.) and in URLs
         # with forward slashes (/), so replace the former with the latter.
         base = cls.OBJECT_NAME.replace(".", "/")  # type: ignore
-        if engine is None:
-            return "/%ss" % (base)
+        typed_api_type = ApiType.from_str(api_type) if api_type else ApiType.from_str(openai.api_type)
+        api_version = api_version or openai.api_version
+
+        if typed_api_type == ApiType.AZURE:
+            if not api_version:
+                raise error.InvalidRequestError("An API version is required for the Azure API type.")
+            if engine is None:
+                raise error.InvalidRequestError(
+                    "You must provide the deployment name in the 'engine' parameter to access the Azure OpenAI service"
+                )
+            extn = quote_plus(engine)
+            return "/%s/%s/%ss?api-version=%s" % (cls.azure_api_prefix, extn, base, api_version)
+
+        elif typed_api_type == ApiType.OPEN_AI:
+            if engine is None:
+                return "/%ss" % (base)
+
+            extn = quote_plus(engine)
+            return "/engines/%s/%ss" % (extn, base)
+
+        else:
+            raise error.InvalidAPIType('Unsupported API type %s' % api_type)
 
-        extn = quote_plus(engine)
-        return "/engines/%s/%ss" % (extn, base)
 
     @classmethod
     def create(
         cls,
         api_key=None,
         api_base=None,
+        api_type=None,
         request_id=None,
         api_version=None,
         organization=None,
@@ -58,10 +81,11 @@ class EngineAPIResource(APIResource):
         requestor = api_requestor.APIRequestor(
             api_key,
             api_base=api_base,
+            api_type=api_type,
             api_version=api_version,
             organization=organization,
         )
-        url = cls.class_url(engine)
+        url = cls.class_url(engine, api_type, api_version)
         response, _, api_key = requestor.request(
             "post", url, params, stream=stream, request_id=request_id
         )
@@ -103,14 +127,28 @@ class EngineAPIResource(APIResource):
                 "id",
             )
 
-        base = self.class_url(self.engine)
-        extn = quote_plus(id)
-        url = "%s/%s" % (base, extn)
+        params_connector = '?'
+        if self.typed_api_type == ApiType.AZURE:
+            api_version = self.api_version or openai.api_version
+            if not api_version:
+                raise error.InvalidRequestError("An API version is required for the Azure API type.")
+            extn = quote_plus(id)
+            base = self.OBJECT_NAME.replace(".", "/")
+            url = "/%s/%s/%ss/%s?api-version=%s" % (self.azure_api_prefix, self.engine, base, extn, api_version)
+            params_connector = '&'
+
+        elif self.typed_api_type == ApiType.OPEN_AI:
+            base = self.class_url(self.engine, self.api_type, self.api_version)
+            extn = quote_plus(id)
+            url = "%s/%s" % (base, extn)
+
+        else:
+            raise error.InvalidAPIType('Unsupported API type %s' % self.api_type)
 
         timeout = self.get("timeout")
         if timeout is not None:
             timeout = quote_plus(str(timeout))
-            url += "?timeout={}".format(timeout)
+            url += params_connector + "timeout={}".format(timeout)
         return url
 
     def wait(self, timeout=None):
openai/api_resources/engine.py
@@ -3,7 +3,8 @@ import warnings
 
 from openai import util
 from openai.api_resources.abstract import ListableAPIResource, UpdateableAPIResource
-from openai.error import TryAgain
+from openai.error import InvalidAPIType, TryAgain
+from openai.util import ApiType
 
 
 class Engine(ListableAPIResource, UpdateableAPIResource):
@@ -27,7 +28,12 @@ class Engine(ListableAPIResource, UpdateableAPIResource):
                 util.log_info("Waiting for model to warm up", error=e)
 
     def search(self, **params):
-        return self.request("post", self.instance_url() + "/search", params)
+        if self.typed_api_type == ApiType.AZURE:
+            return self.request("post", self.instance_url("search"), params)
+        elif self.typed_api_type == ApiType.OPEN_AI:
+            return self.request("post", self.instance_url() + "/search", params)
+        else:
+            raise InvalidAPIType('Unsupported API type %s' % self.api_type)
 
     def embeddings(self, **params):
         warnings.warn(
openai/tests/test_api_requestor.py
@@ -1,11 +1,12 @@
 import json
-
+import pytest
 import requests
 from pytest_mock import MockerFixture
 
 from openai import Model
+from openai.api_requestor import APIRequestor
 
-
+@pytest.mark.requestor
 def test_requestor_sets_request_id(mocker: MockerFixture) -> None:
     # Fake out 'requests' and confirm that the X-Request-Id header is set.
 
@@ -25,3 +26,25 @@ def test_requestor_sets_request_id(mocker: MockerFixture) -> None:
     Model.retrieve("xxx", request_id=fake_request_id)  # arbitrary API resource
     got_request_id = got_headers.get("X-Request-Id")
     assert got_request_id == fake_request_id
+
+@pytest.mark.requestor
+def test_requestor_open_ai_headers() -> None:
+    api_requestor = APIRequestor(key="test_key", api_type="open_ai")
+    headers = {"Test_Header": "Unit_Test_Header"}
+    headers = api_requestor.request_headers(method="get", extra=headers, request_id="test_id")
+    print(headers)
+    assert "Test_Header"in headers
+    assert headers["Test_Header"] == "Unit_Test_Header"
+    assert "Authorization"in headers
+    assert headers["Authorization"] == "Bearer test_key"
+
+@pytest.mark.requestor
+def test_requestor_azure_headers() -> None:
+    api_requestor = APIRequestor(key="test_key", api_type="azure")
+    headers = {"Test_Header": "Unit_Test_Header"}
+    headers = api_requestor.request_headers(method="get", extra=headers, request_id="test_id")
+    print(headers)
+    assert "Test_Header"in headers
+    assert headers["Test_Header"] == "Unit_Test_Header"
+    assert "api-key"in headers
+    assert headers["api-key"] == "test_key"
openai/tests/test_url_composition.py
@@ -0,0 +1,124 @@
+from sys import api_version
+import pytest
+
+from openai import Completion
+from openai import Engine
+from openai.util import ApiType
+
+@pytest.mark.url
+def test_completions_url_composition_azure() -> None:
+    url = Completion.class_url("test_engine", "azure", '2021-11-01-preview')
+    assert url == '/openai/deployments/test_engine/completions?api-version=2021-11-01-preview'
+
+@pytest.mark.url
+def test_completions_url_composition_default() -> None:
+    url = Completion.class_url("test_engine")
+    assert url == '/engines/test_engine/completions'
+
+@pytest.mark.url
+def test_completions_url_composition_open_ai() -> None:
+    url = Completion.class_url("test_engine", "open_ai")
+    assert url == '/engines/test_engine/completions'
+
+@pytest.mark.url
+def test_completions_url_composition_invalid_type() -> None:
+    with pytest.raises(Exception):
+        url = Completion.class_url("test_engine", "invalid")
+
+@pytest.mark.url
+def test_completions_url_composition_instance_url_azure() -> None:
+    completion = Completion(id="test_id", engine="test_engine", api_type="azure", api_version='2021-11-01-preview')
+    url = completion.instance_url()
+    assert url == "/openai/deployments/test_engine/completions/test_id?api-version=2021-11-01-preview"
+
+@pytest.mark.url
+def test_completions_url_composition_instance_url_azure_no_version() -> None:
+    completion = Completion(id="test_id", engine="test_engine", api_type="azure", api_version=None)
+    with pytest.raises(Exception):
+        completion.instance_url()
+
+@pytest.mark.url
+def test_completions_url_composition_instance_url_default() -> None:
+    completion = Completion(id="test_id", engine="test_engine")
+    url = completion.instance_url()
+    assert url == "/engines/test_engine/completions/test_id"
+
+@pytest.mark.url
+def test_completions_url_composition_instance_url_open_ai() -> None:
+    completion = Completion(id="test_id", engine="test_engine", api_type="open_ai", api_version='2021-11-01-preview')
+    url = completion.instance_url()
+    assert url == "/engines/test_engine/completions/test_id"
+
+@pytest.mark.url
+def test_completions_url_composition_instance_url_invalid() -> None:
+    completion = Completion(id="test_id", engine="test_engine", api_type="invalid")
+    with pytest.raises(Exception):
+        url = completion.instance_url()
+
+@pytest.mark.url
+def test_completions_url_composition_instance_url_timeout_azure() -> None:
+    completion = Completion(id="test_id", engine="test_engine", api_type="azure", api_version='2021-11-01-preview')
+    completion["timeout"] = 12
+    url = completion.instance_url()
+    assert url == "/openai/deployments/test_engine/completions/test_id?api-version=2021-11-01-preview&timeout=12"
+
+@pytest.mark.url
+def test_completions_url_composition_instance_url_timeout_openai() -> None:
+    completion = Completion(id="test_id", engine="test_engine", api_type="open_ai" )
+    completion["timeout"] = 12
+    url = completion.instance_url()
+    assert url == "/engines/test_engine/completions/test_id?timeout=12"
+
+@pytest.mark.url
+def test_engine_search_url_composition_azure() -> None:
+    engine = Engine(id="test_id", api_type="azure", api_version='2021-11-01-preview')
+    assert engine.api_type == "azure"
+    assert engine.typed_api_type == ApiType.AZURE
+    url = engine.instance_url("test_operation")
+    assert url == '/openai/deployments/test_id/test_operation?api-version=2021-11-01-preview'
+
+@pytest.mark.url
+def test_engine_search_url_composition_azure_no_version() -> None:
+    engine = Engine(id="test_id", api_type="azure", api_version=None)
+    assert engine.api_type == "azure"
+    assert engine.typed_api_type == ApiType.AZURE
+    with pytest.raises(Exception):
+        engine.instance_url("test_operation")
+
+@pytest.mark.url
+def test_engine_search_url_composition_azure_no_operation() -> None:
+    engine = Engine(id="test_id", api_type="azure", api_version='2021-11-01-preview')
+    assert engine.api_type == "azure"
+    assert engine.typed_api_type == ApiType.AZURE
+    with pytest.raises(Exception):
+        engine.instance_url()    
+
+@pytest.mark.url
+def test_engine_search_url_composition_default() -> None:
+    engine = Engine(id="test_id")
+    assert engine.api_type == None
+    assert engine.typed_api_type == ApiType.OPEN_AI
+    url = engine.instance_url()
+    assert url == '/engines/test_id'
+
+@pytest.mark.url
+def test_engine_search_url_composition_open_ai() -> None:
+    engine = Engine(id="test_id", api_type="open_ai")
+    assert engine.api_type == "open_ai"
+    assert engine.typed_api_type == ApiType.OPEN_AI
+    url = engine.instance_url()
+    assert url == '/engines/test_id'
+
+@pytest.mark.url
+def test_engine_search_url_composition_invalid_type() -> None:
+    engine = Engine(id="test_id", api_type="invalid")
+    assert engine.api_type == "invalid"
+    with pytest.raises(Exception):
+        assert engine.typed_api_type == ApiType.OPEN_AI
+
+@pytest.mark.url
+def test_engine_search_url_composition_invalid_search() -> None:
+    engine = Engine(id="test_id", api_type="invalid")
+    assert engine.api_type == "invalid"
+    with pytest.raises(Exception):
+        engine.search()
openai/__init__.py
@@ -27,7 +27,8 @@ api_key_path: Optional[str] = os.environ.get("OPENAI_API_KEY_PATH")
 
 organization = os.environ.get("OPENAI_ORGANIZATION")
 api_base = os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1")
-api_version = None
+api_type = os.environ.get("OPENAI_API_TYPE", "open_ai")
+api_version = '2021-11-01-preview' if api_type == "azure" else None
 verify_ssl_certs = True  # No effect. Certificates are always verified.
 proxy = None
 app_info = None
@@ -52,6 +53,7 @@ __all__ = [
     "Search",
     "api_base",
     "api_key",
+    "api_type",
     "api_key_path",
     "api_version",
     "app_info",
openai/api_requestor.py
@@ -1,3 +1,4 @@
+from email import header
 import json
 import platform
 import threading
@@ -11,6 +12,7 @@ import requests
 import openai
 from openai import error, util, version
 from openai.openai_response import OpenAIResponse
+from openai.util import ApiType
 
 TIMEOUT_SECS = 600
 MAX_CONNECTION_RETRIES = 2
@@ -69,9 +71,10 @@ def parse_stream(rbody):
 
 
 class APIRequestor:
-    def __init__(self, key=None, api_base=None, api_version=None, organization=None):
+    def __init__(self, key=None, api_base=None, api_type=None, api_version=None, organization=None):
         self.api_base = api_base or openai.api_base
         self.api_key = key or util.default_api_key()
+        self.api_type = ApiType.from_str(api_type) if api_type else ApiType.from_str(openai.api_type)
         self.api_version = api_version or openai.api_version
         self.organization = organization or openai.organization
 
@@ -192,13 +195,14 @@ class APIRequestor:
         headers = {
             "X-OpenAI-Client-User-Agent": json.dumps(ua),
             "User-Agent": user_agent,
-            "Authorization": "Bearer %s" % (self.api_key,),
         }
 
+        headers.update(util.api_key_to_header(self.api_type, self.api_key))
+
         if self.organization:
             headers["OpenAI-Organization"] = self.organization
 
-        if self.api_version is not None:
+        if self.api_version is not None and self.api_type == ApiType.OPEN_AI:
             headers["OpenAI-Version"] = self.api_version
         if request_id is not None:
             headers["X-Request-Id"] = request_id
openai/error.py
@@ -146,6 +146,9 @@ class RateLimitError(OpenAIError):
 class ServiceUnavailableError(OpenAIError):
     pass
 
+class InvalidAPIType(OpenAIError):
+    pass
+
 
 class SignatureVerificationError(OpenAIError):
     def __init__(self, message, sig_header, http_body=None):
openai/openai_object.py
@@ -2,8 +2,10 @@ import json
 from copy import deepcopy
 from typing import Optional
 
+import openai
 from openai import api_requestor, util
 from openai.openai_response import OpenAIResponse
+from openai.util import ApiType
 
 
 class OpenAIObject(dict):
@@ -14,6 +16,7 @@ class OpenAIObject(dict):
         id=None,
         api_key=None,
         api_version=None,
+        api_type=None,
         organization=None,
         response_ms: Optional[int] = None,
         api_base=None,
@@ -30,6 +33,7 @@ class OpenAIObject(dict):
 
         object.__setattr__(self, "api_key", api_key)
         object.__setattr__(self, "api_version", api_version)
+        object.__setattr__(self, "api_type", api_type)
         object.__setattr__(self, "organization", organization)
         object.__setattr__(self, "api_base_override", api_base)
         object.__setattr__(self, "engine", engine)
@@ -90,6 +94,7 @@ class OpenAIObject(dict):
                 self.get("id", None),
                 self.api_key,
                 self.api_version,
+                self.api_type,
                 self.organization,
             ),
             dict(self),  # state
@@ -128,11 +133,13 @@ class OpenAIObject(dict):
         values,
         api_key=None,
         api_version=None,
+        api_type=None,
         organization=None,
         response_ms: Optional[int] = None,
     ):
         self.api_key = api_key or getattr(values, "api_key", None)
         self.api_version = api_version or getattr(values, "api_version", None)
+        self.api_type = api_type or getattr(values, "api_type", None)
         self.organization = organization or getattr(values, "organization", None)
         self._response_ms = response_ms or getattr(values, "_response_ms", None)
 
@@ -164,6 +171,7 @@ class OpenAIObject(dict):
         requestor = api_requestor.APIRequestor(
             key=self.api_key,
             api_base=self.api_base_override or self.api_base(),
+            api_type=self.api_type,
             api_version=self.api_version,
             organization=self.organization,
         )
@@ -233,6 +241,10 @@ class OpenAIObject(dict):
     def openai_id(self):
         return self.id
 
+    @property
+    def typed_api_type(self):
+        return ApiType.from_str(self.api_type) if self.api_type else ApiType.from_str(openai.api_type)
+
     # This class overrides __setitem__ to throw exceptions on inputs that it
     # doesn't like. This can cause problems when we try to copy an object
     # wholesale because some data that's returned from the API may not be valid
@@ -243,6 +255,7 @@ class OpenAIObject(dict):
             self.get("id"),
             self.api_key,
             api_version=self.api_version,
+            api_type=self.api_type,
             organization=self.organization,
         )
 
openai/openai_response.py
@@ -17,4 +17,4 @@ class OpenAIResponse:
     @property
     def response_ms(self) -> Optional[int]:
         h = self._headers.get("Openai-Processing-Ms")
-        return None if h is None else int(h)
+        return None if h is None else round(float(h))
openai/util.py
@@ -3,6 +3,7 @@ import os
 import re
 import sys
 from typing import Optional
+from enum import Enum
 
 import openai
 
@@ -17,6 +18,21 @@ __all__ = [
     "logfmt",
 ]
 
+api_key_to_header = lambda api, key: {"Authorization": f"Bearer {key}"} if api == ApiType.OPEN_AI else {"api-key": f"{key}"}
+
+class ApiType(Enum):
+    AZURE = 1
+    OPEN_AI = 2
+
+    @staticmethod
+    def from_str(label):
+        if label.lower() == 'azure':
+            return ApiType.AZURE
+        elif label.lower() in ('open_ai', 'openai'):
+            return ApiType.OPEN_AI
+        else:
+            raise openai.error.InvalidAPIType("The API type provided in invalid. Please select one of the supported API types: 'azure', 'open_ai'")    
+
 
 def _console_log_level():
     if openai.log in ["debug", "info"]:
openai/version.py
@@ -1,1 +1,1 @@
-VERSION = "0.11.6"
+VERSION = "0.12.0"
.gitignore
@@ -4,4 +4,6 @@
 /public/dist
 __pycache__
 build
-.ipynb_checkpoints
+*.egg
+.vscode/settings.json
+.ipynb_checkpoints
\ No newline at end of file
pytest.ini
@@ -0,0 +1,4 @@
+[pytest]
+markers =
+    url: mark a test as part of the url composition tests.
+    requestor: mark test as part of the api_requestor tests.