| """ |
| Inference API for pg_plan_cache models. |
| |
| Loads trained models and provides prediction functions for: |
| 1. Cache benefit (high / medium / low) |
| 2. Recommended TTL (seconds) |
| 3. Complexity score (1-100) |
| """ |
|
|
| import os |
| import json |
| import joblib |
| import numpy as np |
| from features import extract_features, FEATURE_NAMES |
|
|
| MODEL_DIR = os.path.join(os.path.dirname(__file__), "trained") |
|
|
| _cache_advisor = None |
| _ttl_recommender = None |
| _complexity_estimator = None |
| _label_encoder = None |
| _loaded = False |
|
|
|
|
| def _load_models(): |
| """Lazy-load all models from disk.""" |
| global _cache_advisor, _ttl_recommender, _complexity_estimator, _label_encoder, _loaded |
| if _loaded: |
| return |
|
|
| _cache_advisor = joblib.load(os.path.join(MODEL_DIR, "cache_advisor.joblib")) |
| _ttl_recommender = joblib.load(os.path.join(MODEL_DIR, "ttl_recommender.joblib")) |
| _complexity_estimator = joblib.load(os.path.join(MODEL_DIR, "complexity_estimator.joblib")) |
| _label_encoder = joblib.load(os.path.join(MODEL_DIR, "label_encoder.joblib")) |
| _loaded = True |
|
|
|
|
| def predict(sql: str) -> dict: |
| """ |
| Run all three models on a SQL query. |
| |
| Returns: |
| { |
| "query": str, |
| "cache_benefit": "high" | "medium" | "low", |
| "cache_benefit_probabilities": {"high": 0.8, "medium": 0.15, "low": 0.05}, |
| "recommended_ttl": int, # seconds |
| "ttl_human": str, # e.g. "1h 0m" |
| "complexity_score": int, # 1-100 |
| "complexity_label": str, # "simple" | "moderate" | "complex" | "very complex" |
| "features": {name: value, ...}, |
| } |
| """ |
| _load_models() |
|
|
| features = extract_features(sql) |
| X = np.array([features]) |
|
|
| |
| benefit_idx = _cache_advisor.predict(X)[0] |
| benefit_label = _label_encoder.inverse_transform([benefit_idx])[0] |
| benefit_probs = _cache_advisor.predict_proba(X)[0] |
| prob_dict = { |
| _label_encoder.inverse_transform([i])[0]: round(float(p), 4) |
| for i, p in enumerate(benefit_probs) |
| } |
|
|
| |
| ttl_raw = _ttl_recommender.predict(X)[0] |
| ttl = max(0, int(round(ttl_raw))) |
| hours, mins = divmod(ttl // 60, 60) |
| ttl_human = f"{hours}h {mins}m" if hours else f"{mins}m" |
|
|
| |
| cplx_raw = _complexity_estimator.predict(X)[0] |
| cplx = max(1, min(100, int(round(cplx_raw)))) |
| if cplx <= 20: |
| cplx_label = "simple" |
| elif cplx <= 45: |
| cplx_label = "moderate" |
| elif cplx <= 75: |
| cplx_label = "complex" |
| else: |
| cplx_label = "very complex" |
|
|
| return { |
| "query": sql, |
| "cache_benefit": benefit_label, |
| "cache_benefit_probabilities": prob_dict, |
| "recommended_ttl": ttl, |
| "ttl_human": ttl_human, |
| "complexity_score": cplx, |
| "complexity_label": cplx_label, |
| "features": dict(zip(FEATURE_NAMES, features)), |
| } |
|
|
|
|
| def predict_batch(queries: list[str]) -> list[dict]: |
| """Run predictions on multiple queries.""" |
| return [predict(q) for q in queries] |
|
|
|
|
| def format_prediction(result: dict) -> str: |
| """Format a prediction result as a readable string.""" |
| lines = [ |
| f" Query: {result['query'][:100]}{'...' if len(result['query']) > 100 else ''}", |
| f" Cache Benefit: {result['cache_benefit'].upper()}", |
| f" Probabilities: {result['cache_benefit_probabilities']}", |
| f" Recommended TTL: {result['recommended_ttl']}s ({result['ttl_human']})", |
| f" Complexity: {result['complexity_score']}/100 ({result['complexity_label']})", |
| ] |
| return "\n".join(lines) |
|
|
|
|
| def get_model_info() -> dict: |
| """Return model metadata.""" |
| meta_path = os.path.join(MODEL_DIR, "metadata.json") |
| if os.path.exists(meta_path): |
| with open(meta_path) as f: |
| return json.load(f) |
| return {"error": "metadata.json not found. Run train.py first."} |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| import sys |
|
|
| if len(sys.argv) < 2: |
| print("Usage: python predict.py \"SELECT * FROM users WHERE id = 42\"") |
| sys.exit(1) |
|
|
| sql = " ".join(sys.argv[1:]) |
| result = predict(sql) |
| print(format_prediction(result)) |
|
|