Jdice27 commited on
Commit
2972ac7
·
verified ·
1 Parent(s): 744a6a7

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +134 -434
model.py CHANGED
@@ -3,17 +3,17 @@ AirTrackLM - Model Architecture
3
  ================================
4
  Decoder-only transformer with 4 embedding types for air track next-state prediction.
5
 
6
- Embedding types (following LLM4STP, adapted for aviation):
7
- 1. Geohash: 40-bit binary per ENU axis (120 bits total) → Linear projection → d_model
8
- 2. Temporal: Sinusoidal second-of-day + learned hour/dow/month embeddings
9
- 3. Uncertainty: Learned embedding from trajectory smoothness bins
10
- 4. Prompt: Learned tokens for task/aircraft/phase/region metadata
11
-
12
- Core architecture:
13
- - Additive embedding fusion (E_geo + E_feat + E_temp + E_uncert)
14
- - Prompt tokens prepended to sequence
15
- - Causal (GPT-style) multi-head self-attention
16
- - Multi-head output: separate prediction per feature type
17
  """
18
 
19
  import math
@@ -24,55 +24,45 @@ from typing import Optional, Dict, Tuple
24
  from dataclasses import dataclass
25
 
26
 
27
- # ============================================================
28
- # Configuration
29
- # ============================================================
30
-
31
  @dataclass
32
  class AirTrackConfig:
33
- """Model configuration."""
34
-
35
- # Transformer backbone
36
  d_model: int = 256
37
  n_heads: int = 8
38
  n_layers: int = 8
39
  d_ff: int = 1024
40
  dropout: float = 0.1
41
- max_seq_len: int = 256 # max sequence length (prompt + trajectory)
42
 
43
- # Geohash embedding (LLM4STP style)
44
- geohash_bits: int = 120 # 40 bits × 3 axes (E, N, U)
45
- geohash_hidden: int = 64 # intermediate projection dim
46
 
47
- # Feature bins (discretized kinematic features)
48
- n_cog_bins: int = 180 # 2° resolution over [0, 360)
49
- n_sog_bins: int = 300 # 2-knot resolution over [0, 600]
50
- n_rot_bins: int = 120 # 0.1°/s over [-6, 6]
51
- n_alt_rate_bins: int = 120 # 100 ft/min over [-6000, 6000]
52
 
53
- # Temporal embedding
54
  n_hours: int = 24
55
  n_dow: int = 7
56
  n_months: int = 12
57
- time_sinusoidal_dim: int = 32 # dimension for sinusoidal second-of-day encoding
58
 
59
- # Uncertainty embedding
60
  n_uncert_bins: int = 16
61
- n_uncert_methods: int = 4 # kinematic_var, pred_residual, spatial_density, phase_entropy
62
- use_multi_uncertainty: bool = True # if True, use MultiUncertaintyEmbedding
63
- use_heteroscedastic: bool = True # if True, add learned uncertainty head
64
-
65
- # Prompt embedding
66
- n_prompt_tokens: int = 23 # PromptTokens.VOCAB_SIZE
67
- n_prompt_len: int = 5 # [BOS, TASK, AIRCRAFT, PHASE, REGION]
68
 
69
- # Output heads
70
- # We predict: geohash (regression), COG bin, SOG bin, ROT bin, alt_rate bin
71
- predict_geohash: bool = True # if True, predict geohash bits (binary classification per bit)
72
- predict_continuous: bool = True # if True, also predict continuous ENU offset (regression)
73
 
74
- # Ablation variants for geohash
75
- geohash_mode: str = 'absolute' # 'absolute', 'none', 'relative', 'multi_res', 'continuous'
 
 
76
 
77
 
78
  # ============================================================
@@ -80,16 +70,8 @@ class AirTrackConfig:
80
  # ============================================================
81
 
82
  class GeohashEmbedding(nn.Module):
83
- """
84
- Binary geohash embedding following LLM4STP.
85
- Projects 120-bit binary vector through:
86
- Linear(120 → geohash_hidden) → ReLU → Linear(geohash_hidden → d_model)
87
-
88
- LLM4STP uses Conv1d on the bits, but we use MLP for simplicity
89
- since each timestep's 120 bits are independent.
90
- """
91
-
92
- def __init__(self, config: AirTrackConfig):
93
  super().__init__()
94
  self.projection = nn.Sequential(
95
  nn.Linear(config.geohash_bits, config.geohash_hidden),
@@ -97,20 +79,13 @@ class GeohashEmbedding(nn.Module):
97
  nn.Linear(config.geohash_hidden, config.d_model),
98
  )
99
 
100
- def forward(self, geohash_bits: torch.Tensor) -> torch.Tensor:
101
- """
102
- Args:
103
- geohash_bits: (batch, seq_len, 120) float tensor of binary geohash
104
- Returns:
105
- (batch, seq_len, d_model)
106
- """
107
  return self.projection(geohash_bits)
108
 
109
 
110
  class ContinuousPositionEmbedding(nn.Module):
111
- """Ablation variant V5: direct linear projection of continuous ENU coordinates."""
112
-
113
- def __init__(self, config: AirTrackConfig):
114
  super().__init__()
115
  self.projection = nn.Sequential(
116
  nn.Linear(3, config.geohash_hidden),
@@ -118,81 +93,41 @@ class ContinuousPositionEmbedding(nn.Module):
118
  nn.Linear(config.geohash_hidden, config.d_model),
119
  )
120
 
121
- def forward(self, east: torch.Tensor, north: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
122
- """
123
- Args:
124
- east, north, up: (batch, seq_len) each
125
- Returns:
126
- (batch, seq_len, d_model)
127
- """
128
- pos = torch.stack([east, north, up], dim=-1) # (B, L, 3)
129
  return self.projection(pos)
130
 
131
 
132
  class FeatureEmbedding(nn.Module):
133
- """
134
- Learned embedding tables for discretized kinematic features.
135
- Each feature has its own embedding table, all outputs summed.
136
- """
137
-
138
- def __init__(self, config: AirTrackConfig):
139
  super().__init__()
140
  self.cog_embed = nn.Embedding(config.n_cog_bins, config.d_model)
141
  self.sog_embed = nn.Embedding(config.n_sog_bins, config.d_model)
142
  self.rot_embed = nn.Embedding(config.n_rot_bins, config.d_model)
143
  self.alt_rate_embed = nn.Embedding(config.n_alt_rate_bins, config.d_model)
144
 
145
- def forward(
146
- self,
147
- cog_bins: torch.Tensor,
148
- sog_bins: torch.Tensor,
149
- rot_bins: torch.Tensor,
150
- alt_rate_bins: torch.Tensor,
151
- ) -> torch.Tensor:
152
- """
153
- Args:
154
- *_bins: (batch, seq_len) long tensors of bin indices
155
- Returns:
156
- (batch, seq_len, d_model) — sum of all feature embeddings
157
- """
158
- return (
159
- self.cog_embed(cog_bins) +
160
- self.sog_embed(sog_bins) +
161
- self.rot_embed(rot_bins) +
162
- self.alt_rate_embed(alt_rate_bins)
163
- )
164
 
165
 
166
  class TemporalEmbedding(nn.Module):
167
  """
