""" Fine-tune Qwen3-1.7B for mathematical reasoning using TRL SFTTrainer. Tokenization done manually to avoid max_seq_length API issues. """ import os import torch from datasets import load_dataset, concatenate_datasets from peft import LoraConfig from trl import SFTTrainer, SFTConfig from transformers import AutoModelForCausalLM, AutoTokenizer import trackio MODEL_NAME = "Qwen/Qwen3-1.7B" OUTPUT_DIR = "./qwen3-1.7b-math-sft" HUB_MODEL_ID = os.environ.get("HUB_MODEL_ID", "GuizMeuh/qwen3-1.7b-math-sft") EPOCHS = 3 LR = 2e-4 PER_DEVICE_BATCH = 4 GRADIENT_ACCUMULATION = 32 MAX_SEQ_LENGTH = 4096 WARMUP_STEPS = 500 LORA_R = 32 LORA_ALPHA = 16 LORA_DROPOUT = 0.05 def format_messages(example): """Convert to conversational messages format.""" if "query" in example: return { "messages": [ {"role": "user", "content": example["query"]}, {"role": "assistant", "content": example["response"]}, ] } else: return { "messages": [ {"role": "user", "content": example["question"]}, {"role": "assistant", "content": example["answer"]}, ] } def tokenize_messages(example, tokenizer): """Apply chat template and tokenize with truncation.""" text = tokenizer.apply_chat_template( example["messages"], tokenize=False, add_generation_prompt=False ) result = tokenizer( text, truncation=True, max_length=MAX_SEQ_LENGTH, padding=False, return_tensors=None, ) result["labels"] = result["input_ids"].copy() return result def prepare_datasets(tokenizer): print("Loading datasets...") ds_metamath = load_dataset("meta-math/MetaMathQA", split="train") ds_gsm8k = load_dataset("openai/gsm8k", "main", split="train") ds_orca = load_dataset("microsoft/orca-math-word-problems-200k", split="train") ds_metamath = ds_metamath.map(format_messages, remove_columns=ds_metamath.column_names) ds_gsm8k = ds_gsm8k.map(format_messages, remove_columns=ds_gsm8k.column_names) ds_orca = ds_orca.map(format_messages, remove_columns=ds_orca.column_names) print(f"MetaMathQA: {len(ds_metamath)} | GSM8K: {len(ds_gsm8k)} | Orca: {len(ds_orca)}") combined = concatenate_datasets([ds_metamath, ds_gsm8k, ds_orca]) print(f"Combined: {len(combined)} samples") print("Tokenizing with chat template...") combined = combined.map( lambda ex: tokenize_messages(ex, tokenizer), batched=False, remove_columns=["messages"], ) combined = combined.shuffle(seed=42) return combined def main(): run = trackio.init(project="qwen3-math-sft", name="qwen3-1.7b-math-lora") print(f"Trackio run started: {run}") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token print(f"Loading model {MODEL_NAME}...") model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, dtype=torch.bfloat16, trust_remote_code=True, ) model.gradient_checkpointing_enable() train_dataset = prepare_datasets(tokenizer) print(f"Final train dataset columns: {train_dataset.column_names}") peft_config = LoraConfig( r=LORA_R, lora_alpha=LORA_ALPHA, lora_dropout=LORA_DROPOUT, bias="none", task_type="CAUSAL_LM", target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], ) training_args = SFTConfig( output_dir=OUTPUT_DIR, num_train_epochs=EPOCHS, per_device_train_batch_size=PER_DEVICE_BATCH, gradient_accumulation_steps=GRADIENT_ACCUMULATION, learning_rate=LR, bf16=True, lr_scheduler_type="cosine", warmup_steps=WARMUP_STEPS, logging_steps=10, save_strategy="epoch", save_total_limit=2, gradient_checkpointing=True, push_to_hub=True, hub_model_id=HUB_MODEL_ID, hub_private_repo=False, report_to="trackio", disable_tqdm=True, logging_strategy="steps", logging_first_step=True, seed=42, ) print("Initializing SFTTrainer...") trainer = SFTTrainer( model=model, processing_class=tokenizer, train_dataset=train_dataset, peft_config=peft_config, args=training_args, ) print("Starting training...") trainer.train() print("Saving final model...") trainer.save_model(OUTPUT_DIR) trainer.push_to_hub() trackio.log({"training_status": "completed"}) trackio.finish() print("Training complete!") if __name__ == "__main__": main()