""" Comprehensive evaluation on TweetEval (real social media data) + Figure generation Generates: confusion matrices, bar charts, model comparison plots """ import json import gc import numpy as np import time from collections import Counter from datasets import load_dataset from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification import evaluate import torch import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import matplotlib.patches as mpatches from sklearn.metrics import confusion_matrix, classification_report import os os.makedirs("/app/figures", exist_ok=True) print("="*60) print("COMPREHENSIVE EVALUATION + FIGURE GENERATION") print("="*60) # ── Load TweetEval Sentiment (the REAL social media benchmark) ── print("\n1. Loading TweetEval Sentiment (real tweets, 3-class)...") tweeteval = load_dataset("cardiffnlp/tweet_eval", "sentiment") print(f" Train: {len(tweeteval['train'])}") print(f" Val: {len(tweeteval['validation'])}") print(f" Test: {len(tweeteval['test'])}") # Label distribution train_labels = tweeteval['train']['label'] test_labels = tweeteval['test']['label'] label_names = ['Negative', 'Neutral', 'Positive'] print(f" Train distribution: {Counter(train_labels)}") print(f" Test distribution: {Counter(test_labels)}") # Also load SST-2 for comparison sst2 = load_dataset("stanfordnlp/sst2") def preprocess_tweet(text): if not text: return "" return " ".join( '@user' if t.startswith('@') and len(t) > 1 else ('http' if t.startswith('http') else t) for t in text.split(" ") ) # ── Metrics ── accuracy_metric = evaluate.load("accuracy") f1_metric = evaluate.load("f1") precision_metric = evaluate.load("precision") recall_metric = evaluate.load("recall") def compute_metrics(preds, labels, average='macro'): acc = accuracy_metric.compute(predictions=preds, references=labels)["accuracy"] f1 = f1_metric.compute(predictions=preds, references=labels, average=average)["f1"] prec = precision_metric.compute(predictions=preds, references=labels, average=average)["precision"] rec = recall_metric.compute(predictions=preds, references=labels, average=average)["recall"] return { "accuracy": round(acc * 100, 2), "f1": round(f1 * 100, 2), "precision": round(prec * 100, 2), "recall": round(rec * 100, 2), } # ── Prepare test data ── tweeteval_test_texts = [preprocess_tweet(t) for t in list(tweeteval['test']['text'])] tweeteval_test_labels = list(tweeteval['test']['label']) sst2_val_texts = list(sst2['validation']['sentence']) sst2_val_labels = list(sst2['validation']['label']) all_results = {} all_predictions = {} # Store for confusion matrices # ══════════════════════════════════════════════════════════════════ # MODEL 1: DeBERTa-v3-base (SST-2 fine-tuned) # ══════════════════════════════════════════════════════════════════ print("\n" + "="*60) print("MODEL 1: DeBERTa-v3-base-SST2") print("="*60) pipe = pipeline("text-classification", model="cliang1453/deberta-v3-base-sst2", device=-1, batch_size=16) # TweetEval 3-class: DeBERTa is binary, so we map neutral to lower-confidence class print(" TweetEval (binary mapping: neutral->negative if neg_score>pos_score)...") t0 = time.time() tweeteval_preds_deberta = [] # For binary model on 3-class: predict pos/neg, mark neutral based on low confidence for out in pipe(tweeteval_test_texts, truncation=True, max_length=128): if out['label'].lower() == 'positive': tweeteval_preds_deberta.append(2) # positive else: tweeteval_preds_deberta.append(0) # negative te_time = time.time() - t0 # Binary evaluation (collapse neutral) tweeteval_binary_labels = [0 if l == 0 else (1 if l == 2 else -1) for l in tweeteval_test_labels] tweeteval_binary_preds = [0 if p == 0 else 1 for p in tweeteval_preds_deberta] # Filter out neutral ground truth for binary eval binary_mask = [l != -1 for l in tweeteval_binary_labels] binary_labels_filtered = [l for l, m in zip(tweeteval_binary_labels, binary_mask) if m] binary_preds_filtered = [p for p, m in zip(tweeteval_binary_preds, binary_mask) if m] binary_metrics = compute_metrics(binary_preds_filtered, binary_labels_filtered, 'weighted') print(f" TweetEval Binary (excl neutral): Acc={binary_metrics['accuracy']}% F1={binary_metrics['f1']}%") # 3-class evaluation (no neutral class for binary model, so poor on neutral) three_class_metrics = compute_metrics(tweeteval_preds_deberta, tweeteval_test_labels, 'macro') print(f" TweetEval 3-class: Acc={three_class_metrics['accuracy']}% Macro-F1={three_class_metrics['f1']}%") # SST-2 print(" SST-2...") sst2_preds = [] for out in pipe(sst2_val_texts, truncation=True, max_length=128): sst2_preds.append(1 if out['label'].lower() == 'positive' else 0) sst2_metrics = compute_metrics(sst2_preds, sst2_val_labels, 'weighted') print(f" SST-2: Acc={sst2_metrics['accuracy']}% F1={sst2_metrics['f1']}%") all_results["DeBERTa-v3-base"] = { "params": "184M", "sst2": sst2_metrics, "tweeteval_3class": three_class_metrics, "tweeteval_binary": binary_metrics, } all_predictions["DeBERTa-v3-base"] = { "tweeteval": tweeteval_preds_deberta, "sst2": sst2_preds, } del pipe; gc.collect() # ══════════════════════════════════════════════════════════════════ # MODEL 2: Twitter-RoBERTa (3-class native) # ══════════════════════════════════════════════════════════════════ print("\n" + "="*60) print("MODEL 2: Twitter-RoBERTa Sentiment (3-class)") print("="*60) pipe = pipeline("text-classification", model="cardiffnlp/twitter-roberta-base-sentiment-latest", device=-1, batch_size=16, top_k=None) print(" TweetEval 3-class...") t0 = time.time() tweeteval_preds_roberta = [] for out in pipe(tweeteval_test_texts, truncation=True, max_length=128): scores = {r['label'].lower(): r['score'] for r in out} # Map: negative=0, neutral=1, positive=2 label_scores = [ scores.get('negative', scores.get('neg', 0)), scores.get('neutral', scores.get('neu', 0)), scores.get('positive', scores.get('pos', 0)), ] tweeteval_preds_roberta.append(np.argmax(label_scores)) three_class_metrics_rob = compute_metrics(tweeteval_preds_roberta, tweeteval_test_labels, 'macro') print(f" TweetEval 3-class: Acc={three_class_metrics_rob['accuracy']}% Macro-F1={three_class_metrics_rob['f1']}%") # Binary binary_preds_rob = [0 if p == 0 else 1 for p in tweeteval_preds_roberta] binary_preds_rob_filtered = [p for p, m in zip(binary_preds_rob, binary_mask) if m] binary_metrics_rob = compute_metrics(binary_preds_rob_filtered, binary_labels_filtered, 'weighted') print(f" TweetEval Binary: Acc={binary_metrics_rob['accuracy']}% F1={binary_metrics_rob['f1']}%") # SST-2 (binary: take positive vs negative ignoring neutral) sst2_preds_rob = [] for out in pipe(sst2_val_texts, truncation=True, max_length=128): scores = {r['label'].lower(): r['score'] for r in out} pos = scores.get('positive', scores.get('pos', 0)) neg = scores.get('negative', scores.get('neg', 0)) sst2_preds_rob.append(1 if pos > neg else 0) sst2_metrics_rob = compute_metrics(sst2_preds_rob, sst2_val_labels, 'weighted') print(f" SST-2: Acc={sst2_metrics_rob['accuracy']}% F1={sst2_metrics_rob['f1']}%") all_results["Twitter-RoBERTa"] = { "params": "125M", "sst2": sst2_metrics_rob, "tweeteval_3class": three_class_metrics_rob, "tweeteval_binary": binary_metrics_rob, } all_predictions["Twitter-RoBERTa"] = { "tweeteval": tweeteval_preds_roberta, "sst2": sst2_preds_rob, } del pipe; gc.collect() # ══════════════════════════════════════════════════════════════════ # MODEL 3: BERT-base SST-2 # ══════════════════════════════════════════════════════════════════ print("\n" + "="*60) print("MODEL 3: BERT-base-SST2") print("="*60) pipe = pipeline("text-classification", model="textattack/bert-base-uncased-SST-2", device=-1, batch_size=16) print(" TweetEval (binary model on 3-class)...") tweeteval_preds_bert = [] for out in pipe(tweeteval_test_texts, truncation=True, max_length=128): if out['label'].upper() in ['POSITIVE', 'LABEL_1', '1']: tweeteval_preds_bert.append(2) else: tweeteval_preds_bert.append(0) three_class_metrics_bert = compute_metrics(tweeteval_preds_bert, tweeteval_test_labels, 'macro') print(f" TweetEval 3-class: Acc={three_class_metrics_bert['accuracy']}% Macro-F1={three_class_metrics_bert['f1']}%") binary_preds_bert = [0 if p == 0 else 1 for p in tweeteval_preds_bert] binary_preds_bert_filtered = [p for p, m in zip(binary_preds_bert, binary_mask) if m] binary_metrics_bert = compute_metrics(binary_preds_bert_filtered, binary_labels_filtered, 'weighted') print(f" TweetEval Binary: Acc={binary_metrics_bert['accuracy']}% F1={binary_metrics_bert['f1']}%") sst2_preds_bert = [] for out in pipe(sst2_val_texts, truncation=True, max_length=128): sst2_preds_bert.append(1 if out['label'].upper() in ['POSITIVE', 'LABEL_1', '1'] else 0) sst2_metrics_bert = compute_metrics(sst2_preds_bert, sst2_val_labels, 'weighted') print(f" SST-2: Acc={sst2_metrics_bert['accuracy']}% F1={sst2_metrics_bert['f1']}%") all_results["BERT-base"] = { "params": "110M", "sst2": sst2_metrics_bert, "tweeteval_3class": three_class_metrics_bert, "tweeteval_binary": binary_metrics_bert, } all_predictions["BERT-base"] = { "tweeteval": tweeteval_preds_bert, "sst2": sst2_preds_bert, } del pipe; gc.collect() # ══════════════════════════════════════════════════════════════════ # MODEL 4: DistilBERT SST-2 # ══════════════════════════════════════════════════════════════════ print("\n" + "="*60) print("MODEL 4: DistilBERT-SST2") print("="*60) pipe = pipeline("text-classification", model="distilbert/distilbert-base-uncased-finetuned-sst-2-english", device=-1, batch_size=16) tweeteval_preds_distil = [] for out in pipe(tweeteval_test_texts, truncation=True, max_length=128): tweeteval_preds_distil.append(2 if out['label'] == 'POSITIVE' else 0) three_class_metrics_distil = compute_metrics(tweeteval_preds_distil, tweeteval_test_labels, 'macro') print(f" TweetEval 3-class: Acc={three_class_metrics_distil['accuracy']}% Macro-F1={three_class_metrics_distil['f1']}%") binary_preds_distil = [0 if p == 0 else 1 for p in tweeteval_preds_distil] binary_preds_distil_filtered = [p for p, m in zip(binary_preds_distil, binary_mask) if m] binary_metrics_distil = compute_metrics(binary_preds_distil_filtered, binary_labels_filtered, 'weighted') print(f" TweetEval Binary: Acc={binary_metrics_distil['accuracy']}% F1={binary_metrics_distil['f1']}%") sst2_preds_distil = [] for out in pipe(sst2_val_texts, truncation=True, max_length=128): sst2_preds_distil.append(1 if out['label'] == 'POSITIVE' else 0) sst2_metrics_distil = compute_metrics(sst2_preds_distil, sst2_val_labels, 'weighted') print(f" SST-2: Acc={sst2_metrics_distil['accuracy']}% F1={sst2_metrics_distil['f1']}%") all_results["DistilBERT"] = { "params": "66M", "sst2": sst2_metrics_distil, "tweeteval_3class": three_class_metrics_distil, "tweeteval_binary": binary_metrics_distil, } all_predictions["DistilBERT"] = { "tweeteval": tweeteval_preds_distil, "sst2": sst2_preds_distil, } del pipe; gc.collect() # ══════════════════════════════════════════════════════════════════ # GENERATE FIGURES # ══════════════════════════════════════════════════════════════════ print("\n" + "="*60) print("GENERATING FIGURES") print("="*60) plt.rcParams['font.family'] = 'serif' plt.rcParams['font.size'] = 10 # ── FIGURE 1: Confusion Matrix for Twitter-RoBERTa on TweetEval ── print(" Fig 1: Confusion Matrix (Twitter-RoBERTa on TweetEval)...") cm = confusion_matrix(tweeteval_test_labels, tweeteval_preds_roberta, normalize='true') fig, ax = plt.subplots(figsize=(5, 4)) im = ax.imshow(cm, interpolation='nearest', cmap='Blues') ax.figure.colorbar(im, ax=ax, fraction=0.046, pad=0.04) classes = ['Negative', 'Neutral', 'Positive'] ax.set(xticks=np.arange(cm.shape[1]), yticks=np.arange(cm.shape[0]), xticklabels=classes, yticklabels=classes, ylabel='True Label', xlabel='Predicted Label', title='(a) Twitter-RoBERTa on TweetEval') plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") for i in range(cm.shape[0]): for j in range(cm.shape[1]): ax.text(j, i, format(cm[i, j], '.2f'), ha="center", va="center", color="white" if cm[i, j] > 0.5 else "black", fontsize=12) plt.tight_layout() plt.savefig('/app/figures/fig1_confusion_roberta.png', dpi=300, bbox_inches='tight') plt.savefig('/app/figures/fig1_confusion_roberta.pdf', bbox_inches='tight') plt.close() # ── FIGURE 2: Confusion Matrix for DeBERTa on TweetEval ── print(" Fig 2: Confusion Matrix (DeBERTa on TweetEval)...") cm2 = confusion_matrix(tweeteval_test_labels, tweeteval_preds_deberta, normalize='true') fig, ax = plt.subplots(figsize=(5, 4)) im = ax.imshow(cm2, interpolation='nearest', cmap='Oranges') ax.figure.colorbar(im, ax=ax, fraction=0.046, pad=0.04) ax.set(xticks=np.arange(cm2.shape[1]), yticks=np.arange(cm2.shape[0]), xticklabels=classes, yticklabels=classes, ylabel='True Label', xlabel='Predicted Label', title='(b) DeBERTa-v3-base on TweetEval') plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") for i in range(cm2.shape[0]): for j in range(cm2.shape[1]): ax.text(j, i, format(cm2[i, j], '.2f'), ha="center", va="center", color="white" if cm2[i, j] > 0.5 else "black", fontsize=12) plt.tight_layout() plt.savefig('/app/figures/fig2_confusion_deberta.png', dpi=300, bbox_inches='tight') plt.savefig('/app/figures/fig2_confusion_deberta.pdf', bbox_inches='tight') plt.close() # ── FIGURE 3: Model Comparison Bar Chart ── print(" Fig 3: Model Comparison Bar Chart...") models = ['DistilBERT\n(66M)', 'BERT-base\n(110M)', 'Twitter-\nRoBERTa\n(125M)', 'DeBERTa-v3\n(184M)'] sst2_accs = [ all_results['DistilBERT']['sst2']['accuracy'], all_results['BERT-base']['sst2']['accuracy'], all_results['Twitter-RoBERTa']['sst2']['accuracy'], all_results['DeBERTa-v3-base']['sst2']['accuracy'], ] tweet_f1s = [ all_results['DistilBERT']['tweeteval_3class']['f1'], all_results['BERT-base']['tweeteval_3class']['f1'], all_results['Twitter-RoBERTa']['tweeteval_3class']['f1'], all_results['DeBERTa-v3-base']['tweeteval_3class']['f1'], ] x = np.arange(len(models)) width = 0.35 fig, ax = plt.subplots(figsize=(8, 5)) bars1 = ax.bar(x - width/2, sst2_accs, width, label='SST-2 Accuracy (%)', color='#2196F3', edgecolor='black', linewidth=0.5) bars2 = ax.bar(x + width/2, tweet_f1s, width, label='TweetEval Macro-F1 (%)', color='#FF9800', edgecolor='black', linewidth=0.5) ax.set_ylabel('Score (%)', fontsize=12) ax.set_title('Model Comparison: SST-2 vs TweetEval Performance', fontsize=13, fontweight='bold') ax.set_xticks(x) ax.set_xticklabels(models, fontsize=9) ax.legend(loc='upper left', fontsize=10) ax.set_ylim(0, 105) ax.grid(axis='y', alpha=0.3) # Add value labels on bars for bar in bars1: h = bar.get_height() ax.text(bar.get_x() + bar.get_width()/2., h + 0.5, f'{h:.1f}', ha='center', va='bottom', fontsize=8) for bar in bars2: h = bar.get_height() ax.text(bar.get_x() + bar.get_width()/2., h + 0.5, f'{h:.1f}', ha='center', va='bottom', fontsize=8) # Add 95% target line ax.axhline(y=95, color='red', linestyle='--', alpha=0.7, linewidth=1) ax.text(3.5, 95.5, '95% Target', color='red', fontsize=8, ha='right') plt.tight_layout() plt.savefig('/app/figures/fig3_model_comparison.png', dpi=300, bbox_inches='tight') plt.savefig('/app/figures/fig3_model_comparison.pdf', bbox_inches='tight') plt.close() # ── FIGURE 4: TweetEval Class Distribution ── print(" Fig 4: Dataset class distribution...") fig, axes = plt.subplots(1, 2, figsize=(8, 3.5)) # TweetEval te_counts = Counter(tweeteval_test_labels) axes[0].bar(classes, [te_counts[0], te_counts[1], te_counts[2]], color=['#e74c3c', '#95a5a6', '#2ecc71'], edgecolor='black', linewidth=0.5) axes[0].set_title('TweetEval Test Set Distribution', fontsize=10, fontweight='bold') axes[0].set_ylabel('Count') for i, v in enumerate([te_counts[0], te_counts[1], te_counts[2]]): axes[0].text(i, v + 20, str(v), ha='center', fontsize=9) # SST-2 sst2_counts = Counter(sst2_val_labels) axes[1].bar(['Negative', 'Positive'], [sst2_counts[0], sst2_counts[1]], color=['#e74c3c', '#2ecc71'], edgecolor='black', linewidth=0.5) axes[1].set_title('SST-2 Validation Set Distribution', fontsize=10, fontweight='bold') axes[1].set_ylabel('Count') for i, v in enumerate([sst2_counts[0], sst2_counts[1]]): axes[1].text(i, v + 5, str(v), ha='center', fontsize=9) plt.tight_layout() plt.savefig('/app/figures/fig4_data_distribution.png', dpi=300, bbox_inches='tight') plt.savefig('/app/figures/fig4_data_distribution.pdf', bbox_inches='tight') plt.close() # ── FIGURE 5: Per-class F1 comparison ── print(" Fig 5: Per-class F1 comparison...") # Get per-class metrics for Twitter-RoBERTa (the 3-class model) report_rob = classification_report(tweeteval_test_labels, tweeteval_preds_roberta, target_names=classes, output_dict=True) # DeBERTa (binary mapped to 3-class) report_deb = classification_report(tweeteval_test_labels, tweeteval_preds_deberta, target_names=classes, output_dict=True) fig, ax = plt.subplots(figsize=(7, 4)) x = np.arange(3) width = 0.35 f1_rob = [report_rob[c]['f1-score']*100 for c in classes] f1_deb = [report_deb[c]['f1-score']*100 for c in classes] bars1 = ax.bar(x - width/2, f1_rob, width, label='Twitter-RoBERTa', color='#2196F3', edgecolor='black', linewidth=0.5) bars2 = ax.bar(x + width/2, f1_deb, width, label='DeBERTa-v3-base', color='#FF5722', edgecolor='black', linewidth=0.5) ax.set_ylabel('F1-Score (%)') ax.set_title('Per-Class F1 Scores on TweetEval Sentiment', fontsize=12, fontweight='bold') ax.set_xticks(x) ax.set_xticklabels(classes) ax.legend() ax.set_ylim(0, 100) ax.grid(axis='y', alpha=0.3) for bar in bars1: h = bar.get_height() ax.text(bar.get_x() + bar.get_width()/2., h + 1, f'{h:.1f}', ha='center', va='bottom', fontsize=8) for bar in bars2: h = bar.get_height() ax.text(bar.get_x() + bar.get_width()/2., h + 1, f'{h:.1f}', ha='center', va='bottom', fontsize=8) plt.tight_layout() plt.savefig('/app/figures/fig5_per_class_f1.png', dpi=300, bbox_inches='tight') plt.savefig('/app/figures/fig5_per_class_f1.pdf', bbox_inches='tight') plt.close() # ══════════════════════════════════════════════════════════════════ # SAVE ALL RESULTS # ══════════════════════════════════════════════════════════════════ with open("/app/comprehensive_results.json", "w") as f: json.dump(all_results, f, indent=2) print("\n" + "="*60) print("COMPREHENSIVE RESULTS SUMMARY") print("="*60) print(f"\n{'Model':<20} {'Params':>7} {'SST-2 Acc':>10} {'TweetEval Acc':>14} {'TweetEval F1':>13}") print("-"*68) for name, res in all_results.items(): print(f"{name:<20} {res['params']:>7} {res['sst2']['accuracy']:>9.2f}% {res['tweeteval_3class']['accuracy']:>13.2f}% {res['tweeteval_3class']['f1']:>12.2f}%") print("="*68) print(f"\nFigures saved to /app/figures/") print(f" fig1_confusion_roberta.png/pdf") print(f" fig2_confusion_deberta.png/pdf") print(f" fig3_model_comparison.png/pdf") print(f" fig4_data_distribution.png/pdf") print(f" fig5_per_class_f1.png/pdf") print(f"\nResults saved to /app/comprehensive_results.json")