| import gradio as gr |
| import nibabel as nib |
| import numpy as np |
| import os |
| import shutil |
| import pickle |
| import pandas as pd |
|
|
| |
| def load_model(): |
| with open('svm_pipeline.pkl', 'rb') as f: |
| return pickle.load(f) |
|
|
| |
| model = load_model() |
|
|
| |
| def get_image_data(filepath): |
| ''' |
| Access the floating point data of an image |
| |
| Input: Filepath to the image |
| |
| Output: The image's floating point data |
| ''' |
| img = nib.load(filepath) |
| data = img.get_fdata() |
| return data |
|
|
| |
| def image_to_vector(image_data, atlas_data): |
| ''' |
| Create a vector from a region by time matrix from an image using the atlas |
| |
| Input: |
| - Data for the image to take points of |
| - Data from the atlas to apply to the image data |
| |
| Output: A vector of the image's region by time matrix |
| ''' |
| |
| time_dim = image_data.shape[-1] |
| column_names = [f'time_{i}' for i in range(time_dim)] |
| region_names = [f'region_{region}' for region in np.unique(atlas_data)] |
|
|
| |
| reshaped_image_data = image_data.reshape(-1, time_dim) |
|
|
| |
| df_times = pd.DataFrame(reshaped_image_data, columns=column_names) |
| |
| |
| reshaped_atlas_data = atlas_data.reshape(-1) |
| |
| |
| df_full = pd.concat([pd.Series(reshaped_atlas_data, name='atlas_region'), df_times], axis=1) |
| |
| |
| regions_x_time = df_full.groupby('atlas_region').mean() |
| regions_x_time.index = region_names |
| |
| |
| regions_x_time_vector = regions_x_time.to_numpy().reshape(-1) |
| return regions_x_time_vector |
|
|
| |
| def preprocess_and_extract_features(nifti_data, atlas_data): |
| ''' |
| Preprocess the input image data and extract features using the atlas. |
| |
| Input: |
| - nifti_data: The NIfTI image data |
| - atlas_data: The atlas data |
| |
| Output: Extracted feature vector |
| ''' |
| features = image_to_vector(nifti_data, atlas_data) |
| num_required_features = 116 |
|
|
| |
| if features.size < num_required_features: |
| features = np.pad(features, (0, num_required_features - features.size), 'constant') |
| else: |
| features = features[:num_required_features] |
|
|
| return features.reshape(1, -1) |
|
|
| def predict_region(input_file): |
| temp_file_path = None |
| try: |
| |
| temp_file_path = input_file.name + ".nii.gz" |
| shutil.copy(input_file.name, temp_file_path) |
| |
| |
| img = nib.load(temp_file_path) |
| data = img.get_fdata() |
| |
| |
| atlas_filepath = 'aal_mask_pad.nii.gz' |
| if not os.path.exists(atlas_filepath): |
| raise FileNotFoundError(f"Atlas file not found at: {atlas_filepath}") |
| |
| atlas_data = get_image_data(atlas_filepath) |
| |
| |
| features = preprocess_and_extract_features(data, atlas_data) |
|
|
| |
| prediction = model.predict(features) |
| return str(prediction[0]) |
| except Exception as e: |
| return f"Error: {e}" |
| finally: |
| |
| if temp_file_path and os.path.exists(temp_file_path): |
| os.remove(temp_file_path) |
| |
|
|
| |
| interface = gr.Interface( |
| fn=predict_region, |
| inputs=gr.File(label="Region Image (NIfTI file)"), |
| outputs="text", |
| title="Region Prediction", |
| description="Upload a region image in NIfTI format to get the prediction.", |
| allow_flagging="never" |
| ) |
|
|
| |
| interface.launch() |