Commit 946aa8fa
Changed files (4)
openai
openai/api_resources/abstract/engine_api_resource.py
@@ -63,6 +63,7 @@ class EngineAPIResource(APIResource):
engine = params.pop("engine", None)
timeout = params.pop("timeout", None)
stream = params.get("stream", False)
+ headers = params.pop("headers", None)
if engine is None and cls.engine_required:
raise error.InvalidRequestError(
"Must provide an 'engine' parameter to create a %s" % cls, "engine"
@@ -87,7 +88,12 @@ class EngineAPIResource(APIResource):
)
url = cls.class_url(engine, api_type, api_version)
response, _, api_key = requestor.request(
- "post", url, params, stream=stream, request_id=request_id
+ "post",
+ url,
+ params=params,
+ headers=headers,
+ stream=stream,
+ request_id=request_id,
)
if stream:
openai/tests/test_endpoints.py
@@ -28,3 +28,9 @@ def test_completions_multiple_prompts():
prompt=["This was a test", "This was another test"], n=5, engine="ada"
)
assert len(result.choices) == 10
+
+
+def test_completions_model():
+ result = openai.Completion.create(prompt="This was a test", n=5, model="ada")
+ assert len(result.choices) == 5
+ assert result.model.startswith("ada:")
openai/api_requestor.py
@@ -100,8 +100,8 @@ class APIRequestor:
result = self.request_raw(
method.lower(),
url,
- params,
- headers,
+ params=params,
+ supplied_headers=headers,
files=files,
stream=stream,
request_id=request_id,
@@ -212,18 +212,41 @@ class APIRequestor:
return headers
+ def _validate_headers(
+ self, supplied_headers: Optional[Dict[str, str]]
+ ) -> Dict[str, str]:
+ headers: Dict[str, str] = {}
+ if supplied_headers is None:
+ return headers
+
+ if not isinstance(supplied_headers, dict):
+ raise TypeError("Headers must be a dictionary")
+
+ for k, v in supplied_headers.items():
+ if not isinstance(k, str):
+ raise TypeError("Header keys must be strings")
+ if not isinstance(v, str):
+ raise TypeError("Header values must be strings")
+ headers[k] = v
+
+ # NOTE: It is possible to do more validation of the headers, but a request could always
+ # be made to the API manually with invalid headers, so we need to handle them server side.
+
+ return headers
+
def request_raw(
self,
method,
url,
+ *,
params=None,
- supplied_headers=None,
+ supplied_headers: Dict[str, str] = None,
files=None,
- stream=False,
+ stream: bool = False,
request_id: Optional[str] = None,
) -> requests.Response:
abs_url = "%s%s" % (self.api_base, url)
- headers = {}
+ headers = self._validate_headers(supplied_headers)
data = None
if method == "get" or method == "delete":
@@ -246,8 +269,6 @@ class APIRequestor:
)
headers = self.request_headers(method, headers, request_id)
- if supplied_headers is not None:
- headers.update(supplied_headers)
util.log_info("Request to OpenAI API", method=method, path=abs_url)
util.log_debug("Post details", data=data, api_version=self.api_version)
openai/openai_object.py
@@ -176,7 +176,12 @@ class OpenAIObject(dict):
organization=self.organization,
)
response, stream, api_key = requestor.request(
- method, url, params, stream=stream, headers=headers, request_id=request_id
+ method,
+ url,
+ params=params,
+ stream=stream,
+ headers=headers,
+ request_id=request_id,
)
if stream: