Spaces:
Runtime error
Runtime error
| import os | |
| import spacy | |
| from accelerate import PartialState | |
| from accelerate.utils import set_seed | |
| from flask import Flask, request, jsonify | |
| from gpt2_generation import Translator | |
| from gpt2_generation import generate_prompt, MODEL_CLASSES | |
| os.environ["http_proxy"] = "http://127.0.0.1:7890" | |
| os.environ["https_proxy"] = "http://127.0.0.1:7890" | |
| app = Flask(__name__) | |
| path_for_model = "./output/gpt2_openprompt/checkpoint-4500" | |
| args = { | |
| "model_type": "gpt2", | |
| "model_name_or_path": path_for_model, | |
| "length": 80, | |
| "stop_token": None, | |
| "temperature": 1.0, | |
| "length_penalty": 1.2, | |
| "repetition_penalty": 1.2, | |
| "k": 3, | |
| "p": 0.9, | |
| "prefix": "", | |
| "padding_text": "", | |
| "xlm_language": "", | |
| "seed": 42, | |
| "use_cpu": False, | |
| "num_return_sequences": 1, | |
| "fp16": False, | |
| "jit": False, | |
| } | |
| distributed_state = PartialState(cpu=args["use_cpu"]) | |
| if args["seed"] is not None: | |
| set_seed(args["seed"]) | |
| tokenizer = None | |
| model = None | |
| zh_en_translator = None | |
| nlp = None | |
| def load_model_and_components(): | |
| global tokenizer, model, zh_en_translator, nlp | |
| # Initialize the model and tokenizer | |
| try: | |
| args["model_type"] = args["model_type"].lower() | |
| model_class, tokenizer_class = MODEL_CLASSES[args["model_type"]] | |
| except KeyError: | |
| raise KeyError("the model {} you specified is not supported. You are welcome to add it and open a PR :)") | |
| tokenizer = tokenizer_class.from_pretrained(args["model_name_or_path"], padding_side='left') | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.mask_token = tokenizer.eos_token | |
| model = model_class.from_pretrained(args["model_name_or_path"]) | |
| print("Model loaded!") | |
| # translator | |
| zh_en_translator = Translator("Helsinki-NLP/opus-mt-zh-en") | |
| print("Translator loaded!") | |
| # filter | |
| nlp = spacy.load('en_core_web_sm') | |
| print("Filter loaded!") | |
| # Set the model to the right device | |
| model.to(distributed_state.device) | |
| if args["fp16"]: | |
| model.half() | |
| def chat(): | |
| phrase = request.json.get('phrase') | |
| if tokenizer is None or model is None or zh_en_translator is None or nlp is None: | |
| load_model_and_components() | |
| messages = generate_prompt( | |
| prompt_text=phrase, | |
| args=args, | |
| zh_en_translator=zh_en_translator, | |
| nlp=nlp, | |
| model=model, | |
| tokenizer=tokenizer, | |
| distributed_state=distributed_state, | |
| ) | |
| return jsonify(messages) | |
| if __name__ == '__main__': | |
| load_model_and_components() | |
| app.run(host='0.0.0.0', port=10008, debug=False) | |