Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -231,10 +231,13 @@ chatbot = gr.Chatbot(
|
|
| 231 |
def chat_interface(user_input, history, web_search, decoding_strategy, temperature, max_new_tokens, repetition_penalty, top_p):
|
| 232 |
# Ensure the tokenizer is accessible within the function scope
|
| 233 |
global tokenizer
|
|
|
|
|
|
|
|
|
|
| 234 |
|
| 235 |
# Perform model inference
|
| 236 |
response = model_inference(
|
| 237 |
-
user_prompt=
|
| 238 |
chat_history=history,
|
| 239 |
web_search=web_search,
|
| 240 |
temperature=temperature,
|
|
@@ -244,18 +247,18 @@ def chat_interface(user_input, history, web_search, decoding_strategy, temperatu
|
|
| 244 |
tokenizer=tokenizer # Pass tokenizer to the model_inference function
|
| 245 |
)
|
| 246 |
|
| 247 |
-
# Update
|
| 248 |
-
history.append(
|
| 249 |
|
| 250 |
-
# Return the
|
| 251 |
-
return
|
| 252 |
|
| 253 |
# Define the Gradio interface components
|
| 254 |
interface = gr.Interface(
|
| 255 |
fn=chat_interface,
|
| 256 |
inputs=[
|
| 257 |
gr.Textbox(label="User Input", placeholder="Type your message here..."),
|
| 258 |
-
gr.State([]), #
|
| 259 |
gr.Checkbox(label="Perform Web Search"),
|
| 260 |
gr.Radio(["Greedy", "Top P Sampling"], label="Decoding strategy"),
|
| 261 |
gr.Slider(minimum=0.0, maximum=2.0, step=0.05, label="Sampling temperature", value=0.5),
|
|
@@ -263,10 +266,11 @@ interface = gr.Interface(
|
|
| 263 |
gr.Slider(minimum=0.01, maximum=5.0, step=0.01, label="Repetition penalty", value=1),
|
| 264 |
gr.Slider(minimum=0.01, maximum=0.99, step=0.01, label="Top P", value=0.9)
|
| 265 |
],
|
| 266 |
-
outputs=[
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
|
|
|
| 270 |
)
|
| 271 |
|
| 272 |
# Launch the Gradio interface
|
|
|
|
| 231 |
def chat_interface(user_input, history, web_search, decoding_strategy, temperature, max_new_tokens, repetition_penalty, top_p):
|
| 232 |
# Ensure the tokenizer is accessible within the function scope
|
| 233 |
global tokenizer
|
| 234 |
+
|
| 235 |
+
# Wrap the user input in a dictionary as expected by the model_inference function
|
| 236 |
+
user_prompt = {"text": user_input, "files": []}
|
| 237 |
|
| 238 |
# Perform model inference
|
| 239 |
response = model_inference(
|
| 240 |
+
user_prompt=user_prompt,
|
| 241 |
chat_history=history,
|
| 242 |
web_search=web_search,
|
| 243 |
temperature=temperature,
|
|
|
|
| 247 |
tokenizer=tokenizer # Pass tokenizer to the model_inference function
|
| 248 |
)
|
| 249 |
|
| 250 |
+
# Update history with the user input and model response
|
| 251 |
+
history.append((user_input, response))
|
| 252 |
|
| 253 |
+
# Return the response and updated history
|
| 254 |
+
return response, history
|
| 255 |
|
| 256 |
# Define the Gradio interface components
|
| 257 |
interface = gr.Interface(
|
| 258 |
fn=chat_interface,
|
| 259 |
inputs=[
|
| 260 |
gr.Textbox(label="User Input", placeholder="Type your message here..."),
|
| 261 |
+
gr.State([]), # Initialize the chat history as an empty list
|
| 262 |
gr.Checkbox(label="Perform Web Search"),
|
| 263 |
gr.Radio(["Greedy", "Top P Sampling"], label="Decoding strategy"),
|
| 264 |
gr.Slider(minimum=0.0, maximum=2.0, step=0.05, label="Sampling temperature", value=0.5),
|
|
|
|
| 266 |
gr.Slider(minimum=0.01, maximum=5.0, step=0.01, label="Repetition penalty", value=1),
|
| 267 |
gr.Slider(minimum=0.01, maximum=0.99, step=0.01, label="Top P", value=0.9)
|
| 268 |
],
|
| 269 |
+
outputs=[
|
| 270 |
+
gr.Textbox(label="Assistant Response"),
|
| 271 |
+
gr.State([]) # Update the chat history
|
| 272 |
+
],
|
| 273 |
+
live=True
|
| 274 |
)
|
| 275 |
|
| 276 |
# Launch the Gradio interface
|