Spaces:
Running
Running
| """HF Spaces Gradio app — chronic wound classifier (4-class), bilingual EN/FR. | |
| Self-contained: no wound_classifier package install required. The model | |
| architecture and transforms are inlined here so this file plus the .pt | |
| checkpoint and requirements.txt are everything the Space needs. | |
| If the architecture or transform here drifts from | |
| src/wound_classifier/{modeling/models.py, features.py} the Space and the | |
| training pipeline will silently disagree. Keep them in sync. | |
| Theming approximates Hôpital Montfort (Ottawa) brand colors, sourced from | |
| the live hopitalmontfort.com stylesheet: primary "Montfort blue" #00729a, | |
| turquoise accent #47c9cd, warm cream surface #f1ede5. | |
| """ | |
| from __future__ import annotations | |
| from pathlib import Path | |
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from torch import nn | |
| from torchvision import transforms | |
| from torchvision.models import efficientnet_b0 | |
| CKPT_PATH = Path(__file__).parent / "cv_baseline_fold5_best.pt" | |
| IMAGE_SIZE = 224 | |
| IMAGENET_MEAN = (0.485, 0.456, 0.406) | |
| IMAGENET_STD = (0.229, 0.224, 0.225) | |
| IDX_TO_CLASS = ["D", "P", "S", "V"] | |
| CLASS_NAMES: dict[str, dict[str, str]] = { | |
| "en": { | |
| "D": "Diabetic ulcer", | |
| "P": "Pressure ulcer", | |
| "S": "Surgical wound", | |
| "V": "Venous ulcer", | |
| }, | |
| "fr": { | |
| "D": "Ulcère diabétique", | |
| "P": "Escarre", | |
| "S": "Plaie chirurgicale", | |
| "V": "Ulcère veineux", | |
| }, | |
| } | |
| LOW_CONFIDENCE_THRESHOLD = 0.5 | |
| # ---------- Localized strings --------------------------------------------------- | |
| SPACE_URL = "https://huggingface.co/spaces/jbobym/wound-classifier" | |
| TITLE: dict[str, str] = { | |
| "en": ( | |
| "# Chronic Wound Classifier\n" | |
| "*Developed, trained, and deployed by **John Boby Mesadieu**.*\n\n" | |
| "A model I trained to look at a wound photo and guess which of four types it is. " | |
| "It's right roughly 8 times out of 10; this page also tells you when not to trust it.\n\n" | |
| f"**Share this demo:** [{SPACE_URL}]({SPACE_URL})" | |
| ), | |
| "fr": ( | |
| "# Classification des plaies chroniques\n" | |
| "*Conçu, entraîné et déployé par **John Boby Mesadieu**.*\n\n" | |
| "Un modèle que j'ai entraîné pour regarder une photo de plaie et deviner duquel des quatre " | |
| "types il s'agit. Il a raison environ 8 fois sur 10 ; cette page vous dit aussi quand ne " | |
| "pas lui faire confiance.\n\n" | |
| f"**Partager cette démo :** [{SPACE_URL}]({SPACE_URL})" | |
| ), | |
| } | |
| DESCRIPTION: dict[str, str] = { | |
| "en": """\ | |
| Upload a photo of a wound and the model picks one of four types (diabetic, pressure, surgical, or | |
| venous), with a confidence percentage for each. | |
| A few things to know before you try it: | |
| - **Centre the wound in the photo.** The model only looks at a square in the middle of the image; | |
| anything in the corners gets cropped out. | |
| - **JPEG or PNG. That's it.** | |
| - **Only upload wound photos.** The model has to pick one of the four types. If you give it | |
| something else, it will still call it a wound. Watch the confidence percentage: if it comes back | |
| under 50%, the model is probably guessing. | |
| - **Pressure ulcers are the model's weak spot.** It gets them right roughly 4 times out of 10. When | |
| it says Pressure, take the answer with a grain of salt. | |
| This is a research demo, not a medical device. It doesn't diagnose, triage, or replace a | |
| clinician's judgement. The *Approach* section below has the methodology and the headline accuracy. | |
| """, | |
| "fr": """\ | |
| Téléversez une photo de plaie ; le modèle choisit l'un de quatre types (diabétique, escarre, | |
| chirurgicale ou veineux) avec un pourcentage de confiance pour chacun. | |
| À savoir avant d'essayer : | |
| - **Centrez la plaie dans la photo.** Le modèle ne regarde qu'un carré au milieu de l'image ; tout | |
| ce qui se trouve dans les coins est coupé. | |
| - **JPEG ou PNG. C'est tout.** | |
| - **Téléversez seulement des photos de plaie.** Le modèle doit choisir l'un des quatre types. Si | |
| vous lui donnez autre chose, il l'appellera quand même une plaie. Surveillez le pourcentage de | |
| confiance : sous 50 %, le modèle devine probablement. | |
| - **L'escarre est le point faible du modèle.** Il la reconnaît correctement environ 4 fois sur 10. | |
| Quand il dit Escarre, prenez la réponse avec précaution. | |
| Ceci est une démonstration de recherche, pas un dispositif médical. Le modèle ne pose pas de | |
| diagnostic, ne fait pas de triage et ne remplace pas le jugement clinique. La section *Approche* | |
| ci-dessous donne la méthodologie et l'exactitude principale. | |
| """, | |
| } | |
| ARTICLE: dict[str, str] = { | |
| "en": """\ | |
| ### Approach | |
| I trained an image classifier (EfficientNet-B0) on the AZH Chronic Wound Database, a public | |
| research dataset of clinical wound photos. The training was set up so that the same patient's | |
| photos never appeared in both the training and test sets; that detail matters more than it sounds, | |
| because models on this dataset can otherwise inflate their accuracy by quietly memorising patients | |
| instead of learning what wounds actually look like. | |
| On the held-out test set of 184 photos, the version of the model running here gets the wound type | |
| right 81 times out of 100. As a sanity check, I trained nine other versions of the same model on | |
| slightly different slices of the data and averaged their predictions; that combined version scored | |
| 80 out of 100 on the same test, which suggests the headline number is not a fluke. | |
| ### Out of scope | |
| Not for clinical decision-making. No claim of diagnostic accuracy on real patient cohorts. No | |
| fairness audit across skin tones, which is a known gap. | |
| ### Author | |
| John Boby Mesadieu. | |
| ### Dataset citation | |
| Anisuzzaman et al. 2022, *Multi-modal wound classification using wound image and location by deep | |
| neural network*, Sci. Rep. 12:20057. | |
| """, | |
| "fr": """\ | |
| ### Approche | |
| J'ai entraîné un classifieur d'images (EfficientNet-B0) sur la AZH Chronic Wound Database, un jeu | |
| de données public de photos cliniques de plaies. L'entraînement a été configuré pour que les | |
| photos d'un même patient n'apparaissent jamais à la fois dans le jeu d'entraînement et dans le jeu | |
| de test ; ce détail compte, parce que les modèles entraînés sur ce jeu de données peuvent | |
| autrement gonfler leur exactitude en mémorisant discrètement des patients plutôt qu'en apprenant à | |
| quoi ressemble une plaie. | |
| Sur le jeu de test retenu de 184 photos, la version du modèle déployée ici trouve le bon type de | |
| plaie 81 fois sur 100. Comme contrôle, j'ai entraîné neuf autres versions du même modèle sur des | |
| découpes légèrement différentes des données et fait la moyenne de leurs prédictions ; cette | |
| version combinée a obtenu 80 sur 100 sur le même test, ce qui suggère que le chiffre principal | |
| n'est pas un coup de chance. | |
| ### Hors champ | |
| Pas pour la décision clinique. Aucune prétention d'exactitude diagnostique sur de vraies cohortes | |
| de patients. Aucun audit d'équité par teinte de peau, ce qui constitue une limite connue. | |
| ### Auteur | |
| John Boby Mesadieu. | |
| ### Référence du jeu de données | |
| Anisuzzaman et coll. 2022, *Multi-modal wound classification using wound image and location by | |
| deep neural network*, Sci. Rep. 12:20057. | |
| """, | |
| } | |
| LABELS: dict[str, dict[str, str]] = { | |
| "en": { | |
| "lang_radio": "Language / Langue", | |
| "image_input": "Wound photograph (close-up, centered)", | |
| "label_output": "Predicted wound type", | |
| "notes_output": "Notes", | |
| "submit": "Classify", | |
| "clear": "Clear", | |
| "share": "Share this prediction", | |
| }, | |
| "fr": { | |
| "lang_radio": "Language / Langue", | |
| "image_input": "Photographie de la plaie (gros plan, centrée)", | |
| "label_output": "Type de plaie prédit", | |
| "notes_output": "Remarques", | |
| "submit": "Classer", | |
| "clear": "Effacer", | |
| "share": "Partager cette prédiction", | |
| }, | |
| } | |
| NOTE_LOW_CONFIDENCE: dict[str, str] = { | |
| "en": ( | |
| "**Low confidence** (top class {top_label} at {top_pct}). " | |
| "Probably one of two things: the photo isn't a clear close-up of a wound, or it is a " | |
| "wound but not one of the four the model knows. Either way, the model still has to pick, " | |
| "so do not lean on this answer." | |
| ), | |
| "fr": ( | |
| "**Faible confiance** (classe principale {top_label} à {top_pct}). " | |
| "L'une des deux choses, probablement : la photo n'est pas un gros plan clair d'une plaie, " | |
| "ou c'est bien une plaie mais pas l'un des quatre types que le modèle connaît. Dans les " | |
| "deux cas, le modèle est obligé de choisir quand même, alors ne vous appuyez pas sur " | |
| "cette réponse." | |
| ), | |
| } | |
| NOTE_PRESSURE: dict[str, str] = { | |
| "en": ( | |
| "**Pressure ulcers are the model's weak spot.** " | |
| "It gets them right roughly 4 times out of 10. " | |
| "When it says Pressure, take the answer with a grain of salt." | |
| ), | |
| "fr": ( | |
| "**L'escarre est le point faible du modèle.** " | |
| "Il la reconnaît correctement environ 4 fois sur 10. " | |
| "Quand il dit Escarre, prenez la réponse avec précaution." | |
| ), | |
| } | |
| # ---------- Model loading ------------------------------------------------------- | |
| def _build_model(num_classes: int = 4) -> nn.Module: | |
| model: nn.Module = efficientnet_b0(weights=None) | |
| in_features = model.classifier[1].in_features # type: ignore[index, union-attr] | |
| model.classifier = nn.Sequential( | |
| nn.Dropout(p=0.2, inplace=True), | |
| nn.Linear(in_features, num_classes), | |
| ) | |
| return model | |
| def _load_model(path: Path) -> nn.Module: | |
| ckpt = torch.load(path, map_location="cpu", weights_only=False) | |
| model = _build_model(num_classes=4) | |
| model.load_state_dict(ckpt["state_dict"]) | |
| model.eval() | |
| return model | |
| def _build_transform() -> transforms.Compose: | |
| return transforms.Compose( | |
| [ | |
| transforms.Resize(IMAGE_SIZE), | |
| transforms.CenterCrop(IMAGE_SIZE), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), | |
| ] | |
| ) | |
| MODEL = _load_model(CKPT_PATH) | |
| TRANSFORM = _build_transform() | |
| def _lang_code(choice: str) -> str: | |
| return "fr" if choice == "Français" else "en" | |
| def _format_pct(value: float, lang: str) -> str: | |
| pct = f"{value:.0%}" | |
| # Use French non-breaking space + lowercase percent? Standard French formatting. | |
| return pct.replace(".", ",") if lang == "fr" else pct | |
| def classify(image: Image.Image | None, language_choice: str) -> tuple[dict[str, float], str]: | |
| if image is None: | |
| return {}, "" | |
| lang = _lang_code(language_choice) | |
| rgb = image.convert("RGB") | |
| x = TRANSFORM(rgb).unsqueeze(0) | |
| with torch.inference_mode(): | |
| logits = MODEL(x) | |
| probs = torch.softmax(logits, dim=1).squeeze(0).numpy() | |
| name_map = CLASS_NAMES[lang] | |
| label_probs = {name_map[IDX_TO_CLASS[i]]: float(probs[i]) for i in range(4)} | |
| top_label, top_prob = max(label_probs.items(), key=lambda kv: kv[1]) | |
| top_letter = next(letter for letter, name in name_map.items() if name == top_label) | |
| notes: list[str] = [] | |
| if top_prob < LOW_CONFIDENCE_THRESHOLD: | |
| notes.append( | |
| NOTE_LOW_CONFIDENCE[lang].format( | |
| top_label=top_label, top_pct=_format_pct(top_prob, lang) | |
| ) | |
| ) | |
| if top_letter == "P": | |
| notes.append(NOTE_PRESSURE[lang]) | |
| return label_probs, "\n\n".join(notes) | |
| # ---------- UI ------------------------------------------------------------------ | |
| # Custom theme using Hôpital Montfort brand colors (extracted from their stylesheet): | |
| # primary "Montfort blue" #00729a, turquoise accent #47c9cd, warm cream #f1ede5. | |
| montfort_blue = gr.themes.Color( | |
| name="montfort_blue", | |
| c50="#eef7fa", | |
| c100="#c6eafa", | |
| c200="#9bd9ed", | |
| c300="#6ec5dd", | |
| c400="#3aa5c4", | |
| c500="#00729a", | |
| c600="#005f81", | |
| c700="#004d68", | |
| c800="#003a4f", | |
| c900="#002836", | |
| c950="#001a25", | |
| ) | |
| montfort_turquoise = gr.themes.Color( | |
| name="montfort_turquoise", | |
| c50="#e6fbfb", | |
| c100="#c6f4f5", | |
| c200="#9eeaeb", | |
| c300="#73dde0", | |
| c400="#47c9cd", | |
| c500="#23b6ba", | |
| c600="#1a9498", | |
| c700="#147576", | |
| c800="#0f5859", | |
| c900="#0a3c3d", | |
| c950="#062323", | |
| ) | |
| theme = gr.themes.Soft( | |
| primary_hue=montfort_blue, | |
| secondary_hue=montfort_turquoise, | |
| neutral_hue=gr.themes.colors.stone, | |
| font=[gr.themes.GoogleFont("Inter"), "system-ui", "sans-serif"], | |
| ).set( | |
| # Light mode (Montfort cream surface, white panels) | |
| body_background_fill="#f1ede5", | |
| block_background_fill="white", | |
| block_border_color="#dee2e6", | |
| button_primary_background_fill="#00729a", | |
| button_primary_background_fill_hover="#005f81", | |
| button_primary_text_color="white", | |
| # Dark mode (keep the Montfort identity — deep blue surface, lighter blue panels) | |
| body_background_fill_dark="#002836", | |
| block_background_fill_dark="#003a4f", | |
| block_border_color_dark="#005f81", | |
| button_primary_background_fill_dark="#47c9cd", | |
| button_primary_background_fill_hover_dark="#23b6ba", | |
| button_primary_text_color_dark="#001a25", | |
| body_text_color_dark="#eef7fa", | |
| block_label_text_color_dark="#c6eafa", | |
| block_title_text_color_dark="#eef7fa", | |
| ) | |
| def _localize_components( | |
| language_choice: str, | |
| ) -> tuple[gr.Markdown, gr.Markdown, gr.Image, gr.Label, gr.Markdown, gr.Button, gr.Button]: | |
| lang = _lang_code(language_choice) | |
| labels = LABELS[lang] | |
| return ( | |
| gr.Markdown(value=TITLE[lang]), | |
| gr.Markdown(value=DESCRIPTION[lang]), | |
| gr.Image(label=labels["image_input"]), | |
| gr.Label(label=labels["label_output"]), | |
| gr.Markdown(value="", label=labels["notes_output"]), | |
| gr.Button(value=labels["submit"]), | |
| gr.Button(value=labels["clear"]), | |
| ) | |
| with gr.Blocks(theme=theme, title="Chronic Wound Classifier · Hôpital Montfort demo") as demo: | |
| language_radio = gr.Radio( | |
| choices=["English", "Français"], | |
| value="English", | |
| label=LABELS["en"]["lang_radio"], | |
| interactive=True, | |
| ) | |
| title_md = gr.Markdown(TITLE["en"]) | |
| description_md = gr.Markdown(DESCRIPTION["en"]) | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(type="pil", label=LABELS["en"]["image_input"]) | |
| with gr.Row(): | |
| submit_btn = gr.Button(LABELS["en"]["submit"], variant="primary") | |
| clear_btn = gr.Button(LABELS["en"]["clear"]) | |
| with gr.Column(): | |
| label_output = gr.Label(num_top_classes=4, label=LABELS["en"]["label_output"]) | |
| notes_output = gr.Markdown(label=LABELS["en"]["notes_output"]) | |
| share_btn = gr.DeepLinkButton(value=LABELS["en"]["share"], icon=None) | |
| article_md = gr.Markdown(ARTICLE["en"]) | |
| submit_btn.click( | |
| classify, | |
| inputs=[image_input, language_radio], | |
| outputs=[label_output, notes_output], | |
| ) | |
| image_input.change( | |
| classify, | |
| inputs=[image_input, language_radio], | |
| outputs=[label_output, notes_output], | |
| ) | |
| clear_btn.click( | |
| lambda: (None, {}, ""), | |
| inputs=[], | |
| outputs=[image_input, label_output, notes_output], | |
| ) | |
| def _on_language_change( | |
| language_choice: str, current_image: Image.Image | None | |
| ) -> tuple[dict, dict, dict, dict, dict, dict, dict, dict, str]: | |
| lang = _lang_code(language_choice) | |
| labels = LABELS[lang] | |
| # Re-run inference so the on-screen probability labels switch languages too. | |
| new_probs, new_notes = classify(current_image, language_choice) | |
| return ( | |
| gr.update(value=TITLE[lang]), | |
| gr.update(value=DESCRIPTION[lang]), | |
| gr.update(value=ARTICLE[lang]), | |
| gr.update(label=labels["image_input"]), | |
| gr.update(label=labels["label_output"], value=new_probs), | |
| gr.update(value=labels["submit"]), | |
| gr.update(value=labels["clear"]), | |
| gr.update(value=labels["share"]), | |
| new_notes, | |
| ) | |
| language_radio.change( | |
| _on_language_change, | |
| inputs=[language_radio, image_input], | |
| outputs=[ | |
| title_md, | |
| description_md, | |
| article_md, | |
| image_input, | |
| label_output, | |
| submit_btn, | |
| clear_btn, | |
| share_btn, | |
| notes_output, | |
| ], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |