import gradio as gr import pandas as pd import numpy as np import joblib import shap import matplotlib import traceback import warnings from sklearn.metrics import accuracy_score, confusion_matrix warnings.filterwarnings('ignore') matplotlib.use('Agg') import matplotlib.pyplot as plt # ========================================== # 1. LOAD TRAINED ARTIFACTS FROM COLAB MEMORY # ========================================== print("Loading Model Artifacts...") try: best_model = joblib.load('ensemble_model.pkl') scaler = joblib.load('scaler.pkl') imputer = joblib.load('imputer.pkl') encoder = joblib.load('encoder.pkl') FEATURE_NAMES = joblib.load('feature_names.pkl') cat_columns = joblib.load('cat_columns.pkl') # Extract XGBoost from StackingClassifier for SHAP explainability xgb_base = best_model.named_estimators_['xgb'] explainer = shap.TreeExplainer(xgb_base) print("All artifacts loaded successfully.") except Exception as e: print(f"Error loading artifacts: {e}. Ensure the training script ran successfully.") target_names = ['Negative', 'Malaria', 'SCA', 'Co-infection'] # ========================================== # 2. CORE PROCESSING & PREDICTION LOGIC # ========================================== def preprocess_input(input_df): """Replicates the exact Feature Engineering & Preprocessing from Training""" df = input_df.copy() # Feature Engineering symptom_cols = ['fever', 'chills', 'headache', 'muscle_aches', 'fatigue', 'loss_of_appetite', 'jaundice', 'abdominal_pain', 'joint_pain', 'splenomegaly', 'pallor', 'lymphadenopathy'] df['symptom_severity_score'] = df[[c for c in symptom_cols if c in df.columns]].sum(axis=1) if 'age' in df.columns: df['age_group'] = pd.cut(df['age'], bins=[-1, 5, 12, 55, 120], labels=[0, 1, 2, 3]).astype(float) if 'hb' in df.columns and 'wbc' in df.columns: df['infection_anemia_ratio'] = df['wbc'] / (df['hb'] + 1e-5) # Align with model input shapes for c in set(FEATURE_NAMES) - set(df.columns): df[c] = np.nan df_aligned = df[FEATURE_NAMES].copy() # Categorical Encoding MISSING_STR = 'MISSING_CAT' if cat_columns: present_cats = [c for c in cat_columns if c in df_aligned.columns] if present_cats: df_aligned[present_cats] = df_aligned[present_cats].astype(str).replace(['nan', 'None'], np.nan) df_aligned[present_cats] = df_aligned[present_cats].fillna(MISSING_STR) df_aligned[present_cats] = encoder.transform(df_aligned[present_cats]) for i, col in enumerate(cat_columns): if col in present_cats and MISSING_STR in encoder.categories_[i]: missing_code = list(encoder.categories_[i]).index(MISSING_STR) df_aligned[col] = df_aligned[col].replace(missing_code, np.nan) for col in df_aligned.columns: df_aligned[col] = pd.to_numeric(df_aligned[col], errors='coerce') # Impute and Scale X_imp = pd.DataFrame(imputer.transform(df_aligned), columns=FEATURE_NAMES) X_scaled = pd.DataFrame(scaler.transform(X_imp), columns=FEATURE_NAMES) return X_scaled def get_specific_coinfection_type(hb, retic, hb_decline, hb_s): """Determines granular sub-type of Co-infection based on critical markers""" if hb < 5.0: return "Co-infection: Severe Hyperhemolytic Malarial Crisis" elif retic > 8.0: return "Co-infection: Acute Hemolytic Malarial Crisis" elif hb_decline and hb_s > 0: return "Co-infection: Rapidly Progressing Vaso-occlusive Malarial Crisis" else: return "Co-infection: Concurrent Malaria & Sickle Cell Crisis" def get_clinical_recs(diag, rule_triggered=None): recs = f"### Clinical Decision Support Protocol\n\n" if rule_triggered: recs += f"**Critical Protocol Triggered:** *{rule_triggered}*\n\n" if 'Malaria' in diag and 'Co-infection' not in diag: recs += "**Protocol:** Initiate Artemisinin-based Combination Therapy (ACT) per WHO guidelines.\n" elif diag == 'SCA': recs += "**Protocol:** Administer IV Fluids, oxygen therapy, and comprehensive pain management.\n" elif 'Co-infection' in diag: recs += "**Urgent Protocol:** High risk of hyperhemolytic or severe vaso-occlusive crisis.\n" recs += "- **Action:** Immediate admission to high-dependency unit. Initiate rapid intravenous antimalarials, aggressive hydration, and prepare for potential blood transfusion.\n" else: recs += "**Action:** Patient is currently negative for active Malaria and SCA crisis.\n" recs += "- **Follow-up:** Screen for Typhoid, Dengue, or other viral infections if febrile symptoms persist.\n" recs += "\n---\n### Diagnostic Context Notes\n" recs += "- **Overlapping Symptoms:** Fever, Fatigue, Jaundice, Splenomegaly, and Headache *(Headache is uncommon in SCA unless accompanied by severe anemia, cerebral malaria, or stroke risk).* \n" recs += "- **Co-infection Prevalences:** Key clinical indicators for Co-infection include Severe Pallor + Jaundice, High fever, Splenomegaly + malaria, and Extreme Reticulocyte (>8%) + malaria." return recs def generate_shap_plot(X_scaled): try: shap_values = explainer.shap_values(X_scaled) if isinstance(shap_values, list): pat_shap = shap_values[3][0] base_val = explainer.expected_value[3] elif len(shap_values.shape) == 3: pat_shap = shap_values[0, :, 3] base_val = explainer.expected_value[3] if isinstance(explainer.expected_value, list) else explainer.expected_value else: pat_shap = shap_values[0] base_val = explainer.expected_value fig, ax = plt.subplots(figsize=(7, 5)) ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) explanation = shap.Explanation(values=pat_shap, base_values=base_val, data=X_scaled.iloc[0], feature_names=FEATURE_NAMES) shap.waterfall_plot(explanation, show=False) plt.title("XAI Feature Contribution (Impact on Co-Infection Risk)", fontsize=11, fontweight='bold') plt.tight_layout() return fig except Exception as e: fig, ax = plt.subplots(figsize=(6,4)) ax.text(0.5, 0.5, f"Interpretability Module Offline:\n{str(e)}", ha='center', va='center') return fig def manual_inference(age, sex, temp, hb, wbc, platelets, hb_a, hb_s, hb_f, malaria_rdt, reticulocyte, hb_rapid_decline, fever, chills, headache, muscle_aches, fatigue, loss_of_appetite, jaundice, abdominal_pain, joint_pain, splenomegaly, pallor, lymphadenopathy): try: co_infection_flag = False rule_triggered = "" specific_coinfection_name = "" # Hardcoded Critical Clinical Override Rules if hb < 5.0: co_infection_flag = True rule_triggered = "Hemoglobin below critical threshold (5.0 g/dL)" elif reticulocyte > 8.0 and malaria_rdt == "Positive": co_infection_flag = True rule_triggered = "Extreme Reticulocyte (>8%) + Positive Malaria RDT" elif hb_rapid_decline and malaria_rdt == "Positive" and hb_s > 0: co_infection_flag = True rule_triggered = "Rapid Hb decline (>1.5g/dL in 48h) + Positive Malaria + SCA Genotype" if co_infection_flag: specific_coinfection_name = get_specific_coinfection_type(hb, reticulocyte, hb_rapid_decline, hb_s) input_data = pd.DataFrame({ 'age': [age], 'sex': [sex], 'temp': [temp], 'hb': [hb], 'wbc': [wbc], 'platelets': [platelets], 'hb_a': [hb_a], 'hb_s': [hb_s], 'hb_f': [hb_f], 'malaria_rdt': [1.0 if malaria_rdt == "Positive" else 0.0], 'reticulocyte': [reticulocyte], 'hb_rapid_decline': [1.0 if hb_rapid_decline else 0.0], 'fever': [1.0 if fever else 0.0], 'chills': [1.0 if chills else 0.0], 'headache': [1.0 if headache else 0.0], 'muscle_aches': [1.0 if muscle_aches else 0.0], 'fatigue': [1.0 if fatigue else 0.0], 'loss_of_appetite': [1.0 if loss_of_appetite else 0.0], 'jaundice': [1.0 if jaundice else 0.0], 'abdominal_pain': [1.0 if abdominal_pain else 0.0], 'joint_pain': [1.0 if joint_pain else 0.0], 'splenomegaly': [1.0 if splenomegaly else 0.0], 'pallor': [1.0 if pallor else 0.0], 'lymphadenopathy': [1.0 if lymphadenopathy else 0.0] }) X_scaled = preprocess_input(input_data) probs = best_model.predict_proba(X_scaled)[0] # Map probabilities to class names prob_dict = {target_names[i]: probs[i] * 100 for i in range(len(target_names))} # Apply Clinical Overrides if necessary if co_infection_flag: primary_diag = specific_coinfection_name # Adjust probabilities to reflect the clinical override prob_dict = { specific_coinfection_name: 100.0, 'Malaria (Override)': prob_dict['Malaria'], 'SCA (Override)': prob_dict['SCA'], 'Negative': 0.0 } else: pred_idx = np.argmax(probs) primary_diag = target_names[pred_idx] # If AI predicted co-infection without triggering rules, still give it a specific name if primary_diag == 'Co-infection': primary_diag = get_specific_coinfection_type(hb, reticulocyte, hb_rapid_decline, hb_s) prob_dict[primary_diag] = prob_dict.pop('Co-infection') # Formatting Output Markdown diag_output = f"## Primary Diagnosis: {primary_diag}\n\n### Comprehensive Confidence Breakdown:\n" # Sort and display probabilities descending sorted_probs = sorted(prob_dict.items(), key=lambda x: x[1], reverse=True) for disease, conf in sorted_probs: if 'Co-infection' in disease and 'Override' not in disease: diag_output += f"- **{disease}**: {conf:.1f}%\n" else: diag_output += f"- **{disease}**: {conf:.1f}%\n" recs = get_clinical_recs(primary_diag, rule_triggered) fig = generate_shap_plot(X_scaled) return diag_output, recs, fig except Exception as e: return f"### Inference Error\n```\n{traceback.format_exc()}\n```", "System Error.", None # ========================================== # 3. SYSTEM VALIDATION HELPER FUNCTIONS # ========================================== def load_systematic_metrics(): try: y_test_val = joblib.load('y_test_val.pkl') y_probs_val = joblib.load('y_probs_val.pkl') y_pred_val = np.argmax(y_probs_val, axis=1) acc = accuracy_score(y_test_val, y_pred_val) cm = confusion_matrix(y_test_val, y_pred_val) sens_list, spec_list = [], [] for i in range(len(cm)): tp = cm[i,i] fn = np.sum(cm[i,:]) - tp fp = np.sum(cm[:,i]) - tp tn = np.sum(cm) - tp - fn - fp sens_list.append(tp / (tp + fn) if (tp + fn) > 0 else 0) spec_list.append(tn / (tn + fp) if (tn + fp) > 0 else 0) sens = np.mean(sens_list) spec = np.mean(spec_list) return f"### Systematic Evaluation Metrics (Held-out Cohort)\n\n- **Overall Accuracy**: {acc*100:.2f}%\n- **Sensitivity (Macro)**: {sens*100:.2f}%\n- **Specificity (Macro)**: {spec*100:.2f}%" except Exception as e: return f"Error loading validation metrics: Ensure 'y_test_val.pkl' and 'y_probs_val.pkl' exist in memory. \n({str(e)})" def check_calibration(class_name): try: from sklearn.calibration import CalibrationDisplay y_test_val = joblib.load('y_test_val.pkl') y_probs_val = joblib.load('y_probs_val.pkl') class_idx = target_names.index(class_name) y_true_binary = (y_test_val == class_idx).astype(int) y_prob_class = y_probs_val[:, class_idx] fig, ax = plt.subplots(figsize=(6, 5)) ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) CalibrationDisplay.from_predictions(y_true_binary, y_prob_class, n_bins=10, ax=ax, name=class_name) plt.title(f"Reliability Curve (Calibration) for {class_name}", fontweight='bold') plt.tight_layout() return fig except Exception as e: fig, ax = plt.subplots() ax.text(0.5, 0.5, f"Calibration Error:\n{str(e)}", ha='center') return fig # ========================================== # 4. GRADIO UI DEFINITION # ========================================== custom_theme = gr.themes.Monochrome( primary_hue="slate", secondary_hue="gray", font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"] ) # 10 Detailed Clinical Examples spanning all feature variations clinical_examples = [ # [age, sex, temp, hb, wbc, platelets, hb_a, hb_s, hb_f, rdt, retic, hb_decline, fever, chills, headache, muscle, fatigue, appetite, jaundice, abd_pain, joint_pain, spleno, pallor, lymph] [8, "Male", 39.5, 11.5, 9.5, 150, 98.0, 0.0, 2.0, "Positive", 1.5, False, True, True, True, True, True, True, False, False, False, False, False, False], # 1. Uncomplicated Malaria [22, "Female", 39.0, 7.5, 12.0, 90, 95.0, 0.0, 2.0, "Positive", 4.0, False, True, True, True, True, True, True, True, False, False, True, True, False], # 2. Severe Malaria [15, "Male", 37.2, 8.0, 11.0, 250, 5.0, 85.0, 10.0, "Negative", 6.0, False, False, False, False, True, True, False, True, True, True, False, True, False], # 3. SCA Vaso-occlusive Crisis [18, "Female", 37.5, 4.5, 14.0, 300, 2.0, 90.0, 8.0, "Negative", 10.0, True, False, False, False, False, True, False, True, False, True, True, True, False], # 4. SCA Hyperhemolytic (Trigger Hb<5) [12, "Male", 38.8, 6.5, 16.0, 110, 10.0, 80.0, 10.0, "Positive", 9.5, False, True, True, True, True, True, True, True, True, True, True, True, False], # 5. Co-infection (Acute Hemolytic, Retic>8) [25, "Female", 39.2, 7.0, 15.0, 100, 5.0, 85.0, 10.0, "Positive", 5.0, True, True, True, True, True, True, True, True, False, True, True, True, False], # 6. Co-infection (Rapidly Progressing) [30, "Male", 36.8, 14.0, 6.5, 250, 98.0, 0.0, 2.0, "Negative", 1.0, False, False, False, False, False, False, False, False, False, False, False, False, False], # 7. Healthy Adult [45, "Female", 37.8, 13.5, 5.0, 210, 97.0, 0.0, 2.0, "Negative", 1.2, False, True, False, True, True, True, False, False, False, False, False, False, True], # 8. Viral Infection (Non-malarial) [10, "Male", 39.8, 6.0, 18.0, 80, 95.0, 0.0, 3.0, "Positive", 7.0, False, True, True, True, False, True, True, True, True, False, True, True, False], # 9. Malaria with Overlapping Symptoms [28, "Female", 37.0, 12.5, 7.0, 220, 60.0, 38.0, 2.0, "Negative", 1.5, False, False, False, False, False, False, False, False, False, False, False, False, False] # 10. SCA Trait (Asymptomatic) ] with gr.Blocks(theme=custom_theme, title="Hemaclass Clinical Dashboard") as demo: gr.Markdown("# Hemaclass Clinical Decision Support System") gr.Markdown("Deep Stacking Ensemble Model for Malaria and Sickle Cell Anemia Classification.") with gr.Tabs(): # --- TAB 1: CORE INFERENCE --- with gr.TabItem("Single Patient Validation"): with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Demographics & Vitals") with gr.Row(): age_in = gr.Number(label="Age", value=25) sex_in = gr.Dropdown(["Male", "Female"], label="Sex", value="Female") temp_in = gr.Number(label="Temperature (°C)", value=37.5) gr.Markdown("### Clinical Symptoms") with gr.Row(): fever_in = gr.Checkbox(label="Fever") chills_in = gr.Checkbox(label="Chills") headache_in = gr.Checkbox(label="Headache") fatigue_in = gr.Checkbox(label="Fatigue") with gr.Row(): jaundice_in = gr.Checkbox(label="Jaundice") splenomegaly_in = gr.Checkbox(label="Splenomegaly") pallor_in = gr.Checkbox(label="Severe Pallor") muscle_in = gr.Checkbox(label="Muscle Aches") with gr.Accordion("Additional Symptoms", open=False): loss_appetite_in = gr.Checkbox(label="Loss of Appetite") abd_pain_in = gr.Checkbox(label="Abdominal Pain") joint_pain_in = gr.Checkbox(label="Joint Pain") lymph_in = gr.Checkbox(label="Lymphadenopathy") gr.Markdown("### Critical Laboratory Markers") with gr.Row(): rdt_in = gr.Radio(["Negative", "Positive"], label="Malaria RDT", value="Negative") retic_in = gr.Number(label="Reticulocyte Count (%)", value=2.0) with gr.Row(): hb_in = gr.Number(label="Hemoglobin (g/dL)", value=12.0) hb_decline_in = gr.Checkbox(label="Rapid Hb Decline (>1.5g/dl in 48h)") with gr.Row(): hb_a_in = gr.Number(label="HbA Fraction (%)", value=98.0) hb_s_in = gr.Number(label="HbS Fraction (%)", value=0.0) hb_f_in = gr.Number(label="HbF Fraction (%)", value=2.0) with gr.Row(): wbc_in = gr.Number(label="WBC Count (x10^9/L)", value=8.0) platelets_in = gr.Number(label="Platelet Count", value=200) manual_btn = gr.Button("Validate Diagnosis", variant="primary", size="lg") with gr.Column(scale=1): gr.Markdown("### System Output") out_diag = gr.Markdown() out_recs = gr.Markdown() out_shap = gr.Plot(label="Feature Contribution Analysis") gr.Markdown("---") gr.Markdown("### Load Clinical Scenarios") gr.Markdown("Select a predefined clinical case to auto-populate the diagnostic fields.") input_components = [ age_in, sex_in, temp_in, hb_in, wbc_in, platelets_in, hb_a_in, hb_s_in, hb_f_in, rdt_in, retic_in, hb_decline_in, fever_in, chills_in, headache_in, muscle_in, fatigue_in, loss_appetite_in, jaundice_in, abd_pain_in, joint_pain_in, splenomegaly_in, pallor_in, lymph_in ] gr.Examples( examples=clinical_examples, inputs=input_components, label="Predefined Patient Cases" ) manual_btn.click( manual_inference, inputs=input_components, outputs=[out_diag, out_recs, out_shap] ) # --- TAB 2: PERFORMANCE METRICS --- with gr.TabItem("Systematic Testing"): gr.Markdown("### Overall Model Performance on Unseen Test Cohort") metrics_btn = gr.Button("Calculate Systematic Metrics", variant="secondary") out_metrics = gr.Markdown() metrics_btn.click(load_systematic_metrics, inputs=[], outputs=[out_metrics]) # --- TAB 3: ADVANCED CALIBRATION --- with gr.TabItem("Advanced Validation"): gr.Markdown("### Evaluate Diagnosis Calibration") gr.Markdown("Select a disease class below to verify the alignment between predicted probabilities and true clinical frequencies.") with gr.Row(): class_dropdown = gr.Dropdown(target_names, label="Select Target Class", value="Co-infection") calib_btn = gr.Button("Check Calibration", variant="secondary") out_calib = gr.Plot() calib_btn.click(check_calibration, inputs=[class_dropdown], outputs=[out_calib]) # Launch inside Colab if __name__ == "__main__": demo.launch(share=True)