Text Generation
PEFT
Safetensors
Transformers
English
medical
icd-10
clinical-coding
healthcare
lora
sft
trl
conversational
qwen2.5-0.5b-icd10cm-coder / train_icd10_gpu.py
Rakshithch's picture
Update GPU training script with all API fixes and comprehensive evaluation
df41e97 verified
"""
ICD-10-CM Clinical Coding Fine-tuning Script (GPU - Production)
================================================================
Fine-tunes Qwen2.5-1.5B-Instruct with LoRA on 366K synthetic EHR records
for ICD-10-CM code classification from clinical text.
Requirements:
pip install torch transformers trl peft datasets trackio accelerate flash-attn
Hardware: A10G (24GB) or better. Training time: ~2-3 hours.
Based on:
- Recipe 3: Lenz et al. (arxiv:2510.13624) — Instruction-tuning for ICD coding
- Recipe 2: MERA (arxiv:2501.17326) — Code memorization improves accuracy
- FiscaAI/synth-ehr-icd10cm-prompt dataset (366K rows, 5071 ICD-10-CM codes)
- TRL SFTTrainer with prompt/completion format (loss on codes only)
To run:
# On HF Jobs (A10G):
hf_jobs run --script train_icd10_gpu.py --hardware a10g-large --timeout 4h \
--deps torch transformers trl peft datasets trackio accelerate flash-attn
# Or locally with GPU:
pip install torch transformers trl peft datasets trackio accelerate flash-attn
python train_icd10_gpu.py
"""
import os, re, json, random, gc
from collections import Counter
import torch, trackio
from datasets import load_dataset
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTConfig, SFTTrainer
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
HUB_MODEL_ID = "Rakshithch/qwen2.5-1.5b-icd10cm-coder"
DATASET_ID = "Rakshithch/icd10cm-clinical-coding-sft"
OUTPUT_DIR = "./qwen2.5-1.5b-icd10cm-lora"
LEARNING_RATE = 2e-4
NUM_EPOCHS = 3
BATCH_SIZE = 4
GRAD_ACCUM = 8
MAX_LENGTH = 1024
LORA_R = 16
LORA_ALPHA = 32
SEED = 42
random.seed(SEED)
trackio.init(project="icd10-clinical-coding", name="qwen2.5-1.5b-lora-r16-full",
config={"model": MODEL_NAME, "dataset": DATASET_ID, "lora_r": LORA_R,
"lr": LEARNING_RATE, "epochs": NUM_EPOCHS, "eff_batch": BATCH_SIZE*GRAD_ACCUM})
print("Loading dataset...")
ds = load_dataset(DATASET_ID)
print(f"Train: {len(ds['train'])}, Val: {len(ds['validation'])}, Test: {len(ds['test'])}")
def to_pc(example):
msgs = example["messages"]
return {"prompt": msgs[:2], "completion": [msgs[2]]}
train_ds = ds["train"].map(to_pc, remove_columns=ds["train"].column_names, num_proc=4)
val_ds = ds["validation"].map(to_pc, remove_columns=ds["validation"].column_names, num_proc=4)
test_ds = ds["test"]
print(f"Loading {MODEL_NAME}...")
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, dtype=torch.bfloat16,
attn_implementation="flash_attention_2", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
peft_config = LoraConfig(r=LORA_R, lora_alpha=LORA_ALPHA, lora_dropout=0.05,
bias="none", task_type="CAUSAL_LM",
target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"])
training_args = SFTConfig(
output_dir=OUTPUT_DIR, num_train_epochs=NUM_EPOCHS,
per_device_train_batch_size=BATCH_SIZE, per_device_eval_batch_size=BATCH_SIZE,
gradient_accumulation_steps=GRAD_ACCUM, learning_rate=LEARNING_RATE,
lr_scheduler_type="cosine", warmup_steps=100, optim="adamw_torch_fused",
bf16=True, max_length=MAX_LENGTH, gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
logging_steps=25, logging_first_step=True, disable_tqdm=True,
report_to="trackio", run_name="qwen2.5-1.5b-icd10cm-lora-r16-full",
eval_strategy="steps", eval_steps=500, save_strategy="steps", save_steps=500,
save_total_limit=3, load_best_model_at_end=True, metric_for_best_model="eval_loss",
push_to_hub=True, hub_model_id=HUB_MODEL_ID, hub_strategy="every_save",
)
trainer = SFTTrainer(model=model, args=training_args, train_dataset=train_ds,
eval_dataset=val_ds, peft_config=peft_config, processing_class=tokenizer)
trainer.model.print_trainable_parameters()
result = trainer.train()
print(f"Training loss: {result.training_loss:.4f}")
trainer.save_model(OUTPUT_DIR); tokenizer.save_pretrained(OUTPUT_DIR)
trainer.push_to_hub()
# Evaluation
del trainer, model; gc.collect(); torch.cuda.empty_cache()
from transformers import pipeline as hf_pipeline
pipe = hf_pipeline("text-generation", model=OUTPUT_DIR, tokenizer=tokenizer,
device_map="auto", max_new_tokens=150)
eval_size = min(2000, len(test_ds))
eval_indices = random.sample(range(len(test_ds)), eval_size)
correct_exact = correct_category = correct_chapter = total = 0
results = []
for idx in eval_indices:
example = test_ds[idx]
gt_code = example["icd_code"]
try:
out = pipe(example["messages"][:2], max_new_tokens=150, do_sample=False)
generated = out[0]["generated_text"][-1]["content"]
except: total += 1; continue
pred_codes = re.findall(r'\b([A-Z]\d{2}(?:\.\d{1,4})?(?:[A-Z])?)\b', generated)
pred = pred_codes[0] if pred_codes else "NONE"
exact = pred == gt_code
if exact: correct_exact += 1
if pred[:3] == gt_code[:3]: correct_category += 1
if pred[0] == gt_code[0]: correct_chapter += 1
total += 1
results.append({"gt": gt_code, "pred": pred, "exact": exact})
exact_acc = correct_exact/max(total,1)*100
cat_acc = correct_category/max(total,1)*100
ch_acc = correct_chapter/max(total,1)*100
print(f"Exact: {exact_acc:.1f}% | Category: {cat_acc:.1f}% | Chapter: {ch_acc:.1f}%")
trackio.log({"eval/exact_match": exact_acc, "eval/category": cat_acc, "eval/chapter": ch_acc})
trackio.finish()