| """ |
| Generate final figures and compile results for IEEE paper. |
| Uses seed 1 training data + all pre-trained model evaluations. |
| """ |
|
|
| import torch |
| torch.set_num_threads(2) |
|
|
| import json, gc, os, time |
| import numpy as np |
| from collections import Counter |
| from datasets import load_dataset |
| from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments, DataCollatorWithPadding |
| from sklearn.metrics import recall_score, f1_score, accuracy_score, confusion_matrix, classification_report |
| from scipy.stats import chi2 |
| import matplotlib |
| matplotlib.use('Agg') |
| import matplotlib.pyplot as plt |
|
|
| os.makedirs("/app/final_figures", exist_ok=True) |
| plt.rcParams['font.family'] = 'serif' |
| plt.rcParams['font.size'] = 10 |
|
|
| def log(msg): print(msg, flush=True) |
|
|
| |
| log("Loading datasets...") |
| tweeteval = load_dataset("cardiffnlp/tweet_eval", "sentiment") |
| sst2 = load_dataset("stanfordnlp/sst2") |
|
|
| def preprocess_tweet(text): |
| if not text: return "" |
| return " ".join('@user' if w.startswith('@') and len(w)>1 else ('http' if w.startswith('http') else w) for w in text.split()) |
|
|
| te_test_texts = [preprocess_tweet(t) for t in list(tweeteval['test']['text'])] |
| te_test_labels = list(tweeteval['test']['label']) |
| sst2_texts = list(sst2['validation']['sentence']) |
| sst2_labels = list(sst2['validation']['label']) |
| label_names = ['Negative', 'Neutral', 'Positive'] |
|
|
| |
| log("Loading fine-tuned RoBERTa-base (seed 1)...") |
| tok_ft = AutoTokenizer.from_pretrained("roberta-base") |
| model_ft = AutoModelForSequenceClassification.from_pretrained("/app/ft_roberta_seed1/checkpoint-5702") |
|
|
| def tokenize_te(examples): |
| texts = [preprocess_tweet(t) for t in examples['text']] |
| return tok_ft(texts, truncation=True, max_length=128, padding=False) |
|
|
| te_enc = tweeteval.map(tokenize_te, batched=True, remove_columns=['text']) |
| collator = DataCollatorWithPadding(tokenizer=tok_ft) |
|
|
| trainer_ft = Trainer(model=model_ft, data_collator=collator) |
| pred_output = trainer_ft.predict(te_enc['test']) |
| preds_roberta_ft = np.argmax(pred_output.predictions, axis=-1).tolist() |
| mr_ft = recall_score(te_test_labels, preds_roberta_ft, average='macro') |
| mf1_ft = f1_score(te_test_labels, preds_roberta_ft, average='macro') |
| acc_ft = accuracy_score(te_test_labels, preds_roberta_ft) |
| log(f" RoBERTa-FT: MR={mr_ft:.4f}, MF1={mf1_ft:.4f}, Acc={acc_ft:.4f}") |
| del model_ft, trainer_ft; gc.collect() |
|
|
| |
| all_preds = {'RoBERTa-FT': preds_roberta_ft} |
| all_metrics = {'RoBERTa-FT': {'macro_recall': mr_ft, 'macro_f1': mf1_ft, 'accuracy': acc_ft, 'params': '125M', 'sst2_accuracy': None}} |
|
|
| |
| log("Evaluating Twitter-RoBERTa...") |
| pipe = pipeline("text-classification", model="cardiffnlp/twitter-roberta-base-sentiment-latest", device=-1, batch_size=32, top_k=None) |
| preds = [] |
| for out in pipe(te_test_texts, truncation=True, max_length=128): |
| scores = {r['label'].lower(): r['score'] for r in out} |
| preds.append(np.argmax([scores.get('negative',0), scores.get('neutral',0), scores.get('positive',0)])) |
| all_preds['Twitter-RoBERTa'] = preds |
| mr = recall_score(te_test_labels, preds, average='macro') |
| mf1 = f1_score(te_test_labels, preds, average='macro') |
| acc = accuracy_score(te_test_labels, preds) |
| preds_sst2 = [] |
| for out in pipe(sst2_texts, truncation=True, max_length=128): |
| scores = {r['label'].lower(): r['score'] for r in out} |
| preds_sst2.append(1 if scores.get('positive',0) > scores.get('negative',0) else 0) |
| all_metrics['Twitter-RoBERTa'] = {'macro_recall': mr, 'macro_f1': mf1, 'accuracy': acc, 'params': '125M', 'sst2_accuracy': accuracy_score(sst2_labels, preds_sst2)} |
| log(f" Twitter-RoBERTa: MR={mr:.4f}, SST2={all_metrics['Twitter-RoBERTa']['sst2_accuracy']:.4f}") |
| del pipe; gc.collect() |
|
|
| |
| log("Evaluating DeBERTa-v3...") |
| pipe = pipeline("text-classification", model="cliang1453/deberta-v3-base-sst2", device=-1, batch_size=16) |
| preds = [2 if out['label'].lower()=='positive' else 0 for out in pipe(te_test_texts, truncation=True, max_length=128)] |
| all_preds['DeBERTa-v3'] = preds |
| mr = recall_score(te_test_labels, preds, average='macro') |
| mf1 = f1_score(te_test_labels, preds, average='macro', zero_division=0) |
| acc = accuracy_score(te_test_labels, preds) |
| preds_sst2 = [1 if out['label'].lower()=='positive' else 0 for out in pipe(sst2_texts, truncation=True, max_length=128)] |
| all_metrics['DeBERTa-v3'] = {'macro_recall': mr, 'macro_f1': mf1, 'accuracy': acc, 'params': '184M', 'sst2_accuracy': accuracy_score(sst2_labels, preds_sst2)} |
| log(f" DeBERTa-v3: MR={mr:.4f}, SST2={all_metrics['DeBERTa-v3']['sst2_accuracy']:.4f}") |
| del pipe; gc.collect() |
|
|
| |
| log("Evaluating BERT-base...") |
| pipe = pipeline("text-classification", model="textattack/bert-base-uncased-SST-2", device=-1, batch_size=32) |
| preds = [2 if out['label'].upper() in ['POSITIVE','LABEL_1','1'] else 0 for out in pipe(te_test_texts, truncation=True, max_length=128)] |
| all_preds['BERT-base'] = preds |
| mr = recall_score(te_test_labels, preds, average='macro') |
| mf1 = f1_score(te_test_labels, preds, average='macro', zero_division=0) |
| preds_sst2 = [1 if out['label'].upper() in ['POSITIVE','LABEL_1','1'] else 0 for out in pipe(sst2_texts, truncation=True, max_length=128)] |
| all_metrics['BERT-base'] = {'macro_recall': mr, 'macro_f1': mf1, 'accuracy': accuracy_score(te_test_labels, preds), 'params': '110M', 'sst2_accuracy': accuracy_score(sst2_labels, preds_sst2)} |
| del pipe; gc.collect() |
|
|
| |
| log("\nMcNemar Statistical Significance Tests:") |
| def mcnemar_test(y_true, p1, p2): |
| b = sum(1 for yt,a,b in zip(y_true,p1,p2) if a==yt and b!=yt) |
| c = sum(1 for yt,a,b in zip(y_true,p1,p2) if a!=yt and b==yt) |
| if b+c==0: return 0.0, 1.0 |
| stat = (abs(b-c)-1)**2/(b+c) |
| return stat, 1-chi2.cdf(stat, df=1) |
|
|
| mcnemar = {} |
| names = list(all_preds.keys()) |
| for i in range(len(names)): |
| for j in range(i+1, len(names)): |
| stat, p = mcnemar_test(te_test_labels, all_preds[names[i]], all_preds[names[j]]) |
| sig = "***" if p<0.001 else ("**" if p<0.01 else ("*" if p<0.05 else "ns")) |
| mcnemar[f"{names[i]} vs {names[j]}"] = {'chi2': round(stat,2), 'p': round(p,6), 'sig': sig} |
| log(f" {names[i]} vs {names[j]}: chi2={stat:.2f}, p={p:.6f} {sig}") |
|
|
| |
| log("\nGenerating figures...") |
|
|
| |
| log(" Fig 1: Training curves...") |
| train_data = [ |
| (0.0004, 1.053), (0.035, 1.088), (0.070, 1.040), (0.105, 0.910), (0.140, 0.769), |
| (0.175, 0.690), (0.211, 0.731), (0.246, 0.672), (0.281, 0.658), (0.316, 0.643), |
| (0.351, 0.657), (0.386, 0.650), (0.421, 0.635), (0.456, 0.661), (0.491, 0.625), |
| (0.526, 0.611), (0.561, 0.635), (0.596, 0.623), (0.631, 0.606), (0.666, 0.624), |
| (0.702, 0.608), (0.737, 0.615), (0.772, 0.640), (0.807, 0.590), (0.842, 0.609), |
| (0.877, 0.617), (0.912, 0.589), (0.947, 0.586), (0.982, 0.587), |
| (1.017, 0.555), (1.052, 0.550), (1.087, 0.547), (1.122, 0.545), (1.157, 0.529), |
| (1.193, 0.528), (1.228, 0.522), (1.263, 0.546), (1.298, 0.539), (1.333, 0.520), |
| (1.368, 0.538), (1.403, 0.548), (1.438, 0.526), (1.473, 0.539), (1.508, 0.527), |
| (1.543, 0.504), (1.578, 0.503), (1.613, 0.498), (1.649, 0.512), (1.684, 0.491), |
| (1.719, 0.532), (1.754, 0.489), (1.789, 0.510), (1.824, 0.500), (1.859, 0.510), |
| (1.894, 0.516), (1.929, 0.521), (1.964, 0.506), (1.999, 0.516), |
| ] |
| val_data = [(1.0, 0.593, 72.34), (2.0, 0.588, 75.18)] |
|
|
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4)) |
| epochs_t, losses_t = zip(*train_data) |
| ax1.plot(epochs_t, losses_t, color='#2196F3', alpha=0.8, linewidth=1.5, label='Training Loss') |
| val_epochs, val_losses, _ = zip(*val_data) |
| ax1.plot(val_epochs, val_losses, 'ro-', markersize=8, linewidth=2, label='Validation Loss') |
| ax1.set_xlabel('Epoch'); ax1.set_ylabel('Cross-Entropy Loss') |
| ax1.set_title('(a) Training & Validation Loss', fontweight='bold') |
| ax1.legend(); ax1.grid(alpha=0.3) |
|
|
| val_epochs2, _, val_recalls = zip(*val_data) |
| ax2.plot(val_epochs2, val_recalls, 'go-', markersize=10, linewidth=2.5) |
| ax2.set_xlabel('Epoch'); ax2.set_ylabel('Macro-Recall (%)') |
| ax2.set_title('(b) Validation Macro-Recall', fontweight='bold') |
| ax2.set_ylim(70, 78); ax2.grid(alpha=0.3) |
| for e, r in zip(val_epochs2, val_recalls): ax2.annotate(f'{r:.2f}%', (e, r), textcoords="offset points", xytext=(10,5), fontsize=11, fontweight='bold') |
| plt.tight_layout() |
| plt.savefig('/app/final_figures/fig1_training_curves.png', dpi=300, bbox_inches='tight') |
| plt.savefig('/app/final_figures/fig1_training_curves.pdf', bbox_inches='tight') |
| plt.close() |
|
|
| |
| log(" Fig 2: Confusion matrices...") |
| fig, axes = plt.subplots(1, 3, figsize=(14, 4.2)) |
| models_cm = [('Twitter-RoBERTa\n(Pre-trained)', all_preds['Twitter-RoBERTa'], 'Blues'), |
| ('RoBERTa-FT\n(Ours)', all_preds['RoBERTa-FT'], 'Greens'), |
| ('DeBERTa-v3\n(SST-2 only)', all_preds['DeBERTa-v3'], 'Oranges')] |
| for idx, (name, preds, cmap) in enumerate(models_cm): |
| cm = confusion_matrix(te_test_labels, preds, normalize='true') |
| im = axes[idx].imshow(cm, interpolation='nearest', cmap=cmap, vmin=0, vmax=1) |
| axes[idx].set(xticks=range(3), yticks=range(3), xticklabels=['Neg','Neu','Pos'], yticklabels=['Neg','Neu','Pos'], |
| title=name, ylabel='True' if idx==0 else '', xlabel='Predicted') |
| for i in range(3): |
| for j in range(3): |
| axes[idx].text(j, i, f'{cm[i,j]:.2f}', ha='center', va='center', |
| color='white' if cm[i,j]>0.5 else 'black', fontsize=12, fontweight='bold') |
| plt.tight_layout() |
| plt.savefig('/app/final_figures/fig2_confusion_matrices.png', dpi=300, bbox_inches='tight') |
| plt.savefig('/app/final_figures/fig2_confusion_matrices.pdf', bbox_inches='tight') |
| plt.close() |
|
|
| |
| log(" Fig 3: Model comparison...") |
| model_order = ['BERT-base', 'DeBERTa-v3', 'RoBERTa-FT', 'Twitter-RoBERTa'] |
| model_labels = ['BERT-base\n(110M)', 'DeBERTa-v3\n(184M)', 'RoBERTa-FT\n(125M)\n[Ours]', 'Twitter-RoBERTa\n(125M)'] |
| sst2_accs = [all_metrics[m].get('sst2_accuracy',0)*100 if all_metrics[m].get('sst2_accuracy') else 0 for m in model_order] |
| tweet_mrs = [all_metrics[m]['macro_recall']*100 for m in model_order] |
|
|
| x = np.arange(len(model_labels)); width = 0.35 |
| fig, ax = plt.subplots(figsize=(9, 5)) |
| b1 = ax.bar(x - width/2, sst2_accs, width, label='SST-2 Accuracy (%)', color='#2196F3', edgecolor='black', linewidth=0.5) |
| b2 = ax.bar(x + width/2, tweet_mrs, width, label='TweetEval Macro-Recall (%)', color='#FF9800', edgecolor='black', linewidth=0.5) |
| ax.set_ylabel('Score (%)'); ax.set_xticks(x); ax.set_xticklabels(model_labels, fontsize=9) |
| ax.set_title('Model Comparison: SST-2 vs TweetEval', fontweight='bold', fontsize=13) |
| ax.legend(fontsize=10); ax.set_ylim(0, 105); ax.grid(axis='y', alpha=0.3) |
| ax.axhline(y=95, color='red', linestyle='--', alpha=0.5) |
| for bar in list(b1)+list(b2): |
| h = bar.get_height() |
| if h > 0: ax.text(bar.get_x()+bar.get_width()/2., h+1, f'{h:.1f}', ha='center', va='bottom', fontsize=8) |
| plt.tight_layout() |
| plt.savefig('/app/final_figures/fig3_model_comparison.png', dpi=300, bbox_inches='tight') |
| plt.savefig('/app/final_figures/fig3_model_comparison.pdf', bbox_inches='tight') |
| plt.close() |
|
|
| |
| log(" Fig 4: Per-class F1...") |
| fig, ax = plt.subplots(figsize=(8, 4.5)) |
| x = np.arange(3); w = 0.25 |
| for idx, (name, preds, cmap) in enumerate(models_cm): |
| report = classification_report(te_test_labels, preds, output_dict=True, zero_division=0) |
| f1s = [report[c]['f1-score']*100 for c in label_names] |
| bars = ax.bar(x + idx*w, f1s, w, label=name.replace('\n', ' '), edgecolor='black', linewidth=0.5) |
| for bar, v in zip(bars, f1s): |
| if v > 0: ax.text(bar.get_x()+bar.get_width()/2., bar.get_height()+1, f'{v:.1f}', ha='center', fontsize=7) |
| ax.set_xticks(x + w); ax.set_xticklabels(label_names) |
| ax.set_ylabel('F1-Score (%)'); ax.set_title('Per-Class F1 on TweetEval Sentiment', fontweight='bold') |
| ax.legend(fontsize=8); ax.set_ylim(0, 100); ax.grid(axis='y', alpha=0.3) |
| plt.tight_layout() |
| plt.savefig('/app/final_figures/fig4_per_class_f1.png', dpi=300, bbox_inches='tight') |
| plt.savefig('/app/final_figures/fig4_per_class_f1.pdf', bbox_inches='tight') |
| plt.close() |
|
|
| |
| log(" Fig 5: Dataset distribution...") |
| fig, axes = plt.subplots(1, 2, figsize=(8, 3.5)) |
| te_counts = Counter(te_test_labels) |
| axes[0].bar(label_names, [te_counts[0], te_counts[1], te_counts[2]], color=['#e74c3c','#95a5a6','#2ecc71'], edgecolor='black', linewidth=0.5) |
| axes[0].set_title('TweetEval Test (n=12,284)', fontweight='bold'); axes[0].set_ylabel('Count') |
| for i, v in enumerate([te_counts[0], te_counts[1], te_counts[2]]): |
| axes[0].text(i, v+50, f'{v}\n({v/len(te_test_labels)*100:.1f}%)', ha='center', fontsize=8) |
| sst2_counts = Counter(sst2_labels) |
| axes[1].bar(['Negative','Positive'], [sst2_counts[0], sst2_counts[1]], color=['#e74c3c','#2ecc71'], edgecolor='black', linewidth=0.5) |
| axes[1].set_title('SST-2 Validation (n=872)', fontweight='bold'); axes[1].set_ylabel('Count') |
| for i, v in enumerate([sst2_counts[0], sst2_counts[1]]): |
| axes[1].text(i, v+5, f'{v}\n({v/len(sst2_labels)*100:.1f}%)', ha='center', fontsize=8) |
| plt.tight_layout() |
| plt.savefig('/app/final_figures/fig5_data_distribution.png', dpi=300, bbox_inches='tight') |
| plt.savefig('/app/final_figures/fig5_data_distribution.pdf', bbox_inches='tight') |
| plt.close() |
|
|
| |
| results = { |
| 'roberta_ft_seed1': {'macro_recall': round(mr_ft*100,2), 'macro_f1': round(mf1_ft*100,2), 'accuracy': round(acc_ft*100,2), |
| 'val_epoch1': {'macro_recall': 72.34, 'macro_f1': 72.76}, |
| 'val_epoch2': {'macro_recall': 75.18, 'macro_f1': 74.28}}, |
| 'all_models': {k: {kk: round(vv*100,2) if isinstance(vv, float) and vv < 1.01 else vv for kk, vv in v.items()} for k,v in all_metrics.items()}, |
| 'mcnemar': mcnemar, |
| } |
| with open('/app/final_results.json', 'w') as f: |
| json.dump(results, f, indent=2) |
|
|
| log("\n" + "="*70) |
| log("FINAL RESULTS") |
| log("="*70) |
| log(f"\n{'Model':<22} {'Params':>7} {'SST-2 Acc':>10} {'TweetEval MR':>13} {'TweetEval MF1':>14}") |
| log("-"*70) |
| for name, m in all_metrics.items(): |
| sst2_str = f"{m['sst2_accuracy']*100:.2f}%" if m.get('sst2_accuracy') else "N/A" |
| log(f"{name:<22} {m['params']:>7} {sst2_str:>10} {m['macro_recall']*100:>12.2f}% {m['macro_f1']*100:>13.2f}%") |
| log("="*70) |
| log("\nDONE! All figures saved to /app/final_figures/") |
|
|