ARBS / training /finetuning /audio.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
"""Fine-tune ARB model on audio/speech tasks using LoRA.
Freezes text pipeline, adapts audio encoder + core MoE.
Designed for 8GB VRAM with batch_size=1.
Usage:
python training/finetuning/audio.py \\
--audio-dir ./speech-data \\
--steps 2000 --batch 1 --accum 4 --lr 1e-4 \\
--lora-rank 16 --run audio-finetune
Data format: directory of .wav files + transcripts.txt
transcripts.txt: each line is "filename.wav|transcript text"
"""
import os, sys, time
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
import torch
from torch.utils.tensorboard import SummaryWriter
def load_model(lora_rank=16, lora_alpha=32.0, max_moe_iters=1):
"""Build ARB model with audio + LoRA, freeze text parts."""
from arbitor import ARBModel
from training.finetuning.lora import apply_lora_to_model, count_lora_params
model = ARBModel(
enable_image=False, enable_audio=True,
enable_vq=True, enable_graph=True,
enable_memory_modules=False, enable_moe=True,
max_moe_iters=max_moe_iters,
).cuda()
target_modules = ['W_gate', 'W_transform', 'byte_head', 'head', 'router',
'shared_up', 'shared_expert_gate', 'shared_expert_up',
'frame_proj', 'audio_sequencer']
lora_layers = apply_lora_to_model(model, rank=lora_rank, alpha=lora_alpha,
target_modules=target_modules)
lora_p, total_p = count_lora_params(model)
print(f" LoRA trainable: {lora_p:,} params ({lora_p/1e6:.2f}M)", flush=True)
return model, lora_layers
def load_audio_data(audio_dir, sr=16000):
"""Load audio files and transcripts from directory.
Expects transcripts.txt with lines like:
sample1.wav|Hello world this is a test
sample2.wav|Another example transcript
"""
from arbitor.config import SPECIAL_VOCAB
import torchaudio
trans_path = os.path.join(audio_dir, "transcripts.txt")
if not os.path.isfile(trans_path):
print(f" No transcripts.txt found in {audio_dir}", flush=True)
print(f" Using raw audio only (no text targets)", flush=True)
return _load_raw_audio(audio_dir, sr)
data = []
with open(trans_path, "r") as f:
for line in f:
line = line.strip()
if not line or '|' not in line:
continue
wav_name, transcript = line.split("|", 1)
wav_path = os.path.join(audio_dir, wav_name)
if not os.path.isfile(wav_path):
continue
wav, sample_rate = torchaudio.load(wav_path)
if sample_rate != sr:
resample = torchaudio.transforms.Resample(sample_rate, sr)
wav = resample(wav)
if wav.shape[0] > 1:
wav = wav.mean(dim=0, keepdim=True)
wav = wav[:, :sr * 5] # max 5 seconds
# Tokenize transcript
tokens = [SPECIAL_VOCAB['BOS']]
for byte in transcript.encode('utf-8'):
tokens.append(byte)
tokens.append(SPECIAL_VOCAB['EOS'])
while len(tokens) < 4:
tokens.append(SPECIAL_VOCAB['PAD'])
text = torch.tensor(tokens, dtype=torch.long)
data.append((wav, text))
print(f" Loaded {len(data)} audio-transcript pairs from {audio_dir}", flush=True)
return data
def _load_raw_audio(audio_dir, sr):
"""Fallback: load raw audio without transcripts for self-supervised fine-tuning."""
import glob, torchaudio
files = glob.glob(os.path.join(audio_dir, "*.wav")) + \
glob.glob(os.path.join(audio_dir, "*.mp3"))
data = []
for f in files[:500]:
wav, sample_rate = torchaudio.load(f)
if sample_rate != sr:
resample = torchaudio.transforms.Resample(sample_rate, sr)
wav = resample(wav)
if wav.shape[0] > 1:
wav = wav.mean(dim=0, keepdim=True)
wav = wav[:, :sr * 5]
data.append((wav, None))
print(f" Loaded {len(data)} raw audio files (no transcripts)", flush=True)
return data
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="ARB audio fine-tuning")
parser.add_argument("--audio-dir", type=str, required=True, help="Dir with .wav files + transcripts.txt")
parser.add_argument("--steps", type=int, default=2000)
parser.add_argument("--batch", type=int, default=1)
parser.add_argument("--accum", type=int, default=4)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--lora-rank", type=int, default=16)
parser.add_argument("--lora-alpha", type=float, default=32.0)
parser.add_argument("--max-moe-iters", type=int, default=1)
parser.add_argument("--run", type=str, default="audio-finetune")
parser.add_argument("--eval-interval", type=int, default=100)
parser.add_argument("--save-every", type=int, default=500)
args = parser.parse_args()
print("Building model with audio + LoRA...", flush=True)
model, lora_layers = load_model(args.lora_rank, args.lora_alpha, args.max_moe_iters)
from arbitor.encoders.audio import AudioVQEncoder
audio_target_encoder = AudioVQEncoder().cuda().eval()
opt = torch.optim.AdamW(
[p for p in model.parameters() if p.requires_grad],
lr=args.lr, weight_decay=0.01
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.steps)
data = load_audio_data(args.audio_dir)
if len(data) == 0:
print("No audio data found! Use --audio-dir with .wav files.", flush=True)
sys.exit(1)
n = int(0.8 * len(data))
if len(data) > 1:
n = min(max(1, n), len(data) - 1)
train_data = data[:n] if n > 0 else data
val_data = data[n:] if n < len(data) else data[:1]
run_dir = f"models/checkpoints/{args.run}"
os.makedirs(run_dir, exist_ok=True)
writer = SummaryWriter(run_dir)
step = 0
best_val = float('inf')
model.train()
while step < args.steps:
opt.zero_grad()
accum_loss = 0.0
for _ in range(args.accum):
idx = torch.randint(0, len(train_data), (args.batch,)).item()
wav, text = train_data[idx]
wav = wav.cuda()
if text is not None:
text = text.cuda().unsqueeze(0)
_, losses, _, _ = model(x=text, audio=wav, targets=text[:, 3:])
loss_val = losses.total
else:
with torch.no_grad():
_, target_tokens = audio_target_encoder(wav.unsqueeze(0) if wav.dim() == 2 else wav)
rel = model.audio_sequencer(wav)
pred_logits = model.talker_head.token_logits(rel, max_frames=target_tokens.shape[1])
loss_val = torch.nn.functional.cross_entropy(
pred_logits.reshape(-1, pred_logits.shape[-1]),
target_tokens.reshape(-1),
)
loss = loss_val / args.accum
loss.backward()
accum_loss += loss_val.item()
torch.nn.utils.clip_grad_norm_(
[p for p in model.parameters() if p.requires_grad], 1.0
)
opt.step()
scheduler.step()
step += 1
if step % args.eval_interval == 0:
model.eval()
val_loss = 0.0
with torch.no_grad():
for idx in range(min(10, len(val_data))):
wav, text = val_data[idx]
wav = wav.cuda()
if text is not None:
text = text.cuda().unsqueeze(0)
txt_ctx = text[:, :max(4, min(text.shape[1], 16))]
_, lv, _, _ = model(x=txt_ctx, audio=wav, targets=txt_ctx[:, 3:])
val_loss += lv.total.item()
else:
with torch.no_grad():
_, target_tokens = audio_target_encoder(wav.unsqueeze(0) if wav.dim() == 2 else wav)
rel = model.audio_sequencer(wav)
pred_logits = model.talker_head.token_logits(rel, max_frames=target_tokens.shape[1])
val_loss += torch.nn.functional.cross_entropy(
pred_logits.reshape(-1, pred_logits.shape[-1]),
target_tokens.reshape(-1),
).item()
val_loss /= min(10, len(val_data)) if val_loss > 0 else 1
writer.add_scalar("loss/train", accum_loss, step)
writer.add_scalar("loss/eval", val_loss, step)
if val_loss < best_val and val_loss > 0:
best_val = val_loss
from training.finetuning.lora import save_lora
save_lora(lora_layers, f"{run_dir}/best_lora.pt")
print(f"step {step:>5d}/{args.steps} train={accum_loss:.3f} "
f"eval={val_loss:.3f} best={best_val:.3f}", flush=True)
model.train()
from training.finetuning.lora import save_lora
save_lora(lora_layers, f"{run_dir}/final_lora.pt")
print(f"Done. LoRA saved to {run_dir}/", flush=True)