Commit c556584e
Changed files (3)
openai
openai/tests/test_endpoints.py
@@ -2,6 +2,7 @@ import io
import json
import pytest
+import requests
import openai
from openai import error
@@ -86,3 +87,32 @@ def test_timeout_does_not_error():
model="ada",
request_timeout=10,
)
+
+
+def test_user_session():
+ with requests.Session() as session:
+ openai.requestssession = session
+
+ completion = openai.Completion.create(
+ prompt="hello world",
+ model="ada",
+ )
+ assert completion
+
+
+def test_user_session_factory():
+ def factory():
+ session = requests.Session()
+ session.mount(
+ "https://",
+ requests.adapters.HTTPAdapter(max_retries=4),
+ )
+ return session
+
+ openai.requestssession = factory
+
+ completion = openai.Completion.create(
+ prompt="hello world",
+ model="ada",
+ )
+ assert completion
openai/__init__.py
@@ -4,7 +4,7 @@
import os
import sys
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING, Optional, Union, Callable
from contextvars import ContextVar
@@ -36,6 +36,7 @@ from openai.error import APIError, InvalidRequestError, OpenAIError
from openai.version import VERSION
if TYPE_CHECKING:
+ import requests
from aiohttp import ClientSession
api_key = os.environ.get("OPENAI_API_KEY")
@@ -58,6 +59,10 @@ ca_bundle_path = None # No longer used, feature was removed
debug = False
log = None # Set to either 'debug' or 'info', controls console logging
+requestssession: Optional[
+ Union["requests.Session", Callable[[], "requests.Session"]]
+] = None # Provide a requests.Session or Session factory.
+
aiosession: ContextVar[Optional["ClientSession"]] = ContextVar(
"aiohttp-session", default=None
) # Acts as a global aiohttp ClientSession that reuses connections.
openai/api_requestor.py
@@ -76,6 +76,10 @@ def _aiohttp_proxies_arg(proxy) -> Optional[str]:
def _make_session() -> requests.Session:
+ if openai.requestssession:
+ if isinstance(openai.requestssession, requests.Session):
+ return openai.requestssession
+ return openai.requestssession()
if not openai.verify_ssl_certs:
warnings.warn("verify_ssl_certs is ignored; openai always verifies.")
s = requests.Session()