| import argparse |
| import csv |
| import json |
| import os |
|
|
| import cv2 |
| import numpy as np |
| import torch |
| import torchvision.transforms as standard_transforms |
| from PIL import Image |
| from scipy.optimize import linear_sum_assignment |
| from scipy.spatial import cKDTree |
|
|
| from models import build_model |
|
|
|
|
| class Args: |
| backbone = "vgg16_bn" |
| row = 2 |
| line = 2 |
|
|
|
|
| def load_model(weight_path): |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model = build_model(Args()).to(device).eval() |
| if os.path.exists(weight_path): |
| checkpoint = torch.load(weight_path, map_location=device) |
| model.load_state_dict(checkpoint["model"]) |
| transform = standard_transforms.Compose([ |
| standard_transforms.ToTensor(), |
| standard_transforms.Normalize(mean=[0.485, 0.456, 0.406], |
| std=[0.229, 0.224, 0.225]), |
| ]) |
| return model, device, transform |
|
|
|
|
| def infer_points(image, model, device, transform, confidence=0.5, magnification=1.5, batch_size=8): |
| orig_w, orig_h = image.size |
| patch_size = 512 |
| pad = 256 |
| work_w, work_h = int(orig_w * magnification), int(orig_h * magnification) |
| scale = min(1.0, 3840 / float(max(work_w, work_h))) |
| work_w, work_h = int(work_w * scale), int(work_h * scale) |
| magnification = work_w / float(orig_w) |
| resample_filter = getattr(Image, "Resampling", Image).LANCZOS if hasattr(Image, "Resampling") else getattr(Image, "ANTIALIAS", 1) |
| image = image.resize((work_w, work_h), resample_filter) |
| padded_w = ((work_w + pad * 2 + patch_size - 1) // patch_size) * patch_size |
| padded_h = ((work_h + pad * 2 + patch_size - 1) // patch_size) * patch_size |
| padded = Image.new("RGB", (padded_w, padded_h), (0, 0, 0)) |
| padded.paste(image, (pad, pad)) |
| stride = patch_size // 2 |
| jobs = [] |
| for y in range(0, padded_h - stride + 1, stride): |
| for x in range(0, padded_w - stride + 1, stride): |
| if x + patch_size <= padded_w and y + patch_size <= padded_h: |
| jobs.append((x, y, padded.crop((x, y, x + patch_size, y + patch_size)))) |
| all_points = [] |
| for start in range(0, len(jobs), batch_size): |
| batch = jobs[start:start + batch_size] |
| samples = torch.stack([transform(patch) for _, _, patch in batch]).to(device) |
| with torch.inference_mode(): |
| if device.type == "cuda": |
| with torch.cuda.amp.autocast(): |
| out = model(samples) |
| else: |
| out = model(samples) |
| scores = torch.nn.functional.softmax(out["pred_logits"].float(), -1)[:, :, 1] |
| pred = out["pred_points"].float() |
| for idx, (x, y, _) in enumerate(batch): |
| pts = pred[idx][scores[idx] > confidence].detach().cpu().numpy() |
| if len(pts): |
| pts[:, 0] += x - pad |
| pts[:, 1] += y - pad |
| pts /= float(magnification) |
| all_points.extend([p.tolist() for p in pts if 0 <= p[0] < orig_w and 0 <= p[1] < orig_h]) |
| if not all_points: |
| return [] |
| pts = np.array(all_points, dtype=np.float32) |
| tree = cKDTree(pts) |
| suppressed = set() |
| for i, j in tree.query_pairs(r=8.0): |
| if i not in suppressed and j not in suppressed: |
| suppressed.add(j) |
| return [pts[i].tolist() for i in range(len(pts)) if i not in suppressed] |
|
|
|
|
| def load_gt(path): |
| with open(path, "r", encoding="utf-8") as f: |
| data = json.load(f) |
| if isinstance(data, dict) and "annotations" in data: |
| data = data["annotations"] |
| if isinstance(data, dict): |
| return [{"image": image, "points": points} for image, points in data.items()] |
| return data |
|
|
|
|
| def precision_recall(pred_points, gt_points, radius): |
| pred = np.array(pred_points, dtype=np.float32) |
| gt = np.array(gt_points, dtype=np.float32) |
| if len(pred) == 0 and len(gt) == 0: |
| return 1.0, 1.0, 0, 0, 0 |
| if len(pred) == 0: |
| return 0.0, 0.0, 0, 0, len(gt) |
| if len(gt) == 0: |
| return 0.0, 0.0, 0, len(pred), 0 |
| dist = np.linalg.norm(pred[:, None, :] - gt[None, :, :], axis=2) |
| rows, cols = linear_sum_assignment(dist) |
| matches = sum(1 for r, c in zip(rows, cols) if dist[r, c] <= radius) |
| fp = len(pred) - matches |
| fn = len(gt) - matches |
| precision = matches / (matches + fp) if matches + fp else 0.0 |
| recall = matches / (matches + fn) if matches + fn else 0.0 |
| return precision, recall, matches, fp, fn |
|
|
|
|
| def draw_visual(image_path, gt_points, pred_points, output_path): |
| img = cv2.imread(image_path) |
| for x, y in gt_points: |
| cv2.circle(img, (int(x), int(y)), 4, (0, 255, 0), -1) |
| for x, y in pred_points: |
| cv2.circle(img, (int(x), int(y)), 3, (0, 0, 255), 1) |
| cv2.imwrite(output_path, img) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--images_dir", required=True) |
| parser.add_argument("--gt_json", required=True) |
| parser.add_argument("--weights", default=os.path.join("weights", "SHTechA.pth")) |
| parser.add_argument("--output_dir", default="eval_results") |
| parser.add_argument("--confidence", type=float, default=0.5) |
| args = parser.parse_args() |
| os.makedirs(args.output_dir, exist_ok=True) |
| vis_dir = os.path.join(args.output_dir, "visualizations") |
| os.makedirs(vis_dir, exist_ok=True) |
| model, device, transform = load_model(args.weights) |
| rows = [] |
| errors = [] |
| squared_errors = [] |
| for item in load_gt(args.gt_json): |
| image_name = item["image"] |
| gt_points = item.get("points", []) |
| image_path = image_name if os.path.isabs(image_name) else os.path.join(args.images_dir, image_name) |
| pred_points = infer_points(Image.open(image_path).convert("RGB"), model, device, transform, args.confidence) |
| err = abs(len(pred_points) - len(gt_points)) |
| errors.append(err) |
| squared_errors.append(err ** 2) |
| row = {"image": os.path.basename(image_path), "gt_count": len(gt_points), "pred_count": len(pred_points), "abs_error": err, "sq_error": err ** 2} |
| for radius in [5, 10, 15, 20]: |
| p, r, m, fp, fn = precision_recall(pred_points, gt_points, radius) |
| row[f"precision_{radius}px"] = round(p, 4) |
| row[f"recall_{radius}px"] = round(r, 4) |
| row[f"matches_{radius}px"] = m |
| row[f"fp_{radius}px"] = fp |
| row[f"fn_{radius}px"] = fn |
| rows.append(row) |
| draw_visual(image_path, gt_points, pred_points, os.path.join(vis_dir, os.path.splitext(os.path.basename(image_path))[0] + "_eval.png")) |
| summary = {"mae": round(float(np.mean(errors)), 4) if errors else 0, "mse": round(float(np.mean(squared_errors)), 4) if squared_errors else 0, "images": len(rows)} |
| csv_path = os.path.join(args.output_dir, "evaluation.csv") |
| json_path = os.path.join(args.output_dir, "evaluation_summary.json") |
| with open(csv_path, "w", newline="", encoding="utf-8") as f: |
| writer = csv.DictWriter(f, fieldnames=list(rows[0].keys()) if rows else ["image"]) |
| writer.writeheader() |
| writer.writerows(rows) |
| with open(json_path, "w", encoding="utf-8") as f: |
| json.dump({"summary": summary, "rows": rows}, f, indent=2) |
| print(json.dumps({"csv": csv_path, "json": json_path, "visualizations": vis_dir, "summary": summary}, indent=2)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|