| # KoHRM-Text VRAM / OOM Notes |
|
|
| ์์ฑ์ผ: 2026-05-24 |
|
|
| ์ด ๋ฌธ์๋ `KoHRM-Text-1.4B` stage-1 ํ์ต ์ค VRAM์ด ์๊ฐ์ด ์ง๋๋ฉฐ ์ฆ๊ฐํ๋ ์ด์ , ์ด์ OOM ์์ธ, ํ์ฌ ์ด์ ๊ธฐ์ค์ ๊ธฐ๋กํฉ๋๋ค. |
|
|
| ## ํ์ฌ ๊ด์ธก ์ํ |
|
|
| ํ์ฌ stage-1 run์ ๋ค์ ์ค์ ์ผ๋ก ์ ์ ํ์ต ์ค์
๋๋ค. |
|
|
| | ํญ๋ชฉ | ๊ฐ | |
| |---|---:| |
| | GPU | 8 x NVIDIA H200 | |
| | GPU utilization | 8์ฅ ๋ชจ๋ 99% | |
| | global batch | 180,224 tokens | |
| | local token slots/GPU | 22,528 | |
| | context | 4,096 | |
| | VRAM | GPU0 ์ฝ 129.9GB, ๋๋จธ์ง ์ฝ 127.6GB / 143.8GB | |
| | speed | ์ฝ 1.02 step/sec | |
| | checkpoint interval | 5,000 steps | |
|
|
| ํ์ฌ ์ค์ ์ ๋น ๋ฅด์ง๋ง ์ฌ์ VRAM์ด ์์ฃผ ๋์ ํธ์ ์๋๋๋ค. H200 ์ฅ๋น ์ฝ 144GB ์ค 127-130GB๋ฅผ ์ฌ์ฉํ๋ฏ๋ก, NCCL/allocator/compiler/cache/checkpoint ์๊ฐ ํผํฌ๊ฐ ๊ฒน์น๋ฉด OOM ์ํ์ด ๋ค์ ์๊ธธ ์ ์์ต๋๋ค. |
|
|
| ## ์ ํ์ต ์ค VRAM์ด ์ ์ ์ฌ๋ผ๊ฐ๋ |
|
|
| VRAM ์ฆ๊ฐ๊ฐ ๊ณง๋ฐ๋ก โ๋ฉ๋ชจ๋ฆฌ ๋์โ๋ผ๋ ๋ป์ ์๋๋๋ค. ๋ํ PyTorch/FSDP/compile ํ์ต์์๋ ๋ค์ ์์ธ์ด ๊ฒน์น๋ฉด์ ์ด๋ฐ๋ณด๋ค ๋ค์์ VRAM์ด ๋ ๋์์ง๋ ํจํด์ด ํํฉ๋๋ค. |
|
|
| ### 1. torch.compile / CUDA graph / kernel cache |
|
|
| HRM-Text ์ฝ๋๋ ์ฌ๋ฌ forward/backward path๋ฅผ compileํฉ๋๋ค. ์ด๋ฐ ๋ช step์์๋ ๋ชจ๋ shape/path๊ฐ ์์ง compile๋์ง ์์๊ณ , ํ์ต์ด ์งํ๋๋ฉฐ ์ถ๊ฐ graph, Triton kernel, CUDA kernel cache๊ฐ ๋ง๋ค์ด์ง๋๋ค. |
|
|
| ํนํ HRM ๊ตฌ์กฐ๋ H/L recurrent cycle๊ณผ PrefixLM loss๊ฐ ์์ด ๋จ์ decoder-only Transformer๋ณด๋ค compile path๊ฐ ๋ ๋ณต์กํฉ๋๋ค. ์ด๋ฐ VRAM๋ง ๋ณด๊ณ batch๋ฅผ ํฌ๊ฒ ์ก์ผ๋ฉด ํ์ graph๊ฐ ์์ฑ๋ ๋ ์ถ๊ฐ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ๋ชป ๋ฐ์ OOM์ด ๋ ์ ์์ต๋๋ค. |
|
|
| ### 2. final logits buffer ํฌ๊ธฐ |
|
|
| ์ด๋ฒ ๋ชจ๋ธ์ vocab์ด 131,072์
๋๋ค. upstream HRM-Text ๋
ผ๋ฌธ ์ค์ ์ 65,536 vocab๋ณด๋ค ๋ ๋ฐฐ์
๋๋ค. |
|
|
| batch token slots๊ฐ ์ปค์ง์๋ก final logits ๋๋ loss ๊ณ์ฐ ์ชฝ ์์ ๋ฒํผ๊ฐ ๋งค์ฐ ์ปค์ง๋๋ค. |
|
|
| ์๋ฅผ ๋ค์ด local token slots/GPU๊ฐ 32,768์ด๋ฉด `32768 x 131072` bf16 logits ๊ณ์ด ๋ฒํผ๊ฐ ํ์ํ ์ ์์ต๋๋ค. ์ด๋ก ์ ๋จ์ผ bf16 dense buffer๋ง ์ก์๋ ์ฝ 8GB ์ด์์ด๊ณ , ์ค์ backward/temporary/parallel buffer๊น์ง ํฉ์น๋ฉด ํจ์ฌ ์ปค์ง๋๋ค. |
|
|
| ์ด ๋๋ฌธ์ ์ฒ์์๋ `global_batch_size=262144` ๋๋ `229376`์ด ์ ๊น ๋์๊ฐ๋, ๋ค์์ compile graph์ logits/loss ์์ ๋ฒํผ๊ฐ ๊ฒน์น๋ ์๊ฐ OOM์ด ๋ ์ ์์ต๋๋ค. |
|
|
| ### 3. FSDP2 / optimizer / EMA ์ํ |
|
|
| ํ์ฌ ํ์ต์ model weights๋ง ๋ค๊ณ ์๋ ๊ฒ์ด ์๋๋๋ค. |
|
|
| - model parameters |
| - gradients |
| - optimizer state |
| - Adam-atan2 state |
| - EMA state |
| - FSDP shard/all-gather/reduce-scatter buffers |
| - recurrent carry ๊ด๋ จ state |
|
|
| ์ด ์ํ๋ค์ด step๋ง๋ค ํญ์ ๊ฐ์ ์๊ฐ์ ๊ฐ์ ํฌ๊ธฐ๋ก ๋ณด์ด๋ ๊ฒ์ ์๋๋๋ค. ํน์ backward path, optimizer step, checkpoint save ์์ ์ ํผํฌ๊ฐ ์ฌ๋ผ๊ฐ ์ ์์ต๋๋ค. |
|
|
| ### 4. NCCL communication buffers |
|
|
| 8 GPU ๋ถ์ฐ ํ์ต์์๋ NCCL ํต์ ๋ฒํผ๊ฐ ํ์ํฉ๋๋ค. all-gather/reduce-scatter ํ์ด๋ฐ, bucket ํฌ๊ธฐ, compile๋ ๊ทธ๋ํ ์คํ ์์์ ๋ฐ๋ผ GPU๋ณ ํผํฌ๊ฐ ๋ค๋ฅด๊ฒ ๋ณด์ผ ์ ์์ต๋๋ค. |
|
|
| GPU0์ด ๋ค๋ฅธ GPU๋ณด๋ค ๋ ๋๊ฒ ๋ณด์ด๋ ๊ฒ๋ ์ผ๋ฐ์ ์ผ๋ก ๊ฐ๋ฅํฉ๋๋ค. rank0๊ฐ ๋ก๊น
, ์ผ๋ถ metadata, checkpoint coordination, dataloader/host interaction์ ๋ ๋งก๋ ๊ฒฝ์ฐ๊ฐ ์๊ธฐ ๋๋ฌธ์
๋๋ค. |
|
|
| ### 5. CUDA caching allocator |
|
|
| `nvidia-smi`์ used memory๋ โํ์ฌ ํ
์๊ฐ ์ค์ ๋ก ์ฐ๋ ๋ฉ๋ชจ๋ฆฌโ๋ง ๋ปํ์ง ์์ต๋๋ค. PyTorch CUDA allocator๊ฐ ํ ๋ฒ ํ๋ณดํ ๋ธ๋ก์ ์ฌ์ฌ์ฉํ๋ ค๊ณ ์บ์์ ์ก๊ณ ์์ผ๋ฉด `nvidia-smi`์๋ ๊ณ์ ์ฌ์ฉ ์ค์ฒ๋ผ ๋ณด์
๋๋ค. |
|
|
| ๋ฐ๋ผ์ step์ด ์งํ๋ ์๋ก used memory๊ฐ ์ฌ๋ผ๊ฐ๊ณ ์ ๋ด๋ ค๊ฐ์ง ์๋ ๊ฒ์ ์ ์์ผ ์ ์์ต๋๋ค. ์ค์ํ ๊ฒ์ reserved๊ฐ ๊ณ์ ๋ฌดํ ์ฆ๊ฐํ๋์ง, ๋๋ ํน์ step ์ดํ ์์ plateau๋ฅผ ๋ง๋๋์ง์
๋๋ค. |
|
|
| ### 6. checkpoint ์ ์ฅ ์ ์๊ฐ ํผํฌ |
|
|
| FSDP2 checkpoint ์ ์ฅ ์ `.distcp` shard, metadata, state_dict materialization, host/device transfer๊ฐ ๊ฒน์นฉ๋๋ค. ์ ์ฅ ์์ฒด๋ ์ฃผ๋ก CPU/disk ์์
์ด์ง๋ง, ์ ์ฅ ์ง์ /์งํ ๋ชจ๋ธ state ์ ๊ทผ ๋๋ฌธ์ GPU/CPU ๋ฉ๋ชจ๋ฆฌ ํผํฌ๊ฐ ์๊ธธ ์ ์์ต๋๋ค. |
| |
| ๊ทธ๋์ ๋๋ฌด ์ฆ์ checkpoint ์ ์ฅ์ ๋ค์ ๋ฌธ์ ๋ฅผ ๋ง๋ญ๋๋ค. |
| |
| - step ์ฒ๋ฆฌ ์ง์ฐ |
| - ๋์คํฌ ์ฌ์ฉ๋ ๊ธ์ฆ |
| - HF upload ๋ฐ scan ๋น์ฉ ์ฆ๊ฐ |
| - ์ ์ฅ ์์ ํผํฌ ๋ฉ๋ชจ๋ฆฌ ์ฆ๊ฐ |
| |
| ํ์ฌ 5,000 step๋ง๋ค ์ฝ 21GB๊ธ FSDP2 checkpoint๊ฐ ์๊น๋๋ค. 500 step๋ง๋ค ์ ์ฅํ๋ฉด stage-1 ๊ธฐ์ค์ผ๋ก ์ฒดํฌํฌ์ธํธ ์์ ์ ์ฅ ๋ถํ๊ฐ 10๋ฐฐ ๋์ด ๊ณผํฉ๋๋ค. |
| |
| ## ์ด์ OOM ์์ธ |
| |
| ์ด์ OOM์ batch๋ฅผ ํฌ๊ฒ ์ก์์ ๋ ์ด๋ฐ ๊ด์ธก VRAM๋ง ๋ณด๊ณ โ๊ด์ฐฎ๋คโ๊ณ ํ๋จํ ๊ฒ์ด ์์ธ์
๋๋ค. |
| |
| ํต์ฌ์ ๋ค์์
๋๋ค. |
| |
| 1. vocab 131K๋ผ logits/loss ๊ด๋ จ ์์ ๋ฒํผ๊ฐ ํฝ๋๋ค. |
| 2. HRM recurrent compile path๊ฐ ์ด๋ฐ ๋ช step ๋ค ์ถ๊ฐ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์๊ตฌํฉ๋๋ค. |
| 3. H200 8์ฅ์ด๋ผ compute๋ ์ถฉ๋ถํ์ง๋ง, 1.4B + 131K vocab + EMA + optimizer + FSDP2 ์กฐํฉ์์๋ batch๋ฅผ ๋๋ฌด ํฌ๊ฒ ์ก์ผ๋ฉด ํ๋ฐ ํผํฌ๊ฐ ๊ฑธ๋ฆฝ๋๋ค. |
| 4. `global_batch_size=262144`, `229376`์ ์ด๋ฐ์๋ ๊ฐ๋ฅํด ๋ณด์์ง๋ง ์์ ๋ง์ง์ด ๋ถ์กฑํ์ต๋๋ค. |
| |
| ํ์ฌ๋ `global_batch_size=180224`๋ก ๋ด๋ ค ์์ ์งํ ์ค์
๋๋ค. |
| |
| ## ์ด์ ๊ธฐ์ค |
| |
| ํ์ฌ stage-1์์๋ GPU๋ฅผ ๋๋ฆฌ์ง ์๋ ๊ฒ์ด ์ฐ์ ์ด์ง๋ง, OOM์ผ๋ก run์ด ์ฃฝ์ผ๋ฉด ์ฌ์์/๊ฒ์ฆ/์ฒดํฌํฌ์ธํธ ์ ๋ฆฌ ๋น์ฉ์ด ๋ ํฝ๋๋ค. |
| |
| ๊ถ์ฅ ๊ธฐ์ค: |
| |
| | ํญ๋ชฉ | ๊ธฐ์ค | |
| |---|---| |
| | primary batch | `global_batch_size=180224` | |
| | ์ ์ฅ ์ฃผ๊ธฐ | `checkpoint_step_interval=5000` | |
| | ๋ก์ปฌ ๋ณด๊ด | ์ต์ 2-3๊ฐ checkpoint๋ง ์ ์ง | |
| | HF main repo | ์ต์ safetensors export ์ค์ฌ | |
| | HF raw repo | resume๊ฐ ํ์ํ FSDP2 checkpoint๋ง ๋ณ๋ ๋ณด๊ด | |
| | OOM ์ฌ๋ฐ ์ | batch๋ฅผ 5-10% ๋ฎ์ถ๊ณ ๊ฐ์ resume checkpoint์์ ์ฌ์์ | |
| |
| ## 500 step checkpoint๊ฐ ๊ณผํ ์ด์ |
| |
| 500 step๋ง๋ค ์ ์ฅํ๋ฉด ๋ค์ ๋ฌธ์ ๊ฐ ์๊น๋๋ค. |
| |
| - ํ์ฌ FSDP2 checkpoint ํ๋๊ฐ ์ฝ 21GB์
๋๋ค. |
| - 500 step ๊ฐ๊ฒฉ์ด๋ฉด 10,000 step๋ง๋ค ์ฝ 20๊ฐ, ์ฆ ์ฝ 420GB๊ฐ ์๊น๋๋ค. |
| - stage-1 ์ ์ฒด 88,522 step ๊ธฐ์ค์ผ๋ก๋ ๋จ์ ๊ณ์ฐ์ 170๊ฐ ์ด์์ด ์๊ฒจ ์ TB๊ฐ ๋ฉ๋๋ค. |
| - ์ ์ฅ ์์ฒด๊ฐ ํ์ต ๋ฃจํ๋ฅผ ๋ฐฉํดํ๊ณ , HF ์
๋ก๋/์ค์บ๋ ์ปค์ง๋๋ค. |
| |
| ๋ฐ๋ผ์ ํ์ฌ์ฒ๋ผ 5,000 step ๊ฐ๊ฒฉ์ผ๋ก ์ ์ฅํ๊ณ , ๋ก์ปฌ์ ์ต์ 2-3๊ฐ๋ง ๋จ๊ธฐ๋ ํธ์ด ๋ง์ต๋๋ค. |
| |
| ## ๋ค์ batch ์กฐ์ ํ๋จ |
| |
| ํ์ฌ VRAM ์ฌ์ฉ๋์ ๋์ง๋ง ํ์ต ์๋๋ ์์ ์ ์
๋๋ค. |
| |
| ๋ค์ stage์์ batch๋ฅผ ์ฌ๋ฆฌ๊ณ ์ถ์ผ๋ฉด ํ ๋ฒ์ ํฌ๊ฒ ์ฌ๋ฆฌ์ง ๋ง๊ณ ๋ค์ ์์๊ฐ ๋ซ์ต๋๋ค. |
| |
| 1. `global_batch_size=180224`๋ก ์์ ์๋ฃ ํ์ธ |
| 2. ๋ค์ dataset stage์์ `196608` ํ
์คํธ |
| 3. 2-3์ฒ step ์ด์ VRAM plateau ํ์ธ |
| 4. checkpoint ์ ์ฅ ์์ ๊น์ง ํต๊ณผํ๋ฉด ์ ์ง |
| 5. OOM ๋๋ ํผํฌ ๋ถ์์ ์ ์ฆ์ `180224` ๋๋ `172032`๋ก ๋ณต๊ท |
| |
| ๋
ผ๋ฌธ ์ค์ ๊ณผ ๋น๊ตํ๋ฉด H200 8์ฅ์ ๊ฐํ์ง๋ง, ์ด๋ฒ ๋ชจ๋ธ์ vocab์ด 131K๋ผ upstream๊ณผ ๋ฉ๋ชจ๋ฆฌ ๊ตฌ์กฐ๊ฐ ๋ค๋ฆ
๋๋ค. ๋ฐ๋ผ์ โH200์ด๋๊น ๋ฌด์กฐ๊ฑด H100 16์ฅ batch๋ฅผ ๋๊ธด๋คโ๋ ์์ผ๋ก ์ก์ผ๋ฉด ์์ ์ฑ์ด ๋จ์ด์ง๋๋ค. |
| |
| ## ๊ฒฐ๋ก |
| |
| ํ์ฌ VRAM ์์น์ torch compile/cache, 131K vocab logits buffer, FSDP2/optimizer/EMA/NCCL buffer, checkpoint ์๊ฐ ํผํฌ๊ฐ ๊ฒน์น ๊ฒฐ๊ณผ๋ก ๋ณด๋ ๊ฒ์ด ๋ง์ต๋๋ค. |
| |
| ํ์ฌ `global_batch_size=180224`, 5,000 step checkpoint, ์ต์ 2-3๊ฐ ๋ณด๊ด ์ ์ฑ
์ ๋น ๋ฅธ ํ์ต๊ณผ OOM ํํผ ์ฌ์ด์ ํ์ค์ ์ธ ๊ท ํ์
๋๋ค. ํ์ต์ด ์์ ํ ์์ plateau๋ฅผ ๋ณด์ด๋ฉด ๋ค์ stage์์๋ง ์ํญ ์ฆ๋์ ๊ฒํ ํฉ๋๋ค. |
| |
| |