agent-cost-optimizer / training /train_bert_5class.py
narcolepticchicken's picture
Upload training/train_bert_5class.py
e95c4a3 verified
"""Retrain BERT as a 5-class tier router.
Uses SPROUT data to predict optimal tier (1-5) directly.
"""
import os, json, random
import numpy as np
from datasets import load_dataset
from transformers import (
AutoTokenizer, AutoModelForSequenceClassification,
TrainingArguments, Trainer, DataCollatorWithPadding,
)
import torch
REPO = "narcolepticchicken/agent-cost-optimizer"
print("BERT 5-Class Tier Router Training")
print("="*60)
# ── Load SPROUT ──
print("\n[1] Loading SPROUT dataset...")
ds = load_dataset("CARROT-LLM-Routing/SPROUT", split="train", trust_remote_code=True)
print(f" Total rows: {len(ds)}")
print(f" Columns: {ds.column_names}")
# ── Model tier mapping ──
# SPROUT models β†’ tiers (same as v11 training)
MODEL_TIER_MAP = {}
TIER_MODELS = {
1: ["gemma-2-2b-it","phi-3-mini-128k-instruct","qwen2.5-3b-instruct",
"llama-3.2-3b-instruct","deepseek-v3.2"],
2: ["gemma-2-9b-it","mistral-7b-instruct-v0.3","qwen2.5-7b-instruct",
"llama-3.1-8b-instruct","gpt-5-nano","gpt-5-mini"],
3: ["qwen2.5-32b-instruct","mixtral-8x7b-instruct-v0.1",
"gemma-2-27b-it","gemini-2.5-pro"],
4: ["claude-opus-4.7","gpt-5.2","llama-3.1-70b-instruct",
"qwen2.5-72b-instruct"],
5: ["gemini-3-pro","deepseek-v4-flash"],
}
for tier, models in TIER_MODELS.items():
for m in models:
MODEL_TIER_MAP[m.lower()] = tier
# ── Build training data ──
print("\n[2] Building training data...")
texts = []
labels = []
skipped = 0
for row in ds:
# Find the cheapest tier that succeeded
best_tier = 5 # default
found = False
# Try to find model results in the row
for tier in range(1, 6):
for m in TIER_MODELS.get(tier, []):
m_lower = m.lower()
# Check various possible column names
for col in ds.column_names:
if m_lower in col.lower():
val = row.get(col)
if isinstance(val, (int, float)) and val > 0:
best_tier = tier
found = True
break
if found:
break
if found:
break
# Get the prompt/question text
prompt = ""
for col in ["prompt", "question", "input", "query", "problem_statement", "instruction"]:
if col in ds.column_names:
prompt = str(row[col])
break
if not prompt:
# Try first string column
for col in ds.column_names:
if isinstance(row[col], str) and len(row[col]) > 20:
prompt = row[col]
break
if prompt and len(prompt) > 10:
texts.append(prompt[:2000]) # truncate long texts
labels.append(best_tier - 1) # 0-indexed for classification
else:
skipped += 1
print(f" Training samples: {len(texts)}")
print(f" Skipped: {skipped}")
print(f" Label distribution:")
from collections import Counter
label_dist = Counter(labels)
for label in sorted(label_dist):
print(f" Tier {label+1}: {label_dist[label]} ({label_dist[label]/len(labels)*100:.1f}%)")
# ── Tokenize ──
print("\n[3] Tokenizing...")
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
encodings = tokenizer(texts, truncation=True, max_length=512, padding=False)
class TierDataset(torch.utils.data.Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
item = {k: torch.tensor(v[idx]) for k, v in self.encodings.items()}
item["labels"] = torch.tensor(self.labels[idx])
return item
def __len__(self):
return len(self.labels)
# Split
split = int(0.9 * len(texts))
train_ds = TierDataset(
{k: v[:split] for k, v in encodings.items()},
labels[:split],
)
eval_ds = TierDataset(
{k: v[split:] for k, v in encodings.items()},
labels[split:],
)
print(f" Train: {len(train_ds)}, Eval: {len(eval_ds)}")
# ── Train ──
print("\n[4] Training 5-class BERT router...")
model = AutoModelForSequenceClassification.from_pretrained(
"distilbert-base-uncased", num_labels=5
)
training_args = TrainingArguments(
output_dir="/tmp/bert_5class",
num_train_epochs=3,
per_device_train_batch_size=32,
per_device_eval_batch_size=64,
learning_rate=2e-5,
weight_decay=0.01,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="eval_accuracy",
logging_steps=50,
disable_tqdm=True,
)
def compute_metrics(eval_pred):
logits, labels = eval_pred
preds = np.argmax(logits, axis=-1)
acc = np.mean(preds == labels)
return {"accuracy": acc}
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=eval_ds,
tokenizer=tokenizer,
data_collator=DataCollatorWithPadding(tokenizer),
compute_metrics=compute_metrics,
)
trainer.train()
# ── Save and upload ──
print("\n[5] Saving model...")
save_dir = "/tmp/bert_5class_final"
model.save_pretrained(save_dir)
tokenizer.save_pretrained(save_dir)
print(f" Saved to {save_dir}")
# Upload to Hub
from huggingface_hub import HfApi
api = HfApi()
for fname in os.listdir(save_dir):
fpath = os.path.join(save_dir, fname)
if os.path.isfile(fpath):
api.upload_file(
path_or_fileobj=fpath,
path_in_repo=f"router_models/bert_5class/{fname}",
repo_id=REPO,
repo_type="model",
)
print(f" Uploaded {fname}")
print("\nDONE! 5-class BERT router saved to router_models/bert_5class/")