import os import gradio as gr import spaces import torch from threading import Thread from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer MODEL_ID = "lballore/llimba-3b-instruct" HF_TOKEN = os.environ.get("HF_TOKEN") DEMO_TOKEN = os.environ.get("DEMO_TOKEN", "") EXTERNAL_TOKENS = set( filter(None, os.environ.get("ALLOWED_API_TOKENS", "").split(",")) ) # --------------------------------------------------------------------------- # System prompts # --------------------------------------------------------------------------- CHAT_SYSTEM_PROMPT = ( "Ses unu assistente chi chistionat in sardu (LSC). " "Risponde in manera curtza, clara e pretzisa, chene repetire sa dimanda. " "Si non connosches una risposta o non ses seguru, nara-lu in manera onesta " "imbetzes de imbentare." ) TRANSLATE_SYSTEM_TEMPLATE = ( "Ses unu tradutore espertu. Traduzi in {tgt} su testu chi sighit. " "Su testu est in {src}. Risponde solu cun sa tradutzione, " "chene cummentos o ispiegatziones." ) LANGUAGES = { "Sardinian (LSC)": "sardu", "Italian": "italianu", "English": "inglesu", "Spanish": "ispagnolu", "French": "frantzesu", "Portuguese": "portoghesu", } # --------------------------------------------------------------------------- # Examples # --------------------------------------------------------------------------- CHAT_EXAMPLES = [ "Salude! Comente ìstas?", "Cale est sa capitale de sa Sardigna?", "Chie fiat Gigi Riva?", "Ite est su «cantu a tenore» sardu?", "Iscrie unu paragrafu in sardu subra de sa Sardigna.", ] TRANSLATE_EXAMPLES = [ "The weather is rough today.", "Buongiorno, come stai? È una bellissima giornata.", "La cultura sarda è ricca di tradizioni antiche.", ] # --------------------------------------------------------------------------- # Model loading (once at startup; ZeroGPU keeps weights on CPU until generate) # --------------------------------------------------------------------------- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, dtype=torch.bfloat16, device_map="auto", token=HF_TOKEN, ) # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _to_text(content): """Normalize Gradio 6.x message content (may be list of dicts) to string.""" if isinstance(content, str): return content if isinstance(content, list): return "".join( part.get("text", "") if isinstance(part, dict) else str(part) for part in content ) return str(content) def _normalize_history(history): return [ {"role": m["role"], "content": _to_text(m["content"])} for m in history ] def _is_authorized(api_token: str, request: gr.Request) -> bool: """ Return True if the call is allowed. UI calls from the demo page itself are auto-authorized using DEMO_TOKEN (server-side, never reaches the browser). External API calls must supply a token present in ALLOWED_API_TOKENS. The UI vs API distinction is made by inspecting the request path: Gradio routes API calls through /gradio_api/call/{name}. If a future Gradio version changes this routing, the heuristic must be updated. """ path = "" try: path = request.request.url.path if request and request.request else "" except AttributeError: path = "" is_api_call = "/gradio_api/call/" in path if is_api_call: return bool(api_token) and api_token in EXTERNAL_TOKENS else: return bool(DEMO_TOKEN) def _stream_response(messages, max_tokens, temperature, top_p, top_k, rep_penalty): """Run model.generate in a background thread and yield tokens as they arrive.""" inputs = tokenizer.apply_chat_template( messages, add_generation_prompt=True, return_tensors="pt", return_dict=True, ).to(model.device) streamer = TextIteratorStreamer( tokenizer, skip_prompt=True, skip_special_tokens=True, ) do_sample = temperature > 0.0 generation_kwargs = dict( **inputs, streamer=streamer, max_new_tokens=max_tokens, do_sample=do_sample, temperature=temperature if do_sample else 1.0, top_p=top_p if do_sample else 1.0, top_k=top_k if do_sample else 50, repetition_penalty=rep_penalty, pad_token_id=tokenizer.eos_token_id, ) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() response = "" for new_text in streamer: response += new_text yield response # --------------------------------------------------------------------------- # Per-tab respond functions # --------------------------------------------------------------------------- @spaces.GPU(duration=60) def respond_chat( message, history, system_message, max_tokens, temperature, top_p, top_k, rep_penalty, api_token, request: gr.Request, ): if not _is_authorized(api_token, request): yield "🔒 Valid API token required. Contact the project maintainer for access." return messages = [{"role": "system", "content": system_message}] messages.extend(_normalize_history(history)) messages.append({"role": "user", "content": _to_text(message)}) yield from _stream_response( messages, max_tokens, temperature, top_p, top_k, rep_penalty, ) @spaces.GPU(duration=60) def respond_translate( message, history, source_lang, target_lang, max_tokens, temperature, top_p, top_k, rep_penalty, api_token, request: gr.Request, ): if not _is_authorized(api_token, request): yield "🔒 Valid API token required. Contact the project maintainer for access." return src = LANGUAGES[source_lang] tgt = LANGUAGES[target_lang] system_message = TRANSLATE_SYSTEM_TEMPLATE.format(src=src, tgt=tgt) messages = [{"role": "system", "content": system_message}] messages.extend(_normalize_history(history)) messages.append({"role": "user", "content": _to_text(message)}) yield from _stream_response( messages, max_tokens, temperature, top_p, top_k, rep_penalty, ) # --------------------------------------------------------------------------- # UI # --------------------------------------------------------------------------- DESCRIPTION = ( "Chat with [LLiMba-3B-Instruct](https://huggingface.co/lballore/llimba-3b-instruct), " "an open 3B LLM that speaks **Sardinian** (LSC, with Logudorese and Campidanese " "accepted as input). The model retains the multilingual capabilities of its Qwen2.5 base." ) with gr.Blocks(title="LLiMba 3B Demo") as demo: gr.Markdown("# 💬 LLiMba 3B Demo") gr.Markdown(DESCRIPTION) with gr.Tabs(): # ----- Chat tab ----- with gr.Tab("💬 Chat"): gr.ChatInterface( respond_chat, additional_inputs=[ gr.Textbox( value=CHAT_SYSTEM_PROMPT, label="System message", info="Default tells the model to be concise and admit uncertainty.", lines=4, ), gr.Slider( minimum=1, maximum=2048, value=512, step=1, label="Max new tokens", ), gr.Slider( minimum=0.0, maximum=1.0, value=0.3, step=0.05, label="Temperature", info="0 = greedy. ≤0.5 recommended to limit hallucination and language drift.", ), gr.Slider( minimum=0.05, maximum=1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)", ), gr.Slider( minimum=1, maximum=100, value=40, step=1, label="Top-k", ), gr.Slider( minimum=1.0, maximum=2.0, value=1.05, step=0.05, label="Repetition penalty", ), gr.Textbox( value="", label="API token", info="Required for API access. Leave empty when using the demo page.", visible=False, ), ], examples=[[p] for p in CHAT_EXAMPLES], cache_examples=False, ) # ----- Translate tab ----- with gr.Tab("🌐 Translate"): gr.ChatInterface( respond_translate, additional_inputs=[ gr.Dropdown( choices=list(LANGUAGES.keys()), value="English", label="Source language", ), gr.Dropdown( choices=list(LANGUAGES.keys()), value="Sardinian (LSC)", label="Target language", ), gr.Slider( minimum=1, maximum=2048, value=512, step=1, label="Max new tokens", ), gr.Slider( minimum=0.0, maximum=1.0, value=0.0, step=0.05, label="Temperature", info="0 = greedy. Recommended for translation.", ), gr.Slider( minimum=0.05, maximum=1.0, value=1.0, step=0.05, label="Top-p", ), gr.Slider( minimum=1, maximum=100, value=1, step=1, label="Top-k", ), gr.Slider( minimum=1.0, maximum=2.0, value=1.0, step=0.05, label="Repetition penalty", ), gr.Textbox( value="", label="API token", info="Required for API access. Leave empty when using the demo page.", visible=False, ), ], examples=[[p] for p in TRANSLATE_EXAMPLES], cache_examples=False, ) if __name__ == "__main__": demo.launch(theme=gr.themes.Soft())