File size: 7,160 Bytes
408a9b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
"""
Module 5: Explainability
SHAP summary plot, top 10 features, LIME explanation.
"""
import os, sys
sys.path.insert(0, '/app/fraud_detection')
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns
import joblib
import shap
import warnings
warnings.filterwarnings('ignore')

from ae_model import AutoencoderWrapper, Autoencoder
from config import DATA_DIR, MODELS_DIR, FIGURES_DIR, FIG_DPI, FIG_BG

plt.style.use('seaborn-v0_8-whitegrid')


def shap_analysis(model, X_test, feature_names, model_name='XGBoost'):
    """SHAP summary plot for best model."""
    print("=" * 60)
    print(f"SHAP ANALYSIS ({model_name})")
    print("=" * 60)
    
    # Use TreeExplainer for tree-based models
    explainer = shap.TreeExplainer(model)
    
    # Use a sample for speed
    n_samples = min(2000, len(X_test))
    X_sample = X_test.iloc[:n_samples] if isinstance(X_test, pd.DataFrame) else X_test[:n_samples]
    
    shap_values = explainer.shap_values(X_sample)
    
    # For binary classification, shap_values might be a list
    if isinstance(shap_values, list):
        shap_vals = shap_values[1]  # Class 1 (fraud)
    else:
        shap_vals = shap_values
    
    # Summary plot
    fig, ax = plt.subplots(1, 1, figsize=(12, 8), facecolor=FIG_BG)
    shap.summary_plot(shap_vals, X_sample, feature_names=feature_names, 
                     show=False, max_display=20)
    plt.title(f'SHAP Summary Plot - {model_name}', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig(os.path.join(FIGURES_DIR, "shap_summary.png"), dpi=FIG_DPI, bbox_inches='tight', facecolor=FIG_BG)
    plt.savefig(os.path.join(FIGURES_DIR, "shap_summary.pdf"), bbox_inches='tight', facecolor=FIG_BG)
    plt.close('all')
    print("Saved: shap_summary.png/pdf")
    
    # Top 10 features
    mean_shap = np.abs(shap_vals).mean(axis=0)
    feature_importance = pd.DataFrame({
        'Feature': feature_names,
        'Mean |SHAP|': mean_shap
    }).sort_values('Mean |SHAP|', ascending=False)
    
    print(f"\nTop 10 Features Driving Fraud Predictions:")
    print(feature_importance.head(10).to_string(index=False, float_format='%.6f'))
    
    # Plot top 10
    fig, ax = plt.subplots(1, 1, figsize=(10, 6), facecolor=FIG_BG)
    top10 = feature_importance.head(10)
    ax.barh(range(10), top10['Mean |SHAP|'].values[::-1], color='steelblue', edgecolor='black', linewidth=0.3)
    ax.set_yticks(range(10))
    ax.set_yticklabels(top10['Feature'].values[::-1], fontsize=10)
    ax.set_xlabel('Mean |SHAP Value|', fontsize=12)
    ax.set_title(f'Top 10 Features Driving Fraud Predictions ({model_name})', fontsize=13, fontweight='bold')
    plt.tight_layout()
    plt.savefig(os.path.join(FIGURES_DIR, "shap_top10.png"), dpi=FIG_DPI, bbox_inches='tight', facecolor=FIG_BG)
    plt.savefig(os.path.join(FIGURES_DIR, "shap_top10.pdf"), bbox_inches='tight', facecolor=FIG_BG)
    plt.close()
    print("Saved: shap_top10.png/pdf")
    
    feature_importance.to_csv(os.path.join(FIGURES_DIR, "shap_feature_importance.csv"), index=False)
    
    return shap_vals, feature_importance


def lime_explanation(model, X_test, y_test, feature_names, model_name='XGBoost'):
    """LIME explanation for one sample prediction."""
    print("\n" + "=" * 60)
    print(f"LIME EXPLANATION ({model_name})")
    print("=" * 60)
    
    from lime.lime_tabular import LimeTabularExplainer
    
    # Find a fraud sample that was correctly predicted
    proba = model.predict_proba(X_test)[:, 1]
    fraud_mask = y_test == 1
    fraud_indices = np.where(fraud_mask)[0]
    
    # Find first correctly predicted fraud
    sample_idx = None
    for idx in fraud_indices:
        if proba[idx] > 0.5:
            sample_idx = idx
            break
    
    if sample_idx is None:
        sample_idx = fraud_indices[0]
    
    print(f"Selected sample index: {sample_idx}")
    print(f"Actual class: {y_test.iloc[sample_idx]}")
    print(f"Predicted probability: {proba[sample_idx]:.4f}")
    
    # Create LIME explainer
    X_np = X_test.values if isinstance(X_test, pd.DataFrame) else X_test
    
    explainer = LimeTabularExplainer(
        X_np,
        feature_names=feature_names,
        class_names=['Legitimate', 'Fraud'],
        discretize_continuous=True,
        random_state=42
    )
    
    # Explain single prediction
    explanation = explainer.explain_instance(
        X_np[sample_idx],
        model.predict_proba,
        num_features=15,
        top_labels=1
    )
    
    # Get the explanation for fraud class (1)
    label = 1
    exp_list = explanation.as_list(label=label)
    
    print(f"\nLIME Explanation (Top 15 features for fraud prediction):")
    for feature, weight in exp_list:
        direction = "↑ FRAUD" if weight > 0 else "↓ LEGIT"
        print(f"  {feature:50s}{weight:+.4f} {direction}")
    
    # Plot LIME explanation
    fig, ax = plt.subplots(1, 1, figsize=(12, 7), facecolor=FIG_BG)
    
    features = [f for f, w in exp_list]
    weights = [w for f, w in exp_list]
    colors = ['#e74c3c' if w > 0 else '#2ecc71' for w in weights]
    
    ax.barh(range(len(features)), weights, color=colors, edgecolor='black', linewidth=0.3)
    ax.set_yticks(range(len(features)))
    ax.set_yticklabels(features, fontsize=9)
    ax.set_xlabel('Feature Contribution to Fraud Prediction', fontsize=12)
    ax.set_title(f'LIME Explanation - Single Fraud Sample ({model_name})\n'
                f'P(Fraud) = {proba[sample_idx]:.4f}', fontsize=12, fontweight='bold')
    ax.axvline(x=0, color='black', linewidth=0.5)
    
    # Add legend
    from matplotlib.patches import Patch
    legend_elements = [Patch(facecolor='#e74c3c', label='Increases Fraud Risk'),
                      Patch(facecolor='#2ecc71', label='Decreases Fraud Risk')]
    ax.legend(handles=legend_elements, loc='lower right')
    
    plt.tight_layout()
    plt.savefig(os.path.join(FIGURES_DIR, "lime_explanation.png"), dpi=FIG_DPI, bbox_inches='tight', facecolor=FIG_BG)
    plt.savefig(os.path.join(FIGURES_DIR, "lime_explanation.pdf"), bbox_inches='tight', facecolor=FIG_BG)
    plt.close()
    print("Saved: lime_explanation.png/pdf")
    
    return explanation


def run_explainability():
    """Run complete explainability pipeline."""
    # Load data and models
    data = joblib.load(os.path.join(DATA_DIR, "processed_data.joblib"))
    models = joblib.load(os.path.join(MODELS_DIR, "all_models_with_ae.joblib"))
    
    X_test = data['X_test']
    y_test = data['y_test']
    feature_names = data['feature_names']
    
    # Use best model (XGBoost)
    best_model = models['XGBoost']
    
    # SHAP analysis
    shap_vals, feature_importance = shap_analysis(best_model, X_test, feature_names, 'XGBoost')
    
    # LIME explanation
    explanation = lime_explanation(best_model, X_test, y_test, feature_names, 'XGBoost')
    
    print("\n" + "=" * 60)
    print("EXPLAINABILITY COMPLETE")
    print("=" * 60)
    
    return shap_vals, feature_importance, explanation


if __name__ == "__main__":
    shap_vals, feature_importance, explanation = run_explainability()