""" Full-scale training script for LLM4AirTrack. Trains on RKSIa (Incheon arrivals) - full dataset. """ import os import sys import time import json import math import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.optim import AdamW from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts from torch.utils.data import Dataset, DataLoader from pathlib import Path from huggingface_hub import hf_hub_download, HfApi import pandas as pd from scipy.ndimage import uniform_filter1d # ============================================================ # DATA MODULE # ============================================================ def download_atfm_dataset(airport="RKSIa", cache_dir="/app/data/ATFMTraj"): os.makedirs(cache_dir, exist_ok=True) airport_dir = os.path.join(cache_dir, airport) os.makedirs(airport_dir, exist_ok=True) for mode in ["TRAIN", "TEST"]: for var in ["X", "Y", "Z"]: fname = f"{airport}_{mode}_{var}.tsv" fpath = os.path.join(airport_dir, fname) if not os.path.exists(fpath): print(f"Downloading {airport}/{fname}...") hf_hub_download( repo_id="petchthwr/ATFMTraj", filename=f"{airport}/{fname}", repo_type="dataset", local_dir=cache_dir, ) return airport_dir def load_atfm_raw(airport, mode, cache_dir): airport_dir = os.path.join(cache_dir, airport) data, labels = [], None for var in ['X', 'Y', 'Z']: df = pd.read_csv( os.path.join(airport_dir, f"{airport}_{mode}_{var}.tsv"), sep='\t', header=None, na_values='NaN' ) if labels is None: labels = df.values[:, 0] data.append(df.values[:, 1:]) return np.stack(data, axis=-1), labels.astype(int) def compute_kinematic_features(trajectory, dt=1.0): x, y, z = trajectory[:, 0], trajectory[:, 1], trajectory[:, 2] dx, dy, dz = np.gradient(x)/dt, np.gradient(y)/dt, np.gradient(z)/dt speed = np.sqrt(dx**2 + dy**2 + dz**2) + 1e-8 ux, uy, uz = dx/speed, dy/speed, dz/speed r = np.sqrt(x**2 + y**2) + 1e-8 theta = np.arctan2(y, x) return np.stack([x, y, z, ux, uy, uz, r, np.sin(theta), np.cos(theta)], axis=-1) def create_windows(data, labels, context_len=60, pred_len=30, stride=15): total_len = context_len + pred_len contexts, targets, sample_labels = [], [], [] for i in range(len(data)): traj = data[i] valid_mask = ~np.isnan(traj[:, 0]) valid_len = np.sum(valid_mask) if valid_len < total_len: continue traj_valid = traj[valid_mask] for start in range(0, valid_len - total_len + 1, stride): ctx_raw = traj_valid[start:start + context_len] tgt = traj_valid[start + context_len:start + total_len] ctx = compute_kinematic_features(ctx_raw) contexts.append(ctx) targets.append(tgt) sample_labels.append(labels[i]) return ( np.array(contexts, dtype=np.float32), np.array(targets, dtype=np.float32), np.array(sample_labels, dtype=np.int64), ) class AirTrackDataset(Dataset): def __init__(self, contexts, targets, labels): self.contexts = torch.from_numpy(contexts) self.targets = torch.from_numpy(targets) self.labels = torch.from_numpy(labels) def __len__(self): return len(self.contexts) def __getitem__(self, idx): return { "context": self.contexts[idx], "target": self.targets[idx], "label": self.labels[idx], } # ============================================================ # MODEL MODULE # ============================================================ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig class RevIN(nn.Module): def __init__(self, n_features, eps=1e-5): super().__init__() self.eps = eps self.affine_weight = nn.Parameter(torch.ones(n_features)) self.affine_bias = nn.Parameter(torch.zeros(n_features)) def forward(self, x, mode="norm"): if mode == "norm": self._mean = x.mean(dim=1, keepdim=True).detach() self._std = (x.std(dim=1, keepdim=True) + self.eps).detach() x = (x - self._mean) / self._std x = x * self.affine_weight + self.affine_bias elif mode == "denorm": x = (x - self.affine_bias[:3]) / (self.affine_weight[:3] + self.eps) x = x * self._std[:, :, :3] + self._mean[:, :, :3] return x class PatchTokenizer(nn.Module): def __init__(self, patch_len=8, stride=4): super().__init__() self.patch_len = patch_len self.stride = stride def forward(self, x): B, T, F = x.shape x = x.unfold(1, self.patch_len, self.stride) x = x.permute(0, 1, 3, 2).contiguous() return x.reshape(B, x.shape[1], self.patch_len * F) def n_patches(self, seq_len): return (seq_len - self.patch_len) // self.stride + 1 class CrossAttentionReprogrammer(nn.Module): def __init__(self, d_model, n_heads=8, n_prototypes=256, dropout=0.1): super().__init__() self.prototypes = nn.Parameter(torch.randn(n_prototypes, d_model) * 0.02) self.cross_attn = nn.MultiheadAttention( embed_dim=d_model, num_heads=n_heads, dropout=dropout, batch_first=True, ) self.layer_norm = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) def forward(self, patch_embeds): B = patch_embeds.shape[0] protos = self.prototypes.unsqueeze(0).expand(B, -1, -1) attn_out, _ = self.cross_attn(query=patch_embeds, key=protos, value=protos) return self.layer_norm(patch_embeds + self.dropout(attn_out)) class LLM4AirTrack(nn.Module): def __init__( self, llm_name="openai-community/gpt2", n_features=9, context_len=60, pred_len=30, patch_len=8, patch_stride=4, n_prototypes=256, n_classes=39, n_heads=8, dropout=0.1, freeze_llm=True, prompt_text=( "This is an aircraft trajectory in 3D airspace near an airport. " "The data represents ADS-B surveillance with position, velocity, and polar components. " "Predict the future trajectory." ), ): super().__init__() self.pred_len = pred_len self.freeze_llm = freeze_llm # LLM backbone print(f"Loading LLM: {llm_name}") config = AutoConfig.from_pretrained(llm_name) self.d_llm = config.hidden_size self.tokenizer = AutoTokenizer.from_pretrained(llm_name) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.llm = AutoModelForCausalLM.from_pretrained(llm_name) if freeze_llm: for p in self.llm.parameters(): p.requires_grad = False self.llm.eval() # Word embeddings reference if hasattr(self.llm, 'transformer'): self.word_embeddings = self.llm.transformer.wte self.backbone = self.llm.transformer elif hasattr(self.llm, 'model') and hasattr(self.llm.model, 'embed_tokens'): self.word_embeddings = self.llm.model.embed_tokens self.backbone = self.llm.model # Prompt tokens = self.tokenizer(prompt_text, return_tensors="pt", truncation=True, max_length=64) self.register_buffer("prompt_ids", tokens["input_ids"]) # Trainable components self.revin = RevIN(n_features) self.patcher = PatchTokenizer(patch_len, patch_stride) self.patch_embed = nn.Sequential( nn.Linear(patch_len * n_features, self.d_llm), nn.GELU(), nn.LayerNorm(self.d_llm), nn.Dropout(dropout), ) self.reprogrammer = CrossAttentionReprogrammer(self.d_llm, n_heads, n_prototypes, dropout) # Trajectory prediction head self.traj_head = nn.Sequential( nn.Linear(self.d_llm, self.d_llm // 2), nn.GELU(), nn.Dropout(dropout), nn.Linear(self.d_llm // 2, pred_len * 3), ) # Classification head self.cls_head = nn.Sequential( nn.Linear(self.d_llm, self.d_llm // 4), nn.GELU(), nn.Dropout(0.2), nn.Linear(self.d_llm // 4, n_classes), ) total = sum(p.numel() for p in self.parameters()) trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) print(f"Total: {total:,} | Trainable: {trainable:,} ({100*trainable/total:.2f}%)") def forward(self, context, target=None, label=None): B = context.shape[0] device = context.device # Normalize x = self.revin(context, mode="norm") # Patch + embed patches = self.patcher(x) patch_emb = self.patch_embed(patches) # Reprogram reprogrammed = self.reprogrammer(patch_emb) # Prompt prefix with torch.no_grad(): prompt_emb = self.word_embeddings(self.prompt_ids.to(device)) prompt_emb = prompt_emb.expand(B, -1, -1) # Assemble and pass through frozen LLM input_emb = torch.cat([prompt_emb, reprogrammed], dim=1) if self.freeze_llm: with torch.no_grad(): out = self.backbone(inputs_embeds=input_emb) hidden = out.last_hidden_state.detach() else: out = self.backbone(inputs_embeds=input_emb) hidden = out.last_hidden_state hidden = hidden.requires_grad_(True) pooled = hidden.mean(dim=1) # Heads results = {} loss = torch.tensor(0.0, device=device, requires_grad=True) # Trajectory prediction pred_flat = self.traj_head(pooled) pred_traj = pred_flat.reshape(B, self.pred_len, 3) pred_traj = self.revin(pred_traj, mode="denorm") results["pred_trajectory"] = pred_traj if target is not None: traj_loss = F.smooth_l1_loss(pred_traj, target) results["traj_loss"] = traj_loss loss = loss + traj_loss # Classification class_logits = self.cls_head(pooled) results["pred_class"] = class_logits if label is not None: cls_loss = F.cross_entropy(class_logits, label) results["cls_loss"] = cls_loss loss = loss + 0.1 * cls_loss results["loss"] = loss return results # ============================================================ # TRAINING # ============================================================ def compute_metrics(pred, target): disp = torch.sqrt(((pred - target) ** 2).sum(dim=-1)) ade = disp.mean().item() fde = disp[:, -1].mean().item() rmse = torch.sqrt(((pred - target) ** 2).mean(dim=(0, 1))) return { "ADE": ade, "FDE": fde, "RMSE_x": rmse[0].item(), "RMSE_y": rmse[1].item(), "RMSE_z": rmse[2].item(), } def evaluate(model, dataloader, device): model.eval() total_loss, total_correct, n = 0, 0, 0 all_preds, all_targets = [], [] with torch.no_grad(): for batch in dataloader: ctx = batch["context"].to(device) tgt = batch["target"].to(device) lbl = batch["label"].to(device) out = model(ctx, tgt, lbl) total_loss += out["loss"].item() * ctx.shape[0] if "pred_class" in out: total_correct += (out["pred_class"].argmax(-1) == lbl).sum().item() all_preds.append(out["pred_trajectory"].cpu()) all_targets.append(tgt.cpu()) n += ctx.shape[0] preds = torch.cat(all_preds) targets = torch.cat(all_targets) metrics = compute_metrics(preds, targets) metrics["loss"] = total_loss / n metrics["accuracy"] = total_correct / n return metrics def main(): import trackio # Config AIRPORT = "RKSIa" CONTEXT_LEN = 60 PRED_LEN = 30 STRIDE = 15 BATCH_SIZE = 128 EPOCHS = 5 LR = 5e-4 LLM_NAME = "openai-community/gpt2" HUB_MODEL_ID = "Jdice27/LLM4AirTrack" OUTPUT_DIR = "/app/outputs/llm4airtrack" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Device: {device}") if torch.cuda.is_available(): print(f"GPU: {torch.cuda.get_device_name()}") # Trackio tracker = trackio.init(project="LLM4AirTrack", name=f"LLM4AirTrack-{AIRPORT}-gpt2", config={ "airport": AIRPORT, "context_len": CONTEXT_LEN, "pred_len": PRED_LEN, "batch_size": BATCH_SIZE, "epochs": EPOCHS, "lr": LR, "llm": LLM_NAME, }) # Data print(f"\n{'='*60}") print(f"Loading {AIRPORT} data...") download_atfm_dataset(AIRPORT) train_data, train_labels = load_atfm_raw(AIRPORT, "TRAIN", "/app/data/ATFMTraj") test_data, test_labels = load_atfm_raw(AIRPORT, "TEST", "/app/data/ATFMTraj") print(f"Raw: train={train_data.shape}, test={test_data.shape}") # Use larger stride for training to reduce dataset size, keep test manageable train_ctx, train_tgt, train_lbl = create_windows(train_data, train_labels, CONTEXT_LEN, PRED_LEN, stride=30) test_ctx, test_tgt, test_lbl = create_windows(test_data, test_labels, CONTEXT_LEN, PRED_LEN, stride=60) print(f"Windows: train={train_ctx.shape}, test={test_ctx.shape}", flush=True) all_labels = np.concatenate([train_lbl, test_lbl]) n_classes = int(all_labels.max()) + 1 print(f"Classes: {n_classes} (unique in data: {len(np.unique(all_labels))})", flush=True) # Subsample eval set for faster evaluation (use 10% for quick eval) eval_size = min(len(test_ctx), 20000) eval_idx = np.random.RandomState(42).permutation(len(test_ctx))[:eval_size] train_ds = AirTrackDataset(train_ctx, train_tgt, train_lbl) eval_ds = AirTrackDataset(test_ctx[eval_idx], test_tgt[eval_idx], test_lbl[eval_idx]) train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True) test_loader = DataLoader(eval_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True) print(f"Train samples: {len(train_ds)}, Eval samples: {len(eval_ds)}", flush=True) # Model print(f"\n{'='*60}") model = LLM4AirTrack( llm_name=LLM_NAME, n_features=9, context_len=CONTEXT_LEN, pred_len=PRED_LEN, n_classes=n_classes, patch_len=8, patch_stride=4, n_prototypes=256, ).to(device) # Optimizer trainable = [p for p in model.parameters() if p.requires_grad] optimizer = AdamW(trainable, lr=LR, weight_decay=1e-5) scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=len(train_loader), T_mult=2, eta_min=LR * 0.01) # Training print(f"\n{'='*60}") print(f"Training {EPOCHS} epochs, {len(train_loader)} steps/epoch") print(f"{'='*60}\n") best_ade = float("inf") best_epoch = -1 os.makedirs(OUTPUT_DIR, exist_ok=True) for epoch in range(EPOCHS): model.train() model.backbone.eval() # Keep LLM frozen in eval epoch_loss, epoch_traj, epoch_cls, n_batches = 0, 0, 0, 0 t0 = time.time() for batch_idx, batch in enumerate(train_loader): ctx = batch["context"].to(device) tgt = batch["target"].to(device) lbl = batch["label"].to(device) out = model(ctx, tgt, lbl) loss = out["loss"] optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(trainable, 1.0) optimizer.step() scheduler.step() epoch_loss += loss.item() epoch_traj += out.get("traj_loss", torch.tensor(0)).item() epoch_cls += out.get("cls_loss", torch.tensor(0)).item() n_batches += 1 trackio.log({ "train/loss": loss.item(), "train/traj_loss": out.get("traj_loss", torch.tensor(0)).item(), "train/cls_loss": out.get("cls_loss", torch.tensor(0)).item(), "train/lr": optimizer.param_groups[0]["lr"], }) if (batch_idx + 1) % 25 == 0: print(f" [{epoch+1}/{EPOCHS}] step {batch_idx+1}/{len(train_loader)} | " f"loss={epoch_loss/n_batches:.6f} traj={epoch_traj/n_batches:.6f} " f"cls={epoch_cls/n_batches:.6f} lr={optimizer.param_groups[0]['lr']:.2e}", flush=True) dt = time.time() - t0 avg_loss = epoch_loss / n_batches # Evaluate metrics = evaluate(model, test_loader, device) print(f"\nEpoch {epoch+1}/{EPOCHS} ({dt:.0f}s) | " f"Train loss: {avg_loss:.6f} | " f"Eval ADE: {metrics['ADE']:.6f} FDE: {metrics['FDE']:.6f} | " f"Acc: {metrics['accuracy']:.4f}") trackio.log({ "eval/loss": metrics["loss"], "eval/ADE": metrics["ADE"], "eval/FDE": metrics["FDE"], "eval/accuracy": metrics["accuracy"], "eval/RMSE_x": metrics["RMSE_x"], "eval/RMSE_y": metrics["RMSE_y"], "eval/RMSE_z": metrics["RMSE_z"], "epoch": epoch + 1, }) # Save best if metrics["ADE"] < best_ade: best_ade = metrics["ADE"] best_epoch = epoch + 1 save_dir = os.path.join(OUTPUT_DIR, "best_model") os.makedirs(save_dir, exist_ok=True) # Save adapter weights adapter_state = { k: v for k, v in model.state_dict().items() if not any(k.startswith(p) for p in ["llm.", "word_embeddings.", "backbone."]) } torch.save(adapter_state, os.path.join(save_dir, "adapter_weights.pt")) config = { "llm_name": LLM_NAME, "n_features": 9, "context_len": CONTEXT_LEN, "pred_len": PRED_LEN, "patch_len": 8, "patch_stride": 4, "n_prototypes": 256, "n_classes": n_classes, "n_heads": 8, "dropout": 0.1, "best_ade": best_ade, "best_fde": metrics["FDE"], "best_epoch": best_epoch, "best_accuracy": metrics["accuracy"], "airport": AIRPORT, "metrics": metrics, } with open(os.path.join(save_dir, "config.json"), "w") as f: json.dump(config, f, indent=2) print(f" ★ New best! ADE: {best_ade:.6f} (epoch {best_epoch})") print() # Push to Hub print(f"\n{'='*60}") print(f"Training complete! Best ADE: {best_ade:.6f} (epoch {best_epoch})") print(f"Pushing to Hub: {HUB_MODEL_ID}") api = HfApi() try: api.create_repo(HUB_MODEL_ID, exist_ok=True) except Exception as e: print(f"Repo: {e}") save_dir = os.path.join(OUTPUT_DIR, "best_model") api.upload_folder(folder_path=save_dir, repo_id=HUB_MODEL_ID, commit_message=f"Best model: ADE={best_ade:.6f}, epoch {best_epoch}") # Upload source code api.upload_file( path_or_fileobj=__file__, path_in_repo="train_full.py", repo_id=HUB_MODEL_ID, ) # Model card model_card = f"""--- license: apache-2.0 tags: - trajectory-prediction - aviation - adsb - time-series - llm-reprogramming - gpt2 datasets: - petchthwr/ATFMTraj pipeline_tag: time-series-forecasting --- # LLM4AirTrack: LLM-Driven Aircraft Trajectory Prediction Adapts the [LLM4STP](https://github.com/Joker-hang/LLM4STP) framework from maritime AIS to aviation ADS-B. Uses a **frozen GPT-2 backbone** with lightweight trainable adapters (~2.4% of params). ## Architecture ``` ADS-B Features (9-dim) → RevIN → Patch Tokenizer → Patch Embedder → Cross-Attention Reprogrammer (learned text prototypes) → Prompt-as-Prefix → Frozen GPT-2 Backbone → Trajectory Head (future xyz) + Classification Head (STAR/runway) ``` ### Key Components 1. **9-dim Kinematic Features**: Position (x,y,z ENU) + Direction (ux,uy,uz) + Polar (r, sinθ, cosθ) 2. **Patch Tokenization**: Overlapping temporal patches (len=8, stride=4) 3. **Cross-Attention Reprogramming**: 256 learned text prototypes, 8-head attention 4. **Frozen GPT-2**: 124M params frozen, only ~3.1M trainable 5. **Dual Heads**: Trajectory prediction (Smooth L1) + Route classification (CE) ## Training - **Dataset**: [ATFMTraj](https://huggingface.co/datasets/petchthwr/ATFMTraj) - {AIRPORT} - **Source**: OpenSky ADS-B, Incheon International Airport arrivals (2018-2023) - **Context**: {CONTEXT_LEN} timesteps (1s intervals) - **Prediction**: {PRED_LEN} timesteps ahead - **Optimizer**: AdamW, lr={LR}, cosine annealing - **Epochs**: {EPOCHS} ## Results | Metric | Value | |--------|-------| | ADE (normalized) | {best_ade:.6f} | | Best Epoch | {best_epoch} | | Route Classification Acc | {metrics['accuracy']:.4f} | ## Usage ```python import torch, json from train_full import LLM4AirTrack # Load with open("config.json") as f: cfg = json.load(f) model = LLM4AirTrack( llm_name=cfg["llm_name"], context_len=cfg["context_len"], pred_len=cfg["pred_len"], n_classes=cfg["n_classes"], ) state = torch.load("adapter_weights.pt", map_location="cpu") model.load_state_dict(state, strict=False) model.eval() # Predict (input: 60 timesteps of 9-dim kinematic features) context = torch.randn(1, 60, 9) out = model(context) future_xyz = out["pred_trajectory"] # (1, 30, 3) route_class = out["pred_class"].argmax(-1) # (1,) ``` ## Downstream Tasks - **Track Activity Classification**: Route/procedure identification from trajectory embeddings - **Anomaly Detection**: Flag deviations from predicted trajectory - **Conflict Detection**: Multi-aircraft trajectory forecasting - **ETA Prediction**: Time-to-threshold from trajectory state ## References - [LLM4STP](https://github.com/Joker-hang/LLM4STP) - Original maritime framework - [Time-LLM](https://arxiv.org/abs/2310.01728) - Foundational reprogramming approach - [ATFMTraj](https://huggingface.co/datasets/petchthwr/ATFMTraj) - Aviation trajectory dataset - [ATSCC](https://arxiv.org/abs/2407.20028) - Self-supervised trajectory representation - [LLM4Delay](https://arxiv.org/abs/2510.23636) - Cross-modality LLM adaptation for aviation """ api.upload_file( path_or_fileobj=model_card.encode(), path_in_repo="README.md", repo_id=HUB_MODEL_ID, ) print(f"✓ Pushed to: https://huggingface.co/{HUB_MODEL_ID}") if __name__ == "__main__": main()