""" AirTrackLM - CPU Training + Hub Push ===================================== Trains the full model on CPU and pushes checkpoints + source to HF Hub. """ import os import sys import time import json import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from torch.utils.data import DataLoader, random_split from torch.optim import AdamW from torch.optim.lr_scheduler import CosineAnnealingLR from pathlib import Path sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from data_pipeline import TrajectoryProcessor, load_traffic_sample, build_dataset from model import AirTrackLM, AirTrackConfig, NextStateLoss def collate_fn(batch): max_len = max(b['cog_bins'].size(0) for b in batch) collated = {} for key in batch[0].keys(): tensors = [b[key] for b in batch] if key == 'prompt': collated[key] = torch.stack(tensors) else: padded = [] for t in tensors: if t.dim() == 1: padded.append(F.pad(t, (0, max_len - t.size(0)), value=0)) elif t.dim() == 2: padded.append(F.pad(t, (0, 0, 0, max_len - t.size(0)), value=0)) else: padded.append(t) collated[key] = torch.stack(padded) return collated @torch.no_grad() def evaluate(model, dataloader, loss_fn, device): model.eval() loss_components = {} n_batches = 0 correct = {'cog': 0, 'sog': 0, 'rot': 0, 'alt_rate': 0} total_preds = 0 for batch in dataloader: batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} predictions = model(batch) _, loss_log = loss_fn(predictions, batch) for k, v in loss_log.items(): loss_components[k] = loss_components.get(k, 0) + v n_batches += 1 for feat in ['cog', 'sog', 'rot', 'alt_rate']: pred = predictions[f'{feat}_logits'][:, :-1, :].argmax(dim=-1) target = batch[f'{feat}_bins'][:, 1:] correct[feat] += (pred == target).sum().item() total_preds += batch['cog_bins'][:, 1:].numel() metrics = {k: v / max(n_batches, 1) for k, v in loss_components.items()} for feat in ['cog', 'sog', 'rot', 'alt_rate']: metrics[f'{feat}_acc'] = correct[feat] / max(total_preds, 1) return metrics def main(): print("=" * 70) print("AirTrackLM - Training (CPU) + Push to Hub") print("=" * 70) HUB_MODEL_ID = "Jdice27/AirTrackLM" device = torch.device('cpu') config = AirTrackConfig( d_model=256, n_heads=8, n_layers=8, d_ff=1024, dropout=0.1, max_seq_len=256, geohash_mode='absolute', use_multi_uncertainty=True, n_uncert_methods=4, use_heteroscedastic=True, predict_geohash=True, predict_continuous=True, ) SEQ_LEN, STRIDE = 64, 32 BATCH_SIZE = 16 N_EPOCHS = 30 LR = 5e-4 PATIENCE = 8 # ---- Load Data ---- print("\n1. Loading data...") t0 = time.time() raw_trajs = [] for name in ['quickstart', 'switzerland', 'savan']: try: trajs = load_traffic_sample(name) raw_trajs.extend(trajs) print(f" {name}: {len(trajs)} flights") except Exception as e: print(f" {name}: failed ({e})") print(f" Total: {len(raw_trajs)} flights ({time.time()-t0:.1f}s)") # ---- Process ---- print("\n2. Processing...") t0 = time.time() processor = TrajectoryProcessor(resample_dt=5.0) dataset = build_dataset(raw_trajs, processor, seq_len=SEQ_LEN, stride=STRIDE) print(f" {time.time()-t0:.1f}s") n_val = max(1, int(0.15 * len(dataset))) train_ds, val_ds = random_split(dataset, [len(dataset) - n_val, n_val], generator=torch.Generator().manual_seed(42)) print(f" Train: {len(train_ds)}, Val: {len(val_ds)}") # ---- Model ---- model = AirTrackLM(config) print(f"\n3. Model: {sum(p.numel() for p in model.parameters()):,} params") train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, num_workers=0) val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn, num_workers=0) loss_fn = NextStateLoss(config) optimizer = AdamW(model.parameters(), lr=LR, weight_decay=0.01) scheduler = CosineAnnealingLR(optimizer, T_max=N_EPOCHS * len(train_loader), eta_min=LR * 0.01) output_dir = Path('./checkpoints') output_dir.mkdir(exist_ok=True) best_val_loss = float('inf') patience_counter = 0 history = [] print(f"\n4. Training: {N_EPOCHS} epochs") print("=" * 70) for epoch in range(N_EPOCHS): t_epoch = time.time() model.train() train_loss = 0 train_comp = {} n_b = 0 for batch in train_loader: predictions = model(batch) loss, log = loss_fn(predictions, batch) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() optimizer.zero_grad() scheduler.step() train_loss += log['total'] for k, v in log.items(): train_comp[k] = train_comp.get(k, 0) + v n_b += 1 train_avg = {k: v/n_b for k, v in train_comp.items()} val_metrics = evaluate(model, val_loader, loss_fn, device) elapsed = time.time() - t_epoch improved = val_metrics['total'] < best_val_loss print(f"Epoch {epoch+1:02d}/{N_EPOCHS} [{elapsed:.0f}s] {'★' if improved else ' '} " f"train={train_avg['total']:.3f} val={val_metrics['total']:.3f} " f"COG={val_metrics.get('cog_acc',0):.3f} SOG={val_metrics.get('sog_acc',0):.3f} " f"ROT={val_metrics.get('rot_acc',0):.3f} AltRate={val_metrics.get('alt_rate_acc',0):.3f}") history.append({'epoch': epoch+1, 'train': train_avg, 'val': val_metrics, 'lr': scheduler.get_last_lr()[0], 'time': elapsed}) if improved: best_val_loss = val_metrics['total'] patience_counter = 0 torch.save({ 'epoch': epoch+1, 'model_state_dict': model.state_dict(), 'config': config.__dict__, 'val_loss': best_val_loss, 'val_metrics': val_metrics, }, output_dir / 'best_model.pt') else: patience_counter += 1 if patience_counter >= PATIENCE: print(f"Early stopping at epoch {epoch+1}") break # ---- Save + Push ---- print("\n" + "=" * 70) print("Saving and pushing to Hub...") torch.save({ 'model_state_dict': model.state_dict(), 'config': config.__dict__, 'best_val_loss': best_val_loss, 'history': history, }, output_dir / 'final_model.pt') with open(output_dir / 'training_history.json', 'w') as f: json.dump(history, f, indent=2, default=str) with open(output_dir / 'config.json', 'w') as f: json.dump(config.__dict__, f, indent=2) try: from huggingface_hub import HfApi api = HfApi() api.upload_folder( folder_path=str(output_dir), repo_id=HUB_MODEL_ID, repo_type="model", commit_message=f"Training: val_loss={best_val_loss:.4f}", ) print(f"✓ Checkpoints pushed to https://huggingface.co/{HUB_MODEL_ID}") except Exception as e: print(f"Push failed: {e}") print(f"\nBest val loss: {best_val_loss:.4f}") print(f"Final metrics: COG={val_metrics.get('cog_acc',0):.3f} SOG={val_metrics.get('sog_acc',0):.3f}") print("Done!") if __name__ == '__main__': main()