timagonch Claude Sonnet 4.6 commited on
Commit
1ce3127
·
1 Parent(s): e906eea

Deploy algospeak classifier — model loaded from HF Hub at runtime

Browse files
Files changed (6) hide show
  1. Dockerfile +4 -4
  2. app.py +70 -0
  3. poc/config.yaml +35 -0
  4. poc/src/inference.py +328 -0
  5. poc/src/model.py +116 -0
  6. requirements.txt +9 -3
Dockerfile CHANGED
@@ -1,4 +1,4 @@
1
- FROM python:3.13.5-slim
2
 
3
  WORKDIR /app
4
 
@@ -9,12 +9,12 @@ RUN apt-get update && apt-get install -y \
9
  && rm -rf /var/lib/apt/lists/*
10
 
11
  COPY requirements.txt ./
12
- COPY src/ ./src/
13
 
14
- RUN pip3 install -r requirements.txt
15
 
16
  EXPOSE 8501
17
 
18
  HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
19
 
20
- ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
 
1
+ FROM python:3.12-slim
2
 
3
  WORKDIR /app
4
 
 
9
  && rm -rf /var/lib/apt/lists/*
10
 
11
  COPY requirements.txt ./
12
+ RUN pip install --no-cache-dir -r requirements.txt
13
 
14
+ COPY . .
15
 
16
  EXPOSE 8501
17
 
18
  HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
19
 
20
+ ENTRYPOINT ["streamlit", "run", "app.py", "--server.port=8501", "--server.address=0.0.0.0"]
app.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ app.py — Algospeak Classifier demo
3
+
4
+ Streamlit UI for the dual BERTweet model.
5
+ Type a social media post and see the predicted class + confidence scores.
6
+ """
7
+
8
+ import sys
9
+ from pathlib import Path
10
+
11
+ sys.path.insert(0, str(Path(__file__).parent / "poc" / "src"))
12
+
13
+ import yaml
14
+ import torch
15
+ import numpy as np
16
+ import emoji
17
+ import streamlit as st
18
+ from transformers import AutoTokenizer
19
+ from huggingface_hub import hf_hub_download
20
+
21
+ from inference import load_unsupervised_encoder, classify_text
22
+
23
+ BASE_DIR = Path(__file__).parent
24
+ MODEL_REPO = "timagonch/algospeak-classifier-model"
25
+
26
+ CLASS_COLORS = {
27
+ "Allowed": "green",
28
+ "Offensive Language": "red",
29
+ "Mature Content": "orange",
30
+ "Algospeak": "violet",
31
+ }
32
+
33
+
34
+ @st.cache_resource(show_spinner="Loading model...")
35
+ def load_model():
36
+ with open(BASE_DIR / "poc" / "config.yaml") as f:
37
+ cfg = yaml.safe_load(f)
38
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
+
40
+ checkpoint_path = hf_hub_download(repo_id=MODEL_REPO, filename="best_model.pt")
41
+ prototypes_path = hf_hub_download(repo_id=MODEL_REPO, filename="prototypes.npy")
42
+
43
+ encoder = load_unsupervised_encoder(checkpoint_path, cfg, device)
44
+ prototypes = np.load(prototypes_path)
45
+ tokenizer = AutoTokenizer.from_pretrained(cfg["model_name"], use_fast=False)
46
+ return encoder, prototypes, tokenizer, cfg, device
47
+
48
+
49
+ # ─────────────────────────────────────────────────────────────────────
50
+ # UI
51
+ # ─────────────────────────────────────────────────────────────────────
52
+
53
+ st.title("Algospeak Classifier")
54
+ st.caption("Dual BERTweet model · type a social media post to classify it.")
55
+
56
+ text = st.text_area("Post text", height=120, placeholder="Type something here...")
57
+
58
+ if st.button("Classify", type="primary") and text.strip():
59
+ encoder, prototypes, tokenizer, cfg, device = load_model()
60
+ result = classify_text(text, encoder, prototypes, tokenizer, cfg["max_length"], device)
61
+
62
+ label = result["predicted_label"]
63
+ color = CLASS_COLORS[label]
64
+
65
+ st.markdown(f"## :{color}[{label}]")
66
+ st.divider()
67
+
68
+ st.write("**Confidence scores:**")
69
+ for name, score in sorted(result["scores"].items(), key=lambda x: -x[1]):
70
+ st.progress(float(score), text=f"{name}: {score:.1%}")
poc/config.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dual BERTweet Configuration
2
+
3
+ # Classes
4
+ num_classes: 4
5
+ class_labels:
6
+ 0: "Allowed"
7
+ 1: "Offensive Language"
8
+ 2: "Mature Content"
9
+ 3: "Algospeak"
10
+
11
+ # Model
12
+ model_name: "vinai/bertweet-base"
13
+ embedding_dim: 768
14
+ max_length: 128
15
+
16
+ # Training
17
+ batch_size: 32
18
+ learning_rate: 2.0e-5
19
+ weight_decay: 0.01
20
+ num_epochs: 20
21
+ warmup_steps: 200
22
+ early_stopping_patience: 5
23
+ fp16: true
24
+ gradient_clip: 1.0
25
+
26
+ # Loss
27
+ temperature: 0.07
28
+
29
+ # Paths (relative to project root)
30
+ train_csv: "data/splits/train.csv"
31
+ val_csv: "data/splits/val.csv"
32
+ test_csv: "data/splits/test.csv"
33
+ prepared_dir: "poc/data/prepared"
34
+ checkpoint_dir: "poc/checkpoints"
35
+ results_dir: "poc/results"
poc/src/inference.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ inference.py
3
+
4
+ Inference and full evaluation for the dual BERTweet model.
5
+
6
+ Inference uses only the unsupervised encoder:
7
+ 1. Build class prototypes from the training set (average embedding per class).
8
+ 2. For a new post: encode -> cosine similarity to each prototype -> argmax = class.
9
+
10
+ Evaluation produces:
11
+ - Accuracy (overall + per-class)
12
+ - Precision, Recall, F1 (per-class, macro, weighted)
13
+ - Confusion matrix (saved as PNG)
14
+ - ROC curves + AUC per class (saved as PNG)
15
+ - Full metrics saved to JSON
16
+
17
+ Usage:
18
+ uv run python poc/src/inference.py
19
+ """
20
+
21
+ import sys
22
+ import json
23
+ import yaml
24
+ import torch
25
+ import torch.nn.functional as F
26
+ import numpy as np
27
+ import pandas as pd
28
+ import matplotlib.pyplot as plt
29
+ import matplotlib
30
+ import emoji
31
+ matplotlib.use("Agg") # non-interactive backend for saving figures
32
+
33
+ from pathlib import Path
34
+ from torch.utils.data import TensorDataset, DataLoader
35
+ from transformers import AutoTokenizer
36
+ from sklearn.metrics import (
37
+ accuracy_score,
38
+ classification_report,
39
+ confusion_matrix,
40
+ roc_curve,
41
+ auc,
42
+ )
43
+
44
+ sys.path.insert(0, str(Path(__file__).parent))
45
+ from model import DualEncoderModel, BERTweetEncoder
46
+
47
+ BASE_DIR = Path(__file__).resolve().parent.parent.parent
48
+
49
+ CLASS_PREFIX = {
50
+ 0: "Allowed:",
51
+ 1: "Offensive Language:",
52
+ 2: "Mature Content:",
53
+ 3: "Algospeak:",
54
+ }
55
+
56
+ CLASS_NAMES = ["Allowed", "Offensive Language", "Mature Content", "Algospeak"]
57
+
58
+
59
+ def load_config() -> dict:
60
+ with open(BASE_DIR / "poc" / "config.yaml") as f:
61
+ return yaml.safe_load(f)
62
+
63
+
64
+ def load_unsupervised_encoder(ckpt_path: Path, cfg: dict, device: torch.device):
65
+ """Load the full dual model from checkpoint, return only the unsupervised encoder."""
66
+ model = DualEncoderModel(cfg["model_name"], cfg["temperature"])
67
+ ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)
68
+ model.load_state_dict(ckpt["model_state_dict"])
69
+ model = model.to(device)
70
+ model.eval()
71
+ print(f"Loaded checkpoint from epoch {ckpt['epoch']} (val_loss={ckpt['val_loss']:.4f})")
72
+ return model.unsupervised
73
+
74
+
75
+ def load_dataset(path: Path) -> TensorDataset:
76
+ data = torch.load(path, map_location="cpu", weights_only=True)
77
+ return TensorDataset(
78
+ data["unsup_ids"],
79
+ data["unsup_mask"],
80
+ data["labels"],
81
+ )
82
+
83
+
84
+ def get_embeddings(
85
+ encoder: BERTweetEncoder,
86
+ dataset: TensorDataset,
87
+ batch_sz: int,
88
+ device: torch.device,
89
+ ) -> tuple[np.ndarray, np.ndarray]:
90
+ """Run all samples through the unsupervised encoder. Returns (embeddings, labels)."""
91
+ loader = DataLoader(dataset, batch_size=batch_sz, shuffle=False, num_workers=2)
92
+ all_embs, all_labels = [], []
93
+
94
+ with torch.no_grad():
95
+ for unsup_ids, unsup_mask, labels in loader:
96
+ unsup_ids = unsup_ids.to(device)
97
+ unsup_mask = unsup_mask.to(device)
98
+ embs = encoder(unsup_ids, unsup_mask)
99
+ all_embs.append(embs.cpu().numpy())
100
+ all_labels.append(labels.numpy())
101
+
102
+ return np.vstack(all_embs), np.concatenate(all_labels)
103
+
104
+
105
+ def build_prototypes(
106
+ embeddings: np.ndarray,
107
+ labels: np.ndarray,
108
+ num_classes: int,
109
+ ) -> np.ndarray:
110
+ """Average embedding per class -> [num_classes, D] prototype matrix."""
111
+ D = embeddings.shape[1]
112
+ prototypes = np.zeros((num_classes, D), dtype=np.float32)
113
+ for cls in range(num_classes):
114
+ mask = labels == cls
115
+ if mask.sum() > 0:
116
+ proto = embeddings[mask].mean(axis=0)
117
+ prototypes[cls] = proto / (np.linalg.norm(proto) + 1e-8)
118
+ return prototypes
119
+
120
+
121
+ def predict(
122
+ embeddings: np.ndarray,
123
+ prototypes: np.ndarray,
124
+ ) -> tuple[np.ndarray, np.ndarray]:
125
+ """
126
+ Cosine similarity of each embedding to each prototype.
127
+ Returns (predicted_labels, score_matrix [N, num_classes]).
128
+ Scores are softmax-normalized cosine similarities — used for ROC curves.
129
+ """
130
+ # cosine similarity: embeddings are already L2-normalized, prototypes also normalized
131
+ sim = embeddings @ prototypes.T # [N, num_classes]
132
+ scores = torch.softmax(torch.tensor(sim / 0.1), dim=-1).numpy() # [N, num_classes]
133
+ preds = sim.argmax(axis=1)
134
+ return preds, scores
135
+
136
+
137
+ # ─────────────────────────────────────────────────────────────────────
138
+ # Plotting helpers
139
+ # ─────────────────────────────────────────────────────────────────────
140
+
141
+ def plot_confusion_matrix(y_true, y_pred, out_path: Path):
142
+ cm = confusion_matrix(y_true, y_pred)
143
+ fig, ax = plt.subplots(figsize=(7, 6))
144
+ im = ax.imshow(cm, interpolation="nearest", cmap=plt.cm.Blues)
145
+ plt.colorbar(im, ax=ax)
146
+
147
+ ax.set_xticks(range(len(CLASS_NAMES)))
148
+ ax.set_yticks(range(len(CLASS_NAMES)))
149
+ ax.set_xticklabels(CLASS_NAMES, rotation=30, ha="right", fontsize=9)
150
+ ax.set_yticklabels(CLASS_NAMES, fontsize=9)
151
+ ax.set_xlabel("Predicted")
152
+ ax.set_ylabel("True")
153
+ ax.set_title("Confusion Matrix")
154
+
155
+ thresh = cm.max() / 2.0
156
+ for i in range(cm.shape[0]):
157
+ for j in range(cm.shape[1]):
158
+ ax.text(j, i, str(cm[i, j]),
159
+ ha="center", va="center",
160
+ color="white" if cm[i, j] > thresh else "black", fontsize=10)
161
+
162
+ plt.tight_layout()
163
+ plt.savefig(out_path, dpi=150)
164
+ plt.close()
165
+ print(f" Confusion matrix saved -> {out_path}")
166
+
167
+
168
+ def plot_roc_curves(y_true, scores, num_classes: int, out_path: Path):
169
+ fig, ax = plt.subplots(figsize=(8, 6))
170
+ colors = ["#e41a1c", "#377eb8", "#4daf4a", "#984ea3"]
171
+
172
+ for cls in range(num_classes):
173
+ y_bin = (y_true == cls).astype(int)
174
+ fpr, tpr, _ = roc_curve(y_bin, scores[:, cls])
175
+ roc_auc = auc(fpr, tpr)
176
+ ax.plot(fpr, tpr, color=colors[cls], lw=2,
177
+ label=f"{CLASS_NAMES[cls]} (AUC={roc_auc:.3f})")
178
+
179
+ ax.plot([0, 1], [0, 1], "k--", lw=1)
180
+ ax.set_xlabel("False Positive Rate")
181
+ ax.set_ylabel("True Positive Rate")
182
+ ax.set_title("ROC Curves (One-vs-Rest)")
183
+ ax.legend(loc="lower right", fontsize=9)
184
+ plt.tight_layout()
185
+ plt.savefig(out_path, dpi=150)
186
+ plt.close()
187
+ print(f" ROC curves saved -> {out_path}")
188
+
189
+
190
+ # ─────────────────────────────────────────────────────────────────────
191
+ # Main evaluation
192
+ # ─────────────────────────────────────────────────────────────────────
193
+
194
+ def evaluate_split(
195
+ encoder: BERTweetEncoder,
196
+ prototypes: np.ndarray,
197
+ split: str,
198
+ cfg: dict,
199
+ device: torch.device,
200
+ results_dir: Path,
201
+ ) -> dict:
202
+ print(f"\n--- Evaluating {split} split ---")
203
+ dataset = load_dataset(BASE_DIR / cfg["prepared_dir"] / f"{split}.pt")
204
+ embs, labels = get_embeddings(encoder, dataset, cfg["batch_size"], device)
205
+ preds, scores = predict(embs, prototypes)
206
+
207
+ # Save per-sample predictions CSV
208
+ csv_df = pd.read_csv(BASE_DIR / cfg[f"{split}_csv"])
209
+ csv_df = csv_df.dropna(subset=["text"]).reset_index(drop=True)
210
+ pred_df = pd.DataFrame({
211
+ "text": csv_df["text"].astype(str),
212
+ "true_label": [CLASS_NAMES[i] for i in labels],
213
+ "predicted_label": [CLASS_NAMES[i] for i in preds],
214
+ "correct": labels == preds,
215
+ })
216
+ pred_df.to_csv(results_dir / f"predictions_{split}.csv", index=False)
217
+ print(f" Predictions saved -> {results_dir / f'predictions_{split}.csv'}")
218
+
219
+ acc = accuracy_score(labels, preds)
220
+ report = classification_report(
221
+ labels, preds, target_names=CLASS_NAMES, output_dict=True
222
+ )
223
+
224
+ print(f" Accuracy: {acc:.4f}")
225
+ print(classification_report(labels, preds, target_names=CLASS_NAMES, digits=4))
226
+
227
+ plot_confusion_matrix(labels, preds, results_dir / f"confusion_matrix_{split}.png")
228
+ plot_roc_curves(labels, scores, cfg["num_classes"], results_dir / f"roc_curves_{split}.png")
229
+
230
+ aucs = {}
231
+ for cls in range(cfg["num_classes"]):
232
+ y_bin = (labels == cls).astype(int)
233
+ fpr, tpr, _ = roc_curve(y_bin, scores[:, cls])
234
+ aucs[CLASS_NAMES[cls]] = round(auc(fpr, tpr), 4)
235
+
236
+ return {
237
+ "split": split,
238
+ "accuracy": round(acc, 4),
239
+ "macro_f1": round(report["macro avg"]["f1-score"], 4),
240
+ "weighted_f1": round(report["weighted avg"]["f1-score"], 4),
241
+ "per_class": {
242
+ CLASS_NAMES[i]: {
243
+ "precision": round(report[CLASS_NAMES[i]]["precision"], 4),
244
+ "recall": round(report[CLASS_NAMES[i]]["recall"], 4),
245
+ "f1": round(report[CLASS_NAMES[i]]["f1-score"], 4),
246
+ }
247
+ for i in range(cfg["num_classes"])
248
+ },
249
+ "auc_per_class": aucs,
250
+ "mean_auc": round(np.mean(list(aucs.values())), 4),
251
+ }
252
+
253
+
254
+ def classify_text(text: str, encoder, prototypes, tokenizer, max_length, device) -> dict:
255
+ """Classify a single raw text string. Returns predicted class and similarity scores."""
256
+ enc = tokenizer(
257
+ emoji.demojize(text), padding="max_length", truncation=True,
258
+ max_length=max_length, return_tensors="pt",
259
+ )
260
+ with torch.no_grad():
261
+ emb = encoder(enc["input_ids"].to(device), enc["attention_mask"].to(device))
262
+ emb = emb.cpu().numpy()
263
+
264
+ sim = emb @ prototypes.T
265
+ scores = torch.softmax(torch.tensor(sim / 0.1), dim=-1).numpy()[0]
266
+ pred = int(sim.argmax())
267
+
268
+ return {
269
+ "predicted_class": pred,
270
+ "predicted_label": CLASS_NAMES[pred],
271
+ "scores": {CLASS_NAMES[i]: round(float(scores[i]), 4)
272
+ for i in range(len(CLASS_NAMES))},
273
+ }
274
+
275
+
276
+ def main():
277
+ cfg = load_config()
278
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
279
+ print(f"Device: {device}")
280
+
281
+ ckpt_dir = BASE_DIR / cfg["checkpoint_dir"]
282
+ results_dir = BASE_DIR / cfg["results_dir"]
283
+ results_dir.mkdir(parents=True, exist_ok=True)
284
+
285
+ # Load unsupervised encoder
286
+ encoder = load_unsupervised_encoder(ckpt_dir / "best_model.pt", cfg, device)
287
+
288
+ # Build prototypes from training set
289
+ print("\nBuilding class prototypes from training set...")
290
+ train_ds = load_dataset(BASE_DIR / cfg["prepared_dir"] / "train.pt")
291
+ train_embs, train_labels = get_embeddings(encoder, train_ds, cfg["batch_size"], device)
292
+ prototypes = build_prototypes(train_embs, train_labels, cfg["num_classes"])
293
+ np.save(results_dir / "prototypes.npy", prototypes)
294
+ print(f" Prototypes saved -> {results_dir / 'prototypes.npy'}")
295
+
296
+ # Evaluate val and test splits
297
+ all_results = []
298
+ for split in ["val", "test"]:
299
+ result = evaluate_split(encoder, prototypes, split, cfg, device, results_dir)
300
+ all_results.append(result)
301
+
302
+ # Save metrics
303
+ metrics_path = results_dir / "metrics.json"
304
+ with open(metrics_path, "w") as f:
305
+ json.dump(all_results, f, indent=2)
306
+ print(f"\nAll metrics saved -> {metrics_path}")
307
+
308
+ # Summary
309
+ print("\n=== SUMMARY ===")
310
+ for r in all_results:
311
+ print(f"{r['split']:6s} | acc={r['accuracy']:.4f} | macro_f1={r['macro_f1']:.4f} | mean_auc={r['mean_auc']:.4f}")
312
+
313
+ # Quick example inference
314
+ print("\n=== Example inference ===")
315
+ tokenizer = AutoTokenizer.from_pretrained(cfg["model_name"], use_fast=False)
316
+ examples = [
317
+ "I had a great day today, went for a walk in the park.",
318
+ "I'm going to k!ll that n!gga if he shows up again.",
319
+ "she posted an onlyfans link in her bio",
320
+ "gonna unalive myself fr fr cant take this anymore",
321
+ ]
322
+ for text in examples:
323
+ result = classify_text(text, encoder, prototypes, tokenizer, cfg["max_length"], device)
324
+ print(f" [{result['predicted_label']}] {text[:70]}")
325
+
326
+
327
+ if __name__ == "__main__":
328
+ main()
poc/src/model.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ model.py
3
+
4
+ Dual BERTweet architecture for algospeak content moderation.
5
+
6
+ Two independent BERTweet encoders trained jointly with supervised InfoNCE loss:
7
+ - supervised encoder: receives "[CLASS_LABEL]: text" — class-aware during training
8
+ - unsupervised encoder: receives raw text only — the inference model
9
+
10
+ At inference, only the unsupervised encoder is used. Its embeddings are compared
11
+ to class prototypes (built from training data) via cosine similarity.
12
+ """
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from transformers import AutoModel
18
+
19
+
20
+ class BERTweetEncoder(nn.Module):
21
+ """
22
+ Wraps vinai/bertweet-base and returns an L2-normalized CLS token embedding.
23
+ """
24
+
25
+ def __init__(self, model_name: str):
26
+ super().__init__()
27
+ self.bert = AutoModel.from_pretrained(model_name)
28
+
29
+ def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
30
+ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
31
+ cls_emb = outputs.last_hidden_state[:, 0, :] # [B, 768]
32
+ return F.normalize(cls_emb, dim=-1) # L2 normalize -> cosine-ready
33
+
34
+
35
+ class DualEncoderModel(nn.Module):
36
+ """
37
+ Two independent BERTweet encoders trained with supervised InfoNCE loss.
38
+
39
+ supervised encoder:
40
+ Input: "[CLASS_LABEL]: <text>" (e.g. "Offensive Language: I hate you")
41
+ Produces class-aware embeddings during training.
42
+ Discarded after training.
43
+
44
+ unsupervised encoder:
45
+ Input: raw text
46
+ Trained (via InfoNCE) to match the supervised encoder's embedding space.
47
+ Used exclusively at inference.
48
+ """
49
+
50
+ def __init__(self, model_name: str, temperature: float):
51
+ super().__init__()
52
+ self.supervised = BERTweetEncoder(model_name)
53
+ self.unsupervised = BERTweetEncoder(model_name)
54
+ self.temperature = temperature
55
+
56
+ def forward(
57
+ self,
58
+ sup_ids: torch.Tensor,
59
+ sup_mask: torch.Tensor,
60
+ unsup_ids: torch.Tensor,
61
+ unsup_mask: torch.Tensor,
62
+ labels: torch.Tensor,
63
+ ):
64
+ e_s = self.supervised(sup_ids, sup_mask) # [B, D]
65
+ e_u = self.unsupervised(unsup_ids, unsup_mask) # [B, D]
66
+ loss = supervised_infonce_loss(e_s, e_u, labels, self.temperature)
67
+ return loss, e_s, e_u
68
+
69
+
70
+ def supervised_infonce_loss(
71
+ e_s: torch.Tensor,
72
+ e_u: torch.Tensor,
73
+ labels: torch.Tensor,
74
+ temperature: float,
75
+ ) -> torch.Tensor:
76
+ """
77
+ Cross-encoder supervised InfoNCE loss.
78
+
79
+ For each unsupervised embedding e_u_i:
80
+ Positives: all supervised embeddings e_s_j where label_j == label_i
81
+ Negatives: all supervised embeddings e_s_j where label_j != label_i
82
+
83
+ Loss = mean_i [ -log( sum_{j: pos} exp(sim_ij/τ) / sum_j exp(sim_ij/τ) ) ]
84
+
85
+ Both e_s and e_u are L2-normalized so sim = dot product = cosine similarity.
86
+
87
+ Args:
88
+ e_s: [B, D] supervised encoder embeddings
89
+ e_u: [B, D] unsupervised encoder embeddings
90
+ labels: [B] integer class labels
91
+ temperature: scalar τ (typically 0.07)
92
+
93
+ Returns:
94
+ Scalar loss.
95
+ """
96
+ # Similarity matrix: unsupervised queries supervised keys — [B, B]
97
+ sim = torch.mm(e_u, e_s.T) / temperature
98
+
99
+ # Positive mask: True where label_j == label_i — [B, B]
100
+ pos_mask = (labels.unsqueeze(1) == labels.unsqueeze(0)).float()
101
+
102
+ # Numerical stability: subtract row max before exp
103
+ sim_max, _ = sim.max(dim=1, keepdim=True)
104
+ sim = sim - sim_max.detach()
105
+
106
+ exp_sim = torch.exp(sim)
107
+ pos_sum = (exp_sim * pos_mask).sum(dim=1) # [B]
108
+ all_sum = exp_sim.sum(dim=1) # [B]
109
+
110
+ # Skip samples with no positives in this batch (shouldn't happen at batch_size >= num_classes)
111
+ valid = pos_sum > 0
112
+ if not valid.any():
113
+ return torch.tensor(0.0, requires_grad=True, device=e_s.device)
114
+
115
+ loss = -torch.log(pos_sum[valid] / all_sum[valid])
116
+ return loss.mean()
requirements.txt CHANGED
@@ -1,3 +1,9 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
 
 
 
 
1
+ torch
2
+ transformers>=4.57.1
3
+ streamlit>=1.56.0
4
+ numpy
5
+ pyyaml>=6.0.3
6
+ emoji==0.6.0
7
+ scikit-learn>=1.8.0
8
+ sentencepiece
9
+ huggingface_hub