zirobtc commited on
Commit
f8896b8
ยท
1 Parent(s): 697fba5

Upload folder using huggingface_hub

Browse files
.gitignore CHANGED
@@ -17,3 +17,4 @@ metadata/
17
  store/
18
  preprocessed_configs/
19
  .early.coverage
 
 
17
  store/
18
  preprocessed_configs/
19
  .early.coverage
20
+ .ipynb_checkpoints/
data/ohlc_stats.npz CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:f6a84b63ec605e83a655f404bc89d825aa8ffbb5ac3ea24c7d2197324646d016
3
  size 1660
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fc81a5fbd342767ca491eceeca805f54dd9ffe4ff0bb723cdda26e52f54f914d
3
  size 1660
log.log CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:21f6421b07eb49c1e0a5518a628403ce0ae7149fb81a600aebad2dfcaf0313c9
3
  size 2854
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e4ceaef802908dd650ce1ade210a0e827ec433b904adc3bf17c3d8a877e59ae6
3
  size 2854
sample_123hWLTtXVG58ARU_0.json ADDED
The diff for this file is too large to render. See raw diff
 
sample_12m6essAkvZc4cRZ_0.json ADDED
The diff for this file is too large to render. See raw diff
 
train.py CHANGED
@@ -38,7 +38,7 @@ from torch.optim import AdamW
38
  from accelerate import Accelerator
39
  from accelerate.logging import get_logger
40
  from accelerate.utils import ProjectConfiguration, set_seed
41
- from transformers import get_linear_schedule_with_warmup
42
 
43
  # Logging
44
  from tqdm.auto import tqdm
