Commit 3edecbee
Changed files (11)
openai/api_resources/abstract/api_resource.py
@@ -13,14 +13,21 @@ class APIResource(OpenAIObject):
azure_deployments_prefix = "deployments"
@classmethod
- def retrieve(cls, id, api_key=None, request_id=None, **params):
+ def retrieve(
+ cls, id, api_key=None, request_id=None, request_timeout=None, **params
+ ):
instance = cls(id, api_key, **params)
- instance.refresh(request_id=request_id)
+ instance.refresh(request_id=request_id, request_timeout=request_timeout)
return instance
- def refresh(self, request_id=None):
+ def refresh(self, request_id=None, request_timeout=None):
self.refresh_from(
- self.request("get", self.instance_url(), request_id=request_id)
+ self.request(
+ "get",
+ self.instance_url(),
+ request_id=request_id,
+ request_timeout=request_timeout,
+ )
)
return self
openai/api_resources/abstract/engine_api_resource.py
@@ -77,7 +77,7 @@ class EngineAPIResource(APIResource):
timeout = params.pop("timeout", None)
stream = params.get("stream", False)
headers = params.pop("headers", None)
-
+ request_timeout = params.pop("request_timeout", None)
typed_api_type = cls._get_api_type_and_version(api_type=api_type)[0]
if typed_api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD):
if deployment_id is None and engine is None:
@@ -119,6 +119,7 @@ class EngineAPIResource(APIResource):
headers=headers,
stream=stream,
request_id=request_id,
+ request_timeout=request_timeout,
)
if stream:
openai/tests/test_endpoints.py
@@ -1,7 +1,10 @@
import io
import json
+import pytest
+
import openai
+from openai import error
# FILE TESTS
@@ -34,3 +37,24 @@ def test_completions_model():
result = openai.Completion.create(prompt="This was a test", n=5, model="ada")
assert len(result.choices) == 5
assert result.model.startswith("ada")
+
+
+def test_timeout_raises_error():
+ # A query that should take awhile to return
+ with pytest.raises(error.Timeout):
+ openai.Completion.create(
+ prompt="test" * 1000,
+ n=10,
+ model="ada",
+ max_tokens=100,
+ request_timeout=0.01,
+ )
+
+
+def test_timeout_does_not_error():
+ # A query that should be fast
+ openai.Completion.create(
+ prompt="test",
+ model="ada",
+ request_timeout=10,
+ )
openai/tests/test_exceptions.py
@@ -21,6 +21,7 @@ EXCEPTION_TEST_CASES = [
openai.error.SignatureVerificationError("message", "sig_header?"),
openai.error.APIConnectionError("message!", should_retry=True),
openai.error.TryAgain(),
+ openai.error.Timeout(),
openai.error.APIError(
message="message",
code=400,
openai/api_requestor.py
@@ -3,10 +3,11 @@ import platform
import threading
import warnings
from json import JSONDecodeError
-from typing import Dict, Iterator, Optional, Tuple, Union
+from typing import Dict, Iterator, Optional, Tuple, Union, overload
from urllib.parse import urlencode, urlsplit, urlunsplit
import requests
+from typing_extensions import Literal
import openai
from openai import error, util, version
@@ -99,6 +100,63 @@ class APIRequestor:
str += " (%s)" % (info["url"],)
return str
+ @overload
+ def request(
+ self,
+ method,
+ url,
+ params,
+ headers,
+ files,
+ stream: Literal[True],
+ request_id: Optional[str] = ...,
+ request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
+ ) -> Tuple[Iterator[OpenAIResponse], bool, str]:
+ pass
+
+ @overload
+ def request(
+ self,
+ method,
+ url,
+ params=...,
+ headers=...,
+ files=...,
+ *,
+ stream: Literal[True],
+ request_id: Optional[str] = ...,
+ request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
+ ) -> Tuple[Iterator[OpenAIResponse], bool, str]:
+ pass
+
+ @overload
+ def request(
+ self,
+ method,
+ url,
+ params=...,
+ headers=...,
+ files=...,
+ stream: Literal[False] = ...,
+ request_id: Optional[str] = ...,
+ request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
+ ) -> Tuple[OpenAIResponse, bool, str]:
+ pass
+
+ @overload
+ def request(
+ self,
+ method,
+ url,
+ params=...,
+ headers=...,
+ files=...,
+ stream: bool = ...,
+ request_id: Optional[str] = ...,
+ request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
+ ) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool, str]:
+ pass
+
def request(
self,
method,
@@ -106,8 +164,9 @@ class APIRequestor:
params=None,
headers=None,
files=None,
- stream=False,
+ stream: bool = False,
request_id: Optional[str] = None,
+ request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool, str]:
result = self.request_raw(
method.lower(),
@@ -117,6 +176,7 @@ class APIRequestor:
files=files,
stream=stream,
request_id=request_id,
+ request_timeout=request_timeout,
)
resp, got_stream = self._interpret_response(result, stream)
return resp, got_stream, self.api_key
@@ -179,7 +239,11 @@ class APIRequestor:
return error.APIError(message, rbody, rcode, resp, rheaders)
else:
return error.APIError(
- error_data.get("message"), rbody, rcode, resp, rheaders
+ f"{error_data.get('message')} {rbody} {rcode} {resp} {rheaders}",
+ rbody,
+ rcode,
+ resp,
+ rheaders,
)
def request_headers(
@@ -256,6 +320,7 @@ class APIRequestor:
files=None,
stream: bool = False,
request_id: Optional[str] = None,
+ request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
) -> requests.Response:
abs_url = "%s%s" % (self.api_base, url)
headers = self._validate_headers(supplied_headers)
@@ -295,8 +360,10 @@ class APIRequestor:
data=data,
files=files,
stream=stream,
- timeout=TIMEOUT_SECS,
+ timeout=request_timeout if request_timeout else TIMEOUT_SECS,
)
+ except requests.exceptions.Timeout as e:
+ raise error.Timeout("Request timed out") from e
except requests.exceptions.RequestException as e:
raise error.APIConnectionError("Error communicating with OpenAI") from e
util.log_info(
@@ -304,6 +371,7 @@ class APIRequestor:
path=abs_url,
response_code=result.status_code,
processing_ms=result.headers.get("OpenAI-Processing-Ms"),
+ request_id=result.headers.get("X-Request-Id"),
)
# Don't read the whole stream for debug logging unless necessary.
if openai.log == "debug":
openai/cli.py
@@ -9,7 +9,6 @@ from typing import Optional
import requests
import openai
-import openai.wandb_logger
from openai.upload_progress import BufferReader
from openai.validators import (
apply_necessary_remediation,
@@ -542,6 +541,8 @@ class FineTune:
class WandbLogger:
@classmethod
def sync(cls, args):
+ import openai.wandb_logger
+
resp = openai.wandb_logger.WandbLogger.sync(
id=args.id,
n_fine_tunes=args.n_fine_tunes,
openai/error.py
@@ -76,6 +76,10 @@ class TryAgain(OpenAIError):
pass
+class Timeout(OpenAIError):
+ pass
+
+
class APIConnectionError(OpenAIError):
def __init__(
self,
openai/openai_object.py
@@ -1,6 +1,6 @@
import json
from copy import deepcopy
-from typing import Optional
+from typing import Optional, Tuple, Union
import openai
from openai import api_requestor, util
@@ -165,6 +165,7 @@ class OpenAIObject(dict):
stream=False,
plain_old_data=False,
request_id: Optional[str] = None,
+ request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
):
if params is None:
params = self._retrieve_params
@@ -182,6 +183,7 @@ class OpenAIObject(dict):
stream=stream,
headers=headers,
request_id=request_id,
+ request_timeout=request_timeout,
)
if stream:
openai/version.py
@@ -1,1 +1,1 @@
-VERSION = "0.22.1"
+VERSION = "0.23.0"
README.md
@@ -52,6 +52,10 @@ completion = openai.Completion.create(engine="ada", prompt="Hello world")
print(completion.choices[0].text)
```
+
+### Params
+All endpoints have a `.create` method that support a `request_timeout` param. This param takes a `Union[float, Tuple[float, float]]` and will raise a `openai.error.TimeoutError` error if the request exceeds that time in seconds (See: https://requests.readthedocs.io/en/latest/user/quickstart/#timeouts).
+
### Microsoft Azure Endpoints
In order to use the library with Microsoft Azure endpoints, you need to set the api_type, api_base and api_version in addition to the api_key. The api_type must be set to 'azure' and the others correspond to the properties of your endpoint.
setup.py
@@ -19,7 +19,8 @@ setup(
"pandas>=1.2.3", # Needed for CLI fine-tuning data preparation tool
"pandas-stubs>=1.1.0.11", # Needed for type hints for mypy
"openpyxl>=3.0.7", # Needed for CLI fine-tuning data preparation tool xlsx format
- "numpy>=1.22.0", # To address a vuln in <1.21.6
+ "numpy",
+ "typing_extensions", # Needed for type hints for mypy
],
extras_require={
"dev": ["black~=21.6b0", "pytest==6.*"],