| import gradio as gr |
| import pandas as pd |
| import numpy as np |
| import joblib |
| import shap |
| import matplotlib |
| import traceback |
| import warnings |
| from sklearn.metrics import accuracy_score, confusion_matrix |
|
|
| warnings.filterwarnings('ignore') |
| matplotlib.use('Agg') |
| import matplotlib.pyplot as plt |
|
|
| |
| |
| |
| print("Loading Model Artifacts...") |
| try: |
| best_model = joblib.load('ensemble_model.pkl') |
| scaler = joblib.load('scaler.pkl') |
| imputer = joblib.load('imputer.pkl') |
| encoder = joblib.load('encoder.pkl') |
| FEATURE_NAMES = joblib.load('feature_names.pkl') |
| cat_columns = joblib.load('cat_columns.pkl') |
| |
| |
| xgb_base = best_model.named_estimators_['xgb'] |
| explainer = shap.TreeExplainer(xgb_base) |
| print("All artifacts loaded successfully.") |
| except Exception as e: |
| print(f"Error loading artifacts: {e}. Ensure the training script ran successfully.") |
|
|
| target_names = ['Negative', 'Malaria', 'SCA', 'Co-infection'] |
|
|
| |
| |
| |
|
|
| def preprocess_input(input_df): |
| """Replicates the exact Feature Engineering & Preprocessing from Training""" |
| df = input_df.copy() |
| |
| |
| symptom_cols = ['fever', 'chills', 'headache', 'muscle_aches', 'fatigue', |
| 'loss_of_appetite', 'jaundice', 'abdominal_pain', 'joint_pain', |
| 'splenomegaly', 'pallor', 'lymphadenopathy'] |
| |
| df['symptom_severity_score'] = df[[c for c in symptom_cols if c in df.columns]].sum(axis=1) |
| |
| if 'age' in df.columns: |
| df['age_group'] = pd.cut(df['age'], bins=[-1, 5, 12, 55, 120], labels=[0, 1, 2, 3]).astype(float) |
| |
| if 'hb' in df.columns and 'wbc' in df.columns: |
| df['infection_anemia_ratio'] = df['wbc'] / (df['hb'] + 1e-5) |
|
|
| |
| for c in set(FEATURE_NAMES) - set(df.columns): |
| df[c] = np.nan |
| df_aligned = df[FEATURE_NAMES].copy() |
|
|
| |
| MISSING_STR = 'MISSING_CAT' |
| if cat_columns: |
| present_cats = [c for c in cat_columns if c in df_aligned.columns] |
| if present_cats: |
| df_aligned[present_cats] = df_aligned[present_cats].astype(str).replace(['nan', 'None'], np.nan) |
| df_aligned[present_cats] = df_aligned[present_cats].fillna(MISSING_STR) |
| df_aligned[present_cats] = encoder.transform(df_aligned[present_cats]) |
| |
| for i, col in enumerate(cat_columns): |
| if col in present_cats and MISSING_STR in encoder.categories_[i]: |
| missing_code = list(encoder.categories_[i]).index(MISSING_STR) |
| df_aligned[col] = df_aligned[col].replace(missing_code, np.nan) |
|
|
| for col in df_aligned.columns: |
| df_aligned[col] = pd.to_numeric(df_aligned[col], errors='coerce') |
|
|
| |
| X_imp = pd.DataFrame(imputer.transform(df_aligned), columns=FEATURE_NAMES) |
| X_scaled = pd.DataFrame(scaler.transform(X_imp), columns=FEATURE_NAMES) |
| |
| return X_scaled |
|
|
| def get_specific_coinfection_type(hb, retic, hb_decline, hb_s): |
| """Determines granular sub-type of Co-infection based on critical markers""" |
| if hb < 5.0: |
| return "Co-infection: Severe Hyperhemolytic Malarial Crisis" |
| elif retic > 8.0: |
| return "Co-infection: Acute Hemolytic Malarial Crisis" |
| elif hb_decline and hb_s > 0: |
| return "Co-infection: Rapidly Progressing Vaso-occlusive Malarial Crisis" |
| else: |
| return "Co-infection: Concurrent Malaria & Sickle Cell Crisis" |
|
|
| def get_clinical_recs(diag, rule_triggered=None): |
| recs = f"### Clinical Decision Support Protocol\n\n" |
| |
| if rule_triggered: |
| recs += f"**Critical Protocol Triggered:** *{rule_triggered}*\n\n" |
|
|
| if 'Malaria' in diag and 'Co-infection' not in diag: |
| recs += "**Protocol:** Initiate Artemisinin-based Combination Therapy (ACT) per WHO guidelines.\n" |
| elif diag == 'SCA': |
| recs += "**Protocol:** Administer IV Fluids, oxygen therapy, and comprehensive pain management.\n" |
| elif 'Co-infection' in diag: |
| recs += "**Urgent Protocol:** High risk of hyperhemolytic or severe vaso-occlusive crisis.\n" |
| recs += "- **Action:** Immediate admission to high-dependency unit. Initiate rapid intravenous antimalarials, aggressive hydration, and prepare for potential blood transfusion.\n" |
| else: |
| recs += "**Action:** Patient is currently negative for active Malaria and SCA crisis.\n" |
| recs += "- **Follow-up:** Screen for Typhoid, Dengue, or other viral infections if febrile symptoms persist.\n" |
|
|
| recs += "\n---\n### Diagnostic Context Notes\n" |
| recs += "- **Overlapping Symptoms:** Fever, Fatigue, Jaundice, Splenomegaly, and Headache *(Headache is uncommon in SCA unless accompanied by severe anemia, cerebral malaria, or stroke risk).* \n" |
| recs += "- **Co-infection Prevalences:** Key clinical indicators for Co-infection include Severe Pallor + Jaundice, High fever, Splenomegaly + malaria, and Extreme Reticulocyte (>8%) + malaria." |
| |
| return recs |
|
|
| def generate_shap_plot(X_scaled): |
| try: |
| shap_values = explainer.shap_values(X_scaled) |
| |
| if isinstance(shap_values, list): |
| pat_shap = shap_values[3][0] |
| base_val = explainer.expected_value[3] |
| elif len(shap_values.shape) == 3: |
| pat_shap = shap_values[0, :, 3] |
| base_val = explainer.expected_value[3] if isinstance(explainer.expected_value, list) else explainer.expected_value |
| else: |
| pat_shap = shap_values[0] |
| base_val = explainer.expected_value |
|
|
| fig, ax = plt.subplots(figsize=(7, 5)) |
| ax.spines['top'].set_visible(False) |
| ax.spines['right'].set_visible(False) |
| |
| explanation = shap.Explanation(values=pat_shap, base_values=base_val, |
| data=X_scaled.iloc[0], feature_names=FEATURE_NAMES) |
| shap.waterfall_plot(explanation, show=False) |
| plt.title("XAI Feature Contribution (Impact on Co-Infection Risk)", fontsize=11, fontweight='bold') |
| plt.tight_layout() |
| return fig |
| except Exception as e: |
| fig, ax = plt.subplots(figsize=(6,4)) |
| ax.text(0.5, 0.5, f"Interpretability Module Offline:\n{str(e)}", ha='center', va='center') |
| return fig |
|
|
| def manual_inference(age, sex, temp, hb, wbc, platelets, hb_a, hb_s, hb_f, malaria_rdt, reticulocyte, hb_rapid_decline, |
| fever, chills, headache, muscle_aches, fatigue, loss_of_appetite, jaundice, abdominal_pain, joint_pain, splenomegaly, pallor, lymphadenopathy): |
| try: |
| co_infection_flag = False |
| rule_triggered = "" |
| specific_coinfection_name = "" |
|
|
| |
| if hb < 5.0: |
| co_infection_flag = True |
| rule_triggered = "Hemoglobin below critical threshold (5.0 g/dL)" |
| elif reticulocyte > 8.0 and malaria_rdt == "Positive": |
| co_infection_flag = True |
| rule_triggered = "Extreme Reticulocyte (>8%) + Positive Malaria RDT" |
| elif hb_rapid_decline and malaria_rdt == "Positive" and hb_s > 0: |
| co_infection_flag = True |
| rule_triggered = "Rapid Hb decline (>1.5g/dL in 48h) + Positive Malaria + SCA Genotype" |
|
|
| if co_infection_flag: |
| specific_coinfection_name = get_specific_coinfection_type(hb, reticulocyte, hb_rapid_decline, hb_s) |
|
|
| input_data = pd.DataFrame({ |
| 'age': [age], 'sex': [sex], 'temp': [temp], 'hb': [hb], 'wbc': [wbc], 'platelets': [platelets], |
| 'hb_a': [hb_a], 'hb_s': [hb_s], 'hb_f': [hb_f], |
| 'malaria_rdt': [1.0 if malaria_rdt == "Positive" else 0.0], |
| 'reticulocyte': [reticulocyte], 'hb_rapid_decline': [1.0 if hb_rapid_decline else 0.0], |
| 'fever': [1.0 if fever else 0.0], 'chills': [1.0 if chills else 0.0], 'headache': [1.0 if headache else 0.0], |
| 'muscle_aches': [1.0 if muscle_aches else 0.0], 'fatigue': [1.0 if fatigue else 0.0], |
| 'loss_of_appetite': [1.0 if loss_of_appetite else 0.0], 'jaundice': [1.0 if jaundice else 0.0], |
| 'abdominal_pain': [1.0 if abdominal_pain else 0.0], 'joint_pain': [1.0 if joint_pain else 0.0], |
| 'splenomegaly': [1.0 if splenomegaly else 0.0], 'pallor': [1.0 if pallor else 0.0], |
| 'lymphadenopathy': [1.0 if lymphadenopathy else 0.0] |
| }) |
|
|
| X_scaled = preprocess_input(input_data) |
| probs = best_model.predict_proba(X_scaled)[0] |
| |
| |
| prob_dict = {target_names[i]: probs[i] * 100 for i in range(len(target_names))} |
| |
| |
| if co_infection_flag: |
| primary_diag = specific_coinfection_name |
| |
| prob_dict = { |
| specific_coinfection_name: 100.0, |
| 'Malaria (Override)': prob_dict['Malaria'], |
| 'SCA (Override)': prob_dict['SCA'], |
| 'Negative': 0.0 |
| } |
| else: |
| pred_idx = np.argmax(probs) |
| primary_diag = target_names[pred_idx] |
| |
| |
| if primary_diag == 'Co-infection': |
| primary_diag = get_specific_coinfection_type(hb, reticulocyte, hb_rapid_decline, hb_s) |
| prob_dict[primary_diag] = prob_dict.pop('Co-infection') |
|
|
| |
| diag_output = f"## Primary Diagnosis: {primary_diag}\n\n### Comprehensive Confidence Breakdown:\n" |
| |
| |
| sorted_probs = sorted(prob_dict.items(), key=lambda x: x[1], reverse=True) |
| for disease, conf in sorted_probs: |
| if 'Co-infection' in disease and 'Override' not in disease: |
| diag_output += f"- **{disease}**: {conf:.1f}%\n" |
| else: |
| diag_output += f"- **{disease}**: {conf:.1f}%\n" |
|
|
| recs = get_clinical_recs(primary_diag, rule_triggered) |
| fig = generate_shap_plot(X_scaled) |
|
|
| return diag_output, recs, fig |
| except Exception as e: |
| return f"### Inference Error\n```\n{traceback.format_exc()}\n```", "System Error.", None |
|
|
| |
| |
| |
|
|
| def load_systematic_metrics(): |
| try: |
| y_test_val = joblib.load('y_test_val.pkl') |
| y_probs_val = joblib.load('y_probs_val.pkl') |
| y_pred_val = np.argmax(y_probs_val, axis=1) |
|
|
| acc = accuracy_score(y_test_val, y_pred_val) |
| cm = confusion_matrix(y_test_val, y_pred_val) |
| |
| sens_list, spec_list = [], [] |
| for i in range(len(cm)): |
| tp = cm[i,i] |
| fn = np.sum(cm[i,:]) - tp |
| fp = np.sum(cm[:,i]) - tp |
| tn = np.sum(cm) - tp - fn - fp |
| sens_list.append(tp / (tp + fn) if (tp + fn) > 0 else 0) |
| spec_list.append(tn / (tn + fp) if (tn + fp) > 0 else 0) |
| |
| sens = np.mean(sens_list) |
| spec = np.mean(spec_list) |
|
|
| return f"### Systematic Evaluation Metrics (Held-out Cohort)\n\n- **Overall Accuracy**: {acc*100:.2f}%\n- **Sensitivity (Macro)**: {sens*100:.2f}%\n- **Specificity (Macro)**: {spec*100:.2f}%" |
| except Exception as e: |
| return f"Error loading validation metrics: Ensure 'y_test_val.pkl' and 'y_probs_val.pkl' exist in memory. \n({str(e)})" |
|
|
| def check_calibration(class_name): |
| try: |
| from sklearn.calibration import CalibrationDisplay |
| y_test_val = joblib.load('y_test_val.pkl') |
| y_probs_val = joblib.load('y_probs_val.pkl') |
| class_idx = target_names.index(class_name) |
|
|
| y_true_binary = (y_test_val == class_idx).astype(int) |
| y_prob_class = y_probs_val[:, class_idx] |
|
|
| fig, ax = plt.subplots(figsize=(6, 5)) |
| ax.spines['top'].set_visible(False) |
| ax.spines['right'].set_visible(False) |
| CalibrationDisplay.from_predictions(y_true_binary, y_prob_class, n_bins=10, ax=ax, name=class_name) |
| plt.title(f"Reliability Curve (Calibration) for {class_name}", fontweight='bold') |
| plt.tight_layout() |
| return fig |
| except Exception as e: |
| fig, ax = plt.subplots() |
| ax.text(0.5, 0.5, f"Calibration Error:\n{str(e)}", ha='center') |
| return fig |
|
|
| |
| |
| |
|
|
| custom_theme = gr.themes.Monochrome( |
| primary_hue="slate", |
| secondary_hue="gray", |
| font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"] |
| ) |
|
|
| |
| clinical_examples = [ |
| |
| [8, "Male", 39.5, 11.5, 9.5, 150, 98.0, 0.0, 2.0, "Positive", 1.5, False, True, True, True, True, True, True, False, False, False, False, False, False], |
| [22, "Female", 39.0, 7.5, 12.0, 90, 95.0, 0.0, 2.0, "Positive", 4.0, False, True, True, True, True, True, True, True, False, False, True, True, False], |
| [15, "Male", 37.2, 8.0, 11.0, 250, 5.0, 85.0, 10.0, "Negative", 6.0, False, False, False, False, True, True, False, True, True, True, False, True, False], |
| [18, "Female", 37.5, 4.5, 14.0, 300, 2.0, 90.0, 8.0, "Negative", 10.0, True, False, False, False, False, True, False, True, False, True, True, True, False], |
| [12, "Male", 38.8, 6.5, 16.0, 110, 10.0, 80.0, 10.0, "Positive", 9.5, False, True, True, True, True, True, True, True, True, True, True, True, False], |
| [25, "Female", 39.2, 7.0, 15.0, 100, 5.0, 85.0, 10.0, "Positive", 5.0, True, True, True, True, True, True, True, True, False, True, True, True, False], |
| [30, "Male", 36.8, 14.0, 6.5, 250, 98.0, 0.0, 2.0, "Negative", 1.0, False, False, False, False, False, False, False, False, False, False, False, False, False], |
| [45, "Female", 37.8, 13.5, 5.0, 210, 97.0, 0.0, 2.0, "Negative", 1.2, False, True, False, True, True, True, False, False, False, False, False, False, True], |
| [10, "Male", 39.8, 6.0, 18.0, 80, 95.0, 0.0, 3.0, "Positive", 7.0, False, True, True, True, False, True, True, True, True, False, True, True, False], |
| [28, "Female", 37.0, 12.5, 7.0, 220, 60.0, 38.0, 2.0, "Negative", 1.5, False, False, False, False, False, False, False, False, False, False, False, False, False] |
| ] |
|
|
| with gr.Blocks(theme=custom_theme, title="Hemaclass Clinical Dashboard") as demo: |
| gr.Markdown("# Hemaclass Clinical Decision Support System") |
| gr.Markdown("Deep Stacking Ensemble Model for Malaria and Sickle Cell Anemia Classification.") |
|
|
| with gr.Tabs(): |
| |
| with gr.TabItem("Single Patient Validation"): |
| with gr.Row(): |
| with gr.Column(scale=1): |
| gr.Markdown("### Demographics & Vitals") |
| with gr.Row(): |
| age_in = gr.Number(label="Age", value=25) |
| sex_in = gr.Dropdown(["Male", "Female"], label="Sex", value="Female") |
| temp_in = gr.Number(label="Temperature (°C)", value=37.5) |
|
|
| gr.Markdown("### Clinical Symptoms") |
| with gr.Row(): |
| fever_in = gr.Checkbox(label="Fever") |
| chills_in = gr.Checkbox(label="Chills") |
| headache_in = gr.Checkbox(label="Headache") |
| fatigue_in = gr.Checkbox(label="Fatigue") |
| with gr.Row(): |
| jaundice_in = gr.Checkbox(label="Jaundice") |
| splenomegaly_in = gr.Checkbox(label="Splenomegaly") |
| pallor_in = gr.Checkbox(label="Severe Pallor") |
| muscle_in = gr.Checkbox(label="Muscle Aches") |
| with gr.Accordion("Additional Symptoms", open=False): |
| loss_appetite_in = gr.Checkbox(label="Loss of Appetite") |
| abd_pain_in = gr.Checkbox(label="Abdominal Pain") |
| joint_pain_in = gr.Checkbox(label="Joint Pain") |
| lymph_in = gr.Checkbox(label="Lymphadenopathy") |
|
|
| gr.Markdown("### Critical Laboratory Markers") |
| with gr.Row(): |
| rdt_in = gr.Radio(["Negative", "Positive"], label="Malaria RDT", value="Negative") |
| retic_in = gr.Number(label="Reticulocyte Count (%)", value=2.0) |
| with gr.Row(): |
| hb_in = gr.Number(label="Hemoglobin (g/dL)", value=12.0) |
| hb_decline_in = gr.Checkbox(label="Rapid Hb Decline (>1.5g/dl in 48h)") |
| with gr.Row(): |
| hb_a_in = gr.Number(label="HbA Fraction (%)", value=98.0) |
| hb_s_in = gr.Number(label="HbS Fraction (%)", value=0.0) |
| hb_f_in = gr.Number(label="HbF Fraction (%)", value=2.0) |
| with gr.Row(): |
| wbc_in = gr.Number(label="WBC Count (x10^9/L)", value=8.0) |
| platelets_in = gr.Number(label="Platelet Count", value=200) |
|
|
| manual_btn = gr.Button("Validate Diagnosis", variant="primary", size="lg") |
|
|
| with gr.Column(scale=1): |
| gr.Markdown("### System Output") |
| out_diag = gr.Markdown() |
| out_recs = gr.Markdown() |
| out_shap = gr.Plot(label="Feature Contribution Analysis") |
| |
| gr.Markdown("---") |
| gr.Markdown("### Load Clinical Scenarios") |
| gr.Markdown("Select a predefined clinical case to auto-populate the diagnostic fields.") |
| |
| input_components = [ |
| age_in, sex_in, temp_in, hb_in, wbc_in, platelets_in, hb_a_in, hb_s_in, hb_f_in, |
| rdt_in, retic_in, hb_decline_in, fever_in, chills_in, headache_in, muscle_in, |
| fatigue_in, loss_appetite_in, jaundice_in, abd_pain_in, joint_pain_in, |
| splenomegaly_in, pallor_in, lymph_in |
| ] |
| |
| gr.Examples( |
| examples=clinical_examples, |
| inputs=input_components, |
| label="Predefined Patient Cases" |
| ) |
|
|
| manual_btn.click( |
| manual_inference, |
| inputs=input_components, |
| outputs=[out_diag, out_recs, out_shap] |
| ) |
|
|
| |
| with gr.TabItem("Systematic Testing"): |
| gr.Markdown("### Overall Model Performance on Unseen Test Cohort") |
| metrics_btn = gr.Button("Calculate Systematic Metrics", variant="secondary") |
| out_metrics = gr.Markdown() |
| metrics_btn.click(load_systematic_metrics, inputs=[], outputs=[out_metrics]) |
|
|
| |
| with gr.TabItem("Advanced Validation"): |
| gr.Markdown("### Evaluate Diagnosis Calibration") |
| gr.Markdown("Select a disease class below to verify the alignment between predicted probabilities and true clinical frequencies.") |
| with gr.Row(): |
| class_dropdown = gr.Dropdown(target_names, label="Select Target Class", value="Co-infection") |
| calib_btn = gr.Button("Check Calibration", variant="secondary") |
| out_calib = gr.Plot() |
| calib_btn.click(check_calibration, inputs=[class_dropdown], outputs=[out_calib]) |
|
|
| |
| if __name__ == "__main__": |
| demo.launch(share=True) |