Builder-Neekhil's picture
UI commit
853e25d verified
"""SHAP-based explanation layer.
Given a prediction, computes per-feature SHAP values via the XGBoost
tree explainer (single source of truth is cheap and fast; the LGB + CAT
signs would closely mirror XGB for this family of tabular models).
Returns the top-3 helping and top-3 hurting factors, each paired with
natural-language copy that references the Gottman/survival framework
rather than raw feature names.
"""
from __future__ import annotations
import numpy as np
# Human-readable copy for each known feature name.
# If a SHAP-ranked feature isn't in this map, we fall back to a prettified
# version of its column name.
FEATURE_COPY = {
# Gottman dimensions
"gottman_proxy_love_maps": "You know each other's inner worlds well",
"gottman_proxy_shared_goals": "You're aligned on direction and interests",
"gottman_proxy_love_maps_x_shared_goals":
"You know each other AND share direction (the strongest signal in the model)",
"gottman_proxy_ratio":
"Your positive-to-negative interaction ratio (Gottman's 5:1 rule)",
"gottman_proxy_repair":
"You repair well after conflict",
"gottman_proxy_contempt":
"Contempt shows up in the relationship",
"gottman_proxy_criticism":
"Criticism is a recurring pattern",
"gottman_proxy_defensiveness":
"Defensiveness blocks conflict resolution",
"gottman_proxy_stonewalling":
"Shutting down / withdrawing in conflict",
"gottman_proxy_horsemen":
"Gottman's Four Horsemen are present",
"gottman_proxy_contempt_x_stonewalling":
"Contempt + stonewalling together (the deadliest Gottman interaction)",
"gottman_proxy_net_risk":
"The negative patterns outweigh the positive ones",
"gottman_proxy_deep_contempt":
"Deep contempt without empathy buffer",
# Survival features
"survival_is_first_marriage":
"This is your first marriage (74% lower hazard per Cox PH)",
"survival_is_love_marriage":
"Love-marriage classification (23% lower hazard in the training data)",
"survival_age_factor":
"Your age at commitment (each year older = 4% lower hazard)",
"survival_age_gap_risk":
"The age gap between you",
"survival_marriage_number_hazard":
"Subsequent marriages carry compounding hazard (HR 1.34 per step)",
"survival_in_danger_window":
"You're in the 3-7 year window where 41% of divorces cluster",
"survival_hazard_score":
"Overall Cox-model hazard score",
# Demographics
"age": "Your age at commitment",
"age_o": "Your partner's age",
"d_age": "Age difference between you",
"income": "Career/income proxy",
"samerace": "Same-race match",
}
def explain(prediction: dict, feature_vector: np.ndarray,
feature_columns: list[str], loaded_model) -> dict:
"""Compute SHAP top-k for a single prediction.
Returns:
{
"helping": [{"feature": str, "shap": float, "copy": str}, ...], # len 3
"hurting": [{"feature": str, "shap": float, "copy": str}, ...], # len 3
"mover": {"feature": str, "copy": str}, # the one that would most move the needle
}
"""
try:
import shap
explainer = shap.TreeExplainer(loaded_model.xgb)
x = feature_vector.reshape(1, -1) if feature_vector.ndim == 1 else feature_vector
shap_values = explainer.shap_values(x)
# For binary classification, shap returns shape (1, n_features)
if isinstance(shap_values, list):
shap_values = shap_values[1] if len(shap_values) > 1 else shap_values[0]
shap_vals = np.asarray(shap_values).flatten()
except Exception:
# Fallback: feature importance from the model, sign-adjusted by
# whether the user's value was above or below the feature mean.
# This is a graceful degradation, not a silent failure.
importances = getattr(loaded_model.xgb, "feature_importances_", None)
if importances is None:
importances = np.ones(len(feature_columns))
shap_vals = importances * (feature_vector - 0.5) # crude proxy
# Pair SHAP values with names and sort
pairs = list(zip(feature_columns, shap_vals.tolist()))
# Filter out low-magnitude noise
pairs = [p for p in pairs if abs(p[1]) > 1e-5]
pairs_sorted_desc = sorted(pairs, key=lambda p: p[1], reverse=True)
pairs_sorted_asc = sorted(pairs, key=lambda p: p[1])
helping = [_render(f, v) for f, v in pairs_sorted_desc[:3]]
hurting = [_render(f, v) for f, v in pairs_sorted_asc[:3]]
# Mover = the single feature whose flip would most change the prediction
mover_candidates = sorted(pairs, key=lambda p: abs(p[1]), reverse=True)
mover = _render(*mover_candidates[0]) if mover_candidates else None
return {
"helping": helping,
"hurting": hurting,
"mover": mover,
}
def _render(feature: str, shap_value: float) -> dict:
return {
"feature": feature,
"shap": shap_value,
"copy": FEATURE_COPY.get(feature, _prettify(feature)),
}
def _prettify(col: str) -> str:
# Fallback for un-mapped columns
return col.replace("_", " ").replace("proxy", "").strip().capitalize()