narcolepticchicken commited on
Commit
77c676b
·
verified ·
1 Parent(s): 2aced14

Add proposer training script

Browse files
Files changed (1) hide show
  1. train_proposer.py +60 -0
train_proposer.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train Cheap Proposer (Next-Action Predictor)
3
+ =============================================
4
+ SFT on Qwen3-1.7B to predict next action type given conversation state.
5
+ Dataset: narcolepticchicken/speculative-actions-proposer-sft
6
+ """
7
+ import torch
8
+ from datasets import load_dataset
9
+ from trl import SFTTrainer, SFTConfig
10
+ from peft import LoraConfig
11
+
12
+ MODEL = "Qwen/Qwen3-1.7B"
13
+ DATASET = "narcolepticchicken/speculative-actions-proposer-sft"
14
+ OUTPUT = "narcolepticchicken/speculative-proposer-qwen3-1.7b"
15
+
16
+ def main():
17
+ ds = load_dataset(DATASET)
18
+ train_ds = ds["train"]
19
+ eval_ds = ds["test"]
20
+
21
+ peft_config = LoraConfig(
22
+ r=16,
23
+ lora_alpha=32,
24
+ target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
25
+ modules_to_save=["embed_tokens", "lm_head"],
26
+ )
27
+
28
+ config = SFTConfig(
29
+ output_dir="/tmp/proposer-out",
30
+ hub_model_id=OUTPUT,
31
+ push_to_hub=True,
32
+ learning_rate=2e-4,
33
+ per_device_train_batch_size=4,
34
+ gradient_accumulation_steps=4,
35
+ num_train_epochs=3,
36
+ max_seq_length=4096,
37
+ bf16=True,
38
+ gradient_checkpointing=True,
39
+ logging_strategy="steps",
40
+ logging_steps=10,
41
+ logging_first_step=True,
42
+ disable_tqdm=True,
43
+ report_to="trackio",
44
+ run_name="proposer-sft-qwen3-1.7b",
45
+ )
46
+
47
+ trainer = SFTTrainer(
48
+ model=MODEL,
49
+ train_dataset=train_ds,
50
+ eval_dataset=eval_ds,
51
+ args=config,
52
+ peft_config=peft_config,
53
+ )
54
+
55
+ trainer.train()
56
+ trainer.push_to_hub()
57
+ print("Proposer training complete.")
58
+
59
+ if __name__ == "__main__":
60
+ main()