File size: 16,119 Bytes
a4f4b5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3872f06
a4f4b5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3872f06
 
 
 
 
 
 
 
 
 
 
 
 
a4f4b5c
 
 
3872f06
a4f4b5c
3872f06
a4f4b5c
 
 
 
 
 
 
 
 
 
 
 
 
3872f06
 
 
 
 
a4f4b5c
 
 
 
 
60e393a
a4f4b5c
 
 
 
 
60e393a
a4f4b5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60e393a
a4f4b5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92c23ad
 
a4f4b5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92c23ad
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
"""
Indus Script β€” Inference & Generation
======================================
Download models from HuggingFace and run:
  1. Sequence validation  β€” is this inscription valid?
  2. Sign prediction      β€” predict a masked sign
  3. Generate synthetic   β€” generate new Indus sequences
  4. Score any sequence   β€” get ensemble confidence score

Install:
    pip install torch transformers huggingface_hub

Usage:
    python inference.py --task validate --sequence "T638 T177 T420 T122"
    python inference.py --task predict  --sequence "T638 [MASK] T420 T122"
    python inference.py --task generate --count 10
    python inference.py --task score    --sequence "T638 T177 T420"
    python inference.py --task demo
"""

import argparse
import math
import os
import pickle
import sys
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F


# ── Auto-download from HuggingFace ────────────────────────────
HF_REPO = "hellosindh/indus-script-models"   # update after upload

def download_models(repo_id=HF_REPO, local_dir="indus_models"):
    """Download all model files from HuggingFace."""
    try:
        from huggingface_hub import snapshot_download
        print(f"Downloading models from {repo_id}...")
        path = snapshot_download(repo_id=repo_id, local_dir=local_dir)
        print(f"βœ“ Downloaded to {path}")
        return path
    except Exception as e:
        print(f"Download failed: {e}")
        print("Manual download: https://huggingface.co/{repo_id}")
        sys.exit(1)


def get_model_dir():
    """
    Find model directory.
    Priority:
      1. ./models/  (running from cloned HuggingFace repo)
      2. DATA/models/  (running from original indus_script folder)
      3. Auto-download from HuggingFace
    """
    # Running from cloned repo β€” models/ is right here
    cloned = Path("models")
    if cloned.exists() and (cloned / "nanogpt_indus.pt").exists():
        data = Path("data") if Path("data").exists() else Path(".")
        return cloned, data
    # Running from original indus_script folder
    local = Path("DATA/models")
    if local.exists():
        return local, Path("DATA")
    # Auto-download from HuggingFace
    path = download_models()
    return Path(path) / "models", Path(path) / "data"


# ── Device ─────────────────────────────────────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

BOS_ID = 814
EOS_ID = 815
PAD_ID = 816


# ── Load helpers ───────────────────────────────────────────────
def load_tokenizer(data_dir):
    from transformers import PreTrainedTokenizerFast
    # Try data/indus_tokenizer first, then just data_dir itself
    tok_path = data_dir / "indus_tokenizer"
    if not tok_path.exists():
        tok_path = data_dir
    return PreTrainedTokenizerFast.from_pretrained(str(tok_path))


def load_bert_mlm(model_dir):
    from transformers import BertForMaskedLM
    return BertForMaskedLM.from_pretrained(
        str(model_dir / "mlm")).to(device).eval()


def load_bert_cls(model_dir):
    from transformers import BertForSequenceClassification
    return BertForSequenceClassification.from_pretrained(
        str(model_dir / "cls")).to(device).eval()


def load_ngram(model_dir):
    # indus_ngram.py must be importable
    sys.path.insert(0, str(Path(__file__).parent))
    with open(model_dir / "ngram_model.pkl", "rb") as f:
        return pickle.load(f)


def load_electra(model_dir):
    from transformers import BertModel, BertConfig, PreTrainedTokenizerFast
    import json

    class ElectraDisc(nn.Module):
        def __init__(self, cfg):
            super().__init__()
            self.bert       = BertModel(cfg)
            self.classifier = nn.Linear(cfg.hidden_size, 2)
            self.dropout    = nn.Dropout(0.1)

        def forward(self, input_ids, attention_mask):
            out = self.bert(input_ids=input_ids,
                            attention_mask=attention_mask)
            return self.classifier(self.dropout(out.last_hidden_state))

    p = model_dir / "electra"
    with open(p / "discriminator_config.json") as f:
        cfg = json.load(f)
    m = ElectraDisc(BertConfig(**cfg))
    m.load_state_dict(torch.load(p / "discriminator.pt",
                                  map_location=device, weights_only=True))
    tok = PreTrainedTokenizerFast.from_pretrained(str(p))
    return tok, m.to(device).eval()


