narcolepticchicken commited on
Commit
cf57590
·
verified ·
1 Parent(s): b5986fe

Add proposer training job script

Browse files
Files changed (1) hide show
  1. train_proposer_job.py +58 -0
train_proposer_job.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HF Jobs script: Train cheap proposer (Qwen3-1.7B) for next-action prediction.
3
+ Uses SFTTrainer with LoRA on speculative-actions-proposer-sft dataset.
4
+ """
5
+ import torch
6
+ from datasets import load_dataset
7
+ from trl import SFTTrainer, SFTConfig
8
+ from peft import LoraConfig
9
+
10
+ MODEL = "Qwen/Qwen3-1.7B"
11
+ DATASET = "narcolepticchicken/speculative-actions-proposer-sft"
12
+ OUTPUT = "narcolepticchicken/speculative-proposer-qwen3-1.7b"
13
+
14
+ def main():
15
+ ds = load_dataset(DATASET)
16
+ train_ds = ds["train"]
17
+ eval_ds = ds["test"]
18
+
19
+ peft_config = LoraConfig(
20
+ r=16,
21
+ lora_alpha=32,
22
+ target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
23
+ modules_to_save=["embed_tokens", "lm_head"],
24
+ )
25
+
26
+ config = SFTConfig(
27
+ output_dir="/tmp/proposer-out",
28
+ hub_model_id=OUTPUT,
29
+ push_to_hub=True,
30
+ learning_rate=2e-4,
31
+ per_device_train_batch_size=4,
32
+ gradient_accumulation_steps=4,
33
+ num_train_epochs=3,
34
+ max_seq_length=4096,
35
+ bf16=True,
36
+ gradient_checkpointing=True,
37
+ logging_strategy="steps",
38
+ logging_steps=10,
39
+ logging_first_step=True,
40
+ disable_tqdm=True,
41
+ report_to="trackio",
42
+ run_name="proposer-sft-qwen3-1.7b",
43
+ )
44
+
45
+ trainer = SFTTrainer(
46
+ model=MODEL,
47
+ train_dataset=train_ds,
48
+ eval_dataset=eval_ds,
49
+ args=config,
50
+ peft_config=peft_config,
51
+ )
52
+
53
+ trainer.train()
54
+ trainer.push_to_hub()
55
+ print("Proposer training complete.")
56
+
57
+ if __name__ == "__main__":
58
+ main()