Commit 3edecbee

hallacy <hallacy@openai.com>
2022-08-24 11:06:43
Hallacy/v23 (#116) tag: v0.23.0
* overload output type depending on stream literal (#142) * Bump to v22 * [numpy] change version (#143) * [numpy] change version * update comments * no version for numpy * Fix timeouts (#137) * Fix timeouts * Rename to request_timeout and add to readme * Dev/hallacy/request timeout takes tuples (#144) * Add tuple typing for request_timeout * imports * [api_requestor] Log request_id with response (#145) * Only import wandb as needed (#146) Co-authored-by: Felipe Petroski Such <felipe@openai.com> Co-authored-by: Henrique Oliveira Pinto <hponde@openai.com> Co-authored-by: Rachel Lim <rachel@openai.com>
1 parent f3e3083
openai/api_resources/abstract/api_resource.py
@@ -13,14 +13,21 @@ class APIResource(OpenAIObject):
     azure_deployments_prefix = "deployments"
 
     @classmethod
-    def retrieve(cls, id, api_key=None, request_id=None, **params):
+    def retrieve(
+        cls, id, api_key=None, request_id=None, request_timeout=None, **params
+    ):
         instance = cls(id, api_key, **params)
-        instance.refresh(request_id=request_id)
+        instance.refresh(request_id=request_id, request_timeout=request_timeout)
         return instance
 
-    def refresh(self, request_id=None):
+    def refresh(self, request_id=None, request_timeout=None):
         self.refresh_from(
-            self.request("get", self.instance_url(), request_id=request_id)
+            self.request(
+                "get",
+                self.instance_url(),
+                request_id=request_id,
+                request_timeout=request_timeout,
+            )
         )
         return self
 
openai/api_resources/abstract/engine_api_resource.py
@@ -77,7 +77,7 @@ class EngineAPIResource(APIResource):
         timeout = params.pop("timeout", None)
         stream = params.get("stream", False)
         headers = params.pop("headers", None)
-
+        request_timeout = params.pop("request_timeout", None)
         typed_api_type = cls._get_api_type_and_version(api_type=api_type)[0]
         if typed_api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD):
             if deployment_id is None and engine is None:
@@ -119,6 +119,7 @@ class EngineAPIResource(APIResource):
             headers=headers,
             stream=stream,
             request_id=request_id,
+            request_timeout=request_timeout,
         )
 
         if stream:
openai/tests/test_endpoints.py
@@ -1,7 +1,10 @@
 import io
 import json
 
+import pytest
+
 import openai
+from openai import error
 
 
 # FILE TESTS
@@ -34,3 +37,24 @@ def test_completions_model():
     result = openai.Completion.create(prompt="This was a test", n=5, model="ada")
     assert len(result.choices) == 5
     assert result.model.startswith("ada")
+
+
+def test_timeout_raises_error():
+    # A query that should take awhile to return
+    with pytest.raises(error.Timeout):
+        openai.Completion.create(
+            prompt="test" * 1000,
+            n=10,
+            model="ada",
+            max_tokens=100,
+            request_timeout=0.01,
+        )
+
+
+def test_timeout_does_not_error():
+    # A query that should be fast
+    openai.Completion.create(
+        prompt="test",
+        model="ada",
+        request_timeout=10,
+    )
openai/tests/test_exceptions.py
@@ -21,6 +21,7 @@ EXCEPTION_TEST_CASES = [
     openai.error.SignatureVerificationError("message", "sig_header?"),
     openai.error.APIConnectionError("message!", should_retry=True),
     openai.error.TryAgain(),
+    openai.error.Timeout(),
     openai.error.APIError(
         message="message",
         code=400,
openai/api_requestor.py
@@ -3,10 +3,11 @@ import platform
 import threading
 import warnings
 from json import JSONDecodeError
-from typing import Dict, Iterator, Optional, Tuple, Union
+from typing import Dict, Iterator, Optional, Tuple, Union, overload
 from urllib.parse import urlencode, urlsplit, urlunsplit
 
 import requests
+from typing_extensions import Literal
 
 import openai
 from openai import error, util, version
@@ -99,6 +100,63 @@ class APIRequestor:
             str += " (%s)" % (info["url"],)
         return str
 
+    @overload
+    def request(
+        self,
+        method,
+        url,
+        params,
+        headers,
+        files,
+        stream: Literal[True],
+        request_id: Optional[str] = ...,
+        request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
+    ) -> Tuple[Iterator[OpenAIResponse], bool, str]:
+        pass
+
+    @overload
+    def request(
+        self,
+        method,
+        url,
+        params=...,
+        headers=...,
+        files=...,
+        *,
+        stream: Literal[True],
+        request_id: Optional[str] = ...,
+        request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
+    ) -> Tuple[Iterator[OpenAIResponse], bool, str]:
+        pass
+
+    @overload
+    def request(
+        self,
+        method,
+        url,
+        params=...,
+        headers=...,
+        files=...,
+        stream: Literal[False] = ...,
+        request_id: Optional[str] = ...,
+        request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
+    ) -> Tuple[OpenAIResponse, bool, str]:
+        pass
+
+    @overload
+    def request(
+        self,
+        method,
+        url,
+        params=...,
+        headers=...,
+        files=...,
+        stream: bool = ...,
+        request_id: Optional[str] = ...,
+        request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
+    ) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool, str]:
+        pass
+
     def request(
         self,
         method,
@@ -106,8 +164,9 @@ class APIRequestor:
         params=None,
         headers=None,
         files=None,
-        stream=False,
+        stream: bool = False,
         request_id: Optional[str] = None,
+        request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
     ) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool, str]:
         result = self.request_raw(
             method.lower(),
@@ -117,6 +176,7 @@ class APIRequestor:
             files=files,
             stream=stream,
             request_id=request_id,
+            request_timeout=request_timeout,
         )
         resp, got_stream = self._interpret_response(result, stream)
         return resp, got_stream, self.api_key
