main
  1# mypy: ignore-errors
  2from __future__ import annotations
  3
  4import io
  5import time
  6import wave
  7import asyncio
  8from typing import Any, Type, Union, Generic, TypeVar, Callable, overload
  9from typing_extensions import TYPE_CHECKING, Literal
 10
 11from .._types import FileTypes, FileContent
 12from .._extras import numpy as np, sounddevice as sd
 13
 14if TYPE_CHECKING:
 15    import numpy.typing as npt
 16
 17SAMPLE_RATE = 24000
 18
 19DType = TypeVar("DType", bound=np.generic)
 20
 21
 22class Microphone(Generic[DType]):
 23    def __init__(
 24        self,
 25        channels: int = 1,
 26        dtype: Type[DType] = np.int16,
 27        should_record: Union[Callable[[], bool], None] = None,
 28        timeout: Union[float, None] = None,
 29    ):
 30        self.channels = channels
 31        self.dtype = dtype
 32        self.should_record = should_record
 33        self.buffer_chunks = []
 34        self.timeout = timeout
 35        self.has_record_function = callable(should_record)
 36
 37    def _ndarray_to_wav(self, audio_data: npt.NDArray[DType]) -> FileTypes:
 38        buffer: FileContent = io.BytesIO()
 39        with wave.open(buffer, "w") as wav_file:
 40            wav_file.setnchannels(self.channels)
 41            wav_file.setsampwidth(np.dtype(self.dtype).itemsize)
 42            wav_file.setframerate(SAMPLE_RATE)
 43            wav_file.writeframes(audio_data.tobytes())
 44        buffer.seek(0)
 45        return ("audio.wav", buffer, "audio/wav")
 46
 47    @overload
 48    async def record(self, return_ndarray: Literal[True]) -> npt.NDArray[DType]: ...
 49
 50    @overload
 51    async def record(self, return_ndarray: Literal[False]) -> FileTypes: ...
 52
 53    @overload
 54    async def record(self, return_ndarray: None = ...) -> FileTypes: ...
 55
 56    async def record(self, return_ndarray: Union[bool, None] = False) -> Union[npt.NDArray[DType], FileTypes]:
 57        loop = asyncio.get_event_loop()
 58        event = asyncio.Event()
 59        self.buffer_chunks: list[npt.NDArray[DType]] = []
 60        start_time = time.perf_counter()
 61
 62        def callback(
 63            indata: npt.NDArray[DType],
 64            _frame_count: int,
 65            _time_info: Any,
 66            _status: Any,
 67        ):
 68            execution_time = time.perf_counter() - start_time
 69            reached_recording_timeout = execution_time > self.timeout if self.timeout is not None else False
 70            if reached_recording_timeout:
 71                loop.call_soon_threadsafe(event.set)
 72                raise sd.CallbackStop
 73
 74            should_be_recording = self.should_record() if callable(self.should_record) else True
 75            if not should_be_recording:
 76                loop.call_soon_threadsafe(event.set)
 77                raise sd.CallbackStop
 78
 79            self.buffer_chunks.append(indata.copy())
 80
 81        stream = sd.InputStream(
 82            callback=callback,
 83            dtype=self.dtype,
 84            samplerate=SAMPLE_RATE,
 85            channels=self.channels,
 86        )
 87        with stream:
 88            await event.wait()
 89
 90        # Concatenate all chunks into a single buffer, handle empty case
 91        concatenated_chunks: npt.NDArray[DType] = (
 92            np.concatenate(self.buffer_chunks, axis=0)
 93            if len(self.buffer_chunks) > 0
 94            else np.array([], dtype=self.dtype)
 95        )
 96
 97        if return_ndarray:
 98            return concatenated_chunks
 99        else:
100            return self._ndarray_to_wav(concatenated_chunks)