v7: add --held-out-classes and --keep-only-classes for generalization eval
Browse files- train_v7_finetune.py +30 -1
train_v7_finetune.py
CHANGED
|
@@ -112,7 +112,10 @@ class BlindspotDataset(Dataset):
|
|
| 112 |
def __init__(self, airport, mode, data_dir,
|
| 113 |
past_max=256, past_min=60,
|
| 114 |
delta_min=30, delta_max=120,
|
| 115 |
-
seed=0, epoch_multiplier=4
|
|
|
|
|
|
|
|
|
|
| 116 |
ensure_data(airport, data_dir)
|
| 117 |
airport_dir = os.path.join(data_dir, airport)
|
| 118 |
raw, labels = load_atfm(airport, mode, airport_dir)
|
|
@@ -134,6 +137,17 @@ class BlindspotDataset(Dataset):
|
|
| 134 |
raise RuntimeError(
|
| 135 |
f"No trajectories of length >= {min_required} in {airport}/{mode}"
|
| 136 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
raw = raw[keep]
|
| 138 |
lengths = lengths[keep]
|
| 139 |
self.labels = labels[keep].astype(np.int64)
|
|
@@ -838,6 +852,10 @@ def main():
|
|
| 838 |
help="If --pretrained-encoder is a HF repo, name of the file in it.")
|
| 839 |
p.add_argument("--freeze-encoder", action="store_true",
|
| 840 |
help="Freeze tokenizer + encoder weights after loading pretrained.")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 841 |
p.add_argument("--trackio-name", default=None)
|
| 842 |
args = p.parse_args()
|
| 843 |
|
|
@@ -862,23 +880,34 @@ def main():
|
|
| 862 |
trackio.init(project="flight-jepa-v2", name=args.trackio_name,
|
| 863 |
config=vars(args))
|
| 864 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 865 |
train_ds = BlindspotDataset(
|
| 866 |
airport=args.airport, mode="TRAIN", data_dir=args.data_dir,
|
| 867 |
past_max=args.past_max, past_min=args.past_min,
|
| 868 |
delta_min=args.delta_min, delta_max=args.delta_max,
|
| 869 |
seed=args.seed, epoch_multiplier=args.epoch_multiplier,
|
|
|
|
| 870 |
)
|
| 871 |
test_ds = BlindspotDataset(
|
| 872 |
airport=args.airport, mode="TEST", data_dir=args.data_dir,
|
| 873 |
past_max=args.past_max, past_min=args.past_min,
|
| 874 |
delta_min=args.delta_min, delta_max=args.delta_max,
|
| 875 |
seed=args.seed + 1, epoch_multiplier=1,
|
|
|
|
| 876 |
)
|
| 877 |
extrap_ds = BlindspotDataset(
|
| 878 |
airport=args.airport, mode="TEST", data_dir=args.data_dir,
|
| 879 |
past_max=args.past_max, past_min=args.past_min,
|
| 880 |
delta_min=args.delta_min, delta_max=args.extrap_delta_max,
|
| 881 |
seed=args.seed + 99, epoch_multiplier=1,
|
|
|
|
| 882 |
)
|
| 883 |
|
| 884 |
train_dl = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True,
|
|
|
|
| 112 |
def __init__(self, airport, mode, data_dir,
|
| 113 |
past_max=256, past_min=60,
|
| 114 |
delta_min=30, delta_max=120,
|
| 115 |
+
seed=0, epoch_multiplier=4,
|
| 116 |
+
held_out_classes=None, # None = no filter; list = exclude these classes
|
| 117 |
+
keep_only_classes=None, # None = no filter; list = keep ONLY these classes (overrides held_out)
|
| 118 |
+
):
|
| 119 |
ensure_data(airport, data_dir)
|
| 120 |
airport_dir = os.path.join(data_dir, airport)
|
| 121 |
raw, labels = load_atfm(airport, mode, airport_dir)
|
|
|
|
| 137 |
raise RuntimeError(
|
| 138 |
f"No trajectories of length >= {min_required} in {airport}/{mode}"
|
| 139 |
)
|
| 140 |
+
|
| 141 |
+
# Class-based filtering (held-out generalization eval)
|
| 142 |
+
if keep_only_classes is not None:
|
| 143 |
+
keep_set = set(int(c) for c in keep_only_classes)
|
| 144 |
+
class_keep = np.array([int(c) in keep_set for c in labels])
|
| 145 |
+
keep = keep & class_keep
|
| 146 |
+
elif held_out_classes is not None:
|
| 147 |
+
held = set(int(c) for c in held_out_classes)
|
| 148 |
+
class_keep = np.array([int(c) not in held for c in labels])
|
| 149 |
+
keep = keep & class_keep
|
| 150 |
+
|
| 151 |
raw = raw[keep]
|
| 152 |
lengths = lengths[keep]
|
| 153 |
self.labels = labels[keep].astype(np.int64)
|
|
|
|
| 852 |
help="If --pretrained-encoder is a HF repo, name of the file in it.")
|
| 853 |
p.add_argument("--freeze-encoder", action="store_true",
|
| 854 |
help="Freeze tokenizer + encoder weights after loading pretrained.")
|
| 855 |
+
p.add_argument("--held-out-classes", default=None,
|
| 856 |
+
help="Comma-separated class IDs to EXCLUDE from training (e.g., '6,18,28').")
|
| 857 |
+
p.add_argument("--keep-only-classes", default=None,
|
| 858 |
+
help="Comma-separated class IDs to KEEP for evaluation (eval on these only).")
|
| 859 |
p.add_argument("--trackio-name", default=None)
|
| 860 |
args = p.parse_args()
|
| 861 |
|
|
|
|
| 880 |
trackio.init(project="flight-jepa-v2", name=args.trackio_name,
|
| 881 |
config=vars(args))
|
| 882 |
|
| 883 |
+
held_out = (
|
| 884 |
+
[int(c) for c in args.held_out_classes.split(",")]
|
| 885 |
+
if args.held_out_classes else None
|
| 886 |
+
)
|
| 887 |
+
keep_only = (
|
| 888 |
+
[int(c) for c in args.keep_only_classes.split(",")]
|
| 889 |
+
if args.keep_only_classes else None
|
| 890 |
+
)
|
| 891 |
train_ds = BlindspotDataset(
|
| 892 |
airport=args.airport, mode="TRAIN", data_dir=args.data_dir,
|
| 893 |
past_max=args.past_max, past_min=args.past_min,
|
| 894 |
delta_min=args.delta_min, delta_max=args.delta_max,
|
| 895 |
seed=args.seed, epoch_multiplier=args.epoch_multiplier,
|
| 896 |
+
held_out_classes=held_out, keep_only_classes=keep_only,
|
| 897 |
)
|
| 898 |
test_ds = BlindspotDataset(
|
| 899 |
airport=args.airport, mode="TEST", data_dir=args.data_dir,
|
| 900 |
past_max=args.past_max, past_min=args.past_min,
|
| 901 |
delta_min=args.delta_min, delta_max=args.delta_max,
|
| 902 |
seed=args.seed + 1, epoch_multiplier=1,
|
| 903 |
+
held_out_classes=held_out, keep_only_classes=keep_only,
|
| 904 |
)
|
| 905 |
extrap_ds = BlindspotDataset(
|
| 906 |
airport=args.airport, mode="TEST", data_dir=args.data_dir,
|
| 907 |
past_max=args.past_max, past_min=args.past_min,
|
| 908 |
delta_min=args.delta_min, delta_max=args.extrap_delta_max,
|
| 909 |
seed=args.seed + 99, epoch_multiplier=1,
|
| 910 |
+
held_out_classes=held_out, keep_only_classes=keep_only,
|
| 911 |
)
|
| 912 |
|
| 913 |
train_dl = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True,
|