theformatisvalid commited on
Commit
2153792
·
verified ·
1 Parent(s): 37958f4

Upload 7 files

Browse files
src/classical_classifiers.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any, Optional, Union, Tuple
2
+ import numpy as np
3
+ import pandas as pd
4
+ from sklearn.base import BaseEstimator, ClassifierMixin
5
+ from sklearn.linear_model import LogisticRegression
6
+ from sklearn.svm import SVC
7
+ from sklearn.ensemble import RandomForestClassifier, BaggingClassifier, VotingClassifier, StackingClassifier
8
+ from sklearn.model_selection import train_test_split
9
+ from sklearn.metrics import accuracy_score, classification_report
10
+ from sklearn.preprocessing import LabelEncoder
11
+
12
+ XGBClassifier = None
13
+ CatBoostClassifier = None
14
+ LGBMClassifier = None
15
+
16
+ try:
17
+ from xgboost import XGBClassifier
18
+ except ImportError:
19
+ pass
20
+
21
+ try:
22
+ from catboost import CatBoostClassifier
23
+ except ImportError:
24
+ pass
25
+
26
+ try:
27
+ from lightgbm import LGBMClassifier
28
+ except ImportError:
29
+ pass
30
+
31
+
32
+ def get_logistic_regression(
33
+ penalty: str = "l2",
34
+ C: float = 1.0,
35
+ max_iter: int = 1000,
36
+ solver: str = "liblinear", # supports l1
37
+ random_state: int = 42
38
+ ) -> LogisticRegression:
39
+ if penalty not in ("l1", "l2", "elasticnet", "none"):
40
+ raise ValueError("penalty must be 'l1', 'l2', 'elasticnet', or 'none'")
41
+ if penalty == "l1" and solver not in ("liblinear", "saga"):
42
+ solver = "liblinear"
43
+ return LogisticRegression(
44
+ penalty=penalty,
45
+ C=C,
46
+ max_iter=max_iter,
47
+ solver=solver,
48
+ random_state=random_state
49
+ )
50
+
51
+
52
+ def get_svm_linear(C: float = 1.0, random_state: int = 42) -> SVC:
53
+ return SVC(kernel="linear", C=C, probability=True, random_state=random_state)
54
+
55
+
56
+ def get_random_forest(
57
+ n_estimators: int = 100,
58
+ max_depth: Optional[int] = None,
59
+ random_state: int = 42
60
+ ) -> RandomForestClassifier:
61
+ return RandomForestClassifier(
62
+ n_estimators=n_estimators,
63
+ max_depth=max_depth,
64
+ random_state=random_state
65
+ )
66
+
67
+
68
+ def get_gradient_boosting(
69
+ model_type: str = "xgb",
70
+ **kwargs
71
+ ) -> Union[XGBClassifier, "CatBoostClassifier", "LGBMClassifier"]:
72
+ if model_type == "xgb":
73
+ if XGBClassifier is None:
74
+ raise ImportError("XGBoost not installed. Run: pip install xgboost")
75
+ kwargs.setdefault("random_state", 42)
76
+ return XGBClassifier(**kwargs)
77
+ elif model_type == "cat":
78
+ if CatBoostClassifier is None:
79
+ raise ImportError("CatBoost not installed. Run: pip install catboost")
80
+ kwargs.setdefault("verbose", False)
81
+ kwargs.setdefault("random_seed", 42)
82
+ return CatBoostClassifier(**kwargs)
83
+ elif model_type == "lgb":
84
+ if LGBMClassifier is None:
85
+ raise ImportError("LightGBM not installed. Run: pip install lightgbm")
86
+ kwargs.setdefault("random_state", 42)
87
+ return LGBMClassifier(**kwargs)
88
+ else:
89
+ raise ValueError("model_type must be 'xgb', 'cat', or 'lgb'")
90
+
91
+
92
+ def get_bagging_classifier(
93
+ base_estimator: str = "tree",
94
+ n_estimators: int = 10,
95
+ random_state: int = 42
96
+ ) -> BaggingClassifier:
97
+ if base_estimator == "tree":
98
+ from sklearn.tree import DecisionTreeClassifier
99
+ estimator = DecisionTreeClassifier(random_state=random_state)
100
+ elif base_estimator == "lr":
101
+ estimator = get_logistic_regression()
102
+ else:
103
+ raise ValueError("base_estimator must be 'tree' or 'lr'")
104
+ return BaggingClassifier(
105
+ estimator=estimator,
106
+ n_estimators=n_estimators,
107
+ random_state=random_state
108
+ )
109
+
110
+
111
+ def get_stacking_classifier(
112
+ final_estimator: Optional[BaseEstimator] = None,
113
+ cv: int = 5,
114
+ random_state: int = 42
115
+ ) -> StackingClassifier:
116
+ estimators = [
117
+ ("lr", get_logistic_regression()),
118
+ ("svm", get_svm_linear()),
119
+ ]
120
+ if CatBoostClassifier is not None:
121
+ estimators.append(("cat", get_gradient_boosting("cat", iterations=100)))
122
+
123
+ if final_estimator is None:
124
+ final_estimator = get_logistic_regression()
125
+
126
+ return StackingClassifier(
127
+ estimators=estimators,
128
+ final_estimator=final_estimator,
129
+ cv=cv,
130
+ passthrough=False
131
+ )
132
+
133
+
134
+ def get_voting_classifier(
135
+ voting: str = "soft",
136
+ use_catboost: bool = True
137
+ ) -> VotingClassifier:
138
+ clfs = [
139
+ ("lr", get_logistic_regression()),
140
+ ("svm", get_svm_linear()),
141
+ ("rf", get_random_forest(n_estimators=50))
142
+ ]
143
+ if use_catboost and CatBoostClassifier is not None:
144
+ clfs.append(("cat", get_gradient_boosting("cat", iterations=50, verbose=False)))
145
+
146
+ return VotingClassifier(
147
+ estimators=clfs,
148
+ voting=voting
149
+ )
150
+
151
+
152
+ def tpot_classifier(
153
+ generations: int = 5,
154
+ population_size: int = 20,
155
+ cv: int = 5,
156
+ random_state: int = 42,
157
+ verbosity: int = 0
158
+ ) -> Any:
159
+ try:
160
+ from tpot import TPOTClassifier
161
+ except ImportError:
162
+ raise ImportError("TPOT not installed. Run: pip install tpot")
163
+
164
+ return TPOTClassifier(
165
+ generations=generations,
166
+ population_size=population_size,
167
+ cv=cv,
168
+ random_state=random_state,
169
+ verbosity=verbosity,
170
+ n_jobs=-1
171
+ )
172
+
173
+
174
+ def h2o_classifier(
175
+ max_runtime_secs: int = 300,
176
+ seed: int = 42,
177
+ exclude_algos: Optional[list] = None
178
+ ) -> Any:
179
+ try:
180
+ import h2o
181
+ from h2o.automl import H2OAutoML
182
+ except ImportError:
183
+ raise ImportError("H2O not installed. Run: pip install h2o")
184
+
185
+ aml = H2OAutoML(
186
+ max_runtime_secs=max_runtime_secs,
187
+ seed=seed,
188
+ exclude_algos=exclude_algos
189
+ )
190
+ return aml
191
+
192
+
193
+ def train_and_evaluate(
194
+ model: Union[BaseEstimator, Any],
195
+ X_train: Union[np.ndarray, pd.DataFrame],
196
+ y_train: Union[np.ndarray, pd.Series],
197
+ X_test: Union[np.ndarray, pd.DataFrame],
198
+ y_test: Union[np.ndarray, pd.Series],
199
+ is_h2o: bool = False
200
+ ) -> Dict[str, Any]:
201
+ if is_h2o:
202
+ import h2o
203
+ train_frame = X_train.cbind(y_train)
204
+ test_frame = X_test.cbind(y_test)
205
+ y_col = y_train.columns[0]
206
+
207
+ model.train(x=X_train.columns.tolist(), y=y_col, training_frame=train_frame)
208
+ perf = model.model_performance(test_frame)
209
+ return {
210
+ "accuracy": perf.accuracy()[0],
211
+ "auc": perf.auc() if perf._has_auc() else None,
212
+ "best_model": model.leader
213
+ }
214
+ else:
215
+ model.fit(X_train, y_train)
216
+ y_pred = model.predict(X_test)
217
+ return {
218
+ "accuracy": accuracy_score(y_test, y_pred),
219
+ "report": classification_report(y_test, y_pred, output_dict=True),
220
+ "model": model
221
+ }
src/imbalance_handling.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple, Union, Dict, Optional, Any, Callable
2
+ import numpy as np
3
+ import pandas as pd
4
+ from collections import Counter
5
+
6
+
7
+ def compute_class_weights(y: Union[List, np.ndarray], method: str = "balanced") -> Union[Dict[int, float], None]:
8
+ if method == "balanced":
9
+ from sklearn.utils.class_weight import compute_class_weight
10
+ classes = np.unique(y)
11
+ weights = compute_class_weight('balanced', classes=classes, y=y)
12
+ return dict(zip(classes, weights))
13
+ else:
14
+ return None
15
+
16
+
17
+ def get_pytorch_weighted_loss(class_weights: Optional[Dict[int, float]] = None,
18
+ num_classes: Optional[int] = None) -> 'torch.nn.Module':
19
+ try:
20
+ import torch
21
+ import torch.nn as nn
22
+ except ImportError:
23
+ raise ImportError("PyTorch not installed")
24
+
25
+ if class_weights is not None:
26
+ weight_tensor = torch.tensor([class_weights[i] for i in sorted(class_weights.keys())], dtype=torch.float)
27
+ return nn.CrossEntropyLoss(weight=weight_tensor)
28
+ else:
29
+ return nn.CrossEntropyLoss()
30
+
31
+
32
+ def get_tensorflow_weighted_loss(class_weights: Optional[Dict[int, float]] = None) -> Callable:
33
+ if not class_weights:
34
+ return 'sparse_categorical_crossentropy'
35
+
36
+ weight_list = [class_weights[i] for i in sorted(class_weights.keys())]
37
+
38
+ import tensorflow as tf
39
+
40
+ def weighted_sparse_categorical_crossentropy(y_true, y_pred):
41
+ y_true = tf.cast(y_true, tf.int32)
42
+ y_true_one_hot = tf.one_hot(y_true, depth=len(weight_list))
43
+ weights = tf.reduce_sum(y_true_one_hot * weight_list, axis=1)
44
+ unweighted_losses = tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred)
45
+ weighted_losses = unweighted_losses * weights
46
+ return tf.reduce_mean(weighted_losses)
47
+
48
+ return weighted_sparse_categorical_crossentropy
49
+
50
+
51
+ def apply_sampling(
52
+ X: np.ndarray,
53
+ y: np.ndarray,
54
+ method: str = "random_under",
55
+ random_state: int = 42
56
+ ) -> Tuple[np.ndarray, np.ndarray]:
57
+ from imblearn.over_sampling import SMOTE, ADASYN
58
+ from imblearn.under_sampling import RandomUnderSampler
59
+ from imblearn.over_sampling import RandomOverSampler
60
+
61
+ if method == "random_under":
62
+ sampler = RandomUnderSampler(random_state=random_state)
63
+ elif method == "random_over":
64
+ sampler = RandomOverSampler(random_state=random_state)
65
+ elif method == "smote":
66
+ sampler = SMOTE(random_state=random_state)
67
+ elif method == "adasyn":
68
+ sampler = ADASYN(random_state=random_state)
69
+ else:
70
+ raise ValueError("method must be one of: random_under, random_over, smote, adasyn")
71
+
72
+ X_res, y_res = sampler.fit_resample(X, y)
73
+ return X_res, y_res
74
+
75
+
76
+ def augment_texts(
77
+ texts: List[str],
78
+ labels: List[Any],
79
+ augmentation_type: str = "synonym",
80
+ aug_p: float = 0.1,
81
+ lang: str = "ru", # language code
82
+ model_name: Optional[str] = None,
83
+ num_aug: int = 1,
84
+ random_state: int = 42
85
+ ) -> Tuple[List[str], List[Any]]:
86
+ try:
87
+ import nlpaug.augmenter.word as naw
88
+ import nlpaug.augmenter.sentence as nas
89
+ except ImportError:
90
+ raise ImportError("Install nlpaug: pip install nlpaug")
91
+
92
+ augmented_texts = []
93
+ augmented_labels = []
94
+
95
+ if augmentation_type == "synonym":
96
+ if lang == "en":
97
+ aug = naw.SynonymAug(aug_p=aug_p, aug_max=None)
98
+ else:
99
+ aug = naw.ContextualWordEmbsAug(
100
+ model_path='bert-base-multilingual-cased',
101
+ action="substitute",
102
+ aug_p=aug_p,
103
+ device='cpu'
104
+ )
105
+ elif augmentation_type == "insert":
106
+ aug = naw.RandomWordAug(action="insert", aug_p=aug_p)
107
+ elif augmentation_type == "delete":
108
+ aug = naw.RandomWordAug(action="delete", aug_p=aug_p)
109
+ elif augmentation_type == "swap":
110
+ aug = naw.RandomWordAug(action="swap", aug_p=aug_p)
111
+ elif augmentation_type == "eda":
112
+ aug = naw.AntonymAug()
113
+ elif augmentation_type == "back_trans":
114
+ if not model_name:
115
+ if lang == "ru":
116
+ model_name = "Helsinki-NLP/opus-mt-ru-en"
117
+ back_model = "Helsinki-NLP/opus-mt-en-ru"
118
+ else:
119
+ model_name = "Helsinki-NLP/opus-mt-en-ru"
120
+ back_model = "Helsinki-NLP/opus-mt-ru-en"
121
+ else:
122
+ back_model = model_name
123
+
124
+ try:
125
+ from transformers import pipeline
126
+ translator1 = pipeline("translation", model=model_name, tokenizer=model_name)
127
+ translator2 = pipeline("translation", model=back_model, tokenizer=back_model)
128
+
129
+ def back_translate(text):
130
+ try:
131
+ trans = translator1(text)[0]['translation_text']
132
+ back = translator2(trans)[0]['translation_text']
133
+ return back
134
+ except Exception:
135
+ return text
136
+
137
+ augmented = [back_translate(t) for t in texts for _ in range(num_aug)]
138
+ labels_aug = [l for l in labels for _ in range(num_aug)]
139
+ return augmented, labels_aug
140
+ except Exception as e:
141
+ print(f"Back-translation failed: {e}. Falling back to synonym augmentation.")
142
+ aug = naw.ContextualWordEmbsAug(model_path='bert-base-multilingual-cased', aug_p=aug_p)
143
+ elif augmentation_type == "llm":
144
+ raise NotImplementedError("LLM-controlled augmentation requires external API (e.g., OpenAI, YandexGPT)")
145
+ else:
146
+ raise ValueError("Unknown augmentation_type")
147
+
148
+ for text, label in zip(texts, labels):
149
+ for _ in range(num_aug):
150
+ try:
151
+ aug_text = aug.augment(text)
152
+ if isinstance(aug_text, list):
153
+ aug_text = aug_text[0]
154
+ augmented_texts.append(aug_text)
155
+ augmented_labels.append(label)
156
+ except Exception as e:
157
+ augmented_texts.append(text)
158
+ augmented_labels.append(label)
159
+
160
+ return augmented_texts, augmented_labels
161
+
162
+
163
+ def balance_text_dataset(
164
+ texts: List[str],
165
+ labels: List[Any],
166
+ strategy: str = "augmentation",
167
+ minority_classes: Optional[List[Any]] = None,
168
+ augmentation_type: str = "synonym",
169
+ sampling_method: str = "smote",
170
+ lang: str = "ru",
171
+ embedding_func: Optional[Callable] = None,
172
+ class_weights: bool = False,
173
+ random_state: int = 42
174
+ ) -> Union[
175
+ Tuple[List[str], List[Any]], # for augmentation
176
+ Tuple[np.ndarray, np.ndarray, Optional[Dict]] # for sampling + weights
177
+ ]:
178
+ label_counts = Counter(labels)
179
+ if minority_classes is None:
180
+ min_count = min(label_counts.values())
181
+ minority_classes = [lbl for lbl, cnt in label_counts.items() if cnt == min_count]
182
+
183
+ if strategy == "augmentation":
184
+ minority_texts = [t for t, l in zip(texts, labels) if l in minority_classes]
185
+ minority_labels = [l for l in labels if l in minority_classes]
186
+
187
+ aug_texts, aug_labels = augment_texts(
188
+ minority_texts, minority_labels,
189
+ augmentation_type=augmentation_type,
190
+ lang=lang,
191
+ num_aug=max(1, int((max(label_counts.values()) / min_count)) - 1),
192
+ random_state=random_state
193
+ )
194
+
195
+ balanced_texts = texts + aug_texts
196
+ balanced_labels = labels + aug_labels
197
+ return balanced_texts, balanced_labels
198
+
199
+ elif strategy == "sampling":
200
+ if embedding_func is None:
201
+ raise ValueError("embedding_func is required for sampling strategy")
202
+ X_embed = np.array([embedding_func(t) for t in texts])
203
+ X_res, y_res = apply_sampling(X_embed, np.array(labels), method=sampling_method, random_state=random_state)
204
+ weights = compute_class_weights(y_res) if class_weights else None
205
+ return X_res, y_res, weights
206
+
207
+ elif strategy == "both":
208
+ aug_texts, aug_labels = balance_text_dataset(
209
+ texts, labels, strategy="augmentation", minority_classes=minority_classes,
210
+ augmentation_type=augmentation_type, lang=lang, random_state=random_state
211
+ )
212
+ if embedding_func is None:
213
+ return aug_texts, aug_labels
214
+ X_embed = np.array([embedding_func(t) for t in aug_texts])
215
+ X_res, y_res = apply_sampling(X_embed, np.array(aug_labels), method=sampling_method, random_state=random_state)
216
+ weights = compute_class_weights(y_res) if class_weights else None
217
+ return X_res, y_res, weights
218
+ else:
219
+ raise ValueError("strategy must be 'augmentation', 'sampling', or 'both'")
src/main.py ADDED
@@ -0,0 +1,544 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import pandas as pd
4
+ import json
5
+ import matplotlib.pyplot as plt
6
+ import seaborn as sns
7
+ from typing import List, Dict, Any, Union
8
+ import torch
9
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
10
+ import shap
11
+
12
+ st.set_page_config(
13
+ page_title="Text Classifiers",
14
+ layout="wide",
15
+ initial_sidebar_state="expanded"
16
+ )
17
+
18
+ from text_preprocessing import (
19
+ preprocess_text, get_contextual_embeddings, TextVectorizer
20
+ )
21
+ from classical_classifiers import (
22
+ get_logistic_regression, get_svm_linear, get_random_forest,
23
+ get_gradient_boosting, get_voting_classifier
24
+ )
25
+ from neural_classifiers import get_transformer_classifier
26
+ from model_evaluation import evaluate_model
27
+ from model_interpretation import (
28
+ get_linear_feature_importance,
29
+ analyze_errors,
30
+ get_transformer_attention,
31
+ visualize_attention_weights,
32
+ get_token_importance_captum,
33
+ plot_token_importance
34
+ )
35
+
36
+ import warnings
37
+
38
+ warnings.filterwarnings("ignore")
39
+
40
+ if 'models' not in st.session_state:
41
+ st.session_state.models = {}
42
+ if 'results' not in st.session_state:
43
+ st.session_state.results = {}
44
+ if 'dataset' not in st.session_state:
45
+ st.session_state.dataset = None
46
+ if 'task_type' not in st.session_state:
47
+ st.session_state.task_type = None
48
+ if 'preprocessed' not in st.session_state:
49
+ st.session_state.preprocessed = None
50
+ if 'X' not in st.session_state:
51
+ st.session_state.X = None
52
+ if 'y' not in st.session_state:
53
+ st.session_state.y = None
54
+ if 'feature_names' not in st.session_state:
55
+ st.session_state.feature_names = None
56
+ if 'vectorizer' not in st.session_state:
57
+ st.session_state.vectorizer = None
58
+ if 'vectorizer_type' not in st.session_state:
59
+ st.session_state.vectorizer_type = None
60
+ if 'X_test' not in st.session_state:
61
+ st.session_state.X_test = None
62
+ if 'y_test' not in st.session_state:
63
+ st.session_state.y_test = None
64
+ if 'test_texts' not in st.session_state:
65
+ st.session_state.test_texts = None
66
+ if 'label_encoder' not in st.session_state:
67
+ st.session_state.label_encoder = None
68
+ if 'rubert_model' not in st.session_state:
69
+ st.session_state.rubert_model = None
70
+ if 'rubert_tokenizer' not in st.session_state:
71
+ st.session_state.rubert_tokenizer = None
72
+ if 'rubert_trained' not in st.session_state:
73
+ st.session_state.rubert_trained = False
74
+
75
+ st.sidebar.title("Setup")
76
+
77
+ st.sidebar.subheader("1. Upload Dataset (JSONL)")
78
+ uploaded_file = st.sidebar.file_uploader("Upload .jsonl file", type=["jsonl"])
79
+
80
+ if uploaded_file:
81
+ try:
82
+ raw_data = []
83
+ lines = uploaded_file.getvalue().decode("utf-8").splitlines()
84
+ for line in lines:
85
+ if line.strip():
86
+ raw_data.append(json.loads(line))
87
+ st.session_state.dataset = raw_data
88
+
89
+ first = raw_data[0]
90
+ if 'sentiment' in first:
91
+ st.session_state.task_type = "binary"
92
+ labels = [item['sentiment'] for item in raw_data]
93
+ elif 'category' in first:
94
+ st.session_state.task_type = "multiclass"
95
+ labels = [item['category'] for item in raw_data]
96
+ elif 'tags' in first:
97
+ st.session_state.task_type = "multilabel"
98
+ labels = [item['tags'] for item in raw_data]
99
+ else:
100
+ st.sidebar.error("No label field found")
101
+ st.session_state.task_type = None
102
+ st.session_state.dataset = None
103
+
104
+ if st.session_state.task_type:
105
+ st.sidebar.success(f"Loaded {len(raw_data)} samples. Task: {st.session_state.task_type}")
106
+ if st.session_state.task_type == "binary":
107
+ id2label = {0: "Negative", 1: "Positive"}
108
+ label2id = {"Negative": 0, "Positive": 1}
109
+ elif st.session_state.task_type == "multiclass":
110
+ id2label = {0: "Политика", 1: "Экономика", 2: "Спорт", 3: "Культура"}
111
+ label2id = {"Политика": 0, "Экономика": 1, "Спорт": 2, "Культура": 3}
112
+ else:
113
+ id2label = None
114
+ label2id = None
115
+
116
+ st.session_state.id2label = id2label
117
+ st.session_state.label2id = label2id
118
+ except Exception as e:
119
+ st.sidebar.error(f"Failed to parse JSONL: {e}")
120
+ st.session_state.dataset = None
121
+
122
+ if st.session_state.dataset is not None:
123
+ st.sidebar.subheader("2. Preprocess Text")
124
+ lang = st.sidebar.selectbox("Language", ["ru", "en"], index=0)
125
+ st.session_state.preprocess_lang = 'ru'
126
+ if st.sidebar.button("Run Preprocessing"):
127
+ with st.spinner("Preprocessing..."):
128
+ texts = [item['text'] for item in st.session_state.dataset]
129
+ preprocessed = [preprocess_text(text, lang='ru', remove_stopwords=False) for text in texts]
130
+ st.session_state.preprocessed = preprocessed
131
+ st.sidebar.success("Preprocessing done!")
132
+
133
+ if st.session_state.preprocessed is not None:
134
+ st.sidebar.subheader("3. Vectorization (Classical)")
135
+ vectorizer_type = st.sidebar.selectbox("Method", ["TF-IDF", "RuBERT Embeddings"])
136
+ if st.sidebar.button("Vectorize"):
137
+ with st.spinner("Vectorizing..."):
138
+ if vectorizer_type == "TF-IDF":
139
+ vectorizer = TextVectorizer()
140
+ if not isinstance(st.session_state.preprocessed[0], str):
141
+ st.session_state.preprocessed = [
142
+ ' '.join(text) for text in st.session_state.preprocessed
143
+ ]
144
+ st.sidebar.write("Using max_features=5000")
145
+ X = vectorizer.tfidf(st.session_state.preprocessed, max_features=5000)
146
+ st.sidebar.write(f"X shape: {X.shape}")
147
+ st.session_state.vectorizer = vectorizer
148
+ st.session_state.feature_names = vectorizer.tfidf_vectorizer.get_feature_names_out()
149
+ else:
150
+ X = []
151
+ for text in st.session_state.preprocessed:
152
+ emb = get_contextual_embeddings([text], model_name="DeepPavlov/rubert-base-cased")
153
+ X.append(emb[0])
154
+ X = np.array(X)
155
+ st.session_state.vectorizer = None
156
+ st.session_state.feature_names = None
157
+ st.session_state.X = X
158
+ st.session_state.vectorizer_type = vectorizer_type
159
+
160
+ if st.session_state.task_type == "binary":
161
+ y = np.array([item['sentiment'] for item in st.session_state.dataset])
162
+ elif st.session_state.task_type == "multiclass":
163
+ y = np.array([item['category'] for item in st.session_state.dataset])
164
+ else:
165
+ y = [item['tags'] for item in st.session_state.dataset]
166
+ st.session_state.y = y
167
+ st.sidebar.success("Vectorization complete!")
168
+
169
+ if st.session_state.X is not None:
170
+ st.sidebar.subheader("4. Train Classical Models")
171
+ model_options = ["Logistic Regression", "SVM", "Random Forest", "XGBoost", "Voting"]
172
+ selected_models = st.sidebar.multiselect("Models", model_options)
173
+ if st.sidebar.button("Train Classical Models"):
174
+ from sklearn.model_selection import train_test_split
175
+ from sklearn.preprocessing import LabelEncoder
176
+
177
+ X = st.session_state.X
178
+ y = st.session_state.y
179
+
180
+ if st.session_state.task_type == "multiclass":
181
+ le = LabelEncoder()
182
+ y_encoded = le.fit_transform(y)
183
+ st.session_state.label_encoder = le
184
+ y_for_split = y_encoded
185
+ else:
186
+ y_for_split = y if st.session_state.task_type == "binary" else np.array([len(tags) for tags in y])
187
+
188
+ if st.session_state.task_type == "multilabel":
189
+ split_idx = int(0.8 * len(X))
190
+ X_train, X_test = X[:split_idx], X[split_idx:]
191
+ y_train, y_test = y[:split_idx], y[split_idx:]
192
+ test_texts = [item['text'] for item in st.session_state.dataset[split_idx:]]
193
+ else:
194
+ indices = np.arange(len(X))
195
+ X_train, X_test, y_train, y_test, idx_train, idx_test = train_test_split(
196
+ X, y_for_split, indices, test_size=0.2,
197
+ stratify=y_for_split if st.session_state.task_type != "multilabel" else None,
198
+ random_state=42
199
+ )
200
+ test_texts = [st.session_state.dataset[i]['text'] for i in idx_test]
201
+ if st.session_state.task_type == "multiclass":
202
+ y_train = le.inverse_transform(y_train)
203
+ y_test = le.inverse_transform(y_test)
204
+
205
+ st.session_state.X_test = X_test
206
+ st.session_state.y_test = y_test
207
+ st.session_state.test_texts = test_texts
208
+
209
+ for name in selected_models:
210
+ try:
211
+ with st.spinner(f"Training {name}..."):
212
+ if name == "Logistic Regression":
213
+ model = get_logistic_regression()
214
+ model.fit(X_train, y_train)
215
+ st.session_state.models[name] = model
216
+ elif name == "SVM":
217
+ model = get_svm_linear()
218
+ model.fit(X_train, y_train)
219
+ st.session_state.models[name] = model
220
+ elif name == "Random Forest":
221
+ model = get_random_forest()
222
+ model.fit(X_train, y_train)
223
+ st.session_state.models[name] = model
224
+ elif name == "XGBoost":
225
+ model = get_gradient_boosting("xgb", n_estimators=100)
226
+ model.fit(X_train, y_train)
227
+ st.session_state.models[name] = model
228
+ elif name == "Voting":
229
+ model = get_voting_classifier()
230
+ model.fit(X_train, y_train)
231
+ st.session_state.models[name] = model
232
+
233
+ if st.session_state.task_type != "multilabel":
234
+ metrics = evaluate_model(model, X_test, y_test)
235
+ st.session_state.results[name] = metrics
236
+ except Exception as e:
237
+ st.sidebar.error(f"Failed to train {name}: {e}")
238
+ continue
239
+ st.sidebar.success("Classical models trained!")
240
+
241
+ if st.session_state.dataset is not None and st.session_state.task_type in ["binary", "multiclass"]:
242
+ st.sidebar.subheader("5. Train RuBERT (Transformer)")
243
+ if st.sidebar.button("Train RuBERT"):
244
+ with st.spinner("Loading RuBERT..."):
245
+ try:
246
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
247
+
248
+ num_labels = 2 if st.session_state.task_type == "binary" else 4
249
+ model_name = "DeepPavlov/rubert-base-cased"
250
+
251
+ config = AutoConfig.from_pretrained(
252
+ model_name,
253
+ num_labels=num_labels,
254
+ id2label=st.session_state.id2label,
255
+ label2id=st.session_state.label2id
256
+ )
257
+
258
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
259
+ model = AutoModelForSequenceClassification.from_pretrained(model_name, config=config)
260
+
261
+ st.session_state.rubert_model = model
262
+ st.session_state.rubert_tokenizer = tokenizer
263
+ st.session_state.rubert_trained = True
264
+ st.sidebar.success("RuBERT loaded with correct labels!")
265
+ except Exception as e:
266
+ st.sidebar.error(f"RuBERT loading failed: {e}")
267
+ st.exception(e)
268
+
269
+ st.title("Text Classifiers")
270
+
271
+ tab1, tab2, tab3, tab4 = st.tabs([
272
+ "Classify",
273
+ "Interpret",
274
+ "Compare",
275
+ "Error Analysis"
276
+ ])
277
+
278
+ with tab1:
279
+ st.subheader("Classify New Text")
280
+ input_text = st.text_area("Enter text", "Сегодня прошёл важный матч по хоккею.")
281
+
282
+ if st.button("Classify"):
283
+ cols = st.columns(2)
284
+ with cols[0]:
285
+ st.markdown("### Classical Models")
286
+ if not st.session_state.models:
287
+ st.info("No classical models trained")
288
+ else:
289
+ tokens = preprocess_text(input_text, lang='ru', remove_stopwords=False)
290
+ preprocessed = " ".join(tokens)
291
+ if st.session_state.vectorizer_type == "TF-IDF":
292
+ X_input = st.session_state.vectorizer.tfidf_vectorizer.transform([preprocessed]).toarray()
293
+ else:
294
+ X_input = get_contextual_embeddings([preprocessed], model_name="DeepPavlov/rubert-base-cased")
295
+
296
+ for name, model in st.session_state.models.items():
297
+ pred = model.predict(X_input)[0]
298
+ st.write(f"**{name}**: {pred}")
299
+ if hasattr(model, "predict_proba"):
300
+ proba = model.predict_proba(X_input)[0]
301
+ st.write(f"Probabilities: {dict(zip(model.classes_, proba))}")
302
+
303
+ with cols[1]:
304
+ st.markdown("### RuBERT")
305
+ if not st.session_state.rubert_trained:
306
+ st.info("Train RuBERT in sidebar")
307
+ else:
308
+ try:
309
+ from transformers import pipeline
310
+
311
+ pipe = pipeline(
312
+ "text-classification",
313
+ model=st.session_state.rubert_model,
314
+ tokenizer=st.session_state.rubert_tokenizer,
315
+ device=-1
316
+ )
317
+ result = pipe(input_text)
318
+ label = result[0]['label']
319
+ confidence = result[0]['score']
320
+
321
+ if label.startswith("LABEL_") and st.session_state.id2label:
322
+ label_id = int(label.replace("LABEL_", ""))
323
+ readable_label = st.session_state.id2label.get(label_id, label)
324
+ else:
325
+ readable_label = label
326
+
327
+ st.write(f"**Prediction**: {readable_label}")
328
+ st.write(f"**Confidence**: {confidence:.3f}")
329
+ except Exception as e:
330
+ st.error(f"RuBERT inference failed: {e}")
331
+
332
+ with tab2:
333
+ subtab1, subtab2, subtab3 = st.tabs(["SHAP / LIME", "Attention Map", "Captum Heatmap"])
334
+
335
+ with subtab1:
336
+ st.subheader("SHAP: Local Explanation for One Text")
337
+ if not st.session_state.models:
338
+ st.info("Train a classical model first")
339
+ else:
340
+ model_name = st.selectbox("Model", list(st.session_state.models.keys()), key="shap_model")
341
+ text_for_explain = st.text_area("Text to explain", "Прекрасная новость о росте экономики!", key="shap_text")
342
+ top_k = st.slider("Top features to show", 5, 30, 15)
343
+
344
+ if st.button("Explain with SHAP"):
345
+ try:
346
+ import shap
347
+
348
+ model = st.session_state.models[model_name]
349
+ tokens = preprocess_text(text_for_explain, lang='ru', remove_stopwords=False)
350
+ preprocessed = " ".join(tokens)
351
+
352
+ if st.session_state.vectorizer_type == "TF-IDF":
353
+ X_input = st.session_state.vectorizer.tfidf_vectorizer.transform([preprocessed]).toarray()
354
+ feature_names = st.session_state.feature_names
355
+ else:
356
+ X_input = get_contextual_embeddings([preprocessed], model_name="DeepPavlov/rubert-base-cased")
357
+ feature_names = [f"emb_{i}" for i in range(X_input.shape[1])]
358
+
359
+ background = st.session_state.X[:100]
360
+ # st.write(f"DEBUG: st.session_state.X shape = {st.session_state.X.shape}")
361
+ # st.write(f"DEBUG: X_input shape = {X_input.shape}")
362
+ # st.write(f'DEBUG: background shape = {background.shape}')
363
+ if "tree" in str(type(model)).lower():
364
+ explainer = shap.TreeExplainer(model)
365
+ shap_values = explainer.shap_values(X_input)
366
+ else:
367
+ explainer = shap.KernelExplainer(model.predict_proba, background)
368
+ shap_values = explainer.shap_values(X_input, nsamples=200)
369
+
370
+ if isinstance(shap_values, list):
371
+ probs = model.predict_proba(X_input)[0]
372
+ target_class = int(np.argmax(probs))
373
+ single_shap = shap_values[target_class][0]
374
+ expected_val = explainer.expected_value[target_class]
375
+ else:
376
+ sv = shap_values
377
+ if sv.ndim == 1:
378
+ single_shap = sv
379
+ expected_val = explainer.expected_value
380
+ elif sv.ndim == 2:
381
+ if sv.shape[0] == 1:
382
+ single_shap = sv[0]
383
+ expected_val = explainer.expected_value
384
+ elif sv.shape[1] == X_input.shape[1]:
385
+ probs = model.predict_proba(X_input)[0]
386
+ target_class = int(np.argmax(probs))
387
+ single_shap = sv[:, target_class]
388
+ expected_val = explainer.expected_value[target_class] if isinstance(
389
+ explainer.expected_value, (list, np.ndarray)) else explainer.expected_value
390
+ else:
391
+ single_shap = sv[0]
392
+ expected_val = explainer.expected_value
393
+ elif sv.ndim == 3:
394
+ if sv.shape[0] != 1:
395
+ raise ValueError("SHAP explanation for more than one sample not supported")
396
+ probs = model.predict_proba(X_input)[0]
397
+ target_class = int(np.argmax(probs))
398
+ single_shap = sv[0, :, target_class]
399
+ if isinstance(explainer.expected_value, (list, np.ndarray)) and len(
400
+ explainer.expected_value) == sv.shape[2]:
401
+ expected_val = explainer.expected_value[target_class]
402
+ else:
403
+ expected_val = explainer.expected_value
404
+ else:
405
+ raise ValueError(f"Unsupported SHAP shape: {sv.shape}")
406
+
407
+ single_shap = np.array(single_shap).flatten()
408
+ if single_shap.shape[0] != X_input.shape[1]:
409
+ raise ValueError(
410
+ f"SHAP vector length {single_shap.shape[0]} != input features {X_input.shape[1]}")
411
+
412
+ if st.session_state.vectorizer_type == "TF-IDF":
413
+ text_vector = X_input[0]
414
+ nonzero_indices = np.where(text_vector != 0)[0]
415
+ if len(nonzero_indices) == 0:
416
+ st.warning("No known words from training vocabulary found in this text.")
417
+ else:
418
+ filtered_shap = single_shap[nonzero_indices]
419
+ filtered_features = text_vector[nonzero_indices]
420
+ filtered_names = [st.session_state.feature_names[i] for i in nonzero_indices]
421
+
422
+ explanation = shap.Explanation(
423
+ values=filtered_shap,
424
+ base_values=expected_val,
425
+ data=filtered_features,
426
+ feature_names=filtered_names
427
+ )
428
+
429
+ plt.figure(figsize=(10, min(8, top_k * 0.3)))
430
+ shap.plots.waterfall(explanation, max_display=top_k, show=False)
431
+ st.pyplot(plt.gcf())
432
+ plt.close()
433
+ else:
434
+ explanation = shap.Explanation(
435
+ values=single_shap,
436
+ base_values=expected_val,
437
+ data=X_input[0],
438
+ feature_names=feature_names
439
+ )
440
+ plt.figure(figsize=(10, min(8, top_k * 0.3)))
441
+ shap.plots.waterfall(explanation, max_display=top_k, show=False)
442
+ st.pyplot(plt.gcf())
443
+ plt.close()
444
+
445
+ except Exception as e:
446
+ st.error(f"SHAP error: {e}")
447
+ st.exception(e)
448
+
449
+ with subtab2:
450
+ st.subheader("Transformer Attention Map")
451
+ if not st.session_state.rubert_trained:
452
+ st.info("Train RuBERT first")
453
+ else:
454
+ text_att = st.text_area("Text for attention", "Матч завершился победой ЦСКА", key="att_text")
455
+ layer = st.slider("Layer", 0, 11, 6)
456
+ head = st.slider("Head", 0, 11, 0)
457
+ if st.button("Visualize Attention"):
458
+ try:
459
+ tokens, attn = get_transformer_attention(
460
+ st.session_state.rubert_model,
461
+ st.session_state.rubert_tokenizer,
462
+ text_att,
463
+ device="cpu"
464
+ )
465
+ weights = attn[layer, head, :len(tokens), :len(tokens)]
466
+
467
+ fig, ax = plt.subplots(figsize=(10, 4))
468
+ sns.heatmap(
469
+ weights,
470
+ xticklabels=tokens,
471
+ yticklabels=tokens,
472
+ cmap="viridis",
473
+ ax=ax
474
+ )
475
+ plt.xticks(rotation=45, ha="right")
476
+ plt.yticks(rotation=0)
477
+ plt.title(f"Attention: Layer {layer}, Head {head}")
478
+ st.pyplot(fig)
479
+ plt.close(fig)
480
+ except Exception as e:
481
+ st.error(f"Attention failed: {e}")
482
+ st.exception(e)
483
+
484
+ with subtab3:
485
+ st.subheader("Token Importance (Captum)")
486
+ if not st.session_state.rubert_trained:
487
+ st.info("Train RuBERT first")
488
+ else:
489
+ text_captum = st.text_area("Text for Captum", "Это очень плохая новость для политики", key="captum_text")
490
+ method = "IntegratedGradients"
491
+ if st.button("Compute Token Importance"):
492
+ try:
493
+ tokens, importance = get_token_importance_captum(
494
+ st.session_state.rubert_model,
495
+ st.session_state.rubert_tokenizer,
496
+ text_captum,
497
+ device="cpu"
498
+ )
499
+ valid = [(t, imp) for t, imp in zip(tokens, importance) if t not in ["[CLS]", "[SEP]", "[PAD]"]]
500
+ if valid:
501
+ tokens_clean, imp_clean = zip(*valid)
502
+ indices = np.argsort(np.abs(imp_clean))[-15:][::-1]
503
+ tokens_top = [tokens_clean[i] for i in indices]
504
+ imp_top = [imp_clean[i] for i in indices]
505
+
506
+ fig, ax = plt.subplots(figsize=(8, 6))
507
+ colors = ["red" if x < 0 else "green" for x in imp_top]
508
+ ax.barh(range(len(imp_top)), imp_top, color=colors)
509
+ ax.set_yticks(range(len(imp_top)))
510
+ ax.set_yticklabels(tokens_top)
511
+ ax.invert_yaxis()
512
+ ax.set_xlabel("Attribution Score")
513
+ ax.set_title("Token Importance")
514
+ st.pyplot(fig)
515
+ plt.close(fig)
516
+ else:
517
+ st.warning("No valid tokens")
518
+ except Exception as e:
519
+ st.error(f"Captum failed: {e}")
520
+ st.exception(e)
521
+
522
+ with tab3:
523
+ st.subheader("Model Comparison")
524
+ if st.session_state.results:
525
+ df = pd.DataFrame(st.session_state.results).T
526
+ st.dataframe(df)
527
+ else:
528
+ st.info("Train models to see metrics")
529
+
530
+ with tab4:
531
+ st.subheader("Error Analysis")
532
+ if st.session_state.X_test is None:
533
+ st.info("Train models first")
534
+ else:
535
+ model_name = st.selectbox("Model for error analysis", list(st.session_state.models.keys()), key="err_model")
536
+ if st.button("Analyze Errors"):
537
+ model = st.session_state.models[model_name]
538
+ y_pred = model.predict(st.session_state.X_test)
539
+ errors = analyze_errors(
540
+ st.session_state.y_test,
541
+ y_pred,
542
+ st.session_state.test_texts
543
+ )
544
+ st.dataframe(errors[['text', 'true_label', 'pred_label']].head(20))
src/model_evaluation.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any, Union, Callable, Optional, Tuple, List
2
+ import numpy as np
3
+ import pandas as pd
4
+ from collections import defaultdict
5
+ import torch
6
+
7
+ from sklearn.model_selection import (
8
+ StratifiedKFold, GroupKFold, TimeSeriesSplit,
9
+ GridSearchCV, RandomizedSearchCV
10
+ )
11
+ from sklearn.metrics import (
12
+ accuracy_score, precision_score, recall_score, f1_score,
13
+ roc_auc_score, average_precision_score, log_loss,
14
+ confusion_matrix, classification_report
15
+ )
16
+ from sklearn.base import BaseEstimator
17
+ import warnings
18
+
19
+ warnings.filterwarnings("ignore")
20
+
21
+ OPTUNA_AVAILABLE = False
22
+ HYPEROPT_AVAILABLE = False
23
+ try:
24
+ import optuna
25
+ from optuna.samplers import TPESampler
26
+
27
+ OPTUNA_AVAILABLE = True
28
+ except ImportError:
29
+ pass
30
+
31
+ try:
32
+ from hyperopt import fmin, tpe, hp, Trials, STATUS_OK
33
+
34
+ HYPEROPT_AVAILABLE = True
35
+ except ImportError:
36
+ pass
37
+
38
+ WANDB_AVAILABLE = False
39
+ try:
40
+ import wandb
41
+
42
+ WANDB_AVAILABLE = True
43
+ except ImportError:
44
+ pass
45
+
46
+
47
+ def get_cv_splitter(
48
+ cv_type: str = "stratified",
49
+ n_splits: int = 5,
50
+ groups: Optional[np.ndarray] = None,
51
+ random_state: int = 42
52
+ ):
53
+ if cv_type == "stratified":
54
+ return StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state)
55
+ elif cv_type == "group":
56
+ if groups is None:
57
+ raise ValueError("groups must be provided for GroupKFold")
58
+ return GroupKFold(n_splits=n_splits)
59
+ elif cv_type == "time":
60
+ return TimeSeriesSplit(n_splits=n_splits)
61
+ else:
62
+ raise ValueError("cv_type must be 'stratified', 'group', or 'time'")
63
+
64
+
65
+ def grid_search_cv(
66
+ model: BaseEstimator,
67
+ X: np.ndarray,
68
+ y: np.ndarray,
69
+ param_grid: Dict[str, List],
70
+ cv_type: str = "stratified",
71
+ n_splits: int = 5,
72
+ scoring: str = "f1_macro",
73
+ groups: Optional[np.ndarray] = None,
74
+ verbose: int = 1
75
+ ) -> GridSearchCV:
76
+ cv = get_cv_splitter(cv_type, n_splits, groups)
77
+ search = GridSearchCV(
78
+ model, param_grid, cv=cv, scoring=scoring, verbose=verbose, n_jobs=-1
79
+ )
80
+ search.fit(X, y)
81
+ return search
82
+
83
+
84
+ def random_search_cv(
85
+ model: BaseEstimator,
86
+ X: np.ndarray,
87
+ y: np.ndarray,
88
+ param_distributions: Dict[str, Any],
89
+ n_iter: int = 20,
90
+ cv_type: str = "stratified",
91
+ n_splits: int = 5,
92
+ scoring: str = "f1_macro",
93
+ groups: Optional[np.ndarray] = None,
94
+ verbose: int = 1
95
+ ) -> RandomizedSearchCV:
96
+ cv = get_cv_splitter(cv_type, n_splits, groups)
97
+ search = RandomizedSearchCV(
98
+ model, param_distributions, n_iter=n_iter, cv=cv,
99
+ scoring=scoring, verbose=verbose, n_jobs=-1, random_state=42
100
+ )
101
+ search.fit(X, y)
102
+ return search
103
+
104
+
105
+ def _optuna_objective(
106
+ trial,
107
+ model_fn: Callable,
108
+ X: np.ndarray,
109
+ y: np.ndarray,
110
+ cv,
111
+ scoring: str = "f1_macro"
112
+ ) -> float:
113
+ if "logistic" in model_fn.__name__.lower():
114
+ C = trial.suggest_float("C", 1e-4, 1e2, log=True)
115
+ penalty = trial.suggest_categorical("penalty", ["l1", "l2"])
116
+ solver = "liblinear" if penalty == "l1" else "lbfgs"
117
+ model = model_fn(C=C, penalty=penalty, solver=solver)
118
+ elif "random_forest" in model_fn.__name__.lower():
119
+ n_estimators = trial.suggest_int("n_estimators", 50, 300)
120
+ max_depth = trial.suggest_int("max_depth", 3, 20)
121
+ model = model_fn(n_estimators=n_estimators, max_depth=max_depth)
122
+ else:
123
+ model = model_fn(trial)
124
+
125
+ scores = []
126
+ for train_idx, val_idx in cv.split(X, y):
127
+ X_train, X_val = X[train_idx], X[val_idx]
128
+ y_train, y_val = y[train_idx], y[val_idx]
129
+ model.fit(X_train, y_train)
130
+ y_pred = model.predict(X_val)
131
+ if scoring == "f1_macro":
132
+ score = f1_score(y_val, y_pred, average="macro")
133
+ elif scoring == "roc_auc":
134
+ y_proba = model.predict_proba(X_val)[:, 1]
135
+ score = roc_auc_score(y_val, y_proba)
136
+ else:
137
+ raise ValueError(f"Scoring {scoring} not implemented in custom Optuna loop")
138
+ scores.append(score)
139
+ return np.mean(scores)
140
+
141
+
142
+ def optuna_tuning(
143
+ model_fn: Callable,
144
+ X: np.ndarray,
145
+ y: np.ndarray,
146
+ n_trials: int = 50,
147
+ cv_type: str = "stratified",
148
+ n_splits: int = 5,
149
+ scoring: str = "f1_macro",
150
+ groups: Optional[np.ndarray] = None,
151
+ direction: str = "maximize"
152
+ ) -> optuna.Study:
153
+ cv = get_cv_splitter(cv_type, n_splits, groups)
154
+ study = optuna.create_study(direction=direction, sampler=TPESampler(seed=42))
155
+ study.optimize(
156
+ lambda trial: _optuna_objective(trial, model_fn, X, y, cv, scoring),
157
+ n_trials=n_trials
158
+ )
159
+ return study
160
+
161
+
162
+ def hyperopt_tuning(
163
+ model_fn: Callable,
164
+ X: np.ndarray,
165
+ y: np.ndarray,
166
+ space: Dict,
167
+ max_evals: int = 50,
168
+ cv_type: str = "stratified",
169
+ n_splits: int = 5,
170
+ scoring: str = "f1_macro",
171
+ groups: Optional[np.ndarray] = None
172
+ ):
173
+ cv = get_cv_splitter(cv_type, n_splits, groups)
174
+
175
+ def objective(params):
176
+ model = model_fn(**params)
177
+ scores = []
178
+ for train_idx, val_idx in cv.split(X, y):
179
+ X_train, X_val = X[train_idx], X[val_idx]
180
+ y_train, y_val = y[train_idx], y[val_idx]
181
+ model.fit(X_train, y_train)
182
+ y_pred = model.predict(X_val)
183
+ if scoring == "f1_macro":
184
+ score = f1_score(y_val, y_pred, average="macro")
185
+ elif scoring == "roc_auc":
186
+ y_proba = model.predict_proba(X_val)[:, 1]
187
+ score = roc_auc_score(y_val, y_proba)
188
+ else:
189
+ score = -1
190
+ scores.append(-score)
191
+ return {'loss': -np.mean(scores), 'status': STATUS_OK}
192
+
193
+ trials = Trials()
194
+ best = fmin(fn=objective, space=space, algo=tpe.suggest, max_evals=max_evals, trials=trials)
195
+ return best, trials
196
+
197
+
198
+ def compute_classification_metrics(
199
+ y_true: np.ndarray,
200
+ y_pred: np.ndarray,
201
+ y_proba: Optional[np.ndarray] = None,
202
+ average: str = "macro"
203
+ ) -> Dict[str, float]:
204
+ metrics = {
205
+ "accuracy": accuracy_score(y_true, y_pred),
206
+ "precision": precision_score(y_true, y_pred, average=average, zero_division=0),
207
+ "recall": recall_score(y_true, y_pred, average=average, zero_division=0),
208
+ "f1": f1_score(y_true, y_pred, average=average, zero_division=0),
209
+ }
210
+
211
+ if y_proba is not None:
212
+ if len(np.unique(y_true)) == 2:
213
+ metrics["roc_auc"] = roc_auc_score(y_true, y_proba[:, 1])
214
+ metrics["pr_auc"] = average_precision_score(y_true, y_proba[:, 1])
215
+ metrics["log_loss"] = log_loss(y_true, y_proba)
216
+ else:
217
+ try:
218
+ metrics["roc_auc"] = roc_auc_score(y_true, y_proba, multi_class="ovr", average=average)
219
+ metrics["pr_auc"] = average_precision_score(y_true, y_proba, average=average)
220
+ metrics["log_loss"] = log_loss(y_true, y_proba)
221
+ except ValueError:
222
+ metrics["roc_auc"] = np.nan
223
+ metrics["pr_auc"] = np.nan
224
+
225
+ return metrics
226
+
227
+
228
+ def evaluate_model(
229
+ model: BaseEstimator,
230
+ X_test: np.ndarray,
231
+ y_test: np.ndarray,
232
+ average: str = "macro",
233
+ return_pred: bool = False
234
+ ) -> Union[Dict[str, float], Tuple[Dict[str, float], np.ndarray, Optional[np.ndarray]]]:
235
+ y_pred = model.predict(X_test)
236
+ y_proba = None
237
+ if hasattr(model, "predict_proba"):
238
+ y_proba = model.predict_proba(X_test)
239
+
240
+ metrics = compute_classification_metrics(y_test, y_pred, y_proba, average=average)
241
+
242
+ if return_pred:
243
+ return metrics, y_pred, y_proba
244
+ return metrics
245
+
246
+
247
+ def get_early_stopping(
248
+ monitor: str = "val_loss",
249
+ patience: int = 5,
250
+ mode: str = "min",
251
+ framework: str = "keras"
252
+ ):
253
+ if framework == "keras":
254
+ from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
255
+ es = EarlyStopping(monitor=monitor, patience=patience, restore_best_weights=True, mode=mode)
256
+ reduce_lr = ReduceLROnPlateau(monitor=monitor, factor=0.5, patience=3, min_lr=1e-7, mode=mode)
257
+ return [es, reduce_lr]
258
+ elif framework == "pytorch":
259
+ raise NotImplementedError("PyTorch callbacks require custom training loop")
260
+ else:
261
+ raise ValueError("framework must be 'keras' or 'pytorch'")
262
+
263
+
264
+ def init_wandb(
265
+ project_name: str = "text-classification",
266
+ run_name: Optional[str] = None,
267
+ config: Optional[Dict] = None
268
+ ):
269
+ if not WANDB_AVAILABLE:
270
+ return None
271
+ wandb.init(project=project_name, name=run_name, config=config)
272
+ return wandb
273
+
274
+
275
+ def log_metrics_to_wandb(metrics: Dict[str, float]):
276
+ if WANDB_AVAILABLE and wandb.run:
277
+ wandb.log(metrics)
278
+
279
+
280
+ def suggest_transformer_hparams(trial) -> Dict[str, Any]:
281
+ return {
282
+ "learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True),
283
+ "per_device_train_batch_size": trial.suggest_categorical("batch_size", [8, 16, 32]),
284
+ "num_train_epochs": trial.suggest_int("num_train_epochs", 2, 6),
285
+ "weight_decay": trial.suggest_float("weight_decay", 0.0, 0.3),
286
+ "warmup_ratio": trial.suggest_float("warmup_ratio", 0.0, 0.2),
287
+ }
288
+
289
+
290
+ def evaluate_transformer_outputs(
291
+ y_true: List[int],
292
+ y_pred: List[int],
293
+ y_logits: Optional[np.ndarray] = None
294
+ ) -> Dict[str, float]:
295
+ y_true = np.array(y_true)
296
+ y_pred = np.array(y_pred)
297
+ if y_logits is not None:
298
+ y_proba = torch.softmax(torch.tensor(y_logits), dim=-1).numpy()
299
+ else:
300
+ y_proba = None
301
+ return compute_classification_metrics(y_true, y_pred, y_proba, average="macro")
302
+
303
+
304
+ def confusion_matrix_df(y_true: np.ndarray, y_pred: np.ndarray, labels: Optional[List] = None) -> pd.DataFrame:
305
+ cm = confusion_matrix(y_true, y_pred, labels=labels)
306
+ if labels is None:
307
+ labels = sorted(np.unique(y_true))
308
+ return pd.DataFrame(cm, index=[f"True_{l}" for l in labels], columns=[f"Pred_{l}" for l in labels])
src/model_interpretation.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Any, Optional, Union, Callable, Tuple
2
+ import numpy as np
3
+ import pandas as pd
4
+ import matplotlib.pyplot as plt
5
+ import seaborn as sns
6
+ from collections import defaultdict
7
+ import torch
8
+
9
+ from sklearn.base import BaseEstimator
10
+ from sklearn.feature_extraction.text import TfidfVectorizer
11
+ from sklearn.decomposition import PCA
12
+ from sklearn.manifold import TSNE
13
+ import warnings
14
+ warnings.filterwarnings("ignore")
15
+
16
+ SHAP_AVAILABLE = False
17
+ LIME_AVAILABLE = False
18
+ CAPTUM_AVAILABLE = False
19
+ UMAP_AVAILABLE = False
20
+
21
+ try:
22
+ import shap
23
+ SHAP_AVAILABLE = True
24
+ except ImportError:
25
+ pass
26
+
27
+ try:
28
+ import lime
29
+ import lime.lime_text
30
+ LIME_AVAILABLE = True
31
+ except ImportError:
32
+ pass
33
+
34
+ try:
35
+ import captum
36
+ import captum.attr
37
+ CAPTUM_AVAILABLE = True
38
+ except ImportError:
39
+ pass
40
+
41
+ try:
42
+ import umap
43
+ UMAP_AVAILABLE = True
44
+ except ImportError:
45
+ pass
46
+
47
+
48
+ def get_linear_feature_importance(
49
+ model: BaseEstimator,
50
+ feature_names: Optional[List[str]] = None,
51
+ class_index: int = -1
52
+ ) -> pd.DataFrame:
53
+ if hasattr(model, "coef_"):
54
+ coef = model.coef_
55
+ if coef.ndim == 1:
56
+ weights = coef
57
+ else:
58
+ if class_index == -1:
59
+ weights = np.mean(coef, axis=0)
60
+ else:
61
+ weights = coef[class_index]
62
+ else:
63
+ raise ValueError("Model does not have coef_ attribute")
64
+
65
+ if feature_names is None:
66
+ feature_names = [f"feature_{i}" for i in range(len(weights))]
67
+
68
+ df = pd.DataFrame({"feature": feature_names, "weight": weights})
69
+ df = df.sort_values("weight", key=abs, ascending=False).reset_index(drop=True)
70
+ return df
71
+
72
+
73
+ def analyze_tfidf_class_keywords(
74
+ tfidf_matrix: np.ndarray,
75
+ y: np.ndarray,
76
+ feature_names: List[str],
77
+ top_k: int = 20
78
+ ) -> Dict[Any, pd.DataFrame]:
79
+ classes = np.unique(y)
80
+ results = {}
81
+
82
+ for cls in classes:
83
+ mask = (y == cls)
84
+ avg_tfidf = np.mean(tfidf_matrix[mask], axis=0).A1 if hasattr(tfidf_matrix, 'A1') else np.mean(tfidf_matrix[mask], axis=0)
85
+ top_indices = np.argsort(avg_tfidf)[::-1][:top_k]
86
+ top_words = [feature_names[i] for i in top_indices]
87
+ top_scores = [avg_tfidf[i] for i in top_indices]
88
+ results[cls] = pd.DataFrame({"word": top_words, "tfidf_score": top_scores})
89
+
90
+ return results
91
+
92
+
93
+ def explain_with_shap(
94
+ model: BaseEstimator,
95
+ X_train: np.ndarray,
96
+ X_test: np.ndarray,
97
+ feature_names: Optional[List[str]] = None,
98
+ plot_type: str = "bar",
99
+ max_display: int = 20
100
+ ):
101
+ if "tree" in str(type(model)).lower():
102
+ explainer = shap.TreeExplainer(model)
103
+ else:
104
+ explainer = shap.KernelExplainer(model.predict_proba, X_train[:100])
105
+
106
+ shap_values = explainer.shap_values(X_test[:100])
107
+
108
+ if feature_names is None:
109
+ feature_names = [f"feat_{i}" for i in range(X_test.shape[1])]
110
+
111
+ plt.figure(figsize=(10, 6))
112
+ if isinstance(shap_values, list):
113
+ shap.summary_plot(shap_values, X_test[:100], feature_names=feature_names, plot_type=plot_type, max_display=max_display, show=False)
114
+ else:
115
+ shap.summary_plot(shap_values, X_test[:100], feature_names=feature_names, plot_type=plot_type, max_display=max_display, show=False)
116
+ plt.tight_layout()
117
+ plt.show()
118
+
119
+
120
+ def explain_text_with_lime(
121
+ model: Any,
122
+ text: str,
123
+ tokenizer: Callable,
124
+ class_names: List[str],
125
+ num_features: int = 10,
126
+ num_samples: int = 5000
127
+ ):
128
+ def predict_fn(texts):
129
+ tokenized = [tokenizer(t) for t in texts]
130
+ if hasattr(model, "vectorizer"):
131
+ X = model.vectorizer.transform(texts)
132
+ else:
133
+ raise NotImplementedError("Custom predict_fn needed for your pipeline")
134
+ return model.predict_proba(X.toarray())
135
+
136
+ explainer = lime.lime_text.LimeTextExplainer(class_names=class_names)
137
+ exp = explainer.explain_instance(text, predict_fn, num_features=num_features, num_samples=num_samples)
138
+ exp.show_in_notebook()
139
+
140
+
141
+ def visualize_attention_weights(
142
+ tokens: List[str],
143
+ attention_weights: np.ndarray,
144
+ layer: int = 0,
145
+ head: int = 0,
146
+ figsize: Tuple[int, int] = (10, 2)
147
+ ):
148
+ if attention_weights.ndim != 4:
149
+ raise ValueError("attention_weights must be 4D: (layers, heads, seq, seq)")
150
+
151
+ weights = attention_weights[layer, head, :len(tokens), :len(tokens)]
152
+
153
+ plt.figure(figsize=figsize)
154
+ sns.heatmap(
155
+ weights,
156
+ xticklabels=tokens,
157
+ yticklabels=tokens,
158
+ cmap="viridis",
159
+ cbar=True
160
+ )
161
+ plt.title(f"Attention Layer {layer}, Head {head}")
162
+ plt.xticks(rotation=45, ha="right")
163
+ plt.yticks(rotation=0)
164
+ plt.tight_layout()
165
+ plt.show()
166
+
167
+
168
+ def get_transformer_attention(
169
+ model: 'torch.nn.Module',
170
+ tokenizer: 'transformers.PreTrainedTokenizer',
171
+ text: str,
172
+ device: str = "cpu"
173
+ ) -> Tuple[List[str], np.ndarray]:
174
+ if not CAPTUM_AVAILABLE:
175
+ raise ImportError("Install Captum: pip install captum")
176
+
177
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
178
+ input_ids = inputs["input_ids"].to(device)
179
+ model = model.to(device)
180
+ model.eval()
181
+
182
+ with torch.no_grad():
183
+ outputs = model(input_ids, output_attentions=True)
184
+ attentions = outputs.attentions
185
+
186
+ attn = torch.stack(attentions, dim=0).squeeze(1).cpu().numpy()
187
+ tokens = tokenizer.convert_ids_to_tokens(input_ids[0].cpu().numpy())
188
+ return tokens, attn
189
+
190
+
191
+ def analyze_errors(
192
+ y_true: np.ndarray,
193
+ y_pred: np.ndarray,
194
+ texts: List[str],
195
+ labels: Optional[List[Any]] = None
196
+ ) -> pd.DataFrame:
197
+ errors = []
198
+ for i, (true, pred, text) in enumerate(zip(y_true, y_pred, texts)):
199
+ if true != pred:
200
+ errors.append({
201
+ "index": i,
202
+ "text": text,
203
+ "true_label": true,
204
+ "pred_label": pred
205
+ })
206
+ return pd.DataFrame(errors)
207
+
208
+
209
+ def compare_model_errors(
210
+ models: Dict[str, BaseEstimator],
211
+ X_test: np.ndarray,
212
+ y_test: np.ndarray,
213
+ texts: List[str]
214
+ ) -> Dict[str, pd.DataFrame]:
215
+ results = {}
216
+ for name, model in models.items():
217
+ y_pred = model.predict(X_test)
218
+ errors = analyze_errors(y_test, y_pred, texts)
219
+ results[name] = errors
220
+ return results
221
+
222
+
223
+ def plot_embeddings(
224
+ embeddings: np.ndarray,
225
+ labels: np.ndarray,
226
+ method: str = "umap",
227
+ n_components: int = 2,
228
+ figsize: Tuple[int, int] = (12, 8),
229
+ title: str = "Embedding Projection"
230
+ ):
231
+ if method == "tsne":
232
+ reducer = TSNE(n_components=n_components, random_state=42, n_jobs=-1)
233
+ elif method == "umap":
234
+ if not UMAP_AVAILABLE:
235
+ raise ImportError("Install UMAP: pip install umap-learn")
236
+ reducer = umap.UMAP(n_components=n_components, random_state=42, n_jobs=-1)
237
+ else:
238
+ raise ValueError("method must be 'tsne' or 'umap'")
239
+
240
+ proj = reducer.fit_transform(embeddings)
241
+
242
+ plt.figure(figsize=figsize)
243
+ scatter = plt.scatter(proj[:, 0], proj[:, 1], c=labels, cmap="tab10", alpha=0.7)
244
+ plt.colorbar(scatter)
245
+ plt.title(title)
246
+ plt.xlabel("Component 1")
247
+ plt.ylabel("Component 2")
248
+ plt.tight_layout()
249
+ plt.show()
250
+
251
+
252
+ def get_token_importance_captum(
253
+ model: 'torch.nn.Module',
254
+ tokenizer: 'transformers.PreTrainedTokenizer',
255
+ text: str,
256
+ device: str = "cpu"
257
+ ) -> Tuple[List[str], np.ndarray]:
258
+ if not CAPTUM_AVAILABLE:
259
+ raise ImportError("Install Captum: pip install captum")
260
+
261
+ from captum.attr import LayerIntegratedGradients
262
+ import torch
263
+
264
+ inputs = tokenizer(
265
+ text,
266
+ return_tensors="pt",
267
+ truncation=True,
268
+ max_length=512,
269
+ padding=True
270
+ )
271
+ input_ids = inputs["input_ids"].to(device)
272
+ attention_mask = inputs["attention_mask"].to(device)
273
+
274
+ model = model.to(device)
275
+ model.eval()
276
+
277
+ with torch.no_grad():
278
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
279
+ pred_class = torch.argmax(outputs.logits, dim=1).item()
280
+
281
+ def forward_func(input_ids):
282
+ return model(input_ids=input_ids, attention_mask=attention_mask).logits
283
+
284
+ baseline_ids = torch.zeros_like(input_ids).to(device)
285
+ baseline_ids[:, 0] = tokenizer.cls_token_id
286
+ baseline_ids[:, -1] = tokenizer.sep_token_id
287
+
288
+ lig = LayerIntegratedGradients(forward_func, model.bert.embeddings)
289
+
290
+ attributions, delta = lig.attribute(
291
+ inputs=input_ids,
292
+ baselines=baseline_ids,
293
+ target=pred_class,
294
+ return_convergence_delta=True
295
+ )
296
+
297
+ attributions = attributions.sum(dim=-1).squeeze(0).cpu().detach().numpy()
298
+
299
+ tokens = tokenizer.convert_ids_to_tokens(input_ids[0].cpu().numpy())
300
+ return tokens, attributions
301
+
302
+
303
+ def plot_token_importance(tokens: List[str], importance: np.ndarray, top_k: int = 20):
304
+ valid = [(t, imp) for t, imp in zip(tokens, importance) if t not in ["[CLS]", "[SEP]", "[PAD]"]]
305
+ if not valid:
306
+ return
307
+ tokens_clean, imp_clean = zip(*valid)
308
+ indices = np.argsort(np.abs(imp_clean))[-top_k:][::-1]
309
+ tokens_top = [tokens_clean[i] for i in indices]
310
+ imp_top = [imp_clean[i] for i in indices]
311
+
312
+ plt.figure(figsize=(10, 6))
313
+ colors = ["red" if x < 0 else "green" for x in imp_top]
314
+ plt.barh(range(len(imp_top)), imp_top, color=colors)
315
+ plt.yticks(range(len(imp_top)), tokens_top)
316
+ plt.gca().invert_yaxis()
317
+ plt.xlabel("Attribution Score")
318
+ plt.title("Token Importance (Green: positive, Red: negative)")
319
+ plt.tight_layout()
320
+ plt.show()
src/neural_classifiers.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional, Union, Tuple, Dict, Any, Literal
3
+ import numpy as np
4
+
5
+ try:
6
+ import tensorflow as tf
7
+ from tensorflow.keras import layers, models, optimizers, callbacks
8
+ from tensorflow.keras.models import Model
9
+ from tensorflow.keras.layers import (
10
+ Input, Embedding, Dense, Dropout, GlobalMaxPooling1D,
11
+ Conv1D, LSTM, GRU, Bidirectional, Attention, GlobalAveragePooling1D
12
+ )
13
+ TF_AVAILABLE = True
14
+ except ImportError:
15
+ TF_AVAILABLE = False
16
+
17
+ try:
18
+ import torch
19
+ import torch.nn as nn
20
+ from torch.nn.utils.rnn import pad_sequence
21
+ from transformers import (
22
+ AutoTokenizer, AutoModel, AutoConfig,
23
+ BertForSequenceClassification, RobertaForSequenceClassification,
24
+ DistilBertForSequenceClassification, Trainer, TrainingArguments
25
+ )
26
+ from transformers.tokenization_utils_base import BatchEncoding
27
+ TORCH_AVAILABLE = True
28
+ except ImportError:
29
+ TORCH_AVAILABLE = False
30
+
31
+
32
+ class AttentionLayer(tf.keras.layers.Layer):
33
+ def __init__(self, **kwargs):
34
+ super().__init__(**kwargs)
35
+
36
+ def build(self, input_shape):
37
+ self.W = self.add_weight(
38
+ shape=(input_shape[-1], 1),
39
+ initializer='random_normal',
40
+ trainable=True,
41
+ name='attention_weight'
42
+ )
43
+ self.b = self.add_weight(
44
+ shape=(input_shape[1], 1),
45
+ initializer='zeros',
46
+ trainable=True,
47
+ name='attention_bias'
48
+ )
49
+ super().build(input_shape)
50
+
51
+ def call(self, inputs, **kwargs):
52
+ e = tf.keras.activations.tanh(tf.matmul(inputs, self.W) + self.b)
53
+ e = tf.squeeze(e, axis=-1)
54
+ a = tf.nn.softmax(e, axis=1)
55
+ a = tf.expand_dims(a, axis=-1)
56
+ weighted_input = inputs * a
57
+ return tf.reduce_sum(weighted_input, axis=1)
58
+
59
+
60
+ def build_mlp(
61
+ input_dim: int,
62
+ num_classes: int,
63
+ hidden_dims: list = [256, 128],
64
+ dropout: float = 0.3,
65
+ activation: str = 'relu'
66
+ ) -> 'tf.keras.Model':
67
+ if not TF_AVAILABLE:
68
+ raise ImportError("TensorFlow not available")
69
+ inputs = Input(shape=(input_dim,))
70
+ x = inputs
71
+ for dim in hidden_dims:
72
+ x = Dense(dim, activation=activation)(x)
73
+ x = Dropout(dropout)(x)
74
+ outputs = Dense(num_classes, activation='softmax' if num_classes > 2 else 'sigmoid')(x)
75
+ return models.Model(inputs, outputs)
76
+
77
+
78
+ def build_kim_cnn(
79
+ max_len: int,
80
+ vocab_size: int,
81
+ embed_dim: int,
82
+ num_classes: int,
83
+ filter_sizes: list = [3, 4, 5],
84
+ num_filters: int = 100,
85
+ dropout: float = 0.5,
86
+ pre_embed_matrix: Optional[np.ndarray] = None
87
+ ) -> 'tf.keras.Model':
88
+ if not TF_AVAILABLE:
89
+ raise ImportError("TensorFlow not available")
90
+ inputs = Input(shape=(max_len,))
91
+ if pre_embed_matrix is not None:
92
+ embedding = Embedding(
93
+ vocab_size, embed_dim,
94
+ weights=[pre_embed_matrix],
95
+ trainable=False
96
+ )(inputs)
97
+ else:
98
+ embedding = Embedding(vocab_size, embed_dim)(inputs)
99
+
100
+ pooled_outputs = []
101
+ for fs in filter_sizes:
102
+ x = Conv1D(num_filters, fs, activation='relu')(embedding)
103
+ x = GlobalMaxPooling1D()(x)
104
+ pooled_outputs.append(x)
105
+
106
+ merged = tf.concat(pooled_outputs, axis=1)
107
+ x = Dropout(dropout)(merged)
108
+ outputs = Dense(num_classes, activation='softmax' if num_classes > 2 else 'sigmoid')(x)
109
+ return models.Model(inputs, outputs)
110
+
111
+
112
+ def build_lstm(
113
+ max_len: int,
114
+ vocab_size: int,
115
+ embed_dim: int,
116
+ num_classes: int,
117
+ lstm_units: int = 128,
118
+ dropout: float = 0.3,
119
+ bidirectional: bool = False,
120
+ pre_embed_matrix: Optional[np.ndarray] = None
121
+ ) -> 'tf.keras.Model':
122
+ if not TF_AVAILABLE:
123
+ raise ImportError("TensorFlow not available")
124
+ inputs = Input(shape=(max_len,))
125
+ if pre_embed_matrix is not None:
126
+ x = Embedding(vocab_size, embed_dim, weights=[pre_embed_matrix], trainable=False)(inputs)
127
+ else:
128
+ x = Embedding(vocab_size, embed_dim)(inputs)
129
+
130
+ rnn_layer = LSTM(lstm_units, dropout=dropout, recurrent_dropout=dropout)
131
+ if bidirectional:
132
+ x = Bidirectional(rnn_layer)(x)
133
+ else:
134
+ x = rnn_layer(x)
135
+
136
+ outputs = Dense(num_classes, activation='softmax' if num_classes > 2 else 'sigmoid')(x)
137
+ return models.Model(inputs, outputs)
138
+
139
+
140
+ def build_cnn_lstm(
141
+ max_len: int,
142
+ vocab_size: int,
143
+ embed_dim: int,
144
+ num_classes: int,
145
+ filter_size: int = 3,
146
+ num_filters: int = 128,
147
+ lstm_units: int = 64,
148
+ dropout: float = 0.3,
149
+ pre_embed_matrix: Optional[np.ndarray] = None
150
+ ) -> 'tf.keras.Model':
151
+ if not TF_AVAILABLE:
152
+ raise ImportError("TensorFlow not available")
153
+ inputs = Input(shape=(max_len,))
154
+ if pre_embed_matrix is not None:
155
+ x = Embedding(vocab_size, embed_dim, weights=[pre_embed_matrix], trainable=False)(inputs)
156
+ else:
157
+ x = Embedding(vocab_size, embed_dim)(inputs)
158
+
159
+ x = Conv1D(num_filters, filter_size, activation='relu', padding='same')(x)
160
+ x = LSTM(lstm_units, dropout=dropout)(x)
161
+ outputs = Dense(num_classes, activation='softmax' if num_classes > 2 else 'sigmoid')(x)
162
+ return models.Model(inputs, outputs)
163
+
164
+
165
+ def build_birnn_attention(
166
+ max_len: int,
167
+ vocab_size: int,
168
+ embed_dim: int,
169
+ num_classes: int,
170
+ rnn_units: int = 64,
171
+ dropout: float = 0.3,
172
+ pre_embed_matrix: Optional[np.ndarray] = None
173
+ ) -> 'tf.keras.Model':
174
+ if not TF_AVAILABLE:
175
+ raise ImportError("TensorFlow not available")
176
+ inputs = Input(shape=(max_len,))
177
+ if pre_embed_matrix is not None:
178
+ x = Embedding(vocab_size, embed_dim, weights=[pre_embed_matrix], trainable=False)(inputs)
179
+ else:
180
+ x = Embedding(vocab_size, embed_dim)(inputs)
181
+
182
+ x = Bidirectional(LSTM(rnn_units, return_sequences=True, dropout=dropout))(x)
183
+ x = AttentionLayer()(x)
184
+ outputs = Dense(num_classes, activation='softmax' if num_classes > 2 else 'sigmoid')(x)
185
+ return models.Model(inputs, outputs)
186
+
187
+
188
+ _RUSSIAN_TRANSFORMERS = {
189
+ "rubert": "DeepPavlov/rubert-base-cased",
190
+ "ruroberta": "sberbank-ai/ruRoberta-large",
191
+ "distilbert-multilingual": "distilbert-base-multilingual-cased"
192
+ }
193
+
194
+ def get_transformer_classifier(
195
+ model_name: str = "rubert",
196
+ num_classes: int = 2,
197
+ problem_type: Literal["single_label", "multi_label"] = "single_label"
198
+ ) -> Tuple[Any, Any]:
199
+ if not TORCH_AVAILABLE:
200
+ raise ImportError("PyTorch or transformers not available")
201
+
202
+ if model_name not in _RUSSIAN_TRANSFORMERS:
203
+ raise ValueError(f"Unknown model_name. Choose from: {list(_RUSSIAN_TRANSFORMERS.keys())}")
204
+
205
+ model_id = _RUSSIAN_TRANSFORMERS[model_name]
206
+
207
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
208
+
209
+ if "roberta" in model_id.lower():
210
+ model = RobertaForSequenceClassification.from_pretrained(
211
+ model_id, num_labels=num_classes
212
+ )
213
+ elif "distilbert" in model_id.lower():
214
+ model = DistilBertForSequenceClassification.from_pretrained(
215
+ model_id, num_labels=num_classes
216
+ )
217
+ else:
218
+ model = BertForSequenceClassification.from_pretrained(
219
+ model_id, num_labels=num_classes
220
+ )
221
+
222
+ if problem_type == "multi_label":
223
+ model.config.problem_type = "multi_label_classification"
224
+ else:
225
+ model.config.problem_type = "single_label_classification"
226
+
227
+ return model, tokenizer
228
+
229
+
230
+ def quantize_pytorch_model(model: 'torch.nn.Module', backend: str = "qnnpack") -> 'torch.nn.Module':
231
+ if not TORCH_AVAILABLE:
232
+ raise ImportError("PyTorch not available")
233
+ model.eval()
234
+ model.qconfig = torch.quantization.get_default_qconfig(backend)
235
+ torch.quantization.prepare(model, inplace=True)
236
+ torch.quantization.convert(model, inplace=True)
237
+ return model
238
+
239
+
240
+ def prune_keras_model(model: 'tf.keras.Model', sparsity: float = 0.5) -> 'tf.keras.Model':
241
+ try:
242
+ import tensorflow_model_optimization as tfmot
243
+ except ImportError:
244
+ raise ImportError("Install tensorflow-model-optimization for pruning")
245
+ pruning_params = {
246
+ 'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(
247
+ initial_sparsity=0.0, final_sparsity=sparsity, begin_step=0, end_step=1000
248
+ )
249
+ }
250
+ model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params)
251
+ return model_for_pruning
252
+
253
+
254
+ def prepare_keras_inputs(
255
+ texts: list,
256
+ tokenizer=None,
257
+ max_len: int = 128,
258
+ vocab: Optional[dict] = None
259
+ ) -> np.ndarray:
260
+ if tokenizer is not None:
261
+ encodings = tokenizer(texts, truncation=True, padding=True, max_length=max_len, return_tensors="np")
262
+ return encodings['input_ids']
263
+ else:
264
+ from tensorflow.keras.preprocessing.text import Tokenizer
265
+ from tensorflow.keras.preprocessing.sequence import pad_sequences
266
+ tk = Tokenizer(oov_token="<OOV>")
267
+ if vocab:
268
+ tk.word_index = vocab
269
+ else:
270
+ tk.fit_on_texts(texts)
271
+ sequences = tk.texts_to_sequences(texts)
272
+ return pad_sequences(sequences, maxlen=max_len)
273
+
274
+
275
+ def compile_keras_model(
276
+ model: 'tf.keras.Model',
277
+ learning_rate: float = 2e-5,
278
+ num_classes: int = 2
279
+ ):
280
+ loss = 'sparse_categorical_crossentropy' if num_classes > 2 else 'binary_crossentropy'
281
+ model.compile(
282
+ optimizer=optimizers.Adam(learning_rate=learning_rate),
283
+ loss=loss,
284
+ metrics=['accuracy']
285
+ )
286
+ return model
src/text_preprocessing.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import string
3
+ from typing import List, Optional, Union, Dict, Any, Callable
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
8
+ from nltk.corpus import stopwords
9
+ from nltk.tokenize import word_tokenize
10
+ from nltk import download as nltk_download
11
+ from nltk.stem import WordNetLemmatizer
12
+ import spacy
13
+ from gensim.models import KeyedVectors
14
+ from transformers import AutoTokenizer, AutoModel
15
+ import torch
16
+ import emoji
17
+ print('PREPROCESSING IMPORTED')
18
+
19
+ try:
20
+ nltk_download('punkt', quiet=True)
21
+ nltk_download('stopwords', quiet=True)
22
+ nltk_download('wordnet', quiet=True)
23
+ except Exception as e:
24
+ print(f"Warning: NLTK data download failed: {e}")
25
+
26
+ _SPACY_MODEL = None
27
+ _NLTK_LEMMATIZER = None
28
+ _BERT_TOKENIZER = None
29
+ _BERT_MODEL = None
30
+
31
+
32
+ def _load_spacy_model(lang: str = "en_core_web_sm"):
33
+ global _SPACY_MODEL
34
+ if _SPACY_MODEL is None:
35
+ try:
36
+ _SPACY_MODEL = spacy.load(lang)
37
+ except OSError:
38
+ raise ValueError(
39
+ f"spaCy model '{lang}' not found. Please install it via: python -m spacy download {lang}"
40
+ )
41
+ return _SPACY_MODEL
42
+
43
+
44
+ def _load_nltk_lemmatizer():
45
+ global _NLTK_LEMMATIZER
46
+ if _NLTK_LEMMATIZER is None:
47
+ _NLTK_LEMMATIZER = WordNetLemmatizer()
48
+ return _NLTK_LEMMATIZER
49
+
50
+
51
+ def _load_bert_model(model_name: str = "bert-base-uncased"):
52
+ global _BERT_TOKENIZER, _BERT_MODEL
53
+ if _BERT_TOKENIZER is None or _BERT_MODEL is None:
54
+ _BERT_TOKENIZER = AutoTokenizer.from_pretrained(model_name)
55
+ _BERT_MODEL = AutoModel.from_pretrained(model_name)
56
+ return _BERT_TOKENIZER, _BERT_MODEL
57
+
58
+
59
+ def clean_text(text: str) -> str:
60
+ text = re.sub(r"<[^>]+>", "", text)
61
+ text = re.sub(r"https?://\S+|www\.\S+", "", text)
62
+ text = "".join(ch for ch in text if ch in string.printable)
63
+ text = re.sub(r"\s+", " ", text).strip()
64
+ return text
65
+
66
+
67
+ def replace_emojis(text: str) -> str:
68
+ return emoji.demojize(text, delimiters=(" ", " "))
69
+
70
+
71
+ def preprocess_text(
72
+ text: str,
73
+ lang: str = "en",
74
+ remove_stopwords: bool = True,
75
+ use_spacy: bool = True,
76
+ lemmatize: bool = True,
77
+ emoji_to_text: bool = True,
78
+ lowercase: bool = True,
79
+ spacy_model: Optional[str] = None,
80
+ replace_entities: bool = False # ← новая опция: по умолчанию НЕ заменяем числа/URL
81
+ ) -> List[str]:
82
+ import re
83
+ import string
84
+
85
+ if emoji_to_text:
86
+ text = replace_emojis(text)
87
+
88
+ text = re.sub(r"<[^>]+>", "", text)
89
+
90
+ text = re.sub(r"[^\w\s]", " ", text) # заменяем НЕ-слова и НЕ-пробелы на пробел
91
+ text = re.sub(r"\s+", " ", text).strip()
92
+
93
+ if replace_entities:
94
+ text = re.sub(r"\b\d+\b", "<NUM>", text)
95
+ text = re.sub(r"https?://\S+|www\.\S+", "<URL>", text)
96
+ text = re.sub(r"\S+@\S+", "<EMAIL>", text)
97
+
98
+ if lowercase:
99
+ text = text.lower()
100
+
101
+ if use_spacy:
102
+ spacy_lang = spacy_model or ("en_core_web_sm" if lang == "en" else f"{lang}_core_news_sm")
103
+ nlp = _load_spacy_model(spacy_lang)
104
+ doc = nlp(text)
105
+ if lemmatize:
106
+ tokens = [token.lemma_ for token in doc if not token.is_space and not token.is_punct]
107
+ else:
108
+ tokens = [token.text for token in doc if not token.is_space and not token.is_punct]
109
+
110
+ if remove_stopwords:
111
+ tokens = [token for token in tokens if not nlp.vocab[token].is_stop]
112
+
113
+ else:
114
+ tokens = word_tokenize(text)
115
+ if lemmatize:
116
+ lemmatizer = _load_nltk_lemmatizer()
117
+ tokens = [lemmatizer.lemmatize(token) for token in tokens]
118
+
119
+ if remove_stopwords:
120
+ stop_words = set(stopwords.words(lang)) if lang in stopwords.fileids() else set()
121
+ tokens = [token for token in tokens if token not in stop_words]
122
+
123
+ tokens = [token for token in tokens if token not in string.punctuation and len(token) > 0]
124
+
125
+ return tokens
126
+
127
+
128
+ class TextVectorizer:
129
+ def __init__(self):
130
+ self.bow_vectorizer = None
131
+ self.tfidf_vectorizer = None
132
+
133
+ def bow(self, texts: List[str], **kwargs) -> np.ndarray:
134
+ self.bow_vectorizer = CountVectorizer(**kwargs)
135
+ return self.bow_vectorizer.fit_transform(texts).toarray()
136
+
137
+ def tfidf(self, texts: List[str], max_features: int = 5000, **kwargs) -> np.ndarray:
138
+ kwargs['max_features'] = max_features
139
+ self.tfidf_vectorizer = TfidfVectorizer(lowercase=False, **kwargs)
140
+ return self.tfidf_vectorizer.fit_transform(texts).toarray()
141
+
142
+ def ngrams(self, texts: List[str], ngram_range: tuple = (1, 2), **kwargs) -> np.ndarray:
143
+ kwargs.setdefault("ngram_range", ngram_range)
144
+ return self.tfidf(texts, **kwargs)
145
+
146
+
147
+ class EmbeddingVectorizer:
148
+ def __init__(self):
149
+ self.word2vec_model = None
150
+ self.fasttext_model = None
151
+ self.glove_vectors = None
152
+
153
+ def load_word2vec(self, path: str):
154
+ self.word2vec_model = KeyedVectors.load_word2vec_format(path, binary=True)
155
+
156
+ def load_fasttext(self, path: str):
157
+ self.fasttext_model = KeyedVectors.load(path)
158
+
159
+ def load_glove(self, glove_file: str, vocab_size: int = 400000, dim: int = 300):
160
+ self.glove_vectors = {}
161
+ with open(glove_file, "r", encoding="utf-8") as f:
162
+ for i, line in enumerate(f):
163
+ if i >= vocab_size:
164
+ break
165
+ values = line.split()
166
+ word = values[0]
167
+ vector = np.array(values[1:], dtype="float32")
168
+ self.glove_vectors[word] = vector
169
+
170
+ def _get_word_vector(self, word: str, method: str = "word2vec") -> Optional[np.ndarray]:
171
+ if method == "word2vec" and self.word2vec_model and word in self.word2vec_model:
172
+ return self.word2vec_model[word]
173
+ elif method == "fasttext" and self.fasttext_model and word in self.fasttext_model:
174
+ return self.fasttext_model[word]
175
+ elif method == "glove" and self.glove_vectors and word in self.glove_vectors:
176
+ return self.glove_vectors[word]
177
+ return None
178
+
179
+ def _aggregate_vectors(
180
+ self, vectors: List[np.ndarray], strategy: str = "mean"
181
+ ) -> np.ndarray:
182
+ if not vectors:
183
+ return np.zeros(300) # default dim
184
+ if strategy == "mean":
185
+ return np.mean(vectors, axis=0)
186
+ elif strategy == "max":
187
+ return np.max(vectors, axis=0)
188
+ else:
189
+ raise ValueError("Strategy must be 'mean' or 'max'")
190
+
191
+ def get_embeddings(
192
+ self,
193
+ tokenized_texts: List[List[str]],
194
+ method: str = "word2vec",
195
+ aggregation: str = "mean",
196
+ ) -> np.ndarray:
197
+ embeddings = []
198
+ for tokens in tokenized_texts:
199
+ vectors = [
200
+ self._get_word_vector(token, method=method) for token in tokens
201
+ ]
202
+ vectors = [v for v in vectors if v is not None]
203
+ doc_vec = self._aggregate_vectors(vectors, strategy=aggregation)
204
+ embeddings.append(doc_vec)
205
+ return np.array(embeddings)
206
+
207
+
208
+ def get_contextual_embeddings(
209
+ texts: List[str],
210
+ model_name: str = "bert-base-uncased",
211
+ aggregation: str = "mean",
212
+ device: str = "cpu",
213
+ ) -> np.ndarray:
214
+ tokenizer, model = _load_bert_model(model_name)
215
+ model.to(device)
216
+ model.eval()
217
+
218
+ embeddings = []
219
+ with torch.no_grad():
220
+ for text in texts:
221
+ inputs = tokenizer(
222
+ text,
223
+ return_tensors="pt",
224
+ truncation=True,
225
+ padding=True,
226
+ max_length=512,
227
+ )
228
+ inputs = {k: v.to(device) for k, v in inputs.items()}
229
+ outputs = model(**inputs)
230
+ token_embeddings = outputs.last_hidden_state[0].cpu().numpy()
231
+
232
+ # Exclude [CLS] and [SEP] if needed (simple heuristic: skip first and last)
233
+ if len(token_embeddings) > 2:
234
+ token_embeddings = token_embeddings[1:-1]
235
+
236
+ if aggregation == "mean":
237
+ doc_emb = np.mean(token_embeddings, axis=0)
238
+ elif aggregation == "max":
239
+ doc_emb = np.max(token_embeddings, axis=0)
240
+ else:
241
+ raise ValueError("aggregation must be 'mean' or 'max'")
242
+ embeddings.append(doc_emb)
243
+
244
+ return np.array(embeddings)
245
+
246
+
247
+ def extract_meta_features(texts: Union[List[str], pd.Series]) -> pd.DataFrame:
248
+ if isinstance(texts, pd.Series):
249
+ texts = texts.tolist()
250
+
251
+ features = []
252
+ for text in texts:
253
+ original_len = len(text)
254
+ words = text.split()
255
+ word_lengths = [len(w) for w in words] if words else [0]
256
+ avg_word_len = np.mean(word_lengths)
257
+ num_unique_words = len(set(words)) if words else 0
258
+ num_punct = sum(1 for c in text if c in string.punctuation)
259
+ num_upper = sum(1 for c in text if c.isupper())
260
+ num_digits = sum(1 for c in text if c.isdigit())
261
+
262
+ try:
263
+ flesch = np.nan
264
+ except Exception:
265
+ flesch = np.nan
266
+
267
+ features.append({
268
+ "text_length": original_len,
269
+ "avg_word_length": avg_word_len,
270
+ "num_unique_words": num_unique_words,
271
+ "punctuation_ratio": num_punct / original_len if original_len > 0 else 0,
272
+ "uppercase_ratio": num_upper / original_len if original_len > 0 else 0,
273
+ "digit_ratio": num_digits / original_len if original_len > 0 else 0,
274
+ "flesch_reading_ease": flesch,
275
+ })
276
+
277
+ return pd.DataFrame(features)