| import argparse |
| import os |
| import time |
| import logging |
|
|
| from fireredasr_axmodel import FireRedASRAxModel |
|
|
| logger = logging.getLogger() |
| logger.setLevel(logging.INFO) |
| logger_stream_hander = logging.StreamHandler() |
| logger_stream_hander.setLevel("INFO") |
| logger.addHandler(logger_stream_hander) |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description="FireRedASRAxModel Test") |
| parser.add_argument( |
| "--encoder", |
| type=str, |
| default="axmodel/encoder.axmodel", |
| help="Path to axmodel encoder", |
| ) |
| parser.add_argument( |
| "--decoder_loop", |
| type=str, |
| default="axmodel/decoder_loop.axmodel", |
| help="Path to axmodel decoder loop", |
| ) |
| parser.add_argument( |
| "--cmvn", type=str, default="axmodel/cmvn.ark", help="Path to cmvn" |
| ) |
| parser.add_argument( |
| "--dict", type=str, default="axmodel/dict.txt", help="Path to dict" |
| ) |
| parser.add_argument( |
| "--spm_model", |
| type=str, |
| default="axmodel/train_bpe1000.model", |
| help="Path to spm model", |
| ) |
| parser.add_argument( |
| "--wavlist", type=str, default="wavlist.txt", help="File to wav path list" |
| ) |
| parser.add_argument( |
| "--hypo", type=str, default="hypo_axmodel.txt", help="File of hypos" |
| ) |
| parser.add_argument("--beam_size", type=int, default=1, help="") |
| parser.add_argument("--nbest", type=int, default=1, help="") |
| parser.add_argument("--decode_max_len", type=int, default=128, help="max token len") |
| parser.add_argument("--max_dur", type=int, default=10, help="max audio len") |
|
|
| return parser.parse_args() |
|
|
|
|
| def parse_wavlist(wavlist: str): |
| wavpaths = [] |
| with open(wavlist) as f: |
| for line in f: |
| line = line.strip() |
| if not os.path.exists(line): |
| print(f"{line} doesn't exist.") |
| continue |
| wavpaths.append(line) |
|
|
| return wavpaths |
|
|
|
|
| def main(): |
| args = parse_args() |
| print(args) |
|
|
| model = FireRedASRAxModel( |
| args.encoder, |
| args.decoder_loop, |
| args.cmvn, |
| args.dict, |
| args.spm_model, |
| decode_max_len=args.decode_max_len, |
| audio_dur=args.max_dur, |
| ) |
|
|
| wf = open(args.hypo, "wt") |
| wavlist = parse_wavlist(args.wavlist) |
|
|
| total_wav_durations = 0 |
| total_transcribe_durations = 0 |
| for wav in wavlist: |
| batch_wav = [wav] |
| result, wav_durations, transcribe_durations = model.transcribe( |
| batch_wav, args.beam_size, args.nbest |
| ) |
|
|
| wav_durations = sum(wav_durations) |
| total_wav_durations += wav_durations |
| total_transcribe_durations += transcribe_durations |
| logger.info(f"{batch_wav}") |
| logger.info(f"Durations: {wav_durations}") |
| logger.info(f"Transcribe Durations: {transcribe_durations}") |
| rtf = transcribe_durations / wav_durations |
| logger.info(f"(Real time factor) RTF: {rtf}") |
|
|
| text = result["text"] |
| logger.info(f"text: {text}") |
| logger.info("") |
| wf.write(f"{text}\n") |
|
|
| logger.info(f"total wav durations: {total_wav_durations}") |
| logger.info(f"total transcribe durations: {total_transcribe_durations}") |
| avg_ref = total_transcribe_durations / total_wav_durations |
| logger.info(f"AVG RTF: {avg_ref}") |
|
|
| wf.close() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|