| import gradio as gr |
| import pandas as pd |
| import numpy as np |
| import pyarrow.parquet as pq |
| from sklearn.preprocessing import OneHotEncoder, MinMaxScaler |
| from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier |
| from sklearn.model_selection import train_test_split, cross_val_score, StratifiedKFold, RepeatedStratifiedKFold |
| from sklearn.metrics import ( |
| confusion_matrix, classification_report, |
| precision_score, recall_score, f1_score, |
| accuracy_score, balanced_accuracy_score, matthews_corrcoef, |
| roc_auc_score, auc |
| ) |
| from sklearn.utils.class_weight import compute_sample_weight |
| from sklearn.naive_bayes import GaussianNB |
| from sklearn.svm import SVC |
| from sklearn.linear_model import LogisticRegression |
| from sklearn.metrics import brier_score_loss |
| from sklearn.calibration import calibration_curve, CalibratedClassifierCV |
| from sklearn.linear_model import LinearRegression |
| import xgboost as xgb |
| import matplotlib.pyplot as plt |
| import warnings |
| warnings.filterwarnings('ignore') |
|
|
| training_data = None |
| column_names = None |
| test_list = [] |
|
|
| DEFAULT_N_BOOT_CI = 1000 |
|
|
|
|
| def calibrate_probabilities_undersampling(p_s, beta): |
| |
| p_s = np.asarray(p_s, dtype=float) |
| numerator = beta * p_s |
| denominator = np.maximum((beta - 1.0) * p_s + 1.0, 1e-10) |
| return np.clip(numerator / denominator, 0.0, 1.0) |
|
|
|
|
| def bootstrap_ci_from_oof( |
| point_estimate: float, |
| oof_probs: np.ndarray, |
| n_boot: int = DEFAULT_N_BOOT_CI, |
| confidence: float = 0.95, |
| random_state: int = 42, |
| ) -> tuple: |
| |
| if oof_probs is None or len(oof_probs) == 0: |
| return float(point_estimate), float(point_estimate) |
|
|
| oof_probs = np.asarray(oof_probs, dtype=float) |
| rng = np.random.RandomState(random_state) |
| grand_mean = np.mean(oof_probs) |
| n = len(oof_probs) |
|
|
| boot_means = np.array([ |
| np.mean(rng.choice(oof_probs, size=n, replace=True)) |
| for _ in range(n_boot) |
| ]) |
|
|
| shift = point_estimate - grand_mean |
| boot_means = boot_means + shift |
|
|
| alpha = 1.0 - confidence |
| lo = float(np.clip(np.percentile(boot_means, 100 * alpha / 2), 0.0, 1.0)) |
| hi = float(np.clip(np.percentile(boot_means, 100 * (1 - alpha / 2)), 0.0, 1.0)) |
| return lo, hi |
|
|
|
|
| def compute_efs_ci( |
| p_dwogf: float, |
| p_gf: float, |
| oof_dwogf: np.ndarray, |
| oof_gf: np.ndarray, |
| n_boot: int = DEFAULT_N_BOOT_CI, |
| ) -> tuple: |
| """ |
| EFS = 1 - P(DWOGF) - P(GF). |
| Bootstrap CI uses the same shifted-percentile approach from the first |
| codebase, with DWOGF and GF bootstrapped jointly (matching lengths). |
| """ |
| p_efs = float(np.clip(1.0 - p_dwogf - p_gf, 0.0, 1.0)) |
|
|
| if oof_dwogf is None or oof_gf is None: |
| return p_efs, p_efs, p_efs |
|
|
| oof_dwogf = np.asarray(oof_dwogf, dtype=float) |
| oof_gf = np.asarray(oof_gf, dtype=float) |
|
|
| n_min = min(len(oof_dwogf), len(oof_gf)) |
| oof_dwogf = oof_dwogf[:n_min] |
| oof_gf = oof_gf[:n_min] |
|
|
| rng = np.random.RandomState(42) |
| grand_dwogf = np.mean(oof_dwogf) |
| grand_gf = np.mean(oof_gf) |
| shift_dwogf = p_dwogf - grand_dwogf |
| shift_gf = p_gf - grand_gf |
|
|
| efs_boot = np.array([ |
| np.clip( |
| 1.0 |
| - (np.mean(rng.choice(oof_dwogf, size=n_min, replace=True)) + shift_dwogf) |
| - (np.mean(rng.choice(oof_gf, size=n_min, replace=True)) + shift_gf), |
| 0.0, 1.0, |
| ) |
| for _ in range(n_boot) |
| ]) |
|
|
| alpha = 0.05 |
| efs_lo = float(np.percentile(efs_boot, 100 * alpha / 2)) |
| efs_hi = float(np.percentile(efs_boot, 100 * (1 - alpha / 2))) |
| return p_efs, efs_lo, efs_hi |
|
|
|
|
|
|
| def rand_for(neww_list, x_te, rf, lab, x_tr, actual, paramss, |
| X_Tempp, enco, my_table_str, my_table_num, tabl, tracount): |
| cl_list = [] |
| pro_list = [] |
|
|
| for i in neww_list: |
| dff_copy = i.copy() |
| y_cl = dff_copy.loc[:, lab] |
| x_te = pd.DataFrame(x_te, columns=X_Tempp.columns) |
|
|
| if tracount == 0: |
| mm = RandomForestClassifier( |
| n_estimators=100, criterion='entropy', max_features=None, |
| random_state=42, bootstrap=True, oob_score=True, |
| class_weight='balanced', ccp_alpha=0.01 |
| ) |
| calibrated_rf = CalibratedClassifierCV(estimator=mm, method='isotonic', cv=5) |
| calibrated_rf.fit(dff_copy.drop([lab], axis=1), y_cl) |
| out = calibrated_rf.predict(x_te) |
| probs = calibrated_rf.predict_proba(x_te)[:, 1] |
|
|
| elif tracount == 1: |
| dtrain = xgb.DMatrix(dff_copy.drop([lab], axis=1).to_numpy(), label=y_cl) |
| dtest = xgb.DMatrix(x_te.to_numpy()) |
| params = { |
| 'objective': 'binary:logistic', 'eval_metric': 'logloss', |
| 'max_depth': 60, 'eta': 0.1, |
| 'subsample': 0.8, 'colsample_bytree': 0.8, 'seed': 42 |
| } |
| mm = xgb.train(params, dtrain, 100) |
| probs = mm.predict(dtest) |
| out = (probs > 0.5).astype(int) |
|
|
| elif tracount == 5: |
| mm = LogisticRegression(penalty='l2', solver='newton-cholesky', max_iter=200) |
| mm.fit(dff_copy.drop([lab], axis=1), y_cl) |
| out = mm.predict(x_te) |
| probs = mm.predict_proba(x_te)[:, 1] |
|
|
| elif tracount == 4: |
| mm = GaussianNB(var_smoothing=1e-9) |
| mm.fit(dff_copy.drop([lab], axis=1), y_cl) |
| out = mm.predict(x_te) |
| probs = mm.predict_proba(x_te)[:, 1] |
|
|
| elif tracount == 6: |
| mm = SVC(probability=True, C=3) |
| mm.fit(dff_copy.drop([lab], axis=1), y_cl) |
| out = mm.predict(x_te) |
| probs = mm.predict_proba(x_te)[:, 1] |
|
|
| cl_list.append(out) |
| pro_list.append(probs) |
|
|
| return cl_list, pro_list |
|
|
|
|
| def ne_calib(some_prob, down_factor, origin_factor): |
| aa = some_prob * origin_factor / down_factor |
| denone = (1 - some_prob) * (1 - origin_factor) / (1 - down_factor) |
| return aa / (denone + aa) |
|
|
|
|
| def actualll(sl_list, pro_list, delt, down_factor, origin_factor): |
| ac_list = [] |
| probab_list = [] |
| second_probab_list = [] |
|
|
| for i in range(len(sl_list[0])): |
| sum_val = 0 |
| sum_pro = 0 |
| sum_pro_pro = 0 |
|
|
| for j in range(len(sl_list)): |
| sum_pro += ne_calib(pro_list[j][i], down_factor, origin_factor) |
| sum_pro_pro += pro_list[j][i] |
| sum_val += sl_list[j][i] |
|
|
| sum_val /= len(sl_list) |
| sum_pro /= len(sl_list) |
| sum_pro_pro /= len(sl_list) |
|
|
| if sum_val >= delt: |
| ac_list.append(1) |
| probab_list.append(sum_pro) |
| second_probab_list.append(sum_pro_pro) |
| elif 0 <= sum_val < delt: |
| ac_list.append(0) |
| probab_list.append(1 - sum_pro) |
| second_probab_list.append(1 - sum_pro_pro) |
| else: |
| ac_list.append(0) |
| probab_list.append(sum_pro) |
| second_probab_list.append(sum_pro_pro) |
|
|
| return ac_list, probab_list, second_probab_list |
|
|
|
|
| def sli_mod(c_lisy): |
| sli_list = [] |
| for i in c_lisy: |
| k = np.array(i, dtype=float) |
| k[k < 0.5] = -1 |
| k[k >= 0.5] = 1 |
| sli_list.append(list(k)) |
| return sli_list |
|
|
|
|
| def run_model(x_tr, x_te, y_tr, deltaa, lab, rf, X_Tempp, track, |
| actual, paramss, enco, my_table_str, my_table_num, tabl, |
| tracount, origin_factor): |
|
|
| x_tr = pd.DataFrame(x_tr, columns=X_Tempp.columns) |
| y_tr = pd.DataFrame(y_tr, columns=[test_list[track]]) |
| master_table = pd.concat([x_tr, y_tr], axis=1).copy() |
|
|
| only_minority = master_table.loc[master_table[lab] == 1] |
| only_majority = master_table.drop(only_minority.index) |
| min_index = only_minority.index |
| max_index = only_majority.index |
|
|
| df_list = [] |
| down_factor = 0 |
|
|
| if len(min_index) <= 60: |
| for i in range(20): |
| np.random.seed(i + 30) |
| if test_list[track] in ('VOD', 'STROKEHI'): |
| sampled_array = np.random.choice(max_index, size=int(3 * len(min_index)), replace=True) |
| down_factor = 0.25 |
| elif test_list[track] == 'ACSPSHI': |
| sampled_array = np.random.choice(max_index, size=int(2.5 * len(min_index)), replace=True) |
| down_factor = 1 / (1 + 2.5) |
| else: |
| sampled_array = np.random.choice(max_index, size=int(2 * len(min_index)), replace=True) |
| down_factor = 1 / (1 + 2) |
| df_list.append(pd.concat([only_majority.loc[sampled_array], only_minority])) |
| else: |
| for i in range(10): |
| np.random.seed(i + 30) |
| sampled_array = np.random.choice(max_index, size=int(3 * len(min_index)), replace=True) |
| down_factor = 1 / (1 + 3) |
| df_list.append(pd.concat([only_majority.loc[sampled_array], only_minority])) |
|
|
| c_lisy, pro_lisy = rand_for(df_list, x_te, rf, lab, x_tr, actual, paramss, |
| X_Tempp, enco, my_table_str, my_table_num, tabl, tracount) |
| sli_lisy = sli_mod(c_lisy) |
| a_lisy, probab_lisy, secondlisy = actualll(sli_lisy, pro_lisy, deltaa, down_factor, origin_factor) |
| return a_lisy, probab_lisy, secondlisy |
|
|
|
|
| def load_training_data(): |
| global training_data, column_names, test_list |
|
|
| try: |
| my_table = pq.read_table('year6.parquet').to_pandas() |
| print(my_table['YEARGPF'].value_counts()) |
| my_table = my_table[my_table['YEARGPF'] != '< 2008'].reset_index(drop=True) |
|
|
| pa = pd.read_csv('final_variable.csv') |
| pali = list(pa.iloc[:, 0]) |
| print(pali) |
|
|
| training_data = my_table |
| column_names = pali |
| except FileNotFoundError: |
| return "No training Data" |
|
|
|
|
|
|
| def train_and_evaluate(input_file): |
| global training_data, column_names, test_list |
|
|
| if training_data is None or column_names is None: |
| load_training_data() |
|
|
| if input_file is None: |
| return None, None, None |
|
|
| try: |
| input_data = pd.read_csv(input_file.name) |
|
|
| available_features = [col for col in column_names if col in training_data.columns] |
| available_features_input = [col for col in available_features if col in input_data.columns] |
|
|
| if not available_features_input: |
| return "Error: No matching columns found between datasets", None, None |
|
|
| |
| base_outcome_cols = ['DEAD', 'GF', 'AGVHD', 'CGVHD', 'VOCPSHI', 'STROKEHI'] |
| efs_outcomes = ['DWOGF', 'GF'] |
| all_model_outcomes = base_outcome_cols.copy() |
| if 'DWOGF' not in all_model_outcomes: |
| all_model_outcomes.append('DWOGF') |
|
|
| test_list = all_model_outcomes.copy() |
|
|
| total_cols = available_features + all_model_outcomes |
| inter_df = training_data[total_cols].dropna().reset_index(drop=True) |
|
|
| input_data = input_data[input_data['YEARGPF'] != '< 2008'].reset_index(drop=True) |
| inter_input = input_data[total_cols].dropna().reset_index(drop=True) |
|
|
| my_table = inter_df[available_features] |
| X_input = inter_input[available_features][my_table.columns] |
| my_test = X_input |
|
|
| |
| li1 = ['Yes', 'No'] |
| cols_yes_no_train = [col for col in my_table.columns if my_table[col].isin(li1).all()] |
| my_ye_train = my_table[cols_yes_no_train].replace({'Yes': 1, 'No': 0}).astype('int64') |
| my_table_modify = pd.concat([my_table.drop(cols_yes_no_train, axis=1), my_ye_train], axis=1) |
| my_table_str = my_table_modify.select_dtypes(exclude=['number']) |
| my_table_num = my_table_modify.select_dtypes(include=['number']) |
|
|
| |
| cols_yes_no_test = [col for col in my_test.columns if my_test[col].isin(li1).all()] |
| my_ye_test = my_test[cols_yes_no_test].replace({'Yes': 1, 'No': 0}).astype('int64') |
| my_test_modify = pd.concat([my_test.drop(cols_yes_no_test, axis=1), my_ye_test], axis=1) |
| my_test_str_raw = my_test_modify.select_dtypes(exclude=['number']) |
| my_test_num = my_test_modify.select_dtypes(include=['number']) |
|
|
| |
| df_combined = pd.concat([my_table_str, my_test_str_raw], axis=0, ignore_index=True) |
| enco = OneHotEncoder(sparse_output=False, handle_unknown='ignore') |
| encoded = enco.fit_transform(df_combined) |
| encoded_df = pd.DataFrame(encoded, columns=enco.get_feature_names_out()) |
|
|
| tabl = encoded_df.iloc[:len(my_table_str)].reset_index(drop=True) |
| X_train_full = pd.concat([tabl, my_table_num], axis=1) |
| my_test_str = encoded_df.iloc[len(my_table_str):].reset_index(drop=True) |
| my_test_real = pd.concat([my_test_str, my_test_num], axis=1) |
|
|
| |
| outcome_display_names = { |
| 'DEAD': 'Overall Survival', |
| 'GF': 'Graft Failure', |
| 'AGVHD': 'Acute GVHD', |
| 'CGVHD': 'Chronic GVHD', |
| 'VOCPSHI': 'Vaso-Occlusive Crisis Post-HCT', |
| 'STROKEHI': 'Stroke Post-HCT', |
| 'DWOGF': 'Death Without Graft Failure', |
| } |
|
|
| |
| all_pred_proba = {} |
| all_pred_labels = {} |
| all_y_test = {} |
|
|
| metrics_results = [] |
| calibration_results = [] |
| calibration_plots = [] |
|
|
| for i, outcome_col in enumerate(all_model_outcomes): |
| if outcome_col not in training_data.columns: |
| print(f"Warning: {outcome_col} not in training data, skipping.") |
| continue |
|
|
| y_train_series = inter_df[outcome_col] |
| amaj = y_train_series.value_counts().idxmax() |
| amin = y_train_series.value_counts().idxmin() |
| y_train = y_train_series.replace({amin: 1, amaj: 0}).astype(int).values |
|
|
| y_test_series = inter_input[outcome_col] |
| amaj = y_test_series.value_counts().idxmax() |
| amin = y_test_series.value_counts().idxmin() |
| y_test = y_test_series.replace({amin: 1, amaj: 0}).astype(int).values |
|
|
| vddc = float(np.sum(y_train == 1)) / len(y_train) |
| deltaa = 0.2 |
| rf = RandomForestClassifier() |
| paramss = {} |
| tracount = 0 |
|
|
| y_pred, y_pred_proba, _ = run_model( |
| X_train_full.values, my_test_real.values, y_train, |
| deltaa, outcome_col, rf, X_train_full, i, |
| tabl, paramss, enco, my_table_str, my_table_num, tabl, |
| tracount, vddc |
| ) |
|
|
| y_pred = np.array(y_pred) |
| y_pred_proba = np.array(y_pred_proba) |
|
|
| all_pred_proba[outcome_col] = y_pred_proba |
| all_pred_labels[outcome_col] = y_pred |
| all_y_test[outcome_col] = y_test |
|
|
| |
| if outcome_col == 'DWOGF': |
| continue |
|
|
| outcome_name = outcome_display_names.get(outcome_col, outcome_col) |
|
|
| accuracy = accuracy_score(y_test, y_pred) |
| balanced_acc = balanced_accuracy_score(y_test, y_pred) |
| precision = precision_score(y_test, y_pred, average='weighted', zero_division=0) |
| recall = recall_score(y_test, y_pred, average='weighted', zero_division=0) |
| auc_score = roc_auc_score(y_test, y_pred_proba) |
|
|
| metrics_results.append([ |
| outcome_name, |
| f"{accuracy:.3f}", f"{balanced_acc:.3f}", |
| f"{precision:.3f}", f"{recall:.3f}", f"{auc_score:.3f}" |
| ]) |
|
|
| fraction_pos, mean_pred = calibration_curve(y_test, y_pred_proba, n_bins=10) |
| if len(mean_pred) > 1: |
| slope = np.polyfit(mean_pred, fraction_pos, 1)[0] |
| intercept = np.polyfit(mean_pred, fraction_pos, 1)[1] |
| else: |
| slope, intercept = 1.0, 0.0 |
|
|
| calibration_results.append([outcome_name, f"{slope:.3f}", f"{intercept:.3f}"]) |
|
|
| fig, ax = plt.subplots(figsize=(8, 6)) |
| ax.plot([0, 1], [0, 1], 'k--', label='Perfect Calibration') |
| ax.plot(mean_pred, fraction_pos, 'o-', label=outcome_name) |
| ax.set_xlabel('Mean Predicted Probability') |
| ax.set_ylabel('Fraction of Positives') |
| ax.set_title(f'Calibration Plot – {outcome_name}') |
| ax.legend() |
| ax.grid(True, alpha=0.3) |
| plt.tight_layout() |
| calibration_plots.append(fig) |
|
|
| |
| if 'DWOGF' in all_pred_proba and 'GF' in all_pred_proba: |
| proba_dwogf = all_pred_proba['DWOGF'] |
| proba_gf = all_pred_proba['GF'] |
|
|
| |
| efs_probs = np.clip(1.0 - proba_dwogf - proba_gf, 0.0, 1.0) |
|
|
| |
| p_efs_point = float(np.mean(efs_probs)) |
|
|
| p_efs, efs_lo, efs_hi = compute_efs_ci( |
| p_dwogf = float(np.mean(proba_dwogf)), |
| p_gf = float(np.mean(proba_gf)), |
| oof_dwogf = proba_dwogf, |
| oof_gf = proba_gf, |
| n_boot = DEFAULT_N_BOOT_CI, |
| ) |
|
|
| print( |
| f"\nEvent-Free Survival (EFS): {p_efs:.3f} " |
| f"[95% CI: {efs_lo:.3f} – {efs_hi:.3f}]" |
| ) |
|
|
| |
| if 'DWOGF' in all_y_test and 'GF' in all_y_test: |
| n_min_efs = min(len(all_y_test['DWOGF']), len(all_y_test['GF'])) |
| y_efs_true = np.clip( |
| all_y_test['DWOGF'][:n_min_efs] + all_y_test['GF'][:n_min_efs], |
| 0, 1 |
| ) |
| |
| efs_probs_aligned = efs_probs[:n_min_efs] |
|
|
| if len(np.unique(y_efs_true)) > 1: |
| try: |
| fraction_pos_efs, mean_pred_efs = calibration_curve( |
| y_efs_true, 1.0 - efs_probs_aligned, n_bins=10 |
| ) |
| if len(mean_pred_efs) > 1: |
| slope_efs = np.polyfit(mean_pred_efs, fraction_pos_efs, 1)[0] |
| intercept_efs = np.polyfit(mean_pred_efs, fraction_pos_efs, 1)[1] |
| else: |
| slope_efs, intercept_efs = 1.0, 0.0 |
|
|
| calibration_results.insert( |
| 0, |
| ["Event-Free Survival", f"{slope_efs:.3f}", f"{intercept_efs:.3f}"] |
| ) |
|
|
| fig_efs, ax_efs = plt.subplots(figsize=(8, 6)) |
| ax_efs.plot([0, 1], [0, 1], 'k--', label='Perfect Calibration') |
| ax_efs.plot(mean_pred_efs, fraction_pos_efs, 'o-', |
| color='darkorange', label='Event-Free Survival') |
| ax_efs.set_xlabel('Mean Predicted Probability') |
| ax_efs.set_ylabel('Fraction of Positives') |
| ax_efs.set_title('Calibration Plot – Event-Free Survival') |
| ax_efs.legend() |
| ax_efs.grid(True, alpha=0.3) |
| plt.tight_layout() |
| calibration_plots.insert(0, fig_efs) |
| except Exception as e: |
| print(f"Warning: EFS calibration curve failed: {e}") |
|
|
| |
| if 'DWOGF' in all_y_test and 'GF' in all_y_test: |
| try: |
| n_min_efs = min(len(all_y_test['DWOGF']), len(all_y_test['GF'])) |
| y_efs_true = np.clip( |
| all_y_test['DWOGF'][:n_min_efs] + all_y_test['GF'][:n_min_efs], |
| 0, 1 |
| ) |
| efs_event_prob = 1.0 - efs_probs[:n_min_efs] |
|
|
| |
| efs_pred_labels = (efs_event_prob >= 0.5).astype(int) |
|
|
| accuracy_efs = accuracy_score(y_efs_true, efs_pred_labels) |
| bal_acc_efs = balanced_accuracy_score(y_efs_true, efs_pred_labels) |
| precision_efs = precision_score(y_efs_true, efs_pred_labels, |
| average='weighted', zero_division=0) |
| recall_efs = recall_score(y_efs_true, efs_pred_labels, |
| average='weighted', zero_division=0) |
| auc_efs = roc_auc_score(y_efs_true, efs_event_prob) \ |
| if len(np.unique(y_efs_true)) > 1 else float('nan') |
|
|
| metrics_results.insert(0, [ |
| "Event-Free Survival", |
| f"{accuracy_efs:.3f}", f"{bal_acc_efs:.3f}", |
| f"{precision_efs:.3f}", f"{recall_efs:.3f}", f"{auc_efs:.3f}" |
| ]) |
| except Exception as e: |
| print(f"Warning: EFS metrics computation failed: {e}") |
|
|
| |
| metrics_df = pd.DataFrame( |
| metrics_results, |
| columns=['Outcome', 'Accuracy', 'Balanced Accuracy', 'Precision', 'Recall', 'AUC'] |
| ) |
| calibration_df = pd.DataFrame( |
| calibration_results, |
| columns=['Outcome', 'Slope', 'Intercept'] |
| ) |
|
|
| return metrics_df, calibration_df, calibration_plots |
|
|
| except Exception as e: |
| import traceback |
| traceback.print_exc() |
| return f"Error processing data: {str(e)}", None, None |
|
|
|
|
|
|
|
|
| def create_interface(): |
| load_training_data() |
|
|
| with gr.Blocks( |
| css=""" |
| .gradio-container { max-width: none !important; height: 100vh; overflow-y: auto; } |
| .main-container { padding: 20px; } |
| .big-title { font-size: 2.5em; font-weight: bold; margin-bottom: 30px; text-align: center; } |
| .section-title { font-size: 2em; font-weight: bold; margin: 40px 0 20px 0; color: #2d5aa0; } |
| .subsection-title{ font-size: 1.5em; font-weight: bold; margin: 30px 0 15px 0; color: #4a4a4a; } |
| """, |
| title="ML Model Evaluation Pipeline" |
| ) as demo: |
|
|
| with gr.Column(elem_classes=["main-container"]): |
| gr.HTML('<div class="big-title">Input</div>') |
| gr.Markdown("### Please upload the dataset:") |
| file_input = gr.File(label="Upload Dataset (CSV)", file_types=[".csv"], type="filepath") |
| process_btn = gr.Button("Process Dataset", variant="primary", size="lg") |
|
|
| gr.HTML('<div class="section-title">Outputs</div>') |
|
|
| gr.HTML('<div class="subsection-title">Metrics</div>') |
| metrics_table = gr.Dataframe( |
| headers=["Outcome", "Accuracy", "Balanced Accuracy", "Precision", "Recall", "AUC"], |
| interactive=False, wrap=True |
| ) |
|
|
| gr.HTML('<div class="subsection-title">Calibration</div>') |
| calibration_table = gr.Dataframe( |
| headers=["Outcome", "Slope", "Intercept"], |
| interactive=False, wrap=True |
| ) |
|
|
| gr.Markdown("#### Calibration Curves") |
|
|
| |
| plot_efs = gr.Plot(label="Event-Free Survival") |
| plot_os = gr.Plot(label="Overall Survival") |
| plot_gf = gr.Plot(label="Graft Failure") |
| plot_agvhd = gr.Plot(label="Acute GVHD") |
| plot_cgvhd = gr.Plot(label="Chronic GVHD") |
| plot_voc = gr.Plot(label="Vaso-Occlusive Crisis Post-HCT") |
| plot_stroke = gr.Plot(label="Stroke Post-HCT") |
|
|
| plots = [plot_efs, plot_os, plot_gf, plot_agvhd, plot_cgvhd, plot_voc, plot_stroke] |
|
|
| def process_and_display(file): |
| metrics_df, calibration_df, calibration_plots = train_and_evaluate(file) |
|
|
| if isinstance(metrics_df, str): |
| return (metrics_df, None) + tuple([None] * 7) |
|
|
| plot_outputs = [None] * 7 |
| if calibration_plots: |
| for i, plot in enumerate(calibration_plots[:7]): |
| plot_outputs[i] = plot |
|
|
| return (metrics_df, calibration_df, *plot_outputs) |
|
|
| process_btn.click( |
| fn=process_and_display, |
| inputs=[file_input], |
| outputs=[metrics_table, calibration_table] + plots |
| ) |
|
|
| return demo |
|
|
|
|
| if __name__ == "__main__": |
| demo = create_interface() |
| demo.launch(share=True, inbrowser=True, height=800, show_error=True) |