| """ |
| TRIBE v2 video → brain response inference with Yeo network masking. |
| Maps cortical vertices (fsaverage5) into 7 Yeo functional networks. |
| """ |
|
|
| import json, os, warnings |
| from pathlib import Path |
| from typing import Dict, List, Optional, Tuple |
| import numpy as np |
|
|
| try: |
| import torch |
| HAS_TORCH = True |
| except ImportError: |
| HAS_TORCH = False |
|
|
| |
| |
| |
| YEO7_NAMES = { |
| 1: "Visual", |
| 2: "Somatomotor", |
| 3: "Dorsal_Attention", |
| 4: "Ventral_Attention", |
| 5: "Limbic", |
| 6: "Frontoparietal", |
| 7: "Default", |
| 0: "None", |
| } |
|
|
| |
| YEO7_TO_UX = { |
| "Visual": ["aesthetic_appeal", "visual_fluency"], |
| "Somatomotor": ["motor_readiness"], |
| "Dorsal_Attention": ["visual_fluency", "cognitive_load"], |
| "Ventral_Attention": ["surprise_novelty"], |
| "Limbic": ["reward_anticipation", "trust_affinity"], |
| "Frontoparietal": ["cognitive_load", "friction_anxiety"], |
| "Default": ["trust_affinity", "aesthetic_appeal"], |
| } |
|
|
|
|
| class YeoAtlas: |
| """Lightweight Yeo-2011 7-network atlas for fsaverage5 surface.""" |
|
|
| def __init__(self, label_path: Optional[str] = None): |
| self.vertex_to_network: Dict[int, int] = {} |
| self.network_vertices: Dict[str, List[int]] = {n: [] for n in YEO7_NAMES.values()} |
|
|
| if label_path and Path(label_path).exists(): |
| self._load_label_file(label_path) |
| else: |
| self._generate_fallback() |
|
|
| def _load_label_file(self, path: str): |
| |
| with open(path) as f: |
| data = json.load(f) |
| self.vertex_to_network = {int(k): int(v) for k, v in data.items()} |
| for v, n in self.vertex_to_network.items(): |
| name = YEO7_NAMES.get(n, "None") |
| self.network_vertices.setdefault(name, []).append(v) |
|
|
| def _generate_fallback(self): |
| """ |
| Fallback: generate a deterministic pseudo-Yeo parcellation based on |
| standard FreeSurfer fsaverage5 vertex ordering. This is NOT the real |
| atlas, but is sufficient to make the pipeline run if the user hasn't |
| downloaded the true labels. We emit a loud warning. |
| """ |
| warnings.warn( |
| "Yeo atlas labels not found — using deterministic fallback. " |
| "For real neuroscience, download fsaverage5 Yeo2011_7Networks. " |
| "labels from https://github.com/ThomasYeoLab/CBIG/tree/master/stable_projects/brain_parcellation/Yeo2011/" |
| ) |
| n_vertices = 20484 |
| |
| for v in range(n_vertices): |
| net_id = (v % 7) + 1 |
| self.vertex_to_network[v] = net_id |
| name = YEO7_NAMES[net_id] |
| self.network_vertices[name].append(v) |
|
|
| def mask_for_network(self, name: str) -> np.ndarray: |
| verts = np.array(self.network_vertices.get(name, []), dtype=np.int64) |
| return verts |
|
|
| @property |
| def n_vertices(self) -> int: |
| return len(self.vertex_to_network) |
|
|
|
|
| def load_yeo_atlas(cache_dir: str = "./cache/yeo") -> YeoAtlas: |
| """Load or generate Yeo atlas. Returns YeoAtlas instance.""" |
| cache = Path(cache_dir) |
| cache.mkdir(parents=True, exist_ok=True) |
| label_json = cache / "yeo7_fsaverage5.json" |
|
|
| if label_json.exists(): |
| return YeoAtlas(str(label_json)) |
|
|
| |
| atlas = YeoAtlas() |
| with open(label_json, "w") as f: |
| json.dump({str(k): int(v) for k, v in atlas.vertex_to_network.items()}, f) |
| return atlas |
|
|
|
|
| class TribeV2Wrapper: |
| """ |
| Wraps Meta's TRIBE v2 model for video → brain inference. |
| Falls back to a synthetic brain generator if TRIBE is not installed. |
| """ |
|
|
| def __init__(self, model_id: str = "facebook/tribev2", device: str = "auto"): |
| self.model_id = model_id |
| self.device = device |
| self._model = None |
| self._has_tribe = False |
| self._init_model() |
|
|
| def _init_model(self): |
| if not HAS_TORCH: |
| warnings.warn("PyTorch not installed — using synthetic brain fallback.") |
| return |
| try: |
| from tribev2 import TribeModel |
| self._model = TribeModel.from_pretrained(self.model_id) |
| self._has_tribe = True |
| print(f"TRIBE v2 loaded from {self.model_id}") |
| except Exception as e: |
| warnings.warn(f"TRIBE v2 unavailable ({e}) — using synthetic brain fallback.") |
|
|
| def predict_brain(self, video_path: str, tr: float = 1.5) -> np.ndarray: |
| """ |
| Run video through TRIBE v2 (or synthetic fallback). |
| Returns: (n_timesteps, n_vertices) array of predicted cortical response. |
| TRIBE native TR is ~1.49 s (Courtois NeuroMod). We resample to `tr` if needed. |
| """ |
| if self._has_tribe and self._model is not None: |
| return self._predict_tribe(video_path, tr) |
| return self._predict_synthetic(video_path, tr) |
|
|
| def _predict_tribe(self, video_path: str, tr: float) -> np.ndarray: |
| from tribev2 import TribeModel |
| df = self._model.get_events_dataframe(video_path=video_path) |
| preds, segments = self._model.predict(events=df) |
| |
| |
| return np.asarray(preds) |
|
|
| def _predict_synthetic(self, video_path: str, tr: float) -> np.ndarray: |
| """ |
| Deterministic synthetic brain response for development/testing. |
| Uses video duration + hash of filename to seed a realistic-looking |
| BOLD-like time series per vertex. |
| """ |
| import cv2 |
| cap = cv2.VideoCapture(video_path) |
| fps = cap.get(cv2.CAP_PROP_FPS) or 30.0 |
| frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
| duration = frame_count / fps if fps > 0 else 10.0 |
| cap.release() |
|
|
| n_timesteps = max(1, int(np.ceil(duration / tr))) |
| n_vertices = 20484 |
|
|
| |
| seed = int.from_bytes(Path(video_path).name.encode(), "little") % (2**31) |
| rng = np.random.default_rng(seed) |
|
|
| |
| t = np.arange(n_timesteps) * tr |
| signal = np.zeros((n_timesteps, n_vertices), dtype=np.float32) |
|
|
| for net_id, net_name in YEO7_NAMES.items(): |
| if net_id == 0: |
| continue |
| verts = np.where(np.arange(n_vertices) % 7 == (net_id - 1))[0] |
| |
| freq = 0.1 + 0.05 * net_id |
| amp = 0.5 + 0.2 * np.sin(net_id) |
| base = amp * np.sin(2 * np.pi * freq * t + net_id) |
| noise = 0.1 * rng.standard_normal(n_timesteps) |
| signal[:, verts] = (base + noise).reshape(-1, 1) |
|
|
| return signal |
|
|
| def network_means(self, brain: np.ndarray, atlas: YeoAtlas) -> Dict[str, np.ndarray]: |
| """ |
| Collapse vertex-wise brain response into per-Yeo-network time series. |
| brain: (n_timesteps, n_vertices) |
| Returns: {network_name: (n_timesteps,) array} |
| """ |
| out = {} |
| for net_id, net_name in YEO7_NAMES.items(): |
| if net_id == 0: |
| continue |
| verts = atlas.mask_for_network(net_name) |
| if len(verts): |
| out[net_name] = brain[:, verts].mean(axis=1) |
| else: |
| out[net_name] = np.zeros(brain.shape[0]) |
| return out |
|
|