| """Audio/speech training. |
| Freezes text pipeline, trains TalkerHead + OutputRouter. |
| Uses AudioVQEncoder to prepare training targets from audio files. |
| """ |
| import math, os, sys, time, torch |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) |
| from torch.utils.tensorboard import SummaryWriter |
| from arbitor import ARBModel, AUDIO_SR |
| from arbitor.kernel.ternary_scale import TScaleType |
| from arbitor.kernel.ternary_audit import audit_model, format_audit, freeze_float_parameters, trainable_parameters |
| from arbitor.encoders.audio import AudioVQEncoder |
|
|
|
|
| def freeze_core(model): |
| """Freeze text pipeline (embedding through MoE/ByteHead).""" |
| for name, p in model.named_parameters(): |
| p.requires_grad = False |
| for name, p in model.named_parameters(): |
| if any(k in name for k in ('talker_head', 'output_router', 'video_head')): |
| p.requires_grad = True |
|
|
|
|
| def load_audio_data(source, sample_rate=AUDIO_SR): |
| """Load audio file and return waveform tensor.""" |
| import torchaudio |
| wav, sr = torchaudio.load(source) |
| if sr != sample_rate: |
| resample = torchaudio.transforms.Resample(sr, sample_rate) |
| wav = resample(wav) |
| |
| if wav.shape[0] > 1: |
| wav = wav.mean(dim=0, keepdim=True) |
| return wav |
|
|
|
|
| if __name__ == "__main__": |
| import argparse |
| parser = argparse.ArgumentParser(description="ARB audio training") |
| parser.add_argument("--steps", type=int, default=5000) |
| parser.add_argument("--batch", type=int, default=4) |
| parser.add_argument("--data", type=str, default=None, help="Path or HF dataset") |
| parser.add_argument("--audio-dir", type=str, default=None, help="Dir of .wav files") |
| parser.add_argument("--run", type=str, default="audio") |
| parser.add_argument("--ctx", type=int, default=AUDIO_SR, help="Audio samples per example") |
| parser.add_argument("--backend", choices=("triton", "torch", "auto", "tilelang"), default="triton") |
| args = parser.parse_args() |
| os.environ["ARB_TERNARY_BACKEND"] = args.backend |
| if args.backend == "tilelang" and os.environ.get("ARB_TILELANG_TRAINING", "0").lower() not in {"1", "true", "yes"}: |
| raise ValueError("TileLang BigInt training is unfinished. Use --backend triton for training.") |
| sample_len = max(args.ctx, AUDIO_SR) |
|
|
| if args.audio_dir: |
| import glob |
| files = glob.glob(os.path.join(args.audio_dir, "*.wav")) |
| print(f"Found {len(files)} audio files") |
| audio_data = [load_audio_data(f) for f in files[:100]] |
| else: |
| |
| audio_data = [torch.sin(torch.linspace(0, 440*2*math.pi, sample_len)).unsqueeze(0)] |
| print("No audio data provided — using synthetic test tones") |
|
|
| model = ARBModel(enable_image=False, enable_audio=True, |
| enable_vq=False, enable_graph=False, |
| enable_memory_modules=False, enable_moe=False, |
| max_moe_iters=4, |
| enable_attention=False, |
| enable_output_router=False, |
| enable_video_output=False, |
| enable_talker_output=True).cuda() |
| freeze_core(model) |
| freeze_float_parameters(model) |
| vq_encoder = AudioVQEncoder().cuda() |
| print(format_audit(audit_model(model))) |
|
|
| if trainable_parameters(model): |
| raise RuntimeError("Audio trainer is pure ternary; use training/finetuning/audio.py for LoRA adapters.") |
| run_dir = f"models/checkpoints/{args.run}" |
| os.makedirs(run_dir, exist_ok=True) |
| writer = SummaryWriter(run_dir) |
|
|
| for step in range(args.steps): |
| batch = [audio_data[i % len(audio_data)] for i in range(step, step + args.batch)] |
| fixed = [] |
| for w in batch: |
| if w.dim() == 1: |
| w = w.unsqueeze(0) |
| if w.shape[0] > 1: |
| w = w.mean(dim=0, keepdim=True) |
| w = w[:, :sample_len] if w.shape[1] >= sample_len else torch.nn.functional.pad(w, (0, sample_len - w.shape[1])) |
| fixed.append(w) |
| wavs = torch.stack(fixed).cuda() |
|
|
| model.zero_grad(set_to_none=True) |
|
|
| with torch.no_grad(): |
| _, target_tokens = vq_encoder(wavs) |
|
|
| rel = model.audio_sequencer(wavs) |
| pred_logits = model.talker_head.token_logits(rel, max_frames=target_tokens.shape[1]) |
|
|
| loss = torch.nn.functional.cross_entropy( |
| pred_logits.reshape(-1, pred_logits.size(-1)), |
| target_tokens.reshape(-1), |
| ) |
| model.prepare_ternary_backward(loss.detach(), update_scales=True) |
| loss.backward() |
| model._ternary_update_memory(accum_threshold=3, update_scales=True, loss_signal=loss) |
| model.zero_grad(set_to_none=True) |
|
|
| if step % 100 == 0: |
| writer.add_scalar("loss/audio", loss.item(), step) |
| print(f"step {step:>5d} loss={loss.item():.3f}") |
|
|