v7: model card with degradation curves, frozen + held-out ablations
Browse files
README.md
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
tags:
|
| 4 |
+
- trajectory-forecasting
|
| 5 |
+
- aviation
|
| 6 |
+
- ads-b
|
| 7 |
+
- self-supervised
|
| 8 |
+
- jepa
|
| 9 |
+
- pytorch
|
| 10 |
+
datasets:
|
| 11 |
+
- petchthwr/ATFMTraj
|
| 12 |
+
language:
|
| 13 |
+
- en
|
| 14 |
+
library_name: pytorch
|
| 15 |
+
---
|
| 16 |
+
|
| 17 |
+
# Flight-JEPA v7
|
| 18 |
+
|
| 19 |
+
A trajectory forecasting model for aircraft on terminal-area approach,
|
| 20 |
+
specialized for the **blindspot continuation** task: given an observed
|
| 21 |
+
past track, predict the trajectory through a coverage gap of variable
|
| 22 |
+
length and the reappearance distribution.
|
| 23 |
+
|
| 24 |
+
The headline contribution of v7 is a **JEPA-style past-track masked
|
| 25 |
+
pretraining recipe** that produces representations more robust to test-
|
| 26 |
+
time radar coverage gaps. Pretrained-then-fine-tuned models maintain
|
| 27 |
+
significantly lower forecasting error and better calibrated uncertainty
|
| 28 |
+
when up to half of the past observations are missing — the regime
|
| 29 |
+
aviation deployment cares about.
|
| 30 |
+
|
| 31 |
+
## Quick numbers
|
| 32 |
+
|
| 33 |
+
On RKSIa (Incheon arrivals, 8092 test trajectories, n=3 seeds):
|
| 34 |
+
|
| 35 |
+
| Past-track dropout | Scratch FDE | Pretrained FDE | Δ | p (Welch) |
|
| 36 |
+
|---|---:|---:|---:|---:|
|
| 37 |
+
| 0% (clean) | 1890 ± 40 m | **1752 ± 21 m** | **−7.3%** | **0.012** |
|
| 38 |
+
| 10% | 2381 ± 51 m | **2048 ± 62 m** | **−14.0%** | **0.002** |
|
| 39 |
+
| 30% | 3108 ± 87 m | **2591 ± 94 m** | **−16.6%** | **0.002** |
|
| 40 |
+
| 50% | 3784 ± 186 m | **3322 ± 42 m** | **−12.2%** | **0.044** |
|
| 41 |
+
| 70% | 6607 ± 308 m | 6325 ± 79 m | −4.3% | 0.249 |
|
| 42 |
+
|
| 43 |
+
95% coverage at 50% dropout: **85.1% (pretrained) vs 77.1% (scratch)**, p=0.026.
|
| 44 |
+
|
| 45 |
+
## Ablations
|
| 46 |
+
|
| 47 |
+
**Frozen-encoder probe** (pretrained encoder weights frozen; only the
|
| 48 |
+
decoder trains, 3 seeds): FDE 2541 m at 0% dropout — **clearly worse
|
| 49 |
+
than full fine-tune** (1752 m). The pretrained representation is
|
| 50 |
+
informative but suboptimal as-is for forecasting; encoder must adapt.
|
| 51 |
+
|
| 52 |
+
**Held-out-class generalization** (training set with classes
|
| 53 |
+
6, 18, 28 excluded; evaluated on those held-out classes):
|
| 54 |
+
|
| 55 |
+
| Eval | Scratch FDE | Pretrained FDE | Δ | p |
|
| 56 |
+
|---|---:|---:|---:|---:|
|
| 57 |
+
| In-distribution (without held-out classes) | 2118 ± 108 | 1960 ± 58 | −7.4% | 0.110 |
|
| 58 |
+
| Held-out classes only | 3235 ± 207 | 2966 ± 157 | −8.3% | 0.152 |
|
| 59 |
+
|
| 60 |
+
Pretraining helps uniformly across both in-distribution and held-out
|
| 61 |
+
classes — the **generalization gap** (FDE degradation on held-out
|
| 62 |
+
classes vs in-dist) is identical: 52.7% (scratch) vs 51.2% (pretrained),
|
| 63 |
+
p=0.57. So pretraining does not *specifically* improve transfer to
|
| 64 |
+
unseen procedures; it provides a uniform ~8-10% boost everywhere. NLL
|
| 65 |
+
is significantly better for pretrained (p=0.012 in-dist, p=0.067
|
| 66 |
+
generalization).
|
| 67 |
+
|
| 68 |
+
## Recipe
|
| 69 |
+
|
| 70 |
+
The architecture and training recipe in three layers:
|
| 71 |
+
|
| 72 |
+
1. **Encoder/decoder** (no SSL, supervised baseline already strong)
|
| 73 |
+
- Causal transformer encoder over 8 s patches of past 9-dim features
|
| 74 |
+
`[x, y, z, uₓ, u_y, u_z, r, sin θ, cos θ]`. Last-token readout.
|
| 75 |
+
- HiVT-style parallel decoder (arxiv:2207.09588): single context vector
|
| 76 |
+
+ per-step learnable positional embedding → MLP emits `[T, 7]` =
|
| 77 |
+
(μ, log σ, ρ) per step.
|
| 78 |
+
- β-NLL Gaussian loss (Seitzer 2022, arxiv:2203.09168) — prevents the
|
| 79 |
+
σ-collapse pathology of plain Gaussian NLL.
|
| 80 |
+
- Decoder is non-autoregressive: ~15× tighter seed variance than
|
| 81 |
+
teacher-forced GRU rollouts.
|
| 82 |
+
|
| 83 |
+
2. **JEPA pretraining stage** (the v7 contribution)
|
| 84 |
+
- Mask contiguous blocks of past-track patches at ratios sampled from
|
| 85 |
+
U[0.3, 0.7] per example.
|
| 86 |
+
- Encoder + EMA target encoder + small predictor MLP.
|
| 87 |
+
- L1 loss in latent space on masked-patch latents.
|
| 88 |
+
- Trained 60 epochs on the same RKSIa training set (no labels used).
|
| 89 |
+
|
| 90 |
+
3. **Fine-tuning + post-hoc calibration**
|
| 91 |
+
- Load pretrained encoder + tokenizer into the v6 architecture.
|
| 92 |
+
- Train forecaster end-to-end with β-NLL.
|
| 93 |
+
- Optional post-hoc isotonic calibration (Kuleshov 2018) for joint
|
| 94 |
+
coverage approaching nominal levels.
|
| 95 |
+
|
| 96 |
+
## Files
|
| 97 |
+
|
| 98 |
+
| File | Description |
|
| 99 |
+
|------|-------------|
|
| 100 |
+
| `v7-pretrain.pt` | Pretrained encoder + tokenizer checkpoint (used for fine-tuning) |
|
| 101 |
+
| `v7-pre-finetune-s{0,1,2}.pt` | Fine-tuned models, 3 seeds (recommended for inference) |
|
| 102 |
+
| `v7-scratch-s{0,1,2}.pt` | Matched-compute scratch baselines (no pretraining) |
|
| 103 |
+
| `v6-lam0p0-s{0,1,2}.pt` | v6 reference checkpoints (older training pipeline) |
|
| 104 |
+
| `train_v7_finetune.py` | Single-file training script with `--pretrained-encoder` support |
|
| 105 |
+
| `pretrain_v7.py` | Self-contained pretraining script |
|
| 106 |
+
|
| 107 |
+
## Quick start — inference
|
| 108 |
+
|
| 109 |
+
```python
|
| 110 |
+
import torch
|
| 111 |
+
from huggingface_hub import hf_hub_download
|
| 112 |
+
|
| 113 |
+
ckpt_path = hf_hub_download("guychuk/flight-jepa-v2", "v7-pre-finetune-s0.pt")
|
| 114 |
+
state = torch.load(ckpt_path, map_location="cpu", weights_only=False)
|
| 115 |
+
|
| 116 |
+
from train_v2_prod import FlightJEPAv2 # from the project source
|
| 117 |
+
model = FlightJEPAv2(state["config"])
|
| 118 |
+
model.load_state_dict(state["state_dict"])
|
| 119 |
+
model.train(False) # inference mode
|
| 120 |
+
|
| 121 |
+
# Inputs (B=1, T_past=256, 9 features): 9-dim ADS-B-derived features.
|
| 122 |
+
# delta: (B,) — blindspot length in seconds, 30..120.
|
| 123 |
+
# last_pos: (B, 3) — position at the moment of going dark (normalized ENU).
|
| 124 |
+
mu_pos, sigma, rho = model.rollout(past_features, past_length, delta, last_pos,
|
| 125 |
+
delta_max=int(delta.max()))
|
| 126 |
+
# mu_pos: (B, delta_max, 3) predicted positions
|
| 127 |
+
# sigma: (B, delta_max, 3) predicted std
|
| 128 |
+
# rho: (B, delta_max) correlation between x and y per step
|
| 129 |
+
```
|
| 130 |
+
|
| 131 |
+
## Limitations
|
| 132 |
+
|
| 133 |
+
These were measured carefully and are part of the honest story:
|
| 134 |
+
|
| 135 |
+
- **Single airport.** Trained and evaluated on Incheon (RKSI) arrivals.
|
| 136 |
+
Cross-airport transfer not validated for forecasting (v1 classification
|
| 137 |
+
probes showed cross-airport "transfer" was a dataset-difficulty artifact).
|
| 138 |
+
- **Shuffled train/test split.** ATFMTraj's RKSIa split is a random
|
| 139 |
+
shuffled split, not held-out by route or by time. Same flight routes
|
| 140 |
+
appear in both partitions. Held-out-by-time is impossible without
|
| 141 |
+
re-querying OpenSky for the original timestamps. We add a
|
| 142 |
+
held-out-by-class generalization eval as a partial substitute.
|
| 143 |
+
- **The supervised baseline is already strong.** v6 baseline (no
|
| 144 |
+
pretraining) achieves FDE 1865 m at 0% dropout. The pretraining gain
|
| 145 |
+
on clean inputs is modest (~7%); the operationally meaningful gain
|
| 146 |
+
is at degradation (12-17%).
|
| 147 |
+
- **Calibration is partial without isotonic.** β-NLL training brings
|
| 148 |
+
raw 95% coverage to ~97% in-distribution, but coverage degrades
|
| 149 |
+
rapidly with input dropout. Post-hoc isotonic recovers most of the
|
| 150 |
+
loss; the isotonic mapper from v3 (different architecture) gave 92.9%
|
| 151 |
+
joint coverage. v7 isotonic mappers should be re-fit per checkpoint.
|
| 152 |
+
- **Gain shrinks at extreme dropout.** Beyond 50% past-track gap, both
|
| 153 |
+
models fail catastrophically (FDE > 6 km). The pretraining gain is
|
| 154 |
+
most valuable in the 10-50% dropout regime — exactly the operational
|
| 155 |
+
range where deployment cares.
|
| 156 |
+
|
| 157 |
+
## Compute spent
|
| 158 |
+
|
| 159 |
+
~$22 of T4 cloud compute across 7 iterations of the project (v1 through
|
| 160 |
+
v7), including pretraining, baseline runs, and ablations.
|
| 161 |
+
|
| 162 |
+
## Citation
|
| 163 |
+
|
| 164 |
+
If you use this checkpoint, cite the project artifacts:
|
| 165 |
+
|
| 166 |
+
- ATFMTraj dataset: arxiv:2407.20028
|
| 167 |
+
- Forecast-MAE precedent for SSL+forecasting: arxiv:2308.09882
|
| 168 |
+
- HiVT decoder design: arxiv:2207.09588
|
| 169 |
+
- β-NLL: arxiv:2203.09168
|
| 170 |
+
|
| 171 |
+
## License
|
| 172 |
+
|
| 173 |
+
Apache-2.0.
|