KoHRM-Text-1.4B / VRAM_OOM_NOTES_2026-05-24.md
gyung's picture
Update model card and VRAM notes
0756b71 verified

KoHRM-Text VRAM / OOM Notes

์ž‘์„ฑ์ผ: 2026-05-24

์ด ๋ฌธ์„œ๋Š” KoHRM-Text-1.4B stage-1 ํ•™์Šต ์ค‘ VRAM์ด ์‹œ๊ฐ„์ด ์ง€๋‚˜๋ฉฐ ์ฆ๊ฐ€ํ•˜๋Š” ์ด์œ , ์ด์ „ OOM ์›์ธ, ํ˜„์žฌ ์šด์˜ ๊ธฐ์ค€์„ ๊ธฐ๋กํ•ฉ๋‹ˆ๋‹ค.

ํ˜„์žฌ ๊ด€์ธก ์ƒํƒœ

ํ˜„์žฌ stage-1 run์€ ๋‹ค์Œ ์„ค์ •์œผ๋กœ ์ •์ƒ ํ•™์Šต ์ค‘์ž…๋‹ˆ๋‹ค.

ํ•ญ๋ชฉ ๊ฐ’
GPU 8 x NVIDIA H200
GPU utilization 8์žฅ ๋ชจ๋‘ 99%
global batch 180,224 tokens
local token slots/GPU 22,528
context 4,096
VRAM GPU0 ์•ฝ 129.9GB, ๋‚˜๋จธ์ง€ ์•ฝ 127.6GB / 143.8GB
speed ์•ฝ 1.02 step/sec
checkpoint interval 5,000 steps

ํ˜„์žฌ ์„ค์ •์€ ๋น ๋ฅด์ง€๋งŒ ์—ฌ์œ  VRAM์ด ์•„์ฃผ ๋„“์€ ํŽธ์€ ์•„๋‹™๋‹ˆ๋‹ค. H200 ์žฅ๋‹น ์•ฝ 144GB ์ค‘ 127-130GB๋ฅผ ์‚ฌ์šฉํ•˜๋ฏ€๋กœ, NCCL/allocator/compiler/cache/checkpoint ์ˆœ๊ฐ„ ํ”ผํฌ๊ฐ€ ๊ฒน์น˜๋ฉด OOM ์œ„ํ—˜์ด ๋‹ค์‹œ ์ƒ๊ธธ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์™œ ํ•™์Šต ์ค‘ VRAM์ด ์ ์  ์˜ฌ๋ผ๊ฐ€๋‚˜

VRAM ์ฆ๊ฐ€๊ฐ€ ๊ณง๋ฐ”๋กœ โ€œ๋ฉ”๋ชจ๋ฆฌ ๋ˆ„์ˆ˜โ€๋ผ๋Š” ๋œป์€ ์•„๋‹™๋‹ˆ๋‹ค. ๋Œ€ํ˜• PyTorch/FSDP/compile ํ•™์Šต์—์„œ๋Š” ๋‹ค์Œ ์š”์ธ์ด ๊ฒน์น˜๋ฉด์„œ ์ดˆ๋ฐ˜๋ณด๋‹ค ๋’ค์—์„œ VRAM์ด ๋” ๋†’์•„์ง€๋Š” ํŒจํ„ด์ด ํ”ํ•ฉ๋‹ˆ๋‹ค.

1. torch.compile / CUDA graph / kernel cache

HRM-Text ์ฝ”๋“œ๋Š” ์—ฌ๋Ÿฌ forward/backward path๋ฅผ compileํ•ฉ๋‹ˆ๋‹ค. ์ดˆ๋ฐ˜ ๋ช‡ step์—์„œ๋Š” ๋ชจ๋“  shape/path๊ฐ€ ์•„์ง compile๋˜์ง€ ์•Š์•˜๊ณ , ํ•™์Šต์ด ์ง„ํ–‰๋˜๋ฉฐ ์ถ”๊ฐ€ graph, Triton kernel, CUDA kernel cache๊ฐ€ ๋งŒ๋“ค์–ด์ง‘๋‹ˆ๋‹ค.

ํŠนํžˆ HRM ๊ตฌ์กฐ๋Š” H/L recurrent cycle๊ณผ PrefixLM loss๊ฐ€ ์žˆ์–ด ๋‹จ์ˆœ decoder-only Transformer๋ณด๋‹ค compile path๊ฐ€ ๋” ๋ณต์žกํ•ฉ๋‹ˆ๋‹ค. ์ดˆ๋ฐ˜ VRAM๋งŒ ๋ณด๊ณ  batch๋ฅผ ํฌ๊ฒŒ ์žก์œผ๋ฉด ํ›„์† graph๊ฐ€ ์ƒ์„ฑ๋  ๋•Œ ์ถ”๊ฐ€ ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ๋ชป ๋ฐ›์•„ OOM์ด ๋‚  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

