| """ |
| Train Speculative Verifier — v3 |
| ================================ |
| Trains Qwen3-4B as an ACCEPT/REJECT classifier for proposed actions. |
| |
| Given conversation context and a proposed action, the model outputs |
| either "ACCEPT" or "REJECT". |
| |
| Training: SFT on (context + proposal + "ACCEPT/REJECT?") -> "ACCEPT"/"REJECT" |
| Inference: generate single token, check if it's ACCEPT or REJECT. |
| |
| Usage: python train_verifier_v3.py |
| """ |
|
|
| import torch |
| from datasets import load_dataset, Dataset |
| from trl import SFTConfig, SFTTrainer |
|
|
| HUB_ORG = "narcolepticchicken" |
| VERIFIER_SYSTEM = ( |
| "You are an action verifier. Given conversation context and a proposed next action, " |
| "determine if the proposal is correct. Respond with exactly ACCEPT or REJECT." |
| ) |
|
|
|
|
| def build_verifier_sft_data(): |
| """Load verifier pairs and convert to conversational SFT format.""" |
| dataset = load_dataset(f"{HUB_ORG}/speculative-verifier-v3-main") |
| print(f"Loaded raw verifier data: {len(dataset['train'])} train, {len(dataset['test'])} test") |
|
|
| sft_rows = [] |
| for split_name in ["train", "test"]: |
| for row in dataset[split_name]: |
| ctx = row["context"] |
| proposal = row["proposal"] |
| label = row["label"] |
| answer = "ACCEPT" if label == 1 else "REJECT" |
|
|
| msgs = [{"role": "system", "content": VERIFIER_SYSTEM}] |
| |
| for m in ctx[-6:]: |
| msgs.append({"role": m["role"], "content": str(m["content"])[:400]}) |
| |
| msgs.append({ |
| "role": "user", |
| "content": f"Proposed next action: {proposal}\n\nIs this the correct next action? ACCEPT or REJECT?" |
| }) |
| msgs.append({"role": "assistant", "content": answer}) |
| sft_rows.append({"messages": msgs, "split": split_name}) |
|
|
| train_rows = [r for r in sft_rows if r["split"] == "train"] |
| test_rows = [r for r in sft_rows if r["split"] == "test"] |
|
|
| train_ds = Dataset.from_list([{"messages": r["messages"]} for r in train_rows]) |
| test_ds = Dataset.from_list([{"messages": r["messages"]} for r in test_rows]) |
|
|
| print(f"SFT format: {len(train_rows)} train, {len(test_rows)} test") |
| return train_ds, test_ds |
|
|
|
|
| def train(): |
| train_ds, test_ds = build_verifier_sft_data() |
|
|
| training_args = SFTConfig( |
| output_dir="./output", |
| hub_model_id=f"{HUB_ORG}/speculative-verifier-v3-4b", |
| max_seq_length=2048, |
| packing=False, |
| learning_rate=2e-5, |
| per_device_train_batch_size=4, |
| gradient_accumulation_steps=4, |
| num_train_epochs=2, |
| bf16=True, |
| gradient_checkpointing=True, |
| logging_steps=10, |
| logging_first_step=True, |
| save_strategy="epoch", |
| push_to_hub=True, |
| disable_tqdm=True, |
| report_to="none", |
| ) |
|
|
| trainer = SFTTrainer( |
| model="Qwen/Qwen3-4B", |
| args=training_args, |
| train_dataset=train_ds, |
| eval_dataset=test_ds, |
| ) |
|
|
| print(f"Training Qwen3-4B verifier -> {HUB_ORG}/speculative-verifier-v3-4b") |
| trainer.train() |
|
|
| trainer.save_model() |
| trainer.push_to_hub() |
|
|
| |
| metrics = trainer.evaluate() |
| print(f"Eval metrics: {metrics}") |
|
|
|
|
| if __name__ == "__main__": |
| train() |
|
|