168
- Temporal embedding combining:
169
- 1. Sinusoidal encoding of second-of-day (sub-second resolution)
170
- 2. Learned embeddings for hour, day-of-week, month
171
- 3. Sinusoidal encoding of delta-t (time since previous state)
172
-
173
- The sinusoidal encoding gives sub-second precision since it operates
174
- on continuous float seconds, not discrete bins.
175
  """
176
-
177
- def __init__(self, config: AirTrackConfig):
178
  super().__init__()
179
-
180
- # Learned calendar embeddings
181
  self.hour_embed = nn.Embedding(config.n_hours, config.d_model)
182
  self.dow_embed = nn.Embedding(config.n_dow, config.d_model)
183
  self.month_embed = nn.Embedding(config.n_months, config.d_model)
184
 
185
- # Sinusoidal projection for continuous time features
186
- # second_of_day → sinusoidal features → linear → d_model
187
  self.time_sin_dim = config.time_sinusoidal_dim
188
  self.time_projection = nn.Linear(config.time_sinusoidal_dim * 2, config.d_model)
189
-
190
- # Delta-t projection
191
  self.dt_projection = nn.Linear(config.time_sinusoidal_dim * 2, config.d_model)
192
 
193
- # Pre-compute frequency bases for sinusoidal encoding
194
- # Multiple frequencies to capture different time scales
195
- freqs = torch.exp(torch.arange(0, config.time_sinusoidal_dim, dtype=torch.float32) *
196
  -(math.log(86400.0) / config.time_sinusoidal_dim))
197
  self.register_buffer('time_freqs', freqs)
198
 
@@ -200,83 +135,32 @@ class TemporalEmbedding(nn.Module):
200
  -(math.log(3600.0) / config.time_sinusoidal_dim))
201
  self.register_buffer('dt_freqs', dt_freqs)
202
 
203
- def sinusoidal_encode(self, values: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
204
- """
205
- Encode continuous values with multiple sinusoidal frequencies.
206
-
207
- Args:
208
- values: (batch, seq_len) — continuous values
209
- freqs: (dim,) — frequency bases
210
- Returns:
211
- (batch, seq_len, dim*2) — sin and cos features
212
- """
213
- # (B, L, 1) * (1, 1, dim) → (B, L, dim)
214
  angles = values.unsqueeze(-1) * freqs.unsqueeze(0).unsqueeze(0) * 2 * math.pi
215
  return torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)
216
 
217
- def forward(
218
- self,
219
- hour: torch.Tensor,
220
- dow: torch.Tensor,
221
- month: torch.Tensor,
222
- second_of_day: torch.Tensor,
223
- dt: torch.Tensor,
224
- ) -> torch.Tensor:
225
- """
226
- Args:
227
- hour: (B, L) long — hour of day [0, 23]
228
- dow: (B, L) long — day of week [0, 6]
229
- month: (B, L) long — month [0, 11]
230
- second_of_day: (B, L) float — seconds within day [0, 86400)
231
- dt: (B, L) float — delta-t in seconds
232
- Returns:
233
- (B, L, d_model)
234
- """
235
- # Learned calendar embeddings
236
  cal = self.hour_embed(hour) + self.dow_embed(dow) + self.month_embed(month)
237
-
238
- # Sinusoidal second-of-day (sub-second resolution)
239
- time_sin = self.sinusoidal_encode(second_of_day, self.time_freqs) # (B, L, dim*2)
240
- time_emb = self.time_projection(time_sin) # (B, L, d_model)
241
-
242
- # Sinusoidal delta-t
243
- dt_sin = self.sinusoidal_encode(dt, self.dt_freqs) # (B, L, dim*2)
244
- dt_emb = self.dt_projection(dt_sin) # (B, L, d_model)
245
-
246
  return cal + time_emb + dt_emb
247
 
248
 
249
  class UncertaintyEmbedding(nn.Module):
250
- """Learned embedding for trajectory uncertainty bins."""
251
-
252
- def __init__(self, config: AirTrackConfig):
253
  super().__init__()
254
  self.embed = nn.Embedding(config.n_uncert_bins, config.d_model)
255
 
256
- def forward(self, uncert_bins: torch.Tensor) -> torch.Tensor:
257
- """
258
- Args:
259
- uncert_bins: (B, L) long — uncertainty bin indices
260
- Returns:
261
- (B, L, d_model)
262
- """
263
  return self.embed(uncert_bins)
264
 
265
 
266
  class PromptEmbedding(nn.Module):
267
- """Learned prompt token embeddings for task/metadata conditioning."""
268
-
269
- def __init__(self, config: AirTrackConfig):
270
  super().__init__()
271
  self.embed = nn.Embedding(config.n_prompt_tokens, config.d_model)
272
 
273
- def forward(self, prompt_tokens: torch.Tensor) -> torch.Tensor:
274
- """
275
- Args:
276
- prompt_tokens: (B, n_prompt_len) long — prompt token IDs
277
- Returns:
278
- (B, n_prompt_len, d_model)
279
- """
280
  return self.embed(prompt_tokens)
281
 
282
 
@@ -285,42 +169,32 @@ class PromptEmbedding(nn.Module):
285
  # ============================================================
286
 
287
  class SinusoidalPositionalEncoding(nn.Module):
288
- """Standard sinusoidal positional encoding."""
289
-
290
- def __init__(self, d_model: int, max_len: int = 512, dropout: float = 0.1):
291
  super().__init__()
292
  self.dropout = nn.Dropout(p=dropout)
293
-
294
  pe = torch.zeros(max_len, d_model)
295
  position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
296
  div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
297
  pe[:, 0::2] = torch.sin(position * div_term)
298
  pe[:, 1::2] = torch.cos(position * div_term)
299
- pe = pe.unsqueeze(0) # (1, max_len, d_model)
300
- self.register_buffer('pe', pe)
301
 
302
- def forward(self, x: torch.Tensor) -> torch.Tensor:
303
- """x: (B, L, d_model)"""
304
  x = x + self.pe[:, :x.size(1)]
305
  return self.dropout(x)
306
 
307
 
308
  # ============================================================
309
- # Transformer Backbone
310
  # ============================================================
311
 
312
  class TransformerBlock(nn.Module):
313
- """Single transformer decoder block with causal attention."""
314
-
315
- def __init__(self, config: AirTrackConfig):
316
  super().__init__()
317
-
318
  self.ln1 = nn.LayerNorm(config.d_model)
319
  self.attn = nn.MultiheadAttention(
320
- embed_dim=config.d_model,
321
- num_heads=config.n_heads,
322
- dropout=config.dropout,
323
- batch_first=True,
324
  )
325
  self.ln2 = nn.LayerNorm(config.d_model)
326
  self.ffn = nn.Sequential(
@@ -331,23 +205,12 @@ class TransformerBlock(nn.Module):
331
  nn.Dropout(config.dropout),
332
  )
333
 
334
- def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
335
- """
336
- Args:
337
- x: (B, L, d_model)
338
- attn_mask: (L, L) causal mask
339
- Returns:
340
- (B, L, d_model)
341
- """
342
- # Pre-norm architecture (like GPT-2)
343
  h = self.ln1(x)
344
  h, _ = self.attn(h, h, h, attn_mask=attn_mask, is_causal=(attn_mask is None))
345
  x = x + h
346
-
347
  h = self.ln2(x)
348
- h = self.ffn(h)
349
- x = x + h
350
-
351
  return x
352
 
353
 
@@ -356,80 +219,45 @@ class TransformerBlock(nn.Module):
356
  # ============================================================
357
 
358
  class NextStatePredictionHead(nn.Module):
359
- """
360
- Multi-head output for next-state prediction.
361
- Predicts each feature type independently.
362
- """
363
-
364
- def __init__(self, config: AirTrackConfig):
365
  super().__init__()
366
-
367
- # Geohash: predict 120 binary bits (sigmoid per bit)
368
  if config.predict_geohash:
369
  self.geohash_head = nn.Linear(config.d_model, config.geohash_bits)
370
-
371
- # Continuous ENU regression (optional secondary objective)
372
  if config.predict_continuous:
373
  self.continuous_head = nn.Sequential(
374
  nn.Linear(config.d_model, config.d_model // 2),
375
  nn.GELU(),
376
- nn.Linear(config.d_model // 2, 3), # (Δeast, Δnorth, Δup)
377
  )
378
-
379
- # Kinematic feature bin classification
380
  self.cog_head = nn.Linear(config.d_model, config.n_cog_bins)
381
  self.sog_head = nn.Linear(config.d_model, config.n_sog_bins)
382
  self.rot_head = nn.Linear(config.d_model, config.n_rot_bins)
383
  self.alt_rate_head = nn.Linear(config.d_model, config.n_alt_rate_bins)
384
-
385
- self.config = config
386
 
387
- def forward(self, hidden_states: torch.Tensor) -> Dict[str, torch.Tensor]:
388
- """
389
- Args:
390
- hidden_states: (B, L, d_model) — transformer output
391
- Returns:
392
- dict of logits/predictions for each feature
393
- """
394
  out = {}
395
-
396
  if self.config.predict_geohash:
397
- out['geohash_logits'] = self.geohash_head(hidden_states) # (B, L, 120)
398
-
399
  if self.config.predict_continuous:
400
- out['continuous_pred'] = self.continuous_head(hidden_states) # (B, L, 3)
401
-
402
- out['cog_logits'] = self.cog_head(hidden_states) # (B, L, n_cog_bins)
403
- out['sog_logits'] = self.sog_head(hidden_states) # (B, L, n_sog_bins)
404
- out['rot_logits'] = self.rot_head(hidden_states) # (B, L, n_rot_bins)
405
- out['alt_rate_logits'] = self.alt_rate_head(hidden_states) # (B, L, n_alt_rate_bins)
406
-
407
  return out
408
 
409
 
410
  class ClassificationHead(nn.Module):
411
- """Downstream classification head (attached after pretraining)."""
412
-
413
- def __init__(self, d_model: int, n_classes: int, dropout: float = 0.1):
414
  super().__init__()
415
  self.head = nn.Sequential(
416
- nn.Linear(d_model, d_model // 2),
417
- nn.GELU(),
418
- nn.Dropout(dropout),
419
- nn.Linear(d_model // 2, n_classes),
420
  )
421
 
422
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
423
- """
424
- Uses the BOS token representation (first position) for classification.
425
-
426
- Args:
427
- hidden_states: (B, L, d_model)
428
- Returns:
429
- (B, n_classes)
430
- """
431
- cls_repr = hidden_states[:, 0, :] # BOS position
432
- return self.head(cls_repr)
433
 
434
 
435
  # ============================================================
@@ -437,39 +265,22 @@ class ClassificationHead(nn.Module):
437
  # ============================================================
438
 
439
  class AirTrackLM(nn.Module):
440
- """
441
- AirTrackLM: Decoder-only transformer for air track next-state prediction.
442
-
443
- Architecture:
444
- Input → [4 Embedding Types fused additively] → Positional Encoding
445
- → N × TransformerBlock (causal attention)
446
- → Multi-head output (geohash + kinematic features)
447
- """
448
-
449
- def __init__(self, config: AirTrackConfig):
450
  super().__init__()
451
  self.config = config
452
 
453
- # === Embedding layers ===
454
-
455
- # Geohash (spatial position)
456
- if config.geohash_mode == 'absolute':
457
- self.geohash_embed = GeohashEmbedding(config)
458
- elif config.geohash_mode == 'continuous':
459
  self.geohash_embed = ContinuousPositionEmbedding(config)
460
  elif config.geohash_mode == 'none':
461
  self.geohash_embed = None
462
  else:
463
- # relative and multi_res use same base as absolute
464
  self.geohash_embed = GeohashEmbedding(config)
465
 
466
- # Kinematic features
467
  self.feature_embed = FeatureEmbedding(config)
468
-
469
- # Temporal
470
  self.temporal_embed = TemporalEmbedding(config)
471
 
472
- # Uncertainty — single or multi-method
473
  if config.use_multi_uncertainty and config.n_uncert_methods > 1:
474
  from uncertainty import MultiUncertaintyEmbedding
475
  self.uncertainty_embed = MultiUncertaintyEmbedding(
@@ -480,119 +291,79 @@ class AirTrackLM(nn.Module):
480
  self.uncertainty_embed = UncertaintyEmbedding(config)
481
  self._multi_uncert = False
482
 
483
- # Heteroscedastic uncertainty head (learned aleatoric)
484
  self.heteroscedastic_head = None
485
  if config.use_heteroscedastic:
486
  from uncertainty import HeteroscedasticHead
487
  self.heteroscedastic_head = HeteroscedasticHead(config.d_model, n_outputs=6)
488
 
489
- # Prompt
490
  self.prompt_embed = PromptEmbedding(config)
491
-
492
- # === Fusion projection ===
493
- # After additive fusion, project through LayerNorm
494
  self.fusion_ln = nn.LayerNorm(config.d_model)
495
-
496
- # === Positional encoding ===
497
- self.pos_encoding = SinusoidalPositionalEncoding(
498
- config.d_model, config.max_seq_len, config.dropout
499
- )
500
-
501
- # === Transformer blocks ===
502
- self.blocks = nn.ModuleList([
503
- TransformerBlock(config) for _ in range(config.n_layers)
504
- ])
505
-
506
- # Final layer norm
507
  self.final_ln = nn.LayerNorm(config.d_model)
508
-
509
- # === Output heads ===
510
  self.prediction_head = NextStatePredictionHead(config)
511
-
512
- # Classification head (optional, for downstream)
513
  self.classification_head = None
514
 
515
- # Initialize weights
516
  self.apply(self._init_weights)
517
 
518
  def _init_weights(self, module):
519
  if isinstance(module, nn.Linear):
520
- torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
521
  if module.bias is not None:
522
- torch.nn.init.zeros_(module.bias)
523
  elif isinstance(module, nn.Embedding):
524
- torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
525
  elif isinstance(module, nn.LayerNorm):
526
- torch.nn.init.ones_(module.weight)
527
- torch.nn.init.zeros_(module.bias)
528
 
529
- def attach_classification_head(self, n_classes: int):
530
- """Attach a classification head for downstream fine-tuning."""
531
- self.classification_head = ClassificationHead(
532
- self.config.d_model, n_classes, self.config.dropout
533
- )
534
 
535
- def get_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
536
- """Generate causal attention mask."""
537
  mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1)
538
- mask = mask.masked_fill(mask == 1, float('-inf'))
539
- return mask
540
 
541
- def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
542
- """
543
- Forward pass.
544
-
545
- Args:
546
- batch: dict from AirTrackDataset.__getitem__ (batched)
547
-
548
- Returns:
549
- dict with prediction logits and optionally classification logits
550
- """
551
  device = batch['cog_bins'].device
552
- B = batch['cog_bins'].size(0)
553
-
554
- # --- Build state embeddings ---
555
 
556
- # Kinematic feature embedding
557
  feat_emb = self.feature_embed(
558
- batch['cog_bins'], batch['sog_bins'],
559
  batch['rot_bins'], batch['alt_rate_bins']
560
- ) # (B, L, d_model)
561
 
562
  # Temporal embedding
563
  temp_emb = self.temporal_embed(
564
  batch['hour'], batch['dow'], batch['month'],
565
  batch['second_of_day'], batch['dt']
566
- ) # (B, L, d_model)
567
 
568
- # Uncertainty embedding (single or multi-method)
569
  if self._multi_uncert and 'uncert_bins_multi' in batch:
570
- uncert_emb = self.uncertainty_embed(batch['uncert_bins_multi']) # (B, L, d_model)
571
  else:
572
- uncert_emb = self.uncertainty_embed(batch['uncert_bins']) # (B, L, d_model)
573
 
574
  # Geohash embedding
575
  if self.config.geohash_mode == 'continuous':
576
  geo_emb = self.geohash_embed(batch['east'], batch['north'], batch['up'])
577
  elif self.geohash_embed is not None:
578
- geo_emb = self.geohash_embed(batch['geohash_bits']) # (B, L, d_model)
579
  else:
580
  geo_emb = torch.zeros_like(feat_emb)
581
 
582
- # --- Additive fusion ---
583
- state_emb = feat_emb + temp_emb + uncert_emb + geo_emb # (B, L, d_model)
584
  state_emb = self.fusion_ln(state_emb)
585
 
586
- # --- Prepend prompt tokens ---
587
- prompt_emb = self.prompt_embed(batch['prompt']) # (B, n_prompt, d_model)
 
588
 
589
- # Concatenate: [PROMPT | STATE_1 | STATE_2 | ... | STATE_T]
590
- x = torch.cat([prompt_emb, state_emb], dim=1) # (B, n_prompt + L, d_model)
591
-
592
- # --- Positional encoding ---
593
  x = self.pos_encoding(x)
594
-
595
- # --- Causal transformer ---
596
  seq_len = x.size(1)
597
  causal_mask = self.get_causal_mask(seq_len, device)
598
 
@@ -601,26 +372,22 @@ class AirTrackLM(nn.Module):
601
 
602
  x = self.final_ln(x)
603
 
604
- # --- Split output ---
605
  n_prompt = batch['prompt'].size(1)
606
- prompt_output = x[:, :n_prompt, :] # (B, n_prompt, d_model)
607
- state_output = x[:, n_prompt:, :] # (B, L, d_model)
608
 
609
- # --- Prediction heads (on state output) ---
610
  predictions = self.prediction_head(state_output)
611
 
612
- # --- Heteroscedastic uncertainty (learned aleatoric) ---
613
  if self.heteroscedastic_head is not None:
614
- predictions['log_var'] = self.heteroscedastic_head(state_output) # (B, L, 6)
615
 
616
- # --- Classification (optional) ---
617
  if self.classification_head is not None:
618
- predictions['class_logits'] = self.classification_head(x) # uses BOS at position 0
619
 
620
  return predictions
621
 
622
- def count_parameters(self) -> Dict[str, int]:
623
- """Count parameters by component."""
624
  counts = {}
625
  for name, module in [
626
  ('geohash_embed', self.geohash_embed),
@@ -633,7 +400,6 @@ class AirTrackLM(nn.Module):
633
  ]:
634
  if module is not None:
635
  counts[name] = sum(p.numel() for p in module.parameters())
636
-
637
  counts['total'] = sum(p.numel() for p in self.parameters())
638
  counts['trainable'] = sum(p.numel() for p in self.parameters() if p.requires_grad)
639
  return counts
@@ -644,122 +410,57 @@ class AirTrackLM(nn.Module):
644
  # ============================================================
645
 
646
  class NextStateLoss(nn.Module):
647
- """
648
- Multi-task loss for next-state prediction.
649
-
650
- For each position t, the model predicts features at t+1.
651
- Losses:
652
- - Geohash: Binary cross-entropy per bit
653
- - Kinematic features (COG, SOG, ROT, alt_rate): Cross-entropy per feature
654
- - Continuous ENU: MSE (optional)
655
- """
656
-
657
- def __init__(self, config: AirTrackConfig, loss_weights: Optional[Dict[str, float]] = None):
658
  super().__init__()
659
  self.config = config
660
-
661
- # Default loss weights (equal)
662
  self.weights = loss_weights or {
663
- 'geohash': 1.0,
664
- 'continuous': 0.5,
665
- 'cog': 1.0,
666
- 'sog': 1.0,
667
- 'rot': 1.0,
668
- 'alt_rate': 1.0,
669
  }
670
-
671
  self.ce = nn.CrossEntropyLoss(reduction='mean')
672
  self.bce = nn.BCEWithLogitsLoss(reduction='mean')
673
  self.mse = nn.MSELoss(reduction='mean')
674
 
675
- def forward(
676
- self,
677
- predictions: Dict[str, torch.Tensor],
678
- batch: Dict[str, torch.Tensor],
679
- ) -> Tuple[torch.Tensor, Dict[str, float]]:
680
- """
681
- Compute loss. Targets are shifted by 1 (predict next state).
682
-
683
- predictions[key] is at positions [0, 1, ..., L-1]
684
- targets are batch[key] at positions [1, 2, ..., L]
685
-
686
- So we compare predictions[:, :-1, :] with targets[:, 1:, :]
687
- """
688
  losses = {}
689
 
690
- # --- Geohash binary prediction ---
691
  if self.config.predict_geohash and 'geohash_logits' in predictions:
692
- # predictions: (B, L, 120), targets: (B, L, 120) float
693
- pred_geo = predictions['geohash_logits'][:, :-1, :] # (B, L-1, 120)
694
- target_geo = batch['geohash_bits'][:, 1:, :] # (B, L-1, 120)
695
  losses['geohash'] = self.bce(pred_geo, target_geo) * self.weights['geohash']
696
 
697
- # --- Continuous ENU regression (predict delta in km, not raw meters) ---
698
  if self.config.predict_continuous and 'continuous_pred' in predictions:
699
- pred_cont = predictions['continuous_pred'][:, :-1, :] # (B, L-1, 3)
700
- # Target is delta-ENU: position(t+1) - position(t), normalized to km
701
  delta_east = (batch['east'][:, 1:] - batch['east'][:, :-1]) / 1000.0
702
  delta_north = (batch['north'][:, 1:] - batch['north'][:, :-1]) / 1000.0
703
  delta_up = (batch['up'][:, 1:] - batch['up'][:, :-1]) / 1000.0
704
  target_delta = torch.stack([delta_east, delta_north, delta_up], dim=-1)
705
  losses['continuous'] = self.mse(pred_cont, target_delta) * self.weights['continuous']
706
 
707
- # --- COG ---
708
- pred_cog = predictions['cog_logits'][:, :-1, :] # (B, L-1, n_cog_bins)
709
- target_cog = batch['cog_bins'][:, 1:] # (B, L-1)
710
- losses['cog'] = self.ce(pred_cog.reshape(-1, pred_cog.size(-1)), target_cog.reshape(-1)) * self.weights['cog']
711
 
712
- # --- SOG ---
713
- pred_sog = predictions['sog_logits'][:, :-1, :]
714
- target_sog = batch['sog_bins'][:, 1:]
715
- losses['sog'] = self.ce(pred_sog.reshape(-1, pred_sog.size(-1)), target_sog.reshape(-1)) * self.weights['sog']
716
-
717
- # --- ROT ---
718
- pred_rot = predictions['rot_logits'][:, :-1, :]
719
- target_rot = batch['rot_bins'][:, 1:]
720
- losses['rot'] = self.ce(pred_rot.reshape(-1, pred_rot.size(-1)), target_rot.reshape(-1)) * self.weights['rot']
721
-
722
- # --- Alt rate ---
723
- pred_ar = predictions['alt_rate_logits'][:, :-1, :]
724
- target_ar = batch['alt_rate_bins'][:, 1:]
725
- losses['alt_rate'] = self.ce(pred_ar.reshape(-1, pred_ar.size(-1)), target_ar.reshape(-1)) * self.weights['alt_rate']
726
-
727
- # --- Heteroscedastic regularization (learned aleatoric uncertainty) ---
728
  if 'log_var' in predictions:
729
- log_var = predictions['log_var'][:, :-1, :] # (B, L-1, 6)
730
- # Clamp log_var to prevent collapse: [-5, 5] range
731
- log_var_clamped = torch.clamp(log_var, -5.0, 5.0)
732
- # Regularize toward 0 (unit variance prior)
733
- losses['log_var_reg'] = 0.1 * (log_var_clamped ** 2).mean()
734
 
735
- # Total loss
736
  total_loss = sum(losses.values())
737
-
738
- # Log individual losses
739
  loss_log = {k: v.item() for k, v in losses.items()}
740
  loss_log['total'] = total_loss.item()
741
-
742
  return total_loss, loss_log
743
 
744
 
745
- # ============================================================
746
- # Quick test
747
- # ============================================================
748
-
749
  if __name__ == '__main__':
750
  config = AirTrackConfig()
751
  model = AirTrackLM(config)
752
-
753
- # Print parameter counts
754
  counts = model.count_parameters()
755
  print("Parameter counts:")
756
  for name, count in counts.items():
757
  print(f" {name}: {count:,}")
758
 
759
- # Test forward pass with dummy data
760
- B, L = 2, 65 # batch=2, seq_len=65 (64 states + 1 for target shift)
761
- n_prompt = config.n_prompt_len
762
-
763
  batch = {
764
  'geohash_bits': torch.randn(B, L, config.geohash_bits),
765
  'cog_bins': torch.randint(0, config.n_cog_bins, (B, L)),
@@ -767,24 +468,23 @@ if __name__ == '__main__':
767
  'rot_bins': torch.randint(0, config.n_rot_bins, (B, L)),
768
  'alt_rate_bins': torch.randint(0, config.n_alt_rate_bins, (B, L)),
769
  'uncert_bins': torch.randint(0, config.n_uncert_bins, (B, L)),
 
770
  'hour': torch.randint(0, 24, (B, L)),
771
  'dow': torch.randint(0, 7, (B, L)),
772
  'month': torch.randint(0, 12, (B, L)),
773
  'second_of_day': torch.rand(B, L) * 86400,
774
  'dt': torch.ones(B, L) * 5.0,
775
- 'prompt': torch.randint(0, config.n_prompt_tokens, (B, n_prompt)),
776
  'east': torch.randn(B, L) * 1000,
777
  'north': torch.randn(B, L) * 1000,
778
  'up': torch.randn(B, L) * 1000,
779
  }
780
 
781
  predictions = model(batch)
782
-
783
  print("\nPrediction shapes:")
784
  for k, v in predictions.items():
785
  print(f" {k}: {v.shape}")
786
 
787
- # Test loss
788
  loss_fn = NextStateLoss(config)
789
  total_loss, loss_log = loss_fn(predictions, batch)
790
  print(f"\nLoss: {loss_log}")
 
3
  ================================
4
  Decoder-only transformer with 4 embedding types for air track next-state prediction.
5
 
6
+ Embedding types:
7
+ 1. Geohash: 120-bit binary (40 per ENU axis) → MLP → d_model
8
+ 2. Kinematic: Learned embeddings for discretized COG/SOG/ROT/alt_rate
9
+ 3. Temporal: Sinusoidal second-of-day (sub-second) + learned hour/dow/month + Δt
10
+ 4. Uncertainty: Multi-method learned embeddings with attention fusion
11
+
12
+ Architecture:
13
+ - Additive embedding fusion
14
+ - Prompt tokens prepended
15
+ - Pre-norm decoder-only transformer with causal masking
16
+ - Multi-head output (geohash bits + kinematic bins + continuous ENU regression)
17
  """
