| """ |
| Tracking dataset with real dataset loaders and synthetic fallback. |
| |
| Supports: |
| - GOT-10k: train split (~10k sequences, annotations in groundtruth.txt) |
| - LaSOT: training split (1120 sequences, 14 categories) |
| - TrackingNet: training split (30k+ sequences, annotations in anno/) |
| - COCO detection: for static pair pretraining (bbox crops as pseudo-sequences) |
| - Synthetic data generation for testing (no external data needed) |
| - ACL (Adaptive Curriculum Learning) difficulty scaling |
| - Standard tracking augmentations: spatial jitter, horizontal flip, color jitter, |
| grayscale, Gaussian blur, brightness/contrast |
| |
| Each sample produces a (template, search) pair from the same video sequence |
| with controlled temporal distance, plus GT annotations. |
| |
| Dataset directory structure expected: |
| GOT-10k/ |
| train/ |
| GOT-10k_Train_000001/ |
| 00000001.jpg, 00000002.jpg, ... |
| groundtruth.txt # x,y,w,h per line |
| ... |
| LaSOT/ |
| airplane/ |
| airplane-1/ |
| img/ |
| 00000001.jpg, ... |
| groundtruth.txt # x,y,w,h per line |
| ... |
| TrackingNet/ |
| TRAIN_0/ |
| frames/ |
| video_name/ |
| 0.jpg, 1.jpg, ... |
| anno/ |
| video_name.txt # x,y,w,h per line |
| ... |
| COCO/ |
| train2017/ |
| *.jpg |
| annotations/ |
| instances_train2017.json |
| """ |
|
|
| import os |
| import math |
| import glob |
| import random |
| import torch |
| import numpy as np |
| from pathlib import Path |
| from torch.utils.data import Dataset, ConcatDataset |
|
|
|
|
| |
| |
| |
|
|
| class TrackingAugmentation: |
| """Standard tracking augmentations applied to (template, search) pairs. |
| |
| Augmentations preserve the spatial relationship between search region |
| and GT bounding box by applying augmentations consistently. |
| """ |
| |
| def __init__( |
| self, |
| brightness: float = 0.2, |
| contrast: float = 0.2, |
| saturation: float = 0.2, |
| grayscale_prob: float = 0.05, |
| horizontal_flip_prob: float = 0.5, |
| blur_prob: float = 0.1, |
| blur_sigma: tuple = (0.1, 2.0), |
| ): |
| self.brightness = brightness |
| self.contrast = contrast |
| self.saturation = saturation |
| self.grayscale_prob = grayscale_prob |
| self.horizontal_flip_prob = horizontal_flip_prob |
| self.blur_prob = blur_prob |
| self.blur_sigma = blur_sigma |
| |
| def __call__(self, template: torch.Tensor, search: torch.Tensor, |
| bbox: torch.Tensor) -> tuple: |
| """ |
| Args: |
| template: (3, H_t, W_t) tensor in [0, 1] |
| search: (3, H_s, W_s) tensor in [0, 1] |
| bbox: (4,) tensor [cx, cy, w, h] in search region pixels |
| Returns: |
| template, search, bbox (augmented) |
| """ |
| |
| if random.random() < 0.8: |
| |
| factor = 1.0 + random.uniform(-self.brightness, self.brightness) |
| template = (template * factor).clamp(0, 1) |
| search = (search * factor).clamp(0, 1) |
| |
| |
| factor = 1.0 + random.uniform(-self.contrast, self.contrast) |
| t_mean = template.mean() |
| s_mean = search.mean() |
| template = ((template - t_mean) * factor + t_mean).clamp(0, 1) |
| search = ((search - s_mean) * factor + s_mean).clamp(0, 1) |
| |
| |
| if random.random() < self.grayscale_prob: |
| t_gray = template.mean(dim=0, keepdim=True).expand_as(template) |
| s_gray = search.mean(dim=0, keepdim=True).expand_as(search) |
| template = t_gray |
| search = s_gray |
| |
| |
| if random.random() < self.horizontal_flip_prob: |
| template = template.flip(-1) |
| search = search.flip(-1) |
| W_s = search.shape[-1] |
| bbox = bbox.clone() |
| bbox[0] = W_s - bbox[0] |
| |
| |
| if random.random() < self.blur_prob: |
| sigma = random.uniform(*self.blur_sigma) |
| kernel_size = int(2 * round(3 * sigma) + 1) |
| if kernel_size >= 3: |
| search = self._gaussian_blur(search, kernel_size, sigma) |
| |
| return template, search, bbox |
| |
| @staticmethod |
| def _gaussian_blur(img: torch.Tensor, kernel_size: int, sigma: float) -> torch.Tensor: |
| """Apply Gaussian blur to a (C, H, W) tensor.""" |
| import torch.nn.functional as F |
| |
| |
| x = torch.arange(kernel_size, dtype=img.dtype, device=img.device) - kernel_size // 2 |
| kernel_1d = torch.exp(-0.5 * (x / sigma) ** 2) |
| kernel_1d = kernel_1d / kernel_1d.sum() |
| |
| |
| pad = kernel_size // 2 |
| img = img.unsqueeze(0) |
| |
| |
| k_h = kernel_1d.view(1, 1, 1, -1).expand(img.shape[1], -1, -1, -1) |
| img = F.conv2d(F.pad(img, (pad, pad, 0, 0), mode='reflect'), |
| k_h, groups=img.shape[1]) |
| |
| |
| k_v = kernel_1d.view(1, 1, -1, 1).expand(img.shape[1], -1, -1, -1) |
| img = F.conv2d(F.pad(img, (0, 0, pad, pad), mode='reflect'), |
| k_v, groups=img.shape[1]) |
| |
| return img.squeeze(0) |
|
|
|
|
| |
| |
| |
|
|
| def crop_and_resize(image: np.ndarray, center: np.ndarray, size: float, |
| output_size: int) -> np.ndarray: |
| """Crop a square region from image, centered at center, with given size. |
| |
| Args: |
| image: (H, W, 3) numpy array, uint8 or float |
| center: (2,) [cx, cy] in image coordinates |
| size: side length of the square crop |
| output_size: resize crop to (output_size, output_size) |
| Returns: |
| (output_size, output_size, 3) numpy array |
| """ |
| H, W = image.shape[:2] |
| half = size / 2 |
| |
| x1 = int(round(center[0] - half)) |
| y1 = int(round(center[1] - half)) |
| x2 = int(round(center[0] + half)) |
| y2 = int(round(center[1] + half)) |
| |
| |
| pad_left = max(0, -x1) |
| pad_top = max(0, -y1) |
| pad_right = max(0, x2 - W) |
| pad_bottom = max(0, y2 - H) |
| |
| x1c = max(0, x1) |
| y1c = max(0, y1) |
| x2c = min(W, x2) |
| y2c = min(H, y2) |
| |
| crop = image[y1c:y2c, x1c:x2c] |
| |
| if pad_left > 0 or pad_top > 0 or pad_right > 0 or pad_bottom > 0: |
| mean_color = image.mean(axis=(0, 1)) |
| padded = np.full((crop.shape[0] + pad_top + pad_bottom, |
| crop.shape[1] + pad_left + pad_right, 3), |
| mean_color, dtype=crop.dtype) |
| padded[pad_top:pad_top + crop.shape[0], pad_left:pad_left + crop.shape[1]] = crop |
| crop = padded |
| |
| |
| if crop.shape[0] > 0 and crop.shape[1] > 0: |
| import torch.nn.functional as F |
| crop_t = torch.from_numpy(crop.copy()).float().permute(2, 0, 1).unsqueeze(0) |
| crop_t = F.interpolate(crop_t, size=(output_size, output_size), |
| mode='bilinear', align_corners=False) |
| crop = crop_t.squeeze(0).permute(1, 2, 0).numpy() |
| else: |
| crop = np.zeros((output_size, output_size, 3), dtype=np.float32) |
| |
| return crop |
|
|
|
|
| def compute_crop_params(bbox: np.ndarray, context_factor: float = 2.0) -> tuple: |
| """Compute crop center and size from bbox with context. |
| |
| Args: |
| bbox: [x, y, w, h] bounding box |
| context_factor: how much context around bbox (2.0 = 2x target size) |
| Returns: |
| center: (2,) [cx, cy] |
| crop_size: scalar side length |
| """ |
| x, y, w, h = bbox |
| cx = x + w / 2 |
| cy = y + h / 2 |
| |
| |
| |
| p = (w + h) / 2 |
| crop_size = math.sqrt((w + p) * (h + p)) * context_factor |
| crop_size = max(crop_size, 10) |
| |
| return np.array([cx, cy]), crop_size |
|
|
|
|
| |
| |
| |
|
|
| class SequenceDataset(Dataset): |
| """Base class for tracking sequence datasets. |
| |
| Returns K-frame clips: template + K consecutive search frames. |
| The mLSTM processes these as one long sequence where memory carries |
| information across frames — this is the core training paradigm. |
| |
| Subclasses must populate self.sequences with list of: |
| {'frames': [path1, path2, ...], 'gt': [[x,y,w,h], ...]} |
| """ |
| |
| def __init__( |
| self, |
| template_size: int = 128, |
| search_size: int = 256, |
| feat_size: int = 16, |
| acl_difficulty: float = 1.0, |
| max_gap: int = 100, |
| clip_length: int = 3, |
| augmentation: bool = True, |
| ): |
| super().__init__() |
| self.template_size = template_size |
| self.search_size = search_size |
| self.feat_size = feat_size |
| self.acl_difficulty = acl_difficulty |
| self.max_gap = max_gap |
| self.clip_length = clip_length |
| self.sequences = [] |
| |
| self.augmentation = TrackingAugmentation() if augmentation else None |
| |
| def __len__(self): |
| return len(self.sequences) |
| |
| def _load_image(self, path: str) -> np.ndarray: |
| """Load image from path. Returns (H, W, 3) float32 in [0, 255].""" |
| try: |
| from PIL import Image |
| img = Image.open(path).convert('RGB') |
| return np.array(img, dtype=np.float32) |
| except ImportError: |
| import cv2 |
| img = cv2.imread(path) |
| if img is None: |
| return np.zeros((480, 640, 3), dtype=np.float32) |
| return cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32) |
| |
| def _sample_clip(self, idx: int) -> list: |
| """Sample a clip: template frame + K consecutive search frames. |
| |
| Returns: |
| list of frame indices: [template_idx, search_1_idx, ..., search_K_idx] |
| """ |
| seq = self.sequences[idx] |
| n_frames = len(seq['frames']) |
| K = self.clip_length |
| |
| valid = [i for i in range(n_frames) |
| if seq['gt'][i] is not None and seq['gt'][i][2] > 0 and seq['gt'][i][3] > 0] |
| valid_set = set(valid) |
| |
| if len(valid) < K + 1: |
| |
| if len(valid) == 0: |
| return [0] * (K + 1) |
| return [valid[0]] + [valid[min(i, len(valid)-1)] for i in range(K)] |
| |
| |
| t_idx = random.choice(valid) |
| |
| |
| |
| effective_gap = max(1, int(self.max_gap * self.acl_difficulty)) |
| |
| |
| min_start = t_idx + 1 |
| max_start = min(t_idx + effective_gap, n_frames - K) |
| |
| if max_start < min_start: |
| |
| max_start_before = t_idx - K |
| min_start_before = max(0, t_idx - effective_gap - K) |
| if max_start_before >= min_start_before and max_start_before >= 0: |
| clip_start = random.randint(min_start_before, max_start_before) |
| else: |
| |
| clip_start = max(0, min(n_frames - K, t_idx + 1)) |
| |
| else: |
| clip_start = random.randint(min_start, max(min_start, max_start)) |
| |
| |
| search_indices = [] |
| for i in range(clip_start, min(clip_start + K * 3, n_frames)): |
| if i in valid_set and i != t_idx: |
| search_indices.append(i) |
| if len(search_indices) == K: |
| break |
| |
| |
| while len(search_indices) < K: |
| search_indices.append(search_indices[-1] if search_indices else t_idx) |
| |
| return [t_idx] + search_indices[:K] |
| |
| def _process_frame(self, img: np.ndarray, bbox: np.ndarray, is_template: bool): |
| """Crop and preprocess a single frame. |
| |
| Returns: |
| image_tensor: (3, H, W) float [0, 1] |
| bbox_in_crop: (4,) [cx, cy, w, h] in crop coordinates |
| """ |
| if is_template: |
| center, crop_size = compute_crop_params(bbox, context_factor=2.0) |
| output_size = self.template_size |
| else: |
| center, crop_size = compute_crop_params(bbox, context_factor=4.0) |
| output_size = self.search_size |
| |
| jitter = self.acl_difficulty * bbox[2:4].mean() * 0.3 |
| if jitter > 0: |
| center[0] += random.gauss(0, jitter) |
| center[1] += random.gauss(0, jitter) |
| |
| crop = crop_and_resize(img, center, crop_size, output_size) |
| |
| |
| scale = output_size / crop_size |
| cx = (bbox[0] + bbox[2] / 2 - center[0] + crop_size / 2) * scale |
| cy = (bbox[1] + bbox[3] / 2 - center[1] + crop_size / 2) * scale |
| w = bbox[2] * scale |
| h = bbox[3] * scale |
| |
| cx = max(0, min(output_size, cx)) |
| cy = max(0, min(output_size, cy)) |
| w = max(1, min(output_size, w)) |
| h = max(1, min(output_size, h)) |
| |
| tensor = torch.from_numpy(crop).float().permute(2, 0, 1) / 255.0 |
| bbox_crop = torch.tensor([cx, cy, w, h]) |
| |
| return tensor, bbox_crop |
| |
| def _make_heatmap(self, bbox: torch.Tensor): |
| """Generate GT heatmap from bbox in search crop coordinates.""" |
| stride = self.search_size / self.feat_size |
| cx_feat = bbox[0].item() / stride |
| cy_feat = bbox[1].item() / stride |
| w_search = bbox[2].item() |
| h_search = bbox[3].item() |
| |
| y = torch.arange(self.feat_size, dtype=torch.float32) |
| x = torch.arange(self.feat_size, dtype=torch.float32) |
| yy, xx = torch.meshgrid(y, x, indexing='ij') |
| |
| sigma = max(1.0, min(3.0, (w_search + h_search) / (2 * stride * 4))) |
| dist_sq = (xx - cx_feat) ** 2 + (yy - cy_feat) ** 2 |
| heatmap = torch.exp(-dist_sq / (2 * sigma ** 2)).unsqueeze(0) |
| return heatmap |
| |
| def __getitem__(self, idx): |
| seq = self.sequences[idx % len(self.sequences)] |
| clip_indices = self._sample_clip(idx % len(self.sequences)) |
| |
| t_idx = clip_indices[0] |
| s_indices = clip_indices[1:] |
| K = len(s_indices) |
| |
| |
| t_img = self._load_image(seq['frames'][t_idx]) |
| t_bbox = np.array(seq['gt'][t_idx], dtype=np.float32) |
| template, _ = self._process_frame(t_img, t_bbox, is_template=True) |
| |
| |
| searches = [] |
| heatmaps = [] |
| sizes = [] |
| boxes = [] |
| |
| for s_idx in s_indices: |
| s_img = self._load_image(seq['frames'][s_idx]) |
| s_bbox = np.array(seq['gt'][s_idx], dtype=np.float32) |
| search, bbox_crop = self._process_frame(s_img, s_bbox, is_template=False) |
| |
| |
| if self.augmentation is not None: |
| template_aug, search, bbox_crop = self.augmentation(template, search, bbox_crop) |
| |
| if len(searches) == 0: |
| template = template_aug |
| |
| searches.append(search) |
| heatmaps.append(self._make_heatmap(bbox_crop)) |
| sizes.append(torch.tensor([bbox_crop[2].item() / self.search_size, |
| bbox_crop[3].item() / self.search_size])) |
| boxes.append(bbox_crop) |
| |
| return { |
| 'template': template, |
| 'searches': torch.stack(searches, dim=0), |
| 'heatmaps': torch.stack(heatmaps, dim=0), |
| 'sizes': torch.stack(sizes, dim=0), |
| 'boxes': torch.stack(boxes, dim=0), |
| } |
| |
| def set_acl_difficulty(self, difficulty: float): |
| """Update ACL difficulty level (0.0 = easy, 1.0 = hard).""" |
| self.acl_difficulty = min(1.0, max(0.0, difficulty)) |
|
|
|
|
| |
| |
| |
|
|
| class GOT10kDataset(SequenceDataset): |
| """GOT-10k tracking dataset. |
| |
| Structure: |
| root/train/GOT-10k_Train_NNNNNN/ |
| 00000001.jpg, 00000002.jpg, ... |
| groundtruth.txt # x,y,w,h per line |
| """ |
| |
| def __init__(self, root: str, split: str = 'train', **kwargs): |
| super().__init__(**kwargs) |
| self.root = Path(root) |
| self._load_sequences(split) |
| |
| def _load_sequences(self, split): |
| split_dir = self.root / split |
| if not split_dir.exists(): |
| print(f"Warning: GOT-10k {split} not found at {split_dir}") |
| return |
| |
| seq_dirs = sorted([d for d in split_dir.iterdir() if d.is_dir() and 'Train' in d.name]) |
| print(f"Loading GOT-10k {split}: found {len(seq_dirs)} sequences") |
| |
| for seq_dir in seq_dirs: |
| gt_file = seq_dir / 'groundtruth.txt' |
| if not gt_file.exists(): |
| continue |
| |
| |
| gt_boxes = [] |
| with open(gt_file, 'r') as f: |
| for line in f: |
| line = line.strip() |
| if not line: |
| gt_boxes.append(None) |
| continue |
| parts = line.replace(',', ' ').split() |
| try: |
| gt_boxes.append([float(x) for x in parts[:4]]) |
| except ValueError: |
| gt_boxes.append(None) |
| |
| |
| frames = sorted(glob.glob(str(seq_dir / '*.jpg'))) |
| if not frames: |
| frames = sorted(glob.glob(str(seq_dir / '*.png'))) |
| |
| if len(frames) != len(gt_boxes): |
| |
| min_len = min(len(frames), len(gt_boxes)) |
| frames = frames[:min_len] |
| gt_boxes = gt_boxes[:min_len] |
| |
| if len(frames) >= 2: |
| self.sequences.append({'frames': frames, 'gt': gt_boxes}) |
| |
| print(f" Loaded {len(self.sequences)} GOT-10k sequences") |
|
|
|
|
| |
| |
| |
|
|
| class LaSOTDataset(SequenceDataset): |
| """LaSOT tracking dataset. |
| |
| Structure: |
| root/ |
| airplane/ |
| airplane-1/ |
| img/ |
| 00000001.jpg, ... |
| groundtruth.txt # x,y,w,h per line |
| ... |
| """ |
| |
| def __init__(self, root: str, split: str = 'train', **kwargs): |
| super().__init__(**kwargs) |
| self.root = Path(root) |
| self._load_sequences(split) |
| |
| def _load_sequences(self, split): |
| if not self.root.exists(): |
| print(f"Warning: LaSOT not found at {self.root}") |
| return |
| |
| |
| |
| categories = sorted([d for d in self.root.iterdir() if d.is_dir()]) |
| total_seqs = 0 |
| |
| for cat_dir in categories: |
| seq_dirs = sorted([d for d in cat_dir.iterdir() if d.is_dir()]) |
| |
| |
| if split == 'train': |
| seq_dirs = seq_dirs[:int(len(seq_dirs) * 0.8)] |
| else: |
| seq_dirs = seq_dirs[int(len(seq_dirs) * 0.8):] |
| |
| for seq_dir in seq_dirs: |
| gt_file = seq_dir / 'groundtruth.txt' |
| img_dir = seq_dir / 'img' |
| |
| if not gt_file.exists() or not img_dir.exists(): |
| continue |
| |
| |
| gt_boxes = [] |
| with open(gt_file, 'r') as f: |
| for line in f: |
| line = line.strip() |
| if not line: |
| gt_boxes.append(None) |
| continue |
| parts = line.replace(',', ' ').split() |
| try: |
| gt_boxes.append([float(x) for x in parts[:4]]) |
| except ValueError: |
| gt_boxes.append(None) |
| |
| frames = sorted(glob.glob(str(img_dir / '*.jpg'))) |
| |
| if len(frames) != len(gt_boxes): |
| min_len = min(len(frames), len(gt_boxes)) |
| frames = frames[:min_len] |
| gt_boxes = gt_boxes[:min_len] |
| |
| if len(frames) >= 2: |
| self.sequences.append({'frames': frames, 'gt': gt_boxes}) |
| total_seqs += 1 |
| |
| print(f" Loaded {total_seqs} LaSOT {split} sequences across {len(categories)} categories") |
|
|
|
|
| |
| |
| |
|
|
| class TrackingNetDataset(SequenceDataset): |
| """TrackingNet tracking dataset. |
| |
| Structure: |
| root/ |
| TRAIN_0/ |
| frames/ |
| video_name/ |
| 0.jpg, 1.jpg, ... |
| anno/ |
| video_name.txt # x,y,w,h per line |
| TRAIN_1/ |
| ... |
| """ |
| |
| def __init__(self, root: str, chunks: list = None, **kwargs): |
| super().__init__(**kwargs) |
| self.root = Path(root) |
| if chunks is None: |
| chunks = list(range(12)) |
| self._load_sequences(chunks) |
| |
| def _load_sequences(self, chunks): |
| if not self.root.exists(): |
| print(f"Warning: TrackingNet not found at {self.root}") |
| return |
| |
| total_seqs = 0 |
| for chunk_idx in chunks: |
| chunk_dir = self.root / f'TRAIN_{chunk_idx}' |
| if not chunk_dir.exists(): |
| continue |
| |
| anno_dir = chunk_dir / 'anno' |
| frames_dir = chunk_dir / 'frames' |
| |
| if not anno_dir.exists() or not frames_dir.exists(): |
| continue |
| |
| for anno_file in sorted(anno_dir.glob('*.txt')): |
| seq_name = anno_file.stem |
| seq_frames_dir = frames_dir / seq_name |
| |
| if not seq_frames_dir.exists(): |
| continue |
| |
| |
| gt_boxes = [] |
| with open(anno_file, 'r') as f: |
| for line in f: |
| line = line.strip() |
| if not line: |
| gt_boxes.append(None) |
| continue |
| parts = line.replace(',', ' ').split() |
| try: |
| gt_boxes.append([float(x) for x in parts[:4]]) |
| except ValueError: |
| gt_boxes.append(None) |
| |
| frames = sorted(glob.glob(str(seq_frames_dir / '*.jpg'))) |
| if not frames: |
| frames = sorted(glob.glob(str(seq_frames_dir / '*.png'))) |
| |
| if len(frames) != len(gt_boxes): |
| min_len = min(len(frames), len(gt_boxes)) |
| frames = frames[:min_len] |
| gt_boxes = gt_boxes[:min_len] |
| |
| if len(frames) >= 2: |
| self.sequences.append({'frames': frames, 'gt': gt_boxes}) |
| total_seqs += 1 |
| |
| print(f" Loaded {total_seqs} TrackingNet sequences from {len(chunks)} chunks") |
|
|
|
|
| |
| |
| |
|
|
| class COCODetDataset(SequenceDataset): |
| """COCO detection images as pseudo-sequences for pretraining. |
| |
| Each image with a valid bounding box becomes a length-1 "sequence" |
| where template and search are crops from the same image. |
| """ |
| |
| def __init__(self, root: str, ann_file: str = None, **kwargs): |
| super().__init__(**kwargs) |
| self.root = Path(root) |
| self._load_annotations(ann_file) |
| |
| def _load_annotations(self, ann_file): |
| if ann_file is None: |
| ann_file = str(self.root.parent / 'annotations' / 'instances_train2017.json') |
| |
| if not os.path.exists(ann_file): |
| print(f"Warning: COCO annotations not found at {ann_file}") |
| return |
| |
| try: |
| import json |
| with open(ann_file, 'r') as f: |
| coco = json.load(f) |
| |
| |
| images = {img['id']: img for img in coco['images']} |
| |
| |
| for ann in coco['annotations']: |
| if ann.get('iscrowd', 0): |
| continue |
| bbox = ann['bbox'] |
| if bbox[2] < 10 or bbox[3] < 10: |
| continue |
| |
| img_info = images.get(ann['image_id']) |
| if img_info is None: |
| continue |
| |
| img_path = str(self.root / img_info['file_name']) |
| if os.path.exists(img_path): |
| |
| self.sequences.append({ |
| 'frames': [img_path, img_path], |
| 'gt': [bbox, bbox], |
| }) |
| |
| print(f" Loaded {len(self.sequences)} COCO pseudo-sequences") |
| |
| except Exception as e: |
| print(f"Warning: Failed to load COCO annotations: {e}") |
|
|
|
|
| |
| |
| |
|
|
| class SyntheticTrackingDataset(Dataset): |
| """Synthetic tracking dataset for testing without real data. |
| |
| Generates K-frame clips: template + K search frames with a moving |
| colored rectangle target. Motion is linear with noise. |
| """ |
| |
| def __init__( |
| self, |
| length: int = 10000, |
| template_size: int = 128, |
| search_size: int = 256, |
| feat_size: int = 16, |
| acl_difficulty: float = 1.0, |
| clip_length: int = 3, |
| ): |
| super().__init__() |
| self.length = length |
| self.template_size = template_size |
| self.search_size = search_size |
| self.feat_size = feat_size |
| self.acl_difficulty = acl_difficulty |
| self.clip_length = clip_length |
| |
| def __len__(self): |
| return self.length |
| |
| def _make_heatmap(self, cx, cy, w_search, h_search): |
| stride = self.search_size / self.feat_size |
| cx_feat = cx / stride |
| cy_feat = cy / stride |
| y = torch.arange(self.feat_size, dtype=torch.float32) |
| x = torch.arange(self.feat_size, dtype=torch.float32) |
| yy, xx = torch.meshgrid(y, x, indexing='ij') |
| sigma = max(1.0, min(3.0, (w_search + h_search) / (2 * stride * 4))) |
| dist_sq = (xx - cx_feat) ** 2 + (yy - cy_feat) ** 2 |
| return torch.exp(-dist_sq / (2 * sigma ** 2)).unsqueeze(0) |
| |
| def __getitem__(self, idx): |
| rng = random.Random(idx) |
| K = self.clip_length |
| |
| |
| color = torch.tensor([rng.random(), rng.random(), rng.random()]).view(3, 1, 1) |
| target_w = rng.uniform(0.1, 0.5) * self.search_size |
| target_h = rng.uniform(0.1, 0.5) * self.search_size |
| |
| |
| cx0 = self.search_size / 2 |
| cy0 = self.search_size / 2 |
| |
| |
| vx = rng.gauss(0, self.acl_difficulty * 15) |
| vy = rng.gauss(0, self.acl_difficulty * 15) |
| |
| |
| template = torch.randn(3, self.template_size, self.template_size) * 0.1 |
| t_hw = int(min(target_w / 2, self.template_size / 2 - 1)) |
| t_hh = int(min(target_h / 2, self.template_size / 2 - 1)) |
| tc = self.template_size // 2 |
| template[:, tc - t_hh:tc + t_hh, tc - t_hw:tc + t_hw] = color |
| |
| |
| searches = [] |
| heatmaps = [] |
| sizes = [] |
| boxes = [] |
| |
| for k in range(K): |
| |
| cx = cx0 + vx * (k + 1) + rng.gauss(0, self.acl_difficulty * 5) |
| cy = cy0 + vy * (k + 1) + rng.gauss(0, self.acl_difficulty * 5) |
| cx = max(target_w / 2, min(self.search_size - target_w / 2, cx)) |
| cy = max(target_h / 2, min(self.search_size - target_h / 2, cy)) |
| |
| search = torch.randn(3, self.search_size, self.search_size) * 0.1 |
| sx1 = max(0, int(cx - target_w / 2)) |
| sy1 = max(0, int(cy - target_h / 2)) |
| sx2 = min(self.search_size, int(cx + target_w / 2)) |
| sy2 = min(self.search_size, int(cy + target_h / 2)) |
| search[:, sy1:sy2, sx1:sx2] = color |
| |
| searches.append(search) |
| heatmaps.append(self._make_heatmap(cx, cy, target_w, target_h)) |
| sizes.append(torch.tensor([target_w / self.search_size, |
| target_h / self.search_size])) |
| boxes.append(torch.tensor([cx, cy, target_w, target_h])) |
| |
| return { |
| 'template': template, |
| 'searches': torch.stack(searches, dim=0), |
| 'heatmaps': torch.stack(heatmaps, dim=0), |
| 'sizes': torch.stack(sizes, dim=0), |
| 'boxes': torch.stack(boxes, dim=0), |
| } |
| |
| def set_acl_difficulty(self, difficulty: float): |
| self.acl_difficulty = min(1.0, max(0.0, difficulty)) |
|
|
|
|
| |
| |
| |
|
|
| class VisDroneSOTDataset(SequenceDataset): |
| """VisDrone-SOT single object tracking dataset (drone/UAV perspective). |
| |
| Structure: |
| root/ |
| VisDrone2019-SOT-train/ |
| sequences/ |
| uav0000001_00000_s/ |
| 0000001.jpg, 0000002.jpg, ... |
| ... |
| annotations/ |
| uav0000001_00000_s.txt # x,y,w,h per line |
| ... |
| |
| Splits: train (86 sequences, ~70K frames), val (11 sequences), |
| test-dev (35 sequences), test-challenge (35 sequences) |
| |
| Key for our tracker: real drone footage with small targets, fast motion, |
| viewpoint changes, and camera ego-motion — the exact conditions we deploy in. |
| """ |
| |
| def __init__(self, root: str, split: str = 'train', **kwargs): |
| super().__init__(**kwargs) |
| self.root = Path(root) |
| self._load_sequences(split) |
| |
| def _load_sequences(self, split): |
| |
| split_names = { |
| 'train': ['VisDrone2019-SOT-train', 'VisDrone2018-SOT-train', 'train'], |
| 'val': ['VisDrone2019-SOT-val', 'VisDrone2018-SOT-val', 'val'], |
| 'test': ['VisDrone2019-SOT-test-dev', 'VisDrone2018-SOT-test', 'test-dev', 'test'], |
| } |
| |
| split_dir = None |
| for name in split_names.get(split, [split]): |
| candidate = self.root / name |
| if candidate.exists(): |
| split_dir = candidate |
| break |
| |
| if (self.root / 'sequences').exists(): |
| split_dir = self.root |
| break |
| |
| if split_dir is None: |
| print(f"Warning: VisDrone-SOT {split} not found at {self.root}") |
| return |
| |
| seq_dir = split_dir / 'sequences' |
| anno_dir = split_dir / 'annotations' |
| |
| if not seq_dir.exists() or not anno_dir.exists(): |
| print(f"Warning: VisDrone-SOT missing sequences/ or annotations/ at {split_dir}") |
| return |
| |
| total_seqs = 0 |
| for anno_file in sorted(anno_dir.glob('*.txt')): |
| seq_name = anno_file.stem |
| frames_dir = seq_dir / seq_name |
| |
| if not frames_dir.exists(): |
| continue |
| |
| gt_boxes = [] |
| with open(anno_file, 'r') as f: |
| for line in f: |
| line = line.strip() |
| if not line: |
| gt_boxes.append(None) |
| continue |
| parts = line.replace(',', ' ').split() |
| try: |
| gt_boxes.append([float(x) for x in parts[:4]]) |
| except ValueError: |
| gt_boxes.append(None) |
| |
| frames = sorted(glob.glob(str(frames_dir / '*.jpg'))) |
| if not frames: |
| frames = sorted(glob.glob(str(frames_dir / '*.png'))) |
| |
| if len(frames) != len(gt_boxes): |
| min_len = min(len(frames), len(gt_boxes)) |
| frames = frames[:min_len] |
| gt_boxes = gt_boxes[:min_len] |
| |
| if len(frames) >= 2: |
| self.sequences.append({'frames': frames, 'gt': gt_boxes}) |
| total_seqs += 1 |
| |
| print(f" Loaded {total_seqs} VisDrone-SOT {split} sequences") |
|
|
|
|
| |
| |
| |
|
|
| class UAVDTDataset(SequenceDataset): |
| """UAVDT (Unmanned Aerial Vehicle Detection and Tracking) dataset. |
| |
| Structure: |
| root/ |
| UAV-benchmark-S/ # SOT annotations |
| {seq_name}/ |
| {seq_name}_gt.txt # x,y,w,h per line (or comma-separated) |
| UAV-benchmark-M/ # Frames |
| {seq_name}/ |
| img000001.jpg, img000002.jpg, ... |
| |
| Alternative structure (simpler): |
| root/ |
| sequences/ |
| {seq_name}/ |
| img000001.jpg, ... |
| annotations/ |
| {seq_name}_gt.txt |
| |
| 50 sequences total, typically 30 train / 20 test. |
| Contains vehicle tracking from drone perspective — complementary to VisDrone. |
| """ |
| |
| def __init__(self, root: str, split: str = 'train', **kwargs): |
| super().__init__(**kwargs) |
| self.root = Path(root) |
| self._load_sequences(split) |
| |
| def _load_sequences(self, split): |
| |
| anno_dir = self.root / 'UAV-benchmark-S' |
| frame_dir = self.root / 'UAV-benchmark-M' |
| |
| if not anno_dir.exists(): |
| |
| anno_dir = self.root / 'annotations' |
| frame_dir = self.root / 'sequences' |
| |
| if not anno_dir.exists(): |
| |
| anno_dir = self.root |
| frame_dir = self.root |
| |
| if not anno_dir.exists(): |
| print(f"Warning: UAVDT not found at {self.root}") |
| return |
| |
| |
| all_seqs = [] |
| |
| |
| gt_files = sorted(anno_dir.rglob('*_gt.txt')) |
| if not gt_files: |
| gt_files = sorted(anno_dir.rglob('*.txt')) |
| |
| for gt_file in gt_files: |
| seq_name = gt_file.stem.replace('_gt', '') |
| |
| |
| frames_path = None |
| for candidate in [ |
| frame_dir / seq_name, |
| frame_dir / seq_name / 'img', |
| self.root / seq_name, |
| ]: |
| if candidate.exists(): |
| frames_path = candidate |
| break |
| |
| if frames_path is None: |
| continue |
| |
| gt_boxes = [] |
| with open(gt_file, 'r') as f: |
| for line in f: |
| line = line.strip() |
| if not line: |
| gt_boxes.append(None) |
| continue |
| parts = line.replace(',', ' ').replace('\t', ' ').split() |
| try: |
| gt_boxes.append([float(x) for x in parts[:4]]) |
| except (ValueError, IndexError): |
| gt_boxes.append(None) |
| |
| frames = sorted(glob.glob(str(frames_path / '*.jpg'))) |
| if not frames: |
| frames = sorted(glob.glob(str(frames_path / '*.png'))) |
| |
| if len(frames) != len(gt_boxes): |
| min_len = min(len(frames), len(gt_boxes)) |
| frames = frames[:min_len] |
| gt_boxes = gt_boxes[:min_len] |
| |
| if len(frames) >= 2: |
| all_seqs.append({'frames': frames, 'gt': gt_boxes, 'name': seq_name}) |
| |
| |
| all_seqs.sort(key=lambda x: x['name']) |
| split_idx = int(len(all_seqs) * 0.6) |
| |
| if split == 'train': |
| selected = all_seqs[:split_idx] |
| else: |
| selected = all_seqs[split_idx:] |
| |
| for seq in selected: |
| self.sequences.append({'frames': seq['frames'], 'gt': seq['gt']}) |
| |
| print(f" Loaded {len(self.sequences)} UAVDT {split} sequences " |
| f"(from {len(all_seqs)} total)") |
|
|
|
|
| |
| |
| |
|
|
| class WebUAV3MDataset(SequenceDataset): |
| """WebUAV-3M: million-scale multi-modal UAV tracking dataset. |
| |
| Structure: |
| root/ |
| {superclass}/ # e.g., person, vehicle, animal |
| {seq_name}/ |
| img/ |
| 000001.jpg, 000002.jpg, ... |
| groundtruth_rect.txt # x,y,w,h per line |
| OR: |
| {seq_name}/ |
| *.jpg |
| groundtruth_rect.txt |
| |
| 4,500 sequences, 3.3M frames, 12 superclasses, 223 target classes. |
| Average video length: 710 frames (23.7 seconds at 30 FPS). |
| |
| This is the largest UAV tracking dataset. All sequences are from real |
| drone footage. Purpose-built for training deep UAV trackers. |
| """ |
| |
| def __init__(self, root: str, split: str = 'train', max_sequences: int = None, **kwargs): |
| super().__init__(**kwargs) |
| self.root = Path(root) |
| self._load_sequences(split, max_sequences) |
| |
| def _load_sequences(self, split, max_sequences): |
| if not self.root.exists(): |
| print(f"Warning: WebUAV-3M not found at {self.root}") |
| return |
| |
| |
| all_seq_dirs = [] |
| |
| |
| gt_files = sorted(self.root.rglob('groundtruth_rect.txt')) |
| if not gt_files: |
| gt_files = sorted(self.root.rglob('groundtruth.txt')) |
| |
| for gt_file in gt_files: |
| seq_dir = gt_file.parent |
| |
| img_dir = seq_dir / 'img' |
| if not img_dir.exists(): |
| img_dir = seq_dir |
| |
| frames = sorted(glob.glob(str(img_dir / '*.jpg'))) |
| if not frames: |
| frames = sorted(glob.glob(str(img_dir / '*.png'))) |
| |
| if len(frames) >= 2: |
| all_seq_dirs.append((gt_file, frames)) |
| |
| print(f"WebUAV-3M: found {len(all_seq_dirs)} sequences total") |
| |
| |
| split_idx = int(len(all_seq_dirs) * 0.8) |
| if split == 'train': |
| selected = all_seq_dirs[:split_idx] |
| else: |
| selected = all_seq_dirs[split_idx:] |
| |
| |
| if max_sequences and len(selected) > max_sequences: |
| |
| step = len(selected) // max_sequences |
| selected = selected[::step][:max_sequences] |
| |
| for gt_file, frames in selected: |
| gt_boxes = [] |
| with open(gt_file, 'r') as f: |
| for line in f: |
| line = line.strip() |
| if not line: |
| gt_boxes.append(None) |
| continue |
| parts = line.replace(',', ' ').replace('\t', ' ').split() |
| try: |
| gt_boxes.append([float(x) for x in parts[:4]]) |
| except (ValueError, IndexError): |
| gt_boxes.append(None) |
| |
| if len(frames) != len(gt_boxes): |
| min_len = min(len(frames), len(gt_boxes)) |
| frames = frames[:min_len] |
| gt_boxes = gt_boxes[:min_len] |
| |
| if len(frames) >= 2: |
| self.sequences.append({'frames': frames, 'gt': gt_boxes}) |
| |
| print(f" Loaded {len(self.sequences)} WebUAV-3M {split} sequences") |
|
|
|
|
| |
| |
| |
|
|
| def build_tracking_dataset( |
| data_config: dict, |
| template_size: int = 128, |
| search_size: int = 256, |
| feat_size: int = 16, |
| acl_difficulty: float = 0.0, |
| ) -> Dataset: |
| """Build a combined tracking dataset from multiple sources. |
| |
| Standard ground-level datasets provide general tracking capability. |
| UAV-specific datasets provide drone-perspective specialization. |
| The ACL curriculum bridges the gap: it starts training on easy pairs |
| from ground-level data, then progressively incorporates harder pairs |
| including UAV sequences with fast motion, small targets, and viewpoint changes. |
| |
| Args: |
| data_config: dict with optional keys: |
| Ground-level (standard tracking training data): |
| - 'got10k_root': path to GOT-10k dataset |
| - 'lasot_root': path to LaSOT dataset |
| - 'trackingnet_root': path to TrackingNet dataset |
| - 'coco_root': path to COCO train2017 images |
| |
| UAV-specific (drone perspective — the deployment domain): |
| - 'visdrone_root': path to VisDrone-SOT dataset |
| - 'uavdt_root': path to UAVDT dataset |
| - 'webuav3m_root': path to WebUAV-3M dataset |
| - 'webuav3m_max_sequences': limit WebUAV-3M sequences (default: None = all) |
| |
| Fallback: |
| - 'synthetic_length': number of synthetic samples (fallback) |
| template_size: template crop size |
| search_size: search region crop size |
| feat_size: feature map spatial size |
| acl_difficulty: initial ACL difficulty |
| Returns: |
| ConcatDataset or SyntheticTrackingDataset |
| """ |
| common_kwargs = dict( |
| template_size=template_size, |
| search_size=search_size, |
| feat_size=feat_size, |
| acl_difficulty=acl_difficulty, |
| ) |
| |
| datasets = [] |
| |
| if 'got10k_root' in data_config and os.path.exists(data_config['got10k_root']): |
| ds = GOT10kDataset(data_config['got10k_root'], split='train', **common_kwargs) |
| if len(ds) > 0: |
| datasets.append(ds) |
| print(f"GOT-10k: {len(ds)} sequences") |
| |
| if 'lasot_root' in data_config and os.path.exists(data_config['lasot_root']): |
| ds = LaSOTDataset(data_config['lasot_root'], split='train', **common_kwargs) |
| if len(ds) > 0: |
| datasets.append(ds) |
| print(f"LaSOT: {len(ds)} sequences") |
| |
| if 'trackingnet_root' in data_config and os.path.exists(data_config['trackingnet_root']): |
| ds = TrackingNetDataset(data_config['trackingnet_root'], **common_kwargs) |
| if len(ds) > 0: |
| datasets.append(ds) |
| print(f"TrackingNet: {len(ds)} sequences") |
| |
| if 'coco_root' in data_config and os.path.exists(data_config['coco_root']): |
| ds = COCODetDataset(data_config['coco_root'], **common_kwargs) |
| if len(ds) > 0: |
| datasets.append(ds) |
| print(f"COCO: {len(ds)} pseudo-sequences") |
| |
| |
| |
| if 'visdrone_root' in data_config and os.path.exists(data_config['visdrone_root']): |
| ds = VisDroneSOTDataset(data_config['visdrone_root'], split='train', **common_kwargs) |
| if len(ds) > 0: |
| datasets.append(ds) |
| print(f"VisDrone-SOT: {len(ds)} UAV sequences") |
| |
| if 'uavdt_root' in data_config and os.path.exists(data_config['uavdt_root']): |
| ds = UAVDTDataset(data_config['uavdt_root'], split='train', **common_kwargs) |
| if len(ds) > 0: |
| datasets.append(ds) |
| print(f"UAVDT: {len(ds)} UAV sequences") |
| |
| if 'webuav3m_root' in data_config and os.path.exists(data_config['webuav3m_root']): |
| max_seq = data_config.get('webuav3m_max_sequences', None) |
| ds = WebUAV3MDataset(data_config['webuav3m_root'], split='train', |
| max_sequences=max_seq, **common_kwargs) |
| if len(ds) > 0: |
| datasets.append(ds) |
| print(f"WebUAV-3M: {len(ds)} UAV sequences") |
| |
| if datasets: |
| combined = ConcatDataset(datasets) |
| print(f"\nTotal training samples: {len(combined)}") |
| return combined |
| |
| |
| syn_len = data_config.get('synthetic_length', 10000) |
| print(f"No real data found, using {syn_len} synthetic samples") |
| return SyntheticTrackingDataset( |
| length=syn_len, |
| template_size=template_size, |
| search_size=search_size, |
| feat_size=feat_size, |
| acl_difficulty=acl_difficulty, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| class TrackingDataset(SyntheticTrackingDataset): |
| """Backward-compatible alias for SyntheticTrackingDataset.""" |
| def __init__(self, data_dir=None, split='train', synthetic=False, |
| synthetic_length=10000, clip_length=3, **kwargs): |
| super().__init__(length=synthetic_length, clip_length=clip_length, **kwargs) |
|
|