|
|
|
|
| import re |
| import bw2ar |
| import torch |
| import xer |
|
|
| |
| FATHATAN = u'\u064b' |
| DAMMATAN = u'\u064c' |
| KASRATAN = u'\u064d' |
| FATHA = u'\u064e' |
| DAMMA = u'\u064f' |
| KASRA = u'\u0650' |
| SHADDA = u'\u0651' |
| SUKUN = u'\u0652' |
| TATWEEL = u'\u0640' |
|
|
| HARAKAT_PAT = re.compile(u"["+u"".join([FATHATAN, DAMMATAN, KASRATAN, |
| FATHA, DAMMA, KASRA, SUKUN, |
| SHADDA])+u"]") |
|
|
|
|
| class TashkeelTokenizer: |
|
|
| def __init__(self): |
| self.letters = [' ', '$', '&', "'", '*', '<', '>', 'A', 'D', 'E', 'H', 'S', 'T', 'Y', 'Z', |
| 'b', 'd', 'f', 'g', 'h', 'j', 'k', 'l', 'm', 'n', 'p', 'q', 'r', 's', 't', |
| 'v', 'w', 'x', 'y', 'z', '|', '}' |
| ] |
| self.letters = ['<PAD>', '<BOS>', '<EOS>'] + self.letters + ['<MASK>'] |
|
|
| self.no_tashkeel_tag = '<NT>' |
| self.tashkeel_list = ['<NT>', '<SD>', '<SDD>', '<SF>', '<SFF>', '<SK>', |
| '<SKK>', 'F', 'K', 'N', 'a', 'i', 'o', 'u', '~'] |
|
|
| self.tashkeel_list = ['<PAD>', '<BOS>', '<EOS>'] + self.tashkeel_list |
|
|
| self.tashkeel_map = {c:i for i,c in enumerate(self.tashkeel_list)} |
| self.letters_map = {c:i for i,c in enumerate(self.letters)} |
| self.inverse_tags = { |
| '~a': '<SF>', |
| '~u': '<SD>', |
| '~i': '<SK>', |
| '~F': '<SFF>', |
| '~N': '<SDD>', |
| '~K': '<SKK>' |
| } |
| self.tags = {v:k for k,v in self.inverse_tags.items()} |
| self.shaddah_last = ['a~', 'u~', 'i~', 'F~', 'N~', 'K~'] |
| self.shaddah_first = ['~a', '~u', '~i', '~F', '~N', '~K'] |
| self.tahkeel_chars = ['F','N','K','a', 'u', 'i', '~', 'o'] |
|
|
|
|
| def clean_text(self, text): |
| text = re.sub(u'[%s]' % u'\u0640', '', text) |
| text = text.replace('ٱ', 'ا') |
| return ' '.join(re.sub(u"[^\u0621-\u063A\u0640-\u0652\u0670\u0671\ufefb\ufef7\ufef5\ufef9 ]", " ", text, flags=re.UNICODE).split()) |
|
|
|
|
| def check_match(self, text_with_tashkeel, letter_n_tashkeel_pairs): |
| text_with_tashkeel = text_with_tashkeel.strip() |
| |
| syn_text = self.combine_tashkeel_with_text(letter_n_tashkeel_pairs) |
| return syn_text == text_with_tashkeel or syn_text == self.unify_shaddah_position(text_with_tashkeel) |
|
|
|
|
| def unify_shaddah_position(self, text_with_tashkeel): |
| |
| for i in range(len(self.shaddah_first)): |
| text_with_tashkeel = text_with_tashkeel.replace(self.shaddah_last[i], self.shaddah_first[i]) |
| return text_with_tashkeel |
|
|
|
|
| def split_tashkeel_from_text(self, text_with_tashkeel, test_match=True): |
| text_with_tashkeel = self.clean_text(text_with_tashkeel) |
| text_with_tashkeel = bw2ar.transliterate_text(text_with_tashkeel, 'ar2bw') |
| text_with_tashkeel = text_with_tashkeel.replace('`', '') |
|
|
| |
| text_with_tashkeel = self.unify_shaddah_position(text_with_tashkeel) |
|
|
| |
| for i in range(len(self.tahkeel_chars)): |
| text_with_tashkeel = text_with_tashkeel.replace(self.tahkeel_chars[i]*2, self.tahkeel_chars[i]) |
|
|
| letter_n_tashkeel_pairs = [] |
| for i in range(len(text_with_tashkeel)): |
| |
| if i < (len(text_with_tashkeel) - 1) and not text_with_tashkeel[i] in self.tashkeel_list and text_with_tashkeel[i+1] in self.tashkeel_list: |
| |
| |
| if text_with_tashkeel[i+1] == '~': |
| |
| |
| |
| |
| if i+2 < len(text_with_tashkeel) and f'~{text_with_tashkeel[i+2]}' in self.inverse_tags: |
| letter_n_tashkeel_pairs.append((text_with_tashkeel[i], self.inverse_tags[f'~{text_with_tashkeel[i+2]}'])) |
| else: |
| |
| letter_n_tashkeel_pairs.append((text_with_tashkeel[i], '~')) |
| else: |
| letter_n_tashkeel_pairs.append((text_with_tashkeel[i], text_with_tashkeel[i+1])) |
| |
| |
| |
| elif not text_with_tashkeel[i] in self.tashkeel_list: |
| letter_n_tashkeel_pairs.append((text_with_tashkeel[i], self.no_tashkeel_tag)) |
|
|
| if test_match: |
| |
| assert self.check_match(text_with_tashkeel, letter_n_tashkeel_pairs) |
| return [('<BOS>', '<BOS>')] + letter_n_tashkeel_pairs + [('<EOS>', '<EOS>')] |
|
|
|
|
| def combine_tashkeel_with_text(self, letter_n_tashkeel_pairs): |
| combined_with_tashkeel = [] |
| for letter, tashkeel in letter_n_tashkeel_pairs: |
| combined_with_tashkeel.append(letter) |
| if tashkeel in self.tags: |
| combined_with_tashkeel.append(self.tags[tashkeel]) |
| elif tashkeel != self.no_tashkeel_tag: |
| combined_with_tashkeel.append(tashkeel) |
| text = ''.join(combined_with_tashkeel) |
| return text |
|
|
|
|
| def encode(self, text_with_tashkeel, test_match=True): |
| letter_n_tashkeel_pairs = self.split_tashkeel_from_text(text_with_tashkeel, test_match) |
| text, tashkeel = zip(*letter_n_tashkeel_pairs) |
| input_ids = [self.letters_map[c] for c in text] |
| target_ids = [self.tashkeel_map[c] for c in tashkeel] |
| return torch.LongTensor(input_ids), torch.LongTensor(target_ids) |
|
|
|
|
| def filter_tashkeel(self, tashkeel): |
| tmp = [] |
| for i, t in enumerate(tashkeel): |
| if i != 0 and t == '<BOS>': |
| t = self.no_tashkeel_tag |
| elif i != (len(tashkeel) - 1) and t == '<EOS>': |
| t = self.no_tashkeel_tag |
| tmp.append(t) |
| tashkeel = tmp |
| return tashkeel |
|
|
|
|
| def decode(self, input_ids, target_ids): |
| |
| |
| input_ids = input_ids.cpu().tolist() |
| target_ids = target_ids.cpu().tolist() |
| ar_texts = [] |
| for j in range(len(input_ids)): |
| letters = [self.letters[i] for i in input_ids[j]] |
| tashkeel = [self.tashkeel_list[i] for i in target_ids[j]] |
|
|
| letters = list(filter(lambda x: x != '<BOS>' and x != '<EOS>' and x != '<PAD>', letters)) |
| tashkeel = self.filter_tashkeel(tashkeel) |
| tashkeel = list(filter(lambda x: x != '<BOS>' and x != '<EOS>' and x != '<PAD>', tashkeel)) |
|
|
| |
| letter_n_tashkeel_pairs = list(zip(letters, tashkeel)) |
| bw_text = self.combine_tashkeel_with_text(letter_n_tashkeel_pairs) |
| ar_text = bw2ar.transliterate_text(bw_text, 'bw2ar') |
| ar_texts.append(ar_text) |
| return ar_texts |
|
|
| def get_tashkeel_with_case_ending(self, text, case_ending=True): |
| text_split = self.split_tashkeel_from_text(text, test_match=False) |
| text_spaces_indecies = [i for i, el in enumerate(text_split) if el == (' ', '<NT>')] |
| new_text_split = [] |
| for i, el in enumerate(text_split): |
| if not case_ending and (i+1) in text_spaces_indecies: |
| el = (el[0], '<NT>') |
| new_text_split.append(el) |
| letters, tashkeel = zip(*new_text_split) |
| return letters, tashkeel |
|
|
|
|
| def compute_der(self, ref, hyp, case_ending=True): |
| _, ref_tashkeel = self.get_tashkeel_with_case_ending(ref, case_ending=case_ending) |
| _, hyp_tashkeel = self.get_tashkeel_with_case_ending(hyp, case_ending=case_ending) |
| ref_tashkeel = ' '.join(ref_tashkeel) |
| hyp_tashkeel = ' '.join(hyp_tashkeel) |
| return xer.wer(ref_tashkeel, hyp_tashkeel) |
|
|
| def compute_wer(self, ref, hyp, case_ending=True): |
| ref_letters, ref_tashkeel = self.get_tashkeel_with_case_ending(ref, case_ending=case_ending) |
| hyp_letters, hyp_tashkeel = self.get_tashkeel_with_case_ending(hyp, case_ending=case_ending) |
| ref_text_combined = self.combine_tashkeel_with_text(zip(ref_letters, ref_tashkeel)) |
| hyp_text_combined = self.combine_tashkeel_with_text(zip(hyp_letters, hyp_tashkeel)) |
| return xer.wer(ref_text_combined, hyp_text_combined) |
|
|
| def remove_tashkeel(self, text): |
| text = HARAKAT_PAT.sub('', text) |
| text = re.sub(u"[\u064E]", "", text, flags=re.UNICODE) |
| text = re.sub(u"[\u0671]", "", text, flags=re.UNICODE) |
| return text |
|
|
|
|
|
|
| if __name__ == '__main__': |
| import utils |
| from tqdm import tqdm |
| tokenizer = TashkeelTokenizer() |
|
|
| txt_folder_path = 'dataset/train' |
| prepared_lines = [] |
| for filepath in utils.get_files(txt_folder_path, '*.txt'): |
| print(f'Reading file: {filepath}') |
| with open(filepath) as f1: |
| for line in f1: |
| clean_line = tokenizer.clean_text(line) |
| if clean_line != '': |
| prepared_lines.append(clean_line) |
| print(f'completed file: {filepath}') |
|
|
| good_sentences = [] |
| bad_sentences = [] |
| tokenized_sentences = [] |
| for line in tqdm(prepared_lines): |
| try: |
| letter_n_tashkeel_pairs = tokenizer.split_tashkeel_from_text(line, test_match=True) |
| tokenized_sentences.append(letter_n_tashkeel_pairs) |
| good_sentences.append(line) |
| except AssertionError as e: |
| bad_sentences.append(line) |
|
|
| print('len(good_sentences), len(bad_sentences):', len(good_sentences), len(bad_sentences)) |
|
|
|
|
|
|
|
|