AirTrackLM / model.py
Jdice27's picture
Upload model.py with huggingface_hub
2972ac7 verified
"""
AirTrackLM - Model Architecture
================================
Decoder-only transformer with 4 embedding types for air track next-state prediction.
Embedding types:
1. Geohash: 120-bit binary (40 per ENU axis) → MLP → d_model
2. Kinematic: Learned embeddings for discretized COG/SOG/ROT/alt_rate
3. Temporal: Sinusoidal second-of-day (sub-second) + learned hour/dow/month + Δt
4. Uncertainty: Multi-method learned embeddings with attention fusion
Architecture:
- Additive embedding fusion
- Prompt tokens prepended
- Pre-norm decoder-only transformer with causal masking
- Multi-head output (geohash bits + kinematic bins + continuous ENU regression)
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Dict, Tuple
from dataclasses import dataclass
@dataclass
class AirTrackConfig:
d_model: int = 256
n_heads: int = 8
n_layers: int = 8
d_ff: int = 1024
dropout: float = 0.1
max_seq_len: int = 256
# Geohash
geohash_bits: int = 120 # 40 × 3 axes
geohash_hidden: int = 64
# Feature bins
n_cog_bins: int = 180 # 2° resolution
n_sog_bins: int = 300 # 2-knot resolution
n_rot_bins: int = 120 # 0.1°/s resolution
n_alt_rate_bins: int = 120 # 100 ft/min resolution
# Temporal
n_hours: int = 24
n_dow: int = 7
n_months: int = 12
time_sinusoidal_dim: int = 32
# Uncertainty
n_uncert_bins: int = 16
n_uncert_methods: int = 4
use_multi_uncertainty: bool = True
use_heteroscedastic: bool = True
# Prompt
n_prompt_tokens: int = 23
n_prompt_len: int = 5
# Output
predict_geohash: bool = True
predict_continuous: bool = True
geohash_mode: str = 'absolute'
# ============================================================
# Embedding Modules
# ============================================================
class GeohashEmbedding(nn.Module):
"""Binary geohash → MLP → d_model."""
def __init__(self, config):
super().__init__()
self.projection = nn.Sequential(
nn.Linear(config.geohash_bits, config.geohash_hidden),
nn.ReLU(),
nn.Linear(config.geohash_hidden, config.d_model),
)
def forward(self, geohash_bits):
return self.projection(geohash_bits)
class ContinuousPositionEmbedding(nn.Module):
"""Ablation: direct linear projection of continuous ENU."""
def __init__(self, config):
super().__init__()
self.projection = nn.Sequential(
nn.Linear(3, config.geohash_hidden),
nn.ReLU(),
nn.Linear(config.geohash_hidden, config.d_model),
)
def forward(self, east, north, up):
pos = torch.stack([east, north, up], dim=-1)
return self.projection(pos)
class FeatureEmbedding(nn.Module):
"""Learned embeddings for discretized kinematic features, summed."""
def __init__(self, config):
super().__init__()
self.cog_embed = nn.Embedding(config.n_cog_bins, config.d_model)
self.sog_embed = nn.Embedding(config.n_sog_bins, config.d_model)
self.rot_embed = nn.Embedding(config.n_rot_bins, config.d_model)
self.alt_rate_embed = nn.Embedding(config.n_alt_rate_bins, config.d_model)
def forward(self, cog_bins, sog_bins, rot_bins, alt_rate_bins):
return (self.cog_embed(cog_bins) + self.sog_embed(sog_bins) +
self.rot_embed(rot_bins) + self.alt_rate_embed(alt_rate_bins))
class TemporalEmbedding(nn.Module):
"""
Temporal: sinusoidal second-of-day (sub-second precision) + learned calendar + Δt.
"""
def __init__(self, config):
super().__init__()
self.hour_embed = nn.Embedding(config.n_hours, config.d_model)
self.dow_embed = nn.Embedding(config.n_dow, config.d_model)
self.month_embed = nn.Embedding(config.n_months, config.d_model)
self.time_sin_dim = config.time_sinusoidal_dim
self.time_projection = nn.Linear(config.time_sinusoidal_dim * 2, config.d_model)
self.dt_projection = nn.Linear(config.time_sinusoidal_dim * 2, config.d_model)
# Multiple frequency bases for sub-second precision
freqs = torch.exp(torch.arange(0, config.time_sinusoidal_dim, dtype=torch.float32) *
-(math.log(86400.0) / config.time_sinusoidal_dim))
self.register_buffer('time_freqs', freqs)
dt_freqs = torch.exp(torch.arange(0, config.time_sinusoidal_dim, dtype=torch.float32) *
-(math.log(3600.0) / config.time_sinusoidal_dim))
self.register_buffer('dt_freqs', dt_freqs)
def _sinusoidal(self, values, freqs):
angles = values.unsqueeze(-1) * freqs.unsqueeze(0).unsqueeze(0) * 2 * math.pi
return torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)
def forward(self, hour, dow, month, second_of_day, dt):
cal = self.hour_embed(hour) + self.dow_embed(dow) + self.month_embed(month)
time_emb = self.time_projection(self._sinusoidal(second_of_day, self.time_freqs))
dt_emb = self.dt_projection(self._sinusoidal(dt, self.dt_freqs))
return cal + time_emb + dt_emb
class UncertaintyEmbedding(nn.Module):
def __init__(self, config):
super().__init__()
self.embed = nn.Embedding(config.n_uncert_bins, config.d_model)
def forward(self, uncert_bins):
return self.embed(uncert_bins)
class PromptEmbedding(nn.Module):
def __init__(self, config):
super().__init__()
self.embed = nn.Embedding(config.n_prompt_tokens, config.d_model)
def forward(self, prompt_tokens):
return self.embed(prompt_tokens)
# ============================================================
# Positional Encoding
# ============================================================
class SinusoidalPositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=512, dropout=0.1):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x):
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)
# ============================================================
# Transformer
# ============================================================
class TransformerBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.ln1 = nn.LayerNorm(config.d_model)
self.attn = nn.MultiheadAttention(
embed_dim=config.d_model, num_heads=config.n_heads,
dropout=config.dropout, batch_first=True,
)
self.ln2 = nn.LayerNorm(config.d_model)
self.ffn = nn.Sequential(
nn.Linear(config.d_model, config.d_ff),
nn.GELU(),
nn.Dropout(config.dropout),
nn.Linear(config.d_ff, config.d_model),
nn.Dropout(config.dropout),
)
def forward(self, x, attn_mask=None):
h = self.ln1(x)
h, _ = self.attn(h, h, h, attn_mask=attn_mask, is_causal=(attn_mask is None))
x = x + h
h = self.ln2(x)
x = x + self.ffn(h)
return x
# ============================================================
# Output Heads
# ============================================================
class NextStatePredictionHead(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
if config.predict_geohash:
self.geohash_head = nn.Linear(config.d_model, config.geohash_bits)
if config.predict_continuous:
self.continuous_head = nn.Sequential(
nn.Linear(config.d_model, config.d_model // 2),
nn.GELU(),
nn.Linear(config.d_model // 2, 3),
)
self.cog_head = nn.Linear(config.d_model, config.n_cog_bins)
self.sog_head = nn.Linear(config.d_model, config.n_sog_bins)
self.rot_head = nn.Linear(config.d_model, config.n_rot_bins)
self.alt_rate_head = nn.Linear(config.d_model, config.n_alt_rate_bins)
def forward(self, hidden_states):
out = {}
if self.config.predict_geohash:
out['geohash_logits'] = self.geohash_head(hidden_states)
if self.config.predict_continuous:
out['continuous_pred'] = self.continuous_head(hidden_states)
out['cog_logits'] = self.cog_head(hidden_states)
out['sog_logits'] = self.sog_head(hidden_states)
out['rot_logits'] = self.rot_head(hidden_states)
out['alt_rate_logits'] = self.alt_rate_head(hidden_states)
return out
class ClassificationHead(nn.Module):
def __init__(self, d_model, n_classes, dropout=0.1):
super().__init__()
self.head = nn.Sequential(
nn.Linear(d_model, d_model // 2), nn.GELU(),
nn.Dropout(dropout), nn.Linear(d_model // 2, n_classes),
)
def forward(self, hidden_states):
return self.head(hidden_states[:, 0, :])
# ============================================================
# Main Model
# ============================================================
class AirTrackLM(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
# Geohash embedding
if config.geohash_mode == 'continuous':
self.geohash_embed = ContinuousPositionEmbedding(config)
elif config.geohash_mode == 'none':
self.geohash_embed = None
else:
self.geohash_embed = GeohashEmbedding(config)
self.feature_embed = FeatureEmbedding(config)
self.temporal_embed = TemporalEmbedding(config)
# Uncertainty embedding
if config.use_multi_uncertainty and config.n_uncert_methods > 1:
from uncertainty import MultiUncertaintyEmbedding
self.uncertainty_embed = MultiUncertaintyEmbedding(
config.d_model, config.n_uncert_methods, config.n_uncert_bins
)
self._multi_uncert = True
else:
self.uncertainty_embed = UncertaintyEmbedding(config)
self._multi_uncert = False
# Heteroscedastic head
self.heteroscedastic_head = None
if config.use_heteroscedastic:
from uncertainty import HeteroscedasticHead
self.heteroscedastic_head = HeteroscedasticHead(config.d_model, n_outputs=6)
self.prompt_embed = PromptEmbedding(config)
self.fusion_ln = nn.LayerNorm(config.d_model)
self.pos_encoding = SinusoidalPositionalEncoding(config.d_model, config.max_seq_len, config.dropout)
self.blocks = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
self.final_ln = nn.LayerNorm(config.d_model)
self.prediction_head = NextStatePredictionHead(config)
self.classification_head = None
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
def attach_classification_head(self, n_classes):
self.classification_head = ClassificationHead(self.config.d_model, n_classes, self.config.dropout)
def get_causal_mask(self, seq_len, device):
mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1)
return mask.masked_fill(mask == 1, float('-inf'))
def forward(self, batch):
device = batch['cog_bins'].device
# Feature embedding
feat_emb = self.feature_embed(
batch['cog_bins'], batch['sog_bins'],
batch['rot_bins'], batch['alt_rate_bins']
)
# Temporal embedding
temp_emb = self.temporal_embed(
batch['hour'], batch['dow'], batch['month'],
batch['second_of_day'], batch['dt']
)
# Uncertainty embedding
if self._multi_uncert and 'uncert_bins_multi' in batch:
uncert_emb = self.uncertainty_embed(batch['uncert_bins_multi'])
else:
uncert_emb = self.uncertainty_embed(batch['uncert_bins'])
# Geohash embedding
if self.config.geohash_mode == 'continuous':
geo_emb = self.geohash_embed(batch['east'], batch['north'], batch['up'])
elif self.geohash_embed is not None:
geo_emb = self.geohash_embed(batch['geohash_bits'])
else:
geo_emb = torch.zeros_like(feat_emb)
# Additive fusion
state_emb = feat_emb + temp_emb + uncert_emb + geo_emb
state_emb = self.fusion_ln(state_emb)
# Prepend prompt
prompt_emb = self.prompt_embed(batch['prompt'])
x = torch.cat([prompt_emb, state_emb], dim=1)
# Positional encoding + transformer
x = self.pos_encoding(x)
seq_len = x.size(1)
causal_mask = self.get_causal_mask(seq_len, device)
for block in self.blocks:
x = block(x, attn_mask=causal_mask)
x = self.final_ln(x)
# Split prompt / state outputs
n_prompt = batch['prompt'].size(1)
state_output = x[:, n_prompt:, :]
# Predictions
predictions = self.prediction_head(state_output)
if self.heteroscedastic_head is not None:
predictions['log_var'] = self.heteroscedastic_head(state_output)
if self.classification_head is not None:
predictions['class_logits'] = self.classification_head(x)
return predictions
def count_parameters(self):
counts = {}
for name, module in [
('geohash_embed', self.geohash_embed),
('feature_embed', self.feature_embed),
('temporal_embed', self.temporal_embed),
('uncertainty_embed', self.uncertainty_embed),
('prompt_embed', self.prompt_embed),
('transformer_blocks', self.blocks),
('prediction_head', self.prediction_head),
]:
if module is not None:
counts[name] = sum(p.numel() for p in module.parameters())
counts['total'] = sum(p.numel() for p in self.parameters())
counts['trainable'] = sum(p.numel() for p in self.parameters() if p.requires_grad)
return counts
# ============================================================
# Loss Function
# ============================================================
class NextStateLoss(nn.Module):
def __init__(self, config, loss_weights=None):
super().__init__()
self.config = config
self.weights = loss_weights or {
'geohash': 1.0, 'continuous': 0.5,
'cog': 1.0, 'sog': 1.0, 'rot': 1.0, 'alt_rate': 1.0,
}
self.ce = nn.CrossEntropyLoss(reduction='mean')
self.bce = nn.BCEWithLogitsLoss(reduction='mean')
self.mse = nn.MSELoss(reduction='mean')
def forward(self, predictions, batch):
losses = {}
if self.config.predict_geohash and 'geohash_logits' in predictions:
pred_geo = predictions['geohash_logits'][:, :-1, :]
target_geo = batch['geohash_bits'][:, 1:, :]
losses['geohash'] = self.bce(pred_geo, target_geo) * self.weights['geohash']
if self.config.predict_continuous and 'continuous_pred' in predictions:
pred_cont = predictions['continuous_pred'][:, :-1, :]
delta_east = (batch['east'][:, 1:] - batch['east'][:, :-1]) / 1000.0
delta_north = (batch['north'][:, 1:] - batch['north'][:, :-1]) / 1000.0
delta_up = (batch['up'][:, 1:] - batch['up'][:, :-1]) / 1000.0
target_delta = torch.stack([delta_east, delta_north, delta_up], dim=-1)
losses['continuous'] = self.mse(pred_cont, target_delta) * self.weights['continuous']
for feat in ['cog', 'sog', 'rot', 'alt_rate']:
pred = predictions[f'{feat}_logits'][:, :-1, :]
target = batch[f'{feat}_bins'][:, 1:]
losses[feat] = self.ce(pred.reshape(-1, pred.size(-1)), target.reshape(-1)) * self.weights[feat]
if 'log_var' in predictions:
log_var = torch.clamp(predictions['log_var'][:, :-1, :], -5.0, 5.0)
losses['log_var_reg'] = 0.1 * (log_var ** 2).mean()
total_loss = sum(losses.values())
loss_log = {k: v.item() for k, v in losses.items()}
loss_log['total'] = total_loss.item()
return total_loss, loss_log
if __name__ == '__main__':
config = AirTrackConfig()
model = AirTrackLM(config)
counts = model.count_parameters()
print("Parameter counts:")
for name, count in counts.items():
print(f" {name}: {count:,}")
B, L = 2, 65
batch = {
'geohash_bits': torch.randn(B, L, config.geohash_bits),
'cog_bins': torch.randint(0, config.n_cog_bins, (B, L)),
'sog_bins': torch.randint(0, config.n_sog_bins, (B, L)),
'rot_bins': torch.randint(0, config.n_rot_bins, (B, L)),
'alt_rate_bins': torch.randint(0, config.n_alt_rate_bins, (B, L)),
'uncert_bins': torch.randint(0, config.n_uncert_bins, (B, L)),
'uncert_bins_multi': torch.randint(0, config.n_uncert_bins, (B, L, config.n_uncert_methods)),
'hour': torch.randint(0, 24, (B, L)),
'dow': torch.randint(0, 7, (B, L)),
'month': torch.randint(0, 12, (B, L)),
'second_of_day': torch.rand(B, L) * 86400,
'dt': torch.ones(B, L) * 5.0,
'prompt': torch.randint(0, config.n_prompt_tokens, (B, config.n_prompt_len)),
'east': torch.randn(B, L) * 1000,
'north': torch.randn(B, L) * 1000,
'up': torch.randn(B, L) * 1000,
}
predictions = model(batch)
print("\nPrediction shapes:")
for k, v in predictions.items():
print(f" {k}: {v.shape}")
loss_fn = NextStateLoss(config)
total_loss, loss_log = loss_fn(predictions, batch)
print(f"\nLoss: {loss_log}")