MABONGALABS commited on
Commit
ef0a1e4
·
verified ·
1 Parent(s): 25eaacd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +318 -159
app.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import gradio as gr
3
  import pandas as pd
4
  import numpy as np
@@ -6,34 +5,61 @@ import joblib
6
  import shap
7
  import matplotlib
8
  import traceback
 
 
9
 
10
- # FORCE AGG BACKEND: Prevents thread-safety crashes and `main thread not in main loop` GUI errors in Hugging Face Spaces
11
  matplotlib.use('Agg')
12
  import matplotlib.pyplot as plt
13
 
14
  # ==========================================
15
- # 1. LOAD PRODUCTION ASSETS
16
  # ==========================================
 
17
  try:
18
- model = joblib.load('ensemble_model.pkl')
19
  scaler = joblib.load('scaler.pkl')
20
  imputer = joblib.load('imputer.pkl')
21
  encoder = joblib.load('encoder.pkl')
22
- FEATURES = joblib.load('feature_names.pkl')
23
  cat_columns = joblib.load('cat_columns.pkl')
24
- target_names = ['Negative', 'Malaria', 'SCA', 'Co-infection']
25
- xgb_base = model.named_estimators_['xgb']
 
26
  explainer = shap.TreeExplainer(xgb_base)
 
27
  except Exception as e:
28
- raise RuntimeError(f"Error loading model artifacts. Ensure all .pkl files are uploaded to the Hugging Face Space. Details: {e}")
 
 
29
 
30
  # ==========================================
31
- # 2. CORE INFERENCE LOGIC (Phase 4 Deliverable)
32
  # ==========================================
33
- def process_and_predict(input_df):
34
- for c in set(FEATURES) - set(input_df.columns): input_df[c] = np.nan
35
- df_aligned = input_df[FEATURES].copy()
 
 
 
 
 
 
 
 
36
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  MISSING_STR = 'MISSING_CAT'
38
  if cat_columns:
39
  present_cats = [c for c in cat_columns if c in df_aligned.columns]
@@ -41,199 +67,332 @@ def process_and_predict(input_df):
41
  df_aligned[present_cats] = df_aligned[present_cats].astype(str).replace(['nan', 'None'], np.nan)
42
  df_aligned[present_cats] = df_aligned[present_cats].fillna(MISSING_STR)
43
  df_aligned[present_cats] = encoder.transform(df_aligned[present_cats])
 
44
  for i, col in enumerate(cat_columns):
45
  if col in present_cats and MISSING_STR in encoder.categories_[i]:
46
  missing_code = list(encoder.categories_[i]).index(MISSING_STR)
47
  df_aligned[col] = df_aligned[col].replace(missing_code, np.nan)
48
-
49
  for col in df_aligned.columns:
50
- try:
51
- df_aligned[col] = pd.to_numeric(df_aligned[col])
52
- except Exception:
53
- pass
54
-
55
- X_imp = pd.DataFrame(imputer.transform(df_aligned), columns=FEATURES)
56
- X_scaled = pd.DataFrame(scaler.transform(X_imp), columns=FEATURES)
57
 
58
- preds, probs = model.predict(X_scaled), model.predict_proba(X_scaled)
59
- input_df['AI_Diagnosis'] = [target_names[p] for p in preds]
60
- input_df['Confidence'] = [f"{max(pr)*100:.1f}%" for pr in probs]
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
- return input_df, X_scaled, preds, probs, X_imp
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
- def generate_shap_plot(X_scaled, preds, index=0):
65
  try:
66
  shap_values = explainer.shap_values(X_scaled)
67
- pred_idx = preds[index]
68
  if isinstance(shap_values, list):
69
- pat_shap = shap_values[pred_idx][index]
70
- base_val = explainer.expected_value[pred_idx]
 
 
 
71
  else:
72
- pat_shap = shap_values[index, :, pred_idx] if len(shap_values.shape)==3 else shap_values[index]
73
- base_val = explainer.expected_value[pred_idx] if isinstance(explainer.expected_value, list) else explainer.expected_value
74
 
75
- fig, ax = plt.subplots(figsize=(6, 4))
 
 
 
76
  explanation = shap.Explanation(values=pat_shap, base_values=base_val,
77
- data=X_scaled.iloc[index], feature_names=FEATURES)
78
  shap.waterfall_plot(explanation, show=False)
79
- plt.title(f"SHAP Force/Waterfall Explanation for Patient")
80
  plt.tight_layout()
81
  return fig
82
  except Exception as e:
83
  fig, ax = plt.subplots(figsize=(6,4))
84
- ax.text(0.5, 0.5, f"SHAP Error: {str(e)}", ha='center', va='center', color='red')
85
  return fig
86
 
87
- def extract_shap_dict(X_scaled, preds, index=0):
 
88
  try:
