D3V1L1810's picture
Upload 2 files
3ed9870 verified
raw
history blame
2.75 kB
import tensorflow as tf
from transformers import BertTokenizer, TFBertForSequenceClassification
import numpy as np
import json
import requests
import gradio as gr
import logging
bert_tokenizer = BertTokenizer.from_pretrained('MultiTokenizer_ep10')
bert_model = TFBertForSequenceClassification.from_pretrained('MultiModel_ep10')
# def send_results_to_api(data, result_url):
# headers = {'Content-Type':'application/json'}
# response = requests.post(result_url, json = data, headers=headers)
# if response.status_code == 200:
# return response.json
# else:
# return {'error':f"failed to send result to API: {response.status_code}"}
def predict_text(params):
try:
params = json.loads(params)
except JSONDecodeError as e:
logging.error(f"Invalid JSON input: {e.msg} at line {e.lineno} column {e.colno}")
return {"error": f"Invalid JSON input: {e.msg} at line {e.lineno} column {e.colno}"}
texts = params.get("texts",[])
# api = params.get("api", "")
# job_id = params.get("job_id","")
if not texts:
return {"error": "Missing required parameters: 'urls'"}
solutions = []
for text in texts:
encoding = bert_tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=128,
return_token_type_ids=True,
padding = 'max_length',
truncation=True,
return_attention_mask=True,
return_tensors='tf'
)
input_ids = encoding['input_ids']
token_type_ids = encoding['token_type_ids']
attention_mask = encoding['attention_mask']
pred = bert_model.predict([input_ids, token_type_ids, attention_mask])
logits = pred.logits
pred_label = tf.argmax(logits, axis=1).numpy()[0]
label = {0: 'BUSINESS', 1: 'COMEDY', 2: 'CRIME', 3: 'FOOD & DRINK', 4: 'POLITICS', 5: 'SPORTS', 6: 'TRAVEL'}
result = {'text':text, 'label':[label[pred_label]]}
solutions.append(result)
# result_url = f"{api}/{job_id}"
# send_results_to_api(solutions, result_url)
return json.dumps({"solutions":solutions})
inputt = gr.Textbox(label="Parameters in Json Format... Eg. {'texts':['text1', 'text2']")
outputt = gr.JSON()
application = gr.Interface(fn = predict_text, inputs = inputt, outputs = outputt, title='Multi Text Classification with API Integration..')
application.launch()