File size: 5,024 Bytes
11c11f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6cb7b4d
 
 
 
 
 
11c11f8
 
6cb7b4d
 
 
 
 
 
 
11c11f8
 
6cb7b4d
11c11f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from __future__ import annotations

import json
import math
import os
from pathlib import Path
from typing import Any

import torch

from chimera import Chimera51ForCausalLM


DEFAULT_SCALE_PRESETS = {
    "tiny": dict(hidden_size=256, intermediate_size=512, num_heads=4, head_dim=48),
    "small": dict(hidden_size=512, intermediate_size=1024, num_heads=8, head_dim=48),
    "medium": dict(hidden_size=1024, intermediate_size=2048, num_heads=8, head_dim=96),
}


def setup_cpu_runtime(*, interop_threads: int | None = None) -> int:
    n_cpus = os.cpu_count() or 4
    os.environ.setdefault("OMP_NUM_THREADS", str(n_cpus))
    os.environ.setdefault("MKL_NUM_THREADS", str(n_cpus))
    os.environ.setdefault("KMP_AFFINITY", "granularity=fine,compact,1,0")
    os.environ.setdefault("KMP_BLOCKTIME", "1")
    os.environ.setdefault("MALLOC_CONF", "background_thread:true,metadata_thp:auto")

    torch.set_num_threads(int(os.environ.get("OMP_NUM_THREADS", n_cpus)))
    try:
        target = interop_threads
        if target is None:
            target = int(os.environ.get("CHIMERA_INTEROP_THREADS", "1"))
        torch.set_num_interop_threads(target)
    except RuntimeError:
        pass
    return n_cpus


def cosine_lr(step: int, warmup: int, total: int, max_lr: float, min_lr: float) -> float:
    if warmup > 0 and step < warmup:
        return max_lr * (step + 1) / warmup
    if step >= total:
        return min_lr
    progress = (step - warmup) / max(1, total - warmup)
    return min_lr + 0.5 * (max_lr - min_lr) * (1.0 + math.cos(math.pi * progress))


def load_json_config(path: str | os.PathLike[str]) -> dict[str, Any]:
    with open(path, encoding="utf-8") as fh:
        return json.load(fh)


def apply_standard_config_tweaks(config: dict[str, Any], *, scale: str, seq_len: int) -> dict[str, Any]:
    config = dict(config)
    if scale in DEFAULT_SCALE_PRESETS:
        config.update(DEFAULT_SCALE_PRESETS[scale])
    config["num_hidden_layers"] = int(config.get("num_hidden_layers", 28))
    config["vocab_size"] = config.get("vocab_size", 200073)
    config.setdefault("gated_deltanet", {})["chunk_size"] = min(seq_len, 64)
    config.setdefault("xlstm", {})["memory_size_per_head"] = [config["head_dim"], config["head_dim"]]
    config.setdefault("titans", {}).update({
        "memory_depth": 2,
        "persistent_memory_slots": 16,
        "local_window_size": min(seq_len, 256),
    })

    # ── MoE config ──
    # IMPORTANT: force-override MoE sizes when a scale preset is active.
    # The full-scale config.json has moe_intermediate_size=1728 (for hidden=2560).
    # Using setdefault() would silently keep 1728 even at tiny scale (hidden=256),
    # resulting in 158M params in MoE alone (70% of model).
    moe_cfg = config.setdefault("backbone", {}).setdefault("moe", {})
    moe_cfg.setdefault("layers", [3, 7, 11, 15, 19, 23, 27])
    if scale in DEFAULT_SCALE_PRESETS:
        # Force scale-appropriate MoE sizes
        moe_cfg["moe_intermediate_size"] = config["intermediate_size"] // 4
        moe_cfg["n_routed_experts"] = 8
    else:
        moe_cfg.setdefault("moe_intermediate_size", config["intermediate_size"] // 4)
        moe_cfg.setdefault("n_routed_experts", 8)
    moe_cfg.setdefault("n_shared_experts", 1)
    moe_cfg.setdefault("num_experts_per_tok", 2)

    config.setdefault("looping", {}).update({
        "enabled": True,
        "prelude": [0, 3],
        "loop": [4, 23],
        "coda": [24, 27],
        "loop_range": [1, 3],
        "loop_default": 2,
    })
    config.setdefault("span_inference", {})["enabled"] = True
    config.setdefault("grammar", {})["enabled"] = True
    config.setdefault("entropy_valve", {})["enabled"] = True
    config.setdefault("debt_ledger", {})["enabled"] = True
    config.setdefault("multimodal", {})["enabled"] = False
    return config


def build_model_from_args(args) -> tuple[Chimera51ForCausalLM, dict[str, Any]]:
    config = load_json_config(args.config)
    config = apply_standard_config_tweaks(config, scale=args.scale, seq_len=args.seq_len)
    return Chimera51ForCausalLM(config), config


def save_training_checkpoint(model, config: dict[str, Any], step: int, output_dir: str) -> str:
    ckpt_dir = Path(output_dir)
    ckpt_dir.mkdir(parents=True, exist_ok=True)
    raw_model = getattr(model, "_orig_mod", model)
    torch.save({"model": raw_model.state_dict(), "config": config, "step": step}, ckpt_dir / "ckpt.pt")
    return str(ckpt_dir)


def save_final_checkpoint(
    model,
    config: dict[str, Any],
    step: int,
    best_loss: float,
    output_dir: str,
) -> str:
    final_dir = Path(output_dir)
    final_dir.mkdir(parents=True, exist_ok=True)
    raw_model = getattr(model, "_orig_mod", model)
    torch.save(
        {"model": raw_model.state_dict(), "config": config, "step": step, "best_loss": best_loss},
        final_dir / "model.pt",
    )
    with open(final_dir / "config.json", "w", encoding="utf-8") as fh:
        json.dump(config, fh, indent=2)
    return str(final_dir)