@@ -179,7 +239,11 @@ class APIRequestor:
             return error.APIError(message, rbody, rcode, resp, rheaders)
         else:
             return error.APIError(
-                error_data.get("message"), rbody, rcode, resp, rheaders
+                f"{error_data.get('message')} {rbody} {rcode} {resp} {rheaders}",
+                rbody,
+                rcode,
+                resp,
+                rheaders,
             )
 
     def request_headers(
@@ -256,6 +320,7 @@ class APIRequestor:
         files=None,
         stream: bool = False,
         request_id: Optional[str] = None,
+        request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
     ) -> requests.Response:
         abs_url = "%s%s" % (self.api_base, url)
         headers = self._validate_headers(supplied_headers)
@@ -295,8 +360,10 @@ class APIRequestor:
                 data=data,
                 files=files,
                 stream=stream,
-                timeout=TIMEOUT_SECS,
+                timeout=request_timeout if request_timeout else TIMEOUT_SECS,
             )
+        except requests.exceptions.Timeout as e:
+            raise error.Timeout("Request timed out") from e
         except requests.exceptions.RequestException as e:
             raise error.APIConnectionError("Error communicating with OpenAI") from e
         util.log_info(
@@ -304,6 +371,7 @@ class APIRequestor:
             path=abs_url,
             response_code=result.status_code,
             processing_ms=result.headers.get("OpenAI-Processing-Ms"),
+            request_id=result.headers.get("X-Request-Id"),
         )
         # Don't read the whole stream for debug logging unless necessary.
         if openai.log == "debug":
openai/cli.py
@@ -9,7 +9,6 @@ from typing import Optional
 import requests
 
 import openai
-import openai.wandb_logger
 from openai.upload_progress import BufferReader
 from openai.validators import (
     apply_necessary_remediation,
@@ -542,6 +541,8 @@ class FineTune:
 class WandbLogger:
     @classmethod
     def sync(cls, args):
+        import openai.wandb_logger
+
         resp = openai.wandb_logger.WandbLogger.sync(
             id=args.id,
             n_fine_tunes=args.n_fine_tunes,
openai/error.py
@@ -76,6 +76,10 @@ class TryAgain(OpenAIError):
     pass
 
 
+class Timeout(OpenAIError):
+    pass
+
+
 class APIConnectionError(OpenAIError):
     def __init__(
         self,
openai/openai_object.py
@@ -1,6 +1,6 @@
 import json
 from copy import deepcopy
-from typing import Optional
+from typing import Optional, Tuple, Union
 
 import openai
 from openai import api_requestor, util
@@ -165,6 +165,7 @@ class OpenAIObject(dict):
         stream=False,
         plain_old_data=False,
         request_id: Optional[str] = None,
+        request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
     ):
         if params is None:
             params = self._retrieve_params
@@ -182,6 +183,7 @@ class OpenAIObject(dict):
             stream=stream,
             headers=headers,
             request_id=request_id,
+            request_timeout=request_timeout,
         )
 
         if stream:
openai/version.py
@@ -1,1 +1,1 @@
-VERSION = "0.22.1"
+VERSION = "0.23.0"
README.md
@@ -52,6 +52,10 @@ completion = openai.Completion.create(engine="ada", prompt="Hello world")
 print(completion.choices[0].text)
 ```
 
+
+### Params
+All endpoints have a `.create` method that support a `request_timeout` param.  This param takes a `Union[float, Tuple[float, float]]` and will raise a `openai.error.TimeoutError` error if the request exceeds that time in seconds (See: https://requests.readthedocs.io/en/latest/user/quickstart/#timeouts).
+
 ### Microsoft Azure Endpoints
 
 In order to use the library with Microsoft Azure endpoints, you need to set the api_type, api_base and api_version in addition to the api_key. The api_type must be set to 'azure' and the others correspond to the properties of your endpoint.
setup.py
@@ -19,7 +19,8 @@ setup(
         "pandas>=1.2.3",  # Needed for CLI fine-tuning data preparation tool
         "pandas-stubs>=1.1.0.11",  # Needed for type hints for mypy
         "openpyxl>=3.0.7",  # Needed for CLI fine-tuning data preparation tool xlsx format
-        "numpy>=1.22.0",  # To address a vuln in <1.21.6
+        "numpy",
+        "typing_extensions",  # Needed for type hints for mypy
     ],
     extras_require={
         "dev": ["black~=21.6b0", "pytest==6.*"],