flight-jepa-v2 / README.md
guychuk's picture
v8: model card with LOAO results + plots
6caa701 verified
---
license: apache-2.0
tags:
- trajectory-forecasting
- aviation
- ads-b
- self-supervised
- jepa
- pytorch
datasets:
- petchthwr/ATFMTraj
language:
- en
library_name: pytorch
---
# Flight-JEPA v8
A trajectory forecasting model for aircraft on terminal-area approach,
specialized for the **blindspot continuation** task: given an observed
past track, predict the trajectory through a coverage gap of variable
length and the reappearance distribution.
The headline contribution is a **JEPA-style past-track masked
pretraining recipe** that produces representations more robust to test-
time radar coverage gaps, with the gain *generalizing across airports*.
Pretrained-then-fine-tuned models maintain significantly lower
forecasting error and better-calibrated uncertainty when up to 70% of
past observations are missing — including on **completely held-out
airports the fine-tuning never saw**.
## v8 — leave-one-airport-out (LOAO) headline
Across 4 LOAO folds (held out: RKSIa / RKSId / ESSA / LSZH, n=3 seeds
each), pretrained beats scratch with significance:
| Past-track dropout | Mean Δ FDE | p |
|---|---:|---:|
| 0% (clean — no regression) | +2.8% | 0.41 |
| 30% | −6.6% | 0.04 ✓ |
| **50%** | **−23.4%** | **<0.001 ✓** |
| **70%** | **−22.5%** | **<0.001 ✓** |
11 of 12 comparisons at ≥30% dropout reach p<0.05. All 4 LOAO folds
pass the locked criterion independently.
![v8 summary](plots/summary_v8.png)
![v8 FDE per airport](plots/fde_per_airport.png)
![v8 coverage per airport](plots/coverage_per_airport.png)
The result generalizes across very different airports (Korea, Sweden,
Switzerland — different runway geometries and procedures). It uses an
airport-ID token (UniTraj recipe, arxiv:2403.15098) for conditioning.
## v7 — single-airport reference
The v7 prerequisite (single-airport, RKSIa-only) showed the same effect
on a within-airport split:
| Past-track dropout | Scratch FDE | Pretrained FDE | Δ | p (Welch) |
|---|---:|---:|---:|---:|
| 0% (clean) | 1890 ± 40 m | **1752 ± 21 m** | **−7.3%** | **0.012** |
| 10% | 2381 ± 51 m | **2048 ± 62 m** | **−14.0%** | **0.002** |
| 30% | 3108 ± 87 m | **2591 ± 94 m** | **−16.6%** | **0.002** |
| 50% | 3784 ± 186 m | **3322 ± 42 m** | **−12.2%** | **0.044** |
| 70% | 6607 ± 308 m | 6325 ± 79 m | −4.3% | 0.249 |
95% coverage at 50% dropout: **85.1% (pretrained) vs 77.1% (scratch)**, p=0.026.
## Ablations
**Frozen-encoder probe** (pretrained encoder weights frozen; only the
decoder trains, 3 seeds): FDE 2541 m at 0% dropout — **clearly worse
than full fine-tune** (1752 m). The pretrained representation is
informative but suboptimal as-is for forecasting; encoder must adapt.
**Held-out-class generalization** (training set with classes
6, 18, 28 excluded; evaluated on those held-out classes):
| Eval | Scratch FDE | Pretrained FDE | Δ | p |
|---|---:|---:|---:|---:|
| In-distribution (without held-out classes) | 2118 ± 108 | 1960 ± 58 | −7.4% | 0.110 |
| Held-out classes only | 3235 ± 207 | 2966 ± 157 | −8.3% | 0.152 |
Pretraining helps uniformly across both in-distribution and held-out
classes — the **generalization gap** (FDE degradation on held-out
classes vs in-dist) is identical: 52.7% (scratch) vs 51.2% (pretrained),
p=0.57. So pretraining does not *specifically* improve transfer to
unseen procedures; it provides a uniform ~8-10% boost everywhere. NLL
is significantly better for pretrained (p=0.012 in-dist, p=0.067
generalization).
## Recipe
The architecture and training recipe in three layers:
1. **Encoder/decoder** (no SSL, supervised baseline already strong)
- Causal transformer encoder over 8 s patches of past 9-dim features
`[x, y, z, uₓ, u_y, u_z, r, sin θ, cos θ]`. Last-token readout.
- HiVT-style parallel decoder (arxiv:2207.09588): single context vector
+ per-step learnable positional embedding → MLP emits `[T, 7]` =
(μ, log σ, ρ) per step.
- β-NLL Gaussian loss (Seitzer 2022, arxiv:2203.09168) — prevents the
σ-collapse pathology of plain Gaussian NLL.
- Decoder is non-autoregressive: ~15× tighter seed variance than
teacher-forced GRU rollouts.
2. **JEPA pretraining stage** (the v7 contribution)
- Mask contiguous blocks of past-track patches at ratios sampled from
U[0.3, 0.7] per example.
- Encoder + EMA target encoder + small predictor MLP.
- L1 loss in latent space on masked-patch latents.
- Trained 60 epochs on the same RKSIa training set (no labels used).
3. **Fine-tuning + post-hoc calibration**
- Load pretrained encoder + tokenizer into the v6 architecture.
- Train forecaster end-to-end with β-NLL.
- Optional post-hoc isotonic calibration (Kuleshov 2018) for joint
coverage approaching nominal levels.
## Files
| File | Description |
|------|-------------|
| `v7-pretrain.pt` | Pretrained encoder + tokenizer checkpoint (used for fine-tuning) |
| `v7-pre-finetune-s{0,1,2}.pt` | Fine-tuned models, 3 seeds (recommended for inference) |
| `v7-scratch-s{0,1,2}.pt` | Matched-compute scratch baselines (no pretraining) |
| `v6-lam0p0-s{0,1,2}.pt` | v6 reference checkpoints (older training pipeline) |
| `train_v7_finetune.py` | Single-file training script with `--pretrained-encoder` support |
| `pretrain_v7.py` | Self-contained pretraining script |
## Quick start — inference
```python
import torch
from huggingface_hub import hf_hub_download
ckpt_path = hf_hub_download("guychuk/flight-jepa-v2", "v7-pre-finetune-s0.pt")
state = torch.load(ckpt_path, map_location="cpu", weights_only=False)
from train_v2_prod import FlightJEPAv2 # from the project source
model = FlightJEPAv2(state["config"])
model.load_state_dict(state["state_dict"])
model.train(False) # inference mode
# Inputs (B=1, T_past=256, 9 features): 9-dim ADS-B-derived features.
# delta: (B,) — blindspot length in seconds, 30..120.
# last_pos: (B, 3) — position at the moment of going dark (normalized ENU).
mu_pos, sigma, rho = model.rollout(past_features, past_length, delta, last_pos,
delta_max=int(delta.max()))
# mu_pos: (B, delta_max, 3) predicted positions
# sigma: (B, delta_max, 3) predicted std
# rho: (B, delta_max) correlation between x and y per step
```
## Limitations
These were measured carefully and are part of the honest story:
- **Single airport.** Trained and evaluated on Incheon (RKSI) arrivals.
Cross-airport transfer not validated for forecasting (v1 classification
probes showed cross-airport "transfer" was a dataset-difficulty artifact).
- **Shuffled train/test split.** ATFMTraj's RKSIa split is a random
shuffled split, not held-out by route or by time. Same flight routes
appear in both partitions. Held-out-by-time is impossible without
re-querying OpenSky for the original timestamps. We add a
held-out-by-class generalization eval as a partial substitute.
- **The supervised baseline is already strong.** v6 baseline (no
pretraining) achieves FDE 1865 m at 0% dropout. The pretraining gain
on clean inputs is modest (~7%); the operationally meaningful gain
is at degradation (12-17%).
- **Calibration is partial without isotonic.** β-NLL training brings
raw 95% coverage to ~97% in-distribution, but coverage degrades
rapidly with input dropout. Post-hoc isotonic recovers most of the
loss; the isotonic mapper from v3 (different architecture) gave 92.9%
joint coverage. v7 isotonic mappers should be re-fit per checkpoint.
- **Gain shrinks at extreme dropout.** Beyond 50% past-track gap, both
models fail catastrophically (FDE > 6 km). The pretraining gain is
most valuable in the 10-50% dropout regime — exactly the operational
range where deployment cares.
## Compute spent
~$22 of T4 cloud compute across 7 iterations of the project (v1 through
v7), including pretraining, baseline runs, and ablations.
## Citation
If you use this checkpoint, cite the project artifacts:
- ATFMTraj dataset: arxiv:2407.20028
- Forecast-MAE precedent for SSL+forecasting: arxiv:2308.09882
- HiVT decoder design: arxiv:2207.09588
- β-NLL: arxiv:2203.09168
## License
Apache-2.0.