Timmy / app.py
philipp-zettl's picture
Update app.py
632096b verified
import threading
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
MODEL_ID = "HuggingFaceTB/nanowhale-100m"
print(f"Loading model {MODEL_ID} ...")
import torch
from safetensors.torch import load_file
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import hf_hub_download
# Load model (recommended: manual load for reliability)
config = AutoConfig.from_pretrained("HuggingFaceTB/nanowhale-100m", trust_remote_code=True)
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True).float()
# Download and load weights
weights_path = hf_hub_download("HuggingFaceTB/nanowhale-100m", "model.safetensors")
state_dict = load_file(weights_path)
model.load_state_dict(state_dict, strict=True)
model = model.eval()
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/nanowhale-100m")
print("Model loaded.")
DEVICE = next(model.parameters()).device
def build_prompt(system_message: str, history: list[dict[str, str]], user_message: str) -> str:
"""
Try to use the tokenizer's built-in chat template.
Fall back to a simple newline-delimited format if none exists.
"""
messages = [{"role": "system", "content": system_message}]
messages.extend(history)
messages.append({"role": "user", "content": user_message})
if tokenizer.chat_template is not None:
return tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
# Fallback format
parts = [f"System: {system_message}\n"]
for msg in history:
role = "User" if msg["role"] == "user" else "Assistant"
parts.append(f"{role}: {msg['content']}\n")
parts.append(f"User: {user_message}\nAssistant:")
return "".join(parts)
def respond(
message: str,
history: list[dict[str, str]],
system_message: str,
max_new_tokens: int,
temperature: float,
top_p: float,
):
prompt = build_prompt(system_message, history, message)
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
input_len = inputs["input_ids"].shape[-1]
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=True,
)
generation_kwargs = dict(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=temperature > 0.0,
streamer=streamer,
)
thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
response = ""
for token in streamer:
response += token
yield response
thread.join()
chatbot = gr.ChatInterface(
respond,
title="Timmy, chat powered by nanowhale-100m",
additional_inputs=[
gr.Textbox(value="You are a friendly chatbot.", label="System message"),
gr.Slider(minimum=1, maximum=512, value=128, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
],
)
if __name__ == "__main__":
chatbot.launch()