| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| Flight-JEPA v2 — bundled training script for HF Jobs. |
| |
| Self-contained: downloads the dataset from HF, trains either the supervised |
| baseline (`--lambda-jepa 0`) or the JEPA-augmented model, runs blindspot |
| scoring + extrapolation eval, and pushes the result to a hub repo. |
| |
| Usage (HF Jobs): |
| python train_v2_prod.py --tag baseline --lambda-jepa 0.0 \ |
| --hub-model-id guychuk/flight-jepa-v2 --push-to-hub |
| |
| python train_v2_prod.py --tag jepa --lambda-jepa 0.5 \ |
| --hub-model-id guychuk/flight-jepa-v2 --push-to-hub |
| """ |
| from __future__ import annotations |
| import argparse |
| import copy |
| import json |
| import math |
| import os |
| import shutil |
| import sys |
| import time |
|
|
| import numpy as np |
| import pandas as pd |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset, DataLoader |
|
|
| try: |
| import trackio |
| HAS_TRACKIO = True |
| except ImportError: |
| HAS_TRACKIO = False |
|
|
|
|
| |
| |
| |
|
|
| def load_atfm(dset_name, mode, path): |
| variables = ["X", "Y", "Z"] |
| data, labels = [], None |
| for var in variables: |
| df = pd.read_csv(os.path.join(path, f"{dset_name}_{mode}_{var}.tsv"), |
| sep="\t", header=None, na_values="NaN") |
| if labels is None: |
| labels = df.values[:, 0] |
| data.append(df.values[:, 1:]) |
| return np.stack(data, axis=-1), labels.astype(int) |
|
|
|
|
| def compute_features(traj_xyz: np.ndarray) -> np.ndarray: |
| if traj_xyz.shape[0] < 2: |
| T = traj_xyz.shape[0] |
| return np.concatenate([ |
| traj_xyz, np.zeros((T, 3), dtype=traj_xyz.dtype), |
| np.zeros((T, 3), dtype=traj_xyz.dtype) |
| ], axis=1) |
| x, y, z = traj_xyz[:, 0], traj_xyz[:, 1], traj_xyz[:, 2] |
| diff = np.diff(traj_xyz, axis=0) |
| norms = np.maximum(np.linalg.norm(diff, axis=1, keepdims=True), 1e-8) |
| u = diff / norms |
| u = np.vstack([u, u[-1:]]) |
| r = np.sqrt(x ** 2 + y ** 2) |
| theta = np.arctan2(y, x) |
| return np.column_stack([ |
| traj_xyz, u, |
| r[:, None], np.sin(theta)[:, None], np.cos(theta)[:, None] |
| ]).astype(np.float32) |
|
|
|
|
| def ensure_data(airport: str, data_dir: str = "data"): |
| target = os.path.join(data_dir, airport) |
| if os.path.isdir(target) and any(f.endswith(".tsv") for f in os.listdir(target)): |
| return target |
| print(f"[data] downloading {airport} from HF ...") |
| from huggingface_hub import snapshot_download |
| snap = snapshot_download( |
| "petchthwr/ATFMTraj", |
| repo_type="dataset", |
| allow_patterns=[f"{airport}/*"], |
| ) |
| os.makedirs(data_dir, exist_ok=True) |
| src = os.path.join(snap, airport) |
| if not os.path.isdir(target): |
| shutil.copytree(src, target) |
| return target |
|
|
|
|
| |
| |
| |
|
|
| PAD_VALUE = 0.0 |
|
|
|
|
| class BlindspotDataset(Dataset): |
| def __init__(self, airport, mode, data_dir, |
| past_max=256, past_min=60, |
| delta_min=30, delta_max=120, |
| seed=0, epoch_multiplier=4, |
| held_out_classes=None, |
| keep_only_classes=None, |
| ): |
| ensure_data(airport, data_dir) |
| airport_dir = os.path.join(data_dir, airport) |
| raw, labels = load_atfm(airport, mode, airport_dir) |
|
|
| self.past_max = past_max |
| self.past_min = past_min |
| self.delta_min = delta_min |
| self.delta_max = delta_max |
| self.epoch_multiplier = epoch_multiplier |
| self.rng_seed = seed |
|
|
| lengths = np.array( |
| [int(np.sum(~np.isnan(raw[i, :, 0]))) for i in range(raw.shape[0])], |
| dtype=np.int64, |
| ) |
| min_required = past_min + delta_max + 1 |
| keep = lengths >= min_required |
| if keep.sum() == 0: |
| raise RuntimeError( |
| f"No trajectories of length >= {min_required} in {airport}/{mode}" |
| ) |
|
|
| |
| if keep_only_classes is not None: |
| keep_set = set(int(c) for c in keep_only_classes) |
| class_keep = np.array([int(c) in keep_set for c in labels]) |
| keep = keep & class_keep |
| elif held_out_classes is not None: |
| held = set(int(c) for c in held_out_classes) |
| class_keep = np.array([int(c) not in held for c in labels]) |
| keep = keep & class_keep |
|
|
| raw = raw[keep] |
| lengths = lengths[keep] |
| self.labels = labels[keep].astype(np.int64) |
|
|
| self.positions = [] |
| for i in range(raw.shape[0]): |
| L = int(lengths[i]) |
| self.positions.append(np.nan_to_num(raw[i, :L], nan=0.0).astype(np.float32)) |
| del raw |
|
|
| self.n_traj = len(self.positions) |
| print(f"[data] {airport}/{mode}: {self.n_traj} trajectories " |
| f"(after filtering for L >= {min_required})") |
|
|
| def __len__(self): |
| return self.n_traj * self.epoch_multiplier |
|
|
| def __getitem__(self, idx): |
| traj_idx = idx % self.n_traj |
| rng = np.random.default_rng(self.rng_seed + idx * 9173) |
| positions = self.positions[traj_idx] |
| L = positions.shape[0] |
| delta = int(rng.integers(self.delta_min, self.delta_max + 1)) |
| t_in_max = L - delta - 1 |
| t_in_min = self.past_min |
| t_in = int(rng.integers(t_in_min, t_in_max + 1)) |
|
|
| past_start = max(0, t_in - self.past_max) |
| past_pos = positions[past_start:t_in] |
| target_pos = positions[t_in:t_in + delta] |
|
|
| past_features = compute_features(past_pos) |
| T_past = past_features.shape[0] |
| feat_pad = np.full((self.past_max, 9), PAD_VALUE, dtype=np.float32) |
| feat_pad[:T_past] = past_features |
| tgt_pad = np.zeros((self.delta_max, 3), dtype=np.float32) |
| tgt_pad[:delta] = target_pos |
| return { |
| "past_features": torch.from_numpy(feat_pad), |
| "past_length": torch.tensor(T_past, dtype=torch.long), |
| "target_pos": torch.from_numpy(tgt_pad), |
| "delta": torch.tensor(delta, dtype=torch.long), |
| "label": torch.tensor(int(self.labels[traj_idx]), dtype=torch.long), |
| } |
|
|
|
|
| |
| |
| |
|
|
| def sinusoidal_embedding(values, dim): |
| half = dim // 2 |
| device = values.device |
| freqs = torch.exp(-math.log(10000.0) |
| * torch.arange(half, device=device) / half) |
| angles = values.float().unsqueeze(-1) * freqs |
| emb = torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1) |
| if dim % 2 == 1: |
| emb = F.pad(emb, (0, 1)) |
| return emb |
|
|
|
|
| class LearnablePosEnc(nn.Module): |
| def __init__(self, max_len, d_model): |
| super().__init__() |
| self.pe = nn.Parameter(torch.randn(1, max_len, d_model) * 0.02) |
| def forward(self, x): |
| return x + self.pe[:, :x.size(1)] |
|
|
|
|
| class PatchTokenizer(nn.Module): |
| def __init__(self, in_channels=9, d_model=256, patch_size=8, max_patches=64): |
| super().__init__() |
| self.patch_size = patch_size |
| self.d_model = d_model |
| self.embed = nn.Sequential( |
| nn.Conv1d(in_channels, d_model // 2, 5, padding=2), |
| nn.GELU(), |
| nn.Conv1d(d_model // 2, d_model, 3, padding=1), |
| nn.GELU(), |
| ) |
| self.pos_enc = LearnablePosEnc(max_patches, d_model) |
| self.norm = nn.LayerNorm(d_model) |
|
|
| def forward(self, features, lengths): |
| B, T, C = features.shape |
| h = self.embed(features.transpose(1, 2)) |
| N = max(1, T // self.patch_size) |
| h = h[:, :, :N * self.patch_size] |
| h = h.reshape(B, self.d_model, N, self.patch_size).mean(-1) |
| h = h.transpose(1, 2) |
| h = self.norm(self.pos_enc(h)) |
| patch_lengths = (lengths.float() / self.patch_size).clamp(min=1).long() |
| patch_lengths = patch_lengths.clamp(max=N) |
| return h, patch_lengths |
|
|
|
|
| class CausalEncoder(nn.Module): |
| def __init__(self, d_model=256, n_heads=8, n_layers=4, d_ff=1024, dropout=0.1): |
| super().__init__() |
| layer = nn.TransformerEncoderLayer( |
| d_model=d_model, nhead=n_heads, dim_feedforward=d_ff, |
| dropout=dropout, activation="gelu", batch_first=True, |
| norm_first=True, |
| ) |
| self.tf = nn.TransformerEncoder(layer, num_layers=n_layers) |
| self.norm = nn.LayerNorm(d_model) |
|
|
| def forward(self, x, key_padding_mask): |
| N = x.size(1) |
| causal_mask = torch.triu( |
| torch.ones(N, N, dtype=torch.bool, device=x.device), diagonal=1 |
| ) |
| return self.norm( |
| self.tf(x, mask=causal_mask, src_key_padding_mask=key_padding_mask) |
| ) |
|
|
|
|
| def last_valid_token(encoded, patch_lengths): |
| B, N, D = encoded.shape |
| idx = (patch_lengths - 1).clamp(min=0).view(B, 1, 1).expand(-1, 1, D) |
| return encoded.gather(1, idx).squeeze(1) |
|
|
|
|
| class DeltaEmbedding(nn.Module): |
| def __init__(self, d_model=256, d_freq=64): |
| super().__init__() |
| self.d_freq = d_freq |
| self.proj = nn.Sequential( |
| nn.Linear(d_freq * 2, d_model), |
| nn.GELU(), |
| nn.Linear(d_model, d_model), |
| ) |
| def forward(self, delta, t_past): |
| d_emb = sinusoidal_embedding(delta.float(), self.d_freq) |
| rel = delta.float() / t_past.float().clamp(min=1.0) |
| rel_emb = sinusoidal_embedding(rel * 100.0, self.d_freq) |
| return self.proj(torch.cat([d_emb, rel_emb], dim=-1)) |
|
|
|
|
| class GaussianHead(nn.Module): |
| def __init__(self, d_model=256, d_hidden=256): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Linear(d_model, d_hidden), nn.GELU(), |
| nn.Linear(d_hidden, d_hidden), nn.GELU(), |
| ) |
| self.mu_head = nn.Linear(d_hidden, 3) |
| self.log_sigma_head = nn.Linear(d_hidden, 3) |
| self.rho_head = nn.Linear(d_hidden, 1) |
|
|
| def forward(self, h): |
| z = self.net(h) |
| delta_mu = self.mu_head(z) |
| log_sigma = self.log_sigma_head(z).clamp(min=-7.0, max=2.0) |
| rho = torch.tanh(self.rho_head(z)).squeeze(-1) * 0.99 |
| return delta_mu, log_sigma, rho |
|
|
|
|
| def gaussian_nll_xyz(true_delta, mu, log_sigma, rho, beta: float = 0.5): |
| """ |
| β-NLL Gaussian for (x, y, z) — bivariate on xy + independent z. |
| |
| Standard NLL has a degenerate minimum where σ→0 ("σ-collapse", |
| Detlefsen 2019). β-NLL (Seitzer et al., arxiv:2203.09168) reweights |
| each sample's NLL by σ^{2β} (detached) so points with large σ get |
| proportionally more gradient on the mean term, preventing collapse. |
| |
| β = 0 → standard NLL (collapse-prone, what v2 used) |
| β = 0.5 → recommended; preserves uncertainty learning |
| β = 1 → pure squared-error scaling (loses σ learning) |
| """ |
| sx = log_sigma[:, 0].exp() |
| sy = log_sigma[:, 1].exp() |
| sz = log_sigma[:, 2].exp() |
| dx = true_delta[:, 0] - mu[:, 0] |
| dy = true_delta[:, 1] - mu[:, 1] |
| dz = true_delta[:, 2] - mu[:, 2] |
| omr2 = (1.0 - rho * rho).clamp(min=1e-6) |
| z2 = (((dx / sx) ** 2) |
| - 2.0 * rho * (dx / sx) * (dy / sy) |
| + ((dy / sy) ** 2)) / omr2 |
| log_det = 2.0 * (log_sigma[:, 0] + log_sigma[:, 1]) + torch.log(omr2) |
| nll_xy = 0.5 * (z2 + log_det + 2.0 * math.log(2.0 * math.pi)) |
| nll_z = 0.5 * ((dz / sz) ** 2 + 2.0 * log_sigma[:, 2] |
| + math.log(2.0 * math.pi)) |
|
|
| if beta > 0.0: |
| |
| |
| |
| |
| sxy = (sx * sy).sqrt().detach() |
| wxy = sxy.pow(2.0 * beta) |
| wz = sz.detach().pow(2.0 * beta) |
| return wxy * nll_xy + wz * nll_z |
| return nll_xy + nll_z |
|
|
|
|
| class ParallelDecoder(nn.Module): |
| """ |
| HiVT-style parallel decoder (arxiv:2207.09588). |
| |
| Takes a single context vector h ∈ R^d (fused z_in + Δ_emb) and emits |
| a full [T_max, 7] tensor of (μ_x, μ_y, μ_z, log σ_x, log σ_y, log σ_z, ρ) |
| in one forward pass. Each row is the prediction for one future timestep |
| (relative to the start of the blindspot). |
| |
| Coherence comes from the shared MLP backbone + per-step positional embed |
| (every step is a function of the same context, with smoothly-varying |
| positional inputs). Variable Δ is handled by masking unused steps in the |
| loss. |
| |
| Output represents *absolute positions* at each step, not deltas. The |
| NLL loss is applied per-step against target_pos[:, t]. |
| """ |
|
|
| def __init__(self, d_model: int = 256, t_max: int = 120, mlp_hidden: int = 512, |
| dropout: float = 0.1): |
| super().__init__() |
| self.t_max = t_max |
| self.d_model = d_model |
| self.step_pe = LearnablePosEnc(t_max, d_model) |
| self.mlp = nn.Sequential( |
| nn.Linear(d_model, mlp_hidden), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(mlp_hidden, mlp_hidden), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(mlp_hidden, 7), |
| ) |
| |
| |
| nn.init.trunc_normal_(self.mlp[-1].weight, std=0.02) |
| nn.init.zeros_(self.mlp[-1].bias) |
|
|
| def forward(self, h: torch.Tensor) -> torch.Tensor: |
| """ |
| h: (B, D) context vector |
| returns: (B, T_max, 7) — (μ_x, μ_y, μ_z, log σ_x, log σ_y, log σ_z, rho_pre) |
| """ |
| B = h.size(0) |
| |
| h_expand = h.unsqueeze(1).expand(B, self.t_max, self.d_model) |
| h_step = self.step_pe(h_expand) |
| out = self.mlp(h_step) |
| return out |
|
|
|
|
| def split_parallel_output(raw: torch.Tensor): |
| """raw (B, T, 7) -> (mu, log_sigma, rho). |
| mu: (B, T, 3); log_sigma: (B, T, 3); rho: (B, T).""" |
| mu = raw[..., :3] |
| log_sigma = raw[..., 3:6].clamp(min=-7.0, max=2.0) |
| rho = torch.tanh(raw[..., 6]) * 0.99 |
| return mu, log_sigma, rho |
|
|
|
|
| def parallel_nll_xyz(true_pos: torch.Tensor, mu: torch.Tensor, |
| log_sigma: torch.Tensor, rho: torch.Tensor, |
| mask: torch.Tensor, beta: float = 0.5) -> torch.Tensor: |
| """ |
| Per-batch β-NLL over a (B, T, ·) tensor. mask: (B, T) float, 1 for valid. |
| Returns scalar mean NLL across (sample, valid steps). |
| """ |
| sx = log_sigma[..., 0].exp() |
| sy = log_sigma[..., 1].exp() |
| sz = log_sigma[..., 2].exp() |
| dx = true_pos[..., 0] - mu[..., 0] |
| dy = true_pos[..., 1] - mu[..., 1] |
| dz = true_pos[..., 2] - mu[..., 2] |
| omr2 = (1.0 - rho * rho).clamp(min=1e-6) |
| z2 = (((dx / sx) ** 2) |
| - 2.0 * rho * (dx / sx) * (dy / sy) |
| + ((dy / sy) ** 2)) / omr2 |
| log_det = 2.0 * (log_sigma[..., 0] + log_sigma[..., 1]) + torch.log(omr2) |
| nll_xy = 0.5 * (z2 + log_det + 2.0 * math.log(2.0 * math.pi)) |
| nll_z = 0.5 * ((dz / sz) ** 2 + 2.0 * log_sigma[..., 2] |
| + math.log(2.0 * math.pi)) |
| nll = nll_xy + nll_z |
|
|
| if beta > 0.0: |
| sxy = (sx * sy).sqrt().detach() |
| wxy = sxy.pow(2.0 * beta) |
| wz = sz.detach().pow(2.0 * beta) |
| nll = wxy * nll_xy + wz * nll_z |
|
|
| nll = nll * mask |
| valid = mask.sum(-1).clamp(min=1.0) |
| return (nll.sum(-1) / valid).mean() |
|
|
|
|
| class FuturePredictor(nn.Module): |
| def __init__(self, d_model=256, pred_dim=128, dropout=0.1): |
| super().__init__() |
| self.proj_in = nn.Linear(d_model * 2, pred_dim) |
| layer = nn.TransformerEncoderLayer( |
| d_model=pred_dim, nhead=4, dim_feedforward=pred_dim * 2, |
| dropout=dropout, activation="gelu", batch_first=True, norm_first=True, |
| ) |
| self.tf = nn.TransformerEncoder(layer, num_layers=2) |
| self.proj_out = nn.Linear(pred_dim, d_model) |
| self.norm = nn.LayerNorm(d_model) |
|
|
| def forward(self, z_in, delta_emb): |
| h = self.proj_in(torch.cat([z_in, delta_emb], dim=-1)).unsqueeze(1) |
| h = self.tf(h) |
| return self.norm(self.proj_out(h.squeeze(1))) |
|
|
|
|
| class FlightJEPAv2(nn.Module): |
| def __init__(self, cfg): |
| super().__init__() |
| self.cfg = cfg |
| d = cfg.get("d_model", 256) |
| h_ = cfg.get("n_heads", 8) |
| n_l = cfg.get("n_layers", 4) |
| d_ff = cfg.get("d_ff", 1024) |
| dr = cfg.get("dropout", 0.1) |
| ps = cfg.get("patch_size", 8) |
| past_max = cfg.get("past_max", 256) |
| max_patches = past_max // ps |
| self.lambda_jepa = cfg.get("lambda_jepa", 0.0) |
| self.ema_decay = cfg.get("ema_decay", 0.998) |
| self.beta_nll = cfg.get("beta_nll", 0.5) |
| self.decoder_mode = cfg.get("decoder_mode", "ar") |
| self.t_max = cfg.get("delta_max", 120) |
|
|
| self.tokenizer = PatchTokenizer(9, d, ps, max_patches) |
| self.encoder = CausalEncoder(d, h_, n_l, d_ff, dr) |
| self.delta_emb = DeltaEmbedding(d, 64) |
| self.fuse_in = nn.Sequential( |
| nn.Linear(d * 2, d), nn.GELU(), |
| nn.Linear(d, d), |
| ) |
| |
| |
| self.head = GaussianHead(d, d) |
| self.step_cell = nn.GRUCell(input_size=3, hidden_size=d) |
| self.parallel_decoder = ParallelDecoder( |
| d_model=d, t_max=self.t_max, |
| mlp_hidden=d * 2, dropout=dr, |
| ) |
|
|
| self.target_tokenizer = copy.deepcopy(self.tokenizer) |
| self.target_encoder = copy.deepcopy(self.encoder) |
| for p in self.target_tokenizer.parameters(): |
| p.requires_grad = False |
| for p in self.target_encoder.parameters(): |
| p.requires_grad = False |
| self.predictor = FuturePredictor(d, d // 2, dr) |
|
|
| @torch.no_grad() |
| def update_ema(self): |
| m = self.ema_decay |
| for online, target in [(self.tokenizer, self.target_tokenizer), |
| (self.encoder, self.target_encoder)]: |
| for po, pt in zip(online.parameters(), target.parameters()): |
| pt.data.mul_(m).add_(po.data, alpha=1.0 - m) |
|
|
| def encode_past(self, past_features, past_length): |
| patches, patch_lens = self.tokenizer(past_features, past_length) |
| N = patches.size(1) |
| pad_mask = (torch.arange(N, device=patches.device).unsqueeze(0) |
| >= patch_lens.unsqueeze(1)) |
| encoded = self.encoder(patches, key_padding_mask=pad_mask) |
| z_in = last_valid_token(encoded, patch_lens) |
| return z_in, encoded, patch_lens |
|
|
| @torch.no_grad() |
| def encode_future_target(self, target_features, target_length): |
| patches, patch_lens = self.target_tokenizer(target_features, target_length) |
| N = patches.size(1) |
| pad_mask = (torch.arange(N, device=patches.device).unsqueeze(0) |
| >= patch_lens.unsqueeze(1)) |
| encoded = self.target_encoder(patches, key_padding_mask=pad_mask) |
| return last_valid_token(encoded, patch_lens) |
|
|
| def forward(self, past_features, past_length, target_pos, delta, last_pos, |
| ss_prob: float = 0.0): |
| """ |
| ss_prob: scheduled-sampling probability ∈ [0, 1]. AR mode only — |
| parallel mode predicts all timesteps in one shot, no SS needed. |
| """ |
| B = past_features.size(0) |
| device = past_features.device |
| delta_max = target_pos.size(1) |
|
|
| z_in, _, _ = self.encode_past(past_features, past_length) |
| delta_e = self.delta_emb(delta, past_length) |
| h = self.fuse_in(torch.cat([z_in, delta_e], dim=-1)) |
|
|
| |
| if self.decoder_mode == "parallel": |
| raw = self.parallel_decoder(h) |
| |
| raw = raw[:, :delta_max] |
| mu, log_sigma, rho = split_parallel_output(raw) |
| arange = torch.arange(delta_max, device=device).unsqueeze(0) |
| mask = (arange < delta.unsqueeze(1)).float() |
|
|
| nll_loss = parallel_nll_xyz(target_pos, mu, log_sigma, rho, mask, |
| beta=self.beta_nll) |
| with torch.no_grad(): |
| step_l2 = (target_pos - mu).pow(2).sum(-1).sqrt() |
| ade_train = (step_l2 * mask).sum(-1) / mask.sum(-1).clamp(min=1.0) |
| ade_train = ade_train.mean() |
|
|
| losses = {"nll": nll_loss, "ade_train": ade_train, "total": nll_loss} |
|
|
| if self.lambda_jepa > 0.0: |
| tgt_feat = torch.zeros(B, delta_max, 9, device=device) |
| tgt_feat[..., :3] = target_pos |
| z_target = self.encode_future_target(tgt_feat, delta) |
| z_pred = self.predictor(z_in, delta_e) |
| jepa_loss = F.l1_loss(z_pred, z_target.detach()) |
| losses["jepa"] = jepa_loss |
| losses["total"] = nll_loss + self.lambda_jepa * jepa_loss |
|
|
| return losses |
|
|
| |
| prev_pos = last_pos |
| nll_total = torch.zeros(B, device=device) |
| valid_steps = torch.zeros(B, device=device) |
| ade_total = torch.zeros(B, device=device) |
|
|
| for t in range(delta_max): |
| delta_mu, log_sigma, rho = self.head(h) |
| true_pos_t = target_pos[:, t] |
| true_delta = true_pos_t - prev_pos |
|
|
| |
| nll = gaussian_nll_xyz(true_delta, delta_mu, log_sigma, rho, |
| beta=self.beta_nll) |
| mask = (t < delta).float() |
| nll_total = nll_total + nll * mask |
| ade_total = (ade_total |
| + (true_delta - delta_mu).pow(2).sum(-1).sqrt() * mask) |
| valid_steps = valid_steps + mask |
|
|
| |
| |
| if ss_prob > 0.0 and self.training: |
| use_pred = (torch.rand(B, device=device) < ss_prob).float().unsqueeze(-1) |
| |
| |
| fed_delta = use_pred * delta_mu.detach() + (1 - use_pred) * true_delta |
| fed_pos = use_pred * (prev_pos + delta_mu.detach()) + (1 - use_pred) * true_pos_t |
| else: |
| fed_delta = true_delta |
| fed_pos = true_pos_t |
|
|
| h = self.step_cell(fed_delta, h) |
| prev_pos = fed_pos |
|
|
| nll_loss = (nll_total / valid_steps.clamp(min=1.0)).mean() |
| ade_train = (ade_total / valid_steps.clamp(min=1.0)).mean().detach() |
|
|
| losses = {"nll": nll_loss, "ade_train": ade_train, "total": nll_loss} |
|
|
| if self.lambda_jepa > 0.0: |
| tgt_feat = torch.zeros(B, delta_max, 9, device=device) |
| tgt_feat[..., :3] = target_pos |
| z_target = self.encode_future_target(tgt_feat, delta) |
| z_pred = self.predictor(z_in, delta_e) |
| jepa_loss = F.l1_loss(z_pred, z_target.detach()) |
| losses["jepa"] = jepa_loss |
| losses["total"] = nll_loss + self.lambda_jepa * jepa_loss |
|
|
| return losses |
|
|
| @torch.no_grad() |
| def rollout(self, past_features, past_length, delta, last_pos, delta_max): |
| B = past_features.size(0) |
| device = past_features.device |
| z_in, _, _ = self.encode_past(past_features, past_length) |
| delta_e = self.delta_emb(delta, past_length) |
| h = self.fuse_in(torch.cat([z_in, delta_e], dim=-1)) |
|
|
| if self.decoder_mode == "parallel": |
| |
| |
| req = max(delta_max, 1) |
| n_emit = min(req, self.t_max) |
| raw = self.parallel_decoder(h) |
| raw = raw[:, :n_emit] |
| mu_abs, log_sigma, rho = split_parallel_output(raw) |
| sigma = log_sigma.exp() |
| mu_pos = torch.zeros(B, delta_max, 3, device=device) |
| sg = torch.zeros(B, delta_max, 3, device=device) |
| ro = torch.zeros(B, delta_max, device=device) |
| mu_pos[:, :n_emit] = mu_abs |
| sg[:, :n_emit] = sigma |
| ro[:, :n_emit] = rho |
| |
| |
| |
| if delta_max > n_emit: |
| mu_pos[:, n_emit:] = mu_abs[:, -1:] |
| sg[:, n_emit:] = sigma[:, -1:] |
| ro[:, n_emit:] = rho[:, -1:] |
| return mu_pos, sg, ro |
|
|
| |
| prev_pos = last_pos |
| mu_pos = torch.zeros(B, delta_max, 3, device=device) |
| sigma = torch.zeros(B, delta_max, 3, device=device) |
| rho_out = torch.zeros(B, delta_max, device=device) |
| for t in range(delta_max): |
| delta_mu, log_sigma, rho = self.head(h) |
| cur_pos = prev_pos + delta_mu |
| mu_pos[:, t] = cur_pos |
| sigma[:, t] = log_sigma.exp() |
| rho_out[:, t] = rho |
| h = self.step_cell(delta_mu, h) |
| prev_pos = cur_pos |
| return mu_pos, sigma, rho_out |
|
|
|
|
| |
| |
| |
|
|
| RMAX_KM = 120.0 |
| DELTA_BUCKETS = [(30, 60), (60, 90), (90, 120)] |
| EXTRAP_DELTAS = [180, 300] |
| THRESH_M = [500.0, 1000.0, 2000.0] |
|
|
|
|
| def get_last_pos(past_features, past_length): |
| B = past_features.size(0) |
| idx = (past_length - 1).clamp(min=0) |
| return past_features[torch.arange(B, device=past_features.device), idx, :3] |
|
|
|
|
| def train_one_epoch(model, loader, optimizer, device, grad_clip=1.0, |
| log_every: int = 50, ss_prob: float = 0.0): |
| model.train() |
| sums = {"nll": 0.0, "ade": 0.0, "jepa": 0.0, "total": 0.0, "n": 0} |
| t0 = time.time() |
| n_batches = len(loader) if hasattr(loader, "__len__") else 0 |
| for bi, batch in enumerate(loader): |
| past_f = batch["past_features"].to(device) |
| past_l = batch["past_length"].to(device) |
| target = batch["target_pos"].to(device) |
| delta = batch["delta"].to(device) |
| last_pos = get_last_pos(past_f, past_l) |
| losses = model(past_f, past_l, target, delta, last_pos, |
| ss_prob=ss_prob) |
| optimizer.zero_grad() |
| losses["total"].backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) |
| optimizer.step() |
| if model.lambda_jepa > 0.0: |
| model.update_ema() |
| bs = past_f.size(0) |
| sums["nll"] += losses["nll"].item() * bs |
| sums["ade"] += losses["ade_train"].item() * bs |
| if "jepa" in losses: |
| sums["jepa"] += losses["jepa"].item() * bs |
| sums["total"] += losses["total"].item() * bs |
| sums["n"] += bs |
|
|
| if (bi + 1) % log_every == 0 or bi == 0: |
| dt = time.time() - t0 |
| rate = (bi + 1) / max(dt, 0.001) |
| print(f" [batch {bi+1}/{n_batches}] {dt:.1f}s elapsed, " |
| f"{rate:.1f} batch/s, loss={losses['total'].item():.4f}", |
| flush=True) |
| n = max(sums["n"], 1) |
| return {k: v / n for k, v in sums.items() if k != "n"} | { |
| "ade_train": sums["ade"] / n |
| } |
|
|
|
|
| @torch.no_grad() |
| def score_loader(model, loader, device, extrap_delta=None): |
| model.train(False) |
| delta_max_dataset = loader.dataset.delta_max |
| per_sample = [] |
| for batch in loader: |
| past_f = batch["past_features"].to(device) |
| past_l = batch["past_length"].to(device) |
| target = batch["target_pos"].to(device) |
| delta = batch["delta"].to(device) |
| last_pos = get_last_pos(past_f, past_l) |
| if extrap_delta is not None: |
| forced = torch.full_like(delta, extrap_delta) |
| roll_len = extrap_delta |
| else: |
| forced = delta |
| roll_len = int(delta.max().item()) |
| if roll_len > delta_max_dataset: |
| continue |
| mu_pos, sigma, rho = model.rollout(past_f, past_l, forced, last_pos, roll_len) |
| active_len = torch.minimum(forced, delta).clamp(min=1) |
| for i in range(past_f.size(0)): |
| L = int(active_len[i].item()) |
| per_sample.append({ |
| "mu": mu_pos[i, :L].cpu().numpy(), |
| "sigma": sigma[i, :L].cpu().numpy(), |
| "rho": rho[i, :L].cpu().numpy(), |
| "target": target[i, :L].cpu().numpy(), |
| "delta_orig": int(delta[i].item()), |
| }) |
|
|
| if not per_sample: |
| return {} |
| ades, fdes = [], [] |
| in_circle = {t: [] for t in THRESH_M} |
| nlls, coverage95, delta_orig = [], [], [] |
| for s in per_sample: |
| diff = s["target"] - s["mu"] |
| per_step_l2 = np.linalg.norm(diff, axis=1) * RMAX_KM * 1000.0 |
| ades.append(per_step_l2.mean()) |
| fdes.append(per_step_l2[-1]) |
| for t in THRESH_M: |
| in_circle[t].append(per_step_l2[-1] <= t) |
| sx = max(s["sigma"][-1, 0], 1e-9) |
| sy = max(s["sigma"][-1, 1], 1e-9) |
| sz = max(s["sigma"][-1, 2], 1e-9) |
| rho_xy = s["rho"][-1] |
| dx = diff[-1, 0]; dy = diff[-1, 1]; dz = diff[-1, 2] |
| omr2 = max(1.0 - rho_xy * rho_xy, 1e-6) |
| z2 = ((dx / sx) ** 2 - 2 * rho_xy * (dx / sx) * (dy / sy) |
| + (dy / sy) ** 2) / omr2 |
| coverage95.append(z2 <= 5.991) |
| log_det = 2 * (math.log(sx) + math.log(sy)) + math.log(omr2) |
| nll_xy = 0.5 * (z2 + log_det + 2 * math.log(2 * math.pi)) |
| nll_z = 0.5 * ((dz / sz) ** 2 + 2 * math.log(sz) + math.log(2 * math.pi)) |
| nlls.append(nll_xy + nll_z) |
| delta_orig.append(s["delta_orig"]) |
| ades = np.array(ades); fdes = np.array(fdes) |
| nlls = np.array(nlls); coverage95 = np.array(coverage95, dtype=float) |
| delta_orig = np.array(delta_orig) |
| out = { |
| "ade_m": float(ades.mean()), |
| "fde_m": float(fdes.mean()), |
| "fde_median_m": float(np.median(fdes)), |
| "nll_xy_z": float(nlls.mean()), |
| "coverage_95": float(coverage95.mean()), |
| "n": len(ades), |
| } |
| for t in THRESH_M: |
| out[f"miss_rate_{int(t)}m"] = float(1.0 - np.mean(in_circle[t])) |
| if extrap_delta is None: |
| per_bucket = {} |
| for lo, hi in DELTA_BUCKETS: |
| mask = (delta_orig >= lo) & (delta_orig <= hi) |
| if mask.sum() == 0: |
| continue |
| per_bucket[f"delta_{lo}_{hi}"] = { |
| "ade_m": float(ades[mask].mean()), |
| "fde_m": float(fdes[mask].mean()), |
| "coverage_95": float(coverage95[mask].mean()), |
| "n": int(mask.sum()), |
| } |
| out["per_bucket"] = per_bucket |
| return out |
|
|
|
|
| def main(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--airport", default="RKSIa") |
| p.add_argument("--data-dir", default="data") |
| p.add_argument("--tag", default="run") |
| p.add_argument("--out-dir", default="runs") |
| p.add_argument("--epochs", type=int, default=30) |
| p.add_argument("--batch-size", type=int, default=64) |
| p.add_argument("--lr", type=float, default=1e-4) |
| p.add_argument("--weight-decay", type=float, default=1e-4) |
| p.add_argument("--past-max", type=int, default=256) |
| p.add_argument("--past-min", type=int, default=60) |
| p.add_argument("--delta-min", type=int, default=30) |
| p.add_argument("--delta-max", type=int, default=120) |
| p.add_argument("--extrap-delta-max", type=int, default=300) |
| p.add_argument("--epoch-multiplier", type=int, default=4) |
| p.add_argument("--lambda-jepa", type=float, default=0.0) |
| p.add_argument("--ema-decay", type=float, default=0.998) |
| p.add_argument("--beta-nll", type=float, default=0.5, |
| help="β-NLL exponent (Seitzer 2022). 0=plain NLL, 0.5=recommended.") |
| p.add_argument("--ss-max", type=float, default=0.0, |
| help="Max scheduled-sampling probability (0=teacher-forcing only, 0.5=Bengio recommended).") |
| p.add_argument("--ss-warmup-frac", type=float, default=0.5, |
| help="Fraction of training over which ss_prob ramps from 0 to ss_max linearly.") |
| p.add_argument("--decoder-mode", choices=["ar", "parallel"], default="ar", |
| help="ar = v5 GRU autoregressive; parallel = v6 HiVT-style MLP decoder.") |
| p.add_argument("--d-model", type=int, default=256) |
| p.add_argument("--n-layers", type=int, default=4) |
| p.add_argument("--n-heads", type=int, default=8) |
| p.add_argument("--patch-size", type=int, default=8) |
| p.add_argument("--seed", type=int, default=0) |
| p.add_argument("--num-workers", type=int, default=2) |
| p.add_argument("--push-to-hub", action="store_true") |
| p.add_argument("--hub-model-id", default=None) |
| p.add_argument("--pretrained-encoder", default=None, |
| help="Path or HF repo id to a pretrained encoder checkpoint " |
| "(loaded into tokenizer + encoder weights before training).") |
| p.add_argument("--pretrained-encoder-file", default=None, |
| help="If --pretrained-encoder is a HF repo, name of the file in it.") |
| p.add_argument("--freeze-encoder", action="store_true", |
| help="Freeze tokenizer + encoder weights after loading pretrained.") |
| p.add_argument("--held-out-classes", default=None, |
| help="Comma-separated class IDs to EXCLUDE from training (e.g., '6,18,28').") |
| p.add_argument("--keep-only-classes", default=None, |
| help="Comma-separated class IDs to KEEP for evaluation (eval on these only).") |
| p.add_argument("--trackio-name", default=None) |
| args = p.parse_args() |
|
|
| torch.manual_seed(args.seed) |
| np.random.seed(args.seed) |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"[v2] device={device} tag={args.tag} " |
| f"decoder_mode={args.decoder_mode} " |
| f"lambda_jepa={args.lambda_jepa} beta_nll={args.beta_nll} " |
| f"ss_max={args.ss_max} ss_warmup_frac={args.ss_warmup_frac}", |
| flush=True) |
| if device == "cuda": |
| print(f"[v2] cuda device: {torch.cuda.get_device_name(0)} " |
| f"vram={torch.cuda.get_device_properties(0).total_memory/1e9:.1f}GB", |
| flush=True) |
| else: |
| print("[v2] WARNING: CUDA not available, training on CPU. " |
| "This will be very slow.", flush=True) |
|
|
| if HAS_TRACKIO and args.trackio_name: |
| trackio.init(project="flight-jepa-v2", name=args.trackio_name, |
| config=vars(args)) |
|
|
| held_out = ( |
| [int(c) for c in args.held_out_classes.split(",")] |
| if args.held_out_classes else None |
| ) |
| keep_only = ( |
| [int(c) for c in args.keep_only_classes.split(",")] |
| if args.keep_only_classes else None |
| ) |
| train_ds = BlindspotDataset( |
| airport=args.airport, mode="TRAIN", data_dir=args.data_dir, |
| past_max=args.past_max, past_min=args.past_min, |
| delta_min=args.delta_min, delta_max=args.delta_max, |
| seed=args.seed, epoch_multiplier=args.epoch_multiplier, |
| held_out_classes=held_out, keep_only_classes=keep_only, |
| ) |
| test_ds = BlindspotDataset( |
| airport=args.airport, mode="TEST", data_dir=args.data_dir, |
| past_max=args.past_max, past_min=args.past_min, |
| delta_min=args.delta_min, delta_max=args.delta_max, |
| seed=args.seed + 1, epoch_multiplier=1, |
| held_out_classes=held_out, keep_only_classes=keep_only, |
| ) |
| extrap_ds = BlindspotDataset( |
| airport=args.airport, mode="TEST", data_dir=args.data_dir, |
| past_max=args.past_max, past_min=args.past_min, |
| delta_min=args.delta_min, delta_max=args.extrap_delta_max, |
| seed=args.seed + 99, epoch_multiplier=1, |
| held_out_classes=held_out, keep_only_classes=keep_only, |
| ) |
|
|
| train_dl = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, |
| num_workers=args.num_workers, pin_memory=True, |
| drop_last=True) |
| test_dl = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, |
| num_workers=args.num_workers, pin_memory=True) |
| extrap_dl = DataLoader(extrap_ds, batch_size=args.batch_size, shuffle=False, |
| num_workers=args.num_workers, pin_memory=True) |
|
|
| cfg = { |
| "d_model": args.d_model, "n_heads": args.n_heads, |
| "n_layers": args.n_layers, "d_ff": args.d_model * 4, |
| "dropout": 0.1, "patch_size": args.patch_size, |
| "past_max": args.past_max, "lambda_jepa": args.lambda_jepa, |
| "ema_decay": args.ema_decay, "beta_nll": args.beta_nll, |
| "decoder_mode": args.decoder_mode, |
| "delta_max": args.delta_max, |
| } |
| model = FlightJEPAv2(cfg).to(device) |
| n_params = sum(p.numel() for p in model.parameters()) |
| print(f"[v2] params={n_params/1e6:.2f}M") |
|
|
| |
| if args.pretrained_encoder: |
| path = args.pretrained_encoder |
| if not os.path.exists(path): |
| |
| from huggingface_hub import hf_hub_download |
| file_name = args.pretrained_encoder_file or "v7-pretrain.pt" |
| path = hf_hub_download(args.pretrained_encoder, file_name) |
| ck = torch.load(path, map_location=device, weights_only=False) |
| miss_t, unx_t = model.tokenizer.load_state_dict( |
| ck["tokenizer_state_dict"], strict=False |
| ) |
| miss_e, unx_e = model.encoder.load_state_dict( |
| ck["encoder_state_dict"], strict=False |
| ) |
| print(f"[v2] loaded pretrained encoder from {path}") |
| print(f" tokenizer missing={len(miss_t)} unexpected={len(unx_t)}") |
| print(f" encoder missing={len(miss_e)} unexpected={len(unx_e)}") |
| |
| model.target_tokenizer.load_state_dict(model.tokenizer.state_dict()) |
| model.target_encoder.load_state_dict(model.encoder.state_dict()) |
| if args.freeze_encoder: |
| for p_ in model.tokenizer.parameters(): |
| p_.requires_grad = False |
| for p_ in model.encoder.parameters(): |
| p_.requires_grad = False |
| print("[v2] tokenizer + encoder FROZEN") |
|
|
| optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, |
| weight_decay=args.weight_decay) |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) |
|
|
| os.makedirs(args.out_dir, exist_ok=True) |
| history = [] |
| best_fde = float("inf") |
| best_state = None |
|
|
| for epoch in range(args.epochs): |
| t0 = time.time() |
| |
| |
| warmup_epochs = max(1, int(args.epochs * args.ss_warmup_frac)) |
| ss_prob = min(args.ss_max, |
| args.ss_max * (epoch + 1) / warmup_epochs) |
| train_stats = train_one_epoch(model, train_dl, optimizer, device, |
| ss_prob=ss_prob) |
| scheduler.step() |
|
|
| score_stats = None |
| if (epoch + 1) % 5 == 0 or epoch == args.epochs - 1: |
| score_stats = score_loader(model, test_dl, device) |
| if score_stats and score_stats["fde_m"] < best_fde: |
| best_fde = score_stats["fde_m"] |
| best_state = {k: v.detach().cpu().clone() |
| for k, v in model.state_dict().items()} |
|
|
| elapsed = time.time() - t0 |
| log = { |
| "epoch": epoch + 1, "elapsed_s": elapsed, |
| "lr": optimizer.param_groups[0]["lr"], |
| "train": train_stats, "score": score_stats, |
| } |
| history.append(log) |
| msg = (f"[v2] ep {epoch+1:03d} | loss={train_stats['total']:.4f} " |
| f"nll={train_stats['nll']:.4f} ade_t={train_stats['ade_train']:.4f} " |
| f"jepa={train_stats['jepa']:.4f} ss={ss_prob:.2f}") |
| if score_stats: |
| msg += f" | fde={score_stats['fde_m']:.0f}m ade={score_stats['ade_m']:.0f}m" |
| msg += f" | {elapsed:.0f}s" |
| print(msg, flush=True) |
|
|
| if HAS_TRACKIO and args.trackio_name: |
| tlog = {f"train/{k}": v for k, v in train_stats.items()} |
| if score_stats: |
| tlog.update({f"test/{k}": v for k, v in score_stats.items() |
| if isinstance(v, (int, float))}) |
| trackio.log(tlog, step=epoch + 1) |
|
|
| final = {"in_distribution": score_loader(model, test_dl, device)} |
| for d in EXTRAP_DELTAS: |
| final[f"extrap_delta_{d}"] = score_loader(model, extrap_dl, device, extrap_delta=d) |
|
|
| if best_state is not None: |
| model.load_state_dict(best_state) |
|
|
| out_path = os.path.join(args.out_dir, f"{args.tag}.pt") |
| torch.save({ |
| "state_dict": model.state_dict(), |
| "config": cfg, "args": vars(args), |
| "history": history, "final": final, |
| "best_fde_m": best_fde, |
| }, out_path) |
| print(f"[v2] saved {out_path}") |
|
|
| summary_path = os.path.join(args.out_dir, f"{args.tag}_summary.json") |
| with open(summary_path, "w") as f: |
| json.dump({ |
| "tag": args.tag, "lambda_jepa": args.lambda_jepa, |
| "beta_nll": args.beta_nll, |
| "n_params": n_params, "best_fde_m": best_fde, |
| "final": final, "args": vars(args), |
| }, f, indent=2, default=float) |
| print(f"[v2] summary -> {summary_path}", flush=True) |
|
|
| if args.push_to_hub and args.hub_model_id: |
| try: |
| from huggingface_hub import HfApi |
| api = HfApi() |
| api.create_repo(args.hub_model_id, exist_ok=True) |
| for path, fname in [(out_path, f"{args.tag}.pt"), |
| (summary_path, f"{args.tag}_summary.json")]: |
| api.upload_file(path_or_fileobj=path, path_in_repo=fname, |
| repo_id=args.hub_model_id) |
| print(f"[v2] uploaded to {args.hub_model_id}") |
| except Exception as e: |
| print(f"[v2] hub upload failed: {e}") |
|
|
| if HAS_TRACKIO and args.trackio_name: |
| trackio.finish() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|