2. final logits buffer ํฌ๊ธฐ

์ด๋ฒˆ ๋ชจ๋ธ์€ vocab์ด 131,072์ž…๋‹ˆ๋‹ค. upstream HRM-Text ๋…ผ๋ฌธ ์„ค์ •์˜ 65,536 vocab๋ณด๋‹ค ๋‘ ๋ฐฐ์ž…๋‹ˆ๋‹ค.

batch token slots๊ฐ€ ์ปค์งˆ์ˆ˜๋ก final logits ๋˜๋Š” loss ๊ณ„์‚ฐ ์ชฝ ์ž„์‹œ ๋ฒ„ํผ๊ฐ€ ๋งค์šฐ ์ปค์ง‘๋‹ˆ๋‹ค.

์˜ˆ๋ฅผ ๋“ค์–ด local token slots/GPU๊ฐ€ 32,768์ด๋ฉด 32768 x 131072 bf16 logits ๊ณ„์—ด ๋ฒ„ํผ๊ฐ€ ํ•„์š”ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋ก ์ƒ ๋‹จ์ผ bf16 dense buffer๋งŒ ์žก์•„๋„ ์•ฝ 8GB ์ด์ƒ์ด๊ณ , ์‹ค์ œ backward/temporary/parallel buffer๊นŒ์ง€ ํ•ฉ์น˜๋ฉด ํ›จ์”ฌ ์ปค์ง‘๋‹ˆ๋‹ค.

์ด ๋•Œ๋ฌธ์— ์ฒ˜์Œ์—๋Š” global_batch_size=262144 ๋˜๋Š” 229376์ด ์ž ๊น ๋Œ์•„๊ฐ€๋„, ๋’ค์—์„œ compile graph์™€ logits/loss ์ž„์‹œ ๋ฒ„ํผ๊ฐ€ ๊ฒน์น˜๋Š” ์ˆœ๊ฐ„ OOM์ด ๋‚  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

3. FSDP2 / optimizer / EMA ์ƒํƒœ

ํ˜„์žฌ ํ•™์Šต์€ model weights๋งŒ ๋“ค๊ณ  ์žˆ๋Š” ๊ฒƒ์ด ์•„๋‹™๋‹ˆ๋‹ค.

  • model parameters
  • gradients
  • optimizer state
  • Adam-atan2 state
  • EMA state
  • FSDP shard/all-gather/reduce-scatter buffers
  • recurrent carry ๊ด€๋ จ state

์ด ์ƒํƒœ๋“ค์ด step๋งˆ๋‹ค ํ•ญ์ƒ ๊ฐ™์€ ์ˆœ๊ฐ„์— ๊ฐ™์€ ํฌ๊ธฐ๋กœ ๋ณด์ด๋Š” ๊ฒƒ์€ ์•„๋‹™๋‹ˆ๋‹ค. ํŠน์ • backward path, optimizer step, checkpoint save ์‹œ์ ์— ํ”ผํฌ๊ฐ€ ์˜ฌ๋ผ๊ฐˆ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

4. NCCL communication buffers

8 GPU ๋ถ„์‚ฐ ํ•™์Šต์—์„œ๋Š” NCCL ํ†ต์‹  ๋ฒ„ํผ๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค. all-gather/reduce-scatter ํƒ€์ด๋ฐ, bucket ํฌ๊ธฐ, compile๋œ ๊ทธ๋ž˜ํ”„ ์‹คํ–‰ ์ˆœ์„œ์— ๋”ฐ๋ผ GPU๋ณ„ ํ”ผํฌ๊ฐ€ ๋‹ค๋ฅด๊ฒŒ ๋ณด์ผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

GPU0์ด ๋‹ค๋ฅธ GPU๋ณด๋‹ค ๋” ๋†’๊ฒŒ ๋ณด์ด๋Š” ๊ฒƒ๋„ ์ผ๋ฐ˜์ ์œผ๋กœ ๊ฐ€๋Šฅํ•ฉ๋‹ˆ๋‹ค. rank0๊ฐ€ ๋กœ๊น…, ์ผ๋ถ€ metadata, checkpoint coordination, dataloader/host interaction์„ ๋” ๋งก๋Š” ๊ฒฝ์šฐ๊ฐ€ ์žˆ๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค.

