# /// script # requires-python = ">=3.10" # dependencies = ["torch>=2.1","numpy","pandas","scikit-learn","huggingface-hub","trackio"] # /// """ Flight-JEPA v7 — past-track masked JEPA pretraining. Adapted from Forecast-MAE (arxiv:2308.09882) and I-JEPA (arxiv:2301.08243): mask contiguous blocks of *past-track* patches and train an encoder + EMA target + predictor to reconstruct masked-patch latents from visible context. Encoder weights then transfer to v6 fine-tuning. Key differences from v6's JEPA aux: - Pretraining-only objective (no forecasting head, no Δ conditioning). - Masks past-track patches, not future segments. - Trains on the same RKSIa data — this is small-scale demo, not OpenSky-scale. - Output: a `pretrained_encoder.pt` checkpoint loadable by v6 fine-tune. Decision criterion at fine-tune time: - Significant FDE improvement at ≥30% past-track dropout (test-time). - No regression at 0% dropout. """ from __future__ import annotations import argparse import copy import json import math import os import shutil import time import numpy as np import pandas as pd import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader try: import trackio HAS_TRACKIO = True except ImportError: HAS_TRACKIO = False # ============================================================================ # DATA UTILITIES (inlined from train_v2_prod.py for self-contained job) # ============================================================================ def load_atfm(dset_name, mode, path): variables = ["X", "Y", "Z"] data, labels = [], None for var in variables: df = pd.read_csv(os.path.join(path, f"{dset_name}_{mode}_{var}.tsv"), sep="\t", header=None, na_values="NaN") if labels is None: labels = df.values[:, 0] data.append(df.values[:, 1:]) return np.stack(data, axis=-1), labels.astype(int) def compute_features(traj_xyz: np.ndarray) -> np.ndarray: if traj_xyz.shape[0] < 2: T = traj_xyz.shape[0] return np.concatenate([ traj_xyz, np.zeros((T, 3), dtype=traj_xyz.dtype), np.zeros((T, 3), dtype=traj_xyz.dtype) ], axis=1) x, y, z = traj_xyz[:, 0], traj_xyz[:, 1], traj_xyz[:, 2] diff = np.diff(traj_xyz, axis=0) norms = np.maximum(np.linalg.norm(diff, axis=1, keepdims=True), 1e-8) u = diff / norms u = np.vstack([u, u[-1:]]) r = np.sqrt(x ** 2 + y ** 2) theta = np.arctan2(y, x) return np.column_stack([ traj_xyz, u, r[:, None], np.sin(theta)[:, None], np.cos(theta)[:, None] ]).astype(np.float32) def ensure_data(airport: str, data_dir: str = "data"): target = os.path.join(data_dir, airport) if os.path.isdir(target) and any(f.endswith(".tsv") for f in os.listdir(target)): return target print(f"[data] downloading {airport} from HF ...") from huggingface_hub import snapshot_download snap = snapshot_download( "petchthwr/ATFMTraj", repo_type="dataset", allow_patterns=[f"{airport}/*"], ) os.makedirs(data_dir, exist_ok=True) src = os.path.join(snap, airport) if not os.path.isdir(target): shutil.copytree(src, target) return target # ============================================================================ # DATASET — full past-track windows (no Δ / target) # ============================================================================ class PastTrackDataset(Dataset): """ Yields fixed-length past-track windows for masked-prediction pretraining. Per __getitem__: - sample a random window of length past_len from a trajectory - return its 9-dim features padded to past_max """ def __init__(self, airport, mode, data_dir, past_max=256, past_min=128, seed=0, epoch_multiplier=4): # `airport` may be a single string ("RKSIa") or comma-separated # ("RKSId,ESSA,LSZH") for v8 multi-airport pretraining. We union # the trajectories without preserving airport-id (pretraining only # uses features, not labels or airport tokens). airports = [a.strip() for a in airport.split(",")] all_positions = [] for ap in airports: ensure_data(ap, data_dir) ap_dir = os.path.join(data_dir, ap) raw, _ = load_atfm(ap, mode, ap_dir) lengths = np.array( [int(np.sum(~np.isnan(raw[i, :, 0]))) for i in range(raw.shape[0])], dtype=np.int64, ) keep = lengths >= past_min + 1 raw = raw[keep] lengths = lengths[keep] n_kept = 0 for i in range(raw.shape[0]): L = int(lengths[i]) all_positions.append( np.nan_to_num(raw[i, :L], nan=0.0).astype(np.float32) ) n_kept += 1 print(f"[data] {ap}/{mode}: {n_kept} trajectories") self.past_max = past_max self.past_min = past_min self.epoch_multiplier = epoch_multiplier self.rng_seed = seed self.positions = all_positions self.n_traj = len(self.positions) if len(airports) > 1: print(f"[data] union total: {self.n_traj} trajectories from " f"{airports}") def __len__(self): return self.n_traj * self.epoch_multiplier def __getitem__(self, idx): traj_idx = idx % self.n_traj rng = np.random.default_rng(self.rng_seed + idx * 9173) positions = self.positions[traj_idx] L = positions.shape[0] past_len = min(self.past_max, L) start = int(rng.integers(0, max(1, L - past_len + 1))) window = positions[start:start + past_len] feats = compute_features(window) T = feats.shape[0] feat_pad = np.zeros((self.past_max, 9), dtype=np.float32) feat_pad[:T] = feats return { "features": torch.from_numpy(feat_pad), "length": torch.tensor(T, dtype=torch.long), } # ============================================================================ # MODEL — encoder + EMA target + predictor (no decoder, no Δ) # ============================================================================ class LearnablePosEnc(nn.Module): def __init__(self, max_len, d_model): super().__init__() self.pe = nn.Parameter(torch.randn(1, max_len, d_model) * 0.02) def forward(self, x): return x + self.pe[:, :x.size(1)] class PatchTokenizer(nn.Module): def __init__(self, in_channels=9, d_model=256, patch_size=8, max_patches=64): super().__init__() self.patch_size = patch_size self.d_model = d_model self.embed = nn.Sequential( nn.Conv1d(in_channels, d_model // 2, 5, padding=2), nn.GELU(), nn.Conv1d(d_model // 2, d_model, 3, padding=1), nn.GELU(), ) self.pos_enc = LearnablePosEnc(max_patches, d_model) self.norm = nn.LayerNorm(d_model) def forward(self, features, lengths): B, T, C = features.shape h = self.embed(features.transpose(1, 2)) N = max(1, T // self.patch_size) h = h[:, :, :N * self.patch_size] h = h.reshape(B, self.d_model, N, self.patch_size).mean(-1) h = h.transpose(1, 2) h = self.norm(self.pos_enc(h)) patch_lengths = (lengths.float() / self.patch_size).clamp(min=1).long() patch_lengths = patch_lengths.clamp(max=N) return h, patch_lengths class CausalEncoder(nn.Module): def __init__(self, d_model=256, n_heads=8, n_layers=4, d_ff=1024, dropout=0.1): super().__init__() layer = nn.TransformerEncoderLayer( d_model=d_model, nhead=n_heads, dim_feedforward=d_ff, dropout=dropout, activation="gelu", batch_first=True, norm_first=True, ) self.tf = nn.TransformerEncoder(layer, num_layers=n_layers) self.norm = nn.LayerNorm(d_model) def forward(self, x, key_padding_mask, attn_mask=None): N = x.size(1) if attn_mask is None: # Default: causal. For pretraining we may pass full bidirectional. attn_mask = torch.triu( torch.ones(N, N, dtype=torch.bool, device=x.device), diagonal=1 ) return self.norm( self.tf(x, mask=attn_mask, src_key_padding_mask=key_padding_mask) ) class JEPAPredictor(nn.Module): """Predict target patch latents from context patch latents. Adds a query token per masked position via positional embedding.""" def __init__(self, d_model=256, pred_dim=128, max_patches=64, dropout=0.1): super().__init__() self.proj_in = nn.Linear(d_model, pred_dim) self.target_pe = nn.Parameter(torch.randn(1, max_patches, pred_dim) * 0.02) layer = nn.TransformerEncoderLayer( d_model=pred_dim, nhead=4, dim_feedforward=pred_dim * 2, dropout=dropout, activation="gelu", batch_first=True, norm_first=True, ) self.tf = nn.TransformerEncoder(layer, num_layers=2) self.proj_out = nn.Linear(pred_dim, d_model) self.norm = nn.LayerNorm(d_model) def forward(self, ctx_latents, ctx_idx, tgt_idx): """ ctx_latents: (B, N_ctx, d_model) ctx_idx: (B, N_ctx) original positions of context patches tgt_idx: (B, N_tgt) original positions of target patches returns: (B, N_tgt, d_model) predicted target latents """ B = ctx_latents.size(0) d_pred = self.target_pe.size(-1) # Project context to pred_dim and add target positional embeddings ctx_p = self.proj_in(ctx_latents) # (B, N_ctx, d_pred) # Gather target PEs at the masked positions tgt_pe = self.target_pe.expand(B, -1, -1) # (B, max_patches, d_pred) tgt_idx_expanded = tgt_idx.unsqueeze(-1).expand(-1, -1, d_pred) tgt_q = torch.gather(tgt_pe, 1, tgt_idx_expanded) # (B, N_tgt, d_pred) # Concatenate and run transformer h = torch.cat([ctx_p, tgt_q], dim=1) # (B, N_ctx+N_tgt, d_pred) h = self.tf(h) # Take only the target-position outputs N_ctx = ctx_p.size(1) h_tgt = h[:, N_ctx:] return self.norm(self.proj_out(h_tgt)) def make_block_mask(B: int, N: int, mask_ratio: float, rng: np.random.Generator, device, min_visible: int = 4): """ Sample a contiguous block mask per batch element. Returns: ctx_idx: list of LongTensors (variable length per sample) tgt_idx: list of LongTensors For batched processing we'll right-pad and provide separate masks. """ ctx_idxs = [] tgt_idxs = [] for _ in range(B): n_mask = max(1, int(round(N * mask_ratio))) n_mask = min(n_mask, N - min_visible) # Random contiguous block start if N - n_mask <= 0: start = 0 n_mask = N - min_visible else: start = int(rng.integers(0, N - n_mask + 1)) all_idx = np.arange(N) tgt_mask = (all_idx >= start) & (all_idx < start + n_mask) ctx_idxs.append(torch.tensor(all_idx[~tgt_mask], dtype=torch.long, device=device)) tgt_idxs.append(torch.tensor(all_idx[tgt_mask], dtype=torch.long, device=device)) return ctx_idxs, tgt_idxs def gather_by_indices(x: torch.Tensor, idx_list: list[torch.Tensor], pad_value=0.0): """x: (B, N, D). idx_list: per-batch index tensors. Returns (B, N_max, D) padded plus a (B, N_max) mask of which entries are real.""" B = x.size(0); D = x.size(-1) N_max = max((idx.numel() for idx in idx_list), default=1) out = torch.full((B, N_max, D), pad_value, device=x.device, dtype=x.dtype) mask = torch.zeros((B, N_max), dtype=torch.bool, device=x.device) for b in range(B): idx = idx_list[b] n = idx.numel() if n > 0: out[b, :n] = x[b, idx] mask[b, :n] = True return out, mask def gather_indices_only(idx_list: list[torch.Tensor], device): """Pack a list of LongTensors into (B, N_max) padded with zeros.""" B = len(idx_list) N_max = max((idx.numel() for idx in idx_list), default=1) out = torch.zeros((B, N_max), dtype=torch.long, device=device) mask = torch.zeros((B, N_max), dtype=torch.bool, device=device) for b in range(B): idx = idx_list[b] n = idx.numel() if n > 0: out[b, :n] = idx mask[b, :n] = True return out, mask # ============================================================================ # PRETRAIN MODULE # ============================================================================ class FlightJEPAPretrain(nn.Module): def __init__(self, cfg): super().__init__() self.cfg = cfg d = cfg.get("d_model", 256) h_ = cfg.get("n_heads", 8) n_l = cfg.get("n_layers", 4) d_ff = cfg.get("d_ff", 1024) dr = cfg.get("dropout", 0.1) ps = cfg.get("patch_size", 8) past_max = cfg.get("past_max", 256) max_patches = past_max // ps pred_dim = cfg.get("pred_dim", 128) self.ema_decay = cfg.get("ema_decay", 0.998) self.max_patches = max_patches self.tokenizer = PatchTokenizer(9, d, ps, max_patches) self.encoder = CausalEncoder(d, h_, n_l, d_ff, dr) self.predictor = JEPAPredictor(d, pred_dim, max_patches, dr) self.target_tokenizer = copy.deepcopy(self.tokenizer) self.target_encoder = copy.deepcopy(self.encoder) for p in self.target_tokenizer.parameters(): p.requires_grad = False for p in self.target_encoder.parameters(): p.requires_grad = False @torch.no_grad() def update_ema(self): m = self.ema_decay for online, target in [(self.tokenizer, self.target_tokenizer), (self.encoder, self.target_encoder)]: for po, pt in zip(online.parameters(), target.parameters()): pt.data.mul_(m).add_(po.data, alpha=1.0 - m) def forward(self, features, lengths, mask_ratio: float, rng: np.random.Generator): """ Mask a contiguous block of patches per sample. Encode visible context. Predict masked-patch latents. Compare to EMA target encoder over the full sequence at masked positions. Pretraining uses *bidirectional* attention (no causal mask) — at fine-tune time we restore the causal mask. This gives the encoder more signal during pretraining; the encoder's transformer layers are not architecturally causal, only the mask passed in changes the mode. """ B = features.size(0) device = features.device # Tokenize for online (will be partially masked) and target (full). patches_full, patch_lens = self.tokenizer(features, lengths) N = patches_full.size(1) # Padding mask (token absent because past sequence shorter than N*patch). pad_mask = (torch.arange(N, device=device).unsqueeze(0) >= patch_lens.unsqueeze(1)) # True where padded # Sample contiguous masks per sample (drawing only from valid patches). ctx_idx_list, tgt_idx_list = [], [] for b in range(B): n_valid = int(patch_lens[b].item()) if n_valid < 8: # too short to mask meaningfully ctx_idx_list.append(torch.arange(n_valid, device=device)) tgt_idx_list.append(torch.tensor([], dtype=torch.long, device=device)) continue n_mask = max(2, int(round(n_valid * mask_ratio))) n_mask = min(n_mask, n_valid - 4) # keep at least 4 visible start = int(rng.integers(0, n_valid - n_mask + 1)) all_idx = torch.arange(n_valid, device=device) tgt_mask = (all_idx >= start) & (all_idx < start + n_mask) ctx_idx_list.append(all_idx[~tgt_mask]) tgt_idx_list.append(all_idx[tgt_mask]) # Skip if any batch element produced no targets (e.g., very short sequences) n_targets = sum(int(t.numel()) for t in tgt_idx_list) if n_targets == 0: return torch.tensor(0.0, device=device, requires_grad=True) # Pack context indices and gather context tokens. We pass through # the same encoder with a key_padding_mask that hides the masked # positions plus the original padding. # Easier: re-run encoder on a *new* tensor consisting only of context # tokens with bidirectional attention. N_ctx_max = max((idx.numel() for idx in ctx_idx_list), default=1) ctx_tokens = torch.zeros((B, N_ctx_max, patches_full.size(-1)), device=device, dtype=patches_full.dtype) ctx_kpm = torch.ones((B, N_ctx_max), dtype=torch.bool, device=device) # True=pad for b in range(B): idx = ctx_idx_list[b] n = idx.numel() if n > 0: ctx_tokens[b, :n] = patches_full[b, idx] ctx_kpm[b, :n] = False # Bidirectional attention for pretraining (full mask). bi_mask = torch.zeros((N_ctx_max, N_ctx_max), dtype=torch.bool, device=device) ctx_encoded = self.encoder(ctx_tokens, key_padding_mask=ctx_kpm, attn_mask=bi_mask) # Build target-index packed tensors tgt_idx_packed, tgt_idx_mask = gather_indices_only(tgt_idx_list, device) # Build dummy "context indices in original layout" — we need to tell # the predictor where the context tokens live (their original patch # positions). Add positional info to ctx through a side embedding — # we can use the same target_pe table for context too. # Simpler: encode their original position as a query-like input by # *pre-adding* a positional token to ctx encoded representation. # The predictor only needs target PEs — context already carries pos # info via the patch tokenizer's pos_enc, so we don't need to add # context indices. # Predict target latents. pred = self.predictor(ctx_encoded, ctx_idx=None, tgt_idx=tgt_idx_packed) # Pad target predictions to N_tgt_max already done by gather_indices_only # Targets: run the EMA target encoder on the *full* sequence # (causal mask, like fine-tune time) and gather at target positions. with torch.no_grad(): tgt_patches, _ = self.target_tokenizer(features, lengths) tgt_encoded = self.target_encoder(tgt_patches, key_padding_mask=pad_mask) tgt_latents, _ = gather_by_indices(tgt_encoded, tgt_idx_list) # L1 loss in latent space, masked over valid targets loss_per = F.l1_loss(pred, tgt_latents, reduction="none").mean(-1) # (B, N_tgt_max) loss = (loss_per * tgt_idx_mask.float()).sum() / tgt_idx_mask.sum().clamp(min=1) return loss # ============================================================================ # TRAIN LOOP # ============================================================================ def device_pick(arg=None): if arg: return arg if torch.cuda.is_available(): return "cuda" if torch.backends.mps.is_available(): return "mps" return "cpu" def train_one_epoch(model, loader, optimizer, device, mask_ratio_lo, mask_ratio_hi, log_every=50, grad_clip=1.0, rng=None): model.train() sums = {"loss": 0.0, "n": 0} t0 = time.time() rng = rng or np.random.default_rng() n_batches = len(loader) if hasattr(loader, "__len__") else 0 for bi, batch in enumerate(loader): feats = batch["features"].to(device) lens = batch["length"].to(device) mr = float(rng.uniform(mask_ratio_lo, mask_ratio_hi)) loss = model(feats, lens, mask_ratio=mr, rng=rng) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) optimizer.step() model.update_ema() bs = feats.size(0) sums["loss"] += loss.item() * bs sums["n"] += bs if (bi + 1) % log_every == 0 or bi == 0: dt = time.time() - t0 print(f" [batch {bi+1}/{n_batches}] {dt:.1f}s elapsed, " f"mr={mr:.2f}, loss={loss.item():.4f}", flush=True) n = max(sums["n"], 1) return {"loss": sums["loss"] / n} def main(): p = argparse.ArgumentParser() p.add_argument("--airport", default="RKSIa") p.add_argument("--data-dir", default="data") p.add_argument("--tag", default="v7-pretrain") p.add_argument("--out-dir", default="runs") p.add_argument("--epochs", type=int, default=60) p.add_argument("--batch-size", type=int, default=64) p.add_argument("--lr", type=float, default=1.5e-4) p.add_argument("--weight-decay", type=float, default=1e-4) p.add_argument("--past-max", type=int, default=256) p.add_argument("--past-min", type=int, default=128) p.add_argument("--epoch-multiplier", type=int, default=2) p.add_argument("--ema-decay", type=float, default=0.998) p.add_argument("--d-model", type=int, default=256) p.add_argument("--n-layers", type=int, default=4) p.add_argument("--n-heads", type=int, default=8) p.add_argument("--patch-size", type=int, default=8) p.add_argument("--pred-dim", type=int, default=128) p.add_argument("--mask-ratio-lo", type=float, default=0.3) p.add_argument("--mask-ratio-hi", type=float, default=0.7) p.add_argument("--seed", type=int, default=0) p.add_argument("--num-workers", type=int, default=2) p.add_argument("--push-to-hub", action="store_true") p.add_argument("--hub-model-id", default=None) p.add_argument("--trackio-name", default=None) args = p.parse_args() torch.manual_seed(args.seed) np.random.seed(args.seed) rng = np.random.default_rng(args.seed) device = device_pick() print(f"[v7-pretrain] device={device} tag={args.tag}", flush=True) if device == "cuda": print(f"[v7-pretrain] cuda: {torch.cuda.get_device_name(0)} " f"vram={torch.cuda.get_device_properties(0).total_memory/1e9:.1f}GB", flush=True) if HAS_TRACKIO and args.trackio_name: trackio.init(project="flight-jepa-v7-pretrain", name=args.trackio_name, config=vars(args)) train_ds = PastTrackDataset( airport=args.airport, mode="TRAIN", data_dir=args.data_dir, past_max=args.past_max, past_min=args.past_min, seed=args.seed, epoch_multiplier=args.epoch_multiplier, ) train_dl = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, drop_last=True) cfg = { "d_model": args.d_model, "n_heads": args.n_heads, "n_layers": args.n_layers, "d_ff": args.d_model * 4, "dropout": 0.1, "patch_size": args.patch_size, "past_max": args.past_max, "ema_decay": args.ema_decay, "pred_dim": args.pred_dim, } model = FlightJEPAPretrain(cfg).to(device) n_params = sum(p.numel() for p in model.parameters()) print(f"[v7-pretrain] params={n_params/1e6:.2f}M") optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) os.makedirs(args.out_dir, exist_ok=True) history = [] for epoch in range(args.epochs): t0 = time.time() stats = train_one_epoch( model, train_dl, optimizer, device, mask_ratio_lo=args.mask_ratio_lo, mask_ratio_hi=args.mask_ratio_hi, rng=rng, ) scheduler.step() elapsed = time.time() - t0 print(f"[v7-pretrain] ep {epoch+1:03d} loss={stats['loss']:.4f} | {elapsed:.0f}s", flush=True) history.append({"epoch": epoch + 1, "loss": stats["loss"], "elapsed_s": elapsed}) if HAS_TRACKIO and args.trackio_name: trackio.log({"pretrain/loss": stats["loss"]}, step=epoch + 1) out_path = os.path.join(args.out_dir, f"{args.tag}.pt") torch.save({ "encoder_state_dict": model.encoder.state_dict(), "tokenizer_state_dict": model.tokenizer.state_dict(), "config": cfg, "args": vars(args), "history": history, }, out_path) print(f"[v7-pretrain] saved {out_path}") if args.push_to_hub and args.hub_model_id: try: from huggingface_hub import HfApi api = HfApi() api.create_repo(args.hub_model_id, exist_ok=True) api.upload_file(path_or_fileobj=out_path, path_in_repo=f"{args.tag}.pt", repo_id=args.hub_model_id) print(f"[v7-pretrain] uploaded to {args.hub_model_id}") except Exception as e: print(f"[v7-pretrain] hub upload failed: {e}") if HAS_TRACKIO and args.trackio_name: trackio.finish() if __name__ == "__main__": main()