""" End-to-end training pipeline: 1. Load downloaded videos from labeled folders 2. Run TRIBE v2 (or synthetic fallback) → brain response 3. Extract Yeo network time series 4. Distill to 40 features (8 signals × 5 temporal stats) 5. Train Random Forest classifier 6. Save model + features + report """ import json, pickle from pathlib import Path from typing import Dict, List, Optional, Tuple import numpy as np import cv2 from .tribe_wrapper import TribeV2Wrapper, load_yeo_atlas from .signals import network_to_ux_signals from .features import extract_features, FEATURE_NAMES from .classifier import ViralityClassifier, LABELS, LABEL_TO_INT def process_video(video_path: str, tribe: TribeV2Wrapper, atlas, tr: float = 1.5, verbose: bool = False) -> Optional[np.ndarray]: """ video_path → (40,) feature vector. Returns None if video unreadable. """ video_path = Path(video_path) if not video_path.exists(): if verbose: print(f" ✗ missing: {video_path}") return None # Quick sanity check: can OpenCV open it? cap = cv2.VideoCapture(str(video_path)) if not cap.isOpened(): if verbose: print(f" ✗ unreadable: {video_path.name}") cap.release() return None cap.release() try: brain = tribe.predict_brain(str(video_path), tr=tr) network_ts = tribe.network_means(brain, atlas) signals = network_to_ux_signals(network_ts) feats = extract_features(signals, tr=tr) if verbose: print(f" ✓ {video_path.name} → {feats.shape}") return feats except Exception as e: if verbose: print(f" ✗ error on {video_path.name}: {e}") return None def build_dataset(video_root: str, tribe: Optional[TribeV2Wrapper] = None, tr: float = 1.5, verbose: bool = True, cache_path: Optional[str] = None) -> Tuple[np.ndarray, List[str]]: """ Scan video_root/{good,okish,bad}/*.mp4, process each, return feature matrix + labels. Optionally loads/saves a .npz cache to avoid re-processing. """ video_root = Path(video_root) if tribe is None: tribe = TribeV2Wrapper() atlas = load_yeo_atlas() # Cache check if cache_path and Path(cache_path).exists(): data = np.load(cache_path, allow_pickle=True) X, y_str = data["X"], data["y_str"].tolist() if verbose: print(f"Loaded cached dataset: {X.shape} from {cache_path}") return X, y_str X_list, y_list = [], [] for label in LABELS: folder = video_root / label if not folder.exists(): print(f" ⚠ folder missing: {folder}") continue videos = sorted(folder.glob("*.mp4")) if verbose: print(f"\nProcessing {label}: {len(videos)} videos") for vp in videos: feats = process_video(str(vp), tribe, atlas, tr=tr, verbose=verbose) if feats is not None: X_list.append(feats) y_list.append(label) X = np.stack(X_list) if X_list else np.empty((0, 40), dtype=np.float32) if cache_path: np.savez(cache_path, X=X, y_str=np.array(y_list)) if verbose: print(f"Saved cache → {cache_path}") return X, y_list def train_pipeline(video_root: str, model_out: str, report_out: str, tr: float = 1.5, use_cache: bool = True, classifier_params: Optional[Dict] = None) -> Dict: """ Full train → save model + report. """ cache = Path(video_root).parent / "brain_features.npz" if use_cache else None X, y_str = build_dataset(video_root, cache_path=str(cache) if cache else None) y = np.array([LABEL_TO_INT[l] for l in y_str], dtype=np.int32) print(f"\n{'═'*60}") print(f" Dataset: {X.shape[0]} videos × {X.shape[1]} features") for lbl in LABELS: n = (y == LABEL_TO_INT[lbl]).sum() print(f" {lbl}: {n}") print(f"{'═'*60}") if classifier_params is None: classifier_params = {} clf = ViralityClassifier(**classifier_params) report = clf.fit(X, y, test_size=0.2, verbose=True) # Save clf.save(model_out) with open(report_out, "w") as f: json.dump(report, f, indent=2) print(f"Model saved → {model_out}") print(f"Report saved → {report_out}") return report