@@ -119,7 +119,7 @@ def quantile_pinball_loss(preds: torch.Tensor,
119
  return sum(losses) / mask.sum().clamp_min(1.0)
120
 
121
 
122
- def create_balanced_split(dataset, val_ratio: float = 0.1, seed: int = 42):
123
  """
124
  Create train/val split with balanced classes in validation set.
125
  Uses dataset's internal file_class_map for speed (no file loading).
@@ -153,10 +153,10 @@ def create_balanced_split(dataset, val_ratio: float = 0.1, seed: int = 42):
153
  train_indices = []
154
  val_indices = []
155
 
156
- # For each class, take val_ratio samples for validation
157
  for class_id, indices in class_to_indices.items():
158
  random.shuffle(indices)
159
- n_val = max(1, int(len(indices) * val_ratio)) # At least 1 sample per class
160
  val_indices.extend(indices[:n_val])
161
  train_indices.extend(indices[n_val:])
162
 
@@ -329,7 +329,7 @@ def parse_args() -> argparse.Namespace:
329
  parser.add_argument("--pin_memory", dest="pin_memory", action="store_true", default=True)
330
  parser.add_argument("--no-pin_memory", dest="pin_memory", action="store_false")
331
  parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="Path to checkpoint or 'latest'")
332
- parser.add_argument("--val_split", type=float, default=0.1, help="Fraction of data for validation (default 0.1)")
333
  parser.add_argument("--val_every", type=int, default=1000, help="Run validation every N steps (default 1000)")
334
  return parser.parse_args()
335
 
@@ -374,7 +374,10 @@ def main() -> None:
374
  logger.info("Initialized with CLI arguments.")
375
  tensorboard_dir.mkdir(parents=True, exist_ok=True)
376
  checkpoint_dir.mkdir(parents=True, exist_ok=True)
377
- accelerator.init_trackers("oracle_training")
 
 
 
378
 
379
  device = accelerator.device
380
 
@@ -413,7 +416,16 @@ def main() -> None:
413
  logger.info(f"Initializing Encoders with dtype={init_dtype}...")
414
 
415
  # Encoders
416
- multi_modal_encoder = MultiModalEncoder(dtype=init_dtype)
 
 
 
 
 
 
 
 
 
417
  time_encoder = ContextualTimeEncoder(dtype=init_dtype)
418
  token_encoder = TokenEncoder(multi_dim=multi_modal_encoder.embedding_dim, dtype=init_dtype)
419
  wallet_encoder = WalletEncoder(encoder=multi_modal_encoder, dtype=init_dtype)
@@ -430,18 +442,6 @@ def main() -> None:
430
  max_seq_len=max_seq_len
431
  )
432
 
433
- # --- OPTIMIZATION: Pre-load SigLIP encoder to avoid lazy-loading on first batch ---
434
- # This moves the ~8s model load from first batch to startup (where it's expected)
435
- logger.info("Pre-loading SigLIP encoder for collator (avoids first-batch delay)...")
436
- from models.multi_modal_processor import MultiModalEncoder as CollatorEncoder
437
- collator_encoder = CollatorEncoder(
438
- model_id=collator.model_id,
439
- dtype=init_dtype,
440
- device="cuda" # Use GPU for encoding (requires num_workers=0)
441
- )
442
- _set_worker_encoder(collator_encoder)
443
- logger.info("SigLIP encoder pre-loaded successfully.")
444
-
445
  # ==========================================================================
446
  # OFFLINE MODE: No DB connections during training for maximum GPU utilization
447
  # ==========================================================================
@@ -475,18 +475,18 @@ def main() -> None:
475
  raise RuntimeError("Dataset is empty.")
476
 
477
  # --- NEW: Create balanced train/val split ---
478
- logger.info(f"Creating {1-args.val_split:.0%}/{args.val_split:.0%} train/val split with balanced classes...")
479
  train_indices, val_indices, class_distribution = create_balanced_split(
480
- dataset, val_ratio=args.val_split, seed=seed
481
  )
482
 
483
  # Log class distribution (use set for O(1) lookup)
484
  train_set = set(train_indices)
485
  logger.info(f"Total samples: {len(dataset)}, Train: {len(train_indices)}, Val: {len(val_indices)}")
486
  for class_id, indices in sorted(class_distribution.items()):
487
- n_val = int(len(indices) * args.val_split) # Approximate from split ratio
488
  n_train = len(indices) - n_val
489
- logger.info(f" Class {class_id}: {len(indices)} total (~{n_train} train, ~{n_val} val)")
490
 
491
  # --- Compute class weights for loss weighting ---
492
  num_classes = max(class_distribution.keys()) + 1 if class_distribution else 7
@@ -544,7 +544,7 @@ def main() -> None:
544
  # Validation dataloader (no shuffle, no weighted sampling)
545
  val_dl_kwargs = dict(
546
  dataset=val_dataset,
547
- batch_size=batch_size,
548
  shuffle=False,
549
  num_workers=int(args.num_workers),
550
  pin_memory=bool(args.pin_memory),
@@ -588,7 +588,7 @@ def main() -> None:
588
  max_train_steps = epochs * num_update_steps_per_epoch
589
  num_warmup_steps = int(max_train_steps * warmup_ratio)
590
 
591
- scheduler = get_linear_schedule_with_warmup(
592
  optimizer,
593
  num_warmup_steps=num_warmup_steps,
594
  num_training_steps=max_train_steps
@@ -623,7 +623,9 @@ def main() -> None:
623
  if args.resume_from_checkpoint:
624
  if args.resume_from_checkpoint == "latest":
625
  if latest_checkpoint:
626
- logger.info(f"Resuming from latest checkpoint: {latest_checkpoint}")
 
 
627
  accelerator.load_state(str(latest_checkpoint))
628
  else:
629
  logger.warning("Resume requested but no checkpoint found in dir. Starting fresh.")
@@ -631,12 +633,15 @@ def main() -> None:
631
  # Specific path
632
  custom_ckpt = Path(args.resume_from_checkpoint)
633
  if custom_ckpt.exists():
634
- logger.info(f"Resuming from specific checkpoint: {custom_ckpt}")
 
 
635
  accelerator.load_state(str(custom_ckpt))
 
636
  else:
637
  raise FileNotFoundError(f"Checkpoint not found at {custom_ckpt}")
638
  else:
639
- logger.info("No resume flag provided. Starting fresh.")
640
 
641
  # --- 7. Training Loop ---
642
  total_steps = 0
@@ -831,15 +836,14 @@ def main() -> None:
831
  logger.info(f"Epoch {epoch+1} complete. Avg loss: {avg_loss:.6f}")
832
  accelerator.log({"train/loss_epoch": avg_loss}, step=total_steps)
833
 
834
- # Save Checkpoint at end of epoch (REMOVED: saving every epoch is too much)
835
  save_path = checkpoint_dir / f"epoch_{epoch+1}"
836
- # accelerator.save_state(output_dir=str(save_path))
837
- # hf_save_path = save_path / "hf_model"
838
- # unwrapped_model = accelerator.unwrap_model(model)
839
- # unwrapped_model.save_pretrained(str(hf_save_path))
840
 
841
- # logger.info(f"Saved and HF-style model (EOF) to {save_path}")
842
- pass
843
  else:
844
  if accelerator.is_main_process:
845
  logger.warning(f"Epoch {epoch+1}: No valid batches processed.")
 
38
  from accelerate import Accelerator
39
  from accelerate.logging import get_logger
40
  from accelerate.utils import ProjectConfiguration, set_seed
41
+ from transformers import get_cosine_schedule_with_warmup
42
 
43
  # Logging
44
  from tqdm.auto import tqdm
 
119
  return sum(losses) / mask.sum().clamp_min(1.0)
120
 
121
 
122
+ def create_balanced_split(dataset, n_val_per_class: int = 1, seed: int = 42):
123
  """
124
  Create train/val split with balanced classes in validation set.
125
  Uses dataset's internal file_class_map for speed (no file loading).
 
153
  train_indices = []
154
  val_indices = []
155
 
156
+ # For each class, take n_val_per_class samples for validation
157
  for class_id, indices in class_to_indices.items():
158
  random.shuffle(indices)
159
+ n_val = min(len(indices), n_val_per_class) # Ensure we don't take more than we have
160
  val_indices.extend(indices[:n_val])
161
  train_indices.extend(indices[n_val:])
162
 
 
329
  parser.add_argument("--pin_memory", dest="pin_memory", action="store_true", default=True)
330
  parser.add_argument("--no-pin_memory", dest="pin_memory", action="store_false")
331
  parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="Path to checkpoint or 'latest'")
