# Spider-FLEXITOKENS Remote Training Guide ## Target Hardware: NVIDIA RTX 6000 Pro (Blackwell) - **GPU**: RTX 6000 Pro (Blackwell architecture, sm120+) - **VRAM**: 48GB GDDR7 - **Precision**: MXFP8 (rowwise_with_gw_hp recipe) — primary; FP8_DYNAMIC fallback - **Expected peak VRAM**: ~15-20GB (model ~4GB FP8, optimizer ~8GB standard AdamW, activations ~4-8GB with gradient checkpointing) ## Quick Start ```bash # 1. Clone/transfer the repo to the remote machine # 2. Install dependencies (see below) # 3. Run the launch script bash scripts/train_remote.sh ``` ## Environment Setup ### Required Dependencies ```bash pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128 pip install torchao>=0.17.0 pip install datasets transformers pip install bitsandbytes # optional — only used for BF16 fallback ``` ### Optional (Recommended) ```bash pip install unsloth # MoE kernel optimizations + memory-efficient GC ``` ### Verify Installation ```bash python3 -c " import torch print(f'PyTorch: {torch.__version__}') print(f'CUDA: {torch.version.cuda}') print(f'GPU: {torch.cuda.get_device_name(0)}') print(f'Compute capability: sm{torch.cuda.get_device_capability(0)[0]}') import torchao print(f'torchao: {torchao.__version__}') from torchao.float8 import Float8LinearConfig print('FP8 training: available') print(f'Recipes: {[n.value for n in __import__(\"torchao.float8.config\", fromlist=[\"Float8LinearRecipeName\"]).Float8LinearRecipeName]}') " ``` Expected output on RTX 6000 Pro: `sm120` or higher, all 3 recipes available (`tensorwise`, `rowwise`, `rowwise_with_gw_hp`). ## Configuration ### Environment Variables | Variable | Default | Description | |---|---|---| | `PRECISION` | `mxfp8` | Training precision: `mxfp8`, `fp8_dynamic`, `bf16` | | `SEQ_LEN` | `2048` | Sequence length per sample | | `MICRO_BATCH` | `8` | Batch size per forward pass | | `GRAD_ACCUM` | `4` | Gradient accumulation steps | | `TARGET_TOKENS` | `10000000000` | Total training tokens (10B) | | `N_LOOPS` | `6` | Recurrent loop iterations | | `LR` | `3e-4` | Peak learning rate | | `CKPT_EVERY` | `500` | Save checkpoint every N steps | | `CKPT_DIR` | `checkpoints-spider-remote` | Checkpoint output directory | | `RESUME` | _(empty)_ | Path to checkpoint for manual resume | ### Recommended Settings for RTX 6000 Pro (48GB) ```bash # MXFP8 — maximum accuracy, best VRAM efficiency export PRECISION=mxfp8 export MICRO_BATCH=8 export GRAD_ACCUM=4 # Global batch: 8 * 4 * 2048 = 65,536 tokens/step # ~10B tokens ≈ 152,000 steps ``` ### Conservative Settings (if VRAM-constrained) ```bash export PRECISION=fp8_dynamic export MICRO_BATCH=4 export GRAD_ACCUM=8 # Global batch: 4 * 8 * 2048 = 65,536 tokens/step (same throughput, lower peak VRAM) ``` ## Launch ### Fresh Training Run ```bash bash scripts/train_remote.sh ``` ### Resume from Checkpoint ```bash # Auto-resume (picks latest from CKPT_DIR) bash scripts/train_remote.sh # Manual resume from specific checkpoint export RESUME=checkpoints-spider-remote/spider-step5000.pt bash scripts/train_remote.sh ``` ### Resume from Local Smoke Test Transfer the local checkpoint to the remote machine, then: ```bash export RESUME=checkpoints-spider-real/spider-final-ep1.pt bash scripts/train_remote.sh ``` **Note**: The local checkpoint was trained with 8-bit AdamW (BF16). On resume with MXFP8/FP8, the training script will: 1. Load model weights (always compatible) 2. Skip 8-bit optimizer state with a warning (8-bit → standard AdamW mismatch) 3. Continue training with standard AdamW from step 0 optimizer state This is by design — the optimizer state mismatch is handled gracefully. ## Monitoring ### Training Logs The script outputs structured logs every 10 steps: ``` Epoch 1 | step 10/152000 | loss 3.2140 | lm 3.1020 | aux 0.0312 | bp 0.0808 [FIXED/FROZEN] | gnorm 1.23 | lr 3.00e-04 | 0.42M tok/s | 0.07B tokens ``` Key metrics: - **loss**: Total loss (lm + aux + bp) - **lm**: Language modeling loss - **aux**: MoE load-balancing auxiliary loss - **bp**: Boundary predictor loss [FIXED=30% curriculum / ADAPTIVE=learned] - **gnorm**: Gradient norm (should stabilize ~1-5) - **tok/s**: Throughput (expect 0.5-1.0M tok/s on RTX 6000 Pro) ### VRAM Monitoring ```bash watch -n 5 nvidia-smi ``` Expected on RTX 6000 Pro with MXFP8: - Model: ~2GB (weights in FP8) - Optimizer: ~8GB (standard AdamW, FP32 states) - Activations: ~4-8GB (gradient checkpointing enabled) - **Peak**: ~15-20GB total ### Health Warnings The `RecurrentMonitor` checks for: - **Representation drift**: Loop hidden states diverging (cosine sim < 0.5) - **Collapse**: All experts producing identical outputs (std < 1e-6) If you see these warnings, consider reducing `N_LOOPS` or lowering learning rate. ## Precision Fallback Chain The training script automatically falls back if precision setup fails: ``` MXFP8 (sm120+ Blackwell) → FP8_DYNAMIC (sm89+ Ada) → BF16 (all GPUs) ``` - **MXFP8**: Row-wise scaling + high-precision grad weight accumulation. Best accuracy. - **FP8_DYNAMIC**: Row-wise dynamic scaling. Good accuracy/performance tradeoff. - **BF16**: No quantization. Most VRAM, but simplest path. ## Checkpoint Files | File | Description | |---|---| | `spider-step{N}.pt` | Step checkpoint (every `CKPT_EVERY` steps) | | `spider-ep{N}.pt` | Epoch boundary checkpoint | | `spider-best.pt` | Best loss checkpoint (updated when epoch loss improves) | | `spider-final-ep{N}.pt` | Final checkpoint at training end | Each checkpoint contains: - Model state dict - Optimizer state dict - Training step, epoch, config - `best_loss` value - BP optimizer state (if active) ## Troubleshooting ### `mat2 shape must be divisible by 16` Fixed with `pad_inner_dim=True` in `Float8LinearConfig` (v0.17.0+). The training script handles this automatically. ### `CUDA out of memory` Reduce `MICRO_BATCH` or increase `GRAD_ACCUM` to maintain the same global batch size: ```bash export MICRO_BATCH=4 # was 8 export GRAD_ACCUM=8 # was 4 (same 65,536 tok/step) ``` ### Optimizer state mismatch on resume Expected when resuming a BF16 (8-bit Adam) checkpoint on FP8/MXFP8 (standard AdamW). The script logs a warning and continues — model weights load fine, optimizer restarts from scratch. ### Slower than expected throughput - Ensure `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True` is set (default in script) - Check `torch.compile` isn't being used inadvertently (adds compile overhead) - Verify torchao version >= 0.17.0 for optimal FP8 kernels