qwen3-1.7b-math-sft / train_math.py
GuizMeuh's picture
Upload train_math.py
8af5454 verified
"""
Fine-tune Qwen3-1.7B for mathematical reasoning using TRL SFTTrainer.
Tokenization done manually to avoid max_seq_length API issues.
"""
import os
import torch
from datasets import load_dataset, concatenate_datasets
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
import trackio
MODEL_NAME = "Qwen/Qwen3-1.7B"
OUTPUT_DIR = "./qwen3-1.7b-math-sft"
HUB_MODEL_ID = os.environ.get("HUB_MODEL_ID", "GuizMeuh/qwen3-1.7b-math-sft")
EPOCHS = 3
LR = 2e-4
PER_DEVICE_BATCH = 4
GRADIENT_ACCUMULATION = 32
MAX_SEQ_LENGTH = 4096
WARMUP_STEPS = 500
LORA_R = 32
LORA_ALPHA = 16
LORA_DROPOUT = 0.05
def format_messages(example):
"""Convert to conversational messages format."""
if "query" in example:
return {
"messages": [
{"role": "user", "content": example["query"]},
{"role": "assistant", "content": example["response"]},
]
}
else:
return {
"messages": [
{"role": "user", "content": example["question"]},
{"role": "assistant", "content": example["answer"]},
]
}
def tokenize_messages(example, tokenizer):
"""Apply chat template and tokenize with truncation."""
text = tokenizer.apply_chat_template(
example["messages"], tokenize=False, add_generation_prompt=False
)
result = tokenizer(
text,
truncation=True,
max_length=MAX_SEQ_LENGTH,
padding=False,
return_tensors=None,
)
result["labels"] = result["input_ids"].copy()
return result
def prepare_datasets(tokenizer):
print("Loading datasets...")
ds_metamath = load_dataset("meta-math/MetaMathQA", split="train")
ds_gsm8k = load_dataset("openai/gsm8k", "main", split="train")
ds_orca = load_dataset("microsoft/orca-math-word-problems-200k", split="train")
ds_metamath = ds_metamath.map(format_messages, remove_columns=ds_metamath.column_names)
ds_gsm8k = ds_gsm8k.map(format_messages, remove_columns=ds_gsm8k.column_names)
ds_orca = ds_orca.map(format_messages, remove_columns=ds_orca.column_names)
print(f"MetaMathQA: {len(ds_metamath)} | GSM8K: {len(ds_gsm8k)} | Orca: {len(ds_orca)}")
combined = concatenate_datasets([ds_metamath, ds_gsm8k, ds_orca])
print(f"Combined: {len(combined)} samples")
print("Tokenizing with chat template...")
combined = combined.map(
lambda ex: tokenize_messages(ex, tokenizer),
batched=False,
remove_columns=["messages"],
)
combined = combined.shuffle(seed=42)
return combined
def main():
run = trackio.init(project="qwen3-math-sft", name="qwen3-1.7b-math-lora")
print(f"Trackio run started: {run}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print(f"Loading model {MODEL_NAME}...")
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME, dtype=torch.bfloat16, trust_remote_code=True,
)
model.gradient_checkpointing_enable()
train_dataset = prepare_datasets(tokenizer)
print(f"Final train dataset columns: {train_dataset.column_names}")
peft_config = LoraConfig(
r=LORA_R, lora_alpha=LORA_ALPHA, lora_dropout=LORA_DROPOUT,
bias="none", task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)
training_args = SFTConfig(
output_dir=OUTPUT_DIR, num_train_epochs=EPOCHS,
per_device_train_batch_size=PER_DEVICE_BATCH,
gradient_accumulation_steps=GRADIENT_ACCUMULATION,
learning_rate=LR, bf16=True,
lr_scheduler_type="cosine", warmup_steps=WARMUP_STEPS,
logging_steps=10, save_strategy="epoch", save_total_limit=2,
gradient_checkpointing=True, push_to_hub=True, hub_model_id=HUB_MODEL_ID,
hub_private_repo=False, report_to="trackio", disable_tqdm=True,
logging_strategy="steps", logging_first_step=True, seed=42,
)
print("Initializing SFTTrainer...")
trainer = SFTTrainer(
model=model,
processing_class=tokenizer,
train_dataset=train_dataset,
peft_config=peft_config,
args=training_args,
)
print("Starting training...")
trainer.train()
print("Saving final model...")
trainer.save_model(OUTPUT_DIR)
trainer.push_to_hub()
trackio.log({"training_status": "completed"})
trackio.finish()
print("Training complete!")
if __name__ == "__main__":
main()