Text Generation
PEFT
Safetensors
Transformers
English
medical
icd-10
clinical-coding
healthcare
lora
sft
trl
conversational
Rakshithch commited on
Commit
df41e97
·
verified ·
1 Parent(s): 4046e5e

Update GPU training script with all API fixes and comprehensive evaluation

Browse files
Files changed (1) hide show
  1. train_icd10_gpu.py +87 -333
train_icd10_gpu.py CHANGED
@@ -1,375 +1,129 @@
1
  """
2
- ICD-10-CM Clinical Coding Fine-tuning Script
3
- =============================================
4
- Fine-tunes Qwen2.5-1.5B-Instruct with LoRA on synthetic EHR data
5
  for ICD-10-CM code classification from clinical text.
6
 
 
 
 
 
 
7
  Based on:
8
- - Recipe 3 from literature review (Lenz et al., arxiv:2510.13624)
9
- - FiscaAI/synth-ehr-icd10cm-prompt dataset (366K rows, 5071 codes)
 
10
  - TRL SFTTrainer with prompt/completion format (loss on codes only)
11
- """
12
 
13
- import os
14
- import re
15
- import json
16
- import random
17
- import numpy as np
 
 
 
 
 
18
  from collections import Counter
19
-
20
- import torch
21
- import trackio
22
- from datasets import load_dataset, Dataset
23
  from peft import LoraConfig
24
  from transformers import AutoModelForCausalLM, AutoTokenizer
25
  from trl import SFTConfig, SFTTrainer
26
 
27
- # ============================================================================
28
- # Configuration
29
- # ============================================================================
30
  MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
31
  HUB_MODEL_ID = "Rakshithch/qwen2.5-1.5b-icd10cm-coder"
32
- DATASET_NAME = "FiscaAI/synth-ehr-icd10cm-prompt"
33
  OUTPUT_DIR = "./qwen2.5-1.5b-icd10cm-lora"
34
-
35
- # Training hyperparameters (from literature: LoRA SFT recipe)
36
- LEARNING_RATE = 2e-4 # LoRA ~10x base LR
37
  NUM_EPOCHS = 3
38
  BATCH_SIZE = 4
39
- GRAD_ACCUM = 8 # effective batch = 32
40
- MAX_SEQ_LENGTH = 1024 # P95 of user+assistant text fits in ~512 tokens
41
  LORA_R = 16
42
  LORA_ALPHA = 32
43
-
44
- # Data splits
45
- TRAIN_SIZE = 0.90
46
- VAL_SIZE = 0.05
47
- TEST_SIZE = 0.05
48
-
49
  SEED = 42
 
50
 
51
- # ============================================================================
52
- # Initialize trackio
53
- # ============================================================================
54
- trackio.init(
55
- project="icd10-clinical-coding",
56
- name="qwen2.5-1.5b-lora-r16-full",
57
- config={
58
- "model": MODEL_NAME,
59
- "dataset": DATASET_NAME,
60
- "lora_r": LORA_R,
61
- "lora_alpha": LORA_ALPHA,
62
- "lr": LEARNING_RATE,
63
- "epochs": NUM_EPOCHS,
64
- "batch_size": BATCH_SIZE,
65
- "grad_accum": GRAD_ACCUM,
66
- "max_seq_length": MAX_SEQ_LENGTH,
67
- },
68
- )
69
 
70
- # ============================================================================
71
- # 1. Load and prepare dataset
72
- # ============================================================================
73
- print("=" * 70)
74
  print("Loading dataset...")
75
- print("=" * 70)
76
-
77
- raw_ds = load_dataset(DATASET_NAME, split="train")
78
- print(f"Total rows: {len(raw_ds)}")
79
 