332
+ parser.add_argument("--val_samples_per_class", type=int, default=1, help="Number of validation samples per class (default 1)")
333
  parser.add_argument("--val_every", type=int, default=1000, help="Run validation every N steps (default 1000)")
334
  return parser.parse_args()
335
 
 
374
  logger.info("Initialized with CLI arguments.")
375
  tensorboard_dir.mkdir(parents=True, exist_ok=True)
376
  checkpoint_dir.mkdir(parents=True, exist_ok=True)
377
+ from datetime import datetime
378
+ run_name = f"oracle_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
379
+ accelerator.init_trackers(run_name)
380
+ logger.info(f"๐Ÿ“Š TensorBoard run: {run_name}")
381
 
382
  device = accelerator.device
383
 
 
416
  logger.info(f"Initializing Encoders with dtype={init_dtype}...")
417
 
418
  # Encoders
419
+ logger.info("Initializing Shared MultiModalEncoder (SigLIP) on GPU...")
420
+ # Initialize ONCE on GPU for both WalletEncoder (dims) and Collator (encoding)
421
+ multi_modal_encoder = MultiModalEncoder(dtype=init_dtype, device="cuda")
422
+
423
+ # Use this shared instance for setting the worker encoder (num_workers=0 optimization)
424
+ # This avoids loading a second copy of SigLIP
425
+ logger.info("Setting shared encoder for collator...")
426
+ from models.multi_modal_processor import MultiModalEncoder as CollatorEncoder
427
+ _set_worker_encoder(multi_modal_encoder)
428
+
429
  time_encoder = ContextualTimeEncoder(dtype=init_dtype)
430
  token_encoder = TokenEncoder(multi_dim=multi_modal_encoder.embedding_dim, dtype=init_dtype)
431
  wallet_encoder = WalletEncoder(encoder=multi_modal_encoder, dtype=init_dtype)
 
442
  max_seq_len=max_seq_len
443
  )
444
 
 
 
 
 
 
 
 
 
 
 
 
 
445
  # ==========================================================================
446
  # OFFLINE MODE: No DB connections during training for maximum GPU utilization
447
  # ==========================================================================
 
475
  raise RuntimeError("Dataset is empty.")
476
 
477
  # --- NEW: Create balanced train/val split ---
478
+ logger.info(f"Creating balanced split with {args.val_samples_per_class} validation samples per class...")
479
  train_indices, val_indices, class_distribution = create_balanced_split(
480
+ dataset, n_val_per_class=args.val_samples_per_class, seed=seed
481
  )
482
 
483
  # Log class distribution (use set for O(1) lookup)
484
  train_set = set(train_indices)
485
  logger.info(f"Total samples: {len(dataset)}, Train: {len(train_indices)}, Val: {len(val_indices)}")
486
  for class_id, indices in sorted(class_distribution.items()):
487
+ n_val = min(len(indices), args.val_samples_per_class)
488
  n_train = len(indices) - n_val
489
+ logger.info(f" Class {class_id}: {len(indices)} total (~{n_train} train, {n_val} val)")
490
 
491
  # --- Compute class weights for loss weighting ---
492
  num_classes = max(class_distribution.keys()) + 1 if class_distribution else 7
 
544
  # Validation dataloader (no shuffle, no weighted sampling)
