| """ |
| AirTrackLM - Model Architecture |
| ================================ |
| Decoder-only transformer with 4 embedding types for air track next-state prediction. |
| |
| Embedding types (following LLM4STP, adapted for aviation): |
| 1. Geohash: 40-bit binary per ENU axis (120 bits total) → Linear projection → d_model |
| 2. Temporal: Sinusoidal second-of-day + learned hour/dow/month embeddings |
| 3. Uncertainty: Learned embedding from trajectory smoothness bins |
| 4. Prompt: Learned tokens for task/aircraft/phase/region metadata |
| |
| Core architecture: |
| - Additive embedding fusion (E_geo + E_feat + E_temp + E_uncert) |
| - Prompt tokens prepended to sequence |
| - Causal (GPT-style) multi-head self-attention |
| - Multi-head output: separate prediction per feature type |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Optional, Dict, Tuple |
| from dataclasses import dataclass |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class AirTrackConfig: |
| """Model configuration.""" |
| |
| |
| d_model: int = 256 |
| n_heads: int = 8 |
| n_layers: int = 8 |
| d_ff: int = 1024 |
| dropout: float = 0.1 |
| max_seq_len: int = 256 |
| |
| |
| geohash_bits: int = 120 |
| geohash_hidden: int = 64 |
| |
| |
| n_cog_bins: int = 180 |
| n_sog_bins: int = 300 |
| n_rot_bins: int = 120 |
| n_alt_rate_bins: int = 120 |
| |
| |
| n_hours: int = 24 |
| n_dow: int = 7 |
| n_months: int = 12 |
| time_sinusoidal_dim: int = 32 |
| |
| |
| n_uncert_bins: int = 16 |
| n_uncert_methods: int = 4 |
| use_multi_uncertainty: bool = True |
| use_heteroscedastic: bool = True |
| |
| |
| n_prompt_tokens: int = 23 |
| n_prompt_len: int = 5 |
| |
| |
| |
| predict_geohash: bool = True |
| predict_continuous: bool = True |
| |
| |
| geohash_mode: str = 'absolute' |
|
|
|
|
| |
| |
| |
|
|
| class GeohashEmbedding(nn.Module): |
| """ |
| Binary geohash embedding following LLM4STP. |
| Projects 120-bit binary vector through: |
| Linear(120 → geohash_hidden) → ReLU → Linear(geohash_hidden → d_model) |
| |
| LLM4STP uses Conv1d on the bits, but we use MLP for simplicity |
| since each timestep's 120 bits are independent. |
| """ |
| |
| def __init__(self, config: AirTrackConfig): |
| super().__init__() |
| self.projection = nn.Sequential( |
| nn.Linear(config.geohash_bits, config.geohash_hidden), |
| nn.ReLU(), |
| nn.Linear(config.geohash_hidden, config.d_model), |
| ) |
| |
| def forward(self, geohash_bits: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| geohash_bits: (batch, seq_len, 120) float tensor of binary geohash |
| Returns: |
| (batch, seq_len, d_model) |
| """ |
| return self.projection(geohash_bits) |
|
|
|
|
| class ContinuousPositionEmbedding(nn.Module): |
| """Ablation variant V5: direct linear projection of continuous ENU coordinates.""" |
| |
| def __init__(self, config: AirTrackConfig): |
| super().__init__() |
| self.projection = nn.Sequential( |
| nn.Linear(3, config.geohash_hidden), |
| nn.ReLU(), |
| nn.Linear(config.geohash_hidden, config.d_model), |
| ) |
| |
| def forward(self, east: torch.Tensor, north: torch.Tensor, up: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| east, north, up: (batch, seq_len) each |
| Returns: |
| (batch, seq_len, d_model) |
| """ |
| pos = torch.stack([east, north, up], dim=-1) |
| return self.projection(pos) |
|
|
|
|
| class FeatureEmbedding(nn.Module): |
| """ |
| Learned embedding tables for discretized kinematic features. |
| Each feature has its own embedding table, all outputs summed. |
| """ |
| |
| def __init__(self, config: AirTrackConfig): |
| super().__init__() |
| self.cog_embed = nn.Embedding(config.n_cog_bins, config.d_model) |
| self.sog_embed = nn.Embedding(config.n_sog_bins, config.d_model) |
| self.rot_embed = nn.Embedding(config.n_rot_bins, config.d_model) |
| self.alt_rate_embed = nn.Embedding(config.n_alt_rate_bins, config.d_model) |
| |
| def forward( |
| self, |
| cog_bins: torch.Tensor, |
| sog_bins: torch.Tensor, |
| rot_bins: torch.Tensor, |
| alt_rate_bins: torch.Tensor, |
| ) -> torch.Tensor: |
| """ |
| Args: |
| *_bins: (batch, seq_len) long tensors of bin indices |
| Returns: |
| (batch, seq_len, d_model) — sum of all feature embeddings |
| """ |
| return ( |
| self.cog_embed(cog_bins) + |
| self.sog_embed(sog_bins) + |
| self.rot_embed(rot_bins) + |
| self.alt_rate_embed(alt_rate_bins) |
| ) |
|
|
|
|
| class TemporalEmbedding(nn.Module): |
| """ |
| Temporal embedding combining: |
| 1. Sinusoidal encoding of second-of-day (sub-second resolution) |
| 2. Learned embeddings for hour, day-of-week, month |
| 3. Sinusoidal encoding of delta-t (time since previous state) |
| |
| The sinusoidal encoding gives sub-second precision since it operates |
| on continuous float seconds, not discrete bins. |
| """ |
| |
| def __init__(self, config: AirTrackConfig): |
| super().__init__() |
| |
| |
| self.hour_embed = nn.Embedding(config.n_hours, config.d_model) |
| self.dow_embed = nn.Embedding(config.n_dow, config.d_model) |
| self.month_embed = nn.Embedding(config.n_months, config.d_model) |
| |
| |
| |
| self.time_sin_dim = config.time_sinusoidal_dim |
| self.time_projection = nn.Linear(config.time_sinusoidal_dim * 2, config.d_model) |
| |
| |
| self.dt_projection = nn.Linear(config.time_sinusoidal_dim * 2, config.d_model) |
| |
| |
| |
| freqs = torch.exp(torch.arange(0, config.time_sinusoidal_dim, dtype=torch.float32) * |
| -(math.log(86400.0) / config.time_sinusoidal_dim)) |
| self.register_buffer('time_freqs', freqs) |
| |
| dt_freqs = torch.exp(torch.arange(0, config.time_sinusoidal_dim, dtype=torch.float32) * |
| -(math.log(3600.0) / config.time_sinusoidal_dim)) |
| self.register_buffer('dt_freqs', dt_freqs) |
| |
| def sinusoidal_encode(self, values: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: |
| """ |
| Encode continuous values with multiple sinusoidal frequencies. |
| |
| Args: |
| values: (batch, seq_len) — continuous values |
| freqs: (dim,) — frequency bases |
| Returns: |
| (batch, seq_len, dim*2) — sin and cos features |
| """ |
| |
| angles = values.unsqueeze(-1) * freqs.unsqueeze(0).unsqueeze(0) * 2 * math.pi |
| return torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1) |
| |
| def forward( |
| self, |
| hour: torch.Tensor, |
| dow: torch.Tensor, |
| month: torch.Tensor, |
| second_of_day: torch.Tensor, |
| dt: torch.Tensor, |
| ) -> torch.Tensor: |
| """ |
| Args: |
| hour: (B, L) long — hour of day [0, 23] |
| dow: (B, L) long — day of week [0, 6] |
| month: (B, L) long — month [0, 11] |
| second_of_day: (B, L) float — seconds within day [0, 86400) |
| dt: (B, L) float — delta-t in seconds |
| Returns: |
| (B, L, d_model) |
| """ |
| |
| cal = self.hour_embed(hour) + self.dow_embed(dow) + self.month_embed(month) |
| |
| |
| time_sin = self.sinusoidal_encode(second_of_day, self.time_freqs) |
| time_emb = self.time_projection(time_sin) |
| |
| |
| dt_sin = self.sinusoidal_encode(dt, self.dt_freqs) |
| dt_emb = self.dt_projection(dt_sin) |
| |
| return cal + time_emb + dt_emb |
|
|
|
|
| class UncertaintyEmbedding(nn.Module): |
| """Learned embedding for trajectory uncertainty bins.""" |
| |
| def __init__(self, config: AirTrackConfig): |
| super().__init__() |
| self.embed = nn.Embedding(config.n_uncert_bins, config.d_model) |
| |
| def forward(self, uncert_bins: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| uncert_bins: (B, L) long — uncertainty bin indices |
| Returns: |
| (B, L, d_model) |
| """ |
| return self.embed(uncert_bins) |
|
|
|
|
| class PromptEmbedding(nn.Module): |
| """Learned prompt token embeddings for task/metadata conditioning.""" |
| |
| def __init__(self, config: AirTrackConfig): |
| super().__init__() |
| self.embed = nn.Embedding(config.n_prompt_tokens, config.d_model) |
| |
| def forward(self, prompt_tokens: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| prompt_tokens: (B, n_prompt_len) long — prompt token IDs |
| Returns: |
| (B, n_prompt_len, d_model) |
| """ |
| return self.embed(prompt_tokens) |
|
|
|
|
| |
| |
| |
|
|
| class SinusoidalPositionalEncoding(nn.Module): |
| """Standard sinusoidal positional encoding.""" |
| |
| def __init__(self, d_model: int, max_len: int = 512, dropout: float = 0.1): |
| super().__init__() |
| self.dropout = nn.Dropout(p=dropout) |
| |
| pe = torch.zeros(max_len, d_model) |
| position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) |
| div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)) |
| pe[:, 0::2] = torch.sin(position * div_term) |
| pe[:, 1::2] = torch.cos(position * div_term) |
| pe = pe.unsqueeze(0) |
| self.register_buffer('pe', pe) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """x: (B, L, d_model)""" |
| x = x + self.pe[:, :x.size(1)] |
| return self.dropout(x) |
|
|
|
|
| |
| |
| |
|
|
| class TransformerBlock(nn.Module): |
| """Single transformer decoder block with causal attention.""" |
| |
| def __init__(self, config: AirTrackConfig): |
| super().__init__() |
| |
| self.ln1 = nn.LayerNorm(config.d_model) |
| self.attn = nn.MultiheadAttention( |
| embed_dim=config.d_model, |
| num_heads=config.n_heads, |
| dropout=config.dropout, |
| batch_first=True, |
| ) |
| self.ln2 = nn.LayerNorm(config.d_model) |
| self.ffn = nn.Sequential( |
| nn.Linear(config.d_model, config.d_ff), |
| nn.GELU(), |
| nn.Dropout(config.dropout), |
| nn.Linear(config.d_ff, config.d_model), |
| nn.Dropout(config.dropout), |
| ) |
| |
| def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
| """ |
| Args: |
| x: (B, L, d_model) |
| attn_mask: (L, L) causal mask |
| Returns: |
| (B, L, d_model) |
| """ |
| |
| h = self.ln1(x) |
| h, _ = self.attn(h, h, h, attn_mask=attn_mask, is_causal=(attn_mask is None)) |
| x = x + h |
| |
| h = self.ln2(x) |
| h = self.ffn(h) |
| x = x + h |
| |
| return x |
|
|
|
|
| |
| |
| |
|
|
| class NextStatePredictionHead(nn.Module): |
| """ |
| Multi-head output for next-state prediction. |
| Predicts each feature type independently. |
| """ |
| |
| def __init__(self, config: AirTrackConfig): |
| super().__init__() |
| |
| |
| if config.predict_geohash: |
| self.geohash_head = nn.Linear(config.d_model, config.geohash_bits) |
| |
| |
| if config.predict_continuous: |
| self.continuous_head = nn.Sequential( |
| nn.Linear(config.d_model, config.d_model // 2), |
| nn.GELU(), |
| nn.Linear(config.d_model // 2, 3), |
| ) |
| |
| |
| self.cog_head = nn.Linear(config.d_model, config.n_cog_bins) |
| self.sog_head = nn.Linear(config.d_model, config.n_sog_bins) |
| self.rot_head = nn.Linear(config.d_model, config.n_rot_bins) |
| self.alt_rate_head = nn.Linear(config.d_model, config.n_alt_rate_bins) |
| |
| self.config = config |
| |
| def forward(self, hidden_states: torch.Tensor) -> Dict[str, torch.Tensor]: |
| """ |
| Args: |
| hidden_states: (B, L, d_model) — transformer output |
| Returns: |
| dict of logits/predictions for each feature |
| """ |
| out = {} |
| |
| if self.config.predict_geohash: |
| out['geohash_logits'] = self.geohash_head(hidden_states) |
| |
| if self.config.predict_continuous: |
| out['continuous_pred'] = self.continuous_head(hidden_states) |
| |
| out['cog_logits'] = self.cog_head(hidden_states) |
| out['sog_logits'] = self.sog_head(hidden_states) |
| out['rot_logits'] = self.rot_head(hidden_states) |
| out['alt_rate_logits'] = self.alt_rate_head(hidden_states) |
| |
| return out |
|
|
|
|
| class ClassificationHead(nn.Module): |
| """Downstream classification head (attached after pretraining).""" |
| |
| def __init__(self, d_model: int, n_classes: int, dropout: float = 0.1): |
| super().__init__() |
| self.head = nn.Sequential( |
| nn.Linear(d_model, d_model // 2), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(d_model // 2, n_classes), |
| ) |
| |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| """ |
| Uses the BOS token representation (first position) for classification. |
| |
| Args: |
| hidden_states: (B, L, d_model) |
| Returns: |
| (B, n_classes) |
| """ |
| cls_repr = hidden_states[:, 0, :] |
| return self.head(cls_repr) |
|
|
|
|
| |
| |
| |
|
|
| class AirTrackLM(nn.Module): |
| """ |
| AirTrackLM: Decoder-only transformer for air track next-state prediction. |
| |
| Architecture: |
| Input → [4 Embedding Types fused additively] → Positional Encoding |
| → N × TransformerBlock (causal attention) |
| → Multi-head output (geohash + kinematic features) |
| """ |
| |
| def __init__(self, config: AirTrackConfig): |
| super().__init__() |
| self.config = config |
| |
| |
| |
| |
| if config.geohash_mode == 'absolute': |
| self.geohash_embed = GeohashEmbedding(config) |
| elif config.geohash_mode == 'continuous': |
| self.geohash_embed = ContinuousPositionEmbedding(config) |
| elif config.geohash_mode == 'none': |
| self.geohash_embed = None |
| else: |
| |
| self.geohash_embed = GeohashEmbedding(config) |
| |
| |
| self.feature_embed = FeatureEmbedding(config) |
| |
| |
| self.temporal_embed = TemporalEmbedding(config) |
| |
| |
| if config.use_multi_uncertainty and config.n_uncert_methods > 1: |
| from uncertainty import MultiUncertaintyEmbedding |
| self.uncertainty_embed = MultiUncertaintyEmbedding( |
| config.d_model, config.n_uncert_methods, config.n_uncert_bins |
| ) |
| self._multi_uncert = True |
| else: |
| self.uncertainty_embed = UncertaintyEmbedding(config) |
| self._multi_uncert = False |
| |
| |
| self.heteroscedastic_head = None |
| if config.use_heteroscedastic: |
| from uncertainty import HeteroscedasticHead |
| self.heteroscedastic_head = HeteroscedasticHead(config.d_model, n_outputs=6) |
| |
| |
| self.prompt_embed = PromptEmbedding(config) |
| |
| |
| |
| self.fusion_ln = nn.LayerNorm(config.d_model) |
| |
| |
| self.pos_encoding = SinusoidalPositionalEncoding( |
| config.d_model, config.max_seq_len, config.dropout |
| ) |
| |
| |
| self.blocks = nn.ModuleList([ |
| TransformerBlock(config) for _ in range(config.n_layers) |
| ]) |
| |
| |
| self.final_ln = nn.LayerNorm(config.d_model) |
| |
| |
| self.prediction_head = NextStatePredictionHead(config) |
| |
| |
| self.classification_head = None |
| |
| |
| self.apply(self._init_weights) |
| |
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| if module.bias is not None: |
| torch.nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Embedding): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| elif isinstance(module, nn.LayerNorm): |
| torch.nn.init.ones_(module.weight) |
| torch.nn.init.zeros_(module.bias) |
| |
| def attach_classification_head(self, n_classes: int): |
| """Attach a classification head for downstream fine-tuning.""" |
| self.classification_head = ClassificationHead( |
| self.config.d_model, n_classes, self.config.dropout |
| ) |
| |
| def get_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor: |
| """Generate causal attention mask.""" |
| mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1) |
| mask = mask.masked_fill(mask == 1, float('-inf')) |
| return mask |
| |
| def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
| """ |
| Forward pass. |
| |
| Args: |
| batch: dict from AirTrackDataset.__getitem__ (batched) |
| |
| Returns: |
| dict with prediction logits and optionally classification logits |
| """ |
| device = batch['cog_bins'].device |
| B = batch['cog_bins'].size(0) |
| |
| |
| |
| |
| feat_emb = self.feature_embed( |
| batch['cog_bins'], batch['sog_bins'], |
| batch['rot_bins'], batch['alt_rate_bins'] |
| ) |
| |
| |
| temp_emb = self.temporal_embed( |
| batch['hour'], batch['dow'], batch['month'], |
| batch['second_of_day'], batch['dt'] |
| ) |
| |
| |
| if self._multi_uncert and 'uncert_bins_multi' in batch: |
| uncert_emb = self.uncertainty_embed(batch['uncert_bins_multi']) |
| else: |
| uncert_emb = self.uncertainty_embed(batch['uncert_bins']) |
| |
| |
| if self.config.geohash_mode == 'continuous': |
| geo_emb = self.geohash_embed(batch['east'], batch['north'], batch['up']) |
| elif self.geohash_embed is not None: |
| geo_emb = self.geohash_embed(batch['geohash_bits']) |
| else: |
| geo_emb = torch.zeros_like(feat_emb) |
| |
| |
| state_emb = feat_emb + temp_emb + uncert_emb + geo_emb |
| state_emb = self.fusion_ln(state_emb) |
| |
| |
| prompt_emb = self.prompt_embed(batch['prompt']) |
| |
| |
| x = torch.cat([prompt_emb, state_emb], dim=1) |
| |
| |
| x = self.pos_encoding(x) |
| |
| |
| seq_len = x.size(1) |
| causal_mask = self.get_causal_mask(seq_len, device) |
| |
| for block in self.blocks: |
| x = block(x, attn_mask=causal_mask) |
| |
| x = self.final_ln(x) |
| |
| |
| n_prompt = batch['prompt'].size(1) |
| prompt_output = x[:, :n_prompt, :] |
| state_output = x[:, n_prompt:, :] |
| |
| |
| predictions = self.prediction_head(state_output) |
| |
| |
| if self.heteroscedastic_head is not None: |
| predictions['log_var'] = self.heteroscedastic_head(state_output) |
| |
| |
| if self.classification_head is not None: |
| predictions['class_logits'] = self.classification_head(x) |
| |
| return predictions |
| |
| def count_parameters(self) -> Dict[str, int]: |
| """Count parameters by component.""" |
| counts = {} |
| for name, module in [ |
| ('geohash_embed', self.geohash_embed), |
| ('feature_embed', self.feature_embed), |
| ('temporal_embed', self.temporal_embed), |
| ('uncertainty_embed', self.uncertainty_embed), |
| ('prompt_embed', self.prompt_embed), |
| ('transformer_blocks', self.blocks), |
| ('prediction_head', self.prediction_head), |
| ]: |
| if module is not None: |
| counts[name] = sum(p.numel() for p in module.parameters()) |
| |
| counts['total'] = sum(p.numel() for p in self.parameters()) |
| counts['trainable'] = sum(p.numel() for p in self.parameters() if p.requires_grad) |
| return counts |
|
|
|
|
| |
| |
| |
|
|
| class NextStateLoss(nn.Module): |
| """ |
| Multi-task loss for next-state prediction. |
| |
| For each position t, the model predicts features at t+1. |
| Losses: |
| - Geohash: Binary cross-entropy per bit |
| - Kinematic features (COG, SOG, ROT, alt_rate): Cross-entropy per feature |
| - Continuous ENU: MSE (optional) |
| """ |
| |
| def __init__(self, config: AirTrackConfig, loss_weights: Optional[Dict[str, float]] = None): |
| super().__init__() |
| self.config = config |
| |
| |
| self.weights = loss_weights or { |
| 'geohash': 1.0, |
| 'continuous': 0.5, |
| 'cog': 1.0, |
| 'sog': 1.0, |
| 'rot': 1.0, |
| 'alt_rate': 1.0, |
| } |
| |
| self.ce = nn.CrossEntropyLoss(reduction='mean') |
| self.bce = nn.BCEWithLogitsLoss(reduction='mean') |
| self.mse = nn.MSELoss(reduction='mean') |
| |
| def forward( |
| self, |
| predictions: Dict[str, torch.Tensor], |
| batch: Dict[str, torch.Tensor], |
| ) -> Tuple[torch.Tensor, Dict[str, float]]: |
| """ |
| Compute loss. Targets are shifted by 1 (predict next state). |
| |
| predictions[key] is at positions [0, 1, ..., L-1] |
| targets are batch[key] at positions [1, 2, ..., L] |
| |
| So we compare predictions[:, :-1, :] with targets[:, 1:, :] |
| """ |
| losses = {} |
| |
| |
| if self.config.predict_geohash and 'geohash_logits' in predictions: |
| |
| pred_geo = predictions['geohash_logits'][:, :-1, :] |
| target_geo = batch['geohash_bits'][:, 1:, :] |
| losses['geohash'] = self.bce(pred_geo, target_geo) * self.weights['geohash'] |
| |
| |
| if self.config.predict_continuous and 'continuous_pred' in predictions: |
| pred_cont = predictions['continuous_pred'][:, :-1, :] |
| |
| delta_east = (batch['east'][:, 1:] - batch['east'][:, :-1]) / 1000.0 |
| delta_north = (batch['north'][:, 1:] - batch['north'][:, :-1]) / 1000.0 |
| delta_up = (batch['up'][:, 1:] - batch['up'][:, :-1]) / 1000.0 |
| target_delta = torch.stack([delta_east, delta_north, delta_up], dim=-1) |
| losses['continuous'] = self.mse(pred_cont, target_delta) * self.weights['continuous'] |
| |
| |
| pred_cog = predictions['cog_logits'][:, :-1, :] |
| target_cog = batch['cog_bins'][:, 1:] |
| losses['cog'] = self.ce(pred_cog.reshape(-1, pred_cog.size(-1)), target_cog.reshape(-1)) * self.weights['cog'] |
| |
| |
| pred_sog = predictions['sog_logits'][:, :-1, :] |
| target_sog = batch['sog_bins'][:, 1:] |
| losses['sog'] = self.ce(pred_sog.reshape(-1, pred_sog.size(-1)), target_sog.reshape(-1)) * self.weights['sog'] |
| |
| |
| pred_rot = predictions['rot_logits'][:, :-1, :] |
| target_rot = batch['rot_bins'][:, 1:] |
| losses['rot'] = self.ce(pred_rot.reshape(-1, pred_rot.size(-1)), target_rot.reshape(-1)) * self.weights['rot'] |
| |
| |
| pred_ar = predictions['alt_rate_logits'][:, :-1, :] |
| target_ar = batch['alt_rate_bins'][:, 1:] |
| losses['alt_rate'] = self.ce(pred_ar.reshape(-1, pred_ar.size(-1)), target_ar.reshape(-1)) * self.weights['alt_rate'] |
| |
| |
| if 'log_var' in predictions: |
| log_var = predictions['log_var'][:, :-1, :] |
| |
| log_var_clamped = torch.clamp(log_var, -5.0, 5.0) |
| |
| losses['log_var_reg'] = 0.1 * (log_var_clamped ** 2).mean() |
| |
| |
| total_loss = sum(losses.values()) |
| |
| |
| loss_log = {k: v.item() for k, v in losses.items()} |
| loss_log['total'] = total_loss.item() |
| |
| return total_loss, loss_log |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == '__main__': |
| config = AirTrackConfig() |
| model = AirTrackLM(config) |
| |
| |
| counts = model.count_parameters() |
| print("Parameter counts:") |
| for name, count in counts.items(): |
| print(f" {name}: {count:,}") |
| |
| |
| B, L = 2, 65 |
| n_prompt = config.n_prompt_len |
| |
| batch = { |
| 'geohash_bits': torch.randn(B, L, config.geohash_bits), |
| 'cog_bins': torch.randint(0, config.n_cog_bins, (B, L)), |
| 'sog_bins': torch.randint(0, config.n_sog_bins, (B, L)), |
| 'rot_bins': torch.randint(0, config.n_rot_bins, (B, L)), |
| 'alt_rate_bins': torch.randint(0, config.n_alt_rate_bins, (B, L)), |
| 'uncert_bins': torch.randint(0, config.n_uncert_bins, (B, L)), |
| 'hour': torch.randint(0, 24, (B, L)), |
| 'dow': torch.randint(0, 7, (B, L)), |
| 'month': torch.randint(0, 12, (B, L)), |
| 'second_of_day': torch.rand(B, L) * 86400, |
| 'dt': torch.ones(B, L) * 5.0, |
| 'prompt': torch.randint(0, config.n_prompt_tokens, (B, n_prompt)), |
| 'east': torch.randn(B, L) * 1000, |
| 'north': torch.randn(B, L) * 1000, |
| 'up': torch.randn(B, L) * 1000, |
| } |
| |
| predictions = model(batch) |
| |
| print("\nPrediction shapes:") |
| for k, v in predictions.items(): |
| print(f" {k}: {v.shape}") |
| |
| |
| loss_fn = NextStateLoss(config) |
| total_loss, loss_log = loss_fn(predictions, batch) |
| print(f"\nLoss: {loss_log}") |
|
|