File size: 7,855 Bytes
20e9cd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
"""
Anchor generation and matching for SCRFD.

SCRFD uses 3-level anchors:
  stride 8:  anchor sizes [16, 32]
  stride 16: anchor sizes [64, 128]
  stride 32: anchor sizes [256, 512]

Matching: ATSS (Adaptive Training Sample Selection) from paper
"Bridging the Gap Between Anchor-based and Anchor-free Detection via
 Adaptive Training Sample Selection" (Zhang et al., 2019)
"""

import torch
import torch.nn as nn
import math
from typing import List, Tuple, Optional


class AnchorGenerator:
    """
    Generate anchors on feature map grids.

    For SCRFD: 2 anchors per location × 3 levels = 6 anchor configs.
    Aspect ratio = 1.0 (square anchors work best for faces).
    """

    def __init__(self,
                 strides: List[int] = [8, 16, 32],
                 anchor_sizes: List[List[int]] = [[16, 32], [64, 128], [256, 512]],
                 ratios: List[float] = [1.0]):
        self.strides = strides
        self.anchor_sizes = anchor_sizes
        self.ratios = ratios
        self.num_anchors_per_level = [len(sizes) * len(ratios) for sizes in anchor_sizes]

    def grid_anchors(self, feat_sizes: List[Tuple[int, int]],
                     device: torch.device) -> List[torch.Tensor]:
        """
        Generate anchor boxes for each feature level.

        Args:
            feat_sizes: [(H, W)] for each level
            device: target device

        Returns:
            List of [num_anchors, 4] tensors in (x1, y1, x2, y2) format
        """
        all_anchors = []
        for i, (feat_h, feat_w) in enumerate(feat_sizes):
            stride = self.strides[i]
            sizes = self.anchor_sizes[i]

            # Grid centers
            shift_x = (torch.arange(0, feat_w, device=device) + 0.5) * stride
            shift_y = (torch.arange(0, feat_h, device=device) + 0.5) * stride
            shift_y, shift_x = torch.meshgrid(shift_y, shift_x, indexing='ij')
            shifts = torch.stack([shift_x.reshape(-1), shift_y.reshape(-1),
                                  shift_x.reshape(-1), shift_y.reshape(-1)], dim=1)

            # Base anchors for this level
            base_anchors = []
            for size in sizes:
                for ratio in self.ratios:
                    w = size * math.sqrt(ratio)
                    h = size / math.sqrt(ratio)
                    base_anchors.append([-w/2, -h/2, w/2, h/2])
            base_anchors = torch.tensor(base_anchors, device=device, dtype=torch.float32)

            # Broadcast: shifts [N, 4] + base_anchors [K, 4] → [N*K, 4]
            num_locs = shifts.shape[0]
            num_bases = base_anchors.shape[0]
            anchors = (shifts.unsqueeze(1) + base_anchors.unsqueeze(0)).reshape(-1, 4)
            all_anchors.append(anchors)

        return all_anchors

    def num_anchors_per_loc(self) -> List[int]:
        return self.num_anchors_per_level


