Commit c556584e

Krista Pratico <krpratic@microsoft.com>
2023-04-28 01:01:49
allow for user passed requests.Session (#390)
1 parent 96e7642
Changed files (3)
openai/tests/test_endpoints.py
@@ -2,6 +2,7 @@ import io
 import json
 
 import pytest
+import requests
 
 import openai
 from openai import error
@@ -86,3 +87,32 @@ def test_timeout_does_not_error():
         model="ada",
         request_timeout=10,
     )
+
+
+def test_user_session():
+     with requests.Session() as session:
+        openai.requestssession = session
+
+        completion = openai.Completion.create(
+            prompt="hello world",
+            model="ada",
+        )
+        assert completion
+
+
+def test_user_session_factory():
+    def factory():
+        session = requests.Session()
+        session.mount(
+            "https://",
+            requests.adapters.HTTPAdapter(max_retries=4),
+        )
+        return session
+
+    openai.requestssession = factory
+
+    completion = openai.Completion.create(
+        prompt="hello world",
+        model="ada",
+    )
+    assert completion
openai/__init__.py
@@ -4,7 +4,7 @@
 
 import os
 import sys
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING, Optional, Union, Callable
 
 from contextvars import ContextVar
 
@@ -36,6 +36,7 @@ from openai.error import APIError, InvalidRequestError, OpenAIError
 from openai.version import VERSION
 
 if TYPE_CHECKING:
+    import requests
     from aiohttp import ClientSession
 
 api_key = os.environ.get("OPENAI_API_KEY")
@@ -58,6 +59,10 @@ ca_bundle_path = None  # No longer used, feature was removed
 debug = False
 log = None  # Set to either 'debug' or 'info', controls console logging
 
+requestssession: Optional[
+    Union["requests.Session", Callable[[], "requests.Session"]]
+] = None # Provide a requests.Session or Session factory.
+
 aiosession: ContextVar[Optional["ClientSession"]] = ContextVar(
     "aiohttp-session", default=None
 )  # Acts as a global aiohttp ClientSession that reuses connections.
openai/api_requestor.py
@@ -76,6 +76,10 @@ def _aiohttp_proxies_arg(proxy) -> Optional[str]:
 
 
 def _make_session() -> requests.Session:
+    if openai.requestssession:
+        if isinstance(openai.requestssession, requests.Session):
+            return openai.requestssession
+        return openai.requestssession()
     if not openai.verify_ssl_certs:
         warnings.warn("verify_ssl_certs is ignored; openai always verifies.")
     s = requests.Session()