permitt commited on
Commit
f5fc858
·
1 Parent(s): ff7e90d

feat: demo app

Browse files
Files changed (1) hide show
  1. app.py +45 -23
app.py CHANGED
@@ -1,32 +1,46 @@
1
  """
2
  ModernBERTić Large - HF Space demo
3
- Three tabs: fill-mask, side-by-side vs BERTić, long-context fill-mask.
4
  """
5
 
 
 
6
  import gradio as gr
 
7
  import torch
8
  import torch.nn.functional as F
9
  from transformers import AutoTokenizer, AutoModelForMaskedLM
10
 
11
  MODEL_NAME = "permitt/galton-modernbertic-large"
12
- BASELINE_NAME = "classla/bcms-bertic"
13
 
14
- device = "cuda" if torch.cuda.is_available() else "cpu"
15
- dtype = torch.bfloat16 if device == "cuda" else torch.float32
 
 
 
16
 
17
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
18
- model = AutoModelForMaskedLM.from_pretrained(MODEL_NAME, torch_dtype=dtype).to(device).eval()
 
 
 
19
 
20
  baseline_tokenizer = AutoTokenizer.from_pretrained(BASELINE_NAME)
21
- baseline_model = AutoModelForMaskedLM.from_pretrained(BASELINE_NAME).to(device).eval()
 
 
22
 
23
  OUR_MASK = tokenizer.mask_token
24
  THEIR_MASK = baseline_tokenizer.mask_token
25
 
26
 
27
  @torch.inference_mode()
28
- def fill_mask(text: str, tok, mdl, top_k: int = 5):
29
- inputs = tok(text, return_tensors="pt", truncation=True, max_length=8192).to(device)
 
 
 
30
  mask_id = tok.mask_token_id
31
  pos = (inputs.input_ids == mask_id).nonzero(as_tuple=True)
32
  if len(pos[1]) == 0:
@@ -35,25 +49,30 @@ def fill_mask(text: str, tok, mdl, top_k: int = 5):
35
  mask_logits = logits[0, pos[1][0]]
36
  probs = F.softmax(mask_logits.float(), dim=-1)
37
  top_probs, top_ids = probs.topk(top_k)
38
- return [(tok.decode([tid]).strip(), float(p)) for tid, p in zip(top_ids, top_probs)]
 
 
 
39
 
40
 
41
  def fmt(preds):
42
  return "\n".join(f"{w:<20} {p:.3f}" for w, p in preds)
43
 
44
 
45
- def predict_ours(text):
46
- return fmt(fill_mask(text, tokenizer, model))
 
47
 
48
 
49
- def predict_compare(text):
50
- ours = fill_mask(text, tokenizer, model)
51
- bertic_text = text.replace(OUR_MASK, THEIR_MASK)
52
- theirs = fill_mask(bertic_text, baseline_tokenizer, baseline_model)
 
53
  return fmt(ours), fmt(theirs)
54
 
55
 
