""" SQL feature extraction for pg_plan_cache models. Extracts structural features from raw SQL query text to feed into the Cache Advisor, TTL Recommender, and Complexity Estimator models. """ import re AGGREGATE_FUNCS = re.compile( r"\b(count|sum|avg|min|max|array_agg|string_agg|bool_and|bool_or|jsonb_agg)\s*\(", re.IGNORECASE, ) WINDOW_FUNCS = re.compile( r"\b(row_number|rank|dense_rank|ntile|lag|lead|first_value|last_value|nth_value)\s*\(", re.IGNORECASE, ) JOIN_PATTERN = re.compile( r"\b(inner\s+join|left\s+join|right\s+join|full\s+join|cross\s+join|join)\b", re.IGNORECASE, ) SUBQUERY_PATTERN = re.compile(r"\(\s*select\b", re.IGNORECASE) CTE_PATTERN = re.compile(r"\bwith\s+\w+\s+as\s*\(", re.IGNORECASE) UNION_PATTERN = re.compile(r"\b(union|intersect|except)\b", re.IGNORECASE) CASE_PATTERN = re.compile(r"\bcase\b", re.IGNORECASE) IN_PATTERN = re.compile(r"\bin\s*\(", re.IGNORECASE) LIKE_PATTERN = re.compile(r"\b(like|ilike)\b", re.IGNORECASE) BETWEEN_PATTERN = re.compile(r"\bbetween\b", re.IGNORECASE) EXISTS_PATTERN = re.compile(r"\bexists\s*\(", re.IGNORECASE) HAVING_PATTERN = re.compile(r"\bhaving\b", re.IGNORECASE) CAST_PATTERN = re.compile(r"\b(cast|::)\b", re.IGNORECASE) FEATURE_NAMES = [ "query_length", "query_type", # 0=SELECT, 1=INSERT, 2=UPDATE, 3=DELETE, 4=OTHER "num_tables", "num_joins", "num_conditions", "num_aggregates", "num_subqueries", "num_columns", "has_distinct", "has_order_by", "has_group_by", "has_having", "has_limit", "has_offset", "has_where", "has_like", "has_in_clause", "has_between", "has_exists", "has_window_func", "has_cte", "has_union", "has_case", "has_cast", "nesting_depth", "num_and_or", "num_string_literals", "num_numeric_literals", ] def _count_tables(sql: str) -> int: """Estimate the number of tables referenced.""" count = 0 # FROM clause tables from_match = re.search(r"\bfrom\s+(.+?)(?:\bwhere\b|\bjoin\b|\bgroup\b|\border\b|\blimit\b|\bhaving\b|;|$)", sql, re.IGNORECASE | re.DOTALL) if from_match: from_clause = from_match.group(1) count += len(re.split(r",", from_clause)) # JOIN tables count += len(JOIN_PATTERN.findall(sql)) return max(count, 0) def _count_columns(sql: str) -> int: """Estimate the number of columns in SELECT clause.""" match = re.search(r"\bselect\s+(.*?)\bfrom\b", sql, re.IGNORECASE | re.DOTALL) if not match: return 0 select_clause = match.group(1).strip() if select_clause == "*": return 1 # Split by commas not inside parentheses depth = 0 count = 1 for ch in select_clause: if ch == '(': depth += 1 elif ch == ')': depth -= 1 elif ch == ',' and depth == 0: count += 1 return count def _nesting_depth(sql: str) -> int: """Calculate maximum parenthesis nesting depth.""" max_depth = 0 depth = 0 for ch in sql: if ch == '(': depth += 1 max_depth = max(max_depth, depth) elif ch == ')': depth -= 1 return max_depth def extract_features(sql: str) -> list[float]: """ Extract a fixed-length feature vector from a SQL query string. Returns a list of floats matching FEATURE_NAMES ordering. """ sql = sql.strip() upper = sql.upper().lstrip() # Query type if upper.startswith("SELECT"): qtype = 0 elif upper.startswith("INSERT"): qtype = 1 elif upper.startswith("UPDATE"): qtype = 2 elif upper.startswith("DELETE"): qtype = 3 else: qtype = 4 num_joins = len(JOIN_PATTERN.findall(sql)) num_aggs = len(AGGREGATE_FUNCS.findall(sql)) num_subqueries = len(SUBQUERY_PATTERN.findall(sql)) num_conditions = len(re.findall(r"\b(and|or)\b", sql, re.IGNORECASE)) num_string_lits = len(re.findall(r"'[^']*'", sql)) num_numeric_lits = len(re.findall(r"\b\d+(?:\.\d+)?\b", sql)) features = [ float(len(sql)), # query_length float(qtype), # query_type float(_count_tables(sql)), # num_tables float(num_joins), # num_joins float(num_conditions), # num_conditions float(num_aggs), # num_aggregates float(num_subqueries), # num_subqueries float(_count_columns(sql)), # num_columns float(bool(re.search(r"\bdistinct\b", sql, re.I))), # has_distinct float(bool(re.search(r"\border\s+by\b", sql, re.I))), # has_order_by float(bool(re.search(r"\bgroup\s+by\b", sql, re.I))), # has_group_by float(bool(HAVING_PATTERN.search(sql))), # has_having float(bool(re.search(r"\blimit\b", sql, re.I))), # has_limit float(bool(re.search(r"\boffset\b", sql, re.I))), # has_offset float(bool(re.search(r"\bwhere\b", sql, re.I))), # has_where float(bool(LIKE_PATTERN.search(sql))), # has_like float(bool(IN_PATTERN.search(sql))), # has_in_clause float(bool(BETWEEN_PATTERN.search(sql))), # has_between float(bool(EXISTS_PATTERN.search(sql))), # has_exists float(bool(WINDOW_FUNCS.search(sql))), # has_window_func float(bool(CTE_PATTERN.search(sql))), # has_cte float(bool(UNION_PATTERN.search(sql))), # has_union float(bool(CASE_PATTERN.search(sql))), # has_case float(bool(CAST_PATTERN.search(sql))), # has_cast float(_nesting_depth(sql)), # nesting_depth float(num_conditions), # num_and_or float(num_string_lits), # num_string_literals float(num_numeric_lits), # num_numeric_literals ] return features