Commit 5d1a7268

Damien Deville <damien@openai.com>
2023-01-13 09:06:42
Fix some typing issues (#177)
1 parent ef8f1f1
Changed files (2)
openai/tests/asyncio/test_endpoints.py
@@ -2,10 +2,10 @@ import io
 import json
 
 import pytest
+from aiohttp import ClientSession
 
 import openai
 from openai import error
-from aiohttp import ClientSession
 
 
 pytestmark = [pytest.mark.asyncio]
openai/api_requestor.py
@@ -89,21 +89,19 @@ def _make_session() -> requests.Session:
     return s
 
 
-def parse_stream_helper(line: bytes):
+def parse_stream_helper(line: bytes) -> Optional[str]:
     if line:
         if line.strip() == b"data: [DONE]":
             # return here will cause GeneratorExit exception in urllib3
             # and it will close http connection with TCP Reset
             return None
-        if hasattr(line, "decode"):
-            line = line.decode("utf-8")
-        if line.startswith("data: "):
-            line = line[len("data: ") :]
-        return line
+        if line.startswith(b"data: "):
+            line = line[len(b"data: ") :]
+        return line.decode("utf-8")
     return None
 
 
-def parse_stream(rbody):
+def parse_stream(rbody: Iterator[bytes]) -> Iterator[str]:
     for line in rbody:
         _line = parse_stream_helper(line)
         if _line is not None:
@@ -111,10 +109,13 @@ def parse_stream(rbody):
 
 
 async def parse_stream_async(rbody: aiohttp.StreamReader):
-    async for line, _ in rbody.iter_chunks():
-        _line = parse_stream_helper(line)
-        if _line is not None:
-            yield _line
+    async for chunk, _ in rbody.iter_chunks():
+        # While the `ChunkTupleAsyncStreamIterator` iterator is meant to iterate over chunks (and thus lines) it seems
+        # to still sometimes return multiple lines at a time, so let's split the chunk by lines again.
+        for line in chunk.splitlines():
+            _line = parse_stream_helper(line)
+            if _line is not None:
+                yield _line
 
 
 class APIRequestor:
@@ -296,20 +297,25 @@ class APIRequestor:
     ) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool, str]:
         ctx = aiohttp_session()
         session = await ctx.__aenter__()
-        result = await self.arequest_raw(
-            method.lower(),
-            url,
-            session,
-            params=params,
-            supplied_headers=headers,
-            files=files,
-            request_id=request_id,
-            request_timeout=request_timeout,
-        )
-        resp, got_stream = await self._interpret_async_response(result, stream)
+        try:
+            result = await self.arequest_raw(
+                method.lower(),
+                url,
+                session,
+                params=params,
+                supplied_headers=headers,
+                files=files,
+                request_id=request_id,
+                request_timeout=request_timeout,
+            )
+            resp, got_stream = await self._interpret_async_response(result, stream)
+        except Exception:
+            await ctx.__aexit__(None, None, None)
+            raise
         if got_stream:
 
             async def wrap_resp():
+                assert isinstance(resp, AsyncGenerator)
                 try:
                     async for r in resp:
                         yield r
@@ -612,7 +618,10 @@ class APIRequestor:
         else:
             return (
                 self._interpret_response_line(
-                    result.content, result.status_code, result.headers, stream=False
+                    result.content.decode("utf-8"),
+                    result.status_code,
+                    result.headers,
+                    stream=False,
                 ),
                 False,
             )
@@ -635,13 +644,16 @@ class APIRequestor:
                 util.log_warn(e, body=result.content)
             return (
                 self._interpret_response_line(
-                    await result.read(), result.status, result.headers, stream=False
+                    (await result.read()).decode("utf-8"),
+                    result.status,
+                    result.headers,
+                    stream=False,
                 ),
                 False,
             )
 
     def _interpret_response_line(
-        self, rbody, rcode, rheaders, stream: bool
+        self, rbody: str, rcode: int, rheaders, stream: bool
     ) -> OpenAIResponse:
         # HTTP 204 response code does not have any content in the body.
         if rcode == 204:
@@ -655,13 +667,11 @@ class APIRequestor:
                 headers=rheaders,
             )
         try:
-            if hasattr(rbody, "decode"):
-                rbody = rbody.decode("utf-8")
             data = json.loads(rbody)
-        except (JSONDecodeError, UnicodeDecodeError):
+        except (JSONDecodeError, UnicodeDecodeError) as e:
             raise error.APIError(
                 f"HTTP code {rcode} from API ({rbody})", rbody, rcode, headers=rheaders
-            )
+            ) from e
         resp = OpenAIResponse(data, rheaders)
         # In the future, we might add a "status" parameter to errors
         # to better handle the "error while streaming" case.