Commit e081d99e

Robert Craigie <robert@craigie.dev>
2025-01-14 04:23:17
Revert "chore(internal): streaming refactors (#2012)"
This reverts commit d76a748f606743407f94dfc26758095560e2082a.
1 parent 82ccc98
Changed files (1)
src
src/openai/_streaming.py
@@ -59,22 +59,42 @@ class Stream(Generic[_T]):
             if sse.data.startswith("[DONE]"):
                 break
 
-            data = sse.json()
-            if is_mapping(data) and data.get("error"):
-                message = None
-                error = data.get("error")
-                if is_mapping(error):
-                    message = error.get("message")
-                if not message or not isinstance(message, str):
-                    message = "An error occurred during streaming"
-
-                raise APIError(
-                    message=message,
-                    request=self.response.request,
-                    body=data["error"],
-                )
-
-            yield process_data(data=data, cast_to=cast_to, response=response)
+            if sse.event is None:
+                data = sse.json()
+                if is_mapping(data) and data.get("error"):
+                    message = None
+                    error = data.get("error")
+                    if is_mapping(error):
+                        message = error.get("message")
+                    if not message or not isinstance(message, str):
+                        message = "An error occurred during streaming"
+
+                    raise APIError(
+                        message=message,
+                        request=self.response.request,
+                        body=data["error"],
+                    )
+
+                yield process_data(data=data, cast_to=cast_to, response=response)
+
+            else:
+                data = sse.json()
+
+                if sse.event == "error" and is_mapping(data) and data.get("error"):
+                    message = None
+                    error = data.get("error")
+                    if is_mapping(error):
+                        message = error.get("message")
+                    if not message or not isinstance(message, str):
+                        message = "An error occurred during streaming"
+
+                    raise APIError(
+                        message=message,
+                        request=self.response.request,
+                        body=data["error"],
+                    )
+
+                yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response)
 
         # Ensure the entire stream is consumed
         for _sse in iterator:
@@ -141,22 +161,42 @@ class AsyncStream(Generic[_T]):
             if sse.data.startswith("[DONE]"):
                 break
 
-            data = sse.json()
-            if is_mapping(data) and data.get("error"):
-                message = None
-                error = data.get("error")
-                if is_mapping(error):
-                    message = error.get("message")
-                if not message or not isinstance(message, str):
-                    message = "An error occurred during streaming"
-
-                raise APIError(
-                    message=message,
-                    request=self.response.request,
-                    body=data["error"],
-                )
-
-            yield process_data(data=data, cast_to=cast_to, response=response)
+            if sse.event is None:
+                data = sse.json()
+                if is_mapping(data) and data.get("error"):
+                    message = None
+                    error = data.get("error")
+                    if is_mapping(error):
+                        message = error.get("message")
+                    if not message or not isinstance(message, str):
+                        message = "An error occurred during streaming"
+
+                    raise APIError(
+                        message=message,
+                        request=self.response.request,
+                        body=data["error"],
+                    )
+
+                yield process_data(data=data, cast_to=cast_to, response=response)
+
+            else:
+                data = sse.json()
+
+                if sse.event == "error" and is_mapping(data) and data.get("error"):
+                    message = None
+                    error = data.get("error")
+                    if is_mapping(error):
+                        message = error.get("message")
+                    if not message or not isinstance(message, str):
+                        message = "An error occurred during streaming"
+
+                    raise APIError(
+                        message=message,
+                        request=self.response.request,
+                        body=data["error"],
+                    )
+
+                yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response)
 
         # Ensure the entire stream is consumed
         async for _sse in iterator: