"""Audio encoding and iterative unmasking inference. Adapted from midmid/prediction/model.py for standalone use. Device management is caller-controlled (for ZeroGPU compatibility). """ import itertools as _it import json import math from pathlib import Path from typing import Optional import numpy as np import torch from midmid.nn import ( ChartMaskPredictor, ChartMaskPredictorConfig, MASK_TOKEN, SILENCE_TOKEN, ) from midmid.datatypes import NoteEvent MERT_MODEL_ID = "m-a-p/MERT-v1-95M" DIFF_ID = {"easy": 0, "medium": 1, "hard": 2, "expert": 3} # Class ID -> fret tuple _CLASS_TO_FRETS: list[tuple[int, ...]] = [] for _r in range(1, 6): _CLASS_TO_FRETS.extend(_it.combinations(range(5), _r)) _CLASS_TO_FRETS.append((7,)) # class 31 = open # Sustain bucket center values in beats _BUCKET_BEATS = [0.0, 1.0, 2.0, 4.0, 8.0, 16.0] # --------------------------------------------------------------------------- # Model loading (safetensors from HF Hub) # --------------------------------------------------------------------------- def load_model_from_hub( repo_id: str = "markury/midmid3-19m-0326", device: str = "cpu", ) -> ChartMaskPredictor: """Download and load model from HuggingFace Hub (safetensors).""" from huggingface_hub import hf_hub_download from safetensors.torch import load_file config_path = hf_hub_download(repo_id, "config.json") weights_path = hf_hub_download(repo_id, "model.safetensors") with open(config_path) as f: config_dict = json.load(f) config = ChartMaskPredictorConfig(**config_dict) model = ChartMaskPredictor(config) state_dict = load_file(weights_path, device=device) model.load_state_dict(state_dict) model.to(device) model.eval() return model # --------------------------------------------------------------------------- # MERT audio encoding (lazy-loaded) # --------------------------------------------------------------------------- _mert_model = None _mert_processor = None _mert_frame_rate = None def _ensure_mert(device: torch.device): """Load MERT model and processor on first use.""" global _mert_model, _mert_processor, _mert_frame_rate if _mert_model is not None: # Move to correct device if needed if next(_mert_model.parameters()).device != device: _mert_model.to(device) return from transformers import AutoModel, Wav2Vec2FeatureExtractor print(f"Loading MERT ({MERT_MODEL_ID}) ...") _mert_processor = Wav2Vec2FeatureExtractor.from_pretrained( MERT_MODEL_ID, trust_remote_code=True, ) _mert_model = AutoModel.from_pretrained(MERT_MODEL_ID, trust_remote_code=True) _mert_model.to(device) _mert_model.eval() # Compute frame rate dynamically sr = _mert_processor.sampling_rate test_wav = np.zeros(sr, dtype=np.float32) inputs = _mert_processor(test_wav, sampling_rate=sr, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): out = _mert_model(**inputs, output_hidden_states=False) _mert_frame_rate = float(out.last_hidden_state.shape[1]) print(f" MERT frame rate: {_mert_frame_rate:.2f} Hz") def move_models_to_device(device: torch.device): """Move all cached models to the specified device (for ZeroGPU).""" global _mert_model if _mert_model is not None: _mert_model.to(device) @torch.no_grad() def encode_audio_mert( audio_path: str, device: torch.device, chunk_sec: float = 60.0, ) -> tuple[torch.Tensor, float]: """Encode audio with MERT, return (embeddings, frame_rate).""" import librosa _ensure_mert(device) sr = _mert_processor.sampling_rate wav, _ = librosa.load(audio_path, sr=sr, mono=True) chunk_samples = int(chunk_sec * sr) overlap_sec = 5.0 overlap_samples = int(overlap_sec * sr) stride_samples = chunk_samples - overlap_samples if len(wav) <= chunk_samples: inputs = _mert_processor(wav, sampling_rate=sr, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} out = _mert_model(**inputs, output_hidden_states=False) return out.last_hidden_state.squeeze(0).cpu(), _mert_frame_rate # Chunked processing for long audio all_emb = [] pos = 0 idx = 0 while pos < len(wav): end = min(pos + chunk_samples, len(wav)) chunk = wav[pos:end] min_len = chunk_samples // 4 if len(chunk) < min_len: chunk = np.pad(chunk, (0, min_len - len(chunk))) inputs = _mert_processor(chunk, sampling_rate=sr, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} out = _mert_model(**inputs, output_hidden_states=False) emb = out.last_hidden_state.squeeze(0) n = emb.shape[0] fps = n / (len(chunk) / sr) half_overlap = int(round((overlap_sec / 2) * fps)) if idx == 0: keep = n - half_overlap if end < len(wav) else n all_emb.append(emb[:keep].cpu()) elif end >= len(wav): all_emb.append(emb[half_overlap:].cpu()) else: keep = int(round((len(chunk) / sr - overlap_sec) * fps)) all_emb.append(emb[half_overlap:half_overlap + keep].cpu()) pos += stride_samples idx += 1 return torch.cat(all_emb, dim=0), _mert_frame_rate # --------------------------------------------------------------------------- # Grid helpers # --------------------------------------------------------------------------- def _build_16th_grid(fretbars): """Build 16th-note timestamps (ms) from beat positions.""" if len(fretbars) < 2: return list(fretbars) positions = [] for i in range(len(fretbars) - 1): start = fretbars[i] interval = fretbars[i + 1] - start for sub in range(4): positions.append(start + sub * interval / 4.0) positions.append(fretbars[-1]) return positions def _get_local_beat_ms(grid_idx, fretbars): beat_idx = min(grid_idx // 4, len(fretbars) - 2) beat_idx = max(0, beat_idx) if beat_idx + 1 < len(fretbars): return fretbars[beat_idx + 1] - fretbars[beat_idx] return 500.0 # --------------------------------------------------------------------------- # Main inference # --------------------------------------------------------------------------- @torch.no_grad() def predict_notes( audio_path: str, model: ChartMaskPredictor, beat_times: list[float], difficulty: str = "expert", device: torch.device = None, num_steps: int = 12, temperature: float = 0.9, ) -> list[NoteEvent]: """MaskGIT-style iterative unmasking inference.""" if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dev = device model.to(dev) model.eval() fretbars = [t * 1000.0 for t in beat_times] if len(fretbars) < 2: return [] # MERT embeddings embeddings, frame_rate = encode_audio_mert(audio_path, dev) # Build grid and sample MERT frames with windowing grid_times = _build_16th_grid(fretbars) num_positions = len(grid_times) max_frame = embeddings.shape[0] - 1 frame_indices = torch.tensor( [min(int(round(t / 1000.0 * frame_rate)), max_frame) for t in grid_times], dtype=torch.long, ) window = 2 if window > 0 and max_frame >= window * 2: padded = torch.nn.functional.pad( embeddings.unsqueeze(0), (0, 0, window, window), mode="replicate", ).squeeze(0) shifted = frame_indices + window stacked = torch.stack( [padded[shifted + d] for d in range(-window, window + 1)], dim=0, ) grid_emb = stacked.mean(dim=0) else: grid_emb = embeddings[frame_indices] # Compute and concat audio features if model expects them if model.config.audio_dim > grid_emb.shape[-1]: import librosa as _lr wav, _ = _lr.load(audio_path, sr=24000, mono=True) hop = 320 onset = _lr.onset.onset_strength(y=wav, sr=24000, hop_length=hop) rms_arr = _lr.feature.rms(y=wav, hop_length=hop)[0] centroid = _lr.feature.spectral_centroid(y=wav, sr=24000, hop_length=hop)[0] def _norm(x): mn, mx = x.min(), x.max() return (x - mn) / max(mx - mn, 1e-8) onset, rms_arr, centroid = _norm(onset), _norm(rms_arr), _norm(centroid) af_rate = 24000 / hop af_max = len(onset) - 1 af_indices = [min(int(round(t / 1000.0 * af_rate)), af_max) for t in grid_times] af_tensor = torch.tensor( [[onset[i], rms_arr[i], centroid[i]] for i in af_indices], dtype=torch.float32, ) grid_emb = torch.cat([grid_emb, af_tensor], dim=-1) audio_features = grid_emb.unsqueeze(0).to(dev) diff_id = DIFF_ID.get(difficulty, 3) diff_tensor = torch.tensor([diff_id], dtype=torch.long, device=dev) padding_mask = torch.ones(1, num_positions, dtype=torch.bool, device=dev) # Start fully masked chart_tokens = torch.full( (1, num_positions), MASK_TOKEN, dtype=torch.long, device=dev, ) # Cosine unmasking schedule schedule = [] for step in range(num_steps): r_prev = math.cos(math.pi / 2 * step / num_steps) r_next = math.cos(math.pi / 2 * (step + 1) / num_steps) n_unmask = max(1, int((r_prev - r_next) * num_positions)) schedule.append(n_unmask) # Iterative unmasking for step in range(num_steps): outputs = model(audio_features, chart_tokens, diff_tensor, padding_mask) token_logits = outputs["token_logits"].squeeze(0) is_masked = (chart_tokens.squeeze(0) == MASK_TOKEN) masked_indices = is_masked.nonzero(as_tuple=True)[0] if len(masked_indices) == 0: break probs = torch.softmax(token_logits / temperature, dim=-1) sampled = torch.multinomial(probs, num_samples=1).squeeze(-1) n_unmask = min(schedule[step], len(masked_indices)) perm = torch.randperm(len(masked_indices), device=dev) unmask_idx = masked_indices[perm[:n_unmask]] chart_tokens[0, unmask_idx] = sampled[unmask_idx] # Final pass for sustain predictions outputs = model(audio_features, chart_tokens, diff_tensor, padding_mask) sustain_prob = outputs["sustain_logits"].squeeze(0).squeeze(-1).sigmoid() dur_pred = outputs["duration_logits"].squeeze(0).argmax(dim=-1) # Convert tokens to NoteEvents tokens = chart_tokens.squeeze(0).cpu() notes = [] for i in range(num_positions): tok = tokens[i].item() if tok >= SILENCE_TOKEN or tok < 0: continue fret_set = set(_CLASS_TO_FRETS[tok]) if not fret_set: continue sustain_ticks = 0 if sustain_prob[i] >= 0.5: bucket = dur_pred[i].item() beat_ms = _get_local_beat_ms(i, fretbars) sustain_ticks = _BUCKET_BEATS[bucket] * beat_ms notes.append(NoteEvent( tick=i, fret_set=fret_set, sustain_ticks=sustain_ticks, )) return notes