narcolepticchicken commited on
Commit
444fd8b
·
verified ·
1 Parent(s): 26e44c5

Upload train_sft_v3.py

Browse files
Files changed (1) hide show
  1. train_sft_v3.py +70 -0
train_sft_v3.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train Speculative Proposer — v3
3
+ ===============================
4
+ Fine-tunes a Qwen3 model to predict agent action types
5
+ from conversation context.
6
+
7
+ Usage: python train_sft.py --model Qwen/Qwen3-1.7B --hub-id speculative-proposer-v3-1.7b
8
+ python train_sft.py --model Qwen/Qwen3-8B --hub-id speculative-proposer-v3-8b
9
+ """
10
+
11
+ import torch
12
+ from datasets import load_dataset
13
+ from trl import SFTConfig, SFTTrainer
14
+
15
+ HUB_ORG = "narcolepticchicken"
16
+ DATASET = f"{HUB_ORG}/speculative-sft-v3-main"
17
+
18
+
19
+ def train(model_name, hub_model_id, max_seq_length=2048, lr=2e-5):
20
+ dataset = load_dataset(DATASET)
21
+ print(f"Loaded dataset: {DATASET}")
22
+ print(f"Train: {len(dataset['train'])} examples, Test: {len(dataset['test'])} examples")
23
+
24
+ training_args = SFTConfig(
25
+ output_dir="./output",
26
+ hub_model_id=f"{HUB_ORG}/{hub_model_id}",
27
+ max_seq_length=max_seq_length,
28
+ packing=False, # conversational format
29
+ learning_rate=lr,
30
+ per_device_train_batch_size=4,
31
+ gradient_accumulation_steps=4,
32
+ num_train_epochs=3,
33
+ bf16=True,
34
+ gradient_checkpointing=True,
35
+ logging_steps=5,
36
+ logging_first_step=True,
37
+ save_strategy="epoch",
38
+ push_to_hub=True,
39
+ disable_tqdm=True,
40
+ dataloader_num_workers=2,
41
+ report_to="none",
42
+ )
43
+
44
+ trainer = SFTTrainer(
45
+ model=model_name,
46
+ args=training_args,
47
+ train_dataset=dataset["train"],
48
+ eval_dataset=dataset["test"],
49
+ )
50
+
51
+ print(f"Training {model_name} -> {HUB_ORG}/{hub_model_id}")
52
+ trainer.train()
53
+
54
+ trainer.save_model()
55
+ trainer.push_to_hub()
56
+
57
+ # Eval
58
+ metrics = trainer.evaluate()
59
+ print(f"Eval metrics: {metrics}")
60
+
61
+
62
+ if __name__ == "__main__":
63
+ import argparse
64
+ parser = argparse.ArgumentParser()
65
+ parser.add_argument("--model", required=True)
66
+ parser.add_argument("--hub-id", required=True)
67
+ parser.add_argument("--lr", type=float, default=2e-5)
68
+ parser.add_argument("--max-seq-length", type=int, default=2048)
69
+ args = parser.parse_args()
70
+ train(args.model, args.hub_id, args.max_seq_length, args.lr)