Jdice27's picture
Add model module
8cae25f verified
"""
LLM4AirTrack: LLM-Driven Multi-Feature Fusion for Aircraft Trajectory Prediction.
Architecture (adapted from LLM4STP/Time-LLM for ADS-B):
ADS-B Features (9-dim) → RevIN → Patch Tokenizer → Patch Embedder
→ Cross-Attention Reprogrammer (learned text prototypes)
→ Prompt-as-Prefix → Frozen GPT-2/LLaMA Backbone
→ Trajectory Head (future xyz) + Classification Head (route class)
Trainable parameters: ~2-5% (adapters only)
Frozen: LLM backbone (preserves language understanding for reprogramming)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from typing import Optional, Dict
class RevIN(nn.Module):
"""Reversible Instance Normalization."""
def __init__(self, n_features, eps=1e-5):
super().__init__()
self.eps = eps
self.affine_weight = nn.Parameter(torch.ones(n_features))
self.affine_bias = nn.Parameter(torch.zeros(n_features))
def forward(self, x, mode="norm"):
if mode == "norm":
self._mean = x.mean(dim=1, keepdim=True).detach()
self._std = (x.std(dim=1, keepdim=True) + self.eps).detach()
x = (x - self._mean) / self._std
x = x * self.affine_weight + self.affine_bias
elif mode == "denorm":
x = (x - self.affine_bias[:3]) / (self.affine_weight[:3] + self.eps)
x = x * self._std[:, :, :3] + self._mean[:, :, :3]
return x
class PatchTokenizer(nn.Module):
"""Convert time series into overlapping patches."""
def __init__(self, patch_len=8, stride=4, n_features=9):
super().__init__()
self.patch_len = patch_len
self.stride = stride
def forward(self, x):
B, T, F = x.shape
x = x.unfold(1, self.patch_len, self.stride)
x = x.permute(0, 1, 3, 2).contiguous()
return x.reshape(B, x.shape[1], self.patch_len * F)
def n_patches(self, seq_len):
return (seq_len - self.patch_len) // self.stride + 1
class CrossAttentionReprogrammer(nn.Module):
"""Reprogram trajectory patches into LLM text space via cross-attention over learned prototypes."""
def __init__(self, d_model, n_heads=8, n_prototypes=256, dropout=0.1):
super().__init__()
self.prototypes = nn.Parameter(torch.randn(n_prototypes, d_model) * 0.02)
self.cross_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=n_heads, dropout=dropout, batch_first=True)
self.layer_norm = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, patch_embeds):
B = patch_embeds.shape[0]
protos = self.prototypes.unsqueeze(0).expand(B, -1, -1)
attn_out, _ = self.cross_attn(query=patch_embeds, key=protos, value=protos)
return self.layer_norm(patch_embeds + self.dropout(attn_out))
class TrajectoryPredictionHead(nn.Module):
"""Maps LLM hidden states to future trajectory (x,y,z)."""
def __init__(self, d_model, pred_len, n_output=3):
super().__init__()
self.pred_len = pred_len
self.n_output = n_output
self.proj = nn.Sequential(
nn.Linear(d_model, d_model // 2), nn.GELU(), nn.Dropout(0.1),
nn.Linear(d_model // 2, pred_len * n_output),
)
def forward(self, hidden):
return self.proj(hidden.mean(dim=1)).reshape(-1, self.pred_len, self.n_output)
class ClassificationHead(nn.Module):
"""Route/procedure classification from LLM hidden states."""
def __init__(self, d_model, n_classes):
super().__init__()
self.cls = nn.Sequential(
nn.Linear(d_model, d_model // 4), nn.GELU(), nn.Dropout(0.2),
nn.Linear(d_model // 4, n_classes),
)
def forward(self, hidden):
return self.cls(hidden.mean(dim=1))
class LLM4AirTrack(nn.Module):
"""
LLM-Driven Multi-Feature Fusion for Aircraft Trajectory Prediction.
Args:
llm_name: HuggingFace model ID for the LLM backbone
n_input_features: Number of input features (default: 9 kinematic)
context_len: Input context window length in timesteps
pred_len: Prediction horizon in timesteps
patch_len: Temporal patch length
patch_stride: Patch stride
n_prototypes: Number of learned text prototypes
n_classes: Number of route/procedure classes
reprogrammer_heads: Number of cross-attention heads
dropout: Dropout rate
freeze_llm: Whether to freeze LLM backbone
"""
def __init__(self, llm_name="openai-community/gpt2", n_input_features=9,
context_len=60, pred_len=30, patch_len=8, patch_stride=4,
n_prototypes=256, n_classes=39, reprogrammer_heads=8,
dropout=0.1, freeze_llm=True,
prompt_text="This is an aircraft trajectory in 3D airspace near an airport. "
"The data represents ADS-B surveillance with position, velocity, and polar components. "
"Predict the future trajectory."):
super().__init__()
self.pred_len = pred_len
self.freeze_llm = freeze_llm
# LLM backbone
config = AutoConfig.from_pretrained(llm_name)
self.d_llm = config.hidden_size
self.tokenizer = AutoTokenizer.from_pretrained(llm_name)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.llm = AutoModelForCausalLM.from_pretrained(llm_name)
if freeze_llm:
for p in self.llm.parameters():
p.requires_grad = False
self.llm.eval()
# Backbone reference
if hasattr(self.llm, 'transformer'):
self.word_embeddings = self.llm.transformer.wte
self.backbone = self.llm.transformer
elif hasattr(self.llm, 'model') and hasattr(self.llm.model, 'embed_tokens'):
self.word_embeddings = self.llm.model.embed_tokens
self.backbone = self.llm.model
# Prompt
tokens = self.tokenizer(prompt_text, return_tensors="pt", truncation=True, max_length=64)
self.register_buffer("prompt_ids", tokens["input_ids"])
# Trainable adapters
self.revin = RevIN(n_input_features)
self.patcher = PatchTokenizer(patch_len, patch_stride, n_input_features)
self.patch_embed = nn.Sequential(
nn.Linear(patch_len * n_input_features, self.d_llm), nn.GELU(),
nn.LayerNorm(self.d_llm), nn.Dropout(dropout),
)
self.reprogrammer = CrossAttentionReprogrammer(self.d_llm, reprogrammer_heads, n_prototypes, dropout)
self.traj_head = TrajectoryPredictionHead(self.d_llm, pred_len)
self.cls_head = ClassificationHead(self.d_llm, n_classes)
total = sum(p.numel() for p in self.parameters())
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
print(f"Total: {total:,} | Trainable: {trainable:,} ({100*trainable/total:.2f}%)")
def forward(self, context, target=None, label=None, task="both"):
B, device = context.shape[0], context.device
x = self.revin(context, mode="norm")
patches = self.patcher(x)
patch_emb = self.patch_embed(patches)
reprogrammed = self.reprogrammer(patch_emb)
with torch.no_grad():
prompt_emb = self.word_embeddings(self.prompt_ids.to(device))
input_emb = torch.cat([prompt_emb.expand(B, -1, -1), reprogrammed], dim=1)
if self.freeze_llm:
with torch.no_grad():
hidden = self.backbone(inputs_embeds=input_emb).last_hidden_state.detach()
else:
hidden = self.backbone(inputs_embeds=input_emb).last_hidden_state
hidden = hidden.requires_grad_(True)
results = {}
loss = torch.tensor(0.0, device=device, requires_grad=True)
if task in ("predict", "both"):
pred = self.traj_head(hidden)
pred = self.revin(pred, mode="denorm")
results["pred_trajectory"] = pred
if target is not None:
traj_loss = F.smooth_l1_loss(pred, target)
results["traj_loss"] = traj_loss
loss = loss + traj_loss
if task in ("classify", "both"):
logits = self.cls_head(hidden)
results["pred_class"] = logits
if label is not None:
cls_loss = F.cross_entropy(logits, label)
results["cls_loss"] = cls_loss
loss = loss + 0.1 * cls_loss
results["loss"] = loss
return results
def count_parameters(model):
"""Parameter breakdown by module."""
breakdown = {}
for name, module in model.named_children():
total = sum(p.numel() for p in module.parameters())
trainable = sum(p.numel() for p in module.parameters() if p.requires_grad)
if total > 0:
breakdown[name] = {"total": total, "trainable": trainable}
return breakdown