technician1 commited on
Commit
ec7d802
·
verified ·
1 Parent(s): 7147dc8

Upload ChatIPC.cpp

Browse files
Files changed (1) hide show
  1. ChatIPC.cpp +152 -64
ChatIPC.cpp CHANGED
@@ -286,16 +286,79 @@ static std::vector<std::string> tokenize_non_alnum(const std::string &s){
286
 
287
  // --------------------------- String interning (short methods) --------------
288
 
 
 
 
 
 
 
 
 
 
 
 
 
289
  struct StringInterner {
290
  std::unordered_set<std::string> pool;
291
- std::mutex m;
 
 
 
292
  const std::string* intern(const std::string &s){
293
  std::lock_guard<std::mutex> lk(m);
 
294
  auto [it, inserted] = pool.emplace(s);
 
 
 
 
 
 
295
  return &*it;
296
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
  };
298
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
  // ---------- Global parsed dictionary (populated once in main) ----------
300
  static void build_def_tokens_cache(){
301
  global_def_tokens_cache.clear();
@@ -332,7 +395,6 @@ static void build_def_tokens_cache(){
332
  }
333
 
334
  // --------------------------- Knowledge base (short methods) --------------
335
- using StrPtr = const std::string*;
336
  struct PtrHash { size_t operator()(StrPtr p) const noexcept { return std::hash<StrPtr>()(p); } };
337
  struct PtrEq { bool operator()(StrPtr a, StrPtr b) const noexcept { return a == b; } };
338
 
@@ -453,28 +515,6 @@ intern_tokens(KnowledgeBase &kb, const std::vector<std::string> &tokens)
453
  return out;
454
  }
455
 
456
- static std::unordered_set<std::string>
457
- aggregate_sets(const std::vector<StrPtr> &tokens,
458
- const std::unordered_map<StrPtr, std::vector<StrPtr>, PtrHash, PtrEq> &def_index)
459
- {
460
- std::unordered_set<std::string> agg;
461
-
462
- for (StrPtr t : tokens){
463
- const std::string tk = normalize_dictionary_key(*t);
464
- if (!tk.empty()) agg.insert(tk);
465
-
466
- auto it = def_index.find(t);
467
- if (it != def_index.end()){
468
- for (StrPtr d : it->second){
469
- const std::string dk = normalize_dictionary_key(*d);
470
- if (!dk.empty()) agg.insert(dk);
471
- }
472
- }
473
- }
474
-
475
- return agg;
476
- }
477
-
478
  // --------------------------- Small JSON parse helpers ----------------------
479
 
480
  static inline bool json_valid_index(size_t i, size_t n){ return i < n; }
@@ -647,6 +687,7 @@ static std::string best_candidate_by_similarity(
647
  const std::vector<StrPtr> &prompt_ptrs,
648
  const std::vector<StrPtr> &resp_ptrs,
649
  const std::unordered_map<StrPtr, std::vector<StrPtr>, PtrHash, PtrEq> &def_index,
 
650
  const std::unordered_map<std::string,int> &recent_counts,
651
  double repeat_penalty,
652
  const std::string &context_tok)
@@ -654,59 +695,108 @@ static std::string best_candidate_by_similarity(
654
  if (cands.empty()) return std::string();
655
  if (cands.size() == 1) return *cands[0];
656
 
657
- auto agg = aggregate_sets(prompt_ptrs, def_index);
658
- for (StrPtr r : resp_ptrs){
659
- auto it = def_index.find(r);
 
 
 
 
 
 
660
  if (it != def_index.end()){
661
  for (StrPtr d : it->second){
662
- const std::string dk = normalize_dictionary_key(*d);
663
- if (!dk.empty()) agg.insert(dk);
664
  }
665
  }
666
- }
667
 
668
- double best = -1e9;
669
- std::string best_tok;
670
- size_t M = cands.size();
671
- std::vector<double> scores(M, 0.0);
672
 
673
- #pragma omp parallel for schedule(static)
674
- for (ptrdiff_t i = 0; i < static_cast<ptrdiff_t>(M); ++i){
675
- const StrPtr cand = cands[(size_t)i];
676
- const std::string cand_key = normalize_dictionary_key(*cand);
677
 
678
- size_t inter = (!cand_key.empty() && agg.count(cand_key)) ? 1 : 0;
679
- size_t cand_size = 1;
 
 
 
 
680
 
681
  auto it = def_index.find(cand);
682
  if (it != def_index.end()){
683
- cand_size += it->second.size();
684
  for (StrPtr d : it->second){
685
- const std::string dk = normalize_dictionary_key(*d);
686
- if (!dk.empty() && agg.count(dk)) ++inter;
687
- }
688
- if (std::find(it->second.begin(), it->second.end(), cand) != it->second.end()){
689
- --cand_size;
690
  }
691
  }
692
 
693
- size_t uni = agg.size() + cand_size - inter;
694
- double s = uni ? static_cast<double>(inter) / static_cast<double>(uni) : 0.0;
695
- scores[(size_t)i] = s;
696
  }
697
 
698
- for (size_t i = 0; i < M; ++i){
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
699
  const std::string &tok = *cands[i];
700
  const std::string tok_key = normalize_dictionary_key(tok);
701
- const std::string count_key = tok_key.empty() ? tok : tok;
702
 
703
- double s = scores[i];
704
  auto rc_it = recent_counts.find(count_key);
705
- int cnt = (rc_it == recent_counts.end() ? 0 : rc_it->second);
706
 
707
- double adjusted = s
708
- + english_rule_bonus(context_tok, tok)
709
- - repeat_penalty * static_cast<double>(cnt);
 
710
 
711
  if (adjusted > best || (adjusted == best && tok < best_tok)){
712
  best = adjusted;
@@ -787,7 +877,8 @@ static std::vector<std::string> construct_response(KnowledgeBase &kb,
787
  }
788
 
789
  std::string chosen = best_candidate_by_similarity(
790
- candidates, prompt_ptrs, resp_ptrs, kb.def_index, recent_counts, repeat_penalty, context_tok
 
791
  );
792
 
793
  if (chosen.empty()) break;
@@ -1102,7 +1193,8 @@ int main(int argc, char **argv){
1102
 
1103
 
1104
  if (!load_kb.empty()){
1105
- try { load_kb_binary(kb, load_kb, dict_depth); std::cerr << "Loaded KB: " << load_kb << "\n"; }
 
1106
  catch (const std::exception &e){ std::cerr << "Load KB error: " << e.what() << "\n"; }
1107
  }
1108
 
@@ -1125,15 +1217,11 @@ int main(int argc, char **argv){
1125
  for (size_t i=1;i<combined.size();++i) kb.add_pair(combined[i-1], combined[i]);
1126
  }
1127
  if (!savefile.empty()){
1128
- try { save_kb_binary(kb, savefile); std::cerr << "Saved KB: " << savefile << "\n"; }
 
1129
  catch (const std::exception &e){ std::cerr << "Save KB error: " << e.what() << "\n"; }
1130
  }
1131
  }
1132
 
1133
- if (!savefile.empty()){
1134
- try { save_kb_binary(kb, savefile); std::cerr << "Saved KB: " << savefile << "\n"; }
1135
- catch (const std::exception &e){ std::cerr << "Save KB error: " << e.what() << "\n"; }
1136
- }
1137
-
1138
  return 0;
1139
  }
 
286
 
287
  // --------------------------- String interning (short methods) --------------
288
 
289
+ using StrPtr = const std::string*;
290
+ using TokenId = std::uint32_t;
291
+ static constexpr TokenId TOKEN_ID_INVALID = 0xFFFFFFFFu;
292
+
293
+ static inline std::size_t popcount_u64(std::uint64_t x){
294
+ #if defined(_MSC_VER)
295
+ return static_cast<std::size_t>(__popcnt64(x));
296
+ #else
297
+ return static_cast<std::size_t>(__builtin_popcountll(static_cast<unsigned long long>(x)));
298
+ #endif
299
+ }
300
+
301
  struct StringInterner {
302
  std::unordered_set<std::string> pool;
303
+ std::unordered_map<std::string, TokenId> id_by_string;
304
+ std::vector<const std::string*> string_by_id;
305
+ mutable std::mutex m;
306
+
307
  const std::string* intern(const std::string &s){
308
  std::lock_guard<std::mutex> lk(m);
309
+
310
  auto [it, inserted] = pool.emplace(s);
311
+ if (inserted){
312
+ const TokenId id = static_cast<TokenId>(string_by_id.size());
313
+ id_by_string.emplace(*it, id);
314
+ string_by_id.push_back(&*it);
315
+ }
316
+
317
  return &*it;
318
  }
319
+
320
+ TokenId id_of(const std::string &s) const {
321
+ std::lock_guard<std::mutex> lk(m);
322
+
323
+ auto it = id_by_string.find(s);
324
+ return (it == id_by_string.end()) ? TOKEN_ID_INVALID : it->second;
325
+ }
326
+
327
+ TokenId id_of(const std::string *p) const {
328
+ return p ? id_of(*p) : TOKEN_ID_INVALID;
329
+ }
330
+
331
+ const std::string* ptr_from_id(TokenId id) const {
332
+ std::lock_guard<std::mutex> lk(m);
333
+
334
+ return (id < string_by_id.size()) ? string_by_id[(size_t)id] : nullptr;
335
+ }
336
+
337
+ std::size_t size() const {
338
+ std::lock_guard<std::mutex> lk(m);
339
+ return string_by_id.size();
340
+ }
341
  };
342
 
343
+ static inline void bitset_set(std::uint64_t *bits, std::size_t words, TokenId id){
344
+ if (id == TOKEN_ID_INVALID) return;
345
+ const std::size_t idx = static_cast<std::size_t>(id >> 6);
346
+ if (idx >= words) return;
347
+ bits[idx] |= (1ULL << (id & 63u));
348
+ }
349
+
350
+ static inline std::size_t bitset_count(const std::uint64_t *bits, std::size_t words){
351
+ std::size_t total = 0;
352
+ for (std::size_t i = 0; i < words; ++i) total += popcount_u64(bits[i]);
353
+ return total;
354
+ }
355
+
356
+ static inline std::size_t bitset_intersection_count(const std::uint64_t *a, const std::uint64_t *b, std::size_t words){
357
+ std::size_t total = 0;
358
+ for (std::size_t i = 0; i < words; ++i) total += popcount_u64(a[i] & b[i]);
359
+ return total;
360
+ }
361
+
362
  // ---------- Global parsed dictionary (populated once in main) ----------
363
  static void build_def_tokens_cache(){
364
  global_def_tokens_cache.clear();
 
395
  }
396
 
397
  // --------------------------- Knowledge base (short methods) --------------
 
398
  struct PtrHash { size_t operator()(StrPtr p) const noexcept { return std::hash<StrPtr>()(p); } };
399
  struct PtrEq { bool operator()(StrPtr a, StrPtr b) const noexcept { return a == b; } };
400
 
 
515
  return out;
516
  }
517
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
518
  // --------------------------- Small JSON parse helpers ----------------------
519
 
520
  static inline bool json_valid_index(size_t i, size_t n){ return i < n; }
 
687
  const std::vector<StrPtr> &prompt_ptrs,
688
  const std::vector<StrPtr> &resp_ptrs,
689
  const std::unordered_map<StrPtr, std::vector<StrPtr>, PtrHash, PtrEq> &def_index,
690
+ const StringInterner &interner,
691
  const std::unordered_map<std::string,int> &recent_counts,
692
  double repeat_penalty,
693
  const std::string &context_tok)
 
695
  if (cands.empty()) return std::string();
696
  if (cands.size() == 1) return *cands[0];
697
 
698
+ const std::size_t words = std::max<std::size_t>(1, (interner.size() + 63u) / 64u);
699
+
700
+ std::vector<std::uint64_t> agg_words(words, 0ULL);
701
+
702
+ auto add_token_and_defs = [&](StrPtr t){
703
+ TokenId tid = interner.id_of(t);
704
+ if (tid != TOKEN_ID_INVALID) bitset_set(agg_words.data(), words, tid);
705
+
706
+ auto it = def_index.find(t);
707
  if (it != def_index.end()){
708
  for (StrPtr d : it->second){
709
+ TokenId did = interner.id_of(d);
710
+ if (did != TOKEN_ID_INVALID) bitset_set(agg_words.data(), words, did);
711
  }
712
  }
713
+ };
714
 
715
+ for (StrPtr t : prompt_ptrs) add_token_and_defs(t);
716
+ for (StrPtr t : resp_ptrs) add_token_and_defs(t);
717
+
718
+ const std::size_t agg_count = bitset_count(agg_words.data(), words);
719
 
720
+ const std::size_t M = cands.size();
721
+ std::vector<std::uint64_t> cand_words(M * words, 0ULL);
722
+ std::vector<std::size_t> cand_counts(M, 0);
 
723
 
724
+ for (std::size_t i = 0; i < M; ++i){
725
+ std::uint64_t *row = cand_words.data() + i * words;
726
+ const StrPtr cand = cands[i];
727
+
728
+ TokenId cid = interner.id_of(cand);
729
+ if (cid != TOKEN_ID_INVALID) bitset_set(row, words, cid);
730
 
731
  auto it = def_index.find(cand);
732
  if (it != def_index.end()){
 
733
  for (StrPtr d : it->second){
734
+ TokenId did = interner.id_of(d);
735
+ if (did != TOKEN_ID_INVALID) bitset_set(row, words, did);
 
 
 
736
  }
737
  }
738
 
739
+ cand_counts[i] = bitset_count(row, words);
 
 
740
  }
741
 
742
+ std::vector<double> scores(M, 0.0);
743
+
744
+ #if defined(_OPENMP) && defined(CHATIPC_ENABLE_OMP_TARGET)
745
+ const bool use_target = (omp_get_num_devices() > 0) && (M >= 256);
746
+ #else
747
+ const bool use_target = false;
748
+ #endif
749
+
750
+ if (use_target){
751
+ std::uint64_t *agg_ptr = agg_words.data();
752
+ std::uint64_t *cand_ptr = cand_words.data();
753
+ std::size_t *count_ptr = cand_counts.data();
754
+ double *score_ptr = scores.data();
755
+ const std::size_t cand_words_total = cand_words.size();
756
+
757
+ #pragma omp target data map(to: agg_ptr[0:words], cand_ptr[0:cand_words_total], count_ptr[0:M]) map(from: score_ptr[0:M])
758
+ {
759
+ #pragma omp target teams distribute parallel for
760
+ for (ptrdiff_t i = 0; i < static_cast<ptrdiff_t>(M); ++i){
761
+ const std::uint64_t *row = cand_ptr + static_cast<std::size_t>(i) * words;
762
+
763
+ std::size_t inter = 0;
764
+ for (std::size_t w = 0; w < words; ++w){
765
+ inter += popcount_u64(agg_ptr[w] & row[w]);
766
+ }
767
+
768
+ const std::size_t uni = agg_count + count_ptr[(size_t)i] - inter;
769
+ score_ptr[(size_t)i] = uni ? static_cast<double>(inter) / static_cast<double>(uni) : 0.0;
770
+ }
771
+ }
772
+ } else {
773
+ #ifdef _OPENMP
774
+ #pragma omp parallel for schedule(static)
775
+ #endif
776
+ for (ptrdiff_t i = 0; i < static_cast<ptrdiff_t>(M); ++i){
777
+ const std::uint64_t *row = cand_words.data() + static_cast<std::size_t>(i) * words;
778
+
779
+ const std::size_t inter = bitset_intersection_count(agg_words.data(), row, words);
780
+ const std::size_t uni = agg_count + cand_counts[(size_t)i] - inter;
781
+ scores[(size_t)i] = uni ? static_cast<double>(inter) / static_cast<double>(uni) : 0.0;
782
+ }
783
+ }
784
+
785
+ double best = -1e9;
786
+ std::string best_tok;
787
+
788
+ for (std::size_t i = 0; i < M; ++i){
789
  const std::string &tok = *cands[i];
790
  const std::string tok_key = normalize_dictionary_key(tok);
791
+ const std::string count_key = tok_key.empty() ? tok : tok_key;
792
 
 
793
  auto rc_it = recent_counts.find(count_key);
794
+ const int cnt = (rc_it == recent_counts.end()) ? 0 : rc_it->second;
795
 
796
+ const double adjusted =
797
+ scores[i] +
798
+ english_rule_bonus(context_tok, tok) -
799
+ repeat_penalty * static_cast<double>(cnt);
800
 
801
  if (adjusted > best || (adjusted == best && tok < best_tok)){
802
  best = adjusted;
 
877
  }
878
 
879
  std::string chosen = best_candidate_by_similarity(
880
+ candidates, prompt_ptrs, resp_ptrs, kb.def_index, kb.interner,
881
+ recent_counts, repeat_penalty, context_tok
882
  );
883
 
884
  if (chosen.empty()) break;
 
1193
 
1194
 
1195
  if (!load_kb.empty()){
1196
+ try { std::cerr << "Loading KB: " << load_kb << "\n";
1197
+ load_kb_binary(kb, load_kb, dict_depth); std::cerr << "Loaded KB: " << load_kb << "\n"; }
1198
  catch (const std::exception &e){ std::cerr << "Load KB error: " << e.what() << "\n"; }
1199
  }
1200
 
 
1217
  for (size_t i=1;i<combined.size();++i) kb.add_pair(combined[i-1], combined[i]);
1218
  }
1219
  if (!savefile.empty()){
1220
+ try { std::cerr << "Saving KB: " << savefile << "\n";
1221
+ save_kb_binary(kb, savefile); std::cerr << "Saved KB: " << savefile << "\n"; }
1222
  catch (const std::exception &e){ std::cerr << "Save KB error: " << e.what() << "\n"; }
1223
  }
1224
  }
1225
 
 
 
 
 
 
1226
  return 0;
1227
  }