| import argparse |
| import json |
| import re |
| import os |
| import unicodedata |
| from typing import List |
| from multiprocessing import Pool |
|
|
| import fasttext |
| import pandas as pd |
| from tqdm import tqdm |
|
|
| |
| FASTTEXT_MODEL_PATH = "filter_en.bin" |
| |
| THRESHOLD = 0.3 |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--data-path", type=str, required=True, |
| help="Directory or file path containing input data.") |
| parser.add_argument("--save-path", type=str, required=True, |
| help="Root directory to save filtered results.") |
| parser.add_argument("--content-key", type=str, required=True, |
| help="JSON key for the review or text content.") |
| parser.add_argument("--processes-num", type=int, default=64, |
| help="Number of parallel worker processes.") |
| parser.add_argument("--write-batch-size", type=int, default=100, |
| help="Batch size for writing to output file.") |
| parser.add_argument("--inplace", action="store_true", |
| help="Skip processing files that already exist.") |
| return parser.parse_args() |
|
|
|
|
| def fasttext_preprocess_func(content: str) -> str: |
| """Normalize content for FastText inference.""" |
| content = re.sub(r'\n{3,}', '\n\n', content) |
| content = content.lower() |
| content = ''.join( |
| c for c in unicodedata.normalize('NFKD', content) |
| if unicodedata.category(c) != 'Mn' |
| ) |
| content = content.replace('\n', '\\n').replace('\r', '\\r').replace('\t', '\\t') |
| content = re.sub(r' +', ' ', content).strip() |
| return content |
|
|
|
|
| def fasttext_infer(norm_content: str, model: fasttext.FastText): |
| """Run FastText model to get the '__label__books' probability.""" |
| labels, probs = model.predict(norm_content, k=10) |
| for label, prob in zip(labels, probs): |
| if label == '__label__books': |
| return label, float(prob) |
| return None, 0.0 |
|
|
|
|
| def load_data(file_path: str, content_key: str) -> List[str]: |
| """Load raw text content from supported files.""" |
| samples: List[str] = [] |
| if file_path.endswith(('.jsonl', '.json')): |
| with open(file_path, 'r', encoding='utf-8') as f: |
| for line in f: |
| data = json.loads(line) |
| if content_key in data and data[content_key]: |
| samples.append(str(data[content_key])) |
| elif file_path.endswith('.parquet'): |
| df = pd.read_parquet(file_path) |
| for val in df.get(content_key, []): |
| if pd.notna(val) and val: |
| samples.append(str(val)) |
| else: |
| raise ValueError(f"Unsupported file type: {file_path}") |
| return samples |
|
|
|
|
| def process_file( |
| file_path: str, |
| save_path: str, |
| item: int, |
| content_key: str, |
| inplace: bool, |
| write_batch_size: int) -> None: |
| """Process one file: filter by '__label__books' score > THRESHOLD.""" |
| fasttext_model = fasttext.load_model(FASTTEXT_MODEL_PATH) |
| contents = load_data(file_path, content_key) |
|
|
| file_name = os.path.basename(file_path) |
| base_name, _ = os.path.splitext(file_name) |
| output_file = os.path.join(save_path, f"{base_name}_filtered.jsonl") |
|
|
| if inplace and os.path.exists(output_file): |
| print(f"Skipping existing file: {output_file}") |
| return |
| if os.path.exists(output_file): |
| os.remove(output_file) |
|
|
| print(f"ID {item}: Processing {file_path} ({len(contents)} records) -> {output_file}") |
| buffer: List[dict] = [] |
| for content in tqdm(contents, desc=f"File {item}"): |
| norm = fasttext_preprocess_func(content) |
| label, score = fasttext_infer(norm, fasttext_model) |
| |
| if label == '__label__books' and score > THRESHOLD: |
| buffer.append({ |
| 'content': content, |
| 'books_score': score |
| }) |
| if len(buffer) >= write_batch_size: |
| with open(output_file, 'a', encoding='utf-8') as out_f: |
| out_f.write("\n".join(json.dumps(x, ensure_ascii=False) for x in buffer) + "\n") |
| buffer.clear() |
| |
| if buffer: |
| with open(output_file, 'a', encoding='utf-8') as out_f: |
| out_f.write("\n".join(json.dumps(x, ensure_ascii=False) for x in buffer) + "\n") |
|
|
|
|
| def main(): |
| args = parse_args() |
| os.makedirs(args.save_path, exist_ok=True) |
|
|
| |
| if os.path.isdir(args.data_path): |
| paths = [os.path.join(args.data_path, fname) for fname in os.listdir(args.data_path)] |
| else: |
| paths = [args.data_path] |
|
|
| print("=" * 80) |
| print(f"Running with FastText model: {FASTTEXT_MODEL_PATH}") |
| print(f"Processing {len(paths)} files, threshold={THRESHOLD} for '__label__books'.") |
| print("=" * 80) |
|
|
| with Pool(processes=args.processes_num) as pool: |
| pool.starmap( |
| process_file, |
| [(p, args.save_path, i, args.content_key, args.inplace, args.write_batch_size) |
| for i, p in enumerate(paths)] |
| ) |
| print("All done.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|