NFNet-F1 β Multi-Crop Plant Disease Classification
This repository contains a high-accuracy CNN model for multi-crop plant (crop) disease classification using leaf images.
The model is trained from scratch using NFNet-F1 and evaluated under a strict, non-leaky validation protocol.
π§ Task Description
Task: Image-based plant / crop disease classification
Input: RGB leaf images
Output: One of 88 cropβdisease classes, e.g.:
Tomato__late_blightRice__hispaWheat__yellow_rustApple__black_rot
This formulation is commonly referred to as crop disease detection in agricultural computer vision literature.
π Dataset
- Source: Merged multi-crop plant disease dataset
- Total images: ~79,000
- Number of classes: 88
- Crops included: Tomato, Rice, Wheat, Apple, Grape, Tea, Corn, Soybean, Potato, and more
- Split:
- Train: ~71k images
- Validation: ~8k images
The dataset contains real-world variation (lighting, background, leaf morphology) and class imbalance, making it significantly more challenging than lab-style benchmarks (e.g. PlantVillage).
ποΈ Model Architecture
- Backbone: NFNet-F1
- Parameters: ~130M
- Training: From scratch (no ImageNet pretraining)
- Precision: bfloat16
- Optimizer: AdamW (fused)
- Stabilization: Adaptive Gradient Clipping (AGC)
- Scheduler: Linear warmup + cosine decay
- Input resolution: 512 Γ 512
β Performance (Validation)
Best checkpoint: Epoch 25
| Metric | Value |
|---|---|
| Validation Accuracy | 95.40% |
| Weighted F1 | 95.38% |
| Macro F1 | ~88.6% |
These results are achieved on a strict validation split with no data leakage and macro-F1 reporting, providing a realistic estimate of real-world performance.
π Repository Files
model.safetensors
Secure, production-ready model weights (preferred format)classification_report.json
Full per-class precision, recall, and F1-scoremetrics.json
Summary validation metrics for the best checkpointconfig.json
Model configuration, class names, and preprocessing details
π Why Accuracy Is Below β99%+β Benchmarks
Some papers report 99%+ accuracy on PlantVillage, which uses controlled lab images with uniform backgrounds and fewer classes.
This model is evaluated on a much harder setting:
- Multi-crop (88 classes)
- Real-world images
- Strong class imbalance
- Macro-F1 reporting
Under these conditions, ~95% accuracy represents state-of-the-art performance.
π Intended Use
- Crop disease diagnosis support systems
- Agricultural decision-support tools
- Research and benchmarking
- Backend inference APIs
β οΈ This model is not a medical or agronomic diagnostic tool and should be used as an assistive system only.
π§ How to Use the Model (Inference)
import torch
import timm
from PIL import Image
from torchvision import transforms
from safetensors.torch import load_file
import json
# -------------------------
# Load config
# -------------------------
with open("config.json", "r") as f:
config = json.load(f)
class_names = config["class_names"]
img_size = config["input_size"]
# -------------------------
# Build model
# -------------------------
model = timm.create_model(
config["architecture"],
pretrained=False,
num_classes=len(class_names)
)
# Load safetensors weights
state_dict = load_file("model.safetensors")
model.load_state_dict(state_dict)
model.eval()
# -------------------------
# Preprocessing
# -------------------------
mean = config["normalization"]["mean"]
std = config["normalization"]["std"]
transform = transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
# -------------------------
# Load image
# -------------------------
image = Image.open("leaf.jpg").convert("RGB")
x = transform(image).unsqueeze(0)
# -------------------------
# Inference
# -------------------------
with torch.no_grad():
logits = model(x)
pred_idx = logits.argmax(dim=1).item()
prediction = class_names[pred_idx]
print("Predicted class:", prediction)
π§Ύ License
This project follows the datasetβs original license: CC BY-NC-SA 4.0
Commercial usage may require additional permissions.
βοΈ Author
Arko007
- Downloads last month
- 30