File size: 6,143 Bytes
b5d1ae7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
"""Consolidated Training: all 3 speculative-action models in one A100 job.

Trains sequentially with explicit cleanup between models:
  1. Qwen3-1.7B proposer (SFT on action prediction)
  2. Qwen3-4B verifier (SFT on ACCEPT/REJECT)
  3. Qwen3-8B proposer (SFT on action prediction)

Each model is: loaded β†’ trained β†’ evaluated β†’ saved+pushed β†’ deleted β†’ CUDA cache cleared.

Requires: transformers>=4.51, trl, torch, datasets, accelerate, peft, huggingface_hub
"""

import gc
import torch

HUB = "narcolepticchicken"

# ═══════════════════════════════════════════════════════════════
# 1. TRAIN 1.7B PROPOSER
# ═══════════════════════════════════════════════════════════════
print("\n" + "=" * 60)
print("1/3: TRAINING 1.7B PROPOSER")
print("=" * 60)

from datasets import load_dataset
from trl import SFTConfig, SFTTrainer

sft_ds = load_dataset(f"{HUB}/speculative-sft-v3-main")
print(f"SFT data: {len(sft_ds['train'])} train, {len(sft_ds['test'])} test")

args_1b = SFTConfig(
    output_dir="./out_1b",
    hub_model_id=f"{HUB}/speculative-proposer-v3-1.7b",
    max_length=2048,
    packing=False,
    learning_rate=2e-5,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    num_train_epochs=3,
    bf16=True,
    gradient_checkpointing=True,
    logging_steps=5,
    logging_first_step=True,
    save_strategy="epoch",
    push_to_hub=True,
    disable_tqdm=True,
    report_to="none",
)

trainer = SFTTrainer(
    model="Qwen/Qwen3-1.7B",
    args=args_1b,
    train_dataset=sft_ds["train"],
    eval_dataset=sft_ds["test"],
)

trainer.train()
trainer.save_model()
trainer.push_to_hub()

metrics = trainer.evaluate()
print(f"1.7B eval loss: {metrics.get('eval_loss', 'N/A')}")

del trainer
gc.collect()
torch.cuda.empty_cache()
print("1.7B proposer βœ“")

# ═══════════════════════════════════════════════════════════════
# 2. TRAIN 4B VERIFIER
# ═══════════════════════════════════════════════════════════════
print("\n" + "=" * 60)
print("2/3: TRAINING 4B VERIFIER")
print("=" * 60)

from datasets import Dataset

VERIFIER_SYSTEM = (
    "You are an action verifier. Given conversation context and a proposed next action, "
    "determine if the proposal is correct. Respond with exactly ACCEPT or REJECT."
)

verif_raw = load_dataset(f"{HUB}/speculative-verifier-v3-main")

sft_rows = []
for split_name in ["train", "test"]:
    for row in verif_raw[split_name]:
        ctx = row["context"]
        proposal = row["proposal"]
        label = row["label"]
        answer = "ACCEPT" if label == 1 else "REJECT"

        msgs = [{"role": "system", "content": VERIFIER_SYSTEM}]
        for m in ctx[-6:]:
            msgs.append({"role": m["role"], "content": str(m["content"])[:400]})
        msgs.append({
            "role": "user",
            "content": f"Proposed next action: {proposal}\n\nIs this the correct next action? ACCEPT or REJECT?"
        })
        msgs.append({"role": "assistant", "content": answer})
        sft_rows.append({"messages": msgs, "split": split_name})

v_train = Dataset.from_list([{"messages": r["messages"]} for r in sft_rows if r["split"] == "train"])
v_test = Dataset.from_list([{"messages": r["messages"]} for r in sft_rows if r["split"] == "test"])
print(f"Verifier SFT: {len(v_train)} train, {len(v_test)} test")

args_4b = SFTConfig(
    output_dir="./out_4b",
    hub_model_id=f"{HUB}/speculative-verifier-v3-4b",
    max_length=2048,
    packing=False,
    learning_rate=2e-5,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    num_train_epochs=2,
    bf16=True,
    gradient_checkpointing=True,
    logging_steps=10,
    logging_first_step=True,
    save_strategy="epoch",
    push_to_hub=True,
    disable_tqdm=True,
    report_to="none",
)

trainer = SFTTrainer(
    model="Qwen/Qwen3-4B",
    args=args_4b,
    train_dataset=v_train,
    eval_dataset=v_test,
)

trainer.train()
trainer.save_model()
trainer.push_to_hub()

metrics = trainer.evaluate()
print(f"4B verifier eval loss: {metrics.get('eval_loss', 'N/A')}")

del trainer
gc.collect()
torch.cuda.empty_cache()
print("4B verifier βœ“")

# ═══════════════════════════════════════════════════════════════
# 3. TRAIN 8B PROPOSER
# ═══════════════════════════════════════════════════════════════
print("\n" + "=" * 60)
print("3/3: TRAINING 8B PROPOSER")
print("=" * 60)

args_8b = SFTConfig(
    output_dir="./out_8b",
    hub_model_id=f"{HUB}/speculative-proposer-v3-8b",
    max_length=2048,
    packing=False,
    learning_rate=2e-5,
    per_device_train_batch_size=2,       # smaller batch to fit 8B on 80GB
    gradient_accumulation_steps=8,        # effective batch = 16 (same as others)
    num_train_epochs=3,
    bf16=True,
    gradient_checkpointing=True,
    logging_steps=5,
    logging_first_step=True,
    save_strategy="epoch",
    push_to_hub=True,
    disable_tqdm=True,
    report_to="none",
)

trainer = SFTTrainer(
    model="Qwen/Qwen3-8B",
    args=args_8b,
    train_dataset=sft_ds["train"],
    eval_dataset=sft_ds["test"],
)

trainer.train()
trainer.save_model()
trainer.push_to_hub()

metrics = trainer.evaluate()
print(f"8B eval loss: {metrics.get('eval_loss', 'N/A')}")

print("\n" + "=" * 60)
print("ALL THREE MODELS TRAINED SUCCESSFULLY!")
print(f"  {HUB}/speculative-proposer-v3-1.7b")
print(f"  {HUB}/speculative-verifier-v3-4b")
print(f"  {HUB}/speculative-proposer-v3-8b")
print("=" * 60)