JersonRuizAlva commited on
Commit
97a4bf8
·
1 Parent(s): 4154629

Add application file

Browse files
app.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py - Módulo para src
2
+ import streamlit as st
3
+ from streamlit_option_menu import option_menu
4
+ import streamlit_lottie as st_lottie
5
+ import json
6
+ import google.generativeai as genai
7
+ from dotenv import load_dotenv
8
+ import os
9
+ import sys
10
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
11
+
12
+ # Importaciones locales
13
+ from datos.upload import show_upload
14
+ from datos.prepare import show_prepare
15
+ from models.train import show_train
16
+ from models.test import show_test
17
+ from models.unsupervised import show_unsupervised
18
+
19
+ # Configuración inicial
20
+ st.set_page_config(initial_sidebar_state="collapsed", page_title="Machine Learning", page_icon="🤖", layout="wide")
21
+ load_dotenv()
22
+
23
+ # Función para cargar el archivo Lottie
24
+ def load_lottie_file(filepath: str):
25
+ try:
26
+ # Construir ruta absoluta
27
+ base_path = os.path.dirname(os.path.abspath(__file__))
28
+ full_path = os.path.join(base_path, 'assets', filepath)
29
+
30
+ with open(full_path, 'r') as f:
31
+ return json.load(f)
32
+ except FileNotFoundError:
33
+ st.error(f"Archivo Lottie no encontrado: {full_path}")
34
+ return None
35
+
36
+ # Configuración del sidebar
37
+ with st.sidebar:
38
+ # Cargar y mostrar el logo animado
39
+ try:
40
+ gemini_logo = load_lottie_file('gemini_logo.json')
41
+ if gemini_logo:
42
+ st_lottie.st_lottie(
43
+ gemini_logo,
44
+ key='logo',
45
+ height=50,
46
+ width=50,
47
+ loop=True,
48
+ quality="low"
49
+ )
50
+ except Exception as e:
51
+ st.error(f"Error al cargar el logo: {e}")
52
+
53
+ # Sección de API Keys
54
+ st.markdown("### Configuración de APIs")
55
+
56
+ # Gemini API
57
+ st.markdown('''
58
+ [Consigue tu API Key de Google AI Studio](https://aistudio.google.com/app/apikey)
59
+ ''')
60
+ genai_api_key = st.text_input(
61
+ "Gemini API Key",
62
+ type="password",
63
+ placeholder="Ingresa tu API Key de Gemini",
64
+ key='gemini_api_key'
65
+ )
66
+
67
+ # Supabase API
68
+ st.markdown('''
69
+ [Consigue tus credenciales de Supabase](https://supabase.com/dashboard/project/_/settings/api)
70
+ ''')
71
+ supabase_url = st.text_input(
72
+ "Supabase URL",
73
+ type="password",
74
+ placeholder="Ingresa tu Supabase URL",
75
+ key='supabase_url'
76
+ )
77
+
78
+ supabase_key = st.text_input(
79
+ "Supabase Key",
80
+ type="password",
81
+ placeholder="Ingresa tu Supabase Key",
82
+ key='supabase_key'
83
+ )
84
+
85
+ # Validación de credenciales
86
+ if not all([genai_api_key, supabase_url, supabase_key]):
87
+ st.warning("Por favor ingresa todas las credenciales necesarias.")
88
+ else:
89
+ genai.configure(api_key=genai_api_key)
90
+ model = genai.GenerativeModel('gemini-1.5-flash')
91
+ st.success("✅ Credenciales configuradas correctamente")
92
+
93
+ st.sidebar.markdown(
94
+ f'''
95
+ <div style="text-align: center; margin-bottom: 20px;">
96
+ <a href="https://jersonalvr.shinyapps.io/prophet/" target="_blank" style="text-decoration: none; color: inherit;">Analizar series temporales</a>
97
+ <br></br>
98
+ Elaborado por
99
+ <a href="https://www.linkedin.com/in/jersonalvr" target="_blank" style="text-decoration: none; color: inherit;">
100
+ <img src="https://cdn-icons-png.flaticon.com/512/174/174857.png" alt="LinkedIn" width="20" style="vertical-align: middle; margin-right: 5px;"/>
101
+ Jerson Ruiz Alva
102
+ </a>
103
+ </div>
104
+ ''',
105
+ unsafe_allow_html=True
106
+ )
107
+
108
+ # Configuración de estilos de navegación
109
+ pages = ["Upload", "Prepare", "Training", "ModelTest", "Unsupervised"]
110
+
111
+ selected_page = option_menu(
112
+ None,
113
+ options=pages,
114
+ icons=['cloud-upload', 'gear', 'robot', 'folder-check', 'search'],
115
+ default_index=0,
116
+ orientation="horizontal",
117
+ styles={
118
+ "container": {"padding": "0!important", "background-color": None},
119
+ "icon": {"color": None, "font-size": "20px"},
120
+ "nav-link": {
121
+ "font-size": "15px",
122
+ "text-align": "center",
123
+ "margin": "0px",
124
+ "--hover-color": "rgba(15, 21, 34, 0.25)",
125
+ },
126
+ "nav-link-selected": {"background-color": "rgba(15, 21, 34, 1)"},
127
+ }
128
+ )
129
+
130
+ # The rest of the page routing remains the same
131
+ if selected_page == "Upload":
132
+ show_upload()
133
+ elif selected_page == "Prepare":
134
+ show_prepare()
135
+ elif selected_page == "Training":
136
+ show_train()
137
+ elif selected_page == "Test":
138
+ show_test()
139
+ elif selected_page == "Unsupervised":
140
+ show_unsupervised()
assets/gemini_logo.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"v":"4.8.0","meta":{"g":"LottieFiles AE ","a":"","k":"","d":"","tc":""},"fr":25,"ip":0,"op":51,"w":1000,"h":1000,"nm":"Typing Element_Torus 1","ddd":0,"assets":[{"id":"comp_0","layers":[{"ddd":0,"ind":3,"ty":3,"nm":"Null 56","sr":1,"ks":{"o":{"a":0,"k":0,"ix":11},"r":{"a":1,"k":[{"i":{"x":[0.833],"y":[0.833]},"o":{"x":[0.167],"y":[0.167]},"t":0,"s":[0]},{"i":{"x":[0.833],"y":[0.833]},"o":{"x":[0.167],"y":[0.167]},"t":25,"s":[360]},{"t":50,"s":[720]}],"ix":10},"p":{"a":0,"k":[964,550,0],"ix":2},"a":{"a":0,"k":[0,0,0],"ix":1},"s":{"a":0,"k":[280,280,100],"ix":6}},"ao":0,"ip":0,"op":1750,"st":0,"bm":0},{"ddd":0,"ind":4,"ty":4,"nm":"Layer 3 Outlines","parent":3,"sr":1,"ks":{"o":{"a":0,"k":100,"ix":11},"r":{"a":0,"k":0,"ix":10},"p":{"a":0,"k":[-33.838,-5.16,0],"ix":2},"a":{"a":0,"k":[270.284,191.035,0],"ix":1},"s":{"a":0,"k":[110,110,100],"ix":6}},"ao":0,"ef":[{"ty":29,"nm":"Gaussian Blur","np":5,"mn":"ADBE Gaussian Blur 2","ix":3,"en":1,"ef":[{"ty":0,"nm":"Blurriness","mn":"ADBE Gaussian Blur 2-0001","ix":1,"v":{"a":0,"k":150,"ix":1}},{"ty":7,"nm":"Blur Dimensions","mn":"ADBE Gaussian Blur 2-0002","ix":2,"v":{"a":0,"k":1,"ix":2}},{"ty":7,"nm":"Repeat Edge Pixels","mn":"ADBE Gaussian Blur 2-0003","ix":3,"v":{"a":0,"k":1,"ix":3}}]}],"shapes":[{"ty":"gr","it":[{"ind":0,"ty":"sh","ix":1,"ks":{"a":0,"k":{"i":[[-51.873,0],[0,-51.873],[51.873,0],[0,51.873]],"o":[[51.873,0],[0,51.873],[-51.873,0],[0,-51.873]],"v":[[0,-93.925],[93.925,0],[0,93.925],[-93.925,0]],"c":true},"ix":2},"nm":"Path 1","mn":"ADBE Vector Shape - Group","hd":false},{"ty":"fl","c":{"a":0,"k":[0.552941176471,0.490196108351,1,1],"ix":4},"o":{"a":0,"k":100,"ix":5},"r":1,"bm":0,"nm":"Fill 1","mn":"ADBE Vector Graphic - Fill","hd":false},{"ty":"tr","p":{"a":0,"k":[305.506,94.175],"ix":2},"a":{"a":0,"k":[0,0],"ix":1},"s":{"a":0,"k":[138,138],"ix":3},"r":{"a":0,"k":0,"ix":6},"o":{"a":0,"k":100,"ix":7},"sk":{"a":0,"k":0,"ix":4},"sa":{"a":0,"k":0,"ix":5},"nm":"Transform"}],"nm":"Group 1","np":2,"cix":2,"bm":0,"ix":1,"mn":"ADBE Vector Group","hd":false},{"ty":"gr","it":[{"ind":0,"ty":"sh","ix":1,"ks":{"a":0,"k":{"i":[[-85.37,0],[0,-85.371],[85.37,0],[0,85.37]],"o":[[85.37,0],[0,85.37],[-85.37,0],[0,-85.371]],"v":[[0,-154.577],[154.577,0],[0,154.577],[-154.577,0]],"c":true},"ix":2},"nm":"Path 1","mn":"ADBE Vector Shape - Group","hd":false},{"ty":"fl","c":{"a":0,"k":[0.470588265213,0.831372608858,1,1],"ix":4},"o":{"a":0,"k":100,"ix":5},"r":1,"bm":0,"nm":"Fill 1","mn":"ADBE Vector Graphic - Fill","hd":false},{"ty":"tr","p":{"a":0,"k":[154.827,195.919],"ix":2},"a":{"a":0,"k":[0,0],"ix":1},"s":{"a":0,"k":[116,116],"ix":3},"r":{"a":0,"k":0,"ix":6},"o":{"a":0,"k":100,"ix":7},"sk":{"a":0,"k":0,"ix":4},"sa":{"a":0,"k":0,"ix":5},"nm":"Transform"}],"nm":"Group 2","np":2,"cix":2,"bm":0,"ix":2,"mn":"ADBE Vector Group","hd":false},{"ty":"gr","it":[{"ind":0,"ty":"sh","ix":1,"ks":{"a":0,"k":{"i":[[-64.842,0],[0,-64.842],[64.842,0],[0,64.842]],"o":[[64.842,0],[0,64.842],[-64.842,0],[0,-64.842]],"v":[[0,-117.406],[117.406,0],[0,117.406],[-117.406,0]],"c":true},"ix":2},"nm":"Path 1","mn":"ADBE Vector Shape - Group","hd":false},{"ty":"fl","c":{"a":0,"k":[0.980392216701,0.529411764706,0.901960844152,1],"ix":4},"o":{"a":0,"k":100,"ix":5},"r":1,"bm":0,"nm":"Fill 1","mn":"ADBE Vector Graphic - Fill","hd":false},{"ty":"tr","p":{"a":0,"k":[422.912,264.413],"ix":2},"a":{"a":0,"k":[0,0],"ix":1},"s":{"a":0,"k":[142,142],"ix":3},"r":{"a":0,"k":0,"ix":6},"o":{"a":0,"k":100,"ix":7},"sk":{"a":0,"k":0,"ix":4},"sa":{"a":0,"k":0,"ix":5},"nm":"Transform"}],"nm":"Group 3","np":2,"cix":2,"bm":0,"ix":3,"mn":"ADBE Vector Group","hd":false}],"ip":0,"op":500,"st":0,"bm":2}]}],"layers":[{"ddd":0,"ind":1,"ty":4,"nm":"Layer 1 Outlines 2","td":1,"sr":1,"ks":{"o":{"a":0,"k":100,"ix":11},"r":{"a":0,"k":0,"ix":10},"p":{"a":0,"k":[93.132,72.852,0],"ix":2},"a":{"a":0,"k":[0,-4,0],"ix":1},"s":{"a":0,"k":[507,507,100],"ix":6}},"ao":0,"shapes":[{"ty":"gr","it":[{"ind":0,"ty":"sh","ix":1,"ks":{"a":0,"k":{"i":[[-6.408,1.204],[-0.04,0.008],[-1.08,1.379],[0,1.711],[1.083,1.379],[1.789,0.375],[0.034,0.007],[7.959,2.856],[5.245,4.574],[2.402,15.036],[0.055,0.2],[1.312,1.036],[0,0],[1.73,0],[1.327,-1.049],[0.427,-1.57],[0.033,-0.205],[14.812,-13.222],[0,0],[8.02,-2.876],[6.419,-1.146],[0.057,-0.012],[1.081,-1.379],[0,0],[0,-1.698],[-1.082,-1.381],[-1.787,-0.375],[-0.04,-0.008],[-7.961,-2.855],[-5.223,-4.587],[-2.42,-13.771],[-0.005,-0.03],[-1.344,-1.152],[-1.826,0],[-1.355,1.163],[-0.323,1.682],[0,0],[0,0],[-14.883,13.332],[-8.014,2.877]],"o":[[0.04,-0.007],[1.792,-0.376],[1.077,-1.376],[0,-1.707],[-1.079,-1.377],[-0.035,-0.007],[-6.416,-1.223],[-8.014,-2.876],[-14.817,-13.384],[-0.032,-0.205],[-0.429,-1.578],[0,0],[-1.322,-1.043],[-1.726,0],[-1.319,1.041],[-0.055,0.2],[-2.408,15.071],[0,0],[-5.181,4.633],[-7.948,2.851],[-0.058,0.01],[-1.785,0.375],[0,0],[-1.085,1.384],[0,1.704],[1.081,1.382],[0.04,0.008],[6.409,1.204],[8.021,2.876],[14.927,13.447],[0.005,0.03],[0.323,1.682],[1.356,1.164],[1.829,0],[1.343,-1.152],[0,0],[0,0],[2.383,-13.866],[5.225,-4.59],[7.951,-2.854]],"v":[[73.819,8.807],[73.938,8.784],[78.358,6.026],[80,1.253],[78.353,-3.524],[73.938,-6.276],[73.835,-6.297],[51.296,-12.296],[30.625,-23.432],[7.499,-73.717],[7.368,-74.324],[4.705,-78.363],[4.7,-78.366],[0,-80],[-4.705,-78.362],[-7.368,-74.324],[-7.499,-73.717],[-30.637,-23.421],[-30.641,-23.417],[-51.229,-12.186],[-73.768,-6.309],[-73.94,-6.276],[-78.354,-3.524],[-78.358,-3.519],[-80,1.253],[-78.358,6.026],[-73.94,8.783],[-73.82,8.807],[-51.282,14.776],[-30.63,25.934],[-7.449,73.686],[-7.433,73.776],[-4.887,78.173],[0.035,80],[4.958,78.173],[7.503,73.776],[7.51,73.737],[7.523,73.665],[30.659,25.939],[51.301,14.775]],"c":true},"ix":2},"nm":"Path 1","mn":"ADBE Vector Shape - Group","hd":false},{"ty":"fl","c":{"a":0,"k":[1,1,1,1],"ix":4},"o":{"a":0,"k":100,"ix":5},"r":1,"bm":0,"nm":"Fill 1","mn":"ADBE Vector Graphic - Fill","hd":false},{"ty":"tr","p":{"a":0,"k":[80.25,80.25],"ix":2},"a":{"a":0,"k":[0,0],"ix":1},"s":{"a":0,"k":[100,100],"ix":3},"r":{"a":0,"k":0,"ix":6},"o":{"a":0,"k":100,"ix":7},"sk":{"a":0,"k":0,"ix":4},"sa":{"a":0,"k":0,"ix":5},"nm":"Transform"}],"nm":"Group 1","np":2,"cix":2,"bm":0,"ix":1,"mn":"ADBE Vector Group","hd":false}],"ip":0,"op":250,"st":0,"bm":0},{"ddd":0,"ind":2,"ty":0,"nm":"Typing_fractal nois","tt":1,"refId":"comp_0","sr":1,"ks":{"o":{"a":0,"k":100,"ix":11},"r":{"a":0,"k":0,"ix":10},"p":{"a":0,"k":[500,500,0],"ix":2},"a":{"a":0,"k":[960,540,0],"ix":1},"s":{"a":0,"k":[109,109,100],"ix":6}},"ao":0,"w":1920,"h":1080,"ip":0,"op":250,"st":0,"bm":0},{"ddd":0,"ind":4,"ty":4,"nm":"Layer 1 Outlines 5","sr":1,"ks":{"o":{"a":0,"k":100,"ix":11},"r":{"a":0,"k":0,"ix":10},"p":{"a":0,"k":[93.132,72.852,0],"ix":2},"a":{"a":0,"k":[0,-4,0],"ix":1},"s":{"a":0,"k":[507,507,100],"ix":6}},"ao":0,"shapes":[{"ty":"gr","it":[{"ind":0,"ty":"sh","ix":1,"ks":{"a":0,"k":{"i":[[-6.408,1.204],[-0.04,0.008],[-1.08,1.379],[0,1.711],[1.083,1.379],[1.789,0.375],[0.034,0.007],[7.959,2.856],[5.245,4.574],[2.402,15.036],[0.055,0.2],[1.312,1.036],[0,0],[1.73,0],[1.327,-1.049],[0.427,-1.57],[0.033,-0.205],[14.812,-13.222],[0,0],[8.02,-2.876],[6.419,-1.146],[0.057,-0.012],[1.081,-1.379],[0,0],[0,-1.698],[-1.082,-1.381],[-1.787,-0.375],[-0.04,-0.008],[-7.961,-2.855],[-5.223,-4.587],[-2.42,-13.771],[-0.005,-0.03],[-1.344,-1.152],[-1.826,0],[-1.355,1.163],[-0.323,1.682],[0,0],[0,0],[-14.883,13.332],[-8.014,2.877]],"o":[[0.04,-0.007],[1.792,-0.376],[1.077,-1.376],[0,-1.707],[-1.079,-1.377],[-0.035,-0.007],[-6.416,-1.223],[-8.014,-2.876],[-14.817,-13.384],[-0.032,-0.205],[-0.429,-1.578],[0,0],[-1.322,-1.043],[-1.726,0],[-1.319,1.041],[-0.055,0.2],[-2.408,15.071],[0,0],[-5.181,4.633],[-7.948,2.851],[-0.058,0.01],[-1.785,0.375],[0,0],[-1.085,1.384],[0,1.704],[1.081,1.382],[0.04,0.008],[6.409,1.204],[8.021,2.876],[14.927,13.447],[0.005,0.03],[0.323,1.682],[1.356,1.164],[1.829,0],[1.343,-1.152],[0,0],[0,0],[2.383,-13.866],[5.225,-4.59],[7.951,-2.854]],"v":[[73.819,8.807],[73.938,8.784],[78.358,6.026],[80,1.253],[78.353,-3.524],[73.938,-6.276],[73.835,-6.297],[51.296,-12.296],[30.625,-23.432],[7.499,-73.717],[7.368,-74.324],[4.705,-78.363],[4.7,-78.366],[0,-80],[-4.705,-78.362],[-7.368,-74.324],[-7.499,-73.717],[-30.637,-23.421],[-30.641,-23.417],[-51.229,-12.186],[-73.768,-6.309],[-73.94,-6.276],[-78.354,-3.524],[-78.358,-3.519],[-80,1.253],[-78.358,6.026],[-73.94,8.783],[-73.82,8.807],[-51.282,14.776],[-30.63,25.934],[-7.449,73.686],[-7.433,73.776],[-4.887,78.173],[0.035,80],[4.958,78.173],[7.503,73.776],[7.51,73.737],[7.523,73.665],[30.659,25.939],[51.301,14.775]],"c":true},"ix":2},"nm":"Path 1","mn":"ADBE Vector Shape - Group","hd":false},{"ty":"fl","c":{"a":0,"k":[1,1,1,1],"ix":4},"o":{"a":0,"k":100,"ix":5},"r":1,"bm":0,"nm":"Fill 1","mn":"ADBE Vector Graphic - Fill","hd":false},{"ty":"tr","p":{"a":0,"k":[80.25,80.25],"ix":2},"a":{"a":0,"k":[0,0],"ix":1},"s":{"a":0,"k":[100,100],"ix":3},"r":{"a":0,"k":0,"ix":6},"o":{"a":0,"k":100,"ix":7},"sk":{"a":0,"k":0,"ix":4},"sa":{"a":0,"k":0,"ix":5},"nm":"Transform"}],"nm":"Group 1","np":2,"cix":2,"bm":0,"ix":1,"mn":"ADBE Vector Group","hd":false}],"ip":0,"op":250,"st":0,"bm":0}],"markers":[]}
datos/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # __init__.py - M�dulo para src data
datos/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (178 Bytes). View file
 
datos/__pycache__/prepare.cpython-312.pyc ADDED
Binary file (36.8 kB). View file
 
datos/__pycache__/upload.cpython-312.pyc ADDED
Binary file (33.5 kB). View file
 
