"""Plant Disease Assistant — Hugging Face Space (CPU, DINOv2-only). Loads the DINOv2-L checkpoint from a HF model repo at startup, then runs classification + template-based responses from a bundled knowledge file. Configurable via environment variables: DINOV2_REPO HF model repo containing best.pt and splits.json (default: iamcode6/dinov2-l-ccmt-mi300x) DINOV2_CKPT Filename of the checkpoint inside the repo (default: best.pt) """ from __future__ import annotations import json import os from pathlib import Path import gradio as gr import numpy as np import timm import torch import torch.nn.functional as F from huggingface_hub import hf_hub_download from PIL import Image from timm.data import create_transform HERE = Path(__file__).parent KNOWLEDGE_PATH = HERE / "treatment_knowledge.json" SPLITS_PATH = HERE / "splits.json" DINOV2_REPO = os.environ.get("DINOV2_REPO", "iamcode6/dinov2-l-ccmt-mi300x") DINOV2_CKPT = os.environ.get("DINOV2_CKPT", "best.pt") DEVICE = "cpu" class PlantClassifier: def __init__(self, checkpoint_path: Path, splits_path: Path): self.device = torch.device(DEVICE) splits = json.loads(splits_path.read_text()) self.idx_to_class = {v: k for k, v in splits["class_to_idx"].items()} self.class_names = [self.idx_to_class[i] for i in range(len(self.idx_to_class))] self.num_classes = len(self.class_names) ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False) if isinstance(ckpt, dict) and "state_dict" in ckpt: state_dict = ckpt["state_dict"] cfg = ckpt.get("cfg", {}) else: state_dict = ckpt cfg = {} state_dict = {k.replace("_orig_mod.", "", 1): v for k, v in state_dict.items()} model_name = cfg.get("model", {}).get("name", "vit_large_patch14_dinov2.lvd142m") img_size = cfg.get("model", {}).get("img_size", 224) self.model = timm.create_model( model_name, pretrained=False, num_classes=self.num_classes, img_size=img_size, ) self.model.load_state_dict(state_dict) self.model.to(self.device).eval() self.transform = create_transform( input_size=img_size, is_training=False, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), interpolation="bicubic", crop_pct=0.95, ) @torch.no_grad() def predict(self, image: Image.Image, top_k: int = 3) -> list[dict]: x = self.transform(image).unsqueeze(0).to(self.device) logits = self.model(x) probs = F.softmax(logits, dim=-1).squeeze(0).float().cpu().numpy() top_indices = np.argsort(probs)[::-1][:top_k] return [ {"class": self.class_names[i], "confidence": float(probs[i]), "index": int(i)} for i in top_indices ] class KnowledgeResponder: def __init__(self, path: Path): self.knowledge = json.loads(path.read_text()) def format_label(self, label: str) -> str: return label.replace("_", " ").title() def respond(self, predictions: list[dict]) -> str: top = predictions[0] label = top["class"] confidence = top["confidence"] if label not in self.knowledge: return ( f"**Prediction:** {self.format_label(label)} " f"(confidence: {confidence:.1%})\n\n" "No detailed information available for this condition." ) k = self.knowledge[label] is_healthy = k["disease"] == "Healthy" lines = [] if is_healthy: lines.append(f"## {k['crop']} — Healthy") lines.append(f"**Confidence:** {confidence:.1%}\n") lines.append(f"{k['symptoms']}") lines.append("\nKeep monitoring regularly and continue your current care routine.") else: lines.append(f"## {k['crop']} — {k['disease']}") lines.append(f"**Confidence:** {confidence:.1%}\n") if k.get("pathogen"): lines.append(f"**Pathogen:** *{k['pathogen']}*\n") lines.append("### Symptoms") lines.append(f"{k['symptoms']}\n") lines.append("### Severity Guide") for level, desc in k["severity_cues"].items(): lines.append(f"- **{level.title()}:** {desc}") lines.append("") lines.append("### Treatment") lines.append(f"{k['treatment']}\n") lines.append("### Prevention") lines.append(f"{k['prevention']}") if len(predictions) > 1: lines.append("\n---\n### Other Possibilities") for p in predictions[1:]: if p["confidence"] > 0.05: lines.append(f"- {self.format_label(p['class'])} ({p['confidence']:.1%})") return "\n".join(lines) print(f"[app] Downloading DINOv2-L checkpoint from {DINOV2_REPO}...") checkpoint_path = Path(hf_hub_download(repo_id=DINOV2_REPO, filename=DINOV2_CKPT)) print("[app] Loading classifier on CPU (~30s)...") classifier = PlantClassifier(checkpoint_path, SPLITS_PATH) print(f"[app] Loaded {classifier.num_classes} classes") knowledge = KnowledgeResponder(KNOWLEDGE_PATH) def diagnose(image: Image.Image | None): if image is None: return "Please upload an image.", "" image = image.convert("RGB") predictions = classifier.predict(image, top_k=3) table = "**DINOv2-L Classification (97% accuracy)**\n\n" table += "| Rank | Disease | Confidence |\n" table += "|------|---------|------------|\n" for i, p in enumerate(predictions): marker = " ←" if i == 0 else "" table += ( f"| {i+1} | {knowledge.format_label(p['class'])} | " f"{p['confidence']:.1%}{marker} |\n" ) return table, knowledge.respond(predictions) CUSTOM_CSS = """ .prose, .prose *, [class*="markdown"], [class*="markdown"] * { color: #1a1a1a !important; opacity: 1 !important; } .prose strong, .prose h1, .prose h2, .prose h3 { color: #000 !important; font-weight: 700 !important; } .dark .prose, .dark .prose *, .dark [class*="markdown"], .dark [class*="markdown"] * { color: #f5f5f5 !important; } .dark .prose strong, .dark .prose h1, .dark .prose h2, .dark .prose h3 { color: #ffffff !important; } .prose table { border-collapse: collapse; } .prose th, .prose td { padding: 6px 10px; border: 1px solid #888; } """ with gr.Blocks(title="Plant Disease Assistant", css=CUSTOM_CSS) as app: gr.Markdown( "# 🌱 Plant Disease Assistant\n" "Upload a photo of a plant leaf to get an instant diagnosis, " "severity assessment, and treatment recommendations.\n\n" "*DINOv2-Large fine-tuned on AMD Instinct MI300X (ROCm) — " "97.06% accuracy on the CCMT crop disease dataset.*" ) with gr.Row(): with gr.Column(scale=1): image_input = gr.Image(type="pil", label="Upload a plant leaf photo") diagnose_btn = gr.Button("Diagnose", variant="primary", size="lg") example_paths = sorted(str(p) for p in (HERE / "examples").glob("*.jpg")) if example_paths: gr.Examples( examples=[[p] for p in example_paths], inputs=image_input, label="Or try one of these (click a thumbnail)", examples_per_page=11, ) with gr.Column(scale=2): classification_output = gr.Markdown() response_output = gr.Markdown() diagnose_btn.click( fn=diagnose, inputs=image_input, outputs=[classification_output, response_output], ) image_input.change( fn=diagnose, inputs=image_input, outputs=[classification_output, response_output], ) gr.Markdown( "---\n" "**Model:** DINOv2-Large (304M params) — 97.06% accuracy, 0.9713 macro F1\n\n" "**Hardware:** Fine-tuned on AMD Instinct MI300X (192 GB HBM3) via AMD Developer Cloud\n\n" "**Dataset:** CCMT Crop Pest and Disease Detection — 22 classes across cashew, cassava, maize, and tomato\n\n" "*Built for the lablab.ai AMD Developer Hackathon*" ) if __name__ == "__main__": app.launch(server_name="0.0.0.0", server_port=7860, show_api=False)