Commit e081d99e
Changed files (1)
src
openai
src/openai/_streaming.py
@@ -59,22 +59,42 @@ class Stream(Generic[_T]):
if sse.data.startswith("[DONE]"):
break
- data = sse.json()
- if is_mapping(data) and data.get("error"):
- message = None
- error = data.get("error")
- if is_mapping(error):
- message = error.get("message")
- if not message or not isinstance(message, str):
- message = "An error occurred during streaming"
-
- raise APIError(
- message=message,
- request=self.response.request,
- body=data["error"],
- )
-
- yield process_data(data=data, cast_to=cast_to, response=response)
+ if sse.event is None:
+ data = sse.json()
+ if is_mapping(data) and data.get("error"):
+ message = None
+ error = data.get("error")
+ if is_mapping(error):
+ message = error.get("message")
+ if not message or not isinstance(message, str):
+ message = "An error occurred during streaming"
+
+ raise APIError(
+ message=message,
+ request=self.response.request,
+ body=data["error"],
+ )
+
+ yield process_data(data=data, cast_to=cast_to, response=response)
+
+ else:
+ data = sse.json()
+
+ if sse.event == "error" and is_mapping(data) and data.get("error"):
+ message = None
+ error = data.get("error")
+ if is_mapping(error):
+ message = error.get("message")
+ if not message or not isinstance(message, str):
+ message = "An error occurred during streaming"
+
+ raise APIError(
+ message=message,
+ request=self.response.request,
+ body=data["error"],
+ )
+
+ yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response)
# Ensure the entire stream is consumed
for _sse in iterator:
@@ -141,22 +161,42 @@ class AsyncStream(Generic[_T]):
if sse.data.startswith("[DONE]"):
break
- data = sse.json()
- if is_mapping(data) and data.get("error"):
- message = None
- error = data.get("error")
- if is_mapping(error):
- message = error.get("message")
- if not message or not isinstance(message, str):
- message = "An error occurred during streaming"
-
- raise APIError(
- message=message,
- request=self.response.request,
- body=data["error"],
- )
-
- yield process_data(data=data, cast_to=cast_to, response=response)
+ if sse.event is None:
+ data = sse.json()
+ if is_mapping(data) and data.get("error"):
+ message = None
+ error = data.get("error")
+ if is_mapping(error):
+ message = error.get("message")
+ if not message or not isinstance(message, str):
+ message = "An error occurred during streaming"
+
+ raise APIError(
+ message=message,
+ request=self.response.request,
+ body=data["error"],
+ )
+
+ yield process_data(data=data, cast_to=cast_to, response=response)
+
+ else:
+ data = sse.json()
+
+ if sse.event == "error" and is_mapping(data) and data.get("error"):
+ message = None
+ error = data.get("error")
+ if is_mapping(error):
+ message = error.get("message")
+ if not message or not isinstance(message, str):
+ message = "An error occurred during streaming"
+
+ raise APIError(
+ message=message,
+ request=self.response.request,
+ body=data["error"],
+ )
+
+ yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response)
# Ensure the entire stream is consumed
async for _sse in iterator: