Zpwang-AI's picture
Rename app.py to app_v1.py
9f289bd
import os
import torch
import torch.nn as nn
import gradio as gr
import yaml
# import lightning
from transformers import AutoTokenizer
from lightning.pytorch.utilities.deepspeed import get_fp32_state_dict_from_zero_checkpoint
# from utils import *
from model.model_v2 import Modelv2
def main():
# cuda_id = get_free_gpu()
# device = torch.device(f'cuda:{cuda_id}')
device = torch.device('cpu')
config_file = './hparams.yaml'
ckpt_file = './epoch1-f1score0.56.ckpt/'
with open(config_file, 'r', encoding='utf-8')as file:
config_dic = yaml.load(file, Loader=yaml.FullLoader)
model = Modelv2(
model_name=config_dic['model_name'],
share_encoder=config_dic['share_encoder'],
)
state_dict = get_fp32_state_dict_from_zero_checkpoint(ckpt_file)
new_state_dict = {k.replace('_forward_module.', ''):v for k,v in state_dict.items()}
model.load_state_dict(new_state_dict)
model.to(device)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(config_dic['model_name'], cache_dir=config_dic['pretrained_model_fold'])
def interface_fn(sentence):
with torch.no_grad():
x_input = tokenizer([sentence], padding=True, truncation=True, return_tensors='pt')
x_input = x_input.to(device)
res = model.predict(x_input)[0]
# return res
res = bool(res.cpu().numpy())
# res = sentence+'hello'
if res:
return '存在谩骂类情感'
else:
return '不存在谩骂类情感'
# return str(res)
demo = gr.Interface(
fn=interface_fn,
inputs="text",
outputs="text",
examples=['Hello world', 'Nice to meet you', 'Fuck you', 'Son of bitch']
)
demo.launch(share=False)
if __name__ == '__main__':
main()