80
- # Remove empty/null user fields
81
- raw_ds = raw_ds.filter(lambda x: x["user"] and x["user"].strip() != "")
82
- print(f"After filtering empties: {len(raw_ds)}")
83
-
84
- # Improved system prompt for ICD-10-CM coding in healthcare claims context
85
- SYSTEM_PROMPT = (
86
- "You are an expert medical coder specializing in ICD-10-CM coding for "
87
- "healthcare claims processing (X12 EDI 837 format). Given a clinical "
88
- "note or symptom description, identify the correct ICD-10-CM diagnosis "
89
- "code. Provide the code followed by a brief explanation."
90
- )
91
-
92
- def format_to_prompt_completion(example):
93
- """Convert to prompt/completion format for loss on completion only."""
94
- prompt = [
95
- {"role": "system", "content": SYSTEM_PROMPT},
96
- {"role": "user", "content": example["user"]},
97
- ]
98
- # Extract just the ICD code and explanation from assistant
99
- completion = [
100
- {"role": "assistant", "content": example["assistant"]},
101
- ]
102
- return {"prompt": prompt, "completion": completion}
103
-
104
- print("Formatting dataset to prompt/completion...")
105
- formatted_ds = raw_ds.map(
106
- format_to_prompt_completion,
107
- remove_columns=raw_ds.column_names,
108
- num_proc=4,
109
- desc="Formatting",
110
- )
111
 
112
- # Split into train/val/test
113
- print("Splitting dataset...")
114
- ds_split = formatted_ds.train_test_split(test_size=(VAL_SIZE + TEST_SIZE), seed=SEED)
115
- val_test = ds_split["test"].train_test_split(test_size=TEST_SIZE / (VAL_SIZE + TEST_SIZE), seed=SEED)
116
 
117
- train_ds = ds_split["train"]
118
- val_ds = val_test["train"]
119
- test_ds = val_test["test"]
120
-
121
- print(f"Train: {len(train_ds)}, Val: {len(val_ds)}, Test: {len(test_ds)}")
122
-
123
- # ============================================================================
124
- # 2. Model & LoRA setup
125
- # ============================================================================
126
- print("\n" + "=" * 70)
127
- print("Loading model...")
128
- print("=" * 70)
129
-
130
- model = AutoModelForCausalLM.from_pretrained(
131
- MODEL_NAME,
132
- dtype=torch.bfloat16,
133
- attn_implementation="flash_attention_2",
134
- device_map="auto",
135
- )
136
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
 
137
 
138
- if tokenizer.pad_token is None:
139
- tokenizer.pad_token = tokenizer.eos_token
140
-
141
- print(f"Model loaded: {MODEL_NAME}")
142
- print(f"Model dtype: {model.dtype}")
143
- print(f"Parameters: {sum(p.numel() for p in model.parameters()) / 1e9:.2f}B")
144
-
145
- peft_config = LoraConfig(
146
- r=LORA_R,
147
- lora_alpha=LORA_ALPHA,
148
- lora_dropout=0.05,
149
- bias="none",
150
- task_type="CAUSAL_LM",
151
- target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
152
- "gate_proj", "up_proj", "down_proj"], # all attention + MLP
153
- )
154
-
155
- # ============================================================================
156
- # 3. Training
157
- # ============================================================================
158
- print("\n" + "=" * 70)
159
- print("Setting up training...")
160
- print("=" * 70)
161
 
162
  training_args = SFTConfig(
163
- output_dir=OUTPUT_DIR,
164
- num_train_epochs=NUM_EPOCHS,
165
- per_device_train_batch_size=BATCH_SIZE,
166
- per_device_eval_batch_size=BATCH_SIZE,
167
- gradient_accumulation_steps=GRAD_ACCUM,
168
- learning_rate=LEARNING_RATE,
169
- lr_scheduler_type="cosine",
170
- warmup_ratio=0.05,
171
- optim="adamw_torch_fused",
172
- bf16=True,
173
- max_length=MAX_SEQ_LENGTH,
174
- gradient_checkpointing=True,
175
  gradient_checkpointing_kwargs={"use_reentrant": False},
176
-
177
- # Logging
178
- logging_steps=25,
179
- logging_first_step=True,
180
- disable_tqdm=True,
181
- report_to="trackio",
182
- run_name="qwen2.5-1.5b-icd10cm-lora-r16",
183
-
184
- # Evaluation
185
- eval_strategy="steps",
186
- eval_steps=500,
187
- save_strategy="steps",
188
- save_steps=500,
189
- save_total_limit=3,
190
- load_best_model_at_end=True,
191
- metric_for_best_model="eval_loss",
192
-
193
- # Push to Hub
194
- push_to_hub=True,
195
- hub_model_id=HUB_MODEL_ID,
196
- hub_strategy="every_save",
197
  )
