omar-ah commited on
Commit
69641c6
·
verified ·
1 Parent(s): 8237685

Fix vil_tracker/data/dataset.py: audit corrections

Browse files
Files changed (1) hide show
  1. vil_tracker/data/dataset.py +769 -60
vil_tracker/data/dataset.py CHANGED
@@ -1,80 +1,714 @@
1
  """
2
- Tracking dataset with synthetic fallback for testing.
3
 
4
  Supports:
5
- - GOT-10k, LaSOT, TrackingNet, COCO formats
 
 
 
6
  - Synthetic data generation for testing (no external data needed)
7
  - ACL (Adaptive Curriculum Learning) difficulty scaling
8
- - Standard tracking augmentations: jitter, flip, color aug
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  """
10
 
11
  import os
12
  import math
 
13
  import random
14
  import torch
15
  import numpy as np
16
- from torch.utils.data import Dataset
 
 
17
 
 
 
 
18
 
19
- class TrackingDataset(Dataset):
20
- """Tracking dataset for ViL Tracker training.
21
 
22
- Each sample provides:
23
- - template: (3, 128, 128) template crop
24
- - search: (3, 256, 256) search region crop
25
- - heatmap: (1, 16, 16) GT center heatmap
26
- - size: (2,) GT normalized [w, h]
27
- - boxes: (4,) GT [cx, cy, w, h] in search region pixels
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  """
29
 
30
  def __init__(
31
  self,
32
- data_dir: str = None,
33
- split: str = 'train',
34
  template_size: int = 128,
35
  search_size: int = 256,
36
  feat_size: int = 16,
37
  acl_difficulty: float = 1.0,
38
- synthetic: bool = False,
39
- synthetic_length: int = 10000,
40
  ):
41
  super().__init__()
42
  self.template_size = template_size
43
  self.search_size = search_size
44
  self.feat_size = feat_size
45
  self.acl_difficulty = acl_difficulty
46
- self.synthetic = synthetic
47
- self.synthetic_length = synthetic_length
48
-
49
- if synthetic:
50
- self.samples = list(range(synthetic_length))
51
- else:
52
- self.samples = self._load_dataset(data_dir, split)
53
-
54
- def _load_dataset(self, data_dir, split):
55
- """Load dataset file list. Returns list of sample dicts."""
56
- samples = []
57
- if data_dir and os.path.exists(data_dir):
58
- # Load real dataset
59
- ann_file = os.path.join(data_dir, f'{split}.json')
60
- if os.path.exists(ann_file):
61
- import json
62
- with open(ann_file, 'r') as f:
63
- samples = json.load(f)
64
-
65
- if not samples:
66
- print(f"Warning: No data found at {data_dir}, using synthetic data")
67
- self.synthetic = True
68
- self.synthetic_length = 10000
69
- return list(range(self.synthetic_length))
70
-
71
- return samples
72
 
73
  def __len__(self):
74
- return len(self.samples) if not self.synthetic else self.synthetic_length
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
- def _generate_synthetic_sample(self, idx):
77
- """Generate a synthetic template/search pair with GT annotations."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  rng = random.Random(idx)
79
 
80
  # Random target size (relative to search region)
@@ -88,7 +722,7 @@ class TrackingDataset(Dataset):
88
  cx = max(target_w / 2, min(self.search_size - target_w / 2, cx))
89
  cy = max(target_h / 2, min(self.search_size - target_h / 2, cy))
90
 
91
- # Create synthetic images (colored rectangles on noise background)
92
  template = torch.randn(3, self.template_size, self.template_size) * 0.1
93
  search = torch.randn(3, self.search_size, self.search_size) * 0.1
94
 
@@ -97,7 +731,7 @@ class TrackingDataset(Dataset):
97
  t_half_h = int(min(target_h / 2, self.template_size / 2 - 1))
98
  tc = self.template_size // 2
99
  color = torch.tensor([rng.random(), rng.random(), rng.random()]).view(3, 1, 1)
100
- template[:, tc-t_half_h:tc+t_half_h, tc-t_half_w:tc+t_half_w] = color
101
 
102
  # Draw target in search region
103
  sx1 = max(0, int(cx - target_w / 2))
@@ -119,10 +753,7 @@ class TrackingDataset(Dataset):
119
  dist_sq = (xx - cx_feat) ** 2 + (yy - cy_feat) ** 2
120
  heatmap = torch.exp(-dist_sq / (2 * sigma ** 2)).unsqueeze(0)
121
 
122
- # Normalized size
123
  size = torch.tensor([target_w / self.search_size, target_h / self.search_size])
124
-
125
- # Box in pixels
126
  boxes = torch.tensor([cx, cy, target_w, target_h])
127
 
128
  return {
@@ -133,15 +764,93 @@ class TrackingDataset(Dataset):
133
  'boxes': boxes,
134
  }
135
 
136
- def __getitem__(self, idx):
137
- if self.synthetic:
138
- return self._generate_synthetic_sample(idx)
139
-
140
- # Real data loading would go here
141
- sample = self.samples[idx]
142
- # ... load images, compute crops, generate targets
143
- return self._generate_synthetic_sample(idx) # fallback
144
-
145
  def set_acl_difficulty(self, difficulty: float):
146
- """Update ACL difficulty level (0.0 = easy, 1.0 = hard)."""
147
  self.acl_difficulty = min(1.0, max(0.0, difficulty))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
+ Tracking dataset with real dataset loaders and synthetic fallback.
3
 
4
  Supports:
5
+ - GOT-10k: train split (~10k sequences, annotations in groundtruth.txt)
6
+ - LaSOT: training split (1120 sequences, 14 categories)
7
+ - TrackingNet: training split (30k+ sequences, annotations in anno/)
8
+ - COCO detection: for static pair pretraining (bbox crops as pseudo-sequences)
9
  - Synthetic data generation for testing (no external data needed)
10
  - ACL (Adaptive Curriculum Learning) difficulty scaling
11
+ - Standard tracking augmentations: spatial jitter, horizontal flip, color jitter,
12
+ grayscale, Gaussian blur, brightness/contrast
13
+
14
+ Each sample produces a (template, search) pair from the same video sequence
15
+ with controlled temporal distance, plus GT annotations.
16
+
17
+ Dataset directory structure expected:
18
+ GOT-10k/
19
+ train/
20
+ GOT-10k_Train_000001/
21
+ 00000001.jpg, 00000002.jpg, ...
22
+ groundtruth.txt # x,y,w,h per line
23
+ ...
24
+ LaSOT/
25
+ airplane/
26
+ airplane-1/
27
+ img/
28
+ 00000001.jpg, ...
29
+ groundtruth.txt # x,y,w,h per line
30
+ ...
31
+ TrackingNet/
32
+ TRAIN_0/
33
+ frames/
34
+ video_name/
35
+ 0.jpg, 1.jpg, ...
36
+ anno/
37
+ video_name.txt # x,y,w,h per line
38
+ ...
39
+ COCO/
40
+ train2017/
41
+ *.jpg
42
+ annotations/
43
+ instances_train2017.json
44
  """
