Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| import wave | |
| import queue | |
| import threading | |
| import datetime | |
| from pathlib import Path | |
| import wave | |
| import pyaudio | |
| from pyannote.audio import Pipeline | |
| HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN") | |
| try: | |
| diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=HUGGINGFACE_TOKEN) | |
| DIARIZATION_AVAILABLE = True | |
| except Exception: | |
| diarization_pipeline = None | |
| DIARIZATION_AVAILABLE = False | |
| # Optional modules (import safely) | |
| try: | |
| from faster_whisper import WhisperModel | |
| FASTER_WHISPER_AVAILABLE = True | |
| except Exception: | |
| FASTER_WHISPER_AVAILABLE = False | |
| import numpy as np | |
| # # Optional: voice activity detection | |
| # try: | |
| # import webrtcvad | |
| # VAD_AVAILABLE = True | |
| # except Exception: | |
| # VAD_AVAILABLE = False | |
| # ========== CONFIG ========== | |
| RUN_TIMESTAMP = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
| FORMAT = pyaudio.paInt16 | |
| CHANNELS = 1 # mono | |
| # RATE = 16000 | |
| RATE = 44100 | |
| CHUNK = 1024 # frames per buffer read | |
| CHUNK_DURATION_SECS = 5 # how long each saved chunk is (seconds) | |
| # OUTPUT_DIR = Path("output_transcript_diarization") | |
| OUTPUT_DIR = Path(os.environ.get("OUTPUT_DIR", "/app/output_transcript_diarization")) | |
| try: | |
| OUTPUT_DIR.mkdir(parents=True, exist_ok=True) | |
| except PermissionError: | |
| OUTPUT_DIR = Path("/tmp/output_transcript_diarization") | |
| OUTPUT_DIR.mkdir(parents=True, exist_ok=True) | |
| CHUNKS_DIR = OUTPUT_DIR / f"chunks_{RUN_TIMESTAMP}" | |
| FINAL_WAV = OUTPUT_DIR / f"recorded_audio_{RUN_TIMESTAMP}.wav" | |
| TRANSCRIPT_FILE = OUTPUT_DIR / f"transcript_{RUN_TIMESTAMP}.txt" | |
| MODEL_NAME = "medium" # if using faster-whisper; change as desired | |
| # OUTPUT_DIR.mkdir(parents=True, exist_ok=True) | |
| # CHUNKS_DIR.mkdir(parents=True, exist_ok=True) | |
| audio = pyaudio.PyAudio() | |
| def list_input_devices(): | |
| ''' | |
| Lists all available audio input devices (microphones, loopbacks, etc.) with their | |
| indices and channel counts. | |
| ''' | |
| pa = pyaudio.PyAudio() | |
| print("Available audio devices (inputs):") | |
| for i in range(pa.get_device_count()): | |
| dev = pa.get_device_info_by_index(i) | |
| if dev.get("maxInputChannels", 0) > 0: | |
| print(f" {i}: {dev['name']} - {dev['maxInputChannels']} chans") | |
| pa.terminate() | |
| def open_stream_for_device(device_index, channels=1): | |
| ''' | |
| Opens a PyAudio input stream for the given device index and channel count. | |
| ''' | |
| stream = audio.open(format=FORMAT, | |
| channels=channels, | |
| rate=RATE, | |
| input=True, | |
| frames_per_buffer=CHUNK, | |
| input_device_index=device_index) | |
| return stream | |
| def save_wav_from_frames(path: Path, frames: list, nchannels=1): | |
| ''' | |
| Saves a list of audio frames as a WAV file at the specified path. | |
| ''' | |
| # Normalize of 44100 Rate | |
| raw = b''.join(frames) | |
| audio_array = np.frombuffer(raw, dtype=np.int16) | |
| # Normalize: scale to 90% of int16 range | |
| if np.max(np.abs(audio_array)) > 0: | |
| audio_array = (audio_array / np.max(np.abs(audio_array)) * 32767 * 0.9).astype(np.int16) | |
| with wave.open(str(path), 'wb') as wf: | |
| wf.setnchannels(nchannels) | |
| wf.setsampwidth(audio.get_sample_size(FORMAT)) | |
| wf.setframerate(RATE) | |
| # wf.writeframes(b''.join(frames)) | |
| wf.writeframes(audio_array.tobytes()) | |
| def merge_mono_files_to_stereo(mic_path: Path, sys_path: Path, out_path: Path): | |
| """ | |
| Create simple stereo WAV: mic -> left channel, system -> right channel. | |
| Requires numpy. Very naive — works when both inputs have same sample rate and length. | |
| """ | |
| with wave.open(str(mic_path), 'rb') as wm, wave.open(str(sys_path), 'rb') as ws: | |
| assert wm.getframerate() == ws.getframerate() == RATE | |
| sampwidth = wm.getsampwidth() | |
| nframes = min(wm.getnframes(), ws.getnframes()) | |
| mic_bytes = wm.readframes(nframes) | |
| sys_bytes = ws.readframes(nframes) | |
| # convert bytes to int16 | |
| mic_arr = np.frombuffer(mic_bytes, dtype=np.int16) | |
| sys_arr = np.frombuffer(sys_bytes, dtype=np.int16) | |
| # interleave into stereo | |
| stereo = np.empty((nframes * 2,), dtype=np.int16) | |
| stereo[0::2] = mic_arr[:nframes] | |
| stereo[1::2] = sys_arr[:nframes] | |
| with wave.open(str(out_path), 'wb') as wf: | |
| wf.setnchannels(2) | |
| wf.setsampwidth(sampwidth) | |
| wf.setframerate(RATE) | |
| wf.writeframes(stereo.tobytes()) | |
| class Transcriber: | |
| def __init__(self): | |
| ''' | |
| Loads the faster-whisper model if available. | |
| ''' | |
| self.model = None | |
| if FASTER_WHISPER_AVAILABLE: | |
| print("Loading faster-whisper model. This may take some time...") | |
| # detect device via torch.cuda if available, otherwise CPU | |
| try: | |
| import torch | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| except Exception: | |
| device = "cpu" | |
| # choose a safe compute_type: float16 on GPU, float32 on CPU | |
| compute_type = "float16" if device == "cuda" else "float32" | |
| try: | |
| # instantiate model (may download weights on first run) | |
| self.model = WhisperModel(MODEL_NAME, device=device, compute_type=compute_type) | |
| print(f"Model loaded on {device} (compute_type={compute_type}).") | |
| except Exception as e: | |
| print("Failed to load faster-whisper model:", e) | |
| print("Continuing without transcription.") | |
| self.model = None | |
| else: | |
| print("faster-whisper not available. Transcription will be disabled.") | |
| def transcribe_file(self, wav_path: str): | |
| ''' | |
| Transcribes a WAV file and returns the text. | |
| ''' | |
| if not self.model: | |
| return None | |
| try: | |
| segments, info = self.model.transcribe(wav_path, beam_size=5) | |
| text = " ".join([seg.text.strip() for seg in segments]) | |
| return text | |
| except Exception as e: | |
| print(f"Transcription error for {wav_path}: {e}") | |
| return None | |
| def diarization_hook(audio_path: str): | |
| """ | |
| Run speaker diarization and return list of (start, end, speaker) tuples. | |
| """ | |
| if not DIARIZATION_AVAILABLE: | |
| return None | |
| diarization = diarization_pipeline(audio_path) | |
| results = [] | |
| for turn, _, speaker in diarization.itertracks(yield_label=True): | |
| results.append((turn.start, turn.end, speaker)) | |
| return results | |
| # Recorder threads | |
| def record_loop(device_index, out_queue, label="mic"): | |
| """ | |
| Continuously read bytes from device stream and push full-second frames to queue. | |
| """ | |
| try: | |
| stream = open_stream_for_device(device_index, channels=CHANNELS) | |
| except Exception as e: | |
| print(f"Could not open stream for device {device_index} ({label}): {e}") | |
| return | |
| frames_per_chunk = int(RATE / CHUNK * CHUNK_DURATION_SECS) | |
| frames = [] | |
| print(f"Recording from device {device_index} ({label}) ... Press Ctrl+C to stop.") | |
| error_count = 0 | |
| try: | |
| while True: | |
| try: | |
| data = stream.read(CHUNK, exception_on_overflow=False) | |
| error_count = 0 # reset on success | |
| except Exception as e: | |
| print(f"Read error on device {device_index} ({label}): {e}") | |
| error_count += 1 | |
| if error_count > 10: | |
| print(f"Too many errors on device {device_index} ({label}). Stopping this thread.") | |
| break | |
| continue | |
| frames.append(data) | |
| if len(frames) >= frames_per_chunk: | |
| ts = datetime.datetime.utcnow().strftime("%Y%m%d_%H%M%S_%f") | |
| filename = CHUNKS_DIR / f"{label}_{ts}.wav" | |
| out_queue.put((filename, frames.copy())) | |
| frames = [] | |
| except KeyboardInterrupt: | |
| print(f"Recording thread {label} received KeyboardInterrupt.") | |
| finally: | |
| try: | |
| stream.stop_stream() | |
| stream.close() | |
| except Exception: | |
| pass | |
| print(f"Recording thread for {label} exited.") | |
| def chunk_writer_and_transcribe_worker(in_queue: queue.Queue, final_frames_list: list, transcriber: Transcriber, single_channel_label="mic"): | |
| """ | |
| Save chunk WAV files and optionally send for transcription. | |
| Also store frames for final concatenated WAV. | |
| """ | |
| while True: | |
| try: | |
| filename, frames = in_queue.get(timeout=1.0) | |
| except queue.Empty: | |
| if stop_event.is_set() and in_queue.empty(): | |
| break | |
| continue | |
| save_wav_from_frames(filename, frames, nchannels=CHANNELS) | |
| print(f"Saved chunk: {filename.name}") | |
| final_frames_list.extend(frames) | |
| diar = diarization_hook(str(filename)) | |
| diar_segments = [] | |
| if diar: | |
| diar_segments = diar # list of (start, end, speaker) | |
| # Transcribe chunk and get segments with timestamps | |
| if transcriber and transcriber.model: | |
| try: | |
| # Get segments with timestamps | |
| segments, info = transcriber.model.transcribe(str(filename), beam_size=5) | |
| # For each segment, find the speaker | |
| for seg in segments: | |
| seg_start = seg.start | |
| seg_end = seg.end | |
| seg_text = seg.text.strip() | |
| # Find speaker whose segment overlaps with this transcription segment | |
| speaker = "Unknown" | |
| for d_start, d_end, d_speaker in diar_segments: | |
| # If diarization segment overlaps with transcription segment | |
| if (seg_start < d_end) and (seg_end > d_start): | |
| speaker = d_speaker | |
| break | |
| line = f"[{filename.name}] {seg_start:.2f}-{seg_end:.2f} Speaker {speaker}: {seg_text}\n" | |
| print(line.strip()) | |
| with open(TRANSCRIPT_FILE, "a", encoding="utf-8") as tf: | |
| tf.write(line) | |
| except Exception as e: | |
| print(f"Transcription error for {filename.name}: {e}") | |
| print("Chunk writer/transcriber worker exiting.") | |
| def run_recording(mic_index, sys_index=None, chunk_secs=CHUNK_DURATION_SECS, model_name=MODEL_NAME, no_transcribe=False): | |
| global CHUNK_DURATION_SECS, MODEL_NAME, stop_event | |
| CHUNK_DURATION_SECS = chunk_secs | |
| MODEL_NAME = model_name | |
| # Start transcriber if enabled | |
| transcriber = None if no_transcribe else Transcriber() | |
| # Queues and threads | |
| q = queue.Queue() | |
| final_frames = [] | |
| if 'stop_event' not in globals() or stop_event is None: | |
| stop_event = threading.Event() | |
| mic_thread = threading.Thread(target=record_loop, args=(mic_index, q, "mic"), daemon=True) | |
| mic_thread.start() | |
| q_sys = None | |
| writer_thread_sys = None | |
| final_frames_sys = [] | |
| if sys_index is not None: | |
| q_sys = queue.Queue() | |
| sys_thread = threading.Thread(target=record_loop, args=(sys_index, q_sys, "sys"), daemon=True) | |
| sys_thread.start() | |
| writer_thread_sys = threading.Thread(target=chunk_writer_and_transcribe_worker, args=(q_sys, final_frames_sys, None, "sys"), daemon=True) | |
| writer_thread_sys.start() | |
| writer_thread = threading.Thread(target=chunk_writer_and_transcribe_worker, args=(q, final_frames, transcriber, "mic"), daemon=True) | |
| writer_thread.start() | |
| # try: | |
| # while True: | |
| # time.sleep(0.5) | |
| # except KeyboardInterrupt: | |
| # print("\nStopping all threads...") | |
| # stop_event.set() | |
| # time.sleep(1.0) | |
| try: | |
| # wait until the shared stop_event is set by the caller (Flask / api_stop-recording) | |
| while not stop_event.is_set(): | |
| time.sleep(0.5) | |
| except KeyboardInterrupt: | |
| print("\nStopping all threads.") | |
| stop_event.set() | |
| time.sleep(1.0) | |
| writer_thread.join(timeout=5) | |
| if writer_thread_sys: | |
| writer_thread_sys.join(timeout=5) | |
| if final_frames: | |
| save_wav_from_frames(FINAL_WAV, final_frames, nchannels=CHANNELS) | |
| print(f"Saved final WAV: {FINAL_WAV}") | |
| if final_frames and final_frames_sys: | |
| final_sys_wav = OUTPUT_DIR / "recorded_system_full.wav" | |
| save_wav_from_frames(final_sys_wav, final_frames_sys, nchannels=CHANNELS) | |
| stereo_path = OUTPUT_DIR / "recorded_audio_stereo.wav" | |
| merge_mono_files_to_stereo(FINAL_WAV, final_sys_wav, stereo_path) | |
| print(f"Saved merged stereo WAV: {stereo_path}") | |
| audio.terminate() | |
| print("Done. Transcript (if any) saved to:", TRANSCRIPT_FILE) | |
| # Main | |
| if __name__ == "__main__": | |
| list_input_devices() | |
| mic_index = input("\nEnter the device index for your microphone (or press ENTER to use default): ").strip() | |
| if mic_index == "": | |
| mic_index = pyaudio.PyAudio().get_default_input_device_info()['index'] | |
| else: | |
| mic_index = int(mic_index) | |
| run_recording(mic_index) | |