| """ |
| Fine-tune Qwen3-1.7B for mathematical reasoning using TRL SFTTrainer. |
| Tokenization done manually to avoid max_seq_length API issues. |
| """ |
| import os |
| import torch |
| from datasets import load_dataset, concatenate_datasets |
| from peft import LoraConfig |
| from trl import SFTTrainer, SFTConfig |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| import trackio |
|
|
| MODEL_NAME = "Qwen/Qwen3-1.7B" |
| OUTPUT_DIR = "./qwen3-1.7b-math-sft" |
| HUB_MODEL_ID = os.environ.get("HUB_MODEL_ID", "GuizMeuh/qwen3-1.7b-math-sft") |
|
|
| EPOCHS = 3 |
| LR = 2e-4 |
| PER_DEVICE_BATCH = 4 |
| GRADIENT_ACCUMULATION = 32 |
| MAX_SEQ_LENGTH = 4096 |
| WARMUP_STEPS = 500 |
| LORA_R = 32 |
| LORA_ALPHA = 16 |
| LORA_DROPOUT = 0.05 |
|
|
|
|
| def format_messages(example): |
| """Convert to conversational messages format.""" |
| if "query" in example: |
| return { |
| "messages": [ |
| {"role": "user", "content": example["query"]}, |
| {"role": "assistant", "content": example["response"]}, |
| ] |
| } |
| else: |
| return { |
| "messages": [ |
| {"role": "user", "content": example["question"]}, |
| {"role": "assistant", "content": example["answer"]}, |
| ] |
| } |
|
|
|
|
| def tokenize_messages(example, tokenizer): |
| """Apply chat template and tokenize with truncation.""" |
| text = tokenizer.apply_chat_template( |
| example["messages"], tokenize=False, add_generation_prompt=False |
| ) |
| result = tokenizer( |
| text, |
| truncation=True, |
| max_length=MAX_SEQ_LENGTH, |
| padding=False, |
| return_tensors=None, |
| ) |
| result["labels"] = result["input_ids"].copy() |
| return result |
|
|
|
|
| def prepare_datasets(tokenizer): |
| print("Loading datasets...") |
| ds_metamath = load_dataset("meta-math/MetaMathQA", split="train") |
| ds_gsm8k = load_dataset("openai/gsm8k", "main", split="train") |
| ds_orca = load_dataset("microsoft/orca-math-word-problems-200k", split="train") |
|
|
| ds_metamath = ds_metamath.map(format_messages, remove_columns=ds_metamath.column_names) |
| ds_gsm8k = ds_gsm8k.map(format_messages, remove_columns=ds_gsm8k.column_names) |
| ds_orca = ds_orca.map(format_messages, remove_columns=ds_orca.column_names) |
|
|
| print(f"MetaMathQA: {len(ds_metamath)} | GSM8K: {len(ds_gsm8k)} | Orca: {len(ds_orca)}") |
|
|
| combined = concatenate_datasets([ds_metamath, ds_gsm8k, ds_orca]) |
| print(f"Combined: {len(combined)} samples") |
|
|
| print("Tokenizing with chat template...") |
| combined = combined.map( |
| lambda ex: tokenize_messages(ex, tokenizer), |
| batched=False, |
| remove_columns=["messages"], |
| ) |
| combined = combined.shuffle(seed=42) |
| return combined |
|
|
|
|
| def main(): |
| run = trackio.init(project="qwen3-math-sft", name="qwen3-1.7b-math-lora") |
| print(f"Trackio run started: {run}") |
|
|
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| print(f"Loading model {MODEL_NAME}...") |
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_NAME, dtype=torch.bfloat16, trust_remote_code=True, |
| ) |
| model.gradient_checkpointing_enable() |
|
|
| train_dataset = prepare_datasets(tokenizer) |
| print(f"Final train dataset columns: {train_dataset.column_names}") |
|
|
| peft_config = LoraConfig( |
| r=LORA_R, lora_alpha=LORA_ALPHA, lora_dropout=LORA_DROPOUT, |
| bias="none", task_type="CAUSAL_LM", |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], |
| ) |
|
|
| training_args = SFTConfig( |
| output_dir=OUTPUT_DIR, num_train_epochs=EPOCHS, |
| per_device_train_batch_size=PER_DEVICE_BATCH, |
| gradient_accumulation_steps=GRADIENT_ACCUMULATION, |
| learning_rate=LR, bf16=True, |
| lr_scheduler_type="cosine", warmup_steps=WARMUP_STEPS, |
| logging_steps=10, save_strategy="epoch", save_total_limit=2, |
| gradient_checkpointing=True, push_to_hub=True, hub_model_id=HUB_MODEL_ID, |
| hub_private_repo=False, report_to="trackio", disable_tqdm=True, |
| logging_strategy="steps", logging_first_step=True, seed=42, |
| ) |
|
|
| print("Initializing SFTTrainer...") |
| trainer = SFTTrainer( |
| model=model, |
| processing_class=tokenizer, |
| train_dataset=train_dataset, |
| peft_config=peft_config, |
| args=training_args, |
| ) |
|
|
| print("Starting training...") |
| trainer.train() |
|
|
| print("Saving final model...") |
| trainer.save_model(OUTPUT_DIR) |
| trainer.push_to_hub() |
|
|
| trackio.log({"training_status": "completed"}) |
| trackio.finish() |
| print("Training complete!") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|