CLIWorks commited on
Commit
02f20fc
·
verified ·
1 Parent(s): 79ad610

Upload mythos-fineweb-moe.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. mythos-fineweb-moe.py +9 -8
mythos-fineweb-moe.py CHANGED
@@ -958,6 +958,7 @@ def main():
958
  micro_batch = int(os.environ.get("MICRO_BATCH", "32"))
959
  target_tokens = int(os.environ.get("TARGET_TOKENS", "50_000_000"))
960
  grad_accum = int(os.environ.get("GRAD_ACCUM", "1"))
 
961
  global_batch_tok = world_size * micro_batch * grad_accum * seq_len
962
  total_steps = target_tokens // global_batch_tok
963
  warmup_steps = 200
@@ -971,8 +972,8 @@ def main():
971
 
972
  if master:
973
  logger.info(
974
- f"[MOE MLA+Engram] hidden=2048 | layers=6 | experts=32 | top-2 | "
975
- f"seq_len={seq_len} | micro_batch={micro_batch} | grad_accum={grad_accum} | "
976
  f"global_batch_tokens={global_batch_tok:,} | total_steps={total_steps:,}"
977
  )
978
  logger.info(
@@ -986,9 +987,9 @@ def main():
986
  # ------------------------------------------------------------------
987
  cfg = SpiderPortalConfig(
988
  hidden_size=2048, num_hidden_layers=6, num_attention_heads=16,
989
- num_key_value_heads=4, intermediate_size=4096,
990
- num_experts=32, num_experts_per_tok=2, num_shared_experts=1,
991
- router_aux_loss_coef=0.05, max_loop_iters=2,
992
  prelude_layers=2, coda_layers=2, lora_rank=128,
993
  rope_theta=10000000.0,
994
  rope_scaling=None,
@@ -1172,7 +1173,7 @@ def main():
1172
  else model.no_sync()
1173
  )
1174
  with sync, amp_ctx, sdpa_ctx:
1175
- output = model(x)
1176
  if master and step == start_step and micro_step == 0:
1177
  peak_vram = torch.cuda.max_memory_allocated() / 1024**3
1178
  logger.info(f"Reached first model forward | Peak VRAM: {peak_vram:.1f}GB")
@@ -1182,8 +1183,8 @@ def main():
1182
  else:
1183
  logits = output
1184
  aux_loss = 0.0
1185
- loss = nn.functional.cross_entropy(
1186
- logits.view(-1, vocab_size), y.view(-1)
1187
  )
1188
  loss = loss + cfg.router_aux_loss_coef * aux_loss
1189
  loss = loss / grad_accum
 
958
  micro_batch = int(os.environ.get("MICRO_BATCH", "32"))
959
  target_tokens = int(os.environ.get("TARGET_TOKENS", "50_000_000"))
960
  grad_accum = int(os.environ.get("GRAD_ACCUM", "1"))
961
+ n_loops = int(os.environ.get("N_LOOPS", "6"))
962
  global_batch_tok = world_size * micro_batch * grad_accum * seq_len
963
  total_steps = target_tokens // global_batch_tok
964
  warmup_steps = 200
 
972
 
973
  if master:
974
  logger.info(
975
+ f"[MOE MLA+Engram] hidden=2048 | layers=6 | experts=16 | top-1 | "
976
+ f"n_loops={n_loops} | seq_len={seq_len} | micro_batch={micro_batch} | grad_accum={grad_accum} | "
977
  f"global_batch_tokens={global_batch_tok:,} | total_steps={total_steps:,}"
978
  )
979
  logger.info(
 
987
  # ------------------------------------------------------------------
988
  cfg = SpiderPortalConfig(
989
  hidden_size=2048, num_hidden_layers=6, num_attention_heads=16,
990
+ num_key_value_heads=4, intermediate_size=1024,
991
+ num_experts=16, num_experts_per_tok=1, num_shared_experts=1,
992
+ router_aux_loss_coef=0.05, max_loop_iters=16,
993
  prelude_layers=2, coda_layers=2, lora_rank=128,
994
  rope_theta=10000000.0,
995
  rope_scaling=None,
 
1173
  else model.no_sync()
1174
  )
1175
  with sync, amp_ctx, sdpa_ctx:
1176
+ output = model(x, n_loops=n_loops)
1177
  if master and step == start_step and micro_step == 0:
1178
  peak_vram = torch.cuda.max_memory_allocated() / 1024**3
1179
  logger.info(f"Reached first model forward | Peak VRAM: {peak_vram:.1f}GB")
 
1183
  else:
1184
  logits = output
1185
  aux_loss = 0.0
1186
+ loss = F.nll_loss(
1187
+ logits.view(-1, vocab_size).log_softmax(dim=-1), y.view(-1)
1188
  )
1189
  loss = loss + cfg.router_aux_loss_coef * aux_loss
1190
  loss = loss / grad_accum