Commit 532d3df1

Hynek Kydlíček <39408646+hynky1999@users.noreply.github.com>
2023-09-26 11:16:58
🐛 fixed asyncio bugs (#584)
1 parent 8944bd1
Changed files (1)
openai/api_requestor.py
@@ -6,11 +6,10 @@ import sys
 import threading
 import time
 import warnings
-from contextlib import asynccontextmanager
 from json import JSONDecodeError
 from typing import (
+    AsyncContextManager,
     AsyncGenerator,
-    AsyncIterator,
     Callable,
     Dict,
     Iterator,
@@ -368,8 +367,9 @@ class APIRequestor:
         request_id: Optional[str] = None,
         request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
     ) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool, str]:
-        ctx = aiohttp_session()
+        ctx = AioHTTPSession()
         session = await ctx.__aenter__()
+        result = None
         try:
             result = await self.arequest_raw(
                 method.lower(),
@@ -383,6 +383,9 @@ class APIRequestor:
             )
             resp, got_stream = await self._interpret_async_response(result, stream)
         except Exception:
+            # Close the request before exiting session context.
+            if result is not None:
+                result.release()
             await ctx.__aexit__(None, None, None)
             raise
         if got_stream:
@@ -393,10 +396,15 @@ class APIRequestor:
                     async for r in resp:
                         yield r
                 finally:
+                    # Close the request before exiting session context. Important to do it here
+                    # as if stream is not fully exhausted, we need to close the request nevertheless.
+                    result.release()
                     await ctx.__aexit__(None, None, None)
 
             return wrap_resp(), got_stream, self.api_key
         else:
+            # Close the request before exiting session context.
+            result.release()
             await ctx.__aexit__(None, None, None)
             return resp, got_stream, self.api_key
 
@@ -770,11 +778,22 @@ class APIRequestor:
         return resp
 
 
-@asynccontextmanager
-async def aiohttp_session() -> AsyncIterator[aiohttp.ClientSession]:
-    user_set_session = openai.aiosession.get()
-    if user_set_session:
-        yield user_set_session
-    else:
-        async with aiohttp.ClientSession() as session:
-            yield session
+class AioHTTPSession(AsyncContextManager):
+    def __init__(self):
+        self._session = None
+        self._should_close_session = False
+
+    async def __aenter__(self):
+        self._session = openai.aiosession.get()
+        if self._session is None:
+            self._session = await aiohttp.ClientSession().__aenter__()
+            self._should_close_session = True
+
+        return self._session
+    
+    async def __aexit__(self, exc_type, exc_value, traceback):
+        if self._session is None:
+            raise RuntimeError("Session is not initialized")
+
+        if self._should_close_session:
+            await self._session.__aexit__(exc_type, exc_value, traceback)
\ No newline at end of file