File size: 1,317 Bytes
77c676b
 
 
 
 
2fea9f4
 
 
 
77c676b
2fea9f4
 
77c676b
2fea9f4
 
 
 
 
 
77c676b
2fea9f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77c676b
2fea9f4
 
 
 
 
 
 
 
77c676b
2fea9f4
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig

HUB_ORG = 'narcolepticchicken'
MODEL = 'Qwen/Qwen3-1.7B'
DATASET = f'{HUB_ORG}/speculative-actions-proposer-sft'
OUTPUT = f'{HUB_ORG}/speculative-proposer-qwen3-1.7b'

print('Loading dataset...')
ds = load_dataset(DATASET)

print('Configuring LoRA...')
peft_config = LoraConfig(
    r=16, lora_alpha=32,
    target_modules=['q_proj', 'v_proj', 'k_proj', 'o_proj'],
    modules_to_save=['embed_tokens', 'lm_head'],
)

print('Configuring SFT...')
config = SFTConfig(
    output_dir='/tmp/proposer-out',
    hub_model_id=OUTPUT,
    push_to_hub=True,
    learning_rate=2e-4,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    num_train_epochs=2,
    bf16=True,
    gradient_checkpointing=True,
    logging_strategy='steps',
    logging_steps=10,
    logging_first_step=True,
    disable_tqdm=True,
    report_to='trackio',
    run_name='proposer-sft-qwen3-1.7b',
    dataset_text_field='text',
)

print('Initializing trainer...')
trainer = SFTTrainer(
    model=MODEL,
    train_dataset=ds['train'],
    eval_dataset=ds['test'],
    args=config,
    peft_config=peft_config,
)

print('Training proposer...')
trainer.train()
trainer.push_to_hub()
print('Proposer training complete.')