zirobtc commited on
Commit
16f4534
·
1 Parent(s): f8896b8

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. data/ohlc_stats.npz +1 -1
  2. inference_utils.py +34 -0
  3. ingest.sh +1 -1
  4. train.py +132 -22
  5. train.sh +2 -2
data/ohlc_stats.npz CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:fc81a5fbd342767ca491eceeca805f54dd9ffe4ff0bb723cdda26e52f54f914d
3
  size 1660
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:92f50d146182941b8b01be19b4699c1b0ebe37bac1ff155580b20a8755994070
3
  size 1660
inference_utils.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def transform_targets(targets):
4
+ """
5
+ Applies the log-transform used during training:
6
+ y_trans = sign(y) * log(1 + |y|)
7
+
8
+ Args:
9
+ targets: torch.Tensor or float, raw returns (e.g. 1.5 for 150%)
10
+ Returns:
11
+ transformed targets in the same shape/type
12
+ """
13
+ if isinstance(targets, torch.Tensor):
14
+ return torch.sign(targets) * torch.log1p(torch.abs(targets))
15
+ else:
16
+ # Handle float/numpy
17
+ import numpy as np
18
+ return np.sign(targets) * np.log1p(np.abs(targets))
19
+
20
+ def inverse_transform_targets(transformed_targets):
21
+ """
22
+ Inverts the log-transform to get back raw returns:
23
+ y = sign(y_trans) * (exp(|y_trans|) - 1)
24
+
25
+ Args:
26
+ transformed_targets: torch.Tensor, model outputs (logits)
27
+ Returns:
28
+ raw returns (e.g. 1.5 for 150%)
29
+ """
30
+ if isinstance(transformed_targets, torch.Tensor):
31
+ return torch.sign(transformed_targets) * (torch.exp(torch.abs(transformed_targets)) - 1)
32
+ else:
33
+ import numpy as np
34
+ return np.sign(transformed_targets) * (np.exp(np.abs(transformed_targets)) - 1)
ingest.sh CHANGED
@@ -20,7 +20,7 @@ error() { echo -e "${RED}[ERROR]${NC} $1"; exit 1; }
20
  #===============================================================================
21
  header "Step 5-6/7: Processing Epochs (Download → Ingest → Delete)"
22
 
23
- EPOCHS=(844 845 846)
24
 
25
 
26
  log "Processing epochs one at a time to minimize disk usage..."
 
20
  #===============================================================================
21
  header "Step 5-6/7: Processing Epochs (Download → Ingest → Delete)"
22
 
23
+ EPOCHS=(844 845 846 847 848 849)
24
 
25
 
26
  log "Processing epochs one at a time to minimize disk usage..."
train.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  import argparse
3
  import math
4
  import logging
