speculative-tool-actions / train_verifier.py
narcolepticchicken's picture
Upload train_verifier.py
cb2bd28 verified
import torch
from datasets import load_dataset
from trl import RewardTrainer, RewardConfig
from peft import LoraConfig
HUB_ORG = 'narcolepticchicken'
MODEL = 'Qwen/Qwen3-4B'
DATASET = f'{HUB_ORG}/speculative-actions-verifier-pref'
OUTPUT = f'{HUB_ORG}/speculative-verifier-qwen3-4b'
print('Loading dataset...')
ds = load_dataset(DATASET)
print('Configuring LoRA...')
peft_config = LoraConfig(
r=16, lora_alpha=32,
target_modules=['q_proj', 'v_proj', 'k_proj', 'o_proj'],
)
print('Configuring Reward Training...')
config = RewardConfig(
output_dir='/tmp/verifier-out',
hub_model_id=OUTPUT,
push_to_hub=True,
learning_rate=2e-4,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
num_train_epochs=2,
bf16=True,
gradient_checkpointing=True,
logging_strategy='steps',
logging_steps=10,
logging_first_step=True,
disable_tqdm=True,
report_to='trackio',
run_name='verifier-reward-qwen3-4b',
)
print('Initializing Reward Trainer...')
trainer = RewardTrainer(
model=MODEL,
train_dataset=ds['train'],
eval_dataset=ds['test'],
args=config,
peft_config=peft_config,
)
print('Training verifier...')
trainer.train()
trainer.push_to_hub()
print('Verifier training complete.')