def load_nanogpt(model_dir):
    ckpt = torch.load(model_dir / "nanogpt_indus.pt",
                      map_location=device, weights_only=False)
    cfg  = ckpt["cfg"]

    class CSA(nn.Module):
        def __init__(self, c):
            super().__init__()
            self.nh = c["n_head"]; self.ne = c["n_embd"]
            self.hd = c["n_embd"] // c["n_head"]
            self.qkv  = nn.Linear(c["n_embd"], 3*c["n_embd"], bias=False)
            self.proj = nn.Linear(c["n_embd"], c["n_embd"],   bias=False)
            self.drop = nn.Dropout(c["dropout"])
            ml = c["block_size"]
            self.register_buffer("mask",
                torch.tril(torch.ones(ml, ml)).view(1, 1, ml, ml))

        def forward(self, x):
            B, T, C = x.shape
            q, k, v = self.qkv(x).split(self.ne, dim=2)
            sh = lambda t: t.view(B, T, self.nh, self.hd).transpose(1, 2)
            q, k, v = sh(q), sh(k), sh(v)
            a = (q @ k.transpose(-2, -1)) / math.sqrt(self.hd)
            a = a.masked_fill(self.mask[:,:,:T,:T] == 0, float("-inf"))
            return self.proj(
                (self.drop(F.softmax(a, dim=-1)) @ v)
                .transpose(1, 2).contiguous().view(B, T, C))

    class TB(nn.Module):
        def __init__(self, c):
            super().__init__()
            self.ln1  = nn.LayerNorm(c["n_embd"]); self.attn = CSA(c)
            self.ln2  = nn.LayerNorm(c["n_embd"])
            self.ffn  = nn.Sequential(
                nn.Linear(c["n_embd"], 4*c["n_embd"]), nn.GELU(),
                nn.Linear(4*c["n_embd"], c["n_embd"]), nn.Dropout(c["dropout"]))
        def forward(self, x):
            return x + self.ffn(self.ln2(x + self.attn(self.ln1(x))))

    class GPT(nn.Module):
        def __init__(self, c):
            super().__init__()
            self.cfg     = c
            self.tok_emb = nn.Embedding(c["vocab_size"], c["n_embd"])
            self.pos_emb = nn.Embedding(c["block_size"], c["n_embd"])
            self.drop    = nn.Dropout(c["dropout"])
            self.blocks  = nn.ModuleList([TB(c) for _ in range(c["n_layer"])])
            self.ln_f    = nn.LayerNorm(c["n_embd"])
            self.head    = nn.Linear(c["n_embd"], c["vocab_size"], bias=False)
            self.tok_emb.weight = self.head.weight

        def forward(self, idx):
            B, T = idx.shape
            x = self.drop(self.tok_emb(idx) + self.pos_emb(
                torch.arange(T, device=idx.device).unsqueeze(0)))
            for b in self.blocks: x = b(x)
            return self.head(self.ln_f(x))

        @torch.no_grad()
        def generate(self, temperature=0.85, top_k=40, max_len=15):
            self.eval()
            idx = torch.tensor([[BOS_ID]], device=device)
            gen = []
            for _ in range(max_len):
                logits = self(idx[:, -self.cfg["block_size"]:])[: ,-1, :] / temperature
                logits[:, PAD_ID] = logits[:, BOS_ID] = logits[:, EOS_ID] = float("-inf")
                if top_k > 0:
                    v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                    logits[logits < v[:, [-1]]] = float("-inf")
                nxt = torch.multinomial(F.softmax(logits, dim=-1), 1)
                if nxt.item() == EOS_ID: break
                gen.append(nxt.item())
                idx = torch.cat([idx, nxt], dim=1)
            return list(reversed(gen))  # RTL→LTR

    m = GPT(cfg)
    m.load_state_dict(ckpt["model_state"])
    return m.to(device).eval()


