rajvivan's picture
Add training script
47c017c verified
"""
Social Media Sentiment Analysis - End-to-End Training Pipeline
Fine-tunes DeBERTa-v3-base on SST-2 + Tweet Sentiment for 95%+ accuracy
Based on: DeBERTaV3 paper (arxiv:2111.09543) training recipe
Optimized for CPU training with gradient accumulation
"""
import os
import json
import numpy as np
import torch
from datasets import load_dataset, concatenate_datasets
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
TrainingArguments,
Trainer,
DataCollatorWithPadding,
EarlyStoppingCallback,
)
import evaluate
import time
# ── Configuration ──────────────────────────────────────────────────
MODEL_NAME = "microsoft/deberta-v3-base"
HUB_MODEL_ID = "rajvivan/deberta-v3-sentiment-analysis"
MAX_LENGTH = 128 # Shorter for CPU efficiency
NUM_EPOCHS = 3
LEARNING_RATE = 2e-5
TRAIN_BATCH_SIZE = 8 # Small for CPU RAM
EVAL_BATCH_SIZE = 16
GRADIENT_ACCUM = 4 # Effective batch = 32
WARMUP_STEPS = 300
WEIGHT_DECAY = 0.01
SEED = 42
device = 'cuda' if torch.cuda.is_available() else 'cpu'
use_fp16 = torch.cuda.is_available()
print(f"πŸš€ Starting Sentiment Analysis Training")
print(f" Model: {MODEL_NAME}")
print(f" Hub ID: {HUB_MODEL_ID}")
print(f" Device: {device}")
print(f" Effective batch size: {TRAIN_BATCH_SIZE * GRADIENT_ACCUM}")
start_time = time.time()
# ── 1. Load and Prepare Datasets ──────────────────────────────────
print("\nπŸ“¦ Loading datasets...")
sst2 = load_dataset("stanfordnlp/sst2")
tweets = load_dataset("mteb/tweet_sentiment_extraction")
print(f" SST-2 train: {len(sst2['train'])} | val: {len(sst2['validation'])}")
print(f" Tweet train: {len(tweets['train'])} | test: {len(tweets['test'])}")
# ── 2. Tokenizer ──────────────────────────────────────────────────
print("\nπŸ”€ Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
def preprocess_sst2(examples):
return tokenizer(examples["sentence"], truncation=True, max_length=MAX_LENGTH, padding=False)
def preprocess_tweet_text(text):
"""Twitter-specific preprocessing (TimeLM paper)"""
if not text:
return ""
return " ".join(
'@user' if t.startswith('@') and len(t) > 1 else ('http' if t.startswith('http') else t)
for t in text.split(" ")
)
def preprocess_tweets(examples):
texts = [preprocess_tweet_text(t) for t in examples["text"]]
return tokenizer(texts, truncation=True, max_length=MAX_LENGTH, padding=False)
# Tokenize SST-2
print(" Tokenizing SST-2...")
sst2_tok = sst2.map(preprocess_sst2, batched=True, num_proc=2)
sst2_tok = sst2_tok.remove_columns(["idx", "sentence"])
# Prepare tweets: binary (remove neutral, remap labels)
print(" Preparing tweet binary data...")
tweets_train_bin = tweets["train"].filter(lambda x: x["label"] != 1)
tweets_test_bin = tweets["test"].filter(lambda x: x["label"] != 1)
def remap_labels(example):
example["label"] = 1 if example["label"] == 2 else 0
return example
tweets_train_bin = tweets_train_bin.map(remap_labels)
tweets_test_bin = tweets_test_bin.map(remap_labels)
tweets_train_tok = tweets_train_bin.map(preprocess_tweets, batched=True, num_proc=2)
tweets_test_tok = tweets_test_bin.map(preprocess_tweets, batched=True, num_proc=2)
tweets_train_tok = tweets_train_tok.remove_columns(["id", "text", "label_text"])
tweets_test_tok = tweets_test_bin.map(preprocess_tweets, batched=True, num_proc=2)
tweets_test_tok = tweets_test_tok.remove_columns(["id", "text", "label_text"])
# Cast label types to match SST-2 ClassLabel
from datasets import ClassLabel, Features, Value
target_features = sst2_tok["train"].features
tweets_train_tok = tweets_train_tok.cast(target_features)
tweets_test_tok = tweets_test_tok.cast(target_features)
# Combine for training
combined_train = concatenate_datasets([sst2_tok["train"], tweets_train_tok])
print(f" Combined train: {len(combined_train)} samples")
print(f" SST-2 val: {len(sst2_tok['validation'])} samples")
print(f" Tweet test: {len(tweets_test_tok)} samples")
# ── 3. Model ──────────────────────────────────────────────────────
print("\n🧠 Loading model...")
id2label = {0: "NEGATIVE", 1: "POSITIVE"}
label2id = {"NEGATIVE": 0, "POSITIVE": 1}
model = AutoModelForSequenceClassification.from_pretrained(
MODEL_NAME, num_labels=2, id2label=id2label, label2id=label2id,
)
total_params = model.num_parameters()
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f" Total params: {total_params:,}")
print(f" Trainable: {trainable:,}")
# ── 4. Metrics ────────────────────────────────────────────────────
accuracy_metric = evaluate.load("accuracy")
f1_metric = evaluate.load("f1")
precision_metric = evaluate.load("precision")
recall_metric = evaluate.load("recall")
def compute_metrics(eval_pred):
logits, labels = eval_pred
preds = np.argmax(logits, axis=-1)
return {
"accuracy": accuracy_metric.compute(predictions=preds, references=labels)["accuracy"],
"f1": f1_metric.compute(predictions=preds, references=labels, average="weighted")["f1"],
"precision": precision_metric.compute(predictions=preds, references=labels, average="weighted")["precision"],
"recall": recall_metric.compute(predictions=preds, references=labels, average="weighted")["recall"],
}
def preprocess_logits_for_metrics(logits, labels):
if isinstance(logits, tuple):
logits = logits[0]
return logits.argmax(dim=-1)
# ── 5. Training ───────────────────────────────────────────────────
print("\nβš™οΈ Configuring training...")
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
training_args = TrainingArguments(
output_dir="./deberta-sentiment-output",
hub_model_id=HUB_MODEL_ID,
push_to_hub=True,
num_train_epochs=NUM_EPOCHS,
per_device_train_batch_size=TRAIN_BATCH_SIZE,
per_device_eval_batch_size=EVAL_BATCH_SIZE,
gradient_accumulation_steps=GRADIENT_ACCUM,
learning_rate=LEARNING_RATE,
weight_decay=WEIGHT_DECAY,
warmup_steps=WARMUP_STEPS,
lr_scheduler_type="linear",
max_grad_norm=1.0,
eval_strategy="epoch",
save_strategy="epoch",
logging_strategy="steps",
logging_steps=100,
logging_first_step=True,
disable_tqdm=True,
load_best_model_at_end=True,
metric_for_best_model="accuracy",
greater_is_better=True,
bf16=False,
fp16=use_fp16,
seed=SEED,
report_to="none",
save_total_limit=2,
dataloader_num_workers=2,
gradient_checkpointing=True, # Save memory
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=combined_train,
eval_dataset=sst2_tok["validation"],
processing_class=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
)
# ── 6. Train ──────────────────────────────────────────────────────
print("\nπŸ‹οΈ Starting training...")
print(f" Steps per epoch: ~{len(combined_train) // (TRAIN_BATCH_SIZE * GRADIENT_ACCUM)}")
train_result = trainer.train()
train_time = time.time() - start_time
# ── 7. Evaluate ───────────────────────────────────────────────────
print("\nπŸ“Š Evaluating on SST-2 validation...")
sst2_results = trainer.evaluate(sst2_tok["validation"])
print(f" Accuracy: {sst2_results['eval_accuracy']:.4f}")
print(f" F1: {sst2_results['eval_f1']:.4f}")
print(f" Precision: {sst2_results['eval_precision']:.4f}")
print(f" Recall: {sst2_results['eval_recall']:.4f}")
print("\nπŸ“Š Evaluating on Tweet Sentiment test set...")
tweet_results = trainer.evaluate(tweets_test_tok)
print(f" Accuracy: {tweet_results['eval_accuracy']:.4f}")
print(f" F1: {tweet_results['eval_f1']:.4f}")
print(f" Precision: {tweet_results['eval_precision']:.4f}")
print(f" Recall: {tweet_results['eval_recall']:.4f}")
# ── 8. Save results ──────────────────────────────────────────────
os.makedirs("./deberta-sentiment-output", exist_ok=True)
results = {
"model": MODEL_NAME,
"hub_model_id": HUB_MODEL_ID,
"total_parameters": total_params,
"training_config": {
"learning_rate": LEARNING_RATE,
"effective_batch_size": TRAIN_BATCH_SIZE * GRADIENT_ACCUM,
"per_device_batch_size": TRAIN_BATCH_SIZE,
"gradient_accumulation_steps": GRADIENT_ACCUM,
"epochs": NUM_EPOCHS,
"warmup_steps": WARMUP_STEPS,
"weight_decay": WEIGHT_DECAY,
"max_length": MAX_LENGTH,
"seed": SEED,
"gradient_checkpointing": True,
},
"datasets": {
"sst2_train": len(sst2_tok["train"]),
"tweet_train_binary": len(tweets_train_tok),
"combined_train": len(combined_train),
"sst2_val": len(sst2_tok["validation"]),
"tweet_test_binary": len(tweets_test_tok),
},
"sst2_eval_results": {k.replace("eval_", ""): v for k, v in sst2_results.items()},
"tweet_eval_results": {k.replace("eval_", ""): v for k, v in tweet_results.items()},
"training_loss": train_result.training_loss,
"training_time_seconds": train_time,
}
with open("./deberta-sentiment-output/results.json", "w") as f:
json.dump(results, f, indent=2)
print(f"\nπŸ’Ύ Results saved")
# ── 9. Push to Hub ────────────────────────────────────────────────
print("\nπŸš€ Pushing model to Hub...")
trainer.push_to_hub(
commit_message="DeBERTa-v3-base sentiment analysis - SST-2 + Tweet binary",
tags=["sentiment-analysis", "deberta-v3", "social-media", "text-classification"],
)
print(f"βœ… Model published: https://huggingface.co/{HUB_MODEL_ID}")
# ── Summary ───────────────────────────────────────────────────────
print("\n" + "="*60)
print("TRAINING COMPLETE - FINAL RESULTS")
print("="*60)
print(f"Training time: {train_time/60:.1f} minutes")
print(f"Training loss: {train_result.training_loss:.4f}")
print(f"SST-2 Accuracy: {sst2_results['eval_accuracy']:.4f}")
print(f"SST-2 F1: {sst2_results['eval_f1']:.4f}")
print(f"Tweet Accuracy: {tweet_results['eval_accuracy']:.4f}")
print(f"Tweet F1: {tweet_results['eval_f1']:.4f}")
print("="*60)