cledouxluma commited on
Commit
20e9cd1
·
verified ·
1 Parent(s): a73c9ef

Upload models/anchor.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models/anchor.py +197 -0
models/anchor.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Anchor generation and matching for SCRFD.
3
+
4
+ SCRFD uses 3-level anchors:
5
+ stride 8: anchor sizes [16, 32]
6
+ stride 16: anchor sizes [64, 128]
7
+ stride 32: anchor sizes [256, 512]
8
+
9
+ Matching: ATSS (Adaptive Training Sample Selection) from paper
10
+ "Bridging the Gap Between Anchor-based and Anchor-free Detection via
11
+ Adaptive Training Sample Selection" (Zhang et al., 2019)
12
+ """
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import math
17
+ from typing import List, Tuple, Optional
18
+
19
+
20
+ class AnchorGenerator:
21
+ """
22
+ Generate anchors on feature map grids.
23
+
24
+ For SCRFD: 2 anchors per location × 3 levels = 6 anchor configs.
25
+ Aspect ratio = 1.0 (square anchors work best for faces).
26
+ """
27
+
28
+ def __init__(self,
29
+ strides: List[int] = [8, 16, 32],
30
+ anchor_sizes: List[List[int]] = [[16, 32], [64, 128], [256, 512]],
31
+ ratios: List[float] = [1.0]):
32
+ self.strides = strides
33
+ self.anchor_sizes = anchor_sizes
34
+ self.ratios = ratios
35
+ self.num_anchors_per_level = [len(sizes) * len(ratios) for sizes in anchor_sizes]
36
+
37
+ def grid_anchors(self, feat_sizes: List[Tuple[int, int]],
38
+ device: torch.device) -> List[torch.Tensor]:
39
+ """
40
+ Generate anchor boxes for each feature level.
41
+
42
+ Args:
43
+ feat_sizes: [(H, W)] for each level
44
+ device: target device
45
+
46
+ Returns:
47
+ List of [num_anchors, 4] tensors in (x1, y1, x2, y2) format
48
+ """
49
+ all_anchors = []
50
+ for i, (feat_h, feat_w) in enumerate(feat_sizes):
51
+ stride = self.strides[i]
52
+ sizes = self.anchor_sizes[i]
53
+
54
+ # Grid centers
55
+ shift_x = (torch.arange(0, feat_w, device=device) + 0.5) * stride
56
+ shift_y = (torch.arange(0, feat_h, device=device) + 0.5) * stride
57
+ shift_y, shift_x = torch.meshgrid(shift_y, shift_x, indexing='ij')
58
+ shifts = torch.stack([shift_x.reshape(-1), shift_y.reshape(-1),
59
+ shift_x.reshape(-1), shift_y.reshape(-1)], dim=1)
60
+
61
+ # Base anchors for this level
62
+ base_anchors = []
63
+ for size in sizes:
64
+ for ratio in self.ratios:
65
+ w = size * math.sqrt(ratio)
66
+ h = size / math.sqrt(ratio)
67
+ base_anchors.append([-w/2, -h/2, w/2, h/2])
68
+ base_anchors = torch.tensor(base_anchors, device=device, dtype=torch.float32)
69
+
70
+ # Broadcast: shifts [N, 4] + base_anchors [K, 4] → [N*K, 4]
71
+ num_locs = shifts.shape[0]
72
+ num_bases = base_anchors.shape[0]
73
+ anchors = (shifts.unsqueeze(1) + base_anchors.unsqueeze(0)).reshape(-1, 4)
74
+ all_anchors.append(anchors)
75
+
76
+ return all_anchors
77
+
78
+ def num_anchors_per_loc(self) -> List[int]:
79
+ return self.num_anchors_per_level
80
+
81
+
82
+ class ATSSAssigner:
83
+ """
84
+ Adaptive Training Sample Selection (ATSS) for anchor-GT matching.
85
+
86
+ Key idea: For each GT, select top-k closest anchors from each pyramid level,
87
+ compute their IoU with the GT, and use mean + std as the adaptive IoU threshold.
88
+ Only anchors with IoU > threshold AND whose center is inside GT are positive.
89
+
90
+ SCRFD uses ATSS because it adapts to face scale automatically —
91
+ tiny faces get lower thresholds (more positives), large faces get higher ones.
92
+
93
+ Args:
94
+ topk: Number of closest anchors to consider per level (default: 9)
95
+ """
96
+
97
+ def __init__(self, topk: int = 9):
98
+ self.topk = topk
99
+
100
+ @torch.no_grad()
101
+ def assign(self, anchors: torch.Tensor, gt_boxes: torch.Tensor,
102
+ gt_labels: torch.Tensor, num_anchors_per_level: List[int]
103
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
104
+ """
105
+ Assign anchors to GT boxes using ATSS.
106
+
107
+ Args:
108
+ anchors: [N, 4] all anchors concatenated
109
+ gt_boxes: [M, 4] ground truth boxes
110
+ gt_labels: [M] ground truth labels (all 1 for face)
111
+ num_anchors_per_level: number of anchors per feature level
112
+
113
+ Returns:
114
+ assigned_labels: [N] (0 = background, 1 = face)
115
+ assigned_gt_inds: [N] (index of assigned GT, -1 for negatives)
116
+ """
117
+ num_anchors = anchors.shape[0]
118
+ num_gts = gt_boxes.shape[0]
119
+
120
+ if num_gts == 0:
121
+ return (torch.zeros(num_anchors, dtype=torch.long, device=anchors.device),
122
+ torch.full((num_anchors,), -1, dtype=torch.long, device=anchors.device))
123
+
124
+ # Anchor centers
125
+ anchor_cx = (anchors[:, 0] + anchors[:, 2]) / 2
126
+ anchor_cy = (anchors[:, 1] + anchors[:, 3]) / 2
127
+ anchor_centers = torch.stack([anchor_cx, anchor_cy], dim=1) # [N, 2]
128
+
129
+ # GT centers
130
+ gt_cx = (gt_boxes[:, 0] + gt_boxes[:, 2]) / 2
131
+ gt_cy = (gt_boxes[:, 1] + gt_boxes[:, 3]) / 2
132
+
133
+ # Distance from each anchor to each GT center
134
+ distances = torch.cdist(anchor_centers, torch.stack([gt_cx, gt_cy], dim=1)) # [N, M]
135
+
136
+ # IoU between anchors and GTs
137
+ ious = self._compute_iou(anchors, gt_boxes) # [N, M]
138
+
139
+ assigned_labels = torch.zeros(num_anchors, dtype=torch.long, device=anchors.device)
140
+ assigned_gt_inds = torch.full((num_anchors,), -1, dtype=torch.long, device=anchors.device)
141
+ assigned_ious = torch.zeros(num_anchors, device=anchors.device)
142
+
143
+ # Process each GT
144
+ for gt_idx in range(num_gts):
145
+ gt_dists = distances[:, gt_idx] # [N]
146
+ gt_ious = ious[:, gt_idx] # [N]
147
+
148
+ # Select top-k closest anchors per level
149
+ candidate_mask = torch.zeros(num_anchors, dtype=torch.bool, device=anchors.device)
150
+ start = 0
151
+ for num_per_level in num_anchors_per_level:
152
+ end = start + num_per_level
153
+ level_dists = gt_dists[start:end]
154
+ k = min(self.topk, num_per_level)
155
+ _, topk_inds = level_dists.topk(k, largest=False)
156
+ candidate_mask[start + topk_inds] = True
157
+ start = end
158
+
159
+ # Compute adaptive threshold
160
+ candidate_ious = gt_ious[candidate_mask]
161
+ iou_mean = candidate_ious.mean()
162
+ iou_std = candidate_ious.std()
163
+ iou_threshold = iou_mean + iou_std
164
+
165
+ # Filter: IoU > threshold AND center inside GT box
166
+ is_positive = (
167
+ candidate_mask &
168
+ (gt_ious >= iou_threshold) &
169
+ (anchor_cx >= gt_boxes[gt_idx, 0]) &
170
+ (anchor_cy >= gt_boxes[gt_idx, 1]) &
171
+ (anchor_cx <= gt_boxes[gt_idx, 2]) &
172
+ (anchor_cy <= gt_boxes[gt_idx, 3])
173
+ )
174
+
175
+ # Assign (higher IoU wins if conflict)
176
+ better = is_positive & (gt_ious > assigned_ious)
177
+ assigned_labels[better] = gt_labels[gt_idx]
178
+ assigned_gt_inds[better] = gt_idx
179
+ assigned_ious[better] = gt_ious[better]
180
+
181
+ return assigned_labels, assigned_gt_inds
182
+
183
+ @staticmethod
184
+ def _compute_iou(boxes1: torch.Tensor, boxes2: torch.Tensor) -> torch.Tensor:
185
+ """Compute pairwise IoU between two sets of boxes. [N,4] × [M,4] → [N,M]"""
186
+ area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1])
187
+ area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1])
188
+
189
+ inter_x1 = torch.max(boxes1[:, 0].unsqueeze(1), boxes2[:, 0].unsqueeze(0))
190
+ inter_y1 = torch.max(boxes1[:, 1].unsqueeze(1), boxes2[:, 1].unsqueeze(0))
191
+ inter_x2 = torch.min(boxes1[:, 2].unsqueeze(1), boxes2[:, 2].unsqueeze(0))
192
+ inter_y2 = torch.min(boxes1[:, 3].unsqueeze(1), boxes2[:, 3].unsqueeze(0))
193
+
194
+ inter = (inter_x2 - inter_x1).clamp(min=0) * (inter_y2 - inter_y1).clamp(min=0)
195
+ union = area1.unsqueeze(1) + area2.unsqueeze(0) - inter
196
+
197
+ return inter / (union + 1e-6)