Commit 7af43ce0

Damien Deville <damien@openai.com>
2023-01-07 04:15:42
Fix API requestor hanging when not using a global session (#167)
1 parent fb4b672
Changed files (2)
openai
openai/api_resources/file.py
@@ -192,16 +192,17 @@ class File(ListableAPIResource, DeletableAPIResource):
             id, api_key, api_base, api_type, api_version, organization
         )
 
-        result = await requestor.arequest_raw("get", url)
-        if not 200 <= result.status < 300:
-            raise requestor.handle_error_response(
-                result.content,
-                result.status,
-                json.loads(cast(bytes, result.content)),
-                result.headers,
-                stream_error=False,
-            )
-        return result.content
+        async with api_requestor.aiohttp_session() as session:
+            result = await requestor.arequest_raw("get", url, session)
+            if not 200 <= result.status < 300:
+                raise requestor.handle_error_response(
+                    result.content,
+                    result.status,
+                    json.loads(cast(bytes, result.content)),
+                    result.headers,
+                    stream_error=False,
+                )
+            return result.content
 
     @classmethod
     def __find_matching_files(cls, name, all_files, purpose):
openai/api_requestor.py
@@ -4,8 +4,18 @@ import platform
 import sys
 import threading
 import warnings
+from contextlib import asynccontextmanager
 from json import JSONDecodeError
-from typing import AsyncGenerator, Dict, Iterator, Optional, Tuple, Union, overload
+from typing import (
+    AsyncGenerator,
+    AsyncIterator,
+    Dict,
+    Iterator,
+    Optional,
+    Tuple,
+    Union,
+    overload,
+)
 from urllib.parse import urlencode, urlsplit, urlunsplit
 
 import aiohttp
@@ -284,17 +294,19 @@ class APIRequestor:
         request_id: Optional[str] = None,
         request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
     ) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool, str]:
-        result = await self.arequest_raw(
-            method.lower(),
-            url,
-            params=params,
-            supplied_headers=headers,
-            files=files,
-            request_id=request_id,
-            request_timeout=request_timeout,
-        )
-        resp, got_stream = await self._interpret_async_response(result, stream)
-        return resp, got_stream, self.api_key
+        async with aiohttp_session() as session:
+            result = await self.arequest_raw(
+                method.lower(),
+                url,
+                session,
+                params=params,
+                supplied_headers=headers,
+                files=files,
+                request_id=request_id,
+                request_timeout=request_timeout,
+            )
+            resp, got_stream = await self._interpret_async_response(result, stream)
+            return resp, got_stream, self.api_key
 
     def handle_error_response(self, rbody, rcode, resp, rheaders, stream_error=False):
         try:
@@ -514,6 +526,7 @@ class APIRequestor:
         self,
         method,
         url,
+        session,
         *,
         params=None,
         supplied_headers: Optional[Dict[str, str]] = None,
@@ -534,7 +547,6 @@ class APIRequestor:
             timeout = aiohttp.ClientTimeout(
                 total=request_timeout if request_timeout else TIMEOUT_SECS
             )
-        user_set_session = openai.aiosession.get()
 
         if files:
             # TODO: Use `aiohttp.MultipartWriter` to create the multipart form data here.
@@ -552,11 +564,7 @@ class APIRequestor:
             "timeout": timeout,
         }
         try:
-            if user_set_session:
-                result = await user_set_session.request(**request_kwargs)
-            else:
-                async with aiohttp.ClientSession() as session:
-                    result = await session.request(**request_kwargs)
+            result = await session.request(**request_kwargs)
             util.log_info(
                 "OpenAI API response",
                 path=abs_url,
@@ -648,3 +656,13 @@ class APIRequestor:
                 rbody, rcode, resp.data, rheaders, stream_error=stream_error
             )
         return resp
+
+
+@asynccontextmanager
+async def aiohttp_session() -> AsyncIterator[aiohttp.ClientSession]:
+    user_set_session = openai.aiosession.get()
+    if user_set_session:
+        yield user_set_session
+    else:
+        async with aiohttp.ClientSession() as session:
+            yield session