import sys import os import types import logging import re # Shim for removed audioop module (Python 3.13+) if 'audioop' not in sys.modules: sys.modules['audioop'] = types.ModuleType('audioop') import gradio as gr import numpy as np import matplotlib.pyplot as plt import matplotlib matplotlib.use('Agg') logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) # --------------------------------------------------------------------------- # Monkey-patch TRIBE's ExtractWordsFromAudio to build word-level events # WITHOUT calling whisperx (which requires CUDA libs unavailable on CPU). # # Instead, we use a simple heuristic: split the transcript text into words # and distribute them evenly across the audio duration. This gives TRIBE # enough word-level signal for its text encoder without needing ASR. # --------------------------------------------------------------------------- def _patched_get_transcript_from_audio(wav_filename, language="english"): """CPU-safe replacement that creates word events from audio duration. When the audio was generated from known text (gTTS), the global CURRENT_SCRIPT_TEXT will contain that text. Otherwise we create a minimal placeholder so TRIBE's pipeline doesn't crash. """ import pandas as pd import soundfile as sf from pathlib import Path wav_filename = Path(wav_filename) # Get audio duration try: info = sf.info(str(wav_filename)) duration = info.duration except Exception: duration = 30.0 # fallback # Use the known script text if available, otherwise a placeholder text = _CURRENT_SCRIPT_TEXT or "audio content placeholder" # Tokenize into words raw_words = text.split() if not raw_words: return pd.DataFrame(columns=["text", "start", "duration", "sequence_id", "sentence"]) # Split into sentences (rough: split on . ! ?) sentences = re.split(r'(?<=[.!?])\s+', text) sentences = [s.strip() for s in sentences if s.strip()] if not sentences: sentences = [text] # Distribute words evenly across the audio duration word_duration = duration / len(raw_words) words = [] word_idx = 0 for sent_idx, sentence in enumerate(sentences): sent_words = sentence.split() for w in sent_words: if word_idx >= len(raw_words): break words.append({ "text": w.replace('"', ''), "start": word_idx * word_duration, "duration": word_duration * 0.9, "sequence_id": sent_idx, "sentence": sentence.replace('"', ''), }) word_idx += 1 return pd.DataFrame(words) # Global to pass text from the analyze function to the monkey-patch _CURRENT_SCRIPT_TEXT = None def apply_patches(): """Patch TRIBE's ExtractWordsFromAudio to avoid whisperx/CUDA dependency.""" try: from tribev2.eventstransforms import ExtractWordsFromAudio ExtractWordsFromAudio._get_transcript_from_audio = staticmethod( _patched_get_transcript_from_audio ) logger.info("Patched ExtractWordsFromAudio (CPU-safe, no whisperx)") except Exception as e: logger.warning(f"Could not patch ExtractWordsFromAudio: {e}") # Apply patches at import time apply_patches() # --------------------------------------------------------------------------- # Model loading # --------------------------------------------------------------------------- model = None def load_model(): global model if model is not None: return "✅ Already loaded!" try: apply_patches() # re-apply in case import order matters from tribev2 import TribeModel model = TribeModel.from_pretrained("facebook/tribev2", cache_folder="/tmp/tribe_cache") return "✅ Model loaded!" except Exception as e: import traceback traceback.print_exc() return f"❌ Error loading model: {str(e)}" # --------------------------------------------------------------------------- # Brain region definitions (approximate vertex ranges on fsaverage5) # --------------------------------------------------------------------------- REGIONS = [ ("Visual cortex", 0.00, 0.15, "#378ADD"), ("Auditory cortex", 0.15, 0.30, "#D85A30"), ("Language (Broca's area)", 0.30, 0.45, "#7F77DD"), ("Prefrontal (attention)", 0.45, 0.62, "#1D9E75"), ("Temporal (memory)", 0.62, 0.78, "#BA7517"), ("Emotion (limbic)", 0.78, 1.00, "#D4537E"), ] def score_predictions(preds): avg = np.mean(np.abs(preds), axis=0) global_max = avg.max() + 1e-8 half = len(avg) // 2 scores = {} for name, s, e, _ in REGIONS: start, end = int(half * s), int(half * e) scores[name] = round(float(np.mean(avg[start:end]) / global_max * 100), 1) return scores, round(sum(scores.values()) / len(scores), 1) def make_brain_plot(preds): try: from nilearn import plotting, datasets avg = np.mean(np.abs(preds), axis=0) avg_norm = (avg - avg.min()) / (avg.max() - avg.min() + 1e-8) half = len(avg_norm) // 2 fsaverage = datasets.fetch_surf_fsaverage("fsaverage5") fig, axes = plt.subplots(1, 2, figsize=(14, 5), subplot_kw={"projection": "3d"}) fig.patch.set_facecolor("#111111") plotting.plot_surf_stat_map(fsaverage.infl_left, avg_norm[:half], hemi="left", view="lateral", colorbar=True, cmap="hot", title="Left hemisphere", axes=axes[0], figure=fig) plotting.plot_surf_stat_map(fsaverage.infl_right, avg_norm[half:], hemi="right", view="lateral", colorbar=True, cmap="hot", title="Right hemisphere", axes=axes[1], figure=fig) plt.tight_layout() plt.savefig("/tmp/brain_map.png", dpi=130, bbox_inches="tight", facecolor="#111111") plt.close() return "/tmp/brain_map.png" except Exception as e: print(f"Brain plot error: {e}") return None def make_score_chart(scores, overall): fig, ax = plt.subplots(figsize=(9, 4)) fig.patch.set_facecolor("#1a1a1a") ax.set_facecolor("#1a1a1a") names = [r[0] for r in REGIONS] colors = [r[3] for r in REGIONS] vals = [scores.get(n, 0) for n in names] bars = ax.barh(names, vals, color=colors, height=0.55) ax.set_xlim(0, 100) ax.axvline(70, color="#888", linestyle="--", linewidth=1, alpha=0.6) ax.set_xlabel("Activation score", color="#ccc", fontsize=11) ax.set_title(f"Brain region activation | Overall: {overall}/100", color="white", fontsize=13, fontweight="bold", pad=12) ax.tick_params(colors="#ccc") for spine in ax.spines.values(): spine.set_edgecolor("#333") for bar, val in zip(bars, vals): ax.text(bar.get_width() + 1, bar.get_y() + bar.get_height() / 2, f"{val}", va="center", color="white", fontsize=10, fontweight="bold") plt.tight_layout() plt.savefig("/tmp/score_chart.png", dpi=130, bbox_inches="tight", facecolor="#1a1a1a") plt.close() return "/tmp/score_chart.png" def generate_suggestions(scores, overall): tips = [] if scores.get("Prefrontal (attention)", 100) < 70: tips.append("→ Open with a bold question or surprising fact to boost attention") if scores.get("Emotion (limbic)", 100) < 70: tips.append("→ Add emotional language — 'imagine', 'feel', personal stories") if scores.get("Temporal (memory)", 100) < 70: tips.append("→ Include specific numbers or data points to improve memorability") if scores.get("Visual cortex", 100) < 70: tips.append("→ Use more visual language — describe what viewers will 'see'") if scores.get("Language (Broca's area)", 100) < 70: tips.append("→ Break long sentences into shorter, punchier ones") if scores.get("Auditory cortex", 100) < 70: tips.append("→ Add rhythm and repetition — the brain responds to sound patterns") if not tips: tips.append("→ Excellent! Consider adding a strong call-to-action at the end") status = "🟢 Strong" if overall >= 75 else "🟡 Good, needs polish" if overall >= 55 else "🔴 Needs work" return f"**Overall: {overall}/100 — {status}**\n\n" + "\n".join(tips) # --------------------------------------------------------------------------- # Main analysis function # --------------------------------------------------------------------------- def analyze(input_mode, script_text, audio_file, progress=gr.Progress()): global _CURRENT_SCRIPT_TEXT if input_mode == "Text" and (not script_text or not script_text.strip()): return None, None, "⚠️ Please paste your script text first.", None if input_mode == "Audio" and audio_file is None: return None, None, "⚠️ Please upload an audio file first.", None if model is None: progress(0.1, desc="Loading TRIBE v2 model (first time ~5 mins)...") msg = load_model() if "Error" in msg: return None, None, msg, None try: if input_mode == "Text": progress(0.2, desc="Converting text to speech...") from gtts import gTTS from langdetect import detect text = script_text.strip() lang = detect(text) audio_path = "/tmp/script_audio.mp3" tts = gTTS(text=text, lang=lang) tts.save(audio_path) # Store text so the monkey-patched transcriber can use it # instead of running ASR on the audio we just synthesised. _CURRENT_SCRIPT_TEXT = text progress(0.4, desc="Running TRIBE v2 on generated audio...") df = model.get_events_dataframe(audio_path=audio_path) else: import shutil progress(0.2, desc="Loading audio file...") ext = os.path.splitext(audio_file)[1] or ".mp3" audio_path = f"/tmp/input_audio{ext}" shutil.copy(audio_file, audio_path) # No known text for uploaded audio _CURRENT_SCRIPT_TEXT = None progress(0.4, desc="Running TRIBE v2 on audio...") df = model.get_events_dataframe(audio_path=audio_path) progress(0.6, desc="Predicting brain response...") preds, segments = model.predict(events=df) progress(0.75, desc="Scoring regions...") scores, overall = score_predictions(preds) progress(0.85, desc="Rendering maps...") brain_img = make_brain_plot(preds) score_img = make_score_chart(scores, overall) suggestions = generate_suggestions(scores, overall) np.save("/tmp/brain_predictions.npy", preds) progress(1.0, desc="Done!") return brain_img, score_img, suggestions, "/tmp/brain_predictions.npy" except Exception as e: import traceback full_error = traceback.format_exc() print(full_error) return None, None, f"❌ Error:\n{str(e)}\n\nFull traceback:\n{full_error}", None finally: _CURRENT_SCRIPT_TEXT = None # --------------------------------------------------------------------------- # Gradio UI # --------------------------------------------------------------------------- css = "#title{text-align:center} #subtitle{text-align:center;color:#888;font-size:14px}" with gr.Blocks(theme=gr.themes.Soft(primary_hue="indigo"), css=css) as demo: gr.Markdown("# 🧠 Script Brain Optimizer", elem_id="title") gr.Markdown("Analyze your script or audio → real fMRI predictions via **TRIBE v2** → iterate", elem_id="subtitle") with gr.Row(): with gr.Column(scale=1): input_mode = gr.Radio( choices=["Text", "Audio"], value="Text", label="Input type", info="Text: paste your script | Audio: upload MP3/WAV" ) script_input = gr.Textbox( label="Your script", placeholder="Paste your content script here...", lines=10, max_lines=20, visible=True ) audio_input = gr.Audio( label="Upload audio file (MP3, WAV, M4A, FLAC)", type="filepath", sources=["upload"], visible=False ) with gr.Row(): clear_btn = gr.Button("Clear", variant="secondary", scale=1) analyze_btn = gr.Button("🧠 Analyze", variant="primary", scale=3) suggestions_out = gr.Markdown(value="*Add your content and click Analyze...*") download_out = gr.File(label="Download predictions (.npy)") with gr.Column(scale=2): brain_img_out = gr.Image(label="Brain activation map", height=320) score_img_out = gr.Image(label="Region scores", height=280) def toggle_mode(mode): return gr.update(visible=mode=="Text"), gr.update(visible=mode=="Audio") input_mode.change(fn=toggle_mode, inputs=[input_mode], outputs=[script_input, audio_input]) analyze_btn.click(fn=analyze, inputs=[input_mode, script_input, audio_input], outputs=[brain_img_out, score_img_out, suggestions_out, download_out]) clear_btn.click( fn=lambda: ("", None, None, None, "*Add your content and click Analyze...*", None), outputs=[script_input, audio_input, brain_img_out, score_img_out, suggestions_out, download_out] ) gr.Markdown("---\n*Powered by [TRIBE v2](https://github.com/facebookresearch/tribev2) by Meta FAIR*") if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)