Jdice27 commited on
Commit
faf2651
·
verified ·
1 Parent(s): d12cc01

Add train_full.py

Browse files
Files changed (1) hide show
  1. train_full.py +383 -0
train_full.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AirTrackLM - Full GPU Training Script
3
+ ======================================
4
+ Trains the full-size model on traffic data and pushes to HuggingFace Hub.
5
+
6
+ Features:
7
+ - Full-size model (256d, 8 heads, 8 layers, ~7M params)
8
+ - Multi-method uncertainty (4 preprocessing methods + learned heteroscedastic)
9
+ - All kinematic features: COG, SOG, ROT, alt_rate
10
+ - 3D binary geohash (40-bit × 3 axes)
11
+ - Sub-second temporal encoding
12
+ - ENU coordinate system with 3-point derivative
13
+ - Trackio monitoring
14
+ - Push to Hub on completion
15
+ """
16
+
17
+ import os
18
+ import sys
19
+ import time
20
+ import json
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ import numpy as np
25
+ from torch.utils.data import DataLoader, random_split
26
+ from torch.optim import AdamW
27
+ from torch.optim.lr_scheduler import CosineAnnealingLR
28
+ from pathlib import Path
29
+
30
+ # Add script directory to path
31
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
32
+
33
+ from data_pipeline import (
34
+ TrajectoryProcessor, FeatureBins, load_traffic_sample, build_dataset
35
+ )
36
+ from model import AirTrackLM, AirTrackConfig, NextStateLoss
37
+
38
+
39
+ def collate_fn(batch):
40
+ """Custom collate: pad variable-length sequences to max length in batch."""
41
+ max_len = max(b['cog_bins'].size(0) for b in batch)
42
+ collated = {}
43
+ for key in batch[0].keys():
44
+ tensors = [b[key] for b in batch]
45
+ if key == 'prompt':
46
+ collated[key] = torch.stack(tensors)
47
+ else:
48
+ padded = []
49
+ for t in tensors:
50
+ if t.dim() == 1:
51
+ pad_size = max_len - t.size(0)
52
+ padded.append(F.pad(t, (0, pad_size), value=0))
53
+ elif t.dim() == 2:
54
+ pad_size = max_len - t.size(0)
55
+ padded.append(F.pad(t, (0, 0, 0, pad_size), value=0))
56
+ else:
57
+ padded.append(t)
58
+ collated[key] = torch.stack(padded)
59
+ return collated
60
+
61
+
62
+ @torch.no_grad()
63
+ def evaluate(model, dataloader, loss_fn, device):
64
+ model.eval()
65
+ total_loss = 0.0
66
+ loss_components = {}
67
+ n_batches = 0
68
+ correct = {'cog': 0, 'sog': 0, 'rot': 0, 'alt_rate': 0}
69
+ total_preds = 0
70
+
71
+ for batch in dataloader:
72
+ batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
73
+ predictions = model(batch)
74
+ loss, loss_log = loss_fn(predictions, batch)
75
+
76
+ total_loss += loss_log['total']
77
+ for k, v in loss_log.items():
78
+ loss_components[k] = loss_components.get(k, 0) + v
79
+ n_batches += 1
80
+
81
+ for feat in ['cog', 'sog', 'rot', 'alt_rate']:
82
+ pred_logits = predictions[f'{feat}_logits'][:, :-1, :]
83
+ target = batch[f'{feat}_bins'][:, 1:]
84
+ pred_class = pred_logits.argmax(dim=-1)
85
+ correct[feat] += (pred_class == target).sum().item()
86
+ total_preds += batch['cog_bins'][:, 1:].numel()
87
+
88
+ avg_metrics = {k: v / max(n_batches, 1) for k, v in loss_components.items()}
89
+ for feat in ['cog', 'sog', 'rot', 'alt_rate']:
90
+ avg_metrics[f'{feat}_acc'] = correct[feat] / max(total_preds, 1)
91
+ return avg_metrics
92
+
93
+
94
+ def main():
95
+ print("=" * 70)
96
+ print("AirTrackLM - Full Training")
97
+ print("=" * 70)
98
+
99
+ # ---- Configuration ----
100
+ HUB_MODEL_ID = "Jdice27/AirTrackLM"
101
+
102
+ config = AirTrackConfig(
103
+ d_model=256,
104
+ n_heads=8,
105
+ n_layers=8,
106
+ d_ff=1024,
107
+ dropout=0.1,
108
+ max_seq_len=256,
109
+ geohash_mode='absolute',
110
+ use_multi_uncertainty=True,
111
+ n_uncert_methods=4,
112
+ use_heteroscedastic=True,
113
+ predict_geohash=True,
114
+ predict_continuous=True,
115
+ )
116
+
117
+ SEQ_LEN = 64 # 64 × 5s = ~5.3 min windows
118
+ STRIDE = 32 # 50% overlap
119
+ BATCH_SIZE = 32
120
+ N_EPOCHS = 50
121
+ LR = 5e-4
122
+ WEIGHT_DECAY = 0.01
123
+ PATIENCE = 10
124
+ RESAMPLE_DT = 5.0
125
+
126
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
127
+ print(f"Device: {device}")
128
+ if device.type == 'cuda':
129
+ print(f"GPU: {torch.cuda.get_device_name(0)}")
130
+ print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
131
+
132
+ # ---- Trackio ----
133
+ try:
134
+ import trackio
135
+ tracker = trackio.init(name="AirTrackLM-pretrain")
136
+ print("Trackio initialized ✓")
137
+ except Exception as e:
138
+ print(f"Trackio not available: {e}")
139
+ tracker = None
140
+
141
+ # ---- Load Data ----
142
+ print("\n1. Loading traffic sample data...")
143
+ t0 = time.time()
144
+
145
+ # Load multiple sample collections for more data
146
+ raw_trajs = []
147
+ for sample_name in ['quickstart', 'switzerland', 'savan']:
148
+ try:
149
+ trajs = load_traffic_sample(sample_name)
150
+ raw_trajs.extend(trajs)
151
+ print(f" {sample_name}: {len(trajs)} flights")
152
+ except Exception as e:
153
+ print(f" {sample_name}: failed ({e})")
154
+
155
+ print(f" Total: {len(raw_trajs)} flights in {time.time()-t0:.1f}s")
156
+
157
+ # Data audit
158
+ lengths = [len(t['timestamps']) for t in raw_trajs]
159
+ print(f" Trajectory lengths: min={min(lengths)}, max={max(lengths)}, median={np.median(lengths):.0f}")
160
+
161
+ # ---- Process ----
162
+ print("\n2. Processing trajectories...")
163
+ t0 = time.time()
164
+ processor = TrajectoryProcessor(resample_dt=RESAMPLE_DT)
165
+ dataset = build_dataset(raw_trajs, processor, seq_len=SEQ_LEN, stride=STRIDE)
166
+ print(f" Processing took {time.time()-t0:.1f}s")
167
+
168
+ # Split
169
+ n_val = max(1, int(0.15 * len(dataset)))
170
+ n_train = len(dataset) - n_val
171
+ train_ds, val_ds = random_split(
172
+ dataset, [n_train, n_val],
173
+ generator=torch.Generator().manual_seed(42)
174
+ )
175
+ print(f"\n3. Split: {n_train} train, {n_val} val")
176
+
177
+ # ---- Model ----
178
+ model = AirTrackLM(config).to(device)
179
+ param_counts = model.count_parameters()
180
+ print(f"\n4. Model: {param_counts['total']:,} parameters")
181
+ for name, count in param_counts.items():
182
+ if name not in ['total', 'trainable']:
183
+ print(f" {name}: {count:,}")
184
+
185
+ # ---- Data Loaders ----
186
+ train_loader = DataLoader(
187
+ train_ds, batch_size=BATCH_SIZE, shuffle=True,
188
+ collate_fn=collate_fn, num_workers=2, pin_memory=(device.type == 'cuda'),
189
+ )
190
+ val_loader = DataLoader(
191
+ val_ds, batch_size=BATCH_SIZE, shuffle=False,
192
+ collate_fn=collate_fn, num_workers=2, pin_memory=(device.type == 'cuda'),
193
+ )
194
+ print(f" {len(train_loader)} train batches, {len(val_loader)} val batches")
195
+
196
+ # ---- Optimizer & Scheduler ----
197
+ loss_fn = NextStateLoss(config)
198
+ optimizer = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY, betas=(0.9, 0.999))
199
+ total_steps = N_EPOCHS * len(train_loader)
200
+ scheduler = CosineAnnealingLR(optimizer, T_max=total_steps, eta_min=LR * 0.01)
201
+
202
+ # Mixed precision
203
+ scaler = torch.amp.GradScaler('cuda') if device.type == 'cuda' else None
204
+
205
+ # ---- Training ----
206
+ output_dir = Path('./checkpoints')
207
+ output_dir.mkdir(exist_ok=True)
208
+
209
+ best_val_loss = float('inf')
210
+ patience_counter = 0
211
+ history = []
212
+ global_step = 0
213
+
214
+ print(f"\n{'='*70}")
215
+ print(f"Training: {N_EPOCHS} epochs, batch_size={BATCH_SIZE}, lr={LR}")
216
+ print(f"{'='*70}\n")
217
+
218
+ for epoch in range(N_EPOCHS):
219
+ t_epoch = time.time()
220
+ model.train()
221
+
222
+ train_loss = 0.0
223
+ train_components = {}
224
+ n_batches = 0
225
+
226
+ for batch_idx, batch in enumerate(train_loader):
227
+ batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
228
+
229
+ if scaler is not None:
230
+ with torch.amp.autocast('cuda'):
231
+ predictions = model(batch)
232
+ loss, loss_log = loss_fn(predictions, batch)
233
+ scaler.scale(loss).backward()
234
+ scaler.unscale_(optimizer)
235
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
236
+ scaler.step(optimizer)
237
+ scaler.update()
238
+ else:
239
+ predictions = model(batch)
240
+ loss, loss_log = loss_fn(predictions, batch)
241
+ loss.backward()
242
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
243
+ optimizer.step()
244
+
245
+ optimizer.zero_grad()
246
+ scheduler.step()
247
+ global_step += 1
248
+
249
+ train_loss += loss_log['total']
250
+ for k, v in loss_log.items():
251
+ train_components[k] = train_components.get(k, 0) + v
252
+ n_batches += 1
253
+
254
+ # Log every 20 steps
255
+ if tracker and global_step % 20 == 0:
256
+ trackio.log({
257
+ 'train/loss': loss_log['total'],
258
+ 'train/lr': scheduler.get_last_lr()[0],
259
+ 'train/step': global_step,
260
+ **{f'train/{k}': v for k, v in loss_log.items() if k != 'total'},
261
+ })
262
+
263
+ if (batch_idx + 1) % 50 == 0:
264
+ print(f" Epoch {epoch+1} Batch {batch_idx+1}/{len(train_loader)} | "
265
+ f"Loss: {train_loss/n_batches:.4f}")
266
+
267
+ train_avg = {k: v / n_batches for k, v in train_components.items()}
268
+
269
+ # Validate
270
+ val_metrics = evaluate(model, val_loader, loss_fn, device)
271
+
272
+ elapsed = time.time() - t_epoch
273
+ improved = val_metrics['total'] < best_val_loss
274
+
275
+ print(f"\nEpoch {epoch+1}/{N_EPOCHS} [{elapsed:.1f}s] {'★' if improved else ' '}")
276
+ print(f" Train: loss={train_avg['total']:.4f} "
277
+ f"(geo={train_avg.get('geohash',0):.4f}, cont={train_avg.get('continuous',0):.4f}, "
278
+ f"cog={train_avg.get('cog',0):.4f}, sog={train_avg.get('sog',0):.4f})")
279
+ print(f" Val: loss={val_metrics['total']:.4f}")
280
+ print(f" Val Acc - COG: {val_metrics.get('cog_acc',0):.3f}, "
281
+ f"SOG: {val_metrics.get('sog_acc',0):.3f}, "
282
+ f"ROT: {val_metrics.get('rot_acc',0):.3f}, "
283
+ f"AltRate: {val_metrics.get('alt_rate_acc',0):.3f}")
284
+ print(f" LR: {scheduler.get_last_lr()[0]:.6f}")
285
+
286
+ # Trackio epoch log
287
+ if tracker:
288
+ trackio.log({
289
+ 'epoch': epoch + 1,
290
+ 'val/loss': val_metrics['total'],
291
+ **{f'val/{k}': v for k, v in val_metrics.items()},
292
+ 'train/epoch_loss': train_avg['total'],
293
+ })
294
+
295
+ history.append({
296
+ 'epoch': epoch + 1,
297
+ 'train': train_avg,
298
+ 'val': val_metrics,
299
+ 'lr': scheduler.get_last_lr()[0],
300
+ 'time': elapsed,
301
+ })
302
+
303
+ # Checkpointing
304
+ if improved:
305
+ best_val_loss = val_metrics['total']
306
+ patience_counter = 0
307
+ torch.save({
308
+ 'epoch': epoch + 1,
309
+ 'model_state_dict': model.state_dict(),
310
+ 'optimizer_state_dict': optimizer.state_dict(),
311
+ 'config': config.__dict__,
312
+ 'val_loss': best_val_loss,
313
+ 'val_metrics': val_metrics,
314
+ }, output_dir / 'best_model.pt')
315
+ print(f" ★ New best model (val_loss={best_val_loss:.4f})")
316
+ else:
317
+ patience_counter += 1
318
+ if patience_counter >= PATIENCE:
319
+ print(f"\nEarly stopping at epoch {epoch+1}")
320
+ break
321
+ print()
322
+
323
+ # ---- Save Final + Push to Hub ----
324
+ print("\n" + "=" * 70)
325
+ print("Training complete. Saving and pushing to Hub...")
326
+
327
+ # Save final checkpoint
328
+ torch.save({
329
+ 'epoch': epoch + 1,
330
+ 'model_state_dict': model.state_dict(),
331
+ 'config': config.__dict__,
332
+ 'best_val_loss': best_val_loss,
333
+ 'history': history,
334
+ }, output_dir / 'final_model.pt')
335
+
336
+ # Save training history
337
+ with open(output_dir / 'training_history.json', 'w') as f:
338
+ json.dump(history, f, indent=2, default=str)
339
+
340
+ # Save config
341
+ with open(output_dir / 'config.json', 'w') as f:
342
+ json.dump(config.__dict__, f, indent=2)
343
+
344
+ # Push to HuggingFace Hub
345
+ try:
346
+ from huggingface_hub import HfApi, upload_folder
347
+ api = HfApi()
348
+
349
+ # Upload all checkpoint files
350
+ api.upload_folder(
351
+ folder_path=str(output_dir),
352
+ repo_id=HUB_MODEL_ID,
353
+ repo_type="model",
354
+ commit_message=f"Training complete: val_loss={best_val_loss:.4f}",
355
+ )
356
+ print(f"✓ Model pushed to https://huggingface.co/{HUB_MODEL_ID}")
357
+ except Exception as e:
358
+ print(f"Failed to push to Hub: {e}")
359
+
360
+ # Also upload source files
361
+ try:
362
+ script_dir = os.path.dirname(os.path.abspath(__file__))
363
+ for fname in ['data_pipeline.py', 'model.py', 'train.py', 'uncertainty.py',
364
+ 'train_full.py', 'ARCHITECTURE.md']:
365
+ fpath = os.path.join(script_dir, fname)
366
+ if os.path.exists(fpath):
367
+ api.upload_file(
368
+ path_or_fileobj=fpath,
369
+ path_in_repo=fname,
370
+ repo_id=HUB_MODEL_ID,
371
+ repo_type="model",
372
+ )
373
+ print(f"✓ Source files uploaded to {HUB_MODEL_ID}")
374
+ except Exception as e:
375
+ print(f"Failed to upload source files: {e}")
376
+
377
+ print(f"\nBest val loss: {best_val_loss:.4f}")
378
+ print(f"Final val metrics: {val_metrics}")
379
+ print("Done!")
380
+
381
+
382
+ if __name__ == '__main__':
383
+ main()