SID2000 commited on
Commit
c869f0e
·
verified ·
1 Parent(s): d9bfcc6

Upload pages/6_Live_Inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. pages/6_Live_Inference.py +294 -0
pages/6_Live_Inference.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Live Brain Prediction - Real-Time Inference from Webcam, Screen, or Video."""
2
+
3
+ import time
4
+
5
+ import numpy as np
6
+ import plotly.graph_objects as go
7
+ import streamlit as st
8
+ from plotly.subplots import make_subplots
9
+
10
+ from session import init_session, show_analysis_log
11
+ from theme import inject_theme, glow_card, section_header
12
+ from utils import make_roi_indices, COGNITIVE_DIMENSIONS
13
+
14
+ st.set_page_config(page_title="Live Inference", page_icon="🔴", layout="wide")
15
+ init_session()
16
+ inject_theme()
17
+ show_analysis_log()
18
+
19
+ st.title("🔴 Live Brain Prediction")
20
+ st.markdown("Real-time brain activation prediction from webcam, screen capture, or video file.")
21
+
22
+ # --- Check Dependencies ---
23
+ deps_ok = True
24
+ missing = []
25
+
26
+ try:
27
+ from live_capture import WebcamCapture, ScreenCapture, FileStreamer, get_capture_source
28
+ from live_engine import LiveInferenceEngine, CORTEXLAB_AVAILABLE
29
+ except ImportError as e:
30
+ deps_ok = False
31
+ missing.append(str(e))
32
+
33
+ # --- Sidebar ---
34
+ with st.sidebar:
35
+ st.header("Live Inference")
36
+
37
+ source_type = st.selectbox("Source", ["webcam", "screen", "file"],
38
+ format_func={"webcam": "Webcam + Mic", "screen": "Screen Capture", "file": "Video File"}.get)
39
+
40
+ if source_type == "file":
41
+ uploaded_file = st.file_uploader("Upload video", type=["mp4", "avi", "mkv", "mov", "webm"])
42
+
43
+ st.subheader("Settings")
44
+ capture_fps = st.slider("Capture FPS", 0.5, 5.0, 1.0, 0.5,
45
+ help="Frames per second. Higher = more responsive but more CPU/GPU load.")
46
+
47
+ if CORTEXLAB_AVAILABLE:
48
+ device = st.selectbox("Device", ["auto", "cuda", "cpu"])
49
+ st.success("CortexLab detected. Real inference available.")
50
+ else:
51
+ device = "cpu"
52
+ st.warning("CortexLab not installed. Running in **simulation mode** (predictions from image statistics).")
53
+ with st.expander("Install CortexLab"):
54
+ st.code("pip install -e ../cortexlab[analysis]", language="bash")
55
+
56
+ st.subheader("Display")
57
+ show_brain_3d = st.checkbox("Show 3D brain", value=True)
58
+ show_timeline = st.checkbox("Show cognitive load timeline", value=True)
59
+ timeline_window = st.slider("Timeline window (seconds)", 10, 120, 60)
60
+
61
+ # --- Initialize Engine ---
62
+ roi_indices, n_vertices = make_roi_indices()
63
+
64
+ if "live_engine" not in st.session_state:
65
+ st.session_state["live_engine"] = None
66
+ if "live_running" not in st.session_state:
67
+ st.session_state["live_running"] = False
68
+
69
+ # --- Controls ---
70
+ col_start, col_stop, col_status = st.columns([1, 1, 2])
71
+
72
+ with col_start:
73
+ start_clicked = st.button("▶ Start", type="primary", use_container_width=True,
74
+ disabled=st.session_state.get("live_running", False))
75
+
76
+ with col_stop:
77
+ stop_clicked = st.button("⬛ Stop", use_container_width=True,
78
+ disabled=not st.session_state.get("live_running", False))
79
+
80
+ # Handle Start
81
+ if start_clicked and deps_ok:
82
+ # Create capture source
83
+ if source_type == "webcam":
84
+ capture = WebcamCapture(fps=capture_fps)
85
+ elif source_type == "screen":
86
+ capture = ScreenCapture(fps=capture_fps)
87
+ elif source_type == "file":
88
+ if uploaded_file is not None:
89
+ import tempfile, os
90
+ tmp_path = os.path.join(tempfile.gettempdir(), uploaded_file.name)
91
+ with open(tmp_path, "wb") as f:
92
+ f.write(uploaded_file.read())
93
+ capture = FileStreamer(file_path=tmp_path, fps=capture_fps)
94
+ else:
95
+ st.error("Upload a video file first.")
96
+ st.stop()
97
+
98
+ # Create and start engine
99
+ engine = LiveInferenceEngine(
100
+ n_vertices=n_vertices,
101
+ roi_indices=roi_indices,
102
+ device=device,
103
+ )
104
+ engine.start(capture)
105
+ st.session_state["live_engine"] = engine
106
+ st.session_state["live_running"] = True
107
+ st.rerun()
108
+
109
+ # Handle Stop
110
+ if stop_clicked:
111
+ engine = st.session_state.get("live_engine")
112
+ if engine:
113
+ engine.stop()
114
+ st.session_state["live_running"] = False
115
+ st.rerun()
116
+
117
+ # --- Status Bar ---
118
+ with col_status:
119
+ engine = st.session_state.get("live_engine")
120
+ if engine and st.session_state.get("live_running"):
121
+ metrics = engine.get_metrics()
122
+ st.markdown(f"""
123
+ <div style="display: flex; gap: 1.5rem; align-items: center; padding: 0.5rem;">
124
+ <span style="color: #EF4444; font-size: 1.2rem;">● LIVE</span>
125
+ <span style="color: #94A3B8;">Mode: <b style="color: #06B6D4;">{metrics.mode}</b></span>
126
+ <span style="color: #94A3B8;">FPS: <b style="color: #10B981;">{metrics.fps:.1f}</b></span>
127
+ <span style="color: #94A3B8;">Predictions: <b style="color: #A29BFE;">{metrics.total_predictions}</b></span>
128
+ <span style="color: #94A3B8;">Latency: <b style="color: #FFEAA7;">{metrics.avg_latency_ms:.0f}ms</b></span>
129
+ </div>
130
+ """, unsafe_allow_html=True)
131
+ elif not st.session_state.get("live_running"):
132
+ st.markdown('<span style="color: #64748B;">Ready. Select a source and click Start.</span>', unsafe_allow_html=True)
133
+
134
+ st.divider()
135
+
136
+ # --- Live Display ---
137
+ if st.session_state.get("live_running") and engine:
138
+ predictions = engine.get_predictions(timeline_window)
139
+
140
+ if predictions:
141
+ latest = predictions[-1]
142
+
143
+ # --- Cognitive Load Metrics ---
144
+ cog = latest.cognitive_load
145
+ c1, c2, c3, c4, c5 = st.columns(5)
146
+ with c1: glow_card("Overall", f"{cog.get('Overall', 0):.2f}", "", "#7C3AED")
147
+ with c2: glow_card("Visual", f"{cog.get('Visual Complexity', 0):.2f}", "", "#00D2FF")
148
+ with c3: glow_card("Auditory", f"{cog.get('Auditory Demand', 0):.2f}", "", "#FF6B6B")
149
+ with c4: glow_card("Language", f"{cog.get('Language Processing', 0):.2f}", "", "#A29BFE")
150
+ with c5: glow_card("Executive", f"{cog.get('Executive Load', 0):.2f}", "", "#FFEAA7")
151
+
152
+ col_brain, col_timeline = st.columns([1, 1])
153
+
154
+ # --- 3D Brain ---
155
+ if show_brain_3d:
156
+ with col_brain:
157
+ section_header("Brain Activation", f"t = {latest.timestamp:.1f}s")
158
+ try:
159
+ from brain_mesh import (
160
+ load_fsaverage_mesh, render_interactive_3d,
161
+ )
162
+ coords, faces = load_fsaverage_mesh("left", "fsaverage4") # Fast mesh for live
163
+ n_mesh = coords.shape[0]
164
+
165
+ # Map vertex data to mesh size
166
+ vd = latest.vertex_data
167
+ if len(vd) < n_mesh:
168
+ vd = np.interp(np.linspace(0, len(vd) - 1, n_mesh), np.arange(len(vd)), vd)
169
+ elif len(vd) > n_mesh:
170
+ vd = vd[:n_mesh]
171
+
172
+ fig_brain = render_interactive_3d(
173
+ coords, faces, vd, cmap="Inferno", vmin=0, vmax=0.8,
174
+ bg_color="#050510", initial_view="Lateral Left",
175
+ )
176
+ if fig_brain:
177
+ fig_brain.update_layout(height=400, margin=dict(l=0, r=0, t=0, b=0))
178
+ st.plotly_chart(fig_brain, use_container_width=True)
179
+ except Exception as e:
180
+ st.warning(f"Brain render error: {e}")
181
+
182
+ # --- Cognitive Load Timeline ---
183
+ if show_timeline:
184
+ with col_timeline:
185
+ section_header("Cognitive Load Timeline", f"{len(predictions)} data points")
186
+
187
+ fig_tl = go.Figure()
188
+ timestamps = [p.timestamp for p in predictions]
189
+ dim_colors = {
190
+ "Visual Complexity": "#00D2FF",
191
+ "Auditory Demand": "#FF6B6B",
192
+ "Language Processing": "#A29BFE",
193
+ "Executive Load": "#FFEAA7",
194
+ }
195
+
196
+ for dim, color in dim_colors.items():
197
+ values = [p.cognitive_load.get(dim, 0) for p in predictions]
198
+ fig_tl.add_trace(go.Scatter(
199
+ x=timestamps, y=values, name=dim.split()[0],
200
+ line=dict(color=color, width=2), mode="lines",
201
+ ))
202
+
203
+ fig_tl.update_layout(
204
+ xaxis_title="Time (seconds)", yaxis_title="Load",
205
+ yaxis_range=[0, 1.05], height=400,
206
+ template="plotly_dark",
207
+ legend=dict(orientation="h", yanchor="bottom", y=1.02),
208
+ margin=dict(l=40, r=10, t=10, b=40),
209
+ )
210
+ st.plotly_chart(fig_tl, use_container_width=True)
211
+
212
+ # --- Store latest predictions for other pages ---
213
+ all_vertex_data = np.array([p.vertex_data for p in predictions])
214
+ st.session_state["brain_predictions"] = all_vertex_data
215
+ st.session_state["roi_indices"] = roi_indices
216
+ st.session_state["data_source"] = "live_inference"
217
+
218
+ # --- Navigation ---
219
+ st.divider()
220
+ st.markdown("**Explore live predictions in other tools:**")
221
+ c1, c2, c3, c4 = st.columns(4)
222
+ with c1: st.page_link("pages/5_Brain_Viewer.py", label="Brain Viewer", icon="🧠")
223
+ with c2: st.page_link("pages/2_Cognitive_Load.py", label="Cognitive Load", icon="📊")
224
+ with c3: st.page_link("pages/3_Temporal_Dynamics.py", label="Temporal Dynamics", icon="⏱️")
225
+ with c4: st.page_link("pages/4_Connectivity.py", label="Connectivity", icon="🔗")
226
+
227
+ # --- Auto-refresh ---
228
+ time.sleep(1.0)
229
+ st.rerun()
230
+
231
+ else:
232
+ # --- Not running: show instructions ---
233
+ st.markdown("""
234
+ <div style="
235
+ text-align: center; padding: 3rem 2rem;
236
+ background: rgba(15, 15, 40, 0.4);
237
+ border: 1px solid rgba(100, 100, 255, 0.15);
238
+ border-radius: 16px; margin: 1rem 0;
239
+ ">
240
+ <div style="font-size: 3rem; margin-bottom: 1rem;">🧠</div>
241
+ <h3 style="color: #F1F5F9; margin-bottom: 0.5rem;">Ready for Live Brain Prediction</h3>
242
+ <p style="color: #94A3B8; max-width: 600px; margin: 0 auto;">
243
+ Select a source (webcam, screen capture, or video file) from the sidebar,
244
+ then click <b>Start</b> to begin real-time brain activation prediction.
245
+ </p>
246
+ <div style="margin-top: 1.5rem; display: flex; justify-content: center; gap: 2rem;">
247
+ <div style="text-align: center;">
248
+ <div style="font-size: 1.5rem;">📹</div>
249
+ <div style="color: #06B6D4; font-size: 0.85rem; font-weight: 600;">Webcam</div>
250
+ <div style="color: #64748B; font-size: 0.75rem;">Live camera feed</div>
251
+ </div>
252
+ <div style="text-align: center;">
253
+ <div style="font-size: 1.5rem;">🖥️</div>
254
+ <div style="color: #7C3AED; font-size: 0.85rem; font-weight: 600;">Screen</div>
255
+ <div style="color: #64748B; font-size: 0.75rem;">Capture display</div>
256
+ </div>
257
+ <div style="text-align: center;">
258
+ <div style="font-size: 1.5rem;">🎬</div>
259
+ <div style="color: #EC4899; font-size: 0.85rem; font-weight: 600;">Video File</div>
260
+ <div style="color: #64748B; font-size: 0.75rem;">Frame-by-frame</div>
261
+ </div>
262
+ </div>
263
+ </div>
264
+ """, unsafe_allow_html=True)
265
+
266
+ # Show last predictions if available
267
+ if st.session_state.get("brain_predictions") is not None and st.session_state.get("data_source") == "live_inference":
268
+ st.info(f"Previous session predictions available ({st.session_state['brain_predictions'].shape[0]} timepoints). Navigate to analysis pages to explore them.")
269
+
270
+ # --- Methodology ---
271
+ with st.expander("About Live Inference", expanded=False):
272
+ st.markdown(f"""
273
+ **Mode: {'Real (CortexLab)' if CORTEXLAB_AVAILABLE else 'Simulation'}**
274
+
275
+ {'**Real Inference**: Uses TRIBE v2 to extract features (V-JEPA2, Wav2Vec-BERT, LLaMA 3.2) and predict fMRI brain activation at each captured frame. Requires GPU for interactive speed.' if CORTEXLAB_AVAILABLE else '**Simulation Mode**: CortexLab is not installed. Predictions are generated from image statistics (brightness, contrast, color variance) mapped to brain ROIs. This demonstrates the pipeline without requiring GPU or model weights.'}
276
+
277
+ **Sources:**
278
+ - **Webcam**: Captures frames via OpenCV. Requires `pip install opencv-python`.
279
+ - **Screen Capture**: Captures display via mss. Requires `pip install mss Pillow`.
280
+ - **Video File**: Reads uploaded video frame-by-frame at the specified FPS.
281
+
282
+ **Cognitive Load Dimensions** are computed from predicted vertex activations
283
+ grouped by HCP MMP1.0 ROIs (same method as the Cognitive Load Scorer page).
284
+
285
+ **Performance:**
286
+ - Simulation mode: ~1-5ms per frame (CPU)
287
+ - Real inference with GPU: ~50-200ms per frame
288
+ - Real inference with CPU: ~5-30s per frame (not recommended)
289
+
290
+ **To enable real inference:**
291
+ ```bash
292
+ pip install -e path/to/cortexlab[analysis]
293
+ ```
294
+ """)