moonlantern1's picture
Upload brain_virality_predictor/trainer.py with huggingface_hub
8c2a812 verified
"""
End-to-end training pipeline:
1. Load downloaded videos from labeled folders
2. Run TRIBE v2 (or synthetic fallback) β†’ brain response
3. Extract Yeo network time series
4. Distill to 40 features (8 signals Γ— 5 temporal stats)
5. Train Random Forest classifier
6. Save model + features + report
"""
import json, pickle
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import numpy as np
import cv2
from .tribe_wrapper import TribeV2Wrapper, load_yeo_atlas
from .signals import network_to_ux_signals
from .features import extract_features, FEATURE_NAMES
from .classifier import ViralityClassifier, LABELS, LABEL_TO_INT
def process_video(video_path: str, tribe: TribeV2Wrapper, atlas,
tr: float = 1.5, verbose: bool = False) -> Optional[np.ndarray]:
"""
video_path β†’ (40,) feature vector.
Returns None if video unreadable.
"""
video_path = Path(video_path)
if not video_path.exists():
if verbose:
print(f" βœ— missing: {video_path}")
return None
# Quick sanity check: can OpenCV open it?
cap = cv2.VideoCapture(str(video_path))
if not cap.isOpened():
if verbose:
print(f" βœ— unreadable: {video_path.name}")
cap.release()
return None
cap.release()
try:
brain = tribe.predict_brain(str(video_path), tr=tr)
network_ts = tribe.network_means(brain, atlas)
signals = network_to_ux_signals(network_ts)
feats = extract_features(signals, tr=tr)
if verbose:
print(f" βœ“ {video_path.name} β†’ {feats.shape}")
return feats
except Exception as e:
if verbose:
print(f" βœ— error on {video_path.name}: {e}")
return None
def build_dataset(video_root: str, tribe: Optional[TribeV2Wrapper] = None,
tr: float = 1.5, verbose: bool = True,
cache_path: Optional[str] = None) -> Tuple[np.ndarray, List[str]]:
"""
Scan video_root/{good,okish,bad}/*.mp4, process each, return feature matrix + labels.
Optionally loads/saves a .npz cache to avoid re-processing.
"""
video_root = Path(video_root)
if tribe is None:
tribe = TribeV2Wrapper()
atlas = load_yeo_atlas()
# Cache check
if cache_path and Path(cache_path).exists():
data = np.load(cache_path, allow_pickle=True)
X, y_str = data["X"], data["y_str"].tolist()
if verbose:
print(f"Loaded cached dataset: {X.shape} from {cache_path}")
return X, y_str
X_list, y_list = [], []
for label in LABELS:
folder = video_root / label
if not folder.exists():
print(f" ⚠ folder missing: {folder}")
continue
videos = sorted(folder.glob("*.mp4"))
if verbose:
print(f"\nProcessing {label}: {len(videos)} videos")
for vp in videos:
feats = process_video(str(vp), tribe, atlas, tr=tr, verbose=verbose)
if feats is not None:
X_list.append(feats)
y_list.append(label)
X = np.stack(X_list) if X_list else np.empty((0, 40), dtype=np.float32)
if cache_path:
np.savez(cache_path, X=X, y_str=np.array(y_list))
if verbose:
print(f"Saved cache β†’ {cache_path}")
return X, y_list
def train_pipeline(video_root: str, model_out: str, report_out: str,
tr: float = 1.5, use_cache: bool = True,
classifier_params: Optional[Dict] = None) -> Dict:
"""
Full train β†’ save model + report.
"""
cache = Path(video_root).parent / "brain_features.npz" if use_cache else None
X, y_str = build_dataset(video_root, cache_path=str(cache) if cache else None)
y = np.array([LABEL_TO_INT[l] for l in y_str], dtype=np.int32)
print(f"\n{'═'*60}")
print(f" Dataset: {X.shape[0]} videos Γ— {X.shape[1]} features")
for lbl in LABELS:
n = (y == LABEL_TO_INT[lbl]).sum()
print(f" {lbl}: {n}")
print(f"{'═'*60}")
if classifier_params is None:
classifier_params = {}
clf = ViralityClassifier(**classifier_params)
report = clf.fit(X, y, test_size=0.2, verbose=True)
# Save
clf.save(model_out)
with open(report_out, "w") as f:
json.dump(report, f, indent=2)
print(f"Model saved β†’ {model_out}")
print(f"Report saved β†’ {report_out}")
return report