5. CUDA caching allocator

nvidia-smi์˜ used memory๋Š” โ€œํ˜„์žฌ ํ…์„œ๊ฐ€ ์‹ค์ œ๋กœ ์“ฐ๋Š” ๋ฉ”๋ชจ๋ฆฌโ€๋งŒ ๋œปํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. PyTorch CUDA allocator๊ฐ€ ํ•œ ๋ฒˆ ํ™•๋ณดํ•œ ๋ธ”๋ก์„ ์žฌ์‚ฌ์šฉํ•˜๋ ค๊ณ  ์บ์‹œ์— ์žก๊ณ  ์žˆ์œผ๋ฉด nvidia-smi์—๋Š” ๊ณ„์† ์‚ฌ์šฉ ์ค‘์ฒ˜๋Ÿผ ๋ณด์ž…๋‹ˆ๋‹ค.

๋”ฐ๋ผ์„œ step์ด ์ง„ํ–‰๋ ์ˆ˜๋ก used memory๊ฐ€ ์˜ฌ๋ผ๊ฐ€๊ณ  ์ž˜ ๋‚ด๋ ค๊ฐ€์ง€ ์•Š๋Š” ๊ฒƒ์€ ์ •์ƒ์ผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ค‘์š”ํ•œ ๊ฒƒ์€ reserved๊ฐ€ ๊ณ„์† ๋ฌดํ•œ ์ฆ๊ฐ€ํ•˜๋Š”์ง€, ๋˜๋Š” ํŠน์ • step ์ดํ›„ ์•ˆ์ • plateau๋ฅผ ๋งŒ๋“œ๋Š”์ง€์ž…๋‹ˆ๋‹ค.

6. checkpoint ์ €์žฅ ์‹œ ์ˆœ๊ฐ„ ํ”ผํฌ

FSDP2 checkpoint ์ €์žฅ ์‹œ .distcp shard, metadata, state_dict materialization, host/device transfer๊ฐ€ ๊ฒน์นฉ๋‹ˆ๋‹ค. ์ €์žฅ ์ž์ฒด๋Š” ์ฃผ๋กœ CPU/disk ์ž‘์—…์ด์ง€๋งŒ, ์ €์žฅ ์ง์ „/์งํ›„ ๋ชจ๋ธ state ์ ‘๊ทผ ๋•Œ๋ฌธ์— GPU/CPU ๋ฉ”๋ชจ๋ฆฌ ํ”ผํฌ๊ฐ€ ์ƒ๊ธธ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

๊ทธ๋ž˜์„œ ๋„ˆ๋ฌด ์žฆ์€ checkpoint ์ €์žฅ์€ ๋‹ค์Œ ๋ฌธ์ œ๋ฅผ ๋งŒ๋“ญ๋‹ˆ๋‹ค.

  • step ์ฒ˜๋ฆฌ ์ง€์—ฐ
  • ๋””์Šคํฌ ์‚ฌ์šฉ๋Ÿ‰ ๊ธ‰์ฆ
  • HF upload ๋ฐ scan ๋น„์šฉ ์ฆ๊ฐ€
  • ์ €์žฅ ์‹œ์  ํ”ผํฌ ๋ฉ”๋ชจ๋ฆฌ ์ฆ๊ฐ€

ํ˜„์žฌ 5,000 step๋งˆ๋‹ค ์•ฝ 21GB๊ธ‰ FSDP2 checkpoint๊ฐ€ ์ƒ๊น๋‹ˆ๋‹ค. 500 step๋งˆ๋‹ค ์ €์žฅํ•˜๋ฉด stage-1 ๊ธฐ์ค€์œผ๋กœ ์ฒดํฌํฌ์ธํŠธ ์ˆ˜์™€ ์ €์žฅ ๋ถ€ํ•˜๊ฐ€ 10๋ฐฐ ๋Š˜์–ด ๊ณผํ•ฉ๋‹ˆ๋‹ค.

์ด์ „ OOM ์›์ธ

์ด์ „ OOM์€ batch๋ฅผ ํฌ๊ฒŒ ์žก์•˜์„ ๋•Œ ์ดˆ๋ฐ˜ ๊ด€์ธก VRAM๋งŒ ๋ณด๊ณ  โ€œ๊ดœ์ฐฎ๋‹คโ€๊ณ  ํŒ๋‹จํ•œ ๊ฒƒ์ด ์›์ธ์ž…๋‹ˆ๋‹ค.