89
- shap_values = explainer.shap_values(X_scaled)
90
- pred_idx = preds[index]
91
- if isinstance(shap_values, list):
92
- pat_shap = shap_values[pred_idx][index]
93
- else:
94
- pat_shap = shap_values[index, :, pred_idx] if len(shap_values.shape)==3 else shap_values[index]
95
- return {feat: impact for feat, impact in zip(FEATURES, pat_shap)}
96
- except:
97
- return None
98
-
99
- def get_clinical_recs(diag, raw_data, shap_values_dict=None):
100
- base_str = ""
101
- if shap_values_dict:
102
- sorted_feats = sorted(shap_values_dict.items(), key=lambda item: abs(item[1]), reverse=True)[:3]
103
- drivers = ", ".join([f"{feat.upper()} ({impact:+.2f})" for feat, impact in sorted_feats])
104
- base_str = f"**Interpretability Output:** Prediction: {diag}. Key drivers: {drivers}\n\n"
105
-
106
- recs = base_str + f"**Clinical Decision Support for:** {diag}\n\n"
107
- if diag == 'Malaria':
108
- recs += "- **Protocol:** Initiate Artemisinin-based Combination Therapy (ACT).\n"
109
- if raw_data.get('temp', 0) > 38.0: recs += "- **Vitals Alert:** High Fever. Administer antipyretics.\n"
110
- if raw_data.get('hb', 12) < 8.0: recs += "- **Lab Alert:** Severe Anemia present. Prepare for blood transfusion review.\n"
111
- elif diag == 'SCA':
112
- recs += "- **Protocol:** Administer IV Fluids, oxygen therapy, and pain management.\n"
113
- if raw_data.get('hb_s', 0) > 30: recs += "- **Lab Alert:** High HbS detected. Review Hydroxyurea therapy candidacy.\n"
114
- elif diag == 'Co-infection':
115
- recs += "- **URGENT PROTOCOL:** High risk of hyperhemolytic crisis.\n"
116
- recs += "- **Action:** Admit to high-dependency unit. Initiate rapid antimalarials and aggressive hydration.\n"
117
- else:
118
- recs += "- **Action:** Negative for Malaria and SCA.\n"
119
- recs += "- **Follow-up:** Screen for Typhoid, Dengue, or viral infections if symptoms persist.\n"
120
- return recs
121
 
122
- # ==========================================
123
- # 3. GRADIO EVENT HANDLERS
124
- # ==========================================
125
- def manual_inference(age, sex, temp, hb, malaria_rdt, hb_s, wbc, platelets, fever, headache, jaundice):
126
- try:
127
  input_data = pd.DataFrame({
128
- 'age': [age], 'sex': [sex], 'temp': [temp], 'hb': [hb],
129
- 'malaria_rdt': [1.0 if malaria_rdt == "Positive" else 0.0],
130
- 'hb_s': [hb_s], 'wbc': [wbc], 'platelets': [platelets],
131
- 'fever': [1.0 if fever else 0.0], 'headache': [1.0 if headache else 0.0], 'jaundice': [1.0 if jaundice else 0.0]
 
 
 
 
 
 
132
  })
 
 
 
133
 
134
- res_df, X_scaled, preds, probs, X_imp = process_and_predict(input_data)
135
-
136
- diag = res_df['AI_Diagnosis'].iloc[0]
137
- conf = res_df['Confidence'].iloc[0]
138
- raw_data = X_imp.iloc[0]
139
 
