import argparse import csv import json import os import cv2 import numpy as np import torch import torchvision.transforms as standard_transforms from PIL import Image from scipy.optimize import linear_sum_assignment from scipy.spatial import cKDTree from models import build_model class Args: backbone = "vgg16_bn" row = 2 line = 2 def load_model(weight_path): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = build_model(Args()).to(device).eval() if os.path.exists(weight_path): checkpoint = torch.load(weight_path, map_location=device) model.load_state_dict(checkpoint["model"]) transform = standard_transforms.Compose([ standard_transforms.ToTensor(), standard_transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) return model, device, transform def infer_points(image, model, device, transform, confidence=0.5, magnification=1.5, batch_size=8): orig_w, orig_h = image.size patch_size = 512 pad = 256 work_w, work_h = int(orig_w * magnification), int(orig_h * magnification) scale = min(1.0, 3840 / float(max(work_w, work_h))) work_w, work_h = int(work_w * scale), int(work_h * scale) magnification = work_w / float(orig_w) resample_filter = getattr(Image, "Resampling", Image).LANCZOS if hasattr(Image, "Resampling") else getattr(Image, "ANTIALIAS", 1) image = image.resize((work_w, work_h), resample_filter) padded_w = ((work_w + pad * 2 + patch_size - 1) // patch_size) * patch_size padded_h = ((work_h + pad * 2 + patch_size - 1) // patch_size) * patch_size padded = Image.new("RGB", (padded_w, padded_h), (0, 0, 0)) padded.paste(image, (pad, pad)) stride = patch_size // 2 jobs = [] for y in range(0, padded_h - stride + 1, stride): for x in range(0, padded_w - stride + 1, stride): if x + patch_size <= padded_w and y + patch_size <= padded_h: jobs.append((x, y, padded.crop((x, y, x + patch_size, y + patch_size)))) all_points = [] for start in range(0, len(jobs), batch_size): batch = jobs[start:start + batch_size] samples = torch.stack([transform(patch) for _, _, patch in batch]).to(device) with torch.inference_mode(): if device.type == "cuda": with torch.cuda.amp.autocast(): out = model(samples) else: out = model(samples) scores = torch.nn.functional.softmax(out["pred_logits"].float(), -1)[:, :, 1] pred = out["pred_points"].float() for idx, (x, y, _) in enumerate(batch): pts = pred[idx][scores[idx] > confidence].detach().cpu().numpy() if len(pts): pts[:, 0] += x - pad pts[:, 1] += y - pad pts /= float(magnification) all_points.extend([p.tolist() for p in pts if 0 <= p[0] < orig_w and 0 <= p[1] < orig_h]) if not all_points: return [] pts = np.array(all_points, dtype=np.float32) tree = cKDTree(pts) suppressed = set() for i, j in tree.query_pairs(r=8.0): if i not in suppressed and j not in suppressed: suppressed.add(j) return [pts[i].tolist() for i in range(len(pts)) if i not in suppressed] def load_gt(path): with open(path, "r", encoding="utf-8") as f: data = json.load(f) if isinstance(data, dict) and "annotations" in data: data = data["annotations"] if isinstance(data, dict): return [{"image": image, "points": points} for image, points in data.items()] return data def precision_recall(pred_points, gt_points, radius): pred = np.array(pred_points, dtype=np.float32) gt = np.array(gt_points, dtype=np.float32) if len(pred) == 0 and len(gt) == 0: return 1.0, 1.0, 0, 0, 0 if len(pred) == 0: return 0.0, 0.0, 0, 0, len(gt) if len(gt) == 0: return 0.0, 0.0, 0, len(pred), 0 dist = np.linalg.norm(pred[:, None, :] - gt[None, :, :], axis=2) rows, cols = linear_sum_assignment(dist) matches = sum(1 for r, c in zip(rows, cols) if dist[r, c] <= radius) fp = len(pred) - matches fn = len(gt) - matches precision = matches / (matches + fp) if matches + fp else 0.0 recall = matches / (matches + fn) if matches + fn else 0.0 return precision, recall, matches, fp, fn def draw_visual(image_path, gt_points, pred_points, output_path): img = cv2.imread(image_path) for x, y in gt_points: cv2.circle(img, (int(x), int(y)), 4, (0, 255, 0), -1) for x, y in pred_points: cv2.circle(img, (int(x), int(y)), 3, (0, 0, 255), 1) cv2.imwrite(output_path, img) def main(): parser = argparse.ArgumentParser() parser.add_argument("--images_dir", required=True) parser.add_argument("--gt_json", required=True) parser.add_argument("--weights", default=os.path.join("weights", "SHTechA.pth")) parser.add_argument("--output_dir", default="eval_results") parser.add_argument("--confidence", type=float, default=0.5) args = parser.parse_args() os.makedirs(args.output_dir, exist_ok=True) vis_dir = os.path.join(args.output_dir, "visualizations") os.makedirs(vis_dir, exist_ok=True) model, device, transform = load_model(args.weights) rows = [] errors = [] squared_errors = [] for item in load_gt(args.gt_json): image_name = item["image"] gt_points = item.get("points", []) image_path = image_name if os.path.isabs(image_name) else os.path.join(args.images_dir, image_name) pred_points = infer_points(Image.open(image_path).convert("RGB"), model, device, transform, args.confidence) err = abs(len(pred_points) - len(gt_points)) errors.append(err) squared_errors.append(err ** 2) row = {"image": os.path.basename(image_path), "gt_count": len(gt_points), "pred_count": len(pred_points), "abs_error": err, "sq_error": err ** 2} for radius in [5, 10, 15, 20]: p, r, m, fp, fn = precision_recall(pred_points, gt_points, radius) row[f"precision_{radius}px"] = round(p, 4) row[f"recall_{radius}px"] = round(r, 4) row[f"matches_{radius}px"] = m row[f"fp_{radius}px"] = fp row[f"fn_{radius}px"] = fn rows.append(row) draw_visual(image_path, gt_points, pred_points, os.path.join(vis_dir, os.path.splitext(os.path.basename(image_path))[0] + "_eval.png")) summary = {"mae": round(float(np.mean(errors)), 4) if errors else 0, "mse": round(float(np.mean(squared_errors)), 4) if squared_errors else 0, "images": len(rows)} csv_path = os.path.join(args.output_dir, "evaluation.csv") json_path = os.path.join(args.output_dir, "evaluation_summary.json") with open(csv_path, "w", newline="", encoding="utf-8") as f: writer = csv.DictWriter(f, fieldnames=list(rows[0].keys()) if rows else ["image"]) writer.writeheader() writer.writerows(rows) with open(json_path, "w", encoding="utf-8") as f: json.dump({"summary": summary, "rows": rows}, f, indent=2) print(json.dumps({"csv": csv_path, "json": json_path, "visualizations": vis_dir, "summary": summary}, indent=2)) if __name__ == "__main__": main()