198
 
199
- trainer = SFTTrainer(
200
- model=model,
201
- args=training_args,
202
- train_dataset=train_ds,
203
- eval_dataset=val_ds,
204
- peft_config=peft_config,
205
- processing_class=tokenizer,
206
- )
207
-
208
- print(f"Trainable parameters: {trainer.model.print_trainable_parameters()}")
209
- print(f"\nStarting training for {NUM_EPOCHS} epochs...")
210
- train_result = trainer.train()
211
-
212
- print("\n" + "=" * 70)
213
- print("Training complete!")
214
- print(f"Train loss: {train_result.training_loss:.4f}")
215
- print("=" * 70)
216
-
217
- # Save final model
218
- trainer.save_model(OUTPUT_DIR)
219
  trainer.push_to_hub()
220
 
221
- # ============================================================================
222
- # 4. Evaluation on test set
223
- # ============================================================================
224
- print("\n" + "=" * 70)
225
- print("Evaluating on test set...")
226
- print("=" * 70)
227
-
228
- from transformers import pipeline
229
-
230
- # Load fine-tuned model for inference
231
- pipe = pipeline(
232
- "text-generation",
233
- model=OUTPUT_DIR,
234
- tokenizer=tokenizer,
235
- device_map="auto",
236
- max_new_tokens=128,
237
- )
238
 
239
- # Evaluation metrics
240
- correct_exact = 0
241
- correct_partial = 0
242
- correct_chapter = 0
243
- correct_category = 0 # first 3 chars (e.g., J18)
244
- total = 0
245
- results = []
246
-
247
- # Sample test set for evaluation (max 2000 for speed)
248
  eval_size = min(2000, len(test_ds))
249
  eval_indices = random.sample(range(len(test_ds)), eval_size)
 
 
250
 
251
- print(f"Evaluating on {eval_size} test examples...")
252
-
253
- for idx, i in enumerate(eval_indices):
254
- example = test_ds[i]
255
-
256
- # Build the prompt messages
257
- messages = example["prompt"]
258
-
259
- # Generate
260
- output = pipe(messages, max_new_tokens=128, do_sample=False, temperature=None)
261
- generated = output[0]["generated_text"][-1]["content"]
262
-
263
- # Extract predicted ICD code from generated text
264
- # Pattern: look for ICD-10-CM code format (letter + digits + optional dot + more chars)
265
  pred_codes = re.findall(r'\b([A-Z]\d{2}(?:\.\d{1,4})?(?:[A-Z])?)\b', generated)
266
-
267
- # Extract ground truth code from completion
268
- gt_text = example["completion"][0]["content"]
269
- gt_codes = re.findall(r'\b([A-Z]\d{2}(?:\.\d{1,4})?(?:[A-Z])?)\b', gt_text)
270
-
271
- if gt_codes and pred_codes:
272
- gt_code = gt_codes[0]
273
- pred_code = pred_codes[0]
274
-
275
- # Exact match
276
- if pred_code == gt_code:
277
- correct_exact += 1
278
-
279
- # Partial match (code without laterality suffix)
280
- gt_base = gt_code.split('.')[0] + ('.' + gt_code.split('.')[1][:2] if '.' in gt_code else '')
281
- pred_base = pred_code.split('.')[0] + ('.' + pred_code.split('.')[1][:2] if '.' in pred_code else '')
282
- if pred_base == gt_base:
283
- correct_partial += 1
284
-
285
- # Category match (first 3 chars, e.g., J18, M24)
286
- if pred_code[:3] == gt_code[:3]:
287
- correct_category += 1
288
-
289
- # Chapter match (first letter)
290
- if pred_code[0] == gt_code[0]:
291
- correct_chapter += 1
292
-
293
- results.append({
294
- "gt_code": gt_code,
295
- "pred_code": pred_code,
296
- "exact_match": pred_code == gt_code,
297
- "category_match": pred_code[:3] == gt_code[:3],
298
- })
299
- else:
300
- results.append({
301
- "gt_code": gt_codes[0] if gt_codes else "NONE",
302
- "pred_code": pred_codes[0] if pred_codes else "NONE",
303
- "exact_match": False,
304
- "category_match": False,
305
- })
306
-
307
  total += 1
 
