""" LLM4AirTrack: LLM-Driven Multi-Feature Fusion for Aircraft Trajectory Prediction. Architecture (adapted from LLM4STP/Time-LLM for ADS-B): ADS-B Features (9-dim) → RevIN → Patch Tokenizer → Patch Embedder → Cross-Attention Reprogrammer (learned text prototypes) → Prompt-as-Prefix → Frozen GPT-2/LLaMA Backbone → Trajectory Head (future xyz) + Classification Head (route class) Trainable parameters: ~2-5% (adapters only) Frozen: LLM backbone (preserves language understanding for reprogramming) """ import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig from typing import Optional, Dict class RevIN(nn.Module): """Reversible Instance Normalization.""" def __init__(self, n_features, eps=1e-5): super().__init__() self.eps = eps self.affine_weight = nn.Parameter(torch.ones(n_features)) self.affine_bias = nn.Parameter(torch.zeros(n_features)) def forward(self, x, mode="norm"): if mode == "norm": self._mean = x.mean(dim=1, keepdim=True).detach() self._std = (x.std(dim=1, keepdim=True) + self.eps).detach() x = (x - self._mean) / self._std x = x * self.affine_weight + self.affine_bias elif mode == "denorm": x = (x - self.affine_bias[:3]) / (self.affine_weight[:3] + self.eps) x = x * self._std[:, :, :3] + self._mean[:, :, :3] return x class PatchTokenizer(nn.Module): """Convert time series into overlapping patches.""" def __init__(self, patch_len=8, stride=4, n_features=9): super().__init__() self.patch_len = patch_len self.stride = stride def forward(self, x): B, T, F = x.shape x = x.unfold(1, self.patch_len, self.stride) x = x.permute(0, 1, 3, 2).contiguous() return x.reshape(B, x.shape[1], self.patch_len * F) def n_patches(self, seq_len): return (seq_len - self.patch_len) // self.stride + 1 class CrossAttentionReprogrammer(nn.Module): """Reprogram trajectory patches into LLM text space via cross-attention over learned prototypes.""" def __init__(self, d_model, n_heads=8, n_prototypes=256, dropout=0.1): super().__init__() self.prototypes = nn.Parameter(torch.randn(n_prototypes, d_model) * 0.02) self.cross_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=n_heads, dropout=dropout, batch_first=True) self.layer_norm = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) def forward(self, patch_embeds): B = patch_embeds.shape[0] protos = self.prototypes.unsqueeze(0).expand(B, -1, -1) attn_out, _ = self.cross_attn(query=patch_embeds, key=protos, value=protos) return self.layer_norm(patch_embeds + self.dropout(attn_out)) class TrajectoryPredictionHead(nn.Module): """Maps LLM hidden states to future trajectory (x,y,z).""" def __init__(self, d_model, pred_len, n_output=3): super().__init__() self.pred_len = pred_len self.n_output = n_output self.proj = nn.Sequential( nn.Linear(d_model, d_model // 2), nn.GELU(), nn.Dropout(0.1), nn.Linear(d_model // 2, pred_len * n_output), ) def forward(self, hidden): return self.proj(hidden.mean(dim=1)).reshape(-1, self.pred_len, self.n_output) class ClassificationHead(nn.Module): """Route/procedure classification from LLM hidden states.""" def __init__(self, d_model, n_classes): super().__init__() self.cls = nn.Sequential( nn.Linear(d_model, d_model // 4), nn.GELU(), nn.Dropout(0.2), nn.Linear(d_model // 4, n_classes), ) def forward(self, hidden): return self.cls(hidden.mean(dim=1)) class LLM4AirTrack(nn.Module): """ LLM-Driven Multi-Feature Fusion for Aircraft Trajectory Prediction. Args: llm_name: HuggingFace model ID for the LLM backbone n_input_features: Number of input features (default: 9 kinematic) context_len: Input context window length in timesteps pred_len: Prediction horizon in timesteps patch_len: Temporal patch length patch_stride: Patch stride n_prototypes: Number of learned text prototypes n_classes: Number of route/procedure classes reprogrammer_heads: Number of cross-attention heads dropout: Dropout rate freeze_llm: Whether to freeze LLM backbone """ def __init__(self, llm_name="openai-community/gpt2", n_input_features=9, context_len=60, pred_len=30, patch_len=8, patch_stride=4, n_prototypes=256, n_classes=39, reprogrammer_heads=8, dropout=0.1, freeze_llm=True, prompt_text="This is an aircraft trajectory in 3D airspace near an airport. " "The data represents ADS-B surveillance with position, velocity, and polar components. " "Predict the future trajectory."): super().__init__() self.pred_len = pred_len self.freeze_llm = freeze_llm # LLM backbone config = AutoConfig.from_pretrained(llm_name) self.d_llm = config.hidden_size self.tokenizer = AutoTokenizer.from_pretrained(llm_name) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.llm = AutoModelForCausalLM.from_pretrained(llm_name) if freeze_llm: for p in self.llm.parameters(): p.requires_grad = False self.llm.eval() # Backbone reference if hasattr(self.llm, 'transformer'): self.word_embeddings = self.llm.transformer.wte self.backbone = self.llm.transformer elif hasattr(self.llm, 'model') and hasattr(self.llm.model, 'embed_tokens'): self.word_embeddings = self.llm.model.embed_tokens self.backbone = self.llm.model # Prompt tokens = self.tokenizer(prompt_text, return_tensors="pt", truncation=True, max_length=64) self.register_buffer("prompt_ids", tokens["input_ids"]) # Trainable adapters self.revin = RevIN(n_input_features) self.patcher = PatchTokenizer(patch_len, patch_stride, n_input_features) self.patch_embed = nn.Sequential( nn.Linear(patch_len * n_input_features, self.d_llm), nn.GELU(), nn.LayerNorm(self.d_llm), nn.Dropout(dropout), ) self.reprogrammer = CrossAttentionReprogrammer(self.d_llm, reprogrammer_heads, n_prototypes, dropout) self.traj_head = TrajectoryPredictionHead(self.d_llm, pred_len) self.cls_head = ClassificationHead(self.d_llm, n_classes) total = sum(p.numel() for p in self.parameters()) trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) print(f"Total: {total:,} | Trainable: {trainable:,} ({100*trainable/total:.2f}%)") def forward(self, context, target=None, label=None, task="both"): B, device = context.shape[0], context.device x = self.revin(context, mode="norm") patches = self.patcher(x) patch_emb = self.patch_embed(patches) reprogrammed = self.reprogrammer(patch_emb) with torch.no_grad(): prompt_emb = self.word_embeddings(self.prompt_ids.to(device)) input_emb = torch.cat([prompt_emb.expand(B, -1, -1), reprogrammed], dim=1) if self.freeze_llm: with torch.no_grad(): hidden = self.backbone(inputs_embeds=input_emb).last_hidden_state.detach() else: hidden = self.backbone(inputs_embeds=input_emb).last_hidden_state hidden = hidden.requires_grad_(True) results = {} loss = torch.tensor(0.0, device=device, requires_grad=True) if task in ("predict", "both"): pred = self.traj_head(hidden) pred = self.revin(pred, mode="denorm") results["pred_trajectory"] = pred if target is not None: traj_loss = F.smooth_l1_loss(pred, target) results["traj_loss"] = traj_loss loss = loss + traj_loss if task in ("classify", "both"): logits = self.cls_head(hidden) results["pred_class"] = logits if label is not None: cls_loss = F.cross_entropy(logits, label) results["cls_loss"] = cls_loss loss = loss + 0.1 * cls_loss results["loss"] = loss return results def count_parameters(model): """Parameter breakdown by module.""" breakdown = {} for name, module in model.named_children(): total = sum(p.numel() for p in module.parameters()) trainable = sum(p.numel() for p in module.parameters() if p.requires_grad) if total > 0: breakdown[name] = {"total": total, "trainable": trainable} return breakdown