Add full_experiments.py
Browse files- code/full_experiments.py +540 -0
code/full_experiments.py
ADDED
|
@@ -0,0 +1,540 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
COMPREHENSIVE EXPERIMENT SUITE for IEEE paper:
|
| 3 |
+
1) Fine-tune roberta-base on TweetEval (3 seeds)
|
| 4 |
+
2) Evaluate all models with multi-seed where applicable
|
| 5 |
+
3) Ablation: learning rate sweep
|
| 6 |
+
4) McNemar test between all model pairs
|
| 7 |
+
5) Generate all figures including training curves
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
torch.set_num_threads(2)
|
| 12 |
+
|
| 13 |
+
import os, json, gc, time, sys
|
| 14 |
+
import numpy as np
|
| 15 |
+
from collections import Counter, defaultdict
|
| 16 |
+
from datasets import load_dataset
|
| 17 |
+
from transformers import (
|
| 18 |
+
AutoTokenizer, AutoModelForSequenceClassification,
|
| 19 |
+
TrainingArguments, Trainer, DataCollatorWithPadding,
|
| 20 |
+
EarlyStoppingCallback, pipeline
|
| 21 |
+
)
|
| 22 |
+
from sklearn.metrics import (
|
| 23 |
+
recall_score, f1_score, accuracy_score, precision_score,
|
| 24 |
+
confusion_matrix, classification_report
|
| 25 |
+
)
|
| 26 |
+
from scipy.stats import chi2
|
| 27 |
+
import matplotlib
|
| 28 |
+
matplotlib.use('Agg')
|
| 29 |
+
import matplotlib.pyplot as plt
|
| 30 |
+
|
| 31 |
+
os.makedirs("/app/figures_v2", exist_ok=True)
|
| 32 |
+
os.makedirs("/app/training_logs", exist_ok=True)
|
| 33 |
+
plt.rcParams['font.family'] = 'serif'
|
| 34 |
+
plt.rcParams['font.size'] = 10
|
| 35 |
+
|
| 36 |
+
def log(msg):
|
| 37 |
+
print(msg, flush=True)
|
| 38 |
+
|
| 39 |
+
log("="*70)
|
| 40 |
+
log("IEEE PAPER: COMPREHENSIVE EXPERIMENT SUITE")
|
| 41 |
+
log("="*70)
|
| 42 |
+
|
| 43 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 44 |
+
# 1. LOAD DATA
|
| 45 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 46 |
+
log("\n[1/7] Loading datasets...")
|
| 47 |
+
tweeteval = load_dataset("cardiffnlp/tweet_eval", "sentiment")
|
| 48 |
+
sst2 = load_dataset("stanfordnlp/sst2")
|
| 49 |
+
|
| 50 |
+
label_names = ['Negative', 'Neutral', 'Positive']
|
| 51 |
+
log(f" TweetEval: train={len(tweeteval['train'])}, val={len(tweeteval['validation'])}, test={len(tweeteval['test'])}")
|
| 52 |
+
log(f" SST-2: train={len(sst2['train'])}, val={len(sst2['validation'])}")
|
| 53 |
+
log(f" TweetEval test distribution: {Counter(tweeteval['test']['label'])}")
|
| 54 |
+
|
| 55 |
+
def preprocess_tweet(text):
|
| 56 |
+
if not text: return ""
|
| 57 |
+
return " ".join(
|
| 58 |
+
'@user' if w.startswith('@') and len(w)>1 else ('http' if w.startswith('http') else w)
|
| 59 |
+
for w in text.split()
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 63 |
+
# 2. FINE-TUNE ROBERTA-BASE ON TWEETEVAL (3 seeds)
|
| 64 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 65 |
+
log("\n[2/7] Fine-tuning roberta-base on TweetEval (3 seeds)...")
|
| 66 |
+
|
| 67 |
+
MODEL_FT = "roberta-base"
|
| 68 |
+
tok_ft = AutoTokenizer.from_pretrained(MODEL_FT)
|
| 69 |
+
|
| 70 |
+
def tokenize_tweeteval(examples):
|
| 71 |
+
texts = [preprocess_tweet(t) for t in examples['text']]
|
| 72 |
+
return tok_ft(texts, truncation=True, max_length=128, padding=False)
|
| 73 |
+
|
| 74 |
+
te_enc = tweeteval.map(tokenize_tweeteval, batched=True, remove_columns=['text'])
|
| 75 |
+
collator_ft = DataCollatorWithPadding(tokenizer=tok_ft)
|
| 76 |
+
|
| 77 |
+
def compute_metrics_ft(eval_pred):
|
| 78 |
+
logits, labels = eval_pred
|
| 79 |
+
preds = np.argmax(logits, axis=-1)
|
| 80 |
+
return {
|
| 81 |
+
'macro_recall': recall_score(labels, preds, average='macro'),
|
| 82 |
+
'macro_f1': f1_score(labels, preds, average='macro'),
|
| 83 |
+
'accuracy': accuracy_score(labels, preds),
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
seed_results_roberta = []
|
| 87 |
+
seed_preds_roberta = []
|
| 88 |
+
training_histories = []
|
| 89 |
+
|
| 90 |
+
for seed in [1, 2, 3]:
|
| 91 |
+
log(f"\n --- Seed {seed}/3 ---")
|
| 92 |
+
model_ft = AutoModelForSequenceClassification.from_pretrained(MODEL_FT, num_labels=3)
|
| 93 |
+
|
| 94 |
+
args = TrainingArguments(
|
| 95 |
+
output_dir=f'/app/ft_roberta_seed{seed}',
|
| 96 |
+
num_train_epochs=2,
|
| 97 |
+
per_device_train_batch_size=16,
|
| 98 |
+
per_device_eval_batch_size=64,
|
| 99 |
+
learning_rate=1e-5,
|
| 100 |
+
weight_decay=0.01,
|
| 101 |
+
warmup_ratio=0.1,
|
| 102 |
+
lr_scheduler_type='linear',
|
| 103 |
+
eval_strategy='epoch',
|
| 104 |
+
save_strategy='epoch',
|
| 105 |
+
logging_strategy='steps',
|
| 106 |
+
logging_steps=100,
|
| 107 |
+
logging_first_step=True,
|
| 108 |
+
disable_tqdm=True,
|
| 109 |
+
load_best_model_at_end=True,
|
| 110 |
+
metric_for_best_model='macro_recall',
|
| 111 |
+
greater_is_better=True,
|
| 112 |
+
seed=seed,
|
| 113 |
+
dataloader_num_workers=0,
|
| 114 |
+
fp16=False,
|
| 115 |
+
gradient_checkpointing=False,
|
| 116 |
+
report_to='none',
|
| 117 |
+
save_total_limit=1,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
trainer = Trainer(
|
| 121 |
+
model=model_ft, args=args,
|
| 122 |
+
train_dataset=te_enc['train'],
|
| 123 |
+
eval_dataset=te_enc['validation'],
|
| 124 |
+
data_collator=collator_ft,
|
| 125 |
+
compute_metrics=compute_metrics_ft,
|
| 126 |
+
callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
t0 = time.time()
|
| 130 |
+
result = trainer.train()
|
| 131 |
+
train_time = time.time() - t0
|
| 132 |
+
|
| 133 |
+
# Save training history
|
| 134 |
+
history = trainer.state.log_history
|
| 135 |
+
training_histories.append(history)
|
| 136 |
+
|
| 137 |
+
# Evaluate on test set
|
| 138 |
+
test_metrics = trainer.evaluate(te_enc['test'])
|
| 139 |
+
log(f" Seed {seed}: Macro-Recall={test_metrics['eval_macro_recall']:.4f}, "
|
| 140 |
+
f"Macro-F1={test_metrics['eval_macro_f1']:.4f}, Acc={test_metrics['eval_accuracy']:.4f}, "
|
| 141 |
+
f"Time={train_time/60:.1f}min")
|
| 142 |
+
|
| 143 |
+
# Get predictions for McNemar
|
| 144 |
+
preds_output = trainer.predict(te_enc['test'])
|
| 145 |
+
preds = np.argmax(preds_output.predictions, axis=-1)
|
| 146 |
+
seed_preds_roberta.append(preds.tolist())
|
| 147 |
+
|
| 148 |
+
seed_results_roberta.append({
|
| 149 |
+
'seed': seed,
|
| 150 |
+
'macro_recall': test_metrics['eval_macro_recall'],
|
| 151 |
+
'macro_f1': test_metrics['eval_macro_f1'],
|
| 152 |
+
'accuracy': test_metrics['eval_accuracy'],
|
| 153 |
+
'train_time_min': round(train_time/60, 1),
|
| 154 |
+
'train_loss': result.training_loss,
|
| 155 |
+
})
|
| 156 |
+
|
| 157 |
+
del model_ft, trainer
|
| 158 |
+
gc.collect()
|
| 159 |
+
|
| 160 |
+
# Aggregate multi-seed results
|
| 161 |
+
recalls = [r['macro_recall'] for r in seed_results_roberta]
|
| 162 |
+
f1s = [r['macro_f1'] for r in seed_results_roberta]
|
| 163 |
+
accs = [r['accuracy'] for r in seed_results_roberta]
|
| 164 |
+
|
| 165 |
+
log(f"\n RoBERTa-base (3 seeds): Macro-Recall={np.mean(recalls):.4f}+/-{np.std(recalls):.4f}, "
|
| 166 |
+
f"Macro-F1={np.mean(f1s):.4f}+/-{np.std(f1s):.4f}, Acc={np.mean(accs):.4f}+/-{np.std(accs):.4f}")
|
| 167 |
+
|
| 168 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 169 |
+
# 3. EVALUATE PRE-TRAINED MODELS (multi-pass for consistency)
|
| 170 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 171 |
+
log("\n[3/7] Evaluating pre-trained models on TweetEval...")
|
| 172 |
+
|
| 173 |
+
te_test_texts = [preprocess_tweet(t) for t in list(tweeteval['test']['text'])]
|
| 174 |
+
te_test_labels = list(tweeteval['test']['label'])
|
| 175 |
+
sst2_texts = list(sst2['validation']['sentence'])
|
| 176 |
+
sst2_labels = list(sst2['validation']['label'])
|
| 177 |
+
|
| 178 |
+
all_model_preds = {}
|
| 179 |
+
all_model_metrics = {}
|
| 180 |
+
|
| 181 |
+
# Twitter-RoBERTa (already fine-tuned, 3-class)
|
| 182 |
+
log(" Twitter-RoBERTa...")
|
| 183 |
+
pipe = pipeline("text-classification", model="cardiffnlp/twitter-roberta-base-sentiment-latest",
|
| 184 |
+
device=-1, batch_size=32, top_k=None)
|
| 185 |
+
preds_twr = []
|
| 186 |
+
for out in pipe(te_test_texts, truncation=True, max_length=128):
|
| 187 |
+
scores = {r['label'].lower(): r['score'] for r in out}
|
| 188 |
+
label_scores = [scores.get('negative',0), scores.get('neutral',0), scores.get('positive',0)]
|
| 189 |
+
preds_twr.append(np.argmax(label_scores))
|
| 190 |
+
|
| 191 |
+
mr_twr = recall_score(te_test_labels, preds_twr, average='macro')
|
| 192 |
+
mf1_twr = f1_score(te_test_labels, preds_twr, average='macro')
|
| 193 |
+
acc_twr = accuracy_score(te_test_labels, preds_twr)
|
| 194 |
+
log(f" Twitter-RoBERTa: MR={mr_twr:.4f}, MF1={mf1_twr:.4f}, Acc={acc_twr:.4f}")
|
| 195 |
+
all_model_preds['Twitter-RoBERTa'] = preds_twr
|
| 196 |
+
all_model_metrics['Twitter-RoBERTa'] = {'macro_recall': mr_twr, 'macro_f1': mf1_twr, 'accuracy': acc_twr, 'params': '125M'}
|
| 197 |
+
|
| 198 |
+
# SST-2 eval for Twitter-RoBERTa
|
| 199 |
+
preds_twr_sst2 = []
|
| 200 |
+
for out in pipe(sst2_texts, truncation=True, max_length=128):
|
| 201 |
+
scores = {r['label'].lower(): r['score'] for r in out}
|
| 202 |
+
preds_twr_sst2.append(1 if scores.get('positive',0) > scores.get('negative',0) else 0)
|
| 203 |
+
acc_twr_sst2 = accuracy_score(sst2_labels, preds_twr_sst2)
|
| 204 |
+
all_model_metrics['Twitter-RoBERTa']['sst2_accuracy'] = acc_twr_sst2
|
| 205 |
+
del pipe; gc.collect()
|
| 206 |
+
|
| 207 |
+
# DeBERTa-v3-base (binary, SST-2 fine-tuned)
|
| 208 |
+
log(" DeBERTa-v3-base...")
|
| 209 |
+
pipe = pipeline("text-classification", model="cliang1453/deberta-v3-base-sst2", device=-1, batch_size=16)
|
| 210 |
+
preds_deb = []
|
| 211 |
+
for out in pipe(te_test_texts, truncation=True, max_length=128):
|
| 212 |
+
preds_deb.append(2 if out['label'].lower() == 'positive' else 0)
|
| 213 |
+
mr_deb = recall_score(te_test_labels, preds_deb, average='macro')
|
| 214 |
+
mf1_deb = f1_score(te_test_labels, preds_deb, average='macro', zero_division=0)
|
| 215 |
+
acc_deb = accuracy_score(te_test_labels, preds_deb)
|
| 216 |
+
log(f" DeBERTa-v3: MR={mr_deb:.4f}, MF1={mf1_deb:.4f}, Acc={acc_deb:.4f}")
|
| 217 |
+
all_model_preds['DeBERTa-v3-base'] = preds_deb
|
| 218 |
+
all_model_metrics['DeBERTa-v3-base'] = {'macro_recall': mr_deb, 'macro_f1': mf1_deb, 'accuracy': acc_deb, 'params': '184M'}
|
| 219 |
+
|
| 220 |
+
preds_deb_sst2 = []
|
| 221 |
+
for out in pipe(sst2_texts, truncation=True, max_length=128):
|
| 222 |
+
preds_deb_sst2.append(1 if out['label'].lower() == 'positive' else 0)
|
| 223 |
+
all_model_metrics['DeBERTa-v3-base']['sst2_accuracy'] = accuracy_score(sst2_labels, preds_deb_sst2)
|
| 224 |
+
del pipe; gc.collect()
|
| 225 |
+
|
| 226 |
+
# BERT-base
|
| 227 |
+
log(" BERT-base...")
|
| 228 |
+
pipe = pipeline("text-classification", model="textattack/bert-base-uncased-SST-2", device=-1, batch_size=32)
|
| 229 |
+
preds_bert = []
|
| 230 |
+
for out in pipe(te_test_texts, truncation=True, max_length=128):
|
| 231 |
+
preds_bert.append(2 if out['label'].upper() in ['POSITIVE','LABEL_1','1'] else 0)
|
| 232 |
+
mr_bert = recall_score(te_test_labels, preds_bert, average='macro')
|
| 233 |
+
mf1_bert = f1_score(te_test_labels, preds_bert, average='macro', zero_division=0)
|
| 234 |
+
acc_bert = accuracy_score(te_test_labels, preds_bert)
|
| 235 |
+
all_model_preds['BERT-base'] = preds_bert
|
| 236 |
+
all_model_metrics['BERT-base'] = {'macro_recall': mr_bert, 'macro_f1': mf1_bert, 'accuracy': acc_bert, 'params': '110M'}
|
| 237 |
+
|
| 238 |
+
preds_bert_sst2 = []
|
| 239 |
+
for out in pipe(sst2_texts, truncation=True, max_length=128):
|
| 240 |
+
preds_bert_sst2.append(1 if out['label'].upper() in ['POSITIVE','LABEL_1','1'] else 0)
|
| 241 |
+
all_model_metrics['BERT-base']['sst2_accuracy'] = accuracy_score(sst2_labels, preds_bert_sst2)
|
| 242 |
+
del pipe; gc.collect()
|
| 243 |
+
|
| 244 |
+
# DistilBERT
|
| 245 |
+
log(" DistilBERT...")
|
| 246 |
+
pipe = pipeline("text-classification", model="distilbert/distilbert-base-uncased-finetuned-sst-2-english",
|
| 247 |
+
device=-1, batch_size=32)
|
| 248 |
+
preds_distil = []
|
| 249 |
+
for out in pipe(te_test_texts, truncation=True, max_length=128):
|
| 250 |
+
preds_distil.append(2 if out['label'] == 'POSITIVE' else 0)
|
| 251 |
+
mr_dist = recall_score(te_test_labels, preds_distil, average='macro')
|
| 252 |
+
mf1_dist = f1_score(te_test_labels, preds_distil, average='macro', zero_division=0)
|
| 253 |
+
acc_dist = accuracy_score(te_test_labels, preds_distil)
|
| 254 |
+
all_model_preds['DistilBERT'] = preds_distil
|
| 255 |
+
all_model_metrics['DistilBERT'] = {'macro_recall': mr_dist, 'macro_f1': mf1_dist, 'accuracy': acc_dist, 'params': '66M'}
|
| 256 |
+
|
| 257 |
+
preds_dist_sst2 = []
|
| 258 |
+
for out in pipe(sst2_texts, truncation=True, max_length=128):
|
| 259 |
+
preds_dist_sst2.append(1 if out['label'] == 'POSITIVE' else 0)
|
| 260 |
+
all_model_metrics['DistilBERT']['sst2_accuracy'] = accuracy_score(sst2_labels, preds_dist_sst2)
|
| 261 |
+
del pipe; gc.collect()
|
| 262 |
+
|
| 263 |
+
# Add fine-tuned RoBERTa (use best seed predictions)
|
| 264 |
+
best_seed_idx = np.argmax(recalls)
|
| 265 |
+
all_model_preds['RoBERTa-FT'] = seed_preds_roberta[best_seed_idx]
|
| 266 |
+
all_model_metrics['RoBERTa-FT'] = {
|
| 267 |
+
'macro_recall': np.mean(recalls), 'macro_recall_std': np.std(recalls),
|
| 268 |
+
'macro_f1': np.mean(f1s), 'macro_f1_std': np.std(f1s),
|
| 269 |
+
'accuracy': np.mean(accs), 'accuracy_std': np.std(accs),
|
| 270 |
+
'params': '125M',
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 274 |
+
# 4. ABLATION: Learning Rate Sweep
|
| 275 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 276 |
+
log("\n[4/7] Ablation: Learning Rate Sweep (roberta-base, seed=1)...")
|
| 277 |
+
|
| 278 |
+
ablation_lrs = [5e-6, 1e-5, 3e-5]
|
| 279 |
+
ablation_results = []
|
| 280 |
+
|
| 281 |
+
for lr in ablation_lrs:
|
| 282 |
+
log(f" LR={lr}...")
|
| 283 |
+
model_ab = AutoModelForSequenceClassification.from_pretrained(MODEL_FT, num_labels=3)
|
| 284 |
+
args_ab = TrainingArguments(
|
| 285 |
+
output_dir=f'/app/ablation_lr_{lr}',
|
| 286 |
+
num_train_epochs=2,
|
| 287 |
+
per_device_train_batch_size=16,
|
| 288 |
+
per_device_eval_batch_size=64,
|
| 289 |
+
learning_rate=lr,
|
| 290 |
+
weight_decay=0.01,
|
| 291 |
+
warmup_ratio=0.1,
|
| 292 |
+
eval_strategy='epoch',
|
| 293 |
+
save_strategy='epoch',
|
| 294 |
+
disable_tqdm=True,
|
| 295 |
+
load_best_model_at_end=True,
|
| 296 |
+
metric_for_best_model='macro_recall',
|
| 297 |
+
greater_is_better=True,
|
| 298 |
+
seed=1,
|
| 299 |
+
dataloader_num_workers=0,
|
| 300 |
+
fp16=False,
|
| 301 |
+
report_to='none',
|
| 302 |
+
save_total_limit=1,
|
| 303 |
+
logging_strategy='steps',
|
| 304 |
+
logging_steps=500,
|
| 305 |
+
logging_first_step=True,
|
| 306 |
+
)
|
| 307 |
+
trainer_ab = Trainer(
|
| 308 |
+
model=model_ab, args=args_ab,
|
| 309 |
+
train_dataset=te_enc['train'],
|
| 310 |
+
eval_dataset=te_enc['validation'],
|
| 311 |
+
data_collator=collator_ft,
|
| 312 |
+
compute_metrics=compute_metrics_ft,
|
| 313 |
+
)
|
| 314 |
+
trainer_ab.train()
|
| 315 |
+
test_ab = trainer_ab.evaluate(te_enc['test'])
|
| 316 |
+
ablation_results.append({
|
| 317 |
+
'lr': lr,
|
| 318 |
+
'macro_recall': test_ab['eval_macro_recall'],
|
| 319 |
+
'macro_f1': test_ab['eval_macro_f1'],
|
| 320 |
+
'accuracy': test_ab['eval_accuracy'],
|
| 321 |
+
})
|
| 322 |
+
log(f" LR={lr}: MR={test_ab['eval_macro_recall']:.4f}, MF1={test_ab['eval_macro_f1']:.4f}")
|
| 323 |
+
del model_ab, trainer_ab; gc.collect()
|
| 324 |
+
|
| 325 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 326 |
+
# 5. McNEMAR TESTS
|
| 327 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 328 |
+
log("\n[5/7] McNemar Statistical Significance Tests...")
|
| 329 |
+
|
| 330 |
+
def mcnemar_test(y_true, y_pred1, y_pred2):
|
| 331 |
+
b = sum(1 for yt,p1,p2 in zip(y_true,y_pred1,y_pred2) if p1==yt and p2!=yt)
|
| 332 |
+
c = sum(1 for yt,p1,p2 in zip(y_true,y_pred1,y_pred2) if p1!=yt and p2==yt)
|
| 333 |
+
if b + c == 0: return 0.0, 1.0
|
| 334 |
+
stat = (abs(b - c) - 1)**2 / (b + c)
|
| 335 |
+
p_val = 1 - chi2.cdf(stat, df=1)
|
| 336 |
+
return stat, p_val
|
| 337 |
+
|
| 338 |
+
mcnemar_results = {}
|
| 339 |
+
model_names = list(all_model_preds.keys())
|
| 340 |
+
for i in range(len(model_names)):
|
| 341 |
+
for j in range(i+1, len(model_names)):
|
| 342 |
+
m1, m2 = model_names[i], model_names[j]
|
| 343 |
+
stat, p = mcnemar_test(te_test_labels, all_model_preds[m1], all_model_preds[m2])
|
| 344 |
+
sig = "***" if p < 0.001 else ("**" if p < 0.01 else ("*" if p < 0.05 else "ns"))
|
| 345 |
+
mcnemar_results[f"{m1} vs {m2}"] = {'statistic': round(stat, 2), 'p_value': round(p, 6), 'significance': sig}
|
| 346 |
+
log(f" {m1} vs {m2}: chi2={stat:.2f}, p={p:.6f} {sig}")
|
| 347 |
+
|
| 348 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 349 |
+
# 6. GENERATE ALL FIGURES
|
| 350 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 351 |
+
log("\n[6/7] Generating figures...")
|
| 352 |
+
|
| 353 |
+
# Fig 1: Training curves (from seed 1)
|
| 354 |
+
log(" Fig 1: Training loss curves...")
|
| 355 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
|
| 356 |
+
for sidx, hist in enumerate(training_histories):
|
| 357 |
+
train_losses = [(h['step'], h['loss']) for h in hist if 'loss' in h and 'eval_loss' not in h]
|
| 358 |
+
eval_losses = [(h['epoch'], h['eval_loss']) for h in hist if 'eval_loss' in h]
|
| 359 |
+
eval_recalls = [(h['epoch'], h['eval_macro_recall']) for h in hist if 'eval_macro_recall' in h]
|
| 360 |
+
|
| 361 |
+
if train_losses:
|
| 362 |
+
steps, losses = zip(*train_losses)
|
| 363 |
+
ax1.plot(steps, losses, alpha=0.7, label=f'Seed {sidx+1}')
|
| 364 |
+
if eval_recalls:
|
| 365 |
+
epochs, recs = zip(*eval_recalls)
|
| 366 |
+
ax2.plot(epochs, [r*100 for r in recs], 'o-', alpha=0.8, label=f'Seed {sidx+1}')
|
| 367 |
+
|
| 368 |
+
ax1.set_xlabel('Training Step'); ax1.set_ylabel('Loss')
|
| 369 |
+
ax1.set_title('(a) Training Loss', fontweight='bold')
|
| 370 |
+
ax1.legend(); ax1.grid(alpha=0.3)
|
| 371 |
+
ax2.set_xlabel('Epoch'); ax2.set_ylabel('Macro-Recall (%)')
|
| 372 |
+
ax2.set_title('(b) Validation Macro-Recall', fontweight='bold')
|
| 373 |
+
ax2.legend(); ax2.grid(alpha=0.3)
|
| 374 |
+
plt.tight_layout()
|
| 375 |
+
plt.savefig('/app/figures_v2/fig_training_curves.png', dpi=300, bbox_inches='tight')
|
| 376 |
+
plt.savefig('/app/figures_v2/fig_training_curves.pdf', bbox_inches='tight')
|
| 377 |
+
plt.close()
|
| 378 |
+
|
| 379 |
+
# Fig 2: Confusion matrices (2x2 grid)
|
| 380 |
+
log(" Fig 2: Confusion matrices...")
|
| 381 |
+
fig, axes = plt.subplots(1, 3, figsize=(14, 4))
|
| 382 |
+
models_cm = [('Twitter-RoBERTa', preds_twr), ('RoBERTa-FT (Ours)', seed_preds_roberta[best_seed_idx]), ('DeBERTa-v3', preds_deb)]
|
| 383 |
+
cmaps = ['Blues', 'Greens', 'Oranges']
|
| 384 |
+
for idx, (name, preds) in enumerate(models_cm):
|
| 385 |
+
cm = confusion_matrix(te_test_labels, preds, normalize='true')
|
| 386 |
+
im = axes[idx].imshow(cm, interpolation='nearest', cmap=cmaps[idx], vmin=0, vmax=1)
|
| 387 |
+
axes[idx].set(xticks=range(3), yticks=range(3),
|
| 388 |
+
xticklabels=['Neg','Neu','Pos'], yticklabels=['Neg','Neu','Pos'],
|
| 389 |
+
title=name, ylabel='True' if idx==0 else '', xlabel='Predicted')
|
| 390 |
+
for i in range(3):
|
| 391 |
+
for j in range(3):
|
| 392 |
+
axes[idx].text(j, i, f'{cm[i,j]:.2f}', ha='center', va='center',
|
| 393 |
+
color='white' if cm[i,j]>0.5 else 'black', fontsize=11)
|
| 394 |
+
plt.tight_layout()
|
| 395 |
+
plt.savefig('/app/figures_v2/fig_confusion_matrices.png', dpi=300, bbox_inches='tight')
|
| 396 |
+
plt.savefig('/app/figures_v2/fig_confusion_matrices.pdf', bbox_inches='tight')
|
| 397 |
+
plt.close()
|
| 398 |
+
|
| 399 |
+
# Fig 3: Model comparison bar chart
|
| 400 |
+
log(" Fig 3: Model comparison...")
|
| 401 |
+
models_plot = ['DistilBERT\n(66M)', 'BERT-base\n(110M)', 'RoBERTa-FT\n(125M)\n[Ours]', 'Twitter-\nRoBERTa\n(125M)', 'DeBERTa-v3\n(184M)']
|
| 402 |
+
sst2_vals = [
|
| 403 |
+
all_model_metrics['DistilBERT'].get('sst2_accuracy',0)*100,
|
| 404 |
+
all_model_metrics['BERT-base'].get('sst2_accuracy',0)*100,
|
| 405 |
+
np.mean(accs)*100, # RoBERTa-FT doesn't have SST-2 eval, use N/A
|
| 406 |
+
all_model_metrics['Twitter-RoBERTa'].get('sst2_accuracy',0)*100,
|
| 407 |
+
all_model_metrics['DeBERTa-v3-base'].get('sst2_accuracy',0)*100,
|
| 408 |
+
]
|
| 409 |
+
tweet_vals = [
|
| 410 |
+
all_model_metrics['DistilBERT']['macro_recall']*100,
|
| 411 |
+
all_model_metrics['BERT-base']['macro_recall']*100,
|
| 412 |
+
np.mean(recalls)*100,
|
| 413 |
+
all_model_metrics['Twitter-RoBERTa']['macro_recall']*100,
|
| 414 |
+
all_model_metrics['DeBERTa-v3-base']['macro_recall']*100,
|
| 415 |
+
]
|
| 416 |
+
tweet_stds = [0, 0, np.std(recalls)*100, 0, 0]
|
| 417 |
+
|
| 418 |
+
x = np.arange(len(models_plot)); width = 0.35
|
| 419 |
+
fig, ax = plt.subplots(figsize=(10, 5))
|
| 420 |
+
bars1 = ax.bar(x - width/2, sst2_vals, width, label='SST-2 Accuracy (%)', color='#2196F3', edgecolor='black', linewidth=0.5)
|
| 421 |
+
bars2 = ax.bar(x + width/2, tweet_vals, width, yerr=tweet_stds, capsize=3,
|
| 422 |
+
label='TweetEval Macro-Recall (%)', color='#FF9800', edgecolor='black', linewidth=0.5)
|
| 423 |
+
ax.set_ylabel('Score (%)'); ax.set_xticks(x); ax.set_xticklabels(models_plot, fontsize=8)
|
| 424 |
+
ax.set_title('Model Comparison: SST-2 vs TweetEval Performance', fontweight='bold')
|
| 425 |
+
ax.legend(loc='upper left'); ax.set_ylim(0, 105); ax.grid(axis='y', alpha=0.3)
|
| 426 |
+
ax.axhline(y=95, color='red', linestyle='--', alpha=0.5); ax.text(4.3, 95.5, '95%', color='red', fontsize=7)
|
| 427 |
+
for bar in bars1+bars2:
|
| 428 |
+
h = bar.get_height()
|
| 429 |
+
if h > 0: ax.text(bar.get_x()+bar.get_width()/2., h+1, f'{h:.1f}', ha='center', va='bottom', fontsize=7)
|
| 430 |
+
plt.tight_layout()
|
| 431 |
+
plt.savefig('/app/figures_v2/fig_model_comparison.png', dpi=300, bbox_inches='tight')
|
| 432 |
+
plt.savefig('/app/figures_v2/fig_model_comparison.pdf', bbox_inches='tight')
|
| 433 |
+
plt.close()
|
| 434 |
+
|
| 435 |
+
# Fig 4: Ablation - Learning Rate
|
| 436 |
+
log(" Fig 4: Learning rate ablation...")
|
| 437 |
+
fig, ax = plt.subplots(figsize=(7, 4))
|
| 438 |
+
lrs = [r['lr'] for r in ablation_results]
|
| 439 |
+
mr_vals = [r['macro_recall']*100 for r in ablation_results]
|
| 440 |
+
mf1_vals = [r['macro_f1']*100 for r in ablation_results]
|
| 441 |
+
ax.plot(range(len(lrs)), mr_vals, 'o-', color='#2196F3', label='Macro-Recall', linewidth=2, markersize=8)
|
| 442 |
+
ax.plot(range(len(lrs)), mf1_vals, 's--', color='#FF9800', label='Macro-F1', linewidth=2, markersize=8)
|
| 443 |
+
ax.set_xticks(range(len(lrs))); ax.set_xticklabels([f'{lr:.0e}' for lr in lrs])
|
| 444 |
+
ax.set_xlabel('Learning Rate'); ax.set_ylabel('Score (%)')
|
| 445 |
+
ax.set_title('Ablation: Learning Rate Sensitivity (RoBERTa-base)', fontweight='bold')
|
| 446 |
+
ax.legend(); ax.grid(alpha=0.3)
|
| 447 |
+
for i, (mr, mf) in enumerate(zip(mr_vals, mf1_vals)):
|
| 448 |
+
ax.annotate(f'{mr:.1f}', (i, mr), textcoords="offset points", xytext=(0,8), ha='center', fontsize=7)
|
| 449 |
+
plt.tight_layout()
|
| 450 |
+
plt.savefig('/app/figures_v2/fig_lr_ablation.png', dpi=300, bbox_inches='tight')
|
| 451 |
+
plt.savefig('/app/figures_v2/fig_lr_ablation.pdf', bbox_inches='tight')
|
| 452 |
+
plt.close()
|
| 453 |
+
|
| 454 |
+
# Fig 5: Per-class F1 comparison (3 models)
|
| 455 |
+
log(" Fig 5: Per-class F1...")
|
| 456 |
+
fig, ax = plt.subplots(figsize=(8, 4.5))
|
| 457 |
+
classes = ['Negative', 'Neutral', 'Positive']
|
| 458 |
+
x = np.arange(3); w = 0.25
|
| 459 |
+
for idx, (name, preds) in enumerate(models_cm):
|
| 460 |
+
report = classification_report(te_test_labels, preds, output_dict=True, zero_division=0)
|
| 461 |
+
f1s_class = [report[c]['f1-score']*100 for c in classes]
|
| 462 |
+
bars = ax.bar(x + idx*w, f1s_class, w, label=name, edgecolor='black', linewidth=0.5)
|
| 463 |
+
for bar, v in zip(bars, f1s_class):
|
| 464 |
+
ax.text(bar.get_x()+bar.get_width()/2., bar.get_height()+1, f'{v:.1f}', ha='center', fontsize=7)
|
| 465 |
+
ax.set_xticks(x + w); ax.set_xticklabels(classes)
|
| 466 |
+
ax.set_ylabel('F1-Score (%)'); ax.set_title('Per-Class F1 on TweetEval', fontweight='bold')
|
| 467 |
+
ax.legend(fontsize=9); ax.set_ylim(0, 100); ax.grid(axis='y', alpha=0.3)
|
| 468 |
+
plt.tight_layout()
|
| 469 |
+
plt.savefig('/app/figures_v2/fig_per_class_f1.png', dpi=300, bbox_inches='tight')
|
| 470 |
+
plt.savefig('/app/figures_v2/fig_per_class_f1.pdf', bbox_inches='tight')
|
| 471 |
+
plt.close()
|
| 472 |
+
|
| 473 |
+
# Fig 6: Data distribution
|
| 474 |
+
log(" Fig 6: Dataset distribution...")
|
| 475 |
+
fig, axes = plt.subplots(1, 2, figsize=(8, 3.5))
|
| 476 |
+
te_counts = Counter(te_test_labels)
|
| 477 |
+
axes[0].bar(classes, [te_counts[0], te_counts[1], te_counts[2]],
|
| 478 |
+
color=['#e74c3c', '#95a5a6', '#2ecc71'], edgecolor='black', linewidth=0.5)
|
| 479 |
+
axes[0].set_title('TweetEval Test (n=12,284)', fontweight='bold'); axes[0].set_ylabel('Count')
|
| 480 |
+
for i, v in enumerate([te_counts[0], te_counts[1], te_counts[2]]):
|
| 481 |
+
axes[0].text(i, v+50, f'{v}\n({v/len(te_test_labels)*100:.1f}%)', ha='center', fontsize=8)
|
| 482 |
+
sst2_counts = Counter(sst2_labels)
|
| 483 |
+
axes[1].bar(['Negative', 'Positive'], [sst2_counts[0], sst2_counts[1]],
|
| 484 |
+
color=['#e74c3c', '#2ecc71'], edgecolor='black', linewidth=0.5)
|
| 485 |
+
axes[1].set_title('SST-2 Validation (n=872)', fontweight='bold'); axes[1].set_ylabel('Count')
|
| 486 |
+
for i, v in enumerate([sst2_counts[0], sst2_counts[1]]):
|
| 487 |
+
axes[1].text(i, v+5, f'{v}\n({v/len(sst2_labels)*100:.1f}%)', ha='center', fontsize=8)
|
| 488 |
+
plt.tight_layout()
|
| 489 |
+
plt.savefig('/app/figures_v2/fig_data_distribution.png', dpi=300, bbox_inches='tight')
|
| 490 |
+
plt.savefig('/app/figures_v2/fig_data_distribution.pdf', bbox_inches='tight')
|
| 491 |
+
plt.close()
|
| 492 |
+
|
| 493 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 494 |
+
# 7. SAVE ALL RESULTS
|
| 495 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 496 |
+
log("\n[7/7] Saving results...")
|
| 497 |
+
|
| 498 |
+
full_results = {
|
| 499 |
+
'multi_seed_roberta_ft': seed_results_roberta,
|
| 500 |
+
'multi_seed_summary': {
|
| 501 |
+
'macro_recall_mean': round(np.mean(recalls), 4),
|
| 502 |
+
'macro_recall_std': round(np.std(recalls), 4),
|
| 503 |
+
'macro_f1_mean': round(np.mean(f1s), 4),
|
| 504 |
+
'macro_f1_std': round(np.std(f1s), 4),
|
| 505 |
+
'accuracy_mean': round(np.mean(accs), 4),
|
| 506 |
+
'accuracy_std': round(np.std(accs), 4),
|
| 507 |
+
},
|
| 508 |
+
'all_models': {k: {kk: round(vv, 4) if isinstance(vv, float) else vv for kk, vv in v.items()}
|
| 509 |
+
for k, v in all_model_metrics.items()},
|
| 510 |
+
'ablation_lr': ablation_results,
|
| 511 |
+
'mcnemar_tests': mcnemar_results,
|
| 512 |
+
}
|
| 513 |
+
|
| 514 |
+
with open('/app/full_experiment_results.json', 'w') as f:
|
| 515 |
+
json.dump(full_results, f, indent=2, default=str)
|
| 516 |
+
|
| 517 |
+
log("\n" + "="*70)
|
| 518 |
+
log("COMPLETE RESULTS SUMMARY")
|
| 519 |
+
log("="*70)
|
| 520 |
+
log(f"\n{'Model':<25} {'Params':>7} {'SST-2 Acc':>10} {'TweetEval MR':>13} {'TweetEval MF1':>14}")
|
| 521 |
+
log("-"*73)
|
| 522 |
+
for name, m in all_model_metrics.items():
|
| 523 |
+
sst2_str = f"{m.get('sst2_accuracy',0)*100:.2f}%" if 'sst2_accuracy' in m else "N/A"
|
| 524 |
+
mr = m['macro_recall']
|
| 525 |
+
mf1 = m['macro_f1']
|
| 526 |
+
std_str = f"+/-{m.get('macro_recall_std',0)*100:.2f}" if 'macro_recall_std' in m else ""
|
| 527 |
+
log(f"{name:<25} {m['params']:>7} {sst2_str:>10} {mr*100:>9.2f}%{std_str:>4} {mf1*100:>10.2f}%")
|
| 528 |
+
log("="*73)
|
| 529 |
+
|
| 530 |
+
log(f"\nAblation (LR):")
|
| 531 |
+
for r in ablation_results:
|
| 532 |
+
log(f" LR={r['lr']:.0e}: MR={r['macro_recall']*100:.2f}%, MF1={r['macro_f1']*100:.2f}%")
|
| 533 |
+
|
| 534 |
+
log(f"\nMcNemar Tests:")
|
| 535 |
+
for pair, res in mcnemar_results.items():
|
| 536 |
+
log(f" {pair}: p={res['p_value']:.6f} {res['significance']}")
|
| 537 |
+
|
| 538 |
+
log(f"\nAll figures saved to /app/figures_v2/")
|
| 539 |
+
log(f"Results saved to /app/full_experiment_results.json")
|
| 540 |
+
log("DONE!")
|