| from pathlib import Path |
| import time |
| import csv |
| from funasr import AutoModel |
|
|
|
|
| def main(): |
| device = "mps" |
| model_dir = "/Users/jeqin/work/code/Fun-ASR-Nano-2512" |
| model = AutoModel( |
| model=model_dir, |
| trust_remote_code=True, |
| remote_code="./model.py", |
| device=device, |
| ) |
|
|
| wav_path = f"/Users/jeqin/work/code/TestTranslator/test_data/audio_clips/zhengyaowei-part1.mp3" |
| res = model.generate( |
| input=[wav_path], |
| cache={}, |
| batch_size=1, |
| |
| |
| |
| |
| |
| |
| language="中文", |
| itn=True, |
| ) |
| text = res[0]["text"] |
| print(text) |
| text = model.generate(input=[wav_path], |
| cache={}, |
| batch_size=1, |
| |
| |
| itn=True, |
| )[0]["text"] |
| print(text) |
| text = model.generate(input=[wav_path], |
| cache={}, |
| batch_size=1, |
| hotwords=["头数", "llama", "decode", "query"], |
| |
| itn=True, |
| )[0]["text"] |
| print(text) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def save_csv(file_path, rows): |
| with open(file_path, "w", encoding="utf-8") as f: |
| writer = csv.writer(f) |
| writer.writerows(rows) |
| print(f"write csv to {file_path}") |
|
|
| def load_model(): |
| device = "mps" |
| s = time.time() |
| |
| model_dir = "/Users/jeqin/work/code/Fun-ASR-MLT-Nano-2512" |
| model = AutoModel( |
| model=model_dir, |
| trust_remote_code=True, |
| remote_code="./model.py", |
| device=device, |
| disable_update=True, |
| ) |
| print("load model cost:", time.time() - s) |
| return model |
|
|
| def inference(model, wav_path): |
| t1 = time.time() |
| res = model.generate(input=[str(wav_path)], cache={}, batch_size=1) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| text = res[0]["text"] |
| return text, time.time()-t1 |
|
|
| def run_audio_clips(): |
| model = load_model() |
| audios = Path("/Users/jeqin/work/code/TestTranslator/test_data/audio_clips/10s-mix") |
| rows = [["file_name", "time", "inference_result"]] |
| for audio in sorted(audios.glob("*.wav")): |
| print(audio) |
| text, cost = inference(model, audio) |
| print("inference cost: ", cost) |
| print(text) |
| rows.append([audio.name, round(cost, 3), text]) |
| file_name = "csv/funasr_nano.csv" |
| |
|
|
|
|
| def run_recordings(): |
| from scripts.asr_utils import get_origin_text_dict, get_text_distance |
| model = load_model() |
| audios = Path("/Users/jeqin/work/code/TestTranslator/test_data/recordings/") |
| rows = [["file_name", "time", "inference_result"]] |
| original = get_origin_text_dict() |
| for audio in sorted(audios.glob("*.wav"), key=lambda x: int(x.stem)): |
| print("processing: ", audio) |
| text, cost = inference(model, audio) |
| print("inference cost: ", cost) |
| print(text) |
| d, nd, diff = get_text_distance(original[audio.stem], text) |
| rows.append([audio.name, round(cost, 3), text, d, diff]) |
| file_name = "csv/funasr_nano.csv" |
| save_csv(file_name, rows) |
|
|
| def run_test_wenet(): |
| from test_data.audios import read_wenet |
| model = load_model() |
| result_list = [] |
| count = 0 |
| for audio, sentence in read_wenet(count_limit=5000): |
| count += 1 |
| print(f"processing {count}: {audio}") |
| text, cost = inference(model, audio) |
| print("inference time:", cost) |
| result_list.append({ |
| "index": count, |
| "audio_path": audio.name, |
| "reference": sentence, |
| |
| "inference_time": round(cost, 3), |
| "inference_result": text |
| }) |
| print("inference cost: ", cost) |
| print(text) |
|
|
| import json |
| with open("csv/funasr_nano_wenet.json", "w", encoding="utf-8") as f: |
| json.dump(result_list, f, ensure_ascii=False, indent=2) |
|
|
| if __name__ == "__main__": |
| |
| run_recordings() |
| |
| |