# /// script # requires-python = ">=3.10" # dependencies = [ # "torch>=2.1", # "numpy", # "pandas", # "scikit-learn", # "huggingface-hub", # "trackio", # ] # /// """ Flight-JEPA v2 — bundled training script for HF Jobs. Self-contained: downloads the dataset from HF, trains either the supervised baseline (`--lambda-jepa 0`) or the JEPA-augmented model, runs blindspot scoring + extrapolation eval, and pushes the result to a hub repo. Usage (HF Jobs): python train_v2_prod.py --tag baseline --lambda-jepa 0.0 \ --hub-model-id guychuk/flight-jepa-v2 --push-to-hub python train_v2_prod.py --tag jepa --lambda-jepa 0.5 \ --hub-model-id guychuk/flight-jepa-v2 --push-to-hub """ from __future__ import annotations import argparse import copy import json import math import os import shutil import sys import time import numpy as np import pandas as pd import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader try: import trackio HAS_TRACKIO = True except ImportError: HAS_TRACKIO = False # ============================================================================ # DATA UTILITIES (inlined from flight_jepa.data) # ============================================================================ def load_atfm(dset_name, mode, path): variables = ["X", "Y", "Z"] data, labels = [], None for var in variables: df = pd.read_csv(os.path.join(path, f"{dset_name}_{mode}_{var}.tsv"), sep="\t", header=None, na_values="NaN") if labels is None: labels = df.values[:, 0] data.append(df.values[:, 1:]) return np.stack(data, axis=-1), labels.astype(int) def compute_features(traj_xyz: np.ndarray) -> np.ndarray: if traj_xyz.shape[0] < 2: T = traj_xyz.shape[0] return np.concatenate([ traj_xyz, np.zeros((T, 3), dtype=traj_xyz.dtype), np.zeros((T, 3), dtype=traj_xyz.dtype) ], axis=1) x, y, z = traj_xyz[:, 0], traj_xyz[:, 1], traj_xyz[:, 2] diff = np.diff(traj_xyz, axis=0) norms = np.maximum(np.linalg.norm(diff, axis=1, keepdims=True), 1e-8) u = diff / norms u = np.vstack([u, u[-1:]]) r = np.sqrt(x ** 2 + y ** 2) theta = np.arctan2(y, x) return np.column_stack([ traj_xyz, u, r[:, None], np.sin(theta)[:, None], np.cos(theta)[:, None] ]).astype(np.float32) def ensure_data(airport: str, data_dir: str = "data"): target = os.path.join(data_dir, airport) if os.path.isdir(target) and any(f.endswith(".tsv") for f in os.listdir(target)): return target print(f"[data] downloading {airport} from HF ...") from huggingface_hub import snapshot_download snap = snapshot_download( "petchthwr/ATFMTraj", repo_type="dataset", allow_patterns=[f"{airport}/*"], ) os.makedirs(data_dir, exist_ok=True) src = os.path.join(snap, airport) if not os.path.isdir(target): shutil.copytree(src, target) return target # ============================================================================ # DATASET — variable-length blindspot # ============================================================================ PAD_VALUE = 0.0 class BlindspotDataset(Dataset): def __init__(self, airport, mode, data_dir, past_max=256, past_min=60, delta_min=30, delta_max=120, seed=0, epoch_multiplier=4): ensure_data(airport, data_dir) airport_dir = os.path.join(data_dir, airport) raw, labels = load_atfm(airport, mode, airport_dir) self.past_max = past_max self.past_min = past_min self.delta_min = delta_min self.delta_max = delta_max self.epoch_multiplier = epoch_multiplier self.rng_seed = seed lengths = np.array( [int(np.sum(~np.isnan(raw[i, :, 0]))) for i in range(raw.shape[0])], dtype=np.int64, ) min_required = past_min + delta_max + 1 keep = lengths >= min_required if keep.sum() == 0: raise RuntimeError( f"No trajectories of length >= {min_required} in {airport}/{mode}" ) raw = raw[keep] lengths = lengths[keep] self.labels = labels[keep].astype(np.int64) self.positions = [] for i in range(raw.shape[0]): L = int(lengths[i]) self.positions.append(np.nan_to_num(raw[i, :L], nan=0.0).astype(np.float32)) del raw self.n_traj = len(self.positions) print(f"[data] {airport}/{mode}: {self.n_traj} trajectories " f"(after filtering for L >= {min_required})") def __len__(self): return self.n_traj * self.epoch_multiplier def __getitem__(self, idx): traj_idx = idx % self.n_traj rng = np.random.default_rng(self.rng_seed + idx * 9173) positions = self.positions[traj_idx] L = positions.shape[0] delta = int(rng.integers(self.delta_min, self.delta_max + 1)) t_in_max = L - delta - 1 t_in_min = self.past_min t_in = int(rng.integers(t_in_min, t_in_max + 1)) past_start = max(0, t_in - self.past_max) past_pos = positions[past_start:t_in] target_pos = positions[t_in:t_in + delta] past_features = compute_features(past_pos) T_past = past_features.shape[0] feat_pad = np.full((self.past_max, 9), PAD_VALUE, dtype=np.float32) feat_pad[:T_past] = past_features tgt_pad = np.zeros((self.delta_max, 3), dtype=np.float32) tgt_pad[:delta] = target_pos return { "past_features": torch.from_numpy(feat_pad), "past_length": torch.tensor(T_past, dtype=torch.long), "target_pos": torch.from_numpy(tgt_pad), "delta": torch.tensor(delta, dtype=torch.long), "label": torch.tensor(int(self.labels[traj_idx]), dtype=torch.long), } # ============================================================================ # MODEL # ============================================================================ def sinusoidal_embedding(values, dim): half = dim // 2 device = values.device freqs = torch.exp(-math.log(10000.0) * torch.arange(half, device=device) / half) angles = values.float().unsqueeze(-1) * freqs emb = torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1) if dim % 2 == 1: emb = F.pad(emb, (0, 1)) return emb class LearnablePosEnc(nn.Module): def __init__(self, max_len, d_model): super().__init__() self.pe = nn.Parameter(torch.randn(1, max_len, d_model) * 0.02) def forward(self, x): return x + self.pe[:, :x.size(1)] class PatchTokenizer(nn.Module): def __init__(self, in_channels=9, d_model=256, patch_size=8, max_patches=64): super().__init__() self.patch_size = patch_size self.d_model = d_model self.embed = nn.Sequential( nn.Conv1d(in_channels, d_model // 2, 5, padding=2), nn.GELU(), nn.Conv1d(d_model // 2, d_model, 3, padding=1), nn.GELU(), ) self.pos_enc = LearnablePosEnc(max_patches, d_model) self.norm = nn.LayerNorm(d_model) def forward(self, features, lengths): B, T, C = features.shape h = self.embed(features.transpose(1, 2)) N = max(1, T // self.patch_size) h = h[:, :, :N * self.patch_size] h = h.reshape(B, self.d_model, N, self.patch_size).mean(-1) h = h.transpose(1, 2) h = self.norm(self.pos_enc(h)) patch_lengths = (lengths.float() / self.patch_size).clamp(min=1).long() patch_lengths = patch_lengths.clamp(max=N) return h, patch_lengths class CausalEncoder(nn.Module): def __init__(self, d_model=256, n_heads=8, n_layers=4, d_ff=1024, dropout=0.1): super().__init__() layer = nn.TransformerEncoderLayer( d_model=d_model, nhead=n_heads, dim_feedforward=d_ff, dropout=dropout, activation="gelu", batch_first=True, norm_first=True, ) self.tf = nn.TransformerEncoder(layer, num_layers=n_layers) self.norm = nn.LayerNorm(d_model) def forward(self, x, key_padding_mask): N = x.size(1) causal_mask = torch.triu( torch.ones(N, N, dtype=torch.bool, device=x.device), diagonal=1 ) return self.norm( self.tf(x, mask=causal_mask, src_key_padding_mask=key_padding_mask) ) def last_valid_token(encoded, patch_lengths): B, N, D = encoded.shape idx = (patch_lengths - 1).clamp(min=0).view(B, 1, 1).expand(-1, 1, D) return encoded.gather(1, idx).squeeze(1) class DeltaEmbedding(nn.Module): def __init__(self, d_model=256, d_freq=64): super().__init__() self.d_freq = d_freq self.proj = nn.Sequential( nn.Linear(d_freq * 2, d_model), nn.GELU(), nn.Linear(d_model, d_model), ) def forward(self, delta, t_past): d_emb = sinusoidal_embedding(delta.float(), self.d_freq) rel = delta.float() / t_past.float().clamp(min=1.0) rel_emb = sinusoidal_embedding(rel * 100.0, self.d_freq) return self.proj(torch.cat([d_emb, rel_emb], dim=-1)) class GaussianHead(nn.Module): def __init__(self, d_model=256, d_hidden=256): super().__init__() self.net = nn.Sequential( nn.Linear(d_model, d_hidden), nn.GELU(), nn.Linear(d_hidden, d_hidden), nn.GELU(), ) self.mu_head = nn.Linear(d_hidden, 3) self.log_sigma_head = nn.Linear(d_hidden, 3) self.rho_head = nn.Linear(d_hidden, 1) def forward(self, h): z = self.net(h) delta_mu = self.mu_head(z) log_sigma = self.log_sigma_head(z).clamp(min=-7.0, max=2.0) rho = torch.tanh(self.rho_head(z)).squeeze(-1) * 0.99 return delta_mu, log_sigma, rho def gaussian_nll_xyz(true_delta, mu, log_sigma, rho): sx = log_sigma[:, 0].exp() sy = log_sigma[:, 1].exp() sz = log_sigma[:, 2].exp() dx = true_delta[:, 0] - mu[:, 0] dy = true_delta[:, 1] - mu[:, 1] dz = true_delta[:, 2] - mu[:, 2] omr2 = (1.0 - rho * rho).clamp(min=1e-6) z2 = (((dx / sx) ** 2) - 2.0 * rho * (dx / sx) * (dy / sy) + ((dy / sy) ** 2)) / omr2 log_det = 2.0 * (log_sigma[:, 0] + log_sigma[:, 1]) + torch.log(omr2) nll_xy = 0.5 * (z2 + log_det + 2.0 * math.log(2.0 * math.pi)) nll_z = 0.5 * ((dz / sz) ** 2 + 2.0 * log_sigma[:, 2] + math.log(2.0 * math.pi)) return nll_xy + nll_z class FuturePredictor(nn.Module): def __init__(self, d_model=256, pred_dim=128, dropout=0.1): super().__init__() self.proj_in = nn.Linear(d_model * 2, pred_dim) layer = nn.TransformerEncoderLayer( d_model=pred_dim, nhead=4, dim_feedforward=pred_dim * 2, dropout=dropout, activation="gelu", batch_first=True, norm_first=True, ) self.tf = nn.TransformerEncoder(layer, num_layers=2) self.proj_out = nn.Linear(pred_dim, d_model) self.norm = nn.LayerNorm(d_model) def forward(self, z_in, delta_emb): h = self.proj_in(torch.cat([z_in, delta_emb], dim=-1)).unsqueeze(1) h = self.tf(h) return self.norm(self.proj_out(h.squeeze(1))) class FlightJEPAv2(nn.Module): def __init__(self, cfg): super().__init__() self.cfg = cfg d = cfg.get("d_model", 256) h_ = cfg.get("n_heads", 8) n_l = cfg.get("n_layers", 4) d_ff = cfg.get("d_ff", 1024) dr = cfg.get("dropout", 0.1) ps = cfg.get("patch_size", 8) past_max = cfg.get("past_max", 256) max_patches = past_max // ps self.lambda_jepa = cfg.get("lambda_jepa", 0.0) self.ema_decay = cfg.get("ema_decay", 0.998) self.tokenizer = PatchTokenizer(9, d, ps, max_patches) self.encoder = CausalEncoder(d, h_, n_l, d_ff, dr) self.delta_emb = DeltaEmbedding(d, 64) self.head = GaussianHead(d, d) self.fuse_in = nn.Sequential( nn.Linear(d * 2, d), nn.GELU(), nn.Linear(d, d), ) self.step_cell = nn.GRUCell(input_size=3, hidden_size=d) self.target_tokenizer = copy.deepcopy(self.tokenizer) self.target_encoder = copy.deepcopy(self.encoder) for p in self.target_tokenizer.parameters(): p.requires_grad = False for p in self.target_encoder.parameters(): p.requires_grad = False self.predictor = FuturePredictor(d, d // 2, dr) @torch.no_grad() def update_ema(self): m = self.ema_decay for online, target in [(self.tokenizer, self.target_tokenizer), (self.encoder, self.target_encoder)]: for po, pt in zip(online.parameters(), target.parameters()): pt.data.mul_(m).add_(po.data, alpha=1.0 - m) def encode_past(self, past_features, past_length): patches, patch_lens = self.tokenizer(past_features, past_length) N = patches.size(1) pad_mask = (torch.arange(N, device=patches.device).unsqueeze(0) >= patch_lens.unsqueeze(1)) encoded = self.encoder(patches, key_padding_mask=pad_mask) z_in = last_valid_token(encoded, patch_lens) return z_in, encoded, patch_lens @torch.no_grad() def encode_future_target(self, target_features, target_length): patches, patch_lens = self.target_tokenizer(target_features, target_length) N = patches.size(1) pad_mask = (torch.arange(N, device=patches.device).unsqueeze(0) >= patch_lens.unsqueeze(1)) encoded = self.target_encoder(patches, key_padding_mask=pad_mask) return last_valid_token(encoded, patch_lens) def forward(self, past_features, past_length, target_pos, delta, last_pos): B = past_features.size(0) device = past_features.device delta_max = target_pos.size(1) z_in, _, _ = self.encode_past(past_features, past_length) delta_e = self.delta_emb(delta, past_length) h = self.fuse_in(torch.cat([z_in, delta_e], dim=-1)) prev_pos = last_pos nll_total = torch.zeros(B, device=device) valid_steps = torch.zeros(B, device=device) ade_total = torch.zeros(B, device=device) for t in range(delta_max): delta_mu, log_sigma, rho = self.head(h) true_pos_t = target_pos[:, t] true_delta = true_pos_t - prev_pos nll = gaussian_nll_xyz(true_delta, delta_mu, log_sigma, rho) mask = (t < delta).float() nll_total = nll_total + nll * mask ade_total = (ade_total + (true_delta - delta_mu).pow(2).sum(-1).sqrt() * mask) valid_steps = valid_steps + mask h = self.step_cell(true_delta, h) prev_pos = true_pos_t nll_loss = (nll_total / valid_steps.clamp(min=1.0)).mean() ade_train = (ade_total / valid_steps.clamp(min=1.0)).mean().detach() losses = {"nll": nll_loss, "ade_train": ade_train, "total": nll_loss} if self.lambda_jepa > 0.0: tgt_feat = torch.zeros(B, delta_max, 9, device=device) tgt_feat[..., :3] = target_pos z_target = self.encode_future_target(tgt_feat, delta) z_pred = self.predictor(z_in, delta_e) jepa_loss = F.l1_loss(z_pred, z_target.detach()) losses["jepa"] = jepa_loss losses["total"] = nll_loss + self.lambda_jepa * jepa_loss return losses @torch.no_grad() def rollout(self, past_features, past_length, delta, last_pos, delta_max): B = past_features.size(0) device = past_features.device z_in, _, _ = self.encode_past(past_features, past_length) delta_e = self.delta_emb(delta, past_length) h = self.fuse_in(torch.cat([z_in, delta_e], dim=-1)) prev_pos = last_pos mu_pos = torch.zeros(B, delta_max, 3, device=device) sigma = torch.zeros(B, delta_max, 3, device=device) rho_out = torch.zeros(B, delta_max, device=device) for t in range(delta_max): delta_mu, log_sigma, rho = self.head(h) cur_pos = prev_pos + delta_mu mu_pos[:, t] = cur_pos sigma[:, t] = log_sigma.exp() rho_out[:, t] = rho h = self.step_cell(delta_mu, h) prev_pos = cur_pos return mu_pos, sigma, rho_out # ============================================================================ # TRAIN + SCORE # ============================================================================ RMAX_KM = 120.0 DELTA_BUCKETS = [(30, 60), (60, 90), (90, 120)] EXTRAP_DELTAS = [180, 300] THRESH_M = [500.0, 1000.0, 2000.0] def get_last_pos(past_features, past_length): B = past_features.size(0) idx = (past_length - 1).clamp(min=0) return past_features[torch.arange(B, device=past_features.device), idx, :3] def train_one_epoch(model, loader, optimizer, device, grad_clip=1.0, log_every: int = 50): model.train() sums = {"nll": 0.0, "ade": 0.0, "jepa": 0.0, "total": 0.0, "n": 0} t0 = time.time() n_batches = len(loader) if hasattr(loader, "__len__") else 0 for bi, batch in enumerate(loader): past_f = batch["past_features"].to(device) past_l = batch["past_length"].to(device) target = batch["target_pos"].to(device) delta = batch["delta"].to(device) last_pos = get_last_pos(past_f, past_l) losses = model(past_f, past_l, target, delta, last_pos) optimizer.zero_grad() losses["total"].backward() torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) optimizer.step() if model.lambda_jepa > 0.0: model.update_ema() bs = past_f.size(0) sums["nll"] += losses["nll"].item() * bs sums["ade"] += losses["ade_train"].item() * bs if "jepa" in losses: sums["jepa"] += losses["jepa"].item() * bs sums["total"] += losses["total"].item() * bs sums["n"] += bs if (bi + 1) % log_every == 0 or bi == 0: dt = time.time() - t0 rate = (bi + 1) / max(dt, 0.001) print(f" [batch {bi+1}/{n_batches}] {dt:.1f}s elapsed, " f"{rate:.1f} batch/s, loss={losses['total'].item():.4f}", flush=True) n = max(sums["n"], 1) return {k: v / n for k, v in sums.items() if k != "n"} | { "ade_train": sums["ade"] / n } @torch.no_grad() def score_loader(model, loader, device, extrap_delta=None): model.train(False) delta_max_dataset = loader.dataset.delta_max per_sample = [] for batch in loader: past_f = batch["past_features"].to(device) past_l = batch["past_length"].to(device) target = batch["target_pos"].to(device) delta = batch["delta"].to(device) last_pos = get_last_pos(past_f, past_l) if extrap_delta is not None: forced = torch.full_like(delta, extrap_delta) roll_len = extrap_delta else: forced = delta roll_len = int(delta.max().item()) if roll_len > delta_max_dataset: continue mu_pos, sigma, rho = model.rollout(past_f, past_l, forced, last_pos, roll_len) active_len = torch.minimum(forced, delta).clamp(min=1) for i in range(past_f.size(0)): L = int(active_len[i].item()) per_sample.append({ "mu": mu_pos[i, :L].cpu().numpy(), "sigma": sigma[i, :L].cpu().numpy(), "rho": rho[i, :L].cpu().numpy(), "target": target[i, :L].cpu().numpy(), "delta_orig": int(delta[i].item()), }) if not per_sample: return {} ades, fdes = [], [] in_circle = {t: [] for t in THRESH_M} nlls, coverage95, delta_orig = [], [], [] for s in per_sample: diff = s["target"] - s["mu"] per_step_l2 = np.linalg.norm(diff, axis=1) * RMAX_KM * 1000.0 ades.append(per_step_l2.mean()) fdes.append(per_step_l2[-1]) for t in THRESH_M: in_circle[t].append(per_step_l2[-1] <= t) sx = max(s["sigma"][-1, 0], 1e-9) sy = max(s["sigma"][-1, 1], 1e-9) sz = max(s["sigma"][-1, 2], 1e-9) rho_xy = s["rho"][-1] dx = diff[-1, 0]; dy = diff[-1, 1]; dz = diff[-1, 2] omr2 = max(1.0 - rho_xy * rho_xy, 1e-6) z2 = ((dx / sx) ** 2 - 2 * rho_xy * (dx / sx) * (dy / sy) + (dy / sy) ** 2) / omr2 coverage95.append(z2 <= 5.991) log_det = 2 * (math.log(sx) + math.log(sy)) + math.log(omr2) nll_xy = 0.5 * (z2 + log_det + 2 * math.log(2 * math.pi)) nll_z = 0.5 * ((dz / sz) ** 2 + 2 * math.log(sz) + math.log(2 * math.pi)) nlls.append(nll_xy + nll_z) delta_orig.append(s["delta_orig"]) ades = np.array(ades); fdes = np.array(fdes) nlls = np.array(nlls); coverage95 = np.array(coverage95, dtype=float) delta_orig = np.array(delta_orig) out = { "ade_m": float(ades.mean()), "fde_m": float(fdes.mean()), "fde_median_m": float(np.median(fdes)), "nll_xy_z": float(nlls.mean()), "coverage_95": float(coverage95.mean()), "n": len(ades), } for t in THRESH_M: out[f"miss_rate_{int(t)}m"] = float(1.0 - np.mean(in_circle[t])) if extrap_delta is None: per_bucket = {} for lo, hi in DELTA_BUCKETS: mask = (delta_orig >= lo) & (delta_orig <= hi) if mask.sum() == 0: continue per_bucket[f"delta_{lo}_{hi}"] = { "ade_m": float(ades[mask].mean()), "fde_m": float(fdes[mask].mean()), "coverage_95": float(coverage95[mask].mean()), "n": int(mask.sum()), } out["per_bucket"] = per_bucket return out def main(): p = argparse.ArgumentParser() p.add_argument("--airport", default="RKSIa") p.add_argument("--data-dir", default="data") p.add_argument("--tag", default="run") p.add_argument("--out-dir", default="runs") p.add_argument("--epochs", type=int, default=30) p.add_argument("--batch-size", type=int, default=64) p.add_argument("--lr", type=float, default=1e-4) p.add_argument("--weight-decay", type=float, default=1e-4) p.add_argument("--past-max", type=int, default=256) p.add_argument("--past-min", type=int, default=60) p.add_argument("--delta-min", type=int, default=30) p.add_argument("--delta-max", type=int, default=120) p.add_argument("--extrap-delta-max", type=int, default=300) p.add_argument("--epoch-multiplier", type=int, default=4) p.add_argument("--lambda-jepa", type=float, default=0.0) p.add_argument("--ema-decay", type=float, default=0.998) p.add_argument("--d-model", type=int, default=256) p.add_argument("--n-layers", type=int, default=4) p.add_argument("--n-heads", type=int, default=8) p.add_argument("--patch-size", type=int, default=8) p.add_argument("--seed", type=int, default=0) p.add_argument("--num-workers", type=int, default=2) p.add_argument("--push-to-hub", action="store_true") p.add_argument("--hub-model-id", default=None) p.add_argument("--trackio-name", default=None) args = p.parse_args() torch.manual_seed(args.seed) np.random.seed(args.seed) device = "cuda" if torch.cuda.is_available() else "cpu" print(f"[v2] device={device} tag={args.tag} lambda_jepa={args.lambda_jepa}", flush=True) if device == "cuda": print(f"[v2] cuda device: {torch.cuda.get_device_name(0)} " f"vram={torch.cuda.get_device_properties(0).total_memory/1e9:.1f}GB", flush=True) else: print("[v2] WARNING: CUDA not available, training on CPU. " "This will be very slow.", flush=True) if HAS_TRACKIO and args.trackio_name: trackio.init(project="flight-jepa-v2", name=args.trackio_name, config=vars(args)) train_ds = BlindspotDataset( airport=args.airport, mode="TRAIN", data_dir=args.data_dir, past_max=args.past_max, past_min=args.past_min, delta_min=args.delta_min, delta_max=args.delta_max, seed=args.seed, epoch_multiplier=args.epoch_multiplier, ) test_ds = BlindspotDataset( airport=args.airport, mode="TEST", data_dir=args.data_dir, past_max=args.past_max, past_min=args.past_min, delta_min=args.delta_min, delta_max=args.delta_max, seed=args.seed + 1, epoch_multiplier=1, ) extrap_ds = BlindspotDataset( airport=args.airport, mode="TEST", data_dir=args.data_dir, past_max=args.past_max, past_min=args.past_min, delta_min=args.delta_min, delta_max=args.extrap_delta_max, seed=args.seed + 99, epoch_multiplier=1, ) train_dl = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, drop_last=True) test_dl = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) extrap_dl = DataLoader(extrap_ds, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) cfg = { "d_model": args.d_model, "n_heads": args.n_heads, "n_layers": args.n_layers, "d_ff": args.d_model * 4, "dropout": 0.1, "patch_size": args.patch_size, "past_max": args.past_max, "lambda_jepa": args.lambda_jepa, "ema_decay": args.ema_decay, } model = FlightJEPAv2(cfg).to(device) n_params = sum(p.numel() for p in model.parameters()) print(f"[v2] params={n_params/1e6:.2f}M") optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) os.makedirs(args.out_dir, exist_ok=True) history = [] best_fde = float("inf") best_state = None for epoch in range(args.epochs): t0 = time.time() train_stats = train_one_epoch(model, train_dl, optimizer, device) scheduler.step() score_stats = None if (epoch + 1) % 5 == 0 or epoch == args.epochs - 1: score_stats = score_loader(model, test_dl, device) if score_stats and score_stats["fde_m"] < best_fde: best_fde = score_stats["fde_m"] best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()} elapsed = time.time() - t0 log = { "epoch": epoch + 1, "elapsed_s": elapsed, "lr": optimizer.param_groups[0]["lr"], "train": train_stats, "score": score_stats, } history.append(log) msg = (f"[v2] ep {epoch+1:03d} | loss={train_stats['total']:.4f} " f"nll={train_stats['nll']:.4f} ade_t={train_stats['ade_train']:.4f} " f"jepa={train_stats['jepa']:.4f}") if score_stats: msg += f" | fde={score_stats['fde_m']:.0f}m ade={score_stats['ade_m']:.0f}m" msg += f" | {elapsed:.0f}s" print(msg, flush=True) if HAS_TRACKIO and args.trackio_name: tlog = {f"train/{k}": v for k, v in train_stats.items()} if score_stats: tlog.update({f"test/{k}": v for k, v in score_stats.items() if isinstance(v, (int, float))}) trackio.log(tlog, step=epoch + 1) final = {"in_distribution": score_loader(model, test_dl, device)} for d in EXTRAP_DELTAS: final[f"extrap_delta_{d}"] = score_loader(model, extrap_dl, device, extrap_delta=d) if best_state is not None: model.load_state_dict(best_state) out_path = os.path.join(args.out_dir, f"{args.tag}.pt") torch.save({ "state_dict": model.state_dict(), "config": cfg, "args": vars(args), "history": history, "final": final, "best_fde_m": best_fde, }, out_path) print(f"[v2] saved {out_path}") summary_path = os.path.join(args.out_dir, f"{args.tag}_summary.json") with open(summary_path, "w") as f: json.dump({ "tag": args.tag, "lambda_jepa": args.lambda_jepa, "n_params": n_params, "best_fde_m": best_fde, "final": final, "args": vars(args), }, f, indent=2, default=float) print(f"[v2] summary -> {summary_path}", flush=True) if args.push_to_hub and args.hub_model_id: try: from huggingface_hub import HfApi api = HfApi() api.create_repo(args.hub_model_id, exist_ok=True) for path, fname in [(out_path, f"{args.tag}.pt"), (summary_path, f"{args.tag}_summary.json")]: api.upload_file(path_or_fileobj=path, path_in_repo=fname, repo_id=args.hub_model_id) print(f"[v2] uploaded to {args.hub_model_id}") except Exception as e: print(f"[v2] hub upload failed: {e}") if HAS_TRACKIO and args.trackio_name: trackio.finish() if __name__ == "__main__": main()