| """ |
| SQL feature extraction for pg_plan_cache models. |
| |
| Extracts structural features from raw SQL query text to feed into |
| the Cache Advisor, TTL Recommender, and Complexity Estimator models. |
| """ |
|
|
| import re |
|
|
|
|
| AGGREGATE_FUNCS = re.compile( |
| r"\b(count|sum|avg|min|max|array_agg|string_agg|bool_and|bool_or|jsonb_agg)\s*\(", |
| re.IGNORECASE, |
| ) |
| WINDOW_FUNCS = re.compile( |
| r"\b(row_number|rank|dense_rank|ntile|lag|lead|first_value|last_value|nth_value)\s*\(", |
| re.IGNORECASE, |
| ) |
| JOIN_PATTERN = re.compile( |
| r"\b(inner\s+join|left\s+join|right\s+join|full\s+join|cross\s+join|join)\b", |
| re.IGNORECASE, |
| ) |
| SUBQUERY_PATTERN = re.compile(r"\(\s*select\b", re.IGNORECASE) |
| CTE_PATTERN = re.compile(r"\bwith\s+\w+\s+as\s*\(", re.IGNORECASE) |
| UNION_PATTERN = re.compile(r"\b(union|intersect|except)\b", re.IGNORECASE) |
| CASE_PATTERN = re.compile(r"\bcase\b", re.IGNORECASE) |
| IN_PATTERN = re.compile(r"\bin\s*\(", re.IGNORECASE) |
| LIKE_PATTERN = re.compile(r"\b(like|ilike)\b", re.IGNORECASE) |
| BETWEEN_PATTERN = re.compile(r"\bbetween\b", re.IGNORECASE) |
| EXISTS_PATTERN = re.compile(r"\bexists\s*\(", re.IGNORECASE) |
| HAVING_PATTERN = re.compile(r"\bhaving\b", re.IGNORECASE) |
| CAST_PATTERN = re.compile(r"\b(cast|::)\b", re.IGNORECASE) |
|
|
| FEATURE_NAMES = [ |
| "query_length", |
| "query_type", |
| "num_tables", |
| "num_joins", |
| "num_conditions", |
| "num_aggregates", |
| "num_subqueries", |
| "num_columns", |
| "has_distinct", |
| "has_order_by", |
| "has_group_by", |
| "has_having", |
| "has_limit", |
| "has_offset", |
| "has_where", |
| "has_like", |
| "has_in_clause", |
| "has_between", |
| "has_exists", |
| "has_window_func", |
| "has_cte", |
| "has_union", |
| "has_case", |
| "has_cast", |
| "nesting_depth", |
| "num_and_or", |
| "num_string_literals", |
| "num_numeric_literals", |
| ] |
|
|
|
|
| def _count_tables(sql: str) -> int: |
| """Estimate the number of tables referenced.""" |
| count = 0 |
| |
| from_match = re.search(r"\bfrom\s+(.+?)(?:\bwhere\b|\bjoin\b|\bgroup\b|\border\b|\blimit\b|\bhaving\b|;|$)", sql, re.IGNORECASE | re.DOTALL) |
| if from_match: |
| from_clause = from_match.group(1) |
| count += len(re.split(r",", from_clause)) |
| |
| count += len(JOIN_PATTERN.findall(sql)) |
| return max(count, 0) |
|
|
|
|
| def _count_columns(sql: str) -> int: |
| """Estimate the number of columns in SELECT clause.""" |
| match = re.search(r"\bselect\s+(.*?)\bfrom\b", sql, re.IGNORECASE | re.DOTALL) |
| if not match: |
| return 0 |
| select_clause = match.group(1).strip() |
| if select_clause == "*": |
| return 1 |
| |
| depth = 0 |
| count = 1 |
| for ch in select_clause: |
| if ch == '(': |
| depth += 1 |
| elif ch == ')': |
| depth -= 1 |
| elif ch == ',' and depth == 0: |
| count += 1 |
| return count |
|
|
|
|
| def _nesting_depth(sql: str) -> int: |
| """Calculate maximum parenthesis nesting depth.""" |
| max_depth = 0 |
| depth = 0 |
| for ch in sql: |
| if ch == '(': |
| depth += 1 |
| max_depth = max(max_depth, depth) |
| elif ch == ')': |
| depth -= 1 |
| return max_depth |
|
|
|
|
| def extract_features(sql: str) -> list[float]: |
| """ |
| Extract a fixed-length feature vector from a SQL query string. |
| |
| Returns a list of floats matching FEATURE_NAMES ordering. |
| """ |
| sql = sql.strip() |
| upper = sql.upper().lstrip() |
|
|
| |
| if upper.startswith("SELECT"): |
| qtype = 0 |
| elif upper.startswith("INSERT"): |
| qtype = 1 |
| elif upper.startswith("UPDATE"): |
| qtype = 2 |
| elif upper.startswith("DELETE"): |
| qtype = 3 |
| else: |
| qtype = 4 |
|
|
| num_joins = len(JOIN_PATTERN.findall(sql)) |
| num_aggs = len(AGGREGATE_FUNCS.findall(sql)) |
| num_subqueries = len(SUBQUERY_PATTERN.findall(sql)) |
| num_conditions = len(re.findall(r"\b(and|or)\b", sql, re.IGNORECASE)) |
| num_string_lits = len(re.findall(r"'[^']*'", sql)) |
| num_numeric_lits = len(re.findall(r"\b\d+(?:\.\d+)?\b", sql)) |
|
|
| features = [ |
| float(len(sql)), |
| float(qtype), |
| float(_count_tables(sql)), |
| float(num_joins), |
| float(num_conditions), |
| float(num_aggs), |
| float(num_subqueries), |
| float(_count_columns(sql)), |
| float(bool(re.search(r"\bdistinct\b", sql, re.I))), |
| float(bool(re.search(r"\border\s+by\b", sql, re.I))), |
| float(bool(re.search(r"\bgroup\s+by\b", sql, re.I))), |
| float(bool(HAVING_PATTERN.search(sql))), |
| float(bool(re.search(r"\blimit\b", sql, re.I))), |
| float(bool(re.search(r"\boffset\b", sql, re.I))), |
| float(bool(re.search(r"\bwhere\b", sql, re.I))), |
| float(bool(LIKE_PATTERN.search(sql))), |
| float(bool(IN_PATTERN.search(sql))), |
| float(bool(BETWEEN_PATTERN.search(sql))), |
| float(bool(EXISTS_PATTERN.search(sql))), |
| float(bool(WINDOW_FUNCS.search(sql))), |
| float(bool(CTE_PATTERN.search(sql))), |
| float(bool(UNION_PATTERN.search(sql))), |
| float(bool(CASE_PATTERN.search(sql))), |
| float(bool(CAST_PATTERN.search(sql))), |
| float(_nesting_depth(sql)), |
| float(num_conditions), |
| float(num_string_lits), |
| float(num_numeric_lits), |
| ] |
|
|
| return features |
|
|