CLIWorks commited on
Commit
c16af35
·
verified ·
1 Parent(s): 513d5a8

Upload remote-training.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. remote-training.md +218 -0
remote-training.md ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Spider-FLEXITOKENS Remote Training Guide
2
+
3
+ ## Target Hardware: NVIDIA RTX 6000 Pro (Blackwell)
4
+
5
+ - **GPU**: RTX 6000 Pro (Blackwell architecture, sm120+)
6
+ - **VRAM**: 48GB GDDR7
7
+ - **Precision**: MXFP8 (rowwise_with_gw_hp recipe) — primary; FP8_DYNAMIC fallback
8
+ - **Expected peak VRAM**: ~15-20GB (model ~4GB FP8, optimizer ~8GB standard AdamW, activations ~4-8GB with gradient checkpointing)
9
+
10
+ ## Quick Start
11
+
12
+ ```bash
13
+ # 1. Clone/transfer the repo to the remote machine
14
+ # 2. Install dependencies (see below)
15
+ # 3. Run the launch script
16
+ bash scripts/train_remote.sh
17
+ ```
18
+
19
+ ## Environment Setup
20
+
21
+ ### Required Dependencies
22
+
23
+ ```bash
24
+ pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
25
+ pip install torchao>=0.17.0
26
+ pip install datasets transformers
27
+ pip install bitsandbytes # optional — only used for BF16 fallback
28
+ ```
29
+
30
+ ### Optional (Recommended)
31
+
32
+ ```bash
33
+ pip install unsloth # MoE kernel optimizations + memory-efficient GC
34
+ ```
35
+
36
+ ### Verify Installation
37
+
38
+ ```bash
39
+ python3 -c "
40
+ import torch
41
+ print(f'PyTorch: {torch.__version__}')
42
+ print(f'CUDA: {torch.version.cuda}')
43
+ print(f'GPU: {torch.cuda.get_device_name(0)}')
44
+ print(f'Compute capability: sm{torch.cuda.get_device_capability(0)[0]}')
45
+
46
+ import torchao
47
+ print(f'torchao: {torchao.__version__}')
48
+
49
+ from torchao.float8 import Float8LinearConfig
50
+ print('FP8 training: available')
51
+ print(f'Recipes: {[n.value for n in __import__(\"torchao.float8.config\", fromlist=[\"Float8LinearRecipeName\"]).Float8LinearRecipeName]}')
52
+ "
53
+ ```
54
+
55
+ Expected output on RTX 6000 Pro: `sm120` or higher, all 3 recipes available (`tensorwise`, `rowwise`, `rowwise_with_gw_hp`).
56
+
57
+ ## Configuration
58
+
59
+ ### Environment Variables
60
+
61
+ | Variable | Default | Description |
62
+ |---|---|---|
63
+ | `PRECISION` | `mxfp8` | Training precision: `mxfp8`, `fp8_dynamic`, `bf16` |
64
+ | `SEQ_LEN` | `2048` | Sequence length per sample |
65
+ | `MICRO_BATCH` | `8` | Batch size per forward pass |
66
+ | `GRAD_ACCUM` | `4` | Gradient accumulation steps |
67
+ | `TARGET_TOKENS` | `10000000000` | Total training tokens (10B) |
68
+ | `N_LOOPS` | `6` | Recurrent loop iterations |
69
+ | `LR` | `3e-4` | Peak learning rate |
70
+ | `CKPT_EVERY` | `500` | Save checkpoint every N steps |
71
+ | `CKPT_DIR` | `checkpoints-spider-remote` | Checkpoint output directory |
72
+ | `RESUME` | _(empty)_ | Path to checkpoint for manual resume |
73
+
74
+ ### Recommended Settings for RTX 6000 Pro (48GB)
75
+
76
+ ```bash
77
+ # MXFP8 — maximum accuracy, best VRAM efficiency
78
+ export PRECISION=mxfp8
79
+ export MICRO_BATCH=8
80
+ export GRAD_ACCUM=4
81
+ # Global batch: 8 * 4 * 2048 = 65,536 tokens/step
82
+ # ~10B tokens ≈ 152,000 steps
83
+ ```
84
+
85
+ ### Conservative Settings (if VRAM-constrained)
86
+
87
+ ```bash
88
+ export PRECISION=fp8_dynamic
89
+ export MICRO_BATCH=4
90
+ export GRAD_ACCUM=8
91
+ # Global batch: 4 * 8 * 2048 = 65,536 tokens/step (same throughput, lower peak VRAM)
92
+ ```
93
+
94
+ ## Launch
95
+
96
+ ### Fresh Training Run
97
+
98
+ ```bash
99
+ bash scripts/train_remote.sh
100
+ ```
101
+
102
+ ### Resume from Checkpoint
103
+
104
+ ```bash
105
+ # Auto-resume (picks latest from CKPT_DIR)
106
+ bash scripts/train_remote.sh
107
+
108
+ # Manual resume from specific checkpoint
109
+ export RESUME=checkpoints-spider-remote/spider-step5000.pt
110
+ bash scripts/train_remote.sh
111
+ ```
112
+
113
+ ### Resume from Local Smoke Test
114
+
115
+ Transfer the local checkpoint to the remote machine, then:
116
+
117
+ ```bash
118
+ export RESUME=checkpoints-spider-real/spider-final-ep1.pt
119
+ bash scripts/train_remote.sh
120
+ ```
121
+
122
+ **Note**: The local checkpoint was trained with 8-bit AdamW (BF16). On resume with MXFP8/FP8, the training script will:
123
+ 1. Load model weights (always compatible)
124
+ 2. Skip 8-bit optimizer state with a warning (8-bit → standard AdamW mismatch)
125
+ 3. Continue training with standard AdamW from step 0 optimizer state
126
+
127
+ This is by design — the optimizer state mismatch is handled gracefully.
128
+
129
+ ## Monitoring
130
+
131
+ ### Training Logs
132
+
133
+ The script outputs structured logs every 10 steps:
134
+
135
+ ```
136
+ Epoch 1 | step 10/152000 | loss 3.2140 | lm 3.1020 | aux 0.0312 | bp 0.0808 [FIXED/FROZEN] | gnorm 1.23 | lr 3.00e-04 | 0.42M tok/s | 0.07B tokens
137
+ ```
138
+
139
+ Key metrics:
140
+ - **loss**: Total loss (lm + aux + bp)
141
+ - **lm**: Language modeling loss
142
+ - **aux**: MoE load-balancing auxiliary loss
143
+ - **bp**: Boundary predictor loss [FIXED=30% curriculum / ADAPTIVE=learned]
144
+ - **gnorm**: Gradient norm (should stabilize ~1-5)
145
+ - **tok/s**: Throughput (expect 0.5-1.0M tok/s on RTX 6000 Pro)
146
+
147
+ ### VRAM Monitoring
148
+
149
+ ```bash
150
+ watch -n 5 nvidia-smi
151
+ ```
152
+
153
+ Expected on RTX 6000 Pro with MXFP8:
154
+ - Model: ~2GB (weights in FP8)
155
+ - Optimizer: ~8GB (standard AdamW, FP32 states)
156
+ - Activations: ~4-8GB (gradient checkpointing enabled)
157
+ - **Peak**: ~15-20GB total
158
+
159
+ ### Health Warnings
160
+
161
+ The `RecurrentMonitor` checks for:
162
+ - **Representation drift**: Loop hidden states diverging (cosine sim < 0.5)
163
+ - **Collapse**: All experts producing identical outputs (std < 1e-6)
164
+
165
+ If you see these warnings, consider reducing `N_LOOPS` or lowering learning rate.
166
+
167
+ ## Precision Fallback Chain
168
+
169
+ The training script automatically falls back if precision setup fails:
170
+
171
+ ```
172
+ MXFP8 (sm120+ Blackwell) ��� FP8_DYNAMIC (sm89+ Ada) → BF16 (all GPUs)
173
+ ```
174
+
175
+ - **MXFP8**: Row-wise scaling + high-precision grad weight accumulation. Best accuracy.
176
+ - **FP8_DYNAMIC**: Row-wise dynamic scaling. Good accuracy/performance tradeoff.
177
+ - **BF16**: No quantization. Most VRAM, but simplest path.
178
+
179
+ ## Checkpoint Files
180
+
181
+ | File | Description |
182
+ |---|---|
183
+ | `spider-step{N}.pt` | Step checkpoint (every `CKPT_EVERY` steps) |
184
+ | `spider-ep{N}.pt` | Epoch boundary checkpoint |
185
+ | `spider-best.pt` | Best loss checkpoint (updated when epoch loss improves) |
186
+ | `spider-final-ep{N}.pt` | Final checkpoint at training end |
187
+
188
+ Each checkpoint contains:
189
+ - Model state dict
190
+ - Optimizer state dict
191
+ - Training step, epoch, config
192
+ - `best_loss` value
193
+ - BP optimizer state (if active)
194
+
195
+ ## Troubleshooting
196
+
197
+ ### `mat2 shape must be divisible by 16`
198
+
199
+ Fixed with `pad_inner_dim=True` in `Float8LinearConfig` (v0.17.0+). The training script handles this automatically.
200
+
201
+ ### `CUDA out of memory`
202
+
203
+ Reduce `MICRO_BATCH` or increase `GRAD_ACCUM` to maintain the same global batch size:
204
+
205
+ ```bash
206
+ export MICRO_BATCH=4 # was 8
207
+ export GRAD_ACCUM=8 # was 4 (same 65,536 tok/step)
208
+ ```
209
+
210
+ ### Optimizer state mismatch on resume
211
+
212
+ Expected when resuming a BF16 (8-bit Adam) checkpoint on FP8/MXFP8 (standard AdamW). The script logs a warning and continues — model weights load fine, optimizer restarts from scratch.
213
+
214
+ ### Slower than expected throughput
215
+
216
+ - Ensure `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True` is set (default in script)
217
+ - Check `torch.compile` isn't being used inadvertently (adds compile overhead)
218
+ - Verify torchao version >= 0.17.0 for optimal FP8 kernels