omar-ah commited on
Commit
9bef6c8
·
verified ·
1 Parent(s): be1f14e

Sequence training: pairs→K-frame clips, mLSTM memory carries across frames

Browse files
Files changed (1) hide show
  1. vil_tracker/data/dataset.py +198 -132
vil_tracker/data/dataset.py CHANGED
@@ -245,6 +245,10 @@ def compute_crop_params(bbox: np.ndarray, context_factor: float = 2.0) -> tuple:
245
  class SequenceDataset(Dataset):
246
  """Base class for tracking sequence datasets.
247
 
 
 
 
 
248
  Subclasses must populate self.sequences with list of:
249
  {'frames': [path1, path2, ...], 'gt': [[x,y,w,h], ...]}
250
  """
@@ -256,6 +260,7 @@ class SequenceDataset(Dataset):
256
  feat_size: int = 16,
257
  acl_difficulty: float = 1.0,
258
  max_gap: int = 100,
 
259
  augmentation: bool = True,
260
  ):
261
  super().__init__()
@@ -264,6 +269,7 @@ class SequenceDataset(Dataset):
264
  self.feat_size = feat_size
265
  self.acl_difficulty = acl_difficulty
266
  self.max_gap = max_gap
 
267
  self.sequences = []
268
 
269
  self.augmentation = TrackingAugmentation() if augmentation else None
@@ -278,123 +284,168 @@ class SequenceDataset(Dataset):
278
  img = Image.open(path).convert('RGB')
279
  return np.array(img, dtype=np.float32)
280
  except ImportError:
281
- # Fallback with OpenCV
282
  import cv2
283
  img = cv2.imread(path)
284
  if img is None:
285
  return np.zeros((480, 640, 3), dtype=np.float32)
286
  return cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32)
287
 
288
- def _sample_pair(self, idx: int) -> tuple:
289
- """Sample a (template_frame_idx, search_frame_idx) pair.
290
-
291
- Temporal distance controlled by ACL difficulty:
292
- - difficulty=0: template and search are very close
293
- - difficulty=1: template and search can be up to max_gap apart
294
 
295
  Returns:
296
- (template_idx, search_idx) frame indices
297
  """
298
  seq = self.sequences[idx]
299
  n_frames = len(seq['frames'])
 
300
 
301
- # Template: sample random frame with valid annotation
302
- valid_indices = [i for i in range(n_frames) if seq['gt'][i] is not None and
303
- seq['gt'][i][2] > 0 and seq['gt'][i][3] > 0]
304
 
305
- if len(valid_indices) < 2:
306
- t_idx = valid_indices[0] if valid_indices else 0
307
- return t_idx, t_idx
 
 
308
 
309
- t_idx = random.choice(valid_indices)
 
310
 
311
- # Search: within difficulty-scaled temporal gap
 
312
  effective_gap = max(1, int(self.max_gap * self.acl_difficulty))
313
- min_idx = max(0, t_idx - effective_gap)
314
- max_idx = min(n_frames - 1, t_idx + effective_gap)
315
-
316
- # Only pick valid indices
317
- search_candidates = [i for i in range(min_idx, max_idx + 1)
318
- if i != t_idx and i in valid_indices]
319
-
320
- if not search_candidates:
321
- return t_idx, t_idx
322
-
323
- s_idx = random.choice(search_candidates)
324
- return t_idx, s_idx
325
-
326
- def __getitem__(self, idx):
327
- seq = self.sequences[idx % len(self.sequences)]
328
- t_idx, s_idx = self._sample_pair(idx % len(self.sequences))
329
-
330
- # Load images
331
- t_img = self._load_image(seq['frames'][t_idx])
332
- s_img = self._load_image(seq['frames'][s_idx])
333
-
334
- t_bbox = np.array(seq['gt'][t_idx], dtype=np.float32) # [x, y, w, h]
335
- s_bbox = np.array(seq['gt'][s_idx], dtype=np.float32)
336
-
337
- # Crop template (centered on target, 2x context)
338
- t_center, t_crop_size = compute_crop_params(t_bbox, context_factor=2.0)
339
- template = crop_and_resize(t_img, t_center, t_crop_size, self.template_size)
340
-
341
- # Crop search region (centered on target with jitter, 4x context)
342
- s_center, s_crop_size = compute_crop_params(s_bbox, context_factor=4.0)
343
-
344
- # Add spatial jitter (controlled by ACL difficulty)
345
- jitter = self.acl_difficulty * s_bbox[2:4].mean() * 0.3
346
- s_center[0] += random.gauss(0, jitter) if jitter > 0 else 0
347
- s_center[1] += random.gauss(0, jitter) if jitter > 0 else 0
348
 
