moonlantern1's picture
Upload brain_virality_predictor/classifier.py with huggingface_hub
e7b3bfa verified
"""
Random Forest classifier trained on 40 brain-derived features.
3-class: good / okish / bad
Cross-validated accuracy target: 70–80%
"""
import json, pickle
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold, cross_val_score, train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from .features import FEATURE_NAMES, feature_vector_to_dict
LABELS = ["good", "okish", "bad"]
LABEL_TO_INT = {lbl: i for i, lbl in enumerate(LABELS)}
class ViralityClassifier:
"""Wraps sklearn RandomForest with our 40-feature brain pipeline."""
def __init__(self, n_estimators: int = 200, max_depth: Optional[int] = 12,
random_state: int = 42, class_weight: str = "balanced"):
self.n_estimators = n_estimators
self.max_depth = max_depth
self.random_state = random_state
self.class_weight = class_weight
self.model: Optional[RandomForestClassifier] = None
self.feature_names = FEATURE_NAMES
def fit(self, X: np.ndarray, y: np.ndarray,
test_size: float = 0.2,
verbose: bool = True) -> Dict:
"""
X: (n_videos, 40)
y: (n_videos,) integer labels 0=good,1=okish,2=bad
Returns training report dict.
"""
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=test_size, stratify=y, random_state=self.random_state
)
self.model = RandomForestClassifier(
n_estimators=self.n_estimators,
max_depth=self.max_depth,
class_weight=self.class_weight,
random_state=self.random_state,
n_jobs=-1,
)
self.model.fit(X_train, y_train)
# cross-validation
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=self.random_state)
cv_scores = cross_val_score(self.model, X_train, y_train, cv=cv, scoring="accuracy")
test_acc = float(self.model.score(X_test, y_test))
y_pred = self.model.predict(X_test)
report = {
"cv_accuracy_mean": float(np.mean(cv_scores)),
"cv_accuracy_std": float(np.std(cv_scores)),
"test_accuracy": test_acc,
"n_train": len(y_train),
"n_test": len(y_test),
"class_distribution_train": {lbl: int((y_train == i).sum()) for i, lbl in enumerate(LABELS)},
"class_distribution_test": {lbl: int((y_test == i).sum()) for i, lbl in enumerate(LABELS)},
"confusion_matrix": confusion_matrix(y_test, y_pred).tolist(),
"classification_report": classification_report(
y_test, y_pred, target_names=LABELS, output_dict=True
),
}
if verbose:
print(f"\n{'═'*60}")
print(f" Classifier training report")
print(f"{'═'*60}")
print(f" CV accuracy : {report['cv_accuracy_mean']:.1%} ± {report['cv_accuracy_std']:.1%}")
print(f" Test accuracy: {report['test_accuracy']:.1%}")
print(f"{'─'*60}")
print(classification_report(y_test, y_pred, target_names=LABELS))
print(f"{'═'*60}")
return report
def predict(self, x: np.ndarray) -> Tuple[str, float, np.ndarray]:
"""
x: (40,) feature vector
Returns: (label, confidence, probability_vector)
"""
if self.model is None:
raise RuntimeError("Model not trained — call .fit() first.")
proba = self.model.predict_proba(x.reshape(1, -1))[0]
idx = int(np.argmax(proba))
label = LABELS[idx]
confidence = float(proba[idx])
return label, confidence, proba
def explain(self, x: np.ndarray, top_n: int = 3) -> Dict:
"""
Return human-readable explanation of the prediction.
Includes strongest positive/negative features.
"""
label, confidence, proba = self.predict(x)
feat_dict = feature_vector_to_dict(x)
# feature importance if model fitted
importance = None
if self.model is not None and hasattr(self.model, "feature_importances_"):
importance = dict(zip(self.feature_names, self.model.feature_importances_.tolist()))
# strongest signal values
sorted_feats = sorted(feat_dict.items(), key=lambda kv: kv[1], reverse=True)
strongest_pos = sorted_feats[:top_n]
strongest_neg = sorted_feats[-top_n:]
explanation = {
"verdict": label,
"confidence": confidence,
"probabilities": {lbl: float(p) for lbl, p in zip(LABELS, proba)},
"strongest_positive_signals": [
{"feature": k, "value": round(v, 4)} for k, v in strongest_pos
],
"strongest_negative_signals": [
{"feature": k, "value": round(v, 4)} for k, v in strongest_neg
],
"feature_importance": importance,
}
return explanation
def save(self, path: str):
"""Pickle model + hyperparameters."""
payload = {
"model": self.model,
"params": {
"n_estimators": self.n_estimators,
"max_depth": self.max_depth,
"random_state": self.random_state,
"class_weight": self.class_weight,
},
}
Path(path).parent.mkdir(parents=True, exist_ok=True)
with open(path, "wb") as f:
pickle.dump(payload, f)
print(f"Model saved → {path}")
@classmethod
def load(cls, path: str) -> "ViralityClassifier":
with open(path, "rb") as f:
payload = pickle.load(f)
inst = cls(**payload["params"])
inst.model = payload["model"]
return inst
def label_distribution(labels: List[str]) -> Dict[str, int]:
return {lbl: labels.count(lbl) for lbl in LABELS}