| from transformers import AutoTokenizer, GPT2LMHeadModel |
| from datasets import load_dataset, Dataset, DatasetDict |
| import random |
| import string |
| import torch |
|
|
| from torchmetrics.text import WordErrorRate, CharErrorRate |
|
|
| wer = WordErrorRate() |
| cer = CharErrorRate() |
|
|
| def process(text): |
|
|
| |
| text = text.lower() |
|
|
| |
| punctuation_to_remove = string.punctuation.replace("'", "") |
| translation_table = str.maketrans('', '', punctuation_to_remove) |
| text = text.translate(translation_table) |
|
|
| |
| while text[0] == ' ' or text[-1] == ' ': |
| if text[0] == ' ': |
| text = text[1:] |
| if text[-1] == ' ': |
| text = text[:-1] |
| |
| return text |
|
|
| import jiwer |
| from edit_distance import SequenceMatcher |
| def correct_text(text): |
| transforms = jiwer.Compose( |
| [ |
| jiwer.ExpandCommonEnglishContractions(), |
| jiwer.ToLowerCase(), |
| jiwer.RemoveMultipleSpaces(), |
| jiwer.Strip(), |
| jiwer.RemovePunctuation(), |
| jiwer.ReduceToListOfListOfWords(), |
| ] |
| ) |
| return transforms(text) |
|
|
| def align_gt_asr(gt, asr): |
| sm = SequenceMatcher(a=gt, b=asr) |
| best_path = [] |
| opcodes = sm.get_opcodes() |
| for tag, i1, i2, j1, j2 in opcodes: |
| if tag == "delete": |
| for i in range(i1, i2): |
| best_path.append([gt[i], ""]) |
| if tag == "replace" or tag == "equal": |
| for i, j in zip(range(i1, i2), range(j1, j2)): |
| best_path.append([gt[i], asr[j]]) |
| if tag == "insert": |
| for j in range(j1, j2): |
| best_path.append(["", asr[j]]) |
| return best_path |
|
|
| dtype = torch.float16 |
|
|
| dataset_name = "./../libripseech_tokenized" |
| dataset = DatasetDict.load_from_disk(dataset_name) |
|
|
| with open("./../prompting/blist/all_rare_words.txt") as fin: |
| rarewords = [process(word.strip()) for word in fin] |
|
|
| tokenizer = AutoTokenizer.from_pretrained("./../tokenizer") |
| tokenizer.pad_token_id = 0 |
| tokenizer.pad_token = "<|padding|>" |
| tokenizer.padding_side = "left" |
|
|
| |
| tokenizer.add_tokens(["<|startofprompt|>", "<|sepofprompt|>", "<|endofprompt|>"]) |
| sot_token = tokenizer.encode("<|startoftranscript|>")[0] |
| eot_token = tokenizer.encode("<|endoftranscript|>")[0] |
|
|
| from math import ceil |
| from tqdm import tqdm |
|
|
| val_bs = 32 |
| n_bwords = 25 |
| context_length = 2048 |
|
|
| def prepare(element): |
| |
| |
| audio_tkns = element["audio_tokens"] |
| data = "".join([f"<|audio:{tkn}|>" for tkn in audio_tkns]) |
| |
| |
| b_words = element["b_words"] |
| if n_bwords > len(b_words): |
| context = b_words + random.sample(rarewords, n_bwords - len(b_words)) |
| else: |
| context = random.sample(b_words, n_bwords) |
| random.shuffle(context) |
| |
| |
| data += "<|startofprompt|>" + "<|sepofprompt|>".join(context) + "<|endofprompt|>" |
| |
| |
| data += "<|startoftranscript|>" |
| |
| return {"data": data, "context": context} |
|
|
| @torch.no_grad() |
| def evaluate_model(model): |
|
|
| transcripts = [] |
| |
| processed_data = dataset["test.clean"].map(prepare) |
| data = processed_data["data"] |
|
|
| for idx in tqdm(range(ceil(len(data)/val_bs))): |
|
|
| outputs = tokenizer(data[idx * val_bs: (idx + 1) * val_bs], truncation=False, max_length=None, padding=True, return_tensors="pt").to(model.device) |
| input_ids = outputs["input_ids"] |
| par = input_ids.shape[-1] |
|
|
| generations = model.generate( |
| input_ids, |
| max_new_tokens=context_length - par - 1, |
| eos_token_id = eot_token |
| ) |
| transcripts += tokenizer.batch_decode(generations[:, par:], skip_special_tokens=True) |
| |
| bias_word_cnt = 0 |
| normal_word_cnt = 0 |
| u_wer = 0.0 |
| b_wer = 0.0 |
| pred_list = correct_text(transcripts) |
| text_list = correct_text(processed_data["text"]) |
| prompt_list = processed_data["context"] |
| for a, b, c in zip(pred_list, text_list, prompt_list): |
| aligned_pair = align_gt_asr(b, a) |
| for gt_word, asr_word in aligned_pair: |
| if gt_word in c or asr_word in c: |
| if gt_word != asr_word: |
| b_wer += 1.0 |
| if gt_word in c: |
| bias_word_cnt += 1 |
| else: |
| if gt_word != asr_word: |
| u_wer += 1.0 |
| if gt_word != "": |
| normal_word_cnt += 1 |
| u_wer = u_wer / normal_word_cnt * 100 |
| b_wer = b_wer / bias_word_cnt * 100 |
| |
| return wer(transcripts, processed_data["text"]).item() * 100, cer(transcripts, processed_data["text"]).item() * 100, b_wer, u_wer |