Commit 77e5cf97
Changed files (4)
openai/tests/test_api_requestor.py
@@ -67,3 +67,35 @@ def test_requestor_azure_ad_headers() -> None:
assert headers["Test_Header"] == "Unit_Test_Header"
assert "Authorization" in headers
assert headers["Authorization"] == "Bearer test_key"
+
+
+@pytest.mark.requestor
+def test_requestor_cycle_sessions(mocker: MockerFixture) -> None:
+ # HACK: we need to purge the _thread_context to not interfere
+ # with other tests
+ from openai.api_requestor import _thread_context
+
+ delattr(_thread_context, "session")
+
+ api_requestor = APIRequestor(key="test_key", api_type="azure_ad")
+
+ mock_session = mocker.MagicMock()
+ mocker.patch("openai.api_requestor._make_session", lambda: mock_session)
+
+ # We don't call `session.close()` if not enough time has elapsed
+ api_requestor.request_raw("get", "http://example.com")
+ mock_session.request.assert_called()
+ api_requestor.request_raw("get", "http://example.com")
+ mock_session.close.assert_not_called()
+
+ mocker.patch("openai.api_requestor.MAX_SESSION_LIFETIME_SECS", 0)
+
+ # Due to 0 lifetime, the original session will be closed before the next call
+ # and a new session will be created
+ mock_session_2 = mocker.MagicMock()
+ mocker.patch("openai.api_requestor._make_session", lambda: mock_session_2)
+ api_requestor.request_raw("get", "http://example.com")
+ mock_session.close.assert_called()
+ mock_session_2.request.assert_called()
+
+ delattr(_thread_context, "session")
openai/tests/test_util.py
@@ -1,3 +1,4 @@
+import json
from tempfile import NamedTemporaryFile
import pytest
@@ -28,3 +29,27 @@ def test_openai_api_key_path_with_malformed_key(api_key_file) -> None:
api_key_file.flush()
with pytest.raises(ValueError, match="Malformed API key"):
util.default_api_key()
+
+
+def test_key_order_openai_object_rendering() -> None:
+ sample_response = {
+ "id": "chatcmpl-7NaPEA6sgX7LnNPyKPbRlsyqLbr5V",
+ "object": "chat.completion",
+ "created": 1685855844,
+ "model": "gpt-3.5-turbo-0301",
+ "usage": {"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97},
+ "choices": [
+ {
+ "message": {
+ "role": "assistant",
+ "content": "The 2020 World Series was played at Globe Life Field in Arlington, Texas. It was the first time that the World Series was played at a neutral site because of the COVID-19 pandemic.",
+ },
+ "finish_reason": "stop",
+ "index": 0,
+ }
+ ],
+ }
+
+ oai_object = util.convert_to_openai_object(sample_response)
+ # The `__str__` method was sorting while dumping to json
+ assert list(json.loads(str(oai_object)).keys()) == list(sample_response.keys())
openai/api_requestor.py
@@ -3,6 +3,7 @@ import json
import platform
import sys
import threading
+import time
import warnings
from contextlib import asynccontextmanager
from json import JSONDecodeError
@@ -32,6 +33,7 @@ from openai.openai_response import OpenAIResponse
from openai.util import ApiType
TIMEOUT_SECS = 600
+MAX_SESSION_LIFETIME_SECS = 180
MAX_CONNECTION_RETRIES = 2
# Has one attribute per thread, 'session'.
@@ -516,6 +518,14 @@ class APIRequestor:
if not hasattr(_thread_context, "session"):
_thread_context.session = _make_session()
+ _thread_context.session_create_time = time.time()
+ elif (
+ time.time() - getattr(_thread_context, "session_create_time", 0)
+ >= MAX_SESSION_LIFETIME_SECS
+ ):
+ _thread_context.session.close()
+ _thread_context.session = _make_session()
+ _thread_context.session_create_time = time.time()
try:
result = _thread_context.session.request(
method,
openai/openai_object.py
@@ -278,7 +278,7 @@ class OpenAIObject(dict):
def __str__(self):
obj = self.to_dict_recursive()
- return json.dumps(obj, sort_keys=True, indent=2)
+ return json.dumps(obj, indent=2)
def to_dict(self):
return dict(self)