Spaces:
Build error
Build error
| import os | |
| import torch | |
| import torch.nn as nn | |
| import gradio as gr | |
| import yaml | |
| # import lightning | |
| from transformers import AutoTokenizer | |
| from lightning.pytorch.utilities.deepspeed import get_fp32_state_dict_from_zero_checkpoint | |
| # from utils import * | |
| from model.model_v2 import Modelv2 | |
| def main(): | |
| # cuda_id = get_free_gpu() | |
| # device = torch.device(f'cuda:{cuda_id}') | |
| device = torch.device('cpu') | |
| config_file = './hparams.yaml' | |
| ckpt_file = './epoch1-f1score0.56.ckpt/' | |
| with open(config_file, 'r', encoding='utf-8')as file: | |
| config_dic = yaml.load(file, Loader=yaml.FullLoader) | |
| model = Modelv2( | |
| model_name=config_dic['model_name'], | |
| share_encoder=config_dic['share_encoder'], | |
| ) | |
| state_dict = get_fp32_state_dict_from_zero_checkpoint(ckpt_file) | |
| new_state_dict = {k.replace('_forward_module.', ''):v for k,v in state_dict.items()} | |
| model.load_state_dict(new_state_dict) | |
| model.to(device) | |
| model.eval() | |
| tokenizer = AutoTokenizer.from_pretrained(config_dic['model_name'], cache_dir=config_dic['pretrained_model_fold']) | |
| def interface_fn(sentence): | |
| with torch.no_grad(): | |
| x_input = tokenizer([sentence], padding=True, truncation=True, return_tensors='pt') | |
| x_input = x_input.to(device) | |
| res = model.predict(x_input)[0] | |
| # return res | |
| res = bool(res.cpu().numpy()) | |
| # res = sentence+'hello' | |
| if res: | |
| return '存在谩骂类情感' | |
| else: | |
| return '不存在谩骂类情感' | |
| # return str(res) | |
| demo = gr.Interface( | |
| fn=interface_fn, | |
| inputs="text", | |
| outputs="text", | |
| examples=['Hello world', 'Nice to meet you', 'Fuck you', 'Son of bitch'] | |
| ) | |
| demo.launch(share=False) | |
| if __name__ == '__main__': | |
| main() |