| from transformers import BlenderbotTokenizer, BlenderbotForConditionalGeneration |
| import torch |
| import gradio as gr |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| model_name = "./blenderbot-1B-distill" |
| tokenizer = BlenderbotTokenizer.from_pretrained(model_name) |
| model = BlenderbotForConditionalGeneration.from_pretrained(model_name) |
| model.to(device) |
|
|
| def get_reply(response, username = None, histories = {}): |
| if username == None or username == "": return "<div class='chatbot'>Enter a username</div>", histories |
| history = histories.get(username, []) |
| history.append(response) |
| if response.endswith(("bye", "Bye", "bye.", "Bye.", "bye!", "Bye!","Hello", "Hi", "hello")): |
| histories[username] = [] |
| return "<div class='chatbot'>Chatbot restarted</div>", histories |
| if len(history) > 4: history = history[-4:] |
| inputs = tokenizer(" ".join(history), return_tensors="pt") |
| inputs.to(device) |
| outputs = model.generate(**inputs) |
| reply = tokenizer.decode(outputs[0][1:-1]).strip() |
| history.append(reply) |
|
|
| html = "<div class='chatbot'>" |
| for m, msg in enumerate(history): |
| cls = "user" if m%2 == 0 else "bot" |
| html += "<div class='msg {}'> {}</div>".format(cls, msg) |
| html += "</div>" |
| histories[username] = history |
| return html, histories |
|
|
| css = """ |
| .chatbox {display:flex;flex-direction:column} |
| .msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%} |
| .msg.user {background-color:cornflowerblue;color:white} |
| .msg.bot {background-color:lightgray;align-self:self-end} |
| .footer {display:none !important} |
| """ |
|
|
| gr.Interface(fn=get_reply, |
| theme="default", |
| inputs=[gr.inputs.Textbox(placeholder="How are you?"), |
| gr.inputs.Textbox(label="Username"), |
| "state"], |
| outputs=["html", "state"], |
| css=css).launch(debug=True, enable_queue=True) |