140
- shap_dict = extract_shap_dict(X_scaled, preds, 0)
141
- recs = get_clinical_recs(diag, raw_data, shap_dict)
142
- fig = generate_shap_plot(X_scaled, preds, 0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
- return f"### **AI Diagnosis:** {diag}\n**Confidence:** {conf}", recs, fig
 
 
 
 
 
 
 
 
 
 
 
145
  except Exception as e:
146
- error_msg = f"### Inference Error\n```\n{traceback.format_exc()}\n```"
147
- return error_msg, "System Error. Please review inputs.", None
148
 
149
- def batch_inference(file_path):
 
 
 
 
150
  try:
151
- if file_path is None:
152
- return pd.DataFrame(), "### ⚠️ Upload Error: Please upload a valid CSV file.", "", None
153
-
154
- file_path_str = file_path if isinstance(file_path, str) else file_path.name
155
- df = pd.read_csv(file_path_str)
 
156
 
157
- if df.empty:
158
- return df, "### ⚠️ Data Error: The uploaded CSV is empty.", "", None
 
 
 
 
 
 
159
 
160
- res_df, X_scaled, preds, probs, X_imp = process_and_predict(df)
161
-
162
- diag = res_df['AI_Diagnosis'].iloc[0]
163
- conf = res_df['Confidence'].iloc[0]
164
- raw_data = X_imp.iloc[0]
165
-
166
- shap_dict = extract_shap_dict(X_scaled, preds, 0)
167
- recs = get_clinical_recs(diag, raw_data, shap_dict)
168
- fig = generate_shap_plot(X_scaled, preds, 0)
169
-
170
- report_text = f"**Batch Processing Complete ({len(res_df)} records analyzed)**\n\n**Deep Dive Analysis (Patient 1):**\n**Diagnosis:** {diag} ({conf})"
171
-
172
- display_cols = ['AI_Diagnosis', 'Confidence'] + [c for c in FEATURES if c in res_df.columns]
173
- return res_df[display_cols].head(15), report_text, recs, fig
 
 
 
 
 
 
 
 
 
 
174
  except Exception as e:
175
- error_msg = f"### ❌ Batch Error\n```\n{traceback.format_exc()}\n```"
176
- return pd.DataFrame(), error_msg, "Failed to process batch. Ensure CSV schema matches training format.", None
 
177
 
178
  # ==========================================
179
- # 4. HUGGING FACE GRADIO UI CONSTRUCTION
180
  # ==========================================
181
- with gr.Blocks(theme=gr.themes.Monochrome(), title="Hemaclass XAI Dashboard") as demo:
182
- gr.Markdown("# 🏥 Hemaclass Clinical Decision Support Dashboard")
183
- gr.Markdown("Phase 4 Prototype: Explainable AI-Driven Ensemble Model for Malaria and Sickle Cell Anemia Classification in Western Kenya.")
184
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  with gr.Tabs():
186
- with gr.TabItem("🩺 Single Patient Evaluation"):
 
187
  with gr.Row():
188
  with gr.Column(scale=1):
189
- gr.Markdown("### Patient Demographics & Vitals")
190
- age_in = gr.Number(label="Patient Age", value=25)
191
- sex_in = gr.Dropdown(["Male", "Female"], label="Sex", value="Female")
192
- temp_in = gr.Slider(minimum=34.0, maximum=42.0, value=37.5, label="Body Temperature (°C)")
193
-
194
- gr.Markdown("### Critical Labs")
195
- hb_in = gr.Number(label="Hemoglobin (Hb) g/dL", value=12.0)
196
- rdt_in = gr.Radio(["Negative", "Positive"], label="Malaria RDT Result", value="Negative")
197
- hbs_in = gr.Slider(minimum=0.0, maximum=100.0, value=0.0, label="Hb S Fraction (%)")
198
- wbc_in = gr.Number(label="WBC Count (x10^9/L)", value=8.0)
199
- platelets_in = gr.Number(label="Platelets (x10^9/L)", value=200)
200
-
201
  gr.Markdown("### Clinical Symptoms")
202
- fever_in = gr.Checkbox(label="Fever")
203
- headache_in = gr.Checkbox(label="Headache")
204
- jaundice_in = gr.Checkbox(label="Jaundice")
205
-
206
- manual_btn = gr.Button("Evaluate Patient", variant="primary")
207
-
208
- with gr.Column(scale=2):
209
- gr.Markdown("### AI Prediction & Explainability")
210
  with gr.Row():
211
- out_diag = gr.Markdown()
 
 
 
212
  with gr.Row():
213
- out_recs = gr.Markdown()
 
 
 
 
 
 
 
 
 
 
214
  with gr.Row():
215
- out_shap = gr.Plot(label="Local Interpretability (SHAP Waterfall)")
216
-
217
- manual_btn.click(manual_inference, inputs=[age_in, sex_in, temp_in, hb_in, rdt_in, hbs_in, wbc_in, platelets_in, fever_in, headache_in, jaundice_in], outputs=[out_diag, out_recs, out_shap])
218
-
219
- with gr.TabItem("📂 Retrospective Batch Evaluation"):
220
- gr.Markdown("Upload anonymized, cleaned clinical data (CSV) for batch processing and population-level analysis.")
221
- file_in = gr.File(label="Upload Cleaned Clinical Dataset", type="filepath", file_types=[".csv"])
222
- batch_btn = gr.Button("Run Batch Inference", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
 
224
- out_df = gr.Dataframe(label="Evaluation Results (Preview Top 15 rows)")
 
 
 
 
 
225
 
226
- gr.Markdown("---")
227
- gr.Markdown("### Patient 1 XAI Deep Dive Analysis")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  with gr.Row():
229
- with gr.Column():
230
- batch_diag = gr.Markdown()
231
- batch_recs = gr.Markdown()
232
- with gr.Column():
233
- batch_shap = gr.Plot(label="Patient 1 SHAP Interpretation")
234
-
235
- batch_btn.click(batch_inference, inputs=file_in, outputs=[out_df, batch_diag, batch_recs, batch_shap])
236
 
 
237
  if __name__ == "__main__":
238
- demo.launch(server_name="0.0.0.0", server_port=7860)
239
-
 
 
1
  import gradio as gr
2
  import pandas as pd
3
  import numpy as np
 
5
  import shap
6
  import matplotlib
7
  import traceback
8
+ import warnings
9
+ from sklearn.metrics import accuracy_score, confusion_matrix
10
 
11
+ warnings.filterwarnings('ignore')
12
  matplotlib.use('Agg')
13
  import matplotlib.pyplot as plt
14
 
15
  # ==========================================
16
+ # 1. LOAD TRAINED ARTIFACTS FROM COLAB MEMORY
17
  # ==========================================
18
+ print("Loading Model Artifacts...")
19
  try:
20
+ best_model = joblib.load('ensemble_model.pkl')
21
  scaler = joblib.load('scaler.pkl')
22
  imputer = joblib.load('imputer.pkl')
23
  encoder = joblib.load('encoder.pkl')
24
+ FEATURE_NAMES = joblib.load('feature_names.pkl')
25
  cat_columns = joblib.load('cat_columns.pkl')
26
+
27
+ # Extract XGBoost from StackingClassifier for SHAP explainability
28
+ xgb_base = best_model.named_estimators_['xgb']
29
  explainer = shap.TreeExplainer(xgb_base)
30
+ print("All artifacts loaded successfully.")
31
  except Exception as e:
32
+ print(f"Error loading artifacts: {e}. Ensure the training script ran successfully.")
33
+
34
+ target_names = ['Negative', 'Malaria', 'SCA', 'Co-infection']
35
 
36
  # ==========================================
37
+ # 2. CORE PROCESSING & PREDICTION LOGIC
38
  # ==========================================
39
+
40
+ def preprocess_input(input_df):
41
+ """Replicates the exact Feature Engineering & Preprocessing from Training"""
42
+ df = input_df.copy()
43
+
44
+ # Feature Engineering
45
+ symptom_cols = ['fever', 'chills', 'headache', 'muscle_aches', 'fatigue',
46
+ 'loss_of_appetite', 'jaundice', 'abdominal_pain', 'joint_pain',
47
+ 'splenomegaly', 'pallor', 'lymphadenopathy']
48
+
49
+ df['symptom_severity_score'] = df[[c for c in symptom_cols if c in df.columns]].sum(axis=1)
50
 
51
+ if 'age' in df.columns:
52
+ df['age_group'] = pd.cut(df['age'], bins=[-1, 5, 12, 55, 120], labels=[0, 1, 2, 3]).astype(float)
53
+
54
+ if 'hb' in df.columns and 'wbc' in df.columns:
55
+ df['infection_anemia_ratio'] = df['wbc'] / (df['hb'] + 1e-5)
56
+
57
+ # Align with model input shapes
58
+ for c in set(FEATURE_NAMES) - set(df.columns):
59
+ df[c] = np.nan
60
+ df_aligned = df[FEATURE_NAMES].copy()
61
+
62
+ # Categorical Encoding
63
  MISSING_STR = 'MISSING_CAT'
64
  if cat_columns:
65
  present_cats = [c for c in cat_columns if c in df_aligned.columns]
 
67
  df_aligned[present_cats] = df_aligned[present_cats].astype(str).replace(['nan', 'None'], np.nan)
68
  df_aligned[present_cats] = df_aligned[present_cats].fillna(MISSING_STR)
69
  df_aligned[present_cats] = encoder.transform(df_aligned[present_cats])
70
+
71
  for i, col in enumerate(cat_columns):
72
  if col in present_cats and MISSING_STR in encoder.categories_[i]:
73
  missing_code = list(encoder.categories_[i]).index(MISSING_STR)
74
  df_aligned[col] = df_aligned[col].replace(missing_code, np.nan)
75
+
76
  for col in df_aligned.columns:
77
+ df_aligned[col] = pd.to_numeric(df_aligned[col], errors='coerce')
78
+
79
+ # Impute and Scale
80
+ X_imp = pd.DataFrame(imputer.transform(df_aligned), columns=FEATURE_NAMES)
81
+ X_scaled = pd.DataFrame(scaler.transform(X_imp), columns=FEATURE_NAMES)
 
 
82
 
83
+ return X_scaled
84
+
85
+ def get_specific_coinfection_type(hb, retic, hb_decline, hb_s):
86
+ """Determines granular sub-type of Co-infection based on critical markers"""
87
+ if hb < 5.0:
88
+ return "Co-infection: Severe Hyperhemolytic Malarial Crisis"
89
+ elif retic > 8.0:
90
+ return "Co-infection: Acute Hemolytic Malarial Crisis"
91
+ elif hb_decline and hb_s > 0:
92
+ return "Co-infection: Rapidly Progressing Vaso-occlusive Malarial Crisis"
93
+ else:
94
+ return "Co-infection: Concurrent Malaria & Sickle Cell Crisis"
95
+
96
+ def get_clinical_recs(diag, rule_triggered=None):
97
+ recs = f"### Clinical Decision Support Protocol\n\n"
98
 
99
+ if rule_triggered:
100
+ recs += f"**Critical Protocol Triggered:** *{rule_triggered}*\n\n"
101
+
102
+ if 'Malaria' in diag and 'Co-infection' not in diag:
103
+ recs += "**Protocol:** Initiate Artemisinin-based Combination Therapy (ACT) per WHO guidelines.\n"
104
+ elif diag == 'SCA':
105
+ recs += "**Protocol:** Administer IV Fluids, oxygen therapy, and comprehensive pain management.\n"
106
+ elif 'Co-infection' in diag:
107
+ recs += "**Urgent Protocol:** High risk of hyperhemolytic or severe vaso-occlusive crisis.\n"
108
+ recs += "- **Action:** Immediate admission to high-dependency unit. Initiate rapid intravenous antimalarials, aggressive hydration, and prepare for potential blood transfusion.\n"
109
+ else:
110
+ recs += "**Action:** Patient is currently negative for active Malaria and SCA crisis.\n"
111
+ recs += "- **Follow-up:** Screen for Typhoid, Dengue, or other viral infections if febrile symptoms persist.\n"
112
+
113
+ recs += "\n---\n### Diagnostic Context Notes\n"
114
+ recs += "- **Overlapping Symptoms:** Fever, Fatigue, Jaundice, Splenomegaly, and Headache *(Headache is uncommon in SCA unless accompanied by severe anemia, cerebral malaria, or stroke risk).* \n"
115
+ recs += "- **Co-infection Prevalences:** Key clinical indicators for Co-infection include Severe Pallor + Jaundice, High fever, Splenomegaly + malaria, and Extreme Reticulocyte (>8%) + malaria."
116
+
117
+ return recs
118
 
119
+ def generate_shap_plot(X_scaled):
120
  try:
121
  shap_values = explainer.shap_values(X_scaled)
122
+
123
  if isinstance(shap_values, list):
124
+ pat_shap = shap_values[3][0]
125
+ base_val = explainer.expected_value[3]
126
+ elif len(shap_values.shape) == 3:
127
+ pat_shap = shap_values[0, :, 3]
128
+ base_val = explainer.expected_value[3] if isinstance(explainer.expected_value, list) else explainer.expected_value
129
  else:
130
+ pat_shap = shap_values[0]
131
+ base_val = explainer.expected_value
132
 
133
+ fig, ax = plt.subplots(figsize=(7, 5))
134
+ ax.spines['top'].set_visible(False)
135
+ ax.spines['right'].set_visible(False)
136
+
137
  explanation = shap.Explanation(values=pat_shap, base_values=base_val,
138
+ data=X_scaled.iloc[0], feature_names=FEATURE_NAMES)
139
  shap.waterfall_plot(explanation, show=False)
140
+ plt.title("XAI Feature Contribution (Impact on Co-Infection Risk)", fontsize=11, fontweight='bold')
141
  plt.tight_layout()
142
  return fig
143
  except Exception as e:
144
  fig, ax = plt.subplots(figsize=(6,4))
145
+ ax.text(0.5, 0.5, f"Interpretability Module Offline:\n{str(e)}", ha='center', va='center')
146
  return fig
147
 
148
+ def manual_inference(age, sex, temp, hb, wbc, platelets, hb_a, hb_s, hb_f, malaria_rdt, reticulocyte, hb_rapid_decline,
149
+ fever, chills, headache, muscle_aches, fatigue, loss_of_appetite, jaundice, abdominal_pain, joint_pain, splenomegaly, pallor, lymphadenopathy):
150
  try:
151
+ co_infection_flag = False
152
+ rule_triggered = ""
153
+ specific_coinfection_name = ""
154
+
155
+ # Hardcoded Critical Clinical Override Rules
156
+ if hb < 5.0:
157
+ co_infection_flag = True
158
+ rule_triggered = "Hemoglobin below critical threshold (5.0 g/dL)"
159
+ elif reticulocyte > 8.0 and malaria_rdt == "Positive":
160
+ co_infection_flag = True
161
+ rule_triggered = "Extreme Reticulocyte (>8%) + Positive Malaria RDT"
162
+ elif hb_rapid_decline and malaria_rdt == "Positive" and hb_s > 0:
163
+ co_infection_flag = True
164
+ rule_triggered = "Rapid Hb decline (>1.5g/dL in 48h) + Positive Malaria + SCA Genotype"
165
+
166
+ if co_infection_flag:
167
+ specific_coinfection_name = get_specific_coinfection_type(hb, reticulocyte, hb_rapid_decline, hb_s)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
 
 
 
 
 
169
  input_data = pd.DataFrame({
170
+ 'age': [age], 'sex': [sex], 'temp': [temp], 'hb': [hb], 'wbc': [wbc], 'platelets': [platelets],
171
+ 'hb_a': [hb_a], 'hb_s': [hb_s], 'hb_f': [hb_f],
172
+ 'malaria_rdt': [1.0 if malaria_rdt == "Positive" else 0.0],
173
+ 'reticulocyte': [reticulocyte], 'hb_rapid_decline': [1.0 if hb_rapid_decline else 0.0],
174
+ 'fever': [1.0 if fever else 0.0], 'chills': [1.0 if chills else 0.0], 'headache': [1.0 if headache else 0.0],
175
+ 'muscle_aches': [1.0 if muscle_aches else 0.0], 'fatigue': [1.0 if fatigue else 0.0],
176
+ 'loss_of_appetite': [1.0 if loss_of_appetite else 0.0], 'jaundice': [1.0 if jaundice else 0.0],
177
+ 'abdominal_pain': [1.0 if abdominal_pain else 0.0], 'joint_pain': [1.0 if joint_pain else 0.0],
178
+ 'splenomegaly': [1.0 if splenomegaly else 0.0], 'pallor': [1.0 if pallor else 0.0],
179
+ 'lymphadenopathy': [1.0 if lymphadenopathy else 0.0]
180
  })
181
+
182
+ X_scaled = preprocess_input(input_data)
183
+ probs = best_model.predict_proba(X_scaled)[0]
184
 
185
+ # Map probabilities to class names
186
+ prob_dict = {target_names[i]: probs[i] * 100 for i in range(len(target_names))}
 
 
 
187
 
188
+ # Apply Clinical Overrides if necessary
189
+ if co_infection_flag:
190
+ primary_diag = specific_coinfection_name
191
+ # Adjust probabilities to reflect the clinical override
192
+ prob_dict = {
193
+ specific_coinfection_name: 100.0,
194
+ 'Malaria (Override)': prob_dict['Malaria'],
195
+ 'SCA (Override)': prob_dict['SCA'],
196
+ 'Negative': 0.0
197
+ }
198
+ else:
199
+ pred_idx = np.argmax(probs)
200
+ primary_diag = target_names[pred_idx]
201
+
202
+ # If AI predicted co-infection without triggering rules, still give it a specific name
203
+ if primary_diag == 'Co-infection':
204
+ primary_diag = get_specific_coinfection_type(hb, reticulocyte, hb_rapid_decline, hb_s)
205
+ prob_dict[primary_diag] = prob_dict.pop('Co-infection')
206
+
207
+ # Formatting Output Markdown
208
+ diag_output = f"## Primary Diagnosis: {primary_diag}\n\n### Comprehensive Confidence Breakdown:\n"
209
 
210
+ # Sort and display probabilities descending
211
+ sorted_probs = sorted(prob_dict.items(), key=lambda x: x[1], reverse=True)
212
+ for disease, conf in sorted_probs:
213
+ if 'Co-infection' in disease and 'Override' not in disease:
214
+ diag_output += f"- **{disease}**: {conf:.1f}%\n"
215
+ else:
216
+ diag_output += f"- **{disease}**: {conf:.1f}%\n"
217
+
218
+ recs = get_clinical_recs(primary_diag, rule_triggered)
219
+ fig = generate_shap_plot(X_scaled)
220
+
221
+ return diag_output, recs, fig
222
  except Exception as e:
223
+ return f"### Inference Error\n```\n{traceback.format_exc()}\n```", "System Error.", None
 
224
 
225
+ # ==========================================
226
+ # 3. SYSTEM VALIDATION HELPER FUNCTIONS
227
+ # ==========================================
228
+
229
+ def load_systematic_metrics():
230
  try:
231
+ y_test_val = joblib.load('y_test_val.pkl')
232
+ y_probs_val = joblib.load('y_probs_val.pkl')
233
+ y_pred_val = np.argmax(y_probs_val, axis=1)
234
+
235
+ acc = accuracy_score(y_test_val, y_pred_val)
236
+ cm = confusion_matrix(y_test_val, y_pred_val)
237
 
238
+ sens_list, spec_list = [], []
239
+ for i in range(len(cm)):
240
+ tp = cm[i,i]
241
+ fn = np.sum(cm[i,:]) - tp
242
+ fp = np.sum(cm[:,i]) - tp
243
+ tn = np.sum(cm) - tp - fn - fp
244
+ sens_list.append(tp / (tp + fn) if (tp + fn) > 0 else 0)
245
+ spec_list.append(tn / (tn + fp) if (tn + fp) > 0 else 0)
246
 
247
+ sens = np.mean(sens_list)
248
+ spec = np.mean(spec_list)
249
+
250
+ return f"### Systematic Evaluation Metrics (Held-out Cohort)\n\n- **Overall Accuracy**: {acc*100:.2f}%\n- **Sensitivity (Macro)**: {sens*100:.2f}%\n- **Specificity (Macro)**: {spec*100:.2f}%"
251
+ except Exception as e:
252
+ return f"Error loading validation metrics: Ensure 'y_test_val.pkl' and 'y_probs_val.pkl' exist in memory. \n({str(e)})"
253
+
254
+ def check_calibration(class_name):
255
+ try:
256
+ from sklearn.calibration import CalibrationDisplay
257
+ y_test_val = joblib.load('y_test_val.pkl')
258
+ y_probs_val = joblib.load('y_probs_val.pkl')
259
+ class_idx = target_names.index(class_name)
260
+
261
+ y_true_binary = (y_test_val == class_idx).astype(int)
262
+ y_prob_class = y_probs_val[:, class_idx]
263
+
264
+ fig, ax = plt.subplots(figsize=(6, 5))
265
+ ax.spines['top'].set_visible(False)
266
+ ax.spines['right'].set_visible(False)
267
+ CalibrationDisplay.from_predictions(y_true_binary, y_prob_class, n_bins=10, ax=ax, name=class_name)
268
+ plt.title(f"Reliability Curve (Calibration) for {class_name}", fontweight='bold')
269
+ plt.tight_layout()
270
+ return fig
271
  except Exception as e:
272
+ fig, ax = plt.subplots()
273
+ ax.text(0.5, 0.5, f"Calibration Error:\n{str(e)}", ha='center')
274
+ return fig
275
 
276
  # ==========================================
277
+ # 4. GRADIO UI DEFINITION
278
  # ==========================================
279
+
280
+ custom_theme = gr.themes.Monochrome(
281
+ primary_hue="slate",
282
+ secondary_hue="gray",
283
+ font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"]
284
+ )
285
+
286
+ # 10 Detailed Clinical Examples spanning all feature variations
287
+ clinical_examples = [
288
+ # [age, sex, temp, hb, wbc, platelets, hb_a, hb_s, hb_f, rdt, retic, hb_decline, fever, chills, headache, muscle, fatigue, appetite, jaundice, abd_pain, joint_pain, spleno, pallor, lymph]
289
+ [8, "Male", 39.5, 11.5, 9.5, 150, 98.0, 0.0, 2.0, "Positive", 1.5, False, True, True, True, True, True, True, False, False, False, False, False, False], # 1. Uncomplicated Malaria
290
+ [22, "Female", 39.0, 7.5, 12.0, 90, 95.0, 0.0, 2.0, "Positive", 4.0, False, True, True, True, True, True, True, True, False, False, True, True, False], # 2. Severe Malaria
291
+ [15, "Male", 37.2, 8.0, 11.0, 250, 5.0, 85.0, 10.0, "Negative", 6.0, False, False, False, False, True, True, False, True, True, True, False, True, False], # 3. SCA Vaso-occlusive Crisis
292
+ [18, "Female", 37.5, 4.5, 14.0, 300, 2.0, 90.0, 8.0, "Negative", 10.0, True, False, False, False, False, True, False, True, False, True, True, True, False], # 4. SCA Hyperhemolytic (Trigger Hb<5)
293
+ [12, "Male", 38.8, 6.5, 16.0, 110, 10.0, 80.0, 10.0, "Positive", 9.5, False, True, True, True, True, True, True, True, True, True, True, True, False], # 5. Co-infection (Acute Hemolytic, Retic>8)
294
+ [25, "Female", 39.2, 7.0, 15.0, 100, 5.0, 85.0, 10.0, "Positive", 5.0, True, True, True, True, True, True, True, True, False, True, True, True, False], # 6. Co-infection (Rapidly Progressing)
295
+ [30, "Male", 36.8, 14.0, 6.5, 250, 98.0, 0.0, 2.0, "Negative", 1.0, False, False, False, False, False, False, False, False, False, False, False, False, False], # 7. Healthy Adult
296
+ [45, "Female", 37.8, 13.5, 5.0, 210, 97.0, 0.0, 2.0, "Negative", 1.2, False, True, False, True, True, True, False, False, False, False, False, False, True], # 8. Viral Infection (Non-malarial)
297
+ [10, "Male", 39.8, 6.0, 18.0, 80, 95.0, 0.0, 3.0, "Positive", 7.0, False, True, True, True, False, True, True, True, True, False, True, True, False], # 9. Malaria with Overlapping Symptoms
298
+ [28, "Female", 37.0, 12.5, 7.0, 220, 60.0, 38.0, 2.0, "Negative", 1.5, False, False, False, False, False, False, False, False, False, False, False, False, False] # 10. SCA Trait (Asymptomatic)
299
+ ]
300
+
301
+ with gr.Blocks(theme=custom_theme, title="Hemaclass Clinical Dashboard") as demo:
302
+ gr.Markdown("# Hemaclass Clinical Decision Support System")
303
+ gr.Markdown("Deep Stacking Ensemble Model for Malaria and Sickle Cell Anemia Classification.")
304
+
305
  with gr.Tabs():
306
+ # --- TAB 1: CORE INFERENCE ---
307
+ with gr.TabItem("Single Patient Validation"):
308
  with gr.Row():
309
  with gr.Column(scale=1):
310
+ gr.Markdown("### Demographics & Vitals")
311
+ with gr.Row():
312
+ age_in = gr.Number(label="Age", value=25)
313
+ sex_in = gr.Dropdown(["Male", "Female"], label="Sex", value="Female")
314
+ temp_in = gr.Number(label="Temperature (°C)", value=37.5)
315
+
 
 
 
 
 
 
316
  gr.Markdown("### Clinical Symptoms")
 
 
 
 
 
 
 
 
317
  with gr.Row():
318
+ fever_in = gr.Checkbox(label="Fever")
319
+ chills_in = gr.Checkbox(label="Chills")
320
+ headache_in = gr.Checkbox(label="Headache")
321
+ fatigue_in = gr.Checkbox(label="Fatigue")
322
  with gr.Row():
323
+ jaundice_in = gr.Checkbox(label="Jaundice")
324
+ splenomegaly_in = gr.Checkbox(label="Splenomegaly")
325
+ pallor_in = gr.Checkbox(label="Severe Pallor")
326
+ muscle_in = gr.Checkbox(label="Muscle Aches")
327
+ with gr.Accordion("Additional Symptoms", open=False):
328
+ loss_appetite_in = gr.Checkbox(label="Loss of Appetite")
329
+ abd_pain_in = gr.Checkbox(label="Abdominal Pain")
330
+ joint_pain_in = gr.Checkbox(label="Joint Pain")
331
+ lymph_in = gr.Checkbox(label="Lymphadenopathy")
332
+
333
+ gr.Markdown("### Critical Laboratory Markers")
334
  with gr.Row():
335
+ rdt_in = gr.Radio(["Negative", "Positive"], label="Malaria RDT", value="Negative")
336
+ retic_in = gr.Number(label="Reticulocyte Count (%)", value=2.0)
337
+ with gr.Row():
338
+ hb_in = gr.Number(label="Hemoglobin (g/dL)", value=12.0)
339
+ hb_decline_in = gr.Checkbox(label="Rapid Hb Decline (>1.5g/dl in 48h)")
340
+ with gr.Row():
341
+ hb_a_in = gr.Number(label="HbA Fraction (%)", value=98.0)
342
+ hb_s_in = gr.Number(label="HbS Fraction (%)", value=0.0)
343
+ hb_f_in = gr.Number(label="HbF Fraction (%)", value=2.0)
344
+ with gr.Row():
345
+ wbc_in = gr.Number(label="WBC Count (x10^9/L)", value=8.0)
346
+ platelets_in = gr.Number(label="Platelet Count", value=200)
347
+
348
+ manual_btn = gr.Button("Validate Diagnosis", variant="primary", size="lg")
349
+
350
+ with gr.Column(scale=1):
351
+ gr.Markdown("### System Output")
352
+ out_diag = gr.Markdown()
353
+ out_recs = gr.Markdown()
354
+ out_shap = gr.Plot(label="Feature Contribution Analysis")
355
+
356
+ gr.Markdown("---")
357
+ gr.Markdown("### Load Clinical Scenarios")
358
+ gr.Markdown("Select a predefined clinical case to auto-populate the diagnostic fields.")
359
 
360
+ input_components = [
361
+ age_in, sex_in, temp_in, hb_in, wbc_in, platelets_in, hb_a_in, hb_s_in, hb_f_in,
362
+ rdt_in, retic_in, hb_decline_in, fever_in, chills_in, headache_in, muscle_in,
363
+ fatigue_in, loss_appetite_in, jaundice_in, abd_pain_in, joint_pain_in,
364
+ splenomegaly_in, pallor_in, lymph_in
365
+ ]
366
 
367
+ gr.Examples(
368
+ examples=clinical_examples,
369
+ inputs=input_components,
370
+ label="Predefined Patient Cases"
371
+ )
372
+
373
+ manual_btn.click(
374
+ manual_inference,
375
+ inputs=input_components,
376
+ outputs=[out_diag, out_recs, out_shap]
377
+ )
378
+
379
+ # --- TAB 2: PERFORMANCE METRICS ---
380
+ with gr.TabItem("Systematic Testing"):
381
+ gr.Markdown("### Overall Model Performance on Unseen Test Cohort")
382
+ metrics_btn = gr.Button("Calculate Systematic Metrics", variant="secondary")
383
+ out_metrics = gr.Markdown()
384
+ metrics_btn.click(load_systematic_metrics, inputs=[], outputs=[out_metrics])
385
+
386
+ # --- TAB 3: ADVANCED CALIBRATION ---
387
+ with gr.TabItem("Advanced Validation"):
388
+ gr.Markdown("### Evaluate Diagnosis Calibration")
389
+ gr.Markdown("Select a disease class below to verify the alignment between predicted probabilities and true clinical frequencies.")
390
  with gr.Row():
391
+ class_dropdown = gr.Dropdown(target_names, label="Select Target Class", value="Co-infection")
392
+ calib_btn = gr.Button("Check Calibration", variant="secondary")
393
+ out_calib = gr.Plot()
394
+ calib_btn.click(check_calibration, inputs=[class_dropdown], outputs=[out_calib])
 
 
 
395
 
396
+ # Launch inside Colab
397
  if __name__ == "__main__":
398
+ demo.launch(share=True)