File size: 3,309 Bytes
5a6b595
 
 
51dec95
5a6b595
51dec95
 
 
 
 
5a6b595
 
 
 
 
51dec95
 
5a6b595
 
51dec95
 
 
 
5a6b595
 
51dec95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a6b595
51dec95
 
5a6b595
51dec95
 
 
 
 
5a6b595
 
51dec95
 
 
5a6b595
 
 
 
 
 
 
 
 
 
 
 
 
51dec95
5a6b595
 
51dec95
 
5a6b595
 
 
 
 
 
 
 
51dec95
5a6b595
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""
Train Speculative Verifier — v3
================================
Trains Qwen3-4B as an ACCEPT/REJECT classifier for proposed actions.

Given conversation context and a proposed action, the model outputs
either "ACCEPT" or "REJECT".

Training: SFT on (context + proposal + "ACCEPT/REJECT?") -> "ACCEPT"/"REJECT"
Inference: generate single token, check if it's ACCEPT or REJECT.

Usage: python train_verifier_v3.py
"""

import torch
from datasets import load_dataset, Dataset
from trl import SFTConfig, SFTTrainer

HUB_ORG = "narcolepticchicken"
VERIFIER_SYSTEM = (
    "You are an action verifier. Given conversation context and a proposed next action, "
    "determine if the proposal is correct. Respond with exactly ACCEPT or REJECT."
)


def build_verifier_sft_data():
    """Load verifier pairs and convert to conversational SFT format."""
    dataset = load_dataset(f"{HUB_ORG}/speculative-verifier-v3-main")
    print(f"Loaded raw verifier data: {len(dataset['train'])} train, {len(dataset['test'])} test")

    sft_rows = []
    for split_name in ["train", "test"]:
        for row in dataset[split_name]:
            ctx = row["context"]
            proposal = row["proposal"]
            label = row["label"]  # 1=ACCEPT, 0=REJECT
            answer = "ACCEPT" if label == 1 else "REJECT"

            msgs = [{"role": "system", "content": VERIFIER_SYSTEM}]
            # Add context
            for m in ctx[-6:]:
                msgs.append({"role": m["role"], "content": str(m["content"])[:400]})
            # Add proposal query
            msgs.append({
                "role": "user",
                "content": f"Proposed next action: {proposal}\n\nIs this the correct next action? ACCEPT or REJECT?"
            })
            msgs.append({"role": "assistant", "content": answer})
            sft_rows.append({"messages": msgs, "split": split_name})

    train_rows = [r for r in sft_rows if r["split"] == "train"]
    test_rows = [r for r in sft_rows if r["split"] == "test"]

    train_ds = Dataset.from_list([{"messages": r["messages"]} for r in train_rows])
    test_ds = Dataset.from_list([{"messages": r["messages"]} for r in test_rows])

    print(f"SFT format: {len(train_rows)} train, {len(test_rows)} test")
    return train_ds, test_ds


def train():
    train_ds, test_ds = build_verifier_sft_data()

    training_args = SFTConfig(
        output_dir="./output",
        hub_model_id=f"{HUB_ORG}/speculative-verifier-v3-4b",
        max_seq_length=2048,
        packing=False,
        learning_rate=2e-5,
        per_device_train_batch_size=4,
        gradient_accumulation_steps=4,
        num_train_epochs=2,
        bf16=True,
        gradient_checkpointing=True,
        logging_steps=10,
        logging_first_step=True,
        save_strategy="epoch",
        push_to_hub=True,
        disable_tqdm=True,
        report_to="none",
    )

    trainer = SFTTrainer(
        model="Qwen/Qwen3-4B",
        args=training_args,
        train_dataset=train_ds,
        eval_dataset=test_ds,
    )

    print(f"Training Qwen3-4B verifier -> {HUB_ORG}/speculative-verifier-v3-4b")
    trainer.train()

    trainer.save_model()
    trainer.push_to_hub()

    # Eval
    metrics = trainer.evaluate()
    print(f"Eval metrics: {metrics}")


if __name__ == "__main__":
    train()