class ATSSAssigner:
    """
    Adaptive Training Sample Selection (ATSS) for anchor-GT matching.

    Key idea: For each GT, select top-k closest anchors from each pyramid level,
    compute their IoU with the GT, and use mean + std as the adaptive IoU threshold.
    Only anchors with IoU > threshold AND whose center is inside GT are positive.

    SCRFD uses ATSS because it adapts to face scale automatically —
    tiny faces get lower thresholds (more positives), large faces get higher ones.

    Args:
        topk: Number of closest anchors to consider per level (default: 9)
    """

    def __init__(self, topk: int = 9):
        self.topk = topk

    @torch.no_grad()
    def assign(self, anchors: torch.Tensor, gt_boxes: torch.Tensor,
               gt_labels: torch.Tensor, num_anchors_per_level: List[int]
               ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Assign anchors to GT boxes using ATSS.

        Args:
            anchors: [N, 4] all anchors concatenated
            gt_boxes: [M, 4] ground truth boxes
            gt_labels: [M] ground truth labels (all 1 for face)
            num_anchors_per_level: number of anchors per feature level

        Returns:
            assigned_labels: [N] (0 = background, 1 = face)
            assigned_gt_inds: [N] (index of assigned GT, -1 for negatives)
        """
        num_anchors = anchors.shape[0]
        num_gts = gt_boxes.shape[0]

        if num_gts == 0:
            return (torch.zeros(num_anchors, dtype=torch.long, device=anchors.device),
                    torch.full((num_anchors,), -1, dtype=torch.long, device=anchors.device))

        # Anchor centers
        anchor_cx = (anchors[:, 0] + anchors[:, 2]) / 2
        anchor_cy = (anchors[:, 1] + anchors[:, 3]) / 2
        anchor_centers = torch.stack([anchor_cx, anchor_cy], dim=1)  # [N, 2]

        # GT centers
        gt_cx = (gt_boxes[:, 0] + gt_boxes[:, 2]) / 2
        gt_cy = (gt_boxes[:, 1] + gt_boxes[:, 3]) / 2

        # Distance from each anchor to each GT center
        distances = torch.cdist(anchor_centers, torch.stack([gt_cx, gt_cy], dim=1))  # [N, M]

        # IoU between anchors and GTs
        ious = self._compute_iou(anchors, gt_boxes)  # [N, M]

        assigned_labels = torch.zeros(num_anchors, dtype=torch.long, device=anchors.device)
        assigned_gt_inds = torch.full((num_anchors,), -1, dtype=torch.long, device=anchors.device)
        assigned_ious = torch.zeros(num_anchors, device=anchors.device)

        # Process each GT
        for gt_idx in range(num_gts):
            gt_dists = distances[:, gt_idx]  # [N]
            gt_ious = ious[:, gt_idx]        # [N]

            # Select top-k closest anchors per level
            candidate_mask = torch.zeros(num_anchors, dtype=torch.bool, device=anchors.device)
            start = 0
            for num_per_level in num_anchors_per_level:
                end = start + num_per_level
                level_dists = gt_dists[start:end]
                k = min(self.topk, num_per_level)
                _, topk_inds = level_dists.topk(k, largest=False)
                candidate_mask[start + topk_inds] = True
                start = end

            # Compute adaptive threshold
            candidate_ious = gt_ious[candidate_mask]
            iou_mean = candidate_ious.mean()
            iou_std = candidate_ious.std()
            iou_threshold = iou_mean + iou_std

            # Filter: IoU > threshold AND center inside GT box
            is_positive = (
                candidate_mask &
                (gt_ious >= iou_threshold) &
                (anchor_cx >= gt_boxes[gt_idx, 0]) &
                (anchor_cy >= gt_boxes[gt_idx, 1]) &
                (anchor_cx <= gt_boxes[gt_idx, 2]) &
                (anchor_cy <= gt_boxes[gt_idx, 3])
            )

            # Assign (higher IoU wins if conflict)
            better = is_positive & (gt_ious > assigned_ious)
            assigned_labels[better] = gt_labels[gt_idx]
            assigned_gt_inds[better] = gt_idx
            assigned_ious[better] = gt_ious[better]

        return assigned_labels, assigned_gt_inds

    @staticmethod
    def _compute_iou(boxes1: torch.Tensor, boxes2: torch.Tensor) -> torch.Tensor:
        """Compute pairwise IoU between two sets of boxes. [N,4] × [M,4] → [N,M]"""
        area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1])
        area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1])

        inter_x1 = torch.max(boxes1[:, 0].unsqueeze(1), boxes2[:, 0].unsqueeze(0))
        inter_y1 = torch.max(boxes1[:, 1].unsqueeze(1), boxes2[:, 1].unsqueeze(0))
        inter_x2 = torch.min(boxes1[:, 2].unsqueeze(1), boxes2[:, 2].unsqueeze(0))
        inter_y2 = torch.min(boxes1[:, 3].unsqueeze(1), boxes2[:, 3].unsqueeze(0))

        inter = (inter_x2 - inter_x1).clamp(min=0) * (inter_y2 - inter_y1).clamp(min=0)
        union = area1.unsqueeze(1) + area2.unsqueeze(0) - inter

        return inter / (union + 1e-6)