main
1# mypy: ignore-errors
2from __future__ import annotations
3
4import queue
5import asyncio
6from typing import Any, Union, Callable, AsyncGenerator, cast
7from typing_extensions import TYPE_CHECKING
8
9from .. import _legacy_response
10from .._extras import numpy as np, sounddevice as sd
11from .._response import StreamedBinaryAPIResponse, AsyncStreamedBinaryAPIResponse
12
13if TYPE_CHECKING:
14 import numpy.typing as npt
15
16SAMPLE_RATE = 24000
17
18
19class LocalAudioPlayer:
20 def __init__(
21 self,
22 should_stop: Union[Callable[[], bool], None] = None,
23 ):
24 self.channels = 1
25 self.dtype = np.float32
26 self.should_stop = should_stop
27
28 async def _tts_response_to_buffer(
29 self,
30 response: Union[
31 _legacy_response.HttpxBinaryResponseContent,
32 AsyncStreamedBinaryAPIResponse,
33 StreamedBinaryAPIResponse,
34 ],
35 ) -> npt.NDArray[np.float32]:
36 chunks: list[bytes] = []
37 if isinstance(response, _legacy_response.HttpxBinaryResponseContent) or isinstance(
38 response, StreamedBinaryAPIResponse
39 ):
40 for chunk in response.iter_bytes(chunk_size=1024):
41 if chunk:
42 chunks.append(chunk)
43 else:
44 async for chunk in response.iter_bytes(chunk_size=1024):
45 if chunk:
46 chunks.append(chunk)
47
48 audio_bytes = b"".join(chunks)
49 audio_np = np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32) / 32767.0
50 audio_np = audio_np.reshape(-1, 1)
51 return audio_np
52
53 async def play(
54 self,
55 input: Union[
56 npt.NDArray[np.int16],
57 npt.NDArray[np.float32],
58 _legacy_response.HttpxBinaryResponseContent,
59 AsyncStreamedBinaryAPIResponse,
60 StreamedBinaryAPIResponse,
61 ],
62 ) -> None:
63 audio_content: npt.NDArray[np.float32]
64 if isinstance(input, np.ndarray):
65 if input.dtype == np.int16 and self.dtype == np.float32:
66 audio_content = (input.astype(np.float32) / 32767.0).reshape(-1, self.channels)
67 elif input.dtype == np.float32:
68 audio_content = cast("npt.NDArray[np.float32]", input)
69 else:
70 raise ValueError(f"Unsupported dtype: {input.dtype}")
71 else:
72 audio_content = await self._tts_response_to_buffer(input)
73
74 loop = asyncio.get_event_loop()
75 event = asyncio.Event()
76 idx = 0
77
78 def callback(
79 outdata: npt.NDArray[np.float32],
80 frame_count: int,
81 _time_info: Any,
82 _status: Any,
83 ):
84 nonlocal idx
85
86 remainder = len(audio_content) - idx
87 if remainder == 0 or (callable(self.should_stop) and self.should_stop()):
88 loop.call_soon_threadsafe(event.set)
89 raise sd.CallbackStop
90 valid_frames = frame_count if remainder >= frame_count else remainder
91 outdata[:valid_frames] = audio_content[idx : idx + valid_frames]
92 outdata[valid_frames:] = 0
93 idx += valid_frames
94
95 stream = sd.OutputStream(
96 samplerate=SAMPLE_RATE,
97 callback=callback,
98 dtype=audio_content.dtype,
99 channels=audio_content.shape[1],
100 )
101 with stream:
102 await event.wait()
103
104 async def play_stream(
105 self,
106 buffer_stream: AsyncGenerator[Union[npt.NDArray[np.float32], npt.NDArray[np.int16], None], None],
107 ) -> None:
108 loop = asyncio.get_event_loop()
109 event = asyncio.Event()
110 buffer_queue: queue.Queue[Union[npt.NDArray[np.float32], npt.NDArray[np.int16], None]] = queue.Queue(maxsize=50)
111
112 async def buffer_producer():
113 async for buffer in buffer_stream:
114 if buffer is None:
115 break
116 await loop.run_in_executor(None, buffer_queue.put, buffer)
117 await loop.run_in_executor(None, buffer_queue.put, None) # Signal completion
118
119 def callback(
120 outdata: npt.NDArray[np.float32],
121 frame_count: int,
122 _time_info: Any,
123 _status: Any,
124 ):
125 nonlocal current_buffer, buffer_pos
126
127 frames_written = 0
128 while frames_written < frame_count:
129 if current_buffer is None or buffer_pos >= len(current_buffer):
130 try:
131 current_buffer = buffer_queue.get(timeout=0.1)
132 if current_buffer is None:
133 loop.call_soon_threadsafe(event.set)
134 raise sd.CallbackStop
135 buffer_pos = 0
136
137 if current_buffer.dtype == np.int16 and self.dtype == np.float32:
138 current_buffer = (current_buffer.astype(np.float32) / 32767.0).reshape(-1, self.channels)
139
140 except queue.Empty:
141 outdata[frames_written:] = 0
142 return
143
144 remaining_frames = len(current_buffer) - buffer_pos
145 frames_to_write = min(frame_count - frames_written, remaining_frames)
146 outdata[frames_written : frames_written + frames_to_write] = current_buffer[
147 buffer_pos : buffer_pos + frames_to_write
148 ]
149 buffer_pos += frames_to_write
150 frames_written += frames_to_write
151
152 current_buffer = None
153 buffer_pos = 0
154
155 producer_task = asyncio.create_task(buffer_producer())
156
157 with sd.OutputStream(
158 samplerate=SAMPLE_RATE,
159 channels=self.channels,
160 dtype=self.dtype,
161 callback=callback,
162 ):
163 await event.wait()
164
165 await producer_task