Jdice27 commited on
Commit
434f8c5
·
verified ·
1 Parent(s): 73acb6d

Upload train_full.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_full.py +661 -0
train_full.py ADDED
@@ -0,0 +1,661 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Full-scale training script for LLM4AirTrack.
3
+ Trains on RKSIa (Incheon arrivals) - full dataset.
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import time
9
+ import json
10
+ import math
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from torch.optim import AdamW
16
+ from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
17
+ from torch.utils.data import Dataset, DataLoader
18
+ from pathlib import Path
19
+ from huggingface_hub import hf_hub_download, HfApi
20
+ import pandas as pd
21
+ from scipy.ndimage import uniform_filter1d
22
+
23
+
24
+ # ============================================================
25
+ # DATA MODULE
26
+ # ============================================================
27
+
28
+ def download_atfm_dataset(airport="RKSIa", cache_dir="/app/data/ATFMTraj"):
29
+ os.makedirs(cache_dir, exist_ok=True)
30
+ airport_dir = os.path.join(cache_dir, airport)
31
+ os.makedirs(airport_dir, exist_ok=True)
32
+ for mode in ["TRAIN", "TEST"]:
33
+ for var in ["X", "Y", "Z"]:
34
+ fname = f"{airport}_{mode}_{var}.tsv"
35
+ fpath = os.path.join(airport_dir, fname)
36
+ if not os.path.exists(fpath):
37
+ print(f"Downloading {airport}/{fname}...")
38
+ hf_hub_download(
39
+ repo_id="petchthwr/ATFMTraj",
40
+ filename=f"{airport}/{fname}",
41
+ repo_type="dataset",
42
+ local_dir=cache_dir,
43
+ )
44
+ return airport_dir
45
+
46
+
47
+ def load_atfm_raw(airport, mode, cache_dir):
48
+ airport_dir = os.path.join(cache_dir, airport)
49
+ data, labels = [], None
50
+ for var in ['X', 'Y', 'Z']:
51
+ df = pd.read_csv(
52
+ os.path.join(airport_dir, f"{airport}_{mode}_{var}.tsv"),
53
+ sep='\t', header=None, na_values='NaN'
54
+ )
55
+ if labels is None:
56
+ labels = df.values[:, 0]
57
+ data.append(df.values[:, 1:])
58
+ return np.stack(data, axis=-1), labels.astype(int)
59
+
60
+
61
+ def compute_kinematic_features(trajectory, dt=1.0):
62
+ x, y, z = trajectory[:, 0], trajectory[:, 1], trajectory[:, 2]
63
+ dx, dy, dz = np.gradient(x)/dt, np.gradient(y)/dt, np.gradient(z)/dt
64
+ speed = np.sqrt(dx**2 + dy**2 + dz**2) + 1e-8
65
+ ux, uy, uz = dx/speed, dy/speed, dz/speed
66
+ r = np.sqrt(x**2 + y**2) + 1e-8
67
+ theta = np.arctan2(y, x)
68
+ return np.stack([x, y, z, ux, uy, uz, r, np.sin(theta), np.cos(theta)], axis=-1)
69
+
70
+
71
+ def create_windows(data, labels, context_len=60, pred_len=30, stride=15):
72
+ total_len = context_len + pred_len
73
+ contexts, targets, sample_labels = [], [], []
74
+ for i in range(len(data)):
75
+ traj = data[i]
76
+ valid_mask = ~np.isnan(traj[:, 0])
77
+ valid_len = np.sum(valid_mask)
78
+ if valid_len < total_len:
79
+ continue
80
+ traj_valid = traj[valid_mask]
81
+ for start in range(0, valid_len - total_len + 1, stride):
82
+ ctx_raw = traj_valid[start:start + context_len]
83
+ tgt = traj_valid[start + context_len:start + total_len]
84
+ ctx = compute_kinematic_features(ctx_raw)
85
+ contexts.append(ctx)
86
+ targets.append(tgt)
87
+ sample_labels.append(labels[i])
88
+ return (
89
+ np.array(contexts, dtype=np.float32),
90
+ np.array(targets, dtype=np.float32),
91
+ np.array(sample_labels, dtype=np.int64),
92
+ )
93
+
94
+
95
+ class AirTrackDataset(Dataset):
96
+ def __init__(self, contexts, targets, labels):
97
+ self.contexts = torch.from_numpy(contexts)
98
+ self.targets = torch.from_numpy(targets)
99
+ self.labels = torch.from_numpy(labels)
100
+
101
+ def __len__(self):
102
+ return len(self.contexts)
103
+
104
+ def __getitem__(self, idx):
105
+ return {
106
+ "context": self.contexts[idx],
107
+ "target": self.targets[idx],
108
+ "label": self.labels[idx],
109
+ }
110
+
111
+
112
+ # ============================================================
113
+ # MODEL MODULE
114
+ # ============================================================
115
+
116
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
117
+
118
+
119
+ class RevIN(nn.Module):
120
+ def __init__(self, n_features, eps=1e-5):
121
+ super().__init__()
122
+ self.eps = eps
123
+ self.affine_weight = nn.Parameter(torch.ones(n_features))
124
+ self.affine_bias = nn.Parameter(torch.zeros(n_features))
125
+
126
+ def forward(self, x, mode="norm"):
127
+ if mode == "norm":
128
+ self._mean = x.mean(dim=1, keepdim=True).detach()
129
+ self._std = (x.std(dim=1, keepdim=True) + self.eps).detach()
130
+ x = (x - self._mean) / self._std
131
+ x = x * self.affine_weight + self.affine_bias
132
+ elif mode == "denorm":
133
+ x = (x - self.affine_bias[:3]) / (self.affine_weight[:3] + self.eps)
134
+ x = x * self._std[:, :, :3] + self._mean[:, :, :3]
135
+ return x
136
+
137
+
138
+ class PatchTokenizer(nn.Module):
139
+ def __init__(self, patch_len=8, stride=4):
140
+ super().__init__()
141
+ self.patch_len = patch_len
142
+ self.stride = stride
143
+
144
+ def forward(self, x):
145
+ B, T, F = x.shape
146
+ x = x.unfold(1, self.patch_len, self.stride)
147
+ x = x.permute(0, 1, 3, 2).contiguous()
148
+ return x.reshape(B, x.shape[1], self.patch_len * F)
149
+
150
+ def n_patches(self, seq_len):
151
+ return (seq_len - self.patch_len) // self.stride + 1
152
+
153
+
154
+ class CrossAttentionReprogrammer(nn.Module):
155
+ def __init__(self, d_model, n_heads=8, n_prototypes=256, dropout=0.1):
156
+ super().__init__()
157
+ self.prototypes = nn.Parameter(torch.randn(n_prototypes, d_model) * 0.02)
158
+ self.cross_attn = nn.MultiheadAttention(
159
+ embed_dim=d_model, num_heads=n_heads, dropout=dropout, batch_first=True,
160
+ )
161
+ self.layer_norm = nn.LayerNorm(d_model)
162
+ self.dropout = nn.Dropout(dropout)
163
+
164
+ def forward(self, patch_embeds):
165
+ B = patch_embeds.shape[0]
166
+ protos = self.prototypes.unsqueeze(0).expand(B, -1, -1)
167
+ attn_out, _ = self.cross_attn(query=patch_embeds, key=protos, value=protos)
168
+ return self.layer_norm(patch_embeds + self.dropout(attn_out))
169
+
170
+
171
+ class LLM4AirTrack(nn.Module):
172
+ def __init__(
173
+ self,
174
+ llm_name="openai-community/gpt2",
175
+ n_features=9,
176
+ context_len=60,
177
+ pred_len=30,
178
+ patch_len=8,
179
+ patch_stride=4,
180
+ n_prototypes=256,
181
+ n_classes=39,
182
+ n_heads=8,
183
+ dropout=0.1,
184
+ freeze_llm=True,
185
+ prompt_text=(
186
+ "This is an aircraft trajectory in 3D airspace near an airport. "
187
+ "The data represents ADS-B surveillance with position, velocity, and polar components. "
188
+ "Predict the future trajectory."
189
+ ),
190
+ ):
191
+ super().__init__()
192
+ self.pred_len = pred_len
193
+ self.freeze_llm = freeze_llm
194
+
195
+ # LLM backbone
196
+ print(f"Loading LLM: {llm_name}")
197
+ config = AutoConfig.from_pretrained(llm_name)
198
+ self.d_llm = config.hidden_size
199
+ self.tokenizer = AutoTokenizer.from_pretrained(llm_name)
200
+ if self.tokenizer.pad_token is None:
201
+ self.tokenizer.pad_token = self.tokenizer.eos_token
202
+ self.llm = AutoModelForCausalLM.from_pretrained(llm_name)
203
+
204
+ if freeze_llm:
205
+ for p in self.llm.parameters():
206
+ p.requires_grad = False
207
+ self.llm.eval()
208
+
209
+ # Word embeddings reference
210
+ if hasattr(self.llm, 'transformer'):
211
+ self.word_embeddings = self.llm.transformer.wte
212
+ self.backbone = self.llm.transformer
213
+ elif hasattr(self.llm, 'model') and hasattr(self.llm.model, 'embed_tokens'):
214
+ self.word_embeddings = self.llm.model.embed_tokens
215
+ self.backbone = self.llm.model
216
+
217
+ # Prompt
218
+ tokens = self.tokenizer(prompt_text, return_tensors="pt", truncation=True, max_length=64)
219
+ self.register_buffer("prompt_ids", tokens["input_ids"])
220
+
221
+ # Trainable components
222
+ self.revin = RevIN(n_features)
223
+ self.patcher = PatchTokenizer(patch_len, patch_stride)
224
+ self.patch_embed = nn.Sequential(
225
+ nn.Linear(patch_len * n_features, self.d_llm),
226
+ nn.GELU(),
227
+ nn.LayerNorm(self.d_llm),
228
+ nn.Dropout(dropout),
229
+ )
230
+ self.reprogrammer = CrossAttentionReprogrammer(self.d_llm, n_heads, n_prototypes, dropout)
231
+
232
+ # Trajectory prediction head
233
+ self.traj_head = nn.Sequential(
234
+ nn.Linear(self.d_llm, self.d_llm // 2),
235
+ nn.GELU(),
236
+ nn.Dropout(dropout),
237
+ nn.Linear(self.d_llm // 2, pred_len * 3),
238
+ )
239
+
240
+ # Classification head
241
+ self.cls_head = nn.Sequential(
242
+ nn.Linear(self.d_llm, self.d_llm // 4),
243
+ nn.GELU(),
244
+ nn.Dropout(0.2),
245
+ nn.Linear(self.d_llm // 4, n_classes),
246
+ )
247
+
248
+ total = sum(p.numel() for p in self.parameters())
249
+ trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
250
+ print(f"Total: {total:,} | Trainable: {trainable:,} ({100*trainable/total:.2f}%)")
251
+
252
+ def forward(self, context, target=None, label=None):
253
+ B = context.shape[0]
254
+ device = context.device
255
+
256
+ # Normalize
257
+ x = self.revin(context, mode="norm")
258
+
259
+ # Patch + embed
260
+ patches = self.patcher(x)
261
+ patch_emb = self.patch_embed(patches)
262
+
263
+ # Reprogram
264
+ reprogrammed = self.reprogrammer(patch_emb)
265
+
266
+ # Prompt prefix
267
+ with torch.no_grad():
268
+ prompt_emb = self.word_embeddings(self.prompt_ids.to(device))
269
+ prompt_emb = prompt_emb.expand(B, -1, -1)
270
+
271
+ # Assemble and pass through frozen LLM
272
+ input_emb = torch.cat([prompt_emb, reprogrammed], dim=1)
273
+
274
+ if self.freeze_llm:
275
+ with torch.no_grad():
276
+ out = self.backbone(inputs_embeds=input_emb)
277
+ hidden = out.last_hidden_state.detach()
278
+ else:
279
+ out = self.backbone(inputs_embeds=input_emb)
280
+ hidden = out.last_hidden_state
281
+
282
+ hidden = hidden.requires_grad_(True)
283
+ pooled = hidden.mean(dim=1)
284
+
285
+ # Heads
286
+ results = {}
287
+ loss = torch.tensor(0.0, device=device, requires_grad=True)
288
+
289
+ # Trajectory prediction
290
+ pred_flat = self.traj_head(pooled)
291
+ pred_traj = pred_flat.reshape(B, self.pred_len, 3)
292
+ pred_traj = self.revin(pred_traj, mode="denorm")
293
+ results["pred_trajectory"] = pred_traj
294
+
295
+ if target is not None:
296
+ traj_loss = F.smooth_l1_loss(pred_traj, target)
297
+ results["traj_loss"] = traj_loss
298
+ loss = loss + traj_loss
299
+
300
+ # Classification
301
+ class_logits = self.cls_head(pooled)
302
+ results["pred_class"] = class_logits
303
+
304
+ if label is not None:
305
+ cls_loss = F.cross_entropy(class_logits, label)
306
+ results["cls_loss"] = cls_loss
307
+ loss = loss + 0.1 * cls_loss
308
+
309
+ results["loss"] = loss
310
+ return results
311
+
312
+
313
+ # ============================================================
314
+ # TRAINING
315
+ # ============================================================
316
+
317
+ def compute_metrics(pred, target):
318
+ disp = torch.sqrt(((pred - target) ** 2).sum(dim=-1))
319
+ ade = disp.mean().item()
320
+ fde = disp[:, -1].mean().item()
321
+ rmse = torch.sqrt(((pred - target) ** 2).mean(dim=(0, 1)))
322
+ return {
323
+ "ADE": ade, "FDE": fde,
324
+ "RMSE_x": rmse[0].item(), "RMSE_y": rmse[1].item(), "RMSE_z": rmse[2].item(),
325
+ }
326
+
327
+
328
+ def evaluate(model, dataloader, device):
329
+ model.eval()
330
+ total_loss, total_correct, n = 0, 0, 0
331
+ all_preds, all_targets = [], []
332
+ with torch.no_grad():
333
+ for batch in dataloader:
334
+ ctx = batch["context"].to(device)
335
+ tgt = batch["target"].to(device)
336
+ lbl = batch["label"].to(device)
337
+ out = model(ctx, tgt, lbl)
338
+ total_loss += out["loss"].item() * ctx.shape[0]
339
+ if "pred_class" in out:
340
+ total_correct += (out["pred_class"].argmax(-1) == lbl).sum().item()
341
+ all_preds.append(out["pred_trajectory"].cpu())
342
+ all_targets.append(tgt.cpu())
343
+ n += ctx.shape[0]
344
+
345
+ preds = torch.cat(all_preds)
346
+ targets = torch.cat(all_targets)
347
+ metrics = compute_metrics(preds, targets)
348
+ metrics["loss"] = total_loss / n
349
+ metrics["accuracy"] = total_correct / n
350
+ return metrics
351
+
352
+
353
+ def main():
354
+ import trackio
355
+
356
+ # Config
357
+ AIRPORT = "RKSIa"
358
+ CONTEXT_LEN = 60
359
+ PRED_LEN = 30
360
+ STRIDE = 15
361
+ BATCH_SIZE = 128
362
+ EPOCHS = 5
363
+ LR = 5e-4
364
+ LLM_NAME = "openai-community/gpt2"
365
+ HUB_MODEL_ID = "Jdice27/LLM4AirTrack"
366
+ OUTPUT_DIR = "/app/outputs/llm4airtrack"
367
+
368
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
369
+ print(f"Device: {device}")
370
+ if torch.cuda.is_available():
371
+ print(f"GPU: {torch.cuda.get_device_name()}")
372
+
373
+ # Trackio
374
+ tracker = trackio.init(project="LLM4AirTrack", name=f"LLM4AirTrack-{AIRPORT}-gpt2", config={
375
+ "airport": AIRPORT, "context_len": CONTEXT_LEN, "pred_len": PRED_LEN,
376
+ "batch_size": BATCH_SIZE, "epochs": EPOCHS, "lr": LR, "llm": LLM_NAME,
377
+ })
378
+
379
+ # Data
380
+ print(f"\n{'='*60}")
381
+ print(f"Loading {AIRPORT} data...")
382
+ download_atfm_dataset(AIRPORT)
383
+ train_data, train_labels = load_atfm_raw(AIRPORT, "TRAIN", "/app/data/ATFMTraj")
384
+ test_data, test_labels = load_atfm_raw(AIRPORT, "TEST", "/app/data/ATFMTraj")
385
+ print(f"Raw: train={train_data.shape}, test={test_data.shape}")
386
+
387
+ # Use larger stride for training to reduce dataset size, keep test manageable
388
+ train_ctx, train_tgt, train_lbl = create_windows(train_data, train_labels, CONTEXT_LEN, PRED_LEN, stride=30)
389
+ test_ctx, test_tgt, test_lbl = create_windows(test_data, test_labels, CONTEXT_LEN, PRED_LEN, stride=60)
390
+ print(f"Windows: train={train_ctx.shape}, test={test_ctx.shape}", flush=True)
391
+
392
+ all_labels = np.concatenate([train_lbl, test_lbl])
393
+ n_classes = int(all_labels.max()) + 1
394
+ print(f"Classes: {n_classes} (unique in data: {len(np.unique(all_labels))})", flush=True)
395
+
396
+ # Subsample eval set for faster evaluation (use 10% for quick eval)
397
+ eval_size = min(len(test_ctx), 20000)
398
+ eval_idx = np.random.RandomState(42).permutation(len(test_ctx))[:eval_size]
399
+
400
+ train_ds = AirTrackDataset(train_ctx, train_tgt, train_lbl)
401
+ eval_ds = AirTrackDataset(test_ctx[eval_idx], test_tgt[eval_idx], test_lbl[eval_idx])
402
+ train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
403
+ test_loader = DataLoader(eval_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
404
+
405
+ print(f"Train samples: {len(train_ds)}, Eval samples: {len(eval_ds)}", flush=True)
406
+
407
+ # Model
408
+ print(f"\n{'='*60}")
409
+ model = LLM4AirTrack(
410
+ llm_name=LLM_NAME,
411
+ n_features=9,
412
+ context_len=CONTEXT_LEN,
413
+ pred_len=PRED_LEN,
414
+ n_classes=n_classes,
415
+ patch_len=8,
416
+ patch_stride=4,
417
+ n_prototypes=256,
418
+ ).to(device)
419
+
420
+ # Optimizer
421
+ trainable = [p for p in model.parameters() if p.requires_grad]
422
+ optimizer = AdamW(trainable, lr=LR, weight_decay=1e-5)
423
+ scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=len(train_loader), T_mult=2, eta_min=LR * 0.01)
424
+
425
+ # Training
426
+ print(f"\n{'='*60}")
427
+ print(f"Training {EPOCHS} epochs, {len(train_loader)} steps/epoch")
428
+ print(f"{'='*60}\n")
429
+
430
+ best_ade = float("inf")
431
+ best_epoch = -1
432
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
433
+
434
+ for epoch in range(EPOCHS):
435
+ model.train()
436
+ model.backbone.eval() # Keep LLM frozen in eval
437
+
438
+ epoch_loss, epoch_traj, epoch_cls, n_batches = 0, 0, 0, 0
439
+ t0 = time.time()
440
+
441
+ for batch_idx, batch in enumerate(train_loader):
442
+ ctx = batch["context"].to(device)
443
+ tgt = batch["target"].to(device)
444
+ lbl = batch["label"].to(device)
445
+
446
+ out = model(ctx, tgt, lbl)
447
+ loss = out["loss"]
448
+
449
+ optimizer.zero_grad()
450
+ loss.backward()
451
+ torch.nn.utils.clip_grad_norm_(trainable, 1.0)
452
+ optimizer.step()
453
+ scheduler.step()
454
+
455
+ epoch_loss += loss.item()
456
+ epoch_traj += out.get("traj_loss", torch.tensor(0)).item()
457
+ epoch_cls += out.get("cls_loss", torch.tensor(0)).item()
458
+ n_batches += 1
459
+
460
+ trackio.log({
461
+ "train/loss": loss.item(),
462
+ "train/traj_loss": out.get("traj_loss", torch.tensor(0)).item(),
463
+ "train/cls_loss": out.get("cls_loss", torch.tensor(0)).item(),
464
+ "train/lr": optimizer.param_groups[0]["lr"],
465
+ })
466
+
467
+ if (batch_idx + 1) % 25 == 0:
468
+ print(f" [{epoch+1}/{EPOCHS}] step {batch_idx+1}/{len(train_loader)} | "
469
+ f"loss={epoch_loss/n_batches:.6f} traj={epoch_traj/n_batches:.6f} "
470
+ f"cls={epoch_cls/n_batches:.6f} lr={optimizer.param_groups[0]['lr']:.2e}",
471
+ flush=True)
472
+
473
+ dt = time.time() - t0
474
+ avg_loss = epoch_loss / n_batches
475
+
476
+ # Evaluate
477
+ metrics = evaluate(model, test_loader, device)
478
+
479
+ print(f"\nEpoch {epoch+1}/{EPOCHS} ({dt:.0f}s) | "
480
+ f"Train loss: {avg_loss:.6f} | "
481
+ f"Eval ADE: {metrics['ADE']:.6f} FDE: {metrics['FDE']:.6f} | "
482
+ f"Acc: {metrics['accuracy']:.4f}")
483
+
484
+ trackio.log({
485
+ "eval/loss": metrics["loss"],
486
+ "eval/ADE": metrics["ADE"],
487
+ "eval/FDE": metrics["FDE"],
488
+ "eval/accuracy": metrics["accuracy"],
489
+ "eval/RMSE_x": metrics["RMSE_x"],
490
+ "eval/RMSE_y": metrics["RMSE_y"],
491
+ "eval/RMSE_z": metrics["RMSE_z"],
492
+ "epoch": epoch + 1,
493
+ })
494
+
495
+ # Save best
496
+ if metrics["ADE"] < best_ade:
497
+ best_ade = metrics["ADE"]
498
+ best_epoch = epoch + 1
499
+
500
+ save_dir = os.path.join(OUTPUT_DIR, "best_model")
501
+ os.makedirs(save_dir, exist_ok=True)
502
+
503
+ # Save adapter weights
504
+ adapter_state = {
505
+ k: v for k, v in model.state_dict().items()
506
+ if not any(k.startswith(p) for p in ["llm.", "word_embeddings.", "backbone."])
507
+ }
508
+ torch.save(adapter_state, os.path.join(save_dir, "adapter_weights.pt"))
509
+
510
+ config = {
511
+ "llm_name": LLM_NAME,
512
+ "n_features": 9,
513
+ "context_len": CONTEXT_LEN,
514
+ "pred_len": PRED_LEN,
515
+ "patch_len": 8,
516
+ "patch_stride": 4,
517
+ "n_prototypes": 256,
518
+ "n_classes": n_classes,
519
+ "n_heads": 8,
520
+ "dropout": 0.1,
521
+ "best_ade": best_ade,
522
+ "best_fde": metrics["FDE"],
523
+ "best_epoch": best_epoch,
524
+ "best_accuracy": metrics["accuracy"],
525
+ "airport": AIRPORT,
526
+ "metrics": metrics,
527
+ }
528
+ with open(os.path.join(save_dir, "config.json"), "w") as f:
529
+ json.dump(config, f, indent=2)
530
+
531
+ print(f" ★ New best! ADE: {best_ade:.6f} (epoch {best_epoch})")
532
+ print()
533
+
534
+ # Push to Hub
535
+ print(f"\n{'='*60}")
536
+ print(f"Training complete! Best ADE: {best_ade:.6f} (epoch {best_epoch})")
537
+ print(f"Pushing to Hub: {HUB_MODEL_ID}")
538
+
539
+ api = HfApi()
540
+ try:
541
+ api.create_repo(HUB_MODEL_ID, exist_ok=True)
542
+ except Exception as e:
543
+ print(f"Repo: {e}")
544
+
545
+ save_dir = os.path.join(OUTPUT_DIR, "best_model")
546
+ api.upload_folder(folder_path=save_dir, repo_id=HUB_MODEL_ID,
547
+ commit_message=f"Best model: ADE={best_ade:.6f}, epoch {best_epoch}")
548
+
549
+ # Upload source code
550
+ api.upload_file(
551
+ path_or_fileobj=__file__,
552
+ path_in_repo="train_full.py",
553
+ repo_id=HUB_MODEL_ID,
554
+ )
555
+
556
+ # Model card
557
+ model_card = f"""---
558
+ license: apache-2.0
559
+ tags:
560
+ - trajectory-prediction
561
+ - aviation
562
+ - adsb
563
+ - time-series
564
+ - llm-reprogramming
565
+ - gpt2
566
+ datasets:
567
+ - petchthwr/ATFMTraj
568
+ pipeline_tag: time-series-forecasting
569
+ ---
570
+
571
+ # LLM4AirTrack: LLM-Driven Aircraft Trajectory Prediction
572
+
573
+ Adapts the [LLM4STP](https://github.com/Joker-hang/LLM4STP) framework from maritime AIS to aviation ADS-B.
574
+ Uses a **frozen GPT-2 backbone** with lightweight trainable adapters (~2.4% of params).
575
+
576
+ ## Architecture
577
+
578
+ ```
579
+ ADS-B Features (9-dim) → RevIN → Patch Tokenizer → Patch Embedder
580
+ → Cross-Attention Reprogrammer (learned text prototypes)
581
+ → Prompt-as-Prefix → Frozen GPT-2 Backbone
582
+ → Trajectory Head (future xyz) + Classification Head (STAR/runway)
583
+ ```
584
+
585
+ ### Key Components
586
+ 1. **9-dim Kinematic Features**: Position (x,y,z ENU) + Direction (ux,uy,uz) + Polar (r, sinθ, cosθ)
587
+ 2. **Patch Tokenization**: Overlapping temporal patches (len=8, stride=4)
588
+ 3. **Cross-Attention Reprogramming**: 256 learned text prototypes, 8-head attention
589
+ 4. **Frozen GPT-2**: 124M params frozen, only ~3.1M trainable
590
+ 5. **Dual Heads**: Trajectory prediction (Smooth L1) + Route classification (CE)
591
+
592
+ ## Training
593
+
594
+ - **Dataset**: [ATFMTraj](https://huggingface.co/datasets/petchthwr/ATFMTraj) - {AIRPORT}
595
+ - **Source**: OpenSky ADS-B, Incheon International Airport arrivals (2018-2023)
596
+ - **Context**: {CONTEXT_LEN} timesteps (1s intervals)
597
+ - **Prediction**: {PRED_LEN} timesteps ahead
598
+ - **Optimizer**: AdamW, lr={LR}, cosine annealing
599
+ - **Epochs**: {EPOCHS}
600
+
601
+ ## Results
602
+
603
+ | Metric | Value |
604
+ |--------|-------|
605
+ | ADE (normalized) | {best_ade:.6f} |
606
+ | Best Epoch | {best_epoch} |
607
+ | Route Classification Acc | {metrics['accuracy']:.4f} |
608
+
609
+ ## Usage
610
+
611
+ ```python
612
+ import torch, json
613
+ from train_full import LLM4AirTrack
614
+
615
+ # Load
616
+ with open("config.json") as f:
617
+ cfg = json.load(f)
618
+
619
+ model = LLM4AirTrack(
620
+ llm_name=cfg["llm_name"],
621
+ context_len=cfg["context_len"],
622
+ pred_len=cfg["pred_len"],
623
+ n_classes=cfg["n_classes"],
624
+ )
625
+ state = torch.load("adapter_weights.pt", map_location="cpu")
626
+ model.load_state_dict(state, strict=False)
627
+ model.eval()
628
+
629
+ # Predict (input: 60 timesteps of 9-dim kinematic features)
630
+ context = torch.randn(1, 60, 9)
631
+ out = model(context)
632
+ future_xyz = out["pred_trajectory"] # (1, 30, 3)
633
+ route_class = out["pred_class"].argmax(-1) # (1,)
634
+ ```
635
+
636
+ ## Downstream Tasks
637
+
638
+ - **Track Activity Classification**: Route/procedure identification from trajectory embeddings
639
+ - **Anomaly Detection**: Flag deviations from predicted trajectory
640
+ - **Conflict Detection**: Multi-aircraft trajectory forecasting
641
+ - **ETA Prediction**: Time-to-threshold from trajectory state
642
+
643
+ ## References
644
+
645
+ - [LLM4STP](https://github.com/Joker-hang/LLM4STP) - Original maritime framework
646
+ - [Time-LLM](https://arxiv.org/abs/2310.01728) - Foundational reprogramming approach
647
+ - [ATFMTraj](https://huggingface.co/datasets/petchthwr/ATFMTraj) - Aviation trajectory dataset
648
+ - [ATSCC](https://arxiv.org/abs/2407.20028) - Self-supervised trajectory representation
649
+ - [LLM4Delay](https://arxiv.org/abs/2510.23636) - Cross-modality LLM adaptation for aviation
650
+ """
651
+ api.upload_file(
652
+ path_or_fileobj=model_card.encode(),
653
+ path_in_repo="README.md",
654
+ repo_id=HUB_MODEL_ID,
655
+ )
656
+
657
+ print(f"✓ Pushed to: https://huggingface.co/{HUB_MODEL_ID}")
658
+
659
+
660
+ if __name__ == "__main__":
661
+ main()