Commit 7af43ce0
Changed files (2)
openai
api_resources
openai/api_resources/file.py
@@ -192,16 +192,17 @@ class File(ListableAPIResource, DeletableAPIResource):
id, api_key, api_base, api_type, api_version, organization
)
- result = await requestor.arequest_raw("get", url)
- if not 200 <= result.status < 300:
- raise requestor.handle_error_response(
- result.content,
- result.status,
- json.loads(cast(bytes, result.content)),
- result.headers,
- stream_error=False,
- )
- return result.content
+ async with api_requestor.aiohttp_session() as session:
+ result = await requestor.arequest_raw("get", url, session)
+ if not 200 <= result.status < 300:
+ raise requestor.handle_error_response(
+ result.content,
+ result.status,
+ json.loads(cast(bytes, result.content)),
+ result.headers,
+ stream_error=False,
+ )
+ return result.content
@classmethod
def __find_matching_files(cls, name, all_files, purpose):
openai/api_requestor.py
@@ -4,8 +4,18 @@ import platform
import sys
import threading
import warnings
+from contextlib import asynccontextmanager
from json import JSONDecodeError
-from typing import AsyncGenerator, Dict, Iterator, Optional, Tuple, Union, overload
+from typing import (
+ AsyncGenerator,
+ AsyncIterator,
+ Dict,
+ Iterator,
+ Optional,
+ Tuple,
+ Union,
+ overload,
+)
from urllib.parse import urlencode, urlsplit, urlunsplit
import aiohttp
@@ -284,17 +294,19 @@ class APIRequestor:
request_id: Optional[str] = None,
request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool, str]:
- result = await self.arequest_raw(
- method.lower(),
- url,
- params=params,
- supplied_headers=headers,
- files=files,
- request_id=request_id,
- request_timeout=request_timeout,
- )
- resp, got_stream = await self._interpret_async_response(result, stream)
- return resp, got_stream, self.api_key
+ async with aiohttp_session() as session:
+ result = await self.arequest_raw(
+ method.lower(),
+ url,
+ session,
+ params=params,
+ supplied_headers=headers,
+ files=files,
+ request_id=request_id,
+ request_timeout=request_timeout,
+ )
+ resp, got_stream = await self._interpret_async_response(result, stream)
+ return resp, got_stream, self.api_key
def handle_error_response(self, rbody, rcode, resp, rheaders, stream_error=False):
try:
@@ -514,6 +526,7 @@ class APIRequestor:
self,
method,
url,
+ session,
*,
params=None,
supplied_headers: Optional[Dict[str, str]] = None,
@@ -534,7 +547,6 @@ class APIRequestor:
timeout = aiohttp.ClientTimeout(
total=request_timeout if request_timeout else TIMEOUT_SECS
)
- user_set_session = openai.aiosession.get()
if files:
# TODO: Use `aiohttp.MultipartWriter` to create the multipart form data here.
@@ -552,11 +564,7 @@ class APIRequestor:
"timeout": timeout,
}
try:
- if user_set_session:
- result = await user_set_session.request(**request_kwargs)
- else:
- async with aiohttp.ClientSession() as session:
- result = await session.request(**request_kwargs)
+ result = await session.request(**request_kwargs)
util.log_info(
"OpenAI API response",
path=abs_url,
@@ -648,3 +656,13 @@ class APIRequestor:
rbody, rcode, resp.data, rheaders, stream_error=stream_error
)
return resp
+
+
+@asynccontextmanager
+async def aiohttp_session() -> AsyncIterator[aiohttp.ClientSession]:
+ user_set_session = openai.aiosession.get()
+ if user_set_session:
+ yield user_set_session
+ else:
+ async with aiohttp.ClientSession() as session:
+ yield session