| import argparse |
| import json |
| import re |
| import os |
| from functools import cache |
| from pathlib import Path |
| from typing import Iterator, List, NoReturn, Optional, Tuple, Union |
|
|
| import kenlm |
| import msgspec |
| import sentencepiece |
| from numpy.random import default_rng |
| from scipy.stats import norm |
| from tqdm import tqdm |
|
|
| from normalization import normalize_text |
|
|
|
|
| RNG = default_rng() |
| LANGS = ("no", "nn", "nob", "nno", "da", "sv", "is", "en") |
| DEFAULT_LANG = "no" |
| BASEPATH = Path(os.environ.get("PERPLEXITY_BASEPATH", "/nfsmounts/datastore/mimir/perplexity")) |
| CONFIG = { |
| "harmful": { |
| "no": {"model": BASEPATH / "kenlm" / "harmful" / "no.bin", "normalize": True}, |
| "nn": {"model": BASEPATH / "kenlm" / "harmful" / "no.bin", "normalize": True}, |
| "nob": {"model": BASEPATH / "kenlm" / "harmful" / "no.bin", "normalize": True}, |
| "nno": {"model": BASEPATH / "kenlm" / "harmful" / "no.bin", "normalize": True}, |
| "da": {"model": BASEPATH / "kenlm" / "harmful" / "da.bin", "normalize": True}, |
| "sv": {"model": BASEPATH / "kenlm" / "harmful" / "sv.bin", "normalize": True}, |
| "is": {"model": BASEPATH / "kenlm" / "harmful" / "is.bin", "normalize": True}, |
| "en": {"model": BASEPATH / "kenlm" / "harmful" / "en.bin", "normalize": True}, |
| }, |
| "wikipedia": { |
| "no": { |
| "model": BASEPATH / "kenlm" / "wikipedia" / "no.arpa.bin", |
| "tokenizer": BASEPATH / "spm" / "wikipedia" / "no.sp.model", |
| "normalize": True |
| }, |
| "nn": { |
| "model": BASEPATH / "kenlm" / "wikipedia" / "nn.arpa.bin", |
| "tokenizer": BASEPATH / "spm" / "wikipedia" / "nn.sp.model", |
| "normalize": True |
| }, |
| "nob": { |
| "model": BASEPATH / "kenlm" / "wikipedia" / "no.arpa.bin", |
| "tokenizer": BASEPATH / "spm" / "wikipedia" / "no.sp.model", |
| "normalize": True |
| }, |
| "nno": { |
| "model": BASEPATH / "kenlm" / "wikipedia" / "nn.arpa.bin", |
| "tokenizer": BASEPATH / "spm" / "wikipedia" / "nn.sp.model", |
| "normalize": True |
| }, |
| "da": { |
| "model": BASEPATH / "kenlm" / "wikipedia" / "da.arpa.bin", |
| "tokenizer": BASEPATH / "spm" / "wikipedia" / "da.sp.model", |
| "normalize": True |
| }, |
| "en": { |
| "model": BASEPATH / "kenlm" / "wikipedia" / "en.arpa.bin", |
| "tokenizer": BASEPATH / "spm" / "wikipedia" / "en.sp.model", |
| "normalize": True |
| }, |
| "is": { |
| "model": BASEPATH / "kenlm" / "wikipedia" / "is.arpa.bin", |
| "tokenizer": BASEPATH / "spm" / "wikipedia" / "is.sp.model", |
| "normalize": True |
| }, |
| "sv": { |
| "model": BASEPATH / "kenlm" / "wikipedia" / "sv.arpa.bin", |
| "tokenizer": BASEPATH / "spm" / "wikipedia" / "sv.sp.model", |
| "normalize": True |
| }, |
| }, |
| "books": { |
| "model": BASEPATH / "kenlm" / "books.norm.sp.arpa.bin", |
| "tokenizer": BASEPATH / "spm" / "books.norm.sp.model", |
| "normalize": True |
| }, |
| "newspapers": { |
| "model": BASEPATH / "kenlm" / "newspapers.norm.sp.arpa.bin", |
| "tokenizer": BASEPATH / "spm" / "newspapers.norm.sp.model", |
| "normalize": True |
| }, |
| "maalfrid": { |
| "model": BASEPATH / "kenlm" / "maalfrid.norm.sp.arpa.bin", |
| "tokenizer": BASEPATH / "spm" / "maalfrid.norm.sp.model", |
| "normalize": True |
| } |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| def should_keep( |
| perp: float, dist_norm: float, dist_mean: float, dist_std: float |
| ) -> bool: |
| """ |
| Decide if a doc is to be retained based on its perplexity value |
| Note: set() must have been called previously |
| """ |
| p = norm.pdf(perp, loc=dist_mean, scale=dist_std) / dist_norm |
| return RNG.uniform() < p |
|
|
|
|
| def fix_language(language: str) -> str: |
| if language not in LANGS: |
| return DEFAULT_LANG |
| else: |
| return language |
|
|
|
|
| def pp(log_score, length): |
| return 10.0 ** (-log_score / length) |
|
|
|
|
| @cache |
| def load_kenlm(model: str) -> kenlm.Model: |
| lm_config = kenlm.Config() |
| lm_config.load_method = 2 |
| return kenlm.Model(str(model), lm_config) |
|
|
|
|
| @cache |
| def load_sentencepiece(model: str) -> sentencepiece.SentencePieceProcessor: |
| sp = sentencepiece.SentencePieceProcessor() |
| sp.load(str(model)) |
| return sp |
|
|
|
|
| def get_perplexity( |
| document: str, |
| model: str, |
| tokenizer: str=None, |
| normalize: bool=False |
| ) -> float: |
| lines = document.split("\n") |
| model = load_kenlm(model) |
| if not lines or not model: |
| return 0.0 |
| if tokenizer: |
| sp = load_sentencepiece(tokenizer) |
| doc_log_score, doc_length = 0, 0 |
| for line in lines: |
| if not line: |
| continue |
| if normalize: |
| line = normalize_text(line) |
| if tokenizer: |
| line = " ".join(sp.encode_as_pieces(line)) |
| log_score = model.score(line) |
| length = len(line.split()) + 1 |
| doc_log_score += log_score |
| doc_length += length |
|
|
| return round(pp(doc_log_score, doc_length), 1) |
|
|
|
|
| def get_perplexity_local( |
| document: str, |
| model: kenlm.Model, |
| tokenizer: sentencepiece.SentencePieceProcessor=None, |
| normalize: bool=False |
| ) -> float: |
| lines = document.split("\n") |
| if not lines or not model: |
| return 0.0 |
| doc_log_score, doc_length = 0, 0 |
| for line in lines: |
| if normalize: |
| line = normalize_text(line) |
| if tokenizer is not None: |
| line = " ".join(tokenizer.encode_as_pieces(line)) |
| log_score = model.score(line) |
| length = len(line.split()) + 1 |
| doc_log_score += log_score |
| doc_length += length |
|
|
| return round(pp(doc_log_score, doc_length), 1) |
|
|
|
|
| def harmful_perplexity(document: str, language: str) -> float: |
| params = CONFIG["harmful"][fix_lang(language)] |
| return get_perplexity(document=document, **params) |
|
|
|
|
| def wikipedia_perplexity(document: str, language: str) -> float: |
| params = CONFIG["wikipedia"][fix_lang(language)] |
| return get_perplexity(document=document, **params) |
|
|
|
|
| def books_perplexity(document: str) -> float: |
| params = CONFIG["books"] |
| return get_perplexity(document=document, **params) |
|
|
|
|
| def newspapers_perplexity(document: str) -> float: |
| params = CONFIG["newspapers"] |
| return get_perplexity(document=document, **params) |
|
|
|
|
| def maalfrid_perplexity(document: str) -> float: |
| params = CONFIG["maalfrid"] |
| return get_perplexity(document=document, **params) |
|
|
|
|
| def source_perplexities( |
| document: str, |
| language: str, |
| model: str | None = None, |
| include_harmful: bool=True) -> float: |
| """Calculates all models perplexities at once""" |
| |
| normalized_document = "\n".join(normalize_text(line) for line in document.split("\n")) |
| language = fix_language(language) |
| |
| if model is not None: |
| params = CONFIG[model] |
| if model == "wikipedia": |
| params = params[language] |
| params.update({"normalize": False}) |
| perplexity = get_perplexity(document=normalized_document, **params) |
| perplexities = { |
| f"{model}_pp": perplexity, |
| } |
| else: |
| params = CONFIG["wikipedia"][language] |
| params.update({"normalize": False}) |
| wikipedia_perplexity = get_perplexity(document=normalized_document, **params) |
|
|
| params = CONFIG["books"] |
| params.update({"normalize": False}) |
| books_perplexity = get_perplexity(document=normalized_document, **params) |
| |
| params = CONFIG["newspapers"] |
| params.update({"normalize": False}) |
| newspapers_perplexity = get_perplexity(document=normalized_document, **params) |
|
|
| params = CONFIG["maalfrid"] |
| params.update({"normalize": False}) |
| maalfrid_perplexity = get_perplexity(document=normalized_document, **params) |
| perplexities = { |
| "wikipedia_pp": wikipedia_perplexity, |
| "books_pp": books_perplexity, |
| "newspapers_pp": newspapers_perplexity, |
| "maalfrid_pp": maalfrid_perplexity, |
| } |
| if include_harmful: |
| params = CONFIG["harmful"][language] |
| params.update({"normalize": False}) |
| harmful_perplexity = get_perplexity(document=normalized_document, **params) |
| perplexities.update({ |
| "harmful_pp": harmful_perplexity, |
| }) |
| return perplexities |
|
|
|
|
| def get_model_for(doc_type: str) -> (str, bool): |
| """Returns model type and if it needs a language variant""" |
| doc_type = doc_type.split("_", 1)[0] |
| if "-" in doc_type: |
| doc_type = doc_type.split("-", 1)[-1] |
| if doc_type in ("book", "books"): |
| return "books", False |
| elif doc_type in ("culturax", "slimpajama", "wikipedia", "digimanus", "pg19", "hplt", "starcoder"): |
| return "wikipedia", True |
| elif doc_type in ("newspaper", "newspapers"): |
| return "newspapers", False |
| elif doc_type in ("evalueringsrapport", "lovdata", "maalfrid", "parlamint"): |
| return "maalfrid", False |
| else: |
| return "wikipedia", True |
|
|
|
|
| def preload_models_tokenizers() -> List: |
| print("Preloading models...", end=" ") |
| models = { |
| "books": ( |
| load_kenlm(BASEPATH / "kenlm" / "books.norm.arpa.bin"), |
| load_sentencepiece(BASEPATH / "spm" / "books.norm.sp.model") |
| ), |
| "newspapers": ( |
| load_kenlm(BASEPATH / "kenlm" / "newspapers.norm.arpa.bin"), |
| load_sentencepiece(BASEPATH / "spm" / "newspapers.norm.sp.model") |
| ), |
| "maalfrid": ( |
| load_kenlm(BASEPATH / "kenlm" / "maalfrid.norm.arpa.bin"), |
| load_sentencepiece(BASEPATH / "spm" / "maalfrid.norm.sp.model") |
| ), |
| } |
| for lang, params in CONFIG["harmful"].items(): |
| model = load_kenlm(params["model"]) |
| models[f"harmful-{lang}"] = model, None |
|
|
| for lang, params in CONFIG["wikipedia"].items(): |
| model = load_kenlm(params["model"]) |
| tokenizer = load_sentencepiece(params["tokenizer"]) |
| models[f"wikipedia-{lang}"] = model, tokenizer |
| print("Done") |
| return models |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
|
|
| def process_file(input_file, output_path, cutoff=None, model=None, overwrite_output=True): |
| """ |
| Processes a file by reading its contents, analyzing each line for language and document type, |
| computing perplexities using specified models, and writing the modified content to a new file. |
| |
| This function performs several steps: |
| 1. Determines the output file path and checks for its existence if overwrite is not desired. |
| 2. Reads the input file line by line, processing each line as a separate JSON document. |
| 3. For each document, identifies its language using a fastText model. If the document type is "starcoder", |
| it defaults the language to English. |
| 4. Depending on the model parameter, computes perplexities for the document text either using a |
| single document type model or a specified general model. |
| 5. Updates the document with computed perplexities and writes it to the output file in JSON format. |
| 6. Optionally stops processing after a specified number of lines determined by the cutoff parameter. |
| |
| Parameters: |
| - input_file (str or Path): Path to the input file to be processed. |
| - output_path (str or Path): Directory path where the output file will be saved. The output file |
| will have the same name as the input file. |
| - cutoff (int, optional): If provided, processing will stop after this number of lines. Defaults to None. |
| - model (str, optional): Specifies the model to use for computing perplexities. If 'single', uses a |
| model specific to the document's type. Otherwise, uses the model specified. |
| Defaults to None. |
| - overwrite_output (bool): If True, will overwrite the output file if it already exists. If False, |
| will skip processing if the output file exists. Defaults to True. |
| |
| Returns: |
| None. Writes processed documents to an output file in the specified output path. |
| """ |
| input_file = Path(input_file) |
| output_file = Path(output_path) / input_file.name |
| if not overwrite_output and output_file.exists(): |
| print(f"Skipping {output_file} as it already exists") |
| return |
| with (open(output_file, 'w', encoding='utf-8') as f, |
| open(input_file, 'r', encoding='utf-8') as lines): |
| for line_count, line in tqdm(enumerate(lines), desc=f"Processing {input_file.name}"): |
| doc = json.loads(line) |
| language = doc["lang_fasttext"] |
| if doc["doc_type"] == "starcoder": |
| language = "en" |
| if model == "single": |
| doc_type_model, _ = get_model_for(doc["doc_type"]) |
| perplexities = source_perplexities(doc["text"], language, model=doc_type_model) |
| perplexities["perplexity"] = perplexities.pop(f"{doc_type_model}_pp") |
| perplexities["perplexity_model"] = doc_type_model |
| else: |
| perplexities = source_perplexities(doc["text"], language, model=model) |
| doc.update(perplexities) |
| f.write(json.dumps(doc) + "\n") |
| if cutoff is not None and line_count >= cutoff: |
| break |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description='Calculate perplexity values for a given JSON Lines file and output the result to a new file.') |
| parser.add_argument('-i', '--input_file', type=str, |
| help='Input file path') |
| parser.add_argument('-o', '--output_path', type=str, |
| help='Output path to write enriched file') |
| parser.add_argument('-c', '--cutoff', required=False, type=int, |
| help='Max number of lines to process') |
| parser.add_argument('-m', '--model', required=False, type=str, |
| help='Run "single" model per doc type, "all" the models, ' |
| 'or a specific model to choose from ' |
| '"books", "wikipedia", "newspapers" or "maalfrid". ' |
| 'Defaults to "single"') |
| parser.add_argument('--overwrite_output', |
| action=argparse.BooleanOptionalAction, default=True, |
| help="Whether to overwrite the output file if exists.") |
|
|
| args = parser.parse_args() |
|
|
| if args.model == "single": |
| process_file( |
| args.input_file, args.output_path, args.cutoff, |
| model="single", overwrite_output=args.overwrite_output, |
| ) |
| elif args.model in ("books", "wikipedia", "newspapers", "maalfrid"): |
| process_file( |
| args.input_file, args.output_path, args.cutoff, |
| model=args.model, overwrite_output=args.overwrite_output, |
| ) |
| else: |
| process_file( |
| args.input_file, args.output_path, args.cutoff, |
| overwrite_output=args.overwrite_output, |
| ) |
|
|