| """ |
| Inference module for counting wheat heads in field images using a DeepLabV3+ semantic |
| segmentation model trained on the GWFSS dataset. |
| |
| The model performs multi-class segmentation (Background, Leaf, Stem, Head) to accurately |
| distinguish wheat heads from other plant organs, then uses connected component analysis |
| to count individual heads. |
| """ |
|
|
| import torch |
| import torchvision.transforms as transforms |
| from PIL import Image |
| import numpy as np |
| import segmentation_models_pytorch as smp |
| from scipy import ndimage |
| from skimage.feature import peak_local_max |
|
|
| |
| IMAGENET_MEAN = [0.485, 0.456, 0.406] |
| IMAGENET_STD = [0.229, 0.224, 0.225] |
|
|
| |
| MASK_COLORS = [ |
| (0, 0, 0), |
| (214, 255, 50), |
| (50, 132, 255), |
| (50, 255, 132), |
| ] |
|
|
| class GWFSSModel: |
| def __init__(self, model_path, device=None): |
| if device is None: |
| if torch.cuda.is_available(): |
| self.device = torch.device("cuda") |
| elif torch.backends.mps.is_available(): |
| self.device = torch.device("mps") |
| else: |
| self.device = torch.device("cpu") |
| else: |
| self.device = device |
| |
| |
| self.model = smp.DeepLabV3Plus( |
| encoder_name="resnet50", |
| encoder_weights=None, |
| in_channels=3, |
| classes=4, |
| ) |
| |
| |
| checkpoint = torch.load(model_path, map_location=self.device, weights_only=False) |
| self.model.load_state_dict(checkpoint['model_state_dict']) |
| self.model = self.model.to(self.device) |
| self.model.eval() |
| |
| |
| self.transform = transforms.Compose([ |
| transforms.Resize((512, 512), interpolation=transforms.InterpolationMode.BILINEAR), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD) |
| ]) |
| |
| def preprocess_image(self, image): |
| if isinstance(image, np.ndarray): |
| image = Image.fromarray(image) |
| |
| if image.mode != 'RGB': |
| image = image.convert('RGB') |
| |
| image_tensor = self.transform(image).unsqueeze(0) |
| return image_tensor.to(self.device) |
|
|
| def predict(self, image): |
| if isinstance(image, str): |
| image = Image.open(image) |
| |
| image_tensor = self.preprocess_image(image) |
| |
| with torch.no_grad(): |
| logits = self.model(image_tensor) |
| |
| predictions = torch.argmax(logits, dim=1).squeeze(0).cpu().numpy() |
| return predictions |
|
|
| def count_heads(self, predictions, min_distance=15): |
| head_mask = (predictions == 3).astype(np.uint8) |
| |
| if head_mask.sum() == 0: |
| return 0 |
| |
| |
| distance = ndimage.distance_transform_edt(head_mask) |
| |
| |
| coords = peak_local_max(distance, min_distance=min_distance, labels=head_mask) |
| |
| |
| num_heads = len(coords) |
| |
| return num_heads |
|
|
| def create_colored_mask(self, predictions): |
| h, w = predictions.shape |
| mask_rgb = np.zeros((h, w, 3), dtype=np.uint8) |
| |
| for class_id, color in enumerate(MASK_COLORS): |
| mask_rgb[predictions == class_id] = color |
| |
| return Image.fromarray(mask_rgb) |
|
|
| def overlay_mask(self, image, predictions, alpha=0.5, heads_only=True): |
| if isinstance(image, np.ndarray): |
| image = Image.fromarray(image) |
| |
| if image.size != (512, 512): |
| image = image.resize((512, 512), Image.Resampling.BILINEAR) |
| |
| |
| h, w = predictions.shape |
| mask_rgb = np.zeros((h, w, 3), dtype=np.uint8) |
| |
| if heads_only: |
| |
| mask_rgb[predictions == 3] = (50, 255, 132) |
| else: |
| |
| for class_id, color in enumerate(MASK_COLORS): |
| mask_rgb[predictions == class_id] = color |
| |
| mask_img = Image.fromarray(mask_rgb) |
| overlay = Image.blend(image.convert('RGB'), mask_img, alpha) |
| return overlay |
|
|
| def predict_and_overlay(self, image, alpha=0.5, heads_only=True): |
| predictions = self.predict(image) |
| overlay = self.overlay_mask(image, predictions, alpha=alpha, heads_only=heads_only) |
| return overlay |
|
|
| if __name__ == "__main__": |
| import sys |
| |
| if len(sys.argv) < 2: |
| print("Usage: python inference.py <image_path> [model_path]") |
| sys.exit(1) |
|
|
| image_path = sys.argv[1] |
| model_path = sys.argv[2] if len(sys.argv) > 2 else "cache/02_dice_stem.pth" |
| |
| print(f"Loading model from {model_path}...") |
| model = GWFSSModel(model_path) |
| |
| print(f"Processing image: {image_path}") |
| image = Image.open(image_path) |
| predictions = model.predict(image) |
| |
| |
| num_heads = model.count_heads(predictions) |
| print(f"\n🌾 {num_heads} heads detected") |
| |
| |
| print("\nGenerating visualisations...") |
| overlay_heads = model.overlay_mask(image, predictions, alpha=0.5, heads_only=True) |
| overlay_all = model.overlay_mask(image, predictions, alpha=0.5, heads_only=False) |
| |
| |
| output_heads = image_path.rsplit('.', 1)[0] + '_heads_only.png' |
| output_all = image_path.rsplit('.', 1)[0] + '_all_classes.png' |
| |
| overlay_heads.save(output_heads) |
| overlay_all.save(output_all) |
| |
| print(f"✓ Saved head overlay to: {output_heads}") |
| print(f"✓ Saved full segmentation to: {output_all}") |