| """ |
| Comprehensive evaluation on TweetEval (real social media data) + Figure generation |
| Generates: confusion matrices, bar charts, model comparison plots |
| """ |
|
|
| import json |
| import gc |
| import numpy as np |
| import time |
| from collections import Counter |
| from datasets import load_dataset |
| from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification |
| import evaluate |
| import torch |
| import matplotlib |
| matplotlib.use('Agg') |
| import matplotlib.pyplot as plt |
| import matplotlib.patches as mpatches |
| from sklearn.metrics import confusion_matrix, classification_report |
| import os |
|
|
| os.makedirs("/app/figures", exist_ok=True) |
|
|
| print("="*60) |
| print("COMPREHENSIVE EVALUATION + FIGURE GENERATION") |
| print("="*60) |
|
|
| |
| print("\n1. Loading TweetEval Sentiment (real tweets, 3-class)...") |
| tweeteval = load_dataset("cardiffnlp/tweet_eval", "sentiment") |
| print(f" Train: {len(tweeteval['train'])}") |
| print(f" Val: {len(tweeteval['validation'])}") |
| print(f" Test: {len(tweeteval['test'])}") |
|
|
| |
| train_labels = tweeteval['train']['label'] |
| test_labels = tweeteval['test']['label'] |
| label_names = ['Negative', 'Neutral', 'Positive'] |
| print(f" Train distribution: {Counter(train_labels)}") |
| print(f" Test distribution: {Counter(test_labels)}") |
|
|
| |
| sst2 = load_dataset("stanfordnlp/sst2") |
|
|
| def preprocess_tweet(text): |
| if not text: |
| return "" |
| return " ".join( |
| '@user' if t.startswith('@') and len(t) > 1 else ('http' if t.startswith('http') else t) |
| for t in text.split(" ") |
| ) |
|
|
| |
| accuracy_metric = evaluate.load("accuracy") |
| f1_metric = evaluate.load("f1") |
| precision_metric = evaluate.load("precision") |
| recall_metric = evaluate.load("recall") |
|
|
| def compute_metrics(preds, labels, average='macro'): |
| acc = accuracy_metric.compute(predictions=preds, references=labels)["accuracy"] |
| f1 = f1_metric.compute(predictions=preds, references=labels, average=average)["f1"] |
| prec = precision_metric.compute(predictions=preds, references=labels, average=average)["precision"] |
| rec = recall_metric.compute(predictions=preds, references=labels, average=average)["recall"] |
| return { |
| "accuracy": round(acc * 100, 2), |
| "f1": round(f1 * 100, 2), |
| "precision": round(prec * 100, 2), |
| "recall": round(rec * 100, 2), |
| } |
|
|
| |
| tweeteval_test_texts = [preprocess_tweet(t) for t in list(tweeteval['test']['text'])] |
| tweeteval_test_labels = list(tweeteval['test']['label']) |
| sst2_val_texts = list(sst2['validation']['sentence']) |
| sst2_val_labels = list(sst2['validation']['label']) |
|
|
| all_results = {} |
| all_predictions = {} |
|
|
| |
| |
| |
| print("\n" + "="*60) |
| print("MODEL 1: DeBERTa-v3-base-SST2") |
| print("="*60) |
|
|
| pipe = pipeline("text-classification", model="cliang1453/deberta-v3-base-sst2", device=-1, batch_size=16) |
|
|
| |
| print(" TweetEval (binary mapping: neutral->negative if neg_score>pos_score)...") |
| t0 = time.time() |
| tweeteval_preds_deberta = [] |
| |
| for out in pipe(tweeteval_test_texts, truncation=True, max_length=128): |
| if out['label'].lower() == 'positive': |
| tweeteval_preds_deberta.append(2) |
| else: |
| tweeteval_preds_deberta.append(0) |
| te_time = time.time() - t0 |
|
|
| |
| tweeteval_binary_labels = [0 if l == 0 else (1 if l == 2 else -1) for l in tweeteval_test_labels] |
| tweeteval_binary_preds = [0 if p == 0 else 1 for p in tweeteval_preds_deberta] |
| |
| binary_mask = [l != -1 for l in tweeteval_binary_labels] |
| binary_labels_filtered = [l for l, m in zip(tweeteval_binary_labels, binary_mask) if m] |
| binary_preds_filtered = [p for p, m in zip(tweeteval_binary_preds, binary_mask) if m] |
|
|
| binary_metrics = compute_metrics(binary_preds_filtered, binary_labels_filtered, 'weighted') |
| print(f" TweetEval Binary (excl neutral): Acc={binary_metrics['accuracy']}% F1={binary_metrics['f1']}%") |
|
|
| |
| three_class_metrics = compute_metrics(tweeteval_preds_deberta, tweeteval_test_labels, 'macro') |
| print(f" TweetEval 3-class: Acc={three_class_metrics['accuracy']}% Macro-F1={three_class_metrics['f1']}%") |
|
|
| |
| print(" SST-2...") |
| sst2_preds = [] |
| for out in pipe(sst2_val_texts, truncation=True, max_length=128): |
| sst2_preds.append(1 if out['label'].lower() == 'positive' else 0) |
| sst2_metrics = compute_metrics(sst2_preds, sst2_val_labels, 'weighted') |
| print(f" SST-2: Acc={sst2_metrics['accuracy']}% F1={sst2_metrics['f1']}%") |
|
|
| all_results["DeBERTa-v3-base"] = { |
| "params": "184M", |
| "sst2": sst2_metrics, |
| "tweeteval_3class": three_class_metrics, |
| "tweeteval_binary": binary_metrics, |
| } |
| all_predictions["DeBERTa-v3-base"] = { |
| "tweeteval": tweeteval_preds_deberta, |
| "sst2": sst2_preds, |
| } |
| del pipe; gc.collect() |
|
|
| |
| |
| |
| print("\n" + "="*60) |
| print("MODEL 2: Twitter-RoBERTa Sentiment (3-class)") |
| print("="*60) |
|
|
| pipe = pipeline("text-classification", model="cardiffnlp/twitter-roberta-base-sentiment-latest", |
| device=-1, batch_size=16, top_k=None) |
|
|
| print(" TweetEval 3-class...") |
| t0 = time.time() |
| tweeteval_preds_roberta = [] |
| for out in pipe(tweeteval_test_texts, truncation=True, max_length=128): |
| scores = {r['label'].lower(): r['score'] for r in out} |
| |
| label_scores = [ |
| scores.get('negative', scores.get('neg', 0)), |
| scores.get('neutral', scores.get('neu', 0)), |
| scores.get('positive', scores.get('pos', 0)), |
| ] |
| tweeteval_preds_roberta.append(np.argmax(label_scores)) |
|
|
| three_class_metrics_rob = compute_metrics(tweeteval_preds_roberta, tweeteval_test_labels, 'macro') |
| print(f" TweetEval 3-class: Acc={three_class_metrics_rob['accuracy']}% Macro-F1={three_class_metrics_rob['f1']}%") |
|
|
| |
| binary_preds_rob = [0 if p == 0 else 1 for p in tweeteval_preds_roberta] |
| binary_preds_rob_filtered = [p for p, m in zip(binary_preds_rob, binary_mask) if m] |
| binary_metrics_rob = compute_metrics(binary_preds_rob_filtered, binary_labels_filtered, 'weighted') |
| print(f" TweetEval Binary: Acc={binary_metrics_rob['accuracy']}% F1={binary_metrics_rob['f1']}%") |
|
|
| |
| sst2_preds_rob = [] |
| for out in pipe(sst2_val_texts, truncation=True, max_length=128): |
| scores = {r['label'].lower(): r['score'] for r in out} |
| pos = scores.get('positive', scores.get('pos', 0)) |
| neg = scores.get('negative', scores.get('neg', 0)) |
| sst2_preds_rob.append(1 if pos > neg else 0) |
| sst2_metrics_rob = compute_metrics(sst2_preds_rob, sst2_val_labels, 'weighted') |
| print(f" SST-2: Acc={sst2_metrics_rob['accuracy']}% F1={sst2_metrics_rob['f1']}%") |
|
|
| all_results["Twitter-RoBERTa"] = { |
| "params": "125M", |
| "sst2": sst2_metrics_rob, |
| "tweeteval_3class": three_class_metrics_rob, |
| "tweeteval_binary": binary_metrics_rob, |
| } |
| all_predictions["Twitter-RoBERTa"] = { |
| "tweeteval": tweeteval_preds_roberta, |
| "sst2": sst2_preds_rob, |
| } |
| del pipe; gc.collect() |
|
|
| |
| |
| |
| print("\n" + "="*60) |
| print("MODEL 3: BERT-base-SST2") |
| print("="*60) |
|
|
| pipe = pipeline("text-classification", model="textattack/bert-base-uncased-SST-2", device=-1, batch_size=16) |
|
|
| print(" TweetEval (binary model on 3-class)...") |
| tweeteval_preds_bert = [] |
| for out in pipe(tweeteval_test_texts, truncation=True, max_length=128): |
| if out['label'].upper() in ['POSITIVE', 'LABEL_1', '1']: |
| tweeteval_preds_bert.append(2) |
| else: |
| tweeteval_preds_bert.append(0) |
|
|
| three_class_metrics_bert = compute_metrics(tweeteval_preds_bert, tweeteval_test_labels, 'macro') |
| print(f" TweetEval 3-class: Acc={three_class_metrics_bert['accuracy']}% Macro-F1={three_class_metrics_bert['f1']}%") |
|
|
| binary_preds_bert = [0 if p == 0 else 1 for p in tweeteval_preds_bert] |
| binary_preds_bert_filtered = [p for p, m in zip(binary_preds_bert, binary_mask) if m] |
| binary_metrics_bert = compute_metrics(binary_preds_bert_filtered, binary_labels_filtered, 'weighted') |
| print(f" TweetEval Binary: Acc={binary_metrics_bert['accuracy']}% F1={binary_metrics_bert['f1']}%") |
|
|
| sst2_preds_bert = [] |
| for out in pipe(sst2_val_texts, truncation=True, max_length=128): |
| sst2_preds_bert.append(1 if out['label'].upper() in ['POSITIVE', 'LABEL_1', '1'] else 0) |
| sst2_metrics_bert = compute_metrics(sst2_preds_bert, sst2_val_labels, 'weighted') |
| print(f" SST-2: Acc={sst2_metrics_bert['accuracy']}% F1={sst2_metrics_bert['f1']}%") |
|
|
| all_results["BERT-base"] = { |
| "params": "110M", |
| "sst2": sst2_metrics_bert, |
| "tweeteval_3class": three_class_metrics_bert, |
| "tweeteval_binary": binary_metrics_bert, |
| } |
| all_predictions["BERT-base"] = { |
| "tweeteval": tweeteval_preds_bert, |
| "sst2": sst2_preds_bert, |
| } |
| del pipe; gc.collect() |
|
|
| |
| |
| |
| print("\n" + "="*60) |
| print("MODEL 4: DistilBERT-SST2") |
| print("="*60) |
|
|
| pipe = pipeline("text-classification", model="distilbert/distilbert-base-uncased-finetuned-sst-2-english", |
| device=-1, batch_size=16) |
|
|
| tweeteval_preds_distil = [] |
| for out in pipe(tweeteval_test_texts, truncation=True, max_length=128): |
| tweeteval_preds_distil.append(2 if out['label'] == 'POSITIVE' else 0) |
|
|
| three_class_metrics_distil = compute_metrics(tweeteval_preds_distil, tweeteval_test_labels, 'macro') |
| print(f" TweetEval 3-class: Acc={three_class_metrics_distil['accuracy']}% Macro-F1={three_class_metrics_distil['f1']}%") |
|
|
| binary_preds_distil = [0 if p == 0 else 1 for p in tweeteval_preds_distil] |
| binary_preds_distil_filtered = [p for p, m in zip(binary_preds_distil, binary_mask) if m] |
| binary_metrics_distil = compute_metrics(binary_preds_distil_filtered, binary_labels_filtered, 'weighted') |
| print(f" TweetEval Binary: Acc={binary_metrics_distil['accuracy']}% F1={binary_metrics_distil['f1']}%") |
|
|
| sst2_preds_distil = [] |
| for out in pipe(sst2_val_texts, truncation=True, max_length=128): |
| sst2_preds_distil.append(1 if out['label'] == 'POSITIVE' else 0) |
| sst2_metrics_distil = compute_metrics(sst2_preds_distil, sst2_val_labels, 'weighted') |
| print(f" SST-2: Acc={sst2_metrics_distil['accuracy']}% F1={sst2_metrics_distil['f1']}%") |
|
|
| all_results["DistilBERT"] = { |
| "params": "66M", |
| "sst2": sst2_metrics_distil, |
| "tweeteval_3class": three_class_metrics_distil, |
| "tweeteval_binary": binary_metrics_distil, |
| } |
| all_predictions["DistilBERT"] = { |
| "tweeteval": tweeteval_preds_distil, |
| "sst2": sst2_preds_distil, |
| } |
| del pipe; gc.collect() |
|
|
| |
| |
| |
| print("\n" + "="*60) |
| print("GENERATING FIGURES") |
| print("="*60) |
|
|
| plt.rcParams['font.family'] = 'serif' |
| plt.rcParams['font.size'] = 10 |
|
|
| |
| print(" Fig 1: Confusion Matrix (Twitter-RoBERTa on TweetEval)...") |
| cm = confusion_matrix(tweeteval_test_labels, tweeteval_preds_roberta, normalize='true') |
| fig, ax = plt.subplots(figsize=(5, 4)) |
| im = ax.imshow(cm, interpolation='nearest', cmap='Blues') |
| ax.figure.colorbar(im, ax=ax, fraction=0.046, pad=0.04) |
| classes = ['Negative', 'Neutral', 'Positive'] |
| ax.set(xticks=np.arange(cm.shape[1]), |
| yticks=np.arange(cm.shape[0]), |
| xticklabels=classes, yticklabels=classes, |
| ylabel='True Label', xlabel='Predicted Label', |
| title='(a) Twitter-RoBERTa on TweetEval') |
| plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") |
| for i in range(cm.shape[0]): |
| for j in range(cm.shape[1]): |
| ax.text(j, i, format(cm[i, j], '.2f'), |
| ha="center", va="center", |
| color="white" if cm[i, j] > 0.5 else "black", fontsize=12) |
| plt.tight_layout() |
| plt.savefig('/app/figures/fig1_confusion_roberta.png', dpi=300, bbox_inches='tight') |
| plt.savefig('/app/figures/fig1_confusion_roberta.pdf', bbox_inches='tight') |
| plt.close() |
|
|
| |
| print(" Fig 2: Confusion Matrix (DeBERTa on TweetEval)...") |
| cm2 = confusion_matrix(tweeteval_test_labels, tweeteval_preds_deberta, normalize='true') |
| fig, ax = plt.subplots(figsize=(5, 4)) |
| im = ax.imshow(cm2, interpolation='nearest', cmap='Oranges') |
| ax.figure.colorbar(im, ax=ax, fraction=0.046, pad=0.04) |
| ax.set(xticks=np.arange(cm2.shape[1]), |
| yticks=np.arange(cm2.shape[0]), |
| xticklabels=classes, yticklabels=classes, |
| ylabel='True Label', xlabel='Predicted Label', |
| title='(b) DeBERTa-v3-base on TweetEval') |
| plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") |
| for i in range(cm2.shape[0]): |
| for j in range(cm2.shape[1]): |
| ax.text(j, i, format(cm2[i, j], '.2f'), |
| ha="center", va="center", |
| color="white" if cm2[i, j] > 0.5 else "black", fontsize=12) |
| plt.tight_layout() |
| plt.savefig('/app/figures/fig2_confusion_deberta.png', dpi=300, bbox_inches='tight') |
| plt.savefig('/app/figures/fig2_confusion_deberta.pdf', bbox_inches='tight') |
| plt.close() |
|
|
| |
| print(" Fig 3: Model Comparison Bar Chart...") |
| models = ['DistilBERT\n(66M)', 'BERT-base\n(110M)', 'Twitter-\nRoBERTa\n(125M)', 'DeBERTa-v3\n(184M)'] |
| sst2_accs = [ |
| all_results['DistilBERT']['sst2']['accuracy'], |
| all_results['BERT-base']['sst2']['accuracy'], |
| all_results['Twitter-RoBERTa']['sst2']['accuracy'], |
| all_results['DeBERTa-v3-base']['sst2']['accuracy'], |
| ] |
| tweet_f1s = [ |
| all_results['DistilBERT']['tweeteval_3class']['f1'], |
| all_results['BERT-base']['tweeteval_3class']['f1'], |
| all_results['Twitter-RoBERTa']['tweeteval_3class']['f1'], |
| all_results['DeBERTa-v3-base']['tweeteval_3class']['f1'], |
| ] |
|
|
| x = np.arange(len(models)) |
| width = 0.35 |
|
|
| fig, ax = plt.subplots(figsize=(8, 5)) |
| bars1 = ax.bar(x - width/2, sst2_accs, width, label='SST-2 Accuracy (%)', color='#2196F3', edgecolor='black', linewidth=0.5) |
| bars2 = ax.bar(x + width/2, tweet_f1s, width, label='TweetEval Macro-F1 (%)', color='#FF9800', edgecolor='black', linewidth=0.5) |
|
|
| ax.set_ylabel('Score (%)', fontsize=12) |
| ax.set_title('Model Comparison: SST-2 vs TweetEval Performance', fontsize=13, fontweight='bold') |
| ax.set_xticks(x) |
| ax.set_xticklabels(models, fontsize=9) |
| ax.legend(loc='upper left', fontsize=10) |
| ax.set_ylim(0, 105) |
| ax.grid(axis='y', alpha=0.3) |
|
|
| |
| for bar in bars1: |
| h = bar.get_height() |
| ax.text(bar.get_x() + bar.get_width()/2., h + 0.5, f'{h:.1f}', ha='center', va='bottom', fontsize=8) |
| for bar in bars2: |
| h = bar.get_height() |
| ax.text(bar.get_x() + bar.get_width()/2., h + 0.5, f'{h:.1f}', ha='center', va='bottom', fontsize=8) |
|
|
| |
| ax.axhline(y=95, color='red', linestyle='--', alpha=0.7, linewidth=1) |
| ax.text(3.5, 95.5, '95% Target', color='red', fontsize=8, ha='right') |
|
|
| plt.tight_layout() |
| plt.savefig('/app/figures/fig3_model_comparison.png', dpi=300, bbox_inches='tight') |
| plt.savefig('/app/figures/fig3_model_comparison.pdf', bbox_inches='tight') |
| plt.close() |
|
|
| |
| print(" Fig 4: Dataset class distribution...") |
| fig, axes = plt.subplots(1, 2, figsize=(8, 3.5)) |
|
|
| |
| te_counts = Counter(tweeteval_test_labels) |
| axes[0].bar(classes, [te_counts[0], te_counts[1], te_counts[2]], |
| color=['#e74c3c', '#95a5a6', '#2ecc71'], edgecolor='black', linewidth=0.5) |
| axes[0].set_title('TweetEval Test Set Distribution', fontsize=10, fontweight='bold') |
| axes[0].set_ylabel('Count') |
| for i, v in enumerate([te_counts[0], te_counts[1], te_counts[2]]): |
| axes[0].text(i, v + 20, str(v), ha='center', fontsize=9) |
|
|
| |
| sst2_counts = Counter(sst2_val_labels) |
| axes[1].bar(['Negative', 'Positive'], [sst2_counts[0], sst2_counts[1]], |
| color=['#e74c3c', '#2ecc71'], edgecolor='black', linewidth=0.5) |
| axes[1].set_title('SST-2 Validation Set Distribution', fontsize=10, fontweight='bold') |
| axes[1].set_ylabel('Count') |
| for i, v in enumerate([sst2_counts[0], sst2_counts[1]]): |
| axes[1].text(i, v + 5, str(v), ha='center', fontsize=9) |
|
|
| plt.tight_layout() |
| plt.savefig('/app/figures/fig4_data_distribution.png', dpi=300, bbox_inches='tight') |
| plt.savefig('/app/figures/fig4_data_distribution.pdf', bbox_inches='tight') |
| plt.close() |
|
|
| |
| print(" Fig 5: Per-class F1 comparison...") |
| |
| report_rob = classification_report(tweeteval_test_labels, tweeteval_preds_roberta, |
| target_names=classes, output_dict=True) |
| |
| report_deb = classification_report(tweeteval_test_labels, tweeteval_preds_deberta, |
| target_names=classes, output_dict=True) |
|
|
| fig, ax = plt.subplots(figsize=(7, 4)) |
| x = np.arange(3) |
| width = 0.35 |
|
|
| f1_rob = [report_rob[c]['f1-score']*100 for c in classes] |
| f1_deb = [report_deb[c]['f1-score']*100 for c in classes] |
|
|
| bars1 = ax.bar(x - width/2, f1_rob, width, label='Twitter-RoBERTa', color='#2196F3', edgecolor='black', linewidth=0.5) |
| bars2 = ax.bar(x + width/2, f1_deb, width, label='DeBERTa-v3-base', color='#FF5722', edgecolor='black', linewidth=0.5) |
|
|
| ax.set_ylabel('F1-Score (%)') |
| ax.set_title('Per-Class F1 Scores on TweetEval Sentiment', fontsize=12, fontweight='bold') |
| ax.set_xticks(x) |
| ax.set_xticklabels(classes) |
| ax.legend() |
| ax.set_ylim(0, 100) |
| ax.grid(axis='y', alpha=0.3) |
|
|
| for bar in bars1: |
| h = bar.get_height() |
| ax.text(bar.get_x() + bar.get_width()/2., h + 1, f'{h:.1f}', ha='center', va='bottom', fontsize=8) |
| for bar in bars2: |
| h = bar.get_height() |
| ax.text(bar.get_x() + bar.get_width()/2., h + 1, f'{h:.1f}', ha='center', va='bottom', fontsize=8) |
|
|
| plt.tight_layout() |
| plt.savefig('/app/figures/fig5_per_class_f1.png', dpi=300, bbox_inches='tight') |
| plt.savefig('/app/figures/fig5_per_class_f1.pdf', bbox_inches='tight') |
| plt.close() |
|
|
| |
| |
| |
| with open("/app/comprehensive_results.json", "w") as f: |
| json.dump(all_results, f, indent=2) |
|
|
| print("\n" + "="*60) |
| print("COMPREHENSIVE RESULTS SUMMARY") |
| print("="*60) |
| print(f"\n{'Model':<20} {'Params':>7} {'SST-2 Acc':>10} {'TweetEval Acc':>14} {'TweetEval F1':>13}") |
| print("-"*68) |
| for name, res in all_results.items(): |
| print(f"{name:<20} {res['params']:>7} {res['sst2']['accuracy']:>9.2f}% {res['tweeteval_3class']['accuracy']:>13.2f}% {res['tweeteval_3class']['f1']:>12.2f}%") |
| print("="*68) |
|
|
| print(f"\nFigures saved to /app/figures/") |
| print(f" fig1_confusion_roberta.png/pdf") |
| print(f" fig2_confusion_deberta.png/pdf") |
| print(f" fig3_model_comparison.png/pdf") |
| print(f" fig4_data_distribution.png/pdf") |
| print(f" fig5_per_class_f1.png/pdf") |
| print(f"\nResults saved to /app/comprehensive_results.json") |
|
|