guychuk commited on
Commit
41e79d0
·
verified ·
1 Parent(s): 0edaa2f

v7: add --held-out-classes and --keep-only-classes for generalization eval

Browse files
Files changed (1) hide show
  1. train_v7_finetune.py +30 -1
train_v7_finetune.py CHANGED
@@ -112,7 +112,10 @@ class BlindspotDataset(Dataset):
112
  def __init__(self, airport, mode, data_dir,
113
  past_max=256, past_min=60,
114
  delta_min=30, delta_max=120,
115
- seed=0, epoch_multiplier=4):
 
 
 
116
  ensure_data(airport, data_dir)
117
  airport_dir = os.path.join(data_dir, airport)
118
  raw, labels = load_atfm(airport, mode, airport_dir)
@@ -134,6 +137,17 @@ class BlindspotDataset(Dataset):
134
  raise RuntimeError(
135
  f"No trajectories of length >= {min_required} in {airport}/{mode}"
136
  )
 
 
 
 
 
 
 
 
 
 
 
137
  raw = raw[keep]
138
  lengths = lengths[keep]
139
  self.labels = labels[keep].astype(np.int64)
@@ -838,6 +852,10 @@ def main():
838
  help="If --pretrained-encoder is a HF repo, name of the file in it.")
839
  p.add_argument("--freeze-encoder", action="store_true",
840
  help="Freeze tokenizer + encoder weights after loading pretrained.")
 
 
 
 
841
  p.add_argument("--trackio-name", default=None)
842
  args = p.parse_args()
843
 
@@ -862,23 +880,34 @@ def main():
862
  trackio.init(project="flight-jepa-v2", name=args.trackio_name,
863
  config=vars(args))
864
 
 
 
 
 
 
 
 
 
865
  train_ds = BlindspotDataset(
866
  airport=args.airport, mode="TRAIN", data_dir=args.data_dir,
867
  past_max=args.past_max, past_min=args.past_min,
868
  delta_min=args.delta_min, delta_max=args.delta_max,
869
  seed=args.seed, epoch_multiplier=args.epoch_multiplier,
 
870
  )
871
  test_ds = BlindspotDataset(
872
  airport=args.airport, mode="TEST", data_dir=args.data_dir,
873
  past_max=args.past_max, past_min=args.past_min,
874
  delta_min=args.delta_min, delta_max=args.delta_max,
875
  seed=args.seed + 1, epoch_multiplier=1,
 
876
  )
877
  extrap_ds = BlindspotDataset(
878
  airport=args.airport, mode="TEST", data_dir=args.data_dir,
879
  past_max=args.past_max, past_min=args.past_min,
880
  delta_min=args.delta_min, delta_max=args.extrap_delta_max,
881
  seed=args.seed + 99, epoch_multiplier=1,
 
882
  )
883
 
884
  train_dl = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True,
 
112
  def __init__(self, airport, mode, data_dir,
113
  past_max=256, past_min=60,
114
  delta_min=30, delta_max=120,
115
+ seed=0, epoch_multiplier=4,
116
+ held_out_classes=None, # None = no filter; list = exclude these classes
117
+ keep_only_classes=None, # None = no filter; list = keep ONLY these classes (overrides held_out)
118
+ ):
119
  ensure_data(airport, data_dir)
120
  airport_dir = os.path.join(data_dir, airport)
121
  raw, labels = load_atfm(airport, mode, airport_dir)
 
137
  raise RuntimeError(
138
  f"No trajectories of length >= {min_required} in {airport}/{mode}"
139
  )
140
+
141
+ # Class-based filtering (held-out generalization eval)
142
+ if keep_only_classes is not None:
143
+ keep_set = set(int(c) for c in keep_only_classes)
144
+ class_keep = np.array([int(c) in keep_set for c in labels])
145
+ keep = keep & class_keep
146
+ elif held_out_classes is not None:
147
+ held = set(int(c) for c in held_out_classes)
148
+ class_keep = np.array([int(c) not in held for c in labels])
149
+ keep = keep & class_keep
150
+
151
  raw = raw[keep]
152
  lengths = lengths[keep]
153
  self.labels = labels[keep].astype(np.int64)
 
852
  help="If --pretrained-encoder is a HF repo, name of the file in it.")
853
  p.add_argument("--freeze-encoder", action="store_true",
854
  help="Freeze tokenizer + encoder weights after loading pretrained.")
855
+ p.add_argument("--held-out-classes", default=None,
856
+ help="Comma-separated class IDs to EXCLUDE from training (e.g., '6,18,28').")
857
+ p.add_argument("--keep-only-classes", default=None,
858
+ help="Comma-separated class IDs to KEEP for evaluation (eval on these only).")
859
  p.add_argument("--trackio-name", default=None)
860
  args = p.parse_args()
861
 
 
880
  trackio.init(project="flight-jepa-v2", name=args.trackio_name,
881
  config=vars(args))
882
 
883
+ held_out = (
884
+ [int(c) for c in args.held_out_classes.split(",")]
885
+ if args.held_out_classes else None
886
+ )
887
+ keep_only = (
888
+ [int(c) for c in args.keep_only_classes.split(",")]
889
+ if args.keep_only_classes else None
890
+ )
891
  train_ds = BlindspotDataset(
892
  airport=args.airport, mode="TRAIN", data_dir=args.data_dir,
893
  past_max=args.past_max, past_min=args.past_min,
894
  delta_min=args.delta_min, delta_max=args.delta_max,
895
  seed=args.seed, epoch_multiplier=args.epoch_multiplier,
896
+ held_out_classes=held_out, keep_only_classes=keep_only,
897
  )
898
  test_ds = BlindspotDataset(
899
  airport=args.airport, mode="TEST", data_dir=args.data_dir,
900
  past_max=args.past_max, past_min=args.past_min,
901
  delta_min=args.delta_min, delta_max=args.delta_max,
902
  seed=args.seed + 1, epoch_multiplier=1,
903
+ held_out_classes=held_out, keep_only_classes=keep_only,
904
  )
905
  extrap_ds = BlindspotDataset(
906
  airport=args.airport, mode="TEST", data_dir=args.data_dir,
907
  past_max=args.past_max, past_min=args.past_min,
908
  delta_min=args.delta_min, delta_max=args.extrap_delta_max,
909
  seed=args.seed + 99, epoch_multiplier=1,
910
+ held_out_classes=held_out, keep_only_classes=keep_only,
911
  )
912
 
913
  train_dl = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True,