| """Model loading + ensemble inference. |
| |
| Loads the v2 ensemble from Builder-Neekhil/relationship-longevity-predictor |
| at startup, caches in memory. Exposes a single predict() function that |
| returns probability + per-model contributions. |
| """ |
| from __future__ import annotations |
|
|
| import json |
| import os |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Any |
|
|
| import joblib |
| import numpy as np |
| from huggingface_hub import hf_hub_download |
|
|
| MODEL_REPO = "Builder-Neekhil/relationship-longevity-predictor" |
|
|
| |
| V2_XGB = "v2_enhanced/enhanced_xgb.joblib" |
| V2_LGB = "v2_enhanced/enhanced_lgb.joblib" |
| V2_CAT = "v2_enhanced/enhanced_cat.cbm" |
| V2_CONFIG = "v2_enhanced/enhanced_config.json" |
| V2_FEATURE_COLS = "v2_enhanced/enhanced_feature_columns.joblib" |
| GOTTMAN_RECIPE = "phase1_divorce_model/gottman_recipe.json" |
| SURVIVAL_RECIPE = "phase2_survival_model/survival_recipe.json" |
| LONGEVITY_PRIORS = "phase2_survival_model/longevity_priors.json" |
|
|
|
|
| @dataclass |
| class LoadedModel: |
| xgb: Any |
| lgb: Any |
| cat: Any |
| feature_columns: list[str] |
| config: dict |
| gottman_recipe: dict |
| survival_recipe: dict |
| longevity_priors: dict |
| weights: dict[str, float] |
|
|
|
|
| _CACHE: LoadedModel | None = None |
|
|
|
|
| def _download(filename: str) -> str: |
| """Download a file from the model repo, cached by HF hub.""" |
| return hf_hub_download(repo_id=MODEL_REPO, filename=filename) |
|
|
|
|
| def _load_json(path: str) -> dict: |
| with open(path) as f: |
| return json.load(f) |
|
|
|
|
| def load() -> LoadedModel: |
| """Load the full ensemble + recipes. Cached after first call.""" |
| global _CACHE |
| if _CACHE is not None: |
| return _CACHE |
|
|
| xgb_path = _download(V2_XGB) |
| lgb_path = _download(V2_LGB) |
| cat_path = _download(V2_CAT) |
| config_path = _download(V2_CONFIG) |
| cols_path = _download(V2_FEATURE_COLS) |
|
|
| xgb = joblib.load(xgb_path) |
| lgb = joblib.load(lgb_path) |
|
|
| |
| from catboost import CatBoostClassifier |
| cat = CatBoostClassifier() |
| cat.load_model(cat_path) |
|
|
| feature_columns = joblib.load(cols_path) |
| config = _load_json(config_path) |
|
|
| |
| try: |
| gottman_recipe = _load_json(_download(GOTTMAN_RECIPE)) |
| except Exception: |
| gottman_recipe = {} |
| try: |
| survival_recipe = _load_json(_download(SURVIVAL_RECIPE)) |
| except Exception: |
| survival_recipe = {} |
| try: |
| longevity_priors = _load_json(_download(LONGEVITY_PRIORS)) |
| except Exception: |
| longevity_priors = {} |
|
|
| |
| |
| |
| raw_weights = config.get("weights", config) |
| weights = _normalize_weights(raw_weights) |
| import logging as _log |
| _log.getLogger(__name__).info("Resolved ensemble weights: %s", weights) |
|
|
| _CACHE = LoadedModel( |
| xgb=xgb, |
| lgb=lgb, |
| cat=cat, |
| feature_columns=feature_columns, |
| config=config, |
| gottman_recipe=gottman_recipe, |
| survival_recipe=survival_recipe, |
| longevity_priors=longevity_priors, |
| weights=weights, |
| ) |
| return _CACHE |
|
|
|
|
| def predict(feature_vector: np.ndarray) -> dict: |
| """Run weighted ensemble prediction. |
| |
| Args: |
| feature_vector: shape (n_features,) or (1, n_features) |
| Returns: |
| { |
| "probability": float in [0, 1], |
| "per_model": {"xgb": float, "lgb": float, "cat": float}, |
| "band": str, # "Low" | "Moderate" | "Moderate-High" | "High" |
| } |
| """ |
| model = load() |
|
|
| x = np.asarray(feature_vector, dtype=np.float32) |
| if x.ndim == 1: |
| x = x.reshape(1, -1) |
|
|
| expected = len(model.feature_columns) |
| if x.shape[1] != expected: |
| raise ValueError( |
| f"Feature shape mismatch: got {x.shape[1]}, model expects {expected}. " |
| f"Check feature_builder output against the feature_columns list." |
| ) |
|
|
| p_xgb = float(model.xgb.predict_proba(x)[0, 1]) |
| p_lgb = float(model.lgb.predict_proba(x)[0, 1]) |
| p_cat = float(model.cat.predict_proba(x)[0, 1]) |
|
|
| w = model.weights |
| p = w["xgb"] * p_xgb + w["lgb"] * p_lgb + w["cat"] * p_cat |
|
|
| return { |
| "probability": p, |
| "per_model": {"xgb": p_xgb, "lgb": p_lgb, "cat": p_cat}, |
| "band": _band(p), |
| } |
|
|
|
|
| def _band(p: float) -> str: |
| """Map probability to a human-readable band. |
| |
| Bands are intentionally wide — avoids false precision. |
| """ |
| if p < 0.35: |
| return "Low" |
| if p < 0.55: |
| return "Moderate" |
| if p < 0.75: |
| return "Moderate-High" |
| return "High" |
|
|
| def _normalize_weights(raw: dict) -> dict[str, float]: |
| """Map any known naming convention to canonical xgb / lgb / cat keys. |
| |
| Handles: xgboost/xgb/enhanced_xgb, lightgbm/lgb/lgbm/enhanced_lgb, |
| catboost/cat/enhanced_cat, plus any _weight suffix variant. |
| """ |
| import logging as _log |
| log = _log.getLogger(__name__) |
|
|
| aliases = { |
| "xgb": ["xgb", "xgboost", "enhanced_xgb", "xgb_weight", "xgboost_weight"], |
| "lgb": ["lgb", "lgbm", "lightgbm", "enhanced_lgb", "lgb_weight", "lightgbm_weight"], |
| "cat": ["cat", "catboost", "enhanced_cat", "cat_weight", "catboost_weight"], |
| } |
| defaults = {"xgb": 0.40, "lgb": 0.35, "cat": 0.25} |
|
|
| out: dict[str, float] = {} |
| for canonical, keys in aliases.items(): |
| for k in keys: |
| if isinstance(raw, dict) and k in raw and isinstance(raw[k], (int, float)): |
| out[canonical] = float(raw[k]) |
| break |
| if canonical not in out: |
| log.warning( |
| "No weight found for %s in config (tried %s). Using default %.2f. " |
| "Available top-level keys: %s", |
| canonical, keys, defaults[canonical], |
| list(raw.keys()) if isinstance(raw, dict) else "N/A", |
| ) |
| out[canonical] = defaults[canonical] |
|
|
| |
| total = sum(out.values()) |
| if total > 0: |
| out = {k: v / total for k, v in out.items()} |
| return out |