| import gradio as gr |
| import torch |
| import os |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
|
|
|
| MODEL_ID = "Batuka0901/mongolian_news_classifier" |
| MAX_LENGTH = 128 |
| HF_TOKEN = os.environ.get("HF_TOKEN") |
|
|
| CATEGORY_DESC = { |
| "байгал орчин": "Байгаль орчин, экологи, цаг уур, ногоон хөгжил", |
| "боловсрол": "Сургуули, их дээд сургууль, багш, сурагч, хичээл", |
| "спорт": "Тэмцээн, уралдаан, шигшээ баг, медаль, аварга", |
| "технологи": "Шинэ технологи, программ хангамж, интернет, инновац", |
| "улс төр": "Засгийн газар, сонгууль, хууль тогтоомж, бодлого", |
| "урлаг соёл": "Урлаг, дуу хөгжим, кино, уран зохиол, соёлын арга хэмжээ", |
| "хууль": "Шүүх, хэрэг, дүрэм журам, хуулийн байгууллага", |
| "эдийн засаг": "Санхүү, худалдаа, үйлдвэрлэл, инфляц, ДНБ", |
| "эрүүл мэнд": "Эмнэлэг, эм, өвчин, эмч, эрүүл мэндийн бодлого", |
| } |
|
|
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN) |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, token=HF_TOKEN).to(device).eval() |
|
|
|
|
| id2label = {int(k): v for k, v in model.config.id2label.items()} |
|
|
| def classify(text: str): |
| text = text.strip() |
| if not text: |
| return "", {}, "" |
|
|
| inputs = tokenizer( |
| text, |
| return_tensors="pt", |
| truncation=True, |
| max_length=MAX_LENGTH, |
| padding=True, |
| ) |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
| with torch.no_grad(): |
| probs = torch.nn.functional.softmax(model(**inputs).logits, dim=-1)[0] |
|
|
| pred_id = probs.argmax().item() |
| pred_label = id2label[pred_id] |
| confidence = probs[pred_id].item() |
|
|
| scores = {id2label[i]: float(probs[i]) for i in range(len(id2label))} |
| desc = CATEGORY_DESC.get(pred_label, "") |
| detail = f"Ангилал: {pred_label} | Итгэлцэл: {confidence:.1%}\n{desc}" |
|
|
| return pred_label, scores, detail |
|
|
| EXAMPLES = [ |
| ["Өнөөдөр хөлбөмбөгийн тэмцээнд Монголын шигшээ баг хожлоо."], |
| ["Эдийн засгийн сайд өнөөдөр мэдээлэл хийлээ."], |
| ["Шинэ технологийн компани байгуулагдлаа."], |
| ["Эмнэлэгт шинэ тоног төхөөрөмж суурилуулжээ."], |
| ["Багш нарын цалин нэмэгдэнэ."], |
| ["Нийслэлийн агаарын бохирдол буурлаа."], |
| ] |
|
|
| CSS = """ |
| #title { text-align: center; margin-bottom: 4px; } |
| #result_label { font-size: 1.3em; font-weight: bold; } |
| footer { display: none !important; } |
| """ |
|
|
| with gr.Blocks(css=CSS, title="Монгол мэдээний ангилагч") as demo: |
|
|
| gr.Markdown("## Монгол мэдээний гарчиг ангилагч", elem_id="title") |
|
|
| with gr.Row(): |
| with gr.Column(scale=3): |
| txt_input = gr.Textbox( |
| label="Мэдээний гарчиг", |
| placeholder="Мэдээний гарчгийг энд бичнэ үү...", |
| lines=3, |
| ) |
| with gr.Row(): |
| btn_clear = gr.Button("Арилгах") |
| btn_submit = gr.Button("Ангилах", variant="primary") |
|
|
| gr.Examples(examples=EXAMPLES, inputs=txt_input) |
|
|
| with gr.Column(scale=2): |
| out_label = gr.Textbox(label="Ангилал", elem_id="result_label") |
| out_scores = gr.Label(label="Ангилал бүрийн магадлал", num_top_classes=9) |
| out_detail = gr.Textbox(label="Дэлгэрэнгүй", lines=2) |
|
|
| btn_submit.click(fn=classify, inputs=txt_input, outputs=[out_label, out_scores, out_detail]) |
| btn_clear.click(fn=lambda: ("", {}, ""), outputs=[out_label, out_scores, out_detail]) |
| txt_input.submit(fn=classify, inputs=txt_input, outputs=[out_label, out_scores, out_detail]) |
|
|
| if __name__ == "__main__": |
| demo.launch() |