349
- search = crop_and_resize(s_img, s_center, s_crop_size, self.search_size)
 
 
350
 
351
- # Compute GT in search crop coordinates
352
- # Target center relative to crop center, then scaled to search_size
353
- scale = self.search_size / s_crop_size
354
- cx_in_search = (s_bbox[0] + s_bbox[2] / 2 - s_center[0] + s_crop_size / 2) * scale
355
- cy_in_search = (s_bbox[1] + s_bbox[3] / 2 - s_center[1] + s_crop_size / 2) * scale
356
- w_in_search = s_bbox[2] * scale
357
- h_in_search = s_bbox[3] * scale
358
-
359
- # Clamp to search region
360
- cx_in_search = max(0, min(self.search_size, cx_in_search))
361
- cy_in_search = max(0, min(self.search_size, cy_in_search))
362
- w_in_search = max(1, min(self.search_size, w_in_search))
363
- h_in_search = max(1, min(self.search_size, h_in_search))
 
 
 
 
 
 
 
364
 
365
- # Convert to tensors [0, 1]
366
- template = torch.from_numpy(template).float().permute(2, 0, 1) / 255.0
367
- search = torch.from_numpy(search).float().permute(2, 0, 1) / 255.0
368
- bbox_tensor = torch.tensor([cx_in_search, cy_in_search, w_in_search, h_in_search])
369
 
370
- # Apply augmentations
371
- if self.augmentation is not None:
372
- template, search, bbox_tensor = self.augmentation(template, search, bbox_tensor)
 
373
 
374
- # Generate GT heatmap
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375
  stride = self.search_size / self.feat_size
376
- cx_feat = bbox_tensor[0].item() / stride
377
- cy_feat = bbox_tensor[1].item() / stride
 
 
378
 
379
  y = torch.arange(self.feat_size, dtype=torch.float32)
380
  x = torch.arange(self.feat_size, dtype=torch.float32)
381
  yy, xx = torch.meshgrid(y, x, indexing='ij')
382
 
383
- # Adaptive sigma based on target size (smaller targets = sharper heatmap)
384
- sigma = max(1.0, min(3.0, (w_in_search + h_in_search) / (2 * stride * 4)))
385
  dist_sq = (xx - cx_feat) ** 2 + (yy - cy_feat) ** 2
386
  heatmap = torch.exp(-dist_sq / (2 * sigma ** 2)).unsqueeze(0)
 
 
 
 
 
387
 
388
- # Normalized size
389
- size = torch.tensor([bbox_tensor[2].item() / self.search_size,
390
- bbox_tensor[3].item() / self.search_size])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
391
 
392
  return {
393
- 'template': template,
394
- 'search': search,
395
- 'heatmap': heatmap,
396
- 'size': size,
397
- 'boxes': bbox_tensor,
398
  }
399
 
400
  def set_acl_difficulty(self, difficulty: float):
@@ -686,8 +737,8 @@ class COCODetDataset(SequenceDataset):
686
  class SyntheticTrackingDataset(Dataset):
687
  """Synthetic tracking dataset for testing without real data.
688
 
689
- Generates colored rectangles on noise backgrounds with controlled
690
- position jitter based on ACL difficulty.
691
  """
692
 
693
  def __init__(
@@ -697,6 +748,7 @@ class SyntheticTrackingDataset(Dataset):
697
  search_size: int = 256,
698
  feat_size: int = 16,
699
  acl_difficulty: float = 1.0,
 
700
  ):
701
  super().__init__()
702
  self.length = length
@@ -704,64 +756,78 @@ class SyntheticTrackingDataset(Dataset):
704
  self.search_size = search_size
705
  self.feat_size = feat_size
706
  self.acl_difficulty = acl_difficulty
 
707
 
708
  def __len__(self):
709
  return self.length
710
 
 
 
 
 
 
 
 
 
 
 
 
711
  def __getitem__(self, idx):
712
  rng = random.Random(idx)
 