@@ -202,6 +203,8 @@ def run_validation(model, val_dataloader, accelerator, quantiles, quality_loss_f
202
  if labels_mask.sum() == 0:
203
  return_loss = torch.tensor(0.0, device=accelerator.device)
204
  else:
 
 
205
  return_loss = quantile_pinball_loss(preds, labels, labels_mask, quantiles)
206
 
207
  quality_loss = quality_loss_fn(quality_preds, quality_targets)
@@ -608,43 +611,102 @@ def main() -> None:
608
  # Load checkpoint if it exists
609
  starting_epoch = 0
610
  resume_step = 0
 
 
611
 
612
  # Check for existing checkpoints
613
  if checkpoint_dir.exists():
614
- # Look for subfolders named 'checkpoint-X' or 'epoch_X'
615
- # Accelerate saves to folders.
616
  dirs = [d for d in checkpoint_dir.iterdir() if d.is_dir()]
617
  if dirs:
618
- # Sort by modification time or name to find latest
619
- # Sort by modification time or name to find latest
620
  dirs.sort(key=lambda x: x.stat().st_mtime)
621
  latest_checkpoint = dirs[-1]
622
 
623
  if args.resume_from_checkpoint:
 
624
  if args.resume_from_checkpoint == "latest":
625
  if latest_checkpoint:
626
- logger.info("=" * 60)
627
- logger.info(f"🔄 RESUMING FROM LATEST CHECKPOINT: {latest_checkpoint}")
628
- logger.info("=" * 60)
629
- accelerator.load_state(str(latest_checkpoint))
630
  else:
631
  logger.warning("Resume requested but no checkpoint found in dir. Starting fresh.")
632
  else:
633
- # Specific path
634
  custom_ckpt = Path(args.resume_from_checkpoint)
635
  if custom_ckpt.exists():
636
- logger.info("=" * 60)
637
- logger.info(f"🔄 RESUMING FROM CHECKPOINT: {custom_ckpt}")
638
- logger.info("=" * 60)
639
- accelerator.load_state(str(custom_ckpt))
640
- logger.info("✅ Model, optimizer, scheduler, and dataloader states restored.")
641
  else:
642
  raise FileNotFoundError(f"Checkpoint not found at {custom_ckpt}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
643
  else:
644
  logger.info("🆕 No resume flag provided. Starting fresh training run.")
645
 
646
  # --- 7. Training Loop ---
647
- total_steps = 0
648
  quality_loss_fn = nn.MSELoss()
649
 
650
  logger.info("***** Running training *****")
@@ -660,18 +722,21 @@ def main() -> None:
660
  epoch_loss = 0.0
661
  valid_batches = 0
662
 
 
 
 
 
 
 
663
  # Tqdm only on main process
664
  progress_bar = tqdm(
665
- dataloader,
666
  desc=f"Epoch {epoch+1}/{epochs}",
667
  disable=not accelerator.is_local_main_process,
668
- initial=resume_step # If you calculate resume_step from checkpoint
669
  )
670
 
671
  for step, batch in enumerate(progress_bar):
672
- # Skip steps if resuming (Accelerate dataloader might handle this automatically if configured,
673
- # but 'skip_first_batches' is often manual.
674
- # For simplicity here, we assume load_state restored the dataloader iterator.)
675
 
676
  if batch is None:
677
  continue
@@ -685,7 +750,8 @@ def main() -> None:
685
 
686
  grad_stats: Optional[Dict[str, float]] = None
687
  module_grad_stats: Dict[str, float] = {}
688
- with accelerator.accumulate(model):
 
689
  outputs = model(batch)
690
 
691
  preds = outputs["quantile_logits"]
@@ -706,6 +772,11 @@ def main() -> None:
706
  t_cutoffs[0] if t_cutoffs else "unknown",
707
  )
708
 
 
 
 
 
 
709
  if labels_mask.sum() == 0:
710
  return_loss = torch.tensor(0.0, requires_grad=True, device=accelerator.device)
711
  else:
@@ -747,6 +818,26 @@ def main() -> None:
747
  scheduler.step()
748
  optimizer.zero_grad()
749
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
750
  # Logging
751
  if accelerator.sync_gradients:
752
  total_steps += 1
@@ -794,7 +885,16 @@ def main() -> None:
794
  save_path = checkpoint_dir / f"checkpoint-{total_steps}"
795
  accelerator.save_state(output_dir=str(save_path))
796
 
797
- # NEW: Save in standard HF-loadable way
 
 
 
 
 
 
 
 
 
798
  hf_save_path = save_path / "hf_model"
799
  unwrapped_model = accelerator.unwrap_model(model)
800
  unwrapped_model.save_pretrained(str(hf_save_path))
@@ -839,6 +939,16 @@ def main() -> None:
839
  # Save Checkpoint at end of epoch
840
  save_path = checkpoint_dir / f"epoch_{epoch+1}"
841
  accelerator.save_state(output_dir=str(save_path))
 
 
 
 
 
 
 
 
 
 
842
  hf_save_path = save_path / "hf_model"
843
  unwrapped_model = accelerator.unwrap_model(model)
844
  unwrapped_model.save_pretrained(str(hf_save_path))
 
1
  import os
2
+ import json
3
  import argparse
4
  import math
5
  import logging
 
203
  if labels_mask.sum() == 0:
204
  return_loss = torch.tensor(0.0, device=accelerator.device)
205
  else:
206
+ # Log-transform targets for validation too (so val loss matches train loss scale)
207
+ labels = torch.sign(labels) * torch.log1p(torch.abs(labels))
208
  return_loss = quantile_pinball_loss(preds, labels, labels_mask, quantiles)
209
 
210
  quality_loss = quality_loss_fn(quality_preds, quality_targets)
 
611
  # Load checkpoint if it exists
612
  starting_epoch = 0
613
  resume_step = 0
614
+ total_steps = 0
615
+ latest_checkpoint = None
616
 
617
  # Check for existing checkpoints
618
  if checkpoint_dir.exists():
 
 
619
  dirs = [d for d in checkpoint_dir.iterdir() if d.is_dir()]
620
  if dirs:
 
 
621
  dirs.sort(key=lambda x: x.stat().st_mtime)
622
  latest_checkpoint = dirs[-1]
623
 
624
  if args.resume_from_checkpoint:
625
+ resolved_ckpt = None
626
  if args.resume_from_checkpoint == "latest":
627
  if latest_checkpoint:
628
+ resolved_ckpt = latest_checkpoint
 
 
 
629
  else:
630
  logger.warning("Resume requested but no checkpoint found in dir. Starting fresh.")
631
  else:
 
632
  custom_ckpt = Path(args.resume_from_checkpoint)
633
  if custom_ckpt.exists():
634
+ resolved_ckpt = custom_ckpt
 
 
 
 
635
  else:
636
  raise FileNotFoundError(f"Checkpoint not found at {custom_ckpt}")
637
+
638
+ if resolved_ckpt is not None:
639
+ logger.info("=" * 60)
640
+ logger.info(f"🔄 RESUMING FROM CHECKPOINT: {resolved_ckpt}")
641
+ logger.info("=" * 60)
642
+ accelerator.load_state(str(resolved_ckpt))
643
+
644
+ # Restore epoch / step counters from training_state.json
645
+ state_file = resolved_ckpt / "training_state.json"
646
+ if state_file.exists():
647
+ with open(state_file, "r") as f:
648
+ training_state = json.load(f)
649
+ starting_epoch = training_state.get("epoch", 0)
650
+ resume_step = training_state.get("step_in_epoch", 0)
651
+ total_steps = training_state.get("global_step", 0)
652
+ logger.info(
653
+ f"✅ Resumed: epoch={starting_epoch}, step_in_epoch={resume_step}, "
654
+ f"global_step={total_steps}"
655
+ )
656
+ else:
657
+ # Try to infer step count from the restored scheduler state
658
+ inner_sched = scheduler.scheduler if hasattr(scheduler, 'scheduler') else scheduler
659
+ restored_step = getattr(inner_sched, 'last_epoch', 0)
660
+ if restored_step > 0:
661
+ total_steps = restored_step
662
+ logger.warning(
663
+ f"⚠️ training_state.json not found. "
664
+ f"Inferred global_step={total_steps} from scheduler state."
665
+ )
666
+ else:
667
+ logger.warning(
668
+ "⚠️ training_state.json not found and scheduler step is 0. "
669
+ "Epoch/step counters start from 0."
670
+ )
671
+ logger.info("✅ Model, optimizer, scheduler, and dataloader states restored.")
672
+
673
+ # --- FIX: Rebuild scheduler to extend over new epochs ---
674
+ # Without this, the cosine schedule wraps/oscillates because the
675
+ # internal step counter (from past runs) exceeds the original
676
+ # max_train_steps (computed for just --epochs in this run).
677
+ # We extend the schedule: completed_steps → completed_steps + this_run_steps
678
+ # with NO warmup (model is already trained).
679
+ extended_total = total_steps + max_train_steps
680
+ logger.info(
681
+ f"♻️ Rebuilding LR schedule: "
682
+ f"completed={total_steps}, this_run={max_train_steps}, "
683
+ f"extended_total={extended_total} (no warmup)"
684
+ )
685
+
686
+ # Get the base optimizer (unwrap Accelerate wrapper)
687
+ base_opt = optimizer.optimizer if hasattr(optimizer, 'optimizer') else optimizer
688
+ new_sched = get_cosine_schedule_with_warmup(
689
+ base_opt,
690
+ num_warmup_steps=0, # no re-warmup
691
+ num_training_steps=extended_total
692
+ )
693
+ # Fast-forward to current position
694
+ new_sched.last_epoch = total_steps
695
+ new_sched._step_count = total_steps + 1
696
+
697
+ # Swap into the Accelerate wrapper
698
+ if hasattr(scheduler, 'scheduler'):
699
+ scheduler.scheduler = new_sched
700
+ else:
701
+ scheduler = new_sched
702
+
703
+ # Log the resulting LR for verification
704
+ current_lr = base_opt.param_groups[0]['lr']
705
+ logger.info(f"♻️ LR after rebuild: {current_lr:.6e}")
706
  else:
707
  logger.info("🆕 No resume flag provided. Starting fresh training run.")
708
 
709
  # --- 7. Training Loop ---
 
710
  quality_loss_fn = nn.MSELoss()
711
 
712
  logger.info("***** Running training *****")
 
722
  epoch_loss = 0.0
723
  valid_batches = 0
724
 
725
+ # Skip already-processed batches when resuming mid-epoch
726
+ active_dataloader = dataloader
727
+ if epoch == starting_epoch and resume_step > 0:
728
+ logger.info(f"⏩ Skipping {resume_step} batches in epoch {epoch+1} (already processed)")
729
+ active_dataloader = accelerator.skip_first_batches(dataloader, num_batches=resume_step)
730
+
731
  # Tqdm only on main process
732
  progress_bar = tqdm(
733
+ active_dataloader,
734
  desc=f"Epoch {epoch+1}/{epochs}",
735
  disable=not accelerator.is_local_main_process,
736
+ initial=resume_step if epoch == starting_epoch else 0
737
  )
738
 
739
  for step, batch in enumerate(progress_bar):
 
 
 
740
 
741
  if batch is None:
742
  continue
 
750
 
751
  grad_stats: Optional[Dict[str, float]] = None
752
  module_grad_stats: Dict[str, float] = {}
753
+ try:
754
+ with accelerator.accumulate(model):
755
  outputs = model(batch)
756
 
757
  preds = outputs["quantile_logits"]
 
772
  t_cutoffs[0] if t_cutoffs else "unknown",
773
  )
