rcrane4 commited on
Commit
7538d69
·
verified ·
1 Parent(s): 2d8d80e

Upload 10 files

Browse files
Files changed (10) hide show
  1. app.py +376 -0
  2. dataset.py +208 -0
  3. diagnose.py +420 -0
  4. download_osf.py +146 -0
  5. features.py +421 -0
  6. inference.py +230 -0
  7. inspect_dataset.py +264 -0
  8. setup.py +21 -0
  9. test.py +261 -0
  10. train.py +306 -0
app.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ app.py
3
+ ------
4
+ FailureGPT — Gradio web interface.
5
+ Drag-and-drop SEM image → segmentation → features → AI diagnosis.
6
+
7
+ Usage:
8
+ pip install gradio
9
+ python app.py
10
+
11
+ Then open http://127.0.0.1:7860 in your browser.
12
+ """
13
+
14
+ import json
15
+ import os
16
+ from pathlib import Path
17
+
18
+ import gradio as gr
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn.functional as F
22
+ from PIL import Image
23
+
24
+ from dataset import IMAGE_SIZE, NUM_CLASSES
25
+ from features import (
26
+ load_model, load_image_tensor, predict_mask,
27
+ extract_features,
28
+ )
29
+ from diagnose import call_claude, format_diagnosis_report
30
+
31
+ # ── Load all three models at startup ─────────────────────────────────────────
32
+ SUBSETS = ["all_defects", "lack_of_fusion", "keyhole"]
33
+ MODELS = {}
34
+
35
+ print("Loading checkpoints...")
36
+ for subset in SUBSETS:
37
+ ckpt = Path("checkpoints") / subset / "best_model.pt"
38
+ if ckpt.exists():
39
+ MODELS[subset] = load_model(ckpt)
40
+ print(f" ✅ {subset}")
41
+ else:
42
+ print(f" ⚠️ {subset} — checkpoint not found")
43
+
44
+ # ─────────────────────────────────────────────────────────────────────────────
45
+
46
+ RISK_COLORS = {
47
+ "low": "#2ecc71",
48
+ "medium": "#f39c12",
49
+ "high": "#e74c3c",
50
+ "critical": "#8e44ad",
51
+ }
52
+
53
+ def run_pipeline(image: np.ndarray, subset: str) -> tuple:
54
+ if image is None:
55
+ return None, "No image provided.", "No image provided.", "—"
56
+ if subset not in MODELS:
57
+ return None, f"No checkpoint for '{subset}'.", "Train the model first.", "—"
58
+
59
+ model = MODELS[subset]
60
+
61
+ # Gradio gives H×W×3 uint8
62
+ arr = image.astype(np.float32)
63
+ if arr.ndim == 2:
64
+ arr = np.stack([arr]*3, axis=-1)
65
+ elif arr.shape[2] == 4:
66
+ arr = arr[:, :, :3]
67
+
68
+ # Normalize to [0,1]
69
+ arr_min, arr_max = arr.min(), arr.max()
70
+ arr_norm = (arr - arr_min) / (arr_max - arr_min + 1e-8) if arr_max > arr_min else arr / 255.0
71
+
72
+ # Build display copy
73
+ display_pil = Image.fromarray(
74
+ (arr_norm * 255).astype(np.uint8), mode="RGB"
75
+ ).resize((IMAGE_SIZE[1], IMAGE_SIZE[0]), Image.BILINEAR)
76
+ display_arr = np.array(display_pil, dtype=np.uint8)
77
+
78
+ # ImageNet normalization for model
79
+ arr_model = np.array(display_pil, dtype=np.float32) / 255.0
80
+ mean = np.array([0.485, 0.456, 0.406])
81
+ std = np.array([0.229, 0.224, 0.225])
82
+ arr_model = (arr_model - mean) / std
83
+ img_tensor = torch.from_numpy(arr_model).permute(2, 0, 1).float()
84
+
85
+ print(f"DEBUG display_arr: min={display_arr.min()} max={display_arr.max()}")
86
+ # ── Step 2: Segment ───────────────────────────────────────────────────────
87
+ mask = predict_mask(model, img_tensor, IMAGE_SIZE)
88
+ print(f"DEBUG mask: unique={np.unique(mask).tolist()} defect_px={( mask==1).sum()}")
89
+
90
+ # ── Step 3: Extract features ──────────────────────────────────────────────
91
+ features = extract_features(mask, IMAGE_SIZE)
92
+
93
+ # ── Step 4: Build overlay ─────────────────────────────────────────────────
94
+ # Replace the overlay build block with:
95
+ overlay = display_arr.copy()
96
+ # First apply cyan to defect pixels at full intensity
97
+ defect_mask = mask == 1
98
+ overlay[defect_mask] = [0, 212, 255]
99
+ # Blend only the background pixels, keep defects fully cyan
100
+ result = display_arr.copy()
101
+ result[~defect_mask] = display_arr[~defect_mask] # background unchanged
102
+ result[defect_mask] = (
103
+ display_arr[defect_mask].astype(float) * 0.3 +
104
+ np.array([0, 212, 255], dtype=float) * 0.7
105
+ ).clip(0, 255).astype(np.uint8)
106
+ overlay = result
107
+
108
+ from PIL import Image as PILImage
109
+ PILImage.fromarray(overlay).save("output/debug_overlay.png")
110
+ print(f"DEBUG saved overlay to output/debug_overlay.png")
111
+
112
+ # ── Step 5: Format features text ──────────────────────────────────────────
113
+ feat_lines = [
114
+ f"Defect Area: {features['defect_area_fraction']:.3f}%",
115
+ f"Defect Count: {features['defect_count']} blobs",
116
+ f"Mean Pore Area: {features.get('mean_pore_area_px', 0):.1f} px²",
117
+ f"Max Pore Area: {features.get('max_pore_area_px', 0)} px²",
118
+ f"Mean Aspect Ratio: {features['mean_aspect_ratio']:.3f}",
119
+ f" (1.0=circular · >2.0=elongated)",
120
+ f"Spatial Spread: {features['spatial_concentration']:.2f}",
121
+ f"Size Std Dev: {features['size_std']:.1f}",
122
+ f"",
123
+ f"Quadrant Distribution:",
124
+ f" TL {features['quadrant_distribution'][0]:.2f} "
125
+ f"TR {features['quadrant_distribution'][1]:.2f}",
126
+ f" BL {features['quadrant_distribution'][2]:.2f} "
127
+ f"BR {features['quadrant_distribution'][3]:.2f}",
128
+ f"",
129
+ f"Rule-based type: {features['defect_type']}",
130
+ f"Confidence: {features['confidence']}",
131
+ ]
132
+ features_text = "\n".join(feat_lines)
133
+
134
+ # ── Step 6: AI Diagnosis ──────────────────────────────────────────────────
135
+ if not os.environ.get("ANTHROPIC_API_KEY"):
136
+ diagnosis_text = (
137
+ "⚠️ ANTHROPIC_API_KEY not set.\n\n"
138
+ "Set it in your terminal:\n"
139
+ " $env:ANTHROPIC_API_KEY = 'sk-ant-...'\n\n"
140
+ "Features extracted successfully:\n\n"
141
+ + features_text
142
+ )
143
+ risk_label = features["defect_type"].upper()
144
+ else:
145
+ diagnosis = call_claude(features, "uploaded_image")
146
+ diagnosis_text = format_diagnosis_report(features, diagnosis, "uploaded_image")
147
+ risk = diagnosis.get("crack_initiation_risk", "unknown")
148
+ mech = diagnosis.get("dominant_failure_mechanism", "unknown")
149
+ risk_label = f"{risk.upper()} RISK — {mech}"
150
+
151
+ # Ensure output is exactly what Gradio expects
152
+ overlay = overlay.astype(np.uint8)
153
+ assert overlay.ndim == 3 and overlay.shape[2] == 3
154
+ print(f"DEBUG overlay: shape={overlay.shape} dtype={overlay.dtype} min={overlay.min()} max={overlay.max()}")
155
+ return overlay, features_text, diagnosis_text, risk_label
156
+ # ── Gradio UI ─────────────────────────────────────────────────────────────────
157
+
158
+ CSS = """
159
+ @import url('https://fonts.googleapis.com/css2?family=Space+Mono:wght@400;700&family=DM+Sans:wght@300;400;600&display=swap');
160
+
161
+ body, .gradio-container {
162
+ background: #080c14 !important;
163
+ font-family: 'DM Sans', sans-serif !important;
164
+ color: #c8d6e5 !important;
165
+ }
166
+
167
+ .gradio-container {
168
+ max-width: 1400px !important;
169
+ margin: 0 auto !important;
170
+ }
171
+
172
+ /* Header */
173
+ #header {
174
+ text-align: center;
175
+ padding: 2rem 0 1rem;
176
+ border-bottom: 1px solid #1e3a5f;
177
+ margin-bottom: 1.5rem;
178
+ }
179
+
180
+ #header h1 {
181
+ font-family: 'Space Mono', monospace !important;
182
+ font-size: 2.4rem !important;
183
+ font-weight: 700 !important;
184
+ color: #00d4ff !important;
185
+ letter-spacing: -1px;
186
+ margin: 0;
187
+ }
188
+
189
+ #header p {
190
+ color: #5a7a9a;
191
+ font-size: 0.9rem;
192
+ margin: 0.4rem 0 0;
193
+ font-family: 'Space Mono', monospace;
194
+ }
195
+
196
+ /* Risk badge */
197
+ #risk_label textarea, #risk_label input {
198
+ font-family: 'Space Mono', monospace !important;
199
+ font-size: 1.1rem !important;
200
+ font-weight: 700 !important;
201
+ color: #00d4ff !important;
202
+ background: #0d1825 !important;
203
+ border: 2px solid #00d4ff !important;
204
+ border-radius: 6px !important;
205
+ text-align: center !important;
206
+ padding: 0.6rem !important;
207
+ }
208
+
209
+ /* Textboxes */
210
+ textarea {
211
+ font-family: 'Space Mono', monospace !important;
212
+ font-size: 0.78rem !important;
213
+ background: #0a1520 !important;
214
+ color: #a8c4dc !important;
215
+ border: 1px solid #1e3a5f !important;
216
+ border-radius: 6px !important;
217
+ line-height: 1.6 !important;
218
+ }
219
+
220
+ /* Labels */
221
+ label span {
222
+ font-family: 'Space Mono', monospace !important;
223
+ font-size: 0.72rem !important;
224
+ color: #4a7a9a !important;
225
+ letter-spacing: 1px !important;
226
+ text-transform: uppercase !important;
227
+ }
228
+
229
+ /* Buttons */
230
+ button.primary {
231
+ background: linear-gradient(135deg, #003d66, #006699) !important;
232
+ border: 1px solid #00d4ff !important;
233
+ color: #00d4ff !important;
234
+ font-family: 'Space Mono', monospace !important;
235
+ font-weight: 700 !important;
236
+ letter-spacing: 1px !important;
237
+ border-radius: 6px !important;
238
+ transition: all 0.2s !important;
239
+ }
240
+
241
+ button.primary:hover {
242
+ background: linear-gradient(135deg, #006699, #00aacc) !important;
243
+ box-shadow: 0 0 20px rgba(0, 212, 255, 0.3) !important;
244
+ }
245
+
246
+ button.secondary {
247
+ background: #0a1520 !important;
248
+ border: 1px solid #1e3a5f !important;
249
+ color: #5a7a9a !important;
250
+ font-family: 'Space Mono', monospace !important;
251
+ border-radius: 6px !important;
252
+ }
253
+
254
+ /* Dropdown */
255
+ select, .wrap {
256
+ background: #0a1520 !important;
257
+ border: 1px solid #1e3a5f !important;
258
+ color: #a8c4dc !important;
259
+ font-family: 'Space Mono', monospace !important;
260
+ }
261
+
262
+ /* Image panels */
263
+ .image-container {
264
+ border: 1px solid #1e3a5f !important;
265
+ border-radius: 8px !important;
266
+ overflow: hidden !important;
267
+ }
268
+
269
+ /* Panel blocks */
270
+ .block {
271
+ background: #0a1520 !important;
272
+ border: 1px solid #1e3a5f !important;
273
+ border-radius: 8px !important;
274
+ }
275
+
276
+ /* Footer note */
277
+ #footer {
278
+ text-align: center;
279
+ padding: 1rem 0;
280
+ color: #2a4a6a;
281
+ font-family: 'Space Mono', monospace;
282
+ font-size: 0.7rem;
283
+ border-top: 1px solid #1e3a5f;
284
+ margin-top: 1.5rem;
285
+ }
286
+ """
287
+
288
+ with gr.Blocks(css=CSS, title="FailSafe") as demo:
289
+
290
+ gr.HTML("""
291
+ <div id="header">
292
+ <h1>⬡ FAILSAFE</h1>
293
+ <p>Ti-6Al-4V · LPBF Defect Analysis · SEM Fractography · Powered by SegFormer + Claude</p>
294
+ </div>
295
+ """)
296
+
297
+ with gr.Row():
298
+ # Left column — inputs
299
+ with gr.Column(scale=1):
300
+ image_input = gr.Image(
301
+ label="SEM FRACTOGRAPH — drag & drop or click to upload",
302
+ type="numpy",
303
+ height=500,
304
+ )
305
+ subset_input = gr.Dropdown(
306
+ choices=SUBSETS,
307
+ value="all_defects",
308
+ label="MODEL SUBSET",
309
+ )
310
+ with gr.Row():
311
+ run_btn = gr.Button("▶ ANALYZE", variant="primary", scale=3)
312
+ clear_btn = gr.Button("✕ CLEAR", variant="secondary", scale=1)
313
+
314
+ risk_output = gr.Textbox(
315
+ label="CRACK INITIATION RISK",
316
+ lines=1,
317
+ interactive=False,
318
+ elem_id="risk_label",
319
+ )
320
+
321
+ # Middle column — image output
322
+ with gr.Column(scale=1):
323
+ overlay_output = gr.Image(
324
+ label="DEFECT SEGMENTATION MAP",
325
+ height=500,
326
+ interactive=False,
327
+ )
328
+ features_output = gr.Textbox(
329
+ label="MORPHOLOGICAL FEATURES",
330
+ lines=14,
331
+ interactive=False,
332
+ )
333
+
334
+ # Right column — diagnosis
335
+ with gr.Column(scale=1):
336
+ diagnosis_output = gr.Textbox(
337
+ label="AI FAILURE DIAGNOSIS — Claude",
338
+ lines=28,
339
+ interactive=False,
340
+ )
341
+
342
+ gr.HTML("""
343
+ <div id="footer">
344
+ FailureGPT · ASU Mechanical Engineering · OSF Ti-64 Dataset ·
345
+ SegFormer-b0 fine-tuned · Claude Reasoning Layer
346
+ </div>
347
+ """)
348
+
349
+ # Wire up
350
+ run_btn.click(
351
+ fn=run_pipeline,
352
+ inputs=[image_input, subset_input],
353
+ outputs=[overlay_output, features_output, diagnosis_output, risk_output],
354
+ )
355
+ clear_btn.click(
356
+ fn=lambda: (None, None, "", "", ""),
357
+ outputs=[image_input, overlay_output, features_output, diagnosis_output, risk_output],
358
+ )
359
+
360
+ # Example images
361
+ example_images = list(Path("data/all_defects/images_8bit").glob("*.png"))[:3]
362
+ if example_images:
363
+ gr.Examples(
364
+ examples=[[str(p), "all_defects"] for p in example_images],
365
+ inputs=[image_input, subset_input],
366
+ label="EXAMPLE IMAGES",
367
+ )
368
+
369
+
370
+ if __name__ == "__main__":
371
+ demo.launch(
372
+ server_name="127.0.0.1",
373
+ server_port=7860,
374
+ share=False,
375
+ show_error=True,
376
+ )
dataset.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ dataset.py
3
+ ----------
4
+ PyTorch Dataset class for the OSF Ti-64 SEM fractography dataset.
5
+ Use this after running inspect_dataset.py to confirm your mask format.
6
+
7
+ Key decisions you may need to make after inspection:
8
+ - If masks are binary (0/255): set NUM_CLASSES=2, update MASK_SCALE
9
+ - If masks are RGB color: set COLOR_MASK=True and define COLOR_TO_LABEL
10
+ - If masks are integer labels (0..N): use as-is (ideal case)
11
+
12
+ Usage:
13
+ from dataset import FractographyDataset
14
+ ds = FractographyDataset("data/", split="train")
15
+ img, mask = ds[0]
16
+ """
17
+
18
+ from pathlib import Path
19
+ from typing import Callable, Optional
20
+
21
+ import numpy as np
22
+ import torch
23
+ from PIL import Image
24
+ from torch.utils.data import Dataset, DataLoader, random_split
25
+ import torchvision.transforms.functional as TF
26
+ import random
27
+
28
+
29
+ # ── Config — update after running inspect_dataset.py ────────────────────────
30
+ NUM_CLASSES = 2 # update once you know how many classes are in your masks
31
+ IMAGE_SIZE = (512, 512) # resize target; SegFormer-b0 default input
32
+ MASK_SCALE = 255
33
+ # If masks use RGB color encoding instead of integer labels, set this to True
34
+ # and populate COLOR_TO_LABEL below.
35
+ COLOR_MASK = False
36
+ COLOR_TO_LABEL: dict[tuple, int] = {
37
+ # (R, G, B): class_index
38
+ # e.g. (255, 0, 0): 1,
39
+ }
40
+ # ─────────────────────────────────────────────────────────────────────────────
41
+
42
+
43
+ def rgb_mask_to_label(mask_rgb: np.ndarray, color_to_label: dict) -> np.ndarray:
44
+ """Convert an H×W×3 RGB mask to an H×W integer label mask."""
45
+ label = np.zeros(mask_rgb.shape[:2], dtype=np.int64)
46
+ for color, cls_idx in color_to_label.items():
47
+ match = np.all(mask_rgb == np.array(color), axis=-1)
48
+ label[match] = cls_idx
49
+ return label
50
+
51
+
52
+ class FractographyDataset(Dataset):
53
+ """
54
+ OSF Ti-64 SEM Fractography Dataset.
55
+
56
+ Args:
57
+ data_dir: Root of downloaded data (contains subfolders with images/ + masks/).
58
+ split: "train", "val", or "all" (no splitting, returns everything).
59
+ transform: Optional callable applied to both image and mask (augmentation).
60
+ image_size: Resize target (H, W).
61
+ """
62
+
63
+ IMAGE_EXTS = {".png", ".tif", ".tiff", ".jpg", ".jpeg"}
64
+
65
+ def __init__(
66
+ self,
67
+ data_dir: str | Path,
68
+ split: str = "all",
69
+ transform: Optional[Callable] = None,
70
+ image_size: tuple[int, int] = IMAGE_SIZE,
71
+ ):
72
+ self.data_dir = Path(data_dir)
73
+ self.split = split
74
+ self.transform = transform
75
+ self.image_size = image_size
76
+ self.pairs = self._find_pairs()
77
+
78
+ if not self.pairs:
79
+ raise FileNotFoundError(
80
+ f"No image/mask pairs found in {self.data_dir}. "
81
+ "Run inspect_dataset.py to diagnose."
82
+ )
83
+ def _find_pairs(self) -> list[tuple[Path, Path]]:
84
+ pairs = []
85
+ for images_dir in sorted(self.data_dir.rglob("images_8bit")):
86
+ if not images_dir.is_dir():
87
+ continue
88
+ masks_dir = images_dir.parent / "masks_8bit"
89
+ if not masks_dir.exists():
90
+ continue
91
+ for img_path in sorted(images_dir.iterdir()):
92
+ if img_path.suffix.lower() not in self.IMAGE_EXTS:
93
+ continue
94
+ mask_path = masks_dir / img_path.name
95
+ if mask_path.exists():
96
+ pairs.append((img_path, mask_path))
97
+ return pairs
98
+ def _load_image(self, path: Path) -> torch.Tensor:
99
+ img = Image.open(path).convert("RGB")
100
+ img = img.resize((self.image_size[1], self.image_size[0]), Image.BILINEAR)
101
+ arr = np.array(img, dtype=np.float32) / 255.0
102
+ mean = np.array([0.485, 0.456, 0.406])
103
+ std = np.array([0.229, 0.224, 0.225])
104
+ arr = (arr - mean) / std
105
+ return torch.from_numpy(arr).permute(2, 0, 1).float()
106
+ def _load_mask(self, path: Path) -> torch.Tensor:
107
+ mask_pil = Image.open(path)
108
+
109
+ if COLOR_MASK:
110
+ mask_arr = np.array(mask_pil.convert("RGB"))
111
+ mask_arr = rgb_mask_to_label(mask_arr, COLOR_TO_LABEL)
112
+ else:
113
+ mask_arr = np.array(mask_pil.convert("L"), dtype=np.int64)
114
+ if MASK_SCALE > 1:
115
+ mask_arr = mask_arr // MASK_SCALE # e.g. 0/255 → 0/1
116
+
117
+ mask_pil_resized = Image.fromarray(mask_arr.astype(np.uint8)).resize(
118
+ (self.image_size[1], self.image_size[0]), Image.NEAREST # NEAREST preserves labels
119
+ )
120
+ mask_arr = np.array(mask_pil_resized, dtype=np.int64)
121
+ return torch.from_numpy(mask_arr).long() # H×W
122
+
123
+ def _augment(self, image: torch.Tensor, mask: torch.Tensor):
124
+ """Shared spatial augmentations (applied identically to image and mask)."""
125
+ # Random horizontal flip
126
+ if random.random() > 0.5:
127
+ image = TF.hflip(image)
128
+ mask = TF.hflip(mask.unsqueeze(0)).squeeze(0)
129
+
130
+ # Random vertical flip
131
+ if random.random() > 0.5:
132
+ image = TF.vflip(image)
133
+ mask = TF.vflip(mask.unsqueeze(0)).squeeze(0)
134
+
135
+ # Random 90° rotation
136
+ k = random.choice([0, 1, 2, 3])
137
+ if k:
138
+ image = torch.rot90(image, k, dims=[1, 2])
139
+ mask = torch.rot90(mask, k, dims=[0, 1])
140
+
141
+ return image, mask
142
+
143
+ def __len__(self) -> int:
144
+ return len(self.pairs)
145
+
146
+ def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
147
+ img_path, mask_path = self.pairs[idx]
148
+ image = self._load_image(img_path)
149
+ mask = self._load_mask(mask_path)
150
+
151
+ if self.split == "train" and self.transform is None:
152
+ image, mask = self._augment(image, mask)
153
+ elif self.transform is not None:
154
+ image, mask = self.transform(image, mask)
155
+
156
+ return image, mask
157
+
158
+ def __repr__(self) -> str:
159
+ return (
160
+ f"FractographyDataset("
161
+ f"n={len(self)}, split='{self.split}', "
162
+ f"image_size={self.image_size}, classes={NUM_CLASSES})"
163
+ )
164
+
165
+
166
+ def get_dataloaders(
167
+ data_dir: str | Path,
168
+ batch_size: int = 4,
169
+ train_frac: float = 0.8,
170
+ num_workers: int = 2,
171
+ ) -> tuple[DataLoader, DataLoader]:
172
+ """
173
+ Returns (train_loader, val_loader) with 80/20 split.
174
+ """
175
+ full_dataset = FractographyDataset(data_dir, split="all")
176
+ n_train = int(len(full_dataset) * train_frac)
177
+ n_val = len(full_dataset) - n_train
178
+ train_ds, val_ds = random_split(full_dataset, [n_train, n_val])
179
+
180
+ # Override split tag so augmentation fires for train only
181
+ train_ds.dataset.split = "train"
182
+
183
+ train_loader = DataLoader(
184
+ train_ds, batch_size=batch_size, shuffle=True,
185
+ num_workers=num_workers, pin_memory=True
186
+ )
187
+ val_loader = DataLoader(
188
+ val_ds, batch_size=batch_size, shuffle=False,
189
+ num_workers=num_workers, pin_memory=True
190
+ )
191
+ print(f"Train: {len(train_ds)} samples | Val: {len(val_ds)} samples")
192
+ return train_loader, val_loader
193
+
194
+
195
+ # ── Quick sanity check ────────────────────────────────────────────────────────
196
+ if __name__ == "__main__":
197
+ import sys
198
+ data_dir = sys.argv[1] if len(sys.argv) > 1 else "data"
199
+
200
+ try:
201
+ ds = FractographyDataset(data_dir)
202
+ print(ds)
203
+ img, mask = ds[0]
204
+ print(f"Image tensor: {img.shape} dtype={img.dtype} range=[{img.min():.2f}, {img.max():.2f}]")
205
+ print(f"Mask tensor: {mask.shape} dtype={mask.dtype} unique={mask.unique().tolist()}")
206
+ print("\n✅ Dataset loads correctly.")
207
+ except FileNotFoundError as e:
208
+ print(f"❌ {e}")
diagnose.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ diagnose.py
3
+ -----------
4
+ Week 4: Generative reasoning layer.
5
+
6
+ Takes the feature dict output from features.py and calls the Claude API
7
+ to produce a structured engineering failure diagnosis.
8
+
9
+ The LLM receives:
10
+ - Quantitative morphological features from the segmentation
11
+ - Material context (Ti-6Al-4V, LPBF process)
12
+ - Defect type classification
13
+ And returns:
14
+ - Natural language diagnosis
15
+ - Crack initiation risk assessment
16
+ - Recommended follow-up actions
17
+
18
+ Usage:
19
+ # Single image full pipeline (segment → extract → diagnose)
20
+ python diagnose.py --image data/all_defects/images/001-Overview-EP04V24.png
21
+ --subset all_defects
22
+
23
+ # From existing feature JSON
24
+ python diagnose.py --json output/features/all_defects_features.json
25
+
26
+ # Interactive mode
27
+ python diagnose.py --interactive --subset all_defects
28
+ """
29
+
30
+ import argparse
31
+ import json
32
+ import time
33
+ from pathlib import Path
34
+
35
+ import torch
36
+ import torch.nn.functional as F
37
+ import numpy as np
38
+ import matplotlib
39
+ matplotlib.use("Agg")
40
+ import matplotlib.pyplot as plt
41
+ from PIL import Image
42
+ from transformers import SegformerForSemanticSegmentation
43
+
44
+ from dataset import FractographyDataset, IMAGE_SIZE, NUM_CLASSES
45
+ from features import (
46
+ load_model, load_image_tensor, predict_mask,
47
+ extract_features, visualize_features
48
+ )
49
+
50
+ # ── Anthropic API ─────────────────────────────────────────────────────────────
51
+ try:
52
+ import anthropic
53
+ HAS_ANTHROPIC = True
54
+ except ImportError:
55
+ HAS_ANTHROPIC = False
56
+ print("⚠️ anthropic package not found. Run: pip install anthropic")
57
+ # ─────────────────────────────────────────────────────────────────────────────
58
+
59
+ MATERIAL_CONTEXT = """
60
+ Material: Ti-6Al-4V (Grade 5 titanium alloy)
61
+ Process: Laser Powder Bed Fusion (LPBF) additive manufacturing
62
+ Application context: High-performance structural components (aerospace/defense)
63
+ Specimen type: Bend test bar, fractured in four-point bending
64
+ """
65
+
66
+ SYSTEM_PROMPT = """You are an expert materials engineer specializing in fractography
67
+ and failure analysis of additively manufactured aerospace components.
68
+ You analyze quantitative defect features extracted from SEM (Scanning Electron Microscope)
69
+ images of Ti-6Al-4V fracture surfaces produced by Laser Powder Bed Fusion (LPBF).
70
+
71
+ Your role is to:
72
+ 1. Interpret morphological defect features in the context of LPBF process physics
73
+ 2. Assess crack initiation and propagation risk based on defect characteristics
74
+ 3. Provide actionable engineering recommendations
75
+ 4. Be precise and quantitative — reference the actual feature values in your diagnosis
76
+
77
+ Always structure your response as valid JSON with these exact keys:
78
+ {
79
+ "diagnosis_summary": "2-3 sentence plain English summary",
80
+ "defect_interpretation": "detailed interpretation of the morphological features",
81
+ "crack_initiation_risk": "low | medium | high | critical",
82
+ "risk_rationale": "why you assigned this risk level, referencing specific features",
83
+ "dominant_failure_mechanism": "e.g. lack of fusion porosity, keyhole porosity, mixed",
84
+ "critical_regions": "which quadrants or regions pose highest risk",
85
+ "recommendations": ["recommendation 1", "recommendation 2", "recommendation 3"],
86
+ "confidence": "low | medium | high",
87
+ "confidence_rationale": "why"
88
+ }
89
+ """
90
+
91
+ def build_user_prompt(features: dict, image_name: str = "") -> str:
92
+ return f"""
93
+ Analyze the following defect features extracted from an SEM fractograph of a
94
+ Ti-6Al-4V LPBF test bar.
95
+
96
+ Material & Process Context:
97
+ {MATERIAL_CONTEXT}
98
+
99
+ Image: {image_name}
100
+
101
+ Extracted Morphological Features:
102
+ - Defect area fraction: {features.get('defect_area_fraction', 0):.3f}% of fracture surface
103
+ - Defect blob count: {features.get('defect_count', 0)} distinct pores/defects
104
+ - Mean pore area: {features.get('mean_pore_area_px', 0):.1f} px² (at 256×256 resolution)
105
+ - Max pore area: {features.get('max_pore_area_px', 0)} px²
106
+ - Mean aspect ratio: {features.get('mean_aspect_ratio', 0):.3f}
107
+ (1.0 = perfectly circular/keyhole, >2.0 = elongated/lack-of-fusion)
108
+ - Spatial spread (std): {features.get('spatial_concentration', 0):.2f} px
109
+ - Size heterogeneity: {features.get('size_std', 0):.1f} px² std dev
110
+ - Quadrant distribution:
111
+ Top-left: {features.get('quadrant_distribution', [0,0,0,0])[0]:.3f}
112
+ Top-right: {features.get('quadrant_distribution', [0,0,0,0])[1]:.3f}
113
+ Bottom-left: {features.get('quadrant_distribution', [0,0,0,0])[2]:.3f}
114
+ Bottom-right: {features.get('quadrant_distribution', [0,0,0,0])[3]:.3f}
115
+ - Rule-based defect type: {features.get('defect_type', 'unknown')}
116
+ (confidence: {features.get('confidence', 'unknown')})
117
+
118
+ Provide a structured engineering diagnosis as JSON.
119
+ """
120
+
121
+
122
+ def call_claude(features: dict, image_name: str = "") -> dict:
123
+ """Call Claude API and return parsed diagnosis dict."""
124
+ if not HAS_ANTHROPIC:
125
+ return {"error": "anthropic package not installed"}
126
+
127
+ client = anthropic.Anthropic() # uses ANTHROPIC_API_KEY env var
128
+ prompt = build_user_prompt(features, image_name)
129
+
130
+ try:
131
+ response = client.messages.create(
132
+ model="claude-sonnet-4-20250514",
133
+ max_tokens=1000,
134
+ system=SYSTEM_PROMPT,
135
+ messages=[{"role": "user", "content": prompt}]
136
+ )
137
+ raw_text = response.content[0].text.strip()
138
+
139
+ # Strip markdown code fences if present
140
+ if raw_text.startswith("```"):
141
+ raw_text = raw_text.split("```")[1]
142
+ if raw_text.startswith("json"):
143
+ raw_text = raw_text[4:]
144
+ raw_text = raw_text.strip()
145
+
146
+ diagnosis = json.loads(raw_text)
147
+ return diagnosis
148
+
149
+ except json.JSONDecodeError as e:
150
+ return {"error": f"JSON parse error: {e}", "raw": raw_text}
151
+ except Exception as e:
152
+ return {"error": str(e)}
153
+
154
+
155
+ def format_diagnosis_report(features: dict, diagnosis: dict, image_name: str = "") -> str:
156
+ """Format a human-readable diagnosis report."""
157
+ sep = "=" * 60
158
+ lines = [
159
+ sep,
160
+ f"FAILURE ANALYSIS REPORT",
161
+ f"Image: {image_name}",
162
+ f"Material: Ti-6Al-4V (LPBF)",
163
+ sep,
164
+ "",
165
+ "QUANTITATIVE FEATURES",
166
+ f" Defect area: {features.get('defect_area_fraction', 0):.3f}%",
167
+ f" Defect count: {features.get('defect_count', 0)}",
168
+ f" Mean aspect ratio:{features.get('mean_aspect_ratio', 0):.3f}",
169
+ f" Rule-based type: {features.get('defect_type', 'unknown')}",
170
+ "",
171
+ ]
172
+
173
+ if "error" in diagnosis:
174
+ lines += [f"⚠️ Diagnosis error: {diagnosis['error']}"]
175
+ return "\n".join(lines)
176
+
177
+ lines += [
178
+ "AI DIAGNOSIS",
179
+ f" Failure mechanism: {diagnosis.get('dominant_failure_mechanism', 'N/A')}",
180
+ f" Crack init. risk: {diagnosis.get('crack_initiation_risk', 'N/A').upper()}",
181
+ f" Critical regions: {diagnosis.get('critical_regions', 'N/A')}",
182
+ f" Confidence: {diagnosis.get('confidence', 'N/A')}",
183
+ "",
184
+ "SUMMARY",
185
+ f" {diagnosis.get('diagnosis_summary', '')}",
186
+ "",
187
+ "DEFECT INTERPRETATION",
188
+ f" {diagnosis.get('defect_interpretation', '')}",
189
+ "",
190
+ "RISK RATIONALE",
191
+ f" {diagnosis.get('risk_rationale', '')}",
192
+ "",
193
+ "RECOMMENDATIONS",
194
+ ]
195
+ for i, rec in enumerate(diagnosis.get("recommendations", []), 1):
196
+ lines.append(f" {i}. {rec}")
197
+ lines.append(sep)
198
+
199
+ return "\n".join(lines)
200
+
201
+
202
+ def visualize_diagnosis(
203
+ image_path: Path,
204
+ mask: np.ndarray,
205
+ features: dict,
206
+ diagnosis: dict,
207
+ out_path: Path,
208
+ ):
209
+ """Save a full diagnosis visualization."""
210
+ raw = np.array(Image.open(image_path), dtype=np.float32)
211
+ raw = (raw - raw.min()) / (raw.max() - raw.min() + 1e-8)
212
+ raw_resized = np.array(
213
+ Image.fromarray((raw * 255).astype(np.uint8)).resize(
214
+ (IMAGE_SIZE[1], IMAGE_SIZE[0]), Image.BILINEAR
215
+ )
216
+ )
217
+
218
+ # Risk color
219
+ risk_colors = {
220
+ "low": "#2ecc71", "medium": "#f39c12",
221
+ "high": "#e74c3c", "critical": "#8e44ad"
222
+ }
223
+ risk = diagnosis.get("crack_initiation_risk", "medium")
224
+ risk_color = risk_colors.get(risk, "#888888")
225
+
226
+ fig = plt.figure(figsize=(18, 8))
227
+ fig.patch.set_facecolor("#0d0d1a")
228
+
229
+ # Title
230
+ mech = diagnosis.get("dominant_failure_mechanism", "Unknown")
231
+ fig.suptitle(
232
+ f"FailureGPT — {image_path.name}\n"
233
+ f"Mechanism: {mech} | Crack Risk: {risk.upper()}",
234
+ fontsize=12, fontweight="bold", color="white", y=1.01
235
+ )
236
+
237
+ # Image panel
238
+ ax1 = fig.add_subplot(1, 3, 1)
239
+ ax1.imshow(raw_resized, cmap="gray")
240
+ ax1.set_title("SEM Fractograph", color="white", fontsize=9)
241
+ ax1.axis("off")
242
+ ax1.set_facecolor("#0d0d1a")
243
+
244
+ # Segmentation overlay
245
+ ax2 = fig.add_subplot(1, 3, 2)
246
+ overlay = np.stack([raw_resized]*3, axis=-1).copy()
247
+ overlay[mask == 1] = [0, 212, 255]
248
+ ax2.imshow(overlay)
249
+ ax2.set_title(
250
+ f"Defect Map\n{features['defect_area_fraction']:.2f}% | "
251
+ f"{features['defect_count']} blobs | AR={features['mean_aspect_ratio']:.2f}",
252
+ color="white", fontsize=9
253
+ )
254
+ ax2.axis("off")
255
+ ax2.set_facecolor("#0d0d1a")
256
+
257
+ # Diagnosis text panel
258
+ ax3 = fig.add_subplot(1, 3, 3)
259
+ ax3.set_facecolor("#0d0d1a")
260
+ ax3.axis("off")
261
+
262
+ if "error" not in diagnosis:
263
+ summary = diagnosis.get("diagnosis_summary", "")
264
+ interp = diagnosis.get("defect_interpretation", "")
265
+ recs = diagnosis.get("recommendations", [])
266
+ conf = diagnosis.get("confidence", "")
267
+
268
+ # Word wrap helper
269
+ def wrap(text, width=42):
270
+ words, lines, line = text.split(), [], ""
271
+ for w in words:
272
+ if len(line) + len(w) + 1 <= width:
273
+ line += (" " if line else "") + w
274
+ else:
275
+ lines.append(line)
276
+ line = w
277
+ if line:
278
+ lines.append(line)
279
+ return "\n".join(lines)
280
+
281
+ report = (
282
+ f"RISK: {risk.upper()}\n"
283
+ f"{'─'*38}\n\n"
284
+ f"SUMMARY\n{wrap(summary)}\n\n"
285
+ f"INTERPRETATION\n{wrap(interp[:200])}\n\n"
286
+ f"RECOMMENDATIONS\n"
287
+ )
288
+ for i, r in enumerate(recs[:3], 1):
289
+ report += f"{i}. {wrap(r[:80])}\n"
290
+ report += f"\nConfidence: {conf}"
291
+
292
+ ax3.text(
293
+ 0.05, 0.97, report,
294
+ transform=ax3.transAxes,
295
+ fontsize=7.5, verticalalignment="top",
296
+ fontfamily="monospace", color="white",
297
+ bbox=dict(
298
+ boxstyle="round", facecolor="#1a1a2e",
299
+ alpha=0.9, edgecolor=risk_color, linewidth=2
300
+ )
301
+ )
302
+ else:
303
+ ax3.text(
304
+ 0.1, 0.5, f"API Error:\n{diagnosis['error']}",
305
+ transform=ax3.transAxes, color="red", fontsize=9
306
+ )
307
+
308
+ ax3.set_title("AI Diagnosis", color="white", fontsize=9)
309
+
310
+ plt.tight_layout()
311
+ out_path.parent.mkdir(parents=True, exist_ok=True)
312
+ plt.savefig(out_path, dpi=150, bbox_inches="tight",
313
+ facecolor="#0d0d1a")
314
+ plt.close()
315
+ print(f" Visualization → {out_path.resolve()}")
316
+
317
+
318
+ def run_full_pipeline(image_path: Path, subset: str, save_vis: bool = True) -> dict:
319
+ """Full pipeline: image → segmentation → features → diagnosis."""
320
+ ckpt_path = Path("checkpoints") / subset / "best_model.pt"
321
+ if not ckpt_path.exists():
322
+ print(f"❌ No checkpoint at {ckpt_path}")
323
+ return {}
324
+
325
+ print(f"\n{'='*60}")
326
+ print(f"FailureGPT Pipeline")
327
+ print(f"Image: {image_path.name}")
328
+ print(f"Subset: {subset}")
329
+ print(f"{'='*60}")
330
+
331
+ # Step 1: Segment
332
+ print("Step 1/3: Segmenting...")
333
+ model = load_model(ckpt_path)
334
+ img_tensor = load_image_tensor(image_path, IMAGE_SIZE)
335
+ mask = predict_mask(model, img_tensor, IMAGE_SIZE)
336
+
337
+ # Step 2: Extract features
338
+ print("Step 2/3: Extracting features...")
339
+ features = extract_features(mask, IMAGE_SIZE)
340
+ print(f" → {features['defect_count']} blobs, "
341
+ f"{features['defect_area_fraction']:.2f}% defect, "
342
+ f"AR={features['mean_aspect_ratio']:.2f}")
343
+
344
+ # Step 3: Generate diagnosis
345
+ print("Step 3/3: Generating diagnosis...")
346
+ diagnosis = call_claude(features, image_path.name)
347
+
348
+ # Print report
349
+ report = format_diagnosis_report(features, diagnosis, image_path.name)
350
+ print(report)
351
+
352
+ # Save visualization
353
+ if save_vis:
354
+ out_path = Path("output/diagnosis") / f"{image_path.stem}_diagnosis.png"
355
+ visualize_diagnosis(image_path, mask, features, diagnosis, out_path)
356
+
357
+ # Save JSON
358
+ result = {"image": str(image_path), "features": features, "diagnosis": diagnosis}
359
+ json_out = Path("output/diagnosis") / f"{image_path.stem}_diagnosis.json"
360
+ json_out.parent.mkdir(parents=True, exist_ok=True)
361
+ with open(json_out, "w") as f:
362
+ json.dump(result, f, indent=2)
363
+ print(f" JSON → {json_out.resolve()}")
364
+
365
+ return result
366
+
367
+
368
+ def interactive_mode(subset: str, data_dir: Path):
369
+ """Interactive CLI: pick an image, get a diagnosis."""
370
+ subset_dir = data_dir / subset
371
+ ds = FractographyDataset(subset_dir, split="all", image_size=IMAGE_SIZE)
372
+
373
+ print(f"\nAvailable images in '{subset}':")
374
+ for i, (img_path, _) in enumerate(ds.pairs[:20]):
375
+ print(f" [{i:2d}] {img_path.name}")
376
+
377
+ try:
378
+ idx = int(input("\nEnter image index: "))
379
+ img_path, _ = ds.pairs[idx]
380
+ run_full_pipeline(img_path, subset)
381
+ except (ValueError, IndexError) as e:
382
+ print(f"Invalid selection: {e}")
383
+
384
+
385
+ if __name__ == "__main__":
386
+ parser = argparse.ArgumentParser()
387
+ parser.add_argument("--image", type=str, default=None)
388
+ parser.add_argument("--subset", type=str, default="all_defects")
389
+ parser.add_argument("--json", type=str, default=None,
390
+ help="Path to existing features JSON from features.py")
391
+ parser.add_argument("--interactive", action="store_true")
392
+ parser.add_argument("--data_dir", type=str, default="data")
393
+ parser.add_argument("--n", type=int, default=3,
394
+ help="Number of images to process in batch mode")
395
+ args = parser.parse_args()
396
+
397
+ if args.interactive:
398
+ interactive_mode(args.subset, Path(args.data_dir))
399
+
400
+ elif args.json:
401
+ # Diagnose from existing feature JSON
402
+ with open(args.json) as f:
403
+ feature_list = json.load(f)
404
+ if isinstance(feature_list, list):
405
+ for item in feature_list[:args.n]:
406
+ diagnosis = call_claude(item, item.get("image", ""))
407
+ print(format_diagnosis_report(item, diagnosis, item.get("image", "")))
408
+ else:
409
+ diagnosis = call_claude(feature_list)
410
+ print(format_diagnosis_report(feature_list, diagnosis))
411
+
412
+ elif args.image:
413
+ run_full_pipeline(Path(args.image), args.subset)
414
+
415
+ else:
416
+ # Batch: run on first n images of subset
417
+ subset_dir = Path(args.data_dir) / args.subset
418
+ ds = FractographyDataset(subset_dir, split="all", image_size=IMAGE_SIZE)
419
+ for img_path, _ in list(ds.pairs)[:args.n]:
420
+ run_full_pipeline(img_path, args.subset)
download_osf.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ download_osf.py
3
+ ---------------
4
+ Downloads the OSF Ti-64 SEM fractography dataset (osf.io/gdwyb).
5
+ The dataset has 3 sub-components:
6
+ - Lack of Fusion defects
7
+ - Keyhole defects
8
+ - All Defects (combined)
9
+
10
+ Each sub-dataset contains SEM images + ground truth segmentation masks.
11
+
12
+ Usage:
13
+ python download_osf.py
14
+
15
+ Output structure:
16
+ data/
17
+ lack_of_fusion/
18
+ images/
19
+ masks/
20
+ keyhole/
21
+ images/
22
+ masks/
23
+ all_defects/
24
+ images/
25
+ masks/
26
+ """
27
+
28
+ import os
29
+ import requests
30
+ from pathlib import Path
31
+ from tqdm import tqdm
32
+
33
+ # OSF project GUIDs for each sub-dataset
34
+ # Inspect at: https://osf.io/gdwyb/
35
+ OSF_API = "https://api.osf.io/v2"
36
+ OSF_PROJECT_ID = "gdwyb" # top-level fractography project
37
+
38
+ DATA_DIR = Path("data")
39
+
40
+
41
+ def list_osf_files(node_id: str) -> list[dict]:
42
+ """Recursively list all files in an OSF node."""
43
+ url = f"{OSF_API}/nodes/{node_id}/files/osfstorage/"
44
+ files = []
45
+ while url:
46
+ resp = requests.get(url, timeout=30)
47
+ resp.raise_for_status()
48
+ data = resp.json()
49
+ for item in data["data"]:
50
+ if item["attributes"]["kind"] == "file":
51
+ files.append({
52
+ "name": item["attributes"]["name"],
53
+ "path": item["attributes"]["materialized_path"],
54
+ "download": item["links"]["download"],
55
+ "size": item["attributes"]["size"],
56
+ })
57
+ elif item["attributes"]["kind"] == "folder":
58
+ # recurse into folders
59
+ folder_id = item["relationships"]["files"]["links"]["related"]["href"]
60
+ files.extend(list_osf_folder(folder_id))
61
+ url = data["links"].get("next")
62
+ return files
63
+
64
+
65
+ def list_osf_folder(url: str) -> list[dict]:
66
+ """Recursively list files inside an OSF folder URL."""
67
+ files = []
68
+ while url:
69
+ resp = requests.get(url, timeout=30)
70
+ resp.raise_for_status()
71
+ data = resp.json()
72
+ for item in data["data"]:
73
+ if item["attributes"]["kind"] == "file":
74
+ files.append({
75
+ "name": item["attributes"]["name"],
76
+ "path": item["attributes"]["materialized_path"],
77
+ "download": item["links"]["download"],
78
+ "size": item["attributes"]["size"],
79
+ })
80
+ elif item["attributes"]["kind"] == "folder":
81
+ folder_url = item["relationships"]["files"]["links"]["related"]["href"]
82
+ files.extend(list_osf_folder(folder_url))
83
+ url = data["links"].get("next")
84
+ return files
85
+
86
+
87
+ def download_file(url: str, dest: Path):
88
+ """Download a file with a progress bar."""
89
+ dest.parent.mkdir(parents=True, exist_ok=True)
90
+ if dest.exists():
91
+ print(f" [skip] {dest.name} already exists")
92
+ return
93
+ resp = requests.get(url, stream=True, timeout=60)
94
+ resp.raise_for_status()
95
+ total = int(resp.headers.get("content-length", 0))
96
+ with open(dest, "wb") as f, tqdm(
97
+ desc=dest.name, total=total, unit="B", unit_scale=True, leave=False
98
+ ) as bar:
99
+ for chunk in resp.iter_content(chunk_size=8192):
100
+ f.write(chunk)
101
+ bar.update(len(chunk))
102
+
103
+
104
+ def download_osf_project(node_id: str, local_root: Path):
105
+ """Download all files from an OSF node into local_root, preserving folder structure."""
106
+ print(f"\n📂 Fetching file list from OSF node: {node_id}")
107
+ try:
108
+ files = list_osf_files(node_id)
109
+ except Exception as e:
110
+ print(f" ⚠️ Could not list files: {e}")
111
+ print(" → You may need to download manually from https://osf.io/gdwyb/")
112
+ return []
113
+
114
+ print(f" Found {len(files)} files")
115
+ for f in files:
116
+ # strip leading slash from materialized path
117
+ rel_path = f["path"].lstrip("/")
118
+ dest = local_root / rel_path
119
+ print(f" ↓ {rel_path} ({f['size'] / 1024:.1f} KB)")
120
+ try:
121
+ download_file(f["download"], dest)
122
+ except Exception as e:
123
+ print(f" ⚠️ Failed: {e}")
124
+ return files
125
+
126
+
127
+ if __name__ == "__main__":
128
+ DATA_DIR.mkdir(exist_ok=True)
129
+ print("=" * 60)
130
+ print("OSF Ti-64 Fractography Dataset Downloader")
131
+ print("Project: https://osf.io/gdwyb/")
132
+ print("=" * 60)
133
+
134
+ files = download_osf_project(OSF_PROJECT_ID, DATA_DIR)
135
+
136
+ if files:
137
+ print(f"\n✅ Download complete. Files saved to: {DATA_DIR.resolve()}")
138
+ else:
139
+ print("\n⚠️ Automatic download failed.")
140
+ print("Manual download steps:")
141
+ print(" 1. Go to https://osf.io/gdwyb/")
142
+ print(" 2. Click each sub-component (Lack of Fusion, Key Hole, All Defects)")
143
+ print(" 3. Download the zip and extract into data/<subfolder>/")
144
+ print(" Expected structure:")
145
+ print(" data/lack_of_fusion/images/*.png (or .tif)")
146
+ print(" data/lack_of_fusion/masks/*.png")
features.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ features.py
3
+ -----------
4
+ Week 3: Feature extraction + defect type classification.
5
+
6
+ Takes a trained SegFormer checkpoint, runs inference on an image,
7
+ and extracts quantitative morphological features from the predicted mask.
8
+ These features feed into:
9
+ 1. A rule-based defect classifier (lack_of_fusion vs keyhole vs clean)
10
+ 2. A structured feature dict consumed by the generative reasoning layer (Week 4)
11
+
12
+ Extracted features:
13
+ - defect_area_fraction : % of image that is defect
14
+ - defect_count : number of distinct defect regions
15
+ - mean_pore_area : mean area of individual defect blobs (px²)
16
+ - max_pore_area : largest single defect region
17
+ - mean_aspect_ratio : mean of (major_axis / minor_axis) per blob
18
+ → circular pores ≈ 1.0 (keyhole)
19
+ → elongated pores > 2.0 (lack of fusion)
20
+ - spatial_concentration : std of defect centroid positions (spread)
21
+ - size_std : std of pore areas (heterogeneity)
22
+ - quadrant_distribution : defect fraction per image quadrant
23
+
24
+ Usage:
25
+ python features.py --image data/all_defects/images/001-Overview-EP04V24.png
26
+ --subset all_defects
27
+
28
+ python features.py --subset all_defects --all # run on all images in subset
29
+ """
30
+
31
+ import argparse
32
+ import json
33
+ import math
34
+ from pathlib import Path
35
+
36
+ import matplotlib
37
+ matplotlib.use("Agg")
38
+ import matplotlib.pyplot as plt
39
+ import numpy as np
40
+ import torch
41
+ import torch.nn.functional as F
42
+ from PIL import Image
43
+ from transformers import SegformerForSemanticSegmentation
44
+
45
+ from dataset import FractographyDataset, IMAGE_SIZE, NUM_CLASSES, MASK_SCALE
46
+
47
+ # ── Config ────────────────────────────────────────────────────────────────────
48
+ DEVICE = torch.device("cpu")
49
+
50
+ # Rule-based classification thresholds (tunable)
51
+ # Lack of fusion: many small irregular pores, high aspect ratio
52
+ # Keyhole: fewer larger circular pores, low aspect ratio
53
+ THRESHOLDS = {
54
+ "min_defect_fraction_to_classify": 0.002,
55
+ "keyhole_max_aspect_ratio": 1.6, # wider keyhole band
56
+ "lof_min_count": 20, # need many blobs for LoF
57
+ }
58
+ # ─────────────────────────────────────────────────────────────────────────────
59
+ def load_model(checkpoint_path: Path) -> SegformerForSemanticSegmentation:
60
+ from transformers import SegformerConfig
61
+
62
+ config = SegformerConfig.from_pretrained("nvidia/mit-b0")
63
+ config.num_labels = NUM_CLASSES
64
+ config.id2label = {0: "background", 1: "defect"}
65
+ config.label2id = {"background": 0, "defect": 1}
66
+
67
+ model = SegformerForSemanticSegmentation(config)
68
+
69
+ state = torch.load(checkpoint_path, map_location=DEVICE, weights_only=True)
70
+ result = model.load_state_dict(state, strict=True)
71
+ model.eval()
72
+ return model
73
+ def load_image_tensor(path: Path, image_size: tuple) -> torch.Tensor:
74
+ img = Image.open(path).convert("RGB")
75
+ img = img.resize((image_size[1], image_size[0]), Image.BILINEAR)
76
+ arr = np.array(img, dtype=np.float32) / 255.0
77
+ mean = np.array([0.485, 0.456, 0.406])
78
+ std = np.array([0.229, 0.224, 0.225])
79
+ arr = (arr - mean) / std
80
+ return torch.from_numpy(arr).permute(2, 0, 1).float()
81
+ @torch.no_grad()
82
+ def predict_mask(model, image_tensor: torch.Tensor, target_size: tuple) -> np.ndarray:
83
+ outputs = model(pixel_values=image_tensor.unsqueeze(0))
84
+ logits = outputs.logits
85
+ upsampled = F.interpolate(
86
+ logits, size=target_size, mode="bilinear", align_corners=False
87
+ )
88
+ pred = upsampled.squeeze(0).argmax(dim=0).numpy()
89
+ return pred.astype(np.uint8)
90
+
91
+ def connected_components(mask: np.ndarray) -> tuple[np.ndarray, int]:
92
+ """
93
+ Simple flood-fill connected components (no scipy dependency).
94
+ Returns (labeled_mask, num_components).
95
+ """
96
+ h, w = mask.shape
97
+ labels = np.zeros((h, w), dtype=np.int32)
98
+ current_label = 0
99
+
100
+ def neighbors(r, c):
101
+ for dr, dc in [(-1,0),(1,0),(0,-1),(0,1)]:
102
+ nr, nc = r+dr, c+dc
103
+ if 0 <= nr < h and 0 <= nc < w:
104
+ yield nr, nc
105
+
106
+ for r in range(h):
107
+ for c in range(w):
108
+ if mask[r, c] == 1 and labels[r, c] == 0:
109
+ current_label += 1
110
+ stack = [(r, c)]
111
+ labels[r, c] = current_label
112
+ while stack:
113
+ cr, cc = stack.pop()
114
+ for nr, nc in neighbors(cr, cc):
115
+ if mask[nr, nc] == 1 and labels[nr, nc] == 0:
116
+ labels[nr, nc] = current_label
117
+ stack.append((nr, nc))
118
+
119
+ return labels, current_label
120
+
121
+
122
+ def blob_properties(labels: np.ndarray, num_blobs: int) -> list[dict]:
123
+ """Compute area, centroid, and aspect ratio for each labeled blob."""
124
+ props = []
125
+ for label_id in range(1, num_blobs + 1):
126
+ ys, xs = np.where(labels == label_id)
127
+ if len(ys) == 0:
128
+ continue
129
+ area = len(ys)
130
+ cy, cx = ys.mean(), xs.mean()
131
+
132
+ # Bounding box aspect ratio as proxy for shape
133
+ h_bbox = ys.max() - ys.min() + 1
134
+ w_bbox = xs.max() - xs.min() + 1
135
+ major = max(h_bbox, w_bbox)
136
+ minor = min(h_bbox, w_bbox)
137
+ aspect_ratio = major / minor if minor > 0 else 1.0
138
+
139
+ props.append({
140
+ "area": area,
141
+ "centroid": (float(cy), float(cx)),
142
+ "aspect_ratio": float(aspect_ratio),
143
+ "bbox": (int(ys.min()), int(xs.min()), int(ys.max()), int(xs.max())),
144
+ })
145
+ return props
146
+
147
+
148
+ def extract_features(mask: np.ndarray, image_size: tuple) -> dict:
149
+ """Extract quantitative morphological features from a binary prediction mask."""
150
+ H, W = image_size
151
+ total_px = H * W
152
+ defect_px = int((mask == 1).sum())
153
+ defect_frac = defect_px / total_px
154
+
155
+ if defect_px == 0:
156
+ return {
157
+ "defect_area_fraction": 0.0,
158
+ "defect_count": 0,
159
+ "mean_pore_area_px": 0.0,
160
+ "max_pore_area_px": 0,
161
+ "mean_aspect_ratio": 0.0,
162
+ "spatial_concentration": 0.0,
163
+ "size_std": 0.0,
164
+ "quadrant_distribution": [0.0, 0.0, 0.0, 0.0],
165
+ "defect_type": "clean",
166
+ "confidence": "high",
167
+ }
168
+
169
+ # Connected components (note: slow for large masks — acceptable at 256×256)
170
+ labels, n_blobs = connected_components(mask)
171
+ props = blob_properties(labels, n_blobs)
172
+
173
+ areas = [p["area"] for p in props]
174
+ aspect_ratios = [p["aspect_ratio"] for p in props]
175
+ centroids = [p["centroid"] for p in props]
176
+
177
+ mean_area = float(np.mean(areas)) if areas else 0.0
178
+ max_area = int(max(areas)) if areas else 0
179
+ mean_ar = float(np.mean(aspect_ratios)) if aspect_ratios else 0.0
180
+ size_std = float(np.std(areas)) if areas else 0.0
181
+
182
+ # Spatial concentration: std of centroid distances from image center
183
+ if centroids:
184
+ cy_center, cx_center = H / 2, W / 2
185
+ dists = [math.sqrt((c[0]-cy_center)**2 + (c[1]-cx_center)**2)
186
+ for c in centroids]
187
+ spatial_conc = float(np.std(dists))
188
+ else:
189
+ spatial_conc = 0.0
190
+
191
+ # Quadrant distribution
192
+ half_h, half_w = H // 2, W // 2
193
+ quads = [
194
+ float((mask[:half_h, :half_w] == 1).sum()), # top-left
195
+ float((mask[:half_h, half_w:] == 1).sum()), # top-right
196
+ float((mask[half_h:, :half_w] == 1).sum()), # bottom-left
197
+ float((mask[half_h:, half_w:] == 1).sum()), # bottom-right
198
+ ]
199
+ total_defect = sum(quads) + 1e-8
200
+ quad_dist = [q / total_defect for q in quads]
201
+
202
+ # ── Rule-based classification ─────────────────────────────────────────────
203
+ defect_type, confidence = classify_defect(defect_frac, n_blobs, mean_ar, mean_area)
204
+
205
+ return {
206
+ "defect_area_fraction": round(defect_frac * 100, 3), # as %
207
+ "defect_count": n_blobs,
208
+ "mean_pore_area_px": round(mean_area, 1),
209
+ "max_pore_area_px": max_area,
210
+ "mean_aspect_ratio": round(mean_ar, 3),
211
+ "spatial_concentration": round(spatial_conc, 2),
212
+ "size_std": round(size_std, 1),
213
+ "quadrant_distribution": [round(q, 3) for q in quad_dist],
214
+ "defect_type": defect_type,
215
+ "confidence": confidence,
216
+ }
217
+
218
+
219
+ def classify_defect(
220
+ defect_frac: float,
221
+ count: int,
222
+ mean_ar: float,
223
+ mean_area: float,
224
+ ) -> tuple[str, str]:
225
+ """
226
+ Rule-based defect classifier.
227
+ Returns (defect_type, confidence).
228
+
229
+ Lack of fusion: many small irregular pores, higher aspect ratio
230
+ Keyhole: fewer larger circular pores, lower aspect ratio
231
+ Mixed: both morphologies present
232
+ Clean: below detection threshold
233
+ """
234
+ t = THRESHOLDS
235
+ if defect_frac < t["min_defect_fraction_to_classify"]:
236
+ return "clean", "high"
237
+
238
+ is_circular = mean_ar <= t["keyhole_max_aspect_ratio"]
239
+ is_many = count >= t["lof_min_count"]
240
+
241
+ if is_circular and not is_many:
242
+ return "keyhole_porosity", "high"
243
+ elif not is_circular and is_many:
244
+ return "lack_of_fusion", "high"
245
+ elif is_circular and is_many:
246
+ return "mixed", "medium"
247
+ else:
248
+ return "lack_of_fusion", "medium"
249
+
250
+
251
+ def visualize_features(
252
+ image_path: Path,
253
+ mask: np.ndarray,
254
+ features: dict,
255
+ out_path: Path,
256
+ ):
257
+ """Save a single-image feature visualization."""
258
+ raw = np.array(Image.open(image_path), dtype=np.float32)
259
+ raw = (raw - raw.min()) / (raw.max() - raw.min() + 1e-8)
260
+ raw_resized = np.array(
261
+ Image.fromarray((raw * 255).astype(np.uint8)).resize(
262
+ (IMAGE_SIZE[1], IMAGE_SIZE[0]), Image.BILINEAR
263
+ )
264
+ )
265
+
266
+ fig, axes = plt.subplots(1, 3, figsize=(15, 5))
267
+ fig.suptitle(
268
+ f"Feature Extraction — {image_path.name}\n"
269
+ f"Defect Type: {features['defect_type'].upper()} "
270
+ f"(confidence: {features['confidence']})",
271
+ fontsize=11, fontweight="bold"
272
+ )
273
+
274
+ # Image
275
+ axes[0].imshow(raw_resized, cmap="gray")
276
+ axes[0].set_title("SEM Image", fontsize=9)
277
+ axes[0].axis("off")
278
+
279
+ # Mask with blob labels
280
+ overlay = np.stack([raw_resized, raw_resized, raw_resized], axis=-1).copy()
281
+ overlay[mask == 1] = [0, 212, 255] # cyan defects
282
+ axes[1].imshow(overlay)
283
+ axes[1].set_title(
284
+ f"Prediction\n{features['defect_area_fraction']:.2f}% defect | "
285
+ f"{features['defect_count']} blobs",
286
+ fontsize=9
287
+ )
288
+ axes[1].axis("off")
289
+
290
+ # Feature summary text
291
+ axes[2].axis("off")
292
+ feature_text = (
293
+ f"Defect Area: {features['defect_area_fraction']:.3f}%\n"
294
+ f"Defect Count: {features['defect_count']}\n"
295
+ f"Mean Pore Area: {features['mean_pore_area_px']:.1f} px²\n"
296
+ f"Max Pore Area: {features['max_pore_area_px']} px²\n"
297
+ f"Mean Aspect Ratio: {features['mean_aspect_ratio']:.3f}\n"
298
+ f" (1.0=circle, >2=elongated)\n"
299
+ f"Spatial Spread: {features['spatial_concentration']:.2f}\n"
300
+ f"Size Std Dev: {features['size_std']:.1f}\n\n"
301
+ f"Quadrant Distribution:\n"
302
+ f" TL:{features['quadrant_distribution'][0]:.2f} "
303
+ f"TR:{features['quadrant_distribution'][1]:.2f}\n"
304
+ f" BL:{features['quadrant_distribution'][2]:.2f} "
305
+ f"BR:{features['quadrant_distribution'][3]:.2f}\n\n"
306
+ f"─────────────────────────\n"
307
+ f"DEFECT TYPE: {features['defect_type']}\n"
308
+ f"CONFIDENCE: {features['confidence']}"
309
+ )
310
+ axes[2].text(
311
+ 0.05, 0.95, feature_text,
312
+ transform=axes[2].transAxes,
313
+ fontsize=9, verticalalignment="top",
314
+ fontfamily="monospace",
315
+ bbox=dict(boxstyle="round", facecolor="#1a1a2e", alpha=0.8, edgecolor="#00d4ff"),
316
+ color="white"
317
+ )
318
+ axes[2].set_title("Extracted Features", fontsize=9)
319
+
320
+ out_path.parent.mkdir(parents=True, exist_ok=True)
321
+ plt.tight_layout()
322
+ plt.savefig(out_path, dpi=150, bbox_inches="tight")
323
+ plt.close()
324
+ print(f" Saved → {out_path.resolve()}")
325
+
326
+
327
+ def run_on_image(image_path: Path, subset: str) -> dict:
328
+ ckpt_path = Path("checkpoints") / subset / "best_model.pt"
329
+ if not ckpt_path.exists():
330
+ print(f"❌ No checkpoint at {ckpt_path}")
331
+ return {}
332
+
333
+ print(f"\nImage: {image_path.name}")
334
+ print(f"Subset: {subset}")
335
+
336
+ model = load_model(ckpt_path)
337
+ img_tensor = load_image_tensor(image_path, IMAGE_SIZE)
338
+ mask = predict_mask(model, img_tensor, IMAGE_SIZE)
339
+ features = extract_features(mask, IMAGE_SIZE)
340
+
341
+ print(f"Defect type: {features['defect_type']} ({features['confidence']} confidence)")
342
+ print(f"Defect area: {features['defect_area_fraction']:.3f}%")
343
+ print(f"Blob count: {features['defect_count']}")
344
+ print(f"Mean AR: {features['mean_aspect_ratio']:.3f}")
345
+ print(json.dumps(features, indent=2))
346
+
347
+ out_path = Path("output/features") / f"{image_path.stem}_features.png"
348
+ visualize_features(image_path, mask, features, out_path)
349
+
350
+ return features
351
+
352
+
353
+ def run_on_subset(subset: str, data_dir: Path, n: int = 6):
354
+ """Run feature extraction on n images from a subset and print summary."""
355
+ subset_dir = data_dir / subset
356
+ if not subset_dir.exists():
357
+ print(f"⚠️ {subset_dir} not found")
358
+ return
359
+
360
+ ds = FractographyDataset(subset_dir, split="all", image_size=IMAGE_SIZE)
361
+ ckpt_path = Path("checkpoints") / subset / "best_model.pt"
362
+ if not ckpt_path.exists():
363
+ print(f"⚠️ No checkpoint for {subset}")
364
+ return
365
+
366
+ model = load_model(ckpt_path)
367
+ results = []
368
+
369
+ print(f"\n{'='*60}")
370
+ print(f"Feature extraction: {subset} ({min(n, len(ds))} images)")
371
+ print(f"{'='*60}")
372
+
373
+ for idx in range(min(n, len(ds))):
374
+ img_path, _ = ds.pairs[idx]
375
+ img_tensor = load_image_tensor(img_path, IMAGE_SIZE)
376
+ mask = predict_mask(model, img_tensor, IMAGE_SIZE)
377
+ features = extract_features(mask, IMAGE_SIZE)
378
+ features["image"] = img_path.name
379
+ results.append(features)
380
+
381
+ out_path = Path("output/features") / subset / f"{img_path.stem}_features.png"
382
+ visualize_features(img_path, mask, features, out_path)
383
+
384
+ # Summary
385
+ print(f"\n Classification summary:")
386
+ from collections import Counter
387
+ counts = Counter(r["defect_type"] for r in results)
388
+ for dtype, count in counts.items():
389
+ print(f" {dtype:25s}: {count}")
390
+
391
+ # Save results JSON
392
+ json_out = Path("output/features") / f"{subset}_features.json"
393
+ json_out.parent.mkdir(parents=True, exist_ok=True)
394
+ with open(json_out, "w") as f:
395
+ json.dump(results, f, indent=2)
396
+ print(f"\n Feature JSON → {json_out.resolve()}")
397
+
398
+
399
+ if __name__ == "__main__":
400
+ parser = argparse.ArgumentParser()
401
+ parser.add_argument("--image", type=str, default=None,
402
+ help="Path to a single SEM image")
403
+ parser.add_argument("--subset", type=str, default="all_defects",
404
+ help="lack_of_fusion | keyhole | all_defects")
405
+ parser.add_argument("--all", action="store_true",
406
+ help="Run on all images in subset (up to --n)")
407
+ parser.add_argument("--n", type=int, default=6,
408
+ help="Number of images to process in --all mode")
409
+ parser.add_argument("--data_dir", type=str, default="data")
410
+ args = parser.parse_args()
411
+
412
+ if args.image:
413
+ run_on_image(Path(args.image), args.subset)
414
+ else:
415
+ subsets = (
416
+ ["lack_of_fusion", "keyhole", "all_defects"]
417
+ if args.subset == "all"
418
+ else [args.subset]
419
+ )
420
+ for subset in subsets:
421
+ run_on_subset(subset, Path(args.data_dir), n=args.n)
inference.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ inference.py
3
+ ------------
4
+ Loads a trained SegFormer checkpoint and runs inference on SEM images.
5
+ Saves a visualization grid showing: original image | predicted mask | overlay.
6
+
7
+ Usage:
8
+ # Run on a specific subset's val images
9
+ python inference.py --subset lack_of_fusion
10
+
11
+ # Run on a specific image
12
+ python inference.py --image path/to/image.png --subset keyhole
13
+
14
+ # Run all three subsets
15
+ python inference.py --subset all
16
+ """
17
+
18
+ import argparse
19
+ import random
20
+ from pathlib import Path
21
+ from features import load_model, load_image_tensor
22
+
23
+
24
+ import matplotlib
25
+ matplotlib.use("Agg")
26
+ import matplotlib.pyplot as plt
27
+ import matplotlib.patches as mpatches
28
+ import numpy as np
29
+ import torch
30
+ import torch.nn.functional as F
31
+ from PIL import Image
32
+ from transformers import SegformerForSemanticSegmentation
33
+
34
+ from dataset import FractographyDataset, IMAGE_SIZE, NUM_CLASSES, MASK_SCALE
35
+
36
+ # ── Config ────────────────────────────────────────────────────────────────────
37
+ DEVICE = torch.device("cpu")
38
+ N_SAMPLES = 6 # images to visualize per subset
39
+ LABEL_MAP = {0: ("Background", "#1a1a2e"), 1: ("Defect", "#00d4ff")}
40
+ # ─────────────────────────────────────────────────────────────────────────────
41
+
42
+
43
+ def load_model(checkpoint_path: Path) -> SegformerForSemanticSegmentation:
44
+ id2label = {0: "background", 1: "defect"}
45
+ label2id = {v: k for k, v in id2label.items()}
46
+ model = SegformerForSemanticSegmentation.from_pretrained(
47
+ "nvidia/mit-b0",
48
+ num_labels=NUM_CLASSES,
49
+ id2label=id2label,
50
+ label2id=label2id,
51
+ ignore_mismatched_sizes=True,
52
+ )
53
+ state = torch.load(checkpoint_path, map_location=DEVICE, weights_only=True)
54
+ model.load_state_dict(state)
55
+ model.eval()
56
+ return model
57
+
58
+
59
+ def load_raw_image(path: Path) -> np.ndarray:
60
+ """Load SEM image as a displayable uint8 RGB array (handles 16-bit)."""
61
+ arr = np.array(Image.open(path), dtype=np.float32)
62
+ arr = (arr - arr.min()) / (arr.max() - arr.min() + 1e-8)
63
+ rgb = np.stack([arr, arr, arr], axis=-1)
64
+ return (rgb * 255).astype(np.uint8)
65
+
66
+
67
+ def predict(model, image_tensor: torch.Tensor, target_size: tuple) -> np.ndarray:
68
+ """Run inference and return (H, W) prediction mask as numpy array."""
69
+ with torch.no_grad():
70
+ outputs = model(pixel_values=image_tensor.unsqueeze(0))
71
+ logits = outputs.logits # (1, C, H/4, W/4)
72
+ upsampled = F.interpolate(
73
+ logits, size=target_size, mode="bilinear", align_corners=False
74
+ )
75
+ pred = upsampled.squeeze(0).argmax(dim=0).numpy() # (H, W)
76
+ return pred
77
+
78
+
79
+ def colorize(mask: np.ndarray) -> np.ndarray:
80
+ rgb = np.zeros((*mask.shape, 3), dtype=np.uint8)
81
+ for val, (_, hex_color) in LABEL_MAP.items():
82
+ r, g, b = tuple(int(hex_color.lstrip("#")[i:i+2], 16) for i in (0, 2, 4))
83
+ rgb[mask == val] = (r, g, b)
84
+ return rgb
85
+
86
+
87
+ def compute_stats(pred: np.ndarray, gt: np.ndarray) -> dict:
88
+ """Compute per-image IoU and defect coverage."""
89
+ pred_defect = pred == 1
90
+ gt_defect = gt == 1
91
+ intersection = (pred_defect & gt_defect).sum()
92
+ union = (pred_defect | gt_defect).sum()
93
+ iou = intersection / union if union > 0 else float("nan")
94
+ coverage_pred = pred_defect.sum() / pred.size * 100
95
+ coverage_gt = gt_defect.sum() / gt.size * 100
96
+ return {"iou": iou, "pred_coverage": coverage_pred, "gt_coverage": coverage_gt}
97
+
98
+
99
+ def run_inference(subset: str, args):
100
+ data_dir = Path(args.data_dir) / subset
101
+ ckpt_path = Path("checkpoints") / subset / "best_model.pt"
102
+ out_dir = Path("output") / "inference"
103
+ out_dir.mkdir(parents=True, exist_ok=True)
104
+
105
+ if not data_dir.exists():
106
+ print(f"⚠️ Skipping '{subset}' — data not found at {data_dir}")
107
+ return
108
+ if not ckpt_path.exists():
109
+ print(f"⚠️ Skipping '{subset}' — no checkpoint at {ckpt_path}")
110
+ return
111
+
112
+ print(f"\n{'='*60}")
113
+ print(f"Inference: {subset}")
114
+ print(f"Checkpoint: {ckpt_path}")
115
+ print(f"{'='*60}")
116
+
117
+ model = load_model(ckpt_path)
118
+
119
+ # Load dataset to get image/mask pairs
120
+ ds = FractographyDataset(data_dir, split="all", image_size=IMAGE_SIZE)
121
+ indices = list(range(len(ds)))
122
+ random.seed(42)
123
+ random.shuffle(indices)
124
+ sample_indices = indices[:N_SAMPLES]
125
+
126
+ # Build figure
127
+ n = len(sample_indices)
128
+ fig, axes = plt.subplots(n, 4, figsize=(16, n * 4))
129
+ if n == 1:
130
+ axes = [axes]
131
+ fig.suptitle(
132
+ f"SegFormer Inference — {subset.replace('_', ' ').title()}",
133
+ fontsize=13, fontweight="bold"
134
+ )
135
+
136
+ ious = []
137
+ for row, idx in enumerate(sample_indices):
138
+ img_path, mask_path = ds.pairs[idx]
139
+ img_tensor, gt_mask = ds[idx]
140
+
141
+ # Raw image for display (16-bit safe)
142
+ raw_img = load_raw_image(img_path)
143
+
144
+ # GT mask (undo MASK_SCALE)
145
+ gt_arr = gt_mask.numpy() # already scaled by dataset
146
+
147
+ # Predict
148
+ pred = predict(model, img_tensor, target_size=IMAGE_SIZE)
149
+
150
+ # Resize raw image to match prediction size for display
151
+ raw_resized = np.array(
152
+ Image.fromarray(raw_img).resize(
153
+ (IMAGE_SIZE[1], IMAGE_SIZE[0]), Image.BILINEAR
154
+ )
155
+ )
156
+
157
+ # Stats
158
+ stats = compute_stats(pred, gt_arr)
159
+ ious.append(stats["iou"])
160
+
161
+ # Colorize
162
+ pred_colored = colorize(pred)
163
+ gt_colored = colorize(gt_arr)
164
+ overlay = (raw_resized.astype(float) * 0.6 +
165
+ pred_colored.astype(float) * 0.4).astype(np.uint8)
166
+
167
+ # Plot
168
+ axes[row][0].imshow(raw_resized, cmap="gray")
169
+ axes[row][0].set_title(f"Image\n{img_path.name}", fontsize=7)
170
+ axes[row][0].axis("off")
171
+
172
+ axes[row][1].imshow(gt_colored)
173
+ axes[row][1].set_title(
174
+ f"Ground Truth\n{stats['gt_coverage']:.1f}% defect", fontsize=7
175
+ )
176
+ axes[row][1].axis("off")
177
+
178
+ axes[row][2].imshow(pred_colored)
179
+ axes[row][2].set_title(
180
+ f"Prediction\n{stats['pred_coverage']:.1f}% defect", fontsize=7
181
+ )
182
+ axes[row][2].axis("off")
183
+
184
+ axes[row][3].imshow(overlay)
185
+ iou_str = f"{stats['iou']:.3f}" if not np.isnan(stats["iou"]) else "N/A"
186
+ axes[row][3].set_title(f"Overlay\nIoU={iou_str}", fontsize=7)
187
+ axes[row][3].axis("off")
188
+
189
+ # Legend
190
+ patches = [
191
+ mpatches.Patch(color=LABEL_MAP[0][1], label="Background"),
192
+ mpatches.Patch(color=LABEL_MAP[1][1], label="Defect"),
193
+ ]
194
+ fig.legend(handles=patches, loc="lower center", ncol=2,
195
+ bbox_to_anchor=(0.5, -0.01), fontsize=9)
196
+
197
+ mean_iou = np.nanmean(ious)
198
+ fig.text(0.5, -0.03, f"Mean IoU (these samples): {mean_iou:.4f}",
199
+ ha="center", fontsize=10, fontweight="bold")
200
+
201
+ plt.tight_layout()
202
+ out_path = out_dir / f"{subset}_inference.png"
203
+ plt.savefig(out_path, dpi=150, bbox_inches="tight")
204
+ plt.close()
205
+
206
+ print(f" Mean IoU (sampled): {mean_iou:.4f}")
207
+ print(f" Saved → {out_path.resolve()}")
208
+
209
+
210
+ if __name__ == "__main__":
211
+ parser = argparse.ArgumentParser()
212
+ parser.add_argument("--subset", type=str, default="all",
213
+ help="lack_of_fusion | keyhole | all_defects | all")
214
+ parser.add_argument("--data_dir", type=str, default="data")
215
+ parser.add_argument("--n", type=int, default=6,
216
+ help="Number of images to visualize")
217
+ args = parser.parse_args()
218
+
219
+ N_SAMPLES = args.n
220
+
221
+ subsets = (
222
+ ["lack_of_fusion", "keyhole", "all_defects"]
223
+ if args.subset == "all"
224
+ else [args.subset]
225
+ )
226
+
227
+ for subset in subsets:
228
+ run_inference(subset, args)
229
+
230
+ print("\n✅ Done. Check output/inference/")
inspect_dataset.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ inspect_dataset.py
3
+ ------------------
4
+ Inspects the OSF Ti-64 SEM fractography dataset after downloading.
5
+ Run after download_osf.py.
6
+
7
+ What this does:
8
+ 1. Scans the data/ directory and reports what it finds
9
+ 2. Detects mask format (grayscale int labels vs RGB color masks)
10
+ 3. Prints unique class label values found in masks
11
+ 4. Generates a visualization grid of image/mask pairs
12
+ 5. Saves visualization to output/inspection_grid.png
13
+
14
+ Usage:
15
+ python inspect_dataset.py
16
+ python inspect_dataset.py --data_dir path/to/your/data
17
+ """
18
+
19
+ import argparse
20
+ import sys
21
+ from pathlib import Path
22
+
23
+ import matplotlib
24
+ matplotlib.use("Agg") # headless-safe; switch to "TkAgg" if you want interactive
25
+ import matplotlib.pyplot as plt
26
+ import matplotlib.patches as mpatches
27
+ import numpy as np
28
+ from PIL import Image
29
+
30
+
31
+ # ── Configurable label map ────────────────────────────────────────────────────
32
+ # Update this once you've inspected the actual class values in your masks.
33
+ # Keys = integer pixel values in mask PNGs.
34
+ LABEL_MAP = {
35
+ 0: ("Background", "#1a1a2e"),
36
+ 1: ("Lack of Fusion", "#e94560"),
37
+ 2: ("Keyhole", "#0f3460"),
38
+ 3: ("Other Defect", "#533483"),
39
+ # Add more if you find additional class values
40
+ }
41
+
42
+ # Fallback colormap for unknown labels
43
+ CMAP = plt.cm.get_cmap("tab10")
44
+ # ─────────────────────────────────────────────────────────────────────────────
45
+
46
+
47
+ def find_image_mask_pairs(data_dir: Path) -> list[tuple[Path, Path]]:
48
+ """
49
+ Scan data_dir for image/mask pairs.
50
+ Assumes masks live in a folder named 'masks' or 'mask',
51
+ and images in 'images' or 'image', or are paired by filename.
52
+ """
53
+ pairs = []
54
+ image_exts = {".png", ".tif", ".tiff", ".jpg", ".jpeg", ".bmp"}
55
+
56
+ # Strategy 1: look for images/ and masks/ sibling folders
57
+ for images_dir in sorted(data_dir.rglob("images")):
58
+ if not images_dir.is_dir():
59
+ continue
60
+ masks_dir = images_dir.parent / "masks"
61
+ if not masks_dir.exists():
62
+ masks_dir = images_dir.parent / "mask"
63
+ if not masks_dir.exists():
64
+ print(f" ⚠️ Found images/ at {images_dir} but no masks/ sibling")
65
+ continue
66
+ for img_path in sorted(images_dir.iterdir()):
67
+ if img_path.suffix.lower() not in image_exts:
68
+ continue
69
+ # Try matching by stem
70
+ for ext in image_exts:
71
+ mask_path = masks_dir / (img_path.stem + ext)
72
+ if mask_path.exists():
73
+ pairs.append((img_path, mask_path))
74
+ break
75
+ else:
76
+ print(f" ⚠️ No mask found for {img_path.name}")
77
+
78
+ # Strategy 2: flat folder — files named *_image.* and *_mask.*
79
+ if not pairs:
80
+ for img_path in sorted(data_dir.rglob("*_image.*")):
81
+ if img_path.suffix.lower() not in image_exts:
82
+ continue
83
+ stem = img_path.stem.replace("_image", "")
84
+ for ext in image_exts:
85
+ mask_path = img_path.parent / f"{stem}_mask{ext}"
86
+ if mask_path.exists():
87
+ pairs.append((img_path, mask_path))
88
+ break
89
+
90
+ return pairs
91
+
92
+
93
+ def inspect_mask(mask_path: Path) -> dict:
94
+ """Return statistics about a mask file."""
95
+ mask = np.array(Image.open(mask_path))
96
+ info = {
97
+ "shape": mask.shape,
98
+ "dtype": str(mask.dtype),
99
+ "mode": Image.open(mask_path).mode,
100
+ "unique_values": sorted(np.unique(mask).tolist()),
101
+ "min": int(mask.min()),
102
+ "max": int(mask.max()),
103
+ }
104
+ return info
105
+
106
+
107
+ def colorize_mask(mask: np.ndarray) -> np.ndarray:
108
+ """Convert integer label mask to RGB image for visualization."""
109
+ unique = np.unique(mask)
110
+ rgb = np.zeros((*mask.shape[:2], 3), dtype=np.uint8)
111
+ for val in unique:
112
+ if val in LABEL_MAP:
113
+ hex_color = LABEL_MAP[val][1].lstrip("#")
114
+ r, g, b = tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))
115
+ color = (r, g, b)
116
+ else:
117
+ # fallback: use matplotlib colormap
118
+ rgba = CMAP(val / max(unique.max(), 1))
119
+ color = tuple(int(c * 255) for c in rgba[:3])
120
+ rgb[mask == val] = color
121
+ return rgb
122
+
123
+
124
+ def make_legend(unique_vals: list[int]) -> list[mpatches.Patch]:
125
+ patches = []
126
+ for val in unique_vals:
127
+ label, hex_color = LABEL_MAP.get(val, (f"Class {val}", "#888888"))
128
+ patches.append(mpatches.Patch(color=hex_color, label=f"{val}: {label}"))
129
+ return patches
130
+
131
+
132
+ def visualize_pairs(
133
+ pairs: list[tuple[Path, Path]],
134
+ n: int = 6,
135
+ output_path: Path = Path("output/inspection_grid.png"),
136
+ ):
137
+ """Save a grid of n image/mask/overlay triplets."""
138
+ n = min(n, len(pairs))
139
+ if n == 0:
140
+ print(" No pairs to visualize.")
141
+ return
142
+
143
+ fig, axes = plt.subplots(n, 3, figsize=(12, n * 4))
144
+ if n == 1:
145
+ axes = [axes]
146
+
147
+ fig.suptitle("OSF Ti-64 SEM Dataset — Inspection Grid\n(Image | Mask | Overlay)",
148
+ fontsize=13, fontweight="bold", y=1.01)
149
+
150
+ all_unique = set()
151
+
152
+ for i, (img_path, mask_path) in enumerate(pairs[:n]):
153
+
154
+ raw = np.array(Image.open(img_path), dtype=np.float32)
155
+ raw = (raw - raw.min()) / (raw.max() - raw.min() + 1e-8)
156
+ img = np.stack([raw, raw, raw], axis=-1)
157
+ mask_pil = Image.open(mask_path)
158
+ mask_arr = np.array(mask_pil)
159
+
160
+ # If mask is RGB, convert to grayscale for inspection
161
+ if mask_arr.ndim == 3:
162
+ mask_arr = np.array(mask_pil.convert("L"))
163
+
164
+ unique_vals = sorted(np.unique(mask_arr).tolist())
165
+ all_unique.update(unique_vals)
166
+ mask_rgb = colorize_mask(mask_arr)
167
+
168
+ # Overlay: blend image and mask
169
+ img_display = (img * 255).astype(np.uint8) if img.max() <= 1.0 else img.astype(np.uint8)
170
+ overlay = (img_display.astype(float) * 0.6 + mask_rgb.astype(float) * 0.4).astype(np.uint8)
171
+ axes[i][0].imshow(img, cmap="gray" if img.ndim == 2 else None)
172
+ axes[i][0].set_title(f"Image\n{img_path.name}", fontsize=8)
173
+ axes[i][0].axis("off")
174
+
175
+ axes[i][1].imshow(mask_rgb)
176
+ axes[i][1].set_title(
177
+ f"Mask (classes: {unique_vals})\n{mask_path.name}", fontsize=8
178
+ )
179
+ axes[i][1].axis("off")
180
+
181
+ axes[i][2].imshow(overlay)
182
+ axes[i][2].set_title("Overlay", fontsize=8)
183
+ axes[i][2].axis("off")
184
+
185
+ # Add legend
186
+ legend_patches = make_legend(sorted(all_unique))
187
+ fig.legend(handles=legend_patches, loc="lower center", ncol=len(legend_patches),
188
+ bbox_to_anchor=(0.5, -0.02), fontsize=9, title="Mask Classes Found")
189
+
190
+ output_path.parent.mkdir(parents=True, exist_ok=True)
191
+ plt.tight_layout()
192
+ plt.savefig(output_path, dpi=150, bbox_inches="tight")
193
+ plt.close()
194
+ print(f"\n✅ Visualization saved to: {output_path.resolve()}")
195
+
196
+
197
+ def print_dataset_summary(data_dir: Path, pairs: list[tuple[Path, Path]]):
198
+ print(f"\n{'='*60}")
199
+ print(f"Dataset Summary — {data_dir.resolve()}")
200
+ print(f"{'='*60}")
201
+ print(f"Total image/mask pairs found: {len(pairs)}")
202
+
203
+ if not pairs:
204
+ print("\n⚠️ No pairs found. Check your data/ folder structure.")
205
+ print("Expected layout:")
206
+ print(" data/")
207
+ print(" <subset>/")
208
+ print(" images/ ← SEM images (.png or .tif)")
209
+ print(" masks/ ← segmentation masks (.png)")
210
+ return
211
+
212
+ # Sample first few masks
213
+ print(f"\nSampling first 5 masks for format inspection:")
214
+ all_unique = set()
215
+ for img_path, mask_path in pairs[:5]:
216
+ info = inspect_mask(mask_path)
217
+ print(f"\n {mask_path.name}")
218
+ print(f" Mode: {info['mode']}")
219
+ print(f" Shape: {info['shape']}")
220
+ print(f" Dtype: {info['dtype']}")
221
+ print(f" Unique values: {info['unique_values']}")
222
+ print(f" Value range: [{info['min']}, {info['max']}]")
223
+ all_unique.update(info["unique_values"])
224
+
225
+ print(f"\n{'─'*40}")
226
+ print(f"All unique class values across sampled masks: {sorted(all_unique)}")
227
+ print("\nLabel interpretation:")
228
+ for v in sorted(all_unique):
229
+ label, _ = LABEL_MAP.get(v, (f"UNKNOWN — update LABEL_MAP in this script", "#888"))
230
+ print(f" {v:3d} → {label}")
231
+
232
+ print(f"\n⚠️ NOTE: If all unique values are {{0, 255}}, masks are binary (defect/no-defect).")
233
+ print(" If values are 0–N, masks are multi-class integer labels — ideal for SegFormer.")
234
+ print(" If mode is 'RGB', masks encode class as color — you'll need to remap.")
235
+
236
+
237
+ def main():
238
+ parser = argparse.ArgumentParser()
239
+ parser.add_argument("--data_dir", type=str, default="data",
240
+ help="Path to downloaded dataset root")
241
+ parser.add_argument("--n_vis", type=int, default=6,
242
+ help="Number of pairs to visualize")
243
+ parser.add_argument("--output", type=str, default="output/inspection_grid.png",
244
+ help="Where to save the visualization grid")
245
+ args = parser.parse_args()
246
+
247
+ data_dir = Path(args.data_dir)
248
+ if not data_dir.exists():
249
+ print(f"❌ data_dir '{data_dir}' does not exist.")
250
+ print("Run download_osf.py first, or set --data_dir to your data folder.")
251
+ sys.exit(1)
252
+
253
+ print("Scanning for image/mask pairs...")
254
+ pairs = find_image_mask_pairs(data_dir)
255
+
256
+ print_dataset_summary(data_dir, pairs)
257
+
258
+ if pairs:
259
+ print(f"\nGenerating visualization grid ({min(args.n_vis, len(pairs))} samples)...")
260
+ visualize_pairs(pairs, n=args.n_vis, output_path=Path(args.output))
261
+
262
+
263
+ if __name__ == "__main__":
264
+ main()
setup.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FailureGPT — OSF Ti-64 Dataset Inspector
3
+ Run this first to install dependencies.
4
+ """
5
+ import subprocess, sys
6
+
7
+ packages = [
8
+ "torch",
9
+ "torchvision",
10
+ "Pillow",
11
+ "matplotlib",
12
+ "numpy",
13
+ "requests",
14
+ "osfclient",
15
+ "tqdm",
16
+ ]
17
+
18
+ for pkg in packages:
19
+ subprocess.check_call([sys.executable, "-m", "pip", "install", pkg, "-q"])
20
+
21
+ print("✅ All dependencies installed.")
test.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ inspect_dataset.py
3
+ ------------------
4
+ Inspects the OSF Ti-64 SEM fractography dataset after downloading.
5
+ Run after download_osf.py.
6
+
7
+ What this does:
8
+ 1. Scans the data/ directory and reports what it finds
9
+ 2. Detects mask format (grayscale int labels vs RGB color masks)
10
+ 3. Prints unique class label values found in masks
11
+ 4. Generates a visualization grid of image/mask pairs
12
+ 5. Saves visualization to output/inspection_grid.png
13
+
14
+ Usage:
15
+ python inspect_dataset.py
16
+ python inspect_dataset.py --data_dir path/to/your/data
17
+ """
18
+
19
+ import argparse
20
+ import sys
21
+ from pathlib import Path
22
+
23
+ import matplotlib
24
+ matplotlib.use("Agg") # headless-safe; switch to "TkAgg" if you want interactive
25
+ import matplotlib.pyplot as plt
26
+ import matplotlib.patches as mpatches
27
+ import numpy as np
28
+ from PIL import Image
29
+
30
+
31
+ # ── Configurable label map ────────────────────────────────────────────────────
32
+ # Update this once you've inspected the actual class values in your masks.
33
+ # Keys = integer pixel values in mask PNGs.
34
+ LABEL_MAP = {
35
+ 0: ("Background", "#1a1a2e"),
36
+ 1: ("Lack of Fusion", "#e94560"),
37
+ 2: ("Keyhole", "#0f3460"),
38
+ 3: ("Other Defect", "#533483"),
39
+ # Add more if you find additional class values
40
+ }
41
+
42
+ # Fallback colormap for unknown labels
43
+ CMAP = plt.cm.get_cmap("tab10")
44
+ # ─────────────────────────────────────────────────────────────────────────────
45
+
46
+
47
+ def find_image_mask_pairs(data_dir: Path) -> list[tuple[Path, Path]]:
48
+ """
49
+ Scan data_dir for image/mask pairs.
50
+ Assumes masks live in a folder named 'masks' or 'mask',
51
+ and images in 'images' or 'image', or are paired by filename.
52
+ """
53
+ pairs = []
54
+ image_exts = {".png", ".tif", ".tiff", ".jpg", ".jpeg", ".bmp"}
55
+
56
+ # Strategy 1: look for images/ and masks/ sibling folders
57
+ for images_dir in sorted(data_dir.rglob("images")):
58
+ if not images_dir.is_dir():
59
+ continue
60
+ masks_dir = images_dir.parent / "masks"
61
+ if not masks_dir.exists():
62
+ masks_dir = images_dir.parent / "mask"
63
+ if not masks_dir.exists():
64
+ print(f" ⚠️ Found images/ at {images_dir} but no masks/ sibling")
65
+ continue
66
+ for img_path in sorted(images_dir.iterdir()):
67
+ if img_path.suffix.lower() not in image_exts:
68
+ continue
69
+ # Try matching by stem
70
+ for ext in image_exts:
71
+ mask_path = masks_dir / (img_path.stem + ext)
72
+ if mask_path.exists():
73
+ pairs.append((img_path, mask_path))
74
+ break
75
+ else:
76
+ print(f" ⚠️ No mask found for {img_path.name}")
77
+
78
+ # Strategy 2: flat folder — files named *_image.* and *_mask.*
79
+ if not pairs:
80
+ for img_path in sorted(data_dir.rglob("*_image.*")):
81
+ if img_path.suffix.lower() not in image_exts:
82
+ continue
83
+ stem = img_path.stem.replace("_image", "")
84
+ for ext in image_exts:
85
+ mask_path = img_path.parent / f"{stem}_mask{ext}"
86
+ if mask_path.exists():
87
+ pairs.append((img_path, mask_path))
88
+ break
89
+
90
+ return pairs
91
+
92
+
93
+ def inspect_mask(mask_path: Path) -> dict:
94
+ """Return statistics about a mask file."""
95
+ mask = np.array(Image.open(mask_path))
96
+ info = {
97
+ "shape": mask.shape,
98
+ "dtype": str(mask.dtype),
99
+ "mode": Image.open(mask_path).mode,
100
+ "unique_values": sorted(np.unique(mask).tolist()),
101
+ "min": int(mask.min()),
102
+ "max": int(mask.max()),
103
+ }
104
+ return info
105
+
106
+
107
+ def colorize_mask(mask: np.ndarray) -> np.ndarray:
108
+ """Convert integer label mask to RGB image for visualization."""
109
+ unique = np.unique(mask)
110
+ rgb = np.zeros((*mask.shape[:2], 3), dtype=np.uint8)
111
+ for val in unique:
112
+ if val in LABEL_MAP:
113
+ hex_color = LABEL_MAP[val][1].lstrip("#")
114
+ r, g, b = tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))
115
+ color = (r, g, b)
116
+ else:
117
+ # fallback: use matplotlib colormap
118
+ rgba = CMAP(val / max(unique.max(), 1))
119
+ color = tuple(int(c * 255) for c in rgba[:3])
120
+ rgb[mask == val] = color
121
+ return rgb
122
+
123
+
124
+ def make_legend(unique_vals: list[int]) -> list[mpatches.Patch]:
125
+ patches = []
126
+ for val in unique_vals:
127
+ label, hex_color = LABEL_MAP.get(val, (f"Class {val}", "#888888"))
128
+ patches.append(mpatches.Patch(color=hex_color, label=f"{val}: {label}"))
129
+ return patches
130
+
131
+
132
+ def visualize_pairs(
133
+ pairs: list[tuple[Path, Path]],
134
+ n: int = 6,
135
+ output_path: Path = Path("output/inspection_grid.png"),
136
+ ):
137
+ """Save a grid of n image/mask/overlay triplets."""
138
+ n = min(n, len(pairs))
139
+ if n == 0:
140
+ print(" No pairs to visualize.")
141
+ return
142
+
143
+ fig, axes = plt.subplots(n, 3, figsize=(12, n * 4))
144
+ if n == 1:
145
+ axes = [axes]
146
+
147
+ fig.suptitle("OSF Ti-64 SEM Dataset — Inspection Grid\n(Image | Mask | Overlay)",
148
+ fontsize=13, fontweight="bold", y=1.01)
149
+
150
+ all_unique = set()
151
+
152
+ for i, (img_path, mask_path) in enumerate(pairs[:n]):
153
+ img = np.array(Image.open(img_path).convert("RGB"))
154
+ mask_pil = Image.open(mask_path)
155
+ mask_arr = np.array(mask_pil)
156
+
157
+ # If mask is RGB, convert to grayscale for inspection
158
+ if mask_arr.ndim == 3:
159
+ mask_arr = np.array(mask_pil.convert("L"))
160
+
161
+ unique_vals = sorted(np.unique(mask_arr).tolist())
162
+ all_unique.update(unique_vals)
163
+ mask_rgb = colorize_mask(mask_arr)
164
+
165
+ # Overlay: blend image and mask
166
+ overlay = (img.astype(float) * 0.5 + mask_rgb.astype(float) * 0.5).astype(np.uint8)
167
+
168
+ axes[i][0].imshow(img, cmap="gray" if img.ndim == 2 else None)
169
+ axes[i][0].set_title(f"Image\n{img_path.name}", fontsize=8)
170
+ axes[i][0].axis("off")
171
+
172
+ axes[i][1].imshow(mask_rgb)
173
+ axes[i][1].set_title(
174
+ f"Mask (classes: {unique_vals})\n{mask_path.name}", fontsize=8
175
+ )
176
+ axes[i][1].axis("off")
177
+
178
+ axes[i][2].imshow(overlay)
179
+ axes[i][2].set_title("Overlay", fontsize=8)
180
+ axes[i][2].axis("off")
181
+
182
+ # Add legend
183
+ legend_patches = make_legend(sorted(all_unique))
184
+ fig.legend(handles=legend_patches, loc="lower center", ncol=len(legend_patches),
185
+ bbox_to_anchor=(0.5, -0.02), fontsize=9, title="Mask Classes Found")
186
+
187
+ output_path.parent.mkdir(parents=True, exist_ok=True)
188
+ plt.tight_layout()
189
+ plt.savefig(output_path, dpi=150, bbox_inches="tight")
190
+ plt.close()
191
+ print(f"\n✅ Visualization saved to: {output_path.resolve()}")
192
+
193
+
194
+ def print_dataset_summary(data_dir: Path, pairs: list[tuple[Path, Path]]):
195
+ print(f"\n{'='*60}")
196
+ print(f"Dataset Summary — {data_dir.resolve()}")
197
+ print(f"{'='*60}")
198
+ print(f"Total image/mask pairs found: {len(pairs)}")
199
+
200
+ if not pairs:
201
+ print("\n⚠️ No pairs found. Check your data/ folder structure.")
202
+ print("Expected layout:")
203
+ print(" data/")
204
+ print(" <subset>/")
205
+ print(" images/ ← SEM images (.png or .tif)")
206
+ print(" masks/ ← segmentation masks (.png)")
207
+ return
208
+
209
+ # Sample first few masks
210
+ print(f"\nSampling first 5 masks for format inspection:")
211
+ all_unique = set()
212
+ for img_path, mask_path in pairs[:5]:
213
+ info = inspect_mask(mask_path)
214
+ print(f"\n {mask_path.name}")
215
+ print(f" Mode: {info['mode']}")
216
+ print(f" Shape: {info['shape']}")
217
+ print(f" Dtype: {info['dtype']}")
218
+ print(f" Unique values: {info['unique_values']}")
219
+ print(f" Value range: [{info['min']}, {info['max']}]")
220
+ all_unique.update(info["unique_values"])
221
+
222
+ print(f"\n{'─'*40}")
223
+ print(f"All unique class values across sampled masks: {sorted(all_unique)}")
224
+ print("\nLabel interpretation:")
225
+ for v in sorted(all_unique):
226
+ label, _ = LABEL_MAP.get(v, (f"UNKNOWN — update LABEL_MAP in this script", "#888"))
227
+ print(f" {v:3d} → {label}")
228
+
229
+ print(f"\n⚠️ NOTE: If all unique values are {{0, 255}}, masks are binary (defect/no-defect).")
230
+ print(" If values are 0–N, masks are multi-class integer labels — ideal for SegFormer.")
231
+ print(" If mode is 'RGB', masks encode class as color — you'll need to remap.")
232
+
233
+
234
+ def main():
235
+ parser = argparse.ArgumentParser()
236
+ parser.add_argument("--data_dir", type=str, default="data",
237
+ help="Path to downloaded dataset root")
238
+ parser.add_argument("--n_vis", type=int, default=6,
239
+ help="Number of pairs to visualize")
240
+ parser.add_argument("--output", type=str, default="output/inspection_grid.png",
241
+ help="Where to save the visualization grid")
242
+ args = parser.parse_args()
243
+
244
+ data_dir = Path(args.data_dir)
245
+ if not data_dir.exists():
246
+ print(f"❌ data_dir '{data_dir}' does not exist.")
247
+ print("Run download_osf.py first, or set --data_dir to your data folder.")
248
+ sys.exit(1)
249
+
250
+ print("Scanning for image/mask pairs...")
251
+ pairs = find_image_mask_pairs(data_dir)
252
+
253
+ print_dataset_summary(data_dir, pairs)
254
+
255
+ if pairs:
256
+ print(f"\nGenerating visualization grid ({min(args.n_vis, len(pairs))} samples)...")
257
+ visualize_pairs(pairs, n=args.n_vis, output_path=Path(args.output))
258
+
259
+
260
+ if __name__ == "__main__":
261
+ main()
train.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ train.py
3
+ --------
4
+ Fine-tunes SegFormer-b0 on the OSF Ti-64 SEM fractography dataset.
5
+ Trains all three subsets (lack_of_fusion, keyhole, all_defects) separately.
6
+ CPU-optimized: small image size, small batch, few epochs.
7
+
8
+ Usage:
9
+ python train.py
10
+ python train.py --epochs 10 --image_size 256
11
+
12
+ Outputs (per subset):
13
+ checkpoints/<subset>/best_model.pt <- best checkpoint by val mIoU
14
+ checkpoints/<subset>/last_model.pt <- final epoch checkpoint
15
+ checkpoints/<subset>/history.json <- loss/mIoU per epoch
16
+ """
17
+
18
+ import argparse
19
+ import json
20
+ import time
21
+ from pathlib import Path
22
+
23
+ import numpy as np
24
+ import torch
25
+ import torch.nn as nn
26
+ from torch.utils.data import DataLoader, random_split
27
+ from transformers import SegformerForSemanticSegmentation
28
+ import torch.nn.functional as F
29
+
30
+ from dataset import FractographyDataset
31
+
32
+ # ── Config ────────────────────────────────────────────────────────────────────
33
+ NUM_CLASSES = 2 # background (0) + defect (1)
34
+ IMAGE_SIZE = (256, 256) # smaller = faster on CPU; increase if you have time
35
+ BATCH_SIZE = 2
36
+ EPOCHS = 15
37
+ LR = 6e-5
38
+ TRAIN_FRAC = 0.8
39
+ WEIGHT_DECAY = 0.01
40
+
41
+ SUBSETS = ["lack_of_fusion", "keyhole", "all_defects"]
42
+ # ─────────────────────────────────────────────────────────────────────────────
43
+
44
+
45
+ def compute_miou(preds: torch.Tensor, targets: torch.Tensor, num_classes: int) -> float:
46
+ """Mean Intersection over Union across all classes."""
47
+ ious = []
48
+ preds = preds.view(-1)
49
+ targets = targets.view(-1)
50
+ for cls in range(num_classes):
51
+ pred_mask = preds == cls
52
+ target_mask = targets == cls
53
+ intersection = (pred_mask & target_mask).sum().item()
54
+ union = (pred_mask | target_mask).sum().item()
55
+ if union == 0:
56
+ continue # class not present in this batch
57
+ ious.append(intersection / union)
58
+ return float(np.mean(ious)) if ious else 0.0
59
+
60
+
61
+ def dice_loss(logits: torch.Tensor, targets: torch.Tensor, smooth: float = 1.0) -> torch.Tensor:
62
+ """
63
+ Soft Dice loss for binary segmentation.
64
+ Directly optimizes overlap — critical for imbalanced datasets.
65
+ logits: (B, num_classes, H, W)
66
+ targets: (B, H, W) integer labels
67
+ """
68
+ probs = torch.softmax(logits, dim=1) # (B, C, H, W)
69
+ # Focus on defect class (index 1)
70
+ prob_defect = probs[:, 1] # (B, H, W)
71
+ target_defect = (targets == 1).float()
72
+
73
+ intersection = (prob_defect * target_defect).sum(dim=(1, 2))
74
+ union = prob_defect.sum(dim=(1, 2)) + target_defect.sum(dim=(1, 2))
75
+ dice = (2.0 * intersection + smooth) / (union + smooth)
76
+ return 1.0 - dice.mean()
77
+
78
+
79
+ def combined_loss(
80
+ logits: torch.Tensor,
81
+ targets: torch.Tensor,
82
+ defect_weight: float = 10.0,
83
+ dice_weight: float = 0.5,
84
+ ) -> torch.Tensor:
85
+ """
86
+ Weighted CE + Dice loss.
87
+ defect_weight: how much extra to penalize missing defect pixels.
88
+ Start at 10x given ~6% defect pixels.
89
+ dice_weight: blend factor for Dice loss (0 = CE only, 1 = Dice only).
90
+ """
91
+ # Upsample logits to match mask size
92
+ logits_up = F.interpolate(
93
+ logits, size=targets.shape[-2:], mode="bilinear", align_corners=False
94
+ )
95
+ # Weighted cross-entropy
96
+ weight = torch.tensor([1.0, defect_weight], device=logits.device)
97
+ ce = F.cross_entropy(logits_up, targets, weight=weight)
98
+ # Dice
99
+ dl = dice_loss(logits_up, targets)
100
+ return (1.0 - dice_weight) * ce + dice_weight * dl
101
+
102
+ def train_one_epoch(model, loader, optimizer, device):
103
+ model.train()
104
+ total_loss = 0.0
105
+ for images, masks in loader:
106
+ images = images.to(device)
107
+ masks = masks.to(device)
108
+
109
+ # Use HuggingFace built-in loss — passes labels at native resolution
110
+ # SegFormer internally downsamples labels to match logit size
111
+ outputs = model(pixel_values=images, labels=masks)
112
+ loss = outputs.loss
113
+
114
+ optimizer.zero_grad()
115
+ loss.backward()
116
+ optimizer.step()
117
+
118
+ total_loss += loss.item()
119
+
120
+ return total_loss / len(loader)
121
+
122
+ @torch.no_grad()
123
+ def evaluate(model, loader, device, num_classes):
124
+ model.eval()
125
+ total_loss = 0.0
126
+ all_miou = []
127
+
128
+ for images, masks in loader:
129
+ images = images.to(device)
130
+ masks = masks.to(device)
131
+
132
+ outputs = model(pixel_values=images, labels=masks)
133
+ loss = outputs.loss
134
+ logits = outputs.logits # (B, num_classes, H/4, W/4)
135
+
136
+ # Upsample logits to mask size
137
+ upsampled = F.interpolate(
138
+ logits,
139
+ size=masks.shape[-2:],
140
+ mode="bilinear",
141
+ align_corners=False,
142
+ )
143
+ preds = upsampled.argmax(dim=1) # (B, H, W)
144
+
145
+ total_loss += loss.item()
146
+ all_miou.append(compute_miou(preds.cpu(), masks.cpu(), num_classes))
147
+
148
+ return total_loss / len(loader), float(np.mean(all_miou))
149
+
150
+
151
+ def train_subset(subset: str, data_root: Path, args):
152
+ subset_dir = data_root / subset
153
+ if not subset_dir.exists():
154
+ print(f"\n⚠️ Skipping '{subset}' — folder not found at {subset_dir}")
155
+ return
156
+
157
+ print(f"\n{'='*60}")
158
+ print(f"Training on subset: {subset}")
159
+ print(f"{'='*60}")
160
+
161
+ # Dataset
162
+ full_ds = FractographyDataset(
163
+ subset_dir,
164
+ split="all",
165
+ image_size=IMAGE_SIZE,
166
+ )
167
+ if len(full_ds) == 0:
168
+ print(f" ⚠️ No image/mask pairs found in {subset_dir}")
169
+ return
170
+
171
+ n_train = max(1, int(len(full_ds) * TRAIN_FRAC))
172
+ n_val = len(full_ds) - n_train
173
+ train_ds, val_ds = random_split(
174
+ full_ds, [n_train, n_val],
175
+ generator=torch.Generator().manual_seed(42)
176
+ )
177
+ print(f" Train: {len(train_ds)} | Val: {len(val_ds)}")
178
+
179
+ train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
180
+ val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
181
+
182
+ # Model
183
+ device = torch.device("cpu")
184
+ id2label = {0: "background", 1: "defect"}
185
+ label2id = {v: k for k, v in id2label.items()}
186
+
187
+ print(f" Loading SegFormer-b0 from HuggingFace...")
188
+ model = SegformerForSemanticSegmentation.from_pretrained(
189
+ "nvidia/mit-b0",
190
+ num_labels=NUM_CLASSES,
191
+ id2label=id2label,
192
+ label2id=label2id,
193
+ ignore_mismatched_sizes=True,
194
+ ).to(device)
195
+
196
+ optimizer = torch.optim.AdamW([
197
+ {"params": model.segformer.parameters(), "lr": args.lr},
198
+ {"params": model.decode_head.parameters(), "lr": args.lr * 50},
199
+ ], weight_decay=WEIGHT_DECAY)
200
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
201
+
202
+ # Checkpoint dir
203
+ ckpt_dir = Path("checkpoints") / subset
204
+ ckpt_dir.mkdir(parents=True, exist_ok=True)
205
+
206
+ history = {"train_loss": [], "val_loss": [], "val_miou": []}
207
+ best_miou = 0.0
208
+
209
+ for epoch in range(1, args.epochs + 1):
210
+ t0 = time.time()
211
+
212
+ train_loss = train_one_epoch(model, train_loader, optimizer, device)
213
+ val_loss, val_miou = evaluate(model, val_loader, device, NUM_CLASSES)
214
+ scheduler.step()
215
+
216
+ elapsed = time.time() - t0
217
+ print(
218
+ f" Epoch {epoch:02d}/{args.epochs} | "
219
+ f"train_loss={train_loss:.4f} | "
220
+ f"val_loss={val_loss:.4f} | "
221
+ f"val_mIoU={val_miou:.4f} | "
222
+ f"{elapsed:.1f}s"
223
+ )
224
+
225
+ history["train_loss"].append(train_loss)
226
+ history["val_loss"].append(val_loss)
227
+ history["val_miou"].append(val_miou)
228
+
229
+ # Save best
230
+ if val_miou >= best_miou:
231
+ best_miou = val_miou
232
+ torch.save(model.state_dict(), ckpt_dir / "best_model.pt")
233
+ print(f" ✅ New best mIoU: {best_miou:.4f} — checkpoint saved")
234
+
235
+ # Save last + history
236
+ torch.save(model.state_dict(), ckpt_dir / "last_model.pt")
237
+ with open(ckpt_dir / "history.json", "w") as f:
238
+ json.dump(history, f, indent=2)
239
+
240
+ print(f"\n Done. Best val mIoU: {best_miou:.4f}")
241
+ print(f" Checkpoints saved to: {ckpt_dir.resolve()}")
242
+ return history
243
+
244
+
245
+ def plot_histories(histories: dict):
246
+ """Save a training curve plot for all subsets."""
247
+ try:
248
+ import matplotlib
249
+ matplotlib.use("Agg")
250
+ import matplotlib.pyplot as plt
251
+
252
+ fig, axes = plt.subplots(1, 2, figsize=(12, 4))
253
+ fig.suptitle("SegFormer-b0 Training — OSF Ti-64", fontweight="bold")
254
+
255
+ for subset, h in histories.items():
256
+ epochs = range(1, len(h["train_loss"]) + 1)
257
+ axes[0].plot(epochs, h["train_loss"], label=f"{subset} train")
258
+ axes[0].plot(epochs, h["val_loss"], label=f"{subset} val", linestyle="--")
259
+ axes[1].plot(epochs, h["val_miou"], label=subset)
260
+
261
+ axes[0].set_title("Loss")
262
+ axes[0].set_xlabel("Epoch")
263
+ axes[0].legend(fontsize=7)
264
+ axes[1].set_title("Val mIoU")
265
+ axes[1].set_xlabel("Epoch")
266
+ axes[1].legend(fontsize=7)
267
+
268
+ out = Path("checkpoints/training_curves.png")
269
+ plt.tight_layout()
270
+ plt.savefig(out, dpi=150)
271
+ plt.close()
272
+ print(f"\n📈 Training curves saved to: {out.resolve()}")
273
+ except Exception as e:
274
+ print(f" (Could not save plot: {e})")
275
+
276
+
277
+ if __name__ == "__main__":
278
+ parser = argparse.ArgumentParser()
279
+ parser.add_argument("--data_dir", type=str, default="data")
280
+ parser.add_argument("--epochs", type=int, default=EPOCHS)
281
+ parser.add_argument("--lr", type=float, default=LR)
282
+ parser.add_argument("--image_size", type=int, default=256,
283
+ help="Square image size (256 recommended for CPU)")
284
+ args = parser.parse_args()
285
+
286
+ # Override IMAGE_SIZE from arg
287
+ IMAGE_SIZE = (args.image_size, args.image_size)
288
+ # Patch dataset module so it uses the right size
289
+ import dataset as ds_module
290
+ ds_module.IMAGE_SIZE = IMAGE_SIZE
291
+
292
+ data_root = Path(args.data_dir)
293
+ histories = {}
294
+
295
+ for subset in SUBSETS:
296
+ h = train_subset(subset, data_root, args)
297
+ if h:
298
+ histories[subset] = h
299
+
300
+ if histories:
301
+ plot_histories(histories)
302
+ print("\n✅ All subsets complete.")
303
+ print("\nSummary:")
304
+ for subset, h in histories.items():
305
+ best = max(h["val_miou"])
306
+ print(f" {subset:20s} best mIoU = {best:.4f}")