""" Train Speculative Proposer — v3 =============================== Fine-tunes a Qwen3 model to predict agent action types from conversation context. Usage: python train_sft.py --model Qwen/Qwen3-1.7B --hub-id speculative-proposer-v3-1.7b python train_sft.py --model Qwen/Qwen3-8B --hub-id speculative-proposer-v3-8b """ import torch from datasets import load_dataset from trl import SFTConfig, SFTTrainer HUB_ORG = "narcolepticchicken" DATASET = f"{HUB_ORG}/speculative-sft-v3-main" def train(model_name, hub_model_id, max_seq_length=2048, lr=2e-5): dataset = load_dataset(DATASET) print(f"Loaded dataset: {DATASET}") print(f"Train: {len(dataset['train'])} examples, Test: {len(dataset['test'])} examples") training_args = SFTConfig( output_dir="./output", hub_model_id=f"{HUB_ORG}/{hub_model_id}", max_seq_length=max_seq_length, packing=False, # conversational format learning_rate=lr, per_device_train_batch_size=4, gradient_accumulation_steps=4, num_train_epochs=3, bf16=True, gradient_checkpointing=True, logging_steps=5, logging_first_step=True, save_strategy="epoch", push_to_hub=True, disable_tqdm=True, dataloader_num_workers=2, report_to="none", ) trainer = SFTTrainer( model=model_name, args=training_args, train_dataset=dataset["train"], eval_dataset=dataset["test"], ) print(f"Training {model_name} -> {HUB_ORG}/{hub_model_id}") trainer.train() trainer.save_model() trainer.push_to_hub() # Eval metrics = trainer.evaluate() print(f"Eval metrics: {metrics}") if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--model", required=True) parser.add_argument("--hub-id", required=True) parser.add_argument("--lr", type=float, default=2e-5) parser.add_argument("--max-seq-length", type=int, default=2048) args = parser.parse_args() train(args.model, args.hub_id, args.max_seq_length, args.lr)