| """ |
| Tracking dataset with synthetic fallback for testing. |
| |
| Supports: |
| - GOT-10k, LaSOT, TrackingNet, COCO formats |
| - Synthetic data generation for testing (no external data needed) |
| - ACL (Adaptive Curriculum Learning) difficulty scaling |
| - Standard tracking augmentations: jitter, flip, color aug |
| """ |
|
|
| import os |
| import math |
| import random |
| import torch |
| import numpy as np |
| from torch.utils.data import Dataset |
|
|
|
|
| class TrackingDataset(Dataset): |
| """Tracking dataset for ViL Tracker training. |
| |
| Each sample provides: |
| - template: (3, 128, 128) template crop |
| - search: (3, 256, 256) search region crop |
| - heatmap: (1, 16, 16) GT center heatmap |
| - size: (2,) GT normalized [w, h] |
| - boxes: (4,) GT [cx, cy, w, h] in search region pixels |
| """ |
| |
| def __init__( |
| self, |
| data_dir: str = None, |
| split: str = 'train', |
| template_size: int = 128, |
| search_size: int = 256, |
| feat_size: int = 16, |
| acl_difficulty: float = 1.0, |
| synthetic: bool = False, |
| synthetic_length: int = 10000, |
| ): |
| super().__init__() |
| self.template_size = template_size |
| self.search_size = search_size |
| self.feat_size = feat_size |
| self.acl_difficulty = acl_difficulty |
| self.synthetic = synthetic |
| self.synthetic_length = synthetic_length |
| |
| if synthetic: |
| self.samples = list(range(synthetic_length)) |
| else: |
| self.samples = self._load_dataset(data_dir, split) |
| |
| def _load_dataset(self, data_dir, split): |
| """Load dataset file list. Returns list of sample dicts.""" |
| samples = [] |
| if data_dir and os.path.exists(data_dir): |
| |
| ann_file = os.path.join(data_dir, f'{split}.json') |
| if os.path.exists(ann_file): |
| import json |
| with open(ann_file, 'r') as f: |
| samples = json.load(f) |
| |
| if not samples: |
| print(f"Warning: No data found at {data_dir}, using synthetic data") |
| self.synthetic = True |
| self.synthetic_length = 10000 |
| return list(range(self.synthetic_length)) |
| |
| return samples |
| |
| def __len__(self): |
| return len(self.samples) if not self.synthetic else self.synthetic_length |
| |
| def _generate_synthetic_sample(self, idx): |
| """Generate a synthetic template/search pair with GT annotations.""" |
| rng = random.Random(idx) |
| |
| |
| target_w = rng.uniform(0.1, 0.5) * self.search_size |
| target_h = rng.uniform(0.1, 0.5) * self.search_size |
| |
| |
| jitter = self.acl_difficulty * 0.3 |
| cx = self.search_size / 2 + rng.gauss(0, jitter * self.search_size) |
| cy = self.search_size / 2 + rng.gauss(0, jitter * self.search_size) |
| cx = max(target_w / 2, min(self.search_size - target_w / 2, cx)) |
| cy = max(target_h / 2, min(self.search_size - target_h / 2, cy)) |
| |
| |
| template = torch.randn(3, self.template_size, self.template_size) * 0.1 |
| search = torch.randn(3, self.search_size, self.search_size) * 0.1 |
| |
| |
| t_half_w = int(min(target_w / 2, self.template_size / 2 - 1)) |
| t_half_h = int(min(target_h / 2, self.template_size / 2 - 1)) |
| tc = self.template_size // 2 |
| color = torch.tensor([rng.random(), rng.random(), rng.random()]).view(3, 1, 1) |
| template[:, tc-t_half_h:tc+t_half_h, tc-t_half_w:tc+t_half_w] = color |
| |
| |
| sx1 = max(0, int(cx - target_w / 2)) |
| sy1 = max(0, int(cy - target_h / 2)) |
| sx2 = min(self.search_size, int(cx + target_w / 2)) |
| sy2 = min(self.search_size, int(cy + target_h / 2)) |
| search[:, sy1:sy2, sx1:sx2] = color |
| |
| |
| stride = self.search_size / self.feat_size |
| cx_feat = cx / stride |
| cy_feat = cy / stride |
| |
| y = torch.arange(self.feat_size, dtype=torch.float32) |
| x = torch.arange(self.feat_size, dtype=torch.float32) |
| yy, xx = torch.meshgrid(y, x, indexing='ij') |
| |
| sigma = 2.0 |
| dist_sq = (xx - cx_feat) ** 2 + (yy - cy_feat) ** 2 |
| heatmap = torch.exp(-dist_sq / (2 * sigma ** 2)).unsqueeze(0) |
| |
| |
| size = torch.tensor([target_w / self.search_size, target_h / self.search_size]) |
| |
| |
| boxes = torch.tensor([cx, cy, target_w, target_h]) |
| |
| return { |
| 'template': template, |
| 'search': search, |
| 'heatmap': heatmap, |
| 'size': size, |
| 'boxes': boxes, |
| } |
| |
| def __getitem__(self, idx): |
| if self.synthetic: |
| return self._generate_synthetic_sample(idx) |
| |
| |
| sample = self.samples[idx] |
| |
| return self._generate_synthetic_sample(idx) |
| |
| def set_acl_difficulty(self, difficulty: float): |
| """Update ACL difficulty level (0.0 = easy, 1.0 = hard).""" |
| self.acl_difficulty = min(1.0, max(0.0, difficulty)) |
|
|