nl2sql-bench / check_quality.py
ritvik360's picture
Upload folder using huggingface_hub
a39d8ef verified
import json
import os
import sys
import re
from collections import Counter
from tqdm import tqdm
# Add project root to path
PROJECT_ROOT = os.path.abspath(os.path.dirname(__file__))
if PROJECT_ROOT not in sys.path:
sys.path.insert(0, PROJECT_ROOT)
from data_factory.validator import SQLValidator
DATASET_FILE = "edge_cases.jsonl"
def main():
if not os.path.exists(DATASET_FILE):
print(f"Error: {DATASET_FILE} not found!")
return
print("Starting Dataset Quality & Sanity Check...\n")
total_rows = 0
corrupt_json = 0
sql_execution_failures = 0
empty_outputs = 0
missing_domains = 0
persona_counts = Counter()
unique_sqls = set()
unique_questions = set()
domain_counts = Counter()
validators = {}
with open(DATASET_FILE, "r", encoding="utf-8") as f:
lines = f.readlines()
for line in tqdm(lines, desc="Analyzing Rows"):
total_rows += 1
try:
record = json.loads(line)
except json.JSONDecodeError:
corrupt_json += 1
continue
prompt_block = record.get("prompt", [])
sql = record.get("sql", "").strip()
metadata = record.get("metadata", {})
if not prompt_block or len(prompt_block) < 2 or not sql:
empty_outputs += 1
continue
user_content = prompt_block[1].get("content", "")
question = user_content.split("QUESTION: ")[-1]
# Smart Domain Extraction: Try metadata first, fallback to prompt parsing
domain = metadata.get("domain")
if not domain:
match = re.search(r"Database:\s*([a-zA-Z0-9_]+)", user_content)
domain = match.group(1) if match else "unknown"
persona = metadata.get("persona", "unknown")
persona_counts[persona] += 1
domain_counts[domain] += 1
unique_sqls.add(sql)
unique_questions.add(question)
# Skip validation if domain is completely unknown/corrupted
if domain == "unknown":
missing_domains += 1
continue
# Strict Execution Quality Check
try:
if domain not in validators:
validators[domain] = SQLValidator(domain, seed=42)
val_result = validators[domain].validate(sql)
if not val_result.passed or val_result.row_count == 0:
sql_execution_failures += 1
except Exception as e:
# If any schema error occurs, mark it as failure
missing_domains += 1
continue
# Cleanup validators
for v in validators.values():
v.close()
# --- REPORT GENERATION ---
print("\n" + "="*60)
print("DATASET HEALTH REPORT")
print("="*60)
print(f"Total Rows Parsed : {total_rows}")
print(f"Corrupt JSON Lines : {corrupt_json}")
print(f"Missing SQL/Domains : {empty_outputs + missing_domains}")
print("\nDIVERSITY METRICS:")
print(f"Unique SQL Queries : {len(unique_sqls)} (Base logic templates)")
print(f"Unique NL Questions : {len(unique_questions)}")
valid_total = total_rows - (corrupt_json + empty_outputs + missing_domains)
duplication_rate = (1 - (len(unique_questions) / valid_total)) * 100 if valid_total else 0
print(f"NL Duplication Rate : {duplication_rate:.2f}% (Should be low!)")
print("\nPERSONA DISTRIBUTION:")
for p, count in persona_counts.most_common():
print(f" - {p}: {count} ({(count/valid_total)*100:.1f}%)" if valid_total else f" - {p}: {count}")
print("\nDOMAIN DISTRIBUTION:")
for d, count in domain_counts.most_common():
print(f" - {d}: {count} ({(count/valid_total)*100:.1f}%)" if valid_total else f" - {d}: {count}")
print("\nCRITICAL QUALITY CHECK:")
fail_rate = (sql_execution_failures / valid_total) * 100 if valid_total else 0
print(f"SQL Execution Failures : {sql_execution_failures} ({fail_rate:.2f}%)")
if fail_rate > 5.0:
print("WARNING: Too many SQLs are failing. Dataset needs cleanup.")
elif fail_rate > 0:
print("GOOD: Very low failure rate. Safe to train after minor filtering.")
else:
print("PERFECT: Zero execution failures. Pure Gold Dataset!")
print("="*60)
if __name__ == "__main__":
main()