| --- |
| license: apache-2.0 |
| tags: |
| - trajectory-forecasting |
| - aviation |
| - ads-b |
| - self-supervised |
| - jepa |
| - pytorch |
| datasets: |
| - petchthwr/ATFMTraj |
| language: |
| - en |
| library_name: pytorch |
| --- |
| |
| # Flight-JEPA v8 |
|
|
| A trajectory forecasting model for aircraft on terminal-area approach, |
| specialized for the **blindspot continuation** task: given an observed |
| past track, predict the trajectory through a coverage gap of variable |
| length and the reappearance distribution. |
|
|
| The headline contribution is a **JEPA-style past-track masked |
| pretraining recipe** that produces representations more robust to test- |
| time radar coverage gaps, with the gain *generalizing across airports*. |
| Pretrained-then-fine-tuned models maintain significantly lower |
| forecasting error and better-calibrated uncertainty when up to 70% of |
| past observations are missing — including on **completely held-out |
| airports the fine-tuning never saw**. |
|
|
| ## v8 — leave-one-airport-out (LOAO) headline |
|
|
| Across 4 LOAO folds (held out: RKSIa / RKSId / ESSA / LSZH, n=3 seeds |
| each), pretrained beats scratch with significance: |
|
|
| | Past-track dropout | Mean Δ FDE | p | |
| |---|---:|---:| |
| | 0% (clean — no regression) | +2.8% | 0.41 | |
| | 30% | −6.6% | 0.04 ✓ | |
| | **50%** | **−23.4%** | **<0.001 ✓** | |
| | **70%** | **−22.5%** | **<0.001 ✓** | |
|
|
| 11 of 12 comparisons at ≥30% dropout reach p<0.05. All 4 LOAO folds |
| pass the locked criterion independently. |
|
|
|  |
|  |
|  |
|
|
| The result generalizes across very different airports (Korea, Sweden, |
| Switzerland — different runway geometries and procedures). It uses an |
| airport-ID token (UniTraj recipe, arxiv:2403.15098) for conditioning. |
|
|
| ## v7 — single-airport reference |
|
|
| The v7 prerequisite (single-airport, RKSIa-only) showed the same effect |
| on a within-airport split: |
|
|
| | Past-track dropout | Scratch FDE | Pretrained FDE | Δ | p (Welch) | |
| |---|---:|---:|---:|---:| |
| | 0% (clean) | 1890 ± 40 m | **1752 ± 21 m** | **−7.3%** | **0.012** | |
| | 10% | 2381 ± 51 m | **2048 ± 62 m** | **−14.0%** | **0.002** | |
| | 30% | 3108 ± 87 m | **2591 ± 94 m** | **−16.6%** | **0.002** | |
| | 50% | 3784 ± 186 m | **3322 ± 42 m** | **−12.2%** | **0.044** | |
| | 70% | 6607 ± 308 m | 6325 ± 79 m | −4.3% | 0.249 | |
|
|
| 95% coverage at 50% dropout: **85.1% (pretrained) vs 77.1% (scratch)**, p=0.026. |
|
|
| ## Ablations |
|
|
| **Frozen-encoder probe** (pretrained encoder weights frozen; only the |
| decoder trains, 3 seeds): FDE 2541 m at 0% dropout — **clearly worse |
| than full fine-tune** (1752 m). The pretrained representation is |
| informative but suboptimal as-is for forecasting; encoder must adapt. |
|
|
| **Held-out-class generalization** (training set with classes |
| 6, 18, 28 excluded; evaluated on those held-out classes): |
|
|
| | Eval | Scratch FDE | Pretrained FDE | Δ | p | |
| |---|---:|---:|---:|---:| |
| | In-distribution (without held-out classes) | 2118 ± 108 | 1960 ± 58 | −7.4% | 0.110 | |
| | Held-out classes only | 3235 ± 207 | 2966 ± 157 | −8.3% | 0.152 | |
|
|
| Pretraining helps uniformly across both in-distribution and held-out |
| classes — the **generalization gap** (FDE degradation on held-out |
| classes vs in-dist) is identical: 52.7% (scratch) vs 51.2% (pretrained), |
| p=0.57. So pretraining does not *specifically* improve transfer to |
| unseen procedures; it provides a uniform ~8-10% boost everywhere. NLL |
| is significantly better for pretrained (p=0.012 in-dist, p=0.067 |
| generalization). |
|
|
| ## Recipe |
|
|
| The architecture and training recipe in three layers: |
|
|
| 1. **Encoder/decoder** (no SSL, supervised baseline already strong) |
| - Causal transformer encoder over 8 s patches of past 9-dim features |
| `[x, y, z, uₓ, u_y, u_z, r, sin θ, cos θ]`. Last-token readout. |
| - HiVT-style parallel decoder (arxiv:2207.09588): single context vector |
| + per-step learnable positional embedding → MLP emits `[T, 7]` = |
| (μ, log σ, ρ) per step. |
| - β-NLL Gaussian loss (Seitzer 2022, arxiv:2203.09168) — prevents the |
| σ-collapse pathology of plain Gaussian NLL. |
| - Decoder is non-autoregressive: ~15× tighter seed variance than |
| teacher-forced GRU rollouts. |
| |
| 2. **JEPA pretraining stage** (the v7 contribution) |
| - Mask contiguous blocks of past-track patches at ratios sampled from |
| U[0.3, 0.7] per example. |
| - Encoder + EMA target encoder + small predictor MLP. |
| - L1 loss in latent space on masked-patch latents. |
| - Trained 60 epochs on the same RKSIa training set (no labels used). |
| |
| 3. **Fine-tuning + post-hoc calibration** |
| - Load pretrained encoder + tokenizer into the v6 architecture. |
| - Train forecaster end-to-end with β-NLL. |
| - Optional post-hoc isotonic calibration (Kuleshov 2018) for joint |
| coverage approaching nominal levels. |
| |
| ## Files |
|
|
| | File | Description | |
| |------|-------------| |
| | `v7-pretrain.pt` | Pretrained encoder + tokenizer checkpoint (used for fine-tuning) | |
| | `v7-pre-finetune-s{0,1,2}.pt` | Fine-tuned models, 3 seeds (recommended for inference) | |
| | `v7-scratch-s{0,1,2}.pt` | Matched-compute scratch baselines (no pretraining) | |
| | `v6-lam0p0-s{0,1,2}.pt` | v6 reference checkpoints (older training pipeline) | |
| | `train_v7_finetune.py` | Single-file training script with `--pretrained-encoder` support | |
| | `pretrain_v7.py` | Self-contained pretraining script | |
|
|
| ## Quick start — inference |
|
|
| ```python |
| import torch |
| from huggingface_hub import hf_hub_download |
| |
| ckpt_path = hf_hub_download("guychuk/flight-jepa-v2", "v7-pre-finetune-s0.pt") |
| state = torch.load(ckpt_path, map_location="cpu", weights_only=False) |
| |
| from train_v2_prod import FlightJEPAv2 # from the project source |
| model = FlightJEPAv2(state["config"]) |
| model.load_state_dict(state["state_dict"]) |
| model.train(False) # inference mode |
| |
| # Inputs (B=1, T_past=256, 9 features): 9-dim ADS-B-derived features. |
| # delta: (B,) — blindspot length in seconds, 30..120. |
| # last_pos: (B, 3) — position at the moment of going dark (normalized ENU). |
| mu_pos, sigma, rho = model.rollout(past_features, past_length, delta, last_pos, |
| delta_max=int(delta.max())) |
| # mu_pos: (B, delta_max, 3) predicted positions |
| # sigma: (B, delta_max, 3) predicted std |
| # rho: (B, delta_max) correlation between x and y per step |
| ``` |
|
|
| ## Limitations |
|
|
| These were measured carefully and are part of the honest story: |
|
|
| - **Single airport.** Trained and evaluated on Incheon (RKSI) arrivals. |
| Cross-airport transfer not validated for forecasting (v1 classification |
| probes showed cross-airport "transfer" was a dataset-difficulty artifact). |
| - **Shuffled train/test split.** ATFMTraj's RKSIa split is a random |
| shuffled split, not held-out by route or by time. Same flight routes |
| appear in both partitions. Held-out-by-time is impossible without |
| re-querying OpenSky for the original timestamps. We add a |
| held-out-by-class generalization eval as a partial substitute. |
| - **The supervised baseline is already strong.** v6 baseline (no |
| pretraining) achieves FDE 1865 m at 0% dropout. The pretraining gain |
| on clean inputs is modest (~7%); the operationally meaningful gain |
| is at degradation (12-17%). |
| - **Calibration is partial without isotonic.** β-NLL training brings |
| raw 95% coverage to ~97% in-distribution, but coverage degrades |
| rapidly with input dropout. Post-hoc isotonic recovers most of the |
| loss; the isotonic mapper from v3 (different architecture) gave 92.9% |
| joint coverage. v7 isotonic mappers should be re-fit per checkpoint. |
| - **Gain shrinks at extreme dropout.** Beyond 50% past-track gap, both |
| models fail catastrophically (FDE > 6 km). The pretraining gain is |
| most valuable in the 10-50% dropout regime — exactly the operational |
| range where deployment cares. |
|
|
| ## Compute spent |
|
|
| ~$22 of T4 cloud compute across 7 iterations of the project (v1 through |
| v7), including pretraining, baseline runs, and ablations. |
|
|
| ## Citation |
|
|
| If you use this checkpoint, cite the project artifacts: |
|
|
| - ATFMTraj dataset: arxiv:2407.20028 |
| - Forecast-MAE precedent for SSL+forecasting: arxiv:2308.09882 |
| - HiVT decoder design: arxiv:2207.09588 |
| - β-NLL: arxiv:2203.09168 |
|
|
| ## License |
|
|
| Apache-2.0. |
|
|