| """ |
| train.py |
| -------- |
| Fine-tunes SegFormer-b0 on the OSF Ti-64 SEM fractography dataset. |
| Trains all three subsets (lack_of_fusion, keyhole, all_defects) separately. |
| CPU-optimized: small image size, small batch, few epochs. |
| |
| Usage: |
| python train.py |
| python train.py --epochs 10 --image_size 256 |
| |
| Outputs (per subset): |
| checkpoints/<subset>/best_model.pt <- best checkpoint by val mIoU |
| checkpoints/<subset>/last_model.pt <- final epoch checkpoint |
| checkpoints/<subset>/history.json <- loss/mIoU per epoch |
| """ |
|
|
| import argparse |
| import json |
| import time |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from torch.utils.data import DataLoader, random_split |
| from transformers import SegformerForSemanticSegmentation |
| import torch.nn.functional as F |
|
|
| from dataset import FractographyDataset |
|
|
| |
| NUM_CLASSES = 2 |
| IMAGE_SIZE = (256, 256) |
| BATCH_SIZE = 2 |
| EPOCHS = 15 |
| LR = 6e-5 |
| TRAIN_FRAC = 0.8 |
| WEIGHT_DECAY = 0.01 |
|
|
| SUBSETS = ["lack_of_fusion", "keyhole", "all_defects"] |
| |
|
|
|
|
| def compute_miou(preds: torch.Tensor, targets: torch.Tensor, num_classes: int) -> float: |
| """Mean Intersection over Union across all classes.""" |
| ious = [] |
| preds = preds.view(-1) |
| targets = targets.view(-1) |
| for cls in range(num_classes): |
| pred_mask = preds == cls |
| target_mask = targets == cls |
| intersection = (pred_mask & target_mask).sum().item() |
| union = (pred_mask | target_mask).sum().item() |
| if union == 0: |
| continue |
| ious.append(intersection / union) |
| return float(np.mean(ious)) if ious else 0.0 |
|
|
|
|
| def dice_loss(logits: torch.Tensor, targets: torch.Tensor, smooth: float = 1.0) -> torch.Tensor: |
| """ |
| Soft Dice loss for binary segmentation. |
| Directly optimizes overlap β critical for imbalanced datasets. |
| logits: (B, num_classes, H, W) |
| targets: (B, H, W) integer labels |
| """ |
| probs = torch.softmax(logits, dim=1) |
| |
| prob_defect = probs[:, 1] |
| target_defect = (targets == 1).float() |
|
|
| intersection = (prob_defect * target_defect).sum(dim=(1, 2)) |
| union = prob_defect.sum(dim=(1, 2)) + target_defect.sum(dim=(1, 2)) |
| dice = (2.0 * intersection + smooth) / (union + smooth) |
| return 1.0 - dice.mean() |
|
|
|
|
| def combined_loss( |
| logits: torch.Tensor, |
| targets: torch.Tensor, |
| defect_weight: float = 10.0, |
| dice_weight: float = 0.5, |
| ) -> torch.Tensor: |
| """ |
| Weighted CE + Dice loss. |
| defect_weight: how much extra to penalize missing defect pixels. |
| Start at 10x given ~6% defect pixels. |
| dice_weight: blend factor for Dice loss (0 = CE only, 1 = Dice only). |
| """ |
| |
| logits_up = F.interpolate( |
| logits, size=targets.shape[-2:], mode="bilinear", align_corners=False |
| ) |
| |
| weight = torch.tensor([1.0, defect_weight], device=logits.device) |
| ce = F.cross_entropy(logits_up, targets, weight=weight) |
| |
| dl = dice_loss(logits_up, targets) |
| return (1.0 - dice_weight) * ce + dice_weight * dl |
|
|
| def train_one_epoch(model, loader, optimizer, device): |
| model.train() |
| total_loss = 0.0 |
| for images, masks in loader: |
| images = images.to(device) |
| masks = masks.to(device) |
|
|
| |
| |
| outputs = model(pixel_values=images, labels=masks) |
| loss = outputs.loss |
|
|
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
|
|
| total_loss += loss.item() |
|
|
| return total_loss / len(loader) |
|
|
| @torch.no_grad() |
| def evaluate(model, loader, device, num_classes): |
| model.eval() |
| total_loss = 0.0 |
| all_miou = [] |
|
|
| for images, masks in loader: |
| images = images.to(device) |
| masks = masks.to(device) |
|
|
| outputs = model(pixel_values=images, labels=masks) |
| loss = outputs.loss |
| logits = outputs.logits |
|
|
| |
| upsampled = F.interpolate( |
| logits, |
| size=masks.shape[-2:], |
| mode="bilinear", |
| align_corners=False, |
| ) |
| preds = upsampled.argmax(dim=1) |
|
|
| total_loss += loss.item() |
| all_miou.append(compute_miou(preds.cpu(), masks.cpu(), num_classes)) |
|
|
| return total_loss / len(loader), float(np.mean(all_miou)) |
|
|
|
|
| def train_subset(subset: str, data_root: Path, args): |
| subset_dir = data_root / subset |
| if not subset_dir.exists(): |
| print(f"\nβ οΈ Skipping '{subset}' β folder not found at {subset_dir}") |
| return |
|
|
| print(f"\n{'='*60}") |
| print(f"Training on subset: {subset}") |
| print(f"{'='*60}") |
|
|
| |
| full_ds = FractographyDataset( |
| subset_dir, |
| split="all", |
| image_size=IMAGE_SIZE, |
| ) |
| if len(full_ds) == 0: |
| print(f" β οΈ No image/mask pairs found in {subset_dir}") |
| return |
|
|
| n_train = max(1, int(len(full_ds) * TRAIN_FRAC)) |
| n_val = len(full_ds) - n_train |
| train_ds, val_ds = random_split( |
| full_ds, [n_train, n_val], |
| generator=torch.Generator().manual_seed(42) |
| ) |
| print(f" Train: {len(train_ds)} | Val: {len(val_ds)}") |
|
|
| train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0) |
| val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0) |
|
|
| |
| device = torch.device("cpu") |
| id2label = {0: "background", 1: "defect"} |
| label2id = {v: k for k, v in id2label.items()} |
|
|
| print(f" Loading SegFormer-b0 from HuggingFace...") |
| model = SegformerForSemanticSegmentation.from_pretrained( |
| "nvidia/mit-b0", |
| num_labels=NUM_CLASSES, |
| id2label=id2label, |
| label2id=label2id, |
| ignore_mismatched_sizes=True, |
| ).to(device) |
|
|
| optimizer = torch.optim.AdamW([ |
| {"params": model.segformer.parameters(), "lr": args.lr}, |
| {"params": model.decode_head.parameters(), "lr": args.lr * 50}, |
| ], weight_decay=WEIGHT_DECAY) |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) |
|
|
| |
| ckpt_dir = Path("checkpoints") / subset |
| ckpt_dir.mkdir(parents=True, exist_ok=True) |
|
|
| history = {"train_loss": [], "val_loss": [], "val_miou": []} |
| best_miou = 0.0 |
|
|
| for epoch in range(1, args.epochs + 1): |
| t0 = time.time() |
|
|
| train_loss = train_one_epoch(model, train_loader, optimizer, device) |
| val_loss, val_miou = evaluate(model, val_loader, device, NUM_CLASSES) |
| scheduler.step() |
|
|
| elapsed = time.time() - t0 |
| print( |
| f" Epoch {epoch:02d}/{args.epochs} | " |
| f"train_loss={train_loss:.4f} | " |
| f"val_loss={val_loss:.4f} | " |
| f"val_mIoU={val_miou:.4f} | " |
| f"{elapsed:.1f}s" |
| ) |
|
|
| history["train_loss"].append(train_loss) |
| history["val_loss"].append(val_loss) |
| history["val_miou"].append(val_miou) |
|
|
| |
| if val_miou >= best_miou: |
| best_miou = val_miou |
| torch.save(model.state_dict(), ckpt_dir / "best_model.pt") |
| print(f" β
New best mIoU: {best_miou:.4f} β checkpoint saved") |
|
|
| |
| torch.save(model.state_dict(), ckpt_dir / "last_model.pt") |
| with open(ckpt_dir / "history.json", "w") as f: |
| json.dump(history, f, indent=2) |
|
|
| print(f"\n Done. Best val mIoU: {best_miou:.4f}") |
| print(f" Checkpoints saved to: {ckpt_dir.resolve()}") |
| return history |
|
|
|
|
| def plot_histories(histories: dict): |
| """Save a training curve plot for all subsets.""" |
| try: |
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
|
|
| fig, axes = plt.subplots(1, 2, figsize=(12, 4)) |
| fig.suptitle("SegFormer-b0 Training β OSF Ti-64", fontweight="bold") |
|
|
| for subset, h in histories.items(): |
| epochs = range(1, len(h["train_loss"]) + 1) |
| axes[0].plot(epochs, h["train_loss"], label=f"{subset} train") |
| axes[0].plot(epochs, h["val_loss"], label=f"{subset} val", linestyle="--") |
| axes[1].plot(epochs, h["val_miou"], label=subset) |
|
|
| axes[0].set_title("Loss") |
| axes[0].set_xlabel("Epoch") |
| axes[0].legend(fontsize=7) |
| axes[1].set_title("Val mIoU") |
| axes[1].set_xlabel("Epoch") |
| axes[1].legend(fontsize=7) |
|
|
| out = Path("checkpoints/training_curves.png") |
| plt.tight_layout() |
| plt.savefig(out, dpi=150) |
| plt.close() |
| print(f"\nπ Training curves saved to: {out.resolve()}") |
| except Exception as e: |
| print(f" (Could not save plot: {e})") |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--data_dir", type=str, default="data") |
| parser.add_argument("--epochs", type=int, default=EPOCHS) |
| parser.add_argument("--lr", type=float, default=LR) |
| parser.add_argument("--image_size", type=int, default=256, |
| help="Square image size (256 recommended for CPU)") |
| args = parser.parse_args() |
|
|
| |
| IMAGE_SIZE = (args.image_size, args.image_size) |
| |
| import dataset as ds_module |
| ds_module.IMAGE_SIZE = IMAGE_SIZE |
|
|
| data_root = Path(args.data_dir) |
| histories = {} |
|
|
| for subset in SUBSETS: |
| h = train_subset(subset, data_root, args) |
| if h: |
| histories[subset] = h |
|
|
| if histories: |
| plot_histories(histories) |
| print("\nβ
All subsets complete.") |
| print("\nSummary:") |
| for subset, h in histories.items(): |
| best = max(h["val_miou"]) |
| print(f" {subset:20s} best mIoU = {best:.4f}") |
|
|