45
 
46
  import os
47
  import math
48
+ import glob
49
  import random
50
  import torch
51
  import numpy as np
52
+ from pathlib import Path
53
+ from torch.utils.data import Dataset, ConcatDataset
54
+
55
 
56
+ # ============================================================
57
+ # Augmentations (no torchvision dependency, works with tensors)
58
+ # ============================================================
59
 
60
+ class TrackingAugmentation:
61
+ """Standard tracking augmentations applied to (template, search) pairs.
62
 
63
+ Augmentations preserve the spatial relationship between search region
64
+ and GT bounding box by applying augmentations consistently.
65
+ """
66
+
67
+ def __init__(
68
+ self,
69
+ brightness: float = 0.2,
70
+ contrast: float = 0.2,
71
+ saturation: float = 0.2,
72
+ grayscale_prob: float = 0.05,
73
+ horizontal_flip_prob: float = 0.5,
74
+ blur_prob: float = 0.1,
75
+ blur_sigma: tuple = (0.1, 2.0),
76
+ ):
77
+ self.brightness = brightness
78
+ self.contrast = contrast
79
+ self.saturation = saturation
80
+ self.grayscale_prob = grayscale_prob
81
+ self.horizontal_flip_prob = horizontal_flip_prob
82
+ self.blur_prob = blur_prob
83
+ self.blur_sigma = blur_sigma
84
+
85
+ def __call__(self, template: torch.Tensor, search: torch.Tensor,
86
+ bbox: torch.Tensor) -> tuple:
87
+ """
88
+ Args:
89
+ template: (3, H_t, W_t) tensor in [0, 1]
90
+ search: (3, H_s, W_s) tensor in [0, 1]
91
+ bbox: (4,) tensor [cx, cy, w, h] in search region pixels
92
+ Returns:
93
+ template, search, bbox (augmented)
94
+ """
95
+ # Color jitter (same for template and search to maintain appearance consistency)
96
+ if random.random() < 0.8:
97
+ # Brightness
98
+ factor = 1.0 + random.uniform(-self.brightness, self.brightness)
99
+ template = (template * factor).clamp(0, 1)
100
+ search = (search * factor).clamp(0, 1)
101
+
102
+ # Contrast
103
+ factor = 1.0 + random.uniform(-self.contrast, self.contrast)
104
+ t_mean = template.mean()
105
+ s_mean = search.mean()
106
+ template = ((template - t_mean) * factor + t_mean).clamp(0, 1)
107
+ search = ((search - s_mean) * factor + s_mean).clamp(0, 1)
108
+
109
+ # Grayscale
110
+ if random.random() < self.grayscale_prob:
111
+ t_gray = template.mean(dim=0, keepdim=True).expand_as(template)
112
+ s_gray = search.mean(dim=0, keepdim=True).expand_as(search)
113
+ template = t_gray
114
+ search = s_gray
115
+
116
+ # Horizontal flip (must also flip bbox cx)
117
+ if random.random() < self.horizontal_flip_prob:
118
+ template = template.flip(-1)
119
+ search = search.flip(-1)
120
+ W_s = search.shape[-1]
121
+ bbox = bbox.clone()
122
+ bbox[0] = W_s - bbox[0] # flip cx
123
+
124
+ # Gaussian blur (search only — simulates motion blur)
125
+ if random.random() < self.blur_prob:
126
+ sigma = random.uniform(*self.blur_sigma)
127
+ kernel_size = int(2 * round(3 * sigma) + 1)
128
+ if kernel_size >= 3:
129
+ search = self._gaussian_blur(search, kernel_size, sigma)
130
+
131
+ return template, search, bbox
132
+
133
+ @staticmethod
134
+ def _gaussian_blur(img: torch.Tensor, kernel_size: int, sigma: float) -> torch.Tensor:
135
+ """Apply Gaussian blur to a (C, H, W) tensor."""
136
+ import torch.nn.functional as F
137
+
138
+ # Create 1D Gaussian kernel
139
+ x = torch.arange(kernel_size, dtype=img.dtype, device=img.device) - kernel_size // 2
140
+ kernel_1d = torch.exp(-0.5 * (x / sigma) ** 2)
141
+ kernel_1d = kernel_1d / kernel_1d.sum()
142
+
143
+ # Apply separable 2D blur
144
+ pad = kernel_size // 2
145
+ img = img.unsqueeze(0) # (1, C, H, W)
146
+
147
+ # Horizontal
148
+ k_h = kernel_1d.view(1, 1, 1, -1).expand(img.shape[1], -1, -1, -1)
149
+ img = F.conv2d(F.pad(img, (pad, pad, 0, 0), mode='reflect'),
150
+ k_h, groups=img.shape[1])
151
+
152
+ # Vertical
153
+ k_v = kernel_1d.view(1, 1, -1, 1).expand(img.shape[1], -1, -1, -1)
154
+ img = F.conv2d(F.pad(img, (0, 0, pad, pad), mode='reflect'),
155
+ k_v, groups=img.shape[1])
156
+
157
+ return img.squeeze(0)
158
+
159
+
160
+ # ============================================================
161
+ # Crop utilities
162
+ # ============================================================
163
+
164
+ def crop_and_resize(image: np.ndarray, center: np.ndarray, size: float,
165
+ output_size: int) -> np.ndarray:
166
+ """Crop a square region from image, centered at center, with given size.
167
+
168
+ Args:
169
+ image: (H, W, 3) numpy array, uint8 or float
170
+ center: (2,) [cx, cy] in image coordinates
171
+ size: side length of the square crop
172
+ output_size: resize crop to (output_size, output_size)
173
+ Returns:
174
+ (output_size, output_size, 3) numpy array
175
+ """
176
+ H, W = image.shape[:2]
177
+ half = size / 2
178
+
179
+ x1 = int(round(center[0] - half))
180
+ y1 = int(round(center[1] - half))
181
+ x2 = int(round(center[0] + half))
182
+ y2 = int(round(center[1] + half))
183
+
184
+ # Boundary padding
185
+ pad_left = max(0, -x1)
186
+ pad_top = max(0, -y1)
187
+ pad_right = max(0, x2 - W)
188
+ pad_bottom = max(0, y2 - H)
189
+
190
+ x1c = max(0, x1)
191
+ y1c = max(0, y1)
192
+ x2c = min(W, x2)
193
+ y2c = min(H, y2)
194
+
195
+ crop = image[y1c:y2c, x1c:x2c]
196
+
197
+ if pad_left > 0 or pad_top > 0 or pad_right > 0 or pad_bottom > 0:
198
+ mean_color = image.mean(axis=(0, 1))
199
+ padded = np.full((crop.shape[0] + pad_top + pad_bottom,
200
+ crop.shape[1] + pad_left + pad_right, 3),
201
+ mean_color, dtype=crop.dtype)
202
+ padded[pad_top:pad_top + crop.shape[0], pad_left:pad_left + crop.shape[1]] = crop
203
+ crop = padded
204
+
205
+ # Resize
206
+ if crop.shape[0] > 0 and crop.shape[1] > 0:
207
+ import torch.nn.functional as F
208
+ crop_t = torch.from_numpy(crop.copy()).float().permute(2, 0, 1).unsqueeze(0)
209
+ crop_t = F.interpolate(crop_t, size=(output_size, output_size),
210
+ mode='bilinear', align_corners=False)
211
+ crop = crop_t.squeeze(0).permute(1, 2, 0).numpy()
212
+ else:
213
+ crop = np.zeros((output_size, output_size, 3), dtype=np.float32)
214
+
215
+ return crop
216
+
217
+
218
+ def compute_crop_params(bbox: np.ndarray, context_factor: float = 2.0) -> tuple:
219
+ """Compute crop center and size from bbox with context.
220
+
221
+ Args:
222
+ bbox: [x, y, w, h] bounding box
223
+ context_factor: how much context around bbox (2.0 = 2x target size)
224
+ Returns:
225
+ center: (2,) [cx, cy]
226
+ crop_size: scalar side length
227
+ """
228
+ x, y, w, h = bbox
229
+ cx = x + w / 2
230
+ cy = y + h / 2
231
+
232
+ # Context amount following STARK/OSTrack convention:
233
+ # s = sqrt((w + 2p) * (h + 2p)), where p = (w + h) / 2
234
+ p = (w + h) / 2
235
+ crop_size = math.sqrt((w + p) * (h + p)) * context_factor
236
+ crop_size = max(crop_size, 10)
237
+
238
+ return np.array([cx, cy]), crop_size
239
+
240
+
241
+ # ============================================================
242
+ # Base sequence dataset
243
+ # ============================================================
244
+
245
+ class SequenceDataset(Dataset):
246
+ """Base class for tracking sequence datasets.
247
+
248
+ Subclasses must populate self.sequences with list of:
249
+ {'frames': [path1, path2, ...], 'gt': [[x,y,w,h], ...]}
250
  """
