rajvivan commited on
Commit
7be4e88
Β·
verified Β·
1 Parent(s): 3d99502

Add full_experiments.py

Browse files
Files changed (1) hide show
  1. code/full_experiments.py +540 -0
code/full_experiments.py ADDED
@@ -0,0 +1,540 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ COMPREHENSIVE EXPERIMENT SUITE for IEEE paper:
3
+ 1) Fine-tune roberta-base on TweetEval (3 seeds)
4
+ 2) Evaluate all models with multi-seed where applicable
5
+ 3) Ablation: learning rate sweep
6
+ 4) McNemar test between all model pairs
7
+ 5) Generate all figures including training curves
8
+ """
9
+
10
+ import torch
11
+ torch.set_num_threads(2)
12
+
13
+ import os, json, gc, time, sys
14
+ import numpy as np
15
+ from collections import Counter, defaultdict
16
+ from datasets import load_dataset
17
+ from transformers import (
18
+ AutoTokenizer, AutoModelForSequenceClassification,
19
+ TrainingArguments, Trainer, DataCollatorWithPadding,
20
+ EarlyStoppingCallback, pipeline
21
+ )
22
+ from sklearn.metrics import (
23
+ recall_score, f1_score, accuracy_score, precision_score,
24
+ confusion_matrix, classification_report
25
+ )
26
+ from scipy.stats import chi2
27
+ import matplotlib
28
+ matplotlib.use('Agg')
29
+ import matplotlib.pyplot as plt
30
+
31
+ os.makedirs("/app/figures_v2", exist_ok=True)
32
+ os.makedirs("/app/training_logs", exist_ok=True)
33
+ plt.rcParams['font.family'] = 'serif'
34
+ plt.rcParams['font.size'] = 10
35
+
36
+ def log(msg):
37
+ print(msg, flush=True)
38
+
39
+ log("="*70)
40
+ log("IEEE PAPER: COMPREHENSIVE EXPERIMENT SUITE")
41
+ log("="*70)
42
+
43
+ # ══════════════════════════════════════════════════════════════════
44
+ # 1. LOAD DATA
45
+ # ══════════════════════════════════════════════════════════════════
46
+ log("\n[1/7] Loading datasets...")
47
+ tweeteval = load_dataset("cardiffnlp/tweet_eval", "sentiment")
48
+ sst2 = load_dataset("stanfordnlp/sst2")
49
+
50
+ label_names = ['Negative', 'Neutral', 'Positive']
51
+ log(f" TweetEval: train={len(tweeteval['train'])}, val={len(tweeteval['validation'])}, test={len(tweeteval['test'])}")
52
+ log(f" SST-2: train={len(sst2['train'])}, val={len(sst2['validation'])}")
53
+ log(f" TweetEval test distribution: {Counter(tweeteval['test']['label'])}")
54
+
55
+ def preprocess_tweet(text):
56
+ if not text: return ""
57
+ return " ".join(
58
+ '@user' if w.startswith('@') and len(w)>1 else ('http' if w.startswith('http') else w)
59
+ for w in text.split()
60
+ )
61
+
62
+ # ══════════════════════════════════════════════════════════════════
63
+ # 2. FINE-TUNE ROBERTA-BASE ON TWEETEVAL (3 seeds)
64
+ # ══════════════════════════════════════════════════════════════════
65
+ log("\n[2/7] Fine-tuning roberta-base on TweetEval (3 seeds)...")
66
+
67
+ MODEL_FT = "roberta-base"
68
+ tok_ft = AutoTokenizer.from_pretrained(MODEL_FT)
69
+
70
+ def tokenize_tweeteval(examples):
71
+ texts = [preprocess_tweet(t) for t in examples['text']]
72
+ return tok_ft(texts, truncation=True, max_length=128, padding=False)
73
+
74
+ te_enc = tweeteval.map(tokenize_tweeteval, batched=True, remove_columns=['text'])
75
+ collator_ft = DataCollatorWithPadding(tokenizer=tok_ft)
76
+
77
+ def compute_metrics_ft(eval_pred):
78
+ logits, labels = eval_pred
79
+ preds = np.argmax(logits, axis=-1)
80
+ return {
81
+ 'macro_recall': recall_score(labels, preds, average='macro'),
82
+ 'macro_f1': f1_score(labels, preds, average='macro'),
83
+ 'accuracy': accuracy_score(labels, preds),
84
+ }
85
+
86
+ seed_results_roberta = []
87
+ seed_preds_roberta = []
88
+ training_histories = []
89
+
90
+ for seed in [1, 2, 3]:
91
+ log(f"\n --- Seed {seed}/3 ---")
92
+ model_ft = AutoModelForSequenceClassification.from_pretrained(MODEL_FT, num_labels=3)
93
+
94
+ args = TrainingArguments(
95
+ output_dir=f'/app/ft_roberta_seed{seed}',
96
+ num_train_epochs=2,
97
+ per_device_train_batch_size=16,
98
+ per_device_eval_batch_size=64,
99
+ learning_rate=1e-5,
100
+ weight_decay=0.01,
101
+ warmup_ratio=0.1,
102
+ lr_scheduler_type='linear',
103
+ eval_strategy='epoch',
104
+ save_strategy='epoch',
105
+ logging_strategy='steps',
106
+ logging_steps=100,
107
+ logging_first_step=True,
108
+ disable_tqdm=True,
109
+ load_best_model_at_end=True,
110
+ metric_for_best_model='macro_recall',
111
+ greater_is_better=True,
112
+ seed=seed,
113
+ dataloader_num_workers=0,
114
+ fp16=False,
115
+ gradient_checkpointing=False,
116
+ report_to='none',
117
+ save_total_limit=1,
118
+ )
119
+
120
+ trainer = Trainer(
121
+ model=model_ft, args=args,
122
+ train_dataset=te_enc['train'],
123
+ eval_dataset=te_enc['validation'],
124
+ data_collator=collator_ft,
125
+ compute_metrics=compute_metrics_ft,
126
+ callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
127
+ )
128
+
129
+ t0 = time.time()
130
+ result = trainer.train()
131
+ train_time = time.time() - t0
132
+
133
+ # Save training history
134
+ history = trainer.state.log_history
135
+ training_histories.append(history)
136
+
137
+ # Evaluate on test set
138
+ test_metrics = trainer.evaluate(te_enc['test'])
139
+ log(f" Seed {seed}: Macro-Recall={test_metrics['eval_macro_recall']:.4f}, "
140
+ f"Macro-F1={test_metrics['eval_macro_f1']:.4f}, Acc={test_metrics['eval_accuracy']:.4f}, "
141
+ f"Time={train_time/60:.1f}min")
142
+
143
+ # Get predictions for McNemar
144
+ preds_output = trainer.predict(te_enc['test'])
145
+ preds = np.argmax(preds_output.predictions, axis=-1)
146
+ seed_preds_roberta.append(preds.tolist())
147
+
148
+ seed_results_roberta.append({
149
+ 'seed': seed,
150
+ 'macro_recall': test_metrics['eval_macro_recall'],
151
+ 'macro_f1': test_metrics['eval_macro_f1'],
152
+ 'accuracy': test_metrics['eval_accuracy'],
153
+ 'train_time_min': round(train_time/60, 1),
154
+ 'train_loss': result.training_loss,
155
+ })
156
+
157
+ del model_ft, trainer
158
+ gc.collect()
159
+
160
+ # Aggregate multi-seed results
161
+ recalls = [r['macro_recall'] for r in seed_results_roberta]
162
+ f1s = [r['macro_f1'] for r in seed_results_roberta]
163
+ accs = [r['accuracy'] for r in seed_results_roberta]
164
+
165
+ log(f"\n RoBERTa-base (3 seeds): Macro-Recall={np.mean(recalls):.4f}+/-{np.std(recalls):.4f}, "
166
+ f"Macro-F1={np.mean(f1s):.4f}+/-{np.std(f1s):.4f}, Acc={np.mean(accs):.4f}+/-{np.std(accs):.4f}")
167
+
168
+ # ══════════════════════════════════════════════════════════════════
169
+ # 3. EVALUATE PRE-TRAINED MODELS (multi-pass for consistency)
170
+ # ══════════════════════════════════════════════════════════════════
171
+ log("\n[3/7] Evaluating pre-trained models on TweetEval...")
172
+
173
+ te_test_texts = [preprocess_tweet(t) for t in list(tweeteval['test']['text'])]
174
+ te_test_labels = list(tweeteval['test']['label'])
175
+ sst2_texts = list(sst2['validation']['sentence'])
176
+ sst2_labels = list(sst2['validation']['label'])
177
+
178
+ all_model_preds = {}
179
+ all_model_metrics = {}
180
+
181
+ # Twitter-RoBERTa (already fine-tuned, 3-class)
182
+ log(" Twitter-RoBERTa...")
183
+ pipe = pipeline("text-classification", model="cardiffnlp/twitter-roberta-base-sentiment-latest",
184
+ device=-1, batch_size=32, top_k=None)
185
+ preds_twr = []
186
+ for out in pipe(te_test_texts, truncation=True, max_length=128):
187
+ scores = {r['label'].lower(): r['score'] for r in out}
188
+ label_scores = [scores.get('negative',0), scores.get('neutral',0), scores.get('positive',0)]
189
+ preds_twr.append(np.argmax(label_scores))
190
+
191
+ mr_twr = recall_score(te_test_labels, preds_twr, average='macro')
192
+ mf1_twr = f1_score(te_test_labels, preds_twr, average='macro')
193
+ acc_twr = accuracy_score(te_test_labels, preds_twr)
194
+ log(f" Twitter-RoBERTa: MR={mr_twr:.4f}, MF1={mf1_twr:.4f}, Acc={acc_twr:.4f}")
195
+ all_model_preds['Twitter-RoBERTa'] = preds_twr
196
+ all_model_metrics['Twitter-RoBERTa'] = {'macro_recall': mr_twr, 'macro_f1': mf1_twr, 'accuracy': acc_twr, 'params': '125M'}
197
+
198
+ # SST-2 eval for Twitter-RoBERTa
199
+ preds_twr_sst2 = []
200
+ for out in pipe(sst2_texts, truncation=True, max_length=128):
201
+ scores = {r['label'].lower(): r['score'] for r in out}
202
+ preds_twr_sst2.append(1 if scores.get('positive',0) > scores.get('negative',0) else 0)
203
+ acc_twr_sst2 = accuracy_score(sst2_labels, preds_twr_sst2)
204
+ all_model_metrics['Twitter-RoBERTa']['sst2_accuracy'] = acc_twr_sst2
205
+ del pipe; gc.collect()
206
+
207
+ # DeBERTa-v3-base (binary, SST-2 fine-tuned)
208
+ log(" DeBERTa-v3-base...")
209
+ pipe = pipeline("text-classification", model="cliang1453/deberta-v3-base-sst2", device=-1, batch_size=16)
210
+ preds_deb = []
211
+ for out in pipe(te_test_texts, truncation=True, max_length=128):
212
+ preds_deb.append(2 if out['label'].lower() == 'positive' else 0)
213
+ mr_deb = recall_score(te_test_labels, preds_deb, average='macro')
214
+ mf1_deb = f1_score(te_test_labels, preds_deb, average='macro', zero_division=0)
215
+ acc_deb = accuracy_score(te_test_labels, preds_deb)
216
+ log(f" DeBERTa-v3: MR={mr_deb:.4f}, MF1={mf1_deb:.4f}, Acc={acc_deb:.4f}")
217
+ all_model_preds['DeBERTa-v3-base'] = preds_deb
218
+ all_model_metrics['DeBERTa-v3-base'] = {'macro_recall': mr_deb, 'macro_f1': mf1_deb, 'accuracy': acc_deb, 'params': '184M'}
219
+
220
+ preds_deb_sst2 = []
221
+ for out in pipe(sst2_texts, truncation=True, max_length=128):
222
+ preds_deb_sst2.append(1 if out['label'].lower() == 'positive' else 0)
223
+ all_model_metrics['DeBERTa-v3-base']['sst2_accuracy'] = accuracy_score(sst2_labels, preds_deb_sst2)
224
+ del pipe; gc.collect()
225
+
226
+ # BERT-base
227
+ log(" BERT-base...")
228
+ pipe = pipeline("text-classification", model="textattack/bert-base-uncased-SST-2", device=-1, batch_size=32)
229
+ preds_bert = []
230
+ for out in pipe(te_test_texts, truncation=True, max_length=128):
231
+ preds_bert.append(2 if out['label'].upper() in ['POSITIVE','LABEL_1','1'] else 0)
232
+ mr_bert = recall_score(te_test_labels, preds_bert, average='macro')
233
+ mf1_bert = f1_score(te_test_labels, preds_bert, average='macro', zero_division=0)
234
+ acc_bert = accuracy_score(te_test_labels, preds_bert)
235
+ all_model_preds['BERT-base'] = preds_bert
236
+ all_model_metrics['BERT-base'] = {'macro_recall': mr_bert, 'macro_f1': mf1_bert, 'accuracy': acc_bert, 'params': '110M'}
237
+
238
+ preds_bert_sst2 = []
239
+ for out in pipe(sst2_texts, truncation=True, max_length=128):
240
+ preds_bert_sst2.append(1 if out['label'].upper() in ['POSITIVE','LABEL_1','1'] else 0)
241
+ all_model_metrics['BERT-base']['sst2_accuracy'] = accuracy_score(sst2_labels, preds_bert_sst2)
242
+ del pipe; gc.collect()
243
+
244
+ # DistilBERT
245
+ log(" DistilBERT...")
246
+ pipe = pipeline("text-classification", model="distilbert/distilbert-base-uncased-finetuned-sst-2-english",
247
+ device=-1, batch_size=32)
248
+ preds_distil = []
249
+ for out in pipe(te_test_texts, truncation=True, max_length=128):
250
+ preds_distil.append(2 if out['label'] == 'POSITIVE' else 0)
251
+ mr_dist = recall_score(te_test_labels, preds_distil, average='macro')
252
+ mf1_dist = f1_score(te_test_labels, preds_distil, average='macro', zero_division=0)
253
+ acc_dist = accuracy_score(te_test_labels, preds_distil)
254
+ all_model_preds['DistilBERT'] = preds_distil
255
+ all_model_metrics['DistilBERT'] = {'macro_recall': mr_dist, 'macro_f1': mf1_dist, 'accuracy': acc_dist, 'params': '66M'}
256
+
257
+ preds_dist_sst2 = []
258
+ for out in pipe(sst2_texts, truncation=True, max_length=128):
259
+ preds_dist_sst2.append(1 if out['label'] == 'POSITIVE' else 0)
260
+ all_model_metrics['DistilBERT']['sst2_accuracy'] = accuracy_score(sst2_labels, preds_dist_sst2)
261
+ del pipe; gc.collect()
262
+
263
+ # Add fine-tuned RoBERTa (use best seed predictions)
264
+ best_seed_idx = np.argmax(recalls)
265
+ all_model_preds['RoBERTa-FT'] = seed_preds_roberta[best_seed_idx]
266
+ all_model_metrics['RoBERTa-FT'] = {
267
+ 'macro_recall': np.mean(recalls), 'macro_recall_std': np.std(recalls),
268
+ 'macro_f1': np.mean(f1s), 'macro_f1_std': np.std(f1s),
269
+ 'accuracy': np.mean(accs), 'accuracy_std': np.std(accs),
270
+ 'params': '125M',
271
+ }
272
+
273
+ # ══════════════════════════════════════════════════════════════════
274
+ # 4. ABLATION: Learning Rate Sweep
275
+ # ══════════════════════════════════════════════════════════════════
276
+ log("\n[4/7] Ablation: Learning Rate Sweep (roberta-base, seed=1)...")
277
+
278
+ ablation_lrs = [5e-6, 1e-5, 3e-5]
279
+ ablation_results = []
280
+
281
+ for lr in ablation_lrs:
282
+ log(f" LR={lr}...")
283
+ model_ab = AutoModelForSequenceClassification.from_pretrained(MODEL_FT, num_labels=3)
284
+ args_ab = TrainingArguments(
285
+ output_dir=f'/app/ablation_lr_{lr}',
286
+ num_train_epochs=2,
287
+ per_device_train_batch_size=16,
288
+ per_device_eval_batch_size=64,
289
+ learning_rate=lr,
290
+ weight_decay=0.01,
291
+ warmup_ratio=0.1,
292
+ eval_strategy='epoch',
293
+ save_strategy='epoch',
294
+ disable_tqdm=True,
295
+ load_best_model_at_end=True,
296
+ metric_for_best_model='macro_recall',
297
+ greater_is_better=True,
298
+ seed=1,
299
+ dataloader_num_workers=0,
300
+ fp16=False,
301
+ report_to='none',
302
+ save_total_limit=1,
303
+ logging_strategy='steps',
304
+ logging_steps=500,
305
+ logging_first_step=True,
306
+ )
307
+ trainer_ab = Trainer(
308
+ model=model_ab, args=args_ab,
309
+ train_dataset=te_enc['train'],
310
+ eval_dataset=te_enc['validation'],
311
+ data_collator=collator_ft,
312
+ compute_metrics=compute_metrics_ft,
313
+ )
314
+ trainer_ab.train()
315
+ test_ab = trainer_ab.evaluate(te_enc['test'])
316
+ ablation_results.append({
317
+ 'lr': lr,
318
+ 'macro_recall': test_ab['eval_macro_recall'],
319
+ 'macro_f1': test_ab['eval_macro_f1'],
320
+ 'accuracy': test_ab['eval_accuracy'],
321
+ })
322
+ log(f" LR={lr}: MR={test_ab['eval_macro_recall']:.4f}, MF1={test_ab['eval_macro_f1']:.4f}")
323
+ del model_ab, trainer_ab; gc.collect()
324
+
325
+ # ══════════════════════════════════════════════════════════════════
326
+ # 5. McNEMAR TESTS
327
+ # ══════════════════════════════════════════════════════════════════
328
+ log("\n[5/7] McNemar Statistical Significance Tests...")
329
+
330
+ def mcnemar_test(y_true, y_pred1, y_pred2):
331
+ b = sum(1 for yt,p1,p2 in zip(y_true,y_pred1,y_pred2) if p1==yt and p2!=yt)
332
+ c = sum(1 for yt,p1,p2 in zip(y_true,y_pred1,y_pred2) if p1!=yt and p2==yt)
333
+ if b + c == 0: return 0.0, 1.0
334
+ stat = (abs(b - c) - 1)**2 / (b + c)
335
+ p_val = 1 - chi2.cdf(stat, df=1)
336
+ return stat, p_val
337
+
338
+ mcnemar_results = {}
339
+ model_names = list(all_model_preds.keys())
340
+ for i in range(len(model_names)):
341
+ for j in range(i+1, len(model_names)):
342
+ m1, m2 = model_names[i], model_names[j]
343
+ stat, p = mcnemar_test(te_test_labels, all_model_preds[m1], all_model_preds[m2])
344
+ sig = "***" if p < 0.001 else ("**" if p < 0.01 else ("*" if p < 0.05 else "ns"))
345
+ mcnemar_results[f"{m1} vs {m2}"] = {'statistic': round(stat, 2), 'p_value': round(p, 6), 'significance': sig}
346
+ log(f" {m1} vs {m2}: chi2={stat:.2f}, p={p:.6f} {sig}")
347
+
348
+ # ══════════════════════════════════════════════════════════════════
349
+ # 6. GENERATE ALL FIGURES
350
+ # ══════════════════════════════════════════════════════════════════
351
+ log("\n[6/7] Generating figures...")
352
+
353
+ # Fig 1: Training curves (from seed 1)
354
+ log(" Fig 1: Training loss curves...")
355
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
356
+ for sidx, hist in enumerate(training_histories):
357
+ train_losses = [(h['step'], h['loss']) for h in hist if 'loss' in h and 'eval_loss' not in h]
358
+ eval_losses = [(h['epoch'], h['eval_loss']) for h in hist if 'eval_loss' in h]
359
+ eval_recalls = [(h['epoch'], h['eval_macro_recall']) for h in hist if 'eval_macro_recall' in h]
360
+
361
+ if train_losses:
362
+ steps, losses = zip(*train_losses)
363
+ ax1.plot(steps, losses, alpha=0.7, label=f'Seed {sidx+1}')
364
+ if eval_recalls:
365
+ epochs, recs = zip(*eval_recalls)
366
+ ax2.plot(epochs, [r*100 for r in recs], 'o-', alpha=0.8, label=f'Seed {sidx+1}')
367
+
368
+ ax1.set_xlabel('Training Step'); ax1.set_ylabel('Loss')
369
+ ax1.set_title('(a) Training Loss', fontweight='bold')
370
+ ax1.legend(); ax1.grid(alpha=0.3)
371
+ ax2.set_xlabel('Epoch'); ax2.set_ylabel('Macro-Recall (%)')
372
+ ax2.set_title('(b) Validation Macro-Recall', fontweight='bold')
373
+ ax2.legend(); ax2.grid(alpha=0.3)
374
+ plt.tight_layout()
375
+ plt.savefig('/app/figures_v2/fig_training_curves.png', dpi=300, bbox_inches='tight')
376
+ plt.savefig('/app/figures_v2/fig_training_curves.pdf', bbox_inches='tight')
377
+ plt.close()
378
+
379
+ # Fig 2: Confusion matrices (2x2 grid)
380
+ log(" Fig 2: Confusion matrices...")
381
+ fig, axes = plt.subplots(1, 3, figsize=(14, 4))
382
+ models_cm = [('Twitter-RoBERTa', preds_twr), ('RoBERTa-FT (Ours)', seed_preds_roberta[best_seed_idx]), ('DeBERTa-v3', preds_deb)]
383
+ cmaps = ['Blues', 'Greens', 'Oranges']
384
+ for idx, (name, preds) in enumerate(models_cm):
385
+ cm = confusion_matrix(te_test_labels, preds, normalize='true')
386
+ im = axes[idx].imshow(cm, interpolation='nearest', cmap=cmaps[idx], vmin=0, vmax=1)
387
+ axes[idx].set(xticks=range(3), yticks=range(3),
388
+ xticklabels=['Neg','Neu','Pos'], yticklabels=['Neg','Neu','Pos'],
389
+ title=name, ylabel='True' if idx==0 else '', xlabel='Predicted')
390
+ for i in range(3):
391
+ for j in range(3):
392
+ axes[idx].text(j, i, f'{cm[i,j]:.2f}', ha='center', va='center',
393
+ color='white' if cm[i,j]>0.5 else 'black', fontsize=11)
394
+ plt.tight_layout()
395
+ plt.savefig('/app/figures_v2/fig_confusion_matrices.png', dpi=300, bbox_inches='tight')
396
+ plt.savefig('/app/figures_v2/fig_confusion_matrices.pdf', bbox_inches='tight')
397
+ plt.close()
398
+
399
+ # Fig 3: Model comparison bar chart
400
+ log(" Fig 3: Model comparison...")
401
+ models_plot = ['DistilBERT\n(66M)', 'BERT-base\n(110M)', 'RoBERTa-FT\n(125M)\n[Ours]', 'Twitter-\nRoBERTa\n(125M)', 'DeBERTa-v3\n(184M)']
402
+ sst2_vals = [
403
+ all_model_metrics['DistilBERT'].get('sst2_accuracy',0)*100,
404
+ all_model_metrics['BERT-base'].get('sst2_accuracy',0)*100,
405
+ np.mean(accs)*100, # RoBERTa-FT doesn't have SST-2 eval, use N/A
406
+ all_model_metrics['Twitter-RoBERTa'].get('sst2_accuracy',0)*100,
407
+ all_model_metrics['DeBERTa-v3-base'].get('sst2_accuracy',0)*100,
408
+ ]
409
+ tweet_vals = [
410
+ all_model_metrics['DistilBERT']['macro_recall']*100,
411
+ all_model_metrics['BERT-base']['macro_recall']*100,
412
+ np.mean(recalls)*100,
413
+ all_model_metrics['Twitter-RoBERTa']['macro_recall']*100,
414
+ all_model_metrics['DeBERTa-v3-base']['macro_recall']*100,
415
+ ]
416
+ tweet_stds = [0, 0, np.std(recalls)*100, 0, 0]
417
+
418
+ x = np.arange(len(models_plot)); width = 0.35
419
+ fig, ax = plt.subplots(figsize=(10, 5))
420
+ bars1 = ax.bar(x - width/2, sst2_vals, width, label='SST-2 Accuracy (%)', color='#2196F3', edgecolor='black', linewidth=0.5)
421
+ bars2 = ax.bar(x + width/2, tweet_vals, width, yerr=tweet_stds, capsize=3,
422
+ label='TweetEval Macro-Recall (%)', color='#FF9800', edgecolor='black', linewidth=0.5)
423
+ ax.set_ylabel('Score (%)'); ax.set_xticks(x); ax.set_xticklabels(models_plot, fontsize=8)
424
+ ax.set_title('Model Comparison: SST-2 vs TweetEval Performance', fontweight='bold')
425
+ ax.legend(loc='upper left'); ax.set_ylim(0, 105); ax.grid(axis='y', alpha=0.3)
426
+ ax.axhline(y=95, color='red', linestyle='--', alpha=0.5); ax.text(4.3, 95.5, '95%', color='red', fontsize=7)
427
+ for bar in bars1+bars2:
428
+ h = bar.get_height()
429
+ if h > 0: ax.text(bar.get_x()+bar.get_width()/2., h+1, f'{h:.1f}', ha='center', va='bottom', fontsize=7)
430
+ plt.tight_layout()
431
+ plt.savefig('/app/figures_v2/fig_model_comparison.png', dpi=300, bbox_inches='tight')
432
+ plt.savefig('/app/figures_v2/fig_model_comparison.pdf', bbox_inches='tight')
433
+ plt.close()
434
+
435
+ # Fig 4: Ablation - Learning Rate
436
+ log(" Fig 4: Learning rate ablation...")
437
+ fig, ax = plt.subplots(figsize=(7, 4))
438
+ lrs = [r['lr'] for r in ablation_results]
439
+ mr_vals = [r['macro_recall']*100 for r in ablation_results]
440
+ mf1_vals = [r['macro_f1']*100 for r in ablation_results]
441
+ ax.plot(range(len(lrs)), mr_vals, 'o-', color='#2196F3', label='Macro-Recall', linewidth=2, markersize=8)
442
+ ax.plot(range(len(lrs)), mf1_vals, 's--', color='#FF9800', label='Macro-F1', linewidth=2, markersize=8)
443
+ ax.set_xticks(range(len(lrs))); ax.set_xticklabels([f'{lr:.0e}' for lr in lrs])
444
+ ax.set_xlabel('Learning Rate'); ax.set_ylabel('Score (%)')
445
+ ax.set_title('Ablation: Learning Rate Sensitivity (RoBERTa-base)', fontweight='bold')
446
+ ax.legend(); ax.grid(alpha=0.3)
447
+ for i, (mr, mf) in enumerate(zip(mr_vals, mf1_vals)):
448
+ ax.annotate(f'{mr:.1f}', (i, mr), textcoords="offset points", xytext=(0,8), ha='center', fontsize=7)
449
+ plt.tight_layout()
450
+ plt.savefig('/app/figures_v2/fig_lr_ablation.png', dpi=300, bbox_inches='tight')
451
+ plt.savefig('/app/figures_v2/fig_lr_ablation.pdf', bbox_inches='tight')
452
+ plt.close()
453
+
454
+ # Fig 5: Per-class F1 comparison (3 models)
455
+ log(" Fig 5: Per-class F1...")
456
+ fig, ax = plt.subplots(figsize=(8, 4.5))
457
+ classes = ['Negative', 'Neutral', 'Positive']
458
+ x = np.arange(3); w = 0.25
459
+ for idx, (name, preds) in enumerate(models_cm):
460
+ report = classification_report(te_test_labels, preds, output_dict=True, zero_division=0)
461
+ f1s_class = [report[c]['f1-score']*100 for c in classes]
462
+ bars = ax.bar(x + idx*w, f1s_class, w, label=name, edgecolor='black', linewidth=0.5)
463
+ for bar, v in zip(bars, f1s_class):
464
+ ax.text(bar.get_x()+bar.get_width()/2., bar.get_height()+1, f'{v:.1f}', ha='center', fontsize=7)
465
+ ax.set_xticks(x + w); ax.set_xticklabels(classes)
466
+ ax.set_ylabel('F1-Score (%)'); ax.set_title('Per-Class F1 on TweetEval', fontweight='bold')
467
+ ax.legend(fontsize=9); ax.set_ylim(0, 100); ax.grid(axis='y', alpha=0.3)
468
+ plt.tight_layout()
469
+ plt.savefig('/app/figures_v2/fig_per_class_f1.png', dpi=300, bbox_inches='tight')
470
+ plt.savefig('/app/figures_v2/fig_per_class_f1.pdf', bbox_inches='tight')
471
+ plt.close()
472
+
473
+ # Fig 6: Data distribution
474
+ log(" Fig 6: Dataset distribution...")
475
+ fig, axes = plt.subplots(1, 2, figsize=(8, 3.5))
476
+ te_counts = Counter(te_test_labels)
477
+ axes[0].bar(classes, [te_counts[0], te_counts[1], te_counts[2]],
478
+ color=['#e74c3c', '#95a5a6', '#2ecc71'], edgecolor='black', linewidth=0.5)
479
+ axes[0].set_title('TweetEval Test (n=12,284)', fontweight='bold'); axes[0].set_ylabel('Count')
480
+ for i, v in enumerate([te_counts[0], te_counts[1], te_counts[2]]):
481
+ axes[0].text(i, v+50, f'{v}\n({v/len(te_test_labels)*100:.1f}%)', ha='center', fontsize=8)
482
+ sst2_counts = Counter(sst2_labels)
483
+ axes[1].bar(['Negative', 'Positive'], [sst2_counts[0], sst2_counts[1]],
484
+ color=['#e74c3c', '#2ecc71'], edgecolor='black', linewidth=0.5)
485
+ axes[1].set_title('SST-2 Validation (n=872)', fontweight='bold'); axes[1].set_ylabel('Count')
486
+ for i, v in enumerate([sst2_counts[0], sst2_counts[1]]):
487
+ axes[1].text(i, v+5, f'{v}\n({v/len(sst2_labels)*100:.1f}%)', ha='center', fontsize=8)
488
+ plt.tight_layout()
489
+ plt.savefig('/app/figures_v2/fig_data_distribution.png', dpi=300, bbox_inches='tight')
490
+ plt.savefig('/app/figures_v2/fig_data_distribution.pdf', bbox_inches='tight')
491
+ plt.close()
492
+
493
+ # ══════════════════════════════════════════════════════════════════
494
+ # 7. SAVE ALL RESULTS
495
+ # ══════════════════════════════════════════════════════════════════
496
+ log("\n[7/7] Saving results...")
497
+
498
+ full_results = {
499
+ 'multi_seed_roberta_ft': seed_results_roberta,
500
+ 'multi_seed_summary': {
501
+ 'macro_recall_mean': round(np.mean(recalls), 4),
502
+ 'macro_recall_std': round(np.std(recalls), 4),
503
+ 'macro_f1_mean': round(np.mean(f1s), 4),
504
+ 'macro_f1_std': round(np.std(f1s), 4),
505
+ 'accuracy_mean': round(np.mean(accs), 4),
506
+ 'accuracy_std': round(np.std(accs), 4),
507
+ },
508
+ 'all_models': {k: {kk: round(vv, 4) if isinstance(vv, float) else vv for kk, vv in v.items()}
509
+ for k, v in all_model_metrics.items()},
510
+ 'ablation_lr': ablation_results,
511
+ 'mcnemar_tests': mcnemar_results,
512
+ }
513
+
514
+ with open('/app/full_experiment_results.json', 'w') as f:
515
+ json.dump(full_results, f, indent=2, default=str)
516
+
517
+ log("\n" + "="*70)
518
+ log("COMPLETE RESULTS SUMMARY")
519
+ log("="*70)
520
+ log(f"\n{'Model':<25} {'Params':>7} {'SST-2 Acc':>10} {'TweetEval MR':>13} {'TweetEval MF1':>14}")
521
+ log("-"*73)
522
+ for name, m in all_model_metrics.items():
523
+ sst2_str = f"{m.get('sst2_accuracy',0)*100:.2f}%" if 'sst2_accuracy' in m else "N/A"
524
+ mr = m['macro_recall']
525
+ mf1 = m['macro_f1']
526
+ std_str = f"+/-{m.get('macro_recall_std',0)*100:.2f}" if 'macro_recall_std' in m else ""
527
+ log(f"{name:<25} {m['params']:>7} {sst2_str:>10} {mr*100:>9.2f}%{std_str:>4} {mf1*100:>10.2f}%")
528
+ log("="*73)
529
+
530
+ log(f"\nAblation (LR):")
531
+ for r in ablation_results:
532
+ log(f" LR={r['lr']:.0e}: MR={r['macro_recall']*100:.2f}%, MF1={r['macro_f1']*100:.2f}%")
533
+
534
+ log(f"\nMcNemar Tests:")
535
+ for pair, res in mcnemar_results.items():
536
+ log(f" {pair}: p={res['p_value']:.6f} {res['significance']}")
537
+
538
+ log(f"\nAll figures saved to /app/figures_v2/")
539
+ log(f"Results saved to /app/full_experiment_results.json")
540
+ log("DONE!")