narcolepticchicken commited on
Commit
2fea9f4
·
verified ·
1 Parent(s): e7867a3

Upload train_proposer.py

Browse files
Files changed (1) hide show
  1. train_proposer.py +43 -50
train_proposer.py CHANGED
@@ -1,60 +1,53 @@
1
- """
2
- Train Cheap Proposer (Next-Action Predictor)
3
- =============================================
4
- SFT on Qwen3-1.7B to predict next action type given conversation state.
5
- Dataset: narcolepticchicken/speculative-actions-proposer-sft
6
- """
7
  import torch
8
  from datasets import load_dataset
9
  from trl import SFTTrainer, SFTConfig
10
  from peft import LoraConfig
11
 
12
- MODEL = "Qwen/Qwen3-1.7B"
13
- DATASET = "narcolepticchicken/speculative-actions-proposer-sft"
14
- OUTPUT = "narcolepticchicken/speculative-proposer-qwen3-1.7b"
 
15
 
16
- def main():
17
- ds = load_dataset(DATASET)
18
- train_ds = ds["train"]
19
- eval_ds = ds["test"]
20
 
21
- peft_config = LoraConfig(
22
- r=16,
23
- lora_alpha=32,
24
- target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
25
- modules_to_save=["embed_tokens", "lm_head"],
26
- )
27
 
28
- config = SFTConfig(
29
- output_dir="/tmp/proposer-out",
30
- hub_model_id=OUTPUT,
31
- push_to_hub=True,
32
- learning_rate=2e-4,
33
- per_device_train_batch_size=4,
34
- gradient_accumulation_steps=4,
35
- num_train_epochs=3,
36
- max_seq_length=4096,
37
- bf16=True,
38
- gradient_checkpointing=True,
39
- logging_strategy="steps",
40
- logging_steps=10,
41
- logging_first_step=True,
42
- disable_tqdm=True,
43
- report_to="trackio",
44
- run_name="proposer-sft-qwen3-1.7b",
45
- )
 
46
 
47
- trainer = SFTTrainer(
48
- model=MODEL,
49
- train_dataset=train_ds,
50
- eval_dataset=eval_ds,
51
- args=config,
52
- peft_config=peft_config,
53
- )
 
54
 
55
- trainer.train()
56
- trainer.push_to_hub()
57
- print("Proposer training complete.")
58
-
59
- if __name__ == "__main__":
60
- main()
 
 
 
 
 
 
 
1
  import torch
2
  from datasets import load_dataset
3
  from trl import SFTTrainer, SFTConfig
4
  from peft import LoraConfig
5
 
6
+ HUB_ORG = 'narcolepticchicken'
7
+ MODEL = 'Qwen/Qwen3-1.7B'
8
+ DATASET = f'{HUB_ORG}/speculative-actions-proposer-sft'
9
+ OUTPUT = f'{HUB_ORG}/speculative-proposer-qwen3-1.7b'
10
 
11
+ print('Loading dataset...')
12
+ ds = load_dataset(DATASET)
 
 
13
 
14
+ print('Configuring LoRA...')
15
+ peft_config = LoraConfig(
16
+ r=16, lora_alpha=32,
17
+ target_modules=['q_proj', 'v_proj', 'k_proj', 'o_proj'],
18
+ modules_to_save=['embed_tokens', 'lm_head'],
19
+ )
20
 
21
+ print('Configuring SFT...')
22
+ config = SFTConfig(
23
+ output_dir='/tmp/proposer-out',
24
+ hub_model_id=OUTPUT,
25
+ push_to_hub=True,
26
+ learning_rate=2e-4,
27
+ per_device_train_batch_size=4,
28
+ gradient_accumulation_steps=4,
29
+ num_train_epochs=2,
30
+ bf16=True,
31
+ gradient_checkpointing=True,
32
+ logging_strategy='steps',
33
+ logging_steps=10,
34
+ logging_first_step=True,
35
+ disable_tqdm=True,
36
+ report_to='trackio',
37
+ run_name='proposer-sft-qwen3-1.7b',
38
+ dataset_text_field='text',
39
+ )
40
 
41
+ print('Initializing trainer...')
42
+ trainer = SFTTrainer(
43
+ model=MODEL,
44
+ train_dataset=ds['train'],
45
+ eval_dataset=ds['test'],
46
+ args=config,
47
+ peft_config=peft_config,
48
+ )
49
 
50
+ print('Training proposer...')
51
+ trainer.train()
52
+ trainer.push_to_hub()
53
+ print('Proposer training complete.')