| """ |
| Anchor generation and matching for SCRFD. |
| |
| SCRFD uses 3-level anchors: |
| stride 8: anchor sizes [16, 32] |
| stride 16: anchor sizes [64, 128] |
| stride 32: anchor sizes [256, 512] |
| |
| Matching: ATSS (Adaptive Training Sample Selection) from paper |
| "Bridging the Gap Between Anchor-based and Anchor-free Detection via |
| Adaptive Training Sample Selection" (Zhang et al., 2019) |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import math |
| from typing import List, Tuple, Optional |
|
|
|
|
| class AnchorGenerator: |
| """ |
| Generate anchors on feature map grids. |
| |
| For SCRFD: 2 anchors per location × 3 levels = 6 anchor configs. |
| Aspect ratio = 1.0 (square anchors work best for faces). |
| """ |
|
|
| def __init__(self, |
| strides: List[int] = [8, 16, 32], |
| anchor_sizes: List[List[int]] = [[16, 32], [64, 128], [256, 512]], |
| ratios: List[float] = [1.0]): |
| self.strides = strides |
| self.anchor_sizes = anchor_sizes |
| self.ratios = ratios |
| self.num_anchors_per_level = [len(sizes) * len(ratios) for sizes in anchor_sizes] |
|
|
| def grid_anchors(self, feat_sizes: List[Tuple[int, int]], |
| device: torch.device) -> List[torch.Tensor]: |
| """ |
| Generate anchor boxes for each feature level. |
| |
| Args: |
| feat_sizes: [(H, W)] for each level |
| device: target device |
| |
| Returns: |
| List of [num_anchors, 4] tensors in (x1, y1, x2, y2) format |
| """ |
| all_anchors = [] |
| for i, (feat_h, feat_w) in enumerate(feat_sizes): |
| stride = self.strides[i] |
| sizes = self.anchor_sizes[i] |
|
|
| |
| shift_x = (torch.arange(0, feat_w, device=device) + 0.5) * stride |
| shift_y = (torch.arange(0, feat_h, device=device) + 0.5) * stride |
| shift_y, shift_x = torch.meshgrid(shift_y, shift_x, indexing='ij') |
| shifts = torch.stack([shift_x.reshape(-1), shift_y.reshape(-1), |
| shift_x.reshape(-1), shift_y.reshape(-1)], dim=1) |
|
|
| |
| base_anchors = [] |
| for size in sizes: |
| for ratio in self.ratios: |
| w = size * math.sqrt(ratio) |
| h = size / math.sqrt(ratio) |
| base_anchors.append([-w/2, -h/2, w/2, h/2]) |
| base_anchors = torch.tensor(base_anchors, device=device, dtype=torch.float32) |
|
|
| |
| num_locs = shifts.shape[0] |
| num_bases = base_anchors.shape[0] |
| anchors = (shifts.unsqueeze(1) + base_anchors.unsqueeze(0)).reshape(-1, 4) |
| all_anchors.append(anchors) |
|
|
| return all_anchors |
|
|
| def num_anchors_per_loc(self) -> List[int]: |
| return self.num_anchors_per_level |
|
|
|
|
| class ATSSAssigner: |
| """ |
| Adaptive Training Sample Selection (ATSS) for anchor-GT matching. |
| |
| Key idea: For each GT, select top-k closest anchors from each pyramid level, |
| compute their IoU with the GT, and use mean + std as the adaptive IoU threshold. |
| Only anchors with IoU > threshold AND whose center is inside GT are positive. |
| |
| SCRFD uses ATSS because it adapts to face scale automatically — |
| tiny faces get lower thresholds (more positives), large faces get higher ones. |
| |
| Args: |
| topk: Number of closest anchors to consider per level (default: 9) |
| """ |
|
|
| def __init__(self, topk: int = 9): |
| self.topk = topk |
|
|
| @torch.no_grad() |
| def assign(self, anchors: torch.Tensor, gt_boxes: torch.Tensor, |
| gt_labels: torch.Tensor, num_anchors_per_level: List[int] |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Assign anchors to GT boxes using ATSS. |
| |
| Args: |
| anchors: [N, 4] all anchors concatenated |
| gt_boxes: [M, 4] ground truth boxes |
| gt_labels: [M] ground truth labels (all 1 for face) |
| num_anchors_per_level: number of anchors per feature level |
| |
| Returns: |
| assigned_labels: [N] (0 = background, 1 = face) |
| assigned_gt_inds: [N] (index of assigned GT, -1 for negatives) |
| """ |
| num_anchors = anchors.shape[0] |
| num_gts = gt_boxes.shape[0] |
|
|
| if num_gts == 0: |
| return (torch.zeros(num_anchors, dtype=torch.long, device=anchors.device), |
| torch.full((num_anchors,), -1, dtype=torch.long, device=anchors.device)) |
|
|
| |
| anchor_cx = (anchors[:, 0] + anchors[:, 2]) / 2 |
| anchor_cy = (anchors[:, 1] + anchors[:, 3]) / 2 |
| anchor_centers = torch.stack([anchor_cx, anchor_cy], dim=1) |
|
|
| |
| gt_cx = (gt_boxes[:, 0] + gt_boxes[:, 2]) / 2 |
| gt_cy = (gt_boxes[:, 1] + gt_boxes[:, 3]) / 2 |
|
|
| |
| distances = torch.cdist(anchor_centers, torch.stack([gt_cx, gt_cy], dim=1)) |
|
|
| |
| ious = self._compute_iou(anchors, gt_boxes) |
|
|
| assigned_labels = torch.zeros(num_anchors, dtype=torch.long, device=anchors.device) |
| assigned_gt_inds = torch.full((num_anchors,), -1, dtype=torch.long, device=anchors.device) |
| assigned_ious = torch.zeros(num_anchors, device=anchors.device) |
|
|
| |
| for gt_idx in range(num_gts): |
| gt_dists = distances[:, gt_idx] |
| gt_ious = ious[:, gt_idx] |
|
|
| |
| candidate_mask = torch.zeros(num_anchors, dtype=torch.bool, device=anchors.device) |
| start = 0 |
| for num_per_level in num_anchors_per_level: |
| end = start + num_per_level |
| level_dists = gt_dists[start:end] |
| k = min(self.topk, num_per_level) |
| _, topk_inds = level_dists.topk(k, largest=False) |
| candidate_mask[start + topk_inds] = True |
| start = end |
|
|
| |
| candidate_ious = gt_ious[candidate_mask] |
| iou_mean = candidate_ious.mean() |
| iou_std = candidate_ious.std() |
| iou_threshold = iou_mean + iou_std |
|
|
| |
| is_positive = ( |
| candidate_mask & |
| (gt_ious >= iou_threshold) & |
| (anchor_cx >= gt_boxes[gt_idx, 0]) & |
| (anchor_cy >= gt_boxes[gt_idx, 1]) & |
| (anchor_cx <= gt_boxes[gt_idx, 2]) & |
| (anchor_cy <= gt_boxes[gt_idx, 3]) |
| ) |
|
|
| |
| better = is_positive & (gt_ious > assigned_ious) |
| assigned_labels[better] = gt_labels[gt_idx] |
| assigned_gt_inds[better] = gt_idx |
| assigned_ious[better] = gt_ious[better] |
|
|
| return assigned_labels, assigned_gt_inds |
|
|
| @staticmethod |
| def _compute_iou(boxes1: torch.Tensor, boxes2: torch.Tensor) -> torch.Tensor: |
| """Compute pairwise IoU between two sets of boxes. [N,4] × [M,4] → [N,M]""" |
| area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1]) |
| area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1]) |
|
|
| inter_x1 = torch.max(boxes1[:, 0].unsqueeze(1), boxes2[:, 0].unsqueeze(0)) |
| inter_y1 = torch.max(boxes1[:, 1].unsqueeze(1), boxes2[:, 1].unsqueeze(0)) |
| inter_x2 = torch.min(boxes1[:, 2].unsqueeze(1), boxes2[:, 2].unsqueeze(0)) |
| inter_y2 = torch.min(boxes1[:, 3].unsqueeze(1), boxes2[:, 3].unsqueeze(0)) |
|
|
| inter = (inter_x2 - inter_x1).clamp(min=0) * (inter_y2 - inter_y1).clamp(min=0) |
| union = area1.unsqueeze(1) + area2.unsqueeze(0) - inter |
|
|
| return inter / (union + 1e-6) |
|
|