545
  val_dl_kwargs = dict(
546
  dataset=val_dataset,
547
+ batch_size=1, # Force batch size 1 for validation to prevent OOM
548
  shuffle=False,
549
  num_workers=int(args.num_workers),
550
  pin_memory=bool(args.pin_memory),
 
588
  max_train_steps = epochs * num_update_steps_per_epoch
589
  num_warmup_steps = int(max_train_steps * warmup_ratio)
590
 
591
+ scheduler = get_cosine_schedule_with_warmup(
592
  optimizer,
593
  num_warmup_steps=num_warmup_steps,
594
  num_training_steps=max_train_steps
 
623
  if args.resume_from_checkpoint:
624
  if args.resume_from_checkpoint == "latest":
625
  if latest_checkpoint:
626
+ logger.info("=" * 60)
627
+ logger.info(f"๐Ÿ”„ RESUMING FROM LATEST CHECKPOINT: {latest_checkpoint}")
628
+ logger.info("=" * 60)
629
  accelerator.load_state(str(latest_checkpoint))
630
  else:
631
  logger.warning("Resume requested but no checkpoint found in dir. Starting fresh.")
 
633
  # Specific path
634
  custom_ckpt = Path(args.resume_from_checkpoint)
635
  if custom_ckpt.exists():
636
+ logger.info("=" * 60)
637
+ logger.info(f"๐Ÿ”„ RESUMING FROM CHECKPOINT: {custom_ckpt}")
638
+ logger.info("=" * 60)
639
  accelerator.load_state(str(custom_ckpt))
640
+ logger.info("โœ… Model, optimizer, scheduler, and dataloader states restored.")
641
  else:
642
  raise FileNotFoundError(f"Checkpoint not found at {custom_ckpt}")
643
  else:
644
+ logger.info("๐Ÿ†• No resume flag provided. Starting fresh training run.")
645
 
646
  # --- 7. Training Loop ---
647
  total_steps = 0
 
836
  logger.info(f"Epoch {epoch+1} complete. Avg loss: {avg_loss:.6f}")
837
  accelerator.log({"train/loss_epoch": avg_loss}, step=total_steps)
838
 
839
+ # Save Checkpoint at end of epoch
840
  save_path = checkpoint_dir / f"epoch_{epoch+1}"
841
+ accelerator.save_state(output_dir=str(save_path))
842
+ hf_save_path = save_path / "hf_model"
843
+ unwrapped_model = accelerator.unwrap_model(model)
844
+ unwrapped_model.save_pretrained(str(hf_save_path))
845
 
846
+ logger.info(f"Saved epoch checkpoint and HF-style model to {save_path}")
 
847
  else:
848
  if accelerator.is_main_process:
849
  logger.warning(f"Epoch {epoch+1}: No valid batches processed.")
train.sh CHANGED
@@ -1,22 +1,24 @@
 
 
1
  accelerate launch train.py \
2
- --epochs 1 \
3
- --batch_size 8 \
4
  --learning_rate 1e-4 \
5
  --warmup_ratio 0.1 \
6
- --grad_accum_steps 2 \
7
  --max_grad_norm 1.0 \
8
  --seed 42 \
9
- --log_every 3 \
10
- --save_every 2000 \
11
  --tensorboard_dir runs/oracle \
12
  --checkpoint_dir checkpoints \
13
  --mixed_precision bf16 \
14
  --max_seq_len 4096 \
15
- --horizons_seconds 30 60 120 240 420 \
16
  --quantiles 0.1 0.5 0.9 \
17
  --ohlc_stats_path ./data/ohlc_stats.npz \
18
  --num_workers 0 \
19
- --pin_memory \
20
- --val_split 0.1 \
21
- --val_every 50 \
22
  "$@"
 
1
+ export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
2
+
3
  accelerate launch train.py \
4
+ --epochs 5 \
5
+ --batch_size 2 \
6
  --learning_rate 1e-4 \
7
  --warmup_ratio 0.1 \
8
+ --grad_accum_steps 8 \
9
  --max_grad_norm 1.0 \
10
  --seed 42 \
11
+ --log_every 10 \
12
+ --save_every 1000 \
13
  --tensorboard_dir runs/oracle \
14
  --checkpoint_dir checkpoints \
15
  --mixed_precision bf16 \
16
  --max_seq_len 4096 \
17
+ --horizons_seconds 300 900 1800 3600 7200 \
18
  --quantiles 0.1 0.5 0.9 \
19
  --ohlc_stats_path ./data/ohlc_stats.npz \
20
  --num_workers 0 \
21
+ --val_samples_per_class 2 \
22
+ --val_every 100 \
23
+ --resume_from_checkpoint checkpoints/checkpoint-3000 \
24
  "$@"