File size: 8,164 Bytes
bf02000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6caa701
bf02000
 
 
 
 
 
6caa701
bf02000
6caa701
 
 
 
 
bf02000
6caa701
bf02000
6caa701
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf02000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
---
license: apache-2.0
tags:
- trajectory-forecasting
- aviation
- ads-b
- self-supervised
- jepa
- pytorch
datasets:
- petchthwr/ATFMTraj
language:
- en
library_name: pytorch
---

# Flight-JEPA v8

A trajectory forecasting model for aircraft on terminal-area approach,
specialized for the **blindspot continuation** task: given an observed
past track, predict the trajectory through a coverage gap of variable
length and the reappearance distribution.

The headline contribution is a **JEPA-style past-track masked
pretraining recipe** that produces representations more robust to test-
time radar coverage gaps, with the gain *generalizing across airports*.
Pretrained-then-fine-tuned models maintain significantly lower
forecasting error and better-calibrated uncertainty when up to 70% of
past observations are missing — including on **completely held-out
airports the fine-tuning never saw**.

## v8 — leave-one-airport-out (LOAO) headline

Across 4 LOAO folds (held out: RKSIa / RKSId / ESSA / LSZH, n=3 seeds
each), pretrained beats scratch with significance:

| Past-track dropout | Mean Δ FDE | p |
|---|---:|---:|
| 0% (clean — no regression) | +2.8% | 0.41 |
| 30% | −6.6% | 0.04 ✓ |
| **50%** | **−23.4%** | **<0.001 ✓** |
| **70%** | **−22.5%** | **<0.001 ✓** |

11 of 12 comparisons at ≥30% dropout reach p<0.05. All 4 LOAO folds
pass the locked criterion independently.

![v8 summary](plots/summary_v8.png)
![v8 FDE per airport](plots/fde_per_airport.png)
![v8 coverage per airport](plots/coverage_per_airport.png)

The result generalizes across very different airports (Korea, Sweden,
Switzerland — different runway geometries and procedures). It uses an
airport-ID token (UniTraj recipe, arxiv:2403.15098) for conditioning.

## v7 — single-airport reference

The v7 prerequisite (single-airport, RKSIa-only) showed the same effect
on a within-airport split:

| Past-track dropout | Scratch FDE | Pretrained FDE | Δ | p (Welch) |
|---|---:|---:|---:|---:|
| 0% (clean) | 1890 ± 40 m | **1752 ± 21 m** | **−7.3%** | **0.012** |
| 10% | 2381 ± 51 m | **2048 ± 62 m** | **−14.0%** | **0.002** |
| 30% | 3108 ± 87 m | **2591 ± 94 m** | **−16.6%** | **0.002** |
| 50% | 3784 ± 186 m | **3322 ± 42 m** | **−12.2%** | **0.044** |
| 70% | 6607 ± 308 m | 6325 ± 79 m | −4.3% | 0.249 |

95% coverage at 50% dropout: **85.1% (pretrained) vs 77.1% (scratch)**, p=0.026.

## Ablations

**Frozen-encoder probe** (pretrained encoder weights frozen; only the
decoder trains, 3 seeds): FDE 2541 m at 0% dropout — **clearly worse
than full fine-tune** (1752 m). The pretrained representation is
informative but suboptimal as-is for forecasting; encoder must adapt.

**Held-out-class generalization** (training set with classes
6, 18, 28 excluded; evaluated on those held-out classes):

| Eval | Scratch FDE | Pretrained FDE | Δ | p |
|---|---:|---:|---:|---:|
| In-distribution (without held-out classes) | 2118 ± 108 | 1960 ± 58 | −7.4% | 0.110 |
| Held-out classes only | 3235 ± 207 | 2966 ± 157 | −8.3% | 0.152 |

Pretraining helps uniformly across both in-distribution and held-out
classes — the **generalization gap** (FDE degradation on held-out
classes vs in-dist) is identical: 52.7% (scratch) vs 51.2% (pretrained),
p=0.57. So pretraining does not *specifically* improve transfer to
unseen procedures; it provides a uniform ~8-10% boost everywhere. NLL
is significantly better for pretrained (p=0.012 in-dist, p=0.067
generalization).

## Recipe

The architecture and training recipe in three layers:

