speculative-tool-actions / train_all_v3.py
narcolepticchicken's picture
Upload train_all_v3.py
b5d1ae7 verified
"""Consolidated Training: all 3 speculative-action models in one A100 job.
Trains sequentially with explicit cleanup between models:
1. Qwen3-1.7B proposer (SFT on action prediction)
2. Qwen3-4B verifier (SFT on ACCEPT/REJECT)
3. Qwen3-8B proposer (SFT on action prediction)
Each model is: loaded β†’ trained β†’ evaluated β†’ saved+pushed β†’ deleted β†’ CUDA cache cleared.
Requires: transformers>=4.51, trl, torch, datasets, accelerate, peft, huggingface_hub
"""
import gc
import torch
HUB = "narcolepticchicken"
# ═══════════════════════════════════════════════════════════════
# 1. TRAIN 1.7B PROPOSER
# ═══════════════════════════════════════════════════════════════
print("\n" + "=" * 60)
print("1/3: TRAINING 1.7B PROPOSER")
print("=" * 60)
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer
sft_ds = load_dataset(f"{HUB}/speculative-sft-v3-main")
print(f"SFT data: {len(sft_ds['train'])} train, {len(sft_ds['test'])} test")
args_1b = SFTConfig(
output_dir="./out_1b",
hub_model_id=f"{HUB}/speculative-proposer-v3-1.7b",
max_length=2048,
packing=False,
learning_rate=2e-5,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
num_train_epochs=3,
bf16=True,
gradient_checkpointing=True,
logging_steps=5,
logging_first_step=True,
save_strategy="epoch",
push_to_hub=True,
disable_tqdm=True,
report_to="none",
)
trainer = SFTTrainer(
model="Qwen/Qwen3-1.7B",
args=args_1b,
train_dataset=sft_ds["train"],
eval_dataset=sft_ds["test"],
)
trainer.train()
trainer.save_model()
trainer.push_to_hub()
metrics = trainer.evaluate()
print(f"1.7B eval loss: {metrics.get('eval_loss', 'N/A')}")
del trainer
gc.collect()
torch.cuda.empty_cache()
print("1.7B proposer βœ“")
# ═══════════════════════════════════════════════════════════════
# 2. TRAIN 4B VERIFIER
# ═══════════════════════════════════════════════════════════════
print("\n" + "=" * 60)
print("2/3: TRAINING 4B VERIFIER")
print("=" * 60)
from datasets import Dataset
VERIFIER_SYSTEM = (
"You are an action verifier. Given conversation context and a proposed next action, "
"determine if the proposal is correct. Respond with exactly ACCEPT or REJECT."
)
verif_raw = load_dataset(f"{HUB}/speculative-verifier-v3-main")
sft_rows = []
for split_name in ["train", "test"]:
for row in verif_raw[split_name]:
ctx = row["context"]
proposal = row["proposal"]
label = row["label"]
answer = "ACCEPT" if label == 1 else "REJECT"
msgs = [{"role": "system", "content": VERIFIER_SYSTEM}]
for m in ctx[-6:]:
msgs.append({"role": m["role"], "content": str(m["content"])[:400]})
msgs.append({
"role": "user",
"content": f"Proposed next action: {proposal}\n\nIs this the correct next action? ACCEPT or REJECT?"
})
msgs.append({"role": "assistant", "content": answer})
sft_rows.append({"messages": msgs, "split": split_name})
v_train = Dataset.from_list([{"messages": r["messages"]} for r in sft_rows if r["split"] == "train"])
v_test = Dataset.from_list([{"messages": r["messages"]} for r in sft_rows if r["split"] == "test"])
print(f"Verifier SFT: {len(v_train)} train, {len(v_test)} test")
args_4b = SFTConfig(
output_dir="./out_4b",
hub_model_id=f"{HUB}/speculative-verifier-v3-4b",
max_length=2048,
packing=False,
learning_rate=2e-5,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
num_train_epochs=2,
bf16=True,
gradient_checkpointing=True,
logging_steps=10,
logging_first_step=True,
save_strategy="epoch",
push_to_hub=True,
disable_tqdm=True,
report_to="none",
)
trainer = SFTTrainer(
model="Qwen/Qwen3-4B",
args=args_4b,
train_dataset=v_train,
eval_dataset=v_test,
)
trainer.train()
trainer.save_model()
trainer.push_to_hub()
metrics = trainer.evaluate()
print(f"4B verifier eval loss: {metrics.get('eval_loss', 'N/A')}")
del trainer
gc.collect()
torch.cuda.empty_cache()
print("4B verifier βœ“")
# ═══════════════════════════════════════════════════════════════
# 3. TRAIN 8B PROPOSER
# ═══════════════════════════════════════════════════════════════
print("\n" + "=" * 60)
print("3/3: TRAINING 8B PROPOSER")
print("=" * 60)
args_8b = SFTConfig(
output_dir="./out_8b",
hub_model_id=f"{HUB}/speculative-proposer-v3-8b",
max_length=2048,
packing=False,
learning_rate=2e-5,
per_device_train_batch_size=2, # smaller batch to fit 8B on 80GB
gradient_accumulation_steps=8, # effective batch = 16 (same as others)
num_train_epochs=3,
bf16=True,
gradient_checkpointing=True,
logging_steps=5,
logging_first_step=True,
save_strategy="epoch",
push_to_hub=True,
disable_tqdm=True,
report_to="none",
)
trainer = SFTTrainer(
model="Qwen/Qwen3-8B",
args=args_8b,
train_dataset=sft_ds["train"],
eval_dataset=sft_ds["test"],
)
trainer.train()
trainer.save_model()
trainer.push_to_hub()
metrics = trainer.evaluate()
print(f"8B eval loss: {metrics.get('eval_loss', 'N/A')}")
print("\n" + "=" * 60)
print("ALL THREE MODELS TRAINED SUCCESSFULLY!")
print(f" {HUB}/speculative-proposer-v3-1.7b")
print(f" {HUB}/speculative-verifier-v3-4b")
print(f" {HUB}/speculative-proposer-v3-8b")
print("=" * 60)