| import numpy as np |
| import torch |
| from transformers import MBartForConditionalGeneration, MBart50TokenizerFast |
| from lang_list import LANGUAGE_NAME_TO_CODE, WHISPER_LANGUAGES |
| import argparse |
| import re |
| from tqdm import tqdm |
|
|
| MAX_LENGTH = 500 |
| MAGIC_STRING = "[$&]" |
| DEBUG = False |
|
|
| language_dict = {} |
| |
| for language_name, language_code in LANGUAGE_NAME_TO_CODE.items(): |
| |
| lang_code = language_code.split('_')[0].lower() |
| |
| |
| if lang_code in WHISPER_LANGUAGES: |
| |
| language_dict[language_name] = { |
| "transcriber": lang_code, |
| "translator": language_code |
| } |
|
|
| def translate(transcribed_text, source_languaje, target_languaje, translate_model, translate_tokenizer, device="cpu"): |
| |
| source_languaje_code = language_dict[source_languaje]["translator"] |
| target_languaje_code = language_dict[target_languaje]["translator"] |
|
|
| encoded = translate_tokenizer(transcribed_text, return_tensors="pt").to(device) |
| generated_tokens = translate_model.generate( |
| **encoded, |
| forced_bos_token_id=translate_tokenizer.lang_code_to_id[target_languaje_code] |
| ) |
| translated = translate_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] |
|
|
| return translated |
|
|
| def main(transcription_file, source_languaje, target_languaje, translate_model, translate_tokenizer, device): |
| output_folder = "translated_transcriptions" |
| _, transcription_file_name = transcription_file.split("/") |
| transcription_file_name, _ = transcription_file_name.split(".") |
|
|
| |
| with open(transcription_file, "r") as f: |
| transcription = f.read().splitlines() |
| |
| |
| raw_transcription = "" |
| progress_bar = tqdm(total=len(transcription), desc='Concatenate transcriptions progress') |
| for line in transcription: |
| if re.match(r"\d+$", line): |
| pass |
| elif re.match(r"\d\d:\d\d:\d\d,\d+ --> \d\d:\d\d:\d\d,\d+", line): |
| pass |
| elif re.match(r"^$", line): |
| pass |
| else: |
| line = re.sub(r"\[SPEAKER_\d\d\]:", MAGIC_STRING, line) |
| raw_transcription += f"{line} " |
| progress_bar.update(1) |
| progress_bar.close() |
| |
| |
| if DEBUG: |
| output_file = f"{output_folder}/{transcription_file_name}_raw.srt" |
| with open(output_file, "w") as f: |
| f.write(raw_transcription) |
|
|
| |
| raw_transcription_list = raw_transcription.split(MAGIC_STRING) |
| if raw_transcription_list[0] == "": |
| raw_transcription_list = raw_transcription_list[1:] |
|
|
| |
| translated_transcription = "" |
| concatenate_transcription = raw_transcription_list[0] + MAGIC_STRING |
| progress_bar = tqdm(total=len(raw_transcription_list), desc='Translate transcriptions progress') |
| progress_bar.update(1) |
| if len(raw_transcription_list) > 1: |
| for transcription in raw_transcription_list[1:]: |
| if len(concatenate_transcription) + len(transcription) < MAX_LENGTH: |
| concatenate_transcription += transcription + MAGIC_STRING |
| else: |
| translation = translate(concatenate_transcription, source_languaje, target_languaje, translate_model, translate_tokenizer, device) |
| translated_transcription += translation |
| concatenate_transcription = transcription + MAGIC_STRING |
| progress_bar.update(1) |
| |
| translation = translate(concatenate_transcription, source_languaje, target_languaje, translate_model, translate_tokenizer, device) |
| translated_transcription += translation |
| else: |
| translated_transcription = translate(concatenate_transcription, source_languaje, target_languaje, translate_model, translate_tokenizer, device) |
| progress_bar.close() |
| |
| |
| if DEBUG: |
| output_file = f"{output_folder}/{transcription_file_name}_{target_languaje}_raw.srt" |
| with open(output_file, "w") as f: |
| f.write(translated_transcription) |
| |
| |
| with open(transcription_file, "r") as f: |
| transcription = f.read().splitlines() |
|
|
| |
| translated_transcription_time_stamps = "" |
| translated_transcription_list = translated_transcription.split(MAGIC_STRING) |
| progress_bar = tqdm(total=len(translated_transcription_list), desc='Add time stamps to translated transcriptions progress') |
| i = 0 |
| for line in transcription: |
| if re.match(r"\d+$", line): |
| translated_transcription_time_stamps += f"{line}\n" |
| elif re.match(r"\d\d:\d\d:\d\d,\d+ --> \d\d:\d\d:\d\d,\d+", line): |
| translated_transcription_time_stamps += f"{line}\n" |
| elif re.match(r"^$", line): |
| translated_transcription_time_stamps += f"{line}\n" |
| else: |
| if (i < len(translated_transcription_list)): |
| if len(translated_transcription_list[i]) > 0: |
| if translated_transcription_list[i][0] == " ": |
| translated_transcription_list[i] = translated_transcription_list[i][1:] |
| speaker = "" |
| if re.match(r"\[SPEAKER_\d\d\]:", line): |
| speaker = re.match(r"\[SPEAKER_\d\d\]:", line).group(0) |
| translated_transcription_time_stamps += f"{speaker} {translated_transcription_list[i]}\n" |
| i += 1 |
| progress_bar.update(1) |
| progress_bar.close() |
| |
| |
| output_file = f"{output_folder}/{transcription_file_name}_{target_languaje}.srt" |
| with open(output_file, "w") as f: |
| f.write(translated_transcription_time_stamps) |
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("transcription_file", help="Transcribed text") |
| parser.add_argument("--source_languaje", type=str, required=True) |
| parser.add_argument("--target_languaje", type=str, required=True) |
| parser.add_argument("--device", type=str, default="cpu") |
| args = parser.parse_args() |
|
|
| transcription_file = args.transcription_file |
| source_languaje = args.source_languaje |
| target_languaje = args.target_languaje |
| device = args.device |
|
|
| |
| print("Loading translation model") |
| translate_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt").to(device) |
| translate_tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt") |
| print("Translation model loaded") |
|
|
| main(transcription_file, source_languaje, target_languaje, translate_model, translate_tokenizer, device) |
|
|