Commit 532d3df1
Changed files (1)
openai
openai/api_requestor.py
@@ -6,11 +6,10 @@ import sys
import threading
import time
import warnings
-from contextlib import asynccontextmanager
from json import JSONDecodeError
from typing import (
+ AsyncContextManager,
AsyncGenerator,
- AsyncIterator,
Callable,
Dict,
Iterator,
@@ -368,8 +367,9 @@ class APIRequestor:
request_id: Optional[str] = None,
request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool, str]:
- ctx = aiohttp_session()
+ ctx = AioHTTPSession()
session = await ctx.__aenter__()
+ result = None
try:
result = await self.arequest_raw(
method.lower(),
@@ -383,6 +383,9 @@ class APIRequestor:
)
resp, got_stream = await self._interpret_async_response(result, stream)
except Exception:
+ # Close the request before exiting session context.
+ if result is not None:
+ result.release()
await ctx.__aexit__(None, None, None)
raise
if got_stream:
@@ -393,10 +396,15 @@ class APIRequestor:
async for r in resp:
yield r
finally:
+ # Close the request before exiting session context. Important to do it here
+ # as if stream is not fully exhausted, we need to close the request nevertheless.
+ result.release()
await ctx.__aexit__(None, None, None)
return wrap_resp(), got_stream, self.api_key
else:
+ # Close the request before exiting session context.
+ result.release()
await ctx.__aexit__(None, None, None)
return resp, got_stream, self.api_key
@@ -770,11 +778,22 @@ class APIRequestor:
return resp
-@asynccontextmanager
-async def aiohttp_session() -> AsyncIterator[aiohttp.ClientSession]:
- user_set_session = openai.aiosession.get()
- if user_set_session:
- yield user_set_session
- else:
- async with aiohttp.ClientSession() as session:
- yield session
+class AioHTTPSession(AsyncContextManager):
+ def __init__(self):
+ self._session = None
+ self._should_close_session = False
+
+ async def __aenter__(self):
+ self._session = openai.aiosession.get()
+ if self._session is None:
+ self._session = await aiohttp.ClientSession().__aenter__()
+ self._should_close_session = True
+
+ return self._session
+
+ async def __aexit__(self, exc_type, exc_value, traceback):
+ if self._session is None:
+ raise RuntimeError("Session is not initialized")
+
+ if self._should_close_session:
+ await self._session.__aexit__(exc_type, exc_value, traceback)
\ No newline at end of file