251
 
252
  def __init__(
253
  self,
 
 
254
  template_size: int = 128,
255
  search_size: int = 256,
256
  feat_size: int = 16,
257
  acl_difficulty: float = 1.0,
258
+ max_gap: int = 100,
259
+ augmentation: bool = True,
260
  ):
261
  super().__init__()
262
  self.template_size = template_size
263
  self.search_size = search_size
264
  self.feat_size = feat_size
265
  self.acl_difficulty = acl_difficulty
266
+ self.max_gap = max_gap
267
+ self.sequences = []
268
+
269
+ self.augmentation = TrackingAugmentation() if augmentation else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
 
271
  def __len__(self):
272
+ return len(self.sequences)
273
+
274
+ def _load_image(self, path: str) -> np.ndarray:
275
+ """Load image from path. Returns (H, W, 3) float32 in [0, 255]."""
276
+ try:
277
+ from PIL import Image
278
+ img = Image.open(path).convert('RGB')
279
+ return np.array(img, dtype=np.float32)
280
+ except ImportError:
281
+ # Fallback with OpenCV
282
+ import cv2
283
+ img = cv2.imread(path)
284
+ if img is None:
285
+ return np.zeros((480, 640, 3), dtype=np.float32)
286
+ return cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32)
287
 
288
+ def _sample_pair(self, idx: int) -> tuple:
289
+ """Sample a (template_frame_idx, search_frame_idx) pair.
290
+
291
+ Temporal distance controlled by ACL difficulty:
292
+ - difficulty=0: template and search are very close
293
+ - difficulty=1: template and search can be up to max_gap apart
294
+
295
+ Returns:
296
+ (template_idx, search_idx) frame indices
297
+ """
298
+ seq = self.sequences[idx]
299
+ n_frames = len(seq['frames'])
300
+
301
+ # Template: sample random frame with valid annotation
302
+ valid_indices = [i for i in range(n_frames) if seq['gt'][i] is not None and
303
+ seq['gt'][i][2] > 0 and seq['gt'][i][3] > 0]
304
+
305
+ if len(valid_indices) < 2:
306
+ t_idx = valid_indices[0] if valid_indices else 0
307
+ return t_idx, t_idx
308
+
309
+ t_idx = random.choice(valid_indices)
310
+
311
+ # Search: within difficulty-scaled temporal gap
312
+ effective_gap = max(1, int(self.max_gap * self.acl_difficulty))
313
+ min_idx = max(0, t_idx - effective_gap)
314
+ max_idx = min(n_frames - 1, t_idx + effective_gap)
315
+
316
+ # Only pick valid indices
317
+ search_candidates = [i for i in range(min_idx, max_idx + 1)
318
+ if i != t_idx and i in valid_indices]
319
+
320
+ if not search_candidates:
321
+ return t_idx, t_idx
322
+
323
+ s_idx = random.choice(search_candidates)
324
+ return t_idx, s_idx
325
+
326
+ def __getitem__(self, idx):
327
+ seq = self.sequences[idx % len(self.sequences)]
328
+ t_idx, s_idx = self._sample_pair(idx % len(self.sequences))
329
+
330
+ # Load images
331
+ t_img = self._load_image(seq['frames'][t_idx])
332
+ s_img = self._load_image(seq['frames'][s_idx])
333
+
334
+ t_bbox = np.array(seq['gt'][t_idx], dtype=np.float32) # [x, y, w, h]
335
+ s_bbox = np.array(seq['gt'][s_idx], dtype=np.float32)
336
+
337
+ # Crop template (centered on target, 2x context)
338
+ t_center, t_crop_size = compute_crop_params(t_bbox, context_factor=2.0)
339
+ template = crop_and_resize(t_img, t_center, t_crop_size, self.template_size)
340
+
341
+ # Crop search region (centered on target with jitter, 4x context)
342
+ s_center, s_crop_size = compute_crop_params(s_bbox, context_factor=4.0)
343
+
344
+ # Add spatial jitter (controlled by ACL difficulty)
345
+ jitter = self.acl_difficulty * s_bbox[2:4].mean() * 0.3
346
+ s_center[0] += random.gauss(0, jitter) if jitter > 0 else 0
347
+ s_center[1] += random.gauss(0, jitter) if jitter > 0 else 0
348
+
349
+ search = crop_and_resize(s_img, s_center, s_crop_size, self.search_size)
350
+
351
+ # Compute GT in search crop coordinates
352
+ # Target center relative to crop center, then scaled to search_size
353
+ scale = self.search_size / s_crop_size
354
+ cx_in_search = (s_bbox[0] + s_bbox[2] / 2 - s_center[0] + s_crop_size / 2) * scale
355
+ cy_in_search = (s_bbox[1] + s_bbox[3] / 2 - s_center[1] + s_crop_size / 2) * scale
356
+ w_in_search = s_bbox[2] * scale
357
+ h_in_search = s_bbox[3] * scale
358
+
359
+ # Clamp to search region
360
+ cx_in_search = max(0, min(self.search_size, cx_in_search))
361
+ cy_in_search = max(0, min(self.search_size, cy_in_search))
362
+ w_in_search = max(1, min(self.search_size, w_in_search))
363
+ h_in_search = max(1, min(self.search_size, h_in_search))
364
+
365
+ # Convert to tensors [0, 1]
366
+ template = torch.from_numpy(template).float().permute(2, 0, 1) / 255.0
367
+ search = torch.from_numpy(search).float().permute(2, 0, 1) / 255.0
368
+ bbox_tensor = torch.tensor([cx_in_search, cy_in_search, w_in_search, h_in_search])
369
+
370
+ # Apply augmentations
371
+ if self.augmentation is not None:
372
+ template, search, bbox_tensor = self.augmentation(template, search, bbox_tensor)
373
+
374
+ # Generate GT heatmap
375
+ stride = self.search_size / self.feat_size
376
+ cx_feat = bbox_tensor[0].item() / stride
377
+ cy_feat = bbox_tensor[1].item() / stride
378
+
379
+ y = torch.arange(self.feat_size, dtype=torch.float32)
380
+ x = torch.arange(self.feat_size, dtype=torch.float32)
381
+ yy, xx = torch.meshgrid(y, x, indexing='ij')
382
+
383
+ # Adaptive sigma based on target size (smaller targets = sharper heatmap)
384
+ sigma = max(1.0, min(3.0, (w_in_search + h_in_search) / (2 * stride * 4)))
385
+ dist_sq = (xx - cx_feat) ** 2 + (yy - cy_feat) ** 2
386
+ heatmap = torch.exp(-dist_sq / (2 * sigma ** 2)).unsqueeze(0)
387
+
388
+ # Normalized size
389
+ size = torch.tensor([bbox_tensor[2].item() / self.search_size,
390
+ bbox_tensor[3].item() / self.search_size])
391
+
392
+ return {
393
+ 'template': template,
394
+ 'search': search,
395
+ 'heatmap': heatmap,
396
+ 'size': size,
397
+ 'boxes': bbox_tensor,
398
+ }
399
+
400
+ def set_acl_difficulty(self, difficulty: float):
401
+ """Update ACL difficulty level (0.0 = easy, 1.0 = hard)."""
402
+ self.acl_difficulty = min(1.0, max(0.0, difficulty))
403
+
404
+
405
+ # ============================================================
406
+ # GOT-10k dataset loader
407
+ # ============================================================
408
+
409
+ class GOT10kDataset(SequenceDataset):
410
+ """GOT-10k tracking dataset.
411
+
412
+ Structure:
413
+ root/train/GOT-10k_Train_NNNNNN/
414
+ 00000001.jpg, 00000002.jpg, ...
415
+ groundtruth.txt # x,y,w,h per line
416
+ """
417
+
418
+ def __init__(self, root: str, split: str = 'train', **kwargs):
419
+ super().__init__(**kwargs)
420
+ self.root = Path(root)
421
+ self._load_sequences(split)
422
+
423
+ def _load_sequences(self, split):
424
+ split_dir = self.root / split
425
+ if not split_dir.exists():
426
+ print(f"Warning: GOT-10k {split} not found at {split_dir}")
427
+ return
428
+
429
+ seq_dirs = sorted([d for d in split_dir.iterdir() if d.is_dir() and 'Train' in d.name])
430
+ print(f"Loading GOT-10k {split}: found {len(seq_dirs)} sequences")
431
+
432
+ for seq_dir in seq_dirs:
433
+ gt_file = seq_dir / 'groundtruth.txt'
434
+ if not gt_file.exists():
435
+ continue
436
+
437
+ # Load annotations
438
+ gt_boxes = []
439
+ with open(gt_file, 'r') as f:
440
+ for line in f:
441
+ line = line.strip()
442
+ if not line:
443
+ gt_boxes.append(None)
444
+ continue
445
+ parts = line.replace(',', ' ').split()
446
+ try:
447
+ gt_boxes.append([float(x) for x in parts[:4]])
448
+ except ValueError:
449
+ gt_boxes.append(None)
450
+
451
+ # Get frame paths
452
+ frames = sorted(glob.glob(str(seq_dir / '*.jpg')))
453
+ if not frames:
454
+ frames = sorted(glob.glob(str(seq_dir / '*.png')))
455
+
456
+ if len(frames) != len(gt_boxes):
457
+ # Trim to shorter
458
+ min_len = min(len(frames), len(gt_boxes))
459
+ frames = frames[:min_len]
460
+ gt_boxes = gt_boxes[:min_len]
461
+
462
+ if len(frames) >= 2:
463
+ self.sequences.append({'frames': frames, 'gt': gt_boxes})
464
+
465
+ print(f" Loaded {len(self.sequences)} GOT-10k sequences")
466
+
467
+
468
+ # ============================================================
469
+ # LaSOT dataset loader
470
+ # ============================================================
471
+
472
+ class LaSOTDataset(SequenceDataset):
473
+ """LaSOT tracking dataset.
474
+
475
+ Structure:
476
+ root/
477
+ airplane/
478
+ airplane-1/
479
+ img/
480
+ 00000001.jpg, ...
481
+ groundtruth.txt # x,y,w,h per line
482
+ ...
483
+ """
484
+
485
+ def __init__(self, root: str, split: str = 'train', **kwargs):
486
+ super().__init__(**kwargs)
487
+ self.root = Path(root)
488
+ self._load_sequences(split)
489
+
490
+ def _load_sequences(self, split):
491
+ if not self.root.exists():
492
+ print(f"Warning: LaSOT not found at {self.root}")
493
+ return
494
+
495
+ # LaSOT train/test split defined by sequence names
496
+ # Training: first 80% of sequences per category
497
+ categories = sorted([d for d in self.root.iterdir() if d.is_dir()])
498
+ total_seqs = 0
499
+
500
+ for cat_dir in categories:
501
+ seq_dirs = sorted([d for d in cat_dir.iterdir() if d.is_dir()])
502
+
503
+ # Train/test split
504
+ if split == 'train':
505
+ seq_dirs = seq_dirs[:int(len(seq_dirs) * 0.8)]
506
+ else:
507
+ seq_dirs = seq_dirs[int(len(seq_dirs) * 0.8):]
508
+
509
+ for seq_dir in seq_dirs:
510
+ gt_file = seq_dir / 'groundtruth.txt'
511
+ img_dir = seq_dir / 'img'
512
+
513
+ if not gt_file.exists() or not img_dir.exists():
514
+ continue
515
+
516
+ # Load annotations
517
+ gt_boxes = []
518
+ with open(gt_file, 'r') as f:
519
+ for line in f:
520
+ line = line.strip()
521
+ if not line:
522
+ gt_boxes.append(None)
523
+ continue
524
+ parts = line.replace(',', ' ').split()
525
+ try:
526
+ gt_boxes.append([float(x) for x in parts[:4]])
527
+ except ValueError:
528
+ gt_boxes.append(None)
529
+
530
+ frames = sorted(glob.glob(str(img_dir / '*.jpg')))
531
+
532
+ if len(frames) != len(gt_boxes):
533
+ min_len = min(len(frames), len(gt_boxes))
534
+ frames = frames[:min_len]
535
+ gt_boxes = gt_boxes[:min_len]
536
+
537
+ if len(frames) >= 2:
538
+ self.sequences.append({'frames': frames, 'gt': gt_boxes})
539
+ total_seqs += 1
540
+
541
+ print(f" Loaded {total_seqs} LaSOT {split} sequences across {len(categories)} categories")
542
+
543
+
544
+ # ============================================================
545
+ # TrackingNet dataset loader
546
+ # ============================================================
547
+
548
+ class TrackingNetDataset(SequenceDataset):
549
+ """TrackingNet tracking dataset.
550
+
551
+ Structure:
552
+ root/
553
+ TRAIN_0/
554
+ frames/
555
+ video_name/
556
+ 0.jpg, 1.jpg, ...
557
+ anno/
558
+ video_name.txt # x,y,w,h per line
559
+ TRAIN_1/
560
+ ...
561
+ """
562
+
563
+ def __init__(self, root: str, chunks: list = None, **kwargs):
564
+ super().__init__(**kwargs)
565
+ self.root = Path(root)
566
+ if chunks is None:
567
+ chunks = list(range(12)) # TRAIN_0 through TRAIN_11
568
+ self._load_sequences(chunks)
569
+
570
+ def _load_sequences(self, chunks):
571
+ if not self.root.exists():
572
+ print(f"Warning: TrackingNet not found at {self.root}")
573
+ return
574
+
575
+ total_seqs = 0
576
+ for chunk_idx in chunks:
577
+ chunk_dir = self.root / f'TRAIN_{chunk_idx}'
578
+ if not chunk_dir.exists():
579
+ continue
580
+
581
+ anno_dir = chunk_dir / 'anno'
582
+ frames_dir = chunk_dir / 'frames'
583
+
584
+ if not anno_dir.exists() or not frames_dir.exists():
585
+ continue
586
+
587
+ for anno_file in sorted(anno_dir.glob('*.txt')):
588
+ seq_name = anno_file.stem
589
+ seq_frames_dir = frames_dir / seq_name
590
+
591
+ if not seq_frames_dir.exists():
592
+ continue
593
+
594
+ # Load annotations
595
+ gt_boxes = []
596
+ with open(anno_file, 'r') as f:
597
+ for line in f:
598
+ line = line.strip()
599
+ if not line:
600
+ gt_boxes.append(None)
601
+ continue
602
+ parts = line.replace(',', ' ').split()
603
+ try:
604
+ gt_boxes.append([float(x) for x in parts[:4]])
605
+ except ValueError:
606
+ gt_boxes.append(None)
607
+
608
+ frames = sorted(glob.glob(str(seq_frames_dir / '*.jpg')))
609
+ if not frames:
610
+ frames = sorted(glob.glob(str(seq_frames_dir / '*.png')))
611
+
612
+ if len(frames) != len(gt_boxes):
613
+ min_len = min(len(frames), len(gt_boxes))
614
+ frames = frames[:min_len]
615
+ gt_boxes = gt_boxes[:min_len]
616
+
617
+ if len(frames) >= 2:
618
+ self.sequences.append({'frames': frames, 'gt': gt_boxes})
619
+ total_seqs += 1
620
+
621
+ print(f" Loaded {total_seqs} TrackingNet sequences from {len(chunks)} chunks")
622
+
623
+
624
+ # ============================================================
625
+ # COCO detection as pseudo-sequences
626
+ # ============================================================
627
+
628
+ class COCODetDataset(SequenceDataset):
629
+ """COCO detection images as pseudo-sequences for pretraining.
630
+
631
+ Each image with a valid bounding box becomes a length-1 "sequence"
632
+ where template and search are crops from the same image.
633
+ """
634
+
635
+ def __init__(self, root: str, ann_file: str = None, **kwargs):
636
+ super().__init__(**kwargs)
637
+ self.root = Path(root)
638
+ self._load_annotations(ann_file)
639
+
640
+ def _load_annotations(self, ann_file):
641
+ if ann_file is None:
642
+ ann_file = str(self.root.parent / 'annotations' / 'instances_train2017.json')
643
+
644
+ if not os.path.exists(ann_file):
645
+ print(f"Warning: COCO annotations not found at {ann_file}")
646
+ return
647
+
648
+ try:
649
+ import json
650
+ with open(ann_file, 'r') as f:
651
+ coco = json.load(f)
652
+
653
+ # Build image lookup
654
+ images = {img['id']: img for img in coco['images']}
655
+
656
+ # Create pseudo-sequences from annotations
657
+ for ann in coco['annotations']:
658
+ if ann.get('iscrowd', 0):
659
+ continue
660
+ bbox = ann['bbox'] # [x, y, w, h]
661
+ if bbox[2] < 10 or bbox[3] < 10:
662
+ continue
663
+
664
+ img_info = images.get(ann['image_id'])
665
+ if img_info is None:
666
+ continue
667
+
668
+ img_path = str(self.root / img_info['file_name'])
669
+ if os.path.exists(img_path):
670
+ # Pseudo-sequence: same frame for template and search
671
+ self.sequences.append({
672
+ 'frames': [img_path, img_path],
673
+ 'gt': [bbox, bbox],
674
+ })
675
+
676
+ print(f" Loaded {len(self.sequences)} COCO pseudo-sequences")
677
+
678
+ except Exception as e:
679
+ print(f"Warning: Failed to load COCO annotations: {e}")
680
+
681
+
682
+ # ============================================================
683
+ # Synthetic dataset (for testing / no-data development)
684
+ # ============================================================
685
+
686
+ class SyntheticTrackingDataset(Dataset):
687
+ """Synthetic tracking dataset for testing without real data.
688
+
689
+ Generates colored rectangles on noise backgrounds with controlled
690
+ position jitter based on ACL difficulty.
691
+ """
692
+
693
+ def __init__(
694
+ self,
695
+ length: int = 10000,
696
+ template_size: int = 128,
697
+ search_size: int = 256,
698
+ feat_size: int = 16,
699
+ acl_difficulty: float = 1.0,
700
+ ):
701
+ super().__init__()
702
+ self.length = length
703
+ self.template_size = template_size
704
+ self.search_size = search_size
705
+ self.feat_size = feat_size
706
+ self.acl_difficulty = acl_difficulty
707
+
708
+ def __len__(self):
709
+ return self.length
710
+
711
+ def __getitem__(self, idx):
712
  rng = random.Random(idx)
