| """ |
| Train Speculative Proposer — v3 |
| =============================== |
| Fine-tunes a Qwen3 model to predict agent action types |
| from conversation context. |
| |
| Usage: python train_sft.py --model Qwen/Qwen3-1.7B --hub-id speculative-proposer-v3-1.7b |
| python train_sft.py --model Qwen/Qwen3-8B --hub-id speculative-proposer-v3-8b |
| """ |
|
|
| import torch |
| from datasets import load_dataset |
| from trl import SFTConfig, SFTTrainer |
|
|
| HUB_ORG = "narcolepticchicken" |
| DATASET = f"{HUB_ORG}/speculative-sft-v3-main" |
|
|
|
|
| def train(model_name, hub_model_id, max_seq_length=2048, lr=2e-5): |
| dataset = load_dataset(DATASET) |
| print(f"Loaded dataset: {DATASET}") |
| print(f"Train: {len(dataset['train'])} examples, Test: {len(dataset['test'])} examples") |
|
|
| training_args = SFTConfig( |
| output_dir="./output", |
| hub_model_id=f"{HUB_ORG}/{hub_model_id}", |
| max_seq_length=max_seq_length, |
| packing=False, |
| learning_rate=lr, |
| per_device_train_batch_size=4, |
| gradient_accumulation_steps=4, |
| num_train_epochs=3, |
| bf16=True, |
| gradient_checkpointing=True, |
| logging_steps=5, |
| logging_first_step=True, |
| save_strategy="epoch", |
| push_to_hub=True, |
| disable_tqdm=True, |
| dataloader_num_workers=2, |
| report_to="none", |
| ) |
|
|
| trainer = SFTTrainer( |
| model=model_name, |
| args=training_args, |
| train_dataset=dataset["train"], |
| eval_dataset=dataset["test"], |
| ) |
|
|
| print(f"Training {model_name} -> {HUB_ORG}/{hub_model_id}") |
| trainer.train() |
|
|
| trainer.save_model() |
| trainer.push_to_hub() |
|
|
| |
| metrics = trainer.evaluate() |
| print(f"Eval metrics: {metrics}") |
|
|
|
|
| if __name__ == "__main__": |
| import argparse |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--model", required=True) |
| parser.add_argument("--hub-id", required=True) |
| parser.add_argument("--lr", type=float, default=2e-5) |
| parser.add_argument("--max-seq-length", type=int, default=2048) |
| args = parser.parse_args() |
| train(args.model, args.hub_id, args.max_seq_length, args.lr) |
|
|