56
- with gr.Blocks(title="ModernBERTić Large", theme=gr.themes.Soft()) as demo:
57
  gr.Markdown(
58
  f"""
59
  # ModernBERTić Large
@@ -75,14 +94,17 @@ with gr.Blocks(title="ModernBERTić Large", theme=gr.themes.Soft()) as demo:
75
  f"Glavni grad Srbije je {OUR_MASK}.",
76
  f"Najveći grad u Hrvatskoj je {OUR_MASK}.",
77
  f"Pisac romana 'Na Drini ćuprija' je {OUR_MASK} Andrić.",
78
- f"Главни град Србије је {OUR_MASK}.", # cyrillic
79
  ],
80
  inputs=inp,
81
  )
82
  btn.click(predict_ours, inp, out)
83
 
84
- with gr.Tab("vs BERTić"):
85
- gr.Markdown("Same input, both models. ModernBERTić-large vs `classla/bcms-bertic`.")
 
 
 
86
  inp2 = gr.Textbox(
87
  label="Input",
88
  value=f"Najveće jezero u Crnoj Gori je {OUR_MASK} jezero.",
@@ -91,13 +113,13 @@ with gr.Blocks(title="ModernBERTić Large", theme=gr.themes.Soft()) as demo:
91
  btn2 = gr.Button("Compare", variant="primary")
92
  with gr.Row():
93
  out_ours = gr.Textbox(label="ModernBERTić-large (ours)", lines=6)
94
- out_theirs = gr.Textbox(label="BERTić (Ljubešić et al.)", lines=6)
95
  btn2.click(predict_compare, inp2, [out_ours, out_theirs])
96
 
97
  with gr.Tab("Long context (8192)"):
98
  gr.Markdown(
99
  "Paste a long passage with one mask token deep in the text. "
100
- "BERTić truncates at 512 tokens. ModernBERTić handles up to 8192."
101
  )
102
  inp3 = gr.Textbox(
103
  label="Long input",
@@ -117,4 +139,4 @@ with gr.Blocks(title="ModernBERTić Large", theme=gr.themes.Soft()) as demo:
117
 
118
 
119
  if __name__ == "__main__":
120
- demo.launch()
 
1
  """
2
  ModernBERTić Large - HF Space demo
3
+ Three tabs: fill-mask, side-by-side vs XLM-R, long-context fill-mask.
4
  """
5
 
6
+ import os
7
+
8
  import gradio as gr
9
+ import spaces
10
  import torch
11
  import torch.nn.functional as F
12
  from transformers import AutoTokenizer, AutoModelForMaskedLM
13
 
14
  MODEL_NAME = "permitt/galton-modernbertic-large"
15
+ BASELINE_NAME = "FacebookAI/xlm-roberta-large"
16
 
17
+ HF_TOKEN = os.environ.get("HF_TOKEN")
18
+ if HF_TOKEN is None:
19
+ raise RuntimeError(
20
+ "HF_TOKEN secret not set. Add it under Space Settings -> Variables and secrets."
21
+ )
22
 
23
+ # Load on CPU. ZeroGPU allocates GPU only inside @spaces.GPU functions.
24
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN)
25
+ model = AutoModelForMaskedLM.from_pretrained(
26
+ MODEL_NAME, dtype=torch.bfloat16, token=HF_TOKEN
27
+ ).eval()
28
 
29
  baseline_tokenizer = AutoTokenizer.from_pretrained(BASELINE_NAME)
30
+ baseline_model = AutoModelForMaskedLM.from_pretrained(
31
+ BASELINE_NAME, dtype=torch.bfloat16
32
+ ).eval()
33
 
34
  OUR_MASK = tokenizer.mask_token
35
  THEIR_MASK = baseline_tokenizer.mask_token
36
 
37
 
38
  @torch.inference_mode()
39
+ def _run(text, tok, mdl, top_k=5):
40
+ mdl = mdl.to("cuda")
41
+ inputs = tok(
42
+ text, return_tensors="pt", truncation=True, max_length=8192
43
+ ).to("cuda")
44
  mask_id = tok.mask_token_id
45
  pos = (inputs.input_ids == mask_id).nonzero(as_tuple=True)
46
  if len(pos[1]) == 0:
 
49
  mask_logits = logits[0, pos[1][0]]
50
  probs = F.softmax(mask_logits.float(), dim=-1)
51
  top_probs, top_ids = probs.topk(top_k)
52
+ return [
53
+ (tok.decode([tid]).strip(), float(p))
54
+ for tid, p in zip(top_ids, top_probs)
55
+ ]
56
 
57
 
58
  def fmt(preds):
59
  return "\n".join(f"{w:<20} {p:.3f}" for w, p in preds)
60
 
61
 
62
+ @spaces.GPU
63
+ def predict_ours(text: str) -> str:
64
+ return fmt(_run(text, tokenizer, model))
65
 
66
 
67
+ @spaces.GPU
68
+ def predict_compare(text: str):
69
+ ours = _run(text, tokenizer, model)
70
+ their_text = text.replace(OUR_MASK, THEIR_MASK)
71
+ theirs = _run(their_text, baseline_tokenizer, baseline_model)
72
  return fmt(ours), fmt(theirs)
73
 
74
 
75
+ with gr.Blocks(title="ModernBERTić Large") as demo:
76
  gr.Markdown(
77
  f"""
78
  # ModernBERTić Large
 
94
  f"Glavni grad Srbije je {OUR_MASK}.",
95
  f"Najveći grad u Hrvatskoj je {OUR_MASK}.",
96
  f"Pisac romana 'Na Drini ćuprija' je {OUR_MASK} Andrić.",
97
+ f"Главни град Србије је {OUR_MASK}.",
98
  ],
99
  inputs=inp,
100
  )
101
  btn.click(predict_ours, inp, out)
102
 
103
+ with gr.Tab("vs XLM-R large"):
104
+ gr.Markdown(
105
+ "Same input, both models. ModernBERTić-large vs `xlm-roberta-large` "
106
+ "(the standard multilingual MLM baseline for BCMS)."
107
+ )
108
  inp2 = gr.Textbox(
109
  label="Input",
110
  value=f"Najveće jezero u Crnoj Gori je {OUR_MASK} jezero.",
 
113
  btn2 = gr.Button("Compare", variant="primary")
114
  with gr.Row():
115
  out_ours = gr.Textbox(label="ModernBERTić-large (ours)", lines=6)
116
+ out_theirs = gr.Textbox(label="XLM-R large", lines=6)
117
  btn2.click(predict_compare, inp2, [out_ours, out_theirs])
118
 
119
  with gr.Tab("Long context (8192)"):
120
  gr.Markdown(
121
  "Paste a long passage with one mask token deep in the text. "
122
+ "Most BCMS encoders truncate at 512 tokens. ModernBERTić handles up to 8192."
123
  )
124
  inp3 = gr.Textbox(
125
  label="Long input",
 
139
 
140
 
141
  if __name__ == "__main__":
142
+ demo.launch(theme=gr.themes.Soft())