Jdice27 commited on
Commit
9ad9478
·
verified ·
1 Parent(s): f967e70

Add train.py

Browse files
Files changed (1) hide show
  1. train.py +379 -0
train.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AirTrackLM - Training Script
3
+ =============================
4
+ Pretraining on next-state prediction with multi-head output.
5
+ """
6
+
7
+ import os
8
+ import time
9
+ import json
10
+ import torch
11
+ import torch.nn as nn
12
+ import numpy as np
13
+ from torch.utils.data import DataLoader, random_split
14
+ from torch.optim import AdamW
15
+ from torch.optim.lr_scheduler import CosineAnnealingLR
16
+ from typing import Dict, Optional
17
+
18
+ from data_pipeline import (
19
+ TrajectoryProcessor, FeatureBins, load_traffic_sample, build_dataset
20
+ )
21
+ from model import AirTrackLM, AirTrackConfig, NextStateLoss
22
+
23
+
24
+ def collate_fn(batch):
25
+ """Custom collate: pad variable-length sequences to max length in batch."""
26
+ # Find max sequence length in this batch
27
+ max_len = max(b['cog_bins'].size(0) for b in batch)
28
+
29
+ collated = {}
30
+ for key in batch[0].keys():
31
+ tensors = [b[key] for b in batch]
32
+
33
+ if key == 'prompt':
34
+ # Fixed length, just stack
35
+ collated[key] = torch.stack(tensors)
36
+ else:
37
+ # Pad to max_len
38
+ padded = []
39
+ for t in tensors:
40
+ if t.dim() == 1:
41
+ pad_size = max_len - t.size(0)
42
+ padded.append(F.pad(t, (0, pad_size), value=0))
43
+ elif t.dim() == 2:
44
+ pad_size = max_len - t.size(0)
45
+ padded.append(F.pad(t, (0, 0, 0, pad_size), value=0))
46
+ else:
47
+ padded.append(t)
48
+ collated[key] = torch.stack(padded)
49
+
50
+ return collated
51
+
52
+
53
+ import torch.nn.functional as F
54
+
55
+
56
+ def train_epoch(
57
+ model: AirTrackLM,
58
+ dataloader: DataLoader,
59
+ loss_fn: NextStateLoss,
60
+ optimizer: torch.optim.Optimizer,
61
+ device: torch.device,
62
+ grad_clip: float = 1.0,
63
+ ) -> Dict[str, float]:
64
+ """Train for one epoch."""
65
+ model.train()
66
+
67
+ total_loss = 0.0
68
+ loss_components = {}
69
+ n_batches = 0
70
+
71
+ for batch_idx, batch in enumerate(dataloader):
72
+ # Move to device
73
+ batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
74
+
75
+ # Forward
76
+ predictions = model(batch)
77
+ loss, loss_log = loss_fn(predictions, batch)
78
+
79
+ # Backward
80
+ optimizer.zero_grad()
81
+ loss.backward()
82
+
83
+ # Gradient clipping
84
+ torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
85
+
86
+ optimizer.step()
87
+
88
+ # Accumulate metrics
89
+ total_loss += loss_log['total']
90
+ for k, v in loss_log.items():
91
+ loss_components[k] = loss_components.get(k, 0) + v
92
+ n_batches += 1
93
+
94
+ if (batch_idx + 1) % 10 == 0:
95
+ avg_loss = total_loss / n_batches
96
+ print(f" Batch {batch_idx+1}/{len(dataloader)} | Loss: {avg_loss:.4f}")
97
+
98
+ # Average
99
+ avg_metrics = {k: v / max(n_batches, 1) for k, v in loss_components.items()}
100
+ return avg_metrics
101
+
102
+
103
+ @torch.no_grad()
104
+ def evaluate(
105
+ model: AirTrackLM,
106
+ dataloader: DataLoader,
107
+ loss_fn: NextStateLoss,
108
+ device: torch.device,
109
+ ) -> Dict[str, float]:
110
+ """Evaluate model on validation set."""
111
+ model.eval()
112
+
113
+ total_loss = 0.0
114
+ loss_components = {}
115
+ n_batches = 0
116
+
117
+ # Also compute accuracy for discrete predictions
118
+ correct = {'cog': 0, 'sog': 0, 'rot': 0, 'alt_rate': 0}
119
+ total_preds = 0
120
+
121
+ for batch in dataloader:
122
+ batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
123
+
124
+ predictions = model(batch)
125
+ loss, loss_log = loss_fn(predictions, batch)
126
+
127
+ total_loss += loss_log['total']
128
+ for k, v in loss_log.items():
129
+ loss_components[k] = loss_components.get(k, 0) + v
130
+ n_batches += 1
131
+
132
+ # Accuracy
133
+ for feat in ['cog', 'sog', 'rot', 'alt_rate']:
134
+ pred_logits = predictions[f'{feat}_logits'][:, :-1, :]
135
+ target = batch[f'{feat}_bins'][:, 1:]
136
+ pred_class = pred_logits.argmax(dim=-1)
137
+ correct[feat] += (pred_class == target).sum().item()
138
+
139
+ total_preds += batch['cog_bins'][:, 1:].numel()
140
+
141
+ avg_metrics = {k: v / max(n_batches, 1) for k, v in loss_components.items()}
142
+
143
+ # Add accuracy
144
+ for feat in ['cog', 'sog', 'rot', 'alt_rate']:
145
+ avg_metrics[f'{feat}_acc'] = correct[feat] / max(total_preds, 1)
146
+
147
+ return avg_metrics
148
+
149
+
150
+ def train(
151
+ config: AirTrackConfig,
152
+ train_dataset,
153
+ val_dataset,
154
+ output_dir: str = './checkpoints',
155
+ n_epochs: int = 30,
156
+ batch_size: int = 32,
157
+ learning_rate: float = 5e-4,
158
+ weight_decay: float = 0.01,
159
+ warmup_fraction: float = 0.05,
160
+ grad_clip: float = 1.0,
161
+ patience: int = 5,
162
+ device: str = 'auto',
163
+ use_trackio: bool = False,
164
+ ):
165
+ """Full training loop."""
166
+
167
+ # Device
168
+ if device == 'auto':
169
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
170
+ else:
171
+ device = torch.device(device)
172
+ print(f"Using device: {device}")
173
+
174
+ # Model
175
+ model = AirTrackLM(config).to(device)
176
+ param_counts = model.count_parameters()
177
+ print(f"Model parameters: {param_counts['total']:,} ({param_counts['trainable']:,} trainable)")
178
+
179
+ # Data loaders
180
+ train_loader = DataLoader(
181
+ train_dataset,
182
+ batch_size=batch_size,
183
+ shuffle=True,
184
+ collate_fn=collate_fn,
185
+ num_workers=0,
186
+ pin_memory=(device.type == 'cuda'),
187
+ )
188
+ val_loader = DataLoader(
189
+ val_dataset,
190
+ batch_size=batch_size,
191
+ shuffle=False,
192
+ collate_fn=collate_fn,
193
+ num_workers=0,
194
+ pin_memory=(device.type == 'cuda'),
195
+ )
196
+
197
+ print(f"Train: {len(train_dataset)} samples, {len(train_loader)} batches")
198
+ print(f"Val: {len(val_dataset)} samples, {len(val_loader)} batches")
199
+
200
+ # Loss
201
+ loss_fn = NextStateLoss(config)
202
+
203
+ # Optimizer
204
+ optimizer = AdamW(
205
+ model.parameters(),
206
+ lr=learning_rate,
207
+ weight_decay=weight_decay,
208
+ betas=(0.9, 0.999),
209
+ )
210
+
211
+ # Scheduler
212
+ total_steps = n_epochs * len(train_loader)
213
+ scheduler = CosineAnnealingLR(optimizer, T_max=total_steps, eta_min=learning_rate * 0.01)
214
+
215
+ # Trackio
216
+ tracker = None
217
+ if use_trackio:
218
+ try:
219
+ import trackio
220
+ tracker = trackio.init(name="AirTrackLM-pretrain")
221
+ print("Trackio initialized")
222
+ except ImportError:
223
+ print("Trackio not available, skipping monitoring")
224
+
225
+ # Output directory
226
+ os.makedirs(output_dir, exist_ok=True)
227
+
228
+ # Training state
229
+ best_val_loss = float('inf')
230
+ patience_counter = 0
231
+ history = []
232
+
233
+ print(f"\n{'='*60}")
234
+ print(f"Starting training: {n_epochs} epochs")
235
+ print(f"{'='*60}\n")
236
+
237
+ for epoch in range(n_epochs):
238
+ t_start = time.time()
239
+
240
+ # Train
241
+ print(f"Epoch {epoch+1}/{n_epochs}")
242
+ train_metrics = train_epoch(model, train_loader, loss_fn, optimizer, device, grad_clip)
243
+
244
+ # Step scheduler
245
+ scheduler.step()
246
+
247
+ # Validate
248
+ val_metrics = evaluate(model, val_loader, loss_fn, device)
249
+
250
+ t_elapsed = time.time() - t_start
251
+
252
+ # Log
253
+ print(f" Train Loss: {train_metrics['total']:.4f} | Val Loss: {val_metrics['total']:.4f}")
254
+ print(f" Val Acc - COG: {val_metrics.get('cog_acc', 0):.3f}, SOG: {val_metrics.get('sog_acc', 0):.3f}, "
255
+ f"ROT: {val_metrics.get('rot_acc', 0):.3f}, AltRate: {val_metrics.get('alt_rate_acc', 0):.3f}")
256
+ print(f" Time: {t_elapsed:.1f}s | LR: {scheduler.get_last_lr()[0]:.6f}")
257
+
258
+ # Trackio logging
259
+ if tracker is not None:
260
+ trackio.log({
261
+ 'train/loss': train_metrics['total'],
262
+ 'val/loss': val_metrics['total'],
263
+ **{f'train/{k}': v for k, v in train_metrics.items() if k != 'total'},
264
+ **{f'val/{k}': v for k, v in val_metrics.items()},
265
+ 'lr': scheduler.get_last_lr()[0],
266
+ 'epoch': epoch + 1,
267
+ })
268
+
269
+ # History
270
+ history.append({
271
+ 'epoch': epoch + 1,
272
+ 'train': train_metrics,
273
+ 'val': val_metrics,
274
+ 'lr': scheduler.get_last_lr()[0],
275
+ 'time': t_elapsed,
276
+ })
277
+
278
+ # Best model checkpoint
279
+ if val_metrics['total'] < best_val_loss:
280
+ best_val_loss = val_metrics['total']
281
+ patience_counter = 0
282
+
283
+ checkpoint = {
284
+ 'epoch': epoch + 1,
285
+ 'model_state_dict': model.state_dict(),
286
+ 'optimizer_state_dict': optimizer.state_dict(),
287
+ 'scheduler_state_dict': scheduler.state_dict(),
288
+ 'config': config.__dict__,
289
+ 'val_loss': best_val_loss,
290
+ 'val_metrics': val_metrics,
291
+ }
292
+ torch.save(checkpoint, os.path.join(output_dir, 'best_model.pt'))
293
+ print(f" ★ New best model saved (val_loss={best_val_loss:.4f})")
294
+ else:
295
+ patience_counter += 1
296
+ if patience_counter >= patience:
297
+ print(f"\nEarly stopping after {patience} epochs without improvement.")
298
+ break
299
+
300
+ print()
301
+
302
+ # Save final model
303
+ torch.save({
304
+ 'epoch': epoch + 1,
305
+ 'model_state_dict': model.state_dict(),
306
+ 'config': config.__dict__,
307
+ }, os.path.join(output_dir, 'final_model.pt'))
308
+
309
+ # Save history
310
+ with open(os.path.join(output_dir, 'training_history.json'), 'w') as f:
311
+ json.dump(history, f, indent=2, default=str)
312
+
313
+ print(f"\nTraining complete. Best val loss: {best_val_loss:.4f}")
314
+ print(f"Checkpoints saved to {output_dir}")
315
+
316
+ return model, history
317
+
318
+
319
+ # ============================================================
320
+ # Main entry point
321
+ # ============================================================
322
+
323
+ if __name__ == '__main__':
324
+ print("=" * 60)
325
+ print("AirTrackLM - Pretraining on Traffic Sample Data")
326
+ print("=" * 60)
327
+
328
+ # Configuration
329
+ config = AirTrackConfig(
330
+ d_model=256,
331
+ n_heads=8,
332
+ n_layers=8,
333
+ d_ff=1024,
334
+ dropout=0.1,
335
+ max_seq_len=256,
336
+ geohash_mode='absolute',
337
+ )
338
+
339
+ # Load data
340
+ print("\n1. Loading traffic sample data...")
341
+ raw_trajs = load_traffic_sample()
342
+ print(f" Loaded {len(raw_trajs)} raw trajectories")
343
+
344
+ # Process
345
+ print("\n2. Processing trajectories...")
346
+ processor = TrajectoryProcessor(resample_dt=5.0)
347
+
348
+ seq_len = 64 # 64 states × 5s = ~5 minutes per window
349
+ stride = 32 # 50% overlap
350
+
351
+ dataset = build_dataset(raw_trajs, processor, seq_len=seq_len, stride=stride)
352
+
353
+ if len(dataset) == 0:
354
+ print("ERROR: No valid windows found. Check data.")
355
+ exit(1)
356
+
357
+ # Split
358
+ n_val = max(1, int(0.15 * len(dataset)))
359
+ n_train = len(dataset) - n_val
360
+ train_dataset, val_dataset = random_split(dataset, [n_train, n_val])
361
+
362
+ print(f"\n3. Dataset split: {n_train} train, {n_val} val")
363
+
364
+ # Train
365
+ print("\n4. Starting training...")
366
+ model, history = train(
367
+ config=config,
368
+ train_dataset=train_dataset,
369
+ val_dataset=val_dataset,
370
+ output_dir='./checkpoints',
371
+ n_epochs=10, # quick run for testing
372
+ batch_size=16,
373
+ learning_rate=5e-4,
374
+ patience=5,
375
+ device='auto',
376
+ use_trackio=False,
377
+ )
378
+
379
+ print("\nDone!")