Text Generation
PEFT
Safetensors
Transformers
English
medical
icd-10
clinical-coding
healthcare
lora
sft
trl
conversational
Rakshithch commited on
Commit
39600f7
·
verified ·
1 Parent(s): 124157b

Add GPU training script

Browse files
Files changed (1) hide show
  1. train_icd10_gpu.py +375 -0
train_icd10_gpu.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ICD-10-CM Clinical Coding Fine-tuning Script
3
+ =============================================
4
+ Fine-tunes Qwen2.5-1.5B-Instruct with LoRA on synthetic EHR data
5
+ for ICD-10-CM code classification from clinical text.
6
+
7
+ Based on:
8
+ - Recipe 3 from literature review (Lenz et al., arxiv:2510.13624)
9
+ - FiscaAI/synth-ehr-icd10cm-prompt dataset (366K rows, 5071 codes)
10
+ - TRL SFTTrainer with prompt/completion format (loss on codes only)
11
+ """
12
+
13
+ import os
14
+ import re
15
+ import json
16
+ import random
17
+ import numpy as np
18
+ from collections import Counter
19
+
20
+ import torch
21
+ import trackio
22
+ from datasets import load_dataset, Dataset
23
+ from peft import LoraConfig
24
+ from transformers import AutoModelForCausalLM, AutoTokenizer
25
+ from trl import SFTConfig, SFTTrainer
26
+
27
+ # ============================================================================
28
+ # Configuration
29
+ # ============================================================================
30
+ MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
31
+ HUB_MODEL_ID = "Rakshithch/qwen2.5-1.5b-icd10cm-coder"
32
+ DATASET_NAME = "FiscaAI/synth-ehr-icd10cm-prompt"
33
+ OUTPUT_DIR = "./qwen2.5-1.5b-icd10cm-lora"
34
+
35
+ # Training hyperparameters (from literature: LoRA SFT recipe)
36
+ LEARNING_RATE = 2e-4 # LoRA ~10x base LR
37
+ NUM_EPOCHS = 3
38
+ BATCH_SIZE = 4
39
+ GRAD_ACCUM = 8 # effective batch = 32
40
+ MAX_SEQ_LENGTH = 1024 # P95 of user+assistant text fits in ~512 tokens
41
+ LORA_R = 16
42
+ LORA_ALPHA = 32
43
+
44
+ # Data splits
45
+ TRAIN_SIZE = 0.90
46
+ VAL_SIZE = 0.05
47
+ TEST_SIZE = 0.05
48
+
49
+ SEED = 42
50
+
51
+ # ============================================================================
52
+ # Initialize trackio
53
+ # ============================================================================
54
+ trackio.init(
55
+ project="icd10-clinical-coding",
56
+ name="qwen2.5-1.5b-lora-r16-full",
57
+ config={
58
+ "model": MODEL_NAME,
59
+ "dataset": DATASET_NAME,
60
+ "lora_r": LORA_R,
61
+ "lora_alpha": LORA_ALPHA,
62
+ "lr": LEARNING_RATE,
63
+ "epochs": NUM_EPOCHS,
64
+ "batch_size": BATCH_SIZE,
65
+ "grad_accum": GRAD_ACCUM,
66
+ "max_seq_length": MAX_SEQ_LENGTH,
67
+ },
68
+ )
69
+
70
+ # ============================================================================
71
+ # 1. Load and prepare dataset
72
+ # ============================================================================
73
+ print("=" * 70)
74
+ print("Loading dataset...")
75
+ print("=" * 70)
76
+
77
+ raw_ds = load_dataset(DATASET_NAME, split="train")
78
+ print(f"Total rows: {len(raw_ds)}")
79
+
80
+ # Remove empty/null user fields
81
+ raw_ds = raw_ds.filter(lambda x: x["user"] and x["user"].strip() != "")
82
+ print(f"After filtering empties: {len(raw_ds)}")
83
+
84
+ # Improved system prompt for ICD-10-CM coding in healthcare claims context
85
+ SYSTEM_PROMPT = (
86
+ "You are an expert medical coder specializing in ICD-10-CM coding for "
87
+ "healthcare claims processing (X12 EDI 837 format). Given a clinical "
88
+ "note or symptom description, identify the correct ICD-10-CM diagnosis "
89
+ "code. Provide the code followed by a brief explanation."
90
+ )
91
+
92
+ def format_to_prompt_completion(example):
93
+ """Convert to prompt/completion format for loss on completion only."""
94
+ prompt = [
95
+ {"role": "system", "content": SYSTEM_PROMPT},
96
+ {"role": "user", "content": example["user"]},
97
+ ]
98
+ # Extract just the ICD code and explanation from assistant
99
+ completion = [
100
+ {"role": "assistant", "content": example["assistant"]},
101
+ ]
102
+ return {"prompt": prompt, "completion": completion}
103
+
104
+ print("Formatting dataset to prompt/completion...")
105
+ formatted_ds = raw_ds.map(
106
+ format_to_prompt_completion,
107
+ remove_columns=raw_ds.column_names,
108
+ num_proc=4,
109
+ desc="Formatting",
110
+ )
111
+
112
+ # Split into train/val/test
113
+ print("Splitting dataset...")
114
+ ds_split = formatted_ds.train_test_split(test_size=(VAL_SIZE + TEST_SIZE), seed=SEED)
115
+ val_test = ds_split["test"].train_test_split(test_size=TEST_SIZE / (VAL_SIZE + TEST_SIZE), seed=SEED)
116
+
117
+ train_ds = ds_split["train"]
118
+ val_ds = val_test["train"]
119
+ test_ds = val_test["test"]
120
+
121
+ print(f"Train: {len(train_ds)}, Val: {len(val_ds)}, Test: {len(test_ds)}")
122
+
123
+ # ============================================================================
124
+ # 2. Model & LoRA setup
125
+ # ============================================================================
126
+ print("\n" + "=" * 70)
127
+ print("Loading model...")
128
+ print("=" * 70)
129
+
130
+ model = AutoModelForCausalLM.from_pretrained(
131
+ MODEL_NAME,
132
+ dtype=torch.bfloat16,
133
+ attn_implementation="flash_attention_2",
134
+ device_map="auto",
135
+ )
136
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
137
+
138
+ if tokenizer.pad_token is None:
139
+ tokenizer.pad_token = tokenizer.eos_token
140
+
141
+ print(f"Model loaded: {MODEL_NAME}")
142
+ print(f"Model dtype: {model.dtype}")
143
+ print(f"Parameters: {sum(p.numel() for p in model.parameters()) / 1e9:.2f}B")
144
+
145
+ peft_config = LoraConfig(
146
+ r=LORA_R,
147
+ lora_alpha=LORA_ALPHA,
148
+ lora_dropout=0.05,
149
+ bias="none",
150
+ task_type="CAUSAL_LM",
151
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
152
+ "gate_proj", "up_proj", "down_proj"], # all attention + MLP
153
+ )
154
+
155
+ # ============================================================================
156
+ # 3. Training
157
+ # ============================================================================
158
+ print("\n" + "=" * 70)
159
+ print("Setting up training...")
160
+ print("=" * 70)
161
+
162
+ training_args = SFTConfig(
163
+ output_dir=OUTPUT_DIR,
164
+ num_train_epochs=NUM_EPOCHS,
165
+ per_device_train_batch_size=BATCH_SIZE,
166
+ per_device_eval_batch_size=BATCH_SIZE,
167
+ gradient_accumulation_steps=GRAD_ACCUM,
168
+ learning_rate=LEARNING_RATE,
169
+ lr_scheduler_type="cosine",
170
+ warmup_ratio=0.05,
171
+ optim="adamw_torch_fused",
172
+ bf16=True,
173
+ max_length=MAX_SEQ_LENGTH,
174
+ gradient_checkpointing=True,
175
+ gradient_checkpointing_kwargs={"use_reentrant": False},
176
+
177
+ # Logging
178
+ logging_steps=25,
179
+ logging_first_step=True,
180
+ disable_tqdm=True,
181
+ report_to="trackio",
182
+ run_name="qwen2.5-1.5b-icd10cm-lora-r16",
183
+
184
+ # Evaluation
185
+ eval_strategy="steps",
186
+ eval_steps=500,
187
+ save_strategy="steps",
188
+ save_steps=500,
189
+ save_total_limit=3,
190
+ load_best_model_at_end=True,
191
+ metric_for_best_model="eval_loss",
192
+
193
+ # Push to Hub
194
+ push_to_hub=True,
195
+ hub_model_id=HUB_MODEL_ID,
196
+ hub_strategy="every_save",
197
+ )
198
+
199
+ trainer = SFTTrainer(
200
+ model=model,
201
+ args=training_args,
202
+ train_dataset=train_ds,
203
+ eval_dataset=val_ds,
204
+ peft_config=peft_config,
205
+ processing_class=tokenizer,
206
+ )
207
+
208
+ print(f"Trainable parameters: {trainer.model.print_trainable_parameters()}")
209
+ print(f"\nStarting training for {NUM_EPOCHS} epochs...")
210
+ train_result = trainer.train()
211
+
212
+ print("\n" + "=" * 70)
213
+ print("Training complete!")
214
+ print(f"Train loss: {train_result.training_loss:.4f}")
215
+ print("=" * 70)
216
+
217
+ # Save final model
218
+ trainer.save_model(OUTPUT_DIR)
219
+ trainer.push_to_hub()
220
+
221
+ # ============================================================================
222
+ # 4. Evaluation on test set
223
+ # ============================================================================
224
+ print("\n" + "=" * 70)
225
+ print("Evaluating on test set...")
226
+ print("=" * 70)
227
+
228
+ from transformers import pipeline
229
+
230
+ # Load fine-tuned model for inference
231
+ pipe = pipeline(
232
+ "text-generation",
233
+ model=OUTPUT_DIR,
234
+ tokenizer=tokenizer,
235
+ device_map="auto",
236
+ max_new_tokens=128,
237
+ )
238
+
239
+ # Evaluation metrics
240
+ correct_exact = 0
241
+ correct_partial = 0
242
+ correct_chapter = 0
243
+ correct_category = 0 # first 3 chars (e.g., J18)
244
+ total = 0
245
+ results = []
246
+
247
+ # Sample test set for evaluation (max 2000 for speed)
248
+ eval_size = min(2000, len(test_ds))
249
+ eval_indices = random.sample(range(len(test_ds)), eval_size)
250
+
251
+ print(f"Evaluating on {eval_size} test examples...")
252
+
253
+ for idx, i in enumerate(eval_indices):
254
+ example = test_ds[i]
255
+
256
+ # Build the prompt messages
257
+ messages = example["prompt"]
258
+
259
+ # Generate
260
+ output = pipe(messages, max_new_tokens=128, do_sample=False, temperature=None)
261
+ generated = output[0]["generated_text"][-1]["content"]
262
+
263
+ # Extract predicted ICD code from generated text
264
+ # Pattern: look for ICD-10-CM code format (letter + digits + optional dot + more chars)
265
+ pred_codes = re.findall(r'\b([A-Z]\d{2}(?:\.\d{1,4})?(?:[A-Z])?)\b', generated)
266
+
267
+ # Extract ground truth code from completion
268
+ gt_text = example["completion"][0]["content"]
269
+ gt_codes = re.findall(r'\b([A-Z]\d{2}(?:\.\d{1,4})?(?:[A-Z])?)\b', gt_text)
270
+
271
+ if gt_codes and pred_codes:
272
+ gt_code = gt_codes[0]
273
+ pred_code = pred_codes[0]
274
+
275
+ # Exact match
276
+ if pred_code == gt_code:
277
+ correct_exact += 1
278
+
279
+ # Partial match (code without laterality suffix)
280
+ gt_base = gt_code.split('.')[0] + ('.' + gt_code.split('.')[1][:2] if '.' in gt_code else '')
281
+ pred_base = pred_code.split('.')[0] + ('.' + pred_code.split('.')[1][:2] if '.' in pred_code else '')
282
+ if pred_base == gt_base:
283
+ correct_partial += 1
284
+
285
+ # Category match (first 3 chars, e.g., J18, M24)
286
+ if pred_code[:3] == gt_code[:3]:
287
+ correct_category += 1
288
+
289
+ # Chapter match (first letter)
290
+ if pred_code[0] == gt_code[0]:
291
+ correct_chapter += 1
292
+
293
+ results.append({
294
+ "gt_code": gt_code,
295
+ "pred_code": pred_code,
296
+ "exact_match": pred_code == gt_code,
297
+ "category_match": pred_code[:3] == gt_code[:3],
298
+ })
299
+ else:
300
+ results.append({
301
+ "gt_code": gt_codes[0] if gt_codes else "NONE",
302
+ "pred_code": pred_codes[0] if pred_codes else "NONE",
303
+ "exact_match": False,
304
+ "category_match": False,
305
+ })
306
+
307
+ total += 1
308
+
309
+ if (idx + 1) % 200 == 0:
310
+ print(f" Evaluated {idx+1}/{eval_size} | "
311
+ f"Exact: {correct_exact/total*100:.1f}% | "
312
+ f"Category: {correct_category/total*100:.1f}%")
313
+
314
+ # Final metrics
315
+ print("\n" + "=" * 70)
316
+ print("EVALUATION RESULTS")
317
+ print("=" * 70)
318
+ exact_acc = correct_exact / total * 100
319
+ partial_acc = correct_partial / total * 100
320
+ category_acc = correct_category / total * 100
321
+ chapter_acc = correct_chapter / total * 100
322
+
323
+ print(f" Exact Match Accuracy: {exact_acc:.2f}% ({correct_exact}/{total})")
324
+ print(f" Partial Match Accuracy: {partial_acc:.2f}% ({correct_partial}/{total})")
325
+ print(f" Category (3-char) Acc: {category_acc:.2f}% ({correct_category}/{total})")
326
+ print(f" Chapter (1st letter): {chapter_acc:.2f}% ({correct_chapter}/{total})")
327
+
328
+ # Log to trackio
329
+ trackio.log({
330
+ "eval/exact_match_accuracy": exact_acc,
331
+ "eval/partial_match_accuracy": partial_acc,
332
+ "eval/category_accuracy": category_acc,
333
+ "eval/chapter_accuracy": chapter_acc,
334
+ "eval/total_samples": total,
335
+ })
336
+
337
+ # Error analysis: which chapters have lowest accuracy
338
+ print("\n--- Per-Chapter Accuracy ---")
339
+ chapter_stats = {}
340
+ for r in results:
341
+ ch = r["gt_code"][0] if r["gt_code"] != "NONE" else "?"
342
+ if ch not in chapter_stats:
343
+ chapter_stats[ch] = {"total": 0, "correct": 0}
344
+ chapter_stats[ch]["total"] += 1
345
+ if r["exact_match"]:
346
+ chapter_stats[ch]["correct"] += 1
347
+
348
+ for ch in sorted(chapter_stats.keys()):
349
+ s = chapter_stats[ch]
350
+ acc = s["correct"] / s["total"] * 100 if s["total"] > 0 else 0
351
+ print(f" Chapter {ch}: {acc:.1f}% ({s['correct']}/{s['total']})")
352
+
353
+ # Save results
354
+ with open(os.path.join(OUTPUT_DIR, "eval_results.json"), "w") as f:
355
+ json.dump({
356
+ "exact_match_accuracy": exact_acc,
357
+ "partial_match_accuracy": partial_acc,
358
+ "category_accuracy": category_acc,
359
+ "chapter_accuracy": chapter_acc,
360
+ "total_evaluated": total,
361
+ "per_chapter": chapter_stats,
362
+ }, f, indent=2)
363
+
364
+ # Sample predictions
365
+ print("\n--- Sample Predictions ---")
366
+ for r in results[:10]:
367
+ status = "✅" if r["exact_match"] else ("🟡" if r["category_match"] else "❌")
368
+ print(f" {status} GT: {r['gt_code']:<12} Pred: {r['pred_code']}")
369
+
370
+ trackio.finish()
371
+
372
+ print("\n" + "=" * 70)
373
+ print(f"Model saved to Hub: https://hf.co/{HUB_MODEL_ID}")
374
+ print(f"Training dashboard: trackio")
375
+ print("=" * 70)