File size: 8,164 Bytes
bf02000 6caa701 bf02000 6caa701 bf02000 6caa701 bf02000 6caa701 bf02000 6caa701 bf02000 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 | ---
license: apache-2.0
tags:
- trajectory-forecasting
- aviation
- ads-b
- self-supervised
- jepa
- pytorch
datasets:
- petchthwr/ATFMTraj
language:
- en
library_name: pytorch
---
# Flight-JEPA v8
A trajectory forecasting model for aircraft on terminal-area approach,
specialized for the **blindspot continuation** task: given an observed
past track, predict the trajectory through a coverage gap of variable
length and the reappearance distribution.
The headline contribution is a **JEPA-style past-track masked
pretraining recipe** that produces representations more robust to test-
time radar coverage gaps, with the gain *generalizing across airports*.
Pretrained-then-fine-tuned models maintain significantly lower
forecasting error and better-calibrated uncertainty when up to 70% of
past observations are missing — including on **completely held-out
airports the fine-tuning never saw**.
## v8 — leave-one-airport-out (LOAO) headline
Across 4 LOAO folds (held out: RKSIa / RKSId / ESSA / LSZH, n=3 seeds
each), pretrained beats scratch with significance:
| Past-track dropout | Mean Δ FDE | p |
|---|---:|---:|
| 0% (clean — no regression) | +2.8% | 0.41 |
| 30% | −6.6% | 0.04 ✓ |
| **50%** | **−23.4%** | **<0.001 ✓** |
| **70%** | **−22.5%** | **<0.001 ✓** |
11 of 12 comparisons at ≥30% dropout reach p<0.05. All 4 LOAO folds
pass the locked criterion independently.



