Spider-FLEXITOKENS Remote Training Guide
Target Hardware: NVIDIA RTX 6000 Pro (Blackwell)
- GPU: RTX 6000 Pro (Blackwell architecture, sm120+)
- VRAM: 48GB GDDR7
- Precision: MXFP8 (rowwise_with_gw_hp recipe) β primary; FP8_DYNAMIC fallback
- Expected peak VRAM: ~15-20GB (model ~4GB FP8, optimizer ~8GB standard AdamW, activations ~4-8GB with gradient checkpointing)
Quick Start
# 1. Clone/transfer the repo to the remote machine
# 2. Install dependencies (see below)
# 3. Run the launch script
bash scripts/train_remote.sh
Environment Setup
Required Dependencies
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
pip install torchao>=0.17.0
pip install datasets transformers
pip install bitsandbytes # optional β only used for BF16 fallback
Optional (Recommended)
pip install unsloth # MoE kernel optimizations + memory-efficient GC
Verify Installation
python3 -c "
import torch
print(f'PyTorch: {torch.__version__}')
print(f'CUDA: {torch.version.cuda}')
print(f'GPU: {torch.cuda.get_device_name(0)}')
print(f'Compute capability: sm{torch.cuda.get_device_capability(0)[0]}')
import torchao
print(f'torchao: {torchao.__version__}')
from torchao.float8 import Float8LinearConfig
print('FP8 training: available')
print(f'Recipes: {[n.value for n in __import__(\"torchao.float8.config\", fromlist=[\"Float8LinearRecipeName\"]).Float8LinearRecipeName]}')
"
Expected output on RTX 6000 Pro: sm120 or higher, all 3 recipes available (tensorwise, rowwise, rowwise_with_gw_hp).
Configuration
Environment Variables
| Variable | Default | Description |
|---|---|---|
PRECISION |
mxfp8 |
Training precision: mxfp8, fp8_dynamic, bf16 |
SEQ_LEN |
2048 |
Sequence length per sample |
MICRO_BATCH |
8 |
Batch size per forward pass |
GRAD_ACCUM |
4 |
Gradient accumulation steps |
TARGET_TOKENS |
10000000000 |
Total training tokens (10B) |
N_LOOPS |
6 |
Recurrent loop iterations |
LR |
3e-4 |
Peak learning rate |
CKPT_EVERY |
500 |
Save checkpoint every N steps |
CKPT_DIR |
checkpoints-spider-remote |
Checkpoint output directory |
RESUME |
(empty) | Path to checkpoint for manual resume |
Recommended Settings for RTX 6000 Pro (48GB)
# MXFP8 β maximum accuracy, best VRAM efficiency
export PRECISION=mxfp8
export MICRO_BATCH=8
export GRAD_ACCUM=4
# Global batch: 8 * 4 * 2048 = 65,536 tokens/step
# ~10B tokens β 152,000 steps
Conservative Settings (if VRAM-constrained)
export PRECISION=fp8_dynamic
export MICRO_BATCH=4
export GRAD_ACCUM=8
# Global batch: 4 * 8 * 2048 = 65,536 tokens/step (same throughput, lower peak VRAM)
Launch
Fresh Training Run
bash scripts/train_remote.sh
Resume from Checkpoint
# Auto-resume (picks latest from CKPT_DIR)
bash scripts/train_remote.sh
# Manual resume from specific checkpoint
export RESUME=checkpoints-spider-remote/spider-step5000.pt
bash scripts/train_remote.sh
Resume from Local Smoke Test
Transfer the local checkpoint to the remote machine, then:
export RESUME=checkpoints-spider-real/spider-final-ep1.pt
bash scripts/train_remote.sh
Note: The local checkpoint was trained with 8-bit AdamW (BF16). On resume with MXFP8/FP8, the training script will:
- Load model weights (always compatible)
- Skip 8-bit optimizer state with a warning (8-bit β standard AdamW mismatch)
- Continue training with standard AdamW from step 0 optimizer state
This is by design β the optimizer state mismatch is handled gracefully.
Monitoring
Training Logs
The script outputs structured logs every 10 steps:
Epoch 1 | step 10/152000 | loss 3.2140 | lm 3.1020 | aux 0.0312 | bp 0.0808 [FIXED/FROZEN] | gnorm 1.23 | lr 3.00e-04 | 0.42M tok/s | 0.07B tokens
Key metrics:
- loss: Total loss (lm + aux + bp)
- lm: Language modeling loss
- aux: MoE load-balancing auxiliary loss
- bp: Boundary predictor loss [FIXED=30% curriculum / ADAPTIVE=learned]
- gnorm: Gradient norm (should stabilize ~1-5)
- tok/s: Throughput (expect 0.5-1.0M tok/s on RTX 6000 Pro)
VRAM Monitoring
watch -n 5 nvidia-smi
Expected on RTX 6000 Pro with MXFP8:
- Model: ~2GB (weights in FP8)
- Optimizer: ~8GB (standard AdamW, FP32 states)
- Activations: ~4-8GB (gradient checkpointing enabled)
- Peak: ~15-20GB total
Health Warnings
The RecurrentMonitor checks for:
- Representation drift: Loop hidden states diverging (cosine sim < 0.5)
- Collapse: All experts producing identical outputs (std < 1e-6)
If you see these warnings, consider reducing N_LOOPS or lowering learning rate.
Precision Fallback Chain
The training script automatically falls back if precision setup fails:
MXFP8 (sm120+ Blackwell) β FP8_DYNAMIC (sm89+ Ada) β BF16 (all GPUs)
- MXFP8: Row-wise scaling + high-precision grad weight accumulation. Best accuracy.
- FP8_DYNAMIC: Row-wise dynamic scaling. Good accuracy/performance tradeoff.
- BF16: No quantization. Most VRAM, but simplest path.
Checkpoint Files
| File | Description |
|---|---|
spider-step{N}.pt |
Step checkpoint (every CKPT_EVERY steps) |
spider-ep{N}.pt |
Epoch boundary checkpoint |
spider-best.pt |
Best loss checkpoint (updated when epoch loss improves) |
spider-final-ep{N}.pt |
Final checkpoint at training end |
Each checkpoint contains:
- Model state dict
- Optimizer state dict
- Training step, epoch, config
best_lossvalue- BP optimizer state (if active)
Troubleshooting
mat2 shape must be divisible by 16
Fixed with pad_inner_dim=True in Float8LinearConfig (v0.17.0+). The training script handles this automatically.
CUDA out of memory
Reduce MICRO_BATCH or increase GRAD_ACCUM to maintain the same global batch size:
export MICRO_BATCH=4 # was 8
export GRAD_ACCUM=8 # was 4 (same 65,536 tok/step)
Optimizer state mismatch on resume
Expected when resuming a BF16 (8-bit Adam) checkpoint on FP8/MXFP8 (standard AdamW). The script logs a warning and continues β model weights load fine, optimizer restarts from scratch.
Slower than expected throughput
- Ensure
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:Trueis set (default in script) - Check
torch.compileisn't being used inadvertently (adds compile overhead) - Verify torchao version >= 0.17.0 for optimal FP8 kernels