| |
| from __future__ import annotations |
| import os, json, re |
| from typing import List, Dict, Any, Optional, Tuple |
|
|
| import gradio as gr |
| import spaces |
| import torch |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| from salamandra_tools import SalamandraClient |
|
|
| |
| MODEL_ID = os.environ.get("MODEL_ID", "BSC-LT/salamandra-7b-tools") |
| DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32 |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| _tok = None |
| _model = None |
|
|
| def _lazy_load() -> Tuple[AutoTokenizer, AutoModelForCausalLM]: |
| global _tok, _model |
| if _tok is None or _model is None: |
| _tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True, trust_remote_code=True) |
| _model = AutoModelForCausalLM.from_pretrained( |
| MODEL_ID, |
| torch_dtype=DTYPE, |
| low_cpu_mem_usage=True, |
| use_safetensors=True, |
| trust_remote_code=True, |
| device_map=None, |
| ).to(DEVICE) |
| return _tok, _model |
|
|
|
|
| |
|
|
| def _render_tools_md(tools: List[Dict[str, Any]]) -> str: |
| """Convierte la especificación OpenAI-style de tools a un bloque breve markdown para el prompt.""" |
| if not tools: |
| return "" |
| lines = ["Herramientas disponibles (formato JSON):"] |
| for t in tools: |
| name = t.get("function", {}).get("name") or t.get("name") or "tool" |
| desc = t.get("function", {}).get("description") or t.get("description") or "" |
| params = t.get("function", {}).get("parameters") or t.get("parameters") or {} |
| lines.append(f"- **{name}**: {desc} | parámetros: {json.dumps(params)[:600]}") |
| return "\n".join(lines) |
|
|
| def _compose_chat_prompt(messages: List[Dict[str, str]], tools_md: str) -> str: |
| """ |
| Soporta mensajes estilo OpenAI: [{"role":"system|user|assistant", "content":"..."}] |
| Usa chat_template si está disponible. |
| """ |
| tok, _ = _lazy_load() |
| sys_text = "" |
| usr_msgs: List[Dict[str, str]] = [] |
| for m in messages: |
| role = m.get("role", "") |
| content = (m.get("content") or "").strip() |
| if role == "system": |
| sys_text += ("\n" + content) if sys_text else content |
| else: |
| usr_msgs.append({"role": role, "content": content}) |
|
|
| |
| if tools_md: |
| sys_text = (sys_text + "\n\n" if sys_text else "") + tools_md + \ |
| "\n\nSi decides llamar a una herramienta, devuelve un objeto JSON con la clave 'tool_calls' " \ |
| "y describe tus razonamientos de forma concisa en 'thought' (opcional)." |
|
|
| |
| conv: List[Dict[str, str]] = [] |
| if sys_text: |
| conv.append({"role":"system", "content": sys_text}) |
| conv.extend(usr_msgs) |
|
|
| chat_template = getattr(tok, "chat_template", None) |
| if chat_template: |
| return tok.apply_chat_template(conv, tokenize=False, add_generation_prompt=True) |
|
|
| |
| rendered = "" |
| if sys_text: |
| rendered += f"<<SYS>>\n{sys_text}\n<</SYS>>\n\n" |
| for m in usr_msgs: |
| if m["role"] == "user": |
| rendered += f"### Usuario\n{m['content']}\n\n" |
| elif m["role"] == "assistant": |
| rendered += f"### Asistente\n{m['content']}\n\n" |
| rendered += "### Asistente\n" |
| return rendered |
|
|
|
|
| |
| |
| |
| |
| EXECUTE_TOOLS = True |
|
|
| def _safe_calculator(expr: str) -> str: |
| |
| if not re.fullmatch(r"[0-9\.\s\+\-\*\/\%\(\)\^eE]+", expr.replace("**","^")): |
| return "Rejected expression." |
| |
| expr = expr.replace("^", "**") |
| try: |
| return str(eval(expr, {"__builtins__":{}}, {})) |
| except Exception as e: |
| return f"Error: {e}" |
|
|
| LOCAL_TOOLBOX = { |
| "calculator": lambda args: _safe_calculator(str(args.get("expr",""))), |
| } |
|
|
| def maybe_execute_tool_calls(tool_calls: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
| if not EXECUTE_TOOLS: |
| return [] |
| results = [] |
| for call in tool_calls: |
| name = call.get("name") |
| args = call.get("arguments", {}) |
| fn = LOCAL_TOOLBOX.get(name) |
| if fn is None: |
| results.append({"name": name, "error": "tool_not_available"}) |
| continue |
| try: |
| out = fn(args) |
| results.append({"name": name, "output": out}) |
| except Exception as e: |
| results.append({"name": name, "error": str(e)}) |
| return results |
|
|
|
|
| |
|
|
| @spaces.GPU |
| def _generate_with_tools( |
| messages: List[Dict[str, str]], |
| tools: List[Dict[str, Any]], |
| max_new_tokens: int = 512, |
| temperature: float = 0.7, |
| top_p: float = 0.95, |
| ) -> Dict[str, Any]: |
| tok, model = _lazy_load() |
| tools_md = _render_tools_md(tools) |
| prompt = _compose_chat_prompt(messages, tools_md) |
|
|
| inputs = tok(prompt, return_tensors="pt").to(DEVICE) |
| with torch.inference_mode(): |
| out = model.generate( |
| **inputs, |
| max_new_tokens=int(max_new_tokens), |
| temperature=float(temperature), |
| top_p=float(top_p), |
| do_sample=True if temperature > 0 else False, |
| pad_token_id=tok.eos_token_id, |
| eos_token_id=tok.eos_token_id, |
| ) |
| text = tok.decode(out[0], skip_special_tokens=True).strip() |
|
|
| |
| tool_calls: List[Dict[str, Any]] = [] |
| try: |
| |
| matches = list(re.finditer(r"\{.*?\"tool_calls\".*?\}", text, flags=re.S)) |
| if matches: |
| block = text[matches[-1].start():matches[-1].end()] |
| obj = json.loads(block) |
| tc = obj.get("tool_calls", []) |
| if isinstance(tc, list): |
| tool_calls = tc |
| except Exception: |
| pass |
|
|
| tool_results = maybe_execute_tool_calls(tool_calls) if tool_calls else [] |
|
|
| return {"text": text, "tool_calls": tool_calls, "tool_results": tool_results} |
|
|
|
|
| |
|
|
| def predict_for_engine(messages_json: str, tools_json: str) -> Dict[str, Any]: |
| """ |
| Endpoint esperado por ENGINE (ToolsClient.chat): |
| - messages_json: JSON de [{"role":"user|assistant|system","content":"..."}] |
| - tools_json: JSON OpenAI-like de herramientas (opcional) |
| Devuelve: {"text": "...", "tool_calls": [...], "tool_results": [...]} |
| """ |
| try: |
| messages = json.loads(messages_json) if messages_json else [] |
| except Exception: |
| messages = [] |
| try: |
| tools = json.loads(tools_json) if tools_json else [] |
| except Exception: |
| tools = [] |
| return _generate_with_tools(messages, tools, max_new_tokens=512, temperature=0.7, top_p=0.95) |
|
|
| def chat_advanced(messages_json: str, tools_json: str, max_new_tokens: int, temperature: float, top_p: float) -> Dict[str, Any]: |
| try: |
| messages = json.loads(messages_json) if messages_json else [] |
| except Exception: |
| messages = [] |
| try: |
| tools = json.loads(tools_json) if tools_json else [] |
| except Exception: |
| tools = [] |
| return _generate_with_tools(messages, tools, max_new_tokens=int(max_new_tokens), temperature=float(temperature), top_p=float(top_p)) |
|
|
|
|
| _salamandra = None |
|
|
| def salamandra_chat_endpoint(prompt: str) -> Dict[str, Any]: |
| global _salamandra |
| if _salamandra is None: |
| _salamandra = SalamandraClient() |
|
|
| try: |
| text = _salamandra.chat(prompt) |
| except Exception as e: |
| text = f"Error ejecutando SalamandraClient: {str(e)}" |
|
|
| return {"text": text} |
| |
| |
| custom_css = """ |
| h2 { |
| background: #e3e4e6 !important; |
| padding: 14px 22px !important; |
| border-radius: 14px !important; |
| box-shadow: 0 4px 12px rgba(0,0,0,0.08) !important; |
| display: block !important; /* ocupa tot l'ample */ |
| width: 100% !important; /* assegura 100% */ |
| margin: 20px auto !important; |
| text-align:center; |
| } |
| """ |
|
|
| |
| with gr.Blocks(title="Salamandra 7B Tools · ZeroGPU", css=custom_css, theme=gr.themes.Soft()) as demo: |
|
|
| |
| gr.Markdown("## Salamandra-7B-Tools · ZeroGPU\nXat amb especificació d'eines (function-calling).") |
|
|
| with gr.Row(): |
| with gr.Column(): |
|
|
| |
| messages = gr.Textbox( |
| label="Missatges (JSON)", |
| value='[{"role":"user","content":"Quant és (2+2)^3?"}]', |
| lines=6 |
| ) |
|
|
| |
| tools = gr.Textbox( |
| label="Eines (JSON, opcional)", |
| value='[{"type":"function","function":{"name":"calculator","description":"Avalua expressions aritmètiques bàsiques.","parameters":{"type":"object","properties":{"expr":{"type":"string"}},"required":["expr"]}}}]', |
| lines=6 |
| ) |
|
|
| |
| max_new = gr.Slider(16, 2048, value=512, step=16, label="Màxim de tokens nous") |
|
|
| |
| temp = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="Temperatura") |
|
|
| |
| topp = gr.Slider(0.1, 1.0, value=0.95, step=0.01, label="Top-p") |
|
|
| |
| btn = gr.Button("Generar", variant="primary") |
|
|
| with gr.Column(): |
| |
| out = gr.JSON(label="Sortida") |
|
|
| |
| btn.click( |
| chat_advanced, |
| [messages, tools, max_new, temp, topp], |
| out, |
| api_name="chat", |
| concurrency_limit=1 |
| ) |
|
|
| |
| gr.Markdown("---") |
| |
|
|
| |
| |
| gr.Button("Provar /predict").click( |
| predict_for_engine, |
| [messages, tools], |
| out, |
| api_name="predict", |
| concurrency_limit=1 |
| ) |
|
|
| |
| gr.Markdown("---") |
| |
|
|
| |
| with gr.Row(): |
| prompt = gr.Textbox(label="Prompt", lines=10) |
|
|
| with gr.Row(): |
| btn2 = gr.Button("Generar", variant="primary") |
|
|
| with gr.Row(): |
| out2 = gr.JSON(label="Sortida") |
|
|
| btn2.click( |
| salamandra_chat_endpoint, |
| [prompt], |
| out2, |
| api_name="generate_out_from_prompt", |
| concurrency_limit=1 |
| ) |
|
|
| |
| gr.Markdown("---") |
| |
|
|
| |
| demo.queue(max_size=16).launch() |
|
|