import os import numpy as np import matplotlib.pyplot as plt from matplotlib.gridspec import GridSpec from scipy.stats import pearsonr from pathlib import Path import gradio as gr import spaces # HF Spaces GPU decorator from tribev2.demo_utils import TribeModel from tribev2.plotting import PlotBrain from nilearn import datasets from nilearn.surface import vol_to_surf CACHE_FOLDER = Path("./cache") CACHE_FOLDER.mkdir(exist_ok=True) # Load model + plotter once at startup (avoids reloading per request) print("Loading TRIBE v2 model...") model = TribeModel.from_pretrained("facebook/tribev2", cache_folder=CACHE_FOLDER) plotter = PlotBrain(mesh="fsaverage5", atlas_name="schaefer_2018", atlas_dim=400) # Precompute atlas surface labels print("Preparing atlas...") fsaverage = datasets.fetch_surf_fsaverage('fsaverage5') atlas = plotter.get_atlas() left_labels = vol_to_surf(atlas.maps, fsaverage.pial_left, interpolation='nearest_most_frequent', radius=3).astype(int) right_labels = vol_to_surf(atlas.maps, fsaverage.pial_right, interpolation='nearest_most_frequent', radius=3).astype(int) surf_labels = np.concatenate([left_labels, right_labels]) unique_labels = np.unique(surf_labels) unique_labels = unique_labels[unique_labels > 0] labels_text = [l.decode() if isinstance(l, bytes) else str(l) for l in atlas.labels] def find_network_cols(keyword): return [i for i, lbl_id in enumerate(unique_labels) if lbl_id < len(labels_text) and keyword.lower() in labels_text[lbl_id].lower()] NETWORKS = { 'Visual': find_network_cols('Vis'), 'Somatomotor': find_network_cols('SomMot'), 'DorsAttn': find_network_cols('DorsAttn'), 'SalVentAttn': find_network_cols('SalVentAttn'), 'Limbic': find_network_cols('Limbic'), 'Control': find_network_cols('Cont'), 'Default': find_network_cols('Default'), } NET_COLORS = { 'Visual': '#ff6b6b', 'Somatomotor': '#4ecdc4', 'DorsAttn': '#ffe66d', 'SalVentAttn': '#a78bfa', 'Limbic': '#f9a826', 'Control': '#06ffa5', 'Default': '#3d5a80', } print("Ready.") def aggregate_roi(preds): n_sec = preds.shape[0] roi = np.zeros((n_sec, len(unique_labels))) for i, lbl in enumerate(unique_labels): mask = surf_labels == lbl if mask.sum() > 0: roi[:, i] = preds[:, mask].mean(axis=1) return roi def build_dashboard(preds, roi_activations, save_path='dashboard.png'): n_sec = preds.shape[0] fig = plt.figure(figsize=(20, 14), facecolor='#0d1117') gs = GridSpec(4, 5, figure=fig, height_ratios=[1.2, 1.2, 2, 1], hspace=0.4, wspace=0.2) fig.suptitle('Neural Engagement Analysis — TRIBE v2', fontsize=24, color='white', fontweight='bold', y=0.96) fig.text(0.5, 0.925, f'Predicted fMRI response across {n_sec} seconds · Schaefer-400 parcellation', ha='center', color='#8b949e', fontsize=12) brain_cols = 6 for i in range(min(n_sec, 12)): row = i // brain_cols col = i % brain_cols ax = fig.add_subplot(gs[row, col] if col < 5 else gs[row, 4]) try: plotter.plot_surf(preds[i], axes=ax, views='left', cmap='fire', norm_percentile=99, vmin=0.3, alpha_cmap=(0, 0.2)) except Exception: ax.text(0.5, 0.5, f't={i}s', ha='center', va='center', color='white') ax.set_title(f't = {i}s', color='white', fontsize=10, pad=2) ax.axis('off') ax.set_facecolor('#0d1117') ax_net = fig.add_subplot(gs[2, :]) ax_net.set_facecolor('#161b22') time_axis = np.arange(n_sec) for name, cols in NETWORKS.items(): if not cols: continue tc = roi_activations[:, cols].mean(axis=1) ax_net.plot(time_axis, tc, label=name, color=NET_COLORS.get(name, 'white'), linewidth=2.5, marker='o', markersize=6, alpha=0.9) visual_tc = roi_activations[:, NETWORKS['Visual']].mean(axis=1) peak_visual = int(np.argmax(visual_tc)) transition = int(np.argmax(np.abs(np.diff(visual_tc)))) ax_net.axvline(peak_visual, color='#ff6b6b', linestyle='--', alpha=0.3) ax_net.axvline(transition + 1, color='#4ecdc4', linestyle='--', alpha=0.3) ax_net.axhline(0, color='#30363d', linewidth=0.5) ax_net.set_xlabel('Time (seconds)', color='white', fontsize=11) ax_net.set_ylabel('Network Activation', color='white', fontsize=11) ax_net.set_title('Network-level time courses', color='white', fontsize=13, pad=10) ax_net.legend(loc='upper left', ncol=4, facecolor='#161b22', edgecolor='#30363d', labelcolor='white', fontsize=9) ax_net.tick_params(colors='white') for spine in ax_net.spines.values(): spine.set_color('#30363d') ax_net.grid(True, alpha=0.1, color='white') ax_net.set_xticks(time_axis) # Metrics card ax_m = fig.add_subplot(gs[3, :2]) ax_m.set_facecolor('#161b22') ax_m.axis('off') sommot_tc = roi_activations[:, NETWORKS['Somatomotor']].mean(axis=1) r, _ = pearsonr(visual_tc, sommot_tc) hook_ratio = visual_tc[:3].mean() / visual_tc[3:].mean() if n_sec > 3 and visual_tc[3:].mean() > 0 else float('nan') metrics = [ ('Visual Peak', f'{visual_tc.max():.3f}', '#ff6b6b'), ('Cross-Modal r', f'{r:+.3f}', '#4ecdc4' if r > 0 else '#ff6b6b'), ('Hook Ratio', f'{hook_ratio:.2f}×' if not np.isnan(hook_ratio) else 'N/A', '#ffe66d'), ('Temporal Var', f'{preds.var(axis=0).mean():.4f}', '#a78bfa'), ('Duration', f'{n_sec}s', '#8b949e'), ] for i, (label, value, color) in enumerate(metrics): y = 0.85 - i * 0.16 ax_m.text(0.05, y, label, color='#8b949e', fontsize=11, transform=ax_m.transAxes) ax_m.text(0.55, y, value, color=color, fontsize=14, fontweight='bold', family='monospace', transform=ax_m.transAxes) # Interpretation card ax_i = fig.add_subplot(gs[3, 2:]) ax_i.set_facecolor('#161b22') ax_i.axis('off') dominant = max(NETWORKS.keys(), key=lambda n: roi_activations[:, NETWORKS[n]].mean() if NETWORKS[n] else -np.inf) lines = [ ('OPENING', f'Visual-dominant hook (t=0–3s)'), ('PROFILE', f'{dominant} network leads predicted response'), ('TRANSITION', f'Modality handoff at t={transition+1}s'), ('COHERENCE', f'Visual↔Auditory r={r:+.2f}'), ('NOTE', 'Descriptive only. Not validated against engagement data.'), ] for i, (label, text) in enumerate(lines): y = 0.9 - i * 0.18 ax_i.text(0.02, y, label, color='#ffa500', fontsize=10, fontweight='bold', family='monospace', transform=ax_i.transAxes) ax_i.text(0.2, y, text, color='white', fontsize=11, transform=ax_i.transAxes) plt.savefig(save_path, dpi=130, bbox_inches='tight', facecolor='#0d1117') plt.close(fig) return save_path def build_per_frame(preds, roi_activations, events_df, save_path='per_frame.png'): n_sec = preds.shape[0] net_tcs = {name: roi_activations[:, cols].mean(axis=1) if cols else np.zeros(n_sec) for name, cols in NETWORKS.items()} baseline_mean = preds.mean() baseline_std = preds.std() def words_at(t): active = events_df[(events_df['start'] <= t) & (events_df['start'] + events_df['duration'] >= t)] return active[active['type'] == 'Word']['text'].dropna().tolist() fig = plt.figure(figsize=(16, 2.5 * n_sec), facecolor='#0d1117') gs = GridSpec(n_sec, 3, figure=fig, width_ratios=[1.2, 1.5, 2.3], hspace=0.3, wspace=0.15) fig.suptitle('Per-Second Neural Breakdown', fontsize=22, color='white', fontweight='bold', y=0.995) for t in range(n_sec): ax_b = fig.add_subplot(gs[t, 0]) try: plotter.plot_surf(preds[t], axes=ax_b, views='left', cmap='fire', norm_percentile=99, vmin=0.1, alpha_cmap=(0, 0.2)) except Exception: ax_b.text(0.5, 0.5, f't={t}s', ha='center', va='center', color='white') ax_b.set_facecolor('#0d1117') ax_b.axis('off') ax_b.set_title(f't = {t}s', color='white', fontsize=14, fontweight='bold', pad=4) ax_bar = fig.add_subplot(gs[t, 1]) ax_bar.set_facecolor('#161b22') names = list(net_tcs.keys()) vals = [net_tcs[n][t] for n in names] colors = [NET_COLORS[n] for n in names] bars = ax_bar.barh(range(len(names)), vals, color=colors, alpha=0.9) ax_bar.axvline(0, color='#30363d', linewidth=1) ax_bar.set_yticks(range(len(names))) ax_bar.set_yticklabels([n[:10] for n in names], color='white', fontsize=9) ax_bar.tick_params(colors='white', labelsize=8) ax_bar.set_xlim(-0.15, 0.45) for spine in ax_bar.spines.values(): spine.set_color('#30363d') ax_bar.grid(True, axis='x', alpha=0.1) for bar, v in zip(bars, vals): x = v + (0.01 if v >= 0 else -0.01) ax_bar.text(x, bar.get_y() + bar.get_height() / 2, f'{v:+.2f}', va='center', ha='left' if v >= 0 else 'right', color='white', fontsize=8) ax_t = fig.add_subplot(gs[t, 2]) ax_t.set_facecolor('#161b22') ax_t.axis('off') net_vals = {n: net_tcs[n][t] for n in net_tcs} dominant = max(net_vals, key=net_vals.get) weakest = min(net_vals, key=net_vals.get) frame_mean = preds[t].mean() z = (frame_mean - baseline_mean) / baseline_std if baseline_std > 0 else 0 words = words_at(t) stimulus = ' '.join(words) if words else '(silent / visual only)' v_, s_ = net_tcs['Visual'][t], net_tcs['Somatomotor'][t] if v_ > 0.15 and s_ < 0.02: modality, mod_color = 'visual-dominant', '#ff6b6b' elif s_ > 0.05 and v_ < 0.15: modality, mod_color = 'auditory-dominant', '#4ecdc4' elif v_ > 0.1 and s_ > 0.05: modality, mod_color = 'multimodal', '#ffe66d' else: modality, mod_color = 'low activation', '#8b949e' y = 0.92 ax_t.text(0.02, y, 'STIMULUS', color='#8b949e', fontsize=9, fontweight='bold', transform=ax_t.transAxes) ax_t.text(0.22, y, stimulus, color='white', fontsize=11, style='italic', transform=ax_t.transAxes) y -= 0.15 ax_t.text(0.02, y, 'MODALITY', color='#8b949e', fontsize=9, fontweight='bold', transform=ax_t.transAxes) ax_t.text(0.22, y, modality, color=mod_color, fontsize=11, fontweight='bold', transform=ax_t.transAxes) ax_t.text(0.55, y, f'z = {z:+.2f}', color='#8b949e', fontsize=10, transform=ax_t.transAxes) y -= 0.15 ax_t.text(0.02, y, 'DOMINANT', color='#8b949e', fontsize=9, fontweight='bold', transform=ax_t.transAxes) ax_t.text(0.22, y, f'{dominant} ({net_vals[dominant]:+.3f})', color=NET_COLORS[dominant], fontsize=11, transform=ax_t.transAxes) y -= 0.15 ax_t.text(0.02, y, 'WEAKEST', color='#8b949e', fontsize=9, fontweight='bold', transform=ax_t.transAxes) ax_t.text(0.22, y, f'{weakest} ({net_vals[weakest]:+.3f})', color=NET_COLORS[weakest], fontsize=11, transform=ax_t.transAxes) plt.savefig(save_path, dpi=130, bbox_inches='tight', facecolor='#0d1117') plt.close(fig) return save_path @spaces.GPU(duration=300) # 5 min GPU allocation per request def analyze_video(video_file, progress=gr.Progress()): if video_file is None: return None, None, None, "Upload a video first." progress(0.05, desc="Reading video...") video_path = Path(video_file) progress(0.1, desc="Extracting events (audio + speech transcription)...") events_df = model.get_events_dataframe(video_path=video_path) progress(0.4, desc="Predicting brain response...") preds, segments = model.predict(events=events_df) progress(0.65, desc="Aggregating network activity...") roi_activations = aggregate_roi(preds) progress(0.75, desc="Building dashboard...") dash = build_dashboard(preds, roi_activations, 'dashboard.png') progress(0.85, desc="Building per-frame breakdown...") breakdown = build_per_frame(preds, roi_activations, events_df, 'per_frame.png') progress(0.92, desc="Rendering brain animation...") mp4_path = 'brain_activity.mp4' plotter.plot_timesteps_mp4(preds, mp4_path, segments=segments) visual_tc = roi_activations[:, NETWORKS['Visual']].mean(axis=1) sommot_tc = roi_activations[:, NETWORKS['Somatomotor']].mean(axis=1) r, p = pearsonr(visual_tc, sommot_tc) summary = f"""Duration: {preds.shape[0]}s Visual peak: {visual_tc.max():.3f} at t={int(visual_tc.argmax())}s Auditory peak: {sommot_tc.max():.3f} at t={int(sommot_tc.argmax())}s Cross-modal coherence: r = {r:+.3f} (p = {p:.3f}) Structure: {'visual-first with auditory handoff' if visual_tc.argmax() < sommot_tc.argmax() else 'auditory-first with visual support'} Note: Metrics are descriptive. Engagement prediction requires validation against real performance data.""" return dash, breakdown, mp4_path, summary with gr.Blocks(theme=gr.themes.Base(primary_hue="orange"), title="Humeo Neural Analyzer") as app: gr.Markdown(""" # 🧠 Humeo Neural Content Analyzer Upload a video. Get predicted fMRI brain response, network-level time courses, and per-second neural breakdown. *Powered by Meta TRIBE v2 · Analysis takes ~2–3 minutes per video.* """) with gr.Row(): with gr.Column(scale=1): video_in = gr.Video(label="Upload video (MP4, short-form recommended)") btn = gr.Button("Analyze", variant="primary", size="lg") with gr.Column(scale=1): summary_out = gr.Textbox(label="Summary", lines=8, show_copy_button=True) with gr.Tab("Dashboard"): dash_out = gr.Image(label="Overview") with gr.Tab("Per-Frame Breakdown"): frame_out = gr.Image(label="Second-by-second analysis") with gr.Tab("Brain Animation"): mp4_out = gr.Video(label="Brain activity video") btn.click(analyze_video, inputs=video_in, outputs=[dash_out, frame_out, mp4_out, summary_out]) gr.Markdown("---\n*Prototype for Humeo R&D. Not a validated engagement prediction tool.*") if __name__ == "__main__": app.launch()