File size: 6,562 Bytes
c16af35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
# Spider-FLEXITOKENS Remote Training Guide

## Target Hardware: NVIDIA RTX 6000 Pro (Blackwell)

- **GPU**: RTX 6000 Pro (Blackwell architecture, sm120+)
- **VRAM**: 48GB GDDR7
- **Precision**: MXFP8 (rowwise_with_gw_hp recipe) β€” primary; FP8_DYNAMIC fallback
- **Expected peak VRAM**: ~15-20GB (model ~4GB FP8, optimizer ~8GB standard AdamW, activations ~4-8GB with gradient checkpointing)

## Quick Start

```bash
# 1. Clone/transfer the repo to the remote machine
# 2. Install dependencies (see below)
# 3. Run the launch script
bash scripts/train_remote.sh
```

## Environment Setup

### Required Dependencies

```bash
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
pip install torchao>=0.17.0
pip install datasets transformers
pip install bitsandbytes  # optional β€” only used for BF16 fallback
```

### Optional (Recommended)

```bash
pip install unsloth  # MoE kernel optimizations + memory-efficient GC
```

### Verify Installation

```bash
python3 -c "
import torch
print(f'PyTorch: {torch.__version__}')
print(f'CUDA: {torch.version.cuda}')
print(f'GPU: {torch.cuda.get_device_name(0)}')
print(f'Compute capability: sm{torch.cuda.get_device_capability(0)[0]}')

import torchao
print(f'torchao: {torchao.__version__}')

from torchao.float8 import Float8LinearConfig
print('FP8 training: available')
print(f'Recipes: {[n.value for n in __import__(\"torchao.float8.config\", fromlist=[\"Float8LinearRecipeName\"]).Float8LinearRecipeName]}')
"
```

Expected output on RTX 6000 Pro: `sm120` or higher, all 3 recipes available (`tensorwise`, `rowwise`, `rowwise_with_gw_hp`).

## Configuration

### Environment Variables

| Variable | Default | Description |
|---|---|---|
| `PRECISION` | `mxfp8` | Training precision: `mxfp8`, `fp8_dynamic`, `bf16` |
| `SEQ_LEN` | `2048` | Sequence length per sample |
| `MICRO_BATCH` | `8` | Batch size per forward pass |
| `GRAD_ACCUM` | `4` | Gradient accumulation steps |
| `TARGET_TOKENS` | `10000000000` | Total training tokens (10B) |
| `N_LOOPS` | `6` | Recurrent loop iterations |
| `LR` | `3e-4` | Peak learning rate |
| `CKPT_EVERY` | `500` | Save checkpoint every N steps |
| `CKPT_DIR` | `checkpoints-spider-remote` | Checkpoint output directory |
| `RESUME` | _(empty)_ | Path to checkpoint for manual resume |

### Recommended Settings for RTX 6000 Pro (48GB)

```bash
# MXFP8 β€” maximum accuracy, best VRAM efficiency
export PRECISION=mxfp8
export MICRO_BATCH=8
export GRAD_ACCUM=4
# Global batch: 8 * 4 * 2048 = 65,536 tokens/step
# ~10B tokens β‰ˆ 152,000 steps
```

### Conservative Settings (if VRAM-constrained)

```bash
export PRECISION=fp8_dynamic
export MICRO_BATCH=4
export GRAD_ACCUM=8
# Global batch: 4 * 8 * 2048 = 65,536 tokens/step (same throughput, lower peak VRAM)
```

## Launch

### Fresh Training Run

```bash
bash scripts/train_remote.sh
```

### Resume from Checkpoint

```bash
# Auto-resume (picks latest from CKPT_DIR)
bash scripts/train_remote.sh

# Manual resume from specific checkpoint
export RESUME=checkpoints-spider-remote/spider-step5000.pt
bash scripts/train_remote.sh
```

### Resume from Local Smoke Test

Transfer the local checkpoint to the remote machine, then:

