LH-Tech-AI's picture
Update app.py
5a7f82a verified
import os
import warnings
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
warnings.filterwarnings("ignore")
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, logging as hf_logging
hf_logging.set_verbosity_error()
# ── Config ────────────────────────────────────────────────────────────────────
MODEL_ID = "SupraLabs/Supra-50M-Instruct"
# ── Load model ────────────────────────────────────────────────────────────────
print(f"[*] Loading {MODEL_ID} on CPU...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, clean_up_tokenization_spaces=False)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.float32)
model.eval()
print("[+] Model ready.")
# ── Prompt builder (Alpaca format) ────────────────────────────────────────────
def build_prompt(history: list, system: str, new_message: str) -> str:
"""Convert chat history + new message into Alpaca instruct format."""
parts = []
if system.strip():
parts.append(
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
f"### Instruction:\n{system}\n\n### Response:\nUnderstood.\n\n"
)
# history is list of {"role": ..., "content": ...} dicts (Gradio 6 format)
for msg in history:
role = msg["role"] if isinstance(msg, dict) else msg[0]
content = msg["content"] if isinstance(msg, dict) else msg[1]
if role == "user":
parts.append(
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
f"### Instruction:\n{content}\n\n### Response:\n"
)
elif role == "assistant" and content:
parts.append(content + "\n\n")
# Add new user message
parts.append(
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
f"### Instruction:\n{new_message}\n\n### Response:\n"
)
return "".join(parts)
# ── Generate ────────────────────────────────────────────
from transformers import TextIteratorStreamer
from threading import Thread
def chat_stream(
message: str,
history: list,
system_prompt: str,
max_new_tokens: int,
temperature: float,
top_p: float,
repetition_penalty: float,
):
if not message.strip():
return
prompt = build_prompt(history, system_prompt, message)
inputs = tokenizer(prompt, return_tensors="pt")
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=int(max_new_tokens),
do_sample=temperature > 0.01,
temperature=float(temperature),
top_p=float(top_p),
top_k=50,
repetition_penalty=float(repetition_penalty),
pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
partial_text = ""
for new_text in streamer:
partial_text += new_text
yield partial_text
# ── UI ────────────────────────────────────────────────────────────────────────
with gr.Blocks(title="Supra-50M Instruct") as demo:
gr.Markdown(
"# πŸ¦… Supra-50M Instruct\n"
"50M-parameter chat model by [SupraLabs](https://huggingface.co/SupraLabs), running on CPU."
)
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.Chatbot(label="Chat", height=480)
msg_box = gr.Textbox(
placeholder="Type your message and press Enter…",
show_label=False,
lines=1,
max_lines=4,
)
with gr.Row():
submit_btn = gr.Button("πŸš€ Send", variant="primary")
stop_btn = gr.Button("πŸ›‘ Stop", variant="stop")
clear_btn = gr.Button("πŸ—‘οΈ Clear chat", variant="secondary")
with gr.Column(scale=1, min_width=220):
gr.Markdown("### βš™οΈ Parameters")
system_prompt = gr.Textbox(
label="System prompt",
value="",
lines=3,
)
max_new_tokens = gr.Slider(
label="Max new tokens", minimum=32, maximum=1024, value=512, step=32
)
temperature = gr.Slider(
label="Temperature", minimum=0.1, maximum=1.5, value=0.35, step=0.05
)
top_p = gr.Slider(
label="Top-p", minimum=0.1, maximum=1.0, value=0.7, step=0.05
)
repetition_penalty = gr.Slider(
label="Repetition penalty", minimum=1.0, maximum=1.5, value=1.3, step=0.05
)
# ── State & wiring ────────────────────────────────────────────────────────
chat_state = gr.State([])
def user_step(message, history):
if not message.strip():
return gr.update(), history, ""
new_history = history + [{"role": "user", "content": message}]
return new_history, new_history, ""
def bot_step(history, system, max_tok, temp, top_p_val, rep_pen):
if not history:
return history, history
user_message = history[-1]["content"]
history_before = history[:-1]
history = history + [{"role": "assistant", "content": ""}]
for response_chunk in chat_stream(user_message, history_before, system, max_tok, temp, top_p_val, rep_pen):
history[-1]["content"] = response_chunk
yield history, history
submit_event = msg_box.submit(
fn=user_step,
inputs=[msg_box, chat_state],
outputs=[chatbot, chat_state, msg_box],
queue=False
).then(
fn=bot_step,
inputs=[chat_state, system_prompt, max_new_tokens, temperature, top_p, repetition_penalty],
outputs=[chatbot, chat_state]
)
click_event = submit_btn.click(
fn=user_step,
inputs=[msg_box, chat_state],
outputs=[chatbot, chat_state, msg_box],
queue=False
).then(
fn=bot_step,
inputs=[chat_state, system_prompt, max_new_tokens, temperature, top_p, repetition_penalty],
outputs=[chatbot, chat_state]
)
stop_btn.click(
fn=None,
inputs=None,
outputs=None,
cancels=[submit_event, click_event]
)
clear_btn.click(
fn=lambda: ([], [], ""),
outputs=[chatbot, chat_state, msg_box],
)
gr.Markdown(
"<p style='text-align:center; color:#aaa; font-size:0.8rem; margin-top:8px;'>"
"Model: <a href='https://huggingface.co/SupraLabs/Supra-50M-Instruct' target='_blank'>"
"SupraLabs/Supra-50M-Instruct</a> β€” Apache 2.0 β€” Β© SupraLabs 2026</p>"
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
ssr_mode=False
)