technician1 commited on
Commit
3d5af4f
·
1 Parent(s): 2ea937a

Upload ChatIPC.cpp

Browse files
Files changed (1) hide show
  1. ChatIPC.cpp +663 -182
ChatIPC.cpp CHANGED
@@ -35,6 +35,232 @@ static inline bool is_space(char c){ return std::isspace(static_cast<unsigned ch
35
  static inline char to_low(char c){ return static_cast<char>(std::tolower(static_cast<unsigned char>(c))); }
36
  static inline void safe_flush(std::ostream &os){ os.flush(); }
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  // Tokenize by whitespace
39
  static std::vector<std::string> tokenize_whitespace(const std::string &s){
40
  std::istringstream iss(s);
@@ -71,17 +297,37 @@ struct StringInterner {
71
  };
72
 
73
  // ---------- Global parsed dictionary (populated once in main) ----------
74
- static std::unordered_map<std::string,std::string> global_def_dict;
75
- static std::unordered_map<std::string, std::vector<std::string>> global_def_tokens_cache;
76
-
77
  static void build_def_tokens_cache(){
78
  global_def_tokens_cache.clear();
79
- global_def_tokens_cache.reserve(global_def_dict.size());
80
- for (const auto &pr : global_def_dict){
81
- auto toks = tokenize_non_alnum(pr.second);
82
- std::sort(toks.begin(), toks.end());
83
- toks.erase(std::unique(toks.begin(), toks.end()), toks.end());
84
- global_def_tokens_cache.emplace(pr.first, std::move(toks));
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  }
86
  }
87
 
@@ -132,25 +378,34 @@ struct KnowledgeBase {
132
  std::unordered_set<StrPtr, PtrHash, PtrEq> acc;
133
  std::vector<StrPtr> frontier;
134
 
135
- auto it_def = global_def_tokens_cache.find(*wp);
136
- if (it_def != global_def_tokens_cache.end()){
137
- for (const auto &t : it_def->second){
138
- StrPtr tp = interner.intern(t);
139
- if (acc.insert(tp).second) frontier.push_back(tp);
 
 
 
140
  }
141
  }
142
 
143
  for (int depth = 1; depth < def_depth && !frontier.empty(); ++depth){
144
- std::vector<StrPtr> nextf;
 
145
  for (StrPtr w : frontier){
146
- auto it2 = global_def_tokens_cache.find(*w);
 
 
 
147
  if (it2 == global_def_tokens_cache.end()) continue;
148
- for (const auto &t : it2->second){
149
- StrPtr tp = interner.intern(t);
150
- if (acc.insert(tp).second) nextf.push_back(tp);
 
151
  }
152
  }
153
- frontier.swap(nextf);
 
154
  }
155
 
156
  std::vector<StrPtr> out;
@@ -159,12 +414,9 @@ struct KnowledgeBase {
159
 
160
  {
161
  std::lock_guard<std::mutex> lk(def_m);
162
- if (def_index.find(wp) == def_index.end()){
163
- def_index.emplace(wp, std::move(out));
164
- }
165
  }
166
  }
167
-
168
  // existing public add_pair but now ensure def-expansion is built immediately
169
  void add_pair(const std::string &k, const std::string &v){
170
  StrPtr kp = interner.intern(k);
@@ -206,13 +458,20 @@ aggregate_sets(const std::vector<StrPtr> &tokens,
206
  const std::unordered_map<StrPtr, std::vector<StrPtr>, PtrHash, PtrEq> &def_index)
207
  {
208
  std::unordered_set<std::string> agg;
 
209
  for (StrPtr t : tokens){
210
- agg.insert(*t);
 
 
211
  auto it = def_index.find(t);
212
  if (it != def_index.end()){
213
- for (StrPtr d : it->second) agg.insert(*d);
 
 
 
214
  }
215
  }
 
216
  return agg;
217
  }
218
 
@@ -243,47 +502,154 @@ static void skip_spaces(const std::string &s, size_t &i){
243
  }
244
 
245
  // Very small JSON-like parser tailored to dictionary_json structure
246
- static std::unordered_map<std::string,std::string> parse_dictionary_json(){
247
- std::unordered_map<std::string,std::string> dict;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
  if (dictionary_json_len == 0) return dict;
249
- std::string text; text.reserve(dictionary_json_len + 1);
250
- for (unsigned int b=0; b < dictionary_json_len; ++b) text.push_back(static_cast<char>(dictionary_json[b]));
 
 
 
 
 
251
  size_t i = 0;
252
- skip_spaces(text,i);
253
- if (!json_valid_index(i,text.size()) || text[i] != '{') return dict;
254
  ++i;
 
255
  while (true){
256
- skip_spaces(text,i);
257
- if (!json_valid_index(i,text.size())) break;
258
- if (text[i] == '}'){ ++i; break; }
259
- std::string key = parse_quoted_string(text,i);
260
- skip_spaces(text,i);
261
- if (!json_valid_index(i,text.size()) || text[i] != ':') break;
 
 
 
 
 
262
  ++i;
263
- skip_spaces(text,i);
264
- std::string val;
265
- if (json_valid_index(i,text.size()) && text[i] == '"') val = parse_quoted_string(text,i);
266
- else {
267
- size_t start = i;
268
- while (json_valid_index(i,text.size()) && text[i] != ',' && text[i] != '}') ++i;
269
- val = text.substr(start, i-start);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  }
271
- dict.emplace(std::move(key), std::move(val));
272
- skip_spaces(text,i);
273
- if (json_valid_index(i,text.size()) && text[i] == ','){ ++i; continue; }
274
- if (json_valid_index(i,text.size()) && text[i] == '}'){ ++i; break; }
 
 
275
  }
 
276
  return dict;
277
  }
278
 
279
- // --------------------------- Candidate selection (short funcs) ---------------
280
-
281
- static std::string best_candidate_by_similarity(const NextSet &cands,
282
  const std::vector<StrPtr> &prompt_ptrs,
283
  const std::vector<StrPtr> &resp_ptrs,
284
  const std::unordered_map<StrPtr, std::vector<StrPtr>, PtrHash, PtrEq> &def_index,
285
  const std::unordered_map<std::string,int> &recent_counts,
286
- double repeat_penalty)
 
287
  {
288
  if (cands.empty()) return std::string();
289
  if (cands.size() == 1) return *cands[0];
@@ -292,7 +658,10 @@ static std::string best_candidate_by_similarity(const NextSet &cands,
292
  for (StrPtr r : resp_ptrs){
293
  auto it = def_index.find(r);
294
  if (it != def_index.end()){
295
- for (StrPtr d : it->second) agg.insert(*d);
 
 
 
296
  }
297
  }
298
 
@@ -304,15 +673,17 @@ static std::string best_candidate_by_similarity(const NextSet &cands,
304
  #pragma omp parallel for schedule(static)
305
  for (ptrdiff_t i = 0; i < static_cast<ptrdiff_t>(M); ++i){
306
  const StrPtr cand = cands[(size_t)i];
 
307
 
308
- size_t inter = agg.count(*cand) ? 1 : 0;
309
  size_t cand_size = 1;
310
 
311
  auto it = def_index.find(cand);
312
  if (it != def_index.end()){
313
  cand_size += it->second.size();
314
  for (StrPtr d : it->second){
315
- if (agg.count(*d)) ++inter;
 
316
  }
317
  if (std::find(it->second.begin(), it->second.end(), cand) != it->second.end()){
318
  --cand_size;
@@ -326,83 +697,112 @@ static std::string best_candidate_by_similarity(const NextSet &cands,
326
 
327
  for (size_t i = 0; i < M; ++i){
328
  const std::string &tok = *cands[i];
 
 
 
329
  double s = scores[i];
330
- auto rc_it = recent_counts.find(tok);
331
  int cnt = (rc_it == recent_counts.end() ? 0 : rc_it->second);
332
- double adjusted = s - repeat_penalty * static_cast<double>(cnt);
 
 
 
 
333
  if (adjusted > best || (adjusted == best && tok < best_tok)){
334
  best = adjusted;
335
  best_tok = tok;
336
  }
337
  }
 
338
  return best_tok;
339
  }
340
 
341
- // --------------------------- Response constructor (short units) ---------------
342
-
343
  static std::vector<std::string> construct_response(KnowledgeBase &kb,
344
- const std::vector<std::string> &prompt_toks,
345
- size_t maxlen,
346
- double repeat_penalty)
347
  {
348
  std::vector<std::string> resp;
349
  if (prompt_toks.empty() || maxlen == 0) return resp;
350
 
351
  auto prompt_ptrs = intern_tokens(kb, prompt_toks);
352
  std::vector<StrPtr> resp_ptrs;
353
-
354
  std::unordered_map<std::string,int> recent_counts;
355
 
356
  auto would_create_2_cycle = [&](const std::string &cand) -> bool {
357
- if (resp.size() < 2) return false;
358
- const std::string &last = resp.back();
359
- const std::string &prev = resp[resp.size()-2];
360
- return (cand == prev && last == resp[resp.size()-3 < resp.size() ? resp.size()-3 : 0]);
361
  };
362
 
363
  std::string last_printed;
 
364
  for (size_t step = 0; step < maxlen; ++step){
365
  NextSet candidates;
366
  bool found = false;
 
 
367
  if (step == 0){
368
  for (ssize_t p = static_cast<ssize_t>(prompt_toks.size()) - 1; p >= 0; --p){
369
  auto opt = kb.lookup_by_string(prompt_toks[(size_t)p]);
370
- if (opt){ candidates = *opt; found = true; break; }
 
 
 
 
 
371
  }
372
  } else {
373
  auto opt = kb.lookup_by_string(last_printed);
374
- if (opt){ candidates = *opt; found = true; }
375
- else {
 
 
 
376
  for (ssize_t p = static_cast<ssize_t>(prompt_toks.size()) - 1; p >= 0; --p){
377
  auto opt2 = kb.lookup_by_string(prompt_toks[(size_t)p]);
378
- if (opt2){ candidates = *opt2; found = true; break; }
 
 
 
 
 
379
  }
380
  }
381
  }
 
382
  if (!found || candidates.empty()) break;
383
 
384
  if (candidates.size() == 1){
385
  std::string only = *candidates[0];
386
- if (recent_counts[only] > 0) break;
 
 
387
  resp.push_back(only);
388
  resp_ptrs.push_back(kb.interner.intern(only));
389
- recent_counts[only] += 1;
390
  last_printed = only;
391
  std::cout << only << ' ' << std::flush;
392
  continue;
393
  }
394
 
395
- std::string chosen = best_candidate_by_similarity(candidates, prompt_ptrs, resp_ptrs, kb.def_index, recent_counts, repeat_penalty);
396
- if (chosen.empty()) break;
 
397
 
 
398
  if (would_create_2_cycle(chosen)) break;
399
 
400
  resp.push_back(chosen);
401
  resp_ptrs.push_back(kb.interner.intern(chosen));
402
- recent_counts[chosen] += 1;
 
 
 
403
  last_printed = chosen;
404
  std::cout << chosen << ' ' << std::flush;
405
  }
 
406
  return resp;
407
  }
408
 
@@ -425,144 +825,225 @@ static void learn_files_parallel(KnowledgeBase &kb, const std::vector<std::strin
425
  for (ptrdiff_t i=0;i<static_cast<ptrdiff_t>(files.size());++i) learn_from_file(kb, files[(size_t)i]);
426
  }
427
 
428
- // --------------------------- Serialization (short functions) ----------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
429
 
430
  static void save_kb_binary(const KnowledgeBase &kb, const std::string &fname){
431
- std::ofstream ofs(fname, std::ios::binary);
432
- if (!ofs) throw std::runtime_error("cannot open save file");
433
-
434
- // interned strings snapshot (must include all tokens used by def_index)
435
- std::vector<const std::string*> interned;
436
- interned.reserve(kb.interner.pool.size());
437
- for (auto &s : kb.interner.pool) interned.push_back(&s);
438
-
439
- uint64_t N = interned.size();
440
- ofs.write(reinterpret_cast<const char*>(&N), sizeof(N));
441
- for (auto p : interned){
442
- uint64_t L = p->size();
443
- ofs.write(reinterpret_cast<const char*>(&L), sizeof(L));
444
- ofs.write(p->data(), static_cast<std::streamsize>(L));
445
- }
446
-
447
- std::unordered_map<StrPtr, uint64_t, PtrHash, PtrEq> ptr_to_idx;
448
- ptr_to_idx.reserve(interned.size());
449
- for (uint64_t i = 0; i < N; ++i){
450
- ptr_to_idx.emplace(interned[(size_t)i], i);
451
- }
452
-
453
- // edges
454
- uint64_t E = kb.next.size();
455
- ofs.write(reinterpret_cast<const char*>(&E), sizeof(E));
456
- for (auto &pr : kb.next){
457
- uint64_t key_idx = ptr_to_idx.at(pr.first);
458
- ofs.write(reinterpret_cast<const char*>(&key_idx), sizeof(key_idx));
459
- uint64_t M = pr.second.size();
460
- ofs.write(reinterpret_cast<const char*>(&M), sizeof(M));
461
- for (auto nxt : pr.second){
462
- uint64_t v_idx = ptr_to_idx.at(nxt);
463
- ofs.write(reinterpret_cast<const char*>(&v_idx), sizeof(v_idx));
464
  }
465
- }
466
- // --- write definition expansion section ---
467
- uint64_t D = static_cast<uint64_t>(kb.def_depth);
468
- ofs.write(reinterpret_cast<const char*>(&D), sizeof(D));
469
-
470
- // def entries: number of keys with a stored expansion
471
- uint64_t K = kb.def_index.size();
472
- ofs.write(reinterpret_cast<const char*>(&K), sizeof(K));
473
- for (auto &pr : kb.def_index){
474
- uint64_t key_idx = ptr_to_idx.at(pr.first);
475
- ofs.write(reinterpret_cast<const char*>(&key_idx), sizeof(key_idx));
476
- uint64_t M = pr.second.size();
477
- ofs.write(reinterpret_cast<const char*>(&M), sizeof(M));
478
- for (auto tokp : pr.second){
479
- uint64_t v_idx = ptr_to_idx.at(tokp);
480
- ofs.write(reinterpret_cast<const char*>(&v_idx), sizeof(v_idx));
 
 
 
 
 
 
 
 
481
  }
 
 
 
482
  }
483
 
484
- safe_flush(ofs);
 
 
 
 
485
  }
486
 
487
  static void load_kb_binary(KnowledgeBase &kb, const std::string &fname, int cli_dict_depth){
488
  std::ifstream ifs(fname, std::ios::binary);
489
  if (!ifs) throw std::runtime_error("cannot open load file");
490
 
491
- uint64_t N;
492
- ifs.read(reinterpret_cast<char*>(&N), sizeof(N));
493
- std::vector<std::string> strings; strings.reserve((size_t)N);
494
- kb.interner.pool.reserve((size_t)N);
495
- for (uint64_t i=0;i<N;++i){
496
- uint64_t L; ifs.read(reinterpret_cast<char*>(&L), sizeof(L));
497
- std::string s; s.resize((size_t)L);
498
- ifs.read(&s[0], static_cast<std::streamsize>(L));
499
- strings.push_back(std::move(s));
500
- }
501
- std::vector<StrPtr> ptrs; ptrs.reserve(strings.size());
502
- for (auto &s : strings) ptrs.push_back(kb.interner.intern(s));
503
-
504
- uint64_t E; ifs.read(reinterpret_cast<char*>(&E), sizeof(E));
505
- kb.next.reserve((size_t)E);
506
- kb.next_key_index.reserve((size_t)E);
507
- for (uint64_t i=0;i<E;++i){
508
- uint64_t key_idx; ifs.read(reinterpret_cast<char*>(&key_idx), sizeof(key_idx));
509
- uint64_t M; ifs.read(reinterpret_cast<char*>(&M), sizeof(M));
510
- StrPtr key_ptr = ptrs.at((size_t)key_idx);
511
- NextSet vec; vec.reserve((size_t)M);
512
- for (uint64_t j=0;j<M;++j){
513
- uint64_t v_idx; ifs.read(reinterpret_cast<char*>(&v_idx), sizeof(v_idx));
514
- vec.push_back(ptrs.at((size_t)v_idx));
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
515
  }
516
- kb.next.emplace(key_ptr, std::move(vec));
517
  }
518
 
519
- // read def-expansion section (new-format)
520
- uint64_t file_def_depth;
521
- ifs.read(reinterpret_cast<char*>(&file_def_depth), sizeof(file_def_depth));
522
- uint64_t K; ifs.read(reinterpret_cast<char*>(&K), sizeof(K));
523
- // populate kb.def_index from file
524
  {
525
  std::lock_guard<std::mutex> lk(kb.def_m);
526
  kb.def_index.clear();
527
- kb.def_index.reserve((size_t)K);
528
  kb.def_depth = static_cast<int>(file_def_depth);
529
  }
530
- for (uint64_t i=0;i<K;++i){
531
- uint64_t key_idx; ifs.read(reinterpret_cast<char*>(&key_idx), sizeof(key_idx));
532
- uint64_t M; ifs.read(reinterpret_cast<char*>(&M), sizeof(M));
533
- std::vector<StrPtr> tokens; tokens.reserve((size_t)M);
534
- for (uint64_t j=0;j<M;++j){
535
- uint64_t v_idx; ifs.read(reinterpret_cast<char*>(&v_idx), sizeof(v_idx));
536
- tokens.push_back(ptrs.at((size_t)v_idx));
 
 
 
 
 
 
 
 
 
 
 
 
 
537
  }
538
- kb.def_index.emplace(ptrs.at((size_t)key_idx), std::move(tokens));
539
  }
540
 
541
- // If CLI requested a different dict depth, clear and recompute expansion for loaded words only
542
- if (cli_dict_depth != kb.def_depth){
543
  kb.set_def_depth(cli_dict_depth);
544
- // --- build deduplicated union of "words present" = saved strings (ptrs) ∪ KB words (keys and neighbors)
545
  std::vector<StrPtr> targets;
546
- targets.reserve(ptrs.size() + kb.next.size()*2);
 
 
 
 
 
 
 
547
 
548
  {
549
- std::unordered_set<StrPtr, PtrHash, PtrEq> seen;
550
- // include all strings from the saved file
551
- for (auto p : ptrs) {
552
- if (seen.insert(p).second) targets.push_back(p);
553
- }
554
- // include all words present in KB edges (keys and their neighbors)
555
- for (auto &pr : kb.next) {
556
  if (seen.insert(pr.first).second) targets.push_back(pr.first);
557
- for (auto v : pr.second) {
558
  if (seen.insert(v).second) targets.push_back(v);
559
  }
560
  }
561
  }
562
 
563
- // --- recompute definition expansion for each target in parallel
564
  #pragma omp parallel for schedule(dynamic)
565
- for (ptrdiff_t i = 0; i < static_cast<ptrdiff_t>(targets.size()); ++i) {
566
  kb.ensure_def_for_interned(targets[(size_t)i]);
567
  }
568
  }
@@ -614,7 +1095,7 @@ int main(int argc, char **argv){
614
  KnowledgeBase kb;
615
 
616
  // parse the embedded dictionary once for use by per-word expansion
617
- global_def_dict = parse_dictionary_json();
618
  build_def_tokens_cache();
619
  // set KB def depth (clears any previous expansion)
620
  kb.set_def_depth(dict_depth);
 
35
  static inline char to_low(char c){ return static_cast<char>(std::tolower(static_cast<unsigned char>(c))); }
36
  static inline void safe_flush(std::ostream &os){ os.flush(); }
37
 
38
+ // NEW: dictionary model, normalization, and English-rule helpers.
39
+ struct DictionaryEntry {
40
+ std::string pos;
41
+ std::string word;
42
+ std::vector<std::string> definitions;
43
+ };
44
+
45
+ static std::vector<DictionaryEntry> global_dictionary_entries;
46
+ static std::unordered_map<std::string, std::vector<std::string>> global_def_tokens_cache;
47
+ static std::unordered_map<std::string, std::vector<std::string>> global_pos_cache;
48
+
49
+ static inline bool is_word_char_for_key(char c){
50
+ unsigned char uc = static_cast<unsigned char>(c);
51
+ return std::isalnum(uc) != 0 || c == '\'' || c == '-';
52
+ }
53
+
54
+ static std::string normalize_dictionary_key(const std::string &s){
55
+ size_t b = 0, e = s.size();
56
+ while (b < e && !is_word_char_for_key(s[b])) ++b;
57
+ while (e > b && !is_word_char_for_key(s[e - 1])) --e;
58
+
59
+ std::string out;
60
+ out.reserve(e - b);
61
+ for (size_t i = b; i < e; ++i) out.push_back(to_low(s[i]));
62
+ return out;
63
+ }
64
+
65
+ static std::string normalize_pos_tag(const std::string &s){
66
+ std::string out;
67
+ out.reserve(s.size());
68
+ for (char c : s){
69
+ unsigned char uc = static_cast<unsigned char>(c);
70
+ if (std::isalpha(uc) != 0) out.push_back(to_low(c));
71
+ }
72
+ return out;
73
+ }
74
+
75
+ enum class PosClass {
76
+ Unknown,
77
+ Noun,
78
+ Verb,
79
+ Adj,
80
+ Adv,
81
+ Pron,
82
+ Prep,
83
+ Conj,
84
+ Det,
85
+ Num,
86
+ Interj
87
+ };
88
+
89
+ static PosClass pos_class_from_tag(const std::string &tag){
90
+ if (tag == "n" || tag == "noun") return PosClass::Noun;
91
+ if (tag == "v" || tag == "verb" || tag == "part" || tag == "participle" || tag == "p") return PosClass::Verb;
92
+ if (tag == "a" || tag == "adj" || tag == "adjective") return PosClass::Adj;
93
+ if (tag == "adv" || tag == "adverb") return PosClass::Adv;
94
+ if (tag == "pron" || tag == "pronoun") return PosClass::Pron;
95
+ if (tag == "prep" || tag == "preposition") return PosClass::Prep;
96
+ if (tag == "conj" || tag == "conjunction") return PosClass::Conj;
97
+ if (tag == "art" || tag == "article" || tag == "det" || tag == "determiner") return PosClass::Det;
98
+ if (tag == "num" || tag == "number") return PosClass::Num;
99
+ if (tag == "interj" || tag == "interjection") return PosClass::Interj;
100
+ return PosClass::Unknown;
101
+ }
102
+
103
+ static bool has_pos_class(const std::vector<std::string> &tags, PosClass cls){
104
+ for (const auto &t : tags){
105
+ if (pos_class_from_tag(t) == cls) return true;
106
+ }
107
+ return false;
108
+ }
109
+
110
+ static const std::vector<std::string> &dictionary_pos_for_token(const std::string &surface){
111
+ static const std::vector<std::string> empty;
112
+ auto key = normalize_dictionary_key(surface);
113
+ if (key.empty()) return empty;
114
+
115
+ auto it = global_pos_cache.find(key);
116
+ return (it == global_pos_cache.end()) ? empty : it->second;
117
+ }
118
+
119
+ static bool first_alpha_is_upper(const std::string &s){
120
+ for (char c : s){
121
+ unsigned char uc = static_cast<unsigned char>(c);
122
+ if (std::isalpha(uc) != 0) return std::isupper(uc) != 0;
123
+ }
124
+ return false;
125
+ }
126
+
127
+ static bool first_alpha_is_lower(const std::string &s){
128
+ for (char c : s){
129
+ unsigned char uc = static_cast<unsigned char>(c);
130
+ if (std::isalpha(uc) != 0) return std::islower(uc) != 0;
131
+ }
132
+ return false;
133
+ }
134
+
135
+ static bool is_sentence_boundary_token(const std::string &s){
136
+ if (s.empty()) return false;
137
+ char c = s.back();
138
+ return c == '.' || c == '!' || c == '?';
139
+ }
140
+
141
+ static bool is_open_punct_token(const std::string &s){
142
+ return s == "(" || s == "[" || s == "{" || s == "\"" || s == "'";
143
+ }
144
+
145
+ static bool is_punctuation_only_token(const std::string &s){
146
+ if (s.empty()) return false;
147
+ for (char c : s){
148
+ unsigned char uc = static_cast<unsigned char>(c);
149
+ if (std::isalnum(uc) != 0) return false;
150
+ }
151
+ return true;
152
+ }
153
+
154
+ static bool is_common_determiner(const std::string &s){
155
+ return s == "a" || s == "an" || s == "the" || s == "this" || s == "that" ||
156
+ s == "these" || s == "those" || s == "my" || s == "your" || s == "his" ||
157
+ s == "her" || s == "its" || s == "our" || s == "their";
158
+ }
159
+
160
+ static bool is_common_preposition(const std::string &s){
161
+ return s == "of" || s == "in" || s == "on" || s == "at" || s == "by" ||
162
+ s == "for" || s == "from" || s == "with" || s == "into" || s == "onto" ||
163
+ s == "about" || s == "over" || s == "under" || s == "after" || s == "before" ||
164
+ s == "between" || s == "through" || s == "during" || s == "without" || s == "within" ||
165
+ s == "under" || s == "across" || s == "against" || s == "among" || s == "around";
166
+ }
167
+
168
+ static bool is_common_aux_or_modal(const std::string &s){
169
+ return s == "to" || s == "be" || s == "am" || s == "is" || s == "are" || s == "was" ||
170
+ s == "were" || s == "been" || s == "being" || s == "have" || s == "has" ||
171
+ s == "had" || s == "do" || s == "does" || s == "did" || s == "can" ||
172
+ s == "could" || s == "may" || s == "might" || s == "must" || s == "shall" ||
173
+ s == "should" || s == "will" || s == "would";
174
+ }
175
+
176
+ static bool begins_with_vowel_sound(const std::string &s){
177
+ if (s.empty()) return false;
178
+
179
+ if (s.rfind("hour", 0) == 0 || s.rfind("honest", 0) == 0 || s.rfind("honor", 0) == 0 ||
180
+ s.rfind("heir", 0) == 0 || s.rfind("herb", 0) == 0) {
181
+ return true;
182
+ }
183
+
184
+ if (s.rfind("uni", 0) == 0 || s.rfind("use", 0) == 0 || s.rfind("user", 0) == 0 ||
185
+ s.rfind("one", 0) == 0 || s.rfind("once", 0) == 0 || s.rfind("euro", 0) == 0) {
186
+ return false;
187
+ }
188
+
189
+ char c = s[0];
190
+ return c == 'a' || c == 'e' || c == 'i' || c == 'o' || c == 'u';
191
+ }
192
+
193
+ static double english_rule_bonus(const std::string &context_tok, const std::string &cand){
194
+ const std::string ctx_key = normalize_dictionary_key(context_tok);
195
+ const std::string cand_key = normalize_dictionary_key(cand);
196
+
197
+ const auto &ctx_tags = dictionary_pos_for_token(context_tok);
198
+ const auto &cand_tags = dictionary_pos_for_token(cand);
199
+
200
+ const bool sentence_start = context_tok.empty() || is_sentence_boundary_token(context_tok) || is_open_punct_token(context_tok);
201
+
202
+ const bool cand_nounish = has_pos_class(cand_tags, PosClass::Noun) ||
203
+ has_pos_class(cand_tags, PosClass::Adj) ||
204
+ has_pos_class(cand_tags, PosClass::Pron) ||
205
+ has_pos_class(cand_tags, PosClass::Num);
206
+
207
+ const bool cand_verbish = has_pos_class(cand_tags, PosClass::Verb);
208
+ const bool cand_advish = has_pos_class(cand_tags, PosClass::Adv);
209
+ const bool cand_prepish = has_pos_class(cand_tags, PosClass::Prep);
210
+ const bool cand_detish = has_pos_class(cand_tags, PosClass::Det);
211
+
212
+ double bonus = 0.0;
213
+
214
+ if (!cand_key.empty()){
215
+ if (sentence_start){
216
+ bonus += first_alpha_is_upper(cand) ? 0.22 : -0.08;
217
+ } else if (first_alpha_is_upper(cand)){
218
+ bonus -= 0.03;
219
+ }
220
+ }
221
+
222
+ if (ctx_key == "a" || ctx_key == "an"){
223
+ const bool vowel = begins_with_vowel_sound(cand_key.empty() ? cand : cand_key);
224
+ bonus += ((ctx_key == "an") == vowel) ? 0.28 : -0.18;
225
+ }
226
+
227
+ const bool ctx_det = has_pos_class(ctx_tags, PosClass::Det) || is_common_determiner(ctx_key);
228
+ const bool ctx_prep = has_pos_class(ctx_tags, PosClass::Prep) || is_common_preposition(ctx_key);
229
+ const bool ctx_aux = is_common_aux_or_modal(ctx_key);
230
+
231
+ if (ctx_det){
232
+ if (cand_nounish) bonus += 0.20;
233
+ if (cand_verbish || cand_advish || cand_prepish) bonus -= 0.08;
234
+ }
235
+
236
+ if (ctx_prep){
237
+ if (cand_nounish) bonus += 0.16;
238
+ if (cand_verbish) bonus -= 0.06;
239
+ }
240
+
241
+ if (ctx_aux){
242
+ if (cand_verbish) bonus += 0.18;
243
+ if (cand_detish) bonus -= 0.04;
244
+ }
245
+
246
+ if (has_pos_class(ctx_tags, PosClass::Pron) || has_pos_class(ctx_tags, PosClass::Noun)){
247
+ if (cand_verbish) bonus += 0.05;
248
+ }
249
+
250
+ if (!context_tok.empty() && (context_tok.back() == ',' || context_tok.back() == ';' || context_tok.back() == ':')){
251
+ if (!cand.empty() && first_alpha_is_lower(cand)) bonus += 0.04;
252
+ }
253
+
254
+ if (is_punctuation_only_token(cand)){
255
+ if (sentence_start) bonus -= 0.05;
256
+ else if (!context_tok.empty() && std::isalnum(static_cast<unsigned char>(context_tok.back())) != 0) bonus += 0.03;
257
+ }
258
+
259
+ if (is_sentence_boundary_token(cand)) bonus += 0.06;
260
+
261
+ return bonus;
262
+ }
263
+
264
  // Tokenize by whitespace
265
  static std::vector<std::string> tokenize_whitespace(const std::string &s){
266
  std::istringstream iss(s);
 
297
  };
298
 
299
  // ---------- Global parsed dictionary (populated once in main) ----------
 
 
 
300
  static void build_def_tokens_cache(){
301
  global_def_tokens_cache.clear();
302
+ global_pos_cache.clear();
303
+
304
+ global_def_tokens_cache.reserve(global_dictionary_entries.size());
305
+ global_pos_cache.reserve(global_dictionary_entries.size());
306
+
307
+ for (const auto &entry : global_dictionary_entries){
308
+ const std::string key = normalize_dictionary_key(entry.word);
309
+ if (key.empty()) continue;
310
+
311
+ std::string pos = normalize_pos_tag(entry.pos);
312
+ if (!pos.empty()) global_pos_cache[key].push_back(std::move(pos));
313
+
314
+ auto &defs = global_def_tokens_cache[key];
315
+ for (const auto &def : entry.definitions){
316
+ auto toks = tokenize_non_alnum(def);
317
+ defs.insert(defs.end(), toks.begin(), toks.end());
318
+ }
319
+ }
320
+
321
+ for (auto &pr : global_def_tokens_cache){
322
+ auto &v = pr.second;
323
+ std::sort(v.begin(), v.end());
324
+ v.erase(std::unique(v.begin(), v.end()), v.end());
325
+ }
326
+
327
+ for (auto &pr : global_pos_cache){
328
+ auto &v = pr.second;
329
+ std::sort(v.begin(), v.end());
330
+ v.erase(std::unique(v.begin(), v.end()), v.end());
331
  }
332
  }
333
 
 
378
  std::unordered_set<StrPtr, PtrHash, PtrEq> acc;
379
  std::vector<StrPtr> frontier;
380
 
381
+ const std::string start_key = normalize_dictionary_key(*wp);
382
+ if (!start_key.empty()){
383
+ auto it_def = global_def_tokens_cache.find(start_key);
384
+ if (it_def != global_def_tokens_cache.end()){
385
+ for (const auto &tok : it_def->second){
386
+ StrPtr tp = interner.intern(tok);
387
+ if (acc.insert(tp).second) frontier.push_back(tp);
388
+ }
389
  }
390
  }
391
 
392
  for (int depth = 1; depth < def_depth && !frontier.empty(); ++depth){
393
+ std::vector<StrPtr> next_frontier;
394
+
395
  for (StrPtr w : frontier){
396
+ const std::string key = normalize_dictionary_key(*w);
397
+ if (key.empty()) continue;
398
+
399
+ auto it2 = global_def_tokens_cache.find(key);
400
  if (it2 == global_def_tokens_cache.end()) continue;
401
+
402
+ for (const auto &tok : it2->second){
403
+ StrPtr tp = interner.intern(tok);
404
+ if (acc.insert(tp).second) next_frontier.push_back(tp);
405
  }
406
  }
407
+
408
+ frontier.swap(next_frontier);
409
  }
410
 
411
  std::vector<StrPtr> out;
 
414
 
415
  {
416
  std::lock_guard<std::mutex> lk(def_m);
417
+ def_index.emplace(wp, std::move(out));
 
 
418
  }
419
  }
 
420
  // existing public add_pair but now ensure def-expansion is built immediately
421
  void add_pair(const std::string &k, const std::string &v){
422
  StrPtr kp = interner.intern(k);
 
458
  const std::unordered_map<StrPtr, std::vector<StrPtr>, PtrHash, PtrEq> &def_index)
459
  {
460
  std::unordered_set<std::string> agg;
461
+
462
  for (StrPtr t : tokens){
463
+ const std::string tk = normalize_dictionary_key(*t);
464
+ if (!tk.empty()) agg.insert(tk);
465
+
466
  auto it = def_index.find(t);
467
  if (it != def_index.end()){
468
+ for (StrPtr d : it->second){
469
+ const std::string dk = normalize_dictionary_key(*d);
470
+ if (!dk.empty()) agg.insert(dk);
471
+ }
472
  }
473
  }
474
+
475
  return agg;
476
  }
477
 
 
502
  }
503
 
504
  // Very small JSON-like parser tailored to dictionary_json structure
505
+ static void skip_json_value(const std::string &s, size_t &i);
506
+
507
+ static std::vector<std::string> parse_json_string_array(const std::string &text, size_t &i){
508
+ std::vector<std::string> out;
509
+ if (!json_valid_index(i, text.size()) || text[i] != '[') return out;
510
+
511
+ ++i;
512
+ while (true){
513
+ skip_spaces(text, i);
514
+ if (!json_valid_index(i, text.size())) break;
515
+ if (text[i] == ']'){ ++i; break; }
516
+
517
+ if (text[i] == '"') out.push_back(parse_quoted_string(text, i));
518
+ else skip_json_value(text, i);
519
+
520
+ skip_spaces(text, i);
521
+ if (json_valid_index(i, text.size()) && text[i] == ','){ ++i; continue; }
522
+ if (json_valid_index(i, text.size()) && text[i] == ']'){ ++i; break; }
523
+ }
524
+
525
+ return out;
526
+ }
527
+
528
+ static void skip_json_value(const std::string &s, size_t &i){
529
+ skip_spaces(s, i);
530
+ if (!json_valid_index(i, s.size())) return;
531
+
532
+ if (s[i] == '"'){
533
+ (void)parse_quoted_string(s, i);
534
+ return;
535
+ }
536
+
537
+ if (s[i] == '['){
538
+ ++i;
539
+ while (json_valid_index(i, s.size())){
540
+ skip_spaces(s, i);
541
+ if (!json_valid_index(i, s.size())) break;
542
+ if (s[i] == ']'){ ++i; break; }
543
+ skip_json_value(s, i);
544
+ skip_spaces(s, i);
545
+ if (json_valid_index(i, s.size()) && s[i] == ','){ ++i; continue; }
546
+ if (json_valid_index(i, s.size()) && s[i] == ']'){ ++i; break; }
547
+ }
548
+ return;
549
+ }
550
+
551
+ if (s[i] == '{'){
552
+ ++i;
553
+ while (json_valid_index(i, s.size())){
554
+ skip_spaces(s, i);
555
+ if (!json_valid_index(i, s.size())) break;
556
+ if (s[i] == '}'){ ++i; break; }
557
+ if (s[i] == '"'){
558
+ (void)parse_quoted_string(s, i);
559
+ skip_spaces(s, i);
560
+ if (json_valid_index(i, s.size()) && s[i] == ':') ++i;
561
+ skip_json_value(s, i);
562
+ skip_spaces(s, i);
563
+ if (json_valid_index(i, s.size()) && s[i] == ','){ ++i; continue; }
564
+ if (json_valid_index(i, s.size()) && s[i] == '}'){ ++i; break; }
565
+ } else {
566
+ ++i;
567
+ }
568
+ }
569
+ return;
570
+ }
571
+
572
+ while (json_valid_index(i, s.size())){
573
+ char c = s[i];
574
+ if (c == ',' || c == ']' || c == '}' || is_space(c)) break;
575
+ ++i;
576
+ }
577
+ }
578
+
579
+ static std::vector<DictionaryEntry> parse_dictionary_json(){
580
+ std::vector<DictionaryEntry> dict;
581
  if (dictionary_json_len == 0) return dict;
582
+
583
+ std::string text;
584
+ text.reserve(dictionary_json_len);
585
+ for (unsigned int b = 0; b < dictionary_json_len; ++b){
586
+ text.push_back(static_cast<char>(dictionary_json[b]));
587
+ }
588
+
589
  size_t i = 0;
590
+ skip_spaces(text, i);
591
+ if (!json_valid_index(i, text.size()) || text[i] != '[') return dict;
592
  ++i;
593
+
594
  while (true){
595
+ skip_spaces(text, i);
596
+ if (!json_valid_index(i, text.size())) break;
597
+ if (text[i] == ']'){ ++i; break; }
598
+ if (text[i] != '{'){
599
+ skip_json_value(text, i);
600
+ skip_spaces(text, i);
601
+ if (json_valid_index(i, text.size()) && text[i] == ','){ ++i; continue; }
602
+ if (json_valid_index(i, text.size()) && text[i] == ']'){ ++i; break; }
603
+ continue;
604
+ }
605
+
606
  ++i;
607
+ DictionaryEntry entry;
608
+
609
+ while (true){
610
+ skip_spaces(text, i);
611
+ if (!json_valid_index(i, text.size())) break;
612
+ if (text[i] == '}'){ ++i; break; }
613
+
614
+ std::string field = parse_quoted_string(text, i);
615
+ skip_spaces(text, i);
616
+ if (!json_valid_index(i, text.size()) || text[i] != ':') break;
617
+ ++i;
618
+ skip_spaces(text, i);
619
+
620
+ if (field == "word"){
621
+ entry.word = parse_quoted_string(text, i);
622
+ } else if (field == "pos"){
623
+ entry.pos = parse_quoted_string(text, i);
624
+ } else if (field == "definitions"){
625
+ entry.definitions = parse_json_string_array(text, i);
626
+ } else {
627
+ skip_json_value(text, i);
628
+ }
629
+
630
+ skip_spaces(text, i);
631
+ if (json_valid_index(i, text.size()) && text[i] == ','){ ++i; continue; }
632
+ if (json_valid_index(i, text.size()) && text[i] == '}'){ ++i; break; }
633
  }
634
+
635
+ if (!entry.word.empty()) dict.push_back(std::move(entry));
636
+
637
+ skip_spaces(text, i);
638
+ if (json_valid_index(i, text.size()) && text[i] == ','){ ++i; continue; }
639
+ if (json_valid_index(i, text.size()) && text[i] == ']'){ ++i; break; }
640
  }
641
+
642
  return dict;
643
  }
644
 
645
+ static std::string best_candidate_by_similarity(
646
+ const NextSet &cands,
 
647
  const std::vector<StrPtr> &prompt_ptrs,
648
  const std::vector<StrPtr> &resp_ptrs,
649
  const std::unordered_map<StrPtr, std::vector<StrPtr>, PtrHash, PtrEq> &def_index,
650
  const std::unordered_map<std::string,int> &recent_counts,
651
+ double repeat_penalty,
652
+ const std::string &context_tok)
653
  {
654
  if (cands.empty()) return std::string();
655
  if (cands.size() == 1) return *cands[0];
 
658
  for (StrPtr r : resp_ptrs){
659
  auto it = def_index.find(r);
660
  if (it != def_index.end()){
661
+ for (StrPtr d : it->second){
662
+ const std::string dk = normalize_dictionary_key(*d);
663
+ if (!dk.empty()) agg.insert(dk);
664
+ }
665
  }
666
  }
667
 
 
673
  #pragma omp parallel for schedule(static)
674
  for (ptrdiff_t i = 0; i < static_cast<ptrdiff_t>(M); ++i){
675
  const StrPtr cand = cands[(size_t)i];
676
+ const std::string cand_key = normalize_dictionary_key(*cand);
677
 
678
+ size_t inter = (!cand_key.empty() && agg.count(cand_key)) ? 1 : 0;
679
  size_t cand_size = 1;
680
 
681
  auto it = def_index.find(cand);
682
  if (it != def_index.end()){
683
  cand_size += it->second.size();
684
  for (StrPtr d : it->second){
685
+ const std::string dk = normalize_dictionary_key(*d);
686
+ if (!dk.empty() && agg.count(dk)) ++inter;
687
  }
688
  if (std::find(it->second.begin(), it->second.end(), cand) != it->second.end()){
689
  --cand_size;
 
697
 
698
  for (size_t i = 0; i < M; ++i){
699
  const std::string &tok = *cands[i];
700
+ const std::string tok_key = normalize_dictionary_key(tok);
701
+ const std::string count_key = tok_key.empty() ? tok : tok;
702
+
703
  double s = scores[i];
704
+ auto rc_it = recent_counts.find(count_key);
705
  int cnt = (rc_it == recent_counts.end() ? 0 : rc_it->second);
706
+
707
+ double adjusted = s
708
+ + english_rule_bonus(context_tok, tok)
709
+ - repeat_penalty * static_cast<double>(cnt);
710
+
711
  if (adjusted > best || (adjusted == best && tok < best_tok)){
712
  best = adjusted;
713
  best_tok = tok;
714
  }
715
  }
716
+
717
  return best_tok;
718
  }
719
 
 
 
720
  static std::vector<std::string> construct_response(KnowledgeBase &kb,
721
+ const std::vector<std::string> &prompt_toks,
722
+ size_t maxlen,
723
+ double repeat_penalty)
724
  {
725
  std::vector<std::string> resp;
726
  if (prompt_toks.empty() || maxlen == 0) return resp;
727
 
728
  auto prompt_ptrs = intern_tokens(kb, prompt_toks);
729
  std::vector<StrPtr> resp_ptrs;
 
730
  std::unordered_map<std::string,int> recent_counts;
731
 
732
  auto would_create_2_cycle = [&](const std::string &cand) -> bool {
733
+ if (resp.size() < 3) return false;
734
+ return normalize_dictionary_key(cand) == normalize_dictionary_key(resp[resp.size() - 2]) &&
735
+ normalize_dictionary_key(resp.back()) == normalize_dictionary_key(resp[resp.size() - 3]);
 
736
  };
737
 
738
  std::string last_printed;
739
+
740
  for (size_t step = 0; step < maxlen; ++step){
741
  NextSet candidates;
742
  bool found = false;
743
+ std::string context_tok;
744
+
745
  if (step == 0){
746
  for (ssize_t p = static_cast<ssize_t>(prompt_toks.size()) - 1; p >= 0; --p){
747
  auto opt = kb.lookup_by_string(prompt_toks[(size_t)p]);
748
+ if (opt){
749
+ candidates = *opt;
750
+ found = true;
751
+ context_tok = prompt_toks[(size_t)p];
752
+ break;
753
+ }
754
  }
755
  } else {
756
  auto opt = kb.lookup_by_string(last_printed);
757
+ if (opt){
758
+ candidates = *opt;
759
+ found = true;
760
+ context_tok = last_printed;
761
+ } else {
762
  for (ssize_t p = static_cast<ssize_t>(prompt_toks.size()) - 1; p >= 0; --p){
763
  auto opt2 = kb.lookup_by_string(prompt_toks[(size_t)p]);
764
+ if (opt2){
765
+ candidates = *opt2;
766
+ found = true;
767
+ context_tok = prompt_toks[(size_t)p];
768
+ break;
769
+ }
770
  }
771
  }
772
  }
773
+
774
  if (!found || candidates.empty()) break;
775
 
776
  if (candidates.size() == 1){
777
  std::string only = *candidates[0];
778
+ std::string only_key = normalize_dictionary_key(only);
779
+ if (recent_counts[only_key.empty() ? only : only_key] > 0) break;
780
+
781
  resp.push_back(only);
782
  resp_ptrs.push_back(kb.interner.intern(only));
783
+ recent_counts[only_key.empty() ? only : only_key] += 1;
784
  last_printed = only;
785
  std::cout << only << ' ' << std::flush;
786
  continue;
787
  }
788
 
789
+ std::string chosen = best_candidate_by_similarity(
790
+ candidates, prompt_ptrs, resp_ptrs, kb.def_index, recent_counts, repeat_penalty, context_tok
791
+ );
792
 
793
+ if (chosen.empty()) break;
794
  if (would_create_2_cycle(chosen)) break;
795
 
796
  resp.push_back(chosen);
797
  resp_ptrs.push_back(kb.interner.intern(chosen));
798
+
799
+ std::string chosen_key = normalize_dictionary_key(chosen);
800
+ recent_counts[chosen_key.empty() ? chosen : chosen_key] += 1;
801
+
802
  last_printed = chosen;
803
  std::cout << chosen << ' ' << std::flush;
804
  }
805
+
806
  return resp;
807
  }
808
 
 
825
  for (ptrdiff_t i=0;i<static_cast<ptrdiff_t>(files.size());++i) learn_from_file(kb, files[(size_t)i]);
826
  }
827
 
828
+ // --------------------------- Serialization (binary, versioned) --------------
829
+ static constexpr std::uint64_t KB_MAGIC = 0x434850434B535641ULL; // "CHPCKSVA"
830
+ static constexpr std::uint64_t KB_VERSION = 1ULL;
831
+
832
+ static void write_u64(std::ostream &os, std::uint64_t v){
833
+ os.write(reinterpret_cast<const char*>(&v), sizeof(v));
834
+ if(!os) throw std::runtime_error("write_u64 failed");
835
+ }
836
+
837
+ static std::uint64_t read_u64(std::istream &is){
838
+ std::uint64_t v = 0;
839
+ is.read(reinterpret_cast<char*>(&v), sizeof(v));
840
+ if(!is) throw std::runtime_error("read_u64 failed");
841
+ return v;
842
+ }
843
+
844
+ static void write_string(std::ostream &os, const std::string &s){
845
+ write_u64(os, static_cast<std::uint64_t>(s.size()));
846
+ if (!s.empty()){
847
+ os.write(s.data(), static_cast<std::streamsize>(s.size()));
848
+ if(!os) throw std::runtime_error("write_string failed");
849
+ }
850
+ }
851
+
852
+ static std::string read_string(std::istream &is){
853
+ std::uint64_t n = read_u64(is);
854
+ if (n > (1ULL << 30)) throw std::runtime_error("corrupt save file: string too large");
855
+
856
+ std::string s;
857
+ s.resize(static_cast<size_t>(n));
858
+
859
+ if (n != 0){
860
+ is.read(&s[0], static_cast<std::streamsize>(n));
861
+ if(!is) throw std::runtime_error("read_string failed");
862
+ }
863
+ return s;
864
+ }
865
 
866
  static void save_kb_binary(const KnowledgeBase &kb, const std::string &fname){
867
+ const std::string temp = fname + ".tmp";
868
+
869
+ {
870
+ std::ofstream ofs(temp.c_str(), std::ios::binary | std::ios::trunc);
871
+ if (!ofs) throw std::runtime_error("cannot open temp save file");
872
+
873
+ std::vector<std::string> pool;
874
+ pool.reserve(kb.interner.pool.size());
875
+ for (const auto &s : kb.interner.pool) pool.push_back(s);
876
+
877
+ std::sort(pool.begin(), pool.end());
878
+
879
+ std::unordered_map<std::string, std::uint64_t> id;
880
+ id.reserve(pool.size());
881
+ for (std::uint64_t i = 0; i < static_cast<std::uint64_t>(pool.size()); ++i){
882
+ id.emplace(pool[(size_t)i], i);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
883
  }
884
+
885
+ write_u64(ofs, KB_MAGIC);
886
+ write_u64(ofs, KB_VERSION);
887
+ write_u64(ofs, static_cast<std::uint64_t>(kb.def_depth));
888
+
889
+ write_u64(ofs, static_cast<std::uint64_t>(pool.size()));
890
+ for (const auto &s : pool) write_string(ofs, s);
891
+
892
+ write_u64(ofs, static_cast<std::uint64_t>(kb.next.size()));
893
+ for (const auto &pr : kb.next){
894
+ write_u64(ofs, id.at(*pr.first));
895
+ write_u64(ofs, static_cast<std::uint64_t>(pr.second.size()));
896
+ for (StrPtr nxt : pr.second){
897
+ write_u64(ofs, id.at(*nxt));
898
+ }
899
+ }
900
+
901
+ write_u64(ofs, static_cast<std::uint64_t>(kb.def_index.size()));
902
+ for (const auto &pr : kb.def_index){
903
+ write_u64(ofs, id.at(*pr.first));
904
+ write_u64(ofs, static_cast<std::uint64_t>(pr.second.size()));
905
+ for (StrPtr tok : pr.second){
906
+ write_u64(ofs, id.at(*tok));
907
+ }
908
  }
909
+
910
+ ofs.flush();
911
+ if (!ofs) throw std::runtime_error("failed while writing temp save file");
912
  }
913
 
914
+ std::remove(fname.c_str());
915
+ if (std::rename(temp.c_str(), fname.c_str()) != 0){
916
+ std::remove(temp.c_str());
917
+ throw std::runtime_error("failed to commit save file");
918
+ }
919
  }
920
 
921
  static void load_kb_binary(KnowledgeBase &kb, const std::string &fname, int cli_dict_depth){
922
  std::ifstream ifs(fname, std::ios::binary);
923
  if (!ifs) throw std::runtime_error("cannot open load file");
924
 
925
+ const std::uint64_t magic = read_u64(ifs);
926
+ if (magic != KB_MAGIC) throw std::runtime_error("bad save file magic");
927
+
928
+ const std::uint64_t version = read_u64(ifs);
929
+ if (version != KB_VERSION) throw std::runtime_error("unsupported save file version");
930
+
931
+ const std::uint64_t file_def_depth = read_u64(ifs);
932
+
933
+ const std::uint64_t N = read_u64(ifs);
934
+ if (N > (1ULL << 26)) throw std::runtime_error("corrupt save file: pool too large");
935
+
936
+ std::vector<std::string> strings;
937
+ strings.reserve(static_cast<size_t>(N));
938
+
939
+ for (std::uint64_t i = 0; i < N; ++i){
940
+ strings.push_back(read_string(ifs));
941
+ }
942
+
943
+ kb.interner.pool.clear();
944
+ kb.interner.pool.reserve(static_cast<size_t>(N));
945
+
946
+ std::vector<StrPtr> ptrs;
947
+ ptrs.reserve(static_cast<size_t>(N));
948
+ for (const auto &s : strings){
949
+ ptrs.push_back(kb.interner.intern(s));
950
+ }
951
+
952
+ // Rebuild next
953
+ const std::uint64_t E = read_u64(ifs);
954
+ if (E > (1ULL << 26)) throw std::runtime_error("corrupt save file: graph too large");
955
+
956
+ {
957
+ std::lock_guard<std::mutex> lk(kb.m);
958
+ kb.next.clear();
959
+ kb.next_key_index.clear();
960
+ kb.next.reserve(static_cast<size_t>(E));
961
+ kb.next_key_index.reserve(static_cast<size_t>(E));
962
+ }
963
+
964
+ for (std::uint64_t i = 0; i < E; ++i){
965
+ const std::uint64_t key_idx = read_u64(ifs);
966
+ const std::uint64_t M = read_u64(ifs);
967
+
968
+ if (key_idx >= ptrs.size()) throw std::runtime_error("corrupt save file: bad graph key");
969
+ if (M > (1ULL << 26)) throw std::runtime_error("corrupt save file: graph degree too large");
970
+
971
+ StrPtr key_ptr = ptrs[(size_t)key_idx];
972
+ NextSet vec;
973
+ vec.reserve(static_cast<size_t>(M));
974
+
975
+ for (std::uint64_t j = 0; j < M; ++j){
976
+ const std::uint64_t v_idx = read_u64(ifs);
977
+ if (v_idx >= ptrs.size()) throw std::runtime_error("corrupt save file: bad graph value");
978
+ vec.push_back(ptrs[(size_t)v_idx]);
979
+ }
980
+
981
+ {
982
+ std::lock_guard<std::mutex> lk(kb.m);
983
+ kb.next.emplace(key_ptr, std::move(vec));
984
+ kb.next_key_index.emplace(*key_ptr, key_ptr);
985
  }
 
986
  }
987
 
988
+ // Rebuild def_index from file
989
+ const std::uint64_t K = read_u64(ifs);
990
+ if (K > (1ULL << 26)) throw std::runtime_error("corrupt save file: def_index too large");
991
+
 
992
  {
993
  std::lock_guard<std::mutex> lk(kb.def_m);
994
  kb.def_index.clear();
995
+ kb.def_index.reserve(static_cast<size_t>(K));
996
  kb.def_depth = static_cast<int>(file_def_depth);
997
  }
998
+
999
+ for (std::uint64_t i = 0; i < K; ++i){
1000
+ const std::uint64_t key_idx = read_u64(ifs);
1001
+ const std::uint64_t M = read_u64(ifs);
1002
+
1003
+ if (key_idx >= ptrs.size()) throw std::runtime_error("corrupt save file: bad def key");
1004
+ if (M > (1ULL << 26)) throw std::runtime_error("corrupt save file: def list too large");
1005
+
1006
+ std::vector<StrPtr> toks;
1007
+ toks.reserve(static_cast<size_t>(M));
1008
+
1009
+ for (std::uint64_t j = 0; j < M; ++j){
1010
+ const std::uint64_t v_idx = read_u64(ifs);
1011
+ if (v_idx >= ptrs.size()) throw std::runtime_error("corrupt save file: bad def value");
1012
+ toks.push_back(ptrs[(size_t)v_idx]);
1013
+ }
1014
+
1015
+ {
1016
+ std::lock_guard<std::mutex> lk(kb.def_m);
1017
+ kb.def_index.emplace(ptrs[(size_t)key_idx], std::move(toks));
1018
  }
 
1019
  }
1020
 
1021
+ // If the caller asks for a different dict depth, recompute with the current embedded dictionary.
1022
+ if (cli_dict_depth != static_cast<int>(file_def_depth)){
1023
  kb.set_def_depth(cli_dict_depth);
1024
+
1025
  std::vector<StrPtr> targets;
1026
+ targets.reserve(ptrs.size() + kb.next.size() * 2);
1027
+
1028
+ std::unordered_set<StrPtr, PtrHash, PtrEq> seen;
1029
+ seen.reserve(ptrs.size() + kb.next.size() * 2);
1030
+
1031
+ for (StrPtr p : ptrs){
1032
+ if (seen.insert(p).second) targets.push_back(p);
1033
+ }
1034
 
1035
  {
1036
+ std::lock_guard<std::mutex> lk(kb.m);
1037
+ for (const auto &pr : kb.next){
 
 
 
 
 
1038
  if (seen.insert(pr.first).second) targets.push_back(pr.first);
1039
+ for (StrPtr v : pr.second){
1040
  if (seen.insert(v).second) targets.push_back(v);
1041
  }
1042
  }
1043
  }
1044
 
 
1045
  #pragma omp parallel for schedule(dynamic)
1046
+ for (ptrdiff_t i = 0; i < static_cast<ptrdiff_t>(targets.size()); ++i){
1047
  kb.ensure_def_for_interned(targets[(size_t)i]);
1048
  }
1049
  }
 
1095
  KnowledgeBase kb;
1096
 
1097
  // parse the embedded dictionary once for use by per-word expansion
1098
+ global_dictionary_entries = parse_dictionary_json();
1099
  build_def_tokens_cache();
1100
  // set KB def depth (clears any previous expansion)
1101
  kb.set_def_depth(dict_depth);