""" AirTrackLM - Model Architecture ================================ Decoder-only transformer with 4 embedding types for air track next-state prediction. Embedding types (following LLM4STP, adapted for aviation): 1. Geohash: 40-bit binary per ENU axis (120 bits total) → Linear projection → d_model 2. Temporal: Sinusoidal second-of-day + learned hour/dow/month embeddings 3. Uncertainty: Learned embedding from trajectory smoothness bins 4. Prompt: Learned tokens for task/aircraft/phase/region metadata Core architecture: - Additive embedding fusion (E_geo + E_feat + E_temp + E_uncert) - Prompt tokens prepended to sequence - Causal (GPT-style) multi-head self-attention - Multi-head output: separate prediction per feature type """ import math import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Dict, Tuple from dataclasses import dataclass # ============================================================ # Configuration # ============================================================ @dataclass class AirTrackConfig: """Model configuration.""" # Transformer backbone d_model: int = 256 n_heads: int = 8 n_layers: int = 8 d_ff: int = 1024 dropout: float = 0.1 max_seq_len: int = 256 # max sequence length (prompt + trajectory) # Geohash embedding (LLM4STP style) geohash_bits: int = 120 # 40 bits × 3 axes (E, N, U) geohash_hidden: int = 64 # intermediate projection dim # Feature bins (discretized kinematic features) n_cog_bins: int = 180 # 2° resolution over [0, 360) n_sog_bins: int = 300 # 2-knot resolution over [0, 600] n_rot_bins: int = 120 # 0.1°/s over [-6, 6] n_alt_rate_bins: int = 120 # 100 ft/min over [-6000, 6000] # Temporal embedding n_hours: int = 24 n_dow: int = 7 n_months: int = 12 time_sinusoidal_dim: int = 32 # dimension for sinusoidal second-of-day encoding # Uncertainty embedding n_uncert_bins: int = 16 n_uncert_methods: int = 4 # kinematic_var, pred_residual, spatial_density, phase_entropy use_multi_uncertainty: bool = True # if True, use MultiUncertaintyEmbedding use_heteroscedastic: bool = True # if True, add learned uncertainty head # Prompt embedding n_prompt_tokens: int = 23 # PromptTokens.VOCAB_SIZE n_prompt_len: int = 5 # [BOS, TASK, AIRCRAFT, PHASE, REGION] # Output heads # We predict: geohash (regression), COG bin, SOG bin, ROT bin, alt_rate bin predict_geohash: bool = True # if True, predict geohash bits (binary classification per bit) predict_continuous: bool = True # if True, also predict continuous ENU offset (regression) # Ablation variants for geohash geohash_mode: str = 'absolute' # 'absolute', 'none', 'relative', 'multi_res', 'continuous' # ============================================================ # Embedding Modules # ============================================================ class GeohashEmbedding(nn.Module): """ Binary geohash embedding following LLM4STP. Projects 120-bit binary vector through: Linear(120 → geohash_hidden) → ReLU → Linear(geohash_hidden → d_model) LLM4STP uses Conv1d on the bits, but we use MLP for simplicity since each timestep's 120 bits are independent. """ def __init__(self, config: AirTrackConfig): super().__init__() self.projection = nn.Sequential( nn.Linear(config.geohash_bits, config.geohash_hidden), nn.ReLU(), nn.Linear(config.geohash_hidden, config.d_model), ) def forward(self, geohash_bits: torch.Tensor) -> torch.Tensor: """ Args: geohash_bits: (batch, seq_len, 120) float tensor of binary geohash Returns: (batch, seq_len, d_model) """ return self.projection(geohash_bits) class ContinuousPositionEmbedding(nn.Module): """Ablation variant V5: direct linear projection of continuous ENU coordinates.""" def __init__(self, config: AirTrackConfig): super().__init__() self.projection = nn.Sequential( nn.Linear(3, config.geohash_hidden), nn.ReLU(), nn.Linear(config.geohash_hidden, config.d_model), ) def forward(self, east: torch.Tensor, north: torch.Tensor, up: torch.Tensor) -> torch.Tensor: """ Args: east, north, up: (batch, seq_len) each Returns: (batch, seq_len, d_model) """ pos = torch.stack([east, north, up], dim=-1) # (B, L, 3) return self.projection(pos) class FeatureEmbedding(nn.Module): """ Learned embedding tables for discretized kinematic features. Each feature has its own embedding table, all outputs summed. """ def __init__(self, config: AirTrackConfig): super().__init__() self.cog_embed = nn.Embedding(config.n_cog_bins, config.d_model) self.sog_embed = nn.Embedding(config.n_sog_bins, config.d_model) self.rot_embed = nn.Embedding(config.n_rot_bins, config.d_model) self.alt_rate_embed = nn.Embedding(config.n_alt_rate_bins, config.d_model) def forward( self, cog_bins: torch.Tensor, sog_bins: torch.Tensor, rot_bins: torch.Tensor, alt_rate_bins: torch.Tensor, ) -> torch.Tensor: """ Args: *_bins: (batch, seq_len) long tensors of bin indices Returns: (batch, seq_len, d_model) — sum of all feature embeddings """ return ( self.cog_embed(cog_bins) + self.sog_embed(sog_bins) + self.rot_embed(rot_bins) + self.alt_rate_embed(alt_rate_bins) ) class TemporalEmbedding(nn.Module): """ Temporal embedding combining: 1. Sinusoidal encoding of second-of-day (sub-second resolution) 2. Learned embeddings for hour, day-of-week, month 3. Sinusoidal encoding of delta-t (time since previous state) The sinusoidal encoding gives sub-second precision since it operates on continuous float seconds, not discrete bins. """ def __init__(self, config: AirTrackConfig): super().__init__() # Learned calendar embeddings self.hour_embed = nn.Embedding(config.n_hours, config.d_model) self.dow_embed = nn.Embedding(config.n_dow, config.d_model) self.month_embed = nn.Embedding(config.n_months, config.d_model) # Sinusoidal projection for continuous time features # second_of_day → sinusoidal features → linear → d_model self.time_sin_dim = config.time_sinusoidal_dim self.time_projection = nn.Linear(config.time_sinusoidal_dim * 2, config.d_model) # Delta-t projection self.dt_projection = nn.Linear(config.time_sinusoidal_dim * 2, config.d_model) # Pre-compute frequency bases for sinusoidal encoding # Multiple frequencies to capture different time scales freqs = torch.exp(torch.arange(0, config.time_sinusoidal_dim, dtype=torch.float32) * -(math.log(86400.0) / config.time_sinusoidal_dim)) self.register_buffer('time_freqs', freqs) dt_freqs = torch.exp(torch.arange(0, config.time_sinusoidal_dim, dtype=torch.float32) * -(math.log(3600.0) / config.time_sinusoidal_dim)) self.register_buffer('dt_freqs', dt_freqs) def sinusoidal_encode(self, values: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: """ Encode continuous values with multiple sinusoidal frequencies. Args: values: (batch, seq_len) — continuous values freqs: (dim,) — frequency bases Returns: (batch, seq_len, dim*2) — sin and cos features """ # (B, L, 1) * (1, 1, dim) → (B, L, dim) angles = values.unsqueeze(-1) * freqs.unsqueeze(0).unsqueeze(0) * 2 * math.pi return torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1) def forward( self, hour: torch.Tensor, dow: torch.Tensor, month: torch.Tensor, second_of_day: torch.Tensor, dt: torch.Tensor, ) -> torch.Tensor: """ Args: hour: (B, L) long — hour of day [0, 23] dow: (B, L) long — day of week [0, 6] month: (B, L) long — month [0, 11] second_of_day: (B, L) float — seconds within day [0, 86400) dt: (B, L) float — delta-t in seconds Returns: (B, L, d_model) """ # Learned calendar embeddings cal = self.hour_embed(hour) + self.dow_embed(dow) + self.month_embed(month) # Sinusoidal second-of-day (sub-second resolution) time_sin = self.sinusoidal_encode(second_of_day, self.time_freqs) # (B, L, dim*2) time_emb = self.time_projection(time_sin) # (B, L, d_model) # Sinusoidal delta-t dt_sin = self.sinusoidal_encode(dt, self.dt_freqs) # (B, L, dim*2) dt_emb = self.dt_projection(dt_sin) # (B, L, d_model) return cal + time_emb + dt_emb class UncertaintyEmbedding(nn.Module): """Learned embedding for trajectory uncertainty bins.""" def __init__(self, config: AirTrackConfig): super().__init__() self.embed = nn.Embedding(config.n_uncert_bins, config.d_model) def forward(self, uncert_bins: torch.Tensor) -> torch.Tensor: """ Args: uncert_bins: (B, L) long — uncertainty bin indices Returns: (B, L, d_model) """ return self.embed(uncert_bins) class PromptEmbedding(nn.Module): """Learned prompt token embeddings for task/metadata conditioning.""" def __init__(self, config: AirTrackConfig): super().__init__() self.embed = nn.Embedding(config.n_prompt_tokens, config.d_model) def forward(self, prompt_tokens: torch.Tensor) -> torch.Tensor: """ Args: prompt_tokens: (B, n_prompt_len) long — prompt token IDs Returns: (B, n_prompt_len, d_model) """ return self.embed(prompt_tokens) # ============================================================ # Positional Encoding # ============================================================ class SinusoidalPositionalEncoding(nn.Module): """Standard sinusoidal positional encoding.""" def __init__(self, d_model: int, max_len: int = 512, dropout: float = 0.1): super().__init__() self.dropout = nn.Dropout(p=dropout) pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) # (1, max_len, d_model) self.register_buffer('pe', pe) def forward(self, x: torch.Tensor) -> torch.Tensor: """x: (B, L, d_model)""" x = x + self.pe[:, :x.size(1)] return self.dropout(x) # ============================================================ # Transformer Backbone # ============================================================ class TransformerBlock(nn.Module): """Single transformer decoder block with causal attention.""" def __init__(self, config: AirTrackConfig): super().__init__() self.ln1 = nn.LayerNorm(config.d_model) self.attn = nn.MultiheadAttention( embed_dim=config.d_model, num_heads=config.n_heads, dropout=config.dropout, batch_first=True, ) self.ln2 = nn.LayerNorm(config.d_model) self.ffn = nn.Sequential( nn.Linear(config.d_model, config.d_ff), nn.GELU(), nn.Dropout(config.dropout), nn.Linear(config.d_ff, config.d_model), nn.Dropout(config.dropout), ) def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor: """ Args: x: (B, L, d_model) attn_mask: (L, L) causal mask Returns: (B, L, d_model) """ # Pre-norm architecture (like GPT-2) h = self.ln1(x) h, _ = self.attn(h, h, h, attn_mask=attn_mask, is_causal=(attn_mask is None)) x = x + h h = self.ln2(x) h = self.ffn(h) x = x + h return x # ============================================================ # Output Heads # ============================================================ class NextStatePredictionHead(nn.Module): """ Multi-head output for next-state prediction. Predicts each feature type independently. """ def __init__(self, config: AirTrackConfig): super().__init__() # Geohash: predict 120 binary bits (sigmoid per bit) if config.predict_geohash: self.geohash_head = nn.Linear(config.d_model, config.geohash_bits) # Continuous ENU regression (optional secondary objective) if config.predict_continuous: self.continuous_head = nn.Sequential( nn.Linear(config.d_model, config.d_model // 2), nn.GELU(), nn.Linear(config.d_model // 2, 3), # (Δeast, Δnorth, Δup) ) # Kinematic feature bin classification self.cog_head = nn.Linear(config.d_model, config.n_cog_bins) self.sog_head = nn.Linear(config.d_model, config.n_sog_bins) self.rot_head = nn.Linear(config.d_model, config.n_rot_bins) self.alt_rate_head = nn.Linear(config.d_model, config.n_alt_rate_bins) self.config = config def forward(self, hidden_states: torch.Tensor) -> Dict[str, torch.Tensor]: """ Args: hidden_states: (B, L, d_model) — transformer output Returns: dict of logits/predictions for each feature """ out = {} if self.config.predict_geohash: out['geohash_logits'] = self.geohash_head(hidden_states) # (B, L, 120) if self.config.predict_continuous: out['continuous_pred'] = self.continuous_head(hidden_states) # (B, L, 3) out['cog_logits'] = self.cog_head(hidden_states) # (B, L, n_cog_bins) out['sog_logits'] = self.sog_head(hidden_states) # (B, L, n_sog_bins) out['rot_logits'] = self.rot_head(hidden_states) # (B, L, n_rot_bins) out['alt_rate_logits'] = self.alt_rate_head(hidden_states) # (B, L, n_alt_rate_bins) return out class ClassificationHead(nn.Module): """Downstream classification head (attached after pretraining).""" def __init__(self, d_model: int, n_classes: int, dropout: float = 0.1): super().__init__() self.head = nn.Sequential( nn.Linear(d_model, d_model // 2), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_model // 2, n_classes), ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ Uses the BOS token representation (first position) for classification. Args: hidden_states: (B, L, d_model) Returns: (B, n_classes) """ cls_repr = hidden_states[:, 0, :] # BOS position return self.head(cls_repr) # ============================================================ # Main Model # ============================================================ class AirTrackLM(nn.Module): """ AirTrackLM: Decoder-only transformer for air track next-state prediction. Architecture: Input → [4 Embedding Types fused additively] → Positional Encoding → N × TransformerBlock (causal attention) → Multi-head output (geohash + kinematic features) """ def __init__(self, config: AirTrackConfig): super().__init__() self.config = config # === Embedding layers === # Geohash (spatial position) if config.geohash_mode == 'absolute': self.geohash_embed = GeohashEmbedding(config) elif config.geohash_mode == 'continuous': self.geohash_embed = ContinuousPositionEmbedding(config) elif config.geohash_mode == 'none': self.geohash_embed = None else: # relative and multi_res use same base as absolute self.geohash_embed = GeohashEmbedding(config) # Kinematic features self.feature_embed = FeatureEmbedding(config) # Temporal self.temporal_embed = TemporalEmbedding(config) # Uncertainty — single or multi-method if config.use_multi_uncertainty and config.n_uncert_methods > 1: from uncertainty import MultiUncertaintyEmbedding self.uncertainty_embed = MultiUncertaintyEmbedding( config.d_model, config.n_uncert_methods, config.n_uncert_bins ) self._multi_uncert = True else: self.uncertainty_embed = UncertaintyEmbedding(config) self._multi_uncert = False # Heteroscedastic uncertainty head (learned aleatoric) self.heteroscedastic_head = None if config.use_heteroscedastic: from uncertainty import HeteroscedasticHead self.heteroscedastic_head = HeteroscedasticHead(config.d_model, n_outputs=6) # Prompt self.prompt_embed = PromptEmbedding(config) # === Fusion projection === # After additive fusion, project through LayerNorm self.fusion_ln = nn.LayerNorm(config.d_model) # === Positional encoding === self.pos_encoding = SinusoidalPositionalEncoding( config.d_model, config.max_seq_len, config.dropout ) # === Transformer blocks === self.blocks = nn.ModuleList([ TransformerBlock(config) for _ in range(config.n_layers) ]) # Final layer norm self.final_ln = nn.LayerNorm(config.d_model) # === Output heads === self.prediction_head = NextStatePredictionHead(config) # Classification head (optional, for downstream) self.classification_head = None # Initialize weights self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) elif isinstance(module, nn.LayerNorm): torch.nn.init.ones_(module.weight) torch.nn.init.zeros_(module.bias) def attach_classification_head(self, n_classes: int): """Attach a classification head for downstream fine-tuning.""" self.classification_head = ClassificationHead( self.config.d_model, n_classes, self.config.dropout ) def get_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor: """Generate causal attention mask.""" mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1) mask = mask.masked_fill(mask == 1, float('-inf')) return mask def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Forward pass. Args: batch: dict from AirTrackDataset.__getitem__ (batched) Returns: dict with prediction logits and optionally classification logits """ device = batch['cog_bins'].device B = batch['cog_bins'].size(0) # --- Build state embeddings --- # Kinematic feature embedding feat_emb = self.feature_embed( batch['cog_bins'], batch['sog_bins'], batch['rot_bins'], batch['alt_rate_bins'] ) # (B, L, d_model) # Temporal embedding temp_emb = self.temporal_embed( batch['hour'], batch['dow'], batch['month'], batch['second_of_day'], batch['dt'] ) # (B, L, d_model) # Uncertainty embedding (single or multi-method) if self._multi_uncert and 'uncert_bins_multi' in batch: uncert_emb = self.uncertainty_embed(batch['uncert_bins_multi']) # (B, L, d_model) else: uncert_emb = self.uncertainty_embed(batch['uncert_bins']) # (B, L, d_model) # Geohash embedding if self.config.geohash_mode == 'continuous': geo_emb = self.geohash_embed(batch['east'], batch['north'], batch['up']) elif self.geohash_embed is not None: geo_emb = self.geohash_embed(batch['geohash_bits']) # (B, L, d_model) else: geo_emb = torch.zeros_like(feat_emb) # --- Additive fusion --- state_emb = feat_emb + temp_emb + uncert_emb + geo_emb # (B, L, d_model) state_emb = self.fusion_ln(state_emb) # --- Prepend prompt tokens --- prompt_emb = self.prompt_embed(batch['prompt']) # (B, n_prompt, d_model) # Concatenate: [PROMPT | STATE_1 | STATE_2 | ... | STATE_T] x = torch.cat([prompt_emb, state_emb], dim=1) # (B, n_prompt + L, d_model) # --- Positional encoding --- x = self.pos_encoding(x) # --- Causal transformer --- seq_len = x.size(1) causal_mask = self.get_causal_mask(seq_len, device) for block in self.blocks: x = block(x, attn_mask=causal_mask) x = self.final_ln(x) # --- Split output --- n_prompt = batch['prompt'].size(1) prompt_output = x[:, :n_prompt, :] # (B, n_prompt, d_model) state_output = x[:, n_prompt:, :] # (B, L, d_model) # --- Prediction heads (on state output) --- predictions = self.prediction_head(state_output) # --- Heteroscedastic uncertainty (learned aleatoric) --- if self.heteroscedastic_head is not None: predictions['log_var'] = self.heteroscedastic_head(state_output) # (B, L, 6) # --- Classification (optional) --- if self.classification_head is not None: predictions['class_logits'] = self.classification_head(x) # uses BOS at position 0 return predictions def count_parameters(self) -> Dict[str, int]: """Count parameters by component.""" counts = {} for name, module in [ ('geohash_embed', self.geohash_embed), ('feature_embed', self.feature_embed), ('temporal_embed', self.temporal_embed), ('uncertainty_embed', self.uncertainty_embed), ('prompt_embed', self.prompt_embed), ('transformer_blocks', self.blocks), ('prediction_head', self.prediction_head), ]: if module is not None: counts[name] = sum(p.numel() for p in module.parameters()) counts['total'] = sum(p.numel() for p in self.parameters()) counts['trainable'] = sum(p.numel() for p in self.parameters() if p.requires_grad) return counts # ============================================================ # Loss Function # ============================================================ class NextStateLoss(nn.Module): """ Multi-task loss for next-state prediction. For each position t, the model predicts features at t+1. Losses: - Geohash: Binary cross-entropy per bit - Kinematic features (COG, SOG, ROT, alt_rate): Cross-entropy per feature - Continuous ENU: MSE (optional) """ def __init__(self, config: AirTrackConfig, loss_weights: Optional[Dict[str, float]] = None): super().__init__() self.config = config # Default loss weights (equal) self.weights = loss_weights or { 'geohash': 1.0, 'continuous': 0.5, 'cog': 1.0, 'sog': 1.0, 'rot': 1.0, 'alt_rate': 1.0, } self.ce = nn.CrossEntropyLoss(reduction='mean') self.bce = nn.BCEWithLogitsLoss(reduction='mean') self.mse = nn.MSELoss(reduction='mean') def forward( self, predictions: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor], ) -> Tuple[torch.Tensor, Dict[str, float]]: """ Compute loss. Targets are shifted by 1 (predict next state). predictions[key] is at positions [0, 1, ..., L-1] targets are batch[key] at positions [1, 2, ..., L] So we compare predictions[:, :-1, :] with targets[:, 1:, :] """ losses = {} # --- Geohash binary prediction --- if self.config.predict_geohash and 'geohash_logits' in predictions: # predictions: (B, L, 120), targets: (B, L, 120) float pred_geo = predictions['geohash_logits'][:, :-1, :] # (B, L-1, 120) target_geo = batch['geohash_bits'][:, 1:, :] # (B, L-1, 120) losses['geohash'] = self.bce(pred_geo, target_geo) * self.weights['geohash'] # --- Continuous ENU regression (predict delta in km, not raw meters) --- if self.config.predict_continuous and 'continuous_pred' in predictions: pred_cont = predictions['continuous_pred'][:, :-1, :] # (B, L-1, 3) # Target is delta-ENU: position(t+1) - position(t), normalized to km delta_east = (batch['east'][:, 1:] - batch['east'][:, :-1]) / 1000.0 delta_north = (batch['north'][:, 1:] - batch['north'][:, :-1]) / 1000.0 delta_up = (batch['up'][:, 1:] - batch['up'][:, :-1]) / 1000.0 target_delta = torch.stack([delta_east, delta_north, delta_up], dim=-1) losses['continuous'] = self.mse(pred_cont, target_delta) * self.weights['continuous'] # --- COG --- pred_cog = predictions['cog_logits'][:, :-1, :] # (B, L-1, n_cog_bins) target_cog = batch['cog_bins'][:, 1:] # (B, L-1) losses['cog'] = self.ce(pred_cog.reshape(-1, pred_cog.size(-1)), target_cog.reshape(-1)) * self.weights['cog'] # --- SOG --- pred_sog = predictions['sog_logits'][:, :-1, :] target_sog = batch['sog_bins'][:, 1:] losses['sog'] = self.ce(pred_sog.reshape(-1, pred_sog.size(-1)), target_sog.reshape(-1)) * self.weights['sog'] # --- ROT --- pred_rot = predictions['rot_logits'][:, :-1, :] target_rot = batch['rot_bins'][:, 1:] losses['rot'] = self.ce(pred_rot.reshape(-1, pred_rot.size(-1)), target_rot.reshape(-1)) * self.weights['rot'] # --- Alt rate --- pred_ar = predictions['alt_rate_logits'][:, :-1, :] target_ar = batch['alt_rate_bins'][:, 1:] losses['alt_rate'] = self.ce(pred_ar.reshape(-1, pred_ar.size(-1)), target_ar.reshape(-1)) * self.weights['alt_rate'] # --- Heteroscedastic regularization (learned aleatoric uncertainty) --- if 'log_var' in predictions: log_var = predictions['log_var'][:, :-1, :] # (B, L-1, 6) # Clamp log_var to prevent collapse: [-5, 5] range log_var_clamped = torch.clamp(log_var, -5.0, 5.0) # Regularize toward 0 (unit variance prior) losses['log_var_reg'] = 0.1 * (log_var_clamped ** 2).mean() # Total loss total_loss = sum(losses.values()) # Log individual losses loss_log = {k: v.item() for k, v in losses.items()} loss_log['total'] = total_loss.item() return total_loss, loss_log # ============================================================ # Quick test # ============================================================ if __name__ == '__main__': config = AirTrackConfig() model = AirTrackLM(config) # Print parameter counts counts = model.count_parameters() print("Parameter counts:") for name, count in counts.items(): print(f" {name}: {count:,}") # Test forward pass with dummy data B, L = 2, 65 # batch=2, seq_len=65 (64 states + 1 for target shift) n_prompt = config.n_prompt_len batch = { 'geohash_bits': torch.randn(B, L, config.geohash_bits), 'cog_bins': torch.randint(0, config.n_cog_bins, (B, L)), 'sog_bins': torch.randint(0, config.n_sog_bins, (B, L)), 'rot_bins': torch.randint(0, config.n_rot_bins, (B, L)), 'alt_rate_bins': torch.randint(0, config.n_alt_rate_bins, (B, L)), 'uncert_bins': torch.randint(0, config.n_uncert_bins, (B, L)), 'hour': torch.randint(0, 24, (B, L)), 'dow': torch.randint(0, 7, (B, L)), 'month': torch.randint(0, 12, (B, L)), 'second_of_day': torch.rand(B, L) * 86400, 'dt': torch.ones(B, L) * 5.0, 'prompt': torch.randint(0, config.n_prompt_tokens, (B, n_prompt)), 'east': torch.randn(B, L) * 1000, 'north': torch.randn(B, L) * 1000, 'up': torch.randn(B, L) * 1000, } predictions = model(batch) print("\nPrediction shapes:") for k, v in predictions.items(): print(f" {k}: {v.shape}") # Test loss loss_fn = NextStateLoss(config) total_loss, loss_log = loss_fn(predictions, batch) print(f"\nLoss: {loss_log}")