""" ICD-10-CM Clinical Coding Fine-tuning Script (GPU - Production) ================================================================ Fine-tunes Qwen2.5-1.5B-Instruct with LoRA on 366K synthetic EHR records for ICD-10-CM code classification from clinical text. Requirements: pip install torch transformers trl peft datasets trackio accelerate flash-attn Hardware: A10G (24GB) or better. Training time: ~2-3 hours. Based on: - Recipe 3: Lenz et al. (arxiv:2510.13624) — Instruction-tuning for ICD coding - Recipe 2: MERA (arxiv:2501.17326) — Code memorization improves accuracy - FiscaAI/synth-ehr-icd10cm-prompt dataset (366K rows, 5071 ICD-10-CM codes) - TRL SFTTrainer with prompt/completion format (loss on codes only) To run: # On HF Jobs (A10G): hf_jobs run --script train_icd10_gpu.py --hardware a10g-large --timeout 4h \ --deps torch transformers trl peft datasets trackio accelerate flash-attn # Or locally with GPU: pip install torch transformers trl peft datasets trackio accelerate flash-attn python train_icd10_gpu.py """ import os, re, json, random, gc from collections import Counter import torch, trackio from datasets import load_dataset from peft import LoraConfig from transformers import AutoModelForCausalLM, AutoTokenizer from trl import SFTConfig, SFTTrainer MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" HUB_MODEL_ID = "Rakshithch/qwen2.5-1.5b-icd10cm-coder" DATASET_ID = "Rakshithch/icd10cm-clinical-coding-sft" OUTPUT_DIR = "./qwen2.5-1.5b-icd10cm-lora" LEARNING_RATE = 2e-4 NUM_EPOCHS = 3 BATCH_SIZE = 4 GRAD_ACCUM = 8 MAX_LENGTH = 1024 LORA_R = 16 LORA_ALPHA = 32 SEED = 42 random.seed(SEED) trackio.init(project="icd10-clinical-coding", name="qwen2.5-1.5b-lora-r16-full", config={"model": MODEL_NAME, "dataset": DATASET_ID, "lora_r": LORA_R, "lr": LEARNING_RATE, "epochs": NUM_EPOCHS, "eff_batch": BATCH_SIZE*GRAD_ACCUM}) print("Loading dataset...") ds = load_dataset(DATASET_ID) print(f"Train: {len(ds['train'])}, Val: {len(ds['validation'])}, Test: {len(ds['test'])}") def to_pc(example): msgs = example["messages"] return {"prompt": msgs[:2], "completion": [msgs[2]]} train_ds = ds["train"].map(to_pc, remove_columns=ds["train"].column_names, num_proc=4) val_ds = ds["validation"].map(to_pc, remove_columns=ds["validation"].column_names, num_proc=4) test_ds = ds["test"] print(f"Loading {MODEL_NAME}...") model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, dtype=torch.bfloat16, attn_implementation="flash_attention_2", device_map="auto") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token peft_config = LoraConfig(r=LORA_R, lora_alpha=LORA_ALPHA, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM", target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"]) training_args = SFTConfig( output_dir=OUTPUT_DIR, num_train_epochs=NUM_EPOCHS, per_device_train_batch_size=BATCH_SIZE, per_device_eval_batch_size=BATCH_SIZE, gradient_accumulation_steps=GRAD_ACCUM, learning_rate=LEARNING_RATE, lr_scheduler_type="cosine", warmup_steps=100, optim="adamw_torch_fused", bf16=True, max_length=MAX_LENGTH, gradient_checkpointing=True, gradient_checkpointing_kwargs={"use_reentrant": False}, logging_steps=25, logging_first_step=True, disable_tqdm=True, report_to="trackio", run_name="qwen2.5-1.5b-icd10cm-lora-r16-full", eval_strategy="steps", eval_steps=500, save_strategy="steps", save_steps=500, save_total_limit=3, load_best_model_at_end=True, metric_for_best_model="eval_loss", push_to_hub=True, hub_model_id=HUB_MODEL_ID, hub_strategy="every_save", ) trainer = SFTTrainer(model=model, args=training_args, train_dataset=train_ds, eval_dataset=val_ds, peft_config=peft_config, processing_class=tokenizer) trainer.model.print_trainable_parameters() result = trainer.train() print(f"Training loss: {result.training_loss:.4f}") trainer.save_model(OUTPUT_DIR); tokenizer.save_pretrained(OUTPUT_DIR) trainer.push_to_hub() # Evaluation del trainer, model; gc.collect(); torch.cuda.empty_cache() from transformers import pipeline as hf_pipeline pipe = hf_pipeline("text-generation", model=OUTPUT_DIR, tokenizer=tokenizer, device_map="auto", max_new_tokens=150) eval_size = min(2000, len(test_ds)) eval_indices = random.sample(range(len(test_ds)), eval_size) correct_exact = correct_category = correct_chapter = total = 0 results = [] for idx in eval_indices: example = test_ds[idx] gt_code = example["icd_code"] try: out = pipe(example["messages"][:2], max_new_tokens=150, do_sample=False) generated = out[0]["generated_text"][-1]["content"] except: total += 1; continue pred_codes = re.findall(r'\b([A-Z]\d{2}(?:\.\d{1,4})?(?:[A-Z])?)\b', generated) pred = pred_codes[0] if pred_codes else "NONE" exact = pred == gt_code if exact: correct_exact += 1 if pred[:3] == gt_code[:3]: correct_category += 1 if pred[0] == gt_code[0]: correct_chapter += 1 total += 1 results.append({"gt": gt_code, "pred": pred, "exact": exact}) exact_acc = correct_exact/max(total,1)*100 cat_acc = correct_category/max(total,1)*100 ch_acc = correct_chapter/max(total,1)*100 print(f"Exact: {exact_acc:.1f}% | Category: {cat_acc:.1f}% | Chapter: {ch_acc:.1f}%") trackio.log({"eval/exact_match": exact_acc, "eval/category": cat_acc, "eval/chapter": ch_acc}) trackio.finish()