chomera / chimera /training /common.py
Lgr54HFi's picture
fix: MoE intermediate_size not scaled for tiny — 158M→4M MoE params
6cb7b4d verified
from __future__ import annotations
import json
import math
import os
from pathlib import Path
from typing import Any
import torch
from chimera import Chimera51ForCausalLM
DEFAULT_SCALE_PRESETS = {
"tiny": dict(hidden_size=256, intermediate_size=512, num_heads=4, head_dim=48),
"small": dict(hidden_size=512, intermediate_size=1024, num_heads=8, head_dim=48),
"medium": dict(hidden_size=1024, intermediate_size=2048, num_heads=8, head_dim=96),
}
def setup_cpu_runtime(*, interop_threads: int | None = None) -> int:
n_cpus = os.cpu_count() or 4
os.environ.setdefault("OMP_NUM_THREADS", str(n_cpus))
os.environ.setdefault("MKL_NUM_THREADS", str(n_cpus))
os.environ.setdefault("KMP_AFFINITY", "granularity=fine,compact,1,0")
os.environ.setdefault("KMP_BLOCKTIME", "1")
os.environ.setdefault("MALLOC_CONF", "background_thread:true,metadata_thp:auto")
torch.set_num_threads(int(os.environ.get("OMP_NUM_THREADS", n_cpus)))
try:
target = interop_threads
if target is None:
target = int(os.environ.get("CHIMERA_INTEROP_THREADS", "1"))
torch.set_num_interop_threads(target)
except RuntimeError:
pass
return n_cpus
def cosine_lr(step: int, warmup: int, total: int, max_lr: float, min_lr: float) -> float:
if warmup > 0 and step < warmup:
return max_lr * (step + 1) / warmup
if step >= total:
return min_lr
progress = (step - warmup) / max(1, total - warmup)
return min_lr + 0.5 * (max_lr - min_lr) * (1.0 + math.cos(math.pi * progress))
def load_json_config(path: str | os.PathLike[str]) -> dict[str, Any]:
with open(path, encoding="utf-8") as fh:
return json.load(fh)
def apply_standard_config_tweaks(config: dict[str, Any], *, scale: str, seq_len: int) -> dict[str, Any]:
config = dict(config)
if scale in DEFAULT_SCALE_PRESETS:
config.update(DEFAULT_SCALE_PRESETS[scale])
config["num_hidden_layers"] = int(config.get("num_hidden_layers", 28))
config["vocab_size"] = config.get("vocab_size", 200073)
config.setdefault("gated_deltanet", {})["chunk_size"] = min(seq_len, 64)
config.setdefault("xlstm", {})["memory_size_per_head"] = [config["head_dim"], config["head_dim"]]
config.setdefault("titans", {}).update({
"memory_depth": 2,
"persistent_memory_slots": 16,
"local_window_size": min(seq_len, 256),
})
# ── MoE config ──
# IMPORTANT: force-override MoE sizes when a scale preset is active.
# The full-scale config.json has moe_intermediate_size=1728 (for hidden=2560).
# Using setdefault() would silently keep 1728 even at tiny scale (hidden=256),
# resulting in 158M params in MoE alone (70% of model).
moe_cfg = config.setdefault("backbone", {}).setdefault("moe", {})
moe_cfg.setdefault("layers", [3, 7, 11, 15, 19, 23, 27])
if scale in DEFAULT_SCALE_PRESETS:
# Force scale-appropriate MoE sizes
moe_cfg["moe_intermediate_size"] = config["intermediate_size"] // 4
moe_cfg["n_routed_experts"] = 8
else:
moe_cfg.setdefault("moe_intermediate_size", config["intermediate_size"] // 4)
moe_cfg.setdefault("n_routed_experts", 8)
moe_cfg.setdefault("n_shared_experts", 1)
moe_cfg.setdefault("num_experts_per_tok", 2)
config.setdefault("looping", {}).update({
"enabled": True,
"prelude": [0, 3],
"loop": [4, 23],
"coda": [24, 27],
"loop_range": [1, 3],
"loop_default": 2,
})
config.setdefault("span_inference", {})["enabled"] = True
config.setdefault("grammar", {})["enabled"] = True
config.setdefault("entropy_valve", {})["enabled"] = True
config.setdefault("debt_ledger", {})["enabled"] = True
config.setdefault("multimodal", {})["enabled"] = False
return config
def build_model_from_args(args) -> tuple[Chimera51ForCausalLM, dict[str, Any]]:
config = load_json_config(args.config)
config = apply_standard_config_tweaks(config, scale=args.scale, seq_len=args.seq_len)
return Chimera51ForCausalLM(config), config
def save_training_checkpoint(model, config: dict[str, Any], step: int, output_dir: str) -> str:
ckpt_dir = Path(output_dir)
ckpt_dir.mkdir(parents=True, exist_ok=True)
raw_model = getattr(model, "_orig_mod", model)
torch.save({"model": raw_model.state_dict(), "config": config, "step": step}, ckpt_dir / "ckpt.pt")
return str(ckpt_dir)
def save_final_checkpoint(
model,
config: dict[str, Any],
step: int,
best_loss: float,
output_dir: str,
) -> str:
final_dir = Path(output_dir)
final_dir.mkdir(parents=True, exist_ok=True)
raw_model = getattr(model, "_orig_mod", model)
torch.save(
{"model": raw_model.state_dict(), "config": config, "step": step, "best_loss": best_loss},
final_dir / "model.pt",
)
with open(final_dir / "config.json", "w", encoding="utf-8") as fh:
json.dump(config, fh, indent=2)
return str(final_dir)