| """ |
| Random Forest classifier trained on 40 brain-derived features. |
| 3-class: good / okish / bad |
| Cross-validated accuracy target: 70–80% |
| """ |
|
|
| import json, pickle |
| from pathlib import Path |
| from typing import Dict, List, Optional, Tuple |
| import numpy as np |
|
|
| from sklearn.ensemble import RandomForestClassifier |
| from sklearn.model_selection import StratifiedKFold, cross_val_score, train_test_split |
| from sklearn.metrics import classification_report, confusion_matrix |
|
|
| from .features import FEATURE_NAMES, feature_vector_to_dict |
|
|
|
|
| LABELS = ["good", "okish", "bad"] |
| LABEL_TO_INT = {lbl: i for i, lbl in enumerate(LABELS)} |
|
|
|
|
| class ViralityClassifier: |
| """Wraps sklearn RandomForest with our 40-feature brain pipeline.""" |
|
|
| def __init__(self, n_estimators: int = 200, max_depth: Optional[int] = 12, |
| random_state: int = 42, class_weight: str = "balanced"): |
| self.n_estimators = n_estimators |
| self.max_depth = max_depth |
| self.random_state = random_state |
| self.class_weight = class_weight |
| self.model: Optional[RandomForestClassifier] = None |
| self.feature_names = FEATURE_NAMES |
|
|
| def fit(self, X: np.ndarray, y: np.ndarray, |
| test_size: float = 0.2, |
| verbose: bool = True) -> Dict: |
| """ |
| X: (n_videos, 40) |
| y: (n_videos,) integer labels 0=good,1=okish,2=bad |
| Returns training report dict. |
| """ |
| X_train, X_test, y_train, y_test = train_test_split( |
| X, y, test_size=test_size, stratify=y, random_state=self.random_state |
| ) |
|
|
| self.model = RandomForestClassifier( |
| n_estimators=self.n_estimators, |
| max_depth=self.max_depth, |
| class_weight=self.class_weight, |
| random_state=self.random_state, |
| n_jobs=-1, |
| ) |
| self.model.fit(X_train, y_train) |
|
|
| |
| cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=self.random_state) |
| cv_scores = cross_val_score(self.model, X_train, y_train, cv=cv, scoring="accuracy") |
|
|
| test_acc = float(self.model.score(X_test, y_test)) |
| y_pred = self.model.predict(X_test) |
|
|
| report = { |
| "cv_accuracy_mean": float(np.mean(cv_scores)), |
| "cv_accuracy_std": float(np.std(cv_scores)), |
| "test_accuracy": test_acc, |
| "n_train": len(y_train), |
| "n_test": len(y_test), |
| "class_distribution_train": {lbl: int((y_train == i).sum()) for i, lbl in enumerate(LABELS)}, |
| "class_distribution_test": {lbl: int((y_test == i).sum()) for i, lbl in enumerate(LABELS)}, |
| "confusion_matrix": confusion_matrix(y_test, y_pred).tolist(), |
| "classification_report": classification_report( |
| y_test, y_pred, target_names=LABELS, output_dict=True |
| ), |
| } |
|
|
| if verbose: |
| print(f"\n{'═'*60}") |
| print(f" Classifier training report") |
| print(f"{'═'*60}") |
| print(f" CV accuracy : {report['cv_accuracy_mean']:.1%} ± {report['cv_accuracy_std']:.1%}") |
| print(f" Test accuracy: {report['test_accuracy']:.1%}") |
| print(f"{'─'*60}") |
| print(classification_report(y_test, y_pred, target_names=LABELS)) |
| print(f"{'═'*60}") |
|
|
| return report |
|
|
| def predict(self, x: np.ndarray) -> Tuple[str, float, np.ndarray]: |
| """ |
| x: (40,) feature vector |
| Returns: (label, confidence, probability_vector) |
| """ |
| if self.model is None: |
| raise RuntimeError("Model not trained — call .fit() first.") |
| proba = self.model.predict_proba(x.reshape(1, -1))[0] |
| idx = int(np.argmax(proba)) |
| label = LABELS[idx] |
| confidence = float(proba[idx]) |
| return label, confidence, proba |
|
|
| def explain(self, x: np.ndarray, top_n: int = 3) -> Dict: |
| """ |
| Return human-readable explanation of the prediction. |
| Includes strongest positive/negative features. |
| """ |
| label, confidence, proba = self.predict(x) |
| feat_dict = feature_vector_to_dict(x) |
|
|
| |
| importance = None |
| if self.model is not None and hasattr(self.model, "feature_importances_"): |
| importance = dict(zip(self.feature_names, self.model.feature_importances_.tolist())) |
|
|
| |
| sorted_feats = sorted(feat_dict.items(), key=lambda kv: kv[1], reverse=True) |
| strongest_pos = sorted_feats[:top_n] |
| strongest_neg = sorted_feats[-top_n:] |
|
|
| explanation = { |
| "verdict": label, |
| "confidence": confidence, |
| "probabilities": {lbl: float(p) for lbl, p in zip(LABELS, proba)}, |
| "strongest_positive_signals": [ |
| {"feature": k, "value": round(v, 4)} for k, v in strongest_pos |
| ], |
| "strongest_negative_signals": [ |
| {"feature": k, "value": round(v, 4)} for k, v in strongest_neg |
| ], |
| "feature_importance": importance, |
| } |
| return explanation |
|
|
| def save(self, path: str): |
| """Pickle model + hyperparameters.""" |
| payload = { |
| "model": self.model, |
| "params": { |
| "n_estimators": self.n_estimators, |
| "max_depth": self.max_depth, |
| "random_state": self.random_state, |
| "class_weight": self.class_weight, |
| }, |
| } |
| Path(path).parent.mkdir(parents=True, exist_ok=True) |
| with open(path, "wb") as f: |
| pickle.dump(payload, f) |
| print(f"Model saved → {path}") |
|
|
| @classmethod |
| def load(cls, path: str) -> "ViralityClassifier": |
| with open(path, "rb") as f: |
| payload = pickle.load(f) |
| inst = cls(**payload["params"]) |
| inst.model = payload["model"] |
| return inst |
|
|
|
|
| def label_distribution(labels: List[str]) -> Dict[str, int]: |
| return {lbl: labels.count(lbl) for lbl in LABELS} |
|
|