speculative-tool-actions / train_verifier_v3.py
narcolepticchicken's picture
Upload train_verifier_v3.py
51dec95 verified
"""
Train Speculative Verifier — v3
================================
Trains Qwen3-4B as an ACCEPT/REJECT classifier for proposed actions.
Given conversation context and a proposed action, the model outputs
either "ACCEPT" or "REJECT".
Training: SFT on (context + proposal + "ACCEPT/REJECT?") -> "ACCEPT"/"REJECT"
Inference: generate single token, check if it's ACCEPT or REJECT.
Usage: python train_verifier_v3.py
"""
import torch
from datasets import load_dataset, Dataset
from trl import SFTConfig, SFTTrainer
HUB_ORG = "narcolepticchicken"
VERIFIER_SYSTEM = (
"You are an action verifier. Given conversation context and a proposed next action, "
"determine if the proposal is correct. Respond with exactly ACCEPT or REJECT."
)
def build_verifier_sft_data():
"""Load verifier pairs and convert to conversational SFT format."""
dataset = load_dataset(f"{HUB_ORG}/speculative-verifier-v3-main")
print(f"Loaded raw verifier data: {len(dataset['train'])} train, {len(dataset['test'])} test")
sft_rows = []
for split_name in ["train", "test"]:
for row in dataset[split_name]:
ctx = row["context"]
proposal = row["proposal"]
label = row["label"] # 1=ACCEPT, 0=REJECT
answer = "ACCEPT" if label == 1 else "REJECT"
msgs = [{"role": "system", "content": VERIFIER_SYSTEM}]
# Add context
for m in ctx[-6:]:
msgs.append({"role": m["role"], "content": str(m["content"])[:400]})
# Add proposal query
msgs.append({
"role": "user",
"content": f"Proposed next action: {proposal}\n\nIs this the correct next action? ACCEPT or REJECT?"
})
msgs.append({"role": "assistant", "content": answer})
sft_rows.append({"messages": msgs, "split": split_name})
train_rows = [r for r in sft_rows if r["split"] == "train"]
test_rows = [r for r in sft_rows if r["split"] == "test"]
train_ds = Dataset.from_list([{"messages": r["messages"]} for r in train_rows])
test_ds = Dataset.from_list([{"messages": r["messages"]} for r in test_rows])
print(f"SFT format: {len(train_rows)} train, {len(test_rows)} test")
return train_ds, test_ds
def train():
train_ds, test_ds = build_verifier_sft_data()
training_args = SFTConfig(
output_dir="./output",
hub_model_id=f"{HUB_ORG}/speculative-verifier-v3-4b",
max_seq_length=2048,
packing=False,
learning_rate=2e-5,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
num_train_epochs=2,
bf16=True,
gradient_checkpointing=True,
logging_steps=10,
logging_first_step=True,
save_strategy="epoch",
push_to_hub=True,
disable_tqdm=True,
report_to="none",
)
trainer = SFTTrainer(
model="Qwen/Qwen3-4B",
args=training_args,
train_dataset=train_ds,
eval_dataset=test_ds,
)
print(f"Training Qwen3-4B verifier -> {HUB_ORG}/speculative-verifier-v3-4b")
trainer.train()
trainer.save_model()
trainer.push_to_hub()
# Eval
metrics = trainer.evaluate()
print(f"Eval metrics: {metrics}")
if __name__ == "__main__":
train()