File size: 8,475 Bytes
11c11f8
 
 
 
 
5b5a08d
11c11f8
 
 
 
 
 
3859a82
b6bcd75
11c11f8
 
5bfbb8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b5a08d
5bfbb8a
 
 
11c11f8
e2f5e25
11c11f8
 
9897d01
11c11f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9897d01
11c11f8
9897d01
 
 
11c11f8
 
9897d01
 
11c11f8
 
 
e2f5e25
5bfbb8a
e2f5e25
9d8c566
 
5bfbb8a
9d8c566
9897d01
11c11f8
 
6d5c935
5bfbb8a
6d5c935
f6670ea
5bfbb8a
 
 
 
11c11f8
 
b6bcd75
5bfbb8a
 
 
 
 
 
5b5a08d
b6bcd75
9d8c566
 
 
5bfbb8a
9897d01
11c11f8
 
 
 
9897d01
11c11f8
5b5a08d
11c11f8
5bfbb8a
5b5a08d
5bfbb8a
 
 
31d69ba
3859a82
11c11f8
 
 
3859a82
5b5a08d
 
 
11c11f8
 
 
 
 
 
 
5bfbb8a
 
31d69ba
3859a82
11c11f8
8e41f12
5b5a08d
b6bcd75
11c11f8
 
b6bcd75
11c11f8
 
 
 
 
31d69ba
5b5a08d
 
11c11f8
e2f5e25
3859a82
31d69ba
11c11f8
31d69ba
5b5a08d
 
3859a82
b6bcd75
 
9897d01
5b5a08d
 
11c11f8
31d69ba
5b5a08d
 
 
 
 
 
 
 
 
 
11c11f8
8e41f12
 
 
 
 
 
11c11f8
5b5a08d
 
9897d01
 
 
 
3859a82
5b5a08d
9897d01
11c11f8
31d69ba
 
5b5a08d
 
31d69ba
5b5a08d
 
 
 
 
 
31d69ba
11c11f8
3859a82
 
11c11f8
5b5a08d
11c11f8
3859a82
 
11c11f8
5b5a08d
 
 
 
11c11f8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
from __future__ import annotations

import json
import math
import os
import sys
import time

import torch

import chimera_turbo

from .common import save_final_checkpoint, save_training_checkpoint
from .hyper import ProgressiveLoopScheduler


def _safe_batch(desired_batch: int, seq_len: int, vocab_size: int,
                max_logits_gb: float = 2.0) -> int:
    """Cap batch size so the logits tensor fits in memory.

    Logits shape: [batch, seq, vocab] at fp32 = batch * seq * vocab * 4 bytes.
    With vocab=200073, batch=256, seq=16: 3.28 GB just for logits.
    Backward doubles this. Must stay well under 32 GB total.
    """
    bytes_per_sample = seq_len * vocab_size * 4  # fp32 logits
    max_bytes = int(max_logits_gb * 1024**3)
    max_batch = max(1, max_bytes // bytes_per_sample)
    capped = min(desired_batch, max_batch)
    if capped < desired_batch:
        print(f"  [MEM] Batch {desired_batch} β†’ {capped} (logits would be "
              f"{desired_batch * seq_len * vocab_size * 4 / 1e9:.1f} GB, cap={max_logits_gb} GB)")
        sys.stdout.flush()
    return capped


def train_fast_loop(args, model, config, loader, compute_loss) -> str:
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.98))
    os.makedirs(args.output_dir, exist_ok=True)
    model.train()
    step, total_loss, best_loss, toks = 0, 0.0, float("inf"), 0
    t0 = time.time()
    data_iter = iter(loader)
    while step < args.max_steps:
        try:
            batch = next(data_iter)
        except StopIteration:
            data_iter = iter(loader)
            batch = next(data_iter)
        loss = compute_loss(batch)
        loss.backward()
        total_loss += float(loss.item())
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)
        toks += batch["input_ids"].numel()
        step += 1
        if step % args.log_every == 0:
            dt = time.time() - t0
            avg = total_loss / args.log_every
            ppl = math.exp(min(avg, 20))
            tps = toks / dt if dt > 0 else 0
            print(f"  step {step:>6}/{args.max_steps} | loss {avg:.4f} | ppl {ppl:>8.2f} | {tps:.0f} tok/s")
            best_loss = min(best_loss, avg)
            total_loss, toks, t0 = 0.0, 0, time.time()
    save_final_checkpoint(model, config, step, best_loss, os.path.join(args.output_dir, "final"))
    return os.path.join(args.output_dir, "final")


def train_standard_loop(args, model, config, loader, compute_loss, optimizer, use_mezo):
    pass