18
 
19
  import math
 
24
  from dataclasses import dataclass
25
 
26
 
 
 
 
 
27
  @dataclass
28
  class AirTrackConfig:
 
 
 
29
  d_model: int = 256
30
  n_heads: int = 8
31
  n_layers: int = 8
32
  d_ff: int = 1024
33
  dropout: float = 0.1
34
+ max_seq_len: int = 256
35
 
36
+ # Geohash
37
+ geohash_bits: int = 120 # 40 × 3 axes
38
+ geohash_hidden: int = 64
39
 
40
+ # Feature bins
41
+ n_cog_bins: int = 180 # 2° resolution
42
+ n_sog_bins: int = 300 # 2-knot resolution
43
+ n_rot_bins: int = 120 # 0.1°/s resolution
44
+ n_alt_rate_bins: int = 120 # 100 ft/min resolution
45
 
46
+ # Temporal
47
  n_hours: int = 24
48
  n_dow: int = 7
49
  n_months: int = 12
50
+ time_sinusoidal_dim: int = 32
51
 
52
+ # Uncertainty
53
  n_uncert_bins: int = 16
54
+ n_uncert_methods: int = 4
55
+ use_multi_uncertainty: bool = True
56
+ use_heteroscedastic: bool = True
 
 
 
 
57
 
