Sequence training: pairs→K-frame clips, mLSTM memory carries across frames
Browse files- vil_tracker/data/dataset.py +198 -132
vil_tracker/data/dataset.py
CHANGED
|
@@ -245,6 +245,10 @@ def compute_crop_params(bbox: np.ndarray, context_factor: float = 2.0) -> tuple:
|
|
| 245 |
class SequenceDataset(Dataset):
|
| 246 |
"""Base class for tracking sequence datasets.
|
| 247 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
Subclasses must populate self.sequences with list of:
|
| 249 |
{'frames': [path1, path2, ...], 'gt': [[x,y,w,h], ...]}
|
| 250 |
"""
|
|
@@ -256,6 +260,7 @@ class SequenceDataset(Dataset):
|
|
| 256 |
feat_size: int = 16,
|
| 257 |
acl_difficulty: float = 1.0,
|
| 258 |
max_gap: int = 100,
|
|
|
|
| 259 |
augmentation: bool = True,
|
| 260 |
):
|
| 261 |
super().__init__()
|
|
@@ -264,6 +269,7 @@ class SequenceDataset(Dataset):
|
|
| 264 |
self.feat_size = feat_size
|
| 265 |
self.acl_difficulty = acl_difficulty
|
| 266 |
self.max_gap = max_gap
|
|
|
|
| 267 |
self.sequences = []
|
| 268 |
|
| 269 |
self.augmentation = TrackingAugmentation() if augmentation else None
|
|
@@ -278,123 +284,168 @@ class SequenceDataset(Dataset):
|
|
| 278 |
img = Image.open(path).convert('RGB')
|
| 279 |
return np.array(img, dtype=np.float32)
|
| 280 |
except ImportError:
|
| 281 |
-
# Fallback with OpenCV
|
| 282 |
import cv2
|
| 283 |
img = cv2.imread(path)
|
| 284 |
if img is None:
|
| 285 |
return np.zeros((480, 640, 3), dtype=np.float32)
|
| 286 |
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32)
|
| 287 |
|
| 288 |
-
def
|
| 289 |
-
"""Sample a
|
| 290 |
-
|
| 291 |
-
Temporal distance controlled by ACL difficulty:
|
| 292 |
-
- difficulty=0: template and search are very close
|
| 293 |
-
- difficulty=1: template and search can be up to max_gap apart
|
| 294 |
|
| 295 |
Returns:
|
| 296 |
-
|
| 297 |
"""
|
| 298 |
seq = self.sequences[idx]
|
| 299 |
n_frames = len(seq['frames'])
|
|
|
|
| 300 |
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
|
| 305 |
-
if len(
|
| 306 |
-
|
| 307 |
-
|
|
|
|
|
|
|
| 308 |
|
| 309 |
-
|
|
|
|
| 310 |
|
| 311 |
-
# Search:
|
|
|
|
| 312 |
effective_gap = max(1, int(self.max_gap * self.acl_difficulty))
|
| 313 |
-
min_idx = max(0, t_idx - effective_gap)
|
| 314 |
-
max_idx = min(n_frames - 1, t_idx + effective_gap)
|
| 315 |
-
|
| 316 |
-
# Only pick valid indices
|
| 317 |
-
search_candidates = [i for i in range(min_idx, max_idx + 1)
|
| 318 |
-
if i != t_idx and i in valid_indices]
|
| 319 |
-
|
| 320 |
-
if not search_candidates:
|
| 321 |
-
return t_idx, t_idx
|
| 322 |
-
|
| 323 |
-
s_idx = random.choice(search_candidates)
|
| 324 |
-
return t_idx, s_idx
|
| 325 |
-
|
| 326 |
-
def __getitem__(self, idx):
|
| 327 |
-
seq = self.sequences[idx % len(self.sequences)]
|
| 328 |
-
t_idx, s_idx = self._sample_pair(idx % len(self.sequences))
|
| 329 |
-
|
| 330 |
-
# Load images
|
| 331 |
-
t_img = self._load_image(seq['frames'][t_idx])
|
| 332 |
-
s_img = self._load_image(seq['frames'][s_idx])
|
| 333 |
-
|
| 334 |
-
t_bbox = np.array(seq['gt'][t_idx], dtype=np.float32) # [x, y, w, h]
|
| 335 |
-
s_bbox = np.array(seq['gt'][s_idx], dtype=np.float32)
|
| 336 |
-
|
| 337 |
-
# Crop template (centered on target, 2x context)
|
| 338 |
-
t_center, t_crop_size = compute_crop_params(t_bbox, context_factor=2.0)
|
| 339 |
-
template = crop_and_resize(t_img, t_center, t_crop_size, self.template_size)
|
| 340 |
-
|
| 341 |
-
# Crop search region (centered on target with jitter, 4x context)
|
| 342 |
-
s_center, s_crop_size = compute_crop_params(s_bbox, context_factor=4.0)
|
| 343 |
-
|
| 344 |
-
# Add spatial jitter (controlled by ACL difficulty)
|
| 345 |
-
jitter = self.acl_difficulty * s_bbox[2:4].mean() * 0.3
|
| 346 |
-
s_center[0] += random.gauss(0, jitter) if jitter > 0 else 0
|
| 347 |
-
s_center[1] += random.gauss(0, jitter) if jitter > 0 else 0
|
| 348 |
|
| 349 |
-
|
|
|
|
|
|
|
| 350 |
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 364 |
|
| 365 |
-
#
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
bbox_tensor = torch.tensor([cx_in_search, cy_in_search, w_in_search, h_in_search])
|
| 369 |
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
|
|
|
| 373 |
|
| 374 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 375 |
stride = self.search_size / self.feat_size
|
| 376 |
-
cx_feat =
|
| 377 |
-
cy_feat =
|
|
|
|
|
|
|
| 378 |
|
| 379 |
y = torch.arange(self.feat_size, dtype=torch.float32)
|
| 380 |
x = torch.arange(self.feat_size, dtype=torch.float32)
|
| 381 |
yy, xx = torch.meshgrid(y, x, indexing='ij')
|
| 382 |
|
| 383 |
-
|
| 384 |
-
sigma = max(1.0, min(3.0, (w_in_search + h_in_search) / (2 * stride * 4)))
|
| 385 |
dist_sq = (xx - cx_feat) ** 2 + (yy - cy_feat) ** 2
|
| 386 |
heatmap = torch.exp(-dist_sq / (2 * sigma ** 2)).unsqueeze(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 387 |
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 391 |
|
| 392 |
return {
|
| 393 |
-
'template': template,
|
| 394 |
-
'
|
| 395 |
-
'
|
| 396 |
-
'
|
| 397 |
-
'boxes':
|
| 398 |
}
|
| 399 |
|
| 400 |
def set_acl_difficulty(self, difficulty: float):
|
|
@@ -686,8 +737,8 @@ class COCODetDataset(SequenceDataset):
|
|
| 686 |
class SyntheticTrackingDataset(Dataset):
|
| 687 |
"""Synthetic tracking dataset for testing without real data.
|
| 688 |
|
| 689 |
-
Generates
|
| 690 |
-
|
| 691 |
"""
|
| 692 |
|
| 693 |
def __init__(
|
|
@@ -697,6 +748,7 @@ class SyntheticTrackingDataset(Dataset):
|
|
| 697 |
search_size: int = 256,
|
| 698 |
feat_size: int = 16,
|
| 699 |
acl_difficulty: float = 1.0,
|
|
|
|
| 700 |
):
|
| 701 |
super().__init__()
|
| 702 |
self.length = length
|
|
@@ -704,64 +756,78 @@ class SyntheticTrackingDataset(Dataset):
|
|
| 704 |
self.search_size = search_size
|
| 705 |
self.feat_size = feat_size
|
| 706 |
self.acl_difficulty = acl_difficulty
|
|
|
|
| 707 |
|
| 708 |
def __len__(self):
|
| 709 |
return self.length
|
| 710 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 711 |
def __getitem__(self, idx):
|
| 712 |
rng = random.Random(idx)
|
|
|
|
| 713 |
|
| 714 |
-
#
|
|
|
|
| 715 |
target_w = rng.uniform(0.1, 0.5) * self.search_size
|
| 716 |
target_h = rng.uniform(0.1, 0.5) * self.search_size
|
| 717 |
|
| 718 |
-
#
|
| 719 |
-
|
| 720 |
-
|
| 721 |
-
cy = self.search_size / 2 + rng.gauss(0, jitter * self.search_size)
|
| 722 |
-
cx = max(target_w / 2, min(self.search_size - target_w / 2, cx))
|
| 723 |
-
cy = max(target_h / 2, min(self.search_size - target_h / 2, cy))
|
| 724 |
|
| 725 |
-
#
|
| 726 |
-
|
| 727 |
-
|
| 728 |
|
| 729 |
-
#
|
| 730 |
-
|
| 731 |
-
|
|
|
|
| 732 |
tc = self.template_size // 2
|
| 733 |
-
|
| 734 |
-
|
| 735 |
-
|
| 736 |
-
|
| 737 |
-
|
| 738 |
-
|
| 739 |
-
|
| 740 |
-
|
| 741 |
-
|
| 742 |
-
|
| 743 |
-
|
| 744 |
-
|
| 745 |
-
|
| 746 |
-
|
| 747 |
-
|
| 748 |
-
|
| 749 |
-
|
| 750 |
-
|
| 751 |
-
|
| 752 |
-
|
| 753 |
-
|
| 754 |
-
|
| 755 |
-
|
| 756 |
-
|
| 757 |
-
|
|
|
|
|
|
|
| 758 |
|
| 759 |
return {
|
| 760 |
-
'template': template,
|
| 761 |
-
'
|
| 762 |
-
'
|
| 763 |
-
'
|
| 764 |
-
'boxes': boxes,
|
| 765 |
}
|
| 766 |
|
| 767 |
def set_acl_difficulty(self, difficulty: float):
|
|
@@ -1202,5 +1268,5 @@ def build_tracking_dataset(
|
|
| 1202 |
class TrackingDataset(SyntheticTrackingDataset):
|
| 1203 |
"""Backward-compatible alias for SyntheticTrackingDataset."""
|
| 1204 |
def __init__(self, data_dir=None, split='train', synthetic=False,
|
| 1205 |
-
synthetic_length=10000, **kwargs):
|
| 1206 |
-
super().__init__(length=synthetic_length, **kwargs)
|
|
|
|
| 245 |
class SequenceDataset(Dataset):
|
| 246 |
"""Base class for tracking sequence datasets.
|
| 247 |
|
| 248 |
+
Returns K-frame clips: template + K consecutive search frames.
|
| 249 |
+
The mLSTM processes these as one long sequence where memory carries
|
| 250 |
+
information across frames — this is the core training paradigm.
|
| 251 |
+
|
| 252 |
Subclasses must populate self.sequences with list of:
|
| 253 |
{'frames': [path1, path2, ...], 'gt': [[x,y,w,h], ...]}
|
| 254 |
"""
|
|
|
|
| 260 |
feat_size: int = 16,
|
| 261 |
acl_difficulty: float = 1.0,
|
| 262 |
max_gap: int = 100,
|
| 263 |
+
clip_length: int = 3,
|
| 264 |
augmentation: bool = True,
|
| 265 |
):
|
| 266 |
super().__init__()
|
|
|
|
| 269 |
self.feat_size = feat_size
|
| 270 |
self.acl_difficulty = acl_difficulty
|
| 271 |
self.max_gap = max_gap
|
| 272 |
+
self.clip_length = clip_length # K search frames per sample
|
| 273 |
self.sequences = []
|
| 274 |
|
| 275 |
self.augmentation = TrackingAugmentation() if augmentation else None
|
|
|
|
| 284 |
img = Image.open(path).convert('RGB')
|
| 285 |
return np.array(img, dtype=np.float32)
|
| 286 |
except ImportError:
|
|
|
|
| 287 |
import cv2
|
| 288 |
img = cv2.imread(path)
|
| 289 |
if img is None:
|
| 290 |
return np.zeros((480, 640, 3), dtype=np.float32)
|
| 291 |
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32)
|
| 292 |
|
| 293 |
+
def _sample_clip(self, idx: int) -> list:
|
| 294 |
+
"""Sample a clip: template frame + K consecutive search frames.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 295 |
|
| 296 |
Returns:
|
| 297 |
+
list of frame indices: [template_idx, search_1_idx, ..., search_K_idx]
|
| 298 |
"""
|
| 299 |
seq = self.sequences[idx]
|
| 300 |
n_frames = len(seq['frames'])
|
| 301 |
+
K = self.clip_length
|
| 302 |
|
| 303 |
+
valid = [i for i in range(n_frames)
|
| 304 |
+
if seq['gt'][i] is not None and seq['gt'][i][2] > 0 and seq['gt'][i][3] > 0]
|
| 305 |
+
valid_set = set(valid)
|
| 306 |
|
| 307 |
+
if len(valid) < K + 1:
|
| 308 |
+
# Not enough frames — repeat what we have
|
| 309 |
+
if len(valid) == 0:
|
| 310 |
+
return [0] * (K + 1)
|
| 311 |
+
return [valid[0]] + [valid[min(i, len(valid)-1)] for i in range(K)]
|
| 312 |
|
| 313 |
+
# Template: pick a random valid frame
|
| 314 |
+
t_idx = random.choice(valid)
|
| 315 |
|
| 316 |
+
# Search frames: K consecutive valid frames AFTER template
|
| 317 |
+
# Temporal gap between template and first search controlled by ACL
|
| 318 |
effective_gap = max(1, int(self.max_gap * self.acl_difficulty))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 319 |
|
| 320 |
+
# Find the start of the search clip: somewhere after template
|
| 321 |
+
min_start = t_idx + 1
|
| 322 |
+
max_start = min(t_idx + effective_gap, n_frames - K)
|
| 323 |
|
| 324 |
+
if max_start < min_start:
|
| 325 |
+
# Try before template
|
| 326 |
+
max_start_before = t_idx - K
|
| 327 |
+
min_start_before = max(0, t_idx - effective_gap - K)
|
| 328 |
+
if max_start_before >= min_start_before and max_start_before >= 0:
|
| 329 |
+
clip_start = random.randint(min_start_before, max_start_before)
|
| 330 |
+
else:
|
| 331 |
+
# Fallback: just use whatever consecutive frames we can find
|
| 332 |
+
clip_start = max(0, min(n_frames - K, t_idx + 1))
|
| 333 |
+
# But ensure template is different from search frames
|
| 334 |
+
else:
|
| 335 |
+
clip_start = random.randint(min_start, max(min_start, max_start))
|
| 336 |
+
|
| 337 |
+
# Collect K consecutive frames, preferring valid ones
|
| 338 |
+
search_indices = []
|
| 339 |
+
for i in range(clip_start, min(clip_start + K * 3, n_frames)):
|
| 340 |
+
if i in valid_set and i != t_idx:
|
| 341 |
+
search_indices.append(i)
|
| 342 |
+
if len(search_indices) == K:
|
| 343 |
+
break
|
| 344 |
|
| 345 |
+
# Pad if we didn't find enough
|
| 346 |
+
while len(search_indices) < K:
|
| 347 |
+
search_indices.append(search_indices[-1] if search_indices else t_idx)
|
|
|
|
| 348 |
|
| 349 |
+
return [t_idx] + search_indices[:K]
|
| 350 |
+
|
| 351 |
+
def _process_frame(self, img: np.ndarray, bbox: np.ndarray, is_template: bool):
|
| 352 |
+
"""Crop and preprocess a single frame.
|
| 353 |
|
| 354 |
+
Returns:
|
| 355 |
+
image_tensor: (3, H, W) float [0, 1]
|
| 356 |
+
bbox_in_crop: (4,) [cx, cy, w, h] in crop coordinates
|
| 357 |
+
"""
|
| 358 |
+
if is_template:
|
| 359 |
+
center, crop_size = compute_crop_params(bbox, context_factor=2.0)
|
| 360 |
+
output_size = self.template_size
|
| 361 |
+
else:
|
| 362 |
+
center, crop_size = compute_crop_params(bbox, context_factor=4.0)
|
| 363 |
+
output_size = self.search_size
|
| 364 |
+
# Spatial jitter for search (controlled by ACL)
|
| 365 |
+
jitter = self.acl_difficulty * bbox[2:4].mean() * 0.3
|
| 366 |
+
if jitter > 0:
|
| 367 |
+
center[0] += random.gauss(0, jitter)
|
| 368 |
+
center[1] += random.gauss(0, jitter)
|
| 369 |
+
|
| 370 |
+
crop = crop_and_resize(img, center, crop_size, output_size)
|
| 371 |
+
|
| 372 |
+
# Compute GT in crop coordinates
|
| 373 |
+
scale = output_size / crop_size
|
| 374 |
+
cx = (bbox[0] + bbox[2] / 2 - center[0] + crop_size / 2) * scale
|
| 375 |
+
cy = (bbox[1] + bbox[3] / 2 - center[1] + crop_size / 2) * scale
|
| 376 |
+
w = bbox[2] * scale
|
| 377 |
+
h = bbox[3] * scale
|
| 378 |
+
|
| 379 |
+
cx = max(0, min(output_size, cx))
|
| 380 |
+
cy = max(0, min(output_size, cy))
|
| 381 |
+
w = max(1, min(output_size, w))
|
| 382 |
+
h = max(1, min(output_size, h))
|
| 383 |
+
|
| 384 |
+
tensor = torch.from_numpy(crop).float().permute(2, 0, 1) / 255.0
|
| 385 |
+
bbox_crop = torch.tensor([cx, cy, w, h])
|
| 386 |
+
|
| 387 |
+
return tensor, bbox_crop
|
| 388 |
+
|
| 389 |
+
def _make_heatmap(self, bbox: torch.Tensor):
|
| 390 |
+
"""Generate GT heatmap from bbox in search crop coordinates."""
|
| 391 |
stride = self.search_size / self.feat_size
|
| 392 |
+
cx_feat = bbox[0].item() / stride
|
| 393 |
+
cy_feat = bbox[1].item() / stride
|
| 394 |
+
w_search = bbox[2].item()
|
| 395 |
+
h_search = bbox[3].item()
|
| 396 |
|
| 397 |
y = torch.arange(self.feat_size, dtype=torch.float32)
|
| 398 |
x = torch.arange(self.feat_size, dtype=torch.float32)
|
| 399 |
yy, xx = torch.meshgrid(y, x, indexing='ij')
|
| 400 |
|
| 401 |
+
sigma = max(1.0, min(3.0, (w_search + h_search) / (2 * stride * 4)))
|
|
|
|
| 402 |
dist_sq = (xx - cx_feat) ** 2 + (yy - cy_feat) ** 2
|
| 403 |
heatmap = torch.exp(-dist_sq / (2 * sigma ** 2)).unsqueeze(0)
|
| 404 |
+
return heatmap
|
| 405 |
+
|
| 406 |
+
def __getitem__(self, idx):
|
| 407 |
+
seq = self.sequences[idx % len(self.sequences)]
|
| 408 |
+
clip_indices = self._sample_clip(idx % len(self.sequences))
|
| 409 |
|
| 410 |
+
t_idx = clip_indices[0]
|
| 411 |
+
s_indices = clip_indices[1:]
|
| 412 |
+
K = len(s_indices)
|
| 413 |
+
|
| 414 |
+
# Load and process template
|
| 415 |
+
t_img = self._load_image(seq['frames'][t_idx])
|
| 416 |
+
t_bbox = np.array(seq['gt'][t_idx], dtype=np.float32)
|
| 417 |
+
template, _ = self._process_frame(t_img, t_bbox, is_template=True)
|
| 418 |
+
|
| 419 |
+
# Load and process K search frames
|
| 420 |
+
searches = []
|
| 421 |
+
heatmaps = []
|
| 422 |
+
sizes = []
|
| 423 |
+
boxes = []
|
| 424 |
+
|
| 425 |
+
for s_idx in s_indices:
|
| 426 |
+
s_img = self._load_image(seq['frames'][s_idx])
|
| 427 |
+
s_bbox = np.array(seq['gt'][s_idx], dtype=np.float32)
|
| 428 |
+
search, bbox_crop = self._process_frame(s_img, s_bbox, is_template=False)
|
| 429 |
+
|
| 430 |
+
# Apply augmentation (same color transform for template+search consistency)
|
| 431 |
+
if self.augmentation is not None:
|
| 432 |
+
template_aug, search, bbox_crop = self.augmentation(template, search, bbox_crop)
|
| 433 |
+
# Only use augmented template from first search frame to keep consistency
|
| 434 |
+
if len(searches) == 0:
|
| 435 |
+
template = template_aug
|
| 436 |
+
|
| 437 |
+
searches.append(search)
|
| 438 |
+
heatmaps.append(self._make_heatmap(bbox_crop))
|
| 439 |
+
sizes.append(torch.tensor([bbox_crop[2].item() / self.search_size,
|
| 440 |
+
bbox_crop[3].item() / self.search_size]))
|
| 441 |
+
boxes.append(bbox_crop)
|
| 442 |
|
| 443 |
return {
|
| 444 |
+
'template': template, # (3, 128, 128)
|
| 445 |
+
'searches': torch.stack(searches, dim=0), # (K, 3, 256, 256)
|
| 446 |
+
'heatmaps': torch.stack(heatmaps, dim=0), # (K, 1, 16, 16)
|
| 447 |
+
'sizes': torch.stack(sizes, dim=0), # (K, 2)
|
| 448 |
+
'boxes': torch.stack(boxes, dim=0), # (K, 4)
|
| 449 |
}
|
| 450 |
|
| 451 |
def set_acl_difficulty(self, difficulty: float):
|
|
|
|
| 737 |
class SyntheticTrackingDataset(Dataset):
|
| 738 |
"""Synthetic tracking dataset for testing without real data.
|
| 739 |
|
| 740 |
+
Generates K-frame clips: template + K search frames with a moving
|
| 741 |
+
colored rectangle target. Motion is linear with noise.
|
| 742 |
"""
|
| 743 |
|
| 744 |
def __init__(
|
|
|
|
| 748 |
search_size: int = 256,
|
| 749 |
feat_size: int = 16,
|
| 750 |
acl_difficulty: float = 1.0,
|
| 751 |
+
clip_length: int = 3,
|
| 752 |
):
|
| 753 |
super().__init__()
|
| 754 |
self.length = length
|
|
|
|
| 756 |
self.search_size = search_size
|
| 757 |
self.feat_size = feat_size
|
| 758 |
self.acl_difficulty = acl_difficulty
|
| 759 |
+
self.clip_length = clip_length
|
| 760 |
|
| 761 |
def __len__(self):
|
| 762 |
return self.length
|
| 763 |
|
| 764 |
+
def _make_heatmap(self, cx, cy, w_search, h_search):
|
| 765 |
+
stride = self.search_size / self.feat_size
|
| 766 |
+
cx_feat = cx / stride
|
| 767 |
+
cy_feat = cy / stride
|
| 768 |
+
y = torch.arange(self.feat_size, dtype=torch.float32)
|
| 769 |
+
x = torch.arange(self.feat_size, dtype=torch.float32)
|
| 770 |
+
yy, xx = torch.meshgrid(y, x, indexing='ij')
|
| 771 |
+
sigma = max(1.0, min(3.0, (w_search + h_search) / (2 * stride * 4)))
|
| 772 |
+
dist_sq = (xx - cx_feat) ** 2 + (yy - cy_feat) ** 2
|
| 773 |
+
return torch.exp(-dist_sq / (2 * sigma ** 2)).unsqueeze(0)
|
| 774 |
+
|
| 775 |
def __getitem__(self, idx):
|
| 776 |
rng = random.Random(idx)
|
| 777 |
+
K = self.clip_length
|
| 778 |
|
| 779 |
+
# Target appearance
|
| 780 |
+
color = torch.tensor([rng.random(), rng.random(), rng.random()]).view(3, 1, 1)
|
| 781 |
target_w = rng.uniform(0.1, 0.5) * self.search_size
|
| 782 |
target_h = rng.uniform(0.1, 0.5) * self.search_size
|
| 783 |
|
| 784 |
+
# Initial position (center of search)
|
| 785 |
+
cx0 = self.search_size / 2
|
| 786 |
+
cy0 = self.search_size / 2
|
|
|
|
|
|
|
|
|
|
| 787 |
|
| 788 |
+
# Velocity (pixels per frame, scaled by difficulty)
|
| 789 |
+
vx = rng.gauss(0, self.acl_difficulty * 15)
|
| 790 |
+
vy = rng.gauss(0, self.acl_difficulty * 15)
|
| 791 |
|
| 792 |
+
# Template: target at center
|
| 793 |
+
template = torch.randn(3, self.template_size, self.template_size) * 0.1
|
| 794 |
+
t_hw = int(min(target_w / 2, self.template_size / 2 - 1))
|
| 795 |
+
t_hh = int(min(target_h / 2, self.template_size / 2 - 1))
|
| 796 |
tc = self.template_size // 2
|
| 797 |
+
template[:, tc - t_hh:tc + t_hh, tc - t_hw:tc + t_hw] = color
|
| 798 |
+
|
| 799 |
+
# K search frames with moving target
|
| 800 |
+
searches = []
|
| 801 |
+
heatmaps = []
|
| 802 |
+
sizes = []
|
| 803 |
+
boxes = []
|
| 804 |
+
|
| 805 |
+
for k in range(K):
|
| 806 |
+
# Position at frame k
|
| 807 |
+
cx = cx0 + vx * (k + 1) + rng.gauss(0, self.acl_difficulty * 5)
|
| 808 |
+
cy = cy0 + vy * (k + 1) + rng.gauss(0, self.acl_difficulty * 5)
|
| 809 |
+
cx = max(target_w / 2, min(self.search_size - target_w / 2, cx))
|
| 810 |
+
cy = max(target_h / 2, min(self.search_size - target_h / 2, cy))
|
| 811 |
+
|
| 812 |
+
search = torch.randn(3, self.search_size, self.search_size) * 0.1
|
| 813 |
+
sx1 = max(0, int(cx - target_w / 2))
|
| 814 |
+
sy1 = max(0, int(cy - target_h / 2))
|
| 815 |
+
sx2 = min(self.search_size, int(cx + target_w / 2))
|
| 816 |
+
sy2 = min(self.search_size, int(cy + target_h / 2))
|
| 817 |
+
search[:, sy1:sy2, sx1:sx2] = color
|
| 818 |
+
|
| 819 |
+
searches.append(search)
|
| 820 |
+
heatmaps.append(self._make_heatmap(cx, cy, target_w, target_h))
|
| 821 |
+
sizes.append(torch.tensor([target_w / self.search_size,
|
| 822 |
+
target_h / self.search_size]))
|
| 823 |
+
boxes.append(torch.tensor([cx, cy, target_w, target_h]))
|
| 824 |
|
| 825 |
return {
|
| 826 |
+
'template': template, # (3, 128, 128)
|
| 827 |
+
'searches': torch.stack(searches, dim=0), # (K, 3, 256, 256)
|
| 828 |
+
'heatmaps': torch.stack(heatmaps, dim=0), # (K, 1, 16, 16)
|
| 829 |
+
'sizes': torch.stack(sizes, dim=0), # (K, 2)
|
| 830 |
+
'boxes': torch.stack(boxes, dim=0), # (K, 4)
|
| 831 |
}
|
| 832 |
|
| 833 |
def set_acl_difficulty(self, difficulty: float):
|
|
|
|
| 1268 |
class TrackingDataset(SyntheticTrackingDataset):
|
| 1269 |
"""Backward-compatible alias for SyntheticTrackingDataset."""
|
| 1270 |
def __init__(self, data_dir=None, split='train', synthetic=False,
|
| 1271 |
+
synthetic_length=10000, clip_length=3, **kwargs):
|
| 1272 |
+
super().__init__(length=synthetic_length, clip_length=clip_length, **kwargs)
|