OpenTransformer commited on
Commit
46d7788
·
verified ·
1 Parent(s): 2334f27

Add n_tt_singlefile.py — Tenstorrent N300s training port

Browse files
Files changed (1) hide show
  1. n_tt_singlefile.py +1636 -0
n_tt_singlefile.py ADDED
@@ -0,0 +1,1636 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #!/usr/bin/env python3
3
+ """
4
+ n_tt_singlefile.py
5
+
6
+ Single-file AR+SAT trainer adapted for Tenstorrent training via TT-XLA / torch_xla.
7
+ It keeps the same rough structure as the original script, but changes the runtime
8
+ model so it can run on:
9
+ - Tenstorrent TT-XLA (`--device_backend tt`)
10
+ - CUDA
11
+ - CPU
12
+
13
+ Important TT notes:
14
+ - Training is the primary target.
15
+ - Inference remains available, but TT token-by-token generation is a bad fit for XLA
16
+ because sequence length changes trigger recompiles. The script therefore falls back
17
+ to CPU for `infer` unless `--force_tt_infer` is explicitly set.
18
+ - On TT, CUDA AMP/GradScaler and `torch.compile(mode="reduce-overhead")` are disabled.
19
+ - On TT, a manual token cross-entropy path is used instead of `nn.CrossEntropyLoss`.
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ import argparse
25
+ import json
26
+ import math
27
+ import os
28
+ import pathlib
29
+ import random
30
+ import sys
31
+ import time
32
+ from contextlib import nullcontext
33
+ from datetime import datetime, timedelta, timezone
34
+ from typing import Dict, List, Optional, Tuple, Any
35
+
36
+ # -----------------------------------------------------------------------------
37
+ # Pre-parse runtime backend before importing torch_xla.
38
+ # -----------------------------------------------------------------------------
39
+ _pre = argparse.ArgumentParser(add_help=False)
40
+ _pre.add_argument("cmd", nargs="?")
41
+ _pre.add_argument("--device_backend", default=os.environ.get("AGI_DEVICE_BACKEND", "auto"))
42
+ _pre.add_argument("--force_tt_infer", action="store_true")
43
+ _pre_known, _ = _pre.parse_known_args()
44
+
45
+ REQUESTED_BACKEND = (_pre_known.device_backend or "auto").lower()
46
+ if REQUESTED_BACKEND == "tt":
47
+ os.environ.setdefault("PJRT_DEVICE", "TT")
48
+ os.environ.setdefault("XLA_STABLEHLO_COMPILE", "1")
49
+ os.environ.setdefault("TT_RUNTIME_TRACE_REGION_SIZE", os.environ.get("TT_RUNTIME_TRACE_REGION_SIZE", "10000000"))
50
+
51
+ import torch
52
+ import torch.nn as nn
53
+ import torch.nn.functional as F
54
+ from datasets import load_dataset, DownloadConfig
55
+ from transformers import AutoTokenizer, logging as hf_log
56
+
57
+ TORCH_XLA_OK = False
58
+ torch_xla = None
59
+ xm = None
60
+ xr = None
61
+ xs = None
62
+
63
+ try:
64
+ import torch_xla # type: ignore
65
+ import torch_xla.core.xla_model as xm # type: ignore
66
+ import torch_xla.runtime as xr # type: ignore
67
+ import torch_xla.distributed.spmd as xs # type: ignore
68
+
69
+ TORCH_XLA_OK = True
70
+ except Exception:
71
+ TORCH_XLA_OK = False
72
+ torch_xla = None
73
+ xm = None
74
+ xr = None
75
+ xs = None
76
+
77
+ # -----------------------------------------------------------------------------
78
+ # Global runtime state
79
+ # -----------------------------------------------------------------------------
80
+ STATUS_FILE = "/workspace/status.json"
81
+
82
+ BACKEND = "cpu"
83
+ DEV = torch.device("cpu")
84
+ TT_MESH = None
85
+
86
+ hf_log.set_verbosity_error()
87
+
88
+ if torch.cuda.is_available():
89
+ torch.backends.cuda.matmul.allow_tf32 = True
90
+ try:
91
+ torch.set_float32_matmul_precision("high")
92
+ except Exception:
93
+ pass
94
+
95
+ TOKENIZER_ID = os.environ.get("TOKENIZER_ID", "deepseek-ai/DeepSeek-V3.2")
96
+ tok = AutoTokenizer.from_pretrained(TOKENIZER_ID, use_fast=True, trust_remote_code=True)
97
+ if tok.pad_token is None:
98
+ tok.add_special_tokens({"pad_token": "<|pad|>"})
99
+
100
+ VOCAB, EOS = (
101
+ max(tok.get_vocab().values()) + 1,
102
+ tok.eos_token_id if tok.eos_token_id is not None else tok.sep_token_id,
103
+ )
104
+
105
+ PRESETS: Dict[str, Dict[str, int]] = {
106
+ "femto_1x": dict(d=16, layers=1, heads=1, rank=16),
107
+ "femto_12x": dict(d=16, layers=1, heads=1, rank=192),
108
+ "femto_24x": dict(d=16, layers=1, heads=1, rank=384),
109
+ "pico_1x": dict(d=32, layers=1, heads=2, rank=16),
110
+ "pico_3x": dict(d=32, layers=1, heads=2, rank=48),
111
+ "pico_6x": dict(d=32, layers=1, heads=2, rank=96),
112
+ "pico_12x": dict(d=32, layers=1, heads=2, rank=192),
113
+ "pico_24x": dict(d=32, layers=1, heads=2, rank=384),
114
+ "pico_48x": dict(d=32, layers=1, heads=2, rank=768),
115
+ "nano_1x": dict(d=64, layers=2, heads=4, rank=16),
116
+ "nano_3x": dict(d=64, layers=2, heads=4, rank=48),
117
+ "nano_6x": dict(d=64, layers=2, heads=4, rank=96),
118
+ "nano_12x": dict(d=64, layers=2, heads=4, rank=192),
119
+ "nano_24x": dict(d=64, layers=2, heads=4, rank=384),
120
+ "nano_48x": dict(d=64, layers=2, heads=4, rank=768),
121
+ "nano_96x": dict(d=64, layers=2, heads=4, rank=1536),
122
+ "micro_3x": dict(d=128, layers=4, heads=8, rank=48),
123
+ "micro_6x": dict(d=128, layers=4, heads=8, rank=96),
124
+ "micro_12x": dict(d=128, layers=4, heads=8, rank=192),
125
+ "micro_24x": dict(d=128, layers=4, heads=8, rank=384),
126
+ "small": dict(d=512, layers=8, heads=16, rank=64),
127
+ "smallx2": dict(d=512, layers=16, heads=16, rank=64),
128
+ "base": dict(d=768, layers=12, heads=24, rank=96),
129
+ "base18": dict(d=768, layers=18, heads=24, rank=96),
130
+ "large": dict(d=1024, layers=24, heads=16, rank=128),
131
+ }
132
+
133
+ DEFAULT_BLOCK = 1122
134
+ DEFAULT_BATCH = 1
135
+ SAT_BLOCK = 2
136
+ LR_CORE, LR_HEAD = 5e-5, 2e-4
137
+ EMIT_LAMBDA = 0.1
138
+ DEFAULT_SAVE_SEC = 24 * 3600
139
+ CKDIR = pathlib.Path("ckpts_expansion")
140
+
141
+ DEFAULT_PRETRAIN_SOURCES = (
142
+ "OpenTransformer/goddess-crawl,OpenTransformer/agillm-crawl-data,"
143
+ "OpenTransformer/web-crawl-2026,OpenTransformer/web-crawl-clean-v2,"
144
+ "OpenTransformer/scraped-web-data,OpenTransformer/turbo-crawl,"
145
+ "OpenTransformer/sft-data-clean,OpenTransformer/web-crawl-v1"
146
+ )
147
+ DEFAULT_AFTER_SFT_SOURCES = "mlabonne/opc-sft-stage2-chat,HuggingFaceH4/ultrachat_200k"
148
+ DEFAULT_AFTER_SFT_BLOCK = 1122
149
+
150
+ _HOT_CFG_PATH = pathlib.Path("/workspace/hot_config.json")
151
+ _hot_cache = {"mtime": 0.0, "data": {}}
152
+ _MASK_CACHE: Dict[Tuple[str, str, int], torch.Tensor] = {}
153
+
154
+ # -----------------------------------------------------------------------------
155
+ # Cosmetic bits
156
+ # -----------------------------------------------------------------------------
157
+ class Colors:
158
+ RESET = "\033[0m"
159
+ BOLD = "\033[1m"
160
+ PROMPT = "\033[36m"
161
+ GEN = "\033[0m"
162
+ INFO = "\033[90m"
163
+ WARN = "\033[93m"
164
+
165
+
166
+ class SafeProgress:
167
+ def __init__(self, total, initial=0, unit="tok"):
168
+ self.total, self.n, self.unit = total, initial, unit
169
+ self.last_print, self.postfix = initial, {}
170
+ self.start_time = time.time()
171
+
172
+ def update(self, n=1):
173
+ self.n += n
174
+ if self.n - self.last_print >= 1_000_000:
175
+ self._print()
176
+ self.last_print = self.n
177
+
178
+ def set_postfix(self, **kwargs):
179
+ self.postfix = kwargs
180
+
181
+ def _print(self):
182
+ elapsed = time.time() - self.start_time
183
+ rate = self.n / elapsed if elapsed > 0 else 0
184
+ pct = 100 * self.n / self.total if self.total > 0 else 0
185
+ pf = " ".join(f"{k}={v}" for k, v in self.postfix.items())
186
+ print(f"[{pct:.1f}%] {self.n:,}/{self.total:,} {self.unit} | {rate:.0f} tok/s | {pf}")
187
+
188
+ def close(self):
189
+ self._print()
190
+ print("Done.")
191
+
192
+
193
+ # -----------------------------------------------------------------------------
194
+ # Runtime / backend helpers
195
+ # -----------------------------------------------------------------------------
196
+ def is_tt() -> bool:
197
+ return BACKEND == "tt"
198
+
199
+
200
+ def is_cuda() -> bool:
201
+ return BACKEND == "cuda"
202
+
203
+
204
+ def get_uk_time() -> str:
205
+ utc_now = datetime.now(timezone.utc)
206
+ year = utc_now.year
207
+ march_last = datetime(year, 3, 31, 1, 0, tzinfo=timezone.utc)
208
+ while march_last.weekday() != 6:
209
+ march_last = march_last.replace(day=march_last.day - 1)
210
+ oct_last = datetime(year, 10, 31, 1, 0, tzinfo=timezone.utc)
211
+ while oct_last.weekday() != 6:
212
+ oct_last = oct_last.replace(day=oct_last.day - 1)
213
+ if march_last <= utc_now < oct_last:
214
+ tz_name = "BST"
215
+ uk_time = utc_now + timedelta(hours=1)
216
+ else:
217
+ tz_name = "GMT"
218
+ uk_time = utc_now
219
+ return uk_time.strftime(f"%Y-%m-%d %H:%M:%S {tz_name}")
220
+
221
+
222
+ def init_runtime(device_backend: str, *, for_infer: bool = False, force_tt_infer: bool = False):
223
+ global BACKEND, DEV, TT_MESH
224
+
225
+ device_backend = (device_backend or "auto").lower()
226
+
227
+ if device_backend == "auto":
228
+ if torch.cuda.is_available():
229
+ device_backend = "cuda"
230
+ else:
231
+ device_backend = "cpu"
232
+
233
+ if for_infer and device_backend == "tt" and not force_tt_infer:
234
+ print("[runtime] TT inference was requested, but token-by-token XLA generation is a recompile trap.")
235
+ print("[runtime] Falling back to CPU for infer. Training on TT is the intended path here.")
236
+ device_backend = "cpu"
237
+
238
+ if device_backend == "tt":
239
+ if not TORCH_XLA_OK:
240
+ raise RuntimeError(
241
+ "TT backend requested but torch_xla / TT-XLA is not available. "
242
+ "Use the TT-XLA image or install the pjrt-plugin-tt wheel."
243
+ )
244
+
245
+ try:
246
+ xr.set_device_type("TT")
247
+ except Exception:
248
+ pass
249
+
250
+ os.environ.setdefault("PJRT_DEVICE", "TT")
251
+ os.environ.setdefault("XLA_STABLEHLO_COMPILE", "1")
252
+
253
+ BACKEND = "tt"
254
+ DEV = torch_xla.device()
255
+ TT_MESH = None
256
+
257
+ try:
258
+ torch_xla.set_custom_compile_options({
259
+ "fp32_dest_acc_en": True,
260
+ "math_fidelity": "hifi4",
261
+ })
262
+ except Exception:
263
+ pass
264
+
265
+ print(f"[runtime] backend=tt device={DEV}")
266
+ return
267
+
268
+ if device_backend == "cuda":
269
+ if not torch.cuda.is_available():
270
+ raise RuntimeError("CUDA backend requested but CUDA is not available.")
271
+ BACKEND = "cuda"
272
+ DEV = torch.device("cuda")
273
+ TT_MESH = None
274
+ print(f"[runtime] backend=cuda device={DEV}")
275
+ return
276
+
277
+ BACKEND = "cpu"
278
+ DEV = torch.device("cpu")
279
+ TT_MESH = None
280
+ print(f"[runtime] backend=cpu device={DEV}")
281
+
282
+
283
+ def tt_sync(wait: bool = True):
284
+ if is_tt():
285
+ torch_xla.sync(wait=wait)
286
+
287
+
288
+ def maybe_clear_tt_cache():
289
+ if is_tt():
290
+ try:
291
+ xr.clear_computation_cache()
292
+ except Exception:
293
+ pass
294
+
295
+
296
+ def optimizer_step_backend(optimizer: torch.optim.Optimizer):
297
+ if is_tt():
298
+ optimizer.step()
299
+ tt_sync(wait=True)
300
+ else:
301
+ optimizer.step()
302
+
303
+
304
+ def state_dict_to_cpu(sd: Dict[str, Any]) -> Dict[str, Any]:
305
+ out: Dict[str, Any] = {}
306
+ for k, v in sd.items():
307
+ if torch.is_tensor(v):
308
+ out[k] = v.detach().cpu()
309
+ else:
310
+ out[k] = v
311
+ return out
312
+
313
+
314
+ def rng_state():
315
+ if is_cuda():
316
+ try:
317
+ return torch.cuda.get_rng_state(DEV)
318
+ except TypeError:
319
+ return torch.cuda.get_rng_state()
320
+ return torch.get_rng_state()
321
+
322
+
323
+ def _state_dict_cpu(module: nn.Module) -> Dict[str, torch.Tensor]:
324
+ raw = _strip_compiled_prefix(module.state_dict())
325
+ return state_dict_to_cpu(raw)
326
+
327
+
328
+ def _state_dict_load(module: nn.Module, sd: Dict[str, torch.Tensor]):
329
+ module.load_state_dict(_strip_compiled_prefix(sd), strict=True)
330
+
331
+
332
+ # -----------------------------------------------------------------------------
333
+ # Status / utilities
334
+ # -----------------------------------------------------------------------------
335
+ def write_status(step, seen_tok, loss, batch, block, tok_per_sec, phase):
336
+ try:
337
+ with open(STATUS_FILE, "w") as f:
338
+ json.dump(
339
+ {
340
+ "step": step,
341
+ "seen_tok": seen_tok,
342
+ "loss": float(loss) if loss is not None else None,
343
+ "batch": batch,
344
+ "block": block,
345
+ "tok_per_sec": tok_per_sec,
346
+ "phase": phase,
347
+ "updated": time.time(),
348
+ "target_tok": 35737600000,
349
+ "backend": BACKEND,
350
+ },
351
+ f,
352
+ )
353
+ except Exception:
354
+ pass
355
+
356
+
357
+ def show_status():
358
+ try:
359
+ with open(STATUS_FILE) as f:
360
+ s = json.load(f)
361
+ age = time.time() - s.get("updated", 0)
362
+ target = s.get("target_tok") or 35737600000
363
+ remaining = target - s.get("seen_tok", 0)
364
+ eta_sec = remaining / max(s.get("tok_per_sec", 1), 1)
365
+ eta_days = eta_sec / 86400
366
+ print(
367
+ f"Step: {s.get('step', '?'):,} | Tokens: {s.get('seen_tok', 0)/1e9:.2f}B / {target/1e9:.1f}B | "
368
+ f"Loss: {s.get('loss', 0):.4f}"
369
+ )
370
+ print(
371
+ f"Speed: {s.get('tok_per_sec', 0):.0f} tok/s | B={s.get('batch')} L={s.get('block')} | "
372
+ f"ETA: {eta_days:.1f} days | {age:.0f}s ago | backend={s.get('backend', '?')}"
373
+ )
374
+ except FileNotFoundError:
375
+ print("No status file. Training not running?")
376
+ except Exception as e:
377
+ print(f"Error: {e}")
378
+
379
+
380
+ def get_hot_datasets(default):
381
+ try:
382
+ if _HOT_CFG_PATH.exists():
383
+ mt = _HOT_CFG_PATH.stat().st_mtime
384
+ if mt > _hot_cache["mtime"]:
385
+ _hot_cache["data"] = json.loads(_HOT_CFG_PATH.read_text())
386
+ _hot_cache["mtime"] = mt
387
+ cfg = _hot_cache["data"]
388
+ if "datasets" in cfg:
389
+ ds = cfg["datasets"]
390
+ if isinstance(ds, list):
391
+ ds = ",".join(ds)
392
+ print(f"[HOT] Using: {ds[:60]}...")
393
+ return ds
394
+ except Exception as e:
395
+ print(f"[HOT] Error: {e}")
396
+ return default
397
+
398
+
399
+ def print_expansion_info(cfg: dict, tie_weights: bool = False):
400
+ d_k = cfg["d"] // cfg["heads"]
401
+ rank = cfg["rank"]
402
+ ratio = rank / d_k
403
+ regime = "COMPRESSION" if ratio < 1 else ("IDENTITY" if ratio == 1 else "EXPANSION")
404
+ tie_str = "YES" if tie_weights else "NO"
405
+ print("┌─────────────────────────────────────────┐")
406
+ print("│ TUNEABLE ATTENTION CONFIG │")
407
+ print("├─────────────────────────────────────────┤")
408
+ print(f"│ d_model: {cfg['d']:4d} heads: {cfg['heads']:2d} d_k: {d_k:3d} │")
409
+ print(f"│ layers: {cfg['layers']:4d} tie_weights: {tie_str:3s} │")
410
+ print(f"│ rank: {rank:4d} ratio: {ratio:.1f}x [{regime:11s}] │")
411
+ print("└─────────────────────────────────────────┘")
412
+
413
+
414
+ def _is_probably_ckpt(path: pathlib.Path) -> bool:
415
+ try:
416
+ return path.is_file() and path.suffix == ".pt" and not path.name.endswith(".pt.tmp") and path.stat().st_size > (1 << 20)
417
+ except Exception:
418
+ return False
419
+
420
+
421
+ def _resolve_ckpt(path: pathlib.Path) -> pathlib.Path | None:
422
+ try:
423
+ if path.is_dir():
424
+ cands = sorted([p for p in path.glob("*.pt") if _is_probably_ckpt(p)], key=lambda p: p.stat().st_mtime, reverse=True)
425
+ return cands[0] if cands else None
426
+ if path.suffix == ".tmp":
427
+ solid = path.with_suffix("")
428
+ return solid if _is_probably_ckpt(solid) else _resolve_ckpt(path.parent)
429
+ return path if _is_probably_ckpt(path) else _resolve_ckpt(path.parent)
430
+ except Exception:
431
+ return None
432
+
433
+
434
+ def _try_load(path: pathlib.Path, map_location="cpu"):
435
+ try:
436
+ return torch.load(path, map_location=map_location)
437
+ except Exception as e:
438
+ print(f"[ckpt-skip] {path} not usable: {e}")
439
+ return None
440
+
441
+
442
+ def _prune_checkpoints(save_dir: pathlib.Path, phase_name: str, max_ckpts: Optional[int]):
443
+ if max_ckpts is None or max_ckpts <= 0:
444
+ return
445
+ try:
446
+ for tmp in save_dir.glob("*.pt.tmp"):
447
+ try:
448
+ tmp.unlink()
449
+ print(f" [prune] cleaned stale tmp {tmp.name}")
450
+ except Exception:
451
+ pass
452
+ pattern = f"{phase_name}_step*.pt"
453
+ ckpts = sorted([p for p in save_dir.glob(pattern) if _is_probably_ckpt(p)], key=lambda p: p.stat().st_mtime)
454
+ excess = len(ckpts) - max_ckpts
455
+ if excess > 0:
456
+ for p in ckpts[:excess]:
457
+ try:
458
+ p.unlink()
459
+ print(f" [prune] deleted old {p.name}")
460
+ except Exception:
461
+ pass
462
+ except Exception as e:
463
+ print(f"[ckpt-prune] error: {e}")
464
+
465
+
466
+ def _strip_compiled_prefix(sd):
467
+ return {k.replace("_orig_mod.", ""): v for k, v in sd.items()}
468
+
469
+
470
+ def save_ckpt(path: pathlib.Path, core, ar_h, sat_h, opt, scaler, meta):
471
+ path.parent.mkdir(exist_ok=True, parents=True)
472
+ tmp = path.with_suffix(path.suffix + ".tmp")
473
+
474
+ if is_tt():
475
+ tt_sync(wait=True)
476
+
477
+ state = {
478
+ "core": _state_dict_cpu(core),
479
+ "ar": _state_dict_cpu(ar_h),
480
+ "sat": _state_dict_cpu(sat_h),
481
+ "opt": opt.state_dict(),
482
+ "scaler": scaler.state_dict() if hasattr(scaler, "state_dict") else {},
483
+ "cfg": meta.get("cfg"),
484
+ "tokenizer_id": TOKENIZER_ID,
485
+ "tie_weights": meta.get("tie_weights", False),
486
+ "backend": BACKEND,
487
+ **{k: v for k, v in meta.items() if k not in ("cfg", "tie_weights")},
488
+ }
489
+ torch.save(state, tmp, _use_new_zipfile_serialization=False)
490
+ tmp.replace(path)
491
+ (path.parent / "latest.json").write_text(
492
+ json.dumps(
493
+ {
494
+ "path": str(path),
495
+ "step": meta["step"],
496
+ "block_size": meta.get("block_size"),
497
+ "batch_size": meta.get("batch_size"),
498
+ "seen_tok": meta.get("seen_tok"),
499
+ }
500
+ )
501
+ )
502
+ print(f"\n✓ saved checkpoint {path.name}")
503
+
504
+
505
+ def load_ckpt(path, core, ar_h, sat_h, opt, scaler):
506
+ p = _resolve_ckpt(path) or path
507
+ ck = _try_load(p, map_location="cpu")
508
+ if ck is None:
509
+ raise FileNotFoundError(f"No valid checkpoint at {p}")
510
+ _state_dict_load(core, ck["core"])
511
+ _state_dict_load(ar_h, ck["ar"])
512
+ _state_dict_load(sat_h, ck["sat"])
513
+ try:
514
+ opt.load_state_dict(ck["opt"])
515
+ except Exception:
516
+ pass
517
+ if ck.get("scaler"):
518
+ try:
519
+ scaler.load_state_dict(ck["scaler"])
520
+ except Exception:
521
+ pass
522
+ return ck.get("step", 0), ck.get("seen_tok", 0), ck.get("wall_time", time.time()), ck.get("block_size")
523
+
524
+
525
+ def _safe_load_any(path: pathlib.Path, tgt: nn.Module, key: str | None = None):
526
+ p = _resolve_ckpt(path) or path
527
+ if not p.exists():
528
+ return 0
529
+ ck = _try_load(p, map_location="cpu")
530
+ if ck is None:
531
+ return 0
532
+ sd = ck.get(key, ck) if key else ck
533
+ if isinstance(sd, dict) and "state_dict" in sd:
534
+ sd = sd["state_dict"]
535
+ tgt_sd = tgt.state_dict()
536
+ filt = {k: v for k, v in sd.items() if k in tgt_sd and v.shape == tgt_sd[k].shape}
537
+ if filt:
538
+ tgt.load_state_dict(filt, strict=False)
539
+ return len(filt)
540
+
541
+
542
+ def infer_cfg_from_ckpt(path: pathlib.Path):
543
+ p = _resolve_ckpt(path) or path
544
+ if not p.exists():
545
+ return None
546
+ sd = _try_load(p, map_location="cpu")
547
+ if sd is None:
548
+ return None
549
+ if "cfg" in sd:
550
+ return dict(sd["cfg"])
551
+ return None
552
+
553
+
554
+ # -----------------------------------------------------------------------------
555
+ # AMP helpers
556
+ # -----------------------------------------------------------------------------
557
+ try:
558
+ from torch.amp import autocast as _ac, GradScaler # type: ignore
559
+ except Exception:
560
+ from torch.cuda.amp import autocast as _ac, GradScaler # type: ignore
561
+
562
+
563
+ def _auto_amp_dtype():
564
+ if is_cuda():
565
+ try:
566
+ if torch.cuda.is_bf16_supported():
567
+ return torch.bfloat16
568
+ return torch.float16
569
+ except Exception:
570
+ return torch.float16
571
+ return torch.float32
572
+
573
+
574
+ def amp(enabled: bool):
575
+ if not (enabled and is_cuda()):
576
+ return nullcontext()
577
+ try:
578
+ return _ac(device_type="cuda", dtype=_auto_amp_dtype())
579
+ except TypeError:
580
+ return _ac(dtype=_auto_amp_dtype())
581
+
582
+
583
+ # -----------------------------------------------------------------------------
584
+ # Dataset stream helpers
585
+ # -----------------------------------------------------------------------------
586
+ def _coerce_role(r: str) -> str:
587
+ r = (r or "").lower()
588
+ if r in {"user", "human", "customer"}:
589
+ return "user"
590
+ if r in {"assistant", "gpt", "bot"}:
591
+ return "assistant"
592
+ if r in {"system", "context"}:
593
+ return "system"
594
+ return r or "user"
595
+
596
+
597
+ def _render_chat_text_from_ex(ex: dict, messages_key: str, add_generation_prompt: bool) -> Optional[str]:
598
+ msgs = ex.get(messages_key)
599
+ if msgs is None:
600
+ for alt in ("conversations", "dialog", "turns"):
601
+ if isinstance(ex.get(alt), list):
602
+ msgs = ex[alt]
603
+ break
604
+ if isinstance(msgs, list) and msgs and isinstance(msgs[0], dict):
605
+ try:
606
+ norm = []
607
+ for m in msgs:
608
+ role = _coerce_role(m.get("role", ""))
609
+ content = m.get("content", m.get("text", ""))
610
+ if not isinstance(content, str):
611
+ continue
612
+ norm.append({"role": role, "content": content})
613
+ if not norm:
614
+ return None
615
+ return tok.apply_chat_template(norm, tokenize=False, add_generation_prompt=add_generation_prompt)
616
+ except Exception:
617
+ return None
618
+ for a, b in (("prompt", "response"), ("instruction", "output"), ("question", "answer")):
619
+ if isinstance(ex.get(a), str) and isinstance(ex.get(b), str):
620
+ return f"User: {ex[a]}\nAssistant: {ex[b]}"
621
+ return None
622
+
623
+
624
+ def _open_stream_one(ds_name: str, seed: int, streaming: bool = True):
625
+ dc = DownloadConfig(max_retries=5, use_etag=True, resume_download=True)
626
+ if ":" in ds_name:
627
+ base, config = ds_name.split(":", 1)
628
+ else:
629
+ base, config = ds_name, None
630
+
631
+ if not streaming:
632
+ print(f"[download] Downloading {ds_name} (non-streaming)...")
633
+
634
+ if base == "json":
635
+ data_files = {"train": config}
636
+ ds = load_dataset("json", data_files=data_files, split="train", streaming=streaming, download_config=dc)
637
+ else:
638
+ ds = (
639
+ load_dataset(base, config, split="train", streaming=streaming, download_config=dc)
640
+ if config
641
+ else load_dataset(base, split="train", streaming=streaming, download_config=dc)
642
+ )
643
+
644
+ if streaming:
645
+ return iter(ds.shuffle(buffer_size=1000, seed=seed))
646
+
647
+ print(f"[download] Got {len(ds):,} examples. Shuffling...")
648
+ ds = ds.shuffle(seed=seed)
649
+ return iter(ds)
650
+
651
+
652
+ def token_stream(
653
+ ds_names: str,
654
+ target: int,
655
+ seed: int = 42,
656
+ chat: bool = False,
657
+ chat_messages_key: str = "messages",
658
+ sft_add_generation_prompt: bool = False,
659
+ dataset_field_text: str = "text",
660
+ streaming: bool = True,
661
+ ):
662
+ ds_names = get_hot_datasets(ds_names)
663
+ sources = [s.strip() for s in ds_names.split(",") if s.strip()]
664
+ if not sources:
665
+ return
666
+
667
+ src_idx = 0
668
+ emitted = 0
669
+ it = None
670
+ attempts = 0
671
+ backoff_base = 2.0
672
+
673
+ while emitted < target:
674
+ try:
675
+ if it is None:
676
+ it = _open_stream_one(sources[src_idx], seed, streaming=streaming)
677
+ ex = next(it)
678
+ text = None
679
+ if isinstance(ex, dict):
680
+ if chat:
681
+ text = _render_chat_text_from_ex(ex, chat_messages_key, sft_add_generation_prompt)
682
+ if text is None:
683
+ if dataset_field_text and isinstance(ex.get(dataset_field_text), str):
684
+ text = ex[dataset_field_text]
685
+ elif isinstance(ex.get("text"), str):
686
+ text = ex["text"]
687
+ if not isinstance(text, str):
688
+ attempts = 0
689
+ continue
690
+ enc = tok.encode(text)
691
+ if EOS is not None and (len(enc) == 0 or enc[-1] != EOS):
692
+ enc = enc + [EOS]
693
+ for t in enc:
694
+ yield t
695
+ emitted += 1
696
+ if emitted >= target:
697
+ return
698
+ attempts = 0
699
+ except StopIteration:
700
+ it = None
701
+ src_idx = (src_idx + 1) % len(sources)
702
+ except Exception as e:
703
+ attempts += 1
704
+ sleep_s = min(60.0, backoff_base ** min(attempts, 6))
705
+ print(f"[stream-retry] {sources[src_idx]} error: {type(e).__name__}, sleeping {sleep_s:.1f}s")
706
+ time.sleep(sleep_s)
707
+ it = None
708
+ if attempts % 2 == 0 and len(sources) > 1:
709
+ src_idx = (src_idx + 1) % len(sources)
710
+
711
+
712
+ # -----------------------------------------------------------------------------
713
+ # TT-friendly manual cross-entropy
714
+ # -----------------------------------------------------------------------------
715
+ def token_ce_loss(
716
+ logits: torch.Tensor,
717
+ targets: torch.Tensor,
718
+ *,
719
+ label_smoothing: float = 0.0,
720
+ ignore_index: int = -100,
721
+ ) -> torch.Tensor:
722
+ """
723
+ Manual token cross-entropy using log_softmax + gather.
724
+
725
+ This avoids `nn.CrossEntropyLoss`, which Tenstorrent's own training utilities
726
+ work around for language-model-style logits.
727
+ """
728
+ log_probs = F.log_softmax(logits, dim=-1)
729
+ valid = targets.ne(ignore_index)
730
+ safe_targets = torch.where(valid, targets, torch.zeros_like(targets))
731
+
732
+ nll = -log_probs.gather(dim=-1, index=safe_targets.unsqueeze(-1)).squeeze(-1)
733
+ if label_smoothing > 0.0:
734
+ smooth = -log_probs.mean(dim=-1)
735
+ per_tok = (1.0 - label_smoothing) * nll + label_smoothing * smooth
736
+ else:
737
+ per_tok = nll
738
+
739
+ per_tok = per_tok * valid.to(per_tok.dtype)
740
+ denom = valid.sum().clamp_min(1).to(per_tok.dtype)
741
+ return per_tok.sum() / denom
742
+
743
+
744
+ # -----------------------------------------------------------------------------
745
+ # ALiBi / masks
746
+ # -----------------------------------------------------------------------------
747
+ @torch._dynamo.disable
748
+ def _alibi_slopes(n_heads: int):
749
+ def pow2slopes(n):
750
+ start = 2 ** (-2 ** -(math.log2(n) - 3))
751
+ ratio = start
752
+ return [start * (ratio ** i) for i in range(n)]
753
+
754
+ if math.log2(n_heads).is_integer():
755
+ vals = pow2slopes(n_heads)
756
+ else:
757
+ closest = 2 ** math.floor(math.log2(n_heads))
758
+ vals = pow2slopes(closest)
759
+ extra = pow2slopes(2 * closest)
760
+ vals += extra[0::2][: n_heads - closest]
761
+ return torch.tensor(vals, device=DEV).view(1, n_heads, 1, 1)
762
+
763
+
764
+ @torch._dynamo.disable
765
+ def alibi_bias(n_heads: int, n_tokens: int):
766
+ i = torch.arange(n_tokens, device=DEV).view(1, 1, n_tokens, 1)
767
+ j = torch.arange(n_tokens, device=DEV).view(1, 1, 1, n_tokens)
768
+ dist = (j - i).clamp_min(0)
769
+ return -_alibi_slopes(n_heads) * dist
770
+
771
+
772
+ def causal_mask(n):
773
+ key = (BACKEND, "causal", n)
774
+ if key not in _MASK_CACHE:
775
+ _MASK_CACHE[key] = torch.triu(torch.full((1, 1, n, n), float("-inf"), device=DEV), 1)
776
+ return _MASK_CACHE[key]
777
+
778
+
779
+ def sat_mask(n, block=SAT_BLOCK):
780
+ key = (BACKEND, f"sat_{block}", n)
781
+ if key not in _MASK_CACHE:
782
+ idx = torch.arange(n, device=DEV)
783
+ grp = idx.unsqueeze(0) // block
784
+ allow = (grp.T == grp) | (grp.T > grp)
785
+ _MASK_CACHE[key] = torch.where(allow, 0.0, float("-inf")).unsqueeze(0).unsqueeze(0)
786
+ return _MASK_CACHE[key]
787
+
788
+
789
+ def sat_mask_cached(new_len: int, cached_len: int, block=SAT_BLOCK):
790
+ total_len = cached_len + new_len
791
+ return torch.zeros((1, 1, new_len, total_len), device=DEV)
792
+
793
+
794
+ # -----------------------------------------------------------------------------
795
+ # Model
796
+ # -----------------------------------------------------------------------------
797
+ class TuneableAttentionMHA(nn.Module):
798
+ def __init__(self, d: int, h: int, r: int, use_relpos: bool = True):
799
+ super().__init__()
800
+ assert d % h == 0
801
+ self.h, self.dk, self.r = h, d // h, r
802
+ self.use_relpos = use_relpos
803
+ self.q = nn.Linear(d, d, bias=False)
804
+ self.k = nn.Linear(d, d, bias=False)
805
+ self.v = nn.Linear(d, d, bias=False)
806
+ self.U = nn.Parameter(torch.randn(self.dk, r))
807
+ nn.init.orthogonal_(self.U)
808
+ self.proj = nn.Linear(h * self.dk, d, bias=False)
809
+ self.drop = nn.Dropout(0.1)
810
+
811
+ def _proj_qk(self, x):
812
+ B, N, _ = x.shape
813
+ return (x.view(B, N, self.h, self.dk).transpose(1, 2) @ self.U)
814
+
815
+ def _reshape_v(self, x):
816
+ B, N, _ = x.shape
817
+ return x.view(B, N, self.h, self.dk).transpose(1, 2)
818
+
819
+ def forward(self, x, mask=None, rel_bias_tokens=None, kv_cache=None, use_cache=False):
820
+ q = self._proj_qk(self.q(x))
821
+ k_new = self._proj_qk(self.k(x))
822
+ v_new = self._reshape_v(self.v(x))
823
+
824
+ if kv_cache is None:
825
+ k, v = k_new, v_new
826
+ else:
827
+ k_cached, v_cached = kv_cache
828
+ if use_cache:
829
+ k = torch.cat([k_cached, k_new], dim=2)
830
+ v = torch.cat([v_cached, v_new], dim=2)
831
+ else:
832
+ k, v = k_new, v_new
833
+
834
+ att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
835
+ if self.use_relpos and rel_bias_tokens is not None:
836
+ att = att + alibi_bias(self.h, rel_bias_tokens).to(att.dtype)[:, :, -q.size(2):, :]
837
+ if mask is not None:
838
+ att = att + mask.to(att.dtype)
839
+
840
+ z = (att.softmax(-1) @ v).transpose(1, 2).reshape(x.size(0), x.size(1), -1)
841
+ out = self.drop(self.proj(z))
842
+ return (out, (k, v)) if use_cache else out
843
+
844
+
845
+ class Block(nn.Module):
846
+ def __init__(self, d: int, h: int, r: int):
847
+ super().__init__()
848
+ self.ln1 = nn.LayerNorm(d)
849
+ self.ln2 = nn.LayerNorm(d)
850
+ self.mha = TuneableAttentionMHA(d, h, r)
851
+ self.ff = nn.Sequential(nn.Linear(d, 4 * d), nn.ReLU(), nn.Linear(4 * d, d))
852
+
853
+ def forward(self, x, mask, kv=None, use_cache=False, total_seq_len=None):
854
+ if use_cache:
855
+ y, new_kv = self.mha(self.ln1(x), mask, rel_bias_tokens=total_seq_len, kv_cache=kv, use_cache=True)
856
+ x = x + y + self.ff(self.ln2(x + y))
857
+ return x, new_kv
858
+ n = x.size(1)
859
+ x = x + self.mha(self.ln1(x), mask, rel_bias_tokens=n)
860
+ return x + self.ff(self.ln2(x))
861
+
862
+
863
+ class Encoder(nn.Module):
864
+ def __init__(self, cfg, tie_weights: bool = False):
865
+ super().__init__()
866
+ d, l, h, r = cfg["d"], cfg["layers"], cfg["heads"], cfg["rank"]
867
+ self.emb = nn.Embedding(VOCAB, d)
868
+ self.blocks = nn.ModuleList([Block(d, h, r) for _ in range(l)])
869
+ self.ln = nn.LayerNorm(d)
870
+ self.tie_weights = tie_weights
871
+
872
+ def forward(self, ids, mask, kv_caches=None, use_cache=False, total_seq_len=None):
873
+ x = self.emb(ids)
874
+ if not use_cache:
875
+ for blk in self.blocks:
876
+ x = blk(x, mask)
877
+ return self.ln(x)
878
+ new_kvs = []
879
+ for i, blk in enumerate(self.blocks):
880
+ kv = kv_caches[i] if kv_caches else None
881
+ x, kv_out = blk(x, mask, kv, use_cache=True, total_seq_len=total_seq_len)
882
+ new_kvs.append(kv_out)
883
+ return self.ln(x), new_kvs
884
+
885
+
886
+ class ARHead(nn.Module):
887
+ def __init__(self, d, tie_weights: bool = False, embedding_weight: nn.Parameter = None):
888
+ super().__init__()
889
+ self.tie_weights = tie_weights
890
+ if tie_weights and embedding_weight is not None:
891
+ self.proj = nn.Linear(d, VOCAB, bias=False)
892
+ self.proj.weight = embedding_weight
893
+ else:
894
+ self.proj = nn.Linear(d, VOCAB)
895
+
896
+ def forward(self, h):
897
+ return self.proj(h)
898
+
899
+
900
+ class SATHead(nn.Module):
901
+ def __init__(self, d, mode="var"):
902
+ super().__init__()
903
+ self.proj = nn.Linear(d, VOCAB)
904
+ self.gate = nn.Linear(d, 2) if mode == "var" else None
905
+
906
+ def forward(self, h_last):
907
+ return self.proj(h_last), (self.gate(h_last[:, 0]) if self.gate else None)
908
+
909
+
910
+ # -----------------------------------------------------------------------------
911
+ # Train helpers
912
+ # -----------------------------------------------------------------------------
913
+ def _parse_grow_plan(s: str) -> List[int]:
914
+ return sorted(set([int(x.strip()) for x in s.split(",") if x.strip() and int(x.strip()) >= 128]))
915
+
916
+
917
+ def _count_enabled_params(*modules) -> int:
918
+ seen_data_ptrs = set()
919
+ total = 0
920
+ for m in modules:
921
+ if m is None:
922
+ continue
923
+ for p in m.parameters():
924
+ if p.data_ptr() not in seen_data_ptrs:
925
+ seen_data_ptrs.add(p.data_ptr())
926
+ total += p.numel()
927
+ return total
928
+
929
+
930
+ def _phase_freeze(core: nn.Module, *, freeze_core: bool, unfreeze_ln: bool, train_emb: bool):
931
+ for p in core.parameters():
932
+ p.requires_grad = not freeze_core
933
+ if freeze_core:
934
+ if unfreeze_ln:
935
+ for blk in core.blocks:
936
+ for p in blk.ln1.parameters():
937
+ p.requires_grad = True
938
+ for p in blk.ln2.parameters():
939
+ p.requires_grad = True
940
+ for p in core.ln.parameters():
941
+ p.requires_grad = True
942
+ if train_emb:
943
+ for p in core.emb.parameters():
944
+ p.requires_grad = True
945
+
946
+
947
+ def _move_model_for_backend(*mods, tt_bf16: bool):
948
+ for m in mods:
949
+ if is_tt():
950
+ if tt_bf16:
951
+ m.to(device=DEV, dtype=torch.bfloat16)
952
+ else:
953
+ m.to(device=DEV)
954
+ else:
955
+ m.to(device=DEV)
956
+
957
+
958
+ def _losses_for_batch(args, core, ar_h, sat_h, ids):
959
+ tgt_ar = ids
960
+
961
+ h_ar = core(ids, causal_mask(ids.size(1)))
962
+ logits_ar = ar_h(h_ar)[:, :-1]
963
+ loss_ar = token_ce_loss(logits_ar, tgt_ar[:, 1:], label_smoothing=args.label_smoothing)
964
+
965
+ if args.ar_only:
966
+ return loss_ar, loss_ar, None
967
+
968
+ h_sat = core(ids, sat_mask(ids.size(1)))
969
+ logits_sat, gate = sat_h(h_sat[:, -SAT_BLOCK:])
970
+ tgt_sat = ids[:, 1:SAT_BLOCK + 1]
971
+ loss_sat = token_ce_loss(logits_sat, tgt_sat, label_smoothing=args.label_smoothing)
972
+ if gate is not None:
973
+ loss_sat = loss_sat + EMIT_LAMBDA * F.cross_entropy(
974
+ gate,
975
+ torch.ones(ids.size(0), device=DEV, dtype=torch.long),
976
+ )
977
+ return loss_ar + loss_sat, loss_ar, loss_sat
978
+
979
+
980
+ def _train_phase(
981
+ args,
982
+ phase_name: str,
983
+ core,
984
+ ar_h,
985
+ sat_h,
986
+ opt,
987
+ scaler,
988
+ start_step,
989
+ seen_tok,
990
+ resume_wall_time,
991
+ cfg,
992
+ source,
993
+ steps,
994
+ block_size,
995
+ batch_size,
996
+ chat_cfg: dict,
997
+ max_ckpts: int,
998
+ target_tokens_override: Optional[int] = None,
999
+ tie_weights: bool = False,
1000
+ streaming: bool = True,
1001
+ ):
1002
+ BLOCK = block_size
1003
+ BATCH = batch_size
1004
+
1005
+ if target_tokens_override is not None:
1006
+ target_tokens = target_tokens_override
1007
+ else:
1008
+ ratio = 51.2 if args.chilla_max_double else 25
1009
+ param_count = _count_enabled_params(core, ar_h, sat_h)
1010
+ target_tokens = int(ratio * param_count)
1011
+
1012
+ if steps:
1013
+ phase_target_tokens = steps * BLOCK * BATCH
1014
+ total_tokens_needed = seen_tok + phase_target_tokens
1015
+ else:
1016
+ total_tokens_needed = target_tokens
1017
+ if total_tokens_needed <= seen_tok:
1018
+ print(f"[{phase_name}] target {total_tokens_needed} already reached.")
1019
+ return start_step, seen_tok, resume_wall_time
1020
+
1021
+ stream = token_stream(
1022
+ source,
1023
+ total_tokens_needed,
1024
+ seed=42,
1025
+ chat=chat_cfg.get("chat", False),
1026
+ chat_messages_key=chat_cfg.get("key", "messages"),
1027
+ sft_add_generation_prompt=chat_cfg.get("gen_prompt", False),
1028
+ dataset_field_text=chat_cfg.get("text_field", "text"),
1029
+ streaming=streaming,
1030
+ )
1031
+
1032
+ pbar = SafeProgress(total=total_tokens_needed, initial=seen_tok, unit="tok")
1033
+ grow_plan = _parse_grow_plan(args.grow_plan) if args.auto_grow else []
1034
+ buf: List[int] = []
1035
+ batch_accum: List[List[int]] = []
1036
+ step = start_step
1037
+ steps_since_last_grow = 0
1038
+ now_wall = time.time()
1039
+ last_save_mono = time.monotonic() - (now_wall - (resume_wall_time or now_wall))
1040
+ print(f"[{phase_name}] Starting. Goal: {total_tokens_needed:,} tokens. Batch={BATCH}, Block={BLOCK}")
1041
+ print(f"[{phase_name}] AR_ONLY={args.ar_only}, TIE_WEIGHTS={tie_weights}, STREAMING={streaming}, backend={BACKEND}")
1042
+
1043
+ step_start_time = time.monotonic()
1044
+ tok_per_sec_avg = 0.0
1045
+
1046
+ while seen_tok < total_tokens_needed:
1047
+ try:
1048
+ while len(buf) < BLOCK:
1049
+ buf.append(next(stream))
1050
+ except StopIteration:
1051
+ break
1052
+
1053
+ seq = buf[:BLOCK]
1054
+ buf = buf[BLOCK:]
1055
+ batch_accum.append(seq)
1056
+ if len(batch_accum) < BATCH:
1057
+ continue
1058
+
1059
+ ids = torch.tensor(batch_accum, device=DEV, dtype=torch.long)
1060
+ batch_accum = []
1061
+
1062
+ try:
1063
+ if is_tt():
1064
+ opt.zero_grad(set_to_none=True)
1065
+ loss, loss_ar, loss_sat = _losses_for_batch(args, core, ar_h, sat_h, ids)
1066
+ if args.max_grad_norm > 0:
1067
+ loss.backward()
1068
+ nn.utils.clip_grad_norm_(list(core.parameters()) + list(ar_h.parameters()) + list(sat_h.parameters()), args.max_grad_norm)
1069
+ else:
1070
+ loss.backward()
1071
+ optimizer_step_backend(opt)
1072
+ opt.zero_grad(set_to_none=True)
1073
+ tt_sync(wait=True)
1074
+ loss_value = float(loss.detach().cpu())
1075
+ else:
1076
+ with amp(args.amp):
1077
+ loss, loss_ar, loss_sat = _losses_for_batch(args, core, ar_h, sat_h, ids)
1078
+ scaler.scale(loss).backward()
1079
+ scaler.unscale_(opt)
1080
+ if args.max_grad_norm > 0:
1081
+ nn.utils.clip_grad_norm_(list(core.parameters()) + list(ar_h.parameters()) + list(sat_h.parameters()), args.max_grad_norm)
1082
+ scaler.step(opt)
1083
+ scaler.update()
1084
+ opt.zero_grad(set_to_none=True)
1085
+ loss_value = float(loss.item())
1086
+ except RuntimeError as e:
1087
+ msg = str(e).lower()
1088
+
1089
+ if is_tt():
1090
+ tt_sync(wait=True)
1091
+ raise RuntimeError(
1092
+ f"TT training step failed. Keep shapes static and reduce --block / --batch_size manually. "
1093
+ f"Original error: {e}"
1094
+ ) from e
1095
+
1096
+ if "out of memory" in msg or "cuda error" in msg:
1097
+ batch_accum = []
1098
+ opt.zero_grad(set_to_none=True)
1099
+ if is_cuda():
1100
+ torch.cuda.empty_cache()
1101
+ torch.cuda.synchronize()
1102
+ if BATCH > 1:
1103
+ print(f"\n[{phase_name} OOM] Reducing Batch: {BATCH} -> {BATCH - 1}")
1104
+ BATCH -= 1
1105
+ time.sleep(3)
1106
+ continue
1107
+ if args.auto_grow:
1108
+ smaller = [b for b in grow_plan if b < BLOCK]
1109
+ new_block = smaller[-1] if smaller else max(128, BLOCK // 2)
1110
+ else:
1111
+ new_block = max(128, BLOCK // 2)
1112
+ print(f"\n[{phase_name} OOM] Reducing Block: {BLOCK} -> {new_block}")
1113
+ BLOCK = new_block
1114
+ time.sleep(3)
1115
+ continue
1116
+ raise
1117
+
1118
+ step += 1
1119
+ toks_processed = BLOCK * BATCH
1120
+ seen_tok += toks_processed
1121
+ pbar.update(toks_processed)
1122
+ pbar.set_postfix(loss=f"{loss_value:.3f}", B=BATCH, L=BLOCK)
1123
+
1124
+ step_elapsed = time.monotonic() - step_start_time
1125
+ tok_per_sec_now = toks_processed / step_elapsed if step_elapsed > 0 else 0
1126
+ tok_per_sec_avg = 0.9 * tok_per_sec_avg + 0.1 * tok_per_sec_now if tok_per_sec_avg > 0 else tok_per_sec_now
1127
+ step_start_time = time.monotonic()
1128
+
1129
+ write_status(step, seen_tok, loss_value, BATCH, BLOCK, tok_per_sec_avg, phase_name)
1130
+
1131
+ if args.save_every_sec > 0:
1132
+ now_mono = time.monotonic()
1133
+ if now_mono - last_save_mono >= args.save_every_sec:
1134
+ ck_name = f"{phase_name}_step{step:08d}.pt"
1135
+ save_ckpt(
1136
+ pathlib.Path(args.save_dir) / ck_name,
1137
+ core,
1138
+ ar_h,
1139
+ sat_h,
1140
+ opt,
1141
+ scaler,
1142
+ meta={
1143
+ "cfg": cfg,
1144
+ "step": step,
1145
+ "seen_tok": seen_tok,
1146
+ "wall_time": time.time(),
1147
+ "tie_weights": tie_weights,
1148
+ "block_size": BLOCK,
1149
+ "batch_size": BATCH,
1150
+ },
1151
+ )
1152
+ _prune_checkpoints(pathlib.Path(args.save_dir), phase_name, max_ckpts)
1153
+ last_save_mono = now_mono
1154
+
1155
+ if args.auto_grow and not is_tt():
1156
+ steps_since_last_grow += 1
1157
+ if steps_since_last_grow >= args.grow_every_steps:
1158
+ steps_since_last_grow = 0
1159
+ try:
1160
+ idx = grow_plan.index(BLOCK)
1161
+ if idx + 1 < len(grow_plan):
1162
+ BLOCK = grow_plan[idx + 1]
1163
+ print(f"[{phase_name} Grow] Block -> {BLOCK}")
1164
+ if is_cuda():
1165
+ torch.cuda.empty_cache()
1166
+ except ValueError:
1167
+ grow_plan = sorted(set(grow_plan + [BLOCK]))
1168
+
1169
+ pbar.close()
1170
+ save_ckpt(
1171
+ pathlib.Path(args.save_dir) / f"{phase_name}_final.pt",
1172
+ core,
1173
+ ar_h,
1174
+ sat_h,
1175
+ opt,
1176
+ scaler,
1177
+ meta={
1178
+ "cfg": cfg,
1179
+ "step": step,
1180
+ "seen_tok": seen_tok,
1181
+ "wall_time": time.time(),
1182
+ "tie_weights": tie_weights,
1183
+ "block_size": BLOCK,
1184
+ "batch_size": BATCH,
1185
+ },
1186
+ )
1187
+
1188
+ return step, seen_tok, time.time()
1189
+
1190
+
1191
+ def train(args):
1192
+ init_runtime(args.device_backend, for_infer=False)
1193
+
1194
+ cfg = PRESETS[args.preset].copy()
1195
+ tie_weights = args.tie_weights
1196
+ print_expansion_info(cfg, tie_weights)
1197
+
1198
+ if not args.fresh:
1199
+ src_probe = pathlib.Path(args.warmstart_from) if args.warmstart_from else pathlib.Path(args.save_dir) / "final.pt"
1200
+ prev_cfg = infer_cfg_from_ckpt(src_probe)
1201
+ else:
1202
+ prev_cfg = None
1203
+
1204
+ if prev_cfg:
1205
+ cfg.update({k: v for k, v in prev_cfg.items() if k in cfg})
1206
+ if args.x2 and prev_cfg.get("layers"):
1207
+ cfg["layers"] = max(cfg["layers"], prev_cfg["layers"] * 2)
1208
+
1209
+ if args.rank:
1210
+ cfg["rank"] = args.rank
1211
+ if args.x2 and not prev_cfg:
1212
+ cfg["layers"] *= 2
1213
+
1214
+ print(f"Config: {cfg}")
1215
+
1216
+ core = Encoder(cfg, tie_weights=tie_weights)
1217
+ ar_h = ARHead(cfg["d"], tie_weights=tie_weights, embedding_weight=core.emb.weight if tie_weights else None)
1218
+ sat_h = SATHead(cfg["d"], mode="var")
1219
+
1220
+ _move_model_for_backend(core, ar_h, sat_h, tt_bf16=args.tt_bf16)
1221
+ total_params = _count_enabled_params(core, ar_h, sat_h)
1222
+ print(f"Total parameters: {total_params:,}")
1223
+
1224
+ if tie_weights:
1225
+ print(f"{Colors.WARN}[weight-tying] Embedding and LM head share weights{Colors.RESET}")
1226
+
1227
+ if not args.fresh:
1228
+ src = pathlib.Path(args.warmstart_from) if args.warmstart_from else pathlib.Path(args.save_dir) / "final.pt"
1229
+ src = _resolve_ckpt(src)
1230
+ if src:
1231
+ loaded = _safe_load_any(src, core, key="core")
1232
+ _safe_load_any(src, ar_h, key="ar")
1233
+ _safe_load_any(src, sat_h, key="sat")
1234
+ if loaded:
1235
+ print(f"Warm-start loaded from {src}")
1236
+
1237
+ _phase_freeze(core, freeze_core=args.freeze_core, unfreeze_ln=args.unfreeze_ln, train_emb=args.train_emb)
1238
+
1239
+ opt = torch.optim.AdamW(
1240
+ [
1241
+ {"params": [p for p in core.parameters() if p.requires_grad], "lr": args.lr_core},
1242
+ {"params": ar_h.parameters(), "lr": args.lr_head},
1243
+ {"params": sat_h.parameters(), "lr": args.lr_head},
1244
+ ]
1245
+ )
1246
+
1247
+ scaler = GradScaler(enabled=(args.amp and is_cuda()))
1248
+ start_step, seen_tok, last_wall, _resumed_block = 0, 0, None, None
1249
+
1250
+ if args.resume and not args.fresh:
1251
+ start_step, seen_tok, last_wall, _resumed_block = load_ckpt(pathlib.Path(args.resume), core, ar_h, sat_h, opt, scaler)
1252
+ print(f"Resumed from step {start_step}" + (f", block_size={_resumed_block}" if _resumed_block else ""))
1253
+
1254
+ if args.compile and is_cuda():
1255
+ print("[torch.compile] Compiling model for CUDA...")
1256
+ core = torch.compile(core, mode="reduce-overhead")
1257
+ ar_h = torch.compile(ar_h, mode="reduce-overhead")
1258
+ sat_h = torch.compile(sat_h, mode="reduce-overhead")
1259
+ print("[torch.compile] Done.")
1260
+ elif args.compile and is_tt():
1261
+ print("[torch.compile] Ignored on TT backend. Use TT-XLA runtime instead.")
1262
+
1263
+ step, seen_tok, last_wall = _train_phase(
1264
+ args,
1265
+ "pretrain",
1266
+ core,
1267
+ ar_h,
1268
+ sat_h,
1269
+ opt,
1270
+ scaler,
1271
+ start_step,
1272
+ seen_tok,
1273
+ last_wall,
1274
+ cfg,
1275
+ args.source,
1276
+ args.steps,
1277
+ (_resumed_block if _resumed_block and args.auto_grow else None) or args.block or DEFAULT_BLOCK,
1278
+ args.batch_size or DEFAULT_BATCH,
1279
+ chat_cfg={
1280
+ "chat": args.chat,
1281
+ "key": args.chat_messages_key,
1282
+ "gen_prompt": args.sft_add_generation_prompt,
1283
+ "text_field": args.dataset_field_text,
1284
+ },
1285
+ max_ckpts=args.max_ckpts,
1286
+ target_tokens_override=args.target_tokens,
1287
+ tie_weights=tie_weights,
1288
+ )
1289
+
1290
+ if (not args.after_sft_source) and (args.after_sft_steps and args.after_sft_steps > 0):
1291
+ args.after_sft_source = DEFAULT_AFTER_SFT_SOURCES
1292
+ args.after_sft_chat = True
1293
+ if args.after_sft_add_generation_prompt is None:
1294
+ args.after_sft_add_generation_prompt = True
1295
+ if not args.after_sft_block:
1296
+ args.after_sft_block = DEFAULT_AFTER_SFT_BLOCK
1297
+
1298
+ if args.after_sft_source and args.after_sft_steps and args.after_sft_steps > 0:
1299
+ print("\n[Orchestrator] Starting Post-Pretraining SFT Phase...")
1300
+
1301
+ _phase_freeze(
1302
+ core,
1303
+ freeze_core=args.after_sft_freeze_core,
1304
+ unfreeze_ln=args.after_sft_unfreeze_ln,
1305
+ train_emb=args.after_sft_train_emb,
1306
+ )
1307
+
1308
+ opt = torch.optim.AdamW(
1309
+ [
1310
+ {"params": [p for p in core.parameters() if p.requires_grad], "lr": args.after_sft_lr_core or args.lr_core},
1311
+ {"params": ar_h.parameters(), "lr": args.after_sft_lr_head or args.lr_head},
1312
+ {"params": sat_h.parameters(), "lr": args.after_sft_lr_head or args.lr_head},
1313
+ ]
1314
+ )
1315
+
1316
+ step, seen_tok, last_wall = _train_phase(
1317
+ args,
1318
+ "sft",
1319
+ core,
1320
+ ar_h,
1321
+ sat_h,
1322
+ opt,
1323
+ scaler,
1324
+ step,
1325
+ seen_tok,
1326
+ last_wall,
1327
+ cfg,
1328
+ args.after_sft_source,
1329
+ args.after_sft_steps,
1330
+ args.after_sft_block or DEFAULT_AFTER_SFT_BLOCK,
1331
+ args.batch_size or DEFAULT_BATCH,
1332
+ chat_cfg={
1333
+ "chat": args.after_sft_chat,
1334
+ "key": args.after_sft_chat_messages_key,
1335
+ "gen_prompt": args.after_sft_add_generation_prompt if args.after_sft_add_generation_prompt is not None else args.sft_add_generation_prompt,
1336
+ "text_field": args.after_sft_dataset_field_text,
1337
+ },
1338
+ max_ckpts=args.max_ckpts,
1339
+ target_tokens_override=None,
1340
+ tie_weights=tie_weights,
1341
+ streaming=False,
1342
+ )
1343
+
1344
+ save_ckpt(
1345
+ pathlib.Path(args.save_dir) / "final.pt",
1346
+ core,
1347
+ ar_h,
1348
+ sat_h,
1349
+ opt,
1350
+ scaler,
1351
+ meta={
1352
+ "cfg": cfg,
1353
+ "step": step,
1354
+ "seen_tok": seen_tok,
1355
+ "wall_time": time.time(),
1356
+ "tie_weights": tie_weights,
1357
+ "block_size": args.block or DEFAULT_BLOCK,
1358
+ "batch_size": args.batch_size or DEFAULT_BATCH,
1359
+ },
1360
+ )
1361
+ print("🎉 All Training Complete")
1362
+
1363
+
1364
+ # -----------------------------------------------------------------------------
1365
+ # Sampling / inference
1366
+ # -----------------------------------------------------------------------------
1367
+ def _apply_penalties(logits, ids, n, rep_p, pres_p, freq_p):
1368
+ if ids.numel() == 0:
1369
+ return logits
1370
+ hist = ids[0, -n:].long() if n > 0 else ids[0].long()
1371
+ uniq, counts = torch.unique(hist, return_counts=True)
1372
+ if pres_p or freq_p:
1373
+ logits[..., uniq] -= (pres_p + freq_p * counts.to(logits.dtype))
1374
+ if rep_p != 1.0:
1375
+ sel = logits[..., uniq]
1376
+ logits[..., uniq] = torch.where(sel > 0, sel / rep_p, sel * rep_p)
1377
+ return logits
1378
+
1379
+
1380
+ def _sample(logits, T, top_k, top_p, min_p, greedy):
1381
+ if greedy:
1382
+ return logits.argmax(-1, keepdim=True)
1383
+ probs = (logits / max(T, 1e-8)).softmax(-1)
1384
+ if top_k:
1385
+ v, i = torch.topk(probs, min(top_k, probs.size(-1)))
1386
+ probs = torch.zeros_like(probs).scatter_(-1, i, v)
1387
+ if top_p < 1.0:
1388
+ s_probs, s_idx = torch.sort(probs, descending=True, dim=-1)
1389
+ probs = torch.zeros_like(probs).scatter_(
1390
+ -1, s_idx, s_probs * (torch.cumsum(s_probs, -1) <= top_p).to(probs.dtype)
1391
+ )
1392
+ if min_p > 0:
1393
+ probs[probs < min_p] = 0
1394
+ if probs.sum() == 0:
1395
+ return logits.argmax(-1, keepdim=True)
1396
+ return probs.div_(probs.sum()).multinomial(1)
1397
+
1398
+
1399
+ @torch.no_grad()
1400
+ def infer(args):
1401
+ init_runtime(args.device_backend, for_infer=True, force_tt_infer=args.force_tt_infer)
1402
+
1403
+ if args.mode == "ar":
1404
+ if args.temperature is None:
1405
+ args.temperature = 0.7
1406
+ if args.top_k is None:
1407
+ args.top_k = 0
1408
+ if args.repetition_penalty is None:
1409
+ args.repetition_penalty = 1.3
1410
+ if args.presence_penalty is None:
1411
+ args.presence_penalty = 0.0
1412
+ if args.frequency_penalty is None:
1413
+ args.frequency_penalty = 0.3
1414
+ if args.penalty_last_n is None:
1415
+ args.penalty_last_n = 128
1416
+ if args.var is None:
1417
+ args.var = False
1418
+ else:
1419
+ if args.temperature is None:
1420
+ args.temperature = 0.5
1421
+ if args.top_k is None:
1422
+ args.top_k = 30
1423
+ if args.repetition_penalty is None:
1424
+ args.repetition_penalty = 2.0
1425
+ if args.presence_penalty is None:
1426
+ args.presence_penalty = 0.6
1427
+ if args.frequency_penalty is None:
1428
+ args.frequency_penalty = 1.0
1429
+ if args.penalty_last_n is None:
1430
+ args.penalty_last_n = 200
1431
+ if args.var is None:
1432
+ args.var = True
1433
+
1434
+ path = _resolve_ckpt(pathlib.Path(args.ckpt)) or pathlib.Path(args.ckpt)
1435
+ sd = torch.load(path, map_location="cpu")
1436
+ cfg = sd["cfg"]
1437
+ tie_weights = sd.get("tie_weights", False)
1438
+
1439
+ uk_time = get_uk_time()
1440
+ ckpt_name = path.name
1441
+ print("┌─────────────────────────────────────────────────┐")
1442
+ print(f"│ INFERENCE @ {uk_time:<35s} │")
1443
+ print("├─────────────────────────────────────────────────┤")
1444
+ print(f"│ Checkpoint: {ckpt_name:<35s} │")
1445
+ print("└─────────────────────────────────────────────────┘")
1446
+ print_expansion_info(cfg, tie_weights)
1447
+
1448
+ core = Encoder(cfg, tie_weights=tie_weights)
1449
+ ar_h = ARHead(cfg["d"], tie_weights=tie_weights, embedding_weight=core.emb.weight if tie_weights else None)
1450
+ sat_h = SATHead(cfg["d"])
1451
+
1452
+ core.load_state_dict(sd["core"])
1453
+ ar_h.load_state_dict(sd["ar"])
1454
+ sat_h.load_state_dict(sd["sat"])
1455
+
1456
+ if args.fp16 and is_cuda():
1457
+ core.half()
1458
+ ar_h.half()
1459
+ sat_h.half()
1460
+ print(f"{Colors.INFO}Using fp16 inference{Colors.RESET}")
1461
+
1462
+ _move_model_for_backend(core, ar_h, sat_h, tt_bf16=args.tt_bf16)
1463
+ core.eval()
1464
+ ar_h.eval()
1465
+ sat_h.eval()
1466
+
1467
+ total_params = _count_enabled_params(core, ar_h, sat_h)
1468
+ if total_params >= 1_000_000_000:
1469
+ param_str = f"{total_params / 1_000_000_000:.2f}B"
1470
+ elif total_params >= 1_000_000:
1471
+ param_str = f"{total_params / 1_000_000:.2f}M"
1472
+ elif total_params >= 1_000:
1473
+ param_str = f"{total_params / 1_000:.2f}K"
1474
+ else:
1475
+ param_str = f"{total_params}"
1476
+ print(f"Model size: {param_str} parameters ({total_params:,})")
1477
+
1478
+ prompt_tokens = tok.encode(args.prompt)
1479
+ prompt_len = len(prompt_tokens)
1480
+ ids = torch.tensor([prompt_tokens], device=DEV, dtype=torch.long)
1481
+ if ids.size(1) == 0:
1482
+ ids = torch.tensor([[EOS]], device=DEV, dtype=torch.long)
1483
+ prompt_len = 1
1484
+
1485
+ mode_str = args.mode
1486
+ if args.mode == "sat":
1487
+ mode_str = f"sat-{'var' if args.var else 'fixed'}"
1488
+ print(f"{Colors.INFO}Generating ({mode_str}) on {BACKEND}...{Colors.RESET}")
1489
+
1490
+ start = time.time()
1491
+ if args.mode == "ar":
1492
+ h, kvs = core(ids, causal_mask(ids.size(1)), use_cache=True, total_seq_len=ids.size(1))
1493
+ for _ in range(args.max_new):
1494
+ logits = ar_h(h)[:, -1]
1495
+ logits = _apply_penalties(
1496
+ logits, ids, args.penalty_last_n, args.repetition_penalty, args.presence_penalty, args.frequency_penalty
1497
+ )
1498
+ nxt = _sample(logits, args.temperature, args.top_k, args.top_p, args.min_p, args.greedy)
1499
+ ids = torch.cat([ids, nxt], 1)
1500
+ h, kvs = core(ids[:, -1:], None, kv_caches=kvs, use_cache=True, total_seq_len=ids.size(1))
1501
+ else:
1502
+ cached_len = ids.size(1)
1503
+ h, kvs = core(ids, sat_mask(ids.size(1)), use_cache=True, total_seq_len=cached_len)
1504
+ added = 0
1505
+ while added < args.max_new:
1506
+ logits_all, gate = sat_h(h[:, -SAT_BLOCK:])
1507
+ stride = SAT_BLOCK if (not args.var or gate is None) else (gate.softmax(-1).multinomial(1).item() + 1)
1508
+ new_tokens = []
1509
+ for i in range(int(stride)):
1510
+ logits = logits_all[:, i]
1511
+ logits = _apply_penalties(
1512
+ logits, ids, args.penalty_last_n, args.repetition_penalty, args.presence_penalty, args.frequency_penalty
1513
+ )
1514
+ nxt = _sample(logits, args.temperature, args.top_k, args.top_p, args.min_p, args.greedy)
1515
+ new_tokens.append(nxt)
1516
+ ids = torch.cat([ids, nxt], 1)
1517
+ added += 1
1518
+ if added >= args.max_new:
1519
+ break
1520
+ if added >= args.max_new:
1521
+ break
1522
+ new_ids = torch.cat(new_tokens, dim=1)
1523
+ mask = sat_mask_cached(new_ids.size(1), cached_len)
1524
+ h, kvs = core(new_ids, mask, kv_caches=kvs, use_cache=True, total_seq_len=ids.size(1))
1525
+ cached_len = ids.size(1)
1526
+
1527
+ if is_tt():
1528
+ tt_sync(wait=True)
1529
+
1530
+ elapsed = time.time() - start
1531
+ gen_tokens = len(ids[0]) - prompt_len
1532
+ tok_per_sec = gen_tokens / elapsed if elapsed > 0 else 0
1533
+
1534
+ all_tokens = ids[0].detach().cpu().tolist()
1535
+ prompt_text = tok.decode(all_tokens[:prompt_len], skip_special_tokens=True)
1536
+ gen_text = tok.decode(all_tokens[prompt_len:], skip_special_tokens=True)
1537
+ print(f"{Colors.PROMPT}{prompt_text}{Colors.RESET}{gen_text}")
1538
+ print(f"{Colors.INFO}[{elapsed:.2f}s | {gen_tokens} tokens | {tok_per_sec:.1f} tok/s | backend={BACKEND}]{Colors.RESET}")
1539
+
1540
+
1541
+ # -----------------------------------------------------------------------------
1542
+ # CLI
1543
+ # -----------------------------------------------------------------------------
1544
+ def main():
1545
+ ap = argparse.ArgumentParser(description="AGILLM Expansion Ratio Testing - TT/CUDA/CPU single-file trainer")
1546
+
1547
+ ap.add_argument("--device_backend", choices=["auto", "cpu", "cuda", "tt"], default=REQUESTED_BACKEND)
1548
+ ap.add_argument("--force_tt_infer", action="store_true", help="Allow TT inference even though it recompiles on changing sequence lengths")
1549
+
1550
+ sub = ap.add_subparsers(dest="cmd", required=True)
1551
+
1552
+ tr = sub.add_parser("train")
1553
+ tr.add_argument("--preset", choices=PRESETS.keys(), default="nano_3x")
1554
+ tr.add_argument("--rank", type=int)
1555
+ tr.add_argument("--block", type=int, default=DEFAULT_BLOCK)
1556
+ tr.add_argument("--batch_size", type=int, default=DEFAULT_BATCH)
1557
+ tr.add_argument("--source", default=DEFAULT_PRETRAIN_SOURCES)
1558
+ tr.add_argument("--target_tokens", type=int)
1559
+ tr.add_argument("--steps", type=int)
1560
+ tr.add_argument("--amp", action="store_true")
1561
+ tr.add_argument("--compile", action="store_true", help="CUDA only. TT ignores this.")
1562
+ tr.add_argument("--save_every_sec", type=int, default=DEFAULT_SAVE_SEC)
1563
+ tr.add_argument("--save_dir", default=str(CKDIR))
1564
+ tr.add_argument("--resume", type=str)
1565
+ tr.add_argument("--x2", action="store_true")
1566
+ tr.add_argument("--warmstart_from", type=str)
1567
+ tr.add_argument("--fresh", action="store_true")
1568
+ tr.add_argument("--max_ckpts", type=int, default=None)
1569
+ tr.add_argument("--chilla_max_double", action="store_true")
1570
+ tr.add_argument("--tie_weights", action="store_true")
1571
+ tr.add_argument("--ar_only", action="store_true")
1572
+ tr.add_argument("--freeze_core", action="store_true")
1573
+ tr.add_argument("--unfreeze_ln", action="store_true")
1574
+ tr.add_argument("--train_emb", action="store_true")
1575
+ tr.add_argument("--lr_core", type=float, default=LR_CORE)
1576
+ tr.add_argument("--lr_head", type=float, default=LR_HEAD)
1577
+ tr.add_argument("--chat", action="store_true")
1578
+ tr.add_argument("--chat_messages_key", default="messages")
1579
+ tr.add_argument("--dataset_field_text", default="text")
1580
+ tr.add_argument("--sft_add_generation_prompt", action="store_true")
1581
+ tr.add_argument("--auto_grow", action="store_true")
1582
+ tr.add_argument("--grow_plan", default="576,640,768,896,1024,1122")
1583
+ tr.add_argument("--grow_every_steps", type=int, default=50000)
1584
+
1585
+ tr.add_argument("--after_sft_source", default="")
1586
+ tr.add_argument("--after_sft_steps", type=int, default=0)
1587
+ tr.add_argument("--after_sft_chat", action="store_true")
1588
+ tr.add_argument("--after_sft_chat_messages_key", default="messages")
1589
+ tr.add_argument("--after_sft_dataset_field_text", default="text")
1590
+ tr.add_argument("--after_sft_add_generation_prompt", type=bool, default=None)
1591
+ tr.add_argument("--after_sft_block", type=int, default=0)
1592
+ tr.add_argument("--after_sft_freeze_core", action="store_true")
1593
+ tr.add_argument("--after_sft_unfreeze_ln", action="store_true")
1594
+ tr.add_argument("--after_sft_train_emb", action="store_true")
1595
+ tr.add_argument("--after_sft_lr_core", type=float, default=0.0)
1596
+ tr.add_argument("--after_sft_lr_head", type=float, default=0.0)
1597
+
1598
+ tr.add_argument("--tt_bf16", action="store_true", default=True, help="Use bf16 weights on TT (recommended)")
1599
+ tr.add_argument("--no-tt-bf16", dest="tt_bf16", action="store_false")
1600
+ tr.add_argument("--label_smoothing", type=float, default=0.1)
1601
+ tr.add_argument("--max_grad_norm", type=float, default=1.0)
1602
+
1603
+ inf = sub.add_parser("infer")
1604
+ inf.add_argument("--mode", choices=["ar", "sat"], required=True)
1605
+ inf.add_argument("--ckpt", required=True)
1606
+ inf.add_argument("--prompt", required=True)
1607
+ inf.add_argument("--max_new", type=int, default=120)
1608
+ inf.add_argument("--temperature", type=float, default=None)
1609
+ inf.add_argument("--greedy", action="store_true")
1610
+ inf.add_argument("--top_k", type=int, default=None)
1611
+ inf.add_argument("--top_p", type=float, default=0.9)
1612
+ inf.add_argument("--min_p", type=float, default=0.0)
1613
+ inf.add_argument("--repetition_penalty", type=float, default=None)
1614
+ inf.add_argument("--presence_penalty", type=float, default=None)
1615
+ inf.add_argument("--frequency_penalty", type=float, default=None)
1616
+ inf.add_argument("--penalty_last_n", type=int, default=None)
1617
+ inf.add_argument("--var", action="store_true", default=None)
1618
+ inf.add_argument("--no-var", dest="var", action="store_false")
1619
+ inf.add_argument("--fp16", action="store_true", help="Use fp16 inference (CUDA only)")
1620
+ inf.add_argument("--tt_bf16", action="store_true", default=True, help="Use bf16 weights on TT if TT infer is forced")
1621
+ inf.add_argument("--no-tt-bf16", dest="tt_bf16", action="store_false")
1622
+
1623
+ st = sub.add_parser("status")
1624
+
1625
+ args = ap.parse_args()
1626
+
1627
+ if args.cmd == "train":
1628
+ train(args)
1629
+ elif args.cmd == "infer":
1630
+ infer(args)
1631
+ else:
1632
+ show_status()
1633
+
1634
+
1635
+ if __name__ == "__main__":
1636
+ main()