""" Train Speculative Verifier — v3 ================================ Trains Qwen3-4B as an ACCEPT/REJECT classifier for proposed actions. Given conversation context and a proposed action, the model outputs either "ACCEPT" or "REJECT". Training: SFT on (context + proposal + "ACCEPT/REJECT?") -> "ACCEPT"/"REJECT" Inference: generate single token, check if it's ACCEPT or REJECT. Usage: python train_verifier_v3.py """ import torch from datasets import load_dataset, Dataset from trl import SFTConfig, SFTTrainer HUB_ORG = "narcolepticchicken" VERIFIER_SYSTEM = ( "You are an action verifier. Given conversation context and a proposed next action, " "determine if the proposal is correct. Respond with exactly ACCEPT or REJECT." ) def build_verifier_sft_data(): """Load verifier pairs and convert to conversational SFT format.""" dataset = load_dataset(f"{HUB_ORG}/speculative-verifier-v3-main") print(f"Loaded raw verifier data: {len(dataset['train'])} train, {len(dataset['test'])} test") sft_rows = [] for split_name in ["train", "test"]: for row in dataset[split_name]: ctx = row["context"] proposal = row["proposal"] label = row["label"] # 1=ACCEPT, 0=REJECT answer = "ACCEPT" if label == 1 else "REJECT" msgs = [{"role": "system", "content": VERIFIER_SYSTEM}] # Add context for m in ctx[-6:]: msgs.append({"role": m["role"], "content": str(m["content"])[:400]}) # Add proposal query msgs.append({ "role": "user", "content": f"Proposed next action: {proposal}\n\nIs this the correct next action? ACCEPT or REJECT?" }) msgs.append({"role": "assistant", "content": answer}) sft_rows.append({"messages": msgs, "split": split_name}) train_rows = [r for r in sft_rows if r["split"] == "train"] test_rows = [r for r in sft_rows if r["split"] == "test"] train_ds = Dataset.from_list([{"messages": r["messages"]} for r in train_rows]) test_ds = Dataset.from_list([{"messages": r["messages"]} for r in test_rows]) print(f"SFT format: {len(train_rows)} train, {len(test_rows)} test") return train_ds, test_ds def train(): train_ds, test_ds = build_verifier_sft_data() training_args = SFTConfig( output_dir="./output", hub_model_id=f"{HUB_ORG}/speculative-verifier-v3-4b", max_seq_length=2048, packing=False, learning_rate=2e-5, per_device_train_batch_size=4, gradient_accumulation_steps=4, num_train_epochs=2, bf16=True, gradient_checkpointing=True, logging_steps=10, logging_first_step=True, save_strategy="epoch", push_to_hub=True, disable_tqdm=True, report_to="none", ) trainer = SFTTrainer( model="Qwen/Qwen3-4B", args=training_args, train_dataset=train_ds, eval_dataset=test_ds, ) print(f"Training Qwen3-4B verifier -> {HUB_ORG}/speculative-verifier-v3-4b") trainer.train() trainer.save_model() trainer.push_to_hub() # Eval metrics = trainer.evaluate() print(f"Eval metrics: {metrics}") if __name__ == "__main__": train()