Spaces:
Sleeping
Sleeping
| """ | |
| DFU Suite - PWAT Scoring API | |
| Hugging Face Spaces deployment for wound analysis | |
| """ | |
| import os | |
| import json | |
| import warnings | |
| import logging | |
| import tempfile | |
| import numpy as np | |
| import pandas as pd | |
| import cv2 | |
| import nrrd | |
| import gradio as gr | |
| from PIL import Image | |
| from joblib import load | |
| import xgboost | |
| import tensorflow as tf | |
| # Use Keras 3.x standalone imports | |
| import keras | |
| from keras.models import load_model | |
| from keras.layers import Conv2D, Multiply, Layer | |
| import keras.backend as K | |
| import radiomics | |
| # Suppress warnings | |
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' | |
| warnings.filterwarnings('ignore') | |
| logging.getLogger('radiomics').setLevel(logging.ERROR) | |
| logging.getLogger('tensorflow').setLevel(logging.ERROR) | |
| # Model directory | |
| MODEL_DIR = os.path.join(os.path.dirname(__file__), 'modelos') | |
| # Download models from HF Model Hub if not present | |
| def download_models_from_hub(): | |
| """Download ML models from Hugging Face Model Hub""" | |
| import sys | |
| from huggingface_hub import hf_hub_download | |
| model_repo = "mmarquezsa/dfu-suite-models" | |
| model_files = [ | |
| "best_model.keras", | |
| "Categoria3.pkl", | |
| "Categoria4.joblib", | |
| "Categoria5.joblib", | |
| "Categoria6.pkl", | |
| "Categoria7.joblib", | |
| "Categoria8.joblib", | |
| ] | |
| os.makedirs(MODEL_DIR, exist_ok=True) | |
| for model_file in model_files: | |
| local_path = os.path.join(MODEL_DIR, model_file) | |
| # Skip if already downloaded | |
| if os.path.exists(local_path): | |
| print(f"[MODEL] {model_file} already exists", flush=True) | |
| continue | |
| try: | |
| print(f"[MODEL] Downloading {model_file} from HF Model Hub...", flush=True) | |
| sys.stdout.flush() | |
| downloaded_path = hf_hub_download( | |
| repo_id=model_repo, | |
| filename=model_file, | |
| local_dir=MODEL_DIR, | |
| local_dir_use_symlinks=False | |
| ) | |
| print(f"[MODEL] Downloaded: {model_file}", flush=True) | |
| sys.stdout.flush() | |
| except Exception as e: | |
| print(f"[MODEL] ERROR downloading {model_file}: {e}", flush=True) | |
| sys.stdout.flush() | |
| import traceback | |
| traceback.print_exc() | |
| # Download models before loading | |
| print("[MODEL] Checking/downloading models from HF Model Hub...", flush=True) | |
| try: | |
| download_models_from_hub() | |
| print("[MODEL] Model download check complete", flush=True) | |
| except Exception as e: | |
| print(f"[MODEL] FATAL ERROR in download_models_from_hub: {e}", flush=True) | |
| import traceback | |
| traceback.print_exc() | |
| # Custom Keras layers | |
| class SpatialAttention(Layer): | |
| def __init__(self, kernel_size=7, filters=1, activation='sigmoid', **kwargs): | |
| super(SpatialAttention, self).__init__(**kwargs) | |
| self.kernel_size = kernel_size | |
| self.filters = filters | |
| self.activation = activation | |
| def build(self, input_shape): | |
| self.conv1 = Conv2D( | |
| filters=self.filters, | |
| kernel_size=self.kernel_size, | |
| padding='same', | |
| activation=self.activation, | |
| kernel_initializer='he_normal', | |
| use_bias=False | |
| ) | |
| super(SpatialAttention, self).build(input_shape) | |
| def call(self, inputs): | |
| attention = self.conv1(inputs) | |
| return Multiply()([inputs, attention]) | |
| def get_config(self): | |
| config = super(SpatialAttention, self).get_config() | |
| config.update({ | |
| 'kernel_size': self.kernel_size, | |
| 'filters': self.filters, | |
| 'activation': self.activation | |
| }) | |
| return config | |
| # Custom metrics for model loading | |
| def dice_coefficient(y_true, y_pred): | |
| smooth = 1e-6 | |
| y_true = tf.cast(y_true, tf.float32) | |
| y_pred = tf.cast(y_pred, tf.float32) | |
| y_true_f = K.flatten(y_true) | |
| y_pred_f = K.flatten(y_pred) | |
| intersection = K.sum(y_true_f * y_pred_f) | |
| return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth) | |
| def iou_metric(y_true, y_pred): | |
| y_true = tf.cast(y_true, tf.float32) | |
| y_pred = tf.cast(y_pred, tf.float32) | |
| y_true_bin = K.cast(K.greater(y_true, 0.5), K.floatx()) | |
| y_pred_bin = K.cast(K.greater(y_pred, 0.5), K.floatx()) | |
| intersection = K.sum(y_true_bin * y_pred_bin) | |
| union = K.sum(y_true_bin) + K.sum(y_pred_bin) - intersection | |
| return intersection / (union + K.epsilon()) | |
| def precision_metric(y_true, y_pred): | |
| y_true = tf.cast(y_true, tf.float32) | |
| y_pred = tf.cast(y_pred, tf.float32) | |
| true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) | |
| predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1))) | |
| return true_positives / (predicted_positives + K.epsilon()) | |
| def recall_metric(y_true, y_pred): | |
| y_true = tf.cast(y_true, tf.float32) | |
| y_pred = tf.cast(y_pred, tf.float32) | |
| true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) | |
| possible_positives = K.sum(K.round(K.clip(y_true, 0, 1))) | |
| return true_positives / (possible_positives + K.epsilon()) | |
| def f1_score(y_true, y_pred): | |
| prec = precision_metric(y_true, y_pred) | |
| rec = recall_metric(y_true, y_pred) | |
| return 2.0 * ((prec * rec) / (prec + rec + K.epsilon())) | |
| def focal_tversky_loss(y_true, y_pred, alpha=0.7, beta=0.3, gamma=4/3): | |
| y_true = tf.cast(y_true, tf.float32) | |
| y_pred = tf.cast(y_pred, tf.float32) | |
| tp = K.sum(y_true * y_pred, axis=[1, 2, 3]) | |
| fp = K.sum((1-y_true) * y_pred, axis=[1, 2, 3]) | |
| fn = K.sum(y_true * (1-y_pred), axis=[1, 2, 3]) | |
| tversky = (tp + K.epsilon()) / (tp + alpha*fp + beta*fn + K.epsilon()) | |
| focal_tversky = K.pow((1 - tversky), gamma) | |
| return K.mean(focal_tversky) | |
| def combined_loss(y_true, y_pred): | |
| return focal_tversky_loss(y_true, y_pred) + keras.losses.BinaryCrossentropy()(y_true, y_pred) | |
| # Load models | |
| print("[MODEL] Loading models into memory...", flush=True) | |
| import sys | |
| sys.stdout.flush() | |
| # Load segmentation model | |
| segmentation_model = None | |
| seg_model_path = os.path.join(MODEL_DIR, 'best_model.keras') | |
| print(f"[MODEL] Checking segmentation model at: {seg_model_path}", flush=True) | |
| sys.stdout.flush() | |
| if os.path.exists(seg_model_path): | |
| file_size_mb = os.path.getsize(seg_model_path) / (1024 * 1024) | |
| print(f"[MODEL] Found best_model.keras ({file_size_mb:.1f} MB)", flush=True) | |
| print(f"[MODEL] Loading segmentation model (this may take 30-60 seconds)...", flush=True) | |
| sys.stdout.flush() | |
| try: | |
| segmentation_model = load_model(seg_model_path, custom_objects={ | |
| 'SpatialAttention': SpatialAttention, | |
| 'dice_coefficient': dice_coefficient, | |
| 'iou_metric': iou_metric, | |
| 'precision_metric': precision_metric, | |
| 'recall_metric': recall_metric, | |
| 'f1_score': f1_score, | |
| 'combined_loss': combined_loss, | |
| 'focal_tversky_loss': focal_tversky_loss | |
| }) | |
| print("[MODEL] OK - Segmentation model loaded successfully", flush=True) | |
| print(f"[MODEL] Model type: {type(segmentation_model)}", flush=True) | |
| sys.stdout.flush() | |
| except Exception as e: | |
| print(f"[MODEL] ERROR - Failed to load segmentation model: {e}", flush=True) | |
| sys.stdout.flush() | |
| import traceback | |
| traceback.print_exc() | |
| sys.stdout.flush() | |
| else: | |
| print(f"[MODEL] ERROR - Segmentation model not found at {seg_model_path}", flush=True) | |
| print(f"[MODEL] Directory contents: {os.listdir(MODEL_DIR)}", flush=True) | |
| sys.stdout.flush() | |
| # Load PWAT classification models | |
| def load_xgboost_model(model_name): | |
| json_path = os.path.join(MODEL_DIR, f"{model_name}.json") | |
| pkl_path = os.path.join(MODEL_DIR, f"{model_name}.pkl") | |
| try: | |
| modelo = xgboost.Booster() | |
| modelo.load_model(json_path) | |
| return modelo, 'xgboost_json' | |
| except: | |
| try: | |
| modelo = load(pkl_path) | |
| return modelo, 'xgboost_pkl' | |
| except Exception as e: | |
| print(f"[MODEL] ERROR loading {model_name}: {e}", flush=True) | |
| sys.stdout.flush() | |
| return None, None | |
| # Load all PWAT models | |
| print("[MODEL] Loading PWAT classification models...", flush=True) | |
| sys.stdout.flush() | |
| print("[MODEL] Loading Categoria3...", flush=True) | |
| Categoria3, tipo_cat3 = load_xgboost_model("Categoria3") | |
| print(f"[MODEL] OK - Categoria3 loaded ({tipo_cat3})", flush=True) | |
| print("[MODEL] Loading Categoria4...", flush=True) | |
| Categoria4 = load(os.path.join(MODEL_DIR, "Categoria4.joblib")) if os.path.exists(os.path.join(MODEL_DIR, "Categoria4.joblib")) else None | |
| print("[MODEL] OK - Categoria4 loaded", flush=True) | |
| print("[MODEL] Loading Categoria5...", flush=True) | |
| Categoria5 = load(os.path.join(MODEL_DIR, "Categoria5.joblib")) if os.path.exists(os.path.join(MODEL_DIR, "Categoria5.joblib")) else None | |
| print("[MODEL] OK - Categoria5 loaded", flush=True) | |
| print("[MODEL] Loading Categoria6...", flush=True) | |
| Categoria6, tipo_cat6 = load_xgboost_model("Categoria6") | |
| print(f"[MODEL] OK - Categoria6 loaded ({tipo_cat6})", flush=True) | |
| print("[MODEL] Loading Categoria7...", flush=True) | |
| Categoria7 = load(os.path.join(MODEL_DIR, "Categoria7.joblib")) if os.path.exists(os.path.join(MODEL_DIR, "Categoria7.joblib")) else None | |
| print("[MODEL] OK - Categoria7 loaded", flush=True) | |
| print("[MODEL] Loading Categoria8...", flush=True) | |
| Categoria8 = load(os.path.join(MODEL_DIR, "Categoria8.joblib")) if os.path.exists(os.path.join(MODEL_DIR, "Categoria8.joblib")) else None | |
| print("[MODEL] OK - Categoria8 loaded", flush=True) | |
| print("[MODEL] ====================================", flush=True) | |
| print("[MODEL] ALL MODELS LOADED SUCCESSFULLY", flush=True) | |
| print("[MODEL] ====================================", flush=True) | |
| sys.stdout.flush() | |
| def segment_wound(image): | |
| """Generate wound segmentation mask from image""" | |
| if segmentation_model is None: | |
| return None | |
| # Preprocess | |
| img = image.resize((256, 256)) | |
| img_array = np.array(img) / 255.0 | |
| img_batch = np.expand_dims(img_array, axis=0) | |
| # Predict | |
| pred = segmentation_model.predict(img_batch, verbose=0) | |
| mask = (pred[0] > 0.5).astype(np.uint8) * 255 | |
| return Image.fromarray(mask.squeeze(), mode='L') | |
| def predict_pwat(image, mask): | |
| """Predict PWAT scores from image and mask""" | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| # Save images temporarily | |
| img_path = os.path.join(tmpdir, "image.jpg") | |
| mask_path = os.path.join(tmpdir, "mask.jpg") | |
| image.save(img_path) | |
| mask.save(mask_path) | |
| # Load and preprocess | |
| img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) | |
| mask_arr = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) | |
| if img is None or mask_arr is None: | |
| return {"error": "Could not load images"} | |
| if np.max(mask_arr) == 0: | |
| return {"error": "Mask is empty"} | |
| # Normalize and resize | |
| mask_arr = (mask_arr > 0).astype(np.uint8) | |
| img = cv2.resize(img, (256, 256)) | |
| mask_arr = cv2.resize(mask_arr, (256, 256), interpolation=cv2.INTER_NEAREST) | |
| # Save as NRRD for pyradiomics | |
| img_nrrd = os.path.join(tmpdir, "image.nrrd") | |
| mask_nrrd = os.path.join(tmpdir, "mask.nrrd") | |
| nrrd.write(img_nrrd, img) | |
| nrrd.write(mask_nrrd, mask_arr) | |
| # Extract radiomics features | |
| extractor = radiomics.featureextractor.RadiomicsFeatureExtractor() | |
| result = extractor.execute(img_nrrd, mask_nrrd) | |
| # Filter features | |
| keys_to_exclude = [ | |
| 'diagnostics_Versions_PyRadiomics', 'diagnostics_Versions_Numpy', | |
| 'diagnostics_Versions_SimpleITK', 'diagnostics_Versions_PyWavelet', | |
| 'diagnostics_Versions_Python', 'diagnostics_Configuration_Settings', | |
| 'diagnostics_Image-original_Spacing', 'diagnostics_Image-original_Size', | |
| 'diagnostics_Image-original_Mean', 'diagnostics_Image-original_Minimum', | |
| 'diagnostics_Image-original_Maximum', 'diagnostics_Mask-original_Hash', | |
| 'diagnostics_Mask-original_Spacing', 'diagnostics_Mask-original_Size', | |
| 'diagnostics_Configuration_EnabledImageTypes', 'diagnostics_Image-original_Hash', | |
| 'diagnostics_Image-original_Dimensionality', 'diagnostics_Mask-original_CenterOfMass', | |
| 'diagnostics_Mask-original_BoundingBox', 'diagnostics_Mask-original_CenterOfMassIndex' | |
| ] | |
| filtered_features = {k: v for k, v in result.items() if k not in keys_to_exclude} | |
| df = pd.DataFrame([filtered_features]) | |
| # Remove non-numeric columns | |
| if 'imagen' in df.columns: | |
| df = df.drop(['imagen'], axis=1) | |
| if len(df.columns) > 2: | |
| df = df.drop(df.columns[:2], axis=1) | |
| for col in df.columns: | |
| df[col] = pd.to_numeric(df[col], errors='coerce') | |
| df = df.fillna(0) | |
| # Predict each category | |
| modelos = [Categoria3, Categoria4, Categoria5, Categoria6, Categoria7, Categoria8] | |
| tipos_modelo = [tipo_cat3, 'sklearn', 'sklearn', tipo_cat6, 'sklearn', 'sklearn'] | |
| resultados = [] | |
| for modelo, cat_num, tipo in zip(modelos, range(3, 9), tipos_modelo): | |
| try: | |
| if modelo is None: | |
| resultados.append(0) | |
| continue | |
| if tipo and tipo.startswith('xgboost'): | |
| import xgboost as xgb | |
| data_array = df.values.flatten().astype(np.float64).reshape(1, -1) | |
| if tipo == 'xgboost_json': | |
| dmatrix = xgb.DMatrix(data_array) | |
| prediccion = modelo.predict(dmatrix) | |
| else: | |
| prediccion = modelo.predict_proba(data_array) | |
| if isinstance(prediccion, np.ndarray): | |
| if prediccion.ndim == 2 and prediccion.shape[1] > 1: | |
| resultado = int(np.argmax(prediccion[0])) + 1 | |
| elif prediccion.ndim == 1 and len(prediccion) > 1: | |
| resultado = int(np.argmax(prediccion)) + 1 | |
| else: | |
| resultado = int(round(float(prediccion.flatten()[0]))) | |
| else: | |
| resultado = int(round(float(prediccion))) | |
| else: | |
| data_array = df.values.astype(np.float32) | |
| if data_array.ndim == 1: | |
| data_array = data_array.reshape(1, -1) | |
| prediccion = modelo.predict(data_array) | |
| resultado = int(prediccion[0]) if hasattr(prediccion, '__len__') else int(prediccion) | |
| resultados.append(resultado) | |
| except Exception as e: | |
| print(f"Error predicting category {cat_num}: {e}") | |
| resultados.append(0) | |
| # Build results | |
| results_dict = { | |
| "Cat1": 0, # Categories 1-2 not predicted by ML | |
| "Cat2": 0, | |
| "Cat3": resultados[0] if len(resultados) > 0 else 0, | |
| "Cat4": resultados[1] if len(resultados) > 1 else 0, | |
| "Cat5": resultados[2] if len(resultados) > 2 else 0, | |
| "Cat6": resultados[3] if len(resultados) > 3 else 0, | |
| "Cat7": resultados[4] if len(resultados) > 4 else 0, | |
| "Cat8": resultados[5] if len(resultados) > 5 else 0, | |
| } | |
| return results_dict | |
| def process_wound(image, mask=None): | |
| """Main processing function""" | |
| if image is None: | |
| return None, {"error": "No image provided"} | |
| # Convert to PIL if needed | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image).convert('RGB') | |
| # Generate mask if not provided | |
| if mask is None: | |
| mask = segment_wound(image) | |
| if mask is None: | |
| return None, {"error": "Segmentation model not available"} | |
| else: | |
| if isinstance(mask, np.ndarray): | |
| mask = Image.fromarray(mask).convert('L') | |
| # Predict PWAT | |
| results = predict_pwat(image, mask) | |
| return mask, results | |
| # Gradio interface | |
| with gr.Blocks(title="DFU Suite - PWAT Scoring") as demo: | |
| gr.Markdown("# DFU Suite - Wound Analysis API") | |
| gr.Markdown("Upload a wound image to get automatic segmentation and PWAT scoring.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(label="Wound Image", type="pil") | |
| input_mask = gr.Image(label="Mask (optional)", type="pil") | |
| submit_btn = gr.Button("Analyze", variant="primary") | |
| with gr.Column(): | |
| output_mask = gr.Image(label="Segmentation Mask") | |
| output_json = gr.JSON(label="PWAT Scores") | |
| submit_btn.click( | |
| fn=process_wound, | |
| inputs=[input_image, input_mask], | |
| outputs=[output_mask, output_json] | |
| ) | |
| gr.Markdown(""" | |
| ## API Usage | |
| ```python | |
| from gradio_client import Client | |
| client = Client("YOUR_SPACE_URL") | |
| result = client.predict( | |
| image="path/to/image.jpg", | |
| mask=None, # Optional | |
| api_name="/predict" | |
| ) | |
| ``` | |
| """) | |
| if __name__ == "__main__": | |
| print("[GRADIO] ====================================", flush=True) | |
| print("[GRADIO] LAUNCHING GRADIO APPLICATION", flush=True) | |
| print("[GRADIO] Application is now ready to use!", flush=True) | |
| print("[GRADIO] ====================================", flush=True) | |
| import sys | |
| sys.stdout.flush() | |
| demo.launch() | |