Upload brain_virality_predictor/classifier.py with huggingface_hub
Browse files
brain_virality_predictor/classifier.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Random Forest classifier trained on 40 brain-derived features.
|
| 3 |
+
3-class: good / okish / bad
|
| 4 |
+
Cross-validated accuracy target: 70–80%
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import json, pickle
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Dict, List, Optional, Tuple
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
from sklearn.ensemble import RandomForestClassifier
|
| 13 |
+
from sklearn.model_selection import StratifiedKFold, cross_val_score, train_test_split
|
| 14 |
+
from sklearn.metrics import classification_report, confusion_matrix
|
| 15 |
+
|
| 16 |
+
from .features import FEATURE_NAMES, feature_vector_to_dict
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
LABELS = ["good", "okish", "bad"]
|
| 20 |
+
LABEL_TO_INT = {lbl: i for i, lbl in enumerate(LABELS)}
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class ViralityClassifier:
|
| 24 |
+
"""Wraps sklearn RandomForest with our 40-feature brain pipeline."""
|
| 25 |
+
|
| 26 |
+
def __init__(self, n_estimators: int = 200, max_depth: Optional[int] = 12,
|
| 27 |
+
random_state: int = 42, class_weight: str = "balanced"):
|
| 28 |
+
self.n_estimators = n_estimators
|
| 29 |
+
self.max_depth = max_depth
|
| 30 |
+
self.random_state = random_state
|
| 31 |
+
self.class_weight = class_weight
|
| 32 |
+
self.model: Optional[RandomForestClassifier] = None
|
| 33 |
+
self.feature_names = FEATURE_NAMES
|
| 34 |
+
|
| 35 |
+
def fit(self, X: np.ndarray, y: np.ndarray,
|
| 36 |
+
test_size: float = 0.2,
|
| 37 |
+
verbose: bool = True) -> Dict:
|
| 38 |
+
"""
|
| 39 |
+
X: (n_videos, 40)
|
| 40 |
+
y: (n_videos,) integer labels 0=good,1=okish,2=bad
|
| 41 |
+
Returns training report dict.
|
| 42 |
+
"""
|
| 43 |
+
X_train, X_test, y_train, y_test = train_test_split(
|
| 44 |
+
X, y, test_size=test_size, stratify=y, random_state=self.random_state
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
self.model = RandomForestClassifier(
|
| 48 |
+
n_estimators=self.n_estimators,
|
| 49 |
+
max_depth=self.max_depth,
|
| 50 |
+
class_weight=self.class_weight,
|
| 51 |
+
random_state=self.random_state,
|
| 52 |
+
n_jobs=-1,
|
| 53 |
+
)
|
| 54 |
+
self.model.fit(X_train, y_train)
|
| 55 |
+
|
| 56 |
+
# cross-validation
|
| 57 |
+
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=self.random_state)
|
| 58 |
+
cv_scores = cross_val_score(self.model, X_train, y_train, cv=cv, scoring="accuracy")
|
| 59 |
+
|
| 60 |
+
test_acc = float(self.model.score(X_test, y_test))
|
| 61 |
+
y_pred = self.model.predict(X_test)
|
| 62 |
+
|
| 63 |
+
report = {
|
| 64 |
+
"cv_accuracy_mean": float(np.mean(cv_scores)),
|
| 65 |
+
"cv_accuracy_std": float(np.std(cv_scores)),
|
| 66 |
+
"test_accuracy": test_acc,
|
| 67 |
+
"n_train": len(y_train),
|
| 68 |
+
"n_test": len(y_test),
|
| 69 |
+
"class_distribution_train": {lbl: int((y_train == i).sum()) for i, lbl in enumerate(LABELS)},
|
| 70 |
+
"class_distribution_test": {lbl: int((y_test == i).sum()) for i, lbl in enumerate(LABELS)},
|
| 71 |
+
"confusion_matrix": confusion_matrix(y_test, y_pred).tolist(),
|
| 72 |
+
"classification_report": classification_report(
|
| 73 |
+
y_test, y_pred, target_names=LABELS, output_dict=True
|
| 74 |
+
),
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
if verbose:
|
| 78 |
+
print(f"\n{'═'*60}")
|
| 79 |
+
print(f" Classifier training report")
|
| 80 |
+
print(f"{'═'*60}")
|
| 81 |
+
print(f" CV accuracy : {report['cv_accuracy_mean']:.1%} ± {report['cv_accuracy_std']:.1%}")
|
| 82 |
+
print(f" Test accuracy: {report['test_accuracy']:.1%}")
|
| 83 |
+
print(f"{'─'*60}")
|
| 84 |
+
print(classification_report(y_test, y_pred, target_names=LABELS))
|
| 85 |
+
print(f"{'═'*60}")
|
| 86 |
+
|
| 87 |
+
return report
|
| 88 |
+
|
| 89 |
+
def predict(self, x: np.ndarray) -> Tuple[str, float, np.ndarray]:
|
| 90 |
+
"""
|
| 91 |
+
x: (40,) feature vector
|
| 92 |
+
Returns: (label, confidence, probability_vector)
|
| 93 |
+
"""
|
| 94 |
+
if self.model is None:
|
| 95 |
+
raise RuntimeError("Model not trained — call .fit() first.")
|
| 96 |
+
proba = self.model.predict_proba(x.reshape(1, -1))[0]
|
| 97 |
+
idx = int(np.argmax(proba))
|
| 98 |
+
label = LABELS[idx]
|
| 99 |
+
confidence = float(proba[idx])
|
| 100 |
+
return label, confidence, proba
|
| 101 |
+
|
| 102 |
+
def explain(self, x: np.ndarray, top_n: int = 3) -> Dict:
|
| 103 |
+
"""
|
| 104 |
+
Return human-readable explanation of the prediction.
|
| 105 |
+
Includes strongest positive/negative features.
|
| 106 |
+
"""
|
| 107 |
+
label, confidence, proba = self.predict(x)
|
| 108 |
+
feat_dict = feature_vector_to_dict(x)
|
| 109 |
+
|
| 110 |
+
# feature importance if model fitted
|
| 111 |
+
importance = None
|
| 112 |
+
if self.model is not None and hasattr(self.model, "feature_importances_"):
|
| 113 |
+
importance = dict(zip(self.feature_names, self.model.feature_importances_.tolist()))
|
| 114 |
+
|
| 115 |
+
# strongest signal values
|
| 116 |
+
sorted_feats = sorted(feat_dict.items(), key=lambda kv: kv[1], reverse=True)
|
| 117 |
+
strongest_pos = sorted_feats[:top_n]
|
| 118 |
+
strongest_neg = sorted_feats[-top_n:]
|
| 119 |
+
|
| 120 |
+
explanation = {
|
| 121 |
+
"verdict": label,
|
| 122 |
+
"confidence": confidence,
|
| 123 |
+
"probabilities": {lbl: float(p) for lbl, p in zip(LABELS, proba)},
|
| 124 |
+
"strongest_positive_signals": [
|
| 125 |
+
{"feature": k, "value": round(v, 4)} for k, v in strongest_pos
|
| 126 |
+
],
|
| 127 |
+
"strongest_negative_signals": [
|
| 128 |
+
{"feature": k, "value": round(v, 4)} for k, v in strongest_neg
|
| 129 |
+
],
|
| 130 |
+
"feature_importance": importance,
|
| 131 |
+
}
|
| 132 |
+
return explanation
|
| 133 |
+
|
| 134 |
+
def save(self, path: str):
|
| 135 |
+
"""Pickle model + hyperparameters."""
|
| 136 |
+
payload = {
|
| 137 |
+
"model": self.model,
|
| 138 |
+
"params": {
|
| 139 |
+
"n_estimators": self.n_estimators,
|
| 140 |
+
"max_depth": self.max_depth,
|
| 141 |
+
"random_state": self.random_state,
|
| 142 |
+
"class_weight": self.class_weight,
|
| 143 |
+
},
|
| 144 |
+
}
|
| 145 |
+
Path(path).parent.mkdir(parents=True, exist_ok=True)
|
| 146 |
+
with open(path, "wb") as f:
|
| 147 |
+
pickle.dump(payload, f)
|
| 148 |
+
print(f"Model saved → {path}")
|
| 149 |
+
|
| 150 |
+
@classmethod
|
| 151 |
+
def load(cls, path: str) -> "ViralityClassifier":
|
| 152 |
+
with open(path, "rb") as f:
|
| 153 |
+
payload = pickle.load(f)
|
| 154 |
+
inst = cls(**payload["params"])
|
| 155 |
+
inst.model = payload["model"]
|
| 156 |
+
return inst
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def label_distribution(labels: List[str]) -> Dict[str, int]:
|
| 160 |
+
return {lbl: labels.count(lbl) for lbl in LABELS}
|