| |
| import logging |
| import tempfile |
|
|
| from inference.infer_tool import Svc |
| from typing import * |
| import api.base |
| import os |
| import io |
| import wave |
| import numpy as np |
| from service.tool import audio_normalize, read_wav_file_to_numpy_array |
| from utils import get_hparams_from_file |
|
|
| logger = logging.getLogger(__name__) |
|
|
| _svc: Optional[Svc] = None |
| _model_paths: Optional[List] = None |
|
|
|
|
| def init(): |
| global _svc, _model_paths |
| _svc = Svc() |
| _model_paths = [] |
|
|
| |
| parent_dir = os.path.abspath(os.path.join(os.getcwd(), os.curdir)) |
|
|
| |
| checkpoints_dir = os.path.join(parent_dir, "checkpoints") |
|
|
| logger.debug(f"CkPoints Dir: {checkpoints_dir}") |
|
|
| for root, dirs, files in os.walk(checkpoints_dir): |
| for dir in dirs: |
| _model_paths.append(dir) |
|
|
|
|
| |
| class ModelListHandler(api.base.ApiHandler): |
| async def get(self): |
| self.write({ |
| "code": 200, |
| "msg": "ok", |
| "data": _model_paths |
| }) |
|
|
|
|
| |
| class SwitchHandler(api.base.ApiHandler): |
| async def post(self): |
| model_name = self.get_argument("model", "") |
| mode = self.get_argument("mode", "single") |
| device = self.get_argument("device", "cuda") |
|
|
| if model_name == "": |
| self.set_status(400) |
| self.write({ |
| "code": 400, |
| "msg": "未选择模型!", |
| "data": None |
| }) |
| return |
|
|
| if mode not in ("single", "batch"): |
| self.set_status(400) |
| self.write({ |
| "code": 400, |
| "msg": "运行模式选择错误!", |
| "data": None |
| }) |
| return |
|
|
| if device not in ("cpu", "cuda"): |
| self.set_status(400) |
| self.write({ |
| "code": 400, |
| "msg": "设备选择错误!", |
| "data": None |
| }) |
| return |
|
|
| logger.debug(f"modelname: {model_name}\n" |
| f"mode: {mode}\n" |
| f"device: {device}\n") |
| try: |
| _svc.set_device(device=device) |
| logger.debug(f"Device set.") |
| _svc.load_checkpoint(path=model_name) |
| logger.debug(f"Model set.") |
| except Exception as e: |
| logger.exception(e) |
| self.set_status(500) |
| self.write({ |
| "code": 500, |
| "msg": "system_error", |
| "data": None |
| }) |
| return |
|
|
| self.write({ |
| "code": 200, |
| "msg": "ok", |
| "data": { |
| "mode": mode |
| } |
| }) |
|
|
|
|
| |
| class SingleInferenceHandler(api.base.ApiHandler): |
| async def post(self): |
| try: |
| from scipy.io import wavfile |
|
|
| dsid = self.get_argument("dsid", "") |
| tran = self.get_argument("tran", "0") |
| th = self.get_argument("th", "-40.0") |
| ns = self.get_argument("ns", "0.4") |
| audiofile_dict = self.request.files.get("srcaudio", []) |
|
|
| if not audiofile_dict: |
| self.set_status(400) |
| self.write({ |
| "code": 400, |
| "msg": "未上传文件!", |
| "data": None |
| }) |
| return |
|
|
| if dsid == "": |
| self.set_status(400) |
| self.write({ |
| "code": 400, |
| "msg": "未选择模型!", |
| "data": None |
| }) |
| return |
|
|
| audiofile = audiofile_dict[0] |
| audio_filename = audiofile['filename'] |
| audio_filebody = audiofile['body'] |
| audio_fileext = os.path.splitext(audio_filename)[-1].lower() |
|
|
| with tempfile.NamedTemporaryFile(suffix=audio_fileext, delete=False) as temp_file: |
| temp_file.write(audio_filebody) |
| temp_file.close() |
|
|
| converted_file = await audio_normalize(full_filename=audio_filename, file_data=audio_filebody) |
| |
| |
| |
| |
| |
| |
| |
|
|
| sampling_rate, audio_array = read_wav_file_to_numpy_array(converted_file) |
| os.remove(converted_file) |
|
|
| scraudio = (sampling_rate, audio_array) |
|
|
| logger.debug(f"read file {audio_filename}\n" |
| f"sampling rate: {sampling_rate}") |
|
|
| tran = float(tran) |
| th = float(th) |
| ns = float(ns) |
|
|
| hparams = get_hparams_from_file(f"checkpoints/{dsid}/config.json") |
| spk = hparams.spk |
| real_dsid = "" |
| for k, v in spk.items(): |
| if v == 0: |
| real_dsid = k |
| logger.debug(f"read dsid is: {real_dsid}") |
|
|
| output_audio_sr, output_audio_array = _svc.inference(srcaudio=scraudio, |
| chara=real_dsid, |
| tran=tran, |
| slice_db=th, |
| ns=ns) |
|
|
| logger.debug(f"svc for {audio_filename} succeed. \n" |
| f"audio data type: {type(output_audio_array)}\n" |
| f"audio data sr: {output_audio_sr}") |
|
|
| logger.debug(f"start output data.") |
|
|
| |
| with io.BytesIO() as wav_file: |
| wavfile.write(wav_file, sampling_rate, output_audio_array) |
| wav_data = wav_file.getvalue() |
|
|
| |
| self.set_header('Content-Type', 'audio/wav') |
| self.set_header('Content-Disposition', f'attachment; filename="svc_output.wav"') |
| self.write(wav_data) |
| await self.flush() |
| logger.debug(f"response completed.") |
| except Exception as e: |
| logger.exception(e) |
| self.set_status(500) |
| self.write({ |
| "code": 500, |
| "msg": "system_error", |
| "data": None |
| }) |
| return |
|
|
|
|
| |
| class BatchInferenceHandler(api.base.ApiHandler): |
| async def post(self): |
| try: |
| from zipfile import ZipFile |
| from scipy.io import wavfile |
| import uuid |
|
|
| dsid = self.get_argument("dsid", "") |
| tran = self.get_argument("tran", "0") |
| th = self.get_argument("th", "-40.0") |
| ns = self.get_argument("ns", "0.4") |
| audiofile_dict = self.request.files.get("srcaudio", []) |
|
|
| logger.debug(len(self.request.files)) |
|
|
| if not audiofile_dict: |
| self.set_status(400) |
| self.write({ |
| "code": 400, |
| "msg": "未上传文件!", |
| "data": None |
| }) |
| return |
|
|
| if dsid == "": |
| self.set_status(400) |
| self.write({ |
| "code": 400, |
| "msg": "未选择模型!", |
| "data": None |
| }) |
| return |
|
|
| temp_dir_name = "temp" |
|
|
| |
| parent_dir = os.path.abspath(os.path.join(os.getcwd(), os.curdir)) |
|
|
| |
| temp_dir = os.path.join(parent_dir, temp_dir_name) |
|
|
| logger.debug(f"TempDir: {temp_dir}") |
|
|
| if not os.path.exists(temp_dir): |
| os.mkdir(temp_dir) |
|
|
| tmp_workdir_name = f"{temp_dir}/batch_{uuid.uuid4()}" |
| if not os.path.exists(tmp_workdir_name): |
| os.mkdir(tmp_workdir_name) |
|
|
| output_files = [] |
|
|
| tran = float(tran) |
| th = float(th) |
| ns = float(ns) |
|
|
| hparams = get_hparams_from_file(f"checkpoints/{dsid}/config.json") |
| spk = hparams.spk |
| real_dsid = "" |
| for k, v in spk.items(): |
| if v == 0: |
| real_dsid = k |
| logger.debug(f"read dsid is: {real_dsid}") |
|
|
| for idx, file in enumerate(audiofile_dict): |
| audio_filename = file["filename"] |
| audio_filebody = file["body"] |
| filename = os.path.basename(audio_filename) |
| audio_fileext = os.path.splitext(audio_filename)[-1].lower() |
|
|
| with tempfile.NamedTemporaryFile(suffix=audio_fileext, delete=False) as temp_file: |
| temp_file.write(audio_filebody) |
| temp_file.close() |
|
|
| converted_file = await audio_normalize(full_filename=audio_filename, file_data=audio_filebody) |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| sampling_rate, audio_array = read_wav_file_to_numpy_array(converted_file) |
| os.remove(converted_file) |
|
|
| scraudio = (sampling_rate, audio_array) |
|
|
| print(f"{idx}, {len(audio_filebody)}, {filename}") |
|
|
| output_sampling_rate, output_audio = _svc.inference(scraudio, chara=real_dsid, tran=tran, |
| slice_db=th, ns=ns) |
| new_filepath = f"{tmp_workdir_name}/{filename}" |
| wavfile.write(filename=new_filepath, rate=output_sampling_rate, data=output_audio) |
| output_files.append(new_filepath) |
|
|
| zipfilename = f"{tmp_workdir_name}/output.zip" |
| with ZipFile(zipfilename, "w") as zip_obj: |
| for idx, filepath in enumerate(output_files): |
| zip_obj.write(filepath, os.path.basename(filepath)) |
|
|
| |
|
|
| logger.debug(f"start output data.") |
| |
| self.set_header("Content-Type", "application/zip") |
| self.set_header("Content-Disposition", "attachment; filename=output.zip") |
| with open(zipfilename, "rb") as file: |
| self.write(file.read()) |
| await self.flush() |
| logger.debug(f"response completed.") |
| except Exception as e: |
| logger.exception(e) |
| self.set_status(500) |
| self.write({ |
| "code": 500, |
| "msg": "system_error", |
| "data": None |
| }) |
| return |
|
|
|
|
| if __name__ == "__main__": |
| init() |
| print(_model_paths) |
|
|