""" DFU Suite - PWAT Scoring API Hugging Face Spaces deployment for wound analysis """ import os import json import warnings import logging import tempfile import numpy as np import pandas as pd import cv2 import nrrd import gradio as gr from PIL import Image from joblib import load import xgboost import tensorflow as tf # Use Keras 3.x standalone imports import keras from keras.models import load_model from keras.layers import Conv2D, Multiply, Layer import keras.backend as K import radiomics # Suppress warnings os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' warnings.filterwarnings('ignore') logging.getLogger('radiomics').setLevel(logging.ERROR) logging.getLogger('tensorflow').setLevel(logging.ERROR) # Model directory MODEL_DIR = os.path.join(os.path.dirname(__file__), 'modelos') # Download models from HF Model Hub if not present def download_models_from_hub(): """Download ML models from Hugging Face Model Hub""" import sys from huggingface_hub import hf_hub_download model_repo = "mmarquezsa/dfu-suite-models" model_files = [ "best_model.keras", "Categoria3.pkl", "Categoria4.joblib", "Categoria5.joblib", "Categoria6.pkl", "Categoria7.joblib", "Categoria8.joblib", ] os.makedirs(MODEL_DIR, exist_ok=True) for model_file in model_files: local_path = os.path.join(MODEL_DIR, model_file) # Skip if already downloaded if os.path.exists(local_path): print(f"[MODEL] {model_file} already exists", flush=True) continue try: print(f"[MODEL] Downloading {model_file} from HF Model Hub...", flush=True) sys.stdout.flush() downloaded_path = hf_hub_download( repo_id=model_repo, filename=model_file, local_dir=MODEL_DIR, local_dir_use_symlinks=False ) print(f"[MODEL] Downloaded: {model_file}", flush=True) sys.stdout.flush() except Exception as e: print(f"[MODEL] ERROR downloading {model_file}: {e}", flush=True) sys.stdout.flush() import traceback traceback.print_exc() # Download models before loading print("[MODEL] Checking/downloading models from HF Model Hub...", flush=True) try: download_models_from_hub() print("[MODEL] Model download check complete", flush=True) except Exception as e: print(f"[MODEL] FATAL ERROR in download_models_from_hub: {e}", flush=True) import traceback traceback.print_exc() # Custom Keras layers class SpatialAttention(Layer): def __init__(self, kernel_size=7, filters=1, activation='sigmoid', **kwargs): super(SpatialAttention, self).__init__(**kwargs) self.kernel_size = kernel_size self.filters = filters self.activation = activation def build(self, input_shape): self.conv1 = Conv2D( filters=self.filters, kernel_size=self.kernel_size, padding='same', activation=self.activation, kernel_initializer='he_normal', use_bias=False ) super(SpatialAttention, self).build(input_shape) def call(self, inputs): attention = self.conv1(inputs) return Multiply()([inputs, attention]) def get_config(self): config = super(SpatialAttention, self).get_config() config.update({ 'kernel_size': self.kernel_size, 'filters': self.filters, 'activation': self.activation }) return config # Custom metrics for model loading def dice_coefficient(y_true, y_pred): smooth = 1e-6 y_true = tf.cast(y_true, tf.float32) y_pred = tf.cast(y_pred, tf.float32) y_true_f = K.flatten(y_true) y_pred_f = K.flatten(y_pred) intersection = K.sum(y_true_f * y_pred_f) return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth) def iou_metric(y_true, y_pred): y_true = tf.cast(y_true, tf.float32) y_pred = tf.cast(y_pred, tf.float32) y_true_bin = K.cast(K.greater(y_true, 0.5), K.floatx()) y_pred_bin = K.cast(K.greater(y_pred, 0.5), K.floatx()) intersection = K.sum(y_true_bin * y_pred_bin) union = K.sum(y_true_bin) + K.sum(y_pred_bin) - intersection return intersection / (union + K.epsilon()) def precision_metric(y_true, y_pred): y_true = tf.cast(y_true, tf.float32) y_pred = tf.cast(y_pred, tf.float32) true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1))) return true_positives / (predicted_positives + K.epsilon()) def recall_metric(y_true, y_pred): y_true = tf.cast(y_true, tf.float32) y_pred = tf.cast(y_pred, tf.float32) true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) possible_positives = K.sum(K.round(K.clip(y_true, 0, 1))) return true_positives / (possible_positives + K.epsilon()) def f1_score(y_true, y_pred): prec = precision_metric(y_true, y_pred) rec = recall_metric(y_true, y_pred) return 2.0 * ((prec * rec) / (prec + rec + K.epsilon())) def focal_tversky_loss(y_true, y_pred, alpha=0.7, beta=0.3, gamma=4/3): y_true = tf.cast(y_true, tf.float32) y_pred = tf.cast(y_pred, tf.float32) tp = K.sum(y_true * y_pred, axis=[1, 2, 3]) fp = K.sum((1-y_true) * y_pred, axis=[1, 2, 3]) fn = K.sum(y_true * (1-y_pred), axis=[1, 2, 3]) tversky = (tp + K.epsilon()) / (tp + alpha*fp + beta*fn + K.epsilon()) focal_tversky = K.pow((1 - tversky), gamma) return K.mean(focal_tversky) def combined_loss(y_true, y_pred): return focal_tversky_loss(y_true, y_pred) + keras.losses.BinaryCrossentropy()(y_true, y_pred) # Load models print("[MODEL] Loading models into memory...", flush=True) import sys sys.stdout.flush() # Load segmentation model segmentation_model = None seg_model_path = os.path.join(MODEL_DIR, 'best_model.keras') print(f"[MODEL] Checking segmentation model at: {seg_model_path}", flush=True) sys.stdout.flush() if os.path.exists(seg_model_path): file_size_mb = os.path.getsize(seg_model_path) / (1024 * 1024) print(f"[MODEL] Found best_model.keras ({file_size_mb:.1f} MB)", flush=True) print(f"[MODEL] Loading segmentation model (this may take 30-60 seconds)...", flush=True) sys.stdout.flush() try: segmentation_model = load_model(seg_model_path, custom_objects={ 'SpatialAttention': SpatialAttention, 'dice_coefficient': dice_coefficient, 'iou_metric': iou_metric, 'precision_metric': precision_metric, 'recall_metric': recall_metric, 'f1_score': f1_score, 'combined_loss': combined_loss, 'focal_tversky_loss': focal_tversky_loss }) print("[MODEL] OK - Segmentation model loaded successfully", flush=True) print(f"[MODEL] Model type: {type(segmentation_model)}", flush=True) sys.stdout.flush() except Exception as e: print(f"[MODEL] ERROR - Failed to load segmentation model: {e}", flush=True) sys.stdout.flush() import traceback traceback.print_exc() sys.stdout.flush() else: print(f"[MODEL] ERROR - Segmentation model not found at {seg_model_path}", flush=True) print(f"[MODEL] Directory contents: {os.listdir(MODEL_DIR)}", flush=True) sys.stdout.flush() # Load PWAT classification models def load_xgboost_model(model_name): json_path = os.path.join(MODEL_DIR, f"{model_name}.json") pkl_path = os.path.join(MODEL_DIR, f"{model_name}.pkl") try: modelo = xgboost.Booster() modelo.load_model(json_path) return modelo, 'xgboost_json' except: try: modelo = load(pkl_path) return modelo, 'xgboost_pkl' except Exception as e: print(f"[MODEL] ERROR loading {model_name}: {e}", flush=True) sys.stdout.flush() return None, None # Load all PWAT models print("[MODEL] Loading PWAT classification models...", flush=True) sys.stdout.flush() print("[MODEL] Loading Categoria3...", flush=True) Categoria3, tipo_cat3 = load_xgboost_model("Categoria3") print(f"[MODEL] OK - Categoria3 loaded ({tipo_cat3})", flush=True) print("[MODEL] Loading Categoria4...", flush=True) Categoria4 = load(os.path.join(MODEL_DIR, "Categoria4.joblib")) if os.path.exists(os.path.join(MODEL_DIR, "Categoria4.joblib")) else None print("[MODEL] OK - Categoria4 loaded", flush=True) print("[MODEL] Loading Categoria5...", flush=True) Categoria5 = load(os.path.join(MODEL_DIR, "Categoria5.joblib")) if os.path.exists(os.path.join(MODEL_DIR, "Categoria5.joblib")) else None print("[MODEL] OK - Categoria5 loaded", flush=True) print("[MODEL] Loading Categoria6...", flush=True) Categoria6, tipo_cat6 = load_xgboost_model("Categoria6") print(f"[MODEL] OK - Categoria6 loaded ({tipo_cat6})", flush=True) print("[MODEL] Loading Categoria7...", flush=True) Categoria7 = load(os.path.join(MODEL_DIR, "Categoria7.joblib")) if os.path.exists(os.path.join(MODEL_DIR, "Categoria7.joblib")) else None print("[MODEL] OK - Categoria7 loaded", flush=True) print("[MODEL] Loading Categoria8...", flush=True) Categoria8 = load(os.path.join(MODEL_DIR, "Categoria8.joblib")) if os.path.exists(os.path.join(MODEL_DIR, "Categoria8.joblib")) else None print("[MODEL] OK - Categoria8 loaded", flush=True) print("[MODEL] ====================================", flush=True) print("[MODEL] ALL MODELS LOADED SUCCESSFULLY", flush=True) print("[MODEL] ====================================", flush=True) sys.stdout.flush() def segment_wound(image): """Generate wound segmentation mask from image""" if segmentation_model is None: return None # Preprocess img = image.resize((256, 256)) img_array = np.array(img) / 255.0 img_batch = np.expand_dims(img_array, axis=0) # Predict pred = segmentation_model.predict(img_batch, verbose=0) mask = (pred[0] > 0.5).astype(np.uint8) * 255 return Image.fromarray(mask.squeeze(), mode='L') def predict_pwat(image, mask): """Predict PWAT scores from image and mask""" with tempfile.TemporaryDirectory() as tmpdir: # Save images temporarily img_path = os.path.join(tmpdir, "image.jpg") mask_path = os.path.join(tmpdir, "mask.jpg") image.save(img_path) mask.save(mask_path) # Load and preprocess img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) mask_arr = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) if img is None or mask_arr is None: return {"error": "Could not load images"} if np.max(mask_arr) == 0: return {"error": "Mask is empty"} # Normalize and resize mask_arr = (mask_arr > 0).astype(np.uint8) img = cv2.resize(img, (256, 256)) mask_arr = cv2.resize(mask_arr, (256, 256), interpolation=cv2.INTER_NEAREST) # Save as NRRD for pyradiomics img_nrrd = os.path.join(tmpdir, "image.nrrd") mask_nrrd = os.path.join(tmpdir, "mask.nrrd") nrrd.write(img_nrrd, img) nrrd.write(mask_nrrd, mask_arr) # Extract radiomics features extractor = radiomics.featureextractor.RadiomicsFeatureExtractor() result = extractor.execute(img_nrrd, mask_nrrd) # Filter features keys_to_exclude = [ 'diagnostics_Versions_PyRadiomics', 'diagnostics_Versions_Numpy', 'diagnostics_Versions_SimpleITK', 'diagnostics_Versions_PyWavelet', 'diagnostics_Versions_Python', 'diagnostics_Configuration_Settings', 'diagnostics_Image-original_Spacing', 'diagnostics_Image-original_Size', 'diagnostics_Image-original_Mean', 'diagnostics_Image-original_Minimum', 'diagnostics_Image-original_Maximum', 'diagnostics_Mask-original_Hash', 'diagnostics_Mask-original_Spacing', 'diagnostics_Mask-original_Size', 'diagnostics_Configuration_EnabledImageTypes', 'diagnostics_Image-original_Hash', 'diagnostics_Image-original_Dimensionality', 'diagnostics_Mask-original_CenterOfMass', 'diagnostics_Mask-original_BoundingBox', 'diagnostics_Mask-original_CenterOfMassIndex' ] filtered_features = {k: v for k, v in result.items() if k not in keys_to_exclude} df = pd.DataFrame([filtered_features]) # Remove non-numeric columns if 'imagen' in df.columns: df = df.drop(['imagen'], axis=1) if len(df.columns) > 2: df = df.drop(df.columns[:2], axis=1) for col in df.columns: df[col] = pd.to_numeric(df[col], errors='coerce') df = df.fillna(0) # Predict each category modelos = [Categoria3, Categoria4, Categoria5, Categoria6, Categoria7, Categoria8] tipos_modelo = [tipo_cat3, 'sklearn', 'sklearn', tipo_cat6, 'sklearn', 'sklearn'] resultados = [] for modelo, cat_num, tipo in zip(modelos, range(3, 9), tipos_modelo): try: if modelo is None: resultados.append(0) continue if tipo and tipo.startswith('xgboost'): import xgboost as xgb data_array = df.values.flatten().astype(np.float64).reshape(1, -1) if tipo == 'xgboost_json': dmatrix = xgb.DMatrix(data_array) prediccion = modelo.predict(dmatrix) else: prediccion = modelo.predict_proba(data_array) if isinstance(prediccion, np.ndarray): if prediccion.ndim == 2 and prediccion.shape[1] > 1: resultado = int(np.argmax(prediccion[0])) + 1 elif prediccion.ndim == 1 and len(prediccion) > 1: resultado = int(np.argmax(prediccion)) + 1 else: resultado = int(round(float(prediccion.flatten()[0]))) else: resultado = int(round(float(prediccion))) else: data_array = df.values.astype(np.float32) if data_array.ndim == 1: data_array = data_array.reshape(1, -1) prediccion = modelo.predict(data_array) resultado = int(prediccion[0]) if hasattr(prediccion, '__len__') else int(prediccion) resultados.append(resultado) except Exception as e: print(f"Error predicting category {cat_num}: {e}") resultados.append(0) # Build results results_dict = { "Cat1": 0, # Categories 1-2 not predicted by ML "Cat2": 0, "Cat3": resultados[0] if len(resultados) > 0 else 0, "Cat4": resultados[1] if len(resultados) > 1 else 0, "Cat5": resultados[2] if len(resultados) > 2 else 0, "Cat6": resultados[3] if len(resultados) > 3 else 0, "Cat7": resultados[4] if len(resultados) > 4 else 0, "Cat8": resultados[5] if len(resultados) > 5 else 0, } return results_dict def process_wound(image, mask=None): """Main processing function""" if image is None: return None, {"error": "No image provided"} # Convert to PIL if needed if isinstance(image, np.ndarray): image = Image.fromarray(image).convert('RGB') # Generate mask if not provided if mask is None: mask = segment_wound(image) if mask is None: return None, {"error": "Segmentation model not available"} else: if isinstance(mask, np.ndarray): mask = Image.fromarray(mask).convert('L') # Predict PWAT results = predict_pwat(image, mask) return mask, results # Gradio interface with gr.Blocks(title="DFU Suite - PWAT Scoring") as demo: gr.Markdown("# DFU Suite - Wound Analysis API") gr.Markdown("Upload a wound image to get automatic segmentation and PWAT scoring.") with gr.Row(): with gr.Column(): input_image = gr.Image(label="Wound Image", type="pil") input_mask = gr.Image(label="Mask (optional)", type="pil") submit_btn = gr.Button("Analyze", variant="primary") with gr.Column(): output_mask = gr.Image(label="Segmentation Mask") output_json = gr.JSON(label="PWAT Scores") submit_btn.click( fn=process_wound, inputs=[input_image, input_mask], outputs=[output_mask, output_json] ) gr.Markdown(""" ## API Usage ```python from gradio_client import Client client = Client("YOUR_SPACE_URL") result = client.predict( image="path/to/image.jpg", mask=None, # Optional api_name="/predict" ) ``` """) if __name__ == "__main__": print("[GRADIO] ====================================", flush=True) print("[GRADIO] LAUNCHING GRADIO APPLICATION", flush=True) print("[GRADIO] Application is now ready to use!", flush=True) print("[GRADIO] ====================================", flush=True) import sys sys.stdout.flush() demo.launch()