fraud-detection-system / explainability.py
rajvivan's picture
Complete fraud detection system: code, figures, models, paper
408a9b2 verified
"""
Module 5: Explainability
SHAP summary plot, top 10 features, LIME explanation.
"""
import os, sys
sys.path.insert(0, '/app/fraud_detection')
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns
import joblib
import shap
import warnings
warnings.filterwarnings('ignore')
from ae_model import AutoencoderWrapper, Autoencoder
from config import DATA_DIR, MODELS_DIR, FIGURES_DIR, FIG_DPI, FIG_BG
plt.style.use('seaborn-v0_8-whitegrid')
def shap_analysis(model, X_test, feature_names, model_name='XGBoost'):
"""SHAP summary plot for best model."""
print("=" * 60)
print(f"SHAP ANALYSIS ({model_name})")
print("=" * 60)
# Use TreeExplainer for tree-based models
explainer = shap.TreeExplainer(model)
# Use a sample for speed
n_samples = min(2000, len(X_test))
X_sample = X_test.iloc[:n_samples] if isinstance(X_test, pd.DataFrame) else X_test[:n_samples]
shap_values = explainer.shap_values(X_sample)
# For binary classification, shap_values might be a list
if isinstance(shap_values, list):
shap_vals = shap_values[1] # Class 1 (fraud)
else:
shap_vals = shap_values
# Summary plot
fig, ax = plt.subplots(1, 1, figsize=(12, 8), facecolor=FIG_BG)
shap.summary_plot(shap_vals, X_sample, feature_names=feature_names,
show=False, max_display=20)
plt.title(f'SHAP Summary Plot - {model_name}', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(os.path.join(FIGURES_DIR, "shap_summary.png"), dpi=FIG_DPI, bbox_inches='tight', facecolor=FIG_BG)
plt.savefig(os.path.join(FIGURES_DIR, "shap_summary.pdf"), bbox_inches='tight', facecolor=FIG_BG)
plt.close('all')
print("Saved: shap_summary.png/pdf")
# Top 10 features
mean_shap = np.abs(shap_vals).mean(axis=0)
feature_importance = pd.DataFrame({
'Feature': feature_names,
'Mean |SHAP|': mean_shap
}).sort_values('Mean |SHAP|', ascending=False)
print(f"\nTop 10 Features Driving Fraud Predictions:")
print(feature_importance.head(10).to_string(index=False, float_format='%.6f'))
# Plot top 10
fig, ax = plt.subplots(1, 1, figsize=(10, 6), facecolor=FIG_BG)
top10 = feature_importance.head(10)
ax.barh(range(10), top10['Mean |SHAP|'].values[::-1], color='steelblue', edgecolor='black', linewidth=0.3)
ax.set_yticks(range(10))
ax.set_yticklabels(top10['Feature'].values[::-1], fontsize=10)
ax.set_xlabel('Mean |SHAP Value|', fontsize=12)
ax.set_title(f'Top 10 Features Driving Fraud Predictions ({model_name})', fontsize=13, fontweight='bold')
plt.tight_layout()
plt.savefig(os.path.join(FIGURES_DIR, "shap_top10.png"), dpi=FIG_DPI, bbox_inches='tight', facecolor=FIG_BG)
plt.savefig(os.path.join(FIGURES_DIR, "shap_top10.pdf"), bbox_inches='tight', facecolor=FIG_BG)
plt.close()
print("Saved: shap_top10.png/pdf")
feature_importance.to_csv(os.path.join(FIGURES_DIR, "shap_feature_importance.csv"), index=False)
return shap_vals, feature_importance
def lime_explanation(model, X_test, y_test, feature_names, model_name='XGBoost'):
"""LIME explanation for one sample prediction."""
print("\n" + "=" * 60)
print(f"LIME EXPLANATION ({model_name})")
print("=" * 60)
from lime.lime_tabular import LimeTabularExplainer
# Find a fraud sample that was correctly predicted
proba = model.predict_proba(X_test)[:, 1]
fraud_mask = y_test == 1
fraud_indices = np.where(fraud_mask)[0]
# Find first correctly predicted fraud
sample_idx = None
for idx in fraud_indices:
if proba[idx] > 0.5:
sample_idx = idx
break
if sample_idx is None:
sample_idx = fraud_indices[0]
print(f"Selected sample index: {sample_idx}")
print(f"Actual class: {y_test.iloc[sample_idx]}")
print(f"Predicted probability: {proba[sample_idx]:.4f}")
# Create LIME explainer
X_np = X_test.values if isinstance(X_test, pd.DataFrame) else X_test
explainer = LimeTabularExplainer(
X_np,
feature_names=feature_names,
class_names=['Legitimate', 'Fraud'],
discretize_continuous=True,
random_state=42
)
# Explain single prediction
explanation = explainer.explain_instance(
X_np[sample_idx],
model.predict_proba,
num_features=15,
top_labels=1
)
# Get the explanation for fraud class (1)
label = 1
exp_list = explanation.as_list(label=label)
print(f"\nLIME Explanation (Top 15 features for fraud prediction):")
for feature, weight in exp_list:
direction = "↑ FRAUD" if weight > 0 else "↓ LEGIT"
print(f" {feature:50s}{weight:+.4f} {direction}")
# Plot LIME explanation
fig, ax = plt.subplots(1, 1, figsize=(12, 7), facecolor=FIG_BG)
features = [f for f, w in exp_list]
weights = [w for f, w in exp_list]
colors = ['#e74c3c' if w > 0 else '#2ecc71' for w in weights]
ax.barh(range(len(features)), weights, color=colors, edgecolor='black', linewidth=0.3)
ax.set_yticks(range(len(features)))
ax.set_yticklabels(features, fontsize=9)
ax.set_xlabel('Feature Contribution to Fraud Prediction', fontsize=12)
ax.set_title(f'LIME Explanation - Single Fraud Sample ({model_name})\n'
f'P(Fraud) = {proba[sample_idx]:.4f}', fontsize=12, fontweight='bold')
ax.axvline(x=0, color='black', linewidth=0.5)
# Add legend
from matplotlib.patches import Patch
legend_elements = [Patch(facecolor='#e74c3c', label='Increases Fraud Risk'),
Patch(facecolor='#2ecc71', label='Decreases Fraud Risk')]
ax.legend(handles=legend_elements, loc='lower right')
plt.tight_layout()
plt.savefig(os.path.join(FIGURES_DIR, "lime_explanation.png"), dpi=FIG_DPI, bbox_inches='tight', facecolor=FIG_BG)
plt.savefig(os.path.join(FIGURES_DIR, "lime_explanation.pdf"), bbox_inches='tight', facecolor=FIG_BG)
plt.close()
print("Saved: lime_explanation.png/pdf")
return explanation
def run_explainability():
"""Run complete explainability pipeline."""
# Load data and models
data = joblib.load(os.path.join(DATA_DIR, "processed_data.joblib"))
models = joblib.load(os.path.join(MODELS_DIR, "all_models_with_ae.joblib"))
X_test = data['X_test']
y_test = data['y_test']
feature_names = data['feature_names']
# Use best model (XGBoost)
best_model = models['XGBoost']
# SHAP analysis
shap_vals, feature_importance = shap_analysis(best_model, X_test, feature_names, 'XGBoost')
# LIME explanation
explanation = lime_explanation(best_model, X_test, y_test, feature_names, 'XGBoost')
print("\n" + "=" * 60)
print("EXPLAINABILITY COMPLETE")
print("=" * 60)
return shap_vals, feature_importance, explanation
if __name__ == "__main__":
shap_vals, feature_importance, explanation = run_explainability()