D3V1L1810's picture
Update app.py
7a437ab verified
raw
history blame
3.09 kB
import tensorflow as tf
from transformers import BertTokenizer, TFBertForSequenceClassification
import numpy as np
import json
import requests
import gradio as gr
import logging
# Initialize the tokenizer and model
bert_tokenizer = BertTokenizer.from_pretrained('MultiTokenizer_ep10')
bert_model = TFBertForSequenceClassification.from_pretrained('MultiModel_ep10')
# Function to send results to API
# def send_results_to_api(data, result_url):
# headers = {'Content-Type':'application/json'}
# response = requests.post(result_url, json = data, headers=headers)
# if response.status_code == 200:
# return response.json
# else:
# return {'error':f"failed to send result to API: {response.status_code}"}
def predict_text(params):
try:
params = json.loads(params)
except json.JSONDecodeError as e:
logging.error(f"Invalid JSON input: {e.msg} at line {e.lineno} column {e.colno}")
return {"error": f"Invalid JSON input: {e.msg} at line {e.lineno} column {e.colno}"}
texts = params.get("urls", [])
if not params.get("normalfileID", []):
file_ids = [None] * len(texts)
else:
file_ids = params.get("normalfileID", [])
if not texts:
return {"error": "Missing required parameters: 'texts'"}
solutions = []
confidence_threshold = 0.85 # Define your confidence threshold
for text, file_id in zip(texts, file_ids):
encoding = bert_tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=128,
return_token_type_ids=True,
padding='max_length',
truncation=True,
return_attention_mask=True,
return_tensors='tf'
)
input_ids = encoding['input_ids']
token_type_ids = encoding['token_type_ids']
attention_mask = encoding['attention_mask']
pred = bert_model.predict([input_ids, token_type_ids, attention_mask])
logits = pred.logits
softmax_scores = tf.nn.softmax(logits, axis=1).numpy()[0]
pred_label = tf.argmax(logits, axis=1).numpy()[0]
# Get the confidence score for the predicted label
confidence = softmax_scores[pred_label]
print(confidence)
# If confidence is below the threshold, set answer to None
if confidence < confidence_threshold:
pred_label = 7 # Set to 'None' class
label = {0: 'BUSINESS', 1: 'COMEDY', 2: 'CRIME', 3: 'FOOD & DRINK', 4: 'POLITICS', 5: 'SPORTS', 6: 'TRAVEL', 7: 'None'}
result = {'text': text, 'answer': [label[pred_label]], "qcUser": None, "normalfileID": file_id}
solutions.append(result)
# result_url = f"{api}/{job_id}"
# send_results_to_api(solutions, result_url)
return json.dumps({"solutions": solutions})
inputt = gr.Textbox(label="Parameters in Json Format... Eg. {'texts':['text1', 'text2']}")
outputt = gr.JSON()
application = gr.Interface(fn=predict_text, inputs=inputt, outputs=outputt, title='Multi Text Classification with API Integration..')
application.launch()