| """ |
| inference.py |
| ------------ |
| Loads a trained SegFormer checkpoint and runs inference on SEM images. |
| Saves a visualization grid showing: original image | predicted mask | overlay. |
| |
| Usage: |
| # Run on a specific subset's val images |
| python inference.py --subset lack_of_fusion |
| |
| # Run on a specific image |
| python inference.py --image path/to/image.png --subset keyhole |
| |
| # Run all three subsets |
| python inference.py --subset all |
| """ |
|
|
| import argparse |
| import random |
| from pathlib import Path |
| from features import load_model, load_image_tensor |
|
|
|
|
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| import matplotlib.patches as mpatches |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from PIL import Image |
| from transformers import SegformerForSemanticSegmentation |
|
|
| from dataset import FractographyDataset, IMAGE_SIZE, NUM_CLASSES, MASK_SCALE |
|
|
| |
| DEVICE = torch.device("cpu") |
| N_SAMPLES = 6 |
| LABEL_MAP = {0: ("Background", "#1a1a2e"), 1: ("Defect", "#00d4ff")} |
| |
|
|
|
|
| def load_model(checkpoint_path: Path) -> SegformerForSemanticSegmentation: |
| id2label = {0: "background", 1: "defect"} |
| label2id = {v: k for k, v in id2label.items()} |
| model = SegformerForSemanticSegmentation.from_pretrained( |
| "nvidia/mit-b0", |
| num_labels=NUM_CLASSES, |
| id2label=id2label, |
| label2id=label2id, |
| ignore_mismatched_sizes=True, |
| ) |
| state = torch.load(checkpoint_path, map_location=DEVICE, weights_only=True) |
| model.load_state_dict(state) |
| model.eval() |
| return model |
|
|
|
|
| def load_raw_image(path: Path) -> np.ndarray: |
| """Load SEM image as a displayable uint8 RGB array (handles 16-bit).""" |
| arr = np.array(Image.open(path), dtype=np.float32) |
| arr = (arr - arr.min()) / (arr.max() - arr.min() + 1e-8) |
| rgb = np.stack([arr, arr, arr], axis=-1) |
| return (rgb * 255).astype(np.uint8) |
|
|
|
|
| def predict(model, image_tensor: torch.Tensor, target_size: tuple) -> np.ndarray: |
| """Run inference and return (H, W) prediction mask as numpy array.""" |
| with torch.no_grad(): |
| outputs = model(pixel_values=image_tensor.unsqueeze(0)) |
| logits = outputs.logits |
| upsampled = F.interpolate( |
| logits, size=target_size, mode="bilinear", align_corners=False |
| ) |
| pred = upsampled.squeeze(0).argmax(dim=0).numpy() |
| return pred |
|
|
|
|
| def colorize(mask: np.ndarray) -> np.ndarray: |
| rgb = np.zeros((*mask.shape, 3), dtype=np.uint8) |
| for val, (_, hex_color) in LABEL_MAP.items(): |
| r, g, b = tuple(int(hex_color.lstrip("#")[i:i+2], 16) for i in (0, 2, 4)) |
| rgb[mask == val] = (r, g, b) |
| return rgb |
|
|
|
|
| def compute_stats(pred: np.ndarray, gt: np.ndarray) -> dict: |
| """Compute per-image IoU and defect coverage.""" |
| pred_defect = pred == 1 |
| gt_defect = gt == 1 |
| intersection = (pred_defect & gt_defect).sum() |
| union = (pred_defect | gt_defect).sum() |
| iou = intersection / union if union > 0 else float("nan") |
| coverage_pred = pred_defect.sum() / pred.size * 100 |
| coverage_gt = gt_defect.sum() / gt.size * 100 |
| return {"iou": iou, "pred_coverage": coverage_pred, "gt_coverage": coverage_gt} |
|
|
|
|
| def run_inference(subset: str, args): |
| data_dir = Path(args.data_dir) / subset |
| ckpt_path = Path("checkpoints") / subset / "best_model.pt" |
| out_dir = Path("output") / "inference" |
| out_dir.mkdir(parents=True, exist_ok=True) |
|
|
| if not data_dir.exists(): |
| print(f"β οΈ Skipping '{subset}' β data not found at {data_dir}") |
| return |
| if not ckpt_path.exists(): |
| print(f"β οΈ Skipping '{subset}' β no checkpoint at {ckpt_path}") |
| return |
|
|
| print(f"\n{'='*60}") |
| print(f"Inference: {subset}") |
| print(f"Checkpoint: {ckpt_path}") |
| print(f"{'='*60}") |
|
|
| model = load_model(ckpt_path) |
|
|
| |
| ds = FractographyDataset(data_dir, split="all", image_size=IMAGE_SIZE) |
| indices = list(range(len(ds))) |
| random.seed(42) |
| random.shuffle(indices) |
| sample_indices = indices[:N_SAMPLES] |
|
|
| |
| n = len(sample_indices) |
| fig, axes = plt.subplots(n, 4, figsize=(16, n * 4)) |
| if n == 1: |
| axes = [axes] |
| fig.suptitle( |
| f"SegFormer Inference β {subset.replace('_', ' ').title()}", |
| fontsize=13, fontweight="bold" |
| ) |
|
|
| ious = [] |
| for row, idx in enumerate(sample_indices): |
| img_path, mask_path = ds.pairs[idx] |
| img_tensor, gt_mask = ds[idx] |
|
|
| |
| raw_img = load_raw_image(img_path) |
|
|
| |
| gt_arr = gt_mask.numpy() |
|
|
| |
| pred = predict(model, img_tensor, target_size=IMAGE_SIZE) |
|
|
| |
| raw_resized = np.array( |
| Image.fromarray(raw_img).resize( |
| (IMAGE_SIZE[1], IMAGE_SIZE[0]), Image.BILINEAR |
| ) |
| ) |
|
|
| |
| stats = compute_stats(pred, gt_arr) |
| ious.append(stats["iou"]) |
|
|
| |
| pred_colored = colorize(pred) |
| gt_colored = colorize(gt_arr) |
| overlay = (raw_resized.astype(float) * 0.6 + |
| pred_colored.astype(float) * 0.4).astype(np.uint8) |
|
|
| |
| axes[row][0].imshow(raw_resized, cmap="gray") |
| axes[row][0].set_title(f"Image\n{img_path.name}", fontsize=7) |
| axes[row][0].axis("off") |
|
|
| axes[row][1].imshow(gt_colored) |
| axes[row][1].set_title( |
| f"Ground Truth\n{stats['gt_coverage']:.1f}% defect", fontsize=7 |
| ) |
| axes[row][1].axis("off") |
|
|
| axes[row][2].imshow(pred_colored) |
| axes[row][2].set_title( |
| f"Prediction\n{stats['pred_coverage']:.1f}% defect", fontsize=7 |
| ) |
| axes[row][2].axis("off") |
|
|
| axes[row][3].imshow(overlay) |
| iou_str = f"{stats['iou']:.3f}" if not np.isnan(stats["iou"]) else "N/A" |
| axes[row][3].set_title(f"Overlay\nIoU={iou_str}", fontsize=7) |
| axes[row][3].axis("off") |
|
|
| |
| patches = [ |
| mpatches.Patch(color=LABEL_MAP[0][1], label="Background"), |
| mpatches.Patch(color=LABEL_MAP[1][1], label="Defect"), |
| ] |
| fig.legend(handles=patches, loc="lower center", ncol=2, |
| bbox_to_anchor=(0.5, -0.01), fontsize=9) |
|
|
| mean_iou = np.nanmean(ious) |
| fig.text(0.5, -0.03, f"Mean IoU (these samples): {mean_iou:.4f}", |
| ha="center", fontsize=10, fontweight="bold") |
|
|
| plt.tight_layout() |
| out_path = out_dir / f"{subset}_inference.png" |
| plt.savefig(out_path, dpi=150, bbox_inches="tight") |
| plt.close() |
|
|
| print(f" Mean IoU (sampled): {mean_iou:.4f}") |
| print(f" Saved β {out_path.resolve()}") |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--subset", type=str, default="all", |
| help="lack_of_fusion | keyhole | all_defects | all") |
| parser.add_argument("--data_dir", type=str, default="data") |
| parser.add_argument("--n", type=int, default=6, |
| help="Number of images to visualize") |
| args = parser.parse_args() |
|
|
| N_SAMPLES = args.n |
|
|
| subsets = ( |
| ["lack_of_fusion", "keyhole", "all_defects"] |
| if args.subset == "all" |
| else [args.subset] |
| ) |
|
|
| for subset in subsets: |
| run_inference(subset, args) |
|
|
| print("\nβ
Done. Check output/inference/") |
|
|