guychuk commited on
Commit
bf02000
·
verified ·
1 Parent(s): b2363f8

v7: model card with degradation curves, frozen + held-out ablations

Browse files
Files changed (1) hide show
  1. README.md +173 -0
README.md ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - trajectory-forecasting
5
+ - aviation
6
+ - ads-b
7
+ - self-supervised
8
+ - jepa
9
+ - pytorch
10
+ datasets:
11
+ - petchthwr/ATFMTraj
12
+ language:
13
+ - en
14
+ library_name: pytorch
15
+ ---
16
+
17
+ # Flight-JEPA v7
18
+
19
+ A trajectory forecasting model for aircraft on terminal-area approach,
20
+ specialized for the **blindspot continuation** task: given an observed
21
+ past track, predict the trajectory through a coverage gap of variable
22
+ length and the reappearance distribution.
23
+
24
+ The headline contribution of v7 is a **JEPA-style past-track masked
25
+ pretraining recipe** that produces representations more robust to test-
26
+ time radar coverage gaps. Pretrained-then-fine-tuned models maintain
27
+ significantly lower forecasting error and better calibrated uncertainty
28
+ when up to half of the past observations are missing — the regime
29
+ aviation deployment cares about.
30
+
31
+ ## Quick numbers
32
+
33
+ On RKSIa (Incheon arrivals, 8092 test trajectories, n=3 seeds):
34
+
35
+ | Past-track dropout | Scratch FDE | Pretrained FDE | Δ | p (Welch) |
36
+ |---|---:|---:|---:|---:|
37
+ | 0% (clean) | 1890 ± 40 m | **1752 ± 21 m** | **−7.3%** | **0.012** |
38
+ | 10% | 2381 ± 51 m | **2048 ± 62 m** | **−14.0%** | **0.002** |
39
+ | 30% | 3108 ± 87 m | **2591 ± 94 m** | **−16.6%** | **0.002** |
40
+ | 50% | 3784 ± 186 m | **3322 ± 42 m** | **−12.2%** | **0.044** |
41
+ | 70% | 6607 ± 308 m | 6325 ± 79 m | −4.3% | 0.249 |
42
+
43
+ 95% coverage at 50% dropout: **85.1% (pretrained) vs 77.1% (scratch)**, p=0.026.
44
+
45
+ ## Ablations
46
+
47
+ **Frozen-encoder probe** (pretrained encoder weights frozen; only the
48
+ decoder trains, 3 seeds): FDE 2541 m at 0% dropout — **clearly worse
49
+ than full fine-tune** (1752 m). The pretrained representation is
50
+ informative but suboptimal as-is for forecasting; encoder must adapt.
51
+
52
+ **Held-out-class generalization** (training set with classes
53
+ 6, 18, 28 excluded; evaluated on those held-out classes):
54
+
55
+ | Eval | Scratch FDE | Pretrained FDE | Δ | p |
56
+ |---|---:|---:|---:|---:|
57
+ | In-distribution (without held-out classes) | 2118 ± 108 | 1960 ± 58 | −7.4% | 0.110 |
58
+ | Held-out classes only | 3235 ± 207 | 2966 ± 157 | −8.3% | 0.152 |
59
+
60
+ Pretraining helps uniformly across both in-distribution and held-out
61
+ classes — the **generalization gap** (FDE degradation on held-out
62
+ classes vs in-dist) is identical: 52.7% (scratch) vs 51.2% (pretrained),
63
+ p=0.57. So pretraining does not *specifically* improve transfer to
64
+ unseen procedures; it provides a uniform ~8-10% boost everywhere. NLL
65
+ is significantly better for pretrained (p=0.012 in-dist, p=0.067
66
+ generalization).
67
+
68
+ ## Recipe
69
+
70
+ The architecture and training recipe in three layers:
71
+
72
+ 1. **Encoder/decoder** (no SSL, supervised baseline already strong)
73
+ - Causal transformer encoder over 8 s patches of past 9-dim features
74
+ `[x, y, z, uₓ, u_y, u_z, r, sin θ, cos θ]`. Last-token readout.
75
+ - HiVT-style parallel decoder (arxiv:2207.09588): single context vector
76
+ + per-step learnable positional embedding → MLP emits `[T, 7]` =
77
+ (μ, log σ, ρ) per step.
78
+ - β-NLL Gaussian loss (Seitzer 2022, arxiv:2203.09168) — prevents the
79
+ σ-collapse pathology of plain Gaussian NLL.
80
+ - Decoder is non-autoregressive: ~15× tighter seed variance than
81
+ teacher-forced GRU rollouts.
82
+
83
+ 2. **JEPA pretraining stage** (the v7 contribution)
84
+ - Mask contiguous blocks of past-track patches at ratios sampled from
85
+ U[0.3, 0.7] per example.
86
+ - Encoder + EMA target encoder + small predictor MLP.
87
+ - L1 loss in latent space on masked-patch latents.
88
+ - Trained 60 epochs on the same RKSIa training set (no labels used).
89
+
90
+ 3. **Fine-tuning + post-hoc calibration**
91
+ - Load pretrained encoder + tokenizer into the v6 architecture.
92
+ - Train forecaster end-to-end with β-NLL.
93
+ - Optional post-hoc isotonic calibration (Kuleshov 2018) for joint
94
+ coverage approaching nominal levels.
95
+
96
+ ## Files
97
+
98
+ | File | Description |
99
+ |------|-------------|
100
+ | `v7-pretrain.pt` | Pretrained encoder + tokenizer checkpoint (used for fine-tuning) |
101
+ | `v7-pre-finetune-s{0,1,2}.pt` | Fine-tuned models, 3 seeds (recommended for inference) |
102
+ | `v7-scratch-s{0,1,2}.pt` | Matched-compute scratch baselines (no pretraining) |
103
+ | `v6-lam0p0-s{0,1,2}.pt` | v6 reference checkpoints (older training pipeline) |
104
+ | `train_v7_finetune.py` | Single-file training script with `--pretrained-encoder` support |
105
+ | `pretrain_v7.py` | Self-contained pretraining script |
106
+
107
+ ## Quick start — inference
108
+
109
+ ```python
110
+ import torch
111
+ from huggingface_hub import hf_hub_download
112
+
113
+ ckpt_path = hf_hub_download("guychuk/flight-jepa-v2", "v7-pre-finetune-s0.pt")
114
+ state = torch.load(ckpt_path, map_location="cpu", weights_only=False)
115
+
116
+ from train_v2_prod import FlightJEPAv2 # from the project source
117
+ model = FlightJEPAv2(state["config"])
118
+ model.load_state_dict(state["state_dict"])
119
+ model.train(False) # inference mode
120
+
121
+ # Inputs (B=1, T_past=256, 9 features): 9-dim ADS-B-derived features.
122
+ # delta: (B,) — blindspot length in seconds, 30..120.
123
+ # last_pos: (B, 3) — position at the moment of going dark (normalized ENU).
124
+ mu_pos, sigma, rho = model.rollout(past_features, past_length, delta, last_pos,
125
+ delta_max=int(delta.max()))
126
+ # mu_pos: (B, delta_max, 3) predicted positions
127
+ # sigma: (B, delta_max, 3) predicted std
128
+ # rho: (B, delta_max) correlation between x and y per step
129
+ ```
130
+
131
+ ## Limitations
132
+
133
+ These were measured carefully and are part of the honest story:
134
+
135
+ - **Single airport.** Trained and evaluated on Incheon (RKSI) arrivals.
136
+ Cross-airport transfer not validated for forecasting (v1 classification
137
+ probes showed cross-airport "transfer" was a dataset-difficulty artifact).
138
+ - **Shuffled train/test split.** ATFMTraj's RKSIa split is a random
139
+ shuffled split, not held-out by route or by time. Same flight routes
140
+ appear in both partitions. Held-out-by-time is impossible without
141
+ re-querying OpenSky for the original timestamps. We add a
142
+ held-out-by-class generalization eval as a partial substitute.
143
+ - **The supervised baseline is already strong.** v6 baseline (no
144
+ pretraining) achieves FDE 1865 m at 0% dropout. The pretraining gain
145
+ on clean inputs is modest (~7%); the operationally meaningful gain
146
+ is at degradation (12-17%).
147
+ - **Calibration is partial without isotonic.** β-NLL training brings
148
+ raw 95% coverage to ~97% in-distribution, but coverage degrades
149
+ rapidly with input dropout. Post-hoc isotonic recovers most of the
150
+ loss; the isotonic mapper from v3 (different architecture) gave 92.9%
151
+ joint coverage. v7 isotonic mappers should be re-fit per checkpoint.
152
+ - **Gain shrinks at extreme dropout.** Beyond 50% past-track gap, both
153
+ models fail catastrophically (FDE > 6 km). The pretraining gain is
154
+ most valuable in the 10-50% dropout regime — exactly the operational
155
+ range where deployment cares.
156
+
157
+ ## Compute spent
158
+
159
+ ~$22 of T4 cloud compute across 7 iterations of the project (v1 through
160
+ v7), including pretraining, baseline runs, and ablations.
161
+
162
+ ## Citation
163
+
164
+ If you use this checkpoint, cite the project artifacts:
165
+
166
+ - ATFMTraj dataset: arxiv:2407.20028
167
+ - Forecast-MAE precedent for SSL+forecasting: arxiv:2308.09882
168
+ - HiVT decoder design: arxiv:2207.09588
169
+ - β-NLL: arxiv:2203.09168
170
+
171
+ ## License
172
+
173
+ Apache-2.0.