omar-ah commited on
Commit
1322185
·
verified ·
1 Parent(s): 51d2470

Add UAV dataset loaders: VisDrone-SOT, UAVDT, WebUAV-3M

Browse files
Files changed (1) hide show
  1. vil_tracker/data/dataset.py +350 -0
vil_tracker/data/dataset.py CHANGED
@@ -768,6 +768,319 @@ class SyntheticTrackingDataset(Dataset):
768
  self.acl_difficulty = min(1.0, max(0.0, difficulty))
769
 
770
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
771
  # ============================================================
772
  # Convenience: build combined dataset
773
  # ============================================================
@@ -781,12 +1094,27 @@ def build_tracking_dataset(
781
  ) -> Dataset:
782
  """Build a combined tracking dataset from multiple sources.
783
 
 
 
 
 
 
 
784
  Args:
785
  data_config: dict with optional keys:
 
786
  - 'got10k_root': path to GOT-10k dataset
787
  - 'lasot_root': path to LaSOT dataset
788
  - 'trackingnet_root': path to TrackingNet dataset
789
  - 'coco_root': path to COCO train2017 images
 
 
 
 
 
 
 
 
790
  - 'synthetic_length': number of synthetic samples (fallback)
791
  template_size: template crop size
792
  search_size: search region crop size
@@ -828,6 +1156,28 @@ def build_tracking_dataset(
828
  datasets.append(ds)
829
  print(f"COCO: {len(ds)} pseudo-sequences")
830
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
831
  if datasets:
832
  combined = ConcatDataset(datasets)
833
  print(f"\nTotal training samples: {len(combined)}")
 
768
  self.acl_difficulty = min(1.0, max(0.0, difficulty))
769
 
770
 
771
+ # ============================================================
772
+ # VisDrone-SOT dataset loader (UAV)
773
+ # ============================================================
774
+
775
+ class VisDroneSOTDataset(SequenceDataset):
776
+ """VisDrone-SOT single object tracking dataset (drone/UAV perspective).
777
+
778
+ Structure:
779
+ root/
780
+ VisDrone2019-SOT-train/
781
+ sequences/
782
+ uav0000001_00000_s/
783
+ 0000001.jpg, 0000002.jpg, ...
784
+ ...
785
+ annotations/
786
+ uav0000001_00000_s.txt # x,y,w,h per line
787
+ ...
788
+
789
+ Splits: train (86 sequences, ~70K frames), val (11 sequences),
790
+ test-dev (35 sequences), test-challenge (35 sequences)
791
+
792
+ Key for our tracker: real drone footage with small targets, fast motion,
793
+ viewpoint changes, and camera ego-motion — the exact conditions we deploy in.
794
+ """
795
+
796
+ def __init__(self, root: str, split: str = 'train', **kwargs):
797
+ super().__init__(**kwargs)
798
+ self.root = Path(root)
799
+ self._load_sequences(split)
800
+
801
+ def _load_sequences(self, split):
802
+ # Try multiple directory naming conventions
803
+ split_names = {
804
+ 'train': ['VisDrone2019-SOT-train', 'VisDrone2018-SOT-train', 'train'],
805
+ 'val': ['VisDrone2019-SOT-val', 'VisDrone2018-SOT-val', 'val'],
806
+ 'test': ['VisDrone2019-SOT-test-dev', 'VisDrone2018-SOT-test', 'test-dev', 'test'],
807
+ }
808
+
809
+ split_dir = None
810
+ for name in split_names.get(split, [split]):
811
+ candidate = self.root / name
812
+ if candidate.exists():
813
+ split_dir = candidate
814
+ break
815
+ # Also check if root itself is the split dir
816
+ if (self.root / 'sequences').exists():
817
+ split_dir = self.root
818
+ break
819
+
820
+ if split_dir is None:
821
+ print(f"Warning: VisDrone-SOT {split} not found at {self.root}")
822
+ return
823
+
824
+ seq_dir = split_dir / 'sequences'
825
+ anno_dir = split_dir / 'annotations'
826
+
827
+ if not seq_dir.exists() or not anno_dir.exists():
828
+ print(f"Warning: VisDrone-SOT missing sequences/ or annotations/ at {split_dir}")
829
+ return
830
+
831
+ total_seqs = 0
832
+ for anno_file in sorted(anno_dir.glob('*.txt')):
833
+ seq_name = anno_file.stem
834
+ frames_dir = seq_dir / seq_name
835
+
836
+ if not frames_dir.exists():
837
+ continue
838
+
839
+ gt_boxes = []
840
+ with open(anno_file, 'r') as f:
841
+ for line in f:
842
+ line = line.strip()
843
+ if not line:
844
+ gt_boxes.append(None)
845
+ continue
846
+ parts = line.replace(',', ' ').split()
847
+ try:
848
+ gt_boxes.append([float(x) for x in parts[:4]])
849
+ except ValueError:
850
+ gt_boxes.append(None)
851
+
852
+ frames = sorted(glob.glob(str(frames_dir / '*.jpg')))
853
+ if not frames:
854
+ frames = sorted(glob.glob(str(frames_dir / '*.png')))
855
+
856
+ if len(frames) != len(gt_boxes):
857
+ min_len = min(len(frames), len(gt_boxes))
858
+ frames = frames[:min_len]
859
+ gt_boxes = gt_boxes[:min_len]
860
+
861
+ if len(frames) >= 2:
862
+ self.sequences.append({'frames': frames, 'gt': gt_boxes})
863
+ total_seqs += 1
864
+
865
+ print(f" Loaded {total_seqs} VisDrone-SOT {split} sequences")
866
+
867
+
868
+ # ============================================================
869
+ # UAVDT dataset loader (UAV)
870
+ # ============================================================
871
+
872
+ class UAVDTDataset(SequenceDataset):
873
+ """UAVDT (Unmanned Aerial Vehicle Detection and Tracking) dataset.
874
+
875
+ Structure:
876
+ root/
877
+ UAV-benchmark-S/ # SOT annotations
878
+ {seq_name}/
879
+ {seq_name}_gt.txt # x,y,w,h per line (or comma-separated)
880
+ UAV-benchmark-M/ # Frames
881
+ {seq_name}/
882
+ img000001.jpg, img000002.jpg, ...
883
+
884
+ Alternative structure (simpler):
885
+ root/
886
+ sequences/
887
+ {seq_name}/
888
+ img000001.jpg, ...
889
+ annotations/
890
+ {seq_name}_gt.txt
891
+
892
+ 50 sequences total, typically 30 train / 20 test.
893
+ Contains vehicle tracking from drone perspective — complementary to VisDrone.
894
+ """
895
+
896
+ def __init__(self, root: str, split: str = 'train', **kwargs):
897
+ super().__init__(**kwargs)
898
+ self.root = Path(root)
899
+ self._load_sequences(split)
900
+
901
+ def _load_sequences(self, split):
902
+ # Try standard UAVDT structure
903
+ anno_dir = self.root / 'UAV-benchmark-S'
904
+ frame_dir = self.root / 'UAV-benchmark-M'
905
+
906
+ if not anno_dir.exists():
907
+ # Alternative structure
908
+ anno_dir = self.root / 'annotations'
909
+ frame_dir = self.root / 'sequences'
910
+
911
+ if not anno_dir.exists():
912
+ # Try root directly having sequence dirs
913
+ anno_dir = self.root
914
+ frame_dir = self.root
915
+
916
+ if not anno_dir.exists():
917
+ print(f"Warning: UAVDT not found at {self.root}")
918
+ return
919
+
920
+ # Collect all sequences
921
+ all_seqs = []
922
+
923
+ # Find annotation files
924
+ gt_files = sorted(anno_dir.rglob('*_gt.txt'))
925
+ if not gt_files:
926
+ gt_files = sorted(anno_dir.rglob('*.txt'))
927
+
928
+ for gt_file in gt_files:
929
+ seq_name = gt_file.stem.replace('_gt', '')
930
+
931
+ # Find frames directory
932
+ frames_path = None
933
+ for candidate in [
934
+ frame_dir / seq_name,
935
+ frame_dir / seq_name / 'img',
936
+ self.root / seq_name,
937
+ ]:
938
+ if candidate.exists():
939
+ frames_path = candidate
940
+ break
941
+
942
+ if frames_path is None:
943
+ continue
944
+
945
+ gt_boxes = []
946
+ with open(gt_file, 'r') as f:
947
+ for line in f:
948
+ line = line.strip()
949
+ if not line:
950
+ gt_boxes.append(None)
951
+ continue
952
+ parts = line.replace(',', ' ').replace('\t', ' ').split()
953
+ try:
954
+ gt_boxes.append([float(x) for x in parts[:4]])
955
+ except (ValueError, IndexError):
956
+ gt_boxes.append(None)
957
+
958
+ frames = sorted(glob.glob(str(frames_path / '*.jpg')))
959
+ if not frames:
960
+ frames = sorted(glob.glob(str(frames_path / '*.png')))
961
+
962
+ if len(frames) != len(gt_boxes):
963
+ min_len = min(len(frames), len(gt_boxes))
964
+ frames = frames[:min_len]
965
+ gt_boxes = gt_boxes[:min_len]
966
+
967
+ if len(frames) >= 2:
968
+ all_seqs.append({'frames': frames, 'gt': gt_boxes, 'name': seq_name})
969
+
970
+ # Split: first 60% train, last 40% test (standard UAVDT protocol)
971
+ all_seqs.sort(key=lambda x: x['name'])
972
+ split_idx = int(len(all_seqs) * 0.6)
973
+
974
+ if split == 'train':
975
+ selected = all_seqs[:split_idx]
976
+ else:
977
+ selected = all_seqs[split_idx:]
978
+
979
+ for seq in selected:
980
+ self.sequences.append({'frames': seq['frames'], 'gt': seq['gt']})
981
+
982
+ print(f" Loaded {len(self.sequences)} UAVDT {split} sequences "
983
+ f"(from {len(all_seqs)} total)")
984
+
985
+
986
+ # ============================================================
987
+ # WebUAV-3M dataset loader (UAV, large-scale)
988
+ # ============================================================
989
+
990
+ class WebUAV3MDataset(SequenceDataset):
991
+ """WebUAV-3M: million-scale multi-modal UAV tracking dataset.
992
+
993
+ Structure:
994
+ root/
995
+ {superclass}/ # e.g., person, vehicle, animal
996
+ {seq_name}/
997
+ img/
998
+ 000001.jpg, 000002.jpg, ...
999
+ groundtruth_rect.txt # x,y,w,h per line
1000
+ OR:
1001
+ {seq_name}/
1002
+ *.jpg
1003
+ groundtruth_rect.txt
1004
+
1005
+ 4,500 sequences, 3.3M frames, 12 superclasses, 223 target classes.
1006
+ Average video length: 710 frames (23.7 seconds at 30 FPS).
1007
+
1008
+ This is the largest UAV tracking dataset. All sequences are from real
1009
+ drone footage. Purpose-built for training deep UAV trackers.
1010
+ """
1011
+
1012
+ def __init__(self, root: str, split: str = 'train', max_sequences: int = None, **kwargs):
1013
+ super().__init__(**kwargs)
1014
+ self.root = Path(root)
1015
+ self._load_sequences(split, max_sequences)
1016
+
1017
+ def _load_sequences(self, split, max_sequences):
1018
+ if not self.root.exists():
1019
+ print(f"Warning: WebUAV-3M not found at {self.root}")
1020
+ return
1021
+
1022
+ # Find all sequences recursively
1023
+ all_seq_dirs = []
1024
+
1025
+ # Look for groundtruth files recursively
1026
+ gt_files = sorted(self.root.rglob('groundtruth_rect.txt'))
1027
+ if not gt_files:
1028
+ gt_files = sorted(self.root.rglob('groundtruth.txt'))
1029
+
1030
+ for gt_file in gt_files:
1031
+ seq_dir = gt_file.parent
1032
+ # Check for img subdirectory or direct frames
1033
+ img_dir = seq_dir / 'img'
1034
+ if not img_dir.exists():
1035
+ img_dir = seq_dir # frames directly in seq dir
1036
+
1037
+ frames = sorted(glob.glob(str(img_dir / '*.jpg')))
1038
+ if not frames:
1039
+ frames = sorted(glob.glob(str(img_dir / '*.png')))
1040
+
1041
+ if len(frames) >= 2:
1042
+ all_seq_dirs.append((gt_file, frames))
1043
+
1044
+ print(f"WebUAV-3M: found {len(all_seq_dirs)} sequences total")
1045
+
1046
+ # Train/test split (80/20)
1047
+ split_idx = int(len(all_seq_dirs) * 0.8)
1048
+ if split == 'train':
1049
+ selected = all_seq_dirs[:split_idx]
1050
+ else:
1051
+ selected = all_seq_dirs[split_idx:]
1052
+
1053
+ # Optionally limit sequences (WebUAV-3M is huge)
1054
+ if max_sequences and len(selected) > max_sequences:
1055
+ # Sample uniformly to maintain diversity
1056
+ step = len(selected) // max_sequences
1057
+ selected = selected[::step][:max_sequences]
1058
+
1059
+ for gt_file, frames in selected:
1060
+ gt_boxes = []
1061
+ with open(gt_file, 'r') as f:
1062
+ for line in f:
1063
+ line = line.strip()
1064
+ if not line:
1065
+ gt_boxes.append(None)
1066
+ continue
1067
+ parts = line.replace(',', ' ').replace('\t', ' ').split()
1068
+ try:
1069
+ gt_boxes.append([float(x) for x in parts[:4]])
1070
+ except (ValueError, IndexError):
1071
+ gt_boxes.append(None)
1072
+
1073
+ if len(frames) != len(gt_boxes):
1074
+ min_len = min(len(frames), len(gt_boxes))
1075
+ frames = frames[:min_len]
1076
+ gt_boxes = gt_boxes[:min_len]
1077
+
1078
+ if len(frames) >= 2:
1079
+ self.sequences.append({'frames': frames, 'gt': gt_boxes})
1080
+
1081
+ print(f" Loaded {len(self.sequences)} WebUAV-3M {split} sequences")
1082
+
1083
+
1084
  # ============================================================
1085
  # Convenience: build combined dataset
1086
  # ============================================================
 
1094
  ) -> Dataset:
1095
  """Build a combined tracking dataset from multiple sources.