datos/prepare.py ADDED
@@ -0,0 +1,757 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # prepare.py - Módulo para datos
2
+ import streamlit as st
3
+ import pandas as pd
4
+ import numpy as np
5
+ import plotly.express as px
6
+ from sklearn.preprocessing import StandardScaler, LabelEncoder
7
+ from datetime import datetime
8
+ from ydata_profiling import ProfileReport
9
+ import os
10
+
11
+ def show_prepare():
12
+ # Crear un contenedor para mensajes de estado
13
+ status_container = st.empty()
14
+
15
+ # Verificar si hay datos cargados
16
+ if 'er_data' not in st.session_state:
17
+ status_container.warning("⚠️ No hay datos cargados. Por favor, carga un dataset en la página Upload primero.")
18
+ return
19
+
20
+ try:
21
+ # Usar los datos preparados si existen, si no, usar los datos originales
22
+ if 'temp_prepared_data' in st.session_state:
23
+ prepare = st.session_state.temp_prepared_data.copy()
24
+ else:
25
+ # Si no hay datos temporales, intentar usar datos preparados permanentes
26
+ if 'prepared_data' in st.session_state:
27
+ prepare = st.session_state.prepared_data.copy()
28
+ else:
29
+ prepare = st.session_state.er_data.copy()
30
+ st.session_state.temp_prepared_data = prepare.copy()
31
+ except AttributeError:
32
+ status_container.warning("⚠️ No hay datos cargados o los datos son inválidos.")
33
+ return
34
+
35
+ # Análisis de valores únicos por columna
36
+ st.markdown("### Análisis de Valores Únicos por Columna")
37
+
38
+ # Selección de columnas para analizar - sin selección por defecto
39
+ all_columns = prepare.columns.tolist()
40
+ selected_columns = st.multiselect(
41
+ "Seleccionar columnas para analizar",
42
+ all_columns,
43
+ default=[], # Sin selección por defecto
44
+ key="unique_values_columns",
45
+ help="Selecciona las columnas que deseas analizar"
46
+ )
47
+
48
+ if not selected_columns:
49
+ st.info("👆 Selecciona una o más columnas para ver su análisis detallado.")
50
+ else:
51
+ # Controles para número de valores a mostrar
52
+ col1, col2 = st.columns(2)
53
+ with col1:
54
+ n_first = st.number_input(
55
+ "Número de primeros valores a mostrar",
56
+ min_value=1,
57
+ max_value=20,
58
+ value=5,
59
+ key="n_first_values"
60
+ )
61
+ with col2:
62
+ n_last = st.number_input(
63
+ "Número de últimos valores a mostrar",
64
+ min_value=1,
65
+ max_value=20,
66
+ value=5,
67
+ key="n_last_values"
68
+ )
69
+
70
+ # Crear tabs para cada columna seleccionada
71
+ tabs = st.tabs([f"📊 {col}" for col in selected_columns])
72
+
73
+ # Análisis por cada columna seleccionada
74
+ for tab, col in zip(tabs, selected_columns):
75
+ with tab:
76
+ try:
77
+ st.markdown(f"### Análisis de {col} ({prepare[col].dtype})")
78
+
79
+ # Safely convert column to handle mixed types
80
+ column_data = prepare[col].fillna('Sin valor').astype(str)
81
+
82
+ valores_unicos = column_data.unique()
83
+ n_valores = len(valores_unicos)
84
+
85
+ # Información general en columnas
86
+ col1, col2, col3 = st.columns([2, 1, 1])
87
+ with col1:
88
+ st.write(f"Total de valores únicos: {n_valores}")
89
+ with col2:
90
+ st.write(f"Valores nulos: {prepare[col].isnull().sum()}")
91
+ with col3:
92
+ st.write(f"% nulos: {(prepare[col].isnull().sum() / len(prepare) * 100).round(2)}%")
93
+
94
+ st.markdown("---")
95
+
96
+ # Visualización de valores únicos
97
+ if n_valores > (n_first + n_last):
98
+ col1, col2 = st.columns(2)
99
+ with col1:
100
+ st.write("🔼 Primeros valores:")
101
+ for valor in valores_unicos[:n_first]:
102
+ st.write(f"• {str(valor)}")
103
+
104
+ with col2:
105
+ st.write("🔽 Últimos valores:")
106
+ for valor in valores_unicos[-n_last:]:
107
+ st.write(f"• {str(valor)}")
108
+ else:
109
+ st.write("📝 Todos los valores únicos:")
110
+ for valor in valores_unicos:
111
+ st.write(f"• '{str(valor)}'")
112
+
113
+ st.markdown("---")
114
+
115
+ # Distribución de frecuencias
116
+ value_counts = column_data.value_counts()
117
+
118
+ # Gráfico de barras con conversión segura
119
+ fig_bar = {
120
+ 'data': [{
121
+ 'type': 'bar',
122
+ 'x': value_counts.index,
123
+ 'y': value_counts.values,
124
+ 'name': 'Frecuencia'
125
+ }],
126
+ 'layout': {
127
+ 'title': f'Distribución de valores en {col}',
128
+ 'xaxis': {'title': 'Valor'},
129
+ 'yaxis': {'title': 'Frecuencia'},
130
+ 'height': 400,
131
+ 'showlegend': False
132
+ }
133
+ }
134
+ st.plotly_chart(fig_bar, use_container_width=True)
135
+
136
+ # Tabla de frecuencias
137
+ freq_df = pd.DataFrame({
138
+ 'Valor': value_counts.index,
139
+ 'Frecuencia': value_counts.values,
140
+ 'Porcentaje': (value_counts.values / len(prepare) * 100).round(2)
141
+ })
142
+ st.dataframe(freq_df, use_container_width=True)
143
+
144
+ except Exception as e:
145
+ st.error(f"Error al procesar la columna {col}")
146
+ st.error(str(e))
147
+ st.write(f"Detalles técnicos del error en la columna {col}:", e)
148
+
149
+ st.subheader("Preparación de Datos")
150
+
151
+ st.subheader("Eliminación de Columnas")
152
+
153
+ # Verificar el estado actual de valores nulos
154
+ current_null_count = prepare.isnull().sum().sum()
155
+
156
+ all_columns = prepare.columns.tolist()
157
+ columns_to_drop = st.multiselect(
158
+ "Seleccionar columnas a eliminar",
159
+ all_columns,
160
+ key="columns_to_drop"
161
+ )
162
+
163
+ if columns_to_drop:
164
+ if st.button("Eliminar columnas seleccionadas", key="drop_columns_button"):
165
+ try:
166
+ # Crear una copia temporal antes de eliminar columnas
167
+ temp_prepare = prepare.copy()
168
+
169
+ # Verificar que las columnas existen antes de eliminarlas
170
+ missing_cols = [col for col in columns_to_drop if col not in temp_prepare.columns]
171
+ if missing_cols:
172
+ st.error(f"❌ Las siguientes columnas no existen: {', '.join(missing_cols)}")
173
+ return
174
+
175
+ # Eliminar las columnas
176
+ temp_prepare = temp_prepare.drop(columns=columns_to_drop)
177
+
178
+ # Verificar valores nulos después de la eliminación
179
+ new_null_count = temp_prepare.isnull().sum().sum()
180
+
181
+ if new_null_count <= current_null_count: # Permitir igual o menor cantidad de nulos
182
+ # Actualizar el DataFrame y el estado
183
+ prepare = temp_prepare
184
+ st.session_state.temp_prepared_data = prepare.copy()
185
+
186
+ # Mostrar mensaje de éxito
187
+ st.success(f"✅ Columnas eliminadas exitosamente: {', '.join(columns_to_drop)}")
188
+
189
+ # Actualizar información sobre valores nulos
190
+ if new_null_count == 0:
191
+ st.success("✅ No hay valores nulos en los datos.")
192
+ else:
193
+ st.warning(f"⚠️ Hay {new_null_count} valores nulos en los datos.")
194
+
195
+ # Mostrar resumen de las columnas restantes
196
+ st.write(f"Columnas restantes: {len(prepare.columns)}")
197
+ if st.checkbox("Ver lista de columnas restantes", key="remaining_columns_checkbox"):
198
+ st.write(prepare.columns.tolist())
199
+ else:
200
+ st.error(f"❌ La operación incrementaría los valores nulos de {current_null_count} a {new_null_count}. Operación cancelada.")
201
+
202
+ except Exception as e:
203
+ st.error(f"Error durante la eliminación de columnas: {str(e)}")
204
+ st.exception(e)
205
+
206
+ # Manejo de valores faltantes
207
+ st.subheader("Manejo de Valores Faltantes")
208
+ missing_values = prepare.isnull().sum()
209
+ missing_percentages = (missing_values / len(prepare) * 100).round(2)
210
+
211
+ # Crear DataFrame y ordenar por número de valores faltantes de mayor a menor
212
+ missing_df = pd.DataFrame({
213
+ 'Columna': missing_values.index,
214
+ 'Valores Faltantes': missing_values.values,
215
+ 'Porcentaje': missing_percentages.values,
216
+ 'Tipo': prepare[missing_values.index].dtypes
217
+ })
218
+ missing_df = missing_df[missing_df['Valores Faltantes'] > 0].sort_values('Valores Faltantes', ascending=False)
219
+
220
+ if not missing_df.empty:
221
+ # Mostrar advertencia si hay columnas de tipo object con valores faltantes
222
+ object_cols = missing_df[missing_df['Tipo'] == 'object']
223
+ if not object_cols.empty:
224
+ st.warning("⚠️ Se detectaron columnas de tipo texto/categórico (object) con valores faltantes. "
225
+ "Se recomienda revisar estos casos con especial atención ya que el método de imputación "
226
+ "podría afectar significativamente el análisis.")
227
+
228
+ st.write("Columnas de tipo texto/categórico con valores faltantes:")
229
+ st.dataframe(object_cols)
230
+
231
+ st.write("Valores faltantes por columna (ordenados de mayor a menor):")
232
+ st.dataframe(missing_df)
233
+
234
+ # Checkbox para manejo especial de columnas object
235
+ handle_objects = st.checkbox("Especificar valor de reemplazo para columnas de texto",
236
+ help="Marca esta opción para especificar un valor personalizado para rellenar "
237
+ "los valores faltantes en columnas de texto/categóricas")
238
+
239
+ object_replacement = None
240
+ if handle_objects:
241
+ object_replacement = st.text_input("Valor de reemplazo para columnas de texto:",
242
+ value="MISSING",
243
+ help="Este valor se usará para rellenar los valores faltantes "
244
+ "en todas las columnas de texto/categóricas")
245
+
246
+ missing_strategy = st.radio(
247
+ "Selecciona estrategia para valores faltantes:",
248
+ ["Eliminar filas", "Rellenar con media", "Rellenar con mediana", "Rellenar con moda"]
249
+ )
250
+
251
+ if st.button("Aplicar estrategia de valores faltantes", key="apply_missing_strategy_button"):
252
+ try:
253
+ # Guardar el estado anterior de prepare para verificación
254
+ nulls_before = prepare.isnull().sum().sum()
255
+
256
+ if missing_strategy == "Eliminar filas":
257
+ # Guardar el número de filas antes
258
+ rows_before = len(prepare)
259
+
260
+ # Crear una copia para no modificar el original
261
+ prepare_cleaned = prepare.copy()
262
+
263
+ # Eliminar filas con valores nulos
264
+ prepare_cleaned = prepare_cleaned.dropna(how='any')
265
+
266
+ # Verificar que no queden valores nulos
267
+ if prepare_cleaned.isnull().sum().sum() == 0:
268
+ prepare = prepare_cleaned # Actualizar prepare solo si la limpieza fue exitosa
269
+ st.session_state.temp_prepared_data = prepare.copy() # Actualizar el estado temporal
270
+ rows_removed = rows_before - len(prepare)
271
+ st.success(f"Se eliminaron {rows_removed} filas con valores faltantes. No quedan valores nulos.")
272
+ else:
273
+ st.error(f"Error: Aún quedan {prepare_cleaned.isnull().sum().sum()} valores faltantes después de la eliminación.")
274
+ return
275
+ else:
276
+ # Separar columnas numéricas y no numéricas
277
+ numeric_cols = prepare.select_dtypes(include=['int64', 'float64']).columns
278
+ non_numeric_cols = prepare.select_dtypes(exclude=['int64', 'float64']).columns
279
+
280
+ # Manejar columnas object primero si se especificó un valor de reemplazo
281
+ if handle_objects and object_replacement is not None:
282
+ object_cols = prepare.select_dtypes(include=['object']).columns
283
+ for col in object_cols:
284
+ prepare[col] = prepare[col].fillna(object_replacement)
285
+
286
+ if missing_strategy == "Rellenar con media":
287
+ # Para columnas numéricas usar media
288
+ if len(numeric_cols) > 0:
289
+ prepare[numeric_cols] = prepare[numeric_cols].fillna(prepare[numeric_cols].mean())
290
+ # Para columnas no numéricas sin valor especificado usar moda
291
+ if len(non_numeric_cols) > 0 and not handle_objects:
292
+ for col in non_numeric_cols:
293
+ prepare[col] = prepare[col].fillna(prepare[col].mode()[0] if not prepare[col].mode().empty else 'NA')
294
+
295
+ elif missing_strategy == "Rellenar con mediana":
296
+ # Para columnas numéricas usar mediana
297
+ if len(numeric_cols) > 0:
298
+ prepare[numeric_cols] = prepare[numeric_cols].fillna(prepare[numeric_cols].median())
299
+ # Para columnas no numéricas sin valor especificado usar moda
300
+ if len(non_numeric_cols) > 0 and not handle_objects:
301
+ for col in non_numeric_cols:
302
+ prepare[col] = prepare[col].fillna(prepare[col].mode()[0] if not prepare[col].mode().empty else 'NA')
303
+
304
+ else: # Rellenar con moda
305
+ # Usar moda para todas las columnas que no son object o no tienen valor especificado
306
+ for col in prepare.columns:
307
+ if prepare[col].dtype in ['object']:
308
+ if handle_objects:
309
+ continue # Ya se manejaron las columnas object
310
+ mode_value = prepare[col].mode()
311
+ prepare[col] = prepare[col].fillna(mode_value[0] if not mode_value.empty else ('NA' if col in non_numeric_cols else 0))
312
+
313
+ # Actualizar el estado temporal después de la imputación
314
+ st.session_state.temp_prepared_data = prepare.copy()
315
+
316
+ # Verificar los cambios
317
+ nulls_after = prepare.isnull().sum().sum()
318
+ values_filled = nulls_before - nulls_after
319
+
320
+ if missing_strategy == "Eliminar filas":
321
+ st.success(f"Se eliminaron {values_filled} filas con valores faltantes")
322
+ else:
323
+ st.success(f"Se rellenaron {values_filled} valores faltantes")
324
+
325
+ # Verificar si quedan valores nulos
326
+ remaining_nulls = prepare.isnull().sum()
327
+ remaining_nulls = remaining_nulls[remaining_nulls > 0]
328
+
329
+ if not remaining_nulls.empty:
330
+ st.error("⚠️ Error: Aún quedan valores faltantes en las siguientes columnas:")
331
+ for col in remaining_nulls.index:
332
+ st.write(f"- {col}: {remaining_nulls[col]} valores faltantes")
333
+ st.write("Por favor, contacta al equipo de desarrollo para revisar este error.")
334
+
335
+ # Mostrar un resumen de los datos actualizados
336
+ st.write("\n### Resumen después del procesamiento:")
337
+ st.write(f"- Total de filas: {len(prepare)}")
338
+ st.write(f"- Total de columnas: {len(prepare.columns)}")
339
+ st.write(f"- Valores faltantes totales: {prepare.isnull().sum().sum()}")
340
+
341
+ # Verificación final de valores nulos
342
+ final_null_check = prepare.isnull().sum().sum()
343
+ if final_null_check == 0:
344
+ st.success("✅ ¡No quedan valores faltantes en el dataset!")
345
+ else:
346
+ st.error(f"⚠️ Aún quedan {final_null_check} valores nulos en el dataset.")
347
+ return
348
+
349
+ # Actualizar la sesión solo si no hay valores nulos
350
+ if final_null_check == 0:
351
+ st.session_state.prepared_data = prepare.copy()
352
+ st.session_state.temp_prepared_data = prepare.copy()
353
+ # No sobrescribir 'er_data'
354
+
355
+ except Exception as e:
356
+ st.error(f"Error al procesar valores faltantes: {str(e)}")
357
+ st.error("Detalles técnicos del error:")
358
+ st.code(str(e))
359
+
360
+ # Manejo de fechas
361
+ st.subheader("Manejo de Fechas")
362
+ with st.expander("Procesamiento de Fechas"):
363
+ date_columns = st.multiselect(
364
+ "Seleccionar columnas de fecha",
365
+ prepare.columns,
366
+ key="date_columns"
367
+ )
368
+
369
+ if date_columns:
370
+ date_format = st.selectbox(
371
+ "Formato de fecha",
372
+ [
373
+ "yyyy-mm-dd",
374
+ "dd-mm-yyyy",
375
+ "mm-dd-yyyy",
376
+ "yyyy-mm-dd hh:mm",
377
+ "dd-mm-yyyy hh:mm",
378
+ "mm-dd-yyyy hh:mm",
379
+ "yyyy-mm-dd hh:mm:ss",
380
+ "dd-mm-yyyy hh:mm:ss",
381
+ "mm-dd-yyyy hh:mm:ss",
382
+ "hh:mm",
383
+ "hh:mm:ss"
384
+ ],
385
+ help="Selecciona el formato que coincida con tus datos de fecha/hora."
386
+ )
387
+
388
+ time_format = st.radio(
389
+ "Formato de hora",
390
+ ["24 horas", "12 horas (AM/PM)"],
391
+ help="Selecciona si el formato de hora está en 12 o 24 horas"
392
+ )
393
+
394
+ # Ajustar las características disponibles según el formato
395
+ if date_format in ["hh:mm", "hh:mm:ss"]:
396
+ available_features = ["Hora del día", "Periodo del día", "Minutos", "Segundos"]
397
+ else:
398
+ if "hh:mm:ss" in date_format:
399
+ available_features = [
400
+ "Año", "Mes", "Día", "Día de la semana", "Trimestre", "Estación",
401
+ "Es fin de semana", "Hora del día", "Periodo del día", "Minutos", "Segundos"
402
+ ]
403
+ else:
404
+ available_features = [
405
+ "Año", "Mes", "Día", "Día de la semana", "Trimestre", "Estación",
406
+ "Es fin de semana", "Hora del día", "Periodo del día", "Minutos"
407
+ ]
408
+
409
+ date_features = st.multiselect(
410
+ "Seleccionar características a extraer",
411
+ available_features
412
+ )
413
+
414
+ if st.button("Procesar fechas", key="process_dates_button"):
415
+ for col in date_columns:
416
+ try:
417
+ if date_format in ["hh:mm", "hh:mm:ss"]:
418
+ # Procesar solo tiempo
419
+ if date_format == "hh:mm:ss":
420
+ time_parse_format = '%I:%M:%S %p' if time_format == "12 horas (AM/PM)" else '%H:%M:%S'
421
+ time_with_seconds = True
422
+ else:
423
+ time_parse_format = '%I:%M %p' if time_format == "12 horas (AM/PM)" else '%H:%M'
424
+ time_with_seconds = False
425
+
426
+ def convert_time(time_str):
427
+ try:
428
+ time_obj = datetime.strptime(time_str.strip(), time_parse_format)
429
+ if time_with_seconds:
430
+ return time_obj.hour, time_obj.minute, time_obj.second
431
+ else:
432
+ return time_obj.hour, time_obj.minute, None
433
+ except ValueError:
434
+ st.warning(f"⚠️ Formato de hora inesperado en {col}: '{time_str}'")
435
+ return None, None, None if time_with_seconds else None
436
+
437
+ # Aplicar la conversión y crear nuevas columnas
438
+ hours_minutes_seconds = prepare[col].apply(convert_time)
439
+
440
+ # Depuración: Mostrar una vista previa de la conversión
441
+ st.write(f"Vista previa de la conversión de tiempo para la columna {col}:")
442
+ st.write(hours_minutes_seconds.head())
443
+
444
+ if "Hora del día" in date_features:
445
+ prepare[f'{col}_hora'] = hours_minutes_seconds.apply(lambda x: x[0] if x and x[0] is not None else None)
446
+ if "Minutos" in date_features:
447
+ prepare[f'{col}_minutos'] = hours_minutes_seconds.apply(lambda x: x[1] if x and x[1] is not None else None)
448
+ if "Segundos" in date_features and time_with_seconds:
449
+ prepare[f'{col}_segundos'] = hours_minutes_seconds.apply(lambda x: x[2] if x and x[2] is not None else None)
450
+
451
+ # Agregar periodo del día si se seleccionó
452
+ if "Periodo del día" in date_features:
453
+ def get_period(hour):
454
+ if hour is None:
455
+ return None
456
+ if 5 <= hour < 12:
457
+ return 'Mañana'
458
+ elif 12 <= hour < 17:
459
+ return 'Tarde'
460
+ elif 17 <= hour < 21:
461
+ return 'Noche'
462
+ else:
463
+ return 'Madrugada'
464
+ prepare[f'{col}_periodo'] = prepare[f'{col}_hora'].apply(get_period)
465
+
466
+ else:
467
+ # Definir el formato de parsing según la selección
468
+ if date_format == "yyyy-mm-dd":
469
+ date_parse_format = '%Y-%m-%d'
470
+ elif date_format == "dd-mm-yyyy":
471
+ date_parse_format = '%d-%m-%Y'
472
+ elif date_format == "mm-dd-yyyy":
473
+ date_parse_format = '%m-%d-%Y'
474
+ elif date_format == "yyyy-mm-dd hh:mm":
475
+ date_parse_format = '%Y-%m-%d %H:%M' if time_format == "24 horas" else '%Y-%m-%d %I:%M %p'
476
+ elif date_format == "dd-mm-yyyy hh:mm":
477
+ date_parse_format = '%d-%m-%Y %H:%M' if time_format == "24 horas" else '%d-%m-%Y %I:%M %p'
478
+ elif date_format == "mm-dd-yyyy hh:mm":
479
+ date_parse_format = '%m-%d-%Y %H:%M' if time_format == "24 horas" else '%m-%d-%Y %I:%M %p'
480
+ elif date_format == "yyyy-mm-dd hh:mm:ss":
481
+ date_parse_format = '%Y-%m-%d %H:%M:%S' if time_format == "24 horas" else '%Y-%m-%d %I:%M:%S %p'
482
+ elif date_format == "dd-mm-yyyy hh:mm:ss":
483
+ date_parse_format = '%d-%m-%Y %H:%M:%S' if time_format == "24 horas" else '%d-%m-%Y %I:%M:%S %p'
484
+ elif date_format == "mm-dd-yyyy hh:mm:ss":
485
+ date_parse_format = '%m-%d-%Y %H:%M:%S' if time_format == "24 horas" else '%m-%d-%Y %I:%M:%S %p'
486
+ else:
487
+ st.error(f"Formato de fecha no reconocido: {date_format}")
488
+ continue
489
+
490
+ # Convertir a datetime con manejo de errores
491
+ temp_dates = pd.to_datetime(prepare[col], format=date_parse_format, errors='coerce')
492
+
493
+ # Depuración: Mostrar una vista previa de las fechas parseadas
494
+ st.write(f"Vista previa de las fechas parseadas para la columna {col}:")
495
+ st.write(temp_dates.head())
496
+
497
+ # Manejo de valores que no se pudieron parsear
498
+ if temp_dates.isnull().any():
499
+ st.warning(f"⚠️ Algunas fechas en la columna {col} no pudieron ser parseadas y se asignaron como NaT.")
500
+
501
+ # Extraer características según selección
502
+ if "Año" in date_features:
503
+ prepare[f'{col}_año'] = temp_dates.dt.year
504
+ if "Mes" in date_features:
505
+ prepare[f'{col}_mes'] = temp_dates.dt.month
506
+ if "Día" in date_features:
507
+ prepare[f'{col}_dia'] = temp_dates.dt.day
508
+ if "Día de la semana" in date_features:
509
+ prepare[f'{col}_dia_semana'] = temp_dates.dt.dayofweek + 1
510
+ if "Trimestre" in date_features:
511
+ prepare[f'{col}_trimestre'] = temp_dates.dt.quarter
512
+ if "Es fin de semana" in date_features:
513
+ prepare[f'{col}_fin_semana'] = temp_dates.dt.dayofweek.isin([5, 6]).astype(int)
514
+ if "Estación" in date_features:
515
+ def get_season(month):
516
+ if month in [12, 1, 2]:
517
+ return 'Invierno'
518
+ elif month in [3, 4, 5]:
519
+ return 'Primavera'
520
+ elif month in [6, 7, 8]:
521
+ return 'Verano'
522
+ else:
523
+ return 'Otoño'
524
+ prepare[f'{col}_estacion'] = temp_dates.dt.month.apply(get_season)
525
+ if "Hora del día" in date_features and any(sub in date_format for sub in ["hh:mm", "hh:mm:ss"]):
526
+ prepare[f'{col}_hora'] = temp_dates.dt.hour
527
+ if "Minutos" in date_features and any(sub in date_format for sub in ["hh:mm", "hh:mm:ss"]):
528
+ prepare[f'{col}_minutos'] = temp_dates.dt.minute
529
+ if "Segundos" in date_features and "hh:mm:ss" in date_format:
530
+ prepare[f'{col}_segundos'] = temp_dates.dt.second
531
+ if "Periodo del día" in date_features and any(sub in date_format for sub in ["hh:mm", "hh:mm:ss"]):
532
+ def get_period(hour):
533
+ if hour is None:
534
+ return None
535
+ if 5 <= hour < 12:
536
+ return 'Mañana'
537
+ elif 12 <= hour < 17:
538
+ return 'Tarde'
539
+ elif 17 <= hour < 21:
540
+ return 'Noche'
541
+ else:
542
+ return 'Madrugada'
543
+ prepare[f'{col}_periodo'] = temp_dates.dt.hour.apply(get_period)
544
+
545
+ # Eliminar la columna original de fecha
546
+ prepare = prepare.drop(columns=[col])
547
+ st.success(f"Columna {col} procesada exitosamente")
548
+
549
+ except Exception as e:
550
+ st.error(f"Error procesando {col}: {str(e)}")
551
+ st.exception(e)
552
+
553
+ # Actualizar el estado temporal después de procesar fechas
554
+ st.session_state.temp_prepared_data = prepare.copy()
555
+
556
+ # Codificación de variables categóricas
557
+ st.subheader("Codificación de Variables Categóricas")
558
+ categorical_columns = prepare.select_dtypes(include=['object']).columns
559
+
560
+ if len(categorical_columns) > 0:
561
+ encoding_method = st.radio(
562
+ "Método de codificación:",
563
+ ["Label Encoding", "One-Hot Encoding"]
564
+ )
565
+
566
+ cols_to_encode = st.multiselect(
567
+ "Seleccionar columnas para codificar",
568
+ categorical_columns,
569
+ key="cols_to_encode"
570
+ )
571
+
572
+ if st.button("Aplicar codificación", key="apply_encoding_button"):
573
+ if encoding_method == "Label Encoding":
574
+ le = LabelEncoder()
575
+ for col in cols_to_encode:
576
+ try:
577
+ prepare[col] = le.fit_transform(prepare[col].astype(str))
578
+ st.success(f"✅ Label Encoding aplicado a la columna '{col}'")
579
+ except Exception as e:
580
+ st.error(f"Error al codificar la columna {col} con Label Encoding: {str(e)}")
581
+ # Actualizar el estado temporal después de la codificación
582
+ st.session_state.temp_prepared_data = prepare.copy()
583
+ else: # One-Hot Encoding
584
+ try:
585
+ prepare = pd.get_dummies(prepare, columns=cols_to_encode)
586
+ st.success("✅ One-Hot Encoding aplicado")
587
+ # Actualizar el estado temporal después de la codificación
588
+ st.session_state.temp_prepared_data = prepare.copy()
589
+ except Exception as e:
590
+ st.error(f"Error al aplicar One-Hot Encoding: {str(e)}")
591
+
592
+ # Normalización de variables numéricas
593
+ st.subheader("Normalización de Variables Numéricas")
594
+ numeric_columns = prepare.select_dtypes(include=['int64', 'float64']).columns
595
+
596
+ if len(numeric_columns) > 0:
597
+ cols_to_normalize = st.multiselect(
598
+ "Seleccionar columnas para normalizar",
599
+ numeric_columns,
600
+ key="cols_to_normalize"
601
+ )
602
+
603
+ if cols_to_normalize and st.button("Aplicar normalización", key="apply_normalization_button"):
604
+ try:
605
+ scaler = StandardScaler()
606
+ prepare[cols_to_normalize] = scaler.fit_transform(prepare[cols_to_normalize])
607
+ st.success("✅ Normalización aplicada")
608
+ # Actualizar el estado temporal después de la normalización
609
+ st.session_state.temp_prepared_data = prepare.copy()
610
+ except Exception as e:
611
+ st.error(f"Error al aplicar normalización: {str(e)}")
612
+
613
+ # Guardar datos preparados y mostrar matriz de correlación
614
+ st.write("### Vista previa de los datos:")
615
+ st.dataframe(prepare.head())
616
+
617
+ # Información sobre valores nulos
618
+ null_count = prepare.isnull().sum().sum()
619
+ if null_count > 0:
620
+ st.warning(f"⚠️ Hay {null_count} valores nulos en los datos.")
621
+ else:
622
+ st.success("✅ No hay valores nulos en los datos.")
623
+
624
+ # Matriz de correlación
625
+ st.subheader("Matriz de Correlación")
626
+ numerical_columns = prepare.select_dtypes(include=['int64', 'float64']).columns.tolist()
627
+
628
+ if len(numerical_columns) > 1:
629
+ corr_variables = st.multiselect(
630
+ "Selecciona las variables para incluir en la matriz de correlación",
631
+ options=numerical_columns,
632
+ default=numerical_columns[:min(5, len(numerical_columns))] # Seleccionar hasta 5 columnas por defecto
633
+ )
634
+
635
+ if corr_variables:
636
+ try:
637
+ # -------------------------------------------
638
+ # NUEVO: Detección de Outliers y Visualización
639
+ # -------------------------------------------
640
+ for var in corr_variables:
641
+ # Cálculo de Q1, Q3 e IQR
642
+ Q1 = prepare[var].quantile(0.25)
643
+ Q3 = prepare[var].quantile(0.75)
644
+ IQR = Q3 - Q1
645
+ lower_bound = Q1 - 1.5 * IQR
646
+ upper_bound = Q3 + 1.5 * IQR
647
+
648
+ # Identificación de outliers
649
+ outliers = prepare[(prepare[var] < lower_bound) | (prepare[var] > upper_bound)][var]
650
+ num_outliers = outliers.shape[0]
651
+
652
+ # Mostrar advertencia si hay outliers
653
+ if num_outliers > 0:
654
+ st.warning(f"⚠️ La variable **{var}** tiene {num_outliers} datos atípicos (outliers) detectados.")
655
+
656
+ # Mostrar boxplot usando Plotly
657
+ fig_box = px.box(prepare, y=var, title=f'Boxplot de {var}')
658
+ st.plotly_chart(fig_box, use_container_width=True)
659
+
660
+ # Calcular y mostrar la matriz de correlación
661
+ corr_matrix = prepare[corr_variables].corr(method='pearson')
662
+
663
+ # Mapa de calor de correlación
664
+ fig_corr = px.imshow(
665
+ corr_matrix,
666
+ text_auto=True,
667
+ aspect="auto",
668
+ color_continuous_scale='RdBu_r',
669
+ title='Matriz de Correlación de Pearson'
670
+ )
671
+ st.plotly_chart(fig_corr, use_container_width=True)
672
+
673
+ # Botón de descarga
674
+ csv_corr = corr_matrix.to_csv(index=True).encode('utf-8')
675
+ st.download_button(
676
+ label="Descargar Matriz de Correlación como CSV",
677
+ data=csv_corr,
678
+ file_name='matriz_correlacion.csv',
679
+ mime='text/csv',
680
+ )
681
+
682
+ # Análisis de correlaciones significativas
683
+ st.write("### Análisis de Correlaciones Significativas")
684
+ threshold = st.slider(
685
+ "Selecciona el umbral mínimo de correlación para considerar significativa",
686
+ min_value=0.0,
687
+ max_value=1.0,
688
+ value=0.5,
689
+ step=0.05
690
+ )
691
+
692
+ # Obtener y mostrar correlaciones significativas
693
+ corr_pairs = corr_matrix.unstack()
694
+ significant_corr = corr_pairs[
695
+ (abs(corr_pairs) >= threshold) &
696
+ (abs(corr_pairs) < 1)
697
+ ].drop_duplicates().sort_values(ascending=False)
698
+
699
+ if not significant_corr.empty:
700
+ st.write(f"Correlaciones significativas (|correlación| ≥ {threshold}):")
701
+ for (var1, var2), corr_value in significant_corr.items():
702
+ st.write(f"- **{var1}** y **{var2}**: correlación de **{corr_value:.2f}**")
703
+ else:
704
+ st.write("No se encontraron correlaciones significativas con el umbral seleccionado.")
705
+
706
+ except Exception as e:
707
+ st.error(f"Error al calcular la matriz de correlación: {str(e)}")
708
+ else:
709
+ st.warning("Por favor, selecciona al menos una variable para mostrar la matriz de correlación.")
710
+ else:
711
+ st.warning("No hay suficientes variables numéricas para calcular correlaciones.")
712
+
713
+ # Button para guardar datos preparados
714
+ if st.button("Guardar datos preparados", key="save_prepared_data_button"):
715
+ try:
716
+ null_count = prepare.isnull().sum().sum()
717
+ if null_count == 0:
718
+ st.session_state.prepared_data = prepare.copy()
719
+ st.session_state.temp_prepared_data = prepare.copy()
720
+ st.session_state.data_saved = True
721
+ st.success("✅ Datos preparados guardados exitosamente")
722
+
723
+ # Generar reporte
724
+ progress_container = st.empty()
725
+ with progress_container:
726
+ with st.spinner('Generando reporte del dataset...'):
727
+ profile = ProfileReport(prepare, title="Dataset Report", explorative=True)
728
+ st.session_state.report_html = profile.to_html()
729
+ st.success("¡Reporte generado exitosamente!")
730
+ else:
731
+ st.error(f"❌ No se pueden guardar los datos. Aún hay {null_count} valores nulos.")
732
+ st.warning("Por favor, aplica una estrategia de manejo de valores faltantes antes de guardar.")
733
+ except Exception as e:
734
+ st.error(f"Error al guardar los datos preparados: {str(e)}")
735
+
736
+ # Botones de descarga fuera del bloque principal
737
+ if 'data_saved' in st.session_state and st.session_state.data_saved:
738
+ col1, col2 = st.columns(2)
739
+
740
+ with col1:
741
+ csv = st.session_state.prepared_data.to_csv(index=False).encode('utf-8')
742
+ st.download_button(
743
+ label="Descargar Dataset Preparado",
744
+ data=csv,
745
+ file_name="prepared_dataset.csv",
746
+ mime="text/csv"
747
+ )
748
+
749
+ with col2:
750
+ st.download_button(
751
+ label="Descargar Reporte del Dataset",
752
+ data=st.session_state.report_html,
753
+ file_name="dataset_report.html",
754
+ mime="text/html"
755
+ )
756
+
757
+ st.info("👆 No te olvides de guardar los datos preparados antes de continuar con el análisis en la página Training o Test.")
datos/upload.py ADDED
@@ -0,0 +1,687 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # upload.py - Módulo para datos
2
+ import streamlit as st
3
+ import pandas as pd
4
+ import numpy as np
5
+ import io
6
+ import requests
7
+ from typing import Optional, Dict, List
8
+ import importlib.util
9
+ import os
10
+ import plotly.express as px
11
+ from supabase import create_client
12
+ import re
13
+ from pygwalker.api.streamlit import StreamlitRenderer
14
+
15
+ # Importaciones específicas del proyecto
16
+ from utils.gemini_explainer import generate_dataset_explanation
17
+
18
+ # Ejemplos de datasets
19
+ DATASET_OPTIONS = {
20
+ "Iris": "https://raw.githubusercontent.com/mwaskom/seaborn-data/master/iris.csv",
21
+ "Titanic": "https://raw.githubusercontent.com/datasciencedojo/datasets/master/titanic.csv",
22
+ "Boston Housing": "https://raw.githubusercontent.com/scikit-learn/scikit-learn/main/sklearn/datasets/data/boston_house_prices.csv",
23
+ "Wine Quality": "https://raw.githubusercontent.com/uiuc-cse/data-fa14/master/data/wine.csv",
24
+ "Diabetes": "https://raw.githubusercontent.com/plotly/datasets/master/diabetes.csv"
25
+ }
26
+
27
+ def check_package(package_name: str) -> bool:
28
+ """Verifica si un paquete está instalado"""
29
+ return importlib.util.find_spec(package_name) is not None
30
+
31
+ def get_supported_formats() -> Dict[str, List[str]]:
32
+ """Retorna un diccionario con los formatos soportados basado en las dependencias instaladas"""
33
+ formats = {
34
+ 'CSV': ['csv'],
35
+ 'Excel': ['xls', 'xlsx', 'xlsm', 'xlsb', 'odf', 'ods', 'odt'],
36
+ 'JSON': ['json']
37
+ }
38
+
39
+ # Verificar soporte para parquet
40
+ if check_package('pyarrow') or check_package('fastparquet'):
41
+ formats['Parquet'] = ['parquet']
42
+
43
+ # Verificar soporte para feather
44
+ if check_package('pyarrow'):
45
+ formats['Feather'] = ['feather']
46
+
47
+ # Verificar soporte para HDF5
48
+ if check_package('tables'):
49
+ formats['HDF5'] = ['h5', 'hdf5']
50
+
51
+ # Verificar soporte para SQLite
52
+ if check_package('sqlite3'):
53
+ formats['SQLite'] = ['db', 'sqlite', 'sqlite3']
54
+
55
+ # Verificar soporte para Pickle
56
+ formats['Pickle'] = ['pkl', 'pickle']
57
+
58
+ # Verificar soporte para STATA
59
+ if check_package('pandas.io.stata'):
60
+ formats['STATA'] = ['dta']
61
+
62
+ # Verificar soporte para SAS
63
+ if check_package('pandas.io.sas'):
64
+ formats['SAS'] = ['sas7bdat']
65
+
66
+ return formats
67
+
68
+ def load_file(file_obj: io.BytesIO, file_format: str) -> Optional[pd.DataFrame]:
69
+ """Carga un archivo en un DataFrame basado en su formato"""
70
+ try:
71
+ if file_format in ['csv']:
72
+ return pd.read_csv(file_obj)
73
+ elif file_format in ['xls', 'xlsx', 'xlsm', 'xlsb', 'odf', 'ods', 'odt']:
74
+ return pd.read_excel(file_obj)
75
+ elif file_format in ['json']:
76
+ return pd.read_json(file_obj)
77
+ elif file_format in ['parquet'] and (check_package('pyarrow') or check_package('fastparquet')):
78
+ return pd.read_parquet(file_obj)
79
+ elif file_format in ['feather'] and check_package('pyarrow'):
80
+ return pd.read_feather(file_obj)
81
+ elif file_format in ['h5', 'hdf5'] and check_package('tables'):
82
+ return pd.read_hdf(file_obj)
83
+ elif file_format in ['pkl', 'pickle']:
84
+ return pd.read_pickle(file_obj)
85
+ elif file_format in ['dta'] and check_package('pandas.io.stata'):
86
+ return pd.read_stata(file_obj)
87
+ elif file_format in ['sas7bdat'] and check_package('pandas.io.sas'):
88
+ return pd.read_sas(file_obj)
89
+ elif file_format in ['db', 'sqlite', 'sqlite3'] and check_package('sqlite3'):
90
+ import sqlite3
91
+ conn = sqlite3.connect(file_obj)
92
+ tables = pd.read_sql_query("SELECT name FROM sqlite_master WHERE type='table';", conn)
93
+ if len(tables) > 0:
94
+ table_name = st.selectbox("Selecciona una tabla:", tables['name'].tolist())
95
+ return pd.read_sql_query(f"SELECT * FROM {table_name};", conn)
96
+ else:
97
+ st.error("No se encontraron tablas en la base de datos")
98
+ return None
99
+ except Exception as e:
100
+ st.error(f"Error al cargar el archivo: {str(e)}")
101
+ return None
102
+
103
+ def load_gsheet(sharing_link: str) -> pd.DataFrame:
104
+ """Carga un Google Sheet como DataFrame usando su link de compartir"""
105
+ sheet_export = sharing_link.replace("/edit?usp=sharing", "/export?format=csv")
106
+ return pd.read_csv(sheet_export)
107
+
108
+ def convert_to_raw_github_url(url: str) -> str:
109
+ """Convierte una URL de GitHub en su versión 'raw'"""
110
+ # Patrón para URLs de GitHub
111
+ github_pattern = r'https://github\.com/([^/]+/[^/]+)/blob/([^/]+/.*)'
112
+
113
+ if match := re.match(github_pattern, url):
114
+ # Construir la URL raw
115
+ return f'https://raw.githubusercontent.com/{match.group(1)}/{match.group(2)}'
116
+ return url
117
+
118
+ def load_url_file(url: str) -> Optional[pd.DataFrame]:
119
+ """Carga un archivo desde una URL detectando automáticamente el formato"""
120
+ try:
121
+ # Convertir a URL raw si es una URL de GitHub
122
+ raw_url = convert_to_raw_github_url(url)
123
+
124
+ # Configurar headers para simular un navegador
125
+ headers = {
126
+ 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
127
+ }
128
+
129
+ response = requests.get(raw_url, headers=headers, verify=True)
130
+ if response.status_code != 200:
131
+ raise Exception(f"Error al descargar el archivo (Status code: {response.status_code})")
132
+
133
+ content = io.BytesIO(response.content)
134
+
135
+ # Detectar formato basado en la extensión de la URL
136
+ extension = url.split('.')[-1].lower()
137
+
138
+ # Validar la extensión antes de procesar
139
+ supported_formats = ['csv', 'xls', 'xlsx', 'xlsm', 'xlsb', 'odf', 'ods', 'odt',
140
+ 'json', 'parquet', 'feather', 'h5', 'hdf5', 'pkl', 'pickle',
141
+ 'dta', 'sas7bdat', 'db', 'sqlite', 'sqlite3']
142
+
143
+ if extension not in supported_formats:
144
+ st.error(f"Formato de archivo no soportado: {extension}")
145
+ return None
146
+
147
+ # Verificar que el contenido descargado no esté vacío
148
+ if len(response.content) == 0:
149
+ raise Exception("El archivo descargado está vacío")
150
+
151
+ # Para archivos Excel, usar directamente openpyxl
152
+ if extension in ['xlsx', 'xlsm', 'xlsb']:
153
+ return pd.read_excel(content, engine='openpyxl')
154
+
155
+ return load_file(content, extension)
156
+
157
+ except requests.exceptions.SSLError:
158
+ st.error("Error de seguridad SSL al descargar el archivo. Intente con una URL diferente.")
159
+ return None
160
+ except requests.exceptions.RequestException as e:
161
+ st.error(f"Error en la solicitud HTTP: {str(e)}")
162
+ return None
163
+ except Exception as e:
164
+ st.error(f"Error al cargar la URL: {str(e)}")
165
+ return None
166
+
167
+ def show_supabase_setup_info():
168
+ """Muestra información de configuración para Supabase"""
169
+
170
+ setup_sql = """
171
+ create or replace function get_tables()
172
+ returns table (table_name text)
173
+ language sql
174
+ as $$
175
+ select table_name::text
176
+ from information_schema.tables
177
+ where table_schema = 'public'
178
+ and table_type = 'BASE TABLE';
179
+ $$;
180
+ """
181
+
182
+ with st.expander("ℹ️ Configuración de Supabase", expanded=False):
183
+ st.markdown("""
184
+ ### Pasos para configurar Supabase
185
+
186
+ 1. **Crear función RPC en Supabase:**
187
+ - Ve al Editor SQL de Supabase
188
+ - Copia y ejecuta el siguiente código:
189
+ """)
190
+
191
+ # Mostrar el SQL con botón de copiado
192
+ st.code(setup_sql, language='sql')
193
+
194
+ st.markdown("""
195
+ 2. **Verificar credenciales:**
196
+ - URL del proyecto: `Settings -> API -> Project URL`
197
+ - API Key: `Settings -> API -> Project API keys -> anon/public`
198
+
199
+ 3. **Permisos necesarios:**
200
+ - La función necesita acceso a `information_schema.tables`
201
+ - El usuario debe tener permisos para ejecutar la función RPC
202
+
203
+ 4. **Solución de problemas:**
204
+ - Asegúrate de que existan tablas en el esquema público
205
+ - Verifica que la base de datos esté activa
206
+ - Confirma que las políticas de seguridad permitan el acceso
207
+ """)
208
+
209
+ def get_supabase_tables(supabase_url: str, supabase_key: str) -> Optional[List[str]]:
210
+ """Obtiene la lista de tablas disponibles en Supabase"""
211
+ try:
212
+ from supabase import create_client, Client
213
+
214
+ # Crear cliente de Supabase
215
+ supabase: Client = create_client(supabase_url, supabase_key)
216
+
217
+ try:
218
+ # Intenta primero usando RPC
219
+ result = supabase.rpc('get_tables').execute()
220
+
221
+ if hasattr(result, 'data') and result.data:
222
+ tables = [table['table_name'] for table in result.data]
223
+ if tables:
224
+ return sorted(tables) # Ordenar las tablas alfabéticamente
225
+ except Exception as rpc_error:
226
+ st.warning(f"Método RPC falló: {str(rpc_error)}")
227
+
228
+ try:
229
+ # Si RPC falla, intenta con una consulta SQL directa
230
+ result = supabase.from_('information_schema.tables')\
231
+ .select('table_name')\
232
+ .eq('table_schema', 'public')\
233
+ .eq('table_type', 'BASE TABLE')\
234
+ .execute()
235
+
236
+ if hasattr(result, 'data') and result.data:
237
+ return sorted([table['table_name'] for table in result.data])
238
+ except Exception as sql_error:
239
+ st.warning(f"Consulta SQL directa falló: {str(sql_error)}")
240
+
241
+ # Último intento usando postgREST
242
+ try:
243
+ result = supabase.table('tables').select('*').execute()
244
+ if hasattr(result, 'data') and result.data:
245
+ return sorted([table['name'] for table in result.data])
246
+ except Exception as postgrest_error:
247
+ st.error(f"Todos los métodos de consulta fallaron: {str(postgrest_error)}")
248
+
249
+ st.warning("No se encontraron tablas en el esquema público")
250
+ # Mostrar ayuda de configuración
251
+ show_supabase_setup_info()
252
+ return None
253
+
254
+ except Exception as e:
255
+ st.error(f"Error al conectar con Supabase: {str(e)}")
256
+ st.write("Detalles del error:", str(e))
257
+ # Mostrar ayuda de configuración
258
+ show_supabase_setup_info()
259
+ return None
260
+
261
+ def load_supabase_table(supabase_url: str, supabase_key: str, table_name: str) -> Optional[pd.DataFrame]:
262
+ """Carga una tabla de Supabase como DataFrame"""
263
+ try:
264
+ from supabase import create_client, Client
265
+
266
+ # Crear cliente de Supabase
267
+ supabase: Client = create_client(supabase_url, supabase_key)
268
+
269
+ # Realizar la consulta a la tabla
270
+ response = supabase.table(table_name).select("*").execute()
271
+
272
+ if hasattr(response, 'data'):
273
+ df = pd.DataFrame(response.data)
274
+ if not df.empty:
275
+ return df
276
+ else:
277
+ st.warning(f"La tabla '{table_name}' está vacía")
278
+ return None
279
+ else:
280
+ st.error("No se pudieron obtener datos de la tabla")
281
+ return None
282
+
283
+ except Exception as e:
284
+ st.error(f"Error al cargar la tabla de Supabase: {str(e)}")
285
+ st.write("Detalles del error:", str(e))
286
+ return None
287
+
288
+ def show_upload():
289
+ """Función principal para cargar y analizar datos"""
290
+ st.subheader('Aprenda con sus datos')
291
+
292
+ # Inicializar la variable de estado
293
+ if 'er_data' not in st.session_state:
294
+ st.session_state.er_data = None
295
+
296
+ # Obtener formatos soportados
297
+ SUPPORTED_FORMATS = get_supported_formats()
298
+ accepted_extensions = [ext for formats in SUPPORTED_FORMATS.values() for ext in formats]
299
+
300
+ # Mostrar formatos disponibles
301
+ with st.expander("Ver formatos soportados"):
302
+ for format_type, extensions in SUPPORTED_FORMATS.items():
303
+ st.write(f"**{format_type}**: {', '.join(extensions)}")
304
+
305
+ # Sección de Ejemplos Predeterminados
306
+ st.markdown("#### 0. Ejemplos Predeterminados")
307
+ selected_example = st.selectbox(
308
+ "Selecciona un dataset de ejemplo",
309
+ list(DATASET_OPTIONS.keys()) + ["Ninguno"],
310
+ index=len(DATASET_OPTIONS) # Seleccionar "Ninguno" por defecto
311
+ )
312
+
313
+ if selected_example != "Ninguno":
314
+ example_url = DATASET_OPTIONS[selected_example]
315
+ if st.button(f"Cargar Dataset de {selected_example}"):
316
+ try:
317
+ with st.spinner(f"Cargando dataset {selected_example}..."):
318
+ df = load_url_file(example_url)
319
+ if df is not None:
320
+ st.session_state.er_data = df
321
+ st.success(f"Dataset {selected_example} cargado exitosamente")
322
+ except Exception as e:
323
+ st.error(f"Error al cargar el dataset de ejemplo: {str(e)}")
324
+
325
+ # Secciones de carga de datos
326
+ st.markdown("#### 1. Subir Archivo Local")
327
+ data_file = st.file_uploader("Arrastra o selecciona tu archivo", type=accepted_extensions)
328
+
329
+ if data_file:
330
+ extension = data_file.name.split('.')[-1].lower()
331
+ df = load_file(data_file, extension)
332
+ if df is not None:
333
+ st.session_state.er_data = df
334
+ st.success(f"Archivo local cargado: {data_file.name}")
335
+
336
+ # Carga desde Google Sheet
337
+ st.markdown("#### 2. Cargar desde Google Sheet")
338
+ sharing_link = st.text_input(
339
+ "Link de Google Sheet:",
340
+ placeholder="https://docs.google.com/spreadsheets/d/SHEET-ID/edit?usp=sharing"
341
+ )
342
+ if sharing_link and st.button("Cargar Sheet"):
343
+ try:
344
+ st.session_state.er_data = load_gsheet(sharing_link)
345
+ st.success("Google Sheet cargado exitosamente")
346
+ except Exception as e:
347
+ st.error(f"Error al cargar el Google Sheet: {str(e)}")
348
+
349
+ # Carga desde URL
350
+ st.markdown("#### 3. Cargar desde URL")
351
+ url = st.text_input(
352
+ 'URL del archivo:',
353
+ placeholder='Ejemplo: https://ejemplo.com/datos.csv'
354
+ )
355
+ if url and st.button('Cargar URL'):
356
+ df = load_url_file(url)
357
+ if df is not None:
358
+ st.session_state.er_data = df
359
+
360
+ # Carga desde Supabase
361
+ st.markdown("#### 4. Carga desde Supabase")
362
+
363
+ # Verificar credenciales
364
+ has_credentials = (
365
+ 'supabase_url' in st.session_state and
366
+ 'supabase_key' in st.session_state and
367
+ st.session_state.supabase_url.strip() and
368
+ st.session_state.supabase_key.strip()
369
+ )
370
+
371
+ # Inicializar variables de estado
372
+ if 'supabase_tables' not in st.session_state:
373
+ st.session_state.supabase_tables = None
374
+ if 'supabase_connected' not in st.session_state:
375
+ st.session_state.supabase_connected = False
376
+
377
+ status_container = st.empty()
378
+
379
+ if not has_credentials:
380
+ status_container.warning("👉 Configura tus credenciales de Supabase en la sección superior izquierda antes de continuar.")
381
+ else:
382
+ col1, col2 = st.columns([1, 4])
383
+
384
+ with col1:
385
+ if st.button(
386
+ "Conectar" if not st.session_state.supabase_connected else "Reconectar",
387
+ key="connect_supabase",
388
+ help="Conectar a Supabase y listar tablas disponibles"
389
+ ):
390
+ with st.spinner("Conectando a Supabase..."):
391
+ tables = get_supabase_tables(
392
+ st.session_state.supabase_url,
393
+ st.session_state.supabase_key
394
+ )
395
+
396
+ if tables:
397
+ st.session_state.supabase_tables = tables
398
+ st.session_state.supabase_connected = True
399
+ status_container.success("✅ Conexión exitosa a Supabase")
400
+ else:
401
+ st.session_state.supabase_connected = False
402
+ status_container.error("❌ No se pudieron obtener las tablas. Verifica tus credenciales.")
403
+
404
+ if st.session_state.supabase_connected and st.session_state.supabase_tables:
405
+ table_container = st.container()
406
+
407
+ with table_container:
408
+ selected_table = st.selectbox(
409
+ "Selecciona una tabla:",
410
+ st.session_state.supabase_tables,
411
+ key="supabase_table_selector"
412
+ )
413
+
414
+ if st.button("Cargar Tabla", key="load_supabase_table"):
415
+ try:
416
+ with st.spinner("Cargando datos..."):
417
+ df = load_supabase_table(
418
+ st.session_state.supabase_url,
419
+ st.session_state.supabase_key,
420
+ selected_table
421
+ )
422
+ if df is not None:
423
+ st.session_state.er_data = df
424
+ st.success(f"✅ Tabla '{selected_table}' cargada exitosamente")
425
+ else:
426
+ st.error(f"❌ No se pudo cargar la tabla '{selected_table}'. La tabla puede estar vacía.")
427
+ except Exception as e:
428
+ st.error(f"❌ Error al cargar la tabla: {str(e)}")
429
+ st.write("Detalles del error:", str(e))
430
+
431
+ # Análisis de datos
432
+ if st.session_state.er_data is not None:
433
+ analyze_dataset(st.session_state.er_data)
434
+
435
+ return st.session_state.er_data
436
+
437
+ def analyze_dataset(data):
438
+ """Analizar el dataset cargado"""
439
+ # Generar explicación automática con Gemini
440
+ if 'dataset_explanation' not in st.session_state:
441
+ st.session_state.dataset_explanation = None
442
+
443
+ has_api_key = 'gemini_api_key' in st.session_state and st.session_state.gemini_api_key
444
+
445
+ if st.button(
446
+ "Explicar Dataset",
447
+ key="explain_dataset_button",
448
+ disabled=not has_api_key,
449
+ help="Requiere API key de Gemini para funcionar"
450
+ ):
451
+ st.session_state.dataset_explanation = generate_dataset_explanation(
452
+ data,
453
+ st.session_state.gemini_api_key
454
+ )
455
+
456
+ # Mostrar explicación si existe
457
+ if st.session_state.dataset_explanation:
458
+ st.markdown("### Explicación del Dataset")
459
+ st.write(st.session_state.dataset_explanation)
460
+
461
+ # Botón para limpiar explicación
462
+ if st.button("Limpiar Explicación", key="clear_explanation"):
463
+ st.session_state.dataset_explanation = None
464
+ st.rerun()
465
+
466
+ # Mostrar datos si se han cargado
467
+ # st.markdown("### Dataset Cargado")
468
+ # st.dataframe(data.head())
469
+ renderer = StreamlitRenderer(data)
470
+ renderer.explorer()
471
+ st.info(f"📊 Dimensiones: {data.shape[0]} filas × {data.shape[1]} columnas")
472
+ # Mostrar tipos de datos en columnas múltiples
473
+ with st.expander("📊 Ver tipos de datos por columna", expanded=False):
474
+ # Slider para número de columnas
475
+ num_columns = st.slider(
476
+ "Número de columnas para mostrar tipos de datos",
477
+ min_value=1,
478
+ max_value=10,
479
+ value=5,
480
+ help="Desliza para ajustar el número de columnas en la visualización de tipos de datos",
481
+ key="num_columns_slider"
482
+ )
483
+
484
+ # Obtener tipos de datos de cada columna
485
+ data_types = data.dtypes.reset_index()
486
+ data_types.columns = ["Columna", "Tipo de dato"]
487
+
488
+ st.write("**Tipos de datos por columna:**")
489
+
490
+ # Calcular elementos por columna
491
+ items_per_column = len(data_types) // num_columns + (1 if len(data_types) % num_columns != 0 else 0)
492
+
493
+ # Crear columnas en Streamlit
494
+ cols = st.columns(num_columns)
495
+
496
+ # Distribuir tipos de datos entre columnas
497
+ for col_idx in range(num_columns):
498
+ start_idx = col_idx * items_per_column
499
+ end_idx = min(start_idx + items_per_column, len(data_types))
500
+
501
+ if start_idx < len(data_types):
502
+ with cols[col_idx]:
503
+ for idx in range(start_idx, end_idx):
504
+ st.write(f"**{data_types.iloc[idx]['Columna']}**: {data_types.iloc[idx]['Tipo de dato']}")
505
+
506
+ # Mostrar resumen de tipos de datos
507
+ st.markdown("---")
508
+ st.write("**Resumen de tipos de datos:**")
509
+ type_summary = data.dtypes.value_counts()
510
+ summary_cols = st.columns(len(type_summary))
511
+ for i, (dtype, count) in enumerate(type_summary.items()):
512
+ with summary_cols[i]:
513
+ st.metric(f"Tipo: {dtype}", f"{count} columnas")
514
+
515
+ # Análisis de Variables por Tipo
516
+ st.markdown("### Análisis de Variables por Tipo")
517
+
518
+ # Crear columnas para mostrar variables numéricas y categóricas
519
+ col1, col2 = st.columns(2)
520
+
521
+ with col1:
522
+ st.markdown("#### Variables Numéricas")
523
+ lista_var_numericas = data.select_dtypes(include=['int64', 'float64']).columns.tolist()
524
+
525
+ if lista_var_numericas:
526
+ df_numericas = pd.DataFrame({
527
+ 'Variable': lista_var_numericas,
528
+ 'Tipo': [str(data[col].dtype) for col in lista_var_numericas]
529
+ })
530
+ st.dataframe(df_numericas, hide_index=True)
531
+
532
+ if st.checkbox("Ver estadísticas básicas de variables numéricas", key="show_numeric_stats"):
533
+ st.write(data[lista_var_numericas].describe())
534
+
535
+ selected_num_vars = st.multiselect(
536
+ "Seleccionar variables numéricas para análisis",
537
+ lista_var_numericas,
538
+ default=lista_var_numericas[0] if lista_var_numericas else None,
539
+ key="numeric_vars_select"
540
+ )
541
+
542
+ if selected_num_vars:
543
+ st.write("**Histograma de variables seleccionadas:**")
544
+ for var in selected_num_vars:
545
+ fig_hist = px.histogram(
546
+ data,
547
+ x=var,
548
+ title=f'Histograma de {var}'
549
+ )
550
+ st.plotly_chart(fig_hist)
551
+ else:
552
+ st.info("No se encontraron variables numéricas en el dataset")
553
+
554
+ with col2:
555
+ st.markdown("#### Variables Categóricas")
556
+ lista_var_object = data.select_dtypes(include=['object']).columns.tolist()
557
+
558
+ if lista_var_object:
559
+ df_categoricas = pd.DataFrame({
560
+ 'Variable': lista_var_object,
561
+ 'Tipo': [str(data[col].dtype) for col in lista_var_object]
562
+ })
563
+ st.dataframe(df_categoricas, hide_index=True)
564
+
565
+ if st.checkbox("Ver valores únicos de variables categóricas", key="show_categorical_stats"):
566
+ selected_cat_var = st.selectbox(
567
+ "Seleccionar variable categórica",
568
+ lista_var_object,
569
+ key="categorical_var_select"
570
+ )
571
+ if selected_cat_var:
572
+ unique_values = data[selected_cat_var].value_counts()
573
+ st.write("Valores únicos en {selected_cat_var}")
574
+
575
+ # Gráfico de barras de valores únicos
576
+ fig_bar = px.bar(
577
+ x=unique_values.index,
578
+ y=unique_values.values,
579
+ title=f'Distribución de {selected_cat_var}'
580
+ )
581
+ st.plotly_chart(fig_bar)
582
+
583
+ # Tabla de frecuencia
584
+ freq_df = pd.DataFrame({
585
+ 'Valor': unique_values.index,
586
+ 'Frecuencia': unique_values.values,
587
+ 'Porcentaje': (unique_values.values / len(data) * 100).round(2)
588
+ })
589
+ st.dataframe(freq_df)
590
+ else:
591
+ st.info("No se encontraron variables categóricas en el dataset")
592
+
593
+ # Matriz de Correlación para Variables Numéricas
594
+ st.markdown("### Matriz de Correlación")
595
+ numeric_columns = data.select_dtypes(include=['int64', 'float64']).columns.tolist()
596
+
597
+ if len(numeric_columns) > 1:
598
+ # Selección de variables para correlación
599
+ corr_variables = st.multiselect(
600
+ "Selecciona las variables para la matriz de correlación",
601
+ options=numeric_columns,
602
+ default=numeric_columns[:min(5, len(numeric_columns))]
603
+ )
604
+
605
+ if corr_variables:
606
+ # Calcular matriz de correlación
607
+ corr_matrix = data[corr_variables].corr(method='pearson')
608
+
609
+ # Mapa de calor de correlación
610
+ fig_corr = px.imshow(
611
+ corr_matrix,
612
+ text_auto=True,
613
+ aspect="auto",
614
+ color_continuous_scale='RdBu_r',
615
+ title='Matriz de Correlación de Pearson'
616
+ )
617
+ st.plotly_chart(fig_corr, use_container_width=True)
618
+
619
+ # Análisis de correlaciones significativas
620
+ st.write("### Análisis de Correlaciones Significativas")
621
+ threshold = st.slider(
622
+ "Umbral mínimo de correlación",
623
+ min_value=0.0,
624
+ max_value=1.0,
625
+ value=0.5,
626
+ step=0.05
627
+ )
628
+
629
+ # Obtener correlaciones significativas
630
+ corr_pairs = corr_matrix.unstack()
631
+ significant_corr = corr_pairs[
632
+ (abs(corr_pairs) >= threshold) &
633
+ (abs(corr_pairs) < 1)
634
+ ].sort_values(ascending=False)
635
+
636
+ if not significant_corr.empty:
637
+ st.write(f"Correlaciones significativas (|correlación| ≥ {threshold}):")
638
+ for (var1, var2), corr_value in significant_corr.items():
639
+ st.write(f"- **{var1}** y **{var2}**: correlación de **{corr_value:.2f}**")
640
+ else:
641
+ st.write("No se encontraron correlaciones significativas con el umbral seleccionado.")
642
+
643
+ else:
644
+ st.warning("No hay suficientes variables numéricas para calcular correlaciones.")
645
+
646
+ # Detección de Outliers
647
+ st.markdown("### Detección de Outliers")
648
+ numeric_columns = data.select_dtypes(include=['int64', 'float64']).columns.tolist()
649
+
650
+ if numeric_columns:
651
+ outlier_vars = st.multiselect(
652
+ "Selecciona variables para análisis de outliers",
653
+ options=numeric_columns
654
+ )
655
+
656
+ if outlier_vars:
657
+ for var in outlier_vars:
658
+ # Cálculo de Q1, Q3 e IQR
659
+ Q1 = data[var].quantile(0.25)
660
+ Q3 = data[var].quantile(0.75)
661
+ IQR = Q3 - Q1
662
+ lower_bound = Q1 - 1.5 * IQR
663
+ upper_bound = Q3 + 1.5 * IQR
664
+
665
+ # Identificación de outliers
666
+ outliers = data[(data[var] < lower_bound) | (data[var] > upper_bound)][var]
667
+ num_outliers = outliers.shape[0]
668
+
669
+ st.write(f"### Análisis de Outliers para {var}")
670
+
671
+ # Boxplot
672
+ fig_box = px.box(data, y=var, title=f'Boxplot de {var}')
673
+ st.plotly_chart(fig_box)
674
+
675
+ # Resumen de outliers
676
+ col1, col2, col3 = st.columns(3)
677
+ with col1:
678
+ st.metric("Total de Datos", len(data))
679
+ with col2:
680
+ st.metric("Número de Outliers", num_outliers)
681
+ with col3:
682
+ st.metric("Porcentaje de Outliers", f"{num_outliers/len(data)*100:.2f}%")
683
+
684
+ # Mostrar outliers
685
+ if st.checkbox(f"Mostrar outliers de {var}"):
686
+ st.dataframe(outliers)
687
+ return data
models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # __init__.py - M�dulo para src models
models/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (179 Bytes). View file
 