713
 
714
  # Random target size (relative to search region)
 
722
  cx = max(target_w / 2, min(self.search_size - target_w / 2, cx))
723
  cy = max(target_h / 2, min(self.search_size - target_h / 2, cy))
724
 
725
+ # Create synthetic images
726
  template = torch.randn(3, self.template_size, self.template_size) * 0.1
727
  search = torch.randn(3, self.search_size, self.search_size) * 0.1
728
 
 
731
  t_half_h = int(min(target_h / 2, self.template_size / 2 - 1))
732
  tc = self.template_size // 2
733
  color = torch.tensor([rng.random(), rng.random(), rng.random()]).view(3, 1, 1)
734
+ template[:, tc - t_half_h:tc + t_half_h, tc - t_half_w:tc + t_half_w] = color
735
 
736
  # Draw target in search region
737
  sx1 = max(0, int(cx - target_w / 2))
 
753
  dist_sq = (xx - cx_feat) ** 2 + (yy - cy_feat) ** 2
754
  heatmap = torch.exp(-dist_sq / (2 * sigma ** 2)).unsqueeze(0)
755
 
 
756
  size = torch.tensor([target_w / self.search_size, target_h / self.search_size])
 
 
757
  boxes = torch.tensor([cx, cy, target_w, target_h])
