Jdice27 commited on
Commit
92256a1
·
verified ·
1 Parent(s): b2d2647

Upload train_full.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_full.py +84 -126
train_full.py CHANGED
@@ -1,17 +1,8 @@
1
  """
2
- AirTrackLM - Full GPU Training Script
3
- ======================================
4
- Trains the full-size model on traffic data and pushes to HuggingFace Hub.
5
-
6
- Features:
7
- - Full-size model (256d, 8 heads, 8 layers, ~7M params)
8
- - Multi-method uncertainty (4 preprocessing methods + learned heteroscedastic)
9
- - All kinematic features: COG, SOG, ROT, alt_rate
10
- - 3D binary geohash (40-bit × 3 axes)
11
- - Sub-second temporal encoding
12
- - ENU coordinate system with 3-point derivative
13
- - Trackio monitoring
14
- - Push to Hub on completion
15
  """
16
 
17
  import os
@@ -27,17 +18,13 @@ from torch.optim import AdamW
27
  from torch.optim.lr_scheduler import CosineAnnealingLR
28
  from pathlib import Path
29
 
30
- # Add script directory to path
31
  sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
32
 
33
- from data_pipeline import (
34
- TrajectoryProcessor, FeatureBins, load_traffic_sample, build_dataset
35
- )
36
  from model import AirTrackLM, AirTrackConfig, NextStateLoss
37
 
38
 
39
  def collate_fn(batch):
40
- """Custom collate: pad variable-length sequences to max length in batch."""
41
  max_len = max(b['cog_bins'].size(0) for b in batch)
42
  collated = {}
43
  for key in batch[0].keys():
@@ -48,11 +35,9 @@ def collate_fn(batch):
48
  padded = []
49
  for t in tensors:
50
  if t.dim() == 1:
51
- pad_size = max_len - t.size(0)
52
- padded.append(F.pad(t, (0, pad_size), value=0))
53
  elif t.dim() == 2:
54
- pad_size = max_len - t.size(0)
55
- padded.append(F.pad(t, (0, 0, 0, pad_size), value=0))
56
  else:
57
  padded.append(t)
58
  collated[key] = torch.stack(padded)
@@ -72,17 +57,14 @@ def evaluate(model, dataloader, loss_fn, device):
72
  batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
73
  predictions = model(batch)
74
  loss, loss_log = loss_fn(predictions, batch)
75
-
76
  total_loss += loss_log['total']
77
  for k, v in loss_log.items():
78
  loss_components[k] = loss_components.get(k, 0) + v
79
  n_batches += 1
80
-
81
  for feat in ['cog', 'sog', 'rot', 'alt_rate']:
82
  pred_logits = predictions[f'{feat}_logits'][:, :-1, :]
83
  target = batch[f'{feat}_bins'][:, 1:]
84
- pred_class = pred_logits.argmax(dim=-1)
85
- correct[feat] += (pred_class == target).sum().item()
86
  total_preds += batch['cog_bins'][:, 1:].numel()
87
 
88
  avg_metrics = {k: v / max(n_batches, 1) for k, v in loss_components.items()}
@@ -93,29 +75,20 @@ def evaluate(model, dataloader, loss_fn, device):
93
 
94
  def main():
95
  print("=" * 70)
96
- print("AirTrackLM - Full Training")
97
  print("=" * 70)
98
 
99
- # ---- Configuration ----
100
  HUB_MODEL_ID = "Jdice27/AirTrackLM"
101
 
102
  config = AirTrackConfig(
103
- d_model=256,
104
- n_heads=8,
105
- n_layers=8,
106
- d_ff=1024,
107
- dropout=0.1,
108
- max_seq_len=256,
109
- geohash_mode='absolute',
110
- use_multi_uncertainty=True,
111
- n_uncert_methods=4,
112
- use_heteroscedastic=True,
113
- predict_geohash=True,
114
- predict_continuous=True,
115
  )
116
 
117
- SEQ_LEN = 64 # 64 × 5s = ~5.3 min windows
118
- STRIDE = 32 # 50% overlap
119
  BATCH_SIZE = 32
120
  N_EPOCHS = 50
121
  LR = 5e-4
@@ -127,24 +100,22 @@ def main():
127
  print(f"Device: {device}")
128
  if device.type == 'cuda':
129
  print(f"GPU: {torch.cuda.get_device_name(0)}")