774
 
775
+ # Log-transform targets to stabilize Class 5 gradients
776
+ # y_trans = sign(y) * log(1 + |y|)
777
+ # Compresses 10000x return (label ~15) to ~2.7
778
+ labels = torch.sign(labels) * torch.log1p(torch.abs(labels))
779
+
780
  if labels_mask.sum() == 0:
781
  return_loss = torch.tensor(0.0, requires_grad=True, device=accelerator.device)
782
  else:
 
818
  scheduler.step()
819
  optimizer.zero_grad()
820
 
821
+ except torch.cuda.OutOfMemoryError:
822
+ # Log the offending batch for debugging
823
+ token_addresses = batch.get('token_addresses', [])
824
+ sample_indices = batch.get('sample_indices', [])
825
+ seq_len = batch.get('event_type_ids', torch.empty(0)).shape[-1] if 'event_type_ids' in batch else '?'
826
+ logger.warning(
827
+ "⚠️ CUDA OOM — skipping batch! "
828
+ "seq_len=%s sample_idx=%s token=%s | "
829
+ "Clearing cache and continuing.",
830
+ seq_len,
831
+ sample_indices[0] if sample_indices else "unknown",
832
+ token_addresses[0] if token_addresses else "unknown",
833
+ )
834
+ # Clean up to recover
835
+ optimizer.zero_grad(set_to_none=True)
836
+ torch.cuda.empty_cache()
837
+ if hasattr(torch.cuda, 'reset_peak_memory_stats'):
838
+ torch.cuda.reset_peak_memory_stats()
839
+ continue
840
+
841
  # Logging
