Commit e7ee4daa

Nuno Campos <nuno@boringbits.io>
2023-01-13 01:55:39
Several fixes to make `Completion.acreate(stream=True)` work (#172)
* Added a failing test case for async completion stream * Consume async generator with async for * Consume the stream in chunks as sent by API, to avoid "empty" parts The api will send chunks like ``` b'data: {"id": "cmpl-6W18L0k1kFoHUoSsJOwcPq7DKBaGX", "object": "text_completion", "created": 1673088873, "choices": [{"text": "_", "index": 0, "logprobs": null, "finish_reason": null}], "model": "ada"}\n\n' ``` The default iterator will break on each `\n` character, whereas iter_chunks will just output parts as they arrive * Add another test using global aiosession * Manually consume aiohttp_session asyncontextmanager to ensure that session is only closed once the response stream is finished Previously we'd exit the with statement before the response stream is consumed by the caller, therefore, unless we're using a global ClientSession, the session is closed (and thus the request) before it should be. * Ensure we close the session even if the caller raises an exception while consuming the stream
1 parent 71cee6a
Changed files (3)
openai
api_resources
tests
openai/api_resources/abstract/engine_api_resource.py
@@ -236,7 +236,7 @@ class EngineAPIResource(APIResource):
                     engine=engine,
                     plain_old_data=cls.plain_old_data,
                 )
-                for line in response
+                async for line in response
             )
         else:
             obj = util.convert_to_openai_object(
openai/tests/asyncio/test_endpoints.py
@@ -5,6 +5,7 @@ import pytest
 
 import openai
 from openai import error
+from aiohttp import ClientSession
 
 
 pytestmark = [pytest.mark.asyncio]
@@ -63,3 +64,26 @@ async def test_timeout_does_not_error():
         model="ada",
         request_timeout=10,
     )
+
+
+async def test_completions_stream_finishes_global_session():
+    async with ClientSession() as session:
+        openai.aiosession.set(session)
+
+        # A query that should be fast
+        parts = []
+        async for part in await openai.Completion.acreate(
+            prompt="test", model="ada", request_timeout=3, stream=True
+        ):
+            parts.append(part)
+        assert len(parts) > 1
+
+
+async def test_completions_stream_finishes_local_session():
+    # A query that should be fast
+    parts = []
+    async for part in await openai.Completion.acreate(
+        prompt="test", model="ada", request_timeout=3, stream=True
+    ):
+        parts.append(part)
+    assert len(parts) > 1
openai/api_requestor.py
@@ -89,9 +89,9 @@ def _make_session() -> requests.Session:
     return s
 
 
-def parse_stream_helper(line):
+def parse_stream_helper(line: bytes):
     if line:
-        if line == b"data: [DONE]":
+        if line.strip() == b"data: [DONE]":
             # return here will cause GeneratorExit exception in urllib3
             # and it will close http connection with TCP Reset
             return None
@@ -111,7 +111,7 @@ def parse_stream(rbody):
 
 
 async def parse_stream_async(rbody: aiohttp.StreamReader):
-    async for line in rbody:
+    async for line, _ in rbody.iter_chunks():
         _line = parse_stream_helper(line)
         if _line is not None:
             yield _line
@@ -294,18 +294,31 @@ class APIRequestor:
         request_id: Optional[str] = None,
         request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
     ) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool, str]:
-        async with aiohttp_session() as session:
-            result = await self.arequest_raw(
-                method.lower(),
-                url,
-                session,
-                params=params,
-                supplied_headers=headers,
-                files=files,
-                request_id=request_id,
-                request_timeout=request_timeout,
-            )
-            resp, got_stream = await self._interpret_async_response(result, stream)
+        ctx = aiohttp_session()
+        session = await ctx.__aenter__()
+        result = await self.arequest_raw(
+            method.lower(),
+            url,
+            session,
+            params=params,
+            supplied_headers=headers,
+            files=files,
+            request_id=request_id,
+            request_timeout=request_timeout,
+        )
+        resp, got_stream = await self._interpret_async_response(result, stream)
+        if got_stream:
+
+            async def wrap_resp():
+                try:
+                    async for r in resp:
+                        yield r
+                finally:
+                    await ctx.__aexit__(None, None, None)
+
+            return wrap_resp(), got_stream, self.api_key
+        else:
+            await ctx.__aexit__(None, None, None)
             return resp, got_stream, self.api_key
 
     def handle_error_response(self, rbody, rcode, resp, rheaders, stream_error=False):
@@ -507,7 +520,9 @@ class APIRequestor:
         except requests.exceptions.Timeout as e:
             raise error.Timeout("Request timed out: {}".format(e)) from e
         except requests.exceptions.RequestException as e:
-            raise error.APIConnectionError("Error communicating with OpenAI: {}".format(e)) from e
+            raise error.APIConnectionError(
+                "Error communicating with OpenAI: {}".format(e)
+            ) from e
         util.log_info(
             "OpenAI API response",
             path=abs_url,