File size: 4,888 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""Audio/speech training.
Freezes text pipeline, trains TalkerHead + OutputRouter.
Uses AudioVQEncoder to prepare training targets from audio files.
"""
import math, os, sys, time, torch
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from torch.utils.tensorboard import SummaryWriter
from arbitor import ARBModel, AUDIO_SR
from arbitor.kernel.ternary_scale import TScaleType
from arbitor.kernel.ternary_audit import audit_model, format_audit, freeze_float_parameters, trainable_parameters
from arbitor.encoders.audio import AudioVQEncoder


def freeze_core(model):
    """Freeze text pipeline (embedding through MoE/ByteHead)."""
    for name, p in model.named_parameters():
        p.requires_grad = False
    for name, p in model.named_parameters():
        if any(k in name for k in ('talker_head', 'output_router', 'video_head')):
            p.requires_grad = True


def load_audio_data(source, sample_rate=AUDIO_SR):
    """Load audio file and return waveform tensor."""
    import torchaudio
    wav, sr = torchaudio.load(source)
    if sr != sample_rate:
        resample = torchaudio.transforms.Resample(sr, sample_rate)
        wav = resample(wav)
    # Mono
    if wav.shape[0] > 1:
        wav = wav.mean(dim=0, keepdim=True)
    return wav


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description="ARB audio training")
    parser.add_argument("--steps", type=int, default=5000)
    parser.add_argument("--batch", type=int, default=4)
    parser.add_argument("--data", type=str, default=None, help="Path or HF dataset")
    parser.add_argument("--audio-dir", type=str, default=None, help="Dir of .wav files")
    parser.add_argument("--run", type=str, default="audio")
    parser.add_argument("--ctx", type=int, default=AUDIO_SR, help="Audio samples per example")
    parser.add_argument("--backend", choices=("triton", "torch", "auto", "tilelang"), default="triton")
    args = parser.parse_args()
    os.environ["ARB_TERNARY_BACKEND"] = args.backend
    if args.backend == "tilelang" and os.environ.get("ARB_TILELANG_TRAINING", "0").lower() not in {"1", "true", "yes"}:
        raise ValueError("TileLang BigInt training is unfinished. Use --backend triton for training.")
    sample_len = max(args.ctx, AUDIO_SR)

    if args.audio_dir:
        import glob
        files = glob.glob(os.path.join(args.audio_dir, "*.wav"))
        print(f"Found {len(files)} audio files")
        audio_data = [load_audio_data(f) for f in files[:100]]
    else:
        # Generate synthetic sine waves for smoke testing
        audio_data = [torch.sin(torch.linspace(0, 440*2*math.pi, sample_len)).unsqueeze(0)]
        print("No audio data provided — using synthetic test tones")

    model = ARBModel(enable_image=False, enable_audio=True,
                     enable_vq=False, enable_graph=False,
                     enable_memory_modules=False, enable_moe=False,
                     max_moe_iters=4,
                     enable_attention=False,
                     enable_output_router=False,
                     enable_video_output=False,
                     enable_talker_output=True).cuda()
    freeze_core(model)
    freeze_float_parameters(model)
    vq_encoder = AudioVQEncoder().cuda()
    print(format_audit(audit_model(model)))

    if trainable_parameters(model):
        raise RuntimeError("Audio trainer is pure ternary; use training/finetuning/audio.py for LoRA adapters.")
    run_dir = f"models/checkpoints/{args.run}"
    os.makedirs(run_dir, exist_ok=True)
    writer = SummaryWriter(run_dir)

    for step in range(args.steps):
        batch = [audio_data[i % len(audio_data)] for i in range(step, step + args.batch)]
        fixed = []
        for w in batch:
            if w.dim() == 1:
                w = w.unsqueeze(0)
            if w.shape[0] > 1:
                w = w.mean(dim=0, keepdim=True)
            w = w[:, :sample_len] if w.shape[1] >= sample_len else torch.nn.functional.pad(w, (0, sample_len - w.shape[1]))
            fixed.append(w)
        wavs = torch.stack(fixed).cuda()

        model.zero_grad(set_to_none=True)

        with torch.no_grad():
            _, target_tokens = vq_encoder(wavs)

        rel = model.audio_sequencer(wavs)
        pred_logits = model.talker_head.token_logits(rel, max_frames=target_tokens.shape[1])

        loss = torch.nn.functional.cross_entropy(
            pred_logits.reshape(-1, pred_logits.size(-1)),
            target_tokens.reshape(-1),
        )
        model.prepare_ternary_backward(loss.detach(), update_scales=True)
        loss.backward()
        model._ternary_update_memory(accum_threshold=3, update_scales=True, loss_signal=loss)
        model.zero_grad(set_to_none=True)

        if step % 100 == 0:
            writer.add_scalar("loss/audio", loss.item(), step)
            print(f"step {step:>5d}  loss={loss.item():.3f}")