hanxunh's picture
Add self-contained preprocess() with per-variant defaults
13b7e55 verified
"""
Self-contained loader for AudioMosaic-vit-b16-pretrained.
Usage:
from huggingface_hub import snapshot_download
import sys
local = snapshot_download("hanxunh/AudioMosaic-vit-b16-pretrained")
sys.path.insert(0, local)
from load_model import load_pretrained_encoder
model = load_pretrained_encoder()
"""
import json
import os
import torch
from safetensors.torch import load_file
def load_pretrained_encoder(repo_dir: str = None, device: str = "cpu"):
"""Load the AudioMosaic ViT-B/16 encoder pretrained with NT-Xent + masking on AudioSet-2M."""
if repo_dir is None:
repo_dir = os.path.dirname(os.path.abspath(__file__))
cfg_path = os.path.join(repo_dir, "config.json")
weights_path = os.path.join(repo_dir, "model.safetensors")
with open(cfg_path) as f:
cfg = json.load(f)
cfg.pop("model_class", None)
from modeling import AudioMosaicPretrain
model = AudioMosaicPretrain(**cfg)
state = load_file(weights_path)
msg = model.load_state_dict(state, strict=False)
print(msg)
return model.to(device).eval()
def preprocess(audio, repo_dir: str = None, **overrides):
"""Convenience wrapper: read audio_preprocessing defaults from config.json and call modeling.preprocess.
Pass `repo_dir` if calling from outside the snapshot, otherwise auto-detected.
Use `**overrides` to override any default (e.g., target_length=2048).
"""
if repo_dir is None:
repo_dir = os.path.dirname(os.path.abspath(__file__))
import json as _json
with open(os.path.join(repo_dir, "config.json")) as f:
cfg = _json.load(f)
defaults = cfg.get("audio_preprocessing", {})
defaults.update(overrides)
from modeling import preprocess as _preprocess
return _preprocess(audio, **defaults)