| import argparse |
| import os |
| import pprint |
| import time |
| from typing import List, Tuple, Optional, Dict, Any |
| import numpy as np |
| import cv2 |
| import torch |
| import torch.nn.functional as F |
| from torch.amp import autocast |
| from tqdm import tqdm |
| from safetensors.torch import load_file as safe_load_file |
| import yaml |
|
|
| from src.wireseghr.model import WireSegHR |
| from pathlib import Path |
| from src.wireseghr.metrics import compute_metrics |
|
|
|
|
| def _pad_for_minmax(kernel: int) -> Tuple[int, int, int, int]: |
| |
| if (kernel % 2) == 0: |
| return (kernel // 2 - 1, kernel // 2, kernel // 2 - 1, kernel // 2) |
| else: |
| return (kernel // 2, kernel // 2, kernel // 2, kernel // 2) |
|
|
|
|
| @torch.no_grad() |
| def _coarse_forward( |
| model: WireSegHR, |
| img_rgb: np.ndarray, |
| coarse_size: int, |
| minmax_enable: bool, |
| minmax_kernel: int, |
| device: torch.device, |
| amp_flag: bool, |
| amp_dtype, |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| |
| t_img = ( |
| torch.from_numpy(np.transpose(img_rgb, (2, 0, 1))) |
| .unsqueeze(0) |
| .to(device) |
| .float() |
| ) |
| H = img_rgb.shape[0] |
| W = img_rgb.shape[1] |
|
|
| rgb_c = F.interpolate( |
| t_img, size=(coarse_size, coarse_size), mode="bilinear", align_corners=False |
| )[0] |
| y_t = 0.299 * t_img[:, 0:1] + 0.587 * t_img[:, 1:2] + 0.114 * t_img[:, 2:3] |
| if minmax_enable: |
| pad = _pad_for_minmax(minmax_kernel) |
| y_p = F.pad(y_t, pad, mode="replicate") |
| y_max_full = F.max_pool2d(y_p, kernel_size=minmax_kernel, stride=1) |
| y_min_full = -F.max_pool2d(-y_p, kernel_size=minmax_kernel, stride=1) |
| else: |
| y_min_full = y_t |
| y_max_full = y_t |
| y_min_c = F.interpolate( |
| y_min_full, |
| size=(coarse_size, coarse_size), |
| mode="bilinear", |
| align_corners=False, |
| )[0] |
| y_max_c = F.interpolate( |
| y_max_full, |
| size=(coarse_size, coarse_size), |
| mode="bilinear", |
| align_corners=False, |
| )[0] |
| zeros_c = torch.zeros(1, coarse_size, coarse_size, device=device) |
| x_t = torch.cat([rgb_c, y_min_c, y_max_c, zeros_c], dim=0).unsqueeze(0) |
|
|
| with autocast(device_type=device.type, dtype=amp_dtype, enabled=amp_flag): |
| logits_c, cond_map = model.forward_coarse(x_t) |
| prob = torch.softmax(logits_c, dim=1)[:, 1:2] |
| prob_up = ( |
| F.interpolate(prob, size=(H, W), mode="bilinear", align_corners=False)[0, 0] |
| .detach() |
| .cpu() |
| .float() |
| ) |
| return prob_up, cond_map, t_img, y_min_full, y_max_full |
|
|
|
|
| @torch.no_grad() |
| def _tiled_fine_forward( |
| model: WireSegHR, |
| t_img: torch.Tensor, |
| cond_map: torch.Tensor, |
| y_min_full: torch.Tensor, |
| y_max_full: torch.Tensor, |
| patch_size: int, |
| overlap: int, |
| fine_batch: int, |
| device: torch.device, |
| amp_flag: bool, |
| amp_dtype, |
| ) -> torch.Tensor: |
| H = int(t_img.shape[2]) |
| W = int(t_img.shape[3]) |
| P = patch_size |
| stride = P - overlap |
| assert stride > 0 |
| assert H >= P and W >= P |
|
|
| prob_sum_t = torch.zeros((H, W), device=device, dtype=torch.float32) |
| weight_t = torch.zeros((H, W), device=device, dtype=torch.float32) |
|
|
| hc4, wc4 = cond_map.shape[2], cond_map.shape[3] |
|
|
| ys = list(range(0, H - P + 1, stride)) |
| if ys[-1] != (H - P): |
| ys.append(H - P) |
| xs = list(range(0, W - P + 1, stride)) |
| if xs[-1] != (W - P): |
| xs.append(W - P) |
|
|
| coords: List[Tuple[int, int]] = [] |
| for y0 in ys: |
| for x0 in xs: |
| coords.append((y0, x0)) |
|
|
| for i0 in range(0, len(coords), fine_batch): |
| batch_coords = coords[i0 : i0 + fine_batch] |
| xs_list: List[torch.Tensor] = [] |
| for y0, x0 in batch_coords: |
| y1, x1 = y0 + P, x0 + P |
| |
| y0c = (y0 * hc4) // H |
| y1c = ((y1 * hc4) + H - 1) // H |
| x0c = (x0 * wc4) // W |
| x1c = ((x1 * wc4) + W - 1) // W |
| cond_sub = cond_map[:, :, y0c:y1c, x0c:x1c].float() |
| cond_patch = F.interpolate( |
| cond_sub, size=(P, P), mode="bilinear", align_corners=False |
| ).squeeze(1) |
|
|
| rgb_t = t_img[0, :, y0:y1, x0:x1] |
| ymin_t = y_min_full[0, 0, y0:y1, x0:x1].float().unsqueeze(0) |
| ymax_t = y_max_full[0, 0, y0:y1, x0:x1].float().unsqueeze(0) |
| x_f = torch.cat([rgb_t, ymin_t, ymax_t, cond_patch], dim=0).unsqueeze(0) |
| xs_list.append(x_f) |
|
|
| x_f_batch = torch.cat(xs_list, dim=0) |
| with autocast(device_type=device.type, dtype=amp_dtype, enabled=amp_flag): |
| logits_f = model.forward_fine(x_f_batch) |
| prob_f = torch.softmax(logits_f, dim=1)[:, 1:2] |
| prob_f_up = F.interpolate( |
| prob_f, size=(P, P), mode="bilinear", align_corners=False |
| )[:, 0, :, :] |
|
|
| for bi, (y0, x0) in enumerate(batch_coords): |
| y1, x1 = y0 + P, x0 + P |
| prob_sum_t[y0:y1, x0:x1] += prob_f_up[bi] |
| weight_t[y0:y1, x0:x1] += 1.0 |
|
|
| prob_full = (prob_sum_t / weight_t).detach().cpu().float() |
| return prob_full |
|
|
|
|
| def _build_model_from_cfg(cfg: dict, device: torch.device) -> WireSegHR: |
| pretrained_flag = bool(cfg.get("pretrained", False)) |
| model = WireSegHR( |
| backbone=cfg["backbone"], in_channels=6, pretrained=pretrained_flag |
| ) |
| model = model.to(device) |
| return model |
|
|
|
|
| @torch.no_grad() |
| def infer_image( |
| model: WireSegHR, |
| img_path: str, |
| cfg: dict, |
| device: torch.device, |
| amp_flag: bool, |
| amp_dtype, |
| out_dir: Optional[str] = None, |
| save_prob: bool = False, |
| prob_thresh: Optional[float] = None, |
| ) -> Tuple[np.ndarray, np.ndarray]: |
| assert Path(img_path).is_file(), f"Image not found: {img_path}" |
| bgr = cv2.imread(img_path, cv2.IMREAD_COLOR) |
| assert bgr is not None, f"Failed to read {img_path}" |
| rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 |
|
|
| coarse_size = int(cfg["coarse"]["test_size"]) |
| patch_size = int(cfg["inference"]["fine_patch_size"]) |
| overlap = int(cfg["fine"]["overlap"]) |
| minmax_enable = bool(cfg["minmax"]["enable"]) |
| minmax_kernel = int(cfg["minmax"]["kernel"]) |
| if prob_thresh is None: |
| prob_thresh = float(cfg["inference"]["prob_threshold"]) |
|
|
| prob_c, cond_map, t_img, y_min_full, y_max_full = _coarse_forward( |
| model, |
| rgb, |
| coarse_size, |
| minmax_enable, |
| minmax_kernel, |
| device, |
| amp_flag, |
| amp_dtype, |
| ) |
|
|
| prob_f = _tiled_fine_forward( |
| model, |
| t_img, |
| cond_map, |
| y_min_full, |
| y_max_full, |
| patch_size, |
| overlap, |
| int(cfg.get("eval", {}).get("fine_batch", 16)), |
| device, |
| amp_flag, |
| amp_dtype, |
| ) |
|
|
| |
| pred_t = (prob_f > prob_thresh).to(torch.uint8) * 255 |
| pred = pred_t.detach().cpu().numpy() |
|
|
| if out_dir is not None: |
| os.makedirs(out_dir, exist_ok=True) |
| stem = Path(img_path).stem |
| out_mask = Path(out_dir) / f"{stem}_pred.png" |
| cv2.imwrite(str(out_mask), pred) |
| if save_prob: |
| out_prob = Path(out_dir) / f"{stem}_prob.npy" |
| np.save(out_prob, prob_f.detach().cpu().float().numpy()) |
|
|
| |
| return pred, prob_f.detach().cpu().numpy() |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="WireSegHR inference") |
| parser.add_argument( |
| "--config", type=str, default="configs/default.yaml", help="Path to YAML config" |
| ) |
| parser.add_argument("--image", type=str, required=False, help="Path to input image") |
| parser.add_argument( |
| "--images_dir", |
| type=str, |
| required=False, |
| help="Directory with .jpg/.jpeg images", |
| ) |
| parser.add_argument( |
| "--out", type=str, default="outputs/infer", help="Directory to save predictions" |
| ) |
| parser.add_argument( |
| "--ckpt", |
| type=str, |
| default="", |
| help="Optional checkpoint (.pt with {'model': state_dict} or .safetensors with pure state_dict)", |
| ) |
| parser.add_argument( |
| "--save_prob", action="store_true", help="Also save probability .npy" |
| ) |
| |
| parser.add_argument( |
| "--metrics", |
| action="store_true", |
| help="Compute IoU, F1, Precision, Recall if ground-truth masks are provided", |
| ) |
| parser.add_argument( |
| "--mask", |
| type=str, |
| default="", |
| help="Path to ground-truth mask (.png) for --image when --metrics is set", |
| ) |
| parser.add_argument( |
| "--masks_dir", |
| type=str, |
| default="", |
| help="Directory with ground-truth masks (.png) for --images_dir when --metrics is set", |
| ) |
| |
| parser.add_argument( |
| "--benchmark", |
| action="store_true", |
| help="Run benchmarking on a directory (defaults to cfg.data.test_images)", |
| ) |
| parser.add_argument( |
| "--bench_images_dir", |
| type=str, |
| default="", |
| help="Images dir for benchmark (overrides cfg.data.test_images if set)", |
| ) |
| parser.add_argument( |
| "--bench_masks_dir", |
| type=str, |
| default="", |
| help="Masks dir for benchmark (overrides cfg.data.test_masks if set; used with --metrics)", |
| ) |
| parser.add_argument( |
| "--bench_limit", |
| type=int, |
| default=0, |
| help="Limit number of images for benchmark (0 means all)", |
| ) |
| parser.add_argument( |
| "--bench_warmup", |
| type=int, |
| default=2, |
| help="Number of warmup images (excluded from stats)", |
| ) |
| parser.add_argument( |
| "--bench_size_filter", |
| type=str, |
| default="", |
| help="Only benchmark images matching HxW, e.g. 3000x4000", |
| ) |
| parser.add_argument( |
| "--bench_report_json", |
| type=str, |
| default="", |
| help="Optional path to save JSON report of timings", |
| ) |
|
|
| args = parser.parse_args() |
|
|
| cfg_path = args.config |
| if not Path(cfg_path).is_absolute(): |
| cfg_path = str(Path.cwd() / cfg_path) |
|
|
| with open(cfg_path, "r") as f: |
| cfg = yaml.safe_load(f) |
|
|
| print("[WireSegHR][infer] Loaded config from:", cfg_path) |
| pprint.pprint(cfg) |
|
|
| |
| if not args.benchmark: |
| assert (args.image is not None) ^ (args.images_dir is not None), ( |
| "Provide exactly one of --image or --images_dir" |
| ) |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| precision = str(cfg["optim"].get("precision", "fp32")).lower() |
| assert precision in ("fp32", "fp16", "bf16") |
| amp_enabled = (device.type == "cuda") and (precision in ("fp16", "bf16")) |
| amp_dtype = ( |
| torch.float16 |
| if precision == "fp16" |
| else (torch.bfloat16 if precision == "bf16" else None) |
| ) |
|
|
| model = _build_model_from_cfg(cfg, device) |
|
|
| ckpt_path = args.ckpt if args.ckpt else cfg.get("resume", "") |
| if ckpt_path: |
| assert Path(ckpt_path).is_file(), f"Checkpoint not found: {ckpt_path}" |
| print(f"[WireSegHR][infer] Loading checkpoint: {ckpt_path}") |
| suffix = Path(ckpt_path).suffix.lower() |
| if suffix == ".safetensors": |
| |
| state_dict = safe_load_file(ckpt_path) |
| model.load_state_dict(state_dict) |
| else: |
| print( |
| "[WireSegHR][infer][WARN] Loading a PyTorch checkpoint. Prefer .safetensors for inference-only weights." |
| ) |
| |
| state = torch.load(ckpt_path, map_location=device) |
| assert "model" in state, ( |
| "Expected a dict with key 'model' for PyTorch checkpoint. " |
| "Use scripts/strip_checkpoint.py or provide a .safetensors file." |
| ) |
| model.load_state_dict(state["model"]) |
| model.eval() |
|
|
| |
| if args.benchmark: |
| |
| bench_dir = args.bench_images_dir or cfg["data"]["test_images"] |
| assert Path(bench_dir).is_dir(), f"Not a directory: {bench_dir}" |
| if args.metrics: |
| bench_masks_dir = args.bench_masks_dir or cfg["data"]["test_masks"] |
| assert Path(bench_masks_dir).is_dir(), f"Not a directory: {bench_masks_dir}" |
|
|
| |
| size_filter: Optional[Tuple[int, int]] = None |
| if args.bench_size_filter: |
| try: |
| h_str, w_str = args.bench_size_filter.lower().split("x") |
| size_filter = (int(h_str), int(w_str)) |
| except Exception: |
| raise AssertionError( |
| f"Invalid --bench_size_filter format: {args.bench_size_filter} (use HxW)" |
| ) |
|
|
| |
| img_files = sorted( |
| [ |
| str(Path(bench_dir) / p) |
| for p in os.listdir(bench_dir) |
| if p.lower().endswith((".jpg", ".jpeg")) |
| ] |
| ) |
| assert len(img_files) > 0, f"No .jpg/.jpeg in {bench_dir}" |
|
|
| |
| if size_filter is not None: |
| filt_files: List[str] = [] |
| for p in img_files: |
| bgr = cv2.imread(p, cv2.IMREAD_COLOR) |
| assert bgr is not None, f"Failed to read {p}" |
| if bgr.shape[0] == size_filter[0] and bgr.shape[1] == size_filter[1]: |
| filt_files.append(p) |
| img_files = filt_files |
| assert len(img_files) > 0, ( |
| f"No images matching {size_filter[0]}x{size_filter[1]} in {bench_dir}" |
| ) |
|
|
| |
| if args.bench_limit > 0: |
| img_files = img_files[: args.bench_limit] |
|
|
| print(f"[WireSegHR][bench] Images: {len(img_files)} from {bench_dir}") |
| print(f"[WireSegHR][bench] Warmup: {args.bench_warmup}") |
|
|
| def _sync(): |
| if device.type == "cuda": |
| torch.cuda.synchronize() |
|
|
| timings: List[Dict[str, Any]] = [] |
| |
| if args.metrics: |
| fine_sum: Dict[str, float] = {"iou": 0.0, "f1": 0.0, "precision": 0.0, "recall": 0.0} |
| coarse_sum: Dict[str, float] = {"iou": 0.0, "f1": 0.0, "precision": 0.0, "recall": 0.0} |
| n_metrics = 0 |
|
|
| |
| for i in tqdm(range(min(args.bench_warmup, len(img_files))), desc="[bench] Warmup"): |
| _ = infer_image( |
| model, |
| img_files[i], |
| cfg, |
| device, |
| amp_enabled, |
| amp_dtype, |
| out_dir=None, |
| save_prob=False, |
| ) |
|
|
| |
| for p in tqdm(img_files[args.bench_warmup :], desc="[bench] Timed"): |
| |
| bgr = cv2.imread(p, cv2.IMREAD_COLOR) |
| assert bgr is not None, f"Failed to read {p}" |
| rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 |
|
|
| coarse_size = int(cfg["coarse"]["test_size"]) |
| minmax_enable = bool(cfg["minmax"]["enable"]) |
| minmax_kernel = int(cfg["minmax"]["kernel"]) |
|
|
| _sync(); t0 = time.perf_counter() |
| prob_c, cond_map, t_img, y_min_full, y_max_full = _coarse_forward( |
| model, |
| rgb, |
| coarse_size, |
| minmax_enable, |
| minmax_kernel, |
| device, |
| amp_enabled, |
| amp_dtype, |
| ) |
| _sync(); t1 = time.perf_counter() |
|
|
| patch_size = int(cfg["inference"]["fine_patch_size"]) |
| overlap = int(cfg["fine"]["overlap"]) |
|
|
| prob_f = _tiled_fine_forward( |
| model, |
| t_img, |
| cond_map, |
| y_min_full, |
| y_max_full, |
| patch_size, |
| overlap, |
| int(cfg.get("eval", {}).get("fine_batch", 16)), |
| device, |
| amp_enabled, |
| amp_dtype, |
| ) |
| _sync(); t2 = time.perf_counter() |
|
|
| |
| if args.metrics: |
| stem = Path(p).stem |
| gt_path = Path(bench_masks_dir) / f"{stem}.png" |
| assert gt_path.is_file(), f"Missing mask for {stem}: {gt_path}" |
| gt = cv2.imread(str(gt_path), cv2.IMREAD_GRAYSCALE) |
| assert gt is not None, f"Failed to read mask: {gt_path}" |
| gt_bin = (gt > 0).astype(np.uint8) |
| prob_thresh = float(cfg["inference"]["prob_threshold"]) |
| pred_coarse = (prob_c > prob_thresh).to(torch.uint8).cpu().numpy() |
| pred_fine = (prob_f > prob_thresh).to(torch.uint8).cpu().numpy() |
| m_c = compute_metrics(pred_coarse, gt_bin) |
| m_f = compute_metrics(pred_fine, gt_bin) |
| for k in coarse_sum: |
| coarse_sum[k] += m_c[k] |
| for k in fine_sum: |
| fine_sum[k] += m_f[k] |
| n_metrics += 1 |
|
|
| timings.append( |
| { |
| "path": p, |
| "H": int(t_img.shape[2]), |
| "W": int(t_img.shape[3]), |
| "t_coarse_ms": (t1 - t0) * 1000.0, |
| "t_fine_ms": (t2 - t1) * 1000.0, |
| "t_total_ms": (t2 - t0) * 1000.0, |
| } |
| ) |
|
|
| if len(timings) == 0: |
| print("[WireSegHR][bench] Nothing to benchmark after warmup.") |
| return |
|
|
| |
| def _agg(key: str) -> Tuple[float, float, float]: |
| vals = sorted([t[key] for t in timings]) |
| n = len(vals) |
| p50 = vals[n // 2] |
| p95 = vals[min(n - 1, int(0.95 * (n - 1)))] |
| avg = sum(vals) / n |
| return avg, p50, p95 |
|
|
| avg_c, p50_c, p95_c = _agg("t_coarse_ms") |
| avg_f, p50_f, p95_f = _agg("t_fine_ms") |
| avg_t, p50_t, p95_t = _agg("t_total_ms") |
|
|
| print("[WireSegHR][bench] Results (ms):") |
| print(f" Coarse avg={avg_c:.2f} p50={p50_c:.2f} p95={p95_c:.2f}") |
| print(f" Fine avg={avg_f:.2f} p50={p50_f:.2f} p95={p95_f:.2f}") |
| print(f" Total avg={avg_t:.2f} p50={p50_t:.2f} p95={p95_t:.2f}") |
| print(f" Target < 1000 ms per 3000x4000 image: {'YES' if p50_t < 1000.0 else 'NO'}") |
|
|
| if args.bench_report_json: |
| import json |
|
|
| report = { |
| "summary": { |
| "avg_ms": avg_t, |
| "p50_ms": p50_t, |
| "p95_ms": p95_t, |
| "avg_coarse_ms": avg_c, |
| "avg_fine_ms": avg_f, |
| "images": len(timings), |
| }, |
| "per_image": timings, |
| } |
| report_path = args.bench_report_json |
| Path(report_path).parent.mkdir(parents=True, exist_ok=True) |
| with open(report_path, "w") as f: |
| json.dump(report, f, indent=2) |
|
|
| |
| if args.metrics: |
| if n_metrics > 0: |
| for k in fine_sum: |
| fine_sum[k] /= n_metrics |
| for k in coarse_sum: |
| coarse_sum[k] /= n_metrics |
| print( |
| f"[WireSegHR][bench][Fine] IoU={fine_sum['iou']:.4f} F1={fine_sum['f1']:.4f} P={fine_sum['precision']:.4f} R={fine_sum['recall']:.4f}" |
| ) |
| print( |
| f"[WireSegHR][bench][Coarse] IoU={coarse_sum['iou']:.4f} F1={coarse_sum['f1']:.4f} P={coarse_sum['precision']:.4f} R={coarse_sum['recall']:.4f}" |
| ) |
|
|
| return |
|
|
| |
| if args.image is not None: |
| pred, _ = infer_image( |
| model, |
| args.image, |
| cfg, |
| device, |
| amp_enabled, |
| amp_dtype, |
| out_dir=args.out, |
| save_prob=args.save_prob, |
| ) |
| if args.metrics: |
| assert args.mask, "--mask is required with --image when --metrics is set" |
| assert Path(args.mask).is_file(), f"Mask not found: {args.mask}" |
| gt = cv2.imread(args.mask, cv2.IMREAD_GRAYSCALE) |
| assert gt is not None, f"Failed to read mask: {args.mask}" |
| gt_bin = (gt > 0).astype(np.uint8) |
| pred_bin = (pred > 0).astype(np.uint8) |
| m = compute_metrics(pred_bin, gt_bin) |
| print( |
| f"[Infer] IoU={m['iou']:.4f} F1={m['f1']:.4f} P={m['precision']:.4f} R={m['recall']:.4f}" |
| ) |
| print("[WireSegHR][infer] Done.") |
| return |
|
|
| |
| img_dir = args.images_dir |
| assert Path(img_dir).is_dir(), f"Not a directory: {img_dir}" |
| img_files = sorted( |
| [p for p in os.listdir(img_dir) if p.lower().endswith((".jpg", ".jpeg"))] |
| ) |
| assert len(img_files) > 0, f"No .jpg/.jpeg in {img_dir}" |
| os.makedirs(args.out, exist_ok=True) |
| if args.metrics: |
| assert args.masks_dir, ( |
| "--masks_dir is required with --images_dir when --metrics is set" |
| ) |
| assert Path(args.masks_dir).is_dir(), f"Not a directory: {args.masks_dir}" |
| metrics_sum = {"iou": 0.0, "f1": 0.0, "precision": 0.0, "recall": 0.0} |
| n_eval = 0 |
| for name in tqdm(img_files, desc="[infer] Dir"): |
| path = str(Path(img_dir) / name) |
| pred, _ = infer_image( |
| model, |
| path, |
| cfg, |
| device, |
| amp_enabled, |
| amp_dtype, |
| out_dir=args.out, |
| save_prob=args.save_prob, |
| ) |
| if args.metrics: |
| stem = Path(name).stem |
| mask_path = Path(args.masks_dir) / f"{stem}.png" |
| assert mask_path.is_file(), f"Missing mask for {stem}: {mask_path}" |
| gt = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE) |
| assert gt is not None, f"Failed to read mask: {mask_path}" |
| gt_bin = (gt > 0).astype(np.uint8) |
| pred_bin = (pred > 0).astype(np.uint8) |
| m = compute_metrics(pred_bin, gt_bin) |
| for k in metrics_sum: |
| metrics_sum[k] += m[k] |
| n_eval += 1 |
| if args.metrics and n_eval > 0: |
| for k in metrics_sum: |
| metrics_sum[k] /= n_eval |
| print( |
| f"[Infer][Avg over {n_eval}] IoU={metrics_sum['iou']:.4f} F1={metrics_sum['f1']:.4f} P={metrics_sum['precision']:.4f} R={metrics_sum['recall']:.4f}" |
| ) |
| print("[WireSegHR][infer] Done.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|