| """ |
| Evaluate DeBERTa-v3 models fine-tuned on SST-2 |
| """ |
| import json |
| import gc |
| import numpy as np |
| import time |
| from datasets import load_dataset |
| from transformers import pipeline |
| import evaluate |
|
|
| print("Loading datasets...") |
| sst2 = load_dataset("stanfordnlp/sst2") |
| tweets = load_dataset("mteb/tweet_sentiment_extraction") |
| tweets_test_bin = tweets["test"].filter(lambda x: x["label"] != 1) |
| def remap_labels(example): |
| example["label"] = 1 if example["label"] == 2 else 0 |
| return example |
| tweets_test_bin = tweets_test_bin.map(remap_labels) |
|
|
| def preprocess_tweet_text(text): |
| if not text: |
| return "" |
| return " ".join( |
| '@user' if t.startswith('@') and len(t) > 1 else ('http' if t.startswith('http') else t) |
| for t in text.split(" ") |
| ) |
|
|
| sst2_texts = list(sst2["validation"]["sentence"]) |
| sst2_labels = list(sst2["validation"]["label"]) |
| tweet_texts = [preprocess_tweet_text(t) for t in list(tweets_test_bin["text"])] |
| tweet_labels = list(tweets_test_bin["label"]) |
|
|
| accuracy_metric = evaluate.load("accuracy") |
| f1_metric = evaluate.load("f1") |
| precision_metric = evaluate.load("precision") |
| recall_metric = evaluate.load("recall") |
|
|
| def compute_all_metrics(predictions, references): |
| acc = accuracy_metric.compute(predictions=predictions, references=references)["accuracy"] |
| f1 = f1_metric.compute(predictions=predictions, references=references, average="weighted")["f1"] |
| prec = precision_metric.compute(predictions=predictions, references=references, average="weighted")["precision"] |
| rec = recall_metric.compute(predictions=predictions, references=references, average="weighted")["recall"] |
| return { |
| "accuracy": round(acc * 100, 2), |
| "f1": round(f1 * 100, 2), |
| "precision": round(prec * 100, 2), |
| "recall": round(rec * 100, 2), |
| } |
|
|
| |
| try: |
| with open("/app/eval_results.json") as f: |
| all_results = json.load(f) |
| except: |
| all_results = {} |
|
|
| |
| print("\n" + "="*60) |
| print("DeBERTa-v3-base-SST2 (cliang1453)") |
| print("="*60) |
|
|
| pipe = pipeline("text-classification", |
| model="cliang1453/deberta-v3-base-sst2", |
| device=-1, batch_size=16) |
|
|
| |
| test = pipe("This is great!") |
| print(f"Test: {test}") |
| test2 = pipe("This is terrible!") |
| print(f"Test2: {test2}") |
|
|
| |
| label_map = {} |
| for item in [test[0], pipe("awful")[0], pipe("amazing")[0]]: |
| label_map[item['label']] = item['score'] |
| print(f"Label map: {label_map}") |
|
|
| |
| print("Evaluating SST-2...") |
| t0 = time.time() |
| preds = [] |
| for out in pipe(sst2_texts, truncation=True, max_length=128): |
| label = out['label'] |
| |
| if label in ['LABEL_1', 'POSITIVE', 'positive', '1']: |
| preds.append(1) |
| elif label in ['LABEL_0', 'NEGATIVE', 'negative', '0']: |
| preds.append(0) |
| else: |
| preds.append(1 if '1' in label or 'POS' in label.upper() else 0) |
| sst2_time = time.time() - t0 |
| sst2_metrics = compute_all_metrics(np.array(preds), np.array(sst2_labels)) |
| print(f"SST-2: Acc={sst2_metrics['accuracy']}% F1={sst2_metrics['f1']}% ({sst2_time:.1f}s)") |
|
|
| |
| print("Evaluating Tweets...") |
| t0 = time.time() |
| preds = [] |
| for out in pipe(tweet_texts, truncation=True, max_length=128): |
| label = out['label'] |
| if label in ['LABEL_1', 'POSITIVE', 'positive', '1']: |
| preds.append(1) |
| elif label in ['LABEL_0', 'NEGATIVE', 'negative', '0']: |
| preds.append(0) |
| else: |
| preds.append(1 if '1' in label or 'POS' in label.upper() else 0) |
| tweet_time = time.time() - t0 |
| tweet_metrics = compute_all_metrics(np.array(preds), np.array(tweet_labels)) |
| print(f"Tweet: Acc={tweet_metrics['accuracy']}% F1={tweet_metrics['f1']}% ({tweet_time:.1f}s)") |
|
|
| all_results["DeBERTa-v3-base-SST2"] = { |
| "model": "cliang1453/deberta-v3-base-sst2", |
| "params": "184M", |
| "sst2": sst2_metrics, |
| "tweet": tweet_metrics, |
| } |
| del pipe; gc.collect() |
|
|
| |
| with open("/app/eval_results.json", "w") as f: |
| json.dump(all_results, f, indent=2) |
|
|
| print("\n" + "="*60) |
| print("UPDATED RESULTS SUMMARY") |
| print("="*60) |
| print(f"{'Model':<30} {'SST-2 Acc':>10} {'SST-2 F1':>10} {'Tweet Acc':>10} {'Tweet F1':>10}") |
| print("-"*72) |
| for name, res in all_results.items(): |
| print(f"{name:<30} {res['sst2']['accuracy']:>9.2f}% {res['sst2']['f1']:>9.2f}% {res['tweet']['accuracy']:>9.2f}% {res['tweet']['f1']:>9.2f}%") |
| print("="*60) |
| print("\nπΎ Results saved to /app/eval_results.json") |
|
|