Spaces:
Build error
Build error
Upload 7 files
Browse files- src/classical_classifiers.py +221 -0
- src/imbalance_handling.py +219 -0
- src/main.py +544 -0
- src/model_evaluation.py +308 -0
- src/model_interpretation.py +320 -0
- src/neural_classifiers.py +286 -0
- src/text_preprocessing.py +277 -0
src/classical_classifiers.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Any, Optional, Union, Tuple
|
| 2 |
+
import numpy as np
|
| 3 |
+
import pandas as pd
|
| 4 |
+
from sklearn.base import BaseEstimator, ClassifierMixin
|
| 5 |
+
from sklearn.linear_model import LogisticRegression
|
| 6 |
+
from sklearn.svm import SVC
|
| 7 |
+
from sklearn.ensemble import RandomForestClassifier, BaggingClassifier, VotingClassifier, StackingClassifier
|
| 8 |
+
from sklearn.model_selection import train_test_split
|
| 9 |
+
from sklearn.metrics import accuracy_score, classification_report
|
| 10 |
+
from sklearn.preprocessing import LabelEncoder
|
| 11 |
+
|
| 12 |
+
XGBClassifier = None
|
| 13 |
+
CatBoostClassifier = None
|
| 14 |
+
LGBMClassifier = None
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
from xgboost import XGBClassifier
|
| 18 |
+
except ImportError:
|
| 19 |
+
pass
|
| 20 |
+
|
| 21 |
+
try:
|
| 22 |
+
from catboost import CatBoostClassifier
|
| 23 |
+
except ImportError:
|
| 24 |
+
pass
|
| 25 |
+
|
| 26 |
+
try:
|
| 27 |
+
from lightgbm import LGBMClassifier
|
| 28 |
+
except ImportError:
|
| 29 |
+
pass
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def get_logistic_regression(
|
| 33 |
+
penalty: str = "l2",
|
| 34 |
+
C: float = 1.0,
|
| 35 |
+
max_iter: int = 1000,
|
| 36 |
+
solver: str = "liblinear", # supports l1
|
| 37 |
+
random_state: int = 42
|
| 38 |
+
) -> LogisticRegression:
|
| 39 |
+
if penalty not in ("l1", "l2", "elasticnet", "none"):
|
| 40 |
+
raise ValueError("penalty must be 'l1', 'l2', 'elasticnet', or 'none'")
|
| 41 |
+
if penalty == "l1" and solver not in ("liblinear", "saga"):
|
| 42 |
+
solver = "liblinear"
|
| 43 |
+
return LogisticRegression(
|
| 44 |
+
penalty=penalty,
|
| 45 |
+
C=C,
|
| 46 |
+
max_iter=max_iter,
|
| 47 |
+
solver=solver,
|
| 48 |
+
random_state=random_state
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def get_svm_linear(C: float = 1.0, random_state: int = 42) -> SVC:
|
| 53 |
+
return SVC(kernel="linear", C=C, probability=True, random_state=random_state)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def get_random_forest(
|
| 57 |
+
n_estimators: int = 100,
|
| 58 |
+
max_depth: Optional[int] = None,
|
| 59 |
+
random_state: int = 42
|
| 60 |
+
) -> RandomForestClassifier:
|
| 61 |
+
return RandomForestClassifier(
|
| 62 |
+
n_estimators=n_estimators,
|
| 63 |
+
max_depth=max_depth,
|
| 64 |
+
random_state=random_state
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def get_gradient_boosting(
|
| 69 |
+
model_type: str = "xgb",
|
| 70 |
+
**kwargs
|
| 71 |
+
) -> Union[XGBClassifier, "CatBoostClassifier", "LGBMClassifier"]:
|
| 72 |
+
if model_type == "xgb":
|
| 73 |
+
if XGBClassifier is None:
|
| 74 |
+
raise ImportError("XGBoost not installed. Run: pip install xgboost")
|
| 75 |
+
kwargs.setdefault("random_state", 42)
|
| 76 |
+
return XGBClassifier(**kwargs)
|
| 77 |
+
elif model_type == "cat":
|
| 78 |
+
if CatBoostClassifier is None:
|
| 79 |
+
raise ImportError("CatBoost not installed. Run: pip install catboost")
|
| 80 |
+
kwargs.setdefault("verbose", False)
|
| 81 |
+
kwargs.setdefault("random_seed", 42)
|
| 82 |
+
return CatBoostClassifier(**kwargs)
|
| 83 |
+
elif model_type == "lgb":
|
| 84 |
+
if LGBMClassifier is None:
|
| 85 |
+
raise ImportError("LightGBM not installed. Run: pip install lightgbm")
|
| 86 |
+
kwargs.setdefault("random_state", 42)
|
| 87 |
+
return LGBMClassifier(**kwargs)
|
| 88 |
+
else:
|
| 89 |
+
raise ValueError("model_type must be 'xgb', 'cat', or 'lgb'")
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def get_bagging_classifier(
|
| 93 |
+
base_estimator: str = "tree",
|
| 94 |
+
n_estimators: int = 10,
|
| 95 |
+
random_state: int = 42
|
| 96 |
+
) -> BaggingClassifier:
|
| 97 |
+
if base_estimator == "tree":
|
| 98 |
+
from sklearn.tree import DecisionTreeClassifier
|
| 99 |
+
estimator = DecisionTreeClassifier(random_state=random_state)
|
| 100 |
+
elif base_estimator == "lr":
|
| 101 |
+
estimator = get_logistic_regression()
|
| 102 |
+
else:
|
| 103 |
+
raise ValueError("base_estimator must be 'tree' or 'lr'")
|
| 104 |
+
return BaggingClassifier(
|
| 105 |
+
estimator=estimator,
|
| 106 |
+
n_estimators=n_estimators,
|
| 107 |
+
random_state=random_state
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def get_stacking_classifier(
|
| 112 |
+
final_estimator: Optional[BaseEstimator] = None,
|
| 113 |
+
cv: int = 5,
|
| 114 |
+
random_state: int = 42
|
| 115 |
+
) -> StackingClassifier:
|
| 116 |
+
estimators = [
|
| 117 |
+
("lr", get_logistic_regression()),
|
| 118 |
+
("svm", get_svm_linear()),
|
| 119 |
+
]
|
| 120 |
+
if CatBoostClassifier is not None:
|
| 121 |
+
estimators.append(("cat", get_gradient_boosting("cat", iterations=100)))
|
| 122 |
+
|
| 123 |
+
if final_estimator is None:
|
| 124 |
+
final_estimator = get_logistic_regression()
|
| 125 |
+
|
| 126 |
+
return StackingClassifier(
|
| 127 |
+
estimators=estimators,
|
| 128 |
+
final_estimator=final_estimator,
|
| 129 |
+
cv=cv,
|
| 130 |
+
passthrough=False
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def get_voting_classifier(
|
| 135 |
+
voting: str = "soft",
|
| 136 |
+
use_catboost: bool = True
|
| 137 |
+
) -> VotingClassifier:
|
| 138 |
+
clfs = [
|
| 139 |
+
("lr", get_logistic_regression()),
|
| 140 |
+
("svm", get_svm_linear()),
|
| 141 |
+
("rf", get_random_forest(n_estimators=50))
|
| 142 |
+
]
|
| 143 |
+
if use_catboost and CatBoostClassifier is not None:
|
| 144 |
+
clfs.append(("cat", get_gradient_boosting("cat", iterations=50, verbose=False)))
|
| 145 |
+
|
| 146 |
+
return VotingClassifier(
|
| 147 |
+
estimators=clfs,
|
| 148 |
+
voting=voting
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def tpot_classifier(
|
| 153 |
+
generations: int = 5,
|
| 154 |
+
population_size: int = 20,
|
| 155 |
+
cv: int = 5,
|
| 156 |
+
random_state: int = 42,
|
| 157 |
+
verbosity: int = 0
|
| 158 |
+
) -> Any:
|
| 159 |
+
try:
|
| 160 |
+
from tpot import TPOTClassifier
|
| 161 |
+
except ImportError:
|
| 162 |
+
raise ImportError("TPOT not installed. Run: pip install tpot")
|
| 163 |
+
|
| 164 |
+
return TPOTClassifier(
|
| 165 |
+
generations=generations,
|
| 166 |
+
population_size=population_size,
|
| 167 |
+
cv=cv,
|
| 168 |
+
random_state=random_state,
|
| 169 |
+
verbosity=verbosity,
|
| 170 |
+
n_jobs=-1
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def h2o_classifier(
|
| 175 |
+
max_runtime_secs: int = 300,
|
| 176 |
+
seed: int = 42,
|
| 177 |
+
exclude_algos: Optional[list] = None
|
| 178 |
+
) -> Any:
|
| 179 |
+
try:
|
| 180 |
+
import h2o
|
| 181 |
+
from h2o.automl import H2OAutoML
|
| 182 |
+
except ImportError:
|
| 183 |
+
raise ImportError("H2O not installed. Run: pip install h2o")
|
| 184 |
+
|
| 185 |
+
aml = H2OAutoML(
|
| 186 |
+
max_runtime_secs=max_runtime_secs,
|
| 187 |
+
seed=seed,
|
| 188 |
+
exclude_algos=exclude_algos
|
| 189 |
+
)
|
| 190 |
+
return aml
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def train_and_evaluate(
|
| 194 |
+
model: Union[BaseEstimator, Any],
|
| 195 |
+
X_train: Union[np.ndarray, pd.DataFrame],
|
| 196 |
+
y_train: Union[np.ndarray, pd.Series],
|
| 197 |
+
X_test: Union[np.ndarray, pd.DataFrame],
|
| 198 |
+
y_test: Union[np.ndarray, pd.Series],
|
| 199 |
+
is_h2o: bool = False
|
| 200 |
+
) -> Dict[str, Any]:
|
| 201 |
+
if is_h2o:
|
| 202 |
+
import h2o
|
| 203 |
+
train_frame = X_train.cbind(y_train)
|
| 204 |
+
test_frame = X_test.cbind(y_test)
|
| 205 |
+
y_col = y_train.columns[0]
|
| 206 |
+
|
| 207 |
+
model.train(x=X_train.columns.tolist(), y=y_col, training_frame=train_frame)
|
| 208 |
+
perf = model.model_performance(test_frame)
|
| 209 |
+
return {
|
| 210 |
+
"accuracy": perf.accuracy()[0],
|
| 211 |
+
"auc": perf.auc() if perf._has_auc() else None,
|
| 212 |
+
"best_model": model.leader
|
| 213 |
+
}
|
| 214 |
+
else:
|
| 215 |
+
model.fit(X_train, y_train)
|
| 216 |
+
y_pred = model.predict(X_test)
|
| 217 |
+
return {
|
| 218 |
+
"accuracy": accuracy_score(y_test, y_pred),
|
| 219 |
+
"report": classification_report(y_test, y_pred, output_dict=True),
|
| 220 |
+
"model": model
|
| 221 |
+
}
|
src/imbalance_handling.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Tuple, Union, Dict, Optional, Any, Callable
|
| 2 |
+
import numpy as np
|
| 3 |
+
import pandas as pd
|
| 4 |
+
from collections import Counter
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def compute_class_weights(y: Union[List, np.ndarray], method: str = "balanced") -> Union[Dict[int, float], None]:
|
| 8 |
+
if method == "balanced":
|
| 9 |
+
from sklearn.utils.class_weight import compute_class_weight
|
| 10 |
+
classes = np.unique(y)
|
| 11 |
+
weights = compute_class_weight('balanced', classes=classes, y=y)
|
| 12 |
+
return dict(zip(classes, weights))
|
| 13 |
+
else:
|
| 14 |
+
return None
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def get_pytorch_weighted_loss(class_weights: Optional[Dict[int, float]] = None,
|
| 18 |
+
num_classes: Optional[int] = None) -> 'torch.nn.Module':
|
| 19 |
+
try:
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
except ImportError:
|
| 23 |
+
raise ImportError("PyTorch not installed")
|
| 24 |
+
|
| 25 |
+
if class_weights is not None:
|
| 26 |
+
weight_tensor = torch.tensor([class_weights[i] for i in sorted(class_weights.keys())], dtype=torch.float)
|
| 27 |
+
return nn.CrossEntropyLoss(weight=weight_tensor)
|
| 28 |
+
else:
|
| 29 |
+
return nn.CrossEntropyLoss()
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def get_tensorflow_weighted_loss(class_weights: Optional[Dict[int, float]] = None) -> Callable:
|
| 33 |
+
if not class_weights:
|
| 34 |
+
return 'sparse_categorical_crossentropy'
|
| 35 |
+
|
| 36 |
+
weight_list = [class_weights[i] for i in sorted(class_weights.keys())]
|
| 37 |
+
|
| 38 |
+
import tensorflow as tf
|
| 39 |
+
|
| 40 |
+
def weighted_sparse_categorical_crossentropy(y_true, y_pred):
|
| 41 |
+
y_true = tf.cast(y_true, tf.int32)
|
| 42 |
+
y_true_one_hot = tf.one_hot(y_true, depth=len(weight_list))
|
| 43 |
+
weights = tf.reduce_sum(y_true_one_hot * weight_list, axis=1)
|
| 44 |
+
unweighted_losses = tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred)
|
| 45 |
+
weighted_losses = unweighted_losses * weights
|
| 46 |
+
return tf.reduce_mean(weighted_losses)
|
| 47 |
+
|
| 48 |
+
return weighted_sparse_categorical_crossentropy
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def apply_sampling(
|
| 52 |
+
X: np.ndarray,
|
| 53 |
+
y: np.ndarray,
|
| 54 |
+
method: str = "random_under",
|
| 55 |
+
random_state: int = 42
|
| 56 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 57 |
+
from imblearn.over_sampling import SMOTE, ADASYN
|
| 58 |
+
from imblearn.under_sampling import RandomUnderSampler
|
| 59 |
+
from imblearn.over_sampling import RandomOverSampler
|
| 60 |
+
|
| 61 |
+
if method == "random_under":
|
| 62 |
+
sampler = RandomUnderSampler(random_state=random_state)
|
| 63 |
+
elif method == "random_over":
|
| 64 |
+
sampler = RandomOverSampler(random_state=random_state)
|
| 65 |
+
elif method == "smote":
|
| 66 |
+
sampler = SMOTE(random_state=random_state)
|
| 67 |
+
elif method == "adasyn":
|
| 68 |
+
sampler = ADASYN(random_state=random_state)
|
| 69 |
+
else:
|
| 70 |
+
raise ValueError("method must be one of: random_under, random_over, smote, adasyn")
|
| 71 |
+
|
| 72 |
+
X_res, y_res = sampler.fit_resample(X, y)
|
| 73 |
+
return X_res, y_res
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def augment_texts(
|
| 77 |
+
texts: List[str],
|
| 78 |
+
labels: List[Any],
|
| 79 |
+
augmentation_type: str = "synonym",
|
| 80 |
+
aug_p: float = 0.1,
|
| 81 |
+
lang: str = "ru", # language code
|
| 82 |
+
model_name: Optional[str] = None,
|
| 83 |
+
num_aug: int = 1,
|
| 84 |
+
random_state: int = 42
|
| 85 |
+
) -> Tuple[List[str], List[Any]]:
|
| 86 |
+
try:
|
| 87 |
+
import nlpaug.augmenter.word as naw
|
| 88 |
+
import nlpaug.augmenter.sentence as nas
|
| 89 |
+
except ImportError:
|
| 90 |
+
raise ImportError("Install nlpaug: pip install nlpaug")
|
| 91 |
+
|
| 92 |
+
augmented_texts = []
|
| 93 |
+
augmented_labels = []
|
| 94 |
+
|
| 95 |
+
if augmentation_type == "synonym":
|
| 96 |
+
if lang == "en":
|
| 97 |
+
aug = naw.SynonymAug(aug_p=aug_p, aug_max=None)
|
| 98 |
+
else:
|
| 99 |
+
aug = naw.ContextualWordEmbsAug(
|
| 100 |
+
model_path='bert-base-multilingual-cased',
|
| 101 |
+
action="substitute",
|
| 102 |
+
aug_p=aug_p,
|
| 103 |
+
device='cpu'
|
| 104 |
+
)
|
| 105 |
+
elif augmentation_type == "insert":
|
| 106 |
+
aug = naw.RandomWordAug(action="insert", aug_p=aug_p)
|
| 107 |
+
elif augmentation_type == "delete":
|
| 108 |
+
aug = naw.RandomWordAug(action="delete", aug_p=aug_p)
|
| 109 |
+
elif augmentation_type == "swap":
|
| 110 |
+
aug = naw.RandomWordAug(action="swap", aug_p=aug_p)
|
| 111 |
+
elif augmentation_type == "eda":
|
| 112 |
+
aug = naw.AntonymAug()
|
| 113 |
+
elif augmentation_type == "back_trans":
|
| 114 |
+
if not model_name:
|
| 115 |
+
if lang == "ru":
|
| 116 |
+
model_name = "Helsinki-NLP/opus-mt-ru-en"
|
| 117 |
+
back_model = "Helsinki-NLP/opus-mt-en-ru"
|
| 118 |
+
else:
|
| 119 |
+
model_name = "Helsinki-NLP/opus-mt-en-ru"
|
| 120 |
+
back_model = "Helsinki-NLP/opus-mt-ru-en"
|
| 121 |
+
else:
|
| 122 |
+
back_model = model_name
|
| 123 |
+
|
| 124 |
+
try:
|
| 125 |
+
from transformers import pipeline
|
| 126 |
+
translator1 = pipeline("translation", model=model_name, tokenizer=model_name)
|
| 127 |
+
translator2 = pipeline("translation", model=back_model, tokenizer=back_model)
|
| 128 |
+
|
| 129 |
+
def back_translate(text):
|
| 130 |
+
try:
|
| 131 |
+
trans = translator1(text)[0]['translation_text']
|
| 132 |
+
back = translator2(trans)[0]['translation_text']
|
| 133 |
+
return back
|
| 134 |
+
except Exception:
|
| 135 |
+
return text
|
| 136 |
+
|
| 137 |
+
augmented = [back_translate(t) for t in texts for _ in range(num_aug)]
|
| 138 |
+
labels_aug = [l for l in labels for _ in range(num_aug)]
|
| 139 |
+
return augmented, labels_aug
|
| 140 |
+
except Exception as e:
|
| 141 |
+
print(f"Back-translation failed: {e}. Falling back to synonym augmentation.")
|
| 142 |
+
aug = naw.ContextualWordEmbsAug(model_path='bert-base-multilingual-cased', aug_p=aug_p)
|
| 143 |
+
elif augmentation_type == "llm":
|
| 144 |
+
raise NotImplementedError("LLM-controlled augmentation requires external API (e.g., OpenAI, YandexGPT)")
|
| 145 |
+
else:
|
| 146 |
+
raise ValueError("Unknown augmentation_type")
|
| 147 |
+
|
| 148 |
+
for text, label in zip(texts, labels):
|
| 149 |
+
for _ in range(num_aug):
|
| 150 |
+
try:
|
| 151 |
+
aug_text = aug.augment(text)
|
| 152 |
+
if isinstance(aug_text, list):
|
| 153 |
+
aug_text = aug_text[0]
|
| 154 |
+
augmented_texts.append(aug_text)
|
| 155 |
+
augmented_labels.append(label)
|
| 156 |
+
except Exception as e:
|
| 157 |
+
augmented_texts.append(text)
|
| 158 |
+
augmented_labels.append(label)
|
| 159 |
+
|
| 160 |
+
return augmented_texts, augmented_labels
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def balance_text_dataset(
|
| 164 |
+
texts: List[str],
|
| 165 |
+
labels: List[Any],
|
| 166 |
+
strategy: str = "augmentation",
|
| 167 |
+
minority_classes: Optional[List[Any]] = None,
|
| 168 |
+
augmentation_type: str = "synonym",
|
| 169 |
+
sampling_method: str = "smote",
|
| 170 |
+
lang: str = "ru",
|
| 171 |
+
embedding_func: Optional[Callable] = None,
|
| 172 |
+
class_weights: bool = False,
|
| 173 |
+
random_state: int = 42
|
| 174 |
+
) -> Union[
|
| 175 |
+
Tuple[List[str], List[Any]], # for augmentation
|
| 176 |
+
Tuple[np.ndarray, np.ndarray, Optional[Dict]] # for sampling + weights
|
| 177 |
+
]:
|
| 178 |
+
label_counts = Counter(labels)
|
| 179 |
+
if minority_classes is None:
|
| 180 |
+
min_count = min(label_counts.values())
|
| 181 |
+
minority_classes = [lbl for lbl, cnt in label_counts.items() if cnt == min_count]
|
| 182 |
+
|
| 183 |
+
if strategy == "augmentation":
|
| 184 |
+
minority_texts = [t for t, l in zip(texts, labels) if l in minority_classes]
|
| 185 |
+
minority_labels = [l for l in labels if l in minority_classes]
|
| 186 |
+
|
| 187 |
+
aug_texts, aug_labels = augment_texts(
|
| 188 |
+
minority_texts, minority_labels,
|
| 189 |
+
augmentation_type=augmentation_type,
|
| 190 |
+
lang=lang,
|
| 191 |
+
num_aug=max(1, int((max(label_counts.values()) / min_count)) - 1),
|
| 192 |
+
random_state=random_state
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
balanced_texts = texts + aug_texts
|
| 196 |
+
balanced_labels = labels + aug_labels
|
| 197 |
+
return balanced_texts, balanced_labels
|
| 198 |
+
|
| 199 |
+
elif strategy == "sampling":
|
| 200 |
+
if embedding_func is None:
|
| 201 |
+
raise ValueError("embedding_func is required for sampling strategy")
|
| 202 |
+
X_embed = np.array([embedding_func(t) for t in texts])
|
| 203 |
+
X_res, y_res = apply_sampling(X_embed, np.array(labels), method=sampling_method, random_state=random_state)
|
| 204 |
+
weights = compute_class_weights(y_res) if class_weights else None
|
| 205 |
+
return X_res, y_res, weights
|
| 206 |
+
|
| 207 |
+
elif strategy == "both":
|
| 208 |
+
aug_texts, aug_labels = balance_text_dataset(
|
| 209 |
+
texts, labels, strategy="augmentation", minority_classes=minority_classes,
|
| 210 |
+
augmentation_type=augmentation_type, lang=lang, random_state=random_state
|
| 211 |
+
)
|
| 212 |
+
if embedding_func is None:
|
| 213 |
+
return aug_texts, aug_labels
|
| 214 |
+
X_embed = np.array([embedding_func(t) for t in aug_texts])
|
| 215 |
+
X_res, y_res = apply_sampling(X_embed, np.array(aug_labels), method=sampling_method, random_state=random_state)
|
| 216 |
+
weights = compute_class_weights(y_res) if class_weights else None
|
| 217 |
+
return X_res, y_res, weights
|
| 218 |
+
else:
|
| 219 |
+
raise ValueError("strategy must be 'augmentation', 'sampling', or 'both'")
|
src/main.py
ADDED
|
@@ -0,0 +1,544 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import numpy as np
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import json
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
import seaborn as sns
|
| 7 |
+
from typing import List, Dict, Any, Union
|
| 8 |
+
import torch
|
| 9 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
|
| 10 |
+
import shap
|
| 11 |
+
|
| 12 |
+
st.set_page_config(
|
| 13 |
+
page_title="Text Classifiers",
|
| 14 |
+
layout="wide",
|
| 15 |
+
initial_sidebar_state="expanded"
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
from text_preprocessing import (
|
| 19 |
+
preprocess_text, get_contextual_embeddings, TextVectorizer
|
| 20 |
+
)
|
| 21 |
+
from classical_classifiers import (
|
| 22 |
+
get_logistic_regression, get_svm_linear, get_random_forest,
|
| 23 |
+
get_gradient_boosting, get_voting_classifier
|
| 24 |
+
)
|
| 25 |
+
from neural_classifiers import get_transformer_classifier
|
| 26 |
+
from model_evaluation import evaluate_model
|
| 27 |
+
from model_interpretation import (
|
| 28 |
+
get_linear_feature_importance,
|
| 29 |
+
analyze_errors,
|
| 30 |
+
get_transformer_attention,
|
| 31 |
+
visualize_attention_weights,
|
| 32 |
+
get_token_importance_captum,
|
| 33 |
+
plot_token_importance
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
import warnings
|
| 37 |
+
|
| 38 |
+
warnings.filterwarnings("ignore")
|
| 39 |
+
|
| 40 |
+
if 'models' not in st.session_state:
|
| 41 |
+
st.session_state.models = {}
|
| 42 |
+
if 'results' not in st.session_state:
|
| 43 |
+
st.session_state.results = {}
|
| 44 |
+
if 'dataset' not in st.session_state:
|
| 45 |
+
st.session_state.dataset = None
|
| 46 |
+
if 'task_type' not in st.session_state:
|
| 47 |
+
st.session_state.task_type = None
|
| 48 |
+
if 'preprocessed' not in st.session_state:
|
| 49 |
+
st.session_state.preprocessed = None
|
| 50 |
+
if 'X' not in st.session_state:
|
| 51 |
+
st.session_state.X = None
|
| 52 |
+
if 'y' not in st.session_state:
|
| 53 |
+
st.session_state.y = None
|
| 54 |
+
if 'feature_names' not in st.session_state:
|
| 55 |
+
st.session_state.feature_names = None
|
| 56 |
+
if 'vectorizer' not in st.session_state:
|
| 57 |
+
st.session_state.vectorizer = None
|
| 58 |
+
if 'vectorizer_type' not in st.session_state:
|
| 59 |
+
st.session_state.vectorizer_type = None
|
| 60 |
+
if 'X_test' not in st.session_state:
|
| 61 |
+
st.session_state.X_test = None
|
| 62 |
+
if 'y_test' not in st.session_state:
|
| 63 |
+
st.session_state.y_test = None
|
| 64 |
+
if 'test_texts' not in st.session_state:
|
| 65 |
+
st.session_state.test_texts = None
|
| 66 |
+
if 'label_encoder' not in st.session_state:
|
| 67 |
+
st.session_state.label_encoder = None
|
| 68 |
+
if 'rubert_model' not in st.session_state:
|
| 69 |
+
st.session_state.rubert_model = None
|
| 70 |
+
if 'rubert_tokenizer' not in st.session_state:
|
| 71 |
+
st.session_state.rubert_tokenizer = None
|
| 72 |
+
if 'rubert_trained' not in st.session_state:
|
| 73 |
+
st.session_state.rubert_trained = False
|
| 74 |
+
|
| 75 |
+
st.sidebar.title("Setup")
|
| 76 |
+
|
| 77 |
+
st.sidebar.subheader("1. Upload Dataset (JSONL)")
|
| 78 |
+
uploaded_file = st.sidebar.file_uploader("Upload .jsonl file", type=["jsonl"])
|
| 79 |
+
|
| 80 |
+
if uploaded_file:
|
| 81 |
+
try:
|
| 82 |
+
raw_data = []
|
| 83 |
+
lines = uploaded_file.getvalue().decode("utf-8").splitlines()
|
| 84 |
+
for line in lines:
|
| 85 |
+
if line.strip():
|
| 86 |
+
raw_data.append(json.loads(line))
|
| 87 |
+
st.session_state.dataset = raw_data
|
| 88 |
+
|
| 89 |
+
first = raw_data[0]
|
| 90 |
+
if 'sentiment' in first:
|
| 91 |
+
st.session_state.task_type = "binary"
|
| 92 |
+
labels = [item['sentiment'] for item in raw_data]
|
| 93 |
+
elif 'category' in first:
|
| 94 |
+
st.session_state.task_type = "multiclass"
|
| 95 |
+
labels = [item['category'] for item in raw_data]
|
| 96 |
+
elif 'tags' in first:
|
| 97 |
+
st.session_state.task_type = "multilabel"
|
| 98 |
+
labels = [item['tags'] for item in raw_data]
|
| 99 |
+
else:
|
| 100 |
+
st.sidebar.error("No label field found")
|
| 101 |
+
st.session_state.task_type = None
|
| 102 |
+
st.session_state.dataset = None
|
| 103 |
+
|
| 104 |
+
if st.session_state.task_type:
|
| 105 |
+
st.sidebar.success(f"Loaded {len(raw_data)} samples. Task: {st.session_state.task_type}")
|
| 106 |
+
if st.session_state.task_type == "binary":
|
| 107 |
+
id2label = {0: "Negative", 1: "Positive"}
|
| 108 |
+
label2id = {"Negative": 0, "Positive": 1}
|
| 109 |
+
elif st.session_state.task_type == "multiclass":
|
| 110 |
+
id2label = {0: "Политика", 1: "Экономика", 2: "Спорт", 3: "Культура"}
|
| 111 |
+
label2id = {"Политика": 0, "Экономика": 1, "Спорт": 2, "Культура": 3}
|
| 112 |
+
else:
|
| 113 |
+
id2label = None
|
| 114 |
+
label2id = None
|
| 115 |
+
|
| 116 |
+
st.session_state.id2label = id2label
|
| 117 |
+
st.session_state.label2id = label2id
|
| 118 |
+
except Exception as e:
|
| 119 |
+
st.sidebar.error(f"Failed to parse JSONL: {e}")
|
| 120 |
+
st.session_state.dataset = None
|
| 121 |
+
|
| 122 |
+
if st.session_state.dataset is not None:
|
| 123 |
+
st.sidebar.subheader("2. Preprocess Text")
|
| 124 |
+
lang = st.sidebar.selectbox("Language", ["ru", "en"], index=0)
|
| 125 |
+
st.session_state.preprocess_lang = 'ru'
|
| 126 |
+
if st.sidebar.button("Run Preprocessing"):
|
| 127 |
+
with st.spinner("Preprocessing..."):
|
| 128 |
+
texts = [item['text'] for item in st.session_state.dataset]
|
| 129 |
+
preprocessed = [preprocess_text(text, lang='ru', remove_stopwords=False) for text in texts]
|
| 130 |
+
st.session_state.preprocessed = preprocessed
|
| 131 |
+
st.sidebar.success("Preprocessing done!")
|
| 132 |
+
|
| 133 |
+
if st.session_state.preprocessed is not None:
|
| 134 |
+
st.sidebar.subheader("3. Vectorization (Classical)")
|
| 135 |
+
vectorizer_type = st.sidebar.selectbox("Method", ["TF-IDF", "RuBERT Embeddings"])
|
| 136 |
+
if st.sidebar.button("Vectorize"):
|
| 137 |
+
with st.spinner("Vectorizing..."):
|
| 138 |
+
if vectorizer_type == "TF-IDF":
|
| 139 |
+
vectorizer = TextVectorizer()
|
| 140 |
+
if not isinstance(st.session_state.preprocessed[0], str):
|
| 141 |
+
st.session_state.preprocessed = [
|
| 142 |
+
' '.join(text) for text in st.session_state.preprocessed
|
| 143 |
+
]
|
| 144 |
+
st.sidebar.write("Using max_features=5000")
|
| 145 |
+
X = vectorizer.tfidf(st.session_state.preprocessed, max_features=5000)
|
| 146 |
+
st.sidebar.write(f"X shape: {X.shape}")
|
| 147 |
+
st.session_state.vectorizer = vectorizer
|
| 148 |
+
st.session_state.feature_names = vectorizer.tfidf_vectorizer.get_feature_names_out()
|
| 149 |
+
else:
|
| 150 |
+
X = []
|
| 151 |
+
for text in st.session_state.preprocessed:
|
| 152 |
+
emb = get_contextual_embeddings([text], model_name="DeepPavlov/rubert-base-cased")
|
| 153 |
+
X.append(emb[0])
|
| 154 |
+
X = np.array(X)
|
| 155 |
+
st.session_state.vectorizer = None
|
| 156 |
+
st.session_state.feature_names = None
|
| 157 |
+
st.session_state.X = X
|
| 158 |
+
st.session_state.vectorizer_type = vectorizer_type
|
| 159 |
+
|
| 160 |
+
if st.session_state.task_type == "binary":
|
| 161 |
+
y = np.array([item['sentiment'] for item in st.session_state.dataset])
|
| 162 |
+
elif st.session_state.task_type == "multiclass":
|
| 163 |
+
y = np.array([item['category'] for item in st.session_state.dataset])
|
| 164 |
+
else:
|
| 165 |
+
y = [item['tags'] for item in st.session_state.dataset]
|
| 166 |
+
st.session_state.y = y
|
| 167 |
+
st.sidebar.success("Vectorization complete!")
|
| 168 |
+
|
| 169 |
+
if st.session_state.X is not None:
|
| 170 |
+
st.sidebar.subheader("4. Train Classical Models")
|
| 171 |
+
model_options = ["Logistic Regression", "SVM", "Random Forest", "XGBoost", "Voting"]
|
| 172 |
+
selected_models = st.sidebar.multiselect("Models", model_options)
|
| 173 |
+
if st.sidebar.button("Train Classical Models"):
|
| 174 |
+
from sklearn.model_selection import train_test_split
|
| 175 |
+
from sklearn.preprocessing import LabelEncoder
|
| 176 |
+
|
| 177 |
+
X = st.session_state.X
|
| 178 |
+
y = st.session_state.y
|
| 179 |
+
|
| 180 |
+
if st.session_state.task_type == "multiclass":
|
| 181 |
+
le = LabelEncoder()
|
| 182 |
+
y_encoded = le.fit_transform(y)
|
| 183 |
+
st.session_state.label_encoder = le
|
| 184 |
+
y_for_split = y_encoded
|
| 185 |
+
else:
|
| 186 |
+
y_for_split = y if st.session_state.task_type == "binary" else np.array([len(tags) for tags in y])
|
| 187 |
+
|
| 188 |
+
if st.session_state.task_type == "multilabel":
|
| 189 |
+
split_idx = int(0.8 * len(X))
|
| 190 |
+
X_train, X_test = X[:split_idx], X[split_idx:]
|
| 191 |
+
y_train, y_test = y[:split_idx], y[split_idx:]
|
| 192 |
+
test_texts = [item['text'] for item in st.session_state.dataset[split_idx:]]
|
| 193 |
+
else:
|
| 194 |
+
indices = np.arange(len(X))
|
| 195 |
+
X_train, X_test, y_train, y_test, idx_train, idx_test = train_test_split(
|
| 196 |
+
X, y_for_split, indices, test_size=0.2,
|
| 197 |
+
stratify=y_for_split if st.session_state.task_type != "multilabel" else None,
|
| 198 |
+
random_state=42
|
| 199 |
+
)
|
| 200 |
+
test_texts = [st.session_state.dataset[i]['text'] for i in idx_test]
|
| 201 |
+
if st.session_state.task_type == "multiclass":
|
| 202 |
+
y_train = le.inverse_transform(y_train)
|
| 203 |
+
y_test = le.inverse_transform(y_test)
|
| 204 |
+
|
| 205 |
+
st.session_state.X_test = X_test
|
| 206 |
+
st.session_state.y_test = y_test
|
| 207 |
+
st.session_state.test_texts = test_texts
|
| 208 |
+
|
| 209 |
+
for name in selected_models:
|
| 210 |
+
try:
|
| 211 |
+
with st.spinner(f"Training {name}..."):
|
| 212 |
+
if name == "Logistic Regression":
|
| 213 |
+
model = get_logistic_regression()
|
| 214 |
+
model.fit(X_train, y_train)
|
| 215 |
+
st.session_state.models[name] = model
|
| 216 |
+
elif name == "SVM":
|
| 217 |
+
model = get_svm_linear()
|
| 218 |
+
model.fit(X_train, y_train)
|
| 219 |
+
st.session_state.models[name] = model
|
| 220 |
+
elif name == "Random Forest":
|
| 221 |
+
model = get_random_forest()
|
| 222 |
+
model.fit(X_train, y_train)
|
| 223 |
+
st.session_state.models[name] = model
|
| 224 |
+
elif name == "XGBoost":
|
| 225 |
+
model = get_gradient_boosting("xgb", n_estimators=100)
|
| 226 |
+
model.fit(X_train, y_train)
|
| 227 |
+
st.session_state.models[name] = model
|
| 228 |
+
elif name == "Voting":
|
| 229 |
+
model = get_voting_classifier()
|
| 230 |
+
model.fit(X_train, y_train)
|
| 231 |
+
st.session_state.models[name] = model
|
| 232 |
+
|
| 233 |
+
if st.session_state.task_type != "multilabel":
|
| 234 |
+
metrics = evaluate_model(model, X_test, y_test)
|
| 235 |
+
st.session_state.results[name] = metrics
|
| 236 |
+
except Exception as e:
|
| 237 |
+
st.sidebar.error(f"Failed to train {name}: {e}")
|
| 238 |
+
continue
|
| 239 |
+
st.sidebar.success("Classical models trained!")
|
| 240 |
+
|
| 241 |
+
if st.session_state.dataset is not None and st.session_state.task_type in ["binary", "multiclass"]:
|
| 242 |
+
st.sidebar.subheader("5. Train RuBERT (Transformer)")
|
| 243 |
+
if st.sidebar.button("Train RuBERT"):
|
| 244 |
+
with st.spinner("Loading RuBERT..."):
|
| 245 |
+
try:
|
| 246 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
|
| 247 |
+
|
| 248 |
+
num_labels = 2 if st.session_state.task_type == "binary" else 4
|
| 249 |
+
model_name = "DeepPavlov/rubert-base-cased"
|
| 250 |
+
|
| 251 |
+
config = AutoConfig.from_pretrained(
|
| 252 |
+
model_name,
|
| 253 |
+
num_labels=num_labels,
|
| 254 |
+
id2label=st.session_state.id2label,
|
| 255 |
+
label2id=st.session_state.label2id
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 259 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_name, config=config)
|
| 260 |
+
|
| 261 |
+
st.session_state.rubert_model = model
|
| 262 |
+
st.session_state.rubert_tokenizer = tokenizer
|
| 263 |
+
st.session_state.rubert_trained = True
|
| 264 |
+
st.sidebar.success("RuBERT loaded with correct labels!")
|
| 265 |
+
except Exception as e:
|
| 266 |
+
st.sidebar.error(f"RuBERT loading failed: {e}")
|
| 267 |
+
st.exception(e)
|
| 268 |
+
|
| 269 |
+
st.title("Text Classifiers")
|
| 270 |
+
|
| 271 |
+
tab1, tab2, tab3, tab4 = st.tabs([
|
| 272 |
+
"Classify",
|
| 273 |
+
"Interpret",
|
| 274 |
+
"Compare",
|
| 275 |
+
"Error Analysis"
|
| 276 |
+
])
|
| 277 |
+
|
| 278 |
+
with tab1:
|
| 279 |
+
st.subheader("Classify New Text")
|
| 280 |
+
input_text = st.text_area("Enter text", "Сегодня прошёл важный матч по хоккею.")
|
| 281 |
+
|
| 282 |
+
if st.button("Classify"):
|
| 283 |
+
cols = st.columns(2)
|
| 284 |
+
with cols[0]:
|
| 285 |
+
st.markdown("### Classical Models")
|
| 286 |
+
if not st.session_state.models:
|
| 287 |
+
st.info("No classical models trained")
|
| 288 |
+
else:
|
| 289 |
+
tokens = preprocess_text(input_text, lang='ru', remove_stopwords=False)
|
| 290 |
+
preprocessed = " ".join(tokens)
|
| 291 |
+
if st.session_state.vectorizer_type == "TF-IDF":
|
| 292 |
+
X_input = st.session_state.vectorizer.tfidf_vectorizer.transform([preprocessed]).toarray()
|
| 293 |
+
else:
|
| 294 |
+
X_input = get_contextual_embeddings([preprocessed], model_name="DeepPavlov/rubert-base-cased")
|
| 295 |
+
|
| 296 |
+
for name, model in st.session_state.models.items():
|
| 297 |
+
pred = model.predict(X_input)[0]
|
| 298 |
+
st.write(f"**{name}**: {pred}")
|
| 299 |
+
if hasattr(model, "predict_proba"):
|
| 300 |
+
proba = model.predict_proba(X_input)[0]
|
| 301 |
+
st.write(f"Probabilities: {dict(zip(model.classes_, proba))}")
|
| 302 |
+
|
| 303 |
+
with cols[1]:
|
| 304 |
+
st.markdown("### RuBERT")
|
| 305 |
+
if not st.session_state.rubert_trained:
|
| 306 |
+
st.info("Train RuBERT in sidebar")
|
| 307 |
+
else:
|
| 308 |
+
try:
|
| 309 |
+
from transformers import pipeline
|
| 310 |
+
|
| 311 |
+
pipe = pipeline(
|
| 312 |
+
"text-classification",
|
| 313 |
+
model=st.session_state.rubert_model,
|
| 314 |
+
tokenizer=st.session_state.rubert_tokenizer,
|
| 315 |
+
device=-1
|
| 316 |
+
)
|
| 317 |
+
result = pipe(input_text)
|
| 318 |
+
label = result[0]['label']
|
| 319 |
+
confidence = result[0]['score']
|
| 320 |
+
|
| 321 |
+
if label.startswith("LABEL_") and st.session_state.id2label:
|
| 322 |
+
label_id = int(label.replace("LABEL_", ""))
|
| 323 |
+
readable_label = st.session_state.id2label.get(label_id, label)
|
| 324 |
+
else:
|
| 325 |
+
readable_label = label
|
| 326 |
+
|
| 327 |
+
st.write(f"**Prediction**: {readable_label}")
|
| 328 |
+
st.write(f"**Confidence**: {confidence:.3f}")
|
| 329 |
+
except Exception as e:
|
| 330 |
+
st.error(f"RuBERT inference failed: {e}")
|
| 331 |
+
|
| 332 |
+
with tab2:
|
| 333 |
+
subtab1, subtab2, subtab3 = st.tabs(["SHAP / LIME", "Attention Map", "Captum Heatmap"])
|
| 334 |
+
|
| 335 |
+
with subtab1:
|
| 336 |
+
st.subheader("SHAP: Local Explanation for One Text")
|
| 337 |
+
if not st.session_state.models:
|
| 338 |
+
st.info("Train a classical model first")
|
| 339 |
+
else:
|
| 340 |
+
model_name = st.selectbox("Model", list(st.session_state.models.keys()), key="shap_model")
|
| 341 |
+
text_for_explain = st.text_area("Text to explain", "Прекрасная новость о росте экономики!", key="shap_text")
|
| 342 |
+
top_k = st.slider("Top features to show", 5, 30, 15)
|
| 343 |
+
|
| 344 |
+
if st.button("Explain with SHAP"):
|
| 345 |
+
try:
|
| 346 |
+
import shap
|
| 347 |
+
|
| 348 |
+
model = st.session_state.models[model_name]
|
| 349 |
+
tokens = preprocess_text(text_for_explain, lang='ru', remove_stopwords=False)
|
| 350 |
+
preprocessed = " ".join(tokens)
|
| 351 |
+
|
| 352 |
+
if st.session_state.vectorizer_type == "TF-IDF":
|
| 353 |
+
X_input = st.session_state.vectorizer.tfidf_vectorizer.transform([preprocessed]).toarray()
|
| 354 |
+
feature_names = st.session_state.feature_names
|
| 355 |
+
else:
|
| 356 |
+
X_input = get_contextual_embeddings([preprocessed], model_name="DeepPavlov/rubert-base-cased")
|
| 357 |
+
feature_names = [f"emb_{i}" for i in range(X_input.shape[1])]
|
| 358 |
+
|
| 359 |
+
background = st.session_state.X[:100]
|
| 360 |
+
# st.write(f"DEBUG: st.session_state.X shape = {st.session_state.X.shape}")
|
| 361 |
+
# st.write(f"DEBUG: X_input shape = {X_input.shape}")
|
| 362 |
+
# st.write(f'DEBUG: background shape = {background.shape}')
|
| 363 |
+
if "tree" in str(type(model)).lower():
|
| 364 |
+
explainer = shap.TreeExplainer(model)
|
| 365 |
+
shap_values = explainer.shap_values(X_input)
|
| 366 |
+
else:
|
| 367 |
+
explainer = shap.KernelExplainer(model.predict_proba, background)
|
| 368 |
+
shap_values = explainer.shap_values(X_input, nsamples=200)
|
| 369 |
+
|
| 370 |
+
if isinstance(shap_values, list):
|
| 371 |
+
probs = model.predict_proba(X_input)[0]
|
| 372 |
+
target_class = int(np.argmax(probs))
|
| 373 |
+
single_shap = shap_values[target_class][0]
|
| 374 |
+
expected_val = explainer.expected_value[target_class]
|
| 375 |
+
else:
|
| 376 |
+
sv = shap_values
|
| 377 |
+
if sv.ndim == 1:
|
| 378 |
+
single_shap = sv
|
| 379 |
+
expected_val = explainer.expected_value
|
| 380 |
+
elif sv.ndim == 2:
|
| 381 |
+
if sv.shape[0] == 1:
|
| 382 |
+
single_shap = sv[0]
|
| 383 |
+
expected_val = explainer.expected_value
|
| 384 |
+
elif sv.shape[1] == X_input.shape[1]:
|
| 385 |
+
probs = model.predict_proba(X_input)[0]
|
| 386 |
+
target_class = int(np.argmax(probs))
|
| 387 |
+
single_shap = sv[:, target_class]
|
| 388 |
+
expected_val = explainer.expected_value[target_class] if isinstance(
|
| 389 |
+
explainer.expected_value, (list, np.ndarray)) else explainer.expected_value
|
| 390 |
+
else:
|
| 391 |
+
single_shap = sv[0]
|
| 392 |
+
expected_val = explainer.expected_value
|
| 393 |
+
elif sv.ndim == 3:
|
| 394 |
+
if sv.shape[0] != 1:
|
| 395 |
+
raise ValueError("SHAP explanation for more than one sample not supported")
|
| 396 |
+
probs = model.predict_proba(X_input)[0]
|
| 397 |
+
target_class = int(np.argmax(probs))
|
| 398 |
+
single_shap = sv[0, :, target_class]
|
| 399 |
+
if isinstance(explainer.expected_value, (list, np.ndarray)) and len(
|
| 400 |
+
explainer.expected_value) == sv.shape[2]:
|
| 401 |
+
expected_val = explainer.expected_value[target_class]
|
| 402 |
+
else:
|
| 403 |
+
expected_val = explainer.expected_value
|
| 404 |
+
else:
|
| 405 |
+
raise ValueError(f"Unsupported SHAP shape: {sv.shape}")
|
| 406 |
+
|
| 407 |
+
single_shap = np.array(single_shap).flatten()
|
| 408 |
+
if single_shap.shape[0] != X_input.shape[1]:
|
| 409 |
+
raise ValueError(
|
| 410 |
+
f"SHAP vector length {single_shap.shape[0]} != input features {X_input.shape[1]}")
|
| 411 |
+
|
| 412 |
+
if st.session_state.vectorizer_type == "TF-IDF":
|
| 413 |
+
text_vector = X_input[0]
|
| 414 |
+
nonzero_indices = np.where(text_vector != 0)[0]
|
| 415 |
+
if len(nonzero_indices) == 0:
|
| 416 |
+
st.warning("No known words from training vocabulary found in this text.")
|
| 417 |
+
else:
|
| 418 |
+
filtered_shap = single_shap[nonzero_indices]
|
| 419 |
+
filtered_features = text_vector[nonzero_indices]
|
| 420 |
+
filtered_names = [st.session_state.feature_names[i] for i in nonzero_indices]
|
| 421 |
+
|
| 422 |
+
explanation = shap.Explanation(
|
| 423 |
+
values=filtered_shap,
|
| 424 |
+
base_values=expected_val,
|
| 425 |
+
data=filtered_features,
|
| 426 |
+
feature_names=filtered_names
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
plt.figure(figsize=(10, min(8, top_k * 0.3)))
|
| 430 |
+
shap.plots.waterfall(explanation, max_display=top_k, show=False)
|
| 431 |
+
st.pyplot(plt.gcf())
|
| 432 |
+
plt.close()
|
| 433 |
+
else:
|
| 434 |
+
explanation = shap.Explanation(
|
| 435 |
+
values=single_shap,
|
| 436 |
+
base_values=expected_val,
|
| 437 |
+
data=X_input[0],
|
| 438 |
+
feature_names=feature_names
|
| 439 |
+
)
|
| 440 |
+
plt.figure(figsize=(10, min(8, top_k * 0.3)))
|
| 441 |
+
shap.plots.waterfall(explanation, max_display=top_k, show=False)
|
| 442 |
+
st.pyplot(plt.gcf())
|
| 443 |
+
plt.close()
|
| 444 |
+
|
| 445 |
+
except Exception as e:
|
| 446 |
+
st.error(f"SHAP error: {e}")
|
| 447 |
+
st.exception(e)
|
| 448 |
+
|
| 449 |
+
with subtab2:
|
| 450 |
+
st.subheader("Transformer Attention Map")
|
| 451 |
+
if not st.session_state.rubert_trained:
|
| 452 |
+
st.info("Train RuBERT first")
|
| 453 |
+
else:
|
| 454 |
+
text_att = st.text_area("Text for attention", "Матч завершился победой ЦСКА", key="att_text")
|
| 455 |
+
layer = st.slider("Layer", 0, 11, 6)
|
| 456 |
+
head = st.slider("Head", 0, 11, 0)
|
| 457 |
+
if st.button("Visualize Attention"):
|
| 458 |
+
try:
|
| 459 |
+
tokens, attn = get_transformer_attention(
|
| 460 |
+
st.session_state.rubert_model,
|
| 461 |
+
st.session_state.rubert_tokenizer,
|
| 462 |
+
text_att,
|
| 463 |
+
device="cpu"
|
| 464 |
+
)
|
| 465 |
+
weights = attn[layer, head, :len(tokens), :len(tokens)]
|
| 466 |
+
|
| 467 |
+
fig, ax = plt.subplots(figsize=(10, 4))
|
| 468 |
+
sns.heatmap(
|
| 469 |
+
weights,
|
| 470 |
+
xticklabels=tokens,
|
| 471 |
+
yticklabels=tokens,
|
| 472 |
+
cmap="viridis",
|
| 473 |
+
ax=ax
|
| 474 |
+
)
|
| 475 |
+
plt.xticks(rotation=45, ha="right")
|
| 476 |
+
plt.yticks(rotation=0)
|
| 477 |
+
plt.title(f"Attention: Layer {layer}, Head {head}")
|
| 478 |
+
st.pyplot(fig)
|
| 479 |
+
plt.close(fig)
|
| 480 |
+
except Exception as e:
|
| 481 |
+
st.error(f"Attention failed: {e}")
|
| 482 |
+
st.exception(e)
|
| 483 |
+
|
| 484 |
+
with subtab3:
|
| 485 |
+
st.subheader("Token Importance (Captum)")
|
| 486 |
+
if not st.session_state.rubert_trained:
|
| 487 |
+
st.info("Train RuBERT first")
|
| 488 |
+
else:
|
| 489 |
+
text_captum = st.text_area("Text for Captum", "Это очень плохая новость для политики", key="captum_text")
|
| 490 |
+
method = "IntegratedGradients"
|
| 491 |
+
if st.button("Compute Token Importance"):
|
| 492 |
+
try:
|
| 493 |
+
tokens, importance = get_token_importance_captum(
|
| 494 |
+
st.session_state.rubert_model,
|
| 495 |
+
st.session_state.rubert_tokenizer,
|
| 496 |
+
text_captum,
|
| 497 |
+
device="cpu"
|
| 498 |
+
)
|
| 499 |
+
valid = [(t, imp) for t, imp in zip(tokens, importance) if t not in ["[CLS]", "[SEP]", "[PAD]"]]
|
| 500 |
+
if valid:
|
| 501 |
+
tokens_clean, imp_clean = zip(*valid)
|
| 502 |
+
indices = np.argsort(np.abs(imp_clean))[-15:][::-1]
|
| 503 |
+
tokens_top = [tokens_clean[i] for i in indices]
|
| 504 |
+
imp_top = [imp_clean[i] for i in indices]
|
| 505 |
+
|
| 506 |
+
fig, ax = plt.subplots(figsize=(8, 6))
|
| 507 |
+
colors = ["red" if x < 0 else "green" for x in imp_top]
|
| 508 |
+
ax.barh(range(len(imp_top)), imp_top, color=colors)
|
| 509 |
+
ax.set_yticks(range(len(imp_top)))
|
| 510 |
+
ax.set_yticklabels(tokens_top)
|
| 511 |
+
ax.invert_yaxis()
|
| 512 |
+
ax.set_xlabel("Attribution Score")
|
| 513 |
+
ax.set_title("Token Importance")
|
| 514 |
+
st.pyplot(fig)
|
| 515 |
+
plt.close(fig)
|
| 516 |
+
else:
|
| 517 |
+
st.warning("No valid tokens")
|
| 518 |
+
except Exception as e:
|
| 519 |
+
st.error(f"Captum failed: {e}")
|
| 520 |
+
st.exception(e)
|
| 521 |
+
|
| 522 |
+
with tab3:
|
| 523 |
+
st.subheader("Model Comparison")
|
| 524 |
+
if st.session_state.results:
|
| 525 |
+
df = pd.DataFrame(st.session_state.results).T
|
| 526 |
+
st.dataframe(df)
|
| 527 |
+
else:
|
| 528 |
+
st.info("Train models to see metrics")
|
| 529 |
+
|
| 530 |
+
with tab4:
|
| 531 |
+
st.subheader("Error Analysis")
|
| 532 |
+
if st.session_state.X_test is None:
|
| 533 |
+
st.info("Train models first")
|
| 534 |
+
else:
|
| 535 |
+
model_name = st.selectbox("Model for error analysis", list(st.session_state.models.keys()), key="err_model")
|
| 536 |
+
if st.button("Analyze Errors"):
|
| 537 |
+
model = st.session_state.models[model_name]
|
| 538 |
+
y_pred = model.predict(st.session_state.X_test)
|
| 539 |
+
errors = analyze_errors(
|
| 540 |
+
st.session_state.y_test,
|
| 541 |
+
y_pred,
|
| 542 |
+
st.session_state.test_texts
|
| 543 |
+
)
|
| 544 |
+
st.dataframe(errors[['text', 'true_label', 'pred_label']].head(20))
|
src/model_evaluation.py
ADDED
|
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Any, Union, Callable, Optional, Tuple, List
|
| 2 |
+
import numpy as np
|
| 3 |
+
import pandas as pd
|
| 4 |
+
from collections import defaultdict
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from sklearn.model_selection import (
|
| 8 |
+
StratifiedKFold, GroupKFold, TimeSeriesSplit,
|
| 9 |
+
GridSearchCV, RandomizedSearchCV
|
| 10 |
+
)
|
| 11 |
+
from sklearn.metrics import (
|
| 12 |
+
accuracy_score, precision_score, recall_score, f1_score,
|
| 13 |
+
roc_auc_score, average_precision_score, log_loss,
|
| 14 |
+
confusion_matrix, classification_report
|
| 15 |
+
)
|
| 16 |
+
from sklearn.base import BaseEstimator
|
| 17 |
+
import warnings
|
| 18 |
+
|
| 19 |
+
warnings.filterwarnings("ignore")
|
| 20 |
+
|
| 21 |
+
OPTUNA_AVAILABLE = False
|
| 22 |
+
HYPEROPT_AVAILABLE = False
|
| 23 |
+
try:
|
| 24 |
+
import optuna
|
| 25 |
+
from optuna.samplers import TPESampler
|
| 26 |
+
|
| 27 |
+
OPTUNA_AVAILABLE = True
|
| 28 |
+
except ImportError:
|
| 29 |
+
pass
|
| 30 |
+
|
| 31 |
+
try:
|
| 32 |
+
from hyperopt import fmin, tpe, hp, Trials, STATUS_OK
|
| 33 |
+
|
| 34 |
+
HYPEROPT_AVAILABLE = True
|
| 35 |
+
except ImportError:
|
| 36 |
+
pass
|
| 37 |
+
|
| 38 |
+
WANDB_AVAILABLE = False
|
| 39 |
+
try:
|
| 40 |
+
import wandb
|
| 41 |
+
|
| 42 |
+
WANDB_AVAILABLE = True
|
| 43 |
+
except ImportError:
|
| 44 |
+
pass
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def get_cv_splitter(
|
| 48 |
+
cv_type: str = "stratified",
|
| 49 |
+
n_splits: int = 5,
|
| 50 |
+
groups: Optional[np.ndarray] = None,
|
| 51 |
+
random_state: int = 42
|
| 52 |
+
):
|
| 53 |
+
if cv_type == "stratified":
|
| 54 |
+
return StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state)
|
| 55 |
+
elif cv_type == "group":
|
| 56 |
+
if groups is None:
|
| 57 |
+
raise ValueError("groups must be provided for GroupKFold")
|
| 58 |
+
return GroupKFold(n_splits=n_splits)
|
| 59 |
+
elif cv_type == "time":
|
| 60 |
+
return TimeSeriesSplit(n_splits=n_splits)
|
| 61 |
+
else:
|
| 62 |
+
raise ValueError("cv_type must be 'stratified', 'group', or 'time'")
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def grid_search_cv(
|
| 66 |
+
model: BaseEstimator,
|
| 67 |
+
X: np.ndarray,
|
| 68 |
+
y: np.ndarray,
|
| 69 |
+
param_grid: Dict[str, List],
|
| 70 |
+
cv_type: str = "stratified",
|
| 71 |
+
n_splits: int = 5,
|
| 72 |
+
scoring: str = "f1_macro",
|
| 73 |
+
groups: Optional[np.ndarray] = None,
|
| 74 |
+
verbose: int = 1
|
| 75 |
+
) -> GridSearchCV:
|
| 76 |
+
cv = get_cv_splitter(cv_type, n_splits, groups)
|
| 77 |
+
search = GridSearchCV(
|
| 78 |
+
model, param_grid, cv=cv, scoring=scoring, verbose=verbose, n_jobs=-1
|
| 79 |
+
)
|
| 80 |
+
search.fit(X, y)
|
| 81 |
+
return search
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def random_search_cv(
|
| 85 |
+
model: BaseEstimator,
|
| 86 |
+
X: np.ndarray,
|
| 87 |
+
y: np.ndarray,
|
| 88 |
+
param_distributions: Dict[str, Any],
|
| 89 |
+
n_iter: int = 20,
|
| 90 |
+
cv_type: str = "stratified",
|
| 91 |
+
n_splits: int = 5,
|
| 92 |
+
scoring: str = "f1_macro",
|
| 93 |
+
groups: Optional[np.ndarray] = None,
|
| 94 |
+
verbose: int = 1
|
| 95 |
+
) -> RandomizedSearchCV:
|
| 96 |
+
cv = get_cv_splitter(cv_type, n_splits, groups)
|
| 97 |
+
search = RandomizedSearchCV(
|
| 98 |
+
model, param_distributions, n_iter=n_iter, cv=cv,
|
| 99 |
+
scoring=scoring, verbose=verbose, n_jobs=-1, random_state=42
|
| 100 |
+
)
|
| 101 |
+
search.fit(X, y)
|
| 102 |
+
return search
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def _optuna_objective(
|
| 106 |
+
trial,
|
| 107 |
+
model_fn: Callable,
|
| 108 |
+
X: np.ndarray,
|
| 109 |
+
y: np.ndarray,
|
| 110 |
+
cv,
|
| 111 |
+
scoring: str = "f1_macro"
|
| 112 |
+
) -> float:
|
| 113 |
+
if "logistic" in model_fn.__name__.lower():
|
| 114 |
+
C = trial.suggest_float("C", 1e-4, 1e2, log=True)
|
| 115 |
+
penalty = trial.suggest_categorical("penalty", ["l1", "l2"])
|
| 116 |
+
solver = "liblinear" if penalty == "l1" else "lbfgs"
|
| 117 |
+
model = model_fn(C=C, penalty=penalty, solver=solver)
|
| 118 |
+
elif "random_forest" in model_fn.__name__.lower():
|
| 119 |
+
n_estimators = trial.suggest_int("n_estimators", 50, 300)
|
| 120 |
+
max_depth = trial.suggest_int("max_depth", 3, 20)
|
| 121 |
+
model = model_fn(n_estimators=n_estimators, max_depth=max_depth)
|
| 122 |
+
else:
|
| 123 |
+
model = model_fn(trial)
|
| 124 |
+
|
| 125 |
+
scores = []
|
| 126 |
+
for train_idx, val_idx in cv.split(X, y):
|
| 127 |
+
X_train, X_val = X[train_idx], X[val_idx]
|
| 128 |
+
y_train, y_val = y[train_idx], y[val_idx]
|
| 129 |
+
model.fit(X_train, y_train)
|
| 130 |
+
y_pred = model.predict(X_val)
|
| 131 |
+
if scoring == "f1_macro":
|
| 132 |
+
score = f1_score(y_val, y_pred, average="macro")
|
| 133 |
+
elif scoring == "roc_auc":
|
| 134 |
+
y_proba = model.predict_proba(X_val)[:, 1]
|
| 135 |
+
score = roc_auc_score(y_val, y_proba)
|
| 136 |
+
else:
|
| 137 |
+
raise ValueError(f"Scoring {scoring} not implemented in custom Optuna loop")
|
| 138 |
+
scores.append(score)
|
| 139 |
+
return np.mean(scores)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def optuna_tuning(
|
| 143 |
+
model_fn: Callable,
|
| 144 |
+
X: np.ndarray,
|
| 145 |
+
y: np.ndarray,
|
| 146 |
+
n_trials: int = 50,
|
| 147 |
+
cv_type: str = "stratified",
|
| 148 |
+
n_splits: int = 5,
|
| 149 |
+
scoring: str = "f1_macro",
|
| 150 |
+
groups: Optional[np.ndarray] = None,
|
| 151 |
+
direction: str = "maximize"
|
| 152 |
+
) -> optuna.Study:
|
| 153 |
+
cv = get_cv_splitter(cv_type, n_splits, groups)
|
| 154 |
+
study = optuna.create_study(direction=direction, sampler=TPESampler(seed=42))
|
| 155 |
+
study.optimize(
|
| 156 |
+
lambda trial: _optuna_objective(trial, model_fn, X, y, cv, scoring),
|
| 157 |
+
n_trials=n_trials
|
| 158 |
+
)
|
| 159 |
+
return study
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def hyperopt_tuning(
|
| 163 |
+
model_fn: Callable,
|
| 164 |
+
X: np.ndarray,
|
| 165 |
+
y: np.ndarray,
|
| 166 |
+
space: Dict,
|
| 167 |
+
max_evals: int = 50,
|
| 168 |
+
cv_type: str = "stratified",
|
| 169 |
+
n_splits: int = 5,
|
| 170 |
+
scoring: str = "f1_macro",
|
| 171 |
+
groups: Optional[np.ndarray] = None
|
| 172 |
+
):
|
| 173 |
+
cv = get_cv_splitter(cv_type, n_splits, groups)
|
| 174 |
+
|
| 175 |
+
def objective(params):
|
| 176 |
+
model = model_fn(**params)
|
| 177 |
+
scores = []
|
| 178 |
+
for train_idx, val_idx in cv.split(X, y):
|
| 179 |
+
X_train, X_val = X[train_idx], X[val_idx]
|
| 180 |
+
y_train, y_val = y[train_idx], y[val_idx]
|
| 181 |
+
model.fit(X_train, y_train)
|
| 182 |
+
y_pred = model.predict(X_val)
|
| 183 |
+
if scoring == "f1_macro":
|
| 184 |
+
score = f1_score(y_val, y_pred, average="macro")
|
| 185 |
+
elif scoring == "roc_auc":
|
| 186 |
+
y_proba = model.predict_proba(X_val)[:, 1]
|
| 187 |
+
score = roc_auc_score(y_val, y_proba)
|
| 188 |
+
else:
|
| 189 |
+
score = -1
|
| 190 |
+
scores.append(-score)
|
| 191 |
+
return {'loss': -np.mean(scores), 'status': STATUS_OK}
|
| 192 |
+
|
| 193 |
+
trials = Trials()
|
| 194 |
+
best = fmin(fn=objective, space=space, algo=tpe.suggest, max_evals=max_evals, trials=trials)
|
| 195 |
+
return best, trials
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def compute_classification_metrics(
|
| 199 |
+
y_true: np.ndarray,
|
| 200 |
+
y_pred: np.ndarray,
|
| 201 |
+
y_proba: Optional[np.ndarray] = None,
|
| 202 |
+
average: str = "macro"
|
| 203 |
+
) -> Dict[str, float]:
|
| 204 |
+
metrics = {
|
| 205 |
+
"accuracy": accuracy_score(y_true, y_pred),
|
| 206 |
+
"precision": precision_score(y_true, y_pred, average=average, zero_division=0),
|
| 207 |
+
"recall": recall_score(y_true, y_pred, average=average, zero_division=0),
|
| 208 |
+
"f1": f1_score(y_true, y_pred, average=average, zero_division=0),
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
if y_proba is not None:
|
| 212 |
+
if len(np.unique(y_true)) == 2:
|
| 213 |
+
metrics["roc_auc"] = roc_auc_score(y_true, y_proba[:, 1])
|
| 214 |
+
metrics["pr_auc"] = average_precision_score(y_true, y_proba[:, 1])
|
| 215 |
+
metrics["log_loss"] = log_loss(y_true, y_proba)
|
| 216 |
+
else:
|
| 217 |
+
try:
|
| 218 |
+
metrics["roc_auc"] = roc_auc_score(y_true, y_proba, multi_class="ovr", average=average)
|
| 219 |
+
metrics["pr_auc"] = average_precision_score(y_true, y_proba, average=average)
|
| 220 |
+
metrics["log_loss"] = log_loss(y_true, y_proba)
|
| 221 |
+
except ValueError:
|
| 222 |
+
metrics["roc_auc"] = np.nan
|
| 223 |
+
metrics["pr_auc"] = np.nan
|
| 224 |
+
|
| 225 |
+
return metrics
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def evaluate_model(
|
| 229 |
+
model: BaseEstimator,
|
| 230 |
+
X_test: np.ndarray,
|
| 231 |
+
y_test: np.ndarray,
|
| 232 |
+
average: str = "macro",
|
| 233 |
+
return_pred: bool = False
|
| 234 |
+
) -> Union[Dict[str, float], Tuple[Dict[str, float], np.ndarray, Optional[np.ndarray]]]:
|
| 235 |
+
y_pred = model.predict(X_test)
|
| 236 |
+
y_proba = None
|
| 237 |
+
if hasattr(model, "predict_proba"):
|
| 238 |
+
y_proba = model.predict_proba(X_test)
|
| 239 |
+
|
| 240 |
+
metrics = compute_classification_metrics(y_test, y_pred, y_proba, average=average)
|
| 241 |
+
|
| 242 |
+
if return_pred:
|
| 243 |
+
return metrics, y_pred, y_proba
|
| 244 |
+
return metrics
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def get_early_stopping(
|
| 248 |
+
monitor: str = "val_loss",
|
| 249 |
+
patience: int = 5,
|
| 250 |
+
mode: str = "min",
|
| 251 |
+
framework: str = "keras"
|
| 252 |
+
):
|
| 253 |
+
if framework == "keras":
|
| 254 |
+
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
|
| 255 |
+
es = EarlyStopping(monitor=monitor, patience=patience, restore_best_weights=True, mode=mode)
|
| 256 |
+
reduce_lr = ReduceLROnPlateau(monitor=monitor, factor=0.5, patience=3, min_lr=1e-7, mode=mode)
|
| 257 |
+
return [es, reduce_lr]
|
| 258 |
+
elif framework == "pytorch":
|
| 259 |
+
raise NotImplementedError("PyTorch callbacks require custom training loop")
|
| 260 |
+
else:
|
| 261 |
+
raise ValueError("framework must be 'keras' or 'pytorch'")
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def init_wandb(
|
| 265 |
+
project_name: str = "text-classification",
|
| 266 |
+
run_name: Optional[str] = None,
|
| 267 |
+
config: Optional[Dict] = None
|
| 268 |
+
):
|
| 269 |
+
if not WANDB_AVAILABLE:
|
| 270 |
+
return None
|
| 271 |
+
wandb.init(project=project_name, name=run_name, config=config)
|
| 272 |
+
return wandb
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def log_metrics_to_wandb(metrics: Dict[str, float]):
|
| 276 |
+
if WANDB_AVAILABLE and wandb.run:
|
| 277 |
+
wandb.log(metrics)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def suggest_transformer_hparams(trial) -> Dict[str, Any]:
|
| 281 |
+
return {
|
| 282 |
+
"learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True),
|
| 283 |
+
"per_device_train_batch_size": trial.suggest_categorical("batch_size", [8, 16, 32]),
|
| 284 |
+
"num_train_epochs": trial.suggest_int("num_train_epochs", 2, 6),
|
| 285 |
+
"weight_decay": trial.suggest_float("weight_decay", 0.0, 0.3),
|
| 286 |
+
"warmup_ratio": trial.suggest_float("warmup_ratio", 0.0, 0.2),
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def evaluate_transformer_outputs(
|
| 291 |
+
y_true: List[int],
|
| 292 |
+
y_pred: List[int],
|
| 293 |
+
y_logits: Optional[np.ndarray] = None
|
| 294 |
+
) -> Dict[str, float]:
|
| 295 |
+
y_true = np.array(y_true)
|
| 296 |
+
y_pred = np.array(y_pred)
|
| 297 |
+
if y_logits is not None:
|
| 298 |
+
y_proba = torch.softmax(torch.tensor(y_logits), dim=-1).numpy()
|
| 299 |
+
else:
|
| 300 |
+
y_proba = None
|
| 301 |
+
return compute_classification_metrics(y_true, y_pred, y_proba, average="macro")
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def confusion_matrix_df(y_true: np.ndarray, y_pred: np.ndarray, labels: Optional[List] = None) -> pd.DataFrame:
|
| 305 |
+
cm = confusion_matrix(y_true, y_pred, labels=labels)
|
| 306 |
+
if labels is None:
|
| 307 |
+
labels = sorted(np.unique(y_true))
|
| 308 |
+
return pd.DataFrame(cm, index=[f"True_{l}" for l in labels], columns=[f"Pred_{l}" for l in labels])
|
src/model_interpretation.py
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Dict, Any, Optional, Union, Callable, Tuple
|
| 2 |
+
import numpy as np
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import seaborn as sns
|
| 6 |
+
from collections import defaultdict
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from sklearn.base import BaseEstimator
|
| 10 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
| 11 |
+
from sklearn.decomposition import PCA
|
| 12 |
+
from sklearn.manifold import TSNE
|
| 13 |
+
import warnings
|
| 14 |
+
warnings.filterwarnings("ignore")
|
| 15 |
+
|
| 16 |
+
SHAP_AVAILABLE = False
|
| 17 |
+
LIME_AVAILABLE = False
|
| 18 |
+
CAPTUM_AVAILABLE = False
|
| 19 |
+
UMAP_AVAILABLE = False
|
| 20 |
+
|
| 21 |
+
try:
|
| 22 |
+
import shap
|
| 23 |
+
SHAP_AVAILABLE = True
|
| 24 |
+
except ImportError:
|
| 25 |
+
pass
|
| 26 |
+
|
| 27 |
+
try:
|
| 28 |
+
import lime
|
| 29 |
+
import lime.lime_text
|
| 30 |
+
LIME_AVAILABLE = True
|
| 31 |
+
except ImportError:
|
| 32 |
+
pass
|
| 33 |
+
|
| 34 |
+
try:
|
| 35 |
+
import captum
|
| 36 |
+
import captum.attr
|
| 37 |
+
CAPTUM_AVAILABLE = True
|
| 38 |
+
except ImportError:
|
| 39 |
+
pass
|
| 40 |
+
|
| 41 |
+
try:
|
| 42 |
+
import umap
|
| 43 |
+
UMAP_AVAILABLE = True
|
| 44 |
+
except ImportError:
|
| 45 |
+
pass
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def get_linear_feature_importance(
|
| 49 |
+
model: BaseEstimator,
|
| 50 |
+
feature_names: Optional[List[str]] = None,
|
| 51 |
+
class_index: int = -1
|
| 52 |
+
) -> pd.DataFrame:
|
| 53 |
+
if hasattr(model, "coef_"):
|
| 54 |
+
coef = model.coef_
|
| 55 |
+
if coef.ndim == 1:
|
| 56 |
+
weights = coef
|
| 57 |
+
else:
|
| 58 |
+
if class_index == -1:
|
| 59 |
+
weights = np.mean(coef, axis=0)
|
| 60 |
+
else:
|
| 61 |
+
weights = coef[class_index]
|
| 62 |
+
else:
|
| 63 |
+
raise ValueError("Model does not have coef_ attribute")
|
| 64 |
+
|
| 65 |
+
if feature_names is None:
|
| 66 |
+
feature_names = [f"feature_{i}" for i in range(len(weights))]
|
| 67 |
+
|
| 68 |
+
df = pd.DataFrame({"feature": feature_names, "weight": weights})
|
| 69 |
+
df = df.sort_values("weight", key=abs, ascending=False).reset_index(drop=True)
|
| 70 |
+
return df
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def analyze_tfidf_class_keywords(
|
| 74 |
+
tfidf_matrix: np.ndarray,
|
| 75 |
+
y: np.ndarray,
|
| 76 |
+
feature_names: List[str],
|
| 77 |
+
top_k: int = 20
|
| 78 |
+
) -> Dict[Any, pd.DataFrame]:
|
| 79 |
+
classes = np.unique(y)
|
| 80 |
+
results = {}
|
| 81 |
+
|
| 82 |
+
for cls in classes:
|
| 83 |
+
mask = (y == cls)
|
| 84 |
+
avg_tfidf = np.mean(tfidf_matrix[mask], axis=0).A1 if hasattr(tfidf_matrix, 'A1') else np.mean(tfidf_matrix[mask], axis=0)
|
| 85 |
+
top_indices = np.argsort(avg_tfidf)[::-1][:top_k]
|
| 86 |
+
top_words = [feature_names[i] for i in top_indices]
|
| 87 |
+
top_scores = [avg_tfidf[i] for i in top_indices]
|
| 88 |
+
results[cls] = pd.DataFrame({"word": top_words, "tfidf_score": top_scores})
|
| 89 |
+
|
| 90 |
+
return results
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def explain_with_shap(
|
| 94 |
+
model: BaseEstimator,
|
| 95 |
+
X_train: np.ndarray,
|
| 96 |
+
X_test: np.ndarray,
|
| 97 |
+
feature_names: Optional[List[str]] = None,
|
| 98 |
+
plot_type: str = "bar",
|
| 99 |
+
max_display: int = 20
|
| 100 |
+
):
|
| 101 |
+
if "tree" in str(type(model)).lower():
|
| 102 |
+
explainer = shap.TreeExplainer(model)
|
| 103 |
+
else:
|
| 104 |
+
explainer = shap.KernelExplainer(model.predict_proba, X_train[:100])
|
| 105 |
+
|
| 106 |
+
shap_values = explainer.shap_values(X_test[:100])
|
| 107 |
+
|
| 108 |
+
if feature_names is None:
|
| 109 |
+
feature_names = [f"feat_{i}" for i in range(X_test.shape[1])]
|
| 110 |
+
|
| 111 |
+
plt.figure(figsize=(10, 6))
|
| 112 |
+
if isinstance(shap_values, list):
|
| 113 |
+
shap.summary_plot(shap_values, X_test[:100], feature_names=feature_names, plot_type=plot_type, max_display=max_display, show=False)
|
| 114 |
+
else:
|
| 115 |
+
shap.summary_plot(shap_values, X_test[:100], feature_names=feature_names, plot_type=plot_type, max_display=max_display, show=False)
|
| 116 |
+
plt.tight_layout()
|
| 117 |
+
plt.show()
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def explain_text_with_lime(
|
| 121 |
+
model: Any,
|
| 122 |
+
text: str,
|
| 123 |
+
tokenizer: Callable,
|
| 124 |
+
class_names: List[str],
|
| 125 |
+
num_features: int = 10,
|
| 126 |
+
num_samples: int = 5000
|
| 127 |
+
):
|
| 128 |
+
def predict_fn(texts):
|
| 129 |
+
tokenized = [tokenizer(t) for t in texts]
|
| 130 |
+
if hasattr(model, "vectorizer"):
|
| 131 |
+
X = model.vectorizer.transform(texts)
|
| 132 |
+
else:
|
| 133 |
+
raise NotImplementedError("Custom predict_fn needed for your pipeline")
|
| 134 |
+
return model.predict_proba(X.toarray())
|
| 135 |
+
|
| 136 |
+
explainer = lime.lime_text.LimeTextExplainer(class_names=class_names)
|
| 137 |
+
exp = explainer.explain_instance(text, predict_fn, num_features=num_features, num_samples=num_samples)
|
| 138 |
+
exp.show_in_notebook()
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def visualize_attention_weights(
|
| 142 |
+
tokens: List[str],
|
| 143 |
+
attention_weights: np.ndarray,
|
| 144 |
+
layer: int = 0,
|
| 145 |
+
head: int = 0,
|
| 146 |
+
figsize: Tuple[int, int] = (10, 2)
|
| 147 |
+
):
|
| 148 |
+
if attention_weights.ndim != 4:
|
| 149 |
+
raise ValueError("attention_weights must be 4D: (layers, heads, seq, seq)")
|
| 150 |
+
|
| 151 |
+
weights = attention_weights[layer, head, :len(tokens), :len(tokens)]
|
| 152 |
+
|
| 153 |
+
plt.figure(figsize=figsize)
|
| 154 |
+
sns.heatmap(
|
| 155 |
+
weights,
|
| 156 |
+
xticklabels=tokens,
|
| 157 |
+
yticklabels=tokens,
|
| 158 |
+
cmap="viridis",
|
| 159 |
+
cbar=True
|
| 160 |
+
)
|
| 161 |
+
plt.title(f"Attention Layer {layer}, Head {head}")
|
| 162 |
+
plt.xticks(rotation=45, ha="right")
|
| 163 |
+
plt.yticks(rotation=0)
|
| 164 |
+
plt.tight_layout()
|
| 165 |
+
plt.show()
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def get_transformer_attention(
|
| 169 |
+
model: 'torch.nn.Module',
|
| 170 |
+
tokenizer: 'transformers.PreTrainedTokenizer',
|
| 171 |
+
text: str,
|
| 172 |
+
device: str = "cpu"
|
| 173 |
+
) -> Tuple[List[str], np.ndarray]:
|
| 174 |
+
if not CAPTUM_AVAILABLE:
|
| 175 |
+
raise ImportError("Install Captum: pip install captum")
|
| 176 |
+
|
| 177 |
+
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
|
| 178 |
+
input_ids = inputs["input_ids"].to(device)
|
| 179 |
+
model = model.to(device)
|
| 180 |
+
model.eval()
|
| 181 |
+
|
| 182 |
+
with torch.no_grad():
|
| 183 |
+
outputs = model(input_ids, output_attentions=True)
|
| 184 |
+
attentions = outputs.attentions
|
| 185 |
+
|
| 186 |
+
attn = torch.stack(attentions, dim=0).squeeze(1).cpu().numpy()
|
| 187 |
+
tokens = tokenizer.convert_ids_to_tokens(input_ids[0].cpu().numpy())
|
| 188 |
+
return tokens, attn
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def analyze_errors(
|
| 192 |
+
y_true: np.ndarray,
|
| 193 |
+
y_pred: np.ndarray,
|
| 194 |
+
texts: List[str],
|
| 195 |
+
labels: Optional[List[Any]] = None
|
| 196 |
+
) -> pd.DataFrame:
|
| 197 |
+
errors = []
|
| 198 |
+
for i, (true, pred, text) in enumerate(zip(y_true, y_pred, texts)):
|
| 199 |
+
if true != pred:
|
| 200 |
+
errors.append({
|
| 201 |
+
"index": i,
|
| 202 |
+
"text": text,
|
| 203 |
+
"true_label": true,
|
| 204 |
+
"pred_label": pred
|
| 205 |
+
})
|
| 206 |
+
return pd.DataFrame(errors)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def compare_model_errors(
|
| 210 |
+
models: Dict[str, BaseEstimator],
|
| 211 |
+
X_test: np.ndarray,
|
| 212 |
+
y_test: np.ndarray,
|
| 213 |
+
texts: List[str]
|
| 214 |
+
) -> Dict[str, pd.DataFrame]:
|
| 215 |
+
results = {}
|
| 216 |
+
for name, model in models.items():
|
| 217 |
+
y_pred = model.predict(X_test)
|
| 218 |
+
errors = analyze_errors(y_test, y_pred, texts)
|
| 219 |
+
results[name] = errors
|
| 220 |
+
return results
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def plot_embeddings(
|
| 224 |
+
embeddings: np.ndarray,
|
| 225 |
+
labels: np.ndarray,
|
| 226 |
+
method: str = "umap",
|
| 227 |
+
n_components: int = 2,
|
| 228 |
+
figsize: Tuple[int, int] = (12, 8),
|
| 229 |
+
title: str = "Embedding Projection"
|
| 230 |
+
):
|
| 231 |
+
if method == "tsne":
|
| 232 |
+
reducer = TSNE(n_components=n_components, random_state=42, n_jobs=-1)
|
| 233 |
+
elif method == "umap":
|
| 234 |
+
if not UMAP_AVAILABLE:
|
| 235 |
+
raise ImportError("Install UMAP: pip install umap-learn")
|
| 236 |
+
reducer = umap.UMAP(n_components=n_components, random_state=42, n_jobs=-1)
|
| 237 |
+
else:
|
| 238 |
+
raise ValueError("method must be 'tsne' or 'umap'")
|
| 239 |
+
|
| 240 |
+
proj = reducer.fit_transform(embeddings)
|
| 241 |
+
|
| 242 |
+
plt.figure(figsize=figsize)
|
| 243 |
+
scatter = plt.scatter(proj[:, 0], proj[:, 1], c=labels, cmap="tab10", alpha=0.7)
|
| 244 |
+
plt.colorbar(scatter)
|
| 245 |
+
plt.title(title)
|
| 246 |
+
plt.xlabel("Component 1")
|
| 247 |
+
plt.ylabel("Component 2")
|
| 248 |
+
plt.tight_layout()
|
| 249 |
+
plt.show()
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def get_token_importance_captum(
|
| 253 |
+
model: 'torch.nn.Module',
|
| 254 |
+
tokenizer: 'transformers.PreTrainedTokenizer',
|
| 255 |
+
text: str,
|
| 256 |
+
device: str = "cpu"
|
| 257 |
+
) -> Tuple[List[str], np.ndarray]:
|
| 258 |
+
if not CAPTUM_AVAILABLE:
|
| 259 |
+
raise ImportError("Install Captum: pip install captum")
|
| 260 |
+
|
| 261 |
+
from captum.attr import LayerIntegratedGradients
|
| 262 |
+
import torch
|
| 263 |
+
|
| 264 |
+
inputs = tokenizer(
|
| 265 |
+
text,
|
| 266 |
+
return_tensors="pt",
|
| 267 |
+
truncation=True,
|
| 268 |
+
max_length=512,
|
| 269 |
+
padding=True
|
| 270 |
+
)
|
| 271 |
+
input_ids = inputs["input_ids"].to(device)
|
| 272 |
+
attention_mask = inputs["attention_mask"].to(device)
|
| 273 |
+
|
| 274 |
+
model = model.to(device)
|
| 275 |
+
model.eval()
|
| 276 |
+
|
| 277 |
+
with torch.no_grad():
|
| 278 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
| 279 |
+
pred_class = torch.argmax(outputs.logits, dim=1).item()
|
| 280 |
+
|
| 281 |
+
def forward_func(input_ids):
|
| 282 |
+
return model(input_ids=input_ids, attention_mask=attention_mask).logits
|
| 283 |
+
|
| 284 |
+
baseline_ids = torch.zeros_like(input_ids).to(device)
|
| 285 |
+
baseline_ids[:, 0] = tokenizer.cls_token_id
|
| 286 |
+
baseline_ids[:, -1] = tokenizer.sep_token_id
|
| 287 |
+
|
| 288 |
+
lig = LayerIntegratedGradients(forward_func, model.bert.embeddings)
|
| 289 |
+
|
| 290 |
+
attributions, delta = lig.attribute(
|
| 291 |
+
inputs=input_ids,
|
| 292 |
+
baselines=baseline_ids,
|
| 293 |
+
target=pred_class,
|
| 294 |
+
return_convergence_delta=True
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
attributions = attributions.sum(dim=-1).squeeze(0).cpu().detach().numpy()
|
| 298 |
+
|
| 299 |
+
tokens = tokenizer.convert_ids_to_tokens(input_ids[0].cpu().numpy())
|
| 300 |
+
return tokens, attributions
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def plot_token_importance(tokens: List[str], importance: np.ndarray, top_k: int = 20):
|
| 304 |
+
valid = [(t, imp) for t, imp in zip(tokens, importance) if t not in ["[CLS]", "[SEP]", "[PAD]"]]
|
| 305 |
+
if not valid:
|
| 306 |
+
return
|
| 307 |
+
tokens_clean, imp_clean = zip(*valid)
|
| 308 |
+
indices = np.argsort(np.abs(imp_clean))[-top_k:][::-1]
|
| 309 |
+
tokens_top = [tokens_clean[i] for i in indices]
|
| 310 |
+
imp_top = [imp_clean[i] for i in indices]
|
| 311 |
+
|
| 312 |
+
plt.figure(figsize=(10, 6))
|
| 313 |
+
colors = ["red" if x < 0 else "green" for x in imp_top]
|
| 314 |
+
plt.barh(range(len(imp_top)), imp_top, color=colors)
|
| 315 |
+
plt.yticks(range(len(imp_top)), tokens_top)
|
| 316 |
+
plt.gca().invert_yaxis()
|
| 317 |
+
plt.xlabel("Attribution Score")
|
| 318 |
+
plt.title("Token Importance (Green: positive, Red: negative)")
|
| 319 |
+
plt.tight_layout()
|
| 320 |
+
plt.show()
|
src/neural_classifiers.py
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Optional, Union, Tuple, Dict, Any, Literal
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
try:
|
| 6 |
+
import tensorflow as tf
|
| 7 |
+
from tensorflow.keras import layers, models, optimizers, callbacks
|
| 8 |
+
from tensorflow.keras.models import Model
|
| 9 |
+
from tensorflow.keras.layers import (
|
| 10 |
+
Input, Embedding, Dense, Dropout, GlobalMaxPooling1D,
|
| 11 |
+
Conv1D, LSTM, GRU, Bidirectional, Attention, GlobalAveragePooling1D
|
| 12 |
+
)
|
| 13 |
+
TF_AVAILABLE = True
|
| 14 |
+
except ImportError:
|
| 15 |
+
TF_AVAILABLE = False
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 21 |
+
from transformers import (
|
| 22 |
+
AutoTokenizer, AutoModel, AutoConfig,
|
| 23 |
+
BertForSequenceClassification, RobertaForSequenceClassification,
|
| 24 |
+
DistilBertForSequenceClassification, Trainer, TrainingArguments
|
| 25 |
+
)
|
| 26 |
+
from transformers.tokenization_utils_base import BatchEncoding
|
| 27 |
+
TORCH_AVAILABLE = True
|
| 28 |
+
except ImportError:
|
| 29 |
+
TORCH_AVAILABLE = False
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class AttentionLayer(tf.keras.layers.Layer):
|
| 33 |
+
def __init__(self, **kwargs):
|
| 34 |
+
super().__init__(**kwargs)
|
| 35 |
+
|
| 36 |
+
def build(self, input_shape):
|
| 37 |
+
self.W = self.add_weight(
|
| 38 |
+
shape=(input_shape[-1], 1),
|
| 39 |
+
initializer='random_normal',
|
| 40 |
+
trainable=True,
|
| 41 |
+
name='attention_weight'
|
| 42 |
+
)
|
| 43 |
+
self.b = self.add_weight(
|
| 44 |
+
shape=(input_shape[1], 1),
|
| 45 |
+
initializer='zeros',
|
| 46 |
+
trainable=True,
|
| 47 |
+
name='attention_bias'
|
| 48 |
+
)
|
| 49 |
+
super().build(input_shape)
|
| 50 |
+
|
| 51 |
+
def call(self, inputs, **kwargs):
|
| 52 |
+
e = tf.keras.activations.tanh(tf.matmul(inputs, self.W) + self.b)
|
| 53 |
+
e = tf.squeeze(e, axis=-1)
|
| 54 |
+
a = tf.nn.softmax(e, axis=1)
|
| 55 |
+
a = tf.expand_dims(a, axis=-1)
|
| 56 |
+
weighted_input = inputs * a
|
| 57 |
+
return tf.reduce_sum(weighted_input, axis=1)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def build_mlp(
|
| 61 |
+
input_dim: int,
|
| 62 |
+
num_classes: int,
|
| 63 |
+
hidden_dims: list = [256, 128],
|
| 64 |
+
dropout: float = 0.3,
|
| 65 |
+
activation: str = 'relu'
|
| 66 |
+
) -> 'tf.keras.Model':
|
| 67 |
+
if not TF_AVAILABLE:
|
| 68 |
+
raise ImportError("TensorFlow not available")
|
| 69 |
+
inputs = Input(shape=(input_dim,))
|
| 70 |
+
x = inputs
|
| 71 |
+
for dim in hidden_dims:
|
| 72 |
+
x = Dense(dim, activation=activation)(x)
|
| 73 |
+
x = Dropout(dropout)(x)
|
| 74 |
+
outputs = Dense(num_classes, activation='softmax' if num_classes > 2 else 'sigmoid')(x)
|
| 75 |
+
return models.Model(inputs, outputs)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def build_kim_cnn(
|
| 79 |
+
max_len: int,
|
| 80 |
+
vocab_size: int,
|
| 81 |
+
embed_dim: int,
|
| 82 |
+
num_classes: int,
|
| 83 |
+
filter_sizes: list = [3, 4, 5],
|
| 84 |
+
num_filters: int = 100,
|
| 85 |
+
dropout: float = 0.5,
|
| 86 |
+
pre_embed_matrix: Optional[np.ndarray] = None
|
| 87 |
+
) -> 'tf.keras.Model':
|
| 88 |
+
if not TF_AVAILABLE:
|
| 89 |
+
raise ImportError("TensorFlow not available")
|
| 90 |
+
inputs = Input(shape=(max_len,))
|
| 91 |
+
if pre_embed_matrix is not None:
|
| 92 |
+
embedding = Embedding(
|
| 93 |
+
vocab_size, embed_dim,
|
| 94 |
+
weights=[pre_embed_matrix],
|
| 95 |
+
trainable=False
|
| 96 |
+
)(inputs)
|
| 97 |
+
else:
|
| 98 |
+
embedding = Embedding(vocab_size, embed_dim)(inputs)
|
| 99 |
+
|
| 100 |
+
pooled_outputs = []
|
| 101 |
+
for fs in filter_sizes:
|
| 102 |
+
x = Conv1D(num_filters, fs, activation='relu')(embedding)
|
| 103 |
+
x = GlobalMaxPooling1D()(x)
|
| 104 |
+
pooled_outputs.append(x)
|
| 105 |
+
|
| 106 |
+
merged = tf.concat(pooled_outputs, axis=1)
|
| 107 |
+
x = Dropout(dropout)(merged)
|
| 108 |
+
outputs = Dense(num_classes, activation='softmax' if num_classes > 2 else 'sigmoid')(x)
|
| 109 |
+
return models.Model(inputs, outputs)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def build_lstm(
|
| 113 |
+
max_len: int,
|
| 114 |
+
vocab_size: int,
|
| 115 |
+
embed_dim: int,
|
| 116 |
+
num_classes: int,
|
| 117 |
+
lstm_units: int = 128,
|
| 118 |
+
dropout: float = 0.3,
|
| 119 |
+
bidirectional: bool = False,
|
| 120 |
+
pre_embed_matrix: Optional[np.ndarray] = None
|
| 121 |
+
) -> 'tf.keras.Model':
|
| 122 |
+
if not TF_AVAILABLE:
|
| 123 |
+
raise ImportError("TensorFlow not available")
|
| 124 |
+
inputs = Input(shape=(max_len,))
|
| 125 |
+
if pre_embed_matrix is not None:
|
| 126 |
+
x = Embedding(vocab_size, embed_dim, weights=[pre_embed_matrix], trainable=False)(inputs)
|
| 127 |
+
else:
|
| 128 |
+
x = Embedding(vocab_size, embed_dim)(inputs)
|
| 129 |
+
|
| 130 |
+
rnn_layer = LSTM(lstm_units, dropout=dropout, recurrent_dropout=dropout)
|
| 131 |
+
if bidirectional:
|
| 132 |
+
x = Bidirectional(rnn_layer)(x)
|
| 133 |
+
else:
|
| 134 |
+
x = rnn_layer(x)
|
| 135 |
+
|
| 136 |
+
outputs = Dense(num_classes, activation='softmax' if num_classes > 2 else 'sigmoid')(x)
|
| 137 |
+
return models.Model(inputs, outputs)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def build_cnn_lstm(
|
| 141 |
+
max_len: int,
|
| 142 |
+
vocab_size: int,
|
| 143 |
+
embed_dim: int,
|
| 144 |
+
num_classes: int,
|
| 145 |
+
filter_size: int = 3,
|
| 146 |
+
num_filters: int = 128,
|
| 147 |
+
lstm_units: int = 64,
|
| 148 |
+
dropout: float = 0.3,
|
| 149 |
+
pre_embed_matrix: Optional[np.ndarray] = None
|
| 150 |
+
) -> 'tf.keras.Model':
|
| 151 |
+
if not TF_AVAILABLE:
|
| 152 |
+
raise ImportError("TensorFlow not available")
|
| 153 |
+
inputs = Input(shape=(max_len,))
|
| 154 |
+
if pre_embed_matrix is not None:
|
| 155 |
+
x = Embedding(vocab_size, embed_dim, weights=[pre_embed_matrix], trainable=False)(inputs)
|
| 156 |
+
else:
|
| 157 |
+
x = Embedding(vocab_size, embed_dim)(inputs)
|
| 158 |
+
|
| 159 |
+
x = Conv1D(num_filters, filter_size, activation='relu', padding='same')(x)
|
| 160 |
+
x = LSTM(lstm_units, dropout=dropout)(x)
|
| 161 |
+
outputs = Dense(num_classes, activation='softmax' if num_classes > 2 else 'sigmoid')(x)
|
| 162 |
+
return models.Model(inputs, outputs)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def build_birnn_attention(
|
| 166 |
+
max_len: int,
|
| 167 |
+
vocab_size: int,
|
| 168 |
+
embed_dim: int,
|
| 169 |
+
num_classes: int,
|
| 170 |
+
rnn_units: int = 64,
|
| 171 |
+
dropout: float = 0.3,
|
| 172 |
+
pre_embed_matrix: Optional[np.ndarray] = None
|
| 173 |
+
) -> 'tf.keras.Model':
|
| 174 |
+
if not TF_AVAILABLE:
|
| 175 |
+
raise ImportError("TensorFlow not available")
|
| 176 |
+
inputs = Input(shape=(max_len,))
|
| 177 |
+
if pre_embed_matrix is not None:
|
| 178 |
+
x = Embedding(vocab_size, embed_dim, weights=[pre_embed_matrix], trainable=False)(inputs)
|
| 179 |
+
else:
|
| 180 |
+
x = Embedding(vocab_size, embed_dim)(inputs)
|
| 181 |
+
|
| 182 |
+
x = Bidirectional(LSTM(rnn_units, return_sequences=True, dropout=dropout))(x)
|
| 183 |
+
x = AttentionLayer()(x)
|
| 184 |
+
outputs = Dense(num_classes, activation='softmax' if num_classes > 2 else 'sigmoid')(x)
|
| 185 |
+
return models.Model(inputs, outputs)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
_RUSSIAN_TRANSFORMERS = {
|
| 189 |
+
"rubert": "DeepPavlov/rubert-base-cased",
|
| 190 |
+
"ruroberta": "sberbank-ai/ruRoberta-large",
|
| 191 |
+
"distilbert-multilingual": "distilbert-base-multilingual-cased"
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
def get_transformer_classifier(
|
| 195 |
+
model_name: str = "rubert",
|
| 196 |
+
num_classes: int = 2,
|
| 197 |
+
problem_type: Literal["single_label", "multi_label"] = "single_label"
|
| 198 |
+
) -> Tuple[Any, Any]:
|
| 199 |
+
if not TORCH_AVAILABLE:
|
| 200 |
+
raise ImportError("PyTorch or transformers not available")
|
| 201 |
+
|
| 202 |
+
if model_name not in _RUSSIAN_TRANSFORMERS:
|
| 203 |
+
raise ValueError(f"Unknown model_name. Choose from: {list(_RUSSIAN_TRANSFORMERS.keys())}")
|
| 204 |
+
|
| 205 |
+
model_id = _RUSSIAN_TRANSFORMERS[model_name]
|
| 206 |
+
|
| 207 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 208 |
+
|
| 209 |
+
if "roberta" in model_id.lower():
|
| 210 |
+
model = RobertaForSequenceClassification.from_pretrained(
|
| 211 |
+
model_id, num_labels=num_classes
|
| 212 |
+
)
|
| 213 |
+
elif "distilbert" in model_id.lower():
|
| 214 |
+
model = DistilBertForSequenceClassification.from_pretrained(
|
| 215 |
+
model_id, num_labels=num_classes
|
| 216 |
+
)
|
| 217 |
+
else:
|
| 218 |
+
model = BertForSequenceClassification.from_pretrained(
|
| 219 |
+
model_id, num_labels=num_classes
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
if problem_type == "multi_label":
|
| 223 |
+
model.config.problem_type = "multi_label_classification"
|
| 224 |
+
else:
|
| 225 |
+
model.config.problem_type = "single_label_classification"
|
| 226 |
+
|
| 227 |
+
return model, tokenizer
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def quantize_pytorch_model(model: 'torch.nn.Module', backend: str = "qnnpack") -> 'torch.nn.Module':
|
| 231 |
+
if not TORCH_AVAILABLE:
|
| 232 |
+
raise ImportError("PyTorch not available")
|
| 233 |
+
model.eval()
|
| 234 |
+
model.qconfig = torch.quantization.get_default_qconfig(backend)
|
| 235 |
+
torch.quantization.prepare(model, inplace=True)
|
| 236 |
+
torch.quantization.convert(model, inplace=True)
|
| 237 |
+
return model
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def prune_keras_model(model: 'tf.keras.Model', sparsity: float = 0.5) -> 'tf.keras.Model':
|
| 241 |
+
try:
|
| 242 |
+
import tensorflow_model_optimization as tfmot
|
| 243 |
+
except ImportError:
|
| 244 |
+
raise ImportError("Install tensorflow-model-optimization for pruning")
|
| 245 |
+
pruning_params = {
|
| 246 |
+
'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(
|
| 247 |
+
initial_sparsity=0.0, final_sparsity=sparsity, begin_step=0, end_step=1000
|
| 248 |
+
)
|
| 249 |
+
}
|
| 250 |
+
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params)
|
| 251 |
+
return model_for_pruning
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def prepare_keras_inputs(
|
| 255 |
+
texts: list,
|
| 256 |
+
tokenizer=None,
|
| 257 |
+
max_len: int = 128,
|
| 258 |
+
vocab: Optional[dict] = None
|
| 259 |
+
) -> np.ndarray:
|
| 260 |
+
if tokenizer is not None:
|
| 261 |
+
encodings = tokenizer(texts, truncation=True, padding=True, max_length=max_len, return_tensors="np")
|
| 262 |
+
return encodings['input_ids']
|
| 263 |
+
else:
|
| 264 |
+
from tensorflow.keras.preprocessing.text import Tokenizer
|
| 265 |
+
from tensorflow.keras.preprocessing.sequence import pad_sequences
|
| 266 |
+
tk = Tokenizer(oov_token="<OOV>")
|
| 267 |
+
if vocab:
|
| 268 |
+
tk.word_index = vocab
|
| 269 |
+
else:
|
| 270 |
+
tk.fit_on_texts(texts)
|
| 271 |
+
sequences = tk.texts_to_sequences(texts)
|
| 272 |
+
return pad_sequences(sequences, maxlen=max_len)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def compile_keras_model(
|
| 276 |
+
model: 'tf.keras.Model',
|
| 277 |
+
learning_rate: float = 2e-5,
|
| 278 |
+
num_classes: int = 2
|
| 279 |
+
):
|
| 280 |
+
loss = 'sparse_categorical_crossentropy' if num_classes > 2 else 'binary_crossentropy'
|
| 281 |
+
model.compile(
|
| 282 |
+
optimizer=optimizers.Adam(learning_rate=learning_rate),
|
| 283 |
+
loss=loss,
|
| 284 |
+
metrics=['accuracy']
|
| 285 |
+
)
|
| 286 |
+
return model
|
src/text_preprocessing.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import string
|
| 3 |
+
from typing import List, Optional, Union, Dict, Any, Callable
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
|
| 8 |
+
from nltk.corpus import stopwords
|
| 9 |
+
from nltk.tokenize import word_tokenize
|
| 10 |
+
from nltk import download as nltk_download
|
| 11 |
+
from nltk.stem import WordNetLemmatizer
|
| 12 |
+
import spacy
|
| 13 |
+
from gensim.models import KeyedVectors
|
| 14 |
+
from transformers import AutoTokenizer, AutoModel
|
| 15 |
+
import torch
|
| 16 |
+
import emoji
|
| 17 |
+
print('PREPROCESSING IMPORTED')
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
nltk_download('punkt', quiet=True)
|
| 21 |
+
nltk_download('stopwords', quiet=True)
|
| 22 |
+
nltk_download('wordnet', quiet=True)
|
| 23 |
+
except Exception as e:
|
| 24 |
+
print(f"Warning: NLTK data download failed: {e}")
|
| 25 |
+
|
| 26 |
+
_SPACY_MODEL = None
|
| 27 |
+
_NLTK_LEMMATIZER = None
|
| 28 |
+
_BERT_TOKENIZER = None
|
| 29 |
+
_BERT_MODEL = None
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _load_spacy_model(lang: str = "en_core_web_sm"):
|
| 33 |
+
global _SPACY_MODEL
|
| 34 |
+
if _SPACY_MODEL is None:
|
| 35 |
+
try:
|
| 36 |
+
_SPACY_MODEL = spacy.load(lang)
|
| 37 |
+
except OSError:
|
| 38 |
+
raise ValueError(
|
| 39 |
+
f"spaCy model '{lang}' not found. Please install it via: python -m spacy download {lang}"
|
| 40 |
+
)
|
| 41 |
+
return _SPACY_MODEL
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _load_nltk_lemmatizer():
|
| 45 |
+
global _NLTK_LEMMATIZER
|
| 46 |
+
if _NLTK_LEMMATIZER is None:
|
| 47 |
+
_NLTK_LEMMATIZER = WordNetLemmatizer()
|
| 48 |
+
return _NLTK_LEMMATIZER
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def _load_bert_model(model_name: str = "bert-base-uncased"):
|
| 52 |
+
global _BERT_TOKENIZER, _BERT_MODEL
|
| 53 |
+
if _BERT_TOKENIZER is None or _BERT_MODEL is None:
|
| 54 |
+
_BERT_TOKENIZER = AutoTokenizer.from_pretrained(model_name)
|
| 55 |
+
_BERT_MODEL = AutoModel.from_pretrained(model_name)
|
| 56 |
+
return _BERT_TOKENIZER, _BERT_MODEL
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def clean_text(text: str) -> str:
|
| 60 |
+
text = re.sub(r"<[^>]+>", "", text)
|
| 61 |
+
text = re.sub(r"https?://\S+|www\.\S+", "", text)
|
| 62 |
+
text = "".join(ch for ch in text if ch in string.printable)
|
| 63 |
+
text = re.sub(r"\s+", " ", text).strip()
|
| 64 |
+
return text
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def replace_emojis(text: str) -> str:
|
| 68 |
+
return emoji.demojize(text, delimiters=(" ", " "))
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def preprocess_text(
|
| 72 |
+
text: str,
|
| 73 |
+
lang: str = "en",
|
| 74 |
+
remove_stopwords: bool = True,
|
| 75 |
+
use_spacy: bool = True,
|
| 76 |
+
lemmatize: bool = True,
|
| 77 |
+
emoji_to_text: bool = True,
|
| 78 |
+
lowercase: bool = True,
|
| 79 |
+
spacy_model: Optional[str] = None,
|
| 80 |
+
replace_entities: bool = False # ← новая опция: по умолчанию НЕ заменяем числа/URL
|
| 81 |
+
) -> List[str]:
|
| 82 |
+
import re
|
| 83 |
+
import string
|
| 84 |
+
|
| 85 |
+
if emoji_to_text:
|
| 86 |
+
text = replace_emojis(text)
|
| 87 |
+
|
| 88 |
+
text = re.sub(r"<[^>]+>", "", text)
|
| 89 |
+
|
| 90 |
+
text = re.sub(r"[^\w\s]", " ", text) # заменяем НЕ-слова и НЕ-пробелы на пробел
|
| 91 |
+
text = re.sub(r"\s+", " ", text).strip()
|
| 92 |
+
|
| 93 |
+
if replace_entities:
|
| 94 |
+
text = re.sub(r"\b\d+\b", "<NUM>", text)
|
| 95 |
+
text = re.sub(r"https?://\S+|www\.\S+", "<URL>", text)
|
| 96 |
+
text = re.sub(r"\S+@\S+", "<EMAIL>", text)
|
| 97 |
+
|
| 98 |
+
if lowercase:
|
| 99 |
+
text = text.lower()
|
| 100 |
+
|
| 101 |
+
if use_spacy:
|
| 102 |
+
spacy_lang = spacy_model or ("en_core_web_sm" if lang == "en" else f"{lang}_core_news_sm")
|
| 103 |
+
nlp = _load_spacy_model(spacy_lang)
|
| 104 |
+
doc = nlp(text)
|
| 105 |
+
if lemmatize:
|
| 106 |
+
tokens = [token.lemma_ for token in doc if not token.is_space and not token.is_punct]
|
| 107 |
+
else:
|
| 108 |
+
tokens = [token.text for token in doc if not token.is_space and not token.is_punct]
|
| 109 |
+
|
| 110 |
+
if remove_stopwords:
|
| 111 |
+
tokens = [token for token in tokens if not nlp.vocab[token].is_stop]
|
| 112 |
+
|
| 113 |
+
else:
|
| 114 |
+
tokens = word_tokenize(text)
|
| 115 |
+
if lemmatize:
|
| 116 |
+
lemmatizer = _load_nltk_lemmatizer()
|
| 117 |
+
tokens = [lemmatizer.lemmatize(token) for token in tokens]
|
| 118 |
+
|
| 119 |
+
if remove_stopwords:
|
| 120 |
+
stop_words = set(stopwords.words(lang)) if lang in stopwords.fileids() else set()
|
| 121 |
+
tokens = [token for token in tokens if token not in stop_words]
|
| 122 |
+
|
| 123 |
+
tokens = [token for token in tokens if token not in string.punctuation and len(token) > 0]
|
| 124 |
+
|
| 125 |
+
return tokens
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class TextVectorizer:
|
| 129 |
+
def __init__(self):
|
| 130 |
+
self.bow_vectorizer = None
|
| 131 |
+
self.tfidf_vectorizer = None
|
| 132 |
+
|
| 133 |
+
def bow(self, texts: List[str], **kwargs) -> np.ndarray:
|
| 134 |
+
self.bow_vectorizer = CountVectorizer(**kwargs)
|
| 135 |
+
return self.bow_vectorizer.fit_transform(texts).toarray()
|
| 136 |
+
|
| 137 |
+
def tfidf(self, texts: List[str], max_features: int = 5000, **kwargs) -> np.ndarray:
|
| 138 |
+
kwargs['max_features'] = max_features
|
| 139 |
+
self.tfidf_vectorizer = TfidfVectorizer(lowercase=False, **kwargs)
|
| 140 |
+
return self.tfidf_vectorizer.fit_transform(texts).toarray()
|
| 141 |
+
|
| 142 |
+
def ngrams(self, texts: List[str], ngram_range: tuple = (1, 2), **kwargs) -> np.ndarray:
|
| 143 |
+
kwargs.setdefault("ngram_range", ngram_range)
|
| 144 |
+
return self.tfidf(texts, **kwargs)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class EmbeddingVectorizer:
|
| 148 |
+
def __init__(self):
|
| 149 |
+
self.word2vec_model = None
|
| 150 |
+
self.fasttext_model = None
|
| 151 |
+
self.glove_vectors = None
|
| 152 |
+
|
| 153 |
+
def load_word2vec(self, path: str):
|
| 154 |
+
self.word2vec_model = KeyedVectors.load_word2vec_format(path, binary=True)
|
| 155 |
+
|
| 156 |
+
def load_fasttext(self, path: str):
|
| 157 |
+
self.fasttext_model = KeyedVectors.load(path)
|
| 158 |
+
|
| 159 |
+
def load_glove(self, glove_file: str, vocab_size: int = 400000, dim: int = 300):
|
| 160 |
+
self.glove_vectors = {}
|
| 161 |
+
with open(glove_file, "r", encoding="utf-8") as f:
|
| 162 |
+
for i, line in enumerate(f):
|
| 163 |
+
if i >= vocab_size:
|
| 164 |
+
break
|
| 165 |
+
values = line.split()
|
| 166 |
+
word = values[0]
|
| 167 |
+
vector = np.array(values[1:], dtype="float32")
|
| 168 |
+
self.glove_vectors[word] = vector
|
| 169 |
+
|
| 170 |
+
def _get_word_vector(self, word: str, method: str = "word2vec") -> Optional[np.ndarray]:
|
| 171 |
+
if method == "word2vec" and self.word2vec_model and word in self.word2vec_model:
|
| 172 |
+
return self.word2vec_model[word]
|
| 173 |
+
elif method == "fasttext" and self.fasttext_model and word in self.fasttext_model:
|
| 174 |
+
return self.fasttext_model[word]
|
| 175 |
+
elif method == "glove" and self.glove_vectors and word in self.glove_vectors:
|
| 176 |
+
return self.glove_vectors[word]
|
| 177 |
+
return None
|
| 178 |
+
|
| 179 |
+
def _aggregate_vectors(
|
| 180 |
+
self, vectors: List[np.ndarray], strategy: str = "mean"
|
| 181 |
+
) -> np.ndarray:
|
| 182 |
+
if not vectors:
|
| 183 |
+
return np.zeros(300) # default dim
|
| 184 |
+
if strategy == "mean":
|
| 185 |
+
return np.mean(vectors, axis=0)
|
| 186 |
+
elif strategy == "max":
|
| 187 |
+
return np.max(vectors, axis=0)
|
| 188 |
+
else:
|
| 189 |
+
raise ValueError("Strategy must be 'mean' or 'max'")
|
| 190 |
+
|
| 191 |
+
def get_embeddings(
|
| 192 |
+
self,
|
| 193 |
+
tokenized_texts: List[List[str]],
|
| 194 |
+
method: str = "word2vec",
|
| 195 |
+
aggregation: str = "mean",
|
| 196 |
+
) -> np.ndarray:
|
| 197 |
+
embeddings = []
|
| 198 |
+
for tokens in tokenized_texts:
|
| 199 |
+
vectors = [
|
| 200 |
+
self._get_word_vector(token, method=method) for token in tokens
|
| 201 |
+
]
|
| 202 |
+
vectors = [v for v in vectors if v is not None]
|
| 203 |
+
doc_vec = self._aggregate_vectors(vectors, strategy=aggregation)
|
| 204 |
+
embeddings.append(doc_vec)
|
| 205 |
+
return np.array(embeddings)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def get_contextual_embeddings(
|
| 209 |
+
texts: List[str],
|
| 210 |
+
model_name: str = "bert-base-uncased",
|
| 211 |
+
aggregation: str = "mean",
|
| 212 |
+
device: str = "cpu",
|
| 213 |
+
) -> np.ndarray:
|
| 214 |
+
tokenizer, model = _load_bert_model(model_name)
|
| 215 |
+
model.to(device)
|
| 216 |
+
model.eval()
|
| 217 |
+
|
| 218 |
+
embeddings = []
|
| 219 |
+
with torch.no_grad():
|
| 220 |
+
for text in texts:
|
| 221 |
+
inputs = tokenizer(
|
| 222 |
+
text,
|
| 223 |
+
return_tensors="pt",
|
| 224 |
+
truncation=True,
|
| 225 |
+
padding=True,
|
| 226 |
+
max_length=512,
|
| 227 |
+
)
|
| 228 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 229 |
+
outputs = model(**inputs)
|
| 230 |
+
token_embeddings = outputs.last_hidden_state[0].cpu().numpy()
|
| 231 |
+
|
| 232 |
+
# Exclude [CLS] and [SEP] if needed (simple heuristic: skip first and last)
|
| 233 |
+
if len(token_embeddings) > 2:
|
| 234 |
+
token_embeddings = token_embeddings[1:-1]
|
| 235 |
+
|
| 236 |
+
if aggregation == "mean":
|
| 237 |
+
doc_emb = np.mean(token_embeddings, axis=0)
|
| 238 |
+
elif aggregation == "max":
|
| 239 |
+
doc_emb = np.max(token_embeddings, axis=0)
|
| 240 |
+
else:
|
| 241 |
+
raise ValueError("aggregation must be 'mean' or 'max'")
|
| 242 |
+
embeddings.append(doc_emb)
|
| 243 |
+
|
| 244 |
+
return np.array(embeddings)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def extract_meta_features(texts: Union[List[str], pd.Series]) -> pd.DataFrame:
|
| 248 |
+
if isinstance(texts, pd.Series):
|
| 249 |
+
texts = texts.tolist()
|
| 250 |
+
|
| 251 |
+
features = []
|
| 252 |
+
for text in texts:
|
| 253 |
+
original_len = len(text)
|
| 254 |
+
words = text.split()
|
| 255 |
+
word_lengths = [len(w) for w in words] if words else [0]
|
| 256 |
+
avg_word_len = np.mean(word_lengths)
|
| 257 |
+
num_unique_words = len(set(words)) if words else 0
|
| 258 |
+
num_punct = sum(1 for c in text if c in string.punctuation)
|
| 259 |
+
num_upper = sum(1 for c in text if c.isupper())
|
| 260 |
+
num_digits = sum(1 for c in text if c.isdigit())
|
| 261 |
+
|
| 262 |
+
try:
|
| 263 |
+
flesch = np.nan
|
| 264 |
+
except Exception:
|
| 265 |
+
flesch = np.nan
|
| 266 |
+
|
| 267 |
+
features.append({
|
| 268 |
+
"text_length": original_len,
|
| 269 |
+
"avg_word_length": avg_word_len,
|
| 270 |
+
"num_unique_words": num_unique_words,
|
| 271 |
+
"punctuation_ratio": num_punct / original_len if original_len > 0 else 0,
|
| 272 |
+
"uppercase_ratio": num_upper / original_len if original_len > 0 else 0,
|
| 273 |
+
"digit_ratio": num_digits / original_len if original_len > 0 else 0,
|
| 274 |
+
"flesch_reading_ease": flesch,
|
| 275 |
+
})
|
| 276 |
+
|
| 277 |
+
return pd.DataFrame(features)
|