# ── Scoring functions ──────────────────────────────────────────
def parse_sequence(seq_str):
    """Parse 'T638 T177 T420' or '638 177 420' into list of ints."""
    tokens = seq_str.strip().split()
    ids = []
    for t in tokens:
        if t.upper() == "[MASK]":
            ids.append(None)
        else:
            t = t.upper().lstrip("T")
            ids.append(int(t))
    return ids


def bert_validity_score(seq, tok, cls_model):
    text = " ".join(f"T{t}" for t in seq)
    enc  = tok(text, return_tensors="pt", truncation=True,
               max_length=32).to(device)
    with torch.no_grad():
        return float(torch.softmax(cls_model(**enc).logits, dim=-1)[0][1])


def bert_predict_mask(seq_with_none, tok, mlm_model, top_k=5):
    parts = ["[MASK]" if t is None else f"T{t}" for t in seq_with_none]
    enc   = tok(" ".join(parts), return_tensors="pt",
                truncation=True, max_length=32).to(device)
    with torch.no_grad():
        logits = mlm_model(**enc).logits
    results = {}
    for pos, val in enumerate(seq_with_none):
        if val is not None: continue
        tp, ti = torch.softmax(logits[0, pos+1], dim=-1).topk(top_k)
        preds  = []
        for p, tid in zip(tp.tolist(), ti.tolist()):
            ts = tok.convert_ids_to_tokens([tid])[0]
            if ts.startswith("T") and ts[1:].isdigit():
                preds.append((int(ts[1:]), round(p, 4)))
        results[pos] = preds
    return results


def electra_score(seq, tok, disc):
    enc = tok(" ".join(f"T{t}" for t in seq), return_tensors="pt",
               truncation=True, max_length=32).to(device)
    with torch.no_grad():
        logits = disc(enc["input_ids"], enc["attention_mask"])
    probs = torch.softmax(logits[0], dim=-1)
    n     = min(len(seq), probs.shape[0]-1)
    return float(probs[1:n+1, 0].mean())


def ensemble_score(seq, tok, cls, ngram, elec_tok, elec_disc):
    b = bert_validity_score(seq, tok, cls)
    n = ngram.validity_score(seq)
    e = electra_score(seq, elec_tok, elec_disc)
    return 0.50*b + 0.25*n + 0.25*e, b, n, e


def load_glyph_map(data_dir):
    import json
    p = data_dir / "id_to_glyph.json"
    if p.exists():
        with open(p, encoding="utf-8") as f:
            return json.load(f)
    return {}


# ── Tasks ──────────────────────────────────────────────────────
def task_validate(seq_str, models):
    tok, cls, ngram, elec_tok, elec_disc, glyph_map = models
    seq = parse_sequence(seq_str)
    if any(t is None for t in seq):
        print("Use --task predict for sequences with [MASK]")
        return
    ens, b, n, e = ensemble_score(seq, tok, cls, ngram, elec_tok, elec_disc)
    glyphs = "".join(glyph_map.get(str(t), f"[{t}]") for t in seq)
    print(f"\n  Sequence  : {' '.join(f'T{t}' for t in seq)}")
    print(f"  Glyphs    : {glyphs}")
    print(f"  BERT      : {b:.4f}")
    print(f"  N-gram    : {n:.4f}")
    print(f"  ELECTRA   : {e:.4f}")
    print(f"  Ensemble  : {ens:.4f}")
    print(f"  Verdict   : {'βœ… VALID (β‰₯85%)' if ens >= 0.85 else '⚠ UNCERTAIN (β‰₯70%)' if ens >= 0.70 else '❌ INVALID (<70%)'}")


def task_predict(seq_str, models):
    tok, cls, ngram, elec_tok, elec_disc, glyph_map = models
    model_dir, data_dir = get_model_dir()
    mlm = load_bert_mlm(model_dir)
    seq = parse_sequence(seq_str)
    preds = bert_predict_mask(seq, tok, mlm, top_k=5)
    print(f"\n  Input: {seq_str}")
    for pos, candidates in preds.items():
        print(f"\n  Position {pos} predictions:")
        for sign_id, prob in candidates:
            g = glyph_map.get(str(sign_id), "?")
            print(f"    T{sign_id:<5} {g}  {prob*100:>6.2f}%")


