File size: 5,464 Bytes
39600f7 df41e97 39600f7 df41e97 39600f7 df41e97 39600f7 df41e97 39600f7 df41e97 39600f7 df41e97 39600f7 df41e97 39600f7 df41e97 39600f7 df41e97 39600f7 df41e97 39600f7 df41e97 39600f7 df41e97 39600f7 df41e97 39600f7 df41e97 39600f7 df41e97 39600f7 df41e97 39600f7 df41e97 39600f7 df41e97 39600f7 df41e97 39600f7 df41e97 39600f7 df41e97 39600f7 df41e97 39600f7 df41e97 39600f7 df41e97 39600f7 df41e97 39600f7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | """
ICD-10-CM Clinical Coding Fine-tuning Script (GPU - Production)
================================================================
Fine-tunes Qwen2.5-1.5B-Instruct with LoRA on 366K synthetic EHR records
for ICD-10-CM code classification from clinical text.
Requirements:
pip install torch transformers trl peft datasets trackio accelerate flash-attn
Hardware: A10G (24GB) or better. Training time: ~2-3 hours.
Based on:
- Recipe 3: Lenz et al. (arxiv:2510.13624) — Instruction-tuning for ICD coding
- Recipe 2: MERA (arxiv:2501.17326) — Code memorization improves accuracy
- FiscaAI/synth-ehr-icd10cm-prompt dataset (366K rows, 5071 ICD-10-CM codes)
- TRL SFTTrainer with prompt/completion format (loss on codes only)
To run:
# On HF Jobs (A10G):
hf_jobs run --script train_icd10_gpu.py --hardware a10g-large --timeout 4h \
--deps torch transformers trl peft datasets trackio accelerate flash-attn
# Or locally with GPU:
pip install torch transformers trl peft datasets trackio accelerate flash-attn
python train_icd10_gpu.py
"""
import os, re, json, random, gc
from collections import Counter
import torch, trackio
from datasets import load_dataset
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTConfig, SFTTrainer
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
HUB_MODEL_ID = "Rakshithch/qwen2.5-1.5b-icd10cm-coder"
DATASET_ID = "Rakshithch/icd10cm-clinical-coding-sft"
OUTPUT_DIR = "./qwen2.5-1.5b-icd10cm-lora"
LEARNING_RATE = 2e-4
NUM_EPOCHS = 3
BATCH_SIZE = 4
GRAD_ACCUM = 8
MAX_LENGTH = 1024
LORA_R = 16
LORA_ALPHA = 32
SEED = 42
random.seed(SEED)
trackio.init(project="icd10-clinical-coding", name="qwen2.5-1.5b-lora-r16-full",
config={"model": MODEL_NAME, "dataset": DATASET_ID, "lora_r": LORA_R,
"lr": LEARNING_RATE, "epochs": NUM_EPOCHS, "eff_batch": BATCH_SIZE*GRAD_ACCUM})
print("Loading dataset...")
ds = load_dataset(DATASET_ID)
print(f"Train: {len(ds['train'])}, Val: {len(ds['validation'])}, Test: {len(ds['test'])}")
def to_pc(example):
msgs = example["messages"]
return {"prompt": msgs[:2], "completion": [msgs[2]]}
train_ds = ds["train"].map(to_pc, remove_columns=ds["train"].column_names, num_proc=4)
val_ds = ds["validation"].map(to_pc, remove_columns=ds["validation"].column_names, num_proc=4)
test_ds = ds["test"]
print(f"Loading {MODEL_NAME}...")
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, dtype=torch.bfloat16,
attn_implementation="flash_attention_2", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
peft_config = LoraConfig(r=LORA_R, lora_alpha=LORA_ALPHA, lora_dropout=0.05,
bias="none", task_type="CAUSAL_LM",
target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"])
training_args = SFTConfig(
output_dir=OUTPUT_DIR, num_train_epochs=NUM_EPOCHS,
per_device_train_batch_size=BATCH_SIZE, per_device_eval_batch_size=BATCH_SIZE,
gradient_accumulation_steps=GRAD_ACCUM, learning_rate=LEARNING_RATE,
lr_scheduler_type="cosine", warmup_steps=100, optim="adamw_torch_fused",
bf16=True, max_length=MAX_LENGTH, gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
logging_steps=25, logging_first_step=True, disable_tqdm=True,
report_to="trackio", run_name="qwen2.5-1.5b-icd10cm-lora-r16-full",
eval_strategy="steps", eval_steps=500, save_strategy="steps", save_steps=500,
save_total_limit=3, load_best_model_at_end=True, metric_for_best_model="eval_loss",
push_to_hub=True, hub_model_id=HUB_MODEL_ID, hub_strategy="every_save",
)
trainer = SFTTrainer(model=model, args=training_args, train_dataset=train_ds,
eval_dataset=val_ds, peft_config=peft_config, processing_class=tokenizer)
trainer.model.print_trainable_parameters()
result = trainer.train()
print(f"Training loss: {result.training_loss:.4f}")
trainer.save_model(OUTPUT_DIR); tokenizer.save_pretrained(OUTPUT_DIR)
trainer.push_to_hub()
# Evaluation
del trainer, model; gc.collect(); torch.cuda.empty_cache()
from transformers import pipeline as hf_pipeline
pipe = hf_pipeline("text-generation", model=OUTPUT_DIR, tokenizer=tokenizer,
device_map="auto", max_new_tokens=150)
eval_size = min(2000, len(test_ds))
eval_indices = random.sample(range(len(test_ds)), eval_size)
correct_exact = correct_category = correct_chapter = total = 0
results = []
for idx in eval_indices:
example = test_ds[idx]
gt_code = example["icd_code"]
try:
out = pipe(example["messages"][:2], max_new_tokens=150, do_sample=False)
generated = out[0]["generated_text"][-1]["content"]
except: total += 1; continue
pred_codes = re.findall(r'\b([A-Z]\d{2}(?:\.\d{1,4})?(?:[A-Z])?)\b', generated)
pred = pred_codes[0] if pred_codes else "NONE"
exact = pred == gt_code
if exact: correct_exact += 1
if pred[:3] == gt_code[:3]: correct_category += 1
if pred[0] == gt_code[0]: correct_chapter += 1
total += 1
results.append({"gt": gt_code, "pred": pred, "exact": exact})
exact_acc = correct_exact/max(total,1)*100
cat_acc = correct_category/max(total,1)*100
ch_acc = correct_chapter/max(total,1)*100
print(f"Exact: {exact_acc:.1f}% | Category: {cat_acc:.1f}% | Chapter: {ch_acc:.1f}%")
trackio.log({"eval/exact_match": exact_acc, "eval/category": cat_acc, "eval/chapter": ch_acc})
trackio.finish()
|