| from pywhispercpp.model import Model |
| import soundfile |
| import numpy as np |
| from logging import getLogger |
| from pathlib import Path |
|
|
| from lib.utils import Timer, read_audio |
|
|
| logger = getLogger(__name__) |
|
|
| MODEL_DIR = Path("/Users/jeqin/work/code/Translator/python_server/moyoyo_asr_models") |
| WHISPER_PROMPT_ZH = "以下是简体中文普通话的句子。" |
| WHISPER_PROMPT_EN = "" |
|
|
| class WhisperCPP: |
| def __init__(self, model_dir=MODEL_DIR, source_lange: str = 'en') -> None: |
| whisper_model = 'large-v3-turbo-q5_0' |
| with Timer("load whisper"): |
| self.model = Model( |
| model=whisper_model, |
| models_dir=str(model_dir), |
| print_realtime=False, |
| print_progress=False, |
| print_timestamps=False, |
| translate=False, |
| |
| temperature=0., |
| no_context=True |
| ) |
| self._warmup() |
|
|
| def _warmup(self): |
| fake_audio = np.random.randn(16000).astype(np.float32) |
| self.model.transcribe(fake_audio, print_progress=False) |
|
|
| @staticmethod |
| def config_language(language): |
| if language == "zh": |
| return WHISPER_PROMPT_ZH |
| elif language == "en": |
| return WHISPER_PROMPT_EN |
| raise ValueError(f"Unsupported language : {language}") |
|
|
| def transcribe(self, audio: np.ndarray, language): |
| prompt = self.config_language(language) |
| try: |
| with Timer("whisper inference") as t: |
| segments = self.model.transcribe( |
| audio, |
| initial_prompt=prompt, |
| language=language, |
| |
| split_on_word=True, |
| |
| ) |
| text = "".join([s.text for s in segments]) |
| return text, t.duration |
| except Exception as e: |
| logger.error(e) |
| return [] |
|
|
| if __name__ == '__main__': |
| from lib.utils import read_audio |
| whisper = WhisperCPP() |
| audio = read_audio(Path("/Users/jeqin/work/code/TestTranslator/test_data/recordings/1.wav")) |
| text, time_cost = whisper.transcribe(audio, "zh") |
| print(text) |
| print(time_cost) |