58
+ # Prompt
59
+ n_prompt_tokens: int = 23
60
+ n_prompt_len: int = 5
 
61
 
62
+ # Output
63
+ predict_geohash: bool = True
64
+ predict_continuous: bool = True
65
+ geohash_mode: str = 'absolute'
66
 
67
 
68
  # ============================================================
 
70
  # ============================================================
71
 
72
  class GeohashEmbedding(nn.Module):
73
+ """Binary geohash → MLP → d_model."""
74
+ def __init__(self, config):
 
 
 
 
 
 
 
 
75
  super().__init__()
76
  self.projection = nn.Sequential(
77
  nn.Linear(config.geohash_bits, config.geohash_hidden),
 
79
  nn.Linear(config.geohash_hidden, config.d_model),
80
  )
81
 
82
+ def forward(self, geohash_bits):
 
 
 
 
 
 
83
  return self.projection(geohash_bits)
84
 
85
 
86
  class ContinuousPositionEmbedding(nn.Module):
87
+ """Ablation: direct linear projection of continuous ENU."""
88
+ def __init__(self, config):
 
89
  super().__init__()
90
  self.projection = nn.Sequential(
91
  nn.Linear(3, config.geohash_hidden),
 
93
  nn.Linear(config.geohash_hidden, config.d_model),
94
  )
