File size: 4,401 Bytes
f9aebc6
 
f0e3a67
f9aebc6
 
f0e3a67
f9aebc6
 
f0e3a67
f9aebc6
 
 
 
 
 
 
 
 
 
 
 
 
ba9c8e0
f9aebc6
f0e3a67
 
f9aebc6
ba9c8e0
f9aebc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import gradio as gr
import torch
import os
from transformers import AutoModelForSequenceClassification, AutoTokenizer


MODEL_ID = "Batuka0901/mongolian_news_classifier"
MAX_LENGTH = 128
HF_TOKEN = os.environ.get("HF_TOKEN")

CATEGORY_DESC = {
    "байгал орчин": "Байгаль орчин, экологи, цаг уур, ногоон хөгжил",
    "боловсрол":    "Сургуули, их дээд сургууль, багш, сурагч, хичээл",
    "спорт":        "Тэмцээн, уралдаан, шигшээ баг, медаль, аварга",
    "технологи":    "Шинэ технологи, программ хангамж, интернет, инновац",
    "улс төр":      "Засгийн газар, сонгууль, хууль тогтоомж, бодлого",
    "урлаг соёл":   "Урлаг, дуу хөгжим, кино, уран зохиол, соёлын арга хэмжээ",
    "хууль":        "Шүүх, хэрэг, дүрэм журам, хуулийн байгууллага",
    "эдийн засаг":  "Санхүү, худалдаа, үйлдвэрлэл, инфляц, ДНБ",
    "эрүүл мэнд":   "Эмнэлэг, эм, өвчин, эмч, эрүүл мэндийн бодлого",
}


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, token=HF_TOKEN).to(device).eval()


id2label = {int(k): v for k, v in model.config.id2label.items()}

def classify(text: str):
    text = text.strip()
    if not text:
        return "", {}, ""

    inputs = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        max_length=MAX_LENGTH,
        padding=True,
    )
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        probs = torch.nn.functional.softmax(model(**inputs).logits, dim=-1)[0]

    pred_id    = probs.argmax().item()
    pred_label = id2label[pred_id]
    confidence = probs[pred_id].item()

    scores = {id2label[i]: float(probs[i]) for i in range(len(id2label))}
    desc = CATEGORY_DESC.get(pred_label, "")
    detail = f"Ангилал: {pred_label}  |  Итгэлцэл: {confidence:.1%}\n{desc}"

    return pred_label, scores, detail

EXAMPLES = [
    ["Өнөөдөр хөлбөмбөгийн тэмцээнд Монголын шигшээ баг хожлоо."],
    ["Эдийн засгийн сайд өнөөдөр мэдээлэл хийлээ."],
    ["Шинэ технологийн компани байгуулагдлаа."],
    ["Эмнэлэгт шинэ тоног төхөөрөмж суурилуулжээ."],
    ["Багш нарын цалин нэмэгдэнэ."],
    ["Нийслэлийн агаарын бохирдол буурлаа."],
]

CSS = """
#title { text-align: center; margin-bottom: 4px; }
#result_label { font-size: 1.3em; font-weight: bold; }
footer { display: none !important; }
"""

with gr.Blocks(css=CSS, title="Монгол мэдээний ангилагч") as demo:

    gr.Markdown("## Монгол мэдээний гарчиг ангилагч", elem_id="title")

    with gr.Row():
        with gr.Column(scale=3):
            txt_input = gr.Textbox(
                label="Мэдээний гарчиг",
                placeholder="Мэдээний гарчгийг энд бичнэ үү...",
                lines=3,
            )
            with gr.Row():
                btn_clear  = gr.Button("Арилгах")
                btn_submit = gr.Button("Ангилах", variant="primary")

            gr.Examples(examples=EXAMPLES, inputs=txt_input)

        with gr.Column(scale=2):
            out_label  = gr.Textbox(label="Ангилал", elem_id="result_label")
            out_scores = gr.Label(label="Ангилал бүрийн магадлал", num_top_classes=9)
            out_detail = gr.Textbox(label="Дэлгэрэнгүй", lines=2)

    btn_submit.click(fn=classify, inputs=txt_input, outputs=[out_label, out_scores, out_detail])
    btn_clear.click(fn=lambda: ("", {}, ""), outputs=[out_label, out_scores, out_detail])
    txt_input.submit(fn=classify, inputs=txt_input, outputs=[out_label, out_scores, out_detail])

if __name__ == "__main__":
    demo.launch()