File size: 2,126 Bytes
444fd8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
"""
Train Speculative Proposer — v3
===============================
Fine-tunes a Qwen3 model to predict agent action types
from conversation context.

Usage: python train_sft.py --model Qwen/Qwen3-1.7B --hub-id speculative-proposer-v3-1.7b
       python train_sft.py --model Qwen/Qwen3-8B --hub-id speculative-proposer-v3-8b
"""

import torch
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer

HUB_ORG = "narcolepticchicken"
DATASET = f"{HUB_ORG}/speculative-sft-v3-main"


def train(model_name, hub_model_id, max_seq_length=2048, lr=2e-5):
    dataset = load_dataset(DATASET)
    print(f"Loaded dataset: {DATASET}")
    print(f"Train: {len(dataset['train'])} examples, Test: {len(dataset['test'])} examples")

    training_args = SFTConfig(
        output_dir="./output",
        hub_model_id=f"{HUB_ORG}/{hub_model_id}",
        max_seq_length=max_seq_length,
        packing=False,  # conversational format
        learning_rate=lr,
        per_device_train_batch_size=4,
        gradient_accumulation_steps=4,
        num_train_epochs=3,
        bf16=True,
        gradient_checkpointing=True,
        logging_steps=5,
        logging_first_step=True,
        save_strategy="epoch",
        push_to_hub=True,
        disable_tqdm=True,
        dataloader_num_workers=2,
        report_to="none",
    )

    trainer = SFTTrainer(
        model=model_name,
        args=training_args,
        train_dataset=dataset["train"],
        eval_dataset=dataset["test"],
    )

    print(f"Training {model_name} -> {HUB_ORG}/{hub_model_id}")
    trainer.train()

    trainer.save_model()
    trainer.push_to_hub()

    # Eval
    metrics = trainer.evaluate()
    print(f"Eval metrics: {metrics}")


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", required=True)
    parser.add_argument("--hub-id", required=True)
    parser.add_argument("--lr", type=float, default=2e-5)
    parser.add_argument("--max-seq-length", type=int, default=2048)
    args = parser.parse_args()
    train(args.model, args.hub_id, args.max_seq_length, args.lr)