File size: 5,256 Bytes
853e25d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
"""SHAP-based explanation layer.

Given a prediction, computes per-feature SHAP values via the XGBoost
tree explainer (single source of truth is cheap and fast; the LGB + CAT
signs would closely mirror XGB for this family of tabular models).

Returns the top-3 helping and top-3 hurting factors, each paired with
natural-language copy that references the Gottman/survival framework
rather than raw feature names.
"""
from __future__ import annotations

import numpy as np

# Human-readable copy for each known feature name.
# If a SHAP-ranked feature isn't in this map, we fall back to a prettified
# version of its column name.
FEATURE_COPY = {
    # Gottman dimensions
    "gottman_proxy_love_maps": "You know each other's inner worlds well",
    "gottman_proxy_shared_goals": "You're aligned on direction and interests",
    "gottman_proxy_love_maps_x_shared_goals":
        "You know each other AND share direction (the strongest signal in the model)",
    "gottman_proxy_ratio":
        "Your positive-to-negative interaction ratio (Gottman's 5:1 rule)",
    "gottman_proxy_repair":
        "You repair well after conflict",
    "gottman_proxy_contempt":
        "Contempt shows up in the relationship",
    "gottman_proxy_criticism":
        "Criticism is a recurring pattern",
    "gottman_proxy_defensiveness":
        "Defensiveness blocks conflict resolution",
    "gottman_proxy_stonewalling":
        "Shutting down / withdrawing in conflict",
    "gottman_proxy_horsemen":
        "Gottman's Four Horsemen are present",
    "gottman_proxy_contempt_x_stonewalling":
        "Contempt + stonewalling together (the deadliest Gottman interaction)",
    "gottman_proxy_net_risk":
        "The negative patterns outweigh the positive ones",
    "gottman_proxy_deep_contempt":
        "Deep contempt without empathy buffer",

    # Survival features
    "survival_is_first_marriage":
        "This is your first marriage (74% lower hazard per Cox PH)",
    "survival_is_love_marriage":
        "Love-marriage classification (23% lower hazard in the training data)",
    "survival_age_factor":
        "Your age at commitment (each year older = 4% lower hazard)",
    "survival_age_gap_risk":
        "The age gap between you",
    "survival_marriage_number_hazard":
        "Subsequent marriages carry compounding hazard (HR 1.34 per step)",
    "survival_in_danger_window":
        "You're in the 3-7 year window where 41% of divorces cluster",
    "survival_hazard_score":
        "Overall Cox-model hazard score",

    # Demographics
    "age": "Your age at commitment",
    "age_o": "Your partner's age",
    "d_age": "Age difference between you",
    "income": "Career/income proxy",
    "samerace": "Same-race match",
}


def explain(prediction: dict, feature_vector: np.ndarray,
            feature_columns: list[str], loaded_model) -> dict:
    """Compute SHAP top-k for a single prediction.

    Returns:
      {
        "helping": [{"feature": str, "shap": float, "copy": str}, ...],  # len 3
        "hurting": [{"feature": str, "shap": float, "copy": str}, ...],  # len 3
        "mover": {"feature": str, "copy": str},  # the one that would most move the needle
      }
    """
    try:
        import shap
        explainer = shap.TreeExplainer(loaded_model.xgb)
        x = feature_vector.reshape(1, -1) if feature_vector.ndim == 1 else feature_vector
        shap_values = explainer.shap_values(x)
        # For binary classification, shap returns shape (1, n_features)
        if isinstance(shap_values, list):
            shap_values = shap_values[1] if len(shap_values) > 1 else shap_values[0]
        shap_vals = np.asarray(shap_values).flatten()
    except Exception:
        # Fallback: feature importance from the model, sign-adjusted by
        # whether the user's value was above or below the feature mean.
        # This is a graceful degradation, not a silent failure.
        importances = getattr(loaded_model.xgb, "feature_importances_", None)
        if importances is None:
            importances = np.ones(len(feature_columns))
        shap_vals = importances * (feature_vector - 0.5)  # crude proxy

    # Pair SHAP values with names and sort
    pairs = list(zip(feature_columns, shap_vals.tolist()))
    # Filter out low-magnitude noise
    pairs = [p for p in pairs if abs(p[1]) > 1e-5]

    pairs_sorted_desc = sorted(pairs, key=lambda p: p[1], reverse=True)
    pairs_sorted_asc = sorted(pairs, key=lambda p: p[1])

    helping = [_render(f, v) for f, v in pairs_sorted_desc[:3]]
    hurting = [_render(f, v) for f, v in pairs_sorted_asc[:3]]

    # Mover = the single feature whose flip would most change the prediction
    mover_candidates = sorted(pairs, key=lambda p: abs(p[1]), reverse=True)
    mover = _render(*mover_candidates[0]) if mover_candidates else None

    return {
        "helping": helping,
        "hurting": hurting,
        "mover": mover,
    }


def _render(feature: str, shap_value: float) -> dict:
    return {
        "feature": feature,
        "shap": shap_value,
        "copy": FEATURE_COPY.get(feature, _prettify(feature)),
    }


def _prettify(col: str) -> str:
    # Fallback for un-mapped columns
    return col.replace("_", " ").replace("proxy", "").strip().capitalize()