ํ•ต์‹ฌ์€ ๋‹ค์Œ์ž…๋‹ˆ๋‹ค.

  1. vocab 131K๋ผ logits/loss ๊ด€๋ จ ์ž„์‹œ ๋ฒ„ํผ๊ฐ€ ํฝ๋‹ˆ๋‹ค.
  2. HRM recurrent compile path๊ฐ€ ์ดˆ๋ฐ˜ ๋ช‡ step ๋’ค ์ถ”๊ฐ€ ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ์š”๊ตฌํ•ฉ๋‹ˆ๋‹ค.
  3. H200 8์žฅ์ด๋ผ compute๋Š” ์ถฉ๋ถ„ํ•˜์ง€๋งŒ, 1.4B + 131K vocab + EMA + optimizer + FSDP2 ์กฐํ•ฉ์—์„œ๋Š” batch๋ฅผ ๋„ˆ๋ฌด ํฌ๊ฒŒ ์žก์œผ๋ฉด ํ›„๋ฐ˜ ํ”ผํฌ๊ฐ€ ๊ฑธ๋ฆฝ๋‹ˆ๋‹ค.
  4. global_batch_size=262144, 229376์€ ์ดˆ๋ฐ˜์—๋Š” ๊ฐ€๋Šฅํ•ด ๋ณด์˜€์ง€๋งŒ ์•ˆ์ • ๋งˆ์ง„์ด ๋ถ€์กฑํ–ˆ์Šต๋‹ˆ๋‹ค.

ํ˜„์žฌ๋Š” global_batch_size=180224๋กœ ๋‚ด๋ ค ์•ˆ์ • ์ง„ํ–‰ ์ค‘์ž…๋‹ˆ๋‹ค.

์šด์˜ ๊ธฐ์ค€

ํ˜„์žฌ stage-1์—์„œ๋Š” GPU๋ฅผ ๋†€๋ฆฌ์ง€ ์•Š๋Š” ๊ฒƒ์ด ์šฐ์„ ์ด์ง€๋งŒ, OOM์œผ๋กœ run์ด ์ฃฝ์œผ๋ฉด ์žฌ์‹œ์ž‘/๊ฒ€์ฆ/์ฒดํฌํฌ์ธํŠธ ์ •๋ฆฌ ๋น„์šฉ์ด ๋” ํฝ๋‹ˆ๋‹ค.

๊ถŒ์žฅ ๊ธฐ์ค€:

ํ•ญ๋ชฉ ๊ธฐ์ค€
primary batch global_batch_size=180224
์ €์žฅ ์ฃผ๊ธฐ checkpoint_step_interval=5000
๋กœ์ปฌ ๋ณด๊ด€ ์ตœ์‹  2-3๊ฐœ checkpoint๋งŒ ์œ ์ง€
HF main repo ์ตœ์‹  safetensors export ์ค‘์‹ฌ
HF raw repo resume๊ฐ€ ํ•„์š”ํ•œ FSDP2 checkpoint๋งŒ ๋ณ„๋„ ๋ณด๊ด€
OOM ์žฌ๋ฐœ ์‹œ batch๋ฅผ 5-10% ๋‚ฎ์ถ”๊ณ  ๊ฐ™์€ resume checkpoint์—์„œ ์žฌ์‹œ์ž‘

500 step checkpoint๊ฐ€ ๊ณผํ•œ ์ด์œ 

500 step๋งˆ๋‹ค ์ €์žฅํ•˜๋ฉด ๋‹ค์Œ ๋ฌธ์ œ๊ฐ€ ์ƒ๊น๋‹ˆ๋‹ค.

  • ํ˜„์žฌ FSDP2 checkpoint ํ•˜๋‚˜๊ฐ€ ์•ฝ 21GB์ž…๋‹ˆ๋‹ค.
  • 500 step ๊ฐ„๊ฒฉ์ด๋ฉด 10,000 step๋งˆ๋‹ค ์•ฝ 20๊ฐœ, ์ฆ‰ ์•ฝ 420GB๊ฐ€ ์ƒ๊น๋‹ˆ๋‹ค.
  • stage-1 ์ „์ฒด 88,522 step ๊ธฐ์ค€์œผ๋กœ๋Š” ๋‹จ์ˆœ ๊ณ„์‚ฐ์ƒ 170๊ฐœ ์ด์ƒ์ด ์ƒ๊ฒจ ์ˆ˜ TB๊ฐ€ ๋ฉ๋‹ˆ๋‹ค.
  • ์ €์žฅ ์ž์ฒด๊ฐ€ ํ•™์Šต ๋ฃจํ”„๋ฅผ ๋ฐฉํ•ดํ•˜๊ณ , HF ์—…๋กœ๋“œ/์Šค์บ”๋„ ์ปค์ง‘๋‹ˆ๋‹ค.

