hellosindh commited on
Commit
a4f4b5c
Β·
verified Β·
1 Parent(s): 2c21c22

Upload 2 files

Browse files
Files changed (2) hide show
  1. README.md +107 -0
  2. inference.py +399 -0
README.md ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Indus Script Models
2
+
3
+ Four trained models + NanoGPT for the undeciphered Indus Valley Script (2600–1900 BCE).
4
+
5
+ ## What's in this repo
6
+
7
+ ```
8
+ models/
9
+ mlm/best/ TinyBERT masked language model
10
+ cls/best/ TinyBERT sequence classifier (valid vs corrupted)
11
+ ngram_model.pkl N-gram RTL transition model
12
+ electra/best/ ELECTRA token discriminator
13
+ deberta/best/ DeBERTa sequence discriminator
14
+ nanogpt_indus.pt NanoGPT generator (153K params)
15
+ data/
16
+ indus_tokenizer/ Custom tokenizer (641 Indus sign tokens)
17
+ id_to_glyph.json Sign ID β†’ glyph character mapping
18
+ inference.py Run all tasks (see below)
19
+ indus_ngram.py Required by ngram_model.pkl
20
+ ```
21
+
22
+ ## How the pipeline works
23
+
24
+ **Stage 1 β€” Real inscriptions (3,310 sequences):**
25
+ Four models trained independently on real Indus Script inscriptions.
26
+ Each learned a different aspect of grammar:
27
+ - TinyBERT MLM β†’ which signs can fill a masked position
28
+ - TinyBERT Classifier β†’ valid sequence vs corrupted
29
+ - N-gram RTL β†’ right-to-left transition probabilities
30
+ - ELECTRA β†’ token-level real vs fake discrimination
31
+ - DeBERTa β†’ sequence-level real vs fake discrimination
32
+
33
+ **Stage 2 β€” Generate + filter:**
34
+ NanoGPT generates candidates in RTL order.
35
+ Each candidate scored by BERT (50%) + N-gram (25%) + ELECTRA (25%).
36
+ Only sequences scoring β‰₯85% ensemble are kept.
37
+ Exact matches to real inscriptions separated as validation evidence.
38
+
39
+ **Stage 3 β€” Retrain on combined data (3,310 real + 5,000 synthetic = 8,310):**
40
+ All models retrained β†’ TinyBERT accuracy 78% β†’ 89%, NanoGPT PPL 32.5 β†’ 13.3.
41
+ Final 5,000 sequences generated with retrained models.
42
+
43
+ ## Quick start
44
+
45
+ ```bash
46
+ pip install torch transformers huggingface_hub
47
+
48
+ # Clone this repo
49
+ git clone https://huggingface.co/YOUR_USERNAME/indus-script-models
50
+ cd indus-script-models
51
+
52
+ # Run demo (validates 5 example sequences)
53
+ python inference.py --task demo
54
+
55
+ # Validate a sequence
56
+ python inference.py --task validate --sequence "T638 T177 T420 T122"
57
+
58
+ # Predict a masked sign
59
+ python inference.py --task predict --sequence "T638 [MASK] T420 T122"
60
+
61
+ # Generate 10 new sequences
62
+ python inference.py --task generate --count 10
63
+
64
+ # Score any sequence
65
+ python inference.py --task score --sequence "T604 T123 T609"
66
+ ```
67
+
68
+ ## Example output
69
+
70
+ ```
71
+ Loading models...
72
+ βœ“ TinyBERT
73
+ βœ“ N-gram
74
+ βœ“ ELECTRA
75
+
76
+ Sequence : T638 T177 T420 T122
77
+ Glyphs : 𐦭𐦬𐦰𐦑
78
+ BERT : 0.9650
79
+ N-gram : 0.8930
80
+ ELECTRA : 0.9410
81
+ Ensemble : 0.9410
82
+ Verdict : βœ… VALID (β‰₯85%)
83
+ ```
84
+
85
+ ## Model performance
86
+
87
+ | Model | Metric | Value |
88
+ |---|---|---|
89
+ | TinyBERT Classifier | Test accuracy | 89.0% |
90
+ | TinyBERT MLM | Val loss | 2.06 |
91
+ | N-gram RTL | Pairwise accuracy | 88.2% |
92
+ | ELECTRA | Token accuracy | 95.1% |
93
+ | DeBERTa | Test accuracy | 87.1% |
94
+ | NanoGPT | Perplexity | 13.3 |
95
+
96
+ ## Key findings
97
+
98
+ - **RTL confirmed** β€” right-to-left has 12% stronger grammatical structure than LTR
99
+ - **Grammar proven** β€” H1β†’H2β†’H3 = 6.03β†’3.41β†’2.39 bits (language-like decay)
100
+ - **Zipf's law** β€” RΒ²=0.968 (language-like token distribution)
101
+ - **752 seal reproductions** β€” model independently reproduced real inscriptions
102
+ - **Sign roles** β€” PREFIX (T638, T604), SUFFIX (T123, T122), CORE (T101, T268)
103
+
104
+ ## Dataset
105
+
106
+ The 5,000 synthetic sequences are available at:
107
+ [YOUR_USERNAME/indus-script-synthetic](https://huggingface.co/datasets/YOUR_USERNAME/indus-script-synthetic)
inference.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Indus Script β€” Inference & Generation
3
+ ======================================
4
+ Download models from HuggingFace and run:
5
+ 1. Sequence validation β€” is this inscription valid?
6
+ 2. Sign prediction β€” predict a masked sign
7
+ 3. Generate synthetic β€” generate new Indus sequences
8
+ 4. Score any sequence β€” get ensemble confidence score
9
+
10
+ Install:
11
+ pip install torch transformers huggingface_hub
12
+
13
+ Usage:
14
+ python inference.py --task validate --sequence "T638 T177 T420 T122"
15
+ python inference.py --task predict --sequence "T638 [MASK] T420 T122"
16
+ python inference.py --task generate --count 10
17
+ python inference.py --task score --sequence "T638 T177 T420"
18
+ python inference.py --task demo
19
+ """
20
+
21
+ import argparse
22
+ import math
23
+ import os
24
+ import pickle
25
+ import sys
26
+ from pathlib import Path
27
+
28
+ import torch
29
+ import torch.nn as nn
30
+ import torch.nn.functional as F
31
+
32
+
33
+ # ── Auto-download from HuggingFace ────────────────────────────
34
+ HF_REPO = "YOUR_USERNAME/indus-script-models" # update after upload
35
+
36
+ def download_models(repo_id=HF_REPO, local_dir="indus_models"):
37
+ """Download all model files from HuggingFace."""
38
+ try:
39
+ from huggingface_hub import snapshot_download
40
+ print(f"Downloading models from {repo_id}...")
41
+ path = snapshot_download(repo_id=repo_id, local_dir=local_dir)
42
+ print(f"βœ“ Downloaded to {path}")
43
+ return path
44
+ except Exception as e:
45
+ print(f"Download failed: {e}")
46
+ print("Manual download: https://huggingface.co/{repo_id}")
47
+ sys.exit(1)
48
+
49
+
50
+ def get_model_dir():
51
+ """Find model directory β€” local DATA/models or downloaded."""
52
+ # Try local development path first
53
+ local = Path("DATA/models")
54
+ if local.exists():
55
+ return local, Path("DATA")
56
+ # Try downloaded path
57
+ downloaded = Path("indus_models")
58
+ if downloaded.exists():
59
+ return downloaded / "models", downloaded
60
+ # Auto-download
61
+ path = download_models()
62
+ return Path(path) / "models", Path(path)
63
+
64
+
65
+ # ── Device ─────────────────────────────────────────────────────
66
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
67
+
68
+ BOS_ID = 814
69
+ EOS_ID = 815
70
+ PAD_ID = 816
71
+
72
+
73
+ # ── Load helpers ───────────────────────────────────────────────
74
+ def load_tokenizer(data_dir):
75
+ from transformers import PreTrainedTokenizerFast
76
+ return PreTrainedTokenizerFast.from_pretrained(
77
+ str(data_dir / "indus_tokenizer"))
78
+
79
+
80
+ def load_bert_mlm(model_dir):
81
+ from transformers import BertForMaskedLM
82
+ return BertForMaskedLM.from_pretrained(
83
+ str(model_dir / "mlm" / "best")).to(device).eval()
84
+
85
+
86
+ def load_bert_cls(model_dir):
87
+ from transformers import BertForSequenceClassification
88
+ return BertForSequenceClassification.from_pretrained(
89
+ str(model_dir / "cls" / "best")).to(device).eval()
90
+
91
+
92
+ def load_ngram(model_dir):
93
+ # indus_ngram.py must be importable
94
+ sys.path.insert(0, str(Path(__file__).parent))
95
+ with open(model_dir / "ngram_model.pkl", "rb") as f:
96
+ return pickle.load(f)
97
+
98
+
99
+ def load_electra(model_dir):
100
+ from transformers import BertModel, BertConfig, PreTrainedTokenizerFast
101
+ import json
102
+
103
+ class ElectraDisc(nn.Module):
104
+ def __init__(self, cfg):
105
+ super().__init__()
106
+ self.bert = BertModel(cfg)
107
+ self.classifier = nn.Linear(cfg.hidden_size, 2)
108
+ self.dropout = nn.Dropout(0.1)
109
+
110
+ def forward(self, input_ids, attention_mask):
111
+ out = self.bert(input_ids=input_ids,
112
+ attention_mask=attention_mask)
113
+ return self.classifier(self.dropout(out.last_hidden_state))
114
+
115
+ p = model_dir / "electra" / "best"
116
+ with open(p / "discriminator_config.json") as f:
117
+ cfg = json.load(f)
118
+ m = ElectraDisc(BertConfig(**cfg))
119
+ m.load_state_dict(torch.load(p / "discriminator.pt",
120
+ map_location=device, weights_only=True))
121
+ tok = PreTrainedTokenizerFast.from_pretrained(str(p))
122
+ return tok, m.to(device).eval()
123
+
124
+
125
+ def load_nanogpt(model_dir):
126
+ ckpt = torch.load(model_dir / "nanogpt_indus.pt",
127
+ map_location=device, weights_only=False)
128
+ cfg = ckpt["cfg"]
129
+
130
+ class CSA(nn.Module):
131
+ def __init__(self, c):
132
+ super().__init__()
133
+ self.nh = c["n_head"]; self.ne = c["n_embd"]
134
+ self.hd = c["n_embd"] // c["n_head"]
135
+ self.qkv = nn.Linear(c["n_embd"], 3*c["n_embd"], bias=False)
136
+ self.proj = nn.Linear(c["n_embd"], c["n_embd"], bias=False)
137
+ self.drop = nn.Dropout(c["dropout"])
138
+ ml = c["block_size"]
139
+ self.register_buffer("mask",
140
+ torch.tril(torch.ones(ml, ml)).view(1, 1, ml, ml))
141
+
142
+ def forward(self, x):
143
+ B, T, C = x.shape
144
+ q, k, v = self.qkv(x).split(self.ne, dim=2)
145
+ sh = lambda t: t.view(B, T, self.nh, self.hd).transpose(1, 2)
146
+ q, k, v = sh(q), sh(k), sh(v)
147
+ a = (q @ k.transpose(-2, -1)) / math.sqrt(self.hd)
148
+ a = a.masked_fill(self.mask[:,:,:T,:T] == 0, float("-inf"))
149
+ return self.proj(
150
+ (self.drop(F.softmax(a, dim=-1)) @ v)
151
+ .transpose(1, 2).contiguous().view(B, T, C))
152
+
153
+ class TB(nn.Module):
154
+ def __init__(self, c):
155
+ super().__init__()
156
+ self.ln1 = nn.LayerNorm(c["n_embd"]); self.attn = CSA(c)
157
+ self.ln2 = nn.LayerNorm(c["n_embd"])
158
+ self.ffn = nn.Sequential(
159
+ nn.Linear(c["n_embd"], 4*c["n_embd"]), nn.GELU(),
160
+ nn.Linear(4*c["n_embd"], c["n_embd"]), nn.Dropout(c["dropout"]))
161
+ def forward(self, x):
162
+ return x + self.ffn(self.ln2(x + self.attn(self.ln1(x))))
163
+
164
+ class GPT(nn.Module):
165
+ def __init__(self, c):
166
+ super().__init__()
167
+ self.cfg = c
168
+ self.tok_emb = nn.Embedding(c["vocab_size"], c["n_embd"])
169
+ self.pos_emb = nn.Embedding(c["block_size"], c["n_embd"])
170
+ self.drop = nn.Dropout(c["dropout"])
171
+ self.blocks = nn.ModuleList([TB(c) for _ in range(c["n_layer"])])
172
+ self.ln_f = nn.LayerNorm(c["n_embd"])
173
+ self.head = nn.Linear(c["n_embd"], c["vocab_size"], bias=False)
174
+ self.tok_emb.weight = self.head.weight
175
+
176
+ def forward(self, idx):
177
+ B, T = idx.shape
178
+ x = self.drop(self.tok_emb(idx) + self.pos_emb(
179
+ torch.arange(T, device=idx.device).unsqueeze(0)))
180
+ for b in self.blocks: x = b(x)
181
+ return self.head(self.ln_f(x))
182
+
183
+ @torch.no_grad()
184
+ def generate(self, temperature=0.85, top_k=40, max_len=15):
185
+ self.eval()
186
+ idx = torch.tensor([[BOS_ID]], device=device)
187
+ gen = []
188
+ for _ in range(max_len):
189
+ logits = self(idx[:, -self.cfg["block_size"]:])[: ,-1, :] / temperature
190
+ logits[:, PAD_ID] = logits[:, BOS_ID] = logits[:, EOS_ID] = float("-inf")
191
+ if top_k > 0:
192
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
193
+ logits[logits < v[:, [-1]]] = float("-inf")
194
+ nxt = torch.multinomial(F.softmax(logits, dim=-1), 1)
195
+ if nxt.item() == EOS_ID: break
196
+ gen.append(nxt.item())
197
+ idx = torch.cat([idx, nxt], dim=1)
198
+ return list(reversed(gen)) # RTL→LTR
199
+
200
+ m = GPT(cfg)
201
+ m.load_state_dict(ckpt["model_state"])
202
+ return m.to(device).eval()
203
+
204
+
205
+ # ── Scoring functions ──────────────────────────────────────────
206
+ def parse_sequence(seq_str):
207
+ """Parse 'T638 T177 T420' or '638 177 420' into list of ints."""
208
+ tokens = seq_str.strip().split()
209
+ ids = []
210
+ for t in tokens:
211
+ if t.upper() == "[MASK]":
212
+ ids.append(None)
213
+ else:
214
+ t = t.upper().lstrip("T")
215
+ ids.append(int(t))
216
+ return ids
217
+
218
+
219
+ def bert_validity_score(seq, tok, cls_model):
220
+ text = " ".join(f"T{t}" for t in seq)
221
+ enc = tok(text, return_tensors="pt", truncation=True,
222
+ max_length=32).to(device)
223
+ with torch.no_grad():
224
+ return float(torch.softmax(cls_model(**enc).logits, dim=-1)[0][1])
225
+
226
+
227
+ def bert_predict_mask(seq_with_none, tok, mlm_model, top_k=5):
228
+ parts = ["[MASK]" if t is None else f"T{t}" for t in seq_with_none]
229
+ enc = tok(" ".join(parts), return_tensors="pt",
230
+ truncation=True, max_length=32).to(device)
231
+ with torch.no_grad():
232
+ logits = mlm_model(**enc).logits
233
+ results = {}
234
+ for pos, val in enumerate(seq_with_none):
235
+ if val is not None: continue
236
+ tp, ti = torch.softmax(logits[0, pos+1], dim=-1).topk(top_k)
237
+ preds = []
238
+ for p, tid in zip(tp.tolist(), ti.tolist()):
239
+ ts = tok.convert_ids_to_tokens([tid])[0]
240
+ if ts.startswith("T") and ts[1:].isdigit():
241
+ preds.append((int(ts[1:]), round(p, 4)))
242
+ results[pos] = preds
243
+ return results
244
+
245
+
246
+ def electra_score(seq, tok, disc):
247
+ enc = tok(" ".join(f"T{t}" for t in seq), return_tensors="pt",
248
+ truncation=True, max_length=32).to(device)
249
+ with torch.no_grad():
250
+ logits = disc(enc["input_ids"], enc["attention_mask"])
251
+ probs = torch.softmax(logits[0], dim=-1)
252
+ n = min(len(seq), probs.shape[0]-1)
253
+ return float(probs[1:n+1, 0].mean())
254
+
255
+
256
+ def ensemble_score(seq, tok, cls, ngram, elec_tok, elec_disc):
257
+ b = bert_validity_score(seq, tok, cls)
258
+ n = ngram.validity_score(seq)
259
+ e = electra_score(seq, elec_tok, elec_disc)
260
+ return 0.50*b + 0.25*n + 0.25*e, b, n, e
261
+
262
+
263
+ def load_glyph_map(data_dir):
264
+ import json
265
+ p = data_dir / "id_to_glyph.json"
266
+ if p.exists():
267
+ with open(p, encoding="utf-8") as f:
268
+ return json.load(f)
269
+ return {}
270
+
271
+
272
+ # ── Tasks ──────────────────────────────────────────────────────
273
+ def task_validate(seq_str, models):
274
+ tok, cls, ngram, elec_tok, elec_disc, glyph_map = models
275
+ seq = parse_sequence(seq_str)
276
+ if any(t is None for t in seq):
277
+ print("Use --task predict for sequences with [MASK]")
278
+ return
279
+ ens, b, n, e = ensemble_score(seq, tok, cls, ngram, elec_tok, elec_disc)
280
+ glyphs = "".join(glyph_map.get(str(t), f"[{t}]") for t in seq)
281
+ print(f"\n Sequence : {' '.join(f'T{t}' for t in seq)}")
282
+ print(f" Glyphs : {glyphs}")
283
+ print(f" BERT : {b:.4f}")
284
+ print(f" N-gram : {n:.4f}")
285
+ print(f" ELECTRA : {e:.4f}")
286
+ print(f" Ensemble : {ens:.4f}")
287
+ print(f" Verdict : {'βœ… VALID (β‰₯85%)' if ens >= 0.85 else '⚠ UNCERTAIN (β‰₯70%)' if ens >= 0.70 else '❌ INVALID (<70%)'}")
288
+
289
+
290
+ def task_predict(seq_str, models):
291
+ tok, cls, ngram, elec_tok, elec_disc, glyph_map = models
292
+ mlm = load_bert_mlm(models[0].__class__) # reload MLM
293
+ seq = parse_sequence(seq_str)
294
+ preds = bert_predict_mask(seq, tok, mlm, top_k=5)
295
+ print(f"\n Input: {seq_str}")
296
+ for pos, candidates in preds.items():
297
+ print(f"\n Position {pos} predictions:")
298
+ for sign_id, prob in candidates:
299
+ g = glyph_map.get(str(sign_id), "?")
300
+ print(f" T{sign_id:<5} {g} {prob*100:>6.2f}%")
301
+
302
+
303
+ def task_generate(count, models, threshold=0.85):
304
+ tok, cls, ngram, elec_tok, elec_disc, glyph_map = models
305
+ model_dir, data_dir = get_model_dir()
306
+ gpt = load_nanogpt(model_dir)
307
+ kept = []
308
+ seen = set()
309
+ attempts = 0
310
+
311
+ print(f"\n Generating (threshold={threshold:.0%})...\n")
312
+ temps = [0.85, 0.90, 1.00, 1.10]
313
+ topks = [40, 50, 60, 80 ]
314
+
315
+ while len(kept) < count and attempts < count * 100:
316
+ i = attempts % len(temps)
317
+ seq = gpt.generate(temperature=temps[i], top_k=topks[i])
318
+ attempts += 1
319
+ if len(seq) < 2 or tuple(seq) in seen: continue
320
+ seen.add(tuple(seq))
321
+ ens, b, n, e = ensemble_score(seq, tok, cls, ngram, elec_tok, elec_disc)
322
+ if ens >= threshold:
323
+ glyphs = "".join(glyph_map.get(str(t), "?") for t in seq)
324
+ kept.append((seq, ens, glyphs))
325
+ seq_str = " ".join(f"T{t}" for t in seq)
326
+ print(f" {len(kept):>3}. {glyphs} | {seq_str} | score={ens:.3f}")
327
+
328
+ print(f"\n Generated {len(kept)} sequences in {attempts} attempts")
329
+ return kept
330
+
331
+
332
+ def task_score(seq_str, models):
333
+ task_validate(seq_str, models)
334
+
335
+
336
+ def task_demo(models, glyph_map):
337
+ print("\n" + "="*60)
338
+ print(" INDUS SCRIPT β€” INFERENCE DEMO")
339
+ print("="*60)
340
+
341
+ examples = [
342
+ ("T638 T177 T420 T122", "Known valid sequence"),
343
+ ("T604 T123 T609", "Known formula (appears on 80+ seals)"),
344
+ ("T406 T638 T243", "Known formula (appears on 37 seals)"),
345
+ ("T122 T638 T177", "Reversed β€” should score lower"),
346
+ ("T999 T888 T777", "Invalid token IDs"),
347
+ ]
348
+
349
+ tok, cls, ngram, elec_tok, elec_disc, glyph_map = models
350
+ print(f"\n {'Sequence':<35} {'Ensemble':>9} Verdict")
351
+ print(" " + "─"*58)
352
+ for seq_str, label in examples:
353
+ try:
354
+ seq = [int(t.lstrip("T")) for t in seq_str.split()]
355
+ ens, b, n, e = ensemble_score(seq, tok, cls, ngram, elec_tok, elec_disc)
356
+ g = "".join(glyph_map.get(str(t),"?") for t in seq)
357
+ verdict = "βœ…" if ens>=0.85 else "⚠" if ens>=0.70 else "❌"
358
+ print(f" {seq_str:<35} {ens:>8.3f} {verdict} {label}")
359
+ except Exception:
360
+ print(f" {seq_str:<35} {'β€”':>9} ❌ {label}")
361
+
362
+
363
+ # ── Main ───────────────────────────────────────────────────────
364
+ def main():
365
+ parser = argparse.ArgumentParser(description="Indus Script Inference")
366
+ parser.add_argument("--task", choices=["validate","predict","generate","score","demo"],
367
+ default="demo")
368
+ parser.add_argument("--sequence", type=str, default="T638 T177 T420 T122",
369
+ help="Sequence like 'T638 T177 T420' or 'T638 [MASK] T420'")
370
+ parser.add_argument("--count", type=int, default=10,
371
+ help="Number of sequences to generate")
372
+ parser.add_argument("--threshold",type=float, default=0.85)
373
+ parser.add_argument("--download", action="store_true",
374
+ help="Force re-download from HuggingFace")
375
+ args = parser.parse_args()
376
+
377
+ if args.download:
378
+ download_models()
379
+
380
+ print("Loading models...")
381
+ model_dir, data_dir = get_model_dir()
382
+
383
+ tok = load_tokenizer(data_dir)
384
+ cls = load_bert_cls(model_dir); print(" βœ“ TinyBERT")
385
+ ngram = load_ngram(model_dir); print(" βœ“ N-gram")
386
+ elec_tok, elec_disc = load_electra(model_dir); print(" βœ“ ELECTRA")
387
+ glyph_map = load_glyph_map(data_dir)
388
+
389
+ models = (tok, cls, ngram, elec_tok, elec_disc, glyph_map)
390
+
391
+ if args.task == "validate": task_validate(args.sequence, models)
392
+ elif args.task == "predict": task_predict(args.sequence, models)
393
+ elif args.task == "generate": task_generate(args.count, models, args.threshold)
394
+ elif args.task == "score": task_score(args.sequence, models)
395
+ elif args.task == "demo": task_demo(models, glyph_map)
396
+
397
+
398
+ if __name__ == "__main__":
399
+ main()