"""Retrain BERT as a 5-class tier router. Uses SPROUT data to predict optimal tier (1-5) directly. """ import os, json, random import numpy as np from datasets import load_dataset from transformers import ( AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, DataCollatorWithPadding, ) import torch REPO = "narcolepticchicken/agent-cost-optimizer" print("BERT 5-Class Tier Router Training") print("="*60) # ── Load SPROUT ── print("\n[1] Loading SPROUT dataset...") ds = load_dataset("CARROT-LLM-Routing/SPROUT", split="train", trust_remote_code=True) print(f" Total rows: {len(ds)}") print(f" Columns: {ds.column_names}") # ── Model tier mapping ── # SPROUT models → tiers (same as v11 training) MODEL_TIER_MAP = {} TIER_MODELS = { 1: ["gemma-2-2b-it","phi-3-mini-128k-instruct","qwen2.5-3b-instruct", "llama-3.2-3b-instruct","deepseek-v3.2"], 2: ["gemma-2-9b-it","mistral-7b-instruct-v0.3","qwen2.5-7b-instruct", "llama-3.1-8b-instruct","gpt-5-nano","gpt-5-mini"], 3: ["qwen2.5-32b-instruct","mixtral-8x7b-instruct-v0.1", "gemma-2-27b-it","gemini-2.5-pro"], 4: ["claude-opus-4.7","gpt-5.2","llama-3.1-70b-instruct", "qwen2.5-72b-instruct"], 5: ["gemini-3-pro","deepseek-v4-flash"], } for tier, models in TIER_MODELS.items(): for m in models: MODEL_TIER_MAP[m.lower()] = tier # ── Build training data ── print("\n[2] Building training data...") texts = [] labels = [] skipped = 0 for row in ds: # Find the cheapest tier that succeeded best_tier = 5 # default found = False # Try to find model results in the row for tier in range(1, 6): for m in TIER_MODELS.get(tier, []): m_lower = m.lower() # Check various possible column names for col in ds.column_names: if m_lower in col.lower(): val = row.get(col) if isinstance(val, (int, float)) and val > 0: best_tier = tier found = True break if found: break if found: break # Get the prompt/question text prompt = "" for col in ["prompt", "question", "input", "query", "problem_statement", "instruction"]: if col in ds.column_names: prompt = str(row[col]) break if not prompt: # Try first string column for col in ds.column_names: if isinstance(row[col], str) and len(row[col]) > 20: prompt = row[col] break if prompt and len(prompt) > 10: texts.append(prompt[:2000]) # truncate long texts labels.append(best_tier - 1) # 0-indexed for classification else: skipped += 1 print(f" Training samples: {len(texts)}") print(f" Skipped: {skipped}") print(f" Label distribution:") from collections import Counter label_dist = Counter(labels) for label in sorted(label_dist): print(f" Tier {label+1}: {label_dist[label]} ({label_dist[label]/len(labels)*100:.1f}%)") # ── Tokenize ── print("\n[3] Tokenizing...") tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") encodings = tokenizer(texts, truncation=True, max_length=512, padding=False) class TierDataset(torch.utils.data.Dataset): def __init__(self, encodings, labels): self.encodings = encodings self.labels = labels def __getitem__(self, idx): item = {k: torch.tensor(v[idx]) for k, v in self.encodings.items()} item["labels"] = torch.tensor(self.labels[idx]) return item def __len__(self): return len(self.labels) # Split split = int(0.9 * len(texts)) train_ds = TierDataset( {k: v[:split] for k, v in encodings.items()}, labels[:split], ) eval_ds = TierDataset( {k: v[split:] for k, v in encodings.items()}, labels[split:], ) print(f" Train: {len(train_ds)}, Eval: {len(eval_ds)}") # ── Train ── print("\n[4] Training 5-class BERT router...") model = AutoModelForSequenceClassification.from_pretrained( "distilbert-base-uncased", num_labels=5 ) training_args = TrainingArguments( output_dir="/tmp/bert_5class", num_train_epochs=3, per_device_train_batch_size=32, per_device_eval_batch_size=64, learning_rate=2e-5, weight_decay=0.01, eval_strategy="epoch", save_strategy="epoch", load_best_model_at_end=True, metric_for_best_model="eval_accuracy", logging_steps=50, disable_tqdm=True, ) def compute_metrics(eval_pred): logits, labels = eval_pred preds = np.argmax(logits, axis=-1) acc = np.mean(preds == labels) return {"accuracy": acc} trainer = Trainer( model=model, args=training_args, train_dataset=train_ds, eval_dataset=eval_ds, tokenizer=tokenizer, data_collator=DataCollatorWithPadding(tokenizer), compute_metrics=compute_metrics, ) trainer.train() # ── Save and upload ── print("\n[5] Saving model...") save_dir = "/tmp/bert_5class_final" model.save_pretrained(save_dir) tokenizer.save_pretrained(save_dir) print(f" Saved to {save_dir}") # Upload to Hub from huggingface_hub import HfApi api = HfApi() for fname in os.listdir(save_dir): fpath = os.path.join(save_dir, fname) if os.path.isfile(fpath): api.upload_file( path_or_fileobj=fpath, path_in_repo=f"router_models/bert_5class/{fname}", repo_id=REPO, repo_type="model", ) print(f" Uploaded {fname}") print("\nDONE! 5-class BERT router saved to router_models/bert_5class/")