| """Retrain BERT as a 5-class tier router. |
| |
| Uses SPROUT data to predict optimal tier (1-5) directly. |
| """ |
| import os, json, random |
| import numpy as np |
| from datasets import load_dataset |
| from transformers import ( |
| AutoTokenizer, AutoModelForSequenceClassification, |
| TrainingArguments, Trainer, DataCollatorWithPadding, |
| ) |
| import torch |
|
|
| REPO = "narcolepticchicken/agent-cost-optimizer" |
| print("BERT 5-Class Tier Router Training") |
| print("="*60) |
|
|
| |
| print("\n[1] Loading SPROUT dataset...") |
| ds = load_dataset("CARROT-LLM-Routing/SPROUT", split="train", trust_remote_code=True) |
| print(f" Total rows: {len(ds)}") |
| print(f" Columns: {ds.column_names}") |
|
|
| |
| |
| MODEL_TIER_MAP = {} |
| TIER_MODELS = { |
| 1: ["gemma-2-2b-it","phi-3-mini-128k-instruct","qwen2.5-3b-instruct", |
| "llama-3.2-3b-instruct","deepseek-v3.2"], |
| 2: ["gemma-2-9b-it","mistral-7b-instruct-v0.3","qwen2.5-7b-instruct", |
| "llama-3.1-8b-instruct","gpt-5-nano","gpt-5-mini"], |
| 3: ["qwen2.5-32b-instruct","mixtral-8x7b-instruct-v0.1", |
| "gemma-2-27b-it","gemini-2.5-pro"], |
| 4: ["claude-opus-4.7","gpt-5.2","llama-3.1-70b-instruct", |
| "qwen2.5-72b-instruct"], |
| 5: ["gemini-3-pro","deepseek-v4-flash"], |
| } |
| for tier, models in TIER_MODELS.items(): |
| for m in models: |
| MODEL_TIER_MAP[m.lower()] = tier |
|
|
| |
| print("\n[2] Building training data...") |
| texts = [] |
| labels = [] |
| skipped = 0 |
|
|
| for row in ds: |
| |
| best_tier = 5 |
| found = False |
| |
| |
| for tier in range(1, 6): |
| for m in TIER_MODELS.get(tier, []): |
| m_lower = m.lower() |
| |
| for col in ds.column_names: |
| if m_lower in col.lower(): |
| val = row.get(col) |
| if isinstance(val, (int, float)) and val > 0: |
| best_tier = tier |
| found = True |
| break |
| if found: |
| break |
| if found: |
| break |
| |
| |
| prompt = "" |
| for col in ["prompt", "question", "input", "query", "problem_statement", "instruction"]: |
| if col in ds.column_names: |
| prompt = str(row[col]) |
| break |
| |
| if not prompt: |
| |
| for col in ds.column_names: |
| if isinstance(row[col], str) and len(row[col]) > 20: |
| prompt = row[col] |
| break |
| |
| if prompt and len(prompt) > 10: |
| texts.append(prompt[:2000]) |
| labels.append(best_tier - 1) |
| else: |
| skipped += 1 |
|
|
| print(f" Training samples: {len(texts)}") |
| print(f" Skipped: {skipped}") |
| print(f" Label distribution:") |
| from collections import Counter |
| label_dist = Counter(labels) |
| for label in sorted(label_dist): |
| print(f" Tier {label+1}: {label_dist[label]} ({label_dist[label]/len(labels)*100:.1f}%)") |
|
|
| |
| print("\n[3] Tokenizing...") |
| tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") |
|
|
| encodings = tokenizer(texts, truncation=True, max_length=512, padding=False) |
|
|
| class TierDataset(torch.utils.data.Dataset): |
| def __init__(self, encodings, labels): |
| self.encodings = encodings |
| self.labels = labels |
| def __getitem__(self, idx): |
| item = {k: torch.tensor(v[idx]) for k, v in self.encodings.items()} |
| item["labels"] = torch.tensor(self.labels[idx]) |
| return item |
| def __len__(self): |
| return len(self.labels) |
|
|
| |
| split = int(0.9 * len(texts)) |
| train_ds = TierDataset( |
| {k: v[:split] for k, v in encodings.items()}, |
| labels[:split], |
| ) |
| eval_ds = TierDataset( |
| {k: v[split:] for k, v in encodings.items()}, |
| labels[split:], |
| ) |
| print(f" Train: {len(train_ds)}, Eval: {len(eval_ds)}") |
|
|
| |
| print("\n[4] Training 5-class BERT router...") |
| model = AutoModelForSequenceClassification.from_pretrained( |
| "distilbert-base-uncased", num_labels=5 |
| ) |
|
|
| training_args = TrainingArguments( |
| output_dir="/tmp/bert_5class", |
| num_train_epochs=3, |
| per_device_train_batch_size=32, |
| per_device_eval_batch_size=64, |
| learning_rate=2e-5, |
| weight_decay=0.01, |
| eval_strategy="epoch", |
| save_strategy="epoch", |
| load_best_model_at_end=True, |
| metric_for_best_model="eval_accuracy", |
| logging_steps=50, |
| disable_tqdm=True, |
| ) |
|
|
| def compute_metrics(eval_pred): |
| logits, labels = eval_pred |
| preds = np.argmax(logits, axis=-1) |
| acc = np.mean(preds == labels) |
| return {"accuracy": acc} |
|
|
| trainer = Trainer( |
| model=model, |
| args=training_args, |
| train_dataset=train_ds, |
| eval_dataset=eval_ds, |
| tokenizer=tokenizer, |
| data_collator=DataCollatorWithPadding(tokenizer), |
| compute_metrics=compute_metrics, |
| ) |
|
|
| trainer.train() |
|
|
| |
| print("\n[5] Saving model...") |
| save_dir = "/tmp/bert_5class_final" |
| model.save_pretrained(save_dir) |
| tokenizer.save_pretrained(save_dir) |
| print(f" Saved to {save_dir}") |
|
|
| |
| from huggingface_hub import HfApi |
| api = HfApi() |
| for fname in os.listdir(save_dir): |
| fpath = os.path.join(save_dir, fname) |
| if os.path.isfile(fpath): |
| api.upload_file( |
| path_or_fileobj=fpath, |
| path_in_repo=f"router_models/bert_5class/{fname}", |
| repo_id=REPO, |
| repo_type="model", |
| ) |
| print(f" Uploaded {fname}") |
|
|
| print("\nDONE! 5-class BERT router saved to router_models/bert_5class/") |
|
|