omar-ah commited on
Commit
99c5702
·
verified ·
1 Parent(s): 908010b

Upload vil_tracker/inference/online_tracker.py with huggingface_hub

Browse files
vil_tracker/inference/online_tracker.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Online Tracker: Full inference pipeline for ViL Tracker.
3
+
4
+ Pipeline per frame:
5
+ 1. Crop search region around predicted position
6
+ 2. Run model: template + search → heatmap, size, offset
7
+ 3. Decode predictions → candidate box
8
+ 4. Apply Kalman filter for temporal smoothing
9
+ 5. Update search region for next frame
10
+
11
+ Features:
12
+ - Adaptive search region scaling
13
+ - Confidence-based template update (skip when uncertain)
14
+ - Kalman filter with uncertainty-adaptive noise
15
+ """
16
+
17
+ import torch
18
+ import numpy as np
19
+ from .kalman import KalmanFilter
20
+
21
+
22
+ class OnlineTracker:
23
+ """Online single-object tracker using ViL backbone.
24
+
25
+ Usage:
26
+ tracker = OnlineTracker(model, device='cuda')
27
+ tracker.initialize(first_frame, init_bbox) # [x, y, w, h]
28
+ for frame in video[1:]:
29
+ bbox = tracker.track(frame) # returns [x, y, w, h]
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ model,
35
+ device: str = 'cuda',
36
+ template_size: int = 128,
37
+ search_size: int = 256,
38
+ search_scale: float = 4.0,
39
+ confidence_threshold: float = 0.3,
40
+ template_update_threshold: float = 0.8,
41
+ ):
42
+ self.model = model
43
+ self.device = device
44
+ self.template_size = template_size
45
+ self.search_size = search_size
46
+ self.search_scale = search_scale
47
+ self.confidence_threshold = confidence_threshold
48
+ self.template_update_threshold = template_update_threshold
49
+
50
+ self.model.eval()
51
+
52
+ # State
53
+ self.template = None
54
+ self.kalman = KalmanFilter()
55
+ self.target_pos = None # [cx, cy]
56
+ self.target_sz = None # [w, h]
57
+ self.frame_count = 0
58
+
59
+ def initialize(self, frame: np.ndarray, bbox: list):
60
+ """Initialize tracker with first frame and bounding box.
61
+
62
+ Args:
63
+ frame: (H, W, 3) BGR or RGB numpy array
64
+ bbox: [x, y, w, h] initial bounding box (top-left format)
65
+ """
66
+ x, y, w, h = bbox
67
+ self.target_pos = np.array([x + w / 2, y + h / 2])
68
+ self.target_sz = np.array([w, h])
69
+
70
+ # Crop and embed template
71
+ self.template = self._crop_and_preprocess(
72
+ frame, self.target_pos, self.target_sz,
73
+ output_size=self.template_size,
74
+ scale_factor=2.0,
75
+ )
76
+
77
+ # Initialize Kalman filter
78
+ self.kalman.initialize(np.array([
79
+ self.target_pos[0], self.target_pos[1],
80
+ self.target_sz[0], self.target_sz[1],
81
+ ]))
82
+
83
+ # Reset temporal modulation
84
+ self.model.reset_temporal()
85
+ self.frame_count = 0
86
+
87
+ def track(self, frame: np.ndarray) -> list:
88
+ """Track target in new frame.
89
+
90
+ Args:
91
+ frame: (H, W, 3) numpy array
92
+ Returns:
93
+ [x, y, w, h] predicted bounding box (top-left format)
94
+ """
95
+ self.frame_count += 1
96
+
97
+ # Kalman predict
98
+ kf_pred = self.kalman.predict()
99
+ pred_pos = kf_pred[:2]
100
+ pred_sz = kf_pred[2:]
101
+
102
+ # Crop search region around predicted position
103
+ search = self._crop_and_preprocess(
104
+ frame, pred_pos, pred_sz,
105
+ output_size=self.search_size,
106
+ scale_factor=self.search_scale,
107
+ )
108
+
109
+ # Run model
110
+ with torch.no_grad():
111
+ output = self.model(
112
+ self.template.to(self.device),
113
+ search.to(self.device),
114
+ use_temporal=(self.frame_count > 1),
115
+ )
116
+
117
+ # Extract predictions
118
+ boxes = output['boxes'].cpu().numpy()[0] # [cx, cy, w, h] in search region
119
+ score = output['scores'].cpu().item()
120
+
121
+ # Map back to original frame coordinates
122
+ scale_factor = self.search_scale * max(pred_sz) / self.search_size
123
+ cx = (boxes[0] - self.search_size / 2) * scale_factor + pred_pos[0]
124
+ cy = (boxes[1] - self.search_size / 2) * scale_factor + pred_pos[1]
125
+ w = boxes[2] * scale_factor
126
+ h = boxes[3] * scale_factor
127
+
128
+ # Confidence-based update
129
+ if score > self.confidence_threshold:
130
+ # Get uncertainty for Kalman noise adaptation
131
+ uncertainty = 1.0
132
+ if 'log_variance' in output:
133
+ log_var = output['log_variance'].mean().cpu().item()
134
+ uncertainty = max(0.5, min(3.0, np.exp(log_var / 2)))
135
+
136
+ self.kalman.update(np.array([cx, cy, w, h]), uncertainty)
137
+
138
+ # Update template if very confident
139
+ if score > self.template_update_threshold and self.frame_count % 10 == 0:
140
+ self.template = self._crop_and_preprocess(
141
+ frame, np.array([cx, cy]), np.array([w, h]),
142
+ output_size=self.template_size,
143
+ scale_factor=2.0,
144
+ )
145
+
146
+ # Use Kalman-smoothed state
147
+ state = self.kalman.get_state()
148
+ self.target_pos = state[:2]
149
+ self.target_sz = state[2:]
150
+
151
+ # Return top-left format [x, y, w, h]
152
+ return [
153
+ self.target_pos[0] - self.target_sz[0] / 2,
154
+ self.target_pos[1] - self.target_sz[1] / 2,
155
+ self.target_sz[0],
156
+ self.target_sz[1],
157
+ ]
158
+
159
+ def _crop_and_preprocess(
160
+ self,
161
+ frame: np.ndarray,
162
+ center: np.ndarray,
163
+ size: np.ndarray,
164
+ output_size: int,
165
+ scale_factor: float,
166
+ ) -> torch.Tensor:
167
+ """Crop and preprocess image region.
168
+
169
+ Args:
170
+ frame: (H, W, 3) numpy array
171
+ center: [cx, cy] crop center
172
+ size: [w, h] target size
173
+ output_size: desired output size
174
+ scale_factor: how much larger than target to crop
175
+ Returns:
176
+ (1, 3, output_size, output_size) preprocessed tensor
177
+ """
178
+ H, W = frame.shape[:2]
179
+
180
+ # Compute crop size
181
+ crop_size = max(size[0], size[1]) * scale_factor
182
+ crop_size = max(crop_size, 10) # minimum crop size
183
+
184
+ # Crop coordinates
185
+ x1 = int(center[0] - crop_size / 2)
186
+ y1 = int(center[1] - crop_size / 2)
187
+ x2 = int(x1 + crop_size)
188
+ y2 = int(y1 + crop_size)
189
+
190
+ # Handle boundaries with padding
191
+ pad_left = max(0, -x1)
192
+ pad_top = max(0, -y1)
193
+ pad_right = max(0, x2 - W)
194
+ pad_bottom = max(0, y2 - H)
195
+
196
+ x1 = max(0, x1)
197
+ y1 = max(0, y1)
198
+ x2 = min(W, x2)
199
+ y2 = min(H, y2)
200
+
201
+ crop = frame[y1:y2, x1:x2]
202
+
203
+ if pad_left > 0 or pad_top > 0 or pad_right > 0 or pad_bottom > 0:
204
+ crop = np.pad(crop, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)),
205
+ mode='constant', constant_values=0)
206
+
207
+ # Resize to output_size
208
+ if crop.shape[0] > 0 and crop.shape[1] > 0:
209
+ import torch.nn.functional as F
210
+ crop_tensor = torch.from_numpy(crop).float().permute(2, 0, 1).unsqueeze(0)
211
+ crop_tensor = F.interpolate(crop_tensor, size=(output_size, output_size),
212
+ mode='bilinear', align_corners=False)
213
+ else:
214
+ crop_tensor = torch.zeros(1, 3, output_size, output_size)
215
+
216
+ # Normalize to [0, 1]
217
+ crop_tensor = crop_tensor / 255.0
218
+
219
+ return crop_tensor