| import argparse |
| import os |
| import time |
| from pathlib import Path |
| import csv |
|
|
| import torch |
| import librosa |
| from transformers import WhisperForConditionalGeneration, WhisperProcessor |
|
|
| def save_csv(file_path, rows): |
| with open(file_path, "w", encoding="utf-8") as f: |
| writer = csv.writer(f) |
| writer.writerows(rows) |
| print(f"write csv to {file_path}") |
|
|
|
|
| def load_audio(audio_path: str, sr: int = 16000): |
| |
| audio, _ = librosa.load(audio_path, sr=sr, mono=True) |
| return audio |
|
|
|
|
| def transcribe_file( |
| audio_path: str, |
| model, |
| processor, |
| language: str = "Chinese", |
| task: str = "transcribe", |
| timestamps: bool = False, |
| max_new_tokens: int = 255, |
| ): |
| |
| audio = load_audio(audio_path, sr=16000) |
| inputs = processor(audio, sampling_rate=16000, return_tensors="pt") |
|
|
| |
| device = next(model.parameters()).device |
| input_features = inputs["input_features"].to(device) |
|
|
| |
| with torch.inference_mode(), torch.autocast(device_type="cuda", enabled=(device.type == "cuda")): |
| generated_ids = model.generate( |
| input_features=input_features, |
| max_new_tokens=max_new_tokens, |
| return_timestamps=timestamps, |
| ) |
|
|
| |
| text = processor.tokenizer.batch_decode(generated_ids.cpu().numpy(), skip_special_tokens=True) |
| return text[0] |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser("Simple Whisper Inference") |
| parser.add_argument("--model_path", type=str, default="whisper-large-v3-turbo-finetune", |
| help="本地合并模型路径或HF模型名") |
| parser.add_argument("--input", type=str, required=True, |
| help="音频文件路径,或目录(将批量处理其中的音频)") |
| parser.add_argument("--language", type=str, default="Chinese", |
| help="语言(如 Chinese / English / zh / en)") |
| parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], |
| help="任务:转写或翻译") |
| parser.add_argument("--timestamps", action="store_true", help="是否返回时间戳(若模型与版本支持)") |
| parser.add_argument("--local_files_only", action="store_true", help="仅本地加载,不联网") |
| parser.add_argument("--batch_exts", type=str, default=".wav,.mp3,.flac,.m4a", |
| help="当 --input 是目录时,处理这些后缀的文件,逗号分隔") |
| args = parser.parse_args() |
|
|
| |
| processor = WhisperProcessor.from_pretrained( |
| args.model_path, |
| language=args.language, |
| task=args.task, |
| no_timestamps=not args.timestamps, |
| local_files_only=args.local_files_only, |
| ) |
| model = WhisperForConditionalGeneration.from_pretrained( |
| args.model_path, |
| device_map="auto", |
| local_files_only=args.local_files_only, |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
| ) |
|
|
| model.generation_config.language = args.language.lower() |
| model.generation_config.forced_decoder_ids = None |
| model.eval() |
|
|
| path = Path(args.input) |
| if path.is_file(): |
| text = transcribe_file( |
| str(path), model, processor, |
| language=args.language, task=args.task, timestamps=args.timestamps |
| ) |
| print(f"{path.name} -> {text}") |
| else: |
| |
| exts = {e.strip().lower() for e in args.batch_exts.split(",")} |
| files = [p for p in path.rglob("*") if p.suffix.lower() in exts] |
| if not files: |
| print("目录中未找到可处理的音频文件。") |
| return |
| for p in sorted(files): |
| try: |
| t0 = time.time() |
| text = transcribe_file( |
| str(p), model, processor, |
| language=args.language, task=args.task, timestamps=args.timestamps |
| ) |
| t1 = time.time() |
| print(f"{p.name} -> {text}; time cost: {t1-t0}") |
| except Exception as e: |
| print(f"{p.name} -> 失败: {e}") |
|
|
| def load_model(): |
| |
| model_path = "/Users/jeqin/Downloads/whisper-large-v3-turbo-finetune_1219" |
| lang = "zh" |
| t0 = time.time() |
| processor = WhisperProcessor.from_pretrained( |
| model_path, |
| language=lang, |
| task="transcribe", |
| no_timestamps=True, |
| local_files_only=True, |
| ) |
| model = WhisperForConditionalGeneration.from_pretrained( |
| model_path, |
| device_map="mps", |
| local_files_only=True, |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
| ) |
|
|
| model.generation_config.language = lang.lower() |
| model.generation_config.forced_decoder_ids = None |
| model.eval() |
| print("load model time: ", time.time() - t0) |
| return model, processor |
|
|
| def run_test_audios(): |
| model, processor = load_model() |
| audios = Path("../test_data/audio_clips/") |
| rows = [["file_name", "inference_time", "inference_result"]] |
| for audio in sorted(audios.glob("*en-ac1-16k/*.wav")): |
| try: |
| t0 = time.time() |
| text = transcribe_file( |
| str(audio), model, processor |
| ) |
|
|
| t = time.time()-t0 |
| print(f"{audio.name} -> {text}; time cost: {t}") |
| rows.append([f"{audio.parent.name}/{audio.name}", t, text]) |
| except Exception as e: |
| print(f"{audio.name} -> 失败: {e}") |
| save_csv("csv/fine-tune_whisper-0901.csv", rows) |
|
|
| def run_recordings(): |
| from scripts.asr_utils import get_origin_text_dict, get_text_distance |
| model, processor = load_model() |
| audios = Path("../test_data/recordings/") |
| rows = [["file_name", "time", "inference_result"]] |
| original = get_origin_text_dict() |
| for audio in sorted(audios.glob("*.wav"), key=lambda x: int(x.stem)): |
| print(audio) |
| try: |
| t0 = time.time() |
| text = transcribe_file( |
| str(audio), model, processor |
| ) |
| t = time.time()-t0 |
| print(text) |
| print("inference time:", t) |
| d, nd, diff = get_text_distance(original[audio.stem], text) |
| rows.append([audio.name, round(t, 3), text, d, round(nd,3), diff]) |
| except Exception as e: |
| print(f"{audio.name} -> 失败: {e}") |
| save_csv("csv/fine-tune_whisper.csv", rows) |
|
|
|
|
| def run_test_dataset(): |
| from test_data.audios import read_dataset |
| model, processor = load_model() |
| test_data = Path("../test_data/AIShell/dataset/dataset.txt") |
| audio_parent = Path("../test_data/") |
| rows = [["file_name", "time", "inference_result"]] |
| result_list = [] |
| count = 0 |
| try: |
| for audio_path, sentence, duration in read_dataset(test_data): |
| count += 1 |
| print(f"processing {count}: {audio_path}") |
|
|
| t1 = time.time() |
| text = transcribe_file( |
| str(audio_parent/audio_path), model, processor |
| ) |
| t = time.time() - t1 |
| print("inference time:", t) |
| print(text) |
| result_list.append({ |
| "index": count, |
| "audio_path": audio_path, |
| "reference": sentence, |
| "duration": duration, |
| "inference_time": round(t, 3), |
| "inference_result": text |
| }) |
| except Exception as e: |
| print(e) |
| except KeyboardInterrupt as e: |
| print(e) |
| import json |
| with open("csv/whisper_finetuned_dataset_results.json", "w", encoding="utf-8") as f: |
| json.dump(result_list, f, ensure_ascii=False, indent=2) |
|
|
| def run_test_emilia(): |
| from test_data.audios import read_emilia |
| model, processor = load_model() |
| parent = Path("../test_data/ZH-B000008") |
| result_list = [] |
| count = 0 |
| try: |
| for audio_path, sentence, duration in read_emilia(parent, count_limit=5000): |
| count += 1 |
| print(f"processing {count}: {audio_path}") |
|
|
| t1 = time.time() |
| text = transcribe_file( |
| str(audio_path), model, processor |
| ) |
| t = time.time() - t1 |
| print("inference time:", t) |
| print(text) |
| result_list.append({ |
| "index": count, |
| "audio_path": audio_path.name, |
| "reference": sentence, |
| "duration": duration, |
| "inference_time": round(t, 3), |
| "inference_result": text |
| }) |
| except Exception as e: |
| print(e) |
| except KeyboardInterrupt as e: |
| print(e) |
| import json |
| with open("csv/whisper_finetune_emilia_results.json", "w", encoding="utf-8") as f: |
| json.dump(result_list, f, ensure_ascii=False, indent=2) |
|
|
|
|
| def run_test_st(): |
| from test_data.audios import read_st |
| model, processor = load_model() |
| |
| result_list = [] |
| count = 0 |
| try: |
| for audio_path, sentence in read_st(count_limit=5000): |
| count += 1 |
| print(f"processing {count}: {audio_path}") |
|
|
| t1 = time.time() |
| text = transcribe_file( |
| str(audio_path), model, processor |
| ) |
| t = time.time() - t1 |
| print("inference time:", t) |
| print(text) |
| result_list.append({ |
| "index": count, |
| "audio_path": audio_path.name, |
| "reference": sentence, |
| |
| "inference_time": round(t, 3), |
| "inference_result": text |
| }) |
| except Exception as e: |
| print(e) |
| except KeyboardInterrupt as e: |
| print(e) |
| import json |
| with open("csv/whisper_finetune_st_results.json", "w", encoding="utf-8") as f: |
| json.dump(result_list, f, ensure_ascii=False, indent=2) |
|
|
| def run_test_wenet(): |
| from test_data.audios import read_wenet |
| model, processor = load_model() |
| result_list = [] |
| count = 0 |
| try: |
| for audio_path, sentence in read_wenet(count_limit=5000): |
| count += 1 |
| print(f"processing {count}: {audio_path}") |
|
|
| t1 = time.time() |
| text = transcribe_file( |
| str(audio_path), model, processor |
| ) |
| t = time.time() - t1 |
| print("inference time:", t) |
| print(text) |
| result_list.append({ |
| "index": count, |
| "audio_path": audio_path.name, |
| "reference": sentence, |
| |
| "inference_time": round(t, 3), |
| "inference_result": text |
| }) |
| except Exception as e: |
| print(e) |
| except KeyboardInterrupt as e: |
| print(e) |
| import json |
| with open("csv/whisper_finetune_wenet_results.json", "w", encoding="utf-8") as f: |
| json.dump(result_list, f, ensure_ascii=False, indent=2) |
|
|
| if __name__ == "__main__": |
| |
| |
| |
| |
| run_test_wenet() |
|
|