Jdice27 commited on
Commit
b41eaa4
·
verified ·
1 Parent(s): a67d720

Update train_cpu.py - fix heteroscedastic loss clamping

Browse files
Files changed (1) hide show
  1. train_cpu.py +221 -0
train_cpu.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AirTrackLM - CPU Training + Hub Push
3
+ =====================================
4
+ Trains the full model on CPU and pushes checkpoints + source to HF Hub.
5
+ """
6
+
7
+ import os
8
+ import sys
9
+ import time
10
+ import json
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ import numpy as np
15
+ from torch.utils.data import DataLoader, random_split
16
+ from torch.optim import AdamW
17
+ from torch.optim.lr_scheduler import CosineAnnealingLR
18
+ from pathlib import Path
19
+
20
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
21
+
22
+ from data_pipeline import TrajectoryProcessor, load_traffic_sample, build_dataset
23
+ from model import AirTrackLM, AirTrackConfig, NextStateLoss
24
+
25
+
26
+ def collate_fn(batch):
27
+ max_len = max(b['cog_bins'].size(0) for b in batch)
28
+ collated = {}
29
+ for key in batch[0].keys():
30
+ tensors = [b[key] for b in batch]
31
+ if key == 'prompt':
32
+ collated[key] = torch.stack(tensors)
33
+ else:
34
+ padded = []
35
+ for t in tensors:
36
+ if t.dim() == 1:
37
+ padded.append(F.pad(t, (0, max_len - t.size(0)), value=0))
38
+ elif t.dim() == 2:
39
+ padded.append(F.pad(t, (0, 0, 0, max_len - t.size(0)), value=0))
40
+ else:
41
+ padded.append(t)
42
+ collated[key] = torch.stack(padded)
43
+ return collated
44
+
45
+
46
+ @torch.no_grad()
47
+ def evaluate(model, dataloader, loss_fn, device):
48
+ model.eval()
49
+ loss_components = {}
50
+ n_batches = 0
51
+ correct = {'cog': 0, 'sog': 0, 'rot': 0, 'alt_rate': 0}
52
+ total_preds = 0
53
+
54
+ for batch in dataloader:
55
+ batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
56
+ predictions = model(batch)
57
+ _, loss_log = loss_fn(predictions, batch)
58
+ for k, v in loss_log.items():
59
+ loss_components[k] = loss_components.get(k, 0) + v
60
+ n_batches += 1
61
+ for feat in ['cog', 'sog', 'rot', 'alt_rate']:
62
+ pred = predictions[f'{feat}_logits'][:, :-1, :].argmax(dim=-1)
63
+ target = batch[f'{feat}_bins'][:, 1:]
64
+ correct[feat] += (pred == target).sum().item()
65
+ total_preds += batch['cog_bins'][:, 1:].numel()
66
+
67
+ metrics = {k: v / max(n_batches, 1) for k, v in loss_components.items()}
68
+ for feat in ['cog', 'sog', 'rot', 'alt_rate']:
69
+ metrics[f'{feat}_acc'] = correct[feat] / max(total_preds, 1)
70
+ return metrics
71
+
72
+
73
+ def main():
74
+ print("=" * 70)
75
+ print("AirTrackLM - Training (CPU) + Push to Hub")
76
+ print("=" * 70)
77
+
78
+ HUB_MODEL_ID = "Jdice27/AirTrackLM"
79
+ device = torch.device('cpu')
80
+
81
+ config = AirTrackConfig(
82
+ d_model=256, n_heads=8, n_layers=8, d_ff=1024,
83
+ dropout=0.1, max_seq_len=256, geohash_mode='absolute',
84
+ use_multi_uncertainty=True, n_uncert_methods=4,
85
+ use_heteroscedastic=True, predict_geohash=True, predict_continuous=True,
86
+ )
87
+
88
+ SEQ_LEN, STRIDE = 64, 32
89
+ BATCH_SIZE = 16
90
+ N_EPOCHS = 30
91
+ LR = 5e-4
92
+ PATIENCE = 8
93
+
94
+ # ---- Load Data ----
95
+ print("\n1. Loading data...")
96
+ t0 = time.time()
97
+ raw_trajs = []
98
+ for name in ['quickstart', 'switzerland', 'savan']:
99
+ try:
100
+ trajs = load_traffic_sample(name)
101
+ raw_trajs.extend(trajs)
102
+ print(f" {name}: {len(trajs)} flights")
103
+ except Exception as e:
104
+ print(f" {name}: failed ({e})")
105
+ print(f" Total: {len(raw_trajs)} flights ({time.time()-t0:.1f}s)")
106
+
107
+ # ---- Process ----
108
+ print("\n2. Processing...")
109
+ t0 = time.time()
110
+ processor = TrajectoryProcessor(resample_dt=5.0)
111
+ dataset = build_dataset(raw_trajs, processor, seq_len=SEQ_LEN, stride=STRIDE)
112
+ print(f" {time.time()-t0:.1f}s")
113
+
114
+ n_val = max(1, int(0.15 * len(dataset)))
115
+ train_ds, val_ds = random_split(dataset, [len(dataset) - n_val, n_val],
116
+ generator=torch.Generator().manual_seed(42))
117
+ print(f" Train: {len(train_ds)}, Val: {len(val_ds)}")
118
+
119
+ # ---- Model ----
120
+ model = AirTrackLM(config)
121
+ print(f"\n3. Model: {sum(p.numel() for p in model.parameters()):,} params")
122
+
123
+ train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
124
+ collate_fn=collate_fn, num_workers=0)
125
+ val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
126
+ collate_fn=collate_fn, num_workers=0)
127
+
128
+ loss_fn = NextStateLoss(config)
129
+ optimizer = AdamW(model.parameters(), lr=LR, weight_decay=0.01)
130
+ scheduler = CosineAnnealingLR(optimizer, T_max=N_EPOCHS * len(train_loader), eta_min=LR * 0.01)
131
+
132
+ output_dir = Path('./checkpoints')
133
+ output_dir.mkdir(exist_ok=True)
134
+
135
+ best_val_loss = float('inf')
136
+ patience_counter = 0
137
+ history = []
138
+
139
+ print(f"\n4. Training: {N_EPOCHS} epochs")
140
+ print("=" * 70)
141
+
142
+ for epoch in range(N_EPOCHS):
143
+ t_epoch = time.time()
144
+ model.train()
145
+ train_loss = 0
146
+ train_comp = {}
147
+ n_b = 0
148
+
149
+ for batch in train_loader:
150
+ predictions = model(batch)
151
+ loss, log = loss_fn(predictions, batch)
152
+ loss.backward()
153
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
154
+ optimizer.step()
155
+ optimizer.zero_grad()
156
+ scheduler.step()
157
+
158
+ train_loss += log['total']
159
+ for k, v in log.items():
160
+ train_comp[k] = train_comp.get(k, 0) + v
161
+ n_b += 1
162
+
163
+ train_avg = {k: v/n_b for k, v in train_comp.items()}
164
+ val_metrics = evaluate(model, val_loader, loss_fn, device)
165
+
166
+ elapsed = time.time() - t_epoch
167
+ improved = val_metrics['total'] < best_val_loss
168
+
169
+ print(f"Epoch {epoch+1:02d}/{N_EPOCHS} [{elapsed:.0f}s] {'★' if improved else ' '} "
170
+ f"train={train_avg['total']:.3f} val={val_metrics['total']:.3f} "
171
+ f"COG={val_metrics.get('cog_acc',0):.3f} SOG={val_metrics.get('sog_acc',0):.3f} "
172
+ f"ROT={val_metrics.get('rot_acc',0):.3f} AltRate={val_metrics.get('alt_rate_acc',0):.3f}")
173
+
174
+ history.append({'epoch': epoch+1, 'train': train_avg, 'val': val_metrics,
175
+ 'lr': scheduler.get_last_lr()[0], 'time': elapsed})
176
+
177
+ if improved:
178
+ best_val_loss = val_metrics['total']
179
+ patience_counter = 0
180
+ torch.save({
181
+ 'epoch': epoch+1, 'model_state_dict': model.state_dict(),
182
+ 'config': config.__dict__, 'val_loss': best_val_loss, 'val_metrics': val_metrics,
183
+ }, output_dir / 'best_model.pt')
184
+ else:
185
+ patience_counter += 1
186
+ if patience_counter >= PATIENCE:
187
+ print(f"Early stopping at epoch {epoch+1}")
188
+ break
189
+
190
+ # ---- Save + Push ----
191
+ print("\n" + "=" * 70)
192
+ print("Saving and pushing to Hub...")
193
+
194
+ torch.save({
195
+ 'model_state_dict': model.state_dict(), 'config': config.__dict__,
196
+ 'best_val_loss': best_val_loss, 'history': history,
197
+ }, output_dir / 'final_model.pt')
198
+
199
+ with open(output_dir / 'training_history.json', 'w') as f:
200
+ json.dump(history, f, indent=2, default=str)
201
+ with open(output_dir / 'config.json', 'w') as f:
202
+ json.dump(config.__dict__, f, indent=2)
203
+
204
+ try:
205
+ from huggingface_hub import HfApi
206
+ api = HfApi()
207
+ api.upload_folder(
208
+ folder_path=str(output_dir), repo_id=HUB_MODEL_ID,
209
+ repo_type="model", commit_message=f"Training: val_loss={best_val_loss:.4f}",
210
+ )
211
+ print(f"✓ Checkpoints pushed to https://huggingface.co/{HUB_MODEL_ID}")
212
+ except Exception as e:
213
+ print(f"Push failed: {e}")
214
+
215
+ print(f"\nBest val loss: {best_val_loss:.4f}")
216
+ print(f"Final metrics: COG={val_metrics.get('cog_acc',0):.3f} SOG={val_metrics.get('sog_acc',0):.3f}")
217
+ print("Done!")
218
+
219
+
220
+ if __name__ == '__main__':
221
+ main()