Upload indus_ngram.py
Browse files- indus_ngram.py +234 -0
indus_ngram.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
indus_ngram.py — Standalone module for InduNgramModel
|
| 3 |
+
======================================================
|
| 4 |
+
This file MUST exist so that pickle can import InduNgramModel
|
| 5 |
+
when loading ngram_model.pkl in 07_ensemble.py and 08_electra_train.py.
|
| 6 |
+
|
| 7 |
+
The pickle fix:
|
| 8 |
+
When you save a class with pickle, Python records the module path.
|
| 9 |
+
If the class was defined in __main__ (i.e. inside 06_ngram_model.py
|
| 10 |
+
when run directly), pickle saves it as __main__.InduNgramModel.
|
| 11 |
+
When another script tries to load it, __main__ refers to THAT script,
|
| 12 |
+
which doesn't have InduNgramModel — hence the AttributeError.
|
| 13 |
+
|
| 14 |
+
Solution: define the class in THIS standalone module (indus_ngram.py).
|
| 15 |
+
Both 06_ngram_model.py and 07_ensemble.py import from here.
|
| 16 |
+
Pickle records the path as indus_ngram.InduNgramModel — always findable.
|
| 17 |
+
|
| 18 |
+
Do not rename or move this file.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import math
|
| 22 |
+
import pickle
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
from collections import Counter, defaultdict
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class InduNgramModel:
|
| 28 |
+
"""
|
| 29 |
+
Kneser-Ney smoothed N-gram LM for Indus Script.
|
| 30 |
+
|
| 31 |
+
RTL mode (default, recommended):
|
| 32 |
+
Sequences reversed before training/scoring.
|
| 33 |
+
RTL bigram entropy (3.18) < LTR (3.72) → supports RTL hypothesis.
|
| 34 |
+
|
| 35 |
+
Sign roles in RTL reading direction:
|
| 36 |
+
INITIAL = reading-start sign (data position [-1])
|
| 37 |
+
TERMINAL = reading-end sign (data position [0])
|
| 38 |
+
MEDIAL = appears in middle positions
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(self, rtl=True):
|
| 42 |
+
self.rtl = rtl
|
| 43 |
+
self.unigram = Counter()
|
| 44 |
+
self.bigram = defaultdict(Counter)
|
| 45 |
+
self.trigram = defaultdict(Counter)
|
| 46 |
+
self.start_cnt = Counter()
|
| 47 |
+
self.end_cnt = Counter()
|
| 48 |
+
self.total_seqs = 0
|
| 49 |
+
self.total_tokens = 0
|
| 50 |
+
self.vocab_size = 0
|
| 51 |
+
self.D = 0.75
|
| 52 |
+
self.score_mean = 0.0
|
| 53 |
+
self.score_std = 1.0
|
| 54 |
+
self.score_min = -20.0
|
| 55 |
+
self.score_max = 0.0
|
| 56 |
+
self._pairwise_acc = 0.0
|
| 57 |
+
self._cont_right = Counter()
|
| 58 |
+
self._total_bi_types = 0
|
| 59 |
+
|
| 60 |
+
def _orient(self, seq):
|
| 61 |
+
cleaned = [t for t in seq if t is not None]
|
| 62 |
+
return list(reversed(cleaned)) if self.rtl else list(cleaned)
|
| 63 |
+
|
| 64 |
+
def train(self, sequences):
|
| 65 |
+
mode = "RTL" if self.rtl else "LTR"
|
| 66 |
+
print(f" Training [{mode}] on {len(sequences):,} sequences...")
|
| 67 |
+
|
| 68 |
+
for seq in sequences:
|
| 69 |
+
s = self._orient(seq)
|
| 70 |
+
if not s:
|
| 71 |
+
continue
|
| 72 |
+
self.total_seqs += 1
|
| 73 |
+
self.unigram.update(s)
|
| 74 |
+
self.start_cnt[s[0]] += 1
|
| 75 |
+
self.end_cnt[s[-1]] += 1
|
| 76 |
+
for i in range(len(s) - 1):
|
| 77 |
+
self.bigram[s[i]][s[i+1]] += 1
|
| 78 |
+
for i in range(len(s) - 2):
|
| 79 |
+
self.trigram[(s[i], s[i+1])][s[i+2]] += 1
|
| 80 |
+
|
| 81 |
+
self.total_tokens = sum(self.unigram.values())
|
| 82 |
+
self.vocab_size = len(self.unigram)
|
| 83 |
+
|
| 84 |
+
for a, followers in self.bigram.items():
|
| 85 |
+
for b in followers:
|
| 86 |
+
self._cont_right[b] += 1
|
| 87 |
+
self._total_bi_types = sum(self._cont_right.values())
|
| 88 |
+
|
| 89 |
+
self._calibrate(sequences)
|
| 90 |
+
print(f" Vocab : {self.vocab_size}")
|
| 91 |
+
print(f" Pairwise : {self._pairwise_acc*100:.1f}%")
|
| 92 |
+
print(f" Score range: [{self.score_min:.3f}, {self.score_max:.3f}]")
|
| 93 |
+
|
| 94 |
+
def _calibrate(self, sequences):
|
| 95 |
+
import random, statistics
|
| 96 |
+
random.seed(42)
|
| 97 |
+
all_toks = list(self.unigram.keys())
|
| 98 |
+
|
| 99 |
+
def corrupt(seq):
|
| 100 |
+
r = random.randint(0, 3)
|
| 101 |
+
c = list(seq)
|
| 102 |
+
if r == 0:
|
| 103 |
+
random.shuffle(c)
|
| 104 |
+
elif r == 1:
|
| 105 |
+
c[0] = random.choice(list(self.end_cnt.keys()))
|
| 106 |
+
elif r == 2:
|
| 107 |
+
c[-1] = random.choice(list(self.start_cnt.keys()))
|
| 108 |
+
else:
|
| 109 |
+
for p in random.sample(range(len(c)), max(1, len(c)//2)):
|
| 110 |
+
c[p] = random.choice(all_toks)
|
| 111 |
+
return c
|
| 112 |
+
|
| 113 |
+
sample = sequences[:500]
|
| 114 |
+
good = [self._raw_score(s) for s in sample]
|
| 115 |
+
bad = [self._raw_score(corrupt(s)) for s in sample]
|
| 116 |
+
all_s = good + bad
|
| 117 |
+
|
| 118 |
+
self.score_mean = statistics.mean(all_s)
|
| 119 |
+
self.score_std = statistics.stdev(all_s)
|
| 120 |
+
self.score_min = min(all_s)
|
| 121 |
+
self.score_max = max(all_s)
|
| 122 |
+
self._pairwise_acc = sum(g > b for g, b in zip(good, bad)) / len(good)
|
| 123 |
+
|
| 124 |
+
def _p_uni_kn(self, w):
|
| 125 |
+
return (self._cont_right[w] + 1) / (self._total_bi_types + self.vocab_size)
|
| 126 |
+
|
| 127 |
+
def _p_bi_kn(self, w, given):
|
| 128 |
+
gt = sum(self.bigram[given].values())
|
| 129 |
+
if gt == 0:
|
| 130 |
+
return self._p_uni_kn(w)
|
| 131 |
+
cnt = self.bigram[given].get(w, 0)
|
| 132 |
+
first = max(cnt - self.D, 0) / gt
|
| 133 |
+
lam = (self.D / gt) * len(self.bigram[given])
|
| 134 |
+
return first + lam * self._p_uni_kn(w)
|
| 135 |
+
|
| 136 |
+
def _p_tri_kn(self, w, a, b):
|
| 137 |
+
gt = sum(self.trigram[(a, b)].values())
|
| 138 |
+
if gt == 0:
|
| 139 |
+
return self._p_bi_kn(w, b)
|
| 140 |
+
cnt = self.trigram[(a, b)].get(w, 0)
|
| 141 |
+
first = max(cnt - self.D, 0) / gt
|
| 142 |
+
lam = (self.D / gt) * len(self.trigram[(a, b)])
|
| 143 |
+
return first + lam * self._p_bi_kn(w, b)
|
| 144 |
+
|
| 145 |
+
def _p_initial(self, w):
|
| 146 |
+
return (self.start_cnt[w] + 0.1) / (self.total_seqs + 0.1 * self.vocab_size)
|
| 147 |
+
|
| 148 |
+
def _p_terminal(self, w):
|
| 149 |
+
return (self.end_cnt[w] + 0.1) / (self.total_seqs + 0.1 * self.vocab_size)
|
| 150 |
+
|
| 151 |
+
def _raw_score(self, seq):
|
| 152 |
+
if not seq:
|
| 153 |
+
return self.score_min
|
| 154 |
+
s = self._orient(seq)
|
| 155 |
+
eps = 1e-10
|
| 156 |
+
lp = math.log(self._p_initial(s[0]) + eps)
|
| 157 |
+
lp += math.log(self._p_terminal(s[-1]) + eps)
|
| 158 |
+
for i, w in enumerate(s):
|
| 159 |
+
if i == 0:
|
| 160 |
+
p = self._p_uni_kn(w)
|
| 161 |
+
elif i == 1:
|
| 162 |
+
p = self._p_bi_kn(w, s[i-1])
|
| 163 |
+
else:
|
| 164 |
+
p = self._p_tri_kn(w, s[i-2], s[i-1])
|
| 165 |
+
lp += math.log(p + eps)
|
| 166 |
+
return lp / (len(s) + 2)
|
| 167 |
+
|
| 168 |
+
def validity_score(self, seq):
|
| 169 |
+
raw = self._raw_score(seq)
|
| 170 |
+
norm = (raw - self.score_min) / (self.score_max - self.score_min + 1e-10)
|
| 171 |
+
return float(max(0.02, min(0.98, norm)))
|
| 172 |
+
|
| 173 |
+
def predict_masked(self, seq_with_none, top_k=10):
|
| 174 |
+
masked = [i for i, t in enumerate(seq_with_none) if t is None]
|
| 175 |
+
results = {}
|
| 176 |
+
n = len(seq_with_none)
|
| 177 |
+
|
| 178 |
+
for orig_pos in masked:
|
| 179 |
+
ort_pos = (n - 1 - orig_pos) if self.rtl else orig_pos
|
| 180 |
+
oriented = self._orient(seq_with_none)
|
| 181 |
+
if ort_pos >= len(oriented):
|
| 182 |
+
continue
|
| 183 |
+
|
| 184 |
+
prev = oriented[ort_pos-1] if ort_pos > 0 and oriented[ort_pos-1] is not None else None
|
| 185 |
+
prev2 = oriented[ort_pos-2] if ort_pos > 1 and oriented[ort_pos-2] is not None else None
|
| 186 |
+
|
| 187 |
+
cands = []
|
| 188 |
+
for cand in self.unigram:
|
| 189 |
+
if prev2 is not None and prev is not None:
|
| 190 |
+
p = self._p_tri_kn(cand, prev2, prev)
|
| 191 |
+
elif prev is not None:
|
| 192 |
+
p = self._p_bi_kn(cand, prev)
|
| 193 |
+
else:
|
| 194 |
+
p = self._p_uni_kn(cand)
|
| 195 |
+
|
| 196 |
+
if ort_pos == 0:
|
| 197 |
+
p *= max(self._p_initial(cand) * self.vocab_size, 0.01)
|
| 198 |
+
elif ort_pos == n - 1:
|
| 199 |
+
p *= max(self._p_terminal(cand) * self.vocab_size, 0.01)
|
| 200 |
+
|
| 201 |
+
cands.append((cand, p))
|
| 202 |
+
|
| 203 |
+
cands.sort(key=lambda x: -x[1])
|
| 204 |
+
total = sum(p for _, p in cands[:top_k * 3]) or 1
|
| 205 |
+
results[orig_pos] = [
|
| 206 |
+
{"id": c, "prob": p / total, "rank": i + 1}
|
| 207 |
+
for i, (c, p) in enumerate(cands[:top_k])
|
| 208 |
+
]
|
| 209 |
+
|
| 210 |
+
return results
|
| 211 |
+
|
| 212 |
+
def sign_role(self, sign_id):
|
| 213 |
+
"""Positional role in reading direction."""
|
| 214 |
+
init_p = self.start_cnt[sign_id] / (self.total_seqs + 1)
|
| 215 |
+
term_p = self.end_cnt[sign_id] / (self.total_seqs + 1)
|
| 216 |
+
if init_p > 0.05 and init_p > term_p * 2:
|
| 217 |
+
return "INITIAL"
|
| 218 |
+
elif term_p > 0.05 and term_p > init_p * 2:
|
| 219 |
+
return "TERMINAL"
|
| 220 |
+
elif self.unigram[sign_id] > 5:
|
| 221 |
+
return "MEDIAL"
|
| 222 |
+
return "RARE"
|
| 223 |
+
|
| 224 |
+
def save(self, path):
|
| 225 |
+
path = Path(path)
|
| 226 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 227 |
+
with open(path, "wb") as f:
|
| 228 |
+
pickle.dump(self, f, protocol=pickle.HIGHEST_PROTOCOL)
|
| 229 |
+
print(f" Saved → {path}")
|
| 230 |
+
|
| 231 |
+
@staticmethod
|
| 232 |
+
def load(path):
|
| 233 |
+
with open(path, "rb") as f:
|
| 234 |
+
return pickle.load(f)
|