130
- print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
131
 
132
  # ---- Trackio ----
 
133
  try:
134
  import trackio
135
  tracker = trackio.init(name="AirTrackLM-pretrain")
136
  print("Trackio initialized ✓")
137
  except Exception as e:
138
- print(f"Trackio not available: {e}")
139
- tracker = None
140
 
141
  # ---- Load Data ----
142
  print("\n1. Loading traffic sample data...")
143
  t0 = time.time()
144
-
145
- # Load multiple sample collections for more data
146
  raw_trajs = []
147
- for sample_name in ['quickstart', 'switzerland', 'savan']:
148
  try:
149
  trajs = load_traffic_sample(sample_name)
150
  raw_trajs.extend(trajs)
@@ -152,37 +123,51 @@ def main():
152
  except Exception as e:
153
  print(f" {sample_name}: failed ({e})")
154
 
 
 
 
 
 
 
 
 
 
155
  print(f" Total: {len(raw_trajs)} flights in {time.time()-t0:.1f}s")
156
 
 
 
 
 
157
  # Data audit
158
  lengths = [len(t['timestamps']) for t in raw_trajs]
159
- print(f" Trajectory lengths: min={min(lengths)}, max={max(lengths)}, median={np.median(lengths):.0f}")
160
 
161
  # ---- Process ----
162
  print("\n2. Processing trajectories...")
163
  t0 = time.time()
164
  processor = TrajectoryProcessor(resample_dt=RESAMPLE_DT)
165
  dataset = build_dataset(raw_trajs, processor, seq_len=SEQ_LEN, stride=STRIDE)
166
- print(f" Processing took {time.time()-t0:.1f}s")
 
 
 
 
167
 
168
  # Split
169
  n_val = max(1, int(0.15 * len(dataset)))
170
  n_train = len(dataset) - n_val
171
- train_ds, val_ds = random_split(
172
- dataset, [n_train, n_val],
173
- generator=torch.Generator().manual_seed(42)
174
- )
175
  print(f"\n3. Split: {n_train} train, {n_val} val")
176
 
177
  # ---- Model ----
178
  model = AirTrackLM(config).to(device)
179
  param_counts = model.count_parameters()
180
- print(f"\n4. Model: {param_counts['total']:,} parameters")
181
  for name, count in param_counts.items():
182
  if name not in ['total', 'trainable']:
183
  print(f" {name}: {count:,}")
184
 
185
- # ---- Data Loaders ----
186
  train_loader = DataLoader(
187
  train_ds, batch_size=BATCH_SIZE, shuffle=True,
188
  collate_fn=collate_fn, num_workers=2, pin_memory=(device.type == 'cuda'),
@@ -193,16 +178,14 @@ def main():
193
  )
194
  print(f" {len(train_loader)} train batches, {len(val_loader)} val batches")
195
 
196
- # ---- Optimizer & Scheduler ----
197
  loss_fn = NextStateLoss(config)
198
  optimizer = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY, betas=(0.9, 0.999))
199
  total_steps = N_EPOCHS * len(train_loader)
200
  scheduler = CosineAnnealingLR(optimizer, T_max=total_steps, eta_min=LR * 0.01)
201
-
202
- # Mixed precision
203
  scaler = torch.amp.GradScaler('cuda') if device.type == 'cuda' else None
204
 
205
- # ---- Training ----
206
  output_dir = Path('./checkpoints')
207
  output_dir.mkdir(exist_ok=True)
208
 
@@ -212,13 +195,12 @@ def main():
212
  global_step = 0
213
 
214
  print(f"\n{'='*70}")
215
- print(f"Training: {N_EPOCHS} epochs, batch_size={BATCH_SIZE}, lr={LR}")
216
  print(f"{'='*70}\n")
217
 
218
  for epoch in range(N_EPOCHS):
219
  t_epoch = time.time()
220
  model.train()
221
-
222
  train_loss = 0.0
223
  train_components = {}
224
  n_batches = 0
@@ -251,56 +233,47 @@ def main():
251
  train_components[k] = train_components.get(k, 0) + v
252
  n_batches += 1
253
 
254
- # Log every 20 steps
255
  if tracker and global_step % 20 == 0:
