Tribe_research / app.py
moonlantern1's picture
Create app.py
516e91b verified
import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from scipy.stats import pearsonr
from pathlib import Path
import gradio as gr
import spaces # HF Spaces GPU decorator
from tribev2.demo_utils import TribeModel
from tribev2.plotting import PlotBrain
from nilearn import datasets
from nilearn.surface import vol_to_surf
CACHE_FOLDER = Path("./cache")
CACHE_FOLDER.mkdir(exist_ok=True)
# Load model + plotter once at startup (avoids reloading per request)
print("Loading TRIBE v2 model...")
model = TribeModel.from_pretrained("facebook/tribev2", cache_folder=CACHE_FOLDER)
plotter = PlotBrain(mesh="fsaverage5", atlas_name="schaefer_2018", atlas_dim=400)
# Precompute atlas surface labels
print("Preparing atlas...")
fsaverage = datasets.fetch_surf_fsaverage('fsaverage5')
atlas = plotter.get_atlas()
left_labels = vol_to_surf(atlas.maps, fsaverage.pial_left,
interpolation='nearest_most_frequent', radius=3).astype(int)
right_labels = vol_to_surf(atlas.maps, fsaverage.pial_right,
interpolation='nearest_most_frequent', radius=3).astype(int)
surf_labels = np.concatenate([left_labels, right_labels])
unique_labels = np.unique(surf_labels)
unique_labels = unique_labels[unique_labels > 0]
labels_text = [l.decode() if isinstance(l, bytes) else str(l) for l in atlas.labels]
def find_network_cols(keyword):
return [i for i, lbl_id in enumerate(unique_labels)
if lbl_id < len(labels_text) and keyword.lower() in labels_text[lbl_id].lower()]
NETWORKS = {
'Visual': find_network_cols('Vis'),
'Somatomotor': find_network_cols('SomMot'),
'DorsAttn': find_network_cols('DorsAttn'),
'SalVentAttn': find_network_cols('SalVentAttn'),
'Limbic': find_network_cols('Limbic'),
'Control': find_network_cols('Cont'),
'Default': find_network_cols('Default'),
}
NET_COLORS = {
'Visual': '#ff6b6b', 'Somatomotor': '#4ecdc4', 'DorsAttn': '#ffe66d',
'SalVentAttn': '#a78bfa', 'Limbic': '#f9a826', 'Control': '#06ffa5',
'Default': '#3d5a80',
}
print("Ready.")
def aggregate_roi(preds):
n_sec = preds.shape[0]
roi = np.zeros((n_sec, len(unique_labels)))
for i, lbl in enumerate(unique_labels):
mask = surf_labels == lbl
if mask.sum() > 0:
roi[:, i] = preds[:, mask].mean(axis=1)
return roi
def build_dashboard(preds, roi_activations, save_path='dashboard.png'):
n_sec = preds.shape[0]
fig = plt.figure(figsize=(20, 14), facecolor='#0d1117')
gs = GridSpec(4, 5, figure=fig,
height_ratios=[1.2, 1.2, 2, 1],
hspace=0.4, wspace=0.2)
fig.suptitle('Neural Engagement Analysis — TRIBE v2',
fontsize=24, color='white', fontweight='bold', y=0.96)
fig.text(0.5, 0.925,
f'Predicted fMRI response across {n_sec} seconds · Schaefer-400 parcellation',
ha='center', color='#8b949e', fontsize=12)
brain_cols = 6
for i in range(min(n_sec, 12)):
row = i // brain_cols
col = i % brain_cols
ax = fig.add_subplot(gs[row, col] if col < 5 else gs[row, 4])
try:
plotter.plot_surf(preds[i], axes=ax, views='left',
cmap='fire', norm_percentile=99,
vmin=0.3, alpha_cmap=(0, 0.2))
except Exception:
ax.text(0.5, 0.5, f't={i}s', ha='center', va='center', color='white')
ax.set_title(f't = {i}s', color='white', fontsize=10, pad=2)
ax.axis('off')
ax.set_facecolor('#0d1117')
ax_net = fig.add_subplot(gs[2, :])
ax_net.set_facecolor('#161b22')
time_axis = np.arange(n_sec)
for name, cols in NETWORKS.items():
if not cols:
continue
tc = roi_activations[:, cols].mean(axis=1)
ax_net.plot(time_axis, tc, label=name, color=NET_COLORS.get(name, 'white'),
linewidth=2.5, marker='o', markersize=6, alpha=0.9)
visual_tc = roi_activations[:, NETWORKS['Visual']].mean(axis=1)
peak_visual = int(np.argmax(visual_tc))
transition = int(np.argmax(np.abs(np.diff(visual_tc))))
ax_net.axvline(peak_visual, color='#ff6b6b', linestyle='--', alpha=0.3)
ax_net.axvline(transition + 1, color='#4ecdc4', linestyle='--', alpha=0.3)
ax_net.axhline(0, color='#30363d', linewidth=0.5)
ax_net.set_xlabel('Time (seconds)', color='white', fontsize=11)
ax_net.set_ylabel('Network Activation', color='white', fontsize=11)
ax_net.set_title('Network-level time courses', color='white', fontsize=13, pad=10)
ax_net.legend(loc='upper left', ncol=4, facecolor='#161b22',
edgecolor='#30363d', labelcolor='white', fontsize=9)
ax_net.tick_params(colors='white')
for spine in ax_net.spines.values():
spine.set_color('#30363d')
ax_net.grid(True, alpha=0.1, color='white')
ax_net.set_xticks(time_axis)
# Metrics card
ax_m = fig.add_subplot(gs[3, :2])
ax_m.set_facecolor('#161b22')
ax_m.axis('off')
sommot_tc = roi_activations[:, NETWORKS['Somatomotor']].mean(axis=1)
r, _ = pearsonr(visual_tc, sommot_tc)
hook_ratio = visual_tc[:3].mean() / visual_tc[3:].mean() if n_sec > 3 and visual_tc[3:].mean() > 0 else float('nan')
metrics = [
('Visual Peak', f'{visual_tc.max():.3f}', '#ff6b6b'),
('Cross-Modal r', f'{r:+.3f}', '#4ecdc4' if r > 0 else '#ff6b6b'),
('Hook Ratio', f'{hook_ratio:.2f}×' if not np.isnan(hook_ratio) else 'N/A', '#ffe66d'),
('Temporal Var', f'{preds.var(axis=0).mean():.4f}', '#a78bfa'),
('Duration', f'{n_sec}s', '#8b949e'),
]
for i, (label, value, color) in enumerate(metrics):
y = 0.85 - i * 0.16
ax_m.text(0.05, y, label, color='#8b949e', fontsize=11, transform=ax_m.transAxes)
ax_m.text(0.55, y, value, color=color, fontsize=14, fontweight='bold',
family='monospace', transform=ax_m.transAxes)
# Interpretation card
ax_i = fig.add_subplot(gs[3, 2:])
ax_i.set_facecolor('#161b22')
ax_i.axis('off')
dominant = max(NETWORKS.keys(),
key=lambda n: roi_activations[:, NETWORKS[n]].mean() if NETWORKS[n] else -np.inf)
lines = [
('OPENING', f'Visual-dominant hook (t=0–3s)'),
('PROFILE', f'{dominant} network leads predicted response'),
('TRANSITION', f'Modality handoff at t={transition+1}s'),
('COHERENCE', f'Visual↔Auditory r={r:+.2f}'),
('NOTE', 'Descriptive only. Not validated against engagement data.'),
]
for i, (label, text) in enumerate(lines):
y = 0.9 - i * 0.18
ax_i.text(0.02, y, label, color='#ffa500', fontsize=10,
fontweight='bold', family='monospace', transform=ax_i.transAxes)
ax_i.text(0.2, y, text, color='white', fontsize=11, transform=ax_i.transAxes)
plt.savefig(save_path, dpi=130, bbox_inches='tight', facecolor='#0d1117')
plt.close(fig)
return save_path
def build_per_frame(preds, roi_activations, events_df, save_path='per_frame.png'):
n_sec = preds.shape[0]
net_tcs = {name: roi_activations[:, cols].mean(axis=1) if cols else np.zeros(n_sec)
for name, cols in NETWORKS.items()}
baseline_mean = preds.mean()
baseline_std = preds.std()
def words_at(t):
active = events_df[(events_df['start'] <= t) & (events_df['start'] + events_df['duration'] >= t)]
return active[active['type'] == 'Word']['text'].dropna().tolist()
fig = plt.figure(figsize=(16, 2.5 * n_sec), facecolor='#0d1117')
gs = GridSpec(n_sec, 3, figure=fig, width_ratios=[1.2, 1.5, 2.3],
hspace=0.3, wspace=0.15)
fig.suptitle('Per-Second Neural Breakdown', fontsize=22,
color='white', fontweight='bold', y=0.995)
for t in range(n_sec):
ax_b = fig.add_subplot(gs[t, 0])
try:
plotter.plot_surf(preds[t], axes=ax_b, views='left',
cmap='fire', norm_percentile=99,
vmin=0.1, alpha_cmap=(0, 0.2))
except Exception:
ax_b.text(0.5, 0.5, f't={t}s', ha='center', va='center', color='white')
ax_b.set_facecolor('#0d1117')
ax_b.axis('off')
ax_b.set_title(f't = {t}s', color='white', fontsize=14, fontweight='bold', pad=4)
ax_bar = fig.add_subplot(gs[t, 1])
ax_bar.set_facecolor('#161b22')
names = list(net_tcs.keys())
vals = [net_tcs[n][t] for n in names]
colors = [NET_COLORS[n] for n in names]
bars = ax_bar.barh(range(len(names)), vals, color=colors, alpha=0.9)
ax_bar.axvline(0, color='#30363d', linewidth=1)
ax_bar.set_yticks(range(len(names)))
ax_bar.set_yticklabels([n[:10] for n in names], color='white', fontsize=9)
ax_bar.tick_params(colors='white', labelsize=8)
ax_bar.set_xlim(-0.15, 0.45)
for spine in ax_bar.spines.values():
spine.set_color('#30363d')
ax_bar.grid(True, axis='x', alpha=0.1)
for bar, v in zip(bars, vals):
x = v + (0.01 if v >= 0 else -0.01)
ax_bar.text(x, bar.get_y() + bar.get_height() / 2, f'{v:+.2f}',
va='center', ha='left' if v >= 0 else 'right',
color='white', fontsize=8)
ax_t = fig.add_subplot(gs[t, 2])
ax_t.set_facecolor('#161b22')
ax_t.axis('off')
net_vals = {n: net_tcs[n][t] for n in net_tcs}
dominant = max(net_vals, key=net_vals.get)
weakest = min(net_vals, key=net_vals.get)
frame_mean = preds[t].mean()
z = (frame_mean - baseline_mean) / baseline_std if baseline_std > 0 else 0
words = words_at(t)
stimulus = ' '.join(words) if words else '(silent / visual only)'
v_, s_ = net_tcs['Visual'][t], net_tcs['Somatomotor'][t]
if v_ > 0.15 and s_ < 0.02:
modality, mod_color = 'visual-dominant', '#ff6b6b'
elif s_ > 0.05 and v_ < 0.15:
modality, mod_color = 'auditory-dominant', '#4ecdc4'
elif v_ > 0.1 and s_ > 0.05:
modality, mod_color = 'multimodal', '#ffe66d'
else:
modality, mod_color = 'low activation', '#8b949e'
y = 0.92
ax_t.text(0.02, y, 'STIMULUS', color='#8b949e', fontsize=9,
fontweight='bold', transform=ax_t.transAxes)
ax_t.text(0.22, y, stimulus, color='white', fontsize=11,
style='italic', transform=ax_t.transAxes)
y -= 0.15
ax_t.text(0.02, y, 'MODALITY', color='#8b949e', fontsize=9,
fontweight='bold', transform=ax_t.transAxes)
ax_t.text(0.22, y, modality, color=mod_color, fontsize=11,
fontweight='bold', transform=ax_t.transAxes)
ax_t.text(0.55, y, f'z = {z:+.2f}', color='#8b949e', fontsize=10,
transform=ax_t.transAxes)
y -= 0.15
ax_t.text(0.02, y, 'DOMINANT', color='#8b949e', fontsize=9,
fontweight='bold', transform=ax_t.transAxes)
ax_t.text(0.22, y, f'{dominant} ({net_vals[dominant]:+.3f})',
color=NET_COLORS[dominant], fontsize=11, transform=ax_t.transAxes)
y -= 0.15
ax_t.text(0.02, y, 'WEAKEST', color='#8b949e', fontsize=9,
fontweight='bold', transform=ax_t.transAxes)
ax_t.text(0.22, y, f'{weakest} ({net_vals[weakest]:+.3f})',
color=NET_COLORS[weakest], fontsize=11, transform=ax_t.transAxes)
plt.savefig(save_path, dpi=130, bbox_inches='tight', facecolor='#0d1117')
plt.close(fig)
return save_path
@spaces.GPU(duration=300) # 5 min GPU allocation per request
def analyze_video(video_file, progress=gr.Progress()):
if video_file is None:
return None, None, None, "Upload a video first."
progress(0.05, desc="Reading video...")
video_path = Path(video_file)
progress(0.1, desc="Extracting events (audio + speech transcription)...")
events_df = model.get_events_dataframe(video_path=video_path)
progress(0.4, desc="Predicting brain response...")
preds, segments = model.predict(events=events_df)
progress(0.65, desc="Aggregating network activity...")
roi_activations = aggregate_roi(preds)
progress(0.75, desc="Building dashboard...")
dash = build_dashboard(preds, roi_activations, 'dashboard.png')
progress(0.85, desc="Building per-frame breakdown...")
breakdown = build_per_frame(preds, roi_activations, events_df, 'per_frame.png')
progress(0.92, desc="Rendering brain animation...")
mp4_path = 'brain_activity.mp4'
plotter.plot_timesteps_mp4(preds, mp4_path, segments=segments)
visual_tc = roi_activations[:, NETWORKS['Visual']].mean(axis=1)
sommot_tc = roi_activations[:, NETWORKS['Somatomotor']].mean(axis=1)
r, p = pearsonr(visual_tc, sommot_tc)
summary = f"""Duration: {preds.shape[0]}s
Visual peak: {visual_tc.max():.3f} at t={int(visual_tc.argmax())}s
Auditory peak: {sommot_tc.max():.3f} at t={int(sommot_tc.argmax())}s
Cross-modal coherence: r = {r:+.3f} (p = {p:.3f})
Structure: {'visual-first with auditory handoff' if visual_tc.argmax() < sommot_tc.argmax() else 'auditory-first with visual support'}
Note: Metrics are descriptive. Engagement prediction requires validation against real performance data."""
return dash, breakdown, mp4_path, summary
with gr.Blocks(theme=gr.themes.Base(primary_hue="orange"), title="Humeo Neural Analyzer") as app:
gr.Markdown("""
# 🧠 Humeo Neural Content Analyzer
Upload a video. Get predicted fMRI brain response, network-level time courses, and per-second neural breakdown.
*Powered by Meta TRIBE v2 · Analysis takes ~2–3 minutes per video.*
""")
with gr.Row():
with gr.Column(scale=1):
video_in = gr.Video(label="Upload video (MP4, short-form recommended)")
btn = gr.Button("Analyze", variant="primary", size="lg")
with gr.Column(scale=1):
summary_out = gr.Textbox(label="Summary", lines=8, show_copy_button=True)
with gr.Tab("Dashboard"):
dash_out = gr.Image(label="Overview")
with gr.Tab("Per-Frame Breakdown"):
frame_out = gr.Image(label="Second-by-second analysis")
with gr.Tab("Brain Animation"):
mp4_out = gr.Video(label="Brain activity video")
btn.click(analyze_video, inputs=video_in,
outputs=[dash_out, frame_out, mp4_out, summary_out])
gr.Markdown("---\n*Prototype for Humeo R&D. Not a validated engagement prediction tool.*")
if __name__ == "__main__":
app.launch()