File size: 5,282 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
"""Text-only pure-ternary training.

For the maintained CLI entrypoint prefer ``python -m arbitor.train`` or the
``arbs-train`` console script after ``pip install -e .``.
"""
import os, sys, time, math, json, torch
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from torch.utils.tensorboard import SummaryWriter
from arbitor import ARBModel, VOCAB, CTX
from arbitor.kernel.ternary_scale import TScaleType
from arbitor.kernel.ternary_audit import freeze_float_parameters, audit_model, format_audit


def load_text_data(source, ctx=CTX):
    """Load byte dataset from .txt or .pt file, or HuggingFace dataset name."""
    if source.endswith('.pt'):
        data = torch.load(source, weights_only=True)
    elif os.path.isfile(source):
        data = torch.tensor(list(open(source, 'rb').read()), dtype=torch.long)
    else:
        raise ValueError(f"Dataset not found: {source}")
    n = int(0.9 * len(data))
    return data[:n].cuda(), data[n:].cuda()


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description="ARB text training")
    parser.add_argument("--steps", type=int, default=5000)
    parser.add_argument("--batch", type=int, default=8)
    parser.add_argument("--ctx", type=int, default=min(CTX, 128))
    parser.add_argument("--data", type=str, default="training/data/tinyshakespeare.txt")
    parser.add_argument("--run", type=str, default="text")
    parser.add_argument("--eval-interval", type=int, default=250)
    parser.add_argument("--log-interval", type=int, default=10)
    parser.add_argument("--backend", choices=("triton", "torch", "auto", "tilelang"), default="triton")
    parser.add_argument("--full-text-stack", action="store_true",
                        help="Enable text VQ + KG/MoEGraph + KV attention. Default is byte-text only.")
    parser.add_argument("--with-vq", action="store_true")
    parser.add_argument("--with-graph", action="store_true")
    parser.add_argument("--with-moe", action="store_true")
    parser.add_argument("--with-attention", action="store_true")
    args = parser.parse_args()
    os.environ["ARB_TERNARY_BACKEND"] = args.backend
    if args.backend == "tilelang" and os.environ.get("ARB_TILELANG_TRAINING", "0").lower() not in {"1", "true", "yes"}:
        raise ValueError("TileLang BigInt training is unfinished. Use --backend triton for training.")

    enable_vq = args.full_text_stack or args.with_vq
    enable_graph = (args.full_text_stack or args.with_graph) and enable_vq
    enable_moe = args.full_text_stack or args.with_moe
    enable_attention = (args.full_text_stack or args.with_attention) and enable_graph

    model = ARBModel(enable_image=False, enable_audio=False,
                     enable_vq=enable_vq, enable_graph=enable_graph,
                     enable_memory_modules=False, enable_moe=enable_moe,
                     max_moe_iters=4,
                     enable_attention=enable_attention,
                     enable_output_router=False,
                     enable_video_output=False,
                     enable_talker_output=False).cuda()
    freeze_float_parameters(model)

    print(format_audit(audit_model(model)))
    print(
        "Text trainer config: "
        f"backend={args.backend}, vq={enable_vq}, graph={enable_graph}, "
        f"moe={enable_moe}, attention={enable_attention}"
    )
    train_params = [p for p in model.parameters() if p.requires_grad]
    if train_params:
        raise RuntimeError("Text trainer is pure ternary; use training/finetuning for float LoRA adapters.")
    print("Pure ternary update path is active")

    train_data, val_data = load_text_data(args.data, args.ctx)
    run_dir = f"models/checkpoints/{args.run}"
    os.makedirs(run_dir, exist_ok=True)
    writer = SummaryWriter(run_dir)
    step = 0
    best = float('inf')

    while step < args.steps:
        ix = torch.randint(0, len(train_data) - args.ctx - 1, (args.batch,))
        x = torch.stack([train_data[i:i+args.ctx] for i in ix])
        t = x[:, 3:]

        _, losses, _, _ = model(x, targets=t)
        model.prepare_ternary_backward(losses.total.detach(), update_scales=True)
        losses.total.backward()

        model._ternary_update_memory(accum_threshold=3, update_scales=True, loss_signal=losses.total)
        model.zero_grad(set_to_none=True)

        step += 1
        if args.log_interval and step % args.log_interval == 0:
            writer.add_scalar("loss/train", losses.total.item(), step)
            print(f"step {step:>5d}  train={losses.total.item():.3f}")

        if args.eval_interval and step % args.eval_interval == 0:
            with torch.no_grad():
                ix_v = torch.randint(0, len(val_data) - args.ctx - 1, (args.batch,))
                xv = torch.stack([val_data[i:i+args.ctx] for i in ix_v])
                tv = xv[:, 3:]
                _, lv, _, _ = model(xv, targets=tv)
                val = lv.total.item()
            writer.add_scalar("loss/train", losses.total.item(), step)
            writer.add_scalar("loss/eval", val, step)
            if val < best:
                best = val
                torch.save(model.state_dict(), f"{run_dir}/best.pt")
            print(f"step {step:>5d}  train={losses.total.item():.3f}  eval={val:.3f}  best={best:.3f}")