D3V1L1810's picture
Update app.py
09aaef7 verified
raw
history blame
3.06 kB
import tensorflow as tf
from transformers import BertTokenizer, TFBertForSequenceClassification
import numpy as np
import json
import requests
import gradio as gr
import logging
bert_tokenizer = BertTokenizer.from_pretrained('MultiTokenizer_ep10')
bert_model = TFBertForSequenceClassification.from_pretrained('MultiModel_ep10')
# def send_results_to_api(data, result_url):
# headers = {'Content-Type':'application/json'}
# response = requests.post(result_url, json = data, headers=headers)
# if response.status_code == 200:
# return response.json
# else:
# return {'error':f"failed to send result to API: {response.status_code}"}
def predict_text(params):
try:
params = json.loads(params)
except JSONDecodeError as e:
logging.error(f"Invalid JSON input: {e.msg} at line {e.lineno} column {e.colno}")
return {"error": f"Invalid JSON input: {e.msg} at line {e.lineno} column {e.colno}"}
texts = params.get("urls",[])
if not params.get("normalfileID",[]):
file_ids = [None]*len(texts)
else:
file_ids = params.get("normalfileID",[])
# api = params.get("api", "")
# job_id = params.get("job_id","")
if not texts:
return {"error": "Missing required parameters: 'texts'"}
solutions = []
for text,file_id in zip(texts,file_ids):
encoding = bert_tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=128,
return_token_type_ids=True,
padding = 'max_length',
truncation=True,
return_attention_mask=True,
return_tensors='tf'
)
input_ids = encoding['input_ids']
token_type_ids = encoding['token_type_ids']
attention_mask = encoding['attention_mask']
pred = bert_model.predict([input_ids, token_type_ids, attention_mask])
logits = pred.logits
pred_label = tf.argmax(logits, axis=1).numpy()[0]
print(f"{logits}\t{pred_label}")
if not pred_label:
predict_label = 7
label = {0: 'BUSINESS', 1: 'COMEDY', 2: 'CRIME', 3: 'FOOD & DRINK', 4: 'POLITICS', 5: 'SPORTS', 6: 'TRAVEL', 7: 'None'}
result = {'text':text, 'answer':[label[pred_label]], "qcUser" : None,"normalfileID":file_id}
solutions.append(result)
# result_url = f"{api}/{job_id}"
# send_results_to_api(solutions, result_url)
return json.dumps({"solutions":solutions})
inputt = gr.Textbox(label="Parameters in Json Format... Eg. {'texts':['text1', 'text2']")
outputt = gr.JSON()
application = gr.Interface(fn = predict_text, inputs = inputt, outputs = outputt, title='Multi Text Classification with API Integration..')
application.launch()