rajvivan's picture
Add final_analysis.py
ae4d12a verified
"""
Generate final figures and compile results for IEEE paper.
Uses seed 1 training data + all pre-trained model evaluations.
"""
import torch
torch.set_num_threads(2)
import json, gc, os, time
import numpy as np
from collections import Counter
from datasets import load_dataset
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments, DataCollatorWithPadding
from sklearn.metrics import recall_score, f1_score, accuracy_score, confusion_matrix, classification_report
from scipy.stats import chi2
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
os.makedirs("/app/final_figures", exist_ok=True)
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.size'] = 10
def log(msg): print(msg, flush=True)
# ── Load data ──
log("Loading datasets...")
tweeteval = load_dataset("cardiffnlp/tweet_eval", "sentiment")
sst2 = load_dataset("stanfordnlp/sst2")
def preprocess_tweet(text):
if not text: return ""
return " ".join('@user' if w.startswith('@') and len(w)>1 else ('http' if w.startswith('http') else w) for w in text.split())
te_test_texts = [preprocess_tweet(t) for t in list(tweeteval['test']['text'])]
te_test_labels = list(tweeteval['test']['label'])
sst2_texts = list(sst2['validation']['sentence'])
sst2_labels = list(sst2['validation']['label'])
label_names = ['Negative', 'Neutral', 'Positive']
# ── Collect RoBERTa-FT predictions from trained model ──
log("Loading fine-tuned RoBERTa-base (seed 1)...")
tok_ft = AutoTokenizer.from_pretrained("roberta-base")
model_ft = AutoModelForSequenceClassification.from_pretrained("/app/ft_roberta_seed1/checkpoint-5702")
def tokenize_te(examples):
texts = [preprocess_tweet(t) for t in examples['text']]
return tok_ft(texts, truncation=True, max_length=128, padding=False)
te_enc = tweeteval.map(tokenize_te, batched=True, remove_columns=['text'])
collator = DataCollatorWithPadding(tokenizer=tok_ft)
trainer_ft = Trainer(model=model_ft, data_collator=collator)
pred_output = trainer_ft.predict(te_enc['test'])
preds_roberta_ft = np.argmax(pred_output.predictions, axis=-1).tolist()
mr_ft = recall_score(te_test_labels, preds_roberta_ft, average='macro')
mf1_ft = f1_score(te_test_labels, preds_roberta_ft, average='macro')
acc_ft = accuracy_score(te_test_labels, preds_roberta_ft)
log(f" RoBERTa-FT: MR={mr_ft:.4f}, MF1={mf1_ft:.4f}, Acc={acc_ft:.4f}")
del model_ft, trainer_ft; gc.collect()
# ── Evaluate pre-trained models ──
all_preds = {'RoBERTa-FT': preds_roberta_ft}
all_metrics = {'RoBERTa-FT': {'macro_recall': mr_ft, 'macro_f1': mf1_ft, 'accuracy': acc_ft, 'params': '125M', 'sst2_accuracy': None}}
# Twitter-RoBERTa
log("Evaluating Twitter-RoBERTa...")
pipe = pipeline("text-classification", model="cardiffnlp/twitter-roberta-base-sentiment-latest", device=-1, batch_size=32, top_k=None)
preds = []
for out in pipe(te_test_texts, truncation=True, max_length=128):
scores = {r['label'].lower(): r['score'] for r in out}
preds.append(np.argmax([scores.get('negative',0), scores.get('neutral',0), scores.get('positive',0)]))
all_preds['Twitter-RoBERTa'] = preds
mr = recall_score(te_test_labels, preds, average='macro')
mf1 = f1_score(te_test_labels, preds, average='macro')
acc = accuracy_score(te_test_labels, preds)
preds_sst2 = []
for out in pipe(sst2_texts, truncation=True, max_length=128):
scores = {r['label'].lower(): r['score'] for r in out}
preds_sst2.append(1 if scores.get('positive',0) > scores.get('negative',0) else 0)
all_metrics['Twitter-RoBERTa'] = {'macro_recall': mr, 'macro_f1': mf1, 'accuracy': acc, 'params': '125M', 'sst2_accuracy': accuracy_score(sst2_labels, preds_sst2)}
log(f" Twitter-RoBERTa: MR={mr:.4f}, SST2={all_metrics['Twitter-RoBERTa']['sst2_accuracy']:.4f}")
del pipe; gc.collect()
# DeBERTa-v3
log("Evaluating DeBERTa-v3...")
pipe = pipeline("text-classification", model="cliang1453/deberta-v3-base-sst2", device=-1, batch_size=16)
preds = [2 if out['label'].lower()=='positive' else 0 for out in pipe(te_test_texts, truncation=True, max_length=128)]
all_preds['DeBERTa-v3'] = preds
mr = recall_score(te_test_labels, preds, average='macro')
mf1 = f1_score(te_test_labels, preds, average='macro', zero_division=0)
acc = accuracy_score(te_test_labels, preds)
preds_sst2 = [1 if out['label'].lower()=='positive' else 0 for out in pipe(sst2_texts, truncation=True, max_length=128)]
all_metrics['DeBERTa-v3'] = {'macro_recall': mr, 'macro_f1': mf1, 'accuracy': acc, 'params': '184M', 'sst2_accuracy': accuracy_score(sst2_labels, preds_sst2)}
log(f" DeBERTa-v3: MR={mr:.4f}, SST2={all_metrics['DeBERTa-v3']['sst2_accuracy']:.4f}")
del pipe; gc.collect()
# BERT-base
log("Evaluating BERT-base...")
pipe = pipeline("text-classification", model="textattack/bert-base-uncased-SST-2", device=-1, batch_size=32)
preds = [2 if out['label'].upper() in ['POSITIVE','LABEL_1','1'] else 0 for out in pipe(te_test_texts, truncation=True, max_length=128)]
all_preds['BERT-base'] = preds
mr = recall_score(te_test_labels, preds, average='macro')
mf1 = f1_score(te_test_labels, preds, average='macro', zero_division=0)
preds_sst2 = [1 if out['label'].upper() in ['POSITIVE','LABEL_1','1'] else 0 for out in pipe(sst2_texts, truncation=True, max_length=128)]
all_metrics['BERT-base'] = {'macro_recall': mr, 'macro_f1': mf1, 'accuracy': accuracy_score(te_test_labels, preds), 'params': '110M', 'sst2_accuracy': accuracy_score(sst2_labels, preds_sst2)}
del pipe; gc.collect()
# ── McNemar Tests ──
log("\nMcNemar Statistical Significance Tests:")
def mcnemar_test(y_true, p1, p2):
b = sum(1 for yt,a,b in zip(y_true,p1,p2) if a==yt and b!=yt)
c = sum(1 for yt,a,b in zip(y_true,p1,p2) if a!=yt and b==yt)
if b+c==0: return 0.0, 1.0
stat = (abs(b-c)-1)**2/(b+c)
return stat, 1-chi2.cdf(stat, df=1)
mcnemar = {}
names = list(all_preds.keys())
for i in range(len(names)):
for j in range(i+1, len(names)):
stat, p = mcnemar_test(te_test_labels, all_preds[names[i]], all_preds[names[j]])
sig = "***" if p<0.001 else ("**" if p<0.01 else ("*" if p<0.05 else "ns"))
mcnemar[f"{names[i]} vs {names[j]}"] = {'chi2': round(stat,2), 'p': round(p,6), 'sig': sig}
log(f" {names[i]} vs {names[j]}: chi2={stat:.2f}, p={p:.6f} {sig}")
# ── FIGURES ──
log("\nGenerating figures...")
# Fig 1: Training Loss Curve
log(" Fig 1: Training curves...")
train_data = [
(0.0004, 1.053), (0.035, 1.088), (0.070, 1.040), (0.105, 0.910), (0.140, 0.769),
(0.175, 0.690), (0.211, 0.731), (0.246, 0.672), (0.281, 0.658), (0.316, 0.643),
(0.351, 0.657), (0.386, 0.650), (0.421, 0.635), (0.456, 0.661), (0.491, 0.625),
(0.526, 0.611), (0.561, 0.635), (0.596, 0.623), (0.631, 0.606), (0.666, 0.624),
(0.702, 0.608), (0.737, 0.615), (0.772, 0.640), (0.807, 0.590), (0.842, 0.609),
(0.877, 0.617), (0.912, 0.589), (0.947, 0.586), (0.982, 0.587),
(1.017, 0.555), (1.052, 0.550), (1.087, 0.547), (1.122, 0.545), (1.157, 0.529),
(1.193, 0.528), (1.228, 0.522), (1.263, 0.546), (1.298, 0.539), (1.333, 0.520),
(1.368, 0.538), (1.403, 0.548), (1.438, 0.526), (1.473, 0.539), (1.508, 0.527),
(1.543, 0.504), (1.578, 0.503), (1.613, 0.498), (1.649, 0.512), (1.684, 0.491),
(1.719, 0.532), (1.754, 0.489), (1.789, 0.510), (1.824, 0.500), (1.859, 0.510),
(1.894, 0.516), (1.929, 0.521), (1.964, 0.506), (1.999, 0.516),
]
val_data = [(1.0, 0.593, 72.34), (2.0, 0.588, 75.18)]
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
epochs_t, losses_t = zip(*train_data)
ax1.plot(epochs_t, losses_t, color='#2196F3', alpha=0.8, linewidth=1.5, label='Training Loss')
val_epochs, val_losses, _ = zip(*val_data)
ax1.plot(val_epochs, val_losses, 'ro-', markersize=8, linewidth=2, label='Validation Loss')
ax1.set_xlabel('Epoch'); ax1.set_ylabel('Cross-Entropy Loss')
ax1.set_title('(a) Training & Validation Loss', fontweight='bold')
ax1.legend(); ax1.grid(alpha=0.3)
val_epochs2, _, val_recalls = zip(*val_data)
ax2.plot(val_epochs2, val_recalls, 'go-', markersize=10, linewidth=2.5)
ax2.set_xlabel('Epoch'); ax2.set_ylabel('Macro-Recall (%)')
ax2.set_title('(b) Validation Macro-Recall', fontweight='bold')
ax2.set_ylim(70, 78); ax2.grid(alpha=0.3)
for e, r in zip(val_epochs2, val_recalls): ax2.annotate(f'{r:.2f}%', (e, r), textcoords="offset points", xytext=(10,5), fontsize=11, fontweight='bold')
plt.tight_layout()
plt.savefig('/app/final_figures/fig1_training_curves.png', dpi=300, bbox_inches='tight')
plt.savefig('/app/final_figures/fig1_training_curves.pdf', bbox_inches='tight')
plt.close()
# Fig 2: Confusion Matrices (3 models)
log(" Fig 2: Confusion matrices...")
fig, axes = plt.subplots(1, 3, figsize=(14, 4.2))
models_cm = [('Twitter-RoBERTa\n(Pre-trained)', all_preds['Twitter-RoBERTa'], 'Blues'),
('RoBERTa-FT\n(Ours)', all_preds['RoBERTa-FT'], 'Greens'),
('DeBERTa-v3\n(SST-2 only)', all_preds['DeBERTa-v3'], 'Oranges')]
for idx, (name, preds, cmap) in enumerate(models_cm):
cm = confusion_matrix(te_test_labels, preds, normalize='true')
im = axes[idx].imshow(cm, interpolation='nearest', cmap=cmap, vmin=0, vmax=1)
axes[idx].set(xticks=range(3), yticks=range(3), xticklabels=['Neg','Neu','Pos'], yticklabels=['Neg','Neu','Pos'],
title=name, ylabel='True' if idx==0 else '', xlabel='Predicted')
for i in range(3):
for j in range(3):
axes[idx].text(j, i, f'{cm[i,j]:.2f}', ha='center', va='center',
color='white' if cm[i,j]>0.5 else 'black', fontsize=12, fontweight='bold')
plt.tight_layout()
plt.savefig('/app/final_figures/fig2_confusion_matrices.png', dpi=300, bbox_inches='tight')
plt.savefig('/app/final_figures/fig2_confusion_matrices.pdf', bbox_inches='tight')
plt.close()
# Fig 3: Model Comparison
log(" Fig 3: Model comparison...")
model_order = ['BERT-base', 'DeBERTa-v3', 'RoBERTa-FT', 'Twitter-RoBERTa']
model_labels = ['BERT-base\n(110M)', 'DeBERTa-v3\n(184M)', 'RoBERTa-FT\n(125M)\n[Ours]', 'Twitter-RoBERTa\n(125M)']
sst2_accs = [all_metrics[m].get('sst2_accuracy',0)*100 if all_metrics[m].get('sst2_accuracy') else 0 for m in model_order]
tweet_mrs = [all_metrics[m]['macro_recall']*100 for m in model_order]
x = np.arange(len(model_labels)); width = 0.35
fig, ax = plt.subplots(figsize=(9, 5))
b1 = ax.bar(x - width/2, sst2_accs, width, label='SST-2 Accuracy (%)', color='#2196F3', edgecolor='black', linewidth=0.5)
b2 = ax.bar(x + width/2, tweet_mrs, width, label='TweetEval Macro-Recall (%)', color='#FF9800', edgecolor='black', linewidth=0.5)
ax.set_ylabel('Score (%)'); ax.set_xticks(x); ax.set_xticklabels(model_labels, fontsize=9)
ax.set_title('Model Comparison: SST-2 vs TweetEval', fontweight='bold', fontsize=13)
ax.legend(fontsize=10); ax.set_ylim(0, 105); ax.grid(axis='y', alpha=0.3)
ax.axhline(y=95, color='red', linestyle='--', alpha=0.5)
for bar in list(b1)+list(b2):
h = bar.get_height()
if h > 0: ax.text(bar.get_x()+bar.get_width()/2., h+1, f'{h:.1f}', ha='center', va='bottom', fontsize=8)
plt.tight_layout()
plt.savefig('/app/final_figures/fig3_model_comparison.png', dpi=300, bbox_inches='tight')
plt.savefig('/app/final_figures/fig3_model_comparison.pdf', bbox_inches='tight')
plt.close()
# Fig 4: Per-class F1
log(" Fig 4: Per-class F1...")
fig, ax = plt.subplots(figsize=(8, 4.5))
x = np.arange(3); w = 0.25
for idx, (name, preds, cmap) in enumerate(models_cm):
report = classification_report(te_test_labels, preds, output_dict=True, zero_division=0)
f1s = [report[c]['f1-score']*100 for c in label_names]
bars = ax.bar(x + idx*w, f1s, w, label=name.replace('\n', ' '), edgecolor='black', linewidth=0.5)
for bar, v in zip(bars, f1s):
if v > 0: ax.text(bar.get_x()+bar.get_width()/2., bar.get_height()+1, f'{v:.1f}', ha='center', fontsize=7)
ax.set_xticks(x + w); ax.set_xticklabels(label_names)
ax.set_ylabel('F1-Score (%)'); ax.set_title('Per-Class F1 on TweetEval Sentiment', fontweight='bold')
ax.legend(fontsize=8); ax.set_ylim(0, 100); ax.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.savefig('/app/final_figures/fig4_per_class_f1.png', dpi=300, bbox_inches='tight')
plt.savefig('/app/final_figures/fig4_per_class_f1.pdf', bbox_inches='tight')
plt.close()
# Fig 5: Dataset distribution
log(" Fig 5: Dataset distribution...")
fig, axes = plt.subplots(1, 2, figsize=(8, 3.5))
te_counts = Counter(te_test_labels)
axes[0].bar(label_names, [te_counts[0], te_counts[1], te_counts[2]], color=['#e74c3c','#95a5a6','#2ecc71'], edgecolor='black', linewidth=0.5)
axes[0].set_title('TweetEval Test (n=12,284)', fontweight='bold'); axes[0].set_ylabel('Count')
for i, v in enumerate([te_counts[0], te_counts[1], te_counts[2]]):
axes[0].text(i, v+50, f'{v}\n({v/len(te_test_labels)*100:.1f}%)', ha='center', fontsize=8)
sst2_counts = Counter(sst2_labels)
axes[1].bar(['Negative','Positive'], [sst2_counts[0], sst2_counts[1]], color=['#e74c3c','#2ecc71'], edgecolor='black', linewidth=0.5)
axes[1].set_title('SST-2 Validation (n=872)', fontweight='bold'); axes[1].set_ylabel('Count')
for i, v in enumerate([sst2_counts[0], sst2_counts[1]]):
axes[1].text(i, v+5, f'{v}\n({v/len(sst2_labels)*100:.1f}%)', ha='center', fontsize=8)
plt.tight_layout()
plt.savefig('/app/final_figures/fig5_data_distribution.png', dpi=300, bbox_inches='tight')
plt.savefig('/app/final_figures/fig5_data_distribution.pdf', bbox_inches='tight')
plt.close()
# ── Save all results ──
results = {
'roberta_ft_seed1': {'macro_recall': round(mr_ft*100,2), 'macro_f1': round(mf1_ft*100,2), 'accuracy': round(acc_ft*100,2),
'val_epoch1': {'macro_recall': 72.34, 'macro_f1': 72.76},
'val_epoch2': {'macro_recall': 75.18, 'macro_f1': 74.28}},
'all_models': {k: {kk: round(vv*100,2) if isinstance(vv, float) and vv < 1.01 else vv for kk, vv in v.items()} for k,v in all_metrics.items()},
'mcnemar': mcnemar,
}
with open('/app/final_results.json', 'w') as f:
json.dump(results, f, indent=2)
log("\n" + "="*70)
log("FINAL RESULTS")
log("="*70)
log(f"\n{'Model':<22} {'Params':>7} {'SST-2 Acc':>10} {'TweetEval MR':>13} {'TweetEval MF1':>14}")
log("-"*70)
for name, m in all_metrics.items():
sst2_str = f"{m['sst2_accuracy']*100:.2f}%" if m.get('sst2_accuracy') else "N/A"
log(f"{name:<22} {m['params']:>7} {sst2_str:>10} {m['macro_recall']*100:>12.2f}% {m['macro_f1']*100:>13.2f}%")
log("="*70)
log("\nDONE! All figures saved to /app/final_figures/")