""" TRIBE v2 video → brain response inference with Yeo network masking. Maps cortical vertices (fsaverage5) into 7 Yeo functional networks. """ import json, os, warnings from pathlib import Path from typing import Dict, List, Optional, Tuple import numpy as np try: import torch HAS_TORCH = True except ImportError: HAS_TORCH = False # ── Yeo 2011 7-network fsaverage5 labels (left hemi → vertex IDs) ───── # These are the canonical surface-based labels from FreeSurfer's fsaverage5. # We ship a lightweight JSON with vertex→network mapping so nilearn is optional. YEO7_NAMES = { 1: "Visual", 2: "Somatomotor", 3: "Dorsal_Attention", 4: "Ventral_Attention", 5: "Limbic", 6: "Frontoparietal", 7: "Default", 0: "None", } # Yeo7 → our 8 UX signals (multi-mapping allowed) YEO7_TO_UX = { "Visual": ["aesthetic_appeal", "visual_fluency"], "Somatomotor": ["motor_readiness"], "Dorsal_Attention": ["visual_fluency", "cognitive_load"], "Ventral_Attention": ["surprise_novelty"], "Limbic": ["reward_anticipation", "trust_affinity"], "Frontoparietal": ["cognitive_load", "friction_anxiety"], "Default": ["trust_affinity", "aesthetic_appeal"], } class YeoAtlas: """Lightweight Yeo-2011 7-network atlas for fsaverage5 surface.""" def __init__(self, label_path: Optional[str] = None): self.vertex_to_network: Dict[int, int] = {} self.network_vertices: Dict[str, List[int]] = {n: [] for n in YEO7_NAMES.values()} if label_path and Path(label_path).exists(): self._load_label_file(label_path) else: self._generate_fallback() def _load_label_file(self, path: str): # FreeSurfer .label or .annot format not needed if we use our fallback JSON with open(path) as f: data = json.load(f) self.vertex_to_network = {int(k): int(v) for k, v in data.items()} for v, n in self.vertex_to_network.items(): name = YEO7_NAMES.get(n, "None") self.network_vertices.setdefault(name, []).append(v) def _generate_fallback(self): """ Fallback: generate a deterministic pseudo-Yeo parcellation based on standard FreeSurfer fsaverage5 vertex ordering. This is NOT the real atlas, but is sufficient to make the pipeline run if the user hasn't downloaded the true labels. We emit a loud warning. """ warnings.warn( "Yeo atlas labels not found — using deterministic fallback. " "For real neuroscience, download fsaverage5 Yeo2011_7Networks. " "labels from https://github.com/ThomasYeoLab/CBIG/tree/master/stable_projects/brain_parcellation/Yeo2011/" ) n_vertices = 20484 # fsaverage5 total # deterministic round-robin assignment per canonical ordering for v in range(n_vertices): net_id = (v % 7) + 1 self.vertex_to_network[v] = net_id name = YEO7_NAMES[net_id] self.network_vertices[name].append(v) def mask_for_network(self, name: str) -> np.ndarray: verts = np.array(self.network_vertices.get(name, []), dtype=np.int64) return verts @property def n_vertices(self) -> int: return len(self.vertex_to_network) def load_yeo_atlas(cache_dir: str = "./cache/yeo") -> YeoAtlas: """Load or generate Yeo atlas. Returns YeoAtlas instance.""" cache = Path(cache_dir) cache.mkdir(parents=True, exist_ok=True) label_json = cache / "yeo7_fsaverage5.json" if label_json.exists(): return YeoAtlas(str(label_json)) # If user hasn't provided labels, create fallback and save it atlas = YeoAtlas() with open(label_json, "w") as f: json.dump({str(k): int(v) for k, v in atlas.vertex_to_network.items()}, f) return atlas class TribeV2Wrapper: """ Wraps Meta's TRIBE v2 model for video → brain inference. Falls back to a synthetic brain generator if TRIBE is not installed. """ def __init__(self, model_id: str = "facebook/tribev2", device: str = "auto"): self.model_id = model_id self.device = device self._model = None self._has_tribe = False self._init_model() def _init_model(self): if not HAS_TORCH: warnings.warn("PyTorch not installed — using synthetic brain fallback.") return try: from tribev2 import TribeModel self._model = TribeModel.from_pretrained(self.model_id) self._has_tribe = True print(f"TRIBE v2 loaded from {self.model_id}") except Exception as e: warnings.warn(f"TRIBE v2 unavailable ({e}) — using synthetic brain fallback.") def predict_brain(self, video_path: str, tr: float = 1.5) -> np.ndarray: """ Run video through TRIBE v2 (or synthetic fallback). Returns: (n_timesteps, n_vertices) array of predicted cortical response. TRIBE native TR is ~1.49 s (Courtois NeuroMod). We resample to `tr` if needed. """ if self._has_tribe and self._model is not None: return self._predict_tribe(video_path, tr) return self._predict_synthetic(video_path, tr) def _predict_tribe(self, video_path: str, tr: float) -> np.ndarray: from tribev2 import TribeModel df = self._model.get_events_dataframe(video_path=video_path) preds, segments = self._model.predict(events=df) # preds shape: (n_timesteps, n_vertices) # TRIBE TR is ~1.49s; if user wants different TR, we can resample later return np.asarray(preds) def _predict_synthetic(self, video_path: str, tr: float) -> np.ndarray: """ Deterministic synthetic brain response for development/testing. Uses video duration + hash of filename to seed a realistic-looking BOLD-like time series per vertex. """ import cv2 cap = cv2.VideoCapture(video_path) fps = cap.get(cv2.CAP_PROP_FPS) or 30.0 frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) duration = frame_count / fps if fps > 0 else 10.0 cap.release() n_timesteps = max(1, int(np.ceil(duration / tr))) n_vertices = 20484 # fsaverage5 # Seed from video name for reproducibility seed = int.from_bytes(Path(video_path).name.encode(), "little") % (2**31) rng = np.random.default_rng(seed) # Base signal: slow drift + fast fluctuation per network t = np.arange(n_timesteps) * tr signal = np.zeros((n_timesteps, n_vertices), dtype=np.float32) for net_id, net_name in YEO7_NAMES.items(): if net_id == 0: continue verts = np.where(np.arange(n_vertices) % 7 == (net_id - 1))[0] # Network-specific temporal dynamics freq = 0.1 + 0.05 * net_id # ~0.01–0.1 Hz BOLD-ish amp = 0.5 + 0.2 * np.sin(net_id) base = amp * np.sin(2 * np.pi * freq * t + net_id) noise = 0.1 * rng.standard_normal(n_timesteps) signal[:, verts] = (base + noise).reshape(-1, 1) return signal def network_means(self, brain: np.ndarray, atlas: YeoAtlas) -> Dict[str, np.ndarray]: """ Collapse vertex-wise brain response into per-Yeo-network time series. brain: (n_timesteps, n_vertices) Returns: {network_name: (n_timesteps,) array} """ out = {} for net_id, net_name in YEO7_NAMES.items(): if net_id == 0: continue verts = atlas.mask_for_network(net_name) if len(verts): out[net_name] = brain[:, verts].mean(axis=1) else: out[net_name] = np.zeros(brain.shape[0]) return out