vil-tracker / vil_tracker /data /dataset.py
omar-ah's picture
Upload vil_tracker/data/dataset.py with huggingface_hub
823a1a3 verified
raw
history blame
5.44 kB
"""
Tracking dataset with synthetic fallback for testing.
Supports:
- GOT-10k, LaSOT, TrackingNet, COCO formats
- Synthetic data generation for testing (no external data needed)
- ACL (Adaptive Curriculum Learning) difficulty scaling
- Standard tracking augmentations: jitter, flip, color aug
"""
import os
import math
import random
import torch
import numpy as np
from torch.utils.data import Dataset
class TrackingDataset(Dataset):
"""Tracking dataset for ViL Tracker training.
Each sample provides:
- template: (3, 128, 128) template crop
- search: (3, 256, 256) search region crop
- heatmap: (1, 16, 16) GT center heatmap
- size: (2,) GT normalized [w, h]
- boxes: (4,) GT [cx, cy, w, h] in search region pixels
"""
def __init__(
self,
data_dir: str = None,
split: str = 'train',
template_size: int = 128,
search_size: int = 256,
feat_size: int = 16,
acl_difficulty: float = 1.0,
synthetic: bool = False,
synthetic_length: int = 10000,
):
super().__init__()
self.template_size = template_size
self.search_size = search_size
self.feat_size = feat_size
self.acl_difficulty = acl_difficulty
self.synthetic = synthetic
self.synthetic_length = synthetic_length
if synthetic:
self.samples = list(range(synthetic_length))
else:
self.samples = self._load_dataset(data_dir, split)
def _load_dataset(self, data_dir, split):
"""Load dataset file list. Returns list of sample dicts."""
samples = []
if data_dir and os.path.exists(data_dir):
# Load real dataset
ann_file = os.path.join(data_dir, f'{split}.json')
if os.path.exists(ann_file):
import json
with open(ann_file, 'r') as f:
samples = json.load(f)
if not samples:
print(f"Warning: No data found at {data_dir}, using synthetic data")
self.synthetic = True
self.synthetic_length = 10000
return list(range(self.synthetic_length))
return samples
def __len__(self):
return len(self.samples) if not self.synthetic else self.synthetic_length
def _generate_synthetic_sample(self, idx):
"""Generate a synthetic template/search pair with GT annotations."""
rng = random.Random(idx)
# Random target size (relative to search region)
target_w = rng.uniform(0.1, 0.5) * self.search_size
target_h = rng.uniform(0.1, 0.5) * self.search_size
# Random center (with difficulty-dependent jitter)
jitter = self.acl_difficulty * 0.3
cx = self.search_size / 2 + rng.gauss(0, jitter * self.search_size)
cy = self.search_size / 2 + rng.gauss(0, jitter * self.search_size)
cx = max(target_w / 2, min(self.search_size - target_w / 2, cx))
cy = max(target_h / 2, min(self.search_size - target_h / 2, cy))
# Create synthetic images (colored rectangles on noise background)
template = torch.randn(3, self.template_size, self.template_size) * 0.1
search = torch.randn(3, self.search_size, self.search_size) * 0.1
# Draw target in template (centered)
t_half_w = int(min(target_w / 2, self.template_size / 2 - 1))
t_half_h = int(min(target_h / 2, self.template_size / 2 - 1))
tc = self.template_size // 2
color = torch.tensor([rng.random(), rng.random(), rng.random()]).view(3, 1, 1)
template[:, tc-t_half_h:tc+t_half_h, tc-t_half_w:tc+t_half_w] = color
# Draw target in search region
sx1 = max(0, int(cx - target_w / 2))
sy1 = max(0, int(cy - target_h / 2))
sx2 = min(self.search_size, int(cx + target_w / 2))
sy2 = min(self.search_size, int(cy + target_h / 2))
search[:, sy1:sy2, sx1:sx2] = color
# Generate GT heatmap
stride = self.search_size / self.feat_size
cx_feat = cx / stride
cy_feat = cy / stride
y = torch.arange(self.feat_size, dtype=torch.float32)
x = torch.arange(self.feat_size, dtype=torch.float32)
yy, xx = torch.meshgrid(y, x, indexing='ij')
sigma = 2.0
dist_sq = (xx - cx_feat) ** 2 + (yy - cy_feat) ** 2
heatmap = torch.exp(-dist_sq / (2 * sigma ** 2)).unsqueeze(0)
# Normalized size
size = torch.tensor([target_w / self.search_size, target_h / self.search_size])
# Box in pixels
boxes = torch.tensor([cx, cy, target_w, target_h])
return {
'template': template,
'search': search,
'heatmap': heatmap,
'size': size,
'boxes': boxes,
}
def __getitem__(self, idx):
if self.synthetic:
return self._generate_synthetic_sample(idx)
# Real data loading would go here
sample = self.samples[idx]
# ... load images, compute crops, generate targets
return self._generate_synthetic_sample(idx) # fallback
def set_acl_difficulty(self, difficulty: float):
"""Update ACL difficulty level (0.0 = easy, 1.0 = hard)."""
self.acl_difficulty = min(1.0, max(0.0, difficulty))