File size: 5,686 Bytes
e95c4a3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 | """Retrain BERT as a 5-class tier router.
Uses SPROUT data to predict optimal tier (1-5) directly.
"""
import os, json, random
import numpy as np
from datasets import load_dataset
from transformers import (
AutoTokenizer, AutoModelForSequenceClassification,
TrainingArguments, Trainer, DataCollatorWithPadding,
)
import torch
REPO = "narcolepticchicken/agent-cost-optimizer"
print("BERT 5-Class Tier Router Training")
print("="*60)
# ββ Load SPROUT ββ
print("\n[1] Loading SPROUT dataset...")
ds = load_dataset("CARROT-LLM-Routing/SPROUT", split="train", trust_remote_code=True)
print(f" Total rows: {len(ds)}")
print(f" Columns: {ds.column_names}")
# ββ Model tier mapping ββ
# SPROUT models β tiers (same as v11 training)
MODEL_TIER_MAP = {}
TIER_MODELS = {
1: ["gemma-2-2b-it","phi-3-mini-128k-instruct","qwen2.5-3b-instruct",
"llama-3.2-3b-instruct","deepseek-v3.2"],
2: ["gemma-2-9b-it","mistral-7b-instruct-v0.3","qwen2.5-7b-instruct",
"llama-3.1-8b-instruct","gpt-5-nano","gpt-5-mini"],
3: ["qwen2.5-32b-instruct","mixtral-8x7b-instruct-v0.1",
"gemma-2-27b-it","gemini-2.5-pro"],
4: ["claude-opus-4.7","gpt-5.2","llama-3.1-70b-instruct",
"qwen2.5-72b-instruct"],
5: ["gemini-3-pro","deepseek-v4-flash"],
}
for tier, models in TIER_MODELS.items():
for m in models:
MODEL_TIER_MAP[m.lower()] = tier
# ββ Build training data ββ
print("\n[2] Building training data...")
texts = []
labels = []
skipped = 0
for row in ds:
# Find the cheapest tier that succeeded
best_tier = 5 # default
found = False
# Try to find model results in the row
for tier in range(1, 6):
for m in TIER_MODELS.get(tier, []):
m_lower = m.lower()
# Check various possible column names
for col in ds.column_names:
if m_lower in col.lower():
val = row.get(col)
if isinstance(val, (int, float)) and val > 0:
best_tier = tier
found = True
break
if found:
break
if found:
break
# Get the prompt/question text
prompt = ""
for col in ["prompt", "question", "input", "query", "problem_statement", "instruction"]:
if col in ds.column_names:
prompt = str(row[col])
break
if not prompt:
# Try first string column
for col in ds.column_names:
if isinstance(row[col], str) and len(row[col]) > 20:
prompt = row[col]
break
if prompt and len(prompt) > 10:
texts.append(prompt[:2000]) # truncate long texts
labels.append(best_tier - 1) # 0-indexed for classification
else:
skipped += 1
print(f" Training samples: {len(texts)}")
print(f" Skipped: {skipped}")
print(f" Label distribution:")
from collections import Counter
label_dist = Counter(labels)
for label in sorted(label_dist):
print(f" Tier {label+1}: {label_dist[label]} ({label_dist[label]/len(labels)*100:.1f}%)")
# ββ Tokenize ββ
print("\n[3] Tokenizing...")
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
encodings = tokenizer(texts, truncation=True, max_length=512, padding=False)
class TierDataset(torch.utils.data.Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
item = {k: torch.tensor(v[idx]) for k, v in self.encodings.items()}
item["labels"] = torch.tensor(self.labels[idx])
return item
def __len__(self):
return len(self.labels)
# Split
split = int(0.9 * len(texts))
train_ds = TierDataset(
{k: v[:split] for k, v in encodings.items()},
labels[:split],
)
eval_ds = TierDataset(
{k: v[split:] for k, v in encodings.items()},
labels[split:],
)
print(f" Train: {len(train_ds)}, Eval: {len(eval_ds)}")
# ββ Train ββ
print("\n[4] Training 5-class BERT router...")
model = AutoModelForSequenceClassification.from_pretrained(
"distilbert-base-uncased", num_labels=5
)
training_args = TrainingArguments(
output_dir="/tmp/bert_5class",
num_train_epochs=3,
per_device_train_batch_size=32,
per_device_eval_batch_size=64,
learning_rate=2e-5,
weight_decay=0.01,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="eval_accuracy",
logging_steps=50,
disable_tqdm=True,
)
def compute_metrics(eval_pred):
logits, labels = eval_pred
preds = np.argmax(logits, axis=-1)
acc = np.mean(preds == labels)
return {"accuracy": acc}
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=eval_ds,
tokenizer=tokenizer,
data_collator=DataCollatorWithPadding(tokenizer),
compute_metrics=compute_metrics,
)
trainer.train()
# ββ Save and upload ββ
print("\n[5] Saving model...")
save_dir = "/tmp/bert_5class_final"
model.save_pretrained(save_dir)
tokenizer.save_pretrained(save_dir)
print(f" Saved to {save_dir}")
# Upload to Hub
from huggingface_hub import HfApi
api = HfApi()
for fname in os.listdir(save_dir):
fpath = os.path.join(save_dir, fname)
if os.path.isfile(fpath):
api.upload_file(
path_or_fileobj=fpath,
path_in_repo=f"router_models/bert_5class/{fname}",
repo_id=REPO,
repo_type="model",
)
print(f" Uploaded {fname}")
print("\nDONE! 5-class BERT router saved to router_models/bert_5class/")
|