The result generalizes across very different airports (Korea, Sweden,
Switzerland — different runway geometries and procedures). It uses an
airport-ID token (UniTraj recipe, arxiv:2403.15098) for conditioning.
## v7 — single-airport reference
The v7 prerequisite (single-airport, RKSIa-only) showed the same effect
on a within-airport split:
| Past-track dropout | Scratch FDE | Pretrained FDE | Δ | p (Welch) |
|---|---:|---:|---:|---:|
| 0% (clean) | 1890 ± 40 m | **1752 ± 21 m** | **−7.3%** | **0.012** |
| 10% | 2381 ± 51 m | **2048 ± 62 m** | **−14.0%** | **0.002** |
| 30% | 3108 ± 87 m | **2591 ± 94 m** | **−16.6%** | **0.002** |
| 50% | 3784 ± 186 m | **3322 ± 42 m** | **−12.2%** | **0.044** |
| 70% | 6607 ± 308 m | 6325 ± 79 m | −4.3% | 0.249 |
95% coverage at 50% dropout: **85.1% (pretrained) vs 77.1% (scratch)**, p=0.026.
## Ablations
**Frozen-encoder probe** (pretrained encoder weights frozen; only the
decoder trains, 3 seeds): FDE 2541 m at 0% dropout — **clearly worse
than full fine-tune** (1752 m). The pretrained representation is
informative but suboptimal as-is for forecasting; encoder must adapt.
**Held-out-class generalization** (training set with classes
6, 18, 28 excluded; evaluated on those held-out classes):
| Eval | Scratch FDE | Pretrained FDE | Δ | p |
|---|---:|---:|---:|---:|
| In-distribution (without held-out classes) | 2118 ± 108 | 1960 ± 58 | −7.4% | 0.110 |
| Held-out classes only | 3235 ± 207 | 2966 ± 157 | −8.3% | 0.152 |
Pretraining helps uniformly across both in-distribution and held-out
classes — the **generalization gap** (FDE degradation on held-out
classes vs in-dist) is identical: 52.7% (scratch) vs 51.2% (pretrained),
p=0.57. So pretraining does not *specifically* improve transfer to
unseen procedures; it provides a uniform ~8-10% boost everywhere. NLL
is significantly better for pretrained (p=0.012 in-dist, p=0.067
generalization).
## Recipe
The architecture and training recipe in three layers:
1. **Encoder/decoder** (no SSL, supervised baseline already strong)
- Causal transformer encoder over 8 s patches of past 9-dim features
`[x, y, z, uₓ, u_y, u_z, r, sin θ, cos θ]`. Last-token readout.
- HiVT-style parallel decoder (arxiv:2207.09588): single context vector
+ per-step learnable positional embedding → MLP emits `[T, 7]` =
(μ, log σ, ρ) per step.
- β-NLL Gaussian loss (Seitzer 2022, arxiv:2203.09168) — prevents the
σ-collapse pathology of plain Gaussian NLL.
- Decoder is non-autoregressive: ~15× tighter seed variance than
teacher-forced GRU rollouts.
2. **JEPA pretraining stage** (the v7 contribution)
- Mask contiguous blocks of past-track patches at ratios sampled from
U[0.3, 0.7] per example.
- Encoder + EMA target encoder + small predictor MLP.
- L1 loss in latent space on masked-patch latents.
- Trained 60 epochs on the same RKSIa training set (no labels used).
3. **Fine-tuning + post-hoc calibration**
- Load pretrained encoder + tokenizer into the v6 architecture.
- Train forecaster end-to-end with β-NLL.
- Optional post-hoc isotonic calibration (Kuleshov 2018) for joint
coverage approaching nominal levels.
## Files
| File | Description |
|------|-------------|
| `v7-pretrain.pt` | Pretrained encoder + tokenizer checkpoint (used for fine-tuning) |
| `v7-pre-finetune-s{0,1,2}.pt` | Fine-tuned models, 3 seeds (recommended for inference) |
| `v7-scratch-s{0,1,2}.pt` | Matched-compute scratch baselines (no pretraining) |
| `v6-lam0p0-s{0,1,2}.pt` | v6 reference checkpoints (older training pipeline) |
| `train_v7_finetune.py` | Single-file training script with `--pretrained-encoder` support |
| `pretrain_v7.py` | Self-contained pretraining script |
## Quick start — inference
```python
import torch
from huggingface_hub import hf_hub_download
ckpt_path = hf_hub_download("guychuk/flight-jepa-v2", "v7-pre-finetune-s0.pt")
state = torch.load(ckpt_path, map_location="cpu", weights_only=False)
from train_v2_prod import FlightJEPAv2 # from the project source
model = FlightJEPAv2(state["config"])
model.load_state_dict(state["state_dict"])
model.train(False) # inference mode
# Inputs (B=1, T_past=256, 9 features): 9-dim ADS-B-derived features.
# delta: (B,) — blindspot length in seconds, 30..120.
# last_pos: (B, 3) — position at the moment of going dark (normalized ENU).
mu_pos, sigma, rho = model.rollout(past_features, past_length, delta, last_pos,
delta_max=int(delta.max()))
# mu_pos: (B, delta_max, 3) predicted positions
# sigma: (B, delta_max, 3) predicted std
# rho: (B, delta_max) correlation between x and y per step
```
## Limitations
These were measured carefully and are part of the honest story:
- **Single airport.** Trained and evaluated on Incheon (RKSI) arrivals.
Cross-airport transfer not validated for forecasting (v1 classification
probes showed cross-airport "transfer" was a dataset-difficulty artifact).
- **Shuffled train/test split.** ATFMTraj's RKSIa split is a random
shuffled split, not held-out by route or by time. Same flight routes
appear in both partitions. Held-out-by-time is impossible without
re-querying OpenSky for the original timestamps. We add a
held-out-by-class generalization eval as a partial substitute.
- **The supervised baseline is already strong.** v6 baseline (no
pretraining) achieves FDE 1865 m at 0% dropout. The pretraining gain
on clean inputs is modest (~7%); the operationally meaningful gain
is at degradation (12-17%).
- **Calibration is partial without isotonic.** β-NLL training brings
raw 95% coverage to ~97% in-distribution, but coverage degrades
rapidly with input dropout. Post-hoc isotonic recovers most of the
loss; the isotonic mapper from v3 (different architecture) gave 92.9%
joint coverage. v7 isotonic mappers should be re-fit per checkpoint.
- **Gain shrinks at extreme dropout.** Beyond 50% past-track gap, both
models fail catastrophically (FDE > 6 km). The pretraining gain is
most valuable in the 10-50% dropout regime — exactly the operational
range where deployment cares.
## Compute spent
~$22 of T4 cloud compute across 7 iterations of the project (v1 through
v7), including pretraining, baseline runs, and ablations.
## Citation
If you use this checkpoint, cite the project artifacts:
- ATFMTraj dataset: arxiv:2407.20028
- Forecast-MAE precedent for SSL+forecasting: arxiv:2308.09882
- HiVT decoder design: arxiv:2207.09588
- β-NLL: arxiv:2203.09168
## License
Apache-2.0.
|