File size: 4,625 Bytes
9399763
 
8af5454
9399763
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a3528da
9399763
 
 
 
 
8af5454
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9399763
 
8af5454
9399763
 
 
 
8af5454
 
 
 
 
 
9399763
 
8af5454
 
 
 
 
 
 
 
9399763
 
 
 
 
12f3894
86d41c6
9399763
 
 
 
 
 
 
c12a7c8
9399763
 
 
8af5454
 
9399763
 
 
 
 
 
 
 
 
 
 
c12a7c8
a3528da
9399763
 
 
 
 
 
 
 
8af5454
 
 
 
 
9399763
 
 
 
 
 
 
 
 
12f3894
9399763
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""
Fine-tune Qwen3-1.7B for mathematical reasoning using TRL SFTTrainer.
Tokenization done manually to avoid max_seq_length API issues.
"""
import os
import torch
from datasets import load_dataset, concatenate_datasets
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
import trackio

MODEL_NAME = "Qwen/Qwen3-1.7B"
OUTPUT_DIR = "./qwen3-1.7b-math-sft"
HUB_MODEL_ID = os.environ.get("HUB_MODEL_ID", "GuizMeuh/qwen3-1.7b-math-sft")

EPOCHS = 3
LR = 2e-4
PER_DEVICE_BATCH = 4
GRADIENT_ACCUMULATION = 32
MAX_SEQ_LENGTH = 4096
WARMUP_STEPS = 500
LORA_R = 32
LORA_ALPHA = 16
LORA_DROPOUT = 0.05


def format_messages(example):
    """Convert to conversational messages format."""
    if "query" in example:
        return {
            "messages": [
                {"role": "user", "content": example["query"]},
                {"role": "assistant", "content": example["response"]},
            ]
        }
    else:
        return {
            "messages": [
                {"role": "user", "content": example["question"]},
                {"role": "assistant", "content": example["answer"]},
            ]
        }


def tokenize_messages(example, tokenizer):
    """Apply chat template and tokenize with truncation."""
    text = tokenizer.apply_chat_template(
        example["messages"], tokenize=False, add_generation_prompt=False
    )
    result = tokenizer(
        text,
        truncation=True,
        max_length=MAX_SEQ_LENGTH,
        padding=False,
        return_tensors=None,
    )
    result["labels"] = result["input_ids"].copy()
    return result


def prepare_datasets(tokenizer):
    print("Loading datasets...")
    ds_metamath = load_dataset("meta-math/MetaMathQA", split="train")
    ds_gsm8k = load_dataset("openai/gsm8k", "main", split="train")
    ds_orca = load_dataset("microsoft/orca-math-word-problems-200k", split="train")

    ds_metamath = ds_metamath.map(format_messages, remove_columns=ds_metamath.column_names)
    ds_gsm8k = ds_gsm8k.map(format_messages, remove_columns=ds_gsm8k.column_names)
    ds_orca = ds_orca.map(format_messages, remove_columns=ds_orca.column_names)

    print(f"MetaMathQA: {len(ds_metamath)} | GSM8K: {len(ds_gsm8k)} | Orca: {len(ds_orca)}")

    combined = concatenate_datasets([ds_metamath, ds_gsm8k, ds_orca])
    print(f"Combined: {len(combined)} samples")

    print("Tokenizing with chat template...")
    combined = combined.map(
        lambda ex: tokenize_messages(ex, tokenizer),
        batched=False,
        remove_columns=["messages"],
    )
    combined = combined.shuffle(seed=42)
    return combined


def main():
    run = trackio.init(project="qwen3-math-sft", name="qwen3-1.7b-math-lora")
    print(f"Trackio run started: {run}")

    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    print(f"Loading model {MODEL_NAME}...")
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME, dtype=torch.bfloat16, trust_remote_code=True,
    )
    model.gradient_checkpointing_enable()

    train_dataset = prepare_datasets(tokenizer)
    print(f"Final train dataset columns: {train_dataset.column_names}")

    peft_config = LoraConfig(
        r=LORA_R, lora_alpha=LORA_ALPHA, lora_dropout=LORA_DROPOUT,
        bias="none", task_type="CAUSAL_LM",
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    )

    training_args = SFTConfig(
        output_dir=OUTPUT_DIR, num_train_epochs=EPOCHS,
        per_device_train_batch_size=PER_DEVICE_BATCH,
        gradient_accumulation_steps=GRADIENT_ACCUMULATION,
        learning_rate=LR, bf16=True,
        lr_scheduler_type="cosine", warmup_steps=WARMUP_STEPS,
        logging_steps=10, save_strategy="epoch", save_total_limit=2,
        gradient_checkpointing=True, push_to_hub=True, hub_model_id=HUB_MODEL_ID,
        hub_private_repo=False, report_to="trackio", disable_tqdm=True,
        logging_strategy="steps", logging_first_step=True, seed=42,
    )

    print("Initializing SFTTrainer...")
    trainer = SFTTrainer(
        model=model,
        processing_class=tokenizer,
        train_dataset=train_dataset,
        peft_config=peft_config,
        args=training_args,
    )

    print("Starting training...")
    trainer.train()

    print("Saving final model...")
    trainer.save_model(OUTPUT_DIR)
    trainer.push_to_hub()

    trackio.log({"training_status": "completed"})
    trackio.finish()
    print("Training complete!")


if __name__ == "__main__":
    main()