permitt commited on
Commit
ff7e90d
·
1 Parent(s): 3afcada

feat: demo app

Browse files
Files changed (2) hide show
  1. app.py +120 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ModernBERTić Large - HF Space demo
3
+ Three tabs: fill-mask, side-by-side vs BERTić, long-context fill-mask.
4
+ """
5
+
6
+ import gradio as gr
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
10
+
11
+ MODEL_NAME = "permitt/galton-modernbertic-large"
12
+ BASELINE_NAME = "classla/bcms-bertic"
13
+
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ dtype = torch.bfloat16 if device == "cuda" else torch.float32
16
+
17
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
18
+ model = AutoModelForMaskedLM.from_pretrained(MODEL_NAME, torch_dtype=dtype).to(device).eval()
19
+
20
+ baseline_tokenizer = AutoTokenizer.from_pretrained(BASELINE_NAME)
21
+ baseline_model = AutoModelForMaskedLM.from_pretrained(BASELINE_NAME).to(device).eval()
22
+
23
+ OUR_MASK = tokenizer.mask_token
24
+ THEIR_MASK = baseline_tokenizer.mask_token
25
+
26
+
27
+ @torch.inference_mode()
28
+ def fill_mask(text: str, tok, mdl, top_k: int = 5):
29
+ inputs = tok(text, return_tensors="pt", truncation=True, max_length=8192).to(device)
30
+ mask_id = tok.mask_token_id
31
+ pos = (inputs.input_ids == mask_id).nonzero(as_tuple=True)
32
+ if len(pos[1]) == 0:
33
+ return [("(no mask token in input)", 0.0)]
34
+ logits = mdl(**inputs).logits
35
+ mask_logits = logits[0, pos[1][0]]
36
+ probs = F.softmax(mask_logits.float(), dim=-1)
37
+ top_probs, top_ids = probs.topk(top_k)
38
+ return [(tok.decode([tid]).strip(), float(p)) for tid, p in zip(top_ids, top_probs)]
39
+
40
+
41
+ def fmt(preds):
42
+ return "\n".join(f"{w:<20} {p:.3f}" for w, p in preds)
43
+
44
+
45
+ def predict_ours(text):
46
+ return fmt(fill_mask(text, tokenizer, model))
47
+
48
+
49
+ def predict_compare(text):
50
+ ours = fill_mask(text, tokenizer, model)
51
+ bertic_text = text.replace(OUR_MASK, THEIR_MASK)
52
+ theirs = fill_mask(bertic_text, baseline_tokenizer, baseline_model)
53
+ return fmt(ours), fmt(theirs)
54
+
55
+
56
+ with gr.Blocks(title="ModernBERTić Large", theme=gr.themes.Soft()) as demo:
57
+ gr.Markdown(
58
+ f"""
59
+ # ModernBERTić Large
60
+ First ModernBERT-style encoder for **Bosnian / Croatian / Montenegrin / Serbian**.
61
+ Pretrained on ~66B tokens with 8192 context window. Use `{OUR_MASK}` as the mask token.
62
+ """
63
+ )
64
+
65
+ with gr.Tab("Fill mask"):
66
+ inp = gr.Textbox(
67
+ label="Input",
68
+ value=f"Glavni grad Crne Gore je {OUR_MASK}.",
69
+ lines=2,
70
+ )
71
+ btn = gr.Button("Predict", variant="primary")
72
+ out = gr.Textbox(label="Top-5 predictions", lines=6)
73
+ gr.Examples(
74
+ examples=[
75
+ f"Glavni grad Srbije je {OUR_MASK}.",
76
+ f"Najveći grad u Hrvatskoj je {OUR_MASK}.",
77
+ f"Pisac romana 'Na Drini ćuprija' je {OUR_MASK} Andrić.",
78
+ f"Главни град Србије је {OUR_MASK}.", # cyrillic
79
+ ],
80
+ inputs=inp,
81
+ )
82
+ btn.click(predict_ours, inp, out)
83
+
84
+ with gr.Tab("vs BERTić"):
85
+ gr.Markdown("Same input, both models. ModernBERTić-large vs `classla/bcms-bertic`.")
86
+ inp2 = gr.Textbox(
87
+ label="Input",
88
+ value=f"Najveće jezero u Crnoj Gori je {OUR_MASK} jezero.",
89
+ lines=2,
90
+ )
91
+ btn2 = gr.Button("Compare", variant="primary")
92
+ with gr.Row():
93
+ out_ours = gr.Textbox(label="ModernBERTić-large (ours)", lines=6)
94
+ out_theirs = gr.Textbox(label="BERTić (Ljubešić et al.)", lines=6)
95
+ btn2.click(predict_compare, inp2, [out_ours, out_theirs])
96
+
97
+ with gr.Tab("Long context (8192)"):
98
+ gr.Markdown(
99
+ "Paste a long passage with one mask token deep in the text. "
100
+ "BERTić truncates at 512 tokens. ModernBERTić handles up to 8192."
101
+ )
102
+ inp3 = gr.Textbox(
103
+ label="Long input",
104
+ lines=15,
105
+ placeholder=f"Paste a Wikipedia paragraph and place {OUR_MASK} somewhere late in the text...",
106
+ )
107
+ btn3 = gr.Button("Predict", variant="primary")
108
+ out3 = gr.Textbox(label="Top-5 predictions", lines=6)
109
+ btn3.click(predict_ours, inp3, out3)
110
+
111
+ gr.Markdown(
112
+ """
113
+ ---
114
+ Trained on EuroHPC Leonardo (64× A100). Paper, checkpoints and SuperGLUE-SR results: [link].
115
+ """
116
+ )
117
+
118
+
119
+ if __name__ == "__main__":
120
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ transformers>=4.48
3
+ gradio>=4.0
4
+ spaces
5
+