Qwen2.5-3B: Medical Assistant with SFT β†’ DPO β†’ GRPO Pipeline

A fine-tuned version of Qwen2.5-3B trained through a full 3-stage alignment pipeline: Supervised Fine-Tuning (SFT) β†’ Direct Preference Optimization (DPO) β†’ Group Relative Policy Optimization (GRPO).


πŸ—‚οΈ Repository Structure

β”œβ”€β”€ sft-checkpoint/          # Stage 1: SFT LoRA adapter (AlpaCare medical dataset)
β”œβ”€β”€ dpo-checkpoint/          # Stage 2: DPO LoRA adapter (Anthropic hh-rlhf)
β”œβ”€β”€ grpo-checkpoint/         # Stage 3: GRPO LoRA adapter (GSM8K reasoning)
└── README.md

Each folder contains LoRA adapter weights only. Load them on top of unsloth/Qwen2.5-3B as described below.


πŸ‹οΈ Training Pipeline

Stage 1 β€” Supervised Fine-Tuning (SFT)

πŸ““ Notebook: Supervised Fine-tuning using Qwen2.5_3B

Config Value
Base model unsloth/Qwen2.5-3B
Dataset lavita/AlpaCare-MedInstruct-52k
LoRA rank 16
Learning rate 5e-6
Max steps 20,000
Batch size 1 (grad accum 2)
Optimizer AdamW 8-bit
Training mask Response-only (train_on_responses_only)

Purpose: Teach the model to follow medical instructions in ChatML format.


Stage 2 β€” Direct Preference Optimization (DPO)

πŸ““ Notebook: Applying DPO on SFT Qwen2.5_3B

Config Value
Base SFT checkpoint
Dataset Anthropic/hh-rlhf (1,000 samples)
LoRA rank 64
Learning rate 5e-6
Beta 0.1
Loss type sigmoid
Max length 1024

Purpose: Align the model's responses with human preferences β€” reduce harmful, unhelpful, or unsafe outputs.


Stage 3 β€” Group Relative Policy Optimization (GRPO)

πŸ““ Notebook: Applying GRPO on Qwen2.5_3B

Config Value
Base DPO checkpoint
Dataset openai/gsm8k
Learning rate 5e-6
Max steps 500
Num generations 2
Max completion length 300
Optimizer AdamW 8-bit

Reward functions:

  • correctness_reward_func β€” exact answer match (+2.0)
  • int_reward_func β€” integer answer reward (+0.5)
  • strict_format_reward_func β€” strict XML format (+0.5)
  • soft_format_reward_func β€” soft XML format (+0.5)
  • xmlcount_reward_func β€” XML tag count reward (+0.5)

Purpose: Teach the model to reason step-by-step using <reasoning> and <answer> XML tags before answering.


πŸš€ How to Use

Load SFT Checkpoint

from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
import torch

base_model_id = "unsloth/Qwen2.5-3B"

tokenizer = AutoTokenizer.from_pretrained(base_model_id)
tokenizer.pad_token_id = tokenizer.eos_token_id

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)

base_model = AutoModelForCausalLM.from_pretrained(
    base_model_id,
    quantization_config=quantization_config,
    device_map="auto",
)

model = PeftModel.from_pretrained(
    base_model,
    "diaaessam/Qwen2.5-3B-Medical-SFT-DPO-GRPO/sft-checkpoint",
    is_trainable=False,
)
model.eval()

Load GRPO Checkpoint (with Reasoning)

from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
import torch

base_model_id = "unsloth/Qwen2.5-3B"

tokenizer = AutoTokenizer.from_pretrained(base_model_id)
tokenizer.pad_token_id = tokenizer.eos_token_id

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)

base_model = AutoModelForCausalLM.from_pretrained(
    base_model_id,
    quantization_config=quantization_config,
    device_map="auto",
)

model = PeftModel.from_pretrained(
    base_model,
    "diaaessam/Qwen2.5-3B-Medical-SFT-DPO-GRPO/grpo-checkpoint",
    is_trainable=False,
)
model.eval()

Inference β€” Medical Question (SFT/DPO)

def ask(model, tokenizer, question, system="You are a helpful medical assistant."):
    messages = [
        {"role": "system", "content": system},
        {"role": "user",   "content": question},
    ]
    prompt = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048).to("cuda")
    with torch.no_grad():
        outputs = model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_new_tokens=256,
            do_sample=True,
            temperature=0.7,
            top_k=50,
            top_p=0.95,
            pad_token_id=tokenizer.eos_token_id,
        )
    return tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)

print(ask(model, tokenizer, "What are the symptoms of diabetes?"))

Inference β€” Reasoning Question (GRPO)

GRPO_SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

def ask_with_reasoning(model, tokenizer, question):
    messages = [
        {"role": "system", "content": GRPO_SYSTEM_PROMPT},
        {"role": "user",   "content": question},
    ]
    prompt = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048).to("cuda")
    with torch.no_grad():
        outputs = model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_new_tokens=512,
            do_sample=True,
            temperature=0.7,
            top_k=50,
            top_p=0.95,
            pad_token_id=tokenizer.eos_token_id,
        )
    return tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)

print(ask_with_reasoning(model, tokenizer, "If a patient takes 500mg of paracetamol every 6 hours, what is the total daily dose?"))

Expected output:

<reasoning>
The patient takes 500mg every 6 hours.
There are 24 hours in a day, so 24 / 6 = 4 doses per day.
Total = 500mg Γ— 4 = 2000mg per day.
</reasoning>
<answer>
2000mg
</answer>

πŸ“Š Datasets Used

Stage Dataset Size
SFT lavita/AlpaCare-MedInstruct-52k 52k medical instructions
DPO Anthropic/hh-rlhf 1,000 preference pairs
GRPO openai/gsm8k 7,473 math word problems

βš™οΈ Hardware & Framework

  • GPU: NVIDIA Tesla T4 (Kaggle/Colab)
  • Framework: HuggingFace Transformers + TRL + PEFT
  • Quantization: 4-bit NF4 (bitsandbytes)
  • LoRA: PEFT with rank 16 (SFT) / 64 (DPO, GRPO)

πŸ“ Citation

If you use this model, please cite the base model and the datasets:

@misc{qwen2.5,
  title  = {Qwen2.5: A Party of Foundation Models},
  author = {Qwen Team},
  year   = {2024},
  url    = {https://qwenlm.github.io/blog/qwen2.5/}
}

⚠️ Limitations

  • The model was trained on a subset of hh-rlhf (1,000 samples) and GSM8K (500 GRPO steps) β€” not a full training run.
  • Medical outputs should not be used as clinical advice.
  • GRPO reasoning is trained on math problems β€” medical reasoning chains may not always be accurate.
  • The model is 3B parameters β€” larger models will generally outperform it on complex tasks.
Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for Diaa-Essam/Qwen2.5-3B-Medical-SFT-DPO-GRPO

Base model

Qwen/Qwen2.5-3B
Adapter
(292)
this model