vil-tracker / vil_tracker /data /dataset.py
omar-ah's picture
Sequence training: pairs→K-frame clips, mLSTM memory carries across frames
9bef6c8 verified
"""
Tracking dataset with real dataset loaders and synthetic fallback.
Supports:
- GOT-10k: train split (~10k sequences, annotations in groundtruth.txt)
- LaSOT: training split (1120 sequences, 14 categories)
- TrackingNet: training split (30k+ sequences, annotations in anno/)
- COCO detection: for static pair pretraining (bbox crops as pseudo-sequences)
- Synthetic data generation for testing (no external data needed)
- ACL (Adaptive Curriculum Learning) difficulty scaling
- Standard tracking augmentations: spatial jitter, horizontal flip, color jitter,
grayscale, Gaussian blur, brightness/contrast
Each sample produces a (template, search) pair from the same video sequence
with controlled temporal distance, plus GT annotations.
Dataset directory structure expected:
GOT-10k/
train/
GOT-10k_Train_000001/
00000001.jpg, 00000002.jpg, ...
groundtruth.txt # x,y,w,h per line
...
LaSOT/
airplane/
airplane-1/
img/
00000001.jpg, ...
groundtruth.txt # x,y,w,h per line
...
TrackingNet/
TRAIN_0/
frames/
video_name/
0.jpg, 1.jpg, ...
anno/
video_name.txt # x,y,w,h per line
...
COCO/
train2017/
*.jpg
annotations/
instances_train2017.json
"""
import os
import math
import glob
import random
import torch
import numpy as np
from pathlib import Path
from torch.utils.data import Dataset, ConcatDataset
# ============================================================
# Augmentations (no torchvision dependency, works with tensors)
# ============================================================
class TrackingAugmentation:
"""Standard tracking augmentations applied to (template, search) pairs.
Augmentations preserve the spatial relationship between search region
and GT bounding box by applying augmentations consistently.
"""
def __init__(
self,
brightness: float = 0.2,
contrast: float = 0.2,
saturation: float = 0.2,
grayscale_prob: float = 0.05,
horizontal_flip_prob: float = 0.5,
blur_prob: float = 0.1,
blur_sigma: tuple = (0.1, 2.0),
):
self.brightness = brightness
self.contrast = contrast
self.saturation = saturation
self.grayscale_prob = grayscale_prob
self.horizontal_flip_prob = horizontal_flip_prob
self.blur_prob = blur_prob
self.blur_sigma = blur_sigma
def __call__(self, template: torch.Tensor, search: torch.Tensor,
bbox: torch.Tensor) -> tuple:
"""
Args:
template: (3, H_t, W_t) tensor in [0, 1]
search: (3, H_s, W_s) tensor in [0, 1]
bbox: (4,) tensor [cx, cy, w, h] in search region pixels
Returns:
template, search, bbox (augmented)
"""
# Color jitter (same for template and search to maintain appearance consistency)
if random.random() < 0.8:
# Brightness
factor = 1.0 + random.uniform(-self.brightness, self.brightness)
template = (template * factor).clamp(0, 1)
search = (search * factor).clamp(0, 1)
# Contrast
factor = 1.0 + random.uniform(-self.contrast, self.contrast)
t_mean = template.mean()
s_mean = search.mean()
template = ((template - t_mean) * factor + t_mean).clamp(0, 1)
search = ((search - s_mean) * factor + s_mean).clamp(0, 1)
# Grayscale
if random.random() < self.grayscale_prob:
t_gray = template.mean(dim=0, keepdim=True).expand_as(template)
s_gray = search.mean(dim=0, keepdim=True).expand_as(search)
template = t_gray
search = s_gray
# Horizontal flip (must also flip bbox cx)
if random.random() < self.horizontal_flip_prob:
template = template.flip(-1)
search = search.flip(-1)
W_s = search.shape[-1]
bbox = bbox.clone()
bbox[0] = W_s - bbox[0] # flip cx
# Gaussian blur (search only — simulates motion blur)
if random.random() < self.blur_prob:
sigma = random.uniform(*self.blur_sigma)
kernel_size = int(2 * round(3 * sigma) + 1)
if kernel_size >= 3:
search = self._gaussian_blur(search, kernel_size, sigma)
return template, search, bbox
@staticmethod
def _gaussian_blur(img: torch.Tensor, kernel_size: int, sigma: float) -> torch.Tensor:
"""Apply Gaussian blur to a (C, H, W) tensor."""
import torch.nn.functional as F
# Create 1D Gaussian kernel
x = torch.arange(kernel_size, dtype=img.dtype, device=img.device) - kernel_size // 2
kernel_1d = torch.exp(-0.5 * (x / sigma) ** 2)
kernel_1d = kernel_1d / kernel_1d.sum()
# Apply separable 2D blur
pad = kernel_size // 2
img = img.unsqueeze(0) # (1, C, H, W)
# Horizontal
k_h = kernel_1d.view(1, 1, 1, -1).expand(img.shape[1], -1, -1, -1)
img = F.conv2d(F.pad(img, (pad, pad, 0, 0), mode='reflect'),
k_h, groups=img.shape[1])
# Vertical
k_v = kernel_1d.view(1, 1, -1, 1).expand(img.shape[1], -1, -1, -1)
img = F.conv2d(F.pad(img, (0, 0, pad, pad), mode='reflect'),
k_v, groups=img.shape[1])
return img.squeeze(0)
# ============================================================
# Crop utilities
# ============================================================
def crop_and_resize(image: np.ndarray, center: np.ndarray, size: float,
output_size: int) -> np.ndarray:
"""Crop a square region from image, centered at center, with given size.
Args:
image: (H, W, 3) numpy array, uint8 or float
center: (2,) [cx, cy] in image coordinates
size: side length of the square crop
output_size: resize crop to (output_size, output_size)
Returns:
(output_size, output_size, 3) numpy array
"""
H, W = image.shape[:2]
half = size / 2
x1 = int(round(center[0] - half))
y1 = int(round(center[1] - half))
x2 = int(round(center[0] + half))
y2 = int(round(center[1] + half))
# Boundary padding
pad_left = max(0, -x1)
pad_top = max(0, -y1)
pad_right = max(0, x2 - W)
pad_bottom = max(0, y2 - H)
x1c = max(0, x1)
y1c = max(0, y1)
x2c = min(W, x2)
y2c = min(H, y2)
crop = image[y1c:y2c, x1c:x2c]
if pad_left > 0 or pad_top > 0 or pad_right > 0 or pad_bottom > 0:
mean_color = image.mean(axis=(0, 1))
padded = np.full((crop.shape[0] + pad_top + pad_bottom,
crop.shape[1] + pad_left + pad_right, 3),
mean_color, dtype=crop.dtype)
padded[pad_top:pad_top + crop.shape[0], pad_left:pad_left + crop.shape[1]] = crop
crop = padded
# Resize
if crop.shape[0] > 0 and crop.shape[1] > 0:
import torch.nn.functional as F
crop_t = torch.from_numpy(crop.copy()).float().permute(2, 0, 1).unsqueeze(0)
crop_t = F.interpolate(crop_t, size=(output_size, output_size),
mode='bilinear', align_corners=False)
crop = crop_t.squeeze(0).permute(1, 2, 0).numpy()
else:
crop = np.zeros((output_size, output_size, 3), dtype=np.float32)
return crop
def compute_crop_params(bbox: np.ndarray, context_factor: float = 2.0) -> tuple:
"""Compute crop center and size from bbox with context.
Args:
bbox: [x, y, w, h] bounding box
context_factor: how much context around bbox (2.0 = 2x target size)
Returns:
center: (2,) [cx, cy]
crop_size: scalar side length
"""
x, y, w, h = bbox
cx = x + w / 2
cy = y + h / 2
# Context amount following STARK/OSTrack convention:
# s = sqrt((w + 2p) * (h + 2p)), where p = (w + h) / 2
p = (w + h) / 2
crop_size = math.sqrt((w + p) * (h + p)) * context_factor
crop_size = max(crop_size, 10)
return np.array([cx, cy]), crop_size
# ============================================================
# Base sequence dataset
# ============================================================
class SequenceDataset(Dataset):
"""Base class for tracking sequence datasets.
Returns K-frame clips: template + K consecutive search frames.
The mLSTM processes these as one long sequence where memory carries
information across frames — this is the core training paradigm.
Subclasses must populate self.sequences with list of:
{'frames': [path1, path2, ...], 'gt': [[x,y,w,h], ...]}
"""
def __init__(
self,
template_size: int = 128,
search_size: int = 256,
feat_size: int = 16,
acl_difficulty: float = 1.0,
max_gap: int = 100,
clip_length: int = 3,
augmentation: bool = True,
):
super().__init__()
self.template_size = template_size
self.search_size = search_size
self.feat_size = feat_size
self.acl_difficulty = acl_difficulty
self.max_gap = max_gap
self.clip_length = clip_length # K search frames per sample
self.sequences = []
self.augmentation = TrackingAugmentation() if augmentation else None
def __len__(self):
return len(self.sequences)
def _load_image(self, path: str) -> np.ndarray:
"""Load image from path. Returns (H, W, 3) float32 in [0, 255]."""
try:
from PIL import Image
img = Image.open(path).convert('RGB')
return np.array(img, dtype=np.float32)
except ImportError:
import cv2
img = cv2.imread(path)
if img is None:
return np.zeros((480, 640, 3), dtype=np.float32)
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32)
def _sample_clip(self, idx: int) -> list:
"""Sample a clip: template frame + K consecutive search frames.
Returns:
list of frame indices: [template_idx, search_1_idx, ..., search_K_idx]
"""
seq = self.sequences[idx]
n_frames = len(seq['frames'])
K = self.clip_length
valid = [i for i in range(n_frames)
if seq['gt'][i] is not None and seq['gt'][i][2] > 0 and seq['gt'][i][3] > 0]
valid_set = set(valid)
if len(valid) < K + 1:
# Not enough frames — repeat what we have
if len(valid) == 0:
return [0] * (K + 1)
return [valid[0]] + [valid[min(i, len(valid)-1)] for i in range(K)]
# Template: pick a random valid frame
t_idx = random.choice(valid)
# Search frames: K consecutive valid frames AFTER template
# Temporal gap between template and first search controlled by ACL
effective_gap = max(1, int(self.max_gap * self.acl_difficulty))
# Find the start of the search clip: somewhere after template
min_start = t_idx + 1
max_start = min(t_idx + effective_gap, n_frames - K)
if max_start < min_start:
# Try before template
max_start_before = t_idx - K
min_start_before = max(0, t_idx - effective_gap - K)
if max_start_before >= min_start_before and max_start_before >= 0:
clip_start = random.randint(min_start_before, max_start_before)
else:
# Fallback: just use whatever consecutive frames we can find
clip_start = max(0, min(n_frames - K, t_idx + 1))
# But ensure template is different from search frames
else:
clip_start = random.randint(min_start, max(min_start, max_start))
# Collect K consecutive frames, preferring valid ones
search_indices = []
for i in range(clip_start, min(clip_start + K * 3, n_frames)):
if i in valid_set and i != t_idx:
search_indices.append(i)
if len(search_indices) == K:
break
# Pad if we didn't find enough
while len(search_indices) < K:
search_indices.append(search_indices[-1] if search_indices else t_idx)
return [t_idx] + search_indices[:K]
def _process_frame(self, img: np.ndarray, bbox: np.ndarray, is_template: bool):
"""Crop and preprocess a single frame.
Returns:
image_tensor: (3, H, W) float [0, 1]
bbox_in_crop: (4,) [cx, cy, w, h] in crop coordinates
"""
if is_template:
center, crop_size = compute_crop_params(bbox, context_factor=2.0)
output_size = self.template_size
else:
center, crop_size = compute_crop_params(bbox, context_factor=4.0)
output_size = self.search_size
# Spatial jitter for search (controlled by ACL)
jitter = self.acl_difficulty * bbox[2:4].mean() * 0.3
if jitter > 0:
center[0] += random.gauss(0, jitter)
center[1] += random.gauss(0, jitter)
crop = crop_and_resize(img, center, crop_size, output_size)
# Compute GT in crop coordinates
scale = output_size / crop_size
cx = (bbox[0] + bbox[2] / 2 - center[0] + crop_size / 2) * scale
cy = (bbox[1] + bbox[3] / 2 - center[1] + crop_size / 2) * scale
w = bbox[2] * scale
h = bbox[3] * scale
cx = max(0, min(output_size, cx))
cy = max(0, min(output_size, cy))
w = max(1, min(output_size, w))
h = max(1, min(output_size, h))
tensor = torch.from_numpy(crop).float().permute(2, 0, 1) / 255.0
bbox_crop = torch.tensor([cx, cy, w, h])
return tensor, bbox_crop
def _make_heatmap(self, bbox: torch.Tensor):
"""Generate GT heatmap from bbox in search crop coordinates."""
stride = self.search_size / self.feat_size
cx_feat = bbox[0].item() / stride
cy_feat = bbox[1].item() / stride
w_search = bbox[2].item()
h_search = bbox[3].item()
y = torch.arange(self.feat_size, dtype=torch.float32)
x = torch.arange(self.feat_size, dtype=torch.float32)
yy, xx = torch.meshgrid(y, x, indexing='ij')
sigma = max(1.0, min(3.0, (w_search + h_search) / (2 * stride * 4)))
dist_sq = (xx - cx_feat) ** 2 + (yy - cy_feat) ** 2
heatmap = torch.exp(-dist_sq / (2 * sigma ** 2)).unsqueeze(0)
return heatmap
def __getitem__(self, idx):
seq = self.sequences[idx % len(self.sequences)]
clip_indices = self._sample_clip(idx % len(self.sequences))
t_idx = clip_indices[0]
s_indices = clip_indices[1:]
K = len(s_indices)
# Load and process template
t_img = self._load_image(seq['frames'][t_idx])
t_bbox = np.array(seq['gt'][t_idx], dtype=np.float32)
template, _ = self._process_frame(t_img, t_bbox, is_template=True)
# Load and process K search frames
searches = []
heatmaps = []
sizes = []
boxes = []
for s_idx in s_indices:
s_img = self._load_image(seq['frames'][s_idx])
s_bbox = np.array(seq['gt'][s_idx], dtype=np.float32)
search, bbox_crop = self._process_frame(s_img, s_bbox, is_template=False)
# Apply augmentation (same color transform for template+search consistency)
if self.augmentation is not None:
template_aug, search, bbox_crop = self.augmentation(template, search, bbox_crop)
# Only use augmented template from first search frame to keep consistency
if len(searches) == 0:
template = template_aug
searches.append(search)
heatmaps.append(self._make_heatmap(bbox_crop))
sizes.append(torch.tensor([bbox_crop[2].item() / self.search_size,
bbox_crop[3].item() / self.search_size]))
boxes.append(bbox_crop)
return {
'template': template, # (3, 128, 128)
'searches': torch.stack(searches, dim=0), # (K, 3, 256, 256)
'heatmaps': torch.stack(heatmaps, dim=0), # (K, 1, 16, 16)
'sizes': torch.stack(sizes, dim=0), # (K, 2)
'boxes': torch.stack(boxes, dim=0), # (K, 4)
}
def set_acl_difficulty(self, difficulty: float):
"""Update ACL difficulty level (0.0 = easy, 1.0 = hard)."""
self.acl_difficulty = min(1.0, max(0.0, difficulty))
# ============================================================
# GOT-10k dataset loader
# ============================================================
class GOT10kDataset(SequenceDataset):
"""GOT-10k tracking dataset.
Structure:
root/train/GOT-10k_Train_NNNNNN/
00000001.jpg, 00000002.jpg, ...
groundtruth.txt # x,y,w,h per line
"""
def __init__(self, root: str, split: str = 'train', **kwargs):
super().__init__(**kwargs)
self.root = Path(root)
self._load_sequences(split)
def _load_sequences(self, split):
split_dir = self.root / split
if not split_dir.exists():
print(f"Warning: GOT-10k {split} not found at {split_dir}")
return
seq_dirs = sorted([d for d in split_dir.iterdir() if d.is_dir() and 'Train' in d.name])
print(f"Loading GOT-10k {split}: found {len(seq_dirs)} sequences")
for seq_dir in seq_dirs:
gt_file = seq_dir / 'groundtruth.txt'
if not gt_file.exists():
continue
# Load annotations
gt_boxes = []
with open(gt_file, 'r') as f:
for line in f:
line = line.strip()
if not line:
gt_boxes.append(None)
continue
parts = line.replace(',', ' ').split()
try:
gt_boxes.append([float(x) for x in parts[:4]])
except ValueError:
gt_boxes.append(None)
# Get frame paths
frames = sorted(glob.glob(str(seq_dir / '*.jpg')))
if not frames:
frames = sorted(glob.glob(str(seq_dir / '*.png')))
if len(frames) != len(gt_boxes):
# Trim to shorter
min_len = min(len(frames), len(gt_boxes))
frames = frames[:min_len]
gt_boxes = gt_boxes[:min_len]
if len(frames) >= 2:
self.sequences.append({'frames': frames, 'gt': gt_boxes})
print(f" Loaded {len(self.sequences)} GOT-10k sequences")
# ============================================================
# LaSOT dataset loader
# ============================================================
class LaSOTDataset(SequenceDataset):
"""LaSOT tracking dataset.
Structure:
root/
airplane/
airplane-1/
img/
00000001.jpg, ...
groundtruth.txt # x,y,w,h per line
...
"""
def __init__(self, root: str, split: str = 'train', **kwargs):
super().__init__(**kwargs)
self.root = Path(root)
self._load_sequences(split)
def _load_sequences(self, split):
if not self.root.exists():
print(f"Warning: LaSOT not found at {self.root}")
return
# LaSOT train/test split defined by sequence names
# Training: first 80% of sequences per category
categories = sorted([d for d in self.root.iterdir() if d.is_dir()])
total_seqs = 0
for cat_dir in categories:
seq_dirs = sorted([d for d in cat_dir.iterdir() if d.is_dir()])
# Train/test split
if split == 'train':
seq_dirs = seq_dirs[:int(len(seq_dirs) * 0.8)]
else:
seq_dirs = seq_dirs[int(len(seq_dirs) * 0.8):]
for seq_dir in seq_dirs:
gt_file = seq_dir / 'groundtruth.txt'
img_dir = seq_dir / 'img'
if not gt_file.exists() or not img_dir.exists():
continue
# Load annotations
gt_boxes = []
with open(gt_file, 'r') as f:
for line in f:
line = line.strip()
if not line:
gt_boxes.append(None)
continue
parts = line.replace(',', ' ').split()
try:
gt_boxes.append([float(x) for x in parts[:4]])
except ValueError:
gt_boxes.append(None)
frames = sorted(glob.glob(str(img_dir / '*.jpg')))
if len(frames) != len(gt_boxes):
min_len = min(len(frames), len(gt_boxes))
frames = frames[:min_len]
gt_boxes = gt_boxes[:min_len]
if len(frames) >= 2:
self.sequences.append({'frames': frames, 'gt': gt_boxes})
total_seqs += 1
print(f" Loaded {total_seqs} LaSOT {split} sequences across {len(categories)} categories")
# ============================================================
# TrackingNet dataset loader
# ============================================================
class TrackingNetDataset(SequenceDataset):
"""TrackingNet tracking dataset.
Structure:
root/
TRAIN_0/
frames/
video_name/
0.jpg, 1.jpg, ...
anno/
video_name.txt # x,y,w,h per line
TRAIN_1/
...
"""
def __init__(self, root: str, chunks: list = None, **kwargs):
super().__init__(**kwargs)
self.root = Path(root)
if chunks is None:
chunks = list(range(12)) # TRAIN_0 through TRAIN_11
self._load_sequences(chunks)
def _load_sequences(self, chunks):
if not self.root.exists():
print(f"Warning: TrackingNet not found at {self.root}")
return
total_seqs = 0
for chunk_idx in chunks:
chunk_dir = self.root / f'TRAIN_{chunk_idx}'
if not chunk_dir.exists():
continue
anno_dir = chunk_dir / 'anno'
frames_dir = chunk_dir / 'frames'
if not anno_dir.exists() or not frames_dir.exists():
continue
for anno_file in sorted(anno_dir.glob('*.txt')):
seq_name = anno_file.stem
seq_frames_dir = frames_dir / seq_name
if not seq_frames_dir.exists():
continue
# Load annotations
gt_boxes = []
with open(anno_file, 'r') as f:
for line in f:
line = line.strip()
if not line:
gt_boxes.append(None)
continue
parts = line.replace(',', ' ').split()
try:
gt_boxes.append([float(x) for x in parts[:4]])
except ValueError:
gt_boxes.append(None)
frames = sorted(glob.glob(str(seq_frames_dir / '*.jpg')))
if not frames:
frames = sorted(glob.glob(str(seq_frames_dir / '*.png')))
if len(frames) != len(gt_boxes):
min_len = min(len(frames), len(gt_boxes))
frames = frames[:min_len]
gt_boxes = gt_boxes[:min_len]
if len(frames) >= 2:
self.sequences.append({'frames': frames, 'gt': gt_boxes})
total_seqs += 1
print(f" Loaded {total_seqs} TrackingNet sequences from {len(chunks)} chunks")
# ============================================================
# COCO detection as pseudo-sequences
# ============================================================
class COCODetDataset(SequenceDataset):
"""COCO detection images as pseudo-sequences for pretraining.
Each image with a valid bounding box becomes a length-1 "sequence"
where template and search are crops from the same image.
"""
def __init__(self, root: str, ann_file: str = None, **kwargs):
super().__init__(**kwargs)
self.root = Path(root)
self._load_annotations(ann_file)
def _load_annotations(self, ann_file):
if ann_file is None:
ann_file = str(self.root.parent / 'annotations' / 'instances_train2017.json')
if not os.path.exists(ann_file):
print(f"Warning: COCO annotations not found at {ann_file}")
return
try:
import json
with open(ann_file, 'r') as f:
coco = json.load(f)
# Build image lookup
images = {img['id']: img for img in coco['images']}
# Create pseudo-sequences from annotations
for ann in coco['annotations']:
if ann.get('iscrowd', 0):
continue
bbox = ann['bbox'] # [x, y, w, h]
if bbox[2] < 10 or bbox[3] < 10:
continue
img_info = images.get(ann['image_id'])
if img_info is None:
continue
img_path = str(self.root / img_info['file_name'])
if os.path.exists(img_path):
# Pseudo-sequence: same frame for template and search
self.sequences.append({
'frames': [img_path, img_path],
'gt': [bbox, bbox],
})
print(f" Loaded {len(self.sequences)} COCO pseudo-sequences")
except Exception as e:
print(f"Warning: Failed to load COCO annotations: {e}")
# ============================================================
# Synthetic dataset (for testing / no-data development)
# ============================================================
class SyntheticTrackingDataset(Dataset):
"""Synthetic tracking dataset for testing without real data.
Generates K-frame clips: template + K search frames with a moving
colored rectangle target. Motion is linear with noise.
"""
def __init__(
self,
length: int = 10000,
template_size: int = 128,
search_size: int = 256,
feat_size: int = 16,
acl_difficulty: float = 1.0,
clip_length: int = 3,
):
super().__init__()
self.length = length
self.template_size = template_size
self.search_size = search_size
self.feat_size = feat_size
self.acl_difficulty = acl_difficulty
self.clip_length = clip_length
def __len__(self):
return self.length
def _make_heatmap(self, cx, cy, w_search, h_search):
stride = self.search_size / self.feat_size
cx_feat = cx / stride
cy_feat = cy / stride
y = torch.arange(self.feat_size, dtype=torch.float32)
x = torch.arange(self.feat_size, dtype=torch.float32)
yy, xx = torch.meshgrid(y, x, indexing='ij')
sigma = max(1.0, min(3.0, (w_search + h_search) / (2 * stride * 4)))
dist_sq = (xx - cx_feat) ** 2 + (yy - cy_feat) ** 2
return torch.exp(-dist_sq / (2 * sigma ** 2)).unsqueeze(0)
def __getitem__(self, idx):
rng = random.Random(idx)
K = self.clip_length
# Target appearance
color = torch.tensor([rng.random(), rng.random(), rng.random()]).view(3, 1, 1)
target_w = rng.uniform(0.1, 0.5) * self.search_size
target_h = rng.uniform(0.1, 0.5) * self.search_size
# Initial position (center of search)
cx0 = self.search_size / 2
cy0 = self.search_size / 2
# Velocity (pixels per frame, scaled by difficulty)
vx = rng.gauss(0, self.acl_difficulty * 15)
vy = rng.gauss(0, self.acl_difficulty * 15)
# Template: target at center
template = torch.randn(3, self.template_size, self.template_size) * 0.1
t_hw = int(min(target_w / 2, self.template_size / 2 - 1))
t_hh = int(min(target_h / 2, self.template_size / 2 - 1))
tc = self.template_size // 2
template[:, tc - t_hh:tc + t_hh, tc - t_hw:tc + t_hw] = color
# K search frames with moving target
searches = []
heatmaps = []
sizes = []
boxes = []
for k in range(K):
# Position at frame k
cx = cx0 + vx * (k + 1) + rng.gauss(0, self.acl_difficulty * 5)
cy = cy0 + vy * (k + 1) + rng.gauss(0, self.acl_difficulty * 5)
cx = max(target_w / 2, min(self.search_size - target_w / 2, cx))
cy = max(target_h / 2, min(self.search_size - target_h / 2, cy))
search = torch.randn(3, self.search_size, self.search_size) * 0.1
sx1 = max(0, int(cx - target_w / 2))
sy1 = max(0, int(cy - target_h / 2))
sx2 = min(self.search_size, int(cx + target_w / 2))
sy2 = min(self.search_size, int(cy + target_h / 2))
search[:, sy1:sy2, sx1:sx2] = color
searches.append(search)
heatmaps.append(self._make_heatmap(cx, cy, target_w, target_h))
sizes.append(torch.tensor([target_w / self.search_size,
target_h / self.search_size]))
boxes.append(torch.tensor([cx, cy, target_w, target_h]))
return {
'template': template, # (3, 128, 128)
'searches': torch.stack(searches, dim=0), # (K, 3, 256, 256)
'heatmaps': torch.stack(heatmaps, dim=0), # (K, 1, 16, 16)
'sizes': torch.stack(sizes, dim=0), # (K, 2)
'boxes': torch.stack(boxes, dim=0), # (K, 4)
}
def set_acl_difficulty(self, difficulty: float):
self.acl_difficulty = min(1.0, max(0.0, difficulty))
# ============================================================
# VisDrone-SOT dataset loader (UAV)
# ============================================================
class VisDroneSOTDataset(SequenceDataset):
"""VisDrone-SOT single object tracking dataset (drone/UAV perspective).
Structure:
root/
VisDrone2019-SOT-train/
sequences/
uav0000001_00000_s/
0000001.jpg, 0000002.jpg, ...
...
annotations/
uav0000001_00000_s.txt # x,y,w,h per line
...
Splits: train (86 sequences, ~70K frames), val (11 sequences),
test-dev (35 sequences), test-challenge (35 sequences)
Key for our tracker: real drone footage with small targets, fast motion,
viewpoint changes, and camera ego-motion — the exact conditions we deploy in.
"""
def __init__(self, root: str, split: str = 'train', **kwargs):
super().__init__(**kwargs)
self.root = Path(root)
self._load_sequences(split)
def _load_sequences(self, split):
# Try multiple directory naming conventions
split_names = {
'train': ['VisDrone2019-SOT-train', 'VisDrone2018-SOT-train', 'train'],
'val': ['VisDrone2019-SOT-val', 'VisDrone2018-SOT-val', 'val'],
'test': ['VisDrone2019-SOT-test-dev', 'VisDrone2018-SOT-test', 'test-dev', 'test'],
}
split_dir = None
for name in split_names.get(split, [split]):
candidate = self.root / name
if candidate.exists():
split_dir = candidate
break
# Also check if root itself is the split dir
if (self.root / 'sequences').exists():
split_dir = self.root
break
if split_dir is None:
print(f"Warning: VisDrone-SOT {split} not found at {self.root}")
return
seq_dir = split_dir / 'sequences'
anno_dir = split_dir / 'annotations'
if not seq_dir.exists() or not anno_dir.exists():
print(f"Warning: VisDrone-SOT missing sequences/ or annotations/ at {split_dir}")
return
total_seqs = 0
for anno_file in sorted(anno_dir.glob('*.txt')):
seq_name = anno_file.stem
frames_dir = seq_dir / seq_name
if not frames_dir.exists():
continue
gt_boxes = []
with open(anno_file, 'r') as f:
for line in f:
line = line.strip()
if not line:
gt_boxes.append(None)
continue
parts = line.replace(',', ' ').split()
try:
gt_boxes.append([float(x) for x in parts[:4]])
except ValueError:
gt_boxes.append(None)
frames = sorted(glob.glob(str(frames_dir / '*.jpg')))
if not frames:
frames = sorted(glob.glob(str(frames_dir / '*.png')))
if len(frames) != len(gt_boxes):
min_len = min(len(frames), len(gt_boxes))
frames = frames[:min_len]
gt_boxes = gt_boxes[:min_len]
if len(frames) >= 2:
self.sequences.append({'frames': frames, 'gt': gt_boxes})
total_seqs += 1
print(f" Loaded {total_seqs} VisDrone-SOT {split} sequences")
# ============================================================
# UAVDT dataset loader (UAV)
# ============================================================
class UAVDTDataset(SequenceDataset):
"""UAVDT (Unmanned Aerial Vehicle Detection and Tracking) dataset.
Structure:
root/
UAV-benchmark-S/ # SOT annotations
{seq_name}/
{seq_name}_gt.txt # x,y,w,h per line (or comma-separated)
UAV-benchmark-M/ # Frames
{seq_name}/
img000001.jpg, img000002.jpg, ...
Alternative structure (simpler):
root/
sequences/
{seq_name}/
img000001.jpg, ...
annotations/
{seq_name}_gt.txt
50 sequences total, typically 30 train / 20 test.
Contains vehicle tracking from drone perspective — complementary to VisDrone.
"""
def __init__(self, root: str, split: str = 'train', **kwargs):
super().__init__(**kwargs)
self.root = Path(root)
self._load_sequences(split)
def _load_sequences(self, split):
# Try standard UAVDT structure
anno_dir = self.root / 'UAV-benchmark-S'
frame_dir = self.root / 'UAV-benchmark-M'
if not anno_dir.exists():
# Alternative structure
anno_dir = self.root / 'annotations'
frame_dir = self.root / 'sequences'
if not anno_dir.exists():
# Try root directly having sequence dirs
anno_dir = self.root
frame_dir = self.root
if not anno_dir.exists():
print(f"Warning: UAVDT not found at {self.root}")
return
# Collect all sequences
all_seqs = []
# Find annotation files
gt_files = sorted(anno_dir.rglob('*_gt.txt'))
if not gt_files:
gt_files = sorted(anno_dir.rglob('*.txt'))
for gt_file in gt_files:
seq_name = gt_file.stem.replace('_gt', '')
# Find frames directory
frames_path = None
for candidate in [
frame_dir / seq_name,
frame_dir / seq_name / 'img',
self.root / seq_name,
]:
if candidate.exists():
frames_path = candidate
break
if frames_path is None:
continue
gt_boxes = []
with open(gt_file, 'r') as f:
for line in f:
line = line.strip()
if not line:
gt_boxes.append(None)
continue
parts = line.replace(',', ' ').replace('\t', ' ').split()
try:
gt_boxes.append([float(x) for x in parts[:4]])
except (ValueError, IndexError):
gt_boxes.append(None)
frames = sorted(glob.glob(str(frames_path / '*.jpg')))
if not frames:
frames = sorted(glob.glob(str(frames_path / '*.png')))
if len(frames) != len(gt_boxes):
min_len = min(len(frames), len(gt_boxes))
frames = frames[:min_len]
gt_boxes = gt_boxes[:min_len]
if len(frames) >= 2:
all_seqs.append({'frames': frames, 'gt': gt_boxes, 'name': seq_name})
# Split: first 60% train, last 40% test (standard UAVDT protocol)
all_seqs.sort(key=lambda x: x['name'])
split_idx = int(len(all_seqs) * 0.6)
if split == 'train':
selected = all_seqs[:split_idx]
else:
selected = all_seqs[split_idx:]
for seq in selected:
self.sequences.append({'frames': seq['frames'], 'gt': seq['gt']})
print(f" Loaded {len(self.sequences)} UAVDT {split} sequences "
f"(from {len(all_seqs)} total)")
# ============================================================
# WebUAV-3M dataset loader (UAV, large-scale)
# ============================================================
class WebUAV3MDataset(SequenceDataset):
"""WebUAV-3M: million-scale multi-modal UAV tracking dataset.
Structure:
root/
{superclass}/ # e.g., person, vehicle, animal
{seq_name}/
img/
000001.jpg, 000002.jpg, ...
groundtruth_rect.txt # x,y,w,h per line
OR:
{seq_name}/
*.jpg
groundtruth_rect.txt
4,500 sequences, 3.3M frames, 12 superclasses, 223 target classes.
Average video length: 710 frames (23.7 seconds at 30 FPS).
This is the largest UAV tracking dataset. All sequences are from real
drone footage. Purpose-built for training deep UAV trackers.
"""
def __init__(self, root: str, split: str = 'train', max_sequences: int = None, **kwargs):
super().__init__(**kwargs)
self.root = Path(root)
self._load_sequences(split, max_sequences)
def _load_sequences(self, split, max_sequences):
if not self.root.exists():
print(f"Warning: WebUAV-3M not found at {self.root}")
return
# Find all sequences recursively
all_seq_dirs = []
# Look for groundtruth files recursively
gt_files = sorted(self.root.rglob('groundtruth_rect.txt'))
if not gt_files:
gt_files = sorted(self.root.rglob('groundtruth.txt'))
for gt_file in gt_files:
seq_dir = gt_file.parent
# Check for img subdirectory or direct frames
img_dir = seq_dir / 'img'
if not img_dir.exists():
img_dir = seq_dir # frames directly in seq dir
frames = sorted(glob.glob(str(img_dir / '*.jpg')))
if not frames:
frames = sorted(glob.glob(str(img_dir / '*.png')))
if len(frames) >= 2:
all_seq_dirs.append((gt_file, frames))
print(f"WebUAV-3M: found {len(all_seq_dirs)} sequences total")
# Train/test split (80/20)
split_idx = int(len(all_seq_dirs) * 0.8)
if split == 'train':
selected = all_seq_dirs[:split_idx]
else:
selected = all_seq_dirs[split_idx:]
# Optionally limit sequences (WebUAV-3M is huge)
if max_sequences and len(selected) > max_sequences:
# Sample uniformly to maintain diversity
step = len(selected) // max_sequences
selected = selected[::step][:max_sequences]
for gt_file, frames in selected:
gt_boxes = []
with open(gt_file, 'r') as f:
for line in f:
line = line.strip()
if not line:
gt_boxes.append(None)
continue
parts = line.replace(',', ' ').replace('\t', ' ').split()
try:
gt_boxes.append([float(x) for x in parts[:4]])
except (ValueError, IndexError):
gt_boxes.append(None)
if len(frames) != len(gt_boxes):
min_len = min(len(frames), len(gt_boxes))
frames = frames[:min_len]
gt_boxes = gt_boxes[:min_len]
if len(frames) >= 2:
self.sequences.append({'frames': frames, 'gt': gt_boxes})
print(f" Loaded {len(self.sequences)} WebUAV-3M {split} sequences")
# ============================================================
# Convenience: build combined dataset
# ============================================================
def build_tracking_dataset(
data_config: dict,
template_size: int = 128,
search_size: int = 256,
feat_size: int = 16,
acl_difficulty: float = 0.0,
) -> Dataset:
"""Build a combined tracking dataset from multiple sources.
Standard ground-level datasets provide general tracking capability.
UAV-specific datasets provide drone-perspective specialization.
The ACL curriculum bridges the gap: it starts training on easy pairs
from ground-level data, then progressively incorporates harder pairs
including UAV sequences with fast motion, small targets, and viewpoint changes.
Args:
data_config: dict with optional keys:
Ground-level (standard tracking training data):
- 'got10k_root': path to GOT-10k dataset
- 'lasot_root': path to LaSOT dataset
- 'trackingnet_root': path to TrackingNet dataset
- 'coco_root': path to COCO train2017 images
UAV-specific (drone perspective — the deployment domain):
- 'visdrone_root': path to VisDrone-SOT dataset
- 'uavdt_root': path to UAVDT dataset
- 'webuav3m_root': path to WebUAV-3M dataset
- 'webuav3m_max_sequences': limit WebUAV-3M sequences (default: None = all)
Fallback:
- 'synthetic_length': number of synthetic samples (fallback)
template_size: template crop size
search_size: search region crop size
feat_size: feature map spatial size
acl_difficulty: initial ACL difficulty
Returns:
ConcatDataset or SyntheticTrackingDataset
"""
common_kwargs = dict(
template_size=template_size,
search_size=search_size,
feat_size=feat_size,
acl_difficulty=acl_difficulty,
)
datasets = []
if 'got10k_root' in data_config and os.path.exists(data_config['got10k_root']):
ds = GOT10kDataset(data_config['got10k_root'], split='train', **common_kwargs)
if len(ds) > 0:
datasets.append(ds)
print(f"GOT-10k: {len(ds)} sequences")
if 'lasot_root' in data_config and os.path.exists(data_config['lasot_root']):
ds = LaSOTDataset(data_config['lasot_root'], split='train', **common_kwargs)
if len(ds) > 0:
datasets.append(ds)
print(f"LaSOT: {len(ds)} sequences")
if 'trackingnet_root' in data_config and os.path.exists(data_config['trackingnet_root']):
ds = TrackingNetDataset(data_config['trackingnet_root'], **common_kwargs)
if len(ds) > 0:
datasets.append(ds)
print(f"TrackingNet: {len(ds)} sequences")
if 'coco_root' in data_config and os.path.exists(data_config['coco_root']):
ds = COCODetDataset(data_config['coco_root'], **common_kwargs)
if len(ds) > 0:
datasets.append(ds)
print(f"COCO: {len(ds)} pseudo-sequences")
# --- UAV-specific datasets (drone perspective) ---
if 'visdrone_root' in data_config and os.path.exists(data_config['visdrone_root']):
ds = VisDroneSOTDataset(data_config['visdrone_root'], split='train', **common_kwargs)
if len(ds) > 0:
datasets.append(ds)
print(f"VisDrone-SOT: {len(ds)} UAV sequences")
if 'uavdt_root' in data_config and os.path.exists(data_config['uavdt_root']):
ds = UAVDTDataset(data_config['uavdt_root'], split='train', **common_kwargs)
if len(ds) > 0:
datasets.append(ds)
print(f"UAVDT: {len(ds)} UAV sequences")
if 'webuav3m_root' in data_config and os.path.exists(data_config['webuav3m_root']):
max_seq = data_config.get('webuav3m_max_sequences', None)
ds = WebUAV3MDataset(data_config['webuav3m_root'], split='train',
max_sequences=max_seq, **common_kwargs)
if len(ds) > 0:
datasets.append(ds)
print(f"WebUAV-3M: {len(ds)} UAV sequences")
if datasets:
combined = ConcatDataset(datasets)
print(f"\nTotal training samples: {len(combined)}")
return combined
# Fallback to synthetic
syn_len = data_config.get('synthetic_length', 10000)
print(f"No real data found, using {syn_len} synthetic samples")
return SyntheticTrackingDataset(
length=syn_len,
template_size=template_size,
search_size=search_size,
feat_size=feat_size,
acl_difficulty=acl_difficulty,
)
# ============================================================
# Legacy alias for backward compatibility
# ============================================================
class TrackingDataset(SyntheticTrackingDataset):
"""Backward-compatible alias for SyntheticTrackingDataset."""
def __init__(self, data_dir=None, split='train', synthetic=False,
synthetic_length=10000, clip_length=3, **kwargs):
super().__init__(length=synthetic_length, clip_length=clip_length, **kwargs)