""" Module 5: Explainability SHAP summary plot, top 10 features, LIME explanation. """ import os, sys sys.path.insert(0, '/app/fraud_detection') import numpy as np import pandas as pd import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import seaborn as sns import joblib import shap import warnings warnings.filterwarnings('ignore') from ae_model import AutoencoderWrapper, Autoencoder from config import DATA_DIR, MODELS_DIR, FIGURES_DIR, FIG_DPI, FIG_BG plt.style.use('seaborn-v0_8-whitegrid') def shap_analysis(model, X_test, feature_names, model_name='XGBoost'): """SHAP summary plot for best model.""" print("=" * 60) print(f"SHAP ANALYSIS ({model_name})") print("=" * 60) # Use TreeExplainer for tree-based models explainer = shap.TreeExplainer(model) # Use a sample for speed n_samples = min(2000, len(X_test)) X_sample = X_test.iloc[:n_samples] if isinstance(X_test, pd.DataFrame) else X_test[:n_samples] shap_values = explainer.shap_values(X_sample) # For binary classification, shap_values might be a list if isinstance(shap_values, list): shap_vals = shap_values[1] # Class 1 (fraud) else: shap_vals = shap_values # Summary plot fig, ax = plt.subplots(1, 1, figsize=(12, 8), facecolor=FIG_BG) shap.summary_plot(shap_vals, X_sample, feature_names=feature_names, show=False, max_display=20) plt.title(f'SHAP Summary Plot - {model_name}', fontsize=14, fontweight='bold') plt.tight_layout() plt.savefig(os.path.join(FIGURES_DIR, "shap_summary.png"), dpi=FIG_DPI, bbox_inches='tight', facecolor=FIG_BG) plt.savefig(os.path.join(FIGURES_DIR, "shap_summary.pdf"), bbox_inches='tight', facecolor=FIG_BG) plt.close('all') print("Saved: shap_summary.png/pdf") # Top 10 features mean_shap = np.abs(shap_vals).mean(axis=0) feature_importance = pd.DataFrame({ 'Feature': feature_names, 'Mean |SHAP|': mean_shap }).sort_values('Mean |SHAP|', ascending=False) print(f"\nTop 10 Features Driving Fraud Predictions:") print(feature_importance.head(10).to_string(index=False, float_format='%.6f')) # Plot top 10 fig, ax = plt.subplots(1, 1, figsize=(10, 6), facecolor=FIG_BG) top10 = feature_importance.head(10) ax.barh(range(10), top10['Mean |SHAP|'].values[::-1], color='steelblue', edgecolor='black', linewidth=0.3) ax.set_yticks(range(10)) ax.set_yticklabels(top10['Feature'].values[::-1], fontsize=10) ax.set_xlabel('Mean |SHAP Value|', fontsize=12) ax.set_title(f'Top 10 Features Driving Fraud Predictions ({model_name})', fontsize=13, fontweight='bold') plt.tight_layout() plt.savefig(os.path.join(FIGURES_DIR, "shap_top10.png"), dpi=FIG_DPI, bbox_inches='tight', facecolor=FIG_BG) plt.savefig(os.path.join(FIGURES_DIR, "shap_top10.pdf"), bbox_inches='tight', facecolor=FIG_BG) plt.close() print("Saved: shap_top10.png/pdf") feature_importance.to_csv(os.path.join(FIGURES_DIR, "shap_feature_importance.csv"), index=False) return shap_vals, feature_importance def lime_explanation(model, X_test, y_test, feature_names, model_name='XGBoost'): """LIME explanation for one sample prediction.""" print("\n" + "=" * 60) print(f"LIME EXPLANATION ({model_name})") print("=" * 60) from lime.lime_tabular import LimeTabularExplainer # Find a fraud sample that was correctly predicted proba = model.predict_proba(X_test)[:, 1] fraud_mask = y_test == 1 fraud_indices = np.where(fraud_mask)[0] # Find first correctly predicted fraud sample_idx = None for idx in fraud_indices: if proba[idx] > 0.5: sample_idx = idx break if sample_idx is None: sample_idx = fraud_indices[0] print(f"Selected sample index: {sample_idx}") print(f"Actual class: {y_test.iloc[sample_idx]}") print(f"Predicted probability: {proba[sample_idx]:.4f}") # Create LIME explainer X_np = X_test.values if isinstance(X_test, pd.DataFrame) else X_test explainer = LimeTabularExplainer( X_np, feature_names=feature_names, class_names=['Legitimate', 'Fraud'], discretize_continuous=True, random_state=42 ) # Explain single prediction explanation = explainer.explain_instance( X_np[sample_idx], model.predict_proba, num_features=15, top_labels=1 ) # Get the explanation for fraud class (1) label = 1 exp_list = explanation.as_list(label=label) print(f"\nLIME Explanation (Top 15 features for fraud prediction):") for feature, weight in exp_list: direction = "↑ FRAUD" if weight > 0 else "↓ LEGIT" print(f" {feature:50s} → {weight:+.4f} {direction}") # Plot LIME explanation fig, ax = plt.subplots(1, 1, figsize=(12, 7), facecolor=FIG_BG) features = [f for f, w in exp_list] weights = [w for f, w in exp_list] colors = ['#e74c3c' if w > 0 else '#2ecc71' for w in weights] ax.barh(range(len(features)), weights, color=colors, edgecolor='black', linewidth=0.3) ax.set_yticks(range(len(features))) ax.set_yticklabels(features, fontsize=9) ax.set_xlabel('Feature Contribution to Fraud Prediction', fontsize=12) ax.set_title(f'LIME Explanation - Single Fraud Sample ({model_name})\n' f'P(Fraud) = {proba[sample_idx]:.4f}', fontsize=12, fontweight='bold') ax.axvline(x=0, color='black', linewidth=0.5) # Add legend from matplotlib.patches import Patch legend_elements = [Patch(facecolor='#e74c3c', label='Increases Fraud Risk'), Patch(facecolor='#2ecc71', label='Decreases Fraud Risk')] ax.legend(handles=legend_elements, loc='lower right') plt.tight_layout() plt.savefig(os.path.join(FIGURES_DIR, "lime_explanation.png"), dpi=FIG_DPI, bbox_inches='tight', facecolor=FIG_BG) plt.savefig(os.path.join(FIGURES_DIR, "lime_explanation.pdf"), bbox_inches='tight', facecolor=FIG_BG) plt.close() print("Saved: lime_explanation.png/pdf") return explanation def run_explainability(): """Run complete explainability pipeline.""" # Load data and models data = joblib.load(os.path.join(DATA_DIR, "processed_data.joblib")) models = joblib.load(os.path.join(MODELS_DIR, "all_models_with_ae.joblib")) X_test = data['X_test'] y_test = data['y_test'] feature_names = data['feature_names'] # Use best model (XGBoost) best_model = models['XGBoost'] # SHAP analysis shap_vals, feature_importance = shap_analysis(best_model, X_test, feature_names, 'XGBoost') # LIME explanation explanation = lime_explanation(best_model, X_test, y_test, feature_names, 'XGBoost') print("\n" + "=" * 60) print("EXPLAINABILITY COMPLETE") print("=" * 60) return shap_vals, feature_importance, explanation if __name__ == "__main__": shap_vals, feature_importance, explanation = run_explainability()