95
 
96
+ def forward(self, east, north, up):
97
+ pos = torch.stack([east, north, up], dim=-1)
 
 
 
 
 
 
98
  return self.projection(pos)
99
 
100
 
101
  class FeatureEmbedding(nn.Module):
102
+ """Learned embeddings for discretized kinematic features, summed."""
103
+ def __init__(self, config):
 
 
 
 
104
  super().__init__()
105
  self.cog_embed = nn.Embedding(config.n_cog_bins, config.d_model)
106
  self.sog_embed = nn.Embedding(config.n_sog_bins, config.d_model)
107
  self.rot_embed = nn.Embedding(config.n_rot_bins, config.d_model)
108
  self.alt_rate_embed = nn.Embedding(config.n_alt_rate_bins, config.d_model)
109
 
110
+ def forward(self, cog_bins, sog_bins, rot_bins, alt_rate_bins):
111
+ return (self.cog_embed(cog_bins) + self.sog_embed(sog_bins) +
112
+ self.rot_embed(rot_bins) + self.alt_rate_embed(alt_rate_bins))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
 
115
  class TemporalEmbedding(nn.Module):
116
  """
117
+ Temporal: sinusoidal second-of-day (sub-second precision) + learned calendar + Δt.
 
 
 
 
 
 
118
  """
