--- license: apache-2.0 tags: - trajectory-forecasting - aviation - ads-b - self-supervised - jepa - pytorch datasets: - petchthwr/ATFMTraj language: - en library_name: pytorch --- # Flight-JEPA v8 A trajectory forecasting model for aircraft on terminal-area approach, specialized for the **blindspot continuation** task: given an observed past track, predict the trajectory through a coverage gap of variable length and the reappearance distribution. The headline contribution is a **JEPA-style past-track masked pretraining recipe** that produces representations more robust to test- time radar coverage gaps, with the gain *generalizing across airports*. Pretrained-then-fine-tuned models maintain significantly lower forecasting error and better-calibrated uncertainty when up to 70% of past observations are missing — including on **completely held-out airports the fine-tuning never saw**. ## v8 — leave-one-airport-out (LOAO) headline Across 4 LOAO folds (held out: RKSIa / RKSId / ESSA / LSZH, n=3 seeds each), pretrained beats scratch with significance: | Past-track dropout | Mean Δ FDE | p | |---|---:|---:| | 0% (clean — no regression) | +2.8% | 0.41 | | 30% | −6.6% | 0.04 ✓ | | **50%** | **−23.4%** | **<0.001 ✓** | | **70%** | **−22.5%** | **<0.001 ✓** | 11 of 12 comparisons at ≥30% dropout reach p<0.05. All 4 LOAO folds pass the locked criterion independently. ![v8 summary](plots/summary_v8.png) ![v8 FDE per airport](plots/fde_per_airport.png) ![v8 coverage per airport](plots/coverage_per_airport.png) The result generalizes across very different airports (Korea, Sweden, Switzerland — different runway geometries and procedures). It uses an airport-ID token (UniTraj recipe, arxiv:2403.15098) for conditioning. ## v7 — single-airport reference The v7 prerequisite (single-airport, RKSIa-only) showed the same effect on a within-airport split: | Past-track dropout | Scratch FDE | Pretrained FDE | Δ | p (Welch) | |---|---:|---:|---:|---:| | 0% (clean) | 1890 ± 40 m | **1752 ± 21 m** | **−7.3%** | **0.012** | | 10% | 2381 ± 51 m | **2048 ± 62 m** | **−14.0%** | **0.002** | | 30% | 3108 ± 87 m | **2591 ± 94 m** | **−16.6%** | **0.002** | | 50% | 3784 ± 186 m | **3322 ± 42 m** | **−12.2%** | **0.044** | | 70% | 6607 ± 308 m | 6325 ± 79 m | −4.3% | 0.249 | 95% coverage at 50% dropout: **85.1% (pretrained) vs 77.1% (scratch)**, p=0.026. ## Ablations **Frozen-encoder probe** (pretrained encoder weights frozen; only the decoder trains, 3 seeds): FDE 2541 m at 0% dropout — **clearly worse than full fine-tune** (1752 m). The pretrained representation is informative but suboptimal as-is for forecasting; encoder must adapt. **Held-out-class generalization** (training set with classes 6, 18, 28 excluded; evaluated on those held-out classes): | Eval | Scratch FDE | Pretrained FDE | Δ | p | |---|---:|---:|---:|---:| | In-distribution (without held-out classes) | 2118 ± 108 | 1960 ± 58 | −7.4% | 0.110 | | Held-out classes only | 3235 ± 207 | 2966 ± 157 | −8.3% | 0.152 | Pretraining helps uniformly across both in-distribution and held-out classes — the **generalization gap** (FDE degradation on held-out classes vs in-dist) is identical: 52.7% (scratch) vs 51.2% (pretrained), p=0.57. So pretraining does not *specifically* improve transfer to unseen procedures; it provides a uniform ~8-10% boost everywhere. NLL is significantly better for pretrained (p=0.012 in-dist, p=0.067 generalization). ## Recipe The architecture and training recipe in three layers: 1. **Encoder/decoder** (no SSL, supervised baseline already strong) - Causal transformer encoder over 8 s patches of past 9-dim features `[x, y, z, uₓ, u_y, u_z, r, sin θ, cos θ]`. Last-token readout. - HiVT-style parallel decoder (arxiv:2207.09588): single context vector + per-step learnable positional embedding → MLP emits `[T, 7]` = (μ, log σ, ρ) per step. - β-NLL Gaussian loss (Seitzer 2022, arxiv:2203.09168) — prevents the σ-collapse pathology of plain Gaussian NLL. - Decoder is non-autoregressive: ~15× tighter seed variance than teacher-forced GRU rollouts. 2. **JEPA pretraining stage** (the v7 contribution) - Mask contiguous blocks of past-track patches at ratios sampled from U[0.3, 0.7] per example. - Encoder + EMA target encoder + small predictor MLP. - L1 loss in latent space on masked-patch latents. - Trained 60 epochs on the same RKSIa training set (no labels used). 3. **Fine-tuning + post-hoc calibration** - Load pretrained encoder + tokenizer into the v6 architecture. - Train forecaster end-to-end with β-NLL. - Optional post-hoc isotonic calibration (Kuleshov 2018) for joint coverage approaching nominal levels. ## Files | File | Description | |------|-------------| | `v7-pretrain.pt` | Pretrained encoder + tokenizer checkpoint (used for fine-tuning) | | `v7-pre-finetune-s{0,1,2}.pt` | Fine-tuned models, 3 seeds (recommended for inference) | | `v7-scratch-s{0,1,2}.pt` | Matched-compute scratch baselines (no pretraining) | | `v6-lam0p0-s{0,1,2}.pt` | v6 reference checkpoints (older training pipeline) | | `train_v7_finetune.py` | Single-file training script with `--pretrained-encoder` support | | `pretrain_v7.py` | Self-contained pretraining script | ## Quick start — inference ```python import torch from huggingface_hub import hf_hub_download ckpt_path = hf_hub_download("guychuk/flight-jepa-v2", "v7-pre-finetune-s0.pt") state = torch.load(ckpt_path, map_location="cpu", weights_only=False) from train_v2_prod import FlightJEPAv2 # from the project source model = FlightJEPAv2(state["config"]) model.load_state_dict(state["state_dict"]) model.train(False) # inference mode # Inputs (B=1, T_past=256, 9 features): 9-dim ADS-B-derived features. # delta: (B,) — blindspot length in seconds, 30..120. # last_pos: (B, 3) — position at the moment of going dark (normalized ENU). mu_pos, sigma, rho = model.rollout(past_features, past_length, delta, last_pos, delta_max=int(delta.max())) # mu_pos: (B, delta_max, 3) predicted positions # sigma: (B, delta_max, 3) predicted std # rho: (B, delta_max) correlation between x and y per step ``` ## Limitations These were measured carefully and are part of the honest story: - **Single airport.** Trained and evaluated on Incheon (RKSI) arrivals. Cross-airport transfer not validated for forecasting (v1 classification probes showed cross-airport "transfer" was a dataset-difficulty artifact). - **Shuffled train/test split.** ATFMTraj's RKSIa split is a random shuffled split, not held-out by route or by time. Same flight routes appear in both partitions. Held-out-by-time is impossible without re-querying OpenSky for the original timestamps. We add a held-out-by-class generalization eval as a partial substitute. - **The supervised baseline is already strong.** v6 baseline (no pretraining) achieves FDE 1865 m at 0% dropout. The pretraining gain on clean inputs is modest (~7%); the operationally meaningful gain is at degradation (12-17%). - **Calibration is partial without isotonic.** β-NLL training brings raw 95% coverage to ~97% in-distribution, but coverage degrades rapidly with input dropout. Post-hoc isotonic recovers most of the loss; the isotonic mapper from v3 (different architecture) gave 92.9% joint coverage. v7 isotonic mappers should be re-fit per checkpoint. - **Gain shrinks at extreme dropout.** Beyond 50% past-track gap, both models fail catastrophically (FDE > 6 km). The pretraining gain is most valuable in the 10-50% dropout regime — exactly the operational range where deployment cares. ## Compute spent ~$22 of T4 cloud compute across 7 iterations of the project (v1 through v7), including pretraining, baseline runs, and ablations. ## Citation If you use this checkpoint, cite the project artifacts: - ATFMTraj dataset: arxiv:2407.20028 - Forecast-MAE precedent for SSL+forecasting: arxiv:2308.09882 - HiVT decoder design: arxiv:2207.09588 - β-NLL: arxiv:2203.09168 ## License Apache-2.0.