| import torch |
| import torch.nn as nn |
| import whisper |
| from whisper.model import AudioEncoder, ModelDimensions |
| from typing import Dict, Optional |
| from whisperspeech.vq_stoks import RQBottleneckTransformer, Tunables |
| from huggingface_hub import hf_hub_download |
| import torch.nn.functional as F |
| import os |
| from typing import List, Optional, Union |
| import io |
| import urllib |
| from tqdm import tqdm |
| import torchaudio |
|
|
| _HF_MODELS = { |
| "medium": "https://huggingface.co/jan-hq/WhisperVQ/resolve/main/medium_encoder_only.pt", |
| } |
|
|
|
|
| def available_models() -> List[str]: |
| """Returns the names of available models""" |
| return list(_HF_MODELS.keys()) |
|
|
|
|
| def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]: |
| os.makedirs(root, exist_ok=True) |
|
|
| expected_sha256 = url.split("/")[-2] |
| download_target = os.path.join(root, os.path.basename(url)) |
|
|
| if os.path.exists(download_target) and not os.path.isfile(download_target): |
| raise RuntimeError( |
| f"{download_target} exists and is not a regular file") |
|
|
| if os.path.isfile(download_target): |
| with open(download_target, "rb") as f: |
| model_bytes = f.read() |
| return model_bytes if in_memory else download_target |
| import ssl |
| ssl._create_default_https_context = ssl._create_unverified_context |
| with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: |
| with tqdm( |
| total=int(source.info().get("Content-Length")), |
| ncols=80, |
| unit="iB", |
| unit_scale=True, |
| unit_divisor=1024, |
| ) as loop: |
| while True: |
| buffer = source.read(8192) |
| if not buffer: |
| break |
|
|
| output.write(buffer) |
| loop.update(len(buffer)) |
|
|
| model_bytes = open(download_target, "rb").read() |
| return model_bytes if in_memory else download_target |
|
|
|
|
| class CustomWhisperEncoder(nn.Module): |
| """ |
| Lightweight wrapper that only loads the AudioEncoder part of Whisper |
| """ |
|
|
| def __init__(self, name: str, device: str = None, download_root: str = None, in_memory: bool = False,): |
| super().__init__() |
| if device is None: |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| if download_root is None: |
| default = os.path.join(os.path.expanduser("~"), ".cache") |
| |
| download_root = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) |
|
|
| if name in _HF_MODELS: |
| checkpoint_file = _download( |
| _HF_MODELS[name], download_root, in_memory) |
| elif os.path.isfile(name): |
| checkpoint_file = open(name, "rb").read() if in_memory else name |
| else: |
| raise RuntimeError( |
| f"Model {name} not found; available models = {available_models()}" |
| ) |
|
|
| |
| with ( |
| io.BytesIO(checkpoint_file) if in_memory else open( |
| checkpoint_file, "rb") |
| ) as fp: |
| checkpoint = torch.load(fp, map_location=device) |
| del checkpoint_file |
| dims = ModelDimensions(**checkpoint["dims"]) |
| self.encoder = AudioEncoder( |
| dims.n_mels, |
| dims.n_audio_ctx, |
| dims.n_audio_state, |
| dims.n_audio_head, |
| dims.n_audio_layer, |
| ) |
|
|
| self.encoder.load_state_dict(checkpoint["model_state_dict"]) |
|
|
| if device: |
| self.to(device) |
|
|
| self.eval() |
|
|
| def forward(self, mel: torch.Tensor): |
| return self.encoder(mel) |
|
|
|
|
| class CustomRQBottleneckTransformer(RQBottleneckTransformer): |
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
|
|
| @classmethod |
| def load_vq_only(cls, ref="collabora/spear-tts-pytorch:whisper-vq-stoks-medium-en+pl.model", |
| repo_id=None, filename=None, local_filename=None): |
| if repo_id is None and filename is None and local_filename is None: |
| if ":" in ref: |
| repo_id, filename = ref.split(":", 1) |
| else: |
| local_filename = ref |
| if not local_filename: |
| local_filename = hf_hub_download( |
| repo_id=repo_id, filename=filename) |
|
|
| |
| spec = torch.load(local_filename) |
|
|
| |
| instance = cls(**spec['config'], tunables=Tunables(** |
| Tunables.upgrade(spec.get('tunables', {})))) |
|
|
| |
| required_components = { |
| 'rq', 'mlp', 'mlp_ln' |
| } |
| filtered_state_dict = { |
| k: v for k, v in spec['state_dict'].items() |
| if any(k.startswith(comp) for comp in required_components) |
| } |
|
|
| instance.load_state_dict(filtered_state_dict, strict=False) |
| instance.eval() |
| return instance |
|
|
| def load_encoder(self, device=None): |
| if self.whmodel is not None: |
| return |
| device = device or self.device |
| |
| if self.whmodel is None: |
| encoder = CustomWhisperEncoder( |
| self.whisper_model_name, device=device) |
| self.whmodel = encoder |
| multilingual = not self.whisper_model_name.endswith('.en') |
| self.tokenizer = whisper.tokenizer.get_tokenizer(multilingual) |
|
|
| def optimzed_encode_mel(self, mel): |
| assert len( |
| mel.shape) == 3, "invalid mel spectrogram shape, expect (batch,chn,time)" |
| self.load_encoder() |
| n = mel.shape[-1] |
| if n > whisper.audio.N_FRAMES: |
| padding = 0 |
| padded = mel[:, :, :whisper.audio.N_FRAMES] |
| else: |
| padding = -n % whisper.audio.N_FRAMES |
| padded = F.pad(mel, (0, padding), value=-1.5) |
| |
| embs = self.whmodel.encoder(padded) |
| stoks = self.quantize(embs) |
| if self.tunables.mask_embs: |
| return stoks[:, :n//2//self.downsample] |
| else: |
| return stoks |
| |
|
|
| def encode_audio(self, audio): |
| if isinstance(audio, str): |
| x, sr = torchaudio.load(audio) |
| x = torchaudio.transforms.Resample(sr, 16000)(x)[0] |
| audio = x.unsqueeze(0) |
| return self.optimzed_encode_mel(self.log_mel_spectrogram(audio).to(self.device)) |
|
|
|
|
| if __name__ == "__main__": |
| |
| vqmodel = CustomRQBottleneckTransformer.load_vq_only( |
| "whisper-vq-stoks-v3-7lang-fixed.model" |
| ).to("cuda") |
| vqmodel.load_encoder('cuda') |
| vqmodel.eval() |
|
|