| import gradio as gr |
| from datasets import load_dataset |
| import os |
| from threading import Thread |
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig |
| from sentence_transformers import SentenceTransformer |
|
|
| |
| token = os.getenv("HF_TOKEN") |
|
|
| |
| ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1") |
|
|
| dataset = load_dataset("not-lain/wikipedia",revision = "embedded") |
| data = dataset["train"] |
| data = data.add_faiss_index("embeddings") |
|
|
| |
| bnb_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_use_double_quant=True, |
| bnb_4bit_quant_type="nf4", |
| bnb_4bit_compute_dtype=torch.bfloat16 |
| ) |
|
|
| |
| |
| MODELOS = { |
| "Llama-3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct", |
| "DeepSeek-R1": "deepseek-ai/DeepSeek-R1" |
| } |
|
|
|
|
|
|
| |
| MODELOS_CARGADOS = {} |
|
|
| |
| def get_model(selected_model): |
| model_id = MODELOS[selected_model] |
| if selected_model not in MODELOS_CARGADOS: |
| tokenizer = AutoTokenizer.from_pretrained(model_id, token=token) |
| model = AutoModelForCausalLM.from_pretrained( |
| model_id, |
| torch_dtype=torch.bfloat16, |
| device_map="auto", |
| quantization_config=bnb_config, |
| token=token |
| ) |
| MODELOS_CARGADOS[selected_model] = (model, tokenizer) |
| return MODELOS_CARGADOS[selected_model] |
|
|
| |
| def get_terminators(tokenizer): |
| return [ |
| tokenizer.eos_token_id, |
| tokenizer.convert_tokens_to_ids("<|eot_id|>") |
| ] |
|
|
| |
| SYS_PROMPT = ( |
| "Tu tarea es analizar un listado de unidades de competencia y devolver únicamente aquellas relacionadas " |
| "con la profesión de {texto_usuario}. Debes buscar palabras clave o términos relacionados con la profesión " |
| "en los nombres de las unidades de competencia. If you don't know the answer, just say 'I do not know.' " |
| "Don't make up an answer." |
| ) |
|
|
| def search(query: str, k: int = 3): |
| """Embebe la consulta y retorna los resultados más probables.""" |
| embedded_query = ST.encode(query) |
| scores, retrieved_examples = data.get_nearest_examples( |
| "embeddings", |
| embedded_query, |
| k=k |
| ) |
| return scores, retrieved_examples |
|
|
| def format_prompt(prompt, retrieved_documents, k): |
| """Construye el prompt a partir de los documentos recuperados.""" |
| PROMPT = f"Question: {prompt}\nContext:\n" |
| for idx in range(k): |
| PROMPT += f"{retrieved_documents['text'][idx]}\n" |
| return PROMPT |
|
|
| def talk(prompt, selected_model, history=[]): |
| |
| model, tokenizer = get_model(selected_model) |
| terminators = get_terminators(tokenizer) |
| |
| |
| k = 1 |
| scores, retrieved_documents = search(prompt, k) |
| formatted_prompt = format_prompt(prompt, retrieved_documents, k) |
| formatted_prompt = formatted_prompt[:2000] |
|
|
| |
| messages = [ |
| {"role": "system", "content": SYS_PROMPT}, |
| {"role": "user", "content": formatted_prompt} |
| ] |
| |
| |
| input_ids = tokenizer.apply_chat_template( |
| messages, |
| add_generation_prompt=True, |
| return_tensors="pt" |
| ).to(model.device) |
| |
| |
| streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) |
| generate_kwargs = dict( |
| input_ids=input_ids, |
| streamer=streamer, |
| max_new_tokens=1024, |
| do_sample=True, |
| top_p=0.95, |
| temperature=0.75, |
| eos_token_id=terminators, |
| ) |
| |
| |
| t = Thread(target=model.generate, kwargs=generate_kwargs) |
| t.start() |
| |
| outputs = [] |
| for text in streamer: |
| outputs.append(text) |
| yield "".join(outputs) |
|
|
| |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: |
| gr.Markdown("# RAG Chatbot") |
| gr.Markdown("Selecciona el modelo de inteligencia artificial:") |
| |
| with gr.Row(): |
| |
| modelo_selector = gr.Dropdown(choices=list(MODELOS.keys()), value="Llama-3.2-3B-Instruct", label="Modelo") |
| |
| chatbot = gr.Chatbot(show_label=True, show_share_button=True, show_copy_button=True, layout="bubble") |
| |
| |
| prompt_input = gr.Textbox(lines=2, label="Ingresa tu pregunta") |
| |
| |
| send_btn = gr.Button("Enviar") |
| |
| |
| send_btn.click(fn=talk, inputs=[prompt_input, modelo_selector], outputs=chatbot) |
| |
| if __name__ == "__main__": |
| demo.launch(debug=True) |
|
|