256
- trackio.log({
257
- 'train/loss': loss_log['total'],
258
- 'train/lr': scheduler.get_last_lr()[0],
259
- 'train/step': global_step,
260
- **{f'train/{k}': v for k, v in loss_log.items() if k != 'total'},
261
- })
 
 
262
 
263
  if (batch_idx + 1) % 50 == 0:
264
- print(f" Epoch {epoch+1} Batch {batch_idx+1}/{len(train_loader)} | "
265
- f"Loss: {train_loss/n_batches:.4f}")
266
 
267
  train_avg = {k: v / n_batches for k, v in train_components.items()}
268
-
269
- # Validate
270
  val_metrics = evaluate(model, val_loader, loss_fn, device)
271
 
272
  elapsed = time.time() - t_epoch
273
  improved = val_metrics['total'] < best_val_loss
274
 
275
- print(f"\nEpoch {epoch+1}/{N_EPOCHS} [{elapsed:.1f}s] {'★' if improved else ' '}")
276
- print(f" Train: loss={train_avg['total']:.4f} "
277
- f"(geo={train_avg.get('geohash',0):.4f}, cont={train_avg.get('continuous',0):.4f}, "
278
- f"cog={train_avg.get('cog',0):.4f}, sog={train_avg.get('sog',0):.4f})")
279
- print(f" Val: loss={val_metrics['total']:.4f}")
280
- print(f" Val Acc - COG: {val_metrics.get('cog_acc',0):.3f}, "
281
- f"SOG: {val_metrics.get('sog_acc',0):.3f}, "
282
- f"ROT: {val_metrics.get('rot_acc',0):.3f}, "
283
- f"AltRate: {val_metrics.get('alt_rate_acc',0):.3f}")
284
  print(f" LR: {scheduler.get_last_lr()[0]:.6f}")
285
 
286
- # Trackio epoch log
287
  if tracker:
288
- trackio.log({
289
- 'epoch': epoch + 1,
290
- 'val/loss': val_metrics['total'],
291
- **{f'val/{k}': v for k, v in val_metrics.items()},
292
- 'train/epoch_loss': train_avg['total'],
293
- })
 
 
 
294
 
295
  history.append({
296
- 'epoch': epoch + 1,
297
- 'train': train_avg,
298
- 'val': val_metrics,
299
- 'lr': scheduler.get_last_lr()[0],
300
- 'time': elapsed,
301
  })
302
 
303
- # Checkpointing
304
  if improved:
305
  best_val_loss = val_metrics['total']
306
  patience_counter = 0
@@ -312,7 +285,7 @@ def main():
312
  'val_loss': best_val_loss,
313
  'val_metrics': val_metrics,
314
  }, output_dir / 'best_model.pt')
315
- print(f" ★ New best model (val_loss={best_val_loss:.4f})")
316
  else:
317
  patience_counter += 1
318
  if patience_counter >= PATIENCE:
@@ -320,62 +293,47 @@ def main():
320
  break
321
  print()
322
 
323
- # ---- Save Final + Push to Hub ----
324
  print("\n" + "=" * 70)
325
- print("Training complete. Saving and pushing to Hub...")
326
 
327
- # Save final checkpoint
328
  torch.save({
329
- 'epoch': epoch + 1,
330
- 'model_state_dict': model.state_dict(),
331
- 'config': config.__dict__,
332
- 'best_val_loss': best_val_loss,
333
- 'history': history,
334
  }, output_dir / 'final_model.pt')
335
 
336
- # Save training history
337
  with open(output_dir / 'training_history.json', 'w') as f:
338
  json.dump(history, f, indent=2, default=str)
339
 
340
- # Save config
341
  with open(output_dir / 'config.json', 'w') as f:
342
  json.dump(config.__dict__, f, indent=2)
343
 
344
- # Push to HuggingFace Hub
345
  try:
346
- from huggingface_hub import HfApi, upload_folder
347
  api = HfApi()
348
-
349
- # Upload all checkpoint files
350
  api.upload_folder(
351
- folder_path=str(output_dir),
352
- repo_id=HUB_MODEL_ID,
353
- repo_type="model",
354
- commit_message=f"Training complete: val_loss={best_val_loss:.4f}",
355
  )
356
- print(f"✓ Model pushed to https://huggingface.co/{HUB_MODEL_ID}")
357
  except Exception as e:
358
- print(f"Failed to push to Hub: {e}")
359
 
