File size: 4,625 Bytes
9399763 8af5454 9399763 a3528da 9399763 8af5454 9399763 8af5454 9399763 8af5454 9399763 8af5454 9399763 12f3894 86d41c6 9399763 c12a7c8 9399763 8af5454 9399763 c12a7c8 a3528da 9399763 8af5454 9399763 12f3894 9399763 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 | """
Fine-tune Qwen3-1.7B for mathematical reasoning using TRL SFTTrainer.
Tokenization done manually to avoid max_seq_length API issues.
"""
import os
import torch
from datasets import load_dataset, concatenate_datasets
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
import trackio
MODEL_NAME = "Qwen/Qwen3-1.7B"
OUTPUT_DIR = "./qwen3-1.7b-math-sft"
HUB_MODEL_ID = os.environ.get("HUB_MODEL_ID", "GuizMeuh/qwen3-1.7b-math-sft")
EPOCHS = 3
LR = 2e-4
PER_DEVICE_BATCH = 4
GRADIENT_ACCUMULATION = 32
MAX_SEQ_LENGTH = 4096
WARMUP_STEPS = 500
LORA_R = 32
LORA_ALPHA = 16
LORA_DROPOUT = 0.05
def format_messages(example):
"""Convert to conversational messages format."""
if "query" in example:
return {
"messages": [
{"role": "user", "content": example["query"]},
{"role": "assistant", "content": example["response"]},
]
}
else:
return {
"messages": [
{"role": "user", "content": example["question"]},
{"role": "assistant", "content": example["answer"]},
]
}
def tokenize_messages(example, tokenizer):
"""Apply chat template and tokenize with truncation."""
text = tokenizer.apply_chat_template(
example["messages"], tokenize=False, add_generation_prompt=False
)
result = tokenizer(
text,
truncation=True,
max_length=MAX_SEQ_LENGTH,
padding=False,
return_tensors=None,
)
result["labels"] = result["input_ids"].copy()
return result
def prepare_datasets(tokenizer):
print("Loading datasets...")
ds_metamath = load_dataset("meta-math/MetaMathQA", split="train")
ds_gsm8k = load_dataset("openai/gsm8k", "main", split="train")
ds_orca = load_dataset("microsoft/orca-math-word-problems-200k", split="train")
ds_metamath = ds_metamath.map(format_messages, remove_columns=ds_metamath.column_names)
ds_gsm8k = ds_gsm8k.map(format_messages, remove_columns=ds_gsm8k.column_names)
ds_orca = ds_orca.map(format_messages, remove_columns=ds_orca.column_names)
print(f"MetaMathQA: {len(ds_metamath)} | GSM8K: {len(ds_gsm8k)} | Orca: {len(ds_orca)}")
combined = concatenate_datasets([ds_metamath, ds_gsm8k, ds_orca])
print(f"Combined: {len(combined)} samples")
print("Tokenizing with chat template...")
combined = combined.map(
lambda ex: tokenize_messages(ex, tokenizer),
batched=False,
remove_columns=["messages"],
)
combined = combined.shuffle(seed=42)
return combined
def main():
run = trackio.init(project="qwen3-math-sft", name="qwen3-1.7b-math-lora")
print(f"Trackio run started: {run}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print(f"Loading model {MODEL_NAME}...")
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME, dtype=torch.bfloat16, trust_remote_code=True,
)
model.gradient_checkpointing_enable()
train_dataset = prepare_datasets(tokenizer)
print(f"Final train dataset columns: {train_dataset.column_names}")
peft_config = LoraConfig(
r=LORA_R, lora_alpha=LORA_ALPHA, lora_dropout=LORA_DROPOUT,
bias="none", task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)
training_args = SFTConfig(
output_dir=OUTPUT_DIR, num_train_epochs=EPOCHS,
per_device_train_batch_size=PER_DEVICE_BATCH,
gradient_accumulation_steps=GRADIENT_ACCUMULATION,
learning_rate=LR, bf16=True,
lr_scheduler_type="cosine", warmup_steps=WARMUP_STEPS,
logging_steps=10, save_strategy="epoch", save_total_limit=2,
gradient_checkpointing=True, push_to_hub=True, hub_model_id=HUB_MODEL_ID,
hub_private_repo=False, report_to="trackio", disable_tqdm=True,
logging_strategy="steps", logging_first_step=True, seed=42,
)
print("Initializing SFTTrainer...")
trainer = SFTTrainer(
model=model,
processing_class=tokenizer,
train_dataset=train_dataset,
peft_config=peft_config,
args=training_args,
)
print("Starting training...")
trainer.train()
print("Saving final model...")
trainer.save_model(OUTPUT_DIR)
trainer.push_to_hub()
trackio.log({"training_status": "completed"})
trackio.finish()
print("Training complete!")
if __name__ == "__main__":
main()
|