| import gradio as gr |
| import pandas as pd |
| import numpy as np |
| import joblib |
| import onnxruntime as ort |
| import os |
| import logging |
|
|
| |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
| |
| feature_names = ['Age', 'Sex', 'CD4+ T-cell count', 'Viral load', 'WBC count', 'Hemoglobin', 'Platelet count'] |
|
|
| |
| ort_session = None |
| scaler = None |
| model_loaded = False |
| scaler_loaded = False |
|
|
| |
| try: |
| |
| script_dir = os.path.dirname(os.path.abspath(__file__)) |
| os.chdir(script_dir) |
| logging.info(f"Current working directory set to: {os.getcwd()}") |
|
|
| |
| model_path = "hiv_model.onnx" |
| scaler_path = "hiv_scaler.pkl" |
|
|
| if not os.path.exists(model_path): |
| logging.error(f"Model file not found: {model_path}") |
| raise FileNotFoundError(f"Model file not found: {model_path}") |
|
|
| if not os.path.exists(scaler_path): |
| logging.error(f"Scaler file not found: {scaler_path}") |
| raise FileNotFoundError(f"Scaler file not found: {scaler_path}") |
|
|
| |
| ort_session = ort.InferenceSession(model_path) |
| scaler = joblib.load(scaler_path) |
|
|
| model_loaded = True |
| scaler_loaded = True |
|
|
| logging.info("Model and scaler loaded successfully.") |
|
|
| except FileNotFoundError as e: |
| logging.error(f"File not found: {e}") |
| ort_session = None |
| scaler = None |
| model_loaded = False |
| scaler_loaded = False |
|
|
| except Exception as e: |
| logging.exception(f"An error occurred while loading the model or scaler: {e}") |
| ort_session = None |
| scaler = None |
| model_loaded = False |
| scaler_loaded = False |
| |
| |
|
|
| def predict_risk(age, sex, cd4_count, viral_load, wbc_count, hemoglobin, platelet_count): |
| """Predicts HIV risk probability based on input features.""" |
|
|
| if not model_loaded or not scaler_loaded: |
| return "Model or scaler not loaded. Check the logs for errors. Ensure 'hiv_model.onnx' and 'hiv_scaler.pkl' are in the same directory." |
|
|
| try: |
| |
| input_data = { |
| 'Age': [age], |
| 'Sex': [0 if sex == "Female" else 1], |
| 'CD4+ T-cell count': [cd4_count], |
| 'Viral load': [viral_load], |
| 'WBC count': [wbc_count], |
| 'Hemoglobin': [hemoglobin], |
| 'Platelet count': [platelet_count] |
| } |
| input_df = pd.DataFrame(input_data) |
|
|
| |
| scaled_values = scaler.transform(input_df[feature_names]) |
| scaled_df = pd.DataFrame(scaled_values, columns=feature_names) |
|
|
| |
| input_array = scaled_df[feature_names].values.astype(np.float32) |
| ort_inputs = {ort_session.get_inputs()[0].name: input_array} |
| ort_outs = ort_session.run(None, ort_inputs) |
|
|
| |
| probabilities = ort_outs[0][0] |
| risk_probability = probabilities[1] |
| if 0 < risk_probability <= 100: |
| return f"HIV Risk Probability: {risk_probability:.4f}" |
| elif risk_probability > 100: |
| return f"HIV Risk Probability: 100" |
| else: |
| return f"HIV Risk Probability: 0" |
|
|
| except Exception as e: |
| logging.exception(f"An error occurred during prediction: {e}") |
| return f"An error occurred during prediction: {e}. Check the logs for details." |
|
|
| |
| age_input = gr.Number(label="Age", value=30) |
| sex_input = gr.Radio(["Female", "Male"], label="Sex", value="Female") |
| cd4_input = gr.Number(label="CD4+ T-cell count", value=500) |
| viral_input = gr.Number(label="Viral load", value=10000) |
| wbc_input = gr.Number(label="WBC count", value=7000) |
| hemoglobin_input = gr.Number(label="Hemoglobin", value=14.0) |
| platelet_input = gr.Number(label="Platelet count", value=250000) |
|
|
| |
| iface = gr.Interface( |
| fn=predict_risk, |
| inputs=[age_input, sex_input, cd4_input, viral_input, wbc_input, hemoglobin_input, platelet_input], |
| outputs="text", |
| title="Sentinel-P1: HIV Risk Prediction Demo", |
| description="Enter blood report values to estimate HIV risk. This is a demonstration model and should not be used for medical advice.Low risk : <1% probability of HIV infection, Moderate risk: 1% to 5% probability,High risk: >5% probability", |
| ) |
|
|
| iface.launch() |