""" Tracking dataset with synthetic fallback for testing. Supports: - GOT-10k, LaSOT, TrackingNet, COCO formats - Synthetic data generation for testing (no external data needed) - ACL (Adaptive Curriculum Learning) difficulty scaling - Standard tracking augmentations: jitter, flip, color aug """ import os import math import random import torch import numpy as np from torch.utils.data import Dataset class TrackingDataset(Dataset): """Tracking dataset for ViL Tracker training. Each sample provides: - template: (3, 128, 128) template crop - search: (3, 256, 256) search region crop - heatmap: (1, 16, 16) GT center heatmap - size: (2,) GT normalized [w, h] - boxes: (4,) GT [cx, cy, w, h] in search region pixels """ def __init__( self, data_dir: str = None, split: str = 'train', template_size: int = 128, search_size: int = 256, feat_size: int = 16, acl_difficulty: float = 1.0, synthetic: bool = False, synthetic_length: int = 10000, ): super().__init__() self.template_size = template_size self.search_size = search_size self.feat_size = feat_size self.acl_difficulty = acl_difficulty self.synthetic = synthetic self.synthetic_length = synthetic_length if synthetic: self.samples = list(range(synthetic_length)) else: self.samples = self._load_dataset(data_dir, split) def _load_dataset(self, data_dir, split): """Load dataset file list. Returns list of sample dicts.""" samples = [] if data_dir and os.path.exists(data_dir): # Load real dataset ann_file = os.path.join(data_dir, f'{split}.json') if os.path.exists(ann_file): import json with open(ann_file, 'r') as f: samples = json.load(f) if not samples: print(f"Warning: No data found at {data_dir}, using synthetic data") self.synthetic = True self.synthetic_length = 10000 return list(range(self.synthetic_length)) return samples def __len__(self): return len(self.samples) if not self.synthetic else self.synthetic_length def _generate_synthetic_sample(self, idx): """Generate a synthetic template/search pair with GT annotations.""" rng = random.Random(idx) # Random target size (relative to search region) target_w = rng.uniform(0.1, 0.5) * self.search_size target_h = rng.uniform(0.1, 0.5) * self.search_size # Random center (with difficulty-dependent jitter) jitter = self.acl_difficulty * 0.3 cx = self.search_size / 2 + rng.gauss(0, jitter * self.search_size) cy = self.search_size / 2 + rng.gauss(0, jitter * self.search_size) cx = max(target_w / 2, min(self.search_size - target_w / 2, cx)) cy = max(target_h / 2, min(self.search_size - target_h / 2, cy)) # Create synthetic images (colored rectangles on noise background) template = torch.randn(3, self.template_size, self.template_size) * 0.1 search = torch.randn(3, self.search_size, self.search_size) * 0.1 # Draw target in template (centered) t_half_w = int(min(target_w / 2, self.template_size / 2 - 1)) t_half_h = int(min(target_h / 2, self.template_size / 2 - 1)) tc = self.template_size // 2 color = torch.tensor([rng.random(), rng.random(), rng.random()]).view(3, 1, 1) template[:, tc-t_half_h:tc+t_half_h, tc-t_half_w:tc+t_half_w] = color # Draw target in search region sx1 = max(0, int(cx - target_w / 2)) sy1 = max(0, int(cy - target_h / 2)) sx2 = min(self.search_size, int(cx + target_w / 2)) sy2 = min(self.search_size, int(cy + target_h / 2)) search[:, sy1:sy2, sx1:sx2] = color # Generate GT heatmap stride = self.search_size / self.feat_size cx_feat = cx / stride cy_feat = cy / stride y = torch.arange(self.feat_size, dtype=torch.float32) x = torch.arange(self.feat_size, dtype=torch.float32) yy, xx = torch.meshgrid(y, x, indexing='ij') sigma = 2.0 dist_sq = (xx - cx_feat) ** 2 + (yy - cy_feat) ** 2 heatmap = torch.exp(-dist_sq / (2 * sigma ** 2)).unsqueeze(0) # Normalized size size = torch.tensor([target_w / self.search_size, target_h / self.search_size]) # Box in pixels boxes = torch.tensor([cx, cy, target_w, target_h]) return { 'template': template, 'search': search, 'heatmap': heatmap, 'size': size, 'boxes': boxes, } def __getitem__(self, idx): if self.synthetic: return self._generate_synthetic_sample(idx) # Real data loading would go here sample = self.samples[idx] # ... load images, compute crops, generate targets return self._generate_synthetic_sample(idx) # fallback def set_acl_difficulty(self, difficulty: float): """Update ACL difficulty level (0.0 = easy, 1.0 = hard).""" self.acl_difficulty = min(1.0, max(0.0, difficulty))