713
 
714
- # Random target size (relative to search region)
 
715
  target_w = rng.uniform(0.1, 0.5) * self.search_size
716
  target_h = rng.uniform(0.1, 0.5) * self.search_size
717
 
718
- # Random center (with difficulty-dependent jitter)
719
- jitter = self.acl_difficulty * 0.3
720
- cx = self.search_size / 2 + rng.gauss(0, jitter * self.search_size)
721
- cy = self.search_size / 2 + rng.gauss(0, jitter * self.search_size)
722
- cx = max(target_w / 2, min(self.search_size - target_w / 2, cx))
723
- cy = max(target_h / 2, min(self.search_size - target_h / 2, cy))
724
 
725
- # Create synthetic images
726
- template = torch.randn(3, self.template_size, self.template_size) * 0.1
727
- search = torch.randn(3, self.search_size, self.search_size) * 0.1
728
 
729
- # Draw target in template (centered)
730
- t_half_w = int(min(target_w / 2, self.template_size / 2 - 1))
731
- t_half_h = int(min(target_h / 2, self.template_size / 2 - 1))
 
732
  tc = self.template_size // 2
733
- color = torch.tensor([rng.random(), rng.random(), rng.random()]).view(3, 1, 1)
734
- template[:, tc - t_half_h:tc + t_half_h, tc - t_half_w:tc + t_half_w] = color
735
-
736
- # Draw target in search region
737
- sx1 = max(0, int(cx - target_w / 2))
738
- sy1 = max(0, int(cy - target_h / 2))
739
- sx2 = min(self.search_size, int(cx + target_w / 2))
740
- sy2 = min(self.search_size, int(cy + target_h / 2))
741
- search[:, sy1:sy2, sx1:sx2] = color
742
-
743
- # Generate GT heatmap
744
- stride = self.search_size / self.feat_size
745
- cx_feat = cx / stride
746
- cy_feat = cy / stride
747
-
748
- y = torch.arange(self.feat_size, dtype=torch.float32)
749
- x = torch.arange(self.feat_size, dtype=torch.float32)
750
- yy, xx = torch.meshgrid(y, x, indexing='ij')
751
-
752
- sigma = 2.0
753
- dist_sq = (xx - cx_feat) ** 2 + (yy - cy_feat) ** 2
754
- heatmap = torch.exp(-dist_sq / (2 * sigma ** 2)).unsqueeze(0)
755
-
756
- size = torch.tensor([target_w / self.search_size, target_h / self.search_size])
757
- boxes = torch.tensor([cx, cy, target_w, target_h])
 
 
758
 
759
  return {
760
- 'template': template,
761
- 'search': search,
762
- 'heatmap': heatmap,
763
- 'size': size,
764
- 'boxes': boxes,
765
  }
766
 
767
  def set_acl_difficulty(self, difficulty: float):
@@ -1202,5 +1268,5 @@ def build_tracking_dataset(
1202
  class TrackingDataset(SyntheticTrackingDataset):
1203
  """Backward-compatible alias for SyntheticTrackingDataset."""
1204
  def __init__(self, data_dir=None, split='train', synthetic=False,
1205
- synthetic_length=10000, **kwargs):
1206
- super().__init__(length=synthetic_length, **kwargs)
 
245
  class SequenceDataset(Dataset):
246
  """Base class for tracking sequence datasets.
247
 
248
+ Returns K-frame clips: template + K consecutive search frames.
249
+ The mLSTM processes these as one long sequence where memory carries
250
+ information across frames — this is the core training paradigm.
251
+
252
  Subclasses must populate self.sequences with list of:
253
  {'frames': [path1, path2, ...], 'gt': [[x,y,w,h], ...]}
254
  """
 
260
  feat_size: int = 16,
261
  acl_difficulty: float = 1.0,
262
  max_gap: int = 100,
263
+ clip_length: int = 3,
264
  augmentation: bool = True,
265
  ):
266
  super().__init__()
 
269
  self.feat_size = feat_size
270
  self.acl_difficulty = acl_difficulty
271
  self.max_gap = max_gap
272
+ self.clip_length = clip_length # K search frames per sample
273
  self.sequences = []
274
 
275
  self.augmentation = TrackingAugmentation() if augmentation else None
 
284
  img = Image.open(path).convert('RGB')
