Lgr54HFi commited on
Commit
6639e7f
·
verified ·
1 Parent(s): bc0ec84

Upload train_fast.py

Browse files
Files changed (1) hide show
  1. train_fast.py +282 -0
train_fast.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Chimera 5.2 — Fast CPU training with pre-tokenized dataset cache."""
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import json
7
+ import math
8
+ import os
9
+ import sys
10
+ import time
11
+
12
+ # CPU threading must be configured *before* importing torch.
13
+ ncpus = int(os.environ.get("OMP_NUM_THREADS", os.cpu_count() or 4))
14
+ os.environ["OMP_NUM_THREADS"] = str(ncpus)
15
+ os.environ["MKL_NUM_THREADS"] = str(ncpus)
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ from torch.utils.data import DataLoader, Dataset
21
+
22
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
23
+ from chimera import Chimera51ForCausalLM
24
+
25
+
26
+ torch.set_num_threads(ncpus)
27
+ try:
28
+ torch.set_num_interop_threads(1)
29
+ except RuntimeError:
30
+ pass
31
+
32
+
33
+ # ---------------------------------------------------------------------------
34
+ # Pre-tokenized dataset cache
35
+ # ---------------------------------------------------------------------------
36
+
37
+ class PreTokenizedDataset(Dataset):
38
+ def __init__(self, ids: torch.Tensor, seq_len: int):
39
+ n = ids.numel() // (seq_len + 1)
40
+ self.chunks = ids[:n * (seq_len + 1)].view(n, seq_len + 1)
41
+ self.seq_len = seq_len
42
+
43
+ def __len__(self) -> int:
44
+ return self.chunks.size(0)
45
+
46
+ def __getitem__(self, idx: int):
47
+ c = self.chunks[idx]
48
+ return {"input_ids": c[:-1], "labels": c[1:]}
49
+
50
+
51
+ def build_or_load_dataset(seq_len: int, max_samples: int, cache_dir: str = "./cache"):
52
+ cache_path = os.path.join(cache_dir, f"tiny_stories_{seq_len}_{max_samples}.pt")
53
+ os.makedirs(cache_dir, exist_ok=True)
54
+
55
+ if os.path.exists(cache_path):
56
+ print(f"[CACHE] Loading pre-tokenized dataset from {cache_path}")
57
+ chunks = torch.load(cache_path, weights_only=False)
58
+ return PreTokenizedDataset(chunks, seq_len)
59
+
60
+ from datasets import load_dataset
61
+ from chimera import ChimeraTokenizer
62
+
63
+ print(f"[DATA] Downloading TinyStories...")
64
+ ds = load_dataset("roneneldan/TinyStories", split="train", streaming=True)
65
+ tok = ChimeraTokenizer(pretrained="o200k_base")
66
+
67
+ target = max_samples * (seq_len + 1)
68
+ buffer = torch.empty(target, dtype=torch.long)
69
+ buf_idx = 0
70
+ processed = 0
71
+
72
+ for ex in ds:
73
+ text = ex.get("text", "")
74
+ if not text:
75
+ continue
76
+ ids = tok.encode(text, add_special_tokens=False)
77
+ ids.append(tok.eos_token_id)
78
+ n = len(ids)
79
+ if buf_idx + n > target:
80
+ n = target - buf_idx
81
+ if n <= 0:
82
+ break
83
+ ids = ids[:n]
84
+ if n > 0:
85
+ buffer[buf_idx:buf_idx + n] = torch.tensor(ids, dtype=torch.long)
86
+ buf_idx += n
87
+ processed += 1
88
+ if (processed % 1000) == 0:
89
+ print(f" {processed:,} stories, {buf_idx:,}/{target} tokens...")
90
+ if buf_idx >= target:
91
+ break
92
+
93
+ all_ids = buffer[:buf_idx]
94
+ n = all_ids.numel() // (seq_len + 1)
95
+ chunks = all_ids[:n * (seq_len + 1)]
96
+
97
+ torch.save(chunks, cache_path)
98
+ print(f"[CACHE] Saved {chunks.numel():,} tokens to {cache_path}")
99
+ return PreTokenizedDataset(chunks, seq_len)
100
+
101
+
102
+ # ---------------------------------------------------------------------------
103
+ # Fast training loop
104
+ # ---------------------------------------------------------------------------
105
+
106
+ def cosine_lr(step: int, warmup: int, total: int, max_lr: float, min_lr: float) -> float:
107
+ if warmup > 0 and step < warmup:
108
+ return max_lr * (step + 1) / warmup
109
+ if step >= total:
110
+ return min_lr
111
+ p = (step - warmup) / max(1, total - warmup)
112
+ return min_lr + 0.5 * (max_lr - min_lr) * (1.0 + math.cos(math.pi * p))
113
+
114
+
115
+ _SCALE_PRESETS = {
116
+ "tiny": dict(hidden_size=256, intermediate_size=512, num_heads=4, head_dim=48),
117
+ "small": dict(hidden_size=512, intermediate_size=1024, num_heads=8, head_dim=48),
118
+ "medium": dict(hidden_size=1024, intermediate_size=2048, num_heads=8, head_dim=96),
119
+ }
120
+
121
+
122
+ def train(args) -> None:
123
+ with open(args.config) as f:
124
+ config = json.load(f)
125
+
126
+ if args.scale in _SCALE_PRESETS:
127
+ config.update(_SCALE_PRESETS[args.scale])
128
+ config["num_hidden_layers"] = int(config.get("num_hidden_layers", 28))
129
+ config["vocab_size"] = config.get("vocab_size", 200073)
130
+ config.setdefault("gated_deltanet", {})["chunk_size"] = min(args.seq_len, 64)
131
+ config.setdefault("xlstm", {})["memory_size_per_head"] = [config["head_dim"], config["head_dim"]]
132
+ config.setdefault("titans", {}).update({
133
+ "memory_depth": 2, "persistent_memory_slots": 16,
134
+ "local_window_size": min(args.seq_len, 256),
135
+ })
136
+ moe_cfg = config.setdefault("backbone", {}).setdefault("moe", {})
137
+ moe_cfg.setdefault("layers", [3, 7, 11, 15, 19, 23, 27])
138
+ moe_cfg.setdefault("moe_intermediate_size", config["intermediate_size"] // 4)
139
+ moe_cfg.setdefault("n_routed_experts", 8)
140
+ moe_cfg.setdefault("n_shared_experts", 1)
141
+ moe_cfg.setdefault("num_experts_per_tok", 2)
142
+ config.setdefault("looping", {}).update({
143
+ "enabled": True, "prelude": [0, 3], "loop": [4, 23], "coda": [24, 27],
144
+ "loop_range": [1, 3], "loop_default": 2,
145
+ })
146
+ config.setdefault("span_inference", {})["enabled"] = True
147
+ config.setdefault("grammar", {})["enabled"] = True
148
+ config.setdefault("entropy_valve", {})["enabled"] = True
149
+ config.setdefault("debt_ledger", {})["enabled"] = True
150
+ config.setdefault("multimodal", {})["enabled"] = False
151
+
152
+ print("=" * 60)
153
+ print(f"CHIMERA 5.2 FAST TRAIN — scale={args.scale}, seq_len={args.seq_len}, steps={args.max_steps}")
154
+ print(f"Layers={config['num_hidden_layers']} hidden={config['hidden_size']} vocab={config['vocab_size']}")
155
+ print(f"Threads: {torch.get_num_threads()} bf16={args.bf16} compile={args.compile}")
156
+ print("=" * 60)
157
+
158
+ model = Chimera51ForCausalLM(config)
159
+ counts = model.count_parameters()
160
+ print(f"Params: total={counts['total']:,} ternary={counts['ternary']:,}")
161
+
162
+ if args.compile:
163
+ print("[OPT] Compiling model...")
164
+ model = torch.compile(model, backend="inductor", mode="default", dynamic=True)
165
+
166
+ dataset = build_or_load_dataset(args.seq_len, args.max_samples, args.cache_dir)
167
+ loader = DataLoader(
168
+ dataset, batch_size=args.batch_size, shuffle=True,
169
+ num_workers=0, drop_last=True,
170
+ )
171
+
172
+ optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.95))
173
+
174
+ def compute_loss(batch) -> torch.Tensor:
175
+ ids = batch["input_ids"]
176
+ labels = batch["labels"]
177
+ if args.bf16:
178
+ with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
179
+ out = model(ids, labels=labels)
180
+ else:
181
+ out = model(ids, labels=labels)
182
+ return out.loss
183
+
184
+ os.makedirs(args.output_dir, exist_ok=True)
185
+ log_path = os.path.join(args.output_dir, "log.jsonl")
186
+ log_f = open(log_path, "w", encoding="utf-8")
187
+
188
+ model.train()
189
+ step = 0
190
+ total_loss = 0.0
191
+ best_loss = float("inf")
192
+ toks = 0
193
+ t0 = time.time()
194
+ data_iter = iter(loader)
195
+ warmup = min(args.warmup, max(1, args.max_steps // 10))
196
+
197
+ print(f"\n{'=' * 60}\nTraining starts\n{'=' * 60}\n")
198
+
199
+ while step < args.max_steps:
200
+ try:
201
+ batch = next(data_iter)
202
+ except StopIteration:
203
+ data_iter = iter(loader)
204
+ batch = next(data_iter)
205
+
206
+ loss = compute_loss(batch)
207
+ loss.backward()
208
+ total_loss += float(loss.item())
209
+
210
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
211
+ cur_lr = cosine_lr(step, warmup, args.max_steps, args.lr, args.lr * 0.1)
212
+ for pg in optimizer.param_groups:
213
+ pg["lr"] = cur_lr
214
+ optimizer.step()
215
+ optimizer.zero_grad(set_to_none=True)
216
+
217
+ toks += batch["input_ids"].numel()
218
+ step += 1
219
+
220
+ if step % args.log_every == 0:
221
+ dt = time.time() - t0
222
+ avg = total_loss / args.log_every
223
+ ppl = math.exp(min(avg, 20))
224
+ tps = toks / dt if dt > 0 else 0
225
+ eta_h = (args.max_steps - step) / (step / dt) / 3600 if dt > 0 else 0.0
226
+ log_f.write(json.dumps({
227
+ "step": step, "loss": round(avg, 4), "ppl": round(ppl, 2),
228
+ "lr": cur_lr, "tok/s": round(tps),
229
+ }) + "\n")
230
+ log_f.flush()
231
+ print(f" step {step:>6}/{args.max_steps} | loss {avg:.4f} | "
232
+ f"ppl {ppl:>8.2f} | lr {cur_lr:.2e} | "
233
+ f"{tps:.0f} tok/s | ETA {eta_h:.1f}h")
234
+ best_loss = min(best_loss, avg)
235
+ total_loss = 0.0
236
+ toks = 0
237
+ t0 = time.time()
238
+
239
+ if step % args.save_every == 0:
240
+ ckpt_dir = os.path.join(args.output_dir, f"ckpt-{step}")
241
+ os.makedirs(ckpt_dir, exist_ok=True)
242
+ raw = getattr(model, "_orig_mod", model)
243
+ torch.save({
244
+ "model": raw.state_dict(), "config": config,
245
+ "step": step,
246
+ }, os.path.join(ckpt_dir, "ckpt.pt"))
247
+ print(f" [SAVE] {ckpt_dir}")
248
+
249
+ final_dir = os.path.join(args.output_dir, "final")
250
+ os.makedirs(final_dir, exist_ok=True)
251
+ raw = getattr(model, "_orig_mod", model)
252
+ torch.save({
253
+ "model": raw.state_dict(), "config": config,
254
+ "step": step, "best_loss": best_loss,
255
+ }, os.path.join(final_dir, "model.pt"))
256
+ with open(os.path.join(final_dir, "config.json"), "w", encoding="utf-8") as fh:
257
+ json.dump(config, fh, indent=2)
258
+ log_f.close()
259
+
260
+ print(f"\n{'=' * 60}")
261
+ print(f"DONE — best loss {best_loss:.4f}, ppl {math.exp(min(best_loss, 20)):.2f}")
262
+ print(f"Saved to {final_dir}")
263
+
264
+
265
+ if __name__ == "__main__":
266
+ p = argparse.ArgumentParser(description="Chimera 5.2 Fast CPU training")
267
+ p.add_argument("--config", default="config.json")
268
+ p.add_argument("--scale", default="tiny", choices=["tiny", "small", "medium", "full"])
269
+ p.add_argument("--seq_len", type=int, default=32)
270
+ p.add_argument("--batch_size", type=int, default=4)
271
+ p.add_argument("--lr", type=float, default=1e-3)
272
+ p.add_argument("--warmup", type=int, default=100)
273
+ p.add_argument("--max_steps", type=int, default=1000)
274
+ p.add_argument("--max_samples", type=int, default=5000)
275
+ p.add_argument("--bf16", action="store_true", default=False)
276
+ p.add_argument("--compile", action="store_true", default=False)
277
+ p.add_argument("--cache_dir", default="./cache")
278
+ p.add_argument("--log_every", type=int, default=10)
279
+ p.add_argument("--save_every", type=int, default=500)
280
+ p.add_argument("--output_dir", default="./chimera_output")
281
+ args = p.parse_args()
282
+ train(args)