Spaces:
Runtime error
Runtime error
| import os | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from matplotlib.gridspec import GridSpec | |
| from scipy.stats import pearsonr | |
| from pathlib import Path | |
| import gradio as gr | |
| import spaces # HF Spaces GPU decorator | |
| from tribev2.demo_utils import TribeModel | |
| from tribev2.plotting import PlotBrain | |
| from nilearn import datasets | |
| from nilearn.surface import vol_to_surf | |
| CACHE_FOLDER = Path("./cache") | |
| CACHE_FOLDER.mkdir(exist_ok=True) | |
| # Load model + plotter once at startup (avoids reloading per request) | |
| print("Loading TRIBE v2 model...") | |
| model = TribeModel.from_pretrained("facebook/tribev2", cache_folder=CACHE_FOLDER) | |
| plotter = PlotBrain(mesh="fsaverage5", atlas_name="schaefer_2018", atlas_dim=400) | |
| # Precompute atlas surface labels | |
| print("Preparing atlas...") | |
| fsaverage = datasets.fetch_surf_fsaverage('fsaverage5') | |
| atlas = plotter.get_atlas() | |
| left_labels = vol_to_surf(atlas.maps, fsaverage.pial_left, | |
| interpolation='nearest_most_frequent', radius=3).astype(int) | |
| right_labels = vol_to_surf(atlas.maps, fsaverage.pial_right, | |
| interpolation='nearest_most_frequent', radius=3).astype(int) | |
| surf_labels = np.concatenate([left_labels, right_labels]) | |
| unique_labels = np.unique(surf_labels) | |
| unique_labels = unique_labels[unique_labels > 0] | |
| labels_text = [l.decode() if isinstance(l, bytes) else str(l) for l in atlas.labels] | |
| def find_network_cols(keyword): | |
| return [i for i, lbl_id in enumerate(unique_labels) | |
| if lbl_id < len(labels_text) and keyword.lower() in labels_text[lbl_id].lower()] | |
| NETWORKS = { | |
| 'Visual': find_network_cols('Vis'), | |
| 'Somatomotor': find_network_cols('SomMot'), | |
| 'DorsAttn': find_network_cols('DorsAttn'), | |
| 'SalVentAttn': find_network_cols('SalVentAttn'), | |
| 'Limbic': find_network_cols('Limbic'), | |
| 'Control': find_network_cols('Cont'), | |
| 'Default': find_network_cols('Default'), | |
| } | |
| NET_COLORS = { | |
| 'Visual': '#ff6b6b', 'Somatomotor': '#4ecdc4', 'DorsAttn': '#ffe66d', | |
| 'SalVentAttn': '#a78bfa', 'Limbic': '#f9a826', 'Control': '#06ffa5', | |
| 'Default': '#3d5a80', | |
| } | |
| print("Ready.") | |
| def aggregate_roi(preds): | |
| n_sec = preds.shape[0] | |
| roi = np.zeros((n_sec, len(unique_labels))) | |
| for i, lbl in enumerate(unique_labels): | |
| mask = surf_labels == lbl | |
| if mask.sum() > 0: | |
| roi[:, i] = preds[:, mask].mean(axis=1) | |
| return roi | |
| def build_dashboard(preds, roi_activations, save_path='dashboard.png'): | |
| n_sec = preds.shape[0] | |
| fig = plt.figure(figsize=(20, 14), facecolor='#0d1117') | |
| gs = GridSpec(4, 5, figure=fig, | |
| height_ratios=[1.2, 1.2, 2, 1], | |
| hspace=0.4, wspace=0.2) | |
| fig.suptitle('Neural Engagement Analysis — TRIBE v2', | |
| fontsize=24, color='white', fontweight='bold', y=0.96) | |
| fig.text(0.5, 0.925, | |
| f'Predicted fMRI response across {n_sec} seconds · Schaefer-400 parcellation', | |
| ha='center', color='#8b949e', fontsize=12) | |
| brain_cols = 6 | |
| for i in range(min(n_sec, 12)): | |
| row = i // brain_cols | |
| col = i % brain_cols | |
| ax = fig.add_subplot(gs[row, col] if col < 5 else gs[row, 4]) | |
| try: | |
| plotter.plot_surf(preds[i], axes=ax, views='left', | |
| cmap='fire', norm_percentile=99, | |
| vmin=0.3, alpha_cmap=(0, 0.2)) | |
| except Exception: | |
| ax.text(0.5, 0.5, f't={i}s', ha='center', va='center', color='white') | |
| ax.set_title(f't = {i}s', color='white', fontsize=10, pad=2) | |
| ax.axis('off') | |
| ax.set_facecolor('#0d1117') | |
| ax_net = fig.add_subplot(gs[2, :]) | |
| ax_net.set_facecolor('#161b22') | |
| time_axis = np.arange(n_sec) | |
| for name, cols in NETWORKS.items(): | |
| if not cols: | |
| continue | |
| tc = roi_activations[:, cols].mean(axis=1) | |
| ax_net.plot(time_axis, tc, label=name, color=NET_COLORS.get(name, 'white'), | |
| linewidth=2.5, marker='o', markersize=6, alpha=0.9) | |
| visual_tc = roi_activations[:, NETWORKS['Visual']].mean(axis=1) | |
| peak_visual = int(np.argmax(visual_tc)) | |
| transition = int(np.argmax(np.abs(np.diff(visual_tc)))) | |
| ax_net.axvline(peak_visual, color='#ff6b6b', linestyle='--', alpha=0.3) | |
| ax_net.axvline(transition + 1, color='#4ecdc4', linestyle='--', alpha=0.3) | |
| ax_net.axhline(0, color='#30363d', linewidth=0.5) | |
| ax_net.set_xlabel('Time (seconds)', color='white', fontsize=11) | |
| ax_net.set_ylabel('Network Activation', color='white', fontsize=11) | |
| ax_net.set_title('Network-level time courses', color='white', fontsize=13, pad=10) | |
| ax_net.legend(loc='upper left', ncol=4, facecolor='#161b22', | |
| edgecolor='#30363d', labelcolor='white', fontsize=9) | |
| ax_net.tick_params(colors='white') | |
| for spine in ax_net.spines.values(): | |
| spine.set_color('#30363d') | |
| ax_net.grid(True, alpha=0.1, color='white') | |
| ax_net.set_xticks(time_axis) | |
| # Metrics card | |
| ax_m = fig.add_subplot(gs[3, :2]) | |
| ax_m.set_facecolor('#161b22') | |
| ax_m.axis('off') | |
| sommot_tc = roi_activations[:, NETWORKS['Somatomotor']].mean(axis=1) | |
| r, _ = pearsonr(visual_tc, sommot_tc) | |
| hook_ratio = visual_tc[:3].mean() / visual_tc[3:].mean() if n_sec > 3 and visual_tc[3:].mean() > 0 else float('nan') | |
| metrics = [ | |
| ('Visual Peak', f'{visual_tc.max():.3f}', '#ff6b6b'), | |
| ('Cross-Modal r', f'{r:+.3f}', '#4ecdc4' if r > 0 else '#ff6b6b'), | |
| ('Hook Ratio', f'{hook_ratio:.2f}×' if not np.isnan(hook_ratio) else 'N/A', '#ffe66d'), | |
| ('Temporal Var', f'{preds.var(axis=0).mean():.4f}', '#a78bfa'), | |
| ('Duration', f'{n_sec}s', '#8b949e'), | |
| ] | |
| for i, (label, value, color) in enumerate(metrics): | |
| y = 0.85 - i * 0.16 | |
| ax_m.text(0.05, y, label, color='#8b949e', fontsize=11, transform=ax_m.transAxes) | |
| ax_m.text(0.55, y, value, color=color, fontsize=14, fontweight='bold', | |
| family='monospace', transform=ax_m.transAxes) | |
| # Interpretation card | |
| ax_i = fig.add_subplot(gs[3, 2:]) | |
| ax_i.set_facecolor('#161b22') | |
| ax_i.axis('off') | |
| dominant = max(NETWORKS.keys(), | |
| key=lambda n: roi_activations[:, NETWORKS[n]].mean() if NETWORKS[n] else -np.inf) | |
| lines = [ | |
| ('OPENING', f'Visual-dominant hook (t=0–3s)'), | |
| ('PROFILE', f'{dominant} network leads predicted response'), | |
| ('TRANSITION', f'Modality handoff at t={transition+1}s'), | |
| ('COHERENCE', f'Visual↔Auditory r={r:+.2f}'), | |
| ('NOTE', 'Descriptive only. Not validated against engagement data.'), | |
| ] | |
| for i, (label, text) in enumerate(lines): | |
| y = 0.9 - i * 0.18 | |
| ax_i.text(0.02, y, label, color='#ffa500', fontsize=10, | |
| fontweight='bold', family='monospace', transform=ax_i.transAxes) | |
| ax_i.text(0.2, y, text, color='white', fontsize=11, transform=ax_i.transAxes) | |
| plt.savefig(save_path, dpi=130, bbox_inches='tight', facecolor='#0d1117') | |
| plt.close(fig) | |
| return save_path | |
| def build_per_frame(preds, roi_activations, events_df, save_path='per_frame.png'): | |
| n_sec = preds.shape[0] | |
| net_tcs = {name: roi_activations[:, cols].mean(axis=1) if cols else np.zeros(n_sec) | |
| for name, cols in NETWORKS.items()} | |
| baseline_mean = preds.mean() | |
| baseline_std = preds.std() | |
| def words_at(t): | |
| active = events_df[(events_df['start'] <= t) & (events_df['start'] + events_df['duration'] >= t)] | |
| return active[active['type'] == 'Word']['text'].dropna().tolist() | |
| fig = plt.figure(figsize=(16, 2.5 * n_sec), facecolor='#0d1117') | |
| gs = GridSpec(n_sec, 3, figure=fig, width_ratios=[1.2, 1.5, 2.3], | |
| hspace=0.3, wspace=0.15) | |
| fig.suptitle('Per-Second Neural Breakdown', fontsize=22, | |
| color='white', fontweight='bold', y=0.995) | |
| for t in range(n_sec): | |
| ax_b = fig.add_subplot(gs[t, 0]) | |
| try: | |
| plotter.plot_surf(preds[t], axes=ax_b, views='left', | |
| cmap='fire', norm_percentile=99, | |
| vmin=0.1, alpha_cmap=(0, 0.2)) | |
| except Exception: | |
| ax_b.text(0.5, 0.5, f't={t}s', ha='center', va='center', color='white') | |
| ax_b.set_facecolor('#0d1117') | |
| ax_b.axis('off') | |
| ax_b.set_title(f't = {t}s', color='white', fontsize=14, fontweight='bold', pad=4) | |
| ax_bar = fig.add_subplot(gs[t, 1]) | |
| ax_bar.set_facecolor('#161b22') | |
| names = list(net_tcs.keys()) | |
| vals = [net_tcs[n][t] for n in names] | |
| colors = [NET_COLORS[n] for n in names] | |
| bars = ax_bar.barh(range(len(names)), vals, color=colors, alpha=0.9) | |
| ax_bar.axvline(0, color='#30363d', linewidth=1) | |
| ax_bar.set_yticks(range(len(names))) | |
| ax_bar.set_yticklabels([n[:10] for n in names], color='white', fontsize=9) | |
| ax_bar.tick_params(colors='white', labelsize=8) | |
| ax_bar.set_xlim(-0.15, 0.45) | |
| for spine in ax_bar.spines.values(): | |
| spine.set_color('#30363d') | |
| ax_bar.grid(True, axis='x', alpha=0.1) | |
| for bar, v in zip(bars, vals): | |
| x = v + (0.01 if v >= 0 else -0.01) | |
| ax_bar.text(x, bar.get_y() + bar.get_height() / 2, f'{v:+.2f}', | |
| va='center', ha='left' if v >= 0 else 'right', | |
| color='white', fontsize=8) | |
| ax_t = fig.add_subplot(gs[t, 2]) | |
| ax_t.set_facecolor('#161b22') | |
| ax_t.axis('off') | |
| net_vals = {n: net_tcs[n][t] for n in net_tcs} | |
| dominant = max(net_vals, key=net_vals.get) | |
| weakest = min(net_vals, key=net_vals.get) | |
| frame_mean = preds[t].mean() | |
| z = (frame_mean - baseline_mean) / baseline_std if baseline_std > 0 else 0 | |
| words = words_at(t) | |
| stimulus = ' '.join(words) if words else '(silent / visual only)' | |
| v_, s_ = net_tcs['Visual'][t], net_tcs['Somatomotor'][t] | |
| if v_ > 0.15 and s_ < 0.02: | |
| modality, mod_color = 'visual-dominant', '#ff6b6b' | |
| elif s_ > 0.05 and v_ < 0.15: | |
| modality, mod_color = 'auditory-dominant', '#4ecdc4' | |
| elif v_ > 0.1 and s_ > 0.05: | |
| modality, mod_color = 'multimodal', '#ffe66d' | |
| else: | |
| modality, mod_color = 'low activation', '#8b949e' | |
| y = 0.92 | |
| ax_t.text(0.02, y, 'STIMULUS', color='#8b949e', fontsize=9, | |
| fontweight='bold', transform=ax_t.transAxes) | |
| ax_t.text(0.22, y, stimulus, color='white', fontsize=11, | |
| style='italic', transform=ax_t.transAxes) | |
| y -= 0.15 | |
| ax_t.text(0.02, y, 'MODALITY', color='#8b949e', fontsize=9, | |
| fontweight='bold', transform=ax_t.transAxes) | |
| ax_t.text(0.22, y, modality, color=mod_color, fontsize=11, | |
| fontweight='bold', transform=ax_t.transAxes) | |
| ax_t.text(0.55, y, f'z = {z:+.2f}', color='#8b949e', fontsize=10, | |
| transform=ax_t.transAxes) | |
| y -= 0.15 | |
| ax_t.text(0.02, y, 'DOMINANT', color='#8b949e', fontsize=9, | |
| fontweight='bold', transform=ax_t.transAxes) | |
| ax_t.text(0.22, y, f'{dominant} ({net_vals[dominant]:+.3f})', | |
| color=NET_COLORS[dominant], fontsize=11, transform=ax_t.transAxes) | |
| y -= 0.15 | |
| ax_t.text(0.02, y, 'WEAKEST', color='#8b949e', fontsize=9, | |
| fontweight='bold', transform=ax_t.transAxes) | |
| ax_t.text(0.22, y, f'{weakest} ({net_vals[weakest]:+.3f})', | |
| color=NET_COLORS[weakest], fontsize=11, transform=ax_t.transAxes) | |
| plt.savefig(save_path, dpi=130, bbox_inches='tight', facecolor='#0d1117') | |
| plt.close(fig) | |
| return save_path | |
| # 5 min GPU allocation per request | |
| def analyze_video(video_file, progress=gr.Progress()): | |
| if video_file is None: | |
| return None, None, None, "Upload a video first." | |
| progress(0.05, desc="Reading video...") | |
| video_path = Path(video_file) | |
| progress(0.1, desc="Extracting events (audio + speech transcription)...") | |
| events_df = model.get_events_dataframe(video_path=video_path) | |
| progress(0.4, desc="Predicting brain response...") | |
| preds, segments = model.predict(events=events_df) | |
| progress(0.65, desc="Aggregating network activity...") | |
| roi_activations = aggregate_roi(preds) | |
| progress(0.75, desc="Building dashboard...") | |
| dash = build_dashboard(preds, roi_activations, 'dashboard.png') | |
| progress(0.85, desc="Building per-frame breakdown...") | |
| breakdown = build_per_frame(preds, roi_activations, events_df, 'per_frame.png') | |
| progress(0.92, desc="Rendering brain animation...") | |
| mp4_path = 'brain_activity.mp4' | |
| plotter.plot_timesteps_mp4(preds, mp4_path, segments=segments) | |
| visual_tc = roi_activations[:, NETWORKS['Visual']].mean(axis=1) | |
| sommot_tc = roi_activations[:, NETWORKS['Somatomotor']].mean(axis=1) | |
| r, p = pearsonr(visual_tc, sommot_tc) | |
| summary = f"""Duration: {preds.shape[0]}s | |
| Visual peak: {visual_tc.max():.3f} at t={int(visual_tc.argmax())}s | |
| Auditory peak: {sommot_tc.max():.3f} at t={int(sommot_tc.argmax())}s | |
| Cross-modal coherence: r = {r:+.3f} (p = {p:.3f}) | |
| Structure: {'visual-first with auditory handoff' if visual_tc.argmax() < sommot_tc.argmax() else 'auditory-first with visual support'} | |
| Note: Metrics are descriptive. Engagement prediction requires validation against real performance data.""" | |
| return dash, breakdown, mp4_path, summary | |
| with gr.Blocks(theme=gr.themes.Base(primary_hue="orange"), title="Humeo Neural Analyzer") as app: | |
| gr.Markdown(""" | |
| # 🧠 Humeo Neural Content Analyzer | |
| Upload a video. Get predicted fMRI brain response, network-level time courses, and per-second neural breakdown. | |
| *Powered by Meta TRIBE v2 · Analysis takes ~2–3 minutes per video.* | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| video_in = gr.Video(label="Upload video (MP4, short-form recommended)") | |
| btn = gr.Button("Analyze", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| summary_out = gr.Textbox(label="Summary", lines=8, show_copy_button=True) | |
| with gr.Tab("Dashboard"): | |
| dash_out = gr.Image(label="Overview") | |
| with gr.Tab("Per-Frame Breakdown"): | |
| frame_out = gr.Image(label="Second-by-second analysis") | |
| with gr.Tab("Brain Animation"): | |
| mp4_out = gr.Video(label="Brain activity video") | |
| btn.click(analyze_video, inputs=video_in, | |
| outputs=[dash_out, frame_out, mp4_out, summary_out]) | |
| gr.Markdown("---\n*Prototype for Humeo R&D. Not a validated engagement prediction tool.*") | |
| if __name__ == "__main__": | |
| app.launch() |