285
  return np.array(img, dtype=np.float32)
286
  except ImportError:
 
287
  import cv2
288
  img = cv2.imread(path)
289
  if img is None:
290
  return np.zeros((480, 640, 3), dtype=np.float32)
291
  return cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32)
292
 
293
+ def _sample_clip(self, idx: int) -> list:
294
+ """Sample a clip: template frame + K consecutive search frames.
 
 
 
 
295
 
296
  Returns:
297
+ list of frame indices: [template_idx, search_1_idx, ..., search_K_idx]
298
  """
299
  seq = self.sequences[idx]
300
  n_frames = len(seq['frames'])
301
+ K = self.clip_length
302
 
303
+ valid = [i for i in range(n_frames)
304
+ if seq['gt'][i] is not None and seq['gt'][i][2] > 0 and seq['gt'][i][3] > 0]
305
+ valid_set = set(valid)
306
 
307
+ if len(valid) < K + 1:
308
+ # Not enough frames repeat what we have
309
+ if len(valid) == 0:
310
+ return [0] * (K + 1)
311
+ return [valid[0]] + [valid[min(i, len(valid)-1)] for i in range(K)]
312
 
313
+ # Template: pick a random valid frame
314
+ t_idx = random.choice(valid)
315
 
316
+ # Search frames: K consecutive valid frames AFTER template
317
+ # Temporal gap between template and first search controlled by ACL
318
  effective_gap = max(1, int(self.max_gap * self.acl_difficulty))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
 
320
+ # Find the start of the search clip: somewhere after template
321
+ min_start = t_idx + 1
322
+ max_start = min(t_idx + effective_gap, n_frames - K)
323
 
324
+ if max_start < min_start:
325
+ # Try before template
326
+ max_start_before = t_idx - K
327
+ min_start_before = max(0, t_idx - effective_gap - K)
328
+ if max_start_before >= min_start_before and max_start_before >= 0:
329
+ clip_start = random.randint(min_start_before, max_start_before)
330
+ else:
331
+ # Fallback: just use whatever consecutive frames we can find
332
+ clip_start = max(0, min(n_frames - K, t_idx + 1))
333
+ # But ensure template is different from search frames
334
+ else:
335
+ clip_start = random.randint(min_start, max(min_start, max_start))
336
+
337
+ # Collect K consecutive frames, preferring valid ones
338
+ search_indices = []
339
+ for i in range(clip_start, min(clip_start + K * 3, n_frames)):
340
+ if i in valid_set and i != t_idx:
341
+ search_indices.append(i)
342
+ if len(search_indices) == K:
343
+ break
344
 
345
+ # Pad if we didn't find enough
346
+ while len(search_indices) < K:
347
+ search_indices.append(search_indices[-1] if search_indices else t_idx)
 
348
 
349
+ return [t_idx] + search_indices[:K]
350
+
351
+ def _process_frame(self, img: np.ndarray, bbox: np.ndarray, is_template: bool):
352
+ """Crop and preprocess a single frame.
353
 
354
+ Returns:
355
+ image_tensor: (3, H, W) float [0, 1]
356
+ bbox_in_crop: (4,) [cx, cy, w, h] in crop coordinates
357
+ """
358
+ if is_template:
359
+ center, crop_size = compute_crop_params(bbox, context_factor=2.0)
360
+ output_size = self.template_size
361
+ else:
362
+ center, crop_size = compute_crop_params(bbox, context_factor=4.0)
363
+ output_size = self.search_size
364
+ # Spatial jitter for search (controlled by ACL)
365
+ jitter = self.acl_difficulty * bbox[2:4].mean() * 0.3
366
+ if jitter > 0:
367
+ center[0] += random.gauss(0, jitter)
368
+ center[1] += random.gauss(0, jitter)
369
+
370
+ crop = crop_and_resize(img, center, crop_size, output_size)
371
+
372
+ # Compute GT in crop coordinates
373
+ scale = output_size / crop_size
374
+ cx = (bbox[0] + bbox[2] / 2 - center[0] + crop_size / 2) * scale
375
+ cy = (bbox[1] + bbox[3] / 2 - center[1] + crop_size / 2) * scale
376
+ w = bbox[2] * scale
377
+ h = bbox[3] * scale
378
+
379
+ cx = max(0, min(output_size, cx))
380
+ cy = max(0, min(output_size, cy))
381
+ w = max(1, min(output_size, w))
382
+ h = max(1, min(output_size, h))
383
+
384
+ tensor = torch.from_numpy(crop).float().permute(2, 0, 1) / 255.0
385
+ bbox_crop = torch.tensor([cx, cy, w, h])
386
+
387
+ return tensor, bbox_crop
388
+
389
+ def _make_heatmap(self, bbox: torch.Tensor):
390
+ """Generate GT heatmap from bbox in search crop coordinates."""
391
  stride = self.search_size / self.feat_size