360
- # Also upload source files
361
  try:
362
  script_dir = os.path.dirname(os.path.abspath(__file__))
363
- for fname in ['data_pipeline.py', 'model.py', 'train.py', 'uncertainty.py',
364
- 'train_full.py', 'ARCHITECTURE.md']:
365
  fpath = os.path.join(script_dir, fname)
366
  if os.path.exists(fpath):
367
  api.upload_file(
368
- path_or_fileobj=fpath,
369
- path_in_repo=fname,
370
- repo_id=HUB_MODEL_ID,
371
- repo_type="model",
372
  )
373
- print(f"✓ Source files uploaded to {HUB_MODEL_ID}")
374
  except Exception as e:
375
- print(f"Failed to upload source files: {e}")
376
 
377
  print(f"\nBest val loss: {best_val_loss:.4f}")
378
- print(f"Final val metrics: {val_metrics}")
379
  print("Done!")
380
 
381
 
 
1
  """
2
+ AirTrackLM - Full Training Script
3
+ ===================================
4
+ Trains decoder-only transformer on traffic library ADS-B data.
5
+ Pushes model + source to HuggingFace Hub.
 
 
 
 
 
 
 
 
 
6
  """
7
 
8
  import os
 
18
  from torch.optim.lr_scheduler import CosineAnnealingLR
19
  from pathlib import Path
20
 
 
21
  sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
22
 
23
+ from data_pipeline import TrajectoryProcessor, FeatureBins, load_traffic_sample, build_dataset
 
 
24
  from model import AirTrackLM, AirTrackConfig, NextStateLoss
25
 
26
 
27
  def collate_fn(batch):
 
28
  max_len = max(b['cog_bins'].size(0) for b in batch)
29
  collated = {}
30
  for key in batch[0].keys():
 
35
  padded = []
36
  for t in tensors:
37
  if t.dim() == 1:
38
+ padded.append(F.pad(t, (0, max_len - t.size(0)), value=0))
 
39
  elif t.dim() == 2:
40
+ padded.append(F.pad(t, (0, 0, 0, max_len - t.size(0)), value=0))
 
41
  else:
42
  padded.append(t)
43
  collated[key] = torch.stack(padded)
 
57
  batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
58
  predictions = model(batch)
59
  loss, loss_log = loss_fn(predictions, batch)
 
60
  total_loss += loss_log['total']
61
  for k, v in loss_log.items():
62
  loss_components[k] = loss_components.get(k, 0) + v
63
  n_batches += 1
 
64
  for feat in ['cog', 'sog', 'rot', 'alt_rate']:
65
  pred_logits = predictions[f'{feat}_logits'][:, :-1, :]
66
  target = batch[f'{feat}_bins'][:, 1:]
67
+ correct[feat] += (pred_logits.argmax(dim=-1) == target).sum().item()
 
68
  total_preds += batch['cog_bins'][:, 1:].numel()
69
 
70
  avg_metrics = {k: v / max(n_batches, 1) for k, v in loss_components.items()}
 
75
 
76
  def main():
77
  print("=" * 70)
78
+ print("AirTrackLM - Full Training Pipeline")
79
  print("=" * 70)
80
 
 
81
  HUB_MODEL_ID = "Jdice27/AirTrackLM"
82
 
83
  config = AirTrackConfig(
84
+ d_model=256, n_heads=8, n_layers=8, d_ff=1024,
85
+ dropout=0.1, max_seq_len=256, geohash_mode='absolute',
86
+ use_multi_uncertainty=True, n_uncert_methods=4,
87
+ use_heteroscedastic=True, predict_geohash=True, predict_continuous=True,
 
 
 
 
 
 
 
 
88
  )
89
 
90
+ SEQ_LEN = 64
91
+ STRIDE = 32
92
  BATCH_SIZE = 32
93
  N_EPOCHS = 50
94
  LR = 5e-4
 
100
  print(f"Device: {device}")
101
  if device.type == 'cuda':
102
  print(f"GPU: {torch.cuda.get_device_name(0)}")
103
+ print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
104
 
105
  # ---- Trackio ----
106
+ tracker = None
107
  try:
108
  import trackio
109
  tracker = trackio.init(name="AirTrackLM-pretrain")
110
  print("Trackio initialized ✓")
111
  except Exception as e:
112
+ print(f"Trackio: {e}")
 
