| import streamlit as st |
| import yaml |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| import torch |
|
|
| class ModelManager: |
| def __init__(self, model_name="microsoft/Phi-4-mini-instruct"): |
| |
| self.models = { |
| "microsoft/Phi-4-mini-instruct": "microsoft/Phi-4-mini-instruct", |
| "microsoft/Phi-4-multimodal": "microsoft/Phi-4-multimodal", |
| "meta-llama/Llama-3.3-70B-Instruct": "meta-llama/Llama-3.3-70B-Instruct" |
| } |
| self.current_model_name = model_name |
| self.tokenizer = None |
| self.model = None |
| self.load_model(model_name) |
| |
| def load_model(self, model_name): |
| self.current_model_name = model_name |
| model_path = self.models[model_name] |
| st.info(f"Cargando modelo: {model_name} ...") |
| self.tokenizer = AutoTokenizer.from_pretrained(model_path) |
| self.model = AutoModelForCausalLM.from_pretrained(model_path) |
| |
| def generate(self, prompt, max_length=50, temperature=0.7): |
| inputs = self.tokenizer(prompt, return_tensors="pt") |
| outputs = self.model.generate(inputs["input_ids"], max_length=max_length, temperature=temperature) |
| return self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
| |
| def switch_model(self, model_name): |
| if model_name in self.models: |
| self.load_model(model_name) |
| else: |
| raise ValueError(f"El modelo {model_name} no est谩 disponible.") |
|
|
| @st.cache_data |
| def load_prompts(): |
| with open("prompt.yml", "r", encoding="utf-8") as f: |
| prompts = yaml.safe_load(f) |
| return prompts |
|
|
| def main(): |
| st.title("Switcher de Modelos de Transformers") |
| |
| |
| prompts_config = load_prompts() |
| |
| |
| st.sidebar.title("Selecci贸n de Modelo") |
| model_choice = st.sidebar.selectbox("Selecciona un modelo", list(prompts_config.keys())) |
| |
| |
| model_manager = ModelManager(model_name=model_choice) |
| |
| |
| style_prompt = prompts_config.get(model_choice, prompts_config.get("default_prompt", "")) |
| |
| st.write(f"**Modelo en uso:** {model_choice}") |
| |
| |
| user_prompt = st.text_area("Ingresa tu prompt:", value=style_prompt) |
| |
| max_length = st.slider("Longitud m谩xima", min_value=10, max_value=200, value=50) |
| temperature = st.slider("Temperatura", min_value=0.1, max_value=1.0, value=0.7) |
| |
| if st.button("Generar respuesta"): |
| result = model_manager.generate(user_prompt, max_length=max_length, temperature=temperature) |
| st.text_area("Salida", value=result, height=200) |
|
|
| if __name__ == "__main__": |
| main() |
|
|