import gradio as gr import torch import os from transformers import AutoModelForSequenceClassification, AutoTokenizer MODEL_ID = "Batuka0901/mongolian_news_classifier" MAX_LENGTH = 128 HF_TOKEN = os.environ.get("HF_TOKEN") CATEGORY_DESC = { "байгал орчин": "Байгаль орчин, экологи, цаг уур, ногоон хөгжил", "боловсрол": "Сургуули, их дээд сургууль, багш, сурагч, хичээл", "спорт": "Тэмцээн, уралдаан, шигшээ баг, медаль, аварга", "технологи": "Шинэ технологи, программ хангамж, интернет, инновац", "улс төр": "Засгийн газар, сонгууль, хууль тогтоомж, бодлого", "урлаг соёл": "Урлаг, дуу хөгжим, кино, уран зохиол, соёлын арга хэмжээ", "хууль": "Шүүх, хэрэг, дүрэм журам, хуулийн байгууллага", "эдийн засаг": "Санхүү, худалдаа, үйлдвэрлэл, инфляц, ДНБ", "эрүүл мэнд": "Эмнэлэг, эм, өвчин, эмч, эрүүл мэндийн бодлого", } device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN) model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, token=HF_TOKEN).to(device).eval() id2label = {int(k): v for k, v in model.config.id2label.items()} def classify(text: str): text = text.strip() if not text: return "", {}, "" inputs = tokenizer( text, return_tensors="pt", truncation=True, max_length=MAX_LENGTH, padding=True, ) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): probs = torch.nn.functional.softmax(model(**inputs).logits, dim=-1)[0] pred_id = probs.argmax().item() pred_label = id2label[pred_id] confidence = probs[pred_id].item() scores = {id2label[i]: float(probs[i]) for i in range(len(id2label))} desc = CATEGORY_DESC.get(pred_label, "") detail = f"Ангилал: {pred_label} | Итгэлцэл: {confidence:.1%}\n{desc}" return pred_label, scores, detail EXAMPLES = [ ["Өнөөдөр хөлбөмбөгийн тэмцээнд Монголын шигшээ баг хожлоо."], ["Эдийн засгийн сайд өнөөдөр мэдээлэл хийлээ."], ["Шинэ технологийн компани байгуулагдлаа."], ["Эмнэлэгт шинэ тоног төхөөрөмж суурилуулжээ."], ["Багш нарын цалин нэмэгдэнэ."], ["Нийслэлийн агаарын бохирдол буурлаа."], ] CSS = """ #title { text-align: center; margin-bottom: 4px; } #result_label { font-size: 1.3em; font-weight: bold; } footer { display: none !important; } """ with gr.Blocks(css=CSS, title="Монгол мэдээний ангилагч") as demo: gr.Markdown("## Монгол мэдээний гарчиг ангилагч", elem_id="title") with gr.Row(): with gr.Column(scale=3): txt_input = gr.Textbox( label="Мэдээний гарчиг", placeholder="Мэдээний гарчгийг энд бичнэ үү...", lines=3, ) with gr.Row(): btn_clear = gr.Button("Арилгах") btn_submit = gr.Button("Ангилах", variant="primary") gr.Examples(examples=EXAMPLES, inputs=txt_input) with gr.Column(scale=2): out_label = gr.Textbox(label="Ангилал", elem_id="result_label") out_scores = gr.Label(label="Ангилал бүрийн магадлал", num_top_classes=9) out_detail = gr.Textbox(label="Дэлгэрэнгүй", lines=2) btn_submit.click(fn=classify, inputs=txt_input, outputs=[out_label, out_scores, out_detail]) btn_clear.click(fn=lambda: ("", {}, ""), outputs=[out_label, out_scores, out_detail]) txt_input.submit(fn=classify, inputs=txt_input, outputs=[out_label, out_scores, out_detail]) if __name__ == "__main__": demo.launch()