113
 
114
  # ---- Load Data ----
115
  print("\n1. Loading traffic sample data...")
116
  t0 = time.time()
 
 
117
  raw_trajs = []
118
+ for sample_name in ['quickstart']:
119
  try:
120
  trajs = load_traffic_sample(sample_name)
121
  raw_trajs.extend(trajs)
 
123
  except Exception as e:
124
  print(f" {sample_name}: failed ({e})")
125
 
126
+ # Try additional samples
127
+ for sample_name in ['switzerland', 'savan']:
128
+ try:
129
+ trajs = load_traffic_sample(sample_name)
130
+ raw_trajs.extend(trajs)
131
+ print(f" {sample_name}: {len(trajs)} flights")
132
+ except Exception as e:
133
+ print(f" {sample_name}: skipped ({e})")
134
+
135
  print(f" Total: {len(raw_trajs)} flights in {time.time()-t0:.1f}s")
136
 
137
+ if len(raw_trajs) == 0:
138
+ print("ERROR: No trajectories loaded!")
139
+ return
140
+
141
  # Data audit
142
  lengths = [len(t['timestamps']) for t in raw_trajs]
143
+ print(f" Lengths: min={min(lengths)}, max={max(lengths)}, median={np.median(lengths):.0f}")
144
 
145
  # ---- Process ----
146
  print("\n2. Processing trajectories...")
147
  t0 = time.time()
148
  processor = TrajectoryProcessor(resample_dt=RESAMPLE_DT)
149
  dataset = build_dataset(raw_trajs, processor, seq_len=SEQ_LEN, stride=STRIDE)
150
+ print(f" Processing: {time.time()-t0:.1f}s")
151
+
152
+ if len(dataset) == 0:
153
+ print("ERROR: No valid windows!")
154
+ return
155
 
156
  # Split
157
  n_val = max(1, int(0.15 * len(dataset)))
158
  n_train = len(dataset) - n_val
159
+ train_ds, val_ds = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(42))
 
 
 
160
  print(f"\n3. Split: {n_train} train, {n_val} val")
161
 
162
  # ---- Model ----
163
  model = AirTrackLM(config).to(device)
164
  param_counts = model.count_parameters()
165
+ print(f"\n4. Model: {param_counts['total']:,} params ({param_counts['trainable']:,} trainable)")
166
  for name, count in param_counts.items():
167
  if name not in ['total', 'trainable']:
168
  print(f" {name}: {count:,}")
169
 
170
+ # ---- Loaders ----
171
  train_loader = DataLoader(
172
  train_ds, batch_size=BATCH_SIZE, shuffle=True,
173
  collate_fn=collate_fn, num_workers=2, pin_memory=(device.type == 'cuda'),
 
178
  )
179
  print(f" {len(train_loader)} train batches, {len(val_loader)} val batches")
180
 
181
+ # ---- Optimizer ----
182
  loss_fn = NextStateLoss(config)
183
  optimizer = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY, betas=(0.9, 0.999))
184
  total_steps = N_EPOCHS * len(train_loader)
185
  scheduler = CosineAnnealingLR(optimizer, T_max=total_steps, eta_min=LR * 0.01)
 
 
186
  scaler = torch.amp.GradScaler('cuda') if device.type == 'cuda' else None
187
 
188
+ # ---- Train ----
189
  output_dir = Path('./checkpoints')
190
  output_dir.mkdir(exist_ok=True)
191
 
 
195
  global_step = 0
196
 
197
  print(f"\n{'='*70}")
198
+ print(f"Training: {N_EPOCHS} epochs, bs={BATCH_SIZE}, lr={LR}")
199
  print(f"{'='*70}\n")
200
 
201
  for epoch in range(N_EPOCHS):
202
  t_epoch = time.time()
203
  model.train()
 
204
  train_loss = 0.0
205
  train_components = {}
206
  n_batches = 0
 
233
  train_components[k] = train_components.get(k, 0) + v
234
  n_batches += 1
235
 
 
236
  if tracker and global_step % 20 == 0:
237
+ try:
238
+ trackio.log({
239
+ 'train/loss': loss_log['total'],
240
+ 'train/lr': scheduler.get_last_lr()[0],
241
+ 'train/step': global_step,
242
+ })
243
+ except Exception:
244
+ pass
245
 
