speculative-tool-actions / train_proposer.py
narcolepticchicken's picture
Upload train_proposer.py
2fea9f4 verified
import torch
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig
HUB_ORG = 'narcolepticchicken'
MODEL = 'Qwen/Qwen3-1.7B'
DATASET = f'{HUB_ORG}/speculative-actions-proposer-sft'
OUTPUT = f'{HUB_ORG}/speculative-proposer-qwen3-1.7b'
print('Loading dataset...')
ds = load_dataset(DATASET)
print('Configuring LoRA...')
peft_config = LoraConfig(
r=16, lora_alpha=32,
target_modules=['q_proj', 'v_proj', 'k_proj', 'o_proj'],
modules_to_save=['embed_tokens', 'lm_head'],
)
print('Configuring SFT...')
config = SFTConfig(
output_dir='/tmp/proposer-out',
hub_model_id=OUTPUT,
push_to_hub=True,
learning_rate=2e-4,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
num_train_epochs=2,
bf16=True,
gradient_checkpointing=True,
logging_strategy='steps',
logging_steps=10,
logging_first_step=True,
disable_tqdm=True,
report_to='trackio',
run_name='proposer-sft-qwen3-1.7b',
dataset_text_field='text',
)
print('Initializing trainer...')
trainer = SFTTrainer(
model=MODEL,
train_dataset=ds['train'],
eval_dataset=ds['test'],
args=config,
peft_config=peft_config,
)
print('Training proposer...')
trainer.train()
trainer.push_to_hub()
print('Proposer training complete.')