File size: 7,904 Bytes
c12563b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
"""
TRIBE v2 video → brain response inference with Yeo network masking.
Maps cortical vertices (fsaverage5) into 7 Yeo functional networks.
"""

import json, os, warnings
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import numpy as np

try:
    import torch
    HAS_TORCH = True
except ImportError:
    HAS_TORCH = False

# ── Yeo 2011 7-network fsaverage5 labels (left hemi → vertex IDs) ─────
# These are the canonical surface-based labels from FreeSurfer's fsaverage5.
# We ship a lightweight JSON with vertex→network mapping so nilearn is optional.
YEO7_NAMES = {
    1: "Visual",
    2: "Somatomotor",
    3: "Dorsal_Attention",
    4: "Ventral_Attention",
    5: "Limbic",
    6: "Frontoparietal",
    7: "Default",
    0: "None",
}

# Yeo7 → our 8 UX signals (multi-mapping allowed)
YEO7_TO_UX = {
    "Visual":           ["aesthetic_appeal", "visual_fluency"],
    "Somatomotor":      ["motor_readiness"],
    "Dorsal_Attention": ["visual_fluency", "cognitive_load"],
    "Ventral_Attention": ["surprise_novelty"],
    "Limbic":           ["reward_anticipation", "trust_affinity"],
    "Frontoparietal":   ["cognitive_load", "friction_anxiety"],
    "Default":          ["trust_affinity", "aesthetic_appeal"],
}


class YeoAtlas:
    """Lightweight Yeo-2011 7-network atlas for fsaverage5 surface."""

    def __init__(self, label_path: Optional[str] = None):
        self.vertex_to_network: Dict[int, int] = {}
        self.network_vertices: Dict[str, List[int]] = {n: [] for n in YEO7_NAMES.values()}

        if label_path and Path(label_path).exists():
            self._load_label_file(label_path)
        else:
            self._generate_fallback()

    def _load_label_file(self, path: str):
        # FreeSurfer .label or .annot format not needed if we use our fallback JSON
        with open(path) as f:
            data = json.load(f)
        self.vertex_to_network = {int(k): int(v) for k, v in data.items()}
        for v, n in self.vertex_to_network.items():
            name = YEO7_NAMES.get(n, "None")
            self.network_vertices.setdefault(name, []).append(v)

    def _generate_fallback(self):
        """
        Fallback: generate a deterministic pseudo-Yeo parcellation based on
        standard FreeSurfer fsaverage5 vertex ordering.  This is NOT the real
        atlas, but is sufficient to make the pipeline run if the user hasn't
        downloaded the true labels.  We emit a loud warning.
        """
        warnings.warn(
            "Yeo atlas labels not found — using deterministic fallback. "
            "For real neuroscience, download fsaverage5 Yeo2011_7Networks. "
            "labels from https://github.com/ThomasYeoLab/CBIG/tree/master/stable_projects/brain_parcellation/Yeo2011/"
        )
        n_vertices = 20484  # fsaverage5 total
        # deterministic round-robin assignment per canonical ordering
        for v in range(n_vertices):
            net_id = (v % 7) + 1
            self.vertex_to_network[v] = net_id
            name = YEO7_NAMES[net_id]
            self.network_vertices[name].append(v)

    def mask_for_network(self, name: str) -> np.ndarray:
        verts = np.array(self.network_vertices.get(name, []), dtype=np.int64)
        return verts

    @property
    def n_vertices(self) -> int:
        return len(self.vertex_to_network)


def load_yeo_atlas(cache_dir: str = "./cache/yeo") -> YeoAtlas:
    """Load or generate Yeo atlas. Returns YeoAtlas instance."""
    cache = Path(cache_dir)
    cache.mkdir(parents=True, exist_ok=True)
    label_json = cache / "yeo7_fsaverage5.json"

    if label_json.exists():
        return YeoAtlas(str(label_json))

    # If user hasn't provided labels, create fallback and save it
    atlas = YeoAtlas()
    with open(label_json, "w") as f:
        json.dump({str(k): int(v) for k, v in atlas.vertex_to_network.items()}, f)
    return atlas


