dfu-suite-pwat / app.py
mmarquezsa's picture
Upgrade to Keras 3.x for model compatibility
cdeb80f verified
"""
DFU Suite - PWAT Scoring API
Hugging Face Spaces deployment for wound analysis
"""
import os
import json
import warnings
import logging
import tempfile
import numpy as np
import pandas as pd
import cv2
import nrrd
import gradio as gr
from PIL import Image
from joblib import load
import xgboost
import tensorflow as tf
# Use Keras 3.x standalone imports
import keras
from keras.models import load_model
from keras.layers import Conv2D, Multiply, Layer
import keras.backend as K
import radiomics
# Suppress warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
warnings.filterwarnings('ignore')
logging.getLogger('radiomics').setLevel(logging.ERROR)
logging.getLogger('tensorflow').setLevel(logging.ERROR)
# Model directory
MODEL_DIR = os.path.join(os.path.dirname(__file__), 'modelos')
# Download models from HF Model Hub if not present
def download_models_from_hub():
"""Download ML models from Hugging Face Model Hub"""
import sys
from huggingface_hub import hf_hub_download
model_repo = "mmarquezsa/dfu-suite-models"
model_files = [
"best_model.keras",
"Categoria3.pkl",
"Categoria4.joblib",
"Categoria5.joblib",
"Categoria6.pkl",
"Categoria7.joblib",
"Categoria8.joblib",
]
os.makedirs(MODEL_DIR, exist_ok=True)
for model_file in model_files:
local_path = os.path.join(MODEL_DIR, model_file)
# Skip if already downloaded
if os.path.exists(local_path):
print(f"[MODEL] {model_file} already exists", flush=True)
continue
try:
print(f"[MODEL] Downloading {model_file} from HF Model Hub...", flush=True)
sys.stdout.flush()
downloaded_path = hf_hub_download(
repo_id=model_repo,
filename=model_file,
local_dir=MODEL_DIR,
local_dir_use_symlinks=False
)
print(f"[MODEL] Downloaded: {model_file}", flush=True)
sys.stdout.flush()
except Exception as e:
print(f"[MODEL] ERROR downloading {model_file}: {e}", flush=True)
sys.stdout.flush()
import traceback
traceback.print_exc()
# Download models before loading
print("[MODEL] Checking/downloading models from HF Model Hub...", flush=True)
try:
download_models_from_hub()
print("[MODEL] Model download check complete", flush=True)
except Exception as e:
print(f"[MODEL] FATAL ERROR in download_models_from_hub: {e}", flush=True)
import traceback
traceback.print_exc()
# Custom Keras layers
class SpatialAttention(Layer):
def __init__(self, kernel_size=7, filters=1, activation='sigmoid', **kwargs):
super(SpatialAttention, self).__init__(**kwargs)
self.kernel_size = kernel_size
self.filters = filters
self.activation = activation
def build(self, input_shape):
self.conv1 = Conv2D(
filters=self.filters,
kernel_size=self.kernel_size,
padding='same',
activation=self.activation,
kernel_initializer='he_normal',
use_bias=False
)
super(SpatialAttention, self).build(input_shape)
def call(self, inputs):
attention = self.conv1(inputs)
return Multiply()([inputs, attention])
def get_config(self):
config = super(SpatialAttention, self).get_config()
config.update({
'kernel_size': self.kernel_size,
'filters': self.filters,
'activation': self.activation
})
return config
# Custom metrics for model loading
def dice_coefficient(y_true, y_pred):
smooth = 1e-6
y_true = tf.cast(y_true, tf.float32)
y_pred = tf.cast(y_pred, tf.float32)
y_true_f = K.flatten(y_true)
y_pred_f = K.flatten(y_pred)
intersection = K.sum(y_true_f * y_pred_f)
return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
def iou_metric(y_true, y_pred):
y_true = tf.cast(y_true, tf.float32)
y_pred = tf.cast(y_pred, tf.float32)
y_true_bin = K.cast(K.greater(y_true, 0.5), K.floatx())
y_pred_bin = K.cast(K.greater(y_pred, 0.5), K.floatx())
intersection = K.sum(y_true_bin * y_pred_bin)
union = K.sum(y_true_bin) + K.sum(y_pred_bin) - intersection
return intersection / (union + K.epsilon())
def precision_metric(y_true, y_pred):
y_true = tf.cast(y_true, tf.float32)
y_pred = tf.cast(y_pred, tf.float32)
true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
return true_positives / (predicted_positives + K.epsilon())
def recall_metric(y_true, y_pred):
y_true = tf.cast(y_true, tf.float32)
y_pred = tf.cast(y_pred, tf.float32)
true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
return true_positives / (possible_positives + K.epsilon())
def f1_score(y_true, y_pred):
prec = precision_metric(y_true, y_pred)
rec = recall_metric(y_true, y_pred)
return 2.0 * ((prec * rec) / (prec + rec + K.epsilon()))
def focal_tversky_loss(y_true, y_pred, alpha=0.7, beta=0.3, gamma=4/3):
y_true = tf.cast(y_true, tf.float32)
y_pred = tf.cast(y_pred, tf.float32)
tp = K.sum(y_true * y_pred, axis=[1, 2, 3])
fp = K.sum((1-y_true) * y_pred, axis=[1, 2, 3])
fn = K.sum(y_true * (1-y_pred), axis=[1, 2, 3])
tversky = (tp + K.epsilon()) / (tp + alpha*fp + beta*fn + K.epsilon())
focal_tversky = K.pow((1 - tversky), gamma)
return K.mean(focal_tversky)
def combined_loss(y_true, y_pred):
return focal_tversky_loss(y_true, y_pred) + keras.losses.BinaryCrossentropy()(y_true, y_pred)
# Load models
print("[MODEL] Loading models into memory...", flush=True)
import sys
sys.stdout.flush()
# Load segmentation model
segmentation_model = None
seg_model_path = os.path.join(MODEL_DIR, 'best_model.keras')
print(f"[MODEL] Checking segmentation model at: {seg_model_path}", flush=True)
sys.stdout.flush()
if os.path.exists(seg_model_path):
file_size_mb = os.path.getsize(seg_model_path) / (1024 * 1024)
print(f"[MODEL] Found best_model.keras ({file_size_mb:.1f} MB)", flush=True)
print(f"[MODEL] Loading segmentation model (this may take 30-60 seconds)...", flush=True)
sys.stdout.flush()
try:
segmentation_model = load_model(seg_model_path, custom_objects={
'SpatialAttention': SpatialAttention,
'dice_coefficient': dice_coefficient,
'iou_metric': iou_metric,
'precision_metric': precision_metric,
'recall_metric': recall_metric,
'f1_score': f1_score,
'combined_loss': combined_loss,
'focal_tversky_loss': focal_tversky_loss
})
print("[MODEL] OK - Segmentation model loaded successfully", flush=True)
print(f"[MODEL] Model type: {type(segmentation_model)}", flush=True)
sys.stdout.flush()
except Exception as e:
print(f"[MODEL] ERROR - Failed to load segmentation model: {e}", flush=True)
sys.stdout.flush()
import traceback
traceback.print_exc()
sys.stdout.flush()
else:
print(f"[MODEL] ERROR - Segmentation model not found at {seg_model_path}", flush=True)
print(f"[MODEL] Directory contents: {os.listdir(MODEL_DIR)}", flush=True)
sys.stdout.flush()
# Load PWAT classification models
def load_xgboost_model(model_name):
json_path = os.path.join(MODEL_DIR, f"{model_name}.json")
pkl_path = os.path.join(MODEL_DIR, f"{model_name}.pkl")
try:
modelo = xgboost.Booster()
modelo.load_model(json_path)
return modelo, 'xgboost_json'
except:
try:
modelo = load(pkl_path)
return modelo, 'xgboost_pkl'
except Exception as e:
print(f"[MODEL] ERROR loading {model_name}: {e}", flush=True)
sys.stdout.flush()
return None, None
# Load all PWAT models
print("[MODEL] Loading PWAT classification models...", flush=True)
sys.stdout.flush()
print("[MODEL] Loading Categoria3...", flush=True)
Categoria3, tipo_cat3 = load_xgboost_model("Categoria3")
print(f"[MODEL] OK - Categoria3 loaded ({tipo_cat3})", flush=True)
print("[MODEL] Loading Categoria4...", flush=True)
Categoria4 = load(os.path.join(MODEL_DIR, "Categoria4.joblib")) if os.path.exists(os.path.join(MODEL_DIR, "Categoria4.joblib")) else None
print("[MODEL] OK - Categoria4 loaded", flush=True)
print("[MODEL] Loading Categoria5...", flush=True)
Categoria5 = load(os.path.join(MODEL_DIR, "Categoria5.joblib")) if os.path.exists(os.path.join(MODEL_DIR, "Categoria5.joblib")) else None
print("[MODEL] OK - Categoria5 loaded", flush=True)
print("[MODEL] Loading Categoria6...", flush=True)
Categoria6, tipo_cat6 = load_xgboost_model("Categoria6")
print(f"[MODEL] OK - Categoria6 loaded ({tipo_cat6})", flush=True)
print("[MODEL] Loading Categoria7...", flush=True)
Categoria7 = load(os.path.join(MODEL_DIR, "Categoria7.joblib")) if os.path.exists(os.path.join(MODEL_DIR, "Categoria7.joblib")) else None
print("[MODEL] OK - Categoria7 loaded", flush=True)
print("[MODEL] Loading Categoria8...", flush=True)
Categoria8 = load(os.path.join(MODEL_DIR, "Categoria8.joblib")) if os.path.exists(os.path.join(MODEL_DIR, "Categoria8.joblib")) else None
print("[MODEL] OK - Categoria8 loaded", flush=True)
print("[MODEL] ====================================", flush=True)
print("[MODEL] ALL MODELS LOADED SUCCESSFULLY", flush=True)
print("[MODEL] ====================================", flush=True)
sys.stdout.flush()
def segment_wound(image):
"""Generate wound segmentation mask from image"""
if segmentation_model is None:
return None
# Preprocess
img = image.resize((256, 256))
img_array = np.array(img) / 255.0
img_batch = np.expand_dims(img_array, axis=0)
# Predict
pred = segmentation_model.predict(img_batch, verbose=0)
mask = (pred[0] > 0.5).astype(np.uint8) * 255
return Image.fromarray(mask.squeeze(), mode='L')
def predict_pwat(image, mask):
"""Predict PWAT scores from image and mask"""
with tempfile.TemporaryDirectory() as tmpdir:
# Save images temporarily
img_path = os.path.join(tmpdir, "image.jpg")
mask_path = os.path.join(tmpdir, "mask.jpg")
image.save(img_path)
mask.save(mask_path)
# Load and preprocess
img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
mask_arr = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
if img is None or mask_arr is None:
return {"error": "Could not load images"}
if np.max(mask_arr) == 0:
return {"error": "Mask is empty"}
# Normalize and resize
mask_arr = (mask_arr > 0).astype(np.uint8)
img = cv2.resize(img, (256, 256))
mask_arr = cv2.resize(mask_arr, (256, 256), interpolation=cv2.INTER_NEAREST)
# Save as NRRD for pyradiomics
img_nrrd = os.path.join(tmpdir, "image.nrrd")
mask_nrrd = os.path.join(tmpdir, "mask.nrrd")
nrrd.write(img_nrrd, img)
nrrd.write(mask_nrrd, mask_arr)
# Extract radiomics features
extractor = radiomics.featureextractor.RadiomicsFeatureExtractor()
result = extractor.execute(img_nrrd, mask_nrrd)
# Filter features
keys_to_exclude = [
'diagnostics_Versions_PyRadiomics', 'diagnostics_Versions_Numpy',
'diagnostics_Versions_SimpleITK', 'diagnostics_Versions_PyWavelet',
'diagnostics_Versions_Python', 'diagnostics_Configuration_Settings',
'diagnostics_Image-original_Spacing', 'diagnostics_Image-original_Size',
'diagnostics_Image-original_Mean', 'diagnostics_Image-original_Minimum',
'diagnostics_Image-original_Maximum', 'diagnostics_Mask-original_Hash',
'diagnostics_Mask-original_Spacing', 'diagnostics_Mask-original_Size',
'diagnostics_Configuration_EnabledImageTypes', 'diagnostics_Image-original_Hash',
'diagnostics_Image-original_Dimensionality', 'diagnostics_Mask-original_CenterOfMass',
'diagnostics_Mask-original_BoundingBox', 'diagnostics_Mask-original_CenterOfMassIndex'
]
filtered_features = {k: v for k, v in result.items() if k not in keys_to_exclude}
df = pd.DataFrame([filtered_features])
# Remove non-numeric columns
if 'imagen' in df.columns:
df = df.drop(['imagen'], axis=1)
if len(df.columns) > 2:
df = df.drop(df.columns[:2], axis=1)
for col in df.columns:
df[col] = pd.to_numeric(df[col], errors='coerce')
df = df.fillna(0)
# Predict each category
modelos = [Categoria3, Categoria4, Categoria5, Categoria6, Categoria7, Categoria8]
tipos_modelo = [tipo_cat3, 'sklearn', 'sklearn', tipo_cat6, 'sklearn', 'sklearn']
resultados = []
for modelo, cat_num, tipo in zip(modelos, range(3, 9), tipos_modelo):
try:
if modelo is None:
resultados.append(0)
continue
if tipo and tipo.startswith('xgboost'):
import xgboost as xgb
data_array = df.values.flatten().astype(np.float64).reshape(1, -1)
if tipo == 'xgboost_json':
dmatrix = xgb.DMatrix(data_array)
prediccion = modelo.predict(dmatrix)
else:
prediccion = modelo.predict_proba(data_array)
if isinstance(prediccion, np.ndarray):
if prediccion.ndim == 2 and prediccion.shape[1] > 1:
resultado = int(np.argmax(prediccion[0])) + 1
elif prediccion.ndim == 1 and len(prediccion) > 1:
resultado = int(np.argmax(prediccion)) + 1
else:
resultado = int(round(float(prediccion.flatten()[0])))
else:
resultado = int(round(float(prediccion)))
else:
data_array = df.values.astype(np.float32)
if data_array.ndim == 1:
data_array = data_array.reshape(1, -1)
prediccion = modelo.predict(data_array)
resultado = int(prediccion[0]) if hasattr(prediccion, '__len__') else int(prediccion)
resultados.append(resultado)
except Exception as e:
print(f"Error predicting category {cat_num}: {e}")
resultados.append(0)
# Build results
results_dict = {
"Cat1": 0, # Categories 1-2 not predicted by ML
"Cat2": 0,
"Cat3": resultados[0] if len(resultados) > 0 else 0,
"Cat4": resultados[1] if len(resultados) > 1 else 0,
"Cat5": resultados[2] if len(resultados) > 2 else 0,
"Cat6": resultados[3] if len(resultados) > 3 else 0,
"Cat7": resultados[4] if len(resultados) > 4 else 0,
"Cat8": resultados[5] if len(resultados) > 5 else 0,
}
return results_dict
def process_wound(image, mask=None):
"""Main processing function"""
if image is None:
return None, {"error": "No image provided"}
# Convert to PIL if needed
if isinstance(image, np.ndarray):
image = Image.fromarray(image).convert('RGB')
# Generate mask if not provided
if mask is None:
mask = segment_wound(image)
if mask is None:
return None, {"error": "Segmentation model not available"}
else:
if isinstance(mask, np.ndarray):
mask = Image.fromarray(mask).convert('L')
# Predict PWAT
results = predict_pwat(image, mask)
return mask, results
# Gradio interface
with gr.Blocks(title="DFU Suite - PWAT Scoring") as demo:
gr.Markdown("# DFU Suite - Wound Analysis API")
gr.Markdown("Upload a wound image to get automatic segmentation and PWAT scoring.")
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Wound Image", type="pil")
input_mask = gr.Image(label="Mask (optional)", type="pil")
submit_btn = gr.Button("Analyze", variant="primary")
with gr.Column():
output_mask = gr.Image(label="Segmentation Mask")
output_json = gr.JSON(label="PWAT Scores")
submit_btn.click(
fn=process_wound,
inputs=[input_image, input_mask],
outputs=[output_mask, output_json]
)
gr.Markdown("""
## API Usage
```python
from gradio_client import Client
client = Client("YOUR_SPACE_URL")
result = client.predict(
image="path/to/image.jpg",
mask=None, # Optional
api_name="/predict"
)
```
""")
if __name__ == "__main__":
print("[GRADIO] ====================================", flush=True)
print("[GRADIO] LAUNCHING GRADIO APPLICATION", flush=True)
print("[GRADIO] Application is now ready to use!", flush=True)
print("[GRADIO] ====================================", flush=True)
import sys
sys.stdout.flush()
demo.launch()