Spaces:
Running
Running
File size: 3,258 Bytes
0de99ef 1cb4813 0de99ef 1cb4813 0de99ef 1cb4813 0de99ef 0c3a5b0 175f545 0c3a5b0 0de99ef 1cb4813 0de99ef 1cb4813 0de99ef 1cb4813 0de99ef 1cb4813 0de99ef 1cb4813 0de99ef 1cb4813 0de99ef 1cb4813 0de99ef 1cb4813 632096b 1cb4813 0de99ef 1cb4813 0de99ef 1cb4813 0de99ef | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 | import threading
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
MODEL_ID = "HuggingFaceTB/nanowhale-100m"
print(f"Loading model {MODEL_ID} ...")
import torch
from safetensors.torch import load_file
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import hf_hub_download
# Load model (recommended: manual load for reliability)
config = AutoConfig.from_pretrained("HuggingFaceTB/nanowhale-100m", trust_remote_code=True)
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True).float()
# Download and load weights
weights_path = hf_hub_download("HuggingFaceTB/nanowhale-100m", "model.safetensors")
state_dict = load_file(weights_path)
model.load_state_dict(state_dict, strict=True)
model = model.eval()
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/nanowhale-100m")
print("Model loaded.")
DEVICE = next(model.parameters()).device
def build_prompt(system_message: str, history: list[dict[str, str]], user_message: str) -> str:
"""
Try to use the tokenizer's built-in chat template.
Fall back to a simple newline-delimited format if none exists.
"""
messages = [{"role": "system", "content": system_message}]
messages.extend(history)
messages.append({"role": "user", "content": user_message})
if tokenizer.chat_template is not None:
return tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
# Fallback format
parts = [f"System: {system_message}\n"]
for msg in history:
role = "User" if msg["role"] == "user" else "Assistant"
parts.append(f"{role}: {msg['content']}\n")
parts.append(f"User: {user_message}\nAssistant:")
return "".join(parts)
def respond(
message: str,
history: list[dict[str, str]],
system_message: str,
max_new_tokens: int,
temperature: float,
top_p: float,
):
prompt = build_prompt(system_message, history, message)
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
input_len = inputs["input_ids"].shape[-1]
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=True,
)
generation_kwargs = dict(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=temperature > 0.0,
streamer=streamer,
)
thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
response = ""
for token in streamer:
response += token
yield response
thread.join()
chatbot = gr.ChatInterface(
respond,
title="Timmy, chat powered by nanowhale-100m",
additional_inputs=[
gr.Textbox(value="You are a friendly chatbot.", label="System message"),
gr.Slider(minimum=1, maximum=512, value=128, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
],
)
if __name__ == "__main__":
chatbot.launch() |