""" AirTrackLM - Training Script ============================= Pretraining on next-state prediction with multi-head output. """ import os import time import json import torch import torch.nn as nn import numpy as np from torch.utils.data import DataLoader, random_split from torch.optim import AdamW from torch.optim.lr_scheduler import CosineAnnealingLR from typing import Dict, Optional from data_pipeline import ( TrajectoryProcessor, FeatureBins, load_traffic_sample, build_dataset ) from model import AirTrackLM, AirTrackConfig, NextStateLoss def collate_fn(batch): """Custom collate: pad variable-length sequences to max length in batch.""" # Find max sequence length in this batch max_len = max(b['cog_bins'].size(0) for b in batch) collated = {} for key in batch[0].keys(): tensors = [b[key] for b in batch] if key == 'prompt': # Fixed length, just stack collated[key] = torch.stack(tensors) else: # Pad to max_len padded = [] for t in tensors: if t.dim() == 1: pad_size = max_len - t.size(0) padded.append(F.pad(t, (0, pad_size), value=0)) elif t.dim() == 2: pad_size = max_len - t.size(0) padded.append(F.pad(t, (0, 0, 0, pad_size), value=0)) else: padded.append(t) collated[key] = torch.stack(padded) return collated import torch.nn.functional as F def train_epoch( model: AirTrackLM, dataloader: DataLoader, loss_fn: NextStateLoss, optimizer: torch.optim.Optimizer, device: torch.device, grad_clip: float = 1.0, ) -> Dict[str, float]: """Train for one epoch.""" model.train() total_loss = 0.0 loss_components = {} n_batches = 0 for batch_idx, batch in enumerate(dataloader): # Move to device batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # Forward predictions = model(batch) loss, loss_log = loss_fn(predictions, batch) # Backward optimizer.zero_grad() loss.backward() # Gradient clipping torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) optimizer.step() # Accumulate metrics total_loss += loss_log['total'] for k, v in loss_log.items(): loss_components[k] = loss_components.get(k, 0) + v n_batches += 1 if (batch_idx + 1) % 10 == 0: avg_loss = total_loss / n_batches print(f" Batch {batch_idx+1}/{len(dataloader)} | Loss: {avg_loss:.4f}") # Average avg_metrics = {k: v / max(n_batches, 1) for k, v in loss_components.items()} return avg_metrics @torch.no_grad() def evaluate( model: AirTrackLM, dataloader: DataLoader, loss_fn: NextStateLoss, device: torch.device, ) -> Dict[str, float]: """Evaluate model on validation set.""" model.eval() total_loss = 0.0 loss_components = {} n_batches = 0 # Also compute accuracy for discrete predictions correct = {'cog': 0, 'sog': 0, 'rot': 0, 'alt_rate': 0} total_preds = 0 for batch in dataloader: batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} predictions = model(batch) loss, loss_log = loss_fn(predictions, batch) total_loss += loss_log['total'] for k, v in loss_log.items(): loss_components[k] = loss_components.get(k, 0) + v n_batches += 1 # Accuracy for feat in ['cog', 'sog', 'rot', 'alt_rate']: pred_logits = predictions[f'{feat}_logits'][:, :-1, :] target = batch[f'{feat}_bins'][:, 1:] pred_class = pred_logits.argmax(dim=-1) correct[feat] += (pred_class == target).sum().item() total_preds += batch['cog_bins'][:, 1:].numel() avg_metrics = {k: v / max(n_batches, 1) for k, v in loss_components.items()} # Add accuracy for feat in ['cog', 'sog', 'rot', 'alt_rate']: avg_metrics[f'{feat}_acc'] = correct[feat] / max(total_preds, 1) return avg_metrics def train( config: AirTrackConfig, train_dataset, val_dataset, output_dir: str = './checkpoints', n_epochs: int = 30, batch_size: int = 32, learning_rate: float = 5e-4, weight_decay: float = 0.01, warmup_fraction: float = 0.05, grad_clip: float = 1.0, patience: int = 5, device: str = 'auto', use_trackio: bool = False, ): """Full training loop.""" # Device if device == 'auto': device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') else: device = torch.device(device) print(f"Using device: {device}") # Model model = AirTrackLM(config).to(device) param_counts = model.count_parameters() print(f"Model parameters: {param_counts['total']:,} ({param_counts['trainable']:,} trainable)") # Data loaders train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, num_workers=0, pin_memory=(device.type == 'cuda'), ) val_loader = DataLoader( val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0, pin_memory=(device.type == 'cuda'), ) print(f"Train: {len(train_dataset)} samples, {len(train_loader)} batches") print(f"Val: {len(val_dataset)} samples, {len(val_loader)} batches") # Loss loss_fn = NextStateLoss(config) # Optimizer optimizer = AdamW( model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(0.9, 0.999), ) # Scheduler total_steps = n_epochs * len(train_loader) scheduler = CosineAnnealingLR(optimizer, T_max=total_steps, eta_min=learning_rate * 0.01) # Trackio tracker = None if use_trackio: try: import trackio tracker = trackio.init(name="AirTrackLM-pretrain") print("Trackio initialized") except ImportError: print("Trackio not available, skipping monitoring") # Output directory os.makedirs(output_dir, exist_ok=True) # Training state best_val_loss = float('inf') patience_counter = 0 history = [] print(f"\n{'='*60}") print(f"Starting training: {n_epochs} epochs") print(f"{'='*60}\n") for epoch in range(n_epochs): t_start = time.time() # Train print(f"Epoch {epoch+1}/{n_epochs}") train_metrics = train_epoch(model, train_loader, loss_fn, optimizer, device, grad_clip) # Step scheduler scheduler.step() # Validate val_metrics = evaluate(model, val_loader, loss_fn, device) t_elapsed = time.time() - t_start # Log print(f" Train Loss: {train_metrics['total']:.4f} | Val Loss: {val_metrics['total']:.4f}") print(f" Val Acc - COG: {val_metrics.get('cog_acc', 0):.3f}, SOG: {val_metrics.get('sog_acc', 0):.3f}, " f"ROT: {val_metrics.get('rot_acc', 0):.3f}, AltRate: {val_metrics.get('alt_rate_acc', 0):.3f}") print(f" Time: {t_elapsed:.1f}s | LR: {scheduler.get_last_lr()[0]:.6f}") # Trackio logging if tracker is not None: trackio.log({ 'train/loss': train_metrics['total'], 'val/loss': val_metrics['total'], **{f'train/{k}': v for k, v in train_metrics.items() if k != 'total'}, **{f'val/{k}': v for k, v in val_metrics.items()}, 'lr': scheduler.get_last_lr()[0], 'epoch': epoch + 1, }) # History history.append({ 'epoch': epoch + 1, 'train': train_metrics, 'val': val_metrics, 'lr': scheduler.get_last_lr()[0], 'time': t_elapsed, }) # Best model checkpoint if val_metrics['total'] < best_val_loss: best_val_loss = val_metrics['total'] patience_counter = 0 checkpoint = { 'epoch': epoch + 1, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'config': config.__dict__, 'val_loss': best_val_loss, 'val_metrics': val_metrics, } torch.save(checkpoint, os.path.join(output_dir, 'best_model.pt')) print(f" ★ New best model saved (val_loss={best_val_loss:.4f})") else: patience_counter += 1 if patience_counter >= patience: print(f"\nEarly stopping after {patience} epochs without improvement.") break print() # Save final model torch.save({ 'epoch': epoch + 1, 'model_state_dict': model.state_dict(), 'config': config.__dict__, }, os.path.join(output_dir, 'final_model.pt')) # Save history with open(os.path.join(output_dir, 'training_history.json'), 'w') as f: json.dump(history, f, indent=2, default=str) print(f"\nTraining complete. Best val loss: {best_val_loss:.4f}") print(f"Checkpoints saved to {output_dir}") return model, history # ============================================================ # Main entry point # ============================================================ if __name__ == '__main__': print("=" * 60) print("AirTrackLM - Pretraining on Traffic Sample Data") print("=" * 60) # Configuration config = AirTrackConfig( d_model=256, n_heads=8, n_layers=8, d_ff=1024, dropout=0.1, max_seq_len=256, geohash_mode='absolute', ) # Load data print("\n1. Loading traffic sample data...") raw_trajs = load_traffic_sample() print(f" Loaded {len(raw_trajs)} raw trajectories") # Process print("\n2. Processing trajectories...") processor = TrajectoryProcessor(resample_dt=5.0) seq_len = 64 # 64 states × 5s = ~5 minutes per window stride = 32 # 50% overlap dataset = build_dataset(raw_trajs, processor, seq_len=seq_len, stride=stride) if len(dataset) == 0: print("ERROR: No valid windows found. Check data.") exit(1) # Split n_val = max(1, int(0.15 * len(dataset))) n_train = len(dataset) - n_val train_dataset, val_dataset = random_split(dataset, [n_train, n_val]) print(f"\n3. Dataset split: {n_train} train, {n_val} val") # Train print("\n4. Starting training...") model, history = train( config=config, train_dataset=train_dataset, val_dataset=val_dataset, output_dir='./checkpoints', n_epochs=10, # quick run for testing batch_size=16, learning_rate=5e-4, patience=5, device='auto', use_trackio=False, ) print("\nDone!")