Spider-FLEXITOKENS / remote-training.md
CLIWorks's picture
Upload remote-training.md with huggingface_hub
c16af35 verified

Spider-FLEXITOKENS Remote Training Guide

Target Hardware: NVIDIA RTX 6000 Pro (Blackwell)

  • GPU: RTX 6000 Pro (Blackwell architecture, sm120+)
  • VRAM: 48GB GDDR7
  • Precision: MXFP8 (rowwise_with_gw_hp recipe) β€” primary; FP8_DYNAMIC fallback
  • Expected peak VRAM: ~15-20GB (model ~4GB FP8, optimizer ~8GB standard AdamW, activations ~4-8GB with gradient checkpointing)

Quick Start

# 1. Clone/transfer the repo to the remote machine
# 2. Install dependencies (see below)
# 3. Run the launch script
bash scripts/train_remote.sh

Environment Setup

Required Dependencies

pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
pip install torchao>=0.17.0
pip install datasets transformers
pip install bitsandbytes  # optional β€” only used for BF16 fallback

Optional (Recommended)

pip install unsloth  # MoE kernel optimizations + memory-efficient GC

Verify Installation

python3 -c "
import torch
print(f'PyTorch: {torch.__version__}')
print(f'CUDA: {torch.version.cuda}')
print(f'GPU: {torch.cuda.get_device_name(0)}')
print(f'Compute capability: sm{torch.cuda.get_device_capability(0)[0]}')

import torchao
print(f'torchao: {torchao.__version__}')

from torchao.float8 import Float8LinearConfig
print('FP8 training: available')
print(f'Recipes: {[n.value for n in __import__(\"torchao.float8.config\", fromlist=[\"Float8LinearRecipeName\"]).Float8LinearRecipeName]}')
"

Expected output on RTX 6000 Pro: sm120 or higher, all 3 recipes available (tensorwise, rowwise, rowwise_with_gw_hp).

Configuration

Environment Variables

Variable Default Description
PRECISION mxfp8 Training precision: mxfp8, fp8_dynamic, bf16
SEQ_LEN 2048 Sequence length per sample
MICRO_BATCH 8 Batch size per forward pass
GRAD_ACCUM 4 Gradient accumulation steps
TARGET_TOKENS 10000000000 Total training tokens (10B)
N_LOOPS 6 Recurrent loop iterations
LR 3e-4 Peak learning rate
CKPT_EVERY 500 Save checkpoint every N steps
CKPT_DIR checkpoints-spider-remote Checkpoint output directory
RESUME (empty) Path to checkpoint for manual resume

Recommended Settings for RTX 6000 Pro (48GB)

# MXFP8 β€” maximum accuracy, best VRAM efficiency
export PRECISION=mxfp8
export MICRO_BATCH=8
export GRAD_ACCUM=4
# Global batch: 8 * 4 * 2048 = 65,536 tokens/step
# ~10B tokens β‰ˆ 152,000 steps

Conservative Settings (if VRAM-constrained)

export PRECISION=fp8_dynamic
export MICRO_BATCH=4
export GRAD_ACCUM=8
# Global batch: 4 * 8 * 2048 = 65,536 tokens/step (same throughput, lower peak VRAM)

Launch

Fresh Training Run

bash scripts/train_remote.sh

Resume from Checkpoint

# Auto-resume (picks latest from CKPT_DIR)
bash scripts/train_remote.sh

# Manual resume from specific checkpoint
export RESUME=checkpoints-spider-remote/spider-step5000.pt
bash scripts/train_remote.sh

Resume from Local Smoke Test

Transfer the local checkpoint to the remote machine, then:

export RESUME=checkpoints-spider-real/spider-final-ep1.pt
bash scripts/train_remote.sh

Note: The local checkpoint was trained with 8-bit AdamW (BF16). On resume with MXFP8/FP8, the training script will:

  1. Load model weights (always compatible)
  2. Skip 8-bit optimizer state with a warning (8-bit β†’ standard AdamW mismatch)
  3. Continue training with standard AdamW from step 0 optimizer state

This is by design β€” the optimizer state mismatch is handled gracefully.

Monitoring

Training Logs

The script outputs structured logs every 10 steps:

Epoch 1 | step    10/152000 | loss 3.2140 | lm 3.1020 | aux 0.0312 | bp 0.0808 [FIXED/FROZEN] | gnorm 1.23 | lr 3.00e-04 | 0.42M tok/s | 0.07B tokens

Key metrics:

  • loss: Total loss (lm + aux + bp)
  • lm: Language modeling loss
  • aux: MoE load-balancing auxiliary loss
  • bp: Boundary predictor loss [FIXED=30% curriculum / ADAPTIVE=learned]
  • gnorm: Gradient norm (should stabilize ~1-5)
  • tok/s: Throughput (expect 0.5-1.0M tok/s on RTX 6000 Pro)

VRAM Monitoring

watch -n 5 nvidia-smi

Expected on RTX 6000 Pro with MXFP8:

  • Model: ~2GB (weights in FP8)
  • Optimizer: ~8GB (standard AdamW, FP32 states)
  • Activations: ~4-8GB (gradient checkpointing enabled)
  • Peak: ~15-20GB total

Health Warnings

The RecurrentMonitor checks for:

  • Representation drift: Loop hidden states diverging (cosine sim < 0.5)
  • Collapse: All experts producing identical outputs (std < 1e-6)

If you see these warnings, consider reducing N_LOOPS or lowering learning rate.

Precision Fallback Chain

The training script automatically falls back if precision setup fails:

MXFP8 (sm120+ Blackwell) β†’ FP8_DYNAMIC (sm89+ Ada) β†’ BF16 (all GPUs)
  • MXFP8: Row-wise scaling + high-precision grad weight accumulation. Best accuracy.
  • FP8_DYNAMIC: Row-wise dynamic scaling. Good accuracy/performance tradeoff.
  • BF16: No quantization. Most VRAM, but simplest path.

Checkpoint Files

File Description
spider-step{N}.pt Step checkpoint (every CKPT_EVERY steps)
spider-ep{N}.pt Epoch boundary checkpoint
spider-best.pt Best loss checkpoint (updated when epoch loss improves)
spider-final-ep{N}.pt Final checkpoint at training end

Each checkpoint contains:

  • Model state dict
  • Optimizer state dict
  • Training step, epoch, config
  • best_loss value
  • BP optimizer state (if active)

Troubleshooting

mat2 shape must be divisible by 16

Fixed with pad_inner_dim=True in Float8LinearConfig (v0.17.0+). The training script handles this automatically.

CUDA out of memory

Reduce MICRO_BATCH or increase GRAD_ACCUM to maintain the same global batch size:

export MICRO_BATCH=4    # was 8
export GRAD_ACCUM=8     # was 4 (same 65,536 tok/step)

Optimizer state mismatch on resume

Expected when resuming a BF16 (8-bit Adam) checkpoint on FP8/MXFP8 (standard AdamW). The script logs a warning and continues β€” model weights load fine, optimizer restarts from scratch.

Slower than expected throughput

  • Ensure PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True is set (default in script)
  • Check torch.compile isn't being used inadvertently (adds compile overhead)
  • Verify torchao version >= 0.17.0 for optimal FP8 kernels