import json import os import sys import re from tqdm import tqdm PROJECT_ROOT = os.path.abspath(os.path.dirname(__file__)) if PROJECT_ROOT not in sys.path: sys.path.insert(0, PROJECT_ROOT) from data_factory.validator import SQLValidator INPUT_FILE = "nl2sql_50k_elite_dataset.jsonl" OUTPUT_FILE = "nl2sql_cleaned_ready_to_train.jsonl" def main(): if not os.path.exists(INPUT_FILE): print(f"Error: {INPUT_FILE} not found!") return print(f"Sweeping dataset to remove bad SQLs...") with open(INPUT_FILE, "r", encoding="utf-8") as f: lines = f.readlines() validators = {} cleaned_count = 0 failed_count = 0 with open(OUTPUT_FILE, "w", encoding="utf-8") as out_f: for line in tqdm(lines, desc="Filtering Garbage"): try: record = json.loads(line) except json.JSONDecodeError: failed_count += 1 continue sql = record.get("sql", "").strip() metadata = record.get("metadata", {}) domain = metadata.get("domain") # Fallback for domain extraction if not domain or domain == "unknown": content = record.get("prompt", [{}, {}])[1].get("content", "") match = re.search(r"Database:\s*([a-zA-Z0-9_]+)", content) domain = match.group(1) if match else "unknown" if domain == "unknown": failed_count += 1 continue if domain not in validators: validators[domain] = SQLValidator(domain, seed=42) try: val_result = validators[domain].validate(sql) # Keep ONLY if SQL is 100% perfect and returns data if val_result.passed and val_result.row_count > 0: out_f.write(line) cleaned_count += 1 else: failed_count += 1 except Exception: failed_count += 1 for v in validators.values(): v.close() print("\n" + "="*50) print("DATASET CLEANUP COMPLETE") print("="*50) print(f"Original Rows : {len(lines)}") print(f"Cleaned Rows : {cleaned_count} (100% Valid SQL)") print(f"Removed Rows : {failed_count}") print(f"Saved To : {OUTPUT_FILE}") print("="*50) if __name__ == "__main__": main()