| """
|
| inspect_dataset.py
|
| ------------------
|
| Inspects the OSF Ti-64 SEM fractography dataset after downloading.
|
| Run after download_osf.py.
|
|
|
| What this does:
|
| 1. Scans the data/ directory and reports what it finds
|
| 2. Detects mask format (grayscale int labels vs RGB color masks)
|
| 3. Prints unique class label values found in masks
|
| 4. Generates a visualization grid of image/mask pairs
|
| 5. Saves visualization to output/inspection_grid.png
|
|
|
| Usage:
|
| python inspect_dataset.py
|
| python inspect_dataset.py --data_dir path/to/your/data
|
| """
|
|
|
| import argparse
|
| import sys
|
| from pathlib import Path
|
|
|
| import matplotlib
|
| matplotlib.use("Agg")
|
| import matplotlib.pyplot as plt
|
| import matplotlib.patches as mpatches
|
| import numpy as np
|
| from PIL import Image
|
|
|
|
|
|
|
|
|
|
|
| LABEL_MAP = {
|
| 0: ("Background", "#1a1a2e"),
|
| 1: ("Lack of Fusion", "#e94560"),
|
| 2: ("Keyhole", "#0f3460"),
|
| 3: ("Other Defect", "#533483"),
|
|
|
| }
|
|
|
|
|
| CMAP = plt.cm.get_cmap("tab10")
|
|
|
|
|
|
|
| def find_image_mask_pairs(data_dir: Path) -> list[tuple[Path, Path]]:
|
| """
|
| Scan data_dir for image/mask pairs.
|
| Assumes masks live in a folder named 'masks' or 'mask',
|
| and images in 'images' or 'image', or are paired by filename.
|
| """
|
| pairs = []
|
| image_exts = {".png", ".tif", ".tiff", ".jpg", ".jpeg", ".bmp"}
|
|
|
|
|
| for images_dir in sorted(data_dir.rglob("images")):
|
| if not images_dir.is_dir():
|
| continue
|
| masks_dir = images_dir.parent / "masks"
|
| if not masks_dir.exists():
|
| masks_dir = images_dir.parent / "mask"
|
| if not masks_dir.exists():
|
| print(f" β οΈ Found images/ at {images_dir} but no masks/ sibling")
|
| continue
|
| for img_path in sorted(images_dir.iterdir()):
|
| if img_path.suffix.lower() not in image_exts:
|
| continue
|
|
|
| for ext in image_exts:
|
| mask_path = masks_dir / (img_path.stem + ext)
|
| if mask_path.exists():
|
| pairs.append((img_path, mask_path))
|
| break
|
| else:
|
| print(f" β οΈ No mask found for {img_path.name}")
|
|
|
|
|
| if not pairs:
|
| for img_path in sorted(data_dir.rglob("*_image.*")):
|
| if img_path.suffix.lower() not in image_exts:
|
| continue
|
| stem = img_path.stem.replace("_image", "")
|
| for ext in image_exts:
|
| mask_path = img_path.parent / f"{stem}_mask{ext}"
|
| if mask_path.exists():
|
| pairs.append((img_path, mask_path))
|
| break
|
|
|
| return pairs
|
|
|
|
|
| def inspect_mask(mask_path: Path) -> dict:
|
| """Return statistics about a mask file."""
|
| mask = np.array(Image.open(mask_path))
|
| info = {
|
| "shape": mask.shape,
|
| "dtype": str(mask.dtype),
|
| "mode": Image.open(mask_path).mode,
|
| "unique_values": sorted(np.unique(mask).tolist()),
|
| "min": int(mask.min()),
|
| "max": int(mask.max()),
|
| }
|
| return info
|
|
|
|
|
| def colorize_mask(mask: np.ndarray) -> np.ndarray:
|
| """Convert integer label mask to RGB image for visualization."""
|
| unique = np.unique(mask)
|
| rgb = np.zeros((*mask.shape[:2], 3), dtype=np.uint8)
|
| for val in unique:
|
| if val in LABEL_MAP:
|
| hex_color = LABEL_MAP[val][1].lstrip("#")
|
| r, g, b = tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))
|
| color = (r, g, b)
|
| else:
|
|
|
| rgba = CMAP(val / max(unique.max(), 1))
|
| color = tuple(int(c * 255) for c in rgba[:3])
|
| rgb[mask == val] = color
|
| return rgb
|
|
|
|
|
| def make_legend(unique_vals: list[int]) -> list[mpatches.Patch]:
|
| patches = []
|
| for val in unique_vals:
|
| label, hex_color = LABEL_MAP.get(val, (f"Class {val}", "#888888"))
|
| patches.append(mpatches.Patch(color=hex_color, label=f"{val}: {label}"))
|
| return patches
|
|
|
|
|
| def visualize_pairs(
|
| pairs: list[tuple[Path, Path]],
|
| n: int = 6,
|
| output_path: Path = Path("output/inspection_grid.png"),
|
| ):
|
| """Save a grid of n image/mask/overlay triplets."""
|
| n = min(n, len(pairs))
|
| if n == 0:
|
| print(" No pairs to visualize.")
|
| return
|
|
|
| fig, axes = plt.subplots(n, 3, figsize=(12, n * 4))
|
| if n == 1:
|
| axes = [axes]
|
|
|
| fig.suptitle("OSF Ti-64 SEM Dataset β Inspection Grid\n(Image | Mask | Overlay)",
|
| fontsize=13, fontweight="bold", y=1.01)
|
|
|
| all_unique = set()
|
|
|
| for i, (img_path, mask_path) in enumerate(pairs[:n]):
|
| img = np.array(Image.open(img_path).convert("RGB"))
|
| mask_pil = Image.open(mask_path)
|
| mask_arr = np.array(mask_pil)
|
|
|
|
|
| if mask_arr.ndim == 3:
|
| mask_arr = np.array(mask_pil.convert("L"))
|
|
|
| unique_vals = sorted(np.unique(mask_arr).tolist())
|
| all_unique.update(unique_vals)
|
| mask_rgb = colorize_mask(mask_arr)
|
|
|
|
|
| overlay = (img.astype(float) * 0.5 + mask_rgb.astype(float) * 0.5).astype(np.uint8)
|
|
|
| axes[i][0].imshow(img, cmap="gray" if img.ndim == 2 else None)
|
| axes[i][0].set_title(f"Image\n{img_path.name}", fontsize=8)
|
| axes[i][0].axis("off")
|
|
|
| axes[i][1].imshow(mask_rgb)
|
| axes[i][1].set_title(
|
| f"Mask (classes: {unique_vals})\n{mask_path.name}", fontsize=8
|
| )
|
| axes[i][1].axis("off")
|
|
|
| axes[i][2].imshow(overlay)
|
| axes[i][2].set_title("Overlay", fontsize=8)
|
| axes[i][2].axis("off")
|
|
|
|
|
| legend_patches = make_legend(sorted(all_unique))
|
| fig.legend(handles=legend_patches, loc="lower center", ncol=len(legend_patches),
|
| bbox_to_anchor=(0.5, -0.02), fontsize=9, title="Mask Classes Found")
|
|
|
| output_path.parent.mkdir(parents=True, exist_ok=True)
|
| plt.tight_layout()
|
| plt.savefig(output_path, dpi=150, bbox_inches="tight")
|
| plt.close()
|
| print(f"\nβ
Visualization saved to: {output_path.resolve()}")
|
|
|
|
|
| def print_dataset_summary(data_dir: Path, pairs: list[tuple[Path, Path]]):
|
| print(f"\n{'='*60}")
|
| print(f"Dataset Summary β {data_dir.resolve()}")
|
| print(f"{'='*60}")
|
| print(f"Total image/mask pairs found: {len(pairs)}")
|
|
|
| if not pairs:
|
| print("\nβ οΈ No pairs found. Check your data/ folder structure.")
|
| print("Expected layout:")
|
| print(" data/")
|
| print(" <subset>/")
|
| print(" images/ β SEM images (.png or .tif)")
|
| print(" masks/ β segmentation masks (.png)")
|
| return
|
|
|
|
|
| print(f"\nSampling first 5 masks for format inspection:")
|
| all_unique = set()
|
| for img_path, mask_path in pairs[:5]:
|
| info = inspect_mask(mask_path)
|
| print(f"\n {mask_path.name}")
|
| print(f" Mode: {info['mode']}")
|
| print(f" Shape: {info['shape']}")
|
| print(f" Dtype: {info['dtype']}")
|
| print(f" Unique values: {info['unique_values']}")
|
| print(f" Value range: [{info['min']}, {info['max']}]")
|
| all_unique.update(info["unique_values"])
|
|
|
| print(f"\n{'β'*40}")
|
| print(f"All unique class values across sampled masks: {sorted(all_unique)}")
|
| print("\nLabel interpretation:")
|
| for v in sorted(all_unique):
|
| label, _ = LABEL_MAP.get(v, (f"UNKNOWN β update LABEL_MAP in this script", "#888"))
|
| print(f" {v:3d} β {label}")
|
|
|
| print(f"\nβ οΈ NOTE: If all unique values are {{0, 255}}, masks are binary (defect/no-defect).")
|
| print(" If values are 0βN, masks are multi-class integer labels β ideal for SegFormer.")
|
| print(" If mode is 'RGB', masks encode class as color β you'll need to remap.")
|
|
|
|
|
| def main():
|
| parser = argparse.ArgumentParser()
|
| parser.add_argument("--data_dir", type=str, default="data",
|
| help="Path to downloaded dataset root")
|
| parser.add_argument("--n_vis", type=int, default=6,
|
| help="Number of pairs to visualize")
|
| parser.add_argument("--output", type=str, default="output/inspection_grid.png",
|
| help="Where to save the visualization grid")
|
| args = parser.parse_args()
|
|
|
| data_dir = Path(args.data_dir)
|
| if not data_dir.exists():
|
| print(f"β data_dir '{data_dir}' does not exist.")
|
| print("Run download_osf.py first, or set --data_dir to your data folder.")
|
| sys.exit(1)
|
|
|
| print("Scanning for image/mask pairs...")
|
| pairs = find_image_mask_pairs(data_dir)
|
|
|
| print_dataset_summary(data_dir, pairs)
|
|
|
| if pairs:
|
| print(f"\nGenerating visualization grid ({min(args.n_vis, len(pairs))} samples)...")
|
| visualize_pairs(pairs, n=args.n_vis, output_path=Path(args.output))
|
|
|
|
|
| if __name__ == "__main__":
|
| main() |