Text Generation
PEFT
Safetensors
Transformers
English
medical
icd-10
clinical-coding
healthcare
lora
sft
trl
conversational
File size: 5,464 Bytes
39600f7
df41e97
 
 
39600f7
 
df41e97
 
 
 
 
39600f7
df41e97
 
 
39600f7
 
df41e97
 
 
 
 
 
 
 
 
 
39600f7
df41e97
 
39600f7
 
 
 
 
 
df41e97
39600f7
df41e97
39600f7
 
df41e97
 
39600f7
 
 
df41e97
39600f7
df41e97
 
 
39600f7
 
df41e97
 
39600f7
df41e97
 
 
39600f7
df41e97
 
 
39600f7
df41e97
 
 
39600f7
df41e97
39600f7
df41e97
 
 
39600f7
 
df41e97
 
 
 
 
39600f7
df41e97
 
 
 
 
39600f7
 
df41e97
 
 
 
 
 
39600f7
 
df41e97
 
 
 
 
39600f7
 
 
df41e97
 
39600f7
df41e97
 
 
 
 
 
 
39600f7
df41e97
 
 
 
 
39600f7
df41e97
39600f7
df41e97
 
 
 
 
39600f7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""
ICD-10-CM Clinical Coding Fine-tuning Script (GPU - Production)
================================================================
Fine-tunes Qwen2.5-1.5B-Instruct with LoRA on 366K synthetic EHR records
for ICD-10-CM code classification from clinical text.

Requirements:
  pip install torch transformers trl peft datasets trackio accelerate flash-attn

Hardware: A10G (24GB) or better. Training time: ~2-3 hours.

Based on:
- Recipe 3: Lenz et al. (arxiv:2510.13624) — Instruction-tuning for ICD coding
- Recipe 2: MERA (arxiv:2501.17326) — Code memorization improves accuracy
- FiscaAI/synth-ehr-icd10cm-prompt dataset (366K rows, 5071 ICD-10-CM codes)
- TRL SFTTrainer with prompt/completion format (loss on codes only)

To run:
  # On HF Jobs (A10G):
  hf_jobs run --script train_icd10_gpu.py --hardware a10g-large --timeout 4h \
    --deps torch transformers trl peft datasets trackio accelerate flash-attn
  
  # Or locally with GPU:
  pip install torch transformers trl peft datasets trackio accelerate flash-attn
  python train_icd10_gpu.py
"""
import os, re, json, random, gc
from collections import Counter
import torch, trackio
from datasets import load_dataset
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTConfig, SFTTrainer

MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
HUB_MODEL_ID = "Rakshithch/qwen2.5-1.5b-icd10cm-coder"
DATASET_ID = "Rakshithch/icd10cm-clinical-coding-sft"
OUTPUT_DIR = "./qwen2.5-1.5b-icd10cm-lora"
LEARNING_RATE = 2e-4
NUM_EPOCHS = 3
BATCH_SIZE = 4
GRAD_ACCUM = 8
MAX_LENGTH = 1024
LORA_R = 16
LORA_ALPHA = 32
SEED = 42
random.seed(SEED)

trackio.init(project="icd10-clinical-coding", name="qwen2.5-1.5b-lora-r16-full",
    config={"model": MODEL_NAME, "dataset": DATASET_ID, "lora_r": LORA_R,
            "lr": LEARNING_RATE, "epochs": NUM_EPOCHS, "eff_batch": BATCH_SIZE*GRAD_ACCUM})

print("Loading dataset...")
ds = load_dataset(DATASET_ID)
print(f"Train: {len(ds['train'])}, Val: {len(ds['validation'])}, Test: {len(ds['test'])}")

def to_pc(example):
    msgs = example["messages"]
    return {"prompt": msgs[:2], "completion": [msgs[2]]}

train_ds = ds["train"].map(to_pc, remove_columns=ds["train"].column_names, num_proc=4)
val_ds = ds["validation"].map(to_pc, remove_columns=ds["validation"].column_names, num_proc=4)
test_ds = ds["test"]

print(f"Loading {MODEL_NAME}...")
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, dtype=torch.bfloat16,
    attn_implementation="flash_attention_2", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token

peft_config = LoraConfig(r=LORA_R, lora_alpha=LORA_ALPHA, lora_dropout=0.05,
    bias="none", task_type="CAUSAL_LM",
    target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"])

training_args = SFTConfig(
    output_dir=OUTPUT_DIR, num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE, per_device_eval_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRAD_ACCUM, learning_rate=LEARNING_RATE,
    lr_scheduler_type="cosine", warmup_steps=100, optim="adamw_torch_fused",
    bf16=True, max_length=MAX_LENGTH, gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    logging_steps=25, logging_first_step=True, disable_tqdm=True,
    report_to="trackio", run_name="qwen2.5-1.5b-icd10cm-lora-r16-full",
    eval_strategy="steps", eval_steps=500, save_strategy="steps", save_steps=500,
    save_total_limit=3, load_best_model_at_end=True, metric_for_best_model="eval_loss",
    push_to_hub=True, hub_model_id=HUB_MODEL_ID, hub_strategy="every_save",
)

trainer = SFTTrainer(model=model, args=training_args, train_dataset=train_ds,
    eval_dataset=val_ds, peft_config=peft_config, processing_class=tokenizer)
trainer.model.print_trainable_parameters()
result = trainer.train()
print(f"Training loss: {result.training_loss:.4f}")
trainer.save_model(OUTPUT_DIR); tokenizer.save_pretrained(OUTPUT_DIR)
trainer.push_to_hub()

# Evaluation
del trainer, model; gc.collect(); torch.cuda.empty_cache()
from transformers import pipeline as hf_pipeline
pipe = hf_pipeline("text-generation", model=OUTPUT_DIR, tokenizer=tokenizer,
    device_map="auto", max_new_tokens=150)

eval_size = min(2000, len(test_ds))
eval_indices = random.sample(range(len(test_ds)), eval_size)
correct_exact = correct_category = correct_chapter = total = 0
results = []

for idx in eval_indices:
    example = test_ds[idx]
    gt_code = example["icd_code"]
    try:
        out = pipe(example["messages"][:2], max_new_tokens=150, do_sample=False)
        generated = out[0]["generated_text"][-1]["content"]
    except: total += 1; continue
    pred_codes = re.findall(r'\b([A-Z]\d{2}(?:\.\d{1,4})?(?:[A-Z])?)\b', generated)
    pred = pred_codes[0] if pred_codes else "NONE"
    exact = pred == gt_code
    if exact: correct_exact += 1
    if pred[:3] == gt_code[:3]: correct_category += 1
    if pred[0] == gt_code[0]: correct_chapter += 1
    total += 1
    results.append({"gt": gt_code, "pred": pred, "exact": exact})

exact_acc = correct_exact/max(total,1)*100
cat_acc = correct_category/max(total,1)*100
ch_acc = correct_chapter/max(total,1)*100
print(f"Exact: {exact_acc:.1f}% | Category: {cat_acc:.1f}% | Chapter: {ch_acc:.1f}%")
trackio.log({"eval/exact_match": exact_acc, "eval/category": cat_acc, "eval/chapter": ch_acc})
trackio.finish()