119
+ def __init__(self, config):
 
120
  super().__init__()
 
 
121
  self.hour_embed = nn.Embedding(config.n_hours, config.d_model)
122
  self.dow_embed = nn.Embedding(config.n_dow, config.d_model)
123
  self.month_embed = nn.Embedding(config.n_months, config.d_model)
124
 
 
 
125
  self.time_sin_dim = config.time_sinusoidal_dim
126
  self.time_projection = nn.Linear(config.time_sinusoidal_dim * 2, config.d_model)
 
 
127
  self.dt_projection = nn.Linear(config.time_sinusoidal_dim * 2, config.d_model)
128
 
129
+ # Multiple frequency bases for sub-second precision
130
+ freqs = torch.exp(torch.arange(0, config.time_sinusoidal_dim, dtype=torch.float32) *
 
131
  -(math.log(86400.0) / config.time_sinusoidal_dim))
132
  self.register_buffer('time_freqs', freqs)
133
 
 
135
  -(math.log(3600.0) / config.time_sinusoidal_dim))
136
  self.register_buffer('dt_freqs', dt_freqs)
137
 
138
+ def _sinusoidal(self, values, freqs):
 
 
 
 
 
 
 
 
 
 
139
  angles = values.unsqueeze(-1) * freqs.unsqueeze(0).unsqueeze(0) * 2 * math.pi
140
  return torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)
141
 
142
+ def forward(self, hour, dow, month, second_of_day, dt):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  cal = self.hour_embed(hour) + self.dow_embed(dow) + self.month_embed(month)
144
+ time_emb = self.time_projection(self._sinusoidal(second_of_day, self.time_freqs))
145
+ dt_emb = self.dt_projection(self._sinusoidal(dt, self.dt_freqs))
 
 
 
 
 
 
 
146
  return cal + time_emb + dt_emb
147
 
148
 
149
  class UncertaintyEmbedding(nn.Module):
150
+ def __init__(self, config):
 
 
151
  super().__init__()
152
  self.embed = nn.Embedding(config.n_uncert_bins, config.d_model)
153
 
154
+ def forward(self, uncert_bins):
 
 
 
 
 
 
155
  return self.embed(uncert_bins)
156
 
157
 
158
  class PromptEmbedding(nn.Module):
159
+ def __init__(self, config):
 
 
160
  super().__init__()
161
  self.embed = nn.Embedding(config.n_prompt_tokens, config.d_model)
162
 
163
+ def forward(self, prompt_tokens):
 
 
 
 
 
 
164
  return self.embed(prompt_tokens)
165
 
166
 
 
169
  # ============================================================
170
 
171
  class SinusoidalPositionalEncoding(nn.Module):
172
+ def __init__(self, d_model, max_len=512, dropout=0.1):
 
 
173
  super().__init__()
174
  self.dropout = nn.Dropout(p=dropout)
 
175
  pe = torch.zeros(max_len, d_model)
176
  position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
177
  div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
178
  pe[:, 0::2] = torch.sin(position * div_term)
179
  pe[:, 1::2] = torch.cos(position * div_term)
180
+ self.register_buffer('pe', pe.unsqueeze(0))
 
181
 
