Spaces:
Running
Running
| import os | |
| import warnings | |
| os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" | |
| warnings.filterwarnings("ignore") | |
| import torch | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, logging as hf_logging | |
| hf_logging.set_verbosity_error() | |
| # ββ Config ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| MODEL_ID = "SupraLabs/Supra-50M-Instruct" | |
| # ββ Load model ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print(f"[*] Loading {MODEL_ID} on CPU...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, clean_up_tokenization_spaces=False) | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.float32) | |
| model.eval() | |
| print("[+] Model ready.") | |
| # ββ Prompt builder (Alpaca format) ββββββββββββββββββββββββββββββββββββββββββββ | |
| def build_prompt(history: list, system: str, new_message: str) -> str: | |
| """Convert chat history + new message into Alpaca instruct format.""" | |
| parts = [] | |
| if system.strip(): | |
| parts.append( | |
| "Below is an instruction that describes a task. " | |
| "Write a response that appropriately completes the request.\n\n" | |
| f"### Instruction:\n{system}\n\n### Response:\nUnderstood.\n\n" | |
| ) | |
| # history is list of {"role": ..., "content": ...} dicts (Gradio 6 format) | |
| for msg in history: | |
| role = msg["role"] if isinstance(msg, dict) else msg[0] | |
| content = msg["content"] if isinstance(msg, dict) else msg[1] | |
| if role == "user": | |
| parts.append( | |
| "Below is an instruction that describes a task. " | |
| "Write a response that appropriately completes the request.\n\n" | |
| f"### Instruction:\n{content}\n\n### Response:\n" | |
| ) | |
| elif role == "assistant" and content: | |
| parts.append(content + "\n\n") | |
| # Add new user message | |
| parts.append( | |
| "Below is an instruction that describes a task. " | |
| "Write a response that appropriately completes the request.\n\n" | |
| f"### Instruction:\n{new_message}\n\n### Response:\n" | |
| ) | |
| return "".join(parts) | |
| # ββ Generate ββββββββββββββββββββββββββββββββββββββββββββ | |
| from transformers import TextIteratorStreamer | |
| from threading import Thread | |
| def chat_stream( | |
| message: str, | |
| history: list, | |
| system_prompt: str, | |
| max_new_tokens: int, | |
| temperature: float, | |
| top_p: float, | |
| repetition_penalty: float, | |
| ): | |
| if not message.strip(): | |
| return | |
| prompt = build_prompt(history, system_prompt, message) | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| generation_kwargs = dict( | |
| **inputs, | |
| streamer=streamer, | |
| max_new_tokens=int(max_new_tokens), | |
| do_sample=temperature > 0.01, | |
| temperature=float(temperature), | |
| top_p=float(top_p), | |
| top_k=50, | |
| repetition_penalty=float(repetition_penalty), | |
| pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| ) | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| partial_text = "" | |
| for new_text in streamer: | |
| partial_text += new_text | |
| yield partial_text | |
| # ββ UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Blocks(title="Supra-50M Instruct") as demo: | |
| gr.Markdown( | |
| "# π¦ Supra-50M Instruct\n" | |
| "50M-parameter chat model by [SupraLabs](https://huggingface.co/SupraLabs), running on CPU." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| chatbot = gr.Chatbot(label="Chat", height=480) | |
| msg_box = gr.Textbox( | |
| placeholder="Type your message and press Enterβ¦", | |
| show_label=False, | |
| lines=1, | |
| max_lines=4, | |
| ) | |
| with gr.Row(): | |
| submit_btn = gr.Button("π Send", variant="primary") | |
| stop_btn = gr.Button("π Stop", variant="stop") | |
| clear_btn = gr.Button("ποΈ Clear chat", variant="secondary") | |
| with gr.Column(scale=1, min_width=220): | |
| gr.Markdown("### βοΈ Parameters") | |
| system_prompt = gr.Textbox( | |
| label="System prompt", | |
| value="", | |
| lines=3, | |
| ) | |
| max_new_tokens = gr.Slider( | |
| label="Max new tokens", minimum=32, maximum=1024, value=512, step=32 | |
| ) | |
| temperature = gr.Slider( | |
| label="Temperature", minimum=0.1, maximum=1.5, value=0.35, step=0.05 | |
| ) | |
| top_p = gr.Slider( | |
| label="Top-p", minimum=0.1, maximum=1.0, value=0.7, step=0.05 | |
| ) | |
| repetition_penalty = gr.Slider( | |
| label="Repetition penalty", minimum=1.0, maximum=1.5, value=1.3, step=0.05 | |
| ) | |
| # ββ State & wiring ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| chat_state = gr.State([]) | |
| def user_step(message, history): | |
| if not message.strip(): | |
| return gr.update(), history, "" | |
| new_history = history + [{"role": "user", "content": message}] | |
| return new_history, new_history, "" | |
| def bot_step(history, system, max_tok, temp, top_p_val, rep_pen): | |
| if not history: | |
| return history, history | |
| user_message = history[-1]["content"] | |
| history_before = history[:-1] | |
| history = history + [{"role": "assistant", "content": ""}] | |
| for response_chunk in chat_stream(user_message, history_before, system, max_tok, temp, top_p_val, rep_pen): | |
| history[-1]["content"] = response_chunk | |
| yield history, history | |
| submit_event = msg_box.submit( | |
| fn=user_step, | |
| inputs=[msg_box, chat_state], | |
| outputs=[chatbot, chat_state, msg_box], | |
| queue=False | |
| ).then( | |
| fn=bot_step, | |
| inputs=[chat_state, system_prompt, max_new_tokens, temperature, top_p, repetition_penalty], | |
| outputs=[chatbot, chat_state] | |
| ) | |
| click_event = submit_btn.click( | |
| fn=user_step, | |
| inputs=[msg_box, chat_state], | |
| outputs=[chatbot, chat_state, msg_box], | |
| queue=False | |
| ).then( | |
| fn=bot_step, | |
| inputs=[chat_state, system_prompt, max_new_tokens, temperature, top_p, repetition_penalty], | |
| outputs=[chatbot, chat_state] | |
| ) | |
| stop_btn.click( | |
| fn=None, | |
| inputs=None, | |
| outputs=None, | |
| cancels=[submit_event, click_event] | |
| ) | |
| clear_btn.click( | |
| fn=lambda: ([], [], ""), | |
| outputs=[chatbot, chat_state, msg_box], | |
| ) | |
| gr.Markdown( | |
| "<p style='text-align:center; color:#aaa; font-size:0.8rem; margin-top:8px;'>" | |
| "Model: <a href='https://huggingface.co/SupraLabs/Supra-50M-Instruct' target='_blank'>" | |
| "SupraLabs/Supra-50M-Instruct</a> β Apache 2.0 β Β© SupraLabs 2026</p>" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| ssr_mode=False | |
| ) |