models/__pycache__/test.cpython-312.pyc ADDED
Binary file (17.7 kB). View file
 
models/__pycache__/train.cpython-312.pyc ADDED
Binary file (31.8 kB). View file
 
models/__pycache__/unsupervised.cpython-312.pyc ADDED
Binary file (20.3 kB). View file
 
models/test.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # test.py - Módulo para models
2
+ import streamlit as st
3
+ import pandas as pd
4
+ import numpy as np
5
+ import pickle
6
+ import plotly.express as px
7
+ import plotly.graph_objects as go
8
+ import google.generativeai as genai
9
+ from sklearn.model_selection import train_test_split
10
+ from sklearn.metrics import (
11
+ mean_squared_error,
12
+ r2_score,
13
+ mean_absolute_percentage_error,
14
+ accuracy_score,
15
+ precision_score,
16
+ recall_score,
17
+ f1_score,
18
+ confusion_matrix,
19
+ classification_report
20
+ )
21
+ from sklearn.preprocessing import LabelEncoder
22
+ import io
23
+
24
+ class ModelTester:
25
+ def __init__(self, model, X, y, problem_type):
26
+ self.model = model
27
+ self.X = X
28
+ self.y = y
29
+ self.problem_type = problem_type
30
+ self.label_encoder = None
31
+
32
+ def _prepare_data(self, test_size=0.2, random_state=42):
33
+ """Preparar datos para prueba"""
34
+ X_train, X_test, y_train, y_test = train_test_split(
35
+ self.X, self.y,
36
+ test_size=test_size,
37
+ random_state=random_state,
38
+ stratify=self.y if self.problem_type == 'classification' else None
39
+ )
40
+ return X_train, X_test, y_train, y_test
41
+
42
+ def _encode_target(self, y):
43
+ """Codificar variable objetivo para clasificación"""
44
+ if self.problem_type == 'classification':
45
+ self.label_encoder = LabelEncoder()
46
+ return self.label_encoder.fit_transform(y)
47
+ return y
48
+
49
+ def evaluate_regression(self, X_test, y_test):
50
+ """Evaluar modelo de regresión"""
51
+ y_pred = self.model.predict(X_test)
52
+
53
+ metrics = {
54
+ 'MSE': mean_squared_error(y_test, y_pred),
55
+ 'R² Score': r2_score(y_test, y_pred),
56
+ 'MAPE': mean_absolute_percentage_error(y_test, y_pred) * 100
57
+ }
58
+
59
+ return metrics, y_pred
60
+
61
+ def evaluate_classification(self, X_test, y_test):
62
+ """Evaluar modelo de clasificación"""
63
+ y_test_encoded = self._encode_target(y_test)
64
+ y_pred = self.model.predict(X_test)
65
+
66
+ metrics = {
67
+ 'Accuracy': accuracy_score(y_test_encoded, y_pred),
68
+ 'Precision': precision_score(y_test_encoded, y_pred, average='weighted'),
69
+ 'Recall': recall_score(y_test_encoded, y_pred, average='weighted'),
70
+ 'F1 Score': f1_score(y_test_encoded, y_pred, average='weighted')
71
+ }
72
+
73
+ return metrics, y_pred
74
+
75
+ def plot_regression_results(self, y_test, y_pred):
76
+ """Crear gráfico de resultados de regresión"""
77
+ fig = go.Figure()
78
+ fig.add_trace(go.Scatter(
79
+ x=y_test, y=y_pred,
80
+ mode='markers',
81
+ name='Predicciones vs Valores Reales'
82
+ ))
83
+ fig.add_trace(go.Scatter(
84
+ x=[y_test.min(), y_test.max()],
85
+ y=[y_test.min(), y_test.max()],
86
+ mode='lines',
87
+ name='Línea Perfecta',
88
+ line=dict(color='red', dash='dash')
89
+ ))
90
+ fig.update_layout(
91
+ title='Predicciones vs Valores Reales',
92
+ xaxis_title='Valores Reales',
93
+ yaxis_title='Predicciones'
94
+ )
95
+ return fig
96
+
97
+ def plot_classification_results(self, y_test, y_pred):
98
+ """Crear matriz de confusión para clasificación"""
99
+ cm = confusion_matrix(
100
+ self._encode_target(y_test),
101
+ y_pred
102
+ )
103
+
104
+ fig = px.imshow(
105
+ cm,
106
+ labels=dict(x="Predicción", y="Real"),
107
+ x=[str(c) for c in self.label_encoder.classes_] if self.label_encoder else None,
108
+ y=[str(c) for c in self.label_encoder.classes_] if self.label_encoder else None,
109
+ title="Matriz de Confusión"
110
+ )
111
+ return fig
112
+
113
+ def load_model(uploaded_file):
114
+ """Cargar modelo desde archivo pickle"""
115
+ try:
116
+ with uploaded_file as f:
117
+ model = pickle.load(f)
118
+ return model
119
+ except Exception as e:
120
+ st.error(f"Error al cargar el modelo: {e}")
121
+ return None
122
+
123
+ def get_model_features(model):
124
+ """Extract feature names from the model if available."""
125
+ if hasattr(model, 'feature_names_in_'):
126
+ return list(model.feature_names_in_)
127
+ return None
128
+
129
+ def align_features(X, model_features):
130
+ """Align input features with model's expected features."""
131
+ if model_features is None:
132
+ return X
133
+
134
+ # Create a new DataFrame with the correct features in the correct order
135
+ missing_cols = set(model_features) - set(X.columns)
136
+ extra_cols = set(X.columns) - set(model_features)
137
+
138
+ if missing_cols:
139
+ st.warning(f"Missing features: {missing_cols}. These will need to be provided.")
140
+ return None
141
+
142
+ if extra_cols:
143
+ st.warning(f"Extra features detected: {extra_cols}. These will be ignored.")
144
+
145
+ return X[model_features]
146
+
147
+ def determine_problem_type(model):
148
+ """Determine if the model is for classification or regression."""
149
+ class_methods = ['predict_proba', 'classes_']
150
+ return 'classification' if any(hasattr(model, method) for method in class_methods) else 'regression'
151
+
152
+ def generate_model_explanation(model, metrics, problem_type):
153
+ """Generar explicación del modelo usando Gemini"""
154
+ try:
155
+ genai.configure(api_key=st.session_state.get('gemini_api_key'))
156
+ model_ai = genai.GenerativeModel('gemini-1.5-flash')
157
+
158
+ metrics_text = "\n".join([f"{k}: {v}" for k, v in metrics.items()])
159
+
160
+ prompt = f"""Analiza los siguientes resultados de un modelo de {problem_type}:
161
+
162
+ Métricas de Rendimiento:
163
+ {metrics_text}
164
+
165
+ Proporciona:
166
+ 1. Interpretación de las métricas
167
+ 2. Fortalezas y debilidades del modelo
168
+ 3. Posibles mejoras o alternativas
169
+ 4. Contexto práctico de estos resultados
170
+ """
171
+
172
+ response = model_ai.generate_content(prompt)
173
+ return response.text
174
+ except Exception as e:
175
+ st.error(f"Error generando explicación: {e}")
176
+ return "No se pudo generar la explicación."
177
+
178
+ def show_test():
179
+ st.title("Prueba de Modelo")
180
+
181
+ # Cargar modelo
182
+ uploaded_model = st.file_uploader(
183
+ "Cargar modelo entrenado",
184
+ type=['pkl']
185
+ )
186
+
187
+ if not uploaded_model:
188
+ st.warning("Por favor, cargue un modelo entrenado")
189
+ return
190
+
191
+ # Cargar datos preparados
192
+ if 'prepared_data' not in st.session_state:
193
+ st.warning("No hay datos preparados. Por favor, prepare los datos primero.")
194
+ return
195
+
196
+ data = st.session_state.prepared_data
197
+
198
+ # Selección de características y objetivo
199
+ st.subheader("Configuración de Prueba")
200
+
201
+ # Columnas numéricas
202
+ model_features = get_model_features(uploaded_model)
203
+ numeric_cols = data.select_dtypes(include=['int64', 'float64']).columns.tolist()
204
+
205
+ if model_features:
206
+ # Pre-select features that match the model's expected features
207
+ default_features = [col for col in model_features if col in numeric_cols]
208
+ feature_cols = st.multiselect(
209
+ "Seleccionar variables predictoras (X):",
210
+ numeric_cols,
211
+ default=default_features
212
+ )
213
+ else:
214
+ feature_cols = st.multiselect(
215
+ "Seleccionar variables predictoras (X):",
216
+ numeric_cols,
217
+ default=st.session_state.get('feature_cols', [])
218
+ )
219
+
220
+ available_targets = [col for col in data.columns if col not in feature_cols]
221
+ target_col = st.selectbox(
222
+ "Seleccionar variable objetivo (y):",
223
+ available_targets,
224
+ index=available_targets.index(st.session_state.get('target_col', available_targets[0]))
225
+ if st.session_state.get('target_col') in available_targets else 0
226
+ )
227
+
228
+ if not feature_cols or not target_col:
229
+ st.warning("Seleccione variables predictoras y objetivo")
230
+ return
231
+
232
+ # Cargar modelo
233
+ model = load_model(uploaded_model)
234
+ if not model:
235
+ return
236
+
237
+ # Preparar datos
238
+ X = data[feature_cols]
239
+ y = data[target_col]
240
+
241
+ # Determinar tipo de problema
242
+ problem_type = 'classification' if y.dtype == 'object' or y.nunique() <= 10 else 'regression'
243
+ st.write(f"Tipo de problema detectado: {problem_type}")
244
+
245
+ # Opciones de prueba
246
+ test_size = st.slider(
247
+ "Tamaño del conjunto de prueba",
248
+ 0.1, 0.5, 0.2
249
+ )
250
+
251
+ # Probar modelo
252
+ if 'model_evaluated' not in st.session_state:
253
+ st.session_state.model_evaluated = False
254
+
255
+ if st.button("Evaluar Modelo"):
256
+ # Crear tester
257
+ model_tester = ModelTester(model, X, y, problem_type)
258
+
259
+ # Preparar datos
260
+ X_train, X_test, y_train, y_test = model_tester._prepare_data(test_size)
261
+
262
+ # Evaluar modelo según el tipo de problema
263
+ if problem_type == 'regression':
264
+ metrics, y_pred = model_tester.evaluate_regression(X_test, y_test)
265
+
266
+ # Métricas de rendimiento
267
+ st.subheader("Métricas de Rendimiento")
268
+ col1, col2, col3 = st.columns(3)
269
+ col1.metric("MSE", f"{metrics['MSE']:.4f}")
270
+ col2.metric("R² Score", f"{metrics['R² Score']:.4f}")
271
+ col3.metric("MAPE", f"{metrics['MAPE']:.2f}%")
272
+
273
+ # Visualización de resultados
274
+ st.subheader("Visualización de Resultados")
275
+ fig = model_tester.plot_regression_results(y_test, y_pred)
276
+ st.plotly_chart(fig, use_container_width=True)
277
+
278
+ else: # Clasificación
279
+ metrics, y_pred = model_tester.evaluate_classification(X_test, y_test)
280
+
281
+ # Métricas de rendimiento
282
+ st.subheader("Métricas de Rendimiento")
283
+ col1, col2, col3, col4 = st.columns(4)
284
+ col1.metric("Accuracy", f"{metrics['Accuracy']:.4f}")
285
+ col2.metric("Precision", f"{metrics['Precision']:.4f}")
286
+ col3.metric("Recall", f"{metrics['Recall']:.4f}")
287
+ col4.metric("F1 Score", f"{metrics['F1 Score']:.4f}")
288
+
289
+ # Matriz de confusión
290
+ st.subheader("Matriz de Confusión")
291
+ fig = model_tester.plot_classification_results(y_test, y_pred)
292
+ st.plotly_chart(fig)
293
+
294
+ # Reporte de clasificación
295
+ st.subheader("Reporte de Clasificación")
296
+ st.text(classification_report(
297
+ model_tester._encode_target(y_test),
298
+ y_pred
299
+ ))
300
+
301
+ # Guardar métricas en session state
302
+ st.session_state.metrics = metrics
303
+ st.session_state.model_evaluated = True
304
+
305
+ # Explicación del modelo con Gemini (fuera del if anterior)
306
+ st.subheader("Análisis de Resultados")
307
+ if st.session_state.get('gemini_api_key'):
308
+ if st.button("Generar Explicación Detallada", disabled=not st.session_state.model_evaluated, help="Evalúa el modelo primero"):
309
+ with st.spinner("Generando explicación..."):
310
+ explanation = generate_model_explanation(
311
+ model, st.session_state.metrics, problem_type
312
+ )
313
+ st.markdown(explanation)
314
+ else:
315
+ st.warning("Configure la API key de Gemini para obtener explicaciones detalladas")
316
+
317
+ # Predicciones de ejemplo
318
+ st.subheader("Predicciones de Ejemplo")
319
+ num_samples = st.slider(
320
+ "Número de muestras a mostrar",
321
+ 5, 50, 10
322
+ )
323
+
324
+ # Seleccionar muestras aleatorias
325
+ sample_indices = np.random.choice(
326
+ len(X_test),
327
+ min(num_samples, len(X_test)),
328
+ replace=False
329
+ )
330
+ sample_X = X_test.iloc[sample_indices]
331
+ sample_y_true = y_test.iloc[sample_indices]
332
+ sample_y_pred = model.predict(sample_X)
333
+
334
+ # Crear DataFrame de comparación
335
+ comparison_df = pd.DataFrame({
336
+ 'Características': [
337
+ ', '.join([f"{col}: {val}" for col, val in row.items()])
338
+ for _, row in sample_X.iterrows()
339
+ ],
340
+ 'Valor Real': sample_y_true,
341
+ 'Predicción': sample_y_pred,
342
+ 'Error Absoluto' if problem_type == 'regression'
343
+ else 'Predicción Correcta':
344
+ np.abs(sample_y_true - sample_y_pred) if problem_type == 'regression'
345
+ else (sample_y_true == sample_y_pred)
346
+ })
347
+
348
+ st.dataframe(comparison_df)
349
+
350
+ # Opciones de descarga
351
+ st.subheader("Descargar Resultados")
352
+
353
+ # Guardar métricas
354
+ metrics_df = pd.DataFrame.from_dict(metrics, orient='index', columns=['Valor'])
355
+
356
+ # Selector de formato
357
+ download_format = st.selectbox(
358
+ "Seleccionar formato de descarga",
359
+ ["CSV", "Excel"]
360
+ )
361
+
362
+ if download_format == "CSV":
363
+ csv_data = metrics_df.to_csv().encode('utf-8')
364
+ st.download_button(
365
+ label="Descargar Métricas (CSV)",
366
+ data=csv_data,
367
+ file_name="model_metrics.csv",
368
+ mime="text/csv"
369
+ )
370
+ else:
371
+ excel_buffer = io.BytesIO()
372
+ with pd.ExcelWriter(excel_buffer, engine='xlsxwriter') as writer:
373
+ metrics_df.to_excel(writer, index=True, sheet_name='Métricas')
374
+ comparison_df.to_excel(writer, index=False, sheet_name='Predicciones')
375
+ excel_buffer.seek(0)
376
+ st.download_button(
377
+ label="Descargar Resultados (Excel)",
378
+ data=excel_buffer,
379
+ file_name="model_results.xlsx",
380
+ mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
381
+ )
382
+
383
+ def main():
384
+ """Función principal para ejecutar la página de prueba de modelos"""
385
+ show_test()
386
+
387
+ if __name__ == "__main__":
388
+ main()
models/train.py ADDED
@@ -0,0 +1,829 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # models/train.py
2
+ import time
3
+ import streamlit as st
4
+ import pandas as pd
5
+ import numpy as np
6
+ from sklearn.model_selection import train_test_split, GridSearchCV
7
+ from sklearn.preprocessing import StandardScaler, LabelEncoder
8
+ from sklearn.pipeline import Pipeline
9
+ from imblearn.over_sampling import SMOTE
10
+ from sklearn.utils import resample
11
+ from sklearn.metrics import (
12
+ mean_squared_error, r2_score,
13
+ accuracy_score, classification_report, confusion_matrix
14
+ )
15
+ import os
16
+ import sys
17
+ import pickle
18
+ import io
19
+ import h2o
20
+ from flaml import AutoML
21
+ from typing import Dict, Any, Optional
22
+
23
+ # Importaciones
24
+ from utils.model_utils import (
25
+ ModelTrainer, # Importar la clase
26
+ get_model_options,
27
+ train_model_pipeline,
28
+ process_classification_data,
29
+ create_class_distribution_plot
30
+ )
31
+ from utils.gemini_explainer import initialize_gemini_explainer
32
+ from utils.gemini_explainer import generate_model_explanation
33
+ from utils.shap_explainer import create_shap_analysis_dashboard
34
+
35
+ def safe_init_h2o(url=None, **kwargs):
36
+ """
37
+ Safely initialize H2O cluster if not already running.
38
+
39
+ Args:
40
+ url (str, optional): H2O cluster URL. Defaults to None (local instance).
41
+ **kwargs: Additional arguments to pass to h2o.init()
42
+
43
+ Returns:
44
+ h2o._backend.H2OConnection: The H2O connection object
45
+ """
46
+ # Get current H2O instance if exists
47
+ current = h2o.connection()
48
+
49
+ # Check if H2O is already running
50
+ if current and current.cluster:
51
+ print("H2O is already running at", current.base_url)
52
+ return current
53
+
54
+ # Initialize new H2O instance
55
+ print("Starting new H2O instance...")
56
+ return h2o.init(url=url, **kwargs)
57
+
58
+ def convert_h2o_to_pandas(h2o_df):
59
+ """
60
+ Convierte un H2OFrame a pandas DataFrame utilizando múltiples hilos.
61
+
62
+ Args:
63
+ h2o_df (h2o.H2OFrame): Frame de H2O a convertir.
64
+
65
+ Returns:
66
+ pd.DataFrame: DataFrame de pandas.
67
+ """
68
+ return h2o_df.as_data_frame(use_multi_thread=True)
69
+
70
+ # Obtener la ruta del directorio raíz del proyecto
71
+ project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
72
+ sys.path.insert(0, project_root)
73
+
74
+ def validate_data_preparation(train):
75
+ """
76
+ Validar que los datos estén preparados correctamente
77
+
78
+ Args:
79
+ train (pd.DataFrame): Datos de entrenamiento
80
+
81
+ Returns:
82
+ bool: Indica si los datos están listos para entrenamiento
83
+ """
84
+ if train is None or train.empty:
85
+ st.warning("⚠️ No hay datos preparados en la sesión.")
86
+ return False
87
+ return True
88
+
89
+ def select_features_and_target(train):
90
+ """
91
+ Permitir al usuario seleccionar características y variable objetivo
92
+
93
+ Args:
94
+ train (pd.DataFrame): Datos de entrenamiento
95
+
96
+ Returns:
97
+ tuple: Variables predictoras (X) y variable objetivo (y)
98
+ """
99
+ numeric_cols = train.select_dtypes(include=['int64', 'float64']).columns.tolist()
100
+
101
+ # Mantener las selecciones en session_state
102
+ if 'feature_cols' not in st.session_state:
103
+ st.session_state.feature_cols = []
104
+
105
+ feature_cols = st.multiselect(
106
+ "Selecciona las variables predictoras (X):",
107
+ numeric_cols,
108
+ default=st.session_state.feature_cols
109
+ )
110
+ st.session_state.feature_cols = feature_cols
111
+
112
+ # Obtener TODAS las columnas disponibles para target
113
+ all_cols = train.columns.tolist()
114
+ available_targets = [col for col in all_cols if col not in feature_cols]
115
+
116
+ if not available_targets:
117
+ st.warning("Por favor, deselecciona algunas variables predictoras para poder seleccionar la variable objetivo.")
118
+ return None, None
119
+
120
+ if ('target_col' not in st.session_state or
121
+ st.session_state.target_col not in available_targets):
122
+ st.session_state.target_col = available_targets[0]
123
+
124
+ target_col = st.selectbox(
125
+ "Selecciona la variable objetivo (y):",
126
+ available_targets,
127
+ index=available_targets.index(st.session_state.target_col)
128
+ )
129
+ st.session_state.target_col = target_col
130
+
131
+ if not (feature_cols and target_col):
132
+ st.warning("Por favor selecciona variables predictoras y objetivo.")
133
+ return None, None
134
+
135
+ X = train[feature_cols]
136
+ y = train[target_col]
137
+
138
+ return X, y
139
+
140
+ def determine_problem_type(y):
141
+ """
142
+ Determinar el tipo de problema de machine learning
143
+
144
+ Args:
145
+ y (pd.Series): Variable objetivo
146
+
147
+ Returns:
148
+ str: Tipo de problema ('classification' o 'regression')
149
+ """
150
+ is_categorical = y.dtype == 'object' or (y.dtype.name.startswith(('int', 'float')) and y.nunique() <= 10)
151
+ problem_type = 'classification' if is_categorical else 'regression'
152
+ st.write(f"Tipo de problema identificado: **{problem_type}**")
153
+ return problem_type
154
+
155
+ def handle_data_balancing(X, y, random_state):
156
+ """
157
+ Manejar el desbalanceo de clases
158
+
159
+ Args:
160
+ X (pd.DataFrame): Variables predictoras
161
+ y (pd.Series): Variable objetivo
162
+ random_state (int): Semilla aleatoria
163
+
164
+ Returns:
165
+ tuple: Variables predictoras y objetivo balanceadas
166
+ """
167
+ if y.value_counts().min() / y.value_counts().max() < 0.5:
168
+ st.write("⚠️ Se detectó desbalanceo en las clases")
169
+ balance_method = st.selectbox(
170
+ "Técnica de balanceo:",
171
+ ["Ninguno", "Submuestreo", "Sobremuestreo", "SMOTE"]
172
+ )
173
+
174
+ if balance_method != "Ninguno":
175
+ with st.spinner("Aplicando técnica de balanceo..."):
176
+ if balance_method == "Submuestreo":
177
+ min_class_size = y.value_counts().min()
178
+ X, y = resample(X, y, n_samples=min_class_size*2, stratify=y)
179
+ elif balance_method == "Sobremuestreo":
180
+ max_class_size = y.value_counts().max()
181
+ X, y = resample(X, y, n_samples=max_class_size*2, stratify=y)
182
+ else: # SMOTE
183
+ smote = SMOTE(random_state=random_state)
184
+ X, y = smote.fit_resample(X, y)
185
+ st.success("Balanceo completado!")
186
+
187
+ return X, y
188
+
189
+ def safe_init_h2o(url=None, **kwargs):
190
+ """
191
+ Safely initialize H2O cluster if not already running.
192
+
193
+ Args:
194
+ url (str, optional): H2O cluster URL. Defaults to None (local instance).
195
+ **kwargs: Additional arguments to pass to h2o.init()
196
+
197
+ Returns:
198
+ h2o._backend.H2OConnection: The H2O connection object
199
+ """
200
+ # Get current H2O instance if exists
201
+ current = h2o.connection()
202
+
203
+ # Check if H2O is already running
204
+ if current and current.cluster:
205
+ print("H2O is already running at", current.base_url)
206
+ return current
207
+
208
+ # Initialize new H2O instance
209
+ print("Starting new H2O instance...")
210
+ return h2o.init(url=url, **kwargs)
211
+
212
+ class AutoMLTrainer:
213
+ """Clase para gestionar el entrenamiento automático de modelos"""
214
+
215
+ @staticmethod
216
+ def train_h2o_automl(
217
+ X_train: pd.DataFrame,
218
+ y_train: pd.Series,
219
+ X_test: pd.DataFrame,
220
+ y_test: pd.Series,
221
+ problem_type: str,
222
+ time_limit: int = 3600,
223
+ max_models: int = 20
224
+ ) -> Dict[str, Any]:
225
+ """
226
+ Entrenar modelos usando H2O AutoML con manejo correcto de tipos de datos
227
+ """
228
+ try:
229
+ safe_init_h2o()
230
+
231
+ # Crear un DataFrame combinado con la variable objetivo
232
+ train_df = X_train.copy()
233
+ test_df = X_test.copy()
234
+
235
+ # Manejar la variable objetivo según el tipo de problema
236
+ if problem_type == 'classification':
237
+ train_df['target'] = y_train.astype(str)
238
+ test_df['target'] = y_test.astype(str)
239
+ else:
240
+ train_df['target'] = y_train.astype(float)
241
+ test_df['target'] = y_test.astype(float)
242
+
243
+ # Convertir a H2OFrame
244
+ train = h2o.H2OFrame(train_df)
245
+ test = h2o.H2OFrame(test_df)
246
+
247
+ # Si es clasificación, convertir explícitamente la columna objetivo a factor
248
+ if problem_type == 'classification':
249
+ train['target'] = train['target'].asfactor()
250
+ test['target'] = test['target'].asfactor()
251
+
252
+ # Especificar columnas
253
+ feature_cols = X_train.columns.tolist()
254
+ target_col = 'target'
255
+
256
+ # Configurar AutoML
257
+ aml = h2o.automl.H2OAutoML(
258
+ max_runtime_secs=time_limit,
259
+ max_models=max_models,
260
+ seed=42,
261
+ sort_metric="AUTO"
262
+ )
263
+
264
+ # Entrenar
265
+ start_time = time.time()
266
+ aml.train(x=feature_cols, y=target_col, training_frame=train)
267
+ training_time = time.time() - start_time
268
+
269
+ # Obtener el mejor modelo
270
+ best_model = aml.leader
271
+
272
+ # Obtener hiperparámetros correctamente
273
+ hyperparameters = best_model.params
274
+
275
+ # Obtener predicciones
276
+ preds = best_model.predict(test[feature_cols])
277
+ predictions = preds.as_data_frame(use_pandas=True)
278
+
279
+ if problem_type == 'classification':
280
+ predictions = predictions['predict']
281
+
282
+ # Preparar resultados
283
+ results = {
284
+ 'best_model': best_model,
285
+ 'training_time': training_time,
286
+ 'leaderboard': aml.leaderboard.as_data_frame(),
287
+ 'hyperparameters': hyperparameters,
288
+ 'predictions': predictions
289
+ }
290
+
291
+ # Métricas según tipo de problema
292
+ if problem_type == 'classification':
293
+ results.update({
294
+ 'test_accuracy': accuracy_score(y_test.astype(str), predictions.astype(str)),
295
+ 'classification_report': classification_report(
296
+ y_test.astype(str),
297
+ predictions.astype(str),
298
+ output_dict=True
299
+ )
300
+ })
301
+ else:
302
+ results.update({
303
+ 'test_rmse': np.sqrt(mean_squared_error(y_test, predictions)),
304
+ 'test_r2': r2_score(y_test, predictions)
305
+ })
306
+
307
+ return results
308
+
309
+ except Exception as e:
310
+ print(f"Error detallado en H2O AutoML: {str(e)}")
311
+ return {'error': str(e)}
312
+
313
+ @staticmethod
314
+ def train_flaml_automl(
315
+ X_train: pd.DataFrame,
316
+ y_train: pd.Series,
317
+ X_test: pd.DataFrame,
318
+ y_test: pd.Series,
319
+ problem_type: str,
320
+ time_limit: int = 3600,
321
+ metric: Optional[str] = None
322
+ ) -> Dict[str, Any]:
323
+ """
324
+ Entrenar modelos usando FLAML AutoML
325
+
326
+ Args:
327
+ X_train: Features de entrenamiento
328
+ y_train: Target de entrenamiento
329
+ X_test: Features de prueba
330
+ y_test: Target de prueba
331
+ problem_type: Tipo de problema
332
+ time_limit: Límite de tiempo en segundos
333
+ metric: Métrica de evaluación
334
+
335
+ Returns:
336
+ Dict con resultados del entrenamiento
337
+ """
338
+ try:
339
+ # Configurar AutoML
340
+ task = 'classification' if problem_type == 'classification' else 'regression'
341
+ metric = metric or ('accuracy' if task == 'classification' else 'r2')
342
+
343
+ automl = AutoML()
344
+
345
+ # Entrenar
346
+ start_time = time.time()
347
+ automl.fit(
348
+ X_train=X_train,
349
+ y_train=y_train,
350
+ task=task,
351
+ time_budget=time_limit,
352
+ metric=metric,
353
+ verbose=1
354
+ )
355
+ training_time = time.time() - start_time
356
+
357
+ # Predicciones
358
+ predictions = automl.predict(X_test)
359
+
360
+ # Preparar resultados
361
+ results = {
362
+ 'best_model': automl.model,
363
+ 'best_config': automl.best_config,
364
+ 'training_time': training_time,
365
+ 'best_estimator': automl.best_estimator,
366
+ 'predictions': predictions
367
+ }
368
+
369
+ # Métricas específicas
370
+ if problem_type == 'classification':
371
+ results.update({
372
+ 'test_accuracy': accuracy_score(y_test, predictions),
373
+ 'classification_report': classification_report(y_test, predictions, output_dict=True)
374
+ })
375
+ else:
376
+ results.update({
377
+ 'test_rmse': np.sqrt(mean_squared_error(y_test, predictions)),
378
+ 'test_r2': r2_score(y_test, predictions)
379
+ })
380
+
381
+ return results
382
+
383
+ except Exception as e:
384
+ return {'error': str(e)}
385
+
386
+ def descargar_modelo_h2o(modelo_h2o, nombre_modelo):
387
+ """
388
+ Guarda y prepara el modelo H2O para su descarga.
389
+
390
+ Args:
391
+ modelo_h2o: Objeto del modelo H2O.
392
+ nombre_modelo (str): Nombre del modelo para el archivo.
393
+
394
+ Returns:
395
+ bytes: Contenido del archivo del modelo.
396
+ """
397
+ try:
398
+ # Guardar el modelo en una ruta temporal
399
+ modelo_path = h2o.save_model(model=modelo_h2o, path="/tmp", force=True)
400
+
401
+ # Leer el archivo del modelo
402
+ with open(modelo_path, "rb") as file:
403
+ modelo_data = file.read()
404
+
405
+ # Opcional: Eliminar el archivo temporal después de leerlo
406
+ os.remove(modelo_path)
407
+
408
+ return modelo_data
409
+ except Exception as e:
410
+ st.error(f"Error al preparar el modelo para descarga: {str(e)}")
411
+ return None
412
+
413
+ def show_automl_section(X: pd.DataFrame, y: pd.Series, problem_type: str):
414
+ """Mostrar sección de AutoML"""
415
+
416
+ st.header("🤖 Búsqueda Automática del Mejor Modelo")
417
+
418
+ # Parámetros de AutoML
419
+ col1, col2 = st.columns(2)
420
+ with col1:
421
+ time_limit = st.number_input(
422
+ "Límite de tiempo (segundos)",
423
+ min_value=60,
424
+ max_value=7200,
425
+ value=3600,
426
+ step=300,
427
+ key="automl_time_limit"
428
+ )
429
+
430
+ with col2:
431
+ framework = st.selectbox(
432
+ "Framework AutoML",
433
+ ["H2O AutoML", "FLAML"],
434
+ key="automl_framework"
435
+ )
436
+
437
+ # Inicializar estado para modelos AutoML
438
+ if 'automl_models' not in st.session_state:
439
+ st.session_state.automl_models = {}
440
+
441
+ # Botón de entrenamiento
442
+ train_button = st.button(
443
+ "Entrenar Modelos Automáticamente",
444
+ key="train_automl_button",
445
+ use_container_width=True
446
+ )
447
+
448
+ if train_button:
449
+ try:
450
+ # División de datos
451
+ X_train, X_test, y_train, y_test = train_test_split(
452
+ X, y, test_size=0.2, random_state=42,
453
+ stratify=y if problem_type == 'classification' else None
454
+ )
455
+
456
+ with st.spinner("Entrenando modelos automáticamente..."):
457
+ if framework == "H2O AutoML":
458
+ results = AutoMLTrainer.train_h2o_automl(
459
+ X_train, y_train, X_test, y_test,
460
+ problem_type, time_limit
461
+ )
462
+ else: # FLAML
463
+ results = AutoMLTrainer.train_flaml_automl(
464
+ X_train, y_train, X_test, y_test,
465
+ problem_type, time_limit
466
+ )
467
+
468
+ # Almacenar resultados
469
+ st.session_state.automl_models[framework] = results
470
+
471
+ except Exception as e:
472
+ st.error(f"Error en entrenamiento AutoML: {str(e)}")
473
+
474
+ # Mostrar resultados si existen
475
+ if st.session_state.automl_models:
476
+ for framework, results in st.session_state.automl_models.items():
477
+ st.subheader(f"Resultados de {framework}")
478
+
479
+ if 'error' in results:
480
+ st.error(f"Error: {results['error']}")
481
+ continue
482
+
483
+ # Métricas principales
484
+ cols = st.columns(3)
485
+ with cols[0]:
486
+ st.metric(
487
+ "Tiempo de Entrenamiento",
488
+ f"{results['training_time']:.2f}s"
489
+ )
490
+
491
+ with cols[1]:
492
+ if problem_type == 'classification':
493
+ st.metric("Accuracy", f"{results['test_accuracy']:.4f}")
494
+ else:
495
+ st.metric("R² Score", f"{results['test_r2']:.4f}")
496
+
497
+ with cols[2]:
498
+ if problem_type == 'classification':
499
+ st.metric(
500
+ "F1 Score",
501
+ f"{results['classification_report']['macro avg']['f1-score']:.4f}"
502
+ )
503
+ else:
504
+ st.metric("RMSE", f"{results['test_rmse']:.4f}")
505
+
506
+ # Explicación del modelo
507
+ if st.button("Generar Explicación", key=f"{framework}_explain"):
508
+ if 'gemini_api_key' in st.session_state:
509
+ with st.spinner("Generando explicación..."):
510
+ explainer = initialize_gemini_explainer()
511
+ model_info = {
512
+ 'name': framework,
513
+ 'problem_type': problem_type,
514
+ 'hyperparameters': results.get('hyperparameters', 'N/A'),
515
+ 'performance_metric': results.get('test_accuracy', results.get('test_r2', 'N/A')),
516
+ 'training_time': results.get('training_time', 'N/A')
517
+ }
518
+ explanation = explainer.generate_model_explanation(model_info)
519
+ st.markdown(explanation)
520
+ else:
521
+ st.warning("Configura tu API key de Gemini para generar explicaciones")
522
+
523
+ # Análisis SHAP
524
+ if st.button("Mostrar Análisis SHAP", key=f"{framework}_shap"):
525
+ create_shap_analysis_dashboard(
526
+ results['best_model'],
527
+ X,
528
+ problem_type
529
+ )
530
+
531
+ # Descarga del modelo
532
+ if st.button("Descargar Modelo", key=f"{framework}_download"):
533
+ modelo_data = descargar_modelo_h2o(results['best_model'], framework)
534
+ if modelo_data:
535
+ st.download_button(
536
+ label=f"Descargar {framework}",
537
+ data=modelo_data,
538
+ file_name=f"{framework.lower().replace(' ', '_')}_{int(time.time())}.zip",
539
+ mime="application/zip",
540
+ key=f"{framework}_download_button"
541
+ )
542
+
543
+ def show_train():
544
+ """
545
+ Función principal para mostrar la interfaz de entrenamiento de modelos
546
+ """
547
+ st.title("Desarrollo de Modelos")
548
+
549
+ # Verificar preparación de datos
550
+ if 'prepared_data' not in st.session_state:
551
+ st.warning("⚠️ No hay datos preparados en la sesión. Por favor, carga y prepara los datos primero.")
552
+ return
553
+
554
+ if st.session_state.prepared_data is None:
555
+ st.warning("⚠️ Los datos preparados están vacíos. Por favor, verifica la preparación de datos.")
556
+ return
557
+
558
+ # Inicializar 'trained_models' si no existe
559
+ if 'trained_models' not in st.session_state:
560
+ st.session_state.trained_models = {}
561
+
562
+ train = st.session_state.prepared_data
563
+
564
+ try:
565
+ # Seleccionar características y objetivo
566
+ X, y = select_features_and_target(train)
567
+ if X is None or y is None:
568
+ return
569
+
570
+ # Verificar valores nulos
571
+ if X.isnull().sum().sum() > 0 or y.isnull().sum() > 0:
572
+ st.error("Hay valores nulos en los datos. Por favor, vuelve a la página de preparación y maneja los valores faltantes.")
573
+ return
574
+
575
+ # Determinar tipo de problema
576
+ problem_type = determine_problem_type(y)
577
+
578
+ # Configuraciones de entrenamiento
579
+ col1, col2, col3 = st.columns(3)
580
+ with col1:
581
+ test_size = st.slider("Tamaño del conjunto de prueba:", 0.1, 0.5, 0.2)
582
+ with col2:
583
+ random_state = st.number_input("Random State:", min_value=0, value=42)
584
+ with col3:
585
+ n_folds = st.number_input("Número de folds para validación cruzada:", min_value=2, max_value=10, value=5)
586
+ st.session_state.n_folds = n_folds
587
+
588
+ # Preprocesamiento de datos para clasificación
589
+ if problem_type == 'classification':
590
+ y_original = y
591
+ le = LabelEncoder()
592
+ y = pd.Series(le.fit_transform(y))
593
+ st.session_state.label_encoder = le
594
+ st.write("Mapeo de clases:", dict(enumerate(le.classes_)))
595
+
596
+ # Visualizar distribución de clases
597
+ fig = create_class_distribution_plot(y_original)
598
+ st.plotly_chart(fig)
599
+
600
+ # Manejar desbalanceo de clases
601
+ X, y = handle_data_balancing(X, y, random_state)
602
+
603
+ show_automl_section(X, y, problem_type)
604
+
605
+ # Obtener opciones de modelos
606
+ model_options = get_model_options(problem_type)
607
+ # Gestionar modelos seleccionados
608
+ if 'selected_models' not in st.session_state:
609
+ st.session_state.selected_models = []
610
+
611
+ selected_models = st.multiselect(
612
+ "Selecciona los modelos a entrenar:",
613
+ list(model_options.keys()),
614
+ default=st.session_state.selected_models
615
+ )
616
+ st.session_state.selected_models = selected_models
617
+
618
+ if not selected_models:
619
+ st.warning("Por favor selecciona al menos un modelo para entrenar.")
620
+ return
621
+
622
+ # Configurar re-entrenamiento
623
+ if st.button("Reentrenar Modelos"):
624
+ st.session_state.retrain_models = True
625
+ else:
626
+ # Solo establecer a False si no está ya en sesión
627
+ if 'retrain_models' not in st.session_state:
628
+ st.session_state.retrain_models = False
629
+
630
+ # Dividir datos
631
+ X_train, X_test, y_train, y_test = train_test_split(
632
+ X, y, test_size=test_size, random_state=random_state,
633
+ stratify=y if problem_type == 'classification' else None
634
+ )
635
+
636
+ # Crear columnas para mostrar resultados de modelos
637
+ cols = st.columns(len(selected_models))
638
+
639
+ # Entrenar y mostrar resultados de cada modelo
640
+ for i, model_name in enumerate(selected_models):
641
+ with cols[i]:
642
+ st.write(f"### {model_name}")
643
+
644
+ # Verificar si el modelo ya está entrenado y si no se solicita reentrenamiento
645
+ if (model_name not in st.session_state.trained_models) or st.session_state.retrain_models:
646
+ # Entrenar modelo
647
+ trained_model = train_model_pipeline(
648
+ X_train=X_train,
649
+ y_train=y_train,
650
+ model_config=model_options[model_name],
651
+ X_test=X_test,
652
+ y_test=y_test,
653
+ cv=st.session_state.n_folds,
654
+ scoring=None,
655
+ random_state=random_state, # Pasar random_state
656
+ n_jobs=-1, # Para usar todos los núcleos disponibles
657
+ verbose=1
658
+ )
659
+
660
+ # Almacenar el modelo entrenado en session_state
661
+ if 'trained_models' not in st.session_state:
662
+ st.session_state.trained_models = {}
663
+ st.session_state.trained_models[model_name] = trained_model
664
+ else:
665
+ # Reutilizar el modelo ya entrenado
666
+ trained_model = st.session_state.trained_models[model_name]
667
+
668
+ # Mostrar resultados del modelo
669
+ show_model_results(
670
+ model_name,
671
+ problem_type,
672
+ y_test,
673
+ cols[i],
674
+ trained_model
675
+ )
676
+
677
+ except Exception as e:
678
+ st.error(f"Error inesperado: {str(e)}")
679
+
680
+
681
+ def show_model_results(model_name, problem_type, y_test, col, trained_model):
682
+ """
683
+ Mostrar resultados detallados de un modelo entrenado
684
+
685
+ Args:
686
+ model_name (str): Nombre del modelo
687
+ problem_type (str): Tipo de problema
688
+ y_test (pd.Series): Datos de prueba
689
+ col (streamlit.delta_generator.DeltaGenerator): Columna de Streamlit
690
+ trained_model (dict): Resultados del entrenamiento
691
+ """
692
+ with col:
693
+ # Verificar si el modelo está en la sesión de modelos entrenados
694
+ if model_name in st.session_state.trained_models:
695
+ results = st.session_state.trained_models[model_name]
696
+
697
+ # Verificar si hubo un error durante el entrenamiento
698
+ if 'error' in results:
699
+ st.error(results['error'])
700
+ return
701
+
702
+ # Mostrar métricas de rendimiento
703
+ if 'training_time' in results:
704
+ st.success(f"¡Entrenamiento completado en {results['training_time']:.2f} segundos!")
705
+ st.write("Mejores parámetros:", results.get('best_params', 'N/A'))
706
+
707
+ # Métricas específicas según el tipo de problema
708
+ if problem_type == 'classification':
709
+ st.write("Accuracy:", results.get('test_accuracy', 'N/A'))
710
+ st.text("Reporte de clasificación:")
711
+ st.text(pd.DataFrame(results.get('classification_report', {})).transpose().to_string())
712
+ else:
713
+ st.write("R² Score:", results.get('test_r2', 'N/A'))
714
+ st.write("RMSE:", results.get('test_rmse', 'N/A'))
715
+
716
+ # Sección de explicación de parámetros con Gemini
717
+ st.write("---")
718
+ st.write("### Explicación de Parámetros")
719
+
720
+ # Verificar disponibilidad de API key de Gemini
721
+ has_api_key = 'gemini_api_key' in st.session_state and st.session_state.gemini_api_key
722
+
723
+ if not has_api_key:
724
+ st.warning("Configure su API key de Gemini en la sección superior izquierda para usar la explicación automática de los parámetros.")
725
+
726
+ # Inicializar el explainer si no lo has hecho ya
727
+ if 'explainer' not in st.session_state:
728
+ st.session_state.explainer = initialize_gemini_explainer()
729
+
730
+ explainer = st.session_state.explainer
731
+
732
+ # Inicializar explicaciones en el estado de la sesión
733
+ if 'model_explanations' not in st.session_state:
734
+ st.session_state.model_explanations = {}
735
+
736
+ # Botón para generar explicación
737
+ explain_button = st.button(
738
+ "Explicar Parámetros",
739
+ disabled=not has_api_key,
740
+ key=f"explain_{model_name}"
741
+ )
742
+
743
+ # Mostrar explicación existente si está disponible
744
+ if model_name in st.session_state.model_explanations:
745
+ st.markdown(st.session_state.model_explanations[model_name])
746
+
747
+ # Inicializar el explainer solo cuando se necesite
748
+ if 'explain_button' in locals() and explain_button and has_api_key:
749
+ explainer = initialize_gemini_explainer()
750
+ if explainer: # Verificar que el explainer se inicializó correctamente
751
+ try:
752
+ with st.spinner("Generando explicación..."):
753
+ model_info = {
754
+ 'name': model_name,
755
+ 'problem_type': problem_type,
756
+ 'hyperparameters': results.get('hyperparameters', 'N/A'),
757
+ 'performance_metric': results.get('test_accuracy', results.get('test_r2', 'N/A')),
758
+ 'training_time': results.get('training_time', 'N/A')
759
+ }
760
+
761
+ explanation = explainer.generate_model_explanation(model_info)
762
+
763
+ # Almacenar explicación
764
+ st.session_state.model_explanations[model_name] = explanation
765
+
766
+ # Mostrar explicación
767
+ st.markdown(explanation)
768
+ except Exception as e:
769
+ st.error(f"Error al generar la explicación: {str(e)}")
770
+ else:
771
+ st.error("No se pudo inicializar el explicador de Gemini")
772
+
773
+ # Sección de análisis SHAP
774
+ st.write("---")
775
+ st.write("### Análisis SHAP")
776
+
777
+ if st.button("Mostrar Análisis SHAP", key=f"shap_button_{model_name}"):
778
+ try:
779
+ # Obtener datos preparados
780
+ X = st.session_state.prepared_data[st.session_state.feature_cols]
781
+
782
+ # Crear dashboard de análisis SHAP
783
+ create_shap_analysis_dashboard(
784
+ results['best_model'], # Usar el mejor modelo
785
+ X,
786
+ problem_type
787
+ )
788
+ except Exception as e:
789
+ st.error(f"Error en el análisis SHAP: {str(e)}")
790
+
791
+ # Sección de descarga del modelo
792
+ st.write("---")
793
+ st.write("### Descarga del modelo")
794
+
795
+ # Generar nombre de archivo
796
+ model_file_key = f"model_file_{model_name}"
797
+ if model_file_key not in st.session_state:
798
+ st.session_state[model_file_key] = f"{model_name.lower().replace(' ', '_')}_{int(time.time())}.pkl"
799
+
800
+ # Input para nombre de archivo
801
+ model_name_input = st.text_input(
802
+ "Nombre del archivo:",
803
+ value=st.session_state[model_file_key],
804
+ key=f"name_input_{model_name}"
805
+ )
806
+ # Botón de descarga
807
+ model_buffer = io.BytesIO()
808
+ pickle.dump(results['best_model'], model_buffer) # Guardar el mejor modelo
809
+ model_buffer.seek(0)
810
+
811
+ download_key = f"download_{model_name}"
812
+ st.download_button(
813
+ label="Descargar Modelo",
814
+ data=model_buffer,
815
+ file_name=model_name_input,
816
+ mime="application/octet-stream",
817
+ key=download_key
818
+ )
819
+
820
+ # # Botón de descarga
821
+ # modelo_data = descargar_modelo_h2o(results['best_model'], model_name)
822
+ # if modelo_data:
823
+ # st.download_button(
824
+ # label="Descargar Modelo",
825
+ # data=modelo_data,
826
+ # file_name=model_name_input,
827
+ # mime="application/zip",
828
+ # key=f"download_{model_name}"
829
+ # )
models/unsupervised.py ADDED
@@ -0,0 +1,585 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # unsupervised.py - Módulo para models
2
+ import streamlit as st
3
+ import pandas as pd
4
+ import numpy as np
5
+ import plotly.express as px
6
+ import plotly.graph_objects as go
7
+ from sklearn.preprocessing import StandardScaler
8
+ from sklearn.decomposition import PCA
9
+ from sklearn.cluster import KMeans, DBSCAN, AgglomerativeClustering
10
+ from sklearn.manifold import TSNE
11
+ from sklearn.metrics import silhouette_score, calinski_harabasz_score, davies_bouldin_score
12
+ import google.generativeai as genai
13
+ import umap
14
+
15
+ class UnsupervisedAnalyzer:
16
+ def __init__(self, data):
17
+ self.data = data
18
+ self.scaler = StandardScaler()
19
+
20
+ def preprocess_data(self, feature_cols):
21
+ """Escalar datos seleccionados"""
22
+ X = self.data[feature_cols]
23
+ return self.scaler.fit_transform(X)
24
+
25
+ def perform_kmeans(self, X_scaled, n_clusters):
26
+ """Realizar clustering K-Means"""
27
+ kmeans = KMeans(
28
+ n_clusters=n_clusters,
29
+ random_state=42,
30
+ n_init=10
31
+ )
32
+ clusters = kmeans.fit_predict(X_scaled)
33
+
34
+ # Calcular métricas
35
+ metrics = {
36
+ 'Silhouette Score': silhouette_score(X_scaled, clusters),
37
+ 'Calinski-Harabasz Score': calinski_harabasz_score(X_scaled, clusters),
38
+ 'Davies-Bouldin Score': davies_bouldin_score(X_scaled, clusters)
39
+ }
40
+
41
+ return {
42
+ 'clusters': clusters,
43
+ 'metrics': metrics,
44
+ 'centroids': kmeans.cluster_centers_
45
+ }
46
+
47
+ def perform_dbscan(self, X_scaled, eps, min_samples):
48
+ """Realizar clustering DBSCAN"""
49
+ dbscan = DBSCAN(eps=eps, min_samples=min_samples)
50
+ clusters = dbscan.fit_predict(X_scaled)
51
+
52
+ # Calcular métricas
53
+ unique_clusters = np.setdiff1d(np.unique(clusters), [-1])
54
+ metrics = {
55
+ 'Noise Points': np.sum(clusters == -1),
56
+ 'Number of Clusters': len(unique_clusters)
57
+ }
58
+
59
+ # Solo calcular métricas si hay clusters válidos
60
+ if len(unique_clusters) > 0:
61
+ non_noise_mask = clusters != -1
62
+ metrics.update({
63
+ 'Silhouette Score': silhouette_score(X_scaled[non_noise_mask], clusters[non_noise_mask]),
64
+ 'Calinski-Harabasz Score': calinski_harabasz_score(X_scaled[non_noise_mask], clusters[non_noise_mask]),
65
+ 'Davies-Bouldin Score': davies_bouldin_score(X_scaled[non_noise_mask], clusters[non_noise_mask])
66
+ })
67
+ else:
68
+ metrics.update({
69
+ 'Silhouette Score': None,
70
+ 'Calinski-Harabasz Score': None,
71
+ 'Davies-Bouldin Score': None
72
+ })
73
+
74
+ return {
75
+ 'clusters': clusters,
76
+ 'metrics': metrics
77
+ }
78
+
79
+ def perform_hierarchical_clustering(self, X_scaled, n_clusters):
80
+ """Realizar clustering jerárquico"""
81
+ hierarchical = AgglomerativeClustering(n_clusters=n_clusters)
82
+ clusters = hierarchical.fit_predict(X_scaled)
83
+
84
+ # Calcular métricas
85
+ metrics = {
86
+ 'Silhouette Score': silhouette_score(X_scaled, clusters),
87
+ 'Calinski-Harabasz Score': calinski_harabasz_score(X_scaled, clusters),
88
+ 'Davies-Bouldin Score': davies_bouldin_score(X_scaled, clusters)
89
+ }
90
+
91
+ return {
92
+ 'clusters': clusters,
93
+ 'metrics': metrics
94
+ }
95
+
96
+ def perform_dimensionality_reduction(self, X_scaled, method='PCA', n_components=2):
97
+ """Realizar reducción de dimensionalidad"""
98
+ if method == 'PCA':
99
+ reducer = PCA(n_components=n_components)
100
+ reduced_data = reducer.fit_transform(X_scaled)
101
+ return {
102
+ 'reduced_data': reduced_data,
103
+ 'explained_variance': reducer.explained_variance_ratio_
104
+ }
105
+ elif method == 't-SNE':
106
+ reducer = TSNE(n_components=n_components, random_state=42)
107
+ reduced_data = reducer.fit_transform(X_scaled)
108
+ return {
109
+ 'reduced_data': reduced_data
110
+ }
111
+ elif method == 'UMAP':
112
+ reducer = umap.UMAP(n_components=n_components, random_state=42)
113
+ reduced_data = reducer.fit_transform(X_scaled)
114
+ return {
115
+ 'reduced_data': reduced_data
116
+ }
117
+
118
+ def generate_method_explanation(method, parameters, metrics):
119
+ """Generar explicación del método usando Gemini"""
120
+ try:
121
+ genai.configure(api_key=st.session_state.gemini_api_key)
122
+ model = genai.GenerativeModel('gemini-1.5-flash')
123
+
124
+ # Preparar prompt basado en el método
125
+ prompt = f"""Explica detalladamente el método de análisis no supervisado: {method}
126
+
127
+ Parámetros:
128
+ {', '.join([f"{k}: {v}" for k, v in parameters.items()])}
129
+
130
+ Métricas:
131
+ {', '.join([f"{k}: {v}" for k, v in metrics.items()])}
132
+
133
+ En tu explicación, incluye:
134
+ 1. Objetivo principal del método
135
+ 2. Cómo funciona el algoritmo
136
+ 3. Interpretación de los parámetros
137
+ 4. Significado de las métricas
138
+ 5. Casos de uso recomendados"""
139
+
140
+ response = model.generate_content(prompt)
141
+ return response.text
142
+ except Exception as e:
143
+ return f"Error al generar explicación: {str(e)}"
144
+
145
+ def visualize_clustering(X_scaled, clusters, method_name, n_components=2):
146
+ """Visualización de clustering"""
147
+ reducer = PCA(n_components=n_components)
148
+ X_reduced = reducer.fit_transform(X_scaled)
149
+
150
+ if n_components == 2:
151
+ fig = px.scatter(
152
+ x=X_reduced[:, 0],
153
+ y=X_reduced[:, 1],
154
+ color=clusters.astype(str),
155
+ title=f'Clustering {method_name} - Visualización PCA',
156
+ labels={'x': 'PCA Componente 1', 'y': 'PCA Componente 2'}
157
+ )
158
+ else:
159
+ fig = go.Figure(data=[
160
+ go.Scatter3d(
161
+ x=X_reduced[:, 0],
162
+ y=X_reduced[:, 1],
163
+ z=X_reduced[:, 2],
164
+ mode='markers',
165
+ marker=dict(
166
+ size=5,
167
+ color=clusters,
168
+ colorscale='Viridis',
169
+ opacity=0.8
170
+ )
171
+ )
172
+ ])
173
+ fig.update_layout(
174
+ title=f'Clustering {method_name} - Visualización 3D',
175
+ scene=dict(
176
+ xaxis_title='PCA 1',
177
+ yaxis_title='PCA 2',
178
+ zaxis_title='PCA 3'
179
+ )
180
+ )
181
+
182
+ return fig
183
+
184
+ def show_unsupervised_analysis():
185
+ st.title("Análisis No Supervisado")
186
+
187
+ # Verificar datos preparados
188
+ if 'prepared_data' not in st.session_state or st.session_state.prepared_data is None:
189
+ st.warning("Por favor, carga y prepara tus datos primero")
190
+ return
191
+
192
+ # Obtener datos
193
+ data = st.session_state.prepared_data
194
+
195
+ # Seleccionar columnas numéricas
196
+ numeric_cols = data.select_dtypes(include=['int64', 'float64']).columns.tolist()
197
+
198
+ if not numeric_cols:
199
+ st.error("No hay variables numéricas para análisis no supervisado")
200
+ return
201
+
202
+ # Selección de características
203
+ feature_cols = st.multiselect(
204
+ "Seleccionar Variables para Análisis",
205
+ numeric_cols,
206
+ default=numeric_cols[:min(5, len(numeric_cols))]
207
+ )
208
+
209
+ if not feature_cols:
210
+ st.warning("Selecciona al menos una variable")
211
+ return
212
+
213
+ # Inicializar analizador
214
+ analyzer = UnsupervisedAnalyzer(data)
215
+ X_scaled = analyzer.preprocess_data(feature_cols)
216
+
217
+ # Selección de métodos
218
+ methods = st.multiselect(
219
+ "Seleccionar Métodos de Análisis",
220
+ [
221
+ "K-Means",
222
+ "DBSCAN",
223
+ "Clustering Jerárquico",
224
+ "PCA",
225
+ "t-SNE",
226
+ "UMAP"
227
+ ]
228
+ )
229
+
230
+ # Contenedor para resultados
231
+ results = {}
232
+
233
+ # Columnas para visualización
234
+ if methods:
235
+ cols = st.columns(len(methods))
236
+
237
+ for i, method in enumerate(methods):
238
+ with cols[i]:
239
+ st.subheader(method)
240
+
241
+ # Parámetros específicos por método
242
+ if method == "K-Means":
243
+ n_clusters = st.slider(
244
+ "Número de Clusters",
245
+ min_value=2,
246
+ max_value=10,
247
+ value=3,
248
+ key=f"kmeans_clusters_{i}"
249
+ )
250
+ result = analyzer.perform_kmeans(X_scaled, n_clusters)
251
+ results['K-Means'] = result
252
+
253
+ # Visualización
254
+ fig = visualize_clustering(X_scaled, result['clusters'], method)
255
+ st.plotly_chart(fig)
256
+
257
+ # Métricas
258
+ st.write("Métricas:")
259
+ for metric, value in result['metrics'].items():
260
+ st.metric(metric, f"{value:.4f}")
261
+
262
+ # Explicación con Gemini
263
+ if st.session_state.get('gemini_api_key'):
264
+ explanation = generate_method_explanation(
265
+ method,
266
+ {'Número de Clusters': n_clusters},
267
+ result['metrics']
268
+ )
269
+ with st.expander("Explicación del Método"):
270
+ st.markdown(explanation)
271
+
272
+ elif method == "DBSCAN":
273
+ eps = st.slider(
274
+ "Epsilon",
275
+ min_value=0.1,
276
+ max_value=2.0,
277
+ value=0.5,
278
+ key=f"dbscan_eps_{i}"
279
+ )
280
+ min_samples = st.slider(
281
+ "Mínimo de Muestras",
282
+ min_value=2,
283
+ max_value=20,
284
+ value=5,
285
+ key=f"dbscan_min_samples_{i}"
286
+ )
287
+ result = analyzer.perform_dbscan(X_scaled, eps, min_samples)
288
+ results['DBSCAN'] = result
289
+
290
+ # Visualización
291
+ fig = visualize_clustering(X_scaled, result['clusters'], method)
292
+ st.plotly_chart(fig)
293
+
294
+ # Métricas
295
+ st.write("Métricas:")
296
+ for metric, value in result['metrics'].items():
297
+ st.metric(metric, str(value))
298
+
299
+ # Explicación con Gemini
300
+ if st.session_state.get('gemini_api_key'):
301
+ explanation = generate_method_explanation(
302
+ method,
303
+ {
304
+ 'Epsilon': eps,
305
+ 'Mínimo de Muestras': min_samples
306
+ },
307
+ result['metrics']
308
+ )
309
+ with st.expander("Explicación del Método"):
310
+ st.markdown(explanation)
311
+
312
+ elif method == "Clustering Jerárquico":
313
+ n_clusters = st.slider(
314
+ "Número de Clusters",
315
+ min_value=2,
316
+ max_value=10,
317
+ value=3,
318
+ key=f"hierarchical_clusters_{i}"
319
+ )
320
+ result = analyzer.perform_hierarchical_clustering(X_scaled, n_clusters)
321
+ results['Clustering Jerárquico'] = result
322
+
323
+ # Visualización
324
+ fig = visualize_clustering(X_scaled, result['clusters'], method)
325
+ st.plotly_chart(fig)
326
+
327
+ # Métricas
328
+ st.write("Métricas:")
329
+ for metric, value in result['metrics'].items():
330
+ st.metric(metric, f"{value:.4f}")
331
+
332
+ # Explicación con Gemini
333
+ if st.session_state.get('gemini_api_key'):
334
+ explanation = generate_method_explanation(
335
+ method,
336
+ {'Número de Clusters': n_clusters},
337
+ result['metrics']
338
+ )
339
+ with st.expander("Explicación del Método"):
340
+ st.markdown(explanation)
341
+
342
+ elif method in ["PCA", "t-SNE", "UMAP"]:
343
+ n_components = st.slider(
344
+ "Número de Componentes",
345
+ min_value=2,
346
+ max_value=3,
347
+ value=2,
348
+ key=f"{method}_components_{i}"
349
+ )
350
+ result = analyzer.perform_dimensionality_reduction(
351
+ X_scaled,
352
+ method=method,
353
+ n_components=n_components
354
+ )
355
+ results[method] = result
356
+
357
+ # Visualización de reducción de dimensionalidad
358
+ fig = px.scatter(
359
+ x=result['reduced_data'][:, 0],
360
+ y=result['reduced_data'][:, 1],
361
+ title=f'Reducción de Dimensionalidad - {method}'
362
+ )
363
+ st.plotly_chart(fig)
364
+
365
+ # Varianza explicada para PCA
366
+ if method == 'PCA':
367
+ st.write("Varianza Explicada:")
368
+ varianza_df = pd.DataFrame({
369
+ 'Componente': range(1, len(result['explained_variance']) + 1),
370
+ 'Varianza Explicada (%)': result['explained_variance'] * 100,
371
+ 'Varianza Acumulada (%)': np.cumsum(result['explained_variance']) * 100
372
+ })
373
+ st.dataframe(varianza_df)
374
+
375
+ # Explicación con Gemini
376
+ if st.session_state.get('gemini_api_key'):
377
+ explanation = generate_method_explanation(
378
+ method,
379
+ {'Número de Componentes': n_components},
380
+ {}
381
+ )
382
+ with st.expander("Explicación del Método"):
383
+ st.markdown(explanation)
384
+
385
+ # Exportar resultados
386
+ if st.button("Exportar Resultados"):
387
+ export_data = []
388
+ for method, result in results.items():
389
+ method_data = {
390
+ 'Método': method,
391
+ 'Variables': ', '.join(feature_cols)
392
+ }
393
+
394
+ # Agregar métricas si están disponibles
395
+ if 'metrics' in result:
396
+ method_data.update(result['metrics'])
397
+
398
+ export_data.append(method_data)
399
+
400
+ export_df = pd.DataFrame(export_data)
401
+ csv = export_df.to_csv(index=False).encode('utf-8')
402
+ st.download_button(
403
+ label="Descargar Resultados",
404
+ data=csv,
405
+ file_name="unsupervised_analysis_results.csv",
406
+ mime="text/csv",
407
+ key="download_unsupervised_results"
408
+ )
409
+
410
+ def show_unsupervised():
411
+ """Función principal para mostrar la página de análisis no supervisado"""
412
+ st.title("🔍 Análisis No Supervisado")
413
+
414
+ # Verificar datos preparados
415
+ if 'prepared_data' not in st.session_state or st.session_state.prepared_data is None:
416
+ st.warning("Por favor, carga y prepara tus datos primero en la página de Preparación.")
417
+ return
418
+
419
+ # Obtener datos preparados
420
+ data = st.session_state.prepared_data
421
+
422
+ # Sección de selección de variables
423
+ st.header("Configuración del Análisis")
424
+
425
+ # Seleccionar columnas numéricas
426
+ numeric_cols = data.select_dtypes(include=['int64', 'float64']).columns.tolist()
427
+
428
+ if not numeric_cols:
429
+ st.error("No hay variables numéricas disponibles para realizar análisis no supervisado.")
430
+ return
431
+
432
+ # Selección de características
433
+ st.subheader("Selección de Variables")
434
+ feature_cols = st.multiselect(
435
+ "Selecciona las variables para el análisis",
436
+ numeric_cols,
437
+ default=numeric_cols[:min(5, len(numeric_cols))]
438
+ )
439
+
440
+ if not feature_cols:
441
+ st.warning("Por favor, selecciona al menos una variable.")
442
+ return
443
+
444
+ # Inicializar analizador
445
+ analyzer = UnsupervisedAnalyzer(data)
446
+ X_scaled = analyzer.preprocess_data(feature_cols)
447
+
448
+ # Sección de métodos de análisis
449
+ st.header("Métodos de Análisis")
450
+
451
+ # Selección de métodos
452
+ metodos = st.multiselect(
453
+ "Elige los métodos de análisis no supervisado",
454
+ [
455
+ "K-Means",
456
+ "DBSCAN",
457
+ "Clustering Jerárquico",
458
+ "Análisis de Componentes Principales (PCA)",
459
+ "t-SNE",
460
+ "UMAP"
461
+ ]
462
+ )
463
+
464
+ # Contenedor de resultados
465
+ resultados = {}
466
+
467
+ # Procesamiento de métodos seleccionados
468
+ if metodos:
469
+ # Crear columnas dinámicamente
470
+ cols = st.columns(len(metodos))
471
+
472
+ for i, metodo in enumerate(metodos):
473
+ with cols[i]:
474
+ st.subheader(metodo)
475
+
476
+ # Parámetros específicos por método
477
+ if metodo == "K-Means":
478
+ n_clusters = st.slider(
479
+ "Número de Clusters",
480
+ min_value=2,
481
+ max_value=10,
482
+ value=3,
483
+ key=f"kmeans_clusters_{i}"
484
+ )
485
+
486
+ # Realizar K-Means
487
+ resultado = analyzer.perform_kmeans(X_scaled, n_clusters)
488
+ resultados[metodo] = resultado
489
+
490
+ # Visualización
491
+ fig = visualize_clustering(X_scaled, resultado['clusters'], metodo)
492
+ st.plotly_chart(fig)
493
+
494
+ # Mostrar métricas
495
+ st.subheader("Métricas")
496
+ for metrica, valor in resultado['metrics'].items():
497
+ st.metric(metrica, f"{valor:.4f}")
498
+
499
+ # Explicación con Gemini
500
+ if st.session_state.get('gemini_api_key'):
501
+ explicacion = generate_method_explanation(
502
+ metodo,
503
+ {'Número de Clusters': n_clusters},
504
+ resultado['metrics']
505
+ )
506
+ with st.expander("Explicación del Método"):
507
+ st.markdown(explicacion)
508
+
509
+ elif metodo == "DBSCAN":
510
+ eps = st.slider(
511
+ "Epsilon",
512
+ min_value=0.1,
513
+ max_value=2.0,
514
+ value=0.5,
515
+ key=f"dbscan_eps_{i}"
516
+ )
517
+ min_samples = st.slider(
518
+ "Mínimo de Muestras",
519
+ min_value=2,
520
+ max_value=20,
521
+ value=5,
522
+ key=f"dbscan_min_samples_{i}"
523
+ )
524
+
525
+ # Realizar DBSCAN
526
+ resultado = analyzer.perform_dbscan(X_scaled, eps, min_samples)
527
+ resultados[metodo] = resultado
528
+
529
+ # Visualización
530
+ fig = visualize_clustering(X_scaled, resultado['clusters'], metodo)
531
+ st.plotly_chart(fig)
532
+
533
+ # Mostrar métricas
534
+ st.subheader("Métricas")
535
+ for metrica, valor in resultado['metrics'].items():
536
+ st.metric(metrica, str(valor))
537
+
538
+ # Explicación con Gemini
539
+ if st.session_state.get('gemini_api_key'):
540
+ explicacion = generate_method_explanation(
541
+ metodo,
542
+ {
543
+ 'Epsilon': eps,
544
+ 'Mínimo de Muestras': min_samples
545
+ },
546
+ resultado['metrics']
547
+ )
548
+ with st.expander("Explicación del Método"):
549
+ st.markdown(explicacion)
550
+
551
+ # Continuar con los demás métodos de manera similar...
552
+
553
+ # Sección de exportación de resultados
554
+ if st.button("Exportar Resultados del Análisis"):
555
+ # Crear DataFrame con resultados
556
+ datos_exportacion = []
557
+ for metodo, resultado in resultados.items():
558
+ datos_metodo = {
559
+ 'Método': metodo,
560
+ 'Variables': ', '.join(feature_cols)
561
+ }
562
+
563
+ # Agregar métricas si están disponibles
564
+ if 'metrics' in resultado:
565
+ datos_metodo.update(resultado['metrics'])
566
+
567
+ datos_exportacion.append(datos_metodo)
568
+
569
+ df_exportacion = pd.DataFrame(datos_exportacion)
570
+
571
+ # Descargar CSV
572
+ csv = df_exportacion.to_csv(index=False).encode('utf-8')
573
+ st.download_button(
574
+ label="Descargar Resultados",
575
+ data=csv,
576
+ file_name="analisis_no_supervisado.csv",
577
+ mime="text/csv"
578
+ )
579
+
580
+ # Función principal para ejecutar el análisis no supervisado
581
+ def main():
582
+ show_unsupervised()
583
+
584
+ if __name__ == "__main__":
585
+ main()
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ libsnappy-dev
2
+ libgl1
requirements.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit_option_menu
2
+ streamlit_lottie
3
+ pandas
4
+ numpy
5
+ scikit-learn==1.4.0
6
+ google-generativeai
7
+ plotly
8
+ supabase
9
+ python-dotenv
10
+ shap
11
+ xgboost
12
+ requests
13
+ typing
14
+ streamlit_shap
15
+ matplotlib
16
+ pyarrow
17
+ umap
18
+ imblearn
19
+ openpyxl
20
+ pygwalker
21
+ ydata_profiling
22
+ stqdm
23
+ h2o
24
+ FLAML
utils/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # __init__.py - M�dulo para src utils
2
+ from .model_utils import ModelTrainer
utils/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (229 Bytes). View file
 