182
+ def forward(self, x):
 
183
  x = x + self.pe[:, :x.size(1)]
184
  return self.dropout(x)
185
 
186
 
187
  # ============================================================
188
+ # Transformer
189
  # ============================================================
190
 
191
  class TransformerBlock(nn.Module):
192
+ def __init__(self, config):
 
 
193
  super().__init__()
 
194
  self.ln1 = nn.LayerNorm(config.d_model)
195
  self.attn = nn.MultiheadAttention(
196
+ embed_dim=config.d_model, num_heads=config.n_heads,
197
+ dropout=config.dropout, batch_first=True,
 
 
198
  )
199
  self.ln2 = nn.LayerNorm(config.d_model)
200
  self.ffn = nn.Sequential(
 
205
  nn.Dropout(config.dropout),
206
  )
207
 
208
+ def forward(self, x, attn_mask=None):
 
 
 
 
 
 
 
 
209
  h = self.ln1(x)
210
  h, _ = self.attn(h, h, h, attn_mask=attn_mask, is_causal=(attn_mask is None))
211
  x = x + h
 
212
  h = self.ln2(x)
213
+ x = x + self.ffn(h)
 
 
214
  return x
215
 
216
 
 
219
  # ============================================================
220
 
221
  class NextStatePredictionHead(nn.Module):
222
+ def __init__(self, config):
 
 
 
 
 
223
  super().__init__()
224
+ self.config = config
 
225
  if config.predict_geohash:
226
  self.geohash_head = nn.Linear(config.d_model, config.geohash_bits)
 
 
227
  if config.predict_continuous:
