Jdice27 commited on
Commit
e43dca4
·
verified ·
1 Parent(s): faf2651

Add ARCHITECTURE.md

Browse files
Files changed (1) hide show
  1. ARCHITECTURE.md +638 -0
ARCHITECTURE.md ADDED
@@ -0,0 +1,638 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AirTrackLM: LLM4STP Adapted for ADS-B Air Track Prediction
2
+
3
+ ## Complete Architecture & Implementation Plan
4
+
5
+ ---
6
+
7
+ ## 1. Executive Summary
8
+
9
+ We adapt the LLM4STP multi-feature fusion architecture (originally for maritime AIS ship trajectory prediction) to work with **ADS-B air track data**. The model uses a **decoder-only transformer** with four specialized embedding types — Prompt, Uncertainty, Geohash, and Temporal — fused together for **next-state prediction** pretraining. Once pretrained, the model is adaptable to downstream tasks like activity classification.
10
+
11
+ This design is grounded in published results from:
12
+ - **FTP-LLM** (arXiv:2501.17459) — LLaMA-3.1-8B for flight trajectory prediction
13
+ - **H3-CLM** (arXiv:2405.09596) — H3 geohash + causal LM for maritime trajectories
14
+ - **GeoFormer** (arXiv:2311.05092) — GPT-style geospatial tokenization
15
+ - **TrAISFormer** (arXiv:2109.03958) — Discrete tokenization of AIS features
16
+
17
+ ---
18
+
19
+ ## 2. System Architecture Overview
20
+
21
+ ```
22
+ ┌─────────────────────────────────────────────────────────────────────┐
23
+ │ RAW ADS-B INPUT │
24
+ │ (timestamp, latitude, longitude, altitude) │
25
+ └─────────────────────────┬───────────────────────────────────────────┘
26
+
27
+
28
+ ┌─────────────────────────────────────────────────────────────────────┐
29
+ │ FEATURE DERIVATION PIPELINE │
30
+ │ │
31
+ │ Raw: lat, lon, alt │
32
+ │ Derived: COG, SOG, ROT, altitude_rate │
33
+ │ Meta: timestamp → (hour, day_of_week, month) │
34
+ │ │
35
+ │ Output per timestep: │
36
+ │ state_t = [lat, lon, alt, COG, SOG, ROT, alt_rate] │
37
+ └─────────────────────────┬───────────────────────────────────────────┘
38
+
39
+
40
+ ┌─────────────────────────────────────────────────────────────────────┐
41
+ │ TOKENIZATION / ENCODING │
42
+ │ │
43
+ │ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
44
+ │ │ Geohash │ │ Continuous │ │ Temporal │ │
45
+ │ │ Tokenizer │ │ Discretizer │ │ Encoder │ │
46
+ │ │ │ │ │ │ │ │
47
+ │ │ lat,lon,alt │ │ COG,SOG,ROT │ │ hour,dow, │ │
48
+ │ │ → H3 cell + │ │ alt_rate │ │ month │ │
49
+ │ │ alt_band │ │ → bin IDs │ │ → time IDs │ │
50
+ │ └──────┬───────┘ └──────┬───────┘ └──────┬───────┘ │
51
+ │ │ │ │ │
52
+ │ ▼ ▼ ▼ │
53
+ │ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
54
+ │ │ Geohash │ │ Feature │ │ Temporal │ │
55
+ │ │ Embedding │ │ Embeddings │ │ Embedding │ │
56
+ │ │ Table │ │ Tables │ │ Table │ │
57
+ │ │ (d_model) │ │ (d_model) │ │ (d_model) │ │
58
+ │ └──────┬───────┘ └──────┬───────┘ └──────┬───────┘ │
59
+ │ │ │ │ │
60
+ └──────────┼─────────────────┼─────────────────┼──────────────────────┘
61
+ │ │ │
62
+ ▼ ▼ ▼
63
+ ┌─────────────────────────────────────────────────────────────────────┐
64
+ │ EMBEDDING FUSION LAYER │
65
+ │ │
66
+ │ ┌────────────┐ ┌────────────┐ ┌────────────┐ ┌──────────────┐ │
67
+ │ │ Geohash │ │ Feature │ │ Temporal │ │ Uncertainty │ │
68
+ │ │ Embed │ │ Embed │ │ Embed │ │ Embed │ │
69
+ │ │ (d_model) │ │ (d_model) │ │ (d_model) │ │ (d_model) │ │
70
+ │ └─────┬──────┘ └─────┬──────┘ └─────┬──────┘ └──────┬───────┘ │
71
+ │ │ │ │ │ │
72
+ │ └──────────┬───┴──────┬───────┘ │ │
73
+ │ │ │ │ │
74
+ │ ▼ ▼ ▼ │
75
+ │ E_state = E_geo + E_feat + E_temp + E_uncert │
76
+ │ │ │
77
+ │ ▼ │
78
+ │ ┌───────────────────────────────────────────┐ │
79
+ │ │ Prompt Embedding (prepended prefix) │ │
80
+ │ │ [PROMPT_1, PROMPT_2, ..., PROMPT_k] │ │
81
+ │ └───────────────────┬───────────────────────┘ │
82
+ │ │ │
83
+ │ ▼ │
84
+ │ Input: [PROMPT_TOKENS | STATE_1 | STATE_2 | ... | STATE_T] │
85
+ │ │ │
86
+ │ ▼ │
87
+ │ Linear Projection → d_model │
88
+ │ │ │
89
+ │ ▼ │
90
+ │ + Positional Encoding (sinusoidal) │
91
+ │ │
92
+ └───────────────────────┬─────────────────────────────────────────────┘
93
+
94
+
95
+ ┌─────────────────────────────────────────────────────────────────────┐
96
+ │ DECODER-ONLY TRANSFORMER BACKBONE │
97
+ │ │
98
+ │ ┌─────────────────────────────────────────────────────┐ │
99
+ │ │ Transformer Block ×N_layers │ │
100
+ │ │ │ │
101
+ │ │ ┌─────────────────────────────────────────┐ │ │
102
+ │ │ │ Causal Multi-Head Self-Attention │ │ │
103
+ │ │ │ (masked: each position attends only │ │ │
104
+ │ │ │ to itself and earlier positions) │ │ │
105
+ │ │ └──────────────────┬──────────────────────┘ │ │
106
+ │ │ │ │ │
107
+ │ │ ▼ │ │
108
+ │ │ ┌─────────────────────────────────────────┐ │ │
109
+ │ │ │ LayerNorm + Residual Connection │ │ │
110
+ │ │ └─────────���────────┬──────────────────────┘ │ │
111
+ │ │ │ │ │
112
+ │ │ ▼ │ │
113
+ │ │ ┌─────────────────────────────────────────┐ │ │
114
+ │ │ │ Feed-Forward Network │ │ │
115
+ │ │ │ (Linear → GELU → Linear) │ │ │
116
+ │ │ │ d_model → 4*d_model → d_model │ │ │
117
+ │ │ └──────────────────┬──────────────────────┘ │ │
118
+ │ │ │ │ │
119
+ │ │ ▼ │ │
120
+ │ │ ┌─────────────────────────────────────────┐ │ │
121
+ │ │ │ LayerNorm + Residual Connection │ │ │
122
+ │ │ └─────────────────────────────────────────┘ │ │
123
+ │ │ │ │
124
+ │ └─────────────────────────────────────────────┘ │ │
125
+ │ │
126
+ └───────────────────────┬─────────────────────────────────────────────┘
127
+
128
+
129
+ ┌─────────────────────────────────────────────────────────────────────┐
130
+ │ OUTPUT HEADS │
131
+ │ │
132
+ │ ┌─────────────────────────────────────────────────────────┐ │
133
+ │ │ PRETRAINING: Next-State Prediction Head │ │
134
+ │ │ │ │
135
+ │ │ For each position t, predict state at t+1: │ │
136
+ │ │ │ │
137
+ │ │ h_t → Linear → softmax → P(geohash_token_{t+1}) │ │
138
+ │ │ h_t → Linear → softmax → P(COG_bin_{t+1}) │ │
139
+ │ │ h_t → Linear → softmax → P(SOG_bin_{t+1}) │ │
140
+ │ │ h_t → Linear → softmax → P(ROT_bin_{t+1}) │ │
141
+ │ │ h_t → Linear → softmax → P(alt_rate_bin_{t+1}) │ │
142
+ │ │ h_t → Linear → softmax → P(alt_band_{t+1}) │ │
143
+ │ │ │ │
144
+ │ │ Loss = Σ CrossEntropy(predicted_feature, true_feature) │ │
145
+ │ └─────────────────────────────────────────────────────────┘ │
146
+ │ │
147
+ │ ┌─────────────────────────────────────────────────────────┐ │
148
+ │ │ DOWNSTREAM: Activity Classification Head │ │
149
+ │ │ (attached after pretraining, frozen or fine-tuned) │ │
150
+ │ │ │ │
151
+ │ │ h_[BOS] or mean(h_1:T) → MLP → softmax → class label │ │
152
+ │ └─────────────────────────────────────────────────────────┘ │
153
+ │ │
154
+ └─────────────────────────────────────────────────────────────────────┘
155
+ ```
156
+
157
+ ---
158
+
159
+ ## 3. The Four Embedding Types (Detailed)
160
+
161
+ ### 3.1 Geohash Embeddings — Spatial Position Encoding
162
+
163
+ **Purpose**: Encode the aircraft's 3D geographic position as a discrete token.
164
+
165
+ **Method**: We use **H3 hexagonal hierarchical spatial index** (Uber's H3) at resolution 5 (hex area ≈ 252 km², edge ≈ 9.85 km) for en-route flight, with an option to use resolution 7 (≈ 5.16 km², edge ≈ 1.22 km) for terminal areas. This follows the H3-CLM paper's approach but adapted for aviation's larger spatial scale.
166
+
167
+ **3D Extension**: Since aircraft operate in 3D, we combine the H3 cell with an **altitude band**:
168
+ ```
169
+ Geohash Token = H3_cell_index × N_alt_bands + alt_band_index
170
+
171
+ Altitude bands (1000 ft increments):
172
+ Band 0: 0 - 1,000 ft (ground / taxi)
173
+ Band 1: 1,000 - 2,000 ft (initial climb / approach)
174
+ ...
175
+ Band 45: 44,000 - 45,000 ft (high cruise)
176
+
177
+ N_alt_bands = 46
178
+ ```
179
+
180
+ **Vocabulary size**: At H3 resolution 5, the number of unique cells covering typical airspace is ~100K-200K. With altitude bands: `~200K × 46 ≈ 9.2M` — too large for direct embedding.
181
+
182
+ **Solution — Factored Embedding**:
183
+ ```
184
+ E_geohash = E_h3[h3_cell_id] + E_alt[alt_band_id]
185
+
186
+ E_h3: learned embedding table, vocab = N_h3_cells (~200K or hashing trick to 50K)
187
+ E_alt: learned embedding table, vocab = 46
188
+
189
+ Both project to d_model dimensions.
190
+ ```
191
+
192
+ The **hashing trick**: Map H3 cell indices through a hash function to a fixed vocabulary of ~50,000 buckets. This bounds memory while maintaining spatial discrimination.
193
+
194
+ **Why H3 over traditional geohash**: H3 hexagons have uniform area (no polar distortion), hierarchical nesting, and consistent neighbor relationships — critical for trajectory continuity.
195
+
196
+ ### 3.2 Temporal Embeddings — When Is the Aircraft Flying?
197
+
198
+ **Purpose**: Encode temporal context — time of day affects traffic density, routes, and behavior.
199
+
200
+ **Method**: Additive composition of multiple temporal scales:
201
+ ```
202
+ E_temporal = E_hour[hour_of_day] + E_dow[day_of_week] + E_month[month]
203
+
204
+ E_hour: 24 entries (captures rush hour vs. night patterns)
205
+ E_dow: 7 entries (weekday vs. weekend traffic)
206
+ E_month: 12 entries (seasonal routes, weather patterns)
207
+
208
+ All project to d_model dimensions.
209
+ ```
210
+
211
+ **Optional — Sinusoidal Sub-minute Encoding**: For sub-minute resolution:
212
+ ```
213
+ E_minute = sin(2π × minute / 60), cos(2π × minute / 60) → linear → d_model
214
+ ```
215
+
216
+ ### 3.3 Uncertainty Embeddings — How Confident Are We?
217
+
218
+ **Purpose**: Encode the model's uncertainty about the current trajectory state. Aircraft in straight-and-level cruise have low uncertainty; aircraft maneuvering near airports have high uncertainty.
219
+
220
+ **Method**: Compute a **trajectory smoothness score** from recent states, then discretize:
221
+
222
+ ```
223
+ Uncertainty sources (sliding window of k=5 recent states):
224
+
225
+ 1. Position variance: σ²_pos = var(Δlat) + var(Δlon)
226
+ 2. Heading variance: σ²_COG = circular_var(COG_{t-k:t})
227
+ 3. Speed variance: σ²_SOG = var(SOG_{t-k:t})
228
+ 4. Altitude variance: σ²_alt = var(alt_rate_{t-k:t})
229
+
230
+ Combined uncertainty score:
231
+ U_t = w1·σ²_pos + w2·σ²_COG + w3·σ²_SOG + w4·σ²_alt
232
+
233
+ Discretize into N_uncert = 16 bins (quantile binning on training data)
234
+
235
+ E_uncertainty = E_uncert_table[bin(U_t)] → d_model
236
+ ```
237
+
238
+ **Weights w1-w4**: Hyperparameters tuned on validation data, or learned as part of the model.
239
+
240
+ **During inference**: For multi-step prediction, uncertainty can be updated using MC-Dropout or ensemble disagreement.
241
+
242
+ ### 3.4 Prompt Embeddings — Task and Context Metadata
243
+
244
+ **Purpose**: Provide metadata context about the flight, analogous to system prompts in LLMs. Enables task conditioning and multi-task learning.
245
+
246
+ **Method**: Learnable prompt tokens prepended to the trajectory:
247
+
248
+ ```
249
+ Prompt token vocabulary:
250
+ - Aircraft category: [HEAVY, LARGE, SMALL, ROTORCRAFT, GLIDER, UAV, UNKNOWN] (7)
251
+ - Flight phase: [CLIMB, CRUISE, DESCENT, APPROACH, GROUND, UNKNOWN] (6)
252
+ - Region: [CONUS, EUROPE, ASIA, OTHER] (4)
253
+ - Task: [PREDICT, CLASSIFY, DETECT_ANOMALY] (3)
254
+ - Special: [BOS, EOS, PAD, MASK] (4)
255
+
256
+ Total prompt vocab: ~24 tokens
257
+
258
+ Prompt sequence (prepended):
259
+ [BOS, TASK_TOKEN, AIRCRAFT_TOKEN, PHASE_TOKEN, REGION_TOKEN]
260
+
261
+ Each has a learned embedding of dimension d_model.
262
+ ```
263
+
264
+ **For downstream classification**: Change TASK_TOKEN to CLASSIFY; output at BOS position is used for classification.
265
+
266
+ ---
267
+
268
+ ## 4. Feature Derivation Pipeline
269
+
270
+ ### 4.1 Raw Input
271
+ ```
272
+ timestamp (Unix epoch seconds)
273
+ latitude (degrees, WGS84)
274
+ longitude (degrees, WGS84)
275
+ altitude (feet, barometric or geometric)
276
+ ```
277
+
278
+ ### 4.2 Derived Features
279
+
280
+ ```python
281
+ import numpy as np
282
+
283
+ def derive_features(timestamps, lats, lons, alts):
284
+ """
285
+ Derive COG, SOG, ROT, and altitude rate from raw position data.
286
+ All inputs: numpy arrays of shape (N,) for a single trajectory.
287
+ Returns arrays of shape (N,) — first element is NaN.
288
+ """
289
+ dt = np.diff(timestamps) # seconds
290
+ dt = np.maximum(dt, 1e-6) # avoid division by zero
291
+
292
+ # --- Course Over Ground (COG) ---
293
+ lat1, lat2 = np.radians(lats[:-1]), np.radians(lats[1:])
294
+ dlon = np.radians(np.diff(lons))
295
+
296
+ x = np.sin(dlon) * np.cos(lat2)
297
+ y = np.cos(lat1) * np.sin(lat2) - np.sin(lat1) * np.cos(lat2) * np.cos(dlon)
298
+ COG = np.degrees(np.arctan2(x, y)) % 360 # [0, 360)
299
+
300
+ # --- Speed Over Ground (SOG) ---
301
+ dlat = np.radians(np.diff(lats))
302
+ a = np.sin(dlat/2)**2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon/2)**2
303
+ c = 2 * np.arctan2(np.sqrt(a), np.sqrt(1-a))
304
+ distance_nm = 3440.065 * c # Earth radius in nautical miles
305
+ SOG = distance_nm / (dt / 3600) # knots
306
+
307
+ # --- Rate of Turn (ROT) ---
308
+ dCOG = np.diff(COG)
309
+ dCOG = (dCOG + 180) % 360 - 180 # normalize to [-180, 180]
310
+ ROT = np.full(len(lats), np.nan)
311
+ ROT[2:] = dCOG / dt[1:] # degrees per second
312
+
313
+ # --- Rate of Altitude Change ---
314
+ dalt = np.diff(alts) # feet
315
+ alt_rate = dalt / (dt / 60) # feet per minute
316
+
317
+ # Pad first elements
318
+ COG_full = np.concatenate([[np.nan], COG])
319
+ SOG_full = np.concatenate([[np.nan], SOG])
320
+ alt_rate_full = np.concatenate([[np.nan], alt_rate])
321
+
322
+ return COG_full, SOG_full, ROT, alt_rate_full
323
+ ```
324
+
325
+ ### 4.3 Feature Discretization
326
+
327
+ | Feature | Range | Bin Width | N_bins | Notes |
328
+ |---------------|-------------------|--------------|--------|--------------------|
329
+ | COG | [0, 360) | 5° | 72 | Circular |
330
+ | SOG | [0, 600] kts | 5 knots | 121 | Capped at ~Mach 1 |
331
+ | ROT | [-6, 6] °/s | 0.25 °/s | 49 | Capped ±6°/s |
332
+ | Altitude Rate | [-6000, 6000] fpm | 200 ft/min | 61 | Capped ±6000 fpm |
333
+
334
+ Outliers beyond caps clipped to boundary bin.
335
+
336
+ ### 4.4 Trajectory Preprocessing Pipeline
337
+
338
+ ```
339
+ 1. Segment raw ADS-B by ICAO24 + temporal gaps > 15 min → individual flights
340
+ 2. Resample to fixed Δt = 60 seconds (linear interp for position, circular for heading)
341
+ 3. Derive features (COG, SOG, ROT, alt_rate)
342
+ 4. Drop first 2 points per trajectory (NaN from derivation)
343
+ 5. Filter: remove trajectories with < 20 points (< 20 minutes)
344
+ 6. Compute H3 cell (res 5) + altitude band for each point
345
+ 7. Discretize all continuous features into bins
346
+ 8. Compute uncertainty scores (sliding window k=5)
347
+ 9. Extract temporal features (hour, dow, month)
348
+ 10. Construct prompt tokens from metadata (if available)
349
+ ```
350
+
351
+ ---
352
+
353
+ ## 5. Model Hyperparameters
354
+
355
+ ### 5.1 Model Dimensions
356
+
357
+ | Parameter | Value | Rationale |
358
+ |------------------|--------|----------------------------------------------------|
359
+ | d_model | 256 | H3-CLM found 256-1024 effective |
360
+ | n_heads | 8 | head_dim = 32 |
361
+ | n_layers | 8 | Moderate depth for ~10M param model |
362
+ | d_ff | 1024 | 4× d_model (standard) |
363
+ | max_seq_len | 128 | 128 states × 60s ≈ 2 hours of flight |
364
+ | n_prompt_tokens | 5 | [BOS, TASK, AIRCRAFT, PHASE, REGION] |
365
+ | dropout | 0.1 | |
366
+
367
+ **Total parameters**: ~8-12M (trainable on single GPU in hours)
368
+
369
+ ### 5.2 Vocabulary Sizes
370
+
371
+ | Embedding | Vocab | Dim |
372
+ |------------------|--------|-----|
373
+ | H3 cells | 50,000 | 256 |
374
+ | Altitude bands | 46 | 256 |
375
+ | COG bins | 72 | 256 |
376
+ | SOG bins | 121 | 256 |
377
+ | ROT bins | 49 | 256 |
378
+ | Alt rate bins | 61 | 256 |
379
+ | Hour of day | 24 | 256 |
380
+ | Day of week | 7 | 256 |
381
+ | Month | 12 | 256 |
382
+ | Uncertainty bins | 16 | 256 |
383
+ | Prompt tokens | 24 | 256 |
384
+
385
+ ### 5.3 State Token Composition
386
+
387
+ Each timestep → single state token via additive fusion:
388
+
389
+ ```
390
+ E_state_t = E_h3[h3_id_t] + E_alt_band[alt_band_t] # Geohash (3D position)
391
+ + E_COG[cog_bin_t] + E_SOG[sog_bin_t] # Kinematics
392
+ + E_ROT[rot_bin_t] + E_alt_rate[alt_rate_bin_t] # Dynamics
393
+ + E_hour[hour_t] + E_dow[dow_t] + E_month[month_t] # Temporal
394
+ + E_uncert[uncert_bin_t] # Uncertainty
395
+
396
+ E_state_t ∈ R^{d_model}
397
+ ```
398
+
399
+ This additive fusion follows BERT (token + segment + position) and TrAISFormer.
400
+
401
+ ---
402
+
403
+ ## 6. Training Recipe
404
+
405
+ ### 6.1 Pretraining: Next-State Prediction (Causal LM)
406
+
407
+ **Objective**: Given states 1..T, predict state at T+1 (applied autoregressively at every position).
408
+
409
+ **Loss**:
410
+ ```
411
+ L = Σ_{t=1}^{T-1} [ λ_geo · CE(ŷ_geo_t, y_geo_{t+1})
412
+ + λ_COG · CE(ŷ_COG_t, y_COG_{t+1})
413
+ + λ_SOG · CE(ŷ_SOG_t, y_SOG_{t+1})
414
+ + λ_ROT · CE(ŷ_ROT_t, y_ROT_{t+1})
415
+ + λ_alt · CE(ŷ_alt_rate_t, y_alt_rate_{t+1})
416
+ + λ_altb · CE(ŷ_alt_band_t, y_alt_band_{t+1}) ]
417
+
418
+ λ values default to 1.0 (equal weighting).
419
+ ```
420
+
421
+ **Training hyperparameters** (based on FTP-LLM + H3-CLM):
422
+
423
+ | Parameter | Value |
424
+ |----------------------|---------------------|
425
+ | Optimizer | AdamW |
426
+ | Learning rate | 5e-4 |
427
+ | LR Schedule | Cosine + 5% warmup |
428
+ | Batch size (per GPU) | 64 |
429
+ | Gradient accumulation| 4 (effective = 256) |
430
+ | Max epochs | 30 (early stop p=5) |
431
+ | Weight decay | 0.01 |
432
+ | Gradient clipping | 1.0 |
433
+ | Mixed precision | bf16 |
434
+
435
+ **Data windowing**: Sliding window size=128, stride=64 (50% overlap).
436
+
437
+ ### 6.2 Downstream: Activity Classification
438
+
439
+ After pretraining, attach classification head:
440
+ ```
441
+ h_BOS → Linear(256, 128) → GELU → Dropout(0.1) → Linear(128, N_classes)
442
+ ```
443
+
444
+ **Fine-tuning options**:
445
+ - **A**: Freeze backbone, train head only (fast, small data)
446
+ - **B**: Full fine-tune, backbone lr=1e-5, head lr=1e-3
447
+
448
+ ---
449
+
450
+ ## 7. Dataset Strategy
451
+
452
+ ### 7.1 Prototyping — `traffic` Python Library
453
+
454
+ ```python
455
+ from traffic.data.samples import landing_zurich_2019
456
+ # ~2,000 flights near Zurich
457
+ # Columns: timestamp, icao24, callsign, latitude, longitude, altitude,
458
+ # groundspeed, track, vertical_rate, ...
459
+ ```
460
+
461
+ Instant access, clean, well-documented. Single airport, limited diversity.
462
+
463
+ ### 7.2 Training — OpenSky Network
464
+
465
+ ```python
466
+ from pyopensky.trino import Trino
467
+ trino = Trino()
468
+ df = trino.rawquery("""
469
+ SELECT time, icao24, lat, lon, baroaltitude, velocity, heading, vertrate
470
+ FROM state_vectors_data4
471
+ WHERE hour >= '2024-01-15 00:00:00'
472
+ AND hour < '2024-01-15 12:00:00'
473
+ AND lat BETWEEN 40 AND 55
474
+ AND lon BETWEEN -10 AND 20
475
+ ORDER BY icao24, time
476
+ """)
477
+ ```
478
+
479
+ **Target**:
480
+ - **Region A** (train): Europe, 1 month → ~500K-1M flights
481
+ - **Region B** (OOD test): US CONUS, 1 week → ~200K flights
482
+ - **Region C** (far test): East Asia, 1 week → ~100K flights
483
+
484
+ ### 7.3 Alternative: SCAT Dataset
485
+
486
+ ~170K en-route flights over Sweden, Zenodo. Pre-segmented, clean.
487
+
488
+ ### 7.4 Data Split
489
+
490
+ ```
491
+ Training: 70% of Region A flights
492
+ Validation: 15% of Region A flights
493
+ Test (IID): 15% of Region A flights
494
+ Test (OOD): 100% of Region B flights
495
+ Test (Far): 100% of Region C flights
496
+ ```
497
+
498
+ Split by **flight** (not time window) to avoid data leakage.
499
+
500
+ ---
501
+
502
+ ## 8. Ablation Study: Geohash Geographic Dependency
503
+
504
+ ### 8.1 Hypothesis
505
+
506
+ > Geohash embeddings encode **absolute geographic position**, causing the model to memorize region-specific patterns (airways, approach paths, airspace structure). This improves in-distribution performance but degrades transfer to unseen regions.
507
+
508
+ ### 8.2 Experimental Variants
509
+
510
+ | Variant | Geohash Type | Description |
511
+ |---------|-------------|-------------|
512
+ | **V1: Full Model** | H3 absolute | Complete architecture as described |
513
+ | **V2: No Geohash** | None | Remove geohash entirely; model sees only kinematics + temporal + uncertainty |
514
+ | **V3: Relative Geohash** | H3 relative | H3 cell of (Δlat, Δlon) from trajectory start — position-invariant |
515
+ | **V4: Multi-Resolution** | H3 res 3+5+7 | 3 resolutions summed (coarse→fine) |
516
+ | **V5: Continuous Position** | Linear projection | `Linear([lat, lon, alt] → d_model)` — no discretization |
517
+
518
+ ### 8.3 Evaluation Metrics
519
+
520
+ For each variant × each test set (IID, OOD, Far):
521
+
522
+ | Metric | Description |
523
+ |--------|-------------|
524
+ | Geo Accuracy | % correct H3 cell prediction |
525
+ | Position MAE | Mean absolute error in km |
526
+ | COG MAE | Heading error in degrees |
527
+ | SOG MAE | Speed error in knots |
528
+ | Multi-step ADE | Average displacement error over 5 predicted steps |
529
+ | Multi-step FDE | Final displacement error at step 5 |
530
+
531
+ ### 8.4 Key Comparisons
532
+
533
+ | Comparison | Tests |
534
+ |-----------|-------|
535
+ | V1 vs V2 (IID) | How much geohash helps when test = train region |
536
+ | V1 vs V2 (OOD) | If V2 > V1 on OOD → geohash causes geographic overfitting |
537
+ | V1 vs V3 (OOD) | If V3 good on both IID and OOD → relative geohash is the sweet spot |
538
+ | V4 (all) | Multi-resolution: coarse cells transfer, fine cells specialize? |
539
+ | V5 (all) | Does continuous encoding avoid discretization issues? |
540
+
541
+ ### 8.5 Expected Outcomes
542
+
543
+ - **V1**: Best IID, worst OOD (hypothesis)
544
+ - **V3**: Best compromise — predicted winner
545
+ - **V5**: May struggle (loses discrete token structure transformers excel at)
546
+ - **V2**: Strong OOD baseline, sacrifices IID
547
+
548
+ ### 8.6 Additional Analysis
549
+
550
+ - **Attention visualization**: V1 vs V3 attention patterns
551
+ - **Embedding clustering**: t-SNE of geohash embeddings colored by region
552
+ - **Learning curves**: IID vs OOD performance vs training data size
553
+
554
+ ---
555
+
556
+ ## 9. Implementation Phases
557
+
558
+ ### Phase 1: Data Pipeline (Week 1)
559
+ - Set up `traffic` library, extract sample trajectories
560
+ - Implement feature derivation (COG, SOG, ROT, alt_rate)
561
+ - Implement H3 geohash encoding + altitude banding
562
+ - Implement feature discretization (binning)
563
+ - Implement uncertainty score computation
564
+ - Build PyTorch Dataset class with sliding window
565
+ - Unit tests for all derivation functions
566
+
567
+ ### Phase 2: Model Architecture (Week 1-2)
568
+ - Implement all embedding tables
569
+ - Implement additive fusion layer
570
+ - Implement prompt token prepending
571
+ - Implement decoder-only transformer backbone
572
+ - Implement multi-head output (6 prediction heads)
573
+ - Implement classification head (for downstream)
574
+ - Forward pass test with dummy data
575
+
576
+ ### Phase 3: Pretraining (Week 2-3)
577
+ - Implement training loop with multi-task loss
578
+ - Prototyping run on `traffic` data (small, fast iteration)
579
+ - Scale to OpenSky data
580
+ - Monitor loss curves, validate convergence
581
+ - Save best checkpoint
582
+
583
+ ### Phase 4: Downstream Adaptation (Week 3-4)
584
+ - Implement classification fine-tuning pipeline
585
+ - Test on activity classification task
586
+ - Compare frozen vs. fine-tuned backbone
587
+
588
+ ### Phase 5: Ablation Study (Week 4-5)
589
+ - Implement all 5 geohash variants
590
+ - Train each variant with identical hyperparameters
591
+ - Evaluate on IID, OOD, and Far test sets
592
+ - Generate comparison tables and visualizations
593
+ - Write analysis of geographic dependency findings
594
+
595
+ ---
596
+
597
+ ## 10. Key Design Decisions & Rationale
598
+
599
+ | Decision | Choice | Why |
600
+ |----------|--------|-----|
601
+ | Custom model vs. pretrained LLM | Custom ~10M param transformer | FTP-LLM showed text-tokenized LLMs work, but custom allows proper multi-feature fusion. 10M params trains in hours. |
602
+ | H3 vs. traditional geohash | H3 | Uniform hexagonal cells, no polar distortion, hierarchical. Proven by H3-CLM. |
603
+ | Additive vs. concatenative fusion | Additive | BERT/TrAISFormer paradigm. Keeps d_model constant. Concatenation → d_model × N_features = massive. |
604
+ | 60s time resolution | 60 seconds | FTP-LLM validated 1-min aggregation. 128 steps ≈ 2+ hours. |
605
+ | Factored geohash (H3 + alt) | Separate tables, summed | Avoids combinatorial explosion (9.2M → 50K + 46). |
606
+ | Multi-head output | Separate softmax per feature | More interpretable, allows per-feature analysis. |
607
+ | Uncertainty from smoothness | Variance-based | Computable at data time, no inference overhead. |
608
+
609
+ ---
610
+
611
+ ## 11. Risk Analysis
612
+
613
+ | Risk | Likelihood | Impact | Mitigation |
614
+ |------|-----------|--------|------------|
615
+ | Geohash overfits to region | High | High | Ablation study; V3 (relative) is fallback |
616
+ | OpenSky access issues | Medium | High | Fallback: `traffic` samples + SCAT |
617
+ | 60s too coarse for terminal | Medium | Low | Separate terminal model at 10s |
618
+ | Model too small | Low | Medium | Scale: d_model→512, n_layers→16 (~40M) |
619
+ | Alt discretization too coarse | Low | Low | Refine to 500ft bands (92) |
620
+
621
+ ---
622
+
623
+ ## 12. Monitoring & Evaluation
624
+
625
+ **During training** (Trackio):
626
+ - Total loss + per-feature loss curves
627
+ - Validation loss each epoch
628
+ - LR schedule, GPU utilization
629
+
630
+ **After training**:
631
+ - Next-state accuracy (top-1, top-5 per feature)
632
+ - Position error in km
633
+ - Multi-step prediction (1, 5, 10, 20 steps ahead)
634
+ - Downstream classification F1/precision/recall
635
+
636
+ ---
637
+
638
+ *Grounded in: FTP-LLM, H3-CLM, GeoFormer, TrAISFormer, and LLM4STP (reconstructed). Ready for implementation upon approval.*