omar-ah commited on
Commit
823a1a3
·
verified ·
1 Parent(s): fc9248f

Upload vil_tracker/data/dataset.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. vil_tracker/data/dataset.py +147 -0
vil_tracker/data/dataset.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tracking dataset with synthetic fallback for testing.
3
+
4
+ Supports:
5
+ - GOT-10k, LaSOT, TrackingNet, COCO formats
6
+ - Synthetic data generation for testing (no external data needed)
7
+ - ACL (Adaptive Curriculum Learning) difficulty scaling
8
+ - Standard tracking augmentations: jitter, flip, color aug
9
+ """
10
+
11
+ import os
12
+ import math
13
+ import random
14
+ import torch
15
+ import numpy as np
16
+ from torch.utils.data import Dataset
17
+
18
+
19
+ class TrackingDataset(Dataset):
20
+ """Tracking dataset for ViL Tracker training.
21
+
22
+ Each sample provides:
23
+ - template: (3, 128, 128) template crop
24
+ - search: (3, 256, 256) search region crop
25
+ - heatmap: (1, 16, 16) GT center heatmap
26
+ - size: (2,) GT normalized [w, h]
27
+ - boxes: (4,) GT [cx, cy, w, h] in search region pixels
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ data_dir: str = None,
33
+ split: str = 'train',
34
+ template_size: int = 128,
35
+ search_size: int = 256,
36
+ feat_size: int = 16,
37
+ acl_difficulty: float = 1.0,
38
+ synthetic: bool = False,
39
+ synthetic_length: int = 10000,
40
+ ):
41
+ super().__init__()
42
+ self.template_size = template_size
43
+ self.search_size = search_size
44
+ self.feat_size = feat_size
45
+ self.acl_difficulty = acl_difficulty
46
+ self.synthetic = synthetic
47
+ self.synthetic_length = synthetic_length
48
+
49
+ if synthetic:
50
+ self.samples = list(range(synthetic_length))
51
+ else:
52
+ self.samples = self._load_dataset(data_dir, split)
53
+
54
+ def _load_dataset(self, data_dir, split):
55
+ """Load dataset file list. Returns list of sample dicts."""
56
+ samples = []
57
+ if data_dir and os.path.exists(data_dir):
58
+ # Load real dataset
59
+ ann_file = os.path.join(data_dir, f'{split}.json')
60
+ if os.path.exists(ann_file):
61
+ import json
62
+ with open(ann_file, 'r') as f:
63
+ samples = json.load(f)
64
+
65
+ if not samples:
66
+ print(f"Warning: No data found at {data_dir}, using synthetic data")
67
+ self.synthetic = True
68
+ self.synthetic_length = 10000
69
+ return list(range(self.synthetic_length))
70
+
71
+ return samples
72
+
73
+ def __len__(self):
74
+ return len(self.samples) if not self.synthetic else self.synthetic_length
75
+
76
+ def _generate_synthetic_sample(self, idx):
77
+ """Generate a synthetic template/search pair with GT annotations."""
78
+ rng = random.Random(idx)
79
+
80
+ # Random target size (relative to search region)
81
+ target_w = rng.uniform(0.1, 0.5) * self.search_size
82
+ target_h = rng.uniform(0.1, 0.5) * self.search_size
83
+
84
+ # Random center (with difficulty-dependent jitter)
85
+ jitter = self.acl_difficulty * 0.3
86
+ cx = self.search_size / 2 + rng.gauss(0, jitter * self.search_size)
87
+ cy = self.search_size / 2 + rng.gauss(0, jitter * self.search_size)
88
+ cx = max(target_w / 2, min(self.search_size - target_w / 2, cx))
89
+ cy = max(target_h / 2, min(self.search_size - target_h / 2, cy))
90
+
91
+ # Create synthetic images (colored rectangles on noise background)
92
+ template = torch.randn(3, self.template_size, self.template_size) * 0.1
93
+ search = torch.randn(3, self.search_size, self.search_size) * 0.1
94
+
95
+ # Draw target in template (centered)
96
+ t_half_w = int(min(target_w / 2, self.template_size / 2 - 1))
97
+ t_half_h = int(min(target_h / 2, self.template_size / 2 - 1))
98
+ tc = self.template_size // 2
99
+ color = torch.tensor([rng.random(), rng.random(), rng.random()]).view(3, 1, 1)
100
+ template[:, tc-t_half_h:tc+t_half_h, tc-t_half_w:tc+t_half_w] = color
101
+
102
+ # Draw target in search region
103
+ sx1 = max(0, int(cx - target_w / 2))
104
+ sy1 = max(0, int(cy - target_h / 2))
105
+ sx2 = min(self.search_size, int(cx + target_w / 2))
106
+ sy2 = min(self.search_size, int(cy + target_h / 2))
107
+ search[:, sy1:sy2, sx1:sx2] = color
108
+
109
+ # Generate GT heatmap
110
+ stride = self.search_size / self.feat_size
111
+ cx_feat = cx / stride
112
+ cy_feat = cy / stride
113
+
114
+ y = torch.arange(self.feat_size, dtype=torch.float32)
115
+ x = torch.arange(self.feat_size, dtype=torch.float32)
116
+ yy, xx = torch.meshgrid(y, x, indexing='ij')
117
+
118
+ sigma = 2.0
119
+ dist_sq = (xx - cx_feat) ** 2 + (yy - cy_feat) ** 2
120
+ heatmap = torch.exp(-dist_sq / (2 * sigma ** 2)).unsqueeze(0)
121
+
122
+ # Normalized size
123
+ size = torch.tensor([target_w / self.search_size, target_h / self.search_size])
124
+
125
+ # Box in pixels
126
+ boxes = torch.tensor([cx, cy, target_w, target_h])
127
+
128
+ return {
129
+ 'template': template,
130
+ 'search': search,
131
+ 'heatmap': heatmap,
132
+ 'size': size,
133
+ 'boxes': boxes,
134
+ }
135
+
136
+ def __getitem__(self, idx):
137
+ if self.synthetic:
138
+ return self._generate_synthetic_sample(idx)
139
+
140
+ # Real data loading would go here
141
+ sample = self.samples[idx]
142
+ # ... load images, compute crops, generate targets
143
+ return self._generate_synthetic_sample(idx) # fallback
144
+
145
+ def set_acl_difficulty(self, difficulty: float):
146
+ """Update ACL difficulty level (0.0 = easy, 1.0 = hard)."""
147
+ self.acl_difficulty = min(1.0, max(0.0, difficulty))