228
  self.continuous_head = nn.Sequential(
229
  nn.Linear(config.d_model, config.d_model // 2),
230
  nn.GELU(),
231
+ nn.Linear(config.d_model // 2, 3),
232
  )
 
 
233
  self.cog_head = nn.Linear(config.d_model, config.n_cog_bins)
234
  self.sog_head = nn.Linear(config.d_model, config.n_sog_bins)
235
  self.rot_head = nn.Linear(config.d_model, config.n_rot_bins)
236
  self.alt_rate_head = nn.Linear(config.d_model, config.n_alt_rate_bins)
 
 
237
 
238
+ def forward(self, hidden_states):
 
 
 
 
 
 
239
  out = {}
 
240
  if self.config.predict_geohash:
241
+ out['geohash_logits'] = self.geohash_head(hidden_states)
 
242
  if self.config.predict_continuous:
243
+ out['continuous_pred'] = self.continuous_head(hidden_states)
244
+ out['cog_logits'] = self.cog_head(hidden_states)
245
+ out['sog_logits'] = self.sog_head(hidden_states)
246
+ out['rot_logits'] = self.rot_head(hidden_states)
247
+ out['alt_rate_logits'] = self.alt_rate_head(hidden_states)
 
 
248
  return out
249
 
250
 
251
  class ClassificationHead(nn.Module):
252
+ def __init__(self, d_model, n_classes, dropout=0.1):
 
 
253
  super().__init__()
254
  self.head = nn.Sequential(
255
+ nn.Linear(d_model, d_model // 2), nn.GELU(),
256
+ nn.Dropout(dropout), nn.Linear(d_model // 2, n_classes),
 
 
257
  )
258
 
259
+ def forward(self, hidden_states):
260
+ return self.head(hidden_states[:, 0, :])
 
 
 
 
 
 
 
 
 
261
 
262
 
263
  # ============================================================
 
265
  # ============================================================
266
 
267
  class AirTrackLM(nn.Module):
268
+ def __init__(self, config):
 
 
 
 
 
 
 
 
 
269
  super().__init__()
270
  self.config = config
271
 
272
+ # Geohash embedding
273
+ if config.geohash_mode == 'continuous':
 
 
 
 
274
  self.geohash_embed = ContinuousPositionEmbedding(config)
275
  elif config.geohash_mode == 'none':
276
  self.geohash_embed = None
277
  else:
 
278
  self.geohash_embed = GeohashEmbedding(config)
279
 
 
280
  self.feature_embed = FeatureEmbedding(config)
 
 
281
  self.temporal_embed = TemporalEmbedding(config)
282
 
283
+ # Uncertainty embedding
284
  if config.use_multi_uncertainty and config.n_uncert_methods > 1:
285
  from uncertainty import MultiUncertaintyEmbedding
286
  self.uncertainty_embed = MultiUncertaintyEmbedding(
 
291
  self.uncertainty_embed = UncertaintyEmbedding(config)
292
  self._multi_uncert = False
293
 
294
+ # Heteroscedastic head
295
  self.heteroscedastic_head = None
296
  if config.use_heteroscedastic:
297
  from uncertainty import HeteroscedasticHead
298
  self.heteroscedastic_head = HeteroscedasticHead(config.d_model, n_outputs=6)
299
 
 
300
  self.prompt_embed = PromptEmbedding(config)
 
 
 
301
  self.fusion_ln = nn.LayerNorm(config.d_model)
302
+ self.pos_encoding = SinusoidalPositionalEncoding(config.d_model, config.max_seq_len, config.dropout)
303
+ self.blocks = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
 
 
 
 
 
 
 
 
 
 
304
  self.final_ln = nn.LayerNorm(config.d_model)
 
 
305
  self.prediction_head = NextStatePredictionHead(config)
 
 
306
  self.classification_head = None
307
 
 
308
  self.apply(self._init_weights)
309
 
310
  def _init_weights(self, module):
311
  if isinstance(module, nn.Linear):
312
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
313
  if module.bias is not None:
314
+ nn.init.zeros_(module.bias)
315
  elif isinstance(module, nn.Embedding):
316
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
317
  elif isinstance(module, nn.LayerNorm):
318
+ nn.init.ones_(module.weight)
319
+ nn.init.zeros_(module.bias)
320
 
321
+ def attach_classification_head(self, n_classes):
322
+ self.classification_head = ClassificationHead(self.config.d_model, n_classes, self.config.dropout)
 
 
 
323
 
324
+ def get_causal_mask(self, seq_len, device):
 
325
  mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1)
326
+ return mask.masked_fill(mask == 1, float('-inf'))
 
327
 
328
+ def forward(self, batch):
 
 
 
 
 
 
 
 
 
329
  device = batch['cog_bins'].device
 
 
 
330
 
331
+ # Feature embedding
332
  feat_emb = self.feature_embed(
333
+ batch['cog_bins'], batch['sog_bins'],
334
  batch['rot_bins'], batch['alt_rate_bins']
335
+ )
336
 
337
  # Temporal embedding
338
  temp_emb = self.temporal_embed(
339
  batch['hour'], batch['dow'], batch['month'],
340
  batch['second_of_day'], batch['dt']
341
+ )
342
 
343
+ # Uncertainty embedding
344
  if self._multi_uncert and 'uncert_bins_multi' in batch:
345
+ uncert_emb = self.uncertainty_embed(batch['uncert_bins_multi'])
346
  else:
347
+ uncert_emb = self.uncertainty_embed(batch['uncert_bins'])
348
 
349
  # Geohash embedding
350
  if self.config.geohash_mode == 'continuous':
351
  geo_emb = self.geohash_embed(batch['east'], batch['north'], batch['up'])
352
  elif self.geohash_embed is not None:
353
+ geo_emb = self.geohash_embed(batch['geohash_bits'])
354
  else:
355
  geo_emb = torch.zeros_like(feat_emb)
356
 
357
+ # Additive fusion
358
+ state_emb = feat_emb + temp_emb + uncert_emb + geo_emb
359
  state_emb = self.fusion_ln(state_emb)
360
 
361
+ # Prepend prompt
362
+ prompt_emb = self.prompt_embed(batch['prompt'])
363
+ x = torch.cat([prompt_emb, state_emb], dim=1)
364
 
365
+ # Positional encoding + transformer
 
 
 
366
  x = self.pos_encoding(x)
 
 
367
  seq_len = x.size(1)
368
  causal_mask = self.get_causal_mask(seq_len, device)
369
 
 
372
 
373
  x = self.final_ln(x)
374
 
375
+ # Split prompt / state outputs
376
  n_prompt = batch['prompt'].size(1)
377
+ state_output = x[:, n_prompt:, :]
 
378
 
379
+ # Predictions
380
  predictions = self.prediction_head(state_output)
381
 
 
382
  if self.heteroscedastic_head is not None:
383
+ predictions['log_var'] = self.heteroscedastic_head(state_output)
384
 
 
385
  if self.classification_head is not None:
386
+ predictions['class_logits'] = self.classification_head(x)
387
 
388
  return predictions
389
 
390
+ def count_parameters(self):
 
391
  counts = {}
392
  for name, module in [
393
  ('geohash_embed', self.geohash_embed),
 
400
  ]:
401
  if module is not None:
402
  counts[name] = sum(p.numel() for p in module.parameters())
 
403
  counts['total'] = sum(p.numel() for p in self.parameters())
404
  counts['trainable'] = sum(p.numel() for p in self.parameters() if p.requires_grad)
405
  return counts
 
410
  # ============================================================
411
 
412
  class NextStateLoss(nn.Module):
413
+ def __init__(self, config, loss_weights=None):
 
 
 
 
 
 
 
 
 
 
414
  super().__init__()
415
  self.config = config
 
 
416
  self.weights = loss_weights or {
417
+ 'geohash': 1.0, 'continuous': 0.5,
418
+ 'cog': 1.0, 'sog': 1.0, 'rot': 1.0, 'alt_rate': 1.0,
 
 
 
 
419
  }
 
420
  self.ce = nn.CrossEntropyLoss(reduction='mean')
421
  self.bce = nn.BCEWithLogitsLoss(reduction='mean')
422
  self.mse = nn.MSELoss(reduction='mean')
423
 
424
+ def forward(self, predictions, batch):
 
 
 
 
 
 
 
 
 
 
 
 
425
  losses = {}
426
 
 
427
  if self.config.predict_geohash and 'geohash_logits' in predictions:
428
+ pred_geo = predictions['geohash_logits'][:, :-1, :]
429
+ target_geo = batch['geohash_bits'][:, 1:, :]
 
430
  losses['geohash'] = self.bce(pred_geo, target_geo) * self.weights['geohash']
431
 
 
432
  if self.config.predict_continuous and 'continuous_pred' in predictions:
433
+ pred_cont = predictions['continuous_pred'][:, :-1, :]
 
434
  delta_east = (batch['east'][:, 1:] - batch['east'][:, :-1]) / 1000.0
435
  delta_north = (batch['north'][:, 1:] - batch['north'][:, :-1]) / 1000.0
436
  delta_up = (batch['up'][:, 1:] - batch['up'][:, :-1]) / 1000.0
437
  target_delta = torch.stack([delta_east, delta_north, delta_up], dim=-1)
438
  losses['continuous'] = self.mse(pred_cont, target_delta) * self.weights['continuous']
439
 
440
+ for feat in ['cog', 'sog', 'rot', 'alt_rate']:
441
+ pred = predictions[f'{feat}_logits'][:, :-1, :]
442
+ target = batch[f'{feat}_bins'][:, 1:]
443
+ losses[feat] = self.ce(pred.reshape(-1, pred.size(-1)), target.reshape(-1)) * self.weights[feat]
444
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
445
  if 'log_var' in predictions:
446
+ log_var = torch.clamp(predictions['log_var'][:, :-1, :], -5.0, 5.0)
447
+ losses['log_var_reg'] = 0.1 * (log_var ** 2).mean()
 
 
 
448
 
 
449
  total_loss = sum(losses.values())
 
 
450
  loss_log = {k: v.item() for k, v in losses.items()}
451
  loss_log['total'] = total_loss.item()
 
452
  return total_loss, loss_log
453
 
454
 
 
 
 
 
455
  if __name__ == '__main__':
456
  config = AirTrackConfig()
457
  model = AirTrackLM(config)
 
 
458
  counts = model.count_parameters()
459
  print("Parameter counts:")
460
  for name, count in counts.items():
461
  print(f" {name}: {count:,}")
462
 
463
+ B, L = 2, 65
 
 
 
464
  batch = {
465
  'geohash_bits': torch.randn(B, L, config.geohash_bits),
466
  'cog_bins': torch.randint(0, config.n_cog_bins, (B, L)),
 
468
  'rot_bins': torch.randint(0, config.n_rot_bins, (B, L)),
469
  'alt_rate_bins': torch.randint(0, config.n_alt_rate_bins, (B, L)),
470
  'uncert_bins': torch.randint(0, config.n_uncert_bins, (B, L)),
471
+ 'uncert_bins_multi': torch.randint(0, config.n_uncert_bins, (B, L, config.n_uncert_methods)),
472
  'hour': torch.randint(0, 24, (B, L)),
473
  'dow': torch.randint(0, 7, (B, L)),
474
  'month': torch.randint(0, 12, (B, L)),
475
  'second_of_day': torch.rand(B, L) * 86400,
476
  'dt': torch.ones(B, L) * 5.0,
477
+ 'prompt': torch.randint(0, config.n_prompt_tokens, (B, config.n_prompt_len)),
478
  'east': torch.randn(B, L) * 1000,
479
  'north': torch.randn(B, L) * 1000,
480
  'up': torch.randn(B, L) * 1000,
481
  }
482
 
483
  predictions = model(batch)
 
484
  print("\nPrediction shapes:")
485
  for k, v in predictions.items():
486
  print(f" {k}: {v.shape}")
487
 
 
488
  loss_fn = NextStateLoss(config)
489
  total_loss, loss_log = loss_fn(predictions, batch)
490
  print(f"\nLoss: {loss_log}")