Jdice27 commited on
Commit
f967e70
·
verified ·
1 Parent(s): e8142ba

Add model.py

Browse files
Files changed (1) hide show
  1. model.py +791 -0
model.py ADDED
@@ -0,0 +1,791 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AirTrackLM - Model Architecture
3
+ ================================
4
+ Decoder-only transformer with 4 embedding types for air track next-state prediction.
5
+
6
+ Embedding types (following LLM4STP, adapted for aviation):
7
+ 1. Geohash: 40-bit binary per ENU axis (120 bits total) → Linear projection → d_model
8
+ 2. Temporal: Sinusoidal second-of-day + learned hour/dow/month embeddings
9
+ 3. Uncertainty: Learned embedding from trajectory smoothness bins
10
+ 4. Prompt: Learned tokens for task/aircraft/phase/region metadata
11
+
12
+ Core architecture:
13
+ - Additive embedding fusion (E_geo + E_feat + E_temp + E_uncert)
14
+ - Prompt tokens prepended to sequence
15
+ - Causal (GPT-style) multi-head self-attention
16
+ - Multi-head output: separate prediction per feature type
17
+ """
18
+
19
+ import math
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ from typing import Optional, Dict, Tuple
24
+ from dataclasses import dataclass
25
+
26
+
27
+ # ============================================================
28
+ # Configuration
29
+ # ============================================================
30
+
31
+ @dataclass
32
+ class AirTrackConfig:
33
+ """Model configuration."""
34
+
35
+ # Transformer backbone
36
+ d_model: int = 256
37
+ n_heads: int = 8
38
+ n_layers: int = 8
39
+ d_ff: int = 1024
40
+ dropout: float = 0.1
41
+ max_seq_len: int = 256 # max sequence length (prompt + trajectory)
42
+
43
+ # Geohash embedding (LLM4STP style)
44
+ geohash_bits: int = 120 # 40 bits × 3 axes (E, N, U)
45
+ geohash_hidden: int = 64 # intermediate projection dim
46
+
47
+ # Feature bins (discretized kinematic features)
48
+ n_cog_bins: int = 180 # 2° resolution over [0, 360)
49
+ n_sog_bins: int = 300 # 2-knot resolution over [0, 600]
50
+ n_rot_bins: int = 120 # 0.1°/s over [-6, 6]
51
+ n_alt_rate_bins: int = 120 # 100 ft/min over [-6000, 6000]
52
+
53
+ # Temporal embedding
54
+ n_hours: int = 24
55
+ n_dow: int = 7
56
+ n_months: int = 12
57
+ time_sinusoidal_dim: int = 32 # dimension for sinusoidal second-of-day encoding
58
+
59
+ # Uncertainty embedding
60
+ n_uncert_bins: int = 16
61
+ n_uncert_methods: int = 4 # kinematic_var, pred_residual, spatial_density, phase_entropy
62
+ use_multi_uncertainty: bool = True # if True, use MultiUncertaintyEmbedding
63
+ use_heteroscedastic: bool = True # if True, add learned uncertainty head
64
+
65
+ # Prompt embedding
66
+ n_prompt_tokens: int = 23 # PromptTokens.VOCAB_SIZE
67
+ n_prompt_len: int = 5 # [BOS, TASK, AIRCRAFT, PHASE, REGION]
68
+
69
+ # Output heads
70
+ # We predict: geohash (regression), COG bin, SOG bin, ROT bin, alt_rate bin
71
+ predict_geohash: bool = True # if True, predict geohash bits (binary classification per bit)
72
+ predict_continuous: bool = True # if True, also predict continuous ENU offset (regression)
73
+
74
+ # Ablation variants for geohash
75
+ geohash_mode: str = 'absolute' # 'absolute', 'none', 'relative', 'multi_res', 'continuous'
76
+
77
+
78
+ # ============================================================
79
+ # Embedding Modules
80
+ # ============================================================
81
+
82
+ class GeohashEmbedding(nn.Module):
83
+ """
84
+ Binary geohash embedding following LLM4STP.
85
+ Projects 120-bit binary vector through:
86
+ Linear(120 → geohash_hidden) → ReLU → Linear(geohash_hidden → d_model)
87
+
88
+ LLM4STP uses Conv1d on the bits, but we use MLP for simplicity
89
+ since each timestep's 120 bits are independent.
90
+ """
91
+
92
+ def __init__(self, config: AirTrackConfig):
93
+ super().__init__()
94
+ self.projection = nn.Sequential(
95
+ nn.Linear(config.geohash_bits, config.geohash_hidden),
96
+ nn.ReLU(),
97
+ nn.Linear(config.geohash_hidden, config.d_model),
98
+ )
99
+
100
+ def forward(self, geohash_bits: torch.Tensor) -> torch.Tensor:
101
+ """
102
+ Args:
103
+ geohash_bits: (batch, seq_len, 120) float tensor of binary geohash
104
+ Returns:
105
+ (batch, seq_len, d_model)
106
+ """
107
+ return self.projection(geohash_bits)
108
+
109
+
110
+ class ContinuousPositionEmbedding(nn.Module):
111
+ """Ablation variant V5: direct linear projection of continuous ENU coordinates."""
112
+
113
+ def __init__(self, config: AirTrackConfig):
114
+ super().__init__()
115
+ self.projection = nn.Sequential(
116
+ nn.Linear(3, config.geohash_hidden),
117
+ nn.ReLU(),
118
+ nn.Linear(config.geohash_hidden, config.d_model),
119
+ )
120
+
121
+ def forward(self, east: torch.Tensor, north: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
122
+ """
123
+ Args:
124
+ east, north, up: (batch, seq_len) each
125
+ Returns:
126
+ (batch, seq_len, d_model)
127
+ """
128
+ pos = torch.stack([east, north, up], dim=-1) # (B, L, 3)
129
+ return self.projection(pos)
130
+
131
+
132
+ class FeatureEmbedding(nn.Module):
133
+ """
134
+ Learned embedding tables for discretized kinematic features.
135
+ Each feature has its own embedding table, all outputs summed.
136
+ """
137
+
138
+ def __init__(self, config: AirTrackConfig):
139
+ super().__init__()
140
+ self.cog_embed = nn.Embedding(config.n_cog_bins, config.d_model)
141
+ self.sog_embed = nn.Embedding(config.n_sog_bins, config.d_model)
142
+ self.rot_embed = nn.Embedding(config.n_rot_bins, config.d_model)
143
+ self.alt_rate_embed = nn.Embedding(config.n_alt_rate_bins, config.d_model)
144
+
145
+ def forward(
146
+ self,
147
+ cog_bins: torch.Tensor,
148
+ sog_bins: torch.Tensor,
149
+ rot_bins: torch.Tensor,
150
+ alt_rate_bins: torch.Tensor,
151
+ ) -> torch.Tensor:
152
+ """
153
+ Args:
154
+ *_bins: (batch, seq_len) long tensors of bin indices
155
+ Returns:
156
+ (batch, seq_len, d_model) — sum of all feature embeddings
157
+ """
158
+ return (
159
+ self.cog_embed(cog_bins) +
160
+ self.sog_embed(sog_bins) +
161
+ self.rot_embed(rot_bins) +
162
+ self.alt_rate_embed(alt_rate_bins)
163
+ )
164
+
165
+
166
+ class TemporalEmbedding(nn.Module):
167
+ """
168
+ Temporal embedding combining:
169
+ 1. Sinusoidal encoding of second-of-day (sub-second resolution)
170
+ 2. Learned embeddings for hour, day-of-week, month
171
+ 3. Sinusoidal encoding of delta-t (time since previous state)
172
+
173
+ The sinusoidal encoding gives sub-second precision since it operates
174
+ on continuous float seconds, not discrete bins.
175
+ """
176
+
177
+ def __init__(self, config: AirTrackConfig):
178
+ super().__init__()
179
+
180
+ # Learned calendar embeddings
181
+ self.hour_embed = nn.Embedding(config.n_hours, config.d_model)
182
+ self.dow_embed = nn.Embedding(config.n_dow, config.d_model)
183
+ self.month_embed = nn.Embedding(config.n_months, config.d_model)
184
+
185
+ # Sinusoidal projection for continuous time features
186
+ # second_of_day → sinusoidal features → linear → d_model
187
+ self.time_sin_dim = config.time_sinusoidal_dim
188
+ self.time_projection = nn.Linear(config.time_sinusoidal_dim * 2, config.d_model)
189
+
190
+ # Delta-t projection
191
+ self.dt_projection = nn.Linear(config.time_sinusoidal_dim * 2, config.d_model)
192
+
193
+ # Pre-compute frequency bases for sinusoidal encoding
194
+ # Multiple frequencies to capture different time scales
195
+ freqs = torch.exp(torch.arange(0, config.time_sinusoidal_dim, dtype=torch.float32) *
196
+ -(math.log(86400.0) / config.time_sinusoidal_dim))
197
+ self.register_buffer('time_freqs', freqs)
198
+
199
+ dt_freqs = torch.exp(torch.arange(0, config.time_sinusoidal_dim, dtype=torch.float32) *
200
+ -(math.log(3600.0) / config.time_sinusoidal_dim))
201
+ self.register_buffer('dt_freqs', dt_freqs)
202
+
203
+ def sinusoidal_encode(self, values: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
204
+ """
205
+ Encode continuous values with multiple sinusoidal frequencies.
206
+
207
+ Args:
208
+ values: (batch, seq_len) — continuous values
209
+ freqs: (dim,) — frequency bases
210
+ Returns:
211
+ (batch, seq_len, dim*2) — sin and cos features
212
+ """
213
+ # (B, L, 1) * (1, 1, dim) → (B, L, dim)
214
+ angles = values.unsqueeze(-1) * freqs.unsqueeze(0).unsqueeze(0) * 2 * math.pi
215
+ return torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)
216
+
217
+ def forward(
218
+ self,
219
+ hour: torch.Tensor,
220
+ dow: torch.Tensor,
221
+ month: torch.Tensor,
222
+ second_of_day: torch.Tensor,
223
+ dt: torch.Tensor,
224
+ ) -> torch.Tensor:
225
+ """
226
+ Args:
227
+ hour: (B, L) long — hour of day [0, 23]
228
+ dow: (B, L) long — day of week [0, 6]
229
+ month: (B, L) long — month [0, 11]
230
+ second_of_day: (B, L) float — seconds within day [0, 86400)
231
+ dt: (B, L) float — delta-t in seconds
232
+ Returns:
233
+ (B, L, d_model)
234
+ """
235
+ # Learned calendar embeddings
236
+ cal = self.hour_embed(hour) + self.dow_embed(dow) + self.month_embed(month)
237
+
238
+ # Sinusoidal second-of-day (sub-second resolution)
239
+ time_sin = self.sinusoidal_encode(second_of_day, self.time_freqs) # (B, L, dim*2)
240
+ time_emb = self.time_projection(time_sin) # (B, L, d_model)
241
+
242
+ # Sinusoidal delta-t
243
+ dt_sin = self.sinusoidal_encode(dt, self.dt_freqs) # (B, L, dim*2)
244
+ dt_emb = self.dt_projection(dt_sin) # (B, L, d_model)
245
+
246
+ return cal + time_emb + dt_emb
247
+
248
+
249
+ class UncertaintyEmbedding(nn.Module):
250
+ """Learned embedding for trajectory uncertainty bins."""
251
+
252
+ def __init__(self, config: AirTrackConfig):
253
+ super().__init__()
254
+ self.embed = nn.Embedding(config.n_uncert_bins, config.d_model)
255
+
256
+ def forward(self, uncert_bins: torch.Tensor) -> torch.Tensor:
257
+ """
258
+ Args:
259
+ uncert_bins: (B, L) long — uncertainty bin indices
260
+ Returns:
261
+ (B, L, d_model)
262
+ """
263
+ return self.embed(uncert_bins)
264
+
265
+
266
+ class PromptEmbedding(nn.Module):
267
+ """Learned prompt token embeddings for task/metadata conditioning."""
268
+
269
+ def __init__(self, config: AirTrackConfig):
270
+ super().__init__()
271
+ self.embed = nn.Embedding(config.n_prompt_tokens, config.d_model)
272
+
273
+ def forward(self, prompt_tokens: torch.Tensor) -> torch.Tensor:
274
+ """
275
+ Args:
276
+ prompt_tokens: (B, n_prompt_len) long — prompt token IDs
277
+ Returns:
278
+ (B, n_prompt_len, d_model)
279
+ """
280
+ return self.embed(prompt_tokens)
281
+
282
+
283
+ # ============================================================
284
+ # Positional Encoding
285
+ # ============================================================
286
+
287
+ class SinusoidalPositionalEncoding(nn.Module):
288
+ """Standard sinusoidal positional encoding."""
289
+
290
+ def __init__(self, d_model: int, max_len: int = 512, dropout: float = 0.1):
291
+ super().__init__()
292
+ self.dropout = nn.Dropout(p=dropout)
293
+
294
+ pe = torch.zeros(max_len, d_model)
295
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
296
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
297
+ pe[:, 0::2] = torch.sin(position * div_term)
298
+ pe[:, 1::2] = torch.cos(position * div_term)
299
+ pe = pe.unsqueeze(0) # (1, max_len, d_model)
300
+ self.register_buffer('pe', pe)
301
+
302
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
303
+ """x: (B, L, d_model)"""
304
+ x = x + self.pe[:, :x.size(1)]
305
+ return self.dropout(x)
306
+
307
+
308
+ # ============================================================
309
+ # Transformer Backbone
310
+ # ============================================================
311
+
312
+ class TransformerBlock(nn.Module):
313
+ """Single transformer decoder block with causal attention."""
314
+
315
+ def __init__(self, config: AirTrackConfig):
316
+ super().__init__()
317
+
318
+ self.ln1 = nn.LayerNorm(config.d_model)
319
+ self.attn = nn.MultiheadAttention(
320
+ embed_dim=config.d_model,
321
+ num_heads=config.n_heads,
322
+ dropout=config.dropout,
323
+ batch_first=True,
324
+ )
325
+ self.ln2 = nn.LayerNorm(config.d_model)
326
+ self.ffn = nn.Sequential(
327
+ nn.Linear(config.d_model, config.d_ff),
328
+ nn.GELU(),
329
+ nn.Dropout(config.dropout),
330
+ nn.Linear(config.d_ff, config.d_model),
331
+ nn.Dropout(config.dropout),
332
+ )
333
+
334
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
335
+ """
336
+ Args:
337
+ x: (B, L, d_model)
338
+ attn_mask: (L, L) causal mask
339
+ Returns:
340
+ (B, L, d_model)
341
+ """
342
+ # Pre-norm architecture (like GPT-2)
343
+ h = self.ln1(x)
344
+ h, _ = self.attn(h, h, h, attn_mask=attn_mask, is_causal=(attn_mask is None))
345
+ x = x + h
346
+
347
+ h = self.ln2(x)
348
+ h = self.ffn(h)
349
+ x = x + h
350
+
351
+ return x
352
+
353
+
354
+ # ============================================================
355
+ # Output Heads
356
+ # ============================================================
357
+
358
+ class NextStatePredictionHead(nn.Module):
359
+ """
360
+ Multi-head output for next-state prediction.
361
+ Predicts each feature type independently.
362
+ """
363
+
364
+ def __init__(self, config: AirTrackConfig):
365
+ super().__init__()
366
+
367
+ # Geohash: predict 120 binary bits (sigmoid per bit)
368
+ if config.predict_geohash:
369
+ self.geohash_head = nn.Linear(config.d_model, config.geohash_bits)
370
+
371
+ # Continuous ENU regression (optional secondary objective)
372
+ if config.predict_continuous:
373
+ self.continuous_head = nn.Sequential(
374
+ nn.Linear(config.d_model, config.d_model // 2),
375
+ nn.GELU(),
376
+ nn.Linear(config.d_model // 2, 3), # (Δeast, Δnorth, Δup)
377
+ )
378
+
379
+ # Kinematic feature bin classification
380
+ self.cog_head = nn.Linear(config.d_model, config.n_cog_bins)
381
+ self.sog_head = nn.Linear(config.d_model, config.n_sog_bins)
382
+ self.rot_head = nn.Linear(config.d_model, config.n_rot_bins)
383
+ self.alt_rate_head = nn.Linear(config.d_model, config.n_alt_rate_bins)
384
+
385
+ self.config = config
386
+
387
+ def forward(self, hidden_states: torch.Tensor) -> Dict[str, torch.Tensor]:
388
+ """
389
+ Args:
390
+ hidden_states: (B, L, d_model) — transformer output
391
+ Returns:
392
+ dict of logits/predictions for each feature
393
+ """
394
+ out = {}
395
+
396
+ if self.config.predict_geohash:
397
+ out['geohash_logits'] = self.geohash_head(hidden_states) # (B, L, 120)
398
+
399
+ if self.config.predict_continuous:
400
+ out['continuous_pred'] = self.continuous_head(hidden_states) # (B, L, 3)
401
+
402
+ out['cog_logits'] = self.cog_head(hidden_states) # (B, L, n_cog_bins)
403
+ out['sog_logits'] = self.sog_head(hidden_states) # (B, L, n_sog_bins)
404
+ out['rot_logits'] = self.rot_head(hidden_states) # (B, L, n_rot_bins)
405
+ out['alt_rate_logits'] = self.alt_rate_head(hidden_states) # (B, L, n_alt_rate_bins)
406
+
407
+ return out
408
+
409
+
410
+ class ClassificationHead(nn.Module):
411
+ """Downstream classification head (attached after pretraining)."""
412
+
413
+ def __init__(self, d_model: int, n_classes: int, dropout: float = 0.1):
414
+ super().__init__()
415
+ self.head = nn.Sequential(
416
+ nn.Linear(d_model, d_model // 2),
417
+ nn.GELU(),
418
+ nn.Dropout(dropout),
419
+ nn.Linear(d_model // 2, n_classes),
420
+ )
421
+
422
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
423
+ """
424
+ Uses the BOS token representation (first position) for classification.
425
+
426
+ Args:
427
+ hidden_states: (B, L, d_model)
428
+ Returns:
429
+ (B, n_classes)
430
+ """
431
+ cls_repr = hidden_states[:, 0, :] # BOS position
432
+ return self.head(cls_repr)
433
+
434
+
435
+ # ============================================================
436
+ # Main Model
437
+ # ============================================================
438
+
439
+ class AirTrackLM(nn.Module):
440
+ """
441
+ AirTrackLM: Decoder-only transformer for air track next-state prediction.
442
+
443
+ Architecture:
444
+ Input → [4 Embedding Types fused additively] → Positional Encoding
445
+ → N × TransformerBlock (causal attention)
446
+ → Multi-head output (geohash + kinematic features)
447
+ """
448
+
449
+ def __init__(self, config: AirTrackConfig):
450
+ super().__init__()
451
+ self.config = config
452
+
453
+ # === Embedding layers ===
454
+
455
+ # Geohash (spatial position)
456
+ if config.geohash_mode == 'absolute':
457
+ self.geohash_embed = GeohashEmbedding(config)
458
+ elif config.geohash_mode == 'continuous':
459
+ self.geohash_embed = ContinuousPositionEmbedding(config)
460
+ elif config.geohash_mode == 'none':
461
+ self.geohash_embed = None
462
+ else:
463
+ # relative and multi_res use same base as absolute
464
+ self.geohash_embed = GeohashEmbedding(config)
465
+
466
+ # Kinematic features
467
+ self.feature_embed = FeatureEmbedding(config)
468
+
469
+ # Temporal
470
+ self.temporal_embed = TemporalEmbedding(config)
471
+
472
+ # Uncertainty — single or multi-method
473
+ if config.use_multi_uncertainty and config.n_uncert_methods > 1:
474
+ from uncertainty import MultiUncertaintyEmbedding
475
+ self.uncertainty_embed = MultiUncertaintyEmbedding(
476
+ config.d_model, config.n_uncert_methods, config.n_uncert_bins
477
+ )
478
+ self._multi_uncert = True
479
+ else:
480
+ self.uncertainty_embed = UncertaintyEmbedding(config)
481
+ self._multi_uncert = False
482
+
483
+ # Heteroscedastic uncertainty head (learned aleatoric)
484
+ self.heteroscedastic_head = None
485
+ if config.use_heteroscedastic:
486
+ from uncertainty import HeteroscedasticHead
487
+ self.heteroscedastic_head = HeteroscedasticHead(config.d_model, n_outputs=6)
488
+
489
+ # Prompt
490
+ self.prompt_embed = PromptEmbedding(config)
491
+
492
+ # === Fusion projection ===
493
+ # After additive fusion, project through LayerNorm
494
+ self.fusion_ln = nn.LayerNorm(config.d_model)
495
+
496
+ # === Positional encoding ===
497
+ self.pos_encoding = SinusoidalPositionalEncoding(
498
+ config.d_model, config.max_seq_len, config.dropout
499
+ )
500
+
501
+ # === Transformer blocks ===
502
+ self.blocks = nn.ModuleList([
503
+ TransformerBlock(config) for _ in range(config.n_layers)
504
+ ])
505
+
506
+ # Final layer norm
507
+ self.final_ln = nn.LayerNorm(config.d_model)
508
+
509
+ # === Output heads ===
510
+ self.prediction_head = NextStatePredictionHead(config)
511
+
512
+ # Classification head (optional, for downstream)
513
+ self.classification_head = None
514
+
515
+ # Initialize weights
516
+ self.apply(self._init_weights)
517
+
518
+ def _init_weights(self, module):
519
+ if isinstance(module, nn.Linear):
520
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
521
+ if module.bias is not None:
522
+ torch.nn.init.zeros_(module.bias)
523
+ elif isinstance(module, nn.Embedding):
524
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
525
+ elif isinstance(module, nn.LayerNorm):
526
+ torch.nn.init.ones_(module.weight)
527
+ torch.nn.init.zeros_(module.bias)
528
+
529
+ def attach_classification_head(self, n_classes: int):
530
+ """Attach a classification head for downstream fine-tuning."""
531
+ self.classification_head = ClassificationHead(
532
+ self.config.d_model, n_classes, self.config.dropout
533
+ )
534
+
535
+ def get_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
536
+ """Generate causal attention mask."""
537
+ mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1)
538
+ mask = mask.masked_fill(mask == 1, float('-inf'))
539
+ return mask
540
+
541
+ def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
542
+ """
543
+ Forward pass.
544
+
545
+ Args:
546
+ batch: dict from AirTrackDataset.__getitem__ (batched)
547
+
548
+ Returns:
549
+ dict with prediction logits and optionally classification logits
550
+ """
551
+ device = batch['cog_bins'].device
552
+ B = batch['cog_bins'].size(0)
553
+
554
+ # --- Build state embeddings ---
555
+
556
+ # Kinematic feature embedding
557
+ feat_emb = self.feature_embed(
558
+ batch['cog_bins'], batch['sog_bins'],
559
+ batch['rot_bins'], batch['alt_rate_bins']
560
+ ) # (B, L, d_model)
561
+
562
+ # Temporal embedding
563
+ temp_emb = self.temporal_embed(
564
+ batch['hour'], batch['dow'], batch['month'],
565
+ batch['second_of_day'], batch['dt']
566
+ ) # (B, L, d_model)
567
+
568
+ # Uncertainty embedding (single or multi-method)
569
+ if self._multi_uncert and 'uncert_bins_multi' in batch:
570
+ uncert_emb = self.uncertainty_embed(batch['uncert_bins_multi']) # (B, L, d_model)
571
+ else:
572
+ uncert_emb = self.uncertainty_embed(batch['uncert_bins']) # (B, L, d_model)
573
+
574
+ # Geohash embedding
575
+ if self.config.geohash_mode == 'continuous':
576
+ geo_emb = self.geohash_embed(batch['east'], batch['north'], batch['up'])
577
+ elif self.geohash_embed is not None:
578
+ geo_emb = self.geohash_embed(batch['geohash_bits']) # (B, L, d_model)
579
+ else:
580
+ geo_emb = torch.zeros_like(feat_emb)
581
+
582
+ # --- Additive fusion ---
583
+ state_emb = feat_emb + temp_emb + uncert_emb + geo_emb # (B, L, d_model)
584
+ state_emb = self.fusion_ln(state_emb)
585
+
586
+ # --- Prepend prompt tokens ---
587
+ prompt_emb = self.prompt_embed(batch['prompt']) # (B, n_prompt, d_model)
588
+
589
+ # Concatenate: [PROMPT | STATE_1 | STATE_2 | ... | STATE_T]
590
+ x = torch.cat([prompt_emb, state_emb], dim=1) # (B, n_prompt + L, d_model)
591
+
592
+ # --- Positional encoding ---
593
+ x = self.pos_encoding(x)
594
+
595
+ # --- Causal transformer ---
596
+ seq_len = x.size(1)
597
+ causal_mask = self.get_causal_mask(seq_len, device)
598
+
599
+ for block in self.blocks:
600
+ x = block(x, attn_mask=causal_mask)
601
+
602
+ x = self.final_ln(x)
603
+
604
+ # --- Split output ---
605
+ n_prompt = batch['prompt'].size(1)
606
+ prompt_output = x[:, :n_prompt, :] # (B, n_prompt, d_model)
607
+ state_output = x[:, n_prompt:, :] # (B, L, d_model)
608
+
609
+ # --- Prediction heads (on state output) ---
610
+ predictions = self.prediction_head(state_output)
611
+
612
+ # --- Heteroscedastic uncertainty (learned aleatoric) ---
613
+ if self.heteroscedastic_head is not None:
614
+ predictions['log_var'] = self.heteroscedastic_head(state_output) # (B, L, 6)
615
+
616
+ # --- Classification (optional) ---
617
+ if self.classification_head is not None:
618
+ predictions['class_logits'] = self.classification_head(x) # uses BOS at position 0
619
+
620
+ return predictions
621
+
622
+ def count_parameters(self) -> Dict[str, int]:
623
+ """Count parameters by component."""
624
+ counts = {}
625
+ for name, module in [
626
+ ('geohash_embed', self.geohash_embed),
627
+ ('feature_embed', self.feature_embed),
628
+ ('temporal_embed', self.temporal_embed),
629
+ ('uncertainty_embed', self.uncertainty_embed),
630
+ ('prompt_embed', self.prompt_embed),
631
+ ('transformer_blocks', self.blocks),
632
+ ('prediction_head', self.prediction_head),
633
+ ]:
634
+ if module is not None:
635
+ counts[name] = sum(p.numel() for p in module.parameters())
636
+
637
+ counts['total'] = sum(p.numel() for p in self.parameters())
638
+ counts['trainable'] = sum(p.numel() for p in self.parameters() if p.requires_grad)
639
+ return counts
640
+
641
+
642
+ # ============================================================
643
+ # Loss Function
644
+ # ============================================================
645
+
646
+ class NextStateLoss(nn.Module):
647
+ """
648
+ Multi-task loss for next-state prediction.
649
+
650
+ For each position t, the model predicts features at t+1.
651
+ Losses:
652
+ - Geohash: Binary cross-entropy per bit
653
+ - Kinematic features (COG, SOG, ROT, alt_rate): Cross-entropy per feature
654
+ - Continuous ENU: MSE (optional)
655
+ """
656
+
657
+ def __init__(self, config: AirTrackConfig, loss_weights: Optional[Dict[str, float]] = None):
658
+ super().__init__()
659
+ self.config = config
660
+
661
+ # Default loss weights (equal)
662
+ self.weights = loss_weights or {
663
+ 'geohash': 1.0,
664
+ 'continuous': 0.5,
665
+ 'cog': 1.0,
666
+ 'sog': 1.0,
667
+ 'rot': 1.0,
668
+ 'alt_rate': 1.0,
669
+ }
670
+
671
+ self.ce = nn.CrossEntropyLoss(reduction='mean')
672
+ self.bce = nn.BCEWithLogitsLoss(reduction='mean')
673
+ self.mse = nn.MSELoss(reduction='mean')
674
+
675
+ def forward(
676
+ self,
677
+ predictions: Dict[str, torch.Tensor],
678
+ batch: Dict[str, torch.Tensor],
679
+ ) -> Tuple[torch.Tensor, Dict[str, float]]:
680
+ """
681
+ Compute loss. Targets are shifted by 1 (predict next state).
682
+
683
+ predictions[key] is at positions [0, 1, ..., L-1]
684
+ targets are batch[key] at positions [1, 2, ..., L]
685
+
686
+ So we compare predictions[:, :-1, :] with targets[:, 1:, :]
687
+ """
688
+ losses = {}
689
+
690
+ # --- Geohash binary prediction ---
691
+ if self.config.predict_geohash and 'geohash_logits' in predictions:
692
+ # predictions: (B, L, 120), targets: (B, L, 120) float
693
+ pred_geo = predictions['geohash_logits'][:, :-1, :] # (B, L-1, 120)
694
+ target_geo = batch['geohash_bits'][:, 1:, :] # (B, L-1, 120)
695
+ losses['geohash'] = self.bce(pred_geo, target_geo) * self.weights['geohash']
696
+
697
+ # --- Continuous ENU regression (predict delta in km, not raw meters) ---
698
+ if self.config.predict_continuous and 'continuous_pred' in predictions:
699
+ pred_cont = predictions['continuous_pred'][:, :-1, :] # (B, L-1, 3)
700
+ # Target is delta-ENU: position(t+1) - position(t), normalized to km
701
+ delta_east = (batch['east'][:, 1:] - batch['east'][:, :-1]) / 1000.0
702
+ delta_north = (batch['north'][:, 1:] - batch['north'][:, :-1]) / 1000.0
703
+ delta_up = (batch['up'][:, 1:] - batch['up'][:, :-1]) / 1000.0
704
+ target_delta = torch.stack([delta_east, delta_north, delta_up], dim=-1)
705
+ losses['continuous'] = self.mse(pred_cont, target_delta) * self.weights['continuous']
706
+
707
+ # --- COG ---
708
+ pred_cog = predictions['cog_logits'][:, :-1, :] # (B, L-1, n_cog_bins)
709
+ target_cog = batch['cog_bins'][:, 1:] # (B, L-1)
710
+ losses['cog'] = self.ce(pred_cog.reshape(-1, pred_cog.size(-1)), target_cog.reshape(-1)) * self.weights['cog']
711
+
712
+ # --- SOG ---
713
+ pred_sog = predictions['sog_logits'][:, :-1, :]
714
+ target_sog = batch['sog_bins'][:, 1:]
715
+ losses['sog'] = self.ce(pred_sog.reshape(-1, pred_sog.size(-1)), target_sog.reshape(-1)) * self.weights['sog']
716
+
717
+ # --- ROT ---
718
+ pred_rot = predictions['rot_logits'][:, :-1, :]
719
+ target_rot = batch['rot_bins'][:, 1:]
720
+ losses['rot'] = self.ce(pred_rot.reshape(-1, pred_rot.size(-1)), target_rot.reshape(-1)) * self.weights['rot']
721
+
722
+ # --- Alt rate ---
723
+ pred_ar = predictions['alt_rate_logits'][:, :-1, :]
724
+ target_ar = batch['alt_rate_bins'][:, 1:]
725
+ losses['alt_rate'] = self.ce(pred_ar.reshape(-1, pred_ar.size(-1)), target_ar.reshape(-1)) * self.weights['alt_rate']
726
+
727
+ # --- Heteroscedastic regularization (learned aleatoric uncertainty) ---
728
+ if 'log_var' in predictions:
729
+ log_var = predictions['log_var'][:, :-1, :] # (B, L-1, 6)
730
+ # Regularize: penalize overly high uncertainty (prevent collapse)
731
+ # The individual heads already implicitly learn to attend to uncertainty
732
+ # via the gradient signal, but we add a mild KL-like penalty
733
+ log_var_penalty = 0.01 * log_var.mean()
734
+ losses['log_var_reg'] = log_var_penalty
735
+
736
+ # Total loss
737
+ total_loss = sum(losses.values())
738
+
739
+ # Log individual losses
740
+ loss_log = {k: v.item() for k, v in losses.items()}
741
+ loss_log['total'] = total_loss.item()
742
+
743
+ return total_loss, loss_log
744
+
745
+
746
+ # ============================================================
747
+ # Quick test
748
+ # ============================================================
749
+
750
+ if __name__ == '__main__':
751
+ config = AirTrackConfig()
752
+ model = AirTrackLM(config)
753
+
754
+ # Print parameter counts
755
+ counts = model.count_parameters()
756
+ print("Parameter counts:")
757
+ for name, count in counts.items():
758
+ print(f" {name}: {count:,}")
759
+
760
+ # Test forward pass with dummy data
761
+ B, L = 2, 65 # batch=2, seq_len=65 (64 states + 1 for target shift)
762
+ n_prompt = config.n_prompt_len
763
+
764
+ batch = {
765
+ 'geohash_bits': torch.randn(B, L, config.geohash_bits),
766
+ 'cog_bins': torch.randint(0, config.n_cog_bins, (B, L)),
767
+ 'sog_bins': torch.randint(0, config.n_sog_bins, (B, L)),
768
+ 'rot_bins': torch.randint(0, config.n_rot_bins, (B, L)),
769
+ 'alt_rate_bins': torch.randint(0, config.n_alt_rate_bins, (B, L)),
770
+ 'uncert_bins': torch.randint(0, config.n_uncert_bins, (B, L)),
771
+ 'hour': torch.randint(0, 24, (B, L)),
772
+ 'dow': torch.randint(0, 7, (B, L)),
773
+ 'month': torch.randint(0, 12, (B, L)),
774
+ 'second_of_day': torch.rand(B, L) * 86400,
775
+ 'dt': torch.ones(B, L) * 5.0,
776
+ 'prompt': torch.randint(0, config.n_prompt_tokens, (B, n_prompt)),
777
+ 'east': torch.randn(B, L) * 1000,
778
+ 'north': torch.randn(B, L) * 1000,
779
+ 'up': torch.randn(B, L) * 1000,
780
+ }
781
+
782
+ predictions = model(batch)
783
+
784
+ print("\nPrediction shapes:")
785
+ for k, v in predictions.items():
786
+ print(f" {k}: {v.shape}")
787
+
788
+ # Test loss
789
+ loss_fn = NextStateLoss(config)
790
+ total_loss, loss_log = loss_fn(predictions, batch)
791
+ print(f"\nLoss: {loss_log}")