"""Fine-tune ARB model on audio/speech tasks using LoRA. Freezes text pipeline, adapts audio encoder + core MoE. Designed for 8GB VRAM with batch_size=1. Usage: python training/finetuning/audio.py \\ --audio-dir ./speech-data \\ --steps 2000 --batch 1 --accum 4 --lr 1e-4 \\ --lora-rank 16 --run audio-finetune Data format: directory of .wav files + transcripts.txt transcripts.txt: each line is "filename.wav|transcript text" """ import os, sys, time sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) import torch from torch.utils.tensorboard import SummaryWriter def load_model(lora_rank=16, lora_alpha=32.0, max_moe_iters=1): """Build ARB model with audio + LoRA, freeze text parts.""" from arbitor import ARBModel from training.finetuning.lora import apply_lora_to_model, count_lora_params model = ARBModel( enable_image=False, enable_audio=True, enable_vq=True, enable_graph=True, enable_memory_modules=False, enable_moe=True, max_moe_iters=max_moe_iters, ).cuda() target_modules = ['W_gate', 'W_transform', 'byte_head', 'head', 'router', 'shared_up', 'shared_expert_gate', 'shared_expert_up', 'frame_proj', 'audio_sequencer'] lora_layers = apply_lora_to_model(model, rank=lora_rank, alpha=lora_alpha, target_modules=target_modules) lora_p, total_p = count_lora_params(model) print(f" LoRA trainable: {lora_p:,} params ({lora_p/1e6:.2f}M)", flush=True) return model, lora_layers def load_audio_data(audio_dir, sr=16000): """Load audio files and transcripts from directory. Expects transcripts.txt with lines like: sample1.wav|Hello world this is a test sample2.wav|Another example transcript """ from arbitor.config import SPECIAL_VOCAB import torchaudio trans_path = os.path.join(audio_dir, "transcripts.txt") if not os.path.isfile(trans_path): print(f" No transcripts.txt found in {audio_dir}", flush=True) print(f" Using raw audio only (no text targets)", flush=True) return _load_raw_audio(audio_dir, sr) data = [] with open(trans_path, "r") as f: for line in f: line = line.strip() if not line or '|' not in line: continue wav_name, transcript = line.split("|", 1) wav_path = os.path.join(audio_dir, wav_name) if not os.path.isfile(wav_path): continue wav, sample_rate = torchaudio.load(wav_path) if sample_rate != sr: resample = torchaudio.transforms.Resample(sample_rate, sr) wav = resample(wav) if wav.shape[0] > 1: wav = wav.mean(dim=0, keepdim=True) wav = wav[:, :sr * 5] # max 5 seconds # Tokenize transcript tokens = [SPECIAL_VOCAB['BOS']] for byte in transcript.encode('utf-8'): tokens.append(byte) tokens.append(SPECIAL_VOCAB['EOS']) while len(tokens) < 4: tokens.append(SPECIAL_VOCAB['PAD']) text = torch.tensor(tokens, dtype=torch.long) data.append((wav, text)) print(f" Loaded {len(data)} audio-transcript pairs from {audio_dir}", flush=True) return data def _load_raw_audio(audio_dir, sr): """Fallback: load raw audio without transcripts for self-supervised fine-tuning.""" import glob, torchaudio files = glob.glob(os.path.join(audio_dir, "*.wav")) + \ glob.glob(os.path.join(audio_dir, "*.mp3")) data = [] for f in files[:500]: wav, sample_rate = torchaudio.load(f) if sample_rate != sr: resample = torchaudio.transforms.Resample(sample_rate, sr) wav = resample(wav) if wav.shape[0] > 1: wav = wav.mean(dim=0, keepdim=True) wav = wav[:, :sr * 5] data.append((wav, None)) print(f" Loaded {len(data)} raw audio files (no transcripts)", flush=True) return data if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="ARB audio fine-tuning") parser.add_argument("--audio-dir", type=str, required=True, help="Dir with .wav files + transcripts.txt") parser.add_argument("--steps", type=int, default=2000) parser.add_argument("--batch", type=int, default=1) parser.add_argument("--accum", type=int, default=4) parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument("--lora-rank", type=int, default=16) parser.add_argument("--lora-alpha", type=float, default=32.0) parser.add_argument("--max-moe-iters", type=int, default=1) parser.add_argument("--run", type=str, default="audio-finetune") parser.add_argument("--eval-interval", type=int, default=100) parser.add_argument("--save-every", type=int, default=500) args = parser.parse_args() print("Building model with audio + LoRA...", flush=True) model, lora_layers = load_model(args.lora_rank, args.lora_alpha, args.max_moe_iters) from arbitor.encoders.audio import AudioVQEncoder audio_target_encoder = AudioVQEncoder().cuda().eval() opt = torch.optim.AdamW( [p for p in model.parameters() if p.requires_grad], lr=args.lr, weight_decay=0.01 ) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.steps) data = load_audio_data(args.audio_dir) if len(data) == 0: print("No audio data found! Use --audio-dir with .wav files.", flush=True) sys.exit(1) n = int(0.8 * len(data)) if len(data) > 1: n = min(max(1, n), len(data) - 1) train_data = data[:n] if n > 0 else data val_data = data[n:] if n < len(data) else data[:1] run_dir = f"models/checkpoints/{args.run}" os.makedirs(run_dir, exist_ok=True) writer = SummaryWriter(run_dir) step = 0 best_val = float('inf') model.train() while step < args.steps: opt.zero_grad() accum_loss = 0.0 for _ in range(args.accum): idx = torch.randint(0, len(train_data), (args.batch,)).item() wav, text = train_data[idx] wav = wav.cuda() if text is not None: text = text.cuda().unsqueeze(0) _, losses, _, _ = model(x=text, audio=wav, targets=text[:, 3:]) loss_val = losses.total else: with torch.no_grad(): _, target_tokens = audio_target_encoder(wav.unsqueeze(0) if wav.dim() == 2 else wav) rel = model.audio_sequencer(wav) pred_logits = model.talker_head.token_logits(rel, max_frames=target_tokens.shape[1]) loss_val = torch.nn.functional.cross_entropy( pred_logits.reshape(-1, pred_logits.shape[-1]), target_tokens.reshape(-1), ) loss = loss_val / args.accum loss.backward() accum_loss += loss_val.item() torch.nn.utils.clip_grad_norm_( [p for p in model.parameters() if p.requires_grad], 1.0 ) opt.step() scheduler.step() step += 1 if step % args.eval_interval == 0: model.eval() val_loss = 0.0 with torch.no_grad(): for idx in range(min(10, len(val_data))): wav, text = val_data[idx] wav = wav.cuda() if text is not None: text = text.cuda().unsqueeze(0) txt_ctx = text[:, :max(4, min(text.shape[1], 16))] _, lv, _, _ = model(x=txt_ctx, audio=wav, targets=txt_ctx[:, 3:]) val_loss += lv.total.item() else: with torch.no_grad(): _, target_tokens = audio_target_encoder(wav.unsqueeze(0) if wav.dim() == 2 else wav) rel = model.audio_sequencer(wav) pred_logits = model.talker_head.token_logits(rel, max_frames=target_tokens.shape[1]) val_loss += torch.nn.functional.cross_entropy( pred_logits.reshape(-1, pred_logits.shape[-1]), target_tokens.reshape(-1), ).item() val_loss /= min(10, len(val_data)) if val_loss > 0 else 1 writer.add_scalar("loss/train", accum_loss, step) writer.add_scalar("loss/eval", val_loss, step) if val_loss < best_val and val_loss > 0: best_val = val_loss from training.finetuning.lora import save_lora save_lora(lora_layers, f"{run_dir}/best_lora.pt") print(f"step {step:>5d}/{args.steps} train={accum_loss:.3f} " f"eval={val_loss:.3f} best={best_val:.3f}", flush=True) model.train() from training.finetuning.lora import save_lora save_lora(lora_layers, f"{run_dir}/final_lora.pt") print(f"Done. LoRA saved to {run_dir}/", flush=True)