| from __future__ import annotations |
|
|
| import json |
| import math |
| import os |
| from pathlib import Path |
| from typing import Any |
|
|
| import torch |
|
|
| from chimera import Chimera51ForCausalLM |
|
|
|
|
| DEFAULT_SCALE_PRESETS = { |
| "tiny": dict(hidden_size=256, intermediate_size=512, num_heads=4, head_dim=48), |
| "small": dict(hidden_size=512, intermediate_size=1024, num_heads=8, head_dim=48), |
| "medium": dict(hidden_size=1024, intermediate_size=2048, num_heads=8, head_dim=96), |
| } |
|
|
|
|
| def setup_cpu_runtime(*, interop_threads: int | None = None) -> int: |
| n_cpus = os.cpu_count() or 4 |
| os.environ.setdefault("OMP_NUM_THREADS", str(n_cpus)) |
| os.environ.setdefault("MKL_NUM_THREADS", str(n_cpus)) |
| os.environ.setdefault("KMP_AFFINITY", "granularity=fine,compact,1,0") |
| os.environ.setdefault("KMP_BLOCKTIME", "1") |
| os.environ.setdefault("MALLOC_CONF", "background_thread:true,metadata_thp:auto") |
|
|
| torch.set_num_threads(int(os.environ.get("OMP_NUM_THREADS", n_cpus))) |
| try: |
| target = interop_threads |
| if target is None: |
| target = int(os.environ.get("CHIMERA_INTEROP_THREADS", "1")) |
| torch.set_num_interop_threads(target) |
| except RuntimeError: |
| pass |
| return n_cpus |
|
|
|
|
| def cosine_lr(step: int, warmup: int, total: int, max_lr: float, min_lr: float) -> float: |
| if warmup > 0 and step < warmup: |
| return max_lr * (step + 1) / warmup |
| if step >= total: |
| return min_lr |
| progress = (step - warmup) / max(1, total - warmup) |
| return min_lr + 0.5 * (max_lr - min_lr) * (1.0 + math.cos(math.pi * progress)) |
|
|
|
|
| def load_json_config(path: str | os.PathLike[str]) -> dict[str, Any]: |
| with open(path, encoding="utf-8") as fh: |
| return json.load(fh) |
|
|
|
|
| def apply_standard_config_tweaks(config: dict[str, Any], *, scale: str, seq_len: int) -> dict[str, Any]: |
| config = dict(config) |
| if scale in DEFAULT_SCALE_PRESETS: |
| config.update(DEFAULT_SCALE_PRESETS[scale]) |
| config["num_hidden_layers"] = int(config.get("num_hidden_layers", 28)) |
| config["vocab_size"] = config.get("vocab_size", 200073) |
| config.setdefault("gated_deltanet", {})["chunk_size"] = min(seq_len, 64) |
| config.setdefault("xlstm", {})["memory_size_per_head"] = [config["head_dim"], config["head_dim"]] |
| config.setdefault("titans", {}).update({ |
| "memory_depth": 2, |
| "persistent_memory_slots": 16, |
| "local_window_size": min(seq_len, 256), |
| }) |
|
|
| |
| |
| |
| |
| |
| moe_cfg = config.setdefault("backbone", {}).setdefault("moe", {}) |
| moe_cfg.setdefault("layers", [3, 7, 11, 15, 19, 23, 27]) |
| if scale in DEFAULT_SCALE_PRESETS: |
| |
| moe_cfg["moe_intermediate_size"] = config["intermediate_size"] // 4 |
| moe_cfg["n_routed_experts"] = 8 |
| else: |
| moe_cfg.setdefault("moe_intermediate_size", config["intermediate_size"] // 4) |
| moe_cfg.setdefault("n_routed_experts", 8) |
| moe_cfg.setdefault("n_shared_experts", 1) |
| moe_cfg.setdefault("num_experts_per_tok", 2) |
|
|
| config.setdefault("looping", {}).update({ |
| "enabled": True, |
| "prelude": [0, 3], |
| "loop": [4, 23], |
| "coda": [24, 27], |
| "loop_range": [1, 3], |
| "loop_default": 2, |
| }) |
| config.setdefault("span_inference", {})["enabled"] = True |
| config.setdefault("grammar", {})["enabled"] = True |
| config.setdefault("entropy_valve", {})["enabled"] = True |
| config.setdefault("debt_ledger", {})["enabled"] = True |
| config.setdefault("multimodal", {})["enabled"] = False |
| return config |
|
|
|
|
| def build_model_from_args(args) -> tuple[Chimera51ForCausalLM, dict[str, Any]]: |
| config = load_json_config(args.config) |
| config = apply_standard_config_tweaks(config, scale=args.scale, seq_len=args.seq_len) |
| return Chimera51ForCausalLM(config), config |
|
|
|
|
| def save_training_checkpoint(model, config: dict[str, Any], step: int, output_dir: str) -> str: |
| ckpt_dir = Path(output_dir) |
| ckpt_dir.mkdir(parents=True, exist_ok=True) |
| raw_model = getattr(model, "_orig_mod", model) |
| torch.save({"model": raw_model.state_dict(), "config": config, "step": step}, ckpt_dir / "ckpt.pt") |
| return str(ckpt_dir) |
|
|
|
|
| def save_final_checkpoint( |
| model, |
| config: dict[str, Any], |
| step: int, |
| best_loss: float, |
| output_dir: str, |
| ) -> str: |
| final_dir = Path(output_dir) |
| final_dir.mkdir(parents=True, exist_ok=True) |
| raw_model = getattr(model, "_orig_mod", model) |
| torch.save( |
| {"model": raw_model.state_dict(), "config": config, "step": step, "best_loss": best_loss}, |
| final_dir / "model.pt", |
| ) |
| with open(final_dir / "config.json", "w", encoding="utf-8") as fh: |
| json.dump(config, fh, indent=2) |
| return str(final_dir) |
|
|