Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| import sys | |
| import re | |
| from collections import Counter | |
| from tqdm import tqdm | |
| # Add project root to path | |
| PROJECT_ROOT = os.path.abspath(os.path.dirname(__file__)) | |
| if PROJECT_ROOT not in sys.path: | |
| sys.path.insert(0, PROJECT_ROOT) | |
| from data_factory.validator import SQLValidator | |
| DATASET_FILE = "edge_cases.jsonl" | |
| def main(): | |
| if not os.path.exists(DATASET_FILE): | |
| print(f"Error: {DATASET_FILE} not found!") | |
| return | |
| print("Starting Dataset Quality & Sanity Check...\n") | |
| total_rows = 0 | |
| corrupt_json = 0 | |
| sql_execution_failures = 0 | |
| empty_outputs = 0 | |
| missing_domains = 0 | |
| persona_counts = Counter() | |
| unique_sqls = set() | |
| unique_questions = set() | |
| domain_counts = Counter() | |
| validators = {} | |
| with open(DATASET_FILE, "r", encoding="utf-8") as f: | |
| lines = f.readlines() | |
| for line in tqdm(lines, desc="Analyzing Rows"): | |
| total_rows += 1 | |
| try: | |
| record = json.loads(line) | |
| except json.JSONDecodeError: | |
| corrupt_json += 1 | |
| continue | |
| prompt_block = record.get("prompt", []) | |
| sql = record.get("sql", "").strip() | |
| metadata = record.get("metadata", {}) | |
| if not prompt_block or len(prompt_block) < 2 or not sql: | |
| empty_outputs += 1 | |
| continue | |
| user_content = prompt_block[1].get("content", "") | |
| question = user_content.split("QUESTION: ")[-1] | |
| # Smart Domain Extraction: Try metadata first, fallback to prompt parsing | |
| domain = metadata.get("domain") | |
| if not domain: | |
| match = re.search(r"Database:\s*([a-zA-Z0-9_]+)", user_content) | |
| domain = match.group(1) if match else "unknown" | |
| persona = metadata.get("persona", "unknown") | |
| persona_counts[persona] += 1 | |
| domain_counts[domain] += 1 | |
| unique_sqls.add(sql) | |
| unique_questions.add(question) | |
| # Skip validation if domain is completely unknown/corrupted | |
| if domain == "unknown": | |
| missing_domains += 1 | |
| continue | |
| # Strict Execution Quality Check | |
| try: | |
| if domain not in validators: | |
| validators[domain] = SQLValidator(domain, seed=42) | |
| val_result = validators[domain].validate(sql) | |
| if not val_result.passed or val_result.row_count == 0: | |
| sql_execution_failures += 1 | |
| except Exception as e: | |
| # If any schema error occurs, mark it as failure | |
| missing_domains += 1 | |
| continue | |
| # Cleanup validators | |
| for v in validators.values(): | |
| v.close() | |
| # --- REPORT GENERATION --- | |
| print("\n" + "="*60) | |
| print("DATASET HEALTH REPORT") | |
| print("="*60) | |
| print(f"Total Rows Parsed : {total_rows}") | |
| print(f"Corrupt JSON Lines : {corrupt_json}") | |
| print(f"Missing SQL/Domains : {empty_outputs + missing_domains}") | |
| print("\nDIVERSITY METRICS:") | |
| print(f"Unique SQL Queries : {len(unique_sqls)} (Base logic templates)") | |
| print(f"Unique NL Questions : {len(unique_questions)}") | |
| valid_total = total_rows - (corrupt_json + empty_outputs + missing_domains) | |
| duplication_rate = (1 - (len(unique_questions) / valid_total)) * 100 if valid_total else 0 | |
| print(f"NL Duplication Rate : {duplication_rate:.2f}% (Should be low!)") | |
| print("\nPERSONA DISTRIBUTION:") | |
| for p, count in persona_counts.most_common(): | |
| print(f" - {p}: {count} ({(count/valid_total)*100:.1f}%)" if valid_total else f" - {p}: {count}") | |
| print("\nDOMAIN DISTRIBUTION:") | |
| for d, count in domain_counts.most_common(): | |
| print(f" - {d}: {count} ({(count/valid_total)*100:.1f}%)" if valid_total else f" - {d}: {count}") | |
| print("\nCRITICAL QUALITY CHECK:") | |
| fail_rate = (sql_execution_failures / valid_total) * 100 if valid_total else 0 | |
| print(f"SQL Execution Failures : {sql_execution_failures} ({fail_rate:.2f}%)") | |
| if fail_rate > 5.0: | |
| print("WARNING: Too many SQLs are failing. Dataset needs cleanup.") | |
| elif fail_rate > 0: | |
| print("GOOD: Very low failure rate. Safe to train after minor filtering.") | |
| else: | |
| print("PERFECT: Zero execution failures. Pure Gold Dataset!") | |
| print("="*60) | |
| if __name__ == "__main__": | |
| main() |