hellosindh commited on
Commit
5a95e08
·
verified ·
1 Parent(s): a4f4b5c

Upload indus_ngram.py

Browse files
Files changed (1) hide show
  1. indus_ngram.py +234 -0
indus_ngram.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ indus_ngram.py — Standalone module for InduNgramModel
3
+ ======================================================
4
+ This file MUST exist so that pickle can import InduNgramModel
5
+ when loading ngram_model.pkl in 07_ensemble.py and 08_electra_train.py.
6
+
7
+ The pickle fix:
8
+ When you save a class with pickle, Python records the module path.
9
+ If the class was defined in __main__ (i.e. inside 06_ngram_model.py
10
+ when run directly), pickle saves it as __main__.InduNgramModel.
11
+ When another script tries to load it, __main__ refers to THAT script,
12
+ which doesn't have InduNgramModel — hence the AttributeError.
13
+
14
+ Solution: define the class in THIS standalone module (indus_ngram.py).
15
+ Both 06_ngram_model.py and 07_ensemble.py import from here.
16
+ Pickle records the path as indus_ngram.InduNgramModel — always findable.
17
+
18
+ Do not rename or move this file.
19
+ """
20
+
21
+ import math
22
+ import pickle
23
+ from pathlib import Path
24
+ from collections import Counter, defaultdict
25
+
26
+
27
+ class InduNgramModel:
28
+ """
29
+ Kneser-Ney smoothed N-gram LM for Indus Script.
30
+
31
+ RTL mode (default, recommended):
32
+ Sequences reversed before training/scoring.
33
+ RTL bigram entropy (3.18) < LTR (3.72) → supports RTL hypothesis.
34
+
35
+ Sign roles in RTL reading direction:
36
+ INITIAL = reading-start sign (data position [-1])
37
+ TERMINAL = reading-end sign (data position [0])
38
+ MEDIAL = appears in middle positions
39
+ """
40
+
41
+ def __init__(self, rtl=True):
42
+ self.rtl = rtl
43
+ self.unigram = Counter()
44
+ self.bigram = defaultdict(Counter)
45
+ self.trigram = defaultdict(Counter)
46
+ self.start_cnt = Counter()
47
+ self.end_cnt = Counter()
48
+ self.total_seqs = 0
49
+ self.total_tokens = 0
50
+ self.vocab_size = 0
51
+ self.D = 0.75
52
+ self.score_mean = 0.0
53
+ self.score_std = 1.0
54
+ self.score_min = -20.0
55
+ self.score_max = 0.0
56
+ self._pairwise_acc = 0.0
57
+ self._cont_right = Counter()
58
+ self._total_bi_types = 0
59
+
60
+ def _orient(self, seq):
61
+ cleaned = [t for t in seq if t is not None]
62
+ return list(reversed(cleaned)) if self.rtl else list(cleaned)
63
+
64
+ def train(self, sequences):
65
+ mode = "RTL" if self.rtl else "LTR"
66
+ print(f" Training [{mode}] on {len(sequences):,} sequences...")
67
+
68
+ for seq in sequences:
69
+ s = self._orient(seq)
70
+ if not s:
71
+ continue
72
+ self.total_seqs += 1
73
+ self.unigram.update(s)
74
+ self.start_cnt[s[0]] += 1
75
+ self.end_cnt[s[-1]] += 1
76
+ for i in range(len(s) - 1):
77
+ self.bigram[s[i]][s[i+1]] += 1
78
+ for i in range(len(s) - 2):
79
+ self.trigram[(s[i], s[i+1])][s[i+2]] += 1
80
+
81
+ self.total_tokens = sum(self.unigram.values())
82
+ self.vocab_size = len(self.unigram)
83
+
84
+ for a, followers in self.bigram.items():
85
+ for b in followers:
86
+ self._cont_right[b] += 1
87
+ self._total_bi_types = sum(self._cont_right.values())
88
+
89
+ self._calibrate(sequences)
90
+ print(f" Vocab : {self.vocab_size}")
91
+ print(f" Pairwise : {self._pairwise_acc*100:.1f}%")
92
+ print(f" Score range: [{self.score_min:.3f}, {self.score_max:.3f}]")
93
+
94
+ def _calibrate(self, sequences):
95
+ import random, statistics
96
+ random.seed(42)
97
+ all_toks = list(self.unigram.keys())
98
+
99
+ def corrupt(seq):
100
+ r = random.randint(0, 3)
101
+ c = list(seq)
102
+ if r == 0:
103
+ random.shuffle(c)
104
+ elif r == 1:
105
+ c[0] = random.choice(list(self.end_cnt.keys()))
106
+ elif r == 2:
107
+ c[-1] = random.choice(list(self.start_cnt.keys()))
108
+ else:
109
+ for p in random.sample(range(len(c)), max(1, len(c)//2)):
110
+ c[p] = random.choice(all_toks)
111
+ return c
112
+
113
+ sample = sequences[:500]
114
+ good = [self._raw_score(s) for s in sample]
115
+ bad = [self._raw_score(corrupt(s)) for s in sample]
116
+ all_s = good + bad
117
+
118
+ self.score_mean = statistics.mean(all_s)
119
+ self.score_std = statistics.stdev(all_s)
120
+ self.score_min = min(all_s)
121
+ self.score_max = max(all_s)
122
+ self._pairwise_acc = sum(g > b for g, b in zip(good, bad)) / len(good)
123
+
124
+ def _p_uni_kn(self, w):
125
+ return (self._cont_right[w] + 1) / (self._total_bi_types + self.vocab_size)
126
+
127
+ def _p_bi_kn(self, w, given):
128
+ gt = sum(self.bigram[given].values())
129
+ if gt == 0:
130
+ return self._p_uni_kn(w)
131
+ cnt = self.bigram[given].get(w, 0)
132
+ first = max(cnt - self.D, 0) / gt
133
+ lam = (self.D / gt) * len(self.bigram[given])
134
+ return first + lam * self._p_uni_kn(w)
135
+
136
+ def _p_tri_kn(self, w, a, b):
137
+ gt = sum(self.trigram[(a, b)].values())
138
+ if gt == 0:
139
+ return self._p_bi_kn(w, b)
140
+ cnt = self.trigram[(a, b)].get(w, 0)
141
+ first = max(cnt - self.D, 0) / gt
142
+ lam = (self.D / gt) * len(self.trigram[(a, b)])
143
+ return first + lam * self._p_bi_kn(w, b)
144
+
145
+ def _p_initial(self, w):
146
+ return (self.start_cnt[w] + 0.1) / (self.total_seqs + 0.1 * self.vocab_size)
147
+
148
+ def _p_terminal(self, w):
149
+ return (self.end_cnt[w] + 0.1) / (self.total_seqs + 0.1 * self.vocab_size)
150
+
151
+ def _raw_score(self, seq):
152
+ if not seq:
153
+ return self.score_min
154
+ s = self._orient(seq)
155
+ eps = 1e-10
156
+ lp = math.log(self._p_initial(s[0]) + eps)
157
+ lp += math.log(self._p_terminal(s[-1]) + eps)
158
+ for i, w in enumerate(s):
159
+ if i == 0:
160
+ p = self._p_uni_kn(w)
161
+ elif i == 1:
162
+ p = self._p_bi_kn(w, s[i-1])
163
+ else:
164
+ p = self._p_tri_kn(w, s[i-2], s[i-1])
165
+ lp += math.log(p + eps)
166
+ return lp / (len(s) + 2)
167
+
168
+ def validity_score(self, seq):
169
+ raw = self._raw_score(seq)
170
+ norm = (raw - self.score_min) / (self.score_max - self.score_min + 1e-10)
171
+ return float(max(0.02, min(0.98, norm)))
172
+
173
+ def predict_masked(self, seq_with_none, top_k=10):
174
+ masked = [i for i, t in enumerate(seq_with_none) if t is None]
175
+ results = {}
176
+ n = len(seq_with_none)
177
+
178
+ for orig_pos in masked:
179
+ ort_pos = (n - 1 - orig_pos) if self.rtl else orig_pos
180
+ oriented = self._orient(seq_with_none)
181
+ if ort_pos >= len(oriented):
182
+ continue
183
+
184
+ prev = oriented[ort_pos-1] if ort_pos > 0 and oriented[ort_pos-1] is not None else None
185
+ prev2 = oriented[ort_pos-2] if ort_pos > 1 and oriented[ort_pos-2] is not None else None
186
+
187
+ cands = []
188
+ for cand in self.unigram:
189
+ if prev2 is not None and prev is not None:
190
+ p = self._p_tri_kn(cand, prev2, prev)
191
+ elif prev is not None:
192
+ p = self._p_bi_kn(cand, prev)
193
+ else:
194
+ p = self._p_uni_kn(cand)
195
+
196
+ if ort_pos == 0:
197
+ p *= max(self._p_initial(cand) * self.vocab_size, 0.01)
198
+ elif ort_pos == n - 1:
199
+ p *= max(self._p_terminal(cand) * self.vocab_size, 0.01)
200
+
201
+ cands.append((cand, p))
202
+
203
+ cands.sort(key=lambda x: -x[1])
204
+ total = sum(p for _, p in cands[:top_k * 3]) or 1
205
+ results[orig_pos] = [
206
+ {"id": c, "prob": p / total, "rank": i + 1}
207
+ for i, (c, p) in enumerate(cands[:top_k])
208
+ ]
209
+
210
+ return results
211
+
212
+ def sign_role(self, sign_id):
213
+ """Positional role in reading direction."""
214
+ init_p = self.start_cnt[sign_id] / (self.total_seqs + 1)
215
+ term_p = self.end_cnt[sign_id] / (self.total_seqs + 1)
216
+ if init_p > 0.05 and init_p > term_p * 2:
217
+ return "INITIAL"
218
+ elif term_p > 0.05 and term_p > init_p * 2:
219
+ return "TERMINAL"
220
+ elif self.unigram[sign_id] > 5:
221
+ return "MEDIAL"
222
+ return "RARE"
223
+
224
+ def save(self, path):
225
+ path = Path(path)
226
+ path.parent.mkdir(parents=True, exist_ok=True)
227
+ with open(path, "wb") as f:
228
+ pickle.dump(self, f, protocol=pickle.HIGHEST_PROTOCOL)
229
+ print(f" Saved → {path}")
230
+
231
+ @staticmethod
232
+ def load(path):
233
+ with open(path, "rb") as f:
234
+ return pickle.load(f)