File size: 7,855 Bytes
20e9cd1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 | """
Anchor generation and matching for SCRFD.
SCRFD uses 3-level anchors:
stride 8: anchor sizes [16, 32]
stride 16: anchor sizes [64, 128]
stride 32: anchor sizes [256, 512]
Matching: ATSS (Adaptive Training Sample Selection) from paper
"Bridging the Gap Between Anchor-based and Anchor-free Detection via
Adaptive Training Sample Selection" (Zhang et al., 2019)
"""
import torch
import torch.nn as nn
import math
from typing import List, Tuple, Optional
class AnchorGenerator:
"""
Generate anchors on feature map grids.
For SCRFD: 2 anchors per location × 3 levels = 6 anchor configs.
Aspect ratio = 1.0 (square anchors work best for faces).
"""
def __init__(self,
strides: List[int] = [8, 16, 32],
anchor_sizes: List[List[int]] = [[16, 32], [64, 128], [256, 512]],
ratios: List[float] = [1.0]):
self.strides = strides
self.anchor_sizes = anchor_sizes
self.ratios = ratios
self.num_anchors_per_level = [len(sizes) * len(ratios) for sizes in anchor_sizes]
def grid_anchors(self, feat_sizes: List[Tuple[int, int]],
device: torch.device) -> List[torch.Tensor]:
"""
Generate anchor boxes for each feature level.
Args:
feat_sizes: [(H, W)] for each level
device: target device
Returns:
List of [num_anchors, 4] tensors in (x1, y1, x2, y2) format
"""
all_anchors = []
for i, (feat_h, feat_w) in enumerate(feat_sizes):
stride = self.strides[i]
sizes = self.anchor_sizes[i]
# Grid centers
shift_x = (torch.arange(0, feat_w, device=device) + 0.5) * stride
shift_y = (torch.arange(0, feat_h, device=device) + 0.5) * stride
shift_y, shift_x = torch.meshgrid(shift_y, shift_x, indexing='ij')
shifts = torch.stack([shift_x.reshape(-1), shift_y.reshape(-1),
shift_x.reshape(-1), shift_y.reshape(-1)], dim=1)
# Base anchors for this level
base_anchors = []
for size in sizes:
for ratio in self.ratios:
w = size * math.sqrt(ratio)
h = size / math.sqrt(ratio)
base_anchors.append([-w/2, -h/2, w/2, h/2])
base_anchors = torch.tensor(base_anchors, device=device, dtype=torch.float32)
# Broadcast: shifts [N, 4] + base_anchors [K, 4] → [N*K, 4]
num_locs = shifts.shape[0]
num_bases = base_anchors.shape[0]
anchors = (shifts.unsqueeze(1) + base_anchors.unsqueeze(0)).reshape(-1, 4)
all_anchors.append(anchors)
return all_anchors
def num_anchors_per_loc(self) -> List[int]:
return self.num_anchors_per_level
class ATSSAssigner:
"""
Adaptive Training Sample Selection (ATSS) for anchor-GT matching.
Key idea: For each GT, select top-k closest anchors from each pyramid level,
compute their IoU with the GT, and use mean + std as the adaptive IoU threshold.
Only anchors with IoU > threshold AND whose center is inside GT are positive.
SCRFD uses ATSS because it adapts to face scale automatically —
tiny faces get lower thresholds (more positives), large faces get higher ones.
Args:
topk: Number of closest anchors to consider per level (default: 9)
"""
def __init__(self, topk: int = 9):
self.topk = topk
@torch.no_grad()
def assign(self, anchors: torch.Tensor, gt_boxes: torch.Tensor,
gt_labels: torch.Tensor, num_anchors_per_level: List[int]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Assign anchors to GT boxes using ATSS.
Args:
anchors: [N, 4] all anchors concatenated
gt_boxes: [M, 4] ground truth boxes
gt_labels: [M] ground truth labels (all 1 for face)
num_anchors_per_level: number of anchors per feature level
Returns:
assigned_labels: [N] (0 = background, 1 = face)
assigned_gt_inds: [N] (index of assigned GT, -1 for negatives)
"""
num_anchors = anchors.shape[0]
num_gts = gt_boxes.shape[0]
if num_gts == 0:
return (torch.zeros(num_anchors, dtype=torch.long, device=anchors.device),
torch.full((num_anchors,), -1, dtype=torch.long, device=anchors.device))
# Anchor centers
anchor_cx = (anchors[:, 0] + anchors[:, 2]) / 2
anchor_cy = (anchors[:, 1] + anchors[:, 3]) / 2
anchor_centers = torch.stack([anchor_cx, anchor_cy], dim=1) # [N, 2]
# GT centers
gt_cx = (gt_boxes[:, 0] + gt_boxes[:, 2]) / 2
gt_cy = (gt_boxes[:, 1] + gt_boxes[:, 3]) / 2
# Distance from each anchor to each GT center
distances = torch.cdist(anchor_centers, torch.stack([gt_cx, gt_cy], dim=1)) # [N, M]
# IoU between anchors and GTs
ious = self._compute_iou(anchors, gt_boxes) # [N, M]
assigned_labels = torch.zeros(num_anchors, dtype=torch.long, device=anchors.device)
assigned_gt_inds = torch.full((num_anchors,), -1, dtype=torch.long, device=anchors.device)
assigned_ious = torch.zeros(num_anchors, device=anchors.device)
# Process each GT
for gt_idx in range(num_gts):
gt_dists = distances[:, gt_idx] # [N]
gt_ious = ious[:, gt_idx] # [N]
# Select top-k closest anchors per level
candidate_mask = torch.zeros(num_anchors, dtype=torch.bool, device=anchors.device)
start = 0
for num_per_level in num_anchors_per_level:
end = start + num_per_level
level_dists = gt_dists[start:end]
k = min(self.topk, num_per_level)
_, topk_inds = level_dists.topk(k, largest=False)
candidate_mask[start + topk_inds] = True
start = end
# Compute adaptive threshold
candidate_ious = gt_ious[candidate_mask]
iou_mean = candidate_ious.mean()
iou_std = candidate_ious.std()
iou_threshold = iou_mean + iou_std
# Filter: IoU > threshold AND center inside GT box
is_positive = (
candidate_mask &
(gt_ious >= iou_threshold) &
(anchor_cx >= gt_boxes[gt_idx, 0]) &
(anchor_cy >= gt_boxes[gt_idx, 1]) &
(anchor_cx <= gt_boxes[gt_idx, 2]) &
(anchor_cy <= gt_boxes[gt_idx, 3])
)
# Assign (higher IoU wins if conflict)
better = is_positive & (gt_ious > assigned_ious)
assigned_labels[better] = gt_labels[gt_idx]
assigned_gt_inds[better] = gt_idx
assigned_ious[better] = gt_ious[better]
return assigned_labels, assigned_gt_inds
@staticmethod
def _compute_iou(boxes1: torch.Tensor, boxes2: torch.Tensor) -> torch.Tensor:
"""Compute pairwise IoU between two sets of boxes. [N,4] × [M,4] → [N,M]"""
area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1])
area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1])
inter_x1 = torch.max(boxes1[:, 0].unsqueeze(1), boxes2[:, 0].unsqueeze(0))
inter_y1 = torch.max(boxes1[:, 1].unsqueeze(1), boxes2[:, 1].unsqueeze(0))
inter_x2 = torch.min(boxes1[:, 2].unsqueeze(1), boxes2[:, 2].unsqueeze(0))
inter_y2 = torch.min(boxes1[:, 3].unsqueeze(1), boxes2[:, 3].unsqueeze(0))
inter = (inter_x2 - inter_x1).clamp(min=0) * (inter_y2 - inter_y1).clamp(min=0)
union = area1.unsqueeze(1) + area2.unsqueeze(0) - inter
return inter / (union + 1e-6)
|