narcolepticchicken commited on
Commit
e95c4a3
Β·
verified Β·
1 Parent(s): 96c57cf

Upload training/train_bert_5class.py

Browse files
Files changed (1) hide show
  1. training/train_bert_5class.py +187 -0
training/train_bert_5class.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Retrain BERT as a 5-class tier router.
2
+
3
+ Uses SPROUT data to predict optimal tier (1-5) directly.
4
+ """
5
+ import os, json, random
6
+ import numpy as np
7
+ from datasets import load_dataset
8
+ from transformers import (
9
+ AutoTokenizer, AutoModelForSequenceClassification,
10
+ TrainingArguments, Trainer, DataCollatorWithPadding,
11
+ )
12
+ import torch
13
+
14
+ REPO = "narcolepticchicken/agent-cost-optimizer"
15
+ print("BERT 5-Class Tier Router Training")
16
+ print("="*60)
17
+
18
+ # ── Load SPROUT ──
19
+ print("\n[1] Loading SPROUT dataset...")
20
+ ds = load_dataset("CARROT-LLM-Routing/SPROUT", split="train", trust_remote_code=True)
21
+ print(f" Total rows: {len(ds)}")
22
+ print(f" Columns: {ds.column_names}")
23
+
24
+ # ── Model tier mapping ──
25
+ # SPROUT models β†’ tiers (same as v11 training)
26
+ MODEL_TIER_MAP = {}
27
+ TIER_MODELS = {
28
+ 1: ["gemma-2-2b-it","phi-3-mini-128k-instruct","qwen2.5-3b-instruct",
29
+ "llama-3.2-3b-instruct","deepseek-v3.2"],
30
+ 2: ["gemma-2-9b-it","mistral-7b-instruct-v0.3","qwen2.5-7b-instruct",
31
+ "llama-3.1-8b-instruct","gpt-5-nano","gpt-5-mini"],
32
+ 3: ["qwen2.5-32b-instruct","mixtral-8x7b-instruct-v0.1",
33
+ "gemma-2-27b-it","gemini-2.5-pro"],
34
+ 4: ["claude-opus-4.7","gpt-5.2","llama-3.1-70b-instruct",
35
+ "qwen2.5-72b-instruct"],
36
+ 5: ["gemini-3-pro","deepseek-v4-flash"],
37
+ }
38
+ for tier, models in TIER_MODELS.items():
39
+ for m in models:
40
+ MODEL_TIER_MAP[m.lower()] = tier
41
+
42
+ # ── Build training data ──
43
+ print("\n[2] Building training data...")
44
+ texts = []
45
+ labels = []
46
+ skipped = 0
47
+
48
+ for row in ds:
49
+ # Find the cheapest tier that succeeded
50
+ best_tier = 5 # default
51
+ found = False
52
+
53
+ # Try to find model results in the row
54
+ for tier in range(1, 6):
55
+ for m in TIER_MODELS.get(tier, []):
56
+ m_lower = m.lower()
57
+ # Check various possible column names
58
+ for col in ds.column_names:
59
+ if m_lower in col.lower():
60
+ val = row.get(col)
61
+ if isinstance(val, (int, float)) and val > 0:
62
+ best_tier = tier
63
+ found = True
64
+ break
65
+ if found:
66
+ break
67
+ if found:
68
+ break
69
+
70
+ # Get the prompt/question text
71
+ prompt = ""
72
+ for col in ["prompt", "question", "input", "query", "problem_statement", "instruction"]:
73
+ if col in ds.column_names:
74
+ prompt = str(row[col])
75
+ break
76
+
77
+ if not prompt:
78
+ # Try first string column
79
+ for col in ds.column_names:
80
+ if isinstance(row[col], str) and len(row[col]) > 20:
81
+ prompt = row[col]
82
+ break
83
+
84
+ if prompt and len(prompt) > 10:
85
+ texts.append(prompt[:2000]) # truncate long texts
86
+ labels.append(best_tier - 1) # 0-indexed for classification
87
+ else:
88
+ skipped += 1
89
+
90
+ print(f" Training samples: {len(texts)}")
91
+ print(f" Skipped: {skipped}")
92
+ print(f" Label distribution:")
93
+ from collections import Counter
94
+ label_dist = Counter(labels)
95
+ for label in sorted(label_dist):
96
+ print(f" Tier {label+1}: {label_dist[label]} ({label_dist[label]/len(labels)*100:.1f}%)")
97
+
98
+ # ── Tokenize ──
99
+ print("\n[3] Tokenizing...")
100
+ tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
101
+
102
+ encodings = tokenizer(texts, truncation=True, max_length=512, padding=False)
103
+
104
+ class TierDataset(torch.utils.data.Dataset):
105
+ def __init__(self, encodings, labels):
106
+ self.encodings = encodings
107
+ self.labels = labels
108
+ def __getitem__(self, idx):
109
+ item = {k: torch.tensor(v[idx]) for k, v in self.encodings.items()}
110
+ item["labels"] = torch.tensor(self.labels[idx])
111
+ return item
112
+ def __len__(self):
113
+ return len(self.labels)
114
+
115
+ # Split
116
+ split = int(0.9 * len(texts))
117
+ train_ds = TierDataset(
118
+ {k: v[:split] for k, v in encodings.items()},
119
+ labels[:split],
120
+ )
121
+ eval_ds = TierDataset(
122
+ {k: v[split:] for k, v in encodings.items()},
123
+ labels[split:],
124
+ )
125
+ print(f" Train: {len(train_ds)}, Eval: {len(eval_ds)}")
126
+
127
+ # ── Train ──
128
+ print("\n[4] Training 5-class BERT router...")
129
+ model = AutoModelForSequenceClassification.from_pretrained(
130
+ "distilbert-base-uncased", num_labels=5
131
+ )
132
+
133
+ training_args = TrainingArguments(
134
+ output_dir="/tmp/bert_5class",
135
+ num_train_epochs=3,
136
+ per_device_train_batch_size=32,
137
+ per_device_eval_batch_size=64,
138
+ learning_rate=2e-5,
139
+ weight_decay=0.01,
140
+ eval_strategy="epoch",
141
+ save_strategy="epoch",
142
+ load_best_model_at_end=True,
143
+ metric_for_best_model="eval_accuracy",
144
+ logging_steps=50,
145
+ disable_tqdm=True,
146
+ )
147
+
148
+ def compute_metrics(eval_pred):
149
+ logits, labels = eval_pred
150
+ preds = np.argmax(logits, axis=-1)
151
+ acc = np.mean(preds == labels)
152
+ return {"accuracy": acc}
153
+
154
+ trainer = Trainer(
155
+ model=model,
156
+ args=training_args,
157
+ train_dataset=train_ds,
158
+ eval_dataset=eval_ds,
159
+ tokenizer=tokenizer,
160
+ data_collator=DataCollatorWithPadding(tokenizer),
161
+ compute_metrics=compute_metrics,
162
+ )
163
+
164
+ trainer.train()
165
+
166
+ # ── Save and upload ──
167
+ print("\n[5] Saving model...")
168
+ save_dir = "/tmp/bert_5class_final"
169
+ model.save_pretrained(save_dir)
170
+ tokenizer.save_pretrained(save_dir)
171
+ print(f" Saved to {save_dir}")
172
+
173
+ # Upload to Hub
174
+ from huggingface_hub import HfApi
175
+ api = HfApi()
176
+ for fname in os.listdir(save_dir):
177
+ fpath = os.path.join(save_dir, fname)
178
+ if os.path.isfile(fpath):
179
+ api.upload_file(
180
+ path_or_fileobj=fpath,
181
+ path_in_repo=f"router_models/bert_5class/{fname}",
182
+ repo_id=REPO,
183
+ repo_type="model",
184
+ )
185
+ print(f" Uploaded {fname}")
186
+
187
+ print("\nDONE! 5-class BERT router saved to router_models/bert_5class/")