| from flask import Flask, request |
| import requests |
| import os |
| import re |
| import textwrap |
| from transformers import AutoModelForSeq2SeqLM |
| from transformers import AutoTokenizer |
| from langdetect import detect |
| import subprocess |
|
|
| tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base") |
|
|
| vn_tokenizer = AutoTokenizer.from_pretrained("GuysTrans/bart-base-vn-ehealth-vn-tokenizer") |
|
|
| model = AutoModelForSeq2SeqLM.from_pretrained( |
| "GuysTrans/bart-base-finetuned-xsum", revision="worked") |
|
|
| vn_model = AutoModelForSeq2SeqLM.from_pretrained( |
| "GuysTrans/bart-base-vn-ehealth-vn-tokenizer", revision="worked") |
|
|
| map_words = { |
| "Hello and Welcome to 'Ask A Doctor' service": "", |
| "Hello,": "", |
| "Hi,": "", |
| "Hello": "", |
| "Hi": "", |
| "Ask A Doctor": "MedForum", |
| "H C M": "Med Forum" |
| } |
|
|
| word_remove_sentence = [ |
| "Welcome to", |
| |
| |
| |
| |
| |
| |
| ] |
|
|
|
|
| def generate_summary(question, model, tokenizer): |
| inputs = tokenizer( |
| question, |
| padding="max_length", |
| truncation=True, |
| max_length=512, |
| return_tensors="pt", |
| ) |
| input_ids = inputs.input_ids.to(model.device) |
| attention_mask = inputs.attention_mask.to(model.device) |
| outputs = model.generate( |
| input_ids, attention_mask=attention_mask, max_new_tokens=4096, do_sample=True, num_beams=4, top_k=50, early_stopping=True, no_repeat_ngram_size=2) |
| output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True) |
| return outputs, output_str |
|
|
|
|
| app = Flask(__name__) |
|
|
| FB_API_URL = 'https://graph.facebook.com/v2.6/me/messages' |
| VERIFY_TOKEN = '5rApTs/BRm6jtiwApOpIdjBHe73ifm6mNGZOsYkwwAw=' |
| |
| PAGE_ACCESS_TOKEN = os.environ['PAGE_ACCESS_TOKEN'] |
|
|
|
|
| def get_bot_response(message): |
| lang = detect(message) |
| model_use = model |
| tokenizer_use = tokenizer |
| template = "Welcome to MedForRum chatbot service. %s. Thanks for asking on MedForum." |
| if lang == "vi": |
| model_use = vn_model |
| tokenizer_use = vn_tokenizer |
| template = "Chào mừng bạn đến với dịch vụ MedForRum chatbot. %s. Cảm ơn bạn đã sử dụng MedForum." |
| return template % post_process(generate_summary(message, model_use, tokenizer_use)[1][0]) |
|
|
|
|
| def verify_webhook(req): |
| if req.args.get("hub.verify_token") == VERIFY_TOKEN: |
| return req.args.get("hub.challenge") |
| else: |
| return "incorrect" |
|
|
|
|
| def respond(sender, message): |
| """Formulate a response to the user and |
| pass it on to a function that sends it.""" |
| response = get_bot_response(message) |
| send_message(sender, response) |
| return response |
|
|
|
|
| def is_user_message(message): |
| """Check if the message is a message from the user""" |
| return (message.get('message') and |
| message['message'].get('text') and |
| not message['message'].get("is_echo")) |
|
|
|
|
| @app.route("/webhook", methods=['GET', 'POST']) |
| def listen(): |
| """This is the main function flask uses to |
| listen at the `/webhook` endpoint""" |
| if request.method == 'GET': |
| return verify_webhook(request) |
|
|
| if request.method == 'POST': |
| payload = request.json |
| event = payload['entry'][0]['messaging'] |
| for x in event: |
| if is_user_message(x): |
| text = x['message']['text'] |
| sender_id = x['sender']['id'] |
| respond(sender_id, text) |
|
|
| return "ok" |
|
|
|
|
| def send_message(recipient_id, text): |
| """Send a response to Facebook""" |
| payload = { |
| 'message': { |
| 'text': text |
| }, |
| 'recipient': { |
| 'id': recipient_id |
| }, |
| 'notification_type': 'regular' |
| } |
|
|
| auth = { |
| 'access_token': PAGE_ACCESS_TOKEN |
| } |
|
|
| response = requests.post( |
| FB_API_URL, |
| params=auth, |
| json=payload |
| ) |
|
|
| return response.json() |
|
|
|
|
| @app.route("/webhook/chat", methods=['POST']) |
| def chat(): |
| payload = request.json |
| message = payload['message'] |
| response = get_bot_response(message) |
| return {"message": response} |
|
|
| def post_process(output): |
| |
| |
| lines = output.split(".") |
| for line in lines: |
| for word in word_remove_sentence: |
| if word.lower() in line.lower(): |
| lines.remove(line) |
| break |
| |
| output = ".".join(lines) |
| for item in map_words.keys(): |
| output = re.sub(item, map_words[item], output, re.I) |
| |
| return textwrap.fill(textwrap.dedent(output).strip(), width=120) |
| |
|
|
|
|
| subprocess.Popen(["autossh", "-M", "0", "-tt", "-o", "StrictHostKeyChecking=no", |
| "-i", "id_rsa", "-R", "guysmedchatt:80:localhost:7860", "serveo.net"]) |
| |
|
|