moonlantern1's picture
Upload brain_virality_predictor/tribe_wrapper.py with huggingface_hub
c12563b verified
"""
TRIBE v2 video → brain response inference with Yeo network masking.
Maps cortical vertices (fsaverage5) into 7 Yeo functional networks.
"""
import json, os, warnings
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import numpy as np
try:
import torch
HAS_TORCH = True
except ImportError:
HAS_TORCH = False
# ── Yeo 2011 7-network fsaverage5 labels (left hemi → vertex IDs) ─────
# These are the canonical surface-based labels from FreeSurfer's fsaverage5.
# We ship a lightweight JSON with vertex→network mapping so nilearn is optional.
YEO7_NAMES = {
1: "Visual",
2: "Somatomotor",
3: "Dorsal_Attention",
4: "Ventral_Attention",
5: "Limbic",
6: "Frontoparietal",
7: "Default",
0: "None",
}
# Yeo7 → our 8 UX signals (multi-mapping allowed)
YEO7_TO_UX = {
"Visual": ["aesthetic_appeal", "visual_fluency"],
"Somatomotor": ["motor_readiness"],
"Dorsal_Attention": ["visual_fluency", "cognitive_load"],
"Ventral_Attention": ["surprise_novelty"],
"Limbic": ["reward_anticipation", "trust_affinity"],
"Frontoparietal": ["cognitive_load", "friction_anxiety"],
"Default": ["trust_affinity", "aesthetic_appeal"],
}
class YeoAtlas:
"""Lightweight Yeo-2011 7-network atlas for fsaverage5 surface."""
def __init__(self, label_path: Optional[str] = None):
self.vertex_to_network: Dict[int, int] = {}
self.network_vertices: Dict[str, List[int]] = {n: [] for n in YEO7_NAMES.values()}
if label_path and Path(label_path).exists():
self._load_label_file(label_path)
else:
self._generate_fallback()
def _load_label_file(self, path: str):
# FreeSurfer .label or .annot format not needed if we use our fallback JSON
with open(path) as f:
data = json.load(f)
self.vertex_to_network = {int(k): int(v) for k, v in data.items()}
for v, n in self.vertex_to_network.items():
name = YEO7_NAMES.get(n, "None")
self.network_vertices.setdefault(name, []).append(v)
def _generate_fallback(self):
"""
Fallback: generate a deterministic pseudo-Yeo parcellation based on
standard FreeSurfer fsaverage5 vertex ordering. This is NOT the real
atlas, but is sufficient to make the pipeline run if the user hasn't
downloaded the true labels. We emit a loud warning.
"""
warnings.warn(
"Yeo atlas labels not found — using deterministic fallback. "
"For real neuroscience, download fsaverage5 Yeo2011_7Networks. "
"labels from https://github.com/ThomasYeoLab/CBIG/tree/master/stable_projects/brain_parcellation/Yeo2011/"
)
n_vertices = 20484 # fsaverage5 total
# deterministic round-robin assignment per canonical ordering
for v in range(n_vertices):
net_id = (v % 7) + 1
self.vertex_to_network[v] = net_id
name = YEO7_NAMES[net_id]
self.network_vertices[name].append(v)
def mask_for_network(self, name: str) -> np.ndarray:
verts = np.array(self.network_vertices.get(name, []), dtype=np.int64)
return verts
@property
def n_vertices(self) -> int:
return len(self.vertex_to_network)
def load_yeo_atlas(cache_dir: str = "./cache/yeo") -> YeoAtlas:
"""Load or generate Yeo atlas. Returns YeoAtlas instance."""
cache = Path(cache_dir)
cache.mkdir(parents=True, exist_ok=True)
label_json = cache / "yeo7_fsaverage5.json"
if label_json.exists():
return YeoAtlas(str(label_json))
# If user hasn't provided labels, create fallback and save it
atlas = YeoAtlas()
with open(label_json, "w") as f:
json.dump({str(k): int(v) for k, v in atlas.vertex_to_network.items()}, f)
return atlas
class TribeV2Wrapper:
"""
Wraps Meta's TRIBE v2 model for video → brain inference.
Falls back to a synthetic brain generator if TRIBE is not installed.
"""
def __init__(self, model_id: str = "facebook/tribev2", device: str = "auto"):
self.model_id = model_id
self.device = device
self._model = None
self._has_tribe = False
self._init_model()
def _init_model(self):
if not HAS_TORCH:
warnings.warn("PyTorch not installed — using synthetic brain fallback.")
return
try:
from tribev2 import TribeModel
self._model = TribeModel.from_pretrained(self.model_id)
self._has_tribe = True
print(f"TRIBE v2 loaded from {self.model_id}")
except Exception as e:
warnings.warn(f"TRIBE v2 unavailable ({e}) — using synthetic brain fallback.")
def predict_brain(self, video_path: str, tr: float = 1.5) -> np.ndarray:
"""
Run video through TRIBE v2 (or synthetic fallback).
Returns: (n_timesteps, n_vertices) array of predicted cortical response.
TRIBE native TR is ~1.49 s (Courtois NeuroMod). We resample to `tr` if needed.
"""
if self._has_tribe and self._model is not None:
return self._predict_tribe(video_path, tr)
return self._predict_synthetic(video_path, tr)
def _predict_tribe(self, video_path: str, tr: float) -> np.ndarray:
from tribev2 import TribeModel
df = self._model.get_events_dataframe(video_path=video_path)
preds, segments = self._model.predict(events=df)
# preds shape: (n_timesteps, n_vertices)
# TRIBE TR is ~1.49s; if user wants different TR, we can resample later
return np.asarray(preds)
def _predict_synthetic(self, video_path: str, tr: float) -> np.ndarray:
"""
Deterministic synthetic brain response for development/testing.
Uses video duration + hash of filename to seed a realistic-looking
BOLD-like time series per vertex.
"""
import cv2
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
duration = frame_count / fps if fps > 0 else 10.0
cap.release()
n_timesteps = max(1, int(np.ceil(duration / tr)))
n_vertices = 20484 # fsaverage5
# Seed from video name for reproducibility
seed = int.from_bytes(Path(video_path).name.encode(), "little") % (2**31)
rng = np.random.default_rng(seed)
# Base signal: slow drift + fast fluctuation per network
t = np.arange(n_timesteps) * tr
signal = np.zeros((n_timesteps, n_vertices), dtype=np.float32)
for net_id, net_name in YEO7_NAMES.items():
if net_id == 0:
continue
verts = np.where(np.arange(n_vertices) % 7 == (net_id - 1))[0]
# Network-specific temporal dynamics
freq = 0.1 + 0.05 * net_id # ~0.01–0.1 Hz BOLD-ish
amp = 0.5 + 0.2 * np.sin(net_id)
base = amp * np.sin(2 * np.pi * freq * t + net_id)
noise = 0.1 * rng.standard_normal(n_timesteps)
signal[:, verts] = (base + noise).reshape(-1, 1)
return signal
def network_means(self, brain: np.ndarray, atlas: YeoAtlas) -> Dict[str, np.ndarray]:
"""
Collapse vertex-wise brain response into per-Yeo-network time series.
brain: (n_timesteps, n_vertices)
Returns: {network_name: (n_timesteps,) array}
"""
out = {}
for net_id, net_name in YEO7_NAMES.items():
if net_id == 0:
continue
verts = atlas.mask_for_network(net_name)
if len(verts):
out[net_name] = brain[:, verts].mean(axis=1)
else:
out[net_name] = np.zeros(brain.shape[0])
return out