File size: 6,031 Bytes
e7b3bfa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
"""
Random Forest classifier trained on 40 brain-derived features.
3-class: good / okish / bad
Cross-validated accuracy target: 70–80%
"""

import json, pickle
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import numpy as np

from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold, cross_val_score, train_test_split
from sklearn.metrics import classification_report, confusion_matrix

from .features import FEATURE_NAMES, feature_vector_to_dict


LABELS = ["good", "okish", "bad"]
LABEL_TO_INT = {lbl: i for i, lbl in enumerate(LABELS)}


class ViralityClassifier:
    """Wraps sklearn RandomForest with our 40-feature brain pipeline."""

    def __init__(self, n_estimators: int = 200, max_depth: Optional[int] = 12,
                 random_state: int = 42, class_weight: str = "balanced"):
        self.n_estimators = n_estimators
        self.max_depth = max_depth
        self.random_state = random_state
        self.class_weight = class_weight
        self.model: Optional[RandomForestClassifier] = None
        self.feature_names = FEATURE_NAMES

    def fit(self, X: np.ndarray, y: np.ndarray,
            test_size: float = 0.2,
            verbose: bool = True) -> Dict:
        """
        X: (n_videos, 40)
        y: (n_videos,) integer labels 0=good,1=okish,2=bad
        Returns training report dict.
        """
        X_train, X_test, y_train, y_test = train_test_split(
            X, y, test_size=test_size, stratify=y, random_state=self.random_state
        )

        self.model = RandomForestClassifier(
            n_estimators=self.n_estimators,
            max_depth=self.max_depth,
            class_weight=self.class_weight,
            random_state=self.random_state,
            n_jobs=-1,
        )
        self.model.fit(X_train, y_train)

        # cross-validation
        cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=self.random_state)
        cv_scores = cross_val_score(self.model, X_train, y_train, cv=cv, scoring="accuracy")

        test_acc = float(self.model.score(X_test, y_test))
        y_pred = self.model.predict(X_test)

        report = {
            "cv_accuracy_mean": float(np.mean(cv_scores)),
            "cv_accuracy_std": float(np.std(cv_scores)),
            "test_accuracy": test_acc,
            "n_train": len(y_train),
            "n_test": len(y_test),
            "class_distribution_train": {lbl: int((y_train == i).sum()) for i, lbl in enumerate(LABELS)},
            "class_distribution_test": {lbl: int((y_test == i).sum()) for i, lbl in enumerate(LABELS)},
            "confusion_matrix": confusion_matrix(y_test, y_pred).tolist(),
            "classification_report": classification_report(
                y_test, y_pred, target_names=LABELS, output_dict=True
            ),
        }

        if verbose:
            print(f"\n{'═'*60}")
            print(f"  Classifier training report")
            print(f"{'═'*60}")
            print(f"  CV accuracy : {report['cv_accuracy_mean']:.1%} ± {report['cv_accuracy_std']:.1%}")
            print(f"  Test accuracy: {report['test_accuracy']:.1%}")
            print(f"{'─'*60}")
            print(classification_report(y_test, y_pred, target_names=LABELS))
            print(f"{'═'*60}")

        return report

    def predict(self, x: np.ndarray) -> Tuple[str, float, np.ndarray]:
        """
        x: (40,) feature vector
        Returns: (label, confidence, probability_vector)
        """
        if self.model is None:
            raise RuntimeError("Model not trained — call .fit() first.")
        proba = self.model.predict_proba(x.reshape(1, -1))[0]
        idx = int(np.argmax(proba))
        label = LABELS[idx]
        confidence = float(proba[idx])
        return label, confidence, proba

    def explain(self, x: np.ndarray, top_n: int = 3) -> Dict:
        """
        Return human-readable explanation of the prediction.
        Includes strongest positive/negative features.
        """
        label, confidence, proba = self.predict(x)
        feat_dict = feature_vector_to_dict(x)

        # feature importance if model fitted
        importance = None
        if self.model is not None and hasattr(self.model, "feature_importances_"):
            importance = dict(zip(self.feature_names, self.model.feature_importances_.tolist()))

        # strongest signal values
        sorted_feats = sorted(feat_dict.items(), key=lambda kv: kv[1], reverse=True)
        strongest_pos = sorted_feats[:top_n]
        strongest_neg = sorted_feats[-top_n:]

        explanation = {
            "verdict": label,
            "confidence": confidence,
            "probabilities": {lbl: float(p) for lbl, p in zip(LABELS, proba)},
            "strongest_positive_signals": [
                {"feature": k, "value": round(v, 4)} for k, v in strongest_pos
            ],
            "strongest_negative_signals": [
                {"feature": k, "value": round(v, 4)} for k, v in strongest_neg
            ],
            "feature_importance": importance,
        }
        return explanation

    def save(self, path: str):
        """Pickle model + hyperparameters."""
        payload = {
            "model": self.model,
            "params": {
                "n_estimators": self.n_estimators,
                "max_depth": self.max_depth,
                "random_state": self.random_state,
                "class_weight": self.class_weight,
            },
        }
        Path(path).parent.mkdir(parents=True, exist_ok=True)
        with open(path, "wb") as f:
            pickle.dump(payload, f)
        print(f"Model saved → {path}")

    @classmethod
    def load(cls, path: str) -> "ViralityClassifier":
        with open(path, "rb") as f:
            payload = pickle.load(f)
        inst = cls(**payload["params"])
        inst.model = payload["model"]
        return inst


def label_distribution(labels: List[str]) -> Dict[str, int]:
    return {lbl: labels.count(lbl) for lbl in LABELS}