""" Anchor generation and matching for SCRFD. SCRFD uses 3-level anchors: stride 8: anchor sizes [16, 32] stride 16: anchor sizes [64, 128] stride 32: anchor sizes [256, 512] Matching: ATSS (Adaptive Training Sample Selection) from paper "Bridging the Gap Between Anchor-based and Anchor-free Detection via Adaptive Training Sample Selection" (Zhang et al., 2019) """ import torch import torch.nn as nn import math from typing import List, Tuple, Optional class AnchorGenerator: """ Generate anchors on feature map grids. For SCRFD: 2 anchors per location × 3 levels = 6 anchor configs. Aspect ratio = 1.0 (square anchors work best for faces). """ def __init__(self, strides: List[int] = [8, 16, 32], anchor_sizes: List[List[int]] = [[16, 32], [64, 128], [256, 512]], ratios: List[float] = [1.0]): self.strides = strides self.anchor_sizes = anchor_sizes self.ratios = ratios self.num_anchors_per_level = [len(sizes) * len(ratios) for sizes in anchor_sizes] def grid_anchors(self, feat_sizes: List[Tuple[int, int]], device: torch.device) -> List[torch.Tensor]: """ Generate anchor boxes for each feature level. Args: feat_sizes: [(H, W)] for each level device: target device Returns: List of [num_anchors, 4] tensors in (x1, y1, x2, y2) format """ all_anchors = [] for i, (feat_h, feat_w) in enumerate(feat_sizes): stride = self.strides[i] sizes = self.anchor_sizes[i] # Grid centers shift_x = (torch.arange(0, feat_w, device=device) + 0.5) * stride shift_y = (torch.arange(0, feat_h, device=device) + 0.5) * stride shift_y, shift_x = torch.meshgrid(shift_y, shift_x, indexing='ij') shifts = torch.stack([shift_x.reshape(-1), shift_y.reshape(-1), shift_x.reshape(-1), shift_y.reshape(-1)], dim=1) # Base anchors for this level base_anchors = [] for size in sizes: for ratio in self.ratios: w = size * math.sqrt(ratio) h = size / math.sqrt(ratio) base_anchors.append([-w/2, -h/2, w/2, h/2]) base_anchors = torch.tensor(base_anchors, device=device, dtype=torch.float32) # Broadcast: shifts [N, 4] + base_anchors [K, 4] → [N*K, 4] num_locs = shifts.shape[0] num_bases = base_anchors.shape[0] anchors = (shifts.unsqueeze(1) + base_anchors.unsqueeze(0)).reshape(-1, 4) all_anchors.append(anchors) return all_anchors def num_anchors_per_loc(self) -> List[int]: return self.num_anchors_per_level class ATSSAssigner: """ Adaptive Training Sample Selection (ATSS) for anchor-GT matching. Key idea: For each GT, select top-k closest anchors from each pyramid level, compute their IoU with the GT, and use mean + std as the adaptive IoU threshold. Only anchors with IoU > threshold AND whose center is inside GT are positive. SCRFD uses ATSS because it adapts to face scale automatically — tiny faces get lower thresholds (more positives), large faces get higher ones. Args: topk: Number of closest anchors to consider per level (default: 9) """ def __init__(self, topk: int = 9): self.topk = topk @torch.no_grad() def assign(self, anchors: torch.Tensor, gt_boxes: torch.Tensor, gt_labels: torch.Tensor, num_anchors_per_level: List[int] ) -> Tuple[torch.Tensor, torch.Tensor]: """ Assign anchors to GT boxes using ATSS. Args: anchors: [N, 4] all anchors concatenated gt_boxes: [M, 4] ground truth boxes gt_labels: [M] ground truth labels (all 1 for face) num_anchors_per_level: number of anchors per feature level Returns: assigned_labels: [N] (0 = background, 1 = face) assigned_gt_inds: [N] (index of assigned GT, -1 for negatives) """ num_anchors = anchors.shape[0] num_gts = gt_boxes.shape[0] if num_gts == 0: return (torch.zeros(num_anchors, dtype=torch.long, device=anchors.device), torch.full((num_anchors,), -1, dtype=torch.long, device=anchors.device)) # Anchor centers anchor_cx = (anchors[:, 0] + anchors[:, 2]) / 2 anchor_cy = (anchors[:, 1] + anchors[:, 3]) / 2 anchor_centers = torch.stack([anchor_cx, anchor_cy], dim=1) # [N, 2] # GT centers gt_cx = (gt_boxes[:, 0] + gt_boxes[:, 2]) / 2 gt_cy = (gt_boxes[:, 1] + gt_boxes[:, 3]) / 2 # Distance from each anchor to each GT center distances = torch.cdist(anchor_centers, torch.stack([gt_cx, gt_cy], dim=1)) # [N, M] # IoU between anchors and GTs ious = self._compute_iou(anchors, gt_boxes) # [N, M] assigned_labels = torch.zeros(num_anchors, dtype=torch.long, device=anchors.device) assigned_gt_inds = torch.full((num_anchors,), -1, dtype=torch.long, device=anchors.device) assigned_ious = torch.zeros(num_anchors, device=anchors.device) # Process each GT for gt_idx in range(num_gts): gt_dists = distances[:, gt_idx] # [N] gt_ious = ious[:, gt_idx] # [N] # Select top-k closest anchors per level candidate_mask = torch.zeros(num_anchors, dtype=torch.bool, device=anchors.device) start = 0 for num_per_level in num_anchors_per_level: end = start + num_per_level level_dists = gt_dists[start:end] k = min(self.topk, num_per_level) _, topk_inds = level_dists.topk(k, largest=False) candidate_mask[start + topk_inds] = True start = end # Compute adaptive threshold candidate_ious = gt_ious[candidate_mask] iou_mean = candidate_ious.mean() iou_std = candidate_ious.std() iou_threshold = iou_mean + iou_std # Filter: IoU > threshold AND center inside GT box is_positive = ( candidate_mask & (gt_ious >= iou_threshold) & (anchor_cx >= gt_boxes[gt_idx, 0]) & (anchor_cy >= gt_boxes[gt_idx, 1]) & (anchor_cx <= gt_boxes[gt_idx, 2]) & (anchor_cy <= gt_boxes[gt_idx, 3]) ) # Assign (higher IoU wins if conflict) better = is_positive & (gt_ious > assigned_ious) assigned_labels[better] = gt_labels[gt_idx] assigned_gt_inds[better] = gt_idx assigned_ious[better] = gt_ious[better] return assigned_labels, assigned_gt_inds @staticmethod def _compute_iou(boxes1: torch.Tensor, boxes2: torch.Tensor) -> torch.Tensor: """Compute pairwise IoU between two sets of boxes. [N,4] × [M,4] → [N,M]""" area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1]) area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1]) inter_x1 = torch.max(boxes1[:, 0].unsqueeze(1), boxes2[:, 0].unsqueeze(0)) inter_y1 = torch.max(boxes1[:, 1].unsqueeze(1), boxes2[:, 1].unsqueeze(0)) inter_x2 = torch.min(boxes1[:, 2].unsqueeze(1), boxes2[:, 2].unsqueeze(0)) inter_y2 = torch.min(boxes1[:, 3].unsqueeze(1), boxes2[:, 3].unsqueeze(0)) inter = (inter_x2 - inter_x1).clamp(min=0) * (inter_y2 - inter_y1).clamp(min=0) union = area1.unsqueeze(1) + area2.unsqueeze(0) - inter return inter / (union + 1e-6)