speculative-tool-actions / train_sft_v3.py
narcolepticchicken's picture
Upload train_sft_v3.py
444fd8b verified
"""
Train Speculative Proposer — v3
===============================
Fine-tunes a Qwen3 model to predict agent action types
from conversation context.
Usage: python train_sft.py --model Qwen/Qwen3-1.7B --hub-id speculative-proposer-v3-1.7b
python train_sft.py --model Qwen/Qwen3-8B --hub-id speculative-proposer-v3-8b
"""
import torch
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer
HUB_ORG = "narcolepticchicken"
DATASET = f"{HUB_ORG}/speculative-sft-v3-main"
def train(model_name, hub_model_id, max_seq_length=2048, lr=2e-5):
dataset = load_dataset(DATASET)
print(f"Loaded dataset: {DATASET}")
print(f"Train: {len(dataset['train'])} examples, Test: {len(dataset['test'])} examples")
training_args = SFTConfig(
output_dir="./output",
hub_model_id=f"{HUB_ORG}/{hub_model_id}",
max_seq_length=max_seq_length,
packing=False, # conversational format
learning_rate=lr,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
num_train_epochs=3,
bf16=True,
gradient_checkpointing=True,
logging_steps=5,
logging_first_step=True,
save_strategy="epoch",
push_to_hub=True,
disable_tqdm=True,
dataloader_num_workers=2,
report_to="none",
)
trainer = SFTTrainer(
model=model_name,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
)
print(f"Training {model_name} -> {HUB_ORG}/{hub_model_id}")
trainer.train()
trainer.save_model()
trainer.push_to_hub()
# Eval
metrics = trainer.evaluate()
print(f"Eval metrics: {metrics}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True)
parser.add_argument("--hub-id", required=True)
parser.add_argument("--lr", type=float, default=2e-5)
parser.add_argument("--max-seq-length", type=int, default=2048)
args = parser.parse_args()
train(args.model, args.hub_id, args.max_seq_length, args.lr)