File size: 5,256 Bytes
853e25d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 | """SHAP-based explanation layer.
Given a prediction, computes per-feature SHAP values via the XGBoost
tree explainer (single source of truth is cheap and fast; the LGB + CAT
signs would closely mirror XGB for this family of tabular models).
Returns the top-3 helping and top-3 hurting factors, each paired with
natural-language copy that references the Gottman/survival framework
rather than raw feature names.
"""
from __future__ import annotations
import numpy as np
# Human-readable copy for each known feature name.
# If a SHAP-ranked feature isn't in this map, we fall back to a prettified
# version of its column name.
FEATURE_COPY = {
# Gottman dimensions
"gottman_proxy_love_maps": "You know each other's inner worlds well",
"gottman_proxy_shared_goals": "You're aligned on direction and interests",
"gottman_proxy_love_maps_x_shared_goals":
"You know each other AND share direction (the strongest signal in the model)",
"gottman_proxy_ratio":
"Your positive-to-negative interaction ratio (Gottman's 5:1 rule)",
"gottman_proxy_repair":
"You repair well after conflict",
"gottman_proxy_contempt":
"Contempt shows up in the relationship",
"gottman_proxy_criticism":
"Criticism is a recurring pattern",
"gottman_proxy_defensiveness":
"Defensiveness blocks conflict resolution",
"gottman_proxy_stonewalling":
"Shutting down / withdrawing in conflict",
"gottman_proxy_horsemen":
"Gottman's Four Horsemen are present",
"gottman_proxy_contempt_x_stonewalling":
"Contempt + stonewalling together (the deadliest Gottman interaction)",
"gottman_proxy_net_risk":
"The negative patterns outweigh the positive ones",
"gottman_proxy_deep_contempt":
"Deep contempt without empathy buffer",
# Survival features
"survival_is_first_marriage":
"This is your first marriage (74% lower hazard per Cox PH)",
"survival_is_love_marriage":
"Love-marriage classification (23% lower hazard in the training data)",
"survival_age_factor":
"Your age at commitment (each year older = 4% lower hazard)",
"survival_age_gap_risk":
"The age gap between you",
"survival_marriage_number_hazard":
"Subsequent marriages carry compounding hazard (HR 1.34 per step)",
"survival_in_danger_window":
"You're in the 3-7 year window where 41% of divorces cluster",
"survival_hazard_score":
"Overall Cox-model hazard score",
# Demographics
"age": "Your age at commitment",
"age_o": "Your partner's age",
"d_age": "Age difference between you",
"income": "Career/income proxy",
"samerace": "Same-race match",
}
def explain(prediction: dict, feature_vector: np.ndarray,
feature_columns: list[str], loaded_model) -> dict:
"""Compute SHAP top-k for a single prediction.
Returns:
{
"helping": [{"feature": str, "shap": float, "copy": str}, ...], # len 3
"hurting": [{"feature": str, "shap": float, "copy": str}, ...], # len 3
"mover": {"feature": str, "copy": str}, # the one that would most move the needle
}
"""
try:
import shap
explainer = shap.TreeExplainer(loaded_model.xgb)
x = feature_vector.reshape(1, -1) if feature_vector.ndim == 1 else feature_vector
shap_values = explainer.shap_values(x)
# For binary classification, shap returns shape (1, n_features)
if isinstance(shap_values, list):
shap_values = shap_values[1] if len(shap_values) > 1 else shap_values[0]
shap_vals = np.asarray(shap_values).flatten()
except Exception:
# Fallback: feature importance from the model, sign-adjusted by
# whether the user's value was above or below the feature mean.
# This is a graceful degradation, not a silent failure.
importances = getattr(loaded_model.xgb, "feature_importances_", None)
if importances is None:
importances = np.ones(len(feature_columns))
shap_vals = importances * (feature_vector - 0.5) # crude proxy
# Pair SHAP values with names and sort
pairs = list(zip(feature_columns, shap_vals.tolist()))
# Filter out low-magnitude noise
pairs = [p for p in pairs if abs(p[1]) > 1e-5]
pairs_sorted_desc = sorted(pairs, key=lambda p: p[1], reverse=True)
pairs_sorted_asc = sorted(pairs, key=lambda p: p[1])
helping = [_render(f, v) for f, v in pairs_sorted_desc[:3]]
hurting = [_render(f, v) for f, v in pairs_sorted_asc[:3]]
# Mover = the single feature whose flip would most change the prediction
mover_candidates = sorted(pairs, key=lambda p: abs(p[1]), reverse=True)
mover = _render(*mover_candidates[0]) if mover_candidates else None
return {
"helping": helping,
"hurting": hurting,
"mover": mover,
}
def _render(feature: str, shap_value: float) -> dict:
return {
"feature": feature,
"shap": shap_value,
"copy": FEATURE_COPY.get(feature, _prettify(feature)),
}
def _prettify(col: str) -> str:
# Fallback for un-mapped columns
return col.replace("_", " ").replace("proxy", "").strip().capitalize()
|