Qwen2.5-3B: Medical Assistant with SFT β DPO β GRPO Pipeline
A fine-tuned version of Qwen2.5-3B trained through a full 3-stage alignment pipeline: Supervised Fine-Tuning (SFT) β Direct Preference Optimization (DPO) β Group Relative Policy Optimization (GRPO).
ποΈ Repository Structure
βββ sft-checkpoint/ # Stage 1: SFT LoRA adapter (AlpaCare medical dataset)
βββ dpo-checkpoint/ # Stage 2: DPO LoRA adapter (Anthropic hh-rlhf)
βββ grpo-checkpoint/ # Stage 3: GRPO LoRA adapter (GSM8K reasoning)
βββ README.md
Each folder contains LoRA adapter weights only. Load them on top of unsloth/Qwen2.5-3B as described below.
ποΈ Training Pipeline
Stage 1 β Supervised Fine-Tuning (SFT)
π Notebook: Supervised Fine-tuning using Qwen2.5_3B
| Config | Value |
|---|---|
| Base model | unsloth/Qwen2.5-3B |
| Dataset | lavita/AlpaCare-MedInstruct-52k |
| LoRA rank | 16 |
| Learning rate | 5e-6 |
| Max steps | 20,000 |
| Batch size | 1 (grad accum 2) |
| Optimizer | AdamW 8-bit |
| Training mask | Response-only (train_on_responses_only) |
Purpose: Teach the model to follow medical instructions in ChatML format.
Stage 2 β Direct Preference Optimization (DPO)
π Notebook: Applying DPO on SFT Qwen2.5_3B
| Config | Value |
|---|---|
| Base | SFT checkpoint |
| Dataset | Anthropic/hh-rlhf (1,000 samples) |
| LoRA rank | 64 |
| Learning rate | 5e-6 |
| Beta | 0.1 |
| Loss type | sigmoid |
| Max length | 1024 |
Purpose: Align the model's responses with human preferences β reduce harmful, unhelpful, or unsafe outputs.
Stage 3 β Group Relative Policy Optimization (GRPO)
π Notebook: Applying GRPO on Qwen2.5_3B
| Config | Value |
|---|---|
| Base | DPO checkpoint |
| Dataset | openai/gsm8k |
| Learning rate | 5e-6 |
| Max steps | 500 |
| Num generations | 2 |
| Max completion length | 300 |
| Optimizer | AdamW 8-bit |
Reward functions:
correctness_reward_funcβ exact answer match (+2.0)int_reward_funcβ integer answer reward (+0.5)strict_format_reward_funcβ strict XML format (+0.5)soft_format_reward_funcβ soft XML format (+0.5)xmlcount_reward_funcβ XML tag count reward (+0.5)
Purpose: Teach the model to reason step-by-step using <reasoning> and <answer> XML tags before answering.
π How to Use
Load SFT Checkpoint
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
import torch
base_model_id = "unsloth/Qwen2.5-3B"
tokenizer = AutoTokenizer.from_pretrained(base_model_id)
tokenizer.pad_token_id = tokenizer.eos_token_id
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
base_model = AutoModelForCausalLM.from_pretrained(
base_model_id,
quantization_config=quantization_config,
device_map="auto",
)
model = PeftModel.from_pretrained(
base_model,
"diaaessam/Qwen2.5-3B-Medical-SFT-DPO-GRPO/sft-checkpoint",
is_trainable=False,
)
model.eval()
Load GRPO Checkpoint (with Reasoning)
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
import torch
base_model_id = "unsloth/Qwen2.5-3B"
tokenizer = AutoTokenizer.from_pretrained(base_model_id)
tokenizer.pad_token_id = tokenizer.eos_token_id
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
base_model = AutoModelForCausalLM.from_pretrained(
base_model_id,
quantization_config=quantization_config,
device_map="auto",
)
model = PeftModel.from_pretrained(
base_model,
"diaaessam/Qwen2.5-3B-Medical-SFT-DPO-GRPO/grpo-checkpoint",
is_trainable=False,
)
model.eval()
Inference β Medical Question (SFT/DPO)
def ask(model, tokenizer, question, system="You are a helpful medical assistant."):
messages = [
{"role": "system", "content": system},
{"role": "user", "content": question},
]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048).to("cuda")
with torch.no_grad():
outputs = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_new_tokens=256,
do_sample=True,
temperature=0.7,
top_k=50,
top_p=0.95,
pad_token_id=tokenizer.eos_token_id,
)
return tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
print(ask(model, tokenizer, "What are the symptoms of diabetes?"))
Inference β Reasoning Question (GRPO)
GRPO_SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""
def ask_with_reasoning(model, tokenizer, question):
messages = [
{"role": "system", "content": GRPO_SYSTEM_PROMPT},
{"role": "user", "content": question},
]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048).to("cuda")
with torch.no_grad():
outputs = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_k=50,
top_p=0.95,
pad_token_id=tokenizer.eos_token_id,
)
return tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
print(ask_with_reasoning(model, tokenizer, "If a patient takes 500mg of paracetamol every 6 hours, what is the total daily dose?"))
Expected output:
<reasoning>
The patient takes 500mg every 6 hours.
There are 24 hours in a day, so 24 / 6 = 4 doses per day.
Total = 500mg Γ 4 = 2000mg per day.
</reasoning>
<answer>
2000mg
</answer>
π Datasets Used
| Stage | Dataset | Size |
|---|---|---|
| SFT | lavita/AlpaCare-MedInstruct-52k | 52k medical instructions |
| DPO | Anthropic/hh-rlhf | 1,000 preference pairs |
| GRPO | openai/gsm8k | 7,473 math word problems |
βοΈ Hardware & Framework
- GPU: NVIDIA Tesla T4 (Kaggle/Colab)
- Framework: HuggingFace Transformers + TRL + PEFT
- Quantization: 4-bit NF4 (bitsandbytes)
- LoRA: PEFT with rank 16 (SFT) / 64 (DPO, GRPO)
π Citation
If you use this model, please cite the base model and the datasets:
@misc{qwen2.5,
title = {Qwen2.5: A Party of Foundation Models},
author = {Qwen Team},
year = {2024},
url = {https://qwenlm.github.io/blog/qwen2.5/}
}
β οΈ Limitations
- The model was trained on a subset of hh-rlhf (1,000 samples) and GSM8K (500 GRPO steps) β not a full training run.
- Medical outputs should not be used as clinical advice.
- GRPO reasoning is trained on math problems β medical reasoning chains may not always be accurate.
- The model is 3B parameters β larger models will generally outperform it on complex tasks.
- Downloads last month
- -