| """ |
| COMPREHENSIVE EXPERIMENT SUITE for IEEE paper: |
| 1) Fine-tune roberta-base on TweetEval (3 seeds) |
| 2) Evaluate all models with multi-seed where applicable |
| 3) Ablation: learning rate sweep |
| 4) McNemar test between all model pairs |
| 5) Generate all figures including training curves |
| """ |
|
|
| import torch |
| torch.set_num_threads(2) |
|
|
| import os, json, gc, time, sys |
| import numpy as np |
| from collections import Counter, defaultdict |
| from datasets import load_dataset |
| from transformers import ( |
| AutoTokenizer, AutoModelForSequenceClassification, |
| TrainingArguments, Trainer, DataCollatorWithPadding, |
| EarlyStoppingCallback, pipeline |
| ) |
| from sklearn.metrics import ( |
| recall_score, f1_score, accuracy_score, precision_score, |
| confusion_matrix, classification_report |
| ) |
| from scipy.stats import chi2 |
| import matplotlib |
| matplotlib.use('Agg') |
| import matplotlib.pyplot as plt |
|
|
| os.makedirs("/app/figures_v2", exist_ok=True) |
| os.makedirs("/app/training_logs", exist_ok=True) |
| plt.rcParams['font.family'] = 'serif' |
| plt.rcParams['font.size'] = 10 |
|
|
| def log(msg): |
| print(msg, flush=True) |
|
|
| log("="*70) |
| log("IEEE PAPER: COMPREHENSIVE EXPERIMENT SUITE") |
| log("="*70) |
|
|
| |
| |
| |
| log("\n[1/7] Loading datasets...") |
| tweeteval = load_dataset("cardiffnlp/tweet_eval", "sentiment") |
| sst2 = load_dataset("stanfordnlp/sst2") |
|
|
| label_names = ['Negative', 'Neutral', 'Positive'] |
| log(f" TweetEval: train={len(tweeteval['train'])}, val={len(tweeteval['validation'])}, test={len(tweeteval['test'])}") |
| log(f" SST-2: train={len(sst2['train'])}, val={len(sst2['validation'])}") |
| log(f" TweetEval test distribution: {Counter(tweeteval['test']['label'])}") |
|
|
| def preprocess_tweet(text): |
| if not text: return "" |
| return " ".join( |
| '@user' if w.startswith('@') and len(w)>1 else ('http' if w.startswith('http') else w) |
| for w in text.split() |
| ) |
|
|
| |
| |
| |
| log("\n[2/7] Fine-tuning roberta-base on TweetEval (3 seeds)...") |
|
|
| MODEL_FT = "roberta-base" |
| tok_ft = AutoTokenizer.from_pretrained(MODEL_FT) |
|
|
| def tokenize_tweeteval(examples): |
| texts = [preprocess_tweet(t) for t in examples['text']] |
| return tok_ft(texts, truncation=True, max_length=128, padding=False) |
|
|
| te_enc = tweeteval.map(tokenize_tweeteval, batched=True, remove_columns=['text']) |
| collator_ft = DataCollatorWithPadding(tokenizer=tok_ft) |
|
|
| def compute_metrics_ft(eval_pred): |
| logits, labels = eval_pred |
| preds = np.argmax(logits, axis=-1) |
| return { |
| 'macro_recall': recall_score(labels, preds, average='macro'), |
| 'macro_f1': f1_score(labels, preds, average='macro'), |
| 'accuracy': accuracy_score(labels, preds), |
| } |
|
|
| seed_results_roberta = [] |
| seed_preds_roberta = [] |
| training_histories = [] |
|
|
| for seed in [1, 2, 3]: |
| log(f"\n --- Seed {seed}/3 ---") |
| model_ft = AutoModelForSequenceClassification.from_pretrained(MODEL_FT, num_labels=3) |
| |
| args = TrainingArguments( |
| output_dir=f'/app/ft_roberta_seed{seed}', |
| num_train_epochs=2, |
| per_device_train_batch_size=16, |
| per_device_eval_batch_size=64, |
| learning_rate=1e-5, |
| weight_decay=0.01, |
| warmup_ratio=0.1, |
| lr_scheduler_type='linear', |
| eval_strategy='epoch', |
| save_strategy='epoch', |
| logging_strategy='steps', |
| logging_steps=100, |
| logging_first_step=True, |
| disable_tqdm=True, |
| load_best_model_at_end=True, |
| metric_for_best_model='macro_recall', |
| greater_is_better=True, |
| seed=seed, |
| dataloader_num_workers=0, |
| fp16=False, |
| gradient_checkpointing=False, |
| report_to='none', |
| save_total_limit=1, |
| ) |
| |
| trainer = Trainer( |
| model=model_ft, args=args, |
| train_dataset=te_enc['train'], |
| eval_dataset=te_enc['validation'], |
| data_collator=collator_ft, |
| compute_metrics=compute_metrics_ft, |
| callbacks=[EarlyStoppingCallback(early_stopping_patience=2)], |
| ) |
| |
| t0 = time.time() |
| result = trainer.train() |
| train_time = time.time() - t0 |
| |
| |
| history = trainer.state.log_history |
| training_histories.append(history) |
| |
| |
| test_metrics = trainer.evaluate(te_enc['test']) |
| log(f" Seed {seed}: Macro-Recall={test_metrics['eval_macro_recall']:.4f}, " |
| f"Macro-F1={test_metrics['eval_macro_f1']:.4f}, Acc={test_metrics['eval_accuracy']:.4f}, " |
| f"Time={train_time/60:.1f}min") |
| |
| |
| preds_output = trainer.predict(te_enc['test']) |
| preds = np.argmax(preds_output.predictions, axis=-1) |
| seed_preds_roberta.append(preds.tolist()) |
| |
| seed_results_roberta.append({ |
| 'seed': seed, |
| 'macro_recall': test_metrics['eval_macro_recall'], |
| 'macro_f1': test_metrics['eval_macro_f1'], |
| 'accuracy': test_metrics['eval_accuracy'], |
| 'train_time_min': round(train_time/60, 1), |
| 'train_loss': result.training_loss, |
| }) |
| |
| del model_ft, trainer |
| gc.collect() |
|
|
| |
| recalls = [r['macro_recall'] for r in seed_results_roberta] |
| f1s = [r['macro_f1'] for r in seed_results_roberta] |
| accs = [r['accuracy'] for r in seed_results_roberta] |
|
|
| log(f"\n RoBERTa-base (3 seeds): Macro-Recall={np.mean(recalls):.4f}+/-{np.std(recalls):.4f}, " |
| f"Macro-F1={np.mean(f1s):.4f}+/-{np.std(f1s):.4f}, Acc={np.mean(accs):.4f}+/-{np.std(accs):.4f}") |
|
|
| |
| |
| |
| log("\n[3/7] Evaluating pre-trained models on TweetEval...") |
|
|
| te_test_texts = [preprocess_tweet(t) for t in list(tweeteval['test']['text'])] |
| te_test_labels = list(tweeteval['test']['label']) |
| sst2_texts = list(sst2['validation']['sentence']) |
| sst2_labels = list(sst2['validation']['label']) |
|
|
| all_model_preds = {} |
| all_model_metrics = {} |
|
|
| |
| log(" Twitter-RoBERTa...") |
| pipe = pipeline("text-classification", model="cardiffnlp/twitter-roberta-base-sentiment-latest", |
| device=-1, batch_size=32, top_k=None) |
| preds_twr = [] |
| for out in pipe(te_test_texts, truncation=True, max_length=128): |
| scores = {r['label'].lower(): r['score'] for r in out} |
| label_scores = [scores.get('negative',0), scores.get('neutral',0), scores.get('positive',0)] |
| preds_twr.append(np.argmax(label_scores)) |
|
|
| mr_twr = recall_score(te_test_labels, preds_twr, average='macro') |
| mf1_twr = f1_score(te_test_labels, preds_twr, average='macro') |
| acc_twr = accuracy_score(te_test_labels, preds_twr) |
| log(f" Twitter-RoBERTa: MR={mr_twr:.4f}, MF1={mf1_twr:.4f}, Acc={acc_twr:.4f}") |
| all_model_preds['Twitter-RoBERTa'] = preds_twr |
| all_model_metrics['Twitter-RoBERTa'] = {'macro_recall': mr_twr, 'macro_f1': mf1_twr, 'accuracy': acc_twr, 'params': '125M'} |
|
|
| |
| preds_twr_sst2 = [] |
| for out in pipe(sst2_texts, truncation=True, max_length=128): |
| scores = {r['label'].lower(): r['score'] for r in out} |
| preds_twr_sst2.append(1 if scores.get('positive',0) > scores.get('negative',0) else 0) |
| acc_twr_sst2 = accuracy_score(sst2_labels, preds_twr_sst2) |
| all_model_metrics['Twitter-RoBERTa']['sst2_accuracy'] = acc_twr_sst2 |
| del pipe; gc.collect() |
|
|
| |
| log(" DeBERTa-v3-base...") |
| pipe = pipeline("text-classification", model="cliang1453/deberta-v3-base-sst2", device=-1, batch_size=16) |
| preds_deb = [] |
| for out in pipe(te_test_texts, truncation=True, max_length=128): |
| preds_deb.append(2 if out['label'].lower() == 'positive' else 0) |
| mr_deb = recall_score(te_test_labels, preds_deb, average='macro') |
| mf1_deb = f1_score(te_test_labels, preds_deb, average='macro', zero_division=0) |
| acc_deb = accuracy_score(te_test_labels, preds_deb) |
| log(f" DeBERTa-v3: MR={mr_deb:.4f}, MF1={mf1_deb:.4f}, Acc={acc_deb:.4f}") |
| all_model_preds['DeBERTa-v3-base'] = preds_deb |
| all_model_metrics['DeBERTa-v3-base'] = {'macro_recall': mr_deb, 'macro_f1': mf1_deb, 'accuracy': acc_deb, 'params': '184M'} |
|
|
| preds_deb_sst2 = [] |
| for out in pipe(sst2_texts, truncation=True, max_length=128): |
| preds_deb_sst2.append(1 if out['label'].lower() == 'positive' else 0) |
| all_model_metrics['DeBERTa-v3-base']['sst2_accuracy'] = accuracy_score(sst2_labels, preds_deb_sst2) |
| del pipe; gc.collect() |
|
|
| |
| log(" BERT-base...") |
| pipe = pipeline("text-classification", model="textattack/bert-base-uncased-SST-2", device=-1, batch_size=32) |
| preds_bert = [] |
| for out in pipe(te_test_texts, truncation=True, max_length=128): |
| preds_bert.append(2 if out['label'].upper() in ['POSITIVE','LABEL_1','1'] else 0) |
| mr_bert = recall_score(te_test_labels, preds_bert, average='macro') |
| mf1_bert = f1_score(te_test_labels, preds_bert, average='macro', zero_division=0) |
| acc_bert = accuracy_score(te_test_labels, preds_bert) |
| all_model_preds['BERT-base'] = preds_bert |
| all_model_metrics['BERT-base'] = {'macro_recall': mr_bert, 'macro_f1': mf1_bert, 'accuracy': acc_bert, 'params': '110M'} |
|
|
| preds_bert_sst2 = [] |
| for out in pipe(sst2_texts, truncation=True, max_length=128): |
| preds_bert_sst2.append(1 if out['label'].upper() in ['POSITIVE','LABEL_1','1'] else 0) |
| all_model_metrics['BERT-base']['sst2_accuracy'] = accuracy_score(sst2_labels, preds_bert_sst2) |
| del pipe; gc.collect() |
|
|
| |
| log(" DistilBERT...") |
| pipe = pipeline("text-classification", model="distilbert/distilbert-base-uncased-finetuned-sst-2-english", |
| device=-1, batch_size=32) |
| preds_distil = [] |
| for out in pipe(te_test_texts, truncation=True, max_length=128): |
| preds_distil.append(2 if out['label'] == 'POSITIVE' else 0) |
| mr_dist = recall_score(te_test_labels, preds_distil, average='macro') |
| mf1_dist = f1_score(te_test_labels, preds_distil, average='macro', zero_division=0) |
| acc_dist = accuracy_score(te_test_labels, preds_distil) |
| all_model_preds['DistilBERT'] = preds_distil |
| all_model_metrics['DistilBERT'] = {'macro_recall': mr_dist, 'macro_f1': mf1_dist, 'accuracy': acc_dist, 'params': '66M'} |
|
|
| preds_dist_sst2 = [] |
| for out in pipe(sst2_texts, truncation=True, max_length=128): |
| preds_dist_sst2.append(1 if out['label'] == 'POSITIVE' else 0) |
| all_model_metrics['DistilBERT']['sst2_accuracy'] = accuracy_score(sst2_labels, preds_dist_sst2) |
| del pipe; gc.collect() |
|
|
| |
| best_seed_idx = np.argmax(recalls) |
| all_model_preds['RoBERTa-FT'] = seed_preds_roberta[best_seed_idx] |
| all_model_metrics['RoBERTa-FT'] = { |
| 'macro_recall': np.mean(recalls), 'macro_recall_std': np.std(recalls), |
| 'macro_f1': np.mean(f1s), 'macro_f1_std': np.std(f1s), |
| 'accuracy': np.mean(accs), 'accuracy_std': np.std(accs), |
| 'params': '125M', |
| } |
|
|
| |
| |
| |
| log("\n[4/7] Ablation: Learning Rate Sweep (roberta-base, seed=1)...") |
|
|
| ablation_lrs = [5e-6, 1e-5, 3e-5] |
| ablation_results = [] |
|
|
| for lr in ablation_lrs: |
| log(f" LR={lr}...") |
| model_ab = AutoModelForSequenceClassification.from_pretrained(MODEL_FT, num_labels=3) |
| args_ab = TrainingArguments( |
| output_dir=f'/app/ablation_lr_{lr}', |
| num_train_epochs=2, |
| per_device_train_batch_size=16, |
| per_device_eval_batch_size=64, |
| learning_rate=lr, |
| weight_decay=0.01, |
| warmup_ratio=0.1, |
| eval_strategy='epoch', |
| save_strategy='epoch', |
| disable_tqdm=True, |
| load_best_model_at_end=True, |
| metric_for_best_model='macro_recall', |
| greater_is_better=True, |
| seed=1, |
| dataloader_num_workers=0, |
| fp16=False, |
| report_to='none', |
| save_total_limit=1, |
| logging_strategy='steps', |
| logging_steps=500, |
| logging_first_step=True, |
| ) |
| trainer_ab = Trainer( |
| model=model_ab, args=args_ab, |
| train_dataset=te_enc['train'], |
| eval_dataset=te_enc['validation'], |
| data_collator=collator_ft, |
| compute_metrics=compute_metrics_ft, |
| ) |
| trainer_ab.train() |
| test_ab = trainer_ab.evaluate(te_enc['test']) |
| ablation_results.append({ |
| 'lr': lr, |
| 'macro_recall': test_ab['eval_macro_recall'], |
| 'macro_f1': test_ab['eval_macro_f1'], |
| 'accuracy': test_ab['eval_accuracy'], |
| }) |
| log(f" LR={lr}: MR={test_ab['eval_macro_recall']:.4f}, MF1={test_ab['eval_macro_f1']:.4f}") |
| del model_ab, trainer_ab; gc.collect() |
|
|
| |
| |
| |
| log("\n[5/7] McNemar Statistical Significance Tests...") |
|
|
| def mcnemar_test(y_true, y_pred1, y_pred2): |
| b = sum(1 for yt,p1,p2 in zip(y_true,y_pred1,y_pred2) if p1==yt and p2!=yt) |
| c = sum(1 for yt,p1,p2 in zip(y_true,y_pred1,y_pred2) if p1!=yt and p2==yt) |
| if b + c == 0: return 0.0, 1.0 |
| stat = (abs(b - c) - 1)**2 / (b + c) |
| p_val = 1 - chi2.cdf(stat, df=1) |
| return stat, p_val |
|
|
| mcnemar_results = {} |
| model_names = list(all_model_preds.keys()) |
| for i in range(len(model_names)): |
| for j in range(i+1, len(model_names)): |
| m1, m2 = model_names[i], model_names[j] |
| stat, p = mcnemar_test(te_test_labels, all_model_preds[m1], all_model_preds[m2]) |
| sig = "***" if p < 0.001 else ("**" if p < 0.01 else ("*" if p < 0.05 else "ns")) |
| mcnemar_results[f"{m1} vs {m2}"] = {'statistic': round(stat, 2), 'p_value': round(p, 6), 'significance': sig} |
| log(f" {m1} vs {m2}: chi2={stat:.2f}, p={p:.6f} {sig}") |
|
|
| |
| |
| |
| log("\n[6/7] Generating figures...") |
|
|
| |
| log(" Fig 1: Training loss curves...") |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4)) |
| for sidx, hist in enumerate(training_histories): |
| train_losses = [(h['step'], h['loss']) for h in hist if 'loss' in h and 'eval_loss' not in h] |
| eval_losses = [(h['epoch'], h['eval_loss']) for h in hist if 'eval_loss' in h] |
| eval_recalls = [(h['epoch'], h['eval_macro_recall']) for h in hist if 'eval_macro_recall' in h] |
| |
| if train_losses: |
| steps, losses = zip(*train_losses) |
| ax1.plot(steps, losses, alpha=0.7, label=f'Seed {sidx+1}') |
| if eval_recalls: |
| epochs, recs = zip(*eval_recalls) |
| ax2.plot(epochs, [r*100 for r in recs], 'o-', alpha=0.8, label=f'Seed {sidx+1}') |
|
|
| ax1.set_xlabel('Training Step'); ax1.set_ylabel('Loss') |
| ax1.set_title('(a) Training Loss', fontweight='bold') |
| ax1.legend(); ax1.grid(alpha=0.3) |
| ax2.set_xlabel('Epoch'); ax2.set_ylabel('Macro-Recall (%)') |
| ax2.set_title('(b) Validation Macro-Recall', fontweight='bold') |
| ax2.legend(); ax2.grid(alpha=0.3) |
| plt.tight_layout() |
| plt.savefig('/app/figures_v2/fig_training_curves.png', dpi=300, bbox_inches='tight') |
| plt.savefig('/app/figures_v2/fig_training_curves.pdf', bbox_inches='tight') |
| plt.close() |
|
|
| |
| log(" Fig 2: Confusion matrices...") |
| fig, axes = plt.subplots(1, 3, figsize=(14, 4)) |
| models_cm = [('Twitter-RoBERTa', preds_twr), ('RoBERTa-FT (Ours)', seed_preds_roberta[best_seed_idx]), ('DeBERTa-v3', preds_deb)] |
| cmaps = ['Blues', 'Greens', 'Oranges'] |
| for idx, (name, preds) in enumerate(models_cm): |
| cm = confusion_matrix(te_test_labels, preds, normalize='true') |
| im = axes[idx].imshow(cm, interpolation='nearest', cmap=cmaps[idx], vmin=0, vmax=1) |
| axes[idx].set(xticks=range(3), yticks=range(3), |
| xticklabels=['Neg','Neu','Pos'], yticklabels=['Neg','Neu','Pos'], |
| title=name, ylabel='True' if idx==0 else '', xlabel='Predicted') |
| for i in range(3): |
| for j in range(3): |
| axes[idx].text(j, i, f'{cm[i,j]:.2f}', ha='center', va='center', |
| color='white' if cm[i,j]>0.5 else 'black', fontsize=11) |
| plt.tight_layout() |
| plt.savefig('/app/figures_v2/fig_confusion_matrices.png', dpi=300, bbox_inches='tight') |
| plt.savefig('/app/figures_v2/fig_confusion_matrices.pdf', bbox_inches='tight') |
| plt.close() |
|
|
| |
| log(" Fig 3: Model comparison...") |
| models_plot = ['DistilBERT\n(66M)', 'BERT-base\n(110M)', 'RoBERTa-FT\n(125M)\n[Ours]', 'Twitter-\nRoBERTa\n(125M)', 'DeBERTa-v3\n(184M)'] |
| sst2_vals = [ |
| all_model_metrics['DistilBERT'].get('sst2_accuracy',0)*100, |
| all_model_metrics['BERT-base'].get('sst2_accuracy',0)*100, |
| np.mean(accs)*100, |
| all_model_metrics['Twitter-RoBERTa'].get('sst2_accuracy',0)*100, |
| all_model_metrics['DeBERTa-v3-base'].get('sst2_accuracy',0)*100, |
| ] |
| tweet_vals = [ |
| all_model_metrics['DistilBERT']['macro_recall']*100, |
| all_model_metrics['BERT-base']['macro_recall']*100, |
| np.mean(recalls)*100, |
| all_model_metrics['Twitter-RoBERTa']['macro_recall']*100, |
| all_model_metrics['DeBERTa-v3-base']['macro_recall']*100, |
| ] |
| tweet_stds = [0, 0, np.std(recalls)*100, 0, 0] |
|
|
| x = np.arange(len(models_plot)); width = 0.35 |
| fig, ax = plt.subplots(figsize=(10, 5)) |
| bars1 = ax.bar(x - width/2, sst2_vals, width, label='SST-2 Accuracy (%)', color='#2196F3', edgecolor='black', linewidth=0.5) |
| bars2 = ax.bar(x + width/2, tweet_vals, width, yerr=tweet_stds, capsize=3, |
| label='TweetEval Macro-Recall (%)', color='#FF9800', edgecolor='black', linewidth=0.5) |
| ax.set_ylabel('Score (%)'); ax.set_xticks(x); ax.set_xticklabels(models_plot, fontsize=8) |
| ax.set_title('Model Comparison: SST-2 vs TweetEval Performance', fontweight='bold') |
| ax.legend(loc='upper left'); ax.set_ylim(0, 105); ax.grid(axis='y', alpha=0.3) |
| ax.axhline(y=95, color='red', linestyle='--', alpha=0.5); ax.text(4.3, 95.5, '95%', color='red', fontsize=7) |
| for bar in bars1+bars2: |
| h = bar.get_height() |
| if h > 0: ax.text(bar.get_x()+bar.get_width()/2., h+1, f'{h:.1f}', ha='center', va='bottom', fontsize=7) |
| plt.tight_layout() |
| plt.savefig('/app/figures_v2/fig_model_comparison.png', dpi=300, bbox_inches='tight') |
| plt.savefig('/app/figures_v2/fig_model_comparison.pdf', bbox_inches='tight') |
| plt.close() |
|
|
| |
| log(" Fig 4: Learning rate ablation...") |
| fig, ax = plt.subplots(figsize=(7, 4)) |
| lrs = [r['lr'] for r in ablation_results] |
| mr_vals = [r['macro_recall']*100 for r in ablation_results] |
| mf1_vals = [r['macro_f1']*100 for r in ablation_results] |
| ax.plot(range(len(lrs)), mr_vals, 'o-', color='#2196F3', label='Macro-Recall', linewidth=2, markersize=8) |
| ax.plot(range(len(lrs)), mf1_vals, 's--', color='#FF9800', label='Macro-F1', linewidth=2, markersize=8) |
| ax.set_xticks(range(len(lrs))); ax.set_xticklabels([f'{lr:.0e}' for lr in lrs]) |
| ax.set_xlabel('Learning Rate'); ax.set_ylabel('Score (%)') |
| ax.set_title('Ablation: Learning Rate Sensitivity (RoBERTa-base)', fontweight='bold') |
| ax.legend(); ax.grid(alpha=0.3) |
| for i, (mr, mf) in enumerate(zip(mr_vals, mf1_vals)): |
| ax.annotate(f'{mr:.1f}', (i, mr), textcoords="offset points", xytext=(0,8), ha='center', fontsize=7) |
| plt.tight_layout() |
| plt.savefig('/app/figures_v2/fig_lr_ablation.png', dpi=300, bbox_inches='tight') |
| plt.savefig('/app/figures_v2/fig_lr_ablation.pdf', bbox_inches='tight') |
| plt.close() |
|
|
| |
| log(" Fig 5: Per-class F1...") |
| fig, ax = plt.subplots(figsize=(8, 4.5)) |
| classes = ['Negative', 'Neutral', 'Positive'] |
| x = np.arange(3); w = 0.25 |
| for idx, (name, preds) in enumerate(models_cm): |
| report = classification_report(te_test_labels, preds, output_dict=True, zero_division=0) |
| f1s_class = [report[c]['f1-score']*100 for c in classes] |
| bars = ax.bar(x + idx*w, f1s_class, w, label=name, edgecolor='black', linewidth=0.5) |
| for bar, v in zip(bars, f1s_class): |
| ax.text(bar.get_x()+bar.get_width()/2., bar.get_height()+1, f'{v:.1f}', ha='center', fontsize=7) |
| ax.set_xticks(x + w); ax.set_xticklabels(classes) |
| ax.set_ylabel('F1-Score (%)'); ax.set_title('Per-Class F1 on TweetEval', fontweight='bold') |
| ax.legend(fontsize=9); ax.set_ylim(0, 100); ax.grid(axis='y', alpha=0.3) |
| plt.tight_layout() |
| plt.savefig('/app/figures_v2/fig_per_class_f1.png', dpi=300, bbox_inches='tight') |
| plt.savefig('/app/figures_v2/fig_per_class_f1.pdf', bbox_inches='tight') |
| plt.close() |
|
|
| |
| log(" Fig 6: Dataset distribution...") |
| fig, axes = plt.subplots(1, 2, figsize=(8, 3.5)) |
| te_counts = Counter(te_test_labels) |
| axes[0].bar(classes, [te_counts[0], te_counts[1], te_counts[2]], |
| color=['#e74c3c', '#95a5a6', '#2ecc71'], edgecolor='black', linewidth=0.5) |
| axes[0].set_title('TweetEval Test (n=12,284)', fontweight='bold'); axes[0].set_ylabel('Count') |
| for i, v in enumerate([te_counts[0], te_counts[1], te_counts[2]]): |
| axes[0].text(i, v+50, f'{v}\n({v/len(te_test_labels)*100:.1f}%)', ha='center', fontsize=8) |
| sst2_counts = Counter(sst2_labels) |
| axes[1].bar(['Negative', 'Positive'], [sst2_counts[0], sst2_counts[1]], |
| color=['#e74c3c', '#2ecc71'], edgecolor='black', linewidth=0.5) |
| axes[1].set_title('SST-2 Validation (n=872)', fontweight='bold'); axes[1].set_ylabel('Count') |
| for i, v in enumerate([sst2_counts[0], sst2_counts[1]]): |
| axes[1].text(i, v+5, f'{v}\n({v/len(sst2_labels)*100:.1f}%)', ha='center', fontsize=8) |
| plt.tight_layout() |
| plt.savefig('/app/figures_v2/fig_data_distribution.png', dpi=300, bbox_inches='tight') |
| plt.savefig('/app/figures_v2/fig_data_distribution.pdf', bbox_inches='tight') |
| plt.close() |
|
|
| |
| |
| |
| log("\n[7/7] Saving results...") |
|
|
| full_results = { |
| 'multi_seed_roberta_ft': seed_results_roberta, |
| 'multi_seed_summary': { |
| 'macro_recall_mean': round(np.mean(recalls), 4), |
| 'macro_recall_std': round(np.std(recalls), 4), |
| 'macro_f1_mean': round(np.mean(f1s), 4), |
| 'macro_f1_std': round(np.std(f1s), 4), |
| 'accuracy_mean': round(np.mean(accs), 4), |
| 'accuracy_std': round(np.std(accs), 4), |
| }, |
| 'all_models': {k: {kk: round(vv, 4) if isinstance(vv, float) else vv for kk, vv in v.items()} |
| for k, v in all_model_metrics.items()}, |
| 'ablation_lr': ablation_results, |
| 'mcnemar_tests': mcnemar_results, |
| } |
|
|
| with open('/app/full_experiment_results.json', 'w') as f: |
| json.dump(full_results, f, indent=2, default=str) |
|
|
| log("\n" + "="*70) |
| log("COMPLETE RESULTS SUMMARY") |
| log("="*70) |
| log(f"\n{'Model':<25} {'Params':>7} {'SST-2 Acc':>10} {'TweetEval MR':>13} {'TweetEval MF1':>14}") |
| log("-"*73) |
| for name, m in all_model_metrics.items(): |
| sst2_str = f"{m.get('sst2_accuracy',0)*100:.2f}%" if 'sst2_accuracy' in m else "N/A" |
| mr = m['macro_recall'] |
| mf1 = m['macro_f1'] |
| std_str = f"+/-{m.get('macro_recall_std',0)*100:.2f}" if 'macro_recall_std' in m else "" |
| log(f"{name:<25} {m['params']:>7} {sst2_str:>10} {mr*100:>9.2f}%{std_str:>4} {mf1*100:>10.2f}%") |
| log("="*73) |
|
|
| log(f"\nAblation (LR):") |
| for r in ablation_results: |
| log(f" LR={r['lr']:.0e}: MR={r['macro_recall']*100:.2f}%, MF1={r['macro_f1']*100:.2f}%") |
|
|
| log(f"\nMcNemar Tests:") |
| for pair, res in mcnemar_results.items(): |
| log(f" {pair}: p={res['p_value']:.6f} {res['significance']}") |
|
|
| log(f"\nAll figures saved to /app/figures_v2/") |
| log(f"Results saved to /app/full_experiment_results.json") |
| log("DONE!") |
|
|