moonlantern1 commited on
Commit
516e91b
·
verified ·
1 Parent(s): fd56d95

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +336 -0
app.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ from matplotlib.gridspec import GridSpec
5
+ from scipy.stats import pearsonr
6
+ from pathlib import Path
7
+ import gradio as gr
8
+ import spaces # HF Spaces GPU decorator
9
+
10
+ from tribev2.demo_utils import TribeModel
11
+ from tribev2.plotting import PlotBrain
12
+ from nilearn import datasets
13
+ from nilearn.surface import vol_to_surf
14
+
15
+ CACHE_FOLDER = Path("./cache")
16
+ CACHE_FOLDER.mkdir(exist_ok=True)
17
+
18
+ # Load model + plotter once at startup (avoids reloading per request)
19
+ print("Loading TRIBE v2 model...")
20
+ model = TribeModel.from_pretrained("facebook/tribev2", cache_folder=CACHE_FOLDER)
21
+ plotter = PlotBrain(mesh="fsaverage5", atlas_name="schaefer_2018", atlas_dim=400)
22
+
23
+ # Precompute atlas surface labels
24
+ print("Preparing atlas...")
25
+ fsaverage = datasets.fetch_surf_fsaverage('fsaverage5')
26
+ atlas = plotter.get_atlas()
27
+ left_labels = vol_to_surf(atlas.maps, fsaverage.pial_left,
28
+ interpolation='nearest_most_frequent', radius=3).astype(int)
29
+ right_labels = vol_to_surf(atlas.maps, fsaverage.pial_right,
30
+ interpolation='nearest_most_frequent', radius=3).astype(int)
31
+ surf_labels = np.concatenate([left_labels, right_labels])
32
+ unique_labels = np.unique(surf_labels)
33
+ unique_labels = unique_labels[unique_labels > 0]
34
+ labels_text = [l.decode() if isinstance(l, bytes) else str(l) for l in atlas.labels]
35
+
36
+ def find_network_cols(keyword):
37
+ return [i for i, lbl_id in enumerate(unique_labels)
38
+ if lbl_id < len(labels_text) and keyword.lower() in labels_text[lbl_id].lower()]
39
+
40
+ NETWORKS = {
41
+ 'Visual': find_network_cols('Vis'),
42
+ 'Somatomotor': find_network_cols('SomMot'),
43
+ 'DorsAttn': find_network_cols('DorsAttn'),
44
+ 'SalVentAttn': find_network_cols('SalVentAttn'),
45
+ 'Limbic': find_network_cols('Limbic'),
46
+ 'Control': find_network_cols('Cont'),
47
+ 'Default': find_network_cols('Default'),
48
+ }
49
+
50
+ NET_COLORS = {
51
+ 'Visual': '#ff6b6b', 'Somatomotor': '#4ecdc4', 'DorsAttn': '#ffe66d',
52
+ 'SalVentAttn': '#a78bfa', 'Limbic': '#f9a826', 'Control': '#06ffa5',
53
+ 'Default': '#3d5a80',
54
+ }
55
+ print("Ready.")
56
+
57
+
58
+ def aggregate_roi(preds):
59
+ n_sec = preds.shape[0]
60
+ roi = np.zeros((n_sec, len(unique_labels)))
61
+ for i, lbl in enumerate(unique_labels):
62
+ mask = surf_labels == lbl
63
+ if mask.sum() > 0:
64
+ roi[:, i] = preds[:, mask].mean(axis=1)
65
+ return roi
66
+
67
+
68
+ def build_dashboard(preds, roi_activations, save_path='dashboard.png'):
69
+ n_sec = preds.shape[0]
70
+ fig = plt.figure(figsize=(20, 14), facecolor='#0d1117')
71
+ gs = GridSpec(4, 5, figure=fig,
72
+ height_ratios=[1.2, 1.2, 2, 1],
73
+ hspace=0.4, wspace=0.2)
74
+
75
+ fig.suptitle('Neural Engagement Analysis — TRIBE v2',
76
+ fontsize=24, color='white', fontweight='bold', y=0.96)
77
+ fig.text(0.5, 0.925,
78
+ f'Predicted fMRI response across {n_sec} seconds · Schaefer-400 parcellation',
79
+ ha='center', color='#8b949e', fontsize=12)
80
+
81
+ brain_cols = 6
82
+ for i in range(min(n_sec, 12)):
83
+ row = i // brain_cols
84
+ col = i % brain_cols
85
+ ax = fig.add_subplot(gs[row, col] if col < 5 else gs[row, 4])
86
+ try:
87
+ plotter.plot_surf(preds[i], axes=ax, views='left',
88
+ cmap='fire', norm_percentile=99,
89
+ vmin=0.3, alpha_cmap=(0, 0.2))
90
+ except Exception:
91
+ ax.text(0.5, 0.5, f't={i}s', ha='center', va='center', color='white')
92
+ ax.set_title(f't = {i}s', color='white', fontsize=10, pad=2)
93
+ ax.axis('off')
94
+ ax.set_facecolor('#0d1117')
95
+
96
+ ax_net = fig.add_subplot(gs[2, :])
97
+ ax_net.set_facecolor('#161b22')
98
+ time_axis = np.arange(n_sec)
99
+ for name, cols in NETWORKS.items():
100
+ if not cols:
101
+ continue
102
+ tc = roi_activations[:, cols].mean(axis=1)
103
+ ax_net.plot(time_axis, tc, label=name, color=NET_COLORS.get(name, 'white'),
104
+ linewidth=2.5, marker='o', markersize=6, alpha=0.9)
105
+
106
+ visual_tc = roi_activations[:, NETWORKS['Visual']].mean(axis=1)
107
+ peak_visual = int(np.argmax(visual_tc))
108
+ transition = int(np.argmax(np.abs(np.diff(visual_tc))))
109
+ ax_net.axvline(peak_visual, color='#ff6b6b', linestyle='--', alpha=0.3)
110
+ ax_net.axvline(transition + 1, color='#4ecdc4', linestyle='--', alpha=0.3)
111
+
112
+ ax_net.axhline(0, color='#30363d', linewidth=0.5)
113
+ ax_net.set_xlabel('Time (seconds)', color='white', fontsize=11)
114
+ ax_net.set_ylabel('Network Activation', color='white', fontsize=11)
115
+ ax_net.set_title('Network-level time courses', color='white', fontsize=13, pad=10)
116
+ ax_net.legend(loc='upper left', ncol=4, facecolor='#161b22',
117
+ edgecolor='#30363d', labelcolor='white', fontsize=9)
118
+ ax_net.tick_params(colors='white')
119
+ for spine in ax_net.spines.values():
120
+ spine.set_color('#30363d')
121
+ ax_net.grid(True, alpha=0.1, color='white')
122
+ ax_net.set_xticks(time_axis)
123
+
124
+ # Metrics card
125
+ ax_m = fig.add_subplot(gs[3, :2])
126
+ ax_m.set_facecolor('#161b22')
127
+ ax_m.axis('off')
128
+ sommot_tc = roi_activations[:, NETWORKS['Somatomotor']].mean(axis=1)
129
+ r, _ = pearsonr(visual_tc, sommot_tc)
130
+ hook_ratio = visual_tc[:3].mean() / visual_tc[3:].mean() if n_sec > 3 and visual_tc[3:].mean() > 0 else float('nan')
131
+ metrics = [
132
+ ('Visual Peak', f'{visual_tc.max():.3f}', '#ff6b6b'),
133
+ ('Cross-Modal r', f'{r:+.3f}', '#4ecdc4' if r > 0 else '#ff6b6b'),
134
+ ('Hook Ratio', f'{hook_ratio:.2f}×' if not np.isnan(hook_ratio) else 'N/A', '#ffe66d'),
135
+ ('Temporal Var', f'{preds.var(axis=0).mean():.4f}', '#a78bfa'),
136
+ ('Duration', f'{n_sec}s', '#8b949e'),
137
+ ]
138
+ for i, (label, value, color) in enumerate(metrics):
139
+ y = 0.85 - i * 0.16
140
+ ax_m.text(0.05, y, label, color='#8b949e', fontsize=11, transform=ax_m.transAxes)
141
+ ax_m.text(0.55, y, value, color=color, fontsize=14, fontweight='bold',
142
+ family='monospace', transform=ax_m.transAxes)
143
+
144
+ # Interpretation card
145
+ ax_i = fig.add_subplot(gs[3, 2:])
146
+ ax_i.set_facecolor('#161b22')
147
+ ax_i.axis('off')
148
+ dominant = max(NETWORKS.keys(),
149
+ key=lambda n: roi_activations[:, NETWORKS[n]].mean() if NETWORKS[n] else -np.inf)
150
+ lines = [
151
+ ('OPENING', f'Visual-dominant hook (t=0–3s)'),
152
+ ('PROFILE', f'{dominant} network leads predicted response'),
153
+ ('TRANSITION', f'Modality handoff at t={transition+1}s'),
154
+ ('COHERENCE', f'Visual↔Auditory r={r:+.2f}'),
155
+ ('NOTE', 'Descriptive only. Not validated against engagement data.'),
156
+ ]
157
+ for i, (label, text) in enumerate(lines):
158
+ y = 0.9 - i * 0.18
159
+ ax_i.text(0.02, y, label, color='#ffa500', fontsize=10,
160
+ fontweight='bold', family='monospace', transform=ax_i.transAxes)
161
+ ax_i.text(0.2, y, text, color='white', fontsize=11, transform=ax_i.transAxes)
162
+
163
+ plt.savefig(save_path, dpi=130, bbox_inches='tight', facecolor='#0d1117')
164
+ plt.close(fig)
165
+ return save_path
166
+
167
+
168
+ def build_per_frame(preds, roi_activations, events_df, save_path='per_frame.png'):
169
+ n_sec = preds.shape[0]
170
+ net_tcs = {name: roi_activations[:, cols].mean(axis=1) if cols else np.zeros(n_sec)
171
+ for name, cols in NETWORKS.items()}
172
+ baseline_mean = preds.mean()
173
+ baseline_std = preds.std()
174
+
175
+ def words_at(t):
176
+ active = events_df[(events_df['start'] <= t) & (events_df['start'] + events_df['duration'] >= t)]
177
+ return active[active['type'] == 'Word']['text'].dropna().tolist()
178
+
179
+ fig = plt.figure(figsize=(16, 2.5 * n_sec), facecolor='#0d1117')
180
+ gs = GridSpec(n_sec, 3, figure=fig, width_ratios=[1.2, 1.5, 2.3],
181
+ hspace=0.3, wspace=0.15)
182
+ fig.suptitle('Per-Second Neural Breakdown', fontsize=22,
183
+ color='white', fontweight='bold', y=0.995)
184
+
185
+ for t in range(n_sec):
186
+ ax_b = fig.add_subplot(gs[t, 0])
187
+ try:
188
+ plotter.plot_surf(preds[t], axes=ax_b, views='left',
189
+ cmap='fire', norm_percentile=99,
190
+ vmin=0.1, alpha_cmap=(0, 0.2))
191
+ except Exception:
192
+ ax_b.text(0.5, 0.5, f't={t}s', ha='center', va='center', color='white')
193
+ ax_b.set_facecolor('#0d1117')
194
+ ax_b.axis('off')
195
+ ax_b.set_title(f't = {t}s', color='white', fontsize=14, fontweight='bold', pad=4)
196
+
197
+ ax_bar = fig.add_subplot(gs[t, 1])
198
+ ax_bar.set_facecolor('#161b22')
199
+ names = list(net_tcs.keys())
200
+ vals = [net_tcs[n][t] for n in names]
201
+ colors = [NET_COLORS[n] for n in names]
202
+ bars = ax_bar.barh(range(len(names)), vals, color=colors, alpha=0.9)
203
+ ax_bar.axvline(0, color='#30363d', linewidth=1)
204
+ ax_bar.set_yticks(range(len(names)))
205
+ ax_bar.set_yticklabels([n[:10] for n in names], color='white', fontsize=9)
206
+ ax_bar.tick_params(colors='white', labelsize=8)
207
+ ax_bar.set_xlim(-0.15, 0.45)
208
+ for spine in ax_bar.spines.values():
209
+ spine.set_color('#30363d')
210
+ ax_bar.grid(True, axis='x', alpha=0.1)
211
+ for bar, v in zip(bars, vals):
212
+ x = v + (0.01 if v >= 0 else -0.01)
213
+ ax_bar.text(x, bar.get_y() + bar.get_height() / 2, f'{v:+.2f}',
214
+ va='center', ha='left' if v >= 0 else 'right',
215
+ color='white', fontsize=8)
216
+
217
+ ax_t = fig.add_subplot(gs[t, 2])
218
+ ax_t.set_facecolor('#161b22')
219
+ ax_t.axis('off')
220
+ net_vals = {n: net_tcs[n][t] for n in net_tcs}
221
+ dominant = max(net_vals, key=net_vals.get)
222
+ weakest = min(net_vals, key=net_vals.get)
223
+ frame_mean = preds[t].mean()
224
+ z = (frame_mean - baseline_mean) / baseline_std if baseline_std > 0 else 0
225
+ words = words_at(t)
226
+ stimulus = ' '.join(words) if words else '(silent / visual only)'
227
+
228
+ v_, s_ = net_tcs['Visual'][t], net_tcs['Somatomotor'][t]
229
+ if v_ > 0.15 and s_ < 0.02:
230
+ modality, mod_color = 'visual-dominant', '#ff6b6b'
231
+ elif s_ > 0.05 and v_ < 0.15:
232
+ modality, mod_color = 'auditory-dominant', '#4ecdc4'
233
+ elif v_ > 0.1 and s_ > 0.05:
234
+ modality, mod_color = 'multimodal', '#ffe66d'
235
+ else:
236
+ modality, mod_color = 'low activation', '#8b949e'
237
+
238
+ y = 0.92
239
+ ax_t.text(0.02, y, 'STIMULUS', color='#8b949e', fontsize=9,
240
+ fontweight='bold', transform=ax_t.transAxes)
241
+ ax_t.text(0.22, y, stimulus, color='white', fontsize=11,
242
+ style='italic', transform=ax_t.transAxes)
243
+ y -= 0.15
244
+ ax_t.text(0.02, y, 'MODALITY', color='#8b949e', fontsize=9,
245
+ fontweight='bold', transform=ax_t.transAxes)
246
+ ax_t.text(0.22, y, modality, color=mod_color, fontsize=11,
247
+ fontweight='bold', transform=ax_t.transAxes)
248
+ ax_t.text(0.55, y, f'z = {z:+.2f}', color='#8b949e', fontsize=10,
249
+ transform=ax_t.transAxes)
250
+ y -= 0.15
251
+ ax_t.text(0.02, y, 'DOMINANT', color='#8b949e', fontsize=9,
252
+ fontweight='bold', transform=ax_t.transAxes)
253
+ ax_t.text(0.22, y, f'{dominant} ({net_vals[dominant]:+.3f})',
254
+ color=NET_COLORS[dominant], fontsize=11, transform=ax_t.transAxes)
255
+ y -= 0.15
256
+ ax_t.text(0.02, y, 'WEAKEST', color='#8b949e', fontsize=9,
257
+ fontweight='bold', transform=ax_t.transAxes)
258
+ ax_t.text(0.22, y, f'{weakest} ({net_vals[weakest]:+.3f})',
259
+ color=NET_COLORS[weakest], fontsize=11, transform=ax_t.transAxes)
260
+
261
+ plt.savefig(save_path, dpi=130, bbox_inches='tight', facecolor='#0d1117')
262
+ plt.close(fig)
263
+ return save_path
264
+
265
+
266
+ @spaces.GPU(duration=300) # 5 min GPU allocation per request
267
+ def analyze_video(video_file, progress=gr.Progress()):
268
+ if video_file is None:
269
+ return None, None, None, "Upload a video first."
270
+
271
+ progress(0.05, desc="Reading video...")
272
+ video_path = Path(video_file)
273
+
274
+ progress(0.1, desc="Extracting events (audio + speech transcription)...")
275
+ events_df = model.get_events_dataframe(video_path=video_path)
276
+
277
+ progress(0.4, desc="Predicting brain response...")
278
+ preds, segments = model.predict(events=events_df)
279
+
280
+ progress(0.65, desc="Aggregating network activity...")
281
+ roi_activations = aggregate_roi(preds)
282
+
283
+ progress(0.75, desc="Building dashboard...")
284
+ dash = build_dashboard(preds, roi_activations, 'dashboard.png')
285
+
286
+ progress(0.85, desc="Building per-frame breakdown...")
287
+ breakdown = build_per_frame(preds, roi_activations, events_df, 'per_frame.png')
288
+
289
+ progress(0.92, desc="Rendering brain animation...")
290
+ mp4_path = 'brain_activity.mp4'
291
+ plotter.plot_timesteps_mp4(preds, mp4_path, segments=segments)
292
+
293
+ visual_tc = roi_activations[:, NETWORKS['Visual']].mean(axis=1)
294
+ sommot_tc = roi_activations[:, NETWORKS['Somatomotor']].mean(axis=1)
295
+ r, p = pearsonr(visual_tc, sommot_tc)
296
+ summary = f"""Duration: {preds.shape[0]}s
297
+ Visual peak: {visual_tc.max():.3f} at t={int(visual_tc.argmax())}s
298
+ Auditory peak: {sommot_tc.max():.3f} at t={int(sommot_tc.argmax())}s
299
+ Cross-modal coherence: r = {r:+.3f} (p = {p:.3f})
300
+ Structure: {'visual-first with auditory handoff' if visual_tc.argmax() < sommot_tc.argmax() else 'auditory-first with visual support'}
301
+
302
+ Note: Metrics are descriptive. Engagement prediction requires validation against real performance data."""
303
+
304
+ return dash, breakdown, mp4_path, summary
305
+
306
+
307
+ with gr.Blocks(theme=gr.themes.Base(primary_hue="orange"), title="Humeo Neural Analyzer") as app:
308
+ gr.Markdown("""
309
+ # 🧠 Humeo Neural Content Analyzer
310
+ Upload a video. Get predicted fMRI brain response, network-level time courses, and per-second neural breakdown.
311
+
312
+ *Powered by Meta TRIBE v2 · Analysis takes ~2–3 minutes per video.*
313
+ """)
314
+
315
+ with gr.Row():
316
+ with gr.Column(scale=1):
317
+ video_in = gr.Video(label="Upload video (MP4, short-form recommended)")
318
+ btn = gr.Button("Analyze", variant="primary", size="lg")
319
+ with gr.Column(scale=1):
320
+ summary_out = gr.Textbox(label="Summary", lines=8, show_copy_button=True)
321
+
322
+ with gr.Tab("Dashboard"):
323
+ dash_out = gr.Image(label="Overview")
324
+ with gr.Tab("Per-Frame Breakdown"):
325
+ frame_out = gr.Image(label="Second-by-second analysis")
326
+ with gr.Tab("Brain Animation"):
327
+ mp4_out = gr.Video(label="Brain activity video")
328
+
329
+ btn.click(analyze_video, inputs=video_in,
330
+ outputs=[dash_out, frame_out, mp4_out, summary_out])
331
+
332
+ gr.Markdown("---\n*Prototype for Humeo R&D. Not a validated engagement prediction tool.*")
333
+
334
+
335
+ if __name__ == "__main__":
336
+ app.launch()