| """ |
| AirTrackLM - Training Script |
| ============================= |
| Pretraining on next-state prediction with multi-head output. |
| """ |
|
|
| import os |
| import time |
| import json |
| import torch |
| import torch.nn as nn |
| import numpy as np |
| from torch.utils.data import DataLoader, random_split |
| from torch.optim import AdamW |
| from torch.optim.lr_scheduler import CosineAnnealingLR |
| from typing import Dict, Optional |
|
|
| from data_pipeline import ( |
| TrajectoryProcessor, FeatureBins, load_traffic_sample, build_dataset |
| ) |
| from model import AirTrackLM, AirTrackConfig, NextStateLoss |
|
|
|
|
| def collate_fn(batch): |
| """Custom collate: pad variable-length sequences to max length in batch.""" |
| |
| max_len = max(b['cog_bins'].size(0) for b in batch) |
| |
| collated = {} |
| for key in batch[0].keys(): |
| tensors = [b[key] for b in batch] |
| |
| if key == 'prompt': |
| |
| collated[key] = torch.stack(tensors) |
| else: |
| |
| padded = [] |
| for t in tensors: |
| if t.dim() == 1: |
| pad_size = max_len - t.size(0) |
| padded.append(F.pad(t, (0, pad_size), value=0)) |
| elif t.dim() == 2: |
| pad_size = max_len - t.size(0) |
| padded.append(F.pad(t, (0, 0, 0, pad_size), value=0)) |
| else: |
| padded.append(t) |
| collated[key] = torch.stack(padded) |
| |
| return collated |
|
|
|
|
| import torch.nn.functional as F |
|
|
|
|
| def train_epoch( |
| model: AirTrackLM, |
| dataloader: DataLoader, |
| loss_fn: NextStateLoss, |
| optimizer: torch.optim.Optimizer, |
| device: torch.device, |
| grad_clip: float = 1.0, |
| ) -> Dict[str, float]: |
| """Train for one epoch.""" |
| model.train() |
| |
| total_loss = 0.0 |
| loss_components = {} |
| n_batches = 0 |
| |
| for batch_idx, batch in enumerate(dataloader): |
| |
| batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} |
| |
| |
| predictions = model(batch) |
| loss, loss_log = loss_fn(predictions, batch) |
| |
| |
| optimizer.zero_grad() |
| loss.backward() |
| |
| |
| torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) |
| |
| optimizer.step() |
| |
| |
| total_loss += loss_log['total'] |
| for k, v in loss_log.items(): |
| loss_components[k] = loss_components.get(k, 0) + v |
| n_batches += 1 |
| |
| if (batch_idx + 1) % 10 == 0: |
| avg_loss = total_loss / n_batches |
| print(f" Batch {batch_idx+1}/{len(dataloader)} | Loss: {avg_loss:.4f}") |
| |
| |
| avg_metrics = {k: v / max(n_batches, 1) for k, v in loss_components.items()} |
| return avg_metrics |
|
|
|
|
| @torch.no_grad() |
| def evaluate( |
| model: AirTrackLM, |
| dataloader: DataLoader, |
| loss_fn: NextStateLoss, |
| device: torch.device, |
| ) -> Dict[str, float]: |
| """Evaluate model on validation set.""" |
| model.eval() |
| |
| total_loss = 0.0 |
| loss_components = {} |
| n_batches = 0 |
| |
| |
| correct = {'cog': 0, 'sog': 0, 'rot': 0, 'alt_rate': 0} |
| total_preds = 0 |
| |
| for batch in dataloader: |
| batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} |
| |
| predictions = model(batch) |
| loss, loss_log = loss_fn(predictions, batch) |
| |
| total_loss += loss_log['total'] |
| for k, v in loss_log.items(): |
| loss_components[k] = loss_components.get(k, 0) + v |
| n_batches += 1 |
| |
| |
| for feat in ['cog', 'sog', 'rot', 'alt_rate']: |
| pred_logits = predictions[f'{feat}_logits'][:, :-1, :] |
| target = batch[f'{feat}_bins'][:, 1:] |
| pred_class = pred_logits.argmax(dim=-1) |
| correct[feat] += (pred_class == target).sum().item() |
| |
| total_preds += batch['cog_bins'][:, 1:].numel() |
| |
| avg_metrics = {k: v / max(n_batches, 1) for k, v in loss_components.items()} |
| |
| |
| for feat in ['cog', 'sog', 'rot', 'alt_rate']: |
| avg_metrics[f'{feat}_acc'] = correct[feat] / max(total_preds, 1) |
| |
| return avg_metrics |
|
|
|
|
| def train( |
| config: AirTrackConfig, |
| train_dataset, |
| val_dataset, |
| output_dir: str = './checkpoints', |
| n_epochs: int = 30, |
| batch_size: int = 32, |
| learning_rate: float = 5e-4, |
| weight_decay: float = 0.01, |
| warmup_fraction: float = 0.05, |
| grad_clip: float = 1.0, |
| patience: int = 5, |
| device: str = 'auto', |
| use_trackio: bool = False, |
| ): |
| """Full training loop.""" |
| |
| |
| if device == 'auto': |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| else: |
| device = torch.device(device) |
| print(f"Using device: {device}") |
| |
| |
| model = AirTrackLM(config).to(device) |
| param_counts = model.count_parameters() |
| print(f"Model parameters: {param_counts['total']:,} ({param_counts['trainable']:,} trainable)") |
| |
| |
| train_loader = DataLoader( |
| train_dataset, |
| batch_size=batch_size, |
| shuffle=True, |
| collate_fn=collate_fn, |
| num_workers=0, |
| pin_memory=(device.type == 'cuda'), |
| ) |
| val_loader = DataLoader( |
| val_dataset, |
| batch_size=batch_size, |
| shuffle=False, |
| collate_fn=collate_fn, |
| num_workers=0, |
| pin_memory=(device.type == 'cuda'), |
| ) |
| |
| print(f"Train: {len(train_dataset)} samples, {len(train_loader)} batches") |
| print(f"Val: {len(val_dataset)} samples, {len(val_loader)} batches") |
| |
| |
| loss_fn = NextStateLoss(config) |
| |
| |
| optimizer = AdamW( |
| model.parameters(), |
| lr=learning_rate, |
| weight_decay=weight_decay, |
| betas=(0.9, 0.999), |
| ) |
| |
| |
| total_steps = n_epochs * len(train_loader) |
| scheduler = CosineAnnealingLR(optimizer, T_max=total_steps, eta_min=learning_rate * 0.01) |
| |
| |
| tracker = None |
| if use_trackio: |
| try: |
| import trackio |
| tracker = trackio.init(name="AirTrackLM-pretrain") |
| print("Trackio initialized") |
| except ImportError: |
| print("Trackio not available, skipping monitoring") |
| |
| |
| os.makedirs(output_dir, exist_ok=True) |
| |
| |
| best_val_loss = float('inf') |
| patience_counter = 0 |
| history = [] |
| |
| print(f"\n{'='*60}") |
| print(f"Starting training: {n_epochs} epochs") |
| print(f"{'='*60}\n") |
| |
| for epoch in range(n_epochs): |
| t_start = time.time() |
| |
| |
| print(f"Epoch {epoch+1}/{n_epochs}") |
| train_metrics = train_epoch(model, train_loader, loss_fn, optimizer, device, grad_clip) |
| |
| |
| scheduler.step() |
| |
| |
| val_metrics = evaluate(model, val_loader, loss_fn, device) |
| |
| t_elapsed = time.time() - t_start |
| |
| |
| print(f" Train Loss: {train_metrics['total']:.4f} | Val Loss: {val_metrics['total']:.4f}") |
| print(f" Val Acc - COG: {val_metrics.get('cog_acc', 0):.3f}, SOG: {val_metrics.get('sog_acc', 0):.3f}, " |
| f"ROT: {val_metrics.get('rot_acc', 0):.3f}, AltRate: {val_metrics.get('alt_rate_acc', 0):.3f}") |
| print(f" Time: {t_elapsed:.1f}s | LR: {scheduler.get_last_lr()[0]:.6f}") |
| |
| |
| if tracker is not None: |
| trackio.log({ |
| 'train/loss': train_metrics['total'], |
| 'val/loss': val_metrics['total'], |
| **{f'train/{k}': v for k, v in train_metrics.items() if k != 'total'}, |
| **{f'val/{k}': v for k, v in val_metrics.items()}, |
| 'lr': scheduler.get_last_lr()[0], |
| 'epoch': epoch + 1, |
| }) |
| |
| |
| history.append({ |
| 'epoch': epoch + 1, |
| 'train': train_metrics, |
| 'val': val_metrics, |
| 'lr': scheduler.get_last_lr()[0], |
| 'time': t_elapsed, |
| }) |
| |
| |
| if val_metrics['total'] < best_val_loss: |
| best_val_loss = val_metrics['total'] |
| patience_counter = 0 |
| |
| checkpoint = { |
| 'epoch': epoch + 1, |
| 'model_state_dict': model.state_dict(), |
| 'optimizer_state_dict': optimizer.state_dict(), |
| 'scheduler_state_dict': scheduler.state_dict(), |
| 'config': config.__dict__, |
| 'val_loss': best_val_loss, |
| 'val_metrics': val_metrics, |
| } |
| torch.save(checkpoint, os.path.join(output_dir, 'best_model.pt')) |
| print(f" ★ New best model saved (val_loss={best_val_loss:.4f})") |
| else: |
| patience_counter += 1 |
| if patience_counter >= patience: |
| print(f"\nEarly stopping after {patience} epochs without improvement.") |
| break |
| |
| print() |
| |
| |
| torch.save({ |
| 'epoch': epoch + 1, |
| 'model_state_dict': model.state_dict(), |
| 'config': config.__dict__, |
| }, os.path.join(output_dir, 'final_model.pt')) |
| |
| |
| with open(os.path.join(output_dir, 'training_history.json'), 'w') as f: |
| json.dump(history, f, indent=2, default=str) |
| |
| print(f"\nTraining complete. Best val loss: {best_val_loss:.4f}") |
| print(f"Checkpoints saved to {output_dir}") |
| |
| return model, history |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == '__main__': |
| print("=" * 60) |
| print("AirTrackLM - Pretraining on Traffic Sample Data") |
| print("=" * 60) |
| |
| |
| config = AirTrackConfig( |
| d_model=256, |
| n_heads=8, |
| n_layers=8, |
| d_ff=1024, |
| dropout=0.1, |
| max_seq_len=256, |
| geohash_mode='absolute', |
| ) |
| |
| |
| print("\n1. Loading traffic sample data...") |
| raw_trajs = load_traffic_sample() |
| print(f" Loaded {len(raw_trajs)} raw trajectories") |
| |
| |
| print("\n2. Processing trajectories...") |
| processor = TrajectoryProcessor(resample_dt=5.0) |
| |
| seq_len = 64 |
| stride = 32 |
| |
| dataset = build_dataset(raw_trajs, processor, seq_len=seq_len, stride=stride) |
| |
| if len(dataset) == 0: |
| print("ERROR: No valid windows found. Check data.") |
| exit(1) |
| |
| |
| n_val = max(1, int(0.15 * len(dataset))) |
| n_train = len(dataset) - n_val |
| train_dataset, val_dataset = random_split(dataset, [n_train, n_val]) |
| |
| print(f"\n3. Dataset split: {n_train} train, {n_val} val") |
| |
| |
| print("\n4. Starting training...") |
| model, history = train( |
| config=config, |
| train_dataset=train_dataset, |
| val_dataset=val_dataset, |
| output_dir='./checkpoints', |
| n_epochs=10, |
| batch_size=16, |
| learning_rate=5e-4, |
| patience=5, |
| device='auto', |
| use_trackio=False, |
| ) |
| |
| print("\nDone!") |
|
|