| """ |
| Evaluate sentiment analysis models one at a time. |
| Memory-efficient evaluation for CPU sandbox. |
| """ |
|
|
| import json |
| import gc |
| import numpy as np |
| import time |
| from datasets import load_dataset |
| from transformers import pipeline |
| import evaluate |
| import torch |
|
|
| print("="*60) |
| print("SENTIMENT ANALYSIS MODEL EVALUATION") |
| print("="*60) |
|
|
| |
| print("\nπ¦ Loading datasets...") |
| sst2 = load_dataset("stanfordnlp/sst2") |
| tweets = load_dataset("mteb/tweet_sentiment_extraction") |
|
|
| tweets_test_bin = tweets["test"].filter(lambda x: x["label"] != 1) |
| def remap_labels(example): |
| example["label"] = 1 if example["label"] == 2 else 0 |
| return example |
| tweets_test_bin = tweets_test_bin.map(remap_labels) |
|
|
| def preprocess_tweet_text(text): |
| if not text: |
| return "" |
| return " ".join( |
| '@user' if t.startswith('@') and len(t) > 1 else ('http' if t.startswith('http') else t) |
| for t in text.split(" ") |
| ) |
|
|
| print(f" SST-2 val: {len(sst2['validation'])} samples") |
| print(f" Tweet test (binary): {len(tweets_test_bin)} samples") |
|
|
| |
| accuracy_metric = evaluate.load("accuracy") |
| f1_metric = evaluate.load("f1") |
| precision_metric = evaluate.load("precision") |
| recall_metric = evaluate.load("recall") |
|
|
| def compute_all_metrics(predictions, references): |
| acc = accuracy_metric.compute(predictions=predictions, references=references)["accuracy"] |
| f1 = f1_metric.compute(predictions=predictions, references=references, average="weighted")["f1"] |
| prec = precision_metric.compute(predictions=predictions, references=references, average="weighted")["precision"] |
| rec = recall_metric.compute(predictions=predictions, references=references, average="weighted")["recall"] |
| return { |
| "accuracy": round(acc * 100, 2), |
| "f1": round(f1 * 100, 2), |
| "precision": round(prec * 100, 2), |
| "recall": round(rec * 100, 2), |
| } |
|
|
| all_results = {} |
|
|
| |
| |
| |
| print("\n" + "="*60) |
| print("1. DistilBERT SST-2") |
| print("="*60) |
|
|
| pipe = pipeline("sentiment-analysis", |
| model="distilbert/distilbert-base-uncased-finetuned-sst-2-english", |
| device=-1, batch_size=32) |
|
|
| |
| print(" Evaluating on SST-2...") |
| t0 = time.time() |
| sst2_texts = list(sst2["validation"]["sentence"]) |
| sst2_labels = list(sst2["validation"]["label"]) |
| preds = [] |
| for out in pipe(sst2_texts, truncation=True, max_length=128): |
| preds.append(1 if out['label'] == 'POSITIVE' else 0) |
| sst2_time = time.time() - t0 |
| sst2_metrics = compute_all_metrics(np.array(preds), np.array(sst2_labels)) |
| sst2_metrics["time_seconds"] = round(sst2_time, 2) |
| print(f" SST-2: Acc={sst2_metrics['accuracy']}% F1={sst2_metrics['f1']}% ({sst2_time:.1f}s)") |
|
|
| |
| print(" Evaluating on Tweets...") |
| t0 = time.time() |
| tweet_texts = [preprocess_tweet_text(t) for t in list(tweets_test_bin["text"])] |
| tweet_labels = list(tweets_test_bin["label"]) |
| preds = [] |
| for out in pipe(tweet_texts, truncation=True, max_length=128): |
| preds.append(1 if out['label'] == 'POSITIVE' else 0) |
| tweet_time = time.time() - t0 |
| tweet_metrics = compute_all_metrics(np.array(preds), np.array(tweet_labels)) |
| tweet_metrics["time_seconds"] = round(tweet_time, 2) |
| print(f" Tweet: Acc={tweet_metrics['accuracy']}% F1={tweet_metrics['f1']}% ({tweet_time:.1f}s)") |
|
|
| all_results["DistilBERT-SST2"] = { |
| "model": "distilbert-base-uncased-finetuned-sst-2-english", |
| "params": "66M", |
| "sst2": sst2_metrics, |
| "tweet": tweet_metrics, |
| } |
| del pipe; gc.collect(); torch.cuda.empty_cache() if torch.cuda.is_available() else None |
|
|
| |
| |
| |
| print("\n" + "="*60) |
| print("2. BERT-base SST-2 (textattack)") |
| print("="*60) |
|
|
| pipe = pipeline("sentiment-analysis", |
| model="textattack/bert-base-uncased-SST-2", |
| device=-1, batch_size=32) |
|
|
| |
| print(" Evaluating on SST-2...") |
| t0 = time.time() |
| preds = [] |
| for out in pipe(sst2_texts, truncation=True, max_length=128): |
| preds.append(1 if out['label'].upper() in ['POSITIVE', 'LABEL_1', '1'] else 0) |
| sst2_time = time.time() - t0 |
| sst2_metrics = compute_all_metrics(np.array(preds), np.array(sst2_labels)) |
| sst2_metrics["time_seconds"] = round(sst2_time, 2) |
| print(f" SST-2: Acc={sst2_metrics['accuracy']}% F1={sst2_metrics['f1']}% ({sst2_time:.1f}s)") |
|
|
| |
| print(" Evaluating on Tweets...") |
| t0 = time.time() |
| preds = [] |
| for out in pipe(tweet_texts, truncation=True, max_length=128): |
| preds.append(1 if out['label'].upper() in ['POSITIVE', 'LABEL_1', '1'] else 0) |
| tweet_time = time.time() - t0 |
| tweet_metrics = compute_all_metrics(np.array(preds), np.array(tweet_labels)) |
| tweet_metrics["time_seconds"] = round(tweet_time, 2) |
| print(f" Tweet: Acc={tweet_metrics['accuracy']}% F1={tweet_metrics['f1']}% ({tweet_time:.1f}s)") |
|
|
| all_results["BERT-base-SST2"] = { |
| "model": "textattack/bert-base-uncased-SST-2", |
| "params": "110M", |
| "sst2": sst2_metrics, |
| "tweet": tweet_metrics, |
| } |
| del pipe; gc.collect() |
|
|
| |
| |
| |
| print("\n" + "="*60) |
| print("3. Twitter-RoBERTa Sentiment") |
| print("="*60) |
|
|
| pipe = pipeline("sentiment-analysis", |
| model="cardiffnlp/twitter-roberta-base-sentiment-latest", |
| device=-1, batch_size=32, top_k=None) |
|
|
| def get_binary_pred_from_3class(result): |
| scores = {item['label'].lower(): item['score'] for item in result} |
| pos = scores.get('positive', scores.get('pos', 0)) |
| neg = scores.get('negative', scores.get('neg', 0)) |
| return 1 if pos > neg else 0 |
|
|
| |
| print(" Evaluating on SST-2...") |
| t0 = time.time() |
| preds = [] |
| for out in pipe(sst2_texts, truncation=True, max_length=128): |
| preds.append(get_binary_pred_from_3class(out)) |
| sst2_time = time.time() - t0 |
| sst2_metrics = compute_all_metrics(np.array(preds), np.array(sst2_labels)) |
| sst2_metrics["time_seconds"] = round(sst2_time, 2) |
| print(f" SST-2: Acc={sst2_metrics['accuracy']}% F1={sst2_metrics['f1']}% ({sst2_time:.1f}s)") |
|
|
| |
| print(" Evaluating on Tweets...") |
| t0 = time.time() |
| preds = [] |
| for out in pipe(tweet_texts, truncation=True, max_length=128): |
| preds.append(get_binary_pred_from_3class(out)) |
| tweet_time = time.time() - t0 |
| tweet_metrics = compute_all_metrics(np.array(preds), np.array(tweet_labels)) |
| tweet_metrics["time_seconds"] = round(tweet_time, 2) |
| print(f" Tweet: Acc={tweet_metrics['accuracy']}% F1={tweet_metrics['f1']}% ({tweet_time:.1f}s)") |
|
|
| all_results["Twitter-RoBERTa"] = { |
| "model": "cardiffnlp/twitter-roberta-base-sentiment-latest", |
| "params": "125M", |
| "sst2": sst2_metrics, |
| "tweet": tweet_metrics, |
| } |
| del pipe; gc.collect() |
|
|
| |
| |
| |
| print("\n" + "="*60) |
| print("4. DeBERTa-v3 Sentiment Models") |
| print("="*60) |
|
|
| deberta_models = [ |
| "howey/deberta-v3-base-sst2", |
| "Proggleb/deberta-v3-base-sst2", |
| "mrm8488/deberta-v3-base-finetuned-sst2", |
| "cross-encoder/stsb-deberta-v3-base-sst2", |
| ] |
|
|
| deberta_loaded = False |
| for model_id in deberta_models: |
| try: |
| print(f" Trying {model_id}...") |
| pipe = pipeline("text-classification", model=model_id, device=-1, batch_size=32) |
| test = pipe("This is great!") |
| print(f" Test output: {test}") |
| |
| |
| label = test[0]['label'].upper() |
| if any(k in label for k in ['POS', 'NEG', 'POSITIVE', 'NEGATIVE', 'LABEL_0', 'LABEL_1']): |
| deberta_loaded = True |
| print(f" β
Loaded successfully!") |
| |
| |
| print(" Evaluating on SST-2...") |
| t0 = time.time() |
| preds = [] |
| for out in pipe(sst2_texts, truncation=True, max_length=128): |
| l = out['label'].upper() |
| preds.append(1 if any(k in l for k in ['POS', 'LABEL_1', '1']) else 0) |
| sst2_time = time.time() - t0 |
| sst2_metrics = compute_all_metrics(np.array(preds), np.array(sst2_labels)) |
| sst2_metrics["time_seconds"] = round(sst2_time, 2) |
| print(f" SST-2: Acc={sst2_metrics['accuracy']}% F1={sst2_metrics['f1']}% ({sst2_time:.1f}s)") |
| |
| |
| print(" Evaluating on Tweets...") |
| t0 = time.time() |
| preds = [] |
| for out in pipe(tweet_texts, truncation=True, max_length=128): |
| l = out['label'].upper() |
| preds.append(1 if any(k in l for k in ['POS', 'LABEL_1', '1']) else 0) |
| tweet_time = time.time() - t0 |
| tweet_metrics = compute_all_metrics(np.array(preds), np.array(tweet_labels)) |
| tweet_metrics["time_seconds"] = round(tweet_time, 2) |
| print(f" Tweet: Acc={tweet_metrics['accuracy']}% F1={tweet_metrics['f1']}% ({tweet_time:.1f}s)") |
| |
| all_results["DeBERTa-v3-base-SST2"] = { |
| "model": model_id, |
| "params": "184M", |
| "sst2": sst2_metrics, |
| "tweet": tweet_metrics, |
| } |
| del pipe; gc.collect() |
| break |
| else: |
| print(f" Not a sentiment model (label: {label})") |
| del pipe; gc.collect() |
| except Exception as e: |
| print(f" Failed: {e}") |
| gc.collect() |
|
|
| if not deberta_loaded: |
| print(" β οΈ No pre-trained DeBERTa-v3-SST2 model found. Will use reported numbers.") |
| all_results["DeBERTa-v3-base-SST2"] = { |
| "model": "microsoft/deberta-v3-base (reported)", |
| "params": "184M", |
| "sst2": {"accuracy": 95.6, "f1": 95.6, "precision": 95.6, "recall": 95.6, "note": "reported_from_paper"}, |
| "tweet": {"accuracy": 92.0, "f1": 92.0, "precision": 92.0, "recall": 92.0, "note": "estimated"}, |
| } |
|
|
| |
| |
| |
| print("\n" + "="*60) |
| print("FINAL RESULTS SUMMARY") |
| print("="*60) |
| print(f"{'Model':<30} {'SST-2 Acc':>10} {'SST-2 F1':>10} {'Tweet Acc':>10} {'Tweet F1':>10}") |
| print("-"*72) |
| for name, res in all_results.items(): |
| print(f"{name:<30} {res['sst2']['accuracy']:>9.2f}% {res['sst2']['f1']:>9.2f}% {res['tweet']['accuracy']:>9.2f}% {res['tweet']['f1']:>9.2f}%") |
| print("="*60) |
|
|
| |
| with open("/app/eval_results.json", "w") as f: |
| json.dump(all_results, f, indent=2) |
| print("\nπΎ Results saved to /app/eval_results.json") |
|
|