class TribeV2Wrapper:
    """
    Wraps Meta's TRIBE v2 model for video → brain inference.
    Falls back to a synthetic brain generator if TRIBE is not installed.
    """

    def __init__(self, model_id: str = "facebook/tribev2", device: str = "auto"):
        self.model_id = model_id
        self.device = device
        self._model = None
        self._has_tribe = False
        self._init_model()

    def _init_model(self):
        if not HAS_TORCH:
            warnings.warn("PyTorch not installed — using synthetic brain fallback.")
            return
        try:
            from tribev2 import TribeModel
            self._model = TribeModel.from_pretrained(self.model_id)
            self._has_tribe = True
            print(f"TRIBE v2 loaded from {self.model_id}")
        except Exception as e:
            warnings.warn(f"TRIBE v2 unavailable ({e}) — using synthetic brain fallback.")

    def predict_brain(self, video_path: str, tr: float = 1.5) -> np.ndarray:
        """
        Run video through TRIBE v2 (or synthetic fallback).
        Returns: (n_timesteps, n_vertices) array of predicted cortical response.
        TRIBE native TR is ~1.49 s (Courtois NeuroMod). We resample to `tr` if needed.
        """
        if self._has_tribe and self._model is not None:
            return self._predict_tribe(video_path, tr)
        return self._predict_synthetic(video_path, tr)

    def _predict_tribe(self, video_path: str, tr: float) -> np.ndarray:
        from tribev2 import TribeModel
        df = self._model.get_events_dataframe(video_path=video_path)
        preds, segments = self._model.predict(events=df)
        # preds shape: (n_timesteps, n_vertices)
        # TRIBE TR is ~1.49s; if user wants different TR, we can resample later
        return np.asarray(preds)

    def _predict_synthetic(self, video_path: str, tr: float) -> np.ndarray:
        """
        Deterministic synthetic brain response for development/testing.
        Uses video duration + hash of filename to seed a realistic-looking
        BOLD-like time series per vertex.
        """
        import cv2
        cap = cv2.VideoCapture(video_path)
        fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
        frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        duration = frame_count / fps if fps > 0 else 10.0
        cap.release()

        n_timesteps = max(1, int(np.ceil(duration / tr)))
        n_vertices = 20484  # fsaverage5

        # Seed from video name for reproducibility
        seed = int.from_bytes(Path(video_path).name.encode(), "little") % (2**31)
        rng = np.random.default_rng(seed)

        # Base signal: slow drift + fast fluctuation per network
        t = np.arange(n_timesteps) * tr
        signal = np.zeros((n_timesteps, n_vertices), dtype=np.float32)

        for net_id, net_name in YEO7_NAMES.items():
            if net_id == 0:
                continue
            verts = np.where(np.arange(n_vertices) % 7 == (net_id - 1))[0]
            # Network-specific temporal dynamics
            freq = 0.1 + 0.05 * net_id  # ~0.01–0.1 Hz BOLD-ish
            amp = 0.5 + 0.2 * np.sin(net_id)
            base = amp * np.sin(2 * np.pi * freq * t + net_id)
            noise = 0.1 * rng.standard_normal(n_timesteps)
            signal[:, verts] = (base + noise).reshape(-1, 1)

        return signal

    def network_means(self, brain: np.ndarray, atlas: YeoAtlas) -> Dict[str, np.ndarray]:
        """
        Collapse vertex-wise brain response into per-Yeo-network time series.
        brain: (n_timesteps, n_vertices)
        Returns: {network_name: (n_timesteps,) array}
        """
        out = {}
        for net_id, net_name in YEO7_NAMES.items():
            if net_id == 0:
                continue
            verts = atlas.mask_for_network(net_name)
            if len(verts):
                out[net_name] = brain[:, verts].mean(axis=1)
            else:
                out[net_name] = np.zeros(brain.shape[0])
        return out