main
1# mypy: ignore-errors
2from __future__ import annotations
3
4import io
5import time
6import wave
7import asyncio
8from typing import Any, Type, Union, Generic, TypeVar, Callable, overload
9from typing_extensions import TYPE_CHECKING, Literal
10
11from .._types import FileTypes, FileContent
12from .._extras import numpy as np, sounddevice as sd
13
14if TYPE_CHECKING:
15 import numpy.typing as npt
16
17SAMPLE_RATE = 24000
18
19DType = TypeVar("DType", bound=np.generic)
20
21
22class Microphone(Generic[DType]):
23 def __init__(
24 self,
25 channels: int = 1,
26 dtype: Type[DType] = np.int16,
27 should_record: Union[Callable[[], bool], None] = None,
28 timeout: Union[float, None] = None,
29 ):
30 self.channels = channels
31 self.dtype = dtype
32 self.should_record = should_record
33 self.buffer_chunks = []
34 self.timeout = timeout
35 self.has_record_function = callable(should_record)
36
37 def _ndarray_to_wav(self, audio_data: npt.NDArray[DType]) -> FileTypes:
38 buffer: FileContent = io.BytesIO()
39 with wave.open(buffer, "w") as wav_file:
40 wav_file.setnchannels(self.channels)
41 wav_file.setsampwidth(np.dtype(self.dtype).itemsize)
42 wav_file.setframerate(SAMPLE_RATE)
43 wav_file.writeframes(audio_data.tobytes())
44 buffer.seek(0)
45 return ("audio.wav", buffer, "audio/wav")
46
47 @overload
48 async def record(self, return_ndarray: Literal[True]) -> npt.NDArray[DType]: ...
49
50 @overload
51 async def record(self, return_ndarray: Literal[False]) -> FileTypes: ...
52
53 @overload
54 async def record(self, return_ndarray: None = ...) -> FileTypes: ...
55
56 async def record(self, return_ndarray: Union[bool, None] = False) -> Union[npt.NDArray[DType], FileTypes]:
57 loop = asyncio.get_event_loop()
58 event = asyncio.Event()
59 self.buffer_chunks: list[npt.NDArray[DType]] = []
60 start_time = time.perf_counter()
61
62 def callback(
63 indata: npt.NDArray[DType],
64 _frame_count: int,
65 _time_info: Any,
66 _status: Any,
67 ):
68 execution_time = time.perf_counter() - start_time
69 reached_recording_timeout = execution_time > self.timeout if self.timeout is not None else False
70 if reached_recording_timeout:
71 loop.call_soon_threadsafe(event.set)
72 raise sd.CallbackStop
73
74 should_be_recording = self.should_record() if callable(self.should_record) else True
75 if not should_be_recording:
76 loop.call_soon_threadsafe(event.set)
77 raise sd.CallbackStop
78
79 self.buffer_chunks.append(indata.copy())
80
81 stream = sd.InputStream(
82 callback=callback,
83 dtype=self.dtype,
84 samplerate=SAMPLE_RATE,
85 channels=self.channels,
86 )
87 with stream:
88 await event.wait()
89
90 # Concatenate all chunks into a single buffer, handle empty case
91 concatenated_chunks: npt.NDArray[DType] = (
92 np.concatenate(self.buffer_chunks, axis=0)
93 if len(self.buffer_chunks) > 0
94 else np.array([], dtype=self.dtype)
95 )
96
97 if return_ndarray:
98 return concatenated_chunks
99 else:
100 return self._ndarray_to_wav(concatenated_chunks)