permitt's picture
feat: demo app
41efb11
"""
ModernBERTić Large - HF Space demo
Click any word in BCMS text to mask it and see what the model predicts.
"""
import os
import re
import string
import gradio as gr
import spaces
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForMaskedLM
MODEL_NAME = "permitt/galton-modernbertic-large"
HF_TOKEN = os.environ.get("HF_TOKEN")
if HF_TOKEN is None:
raise RuntimeError(
"HF_TOKEN secret not set. Add it under Space Settings -> Variables and secrets."
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN)
model = AutoModelForMaskedLM.from_pretrained(
MODEL_NAME, dtype=torch.bfloat16, token=HF_TOKEN
).eval()
OUR_MASK = tokenizer.mask_token
EXAMPLES = {
"Psihometrija (long context)": (
"Psihometrija je teorijska disciplina, utemeljena na statistici, "
"koja proučava mogućnosti, zakone i principe merenja psiholoških pojava, "
"konstrukcijom, standardizacijom i evaluacijom testova i drugih psiholoških "
"mernih instrumenata, kao i statističkim problemima empirijskih istraživanja. "
"Psihometrija je oblast psihologije koja se bavi teorijom i tehnikom merenja. "
"Psihometrija se generalno odnosi na specijalizovane oblasti u okviru "
"psihologije i obrazovanja posvećene testiranju, merenju, proceni i srodnim "
"aktivnostima. Psihometrija se bavi objektivnim merenjem latentnih "
"konstrukata koji se ne mogu direktno posmatrati. Primeri latentnih "
"konstrukcija uključuju inteligenciju, introvertnost, mentalne poremećaje "
"i obrazovna postignuća. Nivoi sposobnosti na neuočljivim latentnim "
"varijablama zaključuju se putem matematičkog modeliranja na osnovu onoga "
"što se posmatra iz odgovora pojedinaca na stavke na testovima i skalama."
),
"Crna Gora (geography)": (
"Crna Gora je država na jugoistoku Evrope. Glavni grad je Podgorica, "
"a istorijska prestonica je Cetinje. Crna Gora se graniči sa Hrvatskom, "
"Bosnom i Hercegovinom, Srbijom, Kosovom i Albanijom. Skadarsko jezero, "
"koje deli sa Albanijom, najveće je jezero na Balkanu."
),
"Ivo Andrić (literature)": (
"Ivo Andrić je bio jugoslovenski književnik i diplomata, rođen u Travniku "
"u centralnoj Bosni 1892. godine. Najpoznatiji je po romanu Na Drini "
"ćuprija, za koji je 1961. godine dobio Nobelovu nagradu za književnost. "
"Smatra se jednim od najvećih pisaca južnoslovenskih književnosti "
"dvadesetog veka."
),
}
DEFAULT_EXAMPLE = "Psihometrija (long context)"
def split_text(text: str):
"""Split text into a list of tokens, preserving whitespace runs as separate items."""
return [p for p in re.split(r"(\s+)", text) if p]
def make_clickable(text: str, masked_index: int | None = None):
"""Build the value list for HighlightedText: each word gets the 'word' label,
whitespace stays unlabeled, and a single position can be rendered as 'mask'."""
if not text:
return []
tokens = split_text(text)
out = []
for i, t in enumerate(tokens):
if not t.strip():
out.append((t, None))
elif i == masked_index:
out.append((OUR_MASK, "mask"))
else:
out.append((t, "word"))
return out
def strip_edge_punct(word: str):
"""Return (leading, core, trailing) so we can mask only the alphabetic core
of a token like 'Beograd.' and keep the period in place."""
leading = ""
while word and word[0] in string.punctuation:
leading += word[0]
word = word[1:]
trailing = ""
while word and word[-1] in string.punctuation:
trailing = word[-1] + trailing
word = word[:-1]
return leading, word, trailing
@spaces.GPU
@torch.inference_mode()
def _predict(text: str, top_k: int = 5):
mdl = model.to("cuda")
inputs = tokenizer(
text, return_tensors="pt", truncation=True, max_length=8192
).to("cuda")
mask_id = tokenizer.mask_token_id
pos = (inputs.input_ids == mask_id).nonzero(as_tuple=True)
if len(pos[1]) == 0:
return [("(no mask token)", 0.0)]
logits = mdl(**inputs).logits
mask_logits = logits[0, pos[1][0]]
probs = F.softmax(mask_logits.float(), dim=-1)
# Over-fetch so we can drop pure-punctuation predictions and still return top_k.
top_probs, top_ids = probs.topk(top_k * 5)
raw = [
(tokenizer.decode([tid]).strip(), float(p))
for tid, p in zip(top_ids, top_probs)
]
filtered = [
(w, p) for w, p in raw
if w and not all(c in string.punctuation for c in w)
]
return (filtered or raw)[:top_k]
def on_word_click(text: str, evt: gr.SelectData):
if evt.value is None or not str(evt.value).strip():
return gr.update(), gr.update()
tokens = split_text(text)
if evt.index >= len(tokens):
return gr.update(), "Click registered out of range. Edit the text and try again."
target = tokens[evt.index]
if not target.strip():
return gr.update(), gr.update()
leading, core, trailing = strip_edge_punct(target)
if not core:
return gr.update(), gr.update()
masked_tokens = list(tokens)
masked_tokens[evt.index] = leading + OUR_MASK + trailing
masked_text = "".join(masked_tokens)
preds = _predict(masked_text)
lines = [
f"### Masked word: `{core}`",
"",
"| Rank | Token | Probability |",
"|------|-------|-------------|",
]
for i, (w, p) in enumerate(preds, 1):
marker = " ← original" if w.lower() == core.lower() else ""
lines.append(f"| {i} | `{w}`{marker} | {p:.3f} |")
return make_clickable(text, evt.index), "\n".join(lines)
def reset_on_change(text):
return (
make_clickable(text),
"*Click any word above to mask it and see predictions.*",
)
CSS = """
.clickable-text .textspan {
cursor: pointer;
transition: opacity 0.15s;
}
.clickable-text .textspan:hover {
opacity: 0.55;
}
"""
with gr.Blocks(title="ModernBERTić Large") as demo:
gr.Markdown(
"""
# ModernBERTić Large
First ModernBERT-style encoder for **Bosnian / Croatian / Montenegrin / Serbian**.
Pretrained on ~60B tokens with **8192-token context window**.
**How to use:** pick an example below or paste your own BCMS text, then **click any word** in the highlighted view to mask it. The model will predict what fits.
"""
)
with gr.Row():
example_buttons = [gr.Button(label, size="sm") for label in EXAMPLES]
inp = gr.Textbox(
label="Input text (paste anything in BCMS)",
value=EXAMPLES[DEFAULT_EXAMPLE],
lines=8,
)
clickable = gr.HighlightedText(
label="Click any word to mask it",
value=make_clickable(EXAMPLES[DEFAULT_EXAMPLE]),
color_map={
"word": "rgba(99, 102, 241, 0.18)",
"mask": "rgba(239, 68, 68, 0.55)",
},
show_legend=False,
combine_adjacent=False,
elem_classes=["clickable-text"],
)
output = gr.Markdown(
value="*Click any word above to mask it and see predictions.*"
)
for label, btn in zip(EXAMPLES.keys(), example_buttons):
btn.click(
fn=lambda l=label: (
EXAMPLES[l],
make_clickable(EXAMPLES[l]),
"*Click any word above to mask it and see predictions.*",
),
outputs=[inp, clickable, output],
)
inp.change(reset_on_change, inp, [clickable, output])
clickable.select(on_word_click, inp, [clickable, output])
gr.Markdown(
"---\n"
"Trained on EuroHPC Leonardo (64× A100) at Recrewty. https://recrewty.com \n"
"You can find the results at SuperGLUE-SR results: https://balkanbench.com/leaderboard. \n"
"Link to blogposts and release: https://permitt.io. \n"
)
if __name__ == "__main__":
demo.queue().launch(ssr_mode=False, css=CSS)