Commit e7ee4daa
Changed files (3)
openai
openai/api_resources/abstract/engine_api_resource.py
@@ -236,7 +236,7 @@ class EngineAPIResource(APIResource):
engine=engine,
plain_old_data=cls.plain_old_data,
)
- for line in response
+ async for line in response
)
else:
obj = util.convert_to_openai_object(
openai/tests/asyncio/test_endpoints.py
@@ -5,6 +5,7 @@ import pytest
import openai
from openai import error
+from aiohttp import ClientSession
pytestmark = [pytest.mark.asyncio]
@@ -63,3 +64,26 @@ async def test_timeout_does_not_error():
model="ada",
request_timeout=10,
)
+
+
+async def test_completions_stream_finishes_global_session():
+ async with ClientSession() as session:
+ openai.aiosession.set(session)
+
+ # A query that should be fast
+ parts = []
+ async for part in await openai.Completion.acreate(
+ prompt="test", model="ada", request_timeout=3, stream=True
+ ):
+ parts.append(part)
+ assert len(parts) > 1
+
+
+async def test_completions_stream_finishes_local_session():
+ # A query that should be fast
+ parts = []
+ async for part in await openai.Completion.acreate(
+ prompt="test", model="ada", request_timeout=3, stream=True
+ ):
+ parts.append(part)
+ assert len(parts) > 1
openai/api_requestor.py
@@ -89,9 +89,9 @@ def _make_session() -> requests.Session:
return s
-def parse_stream_helper(line):
+def parse_stream_helper(line: bytes):
if line:
- if line == b"data: [DONE]":
+ if line.strip() == b"data: [DONE]":
# return here will cause GeneratorExit exception in urllib3
# and it will close http connection with TCP Reset
return None
@@ -111,7 +111,7 @@ def parse_stream(rbody):
async def parse_stream_async(rbody: aiohttp.StreamReader):
- async for line in rbody:
+ async for line, _ in rbody.iter_chunks():
_line = parse_stream_helper(line)
if _line is not None:
yield _line
@@ -294,18 +294,31 @@ class APIRequestor:
request_id: Optional[str] = None,
request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool, str]:
- async with aiohttp_session() as session:
- result = await self.arequest_raw(
- method.lower(),
- url,
- session,
- params=params,
- supplied_headers=headers,
- files=files,
- request_id=request_id,
- request_timeout=request_timeout,
- )
- resp, got_stream = await self._interpret_async_response(result, stream)
+ ctx = aiohttp_session()
+ session = await ctx.__aenter__()
+ result = await self.arequest_raw(
+ method.lower(),
+ url,
+ session,
+ params=params,
+ supplied_headers=headers,
+ files=files,
+ request_id=request_id,
+ request_timeout=request_timeout,
+ )
+ resp, got_stream = await self._interpret_async_response(result, stream)
+ if got_stream:
+
+ async def wrap_resp():
+ try:
+ async for r in resp:
+ yield r
+ finally:
+ await ctx.__aexit__(None, None, None)
+
+ return wrap_resp(), got_stream, self.api_key
+ else:
+ await ctx.__aexit__(None, None, None)
return resp, got_stream, self.api_key
def handle_error_response(self, rbody, rcode, resp, rheaders, stream_error=False):
@@ -507,7 +520,9 @@ class APIRequestor:
except requests.exceptions.Timeout as e:
raise error.Timeout("Request timed out: {}".format(e)) from e
except requests.exceptions.RequestException as e:
- raise error.APIConnectionError("Error communicating with OpenAI: {}".format(e)) from e
+ raise error.APIConnectionError(
+ "Error communicating with OpenAI: {}".format(e)
+ ) from e
util.log_info(
"OpenAI API response",
path=abs_url,