File size: 2,126 Bytes
444fd8b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 | """
Train Speculative Proposer — v3
===============================
Fine-tunes a Qwen3 model to predict agent action types
from conversation context.
Usage: python train_sft.py --model Qwen/Qwen3-1.7B --hub-id speculative-proposer-v3-1.7b
python train_sft.py --model Qwen/Qwen3-8B --hub-id speculative-proposer-v3-8b
"""
import torch
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer
HUB_ORG = "narcolepticchicken"
DATASET = f"{HUB_ORG}/speculative-sft-v3-main"
def train(model_name, hub_model_id, max_seq_length=2048, lr=2e-5):
dataset = load_dataset(DATASET)
print(f"Loaded dataset: {DATASET}")
print(f"Train: {len(dataset['train'])} examples, Test: {len(dataset['test'])} examples")
training_args = SFTConfig(
output_dir="./output",
hub_model_id=f"{HUB_ORG}/{hub_model_id}",
max_seq_length=max_seq_length,
packing=False, # conversational format
learning_rate=lr,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
num_train_epochs=3,
bf16=True,
gradient_checkpointing=True,
logging_steps=5,
logging_first_step=True,
save_strategy="epoch",
push_to_hub=True,
disable_tqdm=True,
dataloader_num_workers=2,
report_to="none",
)
trainer = SFTTrainer(
model=model_name,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
)
print(f"Training {model_name} -> {HUB_ORG}/{hub_model_id}")
trainer.train()
trainer.save_model()
trainer.push_to_hub()
# Eval
metrics = trainer.evaluate()
print(f"Eval metrics: {metrics}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True)
parser.add_argument("--hub-id", required=True)
parser.add_argument("--lr", type=float, default=2e-5)
parser.add_argument("--max-seq-length", type=int, default=2048)
args = parser.parse_args()
train(args.model, args.hub_id, args.max_seq_length, args.lr)
|