rajvivan's picture
Add comprehensive evaluation script
ce8f6ec verified
"""
Comprehensive evaluation on TweetEval (real social media data) + Figure generation
Generates: confusion matrices, bar charts, model comparison plots
"""
import json
import gc
import numpy as np
import time
from collections import Counter
from datasets import load_dataset
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
import evaluate
import torch
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from sklearn.metrics import confusion_matrix, classification_report
import os
os.makedirs("/app/figures", exist_ok=True)
print("="*60)
print("COMPREHENSIVE EVALUATION + FIGURE GENERATION")
print("="*60)
# ── Load TweetEval Sentiment (the REAL social media benchmark) ──
print("\n1. Loading TweetEval Sentiment (real tweets, 3-class)...")
tweeteval = load_dataset("cardiffnlp/tweet_eval", "sentiment")
print(f" Train: {len(tweeteval['train'])}")
print(f" Val: {len(tweeteval['validation'])}")
print(f" Test: {len(tweeteval['test'])}")
# Label distribution
train_labels = tweeteval['train']['label']
test_labels = tweeteval['test']['label']
label_names = ['Negative', 'Neutral', 'Positive']
print(f" Train distribution: {Counter(train_labels)}")
print(f" Test distribution: {Counter(test_labels)}")
# Also load SST-2 for comparison
sst2 = load_dataset("stanfordnlp/sst2")
def preprocess_tweet(text):
if not text:
return ""
return " ".join(
'@user' if t.startswith('@') and len(t) > 1 else ('http' if t.startswith('http') else t)
for t in text.split(" ")
)
# ── Metrics ──
accuracy_metric = evaluate.load("accuracy")
f1_metric = evaluate.load("f1")
precision_metric = evaluate.load("precision")
recall_metric = evaluate.load("recall")
def compute_metrics(preds, labels, average='macro'):
acc = accuracy_metric.compute(predictions=preds, references=labels)["accuracy"]
f1 = f1_metric.compute(predictions=preds, references=labels, average=average)["f1"]
prec = precision_metric.compute(predictions=preds, references=labels, average=average)["precision"]
rec = recall_metric.compute(predictions=preds, references=labels, average=average)["recall"]
return {
"accuracy": round(acc * 100, 2),
"f1": round(f1 * 100, 2),
"precision": round(prec * 100, 2),
"recall": round(rec * 100, 2),
}
# ── Prepare test data ──
tweeteval_test_texts = [preprocess_tweet(t) for t in list(tweeteval['test']['text'])]
tweeteval_test_labels = list(tweeteval['test']['label'])
sst2_val_texts = list(sst2['validation']['sentence'])
sst2_val_labels = list(sst2['validation']['label'])
all_results = {}
all_predictions = {} # Store for confusion matrices
# ══════════════════════════════════════════════════════════════════
# MODEL 1: DeBERTa-v3-base (SST-2 fine-tuned)
# ══════════════════════════════════════════════════════════════════
print("\n" + "="*60)
print("MODEL 1: DeBERTa-v3-base-SST2")
print("="*60)
pipe = pipeline("text-classification", model="cliang1453/deberta-v3-base-sst2", device=-1, batch_size=16)
# TweetEval 3-class: DeBERTa is binary, so we map neutral to lower-confidence class
print(" TweetEval (binary mapping: neutral->negative if neg_score>pos_score)...")
t0 = time.time()
tweeteval_preds_deberta = []
# For binary model on 3-class: predict pos/neg, mark neutral based on low confidence
for out in pipe(tweeteval_test_texts, truncation=True, max_length=128):
if out['label'].lower() == 'positive':
tweeteval_preds_deberta.append(2) # positive
else:
tweeteval_preds_deberta.append(0) # negative
te_time = time.time() - t0
# Binary evaluation (collapse neutral)
tweeteval_binary_labels = [0 if l == 0 else (1 if l == 2 else -1) for l in tweeteval_test_labels]
tweeteval_binary_preds = [0 if p == 0 else 1 for p in tweeteval_preds_deberta]
# Filter out neutral ground truth for binary eval
binary_mask = [l != -1 for l in tweeteval_binary_labels]
binary_labels_filtered = [l for l, m in zip(tweeteval_binary_labels, binary_mask) if m]
binary_preds_filtered = [p for p, m in zip(tweeteval_binary_preds, binary_mask) if m]
binary_metrics = compute_metrics(binary_preds_filtered, binary_labels_filtered, 'weighted')
print(f" TweetEval Binary (excl neutral): Acc={binary_metrics['accuracy']}% F1={binary_metrics['f1']}%")
# 3-class evaluation (no neutral class for binary model, so poor on neutral)
three_class_metrics = compute_metrics(tweeteval_preds_deberta, tweeteval_test_labels, 'macro')
print(f" TweetEval 3-class: Acc={three_class_metrics['accuracy']}% Macro-F1={three_class_metrics['f1']}%")
# SST-2
print(" SST-2...")
sst2_preds = []
for out in pipe(sst2_val_texts, truncation=True, max_length=128):
sst2_preds.append(1 if out['label'].lower() == 'positive' else 0)
sst2_metrics = compute_metrics(sst2_preds, sst2_val_labels, 'weighted')
print(f" SST-2: Acc={sst2_metrics['accuracy']}% F1={sst2_metrics['f1']}%")
all_results["DeBERTa-v3-base"] = {
"params": "184M",
"sst2": sst2_metrics,
"tweeteval_3class": three_class_metrics,
"tweeteval_binary": binary_metrics,
}
all_predictions["DeBERTa-v3-base"] = {
"tweeteval": tweeteval_preds_deberta,
"sst2": sst2_preds,
}
del pipe; gc.collect()
# ══════════════════════════════════════════════════════════════════
# MODEL 2: Twitter-RoBERTa (3-class native)
# ══════════════════════════════════════════════════════════════════
print("\n" + "="*60)
print("MODEL 2: Twitter-RoBERTa Sentiment (3-class)")
print("="*60)
pipe = pipeline("text-classification", model="cardiffnlp/twitter-roberta-base-sentiment-latest",
device=-1, batch_size=16, top_k=None)
print(" TweetEval 3-class...")
t0 = time.time()
tweeteval_preds_roberta = []
for out in pipe(tweeteval_test_texts, truncation=True, max_length=128):
scores = {r['label'].lower(): r['score'] for r in out}
# Map: negative=0, neutral=1, positive=2
label_scores = [
scores.get('negative', scores.get('neg', 0)),
scores.get('neutral', scores.get('neu', 0)),
scores.get('positive', scores.get('pos', 0)),
]
tweeteval_preds_roberta.append(np.argmax(label_scores))
three_class_metrics_rob = compute_metrics(tweeteval_preds_roberta, tweeteval_test_labels, 'macro')
print(f" TweetEval 3-class: Acc={three_class_metrics_rob['accuracy']}% Macro-F1={three_class_metrics_rob['f1']}%")
# Binary
binary_preds_rob = [0 if p == 0 else 1 for p in tweeteval_preds_roberta]
binary_preds_rob_filtered = [p for p, m in zip(binary_preds_rob, binary_mask) if m]
binary_metrics_rob = compute_metrics(binary_preds_rob_filtered, binary_labels_filtered, 'weighted')
print(f" TweetEval Binary: Acc={binary_metrics_rob['accuracy']}% F1={binary_metrics_rob['f1']}%")
# SST-2 (binary: take positive vs negative ignoring neutral)
sst2_preds_rob = []
for out in pipe(sst2_val_texts, truncation=True, max_length=128):
scores = {r['label'].lower(): r['score'] for r in out}
pos = scores.get('positive', scores.get('pos', 0))
neg = scores.get('negative', scores.get('neg', 0))
sst2_preds_rob.append(1 if pos > neg else 0)
sst2_metrics_rob = compute_metrics(sst2_preds_rob, sst2_val_labels, 'weighted')
print(f" SST-2: Acc={sst2_metrics_rob['accuracy']}% F1={sst2_metrics_rob['f1']}%")
all_results["Twitter-RoBERTa"] = {
"params": "125M",
"sst2": sst2_metrics_rob,
"tweeteval_3class": three_class_metrics_rob,
"tweeteval_binary": binary_metrics_rob,
}
all_predictions["Twitter-RoBERTa"] = {
"tweeteval": tweeteval_preds_roberta,
"sst2": sst2_preds_rob,
}
del pipe; gc.collect()
# ══════════════════════════════════════════════════════════════════
# MODEL 3: BERT-base SST-2
# ══════════════════════════════════════════════════════════════════
print("\n" + "="*60)
print("MODEL 3: BERT-base-SST2")
print("="*60)
pipe = pipeline("text-classification", model="textattack/bert-base-uncased-SST-2", device=-1, batch_size=16)
print(" TweetEval (binary model on 3-class)...")
tweeteval_preds_bert = []
for out in pipe(tweeteval_test_texts, truncation=True, max_length=128):
if out['label'].upper() in ['POSITIVE', 'LABEL_1', '1']:
tweeteval_preds_bert.append(2)
else:
tweeteval_preds_bert.append(0)
three_class_metrics_bert = compute_metrics(tweeteval_preds_bert, tweeteval_test_labels, 'macro')
print(f" TweetEval 3-class: Acc={three_class_metrics_bert['accuracy']}% Macro-F1={three_class_metrics_bert['f1']}%")
binary_preds_bert = [0 if p == 0 else 1 for p in tweeteval_preds_bert]
binary_preds_bert_filtered = [p for p, m in zip(binary_preds_bert, binary_mask) if m]
binary_metrics_bert = compute_metrics(binary_preds_bert_filtered, binary_labels_filtered, 'weighted')
print(f" TweetEval Binary: Acc={binary_metrics_bert['accuracy']}% F1={binary_metrics_bert['f1']}%")
sst2_preds_bert = []
for out in pipe(sst2_val_texts, truncation=True, max_length=128):
sst2_preds_bert.append(1 if out['label'].upper() in ['POSITIVE', 'LABEL_1', '1'] else 0)
sst2_metrics_bert = compute_metrics(sst2_preds_bert, sst2_val_labels, 'weighted')
print(f" SST-2: Acc={sst2_metrics_bert['accuracy']}% F1={sst2_metrics_bert['f1']}%")
all_results["BERT-base"] = {
"params": "110M",
"sst2": sst2_metrics_bert,
"tweeteval_3class": three_class_metrics_bert,
"tweeteval_binary": binary_metrics_bert,
}
all_predictions["BERT-base"] = {
"tweeteval": tweeteval_preds_bert,
"sst2": sst2_preds_bert,
}
del pipe; gc.collect()
# ══════════════════════════════════════════════════════════════════
# MODEL 4: DistilBERT SST-2
# ══════════════════════════════════════════════════════════════════
print("\n" + "="*60)
print("MODEL 4: DistilBERT-SST2")
print("="*60)
pipe = pipeline("text-classification", model="distilbert/distilbert-base-uncased-finetuned-sst-2-english",
device=-1, batch_size=16)
tweeteval_preds_distil = []
for out in pipe(tweeteval_test_texts, truncation=True, max_length=128):
tweeteval_preds_distil.append(2 if out['label'] == 'POSITIVE' else 0)
three_class_metrics_distil = compute_metrics(tweeteval_preds_distil, tweeteval_test_labels, 'macro')
print(f" TweetEval 3-class: Acc={three_class_metrics_distil['accuracy']}% Macro-F1={three_class_metrics_distil['f1']}%")
binary_preds_distil = [0 if p == 0 else 1 for p in tweeteval_preds_distil]
binary_preds_distil_filtered = [p for p, m in zip(binary_preds_distil, binary_mask) if m]
binary_metrics_distil = compute_metrics(binary_preds_distil_filtered, binary_labels_filtered, 'weighted')
print(f" TweetEval Binary: Acc={binary_metrics_distil['accuracy']}% F1={binary_metrics_distil['f1']}%")
sst2_preds_distil = []
for out in pipe(sst2_val_texts, truncation=True, max_length=128):
sst2_preds_distil.append(1 if out['label'] == 'POSITIVE' else 0)
sst2_metrics_distil = compute_metrics(sst2_preds_distil, sst2_val_labels, 'weighted')
print(f" SST-2: Acc={sst2_metrics_distil['accuracy']}% F1={sst2_metrics_distil['f1']}%")
all_results["DistilBERT"] = {
"params": "66M",
"sst2": sst2_metrics_distil,
"tweeteval_3class": three_class_metrics_distil,
"tweeteval_binary": binary_metrics_distil,
}
all_predictions["DistilBERT"] = {
"tweeteval": tweeteval_preds_distil,
"sst2": sst2_preds_distil,
}
del pipe; gc.collect()
# ══════════════════════════════════════════════════════════════════
# GENERATE FIGURES
# ══════════════════════════════════════════════════════════════════
print("\n" + "="*60)
print("GENERATING FIGURES")
print("="*60)
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.size'] = 10
# ── FIGURE 1: Confusion Matrix for Twitter-RoBERTa on TweetEval ──
print(" Fig 1: Confusion Matrix (Twitter-RoBERTa on TweetEval)...")
cm = confusion_matrix(tweeteval_test_labels, tweeteval_preds_roberta, normalize='true')
fig, ax = plt.subplots(figsize=(5, 4))
im = ax.imshow(cm, interpolation='nearest', cmap='Blues')
ax.figure.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
classes = ['Negative', 'Neutral', 'Positive']
ax.set(xticks=np.arange(cm.shape[1]),
yticks=np.arange(cm.shape[0]),
xticklabels=classes, yticklabels=classes,
ylabel='True Label', xlabel='Predicted Label',
title='(a) Twitter-RoBERTa on TweetEval')
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
ax.text(j, i, format(cm[i, j], '.2f'),
ha="center", va="center",
color="white" if cm[i, j] > 0.5 else "black", fontsize=12)
plt.tight_layout()
plt.savefig('/app/figures/fig1_confusion_roberta.png', dpi=300, bbox_inches='tight')
plt.savefig('/app/figures/fig1_confusion_roberta.pdf', bbox_inches='tight')
plt.close()
# ── FIGURE 2: Confusion Matrix for DeBERTa on TweetEval ──
print(" Fig 2: Confusion Matrix (DeBERTa on TweetEval)...")
cm2 = confusion_matrix(tweeteval_test_labels, tweeteval_preds_deberta, normalize='true')
fig, ax = plt.subplots(figsize=(5, 4))
im = ax.imshow(cm2, interpolation='nearest', cmap='Oranges')
ax.figure.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
ax.set(xticks=np.arange(cm2.shape[1]),
yticks=np.arange(cm2.shape[0]),
xticklabels=classes, yticklabels=classes,
ylabel='True Label', xlabel='Predicted Label',
title='(b) DeBERTa-v3-base on TweetEval')
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
for i in range(cm2.shape[0]):
for j in range(cm2.shape[1]):
ax.text(j, i, format(cm2[i, j], '.2f'),
ha="center", va="center",
color="white" if cm2[i, j] > 0.5 else "black", fontsize=12)
plt.tight_layout()
plt.savefig('/app/figures/fig2_confusion_deberta.png', dpi=300, bbox_inches='tight')
plt.savefig('/app/figures/fig2_confusion_deberta.pdf', bbox_inches='tight')
plt.close()
# ── FIGURE 3: Model Comparison Bar Chart ──
print(" Fig 3: Model Comparison Bar Chart...")
models = ['DistilBERT\n(66M)', 'BERT-base\n(110M)', 'Twitter-\nRoBERTa\n(125M)', 'DeBERTa-v3\n(184M)']
sst2_accs = [
all_results['DistilBERT']['sst2']['accuracy'],
all_results['BERT-base']['sst2']['accuracy'],
all_results['Twitter-RoBERTa']['sst2']['accuracy'],
all_results['DeBERTa-v3-base']['sst2']['accuracy'],
]
tweet_f1s = [
all_results['DistilBERT']['tweeteval_3class']['f1'],
all_results['BERT-base']['tweeteval_3class']['f1'],
all_results['Twitter-RoBERTa']['tweeteval_3class']['f1'],
all_results['DeBERTa-v3-base']['tweeteval_3class']['f1'],
]
x = np.arange(len(models))
width = 0.35
fig, ax = plt.subplots(figsize=(8, 5))
bars1 = ax.bar(x - width/2, sst2_accs, width, label='SST-2 Accuracy (%)', color='#2196F3', edgecolor='black', linewidth=0.5)
bars2 = ax.bar(x + width/2, tweet_f1s, width, label='TweetEval Macro-F1 (%)', color='#FF9800', edgecolor='black', linewidth=0.5)
ax.set_ylabel('Score (%)', fontsize=12)
ax.set_title('Model Comparison: SST-2 vs TweetEval Performance', fontsize=13, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(models, fontsize=9)
ax.legend(loc='upper left', fontsize=10)
ax.set_ylim(0, 105)
ax.grid(axis='y', alpha=0.3)
# Add value labels on bars
for bar in bars1:
h = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2., h + 0.5, f'{h:.1f}', ha='center', va='bottom', fontsize=8)
for bar in bars2:
h = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2., h + 0.5, f'{h:.1f}', ha='center', va='bottom', fontsize=8)
# Add 95% target line
ax.axhline(y=95, color='red', linestyle='--', alpha=0.7, linewidth=1)
ax.text(3.5, 95.5, '95% Target', color='red', fontsize=8, ha='right')
plt.tight_layout()
plt.savefig('/app/figures/fig3_model_comparison.png', dpi=300, bbox_inches='tight')
plt.savefig('/app/figures/fig3_model_comparison.pdf', bbox_inches='tight')
plt.close()
# ── FIGURE 4: TweetEval Class Distribution ──
print(" Fig 4: Dataset class distribution...")
fig, axes = plt.subplots(1, 2, figsize=(8, 3.5))
# TweetEval
te_counts = Counter(tweeteval_test_labels)
axes[0].bar(classes, [te_counts[0], te_counts[1], te_counts[2]],
color=['#e74c3c', '#95a5a6', '#2ecc71'], edgecolor='black', linewidth=0.5)
axes[0].set_title('TweetEval Test Set Distribution', fontsize=10, fontweight='bold')
axes[0].set_ylabel('Count')
for i, v in enumerate([te_counts[0], te_counts[1], te_counts[2]]):
axes[0].text(i, v + 20, str(v), ha='center', fontsize=9)
# SST-2
sst2_counts = Counter(sst2_val_labels)
axes[1].bar(['Negative', 'Positive'], [sst2_counts[0], sst2_counts[1]],
color=['#e74c3c', '#2ecc71'], edgecolor='black', linewidth=0.5)
axes[1].set_title('SST-2 Validation Set Distribution', fontsize=10, fontweight='bold')
axes[1].set_ylabel('Count')
for i, v in enumerate([sst2_counts[0], sst2_counts[1]]):
axes[1].text(i, v + 5, str(v), ha='center', fontsize=9)
plt.tight_layout()
plt.savefig('/app/figures/fig4_data_distribution.png', dpi=300, bbox_inches='tight')
plt.savefig('/app/figures/fig4_data_distribution.pdf', bbox_inches='tight')
plt.close()
# ── FIGURE 5: Per-class F1 comparison ──
print(" Fig 5: Per-class F1 comparison...")
# Get per-class metrics for Twitter-RoBERTa (the 3-class model)
report_rob = classification_report(tweeteval_test_labels, tweeteval_preds_roberta,
target_names=classes, output_dict=True)
# DeBERTa (binary mapped to 3-class)
report_deb = classification_report(tweeteval_test_labels, tweeteval_preds_deberta,
target_names=classes, output_dict=True)
fig, ax = plt.subplots(figsize=(7, 4))
x = np.arange(3)
width = 0.35
f1_rob = [report_rob[c]['f1-score']*100 for c in classes]
f1_deb = [report_deb[c]['f1-score']*100 for c in classes]
bars1 = ax.bar(x - width/2, f1_rob, width, label='Twitter-RoBERTa', color='#2196F3', edgecolor='black', linewidth=0.5)
bars2 = ax.bar(x + width/2, f1_deb, width, label='DeBERTa-v3-base', color='#FF5722', edgecolor='black', linewidth=0.5)
ax.set_ylabel('F1-Score (%)')
ax.set_title('Per-Class F1 Scores on TweetEval Sentiment', fontsize=12, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(classes)
ax.legend()
ax.set_ylim(0, 100)
ax.grid(axis='y', alpha=0.3)
for bar in bars1:
h = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2., h + 1, f'{h:.1f}', ha='center', va='bottom', fontsize=8)
for bar in bars2:
h = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2., h + 1, f'{h:.1f}', ha='center', va='bottom', fontsize=8)
plt.tight_layout()
plt.savefig('/app/figures/fig5_per_class_f1.png', dpi=300, bbox_inches='tight')
plt.savefig('/app/figures/fig5_per_class_f1.pdf', bbox_inches='tight')
plt.close()
# ══════════════════════════════════════════════════════════════════
# SAVE ALL RESULTS
# ══════════════════════════════════════════════════════════════════
with open("/app/comprehensive_results.json", "w") as f:
json.dump(all_results, f, indent=2)
print("\n" + "="*60)
print("COMPREHENSIVE RESULTS SUMMARY")
print("="*60)
print(f"\n{'Model':<20} {'Params':>7} {'SST-2 Acc':>10} {'TweetEval Acc':>14} {'TweetEval F1':>13}")
print("-"*68)
for name, res in all_results.items():
print(f"{name:<20} {res['params']:>7} {res['sst2']['accuracy']:>9.2f}% {res['tweeteval_3class']['accuracy']:>13.2f}% {res['tweeteval_3class']['f1']:>12.2f}%")
print("="*68)
print(f"\nFigures saved to /app/figures/")
print(f" fig1_confusion_roberta.png/pdf")
print(f" fig2_confusion_deberta.png/pdf")
print(f" fig3_model_comparison.png/pdf")
print(f" fig4_data_distribution.png/pdf")
print(f" fig5_per_class_f1.png/pdf")
print(f"\nResults saved to /app/comprehensive_results.json")