LoloSemper commited on
Commit
6ac44b9
·
verified ·
1 Parent(s): f8c6941

Upload 4 files

Browse files
ada_learn_malben.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:441760abd9a3a6b143917c68cf1709a1d96323fa334b4952009f50a34ddb57c2
3
+ size 87548073
ada_learn_skin_norm2000.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:820252747a54e3e7494fdc86c998d1d67996cb31bde244ddd4260307e80dd819
3
+ size 87739881
app.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import ViTFeatureExtractor, ViTForImageClassification
3
+ from fastai.learner import load_learner
4
+ from PIL import Image
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ import gradio as gr
8
+ import io
9
+ import base64
10
+
11
+ # 🔹 Cargar modelo ViT desde Hugging Face
12
+ MODEL_NAME = "ahishamm/vit-base-HAM-10000-sharpened-patch-32"
13
+ feature_extractor = ViTFeatureExtractor.from_pretrained(MODEL_NAME)
14
+ model_vit = ViTForImageClassification.from_pretrained(MODEL_NAME)
15
+ model_vit.eval()
16
+
17
+ # 🔹 Cargar modelos Fast.ai desde archivos locales
18
+ model_malignancy = load_learner("ada_learn_malben.pkl")
19
+ model_norm2000 = load_learner("ada_learn_skin_norm2000.pkl")
20
+
21
+ # 🔹 Clases y niveles de riesgo
22
+ CLASSES = [
23
+ "Queratosis actínica / Bowen", "Carcinoma células basales",
24
+ "Lesión queratósica benigna", "Dermatofibroma",
25
+ "Melanoma maligno", "Nevus melanocítico", "Lesión vascular"
26
+ ]
27
+ RISK_LEVELS = {
28
+ 0: {'level': 'Moderado', 'color': '#ffaa00', 'weight': 0.6},
29
+ 1: {'level': 'Alto', 'color': '#ff4444', 'weight': 0.8},
30
+ 2: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1},
31
+ 3: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1},
32
+ 4: {'level': 'Crítico', 'color': '#cc0000', 'weight': 1.0},
33
+ 5: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1},
34
+ 6: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1}
35
+ }
36
+
37
+ def analizar_lesion_combined(img):
38
+ # 🔹 ViT prediction
39
+ inputs = feature_extractor(img, return_tensors="pt")
40
+ with torch.no_grad():
41
+ outputs = model_vit(**inputs)
42
+ probs_vit = outputs.logits.softmax(dim=-1).cpu().numpy()[0]
43
+ pred_idx_vit = int(np.argmax(probs_vit))
44
+ pred_class_vit = CLASSES[pred_idx_vit]
45
+ confidence_vit = probs_vit[pred_idx_vit]
46
+
47
+ # 🔹 Fast.ai malignancy model
48
+ pred_fast_malignant, _, probs_fast_mal = model_malignancy.predict(img)
49
+ prob_malignant = float(probs_fast_mal[1]) # índice 1 = maligno
50
+
51
+ # 🔹 Fast.ai lesion classification
52
+ pred_fast_type, _, probs_fast_type = model_norm2000.predict(img)
53
+
54
+ # 🔹 Gráfico ViT
55
+ colors_bars = [RISK_LEVELS[i]['color'] for i in range(7)]
56
+ fig, ax = plt.subplots(figsize=(8, 3))
57
+ ax.bar(CLASSES, probs_vit*100, color=colors_bars)
58
+ ax.set_title("Probabilidad ViT por tipo de lesión")
59
+ ax.set_ylabel("Probabilidad (%)")
60
+ ax.set_xticklabels(CLASSES, rotation=45, ha='right')
61
+ ax.grid(axis='y', alpha=0.2)
62
+ plt.tight_layout()
63
+ buf = io.BytesIO()
64
+ plt.savefig(buf, format="png")
65
+ plt.close(fig)
66
+ img_bytes = buf.getvalue()
67
+ img_b64 = base64.b64encode(img_bytes).decode("utf-8")
68
+ html_chart = f'<img src="data:image/png;base64,{img_b64}" style="max-width:100%"/>'
69
+
70
+ # 🔹 Informe HTML comparado
71
+ informe = f"""
72
+ <div style="font-family:sans-serif; max-width:800px; margin:auto">
73
+ <h2>🧪 Diagnóstico por 3 modelos de IA</h2>
74
+ <table style="border-collapse: collapse; width:100%; font-size:16px">
75
+ <tr><th style="text-align:left">🔍 Modelo</th><th>Resultado</th><th>Confianza</th></tr>
76
+ <tr><td>🧠 ViT (transformer)</td><td><b>{pred_class_vit}</b></td><td>{confidence_vit:.1%}</td></tr>
77
+ <tr><td>🧬 Fast.ai (clasificación)</td><td><b>{pred_fast_type}</b></td><td>N/A</td></tr>
78
+ <tr><td>⚠️ Fast.ai (malignidad)</td><td><b>{"Maligno" if prob_malignant > 0.5 else "Benigno"}</b></td><td>{prob_malignant:.1%}</td></tr>
79
+ </table>
80
+ <br>
81
+ <b>🩺 Recomendación automática:</b><br>
82
+ """
83
+
84
+ # 🔹 Recomendación
85
+ cancer_risk_score = sum(probs_vit[i] * RISK_LEVELS[i]['weight'] for i in range(7))
86
+ if prob_malignant > 0.7 or cancer_risk_score > 0.6:
87
+ informe += "🚨 <b>CRÍTICO</b> – Derivación urgente a oncología dermatológica"
88
+ elif prob_malignant > 0.4 or cancer_risk_score > 0.4:
89
+ informe += "⚠️ <b>ALTO RIESGO</b> – Consulta con dermatólogo en 7 días"
90
+ elif cancer_risk_score > 0.2:
91
+ informe += "📋 <b>RIESGO MODERADO</b> – Evaluación programada (2-4 semanas)"
92
+ else:
93
+ informe += "✅ <b>BAJO RIESGO</b> – Seguimiento de rutina (3-6 meses)"
94
+
95
+ informe += "</div>"
96
+
97
+ return informe, html_chart
98
+
99
+ # 🔹 Interfaz Gradio
100
+ demo = gr.Interface(
101
+ fn=analizar_lesion_combined,
102
+ inputs=gr.Image(type="pil", label="Sube una imagen de la lesión"),
103
+ outputs=[gr.HTML(label="Informe combinado"), gr.HTML(label="Gráfico ViT")],
104
+ title="Detector de Lesiones Cutáneas (ViT + Fast.ai)",
105
+ description="Comparación entre ViT transformer (HAM10000) y dos modelos Fast.ai entrenados sobre distintos datasets.",
106
+ flagging_mode="never"
107
+ )
108
+
109
+ if __name__ == "__main__":
110
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ transformers
4
+ fastai
5
+ gradio
6
+ matplotlib