308
 
309
- if (idx + 1) % 200 == 0:
310
- print(f" Evaluated {idx+1}/{eval_size} | "
311
- f"Exact: {correct_exact/total*100:.1f}% | "
312
- f"Category: {correct_category/total*100:.1f}%")
313
-
314
- # Final metrics
315
- print("\n" + "=" * 70)
316
- print("EVALUATION RESULTS")
317
- print("=" * 70)
318
- exact_acc = correct_exact / total * 100
319
- partial_acc = correct_partial / total * 100
320
- category_acc = correct_category / total * 100
321
- chapter_acc = correct_chapter / total * 100
322
-
323
- print(f" Exact Match Accuracy: {exact_acc:.2f}% ({correct_exact}/{total})")
324
- print(f" Partial Match Accuracy: {partial_acc:.2f}% ({correct_partial}/{total})")
325
- print(f" Category (3-char) Acc: {category_acc:.2f}% ({correct_category}/{total})")
326
- print(f" Chapter (1st letter): {chapter_acc:.2f}% ({correct_chapter}/{total})")
327
-
328
- # Log to trackio
329
- trackio.log({
330
- "eval/exact_match_accuracy": exact_acc,
331
- "eval/partial_match_accuracy": partial_acc,
332
- "eval/category_accuracy": category_acc,
333
- "eval/chapter_accuracy": chapter_acc,
334
- "eval/total_samples": total,
335
- })
336
-
337
- # Error analysis: which chapters have lowest accuracy
338
- print("\n--- Per-Chapter Accuracy ---")
339
- chapter_stats = {}
340
- for r in results:
341
- ch = r["gt_code"][0] if r["gt_code"] != "NONE" else "?"
342
- if ch not in chapter_stats:
343
- chapter_stats[ch] = {"total": 0, "correct": 0}
344
- chapter_stats[ch]["total"] += 1
345
- if r["exact_match"]:
346
- chapter_stats[ch]["correct"] += 1
347
-
348
- for ch in sorted(chapter_stats.keys()):
349
- s = chapter_stats[ch]
350
- acc = s["correct"] / s["total"] * 100 if s["total"] > 0 else 0
351
- print(f" Chapter {ch}: {acc:.1f}% ({s['correct']}/{s['total']})")
352
-
353
- # Save results
354
- with open(os.path.join(OUTPUT_DIR, "eval_results.json"), "w") as f:
355
- json.dump({
356
- "exact_match_accuracy": exact_acc,
357
- "partial_match_accuracy": partial_acc,
358
- "category_accuracy": category_acc,
359
- "chapter_accuracy": chapter_acc,
360
- "total_evaluated": total,
361
- "per_chapter": chapter_stats,
362
- }, f, indent=2)
363
-
364
- # Sample predictions
365
- print("\n--- Sample Predictions ---")
366
- for r in results[:10]:
367
- status = "✅" if r["exact_match"] else ("🟡" if r["category_match"] else "❌")
368
- print(f" {status} GT: {r['gt_code']:<12} Pred: {r['pred_code']}")
369
-
370
  trackio.finish()