๋”ฐ๋ผ์„œ ํ˜„์žฌ์ฒ˜๋Ÿผ 5,000 step ๊ฐ„๊ฒฉ์œผ๋กœ ์ €์žฅํ•˜๊ณ , ๋กœ์ปฌ์€ ์ตœ์‹  2-3๊ฐœ๋งŒ ๋‚จ๊ธฐ๋Š” ํŽธ์ด ๋งž์Šต๋‹ˆ๋‹ค.

๋‹ค์Œ batch ์กฐ์ • ํŒ๋‹จ

ํ˜„์žฌ VRAM ์‚ฌ์šฉ๋Ÿ‰์€ ๋†’์ง€๋งŒ ํ•™์Šต ์†๋„๋Š” ์•ˆ์ •์ ์ž…๋‹ˆ๋‹ค.

๋‹ค์Œ stage์—์„œ batch๋ฅผ ์˜ฌ๋ฆฌ๊ณ  ์‹ถ์œผ๋ฉด ํ•œ ๋ฒˆ์— ํฌ๊ฒŒ ์˜ฌ๋ฆฌ์ง€ ๋ง๊ณ  ๋‹ค์Œ ์ˆœ์„œ๊ฐ€ ๋‚ซ์Šต๋‹ˆ๋‹ค.

  1. global_batch_size=180224๋กœ ์•ˆ์ • ์™„๋ฃŒ ํ™•์ธ
  2. ๋‹ค์Œ dataset stage์—์„œ 196608 ํ…Œ์ŠคํŠธ
  3. 2-3์ฒœ step ์ด์ƒ VRAM plateau ํ™•์ธ
  4. checkpoint ์ €์žฅ ์‹œ์ ๊นŒ์ง€ ํ†ต๊ณผํ•˜๋ฉด ์œ ์ง€
  5. OOM ๋˜๋Š” ํ”ผํฌ ๋ถˆ์•ˆ์ • ์‹œ ์ฆ‰์‹œ 180224 ๋˜๋Š” 172032๋กœ ๋ณต๊ท€

๋…ผ๋ฌธ ์„ค์ •๊ณผ ๋น„๊ตํ•˜๋ฉด H200 8์žฅ์€ ๊ฐ•ํ•˜์ง€๋งŒ, ์ด๋ฒˆ ๋ชจ๋ธ์€ vocab์ด 131K๋ผ upstream๊ณผ ๋ฉ”๋ชจ๋ฆฌ ๊ตฌ์กฐ๊ฐ€ ๋‹ค๋ฆ…๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ โ€œH200์ด๋‹ˆ๊นŒ ๋ฌด์กฐ๊ฑด H100 16์žฅ batch๋ฅผ ๋„˜๊ธด๋‹คโ€๋Š” ์‹์œผ๋กœ ์žก์œผ๋ฉด ์•ˆ์ •์„ฑ์ด ๋–จ์–ด์ง‘๋‹ˆ๋‹ค.

๊ฒฐ๋ก 

ํ˜„์žฌ VRAM ์ƒ์Šน์€ torch compile/cache, 131K vocab logits buffer, FSDP2/optimizer/EMA/NCCL buffer, checkpoint ์ˆœ๊ฐ„ ํ”ผํฌ๊ฐ€ ๊ฒน์นœ ๊ฒฐ๊ณผ๋กœ ๋ณด๋Š” ๊ฒƒ์ด ๋งž์Šต๋‹ˆ๋‹ค.

ํ˜„์žฌ global_batch_size=180224, 5,000 step checkpoint, ์ตœ์‹  2-3๊ฐœ ๋ณด๊ด€ ์ •์ฑ…์€ ๋น ๋ฅธ ํ•™์Šต๊ณผ OOM ํšŒํ”ผ ์‚ฌ์ด์˜ ํ˜„์‹ค์ ์ธ ๊ท ํ˜•์ž…๋‹ˆ๋‹ค. ํ•™์Šต์ด ์™„์ „ํžˆ ์•ˆ์ • plateau๋ฅผ ๋ณด์ด๋ฉด ๋‹ค์Œ stage์—์„œ๋งŒ ์†Œํญ ์ฆ๋Ÿ‰์„ ๊ฒ€ํ† ํ•ฉ๋‹ˆ๋‹ค.