Upload 10 files
Browse files- app.py +376 -0
- dataset.py +208 -0
- diagnose.py +420 -0
- download_osf.py +146 -0
- features.py +421 -0
- inference.py +230 -0
- inspect_dataset.py +264 -0
- setup.py +21 -0
- test.py +261 -0
- train.py +306 -0
app.py
ADDED
|
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
app.py
|
| 3 |
+
------
|
| 4 |
+
FailureGPT — Gradio web interface.
|
| 5 |
+
Drag-and-drop SEM image → segmentation → features → AI diagnosis.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
pip install gradio
|
| 9 |
+
python app.py
|
| 10 |
+
|
| 11 |
+
Then open http://127.0.0.1:7860 in your browser.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import json
|
| 15 |
+
import os
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
|
| 18 |
+
import gradio as gr
|
| 19 |
+
import numpy as np
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn.functional as F
|
| 22 |
+
from PIL import Image
|
| 23 |
+
|
| 24 |
+
from dataset import IMAGE_SIZE, NUM_CLASSES
|
| 25 |
+
from features import (
|
| 26 |
+
load_model, load_image_tensor, predict_mask,
|
| 27 |
+
extract_features,
|
| 28 |
+
)
|
| 29 |
+
from diagnose import call_claude, format_diagnosis_report
|
| 30 |
+
|
| 31 |
+
# ── Load all three models at startup ─────────────────────────────────────────
|
| 32 |
+
SUBSETS = ["all_defects", "lack_of_fusion", "keyhole"]
|
| 33 |
+
MODELS = {}
|
| 34 |
+
|
| 35 |
+
print("Loading checkpoints...")
|
| 36 |
+
for subset in SUBSETS:
|
| 37 |
+
ckpt = Path("checkpoints") / subset / "best_model.pt"
|
| 38 |
+
if ckpt.exists():
|
| 39 |
+
MODELS[subset] = load_model(ckpt)
|
| 40 |
+
print(f" ✅ {subset}")
|
| 41 |
+
else:
|
| 42 |
+
print(f" ⚠️ {subset} — checkpoint not found")
|
| 43 |
+
|
| 44 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 45 |
+
|
| 46 |
+
RISK_COLORS = {
|
| 47 |
+
"low": "#2ecc71",
|
| 48 |
+
"medium": "#f39c12",
|
| 49 |
+
"high": "#e74c3c",
|
| 50 |
+
"critical": "#8e44ad",
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
def run_pipeline(image: np.ndarray, subset: str) -> tuple:
|
| 54 |
+
if image is None:
|
| 55 |
+
return None, "No image provided.", "No image provided.", "—"
|
| 56 |
+
if subset not in MODELS:
|
| 57 |
+
return None, f"No checkpoint for '{subset}'.", "Train the model first.", "—"
|
| 58 |
+
|
| 59 |
+
model = MODELS[subset]
|
| 60 |
+
|
| 61 |
+
# Gradio gives H×W×3 uint8
|
| 62 |
+
arr = image.astype(np.float32)
|
| 63 |
+
if arr.ndim == 2:
|
| 64 |
+
arr = np.stack([arr]*3, axis=-1)
|
| 65 |
+
elif arr.shape[2] == 4:
|
| 66 |
+
arr = arr[:, :, :3]
|
| 67 |
+
|
| 68 |
+
# Normalize to [0,1]
|
| 69 |
+
arr_min, arr_max = arr.min(), arr.max()
|
| 70 |
+
arr_norm = (arr - arr_min) / (arr_max - arr_min + 1e-8) if arr_max > arr_min else arr / 255.0
|
| 71 |
+
|
| 72 |
+
# Build display copy
|
| 73 |
+
display_pil = Image.fromarray(
|
| 74 |
+
(arr_norm * 255).astype(np.uint8), mode="RGB"
|
| 75 |
+
).resize((IMAGE_SIZE[1], IMAGE_SIZE[0]), Image.BILINEAR)
|
| 76 |
+
display_arr = np.array(display_pil, dtype=np.uint8)
|
| 77 |
+
|
| 78 |
+
# ImageNet normalization for model
|
| 79 |
+
arr_model = np.array(display_pil, dtype=np.float32) / 255.0
|
| 80 |
+
mean = np.array([0.485, 0.456, 0.406])
|
| 81 |
+
std = np.array([0.229, 0.224, 0.225])
|
| 82 |
+
arr_model = (arr_model - mean) / std
|
| 83 |
+
img_tensor = torch.from_numpy(arr_model).permute(2, 0, 1).float()
|
| 84 |
+
|
| 85 |
+
print(f"DEBUG display_arr: min={display_arr.min()} max={display_arr.max()}")
|
| 86 |
+
# ── Step 2: Segment ───────────────────────────────────────────────────────
|
| 87 |
+
mask = predict_mask(model, img_tensor, IMAGE_SIZE)
|
| 88 |
+
print(f"DEBUG mask: unique={np.unique(mask).tolist()} defect_px={( mask==1).sum()}")
|
| 89 |
+
|
| 90 |
+
# ── Step 3: Extract features ──────────────────────────────────────────────
|
| 91 |
+
features = extract_features(mask, IMAGE_SIZE)
|
| 92 |
+
|
| 93 |
+
# ── Step 4: Build overlay ─────────────────────────────────────────────────
|
| 94 |
+
# Replace the overlay build block with:
|
| 95 |
+
overlay = display_arr.copy()
|
| 96 |
+
# First apply cyan to defect pixels at full intensity
|
| 97 |
+
defect_mask = mask == 1
|
| 98 |
+
overlay[defect_mask] = [0, 212, 255]
|
| 99 |
+
# Blend only the background pixels, keep defects fully cyan
|
| 100 |
+
result = display_arr.copy()
|
| 101 |
+
result[~defect_mask] = display_arr[~defect_mask] # background unchanged
|
| 102 |
+
result[defect_mask] = (
|
| 103 |
+
display_arr[defect_mask].astype(float) * 0.3 +
|
| 104 |
+
np.array([0, 212, 255], dtype=float) * 0.7
|
| 105 |
+
).clip(0, 255).astype(np.uint8)
|
| 106 |
+
overlay = result
|
| 107 |
+
|
| 108 |
+
from PIL import Image as PILImage
|
| 109 |
+
PILImage.fromarray(overlay).save("output/debug_overlay.png")
|
| 110 |
+
print(f"DEBUG saved overlay to output/debug_overlay.png")
|
| 111 |
+
|
| 112 |
+
# ── Step 5: Format features text ──────────────────────────────────────────
|
| 113 |
+
feat_lines = [
|
| 114 |
+
f"Defect Area: {features['defect_area_fraction']:.3f}%",
|
| 115 |
+
f"Defect Count: {features['defect_count']} blobs",
|
| 116 |
+
f"Mean Pore Area: {features.get('mean_pore_area_px', 0):.1f} px²",
|
| 117 |
+
f"Max Pore Area: {features.get('max_pore_area_px', 0)} px²",
|
| 118 |
+
f"Mean Aspect Ratio: {features['mean_aspect_ratio']:.3f}",
|
| 119 |
+
f" (1.0=circular · >2.0=elongated)",
|
| 120 |
+
f"Spatial Spread: {features['spatial_concentration']:.2f}",
|
| 121 |
+
f"Size Std Dev: {features['size_std']:.1f}",
|
| 122 |
+
f"",
|
| 123 |
+
f"Quadrant Distribution:",
|
| 124 |
+
f" TL {features['quadrant_distribution'][0]:.2f} "
|
| 125 |
+
f"TR {features['quadrant_distribution'][1]:.2f}",
|
| 126 |
+
f" BL {features['quadrant_distribution'][2]:.2f} "
|
| 127 |
+
f"BR {features['quadrant_distribution'][3]:.2f}",
|
| 128 |
+
f"",
|
| 129 |
+
f"Rule-based type: {features['defect_type']}",
|
| 130 |
+
f"Confidence: {features['confidence']}",
|
| 131 |
+
]
|
| 132 |
+
features_text = "\n".join(feat_lines)
|
| 133 |
+
|
| 134 |
+
# ── Step 6: AI Diagnosis ──────────────────────────────────────────────────
|
| 135 |
+
if not os.environ.get("ANTHROPIC_API_KEY"):
|
| 136 |
+
diagnosis_text = (
|
| 137 |
+
"⚠️ ANTHROPIC_API_KEY not set.\n\n"
|
| 138 |
+
"Set it in your terminal:\n"
|
| 139 |
+
" $env:ANTHROPIC_API_KEY = 'sk-ant-...'\n\n"
|
| 140 |
+
"Features extracted successfully:\n\n"
|
| 141 |
+
+ features_text
|
| 142 |
+
)
|
| 143 |
+
risk_label = features["defect_type"].upper()
|
| 144 |
+
else:
|
| 145 |
+
diagnosis = call_claude(features, "uploaded_image")
|
| 146 |
+
diagnosis_text = format_diagnosis_report(features, diagnosis, "uploaded_image")
|
| 147 |
+
risk = diagnosis.get("crack_initiation_risk", "unknown")
|
| 148 |
+
mech = diagnosis.get("dominant_failure_mechanism", "unknown")
|
| 149 |
+
risk_label = f"{risk.upper()} RISK — {mech}"
|
| 150 |
+
|
| 151 |
+
# Ensure output is exactly what Gradio expects
|
| 152 |
+
overlay = overlay.astype(np.uint8)
|
| 153 |
+
assert overlay.ndim == 3 and overlay.shape[2] == 3
|
| 154 |
+
print(f"DEBUG overlay: shape={overlay.shape} dtype={overlay.dtype} min={overlay.min()} max={overlay.max()}")
|
| 155 |
+
return overlay, features_text, diagnosis_text, risk_label
|
| 156 |
+
# ── Gradio UI ─────────────────────────────────────────────────────────────────
|
| 157 |
+
|
| 158 |
+
CSS = """
|
| 159 |
+
@import url('https://fonts.googleapis.com/css2?family=Space+Mono:wght@400;700&family=DM+Sans:wght@300;400;600&display=swap');
|
| 160 |
+
|
| 161 |
+
body, .gradio-container {
|
| 162 |
+
background: #080c14 !important;
|
| 163 |
+
font-family: 'DM Sans', sans-serif !important;
|
| 164 |
+
color: #c8d6e5 !important;
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
.gradio-container {
|
| 168 |
+
max-width: 1400px !important;
|
| 169 |
+
margin: 0 auto !important;
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
/* Header */
|
| 173 |
+
#header {
|
| 174 |
+
text-align: center;
|
| 175 |
+
padding: 2rem 0 1rem;
|
| 176 |
+
border-bottom: 1px solid #1e3a5f;
|
| 177 |
+
margin-bottom: 1.5rem;
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
#header h1 {
|
| 181 |
+
font-family: 'Space Mono', monospace !important;
|
| 182 |
+
font-size: 2.4rem !important;
|
| 183 |
+
font-weight: 700 !important;
|
| 184 |
+
color: #00d4ff !important;
|
| 185 |
+
letter-spacing: -1px;
|
| 186 |
+
margin: 0;
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
#header p {
|
| 190 |
+
color: #5a7a9a;
|
| 191 |
+
font-size: 0.9rem;
|
| 192 |
+
margin: 0.4rem 0 0;
|
| 193 |
+
font-family: 'Space Mono', monospace;
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
/* Risk badge */
|
| 197 |
+
#risk_label textarea, #risk_label input {
|
| 198 |
+
font-family: 'Space Mono', monospace !important;
|
| 199 |
+
font-size: 1.1rem !important;
|
| 200 |
+
font-weight: 700 !important;
|
| 201 |
+
color: #00d4ff !important;
|
| 202 |
+
background: #0d1825 !important;
|
| 203 |
+
border: 2px solid #00d4ff !important;
|
| 204 |
+
border-radius: 6px !important;
|
| 205 |
+
text-align: center !important;
|
| 206 |
+
padding: 0.6rem !important;
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
/* Textboxes */
|
| 210 |
+
textarea {
|
| 211 |
+
font-family: 'Space Mono', monospace !important;
|
| 212 |
+
font-size: 0.78rem !important;
|
| 213 |
+
background: #0a1520 !important;
|
| 214 |
+
color: #a8c4dc !important;
|
| 215 |
+
border: 1px solid #1e3a5f !important;
|
| 216 |
+
border-radius: 6px !important;
|
| 217 |
+
line-height: 1.6 !important;
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
/* Labels */
|
| 221 |
+
label span {
|
| 222 |
+
font-family: 'Space Mono', monospace !important;
|
| 223 |
+
font-size: 0.72rem !important;
|
| 224 |
+
color: #4a7a9a !important;
|
| 225 |
+
letter-spacing: 1px !important;
|
| 226 |
+
text-transform: uppercase !important;
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
/* Buttons */
|
| 230 |
+
button.primary {
|
| 231 |
+
background: linear-gradient(135deg, #003d66, #006699) !important;
|
| 232 |
+
border: 1px solid #00d4ff !important;
|
| 233 |
+
color: #00d4ff !important;
|
| 234 |
+
font-family: 'Space Mono', monospace !important;
|
| 235 |
+
font-weight: 700 !important;
|
| 236 |
+
letter-spacing: 1px !important;
|
| 237 |
+
border-radius: 6px !important;
|
| 238 |
+
transition: all 0.2s !important;
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
button.primary:hover {
|
| 242 |
+
background: linear-gradient(135deg, #006699, #00aacc) !important;
|
| 243 |
+
box-shadow: 0 0 20px rgba(0, 212, 255, 0.3) !important;
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
button.secondary {
|
| 247 |
+
background: #0a1520 !important;
|
| 248 |
+
border: 1px solid #1e3a5f !important;
|
| 249 |
+
color: #5a7a9a !important;
|
| 250 |
+
font-family: 'Space Mono', monospace !important;
|
| 251 |
+
border-radius: 6px !important;
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
/* Dropdown */
|
| 255 |
+
select, .wrap {
|
| 256 |
+
background: #0a1520 !important;
|
| 257 |
+
border: 1px solid #1e3a5f !important;
|
| 258 |
+
color: #a8c4dc !important;
|
| 259 |
+
font-family: 'Space Mono', monospace !important;
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
/* Image panels */
|
| 263 |
+
.image-container {
|
| 264 |
+
border: 1px solid #1e3a5f !important;
|
| 265 |
+
border-radius: 8px !important;
|
| 266 |
+
overflow: hidden !important;
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
/* Panel blocks */
|
| 270 |
+
.block {
|
| 271 |
+
background: #0a1520 !important;
|
| 272 |
+
border: 1px solid #1e3a5f !important;
|
| 273 |
+
border-radius: 8px !important;
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
/* Footer note */
|
| 277 |
+
#footer {
|
| 278 |
+
text-align: center;
|
| 279 |
+
padding: 1rem 0;
|
| 280 |
+
color: #2a4a6a;
|
| 281 |
+
font-family: 'Space Mono', monospace;
|
| 282 |
+
font-size: 0.7rem;
|
| 283 |
+
border-top: 1px solid #1e3a5f;
|
| 284 |
+
margin-top: 1.5rem;
|
| 285 |
+
}
|
| 286 |
+
"""
|
| 287 |
+
|
| 288 |
+
with gr.Blocks(css=CSS, title="FailSafe") as demo:
|
| 289 |
+
|
| 290 |
+
gr.HTML("""
|
| 291 |
+
<div id="header">
|
| 292 |
+
<h1>⬡ FAILSAFE</h1>
|
| 293 |
+
<p>Ti-6Al-4V · LPBF Defect Analysis · SEM Fractography · Powered by SegFormer + Claude</p>
|
| 294 |
+
</div>
|
| 295 |
+
""")
|
| 296 |
+
|
| 297 |
+
with gr.Row():
|
| 298 |
+
# Left column — inputs
|
| 299 |
+
with gr.Column(scale=1):
|
| 300 |
+
image_input = gr.Image(
|
| 301 |
+
label="SEM FRACTOGRAPH — drag & drop or click to upload",
|
| 302 |
+
type="numpy",
|
| 303 |
+
height=500,
|
| 304 |
+
)
|
| 305 |
+
subset_input = gr.Dropdown(
|
| 306 |
+
choices=SUBSETS,
|
| 307 |
+
value="all_defects",
|
| 308 |
+
label="MODEL SUBSET",
|
| 309 |
+
)
|
| 310 |
+
with gr.Row():
|
| 311 |
+
run_btn = gr.Button("▶ ANALYZE", variant="primary", scale=3)
|
| 312 |
+
clear_btn = gr.Button("✕ CLEAR", variant="secondary", scale=1)
|
| 313 |
+
|
| 314 |
+
risk_output = gr.Textbox(
|
| 315 |
+
label="CRACK INITIATION RISK",
|
| 316 |
+
lines=1,
|
| 317 |
+
interactive=False,
|
| 318 |
+
elem_id="risk_label",
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
# Middle column — image output
|
| 322 |
+
with gr.Column(scale=1):
|
| 323 |
+
overlay_output = gr.Image(
|
| 324 |
+
label="DEFECT SEGMENTATION MAP",
|
| 325 |
+
height=500,
|
| 326 |
+
interactive=False,
|
| 327 |
+
)
|
| 328 |
+
features_output = gr.Textbox(
|
| 329 |
+
label="MORPHOLOGICAL FEATURES",
|
| 330 |
+
lines=14,
|
| 331 |
+
interactive=False,
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
# Right column — diagnosis
|
| 335 |
+
with gr.Column(scale=1):
|
| 336 |
+
diagnosis_output = gr.Textbox(
|
| 337 |
+
label="AI FAILURE DIAGNOSIS — Claude",
|
| 338 |
+
lines=28,
|
| 339 |
+
interactive=False,
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
gr.HTML("""
|
| 343 |
+
<div id="footer">
|
| 344 |
+
FailureGPT · ASU Mechanical Engineering · OSF Ti-64 Dataset ·
|
| 345 |
+
SegFormer-b0 fine-tuned · Claude Reasoning Layer
|
| 346 |
+
</div>
|
| 347 |
+
""")
|
| 348 |
+
|
| 349 |
+
# Wire up
|
| 350 |
+
run_btn.click(
|
| 351 |
+
fn=run_pipeline,
|
| 352 |
+
inputs=[image_input, subset_input],
|
| 353 |
+
outputs=[overlay_output, features_output, diagnosis_output, risk_output],
|
| 354 |
+
)
|
| 355 |
+
clear_btn.click(
|
| 356 |
+
fn=lambda: (None, None, "", "", ""),
|
| 357 |
+
outputs=[image_input, overlay_output, features_output, diagnosis_output, risk_output],
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
# Example images
|
| 361 |
+
example_images = list(Path("data/all_defects/images_8bit").glob("*.png"))[:3]
|
| 362 |
+
if example_images:
|
| 363 |
+
gr.Examples(
|
| 364 |
+
examples=[[str(p), "all_defects"] for p in example_images],
|
| 365 |
+
inputs=[image_input, subset_input],
|
| 366 |
+
label="EXAMPLE IMAGES",
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
if __name__ == "__main__":
|
| 371 |
+
demo.launch(
|
| 372 |
+
server_name="127.0.0.1",
|
| 373 |
+
server_port=7860,
|
| 374 |
+
share=False,
|
| 375 |
+
show_error=True,
|
| 376 |
+
)
|
dataset.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
dataset.py
|
| 3 |
+
----------
|
| 4 |
+
PyTorch Dataset class for the OSF Ti-64 SEM fractography dataset.
|
| 5 |
+
Use this after running inspect_dataset.py to confirm your mask format.
|
| 6 |
+
|
| 7 |
+
Key decisions you may need to make after inspection:
|
| 8 |
+
- If masks are binary (0/255): set NUM_CLASSES=2, update MASK_SCALE
|
| 9 |
+
- If masks are RGB color: set COLOR_MASK=True and define COLOR_TO_LABEL
|
| 10 |
+
- If masks are integer labels (0..N): use as-is (ideal case)
|
| 11 |
+
|
| 12 |
+
Usage:
|
| 13 |
+
from dataset import FractographyDataset
|
| 14 |
+
ds = FractographyDataset("data/", split="train")
|
| 15 |
+
img, mask = ds[0]
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from typing import Callable, Optional
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
from PIL import Image
|
| 24 |
+
from torch.utils.data import Dataset, DataLoader, random_split
|
| 25 |
+
import torchvision.transforms.functional as TF
|
| 26 |
+
import random
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# ── Config — update after running inspect_dataset.py ────────────────────────
|
| 30 |
+
NUM_CLASSES = 2 # update once you know how many classes are in your masks
|
| 31 |
+
IMAGE_SIZE = (512, 512) # resize target; SegFormer-b0 default input
|
| 32 |
+
MASK_SCALE = 255
|
| 33 |
+
# If masks use RGB color encoding instead of integer labels, set this to True
|
| 34 |
+
# and populate COLOR_TO_LABEL below.
|
| 35 |
+
COLOR_MASK = False
|
| 36 |
+
COLOR_TO_LABEL: dict[tuple, int] = {
|
| 37 |
+
# (R, G, B): class_index
|
| 38 |
+
# e.g. (255, 0, 0): 1,
|
| 39 |
+
}
|
| 40 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def rgb_mask_to_label(mask_rgb: np.ndarray, color_to_label: dict) -> np.ndarray:
|
| 44 |
+
"""Convert an H×W×3 RGB mask to an H×W integer label mask."""
|
| 45 |
+
label = np.zeros(mask_rgb.shape[:2], dtype=np.int64)
|
| 46 |
+
for color, cls_idx in color_to_label.items():
|
| 47 |
+
match = np.all(mask_rgb == np.array(color), axis=-1)
|
| 48 |
+
label[match] = cls_idx
|
| 49 |
+
return label
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class FractographyDataset(Dataset):
|
| 53 |
+
"""
|
| 54 |
+
OSF Ti-64 SEM Fractography Dataset.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
data_dir: Root of downloaded data (contains subfolders with images/ + masks/).
|
| 58 |
+
split: "train", "val", or "all" (no splitting, returns everything).
|
| 59 |
+
transform: Optional callable applied to both image and mask (augmentation).
|
| 60 |
+
image_size: Resize target (H, W).
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
IMAGE_EXTS = {".png", ".tif", ".tiff", ".jpg", ".jpeg"}
|
| 64 |
+
|
| 65 |
+
def __init__(
|
| 66 |
+
self,
|
| 67 |
+
data_dir: str | Path,
|
| 68 |
+
split: str = "all",
|
| 69 |
+
transform: Optional[Callable] = None,
|
| 70 |
+
image_size: tuple[int, int] = IMAGE_SIZE,
|
| 71 |
+
):
|
| 72 |
+
self.data_dir = Path(data_dir)
|
| 73 |
+
self.split = split
|
| 74 |
+
self.transform = transform
|
| 75 |
+
self.image_size = image_size
|
| 76 |
+
self.pairs = self._find_pairs()
|
| 77 |
+
|
| 78 |
+
if not self.pairs:
|
| 79 |
+
raise FileNotFoundError(
|
| 80 |
+
f"No image/mask pairs found in {self.data_dir}. "
|
| 81 |
+
"Run inspect_dataset.py to diagnose."
|
| 82 |
+
)
|
| 83 |
+
def _find_pairs(self) -> list[tuple[Path, Path]]:
|
| 84 |
+
pairs = []
|
| 85 |
+
for images_dir in sorted(self.data_dir.rglob("images_8bit")):
|
| 86 |
+
if not images_dir.is_dir():
|
| 87 |
+
continue
|
| 88 |
+
masks_dir = images_dir.parent / "masks_8bit"
|
| 89 |
+
if not masks_dir.exists():
|
| 90 |
+
continue
|
| 91 |
+
for img_path in sorted(images_dir.iterdir()):
|
| 92 |
+
if img_path.suffix.lower() not in self.IMAGE_EXTS:
|
| 93 |
+
continue
|
| 94 |
+
mask_path = masks_dir / img_path.name
|
| 95 |
+
if mask_path.exists():
|
| 96 |
+
pairs.append((img_path, mask_path))
|
| 97 |
+
return pairs
|
| 98 |
+
def _load_image(self, path: Path) -> torch.Tensor:
|
| 99 |
+
img = Image.open(path).convert("RGB")
|
| 100 |
+
img = img.resize((self.image_size[1], self.image_size[0]), Image.BILINEAR)
|
| 101 |
+
arr = np.array(img, dtype=np.float32) / 255.0
|
| 102 |
+
mean = np.array([0.485, 0.456, 0.406])
|
| 103 |
+
std = np.array([0.229, 0.224, 0.225])
|
| 104 |
+
arr = (arr - mean) / std
|
| 105 |
+
return torch.from_numpy(arr).permute(2, 0, 1).float()
|
| 106 |
+
def _load_mask(self, path: Path) -> torch.Tensor:
|
| 107 |
+
mask_pil = Image.open(path)
|
| 108 |
+
|
| 109 |
+
if COLOR_MASK:
|
| 110 |
+
mask_arr = np.array(mask_pil.convert("RGB"))
|
| 111 |
+
mask_arr = rgb_mask_to_label(mask_arr, COLOR_TO_LABEL)
|
| 112 |
+
else:
|
| 113 |
+
mask_arr = np.array(mask_pil.convert("L"), dtype=np.int64)
|
| 114 |
+
if MASK_SCALE > 1:
|
| 115 |
+
mask_arr = mask_arr // MASK_SCALE # e.g. 0/255 → 0/1
|
| 116 |
+
|
| 117 |
+
mask_pil_resized = Image.fromarray(mask_arr.astype(np.uint8)).resize(
|
| 118 |
+
(self.image_size[1], self.image_size[0]), Image.NEAREST # NEAREST preserves labels
|
| 119 |
+
)
|
| 120 |
+
mask_arr = np.array(mask_pil_resized, dtype=np.int64)
|
| 121 |
+
return torch.from_numpy(mask_arr).long() # H×W
|
| 122 |
+
|
| 123 |
+
def _augment(self, image: torch.Tensor, mask: torch.Tensor):
|
| 124 |
+
"""Shared spatial augmentations (applied identically to image and mask)."""
|
| 125 |
+
# Random horizontal flip
|
| 126 |
+
if random.random() > 0.5:
|
| 127 |
+
image = TF.hflip(image)
|
| 128 |
+
mask = TF.hflip(mask.unsqueeze(0)).squeeze(0)
|
| 129 |
+
|
| 130 |
+
# Random vertical flip
|
| 131 |
+
if random.random() > 0.5:
|
| 132 |
+
image = TF.vflip(image)
|
| 133 |
+
mask = TF.vflip(mask.unsqueeze(0)).squeeze(0)
|
| 134 |
+
|
| 135 |
+
# Random 90° rotation
|
| 136 |
+
k = random.choice([0, 1, 2, 3])
|
| 137 |
+
if k:
|
| 138 |
+
image = torch.rot90(image, k, dims=[1, 2])
|
| 139 |
+
mask = torch.rot90(mask, k, dims=[0, 1])
|
| 140 |
+
|
| 141 |
+
return image, mask
|
| 142 |
+
|
| 143 |
+
def __len__(self) -> int:
|
| 144 |
+
return len(self.pairs)
|
| 145 |
+
|
| 146 |
+
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
|
| 147 |
+
img_path, mask_path = self.pairs[idx]
|
| 148 |
+
image = self._load_image(img_path)
|
| 149 |
+
mask = self._load_mask(mask_path)
|
| 150 |
+
|
| 151 |
+
if self.split == "train" and self.transform is None:
|
| 152 |
+
image, mask = self._augment(image, mask)
|
| 153 |
+
elif self.transform is not None:
|
| 154 |
+
image, mask = self.transform(image, mask)
|
| 155 |
+
|
| 156 |
+
return image, mask
|
| 157 |
+
|
| 158 |
+
def __repr__(self) -> str:
|
| 159 |
+
return (
|
| 160 |
+
f"FractographyDataset("
|
| 161 |
+
f"n={len(self)}, split='{self.split}', "
|
| 162 |
+
f"image_size={self.image_size}, classes={NUM_CLASSES})"
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def get_dataloaders(
|
| 167 |
+
data_dir: str | Path,
|
| 168 |
+
batch_size: int = 4,
|
| 169 |
+
train_frac: float = 0.8,
|
| 170 |
+
num_workers: int = 2,
|
| 171 |
+
) -> tuple[DataLoader, DataLoader]:
|
| 172 |
+
"""
|
| 173 |
+
Returns (train_loader, val_loader) with 80/20 split.
|
| 174 |
+
"""
|
| 175 |
+
full_dataset = FractographyDataset(data_dir, split="all")
|
| 176 |
+
n_train = int(len(full_dataset) * train_frac)
|
| 177 |
+
n_val = len(full_dataset) - n_train
|
| 178 |
+
train_ds, val_ds = random_split(full_dataset, [n_train, n_val])
|
| 179 |
+
|
| 180 |
+
# Override split tag so augmentation fires for train only
|
| 181 |
+
train_ds.dataset.split = "train"
|
| 182 |
+
|
| 183 |
+
train_loader = DataLoader(
|
| 184 |
+
train_ds, batch_size=batch_size, shuffle=True,
|
| 185 |
+
num_workers=num_workers, pin_memory=True
|
| 186 |
+
)
|
| 187 |
+
val_loader = DataLoader(
|
| 188 |
+
val_ds, batch_size=batch_size, shuffle=False,
|
| 189 |
+
num_workers=num_workers, pin_memory=True
|
| 190 |
+
)
|
| 191 |
+
print(f"Train: {len(train_ds)} samples | Val: {len(val_ds)} samples")
|
| 192 |
+
return train_loader, val_loader
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
# ── Quick sanity check ────────────────────────────────────────────────────────
|
| 196 |
+
if __name__ == "__main__":
|
| 197 |
+
import sys
|
| 198 |
+
data_dir = sys.argv[1] if len(sys.argv) > 1 else "data"
|
| 199 |
+
|
| 200 |
+
try:
|
| 201 |
+
ds = FractographyDataset(data_dir)
|
| 202 |
+
print(ds)
|
| 203 |
+
img, mask = ds[0]
|
| 204 |
+
print(f"Image tensor: {img.shape} dtype={img.dtype} range=[{img.min():.2f}, {img.max():.2f}]")
|
| 205 |
+
print(f"Mask tensor: {mask.shape} dtype={mask.dtype} unique={mask.unique().tolist()}")
|
| 206 |
+
print("\n✅ Dataset loads correctly.")
|
| 207 |
+
except FileNotFoundError as e:
|
| 208 |
+
print(f"❌ {e}")
|
diagnose.py
ADDED
|
@@ -0,0 +1,420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
diagnose.py
|
| 3 |
+
-----------
|
| 4 |
+
Week 4: Generative reasoning layer.
|
| 5 |
+
|
| 6 |
+
Takes the feature dict output from features.py and calls the Claude API
|
| 7 |
+
to produce a structured engineering failure diagnosis.
|
| 8 |
+
|
| 9 |
+
The LLM receives:
|
| 10 |
+
- Quantitative morphological features from the segmentation
|
| 11 |
+
- Material context (Ti-6Al-4V, LPBF process)
|
| 12 |
+
- Defect type classification
|
| 13 |
+
And returns:
|
| 14 |
+
- Natural language diagnosis
|
| 15 |
+
- Crack initiation risk assessment
|
| 16 |
+
- Recommended follow-up actions
|
| 17 |
+
|
| 18 |
+
Usage:
|
| 19 |
+
# Single image full pipeline (segment → extract → diagnose)
|
| 20 |
+
python diagnose.py --image data/all_defects/images/001-Overview-EP04V24.png
|
| 21 |
+
--subset all_defects
|
| 22 |
+
|
| 23 |
+
# From existing feature JSON
|
| 24 |
+
python diagnose.py --json output/features/all_defects_features.json
|
| 25 |
+
|
| 26 |
+
# Interactive mode
|
| 27 |
+
python diagnose.py --interactive --subset all_defects
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
import argparse
|
| 31 |
+
import json
|
| 32 |
+
import time
|
| 33 |
+
from pathlib import Path
|
| 34 |
+
|
| 35 |
+
import torch
|
| 36 |
+
import torch.nn.functional as F
|
| 37 |
+
import numpy as np
|
| 38 |
+
import matplotlib
|
| 39 |
+
matplotlib.use("Agg")
|
| 40 |
+
import matplotlib.pyplot as plt
|
| 41 |
+
from PIL import Image
|
| 42 |
+
from transformers import SegformerForSemanticSegmentation
|
| 43 |
+
|
| 44 |
+
from dataset import FractographyDataset, IMAGE_SIZE, NUM_CLASSES
|
| 45 |
+
from features import (
|
| 46 |
+
load_model, load_image_tensor, predict_mask,
|
| 47 |
+
extract_features, visualize_features
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# ── Anthropic API ─────────────────────────────────────────────────────────────
|
| 51 |
+
try:
|
| 52 |
+
import anthropic
|
| 53 |
+
HAS_ANTHROPIC = True
|
| 54 |
+
except ImportError:
|
| 55 |
+
HAS_ANTHROPIC = False
|
| 56 |
+
print("⚠️ anthropic package not found. Run: pip install anthropic")
|
| 57 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 58 |
+
|
| 59 |
+
MATERIAL_CONTEXT = """
|
| 60 |
+
Material: Ti-6Al-4V (Grade 5 titanium alloy)
|
| 61 |
+
Process: Laser Powder Bed Fusion (LPBF) additive manufacturing
|
| 62 |
+
Application context: High-performance structural components (aerospace/defense)
|
| 63 |
+
Specimen type: Bend test bar, fractured in four-point bending
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
SYSTEM_PROMPT = """You are an expert materials engineer specializing in fractography
|
| 67 |
+
and failure analysis of additively manufactured aerospace components.
|
| 68 |
+
You analyze quantitative defect features extracted from SEM (Scanning Electron Microscope)
|
| 69 |
+
images of Ti-6Al-4V fracture surfaces produced by Laser Powder Bed Fusion (LPBF).
|
| 70 |
+
|
| 71 |
+
Your role is to:
|
| 72 |
+
1. Interpret morphological defect features in the context of LPBF process physics
|
| 73 |
+
2. Assess crack initiation and propagation risk based on defect characteristics
|
| 74 |
+
3. Provide actionable engineering recommendations
|
| 75 |
+
4. Be precise and quantitative — reference the actual feature values in your diagnosis
|
| 76 |
+
|
| 77 |
+
Always structure your response as valid JSON with these exact keys:
|
| 78 |
+
{
|
| 79 |
+
"diagnosis_summary": "2-3 sentence plain English summary",
|
| 80 |
+
"defect_interpretation": "detailed interpretation of the morphological features",
|
| 81 |
+
"crack_initiation_risk": "low | medium | high | critical",
|
| 82 |
+
"risk_rationale": "why you assigned this risk level, referencing specific features",
|
| 83 |
+
"dominant_failure_mechanism": "e.g. lack of fusion porosity, keyhole porosity, mixed",
|
| 84 |
+
"critical_regions": "which quadrants or regions pose highest risk",
|
| 85 |
+
"recommendations": ["recommendation 1", "recommendation 2", "recommendation 3"],
|
| 86 |
+
"confidence": "low | medium | high",
|
| 87 |
+
"confidence_rationale": "why"
|
| 88 |
+
}
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
def build_user_prompt(features: dict, image_name: str = "") -> str:
|
| 92 |
+
return f"""
|
| 93 |
+
Analyze the following defect features extracted from an SEM fractograph of a
|
| 94 |
+
Ti-6Al-4V LPBF test bar.
|
| 95 |
+
|
| 96 |
+
Material & Process Context:
|
| 97 |
+
{MATERIAL_CONTEXT}
|
| 98 |
+
|
| 99 |
+
Image: {image_name}
|
| 100 |
+
|
| 101 |
+
Extracted Morphological Features:
|
| 102 |
+
- Defect area fraction: {features.get('defect_area_fraction', 0):.3f}% of fracture surface
|
| 103 |
+
- Defect blob count: {features.get('defect_count', 0)} distinct pores/defects
|
| 104 |
+
- Mean pore area: {features.get('mean_pore_area_px', 0):.1f} px² (at 256×256 resolution)
|
| 105 |
+
- Max pore area: {features.get('max_pore_area_px', 0)} px²
|
| 106 |
+
- Mean aspect ratio: {features.get('mean_aspect_ratio', 0):.3f}
|
| 107 |
+
(1.0 = perfectly circular/keyhole, >2.0 = elongated/lack-of-fusion)
|
| 108 |
+
- Spatial spread (std): {features.get('spatial_concentration', 0):.2f} px
|
| 109 |
+
- Size heterogeneity: {features.get('size_std', 0):.1f} px² std dev
|
| 110 |
+
- Quadrant distribution:
|
| 111 |
+
Top-left: {features.get('quadrant_distribution', [0,0,0,0])[0]:.3f}
|
| 112 |
+
Top-right: {features.get('quadrant_distribution', [0,0,0,0])[1]:.3f}
|
| 113 |
+
Bottom-left: {features.get('quadrant_distribution', [0,0,0,0])[2]:.3f}
|
| 114 |
+
Bottom-right: {features.get('quadrant_distribution', [0,0,0,0])[3]:.3f}
|
| 115 |
+
- Rule-based defect type: {features.get('defect_type', 'unknown')}
|
| 116 |
+
(confidence: {features.get('confidence', 'unknown')})
|
| 117 |
+
|
| 118 |
+
Provide a structured engineering diagnosis as JSON.
|
| 119 |
+
"""
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def call_claude(features: dict, image_name: str = "") -> dict:
|
| 123 |
+
"""Call Claude API and return parsed diagnosis dict."""
|
| 124 |
+
if not HAS_ANTHROPIC:
|
| 125 |
+
return {"error": "anthropic package not installed"}
|
| 126 |
+
|
| 127 |
+
client = anthropic.Anthropic() # uses ANTHROPIC_API_KEY env var
|
| 128 |
+
prompt = build_user_prompt(features, image_name)
|
| 129 |
+
|
| 130 |
+
try:
|
| 131 |
+
response = client.messages.create(
|
| 132 |
+
model="claude-sonnet-4-20250514",
|
| 133 |
+
max_tokens=1000,
|
| 134 |
+
system=SYSTEM_PROMPT,
|
| 135 |
+
messages=[{"role": "user", "content": prompt}]
|
| 136 |
+
)
|
| 137 |
+
raw_text = response.content[0].text.strip()
|
| 138 |
+
|
| 139 |
+
# Strip markdown code fences if present
|
| 140 |
+
if raw_text.startswith("```"):
|
| 141 |
+
raw_text = raw_text.split("```")[1]
|
| 142 |
+
if raw_text.startswith("json"):
|
| 143 |
+
raw_text = raw_text[4:]
|
| 144 |
+
raw_text = raw_text.strip()
|
| 145 |
+
|
| 146 |
+
diagnosis = json.loads(raw_text)
|
| 147 |
+
return diagnosis
|
| 148 |
+
|
| 149 |
+
except json.JSONDecodeError as e:
|
| 150 |
+
return {"error": f"JSON parse error: {e}", "raw": raw_text}
|
| 151 |
+
except Exception as e:
|
| 152 |
+
return {"error": str(e)}
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def format_diagnosis_report(features: dict, diagnosis: dict, image_name: str = "") -> str:
|
| 156 |
+
"""Format a human-readable diagnosis report."""
|
| 157 |
+
sep = "=" * 60
|
| 158 |
+
lines = [
|
| 159 |
+
sep,
|
| 160 |
+
f"FAILURE ANALYSIS REPORT",
|
| 161 |
+
f"Image: {image_name}",
|
| 162 |
+
f"Material: Ti-6Al-4V (LPBF)",
|
| 163 |
+
sep,
|
| 164 |
+
"",
|
| 165 |
+
"QUANTITATIVE FEATURES",
|
| 166 |
+
f" Defect area: {features.get('defect_area_fraction', 0):.3f}%",
|
| 167 |
+
f" Defect count: {features.get('defect_count', 0)}",
|
| 168 |
+
f" Mean aspect ratio:{features.get('mean_aspect_ratio', 0):.3f}",
|
| 169 |
+
f" Rule-based type: {features.get('defect_type', 'unknown')}",
|
| 170 |
+
"",
|
| 171 |
+
]
|
| 172 |
+
|
| 173 |
+
if "error" in diagnosis:
|
| 174 |
+
lines += [f"⚠️ Diagnosis error: {diagnosis['error']}"]
|
| 175 |
+
return "\n".join(lines)
|
| 176 |
+
|
| 177 |
+
lines += [
|
| 178 |
+
"AI DIAGNOSIS",
|
| 179 |
+
f" Failure mechanism: {diagnosis.get('dominant_failure_mechanism', 'N/A')}",
|
| 180 |
+
f" Crack init. risk: {diagnosis.get('crack_initiation_risk', 'N/A').upper()}",
|
| 181 |
+
f" Critical regions: {diagnosis.get('critical_regions', 'N/A')}",
|
| 182 |
+
f" Confidence: {diagnosis.get('confidence', 'N/A')}",
|
| 183 |
+
"",
|
| 184 |
+
"SUMMARY",
|
| 185 |
+
f" {diagnosis.get('diagnosis_summary', '')}",
|
| 186 |
+
"",
|
| 187 |
+
"DEFECT INTERPRETATION",
|
| 188 |
+
f" {diagnosis.get('defect_interpretation', '')}",
|
| 189 |
+
"",
|
| 190 |
+
"RISK RATIONALE",
|
| 191 |
+
f" {diagnosis.get('risk_rationale', '')}",
|
| 192 |
+
"",
|
| 193 |
+
"RECOMMENDATIONS",
|
| 194 |
+
]
|
| 195 |
+
for i, rec in enumerate(diagnosis.get("recommendations", []), 1):
|
| 196 |
+
lines.append(f" {i}. {rec}")
|
| 197 |
+
lines.append(sep)
|
| 198 |
+
|
| 199 |
+
return "\n".join(lines)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def visualize_diagnosis(
|
| 203 |
+
image_path: Path,
|
| 204 |
+
mask: np.ndarray,
|
| 205 |
+
features: dict,
|
| 206 |
+
diagnosis: dict,
|
| 207 |
+
out_path: Path,
|
| 208 |
+
):
|
| 209 |
+
"""Save a full diagnosis visualization."""
|
| 210 |
+
raw = np.array(Image.open(image_path), dtype=np.float32)
|
| 211 |
+
raw = (raw - raw.min()) / (raw.max() - raw.min() + 1e-8)
|
| 212 |
+
raw_resized = np.array(
|
| 213 |
+
Image.fromarray((raw * 255).astype(np.uint8)).resize(
|
| 214 |
+
(IMAGE_SIZE[1], IMAGE_SIZE[0]), Image.BILINEAR
|
| 215 |
+
)
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
# Risk color
|
| 219 |
+
risk_colors = {
|
| 220 |
+
"low": "#2ecc71", "medium": "#f39c12",
|
| 221 |
+
"high": "#e74c3c", "critical": "#8e44ad"
|
| 222 |
+
}
|
| 223 |
+
risk = diagnosis.get("crack_initiation_risk", "medium")
|
| 224 |
+
risk_color = risk_colors.get(risk, "#888888")
|
| 225 |
+
|
| 226 |
+
fig = plt.figure(figsize=(18, 8))
|
| 227 |
+
fig.patch.set_facecolor("#0d0d1a")
|
| 228 |
+
|
| 229 |
+
# Title
|
| 230 |
+
mech = diagnosis.get("dominant_failure_mechanism", "Unknown")
|
| 231 |
+
fig.suptitle(
|
| 232 |
+
f"FailureGPT — {image_path.name}\n"
|
| 233 |
+
f"Mechanism: {mech} | Crack Risk: {risk.upper()}",
|
| 234 |
+
fontsize=12, fontweight="bold", color="white", y=1.01
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
# Image panel
|
| 238 |
+
ax1 = fig.add_subplot(1, 3, 1)
|
| 239 |
+
ax1.imshow(raw_resized, cmap="gray")
|
| 240 |
+
ax1.set_title("SEM Fractograph", color="white", fontsize=9)
|
| 241 |
+
ax1.axis("off")
|
| 242 |
+
ax1.set_facecolor("#0d0d1a")
|
| 243 |
+
|
| 244 |
+
# Segmentation overlay
|
| 245 |
+
ax2 = fig.add_subplot(1, 3, 2)
|
| 246 |
+
overlay = np.stack([raw_resized]*3, axis=-1).copy()
|
| 247 |
+
overlay[mask == 1] = [0, 212, 255]
|
| 248 |
+
ax2.imshow(overlay)
|
| 249 |
+
ax2.set_title(
|
| 250 |
+
f"Defect Map\n{features['defect_area_fraction']:.2f}% | "
|
| 251 |
+
f"{features['defect_count']} blobs | AR={features['mean_aspect_ratio']:.2f}",
|
| 252 |
+
color="white", fontsize=9
|
| 253 |
+
)
|
| 254 |
+
ax2.axis("off")
|
| 255 |
+
ax2.set_facecolor("#0d0d1a")
|
| 256 |
+
|
| 257 |
+
# Diagnosis text panel
|
| 258 |
+
ax3 = fig.add_subplot(1, 3, 3)
|
| 259 |
+
ax3.set_facecolor("#0d0d1a")
|
| 260 |
+
ax3.axis("off")
|
| 261 |
+
|
| 262 |
+
if "error" not in diagnosis:
|
| 263 |
+
summary = diagnosis.get("diagnosis_summary", "")
|
| 264 |
+
interp = diagnosis.get("defect_interpretation", "")
|
| 265 |
+
recs = diagnosis.get("recommendations", [])
|
| 266 |
+
conf = diagnosis.get("confidence", "")
|
| 267 |
+
|
| 268 |
+
# Word wrap helper
|
| 269 |
+
def wrap(text, width=42):
|
| 270 |
+
words, lines, line = text.split(), [], ""
|
| 271 |
+
for w in words:
|
| 272 |
+
if len(line) + len(w) + 1 <= width:
|
| 273 |
+
line += (" " if line else "") + w
|
| 274 |
+
else:
|
| 275 |
+
lines.append(line)
|
| 276 |
+
line = w
|
| 277 |
+
if line:
|
| 278 |
+
lines.append(line)
|
| 279 |
+
return "\n".join(lines)
|
| 280 |
+
|
| 281 |
+
report = (
|
| 282 |
+
f"RISK: {risk.upper()}\n"
|
| 283 |
+
f"{'─'*38}\n\n"
|
| 284 |
+
f"SUMMARY\n{wrap(summary)}\n\n"
|
| 285 |
+
f"INTERPRETATION\n{wrap(interp[:200])}\n\n"
|
| 286 |
+
f"RECOMMENDATIONS\n"
|
| 287 |
+
)
|
| 288 |
+
for i, r in enumerate(recs[:3], 1):
|
| 289 |
+
report += f"{i}. {wrap(r[:80])}\n"
|
| 290 |
+
report += f"\nConfidence: {conf}"
|
| 291 |
+
|
| 292 |
+
ax3.text(
|
| 293 |
+
0.05, 0.97, report,
|
| 294 |
+
transform=ax3.transAxes,
|
| 295 |
+
fontsize=7.5, verticalalignment="top",
|
| 296 |
+
fontfamily="monospace", color="white",
|
| 297 |
+
bbox=dict(
|
| 298 |
+
boxstyle="round", facecolor="#1a1a2e",
|
| 299 |
+
alpha=0.9, edgecolor=risk_color, linewidth=2
|
| 300 |
+
)
|
| 301 |
+
)
|
| 302 |
+
else:
|
| 303 |
+
ax3.text(
|
| 304 |
+
0.1, 0.5, f"API Error:\n{diagnosis['error']}",
|
| 305 |
+
transform=ax3.transAxes, color="red", fontsize=9
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
ax3.set_title("AI Diagnosis", color="white", fontsize=9)
|
| 309 |
+
|
| 310 |
+
plt.tight_layout()
|
| 311 |
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 312 |
+
plt.savefig(out_path, dpi=150, bbox_inches="tight",
|
| 313 |
+
facecolor="#0d0d1a")
|
| 314 |
+
plt.close()
|
| 315 |
+
print(f" Visualization → {out_path.resolve()}")
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def run_full_pipeline(image_path: Path, subset: str, save_vis: bool = True) -> dict:
|
| 319 |
+
"""Full pipeline: image → segmentation → features → diagnosis."""
|
| 320 |
+
ckpt_path = Path("checkpoints") / subset / "best_model.pt"
|
| 321 |
+
if not ckpt_path.exists():
|
| 322 |
+
print(f"❌ No checkpoint at {ckpt_path}")
|
| 323 |
+
return {}
|
| 324 |
+
|
| 325 |
+
print(f"\n{'='*60}")
|
| 326 |
+
print(f"FailureGPT Pipeline")
|
| 327 |
+
print(f"Image: {image_path.name}")
|
| 328 |
+
print(f"Subset: {subset}")
|
| 329 |
+
print(f"{'='*60}")
|
| 330 |
+
|
| 331 |
+
# Step 1: Segment
|
| 332 |
+
print("Step 1/3: Segmenting...")
|
| 333 |
+
model = load_model(ckpt_path)
|
| 334 |
+
img_tensor = load_image_tensor(image_path, IMAGE_SIZE)
|
| 335 |
+
mask = predict_mask(model, img_tensor, IMAGE_SIZE)
|
| 336 |
+
|
| 337 |
+
# Step 2: Extract features
|
| 338 |
+
print("Step 2/3: Extracting features...")
|
| 339 |
+
features = extract_features(mask, IMAGE_SIZE)
|
| 340 |
+
print(f" → {features['defect_count']} blobs, "
|
| 341 |
+
f"{features['defect_area_fraction']:.2f}% defect, "
|
| 342 |
+
f"AR={features['mean_aspect_ratio']:.2f}")
|
| 343 |
+
|
| 344 |
+
# Step 3: Generate diagnosis
|
| 345 |
+
print("Step 3/3: Generating diagnosis...")
|
| 346 |
+
diagnosis = call_claude(features, image_path.name)
|
| 347 |
+
|
| 348 |
+
# Print report
|
| 349 |
+
report = format_diagnosis_report(features, diagnosis, image_path.name)
|
| 350 |
+
print(report)
|
| 351 |
+
|
| 352 |
+
# Save visualization
|
| 353 |
+
if save_vis:
|
| 354 |
+
out_path = Path("output/diagnosis") / f"{image_path.stem}_diagnosis.png"
|
| 355 |
+
visualize_diagnosis(image_path, mask, features, diagnosis, out_path)
|
| 356 |
+
|
| 357 |
+
# Save JSON
|
| 358 |
+
result = {"image": str(image_path), "features": features, "diagnosis": diagnosis}
|
| 359 |
+
json_out = Path("output/diagnosis") / f"{image_path.stem}_diagnosis.json"
|
| 360 |
+
json_out.parent.mkdir(parents=True, exist_ok=True)
|
| 361 |
+
with open(json_out, "w") as f:
|
| 362 |
+
json.dump(result, f, indent=2)
|
| 363 |
+
print(f" JSON → {json_out.resolve()}")
|
| 364 |
+
|
| 365 |
+
return result
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
def interactive_mode(subset: str, data_dir: Path):
|
| 369 |
+
"""Interactive CLI: pick an image, get a diagnosis."""
|
| 370 |
+
subset_dir = data_dir / subset
|
| 371 |
+
ds = FractographyDataset(subset_dir, split="all", image_size=IMAGE_SIZE)
|
| 372 |
+
|
| 373 |
+
print(f"\nAvailable images in '{subset}':")
|
| 374 |
+
for i, (img_path, _) in enumerate(ds.pairs[:20]):
|
| 375 |
+
print(f" [{i:2d}] {img_path.name}")
|
| 376 |
+
|
| 377 |
+
try:
|
| 378 |
+
idx = int(input("\nEnter image index: "))
|
| 379 |
+
img_path, _ = ds.pairs[idx]
|
| 380 |
+
run_full_pipeline(img_path, subset)
|
| 381 |
+
except (ValueError, IndexError) as e:
|
| 382 |
+
print(f"Invalid selection: {e}")
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
if __name__ == "__main__":
|
| 386 |
+
parser = argparse.ArgumentParser()
|
| 387 |
+
parser.add_argument("--image", type=str, default=None)
|
| 388 |
+
parser.add_argument("--subset", type=str, default="all_defects")
|
| 389 |
+
parser.add_argument("--json", type=str, default=None,
|
| 390 |
+
help="Path to existing features JSON from features.py")
|
| 391 |
+
parser.add_argument("--interactive", action="store_true")
|
| 392 |
+
parser.add_argument("--data_dir", type=str, default="data")
|
| 393 |
+
parser.add_argument("--n", type=int, default=3,
|
| 394 |
+
help="Number of images to process in batch mode")
|
| 395 |
+
args = parser.parse_args()
|
| 396 |
+
|
| 397 |
+
if args.interactive:
|
| 398 |
+
interactive_mode(args.subset, Path(args.data_dir))
|
| 399 |
+
|
| 400 |
+
elif args.json:
|
| 401 |
+
# Diagnose from existing feature JSON
|
| 402 |
+
with open(args.json) as f:
|
| 403 |
+
feature_list = json.load(f)
|
| 404 |
+
if isinstance(feature_list, list):
|
| 405 |
+
for item in feature_list[:args.n]:
|
| 406 |
+
diagnosis = call_claude(item, item.get("image", ""))
|
| 407 |
+
print(format_diagnosis_report(item, diagnosis, item.get("image", "")))
|
| 408 |
+
else:
|
| 409 |
+
diagnosis = call_claude(feature_list)
|
| 410 |
+
print(format_diagnosis_report(feature_list, diagnosis))
|
| 411 |
+
|
| 412 |
+
elif args.image:
|
| 413 |
+
run_full_pipeline(Path(args.image), args.subset)
|
| 414 |
+
|
| 415 |
+
else:
|
| 416 |
+
# Batch: run on first n images of subset
|
| 417 |
+
subset_dir = Path(args.data_dir) / args.subset
|
| 418 |
+
ds = FractographyDataset(subset_dir, split="all", image_size=IMAGE_SIZE)
|
| 419 |
+
for img_path, _ in list(ds.pairs)[:args.n]:
|
| 420 |
+
run_full_pipeline(img_path, args.subset)
|
download_osf.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
download_osf.py
|
| 3 |
+
---------------
|
| 4 |
+
Downloads the OSF Ti-64 SEM fractography dataset (osf.io/gdwyb).
|
| 5 |
+
The dataset has 3 sub-components:
|
| 6 |
+
- Lack of Fusion defects
|
| 7 |
+
- Keyhole defects
|
| 8 |
+
- All Defects (combined)
|
| 9 |
+
|
| 10 |
+
Each sub-dataset contains SEM images + ground truth segmentation masks.
|
| 11 |
+
|
| 12 |
+
Usage:
|
| 13 |
+
python download_osf.py
|
| 14 |
+
|
| 15 |
+
Output structure:
|
| 16 |
+
data/
|
| 17 |
+
lack_of_fusion/
|
| 18 |
+
images/
|
| 19 |
+
masks/
|
| 20 |
+
keyhole/
|
| 21 |
+
images/
|
| 22 |
+
masks/
|
| 23 |
+
all_defects/
|
| 24 |
+
images/
|
| 25 |
+
masks/
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
import os
|
| 29 |
+
import requests
|
| 30 |
+
from pathlib import Path
|
| 31 |
+
from tqdm import tqdm
|
| 32 |
+
|
| 33 |
+
# OSF project GUIDs for each sub-dataset
|
| 34 |
+
# Inspect at: https://osf.io/gdwyb/
|
| 35 |
+
OSF_API = "https://api.osf.io/v2"
|
| 36 |
+
OSF_PROJECT_ID = "gdwyb" # top-level fractography project
|
| 37 |
+
|
| 38 |
+
DATA_DIR = Path("data")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def list_osf_files(node_id: str) -> list[dict]:
|
| 42 |
+
"""Recursively list all files in an OSF node."""
|
| 43 |
+
url = f"{OSF_API}/nodes/{node_id}/files/osfstorage/"
|
| 44 |
+
files = []
|
| 45 |
+
while url:
|
| 46 |
+
resp = requests.get(url, timeout=30)
|
| 47 |
+
resp.raise_for_status()
|
| 48 |
+
data = resp.json()
|
| 49 |
+
for item in data["data"]:
|
| 50 |
+
if item["attributes"]["kind"] == "file":
|
| 51 |
+
files.append({
|
| 52 |
+
"name": item["attributes"]["name"],
|
| 53 |
+
"path": item["attributes"]["materialized_path"],
|
| 54 |
+
"download": item["links"]["download"],
|
| 55 |
+
"size": item["attributes"]["size"],
|
| 56 |
+
})
|
| 57 |
+
elif item["attributes"]["kind"] == "folder":
|
| 58 |
+
# recurse into folders
|
| 59 |
+
folder_id = item["relationships"]["files"]["links"]["related"]["href"]
|
| 60 |
+
files.extend(list_osf_folder(folder_id))
|
| 61 |
+
url = data["links"].get("next")
|
| 62 |
+
return files
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def list_osf_folder(url: str) -> list[dict]:
|
| 66 |
+
"""Recursively list files inside an OSF folder URL."""
|
| 67 |
+
files = []
|
| 68 |
+
while url:
|
| 69 |
+
resp = requests.get(url, timeout=30)
|
| 70 |
+
resp.raise_for_status()
|
| 71 |
+
data = resp.json()
|
| 72 |
+
for item in data["data"]:
|
| 73 |
+
if item["attributes"]["kind"] == "file":
|
| 74 |
+
files.append({
|
| 75 |
+
"name": item["attributes"]["name"],
|
| 76 |
+
"path": item["attributes"]["materialized_path"],
|
| 77 |
+
"download": item["links"]["download"],
|
| 78 |
+
"size": item["attributes"]["size"],
|
| 79 |
+
})
|
| 80 |
+
elif item["attributes"]["kind"] == "folder":
|
| 81 |
+
folder_url = item["relationships"]["files"]["links"]["related"]["href"]
|
| 82 |
+
files.extend(list_osf_folder(folder_url))
|
| 83 |
+
url = data["links"].get("next")
|
| 84 |
+
return files
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def download_file(url: str, dest: Path):
|
| 88 |
+
"""Download a file with a progress bar."""
|
| 89 |
+
dest.parent.mkdir(parents=True, exist_ok=True)
|
| 90 |
+
if dest.exists():
|
| 91 |
+
print(f" [skip] {dest.name} already exists")
|
| 92 |
+
return
|
| 93 |
+
resp = requests.get(url, stream=True, timeout=60)
|
| 94 |
+
resp.raise_for_status()
|
| 95 |
+
total = int(resp.headers.get("content-length", 0))
|
| 96 |
+
with open(dest, "wb") as f, tqdm(
|
| 97 |
+
desc=dest.name, total=total, unit="B", unit_scale=True, leave=False
|
| 98 |
+
) as bar:
|
| 99 |
+
for chunk in resp.iter_content(chunk_size=8192):
|
| 100 |
+
f.write(chunk)
|
| 101 |
+
bar.update(len(chunk))
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def download_osf_project(node_id: str, local_root: Path):
|
| 105 |
+
"""Download all files from an OSF node into local_root, preserving folder structure."""
|
| 106 |
+
print(f"\n📂 Fetching file list from OSF node: {node_id}")
|
| 107 |
+
try:
|
| 108 |
+
files = list_osf_files(node_id)
|
| 109 |
+
except Exception as e:
|
| 110 |
+
print(f" ⚠️ Could not list files: {e}")
|
| 111 |
+
print(" → You may need to download manually from https://osf.io/gdwyb/")
|
| 112 |
+
return []
|
| 113 |
+
|
| 114 |
+
print(f" Found {len(files)} files")
|
| 115 |
+
for f in files:
|
| 116 |
+
# strip leading slash from materialized path
|
| 117 |
+
rel_path = f["path"].lstrip("/")
|
| 118 |
+
dest = local_root / rel_path
|
| 119 |
+
print(f" ↓ {rel_path} ({f['size'] / 1024:.1f} KB)")
|
| 120 |
+
try:
|
| 121 |
+
download_file(f["download"], dest)
|
| 122 |
+
except Exception as e:
|
| 123 |
+
print(f" ⚠️ Failed: {e}")
|
| 124 |
+
return files
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
if __name__ == "__main__":
|
| 128 |
+
DATA_DIR.mkdir(exist_ok=True)
|
| 129 |
+
print("=" * 60)
|
| 130 |
+
print("OSF Ti-64 Fractography Dataset Downloader")
|
| 131 |
+
print("Project: https://osf.io/gdwyb/")
|
| 132 |
+
print("=" * 60)
|
| 133 |
+
|
| 134 |
+
files = download_osf_project(OSF_PROJECT_ID, DATA_DIR)
|
| 135 |
+
|
| 136 |
+
if files:
|
| 137 |
+
print(f"\n✅ Download complete. Files saved to: {DATA_DIR.resolve()}")
|
| 138 |
+
else:
|
| 139 |
+
print("\n⚠️ Automatic download failed.")
|
| 140 |
+
print("Manual download steps:")
|
| 141 |
+
print(" 1. Go to https://osf.io/gdwyb/")
|
| 142 |
+
print(" 2. Click each sub-component (Lack of Fusion, Key Hole, All Defects)")
|
| 143 |
+
print(" 3. Download the zip and extract into data/<subfolder>/")
|
| 144 |
+
print(" Expected structure:")
|
| 145 |
+
print(" data/lack_of_fusion/images/*.png (or .tif)")
|
| 146 |
+
print(" data/lack_of_fusion/masks/*.png")
|
features.py
ADDED
|
@@ -0,0 +1,421 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
features.py
|
| 3 |
+
-----------
|
| 4 |
+
Week 3: Feature extraction + defect type classification.
|
| 5 |
+
|
| 6 |
+
Takes a trained SegFormer checkpoint, runs inference on an image,
|
| 7 |
+
and extracts quantitative morphological features from the predicted mask.
|
| 8 |
+
These features feed into:
|
| 9 |
+
1. A rule-based defect classifier (lack_of_fusion vs keyhole vs clean)
|
| 10 |
+
2. A structured feature dict consumed by the generative reasoning layer (Week 4)
|
| 11 |
+
|
| 12 |
+
Extracted features:
|
| 13 |
+
- defect_area_fraction : % of image that is defect
|
| 14 |
+
- defect_count : number of distinct defect regions
|
| 15 |
+
- mean_pore_area : mean area of individual defect blobs (px²)
|
| 16 |
+
- max_pore_area : largest single defect region
|
| 17 |
+
- mean_aspect_ratio : mean of (major_axis / minor_axis) per blob
|
| 18 |
+
→ circular pores ≈ 1.0 (keyhole)
|
| 19 |
+
→ elongated pores > 2.0 (lack of fusion)
|
| 20 |
+
- spatial_concentration : std of defect centroid positions (spread)
|
| 21 |
+
- size_std : std of pore areas (heterogeneity)
|
| 22 |
+
- quadrant_distribution : defect fraction per image quadrant
|
| 23 |
+
|
| 24 |
+
Usage:
|
| 25 |
+
python features.py --image data/all_defects/images/001-Overview-EP04V24.png
|
| 26 |
+
--subset all_defects
|
| 27 |
+
|
| 28 |
+
python features.py --subset all_defects --all # run on all images in subset
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
import argparse
|
| 32 |
+
import json
|
| 33 |
+
import math
|
| 34 |
+
from pathlib import Path
|
| 35 |
+
|
| 36 |
+
import matplotlib
|
| 37 |
+
matplotlib.use("Agg")
|
| 38 |
+
import matplotlib.pyplot as plt
|
| 39 |
+
import numpy as np
|
| 40 |
+
import torch
|
| 41 |
+
import torch.nn.functional as F
|
| 42 |
+
from PIL import Image
|
| 43 |
+
from transformers import SegformerForSemanticSegmentation
|
| 44 |
+
|
| 45 |
+
from dataset import FractographyDataset, IMAGE_SIZE, NUM_CLASSES, MASK_SCALE
|
| 46 |
+
|
| 47 |
+
# ── Config ────────────────────────────────────────────────────────────────────
|
| 48 |
+
DEVICE = torch.device("cpu")
|
| 49 |
+
|
| 50 |
+
# Rule-based classification thresholds (tunable)
|
| 51 |
+
# Lack of fusion: many small irregular pores, high aspect ratio
|
| 52 |
+
# Keyhole: fewer larger circular pores, low aspect ratio
|
| 53 |
+
THRESHOLDS = {
|
| 54 |
+
"min_defect_fraction_to_classify": 0.002,
|
| 55 |
+
"keyhole_max_aspect_ratio": 1.6, # wider keyhole band
|
| 56 |
+
"lof_min_count": 20, # need many blobs for LoF
|
| 57 |
+
}
|
| 58 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 59 |
+
def load_model(checkpoint_path: Path) -> SegformerForSemanticSegmentation:
|
| 60 |
+
from transformers import SegformerConfig
|
| 61 |
+
|
| 62 |
+
config = SegformerConfig.from_pretrained("nvidia/mit-b0")
|
| 63 |
+
config.num_labels = NUM_CLASSES
|
| 64 |
+
config.id2label = {0: "background", 1: "defect"}
|
| 65 |
+
config.label2id = {"background": 0, "defect": 1}
|
| 66 |
+
|
| 67 |
+
model = SegformerForSemanticSegmentation(config)
|
| 68 |
+
|
| 69 |
+
state = torch.load(checkpoint_path, map_location=DEVICE, weights_only=True)
|
| 70 |
+
result = model.load_state_dict(state, strict=True)
|
| 71 |
+
model.eval()
|
| 72 |
+
return model
|
| 73 |
+
def load_image_tensor(path: Path, image_size: tuple) -> torch.Tensor:
|
| 74 |
+
img = Image.open(path).convert("RGB")
|
| 75 |
+
img = img.resize((image_size[1], image_size[0]), Image.BILINEAR)
|
| 76 |
+
arr = np.array(img, dtype=np.float32) / 255.0
|
| 77 |
+
mean = np.array([0.485, 0.456, 0.406])
|
| 78 |
+
std = np.array([0.229, 0.224, 0.225])
|
| 79 |
+
arr = (arr - mean) / std
|
| 80 |
+
return torch.from_numpy(arr).permute(2, 0, 1).float()
|
| 81 |
+
@torch.no_grad()
|
| 82 |
+
def predict_mask(model, image_tensor: torch.Tensor, target_size: tuple) -> np.ndarray:
|
| 83 |
+
outputs = model(pixel_values=image_tensor.unsqueeze(0))
|
| 84 |
+
logits = outputs.logits
|
| 85 |
+
upsampled = F.interpolate(
|
| 86 |
+
logits, size=target_size, mode="bilinear", align_corners=False
|
| 87 |
+
)
|
| 88 |
+
pred = upsampled.squeeze(0).argmax(dim=0).numpy()
|
| 89 |
+
return pred.astype(np.uint8)
|
| 90 |
+
|
| 91 |
+
def connected_components(mask: np.ndarray) -> tuple[np.ndarray, int]:
|
| 92 |
+
"""
|
| 93 |
+
Simple flood-fill connected components (no scipy dependency).
|
| 94 |
+
Returns (labeled_mask, num_components).
|
| 95 |
+
"""
|
| 96 |
+
h, w = mask.shape
|
| 97 |
+
labels = np.zeros((h, w), dtype=np.int32)
|
| 98 |
+
current_label = 0
|
| 99 |
+
|
| 100 |
+
def neighbors(r, c):
|
| 101 |
+
for dr, dc in [(-1,0),(1,0),(0,-1),(0,1)]:
|
| 102 |
+
nr, nc = r+dr, c+dc
|
| 103 |
+
if 0 <= nr < h and 0 <= nc < w:
|
| 104 |
+
yield nr, nc
|
| 105 |
+
|
| 106 |
+
for r in range(h):
|
| 107 |
+
for c in range(w):
|
| 108 |
+
if mask[r, c] == 1 and labels[r, c] == 0:
|
| 109 |
+
current_label += 1
|
| 110 |
+
stack = [(r, c)]
|
| 111 |
+
labels[r, c] = current_label
|
| 112 |
+
while stack:
|
| 113 |
+
cr, cc = stack.pop()
|
| 114 |
+
for nr, nc in neighbors(cr, cc):
|
| 115 |
+
if mask[nr, nc] == 1 and labels[nr, nc] == 0:
|
| 116 |
+
labels[nr, nc] = current_label
|
| 117 |
+
stack.append((nr, nc))
|
| 118 |
+
|
| 119 |
+
return labels, current_label
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def blob_properties(labels: np.ndarray, num_blobs: int) -> list[dict]:
|
| 123 |
+
"""Compute area, centroid, and aspect ratio for each labeled blob."""
|
| 124 |
+
props = []
|
| 125 |
+
for label_id in range(1, num_blobs + 1):
|
| 126 |
+
ys, xs = np.where(labels == label_id)
|
| 127 |
+
if len(ys) == 0:
|
| 128 |
+
continue
|
| 129 |
+
area = len(ys)
|
| 130 |
+
cy, cx = ys.mean(), xs.mean()
|
| 131 |
+
|
| 132 |
+
# Bounding box aspect ratio as proxy for shape
|
| 133 |
+
h_bbox = ys.max() - ys.min() + 1
|
| 134 |
+
w_bbox = xs.max() - xs.min() + 1
|
| 135 |
+
major = max(h_bbox, w_bbox)
|
| 136 |
+
minor = min(h_bbox, w_bbox)
|
| 137 |
+
aspect_ratio = major / minor if minor > 0 else 1.0
|
| 138 |
+
|
| 139 |
+
props.append({
|
| 140 |
+
"area": area,
|
| 141 |
+
"centroid": (float(cy), float(cx)),
|
| 142 |
+
"aspect_ratio": float(aspect_ratio),
|
| 143 |
+
"bbox": (int(ys.min()), int(xs.min()), int(ys.max()), int(xs.max())),
|
| 144 |
+
})
|
| 145 |
+
return props
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def extract_features(mask: np.ndarray, image_size: tuple) -> dict:
|
| 149 |
+
"""Extract quantitative morphological features from a binary prediction mask."""
|
| 150 |
+
H, W = image_size
|
| 151 |
+
total_px = H * W
|
| 152 |
+
defect_px = int((mask == 1).sum())
|
| 153 |
+
defect_frac = defect_px / total_px
|
| 154 |
+
|
| 155 |
+
if defect_px == 0:
|
| 156 |
+
return {
|
| 157 |
+
"defect_area_fraction": 0.0,
|
| 158 |
+
"defect_count": 0,
|
| 159 |
+
"mean_pore_area_px": 0.0,
|
| 160 |
+
"max_pore_area_px": 0,
|
| 161 |
+
"mean_aspect_ratio": 0.0,
|
| 162 |
+
"spatial_concentration": 0.0,
|
| 163 |
+
"size_std": 0.0,
|
| 164 |
+
"quadrant_distribution": [0.0, 0.0, 0.0, 0.0],
|
| 165 |
+
"defect_type": "clean",
|
| 166 |
+
"confidence": "high",
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
# Connected components (note: slow for large masks — acceptable at 256×256)
|
| 170 |
+
labels, n_blobs = connected_components(mask)
|
| 171 |
+
props = blob_properties(labels, n_blobs)
|
| 172 |
+
|
| 173 |
+
areas = [p["area"] for p in props]
|
| 174 |
+
aspect_ratios = [p["aspect_ratio"] for p in props]
|
| 175 |
+
centroids = [p["centroid"] for p in props]
|
| 176 |
+
|
| 177 |
+
mean_area = float(np.mean(areas)) if areas else 0.0
|
| 178 |
+
max_area = int(max(areas)) if areas else 0
|
| 179 |
+
mean_ar = float(np.mean(aspect_ratios)) if aspect_ratios else 0.0
|
| 180 |
+
size_std = float(np.std(areas)) if areas else 0.0
|
| 181 |
+
|
| 182 |
+
# Spatial concentration: std of centroid distances from image center
|
| 183 |
+
if centroids:
|
| 184 |
+
cy_center, cx_center = H / 2, W / 2
|
| 185 |
+
dists = [math.sqrt((c[0]-cy_center)**2 + (c[1]-cx_center)**2)
|
| 186 |
+
for c in centroids]
|
| 187 |
+
spatial_conc = float(np.std(dists))
|
| 188 |
+
else:
|
| 189 |
+
spatial_conc = 0.0
|
| 190 |
+
|
| 191 |
+
# Quadrant distribution
|
| 192 |
+
half_h, half_w = H // 2, W // 2
|
| 193 |
+
quads = [
|
| 194 |
+
float((mask[:half_h, :half_w] == 1).sum()), # top-left
|
| 195 |
+
float((mask[:half_h, half_w:] == 1).sum()), # top-right
|
| 196 |
+
float((mask[half_h:, :half_w] == 1).sum()), # bottom-left
|
| 197 |
+
float((mask[half_h:, half_w:] == 1).sum()), # bottom-right
|
| 198 |
+
]
|
| 199 |
+
total_defect = sum(quads) + 1e-8
|
| 200 |
+
quad_dist = [q / total_defect for q in quads]
|
| 201 |
+
|
| 202 |
+
# ── Rule-based classification ─────────────────────────────────────────────
|
| 203 |
+
defect_type, confidence = classify_defect(defect_frac, n_blobs, mean_ar, mean_area)
|
| 204 |
+
|
| 205 |
+
return {
|
| 206 |
+
"defect_area_fraction": round(defect_frac * 100, 3), # as %
|
| 207 |
+
"defect_count": n_blobs,
|
| 208 |
+
"mean_pore_area_px": round(mean_area, 1),
|
| 209 |
+
"max_pore_area_px": max_area,
|
| 210 |
+
"mean_aspect_ratio": round(mean_ar, 3),
|
| 211 |
+
"spatial_concentration": round(spatial_conc, 2),
|
| 212 |
+
"size_std": round(size_std, 1),
|
| 213 |
+
"quadrant_distribution": [round(q, 3) for q in quad_dist],
|
| 214 |
+
"defect_type": defect_type,
|
| 215 |
+
"confidence": confidence,
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def classify_defect(
|
| 220 |
+
defect_frac: float,
|
| 221 |
+
count: int,
|
| 222 |
+
mean_ar: float,
|
| 223 |
+
mean_area: float,
|
| 224 |
+
) -> tuple[str, str]:
|
| 225 |
+
"""
|
| 226 |
+
Rule-based defect classifier.
|
| 227 |
+
Returns (defect_type, confidence).
|
| 228 |
+
|
| 229 |
+
Lack of fusion: many small irregular pores, higher aspect ratio
|
| 230 |
+
Keyhole: fewer larger circular pores, lower aspect ratio
|
| 231 |
+
Mixed: both morphologies present
|
| 232 |
+
Clean: below detection threshold
|
| 233 |
+
"""
|
| 234 |
+
t = THRESHOLDS
|
| 235 |
+
if defect_frac < t["min_defect_fraction_to_classify"]:
|
| 236 |
+
return "clean", "high"
|
| 237 |
+
|
| 238 |
+
is_circular = mean_ar <= t["keyhole_max_aspect_ratio"]
|
| 239 |
+
is_many = count >= t["lof_min_count"]
|
| 240 |
+
|
| 241 |
+
if is_circular and not is_many:
|
| 242 |
+
return "keyhole_porosity", "high"
|
| 243 |
+
elif not is_circular and is_many:
|
| 244 |
+
return "lack_of_fusion", "high"
|
| 245 |
+
elif is_circular and is_many:
|
| 246 |
+
return "mixed", "medium"
|
| 247 |
+
else:
|
| 248 |
+
return "lack_of_fusion", "medium"
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def visualize_features(
|
| 252 |
+
image_path: Path,
|
| 253 |
+
mask: np.ndarray,
|
| 254 |
+
features: dict,
|
| 255 |
+
out_path: Path,
|
| 256 |
+
):
|
| 257 |
+
"""Save a single-image feature visualization."""
|
| 258 |
+
raw = np.array(Image.open(image_path), dtype=np.float32)
|
| 259 |
+
raw = (raw - raw.min()) / (raw.max() - raw.min() + 1e-8)
|
| 260 |
+
raw_resized = np.array(
|
| 261 |
+
Image.fromarray((raw * 255).astype(np.uint8)).resize(
|
| 262 |
+
(IMAGE_SIZE[1], IMAGE_SIZE[0]), Image.BILINEAR
|
| 263 |
+
)
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
| 267 |
+
fig.suptitle(
|
| 268 |
+
f"Feature Extraction — {image_path.name}\n"
|
| 269 |
+
f"Defect Type: {features['defect_type'].upper()} "
|
| 270 |
+
f"(confidence: {features['confidence']})",
|
| 271 |
+
fontsize=11, fontweight="bold"
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
# Image
|
| 275 |
+
axes[0].imshow(raw_resized, cmap="gray")
|
| 276 |
+
axes[0].set_title("SEM Image", fontsize=9)
|
| 277 |
+
axes[0].axis("off")
|
| 278 |
+
|
| 279 |
+
# Mask with blob labels
|
| 280 |
+
overlay = np.stack([raw_resized, raw_resized, raw_resized], axis=-1).copy()
|
| 281 |
+
overlay[mask == 1] = [0, 212, 255] # cyan defects
|
| 282 |
+
axes[1].imshow(overlay)
|
| 283 |
+
axes[1].set_title(
|
| 284 |
+
f"Prediction\n{features['defect_area_fraction']:.2f}% defect | "
|
| 285 |
+
f"{features['defect_count']} blobs",
|
| 286 |
+
fontsize=9
|
| 287 |
+
)
|
| 288 |
+
axes[1].axis("off")
|
| 289 |
+
|
| 290 |
+
# Feature summary text
|
| 291 |
+
axes[2].axis("off")
|
| 292 |
+
feature_text = (
|
| 293 |
+
f"Defect Area: {features['defect_area_fraction']:.3f}%\n"
|
| 294 |
+
f"Defect Count: {features['defect_count']}\n"
|
| 295 |
+
f"Mean Pore Area: {features['mean_pore_area_px']:.1f} px²\n"
|
| 296 |
+
f"Max Pore Area: {features['max_pore_area_px']} px²\n"
|
| 297 |
+
f"Mean Aspect Ratio: {features['mean_aspect_ratio']:.3f}\n"
|
| 298 |
+
f" (1.0=circle, >2=elongated)\n"
|
| 299 |
+
f"Spatial Spread: {features['spatial_concentration']:.2f}\n"
|
| 300 |
+
f"Size Std Dev: {features['size_std']:.1f}\n\n"
|
| 301 |
+
f"Quadrant Distribution:\n"
|
| 302 |
+
f" TL:{features['quadrant_distribution'][0]:.2f} "
|
| 303 |
+
f"TR:{features['quadrant_distribution'][1]:.2f}\n"
|
| 304 |
+
f" BL:{features['quadrant_distribution'][2]:.2f} "
|
| 305 |
+
f"BR:{features['quadrant_distribution'][3]:.2f}\n\n"
|
| 306 |
+
f"─────────────────────────\n"
|
| 307 |
+
f"DEFECT TYPE: {features['defect_type']}\n"
|
| 308 |
+
f"CONFIDENCE: {features['confidence']}"
|
| 309 |
+
)
|
| 310 |
+
axes[2].text(
|
| 311 |
+
0.05, 0.95, feature_text,
|
| 312 |
+
transform=axes[2].transAxes,
|
| 313 |
+
fontsize=9, verticalalignment="top",
|
| 314 |
+
fontfamily="monospace",
|
| 315 |
+
bbox=dict(boxstyle="round", facecolor="#1a1a2e", alpha=0.8, edgecolor="#00d4ff"),
|
| 316 |
+
color="white"
|
| 317 |
+
)
|
| 318 |
+
axes[2].set_title("Extracted Features", fontsize=9)
|
| 319 |
+
|
| 320 |
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 321 |
+
plt.tight_layout()
|
| 322 |
+
plt.savefig(out_path, dpi=150, bbox_inches="tight")
|
| 323 |
+
plt.close()
|
| 324 |
+
print(f" Saved → {out_path.resolve()}")
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def run_on_image(image_path: Path, subset: str) -> dict:
|
| 328 |
+
ckpt_path = Path("checkpoints") / subset / "best_model.pt"
|
| 329 |
+
if not ckpt_path.exists():
|
| 330 |
+
print(f"❌ No checkpoint at {ckpt_path}")
|
| 331 |
+
return {}
|
| 332 |
+
|
| 333 |
+
print(f"\nImage: {image_path.name}")
|
| 334 |
+
print(f"Subset: {subset}")
|
| 335 |
+
|
| 336 |
+
model = load_model(ckpt_path)
|
| 337 |
+
img_tensor = load_image_tensor(image_path, IMAGE_SIZE)
|
| 338 |
+
mask = predict_mask(model, img_tensor, IMAGE_SIZE)
|
| 339 |
+
features = extract_features(mask, IMAGE_SIZE)
|
| 340 |
+
|
| 341 |
+
print(f"Defect type: {features['defect_type']} ({features['confidence']} confidence)")
|
| 342 |
+
print(f"Defect area: {features['defect_area_fraction']:.3f}%")
|
| 343 |
+
print(f"Blob count: {features['defect_count']}")
|
| 344 |
+
print(f"Mean AR: {features['mean_aspect_ratio']:.3f}")
|
| 345 |
+
print(json.dumps(features, indent=2))
|
| 346 |
+
|
| 347 |
+
out_path = Path("output/features") / f"{image_path.stem}_features.png"
|
| 348 |
+
visualize_features(image_path, mask, features, out_path)
|
| 349 |
+
|
| 350 |
+
return features
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
def run_on_subset(subset: str, data_dir: Path, n: int = 6):
|
| 354 |
+
"""Run feature extraction on n images from a subset and print summary."""
|
| 355 |
+
subset_dir = data_dir / subset
|
| 356 |
+
if not subset_dir.exists():
|
| 357 |
+
print(f"⚠️ {subset_dir} not found")
|
| 358 |
+
return
|
| 359 |
+
|
| 360 |
+
ds = FractographyDataset(subset_dir, split="all", image_size=IMAGE_SIZE)
|
| 361 |
+
ckpt_path = Path("checkpoints") / subset / "best_model.pt"
|
| 362 |
+
if not ckpt_path.exists():
|
| 363 |
+
print(f"⚠️ No checkpoint for {subset}")
|
| 364 |
+
return
|
| 365 |
+
|
| 366 |
+
model = load_model(ckpt_path)
|
| 367 |
+
results = []
|
| 368 |
+
|
| 369 |
+
print(f"\n{'='*60}")
|
| 370 |
+
print(f"Feature extraction: {subset} ({min(n, len(ds))} images)")
|
| 371 |
+
print(f"{'='*60}")
|
| 372 |
+
|
| 373 |
+
for idx in range(min(n, len(ds))):
|
| 374 |
+
img_path, _ = ds.pairs[idx]
|
| 375 |
+
img_tensor = load_image_tensor(img_path, IMAGE_SIZE)
|
| 376 |
+
mask = predict_mask(model, img_tensor, IMAGE_SIZE)
|
| 377 |
+
features = extract_features(mask, IMAGE_SIZE)
|
| 378 |
+
features["image"] = img_path.name
|
| 379 |
+
results.append(features)
|
| 380 |
+
|
| 381 |
+
out_path = Path("output/features") / subset / f"{img_path.stem}_features.png"
|
| 382 |
+
visualize_features(img_path, mask, features, out_path)
|
| 383 |
+
|
| 384 |
+
# Summary
|
| 385 |
+
print(f"\n Classification summary:")
|
| 386 |
+
from collections import Counter
|
| 387 |
+
counts = Counter(r["defect_type"] for r in results)
|
| 388 |
+
for dtype, count in counts.items():
|
| 389 |
+
print(f" {dtype:25s}: {count}")
|
| 390 |
+
|
| 391 |
+
# Save results JSON
|
| 392 |
+
json_out = Path("output/features") / f"{subset}_features.json"
|
| 393 |
+
json_out.parent.mkdir(parents=True, exist_ok=True)
|
| 394 |
+
with open(json_out, "w") as f:
|
| 395 |
+
json.dump(results, f, indent=2)
|
| 396 |
+
print(f"\n Feature JSON → {json_out.resolve()}")
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
if __name__ == "__main__":
|
| 400 |
+
parser = argparse.ArgumentParser()
|
| 401 |
+
parser.add_argument("--image", type=str, default=None,
|
| 402 |
+
help="Path to a single SEM image")
|
| 403 |
+
parser.add_argument("--subset", type=str, default="all_defects",
|
| 404 |
+
help="lack_of_fusion | keyhole | all_defects")
|
| 405 |
+
parser.add_argument("--all", action="store_true",
|
| 406 |
+
help="Run on all images in subset (up to --n)")
|
| 407 |
+
parser.add_argument("--n", type=int, default=6,
|
| 408 |
+
help="Number of images to process in --all mode")
|
| 409 |
+
parser.add_argument("--data_dir", type=str, default="data")
|
| 410 |
+
args = parser.parse_args()
|
| 411 |
+
|
| 412 |
+
if args.image:
|
| 413 |
+
run_on_image(Path(args.image), args.subset)
|
| 414 |
+
else:
|
| 415 |
+
subsets = (
|
| 416 |
+
["lack_of_fusion", "keyhole", "all_defects"]
|
| 417 |
+
if args.subset == "all"
|
| 418 |
+
else [args.subset]
|
| 419 |
+
)
|
| 420 |
+
for subset in subsets:
|
| 421 |
+
run_on_subset(subset, Path(args.data_dir), n=args.n)
|
inference.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
inference.py
|
| 3 |
+
------------
|
| 4 |
+
Loads a trained SegFormer checkpoint and runs inference on SEM images.
|
| 5 |
+
Saves a visualization grid showing: original image | predicted mask | overlay.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
# Run on a specific subset's val images
|
| 9 |
+
python inference.py --subset lack_of_fusion
|
| 10 |
+
|
| 11 |
+
# Run on a specific image
|
| 12 |
+
python inference.py --image path/to/image.png --subset keyhole
|
| 13 |
+
|
| 14 |
+
# Run all three subsets
|
| 15 |
+
python inference.py --subset all
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
import random
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
from features import load_model, load_image_tensor
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
import matplotlib
|
| 25 |
+
matplotlib.use("Agg")
|
| 26 |
+
import matplotlib.pyplot as plt
|
| 27 |
+
import matplotlib.patches as mpatches
|
| 28 |
+
import numpy as np
|
| 29 |
+
import torch
|
| 30 |
+
import torch.nn.functional as F
|
| 31 |
+
from PIL import Image
|
| 32 |
+
from transformers import SegformerForSemanticSegmentation
|
| 33 |
+
|
| 34 |
+
from dataset import FractographyDataset, IMAGE_SIZE, NUM_CLASSES, MASK_SCALE
|
| 35 |
+
|
| 36 |
+
# ── Config ────────────────────────────────────────────────────────────────────
|
| 37 |
+
DEVICE = torch.device("cpu")
|
| 38 |
+
N_SAMPLES = 6 # images to visualize per subset
|
| 39 |
+
LABEL_MAP = {0: ("Background", "#1a1a2e"), 1: ("Defect", "#00d4ff")}
|
| 40 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def load_model(checkpoint_path: Path) -> SegformerForSemanticSegmentation:
|
| 44 |
+
id2label = {0: "background", 1: "defect"}
|
| 45 |
+
label2id = {v: k for k, v in id2label.items()}
|
| 46 |
+
model = SegformerForSemanticSegmentation.from_pretrained(
|
| 47 |
+
"nvidia/mit-b0",
|
| 48 |
+
num_labels=NUM_CLASSES,
|
| 49 |
+
id2label=id2label,
|
| 50 |
+
label2id=label2id,
|
| 51 |
+
ignore_mismatched_sizes=True,
|
| 52 |
+
)
|
| 53 |
+
state = torch.load(checkpoint_path, map_location=DEVICE, weights_only=True)
|
| 54 |
+
model.load_state_dict(state)
|
| 55 |
+
model.eval()
|
| 56 |
+
return model
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def load_raw_image(path: Path) -> np.ndarray:
|
| 60 |
+
"""Load SEM image as a displayable uint8 RGB array (handles 16-bit)."""
|
| 61 |
+
arr = np.array(Image.open(path), dtype=np.float32)
|
| 62 |
+
arr = (arr - arr.min()) / (arr.max() - arr.min() + 1e-8)
|
| 63 |
+
rgb = np.stack([arr, arr, arr], axis=-1)
|
| 64 |
+
return (rgb * 255).astype(np.uint8)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def predict(model, image_tensor: torch.Tensor, target_size: tuple) -> np.ndarray:
|
| 68 |
+
"""Run inference and return (H, W) prediction mask as numpy array."""
|
| 69 |
+
with torch.no_grad():
|
| 70 |
+
outputs = model(pixel_values=image_tensor.unsqueeze(0))
|
| 71 |
+
logits = outputs.logits # (1, C, H/4, W/4)
|
| 72 |
+
upsampled = F.interpolate(
|
| 73 |
+
logits, size=target_size, mode="bilinear", align_corners=False
|
| 74 |
+
)
|
| 75 |
+
pred = upsampled.squeeze(0).argmax(dim=0).numpy() # (H, W)
|
| 76 |
+
return pred
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def colorize(mask: np.ndarray) -> np.ndarray:
|
| 80 |
+
rgb = np.zeros((*mask.shape, 3), dtype=np.uint8)
|
| 81 |
+
for val, (_, hex_color) in LABEL_MAP.items():
|
| 82 |
+
r, g, b = tuple(int(hex_color.lstrip("#")[i:i+2], 16) for i in (0, 2, 4))
|
| 83 |
+
rgb[mask == val] = (r, g, b)
|
| 84 |
+
return rgb
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def compute_stats(pred: np.ndarray, gt: np.ndarray) -> dict:
|
| 88 |
+
"""Compute per-image IoU and defect coverage."""
|
| 89 |
+
pred_defect = pred == 1
|
| 90 |
+
gt_defect = gt == 1
|
| 91 |
+
intersection = (pred_defect & gt_defect).sum()
|
| 92 |
+
union = (pred_defect | gt_defect).sum()
|
| 93 |
+
iou = intersection / union if union > 0 else float("nan")
|
| 94 |
+
coverage_pred = pred_defect.sum() / pred.size * 100
|
| 95 |
+
coverage_gt = gt_defect.sum() / gt.size * 100
|
| 96 |
+
return {"iou": iou, "pred_coverage": coverage_pred, "gt_coverage": coverage_gt}
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def run_inference(subset: str, args):
|
| 100 |
+
data_dir = Path(args.data_dir) / subset
|
| 101 |
+
ckpt_path = Path("checkpoints") / subset / "best_model.pt"
|
| 102 |
+
out_dir = Path("output") / "inference"
|
| 103 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 104 |
+
|
| 105 |
+
if not data_dir.exists():
|
| 106 |
+
print(f"⚠️ Skipping '{subset}' — data not found at {data_dir}")
|
| 107 |
+
return
|
| 108 |
+
if not ckpt_path.exists():
|
| 109 |
+
print(f"⚠️ Skipping '{subset}' — no checkpoint at {ckpt_path}")
|
| 110 |
+
return
|
| 111 |
+
|
| 112 |
+
print(f"\n{'='*60}")
|
| 113 |
+
print(f"Inference: {subset}")
|
| 114 |
+
print(f"Checkpoint: {ckpt_path}")
|
| 115 |
+
print(f"{'='*60}")
|
| 116 |
+
|
| 117 |
+
model = load_model(ckpt_path)
|
| 118 |
+
|
| 119 |
+
# Load dataset to get image/mask pairs
|
| 120 |
+
ds = FractographyDataset(data_dir, split="all", image_size=IMAGE_SIZE)
|
| 121 |
+
indices = list(range(len(ds)))
|
| 122 |
+
random.seed(42)
|
| 123 |
+
random.shuffle(indices)
|
| 124 |
+
sample_indices = indices[:N_SAMPLES]
|
| 125 |
+
|
| 126 |
+
# Build figure
|
| 127 |
+
n = len(sample_indices)
|
| 128 |
+
fig, axes = plt.subplots(n, 4, figsize=(16, n * 4))
|
| 129 |
+
if n == 1:
|
| 130 |
+
axes = [axes]
|
| 131 |
+
fig.suptitle(
|
| 132 |
+
f"SegFormer Inference — {subset.replace('_', ' ').title()}",
|
| 133 |
+
fontsize=13, fontweight="bold"
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
ious = []
|
| 137 |
+
for row, idx in enumerate(sample_indices):
|
| 138 |
+
img_path, mask_path = ds.pairs[idx]
|
| 139 |
+
img_tensor, gt_mask = ds[idx]
|
| 140 |
+
|
| 141 |
+
# Raw image for display (16-bit safe)
|
| 142 |
+
raw_img = load_raw_image(img_path)
|
| 143 |
+
|
| 144 |
+
# GT mask (undo MASK_SCALE)
|
| 145 |
+
gt_arr = gt_mask.numpy() # already scaled by dataset
|
| 146 |
+
|
| 147 |
+
# Predict
|
| 148 |
+
pred = predict(model, img_tensor, target_size=IMAGE_SIZE)
|
| 149 |
+
|
| 150 |
+
# Resize raw image to match prediction size for display
|
| 151 |
+
raw_resized = np.array(
|
| 152 |
+
Image.fromarray(raw_img).resize(
|
| 153 |
+
(IMAGE_SIZE[1], IMAGE_SIZE[0]), Image.BILINEAR
|
| 154 |
+
)
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
# Stats
|
| 158 |
+
stats = compute_stats(pred, gt_arr)
|
| 159 |
+
ious.append(stats["iou"])
|
| 160 |
+
|
| 161 |
+
# Colorize
|
| 162 |
+
pred_colored = colorize(pred)
|
| 163 |
+
gt_colored = colorize(gt_arr)
|
| 164 |
+
overlay = (raw_resized.astype(float) * 0.6 +
|
| 165 |
+
pred_colored.astype(float) * 0.4).astype(np.uint8)
|
| 166 |
+
|
| 167 |
+
# Plot
|
| 168 |
+
axes[row][0].imshow(raw_resized, cmap="gray")
|
| 169 |
+
axes[row][0].set_title(f"Image\n{img_path.name}", fontsize=7)
|
| 170 |
+
axes[row][0].axis("off")
|
| 171 |
+
|
| 172 |
+
axes[row][1].imshow(gt_colored)
|
| 173 |
+
axes[row][1].set_title(
|
| 174 |
+
f"Ground Truth\n{stats['gt_coverage']:.1f}% defect", fontsize=7
|
| 175 |
+
)
|
| 176 |
+
axes[row][1].axis("off")
|
| 177 |
+
|
| 178 |
+
axes[row][2].imshow(pred_colored)
|
| 179 |
+
axes[row][2].set_title(
|
| 180 |
+
f"Prediction\n{stats['pred_coverage']:.1f}% defect", fontsize=7
|
| 181 |
+
)
|
| 182 |
+
axes[row][2].axis("off")
|
| 183 |
+
|
| 184 |
+
axes[row][3].imshow(overlay)
|
| 185 |
+
iou_str = f"{stats['iou']:.3f}" if not np.isnan(stats["iou"]) else "N/A"
|
| 186 |
+
axes[row][3].set_title(f"Overlay\nIoU={iou_str}", fontsize=7)
|
| 187 |
+
axes[row][3].axis("off")
|
| 188 |
+
|
| 189 |
+
# Legend
|
| 190 |
+
patches = [
|
| 191 |
+
mpatches.Patch(color=LABEL_MAP[0][1], label="Background"),
|
| 192 |
+
mpatches.Patch(color=LABEL_MAP[1][1], label="Defect"),
|
| 193 |
+
]
|
| 194 |
+
fig.legend(handles=patches, loc="lower center", ncol=2,
|
| 195 |
+
bbox_to_anchor=(0.5, -0.01), fontsize=9)
|
| 196 |
+
|
| 197 |
+
mean_iou = np.nanmean(ious)
|
| 198 |
+
fig.text(0.5, -0.03, f"Mean IoU (these samples): {mean_iou:.4f}",
|
| 199 |
+
ha="center", fontsize=10, fontweight="bold")
|
| 200 |
+
|
| 201 |
+
plt.tight_layout()
|
| 202 |
+
out_path = out_dir / f"{subset}_inference.png"
|
| 203 |
+
plt.savefig(out_path, dpi=150, bbox_inches="tight")
|
| 204 |
+
plt.close()
|
| 205 |
+
|
| 206 |
+
print(f" Mean IoU (sampled): {mean_iou:.4f}")
|
| 207 |
+
print(f" Saved → {out_path.resolve()}")
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
if __name__ == "__main__":
|
| 211 |
+
parser = argparse.ArgumentParser()
|
| 212 |
+
parser.add_argument("--subset", type=str, default="all",
|
| 213 |
+
help="lack_of_fusion | keyhole | all_defects | all")
|
| 214 |
+
parser.add_argument("--data_dir", type=str, default="data")
|
| 215 |
+
parser.add_argument("--n", type=int, default=6,
|
| 216 |
+
help="Number of images to visualize")
|
| 217 |
+
args = parser.parse_args()
|
| 218 |
+
|
| 219 |
+
N_SAMPLES = args.n
|
| 220 |
+
|
| 221 |
+
subsets = (
|
| 222 |
+
["lack_of_fusion", "keyhole", "all_defects"]
|
| 223 |
+
if args.subset == "all"
|
| 224 |
+
else [args.subset]
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
for subset in subsets:
|
| 228 |
+
run_inference(subset, args)
|
| 229 |
+
|
| 230 |
+
print("\n✅ Done. Check output/inference/")
|
inspect_dataset.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
inspect_dataset.py
|
| 3 |
+
------------------
|
| 4 |
+
Inspects the OSF Ti-64 SEM fractography dataset after downloading.
|
| 5 |
+
Run after download_osf.py.
|
| 6 |
+
|
| 7 |
+
What this does:
|
| 8 |
+
1. Scans the data/ directory and reports what it finds
|
| 9 |
+
2. Detects mask format (grayscale int labels vs RGB color masks)
|
| 10 |
+
3. Prints unique class label values found in masks
|
| 11 |
+
4. Generates a visualization grid of image/mask pairs
|
| 12 |
+
5. Saves visualization to output/inspection_grid.png
|
| 13 |
+
|
| 14 |
+
Usage:
|
| 15 |
+
python inspect_dataset.py
|
| 16 |
+
python inspect_dataset.py --data_dir path/to/your/data
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import argparse
|
| 20 |
+
import sys
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
|
| 23 |
+
import matplotlib
|
| 24 |
+
matplotlib.use("Agg") # headless-safe; switch to "TkAgg" if you want interactive
|
| 25 |
+
import matplotlib.pyplot as plt
|
| 26 |
+
import matplotlib.patches as mpatches
|
| 27 |
+
import numpy as np
|
| 28 |
+
from PIL import Image
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# ── Configurable label map ────────────────────────────────────────────────────
|
| 32 |
+
# Update this once you've inspected the actual class values in your masks.
|
| 33 |
+
# Keys = integer pixel values in mask PNGs.
|
| 34 |
+
LABEL_MAP = {
|
| 35 |
+
0: ("Background", "#1a1a2e"),
|
| 36 |
+
1: ("Lack of Fusion", "#e94560"),
|
| 37 |
+
2: ("Keyhole", "#0f3460"),
|
| 38 |
+
3: ("Other Defect", "#533483"),
|
| 39 |
+
# Add more if you find additional class values
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
# Fallback colormap for unknown labels
|
| 43 |
+
CMAP = plt.cm.get_cmap("tab10")
|
| 44 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def find_image_mask_pairs(data_dir: Path) -> list[tuple[Path, Path]]:
|
| 48 |
+
"""
|
| 49 |
+
Scan data_dir for image/mask pairs.
|
| 50 |
+
Assumes masks live in a folder named 'masks' or 'mask',
|
| 51 |
+
and images in 'images' or 'image', or are paired by filename.
|
| 52 |
+
"""
|
| 53 |
+
pairs = []
|
| 54 |
+
image_exts = {".png", ".tif", ".tiff", ".jpg", ".jpeg", ".bmp"}
|
| 55 |
+
|
| 56 |
+
# Strategy 1: look for images/ and masks/ sibling folders
|
| 57 |
+
for images_dir in sorted(data_dir.rglob("images")):
|
| 58 |
+
if not images_dir.is_dir():
|
| 59 |
+
continue
|
| 60 |
+
masks_dir = images_dir.parent / "masks"
|
| 61 |
+
if not masks_dir.exists():
|
| 62 |
+
masks_dir = images_dir.parent / "mask"
|
| 63 |
+
if not masks_dir.exists():
|
| 64 |
+
print(f" ⚠️ Found images/ at {images_dir} but no masks/ sibling")
|
| 65 |
+
continue
|
| 66 |
+
for img_path in sorted(images_dir.iterdir()):
|
| 67 |
+
if img_path.suffix.lower() not in image_exts:
|
| 68 |
+
continue
|
| 69 |
+
# Try matching by stem
|
| 70 |
+
for ext in image_exts:
|
| 71 |
+
mask_path = masks_dir / (img_path.stem + ext)
|
| 72 |
+
if mask_path.exists():
|
| 73 |
+
pairs.append((img_path, mask_path))
|
| 74 |
+
break
|
| 75 |
+
else:
|
| 76 |
+
print(f" ⚠️ No mask found for {img_path.name}")
|
| 77 |
+
|
| 78 |
+
# Strategy 2: flat folder — files named *_image.* and *_mask.*
|
| 79 |
+
if not pairs:
|
| 80 |
+
for img_path in sorted(data_dir.rglob("*_image.*")):
|
| 81 |
+
if img_path.suffix.lower() not in image_exts:
|
| 82 |
+
continue
|
| 83 |
+
stem = img_path.stem.replace("_image", "")
|
| 84 |
+
for ext in image_exts:
|
| 85 |
+
mask_path = img_path.parent / f"{stem}_mask{ext}"
|
| 86 |
+
if mask_path.exists():
|
| 87 |
+
pairs.append((img_path, mask_path))
|
| 88 |
+
break
|
| 89 |
+
|
| 90 |
+
return pairs
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def inspect_mask(mask_path: Path) -> dict:
|
| 94 |
+
"""Return statistics about a mask file."""
|
| 95 |
+
mask = np.array(Image.open(mask_path))
|
| 96 |
+
info = {
|
| 97 |
+
"shape": mask.shape,
|
| 98 |
+
"dtype": str(mask.dtype),
|
| 99 |
+
"mode": Image.open(mask_path).mode,
|
| 100 |
+
"unique_values": sorted(np.unique(mask).tolist()),
|
| 101 |
+
"min": int(mask.min()),
|
| 102 |
+
"max": int(mask.max()),
|
| 103 |
+
}
|
| 104 |
+
return info
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def colorize_mask(mask: np.ndarray) -> np.ndarray:
|
| 108 |
+
"""Convert integer label mask to RGB image for visualization."""
|
| 109 |
+
unique = np.unique(mask)
|
| 110 |
+
rgb = np.zeros((*mask.shape[:2], 3), dtype=np.uint8)
|
| 111 |
+
for val in unique:
|
| 112 |
+
if val in LABEL_MAP:
|
| 113 |
+
hex_color = LABEL_MAP[val][1].lstrip("#")
|
| 114 |
+
r, g, b = tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))
|
| 115 |
+
color = (r, g, b)
|
| 116 |
+
else:
|
| 117 |
+
# fallback: use matplotlib colormap
|
| 118 |
+
rgba = CMAP(val / max(unique.max(), 1))
|
| 119 |
+
color = tuple(int(c * 255) for c in rgba[:3])
|
| 120 |
+
rgb[mask == val] = color
|
| 121 |
+
return rgb
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def make_legend(unique_vals: list[int]) -> list[mpatches.Patch]:
|
| 125 |
+
patches = []
|
| 126 |
+
for val in unique_vals:
|
| 127 |
+
label, hex_color = LABEL_MAP.get(val, (f"Class {val}", "#888888"))
|
| 128 |
+
patches.append(mpatches.Patch(color=hex_color, label=f"{val}: {label}"))
|
| 129 |
+
return patches
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def visualize_pairs(
|
| 133 |
+
pairs: list[tuple[Path, Path]],
|
| 134 |
+
n: int = 6,
|
| 135 |
+
output_path: Path = Path("output/inspection_grid.png"),
|
| 136 |
+
):
|
| 137 |
+
"""Save a grid of n image/mask/overlay triplets."""
|
| 138 |
+
n = min(n, len(pairs))
|
| 139 |
+
if n == 0:
|
| 140 |
+
print(" No pairs to visualize.")
|
| 141 |
+
return
|
| 142 |
+
|
| 143 |
+
fig, axes = plt.subplots(n, 3, figsize=(12, n * 4))
|
| 144 |
+
if n == 1:
|
| 145 |
+
axes = [axes]
|
| 146 |
+
|
| 147 |
+
fig.suptitle("OSF Ti-64 SEM Dataset — Inspection Grid\n(Image | Mask | Overlay)",
|
| 148 |
+
fontsize=13, fontweight="bold", y=1.01)
|
| 149 |
+
|
| 150 |
+
all_unique = set()
|
| 151 |
+
|
| 152 |
+
for i, (img_path, mask_path) in enumerate(pairs[:n]):
|
| 153 |
+
|
| 154 |
+
raw = np.array(Image.open(img_path), dtype=np.float32)
|
| 155 |
+
raw = (raw - raw.min()) / (raw.max() - raw.min() + 1e-8)
|
| 156 |
+
img = np.stack([raw, raw, raw], axis=-1)
|
| 157 |
+
mask_pil = Image.open(mask_path)
|
| 158 |
+
mask_arr = np.array(mask_pil)
|
| 159 |
+
|
| 160 |
+
# If mask is RGB, convert to grayscale for inspection
|
| 161 |
+
if mask_arr.ndim == 3:
|
| 162 |
+
mask_arr = np.array(mask_pil.convert("L"))
|
| 163 |
+
|
| 164 |
+
unique_vals = sorted(np.unique(mask_arr).tolist())
|
| 165 |
+
all_unique.update(unique_vals)
|
| 166 |
+
mask_rgb = colorize_mask(mask_arr)
|
| 167 |
+
|
| 168 |
+
# Overlay: blend image and mask
|
| 169 |
+
img_display = (img * 255).astype(np.uint8) if img.max() <= 1.0 else img.astype(np.uint8)
|
| 170 |
+
overlay = (img_display.astype(float) * 0.6 + mask_rgb.astype(float) * 0.4).astype(np.uint8)
|
| 171 |
+
axes[i][0].imshow(img, cmap="gray" if img.ndim == 2 else None)
|
| 172 |
+
axes[i][0].set_title(f"Image\n{img_path.name}", fontsize=8)
|
| 173 |
+
axes[i][0].axis("off")
|
| 174 |
+
|
| 175 |
+
axes[i][1].imshow(mask_rgb)
|
| 176 |
+
axes[i][1].set_title(
|
| 177 |
+
f"Mask (classes: {unique_vals})\n{mask_path.name}", fontsize=8
|
| 178 |
+
)
|
| 179 |
+
axes[i][1].axis("off")
|
| 180 |
+
|
| 181 |
+
axes[i][2].imshow(overlay)
|
| 182 |
+
axes[i][2].set_title("Overlay", fontsize=8)
|
| 183 |
+
axes[i][2].axis("off")
|
| 184 |
+
|
| 185 |
+
# Add legend
|
| 186 |
+
legend_patches = make_legend(sorted(all_unique))
|
| 187 |
+
fig.legend(handles=legend_patches, loc="lower center", ncol=len(legend_patches),
|
| 188 |
+
bbox_to_anchor=(0.5, -0.02), fontsize=9, title="Mask Classes Found")
|
| 189 |
+
|
| 190 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 191 |
+
plt.tight_layout()
|
| 192 |
+
plt.savefig(output_path, dpi=150, bbox_inches="tight")
|
| 193 |
+
plt.close()
|
| 194 |
+
print(f"\n✅ Visualization saved to: {output_path.resolve()}")
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def print_dataset_summary(data_dir: Path, pairs: list[tuple[Path, Path]]):
|
| 198 |
+
print(f"\n{'='*60}")
|
| 199 |
+
print(f"Dataset Summary — {data_dir.resolve()}")
|
| 200 |
+
print(f"{'='*60}")
|
| 201 |
+
print(f"Total image/mask pairs found: {len(pairs)}")
|
| 202 |
+
|
| 203 |
+
if not pairs:
|
| 204 |
+
print("\n⚠️ No pairs found. Check your data/ folder structure.")
|
| 205 |
+
print("Expected layout:")
|
| 206 |
+
print(" data/")
|
| 207 |
+
print(" <subset>/")
|
| 208 |
+
print(" images/ ← SEM images (.png or .tif)")
|
| 209 |
+
print(" masks/ ← segmentation masks (.png)")
|
| 210 |
+
return
|
| 211 |
+
|
| 212 |
+
# Sample first few masks
|
| 213 |
+
print(f"\nSampling first 5 masks for format inspection:")
|
| 214 |
+
all_unique = set()
|
| 215 |
+
for img_path, mask_path in pairs[:5]:
|
| 216 |
+
info = inspect_mask(mask_path)
|
| 217 |
+
print(f"\n {mask_path.name}")
|
| 218 |
+
print(f" Mode: {info['mode']}")
|
| 219 |
+
print(f" Shape: {info['shape']}")
|
| 220 |
+
print(f" Dtype: {info['dtype']}")
|
| 221 |
+
print(f" Unique values: {info['unique_values']}")
|
| 222 |
+
print(f" Value range: [{info['min']}, {info['max']}]")
|
| 223 |
+
all_unique.update(info["unique_values"])
|
| 224 |
+
|
| 225 |
+
print(f"\n{'─'*40}")
|
| 226 |
+
print(f"All unique class values across sampled masks: {sorted(all_unique)}")
|
| 227 |
+
print("\nLabel interpretation:")
|
| 228 |
+
for v in sorted(all_unique):
|
| 229 |
+
label, _ = LABEL_MAP.get(v, (f"UNKNOWN — update LABEL_MAP in this script", "#888"))
|
| 230 |
+
print(f" {v:3d} → {label}")
|
| 231 |
+
|
| 232 |
+
print(f"\n⚠️ NOTE: If all unique values are {{0, 255}}, masks are binary (defect/no-defect).")
|
| 233 |
+
print(" If values are 0–N, masks are multi-class integer labels — ideal for SegFormer.")
|
| 234 |
+
print(" If mode is 'RGB', masks encode class as color — you'll need to remap.")
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def main():
|
| 238 |
+
parser = argparse.ArgumentParser()
|
| 239 |
+
parser.add_argument("--data_dir", type=str, default="data",
|
| 240 |
+
help="Path to downloaded dataset root")
|
| 241 |
+
parser.add_argument("--n_vis", type=int, default=6,
|
| 242 |
+
help="Number of pairs to visualize")
|
| 243 |
+
parser.add_argument("--output", type=str, default="output/inspection_grid.png",
|
| 244 |
+
help="Where to save the visualization grid")
|
| 245 |
+
args = parser.parse_args()
|
| 246 |
+
|
| 247 |
+
data_dir = Path(args.data_dir)
|
| 248 |
+
if not data_dir.exists():
|
| 249 |
+
print(f"❌ data_dir '{data_dir}' does not exist.")
|
| 250 |
+
print("Run download_osf.py first, or set --data_dir to your data folder.")
|
| 251 |
+
sys.exit(1)
|
| 252 |
+
|
| 253 |
+
print("Scanning for image/mask pairs...")
|
| 254 |
+
pairs = find_image_mask_pairs(data_dir)
|
| 255 |
+
|
| 256 |
+
print_dataset_summary(data_dir, pairs)
|
| 257 |
+
|
| 258 |
+
if pairs:
|
| 259 |
+
print(f"\nGenerating visualization grid ({min(args.n_vis, len(pairs))} samples)...")
|
| 260 |
+
visualize_pairs(pairs, n=args.n_vis, output_path=Path(args.output))
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
if __name__ == "__main__":
|
| 264 |
+
main()
|
setup.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FailureGPT — OSF Ti-64 Dataset Inspector
|
| 3 |
+
Run this first to install dependencies.
|
| 4 |
+
"""
|
| 5 |
+
import subprocess, sys
|
| 6 |
+
|
| 7 |
+
packages = [
|
| 8 |
+
"torch",
|
| 9 |
+
"torchvision",
|
| 10 |
+
"Pillow",
|
| 11 |
+
"matplotlib",
|
| 12 |
+
"numpy",
|
| 13 |
+
"requests",
|
| 14 |
+
"osfclient",
|
| 15 |
+
"tqdm",
|
| 16 |
+
]
|
| 17 |
+
|
| 18 |
+
for pkg in packages:
|
| 19 |
+
subprocess.check_call([sys.executable, "-m", "pip", "install", pkg, "-q"])
|
| 20 |
+
|
| 21 |
+
print("✅ All dependencies installed.")
|
test.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
inspect_dataset.py
|
| 3 |
+
------------------
|
| 4 |
+
Inspects the OSF Ti-64 SEM fractography dataset after downloading.
|
| 5 |
+
Run after download_osf.py.
|
| 6 |
+
|
| 7 |
+
What this does:
|
| 8 |
+
1. Scans the data/ directory and reports what it finds
|
| 9 |
+
2. Detects mask format (grayscale int labels vs RGB color masks)
|
| 10 |
+
3. Prints unique class label values found in masks
|
| 11 |
+
4. Generates a visualization grid of image/mask pairs
|
| 12 |
+
5. Saves visualization to output/inspection_grid.png
|
| 13 |
+
|
| 14 |
+
Usage:
|
| 15 |
+
python inspect_dataset.py
|
| 16 |
+
python inspect_dataset.py --data_dir path/to/your/data
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import argparse
|
| 20 |
+
import sys
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
|
| 23 |
+
import matplotlib
|
| 24 |
+
matplotlib.use("Agg") # headless-safe; switch to "TkAgg" if you want interactive
|
| 25 |
+
import matplotlib.pyplot as plt
|
| 26 |
+
import matplotlib.patches as mpatches
|
| 27 |
+
import numpy as np
|
| 28 |
+
from PIL import Image
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# ── Configurable label map ────────────────────────────────────────────────────
|
| 32 |
+
# Update this once you've inspected the actual class values in your masks.
|
| 33 |
+
# Keys = integer pixel values in mask PNGs.
|
| 34 |
+
LABEL_MAP = {
|
| 35 |
+
0: ("Background", "#1a1a2e"),
|
| 36 |
+
1: ("Lack of Fusion", "#e94560"),
|
| 37 |
+
2: ("Keyhole", "#0f3460"),
|
| 38 |
+
3: ("Other Defect", "#533483"),
|
| 39 |
+
# Add more if you find additional class values
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
# Fallback colormap for unknown labels
|
| 43 |
+
CMAP = plt.cm.get_cmap("tab10")
|
| 44 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def find_image_mask_pairs(data_dir: Path) -> list[tuple[Path, Path]]:
|
| 48 |
+
"""
|
| 49 |
+
Scan data_dir for image/mask pairs.
|
| 50 |
+
Assumes masks live in a folder named 'masks' or 'mask',
|
| 51 |
+
and images in 'images' or 'image', or are paired by filename.
|
| 52 |
+
"""
|
| 53 |
+
pairs = []
|
| 54 |
+
image_exts = {".png", ".tif", ".tiff", ".jpg", ".jpeg", ".bmp"}
|
| 55 |
+
|
| 56 |
+
# Strategy 1: look for images/ and masks/ sibling folders
|
| 57 |
+
for images_dir in sorted(data_dir.rglob("images")):
|
| 58 |
+
if not images_dir.is_dir():
|
| 59 |
+
continue
|
| 60 |
+
masks_dir = images_dir.parent / "masks"
|
| 61 |
+
if not masks_dir.exists():
|
| 62 |
+
masks_dir = images_dir.parent / "mask"
|
| 63 |
+
if not masks_dir.exists():
|
| 64 |
+
print(f" ⚠️ Found images/ at {images_dir} but no masks/ sibling")
|
| 65 |
+
continue
|
| 66 |
+
for img_path in sorted(images_dir.iterdir()):
|
| 67 |
+
if img_path.suffix.lower() not in image_exts:
|
| 68 |
+
continue
|
| 69 |
+
# Try matching by stem
|
| 70 |
+
for ext in image_exts:
|
| 71 |
+
mask_path = masks_dir / (img_path.stem + ext)
|
| 72 |
+
if mask_path.exists():
|
| 73 |
+
pairs.append((img_path, mask_path))
|
| 74 |
+
break
|
| 75 |
+
else:
|
| 76 |
+
print(f" ⚠️ No mask found for {img_path.name}")
|
| 77 |
+
|
| 78 |
+
# Strategy 2: flat folder — files named *_image.* and *_mask.*
|
| 79 |
+
if not pairs:
|
| 80 |
+
for img_path in sorted(data_dir.rglob("*_image.*")):
|
| 81 |
+
if img_path.suffix.lower() not in image_exts:
|
| 82 |
+
continue
|
| 83 |
+
stem = img_path.stem.replace("_image", "")
|
| 84 |
+
for ext in image_exts:
|
| 85 |
+
mask_path = img_path.parent / f"{stem}_mask{ext}"
|
| 86 |
+
if mask_path.exists():
|
| 87 |
+
pairs.append((img_path, mask_path))
|
| 88 |
+
break
|
| 89 |
+
|
| 90 |
+
return pairs
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def inspect_mask(mask_path: Path) -> dict:
|
| 94 |
+
"""Return statistics about a mask file."""
|
| 95 |
+
mask = np.array(Image.open(mask_path))
|
| 96 |
+
info = {
|
| 97 |
+
"shape": mask.shape,
|
| 98 |
+
"dtype": str(mask.dtype),
|
| 99 |
+
"mode": Image.open(mask_path).mode,
|
| 100 |
+
"unique_values": sorted(np.unique(mask).tolist()),
|
| 101 |
+
"min": int(mask.min()),
|
| 102 |
+
"max": int(mask.max()),
|
| 103 |
+
}
|
| 104 |
+
return info
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def colorize_mask(mask: np.ndarray) -> np.ndarray:
|
| 108 |
+
"""Convert integer label mask to RGB image for visualization."""
|
| 109 |
+
unique = np.unique(mask)
|
| 110 |
+
rgb = np.zeros((*mask.shape[:2], 3), dtype=np.uint8)
|
| 111 |
+
for val in unique:
|
| 112 |
+
if val in LABEL_MAP:
|
| 113 |
+
hex_color = LABEL_MAP[val][1].lstrip("#")
|
| 114 |
+
r, g, b = tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))
|
| 115 |
+
color = (r, g, b)
|
| 116 |
+
else:
|
| 117 |
+
# fallback: use matplotlib colormap
|
| 118 |
+
rgba = CMAP(val / max(unique.max(), 1))
|
| 119 |
+
color = tuple(int(c * 255) for c in rgba[:3])
|
| 120 |
+
rgb[mask == val] = color
|
| 121 |
+
return rgb
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def make_legend(unique_vals: list[int]) -> list[mpatches.Patch]:
|
| 125 |
+
patches = []
|
| 126 |
+
for val in unique_vals:
|
| 127 |
+
label, hex_color = LABEL_MAP.get(val, (f"Class {val}", "#888888"))
|
| 128 |
+
patches.append(mpatches.Patch(color=hex_color, label=f"{val}: {label}"))
|
| 129 |
+
return patches
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def visualize_pairs(
|
| 133 |
+
pairs: list[tuple[Path, Path]],
|
| 134 |
+
n: int = 6,
|
| 135 |
+
output_path: Path = Path("output/inspection_grid.png"),
|
| 136 |
+
):
|
| 137 |
+
"""Save a grid of n image/mask/overlay triplets."""
|
| 138 |
+
n = min(n, len(pairs))
|
| 139 |
+
if n == 0:
|
| 140 |
+
print(" No pairs to visualize.")
|
| 141 |
+
return
|
| 142 |
+
|
| 143 |
+
fig, axes = plt.subplots(n, 3, figsize=(12, n * 4))
|
| 144 |
+
if n == 1:
|
| 145 |
+
axes = [axes]
|
| 146 |
+
|
| 147 |
+
fig.suptitle("OSF Ti-64 SEM Dataset — Inspection Grid\n(Image | Mask | Overlay)",
|
| 148 |
+
fontsize=13, fontweight="bold", y=1.01)
|
| 149 |
+
|
| 150 |
+
all_unique = set()
|
| 151 |
+
|
| 152 |
+
for i, (img_path, mask_path) in enumerate(pairs[:n]):
|
| 153 |
+
img = np.array(Image.open(img_path).convert("RGB"))
|
| 154 |
+
mask_pil = Image.open(mask_path)
|
| 155 |
+
mask_arr = np.array(mask_pil)
|
| 156 |
+
|
| 157 |
+
# If mask is RGB, convert to grayscale for inspection
|
| 158 |
+
if mask_arr.ndim == 3:
|
| 159 |
+
mask_arr = np.array(mask_pil.convert("L"))
|
| 160 |
+
|
| 161 |
+
unique_vals = sorted(np.unique(mask_arr).tolist())
|
| 162 |
+
all_unique.update(unique_vals)
|
| 163 |
+
mask_rgb = colorize_mask(mask_arr)
|
| 164 |
+
|
| 165 |
+
# Overlay: blend image and mask
|
| 166 |
+
overlay = (img.astype(float) * 0.5 + mask_rgb.astype(float) * 0.5).astype(np.uint8)
|
| 167 |
+
|
| 168 |
+
axes[i][0].imshow(img, cmap="gray" if img.ndim == 2 else None)
|
| 169 |
+
axes[i][0].set_title(f"Image\n{img_path.name}", fontsize=8)
|
| 170 |
+
axes[i][0].axis("off")
|
| 171 |
+
|
| 172 |
+
axes[i][1].imshow(mask_rgb)
|
| 173 |
+
axes[i][1].set_title(
|
| 174 |
+
f"Mask (classes: {unique_vals})\n{mask_path.name}", fontsize=8
|
| 175 |
+
)
|
| 176 |
+
axes[i][1].axis("off")
|
| 177 |
+
|
| 178 |
+
axes[i][2].imshow(overlay)
|
| 179 |
+
axes[i][2].set_title("Overlay", fontsize=8)
|
| 180 |
+
axes[i][2].axis("off")
|
| 181 |
+
|
| 182 |
+
# Add legend
|
| 183 |
+
legend_patches = make_legend(sorted(all_unique))
|
| 184 |
+
fig.legend(handles=legend_patches, loc="lower center", ncol=len(legend_patches),
|
| 185 |
+
bbox_to_anchor=(0.5, -0.02), fontsize=9, title="Mask Classes Found")
|
| 186 |
+
|
| 187 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 188 |
+
plt.tight_layout()
|
| 189 |
+
plt.savefig(output_path, dpi=150, bbox_inches="tight")
|
| 190 |
+
plt.close()
|
| 191 |
+
print(f"\n✅ Visualization saved to: {output_path.resolve()}")
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def print_dataset_summary(data_dir: Path, pairs: list[tuple[Path, Path]]):
|
| 195 |
+
print(f"\n{'='*60}")
|
| 196 |
+
print(f"Dataset Summary — {data_dir.resolve()}")
|
| 197 |
+
print(f"{'='*60}")
|
| 198 |
+
print(f"Total image/mask pairs found: {len(pairs)}")
|
| 199 |
+
|
| 200 |
+
if not pairs:
|
| 201 |
+
print("\n⚠️ No pairs found. Check your data/ folder structure.")
|
| 202 |
+
print("Expected layout:")
|
| 203 |
+
print(" data/")
|
| 204 |
+
print(" <subset>/")
|
| 205 |
+
print(" images/ ← SEM images (.png or .tif)")
|
| 206 |
+
print(" masks/ ← segmentation masks (.png)")
|
| 207 |
+
return
|
| 208 |
+
|
| 209 |
+
# Sample first few masks
|
| 210 |
+
print(f"\nSampling first 5 masks for format inspection:")
|
| 211 |
+
all_unique = set()
|
| 212 |
+
for img_path, mask_path in pairs[:5]:
|
| 213 |
+
info = inspect_mask(mask_path)
|
| 214 |
+
print(f"\n {mask_path.name}")
|
| 215 |
+
print(f" Mode: {info['mode']}")
|
| 216 |
+
print(f" Shape: {info['shape']}")
|
| 217 |
+
print(f" Dtype: {info['dtype']}")
|
| 218 |
+
print(f" Unique values: {info['unique_values']}")
|
| 219 |
+
print(f" Value range: [{info['min']}, {info['max']}]")
|
| 220 |
+
all_unique.update(info["unique_values"])
|
| 221 |
+
|
| 222 |
+
print(f"\n{'─'*40}")
|
| 223 |
+
print(f"All unique class values across sampled masks: {sorted(all_unique)}")
|
| 224 |
+
print("\nLabel interpretation:")
|
| 225 |
+
for v in sorted(all_unique):
|
| 226 |
+
label, _ = LABEL_MAP.get(v, (f"UNKNOWN — update LABEL_MAP in this script", "#888"))
|
| 227 |
+
print(f" {v:3d} → {label}")
|
| 228 |
+
|
| 229 |
+
print(f"\n⚠️ NOTE: If all unique values are {{0, 255}}, masks are binary (defect/no-defect).")
|
| 230 |
+
print(" If values are 0–N, masks are multi-class integer labels — ideal for SegFormer.")
|
| 231 |
+
print(" If mode is 'RGB', masks encode class as color — you'll need to remap.")
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def main():
|
| 235 |
+
parser = argparse.ArgumentParser()
|
| 236 |
+
parser.add_argument("--data_dir", type=str, default="data",
|
| 237 |
+
help="Path to downloaded dataset root")
|
| 238 |
+
parser.add_argument("--n_vis", type=int, default=6,
|
| 239 |
+
help="Number of pairs to visualize")
|
| 240 |
+
parser.add_argument("--output", type=str, default="output/inspection_grid.png",
|
| 241 |
+
help="Where to save the visualization grid")
|
| 242 |
+
args = parser.parse_args()
|
| 243 |
+
|
| 244 |
+
data_dir = Path(args.data_dir)
|
| 245 |
+
if not data_dir.exists():
|
| 246 |
+
print(f"❌ data_dir '{data_dir}' does not exist.")
|
| 247 |
+
print("Run download_osf.py first, or set --data_dir to your data folder.")
|
| 248 |
+
sys.exit(1)
|
| 249 |
+
|
| 250 |
+
print("Scanning for image/mask pairs...")
|
| 251 |
+
pairs = find_image_mask_pairs(data_dir)
|
| 252 |
+
|
| 253 |
+
print_dataset_summary(data_dir, pairs)
|
| 254 |
+
|
| 255 |
+
if pairs:
|
| 256 |
+
print(f"\nGenerating visualization grid ({min(args.n_vis, len(pairs))} samples)...")
|
| 257 |
+
visualize_pairs(pairs, n=args.n_vis, output_path=Path(args.output))
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
if __name__ == "__main__":
|
| 261 |
+
main()
|
train.py
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
train.py
|
| 3 |
+
--------
|
| 4 |
+
Fine-tunes SegFormer-b0 on the OSF Ti-64 SEM fractography dataset.
|
| 5 |
+
Trains all three subsets (lack_of_fusion, keyhole, all_defects) separately.
|
| 6 |
+
CPU-optimized: small image size, small batch, few epochs.
|
| 7 |
+
|
| 8 |
+
Usage:
|
| 9 |
+
python train.py
|
| 10 |
+
python train.py --epochs 10 --image_size 256
|
| 11 |
+
|
| 12 |
+
Outputs (per subset):
|
| 13 |
+
checkpoints/<subset>/best_model.pt <- best checkpoint by val mIoU
|
| 14 |
+
checkpoints/<subset>/last_model.pt <- final epoch checkpoint
|
| 15 |
+
checkpoints/<subset>/history.json <- loss/mIoU per epoch
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
import json
|
| 20 |
+
import time
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
|
| 23 |
+
import numpy as np
|
| 24 |
+
import torch
|
| 25 |
+
import torch.nn as nn
|
| 26 |
+
from torch.utils.data import DataLoader, random_split
|
| 27 |
+
from transformers import SegformerForSemanticSegmentation
|
| 28 |
+
import torch.nn.functional as F
|
| 29 |
+
|
| 30 |
+
from dataset import FractographyDataset
|
| 31 |
+
|
| 32 |
+
# ── Config ────────────────────────────────────────────────────────────────────
|
| 33 |
+
NUM_CLASSES = 2 # background (0) + defect (1)
|
| 34 |
+
IMAGE_SIZE = (256, 256) # smaller = faster on CPU; increase if you have time
|
| 35 |
+
BATCH_SIZE = 2
|
| 36 |
+
EPOCHS = 15
|
| 37 |
+
LR = 6e-5
|
| 38 |
+
TRAIN_FRAC = 0.8
|
| 39 |
+
WEIGHT_DECAY = 0.01
|
| 40 |
+
|
| 41 |
+
SUBSETS = ["lack_of_fusion", "keyhole", "all_defects"]
|
| 42 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def compute_miou(preds: torch.Tensor, targets: torch.Tensor, num_classes: int) -> float:
|
| 46 |
+
"""Mean Intersection over Union across all classes."""
|
| 47 |
+
ious = []
|
| 48 |
+
preds = preds.view(-1)
|
| 49 |
+
targets = targets.view(-1)
|
| 50 |
+
for cls in range(num_classes):
|
| 51 |
+
pred_mask = preds == cls
|
| 52 |
+
target_mask = targets == cls
|
| 53 |
+
intersection = (pred_mask & target_mask).sum().item()
|
| 54 |
+
union = (pred_mask | target_mask).sum().item()
|
| 55 |
+
if union == 0:
|
| 56 |
+
continue # class not present in this batch
|
| 57 |
+
ious.append(intersection / union)
|
| 58 |
+
return float(np.mean(ious)) if ious else 0.0
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def dice_loss(logits: torch.Tensor, targets: torch.Tensor, smooth: float = 1.0) -> torch.Tensor:
|
| 62 |
+
"""
|
| 63 |
+
Soft Dice loss for binary segmentation.
|
| 64 |
+
Directly optimizes overlap — critical for imbalanced datasets.
|
| 65 |
+
logits: (B, num_classes, H, W)
|
| 66 |
+
targets: (B, H, W) integer labels
|
| 67 |
+
"""
|
| 68 |
+
probs = torch.softmax(logits, dim=1) # (B, C, H, W)
|
| 69 |
+
# Focus on defect class (index 1)
|
| 70 |
+
prob_defect = probs[:, 1] # (B, H, W)
|
| 71 |
+
target_defect = (targets == 1).float()
|
| 72 |
+
|
| 73 |
+
intersection = (prob_defect * target_defect).sum(dim=(1, 2))
|
| 74 |
+
union = prob_defect.sum(dim=(1, 2)) + target_defect.sum(dim=(1, 2))
|
| 75 |
+
dice = (2.0 * intersection + smooth) / (union + smooth)
|
| 76 |
+
return 1.0 - dice.mean()
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def combined_loss(
|
| 80 |
+
logits: torch.Tensor,
|
| 81 |
+
targets: torch.Tensor,
|
| 82 |
+
defect_weight: float = 10.0,
|
| 83 |
+
dice_weight: float = 0.5,
|
| 84 |
+
) -> torch.Tensor:
|
| 85 |
+
"""
|
| 86 |
+
Weighted CE + Dice loss.
|
| 87 |
+
defect_weight: how much extra to penalize missing defect pixels.
|
| 88 |
+
Start at 10x given ~6% defect pixels.
|
| 89 |
+
dice_weight: blend factor for Dice loss (0 = CE only, 1 = Dice only).
|
| 90 |
+
"""
|
| 91 |
+
# Upsample logits to match mask size
|
| 92 |
+
logits_up = F.interpolate(
|
| 93 |
+
logits, size=targets.shape[-2:], mode="bilinear", align_corners=False
|
| 94 |
+
)
|
| 95 |
+
# Weighted cross-entropy
|
| 96 |
+
weight = torch.tensor([1.0, defect_weight], device=logits.device)
|
| 97 |
+
ce = F.cross_entropy(logits_up, targets, weight=weight)
|
| 98 |
+
# Dice
|
| 99 |
+
dl = dice_loss(logits_up, targets)
|
| 100 |
+
return (1.0 - dice_weight) * ce + dice_weight * dl
|
| 101 |
+
|
| 102 |
+
def train_one_epoch(model, loader, optimizer, device):
|
| 103 |
+
model.train()
|
| 104 |
+
total_loss = 0.0
|
| 105 |
+
for images, masks in loader:
|
| 106 |
+
images = images.to(device)
|
| 107 |
+
masks = masks.to(device)
|
| 108 |
+
|
| 109 |
+
# Use HuggingFace built-in loss — passes labels at native resolution
|
| 110 |
+
# SegFormer internally downsamples labels to match logit size
|
| 111 |
+
outputs = model(pixel_values=images, labels=masks)
|
| 112 |
+
loss = outputs.loss
|
| 113 |
+
|
| 114 |
+
optimizer.zero_grad()
|
| 115 |
+
loss.backward()
|
| 116 |
+
optimizer.step()
|
| 117 |
+
|
| 118 |
+
total_loss += loss.item()
|
| 119 |
+
|
| 120 |
+
return total_loss / len(loader)
|
| 121 |
+
|
| 122 |
+
@torch.no_grad()
|
| 123 |
+
def evaluate(model, loader, device, num_classes):
|
| 124 |
+
model.eval()
|
| 125 |
+
total_loss = 0.0
|
| 126 |
+
all_miou = []
|
| 127 |
+
|
| 128 |
+
for images, masks in loader:
|
| 129 |
+
images = images.to(device)
|
| 130 |
+
masks = masks.to(device)
|
| 131 |
+
|
| 132 |
+
outputs = model(pixel_values=images, labels=masks)
|
| 133 |
+
loss = outputs.loss
|
| 134 |
+
logits = outputs.logits # (B, num_classes, H/4, W/4)
|
| 135 |
+
|
| 136 |
+
# Upsample logits to mask size
|
| 137 |
+
upsampled = F.interpolate(
|
| 138 |
+
logits,
|
| 139 |
+
size=masks.shape[-2:],
|
| 140 |
+
mode="bilinear",
|
| 141 |
+
align_corners=False,
|
| 142 |
+
)
|
| 143 |
+
preds = upsampled.argmax(dim=1) # (B, H, W)
|
| 144 |
+
|
| 145 |
+
total_loss += loss.item()
|
| 146 |
+
all_miou.append(compute_miou(preds.cpu(), masks.cpu(), num_classes))
|
| 147 |
+
|
| 148 |
+
return total_loss / len(loader), float(np.mean(all_miou))
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def train_subset(subset: str, data_root: Path, args):
|
| 152 |
+
subset_dir = data_root / subset
|
| 153 |
+
if not subset_dir.exists():
|
| 154 |
+
print(f"\n⚠️ Skipping '{subset}' — folder not found at {subset_dir}")
|
| 155 |
+
return
|
| 156 |
+
|
| 157 |
+
print(f"\n{'='*60}")
|
| 158 |
+
print(f"Training on subset: {subset}")
|
| 159 |
+
print(f"{'='*60}")
|
| 160 |
+
|
| 161 |
+
# Dataset
|
| 162 |
+
full_ds = FractographyDataset(
|
| 163 |
+
subset_dir,
|
| 164 |
+
split="all",
|
| 165 |
+
image_size=IMAGE_SIZE,
|
| 166 |
+
)
|
| 167 |
+
if len(full_ds) == 0:
|
| 168 |
+
print(f" ⚠️ No image/mask pairs found in {subset_dir}")
|
| 169 |
+
return
|
| 170 |
+
|
| 171 |
+
n_train = max(1, int(len(full_ds) * TRAIN_FRAC))
|
| 172 |
+
n_val = len(full_ds) - n_train
|
| 173 |
+
train_ds, val_ds = random_split(
|
| 174 |
+
full_ds, [n_train, n_val],
|
| 175 |
+
generator=torch.Generator().manual_seed(42)
|
| 176 |
+
)
|
| 177 |
+
print(f" Train: {len(train_ds)} | Val: {len(val_ds)}")
|
| 178 |
+
|
| 179 |
+
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
|
| 180 |
+
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
|
| 181 |
+
|
| 182 |
+
# Model
|
| 183 |
+
device = torch.device("cpu")
|
| 184 |
+
id2label = {0: "background", 1: "defect"}
|
| 185 |
+
label2id = {v: k for k, v in id2label.items()}
|
| 186 |
+
|
| 187 |
+
print(f" Loading SegFormer-b0 from HuggingFace...")
|
| 188 |
+
model = SegformerForSemanticSegmentation.from_pretrained(
|
| 189 |
+
"nvidia/mit-b0",
|
| 190 |
+
num_labels=NUM_CLASSES,
|
| 191 |
+
id2label=id2label,
|
| 192 |
+
label2id=label2id,
|
| 193 |
+
ignore_mismatched_sizes=True,
|
| 194 |
+
).to(device)
|
| 195 |
+
|
| 196 |
+
optimizer = torch.optim.AdamW([
|
| 197 |
+
{"params": model.segformer.parameters(), "lr": args.lr},
|
| 198 |
+
{"params": model.decode_head.parameters(), "lr": args.lr * 50},
|
| 199 |
+
], weight_decay=WEIGHT_DECAY)
|
| 200 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
|
| 201 |
+
|
| 202 |
+
# Checkpoint dir
|
| 203 |
+
ckpt_dir = Path("checkpoints") / subset
|
| 204 |
+
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
| 205 |
+
|
| 206 |
+
history = {"train_loss": [], "val_loss": [], "val_miou": []}
|
| 207 |
+
best_miou = 0.0
|
| 208 |
+
|
| 209 |
+
for epoch in range(1, args.epochs + 1):
|
| 210 |
+
t0 = time.time()
|
| 211 |
+
|
| 212 |
+
train_loss = train_one_epoch(model, train_loader, optimizer, device)
|
| 213 |
+
val_loss, val_miou = evaluate(model, val_loader, device, NUM_CLASSES)
|
| 214 |
+
scheduler.step()
|
| 215 |
+
|
| 216 |
+
elapsed = time.time() - t0
|
| 217 |
+
print(
|
| 218 |
+
f" Epoch {epoch:02d}/{args.epochs} | "
|
| 219 |
+
f"train_loss={train_loss:.4f} | "
|
| 220 |
+
f"val_loss={val_loss:.4f} | "
|
| 221 |
+
f"val_mIoU={val_miou:.4f} | "
|
| 222 |
+
f"{elapsed:.1f}s"
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
history["train_loss"].append(train_loss)
|
| 226 |
+
history["val_loss"].append(val_loss)
|
| 227 |
+
history["val_miou"].append(val_miou)
|
| 228 |
+
|
| 229 |
+
# Save best
|
| 230 |
+
if val_miou >= best_miou:
|
| 231 |
+
best_miou = val_miou
|
| 232 |
+
torch.save(model.state_dict(), ckpt_dir / "best_model.pt")
|
| 233 |
+
print(f" ✅ New best mIoU: {best_miou:.4f} — checkpoint saved")
|
| 234 |
+
|
| 235 |
+
# Save last + history
|
| 236 |
+
torch.save(model.state_dict(), ckpt_dir / "last_model.pt")
|
| 237 |
+
with open(ckpt_dir / "history.json", "w") as f:
|
| 238 |
+
json.dump(history, f, indent=2)
|
| 239 |
+
|
| 240 |
+
print(f"\n Done. Best val mIoU: {best_miou:.4f}")
|
| 241 |
+
print(f" Checkpoints saved to: {ckpt_dir.resolve()}")
|
| 242 |
+
return history
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def plot_histories(histories: dict):
|
| 246 |
+
"""Save a training curve plot for all subsets."""
|
| 247 |
+
try:
|
| 248 |
+
import matplotlib
|
| 249 |
+
matplotlib.use("Agg")
|
| 250 |
+
import matplotlib.pyplot as plt
|
| 251 |
+
|
| 252 |
+
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
|
| 253 |
+
fig.suptitle("SegFormer-b0 Training — OSF Ti-64", fontweight="bold")
|
| 254 |
+
|
| 255 |
+
for subset, h in histories.items():
|
| 256 |
+
epochs = range(1, len(h["train_loss"]) + 1)
|
| 257 |
+
axes[0].plot(epochs, h["train_loss"], label=f"{subset} train")
|
| 258 |
+
axes[0].plot(epochs, h["val_loss"], label=f"{subset} val", linestyle="--")
|
| 259 |
+
axes[1].plot(epochs, h["val_miou"], label=subset)
|
| 260 |
+
|
| 261 |
+
axes[0].set_title("Loss")
|
| 262 |
+
axes[0].set_xlabel("Epoch")
|
| 263 |
+
axes[0].legend(fontsize=7)
|
| 264 |
+
axes[1].set_title("Val mIoU")
|
| 265 |
+
axes[1].set_xlabel("Epoch")
|
| 266 |
+
axes[1].legend(fontsize=7)
|
| 267 |
+
|
| 268 |
+
out = Path("checkpoints/training_curves.png")
|
| 269 |
+
plt.tight_layout()
|
| 270 |
+
plt.savefig(out, dpi=150)
|
| 271 |
+
plt.close()
|
| 272 |
+
print(f"\n📈 Training curves saved to: {out.resolve()}")
|
| 273 |
+
except Exception as e:
|
| 274 |
+
print(f" (Could not save plot: {e})")
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
if __name__ == "__main__":
|
| 278 |
+
parser = argparse.ArgumentParser()
|
| 279 |
+
parser.add_argument("--data_dir", type=str, default="data")
|
| 280 |
+
parser.add_argument("--epochs", type=int, default=EPOCHS)
|
| 281 |
+
parser.add_argument("--lr", type=float, default=LR)
|
| 282 |
+
parser.add_argument("--image_size", type=int, default=256,
|
| 283 |
+
help="Square image size (256 recommended for CPU)")
|
| 284 |
+
args = parser.parse_args()
|
| 285 |
+
|
| 286 |
+
# Override IMAGE_SIZE from arg
|
| 287 |
+
IMAGE_SIZE = (args.image_size, args.image_size)
|
| 288 |
+
# Patch dataset module so it uses the right size
|
| 289 |
+
import dataset as ds_module
|
| 290 |
+
ds_module.IMAGE_SIZE = IMAGE_SIZE
|
| 291 |
+
|
| 292 |
+
data_root = Path(args.data_dir)
|
| 293 |
+
histories = {}
|
| 294 |
+
|
| 295 |
+
for subset in SUBSETS:
|
| 296 |
+
h = train_subset(subset, data_root, args)
|
| 297 |
+
if h:
|
| 298 |
+
histories[subset] = h
|
| 299 |
+
|
| 300 |
+
if histories:
|
| 301 |
+
plot_histories(histories)
|
| 302 |
+
print("\n✅ All subsets complete.")
|
| 303 |
+
print("\nSummary:")
|
| 304 |
+
for subset, h in histories.items():
|
| 305 |
+
best = max(h["val_miou"])
|
| 306 |
+
print(f" {subset:20s} best mIoU = {best:.4f}")
|