"""SHAP-based explanation layer. Given a prediction, computes per-feature SHAP values via the XGBoost tree explainer (single source of truth is cheap and fast; the LGB + CAT signs would closely mirror XGB for this family of tabular models). Returns the top-3 helping and top-3 hurting factors, each paired with natural-language copy that references the Gottman/survival framework rather than raw feature names. """ from __future__ import annotations import numpy as np # Human-readable copy for each known feature name. # If a SHAP-ranked feature isn't in this map, we fall back to a prettified # version of its column name. FEATURE_COPY = { # Gottman dimensions "gottman_proxy_love_maps": "You know each other's inner worlds well", "gottman_proxy_shared_goals": "You're aligned on direction and interests", "gottman_proxy_love_maps_x_shared_goals": "You know each other AND share direction (the strongest signal in the model)", "gottman_proxy_ratio": "Your positive-to-negative interaction ratio (Gottman's 5:1 rule)", "gottman_proxy_repair": "You repair well after conflict", "gottman_proxy_contempt": "Contempt shows up in the relationship", "gottman_proxy_criticism": "Criticism is a recurring pattern", "gottman_proxy_defensiveness": "Defensiveness blocks conflict resolution", "gottman_proxy_stonewalling": "Shutting down / withdrawing in conflict", "gottman_proxy_horsemen": "Gottman's Four Horsemen are present", "gottman_proxy_contempt_x_stonewalling": "Contempt + stonewalling together (the deadliest Gottman interaction)", "gottman_proxy_net_risk": "The negative patterns outweigh the positive ones", "gottman_proxy_deep_contempt": "Deep contempt without empathy buffer", # Survival features "survival_is_first_marriage": "This is your first marriage (74% lower hazard per Cox PH)", "survival_is_love_marriage": "Love-marriage classification (23% lower hazard in the training data)", "survival_age_factor": "Your age at commitment (each year older = 4% lower hazard)", "survival_age_gap_risk": "The age gap between you", "survival_marriage_number_hazard": "Subsequent marriages carry compounding hazard (HR 1.34 per step)", "survival_in_danger_window": "You're in the 3-7 year window where 41% of divorces cluster", "survival_hazard_score": "Overall Cox-model hazard score", # Demographics "age": "Your age at commitment", "age_o": "Your partner's age", "d_age": "Age difference between you", "income": "Career/income proxy", "samerace": "Same-race match", } def explain(prediction: dict, feature_vector: np.ndarray, feature_columns: list[str], loaded_model) -> dict: """Compute SHAP top-k for a single prediction. Returns: { "helping": [{"feature": str, "shap": float, "copy": str}, ...], # len 3 "hurting": [{"feature": str, "shap": float, "copy": str}, ...], # len 3 "mover": {"feature": str, "copy": str}, # the one that would most move the needle } """ try: import shap explainer = shap.TreeExplainer(loaded_model.xgb) x = feature_vector.reshape(1, -1) if feature_vector.ndim == 1 else feature_vector shap_values = explainer.shap_values(x) # For binary classification, shap returns shape (1, n_features) if isinstance(shap_values, list): shap_values = shap_values[1] if len(shap_values) > 1 else shap_values[0] shap_vals = np.asarray(shap_values).flatten() except Exception: # Fallback: feature importance from the model, sign-adjusted by # whether the user's value was above or below the feature mean. # This is a graceful degradation, not a silent failure. importances = getattr(loaded_model.xgb, "feature_importances_", None) if importances is None: importances = np.ones(len(feature_columns)) shap_vals = importances * (feature_vector - 0.5) # crude proxy # Pair SHAP values with names and sort pairs = list(zip(feature_columns, shap_vals.tolist())) # Filter out low-magnitude noise pairs = [p for p in pairs if abs(p[1]) > 1e-5] pairs_sorted_desc = sorted(pairs, key=lambda p: p[1], reverse=True) pairs_sorted_asc = sorted(pairs, key=lambda p: p[1]) helping = [_render(f, v) for f, v in pairs_sorted_desc[:3]] hurting = [_render(f, v) for f, v in pairs_sorted_asc[:3]] # Mover = the single feature whose flip would most change the prediction mover_candidates = sorted(pairs, key=lambda p: abs(p[1]), reverse=True) mover = _render(*mover_candidates[0]) if mover_candidates else None return { "helping": helping, "hurting": hurting, "mover": mover, } def _render(feature: str, shap_value: float) -> dict: return { "feature": feature, "shap": shap_value, "copy": FEATURE_COPY.get(feature, _prettify(feature)), } def _prettify(col: str) -> str: # Fallback for un-mapped columns return col.replace("_", " ").replace("proxy", "").strip().capitalize()