utils/__pycache__/gemini_explainer.cpython-312.pyc ADDED
Binary file (18.4 kB). View file
 
utils/__pycache__/model_utils.cpython-312.pyc ADDED
Binary file (21.7 kB). View file
 
utils/__pycache__/shap_explainer.cpython-312.pyc ADDED
Binary file (19 kB). View file
 
utils/gemini_explainer.py ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils/gemini_explainer.py
2
+ import streamlit as st
3
+ import google.generativeai as genai
4
+ from typing import Dict, Any, Optional
5
+ import h2o
6
+ import os
7
+
8
+ def generate_dataset_explanation(dataset, api_key=None):
9
+ """
10
+ Generate a dataset explanation using Gemini AI
11
+
12
+ Args:
13
+ dataset (pd.DataFrame): DataFrame to explain
14
+ api_key (str, optional): Gemini API key
15
+
16
+ Returns:
17
+ str: Explanation of the dataset
18
+ """
19
+ try:
20
+ # Prepare dataset information
21
+ dataset_info = {
22
+ 'rows': len(dataset),
23
+ 'columns': len(dataset.columns),
24
+ 'column_names': list(dataset.columns),
25
+ 'data_types': str(dataset.dtypes),
26
+ 'first_rows': dataset.head().to_string(),
27
+ 'basic_stats': dataset.describe().to_string()
28
+ }
29
+
30
+ # Initialize Gemini Explainer
31
+ explainer = GeminiExplainer(api_key)
32
+
33
+ # Generate explanation
34
+ explanation = explainer.generate_dataset_explanation(dataset_info)
35
+
36
+ return explanation
37
+
38
+ except Exception as e:
39
+ return f"Error generating dataset explanation: {str(e)}"
40
+
41
+ def generate_model_explanation(self, model_info: Dict[str, Any]) -> str:
42
+ """
43
+ Generar una explicación detallada de un modelo de machine learning
44
+
45
+ Args:
46
+ model_info (dict): Información del modelo
47
+
48
+ Returns:
49
+ str: Explicación generada por Gemini
50
+ """
51
+ prompt = f"""Proporciona una explicación detallada del modelo de machine learning:
52
+
53
+ Información del Modelo:
54
+ - Nombre del Modelo: {model_info.get('name', 'N/A')}
55
+ - Tipo de Problema: {model_info.get('problem_type', 'N/A')}
56
+ - Hiperparámetros: {model_info.get('hyperparameters', 'N/A')}
57
+ - Métricas de Rendimiento:
58
+ * Accuracy/R²: {model_info.get('performance_metric', 'N/A')}
59
+ * Otras métricas: {model_info.get('additional_metrics', 'N/A')}
60
+
61
+ En tu explicación, incluye:
62
+ 1. Descripción del algoritmo
63
+ 2. Funcionamiento interno del modelo
64
+ 3. Interpretación de los hiperparámetros
65
+ 4. Análisis de las métricas de rendimiento
66
+ 5. Fortalezas y limitaciones del modelo
67
+ 6. Recomendaciones para posibles mejoras"""
68
+
69
+ try:
70
+ response = self.model.generate_content(prompt)
71
+ return response.text
72
+ except Exception as e:
73
+ return f"Error al generar explicación: {str(e)}"
74
+
75
+ class GeminiExplainer:
76
+ def __init__(self, api_key: Optional[str] = None):
77
+ """
78
+ Inicializar el explicador de Gemini
79
+
80
+ Args:
81
+ api_key (str, opcional): API key de Google Generative AI
82
+ """
83
+ self.api_key = api_key or st.session_state.get('gemini_api_key')
84
+
85
+ if not self.api_key:
86
+ raise ValueError("No se ha proporcionado una API key de Gemini")
87
+
88
+ # Configurar la API de Gemini
89
+ genai.configure(api_key=self.api_key)
90
+
91
+ # Seleccionar modelo
92
+ self.model = genai.GenerativeModel('gemini-1.5-flash')
93
+
94
+ def generate_dataset_explanation(self, dataset_info: Dict[str, Any]) -> str:
95
+ """
96
+ Generar una explicación detallada del dataset
97
+
98
+ Args:
99
+ dataset_info (dict): Información del dataset
100
+
101
+ Returns:
102
+ str: Explicación generada por Gemini
103
+ """
104
+ prompt = f"""Analiza este dataset y proporciona una explicación clara y concisa de su estructura y contenido:
105
+
106
+ Información del Dataset:
107
+ - Dimensiones: {dataset_info.get('rows', 'N/A')} filas × {dataset_info.get('columns', 'N/A')} columnas
108
+ - Columnas: {', '.join(dataset_info.get('column_names', []))}
109
+ - Tipos de datos: {dataset_info.get('data_types', 'N/A')}
110
+ - Primeras filas: {dataset_info.get('first_rows', 'N/A')}
111
+ - Estadísticas básicas: {dataset_info.get('basic_stats', 'N/A')}
112
+
113
+ En tu explicación, incluye:
114
+ 1. Descripción general del dataset
115
+ 2. Tipos de variables presentes
116
+ 3. Posibles desafíos o consideraciones para el análisis
117
+ 4. Sugerencias iniciales de preprocesamiento
118
+ 5. Potenciales insights o patrones preliminares"""
119
+
120
+ try:
121
+ response = self.model.generate_content(prompt)
122
+ return response.text
123
+ except Exception as e:
124
+ return f"Error al generar explicación: {str(e)}"
125
+
126
+ def generate_model_explanation(self, model_info: Dict[str, Any]) -> str:
127
+ """
128
+ Generar una explicación detallada de un modelo de machine learning
129
+
130
+ Args:
131
+ model_info (dict): Información del modelo
132
+
133
+ Returns:
134
+ str: Explicación generada por Gemini
135
+ """
136
+ # Extraer hiperparámetros en formato legible
137
+ hyperparameters = model_info.get('hyperparameters', {})
138
+ if isinstance(hyperparameters, dict):
139
+ hyperparams_str = "\n".join([f"- {k}: {v}" for k, v in hyperparameters.items()])
140
+ else:
141
+ hyperparams_str = str(hyperparameters)
142
+
143
+ prompt = f"""Proporciona una explicación detallada del modelo de machine learning:
144
+
145
+ Información del Modelo:
146
+ - Nombre del Modelo: {model_info.get('name', 'N/A')}
147
+ - Tipo de Problema: {model_info.get('problem_type', 'N/A')}
148
+ - Hiperparámetros:
149
+ {hyperparams_str}
150
+ - Métricas de Rendimiento:
151
+ * Accuracy/R²: {model_info.get('performance_metric', 'N/A')}
152
+ * Tiempo de Entrenamiento: {model_info.get('training_time', 'N/A')}
153
+
154
+ En tu explicación, incluye:
155
+ 1. Descripción del algoritmo
156
+ 2. Funcionamiento interno del modelo
157
+ 3. Interpretación de los hiperparámetros
158
+ 4. Análisis de las métricas de rendimiento
159
+ 5. Fortalezas y limitaciones del modelo
160
+ 6. Recomendaciones para posibles mejoras"""
161
+
162
+ try:
163
+ response = self.model.generate_content(prompt)
164
+ return response.text
165
+ except Exception as e:
166
+ return f"Error al generar explicación: {str(e)}"
167
+
168
+ def generate_clustering_explanation(self, clustering_info: Dict[str, Any]) -> str:
169
+ """
170
+ Generar una explicación de resultados de clustering
171
+
172
+ Args:
173
+ clustering_info (dict): Información del clustering
174
+
175
+ Returns:
176
+ str: Explicación generada por Gemini
177
+ """
178
+ prompt = f"""Analiza los resultados del método de clustering:
179
+
180
+ Información del Clustering:
181
+ - Método: {clustering_info.get('method', 'N/A')}
182
+ - Número de Clusters: {clustering_info.get('n_clusters', 'N/A')}
183
+ - Parámetros: {clustering_info.get('parameters', 'N/A')}
184
+ - Métricas:
185
+ * Silhouette Score: {clustering_info.get('silhouette_score', 'N/A')}
186
+ * Calinski-Harabasz: {clustering_info.get('calinski_score', 'N/A')}
187
+ * Davies-Bouldin: {clustering_info.get('davies_bouldin', 'N/A')}
188
+
189
+ En tu explicación, incluye:
190
+ 1. Descripción del método de clustering
191
+ 2. Interpretación de los parámetros utilizados
192
+ 3. Significado de las métricas de evaluación
193
+ 4. Análisis de la calidad de los clusters
194
+ 5. Posibles insights o patrones detectados
195
+ 6. Recomendaciones para ajustar el clustering"""
196
+
197
+ try:
198
+ response = self.model.generate_content(prompt)
199
+ return response.text
200
+ except Exception as e:
201
+ return f"Error al generar explicación: {str(e)}"
202
+
203
+ def generate_feature_importance_explanation(self, feature_importance_info: Dict[str, Any]) -> str:
204
+ """
205
+ Generar una explicación de la importancia de características
206
+
207
+ Args:
208
+ feature_importance_info (dict): Información de importancia de características
209
+
210
+ Returns:
211
+ str: Explicación generada por Gemini
212
+ """
213
+ method = feature_importance_info.get('method', 'N/A')
214
+ features = feature_importance_info.get('features', [])
215
+ importance_values = feature_importance_info.get('importance_values', {})
216
+
217
+ # Formatear la información de importancia
218
+ importance_str = "\n".join([f"- {feat}: {val}" for feat, val in importance_values.items()])
219
+
220
+ prompt = f"""Analiza la importancia de las características en el modelo:
221
+
222
+ Información de Importancia de Características:
223
+ - Método de Evaluación: {method}
224
+ - Características:
225
+ {importance_str}
226
+
227
+ En tu explicación, incluye:
228
+ 1. Descripción del método de evaluación de importancia
229
+ 2. Análisis de las características más importantes
230
+ 3. Interpretación de los valores de importancia
231
+ 4. Posibles implicaciones para el modelado
232
+ 5. Recomendaciones para selección de características"""
233
+
234
+ try:
235
+ response = self.model.generate_content(prompt)
236
+ return response.text
237
+ except Exception as e:
238
+ return f"Error al generar explicación: {str(e)}"
239
+
240
+ def initialize_gemini_explainer():
241
+ """
242
+ Función de utilidad para inicializar el explicador de Gemini en Streamlit
243
+
244
+ Returns:
245
+ GeminiExplainer: Instancia del explicador de Gemini o None si hay error
246
+ """
247
+ try:
248
+ if 'gemini_api_key' not in st.session_state:
249
+ st.warning("Por favor configura tu API key de Gemini primero")
250
+ return None
251
+
252
+ api_key = st.session_state.get('gemini_api_key')
253
+ if not api_key:
254
+ st.warning("API key de Gemini no encontrada")
255
+ return None
256
+
257
+ # Inicializar explicador con la API key
258
+ explainer = GeminiExplainer(api_key=api_key)
259
+ return explainer
260
+
261
+ except Exception as e:
262
+ st.error(f"Error al inicializar el explicador: {str(e)}")
263
+ return None
264
+
265
+ # Ejemplo de uso en Streamlit
266
+ def main():
267
+ st.title("Explicaciones con Gemini")
268
+
269
+ # Verificar configuración de API key
270
+ if 'gemini_api_key' not in st.session_state:
271
+ st.warning("Configura tu API key de Gemini")
272
+ return
273
+
274
+ explainer = initialize_gemini_explainer()
275
+
276
+ if explainer:
277
+ # Ejemplo de uso de métodos de explicación
278
+ st.subheader("Explicación de Dataset")
279
+ dataset_info = {
280
+ 'rows': 100,
281
+ 'columns': 5,
282
+ 'column_names': ['age', 'income', 'education', 'credit_score', 'loan_approved'],
283
+ 'data_types': 'Mixed (numeric and categorical)',
284
+ 'first_rows': 'Sample data preview',
285
+ 'basic_stats': 'Mean, median, standard deviation'
286
+ }
287
+
288
+ if st.button("Explicar Dataset"):
289
+ explanation = explainer.generate_dataset_explanation(dataset_info)
290
+ st.markdown(explanation)
291
+
292
+ st.subheader("Explicación de Modelo")
293
+ model_info = {
294
+ 'name': 'Random Forest Classifier',
295
+ 'problem_type': 'Clasificación binaria',
296
+ 'hyperparameters': {
297
+ 'n_estimators': 100,
298
+ 'max_depth': 5,
299
+ 'learning_rate': 0.1
300
+ },
301
+ 'performance_metric': 0.85,
302
+ 'additional_metrics': {
303
+ 'precision': 0.82,
304
+ 'recall': 0.88,
305
+ 'f1_score': 0.85
306
+ }
307
+ }
308
+
309
+ if st.button("Explicar Modelo"):
310
+ explanation = explainer.generate_model_explanation(model_info)
311
+ st.markdown(explanation)
312
+
313
+ st.subheader("Explicación de Clustering")
314
+ clustering_info = {
315
+ 'method': 'K-Means',
316
+ 'n_clusters': 3,
317
+ 'parameters': {
318
+ 'eps': 0.5,
319
+ 'min_samples': 5
320
+ },
321
+ 'silhouette_score': 0.7,
322
+ 'calinski_score': 150.5,
323
+ 'davies_bouldin': 0.4
324
+ }
325
+
326
+ if st.button("Explicar Clustering"):
327
+ explanation = explainer.generate_clustering_explanation(clustering_info)
328
+ st.markdown(explanation)
329
+
330
+ st.subheader("Explicación de Importancia de Características")
331
+ feature_importance_info = {
332
+ 'method': 'SHAP Values',
333
+ 'features': ['age', 'income', 'education', 'credit_score'],
334
+ 'importance_values': {
335
+ 'age': 0.35,
336
+ 'income': 0.25,
337
+ 'education': 0.2,
338
+ 'credit_score': 0.2
339
+ }
340
+ }
341
+
342
+ if st.button("Explicar Importancia de Características"):
343
+ explanation = explainer.generate_feature_importance_explanation(feature_importance_info)
344
+ st.markdown(explanation)
345
+
346
+ # Función para manejar errores de API key
347
+ def validate_gemini_api_key(api_key: str) -> bool:
348
+ """
349
+ Validar la API key de Gemini
350
+
351
+ Args:
352
+ api_key (str): API key a validar
353
+
354
+ Returns:
355
+ bool: True si la API key es válida, False en caso contrario
356
+ """
357
+ try:
358
+ genai.configure(api_key=api_key)
359
+ model = genai.GenerativeModel('gemini-1.5-flash')
360
+ # Intentar generar una respuesta simple
361
+ response = model.generate_content("Hola, ¿estás funcionando?")
362
+ return True
363
+ except Exception as e:
364
+ st.error(f"Error de validación de API key: {str(e)}")
365
+ return False
366
+
367
+ # Función de configuración de API key en Streamlit
368
+ def setup_gemini_api_key():
369
+ """
370
+ Configurar y validar la API key de Gemini en Streamlit
371
+ """
372
+ st.sidebar.header("🔑 Configuración de Gemini API")
373
+
374
+ # Input para la API key
375
+ api_key = st.sidebar.text_input(
376
+ "Ingresa tu Gemini API Key",
377
+ type="password",
378
+ help="Puedes obtener tu API key en Google AI Studio"
379
+ )
380
+
381
+ # Botón de validación
382
+ if st.sidebar.button("Validar API Key"):
383
+ if api_key:
384
+ if validate_gemini_api_key(api_key):
385
+ st.session_state.gemini_api_key = api_key
386
+ st.sidebar.success("✅ API Key validada correctamente")
387
+ else:
388
+ st.sidebar.error("❌ API Key inválida")
389
+ else:
390
+ st.sidebar.warning("Por favor, ingresa una API Key")
391
+
392
+ # Mostrar estado actual
393
+ if 'gemini_api_key' in st.session_state:
394
+ st.sidebar.info("API Key configurada")
395
+
396
+ # Configuraciones adicionales y documentación
397
+ def get_gemini_documentation():
398
+ """
399
+ Generar documentación sobre el uso de Gemini en el proyecto
400
+
401
+ Returns:
402
+ str: Documentación en formato markdown
403
+ """
404
+ documentation = """
405
+ ## 🤖 Explicaciones con Gemini AI
406
+
407
+ ### Características
408
+ - Generación de explicaciones detalladas para:
409
+ * Datasets
410
+ * Modelos de Machine Learning
411
+ * Resultados de Clustering
412
+ * Importancia de Características
413
+
414
+ ### Requisitos
415
+ - API Key de Google AI Studio
416
+ - Conexión a internet
417
+ - Biblioteca `google-generativeai`
418
+
419
+ ### Configuración
420
+ 1. Obtén tu API Key en [Google AI Studio](https://makersuite.google.com/app/apikey)
421
+ 2. Configura la API Key en la barra lateral
422
+ 3. Valida la conexión con el botón "Validar API Key"
423
+
424
+ ### Limitaciones
425
+ - Depende de la disponibilidad del servicio
426
+ - Consumo de tokens de API
427
+ - Explicaciones generadas por IA pueden no ser 100% precisas
428
+
429
+ ### Mejores Prácticas
430
+ - Usar como complemento, no como única fuente de verdad
431
+ - Verificar siempre las explicaciones generadas
432
+ - Tener contexto del problema al interpretar resultados
433
+ """
434
+ return documentation
435
+
436
+ # Punto de entrada principal
437
+ if __name__ == "__main__":
438
+ main()
utils/model_utils.py ADDED
@@ -0,0 +1,606 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils/models_utils.py
2
+ import streamlit as st
3
+ import pandas as pd
4
+ import numpy as np
5
+ import plotly.express as px
6
+ import time
7
+ import pickle
8
+ import io
9
+ from stqdm import stqdm
10
+ from sklearn.model_selection import GridSearchCV, train_test_split
11
+ from sklearn.pipeline import Pipeline
12
+ from sklearn.preprocessing import StandardScaler, LabelEncoder
13
+ from sklearn.linear_model import (
14
+ LinearRegression, LogisticRegression, Lasso, Ridge,
15
+ SGDClassifier, RidgeClassifier, PassiveAggressiveClassifier
16
+ )
17
+ from sklearn.tree import DecisionTreeRegressor, DecisionTreeClassifier
18
+ from sklearn.ensemble import (
19
+ RandomForestRegressor, RandomForestClassifier,
20
+ GradientBoostingClassifier, AdaBoostClassifier,
21
+ BaggingClassifier, ExtraTreesClassifier, ExtraTreesRegressor
22
+ )
23
+ from sklearn.naive_bayes import GaussianNB, MultinomialNB, BernoulliNB
24
+ from sklearn.neighbors import KNeighborsClassifier
25
+ from sklearn.svm import SVC, SVR
26
+ from sklearn.metrics import (
27
+ mean_squared_error, r2_score, mean_absolute_error,
28
+ accuracy_score, classification_report, confusion_matrix
29
+ )
30
+ from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
31
+ import xgboost as xgb
32
+ import h2o
33
+ import os
34
+
35
+ class ModelTrainer:
36
+ """
37
+ Clase para gestionar el entrenamiento de modelos de machine learning
38
+ """
39
+ @staticmethod
40
+ def get_model_options(problem_type):
41
+ """
42
+ Obtener opciones de modelos según el tipo de problema
43
+
44
+ Args:
45
+ problem_type (str): Tipo de problema ('classification' o 'regression')
46
+
47
+ Returns:
48
+ dict: Diccionario de opciones de modelos
49
+ """
50
+ if problem_type == 'regression':
51
+ return ModelTrainer._get_regression_models()
52
+ else:
53
+ return ModelTrainer._get_classification_models()
54
+
55
+ @staticmethod
56
+ def _get_regression_models():
57
+ """
58
+ Definir opciones de modelos para regresión
59
+
60
+ Returns:
61
+ dict: Modelos de regresión con sus parámetros
62
+ """
63
+ return {
64
+ 'Regresión Lineal': {
65
+ 'model': lambda rs: Pipeline([
66
+ ('scaler', StandardScaler()),
67
+ ('regressor', LinearRegression())
68
+ ]),
69
+ 'params': {
70
+ 'regressor__fit_intercept': [True, False],
71
+ 'regressor__copy_X': [True],
72
+ 'regressor__positive': [True, False],
73
+ 'scaler__with_mean': [True, False],
74
+ 'scaler__with_std': [True, False]
75
+ }
76
+ },
77
+ 'Lasso': {
78
+ 'model': lambda rs: Pipeline([
79
+ ('scaler', StandardScaler()),
80
+ ('regressor', Lasso(random_state=rs))
81
+ ]),
82
+ 'params': {
83
+ 'regressor__alpha': [0.0001, 0.001, 0.01, 0.1, 1.0, 10.0],
84
+ 'regressor__fit_intercept': [True, False],
85
+ 'regressor__max_iter': [1000, 2000, 5000],
86
+ 'regressor__selection': ['cyclic', 'random'],
87
+ 'regressor__tol': [1e-4, 1e-3],
88
+ 'scaler__with_mean': [True, False],
89
+ 'scaler__with_std': [True, False]
90
+ }
91
+ },
92
+ 'Ridge': {
93
+ 'model': lambda rs: Pipeline([
94
+ ('scaler', StandardScaler()),
95
+ ('regressor', Ridge(random_state=rs))
96
+ ]),
97
+ 'params': {
98
+ 'regressor__alpha': [0.0001, 0.001, 0.01, 0.1, 1.0, 10.0],
99
+ 'regressor__fit_intercept': [True, False],
100
+ 'regressor__solver': ['auto', 'svd', 'cholesky', 'lsqr', 'sparse_cg', 'sag', 'saga'],
101
+ 'regressor__tol': [1e-4, 1e-3],
102
+ 'scaler__with_mean': [True, False],
103
+ 'scaler__with_std': [True, False]
104
+ }
105
+ },
106
+ 'Árbol de Decisión': {
107
+ 'model': lambda rs: DecisionTreeRegressor(random_state=rs),
108
+ 'params': {
109
+ 'max_depth': [3, 5, 7, 10, 15, None],
110
+ 'min_samples_split': [2, 5, 10, 20],
111
+ 'min_samples_leaf': [1, 2, 4, 8],
112
+ 'criterion': ['squared_error', 'friedman_mse', 'absolute_error', 'poisson'],
113
+ 'splitter': ['best', 'random'],
114
+ 'max_features': ['sqrt', 'log2', None]
115
+ }
116
+ },
117
+ 'Random Forest': {
118
+ 'model': lambda rs: RandomForestRegressor(random_state=rs),
119
+ 'params': {
120
+ 'n_estimators': [100, 200, 300, 500],
121
+ 'max_depth': [3, 5, 7, 10, None],
122
+ 'min_samples_split': [2, 5, 10, 20],
123
+ 'min_samples_leaf': [1, 2, 4],
124
+ 'max_features': ['sqrt', 'log2', None],
125
+ 'bootstrap': [True, False],
126
+ 'criterion': ['squared_error', 'absolute_error', 'poisson']
127
+ }
128
+ },
129
+ 'XGBoost': {
130
+ 'model': lambda rs: xgb.XGBRegressor(
131
+ tree_method='hist',
132
+ device='cuda',
133
+ enable_categorical=True,
134
+ random_state=rs
135
+ ),
136
+ 'params': {
137
+ 'n_estimators': [100, 200, 300, 500],
138
+ 'max_depth': [3, 5, 7, 9],
139
+ 'learning_rate': [0.01, 0.05, 0.1, 0.3],
140
+ 'subsample': [0.8, 0.9, 1.0],
141
+ 'colsample_bytree': [0.8, 0.9, 1.0],
142
+ 'min_child_weight': [1, 3, 5],
143
+ 'gamma': [0, 0.1, 0.2],
144
+ 'reg_alpha': [0, 0.1, 0.5],
145
+ 'reg_lambda': [0.1, 1.0, 5.0]
146
+ }
147
+ }
148
+ }
149
+
150
+ @staticmethod
151
+ def _get_classification_models():
152
+ """
153
+ Definir opciones de modelos para clasificación
154
+
155
+ Returns:
156
+ dict: Modelos de clasificación con sus parámetros
157
+ """
158
+ return {
159
+ 'Regresión Logística': {
160
+ 'model': lambda rs: LogisticRegression(max_iter=1000, random_state=rs),
161
+ 'params': {
162
+ 'C': [0.001, 0.01, 0.1, 1.0, 10.0],
163
+ 'penalty': ['l1', 'l2'],
164
+ 'solver': ['liblinear', 'saga'],
165
+ 'class_weight': [None, 'balanced'],
166
+ 'warm_start': [True, False],
167
+ 'tol': [1e-4, 1e-3, 1e-2]
168
+ }
169
+ },
170
+ 'Random Forest': {
171
+ 'model': lambda rs: RandomForestClassifier(random_state=rs),
172
+ 'params': {
173
+ 'n_estimators': [100, 200, 300, 500],
174
+ 'max_depth': [3, 5, 7, 10, None],
175
+ 'min_samples_split': [2, 5, 10],
176
+ 'min_samples_leaf': [1, 2, 4],
177
+ 'class_weight': [None, 'balanced', 'balanced_subsample'],
178
+ 'criterion': ['gini', 'entropy'],
179
+ 'max_features': ['sqrt', 'log2', None]
180
+ }
181
+ },
182
+ 'XGBoost': {
183
+ 'model': lambda rs: xgb.XGBClassifier(
184
+ tree_method='hist',
185
+ device='cuda',
186
+ enable_categorical=True,
187
+ random_state=rs
188
+ ),
189
+ 'params': {
190
+ 'n_estimators': [100, 200, 300, 500],
191
+ 'max_depth': [3, 5, 7, 9],
192
+ 'learning_rate': [0.01, 0.05, 0.1, 0.3],
193
+ 'subsample': [0.8, 0.9, 1.0],
194
+ 'colsample_bytree': [0.8, 0.9, 1.0],
195
+ 'min_child_weight': [1, 3, 5],
196
+ 'gamma': [0, 0.1, 0.2],
197
+ 'reg_alpha': [0, 0.1, 0.5],
198
+ 'reg_lambda': [0.1, 1.0, 5.0],
199
+ 'scale_pos_weight': [1, 2, 3]
200
+ }
201
+ },
202
+ 'SVM': {
203
+ 'model': lambda rs: SVC(random_state=rs),
204
+ 'params': {
205
+ 'C': [0.1, 1, 10, 100],
206
+ 'kernel': ['linear', 'rbf', 'poly', 'sigmoid'],
207
+ 'gamma': ['scale', 'auto', 0.1, 0.01, 0.001],
208
+ 'class_weight': [None, 'balanced'],
209
+ 'probability': [True]
210
+ }
211
+ },
212
+ 'Naive Bayes': {
213
+ 'model': lambda rs: GaussianNB(),
214
+ 'params': {
215
+ 'var_smoothing': [1e-9, 1e-8, 1e-7, 1e-6]
216
+ }
217
+ }
218
+ }
219
+
220
+ @staticmethod
221
+ def _determine_problem_type(model):
222
+ """
223
+ Determinar el tipo de problema basado en el modelo
224
+
225
+ Args:
226
+ model (BaseEstimator): Modelo a evaluar
227
+
228
+ Returns:
229
+ str: Tipo de problema ('classification', 'regression', 'unknown')
230
+ """
231
+ try:
232
+ if hasattr(model, 'predict_proba'):
233
+ return 'classification'
234
+ elif hasattr(model, 'predict'):
235
+ return 'regression'
236
+ else:
237
+ return 'unknown'
238
+ except ImportError:
239
+ return 'unknown'
240
+
241
+ @staticmethod
242
+ def _get_default_scoring(problem_type):
243
+ """
244
+ Obtener la métrica de scoring predeterminada
245
+
246
+ Args:
247
+ problem_type (str): Tipo de problema
248
+
249
+ Returns:
250
+ str: Métrica de scoring predeterminada
251
+ """
252
+ scoring_map = {
253
+ 'classification': 'accuracy',
254
+ 'regression': 'r2',
255
+ 'unknown': None
256
+ }
257
+ return scoring_map.get(problem_type, None)
258
+
259
+ @staticmethod
260
+ def train_model_pipeline(
261
+ X_train,
262
+ y_train,
263
+ model_config,
264
+ X_test=None,
265
+ y_test=None,
266
+ cv=5,
267
+ scoring=None,
268
+ random_state=42,
269
+ **kwargs
270
+ ):
271
+ """
272
+ Entrenar modelo con validación cruzada y evaluación flexible
273
+
274
+ Args:
275
+ X_train (array-like): Datos de entrenamiento
276
+ y_train (array-like): Etiquetas de entrenamiento
277
+ model_config (dict): Configuración del modelo
278
+ X_test (array-like, optional): Datos de prueba
279
+ y_test (array-like, optional): Etiquetas de prueba
280
+ cv (int, optional): Número de pliegues para validación cruzada
281
+ scoring (str, optional): Métrica de puntuación
282
+ random_state (int, optional): Semilla aleatoria para reproducibilidad
283
+ **kwargs: Argumentos adicionales
284
+
285
+ Returns:
286
+ dict: Resultados detallados del entrenamiento
287
+ """
288
+ # Extraer modelo y parámetros
289
+ model_func = model_config.get('model')
290
+ params = model_config.get('params', {})
291
+
292
+ # Instanciar el modelo si es una función
293
+ if callable(model_func):
294
+ model = model_func(random_state)
295
+ else:
296
+ model = model_func
297
+
298
+ # Verificar que el modelo sea una instancia válida
299
+ if not hasattr(model, 'fit') or not hasattr(model, 'predict'):
300
+ raise ValueError(f"Modelo inválido: {model}. Debe tener métodos 'fit' y 'predict'.")
301
+
302
+ # Determinar tipo de problema
303
+ problem_type = ModelTrainer._determine_problem_type(model)
304
+
305
+ # Configurar scoring
306
+ if scoring is None:
307
+ scoring = ModelTrainer._get_default_scoring(problem_type)
308
+
309
+ # Configurar parámetros de GridSearchCV
310
+ grid_search_params = {
311
+ 'estimator': model,
312
+ 'param_grid': params,
313
+ 'cv': cv,
314
+ 'scoring': scoring
315
+ }
316
+
317
+ # Añadir kwargs adicionales
318
+ grid_search_params.update({
319
+ k: v for k, v in kwargs.items()
320
+ if k in ['n_jobs', 'verbose', 'refit', 'error_score']
321
+ })
322
+
323
+ try:
324
+ # Realizar búsqueda de hiperparámetros
325
+ grid_search = GridSearchCV(**grid_search_params)
326
+ with st.spinner(f"Entrenando modelo {model}..."):
327
+ start_time = time.time()
328
+ grid_search.fit(X_train, y_train)
329
+ training_time = time.time() - start_time
330
+
331
+ except Exception as e:
332
+ return {
333
+ 'error': f"Error durante el entrenamiento: {str(e)}",
334
+ 'problem_type': problem_type
335
+ }
336
+
337
+ # Preparar resultados base
338
+ results = {
339
+ 'problem_type': problem_type,
340
+ 'best_model': grid_search.best_estimator_,
341
+ 'best_params': grid_search.best_params_,
342
+ 'best_score': grid_search.best_score_,
343
+ 'cv_results': grid_search.cv_results_,
344
+ 'training_time': training_time
345
+ }
346
+
347
+ # Evaluación en conjunto de prueba
348
+ if X_test is not None and y_test is not None:
349
+ best_model = grid_search.best_estimator_
350
+ y_pred = best_model.predict(X_test)
351
+
352
+ # Métricas específicas según el tipo de problema
353
+ if problem_type == 'classification':
354
+ results.update({
355
+ 'test_accuracy': accuracy_score(y_test, y_pred),
356
+ 'classification_report': classification_report(y_test, y_pred, output_dict=True),
357
+ 'confusion_matrix': confusion_matrix(y_test, y_pred).tolist(),
358
+ 'y_pred': y_pred
359
+ })
360
+ elif problem_type == 'regression':
361
+ results.update({
362
+ 'test_mse': mean_squared_error(y_test, y_pred),
363
+ 'test_rmse': np.sqrt(mean_squared_error(y_test, y_pred)),
364
+ 'test_mae': mean_absolute_error(y_test, y_pred),
365
+ 'test_r2': r2_score(y_test, y_pred),
366
+ 'y_pred': y_pred
367
+ })
368
+ else:
369
+ results['test_predictions'] = y_pred
370
+
371
+ return results
372
+
373
+ @staticmethod
374
+ def create_class_distribution_plot(y_original):
375
+ """
376
+ Crear un gráfico de distribución de clases
377
+
378
+ Args:
379
+ y_original (pd.Series): Variable objetivo original
380
+
381
+ Returns:
382
+ plotly.graph_objs._figure.Figure: Gráfico de distribución de clases
383
+ """
384
+ class_dist = pd.DataFrame({
385
+ 'Clase': y_original.value_counts().index,
386
+ 'Cantidad': y_original.value_counts().values
387
+ })
388
+
389
+ fig = px.bar(
390
+ class_dist,
391
+ x='Clase',
392
+ y='Cantidad',
393
+ title='Distribución de clases'
394
+ )
395
+
396
+ return fig
397
+
398
+ @staticmethod
399
+ def process_classification_data(y, random_state):
400
+ """
401
+ Procesar datos de clasificación
402
+
403
+ Args:
404
+ y (pd.Series): Variable objetivo
405
+ random_state (int): Semilla aleatoria
406
+
407
+ Returns:
408
+ tuple: Variable objetivo procesada y codificador de etiquetas
409
+ """
410
+ # Codificación de etiquetas
411
+ le = LabelEncoder()
412
+ y_encoded = pd.Series(le.fit_transform(y))
413
+
414
+ return y_encoded, le
415
+
416
+ @staticmethod
417
+ def save_model(model, filename):
418
+ """
419
+ Guardar modelo entrenado en un archivo
420
+
421
+ Args:
422
+ model: Modelo entrenado
423
+ filename (str): Nombre del archivo
424
+ """
425
+ if isinstance(model, h2o.estimators.H2OEstimator):
426
+ # Usar método nativo de H2O para guardar modelos
427
+ h2o.save_model(model=model, path=os.path.dirname(filename), force=True)
428
+ else:
429
+ with open(filename, 'wb') as f:
430
+ pickle.dump(model, f)
431
+
432
+ @staticmethod
433
+ def load_model(filename):
434
+ """
435
+ Cargar modelo desde un archivo
436
+
437
+ Args:
438
+ filename (str): Nombre del archivo
439
+
440
+ Returns:
441
+ Modelo cargado
442
+ """
443
+ if filename.endswith('.zip'):
444
+ # Asumir que es un modelo H2O
445
+ return h2o.load_model(filename)
446
+ else:
447
+ with open(filename, 'rb') as f:
448
+ return pickle.load(f)
449
+
450
+ @staticmethod
451
+ def get_model_performance_metrics(y_true, y_pred, problem_type):
452
+ """
453
+ Obtener métricas de rendimiento del modelo
454
+
455
+ Args:
456
+ y_true (pd.Series): Etiquetas verdaderas
457
+ y_pred (pd.Series): Etiquetas predichas
458
+ problem_type (str): Tipo de problema
459
+
460
+ Returns:
461
+ dict: Métricas de rendimiento
462
+ """
463
+ if problem_type == 'classification':
464
+ return {
465
+ 'accuracy': accuracy_score(y_true, y_pred),
466
+ 'classification_report': classification_report(y_true, y_pred, output_dict=True)
467
+ }
468
+ else: # Regresión
469
+ return {
470
+ 'mse': mean_squared_error(y_true, y_pred),
471
+ 'r2_score': r2_score(y_true, y_pred)
472
+ }
473
+
474
+ @staticmethod
475
+ def split_data(X, y, test_size=0.2, random_state=42):
476
+ """
477
+ Dividir datos en conjuntos de entrenamiento y prueba
478
+
479
+ Args:
480
+ X (pd.DataFrame): Features
481
+ y (pd.Series): Variable objetivo
482
+ test_size (float): Proporción de datos de prueba
483
+ random_state (int): Semilla aleatoria
484
+
485
+ Returns:
486
+ tuple: X_train, X_test, y_train, y_test
487
+ """
488
+ return train_test_split(X, y, test_size=test_size, random_state=random_state)
489
+
490
+ @staticmethod
491
+ def prepare_data_for_ml(df, target_column, problem_type='classification', test_size=0.2, random_state=42):
492
+ """
493
+ Preparar datos para machine learning
494
+
495
+ Args:
496
+ df (pd.DataFrame): DataFrame de datos
497
+ target_column (str): Columna objetivo
498
+ problem_type (str): Tipo de problema
499
+ test_size (float): Proporción de datos de prueba
500
+ random_state (int): Semilla aleatoria
501
+
502
+ Returns:
503
+ dict: Diccionario con datos preparados
504
+ """
505
+ # Separar features y target
506
+ X = df.drop(columns=[target_column])
507
+ y = df[target_column]
508
+
509
+ # Preprocesar datos según el tipo de problema
510
+ if problem_type == 'classification':
511
+ y, label_encoder = ModelTrainer.process_classification_data(y, random_state)
512
+ else:
513
+ label_encoder = None
514
+
515
+ # Dividir datos
516
+ X_train, X_test, y_train, y_test = ModelTrainer.split_data(X, y, test_size, random_state)
517
+
518
+ return {
519
+ 'X_train': X_train,
520
+ 'X_test': X_test,
521
+ 'y_train': y_train,
522
+ 'y_test': y_test,
523
+ 'label_encoder': label_encoder,
524
+ 'features': list(X.columns),
525
+ 'problem_type': problem_type
526
+ }
527
+
528
+ @staticmethod
529
+ def generate_model_comparison_report(trained_models, problem_type):
530
+ """
531
+ Generar informe comparativo de modelos
532
+
533
+ Args:
534
+ trained_models (dict): Modelos entrenados
535
+ problem_type (str): Tipo de problema
536
+
537
+ Returns:
538
+ pd.DataFrame: Informe comparativo de modelos
539
+ """
540
+ comparison_data = []
541
+
542
+ for model_name, model_info in trained_models.items():
543
+ model_metrics = ModelTrainer.get_model_performance_metrics(
544
+ model_info['y_test'],
545
+ model_info['y_pred'],
546
+ problem_type
547
+ )
548
+
549
+ model_entry = {
550
+ 'Modelo': model_name,
551
+ 'Tiempo de Entrenamiento': model_info.get('training_time', 0),
552
+ }
553
+
554
+ # Agregar métricas según el tipo de problema
555
+ if problem_type == 'classification':
556
+ model_entry.update({
557
+ 'Precisión': model_metrics['accuracy'],
558
+ 'Precisión (Macro)': model_metrics['classification_report']['macro avg']['precision'],
559
+ 'Recall (Macro)': model_metrics['classification_report']['macro avg']['recall'],
560
+ 'F1-Score (Macro)': model_metrics['classification_report']['macro avg']['f1-score']
561
+ })
562
+ else:
563
+ model_entry.update({
564
+ 'MSE': model_metrics['mse'],
565
+ 'R2 Score': model_metrics['r2_score']
566
+ })
567
+
568
+ comparison_data.append(model_entry)
569
+
570
+ return pd.DataFrame(comparison_data)
571
+
572
+ @staticmethod
573
+ def plot_model_comparison(comparison_df, problem_type):
574
+ """
575
+ Crear gráfico comparativo de modelos
576
+
577
+ Args:
578
+ comparison_df (pd.DataFrame): DataFrame de comparación de modelos
579
+ problem_type (str): Tipo de problema
580
+
581
+ Returns:
582
+ plotly.graph_objs._figure.Figure: Gráfico comparativo
583
+ """
584
+ metric_column = 'Precisión' if problem_type == 'classification' else 'R2 Score'
585
+
586
+ fig = px.bar(
587
+ comparison_df,
588
+ x='Modelo',
589
+ y=metric_column,
590
+ title=f'Comparación de Modelos - {metric_column}'
591
+ )
592
+
593
+ return fig
594
+
595
+ # Funciones sueltas para importación directa
596
+ def get_model_options(problem_type):
597
+ return ModelTrainer.get_model_options(problem_type)
598
+
599
+ def train_model_pipeline(*args, **kwargs):
600
+ return ModelTrainer.train_model_pipeline(*args, **kwargs)
601
+
602
+ def process_classification_data(y, random_state=42):
603
+ return ModelTrainer.process_classification_data(y, random_state)
604
+
605
+ def create_class_distribution_plot(y):
606
+ return ModelTrainer.create_class_distribution_plot(y)
utils/shap_explainer.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils/shap_explainer.py
2
+ import streamlit as st
3
+ import pandas as pd
4
+ import numpy as np
5
+ import shap
6
+ import plotly.express as px
7
+ import plotly.graph_objects as go
8
+ from typing import Dict, Any, Union, Optional
9
+ from sklearn.pipeline import Pipeline
10
+ import h2o
11
+ import streamlit as st
12
+ import pandas as pd
13
+ import numpy as np
14
+ import shap
15
+ import plotly.express as px
16
+ import plotly.graph_objects as go
17
+ from typing import Dict, Any, Optional, Union
18
+ from sklearn.pipeline import Pipeline
19
+
20
+ class SHAPExplainer:
21
+ """
22
+ Clase para realizar explicaciones de modelos usando SHAP (SHapley Additive exPlanations)
23
+ """
24
+ def __init__(self, model, X: pd.DataFrame, problem_type: str = 'classification', explanation_method: str = 'auto'):
25
+ """
26
+ Inicializar el explicador SHAP
27
+
28
+ Args:
29
+ model: Modelo de machine learning entrenado
30
+ X (pd.DataFrame): Datos de entrada para el modelo
31
+ problem_type (str): Tipo de problema ('classification' o 'regression')
32
+ explanation_method (str): Método de explicación ('auto', 'tree', 'linear', 'kernel')
33
+ """
34
+ self.model = model
35
+ self.X = X
36
+ self.problem_type = problem_type
37
+ self.explanation_method = explanation_method
38
+ self.explainer = self._create_explainer()
39
+ self.X_sample = None # Inicializar X_sample
40
+
41
+ def _create_explainer(self):
42
+ """
43
+ Crear el explicador SHAP apropiado según el tipo de modelo y método seleccionado,
44
+ manejando correctamente los Pipelines y modelos de H2O.
45
+
46
+ Returns:
47
+ Explainer de SHAP
48
+ """
49
+ try:
50
+ # Si el modelo es un Pipeline, extraer el estimador final
51
+ if isinstance(self.model, Pipeline):
52
+ estimator = self.model.steps[-1][1]
53
+ else:
54
+ estimator = self.model
55
+
56
+ # Verificar si el modelo es de H2O
57
+ if isinstance(estimator, h2o.estimators.H2OEstimator):
58
+ # Usar KernelExplainer para modelos de H2O
59
+ # Obtener función de predicción compatible con SHAP
60
+ def predict_function(x):
61
+ h2o_frame = h2o.H2OFrame(x)
62
+ preds = estimator.predict(h2o_frame)
63
+ return preds.as_data_frame()['predict'].values
64
+
65
+ return shap.KernelExplainer(predict_function, shap.sample(self.X, 100))
66
+
67
+ # Crear el explicador usando el método seleccionado
68
+ if self.explanation_method.lower() == 'tree':
69
+ return shap.TreeExplainer(estimator)
70
+ elif self.explanation_method.lower() == 'linear':
71
+ return shap.LinearExplainer(estimator, self.X, feature_dependence="independent")
72
+ elif self.explanation_method.lower() == 'kernel':
73
+ return shap.KernelExplainer(estimator.predict, shap.sample(self.X, 100))
74
+ else:
75
+ # 'auto' o cualquier otro valor: usar shap.Explainer que selecciona automáticamente
76
+ return shap.Explainer(estimator, self.X)
77
+
78
+ except Exception as e:
79
+ st.error(f"Error al crear explicador SHAP: {str(e)}")
80
+ return None
81
+
82
+ def compute_shap_values(self, X_sample: Optional[pd.DataFrame] = None, max_samples: int = 100):
83
+ """
84
+ Calcular valores SHAP
85
+
86
+ Args:
87
+ X_sample (pd.DataFrame, opcional): Muestra de datos para calcular SHAP
88
+ max_samples (int): Número máximo de muestras a procesar
89
+
90
+ Returns:
91
+ Valores SHAP
92
+ """
93
+ try:
94
+ # Usar muestra si no se proporciona
95
+ if X_sample is None:
96
+ X_sample = self.X.sample(n=min(max_samples, len(self.X)), random_state=42)
97
+
98
+ # Almacenar el subconjunto de datos utilizado
99
+ self.X_sample = X_sample
100
+
101
+ # Asegurarse de que X_sample es 2D
102
+ if X_sample.ndim != 2:
103
+ raise ValueError(f"Debe pasar una entrada 2D a SHAP. La forma actual es {X_sample.shape}")
104
+
105
+ # Calcular valores SHAP usando el explicador
106
+ shap_values = self.explainer.shap_values(X_sample)
107
+
108
+ # Para clasificación multiclase, SHAP devuelve una lista de arrays
109
+ if isinstance(shap_values, list):
110
+ # Promediar las contribuciones de todas las clases
111
+ shap_values = np.mean(np.abs(shap_values), axis=0)
112
+
113
+ else:
114
+ shap_values = np.abs(shap_values)
115
+
116
+ return shap_values
117
+
118
+ except Exception as e:
119
+ st.error(f"Error al calcular valores SHAP: {str(e)}")
120
+ return None
121
+
122
+ def plot_summary(self, shap_values, title: str = "SHAP Summary Plot"):
123
+ """
124
+ Generar gráfico de resumen de valores SHAP
125
+
126
+ Args:
127
+ shap_values: Valores SHAP calculados
128
+ title (str): Título del gráfico
129
+
130
+ Returns:
131
+ Figura de Plotly
132
+ """
133
+ try:
134
+ feature_names = self.X.columns.tolist()
135
+
136
+ # Calcular importancia de características
137
+ feature_importance = np.mean(shap_values, axis=0)
138
+ importance_df = pd.DataFrame({
139
+ 'feature': feature_names,
140
+ 'importance': feature_importance
141
+ }).sort_values('importance', ascending=False)
142
+
143
+ # Gráfico de barras de importancia
144
+ fig = px.bar(
145
+ importance_df,
146
+ x='importance',
147
+ y='feature',
148
+ orientation='h',
149
+ title=title,
150
+ labels={'importance': 'Importancia SHAP', 'feature': 'Características'}
151
+ )
152
+
153
+ return fig
154
+
155
+ except Exception as e:
156
+ st.error(f"Error al generar gráfico de resumen: {str(e)}")
157
+ return None
158
+
159
+ def plot_dependence(self, shap_values, feature_name: str):
160
+ """
161
+ Generar gráfico de dependencia para una característica
162
+
163
+ Args:
164
+ shap_values: Valores SHAP calculados
165
+ feature_name (str): Nombre de la característica
166
+
167
+ Returns:
168
+ Figura de Plotly
169
+ """
170
+ try:
171
+ feature_idx = self.X.columns.get_loc(feature_name)
172
+
173
+ # Preparar datos usando el mismo subconjunto de datos utilizado para SHAP
174
+ if self.X_sample is not None:
175
+ x = self.X_sample.iloc[:, feature_idx]
176
+ else:
177
+ x = self.X.iloc[:, feature_idx]
178
+
179
+ y = shap_values[:, feature_idx]
180
+
181
+ # Verificar que las longitudes coincidan
182
+ if len(x) != len(y):
183
+ raise ValueError(f"Longitud de 'x' ({len(x)}) y 'y' ({len(y)}) no coinciden.")
184
+
185
+ # Crear scatter plot
186
+ fig = px.scatter(
187
+ x=x,
188
+ y=y,
189
+ title=f'SHAP Dependence Plot - {feature_name}',
190
+ labels={'x': feature_name, 'y': 'SHAP Value'}
191
+ )
192
+
193
+ return fig
194
+
195
+ except Exception as e:
196
+ st.error(f"Error al generar gráfico de dependencia: {str(e)}")
197
+ return None
198
+
199
+ def generate_feature_importance_report(self, shap_values) -> Dict[str, Any]:
200
+ """
201
+ Generar un informe detallado de importancia de características
202
+
203
+ Args:
204
+ shap_values: Valores SHAP calculados
205
+
206
+ Returns:
207
+ Diccionario con información de importancia de características
208
+ """
209
+ try:
210
+ # Calcular importancia
211
+ feature_importance = np.mean(shap_values, axis=0)
212
+
213
+ # Crear DataFrame de importancia
214
+ importance_df = pd.DataFrame({
215
+ 'feature': self.X.columns,
216
+ 'importance': feature_importance
217
+ }).sort_values('importance', ascending=False)
218
+
219
+ # Generar informe
220
+ report = {
221
+ 'top_features': importance_df.head(5).to_dict('records'),
222
+ 'bottom_features': importance_df.tail(5).to_dict('records'),
223
+ 'total_features': len(importance_df),
224
+ 'max_importance': importance_df['importance'].max(),
225
+ 'min_importance': importance_df['importance'].min()
226
+ }
227
+
228
+ return report
229
+
230
+ except Exception as e:
231
+ st.error(f"Error al generar informe de importancia: {str(e)}")
232
+ return {}
233
+
234
+ def create_shap_analysis_dashboard(model, X: pd.DataFrame, problem_type: str = 'classification'):
235
+ """
236
+ Crear un dashboard de análisis SHAP en Streamlit
237
+
238
+ Args:
239
+ model: Modelo de machine learning
240
+ X (pd.DataFrame): Datos de entrada
241
+ problem_type (str): Tipo de problema
242
+ """
243
+ st.title("🔍 Análisis de Explicabilidad SHAP")
244
+
245
+ # Inicializar los valores SHAP en session_state si no existen
246
+ if 'shap_explainer' not in st.session_state:
247
+ # Parámetros por defecto
248
+ explanation_method = 'auto'
249
+ max_samples = 100
250
+
251
+ # Crear y almacenar el explicador SHAP
252
+ st.session_state.shap_explainer = SHAPExplainer(
253
+ model=model,
254
+ X=X,
255
+ problem_type=problem_type,
256
+ explanation_method=explanation_method
257
+ )
258
+
259
+ # Calcular y almacenar los valores SHAP
260
+ st.session_state.shap_values = st.session_state.shap_explainer.compute_shap_values(
261
+ max_samples=max_samples
262
+ )
263
+
264
+ shap_explainer = st.session_state.shap_explainer
265
+ shap_values = st.session_state.shap_values
266
+
267
+ if shap_values is None:
268
+ st.error("No se pudieron calcular los valores SHAP")
269
+ return
270
+
271
+ # Pestañas para diferentes visualizaciones
272
+ tab1, tab2, tab3, tab4 = st.tabs([
273
+ "Resumen de Importancia",
274
+ "Dependencia de Características",
275
+ "Informe Detallado",
276
+ "Configuración Avanzada"
277
+ ])
278
+
279
+ with tab1:
280
+ st.header("Resumen de Importancia de Características")
281
+
282
+ # Gráfico de resumen
283
+ summary_fig = shap_explainer.plot_summary(shap_values)
284
+ if summary_fig:
285
+ st.plotly_chart(summary_fig, use_container_width=True)
286
+
287
+ # Selector de características para análisis detallado
288
+ selected_feature = st.selectbox(
289
+ "Seleccionar característica para análisis detallado",
290
+ X.columns.tolist()
291
+ )
292
+
293
+ # Gráfico de dependencia para la característica seleccionada
294
+ dependence_fig = shap_explainer.plot_dependence(shap_values, selected_feature)
295
+ if dependence_fig:
296
+ st.plotly_chart(dependence_fig, use_container_width=True)
297
+
298
+ with tab2:
299
+ st.header("Análisis de Dependencia de Características")
300
+
301
+ # Matriz de correlación de valores SHAP
302
+ shap_correlation = pd.DataFrame(shap_values).corr()
303
+
304
+ # Heatmap de correlación de valores SHAP
305
+ fig_corr = px.imshow(
306
+ shap_correlation,
307
+ title="Correlación entre Valores SHAP de Características",
308
+ labels=dict(x="Características", y="Características", color="Correlación")
309
+ )
310
+ st.plotly_chart(fig_corr, use_container_width=True)
311
+
312
+ with tab3:
313
+ st.header("Informe Detallado de Importancia")
314
+
315
+ # Generar informe de importancia de características
316
+ importance_report = shap_explainer.generate_feature_importance_report(shap_values)
317
+
318
+ # Mostrar características más importantes
319
+ st.subheader("Top 5 Características Más Importantes")
320
+ top_features_df = pd.DataFrame(importance_report.get('top_features', []))
321
+ st.dataframe(top_features_df)
322
+
323
+ # Visualización de características más importantes
324
+ fig_top_features = px.bar(
325
+ top_features_df,
326
+ x='importance',
327
+ y='feature',
328
+ orientation='h',
329
+ title="Top 5 Características por Importancia SHAP"
330
+ )
331
+ st.plotly_chart(fig_top_features, use_container_width=True)
332
+
333
+ # Métricas de resumen
334
+ col1, col2, col3 = st.columns(3)
335
+ with col1:
336
+ st.metric("Total de Características", importance_report.get('total_features', 'N/A'))
337
+ with col2:
338
+ st.metric("Máxima Importancia", f"{importance_report.get('max_importance', 'N/A'):.4f}")
339
+ with col3:
340
+ st.metric("Mínima Importancia", f"{importance_report.get('min_importance', 'N/A'):.4f}")
341
+
342
+ with tab4:
343
+ st.header("Configuración Avanzada")
344
+
345
+ # Controles de configuración
346
+ st.subheader("Parámetros de Explicación")
347
+
348
+ # Selector de método de explicación
349
+ explanation_method = st.selectbox(
350
+ "Método de Explicación",
351
+ ["auto", "tree", "linear", "kernel"]
352
+ )
353
+
354
+ # Número de muestras para cálculo
355
+ num_samples = st.slider(
356
+ "Número de Muestras para Análisis",
357
+ min_value=10,
358
+ max_value=min(1000, len(X)),
359
+ value=min(100, len(X))
360
+ )
361
+
362
+ # Botón para recalcular con nuevos parámetros
363
+ if st.button("Recalcular SHAP"):
364
+ with st.spinner("Recalculando valores SHAP..."):
365
+ try:
366
+ # Crear y actualizar el explicador SHAP con nuevos parámetros
367
+ shap_explainer = SHAPExplainer(
368
+ model=model,
369
+ X=X,
370
+ problem_type=problem_type,
371
+ explanation_method=explanation_method
372
+ )
373
+ st.session_state.shap_explainer = shap_explainer
374
+
375
+ # Calcular y actualizar los valores SHAP
376
+ shap_values = shap_explainer.compute_shap_values(
377
+ max_samples=num_samples
378
+ )
379
+ st.session_state.shap_values = shap_values
380
+
381
+ st.success("Valores SHAP recalculados correctamente.")
382
+
383
+ except Exception as e:
384
+ st.error(f"Error al recalcular SHAP: {str(e)}")
385
+
386
+ def validate_shap_compatibility(model):
387
+ """
388
+ Validar si un modelo es compatible con SHAP
389
+
390
+ Args:
391
+ model: Modelo de machine learning
392
+
393
+ Returns:
394
+ bool: True si es compatible, False en caso contrario
395
+ """
396
+ compatible_types = [
397
+ 'RandomForestClassifier',
398
+ 'RandomForestRegressor',
399
+ 'GradientBoostingClassifier',
400
+ 'GradientBoostingRegressor',
401
+ 'XGBClassifier',
402
+ 'XGBRegressor',
403
+ 'DecisionTreeClassifier',
404
+ 'DecisionTreeRegressor',
405
+ 'LogisticRegression',
406
+ 'LinearRegression'
407
+ ]
408
+
409
+ return any(
410
+ comp_type in str(type(model).__name__)
411
+ for comp_type in compatible_types
412
+ )
413
+
414
+ def generate_shap_documentation():
415
+ """
416
+ Generar documentación sobre el uso de SHAP
417
+
418
+ Returns:
419
+ str: Documentación en formato markdown
420
+ """
421
+ documentation = """
422
+ ## 🔍 Explicabilidad de Modelos con SHAP
423
+
424
+ ### ¿Qué es SHAP?
425
+ SHAP (SHapley Additive exPlanations) es una metodología para explicar las predicciones
426
+ de modelos de machine learning basada en la teoría de juegos.
427
+
428
+ ### Características Principales
429
+ - Interpretación global y local de modelos
430
+ - Calcula la contribución de cada característica a la predicción
431
+ - Funciona con diferentes tipos de modelos
432
+
433
+ ### Tipos de Visualizaciones
434
+ 1. **Summary Plot**: Importancia general de características
435
+ 2. **Dependence Plot**: Relación entre características y predicciones
436
+ 3. **Force Plot**: Contribución individual de características
437
+
438
+ ### Limitaciones
439
+ - Computacionalmente intensivo para grandes datasets
440
+ - Puede ser lento con modelos complejos
441
+ - Requiere comprensión estadística para interpretación precisa
442
+
443
+ ### Mejores Prácticas
444
+ - Usar como complemento, no como única fuente de verdad
445
+ - Combinar con otras técnicas de explicabilidad
446
+ - Interpretar en contexto del problema de negocio
447
+ """
448
+ return documentation
449
+
450
+ # Punto de entrada principal para pruebas
451
+ def main():
452
+ import streamlit as st
453
+ from sklearn.ensemble import RandomForestClassifier
454
+ from sklearn.datasets import load_iris
455
+
456
+ # Cargar datos de ejemplo
457
+ iris = load_iris()
458
+ X = pd.DataFrame(iris.data, columns=iris.feature_names)
459
+ y = iris.target
460
+
461
+ # Entrenar modelo de ejemplo
462
+ model = RandomForestClassifier(n_estimators=100, random_state=42)
463
+ model.fit(X, y)
464
+
465
+ # Crear dashboard de análisis SHAP
466
+ create_shap_analysis_dashboard(model, X)
467
+
468
+ if __name__ == "__main__":
469
+ main()