ceh-vedant's picture
Update app.py
1458172 verified
import sys
import os
import types
import logging
import re
# Shim for removed audioop module (Python 3.13+)
if 'audioop' not in sys.modules:
sys.modules['audioop'] = types.ModuleType('audioop')
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg')
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
# ---------------------------------------------------------------------------
# Monkey-patch TRIBE's ExtractWordsFromAudio to build word-level events
# WITHOUT calling whisperx (which requires CUDA libs unavailable on CPU).
#
# Instead, we use a simple heuristic: split the transcript text into words
# and distribute them evenly across the audio duration. This gives TRIBE
# enough word-level signal for its text encoder without needing ASR.
# ---------------------------------------------------------------------------
def _patched_get_transcript_from_audio(wav_filename, language="english"):
"""CPU-safe replacement that creates word events from audio duration.
When the audio was generated from known text (gTTS), the global
CURRENT_SCRIPT_TEXT will contain that text. Otherwise we create
a minimal placeholder so TRIBE's pipeline doesn't crash.
"""
import pandas as pd
import soundfile as sf
from pathlib import Path
wav_filename = Path(wav_filename)
# Get audio duration
try:
info = sf.info(str(wav_filename))
duration = info.duration
except Exception:
duration = 30.0 # fallback
# Use the known script text if available, otherwise a placeholder
text = _CURRENT_SCRIPT_TEXT or "audio content placeholder"
# Tokenize into words
raw_words = text.split()
if not raw_words:
return pd.DataFrame(columns=["text", "start", "duration", "sequence_id", "sentence"])
# Split into sentences (rough: split on . ! ?)
sentences = re.split(r'(?<=[.!?])\s+', text)
sentences = [s.strip() for s in sentences if s.strip()]
if not sentences:
sentences = [text]
# Distribute words evenly across the audio duration
word_duration = duration / len(raw_words)
words = []
word_idx = 0
for sent_idx, sentence in enumerate(sentences):
sent_words = sentence.split()
for w in sent_words:
if word_idx >= len(raw_words):
break
words.append({
"text": w.replace('"', ''),
"start": word_idx * word_duration,
"duration": word_duration * 0.9,
"sequence_id": sent_idx,
"sentence": sentence.replace('"', ''),
})
word_idx += 1
return pd.DataFrame(words)
# Global to pass text from the analyze function to the monkey-patch
_CURRENT_SCRIPT_TEXT = None
def apply_patches():
"""Patch TRIBE's ExtractWordsFromAudio to avoid whisperx/CUDA dependency."""
try:
from tribev2.eventstransforms import ExtractWordsFromAudio
ExtractWordsFromAudio._get_transcript_from_audio = staticmethod(
_patched_get_transcript_from_audio
)
logger.info("Patched ExtractWordsFromAudio (CPU-safe, no whisperx)")
except Exception as e:
logger.warning(f"Could not patch ExtractWordsFromAudio: {e}")
# Apply patches at import time
apply_patches()
# ---------------------------------------------------------------------------
# Model loading
# ---------------------------------------------------------------------------
model = None
def load_model():
global model
if model is not None:
return "βœ… Already loaded!"
try:
apply_patches() # re-apply in case import order matters
from tribev2 import TribeModel
model = TribeModel.from_pretrained("facebook/tribev2", cache_folder="/tmp/tribe_cache")
return "βœ… Model loaded!"
except Exception as e:
import traceback
traceback.print_exc()
return f"❌ Error loading model: {str(e)}"
# ---------------------------------------------------------------------------
# Brain region definitions (approximate vertex ranges on fsaverage5)
# ---------------------------------------------------------------------------
REGIONS = [
("Visual cortex", 0.00, 0.15, "#378ADD"),
("Auditory cortex", 0.15, 0.30, "#D85A30"),
("Language (Broca's area)", 0.30, 0.45, "#7F77DD"),
("Prefrontal (attention)", 0.45, 0.62, "#1D9E75"),
("Temporal (memory)", 0.62, 0.78, "#BA7517"),
("Emotion (limbic)", 0.78, 1.00, "#D4537E"),
]
def score_predictions(preds):
avg = np.mean(np.abs(preds), axis=0)
global_max = avg.max() + 1e-8
half = len(avg) // 2
scores = {}
for name, s, e, _ in REGIONS:
start, end = int(half * s), int(half * e)
scores[name] = round(float(np.mean(avg[start:end]) / global_max * 100), 1)
return scores, round(sum(scores.values()) / len(scores), 1)
def make_brain_plot(preds):
try:
from nilearn import plotting, datasets
avg = np.mean(np.abs(preds), axis=0)
avg_norm = (avg - avg.min()) / (avg.max() - avg.min() + 1e-8)
half = len(avg_norm) // 2
fsaverage = datasets.fetch_surf_fsaverage("fsaverage5")
fig, axes = plt.subplots(1, 2, figsize=(14, 5), subplot_kw={"projection": "3d"})
fig.patch.set_facecolor("#111111")
plotting.plot_surf_stat_map(fsaverage.infl_left, avg_norm[:half], hemi="left",
view="lateral", colorbar=True, cmap="hot", title="Left hemisphere", axes=axes[0], figure=fig)
plotting.plot_surf_stat_map(fsaverage.infl_right, avg_norm[half:], hemi="right",
view="lateral", colorbar=True, cmap="hot", title="Right hemisphere", axes=axes[1], figure=fig)
plt.tight_layout()
plt.savefig("/tmp/brain_map.png", dpi=130, bbox_inches="tight", facecolor="#111111")
plt.close()
return "/tmp/brain_map.png"
except Exception as e:
print(f"Brain plot error: {e}")
return None
def make_score_chart(scores, overall):
fig, ax = plt.subplots(figsize=(9, 4))
fig.patch.set_facecolor("#1a1a1a")
ax.set_facecolor("#1a1a1a")
names = [r[0] for r in REGIONS]
colors = [r[3] for r in REGIONS]
vals = [scores.get(n, 0) for n in names]
bars = ax.barh(names, vals, color=colors, height=0.55)
ax.set_xlim(0, 100)
ax.axvline(70, color="#888", linestyle="--", linewidth=1, alpha=0.6)
ax.set_xlabel("Activation score", color="#ccc", fontsize=11)
ax.set_title(f"Brain region activation | Overall: {overall}/100",
color="white", fontsize=13, fontweight="bold", pad=12)
ax.tick_params(colors="#ccc")
for spine in ax.spines.values():
spine.set_edgecolor("#333")
for bar, val in zip(bars, vals):
ax.text(bar.get_width() + 1, bar.get_y() + bar.get_height() / 2,
f"{val}", va="center", color="white", fontsize=10, fontweight="bold")
plt.tight_layout()
plt.savefig("/tmp/score_chart.png", dpi=130, bbox_inches="tight", facecolor="#1a1a1a")
plt.close()
return "/tmp/score_chart.png"
def generate_suggestions(scores, overall):
tips = []
if scores.get("Prefrontal (attention)", 100) < 70:
tips.append("β†’ Open with a bold question or surprising fact to boost attention")
if scores.get("Emotion (limbic)", 100) < 70:
tips.append("β†’ Add emotional language β€” 'imagine', 'feel', personal stories")
if scores.get("Temporal (memory)", 100) < 70:
tips.append("β†’ Include specific numbers or data points to improve memorability")
if scores.get("Visual cortex", 100) < 70:
tips.append("β†’ Use more visual language β€” describe what viewers will 'see'")
if scores.get("Language (Broca's area)", 100) < 70:
tips.append("β†’ Break long sentences into shorter, punchier ones")
if scores.get("Auditory cortex", 100) < 70:
tips.append("β†’ Add rhythm and repetition β€” the brain responds to sound patterns")
if not tips:
tips.append("β†’ Excellent! Consider adding a strong call-to-action at the end")
status = "🟒 Strong" if overall >= 75 else "🟑 Good, needs polish" if overall >= 55 else "πŸ”΄ Needs work"
return f"**Overall: {overall}/100 β€” {status}**\n\n" + "\n".join(tips)
# ---------------------------------------------------------------------------
# Main analysis function
# ---------------------------------------------------------------------------
def analyze(input_mode, script_text, audio_file, progress=gr.Progress()):
global _CURRENT_SCRIPT_TEXT
if input_mode == "Text" and (not script_text or not script_text.strip()):
return None, None, "⚠️ Please paste your script text first.", None
if input_mode == "Audio" and audio_file is None:
return None, None, "⚠️ Please upload an audio file first.", None
if model is None:
progress(0.1, desc="Loading TRIBE v2 model (first time ~5 mins)...")
msg = load_model()
if "Error" in msg:
return None, None, msg, None
try:
if input_mode == "Text":
progress(0.2, desc="Converting text to speech...")
from gtts import gTTS
from langdetect import detect
text = script_text.strip()
lang = detect(text)
audio_path = "/tmp/script_audio.mp3"
tts = gTTS(text=text, lang=lang)
tts.save(audio_path)
# Store text so the monkey-patched transcriber can use it
# instead of running ASR on the audio we just synthesised.
_CURRENT_SCRIPT_TEXT = text
progress(0.4, desc="Running TRIBE v2 on generated audio...")
df = model.get_events_dataframe(audio_path=audio_path)
else:
import shutil
progress(0.2, desc="Loading audio file...")
ext = os.path.splitext(audio_file)[1] or ".mp3"
audio_path = f"/tmp/input_audio{ext}"
shutil.copy(audio_file, audio_path)
# No known text for uploaded audio
_CURRENT_SCRIPT_TEXT = None
progress(0.4, desc="Running TRIBE v2 on audio...")
df = model.get_events_dataframe(audio_path=audio_path)
progress(0.6, desc="Predicting brain response...")
preds, segments = model.predict(events=df)
progress(0.75, desc="Scoring regions...")
scores, overall = score_predictions(preds)
progress(0.85, desc="Rendering maps...")
brain_img = make_brain_plot(preds)
score_img = make_score_chart(scores, overall)
suggestions = generate_suggestions(scores, overall)
np.save("/tmp/brain_predictions.npy", preds)
progress(1.0, desc="Done!")
return brain_img, score_img, suggestions, "/tmp/brain_predictions.npy"
except Exception as e:
import traceback
full_error = traceback.format_exc()
print(full_error)
return None, None, f"❌ Error:\n{str(e)}\n\nFull traceback:\n{full_error}", None
finally:
_CURRENT_SCRIPT_TEXT = None
# ---------------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------------
css = "#title{text-align:center} #subtitle{text-align:center;color:#888;font-size:14px}"
with gr.Blocks(theme=gr.themes.Soft(primary_hue="indigo"), css=css) as demo:
gr.Markdown("# 🧠 Script Brain Optimizer", elem_id="title")
gr.Markdown("Analyze your script or audio β†’ real fMRI predictions via **TRIBE v2** β†’ iterate", elem_id="subtitle")
with gr.Row():
with gr.Column(scale=1):
input_mode = gr.Radio(
choices=["Text", "Audio"], value="Text",
label="Input type",
info="Text: paste your script | Audio: upload MP3/WAV"
)
script_input = gr.Textbox(
label="Your script",
placeholder="Paste your content script here...",
lines=10, max_lines=20, visible=True
)
audio_input = gr.Audio(
label="Upload audio file (MP3, WAV, M4A, FLAC)",
type="filepath", sources=["upload"], visible=False
)
with gr.Row():
clear_btn = gr.Button("Clear", variant="secondary", scale=1)
analyze_btn = gr.Button("🧠 Analyze", variant="primary", scale=3)
suggestions_out = gr.Markdown(value="*Add your content and click Analyze...*")
download_out = gr.File(label="Download predictions (.npy)")
with gr.Column(scale=2):
brain_img_out = gr.Image(label="Brain activation map", height=320)
score_img_out = gr.Image(label="Region scores", height=280)
def toggle_mode(mode):
return gr.update(visible=mode=="Text"), gr.update(visible=mode=="Audio")
input_mode.change(fn=toggle_mode, inputs=[input_mode],
outputs=[script_input, audio_input])
analyze_btn.click(fn=analyze, inputs=[input_mode, script_input, audio_input],
outputs=[brain_img_out, score_img_out, suggestions_out, download_out])
clear_btn.click(
fn=lambda: ("", None, None, None, "*Add your content and click Analyze...*", None),
outputs=[script_input, audio_input, brain_img_out, score_img_out, suggestions_out, download_out]
)
gr.Markdown("---\n*Powered by [TRIBE v2](https://github.com/facebookresearch/tribev2) by Meta FAIR*")
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)