| """ |
| AirTrackLM - Full Training Script |
| =================================== |
| Trains decoder-only transformer on traffic library ADS-B data. |
| Pushes model + source to HuggingFace Hub. |
| """ |
|
|
| import os |
| import sys |
| import time |
| import json |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| from torch.utils.data import DataLoader, random_split |
| from torch.optim import AdamW |
| from torch.optim.lr_scheduler import CosineAnnealingLR |
| from pathlib import Path |
|
|
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
|
|
| from data_pipeline import TrajectoryProcessor, FeatureBins, load_traffic_sample, build_dataset |
| from model import AirTrackLM, AirTrackConfig, NextStateLoss |
|
|
|
|
| def collate_fn(batch): |
| max_len = max(b['cog_bins'].size(0) for b in batch) |
| collated = {} |
| for key in batch[0].keys(): |
| tensors = [b[key] for b in batch] |
| if key == 'prompt': |
| collated[key] = torch.stack(tensors) |
| else: |
| padded = [] |
| for t in tensors: |
| if t.dim() == 1: |
| padded.append(F.pad(t, (0, max_len - t.size(0)), value=0)) |
| elif t.dim() == 2: |
| padded.append(F.pad(t, (0, 0, 0, max_len - t.size(0)), value=0)) |
| else: |
| padded.append(t) |
| collated[key] = torch.stack(padded) |
| return collated |
|
|
|
|
| @torch.no_grad() |
| def evaluate(model, dataloader, loss_fn, device): |
| model.eval() |
| total_loss = 0.0 |
| loss_components = {} |
| n_batches = 0 |
| correct = {'cog': 0, 'sog': 0, 'rot': 0, 'alt_rate': 0} |
| total_preds = 0 |
| |
| for batch in dataloader: |
| batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} |
| predictions = model(batch) |
| loss, loss_log = loss_fn(predictions, batch) |
| total_loss += loss_log['total'] |
| for k, v in loss_log.items(): |
| loss_components[k] = loss_components.get(k, 0) + v |
| n_batches += 1 |
| for feat in ['cog', 'sog', 'rot', 'alt_rate']: |
| pred_logits = predictions[f'{feat}_logits'][:, :-1, :] |
| target = batch[f'{feat}_bins'][:, 1:] |
| correct[feat] += (pred_logits.argmax(dim=-1) == target).sum().item() |
| total_preds += batch['cog_bins'][:, 1:].numel() |
| |
| avg_metrics = {k: v / max(n_batches, 1) for k, v in loss_components.items()} |
| for feat in ['cog', 'sog', 'rot', 'alt_rate']: |
| avg_metrics[f'{feat}_acc'] = correct[feat] / max(total_preds, 1) |
| return avg_metrics |
|
|
|
|
| def main(): |
| print("=" * 70) |
| print("AirTrackLM - Full Training Pipeline") |
| print("=" * 70) |
| |
| HUB_MODEL_ID = "Jdice27/AirTrackLM" |
| |
| config = AirTrackConfig( |
| d_model=256, n_heads=8, n_layers=8, d_ff=1024, |
| dropout=0.1, max_seq_len=256, geohash_mode='absolute', |
| use_multi_uncertainty=True, n_uncert_methods=4, |
| use_heteroscedastic=True, predict_geohash=True, predict_continuous=True, |
| ) |
| |
| SEQ_LEN = 64 |
| STRIDE = 32 |
| BATCH_SIZE = 32 |
| N_EPOCHS = 50 |
| LR = 5e-4 |
| WEIGHT_DECAY = 0.01 |
| PATIENCE = 10 |
| RESAMPLE_DT = 5.0 |
| |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f"Device: {device}") |
| if device.type == 'cuda': |
| print(f"GPU: {torch.cuda.get_device_name(0)}") |
| print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") |
| |
| |
| tracker = None |
| try: |
| import trackio |
| tracker = trackio.init(name="AirTrackLM-pretrain") |
| print("Trackio initialized ✓") |
| except Exception as e: |
| print(f"Trackio: {e}") |
| |
| |
| print("\n1. Loading traffic sample data...") |
| t0 = time.time() |
| raw_trajs = [] |
| for sample_name in ['quickstart']: |
| try: |
| trajs = load_traffic_sample(sample_name) |
| raw_trajs.extend(trajs) |
| print(f" {sample_name}: {len(trajs)} flights") |
| except Exception as e: |
| print(f" {sample_name}: failed ({e})") |
| |
| |
| for sample_name in ['switzerland', 'savan']: |
| try: |
| trajs = load_traffic_sample(sample_name) |
| raw_trajs.extend(trajs) |
| print(f" {sample_name}: {len(trajs)} flights") |
| except Exception as e: |
| print(f" {sample_name}: skipped ({e})") |
| |
| print(f" Total: {len(raw_trajs)} flights in {time.time()-t0:.1f}s") |
| |
| if len(raw_trajs) == 0: |
| print("ERROR: No trajectories loaded!") |
| return |
| |
| |
| lengths = [len(t['timestamps']) for t in raw_trajs] |
| print(f" Lengths: min={min(lengths)}, max={max(lengths)}, median={np.median(lengths):.0f}") |
| |
| |
| print("\n2. Processing trajectories...") |
| t0 = time.time() |
| processor = TrajectoryProcessor(resample_dt=RESAMPLE_DT) |
| dataset = build_dataset(raw_trajs, processor, seq_len=SEQ_LEN, stride=STRIDE) |
| print(f" Processing: {time.time()-t0:.1f}s") |
| |
| if len(dataset) == 0: |
| print("ERROR: No valid windows!") |
| return |
| |
| |
| n_val = max(1, int(0.15 * len(dataset))) |
| n_train = len(dataset) - n_val |
| train_ds, val_ds = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(42)) |
| print(f"\n3. Split: {n_train} train, {n_val} val") |
| |
| |
| model = AirTrackLM(config).to(device) |
| param_counts = model.count_parameters() |
| print(f"\n4. Model: {param_counts['total']:,} params ({param_counts['trainable']:,} trainable)") |
| for name, count in param_counts.items(): |
| if name not in ['total', 'trainable']: |
| print(f" {name}: {count:,}") |
| |
| |
| train_loader = DataLoader( |
| train_ds, batch_size=BATCH_SIZE, shuffle=True, |
| collate_fn=collate_fn, num_workers=2, pin_memory=(device.type == 'cuda'), |
| ) |
| val_loader = DataLoader( |
| val_ds, batch_size=BATCH_SIZE, shuffle=False, |
| collate_fn=collate_fn, num_workers=2, pin_memory=(device.type == 'cuda'), |
| ) |
| print(f" {len(train_loader)} train batches, {len(val_loader)} val batches") |
| |
| |
| loss_fn = NextStateLoss(config) |
| optimizer = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY, betas=(0.9, 0.999)) |
| total_steps = N_EPOCHS * len(train_loader) |
| scheduler = CosineAnnealingLR(optimizer, T_max=total_steps, eta_min=LR * 0.01) |
| scaler = torch.amp.GradScaler('cuda') if device.type == 'cuda' else None |
| |
| |
| output_dir = Path('./checkpoints') |
| output_dir.mkdir(exist_ok=True) |
| |
| best_val_loss = float('inf') |
| patience_counter = 0 |
| history = [] |
| global_step = 0 |
| |
| print(f"\n{'='*70}") |
| print(f"Training: {N_EPOCHS} epochs, bs={BATCH_SIZE}, lr={LR}") |
| print(f"{'='*70}\n") |
| |
| for epoch in range(N_EPOCHS): |
| t_epoch = time.time() |
| model.train() |
| train_loss = 0.0 |
| train_components = {} |
| n_batches = 0 |
| |
| for batch_idx, batch in enumerate(train_loader): |
| batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} |
| |
| if scaler is not None: |
| with torch.amp.autocast('cuda'): |
| predictions = model(batch) |
| loss, loss_log = loss_fn(predictions, batch) |
| scaler.scale(loss).backward() |
| scaler.unscale_(optimizer) |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| scaler.step(optimizer) |
| scaler.update() |
| else: |
| predictions = model(batch) |
| loss, loss_log = loss_fn(predictions, batch) |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| |
| optimizer.zero_grad() |
| scheduler.step() |
| global_step += 1 |
| |
| train_loss += loss_log['total'] |
| for k, v in loss_log.items(): |
| train_components[k] = train_components.get(k, 0) + v |
| n_batches += 1 |
| |
| if tracker and global_step % 20 == 0: |
| try: |
| trackio.log({ |
| 'train/loss': loss_log['total'], |
| 'train/lr': scheduler.get_last_lr()[0], |
| 'train/step': global_step, |
| }) |
| except Exception: |
| pass |
| |
| if (batch_idx + 1) % 50 == 0: |
| print(f" Epoch {epoch+1} Batch {batch_idx+1}/{len(train_loader)} | Loss: {train_loss/n_batches:.4f}") |
| |
| train_avg = {k: v / n_batches for k, v in train_components.items()} |
| val_metrics = evaluate(model, val_loader, loss_fn, device) |
| |
| elapsed = time.time() - t_epoch |
| improved = val_metrics['total'] < best_val_loss |
| |
| print(f"\nEpoch {epoch+1}/{N_EPOCHS} [{elapsed:.1f}s] {'★' if improved else ''}") |
| print(f" Train loss={train_avg['total']:.4f} | Val loss={val_metrics['total']:.4f}") |
| print(f" Val Acc - COG:{val_metrics.get('cog_acc',0):.3f} SOG:{val_metrics.get('sog_acc',0):.3f} " |
| f"ROT:{val_metrics.get('rot_acc',0):.3f} AltRate:{val_metrics.get('alt_rate_acc',0):.3f}") |
| print(f" LR: {scheduler.get_last_lr()[0]:.6f}") |
| |
| if tracker: |
| try: |
| trackio.log({ |
| 'epoch': epoch + 1, |
| 'val/loss': val_metrics['total'], |
| **{f'val/{k}': v for k, v in val_metrics.items()}, |
| 'train/epoch_loss': train_avg['total'], |
| }) |
| except Exception: |
| pass |
| |
| history.append({ |
| 'epoch': epoch + 1, 'train': train_avg, |
| 'val': val_metrics, 'lr': scheduler.get_last_lr()[0], 'time': elapsed, |
| }) |
| |
| if improved: |
| best_val_loss = val_metrics['total'] |
| patience_counter = 0 |
| torch.save({ |
| 'epoch': epoch + 1, |
| 'model_state_dict': model.state_dict(), |
| 'optimizer_state_dict': optimizer.state_dict(), |
| 'config': config.__dict__, |
| 'val_loss': best_val_loss, |
| 'val_metrics': val_metrics, |
| }, output_dir / 'best_model.pt') |
| print(f" ★ Best model saved (val_loss={best_val_loss:.4f})") |
| else: |
| patience_counter += 1 |
| if patience_counter >= PATIENCE: |
| print(f"\nEarly stopping at epoch {epoch+1}") |
| break |
| print() |
| |
| |
| print("\n" + "=" * 70) |
| print("Saving and pushing to Hub...") |
| |
| torch.save({ |
| 'epoch': epoch + 1, 'model_state_dict': model.state_dict(), |
| 'config': config.__dict__, 'best_val_loss': best_val_loss, 'history': history, |
| }, output_dir / 'final_model.pt') |
| |
| with open(output_dir / 'training_history.json', 'w') as f: |
| json.dump(history, f, indent=2, default=str) |
| |
| with open(output_dir / 'config.json', 'w') as f: |
| json.dump(config.__dict__, f, indent=2) |
| |
| try: |
| from huggingface_hub import HfApi |
| api = HfApi() |
| api.upload_folder( |
| folder_path=str(output_dir), repo_id=HUB_MODEL_ID, repo_type="model", |
| commit_message=f"Training: val_loss={best_val_loss:.4f}", |
| ) |
| print(f"✓ Checkpoints pushed to https://huggingface.co/{HUB_MODEL_ID}") |
| except Exception as e: |
| print(f"Push checkpoints failed: {e}") |
| |
| |
| try: |
| script_dir = os.path.dirname(os.path.abspath(__file__)) |
| for fname in ['data_pipeline.py', 'model.py', 'uncertainty.py', 'train_full.py']: |
| fpath = os.path.join(script_dir, fname) |
| if os.path.exists(fpath): |
| api.upload_file( |
| path_or_fileobj=fpath, path_in_repo=fname, |
| repo_id=HUB_MODEL_ID, repo_type="model", |
| ) |
| print(f"✓ Source files uploaded") |
| except Exception as e: |
| print(f"Source upload failed: {e}") |
| |
| print(f"\nBest val loss: {best_val_loss:.4f}") |
| print("Done!") |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|