758
 
759
  return {
 
764
  'boxes': boxes,
765
  }
766
 
 
 
 
 
 
 
 
 
 
767
  def set_acl_difficulty(self, difficulty: float):
 
768
  self.acl_difficulty = min(1.0, max(0.0, difficulty))
769
+
770
+
771
+ # ============================================================
772
+ # Convenience: build combined dataset
773
+ # ============================================================
774
+
775
+ def build_tracking_dataset(
776
+ data_config: dict,
777
+ template_size: int = 128,
778
+ search_size: int = 256,
779
+ feat_size: int = 16,
780
+ acl_difficulty: float = 0.0,
781
+ ) -> Dataset:
782
+ """Build a combined tracking dataset from multiple sources.
783
+
784
+ Args:
785
+ data_config: dict with optional keys:
786
+ - 'got10k_root': path to GOT-10k dataset
787
+ - 'lasot_root': path to LaSOT dataset
788
+ - 'trackingnet_root': path to TrackingNet dataset
789
+ - 'coco_root': path to COCO train2017 images
790
+ - 'synthetic_length': number of synthetic samples (fallback)
791
+ template_size: template crop size
792
+ search_size: search region crop size
793
+ feat_size: feature map spatial size
794
+ acl_difficulty: initial ACL difficulty
795
+ Returns:
796
+ ConcatDataset or SyntheticTrackingDataset
797
+ """
798
+ common_kwargs = dict(
799
+ template_size=template_size,
800
+ search_size=search_size,
801
+ feat_size=feat_size,
802
+ acl_difficulty=acl_difficulty,
803
+ )
804
+
805
+ datasets = []
806
+
807
+ if 'got10k_root' in data_config and os.path.exists(data_config['got10k_root']):
808
+ ds = GOT10kDataset(data_config['got10k_root'], split='train', **common_kwargs)
809
+ if len(ds) > 0:
810
+ datasets.append(ds)
811
+ print(f"GOT-10k: {len(ds)} sequences")
812
+
813
+ if 'lasot_root' in data_config and os.path.exists(data_config['lasot_root']):
814
+ ds = LaSOTDataset(data_config['lasot_root'], split='train', **common_kwargs)
815
+ if len(ds) > 0:
816
+ datasets.append(ds)
817
+ print(f"LaSOT: {len(ds)} sequences")
818
+
819
+ if 'trackingnet_root' in data_config and os.path.exists(data_config['trackingnet_root']):
820
+ ds = TrackingNetDataset(data_config['trackingnet_root'], **common_kwargs)
821
+ if len(ds) > 0:
822
+ datasets.append(ds)
823
+ print(f"TrackingNet: {len(ds)} sequences")
824
+
825
+ if 'coco_root' in data_config and os.path.exists(data_config['coco_root']):
826
+ ds = COCODetDataset(data_config['coco_root'], **common_kwargs)
827
+ if len(ds) > 0:
828
+ datasets.append(ds)
829
+ print(f"COCO: {len(ds)} pseudo-sequences")
830
+
831
+ if datasets:
832
+ combined = ConcatDataset(datasets)
833
+ print(f"\nTotal training samples: {len(combined)}")
834
+ return combined
835
+
836
+ # Fallback to synthetic
837
+ syn_len = data_config.get('synthetic_length', 10000)
838
+ print(f"No real data found, using {syn_len} synthetic samples")
839
+ return SyntheticTrackingDataset(
840
+ length=syn_len,
841
+ template_size=template_size,
842
+ search_size=search_size,
843
+ feat_size=feat_size,
844
+ acl_difficulty=acl_difficulty,
845
+ )
846
+
847
+
848
+ # ============================================================
849
+ # Legacy alias for backward compatibility
850
+ # ============================================================
851
+
852
+ class TrackingDataset(SyntheticTrackingDataset):
853
+ """Backward-compatible alias for SyntheticTrackingDataset."""
854
+ def __init__(self, data_dir=None, split='train', synthetic=False,
855
+ synthetic_length=10000, **kwargs):
856
+ super().__init__(length=synthetic_length, **kwargs)