""" Generate final figures and compile results for IEEE paper. Uses seed 1 training data + all pre-trained model evaluations. """ import torch torch.set_num_threads(2) import json, gc, os, time import numpy as np from collections import Counter from datasets import load_dataset from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments, DataCollatorWithPadding from sklearn.metrics import recall_score, f1_score, accuracy_score, confusion_matrix, classification_report from scipy.stats import chi2 import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt os.makedirs("/app/final_figures", exist_ok=True) plt.rcParams['font.family'] = 'serif' plt.rcParams['font.size'] = 10 def log(msg): print(msg, flush=True) # ── Load data ── log("Loading datasets...") tweeteval = load_dataset("cardiffnlp/tweet_eval", "sentiment") sst2 = load_dataset("stanfordnlp/sst2") def preprocess_tweet(text): if not text: return "" return " ".join('@user' if w.startswith('@') and len(w)>1 else ('http' if w.startswith('http') else w) for w in text.split()) te_test_texts = [preprocess_tweet(t) for t in list(tweeteval['test']['text'])] te_test_labels = list(tweeteval['test']['label']) sst2_texts = list(sst2['validation']['sentence']) sst2_labels = list(sst2['validation']['label']) label_names = ['Negative', 'Neutral', 'Positive'] # ── Collect RoBERTa-FT predictions from trained model ── log("Loading fine-tuned RoBERTa-base (seed 1)...") tok_ft = AutoTokenizer.from_pretrained("roberta-base") model_ft = AutoModelForSequenceClassification.from_pretrained("/app/ft_roberta_seed1/checkpoint-5702") def tokenize_te(examples): texts = [preprocess_tweet(t) for t in examples['text']] return tok_ft(texts, truncation=True, max_length=128, padding=False) te_enc = tweeteval.map(tokenize_te, batched=True, remove_columns=['text']) collator = DataCollatorWithPadding(tokenizer=tok_ft) trainer_ft = Trainer(model=model_ft, data_collator=collator) pred_output = trainer_ft.predict(te_enc['test']) preds_roberta_ft = np.argmax(pred_output.predictions, axis=-1).tolist() mr_ft = recall_score(te_test_labels, preds_roberta_ft, average='macro') mf1_ft = f1_score(te_test_labels, preds_roberta_ft, average='macro') acc_ft = accuracy_score(te_test_labels, preds_roberta_ft) log(f" RoBERTa-FT: MR={mr_ft:.4f}, MF1={mf1_ft:.4f}, Acc={acc_ft:.4f}") del model_ft, trainer_ft; gc.collect() # ── Evaluate pre-trained models ── all_preds = {'RoBERTa-FT': preds_roberta_ft} all_metrics = {'RoBERTa-FT': {'macro_recall': mr_ft, 'macro_f1': mf1_ft, 'accuracy': acc_ft, 'params': '125M', 'sst2_accuracy': None}} # Twitter-RoBERTa log("Evaluating Twitter-RoBERTa...") pipe = pipeline("text-classification", model="cardiffnlp/twitter-roberta-base-sentiment-latest", device=-1, batch_size=32, top_k=None) preds = [] for out in pipe(te_test_texts, truncation=True, max_length=128): scores = {r['label'].lower(): r['score'] for r in out} preds.append(np.argmax([scores.get('negative',0), scores.get('neutral',0), scores.get('positive',0)])) all_preds['Twitter-RoBERTa'] = preds mr = recall_score(te_test_labels, preds, average='macro') mf1 = f1_score(te_test_labels, preds, average='macro') acc = accuracy_score(te_test_labels, preds) preds_sst2 = [] for out in pipe(sst2_texts, truncation=True, max_length=128): scores = {r['label'].lower(): r['score'] for r in out} preds_sst2.append(1 if scores.get('positive',0) > scores.get('negative',0) else 0) all_metrics['Twitter-RoBERTa'] = {'macro_recall': mr, 'macro_f1': mf1, 'accuracy': acc, 'params': '125M', 'sst2_accuracy': accuracy_score(sst2_labels, preds_sst2)} log(f" Twitter-RoBERTa: MR={mr:.4f}, SST2={all_metrics['Twitter-RoBERTa']['sst2_accuracy']:.4f}") del pipe; gc.collect() # DeBERTa-v3 log("Evaluating DeBERTa-v3...") pipe = pipeline("text-classification", model="cliang1453/deberta-v3-base-sst2", device=-1, batch_size=16) preds = [2 if out['label'].lower()=='positive' else 0 for out in pipe(te_test_texts, truncation=True, max_length=128)] all_preds['DeBERTa-v3'] = preds mr = recall_score(te_test_labels, preds, average='macro') mf1 = f1_score(te_test_labels, preds, average='macro', zero_division=0) acc = accuracy_score(te_test_labels, preds) preds_sst2 = [1 if out['label'].lower()=='positive' else 0 for out in pipe(sst2_texts, truncation=True, max_length=128)] all_metrics['DeBERTa-v3'] = {'macro_recall': mr, 'macro_f1': mf1, 'accuracy': acc, 'params': '184M', 'sst2_accuracy': accuracy_score(sst2_labels, preds_sst2)} log(f" DeBERTa-v3: MR={mr:.4f}, SST2={all_metrics['DeBERTa-v3']['sst2_accuracy']:.4f}") del pipe; gc.collect() # BERT-base log("Evaluating BERT-base...") pipe = pipeline("text-classification", model="textattack/bert-base-uncased-SST-2", device=-1, batch_size=32) preds = [2 if out['label'].upper() in ['POSITIVE','LABEL_1','1'] else 0 for out in pipe(te_test_texts, truncation=True, max_length=128)] all_preds['BERT-base'] = preds mr = recall_score(te_test_labels, preds, average='macro') mf1 = f1_score(te_test_labels, preds, average='macro', zero_division=0) preds_sst2 = [1 if out['label'].upper() in ['POSITIVE','LABEL_1','1'] else 0 for out in pipe(sst2_texts, truncation=True, max_length=128)] all_metrics['BERT-base'] = {'macro_recall': mr, 'macro_f1': mf1, 'accuracy': accuracy_score(te_test_labels, preds), 'params': '110M', 'sst2_accuracy': accuracy_score(sst2_labels, preds_sst2)} del pipe; gc.collect() # ── McNemar Tests ── log("\nMcNemar Statistical Significance Tests:") def mcnemar_test(y_true, p1, p2): b = sum(1 for yt,a,b in zip(y_true,p1,p2) if a==yt and b!=yt) c = sum(1 for yt,a,b in zip(y_true,p1,p2) if a!=yt and b==yt) if b+c==0: return 0.0, 1.0 stat = (abs(b-c)-1)**2/(b+c) return stat, 1-chi2.cdf(stat, df=1) mcnemar = {} names = list(all_preds.keys()) for i in range(len(names)): for j in range(i+1, len(names)): stat, p = mcnemar_test(te_test_labels, all_preds[names[i]], all_preds[names[j]]) sig = "***" if p<0.001 else ("**" if p<0.01 else ("*" if p<0.05 else "ns")) mcnemar[f"{names[i]} vs {names[j]}"] = {'chi2': round(stat,2), 'p': round(p,6), 'sig': sig} log(f" {names[i]} vs {names[j]}: chi2={stat:.2f}, p={p:.6f} {sig}") # ── FIGURES ── log("\nGenerating figures...") # Fig 1: Training Loss Curve log(" Fig 1: Training curves...") train_data = [ (0.0004, 1.053), (0.035, 1.088), (0.070, 1.040), (0.105, 0.910), (0.140, 0.769), (0.175, 0.690), (0.211, 0.731), (0.246, 0.672), (0.281, 0.658), (0.316, 0.643), (0.351, 0.657), (0.386, 0.650), (0.421, 0.635), (0.456, 0.661), (0.491, 0.625), (0.526, 0.611), (0.561, 0.635), (0.596, 0.623), (0.631, 0.606), (0.666, 0.624), (0.702, 0.608), (0.737, 0.615), (0.772, 0.640), (0.807, 0.590), (0.842, 0.609), (0.877, 0.617), (0.912, 0.589), (0.947, 0.586), (0.982, 0.587), (1.017, 0.555), (1.052, 0.550), (1.087, 0.547), (1.122, 0.545), (1.157, 0.529), (1.193, 0.528), (1.228, 0.522), (1.263, 0.546), (1.298, 0.539), (1.333, 0.520), (1.368, 0.538), (1.403, 0.548), (1.438, 0.526), (1.473, 0.539), (1.508, 0.527), (1.543, 0.504), (1.578, 0.503), (1.613, 0.498), (1.649, 0.512), (1.684, 0.491), (1.719, 0.532), (1.754, 0.489), (1.789, 0.510), (1.824, 0.500), (1.859, 0.510), (1.894, 0.516), (1.929, 0.521), (1.964, 0.506), (1.999, 0.516), ] val_data = [(1.0, 0.593, 72.34), (2.0, 0.588, 75.18)] fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4)) epochs_t, losses_t = zip(*train_data) ax1.plot(epochs_t, losses_t, color='#2196F3', alpha=0.8, linewidth=1.5, label='Training Loss') val_epochs, val_losses, _ = zip(*val_data) ax1.plot(val_epochs, val_losses, 'ro-', markersize=8, linewidth=2, label='Validation Loss') ax1.set_xlabel('Epoch'); ax1.set_ylabel('Cross-Entropy Loss') ax1.set_title('(a) Training & Validation Loss', fontweight='bold') ax1.legend(); ax1.grid(alpha=0.3) val_epochs2, _, val_recalls = zip(*val_data) ax2.plot(val_epochs2, val_recalls, 'go-', markersize=10, linewidth=2.5) ax2.set_xlabel('Epoch'); ax2.set_ylabel('Macro-Recall (%)') ax2.set_title('(b) Validation Macro-Recall', fontweight='bold') ax2.set_ylim(70, 78); ax2.grid(alpha=0.3) for e, r in zip(val_epochs2, val_recalls): ax2.annotate(f'{r:.2f}%', (e, r), textcoords="offset points", xytext=(10,5), fontsize=11, fontweight='bold') plt.tight_layout() plt.savefig('/app/final_figures/fig1_training_curves.png', dpi=300, bbox_inches='tight') plt.savefig('/app/final_figures/fig1_training_curves.pdf', bbox_inches='tight') plt.close() # Fig 2: Confusion Matrices (3 models) log(" Fig 2: Confusion matrices...") fig, axes = plt.subplots(1, 3, figsize=(14, 4.2)) models_cm = [('Twitter-RoBERTa\n(Pre-trained)', all_preds['Twitter-RoBERTa'], 'Blues'), ('RoBERTa-FT\n(Ours)', all_preds['RoBERTa-FT'], 'Greens'), ('DeBERTa-v3\n(SST-2 only)', all_preds['DeBERTa-v3'], 'Oranges')] for idx, (name, preds, cmap) in enumerate(models_cm): cm = confusion_matrix(te_test_labels, preds, normalize='true') im = axes[idx].imshow(cm, interpolation='nearest', cmap=cmap, vmin=0, vmax=1) axes[idx].set(xticks=range(3), yticks=range(3), xticklabels=['Neg','Neu','Pos'], yticklabels=['Neg','Neu','Pos'], title=name, ylabel='True' if idx==0 else '', xlabel='Predicted') for i in range(3): for j in range(3): axes[idx].text(j, i, f'{cm[i,j]:.2f}', ha='center', va='center', color='white' if cm[i,j]>0.5 else 'black', fontsize=12, fontweight='bold') plt.tight_layout() plt.savefig('/app/final_figures/fig2_confusion_matrices.png', dpi=300, bbox_inches='tight') plt.savefig('/app/final_figures/fig2_confusion_matrices.pdf', bbox_inches='tight') plt.close() # Fig 3: Model Comparison log(" Fig 3: Model comparison...") model_order = ['BERT-base', 'DeBERTa-v3', 'RoBERTa-FT', 'Twitter-RoBERTa'] model_labels = ['BERT-base\n(110M)', 'DeBERTa-v3\n(184M)', 'RoBERTa-FT\n(125M)\n[Ours]', 'Twitter-RoBERTa\n(125M)'] sst2_accs = [all_metrics[m].get('sst2_accuracy',0)*100 if all_metrics[m].get('sst2_accuracy') else 0 for m in model_order] tweet_mrs = [all_metrics[m]['macro_recall']*100 for m in model_order] x = np.arange(len(model_labels)); width = 0.35 fig, ax = plt.subplots(figsize=(9, 5)) b1 = ax.bar(x - width/2, sst2_accs, width, label='SST-2 Accuracy (%)', color='#2196F3', edgecolor='black', linewidth=0.5) b2 = ax.bar(x + width/2, tweet_mrs, width, label='TweetEval Macro-Recall (%)', color='#FF9800', edgecolor='black', linewidth=0.5) ax.set_ylabel('Score (%)'); ax.set_xticks(x); ax.set_xticklabels(model_labels, fontsize=9) ax.set_title('Model Comparison: SST-2 vs TweetEval', fontweight='bold', fontsize=13) ax.legend(fontsize=10); ax.set_ylim(0, 105); ax.grid(axis='y', alpha=0.3) ax.axhline(y=95, color='red', linestyle='--', alpha=0.5) for bar in list(b1)+list(b2): h = bar.get_height() if h > 0: ax.text(bar.get_x()+bar.get_width()/2., h+1, f'{h:.1f}', ha='center', va='bottom', fontsize=8) plt.tight_layout() plt.savefig('/app/final_figures/fig3_model_comparison.png', dpi=300, bbox_inches='tight') plt.savefig('/app/final_figures/fig3_model_comparison.pdf', bbox_inches='tight') plt.close() # Fig 4: Per-class F1 log(" Fig 4: Per-class F1...") fig, ax = plt.subplots(figsize=(8, 4.5)) x = np.arange(3); w = 0.25 for idx, (name, preds, cmap) in enumerate(models_cm): report = classification_report(te_test_labels, preds, output_dict=True, zero_division=0) f1s = [report[c]['f1-score']*100 for c in label_names] bars = ax.bar(x + idx*w, f1s, w, label=name.replace('\n', ' '), edgecolor='black', linewidth=0.5) for bar, v in zip(bars, f1s): if v > 0: ax.text(bar.get_x()+bar.get_width()/2., bar.get_height()+1, f'{v:.1f}', ha='center', fontsize=7) ax.set_xticks(x + w); ax.set_xticklabels(label_names) ax.set_ylabel('F1-Score (%)'); ax.set_title('Per-Class F1 on TweetEval Sentiment', fontweight='bold') ax.legend(fontsize=8); ax.set_ylim(0, 100); ax.grid(axis='y', alpha=0.3) plt.tight_layout() plt.savefig('/app/final_figures/fig4_per_class_f1.png', dpi=300, bbox_inches='tight') plt.savefig('/app/final_figures/fig4_per_class_f1.pdf', bbox_inches='tight') plt.close() # Fig 5: Dataset distribution log(" Fig 5: Dataset distribution...") fig, axes = plt.subplots(1, 2, figsize=(8, 3.5)) te_counts = Counter(te_test_labels) axes[0].bar(label_names, [te_counts[0], te_counts[1], te_counts[2]], color=['#e74c3c','#95a5a6','#2ecc71'], edgecolor='black', linewidth=0.5) axes[0].set_title('TweetEval Test (n=12,284)', fontweight='bold'); axes[0].set_ylabel('Count') for i, v in enumerate([te_counts[0], te_counts[1], te_counts[2]]): axes[0].text(i, v+50, f'{v}\n({v/len(te_test_labels)*100:.1f}%)', ha='center', fontsize=8) sst2_counts = Counter(sst2_labels) axes[1].bar(['Negative','Positive'], [sst2_counts[0], sst2_counts[1]], color=['#e74c3c','#2ecc71'], edgecolor='black', linewidth=0.5) axes[1].set_title('SST-2 Validation (n=872)', fontweight='bold'); axes[1].set_ylabel('Count') for i, v in enumerate([sst2_counts[0], sst2_counts[1]]): axes[1].text(i, v+5, f'{v}\n({v/len(sst2_labels)*100:.1f}%)', ha='center', fontsize=8) plt.tight_layout() plt.savefig('/app/final_figures/fig5_data_distribution.png', dpi=300, bbox_inches='tight') plt.savefig('/app/final_figures/fig5_data_distribution.pdf', bbox_inches='tight') plt.close() # ── Save all results ── results = { 'roberta_ft_seed1': {'macro_recall': round(mr_ft*100,2), 'macro_f1': round(mf1_ft*100,2), 'accuracy': round(acc_ft*100,2), 'val_epoch1': {'macro_recall': 72.34, 'macro_f1': 72.76}, 'val_epoch2': {'macro_recall': 75.18, 'macro_f1': 74.28}}, 'all_models': {k: {kk: round(vv*100,2) if isinstance(vv, float) and vv < 1.01 else vv for kk, vv in v.items()} for k,v in all_metrics.items()}, 'mcnemar': mcnemar, } with open('/app/final_results.json', 'w') as f: json.dump(results, f, indent=2) log("\n" + "="*70) log("FINAL RESULTS") log("="*70) log(f"\n{'Model':<22} {'Params':>7} {'SST-2 Acc':>10} {'TweetEval MR':>13} {'TweetEval MF1':>14}") log("-"*70) for name, m in all_metrics.items(): sst2_str = f"{m['sst2_accuracy']*100:.2f}%" if m.get('sst2_accuracy') else "N/A" log(f"{name:<22} {m['params']:>7} {sst2_str:>10} {m['macro_recall']*100:>12.2f}% {m['macro_f1']*100:>13.2f}%") log("="*70) log("\nDONE! All figures saved to /app/final_figures/")