def train_hyper_loop(args, model, config, dataset, initial_seq, grow, unfreezer):
    use_compile = getattr(args, "compile", False)
    vocab_size = int(config.get("vocab_size", 200073))

    # ── Muon LR for ternary BitLinear ──
    muon_lr = 0.012
    muon_warmup = 30

    model, optimizer, scheduler, extras = chimera_turbo.apply(
        model,
        max_steps=args.max_steps,
        lr=muon_lr,
        weight_decay=0.02,
        warmup_steps=muon_warmup,
        use_compile=use_compile,
        mtp_heads=0,
        llrd_decay=0.90,
        grokfast_alpha=0.95,
        grokfast_lambda=1.5,
    )
    model.train()

    # ── Gradient checkpointing: saves ~60% activation memory ──
    raw_model = getattr(model, "_orig_mod", model)
    if hasattr(raw_model, "enable_gradient_checkpointing"):
        raw_model.enable_gradient_checkpointing()
        print(f"[OPT] Gradient checkpointing: ON")

    # ── Looping: force loops=1 ──
    cur_loops = 1
    if hasattr(raw_model, "loop_controller"):
        raw_model.loop_controller.loop_default = 1
        raw_model.loop_controller.loop_min = 1
        raw_model.loop_controller.loop_max = 1

    use_bf16 = bool(args.bf16)

    os.makedirs(args.output_dir, exist_ok=True)
    log_f = open(os.path.join(args.output_dir, "log_hyper.jsonl"), "w")
    step, total_loss, valid_count, best_loss, toks = 0, 0.0, 0, float("inf"), 0
    t0 = time.time()
    t_start = t0
    cur_seq = initial_seq

    # ── Memory-safe batch size ──
    desired_batch = args.batch_size * max(1, args.seq_len // max(1, cur_seq))
    eff_batch = _safe_batch(desired_batch, cur_seq, vocab_size)

    loader = torch.utils.data.DataLoader(
        dataset, batch_size=eff_batch, shuffle=True, num_workers=0, drop_last=True)
    data_iter = iter(loader)

    print(f"\n{'=' * 65}")
    print(f"Training  batch={eff_batch}  seq={cur_seq}  loops={cur_loops}")
    print(f"Starting first step (may take 30-60s on CPU with 227M params)...")
    print(f"{'=' * 65}")
    sys.stdout.flush()

    while step < args.max_steps:
        if grow:
            ns = grow.get_seq_len(step)
            if ns != cur_seq:
                cur_seq = ns
                dataset.set_seq_len(cur_seq)
                desired_batch = args.batch_size * max(1, args.seq_len // max(1, cur_seq))
                eff_batch = _safe_batch(desired_batch, cur_seq, vocab_size)
                loader = torch.utils.data.DataLoader(
                    dataset, batch_size=eff_batch, shuffle=True, num_workers=0, drop_last=True)
                data_iter = iter(loader)
                print(f"  [P1] seq -> {cur_seq}  batch -> {eff_batch}")
                sys.stdout.flush()

        if unfreezer:
            unfreezer.update(step)

        try:
            batch = next(data_iter)
        except StopIteration:
            data_iter = iter(loader)
            batch = next(data_iter)

        step_t0 = time.time()

        loss_val = chimera_turbo.training_step(
            model, batch, optimizer, scheduler,
            extras=extras, grad_accum_steps=1, step=step,
            autocast_dtype=torch.bfloat16 if use_bf16 else None,
        )

        step_dt = time.time() - step_t0

        cur_lr = optimizer.param_groups[0]["lr"] * optimizer.param_groups[0].get("lr_scale", 1.0)
        if math.isfinite(loss_val):
            total_loss += loss_val
            valid_count += 1
        step_toks = batch["input_ids"].numel()
        toks += step_toks
        step += 1

        # Print every step for the first 5 steps, then every log_every
        should_log = (step <= 5) or (step % args.log_every == 0)

        if step == 1:
            step_tps = step_toks / step_dt if step_dt > 0 else 0
            print(f"  βœ“ Step 1 completed in {step_dt:.1f}s "
                  f"({step_tps:.0f} tok/s, loss={loss_val:.4f})")
            sys.stdout.flush()

        if should_log:
            dt = time.time() - t0
            if valid_count > 0:
                avg = total_loss / valid_count
                ppl = math.exp(min(avg, 20)) if math.isfinite(avg) else float("nan")
            else:
                avg = float("nan")
                ppl = float("nan")
            tps = toks / dt if dt > 0 else 0
            elapsed = time.time() - t_start
            eta_s = (args.max_steps - step) * (elapsed / max(1, step))
            log_f.write(json.dumps({
                "step": step, "loss": round(avg, 4) if math.isfinite(avg) else None,
                "ppl": round(ppl, 2) if math.isfinite(ppl) else None,
                "lr": round(cur_lr, 6), "tok/s": round(tps),
                "seq": cur_seq, "loops": cur_loops,
                "step_time": round(step_dt, 2),
            }) + "\n")
            log_f.flush()
            print(
                f"  step {step:>6}/{args.max_steps} | loss {avg:.4f} | ppl {ppl:>8.2f} "
                f"| {tps:,.0f} tok/s | {step_dt:.1f}s/step | seq {cur_seq} "
                f"| ETA {eta_s / 60:.0f}m"
            )
            sys.stdout.flush()

            if step > 5:
                # Reset counters for clean averages
                best_loss = min(best_loss, avg) if math.isfinite(avg) else best_loss
                total_loss, valid_count, toks, t0 = 0.0, 0, 0, time.time()

        if step % args.save_every == 0:
            d = save_training_checkpoint(model, config, step,
                                          os.path.join(args.output_dir, f"ckpt-{step}"))
            print(f"  [SAVE] {d}")
            sys.stdout.flush()

    d = save_final_checkpoint(model, config, step, best_loss,
                               os.path.join(args.output_dir, "final"))
    log_f.close()
    total_time = time.time() - t_start
    print(f"\nDONE -- best loss {best_loss:.4f}  ppl {math.exp(min(best_loss, 20)):.2f}"
          f"  total time {total_time / 60:.1f}m")
    sys.stdout.flush()
    return d