# Copyright (C) 2025. Huawei Technologies Co., Ltd. All Rights Reserved. (authors: Yusen Sun, # Xiao Chen) # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse from simple_infer import Text2TokenGenerator, dummy_encode_fn from fairseq.dataclass.configs import FairseqConfig import fileinput from fairseq import utils, options from fairseq.token_generation_constraints import pack_constraints, unpack_constraints from fairseq_cli.generate import get_symbols_to_strip_from_output from fairseq.dataclass.utils import convert_namespace_to_omegaconf from collections import namedtuple import time import logging import sys import os from tqdm import tqdm Batch = namedtuple("Batch", "ids src_tokens src_lengths constraints") class T2USeedTTS(Text2TokenGenerator): def __init__(self, args): super().__init__(args) def buffered_read(self, input, buffer_size): buffer = [] with fileinput.input( files=[input], openhook=fileinput.hook_encoded("utf-8") ) as h: for src_str in h: fields = src_str.strip().split("|") phones = self.text2phone(fields[-1]) buffer.append( [fields[0], fields[1], fields[2], fields[3], phones] ) # (ref_wav, ref_wav_tokens, id, text, phones) if len(buffer) >= buffer_size: yield buffer buffer = [] if len(buffer) > 0: yield buffer def generate_for_text_file_input(self, input): start_time = time.time() total_translate_time = 0 hypo_outputs = [] start_id = 0 for inputs in self.buffered_read(input, self.cfg.interactive.buffer_size): phone_lines = [x[-1] for x in inputs] results = [] for batch in self.make_batches(phone_lines, dummy_encode_fn): bsz = batch.src_tokens.size(0) src_tokens = batch.src_tokens src_lengths = batch.src_lengths constraints = batch.constraints if self.use_cuda: src_tokens = src_tokens.cuda() src_lengths = src_lengths.cuda() if constraints is not None: constraints = constraints.cuda() sample = { "net_input": { "src_tokens": src_tokens, "src_lengths": src_lengths, }, } logging.info(f"Processing batch of size: {bsz}") translate_start_time = time.time() translations = self.task.inference_step( self.generator, self.models, sample, constraints=constraints ) translate_time = time.time() - translate_start_time total_translate_time += translate_time list_constraints = [[] for _ in range(bsz)] # if self.cfg.generation.constraints: # list_constraints = [unpack_constraints(c) for c in constraints] for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)): src_tokens_i = utils.strip_pad(src_tokens[i], self.tgt_dict.pad()) constraints = list_constraints[i] results.append( ( start_id + id, src_tokens_i, hypos, { "constraints": constraints, "time": translate_time / len(translations), }, ) ) # sort output to match input order for id_, src_tokens, hypos, info in sorted(results, key=lambda x: x[0]): output = {} output["src_tokens"] = [] # src info input_info = inputs[id_ % self.cfg.interactive.buffer_size] output["src_info"] = input_info # src_str = "" if self.src_dict is not None: src_str = self.src_dict.string( src_tokens, self.cfg.common_eval.post_process ) if src_str != input_info[-1]: logging.info(f"ERROR, input output mismatch!!") logging.info(f"{src_str}") logging.info(f"{ input_info[-1]}") output["src_tokens"] = src_str.split() # Process top predictions output["hypotheses"] = [] for hypo in hypos[: min(len(hypos), self.cfg.generation.nbest)]: hypo_str = self.tgt_dict.string( hypo["tokens"].int().cpu(), self.cfg.common_eval.post_process, extra_symbols_to_ignore=get_symbols_to_strip_from_output( self.generator ), ) output["hypotheses"].append( { "hypo_tokens": hypo_str.split(), "alignment": hypo["alignment"], } ) hypo_outputs.append(output) logging.info(f"output records: {len(hypo_outputs)}") # update running id_ counter start_id += len(inputs) logging.info( "Total time: {:.3f} seconds; translation time: {:.3f}".format( time.time() - start_time, total_translate_time ) ) return hypo_outputs def generate_for_long_text_input_file(self, input, max_segment_len=0): start_time = time.time() total_translate_time = 0 hypo_outputs = [] for inputs in self.buffered_read(input, self.cfg.interactive.buffer_size): logging.info(f"processing inputs: {len(inputs)}") phones = [input_info[-1] for input_info in inputs] hypo_tokens, translate_time = self.generate_for_long_input_text( phones, max_segment_len=max_segment_len ) total_translate_time += translate_time for tok, info in zip(hypo_tokens, inputs): hypo_outputs.append({"hypotheses": tok, "src_info": info}) logging.info( "Total time: {:.3f} seconds; translation time: {:.3f}".format( time.time() - start_time, total_translate_time ) ) return hypo_outputs def infer(unk_args, output_file, max_seg_len): output_fp = sys.stdout if output_file is not None: output_fp = open(output_file, "w") t2u = T2USeedTTS(unk_args) logging.info(f"Using max-seg-len = {max_seg_len}") if max_seg_len <= 0: speech_tokens_info = t2u.generate_for_text_file_input(t2u.cfg.interactive.input) for infor in speech_tokens_info: token_str = " ".join(infor["hypotheses"][0]["hypo_tokens"]) text = infor["src_info"][3] ref_wav = infor["src_info"][0] ref_token = infor["src_info"][1] test_id = infor["src_info"][2] test_line = f"{ref_wav}|{ref_token}|{test_id}.wav|{token_str}|{text}" output_fp.write(test_line + "\n") else: logging.info(f"Split long text into segments of length: {max_seg_len}") speech_tokens_info = t2u.generate_for_long_text_input_file( t2u.cfg.interactive.input, max_segment_len=max_seg_len ) for infor in speech_tokens_info: token_str = " ".join(infor["hypotheses"]) text = infor["src_info"][3] ref_wav = infor["src_info"][0] ref_token = infor["src_info"][1] test_id = infor["src_info"][2] test_line = f"{ref_wav}|{ref_token}|{test_id}.wav|{token_str}|{text}" output_fp.write(test_line + "\n") # speech_tokens_info = t2u.generate("只有当科技为本地社群创造价值的时候,才真正有意义。") # output_fp.write(" ".join(speech_tokens_info["hypotheses"][0]["hypo_tokens"]) + "\n") output_fp.flush() output_fp.close() return if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--output", dest="output", required=False, default=None, help="output file", ) parser.add_argument( "--max-seg-len", dest="max_seg_len", required=False, default=0, type=int, help="max segment length", ) args, unknown_args = parser.parse_known_args() infer(unknown_args, args.output, args.max_seg_len)