indus-script-models / inference.py
hellosindh's picture
Update inference.py
92c23ad verified
"""
Indus Script β€” Inference & Generation
======================================
Download models from HuggingFace and run:
1. Sequence validation β€” is this inscription valid?
2. Sign prediction β€” predict a masked sign
3. Generate synthetic β€” generate new Indus sequences
4. Score any sequence β€” get ensemble confidence score
Install:
pip install torch transformers huggingface_hub
Usage:
python inference.py --task validate --sequence "T638 T177 T420 T122"
python inference.py --task predict --sequence "T638 [MASK] T420 T122"
python inference.py --task generate --count 10
python inference.py --task score --sequence "T638 T177 T420"
python inference.py --task demo
"""
import argparse
import math
import os
import pickle
import sys
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
# ── Auto-download from HuggingFace ────────────────────────────
HF_REPO = "hellosindh/indus-script-models" # update after upload
def download_models(repo_id=HF_REPO, local_dir="indus_models"):
"""Download all model files from HuggingFace."""
try:
from huggingface_hub import snapshot_download
print(f"Downloading models from {repo_id}...")
path = snapshot_download(repo_id=repo_id, local_dir=local_dir)
print(f"βœ“ Downloaded to {path}")
return path
except Exception as e:
print(f"Download failed: {e}")
print("Manual download: https://huggingface.co/{repo_id}")
sys.exit(1)
def get_model_dir():
"""
Find model directory.
Priority:
1. ./models/ (running from cloned HuggingFace repo)
2. DATA/models/ (running from original indus_script folder)
3. Auto-download from HuggingFace
"""
# Running from cloned repo β€” models/ is right here
cloned = Path("models")
if cloned.exists() and (cloned / "nanogpt_indus.pt").exists():
data = Path("data") if Path("data").exists() else Path(".")
return cloned, data
# Running from original indus_script folder
local = Path("DATA/models")
if local.exists():
return local, Path("DATA")
# Auto-download from HuggingFace
path = download_models()
return Path(path) / "models", Path(path) / "data"
# ── Device ─────────────────────────────────────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BOS_ID = 814
EOS_ID = 815
PAD_ID = 816
# ── Load helpers ───────────────────────────────────────────────
def load_tokenizer(data_dir):
from transformers import PreTrainedTokenizerFast
# Try data/indus_tokenizer first, then just data_dir itself
tok_path = data_dir / "indus_tokenizer"
if not tok_path.exists():
tok_path = data_dir
return PreTrainedTokenizerFast.from_pretrained(str(tok_path))
def load_bert_mlm(model_dir):
from transformers import BertForMaskedLM
return BertForMaskedLM.from_pretrained(
str(model_dir / "mlm")).to(device).eval()
def load_bert_cls(model_dir):
from transformers import BertForSequenceClassification
return BertForSequenceClassification.from_pretrained(
str(model_dir / "cls")).to(device).eval()
def load_ngram(model_dir):
# indus_ngram.py must be importable
sys.path.insert(0, str(Path(__file__).parent))
with open(model_dir / "ngram_model.pkl", "rb") as f:
return pickle.load(f)
def load_electra(model_dir):
from transformers import BertModel, BertConfig, PreTrainedTokenizerFast
import json
class ElectraDisc(nn.Module):
def __init__(self, cfg):
super().__init__()
self.bert = BertModel(cfg)
self.classifier = nn.Linear(cfg.hidden_size, 2)
self.dropout = nn.Dropout(0.1)
def forward(self, input_ids, attention_mask):
out = self.bert(input_ids=input_ids,
attention_mask=attention_mask)
return self.classifier(self.dropout(out.last_hidden_state))
p = model_dir / "electra"
with open(p / "discriminator_config.json") as f:
cfg = json.load(f)
m = ElectraDisc(BertConfig(**cfg))
m.load_state_dict(torch.load(p / "discriminator.pt",
map_location=device, weights_only=True))
tok = PreTrainedTokenizerFast.from_pretrained(str(p))
return tok, m.to(device).eval()
def load_nanogpt(model_dir):
ckpt = torch.load(model_dir / "nanogpt_indus.pt",
map_location=device, weights_only=False)
cfg = ckpt["cfg"]
class CSA(nn.Module):
def __init__(self, c):
super().__init__()
self.nh = c["n_head"]; self.ne = c["n_embd"]
self.hd = c["n_embd"] // c["n_head"]
self.qkv = nn.Linear(c["n_embd"], 3*c["n_embd"], bias=False)
self.proj = nn.Linear(c["n_embd"], c["n_embd"], bias=False)
self.drop = nn.Dropout(c["dropout"])
ml = c["block_size"]
self.register_buffer("mask",
torch.tril(torch.ones(ml, ml)).view(1, 1, ml, ml))
def forward(self, x):
B, T, C = x.shape
q, k, v = self.qkv(x).split(self.ne, dim=2)
sh = lambda t: t.view(B, T, self.nh, self.hd).transpose(1, 2)
q, k, v = sh(q), sh(k), sh(v)
a = (q @ k.transpose(-2, -1)) / math.sqrt(self.hd)
a = a.masked_fill(self.mask[:,:,:T,:T] == 0, float("-inf"))
return self.proj(
(self.drop(F.softmax(a, dim=-1)) @ v)
.transpose(1, 2).contiguous().view(B, T, C))
class TB(nn.Module):
def __init__(self, c):
super().__init__()
self.ln1 = nn.LayerNorm(c["n_embd"]); self.attn = CSA(c)
self.ln2 = nn.LayerNorm(c["n_embd"])
self.ffn = nn.Sequential(
nn.Linear(c["n_embd"], 4*c["n_embd"]), nn.GELU(),
nn.Linear(4*c["n_embd"], c["n_embd"]), nn.Dropout(c["dropout"]))
def forward(self, x):
return x + self.ffn(self.ln2(x + self.attn(self.ln1(x))))
class GPT(nn.Module):
def __init__(self, c):
super().__init__()
self.cfg = c
self.tok_emb = nn.Embedding(c["vocab_size"], c["n_embd"])
self.pos_emb = nn.Embedding(c["block_size"], c["n_embd"])
self.drop = nn.Dropout(c["dropout"])
self.blocks = nn.ModuleList([TB(c) for _ in range(c["n_layer"])])
self.ln_f = nn.LayerNorm(c["n_embd"])
self.head = nn.Linear(c["n_embd"], c["vocab_size"], bias=False)
self.tok_emb.weight = self.head.weight
def forward(self, idx):
B, T = idx.shape
x = self.drop(self.tok_emb(idx) + self.pos_emb(
torch.arange(T, device=idx.device).unsqueeze(0)))
for b in self.blocks: x = b(x)
return self.head(self.ln_f(x))
@torch.no_grad()
def generate(self, temperature=0.85, top_k=40, max_len=15):
self.eval()
idx = torch.tensor([[BOS_ID]], device=device)
gen = []
for _ in range(max_len):
logits = self(idx[:, -self.cfg["block_size"]:])[: ,-1, :] / temperature
logits[:, PAD_ID] = logits[:, BOS_ID] = logits[:, EOS_ID] = float("-inf")
if top_k > 0:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = float("-inf")
nxt = torch.multinomial(F.softmax(logits, dim=-1), 1)
if nxt.item() == EOS_ID: break
gen.append(nxt.item())
idx = torch.cat([idx, nxt], dim=1)
return list(reversed(gen)) # RTL→LTR
m = GPT(cfg)
m.load_state_dict(ckpt["model_state"])
return m.to(device).eval()
# ── Scoring functions ──────────────────────────────────────────
def parse_sequence(seq_str):
"""Parse 'T638 T177 T420' or '638 177 420' into list of ints."""
tokens = seq_str.strip().split()
ids = []
for t in tokens:
if t.upper() == "[MASK]":
ids.append(None)
else:
t = t.upper().lstrip("T")
ids.append(int(t))
return ids
def bert_validity_score(seq, tok, cls_model):
text = " ".join(f"T{t}" for t in seq)
enc = tok(text, return_tensors="pt", truncation=True,
max_length=32).to(device)
with torch.no_grad():
return float(torch.softmax(cls_model(**enc).logits, dim=-1)[0][1])
def bert_predict_mask(seq_with_none, tok, mlm_model, top_k=5):
parts = ["[MASK]" if t is None else f"T{t}" for t in seq_with_none]
enc = tok(" ".join(parts), return_tensors="pt",
truncation=True, max_length=32).to(device)
with torch.no_grad():
logits = mlm_model(**enc).logits
results = {}
for pos, val in enumerate(seq_with_none):
if val is not None: continue
tp, ti = torch.softmax(logits[0, pos+1], dim=-1).topk(top_k)
preds = []
for p, tid in zip(tp.tolist(), ti.tolist()):
ts = tok.convert_ids_to_tokens([tid])[0]
if ts.startswith("T") and ts[1:].isdigit():
preds.append((int(ts[1:]), round(p, 4)))
results[pos] = preds
return results
def electra_score(seq, tok, disc):
enc = tok(" ".join(f"T{t}" for t in seq), return_tensors="pt",
truncation=True, max_length=32).to(device)
with torch.no_grad():
logits = disc(enc["input_ids"], enc["attention_mask"])
probs = torch.softmax(logits[0], dim=-1)
n = min(len(seq), probs.shape[0]-1)
return float(probs[1:n+1, 0].mean())
def ensemble_score(seq, tok, cls, ngram, elec_tok, elec_disc):
b = bert_validity_score(seq, tok, cls)
n = ngram.validity_score(seq)
e = electra_score(seq, elec_tok, elec_disc)
return 0.50*b + 0.25*n + 0.25*e, b, n, e
def load_glyph_map(data_dir):
import json
p = data_dir / "id_to_glyph.json"
if p.exists():
with open(p, encoding="utf-8") as f:
return json.load(f)
return {}
# ── Tasks ──────────────────────────────────────────────────────
def task_validate(seq_str, models):
tok, cls, ngram, elec_tok, elec_disc, glyph_map = models
seq = parse_sequence(seq_str)
if any(t is None for t in seq):
print("Use --task predict for sequences with [MASK]")
return
ens, b, n, e = ensemble_score(seq, tok, cls, ngram, elec_tok, elec_disc)
glyphs = "".join(glyph_map.get(str(t), f"[{t}]") for t in seq)
print(f"\n Sequence : {' '.join(f'T{t}' for t in seq)}")
print(f" Glyphs : {glyphs}")
print(f" BERT : {b:.4f}")
print(f" N-gram : {n:.4f}")
print(f" ELECTRA : {e:.4f}")
print(f" Ensemble : {ens:.4f}")
print(f" Verdict : {'βœ… VALID (β‰₯85%)' if ens >= 0.85 else '⚠ UNCERTAIN (β‰₯70%)' if ens >= 0.70 else '❌ INVALID (<70%)'}")
def task_predict(seq_str, models):
tok, cls, ngram, elec_tok, elec_disc, glyph_map = models
model_dir, data_dir = get_model_dir()
mlm = load_bert_mlm(model_dir)
seq = parse_sequence(seq_str)
preds = bert_predict_mask(seq, tok, mlm, top_k=5)
print(f"\n Input: {seq_str}")
for pos, candidates in preds.items():
print(f"\n Position {pos} predictions:")
for sign_id, prob in candidates:
g = glyph_map.get(str(sign_id), "?")
print(f" T{sign_id:<5} {g} {prob*100:>6.2f}%")
def task_generate(count, models, threshold=0.85):
tok, cls, ngram, elec_tok, elec_disc, glyph_map = models
model_dir, data_dir = get_model_dir()
gpt = load_nanogpt(model_dir)
kept = []
seen = set()
attempts = 0
print(f"\n Generating (threshold={threshold:.0%})...\n")
temps = [0.85, 0.90, 1.00, 1.10]
topks = [40, 50, 60, 80 ]
while len(kept) < count and attempts < count * 100:
i = attempts % len(temps)
seq = gpt.generate(temperature=temps[i], top_k=topks[i])
attempts += 1
if len(seq) < 2 or tuple(seq) in seen: continue
seen.add(tuple(seq))
ens, b, n, e = ensemble_score(seq, tok, cls, ngram, elec_tok, elec_disc)
if ens >= threshold:
glyphs = "".join(glyph_map.get(str(t), "?") for t in seq)
kept.append((seq, ens, glyphs))
seq_str = " ".join(f"T{t}" for t in seq)
print(f" {len(kept):>3}. {glyphs} | {seq_str} | score={ens:.3f}")
print(f"\n Generated {len(kept)} sequences in {attempts} attempts")
return kept
def task_score(seq_str, models):
task_validate(seq_str, models)
def task_demo(models, glyph_map):
print("\n" + "="*60)
print(" INDUS SCRIPT β€” INFERENCE DEMO")
print("="*60)
examples = [
("T638 T177 T420 T122", "Known valid sequence"),
("T604 T123 T609", "Known formula (appears on 80+ seals)"),
("T406 T638 T243", "Known formula (appears on 37 seals)"),
("T122 T638 T177", "Reversed β€” should score lower"),
("T999 T888 T777", "Invalid token IDs"),
]
tok, cls, ngram, elec_tok, elec_disc, glyph_map = models
print(f"\n {'Sequence':<35} {'Ensemble':>9} Verdict")
print(" " + "─"*58)
for seq_str, label in examples:
try:
seq = [int(t.lstrip("T")) for t in seq_str.split()]
ens, b, n, e = ensemble_score(seq, tok, cls, ngram, elec_tok, elec_disc)
g = "".join(glyph_map.get(str(t),"?") for t in seq)
verdict = "βœ…" if ens>=0.85 else "⚠" if ens>=0.70 else "❌"
print(f" {seq_str:<35} {ens:>8.3f} {verdict} {label}")
except Exception:
print(f" {seq_str:<35} {'β€”':>9} ❌ {label}")
# ── Main ───────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(description="Indus Script Inference")
parser.add_argument("--task", choices=["validate","predict","generate","score","demo"],
default="demo")
parser.add_argument("--sequence", type=str, default="T638 T177 T420 T122",
help="Sequence like 'T638 T177 T420' or 'T638 [MASK] T420'")
parser.add_argument("--count", type=int, default=10,
help="Number of sequences to generate")
parser.add_argument("--threshold",type=float, default=0.85)
parser.add_argument("--download", action="store_true",
help="Force re-download from HuggingFace")
args = parser.parse_args()
if args.download:
download_models()
print("Loading models...")
model_dir, data_dir = get_model_dir()
tok = load_tokenizer(data_dir)
cls = load_bert_cls(model_dir); print(" βœ“ TinyBERT")
ngram = load_ngram(model_dir); print(" βœ“ N-gram")
elec_tok, elec_disc = load_electra(model_dir); print(" βœ“ ELECTRA")
glyph_map = load_glyph_map(data_dir)
models = (tok, cls, ngram, elec_tok, elec_disc, glyph_map)
if args.task == "validate": task_validate(args.sequence, models)
elif args.task == "predict": task_predict(args.sequence, models)
elif args.task == "generate": task_generate(args.count, models, args.threshold)
elif args.task == "score": task_score(args.sequence, models)
elif args.task == "demo": task_demo(models, glyph_map)
if __name__ == "__main__":
main()