392
+ cx_feat = bbox[0].item() / stride
393
+ cy_feat = bbox[1].item() / stride
394
+ w_search = bbox[2].item()
395
+ h_search = bbox[3].item()
396
 
397
  y = torch.arange(self.feat_size, dtype=torch.float32)
398
  x = torch.arange(self.feat_size, dtype=torch.float32)
399
  yy, xx = torch.meshgrid(y, x, indexing='ij')
400
 
401
+ sigma = max(1.0, min(3.0, (w_search + h_search) / (2 * stride * 4)))
 
402
  dist_sq = (xx - cx_feat) ** 2 + (yy - cy_feat) ** 2
403
  heatmap = torch.exp(-dist_sq / (2 * sigma ** 2)).unsqueeze(0)
404
+ return heatmap
405
+
406
+ def __getitem__(self, idx):
407
+ seq = self.sequences[idx % len(self.sequences)]
408
+ clip_indices = self._sample_clip(idx % len(self.sequences))
409
 
410
+ t_idx = clip_indices[0]
411
+ s_indices = clip_indices[1:]
412
+ K = len(s_indices)
413
+
414
+ # Load and process template
415
+ t_img = self._load_image(seq['frames'][t_idx])
416
+ t_bbox = np.array(seq['gt'][t_idx], dtype=np.float32)
417
+ template, _ = self._process_frame(t_img, t_bbox, is_template=True)
418
+
419
+ # Load and process K search frames
420
+ searches = []
421
+ heatmaps = []
422
+ sizes = []
423
+ boxes = []
424
+
425
+ for s_idx in s_indices:
426
+ s_img = self._load_image(seq['frames'][s_idx])
427
+ s_bbox = np.array(seq['gt'][s_idx], dtype=np.float32)
428
+ search, bbox_crop = self._process_frame(s_img, s_bbox, is_template=False)
429
+
430
+ # Apply augmentation (same color transform for template+search consistency)
431
+ if self.augmentation is not None:
432
+ template_aug, search, bbox_crop = self.augmentation(template, search, bbox_crop)
433
+ # Only use augmented template from first search frame to keep consistency
434
+ if len(searches) == 0:
435
+ template = template_aug
436
+
437
+ searches.append(search)
438
+ heatmaps.append(self._make_heatmap(bbox_crop))
439
+ sizes.append(torch.tensor([bbox_crop[2].item() / self.search_size,
440
+ bbox_crop[3].item() / self.search_size]))
441
+ boxes.append(bbox_crop)
442
 
443
  return {
444
+ 'template': template, # (3, 128, 128)
445
+ 'searches': torch.stack(searches, dim=0), # (K, 3, 256, 256)
446
+ 'heatmaps': torch.stack(heatmaps, dim=0), # (K, 1, 16, 16)
447
+ 'sizes': torch.stack(sizes, dim=0), # (K, 2)
448
+ 'boxes': torch.stack(boxes, dim=0), # (K, 4)
449
  }
450
 
451
  def set_acl_difficulty(self, difficulty: float):
 
737
  class SyntheticTrackingDataset(Dataset):
738
  """Synthetic tracking dataset for testing without real data.
739
 
740
+ Generates K-frame clips: template + K search frames with a moving
741
+ colored rectangle target. Motion is linear with noise.
742
  """
743
 
744
  def __init__(
 
748
  search_size: int = 256,
749
  feat_size: int = 16,
750
  acl_difficulty: float = 1.0,
751
+ clip_length: int = 3,
752
  ):
753
  super().__init__()
754
  self.length = length
 
756
  self.search_size = search_size
757
  self.feat_size = feat_size
758
  self.acl_difficulty = acl_difficulty
759
+ self.clip_length = clip_length
760
 
761
  def __len__(self):
