Spider-FLEXITOKENS / remote-training.md
CLIWorks's picture
Upload remote-training.md with huggingface_hub
c16af35 verified
# Spider-FLEXITOKENS Remote Training Guide
## Target Hardware: NVIDIA RTX 6000 Pro (Blackwell)
- **GPU**: RTX 6000 Pro (Blackwell architecture, sm120+)
- **VRAM**: 48GB GDDR7
- **Precision**: MXFP8 (rowwise_with_gw_hp recipe) — primary; FP8_DYNAMIC fallback
- **Expected peak VRAM**: ~15-20GB (model ~4GB FP8, optimizer ~8GB standard AdamW, activations ~4-8GB with gradient checkpointing)
## Quick Start
```bash
# 1. Clone/transfer the repo to the remote machine
# 2. Install dependencies (see below)
# 3. Run the launch script
bash scripts/train_remote.sh
```
## Environment Setup
### Required Dependencies
```bash
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
pip install torchao>=0.17.0
pip install datasets transformers
pip install bitsandbytes # optional — only used for BF16 fallback
```
### Optional (Recommended)
```bash
pip install unsloth # MoE kernel optimizations + memory-efficient GC
```
### Verify Installation
```bash
python3 -c "
import torch
print(f'PyTorch: {torch.__version__}')
print(f'CUDA: {torch.version.cuda}')
print(f'GPU: {torch.cuda.get_device_name(0)}')
print(f'Compute capability: sm{torch.cuda.get_device_capability(0)[0]}')
import torchao
print(f'torchao: {torchao.__version__}')
from torchao.float8 import Float8LinearConfig
print('FP8 training: available')
print(f'Recipes: {[n.value for n in __import__(\"torchao.float8.config\", fromlist=[\"Float8LinearRecipeName\"]).Float8LinearRecipeName]}')
"
```
Expected output on RTX 6000 Pro: `sm120` or higher, all 3 recipes available (`tensorwise`, `rowwise`, `rowwise_with_gw_hp`).
## Configuration
### Environment Variables
| Variable | Default | Description |
|---|---|---|
| `PRECISION` | `mxfp8` | Training precision: `mxfp8`, `fp8_dynamic`, `bf16` |
| `SEQ_LEN` | `2048` | Sequence length per sample |
| `MICRO_BATCH` | `8` | Batch size per forward pass |
| `GRAD_ACCUM` | `4` | Gradient accumulation steps |
| `TARGET_TOKENS` | `10000000000` | Total training tokens (10B) |
| `N_LOOPS` | `6` | Recurrent loop iterations |
| `LR` | `3e-4` | Peak learning rate |
| `CKPT_EVERY` | `500` | Save checkpoint every N steps |
| `CKPT_DIR` | `checkpoints-spider-remote` | Checkpoint output directory |
| `RESUME` | _(empty)_ | Path to checkpoint for manual resume |
### Recommended Settings for RTX 6000 Pro (48GB)
```bash
# MXFP8 — maximum accuracy, best VRAM efficiency
export PRECISION=mxfp8
export MICRO_BATCH=8
export GRAD_ACCUM=4
# Global batch: 8 * 4 * 2048 = 65,536 tokens/step
# ~10B tokens ≈ 152,000 steps
```
### Conservative Settings (if VRAM-constrained)
```bash
export PRECISION=fp8_dynamic
export MICRO_BATCH=4
export GRAD_ACCUM=8
# Global batch: 4 * 8 * 2048 = 65,536 tokens/step (same throughput, lower peak VRAM)
```
## Launch
### Fresh Training Run
```bash
bash scripts/train_remote.sh
```
### Resume from Checkpoint
```bash
# Auto-resume (picks latest from CKPT_DIR)
bash scripts/train_remote.sh
# Manual resume from specific checkpoint
export RESUME=checkpoints-spider-remote/spider-step5000.pt
bash scripts/train_remote.sh
```
### Resume from Local Smoke Test
Transfer the local checkpoint to the remote machine, then:
```bash
export RESUME=checkpoints-spider-real/spider-final-ep1.pt
bash scripts/train_remote.sh
```
**Note**: The local checkpoint was trained with 8-bit AdamW (BF16). On resume with MXFP8/FP8, the training script will:
1. Load model weights (always compatible)
2. Skip 8-bit optimizer state with a warning (8-bit → standard AdamW mismatch)
3. Continue training with standard AdamW from step 0 optimizer state
This is by design — the optimizer state mismatch is handled gracefully.
## Monitoring
### Training Logs
The script outputs structured logs every 10 steps:
```
Epoch 1 | step 10/152000 | loss 3.2140 | lm 3.1020 | aux 0.0312 | bp 0.0808 [FIXED/FROZEN] | gnorm 1.23 | lr 3.00e-04 | 0.42M tok/s | 0.07B tokens
```
Key metrics:
- **loss**: Total loss (lm + aux + bp)
- **lm**: Language modeling loss
- **aux**: MoE load-balancing auxiliary loss
- **bp**: Boundary predictor loss [FIXED=30% curriculum / ADAPTIVE=learned]
- **gnorm**: Gradient norm (should stabilize ~1-5)
- **tok/s**: Throughput (expect 0.5-1.0M tok/s on RTX 6000 Pro)
### VRAM Monitoring
```bash
watch -n 5 nvidia-smi
```
Expected on RTX 6000 Pro with MXFP8:
- Model: ~2GB (weights in FP8)
- Optimizer: ~8GB (standard AdamW, FP32 states)
- Activations: ~4-8GB (gradient checkpointing enabled)
- **Peak**: ~15-20GB total
### Health Warnings
The `RecurrentMonitor` checks for:
- **Representation drift**: Loop hidden states diverging (cosine sim < 0.5)
- **Collapse**: All experts producing identical outputs (std < 1e-6)
If you see these warnings, consider reducing `N_LOOPS` or lowering learning rate.
## Precision Fallback Chain
The training script automatically falls back if precision setup fails:
```
MXFP8 (sm120+ Blackwell) → FP8_DYNAMIC (sm89+ Ada) → BF16 (all GPUs)
```
- **MXFP8**: Row-wise scaling + high-precision grad weight accumulation. Best accuracy.
- **FP8_DYNAMIC**: Row-wise dynamic scaling. Good accuracy/performance tradeoff.
- **BF16**: No quantization. Most VRAM, but simplest path.
## Checkpoint Files
| File | Description |
|---|---|
| `spider-step{N}.pt` | Step checkpoint (every `CKPT_EVERY` steps) |
| `spider-ep{N}.pt` | Epoch boundary checkpoint |
| `spider-best.pt` | Best loss checkpoint (updated when epoch loss improves) |
| `spider-final-ep{N}.pt` | Final checkpoint at training end |
Each checkpoint contains:
- Model state dict
- Optimizer state dict
- Training step, epoch, config
- `best_loss` value
- BP optimizer state (if active)
## Troubleshooting
### `mat2 shape must be divisible by 16`
Fixed with `pad_inner_dim=True` in `Float8LinearConfig` (v0.17.0+). The training script handles this automatically.
### `CUDA out of memory`
Reduce `MICRO_BATCH` or increase `GRAD_ACCUM` to maintain the same global batch size:
```bash
export MICRO_BATCH=4 # was 8
export GRAD_ACCUM=8 # was 4 (same 65,536 tok/step)
```
### Optimizer state mismatch on resume
Expected when resuming a BF16 (8-bit Adam) checkpoint on FP8/MXFP8 (standard AdamW). The script logs a warning and continues — model weights load fine, optimizer restarts from scratch.
### Slower than expected throughput
- Ensure `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True` is set (default in script)
- Check `torch.compile` isn't being used inadvertently (adds compile overhead)
- Verify torchao version >= 0.17.0 for optimal FP8 kernels