Spaces:
Sleeping
Sleeping
| """Plant Disease Assistant β Hugging Face Space (CPU, DINOv2-only). | |
| Loads the DINOv2-L checkpoint from a HF model repo at startup, then runs | |
| classification + template-based responses from a bundled knowledge file. | |
| Configurable via environment variables: | |
| DINOV2_REPO HF model repo containing best.pt and splits.json | |
| (default: iamcode6/dinov2-l-ccmt-mi300x) | |
| DINOV2_CKPT Filename of the checkpoint inside the repo | |
| (default: best.pt) | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| from pathlib import Path | |
| import gradio as gr | |
| import numpy as np | |
| import timm | |
| import torch | |
| import torch.nn.functional as F | |
| from huggingface_hub import hf_hub_download | |
| from PIL import Image | |
| from timm.data import create_transform | |
| HERE = Path(__file__).parent | |
| KNOWLEDGE_PATH = HERE / "treatment_knowledge.json" | |
| SPLITS_PATH = HERE / "splits.json" | |
| DINOV2_REPO = os.environ.get("DINOV2_REPO", "iamcode6/dinov2-l-ccmt-mi300x") | |
| DINOV2_CKPT = os.environ.get("DINOV2_CKPT", "best.pt") | |
| DEVICE = "cpu" | |
| class PlantClassifier: | |
| def __init__(self, checkpoint_path: Path, splits_path: Path): | |
| self.device = torch.device(DEVICE) | |
| splits = json.loads(splits_path.read_text()) | |
| self.idx_to_class = {v: k for k, v in splits["class_to_idx"].items()} | |
| self.class_names = [self.idx_to_class[i] for i in range(len(self.idx_to_class))] | |
| self.num_classes = len(self.class_names) | |
| ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False) | |
| if isinstance(ckpt, dict) and "state_dict" in ckpt: | |
| state_dict = ckpt["state_dict"] | |
| cfg = ckpt.get("cfg", {}) | |
| else: | |
| state_dict = ckpt | |
| cfg = {} | |
| state_dict = {k.replace("_orig_mod.", "", 1): v for k, v in state_dict.items()} | |
| model_name = cfg.get("model", {}).get("name", "vit_large_patch14_dinov2.lvd142m") | |
| img_size = cfg.get("model", {}).get("img_size", 224) | |
| self.model = timm.create_model( | |
| model_name, pretrained=False, | |
| num_classes=self.num_classes, img_size=img_size, | |
| ) | |
| self.model.load_state_dict(state_dict) | |
| self.model.to(self.device).eval() | |
| self.transform = create_transform( | |
| input_size=img_size, is_training=False, | |
| mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), | |
| interpolation="bicubic", crop_pct=0.95, | |
| ) | |
| def predict(self, image: Image.Image, top_k: int = 3) -> list[dict]: | |
| x = self.transform(image).unsqueeze(0).to(self.device) | |
| logits = self.model(x) | |
| probs = F.softmax(logits, dim=-1).squeeze(0).float().cpu().numpy() | |
| top_indices = np.argsort(probs)[::-1][:top_k] | |
| return [ | |
| {"class": self.class_names[i], "confidence": float(probs[i]), "index": int(i)} | |
| for i in top_indices | |
| ] | |
| class KnowledgeResponder: | |
| def __init__(self, path: Path): | |
| self.knowledge = json.loads(path.read_text()) | |
| def format_label(self, label: str) -> str: | |
| return label.replace("_", " ").title() | |
| def respond(self, predictions: list[dict]) -> str: | |
| top = predictions[0] | |
| label = top["class"] | |
| confidence = top["confidence"] | |
| if label not in self.knowledge: | |
| return ( | |
| f"**Prediction:** {self.format_label(label)} " | |
| f"(confidence: {confidence:.1%})\n\n" | |
| "No detailed information available for this condition." | |
| ) | |
| k = self.knowledge[label] | |
| is_healthy = k["disease"] == "Healthy" | |
| lines = [] | |
| if is_healthy: | |
| lines.append(f"## {k['crop']} β Healthy") | |
| lines.append(f"**Confidence:** {confidence:.1%}\n") | |
| lines.append(f"{k['symptoms']}") | |
| lines.append("\nKeep monitoring regularly and continue your current care routine.") | |
| else: | |
| lines.append(f"## {k['crop']} β {k['disease']}") | |
| lines.append(f"**Confidence:** {confidence:.1%}\n") | |
| if k.get("pathogen"): | |
| lines.append(f"**Pathogen:** *{k['pathogen']}*\n") | |
| lines.append("### Symptoms") | |
| lines.append(f"{k['symptoms']}\n") | |
| lines.append("### Severity Guide") | |
| for level, desc in k["severity_cues"].items(): | |
| lines.append(f"- **{level.title()}:** {desc}") | |
| lines.append("") | |
| lines.append("### Treatment") | |
| lines.append(f"{k['treatment']}\n") | |
| lines.append("### Prevention") | |
| lines.append(f"{k['prevention']}") | |
| if len(predictions) > 1: | |
| lines.append("\n---\n### Other Possibilities") | |
| for p in predictions[1:]: | |
| if p["confidence"] > 0.05: | |
| lines.append(f"- {self.format_label(p['class'])} ({p['confidence']:.1%})") | |
| return "\n".join(lines) | |
| print(f"[app] Downloading DINOv2-L checkpoint from {DINOV2_REPO}...") | |
| checkpoint_path = Path(hf_hub_download(repo_id=DINOV2_REPO, filename=DINOV2_CKPT)) | |
| print("[app] Loading classifier on CPU (~30s)...") | |
| classifier = PlantClassifier(checkpoint_path, SPLITS_PATH) | |
| print(f"[app] Loaded {classifier.num_classes} classes") | |
| knowledge = KnowledgeResponder(KNOWLEDGE_PATH) | |
| def diagnose(image: Image.Image | None): | |
| if image is None: | |
| return "Please upload an image.", "" | |
| image = image.convert("RGB") | |
| predictions = classifier.predict(image, top_k=3) | |
| table = "**DINOv2-L Classification (97% accuracy)**\n\n" | |
| table += "| Rank | Disease | Confidence |\n" | |
| table += "|------|---------|------------|\n" | |
| for i, p in enumerate(predictions): | |
| marker = " β" if i == 0 else "" | |
| table += ( | |
| f"| {i+1} | {knowledge.format_label(p['class'])} | " | |
| f"{p['confidence']:.1%}{marker} |\n" | |
| ) | |
| return table, knowledge.respond(predictions) | |
| CUSTOM_CSS = """ | |
| .prose, .prose *, [class*="markdown"], [class*="markdown"] * { | |
| color: #1a1a1a !important; | |
| opacity: 1 !important; | |
| } | |
| .prose strong, .prose h1, .prose h2, .prose h3 { | |
| color: #000 !important; | |
| font-weight: 700 !important; | |
| } | |
| .dark .prose, .dark .prose *, | |
| .dark [class*="markdown"], .dark [class*="markdown"] * { | |
| color: #f5f5f5 !important; | |
| } | |
| .dark .prose strong, .dark .prose h1, .dark .prose h2, .dark .prose h3 { | |
| color: #ffffff !important; | |
| } | |
| .prose table { border-collapse: collapse; } | |
| .prose th, .prose td { padding: 6px 10px; border: 1px solid #888; } | |
| """ | |
| with gr.Blocks(title="Plant Disease Assistant", css=CUSTOM_CSS) as app: | |
| gr.Markdown( | |
| "# π± Plant Disease Assistant\n" | |
| "Upload a photo of a plant leaf to get an instant diagnosis, " | |
| "severity assessment, and treatment recommendations.\n\n" | |
| "*DINOv2-Large fine-tuned on AMD Instinct MI300X (ROCm) β " | |
| "97.06% accuracy on the CCMT crop disease dataset.*" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_input = gr.Image(type="pil", label="Upload a plant leaf photo") | |
| diagnose_btn = gr.Button("Diagnose", variant="primary", size="lg") | |
| example_paths = sorted(str(p) for p in (HERE / "examples").glob("*.jpg")) | |
| if example_paths: | |
| gr.Examples( | |
| examples=[[p] for p in example_paths], | |
| inputs=image_input, | |
| label="Or try one of these (click a thumbnail)", | |
| examples_per_page=11, | |
| ) | |
| with gr.Column(scale=2): | |
| classification_output = gr.Markdown() | |
| response_output = gr.Markdown() | |
| diagnose_btn.click( | |
| fn=diagnose, inputs=image_input, | |
| outputs=[classification_output, response_output], | |
| ) | |
| image_input.change( | |
| fn=diagnose, inputs=image_input, | |
| outputs=[classification_output, response_output], | |
| ) | |
| gr.Markdown( | |
| "---\n" | |
| "**Model:** DINOv2-Large (304M params) β 97.06% accuracy, 0.9713 macro F1\n\n" | |
| "**Hardware:** Fine-tuned on AMD Instinct MI300X (192 GB HBM3) via AMD Developer Cloud\n\n" | |
| "**Dataset:** CCMT Crop Pest and Disease Detection β 22 classes across cashew, cassava, maize, and tomato\n\n" | |
| "*Built for the lablab.ai AMD Developer Hackathon*" | |
| ) | |
| if __name__ == "__main__": | |
| app.launch(server_name="0.0.0.0", server_port=7860, show_api=False) | |