""" Synthetic training data generator for pg_plan_cache models. Generates realistic SQL queries across a wide range of complexity levels with labels for cache benefit, recommended TTL, and complexity score. """ import random # --------------------------------------------------------------------------- # Building blocks # --------------------------------------------------------------------------- TABLES = [ "users", "orders", "products", "payments", "sessions", "logs", "events", "accounts", "invoices", "shipments", "categories", "reviews", "inventory", "notifications", "messages", "employees", "departments", "projects", "tasks", "comments", ] SCHEMAS = ["public", "app", "analytics", "billing"] COLUMNS = { "users": ["id", "name", "email", "created_at", "status", "age", "country"], "orders": ["id", "user_id", "total", "status", "created_at", "shipped_at"], "products": ["id", "name", "price", "category_id", "stock", "rating"], "payments": ["id", "order_id", "amount", "method", "paid_at", "status"], "sessions": ["id", "user_id", "started_at", "ended_at", "ip_address"], "logs": ["id", "level", "message", "created_at", "source"], "events": ["id", "type", "user_id", "data", "created_at"], "accounts": ["id", "owner_id", "balance", "currency", "opened_at"], "invoices": ["id", "account_id", "amount", "due_date", "status"], "shipments": ["id", "order_id", "carrier", "tracking", "shipped_at"], "categories": ["id", "name", "parent_id", "sort_order"], "reviews": ["id", "product_id", "user_id", "rating", "body", "created_at"], "inventory": ["id", "product_id", "warehouse_id", "quantity", "updated_at"], "notifications": ["id", "user_id", "type", "read", "created_at"], "messages": ["id", "sender_id", "receiver_id", "body", "sent_at"], "employees": ["id", "name", "department_id", "salary", "hired_at"], "departments": ["id", "name", "budget", "manager_id"], "projects": ["id", "name", "department_id", "deadline", "status"], "tasks": ["id", "project_id", "assignee_id", "title", "status", "due_date"], "comments": ["id", "task_id", "user_id", "body", "created_at"], } AGG_FUNCS = ["COUNT", "SUM", "AVG", "MIN", "MAX"] COMPARISONS = ["=", ">", "<", ">=", "<=", "!="] STRING_VALS = ["'active'", "'pending'", "'completed'", "'cancelled'", "'new'", "'shipped'"] JOIN_TYPES = ["JOIN", "LEFT JOIN", "INNER JOIN", "RIGHT JOIN"] WINDOW_FUNCS = ["ROW_NUMBER()", "RANK()", "DENSE_RANK()", "LAG(t.id, 1)", "LEAD(t.id, 1)"] def _rand_table(): return random.choice(TABLES) def _rand_cols(table, n=None): cols = COLUMNS.get(table, ["id", "name"]) n = n or random.randint(1, min(4, len(cols))) return random.sample(cols, min(n, len(cols))) def _rand_where(alias="t"): col = random.choice(["id", "status", "created_at", "name", "amount", "age"]) op = random.choice(COMPARISONS) if col == "status": return f"{alias}.{col} {op} {random.choice(STRING_VALS)}" elif col in ("id", "age", "amount"): return f"{alias}.{col} {op} {random.randint(1, 10000)}" else: return f"{alias}.{col} {op} '2024-{random.randint(1,12):02d}-{random.randint(1,28):02d}'" # --------------------------------------------------------------------------- # Query generators by complexity tier # --------------------------------------------------------------------------- def _simple_select(): """Tier 1: Simple SELECT with optional WHERE.""" t = _rand_table() cols = ", ".join(_rand_cols(t)) sql = f"SELECT {cols} FROM {t}" if random.random() > 0.3: sql += f" WHERE {_rand_where(t[:1])}" if random.random() > 0.7: sql += f" LIMIT {random.choice([10, 20, 50, 100])}" return sql, "low", random.randint(300, 900), random.randint(5, 20) def _select_with_order(): """Tier 1.5: SELECT with ORDER BY and LIMIT.""" t = _rand_table() cols = ", ".join(_rand_cols(t)) order_col = random.choice(COLUMNS.get(t, ["id"])) direction = random.choice(["ASC", "DESC"]) sql = f"SELECT {cols} FROM {t} WHERE {_rand_where(t[:1])} ORDER BY {order_col} {direction} LIMIT {random.choice([10,25,50])}" return sql, "low", random.randint(600, 1200), random.randint(10, 25) def _single_join(): """Tier 2: Single JOIN query.""" t1, t2 = random.sample(TABLES, 2) c1 = ", ".join(f"a.{c}" for c in _rand_cols(t1, 2)) c2 = ", ".join(f"b.{c}" for c in _rand_cols(t2, 2)) jtype = random.choice(JOIN_TYPES) sql = ( f"SELECT {c1}, {c2} FROM {t1} a " f"{jtype} {t2} b ON a.id = b.{t1[:-1]}_id" ) if random.random() > 0.4: sql += f" WHERE {_rand_where('a')}" return sql, "medium", random.randint(1800, 3600), random.randint(25, 45) def _multi_join(): """Tier 3: Multi-table JOIN.""" tables = random.sample(TABLES, random.randint(3, 5)) selects = [] for i, t in enumerate(tables): alias = chr(97 + i) col = random.choice(COLUMNS.get(t, ["id"])) selects.append(f"{alias}.{col}") sql = f"SELECT {', '.join(selects)} FROM {tables[0]} a" for i in range(1, len(tables)): alias = chr(97 + i) prev_alias = chr(97 + i - 1) jtype = random.choice(JOIN_TYPES) sql += f" {jtype} {tables[i]} {alias} ON {prev_alias}.id = {alias}.{tables[i-1][:-1]}_id" if random.random() > 0.3: sql += f" WHERE {_rand_where('a')}" if random.random() > 0.5: sql += f" ORDER BY a.id LIMIT {random.choice([50, 100, 200])}" return sql, "high", random.randint(3600, 7200), random.randint(45, 70) def _aggregate_query(): """Tier 3: Aggregation with GROUP BY.""" t = _rand_table() group_col = random.choice(COLUMNS.get(t, ["id"])[:3]) agg = random.choice(AGG_FUNCS) agg_col = random.choice(["id", "amount", "total", "price", "salary"]) sql = f"SELECT {group_col}, {agg}({agg_col}) FROM {t}" if random.random() > 0.4: sql += f" WHERE {_rand_where(t[:1])}" sql += f" GROUP BY {group_col}" if random.random() > 0.6: sql += f" HAVING {agg}({agg_col}) > {random.randint(1, 1000)}" if random.random() > 0.5: sql += f" ORDER BY {agg}({agg_col}) DESC" return sql, "high", random.randint(3600, 7200), random.randint(40, 65) def _aggregate_join(): """Tier 4: JOIN + Aggregation.""" t1, t2 = random.sample(TABLES, 2) agg = random.choice(AGG_FUNCS) group_col = f"a.{random.choice(COLUMNS.get(t1, ['id'])[:2])}" agg_col = f"b.{random.choice(['id', 'amount', 'total'])}" jtype = random.choice(JOIN_TYPES) sql = ( f"SELECT {group_col}, {agg}({agg_col}) as agg_val " f"FROM {t1} a {jtype} {t2} b ON a.id = b.{t1[:-1]}_id " f"WHERE {_rand_where('a')} " f"GROUP BY {group_col}" ) if random.random() > 0.5: sql += f" HAVING {agg}({agg_col}) > {random.randint(1, 500)}" sql += f" ORDER BY agg_val DESC LIMIT {random.choice([10, 20, 50])}" return sql, "high", random.randint(3600, 7200), random.randint(55, 80) def _subquery(): """Tier 4: Subquery.""" t1, t2 = random.sample(TABLES, 2) cols = ", ".join(_rand_cols(t1, 2)) sub_agg = random.choice(AGG_FUNCS) op = random.choice([">", "<", ">="]) sql = ( f"SELECT {cols} FROM {t1} " f"WHERE id IN (SELECT {t1[:-1]}_id FROM {t2} " f"WHERE {_rand_where(t2[:1])})" ) return sql, "high", random.randint(3600, 5400), random.randint(50, 75) def _correlated_subquery(): """Tier 5: Correlated subquery.""" t1, t2 = random.sample(TABLES, 2) agg = random.choice(AGG_FUNCS) sql = ( f"SELECT a.id, a.name, " f"(SELECT {agg}(b.id) FROM {t2} b WHERE b.{t1[:-1]}_id = a.id) as sub_val " f"FROM {t1} a WHERE {_rand_where('a')}" ) return sql, "high", random.randint(3600, 7200), random.randint(60, 85) def _cte_query(): """Tier 5: Common Table Expression (WITH).""" t1, t2 = random.sample(TABLES, 2) agg = random.choice(AGG_FUNCS) sql = ( f"WITH cte AS (" f"SELECT {t1[:-1]}_id, {agg}(id) as cnt FROM {t2} GROUP BY {t1[:-1]}_id" f") SELECT a.id, a.name, c.cnt " f"FROM {t1} a JOIN cte c ON a.id = c.{t1[:-1]}_id " f"WHERE c.cnt > {random.randint(1, 50)} " f"ORDER BY c.cnt DESC" ) return sql, "high", random.randint(3600, 7200), random.randint(65, 85) def _window_query(): """Tier 5: Window function.""" t = _rand_table() wfunc = random.choice(["ROW_NUMBER()", "RANK()", "DENSE_RANK()"]) partition_col = random.choice(COLUMNS.get(t, ["id"])[:2]) order_col = random.choice(["id", "created_at"]) sql = ( f"SELECT id, {partition_col}, " f"{wfunc} OVER (PARTITION BY {partition_col} ORDER BY {order_col} DESC) as rn " f"FROM {t} WHERE {_rand_where(t[:1])}" ) return sql, "high", random.randint(3600, 7200), random.randint(55, 80) def _union_query(): """Tier 4: UNION query.""" t1, t2 = random.sample(TABLES, 2) sql = ( f"SELECT id, name FROM {t1} WHERE {_rand_where(t1[:1])} " f"UNION ALL " f"SELECT id, name FROM {t2} WHERE {_rand_where(t2[:1])}" ) return sql, "medium", random.randint(1800, 3600), random.randint(35, 55) def _complex_analytics(): """Tier 6: Complex analytics query.""" t1, t2, t3 = random.sample(TABLES, 3) agg1 = random.choice(AGG_FUNCS) agg2 = random.choice(AGG_FUNCS) sql = ( f"WITH monthly AS (" f"SELECT a.id, a.name, {agg1}(b.id) as cnt, {agg2}(c.id) as total " f"FROM {t1} a " f"LEFT JOIN {t2} b ON a.id = b.{t1[:-1]}_id " f"LEFT JOIN {t3} c ON b.id = c.{t2[:-1]}_id " f"WHERE a.created_at >= '2024-01-01' " f"GROUP BY a.id, a.name " f"HAVING {agg1}(b.id) > {random.randint(1, 20)}" f") SELECT name, cnt, total, " f"RANK() OVER (ORDER BY cnt DESC) as rank " f"FROM monthly ORDER BY rank LIMIT 100" ) return sql, "high", random.randint(5400, 7200), random.randint(80, 100) def _insert_query(): """INSERT — not cacheable.""" t = _rand_table() cols = _rand_cols(t, 3) vals = ", ".join( f"{random.randint(1, 9999)}" if c in ("id", "age") else f"'val_{random.randint(1,99)}'" for c in cols ) sql = f"INSERT INTO {t} ({', '.join(cols)}) VALUES ({vals})" return sql, "low", 0, random.randint(5, 15) def _update_query(): """UPDATE — not cacheable.""" t = _rand_table() col = random.choice(COLUMNS.get(t, ["name"])[1:]) sql = f"UPDATE {t} SET {col} = 'updated' WHERE {_rand_where(t[:1])}" return sql, "low", 0, random.randint(5, 15) def _delete_query(): """DELETE — not cacheable.""" t = _rand_table() sql = f"DELETE FROM {t} WHERE {_rand_where(t[:1])}" return sql, "low", 0, random.randint(5, 10) def _exists_query(): """Tier 4: EXISTS subquery.""" t1, t2 = random.sample(TABLES, 2) cols = ", ".join(_rand_cols(t1, 2)) sql = ( f"SELECT {cols} FROM {t1} a " f"WHERE EXISTS (SELECT 1 FROM {t2} b WHERE b.{t1[:-1]}_id = a.id " f"AND {_rand_where('b')})" ) return sql, "high", random.randint(3600, 5400), random.randint(50, 70) def _case_query(): """Tier 3: CASE expression.""" t = _rand_table() sql = ( f"SELECT id, " f"CASE WHEN status = 'active' THEN 'A' " f"WHEN status = 'pending' THEN 'P' " f"ELSE 'X' END as status_code, " f"name FROM {t} WHERE {_rand_where(t[:1])}" ) return sql, "medium", random.randint(1800, 3600), random.randint(25, 40) def _distinct_query(): """Tier 2: SELECT DISTINCT.""" t = _rand_table() col = random.choice(COLUMNS.get(t, ["name"])[:3]) sql = f"SELECT DISTINCT {col} FROM {t} WHERE {_rand_where(t[:1])} ORDER BY {col}" return sql, "medium", random.randint(1200, 2400), random.randint(20, 35) # --------------------------------------------------------------------------- # Generator registry # --------------------------------------------------------------------------- GENERATORS = [ (_simple_select, 15), (_select_with_order, 10), (_single_join, 12), (_multi_join, 8), (_aggregate_query, 10), (_aggregate_join, 8), (_subquery, 7), (_correlated_subquery, 5), (_cte_query, 5), (_window_query, 5), (_union_query, 4), (_complex_analytics, 3), (_insert_query, 8), (_update_query, 5), (_delete_query, 4), (_exists_query, 5), (_case_query, 4), (_distinct_query, 4), ] # Build weighted list _WEIGHTED = [] for gen, weight in GENERATORS: _WEIGHTED.extend([gen] * weight) def generate_sample(): """Generate one (sql, cache_benefit, ttl, complexity) sample.""" gen = random.choice(_WEIGHTED) sql, benefit, ttl, complexity = gen() # Add slight noise to TTL and complexity ttl = max(0, ttl + random.randint(-60, 60)) complexity = max(1, min(100, complexity + random.randint(-3, 3))) return sql, benefit, ttl, complexity def generate_dataset(n: int = 5000, seed: int = 42): """ Generate a training dataset of n samples. Returns: queries: list[str] benefits: list[str] — "low", "medium", "high" ttls: list[int] — recommended TTL in seconds complexities: list[int] — 1-100 complexity score """ random.seed(seed) queries, benefits, ttls, complexities = [], [], [], [] for _ in range(n): sql, benefit, ttl, complexity = generate_sample() queries.append(sql) benefits.append(benefit) ttls.append(ttl) complexities.append(complexity) return queries, benefits, ttls, complexities