371
-
372
- print("\n" + "=" * 70)
373
- print(f"Model saved to Hub: https://hf.co/{HUB_MODEL_ID}")
374
- print(f"Training dashboard: trackio")
375
- print("=" * 70)
 
1
  """
2
+ ICD-10-CM Clinical Coding Fine-tuning Script (GPU - Production)
3
+ ================================================================
4
+ Fine-tunes Qwen2.5-1.5B-Instruct with LoRA on 366K synthetic EHR records
5
  for ICD-10-CM code classification from clinical text.
6
 
7
+ Requirements:
8
+ pip install torch transformers trl peft datasets trackio accelerate flash-attn
9
+
10
+ Hardware: A10G (24GB) or better. Training time: ~2-3 hours.
11
+
12
  Based on:
13
+ - Recipe 3: Lenz et al. (arxiv:2510.13624) — Instruction-tuning for ICD coding
14
+ - Recipe 2: MERA (arxiv:2501.17326) Code memorization improves accuracy
15
+ - FiscaAI/synth-ehr-icd10cm-prompt dataset (366K rows, 5071 ICD-10-CM codes)
16
  - TRL SFTTrainer with prompt/completion format (loss on codes only)
 
17
 
18
+ To run:
19
+ # On HF Jobs (A10G):
20
+ hf_jobs run --script train_icd10_gpu.py --hardware a10g-large --timeout 4h \
21
+ --deps torch transformers trl peft datasets trackio accelerate flash-attn
22
+
23
+ # Or locally with GPU:
24
+ pip install torch transformers trl peft datasets trackio accelerate flash-attn
25
+ python train_icd10_gpu.py
26
+ """
27
+ import os, re, json, random, gc
28
  from collections import Counter
29
+ import torch, trackio
30
+ from datasets import load_dataset
 
 
31
  from peft import LoraConfig
32
  from transformers import AutoModelForCausalLM, AutoTokenizer
33
  from trl import SFTConfig, SFTTrainer
34
 
 
 
 
35
  MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
36
  HUB_MODEL_ID = "Rakshithch/qwen2.5-1.5b-icd10cm-coder"
37
+ DATASET_ID = "Rakshithch/icd10cm-clinical-coding-sft"
38
  OUTPUT_DIR = "./qwen2.5-1.5b-icd10cm-lora"
39
+ LEARNING_RATE = 2e-4
 
 
40
  NUM_EPOCHS = 3
41
  BATCH_SIZE = 4
42
+ GRAD_ACCUM = 8
43
+ MAX_LENGTH = 1024
44
  LORA_R = 16
45
  LORA_ALPHA = 32
 
 
 
 
 
 
46
  SEED = 42
47
+ random.seed(SEED)
48
 
49
+ trackio.init(project="icd10-clinical-coding", name="qwen2.5-1.5b-lora-r16-full",
50
+ config={"model": MODEL_NAME, "dataset": DATASET_ID, "lora_r": LORA_R,
51
+ "lr": LEARNING_RATE, "epochs": NUM_EPOCHS, "eff_batch": BATCH_SIZE*GRAD_ACCUM})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
 
 
 
 
53
  print("Loading dataset...")
54
+ ds = load_dataset(DATASET_ID)
55
+ print(f"Train: {len(ds['train'])}, Val: {len(ds['validation'])}, Test: {len(ds['test'])}")
 
 
56
 
57
+ def to_pc(example):
58
+ msgs = example["messages"]
59
+ return {"prompt": msgs[:2], "completion": [msgs[2]]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
+ train_ds = ds["train"].map(to_pc, remove_columns=ds["train"].column_names, num_proc=4)
62
+ val_ds = ds["validation"].map(to_pc, remove_columns=ds["validation"].column_names, num_proc=4)
63
+ test_ds = ds["test"]
 
64
 
65
+ print(f"Loading {MODEL_NAME}...")
66
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, dtype=torch.bfloat16,
67
+ attn_implementation="flash_attention_2", device_map="auto")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
69
+ if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
70
 
