| """ |
| Module 5: Explainability |
| SHAP summary plot, top 10 features, LIME explanation. |
| """ |
| import os, sys |
| sys.path.insert(0, '/app/fraud_detection') |
| import numpy as np |
| import pandas as pd |
| import matplotlib |
| matplotlib.use('Agg') |
| import matplotlib.pyplot as plt |
| import seaborn as sns |
| import joblib |
| import shap |
| import warnings |
| warnings.filterwarnings('ignore') |
|
|
| from ae_model import AutoencoderWrapper, Autoencoder |
| from config import DATA_DIR, MODELS_DIR, FIGURES_DIR, FIG_DPI, FIG_BG |
|
|
| plt.style.use('seaborn-v0_8-whitegrid') |
|
|
|
|
| def shap_analysis(model, X_test, feature_names, model_name='XGBoost'): |
| """SHAP summary plot for best model.""" |
| print("=" * 60) |
| print(f"SHAP ANALYSIS ({model_name})") |
| print("=" * 60) |
| |
| |
| explainer = shap.TreeExplainer(model) |
| |
| |
| n_samples = min(2000, len(X_test)) |
| X_sample = X_test.iloc[:n_samples] if isinstance(X_test, pd.DataFrame) else X_test[:n_samples] |
| |
| shap_values = explainer.shap_values(X_sample) |
| |
| |
| if isinstance(shap_values, list): |
| shap_vals = shap_values[1] |
| else: |
| shap_vals = shap_values |
| |
| |
| fig, ax = plt.subplots(1, 1, figsize=(12, 8), facecolor=FIG_BG) |
| shap.summary_plot(shap_vals, X_sample, feature_names=feature_names, |
| show=False, max_display=20) |
| plt.title(f'SHAP Summary Plot - {model_name}', fontsize=14, fontweight='bold') |
| plt.tight_layout() |
| plt.savefig(os.path.join(FIGURES_DIR, "shap_summary.png"), dpi=FIG_DPI, bbox_inches='tight', facecolor=FIG_BG) |
| plt.savefig(os.path.join(FIGURES_DIR, "shap_summary.pdf"), bbox_inches='tight', facecolor=FIG_BG) |
| plt.close('all') |
| print("Saved: shap_summary.png/pdf") |
| |
| |
| mean_shap = np.abs(shap_vals).mean(axis=0) |
| feature_importance = pd.DataFrame({ |
| 'Feature': feature_names, |
| 'Mean |SHAP|': mean_shap |
| }).sort_values('Mean |SHAP|', ascending=False) |
| |
| print(f"\nTop 10 Features Driving Fraud Predictions:") |
| print(feature_importance.head(10).to_string(index=False, float_format='%.6f')) |
| |
| |
| fig, ax = plt.subplots(1, 1, figsize=(10, 6), facecolor=FIG_BG) |
| top10 = feature_importance.head(10) |
| ax.barh(range(10), top10['Mean |SHAP|'].values[::-1], color='steelblue', edgecolor='black', linewidth=0.3) |
| ax.set_yticks(range(10)) |
| ax.set_yticklabels(top10['Feature'].values[::-1], fontsize=10) |
| ax.set_xlabel('Mean |SHAP Value|', fontsize=12) |
| ax.set_title(f'Top 10 Features Driving Fraud Predictions ({model_name})', fontsize=13, fontweight='bold') |
| plt.tight_layout() |
| plt.savefig(os.path.join(FIGURES_DIR, "shap_top10.png"), dpi=FIG_DPI, bbox_inches='tight', facecolor=FIG_BG) |
| plt.savefig(os.path.join(FIGURES_DIR, "shap_top10.pdf"), bbox_inches='tight', facecolor=FIG_BG) |
| plt.close() |
| print("Saved: shap_top10.png/pdf") |
| |
| feature_importance.to_csv(os.path.join(FIGURES_DIR, "shap_feature_importance.csv"), index=False) |
| |
| return shap_vals, feature_importance |
|
|
|
|
| def lime_explanation(model, X_test, y_test, feature_names, model_name='XGBoost'): |
| """LIME explanation for one sample prediction.""" |
| print("\n" + "=" * 60) |
| print(f"LIME EXPLANATION ({model_name})") |
| print("=" * 60) |
| |
| from lime.lime_tabular import LimeTabularExplainer |
| |
| |
| proba = model.predict_proba(X_test)[:, 1] |
| fraud_mask = y_test == 1 |
| fraud_indices = np.where(fraud_mask)[0] |
| |
| |
| sample_idx = None |
| for idx in fraud_indices: |
| if proba[idx] > 0.5: |
| sample_idx = idx |
| break |
| |
| if sample_idx is None: |
| sample_idx = fraud_indices[0] |
| |
| print(f"Selected sample index: {sample_idx}") |
| print(f"Actual class: {y_test.iloc[sample_idx]}") |
| print(f"Predicted probability: {proba[sample_idx]:.4f}") |
| |
| |
| X_np = X_test.values if isinstance(X_test, pd.DataFrame) else X_test |
| |
| explainer = LimeTabularExplainer( |
| X_np, |
| feature_names=feature_names, |
| class_names=['Legitimate', 'Fraud'], |
| discretize_continuous=True, |
| random_state=42 |
| ) |
| |
| |
| explanation = explainer.explain_instance( |
| X_np[sample_idx], |
| model.predict_proba, |
| num_features=15, |
| top_labels=1 |
| ) |
| |
| |
| label = 1 |
| exp_list = explanation.as_list(label=label) |
| |
| print(f"\nLIME Explanation (Top 15 features for fraud prediction):") |
| for feature, weight in exp_list: |
| direction = "↑ FRAUD" if weight > 0 else "↓ LEGIT" |
| print(f" {feature:50s} → {weight:+.4f} {direction}") |
| |
| |
| fig, ax = plt.subplots(1, 1, figsize=(12, 7), facecolor=FIG_BG) |
| |
| features = [f for f, w in exp_list] |
| weights = [w for f, w in exp_list] |
| colors = ['#e74c3c' if w > 0 else '#2ecc71' for w in weights] |
| |
| ax.barh(range(len(features)), weights, color=colors, edgecolor='black', linewidth=0.3) |
| ax.set_yticks(range(len(features))) |
| ax.set_yticklabels(features, fontsize=9) |
| ax.set_xlabel('Feature Contribution to Fraud Prediction', fontsize=12) |
| ax.set_title(f'LIME Explanation - Single Fraud Sample ({model_name})\n' |
| f'P(Fraud) = {proba[sample_idx]:.4f}', fontsize=12, fontweight='bold') |
| ax.axvline(x=0, color='black', linewidth=0.5) |
| |
| |
| from matplotlib.patches import Patch |
| legend_elements = [Patch(facecolor='#e74c3c', label='Increases Fraud Risk'), |
| Patch(facecolor='#2ecc71', label='Decreases Fraud Risk')] |
| ax.legend(handles=legend_elements, loc='lower right') |
| |
| plt.tight_layout() |
| plt.savefig(os.path.join(FIGURES_DIR, "lime_explanation.png"), dpi=FIG_DPI, bbox_inches='tight', facecolor=FIG_BG) |
| plt.savefig(os.path.join(FIGURES_DIR, "lime_explanation.pdf"), bbox_inches='tight', facecolor=FIG_BG) |
| plt.close() |
| print("Saved: lime_explanation.png/pdf") |
| |
| return explanation |
|
|
|
|
| def run_explainability(): |
| """Run complete explainability pipeline.""" |
| |
| data = joblib.load(os.path.join(DATA_DIR, "processed_data.joblib")) |
| models = joblib.load(os.path.join(MODELS_DIR, "all_models_with_ae.joblib")) |
| |
| X_test = data['X_test'] |
| y_test = data['y_test'] |
| feature_names = data['feature_names'] |
| |
| |
| best_model = models['XGBoost'] |
| |
| |
| shap_vals, feature_importance = shap_analysis(best_model, X_test, feature_names, 'XGBoost') |
| |
| |
| explanation = lime_explanation(best_model, X_test, y_test, feature_names, 'XGBoost') |
| |
| print("\n" + "=" * 60) |
| print("EXPLAINABILITY COMPLETE") |
| print("=" * 60) |
| |
| return shap_vals, feature_importance, explanation |
|
|
|
|
| if __name__ == "__main__": |
| shap_vals, feature_importance, explanation = run_explainability() |
|
|