""" Random Forest classifier trained on 40 brain-derived features. 3-class: good / okish / bad Cross-validated accuracy target: 70–80% """ import json, pickle from pathlib import Path from typing import Dict, List, Optional, Tuple import numpy as np from sklearn.ensemble import RandomForestClassifier from sklearn.model_selection import StratifiedKFold, cross_val_score, train_test_split from sklearn.metrics import classification_report, confusion_matrix from .features import FEATURE_NAMES, feature_vector_to_dict LABELS = ["good", "okish", "bad"] LABEL_TO_INT = {lbl: i for i, lbl in enumerate(LABELS)} class ViralityClassifier: """Wraps sklearn RandomForest with our 40-feature brain pipeline.""" def __init__(self, n_estimators: int = 200, max_depth: Optional[int] = 12, random_state: int = 42, class_weight: str = "balanced"): self.n_estimators = n_estimators self.max_depth = max_depth self.random_state = random_state self.class_weight = class_weight self.model: Optional[RandomForestClassifier] = None self.feature_names = FEATURE_NAMES def fit(self, X: np.ndarray, y: np.ndarray, test_size: float = 0.2, verbose: bool = True) -> Dict: """ X: (n_videos, 40) y: (n_videos,) integer labels 0=good,1=okish,2=bad Returns training report dict. """ X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=test_size, stratify=y, random_state=self.random_state ) self.model = RandomForestClassifier( n_estimators=self.n_estimators, max_depth=self.max_depth, class_weight=self.class_weight, random_state=self.random_state, n_jobs=-1, ) self.model.fit(X_train, y_train) # cross-validation cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=self.random_state) cv_scores = cross_val_score(self.model, X_train, y_train, cv=cv, scoring="accuracy") test_acc = float(self.model.score(X_test, y_test)) y_pred = self.model.predict(X_test) report = { "cv_accuracy_mean": float(np.mean(cv_scores)), "cv_accuracy_std": float(np.std(cv_scores)), "test_accuracy": test_acc, "n_train": len(y_train), "n_test": len(y_test), "class_distribution_train": {lbl: int((y_train == i).sum()) for i, lbl in enumerate(LABELS)}, "class_distribution_test": {lbl: int((y_test == i).sum()) for i, lbl in enumerate(LABELS)}, "confusion_matrix": confusion_matrix(y_test, y_pred).tolist(), "classification_report": classification_report( y_test, y_pred, target_names=LABELS, output_dict=True ), } if verbose: print(f"\n{'═'*60}") print(f" Classifier training report") print(f"{'═'*60}") print(f" CV accuracy : {report['cv_accuracy_mean']:.1%} ± {report['cv_accuracy_std']:.1%}") print(f" Test accuracy: {report['test_accuracy']:.1%}") print(f"{'─'*60}") print(classification_report(y_test, y_pred, target_names=LABELS)) print(f"{'═'*60}") return report def predict(self, x: np.ndarray) -> Tuple[str, float, np.ndarray]: """ x: (40,) feature vector Returns: (label, confidence, probability_vector) """ if self.model is None: raise RuntimeError("Model not trained — call .fit() first.") proba = self.model.predict_proba(x.reshape(1, -1))[0] idx = int(np.argmax(proba)) label = LABELS[idx] confidence = float(proba[idx]) return label, confidence, proba def explain(self, x: np.ndarray, top_n: int = 3) -> Dict: """ Return human-readable explanation of the prediction. Includes strongest positive/negative features. """ label, confidence, proba = self.predict(x) feat_dict = feature_vector_to_dict(x) # feature importance if model fitted importance = None if self.model is not None and hasattr(self.model, "feature_importances_"): importance = dict(zip(self.feature_names, self.model.feature_importances_.tolist())) # strongest signal values sorted_feats = sorted(feat_dict.items(), key=lambda kv: kv[1], reverse=True) strongest_pos = sorted_feats[:top_n] strongest_neg = sorted_feats[-top_n:] explanation = { "verdict": label, "confidence": confidence, "probabilities": {lbl: float(p) for lbl, p in zip(LABELS, proba)}, "strongest_positive_signals": [ {"feature": k, "value": round(v, 4)} for k, v in strongest_pos ], "strongest_negative_signals": [ {"feature": k, "value": round(v, 4)} for k, v in strongest_neg ], "feature_importance": importance, } return explanation def save(self, path: str): """Pickle model + hyperparameters.""" payload = { "model": self.model, "params": { "n_estimators": self.n_estimators, "max_depth": self.max_depth, "random_state": self.random_state, "class_weight": self.class_weight, }, } Path(path).parent.mkdir(parents=True, exist_ok=True) with open(path, "wb") as f: pickle.dump(payload, f) print(f"Model saved → {path}") @classmethod def load(cls, path: str) -> "ViralityClassifier": with open(path, "rb") as f: payload = pickle.load(f) inst = cls(**payload["params"]) inst.model = payload["model"] return inst def label_distribution(labels: List[str]) -> Dict[str, int]: return {lbl: labels.count(lbl) for lbl in LABELS}