ritvik360's picture
Upload folder using huggingface_hub
a39d8ef verified
"""
data_factory/config.py
======================
Central configuration for the NL2SQL Synthetic Data Factory.
Design philosophy:
- SQL ALWAYS comes from human-verified templates β†’ zero SQL errors
- LLM ONLY generates natural language paraphrases β†’ no SQL hallucination
- Every SQL is execution-validated before saving β†’ guaranteed correctness
"""
from __future__ import annotations
from pathlib import Path
# ── Paths ────────────────────────────────────────────────────────────────
ROOT_DIR = Path(__file__).parent.parent
DATA_DIR = ROOT_DIR / "generated_data"
CHECKPOINT_DIR = DATA_DIR / "checkpoints"
OUTPUT_DIR = DATA_DIR / "output"
# ── vLLM / Model ─────────────────────────────────────────────────────────
# For H100 with 80GB VRAM β€” run Llama-3-70B or Qwen-72B at full bf16
GENERATOR_MODEL = "meta-llama/Meta-Llama-3-70B-Instruct" # change to your preferred model
TENSOR_PARALLEL = 4 # Number of GPUs for tensor parallelism (H100 cluster)
MAX_MODEL_LEN = 4096 # Max context length
GPU_MEMORY_UTIL = 0.90 # Leave 10% headroom
# ── Generation settings ──────────────────────────────────────────────────
PERSONAS = ["ceo", "chatty", "lazy_typist", "non_techie", "analyst"]
NL_VARIANTS_PER_TEMPLATE = 5 # One per persona
AUGMENTATIONS_PER_NL = 3 # Rule-based variations per NL string
TEMPERATURE = 0.85 # Slightly high for diversity
MAX_NEW_TOKENS = 150 # NL questions are short
# ── Scale targets ────────────────────────────────────────────────────────
# 56 base SQL templates Γ— 5 personas Γ— 3 augmentations = 840 "original" records
# With vLLM generating more NL variants, target: ~500K-1M clean records
VLLM_EXTRA_VARIANTS = 10 # Additional vLLM NL variants per template beyond personas
# ── Validation ───────────────────────────────────────────────────────────
RANDOM_SEED = 42
# ── Domains ──────────────────────────────────────────────────────────────
DOMAINS = ["ecommerce", "healthcare", "finance", "hr"]
DIFFICULTY_LABELS = {
"easy": "Single-table SELECT with basic WHERE/ORDER/LIMIT.",
"medium": "Multi-table JOIN with GROUP BY/HAVING/aggregates.",
"hard": "CTEs, window functions, subqueries.",
}