import torch from datasets import load_dataset from trl import SFTTrainer, SFTConfig from peft import LoraConfig HUB_ORG = 'narcolepticchicken' MODEL = 'Qwen/Qwen3-1.7B' DATASET = f'{HUB_ORG}/speculative-actions-proposer-sft' OUTPUT = f'{HUB_ORG}/speculative-proposer-qwen3-1.7b' print('Loading dataset...') ds = load_dataset(DATASET) print('Configuring LoRA...') peft_config = LoraConfig( r=16, lora_alpha=32, target_modules=['q_proj', 'v_proj', 'k_proj', 'o_proj'], modules_to_save=['embed_tokens', 'lm_head'], ) print('Configuring SFT...') config = SFTConfig( output_dir='/tmp/proposer-out', hub_model_id=OUTPUT, push_to_hub=True, learning_rate=2e-4, per_device_train_batch_size=4, gradient_accumulation_steps=4, num_train_epochs=2, bf16=True, gradient_checkpointing=True, logging_strategy='steps', logging_steps=10, logging_first_step=True, disable_tqdm=True, report_to='trackio', run_name='proposer-sft-qwen3-1.7b', dataset_text_field='text', ) print('Initializing trainer...') trainer = SFTTrainer( model=MODEL, train_dataset=ds['train'], eval_dataset=ds['test'], args=config, peft_config=peft_config, ) print('Training proposer...') trainer.train() trainer.push_to_hub() print('Proposer training complete.')