| """ |
| ICD-10-CM Clinical Coding Fine-tuning Script (GPU - Production) |
| ================================================================ |
| Fine-tunes Qwen2.5-1.5B-Instruct with LoRA on 366K synthetic EHR records |
| for ICD-10-CM code classification from clinical text. |
| |
| Requirements: |
| pip install torch transformers trl peft datasets trackio accelerate flash-attn |
| |
| Hardware: A10G (24GB) or better. Training time: ~2-3 hours. |
| |
| Based on: |
| - Recipe 3: Lenz et al. (arxiv:2510.13624) — Instruction-tuning for ICD coding |
| - Recipe 2: MERA (arxiv:2501.17326) — Code memorization improves accuracy |
| - FiscaAI/synth-ehr-icd10cm-prompt dataset (366K rows, 5071 ICD-10-CM codes) |
| - TRL SFTTrainer with prompt/completion format (loss on codes only) |
| |
| To run: |
| # On HF Jobs (A10G): |
| hf_jobs run --script train_icd10_gpu.py --hardware a10g-large --timeout 4h \ |
| --deps torch transformers trl peft datasets trackio accelerate flash-attn |
| |
| # Or locally with GPU: |
| pip install torch transformers trl peft datasets trackio accelerate flash-attn |
| python train_icd10_gpu.py |
| """ |
| import os, re, json, random, gc |
| from collections import Counter |
| import torch, trackio |
| from datasets import load_dataset |
| from peft import LoraConfig |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| from trl import SFTConfig, SFTTrainer |
|
|
| MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" |
| HUB_MODEL_ID = "Rakshithch/qwen2.5-1.5b-icd10cm-coder" |
| DATASET_ID = "Rakshithch/icd10cm-clinical-coding-sft" |
| OUTPUT_DIR = "./qwen2.5-1.5b-icd10cm-lora" |
| LEARNING_RATE = 2e-4 |
| NUM_EPOCHS = 3 |
| BATCH_SIZE = 4 |
| GRAD_ACCUM = 8 |
| MAX_LENGTH = 1024 |
| LORA_R = 16 |
| LORA_ALPHA = 32 |
| SEED = 42 |
| random.seed(SEED) |
|
|
| trackio.init(project="icd10-clinical-coding", name="qwen2.5-1.5b-lora-r16-full", |
| config={"model": MODEL_NAME, "dataset": DATASET_ID, "lora_r": LORA_R, |
| "lr": LEARNING_RATE, "epochs": NUM_EPOCHS, "eff_batch": BATCH_SIZE*GRAD_ACCUM}) |
|
|
| print("Loading dataset...") |
| ds = load_dataset(DATASET_ID) |
| print(f"Train: {len(ds['train'])}, Val: {len(ds['validation'])}, Test: {len(ds['test'])}") |
|
|
| def to_pc(example): |
| msgs = example["messages"] |
| return {"prompt": msgs[:2], "completion": [msgs[2]]} |
|
|
| train_ds = ds["train"].map(to_pc, remove_columns=ds["train"].column_names, num_proc=4) |
| val_ds = ds["validation"].map(to_pc, remove_columns=ds["validation"].column_names, num_proc=4) |
| test_ds = ds["test"] |
|
|
| print(f"Loading {MODEL_NAME}...") |
| model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, dtype=torch.bfloat16, |
| attn_implementation="flash_attention_2", device_map="auto") |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
| if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token |
|
|
| peft_config = LoraConfig(r=LORA_R, lora_alpha=LORA_ALPHA, lora_dropout=0.05, |
| bias="none", task_type="CAUSAL_LM", |
| target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"]) |
|
|
| training_args = SFTConfig( |
| output_dir=OUTPUT_DIR, num_train_epochs=NUM_EPOCHS, |
| per_device_train_batch_size=BATCH_SIZE, per_device_eval_batch_size=BATCH_SIZE, |
| gradient_accumulation_steps=GRAD_ACCUM, learning_rate=LEARNING_RATE, |
| lr_scheduler_type="cosine", warmup_steps=100, optim="adamw_torch_fused", |
| bf16=True, max_length=MAX_LENGTH, gradient_checkpointing=True, |
| gradient_checkpointing_kwargs={"use_reentrant": False}, |
| logging_steps=25, logging_first_step=True, disable_tqdm=True, |
| report_to="trackio", run_name="qwen2.5-1.5b-icd10cm-lora-r16-full", |
| eval_strategy="steps", eval_steps=500, save_strategy="steps", save_steps=500, |
| save_total_limit=3, load_best_model_at_end=True, metric_for_best_model="eval_loss", |
| push_to_hub=True, hub_model_id=HUB_MODEL_ID, hub_strategy="every_save", |
| ) |
|
|
| trainer = SFTTrainer(model=model, args=training_args, train_dataset=train_ds, |
| eval_dataset=val_ds, peft_config=peft_config, processing_class=tokenizer) |
| trainer.model.print_trainable_parameters() |
| result = trainer.train() |
| print(f"Training loss: {result.training_loss:.4f}") |
| trainer.save_model(OUTPUT_DIR); tokenizer.save_pretrained(OUTPUT_DIR) |
| trainer.push_to_hub() |
|
|
| |
| del trainer, model; gc.collect(); torch.cuda.empty_cache() |
| from transformers import pipeline as hf_pipeline |
| pipe = hf_pipeline("text-generation", model=OUTPUT_DIR, tokenizer=tokenizer, |
| device_map="auto", max_new_tokens=150) |
|
|
| eval_size = min(2000, len(test_ds)) |
| eval_indices = random.sample(range(len(test_ds)), eval_size) |
| correct_exact = correct_category = correct_chapter = total = 0 |
| results = [] |
|
|
| for idx in eval_indices: |
| example = test_ds[idx] |
| gt_code = example["icd_code"] |
| try: |
| out = pipe(example["messages"][:2], max_new_tokens=150, do_sample=False) |
| generated = out[0]["generated_text"][-1]["content"] |
| except: total += 1; continue |
| pred_codes = re.findall(r'\b([A-Z]\d{2}(?:\.\d{1,4})?(?:[A-Z])?)\b', generated) |
| pred = pred_codes[0] if pred_codes else "NONE" |
| exact = pred == gt_code |
| if exact: correct_exact += 1 |
| if pred[:3] == gt_code[:3]: correct_category += 1 |
| if pred[0] == gt_code[0]: correct_chapter += 1 |
| total += 1 |
| results.append({"gt": gt_code, "pred": pred, "exact": exact}) |
|
|
| exact_acc = correct_exact/max(total,1)*100 |
| cat_acc = correct_category/max(total,1)*100 |
| ch_acc = correct_chapter/max(total,1)*100 |
| print(f"Exact: {exact_acc:.1f}% | Category: {cat_acc:.1f}% | Chapter: {ch_acc:.1f}%") |
| trackio.log({"eval/exact_match": exact_acc, "eval/category": cat_acc, "eval/chapter": ch_acc}) |
| trackio.finish() |
|
|