File size: 3,258 Bytes
0de99ef
 
1cb4813
0de99ef
 
1cb4813
0de99ef
1cb4813
0de99ef
0c3a5b0
 
 
 
 
 
 
 
 
 
 
 
 
175f545
0c3a5b0
 
 
 
0de99ef
 
 
 
 
 
1cb4813
0de99ef
 
1cb4813
 
 
0de99ef
1cb4813
0de99ef
 
 
 
 
 
 
 
 
 
 
 
 
 
1cb4813
 
0de99ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1cb4813
 
0de99ef
 
 
1cb4813
0de99ef
 
 
 
 
1cb4813
 
 
0de99ef
 
1cb4813
 
 
632096b
1cb4813
0de99ef
 
1cb4813
0de99ef
1cb4813
 
 
 
0de99ef
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import threading

import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

MODEL_ID = "HuggingFaceTB/nanowhale-100m"

print(f"Loading model {MODEL_ID} ...")
import torch
from safetensors.torch import load_file
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import hf_hub_download

# Load model (recommended: manual load for reliability)
config = AutoConfig.from_pretrained("HuggingFaceTB/nanowhale-100m", trust_remote_code=True)
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True).float()

# Download and load weights
weights_path = hf_hub_download("HuggingFaceTB/nanowhale-100m", "model.safetensors")
state_dict = load_file(weights_path)
model.load_state_dict(state_dict, strict=True)
model = model.eval()

tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/nanowhale-100m")


print("Model loaded.")

DEVICE = next(model.parameters()).device


def build_prompt(system_message: str, history: list[dict[str, str]], user_message: str) -> str:
    """
    Try to use the tokenizer's built-in chat template.
    Fall back to a simple newline-delimited format if none exists.
    """
    messages = [{"role": "system", "content": system_message}]
    messages.extend(history)
    messages.append({"role": "user", "content": user_message})

    if tokenizer.chat_template is not None:
        return tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
        )

    # Fallback format
    parts = [f"System: {system_message}\n"]
    for msg in history:
        role = "User" if msg["role"] == "user" else "Assistant"
        parts.append(f"{role}: {msg['content']}\n")
    parts.append(f"User: {user_message}\nAssistant:")
    return "".join(parts)


def respond(
    message: str,
    history: list[dict[str, str]],
    system_message: str,
    max_new_tokens: int,
    temperature: float,
    top_p: float,
):
    prompt = build_prompt(system_message, history, message)

    inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
    input_len = inputs["input_ids"].shape[-1]

    streamer = TextIteratorStreamer(
        tokenizer,
        skip_prompt=True,
        skip_special_tokens=True,
    )

    generation_kwargs = dict(
        **inputs,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_p=top_p,
        do_sample=temperature > 0.0,
        streamer=streamer,
    )

    thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    response = ""
    for token in streamer:
        response += token
        yield response

    thread.join()


chatbot = gr.ChatInterface(
    respond,
    title="Timmy, chat powered by nanowhale-100m",
    additional_inputs=[
        gr.Textbox(value="You are a friendly chatbot.", label="System message"),
        gr.Slider(minimum=1, maximum=512, value=128, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
    ],
)

if __name__ == "__main__":
    chatbot.launch()