narcolepticchicken commited on
Commit
cb2bd28
·
verified ·
1 Parent(s): 2fea9f4

Upload train_verifier.py

Browse files
Files changed (1) hide show
  1. train_verifier.py +41 -50
train_verifier.py CHANGED
@@ -1,60 +1,51 @@
1
- """
2
- Train Verifier / Judge (Outcome Reward Model)
3
- ===============================================
4
- RewardTrainer on Qwen3-4B using preference pairs.
5
- Dataset: narcolepticchicken/speculative-actions-verifier-pref
6
- """
7
  import torch
8
  from datasets import load_dataset
9
  from trl import RewardTrainer, RewardConfig
10
  from peft import LoraConfig
11
 
12
- MODEL = "Qwen/Qwen3-4B"
13
- DATASET = "narcolepticchicken/speculative-actions-verifier-pref"
14
- OUTPUT = "narcolepticchicken/speculative-verifier-qwen3-4b"
 
15
 
16
- def main():
17
- ds = load_dataset(DATASET)
18
- train_ds = ds["train"]
19
- eval_ds = ds["test"]
20
 
21
- peft_config = LoraConfig(
22
- r=16,
23
- lora_alpha=32,
24
- target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
25
- modules_to_save=["score"],
26
- )
27
 
28
- config = RewardConfig(
29
- output_dir="/tmp/verifier-out",
30
- hub_model_id=OUTPUT,
31
- push_to_hub=True,
32
- learning_rate=1e-3,
33
- per_device_train_batch_size=2,
34
- gradient_accumulation_steps=8,
35
- num_train_epochs=2,
36
- max_seq_length=4096,
37
- bf16=True,
38
- gradient_checkpointing=True,
39
- logging_strategy="steps",
40
- logging_steps=10,
41
- logging_first_step=True,
42
- disable_tqdm=True,
43
- report_to="trackio",
44
- run_name="verifier-reward-qwen3-4b",
45
- )
46
 
47
- trainer = RewardTrainer(
48
- model=MODEL,
49
- train_dataset=train_ds,
50
- eval_dataset=eval_ds,
51
- args=config,
52
- peft_config=peft_config,
53
- )
 
54
 
55
- trainer.train()
56
- trainer.push_to_hub()
57
- print("Verifier training complete.")
58
-
59
- if __name__ == "__main__":
60
- main()
 
 
 
 
 
 
 
1
  import torch
2
  from datasets import load_dataset
3
  from trl import RewardTrainer, RewardConfig
4
  from peft import LoraConfig
5
 
6
+ HUB_ORG = 'narcolepticchicken'
7
+ MODEL = 'Qwen/Qwen3-4B'
8
+ DATASET = f'{HUB_ORG}/speculative-actions-verifier-pref'
9
+ OUTPUT = f'{HUB_ORG}/speculative-verifier-qwen3-4b'
10
 
11
+ print('Loading dataset...')
12
+ ds = load_dataset(DATASET)
 
 
13
 
14
+ print('Configuring LoRA...')
15
+ peft_config = LoraConfig(
16
+ r=16, lora_alpha=32,
17
+ target_modules=['q_proj', 'v_proj', 'k_proj', 'o_proj'],
18
+ )
 
19
 
20
+ print('Configuring Reward Training...')
21
+ config = RewardConfig(
22
+ output_dir='/tmp/verifier-out',
23
+ hub_model_id=OUTPUT,
24
+ push_to_hub=True,
25
+ learning_rate=2e-4,
26
+ per_device_train_batch_size=4,
27
+ gradient_accumulation_steps=4,
28
+ num_train_epochs=2,
29
+ bf16=True,
30
+ gradient_checkpointing=True,
31
+ logging_strategy='steps',
32
+ logging_steps=10,
33
+ logging_first_step=True,
34
+ disable_tqdm=True,
35
+ report_to='trackio',
36
+ run_name='verifier-reward-qwen3-4b',
37
+ )
38
 
39
+ print('Initializing Reward Trainer...')
40
+ trainer = RewardTrainer(
41
+ model=MODEL,
42
+ train_dataset=ds['train'],
43
+ eval_dataset=ds['test'],
44
+ args=config,
45
+ peft_config=peft_config,
46
+ )
47
 
48
+ print('Training verifier...')
49
+ trainer.train()
50
+ trainer.push_to_hub()
51
+ print('Verifier training complete.')