AudioTransDiar / rec_transcribe_extension.py
prthm11's picture
Update rec_transcribe_extension.py
2a26bdf verified
import os
import time
import wave
import queue
import threading
import datetime
from pathlib import Path
import wave
import pyaudio
from pyannote.audio import Pipeline
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
try:
diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=HUGGINGFACE_TOKEN)
DIARIZATION_AVAILABLE = True
except Exception:
diarization_pipeline = None
DIARIZATION_AVAILABLE = False
# Optional modules (import safely)
try:
from faster_whisper import WhisperModel
FASTER_WHISPER_AVAILABLE = True
except Exception:
FASTER_WHISPER_AVAILABLE = False
import numpy as np
# # Optional: voice activity detection
# try:
# import webrtcvad
# VAD_AVAILABLE = True
# except Exception:
# VAD_AVAILABLE = False
# ========== CONFIG ==========
RUN_TIMESTAMP = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
FORMAT = pyaudio.paInt16
CHANNELS = 1 # mono
# RATE = 16000
RATE = 44100
CHUNK = 1024 # frames per buffer read
CHUNK_DURATION_SECS = 5 # how long each saved chunk is (seconds)
# OUTPUT_DIR = Path("output_transcript_diarization")
OUTPUT_DIR = Path(os.environ.get("OUTPUT_DIR", "/app/output_transcript_diarization"))
try:
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
except PermissionError:
OUTPUT_DIR = Path("/tmp/output_transcript_diarization")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
CHUNKS_DIR = OUTPUT_DIR / f"chunks_{RUN_TIMESTAMP}"
FINAL_WAV = OUTPUT_DIR / f"recorded_audio_{RUN_TIMESTAMP}.wav"
TRANSCRIPT_FILE = OUTPUT_DIR / f"transcript_{RUN_TIMESTAMP}.txt"
MODEL_NAME = "medium" # if using faster-whisper; change as desired
# OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
# CHUNKS_DIR.mkdir(parents=True, exist_ok=True)
audio = pyaudio.PyAudio()
def list_input_devices():
'''
Lists all available audio input devices (microphones, loopbacks, etc.) with their
indices and channel counts.
'''
pa = pyaudio.PyAudio()
print("Available audio devices (inputs):")
for i in range(pa.get_device_count()):
dev = pa.get_device_info_by_index(i)
if dev.get("maxInputChannels", 0) > 0:
print(f" {i}: {dev['name']} - {dev['maxInputChannels']} chans")
pa.terminate()
def open_stream_for_device(device_index, channels=1):
'''
Opens a PyAudio input stream for the given device index and channel count.
'''
stream = audio.open(format=FORMAT,
channels=channels,
rate=RATE,
input=True,
frames_per_buffer=CHUNK,
input_device_index=device_index)
return stream
def save_wav_from_frames(path: Path, frames: list, nchannels=1):
'''
Saves a list of audio frames as a WAV file at the specified path.
'''
# Normalize of 44100 Rate
raw = b''.join(frames)
audio_array = np.frombuffer(raw, dtype=np.int16)
# Normalize: scale to 90% of int16 range
if np.max(np.abs(audio_array)) > 0:
audio_array = (audio_array / np.max(np.abs(audio_array)) * 32767 * 0.9).astype(np.int16)
with wave.open(str(path), 'wb') as wf:
wf.setnchannels(nchannels)
wf.setsampwidth(audio.get_sample_size(FORMAT))
wf.setframerate(RATE)
# wf.writeframes(b''.join(frames))
wf.writeframes(audio_array.tobytes())
def merge_mono_files_to_stereo(mic_path: Path, sys_path: Path, out_path: Path):
"""
Create simple stereo WAV: mic -> left channel, system -> right channel.
Requires numpy. Very naive — works when both inputs have same sample rate and length.
"""
with wave.open(str(mic_path), 'rb') as wm, wave.open(str(sys_path), 'rb') as ws:
assert wm.getframerate() == ws.getframerate() == RATE
sampwidth = wm.getsampwidth()
nframes = min(wm.getnframes(), ws.getnframes())
mic_bytes = wm.readframes(nframes)
sys_bytes = ws.readframes(nframes)
# convert bytes to int16
mic_arr = np.frombuffer(mic_bytes, dtype=np.int16)
sys_arr = np.frombuffer(sys_bytes, dtype=np.int16)
# interleave into stereo
stereo = np.empty((nframes * 2,), dtype=np.int16)
stereo[0::2] = mic_arr[:nframes]
stereo[1::2] = sys_arr[:nframes]
with wave.open(str(out_path), 'wb') as wf:
wf.setnchannels(2)
wf.setsampwidth(sampwidth)
wf.setframerate(RATE)
wf.writeframes(stereo.tobytes())
class Transcriber:
def __init__(self):
'''
Loads the faster-whisper model if available.
'''
self.model = None
if FASTER_WHISPER_AVAILABLE:
print("Loading faster-whisper model. This may take some time...")
# detect device via torch.cuda if available, otherwise CPU
try:
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
except Exception:
device = "cpu"
# choose a safe compute_type: float16 on GPU, float32 on CPU
compute_type = "float16" if device == "cuda" else "float32"
try:
# instantiate model (may download weights on first run)
self.model = WhisperModel(MODEL_NAME, device=device, compute_type=compute_type)
print(f"Model loaded on {device} (compute_type={compute_type}).")
except Exception as e:
print("Failed to load faster-whisper model:", e)
print("Continuing without transcription.")
self.model = None
else:
print("faster-whisper not available. Transcription will be disabled.")
def transcribe_file(self, wav_path: str):
'''
Transcribes a WAV file and returns the text.
'''
if not self.model:
return None
try:
segments, info = self.model.transcribe(wav_path, beam_size=5)
text = " ".join([seg.text.strip() for seg in segments])
return text
except Exception as e:
print(f"Transcription error for {wav_path}: {e}")
return None
def diarization_hook(audio_path: str):
"""
Run speaker diarization and return list of (start, end, speaker) tuples.
"""
if not DIARIZATION_AVAILABLE:
return None
diarization = diarization_pipeline(audio_path)
results = []
for turn, _, speaker in diarization.itertracks(yield_label=True):
results.append((turn.start, turn.end, speaker))
return results
# Recorder threads
def record_loop(device_index, out_queue, label="mic"):
"""
Continuously read bytes from device stream and push full-second frames to queue.
"""
try:
stream = open_stream_for_device(device_index, channels=CHANNELS)
except Exception as e:
print(f"Could not open stream for device {device_index} ({label}): {e}")
return
frames_per_chunk = int(RATE / CHUNK * CHUNK_DURATION_SECS)
frames = []
print(f"Recording from device {device_index} ({label}) ... Press Ctrl+C to stop.")
error_count = 0
try:
while True:
try:
data = stream.read(CHUNK, exception_on_overflow=False)
error_count = 0 # reset on success
except Exception as e:
print(f"Read error on device {device_index} ({label}): {e}")
error_count += 1
if error_count > 10:
print(f"Too many errors on device {device_index} ({label}). Stopping this thread.")
break
continue
frames.append(data)
if len(frames) >= frames_per_chunk:
ts = datetime.datetime.utcnow().strftime("%Y%m%d_%H%M%S_%f")
filename = CHUNKS_DIR / f"{label}_{ts}.wav"
out_queue.put((filename, frames.copy()))
frames = []
except KeyboardInterrupt:
print(f"Recording thread {label} received KeyboardInterrupt.")
finally:
try:
stream.stop_stream()
stream.close()
except Exception:
pass
print(f"Recording thread for {label} exited.")
def chunk_writer_and_transcribe_worker(in_queue: queue.Queue, final_frames_list: list, transcriber: Transcriber, single_channel_label="mic"):
"""
Save chunk WAV files and optionally send for transcription.
Also store frames for final concatenated WAV.
"""
while True:
try:
filename, frames = in_queue.get(timeout=1.0)
except queue.Empty:
if stop_event.is_set() and in_queue.empty():
break
continue
save_wav_from_frames(filename, frames, nchannels=CHANNELS)
print(f"Saved chunk: {filename.name}")
final_frames_list.extend(frames)
diar = diarization_hook(str(filename))
diar_segments = []
if diar:
diar_segments = diar # list of (start, end, speaker)
# Transcribe chunk and get segments with timestamps
if transcriber and transcriber.model:
try:
# Get segments with timestamps
segments, info = transcriber.model.transcribe(str(filename), beam_size=5)
# For each segment, find the speaker
for seg in segments:
seg_start = seg.start
seg_end = seg.end
seg_text = seg.text.strip()
# Find speaker whose segment overlaps with this transcription segment
speaker = "Unknown"
for d_start, d_end, d_speaker in diar_segments:
# If diarization segment overlaps with transcription segment
if (seg_start < d_end) and (seg_end > d_start):
speaker = d_speaker
break
line = f"[{filename.name}] {seg_start:.2f}-{seg_end:.2f} Speaker {speaker}: {seg_text}\n"
print(line.strip())
with open(TRANSCRIPT_FILE, "a", encoding="utf-8") as tf:
tf.write(line)
except Exception as e:
print(f"Transcription error for {filename.name}: {e}")
print("Chunk writer/transcriber worker exiting.")
def run_recording(mic_index, sys_index=None, chunk_secs=CHUNK_DURATION_SECS, model_name=MODEL_NAME, no_transcribe=False):
global CHUNK_DURATION_SECS, MODEL_NAME, stop_event
CHUNK_DURATION_SECS = chunk_secs
MODEL_NAME = model_name
# Start transcriber if enabled
transcriber = None if no_transcribe else Transcriber()
# Queues and threads
q = queue.Queue()
final_frames = []
if 'stop_event' not in globals() or stop_event is None:
stop_event = threading.Event()
mic_thread = threading.Thread(target=record_loop, args=(mic_index, q, "mic"), daemon=True)
mic_thread.start()
q_sys = None
writer_thread_sys = None
final_frames_sys = []
if sys_index is not None:
q_sys = queue.Queue()
sys_thread = threading.Thread(target=record_loop, args=(sys_index, q_sys, "sys"), daemon=True)
sys_thread.start()
writer_thread_sys = threading.Thread(target=chunk_writer_and_transcribe_worker, args=(q_sys, final_frames_sys, None, "sys"), daemon=True)
writer_thread_sys.start()
writer_thread = threading.Thread(target=chunk_writer_and_transcribe_worker, args=(q, final_frames, transcriber, "mic"), daemon=True)
writer_thread.start()
# try:
# while True:
# time.sleep(0.5)
# except KeyboardInterrupt:
# print("\nStopping all threads...")
# stop_event.set()
# time.sleep(1.0)
try:
# wait until the shared stop_event is set by the caller (Flask / api_stop-recording)
while not stop_event.is_set():
time.sleep(0.5)
except KeyboardInterrupt:
print("\nStopping all threads.")
stop_event.set()
time.sleep(1.0)
writer_thread.join(timeout=5)
if writer_thread_sys:
writer_thread_sys.join(timeout=5)
if final_frames:
save_wav_from_frames(FINAL_WAV, final_frames, nchannels=CHANNELS)
print(f"Saved final WAV: {FINAL_WAV}")
if final_frames and final_frames_sys:
final_sys_wav = OUTPUT_DIR / "recorded_system_full.wav"
save_wav_from_frames(final_sys_wav, final_frames_sys, nchannels=CHANNELS)
stereo_path = OUTPUT_DIR / "recorded_audio_stereo.wav"
merge_mono_files_to_stereo(FINAL_WAV, final_sys_wav, stereo_path)
print(f"Saved merged stereo WAV: {stereo_path}")
audio.terminate()
print("Done. Transcript (if any) saved to:", TRANSCRIPT_FILE)
# Main
if __name__ == "__main__":
list_input_devices()
mic_index = input("\nEnter the device index for your microphone (or press ENTER to use default): ").strip()
if mic_index == "":
mic_index = pyaudio.PyAudio().get_default_input_device_info()['index']
else:
mic_index = int(mic_index)
run_recording(mic_index)