| import argparse |
| import os |
| import pprint |
| import yaml |
| from typing import Tuple, List, Optional, Dict |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.amp import autocast |
| from torch.amp import GradScaler |
| from tqdm import tqdm |
| import random |
| import torch.backends.cudnn as cudnn |
| import cv2 |
| from torch.utils.data import DataLoader |
| import time |
|
|
| from src.wireseghr.model import WireSegHR |
| from src.wireseghr.model.minmax import MinMaxLuminance |
| from src.wireseghr.data.dataset import WireSegDataset |
| from src.wireseghr.model.label_downsample import downsample_label_maxpool |
| from src.wireseghr.data.sampler import BalancedPatchSampler |
| from src.wireseghr.metrics import compute_metrics |
| from infer import _coarse_forward, _tiled_fine_forward |
| from pathlib import Path |
|
|
|
|
| class SizeBatchSampler: |
| """Batch sampler that groups indices by exact (H, W) so all samples in a batch share size. |
| |
| This enables DataLoader prefetching while preserving the existing assumption |
| in `_prepare_batch()` that all items in a batch have the same full resolution. |
| """ |
|
|
| def __init__(self, dset: WireSegDataset, batch_size: int): |
| self.dset = dset |
| self.batch_size = batch_size |
| |
| bins = self.dset.size_bins |
| self._len = 0 |
| for hw, idxs in bins.items(): |
| _ = hw |
| self._len += len(idxs) // self.batch_size |
|
|
| def __len__(self) -> int: |
| return self._len |
|
|
| def __iter__(self): |
| |
| bins = self.dset.size_bins |
| keys = list(bins.keys()) |
| random.shuffle(keys) |
| for hw in keys: |
| pool = list(bins[hw]) |
| random.shuffle(pool) |
| |
| for i in range( |
| 0, len(pool) - (len(pool) % self.batch_size), self.batch_size |
| ): |
| yield pool[i : i + self.batch_size] |
|
|
|
|
| def collate_train(batch: List[Dict]): |
| """Collate function that returns lists of numpy arrays to match existing pipeline.""" |
| imgs = [b["image"] for b in batch] |
| masks = [b["mask"] for b in batch] |
| return imgs, masks |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="WireSegHR training (skeleton)") |
| parser.add_argument( |
| "--config", type=str, default="configs/default.yaml", help="Path to YAML config" |
| ) |
| args = parser.parse_args() |
|
|
| cfg_path = args.config |
| if not Path(cfg_path).is_absolute(): |
| cfg_path = str(Path.cwd() / cfg_path) |
|
|
| with open(cfg_path, "r") as f: |
| cfg = yaml.safe_load(f) |
|
|
| print("[WireSegHR][train] Loaded config from:", cfg_path) |
| pprint.pprint(cfg) |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"[WireSegHR][train] Device: {device}") |
|
|
| |
| coarse_train = int(cfg["coarse"]["train_size"]) |
| coarse_test = int(cfg["coarse"]["test_size"]) |
| patch_size = int(cfg["fine"]["patch_size"]) |
| overlap = int(cfg["fine"]["overlap"]) |
| eval_patch_size = int(cfg["inference"]["fine_patch_size"]) |
| eval_cfg = cfg.get("eval", {}) |
| eval_fine_batch = int(eval_cfg.get("fine_batch", 16)) |
| assert eval_fine_batch >= 1 |
| eval_max_samples = int(eval_cfg.get("max_samples", 16)) |
| assert eval_max_samples >= 1 |
| iters = int(cfg["optim"]["iters"]) |
| batch_size = int(cfg["optim"]["batch_size"]) |
| base_lr = float(cfg["optim"]["lr"]) |
| weight_decay = float(cfg["optim"]["weight_decay"]) |
| power = float(cfg["optim"]["power"]) |
| precision = str(cfg["optim"].get("precision", "fp32")).lower() |
| assert precision in ("fp32", "fp16", "bf16") |
| |
| amp_enabled = (device.type == "cuda") and (precision in ("fp16", "bf16")) |
| |
| if amp_enabled: |
| cc_major, cc_minor = torch.cuda.get_device_capability() |
| if precision == "fp16": |
| assert cc_major >= 7, ( |
| f"fp16 requires Volta (SM 7.0)+; current SM {cc_major}.{cc_minor}" |
| ) |
| elif precision == "bf16": |
| assert cc_major >= 8, ( |
| f"bf16 requires Ampere (SM 8.0)+; current SM {cc_major}.{cc_minor}" |
| ) |
| amp_dtype = ( |
| torch.float16 |
| if precision == "fp16" |
| else (torch.bfloat16 if precision == "bf16" else None) |
| ) |
|
|
| |
| seed = int(cfg.get("seed", 42)) |
| out_dir = cfg.get("out_dir", "runs/wireseghr") |
| eval_interval = int(cfg["eval_interval"]) |
| ckpt_interval = int(cfg["ckpt_interval"]) |
| os.makedirs(out_dir, exist_ok=True) |
| set_seed(seed) |
|
|
| |
| train_images = cfg["data"]["train_images"] |
| train_masks = cfg["data"]["train_masks"] |
| dset = WireSegDataset(train_images, train_masks, split="train") |
| |
| loader_cfg = cfg.get("loader", {}) |
| num_workers = int(loader_cfg.get("num_workers", 4)) |
| prefetch_factor = int(loader_cfg.get("prefetch_factor", 2)) |
| pin_memory = bool(loader_cfg.get("pin_memory", True)) |
| persistent_workers = ( |
| bool(loader_cfg.get("persistent_workers", True)) if num_workers > 0 else False |
| ) |
| batch_sampler = SizeBatchSampler(dset, batch_size) |
| loader_kwargs = dict( |
| batch_sampler=batch_sampler, |
| num_workers=num_workers, |
| pin_memory=pin_memory, |
| persistent_workers=persistent_workers, |
| collate_fn=collate_train, |
| ) |
| if num_workers > 0: |
| loader_kwargs["prefetch_factor"] = prefetch_factor |
| train_loader = DataLoader(dset, **loader_kwargs) |
| |
| val_images = cfg["data"].get("val_images", None) |
| val_masks = cfg["data"].get("val_masks", None) |
| test_images = cfg["data"].get("test_images", None) |
| test_masks = cfg["data"].get("test_masks", None) |
| dset_val = ( |
| WireSegDataset(val_images, val_masks, split="val") |
| if val_images and val_masks |
| else None |
| ) |
| dset_test = ( |
| WireSegDataset(test_images, test_masks, split="test") |
| if test_images and test_masks |
| else None |
| ) |
| sampler = BalancedPatchSampler(patch_size=patch_size, min_wire_ratio=0.01) |
| minmax = ( |
| MinMaxLuminance(kernel=cfg["minmax"]["kernel"]) |
| if cfg["minmax"]["enable"] |
| else None |
| ) |
|
|
| |
| prob_thresh = float(cfg["inference"]["prob_threshold"]) |
| mm_enable = bool(cfg["minmax"]["enable"]) |
| mm_kernel = int(cfg["minmax"]["kernel"]) |
|
|
| |
| |
| pretrained_flag = bool(cfg.get("pretrained", False)) |
| model = WireSegHR( |
| backbone=cfg["backbone"], in_channels=6, pretrained=pretrained_flag |
| ) |
| model = model.to(device) |
|
|
| |
| optim = torch.optim.AdamW(model.parameters(), lr=base_lr, weight_decay=weight_decay) |
| scaler = GradScaler("cuda", enabled=(device.type == "cuda" and precision == "fp16")) |
| ce = nn.CrossEntropyLoss() |
|
|
| |
| start_step = 0 |
| best_f1 = -1.0 |
| resume_path = cfg.get("resume", None) |
| if resume_path and Path(resume_path).is_file(): |
| print(f"[WireSegHR][train] Resuming from {resume_path}") |
| start_step, best_f1 = _load_checkpoint( |
| resume_path, model, optim, scaler, device |
| ) |
|
|
| |
| model.train() |
| step = start_step |
| pbar = tqdm(total=iters - step, initial=0, desc="Train", ncols=100) |
| data_iter = iter(train_loader) |
| while step < iters: |
| optim.zero_grad(set_to_none=True) |
| try: |
| imgs, masks = next(data_iter) |
| except StopIteration: |
| data_iter = iter(train_loader) |
| imgs, masks = next(data_iter) |
| batch = _prepare_batch( |
| imgs, masks, coarse_train, patch_size, sampler, minmax, device |
| ) |
|
|
| with autocast(device_type=device.type, dtype=amp_dtype, enabled=amp_enabled): |
| logits_coarse, cond_map = model.forward_coarse( |
| batch["x_coarse"] |
| ) |
|
|
| |
| B, _, hc4, wc4 = cond_map.shape |
| x_fine = _build_fine_inputs(batch, cond_map, device) |
| with autocast(device_type=device.type, dtype=amp_dtype, enabled=amp_enabled): |
| logits_fine = model.forward_fine(x_fine) |
|
|
| |
| y_coarse = _build_coarse_targets(batch["mask_full"], hc4, wc4, device) |
| y_fine = _build_fine_targets( |
| batch["mask_patches"], |
| logits_fine.shape[2], |
| logits_fine.shape[3], |
| device, |
| ) |
|
|
| loss_coarse = ce(logits_coarse, y_coarse) |
| loss_fine = ce(logits_fine, y_fine) |
| loss = loss_coarse + loss_fine |
|
|
| scaler.scale(loss).backward() |
| scaler.step(optim) |
| scaler.update() |
|
|
| |
| lr = base_lr * ((1.0 - float(step) / float(iters)) ** power) |
| for pg in optim.param_groups: |
| pg["lr"] = lr |
|
|
| if step % 50 == 0: |
| print(f"[Iter {step}/{iters}] lr={lr:.6e}") |
|
|
| |
| if (step % eval_interval == 0) and (dset_val is not None): |
| |
| del ( |
| x_fine, |
| logits_coarse, |
| cond_map, |
| logits_fine, |
| y_coarse, |
| y_fine, |
| loss_coarse, |
| loss_fine, |
| loss, |
| ) |
| torch.cuda.empty_cache() |
| model.eval() |
| print( |
| f"[WireSegHR][train] Eval starting... val_size={len(dset_val)} max={eval_max_samples} patch={eval_patch_size} overlap={overlap} stride={eval_patch_size - overlap} fine_batch={eval_fine_batch}", |
| flush=True, |
| ) |
| val_stats = validate( |
| model, |
| dset_val, |
| coarse_test, |
| device, |
| amp_enabled, |
| amp_dtype, |
| prob_thresh, |
| mm_enable, |
| mm_kernel, |
| eval_patch_size, |
| overlap, |
| eval_fine_batch, |
| eval_max_samples, |
| ) |
| print( |
| f"[Val @ {step}][Fine] IoU={val_stats['iou']:.4f} F1={val_stats['f1']:.4f} P={val_stats['precision']:.4f} R={val_stats['recall']:.4f}" |
| ) |
| print( |
| f"[Val @ {step}][Coarse] IoU={val_stats['iou_coarse']:.4f} F1={val_stats['f1_coarse']:.4f} P={val_stats['precision_coarse']:.4f} R={val_stats['recall_coarse']:.4f}" |
| ) |
| |
| if val_stats["f1"] > best_f1: |
| best_f1 = val_stats["f1"] |
| _save_checkpoint( |
| str(Path(out_dir) / "best.pt"), |
| step, |
| model, |
| optim, |
| scaler, |
| best_f1, |
| ) |
| |
| if ckpt_interval > 0 and (step % ckpt_interval == 0): |
| _save_checkpoint( |
| str(Path(out_dir) / f"ckpt_{step}.pt"), |
| step, |
| model, |
| optim, |
| scaler, |
| best_f1, |
| ) |
| |
| if dset_test is not None: |
| save_test_visuals( |
| model, |
| dset_test, |
| coarse_test, |
| device, |
| str(Path(out_dir) / f"test_vis_{step}"), |
| amp_enabled, |
| mm_enable, |
| mm_kernel, |
| prob_thresh, |
| max_samples=8, |
| ) |
| model.train() |
|
|
| step += 1 |
| pbar.update(1) |
|
|
| |
| _save_checkpoint( |
| str(Path(out_dir) / f"ckpt_{iters}.pt"), step, model, optim, scaler, best_f1 |
| ) |
|
|
| |
| if dset_test is not None: |
| torch.cuda.empty_cache() |
| model.eval() |
| print( |
| f"[WireSegHR][train] Final test starting... test_size={len(dset_test)} patch={eval_patch_size} overlap={overlap} stride={eval_patch_size - overlap} fine_batch={eval_fine_batch}", |
| flush=True, |
| ) |
| test_stats = validate( |
| model, |
| dset_test, |
| coarse_test, |
| device, |
| amp_enabled, |
| amp_dtype, |
| prob_thresh, |
| mm_enable, |
| mm_kernel, |
| eval_patch_size, |
| overlap, |
| eval_fine_batch, |
| len(dset_test), |
| ) |
| print( |
| f"[Test Final][Fine] IoU={test_stats['iou']:.4f} F1={test_stats['f1']:.4f} P={test_stats['precision']:.4f} R={test_stats['recall']:.4f}" |
| ) |
| print( |
| f"[Test Final][Coarse] IoU={test_stats['iou_coarse']:.4f} F1={test_stats['f1_coarse']:.4f} P={test_stats['precision_coarse']:.4f} R={test_stats['recall_coarse']:.4f}" |
| ) |
| |
| final_out = Path(out_dir) / f"final_vis_{step}" |
| final_out.mkdir(parents=True, exist_ok=True) |
| |
| with open(final_out / "metrics.yaml", "w") as f: |
| yaml.safe_dump({**test_stats, "step": step}, f, sort_keys=False) |
| |
| save_final_visuals( |
| model, |
| dset_test, |
| coarse_test, |
| device, |
| str(final_out), |
| amp_enabled, |
| amp_dtype, |
| prob_thresh, |
| mm_enable, |
| mm_kernel, |
| eval_patch_size, |
| overlap, |
| eval_fine_batch, |
| ) |
| model.train() |
|
|
| print("[WireSegHR][train] Done.") |
|
|
|
|
|
|
| def _prepare_batch( |
| imgs: List[np.ndarray], |
| masks: List[np.ndarray], |
| coarse_train: int, |
| patch_size: int, |
| sampler: BalancedPatchSampler, |
| minmax: Optional[MinMaxLuminance], |
| device: torch.device, |
| ): |
| B = len(imgs) |
| assert B == len(masks) |
| |
|
|
| full_h = imgs[0].shape[0] |
| full_w = imgs[0].shape[1] |
| for im, m in zip(imgs, masks): |
| assert im.shape[0] == full_h and im.shape[1] == full_w |
| assert m.shape[0] == full_h and m.shape[1] == full_w |
|
|
| xs_coarse = [] |
| patches_rgb = [] |
| patches_mask = [] |
| patches_min = [] |
| patches_max = [] |
| yx_list: List[tuple[int, int]] = [] |
|
|
| for img, mask in zip(imgs, masks): |
| |
| imgf = img.astype(np.float32) / 255.0 |
| t_img = ( |
| torch.from_numpy(np.transpose(imgf, (2, 0, 1))).unsqueeze(0).to(device) |
| ) |
|
|
| |
| y_t = ( |
| 0.299 * t_img[:, 0:1] + 0.587 * t_img[:, 1:2] + 0.114 * t_img[:, 2:3] |
| ) |
| if minmax is not None: |
| |
| y_p = F.pad(y_t, (2, 3, 2, 3), mode="replicate") |
| y_max_full = F.max_pool2d(y_p, kernel_size=6, stride=1) |
| y_min_full = -F.max_pool2d(-y_p, kernel_size=6, stride=1) |
| else: |
| y_min_full = y_t |
| y_max_full = y_t |
|
|
| |
| rgb_coarse_t = F.interpolate( |
| t_img, |
| size=(coarse_train, coarse_train), |
| mode="bilinear", |
| align_corners=False, |
| )[0] |
| y_min_c_t = F.interpolate( |
| y_min_full, |
| size=(coarse_train, coarse_train), |
| mode="bilinear", |
| align_corners=False, |
| )[0] |
| y_max_c_t = F.interpolate( |
| y_max_full, |
| size=(coarse_train, coarse_train), |
| mode="bilinear", |
| align_corners=False, |
| )[0] |
| zeros_coarse = torch.zeros(1, coarse_train, coarse_train, device=device) |
| c_t = torch.cat( |
| [rgb_coarse_t, y_min_c_t, y_max_c_t, zeros_coarse], dim=0 |
| ) |
| xs_coarse.append(c_t) |
|
|
| |
| y0, x0 = sampler.sample(imgf, mask) |
| patch_rgb = imgf[y0 : y0 + patch_size, x0 : x0 + patch_size, :] |
| patch_mask = mask[y0 : y0 + patch_size, x0 : x0 + patch_size] |
| patches_rgb.append(patch_rgb) |
| patches_mask.append(patch_mask) |
| ymin_patch = ( |
| y_min_full[0, 0, y0 : y0 + patch_size, x0 : x0 + patch_size] |
| .detach() |
| .cpu() |
| .numpy() |
| ) |
| ymax_patch = ( |
| y_max_full[0, 0, y0 : y0 + patch_size, x0 : x0 + patch_size] |
| .detach() |
| .cpu() |
| .numpy() |
| ) |
| patches_min.append(ymin_patch) |
| patches_max.append(ymax_patch) |
| yx_list.append((y0, x0)) |
|
|
| x_coarse = torch.stack(xs_coarse, dim=0) |
|
|
| |
| return { |
| "x_coarse": x_coarse, |
| "full_h": full_h, |
| "full_w": full_w, |
| "rgb_patches": patches_rgb, |
| "mask_patches": patches_mask, |
| "ymin_patches": patches_min, |
| "ymax_patches": patches_max, |
| "patch_yx": yx_list, |
| "mask_full": masks, |
| } |
|
|
|
|
| def _build_fine_inputs( |
| batch, cond_map: torch.Tensor, device: torch.device |
| ) -> torch.Tensor: |
| |
| B = cond_map.shape[0] |
| P = batch["rgb_patches"][0].shape[0] |
| full_h, full_w = batch["full_h"], batch["full_w"] |
| hc4, wc4 = cond_map.shape[2], cond_map.shape[3] |
|
|
| xs: List[torch.Tensor] = [] |
| for i in range(B): |
| rgb = batch["rgb_patches"][i] |
| ymin = batch["ymin_patches"][i] |
| ymax = batch["ymax_patches"][i] |
| y0, x0 = batch["patch_yx"][i] |
|
|
| |
| y1, x1 = y0 + P, x0 + P |
| y0c = (y0 * hc4) // full_h |
| y1c = ((y1 * hc4) + full_h - 1) // full_h |
| x0c = (x0 * wc4) // full_w |
| x1c = ((x1 * wc4) + full_w - 1) // full_w |
| cond_sub = cond_map[i : i + 1, :, y0c:y1c, x0c:x1c].float() |
| cond_patch = F.interpolate( |
| cond_sub, size=(P, P), mode="bilinear", align_corners=False |
| ).squeeze(1) |
|
|
| |
| rgb_t = ( |
| torch.from_numpy(np.transpose(rgb, (2, 0, 1))).to(device).float() |
| ) |
| ymin_t = torch.from_numpy(ymin)[None, ...].to(device).float() |
| ymax_t = torch.from_numpy(ymax)[None, ...].to(device).float() |
| x = torch.cat([rgb_t, ymin_t, ymax_t, cond_patch], dim=0) |
| xs.append(x) |
| x_fine = torch.stack(xs, dim=0) |
| return x_fine |
|
|
|
|
| def _build_coarse_targets( |
| masks: List[np.ndarray], out_h: int, out_w: int, device: torch.device |
| ) -> torch.Tensor: |
| ys: List[torch.Tensor] = [] |
| for m in masks: |
| dm = downsample_label_maxpool(m, out_h, out_w) |
| ys.append(torch.from_numpy(dm.astype(np.int64))) |
| y = torch.stack(ys, dim=0).to(device) |
| return y |
|
|
|
|
| def _build_fine_targets( |
| mask_patches: List[np.ndarray], out_h: int, out_w: int, device: torch.device |
| ) -> torch.Tensor: |
| ys: List[torch.Tensor] = [] |
| for m in mask_patches: |
| dm = downsample_label_maxpool(m, out_h, out_w) |
| ys.append(torch.from_numpy(dm.astype(np.int64))) |
| y = torch.stack(ys, dim=0).to(device) |
| return y |
|
|
|
|
| def set_seed(seed: int): |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(seed) |
| |
| |
| cudnn.benchmark = True |
| cudnn.deterministic = False |
|
|
|
|
| def _save_checkpoint( |
| path: str, |
| step: int, |
| model: nn.Module, |
| optim: torch.optim.Optimizer, |
| scaler: GradScaler, |
| best_f1: float, |
| ): |
| Path(path).parent.mkdir(parents=True, exist_ok=True) |
| state = { |
| "step": step, |
| "model": model.state_dict(), |
| "optim": optim.state_dict(), |
| "scaler": scaler.state_dict(), |
| "best_f1": best_f1, |
| } |
| torch.save(state, path) |
| print(f"[WireSegHR][train] Saved checkpoint: {path}") |
|
|
|
|
| def _load_checkpoint( |
| path: str, |
| model: nn.Module, |
| optim: torch.optim.Optimizer, |
| scaler: GradScaler, |
| device: torch.device, |
| ) -> Tuple[int, float]: |
| ckpt = torch.load(path, map_location=device) |
| model.load_state_dict(ckpt["model"]) |
| optim.load_state_dict(ckpt["optim"]) |
| try: |
| scaler.load_state_dict(ckpt["scaler"]) |
| except Exception: |
| pass |
| step = int(ckpt.get("step", 0)) |
| best_f1 = float(ckpt.get("best_f1", -1.0)) |
| return step, best_f1 |
|
|
|
|
| @torch.no_grad() |
| def validate( |
| model: WireSegHR, |
| dset_val: WireSegDataset, |
| coarse_size: int, |
| device: torch.device, |
| amp_flag: bool, |
| amp_dtype, |
| prob_thresh: float, |
| minmax_enable: bool, |
| minmax_kernel: int, |
| fine_patch_size: int, |
| fine_overlap: int, |
| fine_batch: int, |
| max_images: int, |
| ) -> Dict[str, float]: |
| |
| model = model.to(device) |
| metrics_sum = {"iou": 0.0, "f1": 0.0, "precision": 0.0, "recall": 0.0} |
| coarse_sum = {"iou": 0.0, "f1": 0.0, "precision": 0.0, "recall": 0.0} |
| n = 0 |
| t0 = time.time() |
| total_tiles = 0 |
| target_n = min(len(dset_val), max_images) |
| idxs = random.sample(range(len(dset_val)), k=target_n) |
| print( |
| f"[Eval] Started: N={target_n}/{len(dset_val)} coarse={coarse_size} patch={fine_patch_size} overlap={fine_overlap} stride={fine_patch_size - fine_overlap} fine_batch={fine_batch}", |
| flush=True, |
| ) |
| for j, i in enumerate(idxs): |
| if (j % 2) == 0: |
| print(f"[Eval] Running... {j}/{target_n}", flush=True) |
| item = dset_val[i] |
| img = item["image"].astype(np.float32) / 255.0 |
| mask = item["mask"].astype(np.uint8) |
| H, W = mask.shape |
| |
| prob_up, cond_map, t_img, y_min_full, y_max_full = _coarse_forward( |
| model, |
| img, |
| coarse_size, |
| minmax_enable, |
| int(minmax_kernel), |
| device, |
| amp_flag, |
| amp_dtype, |
| ) |
| |
| pred_coarse = (prob_up > prob_thresh).to(torch.uint8).cpu().numpy() |
| m_c = compute_metrics(pred_coarse, mask) |
| for k in coarse_sum: |
| coarse_sum[k] += m_c[k] |
|
|
| |
| prob_full = _tiled_fine_forward( |
| model, |
| t_img, |
| cond_map, |
| y_min_full, |
| y_max_full, |
| int(fine_patch_size), |
| int(fine_overlap), |
| int(fine_batch), |
| device, |
| amp_flag, |
| amp_dtype, |
| ) |
| |
| P = int(fine_patch_size) |
| stride = P - int(fine_overlap) |
| ys = list(range(0, H - P + 1, stride)) |
| if ys[-1] != (H - P): |
| ys.append(H - P) |
| xs = list(range(0, W - P + 1, stride)) |
| if xs[-1] != (W - P): |
| xs.append(W - P) |
| total_tiles += len(ys) * len(xs) |
| pred_fine = (prob_full > prob_thresh).to(torch.uint8).cpu().numpy() |
| m_f = compute_metrics(pred_fine, mask) |
| for k in metrics_sum: |
| metrics_sum[k] += m_f[k] |
| n += 1 |
| if n > 0: |
| for k in metrics_sum: |
| metrics_sum[k] /= n |
| for k in coarse_sum: |
| coarse_sum[k] /= n |
| dt = time.time() - t0 |
| tp_img = (n / dt) if dt > 0 else 0.0 |
| tp_tile = (total_tiles / dt) if dt > 0 else 0.0 |
| print( |
| f"[Eval] Done in {dt:.2f}s | imgs={n}, tiles={total_tiles}, imgs/s={tp_img:.2f}, tiles/s={tp_tile:.2f}", |
| flush=True, |
| ) |
| out = {k: v for k, v in metrics_sum.items()} |
| out.update( |
| { |
| "iou_coarse": coarse_sum["iou"], |
| "f1_coarse": coarse_sum["f1"], |
| "precision_coarse": coarse_sum["precision"], |
| "recall_coarse": coarse_sum["recall"], |
| } |
| ) |
| return out |
|
|
|
|
| @torch.no_grad() |
| def save_test_visuals( |
| model: WireSegHR, |
| dset_test: WireSegDataset, |
| coarse_size: int, |
| device: torch.device, |
| out_dir: str, |
| amp_flag: bool, |
| minmax_enable: bool, |
| minmax_kernel: int, |
| prob_thresh: float, |
| max_samples: int = 8, |
| ): |
| Path(out_dir).mkdir(parents=True, exist_ok=True) |
| for i in range(min(max_samples, len(dset_test))): |
| item = dset_test[i] |
| img = item["image"].astype(np.float32) / 255.0 |
| H, W = img.shape[:2] |
| prob_up, _cond_map, _t_img, _ymin, _ymax = _coarse_forward( |
| model, |
| img, |
| int(coarse_size), |
| bool(minmax_enable), |
| int(minmax_kernel), |
| device, |
| bool(amp_flag), |
| None, |
| ) |
| pred = ((prob_up > prob_thresh).to(torch.uint8) * 255).cpu().numpy() |
| |
| img_bgr = (img[..., ::-1] * 255.0).astype(np.uint8) |
| cv2.imwrite(str(Path(out_dir) / f"{i:03d}_input.jpg"), img_bgr) |
| cv2.imwrite(str(Path(out_dir) / f"{i:03d}_pred.png"), pred) |
|
|
|
|
| @torch.no_grad() |
| def save_final_visuals( |
| model: WireSegHR, |
| dset_test: WireSegDataset, |
| coarse_size: int, |
| device: torch.device, |
| out_dir: str, |
| amp_flag: bool, |
| amp_dtype, |
| prob_thresh: float, |
| minmax_enable: bool, |
| minmax_kernel: int, |
| fine_patch_size: int, |
| fine_overlap: int, |
| fine_batch: int, |
| ): |
| Path(out_dir).mkdir(parents=True, exist_ok=True) |
| for i in range(len(dset_test)): |
| item = dset_test[i] |
| img = item["image"].astype(np.float32) / 255.0 |
| H, W = img.shape[:2] |
| |
| prob_up, cond_map, t_img, y_min_full, y_max_full = _coarse_forward( |
| model, |
| img, |
| int(coarse_size), |
| bool(minmax_enable), |
| int(minmax_kernel), |
| device, |
| bool(amp_flag), |
| amp_dtype, |
| ) |
| pred_coarse = ((prob_up > prob_thresh).to(torch.uint8) * 255).cpu().numpy() |
| |
| prob_full = _tiled_fine_forward( |
| model, |
| t_img, |
| cond_map, |
| y_min_full, |
| y_max_full, |
| int(fine_patch_size), |
| int(fine_overlap), |
| int(fine_batch), |
| device, |
| bool(amp_flag), |
| amp_dtype, |
| ) |
| pred_fine = ((prob_full > prob_thresh).to(torch.uint8) * 255).cpu().numpy() |
| |
| img_bgr = (img[..., ::-1] * 255.0).astype(np.uint8) |
| base = f"{i:03d}" |
| cv2.imwrite(str(Path(out_dir) / f"{base}_input.jpg"), img_bgr) |
| cv2.imwrite(str(Path(out_dir) / f"{base}_coarse_pred.png"), pred_coarse) |
| cv2.imwrite(str(Path(out_dir) / f"{base}_fine_pred.png"), pred_fine) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|