71
+ peft_config = LoraConfig(r=LORA_R, lora_alpha=LORA_ALPHA, lora_dropout=0.05,
72
+ bias="none", task_type="CAUSAL_LM",
73
+ target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  training_args = SFTConfig(
76
+ output_dir=OUTPUT_DIR, num_train_epochs=NUM_EPOCHS,
77
+ per_device_train_batch_size=BATCH_SIZE, per_device_eval_batch_size=BATCH_SIZE,
78
+ gradient_accumulation_steps=GRAD_ACCUM, learning_rate=LEARNING_RATE,
79
+ lr_scheduler_type="cosine", warmup_steps=100, optim="adamw_torch_fused",
80
+ bf16=True, max_length=MAX_LENGTH, gradient_checkpointing=True,
 
 
 
 
 
 
 
81
  gradient_checkpointing_kwargs={"use_reentrant": False},
82
+ logging_steps=25, logging_first_step=True, disable_tqdm=True,
83
+ report_to="trackio", run_name="qwen2.5-1.5b-icd10cm-lora-r16-full",
84
+ eval_strategy="steps", eval_steps=500, save_strategy="steps", save_steps=500,
85
+ save_total_limit=3, load_best_model_at_end=True, metric_for_best_model="eval_loss",
86
+ push_to_hub=True, hub_model_id=HUB_MODEL_ID, hub_strategy="every_save",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  )
88
 
89
+ trainer = SFTTrainer(model=model, args=training_args, train_dataset=train_ds,
90
+ eval_dataset=val_ds, peft_config=peft_config, processing_class=tokenizer)
91
+ trainer.model.print_trainable_parameters()
92
+ result = trainer.train()
93
+ print(f"Training loss: {result.training_loss:.4f}")
94
+ trainer.save_model(OUTPUT_DIR); tokenizer.save_pretrained(OUTPUT_DIR)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  trainer.push_to_hub()
96
 
97
+ # Evaluation
98
+ del trainer, model; gc.collect(); torch.cuda.empty_cache()
99
+ from transformers import pipeline as hf_pipeline
100
+ pipe = hf_pipeline("text-generation", model=OUTPUT_DIR, tokenizer=tokenizer,
101
+ device_map="auto", max_new_tokens=150)
 
 
 
 
 
 
 
 
 
 
 
 
102
 
 
 
 
 
 
 
 
 
 
103
  eval_size = min(2000, len(test_ds))
104
  eval_indices = random.sample(range(len(test_ds)), eval_size)
105
+ correct_exact = correct_category = correct_chapter = total = 0
106
+ results = []
107
 
108
+ for idx in eval_indices:
109
+ example = test_ds[idx]
110
+ gt_code = example["icd_code"]
111
+ try:
112
+ out = pipe(example["messages"][:2], max_new_tokens=150, do_sample=False)
113
+ generated = out[0]["generated_text"][-1]["content"]
114
+ except: total += 1; continue
 
 
 
 
 
 
 
115
  pred_codes = re.findall(r'\b([A-Z]\d{2}(?:\.\d{1,4})?(?:[A-Z])?)\b', generated)
116
+ pred = pred_codes[0] if pred_codes else "NONE"
117
+ exact = pred == gt_code
118
+ if exact: correct_exact += 1
119
+ if pred[:3] == gt_code[:3]: correct_category += 1
120
+ if pred[0] == gt_code[0]: correct_chapter += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  total += 1
122
+ results.append({"gt": gt_code, "pred": pred, "exact": exact})
123
 
124
+ exact_acc = correct_exact/max(total,1)*100
125
+ cat_acc = correct_category/max(total,1)*100
126
+ ch_acc = correct_chapter/max(total,1)*100
127
+ print(f"Exact: {exact_acc:.1f}% | Category: {cat_acc:.1f}% | Chapter: {ch_acc:.1f}%")
128
+ trackio.log({"eval/exact_match": exact_acc, "eval/category": cat_acc, "eval/chapter": ch_acc})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  trackio.finish()