Upload folder using huggingface_hub
Browse files- data/ohlc_stats.npz +1 -1
- inference_utils.py +34 -0
- ingest.sh +1 -1
- train.py +132 -22
- train.sh +2 -2
data/ohlc_stats.npz
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1660
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:92f50d146182941b8b01be19b4699c1b0ebe37bac1ff155580b20a8755994070
|
| 3 |
size 1660
|
inference_utils.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
def transform_targets(targets):
|
| 4 |
+
"""
|
| 5 |
+
Applies the log-transform used during training:
|
| 6 |
+
y_trans = sign(y) * log(1 + |y|)
|
| 7 |
+
|
| 8 |
+
Args:
|
| 9 |
+
targets: torch.Tensor or float, raw returns (e.g. 1.5 for 150%)
|
| 10 |
+
Returns:
|
| 11 |
+
transformed targets in the same shape/type
|
| 12 |
+
"""
|
| 13 |
+
if isinstance(targets, torch.Tensor):
|
| 14 |
+
return torch.sign(targets) * torch.log1p(torch.abs(targets))
|
| 15 |
+
else:
|
| 16 |
+
# Handle float/numpy
|
| 17 |
+
import numpy as np
|
| 18 |
+
return np.sign(targets) * np.log1p(np.abs(targets))
|
| 19 |
+
|
| 20 |
+
def inverse_transform_targets(transformed_targets):
|
| 21 |
+
"""
|
| 22 |
+
Inverts the log-transform to get back raw returns:
|
| 23 |
+
y = sign(y_trans) * (exp(|y_trans|) - 1)
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
transformed_targets: torch.Tensor, model outputs (logits)
|
| 27 |
+
Returns:
|
| 28 |
+
raw returns (e.g. 1.5 for 150%)
|
| 29 |
+
"""
|
| 30 |
+
if isinstance(transformed_targets, torch.Tensor):
|
| 31 |
+
return torch.sign(transformed_targets) * (torch.exp(torch.abs(transformed_targets)) - 1)
|
| 32 |
+
else:
|
| 33 |
+
import numpy as np
|
| 34 |
+
return np.sign(transformed_targets) * (np.exp(np.abs(transformed_targets)) - 1)
|
ingest.sh
CHANGED
|
@@ -20,7 +20,7 @@ error() { echo -e "${RED}[ERROR]${NC} $1"; exit 1; }
|
|
| 20 |
#===============================================================================
|
| 21 |
header "Step 5-6/7: Processing Epochs (Download → Ingest → Delete)"
|
| 22 |
|
| 23 |
-
EPOCHS=(844 845 846)
|
| 24 |
|
| 25 |
|
| 26 |
log "Processing epochs one at a time to minimize disk usage..."
|
|
|
|
| 20 |
#===============================================================================
|
| 21 |
header "Step 5-6/7: Processing Epochs (Download → Ingest → Delete)"
|
| 22 |
|
| 23 |
+
EPOCHS=(844 845 846 847 848 849)
|
| 24 |
|
| 25 |
|
| 26 |
log "Processing epochs one at a time to minimize disk usage..."
|
train.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import os
|
|
|
|
| 2 |
import argparse
|
| 3 |
import math
|
| 4 |
import logging
|
|
@@ -202,6 +203,8 @@ def run_validation(model, val_dataloader, accelerator, quantiles, quality_loss_f
|
|
| 202 |
if labels_mask.sum() == 0:
|
| 203 |
return_loss = torch.tensor(0.0, device=accelerator.device)
|
| 204 |
else:
|
|
|
|
|
|
|
| 205 |
return_loss = quantile_pinball_loss(preds, labels, labels_mask, quantiles)
|
| 206 |
|
| 207 |
quality_loss = quality_loss_fn(quality_preds, quality_targets)
|
|
@@ -608,43 +611,102 @@ def main() -> None:
|
|
| 608 |
# Load checkpoint if it exists
|
| 609 |
starting_epoch = 0
|
| 610 |
resume_step = 0
|
|
|
|
|
|
|
| 611 |
|
| 612 |
# Check for existing checkpoints
|
| 613 |
if checkpoint_dir.exists():
|
| 614 |
-
# Look for subfolders named 'checkpoint-X' or 'epoch_X'
|
| 615 |
-
# Accelerate saves to folders.
|
| 616 |
dirs = [d for d in checkpoint_dir.iterdir() if d.is_dir()]
|
| 617 |
if dirs:
|
| 618 |
-
# Sort by modification time or name to find latest
|
| 619 |
-
# Sort by modification time or name to find latest
|
| 620 |
dirs.sort(key=lambda x: x.stat().st_mtime)
|
| 621 |
latest_checkpoint = dirs[-1]
|
| 622 |
|
| 623 |
if args.resume_from_checkpoint:
|
|
|
|
| 624 |
if args.resume_from_checkpoint == "latest":
|
| 625 |
if latest_checkpoint:
|
| 626 |
-
|
| 627 |
-
logger.info(f"🔄 RESUMING FROM LATEST CHECKPOINT: {latest_checkpoint}")
|
| 628 |
-
logger.info("=" * 60)
|
| 629 |
-
accelerator.load_state(str(latest_checkpoint))
|
| 630 |
else:
|
| 631 |
logger.warning("Resume requested but no checkpoint found in dir. Starting fresh.")
|
| 632 |
else:
|
| 633 |
-
# Specific path
|
| 634 |
custom_ckpt = Path(args.resume_from_checkpoint)
|
| 635 |
if custom_ckpt.exists():
|
| 636 |
-
|
| 637 |
-
logger.info(f"🔄 RESUMING FROM CHECKPOINT: {custom_ckpt}")
|
| 638 |
-
logger.info("=" * 60)
|
| 639 |
-
accelerator.load_state(str(custom_ckpt))
|
| 640 |
-
logger.info("✅ Model, optimizer, scheduler, and dataloader states restored.")
|
| 641 |
else:
|
| 642 |
raise FileNotFoundError(f"Checkpoint not found at {custom_ckpt}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 643 |
else:
|
| 644 |
logger.info("🆕 No resume flag provided. Starting fresh training run.")
|
| 645 |
|
| 646 |
# --- 7. Training Loop ---
|
| 647 |
-
total_steps = 0
|
| 648 |
quality_loss_fn = nn.MSELoss()
|
| 649 |
|
| 650 |
logger.info("***** Running training *****")
|
|
@@ -660,18 +722,21 @@ def main() -> None:
|
|
| 660 |
epoch_loss = 0.0
|
| 661 |
valid_batches = 0
|
| 662 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 663 |
# Tqdm only on main process
|
| 664 |
progress_bar = tqdm(
|
| 665 |
-
|
| 666 |
desc=f"Epoch {epoch+1}/{epochs}",
|
| 667 |
disable=not accelerator.is_local_main_process,
|
| 668 |
-
initial=resume_step
|
| 669 |
)
|
| 670 |
|
| 671 |
for step, batch in enumerate(progress_bar):
|
| 672 |
-
# Skip steps if resuming (Accelerate dataloader might handle this automatically if configured,
|
| 673 |
-
# but 'skip_first_batches' is often manual.
|
| 674 |
-
# For simplicity here, we assume load_state restored the dataloader iterator.)
|
| 675 |
|
| 676 |
if batch is None:
|
| 677 |
continue
|
|
@@ -685,7 +750,8 @@ def main() -> None:
|
|
| 685 |
|
| 686 |
grad_stats: Optional[Dict[str, float]] = None
|
| 687 |
module_grad_stats: Dict[str, float] = {}
|
| 688 |
-
|
|
|
|
| 689 |
outputs = model(batch)
|
| 690 |
|
| 691 |
preds = outputs["quantile_logits"]
|
|
@@ -706,6 +772,11 @@ def main() -> None:
|
|
| 706 |
t_cutoffs[0] if t_cutoffs else "unknown",
|
| 707 |
)
|
| 708 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 709 |
if labels_mask.sum() == 0:
|
| 710 |
return_loss = torch.tensor(0.0, requires_grad=True, device=accelerator.device)
|
| 711 |
else:
|
|
@@ -747,6 +818,26 @@ def main() -> None:
|
|
| 747 |
scheduler.step()
|
| 748 |
optimizer.zero_grad()
|
| 749 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 750 |
# Logging
|
| 751 |
if accelerator.sync_gradients:
|
| 752 |
total_steps += 1
|
|
@@ -794,7 +885,16 @@ def main() -> None:
|
|
| 794 |
save_path = checkpoint_dir / f"checkpoint-{total_steps}"
|
| 795 |
accelerator.save_state(output_dir=str(save_path))
|
| 796 |
|
| 797 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 798 |
hf_save_path = save_path / "hf_model"
|
| 799 |
unwrapped_model = accelerator.unwrap_model(model)
|
| 800 |
unwrapped_model.save_pretrained(str(hf_save_path))
|
|
@@ -839,6 +939,16 @@ def main() -> None:
|
|
| 839 |
# Save Checkpoint at end of epoch
|
| 840 |
save_path = checkpoint_dir / f"epoch_{epoch+1}"
|
| 841 |
accelerator.save_state(output_dir=str(save_path))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 842 |
hf_save_path = save_path / "hf_model"
|
| 843 |
unwrapped_model = accelerator.unwrap_model(model)
|
| 844 |
unwrapped_model.save_pretrained(str(hf_save_path))
|
|
|
|
| 1 |
import os
|
| 2 |
+
import json
|
| 3 |
import argparse
|
| 4 |
import math
|
| 5 |
import logging
|
|
|
|
| 203 |
if labels_mask.sum() == 0:
|
| 204 |
return_loss = torch.tensor(0.0, device=accelerator.device)
|
| 205 |
else:
|
| 206 |
+
# Log-transform targets for validation too (so val loss matches train loss scale)
|
| 207 |
+
labels = torch.sign(labels) * torch.log1p(torch.abs(labels))
|
| 208 |
return_loss = quantile_pinball_loss(preds, labels, labels_mask, quantiles)
|
| 209 |
|
| 210 |
quality_loss = quality_loss_fn(quality_preds, quality_targets)
|
|
|
|
| 611 |
# Load checkpoint if it exists
|
| 612 |
starting_epoch = 0
|
| 613 |
resume_step = 0
|
| 614 |
+
total_steps = 0
|
| 615 |
+
latest_checkpoint = None
|
| 616 |
|
| 617 |
# Check for existing checkpoints
|
| 618 |
if checkpoint_dir.exists():
|
|
|
|
|
|
|
| 619 |
dirs = [d for d in checkpoint_dir.iterdir() if d.is_dir()]
|
| 620 |
if dirs:
|
|
|
|
|
|
|
| 621 |
dirs.sort(key=lambda x: x.stat().st_mtime)
|
| 622 |
latest_checkpoint = dirs[-1]
|
| 623 |
|
| 624 |
if args.resume_from_checkpoint:
|
| 625 |
+
resolved_ckpt = None
|
| 626 |
if args.resume_from_checkpoint == "latest":
|
| 627 |
if latest_checkpoint:
|
| 628 |
+
resolved_ckpt = latest_checkpoint
|
|
|
|
|
|
|
|
|
|
| 629 |
else:
|
| 630 |
logger.warning("Resume requested but no checkpoint found in dir. Starting fresh.")
|
| 631 |
else:
|
|
|
|
| 632 |
custom_ckpt = Path(args.resume_from_checkpoint)
|
| 633 |
if custom_ckpt.exists():
|
| 634 |
+
resolved_ckpt = custom_ckpt
|
|
|
|
|
|
|
|
|
|
|
|
|
| 635 |
else:
|
| 636 |
raise FileNotFoundError(f"Checkpoint not found at {custom_ckpt}")
|
| 637 |
+
|
| 638 |
+
if resolved_ckpt is not None:
|
| 639 |
+
logger.info("=" * 60)
|
| 640 |
+
logger.info(f"🔄 RESUMING FROM CHECKPOINT: {resolved_ckpt}")
|
| 641 |
+
logger.info("=" * 60)
|
| 642 |
+
accelerator.load_state(str(resolved_ckpt))
|
| 643 |
+
|
| 644 |
+
# Restore epoch / step counters from training_state.json
|
| 645 |
+
state_file = resolved_ckpt / "training_state.json"
|
| 646 |
+
if state_file.exists():
|
| 647 |
+
with open(state_file, "r") as f:
|
| 648 |
+
training_state = json.load(f)
|
| 649 |
+
starting_epoch = training_state.get("epoch", 0)
|
| 650 |
+
resume_step = training_state.get("step_in_epoch", 0)
|
| 651 |
+
total_steps = training_state.get("global_step", 0)
|
| 652 |
+
logger.info(
|
| 653 |
+
f"✅ Resumed: epoch={starting_epoch}, step_in_epoch={resume_step}, "
|
| 654 |
+
f"global_step={total_steps}"
|
| 655 |
+
)
|
| 656 |
+
else:
|
| 657 |
+
# Try to infer step count from the restored scheduler state
|
| 658 |
+
inner_sched = scheduler.scheduler if hasattr(scheduler, 'scheduler') else scheduler
|
| 659 |
+
restored_step = getattr(inner_sched, 'last_epoch', 0)
|
| 660 |
+
if restored_step > 0:
|
| 661 |
+
total_steps = restored_step
|
| 662 |
+
logger.warning(
|
| 663 |
+
f"⚠️ training_state.json not found. "
|
| 664 |
+
f"Inferred global_step={total_steps} from scheduler state."
|
| 665 |
+
)
|
| 666 |
+
else:
|
| 667 |
+
logger.warning(
|
| 668 |
+
"⚠️ training_state.json not found and scheduler step is 0. "
|
| 669 |
+
"Epoch/step counters start from 0."
|
| 670 |
+
)
|
| 671 |
+
logger.info("✅ Model, optimizer, scheduler, and dataloader states restored.")
|
| 672 |
+
|
| 673 |
+
# --- FIX: Rebuild scheduler to extend over new epochs ---
|
| 674 |
+
# Without this, the cosine schedule wraps/oscillates because the
|
| 675 |
+
# internal step counter (from past runs) exceeds the original
|
| 676 |
+
# max_train_steps (computed for just --epochs in this run).
|
| 677 |
+
# We extend the schedule: completed_steps → completed_steps + this_run_steps
|
| 678 |
+
# with NO warmup (model is already trained).
|
| 679 |
+
extended_total = total_steps + max_train_steps
|
| 680 |
+
logger.info(
|
| 681 |
+
f"♻️ Rebuilding LR schedule: "
|
| 682 |
+
f"completed={total_steps}, this_run={max_train_steps}, "
|
| 683 |
+
f"extended_total={extended_total} (no warmup)"
|
| 684 |
+
)
|
| 685 |
+
|
| 686 |
+
# Get the base optimizer (unwrap Accelerate wrapper)
|
| 687 |
+
base_opt = optimizer.optimizer if hasattr(optimizer, 'optimizer') else optimizer
|
| 688 |
+
new_sched = get_cosine_schedule_with_warmup(
|
| 689 |
+
base_opt,
|
| 690 |
+
num_warmup_steps=0, # no re-warmup
|
| 691 |
+
num_training_steps=extended_total
|
| 692 |
+
)
|
| 693 |
+
# Fast-forward to current position
|
| 694 |
+
new_sched.last_epoch = total_steps
|
| 695 |
+
new_sched._step_count = total_steps + 1
|
| 696 |
+
|
| 697 |
+
# Swap into the Accelerate wrapper
|
| 698 |
+
if hasattr(scheduler, 'scheduler'):
|
| 699 |
+
scheduler.scheduler = new_sched
|
| 700 |
+
else:
|
| 701 |
+
scheduler = new_sched
|
| 702 |
+
|
| 703 |
+
# Log the resulting LR for verification
|
| 704 |
+
current_lr = base_opt.param_groups[0]['lr']
|
| 705 |
+
logger.info(f"♻️ LR after rebuild: {current_lr:.6e}")
|
| 706 |
else:
|
| 707 |
logger.info("🆕 No resume flag provided. Starting fresh training run.")
|
| 708 |
|
| 709 |
# --- 7. Training Loop ---
|
|
|
|
| 710 |
quality_loss_fn = nn.MSELoss()
|
| 711 |
|
| 712 |
logger.info("***** Running training *****")
|
|
|
|
| 722 |
epoch_loss = 0.0
|
| 723 |
valid_batches = 0
|
| 724 |
|
| 725 |
+
# Skip already-processed batches when resuming mid-epoch
|
| 726 |
+
active_dataloader = dataloader
|
| 727 |
+
if epoch == starting_epoch and resume_step > 0:
|
| 728 |
+
logger.info(f"⏩ Skipping {resume_step} batches in epoch {epoch+1} (already processed)")
|
| 729 |
+
active_dataloader = accelerator.skip_first_batches(dataloader, num_batches=resume_step)
|
| 730 |
+
|
| 731 |
# Tqdm only on main process
|
| 732 |
progress_bar = tqdm(
|
| 733 |
+
active_dataloader,
|
| 734 |
desc=f"Epoch {epoch+1}/{epochs}",
|
| 735 |
disable=not accelerator.is_local_main_process,
|
| 736 |
+
initial=resume_step if epoch == starting_epoch else 0
|
| 737 |
)
|
| 738 |
|
| 739 |
for step, batch in enumerate(progress_bar):
|
|
|
|
|
|
|
|
|
|
| 740 |
|
| 741 |
if batch is None:
|
| 742 |
continue
|
|
|
|
| 750 |
|
| 751 |
grad_stats: Optional[Dict[str, float]] = None
|
| 752 |
module_grad_stats: Dict[str, float] = {}
|
| 753 |
+
try:
|
| 754 |
+
with accelerator.accumulate(model):
|
| 755 |
outputs = model(batch)
|
| 756 |
|
| 757 |
preds = outputs["quantile_logits"]
|
|
|
|
| 772 |
t_cutoffs[0] if t_cutoffs else "unknown",
|
| 773 |
)
|
| 774 |
|
| 775 |
+
# Log-transform targets to stabilize Class 5 gradients
|
| 776 |
+
# y_trans = sign(y) * log(1 + |y|)
|
| 777 |
+
# Compresses 10000x return (label ~15) to ~2.7
|
| 778 |
+
labels = torch.sign(labels) * torch.log1p(torch.abs(labels))
|
| 779 |
+
|
| 780 |
if labels_mask.sum() == 0:
|
| 781 |
return_loss = torch.tensor(0.0, requires_grad=True, device=accelerator.device)
|
| 782 |
else:
|
|
|
|
| 818 |
scheduler.step()
|
| 819 |
optimizer.zero_grad()
|
| 820 |
|
| 821 |
+
except torch.cuda.OutOfMemoryError:
|
| 822 |
+
# Log the offending batch for debugging
|
| 823 |
+
token_addresses = batch.get('token_addresses', [])
|
| 824 |
+
sample_indices = batch.get('sample_indices', [])
|
| 825 |
+
seq_len = batch.get('event_type_ids', torch.empty(0)).shape[-1] if 'event_type_ids' in batch else '?'
|
| 826 |
+
logger.warning(
|
| 827 |
+
"⚠️ CUDA OOM — skipping batch! "
|
| 828 |
+
"seq_len=%s sample_idx=%s token=%s | "
|
| 829 |
+
"Clearing cache and continuing.",
|
| 830 |
+
seq_len,
|
| 831 |
+
sample_indices[0] if sample_indices else "unknown",
|
| 832 |
+
token_addresses[0] if token_addresses else "unknown",
|
| 833 |
+
)
|
| 834 |
+
# Clean up to recover
|
| 835 |
+
optimizer.zero_grad(set_to_none=True)
|
| 836 |
+
torch.cuda.empty_cache()
|
| 837 |
+
if hasattr(torch.cuda, 'reset_peak_memory_stats'):
|
| 838 |
+
torch.cuda.reset_peak_memory_stats()
|
| 839 |
+
continue
|
| 840 |
+
|
| 841 |
# Logging
|
| 842 |
if accelerator.sync_gradients:
|
| 843 |
total_steps += 1
|
|
|
|
| 885 |
save_path = checkpoint_dir / f"checkpoint-{total_steps}"
|
| 886 |
accelerator.save_state(output_dir=str(save_path))
|
| 887 |
|
| 888 |
+
# Save resume metadata
|
| 889 |
+
state_file = save_path / "training_state.json"
|
| 890 |
+
with open(state_file, "w") as f:
|
| 891 |
+
json.dump({
|
| 892 |
+
"epoch": epoch,
|
| 893 |
+
"step_in_epoch": step + (resume_step if epoch == starting_epoch else 0) + 1,
|
| 894 |
+
"global_step": total_steps,
|
| 895 |
+
}, f, indent=2)
|
| 896 |
+
|
| 897 |
+
# Save in standard HF-loadable way
|
| 898 |
hf_save_path = save_path / "hf_model"
|
| 899 |
unwrapped_model = accelerator.unwrap_model(model)
|
| 900 |
unwrapped_model.save_pretrained(str(hf_save_path))
|
|
|
|
| 939 |
# Save Checkpoint at end of epoch
|
| 940 |
save_path = checkpoint_dir / f"epoch_{epoch+1}"
|
| 941 |
accelerator.save_state(output_dir=str(save_path))
|
| 942 |
+
|
| 943 |
+
# Save resume metadata (epoch completed, so next epoch starts fresh)
|
| 944 |
+
state_file = save_path / "training_state.json"
|
| 945 |
+
with open(state_file, "w") as f:
|
| 946 |
+
json.dump({
|
| 947 |
+
"epoch": epoch + 1,
|
| 948 |
+
"step_in_epoch": 0,
|
| 949 |
+
"global_step": total_steps,
|
| 950 |
+
}, f, indent=2)
|
| 951 |
+
|
| 952 |
hf_save_path = save_path / "hf_model"
|
| 953 |
unwrapped_model = accelerator.unwrap_model(model)
|
| 954 |
unwrapped_model.save_pretrained(str(hf_save_path))
|
train.sh
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
| 2 |
|
| 3 |
accelerate launch train.py \
|
| 4 |
-
--epochs
|
| 5 |
--batch_size 2 \
|
| 6 |
--learning_rate 1e-4 \
|
| 7 |
--warmup_ratio 0.1 \
|
|
@@ -20,5 +20,5 @@ accelerate launch train.py \
|
|
| 20 |
--num_workers 0 \
|
| 21 |
--val_samples_per_class 2 \
|
| 22 |
--val_every 100 \
|
| 23 |
-
--resume_from_checkpoint checkpoints/
|
| 24 |
"$@"
|
|
|
|
| 1 |
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
| 2 |
|
| 3 |
accelerate launch train.py \
|
| 4 |
+
--epochs 10 \
|
| 5 |
--batch_size 2 \
|
| 6 |
--learning_rate 1e-4 \
|
| 7 |
--warmup_ratio 0.1 \
|
|
|
|
| 20 |
--num_workers 0 \
|
| 21 |
--val_samples_per_class 2 \
|
| 22 |
--val_every 100 \
|
| 23 |
+
--resume_from_checkpoint checkpoints/apollo-1 \
|
| 24 |
"$@"
|