Upload folder using huggingface_hub
Browse files- .gitignore +1 -0
- data/ohlc_stats.npz +1 -1
- log.log +1 -1
- sample_123hWLTtXVG58ARU_0.json +0 -0
- sample_12m6essAkvZc4cRZ_0.json +0 -0
- train.py +39 -35
- train.sh +11 -9
.gitignore
CHANGED
|
@@ -17,3 +17,4 @@ metadata/
|
|
| 17 |
store/
|
| 18 |
preprocessed_configs/
|
| 19 |
.early.coverage
|
|
|
|
|
|
| 17 |
store/
|
| 18 |
preprocessed_configs/
|
| 19 |
.early.coverage
|
| 20 |
+
.ipynb_checkpoints/
|
data/ohlc_stats.npz
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1660
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fc81a5fbd342767ca491eceeca805f54dd9ffe4ff0bb723cdda26e52f54f914d
|
| 3 |
size 1660
|
log.log
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 2854
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e4ceaef802908dd650ce1ade210a0e827ec433b904adc3bf17c3d8a877e59ae6
|
| 3 |
size 2854
|
sample_123hWLTtXVG58ARU_0.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
sample_12m6essAkvZc4cRZ_0.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
train.py
CHANGED
|
@@ -38,7 +38,7 @@ from torch.optim import AdamW
|
|
| 38 |
from accelerate import Accelerator
|
| 39 |
from accelerate.logging import get_logger
|
| 40 |
from accelerate.utils import ProjectConfiguration, set_seed
|
| 41 |
-
from transformers import
|
| 42 |
|
| 43 |
# Logging
|
| 44 |
from tqdm.auto import tqdm
|
|
@@ -119,7 +119,7 @@ def quantile_pinball_loss(preds: torch.Tensor,
|
|
| 119 |
return sum(losses) / mask.sum().clamp_min(1.0)
|
| 120 |
|
| 121 |
|
| 122 |
-
def create_balanced_split(dataset,
|
| 123 |
"""
|
| 124 |
Create train/val split with balanced classes in validation set.
|
| 125 |
Uses dataset's internal file_class_map for speed (no file loading).
|
|
@@ -153,10 +153,10 @@ def create_balanced_split(dataset, val_ratio: float = 0.1, seed: int = 42):
|
|
| 153 |
train_indices = []
|
| 154 |
val_indices = []
|
| 155 |
|
| 156 |
-
# For each class, take
|
| 157 |
for class_id, indices in class_to_indices.items():
|
| 158 |
random.shuffle(indices)
|
| 159 |
-
n_val =
|
| 160 |
val_indices.extend(indices[:n_val])
|
| 161 |
train_indices.extend(indices[n_val:])
|
| 162 |
|
|
@@ -329,7 +329,7 @@ def parse_args() -> argparse.Namespace:
|
|
| 329 |
parser.add_argument("--pin_memory", dest="pin_memory", action="store_true", default=True)
|
| 330 |
parser.add_argument("--no-pin_memory", dest="pin_memory", action="store_false")
|
| 331 |
parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="Path to checkpoint or 'latest'")
|
| 332 |
-
parser.add_argument("--
|
| 333 |
parser.add_argument("--val_every", type=int, default=1000, help="Run validation every N steps (default 1000)")
|
| 334 |
return parser.parse_args()
|
| 335 |
|
|
@@ -374,7 +374,10 @@ def main() -> None:
|
|
| 374 |
logger.info("Initialized with CLI arguments.")
|
| 375 |
tensorboard_dir.mkdir(parents=True, exist_ok=True)
|
| 376 |
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
| 377 |
-
|
|
|
|
|
|
|
|
|
|
| 378 |
|
| 379 |
device = accelerator.device
|
| 380 |
|
|
@@ -413,7 +416,16 @@ def main() -> None:
|
|
| 413 |
logger.info(f"Initializing Encoders with dtype={init_dtype}...")
|
| 414 |
|
| 415 |
# Encoders
|
| 416 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 417 |
time_encoder = ContextualTimeEncoder(dtype=init_dtype)
|
| 418 |
token_encoder = TokenEncoder(multi_dim=multi_modal_encoder.embedding_dim, dtype=init_dtype)
|
| 419 |
wallet_encoder = WalletEncoder(encoder=multi_modal_encoder, dtype=init_dtype)
|
|
@@ -430,18 +442,6 @@ def main() -> None:
|
|
| 430 |
max_seq_len=max_seq_len
|
| 431 |
)
|
| 432 |
|
| 433 |
-
# --- OPTIMIZATION: Pre-load SigLIP encoder to avoid lazy-loading on first batch ---
|
| 434 |
-
# This moves the ~8s model load from first batch to startup (where it's expected)
|
| 435 |
-
logger.info("Pre-loading SigLIP encoder for collator (avoids first-batch delay)...")
|
| 436 |
-
from models.multi_modal_processor import MultiModalEncoder as CollatorEncoder
|
| 437 |
-
collator_encoder = CollatorEncoder(
|
| 438 |
-
model_id=collator.model_id,
|
| 439 |
-
dtype=init_dtype,
|
| 440 |
-
device="cuda" # Use GPU for encoding (requires num_workers=0)
|
| 441 |
-
)
|
| 442 |
-
_set_worker_encoder(collator_encoder)
|
| 443 |
-
logger.info("SigLIP encoder pre-loaded successfully.")
|
| 444 |
-
|
| 445 |
# ==========================================================================
|
| 446 |
# OFFLINE MODE: No DB connections during training for maximum GPU utilization
|
| 447 |
# ==========================================================================
|
|
@@ -475,18 +475,18 @@ def main() -> None:
|
|
| 475 |
raise RuntimeError("Dataset is empty.")
|
| 476 |
|
| 477 |
# --- NEW: Create balanced train/val split ---
|
| 478 |
-
logger.info(f"Creating {
|
| 479 |
train_indices, val_indices, class_distribution = create_balanced_split(
|
| 480 |
-
dataset,
|
| 481 |
)
|
| 482 |
|
| 483 |
# Log class distribution (use set for O(1) lookup)
|
| 484 |
train_set = set(train_indices)
|
| 485 |
logger.info(f"Total samples: {len(dataset)}, Train: {len(train_indices)}, Val: {len(val_indices)}")
|
| 486 |
for class_id, indices in sorted(class_distribution.items()):
|
| 487 |
-
n_val =
|
| 488 |
n_train = len(indices) - n_val
|
| 489 |
-
logger.info(f" Class {class_id}: {len(indices)} total (~{n_train} train,
|
| 490 |
|
| 491 |
# --- Compute class weights for loss weighting ---
|
| 492 |
num_classes = max(class_distribution.keys()) + 1 if class_distribution else 7
|
|
@@ -544,7 +544,7 @@ def main() -> None:
|
|
| 544 |
# Validation dataloader (no shuffle, no weighted sampling)
|
| 545 |
val_dl_kwargs = dict(
|
| 546 |
dataset=val_dataset,
|
| 547 |
-
batch_size=
|
| 548 |
shuffle=False,
|
| 549 |
num_workers=int(args.num_workers),
|
| 550 |
pin_memory=bool(args.pin_memory),
|
|
@@ -588,7 +588,7 @@ def main() -> None:
|
|
| 588 |
max_train_steps = epochs * num_update_steps_per_epoch
|
| 589 |
num_warmup_steps = int(max_train_steps * warmup_ratio)
|
| 590 |
|
| 591 |
-
scheduler =
|
| 592 |
optimizer,
|
| 593 |
num_warmup_steps=num_warmup_steps,
|
| 594 |
num_training_steps=max_train_steps
|
|
@@ -623,7 +623,9 @@ def main() -> None:
|
|
| 623 |
if args.resume_from_checkpoint:
|
| 624 |
if args.resume_from_checkpoint == "latest":
|
| 625 |
if latest_checkpoint:
|
| 626 |
-
logger.info(
|
|
|
|
|
|
|
| 627 |
accelerator.load_state(str(latest_checkpoint))
|
| 628 |
else:
|
| 629 |
logger.warning("Resume requested but no checkpoint found in dir. Starting fresh.")
|
|
@@ -631,12 +633,15 @@ def main() -> None:
|
|
| 631 |
# Specific path
|
| 632 |
custom_ckpt = Path(args.resume_from_checkpoint)
|
| 633 |
if custom_ckpt.exists():
|
| 634 |
-
logger.info(
|
|
|
|
|
|
|
| 635 |
accelerator.load_state(str(custom_ckpt))
|
|
|
|
| 636 |
else:
|
| 637 |
raise FileNotFoundError(f"Checkpoint not found at {custom_ckpt}")
|
| 638 |
else:
|
| 639 |
-
logger.info("No resume flag provided. Starting fresh.")
|
| 640 |
|
| 641 |
# --- 7. Training Loop ---
|
| 642 |
total_steps = 0
|
|
@@ -831,15 +836,14 @@ def main() -> None:
|
|
| 831 |
logger.info(f"Epoch {epoch+1} complete. Avg loss: {avg_loss:.6f}")
|
| 832 |
accelerator.log({"train/loss_epoch": avg_loss}, step=total_steps)
|
| 833 |
|
| 834 |
-
# Save Checkpoint at end of epoch
|
| 835 |
save_path = checkpoint_dir / f"epoch_{epoch+1}"
|
| 836 |
-
|
| 837 |
-
|
| 838 |
-
|
| 839 |
-
|
| 840 |
|
| 841 |
-
|
| 842 |
-
pass
|
| 843 |
else:
|
| 844 |
if accelerator.is_main_process:
|
| 845 |
logger.warning(f"Epoch {epoch+1}: No valid batches processed.")
|
|
|
|
| 38 |
from accelerate import Accelerator
|
| 39 |
from accelerate.logging import get_logger
|
| 40 |
from accelerate.utils import ProjectConfiguration, set_seed
|
| 41 |
+
from transformers import get_cosine_schedule_with_warmup
|
| 42 |
|
| 43 |
# Logging
|
| 44 |
from tqdm.auto import tqdm
|
|
|
|
| 119 |
return sum(losses) / mask.sum().clamp_min(1.0)
|
| 120 |
|
| 121 |
|
| 122 |
+
def create_balanced_split(dataset, n_val_per_class: int = 1, seed: int = 42):
|
| 123 |
"""
|
| 124 |
Create train/val split with balanced classes in validation set.
|
| 125 |
Uses dataset's internal file_class_map for speed (no file loading).
|
|
|
|
| 153 |
train_indices = []
|
| 154 |
val_indices = []
|
| 155 |
|
| 156 |
+
# For each class, take n_val_per_class samples for validation
|
| 157 |
for class_id, indices in class_to_indices.items():
|
| 158 |
random.shuffle(indices)
|
| 159 |
+
n_val = min(len(indices), n_val_per_class) # Ensure we don't take more than we have
|
| 160 |
val_indices.extend(indices[:n_val])
|
| 161 |
train_indices.extend(indices[n_val:])
|
| 162 |
|
|
|
|
| 329 |
parser.add_argument("--pin_memory", dest="pin_memory", action="store_true", default=True)
|
| 330 |
parser.add_argument("--no-pin_memory", dest="pin_memory", action="store_false")
|
| 331 |
parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="Path to checkpoint or 'latest'")
|
| 332 |
+
parser.add_argument("--val_samples_per_class", type=int, default=1, help="Number of validation samples per class (default 1)")
|
| 333 |
parser.add_argument("--val_every", type=int, default=1000, help="Run validation every N steps (default 1000)")
|
| 334 |
return parser.parse_args()
|
| 335 |
|
|
|
|
| 374 |
logger.info("Initialized with CLI arguments.")
|
| 375 |
tensorboard_dir.mkdir(parents=True, exist_ok=True)
|
| 376 |
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
| 377 |
+
from datetime import datetime
|
| 378 |
+
run_name = f"oracle_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
| 379 |
+
accelerator.init_trackers(run_name)
|
| 380 |
+
logger.info(f"๐ TensorBoard run: {run_name}")
|
| 381 |
|
| 382 |
device = accelerator.device
|
| 383 |
|
|
|
|
| 416 |
logger.info(f"Initializing Encoders with dtype={init_dtype}...")
|
| 417 |
|
| 418 |
# Encoders
|
| 419 |
+
logger.info("Initializing Shared MultiModalEncoder (SigLIP) on GPU...")
|
| 420 |
+
# Initialize ONCE on GPU for both WalletEncoder (dims) and Collator (encoding)
|
| 421 |
+
multi_modal_encoder = MultiModalEncoder(dtype=init_dtype, device="cuda")
|
| 422 |
+
|
| 423 |
+
# Use this shared instance for setting the worker encoder (num_workers=0 optimization)
|
| 424 |
+
# This avoids loading a second copy of SigLIP
|
| 425 |
+
logger.info("Setting shared encoder for collator...")
|
| 426 |
+
from models.multi_modal_processor import MultiModalEncoder as CollatorEncoder
|
| 427 |
+
_set_worker_encoder(multi_modal_encoder)
|
| 428 |
+
|
| 429 |
time_encoder = ContextualTimeEncoder(dtype=init_dtype)
|
| 430 |
token_encoder = TokenEncoder(multi_dim=multi_modal_encoder.embedding_dim, dtype=init_dtype)
|
| 431 |
wallet_encoder = WalletEncoder(encoder=multi_modal_encoder, dtype=init_dtype)
|
|
|
|
| 442 |
max_seq_len=max_seq_len
|
| 443 |
)
|
| 444 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 445 |
# ==========================================================================
|
| 446 |
# OFFLINE MODE: No DB connections during training for maximum GPU utilization
|
| 447 |
# ==========================================================================
|
|
|
|
| 475 |
raise RuntimeError("Dataset is empty.")
|
| 476 |
|
| 477 |
# --- NEW: Create balanced train/val split ---
|
| 478 |
+
logger.info(f"Creating balanced split with {args.val_samples_per_class} validation samples per class...")
|
| 479 |
train_indices, val_indices, class_distribution = create_balanced_split(
|
| 480 |
+
dataset, n_val_per_class=args.val_samples_per_class, seed=seed
|
| 481 |
)
|
| 482 |
|
| 483 |
# Log class distribution (use set for O(1) lookup)
|
| 484 |
train_set = set(train_indices)
|
| 485 |
logger.info(f"Total samples: {len(dataset)}, Train: {len(train_indices)}, Val: {len(val_indices)}")
|
| 486 |
for class_id, indices in sorted(class_distribution.items()):
|
| 487 |
+
n_val = min(len(indices), args.val_samples_per_class)
|
| 488 |
n_train = len(indices) - n_val
|
| 489 |
+
logger.info(f" Class {class_id}: {len(indices)} total (~{n_train} train, {n_val} val)")
|
| 490 |
|
| 491 |
# --- Compute class weights for loss weighting ---
|
| 492 |
num_classes = max(class_distribution.keys()) + 1 if class_distribution else 7
|
|
|
|
| 544 |
# Validation dataloader (no shuffle, no weighted sampling)
|
| 545 |
val_dl_kwargs = dict(
|
| 546 |
dataset=val_dataset,
|
| 547 |
+
batch_size=1, # Force batch size 1 for validation to prevent OOM
|
| 548 |
shuffle=False,
|
| 549 |
num_workers=int(args.num_workers),
|
| 550 |
pin_memory=bool(args.pin_memory),
|
|
|
|
| 588 |
max_train_steps = epochs * num_update_steps_per_epoch
|
| 589 |
num_warmup_steps = int(max_train_steps * warmup_ratio)
|
| 590 |
|
| 591 |
+
scheduler = get_cosine_schedule_with_warmup(
|
| 592 |
optimizer,
|
| 593 |
num_warmup_steps=num_warmup_steps,
|
| 594 |
num_training_steps=max_train_steps
|
|
|
|
| 623 |
if args.resume_from_checkpoint:
|
| 624 |
if args.resume_from_checkpoint == "latest":
|
| 625 |
if latest_checkpoint:
|
| 626 |
+
logger.info("=" * 60)
|
| 627 |
+
logger.info(f"๐ RESUMING FROM LATEST CHECKPOINT: {latest_checkpoint}")
|
| 628 |
+
logger.info("=" * 60)
|
| 629 |
accelerator.load_state(str(latest_checkpoint))
|
| 630 |
else:
|
| 631 |
logger.warning("Resume requested but no checkpoint found in dir. Starting fresh.")
|
|
|
|
| 633 |
# Specific path
|
| 634 |
custom_ckpt = Path(args.resume_from_checkpoint)
|
| 635 |
if custom_ckpt.exists():
|
| 636 |
+
logger.info("=" * 60)
|
| 637 |
+
logger.info(f"๐ RESUMING FROM CHECKPOINT: {custom_ckpt}")
|
| 638 |
+
logger.info("=" * 60)
|
| 639 |
accelerator.load_state(str(custom_ckpt))
|
| 640 |
+
logger.info("โ
Model, optimizer, scheduler, and dataloader states restored.")
|
| 641 |
else:
|
| 642 |
raise FileNotFoundError(f"Checkpoint not found at {custom_ckpt}")
|
| 643 |
else:
|
| 644 |
+
logger.info("๐ No resume flag provided. Starting fresh training run.")
|
| 645 |
|
| 646 |
# --- 7. Training Loop ---
|
| 647 |
total_steps = 0
|
|
|
|
| 836 |
logger.info(f"Epoch {epoch+1} complete. Avg loss: {avg_loss:.6f}")
|
| 837 |
accelerator.log({"train/loss_epoch": avg_loss}, step=total_steps)
|
| 838 |
|
| 839 |
+
# Save Checkpoint at end of epoch
|
| 840 |
save_path = checkpoint_dir / f"epoch_{epoch+1}"
|
| 841 |
+
accelerator.save_state(output_dir=str(save_path))
|
| 842 |
+
hf_save_path = save_path / "hf_model"
|
| 843 |
+
unwrapped_model = accelerator.unwrap_model(model)
|
| 844 |
+
unwrapped_model.save_pretrained(str(hf_save_path))
|
| 845 |
|
| 846 |
+
logger.info(f"Saved epoch checkpoint and HF-style model to {save_path}")
|
|
|
|
| 847 |
else:
|
| 848 |
if accelerator.is_main_process:
|
| 849 |
logger.warning(f"Epoch {epoch+1}: No valid batches processed.")
|
train.sh
CHANGED
|
@@ -1,22 +1,24 @@
|
|
|
|
|
|
|
|
| 1 |
accelerate launch train.py \
|
| 2 |
-
--epochs
|
| 3 |
-
--batch_size
|
| 4 |
--learning_rate 1e-4 \
|
| 5 |
--warmup_ratio 0.1 \
|
| 6 |
-
--grad_accum_steps
|
| 7 |
--max_grad_norm 1.0 \
|
| 8 |
--seed 42 \
|
| 9 |
-
--log_every
|
| 10 |
-
--save_every
|
| 11 |
--tensorboard_dir runs/oracle \
|
| 12 |
--checkpoint_dir checkpoints \
|
| 13 |
--mixed_precision bf16 \
|
| 14 |
--max_seq_len 4096 \
|
| 15 |
-
--horizons_seconds
|
| 16 |
--quantiles 0.1 0.5 0.9 \
|
| 17 |
--ohlc_stats_path ./data/ohlc_stats.npz \
|
| 18 |
--num_workers 0 \
|
| 19 |
-
--
|
| 20 |
-
--
|
| 21 |
-
--
|
| 22 |
"$@"
|
|
|
|
| 1 |
+
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
| 2 |
+
|
| 3 |
accelerate launch train.py \
|
| 4 |
+
--epochs 5 \
|
| 5 |
+
--batch_size 2 \
|
| 6 |
--learning_rate 1e-4 \
|
| 7 |
--warmup_ratio 0.1 \
|
| 8 |
+
--grad_accum_steps 8 \
|
| 9 |
--max_grad_norm 1.0 \
|
| 10 |
--seed 42 \
|
| 11 |
+
--log_every 10 \
|
| 12 |
+
--save_every 1000 \
|
| 13 |
--tensorboard_dir runs/oracle \
|
| 14 |
--checkpoint_dir checkpoints \
|
| 15 |
--mixed_precision bf16 \
|
| 16 |
--max_seq_len 4096 \
|
| 17 |
+
--horizons_seconds 300 900 1800 3600 7200 \
|
| 18 |
--quantiles 0.1 0.5 0.9 \
|
| 19 |
--ohlc_stats_path ./data/ohlc_stats.npz \
|
| 20 |
--num_workers 0 \
|
| 21 |
+
--val_samples_per_class 2 \
|
| 22 |
+
--val_every 100 \
|
| 23 |
+
--resume_from_checkpoint checkpoints/checkpoint-3000 \
|
| 24 |
"$@"
|