| import os |
| import re |
| import unicodedata |
| from typing import Dict |
|
|
| import kenlm |
| import sentencepiece |
| from huggingface_hub import cached_download, hf_hub_url |
|
|
|
|
| class SentencePiece: |
| def __init__( |
| self, |
| model: str, |
| ): |
| super().__init__() |
| self.sp = sentencepiece.SentencePieceProcessor() |
| self.sp.load(str(model)) |
|
|
| def do(self, text: dict) -> dict: |
| tokenized = self.sp.encode_as_pieces(text) |
| return " ".join(tokenized) |
|
|
|
|
| class KenlmModel: |
| digit_re: re.Pattern = re.compile(r"\d") |
| unicode_punct: Dict[str, str] = { |
| ",": ",", |
| "。": ".", |
| "、": ",", |
| "„": '"', |
| "”": '"', |
| "“": '"', |
| "«": '"', |
| "»": '"', |
| "1": '"', |
| "」": '"', |
| "「": '"', |
| "《": '"', |
| "》": '"', |
| "´": "'", |
| "∶": ":", |
| ":": ":", |
| "?": "?", |
| "!": "!", |
| "(": "(", |
| ")": ")", |
| ";": ";", |
| "–": "-", |
| "—": " - ", |
| ".": ". ", |
| "~": "~", |
| "’": "'", |
| "…": "...", |
| "━": "-", |
| "〈": "<", |
| "〉": ">", |
| "【": "[", |
| "】": "]", |
| "%": "%", |
| "►": "-", |
| } |
| unicode_punct_re = re.compile(f"[{''.join(unicode_punct.keys())}]") |
| non_printing_chars_re = re.compile( |
| f"[{''.join(map(chr, list(range(0,32)) + list(range(127,160))))}]" |
| ) |
| kenlm_model_dir = None |
| sentence_piece_model_dir = None |
|
|
| def __init__( |
| self, |
| model_dataset: str, |
| language: str, |
| lower_case: bool = False, |
| remove_accents: bool = False, |
| normalize_numbers: bool = True, |
| punctuation: int = 1, |
| ): |
| self.model = kenlm.Model(os.path.join(model_dataset, f"{language}.arpa.bin")) |
| self.tokenizer = SentencePiece(os.path.join(model_dataset, f"{language}.sp.model")) |
| self.accent = remove_accents |
| self.case = lower_case |
| self.numbers = normalize_numbers |
| self.punct = punctuation |
|
|
| @classmethod |
| def from_pretrained( |
| cls, |
| model_dataset: str, |
| language: str, |
| ): |
| return cls( |
| model_dataset, |
| language, |
| False, |
| False, |
| True, |
| 1, |
| ) |
|
|
| def pp(self, log_score, length): |
| return 10.0 ** (-log_score / length) |
|
|
| def get_perplexity(self, doc: str, normalize_cc_net: bool = True): |
| if normalize_cc_net: |
| doc = self.normalize( |
| doc, |
| accent=self.accent, |
| case=self.case, |
| numbers=self.numbers, |
| punct=self.punct, |
| ) |
| |
| doc = self.tokenizer.do(doc) |
| doc_log_score, doc_length = 0, 0 |
| for line in doc.split("\n"): |
| log_score = self.model.score(line) |
| length = len(line.split()) + 1 |
| doc_log_score += log_score |
| doc_length += length |
| return round(self.pp(doc_log_score, doc_length), 1) |
|
|
| def normalize( |
| self, |
| line: str, |
| accent: bool = True, |
| case: bool = True, |
| numbers: bool = True, |
| punct: int = 1, |
| ) -> str: |
| line = line.strip() |
| if not line: |
| return line |
| if case: |
| line = line.lower() |
| if accent: |
| line = self.strip_accents(line) |
| if numbers: |
| line = self.digit_re.sub("0", line) |
| if punct == 1: |
| line = self.replace_unicode_punct(line) |
| elif punct == 2: |
| line = self.remove_unicode_punct(line) |
| line = self.remove_non_printing_char(line) |
| return line |
|
|
| def strip_accents(self, line: str) -> str: |
| """Strips accents from a piece of text.""" |
| nfd = unicodedata.normalize("NFD", line) |
| output = [c for c in nfd if unicodedata.category(c) != "Mn"] |
| if len(output) == line: |
| return line |
| return "".join(output) |
|
|
| def replace_unicode_punct(self, text: str) -> str: |
| return "".join(self.unicode_punct.get(c, c) for c in text) |
|
|
| def remove_unicode_punct(self, text: str) -> str: |
| """More aggressive version of replace_unicode_punct but also faster.""" |
| return self.unicode_punct_re.sub("", text) |
|
|
| def remove_non_printing_char(self, text: str) -> str: |
| return self.non_printing_chars_re.sub("", text) |
|
|