def task_generate(count, models, threshold=0.85):
    tok, cls, ngram, elec_tok, elec_disc, glyph_map = models
    model_dir, data_dir = get_model_dir()
    gpt    = load_nanogpt(model_dir)
    kept   = []
    seen   = set()
    attempts = 0

    print(f"\n  Generating (threshold={threshold:.0%})...\n")
    temps = [0.85, 0.90, 1.00, 1.10]
    topks = [40,   50,   60,   80  ]

    while len(kept) < count and attempts < count * 100:
        i    = attempts % len(temps)
        seq  = gpt.generate(temperature=temps[i], top_k=topks[i])
        attempts += 1
        if len(seq) < 2 or tuple(seq) in seen: continue
        seen.add(tuple(seq))
        ens, b, n, e = ensemble_score(seq, tok, cls, ngram, elec_tok, elec_disc)
        if ens >= threshold:
            glyphs = "".join(glyph_map.get(str(t), "?") for t in seq)
            kept.append((seq, ens, glyphs))
            seq_str = " ".join(f"T{t}" for t in seq)
            print(f"  {len(kept):>3}. {glyphs}  |  {seq_str}  |  score={ens:.3f}")

    print(f"\n  Generated {len(kept)} sequences in {attempts} attempts")
    return kept


def task_score(seq_str, models):
    task_validate(seq_str, models)


def task_demo(models, glyph_map):
    print("\n" + "="*60)
    print("  INDUS SCRIPT β€” INFERENCE DEMO")
    print("="*60)

    examples = [
        ("T638 T177 T420 T122",  "Known valid sequence"),
        ("T604 T123 T609",       "Known formula (appears on 80+ seals)"),
        ("T406 T638 T243",       "Known formula (appears on 37 seals)"),
        ("T122 T638 T177",       "Reversed β€” should score lower"),
        ("T999 T888 T777",       "Invalid token IDs"),
    ]

    tok, cls, ngram, elec_tok, elec_disc, glyph_map = models
    print(f"\n  {'Sequence':<35} {'Ensemble':>9}  Verdict")
    print("  " + "─"*58)
    for seq_str, label in examples:
        try:
            seq = [int(t.lstrip("T")) for t in seq_str.split()]
            ens, b, n, e = ensemble_score(seq, tok, cls, ngram, elec_tok, elec_disc)
            g = "".join(glyph_map.get(str(t),"?") for t in seq)
            verdict = "βœ…" if ens>=0.85 else "⚠" if ens>=0.70 else "❌"
            print(f"  {seq_str:<35} {ens:>8.3f}  {verdict}  {label}")
        except Exception:
            print(f"  {seq_str:<35} {'β€”':>9}  ❌  {label}")


# ── Main ───────────────────────────────────────────────────────
def main():
    parser = argparse.ArgumentParser(description="Indus Script Inference")
    parser.add_argument("--task",     choices=["validate","predict","generate","score","demo"],
                        default="demo")
    parser.add_argument("--sequence", type=str, default="T638 T177 T420 T122",
                        help="Sequence like 'T638 T177 T420' or 'T638 [MASK] T420'")
    parser.add_argument("--count",    type=int, default=10,
                        help="Number of sequences to generate")
    parser.add_argument("--threshold",type=float, default=0.85)
    parser.add_argument("--download", action="store_true",
                        help="Force re-download from HuggingFace")
    args = parser.parse_args()

    if args.download:
        download_models()

    print("Loading models...")
    model_dir, data_dir = get_model_dir()

    tok       = load_tokenizer(data_dir)
    cls       = load_bert_cls(model_dir);    print("  βœ“ TinyBERT")
    ngram     = load_ngram(model_dir);       print("  βœ“ N-gram")
    elec_tok, elec_disc = load_electra(model_dir); print("  βœ“ ELECTRA")
    glyph_map = load_glyph_map(data_dir)

    models = (tok, cls, ngram, elec_tok, elec_disc, glyph_map)

    if   args.task == "validate": task_validate(args.sequence, models)
    elif args.task == "predict":  task_predict(args.sequence,  models)
    elif args.task == "generate": task_generate(args.count,    models, args.threshold)
    elif args.task == "score":    task_score(args.sequence,    models)
    elif args.task == "demo":     task_demo(models, glyph_map)


if __name__ == "__main__":
    main()