| --- |
| license: mit |
| language: |
| - en |
| tags: |
| - text-classification |
| - medical |
| - nhs |
| - clinical-letters |
| - distilbert |
| pipeline_tag: text-classification |
| --- |
| |
| # NHS Medical Letter Classifier |
|
|
| Fine-tuned **DistilBERT** (`distilbert-base-uncased`) for classifying OCR'd NHS medical clinic letters into 49 letter type categories. |
|
|
| ## Model Details |
|
|
| | Parameter | Value | |
| |---|---| |
| | Base model | `distilbert-base-uncased` | |
| | Training samples | 13,672 | |
| | Classes | 49 | |
| | Epochs | 6 | |
| | Batch size | 16 | |
| | Learning rate | 2e-5 | |
| | Max sequence length | 512 tokens | |
| | Cleanlab corrections | 212 labels relabeled (1.6% of dataset) | |
|
|
| ## How We Got Here: Experiment Journey |
|
|
| ### 1. Baseline: TF-IDF + LinearSVC |
| - **Approach:** TfidfVectorizer (unigram+bigram, 50k features) with CalibratedClassifierCV(LinearSVC) |
| - **Result:** ~91% accuracy on the original label set |
| - **Takeaway:** Strong baseline, but limited by bag-of-words representation |
|
|
| ### 2. Label Merging (Critical Improvement) |
| - **Approach:** Consolidated synonymous labels (e.g., "Nephrology" to "Renal", "Minor Illness Consultation" to "Pharmacy") and dropped ambiguous/administrative labels |
| - **Result:** Accuracy jumped from ~91% to ~96% |
| - **Takeaway:** Label quality matters more than model architecture. Reduced label set from ~51 to 49 meaningful categories |
|
|
| ### 3. DistilBERT Baseline (Our Core Model) |
| - **Approach:** Fine-tuned `distilbert-base-uncased`, 4 epochs, 512 tokens, 70/10/20 stratified split |
| - **Result:** Top-1: 95.76% | Top-3: 98.06% | Top-5: 98.61% |
| - **Takeaway:** Strong performance, established as the baseline for all further experiments |
|
|
| ### 4. ClinicalBERT & BioClinicalBERT |
| - **Approach:** Tested domain-specific models (`medicalai/ClinicalBERT`, `emilyalsentzer/Bio_ClinicalBERT`) |
| - **Result:** Similar to DistilBERT (~95-96%), no meaningful improvement |
| - **Takeaway:** General-purpose DistilBERT captures enough for this task; domain pre-training didn't help |
|
|
| ### 5. Longformer (1024 tokens) |
| - **Approach:** `allenai/longformer-base-4096` at 1024 tokens with global attention on CLS, case-sensitive |
| - **Result:** Comparable to DistilBERT at 512 tokens |
| - **Takeaway:** Most discriminative information is in the first 512 tokens; longer context doesn't help |
|
|
| ### 6. Hierarchical Architecture |
| - **Approach:** Two-stage: DistilBERT body for CLS embeddings, per-clinic LogisticRegression heads. 51 fine labels mapped to 25 broad categories |
| - **Result:** Did not outperform flat DistilBERT |
| - **Takeaway:** The flat classification space works well; hierarchical routing adds complexity without benefit |
|
|
| ### 7. LLM Relabeling (GPT-5-mini) |
| - **Approach:** Used OpenAI Batch API to get GPT-5-mini to reclassify all 13,672 samples. Trained DistilBERT on LLM-assigned labels |
| - **Result:** 86.22% vs original labels | 93.24% vs LLM labels (Top-1) |
| - **Takeaway:** LLM agrees with original labels ~85.7% of the time. LLM labels are different but not better — the original clinical labels carry domain knowledge the LLM lacks |
|
|
| ### 8. Consensus Relabeling |
| - **Approach:** Only change labels where both BERT and GPT-5-mini agree the original label is wrong |
| - **Result:** Only 4 out of 9,569 samples met the consensus criteria |
| - **Takeaway:** BERT memorizes its training labels, so it almost never disagrees with originals on training data. Consensus is too strict |
|
|
| ### 9. Soft Knowledge Distillation |
| - **Approach:** Got GPT-5-mini top-5 predictions with confidence scores as soft labels. Trained with blended loss: alpha * CE(hard) + (1-alpha) * KL(soft || student), alpha=0.5 |
| - **Result:** Top-1: 95.32% (-0.44pp) | Top-3: 97.48% (-0.58pp) |
| - **Takeaway:** LLM self-reported confidence scores are too noisy/uniform. Soft KL loss stayed flat at ~3.5. Would need actual logprobs for this to work |
|
|
| ### 10. Cleanlab: Remove Mislabeled Samples |
| - **Approach:** Confident learning (Northcutt et al. 2021). 3-fold cross-validation for out-of-sample probabilities, then `find_label_issues()` to detect mislabeled samples. Removed 142 flagged training samples and retrained |
| - **Result:** Top-1: 95.90% (+0.14pp) | Top-3: 97.70% (-0.36pp) |
| - **Takeaway:** Small top-1 gain, but removing ambiguous samples hurt ranked predictions. Manual inspection confirmed ~99% of flagged samples were genuinely mislabeled |
|
|
| ### 11. Cleanlab: Relabel Instead of Remove |
| - **Approach:** Same cleanlab detection, but replaced wrong labels with model's predicted label instead of removing samples |
| - **Result (vs original test labels):** Top-1: 95.80% | Top-3: 97.92% | Top-5: 98.46% |
| - **Result (vs corrected test labels):** Top-1: 98.06% | Top-3: 99.09% | Top-5: 99.38% |
| - **Takeaway:** The ~2pp gap between original and corrected evaluation reveals that the remaining "errors" are mostly test set noise, not model mistakes. True model performance is ~98% top-1 |
|
|
| ### 12. Production Model (This Model) |
| - **Approach:** Fresh 3-fold cleanlab on the **entire** dataset (13,672 samples). Found 212 mislabeled samples (1.6%), relabeled all. Trained on full corrected dataset for 6 epochs |
| - **Sanity check:** 99.74% accuracy on training data (expected, since model saw all data) |
| - **Estimated true accuracy:** ~98% top-1, ~99% top-3 based on corrected-label evaluation |
|
|
| ## Key Findings |
|
|
| 1. **Label quality > model architecture.** Label merging (+5pp) and cleanlab corrections (+2pp true accuracy) had more impact than any model change |
| 2. **DistilBERT is sufficient.** Domain-specific models (ClinicalBERT, BioClinicalBERT) and longer context (Longformer) didn't help |
| 3. **~1.6% of labels are wrong.** Discharge summary (9.1%), Paediatrics (7.2%), and Physiotherapy (6.8%) are the noisiest classes |
| 4. **The model is better than naive metrics suggest.** When evaluated against corrected labels, top-1 jumps from ~96% to ~98% |
|
|
| ## Labels (49 classes) |
|
|
| - `A&E` |
| - `Ambulance Notification` |
| - `Audiology` |
| - `Bowel Cancer Screening` |
| - `Breast Clinic` |
| - `Cancer Screening` |
| - `Cardiology` |
| - `Colposcopy` |
| - `Dermatology` |
| - `Diabetes & Endocrine` |
| - `Diet Services` |
| - `Discharge summary` |
| - `ENT` |
| - `Echocardiogram` |
| - `Elderly Care` |
| - `Gastroenterology` |
| - `General Surgery` |
| - `Genetics` |
| - `Haematology` |
| - `INR` |
| - `Immunology` |
| - `Mammogram` |
| - `Maternity` |
| - `Maxillofacial` |
| - `Mental Health` |
| - `Neurology` |
| - `Neurosurgery` |
| - `Obstetrics & Gynaecology` |
| - `Oncology` |
| - `Ophthalmology` |
| - `Orthopaedics` |
| - `Out of Hours` |
| - `Paediatrics` |
| - `Pain Management` |
| - `Pharmacy` |
| - `Physiotherapy` |
| - `Plastic Surgery` |
| - `Radiology` |
| - `Renal` |
| - `Respiratory` |
| - `Retinal Screening` |
| - `Rheumatology` |
| - `Sexual Health` |
| - `Speech and Language Therapy` |
| - `Stroke Services` |
| - `Urgent Care Centre` |
| - `Urology` |
| - `Vascular` |
| - `Walk in Centre` |
|
|
| ## Usage |
|
|
| ```python |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification |
| import torch, json |
| |
| model = AutoModelForSequenceClassification.from_pretrained("mansour94/kynoby-william-bert-classifier") |
| tokenizer = AutoTokenizer.from_pretrained("mansour94/kynoby-william-bert-classifier") |
| |
| # Load label map |
| from huggingface_hub import hf_hub_download |
| label_map = json.load(open(hf_hub_download("mansour94/kynoby-william-bert-classifier", "label_map.json"))) |
| id2label = {int(k): v for k, v in label_map["id2label"].items()} |
| |
| text = "Dear Dr Smith, I am writing to inform you about the patient's ophthalmology appointment..." |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512) |
| with torch.no_grad(): |
| logits = model(**inputs).logits |
| probs = torch.softmax(logits, dim=-1) |
| |
| # Top-3 predictions |
| top3 = torch.topk(probs, 3) |
| for i in range(3): |
| idx = top3.indices[0][i].item() |
| conf = top3.values[0][i].item() |
| print(f" {id2label[idx]}: {conf:.1%}") |
| ``` |
|
|