| """Fine-tune ARB model on audio/speech tasks using LoRA. |
| |
| Freezes text pipeline, adapts audio encoder + core MoE. |
| Designed for 8GB VRAM with batch_size=1. |
| |
| Usage: |
| python training/finetuning/audio.py \\ |
| --audio-dir ./speech-data \\ |
| --steps 2000 --batch 1 --accum 4 --lr 1e-4 \\ |
| --lora-rank 16 --run audio-finetune |
| |
| Data format: directory of .wav files + transcripts.txt |
| transcripts.txt: each line is "filename.wav|transcript text" |
| """ |
| import os, sys, time |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) |
| import torch |
| from torch.utils.tensorboard import SummaryWriter |
|
|
|
|
| def load_model(lora_rank=16, lora_alpha=32.0, max_moe_iters=1): |
| """Build ARB model with audio + LoRA, freeze text parts.""" |
| from arbitor import ARBModel |
| from training.finetuning.lora import apply_lora_to_model, count_lora_params |
|
|
| model = ARBModel( |
| enable_image=False, enable_audio=True, |
| enable_vq=True, enable_graph=True, |
| enable_memory_modules=False, enable_moe=True, |
| max_moe_iters=max_moe_iters, |
| ).cuda() |
|
|
| target_modules = ['W_gate', 'W_transform', 'byte_head', 'head', 'router', |
| 'shared_up', 'shared_expert_gate', 'shared_expert_up', |
| 'frame_proj', 'audio_sequencer'] |
| lora_layers = apply_lora_to_model(model, rank=lora_rank, alpha=lora_alpha, |
| target_modules=target_modules) |
| lora_p, total_p = count_lora_params(model) |
| print(f" LoRA trainable: {lora_p:,} params ({lora_p/1e6:.2f}M)", flush=True) |
| return model, lora_layers |
|
|
|
|
| def load_audio_data(audio_dir, sr=16000): |
| """Load audio files and transcripts from directory. |
| |
| Expects transcripts.txt with lines like: |
| sample1.wav|Hello world this is a test |
| sample2.wav|Another example transcript |
| """ |
| from arbitor.config import SPECIAL_VOCAB |
| import torchaudio |
|
|
| trans_path = os.path.join(audio_dir, "transcripts.txt") |
| if not os.path.isfile(trans_path): |
| print(f" No transcripts.txt found in {audio_dir}", flush=True) |
| print(f" Using raw audio only (no text targets)", flush=True) |
| return _load_raw_audio(audio_dir, sr) |
|
|
| data = [] |
| with open(trans_path, "r") as f: |
| for line in f: |
| line = line.strip() |
| if not line or '|' not in line: |
| continue |
| wav_name, transcript = line.split("|", 1) |
| wav_path = os.path.join(audio_dir, wav_name) |
| if not os.path.isfile(wav_path): |
| continue |
|
|
| wav, sample_rate = torchaudio.load(wav_path) |
| if sample_rate != sr: |
| resample = torchaudio.transforms.Resample(sample_rate, sr) |
| wav = resample(wav) |
| if wav.shape[0] > 1: |
| wav = wav.mean(dim=0, keepdim=True) |
| wav = wav[:, :sr * 5] |
|
|
| |
| tokens = [SPECIAL_VOCAB['BOS']] |
| for byte in transcript.encode('utf-8'): |
| tokens.append(byte) |
| tokens.append(SPECIAL_VOCAB['EOS']) |
| while len(tokens) < 4: |
| tokens.append(SPECIAL_VOCAB['PAD']) |
| text = torch.tensor(tokens, dtype=torch.long) |
|
|
| data.append((wav, text)) |
|
|
| print(f" Loaded {len(data)} audio-transcript pairs from {audio_dir}", flush=True) |
| return data |
|
|
|
|
| def _load_raw_audio(audio_dir, sr): |
| """Fallback: load raw audio without transcripts for self-supervised fine-tuning.""" |
| import glob, torchaudio |
| files = glob.glob(os.path.join(audio_dir, "*.wav")) + \ |
| glob.glob(os.path.join(audio_dir, "*.mp3")) |
| data = [] |
| for f in files[:500]: |
| wav, sample_rate = torchaudio.load(f) |
| if sample_rate != sr: |
| resample = torchaudio.transforms.Resample(sample_rate, sr) |
| wav = resample(wav) |
| if wav.shape[0] > 1: |
| wav = wav.mean(dim=0, keepdim=True) |
| wav = wav[:, :sr * 5] |
| data.append((wav, None)) |
| print(f" Loaded {len(data)} raw audio files (no transcripts)", flush=True) |
| return data |
|
|
|
|
| if __name__ == "__main__": |
| import argparse |
| parser = argparse.ArgumentParser(description="ARB audio fine-tuning") |
| parser.add_argument("--audio-dir", type=str, required=True, help="Dir with .wav files + transcripts.txt") |
| parser.add_argument("--steps", type=int, default=2000) |
| parser.add_argument("--batch", type=int, default=1) |
| parser.add_argument("--accum", type=int, default=4) |
| parser.add_argument("--lr", type=float, default=1e-4) |
| parser.add_argument("--lora-rank", type=int, default=16) |
| parser.add_argument("--lora-alpha", type=float, default=32.0) |
| parser.add_argument("--max-moe-iters", type=int, default=1) |
| parser.add_argument("--run", type=str, default="audio-finetune") |
| parser.add_argument("--eval-interval", type=int, default=100) |
| parser.add_argument("--save-every", type=int, default=500) |
| args = parser.parse_args() |
|
|
| print("Building model with audio + LoRA...", flush=True) |
| model, lora_layers = load_model(args.lora_rank, args.lora_alpha, args.max_moe_iters) |
| from arbitor.encoders.audio import AudioVQEncoder |
| audio_target_encoder = AudioVQEncoder().cuda().eval() |
|
|
| opt = torch.optim.AdamW( |
| [p for p in model.parameters() if p.requires_grad], |
| lr=args.lr, weight_decay=0.01 |
| ) |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.steps) |
|
|
| data = load_audio_data(args.audio_dir) |
| if len(data) == 0: |
| print("No audio data found! Use --audio-dir with .wav files.", flush=True) |
| sys.exit(1) |
|
|
| n = int(0.8 * len(data)) |
| if len(data) > 1: |
| n = min(max(1, n), len(data) - 1) |
| train_data = data[:n] if n > 0 else data |
| val_data = data[n:] if n < len(data) else data[:1] |
| run_dir = f"models/checkpoints/{args.run}" |
| os.makedirs(run_dir, exist_ok=True) |
| writer = SummaryWriter(run_dir) |
| step = 0 |
| best_val = float('inf') |
| model.train() |
|
|
| while step < args.steps: |
| opt.zero_grad() |
| accum_loss = 0.0 |
|
|
| for _ in range(args.accum): |
| idx = torch.randint(0, len(train_data), (args.batch,)).item() |
| wav, text = train_data[idx] |
| wav = wav.cuda() |
|
|
| if text is not None: |
| text = text.cuda().unsqueeze(0) |
| _, losses, _, _ = model(x=text, audio=wav, targets=text[:, 3:]) |
| loss_val = losses.total |
| else: |
| with torch.no_grad(): |
| _, target_tokens = audio_target_encoder(wav.unsqueeze(0) if wav.dim() == 2 else wav) |
| rel = model.audio_sequencer(wav) |
| pred_logits = model.talker_head.token_logits(rel, max_frames=target_tokens.shape[1]) |
| loss_val = torch.nn.functional.cross_entropy( |
| pred_logits.reshape(-1, pred_logits.shape[-1]), |
| target_tokens.reshape(-1), |
| ) |
|
|
| loss = loss_val / args.accum |
| loss.backward() |
| accum_loss += loss_val.item() |
|
|
| torch.nn.utils.clip_grad_norm_( |
| [p for p in model.parameters() if p.requires_grad], 1.0 |
| ) |
| opt.step() |
| scheduler.step() |
| step += 1 |
|
|
| if step % args.eval_interval == 0: |
| model.eval() |
| val_loss = 0.0 |
| with torch.no_grad(): |
| for idx in range(min(10, len(val_data))): |
| wav, text = val_data[idx] |
| wav = wav.cuda() |
| if text is not None: |
| text = text.cuda().unsqueeze(0) |
| txt_ctx = text[:, :max(4, min(text.shape[1], 16))] |
| _, lv, _, _ = model(x=txt_ctx, audio=wav, targets=txt_ctx[:, 3:]) |
| val_loss += lv.total.item() |
| else: |
| with torch.no_grad(): |
| _, target_tokens = audio_target_encoder(wav.unsqueeze(0) if wav.dim() == 2 else wav) |
| rel = model.audio_sequencer(wav) |
| pred_logits = model.talker_head.token_logits(rel, max_frames=target_tokens.shape[1]) |
| val_loss += torch.nn.functional.cross_entropy( |
| pred_logits.reshape(-1, pred_logits.shape[-1]), |
| target_tokens.reshape(-1), |
| ).item() |
| val_loss /= min(10, len(val_data)) if val_loss > 0 else 1 |
|
|
| writer.add_scalar("loss/train", accum_loss, step) |
| writer.add_scalar("loss/eval", val_loss, step) |
|
|
| if val_loss < best_val and val_loss > 0: |
| best_val = val_loss |
| from training.finetuning.lora import save_lora |
| save_lora(lora_layers, f"{run_dir}/best_lora.pt") |
|
|
| print(f"step {step:>5d}/{args.steps} train={accum_loss:.3f} " |
| f"eval={val_loss:.3f} best={best_val:.3f}", flush=True) |
| model.train() |
|
|
| from training.finetuning.lora import save_lora |
| save_lora(lora_layers, f"{run_dir}/final_lora.pt") |
| print(f"Done. LoRA saved to {run_dir}/", flush=True) |
|
|