# /// script # requires-python = ">=3.10" # dependencies = [ # "torch>=2.1", # "numpy", # "pandas", # "scikit-learn", # "huggingface-hub", # "trackio", # ] # /// """ Flight-JEPA v2 — bundled training script for HF Jobs. Self-contained: downloads the dataset from HF, trains either the supervised baseline (`--lambda-jepa 0`) or the JEPA-augmented model, runs blindspot scoring + extrapolation eval, and pushes the result to a hub repo. Usage (HF Jobs): python train_v2_prod.py --tag baseline --lambda-jepa 0.0 \ --hub-model-id guychuk/flight-jepa-v2 --push-to-hub python train_v2_prod.py --tag jepa --lambda-jepa 0.5 \ --hub-model-id guychuk/flight-jepa-v2 --push-to-hub """ from __future__ import annotations import argparse import copy import json import math import os import shutil import sys import time import numpy as np import pandas as pd import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader try: import trackio HAS_TRACKIO = True except ImportError: HAS_TRACKIO = False # ============================================================================ # DATA UTILITIES (inlined from flight_jepa.data) # ============================================================================ def load_atfm(dset_name, mode, path): variables = ["X", "Y", "Z"] data, labels = [], None for var in variables: df = pd.read_csv(os.path.join(path, f"{dset_name}_{mode}_{var}.tsv"), sep="\t", header=None, na_values="NaN") if labels is None: labels = df.values[:, 0] data.append(df.values[:, 1:]) return np.stack(data, axis=-1), labels.astype(int) def compute_features(traj_xyz: np.ndarray) -> np.ndarray: if traj_xyz.shape[0] < 2: T = traj_xyz.shape[0] return np.concatenate([ traj_xyz, np.zeros((T, 3), dtype=traj_xyz.dtype), np.zeros((T, 3), dtype=traj_xyz.dtype) ], axis=1) x, y, z = traj_xyz[:, 0], traj_xyz[:, 1], traj_xyz[:, 2] diff = np.diff(traj_xyz, axis=0) norms = np.maximum(np.linalg.norm(diff, axis=1, keepdims=True), 1e-8) u = diff / norms u = np.vstack([u, u[-1:]]) r = np.sqrt(x ** 2 + y ** 2) theta = np.arctan2(y, x) return np.column_stack([ traj_xyz, u, r[:, None], np.sin(theta)[:, None], np.cos(theta)[:, None] ]).astype(np.float32) def ensure_data(airport: str, data_dir: str = "data"): target = os.path.join(data_dir, airport) if os.path.isdir(target) and any(f.endswith(".tsv") for f in os.listdir(target)): return target print(f"[data] downloading {airport} from HF ...") from huggingface_hub import snapshot_download snap = snapshot_download( "petchthwr/ATFMTraj", repo_type="dataset", allow_patterns=[f"{airport}/*"], ) os.makedirs(data_dir, exist_ok=True) src = os.path.join(snap, airport) if not os.path.isdir(target): shutil.copytree(src, target) return target # ============================================================================ # DATASET — variable-length blindspot # ============================================================================ PAD_VALUE = 0.0 class BlindspotDataset(Dataset): def __init__(self, airport, mode, data_dir, past_max=256, past_min=60, delta_min=30, delta_max=120, seed=0, epoch_multiplier=4): ensure_data(airport, data_dir) airport_dir = os.path.join(data_dir, airport) raw, labels = load_atfm(airport, mode, airport_dir) self.past_max = past_max self.past_min = past_min self.delta_min = delta_min self.delta_max = delta_max self.epoch_multiplier = epoch_multiplier self.rng_seed = seed lengths = np.array( [int(np.sum(~np.isnan(raw[i, :, 0]))) for i in range(raw.shape[0])], dtype=np.int64, ) min_required = past_min + delta_max + 1 keep = lengths >= min_required if keep.sum() == 0: raise RuntimeError( f"No trajectories of length >= {min_required} in {airport}/{mode}" ) raw = raw[keep] lengths = lengths[keep] self.labels = labels[keep].astype(np.int64) self.positions = [] for i in range(raw.shape[0]): L = int(lengths[i]) self.positions.append(np.nan_to_num(raw[i, :L], nan=0.0).astype(np.float32)) del raw self.n_traj = len(self.positions) print(f"[data] {airport}/{mode}: {self.n_traj} trajectories " f"(after filtering for L >= {min_required})") def __len__(self): return self.n_traj * self.epoch_multiplier def __getitem__(self, idx): traj_idx = idx % self.n_traj rng = np.random.default_rng(self.rng_seed + idx * 9173) positions = self.positions[traj_idx] L = positions.shape[0] delta = int(rng.integers(self.delta_min, self.delta_max + 1)) t_in_max = L - delta - 1 t_in_min = self.past_min t_in = int(rng.integers(t_in_min, t_in_max + 1)) past_start = max(0, t_in - self.past_max) past_pos = positions[past_start:t_in] target_pos = positions[t_in:t_in + delta] past_features = compute_features(past_pos) T_past = past_features.shape[0] feat_pad = np.full((self.past_max, 9), PAD_VALUE, dtype=np.float32) feat_pad[:T_past] = past_features tgt_pad = np.zeros((self.delta_max, 3), dtype=np.float32) tgt_pad[:delta] = target_pos return { "past_features": torch.from_numpy(feat_pad), "past_length": torch.tensor(T_past, dtype=torch.long), "target_pos": torch.from_numpy(tgt_pad), "delta": torch.tensor(delta, dtype=torch.long), "label": torch.tensor(int(self.labels[traj_idx]), dtype=torch.long), } # ============================================================================ # MODEL # ============================================================================ def sinusoidal_embedding(values, dim): half = dim // 2 device = values.device freqs = torch.exp(-math.log(10000.0) * torch.arange(half, device=device) / half) angles = values.float().unsqueeze(-1) * freqs emb = torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1) if dim % 2 == 1: emb = F.pad(emb, (0, 1)) return emb class LearnablePosEnc(nn.Module): def __init__(self, max_len, d_model): super().__init__() self.pe = nn.Parameter(torch.randn(1, max_len, d_model) * 0.02) def forward(self, x): return x + self.pe[:, :x.size(1)] class PatchTokenizer(nn.Module): def __init__(self, in_channels=9, d_model=256, patch_size=8, max_patches=64): super().__init__() self.patch_size = patch_size self.d_model = d_model self.embed = nn.Sequential( nn.Conv1d(in_channels, d_model // 2, 5, padding=2), nn.GELU(), nn.Conv1d(d_model // 2, d_model, 3, padding=1), nn.GELU(), ) self.pos_enc = LearnablePosEnc(max_patches, d_model) self.norm = nn.LayerNorm(d_model) def forward(self, features, lengths): B, T, C = features.shape h = self.embed(features.transpose(1, 2)) N = max(1, T // self.patch_size) h = h[:, :, :N * self.patch_size] h = h.reshape(B, self.d_model, N, self.patch_size).mean(-1) h = h.transpose(1, 2) h = self.norm(self.pos_enc(h)) patch_lengths = (lengths.float() / self.patch_size).clamp(min=1).long() patch_lengths = patch_lengths.clamp(max=N) return h, patch_lengths class CausalEncoder(nn.Module): def __init__(self, d_model=256, n_heads=8, n_layers=4, d_ff=1024, dropout=0.1): super().__init__() layer = nn.TransformerEncoderLayer( d_model=d_model, nhead=n_heads, dim_feedforward=d_ff, dropout=dropout, activation="gelu", batch_first=True, norm_first=True, ) self.tf = nn.TransformerEncoder(layer, num_layers=n_layers) self.norm = nn.LayerNorm(d_model) def forward(self, x, key_padding_mask): N = x.size(1) causal_mask = torch.triu( torch.ones(N, N, dtype=torch.bool, device=x.device), diagonal=1 ) return self.norm( self.tf(x, mask=causal_mask, src_key_padding_mask=key_padding_mask) ) def last_valid_token(encoded, patch_lengths): B, N, D = encoded.shape idx = (patch_lengths - 1).clamp(min=0).view(B, 1, 1).expand(-1, 1, D) return encoded.gather(1, idx).squeeze(1) class DeltaEmbedding(nn.Module): def __init__(self, d_model=256, d_freq=64): super().__init__() self.d_freq = d_freq self.proj = nn.Sequential( nn.Linear(d_freq * 2, d_model), nn.GELU(), nn.Linear(d_model, d_model), ) def forward(self, delta, t_past): d_emb = sinusoidal_embedding(delta.float(), self.d_freq) rel = delta.float() / t_past.float().clamp(min=1.0) rel_emb = sinusoidal_embedding(rel * 100.0, self.d_freq) return self.proj(torch.cat([d_emb, rel_emb], dim=-1)) class GaussianHead(nn.Module): def __init__(self, d_model=256, d_hidden=256): super().__init__() self.net = nn.Sequential( nn.Linear(d_model, d_hidden), nn.GELU(), nn.Linear(d_hidden, d_hidden), nn.GELU(), ) self.mu_head = nn.Linear(d_hidden, 3) self.log_sigma_head = nn.Linear(d_hidden, 3) self.rho_head = nn.Linear(d_hidden, 1) def forward(self, h): z = self.net(h) delta_mu = self.mu_head(z) log_sigma = self.log_sigma_head(z).clamp(min=-7.0, max=2.0) rho = torch.tanh(self.rho_head(z)).squeeze(-1) * 0.99 return delta_mu, log_sigma, rho def gaussian_nll_xyz(true_delta, mu, log_sigma, rho, beta: float = 0.5): """ β-NLL Gaussian for (x, y, z) — bivariate on xy + independent z. Standard NLL has a degenerate minimum where σ→0 ("σ-collapse", Detlefsen 2019). β-NLL (Seitzer et al., arxiv:2203.09168) reweights each sample's NLL by σ^{2β} (detached) so points with large σ get proportionally more gradient on the mean term, preventing collapse. β = 0 → standard NLL (collapse-prone, what v2 used) β = 0.5 → recommended; preserves uncertainty learning β = 1 → pure squared-error scaling (loses σ learning) """ sx = log_sigma[:, 0].exp() sy = log_sigma[:, 1].exp() sz = log_sigma[:, 2].exp() dx = true_delta[:, 0] - mu[:, 0] dy = true_delta[:, 1] - mu[:, 1] dz = true_delta[:, 2] - mu[:, 2] omr2 = (1.0 - rho * rho).clamp(min=1e-6) z2 = (((dx / sx) ** 2) - 2.0 * rho * (dx / sx) * (dy / sy) + ((dy / sy) ** 2)) / omr2 log_det = 2.0 * (log_sigma[:, 0] + log_sigma[:, 1]) + torch.log(omr2) nll_xy = 0.5 * (z2 + log_det + 2.0 * math.log(2.0 * math.pi)) nll_z = 0.5 * ((dz / sz) ** 2 + 2.0 * log_sigma[:, 2] + math.log(2.0 * math.pi)) if beta > 0.0: # Detached per-sample weights: σ^{2β}. Weight is treated as constant # during backward, so it rescales the gradient without participating # in optimization. # For xy use geometric-mean σ; for z use σz directly. sxy = (sx * sy).sqrt().detach() wxy = sxy.pow(2.0 * beta) wz = sz.detach().pow(2.0 * beta) return wxy * nll_xy + wz * nll_z return nll_xy + nll_z class FuturePredictor(nn.Module): def __init__(self, d_model=256, pred_dim=128, dropout=0.1): super().__init__() self.proj_in = nn.Linear(d_model * 2, pred_dim) layer = nn.TransformerEncoderLayer( d_model=pred_dim, nhead=4, dim_feedforward=pred_dim * 2, dropout=dropout, activation="gelu", batch_first=True, norm_first=True, ) self.tf = nn.TransformerEncoder(layer, num_layers=2) self.proj_out = nn.Linear(pred_dim, d_model) self.norm = nn.LayerNorm(d_model) def forward(self, z_in, delta_emb): h = self.proj_in(torch.cat([z_in, delta_emb], dim=-1)).unsqueeze(1) h = self.tf(h) return self.norm(self.proj_out(h.squeeze(1))) class FlightJEPAv2(nn.Module): def __init__(self, cfg): super().__init__() self.cfg = cfg d = cfg.get("d_model", 256) h_ = cfg.get("n_heads", 8) n_l = cfg.get("n_layers", 4) d_ff = cfg.get("d_ff", 1024) dr = cfg.get("dropout", 0.1) ps = cfg.get("patch_size", 8) past_max = cfg.get("past_max", 256) max_patches = past_max // ps self.lambda_jepa = cfg.get("lambda_jepa", 0.0) self.ema_decay = cfg.get("ema_decay", 0.998) self.beta_nll = cfg.get("beta_nll", 0.5) self.tokenizer = PatchTokenizer(9, d, ps, max_patches) self.encoder = CausalEncoder(d, h_, n_l, d_ff, dr) self.delta_emb = DeltaEmbedding(d, 64) self.head = GaussianHead(d, d) self.fuse_in = nn.Sequential( nn.Linear(d * 2, d), nn.GELU(), nn.Linear(d, d), ) self.step_cell = nn.GRUCell(input_size=3, hidden_size=d) self.target_tokenizer = copy.deepcopy(self.tokenizer) self.target_encoder = copy.deepcopy(self.encoder) for p in self.target_tokenizer.parameters(): p.requires_grad = False for p in self.target_encoder.parameters(): p.requires_grad = False self.predictor = FuturePredictor(d, d // 2, dr) @torch.no_grad() def update_ema(self): m = self.ema_decay for online, target in [(self.tokenizer, self.target_tokenizer), (self.encoder, self.target_encoder)]: for po, pt in zip(online.parameters(), target.parameters()): pt.data.mul_(m).add_(po.data, alpha=1.0 - m) def encode_past(self, past_features, past_length): patches, patch_lens = self.tokenizer(past_features, past_length) N = patches.size(1) pad_mask = (torch.arange(N, device=patches.device).unsqueeze(0) >= patch_lens.unsqueeze(1)) encoded = self.encoder(patches, key_padding_mask=pad_mask) z_in = last_valid_token(encoded, patch_lens) return z_in, encoded, patch_lens @torch.no_grad() def encode_future_target(self, target_features, target_length): patches, patch_lens = self.target_tokenizer(target_features, target_length) N = patches.size(1) pad_mask = (torch.arange(N, device=patches.device).unsqueeze(0) >= patch_lens.unsqueeze(1)) encoded = self.target_encoder(patches, key_padding_mask=pad_mask) return last_valid_token(encoded, patch_lens) def forward(self, past_features, past_length, target_pos, delta, last_pos, ss_prob: float = 0.0): """ ss_prob: scheduled-sampling probability ∈ [0, 1]. With this probability per (batch element, timestep), the *predicted* delta replaces the *true* delta in the recurrence. NLL loss is always against truth — only the GRU input + prev_pos accumulator are mixed. """ B = past_features.size(0) device = past_features.device delta_max = target_pos.size(1) z_in, _, _ = self.encode_past(past_features, past_length) delta_e = self.delta_emb(delta, past_length) h = self.fuse_in(torch.cat([z_in, delta_e], dim=-1)) prev_pos = last_pos nll_total = torch.zeros(B, device=device) valid_steps = torch.zeros(B, device=device) ade_total = torch.zeros(B, device=device) for t in range(delta_max): delta_mu, log_sigma, rho = self.head(h) true_pos_t = target_pos[:, t] true_delta = true_pos_t - prev_pos # NLL computed always vs truth. nll = gaussian_nll_xyz(true_delta, delta_mu, log_sigma, rho, beta=self.beta_nll) mask = (t < delta).float() nll_total = nll_total + nll * mask ade_total = (ade_total + (true_delta - delta_mu).pow(2).sum(-1).sqrt() * mask) valid_steps = valid_steps + mask # Scheduled-sampling: with prob ss_prob, feed predicted delta instead # of true delta into the recurrence. Sampled per (batch, step). if ss_prob > 0.0 and self.training: use_pred = (torch.rand(B, device=device) < ss_prob).float().unsqueeze(-1) # Use predicted mean as "what we would do at inference time". # Detach so the prev_pos accumulator gradient doesn't recurse. fed_delta = use_pred * delta_mu.detach() + (1 - use_pred) * true_delta fed_pos = use_pred * (prev_pos + delta_mu.detach()) + (1 - use_pred) * true_pos_t else: fed_delta = true_delta fed_pos = true_pos_t h = self.step_cell(fed_delta, h) prev_pos = fed_pos nll_loss = (nll_total / valid_steps.clamp(min=1.0)).mean() ade_train = (ade_total / valid_steps.clamp(min=1.0)).mean().detach() losses = {"nll": nll_loss, "ade_train": ade_train, "total": nll_loss} if self.lambda_jepa > 0.0: tgt_feat = torch.zeros(B, delta_max, 9, device=device) tgt_feat[..., :3] = target_pos z_target = self.encode_future_target(tgt_feat, delta) z_pred = self.predictor(z_in, delta_e) jepa_loss = F.l1_loss(z_pred, z_target.detach()) losses["jepa"] = jepa_loss losses["total"] = nll_loss + self.lambda_jepa * jepa_loss return losses @torch.no_grad() def rollout(self, past_features, past_length, delta, last_pos, delta_max): B = past_features.size(0) device = past_features.device z_in, _, _ = self.encode_past(past_features, past_length) delta_e = self.delta_emb(delta, past_length) h = self.fuse_in(torch.cat([z_in, delta_e], dim=-1)) prev_pos = last_pos mu_pos = torch.zeros(B, delta_max, 3, device=device) sigma = torch.zeros(B, delta_max, 3, device=device) rho_out = torch.zeros(B, delta_max, device=device) for t in range(delta_max): delta_mu, log_sigma, rho = self.head(h) cur_pos = prev_pos + delta_mu mu_pos[:, t] = cur_pos sigma[:, t] = log_sigma.exp() rho_out[:, t] = rho h = self.step_cell(delta_mu, h) prev_pos = cur_pos return mu_pos, sigma, rho_out # ============================================================================ # TRAIN + SCORE # ============================================================================ RMAX_KM = 120.0 DELTA_BUCKETS = [(30, 60), (60, 90), (90, 120)] EXTRAP_DELTAS = [180, 300] THRESH_M = [500.0, 1000.0, 2000.0] def get_last_pos(past_features, past_length): B = past_features.size(0) idx = (past_length - 1).clamp(min=0) return past_features[torch.arange(B, device=past_features.device), idx, :3] def train_one_epoch(model, loader, optimizer, device, grad_clip=1.0, log_every: int = 50, ss_prob: float = 0.0): model.train() sums = {"nll": 0.0, "ade": 0.0, "jepa": 0.0, "total": 0.0, "n": 0} t0 = time.time() n_batches = len(loader) if hasattr(loader, "__len__") else 0 for bi, batch in enumerate(loader): past_f = batch["past_features"].to(device) past_l = batch["past_length"].to(device) target = batch["target_pos"].to(device) delta = batch["delta"].to(device) last_pos = get_last_pos(past_f, past_l) losses = model(past_f, past_l, target, delta, last_pos, ss_prob=ss_prob) optimizer.zero_grad() losses["total"].backward() torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) optimizer.step() if model.lambda_jepa > 0.0: model.update_ema() bs = past_f.size(0) sums["nll"] += losses["nll"].item() * bs sums["ade"] += losses["ade_train"].item() * bs if "jepa" in losses: sums["jepa"] += losses["jepa"].item() * bs sums["total"] += losses["total"].item() * bs sums["n"] += bs if (bi + 1) % log_every == 0 or bi == 0: dt = time.time() - t0 rate = (bi + 1) / max(dt, 0.001) print(f" [batch {bi+1}/{n_batches}] {dt:.1f}s elapsed, " f"{rate:.1f} batch/s, loss={losses['total'].item():.4f}", flush=True) n = max(sums["n"], 1) return {k: v / n for k, v in sums.items() if k != "n"} | { "ade_train": sums["ade"] / n } @torch.no_grad() def score_loader(model, loader, device, extrap_delta=None): model.train(False) delta_max_dataset = loader.dataset.delta_max per_sample = [] for batch in loader: past_f = batch["past_features"].to(device) past_l = batch["past_length"].to(device) target = batch["target_pos"].to(device) delta = batch["delta"].to(device) last_pos = get_last_pos(past_f, past_l) if extrap_delta is not None: forced = torch.full_like(delta, extrap_delta) roll_len = extrap_delta else: forced = delta roll_len = int(delta.max().item()) if roll_len > delta_max_dataset: continue mu_pos, sigma, rho = model.rollout(past_f, past_l, forced, last_pos, roll_len) active_len = torch.minimum(forced, delta).clamp(min=1) for i in range(past_f.size(0)): L = int(active_len[i].item()) per_sample.append({ "mu": mu_pos[i, :L].cpu().numpy(), "sigma": sigma[i, :L].cpu().numpy(), "rho": rho[i, :L].cpu().numpy(), "target": target[i, :L].cpu().numpy(), "delta_orig": int(delta[i].item()), }) if not per_sample: return {} ades, fdes = [], [] in_circle = {t: [] for t in THRESH_M} nlls, coverage95, delta_orig = [], [], [] for s in per_sample: diff = s["target"] - s["mu"] per_step_l2 = np.linalg.norm(diff, axis=1) * RMAX_KM * 1000.0 ades.append(per_step_l2.mean()) fdes.append(per_step_l2[-1]) for t in THRESH_M: in_circle[t].append(per_step_l2[-1] <= t) sx = max(s["sigma"][-1, 0], 1e-9) sy = max(s["sigma"][-1, 1], 1e-9) sz = max(s["sigma"][-1, 2], 1e-9) rho_xy = s["rho"][-1] dx = diff[-1, 0]; dy = diff[-1, 1]; dz = diff[-1, 2] omr2 = max(1.0 - rho_xy * rho_xy, 1e-6) z2 = ((dx / sx) ** 2 - 2 * rho_xy * (dx / sx) * (dy / sy) + (dy / sy) ** 2) / omr2 coverage95.append(z2 <= 5.991) log_det = 2 * (math.log(sx) + math.log(sy)) + math.log(omr2) nll_xy = 0.5 * (z2 + log_det + 2 * math.log(2 * math.pi)) nll_z = 0.5 * ((dz / sz) ** 2 + 2 * math.log(sz) + math.log(2 * math.pi)) nlls.append(nll_xy + nll_z) delta_orig.append(s["delta_orig"]) ades = np.array(ades); fdes = np.array(fdes) nlls = np.array(nlls); coverage95 = np.array(coverage95, dtype=float) delta_orig = np.array(delta_orig) out = { "ade_m": float(ades.mean()), "fde_m": float(fdes.mean()), "fde_median_m": float(np.median(fdes)), "nll_xy_z": float(nlls.mean()), "coverage_95": float(coverage95.mean()), "n": len(ades), } for t in THRESH_M: out[f"miss_rate_{int(t)}m"] = float(1.0 - np.mean(in_circle[t])) if extrap_delta is None: per_bucket = {} for lo, hi in DELTA_BUCKETS: mask = (delta_orig >= lo) & (delta_orig <= hi) if mask.sum() == 0: continue per_bucket[f"delta_{lo}_{hi}"] = { "ade_m": float(ades[mask].mean()), "fde_m": float(fdes[mask].mean()), "coverage_95": float(coverage95[mask].mean()), "n": int(mask.sum()), } out["per_bucket"] = per_bucket return out def main(): p = argparse.ArgumentParser() p.add_argument("--airport", default="RKSIa") p.add_argument("--data-dir", default="data") p.add_argument("--tag", default="run") p.add_argument("--out-dir", default="runs") p.add_argument("--epochs", type=int, default=30) p.add_argument("--batch-size", type=int, default=64) p.add_argument("--lr", type=float, default=1e-4) p.add_argument("--weight-decay", type=float, default=1e-4) p.add_argument("--past-max", type=int, default=256) p.add_argument("--past-min", type=int, default=60) p.add_argument("--delta-min", type=int, default=30) p.add_argument("--delta-max", type=int, default=120) p.add_argument("--extrap-delta-max", type=int, default=300) p.add_argument("--epoch-multiplier", type=int, default=4) p.add_argument("--lambda-jepa", type=float, default=0.0) p.add_argument("--ema-decay", type=float, default=0.998) p.add_argument("--beta-nll", type=float, default=0.5, help="β-NLL exponent (Seitzer 2022). 0=plain NLL, 0.5=recommended.") p.add_argument("--ss-max", type=float, default=0.0, help="Max scheduled-sampling probability (0=teacher-forcing only, 0.5=Bengio recommended).") p.add_argument("--ss-warmup-frac", type=float, default=0.5, help="Fraction of training over which ss_prob ramps from 0 to ss_max linearly.") p.add_argument("--d-model", type=int, default=256) p.add_argument("--n-layers", type=int, default=4) p.add_argument("--n-heads", type=int, default=8) p.add_argument("--patch-size", type=int, default=8) p.add_argument("--seed", type=int, default=0) p.add_argument("--num-workers", type=int, default=2) p.add_argument("--push-to-hub", action="store_true") p.add_argument("--hub-model-id", default=None) p.add_argument("--trackio-name", default=None) args = p.parse_args() torch.manual_seed(args.seed) np.random.seed(args.seed) device = "cuda" if torch.cuda.is_available() else "cpu" print(f"[v2] device={device} tag={args.tag} " f"lambda_jepa={args.lambda_jepa} beta_nll={args.beta_nll} " f"ss_max={args.ss_max} ss_warmup_frac={args.ss_warmup_frac}", flush=True) if device == "cuda": print(f"[v2] cuda device: {torch.cuda.get_device_name(0)} " f"vram={torch.cuda.get_device_properties(0).total_memory/1e9:.1f}GB", flush=True) else: print("[v2] WARNING: CUDA not available, training on CPU. " "This will be very slow.", flush=True) if HAS_TRACKIO and args.trackio_name: trackio.init(project="flight-jepa-v2", name=args.trackio_name, config=vars(args)) train_ds = BlindspotDataset( airport=args.airport, mode="TRAIN", data_dir=args.data_dir, past_max=args.past_max, past_min=args.past_min, delta_min=args.delta_min, delta_max=args.delta_max, seed=args.seed, epoch_multiplier=args.epoch_multiplier, ) test_ds = BlindspotDataset( airport=args.airport, mode="TEST", data_dir=args.data_dir, past_max=args.past_max, past_min=args.past_min, delta_min=args.delta_min, delta_max=args.delta_max, seed=args.seed + 1, epoch_multiplier=1, ) extrap_ds = BlindspotDataset( airport=args.airport, mode="TEST", data_dir=args.data_dir, past_max=args.past_max, past_min=args.past_min, delta_min=args.delta_min, delta_max=args.extrap_delta_max, seed=args.seed + 99, epoch_multiplier=1, ) train_dl = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, drop_last=True) test_dl = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) extrap_dl = DataLoader(extrap_ds, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) cfg = { "d_model": args.d_model, "n_heads": args.n_heads, "n_layers": args.n_layers, "d_ff": args.d_model * 4, "dropout": 0.1, "patch_size": args.patch_size, "past_max": args.past_max, "lambda_jepa": args.lambda_jepa, "ema_decay": args.ema_decay, "beta_nll": args.beta_nll, } model = FlightJEPAv2(cfg).to(device) n_params = sum(p.numel() for p in model.parameters()) print(f"[v2] params={n_params/1e6:.2f}M") optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) os.makedirs(args.out_dir, exist_ok=True) history = [] best_fde = float("inf") best_state = None for epoch in range(args.epochs): t0 = time.time() # Linear ramp ss_prob: 0 → ss_max over args.ss_warmup_frac of training, # then hold at ss_max. warmup_epochs = max(1, int(args.epochs * args.ss_warmup_frac)) ss_prob = min(args.ss_max, args.ss_max * (epoch + 1) / warmup_epochs) train_stats = train_one_epoch(model, train_dl, optimizer, device, ss_prob=ss_prob) scheduler.step() score_stats = None if (epoch + 1) % 5 == 0 or epoch == args.epochs - 1: score_stats = score_loader(model, test_dl, device) if score_stats and score_stats["fde_m"] < best_fde: best_fde = score_stats["fde_m"] best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()} elapsed = time.time() - t0 log = { "epoch": epoch + 1, "elapsed_s": elapsed, "lr": optimizer.param_groups[0]["lr"], "train": train_stats, "score": score_stats, } history.append(log) msg = (f"[v2] ep {epoch+1:03d} | loss={train_stats['total']:.4f} " f"nll={train_stats['nll']:.4f} ade_t={train_stats['ade_train']:.4f} " f"jepa={train_stats['jepa']:.4f} ss={ss_prob:.2f}") if score_stats: msg += f" | fde={score_stats['fde_m']:.0f}m ade={score_stats['ade_m']:.0f}m" msg += f" | {elapsed:.0f}s" print(msg, flush=True) if HAS_TRACKIO and args.trackio_name: tlog = {f"train/{k}": v for k, v in train_stats.items()} if score_stats: tlog.update({f"test/{k}": v for k, v in score_stats.items() if isinstance(v, (int, float))}) trackio.log(tlog, step=epoch + 1) final = {"in_distribution": score_loader(model, test_dl, device)} for d in EXTRAP_DELTAS: final[f"extrap_delta_{d}"] = score_loader(model, extrap_dl, device, extrap_delta=d) if best_state is not None: model.load_state_dict(best_state) out_path = os.path.join(args.out_dir, f"{args.tag}.pt") torch.save({ "state_dict": model.state_dict(), "config": cfg, "args": vars(args), "history": history, "final": final, "best_fde_m": best_fde, }, out_path) print(f"[v2] saved {out_path}") summary_path = os.path.join(args.out_dir, f"{args.tag}_summary.json") with open(summary_path, "w") as f: json.dump({ "tag": args.tag, "lambda_jepa": args.lambda_jepa, "beta_nll": args.beta_nll, "n_params": n_params, "best_fde_m": best_fde, "final": final, "args": vars(args), }, f, indent=2, default=float) print(f"[v2] summary -> {summary_path}", flush=True) if args.push_to_hub and args.hub_model_id: try: from huggingface_hub import HfApi api = HfApi() api.create_repo(args.hub_model_id, exist_ok=True) for path, fname in [(out_path, f"{args.tag}.pt"), (summary_path, f"{args.tag}_summary.json")]: api.upload_file(path_or_fileobj=path, path_in_repo=fname, repo_id=args.hub_model_id) print(f"[v2] uploaded to {args.hub_model_id}") except Exception as e: print(f"[v2] hub upload failed: {e}") if HAS_TRACKIO and args.trackio_name: trackio.finish() if __name__ == "__main__": main()