| import argparse |
| import time |
| import wave |
| from pathlib import Path |
| from typing import Tuple |
|
|
| import numpy as np |
| import sherpa_onnx |
| from huggingface_hub import hf_hub_download |
|
|
|
|
| def get_args(): |
| parser = argparse.ArgumentParser( |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter |
| ) |
|
|
| parser.add_argument( |
| "--lang", |
| type=str, |
| required=True, |
| help="Language code (e.g., 'en', 'fr', 'de')", |
| ) |
|
|
| parser.add_argument( |
| "--hf-token", |
| type=str, |
| required=True, |
| help="Hugging Face access token for private model repository", |
| ) |
|
|
| parser.add_argument( |
| "--num-threads", |
| type=int, |
| default=1, |
| help="Number of threads for neural network computation", |
| ) |
|
|
| parser.add_argument( |
| "--decoding-method", |
| type=str, |
| default="greedy_search", |
| help="Valid values: greedy_search and modified_beam_search", |
| ) |
|
|
| parser.add_argument( |
| "--max-active-paths", |
| type=int, |
| default=4, |
| help="Used only when --decoding-method is modified_beam_search.", |
| ) |
|
|
| parser.add_argument( |
| "--lm", |
| type=str, |
| default="", |
| help="Used only when --decoding-method is modified_beam_search. Path of language model.", |
| ) |
|
|
| parser.add_argument( |
| "--lm-scale", |
| type=float, |
| default=0.1, |
| help="Used only when --decoding-method is modified_beam_search. Scale of language model.", |
| ) |
|
|
| parser.add_argument( |
| "--provider", |
| type=str, |
| default="cpu", |
| help="Valid values: cpu, cuda, coreml", |
| ) |
|
|
| parser.add_argument( |
| "--hotwords-file", |
| type=str, |
| default="", |
| help="The file containing hotwords, one word/phrase per line.", |
| ) |
|
|
| parser.add_argument( |
| "--hotwords-score", |
| type=float, |
| default=1.5, |
| help="Hotword score for biasing word/phrase. Used only if --hotwords-file is given.", |
| ) |
|
|
| parser.add_argument( |
| "sound_files", |
| type=str, |
| nargs="+", |
| help="The input sound file(s) to decode. Must be WAVE format, single channel, 16-bit.", |
| ) |
|
|
| return parser.parse_args() |
|
|
|
|
| def assert_file_exists(filename: str): |
| assert Path(filename).is_file(), f"{filename} does not exist!" |
|
|
|
|
| def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: |
| with wave.open(wave_filename) as f: |
| assert f.getnchannels() == 1, f.getnchannels() |
| assert f.getsampwidth() == 2, f.getsampwidth() |
| num_samples = f.getnframes() |
| samples = f.readframes(num_samples) |
| samples_int16 = np.frombuffer(samples, dtype=np.int16) |
| samples_float32 = samples_int16.astype(np.float32) / 32768 |
| return samples_float32, f.getframerate() |
|
|
|
|
| def download_models(language_code, hf_token): |
| """Downloads encoder, decoder, joiner, and tokens.txt from Hugging Face.""" |
| repo_id = "Banafo/test-onnx" |
|
|
| model_filenames = { |
| "encoder": f"{language_code}_encoder.onnx", |
| "decoder": f"{language_code}_decoder.onnx", |
| "joiner": f"{language_code}_joiner.onnx", |
| "tokens": f"{language_code}_tokens.txt", |
| } |
|
|
| model_paths = {} |
| for model_name, filename in model_filenames.items(): |
| print(f"Downloading {filename}...") |
| model_paths[model_name] = hf_hub_download(repo_id=repo_id, filename=filename, token=hf_token) |
| print(f"Loaded {filename}") |
|
|
| return model_paths |
|
|
|
|
| def main(): |
| args = get_args() |
|
|
| |
| model_paths = download_models(args.lang, args.hf_token) |
|
|
| |
| recognizer = sherpa_onnx.OnlineRecognizer.from_transducer( |
| tokens=model_paths["tokens"], |
| encoder=model_paths["encoder"], |
| decoder=model_paths["decoder"], |
| joiner=model_paths["joiner"], |
| num_threads=args.num_threads, |
| provider=args.provider, |
| sample_rate=16000, |
| feature_dim=80, |
| decoding_method=args.decoding_method, |
| max_active_paths=args.max_active_paths, |
| lm=args.lm, |
| lm_scale=args.lm_scale, |
| hotwords_file=args.hotwords_file, |
| hotwords_score=args.hotwords_score, |
| ) |
|
|
| print("Started!") |
| start_time = time.time() |
|
|
| streams = [] |
| total_duration = 0 |
| for wave_filename in args.sound_files: |
| assert_file_exists(wave_filename) |
| samples, sample_rate = read_wave(wave_filename) |
| duration = len(samples) / sample_rate |
| total_duration += duration |
|
|
| s = recognizer.create_stream() |
| s.accept_waveform(sample_rate, samples) |
|
|
| tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32) |
| s.accept_waveform(sample_rate, tail_paddings) |
| s.input_finished() |
|
|
| streams.append(s) |
|
|
| while True: |
| ready_list = [s for s in streams if recognizer.is_ready(s)] |
| if not ready_list: |
| break |
| recognizer.decode_streams(ready_list) |
|
|
| results = [recognizer.get_result(s) for s in streams] |
| end_time = time.time() |
| print("Done!") |
|
|
| for wave_filename, result in zip(args.sound_files, results): |
| print(f"{wave_filename}\n{result}") |
| print("-" * 10) |
|
|
| elapsed_seconds = end_time - start_time |
| rtf = elapsed_seconds / total_duration |
| print(f"num_threads: {args.num_threads}") |
| print(f"decoding_method: {args.decoding_method}") |
| print(f"Wave duration: {total_duration:.3f} s") |
| print(f"Elapsed time: {elapsed_seconds:.3f} s") |
| print(f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|