""" Tracking dataset with real dataset loaders and synthetic fallback. Supports: - GOT-10k: train split (~10k sequences, annotations in groundtruth.txt) - LaSOT: training split (1120 sequences, 14 categories) - TrackingNet: training split (30k+ sequences, annotations in anno/) - COCO detection: for static pair pretraining (bbox crops as pseudo-sequences) - Synthetic data generation for testing (no external data needed) - ACL (Adaptive Curriculum Learning) difficulty scaling - Standard tracking augmentations: spatial jitter, horizontal flip, color jitter, grayscale, Gaussian blur, brightness/contrast Each sample produces a (template, search) pair from the same video sequence with controlled temporal distance, plus GT annotations. Dataset directory structure expected: GOT-10k/ train/ GOT-10k_Train_000001/ 00000001.jpg, 00000002.jpg, ... groundtruth.txt # x,y,w,h per line ... LaSOT/ airplane/ airplane-1/ img/ 00000001.jpg, ... groundtruth.txt # x,y,w,h per line ... TrackingNet/ TRAIN_0/ frames/ video_name/ 0.jpg, 1.jpg, ... anno/ video_name.txt # x,y,w,h per line ... COCO/ train2017/ *.jpg annotations/ instances_train2017.json """ import os import math import glob import random import torch import numpy as np from pathlib import Path from torch.utils.data import Dataset, ConcatDataset # ============================================================ # Augmentations (no torchvision dependency, works with tensors) # ============================================================ class TrackingAugmentation: """Standard tracking augmentations applied to (template, search) pairs. Augmentations preserve the spatial relationship between search region and GT bounding box by applying augmentations consistently. """ def __init__( self, brightness: float = 0.2, contrast: float = 0.2, saturation: float = 0.2, grayscale_prob: float = 0.05, horizontal_flip_prob: float = 0.5, blur_prob: float = 0.1, blur_sigma: tuple = (0.1, 2.0), ): self.brightness = brightness self.contrast = contrast self.saturation = saturation self.grayscale_prob = grayscale_prob self.horizontal_flip_prob = horizontal_flip_prob self.blur_prob = blur_prob self.blur_sigma = blur_sigma def __call__(self, template: torch.Tensor, search: torch.Tensor, bbox: torch.Tensor) -> tuple: """ Args: template: (3, H_t, W_t) tensor in [0, 1] search: (3, H_s, W_s) tensor in [0, 1] bbox: (4,) tensor [cx, cy, w, h] in search region pixels Returns: template, search, bbox (augmented) """ # Color jitter (same for template and search to maintain appearance consistency) if random.random() < 0.8: # Brightness factor = 1.0 + random.uniform(-self.brightness, self.brightness) template = (template * factor).clamp(0, 1) search = (search * factor).clamp(0, 1) # Contrast factor = 1.0 + random.uniform(-self.contrast, self.contrast) t_mean = template.mean() s_mean = search.mean() template = ((template - t_mean) * factor + t_mean).clamp(0, 1) search = ((search - s_mean) * factor + s_mean).clamp(0, 1) # Grayscale if random.random() < self.grayscale_prob: t_gray = template.mean(dim=0, keepdim=True).expand_as(template) s_gray = search.mean(dim=0, keepdim=True).expand_as(search) template = t_gray search = s_gray # Horizontal flip (must also flip bbox cx) if random.random() < self.horizontal_flip_prob: template = template.flip(-1) search = search.flip(-1) W_s = search.shape[-1] bbox = bbox.clone() bbox[0] = W_s - bbox[0] # flip cx # Gaussian blur (search only — simulates motion blur) if random.random() < self.blur_prob: sigma = random.uniform(*self.blur_sigma) kernel_size = int(2 * round(3 * sigma) + 1) if kernel_size >= 3: search = self._gaussian_blur(search, kernel_size, sigma) return template, search, bbox @staticmethod def _gaussian_blur(img: torch.Tensor, kernel_size: int, sigma: float) -> torch.Tensor: """Apply Gaussian blur to a (C, H, W) tensor.""" import torch.nn.functional as F # Create 1D Gaussian kernel x = torch.arange(kernel_size, dtype=img.dtype, device=img.device) - kernel_size // 2 kernel_1d = torch.exp(-0.5 * (x / sigma) ** 2) kernel_1d = kernel_1d / kernel_1d.sum() # Apply separable 2D blur pad = kernel_size // 2 img = img.unsqueeze(0) # (1, C, H, W) # Horizontal k_h = kernel_1d.view(1, 1, 1, -1).expand(img.shape[1], -1, -1, -1) img = F.conv2d(F.pad(img, (pad, pad, 0, 0), mode='reflect'), k_h, groups=img.shape[1]) # Vertical k_v = kernel_1d.view(1, 1, -1, 1).expand(img.shape[1], -1, -1, -1) img = F.conv2d(F.pad(img, (0, 0, pad, pad), mode='reflect'), k_v, groups=img.shape[1]) return img.squeeze(0) # ============================================================ # Crop utilities # ============================================================ def crop_and_resize(image: np.ndarray, center: np.ndarray, size: float, output_size: int) -> np.ndarray: """Crop a square region from image, centered at center, with given size. Args: image: (H, W, 3) numpy array, uint8 or float center: (2,) [cx, cy] in image coordinates size: side length of the square crop output_size: resize crop to (output_size, output_size) Returns: (output_size, output_size, 3) numpy array """ H, W = image.shape[:2] half = size / 2 x1 = int(round(center[0] - half)) y1 = int(round(center[1] - half)) x2 = int(round(center[0] + half)) y2 = int(round(center[1] + half)) # Boundary padding pad_left = max(0, -x1) pad_top = max(0, -y1) pad_right = max(0, x2 - W) pad_bottom = max(0, y2 - H) x1c = max(0, x1) y1c = max(0, y1) x2c = min(W, x2) y2c = min(H, y2) crop = image[y1c:y2c, x1c:x2c] if pad_left > 0 or pad_top > 0 or pad_right > 0 or pad_bottom > 0: mean_color = image.mean(axis=(0, 1)) padded = np.full((crop.shape[0] + pad_top + pad_bottom, crop.shape[1] + pad_left + pad_right, 3), mean_color, dtype=crop.dtype) padded[pad_top:pad_top + crop.shape[0], pad_left:pad_left + crop.shape[1]] = crop crop = padded # Resize if crop.shape[0] > 0 and crop.shape[1] > 0: import torch.nn.functional as F crop_t = torch.from_numpy(crop.copy()).float().permute(2, 0, 1).unsqueeze(0) crop_t = F.interpolate(crop_t, size=(output_size, output_size), mode='bilinear', align_corners=False) crop = crop_t.squeeze(0).permute(1, 2, 0).numpy() else: crop = np.zeros((output_size, output_size, 3), dtype=np.float32) return crop def compute_crop_params(bbox: np.ndarray, context_factor: float = 2.0) -> tuple: """Compute crop center and size from bbox with context. Args: bbox: [x, y, w, h] bounding box context_factor: how much context around bbox (2.0 = 2x target size) Returns: center: (2,) [cx, cy] crop_size: scalar side length """ x, y, w, h = bbox cx = x + w / 2 cy = y + h / 2 # Context amount following STARK/OSTrack convention: # s = sqrt((w + 2p) * (h + 2p)), where p = (w + h) / 2 p = (w + h) / 2 crop_size = math.sqrt((w + p) * (h + p)) * context_factor crop_size = max(crop_size, 10) return np.array([cx, cy]), crop_size # ============================================================ # Base sequence dataset # ============================================================ class SequenceDataset(Dataset): """Base class for tracking sequence datasets. Returns K-frame clips: template + K consecutive search frames. The mLSTM processes these as one long sequence where memory carries information across frames — this is the core training paradigm. Subclasses must populate self.sequences with list of: {'frames': [path1, path2, ...], 'gt': [[x,y,w,h], ...]} """ def __init__( self, template_size: int = 128, search_size: int = 256, feat_size: int = 16, acl_difficulty: float = 1.0, max_gap: int = 100, clip_length: int = 3, augmentation: bool = True, ): super().__init__() self.template_size = template_size self.search_size = search_size self.feat_size = feat_size self.acl_difficulty = acl_difficulty self.max_gap = max_gap self.clip_length = clip_length # K search frames per sample self.sequences = [] self.augmentation = TrackingAugmentation() if augmentation else None def __len__(self): return len(self.sequences) def _load_image(self, path: str) -> np.ndarray: """Load image from path. Returns (H, W, 3) float32 in [0, 255].""" try: from PIL import Image img = Image.open(path).convert('RGB') return np.array(img, dtype=np.float32) except ImportError: import cv2 img = cv2.imread(path) if img is None: return np.zeros((480, 640, 3), dtype=np.float32) return cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32) def _sample_clip(self, idx: int) -> list: """Sample a clip: template frame + K consecutive search frames. Returns: list of frame indices: [template_idx, search_1_idx, ..., search_K_idx] """ seq = self.sequences[idx] n_frames = len(seq['frames']) K = self.clip_length valid = [i for i in range(n_frames) if seq['gt'][i] is not None and seq['gt'][i][2] > 0 and seq['gt'][i][3] > 0] valid_set = set(valid) if len(valid) < K + 1: # Not enough frames — repeat what we have if len(valid) == 0: return [0] * (K + 1) return [valid[0]] + [valid[min(i, len(valid)-1)] for i in range(K)] # Template: pick a random valid frame t_idx = random.choice(valid) # Search frames: K consecutive valid frames AFTER template # Temporal gap between template and first search controlled by ACL effective_gap = max(1, int(self.max_gap * self.acl_difficulty)) # Find the start of the search clip: somewhere after template min_start = t_idx + 1 max_start = min(t_idx + effective_gap, n_frames - K) if max_start < min_start: # Try before template max_start_before = t_idx - K min_start_before = max(0, t_idx - effective_gap - K) if max_start_before >= min_start_before and max_start_before >= 0: clip_start = random.randint(min_start_before, max_start_before) else: # Fallback: just use whatever consecutive frames we can find clip_start = max(0, min(n_frames - K, t_idx + 1)) # But ensure template is different from search frames else: clip_start = random.randint(min_start, max(min_start, max_start)) # Collect K consecutive frames, preferring valid ones search_indices = [] for i in range(clip_start, min(clip_start + K * 3, n_frames)): if i in valid_set and i != t_idx: search_indices.append(i) if len(search_indices) == K: break # Pad if we didn't find enough while len(search_indices) < K: search_indices.append(search_indices[-1] if search_indices else t_idx) return [t_idx] + search_indices[:K] def _process_frame(self, img: np.ndarray, bbox: np.ndarray, is_template: bool): """Crop and preprocess a single frame. Returns: image_tensor: (3, H, W) float [0, 1] bbox_in_crop: (4,) [cx, cy, w, h] in crop coordinates """ if is_template: center, crop_size = compute_crop_params(bbox, context_factor=2.0) output_size = self.template_size else: center, crop_size = compute_crop_params(bbox, context_factor=4.0) output_size = self.search_size # Spatial jitter for search (controlled by ACL) jitter = self.acl_difficulty * bbox[2:4].mean() * 0.3 if jitter > 0: center[0] += random.gauss(0, jitter) center[1] += random.gauss(0, jitter) crop = crop_and_resize(img, center, crop_size, output_size) # Compute GT in crop coordinates scale = output_size / crop_size cx = (bbox[0] + bbox[2] / 2 - center[0] + crop_size / 2) * scale cy = (bbox[1] + bbox[3] / 2 - center[1] + crop_size / 2) * scale w = bbox[2] * scale h = bbox[3] * scale cx = max(0, min(output_size, cx)) cy = max(0, min(output_size, cy)) w = max(1, min(output_size, w)) h = max(1, min(output_size, h)) tensor = torch.from_numpy(crop).float().permute(2, 0, 1) / 255.0 bbox_crop = torch.tensor([cx, cy, w, h]) return tensor, bbox_crop def _make_heatmap(self, bbox: torch.Tensor): """Generate GT heatmap from bbox in search crop coordinates.""" stride = self.search_size / self.feat_size cx_feat = bbox[0].item() / stride cy_feat = bbox[1].item() / stride w_search = bbox[2].item() h_search = bbox[3].item() y = torch.arange(self.feat_size, dtype=torch.float32) x = torch.arange(self.feat_size, dtype=torch.float32) yy, xx = torch.meshgrid(y, x, indexing='ij') sigma = max(1.0, min(3.0, (w_search + h_search) / (2 * stride * 4))) dist_sq = (xx - cx_feat) ** 2 + (yy - cy_feat) ** 2 heatmap = torch.exp(-dist_sq / (2 * sigma ** 2)).unsqueeze(0) return heatmap def __getitem__(self, idx): seq = self.sequences[idx % len(self.sequences)] clip_indices = self._sample_clip(idx % len(self.sequences)) t_idx = clip_indices[0] s_indices = clip_indices[1:] K = len(s_indices) # Load and process template t_img = self._load_image(seq['frames'][t_idx]) t_bbox = np.array(seq['gt'][t_idx], dtype=np.float32) template, _ = self._process_frame(t_img, t_bbox, is_template=True) # Load and process K search frames searches = [] heatmaps = [] sizes = [] boxes = [] for s_idx in s_indices: s_img = self._load_image(seq['frames'][s_idx]) s_bbox = np.array(seq['gt'][s_idx], dtype=np.float32) search, bbox_crop = self._process_frame(s_img, s_bbox, is_template=False) # Apply augmentation (same color transform for template+search consistency) if self.augmentation is not None: template_aug, search, bbox_crop = self.augmentation(template, search, bbox_crop) # Only use augmented template from first search frame to keep consistency if len(searches) == 0: template = template_aug searches.append(search) heatmaps.append(self._make_heatmap(bbox_crop)) sizes.append(torch.tensor([bbox_crop[2].item() / self.search_size, bbox_crop[3].item() / self.search_size])) boxes.append(bbox_crop) return { 'template': template, # (3, 128, 128) 'searches': torch.stack(searches, dim=0), # (K, 3, 256, 256) 'heatmaps': torch.stack(heatmaps, dim=0), # (K, 1, 16, 16) 'sizes': torch.stack(sizes, dim=0), # (K, 2) 'boxes': torch.stack(boxes, dim=0), # (K, 4) } def set_acl_difficulty(self, difficulty: float): """Update ACL difficulty level (0.0 = easy, 1.0 = hard).""" self.acl_difficulty = min(1.0, max(0.0, difficulty)) # ============================================================ # GOT-10k dataset loader # ============================================================ class GOT10kDataset(SequenceDataset): """GOT-10k tracking dataset. Structure: root/train/GOT-10k_Train_NNNNNN/ 00000001.jpg, 00000002.jpg, ... groundtruth.txt # x,y,w,h per line """ def __init__(self, root: str, split: str = 'train', **kwargs): super().__init__(**kwargs) self.root = Path(root) self._load_sequences(split) def _load_sequences(self, split): split_dir = self.root / split if not split_dir.exists(): print(f"Warning: GOT-10k {split} not found at {split_dir}") return seq_dirs = sorted([d for d in split_dir.iterdir() if d.is_dir() and 'Train' in d.name]) print(f"Loading GOT-10k {split}: found {len(seq_dirs)} sequences") for seq_dir in seq_dirs: gt_file = seq_dir / 'groundtruth.txt' if not gt_file.exists(): continue # Load annotations gt_boxes = [] with open(gt_file, 'r') as f: for line in f: line = line.strip() if not line: gt_boxes.append(None) continue parts = line.replace(',', ' ').split() try: gt_boxes.append([float(x) for x in parts[:4]]) except ValueError: gt_boxes.append(None) # Get frame paths frames = sorted(glob.glob(str(seq_dir / '*.jpg'))) if not frames: frames = sorted(glob.glob(str(seq_dir / '*.png'))) if len(frames) != len(gt_boxes): # Trim to shorter min_len = min(len(frames), len(gt_boxes)) frames = frames[:min_len] gt_boxes = gt_boxes[:min_len] if len(frames) >= 2: self.sequences.append({'frames': frames, 'gt': gt_boxes}) print(f" Loaded {len(self.sequences)} GOT-10k sequences") # ============================================================ # LaSOT dataset loader # ============================================================ class LaSOTDataset(SequenceDataset): """LaSOT tracking dataset. Structure: root/ airplane/ airplane-1/ img/ 00000001.jpg, ... groundtruth.txt # x,y,w,h per line ... """ def __init__(self, root: str, split: str = 'train', **kwargs): super().__init__(**kwargs) self.root = Path(root) self._load_sequences(split) def _load_sequences(self, split): if not self.root.exists(): print(f"Warning: LaSOT not found at {self.root}") return # LaSOT train/test split defined by sequence names # Training: first 80% of sequences per category categories = sorted([d for d in self.root.iterdir() if d.is_dir()]) total_seqs = 0 for cat_dir in categories: seq_dirs = sorted([d for d in cat_dir.iterdir() if d.is_dir()]) # Train/test split if split == 'train': seq_dirs = seq_dirs[:int(len(seq_dirs) * 0.8)] else: seq_dirs = seq_dirs[int(len(seq_dirs) * 0.8):] for seq_dir in seq_dirs: gt_file = seq_dir / 'groundtruth.txt' img_dir = seq_dir / 'img' if not gt_file.exists() or not img_dir.exists(): continue # Load annotations gt_boxes = [] with open(gt_file, 'r') as f: for line in f: line = line.strip() if not line: gt_boxes.append(None) continue parts = line.replace(',', ' ').split() try: gt_boxes.append([float(x) for x in parts[:4]]) except ValueError: gt_boxes.append(None) frames = sorted(glob.glob(str(img_dir / '*.jpg'))) if len(frames) != len(gt_boxes): min_len = min(len(frames), len(gt_boxes)) frames = frames[:min_len] gt_boxes = gt_boxes[:min_len] if len(frames) >= 2: self.sequences.append({'frames': frames, 'gt': gt_boxes}) total_seqs += 1 print(f" Loaded {total_seqs} LaSOT {split} sequences across {len(categories)} categories") # ============================================================ # TrackingNet dataset loader # ============================================================ class TrackingNetDataset(SequenceDataset): """TrackingNet tracking dataset. Structure: root/ TRAIN_0/ frames/ video_name/ 0.jpg, 1.jpg, ... anno/ video_name.txt # x,y,w,h per line TRAIN_1/ ... """ def __init__(self, root: str, chunks: list = None, **kwargs): super().__init__(**kwargs) self.root = Path(root) if chunks is None: chunks = list(range(12)) # TRAIN_0 through TRAIN_11 self._load_sequences(chunks) def _load_sequences(self, chunks): if not self.root.exists(): print(f"Warning: TrackingNet not found at {self.root}") return total_seqs = 0 for chunk_idx in chunks: chunk_dir = self.root / f'TRAIN_{chunk_idx}' if not chunk_dir.exists(): continue anno_dir = chunk_dir / 'anno' frames_dir = chunk_dir / 'frames' if not anno_dir.exists() or not frames_dir.exists(): continue for anno_file in sorted(anno_dir.glob('*.txt')): seq_name = anno_file.stem seq_frames_dir = frames_dir / seq_name if not seq_frames_dir.exists(): continue # Load annotations gt_boxes = [] with open(anno_file, 'r') as f: for line in f: line = line.strip() if not line: gt_boxes.append(None) continue parts = line.replace(',', ' ').split() try: gt_boxes.append([float(x) for x in parts[:4]]) except ValueError: gt_boxes.append(None) frames = sorted(glob.glob(str(seq_frames_dir / '*.jpg'))) if not frames: frames = sorted(glob.glob(str(seq_frames_dir / '*.png'))) if len(frames) != len(gt_boxes): min_len = min(len(frames), len(gt_boxes)) frames = frames[:min_len] gt_boxes = gt_boxes[:min_len] if len(frames) >= 2: self.sequences.append({'frames': frames, 'gt': gt_boxes}) total_seqs += 1 print(f" Loaded {total_seqs} TrackingNet sequences from {len(chunks)} chunks") # ============================================================ # COCO detection as pseudo-sequences # ============================================================ class COCODetDataset(SequenceDataset): """COCO detection images as pseudo-sequences for pretraining. Each image with a valid bounding box becomes a length-1 "sequence" where template and search are crops from the same image. """ def __init__(self, root: str, ann_file: str = None, **kwargs): super().__init__(**kwargs) self.root = Path(root) self._load_annotations(ann_file) def _load_annotations(self, ann_file): if ann_file is None: ann_file = str(self.root.parent / 'annotations' / 'instances_train2017.json') if not os.path.exists(ann_file): print(f"Warning: COCO annotations not found at {ann_file}") return try: import json with open(ann_file, 'r') as f: coco = json.load(f) # Build image lookup images = {img['id']: img for img in coco['images']} # Create pseudo-sequences from annotations for ann in coco['annotations']: if ann.get('iscrowd', 0): continue bbox = ann['bbox'] # [x, y, w, h] if bbox[2] < 10 or bbox[3] < 10: continue img_info = images.get(ann['image_id']) if img_info is None: continue img_path = str(self.root / img_info['file_name']) if os.path.exists(img_path): # Pseudo-sequence: same frame for template and search self.sequences.append({ 'frames': [img_path, img_path], 'gt': [bbox, bbox], }) print(f" Loaded {len(self.sequences)} COCO pseudo-sequences") except Exception as e: print(f"Warning: Failed to load COCO annotations: {e}") # ============================================================ # Synthetic dataset (for testing / no-data development) # ============================================================ class SyntheticTrackingDataset(Dataset): """Synthetic tracking dataset for testing without real data. Generates K-frame clips: template + K search frames with a moving colored rectangle target. Motion is linear with noise. """ def __init__( self, length: int = 10000, template_size: int = 128, search_size: int = 256, feat_size: int = 16, acl_difficulty: float = 1.0, clip_length: int = 3, ): super().__init__() self.length = length self.template_size = template_size self.search_size = search_size self.feat_size = feat_size self.acl_difficulty = acl_difficulty self.clip_length = clip_length def __len__(self): return self.length def _make_heatmap(self, cx, cy, w_search, h_search): stride = self.search_size / self.feat_size cx_feat = cx / stride cy_feat = cy / stride y = torch.arange(self.feat_size, dtype=torch.float32) x = torch.arange(self.feat_size, dtype=torch.float32) yy, xx = torch.meshgrid(y, x, indexing='ij') sigma = max(1.0, min(3.0, (w_search + h_search) / (2 * stride * 4))) dist_sq = (xx - cx_feat) ** 2 + (yy - cy_feat) ** 2 return torch.exp(-dist_sq / (2 * sigma ** 2)).unsqueeze(0) def __getitem__(self, idx): rng = random.Random(idx) K = self.clip_length # Target appearance color = torch.tensor([rng.random(), rng.random(), rng.random()]).view(3, 1, 1) target_w = rng.uniform(0.1, 0.5) * self.search_size target_h = rng.uniform(0.1, 0.5) * self.search_size # Initial position (center of search) cx0 = self.search_size / 2 cy0 = self.search_size / 2 # Velocity (pixels per frame, scaled by difficulty) vx = rng.gauss(0, self.acl_difficulty * 15) vy = rng.gauss(0, self.acl_difficulty * 15) # Template: target at center template = torch.randn(3, self.template_size, self.template_size) * 0.1 t_hw = int(min(target_w / 2, self.template_size / 2 - 1)) t_hh = int(min(target_h / 2, self.template_size / 2 - 1)) tc = self.template_size // 2 template[:, tc - t_hh:tc + t_hh, tc - t_hw:tc + t_hw] = color # K search frames with moving target searches = [] heatmaps = [] sizes = [] boxes = [] for k in range(K): # Position at frame k cx = cx0 + vx * (k + 1) + rng.gauss(0, self.acl_difficulty * 5) cy = cy0 + vy * (k + 1) + rng.gauss(0, self.acl_difficulty * 5) cx = max(target_w / 2, min(self.search_size - target_w / 2, cx)) cy = max(target_h / 2, min(self.search_size - target_h / 2, cy)) search = torch.randn(3, self.search_size, self.search_size) * 0.1 sx1 = max(0, int(cx - target_w / 2)) sy1 = max(0, int(cy - target_h / 2)) sx2 = min(self.search_size, int(cx + target_w / 2)) sy2 = min(self.search_size, int(cy + target_h / 2)) search[:, sy1:sy2, sx1:sx2] = color searches.append(search) heatmaps.append(self._make_heatmap(cx, cy, target_w, target_h)) sizes.append(torch.tensor([target_w / self.search_size, target_h / self.search_size])) boxes.append(torch.tensor([cx, cy, target_w, target_h])) return { 'template': template, # (3, 128, 128) 'searches': torch.stack(searches, dim=0), # (K, 3, 256, 256) 'heatmaps': torch.stack(heatmaps, dim=0), # (K, 1, 16, 16) 'sizes': torch.stack(sizes, dim=0), # (K, 2) 'boxes': torch.stack(boxes, dim=0), # (K, 4) } def set_acl_difficulty(self, difficulty: float): self.acl_difficulty = min(1.0, max(0.0, difficulty)) # ============================================================ # VisDrone-SOT dataset loader (UAV) # ============================================================ class VisDroneSOTDataset(SequenceDataset): """VisDrone-SOT single object tracking dataset (drone/UAV perspective). Structure: root/ VisDrone2019-SOT-train/ sequences/ uav0000001_00000_s/ 0000001.jpg, 0000002.jpg, ... ... annotations/ uav0000001_00000_s.txt # x,y,w,h per line ... Splits: train (86 sequences, ~70K frames), val (11 sequences), test-dev (35 sequences), test-challenge (35 sequences) Key for our tracker: real drone footage with small targets, fast motion, viewpoint changes, and camera ego-motion — the exact conditions we deploy in. """ def __init__(self, root: str, split: str = 'train', **kwargs): super().__init__(**kwargs) self.root = Path(root) self._load_sequences(split) def _load_sequences(self, split): # Try multiple directory naming conventions split_names = { 'train': ['VisDrone2019-SOT-train', 'VisDrone2018-SOT-train', 'train'], 'val': ['VisDrone2019-SOT-val', 'VisDrone2018-SOT-val', 'val'], 'test': ['VisDrone2019-SOT-test-dev', 'VisDrone2018-SOT-test', 'test-dev', 'test'], } split_dir = None for name in split_names.get(split, [split]): candidate = self.root / name if candidate.exists(): split_dir = candidate break # Also check if root itself is the split dir if (self.root / 'sequences').exists(): split_dir = self.root break if split_dir is None: print(f"Warning: VisDrone-SOT {split} not found at {self.root}") return seq_dir = split_dir / 'sequences' anno_dir = split_dir / 'annotations' if not seq_dir.exists() or not anno_dir.exists(): print(f"Warning: VisDrone-SOT missing sequences/ or annotations/ at {split_dir}") return total_seqs = 0 for anno_file in sorted(anno_dir.glob('*.txt')): seq_name = anno_file.stem frames_dir = seq_dir / seq_name if not frames_dir.exists(): continue gt_boxes = [] with open(anno_file, 'r') as f: for line in f: line = line.strip() if not line: gt_boxes.append(None) continue parts = line.replace(',', ' ').split() try: gt_boxes.append([float(x) for x in parts[:4]]) except ValueError: gt_boxes.append(None) frames = sorted(glob.glob(str(frames_dir / '*.jpg'))) if not frames: frames = sorted(glob.glob(str(frames_dir / '*.png'))) if len(frames) != len(gt_boxes): min_len = min(len(frames), len(gt_boxes)) frames = frames[:min_len] gt_boxes = gt_boxes[:min_len] if len(frames) >= 2: self.sequences.append({'frames': frames, 'gt': gt_boxes}) total_seqs += 1 print(f" Loaded {total_seqs} VisDrone-SOT {split} sequences") # ============================================================ # UAVDT dataset loader (UAV) # ============================================================ class UAVDTDataset(SequenceDataset): """UAVDT (Unmanned Aerial Vehicle Detection and Tracking) dataset. Structure: root/ UAV-benchmark-S/ # SOT annotations {seq_name}/ {seq_name}_gt.txt # x,y,w,h per line (or comma-separated) UAV-benchmark-M/ # Frames {seq_name}/ img000001.jpg, img000002.jpg, ... Alternative structure (simpler): root/ sequences/ {seq_name}/ img000001.jpg, ... annotations/ {seq_name}_gt.txt 50 sequences total, typically 30 train / 20 test. Contains vehicle tracking from drone perspective — complementary to VisDrone. """ def __init__(self, root: str, split: str = 'train', **kwargs): super().__init__(**kwargs) self.root = Path(root) self._load_sequences(split) def _load_sequences(self, split): # Try standard UAVDT structure anno_dir = self.root / 'UAV-benchmark-S' frame_dir = self.root / 'UAV-benchmark-M' if not anno_dir.exists(): # Alternative structure anno_dir = self.root / 'annotations' frame_dir = self.root / 'sequences' if not anno_dir.exists(): # Try root directly having sequence dirs anno_dir = self.root frame_dir = self.root if not anno_dir.exists(): print(f"Warning: UAVDT not found at {self.root}") return # Collect all sequences all_seqs = [] # Find annotation files gt_files = sorted(anno_dir.rglob('*_gt.txt')) if not gt_files: gt_files = sorted(anno_dir.rglob('*.txt')) for gt_file in gt_files: seq_name = gt_file.stem.replace('_gt', '') # Find frames directory frames_path = None for candidate in [ frame_dir / seq_name, frame_dir / seq_name / 'img', self.root / seq_name, ]: if candidate.exists(): frames_path = candidate break if frames_path is None: continue gt_boxes = [] with open(gt_file, 'r') as f: for line in f: line = line.strip() if not line: gt_boxes.append(None) continue parts = line.replace(',', ' ').replace('\t', ' ').split() try: gt_boxes.append([float(x) for x in parts[:4]]) except (ValueError, IndexError): gt_boxes.append(None) frames = sorted(glob.glob(str(frames_path / '*.jpg'))) if not frames: frames = sorted(glob.glob(str(frames_path / '*.png'))) if len(frames) != len(gt_boxes): min_len = min(len(frames), len(gt_boxes)) frames = frames[:min_len] gt_boxes = gt_boxes[:min_len] if len(frames) >= 2: all_seqs.append({'frames': frames, 'gt': gt_boxes, 'name': seq_name}) # Split: first 60% train, last 40% test (standard UAVDT protocol) all_seqs.sort(key=lambda x: x['name']) split_idx = int(len(all_seqs) * 0.6) if split == 'train': selected = all_seqs[:split_idx] else: selected = all_seqs[split_idx:] for seq in selected: self.sequences.append({'frames': seq['frames'], 'gt': seq['gt']}) print(f" Loaded {len(self.sequences)} UAVDT {split} sequences " f"(from {len(all_seqs)} total)") # ============================================================ # WebUAV-3M dataset loader (UAV, large-scale) # ============================================================ class WebUAV3MDataset(SequenceDataset): """WebUAV-3M: million-scale multi-modal UAV tracking dataset. Structure: root/ {superclass}/ # e.g., person, vehicle, animal {seq_name}/ img/ 000001.jpg, 000002.jpg, ... groundtruth_rect.txt # x,y,w,h per line OR: {seq_name}/ *.jpg groundtruth_rect.txt 4,500 sequences, 3.3M frames, 12 superclasses, 223 target classes. Average video length: 710 frames (23.7 seconds at 30 FPS). This is the largest UAV tracking dataset. All sequences are from real drone footage. Purpose-built for training deep UAV trackers. """ def __init__(self, root: str, split: str = 'train', max_sequences: int = None, **kwargs): super().__init__(**kwargs) self.root = Path(root) self._load_sequences(split, max_sequences) def _load_sequences(self, split, max_sequences): if not self.root.exists(): print(f"Warning: WebUAV-3M not found at {self.root}") return # Find all sequences recursively all_seq_dirs = [] # Look for groundtruth files recursively gt_files = sorted(self.root.rglob('groundtruth_rect.txt')) if not gt_files: gt_files = sorted(self.root.rglob('groundtruth.txt')) for gt_file in gt_files: seq_dir = gt_file.parent # Check for img subdirectory or direct frames img_dir = seq_dir / 'img' if not img_dir.exists(): img_dir = seq_dir # frames directly in seq dir frames = sorted(glob.glob(str(img_dir / '*.jpg'))) if not frames: frames = sorted(glob.glob(str(img_dir / '*.png'))) if len(frames) >= 2: all_seq_dirs.append((gt_file, frames)) print(f"WebUAV-3M: found {len(all_seq_dirs)} sequences total") # Train/test split (80/20) split_idx = int(len(all_seq_dirs) * 0.8) if split == 'train': selected = all_seq_dirs[:split_idx] else: selected = all_seq_dirs[split_idx:] # Optionally limit sequences (WebUAV-3M is huge) if max_sequences and len(selected) > max_sequences: # Sample uniformly to maintain diversity step = len(selected) // max_sequences selected = selected[::step][:max_sequences] for gt_file, frames in selected: gt_boxes = [] with open(gt_file, 'r') as f: for line in f: line = line.strip() if not line: gt_boxes.append(None) continue parts = line.replace(',', ' ').replace('\t', ' ').split() try: gt_boxes.append([float(x) for x in parts[:4]]) except (ValueError, IndexError): gt_boxes.append(None) if len(frames) != len(gt_boxes): min_len = min(len(frames), len(gt_boxes)) frames = frames[:min_len] gt_boxes = gt_boxes[:min_len] if len(frames) >= 2: self.sequences.append({'frames': frames, 'gt': gt_boxes}) print(f" Loaded {len(self.sequences)} WebUAV-3M {split} sequences") # ============================================================ # Convenience: build combined dataset # ============================================================ def build_tracking_dataset( data_config: dict, template_size: int = 128, search_size: int = 256, feat_size: int = 16, acl_difficulty: float = 0.0, ) -> Dataset: """Build a combined tracking dataset from multiple sources. Standard ground-level datasets provide general tracking capability. UAV-specific datasets provide drone-perspective specialization. The ACL curriculum bridges the gap: it starts training on easy pairs from ground-level data, then progressively incorporates harder pairs including UAV sequences with fast motion, small targets, and viewpoint changes. Args: data_config: dict with optional keys: Ground-level (standard tracking training data): - 'got10k_root': path to GOT-10k dataset - 'lasot_root': path to LaSOT dataset - 'trackingnet_root': path to TrackingNet dataset - 'coco_root': path to COCO train2017 images UAV-specific (drone perspective — the deployment domain): - 'visdrone_root': path to VisDrone-SOT dataset - 'uavdt_root': path to UAVDT dataset - 'webuav3m_root': path to WebUAV-3M dataset - 'webuav3m_max_sequences': limit WebUAV-3M sequences (default: None = all) Fallback: - 'synthetic_length': number of synthetic samples (fallback) template_size: template crop size search_size: search region crop size feat_size: feature map spatial size acl_difficulty: initial ACL difficulty Returns: ConcatDataset or SyntheticTrackingDataset """ common_kwargs = dict( template_size=template_size, search_size=search_size, feat_size=feat_size, acl_difficulty=acl_difficulty, ) datasets = [] if 'got10k_root' in data_config and os.path.exists(data_config['got10k_root']): ds = GOT10kDataset(data_config['got10k_root'], split='train', **common_kwargs) if len(ds) > 0: datasets.append(ds) print(f"GOT-10k: {len(ds)} sequences") if 'lasot_root' in data_config and os.path.exists(data_config['lasot_root']): ds = LaSOTDataset(data_config['lasot_root'], split='train', **common_kwargs) if len(ds) > 0: datasets.append(ds) print(f"LaSOT: {len(ds)} sequences") if 'trackingnet_root' in data_config and os.path.exists(data_config['trackingnet_root']): ds = TrackingNetDataset(data_config['trackingnet_root'], **common_kwargs) if len(ds) > 0: datasets.append(ds) print(f"TrackingNet: {len(ds)} sequences") if 'coco_root' in data_config and os.path.exists(data_config['coco_root']): ds = COCODetDataset(data_config['coco_root'], **common_kwargs) if len(ds) > 0: datasets.append(ds) print(f"COCO: {len(ds)} pseudo-sequences") # --- UAV-specific datasets (drone perspective) --- if 'visdrone_root' in data_config and os.path.exists(data_config['visdrone_root']): ds = VisDroneSOTDataset(data_config['visdrone_root'], split='train', **common_kwargs) if len(ds) > 0: datasets.append(ds) print(f"VisDrone-SOT: {len(ds)} UAV sequences") if 'uavdt_root' in data_config and os.path.exists(data_config['uavdt_root']): ds = UAVDTDataset(data_config['uavdt_root'], split='train', **common_kwargs) if len(ds) > 0: datasets.append(ds) print(f"UAVDT: {len(ds)} UAV sequences") if 'webuav3m_root' in data_config and os.path.exists(data_config['webuav3m_root']): max_seq = data_config.get('webuav3m_max_sequences', None) ds = WebUAV3MDataset(data_config['webuav3m_root'], split='train', max_sequences=max_seq, **common_kwargs) if len(ds) > 0: datasets.append(ds) print(f"WebUAV-3M: {len(ds)} UAV sequences") if datasets: combined = ConcatDataset(datasets) print(f"\nTotal training samples: {len(combined)}") return combined # Fallback to synthetic syn_len = data_config.get('synthetic_length', 10000) print(f"No real data found, using {syn_len} synthetic samples") return SyntheticTrackingDataset( length=syn_len, template_size=template_size, search_size=search_size, feat_size=feat_size, acl_difficulty=acl_difficulty, ) # ============================================================ # Legacy alias for backward compatibility # ============================================================ class TrackingDataset(SyntheticTrackingDataset): """Backward-compatible alias for SyntheticTrackingDataset.""" def __init__(self, data_dir=None, split='train', synthetic=False, synthetic_length=10000, clip_length=3, **kwargs): super().__init__(length=synthetic_length, clip_length=clip_length, **kwargs)