rajvivan's picture
Add DeBERTa eval script
a6a1051 verified
"""
Evaluate DeBERTa-v3 models fine-tuned on SST-2
"""
import json
import gc
import numpy as np
import time
from datasets import load_dataset
from transformers import pipeline
import evaluate
print("Loading datasets...")
sst2 = load_dataset("stanfordnlp/sst2")
tweets = load_dataset("mteb/tweet_sentiment_extraction")
tweets_test_bin = tweets["test"].filter(lambda x: x["label"] != 1)
def remap_labels(example):
example["label"] = 1 if example["label"] == 2 else 0
return example
tweets_test_bin = tweets_test_bin.map(remap_labels)
def preprocess_tweet_text(text):
if not text:
return ""
return " ".join(
'@user' if t.startswith('@') and len(t) > 1 else ('http' if t.startswith('http') else t)
for t in text.split(" ")
)
sst2_texts = list(sst2["validation"]["sentence"])
sst2_labels = list(sst2["validation"]["label"])
tweet_texts = [preprocess_tweet_text(t) for t in list(tweets_test_bin["text"])]
tweet_labels = list(tweets_test_bin["label"])
accuracy_metric = evaluate.load("accuracy")
f1_metric = evaluate.load("f1")
precision_metric = evaluate.load("precision")
recall_metric = evaluate.load("recall")
def compute_all_metrics(predictions, references):
acc = accuracy_metric.compute(predictions=predictions, references=references)["accuracy"]
f1 = f1_metric.compute(predictions=predictions, references=references, average="weighted")["f1"]
prec = precision_metric.compute(predictions=predictions, references=references, average="weighted")["precision"]
rec = recall_metric.compute(predictions=predictions, references=references, average="weighted")["recall"]
return {
"accuracy": round(acc * 100, 2),
"f1": round(f1 * 100, 2),
"precision": round(prec * 100, 2),
"recall": round(rec * 100, 2),
}
# Load previous results
try:
with open("/app/eval_results.json") as f:
all_results = json.load(f)
except:
all_results = {}
# ── DeBERTa-v3-base-sst2 ─────────────────────────────────────────
print("\n" + "="*60)
print("DeBERTa-v3-base-SST2 (cliang1453)")
print("="*60)
pipe = pipeline("text-classification",
model="cliang1453/deberta-v3-base-sst2",
device=-1, batch_size=16)
# Check label mapping
test = pipe("This is great!")
print(f"Test: {test}")
test2 = pipe("This is terrible!")
print(f"Test2: {test2}")
# Determine label mapping
label_map = {}
for item in [test[0], pipe("awful")[0], pipe("amazing")[0]]:
label_map[item['label']] = item['score']
print(f"Label map: {label_map}")
# SST-2
print("Evaluating SST-2...")
t0 = time.time()
preds = []
for out in pipe(sst2_texts, truncation=True, max_length=128):
label = out['label']
# DeBERTa SST-2 models typically use LABEL_0=negative, LABEL_1=positive
if label in ['LABEL_1', 'POSITIVE', 'positive', '1']:
preds.append(1)
elif label in ['LABEL_0', 'NEGATIVE', 'negative', '0']:
preds.append(0)
else:
preds.append(1 if '1' in label or 'POS' in label.upper() else 0)
sst2_time = time.time() - t0
sst2_metrics = compute_all_metrics(np.array(preds), np.array(sst2_labels))
print(f"SST-2: Acc={sst2_metrics['accuracy']}% F1={sst2_metrics['f1']}% ({sst2_time:.1f}s)")
# Tweets
print("Evaluating Tweets...")
t0 = time.time()
preds = []
for out in pipe(tweet_texts, truncation=True, max_length=128):
label = out['label']
if label in ['LABEL_1', 'POSITIVE', 'positive', '1']:
preds.append(1)
elif label in ['LABEL_0', 'NEGATIVE', 'negative', '0']:
preds.append(0)
else:
preds.append(1 if '1' in label or 'POS' in label.upper() else 0)
tweet_time = time.time() - t0
tweet_metrics = compute_all_metrics(np.array(preds), np.array(tweet_labels))
print(f"Tweet: Acc={tweet_metrics['accuracy']}% F1={tweet_metrics['f1']}% ({tweet_time:.1f}s)")
all_results["DeBERTa-v3-base-SST2"] = {
"model": "cliang1453/deberta-v3-base-sst2",
"params": "184M",
"sst2": sst2_metrics,
"tweet": tweet_metrics,
}
del pipe; gc.collect()
# Save updated results
with open("/app/eval_results.json", "w") as f:
json.dump(all_results, f, indent=2)
print("\n" + "="*60)
print("UPDATED RESULTS SUMMARY")
print("="*60)
print(f"{'Model':<30} {'SST-2 Acc':>10} {'SST-2 F1':>10} {'Tweet Acc':>10} {'Tweet F1':>10}")
print("-"*72)
for name, res in all_results.items():
print(f"{name:<30} {res['sst2']['accuracy']:>9.2f}% {res['sst2']['f1']:>9.2f}% {res['tweet']['accuracy']:>9.2f}% {res['tweet']['f1']:>9.2f}%")
print("="*60)
print("\nπŸ’Ύ Results saved to /app/eval_results.json")