LLM4AirTrack / train_full.py
Jdice27's picture
Upload train_full.py with huggingface_hub
434f8c5 verified
"""
Full-scale training script for LLM4AirTrack.
Trains on RKSIa (Incheon arrivals) - full dataset.
"""
import os
import sys
import time
import json
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
from huggingface_hub import hf_hub_download, HfApi
import pandas as pd
from scipy.ndimage import uniform_filter1d
# ============================================================
# DATA MODULE
# ============================================================
def download_atfm_dataset(airport="RKSIa", cache_dir="/app/data/ATFMTraj"):
os.makedirs(cache_dir, exist_ok=True)
airport_dir = os.path.join(cache_dir, airport)
os.makedirs(airport_dir, exist_ok=True)
for mode in ["TRAIN", "TEST"]:
for var in ["X", "Y", "Z"]:
fname = f"{airport}_{mode}_{var}.tsv"
fpath = os.path.join(airport_dir, fname)
if not os.path.exists(fpath):
print(f"Downloading {airport}/{fname}...")
hf_hub_download(
repo_id="petchthwr/ATFMTraj",
filename=f"{airport}/{fname}",
repo_type="dataset",
local_dir=cache_dir,
)
return airport_dir
def load_atfm_raw(airport, mode, cache_dir):
airport_dir = os.path.join(cache_dir, airport)
data, labels = [], None
for var in ['X', 'Y', 'Z']:
df = pd.read_csv(
os.path.join(airport_dir, f"{airport}_{mode}_{var}.tsv"),
sep='\t', header=None, na_values='NaN'
)
if labels is None:
labels = df.values[:, 0]
data.append(df.values[:, 1:])
return np.stack(data, axis=-1), labels.astype(int)
def compute_kinematic_features(trajectory, dt=1.0):
x, y, z = trajectory[:, 0], trajectory[:, 1], trajectory[:, 2]
dx, dy, dz = np.gradient(x)/dt, np.gradient(y)/dt, np.gradient(z)/dt
speed = np.sqrt(dx**2 + dy**2 + dz**2) + 1e-8
ux, uy, uz = dx/speed, dy/speed, dz/speed
r = np.sqrt(x**2 + y**2) + 1e-8
theta = np.arctan2(y, x)
return np.stack([x, y, z, ux, uy, uz, r, np.sin(theta), np.cos(theta)], axis=-1)
def create_windows(data, labels, context_len=60, pred_len=30, stride=15):
total_len = context_len + pred_len
contexts, targets, sample_labels = [], [], []
for i in range(len(data)):
traj = data[i]
valid_mask = ~np.isnan(traj[:, 0])
valid_len = np.sum(valid_mask)
if valid_len < total_len:
continue
traj_valid = traj[valid_mask]
for start in range(0, valid_len - total_len + 1, stride):
ctx_raw = traj_valid[start:start + context_len]
tgt = traj_valid[start + context_len:start + total_len]
ctx = compute_kinematic_features(ctx_raw)
contexts.append(ctx)
targets.append(tgt)
sample_labels.append(labels[i])
return (
np.array(contexts, dtype=np.float32),
np.array(targets, dtype=np.float32),
np.array(sample_labels, dtype=np.int64),
)
class AirTrackDataset(Dataset):
def __init__(self, contexts, targets, labels):
self.contexts = torch.from_numpy(contexts)
self.targets = torch.from_numpy(targets)
self.labels = torch.from_numpy(labels)
def __len__(self):
return len(self.contexts)
def __getitem__(self, idx):
return {
"context": self.contexts[idx],
"target": self.targets[idx],
"label": self.labels[idx],
}
# ============================================================
# MODEL MODULE
# ============================================================
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
class RevIN(nn.Module):
def __init__(self, n_features, eps=1e-5):
super().__init__()
self.eps = eps
self.affine_weight = nn.Parameter(torch.ones(n_features))
self.affine_bias = nn.Parameter(torch.zeros(n_features))
def forward(self, x, mode="norm"):
if mode == "norm":
self._mean = x.mean(dim=1, keepdim=True).detach()
self._std = (x.std(dim=1, keepdim=True) + self.eps).detach()
x = (x - self._mean) / self._std
x = x * self.affine_weight + self.affine_bias
elif mode == "denorm":
x = (x - self.affine_bias[:3]) / (self.affine_weight[:3] + self.eps)
x = x * self._std[:, :, :3] + self._mean[:, :, :3]
return x
class PatchTokenizer(nn.Module):
def __init__(self, patch_len=8, stride=4):
super().__init__()
self.patch_len = patch_len
self.stride = stride
def forward(self, x):
B, T, F = x.shape
x = x.unfold(1, self.patch_len, self.stride)
x = x.permute(0, 1, 3, 2).contiguous()
return x.reshape(B, x.shape[1], self.patch_len * F)
def n_patches(self, seq_len):
return (seq_len - self.patch_len) // self.stride + 1
class CrossAttentionReprogrammer(nn.Module):
def __init__(self, d_model, n_heads=8, n_prototypes=256, dropout=0.1):
super().__init__()
self.prototypes = nn.Parameter(torch.randn(n_prototypes, d_model) * 0.02)
self.cross_attn = nn.MultiheadAttention(
embed_dim=d_model, num_heads=n_heads, dropout=dropout, batch_first=True,
)
self.layer_norm = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, patch_embeds):
B = patch_embeds.shape[0]
protos = self.prototypes.unsqueeze(0).expand(B, -1, -1)
attn_out, _ = self.cross_attn(query=patch_embeds, key=protos, value=protos)
return self.layer_norm(patch_embeds + self.dropout(attn_out))
class LLM4AirTrack(nn.Module):
def __init__(
self,
llm_name="openai-community/gpt2",
n_features=9,
context_len=60,
pred_len=30,
patch_len=8,
patch_stride=4,
n_prototypes=256,
n_classes=39,
n_heads=8,
dropout=0.1,
freeze_llm=True,
prompt_text=(
"This is an aircraft trajectory in 3D airspace near an airport. "
"The data represents ADS-B surveillance with position, velocity, and polar components. "
"Predict the future trajectory."
),
):
super().__init__()
self.pred_len = pred_len
self.freeze_llm = freeze_llm
# LLM backbone
print(f"Loading LLM: {llm_name}")
config = AutoConfig.from_pretrained(llm_name)
self.d_llm = config.hidden_size
self.tokenizer = AutoTokenizer.from_pretrained(llm_name)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.llm = AutoModelForCausalLM.from_pretrained(llm_name)
if freeze_llm:
for p in self.llm.parameters():
p.requires_grad = False
self.llm.eval()
# Word embeddings reference
if hasattr(self.llm, 'transformer'):
self.word_embeddings = self.llm.transformer.wte
self.backbone = self.llm.transformer
elif hasattr(self.llm, 'model') and hasattr(self.llm.model, 'embed_tokens'):
self.word_embeddings = self.llm.model.embed_tokens
self.backbone = self.llm.model
# Prompt
tokens = self.tokenizer(prompt_text, return_tensors="pt", truncation=True, max_length=64)
self.register_buffer("prompt_ids", tokens["input_ids"])
# Trainable components
self.revin = RevIN(n_features)
self.patcher = PatchTokenizer(patch_len, patch_stride)
self.patch_embed = nn.Sequential(
nn.Linear(patch_len * n_features, self.d_llm),
nn.GELU(),
nn.LayerNorm(self.d_llm),
nn.Dropout(dropout),
)
self.reprogrammer = CrossAttentionReprogrammer(self.d_llm, n_heads, n_prototypes, dropout)
# Trajectory prediction head
self.traj_head = nn.Sequential(
nn.Linear(self.d_llm, self.d_llm // 2),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(self.d_llm // 2, pred_len * 3),
)
# Classification head
self.cls_head = nn.Sequential(
nn.Linear(self.d_llm, self.d_llm // 4),
nn.GELU(),
nn.Dropout(0.2),
nn.Linear(self.d_llm // 4, n_classes),
)
total = sum(p.numel() for p in self.parameters())
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
print(f"Total: {total:,} | Trainable: {trainable:,} ({100*trainable/total:.2f}%)")
def forward(self, context, target=None, label=None):
B = context.shape[0]
device = context.device
# Normalize
x = self.revin(context, mode="norm")
# Patch + embed
patches = self.patcher(x)
patch_emb = self.patch_embed(patches)
# Reprogram
reprogrammed = self.reprogrammer(patch_emb)
# Prompt prefix
with torch.no_grad():
prompt_emb = self.word_embeddings(self.prompt_ids.to(device))
prompt_emb = prompt_emb.expand(B, -1, -1)
# Assemble and pass through frozen LLM
input_emb = torch.cat([prompt_emb, reprogrammed], dim=1)
if self.freeze_llm:
with torch.no_grad():
out = self.backbone(inputs_embeds=input_emb)
hidden = out.last_hidden_state.detach()
else:
out = self.backbone(inputs_embeds=input_emb)
hidden = out.last_hidden_state
hidden = hidden.requires_grad_(True)
pooled = hidden.mean(dim=1)
# Heads
results = {}
loss = torch.tensor(0.0, device=device, requires_grad=True)
# Trajectory prediction
pred_flat = self.traj_head(pooled)
pred_traj = pred_flat.reshape(B, self.pred_len, 3)
pred_traj = self.revin(pred_traj, mode="denorm")
results["pred_trajectory"] = pred_traj
if target is not None:
traj_loss = F.smooth_l1_loss(pred_traj, target)
results["traj_loss"] = traj_loss
loss = loss + traj_loss
# Classification
class_logits = self.cls_head(pooled)
results["pred_class"] = class_logits
if label is not None:
cls_loss = F.cross_entropy(class_logits, label)
results["cls_loss"] = cls_loss
loss = loss + 0.1 * cls_loss
results["loss"] = loss
return results
# ============================================================
# TRAINING
# ============================================================
def compute_metrics(pred, target):
disp = torch.sqrt(((pred - target) ** 2).sum(dim=-1))
ade = disp.mean().item()
fde = disp[:, -1].mean().item()
rmse = torch.sqrt(((pred - target) ** 2).mean(dim=(0, 1)))
return {
"ADE": ade, "FDE": fde,
"RMSE_x": rmse[0].item(), "RMSE_y": rmse[1].item(), "RMSE_z": rmse[2].item(),
}
def evaluate(model, dataloader, device):
model.eval()
total_loss, total_correct, n = 0, 0, 0
all_preds, all_targets = [], []
with torch.no_grad():
for batch in dataloader:
ctx = batch["context"].to(device)
tgt = batch["target"].to(device)
lbl = batch["label"].to(device)
out = model(ctx, tgt, lbl)
total_loss += out["loss"].item() * ctx.shape[0]
if "pred_class" in out:
total_correct += (out["pred_class"].argmax(-1) == lbl).sum().item()
all_preds.append(out["pred_trajectory"].cpu())
all_targets.append(tgt.cpu())
n += ctx.shape[0]
preds = torch.cat(all_preds)
targets = torch.cat(all_targets)
metrics = compute_metrics(preds, targets)
metrics["loss"] = total_loss / n
metrics["accuracy"] = total_correct / n
return metrics
def main():
import trackio
# Config
AIRPORT = "RKSIa"
CONTEXT_LEN = 60
PRED_LEN = 30
STRIDE = 15
BATCH_SIZE = 128
EPOCHS = 5
LR = 5e-4
LLM_NAME = "openai-community/gpt2"
HUB_MODEL_ID = "Jdice27/LLM4AirTrack"
OUTPUT_DIR = "/app/outputs/llm4airtrack"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
if torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name()}")
# Trackio
tracker = trackio.init(project="LLM4AirTrack", name=f"LLM4AirTrack-{AIRPORT}-gpt2", config={
"airport": AIRPORT, "context_len": CONTEXT_LEN, "pred_len": PRED_LEN,
"batch_size": BATCH_SIZE, "epochs": EPOCHS, "lr": LR, "llm": LLM_NAME,
})
# Data
print(f"\n{'='*60}")
print(f"Loading {AIRPORT} data...")
download_atfm_dataset(AIRPORT)
train_data, train_labels = load_atfm_raw(AIRPORT, "TRAIN", "/app/data/ATFMTraj")
test_data, test_labels = load_atfm_raw(AIRPORT, "TEST", "/app/data/ATFMTraj")
print(f"Raw: train={train_data.shape}, test={test_data.shape}")
# Use larger stride for training to reduce dataset size, keep test manageable
train_ctx, train_tgt, train_lbl = create_windows(train_data, train_labels, CONTEXT_LEN, PRED_LEN, stride=30)
test_ctx, test_tgt, test_lbl = create_windows(test_data, test_labels, CONTEXT_LEN, PRED_LEN, stride=60)
print(f"Windows: train={train_ctx.shape}, test={test_ctx.shape}", flush=True)
all_labels = np.concatenate([train_lbl, test_lbl])
n_classes = int(all_labels.max()) + 1
print(f"Classes: {n_classes} (unique in data: {len(np.unique(all_labels))})", flush=True)
# Subsample eval set for faster evaluation (use 10% for quick eval)
eval_size = min(len(test_ctx), 20000)
eval_idx = np.random.RandomState(42).permutation(len(test_ctx))[:eval_size]
train_ds = AirTrackDataset(train_ctx, train_tgt, train_lbl)
eval_ds = AirTrackDataset(test_ctx[eval_idx], test_tgt[eval_idx], test_lbl[eval_idx])
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
test_loader = DataLoader(eval_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
print(f"Train samples: {len(train_ds)}, Eval samples: {len(eval_ds)}", flush=True)
# Model
print(f"\n{'='*60}")
model = LLM4AirTrack(
llm_name=LLM_NAME,
n_features=9,
context_len=CONTEXT_LEN,
pred_len=PRED_LEN,
n_classes=n_classes,
patch_len=8,
patch_stride=4,
n_prototypes=256,
).to(device)
# Optimizer
trainable = [p for p in model.parameters() if p.requires_grad]
optimizer = AdamW(trainable, lr=LR, weight_decay=1e-5)
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=len(train_loader), T_mult=2, eta_min=LR * 0.01)
# Training
print(f"\n{'='*60}")
print(f"Training {EPOCHS} epochs, {len(train_loader)} steps/epoch")
print(f"{'='*60}\n")
best_ade = float("inf")
best_epoch = -1
os.makedirs(OUTPUT_DIR, exist_ok=True)
for epoch in range(EPOCHS):
model.train()
model.backbone.eval() # Keep LLM frozen in eval
epoch_loss, epoch_traj, epoch_cls, n_batches = 0, 0, 0, 0
t0 = time.time()
for batch_idx, batch in enumerate(train_loader):
ctx = batch["context"].to(device)
tgt = batch["target"].to(device)
lbl = batch["label"].to(device)
out = model(ctx, tgt, lbl)
loss = out["loss"]
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(trainable, 1.0)
optimizer.step()
scheduler.step()
epoch_loss += loss.item()
epoch_traj += out.get("traj_loss", torch.tensor(0)).item()
epoch_cls += out.get("cls_loss", torch.tensor(0)).item()
n_batches += 1
trackio.log({
"train/loss": loss.item(),
"train/traj_loss": out.get("traj_loss", torch.tensor(0)).item(),
"train/cls_loss": out.get("cls_loss", torch.tensor(0)).item(),
"train/lr": optimizer.param_groups[0]["lr"],
})
if (batch_idx + 1) % 25 == 0:
print(f" [{epoch+1}/{EPOCHS}] step {batch_idx+1}/{len(train_loader)} | "
f"loss={epoch_loss/n_batches:.6f} traj={epoch_traj/n_batches:.6f} "
f"cls={epoch_cls/n_batches:.6f} lr={optimizer.param_groups[0]['lr']:.2e}",
flush=True)
dt = time.time() - t0
avg_loss = epoch_loss / n_batches
# Evaluate
metrics = evaluate(model, test_loader, device)
print(f"\nEpoch {epoch+1}/{EPOCHS} ({dt:.0f}s) | "
f"Train loss: {avg_loss:.6f} | "
f"Eval ADE: {metrics['ADE']:.6f} FDE: {metrics['FDE']:.6f} | "
f"Acc: {metrics['accuracy']:.4f}")
trackio.log({
"eval/loss": metrics["loss"],
"eval/ADE": metrics["ADE"],
"eval/FDE": metrics["FDE"],
"eval/accuracy": metrics["accuracy"],
"eval/RMSE_x": metrics["RMSE_x"],
"eval/RMSE_y": metrics["RMSE_y"],
"eval/RMSE_z": metrics["RMSE_z"],
"epoch": epoch + 1,
})
# Save best
if metrics["ADE"] < best_ade:
best_ade = metrics["ADE"]
best_epoch = epoch + 1
save_dir = os.path.join(OUTPUT_DIR, "best_model")
os.makedirs(save_dir, exist_ok=True)
# Save adapter weights
adapter_state = {
k: v for k, v in model.state_dict().items()
if not any(k.startswith(p) for p in ["llm.", "word_embeddings.", "backbone."])
}
torch.save(adapter_state, os.path.join(save_dir, "adapter_weights.pt"))
config = {
"llm_name": LLM_NAME,
"n_features": 9,
"context_len": CONTEXT_LEN,
"pred_len": PRED_LEN,
"patch_len": 8,
"patch_stride": 4,
"n_prototypes": 256,
"n_classes": n_classes,
"n_heads": 8,
"dropout": 0.1,
"best_ade": best_ade,
"best_fde": metrics["FDE"],
"best_epoch": best_epoch,
"best_accuracy": metrics["accuracy"],
"airport": AIRPORT,
"metrics": metrics,
}
with open(os.path.join(save_dir, "config.json"), "w") as f:
json.dump(config, f, indent=2)
print(f" ★ New best! ADE: {best_ade:.6f} (epoch {best_epoch})")
print()
# Push to Hub
print(f"\n{'='*60}")
print(f"Training complete! Best ADE: {best_ade:.6f} (epoch {best_epoch})")
print(f"Pushing to Hub: {HUB_MODEL_ID}")
api = HfApi()
try:
api.create_repo(HUB_MODEL_ID, exist_ok=True)
except Exception as e:
print(f"Repo: {e}")
save_dir = os.path.join(OUTPUT_DIR, "best_model")
api.upload_folder(folder_path=save_dir, repo_id=HUB_MODEL_ID,
commit_message=f"Best model: ADE={best_ade:.6f}, epoch {best_epoch}")
# Upload source code
api.upload_file(
path_or_fileobj=__file__,
path_in_repo="train_full.py",
repo_id=HUB_MODEL_ID,
)
# Model card
model_card = f"""---
license: apache-2.0
tags:
- trajectory-prediction
- aviation
- adsb
- time-series
- llm-reprogramming
- gpt2
datasets:
- petchthwr/ATFMTraj
pipeline_tag: time-series-forecasting
---
# LLM4AirTrack: LLM-Driven Aircraft Trajectory Prediction
Adapts the [LLM4STP](https://github.com/Joker-hang/LLM4STP) framework from maritime AIS to aviation ADS-B.
Uses a **frozen GPT-2 backbone** with lightweight trainable adapters (~2.4% of params).
## Architecture
```
ADS-B Features (9-dim) → RevIN → Patch Tokenizer → Patch Embedder
→ Cross-Attention Reprogrammer (learned text prototypes)
→ Prompt-as-Prefix → Frozen GPT-2 Backbone
→ Trajectory Head (future xyz) + Classification Head (STAR/runway)
```
### Key Components
1. **9-dim Kinematic Features**: Position (x,y,z ENU) + Direction (ux,uy,uz) + Polar (r, sinθ, cosθ)
2. **Patch Tokenization**: Overlapping temporal patches (len=8, stride=4)
3. **Cross-Attention Reprogramming**: 256 learned text prototypes, 8-head attention
4. **Frozen GPT-2**: 124M params frozen, only ~3.1M trainable
5. **Dual Heads**: Trajectory prediction (Smooth L1) + Route classification (CE)
## Training
- **Dataset**: [ATFMTraj](https://huggingface.co/datasets/petchthwr/ATFMTraj) - {AIRPORT}
- **Source**: OpenSky ADS-B, Incheon International Airport arrivals (2018-2023)
- **Context**: {CONTEXT_LEN} timesteps (1s intervals)
- **Prediction**: {PRED_LEN} timesteps ahead
- **Optimizer**: AdamW, lr={LR}, cosine annealing
- **Epochs**: {EPOCHS}
## Results
| Metric | Value |
|--------|-------|
| ADE (normalized) | {best_ade:.6f} |
| Best Epoch | {best_epoch} |
| Route Classification Acc | {metrics['accuracy']:.4f} |
## Usage
```python
import torch, json
from train_full import LLM4AirTrack
# Load
with open("config.json") as f:
cfg = json.load(f)
model = LLM4AirTrack(
llm_name=cfg["llm_name"],
context_len=cfg["context_len"],
pred_len=cfg["pred_len"],
n_classes=cfg["n_classes"],
)
state = torch.load("adapter_weights.pt", map_location="cpu")
model.load_state_dict(state, strict=False)
model.eval()
# Predict (input: 60 timesteps of 9-dim kinematic features)
context = torch.randn(1, 60, 9)
out = model(context)
future_xyz = out["pred_trajectory"] # (1, 30, 3)
route_class = out["pred_class"].argmax(-1) # (1,)
```
## Downstream Tasks
- **Track Activity Classification**: Route/procedure identification from trajectory embeddings
- **Anomaly Detection**: Flag deviations from predicted trajectory
- **Conflict Detection**: Multi-aircraft trajectory forecasting
- **ETA Prediction**: Time-to-threshold from trajectory state
## References
- [LLM4STP](https://github.com/Joker-hang/LLM4STP) - Original maritime framework
- [Time-LLM](https://arxiv.org/abs/2310.01728) - Foundational reprogramming approach
- [ATFMTraj](https://huggingface.co/datasets/petchthwr/ATFMTraj) - Aviation trajectory dataset
- [ATSCC](https://arxiv.org/abs/2407.20028) - Self-supervised trajectory representation
- [LLM4Delay](https://arxiv.org/abs/2510.23636) - Cross-modality LLM adaptation for aviation
"""
api.upload_file(
path_or_fileobj=model_card.encode(),
path_in_repo="README.md",
repo_id=HUB_MODEL_ID,
)
print(f"✓ Pushed to: https://huggingface.co/{HUB_MODEL_ID}")
if __name__ == "__main__":
main()