"""Model loading + ensemble inference. Loads the v2 ensemble from Builder-Neekhil/relationship-longevity-predictor at startup, caches in memory. Exposes a single predict() function that returns probability + per-model contributions. """ from __future__ import annotations import json import os from dataclasses import dataclass from pathlib import Path from typing import Any import joblib import numpy as np from huggingface_hub import hf_hub_download MODEL_REPO = "Builder-Neekhil/relationship-longevity-predictor" # v2 artifact paths inside the model repo V2_XGB = "v2_enhanced/enhanced_xgb.joblib" V2_LGB = "v2_enhanced/enhanced_lgb.joblib" V2_CAT = "v2_enhanced/enhanced_cat.cbm" V2_CONFIG = "v2_enhanced/enhanced_config.json" V2_FEATURE_COLS = "v2_enhanced/enhanced_feature_columns.joblib" GOTTMAN_RECIPE = "phase1_divorce_model/gottman_recipe.json" SURVIVAL_RECIPE = "phase2_survival_model/survival_recipe.json" LONGEVITY_PRIORS = "phase2_survival_model/longevity_priors.json" @dataclass class LoadedModel: xgb: Any lgb: Any cat: Any feature_columns: list[str] config: dict gottman_recipe: dict survival_recipe: dict longevity_priors: dict weights: dict[str, float] _CACHE: LoadedModel | None = None def _download(filename: str) -> str: """Download a file from the model repo, cached by HF hub.""" return hf_hub_download(repo_id=MODEL_REPO, filename=filename) def _load_json(path: str) -> dict: with open(path) as f: return json.load(f) def load() -> LoadedModel: """Load the full ensemble + recipes. Cached after first call.""" global _CACHE if _CACHE is not None: return _CACHE xgb_path = _download(V2_XGB) lgb_path = _download(V2_LGB) cat_path = _download(V2_CAT) config_path = _download(V2_CONFIG) cols_path = _download(V2_FEATURE_COLS) xgb = joblib.load(xgb_path) lgb = joblib.load(lgb_path) # CatBoost uses its own format from catboost import CatBoostClassifier cat = CatBoostClassifier() cat.load_model(cat_path) feature_columns = joblib.load(cols_path) config = _load_json(config_path) # Recipes are small JSONs that drive feature engineering at runtime try: gottman_recipe = _load_json(_download(GOTTMAN_RECIPE)) except Exception: gottman_recipe = {} try: survival_recipe = _load_json(_download(SURVIVAL_RECIPE)) except Exception: survival_recipe = {} try: longevity_priors = _load_json(_download(LONGEVITY_PRIORS)) except Exception: longevity_priors = {} # Ensemble weights — normalize whatever key convention the config uses # to our canonical xgb/lgb/cat keys. This is defensive against the many # possible naming schemes (xgboost, enhanced_xgb, lgbm, etc.) raw_weights = config.get("weights", config) # fall back to top-level keys weights = _normalize_weights(raw_weights) import logging as _log _log.getLogger(__name__).info("Resolved ensemble weights: %s", weights) _CACHE = LoadedModel( xgb=xgb, lgb=lgb, cat=cat, feature_columns=feature_columns, config=config, gottman_recipe=gottman_recipe, survival_recipe=survival_recipe, longevity_priors=longevity_priors, weights=weights, ) return _CACHE def predict(feature_vector: np.ndarray) -> dict: """Run weighted ensemble prediction. Args: feature_vector: shape (n_features,) or (1, n_features) Returns: { "probability": float in [0, 1], "per_model": {"xgb": float, "lgb": float, "cat": float}, "band": str, # "Low" | "Moderate" | "Moderate-High" | "High" } """ model = load() x = np.asarray(feature_vector, dtype=np.float32) if x.ndim == 1: x = x.reshape(1, -1) expected = len(model.feature_columns) if x.shape[1] != expected: raise ValueError( f"Feature shape mismatch: got {x.shape[1]}, model expects {expected}. " f"Check feature_builder output against the feature_columns list." ) p_xgb = float(model.xgb.predict_proba(x)[0, 1]) p_lgb = float(model.lgb.predict_proba(x)[0, 1]) p_cat = float(model.cat.predict_proba(x)[0, 1]) w = model.weights p = w["xgb"] * p_xgb + w["lgb"] * p_lgb + w["cat"] * p_cat return { "probability": p, "per_model": {"xgb": p_xgb, "lgb": p_lgb, "cat": p_cat}, "band": _band(p), } def _band(p: float) -> str: """Map probability to a human-readable band. Bands are intentionally wide — avoids false precision. """ if p < 0.35: return "Low" if p < 0.55: return "Moderate" if p < 0.75: return "Moderate-High" return "High" def _normalize_weights(raw: dict) -> dict[str, float]: """Map any known naming convention to canonical xgb / lgb / cat keys. Handles: xgboost/xgb/enhanced_xgb, lightgbm/lgb/lgbm/enhanced_lgb, catboost/cat/enhanced_cat, plus any _weight suffix variant. """ import logging as _log log = _log.getLogger(__name__) aliases = { "xgb": ["xgb", "xgboost", "enhanced_xgb", "xgb_weight", "xgboost_weight"], "lgb": ["lgb", "lgbm", "lightgbm", "enhanced_lgb", "lgb_weight", "lightgbm_weight"], "cat": ["cat", "catboost", "enhanced_cat", "cat_weight", "catboost_weight"], } defaults = {"xgb": 0.40, "lgb": 0.35, "cat": 0.25} out: dict[str, float] = {} for canonical, keys in aliases.items(): for k in keys: if isinstance(raw, dict) and k in raw and isinstance(raw[k], (int, float)): out[canonical] = float(raw[k]) break if canonical not in out: log.warning( "No weight found for %s in config (tried %s). Using default %.2f. " "Available top-level keys: %s", canonical, keys, defaults[canonical], list(raw.keys()) if isinstance(raw, dict) else "N/A", ) out[canonical] = defaults[canonical] # Renormalize in case the resolved weights don't sum to 1.0 total = sum(out.values()) if total > 0: out = {k: v / total for k, v in out.items()} return out