facedet / models /anchor.py
cledouxluma's picture
Upload models/anchor.py with huggingface_hub
20e9cd1 verified
"""
Anchor generation and matching for SCRFD.
SCRFD uses 3-level anchors:
stride 8: anchor sizes [16, 32]
stride 16: anchor sizes [64, 128]
stride 32: anchor sizes [256, 512]
Matching: ATSS (Adaptive Training Sample Selection) from paper
"Bridging the Gap Between Anchor-based and Anchor-free Detection via
Adaptive Training Sample Selection" (Zhang et al., 2019)
"""
import torch
import torch.nn as nn
import math
from typing import List, Tuple, Optional
class AnchorGenerator:
"""
Generate anchors on feature map grids.
For SCRFD: 2 anchors per location × 3 levels = 6 anchor configs.
Aspect ratio = 1.0 (square anchors work best for faces).
"""
def __init__(self,
strides: List[int] = [8, 16, 32],
anchor_sizes: List[List[int]] = [[16, 32], [64, 128], [256, 512]],
ratios: List[float] = [1.0]):
self.strides = strides
self.anchor_sizes = anchor_sizes
self.ratios = ratios
self.num_anchors_per_level = [len(sizes) * len(ratios) for sizes in anchor_sizes]
def grid_anchors(self, feat_sizes: List[Tuple[int, int]],
device: torch.device) -> List[torch.Tensor]:
"""
Generate anchor boxes for each feature level.
Args:
feat_sizes: [(H, W)] for each level
device: target device
Returns:
List of [num_anchors, 4] tensors in (x1, y1, x2, y2) format
"""
all_anchors = []
for i, (feat_h, feat_w) in enumerate(feat_sizes):
stride = self.strides[i]
sizes = self.anchor_sizes[i]
# Grid centers
shift_x = (torch.arange(0, feat_w, device=device) + 0.5) * stride
shift_y = (torch.arange(0, feat_h, device=device) + 0.5) * stride
shift_y, shift_x = torch.meshgrid(shift_y, shift_x, indexing='ij')
shifts = torch.stack([shift_x.reshape(-1), shift_y.reshape(-1),
shift_x.reshape(-1), shift_y.reshape(-1)], dim=1)
# Base anchors for this level
base_anchors = []
for size in sizes:
for ratio in self.ratios:
w = size * math.sqrt(ratio)
h = size / math.sqrt(ratio)
base_anchors.append([-w/2, -h/2, w/2, h/2])
base_anchors = torch.tensor(base_anchors, device=device, dtype=torch.float32)
# Broadcast: shifts [N, 4] + base_anchors [K, 4] → [N*K, 4]
num_locs = shifts.shape[0]
num_bases = base_anchors.shape[0]
anchors = (shifts.unsqueeze(1) + base_anchors.unsqueeze(0)).reshape(-1, 4)
all_anchors.append(anchors)
return all_anchors
def num_anchors_per_loc(self) -> List[int]:
return self.num_anchors_per_level
class ATSSAssigner:
"""
Adaptive Training Sample Selection (ATSS) for anchor-GT matching.
Key idea: For each GT, select top-k closest anchors from each pyramid level,
compute their IoU with the GT, and use mean + std as the adaptive IoU threshold.
Only anchors with IoU > threshold AND whose center is inside GT are positive.
SCRFD uses ATSS because it adapts to face scale automatically —
tiny faces get lower thresholds (more positives), large faces get higher ones.
Args:
topk: Number of closest anchors to consider per level (default: 9)
"""
def __init__(self, topk: int = 9):
self.topk = topk
@torch.no_grad()
def assign(self, anchors: torch.Tensor, gt_boxes: torch.Tensor,
gt_labels: torch.Tensor, num_anchors_per_level: List[int]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Assign anchors to GT boxes using ATSS.
Args:
anchors: [N, 4] all anchors concatenated
gt_boxes: [M, 4] ground truth boxes
gt_labels: [M] ground truth labels (all 1 for face)
num_anchors_per_level: number of anchors per feature level
Returns:
assigned_labels: [N] (0 = background, 1 = face)
assigned_gt_inds: [N] (index of assigned GT, -1 for negatives)
"""
num_anchors = anchors.shape[0]
num_gts = gt_boxes.shape[0]
if num_gts == 0:
return (torch.zeros(num_anchors, dtype=torch.long, device=anchors.device),
torch.full((num_anchors,), -1, dtype=torch.long, device=anchors.device))
# Anchor centers
anchor_cx = (anchors[:, 0] + anchors[:, 2]) / 2
anchor_cy = (anchors[:, 1] + anchors[:, 3]) / 2
anchor_centers = torch.stack([anchor_cx, anchor_cy], dim=1) # [N, 2]
# GT centers
gt_cx = (gt_boxes[:, 0] + gt_boxes[:, 2]) / 2
gt_cy = (gt_boxes[:, 1] + gt_boxes[:, 3]) / 2
# Distance from each anchor to each GT center
distances = torch.cdist(anchor_centers, torch.stack([gt_cx, gt_cy], dim=1)) # [N, M]
# IoU between anchors and GTs
ious = self._compute_iou(anchors, gt_boxes) # [N, M]
assigned_labels = torch.zeros(num_anchors, dtype=torch.long, device=anchors.device)
assigned_gt_inds = torch.full((num_anchors,), -1, dtype=torch.long, device=anchors.device)
assigned_ious = torch.zeros(num_anchors, device=anchors.device)
# Process each GT
for gt_idx in range(num_gts):
gt_dists = distances[:, gt_idx] # [N]
gt_ious = ious[:, gt_idx] # [N]
# Select top-k closest anchors per level
candidate_mask = torch.zeros(num_anchors, dtype=torch.bool, device=anchors.device)
start = 0
for num_per_level in num_anchors_per_level:
end = start + num_per_level
level_dists = gt_dists[start:end]
k = min(self.topk, num_per_level)
_, topk_inds = level_dists.topk(k, largest=False)
candidate_mask[start + topk_inds] = True
start = end
# Compute adaptive threshold
candidate_ious = gt_ious[candidate_mask]
iou_mean = candidate_ious.mean()
iou_std = candidate_ious.std()
iou_threshold = iou_mean + iou_std
# Filter: IoU > threshold AND center inside GT box
is_positive = (
candidate_mask &
(gt_ious >= iou_threshold) &
(anchor_cx >= gt_boxes[gt_idx, 0]) &
(anchor_cy >= gt_boxes[gt_idx, 1]) &
(anchor_cx <= gt_boxes[gt_idx, 2]) &
(anchor_cy <= gt_boxes[gt_idx, 3])
)
# Assign (higher IoU wins if conflict)
better = is_positive & (gt_ious > assigned_ious)
assigned_labels[better] = gt_labels[gt_idx]
assigned_gt_inds[better] = gt_idx
assigned_ious[better] = gt_ious[better]
return assigned_labels, assigned_gt_inds
@staticmethod
def _compute_iou(boxes1: torch.Tensor, boxes2: torch.Tensor) -> torch.Tensor:
"""Compute pairwise IoU between two sets of boxes. [N,4] × [M,4] → [N,M]"""
area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1])
area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1])
inter_x1 = torch.max(boxes1[:, 0].unsqueeze(1), boxes2[:, 0].unsqueeze(0))
inter_y1 = torch.max(boxes1[:, 1].unsqueeze(1), boxes2[:, 1].unsqueeze(0))
inter_x2 = torch.min(boxes1[:, 2].unsqueeze(1), boxes2[:, 2].unsqueeze(0))
inter_y2 = torch.min(boxes1[:, 3].unsqueeze(1), boxes2[:, 3].unsqueeze(0))
inter = (inter_x2 - inter_x1).clamp(min=0) * (inter_y2 - inter_y1).clamp(min=0)
union = area1.unsqueeze(1) + area2.unsqueeze(0) - inter
return inter / (union + 1e-6)