LH-Tech-AI commited on
Commit
b896b6f
Β·
verified Β·
1 Parent(s): 4776b08

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -32
app.py CHANGED
@@ -60,9 +60,12 @@ def build_prompt(history: list, system: str, new_message: str) -> str:
60
  return "".join(parts)
61
 
62
 
63
- # ── Generate ──────────────────────────────────────────────────────────────────
64
 
65
- def chat(
 
 
 
66
  message: str,
67
  history: list,
68
  system_prompt: str,
@@ -70,28 +73,35 @@ def chat(
70
  temperature: float,
71
  top_p: float,
72
  repetition_penalty: float,
73
- ) -> str:
74
  if not message.strip():
75
- return ""
76
 
77
  prompt = build_prompt(history, system_prompt, message)
78
  inputs = tokenizer(prompt, return_tensors="pt")
79
 
80
- with torch.no_grad():
81
- output_ids = model.generate(
82
- **inputs,
83
- max_new_tokens=int(max_new_tokens),
84
- do_sample=temperature > 0.01,
85
- temperature=float(temperature),
86
- top_p=float(top_p),
87
- top_k=50,
88
- repetition_penalty=float(repetition_penalty),
89
- pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
90
- eos_token_id=tokenizer.eos_token_id,
91
- )
 
 
92
 
93
- new_tokens = output_ids[0][inputs["input_ids"].shape[-1]:]
94
- return tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
 
 
 
 
 
95
 
96
 
97
  # ── UI ────────────────────────────────────────────────────────────────────────
@@ -99,7 +109,7 @@ def chat(
99
  with gr.Blocks(title="Supra-50M Instruct") as demo:
100
  gr.Markdown(
101
  "# πŸ¦… Supra-50M Instruct\n"
102
- "50M-parameter chat model by [SupraLabs](https://huggingface.co/SupraLabs), running on CPU (Consider local inference, the quality is much better)"
103
  )
104
 
105
  with gr.Row():
@@ -110,8 +120,11 @@ with gr.Blocks(title="Supra-50M Instruct") as demo:
110
  show_label=False,
111
  lines=1,
112
  max_lines=4,
113
- submit_btn=True,
114
  )
 
 
 
 
115
 
116
  with gr.Column(scale=1, min_width=220):
117
  gr.Markdown("### βš™οΈ Parameters")
@@ -132,28 +145,57 @@ with gr.Blocks(title="Supra-50M Instruct") as demo:
132
  repetition_penalty = gr.Slider(
133
  label="Repetition penalty", minimum=1.0, maximum=1.5, value=1.15, step=0.05
134
  )
135
- clear_btn = gr.Button("πŸ—‘οΈ Clear chat", variant="secondary")
136
 
137
  # ── State & wiring ────────────────────────────────────────────────────────
138
 
139
  chat_state = gr.State([])
140
 
141
- def on_submit(message, history, system, max_tok, temp, top_p_val, rep_pen):
142
  if not message.strip():
143
- return history, history, ""
 
 
 
 
 
 
 
 
 
144
 
145
- response = chat(message, history, system, max_tok, temp, top_p_val, rep_pen)
146
 
147
- history = history + [
148
- {"role": "user", "content": message},
149
- {"role": "assistant", "content": response},
150
- ]
151
- return history, history, ""
152
 
153
- msg_box.submit(
154
- fn=on_submit,
155
- inputs=[msg_box, chat_state, system_prompt, max_new_tokens, temperature, top_p, repetition_penalty],
156
  outputs=[chatbot, chat_state, msg_box],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  )
158
 
159
  clear_btn.click(
 
60
  return "".join(parts)
61
 
62
 
63
+ # ── Generate ────────────────────────────────────────────
64
 
65
+ from transformers import TextIteratorStreamer
66
+ from threading import Thread
67
+
68
+ def chat_stream(
69
  message: str,
70
  history: list,
71
  system_prompt: str,
 
73
  temperature: float,
74
  top_p: float,
75
  repetition_penalty: float,
76
+ ):
77
  if not message.strip():
78
+ return
79
 
80
  prompt = build_prompt(history, system_prompt, message)
81
  inputs = tokenizer(prompt, return_tensors="pt")
82
 
83
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
84
+
85
+ generation_kwargs = dict(
86
+ **inputs,
87
+ streamer=streamer,
88
+ max_new_tokens=int(max_new_tokens),
89
+ do_sample=temperature > 0.01,
90
+ temperature=float(temperature),
91
+ top_p=float(top_p),
92
+ top_k=50,
93
+ repetition_penalty=float(repetition_penalty),
94
+ pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
95
+ eos_token_id=tokenizer.eos_token_id,
96
+ )
97
 
98
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
99
+ thread.start()
100
+
101
+ partial_text = ""
102
+ for new_text in streamer:
103
+ partial_text += new_text
104
+ yield partial_text
105
 
106
 
107
  # ── UI ────────────────────────────────────────────────────────────────────────
 
109
  with gr.Blocks(title="Supra-50M Instruct") as demo:
110
  gr.Markdown(
111
  "# πŸ¦… Supra-50M Instruct\n"
112
+ "50M-parameter chat model by [SupraLabs](https://huggingface.co/SupraLabs), running on CPU."
113
  )
114
 
115
  with gr.Row():
 
120
  show_label=False,
121
  lines=1,
122
  max_lines=4,
 
123
  )
124
+ with gr.Row():
125
+ submit_btn = gr.Button("πŸš€ Send", variant="primary")
126
+ stop_btn = gr.Button("πŸ›‘ Stop", variant="stop")
127
+ clear_btn = gr.Button("πŸ—‘οΈ Clear chat", variant="secondary")
128
 
129
  with gr.Column(scale=1, min_width=220):
130
  gr.Markdown("### βš™οΈ Parameters")
 
145
  repetition_penalty = gr.Slider(
146
  label="Repetition penalty", minimum=1.0, maximum=1.5, value=1.15, step=0.05
147
  )
 
148
 
149
  # ── State & wiring ────────────────────────────────────────────────────────
150
 
151
  chat_state = gr.State([])
152
 
153
+ def user_step(message, history):
154
  if not message.strip():
155
+ return gr.update(), history, ""
156
+ new_history = history + [{"role": "user", "content": message}]
157
+ return new_history, new_history, ""
158
+
159
+ def bot_step(history, system, max_tok, temp, top_p_val, rep_pen):
160
+ if not history:
161
+ return history, history
162
+
163
+ user_message = history[-1]["content"]
164
+ history_before = history[:-1]
165
 
166
+ history = history + [{"role": "assistant", "content": ""}]
167
 
168
+ for response_chunk in chat_stream(user_message, history_before, system, max_tok, temp, top_p_val, rep_pen):
169
+ history[-1]["content"] = response_chunk
170
+ yield history, history
 
 
171
 
172
+ submit_event = msg_box.submit(
173
+ fn=user_step,
174
+ inputs=[msg_box, chat_state],
175
  outputs=[chatbot, chat_state, msg_box],
176
+ queue=False
177
+ ).then(
178
+ fn=bot_step,
179
+ inputs=[chat_state, system_prompt, max_new_tokens, temperature, top_p, repetition_penalty],
180
+ outputs=[chatbot, chat_state]
181
+ )
182
+
183
+ click_event = submit_btn.click(
184
+ fn=user_step,
185
+ inputs=[msg_box, chat_state],
186
+ outputs=[chatbot, chat_state, msg_box],
187
+ queue=False
188
+ ).then(
189
+ fn=bot_step,
190
+ inputs=[chat_state, system_prompt, max_new_tokens, temperature, top_p, repetition_penalty],
191
+ outputs=[chatbot, chat_state]
192
+ )
193
+
194
+ stop_btn.click(
195
+ fn=None,
196
+ inputs=None,
197
+ outputs=None,
198
+ cancels=[submit_event, click_event]
199
  )
200
 
201
  clear_btn.click(