Builder-Neekhil's picture
Update src/inference.py
3963305 verified
"""Model loading + ensemble inference.
Loads the v2 ensemble from Builder-Neekhil/relationship-longevity-predictor
at startup, caches in memory. Exposes a single predict() function that
returns probability + per-model contributions.
"""
from __future__ import annotations
import json
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import joblib
import numpy as np
from huggingface_hub import hf_hub_download
MODEL_REPO = "Builder-Neekhil/relationship-longevity-predictor"
# v2 artifact paths inside the model repo
V2_XGB = "v2_enhanced/enhanced_xgb.joblib"
V2_LGB = "v2_enhanced/enhanced_lgb.joblib"
V2_CAT = "v2_enhanced/enhanced_cat.cbm"
V2_CONFIG = "v2_enhanced/enhanced_config.json"
V2_FEATURE_COLS = "v2_enhanced/enhanced_feature_columns.joblib"
GOTTMAN_RECIPE = "phase1_divorce_model/gottman_recipe.json"
SURVIVAL_RECIPE = "phase2_survival_model/survival_recipe.json"
LONGEVITY_PRIORS = "phase2_survival_model/longevity_priors.json"
@dataclass
class LoadedModel:
xgb: Any
lgb: Any
cat: Any
feature_columns: list[str]
config: dict
gottman_recipe: dict
survival_recipe: dict
longevity_priors: dict
weights: dict[str, float]
_CACHE: LoadedModel | None = None
def _download(filename: str) -> str:
"""Download a file from the model repo, cached by HF hub."""
return hf_hub_download(repo_id=MODEL_REPO, filename=filename)
def _load_json(path: str) -> dict:
with open(path) as f:
return json.load(f)
def load() -> LoadedModel:
"""Load the full ensemble + recipes. Cached after first call."""
global _CACHE
if _CACHE is not None:
return _CACHE
xgb_path = _download(V2_XGB)
lgb_path = _download(V2_LGB)
cat_path = _download(V2_CAT)
config_path = _download(V2_CONFIG)
cols_path = _download(V2_FEATURE_COLS)
xgb = joblib.load(xgb_path)
lgb = joblib.load(lgb_path)
# CatBoost uses its own format
from catboost import CatBoostClassifier
cat = CatBoostClassifier()
cat.load_model(cat_path)
feature_columns = joblib.load(cols_path)
config = _load_json(config_path)
# Recipes are small JSONs that drive feature engineering at runtime
try:
gottman_recipe = _load_json(_download(GOTTMAN_RECIPE))
except Exception:
gottman_recipe = {}
try:
survival_recipe = _load_json(_download(SURVIVAL_RECIPE))
except Exception:
survival_recipe = {}
try:
longevity_priors = _load_json(_download(LONGEVITY_PRIORS))
except Exception:
longevity_priors = {}
# Ensemble weights — normalize whatever key convention the config uses
# to our canonical xgb/lgb/cat keys. This is defensive against the many
# possible naming schemes (xgboost, enhanced_xgb, lgbm, etc.)
raw_weights = config.get("weights", config) # fall back to top-level keys
weights = _normalize_weights(raw_weights)
import logging as _log
_log.getLogger(__name__).info("Resolved ensemble weights: %s", weights)
_CACHE = LoadedModel(
xgb=xgb,
lgb=lgb,
cat=cat,
feature_columns=feature_columns,
config=config,
gottman_recipe=gottman_recipe,
survival_recipe=survival_recipe,
longevity_priors=longevity_priors,
weights=weights,
)
return _CACHE
def predict(feature_vector: np.ndarray) -> dict:
"""Run weighted ensemble prediction.
Args:
feature_vector: shape (n_features,) or (1, n_features)
Returns:
{
"probability": float in [0, 1],
"per_model": {"xgb": float, "lgb": float, "cat": float},
"band": str, # "Low" | "Moderate" | "Moderate-High" | "High"
}
"""
model = load()
x = np.asarray(feature_vector, dtype=np.float32)
if x.ndim == 1:
x = x.reshape(1, -1)
expected = len(model.feature_columns)
if x.shape[1] != expected:
raise ValueError(
f"Feature shape mismatch: got {x.shape[1]}, model expects {expected}. "
f"Check feature_builder output against the feature_columns list."
)
p_xgb = float(model.xgb.predict_proba(x)[0, 1])
p_lgb = float(model.lgb.predict_proba(x)[0, 1])
p_cat = float(model.cat.predict_proba(x)[0, 1])
w = model.weights
p = w["xgb"] * p_xgb + w["lgb"] * p_lgb + w["cat"] * p_cat
return {
"probability": p,
"per_model": {"xgb": p_xgb, "lgb": p_lgb, "cat": p_cat},
"band": _band(p),
}
def _band(p: float) -> str:
"""Map probability to a human-readable band.
Bands are intentionally wide — avoids false precision.
"""
if p < 0.35:
return "Low"
if p < 0.55:
return "Moderate"
if p < 0.75:
return "Moderate-High"
return "High"
def _normalize_weights(raw: dict) -> dict[str, float]:
"""Map any known naming convention to canonical xgb / lgb / cat keys.
Handles: xgboost/xgb/enhanced_xgb, lightgbm/lgb/lgbm/enhanced_lgb,
catboost/cat/enhanced_cat, plus any _weight suffix variant.
"""
import logging as _log
log = _log.getLogger(__name__)
aliases = {
"xgb": ["xgb", "xgboost", "enhanced_xgb", "xgb_weight", "xgboost_weight"],
"lgb": ["lgb", "lgbm", "lightgbm", "enhanced_lgb", "lgb_weight", "lightgbm_weight"],
"cat": ["cat", "catboost", "enhanced_cat", "cat_weight", "catboost_weight"],
}
defaults = {"xgb": 0.40, "lgb": 0.35, "cat": 0.25}
out: dict[str, float] = {}
for canonical, keys in aliases.items():
for k in keys:
if isinstance(raw, dict) and k in raw and isinstance(raw[k], (int, float)):
out[canonical] = float(raw[k])
break
if canonical not in out:
log.warning(
"No weight found for %s in config (tried %s). Using default %.2f. "
"Available top-level keys: %s",
canonical, keys, defaults[canonical],
list(raw.keys()) if isinstance(raw, dict) else "N/A",
)
out[canonical] = defaults[canonical]
# Renormalize in case the resolved weights don't sum to 1.0
total = sum(out.values())
if total > 0:
out = {k: v / total for k, v in out.items()}
return out