Jdice27 commited on
Commit
8cae25f
·
verified ·
1 Parent(s): 0c01cdc

Add model module

Browse files
Files changed (1) hide show
  1. llm4airtrack/model.py +222 -0
llm4airtrack/model.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLM4AirTrack: LLM-Driven Multi-Feature Fusion for Aircraft Trajectory Prediction.
3
+
4
+ Architecture (adapted from LLM4STP/Time-LLM for ADS-B):
5
+
6
+ ADS-B Features (9-dim) → RevIN → Patch Tokenizer → Patch Embedder
7
+ → Cross-Attention Reprogrammer (learned text prototypes)
8
+ → Prompt-as-Prefix → Frozen GPT-2/LLaMA Backbone
9
+ → Trajectory Head (future xyz) + Classification Head (route class)
10
+
11
+ Trainable parameters: ~2-5% (adapters only)
12
+ Frozen: LLM backbone (preserves language understanding for reprogramming)
13
+ """
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
19
+ from typing import Optional, Dict
20
+
21
+
22
+ class RevIN(nn.Module):
23
+ """Reversible Instance Normalization."""
24
+ def __init__(self, n_features, eps=1e-5):
25
+ super().__init__()
26
+ self.eps = eps
27
+ self.affine_weight = nn.Parameter(torch.ones(n_features))
28
+ self.affine_bias = nn.Parameter(torch.zeros(n_features))
29
+
30
+ def forward(self, x, mode="norm"):
31
+ if mode == "norm":
32
+ self._mean = x.mean(dim=1, keepdim=True).detach()
33
+ self._std = (x.std(dim=1, keepdim=True) + self.eps).detach()
34
+ x = (x - self._mean) / self._std
35
+ x = x * self.affine_weight + self.affine_bias
36
+ elif mode == "denorm":
37
+ x = (x - self.affine_bias[:3]) / (self.affine_weight[:3] + self.eps)
38
+ x = x * self._std[:, :, :3] + self._mean[:, :, :3]
39
+ return x
40
+
41
+
42
+ class PatchTokenizer(nn.Module):
43
+ """Convert time series into overlapping patches."""
44
+ def __init__(self, patch_len=8, stride=4, n_features=9):
45
+ super().__init__()
46
+ self.patch_len = patch_len
47
+ self.stride = stride
48
+
49
+ def forward(self, x):
50
+ B, T, F = x.shape
51
+ x = x.unfold(1, self.patch_len, self.stride)
52
+ x = x.permute(0, 1, 3, 2).contiguous()
53
+ return x.reshape(B, x.shape[1], self.patch_len * F)
54
+
55
+ def n_patches(self, seq_len):
56
+ return (seq_len - self.patch_len) // self.stride + 1
57
+
58
+
59
+ class CrossAttentionReprogrammer(nn.Module):
60
+ """Reprogram trajectory patches into LLM text space via cross-attention over learned prototypes."""
61
+ def __init__(self, d_model, n_heads=8, n_prototypes=256, dropout=0.1):
62
+ super().__init__()
63
+ self.prototypes = nn.Parameter(torch.randn(n_prototypes, d_model) * 0.02)
64
+ self.cross_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=n_heads, dropout=dropout, batch_first=True)
65
+ self.layer_norm = nn.LayerNorm(d_model)
66
+ self.dropout = nn.Dropout(dropout)
67
+
68
+ def forward(self, patch_embeds):
69
+ B = patch_embeds.shape[0]
70
+ protos = self.prototypes.unsqueeze(0).expand(B, -1, -1)
71
+ attn_out, _ = self.cross_attn(query=patch_embeds, key=protos, value=protos)
72
+ return self.layer_norm(patch_embeds + self.dropout(attn_out))
73
+
74
+
75
+ class TrajectoryPredictionHead(nn.Module):
76
+ """Maps LLM hidden states to future trajectory (x,y,z)."""
77
+ def __init__(self, d_model, pred_len, n_output=3):
78
+ super().__init__()
79
+ self.pred_len = pred_len
80
+ self.n_output = n_output
81
+ self.proj = nn.Sequential(
82
+ nn.Linear(d_model, d_model // 2), nn.GELU(), nn.Dropout(0.1),
83
+ nn.Linear(d_model // 2, pred_len * n_output),
84
+ )
85
+
86
+ def forward(self, hidden):
87
+ return self.proj(hidden.mean(dim=1)).reshape(-1, self.pred_len, self.n_output)
88
+
89
+
90
+ class ClassificationHead(nn.Module):
91
+ """Route/procedure classification from LLM hidden states."""
92
+ def __init__(self, d_model, n_classes):
93
+ super().__init__()
94
+ self.cls = nn.Sequential(
95
+ nn.Linear(d_model, d_model // 4), nn.GELU(), nn.Dropout(0.2),
96
+ nn.Linear(d_model // 4, n_classes),
97
+ )
98
+
99
+ def forward(self, hidden):
100
+ return self.cls(hidden.mean(dim=1))
101
+
102
+
103
+ class LLM4AirTrack(nn.Module):
104
+ """
105
+ LLM-Driven Multi-Feature Fusion for Aircraft Trajectory Prediction.
106
+
107
+ Args:
108
+ llm_name: HuggingFace model ID for the LLM backbone
109
+ n_input_features: Number of input features (default: 9 kinematic)
110
+ context_len: Input context window length in timesteps
111
+ pred_len: Prediction horizon in timesteps
112
+ patch_len: Temporal patch length
113
+ patch_stride: Patch stride
114
+ n_prototypes: Number of learned text prototypes
115
+ n_classes: Number of route/procedure classes
116
+ reprogrammer_heads: Number of cross-attention heads
117
+ dropout: Dropout rate
118
+ freeze_llm: Whether to freeze LLM backbone
119
+ """
120
+ def __init__(self, llm_name="openai-community/gpt2", n_input_features=9,
121
+ context_len=60, pred_len=30, patch_len=8, patch_stride=4,
122
+ n_prototypes=256, n_classes=39, reprogrammer_heads=8,
123
+ dropout=0.1, freeze_llm=True,
124
+ prompt_text="This is an aircraft trajectory in 3D airspace near an airport. "
125
+ "The data represents ADS-B surveillance with position, velocity, and polar components. "
126
+ "Predict the future trajectory."):
127
+ super().__init__()
128
+ self.pred_len = pred_len
129
+ self.freeze_llm = freeze_llm
130
+
131
+ # LLM backbone
132
+ config = AutoConfig.from_pretrained(llm_name)
133
+ self.d_llm = config.hidden_size
134
+ self.tokenizer = AutoTokenizer.from_pretrained(llm_name)
135
+ if self.tokenizer.pad_token is None:
136
+ self.tokenizer.pad_token = self.tokenizer.eos_token
137
+ self.llm = AutoModelForCausalLM.from_pretrained(llm_name)
138
+
139
+ if freeze_llm:
140
+ for p in self.llm.parameters():
141
+ p.requires_grad = False
142
+ self.llm.eval()
143
+
144
+ # Backbone reference
145
+ if hasattr(self.llm, 'transformer'):
146
+ self.word_embeddings = self.llm.transformer.wte
147
+ self.backbone = self.llm.transformer
148
+ elif hasattr(self.llm, 'model') and hasattr(self.llm.model, 'embed_tokens'):
149
+ self.word_embeddings = self.llm.model.embed_tokens
150
+ self.backbone = self.llm.model
151
+
152
+ # Prompt
153
+ tokens = self.tokenizer(prompt_text, return_tensors="pt", truncation=True, max_length=64)
154
+ self.register_buffer("prompt_ids", tokens["input_ids"])
155
+
156
+ # Trainable adapters
157
+ self.revin = RevIN(n_input_features)
158
+ self.patcher = PatchTokenizer(patch_len, patch_stride, n_input_features)
159
+ self.patch_embed = nn.Sequential(
160
+ nn.Linear(patch_len * n_input_features, self.d_llm), nn.GELU(),
161
+ nn.LayerNorm(self.d_llm), nn.Dropout(dropout),
162
+ )
163
+ self.reprogrammer = CrossAttentionReprogrammer(self.d_llm, reprogrammer_heads, n_prototypes, dropout)
164
+ self.traj_head = TrajectoryPredictionHead(self.d_llm, pred_len)
165
+ self.cls_head = ClassificationHead(self.d_llm, n_classes)
166
+
167
+ total = sum(p.numel() for p in self.parameters())
168
+ trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
169
+ print(f"Total: {total:,} | Trainable: {trainable:,} ({100*trainable/total:.2f}%)")
170
+
171
+ def forward(self, context, target=None, label=None, task="both"):
172
+ B, device = context.shape[0], context.device
173
+
174
+ x = self.revin(context, mode="norm")
175
+ patches = self.patcher(x)
176
+ patch_emb = self.patch_embed(patches)
177
+ reprogrammed = self.reprogrammer(patch_emb)
178
+
179
+ with torch.no_grad():
180
+ prompt_emb = self.word_embeddings(self.prompt_ids.to(device))
181
+ input_emb = torch.cat([prompt_emb.expand(B, -1, -1), reprogrammed], dim=1)
182
+
183
+ if self.freeze_llm:
184
+ with torch.no_grad():
185
+ hidden = self.backbone(inputs_embeds=input_emb).last_hidden_state.detach()
186
+ else:
187
+ hidden = self.backbone(inputs_embeds=input_emb).last_hidden_state
188
+ hidden = hidden.requires_grad_(True)
189
+
190
+ results = {}
191
+ loss = torch.tensor(0.0, device=device, requires_grad=True)
192
+
193
+ if task in ("predict", "both"):
194
+ pred = self.traj_head(hidden)
195
+ pred = self.revin(pred, mode="denorm")
196
+ results["pred_trajectory"] = pred
197
+ if target is not None:
198
+ traj_loss = F.smooth_l1_loss(pred, target)
199
+ results["traj_loss"] = traj_loss
200
+ loss = loss + traj_loss
201
+
202
+ if task in ("classify", "both"):
203
+ logits = self.cls_head(hidden)
204
+ results["pred_class"] = logits
205
+ if label is not None:
206
+ cls_loss = F.cross_entropy(logits, label)
207
+ results["cls_loss"] = cls_loss
208
+ loss = loss + 0.1 * cls_loss
209
+
210
+ results["loss"] = loss
211
+ return results
212
+
213
+
214
+ def count_parameters(model):
215
+ """Parameter breakdown by module."""
216
+ breakdown = {}
217
+ for name, module in model.named_children():
218
+ total = sum(p.numel() for p in module.parameters())
219
+ trainable = sum(p.numel() for p in module.parameters() if p.requires_grad)
220
+ if total > 0:
221
+ breakdown[name] = {"total": total, "trainable": trainable}
222
+ return breakdown