""" AirTrackLM - Uncertainty Methods ================================= Multiple uncertainty quantification approaches for trajectory prediction. Methods: 1. Kinematic Variance - sliding window variance of COG/SOG/ROT/alt_rate 2. Prediction Residual - deviation from constant-velocity prediction model 3. Spatial Density - data coverage proxy 4. Flight Phase Entropy - entropy of phase classification in a window 5. Learned Heteroscedastic - model predicts its own uncertainty (aleatoric) 6. MC-Dropout - Monte Carlo dropout at inference (epistemic) """ import numpy as np import torch import torch.nn as nn from typing import Dict, List, Optional, Tuple from dataclasses import dataclass def uncertainty_kinematic_variance(cog, sog, rot, alt_rate, window=5): N = len(cog) scores = np.zeros(N) global_sog_var = max(np.var(sog), 1e-10) global_rot_var = max(np.var(rot), 1e-10) global_alt_var = max(np.var(alt_rate), 1e-10) for i in range(N): start = max(0, i - window + 1) w = slice(start, i + 1) cog_rad = np.radians(cog[w]) R_len = np.sqrt(np.mean(np.cos(cog_rad))**2 + np.mean(np.sin(cog_rad))**2) cog_var = 1 - R_len sog_var = np.var(sog[w]) / global_sog_var if len(sog[w]) > 1 else 0 rot_var = np.var(rot[w]) / global_rot_var if len(rot[w]) > 1 else 0 alt_var = np.var(alt_rate[w]) / global_alt_var if len(alt_rate[w]) > 1 else 0 scores[i] = cog_var + sog_var + rot_var + alt_var return scores def uncertainty_prediction_residual(east, north, up, timestamps, horizon=3): N = len(east) scores = np.zeros(N) dt = np.diff(timestamps) dt = np.maximum(dt, 1e-6) for i in range(1, N - horizon): vx = (east[i] - east[i-1]) / dt[i-1] vy = (north[i] - north[i-1]) / dt[i-1] vz = (up[i] - up[i-1]) / dt[i-1] dt_pred = timestamps[i + horizon] - timestamps[i] pred_e = east[i] + vx * dt_pred pred_n = north[i] + vy * dt_pred pred_u = up[i] + vz * dt_pred residual = np.sqrt( (pred_e - east[i + horizon])**2 + (pred_n - north[i + horizon])**2 + (pred_u - up[i + horizon])**2 ) scores[i] = residual if N > horizon + 1: scores[0] = scores[1] scores[N-horizon:] = scores[N-horizon-1] return scores def uncertainty_spatial_density(east, north, up, all_trajectories_enu=None, radius_m=5000.0): N = len(east) scores = np.zeros(N) if all_trajectories_enu is not None: all_e = np.concatenate([t[0] for t in all_trajectories_enu]) all_n = np.concatenate([t[1] for t in all_trajectories_enu]) all_u = np.concatenate([t[2] for t in all_trajectories_enu]) if len(all_e) > 50000: idx = np.random.choice(len(all_e), 50000, replace=False) all_e, all_n, all_u = all_e[idx], all_n[idx], all_u[idx] for i in range(N): dists = np.sqrt((all_e - east[i])**2 + (all_n - north[i])**2 + (all_u - up[i])**2) count = np.sum(dists < radius_m) scores[i] = 1.0 / max(count, 1) else: for i in range(1, N): dist = np.sqrt((east[i]-east[i-1])**2 + (north[i]-north[i-1])**2 + (up[i]-up[i-1])**2) scores[i] = dist if N > 1: scores[0] = scores[1] return scores def uncertainty_flight_phase_entropy(sog, alt_rate, up, window=10): N = len(sog) phases = np.zeros(N, dtype=np.int64) for i in range(N): if sog[i] < 30 and up[i] < 500: phases[i] = 0 elif alt_rate[i] > 300: phases[i] = 1 elif alt_rate[i] < -300: phases[i] = 2 elif abs(alt_rate[i]) <= 300 and up[i] > 5000: phases[i] = 3 elif alt_rate[i] < -100 and up[i] < 3000 and sog[i] < 200: phases[i] = 4 else: phases[i] = 5 n_phases = 6 scores = np.zeros(N) for i in range(N): start = max(0, i - window + 1) w_phases = phases[start:i+1] counts = np.bincount(w_phases, minlength=n_phases).astype(float) probs = counts / counts.sum() probs = probs[probs > 0] entropy = -np.sum(probs * np.log2(probs)) scores[i] = entropy return scores @dataclass class UncertaintyConfig: use_kinematic_variance: bool = True use_prediction_residual: bool = True use_spatial_density: bool = True use_flight_phase_entropy: bool = True use_temporal_irregularity: bool = False n_bins: int = 16 window: int = 5 @property def n_methods(self): return sum([ self.use_kinematic_variance, self.use_prediction_residual, self.use_spatial_density, self.use_flight_phase_entropy, self.use_temporal_irregularity, ]) def compute_all_uncertainties(east, north, up, timestamps, cog, sog, rot, alt_rate, config=None, raw_timestamps=None, all_trajectories_enu=None): if config is None: config = UncertaintyConfig() results = {} if config.use_kinematic_variance: results['kinematic_var'] = uncertainty_kinematic_variance(cog, sog, rot, alt_rate, window=config.window) if config.use_prediction_residual: results['pred_residual'] = uncertainty_prediction_residual(east, north, up, timestamps, horizon=3) if config.use_spatial_density: results['spatial_density'] = uncertainty_spatial_density(east, north, up, all_trajectories_enu) if config.use_flight_phase_entropy: results['phase_entropy'] = uncertainty_flight_phase_entropy(sog, alt_rate, up, window=config.window * 2) return results def discretize_scores(scores, n_bins=16): if len(np.unique(scores)) < n_bins: edges = np.linspace(scores.min(), scores.max() + 1e-10, n_bins + 1) else: edges = np.quantile(scores, np.linspace(0, 1, n_bins + 1)) edges[-1] += 1e-10 return np.clip(np.digitize(scores, edges) - 1, 0, n_bins - 1) class HeteroscedasticHead(nn.Module): def __init__(self, d_model, n_outputs=6): super().__init__() self.log_var_head = nn.Sequential( nn.Linear(d_model, d_model // 2), nn.GELU(), nn.Linear(d_model // 2, n_outputs), ) def forward(self, hidden_states): return torch.clamp(self.log_var_head(hidden_states), -5.0, 5.0) class MultiUncertaintyEmbedding(nn.Module): def __init__(self, d_model, n_methods, n_bins=16): super().__init__() self.n_methods = n_methods self.n_bins = n_bins self.embeds = nn.ModuleList([nn.Embedding(n_bins, d_model) for _ in range(n_methods)]) if n_methods > 1: self.method_attention = nn.Sequential( nn.Linear(d_model * n_methods, n_methods), nn.Softmax(dim=-1), ) def forward(self, uncert_bins): B, L, M = uncert_bins.shape embeds = [self.embeds[m](uncert_bins[:, :, m]) for m in range(M)] if M == 1: return embeds[0] concat = torch.cat(embeds, dim=-1) weights = self.method_attention(concat) stacked = torch.stack(embeds, dim=-1) return (stacked * weights.unsqueeze(2)).sum(dim=-1)