246
  if (batch_idx + 1) % 50 == 0:
247
+ print(f" Epoch {epoch+1} Batch {batch_idx+1}/{len(train_loader)} | Loss: {train_loss/n_batches:.4f}")
 
248
 
249
  train_avg = {k: v / n_batches for k, v in train_components.items()}
 
 
250
  val_metrics = evaluate(model, val_loader, loss_fn, device)
251
 
252
  elapsed = time.time() - t_epoch
253
  improved = val_metrics['total'] < best_val_loss
254
 
255
+ print(f"\nEpoch {epoch+1}/{N_EPOCHS} [{elapsed:.1f}s] {'★' if improved else ''}")
256
+ print(f" Train loss={train_avg['total']:.4f} | Val loss={val_metrics['total']:.4f}")
257
+ print(f" Val Acc - COG:{val_metrics.get('cog_acc',0):.3f} SOG:{val_metrics.get('sog_acc',0):.3f} "
258
+ f"ROT:{val_metrics.get('rot_acc',0):.3f} AltRate:{val_metrics.get('alt_rate_acc',0):.3f}")
 
 
 
 
 
259
  print(f" LR: {scheduler.get_last_lr()[0]:.6f}")
260
 
 
261
  if tracker:
262
+ try:
263
+ trackio.log({
264
+ 'epoch': epoch + 1,
265
+ 'val/loss': val_metrics['total'],
266
+ **{f'val/{k}': v for k, v in val_metrics.items()},
267
+ 'train/epoch_loss': train_avg['total'],
268
+ })
269
+ except Exception:
270
+ pass
271
 
272
  history.append({
273
+ 'epoch': epoch + 1, 'train': train_avg,
274
+ 'val': val_metrics, 'lr': scheduler.get_last_lr()[0], 'time': elapsed,
 
 
 
275
  })
276
 
 
277
  if improved:
278
  best_val_loss = val_metrics['total']
279
  patience_counter = 0
 
285
  'val_loss': best_val_loss,
286
  'val_metrics': val_metrics,
287
  }, output_dir / 'best_model.pt')
288
+ print(f" ★ Best model saved (val_loss={best_val_loss:.4f})")
289
  else:
290
  patience_counter += 1
291
  if patience_counter >= PATIENCE:
 
293
  break
294
  print()
295
 
296
+ # ---- Save & Push ----
297
  print("\n" + "=" * 70)
298
+ print("Saving and pushing to Hub...")
299
 
 
300
  torch.save({
301
+ 'epoch': epoch + 1, 'model_state_dict': model.state_dict(),
302
+ 'config': config.__dict__, 'best_val_loss': best_val_loss, 'history': history,
 
 
 
303
  }, output_dir / 'final_model.pt')
304
 
 
305
  with open(output_dir / 'training_history.json', 'w') as f:
306
  json.dump(history, f, indent=2, default=str)
307
 
 
308
  with open(output_dir / 'config.json', 'w') as f:
309
  json.dump(config.__dict__, f, indent=2)
310
 
 
311
  try:
312
+ from huggingface_hub import HfApi
313
  api = HfApi()
 
 
314
  api.upload_folder(
315
+ folder_path=str(output_dir), repo_id=HUB_MODEL_ID, repo_type="model",
316
+ commit_message=f"Training: val_loss={best_val_loss:.4f}",
 
 
317
  )
318
+ print(f"✓ Checkpoints pushed to https://huggingface.co/{HUB_MODEL_ID}")
319
  except Exception as e:
320
+ print(f"Push checkpoints failed: {e}")
321
 
322
+ # Upload source files
323
  try:
324
  script_dir = os.path.dirname(os.path.abspath(__file__))
325
+ for fname in ['data_pipeline.py', 'model.py', 'uncertainty.py', 'train_full.py']:
 
326
  fpath = os.path.join(script_dir, fname)
327
  if os.path.exists(fpath):
328
  api.upload_file(
329
+ path_or_fileobj=fpath, path_in_repo=fname,
330
+ repo_id=HUB_MODEL_ID, repo_type="model",
 
 
331
  )
332
+ print(f"✓ Source files uploaded")
333
  except Exception as e:
334
+ print(f"Source upload failed: {e}")
335
 
336
  print(f"\nBest val loss: {best_val_loss:.4f}")
 
337
  print("Done!")
338
 
339