""" Inference API for pg_plan_cache models. Loads trained models and provides prediction functions for: 1. Cache benefit (high / medium / low) 2. Recommended TTL (seconds) 3. Complexity score (1-100) """ import os import json import joblib import numpy as np from features import extract_features, FEATURE_NAMES MODEL_DIR = os.path.join(os.path.dirname(__file__), "trained") _cache_advisor = None _ttl_recommender = None _complexity_estimator = None _label_encoder = None _loaded = False def _load_models(): """Lazy-load all models from disk.""" global _cache_advisor, _ttl_recommender, _complexity_estimator, _label_encoder, _loaded if _loaded: return _cache_advisor = joblib.load(os.path.join(MODEL_DIR, "cache_advisor.joblib")) _ttl_recommender = joblib.load(os.path.join(MODEL_DIR, "ttl_recommender.joblib")) _complexity_estimator = joblib.load(os.path.join(MODEL_DIR, "complexity_estimator.joblib")) _label_encoder = joblib.load(os.path.join(MODEL_DIR, "label_encoder.joblib")) _loaded = True def predict(sql: str) -> dict: """ Run all three models on a SQL query. Returns: { "query": str, "cache_benefit": "high" | "medium" | "low", "cache_benefit_probabilities": {"high": 0.8, "medium": 0.15, "low": 0.05}, "recommended_ttl": int, # seconds "ttl_human": str, # e.g. "1h 0m" "complexity_score": int, # 1-100 "complexity_label": str, # "simple" | "moderate" | "complex" | "very complex" "features": {name: value, ...}, } """ _load_models() features = extract_features(sql) X = np.array([features]) # Cache advisor benefit_idx = _cache_advisor.predict(X)[0] benefit_label = _label_encoder.inverse_transform([benefit_idx])[0] benefit_probs = _cache_advisor.predict_proba(X)[0] prob_dict = { _label_encoder.inverse_transform([i])[0]: round(float(p), 4) for i, p in enumerate(benefit_probs) } # TTL recommender ttl_raw = _ttl_recommender.predict(X)[0] ttl = max(0, int(round(ttl_raw))) hours, mins = divmod(ttl // 60, 60) ttl_human = f"{hours}h {mins}m" if hours else f"{mins}m" # Complexity estimator cplx_raw = _complexity_estimator.predict(X)[0] cplx = max(1, min(100, int(round(cplx_raw)))) if cplx <= 20: cplx_label = "simple" elif cplx <= 45: cplx_label = "moderate" elif cplx <= 75: cplx_label = "complex" else: cplx_label = "very complex" return { "query": sql, "cache_benefit": benefit_label, "cache_benefit_probabilities": prob_dict, "recommended_ttl": ttl, "ttl_human": ttl_human, "complexity_score": cplx, "complexity_label": cplx_label, "features": dict(zip(FEATURE_NAMES, features)), } def predict_batch(queries: list[str]) -> list[dict]: """Run predictions on multiple queries.""" return [predict(q) for q in queries] def format_prediction(result: dict) -> str: """Format a prediction result as a readable string.""" lines = [ f" Query: {result['query'][:100]}{'...' if len(result['query']) > 100 else ''}", f" Cache Benefit: {result['cache_benefit'].upper()}", f" Probabilities: {result['cache_benefit_probabilities']}", f" Recommended TTL: {result['recommended_ttl']}s ({result['ttl_human']})", f" Complexity: {result['complexity_score']}/100 ({result['complexity_label']})", ] return "\n".join(lines) def get_model_info() -> dict: """Return model metadata.""" meta_path = os.path.join(MODEL_DIR, "metadata.json") if os.path.exists(meta_path): with open(meta_path) as f: return json.load(f) return {"error": "metadata.json not found. Run train.py first."} # --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- if __name__ == "__main__": import sys if len(sys.argv) < 2: print("Usage: python predict.py \"SELECT * FROM users WHERE id = 42\"") sys.exit(1) sql = " ".join(sys.argv[1:]) result = predict(sql) print(format_prediction(result))