anthonyfang commited on
Commit
387be65
·
1 Parent(s): 8e123e8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +250 -0
app.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import gradio as gr
4
+ from text_generation import Client
5
+ from conversation import get_default_conv_template
6
+ from transformers import AutoTokenizer
7
+ from pymongo import MongoClient
8
+
9
+ DB_NAME = os.getenv("MONGO_DBNAME", "facebook/blenderbot-400M-distill")
10
+ USER = os.getenv("MONGO_USER")
11
+ PASSWORD = os.getenv("MONGO_PASSWORD")
12
+
13
+ uri = f"mongodb+srv://{USER}:{PASSWORD}@{DB_NAME}.kvwjiok.mongodb.net/?retryWrites=true&w=majority"
14
+ mongo_client = MongoClient(uri)
15
+ db = mongo_client[DB_NAME]
16
+ conversations_collection = db['conversations']
17
+
18
+ DESCRIPTION = """
19
+ # Language Models for Taiwanese Culture
20
+ <p align="center">
21
+ ✍️ <a href="https://huggingface.co/spaces/yentinglin/Taiwan-LLaMa2" target="_blank">Online Demo</a>
22
+
23
+ 🤗 <a href="https://huggingface.co/yentinglin" target="_blank">HF Repo</a> • 🐦 <a href="https://twitter.com/yentinglin56" target="_blank">Twitter</a> • 📃 <a href="https://arxiv.org/pdf/2305.13711.pdf" target="_blank">[Paper Coming Soon]</a>
24
+ • 👨️ <a href="https://github.com/MiuLab/Taiwan-LLaMa/tree/main" target="_blank">Github Repo</a>
25
+ <br/><br/>
26
+ <img src="https://www.csie.ntu.edu.tw/~miulab/taiwan-llama/logo-v2.png" width="100"> <br/>
27
+ </p>
28
+ Taiwan-LLaMa is a fine-tuned model specifically designed for traditional mandarin applications. It is built upon the LLaMa 2 architecture and includes a pretraining phase with over 5 billion tokens and fine-tuning with over 490k multi-turn conversational data in Traditional Mandarin.
29
+ ## Key Features
30
+ 1. **Traditional Mandarin Support**: The model is fine-tuned to understand and generate text in Traditional Mandarin, making it suitable for Taiwanese culture and related applications.
31
+ 2. **Instruction-Tuned**: Further fine-tuned on conversational data to offer context-aware and instruction-following responses.
32
+ 3. **Performance on Vicuna Benchmark**: Taiwan-LLaMa's relative performance on Vicuna Benchmark is measured against models like GPT-4 and ChatGPT. It's particularly optimized for Taiwanese culture.
33
+ 4. **Flexible Customization**: Advanced options for controlling the model's behavior like system prompt, temperature, top-p, and top-k are available in the demo.
34
+ ## Model Versions
35
+ Different versions of Taiwan-LLaMa are available:
36
+ - **Taiwan-LLaMa v1.0 (This demo)**: Optimized for Taiwanese Culture
37
+ - **Taiwan-LLaMa v0.9**: Partial instruction set
38
+ - **Taiwan-LLaMa v0.0**: No Traditional Mandarin pretraining
39
+ The models can be accessed from the provided links in the Hugging Face repository.
40
+ Try out the demo to interact with Taiwan-LLaMa and experience its capabilities in handling Traditional Mandarin!
41
+ """
42
+
43
+ LICENSE = """
44
+ ## Licenses
45
+ - Code is licensed under Apache 2.0 License.
46
+ - Models are licensed under the LLAMA 2 Community License.
47
+ - By using this model, you agree to the terms and conditions specified in the license.
48
+ - By using this demo, you agree to share your input utterances with us to improve the model.
49
+ ## Acknowledgements
50
+ Taiwan-LLaMa project acknowledges the efforts of the [Meta LLaMa team](https://github.com/facebookresearch/llama) and [Vicuna team](https://github.com/lm-sys/FastChat) in democratizing large language models.
51
+ """
52
+
53
+ DEFAULT_SYSTEM_PROMPT = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. You are built by NTU Miulab by Yen-Ting Lin for research purpose."
54
+
55
+ endpoint_url = os.environ.get("ENDPOINT_URL", "http://127.0.0.1:8080")
56
+ client = Client(endpoint_url, timeout=120)
57
+ eos_token = "</s>"
58
+ MAX_MAX_NEW_TOKENS = 1024
59
+ DEFAULT_MAX_NEW_TOKENS = 1024
60
+
61
+ max_prompt_length = 4096 - MAX_MAX_NEW_TOKENS - 10
62
+
63
+ model_name = "yentinglin/Taiwan-LLaMa-v1.0"
64
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
65
+
66
+ with gr.Blocks() as demo:
67
+ gr.Markdown(DESCRIPTION)
68
+
69
+ chatbot = gr.Chatbot()
70
+ with gr.Row():
71
+ msg = gr.Textbox(
72
+ container=False,
73
+ show_label=False,
74
+ placeholder='Type a message...',
75
+ scale=10,
76
+ )
77
+ submit_button = gr.Button('Submit',
78
+ variant='primary',
79
+ scale=1,
80
+ min_width=0)
81
+
82
+ with gr.Row():
83
+ retry_button = gr.Button('🔄 Retry', variant='secondary')
84
+ undo_button = gr.Button('↩️ Undo', variant='secondary')
85
+ clear = gr.Button('🗑️ Clear', variant='secondary')
86
+
87
+ saved_input = gr.State()
88
+
89
+ with gr.Accordion(label='Advanced options', open=False):
90
+ system_prompt = gr.Textbox(label='System prompt',
91
+ value=DEFAULT_SYSTEM_PROMPT,
92
+ lines=6)
93
+ max_new_tokens = gr.Slider(
94
+ label='Max new tokens',
95
+ minimum=1,
96
+ maximum=MAX_MAX_NEW_TOKENS,
97
+ step=1,
98
+ value=DEFAULT_MAX_NEW_TOKENS,
99
+ )
100
+ temperature = gr.Slider(
101
+ label='Temperature',
102
+ minimum=0.1,
103
+ maximum=1.0,
104
+ step=0.1,
105
+ value=0.7,
106
+ )
107
+ top_p = gr.Slider(
108
+ label='Top-p (nucleus sampling)',
109
+ minimum=0.05,
110
+ maximum=1.0,
111
+ step=0.05,
112
+ value=0.9,
113
+ )
114
+ top_k = gr.Slider(
115
+ label='Top-k',
116
+ minimum=1,
117
+ maximum=1000,
118
+ step=1,
119
+ value=50,
120
+ )
121
+
122
+ def user(user_message, history):
123
+ return "", history + [[user_message, None]]
124
+
125
+
126
+ def bot(history, max_new_tokens, temperature, top_p, top_k, system_prompt):
127
+ conv = get_default_conv_template("vicuna").copy()
128
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # map human to USER and gpt to ASSISTANT
129
+ conv.system = system_prompt
130
+ for user, bot in history:
131
+ conv.append_message(roles['human'], user)
132
+ conv.append_message(roles["gpt"], bot)
133
+ msg = conv.get_prompt()
134
+ prompt_tokens = tokenizer.encode(msg)
135
+ length_of_prompt = len(prompt_tokens)
136
+ if length_of_prompt > max_prompt_length:
137
+ msg = tokenizer.decode(prompt_tokens[-max_prompt_length + 1:])
138
+
139
+ history[-1][1] = ""
140
+ for response in client.generate_stream(
141
+ msg,
142
+ max_new_tokens=max_new_tokens,
143
+ temperature=temperature,
144
+ top_p=top_p,
145
+ top_k=top_k,
146
+ ):
147
+ if not response.token.special:
148
+ character = response.token.text
149
+ history[-1][1] += character
150
+ yield history
151
+
152
+ # After generating the response, store the conversation history in MongoDB
153
+ conversation_document = {
154
+ "model_name": model_name,
155
+ "history": history,
156
+ "system_prompt": system_prompt,
157
+ "max_new_tokens": max_new_tokens,
158
+ "temperature": temperature,
159
+ "top_p": top_p,
160
+ "top_k": top_k,
161
+ }
162
+ conversations_collection.insert_one(conversation_document)
163
+
164
+ msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
165
+ fn=bot,
166
+ inputs=[
167
+ chatbot,
168
+ max_new_tokens,
169
+ temperature,
170
+ top_p,
171
+ top_k,
172
+ system_prompt,
173
+ ],
174
+ outputs=chatbot
175
+ )
176
+ submit_button.click(
177
+ user, [msg, chatbot], [msg, chatbot], queue=False
178
+ ).then(
179
+ fn=bot,
180
+ inputs=[
181
+ chatbot,
182
+ max_new_tokens,
183
+ temperature,
184
+ top_p,
185
+ top_k,
186
+ system_prompt,
187
+ ],
188
+ outputs=chatbot
189
+ )
190
+
191
+
192
+ def delete_prev_fn(
193
+ history: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], str]:
194
+ try:
195
+ message, _ = history.pop()
196
+ except IndexError:
197
+ message = ''
198
+ return history, message or ''
199
+
200
+
201
+ def display_input(message: str,
202
+ history: list[tuple[str, str]]) -> list[tuple[str, str]]:
203
+ history.append((message, ''))
204
+ return history
205
+
206
+ retry_button.click(
207
+ fn=delete_prev_fn,
208
+ inputs=chatbot,
209
+ outputs=[chatbot, saved_input],
210
+ api_name=False,
211
+ queue=False,
212
+ ).then(
213
+ fn=display_input,
214
+ inputs=[saved_input, chatbot],
215
+ outputs=chatbot,
216
+ api_name=False,
217
+ queue=False,
218
+ ).then(
219
+ fn=bot,
220
+ inputs=[
221
+ chatbot,
222
+ max_new_tokens,
223
+ temperature,
224
+ top_p,
225
+ top_k,
226
+ system_prompt,
227
+ ],
228
+ outputs=chatbot,
229
+ )
230
+
231
+ undo_button.click(
232
+ fn=delete_prev_fn,
233
+ inputs=chatbot,
234
+ outputs=[chatbot, saved_input],
235
+ api_name=False,
236
+ queue=False,
237
+ ).then(
238
+ fn=lambda x: x,
239
+ inputs=[saved_input],
240
+ outputs=msg,
241
+ api_name=False,
242
+ queue=False,
243
+ )
244
+
245
+ clear.click(lambda: None, None, chatbot, queue=False)
246
+
247
+ gr.Markdown(LICENSE)
248
+
249
+ demo.queue(concurrency_count=4, max_size=128)
250
+ demo.launch()