| """ |
| End-to-end training pipeline: |
| 1. Load downloaded videos from labeled folders |
| 2. Run TRIBE v2 (or synthetic fallback) β brain response |
| 3. Extract Yeo network time series |
| 4. Distill to 40 features (8 signals Γ 5 temporal stats) |
| 5. Train Random Forest classifier |
| 6. Save model + features + report |
| """ |
|
|
| import json, pickle |
| from pathlib import Path |
| from typing import Dict, List, Optional, Tuple |
| import numpy as np |
| import cv2 |
|
|
| from .tribe_wrapper import TribeV2Wrapper, load_yeo_atlas |
| from .signals import network_to_ux_signals |
| from .features import extract_features, FEATURE_NAMES |
| from .classifier import ViralityClassifier, LABELS, LABEL_TO_INT |
|
|
|
|
| def process_video(video_path: str, tribe: TribeV2Wrapper, atlas, |
| tr: float = 1.5, verbose: bool = False) -> Optional[np.ndarray]: |
| """ |
| video_path β (40,) feature vector. |
| Returns None if video unreadable. |
| """ |
| video_path = Path(video_path) |
| if not video_path.exists(): |
| if verbose: |
| print(f" β missing: {video_path}") |
| return None |
|
|
| |
| cap = cv2.VideoCapture(str(video_path)) |
| if not cap.isOpened(): |
| if verbose: |
| print(f" β unreadable: {video_path.name}") |
| cap.release() |
| return None |
| cap.release() |
|
|
| try: |
| brain = tribe.predict_brain(str(video_path), tr=tr) |
| network_ts = tribe.network_means(brain, atlas) |
| signals = network_to_ux_signals(network_ts) |
| feats = extract_features(signals, tr=tr) |
| if verbose: |
| print(f" β {video_path.name} β {feats.shape}") |
| return feats |
| except Exception as e: |
| if verbose: |
| print(f" β error on {video_path.name}: {e}") |
| return None |
|
|
|
|
| def build_dataset(video_root: str, tribe: Optional[TribeV2Wrapper] = None, |
| tr: float = 1.5, verbose: bool = True, |
| cache_path: Optional[str] = None) -> Tuple[np.ndarray, List[str]]: |
| """ |
| Scan video_root/{good,okish,bad}/*.mp4, process each, return feature matrix + labels. |
| Optionally loads/saves a .npz cache to avoid re-processing. |
| """ |
| video_root = Path(video_root) |
| if tribe is None: |
| tribe = TribeV2Wrapper() |
| atlas = load_yeo_atlas() |
|
|
| |
| if cache_path and Path(cache_path).exists(): |
| data = np.load(cache_path, allow_pickle=True) |
| X, y_str = data["X"], data["y_str"].tolist() |
| if verbose: |
| print(f"Loaded cached dataset: {X.shape} from {cache_path}") |
| return X, y_str |
|
|
| X_list, y_list = [], [] |
| for label in LABELS: |
| folder = video_root / label |
| if not folder.exists(): |
| print(f" β folder missing: {folder}") |
| continue |
| videos = sorted(folder.glob("*.mp4")) |
| if verbose: |
| print(f"\nProcessing {label}: {len(videos)} videos") |
| for vp in videos: |
| feats = process_video(str(vp), tribe, atlas, tr=tr, verbose=verbose) |
| if feats is not None: |
| X_list.append(feats) |
| y_list.append(label) |
|
|
| X = np.stack(X_list) if X_list else np.empty((0, 40), dtype=np.float32) |
| if cache_path: |
| np.savez(cache_path, X=X, y_str=np.array(y_list)) |
| if verbose: |
| print(f"Saved cache β {cache_path}") |
| return X, y_list |
|
|
|
|
| def train_pipeline(video_root: str, model_out: str, report_out: str, |
| tr: float = 1.5, use_cache: bool = True, |
| classifier_params: Optional[Dict] = None) -> Dict: |
| """ |
| Full train β save model + report. |
| """ |
| cache = Path(video_root).parent / "brain_features.npz" if use_cache else None |
| X, y_str = build_dataset(video_root, cache_path=str(cache) if cache else None) |
|
|
| y = np.array([LABEL_TO_INT[l] for l in y_str], dtype=np.int32) |
|
|
| print(f"\n{'β'*60}") |
| print(f" Dataset: {X.shape[0]} videos Γ {X.shape[1]} features") |
| for lbl in LABELS: |
| n = (y == LABEL_TO_INT[lbl]).sum() |
| print(f" {lbl}: {n}") |
| print(f"{'β'*60}") |
|
|
| if classifier_params is None: |
| classifier_params = {} |
| clf = ViralityClassifier(**classifier_params) |
| report = clf.fit(X, y, test_size=0.2, verbose=True) |
|
|
| |
| clf.save(model_out) |
| with open(report_out, "w") as f: |
| json.dump(report, f, indent=2) |
| print(f"Model saved β {model_out}") |
| print(f"Report saved β {report_out}") |
|
|
| return report |
|
|