main
  1# mypy: ignore-errors
  2from __future__ import annotations
  3
  4import queue
  5import asyncio
  6from typing import Any, Union, Callable, AsyncGenerator, cast
  7from typing_extensions import TYPE_CHECKING
  8
  9from .. import _legacy_response
 10from .._extras import numpy as np, sounddevice as sd
 11from .._response import StreamedBinaryAPIResponse, AsyncStreamedBinaryAPIResponse
 12
 13if TYPE_CHECKING:
 14    import numpy.typing as npt
 15
 16SAMPLE_RATE = 24000
 17
 18
 19class LocalAudioPlayer:
 20    def __init__(
 21        self,
 22        should_stop: Union[Callable[[], bool], None] = None,
 23    ):
 24        self.channels = 1
 25        self.dtype = np.float32
 26        self.should_stop = should_stop
 27
 28    async def _tts_response_to_buffer(
 29        self,
 30        response: Union[
 31            _legacy_response.HttpxBinaryResponseContent,
 32            AsyncStreamedBinaryAPIResponse,
 33            StreamedBinaryAPIResponse,
 34        ],
 35    ) -> npt.NDArray[np.float32]:
 36        chunks: list[bytes] = []
 37        if isinstance(response, _legacy_response.HttpxBinaryResponseContent) or isinstance(
 38            response, StreamedBinaryAPIResponse
 39        ):
 40            for chunk in response.iter_bytes(chunk_size=1024):
 41                if chunk:
 42                    chunks.append(chunk)
 43        else:
 44            async for chunk in response.iter_bytes(chunk_size=1024):
 45                if chunk:
 46                    chunks.append(chunk)
 47
 48        audio_bytes = b"".join(chunks)
 49        audio_np = np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32) / 32767.0
 50        audio_np = audio_np.reshape(-1, 1)
 51        return audio_np
 52
 53    async def play(
 54        self,
 55        input: Union[
 56            npt.NDArray[np.int16],
 57            npt.NDArray[np.float32],
 58            _legacy_response.HttpxBinaryResponseContent,
 59            AsyncStreamedBinaryAPIResponse,
 60            StreamedBinaryAPIResponse,
 61        ],
 62    ) -> None:
 63        audio_content: npt.NDArray[np.float32]
 64        if isinstance(input, np.ndarray):
 65            if input.dtype == np.int16 and self.dtype == np.float32:
 66                audio_content = (input.astype(np.float32) / 32767.0).reshape(-1, self.channels)
 67            elif input.dtype == np.float32:
 68                audio_content = cast("npt.NDArray[np.float32]", input)
 69            else:
 70                raise ValueError(f"Unsupported dtype: {input.dtype}")
 71        else:
 72            audio_content = await self._tts_response_to_buffer(input)
 73
 74        loop = asyncio.get_event_loop()
 75        event = asyncio.Event()
 76        idx = 0
 77
 78        def callback(
 79            outdata: npt.NDArray[np.float32],
 80            frame_count: int,
 81            _time_info: Any,
 82            _status: Any,
 83        ):
 84            nonlocal idx
 85
 86            remainder = len(audio_content) - idx
 87            if remainder == 0 or (callable(self.should_stop) and self.should_stop()):
 88                loop.call_soon_threadsafe(event.set)
 89                raise sd.CallbackStop
 90            valid_frames = frame_count if remainder >= frame_count else remainder
 91            outdata[:valid_frames] = audio_content[idx : idx + valid_frames]
 92            outdata[valid_frames:] = 0
 93            idx += valid_frames
 94
 95        stream = sd.OutputStream(
 96            samplerate=SAMPLE_RATE,
 97            callback=callback,
 98            dtype=audio_content.dtype,
 99            channels=audio_content.shape[1],
100        )
101        with stream:
102            await event.wait()
103
104    async def play_stream(
105        self,
106        buffer_stream: AsyncGenerator[Union[npt.NDArray[np.float32], npt.NDArray[np.int16], None], None],
107    ) -> None:
108        loop = asyncio.get_event_loop()
109        event = asyncio.Event()
110        buffer_queue: queue.Queue[Union[npt.NDArray[np.float32], npt.NDArray[np.int16], None]] = queue.Queue(maxsize=50)
111
112        async def buffer_producer():
113            async for buffer in buffer_stream:
114                if buffer is None:
115                    break
116                await loop.run_in_executor(None, buffer_queue.put, buffer)
117            await loop.run_in_executor(None, buffer_queue.put, None)  # Signal completion
118
119        def callback(
120            outdata: npt.NDArray[np.float32],
121            frame_count: int,
122            _time_info: Any,
123            _status: Any,
124        ):
125            nonlocal current_buffer, buffer_pos
126
127            frames_written = 0
128            while frames_written < frame_count:
129                if current_buffer is None or buffer_pos >= len(current_buffer):
130                    try:
131                        current_buffer = buffer_queue.get(timeout=0.1)
132                        if current_buffer is None:
133                            loop.call_soon_threadsafe(event.set)
134                            raise sd.CallbackStop
135                        buffer_pos = 0
136
137                        if current_buffer.dtype == np.int16 and self.dtype == np.float32:
138                            current_buffer = (current_buffer.astype(np.float32) / 32767.0).reshape(-1, self.channels)
139
140                    except queue.Empty:
141                        outdata[frames_written:] = 0
142                        return
143
144                remaining_frames = len(current_buffer) - buffer_pos
145                frames_to_write = min(frame_count - frames_written, remaining_frames)
146                outdata[frames_written : frames_written + frames_to_write] = current_buffer[
147                    buffer_pos : buffer_pos + frames_to_write
148                ]
149                buffer_pos += frames_to_write
150                frames_written += frames_to_write
151
152        current_buffer = None
153        buffer_pos = 0
154
155        producer_task = asyncio.create_task(buffer_producer())
156
157        with sd.OutputStream(
158            samplerate=SAMPLE_RATE,
159            channels=self.channels,
160            dtype=self.dtype,
161            callback=callback,
162        ):
163            await event.wait()
164
165        await producer_task