File size: 7,160 Bytes
408a9b2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 | """
Module 5: Explainability
SHAP summary plot, top 10 features, LIME explanation.
"""
import os, sys
sys.path.insert(0, '/app/fraud_detection')
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns
import joblib
import shap
import warnings
warnings.filterwarnings('ignore')
from ae_model import AutoencoderWrapper, Autoencoder
from config import DATA_DIR, MODELS_DIR, FIGURES_DIR, FIG_DPI, FIG_BG
plt.style.use('seaborn-v0_8-whitegrid')
def shap_analysis(model, X_test, feature_names, model_name='XGBoost'):
"""SHAP summary plot for best model."""
print("=" * 60)
print(f"SHAP ANALYSIS ({model_name})")
print("=" * 60)
# Use TreeExplainer for tree-based models
explainer = shap.TreeExplainer(model)
# Use a sample for speed
n_samples = min(2000, len(X_test))
X_sample = X_test.iloc[:n_samples] if isinstance(X_test, pd.DataFrame) else X_test[:n_samples]
shap_values = explainer.shap_values(X_sample)
# For binary classification, shap_values might be a list
if isinstance(shap_values, list):
shap_vals = shap_values[1] # Class 1 (fraud)
else:
shap_vals = shap_values
# Summary plot
fig, ax = plt.subplots(1, 1, figsize=(12, 8), facecolor=FIG_BG)
shap.summary_plot(shap_vals, X_sample, feature_names=feature_names,
show=False, max_display=20)
plt.title(f'SHAP Summary Plot - {model_name}', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(os.path.join(FIGURES_DIR, "shap_summary.png"), dpi=FIG_DPI, bbox_inches='tight', facecolor=FIG_BG)
plt.savefig(os.path.join(FIGURES_DIR, "shap_summary.pdf"), bbox_inches='tight', facecolor=FIG_BG)
plt.close('all')
print("Saved: shap_summary.png/pdf")
# Top 10 features
mean_shap = np.abs(shap_vals).mean(axis=0)
feature_importance = pd.DataFrame({
'Feature': feature_names,
'Mean |SHAP|': mean_shap
}).sort_values('Mean |SHAP|', ascending=False)
print(f"\nTop 10 Features Driving Fraud Predictions:")
print(feature_importance.head(10).to_string(index=False, float_format='%.6f'))
# Plot top 10
fig, ax = plt.subplots(1, 1, figsize=(10, 6), facecolor=FIG_BG)
top10 = feature_importance.head(10)
ax.barh(range(10), top10['Mean |SHAP|'].values[::-1], color='steelblue', edgecolor='black', linewidth=0.3)
ax.set_yticks(range(10))
ax.set_yticklabels(top10['Feature'].values[::-1], fontsize=10)
ax.set_xlabel('Mean |SHAP Value|', fontsize=12)
ax.set_title(f'Top 10 Features Driving Fraud Predictions ({model_name})', fontsize=13, fontweight='bold')
plt.tight_layout()
plt.savefig(os.path.join(FIGURES_DIR, "shap_top10.png"), dpi=FIG_DPI, bbox_inches='tight', facecolor=FIG_BG)
plt.savefig(os.path.join(FIGURES_DIR, "shap_top10.pdf"), bbox_inches='tight', facecolor=FIG_BG)
plt.close()
print("Saved: shap_top10.png/pdf")
feature_importance.to_csv(os.path.join(FIGURES_DIR, "shap_feature_importance.csv"), index=False)
return shap_vals, feature_importance
def lime_explanation(model, X_test, y_test, feature_names, model_name='XGBoost'):
"""LIME explanation for one sample prediction."""
print("\n" + "=" * 60)
print(f"LIME EXPLANATION ({model_name})")
print("=" * 60)
from lime.lime_tabular import LimeTabularExplainer
# Find a fraud sample that was correctly predicted
proba = model.predict_proba(X_test)[:, 1]
fraud_mask = y_test == 1
fraud_indices = np.where(fraud_mask)[0]
# Find first correctly predicted fraud
sample_idx = None
for idx in fraud_indices:
if proba[idx] > 0.5:
sample_idx = idx
break
if sample_idx is None:
sample_idx = fraud_indices[0]
print(f"Selected sample index: {sample_idx}")
print(f"Actual class: {y_test.iloc[sample_idx]}")
print(f"Predicted probability: {proba[sample_idx]:.4f}")
# Create LIME explainer
X_np = X_test.values if isinstance(X_test, pd.DataFrame) else X_test
explainer = LimeTabularExplainer(
X_np,
feature_names=feature_names,
class_names=['Legitimate', 'Fraud'],
discretize_continuous=True,
random_state=42
)
# Explain single prediction
explanation = explainer.explain_instance(
X_np[sample_idx],
model.predict_proba,
num_features=15,
top_labels=1
)
# Get the explanation for fraud class (1)
label = 1
exp_list = explanation.as_list(label=label)
print(f"\nLIME Explanation (Top 15 features for fraud prediction):")
for feature, weight in exp_list:
direction = "↑ FRAUD" if weight > 0 else "↓ LEGIT"
print(f" {feature:50s} → {weight:+.4f} {direction}")
# Plot LIME explanation
fig, ax = plt.subplots(1, 1, figsize=(12, 7), facecolor=FIG_BG)
features = [f for f, w in exp_list]
weights = [w for f, w in exp_list]
colors = ['#e74c3c' if w > 0 else '#2ecc71' for w in weights]
ax.barh(range(len(features)), weights, color=colors, edgecolor='black', linewidth=0.3)
ax.set_yticks(range(len(features)))
ax.set_yticklabels(features, fontsize=9)
ax.set_xlabel('Feature Contribution to Fraud Prediction', fontsize=12)
ax.set_title(f'LIME Explanation - Single Fraud Sample ({model_name})\n'
f'P(Fraud) = {proba[sample_idx]:.4f}', fontsize=12, fontweight='bold')
ax.axvline(x=0, color='black', linewidth=0.5)
# Add legend
from matplotlib.patches import Patch
legend_elements = [Patch(facecolor='#e74c3c', label='Increases Fraud Risk'),
Patch(facecolor='#2ecc71', label='Decreases Fraud Risk')]
ax.legend(handles=legend_elements, loc='lower right')
plt.tight_layout()
plt.savefig(os.path.join(FIGURES_DIR, "lime_explanation.png"), dpi=FIG_DPI, bbox_inches='tight', facecolor=FIG_BG)
plt.savefig(os.path.join(FIGURES_DIR, "lime_explanation.pdf"), bbox_inches='tight', facecolor=FIG_BG)
plt.close()
print("Saved: lime_explanation.png/pdf")
return explanation
def run_explainability():
"""Run complete explainability pipeline."""
# Load data and models
data = joblib.load(os.path.join(DATA_DIR, "processed_data.joblib"))
models = joblib.load(os.path.join(MODELS_DIR, "all_models_with_ae.joblib"))
X_test = data['X_test']
y_test = data['y_test']
feature_names = data['feature_names']
# Use best model (XGBoost)
best_model = models['XGBoost']
# SHAP analysis
shap_vals, feature_importance = shap_analysis(best_model, X_test, feature_names, 'XGBoost')
# LIME explanation
explanation = lime_explanation(best_model, X_test, y_test, feature_names, 'XGBoost')
print("\n" + "=" * 60)
print("EXPLAINABILITY COMPLETE")
print("=" * 60)
return shap_vals, feature_importance, explanation
if __name__ == "__main__":
shap_vals, feature_importance, explanation = run_explainability()
|