AirTrackLM / model.py
Jdice27's picture
Update model.py - fix heteroscedastic loss clamping
a7372d1 verified
raw
history blame
30.2 kB
"""
AirTrackLM - Model Architecture
================================
Decoder-only transformer with 4 embedding types for air track next-state prediction.
Embedding types (following LLM4STP, adapted for aviation):
1. Geohash: 40-bit binary per ENU axis (120 bits total) → Linear projection → d_model
2. Temporal: Sinusoidal second-of-day + learned hour/dow/month embeddings
3. Uncertainty: Learned embedding from trajectory smoothness bins
4. Prompt: Learned tokens for task/aircraft/phase/region metadata
Core architecture:
- Additive embedding fusion (E_geo + E_feat + E_temp + E_uncert)
- Prompt tokens prepended to sequence
- Causal (GPT-style) multi-head self-attention
- Multi-head output: separate prediction per feature type
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Dict, Tuple
from dataclasses import dataclass
# ============================================================
# Configuration
# ============================================================
@dataclass
class AirTrackConfig:
"""Model configuration."""
# Transformer backbone
d_model: int = 256
n_heads: int = 8
n_layers: int = 8
d_ff: int = 1024
dropout: float = 0.1
max_seq_len: int = 256 # max sequence length (prompt + trajectory)
# Geohash embedding (LLM4STP style)
geohash_bits: int = 120 # 40 bits × 3 axes (E, N, U)
geohash_hidden: int = 64 # intermediate projection dim
# Feature bins (discretized kinematic features)
n_cog_bins: int = 180 # 2° resolution over [0, 360)
n_sog_bins: int = 300 # 2-knot resolution over [0, 600]
n_rot_bins: int = 120 # 0.1°/s over [-6, 6]
n_alt_rate_bins: int = 120 # 100 ft/min over [-6000, 6000]
# Temporal embedding
n_hours: int = 24
n_dow: int = 7
n_months: int = 12
time_sinusoidal_dim: int = 32 # dimension for sinusoidal second-of-day encoding
# Uncertainty embedding
n_uncert_bins: int = 16
n_uncert_methods: int = 4 # kinematic_var, pred_residual, spatial_density, phase_entropy
use_multi_uncertainty: bool = True # if True, use MultiUncertaintyEmbedding
use_heteroscedastic: bool = True # if True, add learned uncertainty head
# Prompt embedding
n_prompt_tokens: int = 23 # PromptTokens.VOCAB_SIZE
n_prompt_len: int = 5 # [BOS, TASK, AIRCRAFT, PHASE, REGION]
# Output heads
# We predict: geohash (regression), COG bin, SOG bin, ROT bin, alt_rate bin
predict_geohash: bool = True # if True, predict geohash bits (binary classification per bit)
predict_continuous: bool = True # if True, also predict continuous ENU offset (regression)
# Ablation variants for geohash
geohash_mode: str = 'absolute' # 'absolute', 'none', 'relative', 'multi_res', 'continuous'
# ============================================================
# Embedding Modules
# ============================================================
class GeohashEmbedding(nn.Module):
"""
Binary geohash embedding following LLM4STP.
Projects 120-bit binary vector through:
Linear(120 → geohash_hidden) → ReLU → Linear(geohash_hidden → d_model)
LLM4STP uses Conv1d on the bits, but we use MLP for simplicity
since each timestep's 120 bits are independent.
"""
def __init__(self, config: AirTrackConfig):
super().__init__()
self.projection = nn.Sequential(
nn.Linear(config.geohash_bits, config.geohash_hidden),
nn.ReLU(),
nn.Linear(config.geohash_hidden, config.d_model),
)
def forward(self, geohash_bits: torch.Tensor) -> torch.Tensor:
"""
Args:
geohash_bits: (batch, seq_len, 120) float tensor of binary geohash
Returns:
(batch, seq_len, d_model)
"""
return self.projection(geohash_bits)
class ContinuousPositionEmbedding(nn.Module):
"""Ablation variant V5: direct linear projection of continuous ENU coordinates."""
def __init__(self, config: AirTrackConfig):
super().__init__()
self.projection = nn.Sequential(
nn.Linear(3, config.geohash_hidden),
nn.ReLU(),
nn.Linear(config.geohash_hidden, config.d_model),
)
def forward(self, east: torch.Tensor, north: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
"""
Args:
east, north, up: (batch, seq_len) each
Returns:
(batch, seq_len, d_model)
"""
pos = torch.stack([east, north, up], dim=-1) # (B, L, 3)
return self.projection(pos)
class FeatureEmbedding(nn.Module):
"""
Learned embedding tables for discretized kinematic features.
Each feature has its own embedding table, all outputs summed.
"""
def __init__(self, config: AirTrackConfig):
super().__init__()
self.cog_embed = nn.Embedding(config.n_cog_bins, config.d_model)
self.sog_embed = nn.Embedding(config.n_sog_bins, config.d_model)
self.rot_embed = nn.Embedding(config.n_rot_bins, config.d_model)
self.alt_rate_embed = nn.Embedding(config.n_alt_rate_bins, config.d_model)
def forward(
self,
cog_bins: torch.Tensor,
sog_bins: torch.Tensor,
rot_bins: torch.Tensor,
alt_rate_bins: torch.Tensor,
) -> torch.Tensor:
"""
Args:
*_bins: (batch, seq_len) long tensors of bin indices
Returns:
(batch, seq_len, d_model) — sum of all feature embeddings
"""
return (
self.cog_embed(cog_bins) +
self.sog_embed(sog_bins) +
self.rot_embed(rot_bins) +
self.alt_rate_embed(alt_rate_bins)
)
class TemporalEmbedding(nn.Module):
"""
Temporal embedding combining:
1. Sinusoidal encoding of second-of-day (sub-second resolution)
2. Learned embeddings for hour, day-of-week, month
3. Sinusoidal encoding of delta-t (time since previous state)
The sinusoidal encoding gives sub-second precision since it operates
on continuous float seconds, not discrete bins.
"""
def __init__(self, config: AirTrackConfig):
super().__init__()
# Learned calendar embeddings
self.hour_embed = nn.Embedding(config.n_hours, config.d_model)
self.dow_embed = nn.Embedding(config.n_dow, config.d_model)
self.month_embed = nn.Embedding(config.n_months, config.d_model)
# Sinusoidal projection for continuous time features
# second_of_day → sinusoidal features → linear → d_model
self.time_sin_dim = config.time_sinusoidal_dim
self.time_projection = nn.Linear(config.time_sinusoidal_dim * 2, config.d_model)
# Delta-t projection
self.dt_projection = nn.Linear(config.time_sinusoidal_dim * 2, config.d_model)
# Pre-compute frequency bases for sinusoidal encoding
# Multiple frequencies to capture different time scales
freqs = torch.exp(torch.arange(0, config.time_sinusoidal_dim, dtype=torch.float32) *
-(math.log(86400.0) / config.time_sinusoidal_dim))
self.register_buffer('time_freqs', freqs)
dt_freqs = torch.exp(torch.arange(0, config.time_sinusoidal_dim, dtype=torch.float32) *
-(math.log(3600.0) / config.time_sinusoidal_dim))
self.register_buffer('dt_freqs', dt_freqs)
def sinusoidal_encode(self, values: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
"""
Encode continuous values with multiple sinusoidal frequencies.
Args:
values: (batch, seq_len) — continuous values
freqs: (dim,) — frequency bases
Returns:
(batch, seq_len, dim*2) — sin and cos features
"""
# (B, L, 1) * (1, 1, dim) → (B, L, dim)
angles = values.unsqueeze(-1) * freqs.unsqueeze(0).unsqueeze(0) * 2 * math.pi
return torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)
def forward(
self,
hour: torch.Tensor,
dow: torch.Tensor,
month: torch.Tensor,
second_of_day: torch.Tensor,
dt: torch.Tensor,
) -> torch.Tensor:
"""
Args:
hour: (B, L) long — hour of day [0, 23]
dow: (B, L) long — day of week [0, 6]
month: (B, L) long — month [0, 11]
second_of_day: (B, L) float — seconds within day [0, 86400)
dt: (B, L) float — delta-t in seconds
Returns:
(B, L, d_model)
"""
# Learned calendar embeddings
cal = self.hour_embed(hour) + self.dow_embed(dow) + self.month_embed(month)
# Sinusoidal second-of-day (sub-second resolution)
time_sin = self.sinusoidal_encode(second_of_day, self.time_freqs) # (B, L, dim*2)
time_emb = self.time_projection(time_sin) # (B, L, d_model)
# Sinusoidal delta-t
dt_sin = self.sinusoidal_encode(dt, self.dt_freqs) # (B, L, dim*2)
dt_emb = self.dt_projection(dt_sin) # (B, L, d_model)
return cal + time_emb + dt_emb
class UncertaintyEmbedding(nn.Module):
"""Learned embedding for trajectory uncertainty bins."""
def __init__(self, config: AirTrackConfig):
super().__init__()
self.embed = nn.Embedding(config.n_uncert_bins, config.d_model)
def forward(self, uncert_bins: torch.Tensor) -> torch.Tensor:
"""
Args:
uncert_bins: (B, L) long — uncertainty bin indices
Returns:
(B, L, d_model)
"""
return self.embed(uncert_bins)
class PromptEmbedding(nn.Module):
"""Learned prompt token embeddings for task/metadata conditioning."""
def __init__(self, config: AirTrackConfig):
super().__init__()
self.embed = nn.Embedding(config.n_prompt_tokens, config.d_model)
def forward(self, prompt_tokens: torch.Tensor) -> torch.Tensor:
"""
Args:
prompt_tokens: (B, n_prompt_len) long — prompt token IDs
Returns:
(B, n_prompt_len, d_model)
"""
return self.embed(prompt_tokens)
# ============================================================
# Positional Encoding
# ============================================================
class SinusoidalPositionalEncoding(nn.Module):
"""Standard sinusoidal positional encoding."""
def __init__(self, d_model: int, max_len: int = 512, dropout: float = 0.1):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # (1, max_len, d_model)
self.register_buffer('pe', pe)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""x: (B, L, d_model)"""
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)
# ============================================================
# Transformer Backbone
# ============================================================
class TransformerBlock(nn.Module):
"""Single transformer decoder block with causal attention."""
def __init__(self, config: AirTrackConfig):
super().__init__()
self.ln1 = nn.LayerNorm(config.d_model)
self.attn = nn.MultiheadAttention(
embed_dim=config.d_model,
num_heads=config.n_heads,
dropout=config.dropout,
batch_first=True,
)
self.ln2 = nn.LayerNorm(config.d_model)
self.ffn = nn.Sequential(
nn.Linear(config.d_model, config.d_ff),
nn.GELU(),
nn.Dropout(config.dropout),
nn.Linear(config.d_ff, config.d_model),
nn.Dropout(config.dropout),
)
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Args:
x: (B, L, d_model)
attn_mask: (L, L) causal mask
Returns:
(B, L, d_model)
"""
# Pre-norm architecture (like GPT-2)
h = self.ln1(x)
h, _ = self.attn(h, h, h, attn_mask=attn_mask, is_causal=(attn_mask is None))
x = x + h
h = self.ln2(x)
h = self.ffn(h)
x = x + h
return x
# ============================================================
# Output Heads
# ============================================================
class NextStatePredictionHead(nn.Module):
"""
Multi-head output for next-state prediction.
Predicts each feature type independently.
"""
def __init__(self, config: AirTrackConfig):
super().__init__()
# Geohash: predict 120 binary bits (sigmoid per bit)
if config.predict_geohash:
self.geohash_head = nn.Linear(config.d_model, config.geohash_bits)
# Continuous ENU regression (optional secondary objective)
if config.predict_continuous:
self.continuous_head = nn.Sequential(
nn.Linear(config.d_model, config.d_model // 2),
nn.GELU(),
nn.Linear(config.d_model // 2, 3), # (Δeast, Δnorth, Δup)
)
# Kinematic feature bin classification
self.cog_head = nn.Linear(config.d_model, config.n_cog_bins)
self.sog_head = nn.Linear(config.d_model, config.n_sog_bins)
self.rot_head = nn.Linear(config.d_model, config.n_rot_bins)
self.alt_rate_head = nn.Linear(config.d_model, config.n_alt_rate_bins)
self.config = config
def forward(self, hidden_states: torch.Tensor) -> Dict[str, torch.Tensor]:
"""
Args:
hidden_states: (B, L, d_model) — transformer output
Returns:
dict of logits/predictions for each feature
"""
out = {}
if self.config.predict_geohash:
out['geohash_logits'] = self.geohash_head(hidden_states) # (B, L, 120)
if self.config.predict_continuous:
out['continuous_pred'] = self.continuous_head(hidden_states) # (B, L, 3)
out['cog_logits'] = self.cog_head(hidden_states) # (B, L, n_cog_bins)
out['sog_logits'] = self.sog_head(hidden_states) # (B, L, n_sog_bins)
out['rot_logits'] = self.rot_head(hidden_states) # (B, L, n_rot_bins)
out['alt_rate_logits'] = self.alt_rate_head(hidden_states) # (B, L, n_alt_rate_bins)
return out
class ClassificationHead(nn.Module):
"""Downstream classification head (attached after pretraining)."""
def __init__(self, d_model: int, n_classes: int, dropout: float = 0.1):
super().__init__()
self.head = nn.Sequential(
nn.Linear(d_model, d_model // 2),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(d_model // 2, n_classes),
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Uses the BOS token representation (first position) for classification.
Args:
hidden_states: (B, L, d_model)
Returns:
(B, n_classes)
"""
cls_repr = hidden_states[:, 0, :] # BOS position
return self.head(cls_repr)
# ============================================================
# Main Model
# ============================================================
class AirTrackLM(nn.Module):
"""
AirTrackLM: Decoder-only transformer for air track next-state prediction.
Architecture:
Input → [4 Embedding Types fused additively] → Positional Encoding
→ N × TransformerBlock (causal attention)
→ Multi-head output (geohash + kinematic features)
"""
def __init__(self, config: AirTrackConfig):
super().__init__()
self.config = config
# === Embedding layers ===
# Geohash (spatial position)
if config.geohash_mode == 'absolute':
self.geohash_embed = GeohashEmbedding(config)
elif config.geohash_mode == 'continuous':
self.geohash_embed = ContinuousPositionEmbedding(config)
elif config.geohash_mode == 'none':
self.geohash_embed = None
else:
# relative and multi_res use same base as absolute
self.geohash_embed = GeohashEmbedding(config)
# Kinematic features
self.feature_embed = FeatureEmbedding(config)
# Temporal
self.temporal_embed = TemporalEmbedding(config)
# Uncertainty — single or multi-method
if config.use_multi_uncertainty and config.n_uncert_methods > 1:
from uncertainty import MultiUncertaintyEmbedding
self.uncertainty_embed = MultiUncertaintyEmbedding(
config.d_model, config.n_uncert_methods, config.n_uncert_bins
)
self._multi_uncert = True
else:
self.uncertainty_embed = UncertaintyEmbedding(config)
self._multi_uncert = False
# Heteroscedastic uncertainty head (learned aleatoric)
self.heteroscedastic_head = None
if config.use_heteroscedastic:
from uncertainty import HeteroscedasticHead
self.heteroscedastic_head = HeteroscedasticHead(config.d_model, n_outputs=6)
# Prompt
self.prompt_embed = PromptEmbedding(config)
# === Fusion projection ===
# After additive fusion, project through LayerNorm
self.fusion_ln = nn.LayerNorm(config.d_model)
# === Positional encoding ===
self.pos_encoding = SinusoidalPositionalEncoding(
config.d_model, config.max_seq_len, config.dropout
)
# === Transformer blocks ===
self.blocks = nn.ModuleList([
TransformerBlock(config) for _ in range(config.n_layers)
])
# Final layer norm
self.final_ln = nn.LayerNorm(config.d_model)
# === Output heads ===
self.prediction_head = NextStatePredictionHead(config)
# Classification head (optional, for downstream)
self.classification_head = None
# Initialize weights
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
torch.nn.init.ones_(module.weight)
torch.nn.init.zeros_(module.bias)
def attach_classification_head(self, n_classes: int):
"""Attach a classification head for downstream fine-tuning."""
self.classification_head = ClassificationHead(
self.config.d_model, n_classes, self.config.dropout
)
def get_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
"""Generate causal attention mask."""
mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1)
mask = mask.masked_fill(mask == 1, float('-inf'))
return mask
def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Forward pass.
Args:
batch: dict from AirTrackDataset.__getitem__ (batched)
Returns:
dict with prediction logits and optionally classification logits
"""
device = batch['cog_bins'].device
B = batch['cog_bins'].size(0)
# --- Build state embeddings ---
# Kinematic feature embedding
feat_emb = self.feature_embed(
batch['cog_bins'], batch['sog_bins'],
batch['rot_bins'], batch['alt_rate_bins']
) # (B, L, d_model)
# Temporal embedding
temp_emb = self.temporal_embed(
batch['hour'], batch['dow'], batch['month'],
batch['second_of_day'], batch['dt']
) # (B, L, d_model)
# Uncertainty embedding (single or multi-method)
if self._multi_uncert and 'uncert_bins_multi' in batch:
uncert_emb = self.uncertainty_embed(batch['uncert_bins_multi']) # (B, L, d_model)
else:
uncert_emb = self.uncertainty_embed(batch['uncert_bins']) # (B, L, d_model)
# Geohash embedding
if self.config.geohash_mode == 'continuous':
geo_emb = self.geohash_embed(batch['east'], batch['north'], batch['up'])
elif self.geohash_embed is not None:
geo_emb = self.geohash_embed(batch['geohash_bits']) # (B, L, d_model)
else:
geo_emb = torch.zeros_like(feat_emb)
# --- Additive fusion ---
state_emb = feat_emb + temp_emb + uncert_emb + geo_emb # (B, L, d_model)
state_emb = self.fusion_ln(state_emb)
# --- Prepend prompt tokens ---
prompt_emb = self.prompt_embed(batch['prompt']) # (B, n_prompt, d_model)
# Concatenate: [PROMPT | STATE_1 | STATE_2 | ... | STATE_T]
x = torch.cat([prompt_emb, state_emb], dim=1) # (B, n_prompt + L, d_model)
# --- Positional encoding ---
x = self.pos_encoding(x)
# --- Causal transformer ---
seq_len = x.size(1)
causal_mask = self.get_causal_mask(seq_len, device)
for block in self.blocks:
x = block(x, attn_mask=causal_mask)
x = self.final_ln(x)
# --- Split output ---
n_prompt = batch['prompt'].size(1)
prompt_output = x[:, :n_prompt, :] # (B, n_prompt, d_model)
state_output = x[:, n_prompt:, :] # (B, L, d_model)
# --- Prediction heads (on state output) ---
predictions = self.prediction_head(state_output)
# --- Heteroscedastic uncertainty (learned aleatoric) ---
if self.heteroscedastic_head is not None:
predictions['log_var'] = self.heteroscedastic_head(state_output) # (B, L, 6)
# --- Classification (optional) ---
if self.classification_head is not None:
predictions['class_logits'] = self.classification_head(x) # uses BOS at position 0
return predictions
def count_parameters(self) -> Dict[str, int]:
"""Count parameters by component."""
counts = {}
for name, module in [
('geohash_embed', self.geohash_embed),
('feature_embed', self.feature_embed),
('temporal_embed', self.temporal_embed),
('uncertainty_embed', self.uncertainty_embed),
('prompt_embed', self.prompt_embed),
('transformer_blocks', self.blocks),
('prediction_head', self.prediction_head),
]:
if module is not None:
counts[name] = sum(p.numel() for p in module.parameters())
counts['total'] = sum(p.numel() for p in self.parameters())
counts['trainable'] = sum(p.numel() for p in self.parameters() if p.requires_grad)
return counts
# ============================================================
# Loss Function
# ============================================================
class NextStateLoss(nn.Module):
"""
Multi-task loss for next-state prediction.
For each position t, the model predicts features at t+1.
Losses:
- Geohash: Binary cross-entropy per bit
- Kinematic features (COG, SOG, ROT, alt_rate): Cross-entropy per feature
- Continuous ENU: MSE (optional)
"""
def __init__(self, config: AirTrackConfig, loss_weights: Optional[Dict[str, float]] = None):
super().__init__()
self.config = config
# Default loss weights (equal)
self.weights = loss_weights or {
'geohash': 1.0,
'continuous': 0.5,
'cog': 1.0,
'sog': 1.0,
'rot': 1.0,
'alt_rate': 1.0,
}
self.ce = nn.CrossEntropyLoss(reduction='mean')
self.bce = nn.BCEWithLogitsLoss(reduction='mean')
self.mse = nn.MSELoss(reduction='mean')
def forward(
self,
predictions: Dict[str, torch.Tensor],
batch: Dict[str, torch.Tensor],
) -> Tuple[torch.Tensor, Dict[str, float]]:
"""
Compute loss. Targets are shifted by 1 (predict next state).
predictions[key] is at positions [0, 1, ..., L-1]
targets are batch[key] at positions [1, 2, ..., L]
So we compare predictions[:, :-1, :] with targets[:, 1:, :]
"""
losses = {}
# --- Geohash binary prediction ---
if self.config.predict_geohash and 'geohash_logits' in predictions:
# predictions: (B, L, 120), targets: (B, L, 120) float
pred_geo = predictions['geohash_logits'][:, :-1, :] # (B, L-1, 120)
target_geo = batch['geohash_bits'][:, 1:, :] # (B, L-1, 120)
losses['geohash'] = self.bce(pred_geo, target_geo) * self.weights['geohash']
# --- Continuous ENU regression (predict delta in km, not raw meters) ---
if self.config.predict_continuous and 'continuous_pred' in predictions:
pred_cont = predictions['continuous_pred'][:, :-1, :] # (B, L-1, 3)
# Target is delta-ENU: position(t+1) - position(t), normalized to km
delta_east = (batch['east'][:, 1:] - batch['east'][:, :-1]) / 1000.0
delta_north = (batch['north'][:, 1:] - batch['north'][:, :-1]) / 1000.0
delta_up = (batch['up'][:, 1:] - batch['up'][:, :-1]) / 1000.0
target_delta = torch.stack([delta_east, delta_north, delta_up], dim=-1)
losses['continuous'] = self.mse(pred_cont, target_delta) * self.weights['continuous']
# --- COG ---
pred_cog = predictions['cog_logits'][:, :-1, :] # (B, L-1, n_cog_bins)
target_cog = batch['cog_bins'][:, 1:] # (B, L-1)
losses['cog'] = self.ce(pred_cog.reshape(-1, pred_cog.size(-1)), target_cog.reshape(-1)) * self.weights['cog']
# --- SOG ---
pred_sog = predictions['sog_logits'][:, :-1, :]
target_sog = batch['sog_bins'][:, 1:]
losses['sog'] = self.ce(pred_sog.reshape(-1, pred_sog.size(-1)), target_sog.reshape(-1)) * self.weights['sog']
# --- ROT ---
pred_rot = predictions['rot_logits'][:, :-1, :]
target_rot = batch['rot_bins'][:, 1:]
losses['rot'] = self.ce(pred_rot.reshape(-1, pred_rot.size(-1)), target_rot.reshape(-1)) * self.weights['rot']
# --- Alt rate ---
pred_ar = predictions['alt_rate_logits'][:, :-1, :]
target_ar = batch['alt_rate_bins'][:, 1:]
losses['alt_rate'] = self.ce(pred_ar.reshape(-1, pred_ar.size(-1)), target_ar.reshape(-1)) * self.weights['alt_rate']
# --- Heteroscedastic regularization (learned aleatoric uncertainty) ---
if 'log_var' in predictions:
log_var = predictions['log_var'][:, :-1, :] # (B, L-1, 6)
# Clamp log_var to prevent collapse: [-5, 5] range
log_var_clamped = torch.clamp(log_var, -5.0, 5.0)
# Regularize toward 0 (unit variance prior)
losses['log_var_reg'] = 0.1 * (log_var_clamped ** 2).mean()
# Total loss
total_loss = sum(losses.values())
# Log individual losses
loss_log = {k: v.item() for k, v in losses.items()}
loss_log['total'] = total_loss.item()
return total_loss, loss_log
# ============================================================
# Quick test
# ============================================================
if __name__ == '__main__':
config = AirTrackConfig()
model = AirTrackLM(config)
# Print parameter counts
counts = model.count_parameters()
print("Parameter counts:")
for name, count in counts.items():
print(f" {name}: {count:,}")
# Test forward pass with dummy data
B, L = 2, 65 # batch=2, seq_len=65 (64 states + 1 for target shift)
n_prompt = config.n_prompt_len
batch = {
'geohash_bits': torch.randn(B, L, config.geohash_bits),
'cog_bins': torch.randint(0, config.n_cog_bins, (B, L)),
'sog_bins': torch.randint(0, config.n_sog_bins, (B, L)),
'rot_bins': torch.randint(0, config.n_rot_bins, (B, L)),
'alt_rate_bins': torch.randint(0, config.n_alt_rate_bins, (B, L)),
'uncert_bins': torch.randint(0, config.n_uncert_bins, (B, L)),
'hour': torch.randint(0, 24, (B, L)),
'dow': torch.randint(0, 7, (B, L)),
'month': torch.randint(0, 12, (B, L)),
'second_of_day': torch.rand(B, L) * 86400,
'dt': torch.ones(B, L) * 5.0,
'prompt': torch.randint(0, config.n_prompt_tokens, (B, n_prompt)),
'east': torch.randn(B, L) * 1000,
'north': torch.randn(B, L) * 1000,
'up': torch.randn(B, L) * 1000,
}
predictions = model(batch)
print("\nPrediction shapes:")
for k, v in predictions.items():
print(f" {k}: {v.shape}")
# Test loss
loss_fn = NextStateLoss(config)
total_loss, loss_log = loss_fn(predictions, batch)
print(f"\nLoss: {loss_log}")