Upload models/anchor.py with huggingface_hub
Browse files- models/anchor.py +197 -0
models/anchor.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Anchor generation and matching for SCRFD.
|
| 3 |
+
|
| 4 |
+
SCRFD uses 3-level anchors:
|
| 5 |
+
stride 8: anchor sizes [16, 32]
|
| 6 |
+
stride 16: anchor sizes [64, 128]
|
| 7 |
+
stride 32: anchor sizes [256, 512]
|
| 8 |
+
|
| 9 |
+
Matching: ATSS (Adaptive Training Sample Selection) from paper
|
| 10 |
+
"Bridging the Gap Between Anchor-based and Anchor-free Detection via
|
| 11 |
+
Adaptive Training Sample Selection" (Zhang et al., 2019)
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import math
|
| 17 |
+
from typing import List, Tuple, Optional
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class AnchorGenerator:
|
| 21 |
+
"""
|
| 22 |
+
Generate anchors on feature map grids.
|
| 23 |
+
|
| 24 |
+
For SCRFD: 2 anchors per location × 3 levels = 6 anchor configs.
|
| 25 |
+
Aspect ratio = 1.0 (square anchors work best for faces).
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(self,
|
| 29 |
+
strides: List[int] = [8, 16, 32],
|
| 30 |
+
anchor_sizes: List[List[int]] = [[16, 32], [64, 128], [256, 512]],
|
| 31 |
+
ratios: List[float] = [1.0]):
|
| 32 |
+
self.strides = strides
|
| 33 |
+
self.anchor_sizes = anchor_sizes
|
| 34 |
+
self.ratios = ratios
|
| 35 |
+
self.num_anchors_per_level = [len(sizes) * len(ratios) for sizes in anchor_sizes]
|
| 36 |
+
|
| 37 |
+
def grid_anchors(self, feat_sizes: List[Tuple[int, int]],
|
| 38 |
+
device: torch.device) -> List[torch.Tensor]:
|
| 39 |
+
"""
|
| 40 |
+
Generate anchor boxes for each feature level.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
feat_sizes: [(H, W)] for each level
|
| 44 |
+
device: target device
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
List of [num_anchors, 4] tensors in (x1, y1, x2, y2) format
|
| 48 |
+
"""
|
| 49 |
+
all_anchors = []
|
| 50 |
+
for i, (feat_h, feat_w) in enumerate(feat_sizes):
|
| 51 |
+
stride = self.strides[i]
|
| 52 |
+
sizes = self.anchor_sizes[i]
|
| 53 |
+
|
| 54 |
+
# Grid centers
|
| 55 |
+
shift_x = (torch.arange(0, feat_w, device=device) + 0.5) * stride
|
| 56 |
+
shift_y = (torch.arange(0, feat_h, device=device) + 0.5) * stride
|
| 57 |
+
shift_y, shift_x = torch.meshgrid(shift_y, shift_x, indexing='ij')
|
| 58 |
+
shifts = torch.stack([shift_x.reshape(-1), shift_y.reshape(-1),
|
| 59 |
+
shift_x.reshape(-1), shift_y.reshape(-1)], dim=1)
|
| 60 |
+
|
| 61 |
+
# Base anchors for this level
|
| 62 |
+
base_anchors = []
|
| 63 |
+
for size in sizes:
|
| 64 |
+
for ratio in self.ratios:
|
| 65 |
+
w = size * math.sqrt(ratio)
|
| 66 |
+
h = size / math.sqrt(ratio)
|
| 67 |
+
base_anchors.append([-w/2, -h/2, w/2, h/2])
|
| 68 |
+
base_anchors = torch.tensor(base_anchors, device=device, dtype=torch.float32)
|
| 69 |
+
|
| 70 |
+
# Broadcast: shifts [N, 4] + base_anchors [K, 4] → [N*K, 4]
|
| 71 |
+
num_locs = shifts.shape[0]
|
| 72 |
+
num_bases = base_anchors.shape[0]
|
| 73 |
+
anchors = (shifts.unsqueeze(1) + base_anchors.unsqueeze(0)).reshape(-1, 4)
|
| 74 |
+
all_anchors.append(anchors)
|
| 75 |
+
|
| 76 |
+
return all_anchors
|
| 77 |
+
|
| 78 |
+
def num_anchors_per_loc(self) -> List[int]:
|
| 79 |
+
return self.num_anchors_per_level
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class ATSSAssigner:
|
| 83 |
+
"""
|
| 84 |
+
Adaptive Training Sample Selection (ATSS) for anchor-GT matching.
|
| 85 |
+
|
| 86 |
+
Key idea: For each GT, select top-k closest anchors from each pyramid level,
|
| 87 |
+
compute their IoU with the GT, and use mean + std as the adaptive IoU threshold.
|
| 88 |
+
Only anchors with IoU > threshold AND whose center is inside GT are positive.
|
| 89 |
+
|
| 90 |
+
SCRFD uses ATSS because it adapts to face scale automatically —
|
| 91 |
+
tiny faces get lower thresholds (more positives), large faces get higher ones.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
topk: Number of closest anchors to consider per level (default: 9)
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
def __init__(self, topk: int = 9):
|
| 98 |
+
self.topk = topk
|
| 99 |
+
|
| 100 |
+
@torch.no_grad()
|
| 101 |
+
def assign(self, anchors: torch.Tensor, gt_boxes: torch.Tensor,
|
| 102 |
+
gt_labels: torch.Tensor, num_anchors_per_level: List[int]
|
| 103 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 104 |
+
"""
|
| 105 |
+
Assign anchors to GT boxes using ATSS.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
anchors: [N, 4] all anchors concatenated
|
| 109 |
+
gt_boxes: [M, 4] ground truth boxes
|
| 110 |
+
gt_labels: [M] ground truth labels (all 1 for face)
|
| 111 |
+
num_anchors_per_level: number of anchors per feature level
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
assigned_labels: [N] (0 = background, 1 = face)
|
| 115 |
+
assigned_gt_inds: [N] (index of assigned GT, -1 for negatives)
|
| 116 |
+
"""
|
| 117 |
+
num_anchors = anchors.shape[0]
|
| 118 |
+
num_gts = gt_boxes.shape[0]
|
| 119 |
+
|
| 120 |
+
if num_gts == 0:
|
| 121 |
+
return (torch.zeros(num_anchors, dtype=torch.long, device=anchors.device),
|
| 122 |
+
torch.full((num_anchors,), -1, dtype=torch.long, device=anchors.device))
|
| 123 |
+
|
| 124 |
+
# Anchor centers
|
| 125 |
+
anchor_cx = (anchors[:, 0] + anchors[:, 2]) / 2
|
| 126 |
+
anchor_cy = (anchors[:, 1] + anchors[:, 3]) / 2
|
| 127 |
+
anchor_centers = torch.stack([anchor_cx, anchor_cy], dim=1) # [N, 2]
|
| 128 |
+
|
| 129 |
+
# GT centers
|
| 130 |
+
gt_cx = (gt_boxes[:, 0] + gt_boxes[:, 2]) / 2
|
| 131 |
+
gt_cy = (gt_boxes[:, 1] + gt_boxes[:, 3]) / 2
|
| 132 |
+
|
| 133 |
+
# Distance from each anchor to each GT center
|
| 134 |
+
distances = torch.cdist(anchor_centers, torch.stack([gt_cx, gt_cy], dim=1)) # [N, M]
|
| 135 |
+
|
| 136 |
+
# IoU between anchors and GTs
|
| 137 |
+
ious = self._compute_iou(anchors, gt_boxes) # [N, M]
|
| 138 |
+
|
| 139 |
+
assigned_labels = torch.zeros(num_anchors, dtype=torch.long, device=anchors.device)
|
| 140 |
+
assigned_gt_inds = torch.full((num_anchors,), -1, dtype=torch.long, device=anchors.device)
|
| 141 |
+
assigned_ious = torch.zeros(num_anchors, device=anchors.device)
|
| 142 |
+
|
| 143 |
+
# Process each GT
|
| 144 |
+
for gt_idx in range(num_gts):
|
| 145 |
+
gt_dists = distances[:, gt_idx] # [N]
|
| 146 |
+
gt_ious = ious[:, gt_idx] # [N]
|
| 147 |
+
|
| 148 |
+
# Select top-k closest anchors per level
|
| 149 |
+
candidate_mask = torch.zeros(num_anchors, dtype=torch.bool, device=anchors.device)
|
| 150 |
+
start = 0
|
| 151 |
+
for num_per_level in num_anchors_per_level:
|
| 152 |
+
end = start + num_per_level
|
| 153 |
+
level_dists = gt_dists[start:end]
|
| 154 |
+
k = min(self.topk, num_per_level)
|
| 155 |
+
_, topk_inds = level_dists.topk(k, largest=False)
|
| 156 |
+
candidate_mask[start + topk_inds] = True
|
| 157 |
+
start = end
|
| 158 |
+
|
| 159 |
+
# Compute adaptive threshold
|
| 160 |
+
candidate_ious = gt_ious[candidate_mask]
|
| 161 |
+
iou_mean = candidate_ious.mean()
|
| 162 |
+
iou_std = candidate_ious.std()
|
| 163 |
+
iou_threshold = iou_mean + iou_std
|
| 164 |
+
|
| 165 |
+
# Filter: IoU > threshold AND center inside GT box
|
| 166 |
+
is_positive = (
|
| 167 |
+
candidate_mask &
|
| 168 |
+
(gt_ious >= iou_threshold) &
|
| 169 |
+
(anchor_cx >= gt_boxes[gt_idx, 0]) &
|
| 170 |
+
(anchor_cy >= gt_boxes[gt_idx, 1]) &
|
| 171 |
+
(anchor_cx <= gt_boxes[gt_idx, 2]) &
|
| 172 |
+
(anchor_cy <= gt_boxes[gt_idx, 3])
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
# Assign (higher IoU wins if conflict)
|
| 176 |
+
better = is_positive & (gt_ious > assigned_ious)
|
| 177 |
+
assigned_labels[better] = gt_labels[gt_idx]
|
| 178 |
+
assigned_gt_inds[better] = gt_idx
|
| 179 |
+
assigned_ious[better] = gt_ious[better]
|
| 180 |
+
|
| 181 |
+
return assigned_labels, assigned_gt_inds
|
| 182 |
+
|
| 183 |
+
@staticmethod
|
| 184 |
+
def _compute_iou(boxes1: torch.Tensor, boxes2: torch.Tensor) -> torch.Tensor:
|
| 185 |
+
"""Compute pairwise IoU between two sets of boxes. [N,4] × [M,4] → [N,M]"""
|
| 186 |
+
area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1])
|
| 187 |
+
area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1])
|
| 188 |
+
|
| 189 |
+
inter_x1 = torch.max(boxes1[:, 0].unsqueeze(1), boxes2[:, 0].unsqueeze(0))
|
| 190 |
+
inter_y1 = torch.max(boxes1[:, 1].unsqueeze(1), boxes2[:, 1].unsqueeze(0))
|
| 191 |
+
inter_x2 = torch.min(boxes1[:, 2].unsqueeze(1), boxes2[:, 2].unsqueeze(0))
|
| 192 |
+
inter_y2 = torch.min(boxes1[:, 3].unsqueeze(1), boxes2[:, 3].unsqueeze(0))
|
| 193 |
+
|
| 194 |
+
inter = (inter_x2 - inter_x1).clamp(min=0) * (inter_y2 - inter_y1).clamp(min=0)
|
| 195 |
+
union = area1.unsqueeze(1) + area2.unsqueeze(0) - inter
|
| 196 |
+
|
| 197 |
+
return inter / (union + 1e-6)
|