Upload pg_plan_cache models
Browse files- README.md +88 -0
- dataset.py +380 -0
- features.py +172 -0
- predict.py +137 -0
- requirements.txt +4 -0
- train.py +164 -0
- trained/cache_advisor.joblib +3 -0
- trained/complexity_estimator.joblib +3 -0
- trained/feature_importances.json +92 -0
- trained/label_encoder.joblib +3 -0
- trained/metadata.json +61 -0
- trained/ttl_recommender.joblib +3 -0
README.md
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
library_name: sklearn
|
| 3 |
+
tags:
|
| 4 |
+
- postgresql
|
| 5 |
+
- sql
|
| 6 |
+
- query-cache
|
| 7 |
+
- plan-cache
|
| 8 |
+
- redis
|
| 9 |
+
- database
|
| 10 |
+
- tabular-classification
|
| 11 |
+
- tabular-regression
|
| 12 |
+
pipeline_tag: tabular-classification
|
| 13 |
+
license: mit
|
| 14 |
+
---
|
| 15 |
+
|
| 16 |
+
# pg_plan_cache Models
|
| 17 |
+
|
| 18 |
+
Three machine learning models for the **pg_plan_cache** PostgreSQL extension — a query
|
| 19 |
+
execution plan cache backed by Redis.
|
| 20 |
+
|
| 21 |
+
## Models
|
| 22 |
+
|
| 23 |
+
### 1. SQL Cache Advisor
|
| 24 |
+
- **Task:** Classification (high / medium / low)
|
| 25 |
+
- **Algorithm:** Random Forest (200 trees)
|
| 26 |
+
- **Purpose:** Predicts whether caching a query's execution plan will be beneficial
|
| 27 |
+
|
| 28 |
+
### 2. Cache TTL Recommender
|
| 29 |
+
- **Task:** Regression (seconds)
|
| 30 |
+
- **Algorithm:** Gradient Boosting
|
| 31 |
+
- **Purpose:** Recommends optimal cache TTL based on query characteristics
|
| 32 |
+
|
| 33 |
+
### 3. Query Complexity Estimator
|
| 34 |
+
- **Task:** Regression (1-100 score)
|
| 35 |
+
- **Algorithm:** Gradient Boosting
|
| 36 |
+
- **Purpose:** Estimates query complexity to prioritize caching resources
|
| 37 |
+
|
| 38 |
+
## Features
|
| 39 |
+
|
| 40 |
+
All models use 28 structural features extracted from raw SQL text:
|
| 41 |
+
|
| 42 |
+
| Feature | Description |
|
| 43 |
+
|---------|------------|
|
| 44 |
+
| `query_length` | Character count |
|
| 45 |
+
| `query_type` | SELECT=0, INSERT=1, UPDATE=2, DELETE=3 |
|
| 46 |
+
| `num_tables` | Tables referenced |
|
| 47 |
+
| `num_joins` | JOIN clause count |
|
| 48 |
+
| `num_conditions` | AND/OR conditions |
|
| 49 |
+
| `num_aggregates` | Aggregate function count |
|
| 50 |
+
| `num_subqueries` | Subquery count |
|
| 51 |
+
| `has_window_func` | Window functions present |
|
| 52 |
+
| `has_cte` | Common Table Expressions |
|
| 53 |
+
| `nesting_depth` | Max parenthesis depth |
|
| 54 |
+
| ... | 18 more features |
|
| 55 |
+
|
| 56 |
+
## Usage
|
| 57 |
+
|
| 58 |
+
```python
|
| 59 |
+
from predict import predict, format_prediction
|
| 60 |
+
|
| 61 |
+
result = predict("SELECT u.name, COUNT(o.id) FROM users u JOIN orders o ON u.id = o.user_id GROUP BY u.name")
|
| 62 |
+
print(format_prediction(result))
|
| 63 |
+
# Cache Benefit: HIGH
|
| 64 |
+
# Recommended TTL: 4200s (1h 10m)
|
| 65 |
+
# Complexity: 62/100 (complex)
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
## Training
|
| 69 |
+
|
| 70 |
+
Trained on 8,000 synthetic SQL queries across 18 complexity tiers:
|
| 71 |
+
- Simple SELECTs, filtered queries, ORDER BY
|
| 72 |
+
- Single and multi-table JOINs
|
| 73 |
+
- Aggregations with GROUP BY / HAVING
|
| 74 |
+
- Subqueries, correlated subqueries, EXISTS
|
| 75 |
+
- CTEs, window functions, UNION
|
| 76 |
+
- Complex analytics queries
|
| 77 |
+
- INSERT / UPDATE / DELETE (non-cacheable)
|
| 78 |
+
|
| 79 |
+
```bash
|
| 80 |
+
pip install -r requirements.txt
|
| 81 |
+
python train.py
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
## About pg_plan_cache
|
| 85 |
+
|
| 86 |
+
pg_plan_cache is a PostgreSQL extension that caches query execution plans in Redis.
|
| 87 |
+
It hooks into the PostgreSQL planner, normalizes queries, computes SHA-256 hashes,
|
| 88 |
+
and stores serialized plans with configurable TTL and automatic schema-change invalidation.
|
dataset.py
ADDED
|
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Synthetic training data generator for pg_plan_cache models.
|
| 3 |
+
|
| 4 |
+
Generates realistic SQL queries across a wide range of complexity levels
|
| 5 |
+
with labels for cache benefit, recommended TTL, and complexity score.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import random
|
| 9 |
+
|
| 10 |
+
# ---------------------------------------------------------------------------
|
| 11 |
+
# Building blocks
|
| 12 |
+
# ---------------------------------------------------------------------------
|
| 13 |
+
|
| 14 |
+
TABLES = [
|
| 15 |
+
"users", "orders", "products", "payments", "sessions",
|
| 16 |
+
"logs", "events", "accounts", "invoices", "shipments",
|
| 17 |
+
"categories", "reviews", "inventory", "notifications", "messages",
|
| 18 |
+
"employees", "departments", "projects", "tasks", "comments",
|
| 19 |
+
]
|
| 20 |
+
|
| 21 |
+
SCHEMAS = ["public", "app", "analytics", "billing"]
|
| 22 |
+
|
| 23 |
+
COLUMNS = {
|
| 24 |
+
"users": ["id", "name", "email", "created_at", "status", "age", "country"],
|
| 25 |
+
"orders": ["id", "user_id", "total", "status", "created_at", "shipped_at"],
|
| 26 |
+
"products": ["id", "name", "price", "category_id", "stock", "rating"],
|
| 27 |
+
"payments": ["id", "order_id", "amount", "method", "paid_at", "status"],
|
| 28 |
+
"sessions": ["id", "user_id", "started_at", "ended_at", "ip_address"],
|
| 29 |
+
"logs": ["id", "level", "message", "created_at", "source"],
|
| 30 |
+
"events": ["id", "type", "user_id", "data", "created_at"],
|
| 31 |
+
"accounts": ["id", "owner_id", "balance", "currency", "opened_at"],
|
| 32 |
+
"invoices": ["id", "account_id", "amount", "due_date", "status"],
|
| 33 |
+
"shipments": ["id", "order_id", "carrier", "tracking", "shipped_at"],
|
| 34 |
+
"categories": ["id", "name", "parent_id", "sort_order"],
|
| 35 |
+
"reviews": ["id", "product_id", "user_id", "rating", "body", "created_at"],
|
| 36 |
+
"inventory": ["id", "product_id", "warehouse_id", "quantity", "updated_at"],
|
| 37 |
+
"notifications": ["id", "user_id", "type", "read", "created_at"],
|
| 38 |
+
"messages": ["id", "sender_id", "receiver_id", "body", "sent_at"],
|
| 39 |
+
"employees": ["id", "name", "department_id", "salary", "hired_at"],
|
| 40 |
+
"departments": ["id", "name", "budget", "manager_id"],
|
| 41 |
+
"projects": ["id", "name", "department_id", "deadline", "status"],
|
| 42 |
+
"tasks": ["id", "project_id", "assignee_id", "title", "status", "due_date"],
|
| 43 |
+
"comments": ["id", "task_id", "user_id", "body", "created_at"],
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
AGG_FUNCS = ["COUNT", "SUM", "AVG", "MIN", "MAX"]
|
| 47 |
+
COMPARISONS = ["=", ">", "<", ">=", "<=", "!="]
|
| 48 |
+
STRING_VALS = ["'active'", "'pending'", "'completed'", "'cancelled'", "'new'", "'shipped'"]
|
| 49 |
+
JOIN_TYPES = ["JOIN", "LEFT JOIN", "INNER JOIN", "RIGHT JOIN"]
|
| 50 |
+
WINDOW_FUNCS = ["ROW_NUMBER()", "RANK()", "DENSE_RANK()", "LAG(t.id, 1)", "LEAD(t.id, 1)"]
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _rand_table():
|
| 54 |
+
return random.choice(TABLES)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _rand_cols(table, n=None):
|
| 58 |
+
cols = COLUMNS.get(table, ["id", "name"])
|
| 59 |
+
n = n or random.randint(1, min(4, len(cols)))
|
| 60 |
+
return random.sample(cols, min(n, len(cols)))
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _rand_where(alias="t"):
|
| 64 |
+
col = random.choice(["id", "status", "created_at", "name", "amount", "age"])
|
| 65 |
+
op = random.choice(COMPARISONS)
|
| 66 |
+
if col == "status":
|
| 67 |
+
return f"{alias}.{col} {op} {random.choice(STRING_VALS)}"
|
| 68 |
+
elif col in ("id", "age", "amount"):
|
| 69 |
+
return f"{alias}.{col} {op} {random.randint(1, 10000)}"
|
| 70 |
+
else:
|
| 71 |
+
return f"{alias}.{col} {op} '2024-{random.randint(1,12):02d}-{random.randint(1,28):02d}'"
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
# ---------------------------------------------------------------------------
|
| 75 |
+
# Query generators by complexity tier
|
| 76 |
+
# ---------------------------------------------------------------------------
|
| 77 |
+
|
| 78 |
+
def _simple_select():
|
| 79 |
+
"""Tier 1: Simple SELECT with optional WHERE."""
|
| 80 |
+
t = _rand_table()
|
| 81 |
+
cols = ", ".join(_rand_cols(t))
|
| 82 |
+
sql = f"SELECT {cols} FROM {t}"
|
| 83 |
+
if random.random() > 0.3:
|
| 84 |
+
sql += f" WHERE {_rand_where(t[:1])}"
|
| 85 |
+
if random.random() > 0.7:
|
| 86 |
+
sql += f" LIMIT {random.choice([10, 20, 50, 100])}"
|
| 87 |
+
return sql, "low", random.randint(300, 900), random.randint(5, 20)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def _select_with_order():
|
| 91 |
+
"""Tier 1.5: SELECT with ORDER BY and LIMIT."""
|
| 92 |
+
t = _rand_table()
|
| 93 |
+
cols = ", ".join(_rand_cols(t))
|
| 94 |
+
order_col = random.choice(COLUMNS.get(t, ["id"]))
|
| 95 |
+
direction = random.choice(["ASC", "DESC"])
|
| 96 |
+
sql = f"SELECT {cols} FROM {t} WHERE {_rand_where(t[:1])} ORDER BY {order_col} {direction} LIMIT {random.choice([10,25,50])}"
|
| 97 |
+
return sql, "low", random.randint(600, 1200), random.randint(10, 25)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def _single_join():
|
| 101 |
+
"""Tier 2: Single JOIN query."""
|
| 102 |
+
t1, t2 = random.sample(TABLES, 2)
|
| 103 |
+
c1 = ", ".join(f"a.{c}" for c in _rand_cols(t1, 2))
|
| 104 |
+
c2 = ", ".join(f"b.{c}" for c in _rand_cols(t2, 2))
|
| 105 |
+
jtype = random.choice(JOIN_TYPES)
|
| 106 |
+
sql = (
|
| 107 |
+
f"SELECT {c1}, {c2} FROM {t1} a "
|
| 108 |
+
f"{jtype} {t2} b ON a.id = b.{t1[:-1]}_id"
|
| 109 |
+
)
|
| 110 |
+
if random.random() > 0.4:
|
| 111 |
+
sql += f" WHERE {_rand_where('a')}"
|
| 112 |
+
return sql, "medium", random.randint(1800, 3600), random.randint(25, 45)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def _multi_join():
|
| 116 |
+
"""Tier 3: Multi-table JOIN."""
|
| 117 |
+
tables = random.sample(TABLES, random.randint(3, 5))
|
| 118 |
+
selects = []
|
| 119 |
+
for i, t in enumerate(tables):
|
| 120 |
+
alias = chr(97 + i)
|
| 121 |
+
col = random.choice(COLUMNS.get(t, ["id"]))
|
| 122 |
+
selects.append(f"{alias}.{col}")
|
| 123 |
+
|
| 124 |
+
sql = f"SELECT {', '.join(selects)} FROM {tables[0]} a"
|
| 125 |
+
for i in range(1, len(tables)):
|
| 126 |
+
alias = chr(97 + i)
|
| 127 |
+
prev_alias = chr(97 + i - 1)
|
| 128 |
+
jtype = random.choice(JOIN_TYPES)
|
| 129 |
+
sql += f" {jtype} {tables[i]} {alias} ON {prev_alias}.id = {alias}.{tables[i-1][:-1]}_id"
|
| 130 |
+
|
| 131 |
+
if random.random() > 0.3:
|
| 132 |
+
sql += f" WHERE {_rand_where('a')}"
|
| 133 |
+
if random.random() > 0.5:
|
| 134 |
+
sql += f" ORDER BY a.id LIMIT {random.choice([50, 100, 200])}"
|
| 135 |
+
return sql, "high", random.randint(3600, 7200), random.randint(45, 70)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def _aggregate_query():
|
| 139 |
+
"""Tier 3: Aggregation with GROUP BY."""
|
| 140 |
+
t = _rand_table()
|
| 141 |
+
group_col = random.choice(COLUMNS.get(t, ["id"])[:3])
|
| 142 |
+
agg = random.choice(AGG_FUNCS)
|
| 143 |
+
agg_col = random.choice(["id", "amount", "total", "price", "salary"])
|
| 144 |
+
sql = f"SELECT {group_col}, {agg}({agg_col}) FROM {t}"
|
| 145 |
+
if random.random() > 0.4:
|
| 146 |
+
sql += f" WHERE {_rand_where(t[:1])}"
|
| 147 |
+
sql += f" GROUP BY {group_col}"
|
| 148 |
+
if random.random() > 0.6:
|
| 149 |
+
sql += f" HAVING {agg}({agg_col}) > {random.randint(1, 1000)}"
|
| 150 |
+
if random.random() > 0.5:
|
| 151 |
+
sql += f" ORDER BY {agg}({agg_col}) DESC"
|
| 152 |
+
return sql, "high", random.randint(3600, 7200), random.randint(40, 65)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def _aggregate_join():
|
| 156 |
+
"""Tier 4: JOIN + Aggregation."""
|
| 157 |
+
t1, t2 = random.sample(TABLES, 2)
|
| 158 |
+
agg = random.choice(AGG_FUNCS)
|
| 159 |
+
group_col = f"a.{random.choice(COLUMNS.get(t1, ['id'])[:2])}"
|
| 160 |
+
agg_col = f"b.{random.choice(['id', 'amount', 'total'])}"
|
| 161 |
+
jtype = random.choice(JOIN_TYPES)
|
| 162 |
+
sql = (
|
| 163 |
+
f"SELECT {group_col}, {agg}({agg_col}) as agg_val "
|
| 164 |
+
f"FROM {t1} a {jtype} {t2} b ON a.id = b.{t1[:-1]}_id "
|
| 165 |
+
f"WHERE {_rand_where('a')} "
|
| 166 |
+
f"GROUP BY {group_col}"
|
| 167 |
+
)
|
| 168 |
+
if random.random() > 0.5:
|
| 169 |
+
sql += f" HAVING {agg}({agg_col}) > {random.randint(1, 500)}"
|
| 170 |
+
sql += f" ORDER BY agg_val DESC LIMIT {random.choice([10, 20, 50])}"
|
| 171 |
+
return sql, "high", random.randint(3600, 7200), random.randint(55, 80)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def _subquery():
|
| 175 |
+
"""Tier 4: Subquery."""
|
| 176 |
+
t1, t2 = random.sample(TABLES, 2)
|
| 177 |
+
cols = ", ".join(_rand_cols(t1, 2))
|
| 178 |
+
sub_agg = random.choice(AGG_FUNCS)
|
| 179 |
+
op = random.choice([">", "<", ">="])
|
| 180 |
+
sql = (
|
| 181 |
+
f"SELECT {cols} FROM {t1} "
|
| 182 |
+
f"WHERE id IN (SELECT {t1[:-1]}_id FROM {t2} "
|
| 183 |
+
f"WHERE {_rand_where(t2[:1])})"
|
| 184 |
+
)
|
| 185 |
+
return sql, "high", random.randint(3600, 5400), random.randint(50, 75)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def _correlated_subquery():
|
| 189 |
+
"""Tier 5: Correlated subquery."""
|
| 190 |
+
t1, t2 = random.sample(TABLES, 2)
|
| 191 |
+
agg = random.choice(AGG_FUNCS)
|
| 192 |
+
sql = (
|
| 193 |
+
f"SELECT a.id, a.name, "
|
| 194 |
+
f"(SELECT {agg}(b.id) FROM {t2} b WHERE b.{t1[:-1]}_id = a.id) as sub_val "
|
| 195 |
+
f"FROM {t1} a WHERE {_rand_where('a')}"
|
| 196 |
+
)
|
| 197 |
+
return sql, "high", random.randint(3600, 7200), random.randint(60, 85)
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def _cte_query():
|
| 201 |
+
"""Tier 5: Common Table Expression (WITH)."""
|
| 202 |
+
t1, t2 = random.sample(TABLES, 2)
|
| 203 |
+
agg = random.choice(AGG_FUNCS)
|
| 204 |
+
sql = (
|
| 205 |
+
f"WITH cte AS ("
|
| 206 |
+
f"SELECT {t1[:-1]}_id, {agg}(id) as cnt FROM {t2} GROUP BY {t1[:-1]}_id"
|
| 207 |
+
f") SELECT a.id, a.name, c.cnt "
|
| 208 |
+
f"FROM {t1} a JOIN cte c ON a.id = c.{t1[:-1]}_id "
|
| 209 |
+
f"WHERE c.cnt > {random.randint(1, 50)} "
|
| 210 |
+
f"ORDER BY c.cnt DESC"
|
| 211 |
+
)
|
| 212 |
+
return sql, "high", random.randint(3600, 7200), random.randint(65, 85)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def _window_query():
|
| 216 |
+
"""Tier 5: Window function."""
|
| 217 |
+
t = _rand_table()
|
| 218 |
+
wfunc = random.choice(["ROW_NUMBER()", "RANK()", "DENSE_RANK()"])
|
| 219 |
+
partition_col = random.choice(COLUMNS.get(t, ["id"])[:2])
|
| 220 |
+
order_col = random.choice(["id", "created_at"])
|
| 221 |
+
sql = (
|
| 222 |
+
f"SELECT id, {partition_col}, "
|
| 223 |
+
f"{wfunc} OVER (PARTITION BY {partition_col} ORDER BY {order_col} DESC) as rn "
|
| 224 |
+
f"FROM {t} WHERE {_rand_where(t[:1])}"
|
| 225 |
+
)
|
| 226 |
+
return sql, "high", random.randint(3600, 7200), random.randint(55, 80)
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def _union_query():
|
| 230 |
+
"""Tier 4: UNION query."""
|
| 231 |
+
t1, t2 = random.sample(TABLES, 2)
|
| 232 |
+
sql = (
|
| 233 |
+
f"SELECT id, name FROM {t1} WHERE {_rand_where(t1[:1])} "
|
| 234 |
+
f"UNION ALL "
|
| 235 |
+
f"SELECT id, name FROM {t2} WHERE {_rand_where(t2[:1])}"
|
| 236 |
+
)
|
| 237 |
+
return sql, "medium", random.randint(1800, 3600), random.randint(35, 55)
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def _complex_analytics():
|
| 241 |
+
"""Tier 6: Complex analytics query."""
|
| 242 |
+
t1, t2, t3 = random.sample(TABLES, 3)
|
| 243 |
+
agg1 = random.choice(AGG_FUNCS)
|
| 244 |
+
agg2 = random.choice(AGG_FUNCS)
|
| 245 |
+
sql = (
|
| 246 |
+
f"WITH monthly AS ("
|
| 247 |
+
f"SELECT a.id, a.name, {agg1}(b.id) as cnt, {agg2}(c.id) as total "
|
| 248 |
+
f"FROM {t1} a "
|
| 249 |
+
f"LEFT JOIN {t2} b ON a.id = b.{t1[:-1]}_id "
|
| 250 |
+
f"LEFT JOIN {t3} c ON b.id = c.{t2[:-1]}_id "
|
| 251 |
+
f"WHERE a.created_at >= '2024-01-01' "
|
| 252 |
+
f"GROUP BY a.id, a.name "
|
| 253 |
+
f"HAVING {agg1}(b.id) > {random.randint(1, 20)}"
|
| 254 |
+
f") SELECT name, cnt, total, "
|
| 255 |
+
f"RANK() OVER (ORDER BY cnt DESC) as rank "
|
| 256 |
+
f"FROM monthly ORDER BY rank LIMIT 100"
|
| 257 |
+
)
|
| 258 |
+
return sql, "high", random.randint(5400, 7200), random.randint(80, 100)
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def _insert_query():
|
| 262 |
+
"""INSERT — not cacheable."""
|
| 263 |
+
t = _rand_table()
|
| 264 |
+
cols = _rand_cols(t, 3)
|
| 265 |
+
vals = ", ".join(
|
| 266 |
+
f"{random.randint(1, 9999)}" if c in ("id", "age") else f"'val_{random.randint(1,99)}'"
|
| 267 |
+
for c in cols
|
| 268 |
+
)
|
| 269 |
+
sql = f"INSERT INTO {t} ({', '.join(cols)}) VALUES ({vals})"
|
| 270 |
+
return sql, "low", 0, random.randint(5, 15)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def _update_query():
|
| 274 |
+
"""UPDATE — not cacheable."""
|
| 275 |
+
t = _rand_table()
|
| 276 |
+
col = random.choice(COLUMNS.get(t, ["name"])[1:])
|
| 277 |
+
sql = f"UPDATE {t} SET {col} = 'updated' WHERE {_rand_where(t[:1])}"
|
| 278 |
+
return sql, "low", 0, random.randint(5, 15)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def _delete_query():
|
| 282 |
+
"""DELETE — not cacheable."""
|
| 283 |
+
t = _rand_table()
|
| 284 |
+
sql = f"DELETE FROM {t} WHERE {_rand_where(t[:1])}"
|
| 285 |
+
return sql, "low", 0, random.randint(5, 10)
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def _exists_query():
|
| 289 |
+
"""Tier 4: EXISTS subquery."""
|
| 290 |
+
t1, t2 = random.sample(TABLES, 2)
|
| 291 |
+
cols = ", ".join(_rand_cols(t1, 2))
|
| 292 |
+
sql = (
|
| 293 |
+
f"SELECT {cols} FROM {t1} a "
|
| 294 |
+
f"WHERE EXISTS (SELECT 1 FROM {t2} b WHERE b.{t1[:-1]}_id = a.id "
|
| 295 |
+
f"AND {_rand_where('b')})"
|
| 296 |
+
)
|
| 297 |
+
return sql, "high", random.randint(3600, 5400), random.randint(50, 70)
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def _case_query():
|
| 301 |
+
"""Tier 3: CASE expression."""
|
| 302 |
+
t = _rand_table()
|
| 303 |
+
sql = (
|
| 304 |
+
f"SELECT id, "
|
| 305 |
+
f"CASE WHEN status = 'active' THEN 'A' "
|
| 306 |
+
f"WHEN status = 'pending' THEN 'P' "
|
| 307 |
+
f"ELSE 'X' END as status_code, "
|
| 308 |
+
f"name FROM {t} WHERE {_rand_where(t[:1])}"
|
| 309 |
+
)
|
| 310 |
+
return sql, "medium", random.randint(1800, 3600), random.randint(25, 40)
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def _distinct_query():
|
| 314 |
+
"""Tier 2: SELECT DISTINCT."""
|
| 315 |
+
t = _rand_table()
|
| 316 |
+
col = random.choice(COLUMNS.get(t, ["name"])[:3])
|
| 317 |
+
sql = f"SELECT DISTINCT {col} FROM {t} WHERE {_rand_where(t[:1])} ORDER BY {col}"
|
| 318 |
+
return sql, "medium", random.randint(1200, 2400), random.randint(20, 35)
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
# ---------------------------------------------------------------------------
|
| 322 |
+
# Generator registry
|
| 323 |
+
# ---------------------------------------------------------------------------
|
| 324 |
+
|
| 325 |
+
GENERATORS = [
|
| 326 |
+
(_simple_select, 15),
|
| 327 |
+
(_select_with_order, 10),
|
| 328 |
+
(_single_join, 12),
|
| 329 |
+
(_multi_join, 8),
|
| 330 |
+
(_aggregate_query, 10),
|
| 331 |
+
(_aggregate_join, 8),
|
| 332 |
+
(_subquery, 7),
|
| 333 |
+
(_correlated_subquery, 5),
|
| 334 |
+
(_cte_query, 5),
|
| 335 |
+
(_window_query, 5),
|
| 336 |
+
(_union_query, 4),
|
| 337 |
+
(_complex_analytics, 3),
|
| 338 |
+
(_insert_query, 8),
|
| 339 |
+
(_update_query, 5),
|
| 340 |
+
(_delete_query, 4),
|
| 341 |
+
(_exists_query, 5),
|
| 342 |
+
(_case_query, 4),
|
| 343 |
+
(_distinct_query, 4),
|
| 344 |
+
]
|
| 345 |
+
|
| 346 |
+
# Build weighted list
|
| 347 |
+
_WEIGHTED = []
|
| 348 |
+
for gen, weight in GENERATORS:
|
| 349 |
+
_WEIGHTED.extend([gen] * weight)
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def generate_sample():
|
| 353 |
+
"""Generate one (sql, cache_benefit, ttl, complexity) sample."""
|
| 354 |
+
gen = random.choice(_WEIGHTED)
|
| 355 |
+
sql, benefit, ttl, complexity = gen()
|
| 356 |
+
# Add slight noise to TTL and complexity
|
| 357 |
+
ttl = max(0, ttl + random.randint(-60, 60))
|
| 358 |
+
complexity = max(1, min(100, complexity + random.randint(-3, 3)))
|
| 359 |
+
return sql, benefit, ttl, complexity
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
def generate_dataset(n: int = 5000, seed: int = 42):
|
| 363 |
+
"""
|
| 364 |
+
Generate a training dataset of n samples.
|
| 365 |
+
|
| 366 |
+
Returns:
|
| 367 |
+
queries: list[str]
|
| 368 |
+
benefits: list[str] — "low", "medium", "high"
|
| 369 |
+
ttls: list[int] — recommended TTL in seconds
|
| 370 |
+
complexities: list[int] — 1-100 complexity score
|
| 371 |
+
"""
|
| 372 |
+
random.seed(seed)
|
| 373 |
+
queries, benefits, ttls, complexities = [], [], [], []
|
| 374 |
+
for _ in range(n):
|
| 375 |
+
sql, benefit, ttl, complexity = generate_sample()
|
| 376 |
+
queries.append(sql)
|
| 377 |
+
benefits.append(benefit)
|
| 378 |
+
ttls.append(ttl)
|
| 379 |
+
complexities.append(complexity)
|
| 380 |
+
return queries, benefits, ttls, complexities
|
features.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SQL feature extraction for pg_plan_cache models.
|
| 3 |
+
|
| 4 |
+
Extracts structural features from raw SQL query text to feed into
|
| 5 |
+
the Cache Advisor, TTL Recommender, and Complexity Estimator models.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import re
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
AGGREGATE_FUNCS = re.compile(
|
| 12 |
+
r"\b(count|sum|avg|min|max|array_agg|string_agg|bool_and|bool_or|jsonb_agg)\s*\(",
|
| 13 |
+
re.IGNORECASE,
|
| 14 |
+
)
|
| 15 |
+
WINDOW_FUNCS = re.compile(
|
| 16 |
+
r"\b(row_number|rank|dense_rank|ntile|lag|lead|first_value|last_value|nth_value)\s*\(",
|
| 17 |
+
re.IGNORECASE,
|
| 18 |
+
)
|
| 19 |
+
JOIN_PATTERN = re.compile(
|
| 20 |
+
r"\b(inner\s+join|left\s+join|right\s+join|full\s+join|cross\s+join|join)\b",
|
| 21 |
+
re.IGNORECASE,
|
| 22 |
+
)
|
| 23 |
+
SUBQUERY_PATTERN = re.compile(r"\(\s*select\b", re.IGNORECASE)
|
| 24 |
+
CTE_PATTERN = re.compile(r"\bwith\s+\w+\s+as\s*\(", re.IGNORECASE)
|
| 25 |
+
UNION_PATTERN = re.compile(r"\b(union|intersect|except)\b", re.IGNORECASE)
|
| 26 |
+
CASE_PATTERN = re.compile(r"\bcase\b", re.IGNORECASE)
|
| 27 |
+
IN_PATTERN = re.compile(r"\bin\s*\(", re.IGNORECASE)
|
| 28 |
+
LIKE_PATTERN = re.compile(r"\b(like|ilike)\b", re.IGNORECASE)
|
| 29 |
+
BETWEEN_PATTERN = re.compile(r"\bbetween\b", re.IGNORECASE)
|
| 30 |
+
EXISTS_PATTERN = re.compile(r"\bexists\s*\(", re.IGNORECASE)
|
| 31 |
+
HAVING_PATTERN = re.compile(r"\bhaving\b", re.IGNORECASE)
|
| 32 |
+
CAST_PATTERN = re.compile(r"\b(cast|::)\b", re.IGNORECASE)
|
| 33 |
+
|
| 34 |
+
FEATURE_NAMES = [
|
| 35 |
+
"query_length",
|
| 36 |
+
"query_type", # 0=SELECT, 1=INSERT, 2=UPDATE, 3=DELETE, 4=OTHER
|
| 37 |
+
"num_tables",
|
| 38 |
+
"num_joins",
|
| 39 |
+
"num_conditions",
|
| 40 |
+
"num_aggregates",
|
| 41 |
+
"num_subqueries",
|
| 42 |
+
"num_columns",
|
| 43 |
+
"has_distinct",
|
| 44 |
+
"has_order_by",
|
| 45 |
+
"has_group_by",
|
| 46 |
+
"has_having",
|
| 47 |
+
"has_limit",
|
| 48 |
+
"has_offset",
|
| 49 |
+
"has_where",
|
| 50 |
+
"has_like",
|
| 51 |
+
"has_in_clause",
|
| 52 |
+
"has_between",
|
| 53 |
+
"has_exists",
|
| 54 |
+
"has_window_func",
|
| 55 |
+
"has_cte",
|
| 56 |
+
"has_union",
|
| 57 |
+
"has_case",
|
| 58 |
+
"has_cast",
|
| 59 |
+
"nesting_depth",
|
| 60 |
+
"num_and_or",
|
| 61 |
+
"num_string_literals",
|
| 62 |
+
"num_numeric_literals",
|
| 63 |
+
]
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _count_tables(sql: str) -> int:
|
| 67 |
+
"""Estimate the number of tables referenced."""
|
| 68 |
+
count = 0
|
| 69 |
+
# FROM clause tables
|
| 70 |
+
from_match = re.search(r"\bfrom\s+(.+?)(?:\bwhere\b|\bjoin\b|\bgroup\b|\border\b|\blimit\b|\bhaving\b|;|$)", sql, re.IGNORECASE | re.DOTALL)
|
| 71 |
+
if from_match:
|
| 72 |
+
from_clause = from_match.group(1)
|
| 73 |
+
count += len(re.split(r",", from_clause))
|
| 74 |
+
# JOIN tables
|
| 75 |
+
count += len(JOIN_PATTERN.findall(sql))
|
| 76 |
+
return max(count, 0)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _count_columns(sql: str) -> int:
|
| 80 |
+
"""Estimate the number of columns in SELECT clause."""
|
| 81 |
+
match = re.search(r"\bselect\s+(.*?)\bfrom\b", sql, re.IGNORECASE | re.DOTALL)
|
| 82 |
+
if not match:
|
| 83 |
+
return 0
|
| 84 |
+
select_clause = match.group(1).strip()
|
| 85 |
+
if select_clause == "*":
|
| 86 |
+
return 1
|
| 87 |
+
# Split by commas not inside parentheses
|
| 88 |
+
depth = 0
|
| 89 |
+
count = 1
|
| 90 |
+
for ch in select_clause:
|
| 91 |
+
if ch == '(':
|
| 92 |
+
depth += 1
|
| 93 |
+
elif ch == ')':
|
| 94 |
+
depth -= 1
|
| 95 |
+
elif ch == ',' and depth == 0:
|
| 96 |
+
count += 1
|
| 97 |
+
return count
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def _nesting_depth(sql: str) -> int:
|
| 101 |
+
"""Calculate maximum parenthesis nesting depth."""
|
| 102 |
+
max_depth = 0
|
| 103 |
+
depth = 0
|
| 104 |
+
for ch in sql:
|
| 105 |
+
if ch == '(':
|
| 106 |
+
depth += 1
|
| 107 |
+
max_depth = max(max_depth, depth)
|
| 108 |
+
elif ch == ')':
|
| 109 |
+
depth -= 1
|
| 110 |
+
return max_depth
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def extract_features(sql: str) -> list[float]:
|
| 114 |
+
"""
|
| 115 |
+
Extract a fixed-length feature vector from a SQL query string.
|
| 116 |
+
|
| 117 |
+
Returns a list of floats matching FEATURE_NAMES ordering.
|
| 118 |
+
"""
|
| 119 |
+
sql = sql.strip()
|
| 120 |
+
upper = sql.upper().lstrip()
|
| 121 |
+
|
| 122 |
+
# Query type
|
| 123 |
+
if upper.startswith("SELECT"):
|
| 124 |
+
qtype = 0
|
| 125 |
+
elif upper.startswith("INSERT"):
|
| 126 |
+
qtype = 1
|
| 127 |
+
elif upper.startswith("UPDATE"):
|
| 128 |
+
qtype = 2
|
| 129 |
+
elif upper.startswith("DELETE"):
|
| 130 |
+
qtype = 3
|
| 131 |
+
else:
|
| 132 |
+
qtype = 4
|
| 133 |
+
|
| 134 |
+
num_joins = len(JOIN_PATTERN.findall(sql))
|
| 135 |
+
num_aggs = len(AGGREGATE_FUNCS.findall(sql))
|
| 136 |
+
num_subqueries = len(SUBQUERY_PATTERN.findall(sql))
|
| 137 |
+
num_conditions = len(re.findall(r"\b(and|or)\b", sql, re.IGNORECASE))
|
| 138 |
+
num_string_lits = len(re.findall(r"'[^']*'", sql))
|
| 139 |
+
num_numeric_lits = len(re.findall(r"\b\d+(?:\.\d+)?\b", sql))
|
| 140 |
+
|
| 141 |
+
features = [
|
| 142 |
+
float(len(sql)), # query_length
|
| 143 |
+
float(qtype), # query_type
|
| 144 |
+
float(_count_tables(sql)), # num_tables
|
| 145 |
+
float(num_joins), # num_joins
|
| 146 |
+
float(num_conditions), # num_conditions
|
| 147 |
+
float(num_aggs), # num_aggregates
|
| 148 |
+
float(num_subqueries), # num_subqueries
|
| 149 |
+
float(_count_columns(sql)), # num_columns
|
| 150 |
+
float(bool(re.search(r"\bdistinct\b", sql, re.I))), # has_distinct
|
| 151 |
+
float(bool(re.search(r"\border\s+by\b", sql, re.I))), # has_order_by
|
| 152 |
+
float(bool(re.search(r"\bgroup\s+by\b", sql, re.I))), # has_group_by
|
| 153 |
+
float(bool(HAVING_PATTERN.search(sql))), # has_having
|
| 154 |
+
float(bool(re.search(r"\blimit\b", sql, re.I))), # has_limit
|
| 155 |
+
float(bool(re.search(r"\boffset\b", sql, re.I))), # has_offset
|
| 156 |
+
float(bool(re.search(r"\bwhere\b", sql, re.I))), # has_where
|
| 157 |
+
float(bool(LIKE_PATTERN.search(sql))), # has_like
|
| 158 |
+
float(bool(IN_PATTERN.search(sql))), # has_in_clause
|
| 159 |
+
float(bool(BETWEEN_PATTERN.search(sql))), # has_between
|
| 160 |
+
float(bool(EXISTS_PATTERN.search(sql))), # has_exists
|
| 161 |
+
float(bool(WINDOW_FUNCS.search(sql))), # has_window_func
|
| 162 |
+
float(bool(CTE_PATTERN.search(sql))), # has_cte
|
| 163 |
+
float(bool(UNION_PATTERN.search(sql))), # has_union
|
| 164 |
+
float(bool(CASE_PATTERN.search(sql))), # has_case
|
| 165 |
+
float(bool(CAST_PATTERN.search(sql))), # has_cast
|
| 166 |
+
float(_nesting_depth(sql)), # nesting_depth
|
| 167 |
+
float(num_conditions), # num_and_or
|
| 168 |
+
float(num_string_lits), # num_string_literals
|
| 169 |
+
float(num_numeric_lits), # num_numeric_literals
|
| 170 |
+
]
|
| 171 |
+
|
| 172 |
+
return features
|
predict.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Inference API for pg_plan_cache models.
|
| 3 |
+
|
| 4 |
+
Loads trained models and provides prediction functions for:
|
| 5 |
+
1. Cache benefit (high / medium / low)
|
| 6 |
+
2. Recommended TTL (seconds)
|
| 7 |
+
3. Complexity score (1-100)
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import json
|
| 12 |
+
import joblib
|
| 13 |
+
import numpy as np
|
| 14 |
+
from features import extract_features, FEATURE_NAMES
|
| 15 |
+
|
| 16 |
+
MODEL_DIR = os.path.join(os.path.dirname(__file__), "trained")
|
| 17 |
+
|
| 18 |
+
_cache_advisor = None
|
| 19 |
+
_ttl_recommender = None
|
| 20 |
+
_complexity_estimator = None
|
| 21 |
+
_label_encoder = None
|
| 22 |
+
_loaded = False
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _load_models():
|
| 26 |
+
"""Lazy-load all models from disk."""
|
| 27 |
+
global _cache_advisor, _ttl_recommender, _complexity_estimator, _label_encoder, _loaded
|
| 28 |
+
if _loaded:
|
| 29 |
+
return
|
| 30 |
+
|
| 31 |
+
_cache_advisor = joblib.load(os.path.join(MODEL_DIR, "cache_advisor.joblib"))
|
| 32 |
+
_ttl_recommender = joblib.load(os.path.join(MODEL_DIR, "ttl_recommender.joblib"))
|
| 33 |
+
_complexity_estimator = joblib.load(os.path.join(MODEL_DIR, "complexity_estimator.joblib"))
|
| 34 |
+
_label_encoder = joblib.load(os.path.join(MODEL_DIR, "label_encoder.joblib"))
|
| 35 |
+
_loaded = True
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def predict(sql: str) -> dict:
|
| 39 |
+
"""
|
| 40 |
+
Run all three models on a SQL query.
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
{
|
| 44 |
+
"query": str,
|
| 45 |
+
"cache_benefit": "high" | "medium" | "low",
|
| 46 |
+
"cache_benefit_probabilities": {"high": 0.8, "medium": 0.15, "low": 0.05},
|
| 47 |
+
"recommended_ttl": int, # seconds
|
| 48 |
+
"ttl_human": str, # e.g. "1h 0m"
|
| 49 |
+
"complexity_score": int, # 1-100
|
| 50 |
+
"complexity_label": str, # "simple" | "moderate" | "complex" | "very complex"
|
| 51 |
+
"features": {name: value, ...},
|
| 52 |
+
}
|
| 53 |
+
"""
|
| 54 |
+
_load_models()
|
| 55 |
+
|
| 56 |
+
features = extract_features(sql)
|
| 57 |
+
X = np.array([features])
|
| 58 |
+
|
| 59 |
+
# Cache advisor
|
| 60 |
+
benefit_idx = _cache_advisor.predict(X)[0]
|
| 61 |
+
benefit_label = _label_encoder.inverse_transform([benefit_idx])[0]
|
| 62 |
+
benefit_probs = _cache_advisor.predict_proba(X)[0]
|
| 63 |
+
prob_dict = {
|
| 64 |
+
_label_encoder.inverse_transform([i])[0]: round(float(p), 4)
|
| 65 |
+
for i, p in enumerate(benefit_probs)
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
# TTL recommender
|
| 69 |
+
ttl_raw = _ttl_recommender.predict(X)[0]
|
| 70 |
+
ttl = max(0, int(round(ttl_raw)))
|
| 71 |
+
hours, mins = divmod(ttl // 60, 60)
|
| 72 |
+
ttl_human = f"{hours}h {mins}m" if hours else f"{mins}m"
|
| 73 |
+
|
| 74 |
+
# Complexity estimator
|
| 75 |
+
cplx_raw = _complexity_estimator.predict(X)[0]
|
| 76 |
+
cplx = max(1, min(100, int(round(cplx_raw))))
|
| 77 |
+
if cplx <= 20:
|
| 78 |
+
cplx_label = "simple"
|
| 79 |
+
elif cplx <= 45:
|
| 80 |
+
cplx_label = "moderate"
|
| 81 |
+
elif cplx <= 75:
|
| 82 |
+
cplx_label = "complex"
|
| 83 |
+
else:
|
| 84 |
+
cplx_label = "very complex"
|
| 85 |
+
|
| 86 |
+
return {
|
| 87 |
+
"query": sql,
|
| 88 |
+
"cache_benefit": benefit_label,
|
| 89 |
+
"cache_benefit_probabilities": prob_dict,
|
| 90 |
+
"recommended_ttl": ttl,
|
| 91 |
+
"ttl_human": ttl_human,
|
| 92 |
+
"complexity_score": cplx,
|
| 93 |
+
"complexity_label": cplx_label,
|
| 94 |
+
"features": dict(zip(FEATURE_NAMES, features)),
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def predict_batch(queries: list[str]) -> list[dict]:
|
| 99 |
+
"""Run predictions on multiple queries."""
|
| 100 |
+
return [predict(q) for q in queries]
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def format_prediction(result: dict) -> str:
|
| 104 |
+
"""Format a prediction result as a readable string."""
|
| 105 |
+
lines = [
|
| 106 |
+
f" Query: {result['query'][:100]}{'...' if len(result['query']) > 100 else ''}",
|
| 107 |
+
f" Cache Benefit: {result['cache_benefit'].upper()}",
|
| 108 |
+
f" Probabilities: {result['cache_benefit_probabilities']}",
|
| 109 |
+
f" Recommended TTL: {result['recommended_ttl']}s ({result['ttl_human']})",
|
| 110 |
+
f" Complexity: {result['complexity_score']}/100 ({result['complexity_label']})",
|
| 111 |
+
]
|
| 112 |
+
return "\n".join(lines)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def get_model_info() -> dict:
|
| 116 |
+
"""Return model metadata."""
|
| 117 |
+
meta_path = os.path.join(MODEL_DIR, "metadata.json")
|
| 118 |
+
if os.path.exists(meta_path):
|
| 119 |
+
with open(meta_path) as f:
|
| 120 |
+
return json.load(f)
|
| 121 |
+
return {"error": "metadata.json not found. Run train.py first."}
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
# ---------------------------------------------------------------------------
|
| 125 |
+
# CLI
|
| 126 |
+
# ---------------------------------------------------------------------------
|
| 127 |
+
|
| 128 |
+
if __name__ == "__main__":
|
| 129 |
+
import sys
|
| 130 |
+
|
| 131 |
+
if len(sys.argv) < 2:
|
| 132 |
+
print("Usage: python predict.py \"SELECT * FROM users WHERE id = 42\"")
|
| 133 |
+
sys.exit(1)
|
| 134 |
+
|
| 135 |
+
sql = " ".join(sys.argv[1:])
|
| 136 |
+
result = predict(sql)
|
| 137 |
+
print(format_prediction(result))
|
requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
scikit-learn>=1.4.0
|
| 2 |
+
joblib>=1.3.0
|
| 3 |
+
numpy>=1.26.0
|
| 4 |
+
huggingface_hub>=0.24.0
|
train.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Train all three pg_plan_cache models:
|
| 4 |
+
1. SQL Cache Advisor (classification: low / medium / high)
|
| 5 |
+
2. Cache TTL Recommender (regression: seconds)
|
| 6 |
+
3. Query Complexity Estimator (regression: 1-100 score)
|
| 7 |
+
|
| 8 |
+
Saves trained models as joblib files in the ./trained/ directory.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
import json
|
| 13 |
+
import numpy as np
|
| 14 |
+
from sklearn.ensemble import RandomForestClassifier, GradientBoostingRegressor
|
| 15 |
+
from sklearn.model_selection import train_test_split, cross_val_score
|
| 16 |
+
from sklearn.metrics import classification_report, mean_absolute_error, r2_score
|
| 17 |
+
from sklearn.preprocessing import LabelEncoder
|
| 18 |
+
import joblib
|
| 19 |
+
|
| 20 |
+
from features import extract_features, FEATURE_NAMES
|
| 21 |
+
from dataset import generate_dataset
|
| 22 |
+
|
| 23 |
+
OUTPUT_DIR = os.path.join(os.path.dirname(__file__), "trained")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def train():
|
| 27 |
+
print("=" * 60)
|
| 28 |
+
print(" pg_plan_cache — Model Training")
|
| 29 |
+
print("=" * 60)
|
| 30 |
+
|
| 31 |
+
# ── Generate data ─────────────────────────────────────────
|
| 32 |
+
print("\n[1/5] Generating synthetic training data...")
|
| 33 |
+
queries, benefits, ttls, complexities = generate_dataset(n=8000, seed=42)
|
| 34 |
+
print(f" Generated {len(queries)} samples")
|
| 35 |
+
|
| 36 |
+
# ── Extract features ──────────────────────────────────────
|
| 37 |
+
print("\n[2/5] Extracting features...")
|
| 38 |
+
X = np.array([extract_features(q) for q in queries])
|
| 39 |
+
print(f" Feature matrix: {X.shape}")
|
| 40 |
+
|
| 41 |
+
# ── Encode labels ─────────────────────────────────────────
|
| 42 |
+
le = LabelEncoder()
|
| 43 |
+
y_benefit = le.fit_transform(benefits) # low=1, medium=2, high=0
|
| 44 |
+
y_ttl = np.array(ttls, dtype=float)
|
| 45 |
+
y_complexity = np.array(complexities, dtype=float)
|
| 46 |
+
|
| 47 |
+
# ── Split ─────────────────────────────────────────────────
|
| 48 |
+
X_train, X_test, yb_train, yb_test, yt_train, yt_test, yc_train, yc_test = \
|
| 49 |
+
train_test_split(X, y_benefit, y_ttl, y_complexity, test_size=0.2, random_state=42)
|
| 50 |
+
|
| 51 |
+
print(f" Train: {len(X_train)}, Test: {len(X_test)}")
|
| 52 |
+
|
| 53 |
+
# ── Model 1: Cache Advisor (classification) ───────────────
|
| 54 |
+
print("\n[3/5] Training SQL Cache Advisor...")
|
| 55 |
+
clf = RandomForestClassifier(
|
| 56 |
+
n_estimators=200,
|
| 57 |
+
max_depth=15,
|
| 58 |
+
min_samples_split=5,
|
| 59 |
+
min_samples_leaf=2,
|
| 60 |
+
random_state=42,
|
| 61 |
+
n_jobs=-1,
|
| 62 |
+
)
|
| 63 |
+
clf.fit(X_train, yb_train)
|
| 64 |
+
|
| 65 |
+
yb_pred = clf.predict(X_test)
|
| 66 |
+
print("\n Classification Report:")
|
| 67 |
+
report = classification_report(yb_test, yb_pred, target_names=le.classes_)
|
| 68 |
+
print(" " + report.replace("\n", "\n "))
|
| 69 |
+
|
| 70 |
+
cv_scores = cross_val_score(clf, X, y_benefit, cv=5, scoring="accuracy")
|
| 71 |
+
print(f" Cross-val accuracy: {cv_scores.mean():.3f} (+/- {cv_scores.std():.3f})")
|
| 72 |
+
|
| 73 |
+
# ── Model 2: TTL Recommender (regression) ─────────────────
|
| 74 |
+
print("\n[4/5] Training Cache TTL Recommender...")
|
| 75 |
+
reg_ttl = GradientBoostingRegressor(
|
| 76 |
+
n_estimators=200,
|
| 77 |
+
max_depth=8,
|
| 78 |
+
learning_rate=0.1,
|
| 79 |
+
min_samples_split=5,
|
| 80 |
+
random_state=42,
|
| 81 |
+
)
|
| 82 |
+
reg_ttl.fit(X_train, yt_train)
|
| 83 |
+
|
| 84 |
+
yt_pred = reg_ttl.predict(X_test)
|
| 85 |
+
mae_ttl = mean_absolute_error(yt_test, yt_pred)
|
| 86 |
+
r2_ttl = r2_score(yt_test, yt_pred)
|
| 87 |
+
print(f" MAE: {mae_ttl:.1f} seconds")
|
| 88 |
+
print(f" R2: {r2_ttl:.3f}")
|
| 89 |
+
|
| 90 |
+
# ── Model 3: Complexity Estimator (regression) ────────────
|
| 91 |
+
print("\n[5/5] Training Query Complexity Estimator...")
|
| 92 |
+
reg_cplx = GradientBoostingRegressor(
|
| 93 |
+
n_estimators=200,
|
| 94 |
+
max_depth=8,
|
| 95 |
+
learning_rate=0.1,
|
| 96 |
+
min_samples_split=5,
|
| 97 |
+
random_state=42,
|
| 98 |
+
)
|
| 99 |
+
reg_cplx.fit(X_train, yc_train)
|
| 100 |
+
|
| 101 |
+
yc_pred = reg_cplx.predict(X_test)
|
| 102 |
+
mae_cplx = mean_absolute_error(yc_test, yc_pred)
|
| 103 |
+
r2_cplx = r2_score(yc_test, yc_pred)
|
| 104 |
+
print(f" MAE: {mae_cplx:.1f} points")
|
| 105 |
+
print(f" R2: {r2_cplx:.3f}")
|
| 106 |
+
|
| 107 |
+
# ── Save models ───────────────────────────────────────────
|
| 108 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 109 |
+
|
| 110 |
+
joblib.dump(clf, os.path.join(OUTPUT_DIR, "cache_advisor.joblib"))
|
| 111 |
+
joblib.dump(reg_ttl, os.path.join(OUTPUT_DIR, "ttl_recommender.joblib"))
|
| 112 |
+
joblib.dump(reg_cplx, os.path.join(OUTPUT_DIR, "complexity_estimator.joblib"))
|
| 113 |
+
joblib.dump(le, os.path.join(OUTPUT_DIR, "label_encoder.joblib"))
|
| 114 |
+
|
| 115 |
+
# Feature importances
|
| 116 |
+
importances = {
|
| 117 |
+
"cache_advisor": dict(zip(FEATURE_NAMES, clf.feature_importances_.tolist())),
|
| 118 |
+
"ttl_recommender": dict(zip(FEATURE_NAMES, reg_ttl.feature_importances_.tolist())),
|
| 119 |
+
"complexity_estimator": dict(zip(FEATURE_NAMES, reg_cplx.feature_importances_.tolist())),
|
| 120 |
+
}
|
| 121 |
+
with open(os.path.join(OUTPUT_DIR, "feature_importances.json"), "w") as f:
|
| 122 |
+
json.dump(importances, f, indent=2)
|
| 123 |
+
|
| 124 |
+
# Model metadata
|
| 125 |
+
metadata = {
|
| 126 |
+
"models": {
|
| 127 |
+
"cache_advisor": {
|
| 128 |
+
"type": "RandomForestClassifier",
|
| 129 |
+
"task": "classification",
|
| 130 |
+
"classes": le.classes_.tolist(),
|
| 131 |
+
"accuracy_cv5": round(float(cv_scores.mean()), 4),
|
| 132 |
+
},
|
| 133 |
+
"ttl_recommender": {
|
| 134 |
+
"type": "GradientBoostingRegressor",
|
| 135 |
+
"task": "regression",
|
| 136 |
+
"unit": "seconds",
|
| 137 |
+
"mae": round(float(mae_ttl), 2),
|
| 138 |
+
"r2": round(float(r2_ttl), 4),
|
| 139 |
+
},
|
| 140 |
+
"complexity_estimator": {
|
| 141 |
+
"type": "GradientBoostingRegressor",
|
| 142 |
+
"task": "regression",
|
| 143 |
+
"unit": "score (1-100)",
|
| 144 |
+
"mae": round(float(mae_cplx), 2),
|
| 145 |
+
"r2": round(float(r2_cplx), 4),
|
| 146 |
+
},
|
| 147 |
+
},
|
| 148 |
+
"features": FEATURE_NAMES,
|
| 149 |
+
"n_features": len(FEATURE_NAMES),
|
| 150 |
+
"training_samples": len(queries),
|
| 151 |
+
"test_samples": len(X_test),
|
| 152 |
+
}
|
| 153 |
+
with open(os.path.join(OUTPUT_DIR, "metadata.json"), "w") as f:
|
| 154 |
+
json.dump(metadata, f, indent=2)
|
| 155 |
+
|
| 156 |
+
print(f"\n Models saved to {OUTPUT_DIR}/")
|
| 157 |
+
print(" Files: cache_advisor.joblib, ttl_recommender.joblib,")
|
| 158 |
+
print(" complexity_estimator.joblib, label_encoder.joblib,")
|
| 159 |
+
print(" feature_importances.json, metadata.json")
|
| 160 |
+
print("\nDone.")
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
if __name__ == "__main__":
|
| 164 |
+
train()
|
trained/cache_advisor.joblib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e11ba948fd643d426b62362f7fd71e30ec90e4a1f1593b2606ae1e31b7b3b19f
|
| 3 |
+
size 818001
|
trained/complexity_estimator.joblib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3bd2a5edfce1496bc23a0686e6822ff3d583c884ad9922d9eed5f369ef0b064b
|
| 3 |
+
size 3038236
|
trained/feature_importances.json
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cache_advisor": {
|
| 3 |
+
"query_length": 0.19116243566746416,
|
| 4 |
+
"query_type": 0.02137394504176744,
|
| 5 |
+
"num_tables": 0.09222282366305111,
|
| 6 |
+
"num_joins": 0.0748793608388074,
|
| 7 |
+
"num_conditions": 0.00154549784133088,
|
| 8 |
+
"num_aggregates": 0.0618503755668228,
|
| 9 |
+
"num_subqueries": 0.05156804724205885,
|
| 10 |
+
"num_columns": 0.07578970828634744,
|
| 11 |
+
"has_distinct": 0.04377194157855687,
|
| 12 |
+
"has_order_by": 0.03645645249300166,
|
| 13 |
+
"has_group_by": 0.04425844049972725,
|
| 14 |
+
"has_having": 0.0022541480803507635,
|
| 15 |
+
"has_limit": 0.042062573427220216,
|
| 16 |
+
"has_offset": 0.0,
|
| 17 |
+
"has_where": 0.008477512665144578,
|
| 18 |
+
"has_like": 0.0,
|
| 19 |
+
"has_in_clause": 0.005441079955562388,
|
| 20 |
+
"has_between": 0.0,
|
| 21 |
+
"has_exists": 0.0009272674367364887,
|
| 22 |
+
"has_window_func": 0.010171898283664462,
|
| 23 |
+
"has_cte": 0.0017415634776680982,
|
| 24 |
+
"has_union": 0.021229522300210402,
|
| 25 |
+
"has_case": 0.010714231584388431,
|
| 26 |
+
"has_cast": 0.0,
|
| 27 |
+
"nesting_depth": 0.1651162458494366,
|
| 28 |
+
"num_and_or": 0.0018247999615881344,
|
| 29 |
+
"num_string_literals": 0.02825993434632328,
|
| 30 |
+
"num_numeric_literals": 0.006900193912770316
|
| 31 |
+
},
|
| 32 |
+
"ttl_recommender": {
|
| 33 |
+
"query_length": 0.49334167936522283,
|
| 34 |
+
"query_type": 0.011472503279799304,
|
| 35 |
+
"num_tables": 0.04121816512371646,
|
| 36 |
+
"num_joins": 0.05664091770080013,
|
| 37 |
+
"num_conditions": 2.6766564086239894e-05,
|
| 38 |
+
"num_aggregates": 0.08454674221524747,
|
| 39 |
+
"num_subqueries": 0.012819407143812049,
|
| 40 |
+
"num_columns": 0.003503947545486143,
|
| 41 |
+
"has_distinct": 0.0058846177923228245,
|
| 42 |
+
"has_order_by": 0.0030112892658353254,
|
| 43 |
+
"has_group_by": 0.11555986501253222,
|
| 44 |
+
"has_having": 0.0005654100636265899,
|
| 45 |
+
"has_limit": 0.020011249481941062,
|
| 46 |
+
"has_offset": 0.0,
|
| 47 |
+
"has_where": 0.0006198304413308254,
|
| 48 |
+
"has_like": 0.0,
|
| 49 |
+
"has_in_clause": 0.006723068906959933,
|
| 50 |
+
"has_between": 0.0,
|
| 51 |
+
"has_exists": 1.5939534844064166e-05,
|
| 52 |
+
"has_window_func": 0.0016085055032078448,
|
| 53 |
+
"has_cte": 2.3841716696771857e-05,
|
| 54 |
+
"has_union": 5.051873650507809e-05,
|
| 55 |
+
"has_case": 2.1925568628142657e-05,
|
| 56 |
+
"has_cast": 0.0,
|
| 57 |
+
"nesting_depth": 0.13173720022142668,
|
| 58 |
+
"num_and_or": 2.27992721191164e-05,
|
| 59 |
+
"num_string_literals": 0.005676787044969987,
|
| 60 |
+
"num_numeric_literals": 0.004897022498882968
|
| 61 |
+
},
|
| 62 |
+
"complexity_estimator": {
|
| 63 |
+
"query_length": 0.5344926759628151,
|
| 64 |
+
"query_type": 0.0015962377188123598,
|
| 65 |
+
"num_tables": 0.031559929024199504,
|
| 66 |
+
"num_joins": 0.02335110657414861,
|
| 67 |
+
"num_conditions": 5.757862902242119e-05,
|
| 68 |
+
"num_aggregates": 0.04750932601796666,
|
| 69 |
+
"num_subqueries": 0.008970394733974358,
|
| 70 |
+
"num_columns": 0.00588104652025957,
|
| 71 |
+
"has_distinct": 0.01062122091510926,
|
| 72 |
+
"has_order_by": 0.0024661023837127443,
|
| 73 |
+
"has_group_by": 0.061828695835283276,
|
| 74 |
+
"has_having": 0.00034502697726715757,
|
| 75 |
+
"has_limit": 0.020807067356268808,
|
| 76 |
+
"has_offset": 0.0,
|
| 77 |
+
"has_where": 0.0004570231775885458,
|
| 78 |
+
"has_like": 0.0,
|
| 79 |
+
"has_in_clause": 0.013672027252240813,
|
| 80 |
+
"has_between": 0.0,
|
| 81 |
+
"has_exists": 7.242098418966911e-05,
|
| 82 |
+
"has_window_func": 0.0009971635825058846,
|
| 83 |
+
"has_cte": 1.4790912091677233e-05,
|
| 84 |
+
"has_union": 0.006250913065401877,
|
| 85 |
+
"has_case": 1.6824403324258042e-05,
|
| 86 |
+
"has_cast": 0.0,
|
| 87 |
+
"nesting_depth": 0.22281668760327789,
|
| 88 |
+
"num_and_or": 7.125446039372882e-05,
|
| 89 |
+
"num_string_literals": 0.003162795018354366,
|
| 90 |
+
"num_numeric_literals": 0.0029816908917914696
|
| 91 |
+
}
|
| 92 |
+
}
|
trained/label_encoder.joblib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:dca6130147e0c2d5e5b985a5abb3087d622fbe3da1e3e09ce3c5a79cc5fd15e8
|
| 3 |
+
size 399
|
trained/metadata.json
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"models": {
|
| 3 |
+
"cache_advisor": {
|
| 4 |
+
"type": "RandomForestClassifier",
|
| 5 |
+
"task": "classification",
|
| 6 |
+
"classes": [
|
| 7 |
+
"high",
|
| 8 |
+
"low",
|
| 9 |
+
"medium"
|
| 10 |
+
],
|
| 11 |
+
"accuracy_cv5": 1.0
|
| 12 |
+
},
|
| 13 |
+
"ttl_recommender": {
|
| 14 |
+
"type": "GradientBoostingRegressor",
|
| 15 |
+
"task": "regression",
|
| 16 |
+
"unit": "seconds",
|
| 17 |
+
"mae": 494.56,
|
| 18 |
+
"r2": 0.8994
|
| 19 |
+
},
|
| 20 |
+
"complexity_estimator": {
|
| 21 |
+
"type": "GradientBoostingRegressor",
|
| 22 |
+
"task": "regression",
|
| 23 |
+
"unit": "score (1-100)",
|
| 24 |
+
"mae": 5.57,
|
| 25 |
+
"r2": 0.9216
|
| 26 |
+
}
|
| 27 |
+
},
|
| 28 |
+
"features": [
|
| 29 |
+
"query_length",
|
| 30 |
+
"query_type",
|
| 31 |
+
"num_tables",
|
| 32 |
+
"num_joins",
|
| 33 |
+
"num_conditions",
|
| 34 |
+
"num_aggregates",
|
| 35 |
+
"num_subqueries",
|
| 36 |
+
"num_columns",
|
| 37 |
+
"has_distinct",
|
| 38 |
+
"has_order_by",
|
| 39 |
+
"has_group_by",
|
| 40 |
+
"has_having",
|
| 41 |
+
"has_limit",
|
| 42 |
+
"has_offset",
|
| 43 |
+
"has_where",
|
| 44 |
+
"has_like",
|
| 45 |
+
"has_in_clause",
|
| 46 |
+
"has_between",
|
| 47 |
+
"has_exists",
|
| 48 |
+
"has_window_func",
|
| 49 |
+
"has_cte",
|
| 50 |
+
"has_union",
|
| 51 |
+
"has_case",
|
| 52 |
+
"has_cast",
|
| 53 |
+
"nesting_depth",
|
| 54 |
+
"num_and_or",
|
| 55 |
+
"num_string_literals",
|
| 56 |
+
"num_numeric_literals"
|
| 57 |
+
],
|
| 58 |
+
"n_features": 28,
|
| 59 |
+
"training_samples": 8000,
|
| 60 |
+
"test_samples": 1600
|
| 61 |
+
}
|
trained/ttl_recommender.joblib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6ac8fbc0829aba31da6ff9ea299f512b63ed95c065cc2ae7a5779c7a110486aa
|
| 3 |
+
size 3066316
|