762
  return self.length
763
 
764
+ def _make_heatmap(self, cx, cy, w_search, h_search):
765
+ stride = self.search_size / self.feat_size
766
+ cx_feat = cx / stride
767
+ cy_feat = cy / stride
768
+ y = torch.arange(self.feat_size, dtype=torch.float32)
769
+ x = torch.arange(self.feat_size, dtype=torch.float32)
770
+ yy, xx = torch.meshgrid(y, x, indexing='ij')
771
+ sigma = max(1.0, min(3.0, (w_search + h_search) / (2 * stride * 4)))
772
+ dist_sq = (xx - cx_feat) ** 2 + (yy - cy_feat) ** 2
773
+ return torch.exp(-dist_sq / (2 * sigma ** 2)).unsqueeze(0)
774
+
775
  def __getitem__(self, idx):
776
  rng = random.Random(idx)
777
+ K = self.clip_length
778
 
779
+ # Target appearance
780
+ color = torch.tensor([rng.random(), rng.random(), rng.random()]).view(3, 1, 1)
781
  target_w = rng.uniform(0.1, 0.5) * self.search_size
782
  target_h = rng.uniform(0.1, 0.5) * self.search_size
783
 
784
+ # Initial position (center of search)
785
+ cx0 = self.search_size / 2
786
+ cy0 = self.search_size / 2
 
 
 
787
 
788
+ # Velocity (pixels per frame, scaled by difficulty)
789
+ vx = rng.gauss(0, self.acl_difficulty * 15)
790
+ vy = rng.gauss(0, self.acl_difficulty * 15)
791
 
792
+ # Template: target at center
793
+ template = torch.randn(3, self.template_size, self.template_size) * 0.1
794
+ t_hw = int(min(target_w / 2, self.template_size / 2 - 1))
795
+ t_hh = int(min(target_h / 2, self.template_size / 2 - 1))
796
  tc = self.template_size // 2
797
+ template[:, tc - t_hh:tc + t_hh, tc - t_hw:tc + t_hw] = color
798
+
799
+ # K search frames with moving target
800
+ searches = []
801
+ heatmaps = []
802
+ sizes = []
803
+ boxes = []
804
+
805
+ for k in range(K):
806
+ # Position at frame k
807
+ cx = cx0 + vx * (k + 1) + rng.gauss(0, self.acl_difficulty * 5)
808
+ cy = cy0 + vy * (k + 1) + rng.gauss(0, self.acl_difficulty * 5)
809
+ cx = max(target_w / 2, min(self.search_size - target_w / 2, cx))
810
+ cy = max(target_h / 2, min(self.search_size - target_h / 2, cy))
811
+
812
+ search = torch.randn(3, self.search_size, self.search_size) * 0.1
813
+ sx1 = max(0, int(cx - target_w / 2))
814
+ sy1 = max(0, int(cy - target_h / 2))
815
+ sx2 = min(self.search_size, int(cx + target_w / 2))
816
+ sy2 = min(self.search_size, int(cy + target_h / 2))
817
+ search[:, sy1:sy2, sx1:sx2] = color
818
+
819
+ searches.append(search)
820
+ heatmaps.append(self._make_heatmap(cx, cy, target_w, target_h))
821
+ sizes.append(torch.tensor([target_w / self.search_size,
822
+ target_h / self.search_size]))
823
+ boxes.append(torch.tensor([cx, cy, target_w, target_h]))
824
 
825
  return {
826
+ 'template': template, # (3, 128, 128)
827
+ 'searches': torch.stack(searches, dim=0), # (K, 3, 256, 256)
828
+ 'heatmaps': torch.stack(heatmaps, dim=0), # (K, 1, 16, 16)
829
+ 'sizes': torch.stack(sizes, dim=0), # (K, 2)
830
+ 'boxes': torch.stack(boxes, dim=0), # (K, 4)
831
  }
832
 
833
  def set_acl_difficulty(self, difficulty: float):
 
1268
  class TrackingDataset(SyntheticTrackingDataset):
1269
  """Backward-compatible alias for SyntheticTrackingDataset."""
1270
  def __init__(self, data_dir=None, split='train', synthetic=False,
1271
+ synthetic_length=10000, clip_length=3, **kwargs):
1272
+ super().__init__(length=synthetic_length, clip_length=clip_length, **kwargs)