| """ |
| LLM4AirTrack: LLM-Driven Multi-Feature Fusion for Aircraft Trajectory Prediction. |
| |
| Architecture (adapted from LLM4STP/Time-LLM for ADS-B): |
| |
| ADS-B Features (9-dim) → RevIN → Patch Tokenizer → Patch Embedder |
| → Cross-Attention Reprogrammer (learned text prototypes) |
| → Prompt-as-Prefix → Frozen GPT-2/LLaMA Backbone |
| → Trajectory Head (future xyz) + Classification Head (route class) |
| |
| Trainable parameters: ~2-5% (adapters only) |
| Frozen: LLM backbone (preserves language understanding for reprogramming) |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig |
| from typing import Optional, Dict |
|
|
|
|
| class RevIN(nn.Module): |
| """Reversible Instance Normalization.""" |
| def __init__(self, n_features, eps=1e-5): |
| super().__init__() |
| self.eps = eps |
| self.affine_weight = nn.Parameter(torch.ones(n_features)) |
| self.affine_bias = nn.Parameter(torch.zeros(n_features)) |
| |
| def forward(self, x, mode="norm"): |
| if mode == "norm": |
| self._mean = x.mean(dim=1, keepdim=True).detach() |
| self._std = (x.std(dim=1, keepdim=True) + self.eps).detach() |
| x = (x - self._mean) / self._std |
| x = x * self.affine_weight + self.affine_bias |
| elif mode == "denorm": |
| x = (x - self.affine_bias[:3]) / (self.affine_weight[:3] + self.eps) |
| x = x * self._std[:, :, :3] + self._mean[:, :, :3] |
| return x |
|
|
|
|
| class PatchTokenizer(nn.Module): |
| """Convert time series into overlapping patches.""" |
| def __init__(self, patch_len=8, stride=4, n_features=9): |
| super().__init__() |
| self.patch_len = patch_len |
| self.stride = stride |
| |
| def forward(self, x): |
| B, T, F = x.shape |
| x = x.unfold(1, self.patch_len, self.stride) |
| x = x.permute(0, 1, 3, 2).contiguous() |
| return x.reshape(B, x.shape[1], self.patch_len * F) |
| |
| def n_patches(self, seq_len): |
| return (seq_len - self.patch_len) // self.stride + 1 |
|
|
|
|
| class CrossAttentionReprogrammer(nn.Module): |
| """Reprogram trajectory patches into LLM text space via cross-attention over learned prototypes.""" |
| def __init__(self, d_model, n_heads=8, n_prototypes=256, dropout=0.1): |
| super().__init__() |
| self.prototypes = nn.Parameter(torch.randn(n_prototypes, d_model) * 0.02) |
| self.cross_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=n_heads, dropout=dropout, batch_first=True) |
| self.layer_norm = nn.LayerNorm(d_model) |
| self.dropout = nn.Dropout(dropout) |
| |
| def forward(self, patch_embeds): |
| B = patch_embeds.shape[0] |
| protos = self.prototypes.unsqueeze(0).expand(B, -1, -1) |
| attn_out, _ = self.cross_attn(query=patch_embeds, key=protos, value=protos) |
| return self.layer_norm(patch_embeds + self.dropout(attn_out)) |
|
|
|
|
| class TrajectoryPredictionHead(nn.Module): |
| """Maps LLM hidden states to future trajectory (x,y,z).""" |
| def __init__(self, d_model, pred_len, n_output=3): |
| super().__init__() |
| self.pred_len = pred_len |
| self.n_output = n_output |
| self.proj = nn.Sequential( |
| nn.Linear(d_model, d_model // 2), nn.GELU(), nn.Dropout(0.1), |
| nn.Linear(d_model // 2, pred_len * n_output), |
| ) |
| |
| def forward(self, hidden): |
| return self.proj(hidden.mean(dim=1)).reshape(-1, self.pred_len, self.n_output) |
|
|
|
|
| class ClassificationHead(nn.Module): |
| """Route/procedure classification from LLM hidden states.""" |
| def __init__(self, d_model, n_classes): |
| super().__init__() |
| self.cls = nn.Sequential( |
| nn.Linear(d_model, d_model // 4), nn.GELU(), nn.Dropout(0.2), |
| nn.Linear(d_model // 4, n_classes), |
| ) |
| |
| def forward(self, hidden): |
| return self.cls(hidden.mean(dim=1)) |
|
|
|
|
| class LLM4AirTrack(nn.Module): |
| """ |
| LLM-Driven Multi-Feature Fusion for Aircraft Trajectory Prediction. |
| |
| Args: |
| llm_name: HuggingFace model ID for the LLM backbone |
| n_input_features: Number of input features (default: 9 kinematic) |
| context_len: Input context window length in timesteps |
| pred_len: Prediction horizon in timesteps |
| patch_len: Temporal patch length |
| patch_stride: Patch stride |
| n_prototypes: Number of learned text prototypes |
| n_classes: Number of route/procedure classes |
| reprogrammer_heads: Number of cross-attention heads |
| dropout: Dropout rate |
| freeze_llm: Whether to freeze LLM backbone |
| """ |
| def __init__(self, llm_name="openai-community/gpt2", n_input_features=9, |
| context_len=60, pred_len=30, patch_len=8, patch_stride=4, |
| n_prototypes=256, n_classes=39, reprogrammer_heads=8, |
| dropout=0.1, freeze_llm=True, |
| prompt_text="This is an aircraft trajectory in 3D airspace near an airport. " |
| "The data represents ADS-B surveillance with position, velocity, and polar components. " |
| "Predict the future trajectory."): |
| super().__init__() |
| self.pred_len = pred_len |
| self.freeze_llm = freeze_llm |
| |
| |
| config = AutoConfig.from_pretrained(llm_name) |
| self.d_llm = config.hidden_size |
| self.tokenizer = AutoTokenizer.from_pretrained(llm_name) |
| if self.tokenizer.pad_token is None: |
| self.tokenizer.pad_token = self.tokenizer.eos_token |
| self.llm = AutoModelForCausalLM.from_pretrained(llm_name) |
| |
| if freeze_llm: |
| for p in self.llm.parameters(): |
| p.requires_grad = False |
| self.llm.eval() |
| |
| |
| if hasattr(self.llm, 'transformer'): |
| self.word_embeddings = self.llm.transformer.wte |
| self.backbone = self.llm.transformer |
| elif hasattr(self.llm, 'model') and hasattr(self.llm.model, 'embed_tokens'): |
| self.word_embeddings = self.llm.model.embed_tokens |
| self.backbone = self.llm.model |
| |
| |
| tokens = self.tokenizer(prompt_text, return_tensors="pt", truncation=True, max_length=64) |
| self.register_buffer("prompt_ids", tokens["input_ids"]) |
| |
| |
| self.revin = RevIN(n_input_features) |
| self.patcher = PatchTokenizer(patch_len, patch_stride, n_input_features) |
| self.patch_embed = nn.Sequential( |
| nn.Linear(patch_len * n_input_features, self.d_llm), nn.GELU(), |
| nn.LayerNorm(self.d_llm), nn.Dropout(dropout), |
| ) |
| self.reprogrammer = CrossAttentionReprogrammer(self.d_llm, reprogrammer_heads, n_prototypes, dropout) |
| self.traj_head = TrajectoryPredictionHead(self.d_llm, pred_len) |
| self.cls_head = ClassificationHead(self.d_llm, n_classes) |
| |
| total = sum(p.numel() for p in self.parameters()) |
| trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) |
| print(f"Total: {total:,} | Trainable: {trainable:,} ({100*trainable/total:.2f}%)") |
| |
| def forward(self, context, target=None, label=None, task="both"): |
| B, device = context.shape[0], context.device |
| |
| x = self.revin(context, mode="norm") |
| patches = self.patcher(x) |
| patch_emb = self.patch_embed(patches) |
| reprogrammed = self.reprogrammer(patch_emb) |
| |
| with torch.no_grad(): |
| prompt_emb = self.word_embeddings(self.prompt_ids.to(device)) |
| input_emb = torch.cat([prompt_emb.expand(B, -1, -1), reprogrammed], dim=1) |
| |
| if self.freeze_llm: |
| with torch.no_grad(): |
| hidden = self.backbone(inputs_embeds=input_emb).last_hidden_state.detach() |
| else: |
| hidden = self.backbone(inputs_embeds=input_emb).last_hidden_state |
| hidden = hidden.requires_grad_(True) |
| |
| results = {} |
| loss = torch.tensor(0.0, device=device, requires_grad=True) |
| |
| if task in ("predict", "both"): |
| pred = self.traj_head(hidden) |
| pred = self.revin(pred, mode="denorm") |
| results["pred_trajectory"] = pred |
| if target is not None: |
| traj_loss = F.smooth_l1_loss(pred, target) |
| results["traj_loss"] = traj_loss |
| loss = loss + traj_loss |
| |
| if task in ("classify", "both"): |
| logits = self.cls_head(hidden) |
| results["pred_class"] = logits |
| if label is not None: |
| cls_loss = F.cross_entropy(logits, label) |
| results["cls_loss"] = cls_loss |
| loss = loss + 0.1 * cls_loss |
| |
| results["loss"] = loss |
| return results |
|
|
|
|
| def count_parameters(model): |
| """Parameter breakdown by module.""" |
| breakdown = {} |
| for name, module in model.named_children(): |
| total = sum(p.numel() for p in module.parameters()) |
| trainable = sum(p.numel() for p in module.parameters() if p.requires_grad) |
| if total > 0: |
| breakdown[name] = {"total": total, "trainable": trainable} |
| return breakdown |
|
|