Batuka0901's picture
Update app.py
ba9c8e0 verified
import gradio as gr
import torch
import os
from transformers import AutoModelForSequenceClassification, AutoTokenizer
MODEL_ID = "Batuka0901/mongolian_news_classifier"
MAX_LENGTH = 128
HF_TOKEN = os.environ.get("HF_TOKEN")
CATEGORY_DESC = {
"байгал орчин": "Байгаль орчин, экологи, цаг уур, ногоон хөгжил",
"боловсрол": "Сургуули, их дээд сургууль, багш, сурагч, хичээл",
"спорт": "Тэмцээн, уралдаан, шигшээ баг, медаль, аварга",
"технологи": "Шинэ технологи, программ хангамж, интернет, инновац",
"улс төр": "Засгийн газар, сонгууль, хууль тогтоомж, бодлого",
"урлаг соёл": "Урлаг, дуу хөгжим, кино, уран зохиол, соёлын арга хэмжээ",
"хууль": "Шүүх, хэрэг, дүрэм журам, хуулийн байгууллага",
"эдийн засаг": "Санхүү, худалдаа, үйлдвэрлэл, инфляц, ДНБ",
"эрүүл мэнд": "Эмнэлэг, эм, өвчин, эмч, эрүүл мэндийн бодлого",
}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, token=HF_TOKEN).to(device).eval()
id2label = {int(k): v for k, v in model.config.id2label.items()}
def classify(text: str):
text = text.strip()
if not text:
return "", {}, ""
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=MAX_LENGTH,
padding=True,
)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
probs = torch.nn.functional.softmax(model(**inputs).logits, dim=-1)[0]
pred_id = probs.argmax().item()
pred_label = id2label[pred_id]
confidence = probs[pred_id].item()
scores = {id2label[i]: float(probs[i]) for i in range(len(id2label))}
desc = CATEGORY_DESC.get(pred_label, "")
detail = f"Ангилал: {pred_label} | Итгэлцэл: {confidence:.1%}\n{desc}"
return pred_label, scores, detail
EXAMPLES = [
["Өнөөдөр хөлбөмбөгийн тэмцээнд Монголын шигшээ баг хожлоо."],
["Эдийн засгийн сайд өнөөдөр мэдээлэл хийлээ."],
["Шинэ технологийн компани байгуулагдлаа."],
["Эмнэлэгт шинэ тоног төхөөрөмж суурилуулжээ."],
["Багш нарын цалин нэмэгдэнэ."],
["Нийслэлийн агаарын бохирдол буурлаа."],
]
CSS = """
#title { text-align: center; margin-bottom: 4px; }
#result_label { font-size: 1.3em; font-weight: bold; }
footer { display: none !important; }
"""
with gr.Blocks(css=CSS, title="Монгол мэдээний ангилагч") as demo:
gr.Markdown("## Монгол мэдээний гарчиг ангилагч", elem_id="title")
with gr.Row():
with gr.Column(scale=3):
txt_input = gr.Textbox(
label="Мэдээний гарчиг",
placeholder="Мэдээний гарчгийг энд бичнэ үү...",
lines=3,
)
with gr.Row():
btn_clear = gr.Button("Арилгах")
btn_submit = gr.Button("Ангилах", variant="primary")
gr.Examples(examples=EXAMPLES, inputs=txt_input)
with gr.Column(scale=2):
out_label = gr.Textbox(label="Ангилал", elem_id="result_label")
out_scores = gr.Label(label="Ангилал бүрийн магадлал", num_top_classes=9)
out_detail = gr.Textbox(label="Дэлгэрэнгүй", lines=2)
btn_submit.click(fn=classify, inputs=txt_input, outputs=[out_label, out_scores, out_detail])
btn_clear.click(fn=lambda: ("", {}, ""), outputs=[out_label, out_scores, out_detail])
txt_input.submit(fn=classify, inputs=txt_input, outputs=[out_label, out_scores, out_detail])
if __name__ == "__main__":
demo.launch()