nilenpatel commited on
Commit
406cec4
·
verified ·
1 Parent(s): 82b8f8b

Upload pg_plan_cache models

Browse files
README.md ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: sklearn
3
+ tags:
4
+ - postgresql
5
+ - sql
6
+ - query-cache
7
+ - plan-cache
8
+ - redis
9
+ - database
10
+ - tabular-classification
11
+ - tabular-regression
12
+ pipeline_tag: tabular-classification
13
+ license: mit
14
+ ---
15
+
16
+ # pg_plan_cache Models
17
+
18
+ Three machine learning models for the **pg_plan_cache** PostgreSQL extension — a query
19
+ execution plan cache backed by Redis.
20
+
21
+ ## Models
22
+
23
+ ### 1. SQL Cache Advisor
24
+ - **Task:** Classification (high / medium / low)
25
+ - **Algorithm:** Random Forest (200 trees)
26
+ - **Purpose:** Predicts whether caching a query's execution plan will be beneficial
27
+
28
+ ### 2. Cache TTL Recommender
29
+ - **Task:** Regression (seconds)
30
+ - **Algorithm:** Gradient Boosting
31
+ - **Purpose:** Recommends optimal cache TTL based on query characteristics
32
+
33
+ ### 3. Query Complexity Estimator
34
+ - **Task:** Regression (1-100 score)
35
+ - **Algorithm:** Gradient Boosting
36
+ - **Purpose:** Estimates query complexity to prioritize caching resources
37
+
38
+ ## Features
39
+
40
+ All models use 28 structural features extracted from raw SQL text:
41
+
42
+ | Feature | Description |
43
+ |---------|------------|
44
+ | `query_length` | Character count |
45
+ | `query_type` | SELECT=0, INSERT=1, UPDATE=2, DELETE=3 |
46
+ | `num_tables` | Tables referenced |
47
+ | `num_joins` | JOIN clause count |
48
+ | `num_conditions` | AND/OR conditions |
49
+ | `num_aggregates` | Aggregate function count |
50
+ | `num_subqueries` | Subquery count |
51
+ | `has_window_func` | Window functions present |
52
+ | `has_cte` | Common Table Expressions |
53
+ | `nesting_depth` | Max parenthesis depth |
54
+ | ... | 18 more features |
55
+
56
+ ## Usage
57
+
58
+ ```python
59
+ from predict import predict, format_prediction
60
+
61
+ result = predict("SELECT u.name, COUNT(o.id) FROM users u JOIN orders o ON u.id = o.user_id GROUP BY u.name")
62
+ print(format_prediction(result))
63
+ # Cache Benefit: HIGH
64
+ # Recommended TTL: 4200s (1h 10m)
65
+ # Complexity: 62/100 (complex)
66
+ ```
67
+
68
+ ## Training
69
+
70
+ Trained on 8,000 synthetic SQL queries across 18 complexity tiers:
71
+ - Simple SELECTs, filtered queries, ORDER BY
72
+ - Single and multi-table JOINs
73
+ - Aggregations with GROUP BY / HAVING
74
+ - Subqueries, correlated subqueries, EXISTS
75
+ - CTEs, window functions, UNION
76
+ - Complex analytics queries
77
+ - INSERT / UPDATE / DELETE (non-cacheable)
78
+
79
+ ```bash
80
+ pip install -r requirements.txt
81
+ python train.py
82
+ ```
83
+
84
+ ## About pg_plan_cache
85
+
86
+ pg_plan_cache is a PostgreSQL extension that caches query execution plans in Redis.
87
+ It hooks into the PostgreSQL planner, normalizes queries, computes SHA-256 hashes,
88
+ and stores serialized plans with configurable TTL and automatic schema-change invalidation.
dataset.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Synthetic training data generator for pg_plan_cache models.
3
+
4
+ Generates realistic SQL queries across a wide range of complexity levels
5
+ with labels for cache benefit, recommended TTL, and complexity score.
6
+ """
7
+
8
+ import random
9
+
10
+ # ---------------------------------------------------------------------------
11
+ # Building blocks
12
+ # ---------------------------------------------------------------------------
13
+
14
+ TABLES = [
15
+ "users", "orders", "products", "payments", "sessions",
16
+ "logs", "events", "accounts", "invoices", "shipments",
17
+ "categories", "reviews", "inventory", "notifications", "messages",
18
+ "employees", "departments", "projects", "tasks", "comments",
19
+ ]
20
+
21
+ SCHEMAS = ["public", "app", "analytics", "billing"]
22
+
23
+ COLUMNS = {
24
+ "users": ["id", "name", "email", "created_at", "status", "age", "country"],
25
+ "orders": ["id", "user_id", "total", "status", "created_at", "shipped_at"],
26
+ "products": ["id", "name", "price", "category_id", "stock", "rating"],
27
+ "payments": ["id", "order_id", "amount", "method", "paid_at", "status"],
28
+ "sessions": ["id", "user_id", "started_at", "ended_at", "ip_address"],
29
+ "logs": ["id", "level", "message", "created_at", "source"],
30
+ "events": ["id", "type", "user_id", "data", "created_at"],
31
+ "accounts": ["id", "owner_id", "balance", "currency", "opened_at"],
32
+ "invoices": ["id", "account_id", "amount", "due_date", "status"],
33
+ "shipments": ["id", "order_id", "carrier", "tracking", "shipped_at"],
34
+ "categories": ["id", "name", "parent_id", "sort_order"],
35
+ "reviews": ["id", "product_id", "user_id", "rating", "body", "created_at"],
36
+ "inventory": ["id", "product_id", "warehouse_id", "quantity", "updated_at"],
37
+ "notifications": ["id", "user_id", "type", "read", "created_at"],
38
+ "messages": ["id", "sender_id", "receiver_id", "body", "sent_at"],
39
+ "employees": ["id", "name", "department_id", "salary", "hired_at"],
40
+ "departments": ["id", "name", "budget", "manager_id"],
41
+ "projects": ["id", "name", "department_id", "deadline", "status"],
42
+ "tasks": ["id", "project_id", "assignee_id", "title", "status", "due_date"],
43
+ "comments": ["id", "task_id", "user_id", "body", "created_at"],
44
+ }
45
+
46
+ AGG_FUNCS = ["COUNT", "SUM", "AVG", "MIN", "MAX"]
47
+ COMPARISONS = ["=", ">", "<", ">=", "<=", "!="]
48
+ STRING_VALS = ["'active'", "'pending'", "'completed'", "'cancelled'", "'new'", "'shipped'"]
49
+ JOIN_TYPES = ["JOIN", "LEFT JOIN", "INNER JOIN", "RIGHT JOIN"]
50
+ WINDOW_FUNCS = ["ROW_NUMBER()", "RANK()", "DENSE_RANK()", "LAG(t.id, 1)", "LEAD(t.id, 1)"]
51
+
52
+
53
+ def _rand_table():
54
+ return random.choice(TABLES)
55
+
56
+
57
+ def _rand_cols(table, n=None):
58
+ cols = COLUMNS.get(table, ["id", "name"])
59
+ n = n or random.randint(1, min(4, len(cols)))
60
+ return random.sample(cols, min(n, len(cols)))
61
+
62
+
63
+ def _rand_where(alias="t"):
64
+ col = random.choice(["id", "status", "created_at", "name", "amount", "age"])
65
+ op = random.choice(COMPARISONS)
66
+ if col == "status":
67
+ return f"{alias}.{col} {op} {random.choice(STRING_VALS)}"
68
+ elif col in ("id", "age", "amount"):
69
+ return f"{alias}.{col} {op} {random.randint(1, 10000)}"
70
+ else:
71
+ return f"{alias}.{col} {op} '2024-{random.randint(1,12):02d}-{random.randint(1,28):02d}'"
72
+
73
+
74
+ # ---------------------------------------------------------------------------
75
+ # Query generators by complexity tier
76
+ # ---------------------------------------------------------------------------
77
+
78
+ def _simple_select():
79
+ """Tier 1: Simple SELECT with optional WHERE."""
80
+ t = _rand_table()
81
+ cols = ", ".join(_rand_cols(t))
82
+ sql = f"SELECT {cols} FROM {t}"
83
+ if random.random() > 0.3:
84
+ sql += f" WHERE {_rand_where(t[:1])}"
85
+ if random.random() > 0.7:
86
+ sql += f" LIMIT {random.choice([10, 20, 50, 100])}"
87
+ return sql, "low", random.randint(300, 900), random.randint(5, 20)
88
+
89
+
90
+ def _select_with_order():
91
+ """Tier 1.5: SELECT with ORDER BY and LIMIT."""
92
+ t = _rand_table()
93
+ cols = ", ".join(_rand_cols(t))
94
+ order_col = random.choice(COLUMNS.get(t, ["id"]))
95
+ direction = random.choice(["ASC", "DESC"])
96
+ sql = f"SELECT {cols} FROM {t} WHERE {_rand_where(t[:1])} ORDER BY {order_col} {direction} LIMIT {random.choice([10,25,50])}"
97
+ return sql, "low", random.randint(600, 1200), random.randint(10, 25)
98
+
99
+
100
+ def _single_join():
101
+ """Tier 2: Single JOIN query."""
102
+ t1, t2 = random.sample(TABLES, 2)
103
+ c1 = ", ".join(f"a.{c}" for c in _rand_cols(t1, 2))
104
+ c2 = ", ".join(f"b.{c}" for c in _rand_cols(t2, 2))
105
+ jtype = random.choice(JOIN_TYPES)
106
+ sql = (
107
+ f"SELECT {c1}, {c2} FROM {t1} a "
108
+ f"{jtype} {t2} b ON a.id = b.{t1[:-1]}_id"
109
+ )
110
+ if random.random() > 0.4:
111
+ sql += f" WHERE {_rand_where('a')}"
112
+ return sql, "medium", random.randint(1800, 3600), random.randint(25, 45)
113
+
114
+
115
+ def _multi_join():
116
+ """Tier 3: Multi-table JOIN."""
117
+ tables = random.sample(TABLES, random.randint(3, 5))
118
+ selects = []
119
+ for i, t in enumerate(tables):
120
+ alias = chr(97 + i)
121
+ col = random.choice(COLUMNS.get(t, ["id"]))
122
+ selects.append(f"{alias}.{col}")
123
+
124
+ sql = f"SELECT {', '.join(selects)} FROM {tables[0]} a"
125
+ for i in range(1, len(tables)):
126
+ alias = chr(97 + i)
127
+ prev_alias = chr(97 + i - 1)
128
+ jtype = random.choice(JOIN_TYPES)
129
+ sql += f" {jtype} {tables[i]} {alias} ON {prev_alias}.id = {alias}.{tables[i-1][:-1]}_id"
130
+
131
+ if random.random() > 0.3:
132
+ sql += f" WHERE {_rand_where('a')}"
133
+ if random.random() > 0.5:
134
+ sql += f" ORDER BY a.id LIMIT {random.choice([50, 100, 200])}"
135
+ return sql, "high", random.randint(3600, 7200), random.randint(45, 70)
136
+
137
+
138
+ def _aggregate_query():
139
+ """Tier 3: Aggregation with GROUP BY."""
140
+ t = _rand_table()
141
+ group_col = random.choice(COLUMNS.get(t, ["id"])[:3])
142
+ agg = random.choice(AGG_FUNCS)
143
+ agg_col = random.choice(["id", "amount", "total", "price", "salary"])
144
+ sql = f"SELECT {group_col}, {agg}({agg_col}) FROM {t}"
145
+ if random.random() > 0.4:
146
+ sql += f" WHERE {_rand_where(t[:1])}"
147
+ sql += f" GROUP BY {group_col}"
148
+ if random.random() > 0.6:
149
+ sql += f" HAVING {agg}({agg_col}) > {random.randint(1, 1000)}"
150
+ if random.random() > 0.5:
151
+ sql += f" ORDER BY {agg}({agg_col}) DESC"
152
+ return sql, "high", random.randint(3600, 7200), random.randint(40, 65)
153
+
154
+
155
+ def _aggregate_join():
156
+ """Tier 4: JOIN + Aggregation."""
157
+ t1, t2 = random.sample(TABLES, 2)
158
+ agg = random.choice(AGG_FUNCS)
159
+ group_col = f"a.{random.choice(COLUMNS.get(t1, ['id'])[:2])}"
160
+ agg_col = f"b.{random.choice(['id', 'amount', 'total'])}"
161
+ jtype = random.choice(JOIN_TYPES)
162
+ sql = (
163
+ f"SELECT {group_col}, {agg}({agg_col}) as agg_val "
164
+ f"FROM {t1} a {jtype} {t2} b ON a.id = b.{t1[:-1]}_id "
165
+ f"WHERE {_rand_where('a')} "
166
+ f"GROUP BY {group_col}"
167
+ )
168
+ if random.random() > 0.5:
169
+ sql += f" HAVING {agg}({agg_col}) > {random.randint(1, 500)}"
170
+ sql += f" ORDER BY agg_val DESC LIMIT {random.choice([10, 20, 50])}"
171
+ return sql, "high", random.randint(3600, 7200), random.randint(55, 80)
172
+
173
+
174
+ def _subquery():
175
+ """Tier 4: Subquery."""
176
+ t1, t2 = random.sample(TABLES, 2)
177
+ cols = ", ".join(_rand_cols(t1, 2))
178
+ sub_agg = random.choice(AGG_FUNCS)
179
+ op = random.choice([">", "<", ">="])
180
+ sql = (
181
+ f"SELECT {cols} FROM {t1} "
182
+ f"WHERE id IN (SELECT {t1[:-1]}_id FROM {t2} "
183
+ f"WHERE {_rand_where(t2[:1])})"
184
+ )
185
+ return sql, "high", random.randint(3600, 5400), random.randint(50, 75)
186
+
187
+
188
+ def _correlated_subquery():
189
+ """Tier 5: Correlated subquery."""
190
+ t1, t2 = random.sample(TABLES, 2)
191
+ agg = random.choice(AGG_FUNCS)
192
+ sql = (
193
+ f"SELECT a.id, a.name, "
194
+ f"(SELECT {agg}(b.id) FROM {t2} b WHERE b.{t1[:-1]}_id = a.id) as sub_val "
195
+ f"FROM {t1} a WHERE {_rand_where('a')}"
196
+ )
197
+ return sql, "high", random.randint(3600, 7200), random.randint(60, 85)
198
+
199
+
200
+ def _cte_query():
201
+ """Tier 5: Common Table Expression (WITH)."""
202
+ t1, t2 = random.sample(TABLES, 2)
203
+ agg = random.choice(AGG_FUNCS)
204
+ sql = (
205
+ f"WITH cte AS ("
206
+ f"SELECT {t1[:-1]}_id, {agg}(id) as cnt FROM {t2} GROUP BY {t1[:-1]}_id"
207
+ f") SELECT a.id, a.name, c.cnt "
208
+ f"FROM {t1} a JOIN cte c ON a.id = c.{t1[:-1]}_id "
209
+ f"WHERE c.cnt > {random.randint(1, 50)} "
210
+ f"ORDER BY c.cnt DESC"
211
+ )
212
+ return sql, "high", random.randint(3600, 7200), random.randint(65, 85)
213
+
214
+
215
+ def _window_query():
216
+ """Tier 5: Window function."""
217
+ t = _rand_table()
218
+ wfunc = random.choice(["ROW_NUMBER()", "RANK()", "DENSE_RANK()"])
219
+ partition_col = random.choice(COLUMNS.get(t, ["id"])[:2])
220
+ order_col = random.choice(["id", "created_at"])
221
+ sql = (
222
+ f"SELECT id, {partition_col}, "
223
+ f"{wfunc} OVER (PARTITION BY {partition_col} ORDER BY {order_col} DESC) as rn "
224
+ f"FROM {t} WHERE {_rand_where(t[:1])}"
225
+ )
226
+ return sql, "high", random.randint(3600, 7200), random.randint(55, 80)
227
+
228
+
229
+ def _union_query():
230
+ """Tier 4: UNION query."""
231
+ t1, t2 = random.sample(TABLES, 2)
232
+ sql = (
233
+ f"SELECT id, name FROM {t1} WHERE {_rand_where(t1[:1])} "
234
+ f"UNION ALL "
235
+ f"SELECT id, name FROM {t2} WHERE {_rand_where(t2[:1])}"
236
+ )
237
+ return sql, "medium", random.randint(1800, 3600), random.randint(35, 55)
238
+
239
+
240
+ def _complex_analytics():
241
+ """Tier 6: Complex analytics query."""
242
+ t1, t2, t3 = random.sample(TABLES, 3)
243
+ agg1 = random.choice(AGG_FUNCS)
244
+ agg2 = random.choice(AGG_FUNCS)
245
+ sql = (
246
+ f"WITH monthly AS ("
247
+ f"SELECT a.id, a.name, {agg1}(b.id) as cnt, {agg2}(c.id) as total "
248
+ f"FROM {t1} a "
249
+ f"LEFT JOIN {t2} b ON a.id = b.{t1[:-1]}_id "
250
+ f"LEFT JOIN {t3} c ON b.id = c.{t2[:-1]}_id "
251
+ f"WHERE a.created_at >= '2024-01-01' "
252
+ f"GROUP BY a.id, a.name "
253
+ f"HAVING {agg1}(b.id) > {random.randint(1, 20)}"
254
+ f") SELECT name, cnt, total, "
255
+ f"RANK() OVER (ORDER BY cnt DESC) as rank "
256
+ f"FROM monthly ORDER BY rank LIMIT 100"
257
+ )
258
+ return sql, "high", random.randint(5400, 7200), random.randint(80, 100)
259
+
260
+
261
+ def _insert_query():
262
+ """INSERT — not cacheable."""
263
+ t = _rand_table()
264
+ cols = _rand_cols(t, 3)
265
+ vals = ", ".join(
266
+ f"{random.randint(1, 9999)}" if c in ("id", "age") else f"'val_{random.randint(1,99)}'"
267
+ for c in cols
268
+ )
269
+ sql = f"INSERT INTO {t} ({', '.join(cols)}) VALUES ({vals})"
270
+ return sql, "low", 0, random.randint(5, 15)
271
+
272
+
273
+ def _update_query():
274
+ """UPDATE — not cacheable."""
275
+ t = _rand_table()
276
+ col = random.choice(COLUMNS.get(t, ["name"])[1:])
277
+ sql = f"UPDATE {t} SET {col} = 'updated' WHERE {_rand_where(t[:1])}"
278
+ return sql, "low", 0, random.randint(5, 15)
279
+
280
+
281
+ def _delete_query():
282
+ """DELETE — not cacheable."""
283
+ t = _rand_table()
284
+ sql = f"DELETE FROM {t} WHERE {_rand_where(t[:1])}"
285
+ return sql, "low", 0, random.randint(5, 10)
286
+
287
+
288
+ def _exists_query():
289
+ """Tier 4: EXISTS subquery."""
290
+ t1, t2 = random.sample(TABLES, 2)
291
+ cols = ", ".join(_rand_cols(t1, 2))
292
+ sql = (
293
+ f"SELECT {cols} FROM {t1} a "
294
+ f"WHERE EXISTS (SELECT 1 FROM {t2} b WHERE b.{t1[:-1]}_id = a.id "
295
+ f"AND {_rand_where('b')})"
296
+ )
297
+ return sql, "high", random.randint(3600, 5400), random.randint(50, 70)
298
+
299
+
300
+ def _case_query():
301
+ """Tier 3: CASE expression."""
302
+ t = _rand_table()
303
+ sql = (
304
+ f"SELECT id, "
305
+ f"CASE WHEN status = 'active' THEN 'A' "
306
+ f"WHEN status = 'pending' THEN 'P' "
307
+ f"ELSE 'X' END as status_code, "
308
+ f"name FROM {t} WHERE {_rand_where(t[:1])}"
309
+ )
310
+ return sql, "medium", random.randint(1800, 3600), random.randint(25, 40)
311
+
312
+
313
+ def _distinct_query():
314
+ """Tier 2: SELECT DISTINCT."""
315
+ t = _rand_table()
316
+ col = random.choice(COLUMNS.get(t, ["name"])[:3])
317
+ sql = f"SELECT DISTINCT {col} FROM {t} WHERE {_rand_where(t[:1])} ORDER BY {col}"
318
+ return sql, "medium", random.randint(1200, 2400), random.randint(20, 35)
319
+
320
+
321
+ # ---------------------------------------------------------------------------
322
+ # Generator registry
323
+ # ---------------------------------------------------------------------------
324
+
325
+ GENERATORS = [
326
+ (_simple_select, 15),
327
+ (_select_with_order, 10),
328
+ (_single_join, 12),
329
+ (_multi_join, 8),
330
+ (_aggregate_query, 10),
331
+ (_aggregate_join, 8),
332
+ (_subquery, 7),
333
+ (_correlated_subquery, 5),
334
+ (_cte_query, 5),
335
+ (_window_query, 5),
336
+ (_union_query, 4),
337
+ (_complex_analytics, 3),
338
+ (_insert_query, 8),
339
+ (_update_query, 5),
340
+ (_delete_query, 4),
341
+ (_exists_query, 5),
342
+ (_case_query, 4),
343
+ (_distinct_query, 4),
344
+ ]
345
+
346
+ # Build weighted list
347
+ _WEIGHTED = []
348
+ for gen, weight in GENERATORS:
349
+ _WEIGHTED.extend([gen] * weight)
350
+
351
+
352
+ def generate_sample():
353
+ """Generate one (sql, cache_benefit, ttl, complexity) sample."""
354
+ gen = random.choice(_WEIGHTED)
355
+ sql, benefit, ttl, complexity = gen()
356
+ # Add slight noise to TTL and complexity
357
+ ttl = max(0, ttl + random.randint(-60, 60))
358
+ complexity = max(1, min(100, complexity + random.randint(-3, 3)))
359
+ return sql, benefit, ttl, complexity
360
+
361
+
362
+ def generate_dataset(n: int = 5000, seed: int = 42):
363
+ """
364
+ Generate a training dataset of n samples.
365
+
366
+ Returns:
367
+ queries: list[str]
368
+ benefits: list[str] — "low", "medium", "high"
369
+ ttls: list[int] — recommended TTL in seconds
370
+ complexities: list[int] — 1-100 complexity score
371
+ """
372
+ random.seed(seed)
373
+ queries, benefits, ttls, complexities = [], [], [], []
374
+ for _ in range(n):
375
+ sql, benefit, ttl, complexity = generate_sample()
376
+ queries.append(sql)
377
+ benefits.append(benefit)
378
+ ttls.append(ttl)
379
+ complexities.append(complexity)
380
+ return queries, benefits, ttls, complexities
features.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SQL feature extraction for pg_plan_cache models.
3
+
4
+ Extracts structural features from raw SQL query text to feed into
5
+ the Cache Advisor, TTL Recommender, and Complexity Estimator models.
6
+ """
7
+
8
+ import re
9
+
10
+
11
+ AGGREGATE_FUNCS = re.compile(
12
+ r"\b(count|sum|avg|min|max|array_agg|string_agg|bool_and|bool_or|jsonb_agg)\s*\(",
13
+ re.IGNORECASE,
14
+ )
15
+ WINDOW_FUNCS = re.compile(
16
+ r"\b(row_number|rank|dense_rank|ntile|lag|lead|first_value|last_value|nth_value)\s*\(",
17
+ re.IGNORECASE,
18
+ )
19
+ JOIN_PATTERN = re.compile(
20
+ r"\b(inner\s+join|left\s+join|right\s+join|full\s+join|cross\s+join|join)\b",
21
+ re.IGNORECASE,
22
+ )
23
+ SUBQUERY_PATTERN = re.compile(r"\(\s*select\b", re.IGNORECASE)
24
+ CTE_PATTERN = re.compile(r"\bwith\s+\w+\s+as\s*\(", re.IGNORECASE)
25
+ UNION_PATTERN = re.compile(r"\b(union|intersect|except)\b", re.IGNORECASE)
26
+ CASE_PATTERN = re.compile(r"\bcase\b", re.IGNORECASE)
27
+ IN_PATTERN = re.compile(r"\bin\s*\(", re.IGNORECASE)
28
+ LIKE_PATTERN = re.compile(r"\b(like|ilike)\b", re.IGNORECASE)
29
+ BETWEEN_PATTERN = re.compile(r"\bbetween\b", re.IGNORECASE)
30
+ EXISTS_PATTERN = re.compile(r"\bexists\s*\(", re.IGNORECASE)
31
+ HAVING_PATTERN = re.compile(r"\bhaving\b", re.IGNORECASE)
32
+ CAST_PATTERN = re.compile(r"\b(cast|::)\b", re.IGNORECASE)
33
+
34
+ FEATURE_NAMES = [
35
+ "query_length",
36
+ "query_type", # 0=SELECT, 1=INSERT, 2=UPDATE, 3=DELETE, 4=OTHER
37
+ "num_tables",
38
+ "num_joins",
39
+ "num_conditions",
40
+ "num_aggregates",
41
+ "num_subqueries",
42
+ "num_columns",
43
+ "has_distinct",
44
+ "has_order_by",
45
+ "has_group_by",
46
+ "has_having",
47
+ "has_limit",
48
+ "has_offset",
49
+ "has_where",
50
+ "has_like",
51
+ "has_in_clause",
52
+ "has_between",
53
+ "has_exists",
54
+ "has_window_func",
55
+ "has_cte",
56
+ "has_union",
57
+ "has_case",
58
+ "has_cast",
59
+ "nesting_depth",
60
+ "num_and_or",
61
+ "num_string_literals",
62
+ "num_numeric_literals",
63
+ ]
64
+
65
+
66
+ def _count_tables(sql: str) -> int:
67
+ """Estimate the number of tables referenced."""
68
+ count = 0
69
+ # FROM clause tables
70
+ from_match = re.search(r"\bfrom\s+(.+?)(?:\bwhere\b|\bjoin\b|\bgroup\b|\border\b|\blimit\b|\bhaving\b|;|$)", sql, re.IGNORECASE | re.DOTALL)
71
+ if from_match:
72
+ from_clause = from_match.group(1)
73
+ count += len(re.split(r",", from_clause))
74
+ # JOIN tables
75
+ count += len(JOIN_PATTERN.findall(sql))
76
+ return max(count, 0)
77
+
78
+
79
+ def _count_columns(sql: str) -> int:
80
+ """Estimate the number of columns in SELECT clause."""
81
+ match = re.search(r"\bselect\s+(.*?)\bfrom\b", sql, re.IGNORECASE | re.DOTALL)
82
+ if not match:
83
+ return 0
84
+ select_clause = match.group(1).strip()
85
+ if select_clause == "*":
86
+ return 1
87
+ # Split by commas not inside parentheses
88
+ depth = 0
89
+ count = 1
90
+ for ch in select_clause:
91
+ if ch == '(':
92
+ depth += 1
93
+ elif ch == ')':
94
+ depth -= 1
95
+ elif ch == ',' and depth == 0:
96
+ count += 1
97
+ return count
98
+
99
+
100
+ def _nesting_depth(sql: str) -> int:
101
+ """Calculate maximum parenthesis nesting depth."""
102
+ max_depth = 0
103
+ depth = 0
104
+ for ch in sql:
105
+ if ch == '(':
106
+ depth += 1
107
+ max_depth = max(max_depth, depth)
108
+ elif ch == ')':
109
+ depth -= 1
110
+ return max_depth
111
+
112
+
113
+ def extract_features(sql: str) -> list[float]:
114
+ """
115
+ Extract a fixed-length feature vector from a SQL query string.
116
+
117
+ Returns a list of floats matching FEATURE_NAMES ordering.
118
+ """
119
+ sql = sql.strip()
120
+ upper = sql.upper().lstrip()
121
+
122
+ # Query type
123
+ if upper.startswith("SELECT"):
124
+ qtype = 0
125
+ elif upper.startswith("INSERT"):
126
+ qtype = 1
127
+ elif upper.startswith("UPDATE"):
128
+ qtype = 2
129
+ elif upper.startswith("DELETE"):
130
+ qtype = 3
131
+ else:
132
+ qtype = 4
133
+
134
+ num_joins = len(JOIN_PATTERN.findall(sql))
135
+ num_aggs = len(AGGREGATE_FUNCS.findall(sql))
136
+ num_subqueries = len(SUBQUERY_PATTERN.findall(sql))
137
+ num_conditions = len(re.findall(r"\b(and|or)\b", sql, re.IGNORECASE))
138
+ num_string_lits = len(re.findall(r"'[^']*'", sql))
139
+ num_numeric_lits = len(re.findall(r"\b\d+(?:\.\d+)?\b", sql))
140
+
141
+ features = [
142
+ float(len(sql)), # query_length
143
+ float(qtype), # query_type
144
+ float(_count_tables(sql)), # num_tables
145
+ float(num_joins), # num_joins
146
+ float(num_conditions), # num_conditions
147
+ float(num_aggs), # num_aggregates
148
+ float(num_subqueries), # num_subqueries
149
+ float(_count_columns(sql)), # num_columns
150
+ float(bool(re.search(r"\bdistinct\b", sql, re.I))), # has_distinct
151
+ float(bool(re.search(r"\border\s+by\b", sql, re.I))), # has_order_by
152
+ float(bool(re.search(r"\bgroup\s+by\b", sql, re.I))), # has_group_by
153
+ float(bool(HAVING_PATTERN.search(sql))), # has_having
154
+ float(bool(re.search(r"\blimit\b", sql, re.I))), # has_limit
155
+ float(bool(re.search(r"\boffset\b", sql, re.I))), # has_offset
156
+ float(bool(re.search(r"\bwhere\b", sql, re.I))), # has_where
157
+ float(bool(LIKE_PATTERN.search(sql))), # has_like
158
+ float(bool(IN_PATTERN.search(sql))), # has_in_clause
159
+ float(bool(BETWEEN_PATTERN.search(sql))), # has_between
160
+ float(bool(EXISTS_PATTERN.search(sql))), # has_exists
161
+ float(bool(WINDOW_FUNCS.search(sql))), # has_window_func
162
+ float(bool(CTE_PATTERN.search(sql))), # has_cte
163
+ float(bool(UNION_PATTERN.search(sql))), # has_union
164
+ float(bool(CASE_PATTERN.search(sql))), # has_case
165
+ float(bool(CAST_PATTERN.search(sql))), # has_cast
166
+ float(_nesting_depth(sql)), # nesting_depth
167
+ float(num_conditions), # num_and_or
168
+ float(num_string_lits), # num_string_literals
169
+ float(num_numeric_lits), # num_numeric_literals
170
+ ]
171
+
172
+ return features
predict.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference API for pg_plan_cache models.
3
+
4
+ Loads trained models and provides prediction functions for:
5
+ 1. Cache benefit (high / medium / low)
6
+ 2. Recommended TTL (seconds)
7
+ 3. Complexity score (1-100)
8
+ """
9
+
10
+ import os
11
+ import json
12
+ import joblib
13
+ import numpy as np
14
+ from features import extract_features, FEATURE_NAMES
15
+
16
+ MODEL_DIR = os.path.join(os.path.dirname(__file__), "trained")
17
+
18
+ _cache_advisor = None
19
+ _ttl_recommender = None
20
+ _complexity_estimator = None
21
+ _label_encoder = None
22
+ _loaded = False
23
+
24
+
25
+ def _load_models():
26
+ """Lazy-load all models from disk."""
27
+ global _cache_advisor, _ttl_recommender, _complexity_estimator, _label_encoder, _loaded
28
+ if _loaded:
29
+ return
30
+
31
+ _cache_advisor = joblib.load(os.path.join(MODEL_DIR, "cache_advisor.joblib"))
32
+ _ttl_recommender = joblib.load(os.path.join(MODEL_DIR, "ttl_recommender.joblib"))
33
+ _complexity_estimator = joblib.load(os.path.join(MODEL_DIR, "complexity_estimator.joblib"))
34
+ _label_encoder = joblib.load(os.path.join(MODEL_DIR, "label_encoder.joblib"))
35
+ _loaded = True
36
+
37
+
38
+ def predict(sql: str) -> dict:
39
+ """
40
+ Run all three models on a SQL query.
41
+
42
+ Returns:
43
+ {
44
+ "query": str,
45
+ "cache_benefit": "high" | "medium" | "low",
46
+ "cache_benefit_probabilities": {"high": 0.8, "medium": 0.15, "low": 0.05},
47
+ "recommended_ttl": int, # seconds
48
+ "ttl_human": str, # e.g. "1h 0m"
49
+ "complexity_score": int, # 1-100
50
+ "complexity_label": str, # "simple" | "moderate" | "complex" | "very complex"
51
+ "features": {name: value, ...},
52
+ }
53
+ """
54
+ _load_models()
55
+
56
+ features = extract_features(sql)
57
+ X = np.array([features])
58
+
59
+ # Cache advisor
60
+ benefit_idx = _cache_advisor.predict(X)[0]
61
+ benefit_label = _label_encoder.inverse_transform([benefit_idx])[0]
62
+ benefit_probs = _cache_advisor.predict_proba(X)[0]
63
+ prob_dict = {
64
+ _label_encoder.inverse_transform([i])[0]: round(float(p), 4)
65
+ for i, p in enumerate(benefit_probs)
66
+ }
67
+
68
+ # TTL recommender
69
+ ttl_raw = _ttl_recommender.predict(X)[0]
70
+ ttl = max(0, int(round(ttl_raw)))
71
+ hours, mins = divmod(ttl // 60, 60)
72
+ ttl_human = f"{hours}h {mins}m" if hours else f"{mins}m"
73
+
74
+ # Complexity estimator
75
+ cplx_raw = _complexity_estimator.predict(X)[0]
76
+ cplx = max(1, min(100, int(round(cplx_raw))))
77
+ if cplx <= 20:
78
+ cplx_label = "simple"
79
+ elif cplx <= 45:
80
+ cplx_label = "moderate"
81
+ elif cplx <= 75:
82
+ cplx_label = "complex"
83
+ else:
84
+ cplx_label = "very complex"
85
+
86
+ return {
87
+ "query": sql,
88
+ "cache_benefit": benefit_label,
89
+ "cache_benefit_probabilities": prob_dict,
90
+ "recommended_ttl": ttl,
91
+ "ttl_human": ttl_human,
92
+ "complexity_score": cplx,
93
+ "complexity_label": cplx_label,
94
+ "features": dict(zip(FEATURE_NAMES, features)),
95
+ }
96
+
97
+
98
+ def predict_batch(queries: list[str]) -> list[dict]:
99
+ """Run predictions on multiple queries."""
100
+ return [predict(q) for q in queries]
101
+
102
+
103
+ def format_prediction(result: dict) -> str:
104
+ """Format a prediction result as a readable string."""
105
+ lines = [
106
+ f" Query: {result['query'][:100]}{'...' if len(result['query']) > 100 else ''}",
107
+ f" Cache Benefit: {result['cache_benefit'].upper()}",
108
+ f" Probabilities: {result['cache_benefit_probabilities']}",
109
+ f" Recommended TTL: {result['recommended_ttl']}s ({result['ttl_human']})",
110
+ f" Complexity: {result['complexity_score']}/100 ({result['complexity_label']})",
111
+ ]
112
+ return "\n".join(lines)
113
+
114
+
115
+ def get_model_info() -> dict:
116
+ """Return model metadata."""
117
+ meta_path = os.path.join(MODEL_DIR, "metadata.json")
118
+ if os.path.exists(meta_path):
119
+ with open(meta_path) as f:
120
+ return json.load(f)
121
+ return {"error": "metadata.json not found. Run train.py first."}
122
+
123
+
124
+ # ---------------------------------------------------------------------------
125
+ # CLI
126
+ # ---------------------------------------------------------------------------
127
+
128
+ if __name__ == "__main__":
129
+ import sys
130
+
131
+ if len(sys.argv) < 2:
132
+ print("Usage: python predict.py \"SELECT * FROM users WHERE id = 42\"")
133
+ sys.exit(1)
134
+
135
+ sql = " ".join(sys.argv[1:])
136
+ result = predict(sql)
137
+ print(format_prediction(result))
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ scikit-learn>=1.4.0
2
+ joblib>=1.3.0
3
+ numpy>=1.26.0
4
+ huggingface_hub>=0.24.0
train.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Train all three pg_plan_cache models:
4
+ 1. SQL Cache Advisor (classification: low / medium / high)
5
+ 2. Cache TTL Recommender (regression: seconds)
6
+ 3. Query Complexity Estimator (regression: 1-100 score)
7
+
8
+ Saves trained models as joblib files in the ./trained/ directory.
9
+ """
10
+
11
+ import os
12
+ import json
13
+ import numpy as np
14
+ from sklearn.ensemble import RandomForestClassifier, GradientBoostingRegressor
15
+ from sklearn.model_selection import train_test_split, cross_val_score
16
+ from sklearn.metrics import classification_report, mean_absolute_error, r2_score
17
+ from sklearn.preprocessing import LabelEncoder
18
+ import joblib
19
+
20
+ from features import extract_features, FEATURE_NAMES
21
+ from dataset import generate_dataset
22
+
23
+ OUTPUT_DIR = os.path.join(os.path.dirname(__file__), "trained")
24
+
25
+
26
+ def train():
27
+ print("=" * 60)
28
+ print(" pg_plan_cache — Model Training")
29
+ print("=" * 60)
30
+
31
+ # ── Generate data ─────────────────────────────────────────
32
+ print("\n[1/5] Generating synthetic training data...")
33
+ queries, benefits, ttls, complexities = generate_dataset(n=8000, seed=42)
34
+ print(f" Generated {len(queries)} samples")
35
+
36
+ # ── Extract features ──────────────────────────────────────
37
+ print("\n[2/5] Extracting features...")
38
+ X = np.array([extract_features(q) for q in queries])
39
+ print(f" Feature matrix: {X.shape}")
40
+
41
+ # ── Encode labels ─────────────────────────────────────────
42
+ le = LabelEncoder()
43
+ y_benefit = le.fit_transform(benefits) # low=1, medium=2, high=0
44
+ y_ttl = np.array(ttls, dtype=float)
45
+ y_complexity = np.array(complexities, dtype=float)
46
+
47
+ # ── Split ─────────────────────────────────────────────────
48
+ X_train, X_test, yb_train, yb_test, yt_train, yt_test, yc_train, yc_test = \
49
+ train_test_split(X, y_benefit, y_ttl, y_complexity, test_size=0.2, random_state=42)
50
+
51
+ print(f" Train: {len(X_train)}, Test: {len(X_test)}")
52
+
53
+ # ── Model 1: Cache Advisor (classification) ───────────────
54
+ print("\n[3/5] Training SQL Cache Advisor...")
55
+ clf = RandomForestClassifier(
56
+ n_estimators=200,
57
+ max_depth=15,
58
+ min_samples_split=5,
59
+ min_samples_leaf=2,
60
+ random_state=42,
61
+ n_jobs=-1,
62
+ )
63
+ clf.fit(X_train, yb_train)
64
+
65
+ yb_pred = clf.predict(X_test)
66
+ print("\n Classification Report:")
67
+ report = classification_report(yb_test, yb_pred, target_names=le.classes_)
68
+ print(" " + report.replace("\n", "\n "))
69
+
70
+ cv_scores = cross_val_score(clf, X, y_benefit, cv=5, scoring="accuracy")
71
+ print(f" Cross-val accuracy: {cv_scores.mean():.3f} (+/- {cv_scores.std():.3f})")
72
+
73
+ # ── Model 2: TTL Recommender (regression) ─────────────────
74
+ print("\n[4/5] Training Cache TTL Recommender...")
75
+ reg_ttl = GradientBoostingRegressor(
76
+ n_estimators=200,
77
+ max_depth=8,
78
+ learning_rate=0.1,
79
+ min_samples_split=5,
80
+ random_state=42,
81
+ )
82
+ reg_ttl.fit(X_train, yt_train)
83
+
84
+ yt_pred = reg_ttl.predict(X_test)
85
+ mae_ttl = mean_absolute_error(yt_test, yt_pred)
86
+ r2_ttl = r2_score(yt_test, yt_pred)
87
+ print(f" MAE: {mae_ttl:.1f} seconds")
88
+ print(f" R2: {r2_ttl:.3f}")
89
+
90
+ # ── Model 3: Complexity Estimator (regression) ────────────
91
+ print("\n[5/5] Training Query Complexity Estimator...")
92
+ reg_cplx = GradientBoostingRegressor(
93
+ n_estimators=200,
94
+ max_depth=8,
95
+ learning_rate=0.1,
96
+ min_samples_split=5,
97
+ random_state=42,
98
+ )
99
+ reg_cplx.fit(X_train, yc_train)
100
+
101
+ yc_pred = reg_cplx.predict(X_test)
102
+ mae_cplx = mean_absolute_error(yc_test, yc_pred)
103
+ r2_cplx = r2_score(yc_test, yc_pred)
104
+ print(f" MAE: {mae_cplx:.1f} points")
105
+ print(f" R2: {r2_cplx:.3f}")
106
+
107
+ # ── Save models ───────────────────────────────────────────
108
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
109
+
110
+ joblib.dump(clf, os.path.join(OUTPUT_DIR, "cache_advisor.joblib"))
111
+ joblib.dump(reg_ttl, os.path.join(OUTPUT_DIR, "ttl_recommender.joblib"))
112
+ joblib.dump(reg_cplx, os.path.join(OUTPUT_DIR, "complexity_estimator.joblib"))
113
+ joblib.dump(le, os.path.join(OUTPUT_DIR, "label_encoder.joblib"))
114
+
115
+ # Feature importances
116
+ importances = {
117
+ "cache_advisor": dict(zip(FEATURE_NAMES, clf.feature_importances_.tolist())),
118
+ "ttl_recommender": dict(zip(FEATURE_NAMES, reg_ttl.feature_importances_.tolist())),
119
+ "complexity_estimator": dict(zip(FEATURE_NAMES, reg_cplx.feature_importances_.tolist())),
120
+ }
121
+ with open(os.path.join(OUTPUT_DIR, "feature_importances.json"), "w") as f:
122
+ json.dump(importances, f, indent=2)
123
+
124
+ # Model metadata
125
+ metadata = {
126
+ "models": {
127
+ "cache_advisor": {
128
+ "type": "RandomForestClassifier",
129
+ "task": "classification",
130
+ "classes": le.classes_.tolist(),
131
+ "accuracy_cv5": round(float(cv_scores.mean()), 4),
132
+ },
133
+ "ttl_recommender": {
134
+ "type": "GradientBoostingRegressor",
135
+ "task": "regression",
136
+ "unit": "seconds",
137
+ "mae": round(float(mae_ttl), 2),
138
+ "r2": round(float(r2_ttl), 4),
139
+ },
140
+ "complexity_estimator": {
141
+ "type": "GradientBoostingRegressor",
142
+ "task": "regression",
143
+ "unit": "score (1-100)",
144
+ "mae": round(float(mae_cplx), 2),
145
+ "r2": round(float(r2_cplx), 4),
146
+ },
147
+ },
148
+ "features": FEATURE_NAMES,
149
+ "n_features": len(FEATURE_NAMES),
150
+ "training_samples": len(queries),
151
+ "test_samples": len(X_test),
152
+ }
153
+ with open(os.path.join(OUTPUT_DIR, "metadata.json"), "w") as f:
154
+ json.dump(metadata, f, indent=2)
155
+
156
+ print(f"\n Models saved to {OUTPUT_DIR}/")
157
+ print(" Files: cache_advisor.joblib, ttl_recommender.joblib,")
158
+ print(" complexity_estimator.joblib, label_encoder.joblib,")
159
+ print(" feature_importances.json, metadata.json")
160
+ print("\nDone.")
161
+
162
+
163
+ if __name__ == "__main__":
164
+ train()
trained/cache_advisor.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e11ba948fd643d426b62362f7fd71e30ec90e4a1f1593b2606ae1e31b7b3b19f
3
+ size 818001
trained/complexity_estimator.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3bd2a5edfce1496bc23a0686e6822ff3d583c884ad9922d9eed5f369ef0b064b
3
+ size 3038236
trained/feature_importances.json ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cache_advisor": {
3
+ "query_length": 0.19116243566746416,
4
+ "query_type": 0.02137394504176744,
5
+ "num_tables": 0.09222282366305111,
6
+ "num_joins": 0.0748793608388074,
7
+ "num_conditions": 0.00154549784133088,
8
+ "num_aggregates": 0.0618503755668228,
9
+ "num_subqueries": 0.05156804724205885,
10
+ "num_columns": 0.07578970828634744,
11
+ "has_distinct": 0.04377194157855687,
12
+ "has_order_by": 0.03645645249300166,
13
+ "has_group_by": 0.04425844049972725,
14
+ "has_having": 0.0022541480803507635,
15
+ "has_limit": 0.042062573427220216,
16
+ "has_offset": 0.0,
17
+ "has_where": 0.008477512665144578,
18
+ "has_like": 0.0,
19
+ "has_in_clause": 0.005441079955562388,
20
+ "has_between": 0.0,
21
+ "has_exists": 0.0009272674367364887,
22
+ "has_window_func": 0.010171898283664462,
23
+ "has_cte": 0.0017415634776680982,
24
+ "has_union": 0.021229522300210402,
25
+ "has_case": 0.010714231584388431,
26
+ "has_cast": 0.0,
27
+ "nesting_depth": 0.1651162458494366,
28
+ "num_and_or": 0.0018247999615881344,
29
+ "num_string_literals": 0.02825993434632328,
30
+ "num_numeric_literals": 0.006900193912770316
31
+ },
32
+ "ttl_recommender": {
33
+ "query_length": 0.49334167936522283,
34
+ "query_type": 0.011472503279799304,
35
+ "num_tables": 0.04121816512371646,
36
+ "num_joins": 0.05664091770080013,
37
+ "num_conditions": 2.6766564086239894e-05,
38
+ "num_aggregates": 0.08454674221524747,
39
+ "num_subqueries": 0.012819407143812049,
40
+ "num_columns": 0.003503947545486143,
41
+ "has_distinct": 0.0058846177923228245,
42
+ "has_order_by": 0.0030112892658353254,
43
+ "has_group_by": 0.11555986501253222,
44
+ "has_having": 0.0005654100636265899,
45
+ "has_limit": 0.020011249481941062,
46
+ "has_offset": 0.0,
47
+ "has_where": 0.0006198304413308254,
48
+ "has_like": 0.0,
49
+ "has_in_clause": 0.006723068906959933,
50
+ "has_between": 0.0,
51
+ "has_exists": 1.5939534844064166e-05,
52
+ "has_window_func": 0.0016085055032078448,
53
+ "has_cte": 2.3841716696771857e-05,
54
+ "has_union": 5.051873650507809e-05,
55
+ "has_case": 2.1925568628142657e-05,
56
+ "has_cast": 0.0,
57
+ "nesting_depth": 0.13173720022142668,
58
+ "num_and_or": 2.27992721191164e-05,
59
+ "num_string_literals": 0.005676787044969987,
60
+ "num_numeric_literals": 0.004897022498882968
61
+ },
62
+ "complexity_estimator": {
63
+ "query_length": 0.5344926759628151,
64
+ "query_type": 0.0015962377188123598,
65
+ "num_tables": 0.031559929024199504,
66
+ "num_joins": 0.02335110657414861,
67
+ "num_conditions": 5.757862902242119e-05,
68
+ "num_aggregates": 0.04750932601796666,
69
+ "num_subqueries": 0.008970394733974358,
70
+ "num_columns": 0.00588104652025957,
71
+ "has_distinct": 0.01062122091510926,
72
+ "has_order_by": 0.0024661023837127443,
73
+ "has_group_by": 0.061828695835283276,
74
+ "has_having": 0.00034502697726715757,
75
+ "has_limit": 0.020807067356268808,
76
+ "has_offset": 0.0,
77
+ "has_where": 0.0004570231775885458,
78
+ "has_like": 0.0,
79
+ "has_in_clause": 0.013672027252240813,
80
+ "has_between": 0.0,
81
+ "has_exists": 7.242098418966911e-05,
82
+ "has_window_func": 0.0009971635825058846,
83
+ "has_cte": 1.4790912091677233e-05,
84
+ "has_union": 0.006250913065401877,
85
+ "has_case": 1.6824403324258042e-05,
86
+ "has_cast": 0.0,
87
+ "nesting_depth": 0.22281668760327789,
88
+ "num_and_or": 7.125446039372882e-05,
89
+ "num_string_literals": 0.003162795018354366,
90
+ "num_numeric_literals": 0.0029816908917914696
91
+ }
92
+ }
trained/label_encoder.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dca6130147e0c2d5e5b985a5abb3087d622fbe3da1e3e09ce3c5a79cc5fd15e8
3
+ size 399
trained/metadata.json ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "models": {
3
+ "cache_advisor": {
4
+ "type": "RandomForestClassifier",
5
+ "task": "classification",
6
+ "classes": [
7
+ "high",
8
+ "low",
9
+ "medium"
10
+ ],
11
+ "accuracy_cv5": 1.0
12
+ },
13
+ "ttl_recommender": {
14
+ "type": "GradientBoostingRegressor",
15
+ "task": "regression",
16
+ "unit": "seconds",
17
+ "mae": 494.56,
18
+ "r2": 0.8994
19
+ },
20
+ "complexity_estimator": {
21
+ "type": "GradientBoostingRegressor",
22
+ "task": "regression",
23
+ "unit": "score (1-100)",
24
+ "mae": 5.57,
25
+ "r2": 0.9216
26
+ }
27
+ },
28
+ "features": [
29
+ "query_length",
30
+ "query_type",
31
+ "num_tables",
32
+ "num_joins",
33
+ "num_conditions",
34
+ "num_aggregates",
35
+ "num_subqueries",
36
+ "num_columns",
37
+ "has_distinct",
38
+ "has_order_by",
39
+ "has_group_by",
40
+ "has_having",
41
+ "has_limit",
42
+ "has_offset",
43
+ "has_where",
44
+ "has_like",
45
+ "has_in_clause",
46
+ "has_between",
47
+ "has_exists",
48
+ "has_window_func",
49
+ "has_cte",
50
+ "has_union",
51
+ "has_case",
52
+ "has_cast",
53
+ "nesting_depth",
54
+ "num_and_or",
55
+ "num_string_literals",
56
+ "num_numeric_literals"
57
+ ],
58
+ "n_features": 28,
59
+ "training_samples": 8000,
60
+ "test_samples": 1600
61
+ }
trained/ttl_recommender.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6ac8fbc0829aba31da6ff9ea299f512b63ed95c065cc2ae7a5779c7a110486aa
3
+ size 3066316