842
  if accelerator.sync_gradients:
843
  total_steps += 1
 
885
  save_path = checkpoint_dir / f"checkpoint-{total_steps}"
886
  accelerator.save_state(output_dir=str(save_path))
887
 
888
+ # Save resume metadata
889
+ state_file = save_path / "training_state.json"
890
+ with open(state_file, "w") as f:
891
+ json.dump({
892
+ "epoch": epoch,
893
+ "step_in_epoch": step + (resume_step if epoch == starting_epoch else 0) + 1,
894
+ "global_step": total_steps,
895
+ }, f, indent=2)
896
+
897
+ # Save in standard HF-loadable way
898
  hf_save_path = save_path / "hf_model"
899
  unwrapped_model = accelerator.unwrap_model(model)
900
  unwrapped_model.save_pretrained(str(hf_save_path))
 
939
  # Save Checkpoint at end of epoch
940
  save_path = checkpoint_dir / f"epoch_{epoch+1}"
941
  accelerator.save_state(output_dir=str(save_path))
942
+
943
+ # Save resume metadata (epoch completed, so next epoch starts fresh)
944
+ state_file = save_path / "training_state.json"
945
+ with open(state_file, "w") as f:
946
+ json.dump({
947
+ "epoch": epoch + 1,
948
+ "step_in_epoch": 0,
949
+ "global_step": total_steps,
950
+ }, f, indent=2)
951
+
952
  hf_save_path = save_path / "hf_model"
953
  unwrapped_model = accelerator.unwrap_model(model)
954
  unwrapped_model.save_pretrained(str(hf_save_path))
train.sh CHANGED
@@ -1,7 +1,7 @@
1
  export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
2
 
3
  accelerate launch train.py \
4
- --epochs 5 \
5
  --batch_size 2 \
6
  --learning_rate 1e-4 \
7
  --warmup_ratio 0.1 \
@@ -20,5 +20,5 @@ accelerate launch train.py \
20
  --num_workers 0 \
21
  --val_samples_per_class 2 \
22
  --val_every 100 \
23
- --resume_from_checkpoint checkpoints/checkpoint-3000 \
24
  "$@"
 
1
  export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
2
 
3
  accelerate launch train.py \
4
+ --epochs 10 \
5
  --batch_size 2 \
6
  --learning_rate 1e-4 \
7
  --warmup_ratio 0.1 \
 
20
  --num_workers 0 \
21
  --val_samples_per_class 2 \
22
  --val_every 100 \
23
+ --resume_from_checkpoint checkpoints/apollo-1 \
24
  "$@"