File size: 7,904 Bytes
c12563b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 | """
TRIBE v2 video → brain response inference with Yeo network masking.
Maps cortical vertices (fsaverage5) into 7 Yeo functional networks.
"""
import json, os, warnings
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import numpy as np
try:
import torch
HAS_TORCH = True
except ImportError:
HAS_TORCH = False
# ── Yeo 2011 7-network fsaverage5 labels (left hemi → vertex IDs) ─────
# These are the canonical surface-based labels from FreeSurfer's fsaverage5.
# We ship a lightweight JSON with vertex→network mapping so nilearn is optional.
YEO7_NAMES = {
1: "Visual",
2: "Somatomotor",
3: "Dorsal_Attention",
4: "Ventral_Attention",
5: "Limbic",
6: "Frontoparietal",
7: "Default",
0: "None",
}
# Yeo7 → our 8 UX signals (multi-mapping allowed)
YEO7_TO_UX = {
"Visual": ["aesthetic_appeal", "visual_fluency"],
"Somatomotor": ["motor_readiness"],
"Dorsal_Attention": ["visual_fluency", "cognitive_load"],
"Ventral_Attention": ["surprise_novelty"],
"Limbic": ["reward_anticipation", "trust_affinity"],
"Frontoparietal": ["cognitive_load", "friction_anxiety"],
"Default": ["trust_affinity", "aesthetic_appeal"],
}
class YeoAtlas:
"""Lightweight Yeo-2011 7-network atlas for fsaverage5 surface."""
def __init__(self, label_path: Optional[str] = None):
self.vertex_to_network: Dict[int, int] = {}
self.network_vertices: Dict[str, List[int]] = {n: [] for n in YEO7_NAMES.values()}
if label_path and Path(label_path).exists():
self._load_label_file(label_path)
else:
self._generate_fallback()
def _load_label_file(self, path: str):
# FreeSurfer .label or .annot format not needed if we use our fallback JSON
with open(path) as f:
data = json.load(f)
self.vertex_to_network = {int(k): int(v) for k, v in data.items()}
for v, n in self.vertex_to_network.items():
name = YEO7_NAMES.get(n, "None")
self.network_vertices.setdefault(name, []).append(v)
def _generate_fallback(self):
"""
Fallback: generate a deterministic pseudo-Yeo parcellation based on
standard FreeSurfer fsaverage5 vertex ordering. This is NOT the real
atlas, but is sufficient to make the pipeline run if the user hasn't
downloaded the true labels. We emit a loud warning.
"""
warnings.warn(
"Yeo atlas labels not found — using deterministic fallback. "
"For real neuroscience, download fsaverage5 Yeo2011_7Networks. "
"labels from https://github.com/ThomasYeoLab/CBIG/tree/master/stable_projects/brain_parcellation/Yeo2011/"
)
n_vertices = 20484 # fsaverage5 total
# deterministic round-robin assignment per canonical ordering
for v in range(n_vertices):
net_id = (v % 7) + 1
self.vertex_to_network[v] = net_id
name = YEO7_NAMES[net_id]
self.network_vertices[name].append(v)
def mask_for_network(self, name: str) -> np.ndarray:
verts = np.array(self.network_vertices.get(name, []), dtype=np.int64)
return verts
@property
def n_vertices(self) -> int:
return len(self.vertex_to_network)
def load_yeo_atlas(cache_dir: str = "./cache/yeo") -> YeoAtlas:
"""Load or generate Yeo atlas. Returns YeoAtlas instance."""
cache = Path(cache_dir)
cache.mkdir(parents=True, exist_ok=True)
label_json = cache / "yeo7_fsaverage5.json"
if label_json.exists():
return YeoAtlas(str(label_json))
# If user hasn't provided labels, create fallback and save it
atlas = YeoAtlas()
with open(label_json, "w") as f:
json.dump({str(k): int(v) for k, v in atlas.vertex_to_network.items()}, f)
return atlas
class TribeV2Wrapper:
"""
Wraps Meta's TRIBE v2 model for video → brain inference.
Falls back to a synthetic brain generator if TRIBE is not installed.
"""
def __init__(self, model_id: str = "facebook/tribev2", device: str = "auto"):
self.model_id = model_id
self.device = device
self._model = None
self._has_tribe = False
self._init_model()
def _init_model(self):
if not HAS_TORCH:
warnings.warn("PyTorch not installed — using synthetic brain fallback.")
return
try:
from tribev2 import TribeModel
self._model = TribeModel.from_pretrained(self.model_id)
self._has_tribe = True
print(f"TRIBE v2 loaded from {self.model_id}")
except Exception as e:
warnings.warn(f"TRIBE v2 unavailable ({e}) — using synthetic brain fallback.")
def predict_brain(self, video_path: str, tr: float = 1.5) -> np.ndarray:
"""
Run video through TRIBE v2 (or synthetic fallback).
Returns: (n_timesteps, n_vertices) array of predicted cortical response.
TRIBE native TR is ~1.49 s (Courtois NeuroMod). We resample to `tr` if needed.
"""
if self._has_tribe and self._model is not None:
return self._predict_tribe(video_path, tr)
return self._predict_synthetic(video_path, tr)
def _predict_tribe(self, video_path: str, tr: float) -> np.ndarray:
from tribev2 import TribeModel
df = self._model.get_events_dataframe(video_path=video_path)
preds, segments = self._model.predict(events=df)
# preds shape: (n_timesteps, n_vertices)
# TRIBE TR is ~1.49s; if user wants different TR, we can resample later
return np.asarray(preds)
def _predict_synthetic(self, video_path: str, tr: float) -> np.ndarray:
"""
Deterministic synthetic brain response for development/testing.
Uses video duration + hash of filename to seed a realistic-looking
BOLD-like time series per vertex.
"""
import cv2
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
duration = frame_count / fps if fps > 0 else 10.0
cap.release()
n_timesteps = max(1, int(np.ceil(duration / tr)))
n_vertices = 20484 # fsaverage5
# Seed from video name for reproducibility
seed = int.from_bytes(Path(video_path).name.encode(), "little") % (2**31)
rng = np.random.default_rng(seed)
# Base signal: slow drift + fast fluctuation per network
t = np.arange(n_timesteps) * tr
signal = np.zeros((n_timesteps, n_vertices), dtype=np.float32)
for net_id, net_name in YEO7_NAMES.items():
if net_id == 0:
continue
verts = np.where(np.arange(n_vertices) % 7 == (net_id - 1))[0]
# Network-specific temporal dynamics
freq = 0.1 + 0.05 * net_id # ~0.01–0.1 Hz BOLD-ish
amp = 0.5 + 0.2 * np.sin(net_id)
base = amp * np.sin(2 * np.pi * freq * t + net_id)
noise = 0.1 * rng.standard_normal(n_timesteps)
signal[:, verts] = (base + noise).reshape(-1, 1)
return signal
def network_means(self, brain: np.ndarray, atlas: YeoAtlas) -> Dict[str, np.ndarray]:
"""
Collapse vertex-wise brain response into per-Yeo-network time series.
brain: (n_timesteps, n_vertices)
Returns: {network_name: (n_timesteps,) array}
"""
out = {}
for net_id, net_name in YEO7_NAMES.items():
if net_id == 0:
continue
verts = atlas.mask_for_network(net_name)
if len(verts):
out[net_name] = brain[:, verts].mean(axis=1)
else:
out[net_name] = np.zeros(brain.shape[0])
return out
|