1096
 
1097
+ Standard ground-level datasets provide general tracking capability.
1098
+ UAV-specific datasets provide drone-perspective specialization.
1099
+ The ACL curriculum bridges the gap: it starts training on easy pairs
1100
+ from ground-level data, then progressively incorporates harder pairs
1101
+ including UAV sequences with fast motion, small targets, and viewpoint changes.
1102
+
1103
  Args:
1104
  data_config: dict with optional keys:
1105
+ Ground-level (standard tracking training data):
1106
  - 'got10k_root': path to GOT-10k dataset
1107
  - 'lasot_root': path to LaSOT dataset
1108
  - 'trackingnet_root': path to TrackingNet dataset
1109
  - 'coco_root': path to COCO train2017 images
1110
+
1111
+ UAV-specific (drone perspective — the deployment domain):
1112
+ - 'visdrone_root': path to VisDrone-SOT dataset
1113
+ - 'uavdt_root': path to UAVDT dataset
1114
+ - 'webuav3m_root': path to WebUAV-3M dataset
1115
+ - 'webuav3m_max_sequences': limit WebUAV-3M sequences (default: None = all)
1116
+
1117
+ Fallback:
1118
  - 'synthetic_length': number of synthetic samples (fallback)
1119
  template_size: template crop size
1120
  search_size: search region crop size
 
