pg-plan-cache-models / predict.py
nilenpatel's picture
Upload pg_plan_cache models
406cec4 verified
"""
Inference API for pg_plan_cache models.
Loads trained models and provides prediction functions for:
1. Cache benefit (high / medium / low)
2. Recommended TTL (seconds)
3. Complexity score (1-100)
"""
import os
import json
import joblib
import numpy as np
from features import extract_features, FEATURE_NAMES
MODEL_DIR = os.path.join(os.path.dirname(__file__), "trained")
_cache_advisor = None
_ttl_recommender = None
_complexity_estimator = None
_label_encoder = None
_loaded = False
def _load_models():
"""Lazy-load all models from disk."""
global _cache_advisor, _ttl_recommender, _complexity_estimator, _label_encoder, _loaded
if _loaded:
return
_cache_advisor = joblib.load(os.path.join(MODEL_DIR, "cache_advisor.joblib"))
_ttl_recommender = joblib.load(os.path.join(MODEL_DIR, "ttl_recommender.joblib"))
_complexity_estimator = joblib.load(os.path.join(MODEL_DIR, "complexity_estimator.joblib"))
_label_encoder = joblib.load(os.path.join(MODEL_DIR, "label_encoder.joblib"))
_loaded = True
def predict(sql: str) -> dict:
"""
Run all three models on a SQL query.
Returns:
{
"query": str,
"cache_benefit": "high" | "medium" | "low",
"cache_benefit_probabilities": {"high": 0.8, "medium": 0.15, "low": 0.05},
"recommended_ttl": int, # seconds
"ttl_human": str, # e.g. "1h 0m"
"complexity_score": int, # 1-100
"complexity_label": str, # "simple" | "moderate" | "complex" | "very complex"
"features": {name: value, ...},
}
"""
_load_models()
features = extract_features(sql)
X = np.array([features])
# Cache advisor
benefit_idx = _cache_advisor.predict(X)[0]
benefit_label = _label_encoder.inverse_transform([benefit_idx])[0]
benefit_probs = _cache_advisor.predict_proba(X)[0]
prob_dict = {
_label_encoder.inverse_transform([i])[0]: round(float(p), 4)
for i, p in enumerate(benefit_probs)
}
# TTL recommender
ttl_raw = _ttl_recommender.predict(X)[0]
ttl = max(0, int(round(ttl_raw)))
hours, mins = divmod(ttl // 60, 60)
ttl_human = f"{hours}h {mins}m" if hours else f"{mins}m"
# Complexity estimator
cplx_raw = _complexity_estimator.predict(X)[0]
cplx = max(1, min(100, int(round(cplx_raw))))
if cplx <= 20:
cplx_label = "simple"
elif cplx <= 45:
cplx_label = "moderate"
elif cplx <= 75:
cplx_label = "complex"
else:
cplx_label = "very complex"
return {
"query": sql,
"cache_benefit": benefit_label,
"cache_benefit_probabilities": prob_dict,
"recommended_ttl": ttl,
"ttl_human": ttl_human,
"complexity_score": cplx,
"complexity_label": cplx_label,
"features": dict(zip(FEATURE_NAMES, features)),
}
def predict_batch(queries: list[str]) -> list[dict]:
"""Run predictions on multiple queries."""
return [predict(q) for q in queries]
def format_prediction(result: dict) -> str:
"""Format a prediction result as a readable string."""
lines = [
f" Query: {result['query'][:100]}{'...' if len(result['query']) > 100 else ''}",
f" Cache Benefit: {result['cache_benefit'].upper()}",
f" Probabilities: {result['cache_benefit_probabilities']}",
f" Recommended TTL: {result['recommended_ttl']}s ({result['ttl_human']})",
f" Complexity: {result['complexity_score']}/100 ({result['complexity_label']})",
]
return "\n".join(lines)
def get_model_info() -> dict:
"""Return model metadata."""
meta_path = os.path.join(MODEL_DIR, "metadata.json")
if os.path.exists(meta_path):
with open(meta_path) as f:
return json.load(f)
return {"error": "metadata.json not found. Run train.py first."}
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
if __name__ == "__main__":
import sys
if len(sys.argv) < 2:
print("Usage: python predict.py \"SELECT * FROM users WHERE id = 42\"")
sys.exit(1)
sql = " ".join(sys.argv[1:])
result = predict(sql)
print(format_prediction(result))