import json import os import sys import re from collections import Counter from tqdm import tqdm # Add project root to path PROJECT_ROOT = os.path.abspath(os.path.dirname(__file__)) if PROJECT_ROOT not in sys.path: sys.path.insert(0, PROJECT_ROOT) from data_factory.validator import SQLValidator DATASET_FILE = "edge_cases.jsonl" def main(): if not os.path.exists(DATASET_FILE): print(f"Error: {DATASET_FILE} not found!") return print("Starting Dataset Quality & Sanity Check...\n") total_rows = 0 corrupt_json = 0 sql_execution_failures = 0 empty_outputs = 0 missing_domains = 0 persona_counts = Counter() unique_sqls = set() unique_questions = set() domain_counts = Counter() validators = {} with open(DATASET_FILE, "r", encoding="utf-8") as f: lines = f.readlines() for line in tqdm(lines, desc="Analyzing Rows"): total_rows += 1 try: record = json.loads(line) except json.JSONDecodeError: corrupt_json += 1 continue prompt_block = record.get("prompt", []) sql = record.get("sql", "").strip() metadata = record.get("metadata", {}) if not prompt_block or len(prompt_block) < 2 or not sql: empty_outputs += 1 continue user_content = prompt_block[1].get("content", "") question = user_content.split("QUESTION: ")[-1] # Smart Domain Extraction: Try metadata first, fallback to prompt parsing domain = metadata.get("domain") if not domain: match = re.search(r"Database:\s*([a-zA-Z0-9_]+)", user_content) domain = match.group(1) if match else "unknown" persona = metadata.get("persona", "unknown") persona_counts[persona] += 1 domain_counts[domain] += 1 unique_sqls.add(sql) unique_questions.add(question) # Skip validation if domain is completely unknown/corrupted if domain == "unknown": missing_domains += 1 continue # Strict Execution Quality Check try: if domain not in validators: validators[domain] = SQLValidator(domain, seed=42) val_result = validators[domain].validate(sql) if not val_result.passed or val_result.row_count == 0: sql_execution_failures += 1 except Exception as e: # If any schema error occurs, mark it as failure missing_domains += 1 continue # Cleanup validators for v in validators.values(): v.close() # --- REPORT GENERATION --- print("\n" + "="*60) print("DATASET HEALTH REPORT") print("="*60) print(f"Total Rows Parsed : {total_rows}") print(f"Corrupt JSON Lines : {corrupt_json}") print(f"Missing SQL/Domains : {empty_outputs + missing_domains}") print("\nDIVERSITY METRICS:") print(f"Unique SQL Queries : {len(unique_sqls)} (Base logic templates)") print(f"Unique NL Questions : {len(unique_questions)}") valid_total = total_rows - (corrupt_json + empty_outputs + missing_domains) duplication_rate = (1 - (len(unique_questions) / valid_total)) * 100 if valid_total else 0 print(f"NL Duplication Rate : {duplication_rate:.2f}% (Should be low!)") print("\nPERSONA DISTRIBUTION:") for p, count in persona_counts.most_common(): print(f" - {p}: {count} ({(count/valid_total)*100:.1f}%)" if valid_total else f" - {p}: {count}") print("\nDOMAIN DISTRIBUTION:") for d, count in domain_counts.most_common(): print(f" - {d}: {count} ({(count/valid_total)*100:.1f}%)" if valid_total else f" - {d}: {count}") print("\nCRITICAL QUALITY CHECK:") fail_rate = (sql_execution_failures / valid_total) * 100 if valid_total else 0 print(f"SQL Execution Failures : {sql_execution_failures} ({fail_rate:.2f}%)") if fail_rate > 5.0: print("WARNING: Too many SQLs are failing. Dataset needs cleanup.") elif fail_rate > 0: print("GOOD: Very low failure rate. Safe to train after minor filtering.") else: print("PERFECT: Zero execution failures. Pure Gold Dataset!") print("="*60) if __name__ == "__main__": main()