""" AirTrackLM - Full Training Script =================================== Trains decoder-only transformer on traffic library ADS-B data. Pushes model + source to HuggingFace Hub. """ import os import sys import time import json import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from torch.utils.data import DataLoader, random_split from torch.optim import AdamW from torch.optim.lr_scheduler import CosineAnnealingLR from pathlib import Path sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from data_pipeline import TrajectoryProcessor, FeatureBins, load_traffic_sample, build_dataset from model import AirTrackLM, AirTrackConfig, NextStateLoss def collate_fn(batch): max_len = max(b['cog_bins'].size(0) for b in batch) collated = {} for key in batch[0].keys(): tensors = [b[key] for b in batch] if key == 'prompt': collated[key] = torch.stack(tensors) else: padded = [] for t in tensors: if t.dim() == 1: padded.append(F.pad(t, (0, max_len - t.size(0)), value=0)) elif t.dim() == 2: padded.append(F.pad(t, (0, 0, 0, max_len - t.size(0)), value=0)) else: padded.append(t) collated[key] = torch.stack(padded) return collated @torch.no_grad() def evaluate(model, dataloader, loss_fn, device): model.eval() total_loss = 0.0 loss_components = {} n_batches = 0 correct = {'cog': 0, 'sog': 0, 'rot': 0, 'alt_rate': 0} total_preds = 0 for batch in dataloader: batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} predictions = model(batch) loss, loss_log = loss_fn(predictions, batch) total_loss += loss_log['total'] for k, v in loss_log.items(): loss_components[k] = loss_components.get(k, 0) + v n_batches += 1 for feat in ['cog', 'sog', 'rot', 'alt_rate']: pred_logits = predictions[f'{feat}_logits'][:, :-1, :] target = batch[f'{feat}_bins'][:, 1:] correct[feat] += (pred_logits.argmax(dim=-1) == target).sum().item() total_preds += batch['cog_bins'][:, 1:].numel() avg_metrics = {k: v / max(n_batches, 1) for k, v in loss_components.items()} for feat in ['cog', 'sog', 'rot', 'alt_rate']: avg_metrics[f'{feat}_acc'] = correct[feat] / max(total_preds, 1) return avg_metrics def main(): print("=" * 70) print("AirTrackLM - Full Training Pipeline") print("=" * 70) HUB_MODEL_ID = "Jdice27/AirTrackLM" config = AirTrackConfig( d_model=256, n_heads=8, n_layers=8, d_ff=1024, dropout=0.1, max_seq_len=256, geohash_mode='absolute', use_multi_uncertainty=True, n_uncert_methods=4, use_heteroscedastic=True, predict_geohash=True, predict_continuous=True, ) SEQ_LEN = 64 STRIDE = 32 BATCH_SIZE = 32 N_EPOCHS = 50 LR = 5e-4 WEIGHT_DECAY = 0.01 PATIENCE = 10 RESAMPLE_DT = 5.0 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Device: {device}") if device.type == 'cuda': print(f"GPU: {torch.cuda.get_device_name(0)}") print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") # ---- Trackio ---- tracker = None try: import trackio tracker = trackio.init(name="AirTrackLM-pretrain") print("Trackio initialized ✓") except Exception as e: print(f"Trackio: {e}") # ---- Load Data ---- print("\n1. Loading traffic sample data...") t0 = time.time() raw_trajs = [] for sample_name in ['quickstart']: try: trajs = load_traffic_sample(sample_name) raw_trajs.extend(trajs) print(f" {sample_name}: {len(trajs)} flights") except Exception as e: print(f" {sample_name}: failed ({e})") # Try additional samples for sample_name in ['switzerland', 'savan']: try: trajs = load_traffic_sample(sample_name) raw_trajs.extend(trajs) print(f" {sample_name}: {len(trajs)} flights") except Exception as e: print(f" {sample_name}: skipped ({e})") print(f" Total: {len(raw_trajs)} flights in {time.time()-t0:.1f}s") if len(raw_trajs) == 0: print("ERROR: No trajectories loaded!") return # Data audit lengths = [len(t['timestamps']) for t in raw_trajs] print(f" Lengths: min={min(lengths)}, max={max(lengths)}, median={np.median(lengths):.0f}") # ---- Process ---- print("\n2. Processing trajectories...") t0 = time.time() processor = TrajectoryProcessor(resample_dt=RESAMPLE_DT) dataset = build_dataset(raw_trajs, processor, seq_len=SEQ_LEN, stride=STRIDE) print(f" Processing: {time.time()-t0:.1f}s") if len(dataset) == 0: print("ERROR: No valid windows!") return # Split n_val = max(1, int(0.15 * len(dataset))) n_train = len(dataset) - n_val train_ds, val_ds = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(42)) print(f"\n3. Split: {n_train} train, {n_val} val") # ---- Model ---- model = AirTrackLM(config).to(device) param_counts = model.count_parameters() print(f"\n4. Model: {param_counts['total']:,} params ({param_counts['trainable']:,} trainable)") for name, count in param_counts.items(): if name not in ['total', 'trainable']: print(f" {name}: {count:,}") # ---- Loaders ---- train_loader = DataLoader( train_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, num_workers=2, pin_memory=(device.type == 'cuda'), ) val_loader = DataLoader( val_ds, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn, num_workers=2, pin_memory=(device.type == 'cuda'), ) print(f" {len(train_loader)} train batches, {len(val_loader)} val batches") # ---- Optimizer ---- loss_fn = NextStateLoss(config) optimizer = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY, betas=(0.9, 0.999)) total_steps = N_EPOCHS * len(train_loader) scheduler = CosineAnnealingLR(optimizer, T_max=total_steps, eta_min=LR * 0.01) scaler = torch.amp.GradScaler('cuda') if device.type == 'cuda' else None # ---- Train ---- output_dir = Path('./checkpoints') output_dir.mkdir(exist_ok=True) best_val_loss = float('inf') patience_counter = 0 history = [] global_step = 0 print(f"\n{'='*70}") print(f"Training: {N_EPOCHS} epochs, bs={BATCH_SIZE}, lr={LR}") print(f"{'='*70}\n") for epoch in range(N_EPOCHS): t_epoch = time.time() model.train() train_loss = 0.0 train_components = {} n_batches = 0 for batch_idx, batch in enumerate(train_loader): batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} if scaler is not None: with torch.amp.autocast('cuda'): predictions = model(batch) loss, loss_log = loss_fn(predictions, batch) scaler.scale(loss).backward() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) scaler.step(optimizer) scaler.update() else: predictions = model(batch) loss, loss_log = loss_fn(predictions, batch) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() optimizer.zero_grad() scheduler.step() global_step += 1 train_loss += loss_log['total'] for k, v in loss_log.items(): train_components[k] = train_components.get(k, 0) + v n_batches += 1 if tracker and global_step % 20 == 0: try: trackio.log({ 'train/loss': loss_log['total'], 'train/lr': scheduler.get_last_lr()[0], 'train/step': global_step, }) except Exception: pass if (batch_idx + 1) % 50 == 0: print(f" Epoch {epoch+1} Batch {batch_idx+1}/{len(train_loader)} | Loss: {train_loss/n_batches:.4f}") train_avg = {k: v / n_batches for k, v in train_components.items()} val_metrics = evaluate(model, val_loader, loss_fn, device) elapsed = time.time() - t_epoch improved = val_metrics['total'] < best_val_loss print(f"\nEpoch {epoch+1}/{N_EPOCHS} [{elapsed:.1f}s] {'★' if improved else ''}") print(f" Train loss={train_avg['total']:.4f} | Val loss={val_metrics['total']:.4f}") print(f" Val Acc - COG:{val_metrics.get('cog_acc',0):.3f} SOG:{val_metrics.get('sog_acc',0):.3f} " f"ROT:{val_metrics.get('rot_acc',0):.3f} AltRate:{val_metrics.get('alt_rate_acc',0):.3f}") print(f" LR: {scheduler.get_last_lr()[0]:.6f}") if tracker: try: trackio.log({ 'epoch': epoch + 1, 'val/loss': val_metrics['total'], **{f'val/{k}': v for k, v in val_metrics.items()}, 'train/epoch_loss': train_avg['total'], }) except Exception: pass history.append({ 'epoch': epoch + 1, 'train': train_avg, 'val': val_metrics, 'lr': scheduler.get_last_lr()[0], 'time': elapsed, }) if improved: best_val_loss = val_metrics['total'] patience_counter = 0 torch.save({ 'epoch': epoch + 1, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'config': config.__dict__, 'val_loss': best_val_loss, 'val_metrics': val_metrics, }, output_dir / 'best_model.pt') print(f" ★ Best model saved (val_loss={best_val_loss:.4f})") else: patience_counter += 1 if patience_counter >= PATIENCE: print(f"\nEarly stopping at epoch {epoch+1}") break print() # ---- Save & Push ---- print("\n" + "=" * 70) print("Saving and pushing to Hub...") torch.save({ 'epoch': epoch + 1, 'model_state_dict': model.state_dict(), 'config': config.__dict__, 'best_val_loss': best_val_loss, 'history': history, }, output_dir / 'final_model.pt') with open(output_dir / 'training_history.json', 'w') as f: json.dump(history, f, indent=2, default=str) with open(output_dir / 'config.json', 'w') as f: json.dump(config.__dict__, f, indent=2) try: from huggingface_hub import HfApi api = HfApi() api.upload_folder( folder_path=str(output_dir), repo_id=HUB_MODEL_ID, repo_type="model", commit_message=f"Training: val_loss={best_val_loss:.4f}", ) print(f"✓ Checkpoints pushed to https://huggingface.co/{HUB_MODEL_ID}") except Exception as e: print(f"Push checkpoints failed: {e}") # Upload source files try: script_dir = os.path.dirname(os.path.abspath(__file__)) for fname in ['data_pipeline.py', 'model.py', 'uncertainty.py', 'train_full.py']: fpath = os.path.join(script_dir, fname) if os.path.exists(fpath): api.upload_file( path_or_fileobj=fpath, path_in_repo=fname, repo_id=HUB_MODEL_ID, repo_type="model", ) print(f"✓ Source files uploaded") except Exception as e: print(f"Source upload failed: {e}") print(f"\nBest val loss: {best_val_loss:.4f}") print("Done!") if __name__ == '__main__': main()