gijl commited on
Commit
70a3062
·
verified ·
1 Parent(s): 16c08e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -169
app.py CHANGED
@@ -14,7 +14,7 @@ model = AutoModelForCausalLM.from_pretrained(model_name,
14
  pipe = pipeline("text-generation",
15
  model=model_name,
16
  tokenizer=tokenizer,
17
- max_new_tokens=150,
18
  temperature=0.7)
19
 
20
  def generate_response(message, history):
@@ -36,173 +36,10 @@ def generate_response(message, history):
36
  response = pipe(messages)
37
  return response[0][0]['generated_text'][2]['content']
38
 
39
- def generate_response_stream(message, history, temperature, top_p, top_k, max_new_tokens, repeat_penalty):
40
- chat_messages = [{"role": "system", "content": "Você é ELIZA, uma terapeuta que responde com empatia e faz perguntas para entender melhor o paciente."}]
41
- for human, assistant in history:
42
- chat_messages.append({"role": "user", "content": human})
43
- if assistant is not None:
44
- chat_messages.append({"role": "assistant", "content": assistant})
45
- chat_messages.append({"role": "user", "content": message})
46
- input_ids = tokenizer.apply_chat_template(chat_messages, return_tensors="pt", add_generation_prompt=True).to(model.device)
47
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
48
- generation_kwargs = dict(
49
- input_ids=input_ids,
50
- streamer=streamer,
51
- max_new_tokens=int(max_new_tokens),
52
- temperature=float(temperature),
53
- top_p=float(top_p),
54
- top_k=int(top_k),
55
- repetition_penalty=float(repeat_penalty),
56
- do_sample=True,
57
- )
58
- thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
59
- thread.start()
60
- partial_text = ""
61
- for new_text in streamer:
62
- partial_text += new_text
63
- yield partial_text
64
-
65
- css = """
66
- body { background-color: #1a1a2e; }
67
- .sidebar { background-color: #16213e; border-right: 1px solid #0f3460; padding: 12px; border-radius: 8px; }
68
- .sidebar label { color: #e0e0e0 !important; font-size: 13px !important; }
69
- .sidebar .gr-slider { accent-color: #e94560; }
70
- .chat-area { background-color: #0f3460; border-radius: 8px; }
71
- .title-bar { color: #e94560; font-family: monospace; }
72
- .model-info { background-color: #0d1117; border: 1px solid #30363d; border-radius: 6px; padding: 8px; color: #58a6ff; font-family: monospace; font-size: 12px; }
73
- .gr-button-primary { background-color: #e94560 !important; border: none !important; }
74
- .gr-button { background-color: #16213e !important; color: #e0e0e0 !important; border: 1px solid #0f3460 !important; }
75
- footer { display: none !important; }
76
- """
77
-
78
- with gr.Blocks(css=css, title="Brain map — llama.cpp style") as demo:
79
-
80
- gr.Markdown(
81
- """
82
- <div class='title-bar'>
83
- <h2>🧠 Brain map &nbsp;|&nbsp; Distinguished Medical Assistant</h2>
84
- <p style='color:#8b949e;font-size:13px;font-family:monospace;'>Task execution • Organize a clear explanation • Streaming enabled ⚡</p>
85
- </div>
86
- """
87
- )
88
-
89
- with gr.Row(equal_height=True):
90
-
91
- with gr.Column(scale=1, min_width=260, elem_classes="sidebar"):
92
- gr.Markdown("### ⚙️ Inference Parameters")
93
- temperature_slider = gr.Slider(
94
- minimum=0.0, maximum=2.0, value=0.7, step=0.01,
95
- label="Temperature",
96
- info="Controls randomness. Lower = more deterministic."
97
- )
98
- top_p_slider = gr.Slider(
99
- minimum=0.0, maximum=1.0, value=0.95, step=0.01,
100
- label="Top-P (nucleus sampling)",
101
- info="Cumulative probability cutoff."
102
- )
103
- top_k_slider = gr.Slider(
104
- minimum=0, maximum=200, value=40, step=1,
105
- label="Top-K",
106
- info="Limits token candidates at each step."
107
- )
108
- max_tokens_slider = gr.Slider(
109
- minimum=1, maximum=2048, value=150, step=1,
110
- label="Max New Tokens",
111
- info="Maximum number of tokens to generate."
112
- )
113
- repeat_penalty_slider = gr.Slider(
114
- minimum=1.0, maximum=2.0, value=1.1, step=0.01,
115
- label="Repeat Penalty",
116
- info="Penalizes repeated tokens."
117
- )
118
- gr.Markdown("---")
119
- gr.Markdown("### 🤖 Model Info")
120
- gr.Textbox(
121
- value=model_name,
122
- label="Loaded Model",
123
- interactive=False,
124
- elem_classes="model-info"
125
- )
126
- gr.Textbox(
127
- value="float16 · auto device map",
128
- label="Precision / Device",
129
- interactive=False,
130
- elem_classes="model-info"
131
- )
132
- gr.Markdown("---")
133
- gr.Markdown("### 📋 Session")
134
- clear_btn = gr.Button("🗑️ Clear Chat", variant="secondary")
135
- stop_btn = gr.Button("⏹️ Stop Generation", variant="stop")
136
-
137
- with gr.Column(scale=4, elem_classes="chat-area"):
138
- chatbot = gr.Chatbot(
139
- label="Brain map Chat",
140
- height=520,
141
- show_label=True,
142
- avatar_images=(None, "https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg"),
143
- bubble_full_width=False,
144
- )
145
- with gr.Row():
146
- msg_box = gr.Textbox(
147
- placeholder="Type your message and press Enter or click Send …",
148
- label="",
149
- lines=2,
150
- scale=5,
151
- show_label=False,
152
- )
153
- send_btn = gr.Button("➤ Send", variant="primary", scale=1, min_width=90)
154
-
155
- gr.Markdown(
156
- "<p style='color:#555;font-size:11px;font-family:monospace;text-align:right;'>"
157
- "⚡ Streaming • 🔒 Local inference • Brain map v1.0"
158
- "</p>"
159
- )
160
-
161
- def user_message_submitted(message, history):
162
- return "", history + [[message, None]]
163
-
164
- def bot_streaming_response(history, temperature, top_p, top_k, max_new_tokens, repeat_penalty):
165
- if not history or history[-1][0] is None:
166
- yield history
167
- return
168
- user_message = history[-1][0]
169
- history[-1][1] = ""
170
- for partial_output in generate_response_stream(
171
- user_message,
172
- history[:-1],
173
- temperature,
174
- top_p,
175
- top_k,
176
- max_new_tokens,
177
- repeat_penalty,
178
- ):
179
- history[-1][1] = partial_output
180
- yield history
181
-
182
- submit_event = msg_box.submit(
183
- fn=user_message_submitted,
184
- inputs=[msg_box, chatbot],
185
- outputs=[msg_box, chatbot],
186
- queue=False,
187
- ).then(
188
- fn=bot_streaming_response,
189
- inputs=[chatbot, temperature_slider, top_p_slider, top_k_slider, max_tokens_slider, repeat_penalty_slider],
190
- outputs=chatbot,
191
- )
192
-
193
- click_event = send_btn.click(
194
- fn=user_message_submitted,
195
- inputs=[msg_box, chatbot],
196
- outputs=[msg_box, chatbot],
197
- queue=False,
198
- ).then(
199
- fn=bot_streaming_response,
200
- inputs=[chatbot, temperature_slider, top_p_slider, top_k_slider, max_tokens_slider, repeat_penalty_slider],
201
- outputs=chatbot,
202
- )
203
-
204
- stop_btn.click(fn=None, cancels=[submit_event, click_event])
205
-
206
- clear_btn.click(fn=lambda: [], inputs=None, outputs=chatbot)
207
 
208
  demo.launch()
 
14
  pipe = pipeline("text-generation",
15
  model=model_name,
16
  tokenizer=tokenizer,
17
+ max_new_tokens=1500,
18
  temperature=0.7)
19
 
20
  def generate_response(message, history):
 
36
  response = pipe(messages)
37
  return response[0][0]['generated_text'][2]['content']
38
 
39
+ demo = gr.ChatInterface(
40
+ generate_response,
41
+ title="ELIZA (com LLM)",
42
+ description="Compartilhe seus pensamentos e ELIZA irá ajudar você a refletir sobre eles."
43
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  demo.launch()