| import os |
| import torch |
| import numpy as np |
| import torchaudio |
| import yaml |
| from . import asteroid_test |
| from huggingface_hub import hf_hub_download |
| import logging |
| logger = logging.getLogger(__name__) |
|
|
| torchaudio.set_audio_backend("sox_io") |
|
|
|
|
| def get_conf(): |
| conf_filterbank = { |
| 'n_filters': 64, |
| 'kernel_size': 16, |
| 'stride': 8 |
| } |
|
|
| conf_masknet = { |
| 'in_chan': 64, |
| 'n_src': 2, |
| 'out_chan': 64, |
| 'ff_hid': 256, |
| 'ff_activation': "relu", |
| 'norm_type': "gLN", |
| 'chunk_size': 100, |
| 'hop_size': 50, |
| 'n_repeats': 2, |
| 'mask_act': 'sigmoid', |
| 'bidirectional': True, |
| 'dropout': 0 |
| } |
| return conf_filterbank, conf_masknet |
|
|
|
|
| def load_dpt_model(): |
| print('Load Separation Model...') |
|
|
| |
| from huggingface_hub import hf_hub_download |
| speech_sep_token = os.getenv("SpeechSeparation") |
| if not speech_sep_token: |
| raise EnvironmentError("環境變數 SpeechSeparation 未設定!") |
|
|
| |
| model_path = hf_hub_download( |
| repo_id="DeepLearning101/speech-separation", |
| filename="train_dptnet_aishell_partOverlap_B2_300epoch_quan-int8.p", |
| token=speech_sep_token |
| ) |
|
|
| conf_filterbank, conf_masknet = get_conf() |
| model_class = getattr(asteroid_test, "DPTNet") |
| model = model_class(**conf_filterbank, **conf_masknet) |
| model = torch.quantization.quantize_dynamic(model, {torch.nn.LSTM, torch.nn.Linear}, dtype=torch.qint8) |
|
|
| try: |
| state_dict = torch.load(model_path, map_location="cpu", weights_only=False) |
| except pickle.UnpicklingError as e: |
| raise RuntimeError( |
| "模型載入失敗!請確認:\n" |
| "1. 模型來源是否可信\n" |
| "2. 是否為舊版 PyTorch 儲存的模型\n" |
| "3. 嘗試鎖定 PyTorch 版本為 2.5.x" |
| ) from e |
|
|
| model.load_state_dict(state_dict) |
| model.eval() |
| return model |
|
|
|
|
| import torchaudio |
| import tempfile |
|
|
| def dpt_sep_process(wav_path, model=None, outfilename=None): |
| try: |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model = model.to(device) |
| |
| |
| x, sr = torchaudio.load(wav_path, format="wav") |
| x = x.mean(dim=0, keepdim=True).to(device) |
| |
| |
| if sr != 16000: |
| resampler = torchaudio.transforms.Resample(sr, 16000).to(device) |
| x = resampler(x) |
| sr = 16000 |
| |
| |
| chunk_size = sr * 60 |
| separated = [] |
| for i in range(0, x.shape[1], chunk_size): |
| chunk = x[:, i:i+chunk_size] |
| with torch.no_grad(): |
| est = model(chunk) |
| separated.append(est.cpu()) |
| |
| est_sources = torch.cat(separated, dim=2) |
|
|
| |
| est_sources = est_sources.squeeze(0) |
| sep_1, sep_2 = est_sources[0], est_sources[1] |
|
|
| |
| peak = 0.9 * torch.max(torch.abs(x)) |
| sep_1 = peak * sep_1 / torch.max(torch.abs(sep_1)) |
| sep_2 = peak * sep_2 / torch.max(torch.abs(sep_2)) |
|
|
| |
| with tempfile.TemporaryDirectory() as tmp_dir: |
| sep1_path = os.path.join(tmp_dir, "sep1.wav") |
| sep2_path = os.path.join(tmp_dir, "sep2.wav") |
| |
| torchaudio.save(sep1_path, sep_1.unsqueeze(0), sr) |
| torchaudio.save(sep2_path, sep_2.unsqueeze(0), sr) |
|
|
| |
| final_sep1 = outfilename.replace('.wav', '_sep1.wav') |
| final_sep2 = outfilename.replace('.wav', '_sep2.wav') |
| os.replace(sep1_path, final_sep1) |
| os.replace(sep2_path, final_sep2) |
|
|
| |
| logger.info(f"💾 寫入輸出檔案至: {final_sep1}, {final_sep2}") |
| |
| return final_sep1, final_sep2 |
|
|
| except RuntimeError as e: |
| if "CUDA out of memory" in str(e): |
| raise RuntimeError("記憶體不足,請縮短音訊長度") from e |
| else: |
| raise |
|
|
| if __name__ == '__main__': |
| print("This module should be used via Flask or Gradio.") |