Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
from matplotlib.gridspec import GridSpec
|
| 5 |
+
from scipy.stats import pearsonr
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import gradio as gr
|
| 8 |
+
import spaces # HF Spaces GPU decorator
|
| 9 |
+
|
| 10 |
+
from tribev2.demo_utils import TribeModel
|
| 11 |
+
from tribev2.plotting import PlotBrain
|
| 12 |
+
from nilearn import datasets
|
| 13 |
+
from nilearn.surface import vol_to_surf
|
| 14 |
+
|
| 15 |
+
CACHE_FOLDER = Path("./cache")
|
| 16 |
+
CACHE_FOLDER.mkdir(exist_ok=True)
|
| 17 |
+
|
| 18 |
+
# Load model + plotter once at startup (avoids reloading per request)
|
| 19 |
+
print("Loading TRIBE v2 model...")
|
| 20 |
+
model = TribeModel.from_pretrained("facebook/tribev2", cache_folder=CACHE_FOLDER)
|
| 21 |
+
plotter = PlotBrain(mesh="fsaverage5", atlas_name="schaefer_2018", atlas_dim=400)
|
| 22 |
+
|
| 23 |
+
# Precompute atlas surface labels
|
| 24 |
+
print("Preparing atlas...")
|
| 25 |
+
fsaverage = datasets.fetch_surf_fsaverage('fsaverage5')
|
| 26 |
+
atlas = plotter.get_atlas()
|
| 27 |
+
left_labels = vol_to_surf(atlas.maps, fsaverage.pial_left,
|
| 28 |
+
interpolation='nearest_most_frequent', radius=3).astype(int)
|
| 29 |
+
right_labels = vol_to_surf(atlas.maps, fsaverage.pial_right,
|
| 30 |
+
interpolation='nearest_most_frequent', radius=3).astype(int)
|
| 31 |
+
surf_labels = np.concatenate([left_labels, right_labels])
|
| 32 |
+
unique_labels = np.unique(surf_labels)
|
| 33 |
+
unique_labels = unique_labels[unique_labels > 0]
|
| 34 |
+
labels_text = [l.decode() if isinstance(l, bytes) else str(l) for l in atlas.labels]
|
| 35 |
+
|
| 36 |
+
def find_network_cols(keyword):
|
| 37 |
+
return [i for i, lbl_id in enumerate(unique_labels)
|
| 38 |
+
if lbl_id < len(labels_text) and keyword.lower() in labels_text[lbl_id].lower()]
|
| 39 |
+
|
| 40 |
+
NETWORKS = {
|
| 41 |
+
'Visual': find_network_cols('Vis'),
|
| 42 |
+
'Somatomotor': find_network_cols('SomMot'),
|
| 43 |
+
'DorsAttn': find_network_cols('DorsAttn'),
|
| 44 |
+
'SalVentAttn': find_network_cols('SalVentAttn'),
|
| 45 |
+
'Limbic': find_network_cols('Limbic'),
|
| 46 |
+
'Control': find_network_cols('Cont'),
|
| 47 |
+
'Default': find_network_cols('Default'),
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
NET_COLORS = {
|
| 51 |
+
'Visual': '#ff6b6b', 'Somatomotor': '#4ecdc4', 'DorsAttn': '#ffe66d',
|
| 52 |
+
'SalVentAttn': '#a78bfa', 'Limbic': '#f9a826', 'Control': '#06ffa5',
|
| 53 |
+
'Default': '#3d5a80',
|
| 54 |
+
}
|
| 55 |
+
print("Ready.")
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def aggregate_roi(preds):
|
| 59 |
+
n_sec = preds.shape[0]
|
| 60 |
+
roi = np.zeros((n_sec, len(unique_labels)))
|
| 61 |
+
for i, lbl in enumerate(unique_labels):
|
| 62 |
+
mask = surf_labels == lbl
|
| 63 |
+
if mask.sum() > 0:
|
| 64 |
+
roi[:, i] = preds[:, mask].mean(axis=1)
|
| 65 |
+
return roi
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def build_dashboard(preds, roi_activations, save_path='dashboard.png'):
|
| 69 |
+
n_sec = preds.shape[0]
|
| 70 |
+
fig = plt.figure(figsize=(20, 14), facecolor='#0d1117')
|
| 71 |
+
gs = GridSpec(4, 5, figure=fig,
|
| 72 |
+
height_ratios=[1.2, 1.2, 2, 1],
|
| 73 |
+
hspace=0.4, wspace=0.2)
|
| 74 |
+
|
| 75 |
+
fig.suptitle('Neural Engagement Analysis — TRIBE v2',
|
| 76 |
+
fontsize=24, color='white', fontweight='bold', y=0.96)
|
| 77 |
+
fig.text(0.5, 0.925,
|
| 78 |
+
f'Predicted fMRI response across {n_sec} seconds · Schaefer-400 parcellation',
|
| 79 |
+
ha='center', color='#8b949e', fontsize=12)
|
| 80 |
+
|
| 81 |
+
brain_cols = 6
|
| 82 |
+
for i in range(min(n_sec, 12)):
|
| 83 |
+
row = i // brain_cols
|
| 84 |
+
col = i % brain_cols
|
| 85 |
+
ax = fig.add_subplot(gs[row, col] if col < 5 else gs[row, 4])
|
| 86 |
+
try:
|
| 87 |
+
plotter.plot_surf(preds[i], axes=ax, views='left',
|
| 88 |
+
cmap='fire', norm_percentile=99,
|
| 89 |
+
vmin=0.3, alpha_cmap=(0, 0.2))
|
| 90 |
+
except Exception:
|
| 91 |
+
ax.text(0.5, 0.5, f't={i}s', ha='center', va='center', color='white')
|
| 92 |
+
ax.set_title(f't = {i}s', color='white', fontsize=10, pad=2)
|
| 93 |
+
ax.axis('off')
|
| 94 |
+
ax.set_facecolor('#0d1117')
|
| 95 |
+
|
| 96 |
+
ax_net = fig.add_subplot(gs[2, :])
|
| 97 |
+
ax_net.set_facecolor('#161b22')
|
| 98 |
+
time_axis = np.arange(n_sec)
|
| 99 |
+
for name, cols in NETWORKS.items():
|
| 100 |
+
if not cols:
|
| 101 |
+
continue
|
| 102 |
+
tc = roi_activations[:, cols].mean(axis=1)
|
| 103 |
+
ax_net.plot(time_axis, tc, label=name, color=NET_COLORS.get(name, 'white'),
|
| 104 |
+
linewidth=2.5, marker='o', markersize=6, alpha=0.9)
|
| 105 |
+
|
| 106 |
+
visual_tc = roi_activations[:, NETWORKS['Visual']].mean(axis=1)
|
| 107 |
+
peak_visual = int(np.argmax(visual_tc))
|
| 108 |
+
transition = int(np.argmax(np.abs(np.diff(visual_tc))))
|
| 109 |
+
ax_net.axvline(peak_visual, color='#ff6b6b', linestyle='--', alpha=0.3)
|
| 110 |
+
ax_net.axvline(transition + 1, color='#4ecdc4', linestyle='--', alpha=0.3)
|
| 111 |
+
|
| 112 |
+
ax_net.axhline(0, color='#30363d', linewidth=0.5)
|
| 113 |
+
ax_net.set_xlabel('Time (seconds)', color='white', fontsize=11)
|
| 114 |
+
ax_net.set_ylabel('Network Activation', color='white', fontsize=11)
|
| 115 |
+
ax_net.set_title('Network-level time courses', color='white', fontsize=13, pad=10)
|
| 116 |
+
ax_net.legend(loc='upper left', ncol=4, facecolor='#161b22',
|
| 117 |
+
edgecolor='#30363d', labelcolor='white', fontsize=9)
|
| 118 |
+
ax_net.tick_params(colors='white')
|
| 119 |
+
for spine in ax_net.spines.values():
|
| 120 |
+
spine.set_color('#30363d')
|
| 121 |
+
ax_net.grid(True, alpha=0.1, color='white')
|
| 122 |
+
ax_net.set_xticks(time_axis)
|
| 123 |
+
|
| 124 |
+
# Metrics card
|
| 125 |
+
ax_m = fig.add_subplot(gs[3, :2])
|
| 126 |
+
ax_m.set_facecolor('#161b22')
|
| 127 |
+
ax_m.axis('off')
|
| 128 |
+
sommot_tc = roi_activations[:, NETWORKS['Somatomotor']].mean(axis=1)
|
| 129 |
+
r, _ = pearsonr(visual_tc, sommot_tc)
|
| 130 |
+
hook_ratio = visual_tc[:3].mean() / visual_tc[3:].mean() if n_sec > 3 and visual_tc[3:].mean() > 0 else float('nan')
|
| 131 |
+
metrics = [
|
| 132 |
+
('Visual Peak', f'{visual_tc.max():.3f}', '#ff6b6b'),
|
| 133 |
+
('Cross-Modal r', f'{r:+.3f}', '#4ecdc4' if r > 0 else '#ff6b6b'),
|
| 134 |
+
('Hook Ratio', f'{hook_ratio:.2f}×' if not np.isnan(hook_ratio) else 'N/A', '#ffe66d'),
|
| 135 |
+
('Temporal Var', f'{preds.var(axis=0).mean():.4f}', '#a78bfa'),
|
| 136 |
+
('Duration', f'{n_sec}s', '#8b949e'),
|
| 137 |
+
]
|
| 138 |
+
for i, (label, value, color) in enumerate(metrics):
|
| 139 |
+
y = 0.85 - i * 0.16
|
| 140 |
+
ax_m.text(0.05, y, label, color='#8b949e', fontsize=11, transform=ax_m.transAxes)
|
| 141 |
+
ax_m.text(0.55, y, value, color=color, fontsize=14, fontweight='bold',
|
| 142 |
+
family='monospace', transform=ax_m.transAxes)
|
| 143 |
+
|
| 144 |
+
# Interpretation card
|
| 145 |
+
ax_i = fig.add_subplot(gs[3, 2:])
|
| 146 |
+
ax_i.set_facecolor('#161b22')
|
| 147 |
+
ax_i.axis('off')
|
| 148 |
+
dominant = max(NETWORKS.keys(),
|
| 149 |
+
key=lambda n: roi_activations[:, NETWORKS[n]].mean() if NETWORKS[n] else -np.inf)
|
| 150 |
+
lines = [
|
| 151 |
+
('OPENING', f'Visual-dominant hook (t=0–3s)'),
|
| 152 |
+
('PROFILE', f'{dominant} network leads predicted response'),
|
| 153 |
+
('TRANSITION', f'Modality handoff at t={transition+1}s'),
|
| 154 |
+
('COHERENCE', f'Visual↔Auditory r={r:+.2f}'),
|
| 155 |
+
('NOTE', 'Descriptive only. Not validated against engagement data.'),
|
| 156 |
+
]
|
| 157 |
+
for i, (label, text) in enumerate(lines):
|
| 158 |
+
y = 0.9 - i * 0.18
|
| 159 |
+
ax_i.text(0.02, y, label, color='#ffa500', fontsize=10,
|
| 160 |
+
fontweight='bold', family='monospace', transform=ax_i.transAxes)
|
| 161 |
+
ax_i.text(0.2, y, text, color='white', fontsize=11, transform=ax_i.transAxes)
|
| 162 |
+
|
| 163 |
+
plt.savefig(save_path, dpi=130, bbox_inches='tight', facecolor='#0d1117')
|
| 164 |
+
plt.close(fig)
|
| 165 |
+
return save_path
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def build_per_frame(preds, roi_activations, events_df, save_path='per_frame.png'):
|
| 169 |
+
n_sec = preds.shape[0]
|
| 170 |
+
net_tcs = {name: roi_activations[:, cols].mean(axis=1) if cols else np.zeros(n_sec)
|
| 171 |
+
for name, cols in NETWORKS.items()}
|
| 172 |
+
baseline_mean = preds.mean()
|
| 173 |
+
baseline_std = preds.std()
|
| 174 |
+
|
| 175 |
+
def words_at(t):
|
| 176 |
+
active = events_df[(events_df['start'] <= t) & (events_df['start'] + events_df['duration'] >= t)]
|
| 177 |
+
return active[active['type'] == 'Word']['text'].dropna().tolist()
|
| 178 |
+
|
| 179 |
+
fig = plt.figure(figsize=(16, 2.5 * n_sec), facecolor='#0d1117')
|
| 180 |
+
gs = GridSpec(n_sec, 3, figure=fig, width_ratios=[1.2, 1.5, 2.3],
|
| 181 |
+
hspace=0.3, wspace=0.15)
|
| 182 |
+
fig.suptitle('Per-Second Neural Breakdown', fontsize=22,
|
| 183 |
+
color='white', fontweight='bold', y=0.995)
|
| 184 |
+
|
| 185 |
+
for t in range(n_sec):
|
| 186 |
+
ax_b = fig.add_subplot(gs[t, 0])
|
| 187 |
+
try:
|
| 188 |
+
plotter.plot_surf(preds[t], axes=ax_b, views='left',
|
| 189 |
+
cmap='fire', norm_percentile=99,
|
| 190 |
+
vmin=0.1, alpha_cmap=(0, 0.2))
|
| 191 |
+
except Exception:
|
| 192 |
+
ax_b.text(0.5, 0.5, f't={t}s', ha='center', va='center', color='white')
|
| 193 |
+
ax_b.set_facecolor('#0d1117')
|
| 194 |
+
ax_b.axis('off')
|
| 195 |
+
ax_b.set_title(f't = {t}s', color='white', fontsize=14, fontweight='bold', pad=4)
|
| 196 |
+
|
| 197 |
+
ax_bar = fig.add_subplot(gs[t, 1])
|
| 198 |
+
ax_bar.set_facecolor('#161b22')
|
| 199 |
+
names = list(net_tcs.keys())
|
| 200 |
+
vals = [net_tcs[n][t] for n in names]
|
| 201 |
+
colors = [NET_COLORS[n] for n in names]
|
| 202 |
+
bars = ax_bar.barh(range(len(names)), vals, color=colors, alpha=0.9)
|
| 203 |
+
ax_bar.axvline(0, color='#30363d', linewidth=1)
|
| 204 |
+
ax_bar.set_yticks(range(len(names)))
|
| 205 |
+
ax_bar.set_yticklabels([n[:10] for n in names], color='white', fontsize=9)
|
| 206 |
+
ax_bar.tick_params(colors='white', labelsize=8)
|
| 207 |
+
ax_bar.set_xlim(-0.15, 0.45)
|
| 208 |
+
for spine in ax_bar.spines.values():
|
| 209 |
+
spine.set_color('#30363d')
|
| 210 |
+
ax_bar.grid(True, axis='x', alpha=0.1)
|
| 211 |
+
for bar, v in zip(bars, vals):
|
| 212 |
+
x = v + (0.01 if v >= 0 else -0.01)
|
| 213 |
+
ax_bar.text(x, bar.get_y() + bar.get_height() / 2, f'{v:+.2f}',
|
| 214 |
+
va='center', ha='left' if v >= 0 else 'right',
|
| 215 |
+
color='white', fontsize=8)
|
| 216 |
+
|
| 217 |
+
ax_t = fig.add_subplot(gs[t, 2])
|
| 218 |
+
ax_t.set_facecolor('#161b22')
|
| 219 |
+
ax_t.axis('off')
|
| 220 |
+
net_vals = {n: net_tcs[n][t] for n in net_tcs}
|
| 221 |
+
dominant = max(net_vals, key=net_vals.get)
|
| 222 |
+
weakest = min(net_vals, key=net_vals.get)
|
| 223 |
+
frame_mean = preds[t].mean()
|
| 224 |
+
z = (frame_mean - baseline_mean) / baseline_std if baseline_std > 0 else 0
|
| 225 |
+
words = words_at(t)
|
| 226 |
+
stimulus = ' '.join(words) if words else '(silent / visual only)'
|
| 227 |
+
|
| 228 |
+
v_, s_ = net_tcs['Visual'][t], net_tcs['Somatomotor'][t]
|
| 229 |
+
if v_ > 0.15 and s_ < 0.02:
|
| 230 |
+
modality, mod_color = 'visual-dominant', '#ff6b6b'
|
| 231 |
+
elif s_ > 0.05 and v_ < 0.15:
|
| 232 |
+
modality, mod_color = 'auditory-dominant', '#4ecdc4'
|
| 233 |
+
elif v_ > 0.1 and s_ > 0.05:
|
| 234 |
+
modality, mod_color = 'multimodal', '#ffe66d'
|
| 235 |
+
else:
|
| 236 |
+
modality, mod_color = 'low activation', '#8b949e'
|
| 237 |
+
|
| 238 |
+
y = 0.92
|
| 239 |
+
ax_t.text(0.02, y, 'STIMULUS', color='#8b949e', fontsize=9,
|
| 240 |
+
fontweight='bold', transform=ax_t.transAxes)
|
| 241 |
+
ax_t.text(0.22, y, stimulus, color='white', fontsize=11,
|
| 242 |
+
style='italic', transform=ax_t.transAxes)
|
| 243 |
+
y -= 0.15
|
| 244 |
+
ax_t.text(0.02, y, 'MODALITY', color='#8b949e', fontsize=9,
|
| 245 |
+
fontweight='bold', transform=ax_t.transAxes)
|
| 246 |
+
ax_t.text(0.22, y, modality, color=mod_color, fontsize=11,
|
| 247 |
+
fontweight='bold', transform=ax_t.transAxes)
|
| 248 |
+
ax_t.text(0.55, y, f'z = {z:+.2f}', color='#8b949e', fontsize=10,
|
| 249 |
+
transform=ax_t.transAxes)
|
| 250 |
+
y -= 0.15
|
| 251 |
+
ax_t.text(0.02, y, 'DOMINANT', color='#8b949e', fontsize=9,
|
| 252 |
+
fontweight='bold', transform=ax_t.transAxes)
|
| 253 |
+
ax_t.text(0.22, y, f'{dominant} ({net_vals[dominant]:+.3f})',
|
| 254 |
+
color=NET_COLORS[dominant], fontsize=11, transform=ax_t.transAxes)
|
| 255 |
+
y -= 0.15
|
| 256 |
+
ax_t.text(0.02, y, 'WEAKEST', color='#8b949e', fontsize=9,
|
| 257 |
+
fontweight='bold', transform=ax_t.transAxes)
|
| 258 |
+
ax_t.text(0.22, y, f'{weakest} ({net_vals[weakest]:+.3f})',
|
| 259 |
+
color=NET_COLORS[weakest], fontsize=11, transform=ax_t.transAxes)
|
| 260 |
+
|
| 261 |
+
plt.savefig(save_path, dpi=130, bbox_inches='tight', facecolor='#0d1117')
|
| 262 |
+
plt.close(fig)
|
| 263 |
+
return save_path
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
@spaces.GPU(duration=300) # 5 min GPU allocation per request
|
| 267 |
+
def analyze_video(video_file, progress=gr.Progress()):
|
| 268 |
+
if video_file is None:
|
| 269 |
+
return None, None, None, "Upload a video first."
|
| 270 |
+
|
| 271 |
+
progress(0.05, desc="Reading video...")
|
| 272 |
+
video_path = Path(video_file)
|
| 273 |
+
|
| 274 |
+
progress(0.1, desc="Extracting events (audio + speech transcription)...")
|
| 275 |
+
events_df = model.get_events_dataframe(video_path=video_path)
|
| 276 |
+
|
| 277 |
+
progress(0.4, desc="Predicting brain response...")
|
| 278 |
+
preds, segments = model.predict(events=events_df)
|
| 279 |
+
|
| 280 |
+
progress(0.65, desc="Aggregating network activity...")
|
| 281 |
+
roi_activations = aggregate_roi(preds)
|
| 282 |
+
|
| 283 |
+
progress(0.75, desc="Building dashboard...")
|
| 284 |
+
dash = build_dashboard(preds, roi_activations, 'dashboard.png')
|
| 285 |
+
|
| 286 |
+
progress(0.85, desc="Building per-frame breakdown...")
|
| 287 |
+
breakdown = build_per_frame(preds, roi_activations, events_df, 'per_frame.png')
|
| 288 |
+
|
| 289 |
+
progress(0.92, desc="Rendering brain animation...")
|
| 290 |
+
mp4_path = 'brain_activity.mp4'
|
| 291 |
+
plotter.plot_timesteps_mp4(preds, mp4_path, segments=segments)
|
| 292 |
+
|
| 293 |
+
visual_tc = roi_activations[:, NETWORKS['Visual']].mean(axis=1)
|
| 294 |
+
sommot_tc = roi_activations[:, NETWORKS['Somatomotor']].mean(axis=1)
|
| 295 |
+
r, p = pearsonr(visual_tc, sommot_tc)
|
| 296 |
+
summary = f"""Duration: {preds.shape[0]}s
|
| 297 |
+
Visual peak: {visual_tc.max():.3f} at t={int(visual_tc.argmax())}s
|
| 298 |
+
Auditory peak: {sommot_tc.max():.3f} at t={int(sommot_tc.argmax())}s
|
| 299 |
+
Cross-modal coherence: r = {r:+.3f} (p = {p:.3f})
|
| 300 |
+
Structure: {'visual-first with auditory handoff' if visual_tc.argmax() < sommot_tc.argmax() else 'auditory-first with visual support'}
|
| 301 |
+
|
| 302 |
+
Note: Metrics are descriptive. Engagement prediction requires validation against real performance data."""
|
| 303 |
+
|
| 304 |
+
return dash, breakdown, mp4_path, summary
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
with gr.Blocks(theme=gr.themes.Base(primary_hue="orange"), title="Humeo Neural Analyzer") as app:
|
| 308 |
+
gr.Markdown("""
|
| 309 |
+
# 🧠 Humeo Neural Content Analyzer
|
| 310 |
+
Upload a video. Get predicted fMRI brain response, network-level time courses, and per-second neural breakdown.
|
| 311 |
+
|
| 312 |
+
*Powered by Meta TRIBE v2 · Analysis takes ~2–3 minutes per video.*
|
| 313 |
+
""")
|
| 314 |
+
|
| 315 |
+
with gr.Row():
|
| 316 |
+
with gr.Column(scale=1):
|
| 317 |
+
video_in = gr.Video(label="Upload video (MP4, short-form recommended)")
|
| 318 |
+
btn = gr.Button("Analyze", variant="primary", size="lg")
|
| 319 |
+
with gr.Column(scale=1):
|
| 320 |
+
summary_out = gr.Textbox(label="Summary", lines=8, show_copy_button=True)
|
| 321 |
+
|
| 322 |
+
with gr.Tab("Dashboard"):
|
| 323 |
+
dash_out = gr.Image(label="Overview")
|
| 324 |
+
with gr.Tab("Per-Frame Breakdown"):
|
| 325 |
+
frame_out = gr.Image(label="Second-by-second analysis")
|
| 326 |
+
with gr.Tab("Brain Animation"):
|
| 327 |
+
mp4_out = gr.Video(label="Brain activity video")
|
| 328 |
+
|
| 329 |
+
btn.click(analyze_video, inputs=video_in,
|
| 330 |
+
outputs=[dash_out, frame_out, mp4_out, summary_out])
|
| 331 |
+
|
| 332 |
+
gr.Markdown("---\n*Prototype for Humeo R&D. Not a validated engagement prediction tool.*")
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
if __name__ == "__main__":
|
| 336 |
+
app.launch()
|