indus-script-models / indus_ngram.py
hellosindh's picture
Upload indus_ngram.py
5a95e08 verified
"""
indus_ngram.py — Standalone module for InduNgramModel
======================================================
This file MUST exist so that pickle can import InduNgramModel
when loading ngram_model.pkl in 07_ensemble.py and 08_electra_train.py.
The pickle fix:
When you save a class with pickle, Python records the module path.
If the class was defined in __main__ (i.e. inside 06_ngram_model.py
when run directly), pickle saves it as __main__.InduNgramModel.
When another script tries to load it, __main__ refers to THAT script,
which doesn't have InduNgramModel — hence the AttributeError.
Solution: define the class in THIS standalone module (indus_ngram.py).
Both 06_ngram_model.py and 07_ensemble.py import from here.
Pickle records the path as indus_ngram.InduNgramModel — always findable.
Do not rename or move this file.
"""
import math
import pickle
from pathlib import Path
from collections import Counter, defaultdict
class InduNgramModel:
"""
Kneser-Ney smoothed N-gram LM for Indus Script.
RTL mode (default, recommended):
Sequences reversed before training/scoring.
RTL bigram entropy (3.18) < LTR (3.72) → supports RTL hypothesis.
Sign roles in RTL reading direction:
INITIAL = reading-start sign (data position [-1])
TERMINAL = reading-end sign (data position [0])
MEDIAL = appears in middle positions
"""
def __init__(self, rtl=True):
self.rtl = rtl
self.unigram = Counter()
self.bigram = defaultdict(Counter)
self.trigram = defaultdict(Counter)
self.start_cnt = Counter()
self.end_cnt = Counter()
self.total_seqs = 0
self.total_tokens = 0
self.vocab_size = 0
self.D = 0.75
self.score_mean = 0.0
self.score_std = 1.0
self.score_min = -20.0
self.score_max = 0.0
self._pairwise_acc = 0.0
self._cont_right = Counter()
self._total_bi_types = 0
def _orient(self, seq):
cleaned = [t for t in seq if t is not None]
return list(reversed(cleaned)) if self.rtl else list(cleaned)
def train(self, sequences):
mode = "RTL" if self.rtl else "LTR"
print(f" Training [{mode}] on {len(sequences):,} sequences...")
for seq in sequences:
s = self._orient(seq)
if not s:
continue
self.total_seqs += 1
self.unigram.update(s)
self.start_cnt[s[0]] += 1
self.end_cnt[s[-1]] += 1
for i in range(len(s) - 1):
self.bigram[s[i]][s[i+1]] += 1
for i in range(len(s) - 2):
self.trigram[(s[i], s[i+1])][s[i+2]] += 1
self.total_tokens = sum(self.unigram.values())
self.vocab_size = len(self.unigram)
for a, followers in self.bigram.items():
for b in followers:
self._cont_right[b] += 1
self._total_bi_types = sum(self._cont_right.values())
self._calibrate(sequences)
print(f" Vocab : {self.vocab_size}")
print(f" Pairwise : {self._pairwise_acc*100:.1f}%")
print(f" Score range: [{self.score_min:.3f}, {self.score_max:.3f}]")
def _calibrate(self, sequences):
import random, statistics
random.seed(42)
all_toks = list(self.unigram.keys())
def corrupt(seq):
r = random.randint(0, 3)
c = list(seq)
if r == 0:
random.shuffle(c)
elif r == 1:
c[0] = random.choice(list(self.end_cnt.keys()))
elif r == 2:
c[-1] = random.choice(list(self.start_cnt.keys()))
else:
for p in random.sample(range(len(c)), max(1, len(c)//2)):
c[p] = random.choice(all_toks)
return c
sample = sequences[:500]
good = [self._raw_score(s) for s in sample]
bad = [self._raw_score(corrupt(s)) for s in sample]
all_s = good + bad
self.score_mean = statistics.mean(all_s)
self.score_std = statistics.stdev(all_s)
self.score_min = min(all_s)
self.score_max = max(all_s)
self._pairwise_acc = sum(g > b for g, b in zip(good, bad)) / len(good)
def _p_uni_kn(self, w):
return (self._cont_right[w] + 1) / (self._total_bi_types + self.vocab_size)
def _p_bi_kn(self, w, given):
gt = sum(self.bigram[given].values())
if gt == 0:
return self._p_uni_kn(w)
cnt = self.bigram[given].get(w, 0)
first = max(cnt - self.D, 0) / gt
lam = (self.D / gt) * len(self.bigram[given])
return first + lam * self._p_uni_kn(w)
def _p_tri_kn(self, w, a, b):
gt = sum(self.trigram[(a, b)].values())
if gt == 0:
return self._p_bi_kn(w, b)
cnt = self.trigram[(a, b)].get(w, 0)
first = max(cnt - self.D, 0) / gt
lam = (self.D / gt) * len(self.trigram[(a, b)])
return first + lam * self._p_bi_kn(w, b)
def _p_initial(self, w):
return (self.start_cnt[w] + 0.1) / (self.total_seqs + 0.1 * self.vocab_size)
def _p_terminal(self, w):
return (self.end_cnt[w] + 0.1) / (self.total_seqs + 0.1 * self.vocab_size)
def _raw_score(self, seq):
if not seq:
return self.score_min
s = self._orient(seq)
eps = 1e-10
lp = math.log(self._p_initial(s[0]) + eps)
lp += math.log(self._p_terminal(s[-1]) + eps)
for i, w in enumerate(s):
if i == 0:
p = self._p_uni_kn(w)
elif i == 1:
p = self._p_bi_kn(w, s[i-1])
else:
p = self._p_tri_kn(w, s[i-2], s[i-1])
lp += math.log(p + eps)
return lp / (len(s) + 2)
def validity_score(self, seq):
raw = self._raw_score(seq)
norm = (raw - self.score_min) / (self.score_max - self.score_min + 1e-10)
return float(max(0.02, min(0.98, norm)))
def predict_masked(self, seq_with_none, top_k=10):
masked = [i for i, t in enumerate(seq_with_none) if t is None]
results = {}
n = len(seq_with_none)
for orig_pos in masked:
ort_pos = (n - 1 - orig_pos) if self.rtl else orig_pos
oriented = self._orient(seq_with_none)
if ort_pos >= len(oriented):
continue
prev = oriented[ort_pos-1] if ort_pos > 0 and oriented[ort_pos-1] is not None else None
prev2 = oriented[ort_pos-2] if ort_pos > 1 and oriented[ort_pos-2] is not None else None
cands = []
for cand in self.unigram:
if prev2 is not None and prev is not None:
p = self._p_tri_kn(cand, prev2, prev)
elif prev is not None:
p = self._p_bi_kn(cand, prev)
else:
p = self._p_uni_kn(cand)
if ort_pos == 0:
p *= max(self._p_initial(cand) * self.vocab_size, 0.01)
elif ort_pos == n - 1:
p *= max(self._p_terminal(cand) * self.vocab_size, 0.01)
cands.append((cand, p))
cands.sort(key=lambda x: -x[1])
total = sum(p for _, p in cands[:top_k * 3]) or 1
results[orig_pos] = [
{"id": c, "prob": p / total, "rank": i + 1}
for i, (c, p) in enumerate(cands[:top_k])
]
return results
def sign_role(self, sign_id):
"""Positional role in reading direction."""
init_p = self.start_cnt[sign_id] / (self.total_seqs + 1)
term_p = self.end_cnt[sign_id] / (self.total_seqs + 1)
if init_p > 0.05 and init_p > term_p * 2:
return "INITIAL"
elif term_p > 0.05 and term_p > init_p * 2:
return "TERMINAL"
elif self.unigram[sign_id] > 5:
return "MEDIAL"
return "RARE"
def save(self, path):
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "wb") as f:
pickle.dump(self, f, protocol=pickle.HIGHEST_PROTOCOL)
print(f" Saved → {path}")
@staticmethod
def load(path):
with open(path, "rb") as f:
return pickle.load(f)