Upload brain_virality_predictor/tribe_wrapper.py with huggingface_hub
Browse files
brain_virality_predictor/tribe_wrapper.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TRIBE v2 video → brain response inference with Yeo network masking.
|
| 3 |
+
Maps cortical vertices (fsaverage5) into 7 Yeo functional networks.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import json, os, warnings
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Dict, List, Optional, Tuple
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
import torch
|
| 13 |
+
HAS_TORCH = True
|
| 14 |
+
except ImportError:
|
| 15 |
+
HAS_TORCH = False
|
| 16 |
+
|
| 17 |
+
# ── Yeo 2011 7-network fsaverage5 labels (left hemi → vertex IDs) ─────
|
| 18 |
+
# These are the canonical surface-based labels from FreeSurfer's fsaverage5.
|
| 19 |
+
# We ship a lightweight JSON with vertex→network mapping so nilearn is optional.
|
| 20 |
+
YEO7_NAMES = {
|
| 21 |
+
1: "Visual",
|
| 22 |
+
2: "Somatomotor",
|
| 23 |
+
3: "Dorsal_Attention",
|
| 24 |
+
4: "Ventral_Attention",
|
| 25 |
+
5: "Limbic",
|
| 26 |
+
6: "Frontoparietal",
|
| 27 |
+
7: "Default",
|
| 28 |
+
0: "None",
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
# Yeo7 → our 8 UX signals (multi-mapping allowed)
|
| 32 |
+
YEO7_TO_UX = {
|
| 33 |
+
"Visual": ["aesthetic_appeal", "visual_fluency"],
|
| 34 |
+
"Somatomotor": ["motor_readiness"],
|
| 35 |
+
"Dorsal_Attention": ["visual_fluency", "cognitive_load"],
|
| 36 |
+
"Ventral_Attention": ["surprise_novelty"],
|
| 37 |
+
"Limbic": ["reward_anticipation", "trust_affinity"],
|
| 38 |
+
"Frontoparietal": ["cognitive_load", "friction_anxiety"],
|
| 39 |
+
"Default": ["trust_affinity", "aesthetic_appeal"],
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class YeoAtlas:
|
| 44 |
+
"""Lightweight Yeo-2011 7-network atlas for fsaverage5 surface."""
|
| 45 |
+
|
| 46 |
+
def __init__(self, label_path: Optional[str] = None):
|
| 47 |
+
self.vertex_to_network: Dict[int, int] = {}
|
| 48 |
+
self.network_vertices: Dict[str, List[int]] = {n: [] for n in YEO7_NAMES.values()}
|
| 49 |
+
|
| 50 |
+
if label_path and Path(label_path).exists():
|
| 51 |
+
self._load_label_file(label_path)
|
| 52 |
+
else:
|
| 53 |
+
self._generate_fallback()
|
| 54 |
+
|
| 55 |
+
def _load_label_file(self, path: str):
|
| 56 |
+
# FreeSurfer .label or .annot format not needed if we use our fallback JSON
|
| 57 |
+
with open(path) as f:
|
| 58 |
+
data = json.load(f)
|
| 59 |
+
self.vertex_to_network = {int(k): int(v) for k, v in data.items()}
|
| 60 |
+
for v, n in self.vertex_to_network.items():
|
| 61 |
+
name = YEO7_NAMES.get(n, "None")
|
| 62 |
+
self.network_vertices.setdefault(name, []).append(v)
|
| 63 |
+
|
| 64 |
+
def _generate_fallback(self):
|
| 65 |
+
"""
|
| 66 |
+
Fallback: generate a deterministic pseudo-Yeo parcellation based on
|
| 67 |
+
standard FreeSurfer fsaverage5 vertex ordering. This is NOT the real
|
| 68 |
+
atlas, but is sufficient to make the pipeline run if the user hasn't
|
| 69 |
+
downloaded the true labels. We emit a loud warning.
|
| 70 |
+
"""
|
| 71 |
+
warnings.warn(
|
| 72 |
+
"Yeo atlas labels not found — using deterministic fallback. "
|
| 73 |
+
"For real neuroscience, download fsaverage5 Yeo2011_7Networks. "
|
| 74 |
+
"labels from https://github.com/ThomasYeoLab/CBIG/tree/master/stable_projects/brain_parcellation/Yeo2011/"
|
| 75 |
+
)
|
| 76 |
+
n_vertices = 20484 # fsaverage5 total
|
| 77 |
+
# deterministic round-robin assignment per canonical ordering
|
| 78 |
+
for v in range(n_vertices):
|
| 79 |
+
net_id = (v % 7) + 1
|
| 80 |
+
self.vertex_to_network[v] = net_id
|
| 81 |
+
name = YEO7_NAMES[net_id]
|
| 82 |
+
self.network_vertices[name].append(v)
|
| 83 |
+
|
| 84 |
+
def mask_for_network(self, name: str) -> np.ndarray:
|
| 85 |
+
verts = np.array(self.network_vertices.get(name, []), dtype=np.int64)
|
| 86 |
+
return verts
|
| 87 |
+
|
| 88 |
+
@property
|
| 89 |
+
def n_vertices(self) -> int:
|
| 90 |
+
return len(self.vertex_to_network)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def load_yeo_atlas(cache_dir: str = "./cache/yeo") -> YeoAtlas:
|
| 94 |
+
"""Load or generate Yeo atlas. Returns YeoAtlas instance."""
|
| 95 |
+
cache = Path(cache_dir)
|
| 96 |
+
cache.mkdir(parents=True, exist_ok=True)
|
| 97 |
+
label_json = cache / "yeo7_fsaverage5.json"
|
| 98 |
+
|
| 99 |
+
if label_json.exists():
|
| 100 |
+
return YeoAtlas(str(label_json))
|
| 101 |
+
|
| 102 |
+
# If user hasn't provided labels, create fallback and save it
|
| 103 |
+
atlas = YeoAtlas()
|
| 104 |
+
with open(label_json, "w") as f:
|
| 105 |
+
json.dump({str(k): int(v) for k, v in atlas.vertex_to_network.items()}, f)
|
| 106 |
+
return atlas
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class TribeV2Wrapper:
|
| 110 |
+
"""
|
| 111 |
+
Wraps Meta's TRIBE v2 model for video → brain inference.
|
| 112 |
+
Falls back to a synthetic brain generator if TRIBE is not installed.
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
def __init__(self, model_id: str = "facebook/tribev2", device: str = "auto"):
|
| 116 |
+
self.model_id = model_id
|
| 117 |
+
self.device = device
|
| 118 |
+
self._model = None
|
| 119 |
+
self._has_tribe = False
|
| 120 |
+
self._init_model()
|
| 121 |
+
|
| 122 |
+
def _init_model(self):
|
| 123 |
+
if not HAS_TORCH:
|
| 124 |
+
warnings.warn("PyTorch not installed — using synthetic brain fallback.")
|
| 125 |
+
return
|
| 126 |
+
try:
|
| 127 |
+
from tribev2 import TribeModel
|
| 128 |
+
self._model = TribeModel.from_pretrained(self.model_id)
|
| 129 |
+
self._has_tribe = True
|
| 130 |
+
print(f"TRIBE v2 loaded from {self.model_id}")
|
| 131 |
+
except Exception as e:
|
| 132 |
+
warnings.warn(f"TRIBE v2 unavailable ({e}) — using synthetic brain fallback.")
|
| 133 |
+
|
| 134 |
+
def predict_brain(self, video_path: str, tr: float = 1.5) -> np.ndarray:
|
| 135 |
+
"""
|
| 136 |
+
Run video through TRIBE v2 (or synthetic fallback).
|
| 137 |
+
Returns: (n_timesteps, n_vertices) array of predicted cortical response.
|
| 138 |
+
TRIBE native TR is ~1.49 s (Courtois NeuroMod). We resample to `tr` if needed.
|
| 139 |
+
"""
|
| 140 |
+
if self._has_tribe and self._model is not None:
|
| 141 |
+
return self._predict_tribe(video_path, tr)
|
| 142 |
+
return self._predict_synthetic(video_path, tr)
|
| 143 |
+
|
| 144 |
+
def _predict_tribe(self, video_path: str, tr: float) -> np.ndarray:
|
| 145 |
+
from tribev2 import TribeModel
|
| 146 |
+
df = self._model.get_events_dataframe(video_path=video_path)
|
| 147 |
+
preds, segments = self._model.predict(events=df)
|
| 148 |
+
# preds shape: (n_timesteps, n_vertices)
|
| 149 |
+
# TRIBE TR is ~1.49s; if user wants different TR, we can resample later
|
| 150 |
+
return np.asarray(preds)
|
| 151 |
+
|
| 152 |
+
def _predict_synthetic(self, video_path: str, tr: float) -> np.ndarray:
|
| 153 |
+
"""
|
| 154 |
+
Deterministic synthetic brain response for development/testing.
|
| 155 |
+
Uses video duration + hash of filename to seed a realistic-looking
|
| 156 |
+
BOLD-like time series per vertex.
|
| 157 |
+
"""
|
| 158 |
+
import cv2
|
| 159 |
+
cap = cv2.VideoCapture(video_path)
|
| 160 |
+
fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
|
| 161 |
+
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 162 |
+
duration = frame_count / fps if fps > 0 else 10.0
|
| 163 |
+
cap.release()
|
| 164 |
+
|
| 165 |
+
n_timesteps = max(1, int(np.ceil(duration / tr)))
|
| 166 |
+
n_vertices = 20484 # fsaverage5
|
| 167 |
+
|
| 168 |
+
# Seed from video name for reproducibility
|
| 169 |
+
seed = int.from_bytes(Path(video_path).name.encode(), "little") % (2**31)
|
| 170 |
+
rng = np.random.default_rng(seed)
|
| 171 |
+
|
| 172 |
+
# Base signal: slow drift + fast fluctuation per network
|
| 173 |
+
t = np.arange(n_timesteps) * tr
|
| 174 |
+
signal = np.zeros((n_timesteps, n_vertices), dtype=np.float32)
|
| 175 |
+
|
| 176 |
+
for net_id, net_name in YEO7_NAMES.items():
|
| 177 |
+
if net_id == 0:
|
| 178 |
+
continue
|
| 179 |
+
verts = np.where(np.arange(n_vertices) % 7 == (net_id - 1))[0]
|
| 180 |
+
# Network-specific temporal dynamics
|
| 181 |
+
freq = 0.1 + 0.05 * net_id # ~0.01–0.1 Hz BOLD-ish
|
| 182 |
+
amp = 0.5 + 0.2 * np.sin(net_id)
|
| 183 |
+
base = amp * np.sin(2 * np.pi * freq * t + net_id)
|
| 184 |
+
noise = 0.1 * rng.standard_normal(n_timesteps)
|
| 185 |
+
signal[:, verts] = (base + noise).reshape(-1, 1)
|
| 186 |
+
|
| 187 |
+
return signal
|
| 188 |
+
|
| 189 |
+
def network_means(self, brain: np.ndarray, atlas: YeoAtlas) -> Dict[str, np.ndarray]:
|
| 190 |
+
"""
|
| 191 |
+
Collapse vertex-wise brain response into per-Yeo-network time series.
|
| 192 |
+
brain: (n_timesteps, n_vertices)
|
| 193 |
+
Returns: {network_name: (n_timesteps,) array}
|
| 194 |
+
"""
|
| 195 |
+
out = {}
|
| 196 |
+
for net_id, net_name in YEO7_NAMES.items():
|
| 197 |
+
if net_id == 0:
|
| 198 |
+
continue
|
| 199 |
+
verts = atlas.mask_for_network(net_name)
|
| 200 |
+
if len(verts):
|
| 201 |
+
out[net_name] = brain[:, verts].mean(axis=1)
|
| 202 |
+
else:
|
| 203 |
+
out[net_name] = np.zeros(brain.shape[0])
|
| 204 |
+
return out
|