| """SHAP-based explanation layer. |
| |
| Given a prediction, computes per-feature SHAP values via the XGBoost |
| tree explainer (single source of truth is cheap and fast; the LGB + CAT |
| signs would closely mirror XGB for this family of tabular models). |
| |
| Returns the top-3 helping and top-3 hurting factors, each paired with |
| natural-language copy that references the Gottman/survival framework |
| rather than raw feature names. |
| """ |
| from __future__ import annotations |
|
|
| import numpy as np |
|
|
| |
| |
| |
| FEATURE_COPY = { |
| |
| "gottman_proxy_love_maps": "You know each other's inner worlds well", |
| "gottman_proxy_shared_goals": "You're aligned on direction and interests", |
| "gottman_proxy_love_maps_x_shared_goals": |
| "You know each other AND share direction (the strongest signal in the model)", |
| "gottman_proxy_ratio": |
| "Your positive-to-negative interaction ratio (Gottman's 5:1 rule)", |
| "gottman_proxy_repair": |
| "You repair well after conflict", |
| "gottman_proxy_contempt": |
| "Contempt shows up in the relationship", |
| "gottman_proxy_criticism": |
| "Criticism is a recurring pattern", |
| "gottman_proxy_defensiveness": |
| "Defensiveness blocks conflict resolution", |
| "gottman_proxy_stonewalling": |
| "Shutting down / withdrawing in conflict", |
| "gottman_proxy_horsemen": |
| "Gottman's Four Horsemen are present", |
| "gottman_proxy_contempt_x_stonewalling": |
| "Contempt + stonewalling together (the deadliest Gottman interaction)", |
| "gottman_proxy_net_risk": |
| "The negative patterns outweigh the positive ones", |
| "gottman_proxy_deep_contempt": |
| "Deep contempt without empathy buffer", |
|
|
| |
| "survival_is_first_marriage": |
| "This is your first marriage (74% lower hazard per Cox PH)", |
| "survival_is_love_marriage": |
| "Love-marriage classification (23% lower hazard in the training data)", |
| "survival_age_factor": |
| "Your age at commitment (each year older = 4% lower hazard)", |
| "survival_age_gap_risk": |
| "The age gap between you", |
| "survival_marriage_number_hazard": |
| "Subsequent marriages carry compounding hazard (HR 1.34 per step)", |
| "survival_in_danger_window": |
| "You're in the 3-7 year window where 41% of divorces cluster", |
| "survival_hazard_score": |
| "Overall Cox-model hazard score", |
|
|
| |
| "age": "Your age at commitment", |
| "age_o": "Your partner's age", |
| "d_age": "Age difference between you", |
| "income": "Career/income proxy", |
| "samerace": "Same-race match", |
| } |
|
|
|
|
| def explain(prediction: dict, feature_vector: np.ndarray, |
| feature_columns: list[str], loaded_model) -> dict: |
| """Compute SHAP top-k for a single prediction. |
| |
| Returns: |
| { |
| "helping": [{"feature": str, "shap": float, "copy": str}, ...], # len 3 |
| "hurting": [{"feature": str, "shap": float, "copy": str}, ...], # len 3 |
| "mover": {"feature": str, "copy": str}, # the one that would most move the needle |
| } |
| """ |
| try: |
| import shap |
| explainer = shap.TreeExplainer(loaded_model.xgb) |
| x = feature_vector.reshape(1, -1) if feature_vector.ndim == 1 else feature_vector |
| shap_values = explainer.shap_values(x) |
| |
| if isinstance(shap_values, list): |
| shap_values = shap_values[1] if len(shap_values) > 1 else shap_values[0] |
| shap_vals = np.asarray(shap_values).flatten() |
| except Exception: |
| |
| |
| |
| importances = getattr(loaded_model.xgb, "feature_importances_", None) |
| if importances is None: |
| importances = np.ones(len(feature_columns)) |
| shap_vals = importances * (feature_vector - 0.5) |
|
|
| |
| pairs = list(zip(feature_columns, shap_vals.tolist())) |
| |
| pairs = [p for p in pairs if abs(p[1]) > 1e-5] |
|
|
| pairs_sorted_desc = sorted(pairs, key=lambda p: p[1], reverse=True) |
| pairs_sorted_asc = sorted(pairs, key=lambda p: p[1]) |
|
|
| helping = [_render(f, v) for f, v in pairs_sorted_desc[:3]] |
| hurting = [_render(f, v) for f, v in pairs_sorted_asc[:3]] |
|
|
| |
| mover_candidates = sorted(pairs, key=lambda p: abs(p[1]), reverse=True) |
| mover = _render(*mover_candidates[0]) if mover_candidates else None |
|
|
| return { |
| "helping": helping, |
| "hurting": hurting, |
| "mover": mover, |
| } |
|
|
|
|
| def _render(feature: str, shap_value: float) -> dict: |
| return { |
| "feature": feature, |
| "shap": shap_value, |
| "copy": FEATURE_COPY.get(feature, _prettify(feature)), |
| } |
|
|
|
|
| def _prettify(col: str) -> str: |
| |
| return col.replace("_", " ").replace("proxy", "").strip().capitalize() |
|
|