File size: 6,562 Bytes
c16af35 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 | # Spider-FLEXITOKENS Remote Training Guide
## Target Hardware: NVIDIA RTX 6000 Pro (Blackwell)
- **GPU**: RTX 6000 Pro (Blackwell architecture, sm120+)
- **VRAM**: 48GB GDDR7
- **Precision**: MXFP8 (rowwise_with_gw_hp recipe) β primary; FP8_DYNAMIC fallback
- **Expected peak VRAM**: ~15-20GB (model ~4GB FP8, optimizer ~8GB standard AdamW, activations ~4-8GB with gradient checkpointing)
## Quick Start
```bash
# 1. Clone/transfer the repo to the remote machine
# 2. Install dependencies (see below)
# 3. Run the launch script
bash scripts/train_remote.sh
```
## Environment Setup
### Required Dependencies
```bash
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
pip install torchao>=0.17.0
pip install datasets transformers
pip install bitsandbytes # optional β only used for BF16 fallback
```
### Optional (Recommended)
```bash
pip install unsloth # MoE kernel optimizations + memory-efficient GC
```
### Verify Installation
```bash
python3 -c "
import torch
print(f'PyTorch: {torch.__version__}')
print(f'CUDA: {torch.version.cuda}')
print(f'GPU: {torch.cuda.get_device_name(0)}')
print(f'Compute capability: sm{torch.cuda.get_device_capability(0)[0]}')
import torchao
print(f'torchao: {torchao.__version__}')
from torchao.float8 import Float8LinearConfig
print('FP8 training: available')
print(f'Recipes: {[n.value for n in __import__(\"torchao.float8.config\", fromlist=[\"Float8LinearRecipeName\"]).Float8LinearRecipeName]}')
"
```
Expected output on RTX 6000 Pro: `sm120` or higher, all 3 recipes available (`tensorwise`, `rowwise`, `rowwise_with_gw_hp`).
## Configuration
### Environment Variables
| Variable | Default | Description |
|---|---|---|
| `PRECISION` | `mxfp8` | Training precision: `mxfp8`, `fp8_dynamic`, `bf16` |
| `SEQ_LEN` | `2048` | Sequence length per sample |
| `MICRO_BATCH` | `8` | Batch size per forward pass |
| `GRAD_ACCUM` | `4` | Gradient accumulation steps |
| `TARGET_TOKENS` | `10000000000` | Total training tokens (10B) |
| `N_LOOPS` | `6` | Recurrent loop iterations |
| `LR` | `3e-4` | Peak learning rate |
| `CKPT_EVERY` | `500` | Save checkpoint every N steps |
| `CKPT_DIR` | `checkpoints-spider-remote` | Checkpoint output directory |
| `RESUME` | _(empty)_ | Path to checkpoint for manual resume |
### Recommended Settings for RTX 6000 Pro (48GB)
```bash
# MXFP8 β maximum accuracy, best VRAM efficiency
export PRECISION=mxfp8
export MICRO_BATCH=8
export GRAD_ACCUM=4
# Global batch: 8 * 4 * 2048 = 65,536 tokens/step
# ~10B tokens β 152,000 steps
```
### Conservative Settings (if VRAM-constrained)
```bash
export PRECISION=fp8_dynamic
export MICRO_BATCH=4
export GRAD_ACCUM=8
# Global batch: 4 * 8 * 2048 = 65,536 tokens/step (same throughput, lower peak VRAM)
```
## Launch
### Fresh Training Run
```bash
bash scripts/train_remote.sh
```
### Resume from Checkpoint
```bash
# Auto-resume (picks latest from CKPT_DIR)
bash scripts/train_remote.sh
# Manual resume from specific checkpoint
export RESUME=checkpoints-spider-remote/spider-step5000.pt
bash scripts/train_remote.sh
```
### Resume from Local Smoke Test
Transfer the local checkpoint to the remote machine, then:
```bash
export RESUME=checkpoints-spider-real/spider-final-ep1.pt
bash scripts/train_remote.sh
```
**Note**: The local checkpoint was trained with 8-bit AdamW (BF16). On resume with MXFP8/FP8, the training script will:
1. Load model weights (always compatible)
2. Skip 8-bit optimizer state with a warning (8-bit β standard AdamW mismatch)
3. Continue training with standard AdamW from step 0 optimizer state
This is by design β the optimizer state mismatch is handled gracefully.
## Monitoring
### Training Logs
The script outputs structured logs every 10 steps:
```
Epoch 1 | step 10/152000 | loss 3.2140 | lm 3.1020 | aux 0.0312 | bp 0.0808 [FIXED/FROZEN] | gnorm 1.23 | lr 3.00e-04 | 0.42M tok/s | 0.07B tokens
```
Key metrics:
- **loss**: Total loss (lm + aux + bp)
- **lm**: Language modeling loss
- **aux**: MoE load-balancing auxiliary loss
- **bp**: Boundary predictor loss [FIXED=30% curriculum / ADAPTIVE=learned]
- **gnorm**: Gradient norm (should stabilize ~1-5)
- **tok/s**: Throughput (expect 0.5-1.0M tok/s on RTX 6000 Pro)
### VRAM Monitoring
```bash
watch -n 5 nvidia-smi
```
Expected on RTX 6000 Pro with MXFP8:
- Model: ~2GB (weights in FP8)
- Optimizer: ~8GB (standard AdamW, FP32 states)
- Activations: ~4-8GB (gradient checkpointing enabled)
- **Peak**: ~15-20GB total
### Health Warnings
The `RecurrentMonitor` checks for:
- **Representation drift**: Loop hidden states diverging (cosine sim < 0.5)
- **Collapse**: All experts producing identical outputs (std < 1e-6)
If you see these warnings, consider reducing `N_LOOPS` or lowering learning rate.
## Precision Fallback Chain
The training script automatically falls back if precision setup fails:
```
MXFP8 (sm120+ Blackwell) β FP8_DYNAMIC (sm89+ Ada) β BF16 (all GPUs)
```
- **MXFP8**: Row-wise scaling + high-precision grad weight accumulation. Best accuracy.
- **FP8_DYNAMIC**: Row-wise dynamic scaling. Good accuracy/performance tradeoff.
- **BF16**: No quantization. Most VRAM, but simplest path.
## Checkpoint Files
| File | Description |
|---|---|
| `spider-step{N}.pt` | Step checkpoint (every `CKPT_EVERY` steps) |
| `spider-ep{N}.pt` | Epoch boundary checkpoint |
| `spider-best.pt` | Best loss checkpoint (updated when epoch loss improves) |
| `spider-final-ep{N}.pt` | Final checkpoint at training end |
Each checkpoint contains:
- Model state dict
- Optimizer state dict
- Training step, epoch, config
- `best_loss` value
- BP optimizer state (if active)
## Troubleshooting
### `mat2 shape must be divisible by 16`
Fixed with `pad_inner_dim=True` in `Float8LinearConfig` (v0.17.0+). The training script handles this automatically.
### `CUDA out of memory`
Reduce `MICRO_BATCH` or increase `GRAD_ACCUM` to maintain the same global batch size:
```bash
export MICRO_BATCH=4 # was 8
export GRAD_ACCUM=8 # was 4 (same 65,536 tok/step)
```
### Optimizer state mismatch on resume
Expected when resuming a BF16 (8-bit Adam) checkpoint on FP8/MXFP8 (standard AdamW). The script logs a warning and continues β model weights load fine, optimizer restarts from scratch.
### Slower than expected throughput
- Ensure `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True` is set (default in script)
- Check `torch.compile` isn't being used inadvertently (adds compile overhead)
- Verify torchao version >= 0.17.0 for optimal FP8 kernels
|