Commit 946aa8fa

hallacy <hallacy@openai.com>
2022-02-03 05:59:23
[headers] Allow user provided headers in completion (#116) (#71) tag: v0.14.0
1 parent 62b51ca
Changed files (4)
openai/api_resources/abstract/engine_api_resource.py
@@ -63,6 +63,7 @@ class EngineAPIResource(APIResource):
         engine = params.pop("engine", None)
         timeout = params.pop("timeout", None)
         stream = params.get("stream", False)
+        headers = params.pop("headers", None)
         if engine is None and cls.engine_required:
             raise error.InvalidRequestError(
                 "Must provide an 'engine' parameter to create a %s" % cls, "engine"
@@ -87,7 +88,12 @@ class EngineAPIResource(APIResource):
         )
         url = cls.class_url(engine, api_type, api_version)
         response, _, api_key = requestor.request(
-            "post", url, params, stream=stream, request_id=request_id
+            "post",
+            url,
+            params=params,
+            headers=headers,
+            stream=stream,
+            request_id=request_id,
         )
 
         if stream:
openai/tests/test_endpoints.py
@@ -28,3 +28,9 @@ def test_completions_multiple_prompts():
         prompt=["This was a test", "This was another test"], n=5, engine="ada"
     )
     assert len(result.choices) == 10
+
+
+def test_completions_model():
+    result = openai.Completion.create(prompt="This was a test", n=5, model="ada")
+    assert len(result.choices) == 5
+    assert result.model.startswith("ada:")
openai/api_requestor.py
@@ -100,8 +100,8 @@ class APIRequestor:
         result = self.request_raw(
             method.lower(),
             url,
-            params,
-            headers,
+            params=params,
+            supplied_headers=headers,
             files=files,
             stream=stream,
             request_id=request_id,
@@ -212,18 +212,41 @@ class APIRequestor:
 
         return headers
 
+    def _validate_headers(
+        self, supplied_headers: Optional[Dict[str, str]]
+    ) -> Dict[str, str]:
+        headers: Dict[str, str] = {}
+        if supplied_headers is None:
+            return headers
+
+        if not isinstance(supplied_headers, dict):
+            raise TypeError("Headers must be a dictionary")
+
+        for k, v in supplied_headers.items():
+            if not isinstance(k, str):
+                raise TypeError("Header keys must be strings")
+            if not isinstance(v, str):
+                raise TypeError("Header values must be strings")
+            headers[k] = v
+
+        # NOTE: It is possible to do more validation of the headers, but a request could always
+        # be made to the API manually with invalid headers, so we need to handle them server side.
+
+        return headers
+
     def request_raw(
         self,
         method,
         url,
+        *,
         params=None,
-        supplied_headers=None,
+        supplied_headers: Dict[str, str] = None,
         files=None,
-        stream=False,
+        stream: bool = False,
         request_id: Optional[str] = None,
     ) -> requests.Response:
         abs_url = "%s%s" % (self.api_base, url)
-        headers = {}
+        headers = self._validate_headers(supplied_headers)
 
         data = None
         if method == "get" or method == "delete":
@@ -246,8 +269,6 @@ class APIRequestor:
             )
 
         headers = self.request_headers(method, headers, request_id)
-        if supplied_headers is not None:
-            headers.update(supplied_headers)
 
         util.log_info("Request to OpenAI API", method=method, path=abs_url)
         util.log_debug("Post details", data=data, api_version=self.api_version)
openai/openai_object.py
@@ -176,7 +176,12 @@ class OpenAIObject(dict):
             organization=self.organization,
         )
         response, stream, api_key = requestor.request(
-            method, url, params, stream=stream, headers=headers, request_id=request_id
+            method,
+            url,
+            params=params,
+            stream=stream,
+            headers=headers,
+            request_id=request_id,
         )
 
         if stream: