File size: 2,025 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
import sys
import os
import math

sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))


def _cuda_available():
    if not torch.cuda.is_available():
        return False
    free, total = torch.cuda.mem_get_info()
    if total < 10e9:
        return False
    return True


def test_200_step_smoke():
    if not torch.cuda.is_available():
        print(" SKIP test_200_step_smoke (no CUDA)")
        return
    free, total = torch.cuda.mem_get_info()
    if total < 7.5e9:
        print(f" SKIP test_200_step_smoke (GPU {total/1e9:.1f}GB < 7.5GB)")
        return
    from arbitor.main import ARBModel
    from arbitor.kernel.ternary_scale import TScaleType
    from arbitor.config import VOCAB

    model = ARBModel(
        tscale_type=TScaleType.T32,
        enable_image=False,
        enable_audio=False,
        enable_vq=True,
        enable_graph=True,
        enable_memory_modules=True,
        enable_moe=False,
    ).cuda()

    data = torch.tensor(
        list(open("training/data/tinyshakespeare.txt", "rb").read()),
        dtype=torch.long
    )
    train_data = data[:int(0.9 * data.numel())]

    def get_batch(data, bs, ctx):
        ix = torch.randint(0, data.numel() - ctx - 1, (bs,))
        x = torch.stack([data[i:i+ctx] for i in ix]).cuda()
        return x, x[:, 3:].contiguous()

    losses = []
    for step in range(200):
        model.zero_grad(set_to_none=True)
        accum_loss = 0.0
        for _ in range(2):
            x, t = get_batch(train_data, 1, 64)
            _, lc, _, _ = model(x, targets=t)
            loss = lc.total / 2
            assert torch.isfinite(loss).all(), f"Non-finite loss at step {step}"
            accum_loss += lc.total.item()
        model._ternary_update_memory(loss_components=lc)
        losses.append(accum_loss / 2)

    assert all(math.isfinite(l) for l in losses), "Non-finite loss detected"
    print(f" PASS test_200_step_smoke: {losses[0]:.2f} -> {losses[-1]:.2f} (min={min(losses):.2f}, max={max(losses):.2f})")