pg-plan-cache-models / features.py
nilenpatel's picture
Upload pg_plan_cache models
406cec4 verified
"""
SQL feature extraction for pg_plan_cache models.
Extracts structural features from raw SQL query text to feed into
the Cache Advisor, TTL Recommender, and Complexity Estimator models.
"""
import re
AGGREGATE_FUNCS = re.compile(
r"\b(count|sum|avg|min|max|array_agg|string_agg|bool_and|bool_or|jsonb_agg)\s*\(",
re.IGNORECASE,
)
WINDOW_FUNCS = re.compile(
r"\b(row_number|rank|dense_rank|ntile|lag|lead|first_value|last_value|nth_value)\s*\(",
re.IGNORECASE,
)
JOIN_PATTERN = re.compile(
r"\b(inner\s+join|left\s+join|right\s+join|full\s+join|cross\s+join|join)\b",
re.IGNORECASE,
)
SUBQUERY_PATTERN = re.compile(r"\(\s*select\b", re.IGNORECASE)
CTE_PATTERN = re.compile(r"\bwith\s+\w+\s+as\s*\(", re.IGNORECASE)
UNION_PATTERN = re.compile(r"\b(union|intersect|except)\b", re.IGNORECASE)
CASE_PATTERN = re.compile(r"\bcase\b", re.IGNORECASE)
IN_PATTERN = re.compile(r"\bin\s*\(", re.IGNORECASE)
LIKE_PATTERN = re.compile(r"\b(like|ilike)\b", re.IGNORECASE)
BETWEEN_PATTERN = re.compile(r"\bbetween\b", re.IGNORECASE)
EXISTS_PATTERN = re.compile(r"\bexists\s*\(", re.IGNORECASE)
HAVING_PATTERN = re.compile(r"\bhaving\b", re.IGNORECASE)
CAST_PATTERN = re.compile(r"\b(cast|::)\b", re.IGNORECASE)
FEATURE_NAMES = [
"query_length",
"query_type", # 0=SELECT, 1=INSERT, 2=UPDATE, 3=DELETE, 4=OTHER
"num_tables",
"num_joins",
"num_conditions",
"num_aggregates",
"num_subqueries",
"num_columns",
"has_distinct",
"has_order_by",
"has_group_by",
"has_having",
"has_limit",
"has_offset",
"has_where",
"has_like",
"has_in_clause",
"has_between",
"has_exists",
"has_window_func",
"has_cte",
"has_union",
"has_case",
"has_cast",
"nesting_depth",
"num_and_or",
"num_string_literals",
"num_numeric_literals",
]
def _count_tables(sql: str) -> int:
"""Estimate the number of tables referenced."""
count = 0
# FROM clause tables
from_match = re.search(r"\bfrom\s+(.+?)(?:\bwhere\b|\bjoin\b|\bgroup\b|\border\b|\blimit\b|\bhaving\b|;|$)", sql, re.IGNORECASE | re.DOTALL)
if from_match:
from_clause = from_match.group(1)
count += len(re.split(r",", from_clause))
# JOIN tables
count += len(JOIN_PATTERN.findall(sql))
return max(count, 0)
def _count_columns(sql: str) -> int:
"""Estimate the number of columns in SELECT clause."""
match = re.search(r"\bselect\s+(.*?)\bfrom\b", sql, re.IGNORECASE | re.DOTALL)
if not match:
return 0
select_clause = match.group(1).strip()
if select_clause == "*":
return 1
# Split by commas not inside parentheses
depth = 0
count = 1
for ch in select_clause:
if ch == '(':
depth += 1
elif ch == ')':
depth -= 1
elif ch == ',' and depth == 0:
count += 1
return count
def _nesting_depth(sql: str) -> int:
"""Calculate maximum parenthesis nesting depth."""
max_depth = 0
depth = 0
for ch in sql:
if ch == '(':
depth += 1
max_depth = max(max_depth, depth)
elif ch == ')':
depth -= 1
return max_depth
def extract_features(sql: str) -> list[float]:
"""
Extract a fixed-length feature vector from a SQL query string.
Returns a list of floats matching FEATURE_NAMES ordering.
"""
sql = sql.strip()
upper = sql.upper().lstrip()
# Query type
if upper.startswith("SELECT"):
qtype = 0
elif upper.startswith("INSERT"):
qtype = 1
elif upper.startswith("UPDATE"):
qtype = 2
elif upper.startswith("DELETE"):
qtype = 3
else:
qtype = 4
num_joins = len(JOIN_PATTERN.findall(sql))
num_aggs = len(AGGREGATE_FUNCS.findall(sql))
num_subqueries = len(SUBQUERY_PATTERN.findall(sql))
num_conditions = len(re.findall(r"\b(and|or)\b", sql, re.IGNORECASE))
num_string_lits = len(re.findall(r"'[^']*'", sql))
num_numeric_lits = len(re.findall(r"\b\d+(?:\.\d+)?\b", sql))
features = [
float(len(sql)), # query_length
float(qtype), # query_type
float(_count_tables(sql)), # num_tables
float(num_joins), # num_joins
float(num_conditions), # num_conditions
float(num_aggs), # num_aggregates
float(num_subqueries), # num_subqueries
float(_count_columns(sql)), # num_columns
float(bool(re.search(r"\bdistinct\b", sql, re.I))), # has_distinct
float(bool(re.search(r"\border\s+by\b", sql, re.I))), # has_order_by
float(bool(re.search(r"\bgroup\s+by\b", sql, re.I))), # has_group_by
float(bool(HAVING_PATTERN.search(sql))), # has_having
float(bool(re.search(r"\blimit\b", sql, re.I))), # has_limit
float(bool(re.search(r"\boffset\b", sql, re.I))), # has_offset
float(bool(re.search(r"\bwhere\b", sql, re.I))), # has_where
float(bool(LIKE_PATTERN.search(sql))), # has_like
float(bool(IN_PATTERN.search(sql))), # has_in_clause
float(bool(BETWEEN_PATTERN.search(sql))), # has_between
float(bool(EXISTS_PATTERN.search(sql))), # has_exists
float(bool(WINDOW_FUNCS.search(sql))), # has_window_func
float(bool(CTE_PATTERN.search(sql))), # has_cte
float(bool(UNION_PATTERN.search(sql))), # has_union
float(bool(CASE_PATTERN.search(sql))), # has_case
float(bool(CAST_PATTERN.search(sql))), # has_cast
float(_nesting_depth(sql)), # nesting_depth
float(num_conditions), # num_and_or
float(num_string_lits), # num_string_literals
float(num_numeric_lits), # num_numeric_literals
]
return features