1156
  datasets.append(ds)
1157
  print(f"COCO: {len(ds)} pseudo-sequences")
1158
 
1159
+ # --- UAV-specific datasets (drone perspective) ---
1160
+
1161
+ if 'visdrone_root' in data_config and os.path.exists(data_config['visdrone_root']):
1162
+ ds = VisDroneSOTDataset(data_config['visdrone_root'], split='train', **common_kwargs)
1163
+ if len(ds) > 0:
1164
+ datasets.append(ds)
1165
+ print(f"VisDrone-SOT: {len(ds)} UAV sequences")
1166
+
1167
+ if 'uavdt_root' in data_config and os.path.exists(data_config['uavdt_root']):
1168
+ ds = UAVDTDataset(data_config['uavdt_root'], split='train', **common_kwargs)
1169
+ if len(ds) > 0:
1170
+ datasets.append(ds)
1171
+ print(f"UAVDT: {len(ds)} UAV sequences")
1172
+
1173
+ if 'webuav3m_root' in data_config and os.path.exists(data_config['webuav3m_root']):
1174
+ max_seq = data_config.get('webuav3m_max_sequences', None)
1175
+ ds = WebUAV3MDataset(data_config['webuav3m_root'], split='train',
1176
+ max_sequences=max_seq, **common_kwargs)
1177
+ if len(ds) > 0:
1178
+ datasets.append(ds)
1179
+ print(f"WebUAV-3M: {len(ds)} UAV sequences")
1180
+
1181
  if datasets:
1182
  combined = ConcatDataset(datasets)
1183
  print(f"\nTotal training samples: {len(combined)}")