omar-ah commited on
Commit
ccfb718
·
verified ·
1 Parent(s): bcd8770

Upload vil_tracker/models/heads.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. vil_tracker/models/heads.py +215 -0
vil_tracker/models/heads.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Prediction Heads for ViL Tracker.
3
+
4
+ CenterHead: Predicts center heatmap + bounding box size from search features
5
+ UncertaintyHead: Predicts aleatoric uncertainty for each prediction
6
+ decode_predictions: Converts heatmaps + sizes to bounding boxes
7
+
8
+ Architecture follows SUTrack/OSTrack corner-free head design:
9
+ - Search features (B, 256, D) → reshape to (B, D, 16, 16)
10
+ - Conv layers predict heatmap (B, 1, 16, 16) and size (B, 2, 16, 16)
11
+ - Peak detection gives center, size gives w/h relative to search region
12
+ """
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from einops import rearrange
18
+
19
+
20
+ class CenterHead(nn.Module):
21
+ """Center-based prediction head.
22
+
23
+ Produces:
24
+ - Center heatmap: (B, 1, H, W) probability of target center at each location
25
+ - Size map: (B, 2, H, W) predicted width/height at each location
26
+ - Offset map: (B, 2, H, W) sub-pixel offset refinement
27
+ """
28
+ def __init__(self, dim: int = 384, feat_size: int = 16):
29
+ super().__init__()
30
+ self.feat_size = feat_size
31
+
32
+ # Shared stem
33
+ self.stem = nn.Sequential(
34
+ nn.Conv2d(dim, 256, 3, padding=1),
35
+ nn.GroupNorm(32, 256),
36
+ nn.GELU(),
37
+ nn.Conv2d(256, 256, 3, padding=1),
38
+ nn.GroupNorm(32, 256),
39
+ nn.GELU(),
40
+ )
41
+
42
+ # Center heatmap head
43
+ self.heatmap = nn.Sequential(
44
+ nn.Conv2d(256, 64, 3, padding=1),
45
+ nn.GELU(),
46
+ nn.Conv2d(64, 1, 1),
47
+ )
48
+
49
+ # Size head (w, h)
50
+ self.size = nn.Sequential(
51
+ nn.Conv2d(256, 64, 3, padding=1),
52
+ nn.GELU(),
53
+ nn.Conv2d(64, 2, 1),
54
+ nn.Sigmoid(), # size in [0, 1] relative to search region
55
+ )
56
+
57
+ # Sub-pixel offset head
58
+ self.offset = nn.Sequential(
59
+ nn.Conv2d(256, 64, 3, padding=1),
60
+ nn.GELU(),
61
+ nn.Conv2d(64, 2, 1),
62
+ nn.Tanh(), # offset in [-1, 1] (sub-pixel correction)
63
+ )
64
+
65
+ def forward(self, search_feat: torch.Tensor) -> dict:
66
+ """
67
+ Args:
68
+ search_feat: (B, N, D) search region features, N=16*16=256
69
+ Returns:
70
+ dict with 'heatmap', 'size', 'offset' tensors
71
+ """
72
+ B = search_feat.shape[0]
73
+ # Reshape to spatial grid
74
+ x = rearrange(search_feat, 'b (h w) d -> b d h w', h=self.feat_size, w=self.feat_size)
75
+
76
+ feat = self.stem(x)
77
+
78
+ return {
79
+ 'heatmap': self.heatmap(feat), # (B, 1, 16, 16)
80
+ 'size': self.size(feat), # (B, 2, 16, 16)
81
+ 'offset': self.offset(feat) * 0.5, # (B, 2, 16, 16) scaled to [-0.5, 0.5]
82
+ }
83
+
84
+
85
+ class UncertaintyHead(nn.Module):
86
+ """Predicts aleatoric uncertainty (log variance) for predictions.
87
+
88
+ Used for:
89
+ 1. Weighting loss contributions (uncertain predictions get lower weight)
90
+ 2. Online tracking confidence (skip update when uncertain)
91
+ 3. Kalman filter measurement noise adaptation
92
+ """
93
+ def __init__(self, dim: int = 384, feat_size: int = 16):
94
+ super().__init__()
95
+ self.feat_size = feat_size
96
+ self.net = nn.Sequential(
97
+ nn.Conv2d(dim, 128, 3, padding=1),
98
+ nn.GroupNorm(16, 128),
99
+ nn.GELU(),
100
+ nn.Conv2d(128, 64, 3, padding=1),
101
+ nn.GELU(),
102
+ nn.Conv2d(64, 1, 1),
103
+ )
104
+
105
+ def forward(self, search_feat: torch.Tensor) -> torch.Tensor:
106
+ """
107
+ Args:
108
+ search_feat: (B, N, D) search features
109
+ Returns:
110
+ log_variance: (B, 1, H, W) predicted log variance
111
+ """
112
+ B = search_feat.shape[0]
113
+ x = rearrange(search_feat, 'b (h w) d -> b d h w', h=self.feat_size, w=self.feat_size)
114
+ return self.net(x)
115
+
116
+
117
+ def decode_predictions(
118
+ heatmap: torch.Tensor,
119
+ size: torch.Tensor,
120
+ offset: torch.Tensor,
121
+ search_size: int = 256,
122
+ feat_size: int = 16,
123
+ ) -> tuple:
124
+ """Decode head outputs to bounding boxes.
125
+
126
+ Args:
127
+ heatmap: (B, 1, H, W) center heatmap
128
+ size: (B, 2, H, W) predicted w/h relative to search region
129
+ offset: (B, 2, H, W) sub-pixel offset
130
+ search_size: pixel size of search region
131
+ feat_size: spatial size of feature map
132
+
133
+ Returns:
134
+ boxes: (B, 4) predicted boxes in [cx, cy, w, h] format, in pixels
135
+ scores: (B,) confidence scores
136
+ """
137
+ B = heatmap.shape[0]
138
+ stride = search_size / feat_size # 256/16 = 16
139
+
140
+ # Find peak in heatmap
141
+ heatmap_flat = heatmap.view(B, -1) # (B, H*W)
142
+ scores, indices = heatmap_flat.max(dim=-1) # (B,)
143
+ scores = scores.sigmoid()
144
+
145
+ # Convert flat index to 2D coordinates
146
+ cy_idx = indices // feat_size # row
147
+ cx_idx = indices % feat_size # col
148
+
149
+ # Get size and offset at peak location
150
+ pred_w = size[:, 0].view(B, -1).gather(1, indices.unsqueeze(1)).squeeze(1) # (B,)
151
+ pred_h = size[:, 1].view(B, -1).gather(1, indices.unsqueeze(1)).squeeze(1)
152
+ off_x = offset[:, 0].view(B, -1).gather(1, indices.unsqueeze(1)).squeeze(1)
153
+ off_y = offset[:, 1].view(B, -1).gather(1, indices.unsqueeze(1)).squeeze(1)
154
+
155
+ # Convert to pixel coordinates
156
+ cx = (cx_idx.float() + 0.5 + off_x) * stride
157
+ cy = (cy_idx.float() + 0.5 + off_y) * stride
158
+ w = pred_w * search_size
159
+ h = pred_h * search_size
160
+
161
+ boxes = torch.stack([cx, cy, w, h], dim=-1) # (B, 4)
162
+ return boxes, scores
163
+
164
+
165
+ def generate_heatmap(
166
+ center: torch.Tensor,
167
+ feat_size: int = 16,
168
+ search_size: int = 256,
169
+ sigma: float = 2.0,
170
+ ) -> torch.Tensor:
171
+ """Generate ground truth Gaussian heatmap for center supervision.
172
+
173
+ Args:
174
+ center: (B, 2) target center in pixel coords (cx, cy) in search region
175
+ feat_size: spatial size of feature map
176
+ search_size: pixel size of search region
177
+ sigma: Gaussian standard deviation in feature map units
178
+ Returns:
179
+ heatmap: (B, 1, feat_size, feat_size) ground truth heatmap
180
+ """
181
+ B = center.shape[0]
182
+ stride = search_size / feat_size
183
+
184
+ # Convert pixel center to feature map coordinates
185
+ center_feat = center / stride # (B, 2) in feature map coords
186
+
187
+ # Create coordinate grid
188
+ y = torch.arange(feat_size, device=center.device, dtype=center.dtype)
189
+ x = torch.arange(feat_size, device=center.device, dtype=center.dtype)
190
+ yy, xx = torch.meshgrid(y, x, indexing='ij')
191
+ grid = torch.stack([xx, yy], dim=-1) # (H, W, 2)
192
+
193
+ # Gaussian around center
194
+ center_feat = center_feat.view(B, 1, 1, 2)
195
+ grid = grid.unsqueeze(0) # (1, H, W, 2)
196
+
197
+ dist_sq = ((grid - center_feat) ** 2).sum(dim=-1) # (B, H, W)
198
+ heatmap = torch.exp(-dist_sq / (2 * sigma ** 2))
199
+
200
+ return heatmap.unsqueeze(1) # (B, 1, H, W)
201
+
202
+
203
+ def generate_size_target(
204
+ size: torch.Tensor,
205
+ search_size: int = 256,
206
+ ) -> torch.Tensor:
207
+ """Generate ground truth size target.
208
+
209
+ Args:
210
+ size: (B, 2) target [width, height] in pixels
211
+ search_size: pixel size of search region
212
+ Returns:
213
+ size_norm: (B, 2) normalized to [0, 1] relative to search region
214
+ """
215
+ return size.clamp(min=1) / search_size