| """ |
| Hugging Face Spaces - Gradio App for Stutter Analysis |
| ===================================================== |
| This is a standalone Gradio app for deployment on Hugging Face Spaces. |
| |
| To deploy: |
| 1. Create a new Space on huggingface.co/spaces |
| 2. Choose "Gradio" as SDK |
| 3. Upload this folder's contents |
| 4. Add your model checkpoint to the Space |
| """ |
|
|
| import gradio as gr |
| import torch |
| import torchaudio |
| import tempfile |
| import os |
| import json |
| import soundfile as sf |
| import librosa |
| from datetime import datetime |
| from transformers import WavLMModel |
| import torch.nn as nn |
| import whisper |
|
|
| |
| |
| |
|
|
| class WaveLmStutterClassification(nn.Module): |
| def __init__(self, num_labels=5, freeze_encoder=True, unfreeze_last_n_layers=1): |
| super().__init__() |
| self.wavlm = WavLMModel.from_pretrained("microsoft/wavlm-base") |
| self.hidden_size = self.wavlm.config.hidden_size |
| |
| if freeze_encoder: |
| for param in self.wavlm.parameters(): |
| param.requires_grad = False |
| |
| if unfreeze_last_n_layers > 0: |
| for layer in self.wavlm.encoder.layers[-unfreeze_last_n_layers:]: |
| for param in layer.parameters(): |
| param.requires_grad = True |
| |
| |
| self.classifier = nn.Linear(self.hidden_size, num_labels) |
| self.num_labels = num_labels |
| |
| def forward(self, input_values, attention_mask=None): |
| outputs = self.wavlm(input_values, attention_mask=attention_mask) |
| hidden_states = outputs.last_hidden_state |
| pooled = hidden_states.mean(dim=1) |
| logits = self.classifier(pooled) |
| return logits |
|
|
| |
| |
| |
|
|
| STUTTER_LABELS = ['Prolongation', 'Block', 'SoundRep', 'WordRep', 'Interjection'] |
|
|
| STUTTER_DEFINITIONS = { |
| 'Prolongation': 'Sound stretched longer than normal (e.g., "Ssssssnake")', |
| 'Block': 'Complete stoppage of airflow/sound with tension', |
| 'SoundRep': 'Sound/syllable repetition (e.g., "B-b-b-ball")', |
| 'WordRep': 'Whole word repetition (e.g., "I-I-I want")', |
| 'Interjection': 'Filler words like "um", "uh", "like"' |
| } |
|
|
| SEVERITY_THRESHOLDS = {'very_mild': 5, 'mild': 10, 'moderate': 20, 'severe': 30} |
|
|
| |
| |
| |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| wavlm_model = None |
| whisper_model = None |
|
|
| def load_models(): |
| global wavlm_model, whisper_model |
| |
| |
| print("Loading WavLM model...") |
| wavlm_model = WaveLmStutterClassification(num_labels=5) |
| |
| |
| checkpoint_path = "wavlm_stutter_classification_best.pth" |
| if os.path.exists(checkpoint_path): |
| checkpoint = torch.load(checkpoint_path, map_location=device) |
| |
| if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: |
| wavlm_model.load_state_dict(checkpoint['model_state_dict']) |
| print(f"Loaded checkpoint with {checkpoint.get('val_accuracy', 'N/A')} accuracy") |
| else: |
| |
| wavlm_model.load_state_dict(checkpoint) |
| print("Loaded checkpoint (direct state_dict format)") |
| else: |
| print("WARNING: No checkpoint found, using random weights") |
| |
| wavlm_model.to(device) |
| wavlm_model.eval() |
| |
| |
| print("Loading Whisper model...") |
| whisper_model = whisper.load_model("base", device=device) |
| |
| print("Models loaded!") |
|
|
| |
| |
| |
|
|
| def preprocess_audio(audio_path): |
| """Convert audio to 16kHz mono using soundfile or librosa.""" |
| try: |
| |
| waveform_np, sr = sf.read(audio_path, dtype='float32') |
| |
| |
| if len(waveform_np.shape) > 1: |
| waveform_np = waveform_np.mean(axis=1) |
| |
| except Exception as e: |
| print(f"Soundfile load failed, trying librosa: {e}") |
| |
| |
| waveform_np, sr = librosa.load(audio_path, sr=16000, mono=True) |
|
|
| |
| waveform = torch.from_numpy(waveform_np).float() |
| |
| |
| |
| if sr != 16000: |
| resampler = torchaudio.transforms.Resample(sr, 16000) |
| waveform = resampler(waveform.unsqueeze(0)).squeeze(0) |
| |
| return waveform, 16000 |
|
|
| def chunk_audio(waveform, sr, chunk_sec=3.0): |
| """Split audio into chunks""" |
| chunk_samples = int(chunk_sec * sr) |
| chunks = [] |
| |
| for start in range(0, len(waveform), chunk_samples): |
| end = min(start + chunk_samples, len(waveform)) |
| chunk = waveform[start:end] |
| |
| |
| if len(chunk) < chunk_samples: |
| chunk = torch.nn.functional.pad(chunk, (0, chunk_samples - len(chunk))) |
| |
| chunks.append({ |
| 'chunk': chunk, |
| 'start': start / sr, |
| 'end': end / sr |
| }) |
| |
| return chunks |
|
|
| def analyze_chunk(chunk_waveform, threshold=0.5): |
| """Run WavLM on a single chunk""" |
| with torch.no_grad(): |
| input_tensor = chunk_waveform.unsqueeze(0).to(device) |
| logits = wavlm_model(input_tensor) |
| probs = torch.sigmoid(logits).cpu().numpy()[0] |
| |
| detected = [STUTTER_LABELS[i] for i, p in enumerate(probs) if p > threshold] |
| probabilities = {STUTTER_LABELS[i]: float(probs[i]) for i in range(len(STUTTER_LABELS))} |
| |
| return {'detected': detected, 'probabilities': probabilities} |
|
|
| def get_severity(word_stutter_rate): |
| """Calculate severity from word stutter rate""" |
| if word_stutter_rate < SEVERITY_THRESHOLDS['very_mild']: |
| return 'Very Mild', 1 |
| elif word_stutter_rate < SEVERITY_THRESHOLDS['mild']: |
| return 'Mild', 2 |
| elif word_stutter_rate < SEVERITY_THRESHOLDS['moderate']: |
| return 'Moderate', 3 |
| elif word_stutter_rate < SEVERITY_THRESHOLDS['severe']: |
| return 'Severe', 4 |
| else: |
| return 'Very Severe', 5 |
|
|
| |
| |
| |
|
|
| def analyze_audio(audio_file, threshold=0.5): |
| """Main analysis function for Gradio""" |
| |
| if wavlm_model is None: |
| load_models() |
| |
| if audio_file is None: |
| return "β οΈ Please upload an audio file", "", "", "" |
| |
| try: |
| print(f"Starting analysis of: {audio_file}") |
| |
| |
| waveform, sr = preprocess_audio(audio_file) |
| duration = len(waveform) / sr |
| print(f"Audio preprocessed: {duration:.1f}s, {sr}Hz") |
| |
| |
| chunks = chunk_audio(waveform, sr) |
| |
| stutter_counts = {label: 0 for label in STUTTER_LABELS} |
| timeline = [] |
| |
| for chunk_info in chunks: |
| result = analyze_chunk(chunk_info['chunk'], threshold) |
| for label in result['detected']: |
| stutter_counts[label] += 1 |
| |
| timeline.append({ |
| 'time': f"{chunk_info['start']:.1f}s - {chunk_info['end']:.1f}s", |
| 'detected': ', '.join(result['detected']) if result['detected'] else 'Clear', |
| 'probs': result['probabilities'] |
| }) |
| |
| |
| whisper_result = whisper_model.transcribe(audio_file, word_timestamps=True) |
| transcription = whisper_result['text'] |
| |
| |
| words = [] |
| if 'segments' in whisper_result: |
| for seg in whisper_result['segments']: |
| if 'words' in seg: |
| words.extend(seg['words']) |
| |
| |
| words_with_stutter = 0 |
| annotated_words = [] |
| |
| for word_info in words: |
| word_start = word_info.get('start', 0) |
| word_end = word_info.get('end', 0) |
| word_text = word_info.get('word', '') |
| |
| word_stutters = [] |
| for chunk_info in chunks: |
| if word_start < chunk_info['end'] and word_end > chunk_info['start']: |
| result = analyze_chunk(chunk_info['chunk'], threshold) |
| word_stutters.extend(result['detected']) |
| |
| word_stutters = list(set(word_stutters)) |
| if word_stutters: |
| words_with_stutter += 1 |
| annotated_words.append(f"**[{word_text}]**({', '.join(word_stutters)})") |
| else: |
| annotated_words.append(word_text) |
| |
| |
| total_words = len(words) if words else 1 |
| word_stutter_rate = (words_with_stutter / total_words) * 100 |
| severity_label, severity_score = get_severity(word_stutter_rate) |
| |
| |
| summary = f""" |
| ## π Analysis Summary |
| |
| **Duration:** {duration:.1f} seconds |
| **Total Words:** {total_words} |
| **Words with Stutters:** {words_with_stutter} ({word_stutter_rate:.1f}%) |
| |
| ### Severity: {severity_label} ({severity_score}/5) |
| |
| ### Stutter Type Counts: |
| """ |
| for label, count in stutter_counts.items(): |
| if count > 0: |
| summary += f"- **{label}**: {count} occurrences\n" |
| |
| |
| annotated_text = " ".join(annotated_words) if annotated_words else transcription |
| |
| |
| timeline_text = "| Time | Detected Stutters |\n|------|-------------------|\n" |
| for t in timeline[:15]: |
| timeline_text += f"| {t['time']} | {t['detected']} |\n" |
| |
| |
| definitions = "## π Stutter Type Definitions\n\n" |
| for label, desc in STUTTER_DEFINITIONS.items(): |
| definitions += f"**{label}:** {desc}\n\n" |
| |
| return summary, annotated_text, timeline_text, definitions |
| |
| except Exception as e: |
| import traceback |
| error_trace = traceback.format_exc() |
| print(f"Error in analyze_audio: {error_trace}") |
| return f"β Error: {str(e)}\n\n```\n{error_trace}\n```", "", "", "" |
|
|
| |
| |
| |
|
|
| with gr.Blocks(title="ποΈ Stutter Analysis") as demo: |
| gr.Markdown(""" |
| # ποΈ Speech Fluency Analysis System |
| |
| Upload an audio file to analyze stuttering patterns using AI. |
| |
| **Supported formats:** WAV, MP3, M4A, FLAC |
| """) |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| audio_input = gr.Audio( |
| label="Upload Audio", |
| type="filepath", |
| sources=["upload", "microphone"] |
| ) |
| threshold_slider = gr.Slider( |
| minimum=0.3, |
| maximum=0.7, |
| value=0.5, |
| step=0.05, |
| label="Detection Threshold", |
| info="Lower = more sensitive, Higher = more conservative" |
| ) |
| analyze_btn = gr.Button("π Analyze Speech", variant="primary") |
| |
| with gr.Column(scale=2): |
| summary_output = gr.Markdown(label="Summary") |
| |
| with gr.Tabs(): |
| with gr.Tab("π Transcription"): |
| transcription_output = gr.Markdown(label="Annotated Transcription") |
| |
| with gr.Tab("π Timeline"): |
| timeline_output = gr.Markdown(label="Timeline Analysis") |
| |
| with gr.Tab("π Definitions"): |
| definitions_output = gr.Markdown(label="Stutter Definitions") |
| |
| analyze_btn.click( |
| fn=analyze_audio, |
| inputs=[audio_input, threshold_slider], |
| outputs=[summary_output, transcription_output, timeline_output, definitions_output] |
| ) |
| |
| gr.Markdown(""" |
| --- |
| **Disclaimer:** This tool is for educational/research purposes. |
| Consult a qualified speech-language pathologist for clinical diagnosis. |
| |
| Built with WavLM + Whisper | [GitHub](https://github.com/abhicodes-here2001/Multimodal-stuttering-analysis) |
| """) |
|
|
| |
| load_models() |
|
|
| if __name__ == "__main__": |
| demo.launch(theme=gr.themes.Soft()) |
|
|