Medical LLM Fine-tuning & RAG Pipeline (Qwen3-8B)
Fine-tuning Qwen3-8B on medical domain data using QLoRA, with a Retrieval-Augmented Generation (RAG) system built on top for knowledge-grounded medical question answering.
Model on Hugging Face: XinyuanWang/qwen3-8b-medical-lora
Overview
This project consists of two main components:
QLoRA Fine-tuning β Qwen3-8B is fine-tuned on medical Q&A data using 4-bit NF4 quantization and LoRA adapters, reducing GPU memory requirements while preserving model quality.
RAG Pipeline β A FAISS-based retrieval system indexes 125,847 chunks from 18 classic medical textbooks (MedRAG/textbooks). At inference time, relevant passages are retrieved and injected into the prompt before generation.
Results
| Metric | Value |
|---|---|
| Final training loss | 0.8553 |
| Final eval loss | 0.9455 |
| Train/eval gap | 0.091 (~10%, no overfitting) |
| LoRA adapter size | 174 MB (vs. 16 GB full model) |
| RAG knowledge base | 125,847 chunks, 18 textbooks |
| Retrieval speed | ~145 docs/sec (BGE-M3 on GPU) |
Setup
Requirements
# Fine-tuning
cd LLaMA-Factory && pip install -e ".[torch,bitsandbytes]"
# RAG + inference
pip install torch transformers peft bitsandbytes accelerate
pip install sentence-transformers faiss-cpu datasets
1. Prepare Training Data
python scripts/download_dataset.py # downloads 1,000 samples from OpenMed/Medical-Reasoning-SFT-Mega
python scripts/prepare_dataset.py # converts to Alpaca format β data/medical_train.json
2. Fine-tune the Model
python scripts/download_base_model.py # downloads Qwen3-8B locally
cp configs/dataset_info.json LLaMA-Factory/data/
cd LLaMA-Factory
llamafactory-cli train ../configs/train_medical_ft.yaml # output: saves/qwen3-8b-med-lora/
3. Build the RAG Index
python scripts/download_medrag_textbooks.py # downloads MedRAG/textbooks β data/medrag_textbooks.jsonl
python scripts/build_rag_index.py # builds FAISS index β rag_index/
4. Run Inference
# Interactive mode
python scripts/rag_inference.py --interactive
# Single query
python scripts/rag_inference.py --query "What are the symptoms of appendicitis?"
# Batch inference (50 samples, saves to data/ragas_input.json)
python scripts/run_rag_inference.py
Model Details
Base model: Qwen/Qwen3-8B
Fine-tuning configuration:
| Parameter | Value |
|---|---|
| Method | QLoRA (SFT) |
| Quantization | 4-bit NF4 + double quantization |
| LoRA rank / alpha | 16 / 32 |
| LoRA dropout | 0.05 |
| Target modules | q/k/v/o_proj, gate/up/down_proj |
| Epochs | 3 |
| Effective batch size | 8 (2 Γ 4 gradient accumulation) |
| Learning rate | 5e-5 (cosine, 10% warmup) |
| Sequence length | 512 tokens (packing enabled) |
| Precision | bfloat16 |
| Framework | LLaMA-Factory |
Training data: OpenMed/Medical-Reasoning-SFT-Mega β 1,000 samples, 4,966 Alpaca records after conversion.
RAG Details
Knowledge base: MedRAG/textbooks β 18 medical textbooks including Harrison's Internal Medicine, Schwartz's Surgery, Adams' Neurology, Katzung Pharmacology, Robbins Pathology, and more.
Pipeline:
Query β BGE-M3 Embedding β FAISS Top-5 Retrieval β Prompt β Qwen3-8B β Answer
| Component | Choice |
|---|---|
| Embedding model | BAAI/bge-m3 (1024-dim) |
| Index type | faiss.IndexFlatIP (exact cosine search) |
| Retrieval threshold | 0.45 cosine similarity |
| Generation | 4-bit NF4, greedy decoding, enable_thinking=False |
Repository Structure
medical-llm-finetune/
βββ configs/
β βββ train_medical_ft.yaml # Training configuration
β βββ dataset_info.json # LLaMA-Factory dataset registry
βββ data/
β βββ medical_reasoning_1k.json # Raw downloaded samples
β βββ medical_train.json # Alpaca format training set
β βββ ragas_input.json # Batch inference results (50 samples)
βββ saves/
β βββ qwen3-8b-med-lora/ # LoRA adapter + tokenizer
βββ rag_index/ # FAISS index + chunk metadata (local only)
βββ models/ # Base model weights (local only)
βββ scripts/
βββ download_dataset.py
βββ prepare_dataset.py
βββ download_base_model.py
βββ download_medrag_textbooks.py
βββ build_rag_index.py
βββ rag_inference.py
βββ run_rag_inference.py
References
- Downloads last month
- 2