```bash
export RESUME=checkpoints-spider-real/spider-final-ep1.pt
bash scripts/train_remote.sh
```

**Note**: The local checkpoint was trained with 8-bit AdamW (BF16). On resume with MXFP8/FP8, the training script will:
1. Load model weights (always compatible)
2. Skip 8-bit optimizer state with a warning (8-bit β†’ standard AdamW mismatch)
3. Continue training with standard AdamW from step 0 optimizer state

This is by design β€” the optimizer state mismatch is handled gracefully.

## Monitoring

### Training Logs

The script outputs structured logs every 10 steps:

```
Epoch 1 | step    10/152000 | loss 3.2140 | lm 3.1020 | aux 0.0312 | bp 0.0808 [FIXED/FROZEN] | gnorm 1.23 | lr 3.00e-04 | 0.42M tok/s | 0.07B tokens
```

Key metrics:
- **loss**: Total loss (lm + aux + bp)
- **lm**: Language modeling loss
- **aux**: MoE load-balancing auxiliary loss
- **bp**: Boundary predictor loss [FIXED=30% curriculum / ADAPTIVE=learned]
- **gnorm**: Gradient norm (should stabilize ~1-5)
- **tok/s**: Throughput (expect 0.5-1.0M tok/s on RTX 6000 Pro)

### VRAM Monitoring

```bash
watch -n 5 nvidia-smi
```

Expected on RTX 6000 Pro with MXFP8:
- Model: ~2GB (weights in FP8)
- Optimizer: ~8GB (standard AdamW, FP32 states)
- Activations: ~4-8GB (gradient checkpointing enabled)
- **Peak**: ~15-20GB total

### Health Warnings

The `RecurrentMonitor` checks for:
- **Representation drift**: Loop hidden states diverging (cosine sim < 0.5)
- **Collapse**: All experts producing identical outputs (std < 1e-6)

If you see these warnings, consider reducing `N_LOOPS` or lowering learning rate.

## Precision Fallback Chain

The training script automatically falls back if precision setup fails:

```
MXFP8 (sm120+ Blackwell) β†’ FP8_DYNAMIC (sm89+ Ada) β†’ BF16 (all GPUs)
```

- **MXFP8**: Row-wise scaling + high-precision grad weight accumulation. Best accuracy.
- **FP8_DYNAMIC**: Row-wise dynamic scaling. Good accuracy/performance tradeoff.
- **BF16**: No quantization. Most VRAM, but simplest path.

## Checkpoint Files

| File | Description |
|---|---|
| `spider-step{N}.pt` | Step checkpoint (every `CKPT_EVERY` steps) |
| `spider-ep{N}.pt` | Epoch boundary checkpoint |
| `spider-best.pt` | Best loss checkpoint (updated when epoch loss improves) |
| `spider-final-ep{N}.pt` | Final checkpoint at training end |

Each checkpoint contains:
- Model state dict
- Optimizer state dict
- Training step, epoch, config
- `best_loss` value
- BP optimizer state (if active)

## Troubleshooting

### `mat2 shape must be divisible by 16`

Fixed with `pad_inner_dim=True` in `Float8LinearConfig` (v0.17.0+). The training script handles this automatically.

### `CUDA out of memory`

Reduce `MICRO_BATCH` or increase `GRAD_ACCUM` to maintain the same global batch size:

```bash
export MICRO_BATCH=4    # was 8
export GRAD_ACCUM=8     # was 4 (same 65,536 tok/step)
```

### Optimizer state mismatch on resume

Expected when resuming a BF16 (8-bit Adam) checkpoint on FP8/MXFP8 (standard AdamW). The script logs a warning and continues β€” model weights load fine, optimizer restarts from scratch.

### Slower than expected throughput

- Ensure `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True` is set (default in script)
- Check `torch.compile` isn't being used inadvertently (adds compile overhead)
- Verify torchao version >= 0.17.0 for optimal FP8 kernels