| import os
|
| import logging
|
| import functools
|
| import torch
|
|
|
| logger = logging.getLogger(__name__)
|
|
|
| try:
|
| from fairseq import checkpoint_utils
|
| FAIRSEQ_AVAILABLE = True
|
| except Exception:
|
| FAIRSEQ_AVAILABLE = False
|
|
|
|
|
| def _patch_torch_load():
|
| """Patch torch.load to default weights_only=False for fairseq compatibility (PyTorch 2.6+)."""
|
| _original = torch.load
|
|
|
| @functools.wraps(_original)
|
| def _patched(*args, **kwargs):
|
| kwargs.setdefault("weights_only", False)
|
| return _original(*args, **kwargs)
|
|
|
| return _original, _patched
|
|
|
|
|
| def get_index_path_from_model(sid):
|
| return next(
|
| (
|
| f
|
| for f in [
|
| os.path.join(root, name)
|
| for root, _, files in os.walk(os.getenv("index_root"), topdown=False)
|
| for name in files
|
| if name.endswith(".index") and "trained" not in name
|
| ]
|
| if sid.split(".")[0] in f
|
| ),
|
| "",
|
| )
|
|
|
|
|
| def load_hubert(config):
|
| if FAIRSEQ_AVAILABLE:
|
| _original, _patched = _patch_torch_load()
|
| torch.load = _patched
|
| try:
|
| models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
|
| ["assets/hubert/hubert_base.pt"],
|
| suffix="",
|
| )
|
| finally:
|
| torch.load = _original
|
| hubert_model = models[0]
|
| hubert_model = hubert_model.to(config.device)
|
| if config.is_half:
|
| hubert_model = hubert_model.half()
|
| else:
|
| hubert_model = hubert_model.float()
|
| return hubert_model.eval()
|
|
|
| try:
|
| import torchaudio
|
|
|
| class HubertWrapper:
|
| def __init__(self, model):
|
| self.model = model
|
| self.final_proj = getattr(model, "final_proj", torch.nn.Identity())
|
|
|
| def extract_features(self, source, padding_mask=None, output_layer=None):
|
| feats, _ = self.model.extract_features(source)
|
| if output_layer is None:
|
| idx = -1
|
| else:
|
| idx = min(output_layer - 1, len(feats) - 1)
|
| return (feats[idx], None)
|
|
|
| def to(self, device):
|
| self.model = self.model.to(device)
|
| return self
|
|
|
| def half(self):
|
| self.model = self.model.half()
|
| return self
|
|
|
| def float(self):
|
| self.model = self.model.float()
|
| return self
|
|
|
| def eval(self):
|
| self.model.eval()
|
| return self
|
|
|
| model = torchaudio.pipelines.HUBERT_BASE.get_model()
|
| hubert_model = HubertWrapper(model).to(config.device)
|
| if config.is_half:
|
| hubert_model = hubert_model.half()
|
| else:
|
| hubert_model = hubert_model.float()
|
| return hubert_model.eval()
|
| except Exception as e:
|
| raise RuntimeError(
|
| "HuBERT 模型加载失败,请检查 fairseq 和 torchaudio 是否已安装"
|
| ) from e
|
|
|