SID2000 commited on
Commit
d9bfcc6
·
verified ·
1 Parent(s): 488e9e1

Upload live_engine.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. live_engine.py +284 -0
live_engine.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Real-time brain prediction engine.
2
+
3
+ Runs in a background thread, consuming frames from a capture source,
4
+ extracting features, and producing brain predictions via TRIBE v2.
5
+
6
+ When CortexLab is not installed, falls back to a simulation mode that
7
+ generates synthetic predictions from frame statistics.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import time
13
+ import threading
14
+ import logging
15
+ from collections import deque
16
+ from dataclasses import dataclass, field
17
+
18
+ import numpy as np
19
+
20
+ from live_capture import BaseCapture, MediaFrame
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ # Check if CortexLab is available
25
+ try:
26
+ from cortexlab.inference.predictor import TribeModel
27
+ CORTEXLAB_AVAILABLE = True
28
+ except ImportError:
29
+ CORTEXLAB_AVAILABLE = False
30
+
31
+
32
+ @dataclass
33
+ class LivePrediction:
34
+ """A single prediction with metadata."""
35
+ vertex_data: np.ndarray # (n_vertices,)
36
+ timestamp: float
37
+ cognitive_load: dict[str, float] = field(default_factory=dict)
38
+ processing_time_ms: float = 0.0
39
+
40
+
41
+ @dataclass
42
+ class LiveMetrics:
43
+ """Aggregated metrics from the live engine."""
44
+ fps: float = 0.0
45
+ total_frames: int = 0
46
+ total_predictions: int = 0
47
+ avg_latency_ms: float = 0.0
48
+ is_running: bool = False
49
+ mode: str = "simulation" # "simulation" or "cortexlab"
50
+
51
+
52
+ class LiveInferenceEngine:
53
+ """Background engine for real-time brain prediction.
54
+
55
+ Consumes frames from a capture source and produces brain predictions.
56
+ If CortexLab is installed and a GPU is available, uses the real TRIBE v2
57
+ model. Otherwise, falls back to simulation mode that generates plausible
58
+ predictions from frame statistics.
59
+ """
60
+
61
+ def __init__(
62
+ self,
63
+ n_vertices: int = 580,
64
+ roi_indices: dict | None = None,
65
+ buffer_size: int = 120,
66
+ checkpoint: str = "facebook/tribev2",
67
+ device: str = "auto",
68
+ cache_folder: str = "./cache",
69
+ ):
70
+ self.n_vertices = n_vertices
71
+ self.roi_indices = roi_indices or {}
72
+ self.buffer_size = buffer_size
73
+ self.checkpoint = checkpoint
74
+ self.device = device
75
+ self.cache_folder = cache_folder
76
+
77
+ self._predictions: deque[LivePrediction] = deque(maxlen=buffer_size)
78
+ self._running = False
79
+ self._thread: threading.Thread | None = None
80
+ self._lock = threading.Lock()
81
+ self._model = None
82
+ self._metrics = LiveMetrics()
83
+ self._capture: BaseCapture | None = None
84
+
85
+ def start(self, capture: BaseCapture):
86
+ """Start the inference engine with a media capture source."""
87
+ if self._running:
88
+ return
89
+
90
+ self._capture = capture
91
+ self._running = True
92
+ self._metrics = LiveMetrics(is_running=True)
93
+
94
+ # Try to load CortexLab model
95
+ if CORTEXLAB_AVAILABLE:
96
+ try:
97
+ logger.info("Loading TRIBE v2 model...")
98
+ self._model = TribeModel.from_pretrained(
99
+ self.checkpoint, device=self.device, cache_folder=self.cache_folder
100
+ )
101
+ self._metrics.mode = "cortexlab"
102
+ logger.info("Model loaded. Using real inference.")
103
+ except Exception as e:
104
+ logger.warning(f"Failed to load model: {e}. Using simulation mode.")
105
+ self._model = None
106
+ self._metrics.mode = "simulation"
107
+ else:
108
+ self._metrics.mode = "simulation"
109
+
110
+ capture.start()
111
+ self._thread = threading.Thread(target=self._inference_loop, daemon=True)
112
+ self._thread.start()
113
+
114
+ def stop(self):
115
+ """Stop the engine and capture source."""
116
+ self._running = False
117
+ if self._capture:
118
+ self._capture.stop()
119
+ if self._thread:
120
+ self._thread.join(timeout=5.0)
121
+ self._metrics.is_running = False
122
+
123
+ def get_latest_prediction(self) -> LivePrediction | None:
124
+ with self._lock:
125
+ return self._predictions[-1] if self._predictions else None
126
+
127
+ def get_predictions(self, n: int = 60) -> list[LivePrediction]:
128
+ with self._lock:
129
+ return list(self._predictions)[-n:]
130
+
131
+ def get_metrics(self) -> LiveMetrics:
132
+ return self._metrics
133
+
134
+ def _inference_loop(self):
135
+ """Main loop: consume frames, produce predictions."""
136
+ frame_times = deque(maxlen=30)
137
+ last_frame_count = 0
138
+
139
+ while self._running:
140
+ frame = self._capture.get_latest_frame()
141
+ if frame is None:
142
+ time.sleep(0.1)
143
+ continue
144
+
145
+ # Skip if we already processed this frame
146
+ current_count = self._capture.frame_count
147
+ if current_count == last_frame_count:
148
+ time.sleep(0.05)
149
+ continue
150
+ last_frame_count = current_count
151
+
152
+ start = time.time()
153
+
154
+ if self._model is not None and self._metrics.mode == "cortexlab":
155
+ prediction = self._run_real_inference(frame)
156
+ else:
157
+ prediction = self._run_simulation(frame)
158
+
159
+ elapsed_ms = (time.time() - start) * 1000
160
+ prediction.processing_time_ms = elapsed_ms
161
+
162
+ with self._lock:
163
+ self._predictions.append(prediction)
164
+
165
+ # Update metrics
166
+ frame_times.append(time.time())
167
+ self._metrics.total_predictions += 1
168
+ self._metrics.total_frames = current_count
169
+ self._metrics.avg_latency_ms = elapsed_ms
170
+ if len(frame_times) >= 2:
171
+ self._metrics.fps = (len(frame_times) - 1) / (frame_times[-1] - frame_times[0])
172
+
173
+ # Check if capture stopped (file ended)
174
+ if not self._capture.is_running:
175
+ self._running = False
176
+ self._metrics.is_running = False
177
+
178
+ def _run_real_inference(self, frame: MediaFrame) -> LivePrediction:
179
+ """Run actual TRIBE v2 inference on a frame.
180
+
181
+ For real-time, we skip the full pipeline (get_events_dataframe)
182
+ and use a simplified feature extraction path.
183
+ """
184
+ import tempfile
185
+ import os
186
+
187
+ try:
188
+ # Save frame as temporary video (1 frame)
189
+ import cv2
190
+ tmp_path = os.path.join(tempfile.gettempdir(), "cortexlab_live_frame.mp4")
191
+ h, w = frame.video_frame.shape[:2]
192
+ out = cv2.VideoWriter(tmp_path, cv2.VideoWriter_fourcc(*'mp4v'), 1, (w, h))
193
+ out.write(cv2.cvtColor(frame.video_frame, cv2.COLOR_RGB2BGR))
194
+ out.release()
195
+
196
+ events = self._model.get_events_dataframe(video_path=tmp_path)
197
+ preds, _ = self._model.predict(events, verbose=False)
198
+ vertex_data = preds.mean(axis=0) if preds.ndim == 2 else preds
199
+
200
+ # Normalize to [0, 1]
201
+ vmin, vmax = vertex_data.min(), vertex_data.max()
202
+ if vmax > vmin:
203
+ vertex_data = (vertex_data - vmin) / (vmax - vmin)
204
+
205
+ os.unlink(tmp_path)
206
+ except Exception as e:
207
+ logger.warning(f"Inference failed: {e}. Falling back to simulation.")
208
+ return self._run_simulation(frame)
209
+
210
+ cog_load = self._compute_cognitive_load(vertex_data)
211
+ return LivePrediction(
212
+ vertex_data=vertex_data,
213
+ timestamp=frame.timestamp,
214
+ cognitive_load=cog_load,
215
+ )
216
+
217
+ def _run_simulation(self, frame: MediaFrame) -> LivePrediction:
218
+ """Generate plausible predictions from frame statistics.
219
+
220
+ Uses frame brightness/color as proxy for visual complexity,
221
+ creating biologically-inspired activation patterns.
222
+ """
223
+ rng = np.random.default_rng(int(frame.timestamp * 1000) % (2**31))
224
+
225
+ # Base noise
226
+ vertex_data = rng.standard_normal(self.n_vertices) * 0.03
227
+
228
+ if frame.video_frame is not None:
229
+ img = frame.video_frame.astype(np.float32) / 255.0
230
+
231
+ # Visual complexity from image statistics
232
+ brightness = img.mean()
233
+ contrast = img.std()
234
+ color_variance = img.var(axis=(0, 1)).mean()
235
+
236
+ # Map to ROI activations
237
+ for roi_name, vertices in self.roi_indices.items():
238
+ valid = vertices[vertices < self.n_vertices]
239
+ if len(valid) == 0:
240
+ continue
241
+
242
+ # Visual ROIs respond to brightness/contrast
243
+ if roi_name in ["V1", "V2", "V3", "V4", "MT", "MST", "FFC", "VVC"]:
244
+ activation = contrast * 0.8 + color_variance * 0.5
245
+ # Auditory ROIs get low baseline
246
+ elif roi_name in ["A1", "LBelt", "MBelt", "PBelt", "A4", "A5"]:
247
+ activation = 0.05 + rng.random() * 0.1
248
+ # Language ROIs moderate
249
+ elif roi_name in ["44", "45", "IFJa", "IFJp", "TPOJ1", "TPOJ2"]:
250
+ activation = brightness * 0.3
251
+ # Executive ROIs track change
252
+ elif roi_name in ["46", "9-46d", "8Av", "8Ad", "FEF"]:
253
+ activation = contrast * 0.5
254
+ else:
255
+ activation = 0.1
256
+
257
+ vertex_data[valid] = activation + rng.standard_normal(len(valid)) * 0.05
258
+
259
+ vertex_data = np.clip(vertex_data, 0, 1)
260
+ cog_load = self._compute_cognitive_load(vertex_data)
261
+
262
+ return LivePrediction(
263
+ vertex_data=vertex_data,
264
+ timestamp=frame.timestamp,
265
+ cognitive_load=cog_load,
266
+ )
267
+
268
+ def _compute_cognitive_load(self, vertex_data: np.ndarray) -> dict[str, float]:
269
+ """Compute cognitive load dimensions from vertex data."""
270
+ from utils import COGNITIVE_DIMENSIONS
271
+
272
+ baseline = max(float(np.median(np.abs(vertex_data))), 1e-8)
273
+ scores = {}
274
+ for dim, rois in COGNITIVE_DIMENSIONS.items():
275
+ vals = []
276
+ for roi in rois:
277
+ if roi in self.roi_indices:
278
+ verts = self.roi_indices[roi]
279
+ valid = verts[verts < len(vertex_data)]
280
+ if len(valid) > 0:
281
+ vals.append(np.abs(vertex_data[valid]).mean())
282
+ scores[dim] = min(float(np.mean(vals)) / baseline, 1.0) if vals else 0.0
283
+ scores["Overall"] = float(np.mean(list(scores.values()))) if scores else 0.0
284
+ return scores