moonlantern1 commited on
Commit
e7b3bfa
·
verified ·
1 Parent(s): ae73961

Upload brain_virality_predictor/classifier.py with huggingface_hub

Browse files
brain_virality_predictor/classifier.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Random Forest classifier trained on 40 brain-derived features.
3
+ 3-class: good / okish / bad
4
+ Cross-validated accuracy target: 70–80%
5
+ """
6
+
7
+ import json, pickle
8
+ from pathlib import Path
9
+ from typing import Dict, List, Optional, Tuple
10
+ import numpy as np
11
+
12
+ from sklearn.ensemble import RandomForestClassifier
13
+ from sklearn.model_selection import StratifiedKFold, cross_val_score, train_test_split
14
+ from sklearn.metrics import classification_report, confusion_matrix
15
+
16
+ from .features import FEATURE_NAMES, feature_vector_to_dict
17
+
18
+
19
+ LABELS = ["good", "okish", "bad"]
20
+ LABEL_TO_INT = {lbl: i for i, lbl in enumerate(LABELS)}
21
+
22
+
23
+ class ViralityClassifier:
24
+ """Wraps sklearn RandomForest with our 40-feature brain pipeline."""
25
+
26
+ def __init__(self, n_estimators: int = 200, max_depth: Optional[int] = 12,
27
+ random_state: int = 42, class_weight: str = "balanced"):
28
+ self.n_estimators = n_estimators
29
+ self.max_depth = max_depth
30
+ self.random_state = random_state
31
+ self.class_weight = class_weight
32
+ self.model: Optional[RandomForestClassifier] = None
33
+ self.feature_names = FEATURE_NAMES
34
+
35
+ def fit(self, X: np.ndarray, y: np.ndarray,
36
+ test_size: float = 0.2,
37
+ verbose: bool = True) -> Dict:
38
+ """
39
+ X: (n_videos, 40)
40
+ y: (n_videos,) integer labels 0=good,1=okish,2=bad
41
+ Returns training report dict.
42
+ """
43
+ X_train, X_test, y_train, y_test = train_test_split(
44
+ X, y, test_size=test_size, stratify=y, random_state=self.random_state
45
+ )
46
+
47
+ self.model = RandomForestClassifier(
48
+ n_estimators=self.n_estimators,
49
+ max_depth=self.max_depth,
50
+ class_weight=self.class_weight,
51
+ random_state=self.random_state,
52
+ n_jobs=-1,
53
+ )
54
+ self.model.fit(X_train, y_train)
55
+
56
+ # cross-validation
57
+ cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=self.random_state)
58
+ cv_scores = cross_val_score(self.model, X_train, y_train, cv=cv, scoring="accuracy")
59
+
60
+ test_acc = float(self.model.score(X_test, y_test))
61
+ y_pred = self.model.predict(X_test)
62
+
63
+ report = {
64
+ "cv_accuracy_mean": float(np.mean(cv_scores)),
65
+ "cv_accuracy_std": float(np.std(cv_scores)),
66
+ "test_accuracy": test_acc,
67
+ "n_train": len(y_train),
68
+ "n_test": len(y_test),
69
+ "class_distribution_train": {lbl: int((y_train == i).sum()) for i, lbl in enumerate(LABELS)},
70
+ "class_distribution_test": {lbl: int((y_test == i).sum()) for i, lbl in enumerate(LABELS)},
71
+ "confusion_matrix": confusion_matrix(y_test, y_pred).tolist(),
72
+ "classification_report": classification_report(
73
+ y_test, y_pred, target_names=LABELS, output_dict=True
74
+ ),
75
+ }
76
+
77
+ if verbose:
78
+ print(f"\n{'═'*60}")
79
+ print(f" Classifier training report")
80
+ print(f"{'═'*60}")
81
+ print(f" CV accuracy : {report['cv_accuracy_mean']:.1%} ± {report['cv_accuracy_std']:.1%}")
82
+ print(f" Test accuracy: {report['test_accuracy']:.1%}")
83
+ print(f"{'─'*60}")
84
+ print(classification_report(y_test, y_pred, target_names=LABELS))
85
+ print(f"{'═'*60}")
86
+
87
+ return report
88
+
89
+ def predict(self, x: np.ndarray) -> Tuple[str, float, np.ndarray]:
90
+ """
91
+ x: (40,) feature vector
92
+ Returns: (label, confidence, probability_vector)
93
+ """
94
+ if self.model is None:
95
+ raise RuntimeError("Model not trained — call .fit() first.")
96
+ proba = self.model.predict_proba(x.reshape(1, -1))[0]
97
+ idx = int(np.argmax(proba))
98
+ label = LABELS[idx]
99
+ confidence = float(proba[idx])
100
+ return label, confidence, proba
101
+
102
+ def explain(self, x: np.ndarray, top_n: int = 3) -> Dict:
103
+ """
104
+ Return human-readable explanation of the prediction.
105
+ Includes strongest positive/negative features.
106
+ """
107
+ label, confidence, proba = self.predict(x)
108
+ feat_dict = feature_vector_to_dict(x)
109
+
110
+ # feature importance if model fitted
111
+ importance = None
112
+ if self.model is not None and hasattr(self.model, "feature_importances_"):
113
+ importance = dict(zip(self.feature_names, self.model.feature_importances_.tolist()))
114
+
115
+ # strongest signal values
116
+ sorted_feats = sorted(feat_dict.items(), key=lambda kv: kv[1], reverse=True)
117
+ strongest_pos = sorted_feats[:top_n]
118
+ strongest_neg = sorted_feats[-top_n:]
119
+
120
+ explanation = {
121
+ "verdict": label,
122
+ "confidence": confidence,
123
+ "probabilities": {lbl: float(p) for lbl, p in zip(LABELS, proba)},
124
+ "strongest_positive_signals": [
125
+ {"feature": k, "value": round(v, 4)} for k, v in strongest_pos
126
+ ],
127
+ "strongest_negative_signals": [
128
+ {"feature": k, "value": round(v, 4)} for k, v in strongest_neg
129
+ ],
130
+ "feature_importance": importance,
131
+ }
132
+ return explanation
133
+
134
+ def save(self, path: str):
135
+ """Pickle model + hyperparameters."""
136
+ payload = {
137
+ "model": self.model,
138
+ "params": {
139
+ "n_estimators": self.n_estimators,
140
+ "max_depth": self.max_depth,
141
+ "random_state": self.random_state,
142
+ "class_weight": self.class_weight,
143
+ },
144
+ }
145
+ Path(path).parent.mkdir(parents=True, exist_ok=True)
146
+ with open(path, "wb") as f:
147
+ pickle.dump(payload, f)
148
+ print(f"Model saved → {path}")
149
+
150
+ @classmethod
151
+ def load(cls, path: str) -> "ViralityClassifier":
152
+ with open(path, "rb") as f:
153
+ payload = pickle.load(f)
154
+ inst = cls(**payload["params"])
155
+ inst.model = payload["model"]
156
+ return inst
157
+
158
+
159
+ def label_distribution(labels: List[str]) -> Dict[str, int]:
160
+ return {lbl: labels.count(lbl) for lbl in LABELS}