Commit 5d1a7268
Changed files (2)
openai
tests
asyncio
openai/tests/asyncio/test_endpoints.py
@@ -2,10 +2,10 @@ import io
import json
import pytest
+from aiohttp import ClientSession
import openai
from openai import error
-from aiohttp import ClientSession
pytestmark = [pytest.mark.asyncio]
openai/api_requestor.py
@@ -89,21 +89,19 @@ def _make_session() -> requests.Session:
return s
-def parse_stream_helper(line: bytes):
+def parse_stream_helper(line: bytes) -> Optional[str]:
if line:
if line.strip() == b"data: [DONE]":
# return here will cause GeneratorExit exception in urllib3
# and it will close http connection with TCP Reset
return None
- if hasattr(line, "decode"):
- line = line.decode("utf-8")
- if line.startswith("data: "):
- line = line[len("data: ") :]
- return line
+ if line.startswith(b"data: "):
+ line = line[len(b"data: ") :]
+ return line.decode("utf-8")
return None
-def parse_stream(rbody):
+def parse_stream(rbody: Iterator[bytes]) -> Iterator[str]:
for line in rbody:
_line = parse_stream_helper(line)
if _line is not None:
@@ -111,10 +109,13 @@ def parse_stream(rbody):
async def parse_stream_async(rbody: aiohttp.StreamReader):
- async for line, _ in rbody.iter_chunks():
- _line = parse_stream_helper(line)
- if _line is not None:
- yield _line
+ async for chunk, _ in rbody.iter_chunks():
+ # While the `ChunkTupleAsyncStreamIterator` iterator is meant to iterate over chunks (and thus lines) it seems
+ # to still sometimes return multiple lines at a time, so let's split the chunk by lines again.
+ for line in chunk.splitlines():
+ _line = parse_stream_helper(line)
+ if _line is not None:
+ yield _line
class APIRequestor:
@@ -296,20 +297,25 @@ class APIRequestor:
) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool, str]:
ctx = aiohttp_session()
session = await ctx.__aenter__()
- result = await self.arequest_raw(
- method.lower(),
- url,
- session,
- params=params,
- supplied_headers=headers,
- files=files,
- request_id=request_id,
- request_timeout=request_timeout,
- )
- resp, got_stream = await self._interpret_async_response(result, stream)
+ try:
+ result = await self.arequest_raw(
+ method.lower(),
+ url,
+ session,
+ params=params,
+ supplied_headers=headers,
+ files=files,
+ request_id=request_id,
+ request_timeout=request_timeout,
+ )
+ resp, got_stream = await self._interpret_async_response(result, stream)
+ except Exception:
+ await ctx.__aexit__(None, None, None)
+ raise
if got_stream:
async def wrap_resp():
+ assert isinstance(resp, AsyncGenerator)
try:
async for r in resp:
yield r
@@ -612,7 +618,10 @@ class APIRequestor:
else:
return (
self._interpret_response_line(
- result.content, result.status_code, result.headers, stream=False
+ result.content.decode("utf-8"),
+ result.status_code,
+ result.headers,
+ stream=False,
),
False,
)
@@ -635,13 +644,16 @@ class APIRequestor:
util.log_warn(e, body=result.content)
return (
self._interpret_response_line(
- await result.read(), result.status, result.headers, stream=False
+ (await result.read()).decode("utf-8"),
+ result.status,
+ result.headers,
+ stream=False,
),
False,
)
def _interpret_response_line(
- self, rbody, rcode, rheaders, stream: bool
+ self, rbody: str, rcode: int, rheaders, stream: bool
) -> OpenAIResponse:
# HTTP 204 response code does not have any content in the body.
if rcode == 204:
@@ -655,13 +667,11 @@ class APIRequestor:
headers=rheaders,
)
try:
- if hasattr(rbody, "decode"):
- rbody = rbody.decode("utf-8")
data = json.loads(rbody)
- except (JSONDecodeError, UnicodeDecodeError):
+ except (JSONDecodeError, UnicodeDecodeError) as e:
raise error.APIError(
f"HTTP code {rcode} from API ({rbody})", rbody, rcode, headers=rheaders
- )
+ ) from e
resp = OpenAIResponse(data, rheaders)
# In the future, we might add a "status" parameter to errors
# to better handle the "error while streaming" case.