1from __future__ import annotations
  2
  3import inspect
  4from types import TracebackType
  5from typing import Any, List, Generic, Iterable, Awaitable, cast
  6from typing_extensions import Self, Callable, Iterator, AsyncIterator
  7
  8from ._types import ParsedResponseSnapshot
  9from ._events import (
 10    ResponseStreamEvent,
 11    ResponseTextDoneEvent,
 12    ResponseCompletedEvent,
 13    ResponseTextDeltaEvent,
 14    ResponseFunctionCallArgumentsDeltaEvent,
 15)
 16from ...._types import Omit, omit
 17from ...._utils import is_given, consume_sync_iterator, consume_async_iterator
 18from ...._models import build, construct_type_unchecked
 19from ...._streaming import Stream, AsyncStream
 20from ....types.responses import ParsedResponse, ResponseStreamEvent as RawResponseStreamEvent
 21from ..._parsing._responses import TextFormatT, parse_text, parse_response
 22from ....types.responses.tool_param import ToolParam
 23from ....types.responses.parsed_response import (
 24    ParsedContent,
 25    ParsedResponseOutputMessage,
 26    ParsedResponseFunctionToolCall,
 27)
 28
 29
 30class ResponseStream(Generic[TextFormatT]):
 31    def __init__(
 32        self,
 33        *,
 34        raw_stream: Stream[RawResponseStreamEvent],
 35        text_format: type[TextFormatT] | Omit,
 36        input_tools: Iterable[ToolParam] | Omit,
 37        starting_after: int | None,
 38    ) -> None:
 39        self._raw_stream = raw_stream
 40        self._response = raw_stream.response
 41        self._iterator = self.__stream__()
 42        self._state = ResponseStreamState(text_format=text_format, input_tools=input_tools)
 43        self._starting_after = starting_after
 44
 45    def __next__(self) -> ResponseStreamEvent[TextFormatT]:
 46        return self._iterator.__next__()
 47
 48    def __iter__(self) -> Iterator[ResponseStreamEvent[TextFormatT]]:
 49        for item in self._iterator:
 50            yield item
 51
 52    def __enter__(self) -> Self:
 53        return self
 54
 55    def __stream__(self) -> Iterator[ResponseStreamEvent[TextFormatT]]:
 56        for sse_event in self._raw_stream:
 57            events_to_fire = self._state.handle_event(sse_event)
 58            for event in events_to_fire:
 59                if self._starting_after is None or event.sequence_number > self._starting_after:
 60                    yield event
 61
 62    def __exit__(
 63        self,
 64        exc_type: type[BaseException] | None,
 65        exc: BaseException | None,
 66        exc_tb: TracebackType | None,
 67    ) -> None:
 68        self.close()
 69
 70    def close(self) -> None:
 71        """
 72        Close the response and release the connection.
 73
 74        Automatically called if the response body is read to completion.
 75        """
 76        self._response.close()
 77
 78    def get_final_response(self) -> ParsedResponse[TextFormatT]:
 79        """Waits until the stream has been read to completion and returns
 80        the accumulated `ParsedResponse` object.
 81        """
 82        self.until_done()
 83        response = self._state._completed_response
 84        if not response:
 85            raise RuntimeError("Didn't receive a `response.completed` event.")
 86
 87        return response
 88
 89    def until_done(self) -> Self:
 90        """Blocks until the stream has been consumed."""
 91        consume_sync_iterator(self)
 92        return self
 93
 94
 95class ResponseStreamManager(Generic[TextFormatT]):
 96    def __init__(
 97        self,
 98        api_request: Callable[[], Stream[RawResponseStreamEvent]],
 99        *,
100        text_format: type[TextFormatT] | Omit,
101        input_tools: Iterable[ToolParam] | Omit,
102        starting_after: int | None,
103    ) -> None:
104        self.__stream: ResponseStream[TextFormatT] | None = None
105        self.__api_request = api_request
106        self.__text_format = text_format
107        self.__input_tools = input_tools
108        self.__starting_after = starting_after
109
110    def __enter__(self) -> ResponseStream[TextFormatT]:
111        raw_stream = self.__api_request()
112
113        self.__stream = ResponseStream(
114            raw_stream=raw_stream,
115            text_format=self.__text_format,
116            input_tools=self.__input_tools,
117            starting_after=self.__starting_after,
118        )
119
120        return self.__stream
121
122    def __exit__(
123        self,
124        exc_type: type[BaseException] | None,
125        exc: BaseException | None,
126        exc_tb: TracebackType | None,
127    ) -> None:
128        if self.__stream is not None:
129            self.__stream.close()
130
131
132class AsyncResponseStream(Generic[TextFormatT]):
133    def __init__(
134        self,
135        *,
136        raw_stream: AsyncStream[RawResponseStreamEvent],
137        text_format: type[TextFormatT] | Omit,
138        input_tools: Iterable[ToolParam] | Omit,
139        starting_after: int | None,
140    ) -> None:
141        self._raw_stream = raw_stream
142        self._response = raw_stream.response
143        self._iterator = self.__stream__()
144        self._state = ResponseStreamState(text_format=text_format, input_tools=input_tools)
145        self._starting_after = starting_after
146
147    async def __anext__(self) -> ResponseStreamEvent[TextFormatT]:
148        return await self._iterator.__anext__()
149
150    async def __aiter__(self) -> AsyncIterator[ResponseStreamEvent[TextFormatT]]:
151        async for item in self._iterator:
152            yield item
153
154    async def __stream__(self) -> AsyncIterator[ResponseStreamEvent[TextFormatT]]:
155        async for sse_event in self._raw_stream:
156            events_to_fire = self._state.handle_event(sse_event)
157            for event in events_to_fire:
158                if self._starting_after is None or event.sequence_number > self._starting_after:
159                    yield event
160
161    async def __aenter__(self) -> Self:
162        return self
163
164    async def __aexit__(
165        self,
166        exc_type: type[BaseException] | None,
167        exc: BaseException | None,
168        exc_tb: TracebackType | None,
169    ) -> None:
170        await self.close()
171
172    async def close(self) -> None:
173        """
174        Close the response and release the connection.
175
176        Automatically called if the response body is read to completion.
177        """
178        await self._response.aclose()
179
180    async def get_final_response(self) -> ParsedResponse[TextFormatT]:
181        """Waits until the stream has been read to completion and returns
182        the accumulated `ParsedResponse` object.
183        """
184        await self.until_done()
185        response = self._state._completed_response
186        if not response:
187            raise RuntimeError("Didn't receive a `response.completed` event.")
188
189        return response
190
191    async def until_done(self) -> Self:
192        """Blocks until the stream has been consumed."""
193        await consume_async_iterator(self)
194        return self
195
196
197class AsyncResponseStreamManager(Generic[TextFormatT]):
198    def __init__(
199        self,
200        api_request: Awaitable[AsyncStream[RawResponseStreamEvent]],
201        *,
202        text_format: type[TextFormatT] | Omit,
203        input_tools: Iterable[ToolParam] | Omit,
204        starting_after: int | None,
205    ) -> None:
206        self.__stream: AsyncResponseStream[TextFormatT] | None = None
207        self.__api_request = api_request
208        self.__text_format = text_format
209        self.__input_tools = input_tools
210        self.__starting_after = starting_after
211
212    async def __aenter__(self) -> AsyncResponseStream[TextFormatT]:
213        raw_stream = await self.__api_request
214
215        self.__stream = AsyncResponseStream(
216            raw_stream=raw_stream,
217            text_format=self.__text_format,
218            input_tools=self.__input_tools,
219            starting_after=self.__starting_after,
220        )
221
222        return self.__stream
223
224    async def __aexit__(
225        self,
226        exc_type: type[BaseException] | None,
227        exc: BaseException | None,
228        exc_tb: TracebackType | None,
229    ) -> None:
230        if self.__stream is not None:
231            await self.__stream.close()
232
233
234class ResponseStreamState(Generic[TextFormatT]):
235    def __init__(
236        self,
237        *,
238        input_tools: Iterable[ToolParam] | Omit,
239        text_format: type[TextFormatT] | Omit,
240    ) -> None:
241        self.__current_snapshot: ParsedResponseSnapshot | None = None
242        self._completed_response: ParsedResponse[TextFormatT] | None = None
243        self._input_tools = [tool for tool in input_tools] if is_given(input_tools) else []
244        self._text_format = text_format
245        self._rich_text_format: type | Omit = text_format if inspect.isclass(text_format) else omit
246
247    def handle_event(self, event: RawResponseStreamEvent) -> List[ResponseStreamEvent[TextFormatT]]:
248        self.__current_snapshot = snapshot = self.accumulate_event(event)
249
250        events: List[ResponseStreamEvent[TextFormatT]] = []
251
252        if event.type == "response.output_text.delta":
253            output = snapshot.output[event.output_index]
254            assert output.type == "message"
255
256            content = output.content[event.content_index]
257            assert content.type == "output_text"
258
259            events.append(
260                build(
261                    ResponseTextDeltaEvent,
262                    content_index=event.content_index,
263                    delta=event.delta,
264                    item_id=event.item_id,
265                    output_index=event.output_index,
266                    sequence_number=event.sequence_number,
267                    logprobs=event.logprobs,
268                    type="response.output_text.delta",
269                    snapshot=content.text,
270                )
271            )
272        elif event.type == "response.output_text.done":
273            output = snapshot.output[event.output_index]
274            assert output.type == "message"
275
276            content = output.content[event.content_index]
277            assert content.type == "output_text"
278
279            events.append(
280                build(
281                    ResponseTextDoneEvent[TextFormatT],
282                    content_index=event.content_index,
283                    item_id=event.item_id,
284                    output_index=event.output_index,
285                    sequence_number=event.sequence_number,
286                    logprobs=event.logprobs,
287                    type="response.output_text.done",
288                    text=event.text,
289                    parsed=parse_text(event.text, text_format=self._text_format),
290                )
291            )
292        elif event.type == "response.function_call_arguments.delta":
293            output = snapshot.output[event.output_index]
294            assert output.type == "function_call"
295
296            events.append(
297                build(
298                    ResponseFunctionCallArgumentsDeltaEvent,
299                    delta=event.delta,
300                    item_id=event.item_id,
301                    output_index=event.output_index,
302                    sequence_number=event.sequence_number,
303                    type="response.function_call_arguments.delta",
304                    snapshot=output.arguments,
305                )
306            )
307
308        elif event.type == "response.completed":
309            response = self._completed_response
310            assert response is not None
311
312            events.append(
313                build(
314                    ResponseCompletedEvent,
315                    sequence_number=event.sequence_number,
316                    type="response.completed",
317                    response=response,
318                )
319            )
320        else:
321            events.append(event)
322
323        return events
324
325    def accumulate_event(self, event: RawResponseStreamEvent) -> ParsedResponseSnapshot:
326        snapshot = self.__current_snapshot
327        if snapshot is None:
328            return self._create_initial_response(event)
329
330        if event.type == "response.output_item.added":
331            if event.item.type == "function_call":
332                snapshot.output.append(
333                    construct_type_unchecked(
334                        type_=cast(Any, ParsedResponseFunctionToolCall), value=event.item.to_dict()
335                    )
336                )
337            elif event.item.type == "message":
338                snapshot.output.append(
339                    construct_type_unchecked(type_=cast(Any, ParsedResponseOutputMessage), value=event.item.to_dict())
340                )
341            else:
342                snapshot.output.append(event.item)
343        elif event.type == "response.content_part.added":
344            output = snapshot.output[event.output_index]
345            if output.type == "message":
346                output.content.append(
347                    construct_type_unchecked(type_=cast(Any, ParsedContent), value=event.part.to_dict())
348                )
349        elif event.type == "response.output_text.delta":
350            output = snapshot.output[event.output_index]
351            if output.type == "message":
352                content = output.content[event.content_index]
353                assert content.type == "output_text"
354                content.text += event.delta
355        elif event.type == "response.function_call_arguments.delta":
356            output = snapshot.output[event.output_index]
357            if output.type == "function_call":
358                output.arguments += event.delta
359        elif event.type == "response.completed":
360            self._completed_response = parse_response(
361                text_format=self._text_format,
362                response=event.response,
363                input_tools=self._input_tools,
364            )
365
366        return snapshot
367
368    def _create_initial_response(self, event: RawResponseStreamEvent) -> ParsedResponseSnapshot:
369        if event.type != "response.created":
370            raise RuntimeError(f"Expected to have received `response.created` before `{event.type}`")
371
372        return construct_type_unchecked(type_=ParsedResponseSnapshot, value=event.response.to_dict())