File size: 4,961 Bytes
11c11f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
#!/usr/bin/env python3
"""Chimera 5.2 — Fast CPU training with pre-tokenized dataset cache."""
from __future__ import annotations

import argparse
import json
import math
import os

# CPU threading must be configured *before* importing torch.
ncpus = int(os.environ.get("OMP_NUM_THREADS", os.cpu_count() or 4))
os.environ["OMP_NUM_THREADS"] = str(ncpus)
os.environ["MKL_NUM_THREADS"] = str(ncpus)

import torch
from torch.utils.data import DataLoader

from chimera import Chimera51ForCausalLM
from chimera.paths import DEFAULT_CONFIG_PATH
from chimera.training import (
    PreTokenizedDataset,
    apply_standard_config_tweaks,
    train_fast_loop,
)


torch.set_num_threads(ncpus)
try:
    torch.set_num_interop_threads(1)
except RuntimeError:
    pass


def build_or_load_dataset(seq_len: int, max_samples: int, cache_dir: str = "./cache"):
    cache_path = os.path.join(cache_dir, f"tiny_stories_{seq_len}_{max_samples}.pt")
    os.makedirs(cache_dir, exist_ok=True)

    if os.path.exists(cache_path):
        print(f"[CACHE] Loading pre-tokenized dataset from {cache_path}")
        chunks = torch.load(cache_path, weights_only=False)
        return PreTokenizedDataset(chunks, seq_len)

    from datasets import load_dataset
    from chimera import ChimeraTokenizer

    print(f"[DATA] Downloading TinyStories...")
    ds = load_dataset("roneneldan/TinyStories", split="train", streaming=True)
    tok = ChimeraTokenizer(pretrained="o200k_base")

    target = max_samples * (seq_len + 1)
    buffer = torch.empty(target, dtype=torch.long)
    buf_idx = 0
    processed = 0

    for ex in ds:
        text = ex.get("text", "")
        if not text:
            continue
        ids = tok.encode(text, add_special_tokens=False)
        ids.append(tok.eos_token_id)
        n = len(ids)
        if buf_idx + n > target:
            n = target - buf_idx
            if n <= 0:
                break
            ids = ids[:n]
        if n > 0:
            buffer[buf_idx:buf_idx + n] = torch.tensor(ids, dtype=torch.long)
            buf_idx += n
        processed += 1
        if (processed % 1000) == 0:
            print(f"  {processed:,} stories, {buf_idx:,}/{target} tokens...")
        if buf_idx >= target:
            break

    all_ids = buffer[:buf_idx]
    n = all_ids.numel() // (seq_len + 1)
    chunks = all_ids[:n * (seq_len + 1)]

    torch.save(chunks, cache_path)
    print(f"[CACHE] Saved {chunks.numel():,} tokens to {cache_path}")
    return PreTokenizedDataset(chunks, seq_len)


def train(args) -> None:
    with open(args.config) as f:
        config = json.load(f)
    config = apply_standard_config_tweaks(config, scale=args.scale, seq_len=args.seq_len)

    print("=" * 60)
    print(f"CHIMERA 5.2 FAST TRAIN — scale={args.scale}, seq_len={args.seq_len}, steps={args.max_steps}")
    print(f"Layers={config['num_hidden_layers']} hidden={config['hidden_size']} vocab={config['vocab_size']}")
    print(f"Threads: {torch.get_num_threads()}  bf16={args.bf16}  compile={args.compile}")
    print("=" * 60)

    model = Chimera51ForCausalLM(config)
    counts = model.count_parameters()
    print(f"Params: total={counts['total']:,} ternary={counts['ternary']:,}")

    if args.compile:
        print("[OPT] Compiling model...")
        model = torch.compile(model, backend="inductor", mode="default", dynamic=True)

    dataset = build_or_load_dataset(args.seq_len, args.max_samples, args.cache_dir)
    loader = DataLoader(
        dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=0, drop_last=True,
    )

    def compute_loss(batch) -> torch.Tensor:
        ids = batch["input_ids"]
        labels = batch["labels"]
        if args.bf16:
            with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
                out = model(ids, labels=labels)
        else:
            out = model(ids, labels=labels)
        return out.loss

    train_fast_loop(args, model, config, loader, compute_loss)


if __name__ == "__main__":
    p = argparse.ArgumentParser(description="Chimera 5.2 Fast CPU training")
    p.add_argument("--config", default=str(DEFAULT_CONFIG_PATH))
    p.add_argument("--scale", default="tiny", choices=["tiny", "small", "medium", "full"])
    p.add_argument("--seq_len", type=int, default=32)
    p.add_argument("--batch_size", type=int, default=4)
    p.add_argument("--lr", type=float, default=1e-3)
    p.add_argument("--warmup", type=int, default=100)
    p.add_argument("--max_steps", type=int, default=1000)
    p.add_argument("--max_samples", type=int, default=5000)
    p.add_argument("--bf16", action="store_true", default=False)
    p.add_argument("--compile", action="store_true", default=False)
    p.add_argument("--cache_dir", default="./cache")
    p.add_argument("--log_every", type=int, default=10)
    p.add_argument("--save_every", type=int, default=500)
    p.add_argument("--output_dir", default="./chimera_output")
    args = p.parse_args()
    train(args)