moonlantern1 commited on
Commit
c12563b
·
verified ·
1 Parent(s): 16697f6

Upload brain_virality_predictor/tribe_wrapper.py with huggingface_hub

Browse files
brain_virality_predictor/tribe_wrapper.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TRIBE v2 video → brain response inference with Yeo network masking.
3
+ Maps cortical vertices (fsaverage5) into 7 Yeo functional networks.
4
+ """
5
+
6
+ import json, os, warnings
7
+ from pathlib import Path
8
+ from typing import Dict, List, Optional, Tuple
9
+ import numpy as np
10
+
11
+ try:
12
+ import torch
13
+ HAS_TORCH = True
14
+ except ImportError:
15
+ HAS_TORCH = False
16
+
17
+ # ── Yeo 2011 7-network fsaverage5 labels (left hemi → vertex IDs) ─────
18
+ # These are the canonical surface-based labels from FreeSurfer's fsaverage5.
19
+ # We ship a lightweight JSON with vertex→network mapping so nilearn is optional.
20
+ YEO7_NAMES = {
21
+ 1: "Visual",
22
+ 2: "Somatomotor",
23
+ 3: "Dorsal_Attention",
24
+ 4: "Ventral_Attention",
25
+ 5: "Limbic",
26
+ 6: "Frontoparietal",
27
+ 7: "Default",
28
+ 0: "None",
29
+ }
30
+
31
+ # Yeo7 → our 8 UX signals (multi-mapping allowed)
32
+ YEO7_TO_UX = {
33
+ "Visual": ["aesthetic_appeal", "visual_fluency"],
34
+ "Somatomotor": ["motor_readiness"],
35
+ "Dorsal_Attention": ["visual_fluency", "cognitive_load"],
36
+ "Ventral_Attention": ["surprise_novelty"],
37
+ "Limbic": ["reward_anticipation", "trust_affinity"],
38
+ "Frontoparietal": ["cognitive_load", "friction_anxiety"],
39
+ "Default": ["trust_affinity", "aesthetic_appeal"],
40
+ }
41
+
42
+
43
+ class YeoAtlas:
44
+ """Lightweight Yeo-2011 7-network atlas for fsaverage5 surface."""
45
+
46
+ def __init__(self, label_path: Optional[str] = None):
47
+ self.vertex_to_network: Dict[int, int] = {}
48
+ self.network_vertices: Dict[str, List[int]] = {n: [] for n in YEO7_NAMES.values()}
49
+
50
+ if label_path and Path(label_path).exists():
51
+ self._load_label_file(label_path)
52
+ else:
53
+ self._generate_fallback()
54
+
55
+ def _load_label_file(self, path: str):
56
+ # FreeSurfer .label or .annot format not needed if we use our fallback JSON
57
+ with open(path) as f:
58
+ data = json.load(f)
59
+ self.vertex_to_network = {int(k): int(v) for k, v in data.items()}
60
+ for v, n in self.vertex_to_network.items():
61
+ name = YEO7_NAMES.get(n, "None")
62
+ self.network_vertices.setdefault(name, []).append(v)
63
+
64
+ def _generate_fallback(self):
65
+ """
66
+ Fallback: generate a deterministic pseudo-Yeo parcellation based on
67
+ standard FreeSurfer fsaverage5 vertex ordering. This is NOT the real
68
+ atlas, but is sufficient to make the pipeline run if the user hasn't
69
+ downloaded the true labels. We emit a loud warning.
70
+ """
71
+ warnings.warn(
72
+ "Yeo atlas labels not found — using deterministic fallback. "
73
+ "For real neuroscience, download fsaverage5 Yeo2011_7Networks. "
74
+ "labels from https://github.com/ThomasYeoLab/CBIG/tree/master/stable_projects/brain_parcellation/Yeo2011/"
75
+ )
76
+ n_vertices = 20484 # fsaverage5 total
77
+ # deterministic round-robin assignment per canonical ordering
78
+ for v in range(n_vertices):
79
+ net_id = (v % 7) + 1
80
+ self.vertex_to_network[v] = net_id
81
+ name = YEO7_NAMES[net_id]
82
+ self.network_vertices[name].append(v)
83
+
84
+ def mask_for_network(self, name: str) -> np.ndarray:
85
+ verts = np.array(self.network_vertices.get(name, []), dtype=np.int64)
86
+ return verts
87
+
88
+ @property
89
+ def n_vertices(self) -> int:
90
+ return len(self.vertex_to_network)
91
+
92
+
93
+ def load_yeo_atlas(cache_dir: str = "./cache/yeo") -> YeoAtlas:
94
+ """Load or generate Yeo atlas. Returns YeoAtlas instance."""
95
+ cache = Path(cache_dir)
96
+ cache.mkdir(parents=True, exist_ok=True)
97
+ label_json = cache / "yeo7_fsaverage5.json"
98
+
99
+ if label_json.exists():
100
+ return YeoAtlas(str(label_json))
101
+
102
+ # If user hasn't provided labels, create fallback and save it
103
+ atlas = YeoAtlas()
104
+ with open(label_json, "w") as f:
105
+ json.dump({str(k): int(v) for k, v in atlas.vertex_to_network.items()}, f)
106
+ return atlas
107
+
108
+
109
+ class TribeV2Wrapper:
110
+ """
111
+ Wraps Meta's TRIBE v2 model for video → brain inference.
112
+ Falls back to a synthetic brain generator if TRIBE is not installed.
113
+ """
114
+
115
+ def __init__(self, model_id: str = "facebook/tribev2", device: str = "auto"):
116
+ self.model_id = model_id
117
+ self.device = device
118
+ self._model = None
119
+ self._has_tribe = False
120
+ self._init_model()
121
+
122
+ def _init_model(self):
123
+ if not HAS_TORCH:
124
+ warnings.warn("PyTorch not installed — using synthetic brain fallback.")
125
+ return
126
+ try:
127
+ from tribev2 import TribeModel
128
+ self._model = TribeModel.from_pretrained(self.model_id)
129
+ self._has_tribe = True
130
+ print(f"TRIBE v2 loaded from {self.model_id}")
131
+ except Exception as e:
132
+ warnings.warn(f"TRIBE v2 unavailable ({e}) — using synthetic brain fallback.")
133
+
134
+ def predict_brain(self, video_path: str, tr: float = 1.5) -> np.ndarray:
135
+ """
136
+ Run video through TRIBE v2 (or synthetic fallback).
137
+ Returns: (n_timesteps, n_vertices) array of predicted cortical response.
138
+ TRIBE native TR is ~1.49 s (Courtois NeuroMod). We resample to `tr` if needed.
139
+ """
140
+ if self._has_tribe and self._model is not None:
141
+ return self._predict_tribe(video_path, tr)
142
+ return self._predict_synthetic(video_path, tr)
143
+
144
+ def _predict_tribe(self, video_path: str, tr: float) -> np.ndarray:
145
+ from tribev2 import TribeModel
146
+ df = self._model.get_events_dataframe(video_path=video_path)
147
+ preds, segments = self._model.predict(events=df)
148
+ # preds shape: (n_timesteps, n_vertices)
149
+ # TRIBE TR is ~1.49s; if user wants different TR, we can resample later
150
+ return np.asarray(preds)
151
+
152
+ def _predict_synthetic(self, video_path: str, tr: float) -> np.ndarray:
153
+ """
154
+ Deterministic synthetic brain response for development/testing.
155
+ Uses video duration + hash of filename to seed a realistic-looking
156
+ BOLD-like time series per vertex.
157
+ """
158
+ import cv2
159
+ cap = cv2.VideoCapture(video_path)
160
+ fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
161
+ frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
162
+ duration = frame_count / fps if fps > 0 else 10.0
163
+ cap.release()
164
+
165
+ n_timesteps = max(1, int(np.ceil(duration / tr)))
166
+ n_vertices = 20484 # fsaverage5
167
+
168
+ # Seed from video name for reproducibility
169
+ seed = int.from_bytes(Path(video_path).name.encode(), "little") % (2**31)
170
+ rng = np.random.default_rng(seed)
171
+
172
+ # Base signal: slow drift + fast fluctuation per network
173
+ t = np.arange(n_timesteps) * tr
174
+ signal = np.zeros((n_timesteps, n_vertices), dtype=np.float32)
175
+
176
+ for net_id, net_name in YEO7_NAMES.items():
177
+ if net_id == 0:
178
+ continue
179
+ verts = np.where(np.arange(n_vertices) % 7 == (net_id - 1))[0]
180
+ # Network-specific temporal dynamics
181
+ freq = 0.1 + 0.05 * net_id # ~0.01–0.1 Hz BOLD-ish
182
+ amp = 0.5 + 0.2 * np.sin(net_id)
183
+ base = amp * np.sin(2 * np.pi * freq * t + net_id)
184
+ noise = 0.1 * rng.standard_normal(n_timesteps)
185
+ signal[:, verts] = (base + noise).reshape(-1, 1)
186
+
187
+ return signal
188
+
189
+ def network_means(self, brain: np.ndarray, atlas: YeoAtlas) -> Dict[str, np.ndarray]:
190
+ """
191
+ Collapse vertex-wise brain response into per-Yeo-network time series.
192
+ brain: (n_timesteps, n_vertices)
193
+ Returns: {network_name: (n_timesteps,) array}
194
+ """
195
+ out = {}
196
+ for net_id, net_name in YEO7_NAMES.items():
197
+ if net_id == 0:
198
+ continue
199
+ verts = atlas.mask_for_network(net_name)
200
+ if len(verts):
201
+ out[net_name] = brain[:, verts].mean(axis=1)
202
+ else:
203
+ out[net_name] = np.zeros(brain.shape[0])
204
+ return out