Spaces:
Sleeping
Sleeping
| # Standard library imports | |
| import argparse | |
| from tqdm import tqdm | |
| import matplotlib.pyplot as plt | |
| from tensorflow.keras.preprocessing.image import load_img, img_to_array | |
| from tensorflow.keras.layers import Conv2D, Multiply, Layer | |
| from tensorflow.keras.models import load_model | |
| import tensorflow.keras.backend as K | |
| import tensorflow as tf | |
| from joblib import load | |
| import radiomics | |
| import xgboost | |
| from xgboost import XGBClassifier | |
| from imblearn.over_sampling import RandomOverSampler | |
| from sklearn.preprocessing import StandardScaler | |
| from sklearn.ensemble import RandomForestClassifier | |
| import nrrd | |
| import SimpleITK as sitk | |
| from PIL import Image | |
| import cv2 | |
| import six | |
| import pandas as pd | |
| import numpy as np | |
| import logging | |
| import warnings | |
| import json | |
| import random | |
| import pickle | |
| from glob import glob | |
| import os | |
| # 0 = mostrar todo, 1 = filtrar INFO, 2 = filtrar INFO+WARNING, 3 = filtrar INFO+WARNING+ERROR | |
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' | |
| os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' | |
| # Suprime todas las warnings de Python | |
| warnings.filterwarnings('ignore') | |
| # Opcional: refina solo ciertos módulos | |
| warnings.filterwarnings('ignore', category=UserWarning, module='h5py') | |
| warnings.filterwarnings('ignore', category=UserWarning, module='xgboost') | |
| warnings.filterwarnings('ignore', category=DeprecationWarning) | |
| # También baja el nivel del logger de TensorFlow/keras | |
| logging.getLogger('tensorflow').setLevel(logging.ERROR) | |
| logging.getLogger('keras').setLevel(logging.ERROR) | |
| # Si usas la capa de logging de absl (TF2+), pon: | |
| try: | |
| import absl.logging | |
| absl.logging.set_verbosity(absl.logging.ERROR) | |
| except ImportError: | |
| pass | |
| # Third-party imports | |
| # Image processing | |
| # Machine Learning | |
| # Deep Learning - TensorFlow/Keras | |
| # Visualization | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| MODEL_DIR = os.path.join(BASE_DIR, 'modelos') | |
| IMGS_DIR = os.path.join(BASE_DIR, '../backend/categorizador/predicts', 'imgs') | |
| MASKS_DIR = os.path.join( | |
| BASE_DIR, '../backend/categorizador/predicts', 'masks') | |
| model_path = os.path.join(MODEL_DIR, 'best_model.keras') | |
| # Cargar modelos con sistema de respaldo (JSON primero, PKL como alternativa) | |
| def load_xgboost_model(model_name): | |
| """Carga modelo XGBoost desde JSON, si falla usa PKL como respaldo""" | |
| json_path = os.path.join(MODEL_DIR, f"{model_name}.json") | |
| pkl_path = os.path.join(MODEL_DIR, f"{model_name}.pkl") | |
| try: | |
| # Intentar cargar desde JSON (formato preferido) | |
| modelo = xgboost.Booster() | |
| modelo.load_model(json_path) | |
| print(f"{model_name}: Cargado desde JSON ✓") | |
| return modelo, 'xgboost_json' | |
| except Exception as e: | |
| print(f"{model_name}: Error al cargar JSON ({e}), intentando PKL...") | |
| try: | |
| # Intentar cargar desde PKL (respaldo) | |
| modelo = load(pkl_path) | |
| print(f"{model_name}: Cargado desde PKL (respaldo) ✓") | |
| return modelo, 'xgboost_pkl' | |
| except Exception as e2: | |
| print(f"{model_name}: ERROR - No se pudo cargar ni JSON ni PKL: {e2}") | |
| raise | |
| # Solo mostrar mensajes de carga en modo debug | |
| debug_mode = os.getenv('DEBUG_PWAT') == '1' | |
| # Temporalmente silenciar print para cargar modelos | |
| original_print = print | |
| if not debug_mode: | |
| print = lambda *args, **kwargs: None | |
| Categoria3, tipo_cat3 = load_xgboost_model("Categoria3") | |
| Categoria4 = load(os.path.join(MODEL_DIR, "Categoria4.joblib")) | |
| Categoria5 = load(os.path.join(MODEL_DIR, "Categoria5.joblib")) | |
| Categoria6, tipo_cat6 = load_xgboost_model("Categoria6") | |
| Categoria7 = load(os.path.join(MODEL_DIR, "Categoria7.joblib")) | |
| Categoria8 = load(os.path.join(MODEL_DIR, "Categoria8.joblib")) | |
| # Restaurar print | |
| print = original_print | |
| class SpatialAttention(Layer): | |
| def __init__(self, kernel_size=7, filters=1, activation='sigmoid', **kwargs): | |
| super(SpatialAttention, self).__init__(**kwargs) | |
| self.kernel_size = kernel_size | |
| self.filters = filters | |
| self.activation = activation | |
| def build(self, input_shape): | |
| self.conv1 = Conv2D( | |
| filters=self.filters, | |
| kernel_size=self.kernel_size, | |
| padding='same', | |
| activation=self.activation, | |
| kernel_initializer='he_normal', | |
| use_bias=False | |
| ) | |
| super(SpatialAttention, self).build(input_shape) | |
| def call(self, inputs): | |
| attention = self.conv1(inputs) | |
| return Multiply()([inputs, attention]) | |
| def get_config(self): | |
| config = super(SpatialAttention, self).get_config() | |
| config.update({ | |
| 'kernel_size': self.kernel_size, | |
| 'filters': self.filters, | |
| 'activation': self.activation | |
| }) | |
| return config | |
| def dice_coefficient(y_true, y_pred): | |
| smooth = 1e-6 | |
| y_true = tf.cast(y_true, tf.float32) | |
| y_pred = tf.cast(y_pred, tf.float32) | |
| y_true_f = K.flatten(y_true) | |
| y_pred_f = K.flatten(y_pred) | |
| intersection = K.sum(y_true_f * y_pred_f) | |
| return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth) | |
| def precision_metric(y_true, y_pred): | |
| y_true = tf.cast(y_true, tf.float32) | |
| y_pred = tf.cast(y_pred, tf.float32) | |
| true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) | |
| predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1))) | |
| return true_positives / (predicted_positives + K.epsilon()) | |
| def recall_metric(y_true, y_pred): | |
| y_true = tf.cast(y_true, tf.float32) | |
| y_pred = tf.cast(y_pred, tf.float32) | |
| true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) | |
| possible_positives = K.sum(K.round(K.clip(y_true, 0, 1))) | |
| return true_positives / (possible_positives + K.epsilon()) | |
| def f1_score(y_true, y_pred): | |
| prec = precision_metric(y_true, y_pred) | |
| rec = recall_metric(y_true, y_pred) | |
| return 2.0 * ((prec * rec) / (prec + rec + K.epsilon())) | |
| def iou_metric(y_true, y_pred): | |
| y_true = tf.cast(y_true, tf.float32) | |
| y_pred = tf.cast(y_pred, tf.float32) | |
| y_true_bin = K.cast(K.greater(y_true, 0.5), K.floatx()) | |
| y_pred_bin = K.cast(K.greater(y_pred, 0.5), K.floatx()) | |
| intersection = K.sum(y_true_bin * y_pred_bin) | |
| union = K.sum(y_true_bin) + K.sum(y_pred_bin) - intersection | |
| return intersection / (union + K.epsilon()) | |
| def focal_tversky_loss(y_true, y_pred, alpha=0.7, beta=0.3, gamma=4/3): | |
| y_true = tf.cast(y_true, tf.float32) | |
| y_pred = tf.cast(y_pred, tf.float32) | |
| tp = K.sum(y_true * y_pred, axis=[1, 2, 3]) | |
| fp = K.sum((1-y_true) * y_pred, axis=[1, 2, 3]) | |
| fn = K.sum(y_true * (1-y_pred), axis=[1, 2, 3]) | |
| tversky = (tp + K.epsilon()) / (tp + alpha*fp + beta*fn + K.epsilon()) | |
| focal_tversky = K.pow((1 - tversky), gamma) | |
| return K.mean(focal_tversky) | |
| def combined_loss(y_true, y_pred): | |
| return focal_tversky_loss(y_true, y_pred) + tf.keras.losses.BinaryCrossentropy()(y_true, y_pred) | |
| # 3. Cargar el modelo | |
| # Definir el directorio de salida (modifica esta ruta según tus necesidades) | |
| output_dir = '../backend/categorizador/predicts' | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Directorio para guardar las predicciones | |
| predictions_dir = os.path.join(output_dir, "masks") | |
| os.makedirs(predictions_dir, exist_ok=True) | |
| # Cargar el modelo con las capas y funciones personalizadas | |
| # Función para convertir HDF5 a formato Keras nativo si es necesario | |
| def load_and_convert_model(model_path, custom_objects): | |
| try: | |
| # Intentar cargar directamente como archivo Keras nativo | |
| return load_model(model_path, custom_objects=custom_objects) | |
| except ValueError as e: | |
| if "Please ensure the file is an accessible `.keras` zip file" in str(e): | |
| print( | |
| f"El archivo {model_path} está en formato HDF5. Convirtiendo a formato Keras nativo...") | |
| # Cargar el modelo HDF5 usando tf.keras con extensión temporal .h5 | |
| temp_h5_path = model_path.replace('.keras', '_temp.h5') | |
| import shutil | |
| shutil.copy2(model_path, temp_h5_path) | |
| try: | |
| # Cargar el modelo desde el archivo temporal .h5 | |
| model = tf.keras.models.load_model( | |
| temp_h5_path, custom_objects=custom_objects) | |
| # Guardar en formato Keras nativo | |
| model.save(model_path, save_format='keras') | |
| print(f"Modelo convertido y guardado como {model_path}") | |
| # Limpiar archivo temporal | |
| os.remove(temp_h5_path) | |
| return model | |
| except Exception as conversion_error: | |
| # Limpiar archivo temporal si falla | |
| if os.path.exists(temp_h5_path): | |
| os.remove(temp_h5_path) | |
| raise conversion_error | |
| else: | |
| raise e | |
| # Cargar el modelo | |
| try: | |
| model = load_and_convert_model(model_path, { | |
| 'SpatialAttention': SpatialAttention, | |
| 'dice_coefficient': dice_coefficient, | |
| 'iou_metric': iou_metric, | |
| 'precision_metric': precision_metric, | |
| 'recall_metric': recall_metric, | |
| 'f1_score': f1_score, | |
| 'combined_loss': combined_loss, | |
| 'focal_tversky_loss': focal_tversky_loss | |
| }) | |
| except Exception as e: | |
| print(f"Error al cargar el modelo desde {model_path}: {e}") | |
| print("Verifique que el archivo del modelo existe y es válido.") | |
| raise | |
| # 4. Definir funciones de preprocesamiento | |
| def load_and_preprocess_image(image_path, target_size=(256, 256)): | |
| """ | |
| Carga y preprocesa una imagen. | |
| Args: | |
| image_path (str): Ruta a la imagen. | |
| target_size (tuple): Tamaño al que redimensionar la imagen. | |
| Returns: | |
| np.array: Imagen preprocesada. | |
| """ | |
| try: | |
| img = Image.open(image_path).convert('RGB') | |
| except Exception as e: | |
| print(f"Error al abrir la imagen {image_path}: {e}") | |
| return None | |
| img = img.resize(target_size) | |
| img = np.array(img) | |
| img = img / 255.0 # Normalización | |
| return img | |
| def load_and_preprocess_mask(mask_path, target_size=(256, 256)): | |
| """ | |
| Carga y preprocesa una máscara. | |
| Args: | |
| mask_path (str): Ruta a la máscara. | |
| target_size (tuple): Tamaño al que redimensionar la máscara. | |
| Returns: | |
| np.array: Máscara preprocesada. | |
| """ | |
| try: | |
| mask = Image.open(mask_path).convert('L') # Escala de grises | |
| except Exception as e: | |
| print(f"Error al abrir la máscara {mask_path}: {e}") | |
| return None | |
| mask = mask.resize(target_size) | |
| mask = np.array(mask) | |
| mask = (mask > 127).astype(np.float32) # Binarización | |
| mask = np.expand_dims(mask, axis=-1) | |
| return mask | |
| def prepare_image_for_prediction(image): | |
| """ | |
| Prepara la imagen para la predicción añadiendo una dimensión de batch. | |
| Args: | |
| image (np.array): Imagen preprocesada. | |
| Returns: | |
| np.array: Imagen con dimensión de batch. | |
| """ | |
| return np.expand_dims(image, axis=0) | |
| def predict_mask(model, image): | |
| """ | |
| Genera la máscara de predicción para una imagen dada. | |
| Args: | |
| model (tf.keras.Model): Modelo cargado. | |
| image (np.array): Imagen preprocesada. | |
| Returns: | |
| np.array: Máscara de predicción. | |
| """ | |
| preprocessed_image = prepare_image_for_prediction(image) | |
| pred_mask = model.predict(preprocessed_image, verbose=0) | |
| pred_mask = np.squeeze(pred_mask, axis=0) | |
| return pred_mask | |
| def postprocess_mask(pred_mask, threshold=0.5): | |
| """ | |
| Binariza la máscara de predicción utilizando un umbral. | |
| Args: | |
| pred_mask (np.array): Máscara de predicción. | |
| threshold (float): Umbral para binarización. | |
| Returns: | |
| np.array: Máscara binarizada. | |
| """ | |
| return (pred_mask > threshold).astype(np.float32) | |
| # 5. Definir funciones de visualización y guardado | |
| def visualize_prediction(original_image, true_mask, pred_mask, postprocessed_mask, save_path=None): | |
| """ | |
| Visualiza y opcionalmente guarda la imagen original, máscara verdadera, | |
| máscara de predicción y máscara postprocesada. | |
| Args: | |
| original_image (np.array): Imagen original. | |
| true_mask (np.array): Máscara verdadera. | |
| pred_mask (np.array): Máscara de predicción. | |
| postprocessed_mask (np.array): Máscara postprocesada. | |
| save_path (str, optional): Ruta para guardar la visualización. | |
| """ | |
| fig, axes = plt.subplots(1, 4, figsize=(20, 5)) | |
| axes[0].imshow(original_image) | |
| axes[0].set_title("Imagen Original") | |
| axes[0].axis('off') | |
| axes[1].imshow(true_mask.squeeze(), cmap='gray') | |
| axes[1].set_title("Máscara Verdadera") | |
| axes[1].axis('off') | |
| axes[2].imshow(pred_mask.squeeze(), cmap='gray') | |
| axes[2].set_title("Máscara de Predicción") | |
| axes[2].axis('off') | |
| axes[3].imshow(postprocessed_mask.squeeze(), cmap='gray') | |
| axes[3].set_title("Máscara Postprocesada") | |
| axes[3].axis('off') | |
| plt.tight_layout() | |
| if save_path: | |
| plt.savefig(save_path) | |
| plt.show() | |
| def save_mask(pred_mask, save_path): | |
| """ | |
| Guarda la máscara de predicción como una imagen. | |
| Args: | |
| pred_mask (np.array): Máscara postprocesada. | |
| save_path (str): Ruta para guardar la máscara. | |
| """ | |
| mask = (pred_mask * 255).astype(np.uint8) | |
| mask_image = Image.fromarray(mask.squeeze(), mode='L') | |
| mask_image.save(save_path) | |
| def predecir_mascara(imagen_path, modelo=model, target_size=(256, 256), threshold=0.5): | |
| imagen = load_and_preprocess_image(imagen_path, target_size=target_size) | |
| if imagen is None: | |
| raise ValueError(f"No se pudo cargar la imagen: {imagen_path}") | |
| prediccion = predict_mask(modelo, imagen) | |
| mascara_predicha = postprocess_mask(prediccion, threshold=threshold) | |
| nombre_archivo = os.path.basename(imagen_path) | |
| nombre_base, _ = os.path.splitext(nombre_archivo) | |
| ruta_mascara = os.path.join(predictions_dir, f"{nombre_base}.jpg") | |
| save_mask(mascara_predicha, ruta_mascara) | |
| if os.getenv('DEBUG_PWAT') == '1': | |
| print(f"Máscara guardada en: {ruta_mascara}") | |
| return ruta_mascara | |
| def predecir(image_path, mask_path): | |
| # Silenciar los mensajes no deseados de PyRadiomics | |
| logging.getLogger('radiomics').setLevel(logging.ERROR) | |
| # Solo mostrar estos mensajes en modo debug | |
| if os.getenv('DEBUG_PWAT') == '1': | |
| print(f"Procesando imagen: {os.path.basename(image_path)}") | |
| print(f"Usando máscara: {os.path.basename(mask_path)}") | |
| extractor = radiomics.featureextractor.RadiomicsFeatureExtractor() | |
| img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE) | |
| mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) | |
| # Validar que las imágenes se cargaron correctamente | |
| if img is None: | |
| raise ValueError( | |
| f'No se pudo cargar la imagen desde {image_path}. Verifique que el archivo existe y es una imagen válida.') | |
| if mask is None: | |
| raise ValueError( | |
| f'No se pudo cargar la máscara desde {mask_path}. Verifique que el archivo existe y es una imagen válida.') | |
| # Validar que la máscara no esté vacía | |
| if np.max(mask) == 0: | |
| raise ValueError( | |
| f'La máscara está completamente vacía (todos los pixeles son 0). Verifique que la máscara contenga regiones segmentadas.') | |
| # Normalizar a máscara binaria con tipo compatible para OpenCV | |
| # Evitar tipos int64 que provocan error en cv2.resize (func != 0) | |
| mask = (mask > 0).astype(np.uint8) | |
| img = cv2.resize(img, (256, 256)) | |
| # Mantener máscara binaria usando interpolación de vecino más cercano | |
| mask = cv2.resize(mask, (256, 256), interpolation=cv2.INTER_NEAREST) | |
| # Detectar la extensión del archivo de máscara para reemplazar correctamente | |
| # Obtiene la extensión (.jpg, .png, etc.) | |
| mask_ext = os.path.splitext(mask_path)[1] | |
| nrrd.write(image_path.replace(".jpg", '.nrrd'), img) | |
| nrrd.write(mask_path.replace(mask_ext, '.nrrd'), mask) | |
| result = extractor.execute(image_path.replace( | |
| ".jpg", '.nrrd'), mask_path.replace(mask_ext, '.nrrd')) | |
| # 5. Lista de claves a excluir | |
| keys_to_exclude = [ | |
| 'diagnostics_Versions_PyRadiomics', | |
| 'diagnostics_Versions_Numpy', | |
| 'diagnostics_Versions_SimpleITK', | |
| 'diagnostics_Versions_PyWavelet', | |
| 'diagnostics_Versions_Python', | |
| 'diagnostics_Configuration_Settings', | |
| 'diagnostics_Image-original_Spacing', | |
| 'diagnostics_Image-original_Size', | |
| 'diagnostics_Image-original_Mean', | |
| 'diagnostics_Image-original_Minimum', | |
| 'diagnostics_Image-original_Maximum', | |
| 'diagnostics_Mask-original_Hash', | |
| 'diagnostics_Mask-original_Spacing', | |
| 'diagnostics_Mask-original_Size', | |
| 'diagnostics_Configuration_EnabledImageTypes', | |
| 'diagnostics_Image-original_Hash', | |
| 'diagnostics_Image-original_Dimensionality', | |
| 'diagnostics_Mask-original_CenterOfMass', | |
| 'diagnostics_Mask-original_BoundingBox', | |
| 'diagnostics_Mask-original_CenterOfMassIndex' | |
| ] | |
| # 6. Filtrar el diccionario 'result' | |
| filtered_features = { | |
| k: v | |
| for k, v in result.items() | |
| if k not in keys_to_exclude | |
| } | |
| # 7. Construir el diccionario final | |
| filtered_data = { | |
| 'imagen': os.path.basename(image_path), | |
| **filtered_features | |
| } | |
| df = pd.DataFrame([filtered_data]) | |
| # Eliminar columnas no numéricas y de diagnóstico | |
| df = df.drop(['imagen'], axis=1) | |
| # Eliminar las primeras 2 columnas de diagnósticos restantes | |
| if len(df.columns) > 2: | |
| df = df.drop(df.columns[:2], axis=1) | |
| # Asegurar que todos los valores sean numéricos | |
| for col in df.columns: | |
| df[col] = pd.to_numeric(df[col], errors='coerce') | |
| # Rellenar NaN con 0 si los hay | |
| df = df.fillna(0) | |
| # Solo mostrar en modo debug | |
| if os.getenv('DEBUG_PWAT') == '1': | |
| print(f"Características extraídas: {len(df.columns)} features") | |
| print(f"Shape de datos: {df.shape}") | |
| modelos = [Categoria3, Categoria4, Categoria5, | |
| Categoria6, Categoria7, Categoria8] | |
| tipos_modelo = [tipo_cat3, 'sklearn', | |
| 'sklearn', tipo_cat6, 'sklearn', 'sklearn'] | |
| resultados = [] | |
| for i, z, tipo in zip(modelos, range(3, 9), tipos_modelo): | |
| try: | |
| if tipo.startswith('xgboost'): | |
| # Para modelos XGBoost (tanto JSON como PKL) | |
| import xgboost as xgb | |
| # Aplanar completamente el array y convertir a float64 | |
| data_flat = df.values.flatten().astype(np.float64) | |
| data_array = data_flat.reshape(1, -1) | |
| # Solo mostrar debug si está habilitado | |
| if os.getenv('DEBUG_PWAT') == '1': | |
| print( | |
| f"Datos para XGBoost Cat{z} ({tipo}): shape={data_array.shape}, dtype={data_array.dtype}") | |
| if tipo == 'xgboost_json': | |
| # Modelo cargado desde JSON - usar DMatrix (Booster) | |
| dmatrix = xgb.DMatrix(data_array) | |
| prediccion = i.predict(dmatrix) | |
| else: | |
| # Modelo cargado desde PKL - es un XGBClassifier, usar predict_proba | |
| prediccion = i.predict_proba(data_array) | |
| if os.getenv('DEBUG_PWAT') == '1': | |
| print( | |
| f"Predicción XGBoost raw: {prediccion}, tipo: {type(prediccion)}") | |
| # Manejar predicción de XGBoost (multiclase) | |
| if isinstance(prediccion, np.ndarray): | |
| if prediccion.ndim == 2 and prediccion.shape[1] > 1: | |
| # Es una predicción multiclase (probabilidades) | |
| # +1 porque las clases empiezan en 1 | |
| resultado = int(np.argmax(prediccion[0])) + 1 | |
| elif prediccion.ndim == 1 and len(prediccion) > 1: | |
| # Es una predicción multiclase en 1D | |
| resultado = int(np.argmax(prediccion)) + 1 | |
| else: | |
| # Es una predicción single value | |
| resultado = int(round(float(prediccion.flatten()[0]))) | |
| else: | |
| resultado = int(round(float(prediccion))) | |
| resultados.append(resultado) | |
| if os.getenv('DEBUG_PWAT') == '1': | |
| print( | |
| f"Categoría {z} (XGBoost-{tipo.split('_')[1].upper()}): {resultado}") | |
| else: | |
| # Para modelos sklearn (RandomForest, etc.) | |
| data_array = df.values.astype(np.float32) | |
| if data_array.ndim == 1: | |
| data_array = data_array.reshape(1, -1) | |
| prediccion = i.predict(data_array) | |
| resultado = int(prediccion[0]) if hasattr( | |
| prediccion, '__len__') else int(prediccion) | |
| resultados.append(resultado) | |
| if os.getenv('DEBUG_PWAT') == '1': | |
| print(f"Categoría {z} (Sklearn): {resultado}") | |
| except Exception as e: | |
| if os.getenv('DEBUG_PWAT') == '1': | |
| print(f"ERROR con la categoría {z}: {e}") | |
| print( | |
| f"Tipo de datos: {type(df.values)}, Shape: {df.values.shape}") | |
| print(f"Usando valor por defecto para categoría {z}") | |
| if z == 3: | |
| resultados.append(2) # Valor por defecto para Cat3 | |
| elif z == 6: | |
| resultados.append(3) # Valor por defecto para Cat6 | |
| else: | |
| resultados.append(1) # Valor por defecto genérico | |
| categories = ["Cat3", "Cat4", "Cat5", "Cat6", "Cat7", "Cat8"] | |
| results_dict = {} | |
| for c, r in zip(categories, resultados): | |
| results_dict[c] = r[0] if hasattr(r, "__getitem__") else r | |
| # Solo mostrar la tabla de resultados en modo debug, siempre imprimir el JSON | |
| if os.getenv('DEBUG_PWAT') == '1': | |
| print("\n" + "="*50) | |
| print("RESULTADOS DE PREDICCIÓN") | |
| print("="*50) | |
| for cat, resultado in results_dict.items(): | |
| print(f"{cat}: {resultado}") | |
| print("="*50) | |
| # SIEMPRE imprimir el JSON para que el backend lo pueda parsear | |
| print(json.dumps(results_dict)) | |
| return results_dict | |
| def mask_predict(image_path): | |
| # Si no es una ruta absoluta, agregar el directorio IMGS_DIR | |
| if not os.path.isabs(image_path): | |
| full_image_path = os.path.join(IMGS_DIR, image_path) | |
| else: | |
| full_image_path = image_path | |
| mask_path = predecir_mascara(full_image_path) | |
| predecir(full_image_path, mask_path) | |
| # mask_predict('./predicts/imgs/mar4.jpg') | |
| # predecir_mascara('./predicts/imgs/mar4 copy.jpg') | |
| # predecir('./predicts/imgs/mar4 copy.jpg','./predicts/masks/mar4 copy.jpg') | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--mode", required=True, | |
| choices=["mask_precit", "predecir_mascara", "predecir"]) | |
| parser.add_argument("--image_path", required=True) | |
| parser.add_argument("--mask_path", required=False) | |
| args = parser.parse_args() | |
| if args.mode == "mask_precit": | |
| mask_precit(args.image_path) | |
| elif args.mode == "predecir_mascara": | |
| result = predecir_mascara(os.path.join(IMGS_DIR, args.image_path)) | |
| print(f"Mask saved at: {result}") | |
| elif args.mode == "predecir": | |
| if not args.mask_path: | |
| raise ValueError( | |
| "Favor de proporcionar la ruta de la máscara con --mask_path") | |
| if not args.image_path: | |
| raise ValueError( | |
| "Favor de proporcionar la ruta de la imagen con --image_path") | |
| predecir(os.path.join(IMGS_DIR, args.image_path), | |
| os.path.join(MASKS_DIR, args.mask_path)) | |