| from typing import Dict, List, Any |
| import sys, os, re |
| from tqdm import tqdm |
|
|
| import torch |
| from transformers import AutoModelForSeq2SeqLM, BitsAndBytesConfig |
| from IndicTransTokenizer.utils import preprocess_batch, postprocess_batch |
| from IndicTransTokenizer.tokenizer import IndicTransTokenizer |
|
|
|
|
| class EndpointHandler(): |
| def __init__(self, direction = "en-indic", quantization = ""): |
| self.model_name = "ai4bharat/indictrans2-en-indic-1B" |
|
|
| self.utterance_pattern = re.compile(r"^\d+$") |
| self.timestamp_pattern = re.compile(r"(\d+:\d+:\d+,\d+)\s*-->\s*(\d+:\d+:\d+,\d+)") |
|
|
| self.BATCH_SIZE = 16 |
| self.DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| self.model = None |
| self.tokenizer = None |
|
|
| if quantization == "4-bit": |
| qconfig = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_use_double_quant=True, |
| bnb_4bit_compute_dtype=torch.bfloat16, |
| ) |
| elif quantization == "8-bit": |
| qconfig = BitsAndBytesConfig( |
| load_in_8bit=True, |
| bnb_8bit_use_double_quant=True, |
| bnb_8bit_compute_dtype=torch.bfloat16, |
| ) |
| else: |
| qconfig = None |
|
|
| self.tokenizer = IndicTransTokenizer(direction=direction) |
| self.model = AutoModelForSeq2SeqLM.from_pretrained( |
| self.model_name, |
| trust_remote_code=True, |
| low_cpu_mem_usage=True, |
| quantization_config=qconfig |
| ) |
|
|
| if qconfig==None: |
| self.model = self.model.to(self.DEVICE) |
| self.model.half() |
|
|
| self.model.eval() |
|
|
|
|
| def batch_translate(self, input_sentences, src_lang, tgt_lang): |
| translations = [] |
| for i in range(0, len(input_sentences), self.BATCH_SIZE): |
| batch = input_sentences[i : i + self.BATCH_SIZE] |
|
|
| |
| batch, entity_map = preprocess_batch( |
| batch, src_lang=src_lang, tgt_lang=tgt_lang |
| ) |
|
|
| |
| inputs = self.tokenizer( |
| batch, |
| src=True, |
| truncation=True, |
| padding="longest", |
| return_tensors="pt", |
| return_attention_mask=True, |
| ).to(self.DEVICE) |
|
|
| |
| with torch.no_grad(): |
| generated_tokens = self.model.generate( |
| **inputs, |
| use_cache=True, |
| min_length=0, |
| max_length=256, |
| num_beams=5, |
| num_return_sequences=1, |
| ) |
|
|
| |
| generated_tokens = self.tokenizer.batch_decode( |
| generated_tokens.detach().cpu().tolist(), src=False |
| ) |
|
|
| |
| translations += postprocess_batch( |
| generated_tokens, lang=tgt_lang, placeholder_entity_map=entity_map |
| ) |
|
|
| del inputs |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| return translations |
|
|
|
|
| def read_srt(self, srt_path): |
| data = [] |
| with open(srt_path, 'r', encoding='utf-8') as fp: |
| utterance_ind = "" |
| start_end = "" |
| text = "" |
| for ind, line in enumerate(fp.readlines()): |
| line = line.strip() |
| if re.search(self.utterance_pattern, line) is not None: |
| utterance_ind = line |
| elif re.search(self.timestamp_pattern, line) is not None: |
| start_end = line |
| else: |
| text += line |
|
|
| if utterance_ind!='' and start_end!='' and text!='': |
| data.append({'utterance_ind': utterance_ind, 'start_end': start_end, 'text': text}) |
| utterance_ind = '' |
| start_end = '' |
| text = '' |
|
|
| return data |
|
|
| def test(self, inputs) -> List[Dict[str, Any]]: |
| """ |
| data args: |
| inputs (:obj: (transcript_path : 'str', src_lang : 'str', tgt_lang : 'str') |
| kwargs |
| Return: |
| A :obj:`list` | `dict`: will be serialized and returned |
| """ |
|
|
| src_lang = inputs["src_lang"] |
| tgt_lang = inputs["tgt_lang"] |
| transcript_path = inputs["transcript_path"] |
|
|
| output_translations = [] |
| if self.model is not None: |
| transcriptions = self.read_srt(transcript_path) |
| trans_sents = [entry['text'] for entry in transcriptions] |
| indic_translations = self.batch_translate(trans_sents, src_lang, tgt_lang) |
|
|
| for i in tqdm(range(len(transcriptions))): |
| entry = transcriptions[i] |
| entry['text'] = indic_translations[i] |
| output_translations.append(entry) |
|
|
| return output_translations |
| else: |
| return [] |
|
|
| def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| """ |
| data args: |
| inputs (:obj: (transcript_path : 'str', src_lang : 'str', tgt_lang : 'str') |
| kwargs |
| Return: |
| A :obj:`list` | `dict`: will be serialized and returned |
| """ |
|
|
| inputs = data.pop("inputs",data) |
|
|
| src_lang = inputs["src_lang"] |
| tgt_lang = inputs["tgt_lang"] |
| transcript_path = inputs["transcript_path"] |
|
|
| output_translations = [] |
| if self.model is not None: |
| transcriptions = self.read_srt(transcript_path) |
| trans_sents = [entry['text'] for entry in transcriptions] |
| indic_translations = self.batch_translate(trans_sents, src_lang, tgt_lang) |
|
|
| for i in tqdm(range(len(transcriptions))): |
| entry = transcriptions[i] |
| entry['text'] = indic_translations[i] |
| output_translations.append(entry) |
|
|
| return output_translations |
| else: |
| return [] |