Spaces:
Running on Zero
Running on Zero
| import os | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| from threading import Thread | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| MODEL_ID = "lballore/llimba-3b-instruct" | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| DEMO_TOKEN = os.environ.get("DEMO_TOKEN", "") | |
| EXTERNAL_TOKENS = set( | |
| filter(None, os.environ.get("ALLOWED_API_TOKENS", "").split(",")) | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # System prompts | |
| # --------------------------------------------------------------------------- | |
| CHAT_SYSTEM_PROMPT = ( | |
| "Ses unu assistente chi chistionat in sardu (LSC). " | |
| "Risponde in manera curtza, clara e pretzisa, chene repetire sa dimanda. " | |
| "Si non connosches una risposta o non ses seguru, nara-lu in manera onesta " | |
| "imbetzes de imbentare." | |
| ) | |
| TRANSLATE_SYSTEM_TEMPLATE = ( | |
| "Ses unu tradutore espertu. Traduzi in {tgt} su testu chi sighit. " | |
| "Su testu est in {src}. Risponde solu cun sa tradutzione, " | |
| "chene cummentos o ispiegatziones." | |
| ) | |
| LANGUAGES = { | |
| "Sardinian (LSC)": "sardu", | |
| "Italian": "italianu", | |
| "English": "inglesu", | |
| "Spanish": "ispagnolu", | |
| "French": "frantzesu", | |
| "Portuguese": "portoghesu", | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Examples | |
| # --------------------------------------------------------------------------- | |
| CHAT_EXAMPLES = [ | |
| "Salude! Comente ìstas?", | |
| "Cale est sa capitale de sa Sardigna?", | |
| "Chie fiat Gigi Riva?", | |
| "Ite est su «cantu a tenore» sardu?", | |
| "Iscrie unu paragrafu in sardu subra de sa Sardigna.", | |
| ] | |
| TRANSLATE_EXAMPLES = [ | |
| "The weather is rough today.", | |
| "Buongiorno, come stai? È una bellissima giornata.", | |
| "La cultura sarda è ricca di tradizioni antiche.", | |
| ] | |
| # --------------------------------------------------------------------------- | |
| # Model loading (once at startup; ZeroGPU keeps weights on CPU until generate) | |
| # --------------------------------------------------------------------------- | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| dtype=torch.bfloat16, | |
| device_map="auto", | |
| token=HF_TOKEN, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Helpers | |
| # --------------------------------------------------------------------------- | |
| def _to_text(content): | |
| """Normalize Gradio 6.x message content (may be list of dicts) to string.""" | |
| if isinstance(content, str): | |
| return content | |
| if isinstance(content, list): | |
| return "".join( | |
| part.get("text", "") if isinstance(part, dict) else str(part) | |
| for part in content | |
| ) | |
| return str(content) | |
| def _normalize_history(history): | |
| return [ | |
| {"role": m["role"], "content": _to_text(m["content"])} | |
| for m in history | |
| ] | |
| def _is_authorized(api_token: str, request: gr.Request) -> bool: | |
| """ | |
| Return True if the call is allowed. | |
| UI calls from the demo page itself are auto-authorized using DEMO_TOKEN | |
| (server-side, never reaches the browser). External API calls must supply | |
| a token present in ALLOWED_API_TOKENS. | |
| The UI vs API distinction is made by inspecting the request path: Gradio | |
| routes API calls through /gradio_api/call/{name}. If a future Gradio | |
| version changes this routing, the heuristic must be updated. | |
| """ | |
| path = "" | |
| try: | |
| path = request.request.url.path if request and request.request else "" | |
| except AttributeError: | |
| path = "" | |
| is_api_call = "/gradio_api/call/" in path | |
| if is_api_call: | |
| return bool(api_token) and api_token in EXTERNAL_TOKENS | |
| else: | |
| return bool(DEMO_TOKEN) | |
| def _stream_response(messages, max_tokens, temperature, top_p, top_k, rep_penalty): | |
| """Run model.generate in a background thread and yield tokens as they arrive.""" | |
| inputs = tokenizer.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| return_tensors="pt", | |
| return_dict=True, | |
| ).to(model.device) | |
| streamer = TextIteratorStreamer( | |
| tokenizer, | |
| skip_prompt=True, | |
| skip_special_tokens=True, | |
| ) | |
| do_sample = temperature > 0.0 | |
| generation_kwargs = dict( | |
| **inputs, | |
| streamer=streamer, | |
| max_new_tokens=max_tokens, | |
| do_sample=do_sample, | |
| temperature=temperature if do_sample else 1.0, | |
| top_p=top_p if do_sample else 1.0, | |
| top_k=top_k if do_sample else 50, | |
| repetition_penalty=rep_penalty, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| response = "" | |
| for new_text in streamer: | |
| response += new_text | |
| yield response | |
| # --------------------------------------------------------------------------- | |
| # Per-tab respond functions | |
| # --------------------------------------------------------------------------- | |
| def respond_chat( | |
| message, history, | |
| system_message, max_tokens, temperature, top_p, top_k, rep_penalty, | |
| api_token, | |
| request: gr.Request, | |
| ): | |
| if not _is_authorized(api_token, request): | |
| yield "🔒 Valid API token required. Contact the project maintainer for access." | |
| return | |
| messages = [{"role": "system", "content": system_message}] | |
| messages.extend(_normalize_history(history)) | |
| messages.append({"role": "user", "content": _to_text(message)}) | |
| yield from _stream_response( | |
| messages, max_tokens, temperature, top_p, top_k, rep_penalty, | |
| ) | |
| def respond_translate( | |
| message, history, | |
| source_lang, target_lang, | |
| max_tokens, temperature, top_p, top_k, rep_penalty, | |
| api_token, | |
| request: gr.Request, | |
| ): | |
| if not _is_authorized(api_token, request): | |
| yield "🔒 Valid API token required. Contact the project maintainer for access." | |
| return | |
| src = LANGUAGES[source_lang] | |
| tgt = LANGUAGES[target_lang] | |
| system_message = TRANSLATE_SYSTEM_TEMPLATE.format(src=src, tgt=tgt) | |
| messages = [{"role": "system", "content": system_message}] | |
| messages.extend(_normalize_history(history)) | |
| messages.append({"role": "user", "content": _to_text(message)}) | |
| yield from _stream_response( | |
| messages, max_tokens, temperature, top_p, top_k, rep_penalty, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # UI | |
| # --------------------------------------------------------------------------- | |
| DESCRIPTION = ( | |
| "Chat with [LLiMba-3B-Instruct](https://huggingface.co/lballore/llimba-3b-instruct), " | |
| "an open 3B LLM that speaks **Sardinian** (LSC, with Logudorese and Campidanese " | |
| "accepted as input). The model retains the multilingual capabilities of its Qwen2.5 base." | |
| ) | |
| with gr.Blocks(title="LLiMba 3B Demo") as demo: | |
| gr.Markdown("# 💬 LLiMba 3B Demo") | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Tabs(): | |
| # ----- Chat tab ----- | |
| with gr.Tab("💬 Chat"): | |
| gr.ChatInterface( | |
| respond_chat, | |
| additional_inputs=[ | |
| gr.Textbox( | |
| value=CHAT_SYSTEM_PROMPT, | |
| label="System message", | |
| info="Default tells the model to be concise and admit uncertainty.", | |
| lines=4, | |
| ), | |
| gr.Slider( | |
| minimum=1, maximum=2048, value=512, step=1, | |
| label="Max new tokens", | |
| ), | |
| gr.Slider( | |
| minimum=0.0, maximum=1.0, value=0.3, step=0.05, | |
| label="Temperature", | |
| info="0 = greedy. ≤0.5 recommended to limit hallucination and language drift.", | |
| ), | |
| gr.Slider( | |
| minimum=0.05, maximum=1.0, value=0.9, step=0.05, | |
| label="Top-p (nucleus sampling)", | |
| ), | |
| gr.Slider( | |
| minimum=1, maximum=100, value=40, step=1, | |
| label="Top-k", | |
| ), | |
| gr.Slider( | |
| minimum=1.0, maximum=2.0, value=1.05, step=0.05, | |
| label="Repetition penalty", | |
| ), | |
| gr.Textbox( | |
| value="", | |
| label="API token", | |
| info="Required for API access. Leave empty when using the demo page.", | |
| visible=False, | |
| ), | |
| ], | |
| examples=[[p] for p in CHAT_EXAMPLES], | |
| cache_examples=False, | |
| ) | |
| # ----- Translate tab ----- | |
| with gr.Tab("🌐 Translate"): | |
| gr.ChatInterface( | |
| respond_translate, | |
| additional_inputs=[ | |
| gr.Dropdown( | |
| choices=list(LANGUAGES.keys()), | |
| value="English", | |
| label="Source language", | |
| ), | |
| gr.Dropdown( | |
| choices=list(LANGUAGES.keys()), | |
| value="Sardinian (LSC)", | |
| label="Target language", | |
| ), | |
| gr.Slider( | |
| minimum=1, maximum=2048, value=512, step=1, | |
| label="Max new tokens", | |
| ), | |
| gr.Slider( | |
| minimum=0.0, maximum=1.0, value=0.0, step=0.05, | |
| label="Temperature", | |
| info="0 = greedy. Recommended for translation.", | |
| ), | |
| gr.Slider( | |
| minimum=0.05, maximum=1.0, value=1.0, step=0.05, | |
| label="Top-p", | |
| ), | |
| gr.Slider( | |
| minimum=1, maximum=100, value=1, step=1, | |
| label="Top-k", | |
| ), | |
| gr.Slider( | |
| minimum=1.0, maximum=2.0, value=1.0, step=0.05, | |
| label="Repetition penalty", | |
| ), | |
| gr.Textbox( | |
| value="", | |
| label="API token", | |
| info="Required for API access. Leave empty when using the demo page.", | |
| visible=False, | |
| ), | |
| ], | |
| examples=[[p] for p in TRANSLATE_EXAMPLES], | |
| cache_examples=False, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(theme=gr.themes.Soft()) | |