KoHRM-Text VRAM / OOM Notes
์์ฑ์ผ: 2026-05-24
์ด ๋ฌธ์๋ KoHRM-Text-1.4B stage-1 ํ์ต ์ค VRAM์ด ์๊ฐ์ด ์ง๋๋ฉฐ ์ฆ๊ฐํ๋ ์ด์ , ์ด์ OOM ์์ธ, ํ์ฌ ์ด์ ๊ธฐ์ค์ ๊ธฐ๋กํฉ๋๋ค.
ํ์ฌ ๊ด์ธก ์ํ
ํ์ฌ stage-1 run์ ๋ค์ ์ค์ ์ผ๋ก ์ ์ ํ์ต ์ค์ ๋๋ค.
| ํญ๋ชฉ | ๊ฐ |
|---|---|
| GPU | 8 x NVIDIA H200 |
| GPU utilization | 8์ฅ ๋ชจ๋ 99% |
| global batch | 180,224 tokens |
| local token slots/GPU | 22,528 |
| context | 4,096 |
| VRAM | GPU0 ์ฝ 129.9GB, ๋๋จธ์ง ์ฝ 127.6GB / 143.8GB |
| speed | ์ฝ 1.02 step/sec |
| checkpoint interval | 5,000 steps |
ํ์ฌ ์ค์ ์ ๋น ๋ฅด์ง๋ง ์ฌ์ VRAM์ด ์์ฃผ ๋์ ํธ์ ์๋๋๋ค. H200 ์ฅ๋น ์ฝ 144GB ์ค 127-130GB๋ฅผ ์ฌ์ฉํ๋ฏ๋ก, NCCL/allocator/compiler/cache/checkpoint ์๊ฐ ํผํฌ๊ฐ ๊ฒน์น๋ฉด OOM ์ํ์ด ๋ค์ ์๊ธธ ์ ์์ต๋๋ค.
์ ํ์ต ์ค VRAM์ด ์ ์ ์ฌ๋ผ๊ฐ๋
VRAM ์ฆ๊ฐ๊ฐ ๊ณง๋ฐ๋ก โ๋ฉ๋ชจ๋ฆฌ ๋์โ๋ผ๋ ๋ป์ ์๋๋๋ค. ๋ํ PyTorch/FSDP/compile ํ์ต์์๋ ๋ค์ ์์ธ์ด ๊ฒน์น๋ฉด์ ์ด๋ฐ๋ณด๋ค ๋ค์์ VRAM์ด ๋ ๋์์ง๋ ํจํด์ด ํํฉ๋๋ค.
1. torch.compile / CUDA graph / kernel cache
HRM-Text ์ฝ๋๋ ์ฌ๋ฌ forward/backward path๋ฅผ compileํฉ๋๋ค. ์ด๋ฐ ๋ช step์์๋ ๋ชจ๋ shape/path๊ฐ ์์ง compile๋์ง ์์๊ณ , ํ์ต์ด ์งํ๋๋ฉฐ ์ถ๊ฐ graph, Triton kernel, CUDA kernel cache๊ฐ ๋ง๋ค์ด์ง๋๋ค.
ํนํ HRM ๊ตฌ์กฐ๋ H/L recurrent cycle๊ณผ PrefixLM loss๊ฐ ์์ด ๋จ์ decoder-only Transformer๋ณด๋ค compile path๊ฐ ๋ ๋ณต์กํฉ๋๋ค. ์ด๋ฐ VRAM๋ง ๋ณด๊ณ batch๋ฅผ ํฌ๊ฒ ์ก์ผ๋ฉด ํ์ graph๊ฐ ์์ฑ๋ ๋ ์ถ๊ฐ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ๋ชป ๋ฐ์ OOM์ด ๋ ์ ์์ต๋๋ค.
2. final logits buffer ํฌ๊ธฐ
์ด๋ฒ ๋ชจ๋ธ์ vocab์ด 131,072์ ๋๋ค. upstream HRM-Text ๋ ผ๋ฌธ ์ค์ ์ 65,536 vocab๋ณด๋ค ๋ ๋ฐฐ์ ๋๋ค.
batch token slots๊ฐ ์ปค์ง์๋ก final logits ๋๋ loss ๊ณ์ฐ ์ชฝ ์์ ๋ฒํผ๊ฐ ๋งค์ฐ ์ปค์ง๋๋ค.
์๋ฅผ ๋ค์ด local token slots/GPU๊ฐ 32,768์ด๋ฉด 32768 x 131072 bf16 logits ๊ณ์ด ๋ฒํผ๊ฐ ํ์ํ ์ ์์ต๋๋ค. ์ด๋ก ์ ๋จ์ผ bf16 dense buffer๋ง ์ก์๋ ์ฝ 8GB ์ด์์ด๊ณ , ์ค์ backward/temporary/parallel buffer๊น์ง ํฉ์น๋ฉด ํจ์ฌ ์ปค์ง๋๋ค.
์ด ๋๋ฌธ์ ์ฒ์์๋ global_batch_size=262144 ๋๋ 229376์ด ์ ๊น ๋์๊ฐ๋, ๋ค์์ compile graph์ logits/loss ์์ ๋ฒํผ๊ฐ ๊ฒน์น๋ ์๊ฐ OOM์ด ๋ ์ ์์ต๋๋ค.
3. FSDP2 / optimizer / EMA ์ํ
ํ์ฌ ํ์ต์ model weights๋ง ๋ค๊ณ ์๋ ๊ฒ์ด ์๋๋๋ค.
- model parameters
- gradients
- optimizer state
- Adam-atan2 state
- EMA state
- FSDP shard/all-gather/reduce-scatter buffers
- recurrent carry ๊ด๋ จ state
์ด ์ํ๋ค์ด step๋ง๋ค ํญ์ ๊ฐ์ ์๊ฐ์ ๊ฐ์ ํฌ๊ธฐ๋ก ๋ณด์ด๋ ๊ฒ์ ์๋๋๋ค. ํน์ backward path, optimizer step, checkpoint save ์์ ์ ํผํฌ๊ฐ ์ฌ๋ผ๊ฐ ์ ์์ต๋๋ค.
4. NCCL communication buffers
8 GPU ๋ถ์ฐ ํ์ต์์๋ NCCL ํต์ ๋ฒํผ๊ฐ ํ์ํฉ๋๋ค. all-gather/reduce-scatter ํ์ด๋ฐ, bucket ํฌ๊ธฐ, compile๋ ๊ทธ๋ํ ์คํ ์์์ ๋ฐ๋ผ GPU๋ณ ํผํฌ๊ฐ ๋ค๋ฅด๊ฒ ๋ณด์ผ ์ ์์ต๋๋ค.
GPU0์ด ๋ค๋ฅธ GPU๋ณด๋ค ๋ ๋๊ฒ ๋ณด์ด๋ ๊ฒ๋ ์ผ๋ฐ์ ์ผ๋ก ๊ฐ๋ฅํฉ๋๋ค. rank0๊ฐ ๋ก๊น , ์ผ๋ถ metadata, checkpoint coordination, dataloader/host interaction์ ๋ ๋งก๋ ๊ฒฝ์ฐ๊ฐ ์๊ธฐ ๋๋ฌธ์ ๋๋ค.
5. CUDA caching allocator
nvidia-smi์ used memory๋ โํ์ฌ ํ
์๊ฐ ์ค์ ๋ก ์ฐ๋ ๋ฉ๋ชจ๋ฆฌโ๋ง ๋ปํ์ง ์์ต๋๋ค. PyTorch CUDA allocator๊ฐ ํ ๋ฒ ํ๋ณดํ ๋ธ๋ก์ ์ฌ์ฌ์ฉํ๋ ค๊ณ ์บ์์ ์ก๊ณ ์์ผ๋ฉด nvidia-smi์๋ ๊ณ์ ์ฌ์ฉ ์ค์ฒ๋ผ ๋ณด์
๋๋ค.
๋ฐ๋ผ์ step์ด ์งํ๋ ์๋ก used memory๊ฐ ์ฌ๋ผ๊ฐ๊ณ ์ ๋ด๋ ค๊ฐ์ง ์๋ ๊ฒ์ ์ ์์ผ ์ ์์ต๋๋ค. ์ค์ํ ๊ฒ์ reserved๊ฐ ๊ณ์ ๋ฌดํ ์ฆ๊ฐํ๋์ง, ๋๋ ํน์ step ์ดํ ์์ plateau๋ฅผ ๋ง๋๋์ง์ ๋๋ค.
6. checkpoint ์ ์ฅ ์ ์๊ฐ ํผํฌ
FSDP2 checkpoint ์ ์ฅ ์ .distcp shard, metadata, state_dict materialization, host/device transfer๊ฐ ๊ฒน์นฉ๋๋ค. ์ ์ฅ ์์ฒด๋ ์ฃผ๋ก CPU/disk ์์
์ด์ง๋ง, ์ ์ฅ ์ง์ /์งํ ๋ชจ๋ธ state ์ ๊ทผ ๋๋ฌธ์ GPU/CPU ๋ฉ๋ชจ๋ฆฌ ํผํฌ๊ฐ ์๊ธธ ์ ์์ต๋๋ค.
๊ทธ๋์ ๋๋ฌด ์ฆ์ checkpoint ์ ์ฅ์ ๋ค์ ๋ฌธ์ ๋ฅผ ๋ง๋ญ๋๋ค.
- step ์ฒ๋ฆฌ ์ง์ฐ
- ๋์คํฌ ์ฌ์ฉ๋ ๊ธ์ฆ
- HF upload ๋ฐ scan ๋น์ฉ ์ฆ๊ฐ
- ์ ์ฅ ์์ ํผํฌ ๋ฉ๋ชจ๋ฆฌ ์ฆ๊ฐ
ํ์ฌ 5,000 step๋ง๋ค ์ฝ 21GB๊ธ FSDP2 checkpoint๊ฐ ์๊น๋๋ค. 500 step๋ง๋ค ์ ์ฅํ๋ฉด stage-1 ๊ธฐ์ค์ผ๋ก ์ฒดํฌํฌ์ธํธ ์์ ์ ์ฅ ๋ถํ๊ฐ 10๋ฐฐ ๋์ด ๊ณผํฉ๋๋ค.
์ด์ OOM ์์ธ
์ด์ OOM์ batch๋ฅผ ํฌ๊ฒ ์ก์์ ๋ ์ด๋ฐ ๊ด์ธก VRAM๋ง ๋ณด๊ณ โ๊ด์ฐฎ๋คโ๊ณ ํ๋จํ ๊ฒ์ด ์์ธ์ ๋๋ค.
ํต์ฌ์ ๋ค์์ ๋๋ค.
- vocab 131K๋ผ logits/loss ๊ด๋ จ ์์ ๋ฒํผ๊ฐ ํฝ๋๋ค.
- HRM recurrent compile path๊ฐ ์ด๋ฐ ๋ช step ๋ค ์ถ๊ฐ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์๊ตฌํฉ๋๋ค.
- H200 8์ฅ์ด๋ผ compute๋ ์ถฉ๋ถํ์ง๋ง, 1.4B + 131K vocab + EMA + optimizer + FSDP2 ์กฐํฉ์์๋ batch๋ฅผ ๋๋ฌด ํฌ๊ฒ ์ก์ผ๋ฉด ํ๋ฐ ํผํฌ๊ฐ ๊ฑธ๋ฆฝ๋๋ค.
global_batch_size=262144,229376์ ์ด๋ฐ์๋ ๊ฐ๋ฅํด ๋ณด์์ง๋ง ์์ ๋ง์ง์ด ๋ถ์กฑํ์ต๋๋ค.
ํ์ฌ๋ global_batch_size=180224๋ก ๋ด๋ ค ์์ ์งํ ์ค์
๋๋ค.
์ด์ ๊ธฐ์ค
ํ์ฌ stage-1์์๋ GPU๋ฅผ ๋๋ฆฌ์ง ์๋ ๊ฒ์ด ์ฐ์ ์ด์ง๋ง, OOM์ผ๋ก run์ด ์ฃฝ์ผ๋ฉด ์ฌ์์/๊ฒ์ฆ/์ฒดํฌํฌ์ธํธ ์ ๋ฆฌ ๋น์ฉ์ด ๋ ํฝ๋๋ค.
๊ถ์ฅ ๊ธฐ์ค:
| ํญ๋ชฉ | ๊ธฐ์ค |
|---|---|
| primary batch | global_batch_size=180224 |
| ์ ์ฅ ์ฃผ๊ธฐ | checkpoint_step_interval=5000 |
| ๋ก์ปฌ ๋ณด๊ด | ์ต์ 2-3๊ฐ checkpoint๋ง ์ ์ง |
| HF main repo | ์ต์ safetensors export ์ค์ฌ |
| HF raw repo | resume๊ฐ ํ์ํ FSDP2 checkpoint๋ง ๋ณ๋ ๋ณด๊ด |
| OOM ์ฌ๋ฐ ์ | batch๋ฅผ 5-10% ๋ฎ์ถ๊ณ ๊ฐ์ resume checkpoint์์ ์ฌ์์ |
500 step checkpoint๊ฐ ๊ณผํ ์ด์
500 step๋ง๋ค ์ ์ฅํ๋ฉด ๋ค์ ๋ฌธ์ ๊ฐ ์๊น๋๋ค.
- ํ์ฌ FSDP2 checkpoint ํ๋๊ฐ ์ฝ 21GB์ ๋๋ค.
- 500 step ๊ฐ๊ฒฉ์ด๋ฉด 10,000 step๋ง๋ค ์ฝ 20๊ฐ, ์ฆ ์ฝ 420GB๊ฐ ์๊น๋๋ค.
- stage-1 ์ ์ฒด 88,522 step ๊ธฐ์ค์ผ๋ก๋ ๋จ์ ๊ณ์ฐ์ 170๊ฐ ์ด์์ด ์๊ฒจ ์ TB๊ฐ ๋ฉ๋๋ค.
- ์ ์ฅ ์์ฒด๊ฐ ํ์ต ๋ฃจํ๋ฅผ ๋ฐฉํดํ๊ณ , HF ์ ๋ก๋/์ค์บ๋ ์ปค์ง๋๋ค.
๋ฐ๋ผ์ ํ์ฌ์ฒ๋ผ 5,000 step ๊ฐ๊ฒฉ์ผ๋ก ์ ์ฅํ๊ณ , ๋ก์ปฌ์ ์ต์ 2-3๊ฐ๋ง ๋จ๊ธฐ๋ ํธ์ด ๋ง์ต๋๋ค.
๋ค์ batch ์กฐ์ ํ๋จ
ํ์ฌ VRAM ์ฌ์ฉ๋์ ๋์ง๋ง ํ์ต ์๋๋ ์์ ์ ์ ๋๋ค.
๋ค์ stage์์ batch๋ฅผ ์ฌ๋ฆฌ๊ณ ์ถ์ผ๋ฉด ํ ๋ฒ์ ํฌ๊ฒ ์ฌ๋ฆฌ์ง ๋ง๊ณ ๋ค์ ์์๊ฐ ๋ซ์ต๋๋ค.
global_batch_size=180224๋ก ์์ ์๋ฃ ํ์ธ- ๋ค์ dataset stage์์
196608ํ ์คํธ - 2-3์ฒ step ์ด์ VRAM plateau ํ์ธ
- checkpoint ์ ์ฅ ์์ ๊น์ง ํต๊ณผํ๋ฉด ์ ์ง
- OOM ๋๋ ํผํฌ ๋ถ์์ ์ ์ฆ์
180224๋๋172032๋ก ๋ณต๊ท
๋ ผ๋ฌธ ์ค์ ๊ณผ ๋น๊ตํ๋ฉด H200 8์ฅ์ ๊ฐํ์ง๋ง, ์ด๋ฒ ๋ชจ๋ธ์ vocab์ด 131K๋ผ upstream๊ณผ ๋ฉ๋ชจ๋ฆฌ ๊ตฌ์กฐ๊ฐ ๋ค๋ฆ ๋๋ค. ๋ฐ๋ผ์ โH200์ด๋๊น ๋ฌด์กฐ๊ฑด H100 16์ฅ batch๋ฅผ ๋๊ธด๋คโ๋ ์์ผ๋ก ์ก์ผ๋ฉด ์์ ์ฑ์ด ๋จ์ด์ง๋๋ค.
๊ฒฐ๋ก
ํ์ฌ VRAM ์์น์ torch compile/cache, 131K vocab logits buffer, FSDP2/optimizer/EMA/NCCL buffer, checkpoint ์๊ฐ ํผํฌ๊ฐ ๊ฒน์น ๊ฒฐ๊ณผ๋ก ๋ณด๋ ๊ฒ์ด ๋ง์ต๋๋ค.
ํ์ฌ global_batch_size=180224, 5,000 step checkpoint, ์ต์ 2-3๊ฐ ๋ณด๊ด ์ ์ฑ
์ ๋น ๋ฅธ ํ์ต๊ณผ OOM ํํผ ์ฌ์ด์ ํ์ค์ ์ธ ๊ท ํ์
๋๋ค. ํ์ต์ด ์์ ํ ์์ plateau๋ฅผ ๋ณด์ด๋ฉด ๋ค์ stage์์๋ง ์ํญ ์ฆ๋์ ๊ฒํ ํฉ๋๋ค.