| import torch |
| from datasets import load_dataset |
| from trl import SFTTrainer, SFTConfig |
| from peft import LoraConfig |
|
|
| HUB_ORG = 'narcolepticchicken' |
| MODEL = 'Qwen/Qwen3-1.7B' |
| DATASET = f'{HUB_ORG}/speculative-actions-proposer-sft' |
| OUTPUT = f'{HUB_ORG}/speculative-proposer-qwen3-1.7b' |
|
|
| print('Loading dataset...') |
| ds = load_dataset(DATASET) |
|
|
| print('Configuring LoRA...') |
| peft_config = LoraConfig( |
| r=16, lora_alpha=32, |
| target_modules=['q_proj', 'v_proj', 'k_proj', 'o_proj'], |
| modules_to_save=['embed_tokens', 'lm_head'], |
| ) |
|
|
| print('Configuring SFT...') |
| config = SFTConfig( |
| output_dir='/tmp/proposer-out', |
| hub_model_id=OUTPUT, |
| push_to_hub=True, |
| learning_rate=2e-4, |
| per_device_train_batch_size=4, |
| gradient_accumulation_steps=4, |
| num_train_epochs=2, |
| bf16=True, |
| gradient_checkpointing=True, |
| logging_strategy='steps', |
| logging_steps=10, |
| logging_first_step=True, |
| disable_tqdm=True, |
| report_to='trackio', |
| run_name='proposer-sft-qwen3-1.7b', |
| dataset_text_field='text', |
| ) |
|
|
| print('Initializing trainer...') |
| trainer = SFTTrainer( |
| model=MODEL, |
| train_dataset=ds['train'], |
| eval_dataset=ds['test'], |
| args=config, |
| peft_config=peft_config, |
| ) |
|
|
| print('Training proposer...') |
| trainer.train() |
| trainer.push_to_hub() |
| print('Proposer training complete.') |
|
|