TestTranslator / scripts /run_whisper_finetuned_with_punc_ov.py
yujuanqin's picture
update scripts
42742c6
import argparse
import os
import time
from pathlib import Path
import csv
import json
import yaml
from typing import List, Dict, Optional
import librosa
import openvino_genai
import numpy as np
from scripts.funasr_ct.ct_transformer import CT_Transformer
from scripts.asr_utils import get_origin_text_dict, get_text_distance
def save_csv(file_path, rows):
with open(file_path, "w", encoding="utf-8", newline="") as f:
writer = csv.writer(f)
writer.writerows(rows)
print(f"write csv to {file_path}")
def load_audio(audio_path: str, sr: int = 16000):
# 读取音频并转成 16k 单声道 numpy float32
audio, _ = librosa.load(audio_path, sr=sr, mono=True)
return audio
def read_wav(filepath):
raw_speech, samplerate = librosa.load(filepath, sr=16000)
return raw_speech.tolist()
def transcribe_file(
audio_path: str,
model,
lang="en"
):
raw_speech = read_wav(audio_path)
res = model.generate(raw_speech,language=lang)
# print(res.texts)
# print(f"inference time: {time.time() - t0}")
return str(res)
def load_model(device):
device = "GPU" # GPU can be used as well
# model_path = r"D:\yujuan\yoyo-translator-win\models\whisper-large-v3-turbo-int8"
model_path = r"D:\yujuan\models\whisper-turbo-39000-int8p\whisper-turbo-39000-int8p"
punc_model = r"D:\yujuan\models\funasr_ct\ct-punc"
t0 = time.time()
asr = openvino_genai.WhisperPipeline(model_path, device)
punc = CT_Transformer(punc_model, device=device)
print("load model time: ", time.time() - t0)
return asr, punc
def inference(audio: Path, asr, punc, lang):
try:
t0 = time.time()
asr_text = transcribe_file(
str(audio), asr, lang
)
t1 = time.time()
if lang =="<|zh|>":
punc_text = punc(asr_text)[0]
else:
punc_text = asr_text
t2 = time.time()
print(f"{audio.name} -> {asr_text} -> {punc_text}; \n asr cost: {t1-t0}; punc cost: {t2-t1}")
return punc_text, t2-t0
except Exception as e:
print(f"{audio.name} -> 失败: {e}")
def run_test_audios():
device = "GPU" # GPU can be used as well
lang = "<|en|>"
asr, punc = load_model(device)
audios = Path(r"D:\yujuan\TestTranslator\tests\test_data\test_audios")
rows = [["file_name", "time", "inference_result"]]
for audio in sorted(audios.glob("*en*/*.wav")): # *s/randomforest*.wav"
text, t = inference(audio, asr, punc, lang)
rows.append([f"{audio.parent.name}/{audio.name}", t, text])
save_csv("csv/finetune_whisper_with_punc.csv", rows)
def run_recordings():
device = "GPU" # GPU can be used as well
lang = "<|zh|>"
asr, punc = load_model(device)
audios = Path(r"D:\yujuan\TestTranslator\tests\test_data\recordings")
rows = [["file_name", "time", "inference_result"]]
original = get_origin_text_dict()
for audio in sorted(audios.glob("*.wav"), key=lambda x: int(x.stem)):
text, t = inference(audio, asr, punc, lang)
d, nd, diff = get_text_distance(original[audio.stem], text)
rows.append([audio.name, round(t, 3), text, d, round(nd,3), diff])
save_csv("csv/finetune_whisper_with_punc.csv", rows)
def run_test_dataset():
from test_data.audios import read_dataset
device = "GPU"
lang = "<|zh|>"
asr, punc = load_model(device)
test_data = Path("../tests/test_data/dataset.txt")
audio_parent = Path("../tests/test_data/")
rows = [["file_name", "time", "inference_result"]]
result_list = []
count = 0
try:
for audio_path, sentence, duration in read_dataset(test_data):
count += 1
print(f"processing {count}: {audio_path}")
text, t = inference(audio_parent/audio_path, asr, punc, lang)
print("inference time:", t)
print(text)
result_list.append({
"index": count,
"audio_path": audio_path,
"reference": sentence,
"duration": duration,
"inference_time": round(t, 3),
"inference_result": text
})
except Exception as e:
print(e)
except KeyboardInterrupt as e:
print(e)
import json
with open("csv/whisper_finetune_ov_results.json", "w", encoding="utf-8") as f:
json.dump(result_list, f, ensure_ascii=False, indent=2)
def run_test_emilia():
from test_data.audios import read_emilia
device = "GPU"
lang = "<|zh|>"
asr, punc = load_model(device)
parent = Path("../tests/test_data/ZH-B000000")
result_list = []
count = 0
try:
for audio_path, sentence, duration in read_emilia(parent, count_limit=5000):
count += 1
print(f"processing {count}: {audio_path.name}")
text, t = inference(audio_path, asr, punc, lang)
print("inference time:", t)
print(text)
result_list.append({
"index": count,
"audio_path": audio_path.name,
"reference": sentence,
"duration": duration,
"inference_time": round(t, 3),
"inference_result": text
})
except Exception as e:
print(e)
except KeyboardInterrupt as e:
print(e)
import json
with open("csv/whisper_finetune_emilia_ov_results.json", "w", encoding="utf-8") as f:
json.dump(result_list, f, ensure_ascii=False, indent=2)
if __name__ == "__main__":
# main()
run_test_emilia()