import argparse import json import re import os import unicodedata from typing import List from multiprocessing import Pool import fasttext import pandas as pd from tqdm import tqdm # Only use the Kyutai Dactory English FastText model FASTTEXT_MODEL_PATH = "filter_en.bin" # Minimum probability threshold for the '__label__books' class THRESHOLD = 0.3 def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--data-path", type=str, required=True, help="Directory or file path containing input data.") parser.add_argument("--save-path", type=str, required=True, help="Root directory to save filtered results.") parser.add_argument("--content-key", type=str, required=True, help="JSON key for the review or text content.") parser.add_argument("--processes-num", type=int, default=64, help="Number of parallel worker processes.") parser.add_argument("--write-batch-size", type=int, default=100, help="Batch size for writing to output file.") parser.add_argument("--inplace", action="store_true", help="Skip processing files that already exist.") return parser.parse_args() def fasttext_preprocess_func(content: str) -> str: """Normalize content for FastText inference.""" content = re.sub(r'\n{3,}', '\n\n', content) # collapse multiple newlines content = content.lower() content = ''.join( c for c in unicodedata.normalize('NFKD', content) if unicodedata.category(c) != 'Mn' ) content = content.replace('\n', '\\n').replace('\r', '\\r').replace('\t', '\\t') content = re.sub(r' +', ' ', content).strip() return content def fasttext_infer(norm_content: str, model: fasttext.FastText): """Run FastText model to get the '__label__books' probability.""" labels, probs = model.predict(norm_content, k=10) for label, prob in zip(labels, probs): if label == '__label__books': return label, float(prob) return None, 0.0 def load_data(file_path: str, content_key: str) -> List[str]: """Load raw text content from supported files.""" samples: List[str] = [] if file_path.endswith(('.jsonl', '.json')): with open(file_path, 'r', encoding='utf-8') as f: for line in f: data = json.loads(line) if content_key in data and data[content_key]: samples.append(str(data[content_key])) elif file_path.endswith('.parquet'): df = pd.read_parquet(file_path) for val in df.get(content_key, []): if pd.notna(val) and val: samples.append(str(val)) else: raise ValueError(f"Unsupported file type: {file_path}") return samples def process_file( file_path: str, save_path: str, item: int, content_key: str, inplace: bool, write_batch_size: int) -> None: """Process one file: filter by '__label__books' score > THRESHOLD.""" fasttext_model = fasttext.load_model(FASTTEXT_MODEL_PATH) contents = load_data(file_path, content_key) file_name = os.path.basename(file_path) base_name, _ = os.path.splitext(file_name) output_file = os.path.join(save_path, f"{base_name}_filtered.jsonl") if inplace and os.path.exists(output_file): print(f"Skipping existing file: {output_file}") return if os.path.exists(output_file): os.remove(output_file) print(f"ID {item}: Processing {file_path} ({len(contents)} records) -> {output_file}") buffer: List[dict] = [] for content in tqdm(contents, desc=f"File {item}"): norm = fasttext_preprocess_func(content) label, score = fasttext_infer(norm, fasttext_model) # Keep only if the predicted label is '__label__books' and probability above threshold if label == '__label__books' and score > THRESHOLD: buffer.append({ 'content': content, 'books_score': score }) if len(buffer) >= write_batch_size: with open(output_file, 'a', encoding='utf-8') as out_f: out_f.write("\n".join(json.dumps(x, ensure_ascii=False) for x in buffer) + "\n") buffer.clear() # Write remaining if buffer: with open(output_file, 'a', encoding='utf-8') as out_f: out_f.write("\n".join(json.dumps(x, ensure_ascii=False) for x in buffer) + "\n") def main(): args = parse_args() os.makedirs(args.save_path, exist_ok=True) # Collect input paths if os.path.isdir(args.data_path): paths = [os.path.join(args.data_path, fname) for fname in os.listdir(args.data_path)] else: paths = [args.data_path] print("=" * 80) print(f"Running with FastText model: {FASTTEXT_MODEL_PATH}") print(f"Processing {len(paths)} files, threshold={THRESHOLD} for '__label__books'.") print("=" * 80) with Pool(processes=args.processes_num) as pool: pool.starmap( process_file, [(p, args.save_path, i, args.content_key, args.inplace, args.write_batch_size) for i, p in enumerate(paths)] ) print("All done.") if __name__ == "__main__": main()