1. **Encoder/decoder** (no SSL, supervised baseline already strong)
   - Causal transformer encoder over 8 s patches of past 9-dim features
     `[x, y, z, uₓ, u_y, u_z, r, sin θ, cos θ]`. Last-token readout.
   - HiVT-style parallel decoder (arxiv:2207.09588): single context vector
     + per-step learnable positional embedding → MLP emits `[T, 7]` =
     (μ, log σ, ρ) per step.
   - β-NLL Gaussian loss (Seitzer 2022, arxiv:2203.09168) — prevents the
     σ-collapse pathology of plain Gaussian NLL.
   - Decoder is non-autoregressive: ~15× tighter seed variance than
     teacher-forced GRU rollouts.

2. **JEPA pretraining stage** (the v7 contribution)
   - Mask contiguous blocks of past-track patches at ratios sampled from
     U[0.3, 0.7] per example.
   - Encoder + EMA target encoder + small predictor MLP.
   - L1 loss in latent space on masked-patch latents.
   - Trained 60 epochs on the same RKSIa training set (no labels used).

3. **Fine-tuning + post-hoc calibration**
   - Load pretrained encoder + tokenizer into the v6 architecture.
   - Train forecaster end-to-end with β-NLL.
   - Optional post-hoc isotonic calibration (Kuleshov 2018) for joint
     coverage approaching nominal levels.

## Files

| File | Description |
|------|-------------|
| `v7-pretrain.pt` | Pretrained encoder + tokenizer checkpoint (used for fine-tuning) |
| `v7-pre-finetune-s{0,1,2}.pt` | Fine-tuned models, 3 seeds (recommended for inference) |
| `v7-scratch-s{0,1,2}.pt` | Matched-compute scratch baselines (no pretraining) |
| `v6-lam0p0-s{0,1,2}.pt` | v6 reference checkpoints (older training pipeline) |
| `train_v7_finetune.py` | Single-file training script with `--pretrained-encoder` support |
| `pretrain_v7.py` | Self-contained pretraining script |

## Quick start — inference

```python
import torch
from huggingface_hub import hf_hub_download

ckpt_path = hf_hub_download("guychuk/flight-jepa-v2", "v7-pre-finetune-s0.pt")
state = torch.load(ckpt_path, map_location="cpu", weights_only=False)

from train_v2_prod import FlightJEPAv2  # from the project source
model = FlightJEPAv2(state["config"])
model.load_state_dict(state["state_dict"])
model.train(False)  # inference mode

# Inputs (B=1, T_past=256, 9 features): 9-dim ADS-B-derived features.
# delta: (B,) — blindspot length in seconds, 30..120.
# last_pos: (B, 3) — position at the moment of going dark (normalized ENU).
mu_pos, sigma, rho = model.rollout(past_features, past_length, delta, last_pos,
                                    delta_max=int(delta.max()))
# mu_pos: (B, delta_max, 3) predicted positions
# sigma:  (B, delta_max, 3) predicted std
# rho:    (B, delta_max)    correlation between x and y per step
```

## Limitations

These were measured carefully and are part of the honest story:

- **Single airport.** Trained and evaluated on Incheon (RKSI) arrivals.
  Cross-airport transfer not validated for forecasting (v1 classification
  probes showed cross-airport "transfer" was a dataset-difficulty artifact).
- **Shuffled train/test split.** ATFMTraj's RKSIa split is a random
  shuffled split, not held-out by route or by time. Same flight routes
  appear in both partitions. Held-out-by-time is impossible without
  re-querying OpenSky for the original timestamps. We add a
  held-out-by-class generalization eval as a partial substitute.
- **The supervised baseline is already strong.** v6 baseline (no
  pretraining) achieves FDE 1865 m at 0% dropout. The pretraining gain
  on clean inputs is modest (~7%); the operationally meaningful gain
  is at degradation (12-17%).
- **Calibration is partial without isotonic.** β-NLL training brings
  raw 95% coverage to ~97% in-distribution, but coverage degrades
  rapidly with input dropout. Post-hoc isotonic recovers most of the
  loss; the isotonic mapper from v3 (different architecture) gave 92.9%
  joint coverage. v7 isotonic mappers should be re-fit per checkpoint.
- **Gain shrinks at extreme dropout.** Beyond 50% past-track gap, both
  models fail catastrophically (FDE > 6 km). The pretraining gain is
  most valuable in the 10-50% dropout regime — exactly the operational
  range where deployment cares.

## Compute spent

~$22 of T4 cloud compute across 7 iterations of the project (v1 through
v7), including pretraining, baseline runs, and ablations.

## Citation

If you use this checkpoint, cite the project artifacts:

- ATFMTraj dataset: arxiv:2407.20028
- Forecast-MAE precedent for SSL+forecasting: arxiv:2308.09882
- HiVT decoder design: arxiv:2207.09588
- β-NLL: arxiv:2203.09168

## License

Apache-2.0.