|
|
|
|
| #include <algorithm>
|
| #include <atomic>
|
| #include <cctype>
|
| #include <cinttypes>
|
| #include <cstring>
|
| #include <fstream>
|
| #include <iostream>
|
| #include <iterator>
|
| #include <map>
|
| #include <mutex>
|
| #include <optional>
|
| #include <sstream>
|
| #include <stdexcept>
|
| #include <string>
|
| #include <thread>
|
| #include <unordered_map>
|
| #include <unordered_set>
|
| #include <vector>
|
|
|
| #ifdef _OPENMP
|
| #include <omp.h>
|
| #else
|
| inline int omp_get_max_threads(){ return 1; }
|
| inline int omp_get_thread_num(){ return 0; }
|
| #endif
|
|
|
| extern unsigned char dictionary_json[];
|
| extern unsigned int dictionary_json_len;
|
|
|
|
|
|
|
| static inline bool is_space(char c){ return std::isspace(static_cast<unsigned char>(c)) != 0; }
|
| static inline char to_low(char c){ return static_cast<char>(std::tolower(static_cast<unsigned char>(c))); }
|
| static inline void safe_flush(std::ostream &os){ os.flush(); }
|
|
|
|
|
| struct DictionaryEntry {
|
| std::string pos;
|
| std::string word;
|
| std::vector<std::string> definitions;
|
| };
|
|
|
| static std::vector<DictionaryEntry> global_dictionary_entries;
|
| static std::unordered_map<std::string, std::vector<std::string>> global_def_tokens_cache;
|
| static std::unordered_map<std::string, std::vector<std::string>> global_pos_cache;
|
|
|
| static inline bool is_word_char_for_key(char c){
|
| unsigned char uc = static_cast<unsigned char>(c);
|
| return std::isalnum(uc) != 0 || c == '\'' || c == '-';
|
| }
|
|
|
| static std::string normalize_dictionary_key(const std::string &s){
|
| size_t b = 0, e = s.size();
|
| while (b < e && !is_word_char_for_key(s[b])) ++b;
|
| while (e > b && !is_word_char_for_key(s[e - 1])) --e;
|
|
|
| std::string out;
|
| out.reserve(e - b);
|
| for (size_t i = b; i < e; ++i) out.push_back(to_low(s[i]));
|
| return out;
|
| }
|
|
|
| static std::string normalize_pos_tag(const std::string &s){
|
| std::string out;
|
| out.reserve(s.size());
|
| for (char c : s){
|
| unsigned char uc = static_cast<unsigned char>(c);
|
| if (std::isalpha(uc) != 0) out.push_back(to_low(c));
|
| }
|
| return out;
|
| }
|
|
|
| enum class PosClass {
|
| Unknown,
|
| Noun,
|
| Verb,
|
| Adj,
|
| Adv,
|
| Pron,
|
| Prep,
|
| Conj,
|
| Det,
|
| Num,
|
| Interj
|
| };
|
|
|
| static PosClass pos_class_from_tag(const std::string &tag){
|
| if (tag == "n" || tag == "noun") return PosClass::Noun;
|
| if (tag == "v" || tag == "verb" || tag == "part" || tag == "participle" || tag == "p") return PosClass::Verb;
|
| if (tag == "a" || tag == "adj" || tag == "adjective") return PosClass::Adj;
|
| if (tag == "adv" || tag == "adverb") return PosClass::Adv;
|
| if (tag == "pron" || tag == "pronoun") return PosClass::Pron;
|
| if (tag == "prep" || tag == "preposition") return PosClass::Prep;
|
| if (tag == "conj" || tag == "conjunction") return PosClass::Conj;
|
| if (tag == "art" || tag == "article" || tag == "det" || tag == "determiner") return PosClass::Det;
|
| if (tag == "num" || tag == "number") return PosClass::Num;
|
| if (tag == "interj" || tag == "interjection") return PosClass::Interj;
|
| return PosClass::Unknown;
|
| }
|
|
|
| static bool has_pos_class(const std::vector<std::string> &tags, PosClass cls){
|
| for (const auto &t : tags){
|
| if (pos_class_from_tag(t) == cls) return true;
|
| }
|
| return false;
|
| }
|
|
|
| static const std::vector<std::string> &dictionary_pos_for_token(const std::string &surface){
|
| static const std::vector<std::string> empty;
|
| auto key = normalize_dictionary_key(surface);
|
| if (key.empty()) return empty;
|
|
|
| auto it = global_pos_cache.find(key);
|
| return (it == global_pos_cache.end()) ? empty : it->second;
|
| }
|
|
|
| static bool first_alpha_is_upper(const std::string &s){
|
| for (char c : s){
|
| unsigned char uc = static_cast<unsigned char>(c);
|
| if (std::isalpha(uc) != 0) return std::isupper(uc) != 0;
|
| }
|
| return false;
|
| }
|
|
|
| static bool first_alpha_is_lower(const std::string &s){
|
| for (char c : s){
|
| unsigned char uc = static_cast<unsigned char>(c);
|
| if (std::isalpha(uc) != 0) return std::islower(uc) != 0;
|
| }
|
| return false;
|
| }
|
|
|
| static bool is_sentence_boundary_token(const std::string &s){
|
| if (s.empty()) return false;
|
| char c = s.back();
|
| return c == '.' || c == '!' || c == '?';
|
| }
|
|
|
| static bool is_open_punct_token(const std::string &s){
|
| return s == "(" || s == "[" || s == "{" || s == "\"" || s == "'";
|
| }
|
|
|
| static bool is_punctuation_only_token(const std::string &s){
|
| if (s.empty()) return false;
|
| for (char c : s){
|
| unsigned char uc = static_cast<unsigned char>(c);
|
| if (std::isalnum(uc) != 0) return false;
|
| }
|
| return true;
|
| }
|
|
|
| static bool is_common_determiner(const std::string &s){
|
| return s == "a" || s == "an" || s == "the" || s == "this" || s == "that" ||
|
| s == "these" || s == "those" || s == "my" || s == "your" || s == "his" ||
|
| s == "her" || s == "its" || s == "our" || s == "their";
|
| }
|
|
|
| static bool is_common_preposition(const std::string &s){
|
| return s == "of" || s == "in" || s == "on" || s == "at" || s == "by" ||
|
| s == "for" || s == "from" || s == "with" || s == "into" || s == "onto" ||
|
| s == "about" || s == "over" || s == "under" || s == "after" || s == "before" ||
|
| s == "between" || s == "through" || s == "during" || s == "without" || s == "within" ||
|
| s == "under" || s == "across" || s == "against" || s == "among" || s == "around";
|
| }
|
|
|
| static bool is_common_aux_or_modal(const std::string &s){
|
| return s == "to" || s == "be" || s == "am" || s == "is" || s == "are" || s == "was" ||
|
| s == "were" || s == "been" || s == "being" || s == "have" || s == "has" ||
|
| s == "had" || s == "do" || s == "does" || s == "did" || s == "can" ||
|
| s == "could" || s == "may" || s == "might" || s == "must" || s == "shall" ||
|
| s == "should" || s == "will" || s == "would";
|
| }
|
|
|
| static bool begins_with_vowel_sound(const std::string &s){
|
| if (s.empty()) return false;
|
|
|
| if (s.rfind("hour", 0) == 0 || s.rfind("honest", 0) == 0 || s.rfind("honor", 0) == 0 ||
|
| s.rfind("heir", 0) == 0 || s.rfind("herb", 0) == 0) {
|
| return true;
|
| }
|
|
|
| if (s.rfind("uni", 0) == 0 || s.rfind("use", 0) == 0 || s.rfind("user", 0) == 0 ||
|
| s.rfind("one", 0) == 0 || s.rfind("once", 0) == 0 || s.rfind("euro", 0) == 0) {
|
| return false;
|
| }
|
|
|
| char c = s[0];
|
| return c == 'a' || c == 'e' || c == 'i' || c == 'o' || c == 'u';
|
| }
|
|
|
| static double english_rule_bonus(const std::string &context_tok, const std::string &cand){
|
| const std::string ctx_key = normalize_dictionary_key(context_tok);
|
| const std::string cand_key = normalize_dictionary_key(cand);
|
|
|
| const auto &ctx_tags = dictionary_pos_for_token(context_tok);
|
| const auto &cand_tags = dictionary_pos_for_token(cand);
|
|
|
| const bool sentence_start = context_tok.empty() || is_sentence_boundary_token(context_tok) || is_open_punct_token(context_tok);
|
|
|
| const bool cand_nounish = has_pos_class(cand_tags, PosClass::Noun) ||
|
| has_pos_class(cand_tags, PosClass::Adj) ||
|
| has_pos_class(cand_tags, PosClass::Pron) ||
|
| has_pos_class(cand_tags, PosClass::Num);
|
|
|
| const bool cand_verbish = has_pos_class(cand_tags, PosClass::Verb);
|
| const bool cand_advish = has_pos_class(cand_tags, PosClass::Adv);
|
| const bool cand_prepish = has_pos_class(cand_tags, PosClass::Prep);
|
| const bool cand_detish = has_pos_class(cand_tags, PosClass::Det);
|
|
|
| double bonus = 0.0;
|
|
|
| if (!cand_key.empty()){
|
| if (sentence_start){
|
| bonus += first_alpha_is_upper(cand) ? 0.22 : -0.08;
|
| } else if (first_alpha_is_upper(cand)){
|
| bonus -= 0.03;
|
| }
|
| }
|
|
|
| if (ctx_key == "a" || ctx_key == "an"){
|
| const bool vowel = begins_with_vowel_sound(cand_key.empty() ? cand : cand_key);
|
| bonus += ((ctx_key == "an") == vowel) ? 0.28 : -0.18;
|
| }
|
|
|
| const bool ctx_det = has_pos_class(ctx_tags, PosClass::Det) || is_common_determiner(ctx_key);
|
| const bool ctx_prep = has_pos_class(ctx_tags, PosClass::Prep) || is_common_preposition(ctx_key);
|
| const bool ctx_aux = is_common_aux_or_modal(ctx_key);
|
|
|
| if (ctx_det){
|
| if (cand_nounish) bonus += 0.20;
|
| if (cand_verbish || cand_advish || cand_prepish) bonus -= 0.08;
|
| }
|
|
|
| if (ctx_prep){
|
| if (cand_nounish) bonus += 0.16;
|
| if (cand_verbish) bonus -= 0.06;
|
| }
|
|
|
| if (ctx_aux){
|
| if (cand_verbish) bonus += 0.18;
|
| if (cand_detish) bonus -= 0.04;
|
| }
|
|
|
| if (has_pos_class(ctx_tags, PosClass::Pron) || has_pos_class(ctx_tags, PosClass::Noun)){
|
| if (cand_verbish) bonus += 0.05;
|
| }
|
|
|
| if (!context_tok.empty() && (context_tok.back() == ',' || context_tok.back() == ';' || context_tok.back() == ':')){
|
| if (!cand.empty() && first_alpha_is_lower(cand)) bonus += 0.04;
|
| }
|
|
|
| if (is_punctuation_only_token(cand)){
|
| if (sentence_start) bonus -= 0.05;
|
| else if (!context_tok.empty() && std::isalnum(static_cast<unsigned char>(context_tok.back())) != 0) bonus += 0.03;
|
| }
|
|
|
| if (is_sentence_boundary_token(cand)) bonus += 0.06;
|
|
|
| return bonus;
|
| }
|
|
|
|
|
| static std::vector<std::string> tokenize_whitespace(const std::string &s){
|
| std::istringstream iss(s);
|
| std::vector<std::string> out;
|
| std::string t;
|
| while (iss >> t) out.push_back(t);
|
| return out;
|
| }
|
|
|
|
|
| static std::vector<std::string> tokenize_non_alnum(const std::string &s){
|
| std::vector<std::string> out; std::string cur;
|
| for (char ch : s){
|
| if (std::isalnum(static_cast<unsigned char>(ch)) || ch=='-' || ch=='\''){
|
| cur.push_back(to_low(ch));
|
| } else {
|
| if (!cur.empty()){ out.push_back(cur); cur.clear(); }
|
| }
|
| }
|
| if (!cur.empty()) out.push_back(cur);
|
| return out;
|
| }
|
|
|
|
|
|
|
| using StrPtr = const std::string*;
|
| using TokenId = std::uint32_t;
|
| static constexpr TokenId TOKEN_ID_INVALID = 0xFFFFFFFFu;
|
|
|
| static inline std::size_t popcount_u64(std::uint64_t x){
|
| #if defined(_MSC_VER)
|
| return static_cast<std::size_t>(__popcnt64(x));
|
| #else
|
| return static_cast<std::size_t>(__builtin_popcountll(static_cast<unsigned long long>(x)));
|
| #endif
|
| }
|
|
|
| struct StringInterner {
|
| std::unordered_set<std::string> pool;
|
| std::unordered_map<std::string, TokenId> id_by_string;
|
| std::vector<const std::string*> string_by_id;
|
| mutable std::mutex m;
|
|
|
| const std::string* intern(const std::string &s){
|
| std::lock_guard<std::mutex> lk(m);
|
|
|
| auto [it, inserted] = pool.emplace(s);
|
| if (inserted){
|
| const TokenId id = static_cast<TokenId>(string_by_id.size());
|
| id_by_string.emplace(*it, id);
|
| string_by_id.push_back(&*it);
|
| }
|
|
|
| return &*it;
|
| }
|
|
|
| TokenId id_of(const std::string &s) const {
|
| std::lock_guard<std::mutex> lk(m);
|
|
|
| auto it = id_by_string.find(s);
|
| return (it == id_by_string.end()) ? TOKEN_ID_INVALID : it->second;
|
| }
|
|
|
| TokenId id_of(const std::string *p) const {
|
| return p ? id_of(*p) : TOKEN_ID_INVALID;
|
| }
|
|
|
| const std::string* ptr_from_id(TokenId id) const {
|
| std::lock_guard<std::mutex> lk(m);
|
|
|
| return (id < string_by_id.size()) ? string_by_id[(size_t)id] : nullptr;
|
| }
|
|
|
| std::size_t size() const {
|
| std::lock_guard<std::mutex> lk(m);
|
| return string_by_id.size();
|
| }
|
| };
|
|
|
| static inline void bitset_set(std::uint64_t *bits, std::size_t words, TokenId id){
|
| if (id == TOKEN_ID_INVALID) return;
|
| const std::size_t idx = static_cast<std::size_t>(id >> 6);
|
| if (idx >= words) return;
|
| bits[idx] |= (1ULL << (id & 63u));
|
| }
|
|
|
| static inline std::size_t bitset_count(const std::uint64_t *bits, std::size_t words){
|
| std::size_t total = 0;
|
| for (std::size_t i = 0; i < words; ++i) total += popcount_u64(bits[i]);
|
| return total;
|
| }
|
|
|
| static inline std::size_t bitset_intersection_count(const std::uint64_t *a, const std::uint64_t *b, std::size_t words){
|
| std::size_t total = 0;
|
| for (std::size_t i = 0; i < words; ++i) total += popcount_u64(a[i] & b[i]);
|
| return total;
|
| }
|
|
|
|
|
| static void build_def_tokens_cache(){
|
| global_def_tokens_cache.clear();
|
| global_pos_cache.clear();
|
|
|
| global_def_tokens_cache.reserve(global_dictionary_entries.size());
|
| global_pos_cache.reserve(global_dictionary_entries.size());
|
|
|
| for (const auto &entry : global_dictionary_entries){
|
| const std::string key = normalize_dictionary_key(entry.word);
|
| if (key.empty()) continue;
|
|
|
| std::string pos = normalize_pos_tag(entry.pos);
|
| if (!pos.empty()) global_pos_cache[key].push_back(std::move(pos));
|
|
|
| auto &defs = global_def_tokens_cache[key];
|
| for (const auto &def : entry.definitions){
|
| auto toks = tokenize_non_alnum(def);
|
| defs.insert(defs.end(), toks.begin(), toks.end());
|
| }
|
| }
|
|
|
| for (auto &pr : global_def_tokens_cache){
|
| auto &v = pr.second;
|
| std::sort(v.begin(), v.end());
|
| v.erase(std::unique(v.begin(), v.end()), v.end());
|
| }
|
|
|
| for (auto &pr : global_pos_cache){
|
| auto &v = pr.second;
|
| std::sort(v.begin(), v.end());
|
| v.erase(std::unique(v.begin(), v.end()), v.end());
|
| }
|
| }
|
|
|
|
|
| struct PtrHash { size_t operator()(StrPtr p) const noexcept { return std::hash<StrPtr>()(p); } };
|
| struct PtrEq { bool operator()(StrPtr a, StrPtr b) const noexcept { return a == b; } };
|
|
|
| using NextSet = std::vector<StrPtr>;
|
|
|
| struct KnowledgeBase {
|
| StringInterner interner;
|
| std::unordered_map<StrPtr, NextSet, PtrHash, PtrEq> next;
|
| std::unordered_map<std::string, StrPtr> next_key_index;
|
| mutable std::mutex m;
|
|
|
|
|
| std::unordered_map<StrPtr, std::vector<StrPtr>, PtrHash, PtrEq> def_index;
|
| mutable std::mutex def_m;
|
| int def_depth = 0;
|
|
|
| void add_pair_interned(StrPtr k, StrPtr v){
|
| std::lock_guard<std::mutex> lk(m);
|
| next_key_index.emplace(*k, k);
|
| auto &vec = next[k];
|
| for (auto p : vec) if (p == v) return;
|
| vec.push_back(v);
|
| }
|
|
|
|
|
| void set_def_depth(int D){
|
| std::lock_guard<std::mutex> lk(def_m);
|
| if (D != def_depth){
|
| def_index.clear();
|
| def_depth = D;
|
| }
|
| }
|
|
|
| void ensure_def_for_interned(StrPtr wp){
|
| if (wp == nullptr) return;
|
| if (def_depth <= 0) return;
|
|
|
| {
|
| std::lock_guard<std::mutex> lk(def_m);
|
| if (def_index.find(wp) != def_index.end()) return;
|
| }
|
|
|
| std::unordered_set<StrPtr, PtrHash, PtrEq> acc;
|
| std::vector<StrPtr> frontier;
|
|
|
| const std::string start_key = normalize_dictionary_key(*wp);
|
| if (!start_key.empty()){
|
| auto it_def = global_def_tokens_cache.find(start_key);
|
| if (it_def != global_def_tokens_cache.end()){
|
| for (const auto &tok : it_def->second){
|
| StrPtr tp = interner.intern(tok);
|
| if (acc.insert(tp).second) frontier.push_back(tp);
|
| }
|
| }
|
| }
|
|
|
| for (int depth = 1; depth < def_depth && !frontier.empty(); ++depth){
|
| std::vector<StrPtr> next_frontier;
|
|
|
| for (StrPtr w : frontier){
|
| const std::string key = normalize_dictionary_key(*w);
|
| if (key.empty()) continue;
|
|
|
| auto it2 = global_def_tokens_cache.find(key);
|
| if (it2 == global_def_tokens_cache.end()) continue;
|
|
|
| for (const auto &tok : it2->second){
|
| StrPtr tp = interner.intern(tok);
|
| if (acc.insert(tp).second) next_frontier.push_back(tp);
|
| }
|
| }
|
|
|
| frontier.swap(next_frontier);
|
| }
|
|
|
| std::vector<StrPtr> out;
|
| out.reserve(acc.size());
|
| for (StrPtr p : acc) out.push_back(p);
|
|
|
| {
|
| std::lock_guard<std::mutex> lk(def_m);
|
| def_index.emplace(wp, std::move(out));
|
| }
|
| }
|
|
|
| void add_pair(const std::string &k, const std::string &v){
|
| StrPtr kp = interner.intern(k);
|
| StrPtr vp = interner.intern(v);
|
|
|
| ensure_def_for_interned(kp);
|
| ensure_def_for_interned(vp);
|
| add_pair_interned(kp, vp);
|
| }
|
|
|
| std::optional<NextSet> lookup_by_string(const std::string &k) const {
|
| std::lock_guard<std::mutex> lk(m);
|
| auto kit = next_key_index.find(k);
|
| if (kit == next_key_index.end()) return std::nullopt;
|
| auto it = next.find(kit->second);
|
| if (it == next.end()) return std::nullopt;
|
| return it->second;
|
| }
|
|
|
| std::optional<NextSet> lookup_by_ptr(StrPtr k) const {
|
| std::lock_guard<std::mutex> lk(m);
|
| auto it = next.find(k);
|
| if (it == next.end()) return std::nullopt;
|
| return it->second;
|
| }
|
| };
|
|
|
| static std::vector<StrPtr>
|
| intern_tokens(KnowledgeBase &kb, const std::vector<std::string> &tokens)
|
| {
|
| std::vector<StrPtr> out;
|
| out.reserve(tokens.size());
|
| for (const auto &t : tokens) out.push_back(kb.interner.intern(t));
|
| return out;
|
| }
|
|
|
|
|
|
|
| static inline bool json_valid_index(size_t i, size_t n){ return i < n; }
|
|
|
| static std::string parse_quoted_string(const std::string &text, size_t &i){
|
| std::string out;
|
| if (!json_valid_index(i, text.size()) || text[i] != '"') throw std::runtime_error("expected '\"'");
|
| ++i;
|
| while (json_valid_index(i, text.size())){
|
| char c = text[i++];
|
| if (c == '"') break;
|
| if (c == '\\'){
|
| if (!json_valid_index(i, text.size())) break;
|
| char e = text[i++];
|
| if (e=='n') out.push_back('\n');
|
| else if (e=='t') out.push_back('\t');
|
| else out.push_back(e);
|
| } else out.push_back(c);
|
| }
|
| return out;
|
| }
|
|
|
| static void skip_spaces(const std::string &s, size_t &i){
|
| while (json_valid_index(i, s.size()) && is_space(s[i])) ++i;
|
| }
|
|
|
|
|
| static void skip_json_value(const std::string &s, size_t &i);
|
|
|
| static std::vector<std::string> parse_json_string_array(const std::string &text, size_t &i){
|
| std::vector<std::string> out;
|
| if (!json_valid_index(i, text.size()) || text[i] != '[') return out;
|
|
|
| ++i;
|
| while (true){
|
| skip_spaces(text, i);
|
| if (!json_valid_index(i, text.size())) break;
|
| if (text[i] == ']'){ ++i; break; }
|
|
|
| if (text[i] == '"') out.push_back(parse_quoted_string(text, i));
|
| else skip_json_value(text, i);
|
|
|
| skip_spaces(text, i);
|
| if (json_valid_index(i, text.size()) && text[i] == ','){ ++i; continue; }
|
| if (json_valid_index(i, text.size()) && text[i] == ']'){ ++i; break; }
|
| }
|
|
|
| return out;
|
| }
|
|
|
| static void skip_json_value(const std::string &s, size_t &i){
|
| skip_spaces(s, i);
|
| if (!json_valid_index(i, s.size())) return;
|
|
|
| if (s[i] == '"'){
|
| (void)parse_quoted_string(s, i);
|
| return;
|
| }
|
|
|
| if (s[i] == '['){
|
| ++i;
|
| while (json_valid_index(i, s.size())){
|
| skip_spaces(s, i);
|
| if (!json_valid_index(i, s.size())) break;
|
| if (s[i] == ']'){ ++i; break; }
|
| skip_json_value(s, i);
|
| skip_spaces(s, i);
|
| if (json_valid_index(i, s.size()) && s[i] == ','){ ++i; continue; }
|
| if (json_valid_index(i, s.size()) && s[i] == ']'){ ++i; break; }
|
| }
|
| return;
|
| }
|
|
|
| if (s[i] == '{'){
|
| ++i;
|
| while (json_valid_index(i, s.size())){
|
| skip_spaces(s, i);
|
| if (!json_valid_index(i, s.size())) break;
|
| if (s[i] == '}'){ ++i; break; }
|
| if (s[i] == '"'){
|
| (void)parse_quoted_string(s, i);
|
| skip_spaces(s, i);
|
| if (json_valid_index(i, s.size()) && s[i] == ':') ++i;
|
| skip_json_value(s, i);
|
| skip_spaces(s, i);
|
| if (json_valid_index(i, s.size()) && s[i] == ','){ ++i; continue; }
|
| if (json_valid_index(i, s.size()) && s[i] == '}'){ ++i; break; }
|
| } else {
|
| ++i;
|
| }
|
| }
|
| return;
|
| }
|
|
|
| while (json_valid_index(i, s.size())){
|
| char c = s[i];
|
| if (c == ',' || c == ']' || c == '}' || is_space(c)) break;
|
| ++i;
|
| }
|
| }
|
|
|
| static std::vector<DictionaryEntry> parse_dictionary_json(){
|
| std::vector<DictionaryEntry> dict;
|
| if (dictionary_json_len == 0) return dict;
|
|
|
| std::string text;
|
| text.reserve(dictionary_json_len);
|
| for (unsigned int b = 0; b < dictionary_json_len; ++b){
|
| text.push_back(static_cast<char>(dictionary_json[b]));
|
| }
|
|
|
| size_t i = 0;
|
| skip_spaces(text, i);
|
| if (!json_valid_index(i, text.size()) || text[i] != '[') return dict;
|
| ++i;
|
|
|
| while (true){
|
| skip_spaces(text, i);
|
| if (!json_valid_index(i, text.size())) break;
|
| if (text[i] == ']'){ ++i; break; }
|
| if (text[i] != '{'){
|
| skip_json_value(text, i);
|
| skip_spaces(text, i);
|
| if (json_valid_index(i, text.size()) && text[i] == ','){ ++i; continue; }
|
| if (json_valid_index(i, text.size()) && text[i] == ']'){ ++i; break; }
|
| continue;
|
| }
|
|
|
| ++i;
|
| DictionaryEntry entry;
|
|
|
| while (true){
|
| skip_spaces(text, i);
|
| if (!json_valid_index(i, text.size())) break;
|
| if (text[i] == '}'){ ++i; break; }
|
|
|
| std::string field = parse_quoted_string(text, i);
|
| skip_spaces(text, i);
|
| if (!json_valid_index(i, text.size()) || text[i] != ':') break;
|
| ++i;
|
| skip_spaces(text, i);
|
|
|
| if (field == "word"){
|
| entry.word = parse_quoted_string(text, i);
|
| } else if (field == "pos"){
|
| entry.pos = parse_quoted_string(text, i);
|
| } else if (field == "definitions"){
|
| entry.definitions = parse_json_string_array(text, i);
|
| } else {
|
| skip_json_value(text, i);
|
| }
|
|
|
| skip_spaces(text, i);
|
| if (json_valid_index(i, text.size()) && text[i] == ','){ ++i; continue; }
|
| if (json_valid_index(i, text.size()) && text[i] == '}'){ ++i; break; }
|
| }
|
|
|
| if (!entry.word.empty()) dict.push_back(std::move(entry));
|
|
|
| skip_spaces(text, i);
|
| if (json_valid_index(i, text.size()) && text[i] == ','){ ++i; continue; }
|
| if (json_valid_index(i, text.size()) && text[i] == ']'){ ++i; break; }
|
| }
|
|
|
| return dict;
|
| }
|
|
|
| static std::string best_candidate_by_similarity(
|
| const NextSet &cands,
|
| const std::vector<StrPtr> &prompt_ptrs,
|
| const std::vector<StrPtr> &resp_ptrs,
|
| const std::unordered_map<StrPtr, std::vector<StrPtr>, PtrHash, PtrEq> &def_index,
|
| const StringInterner &interner,
|
| const std::unordered_map<std::string,int> &recent_counts,
|
| double repeat_penalty,
|
| const std::string &context_tok)
|
| {
|
| if (cands.empty()) return std::string();
|
| if (cands.size() == 1) return *cands[0];
|
|
|
| const std::size_t words = std::max<std::size_t>(1, (interner.size() + 63u) / 64u);
|
|
|
| std::vector<std::uint64_t> agg_words(words, 0ULL);
|
|
|
| auto add_token_and_defs = [&](StrPtr t){
|
| TokenId tid = interner.id_of(t);
|
| if (tid != TOKEN_ID_INVALID) bitset_set(agg_words.data(), words, tid);
|
|
|
| auto it = def_index.find(t);
|
| if (it != def_index.end()){
|
| for (StrPtr d : it->second){
|
| TokenId did = interner.id_of(d);
|
| if (did != TOKEN_ID_INVALID) bitset_set(agg_words.data(), words, did);
|
| }
|
| }
|
| };
|
|
|
| for (StrPtr t : prompt_ptrs) add_token_and_defs(t);
|
| for (StrPtr t : resp_ptrs) add_token_and_defs(t);
|
|
|
| const std::size_t agg_count = bitset_count(agg_words.data(), words);
|
|
|
| const std::size_t M = cands.size();
|
| std::vector<std::uint64_t> cand_words(M * words, 0ULL);
|
| std::vector<std::size_t> cand_counts(M, 0);
|
|
|
| for (std::size_t i = 0; i < M; ++i){
|
| std::uint64_t *row = cand_words.data() + i * words;
|
| const StrPtr cand = cands[i];
|
|
|
| TokenId cid = interner.id_of(cand);
|
| if (cid != TOKEN_ID_INVALID) bitset_set(row, words, cid);
|
|
|
| auto it = def_index.find(cand);
|
| if (it != def_index.end()){
|
| for (StrPtr d : it->second){
|
| TokenId did = interner.id_of(d);
|
| if (did != TOKEN_ID_INVALID) bitset_set(row, words, did);
|
| }
|
| }
|
|
|
| cand_counts[i] = bitset_count(row, words);
|
| }
|
|
|
| std::vector<double> scores(M, 0.0);
|
|
|
| #if defined(_OPENMP) && defined(CHATIPC_ENABLE_OMP_TARGET)
|
| const bool use_target = (omp_get_num_devices() > 0) && (M >= 256);
|
| #else
|
| const bool use_target = false;
|
| #endif
|
|
|
| if (use_target){
|
| std::uint64_t *agg_ptr = agg_words.data();
|
| std::uint64_t *cand_ptr = cand_words.data();
|
| std::size_t *count_ptr = cand_counts.data();
|
| double *score_ptr = scores.data();
|
| const std::size_t cand_words_total = cand_words.size();
|
|
|
| #pragma omp target data map(to: agg_ptr[0:words], cand_ptr[0:cand_words_total], count_ptr[0:M]) map(from: score_ptr[0:M])
|
| {
|
| #pragma omp target teams distribute parallel for
|
| for (ptrdiff_t i = 0; i < static_cast<ptrdiff_t>(M); ++i){
|
| const std::uint64_t *row = cand_ptr + static_cast<std::size_t>(i) * words;
|
|
|
| std::size_t inter = 0;
|
| for (std::size_t w = 0; w < words; ++w){
|
| inter += popcount_u64(agg_ptr[w] & row[w]);
|
| }
|
|
|
| const std::size_t uni = agg_count + count_ptr[(size_t)i] - inter;
|
| score_ptr[(size_t)i] = uni ? static_cast<double>(inter) / static_cast<double>(uni) : 0.0;
|
| }
|
| }
|
| } else {
|
| #ifdef _OPENMP
|
| #pragma omp parallel for schedule(static)
|
| #endif
|
| for (ptrdiff_t i = 0; i < static_cast<ptrdiff_t>(M); ++i){
|
| const std::uint64_t *row = cand_words.data() + static_cast<std::size_t>(i) * words;
|
|
|
| const std::size_t inter = bitset_intersection_count(agg_words.data(), row, words);
|
| const std::size_t uni = agg_count + cand_counts[(size_t)i] - inter;
|
| scores[(size_t)i] = uni ? static_cast<double>(inter) / static_cast<double>(uni) : 0.0;
|
| }
|
| }
|
|
|
| double best = -1e9;
|
| std::string best_tok;
|
|
|
| for (std::size_t i = 0; i < M; ++i){
|
| const std::string &tok = *cands[i];
|
| const std::string tok_key = normalize_dictionary_key(tok);
|
| const std::string count_key = tok_key.empty() ? tok : tok_key;
|
|
|
| auto rc_it = recent_counts.find(count_key);
|
| const int cnt = (rc_it == recent_counts.end()) ? 0 : rc_it->second;
|
|
|
| const double adjusted =
|
| scores[i] +
|
| english_rule_bonus(context_tok, tok) -
|
| repeat_penalty * static_cast<double>(cnt);
|
|
|
| if (adjusted > best || (adjusted == best && tok < best_tok)){
|
| best = adjusted;
|
| best_tok = tok;
|
| }
|
| }
|
|
|
| return best_tok;
|
| }
|
|
|
| static std::vector<std::string> construct_response(KnowledgeBase &kb,
|
| const std::vector<std::string> &prompt_toks,
|
| size_t maxlen,
|
| double repeat_penalty)
|
| {
|
| std::vector<std::string> resp;
|
| if (prompt_toks.empty() || maxlen == 0) return resp;
|
|
|
| auto prompt_ptrs = intern_tokens(kb, prompt_toks);
|
| std::vector<StrPtr> resp_ptrs;
|
| std::unordered_map<std::string,int> recent_counts;
|
|
|
| auto would_create_2_cycle = [&](const std::string &cand) -> bool {
|
| if (resp.size() < 3) return false;
|
| return normalize_dictionary_key(cand) == normalize_dictionary_key(resp[resp.size() - 2]) &&
|
| normalize_dictionary_key(resp.back()) == normalize_dictionary_key(resp[resp.size() - 3]);
|
| };
|
|
|
| std::string last_printed;
|
|
|
| for (size_t step = 0; step < maxlen; ++step){
|
| NextSet candidates;
|
| bool found = false;
|
| std::string context_tok;
|
|
|
| if (step == 0){
|
| for (ssize_t p = static_cast<ssize_t>(prompt_toks.size()) - 1; p >= 0; --p){
|
| auto opt = kb.lookup_by_string(prompt_toks[(size_t)p]);
|
| if (opt){
|
| candidates = *opt;
|
| found = true;
|
| context_tok = prompt_toks[(size_t)p];
|
| break;
|
| }
|
| }
|
| } else {
|
| auto opt = kb.lookup_by_string(last_printed);
|
| if (opt){
|
| candidates = *opt;
|
| found = true;
|
| context_tok = last_printed;
|
| } else {
|
| for (ssize_t p = static_cast<ssize_t>(prompt_toks.size()) - 1; p >= 0; --p){
|
| auto opt2 = kb.lookup_by_string(prompt_toks[(size_t)p]);
|
| if (opt2){
|
| candidates = *opt2;
|
| found = true;
|
| context_tok = prompt_toks[(size_t)p];
|
| break;
|
| }
|
| }
|
| }
|
| }
|
|
|
| if (!found || candidates.empty()) break;
|
|
|
| if (candidates.size() == 1){
|
| std::string only = *candidates[0];
|
| std::string only_key = normalize_dictionary_key(only);
|
| if (recent_counts[only_key.empty() ? only : only_key] > 0) break;
|
|
|
| resp.push_back(only);
|
| resp_ptrs.push_back(kb.interner.intern(only));
|
| recent_counts[only_key.empty() ? only : only_key] += 1;
|
| last_printed = only;
|
| std::cout << only << ' ' << std::flush;
|
| continue;
|
| }
|
|
|
| std::string chosen = best_candidate_by_similarity(
|
| candidates, prompt_ptrs, resp_ptrs, kb.def_index, kb.interner,
|
| recent_counts, repeat_penalty, context_tok
|
| );
|
|
|
| if (chosen.empty()) break;
|
| if (would_create_2_cycle(chosen)) break;
|
|
|
| resp.push_back(chosen);
|
| resp_ptrs.push_back(kb.interner.intern(chosen));
|
|
|
| std::string chosen_key = normalize_dictionary_key(chosen);
|
| recent_counts[chosen_key.empty() ? chosen : chosen_key] += 1;
|
|
|
| last_printed = chosen;
|
| std::cout << chosen << ' ' << std::flush;
|
| }
|
|
|
| return resp;
|
| }
|
|
|
|
|
|
|
| static void learn_from_file(KnowledgeBase &kb, const std::string &fname){
|
| std::ifstream ifs(fname);
|
| if (!ifs) return;
|
| std::string tok;
|
| std::string prev;
|
| bool have_prev = false;
|
| while (ifs >> tok){
|
| if (have_prev) kb.add_pair(prev, tok);
|
| prev = tok; have_prev = true;
|
| }
|
| }
|
|
|
| static void learn_files_parallel(KnowledgeBase &kb, const std::vector<std::string> &files){
|
| #pragma omp parallel for schedule(dynamic)
|
| for (ptrdiff_t i=0;i<static_cast<ptrdiff_t>(files.size());++i) learn_from_file(kb, files[(size_t)i]);
|
| }
|
|
|
|
|
| static constexpr std::uint64_t KB_MAGIC = 0x434850434B535641ULL;
|
| static constexpr std::uint64_t KB_VERSION = 1ULL;
|
|
|
| static void write_u64(std::ostream &os, std::uint64_t v){
|
| os.write(reinterpret_cast<const char*>(&v), sizeof(v));
|
| if(!os) throw std::runtime_error("write_u64 failed");
|
| }
|
|
|
| static std::uint64_t read_u64(std::istream &is){
|
| std::uint64_t v = 0;
|
| is.read(reinterpret_cast<char*>(&v), sizeof(v));
|
| if(!is) throw std::runtime_error("read_u64 failed");
|
| return v;
|
| }
|
|
|
| static void write_string(std::ostream &os, const std::string &s){
|
| write_u64(os, static_cast<std::uint64_t>(s.size()));
|
| if (!s.empty()){
|
| os.write(s.data(), static_cast<std::streamsize>(s.size()));
|
| if(!os) throw std::runtime_error("write_string failed");
|
| }
|
| }
|
|
|
| static std::string read_string(std::istream &is){
|
| std::uint64_t n = read_u64(is);
|
| if (n > (1ULL << 30)) throw std::runtime_error("corrupt save file: string too large");
|
|
|
| std::string s;
|
| s.resize(static_cast<size_t>(n));
|
|
|
| if (n != 0){
|
| is.read(&s[0], static_cast<std::streamsize>(n));
|
| if(!is) throw std::runtime_error("read_string failed");
|
| }
|
| return s;
|
| }
|
|
|
| static void save_kb_binary(const KnowledgeBase &kb, const std::string &fname){
|
| const std::string temp = fname + ".tmp";
|
|
|
| {
|
| std::ofstream ofs(temp.c_str(), std::ios::binary | std::ios::trunc);
|
| if (!ofs) throw std::runtime_error("cannot open temp save file");
|
|
|
| std::vector<std::string> pool;
|
| pool.reserve(kb.interner.pool.size());
|
| for (const auto &s : kb.interner.pool) pool.push_back(s);
|
|
|
| std::sort(pool.begin(), pool.end());
|
|
|
| std::unordered_map<std::string, std::uint64_t> id;
|
| id.reserve(pool.size());
|
| for (std::uint64_t i = 0; i < static_cast<std::uint64_t>(pool.size()); ++i){
|
| id.emplace(pool[(size_t)i], i);
|
| }
|
|
|
| write_u64(ofs, KB_MAGIC);
|
| write_u64(ofs, KB_VERSION);
|
| write_u64(ofs, static_cast<std::uint64_t>(kb.def_depth));
|
|
|
| write_u64(ofs, static_cast<std::uint64_t>(pool.size()));
|
| for (const auto &s : pool) write_string(ofs, s);
|
|
|
| write_u64(ofs, static_cast<std::uint64_t>(kb.next.size()));
|
| for (const auto &pr : kb.next){
|
| write_u64(ofs, id.at(*pr.first));
|
| write_u64(ofs, static_cast<std::uint64_t>(pr.second.size()));
|
| for (StrPtr nxt : pr.second){
|
| write_u64(ofs, id.at(*nxt));
|
| }
|
| }
|
|
|
| write_u64(ofs, static_cast<std::uint64_t>(kb.def_index.size()));
|
| for (const auto &pr : kb.def_index){
|
| write_u64(ofs, id.at(*pr.first));
|
| write_u64(ofs, static_cast<std::uint64_t>(pr.second.size()));
|
| for (StrPtr tok : pr.second){
|
| write_u64(ofs, id.at(*tok));
|
| }
|
| }
|
|
|
| ofs.flush();
|
| if (!ofs) throw std::runtime_error("failed while writing temp save file");
|
| }
|
|
|
| std::remove(fname.c_str());
|
| if (std::rename(temp.c_str(), fname.c_str()) != 0){
|
| std::remove(temp.c_str());
|
| throw std::runtime_error("failed to commit save file");
|
| }
|
| }
|
|
|
| static void load_kb_binary(KnowledgeBase &kb, const std::string &fname, int cli_dict_depth){
|
| std::ifstream ifs(fname, std::ios::binary);
|
| if (!ifs) throw std::runtime_error("cannot open load file");
|
|
|
| const std::uint64_t magic = read_u64(ifs);
|
| if (magic != KB_MAGIC) throw std::runtime_error("bad save file magic");
|
|
|
| const std::uint64_t version = read_u64(ifs);
|
| if (version != KB_VERSION) throw std::runtime_error("unsupported save file version");
|
|
|
| const std::uint64_t file_def_depth = read_u64(ifs);
|
|
|
| const std::uint64_t N = read_u64(ifs);
|
| if (N > (1ULL << 26)) throw std::runtime_error("corrupt save file: pool too large");
|
|
|
| std::vector<std::string> strings;
|
| strings.reserve(static_cast<size_t>(N));
|
|
|
| for (std::uint64_t i = 0; i < N; ++i){
|
| strings.push_back(read_string(ifs));
|
| }
|
|
|
| kb.interner.pool.clear();
|
| kb.interner.pool.reserve(static_cast<size_t>(N));
|
|
|
| std::vector<StrPtr> ptrs;
|
| ptrs.reserve(static_cast<size_t>(N));
|
| for (const auto &s : strings){
|
| ptrs.push_back(kb.interner.intern(s));
|
| }
|
|
|
|
|
| const std::uint64_t E = read_u64(ifs);
|
| if (E > (1ULL << 26)) throw std::runtime_error("corrupt save file: graph too large");
|
|
|
| {
|
| std::lock_guard<std::mutex> lk(kb.m);
|
| kb.next.clear();
|
| kb.next_key_index.clear();
|
| kb.next.reserve(static_cast<size_t>(E));
|
| kb.next_key_index.reserve(static_cast<size_t>(E));
|
| }
|
|
|
| for (std::uint64_t i = 0; i < E; ++i){
|
| const std::uint64_t key_idx = read_u64(ifs);
|
| const std::uint64_t M = read_u64(ifs);
|
|
|
| if (key_idx >= ptrs.size()) throw std::runtime_error("corrupt save file: bad graph key");
|
| if (M > (1ULL << 26)) throw std::runtime_error("corrupt save file: graph degree too large");
|
|
|
| StrPtr key_ptr = ptrs[(size_t)key_idx];
|
| NextSet vec;
|
| vec.reserve(static_cast<size_t>(M));
|
|
|
| for (std::uint64_t j = 0; j < M; ++j){
|
| const std::uint64_t v_idx = read_u64(ifs);
|
| if (v_idx >= ptrs.size()) throw std::runtime_error("corrupt save file: bad graph value");
|
| vec.push_back(ptrs[(size_t)v_idx]);
|
| }
|
|
|
| {
|
| std::lock_guard<std::mutex> lk(kb.m);
|
| kb.next.emplace(key_ptr, std::move(vec));
|
| kb.next_key_index.emplace(*key_ptr, key_ptr);
|
| }
|
| }
|
|
|
|
|
| const std::uint64_t K = read_u64(ifs);
|
| if (K > (1ULL << 26)) throw std::runtime_error("corrupt save file: def_index too large");
|
|
|
| {
|
| std::lock_guard<std::mutex> lk(kb.def_m);
|
| kb.def_index.clear();
|
| kb.def_index.reserve(static_cast<size_t>(K));
|
| kb.def_depth = static_cast<int>(file_def_depth);
|
| }
|
|
|
| for (std::uint64_t i = 0; i < K; ++i){
|
| const std::uint64_t key_idx = read_u64(ifs);
|
| const std::uint64_t M = read_u64(ifs);
|
|
|
| if (key_idx >= ptrs.size()) throw std::runtime_error("corrupt save file: bad def key");
|
| if (M > (1ULL << 26)) throw std::runtime_error("corrupt save file: def list too large");
|
|
|
| std::vector<StrPtr> toks;
|
| toks.reserve(static_cast<size_t>(M));
|
|
|
| for (std::uint64_t j = 0; j < M; ++j){
|
| const std::uint64_t v_idx = read_u64(ifs);
|
| if (v_idx >= ptrs.size()) throw std::runtime_error("corrupt save file: bad def value");
|
| toks.push_back(ptrs[(size_t)v_idx]);
|
| }
|
|
|
| {
|
| std::lock_guard<std::mutex> lk(kb.def_m);
|
| kb.def_index.emplace(ptrs[(size_t)key_idx], std::move(toks));
|
| }
|
| }
|
|
|
|
|
| if (cli_dict_depth != static_cast<int>(file_def_depth)){
|
| kb.set_def_depth(cli_dict_depth);
|
|
|
| std::vector<StrPtr> targets;
|
| targets.reserve(ptrs.size() + kb.next.size() * 2);
|
|
|
| std::unordered_set<StrPtr, PtrHash, PtrEq> seen;
|
| seen.reserve(ptrs.size() + kb.next.size() * 2);
|
|
|
| for (StrPtr p : ptrs){
|
| if (seen.insert(p).second) targets.push_back(p);
|
| }
|
|
|
| {
|
| std::lock_guard<std::mutex> lk(kb.m);
|
| for (const auto &pr : kb.next){
|
| if (seen.insert(pr.first).second) targets.push_back(pr.first);
|
| for (StrPtr v : pr.second){
|
| if (seen.insert(v).second) targets.push_back(v);
|
| }
|
| }
|
| }
|
|
|
| #pragma omp parallel for schedule(dynamic)
|
| for (ptrdiff_t i = 0; i < static_cast<ptrdiff_t>(targets.size()); ++i){
|
| kb.ensure_def_for_interned(targets[(size_t)i]);
|
| }
|
| }
|
| }
|
|
|
|
|
|
|
| static void print_usage(const char *p){
|
| std::cout << "Usage: " << p << " [--maxlen N] [--save FILE] [--load-kb FILE] [--dict-depth D] [--learn f1 f2 ...] [--repeat-penalty P] [--help]\n";
|
| std::cout << " --maxlen N Maximum number of tokens constructed in a response.\n";
|
| std::cout << " --save FILE Save the knowledge-base and dictionary expansions to a binary file.\n";
|
| std::cout << " --load-kb FILE Load a previously saved knowledge-base (and dictionary expansions) from a binary file.\n";
|
| std::cout << " --dict-depth D Depth of dictionary-definition expansion used during learning.\n";
|
| std::cout << " --learn f1 f2 ... Learn from one or more text files to update the knowledge base.\n";
|
| std::cout << " --repeat-penalty P Penalize repeated tokens during response generation (higher values discourage repetition).\n";
|
| std::cout << " --help Show command-line interface options for ChatIPC usage.\n";
|
| }
|
|
|
| int main(int argc, char **argv){
|
| size_t maxlen = 100;
|
| std::string savefile;
|
| std::string load_txt;
|
| std::string load_kb;
|
| int dict_depth = 2;
|
| double repeat_penalty = 0.7;
|
| std::vector<std::string> learn_files;
|
|
|
| for (int i=1;i<argc;++i){
|
| std::string a = argv[i];
|
| if (a=="--help"){ print_usage(argv[0]); return 0; }
|
| if (a=="--maxlen" && i+1<argc){ maxlen = std::stoul(argv[++i]); continue; }
|
| if (a=="--save" && i+1<argc){ savefile = argv[++i]; continue; }
|
| if (a=="--load-kb" && i+1<argc){ load_kb = argv[++i]; continue; }
|
| if (a=="--dict-depth" && i+1<argc){ dict_depth = std::stoi(argv[++i]); continue; }
|
| if (a=="--repeat-penalty" && i+1<argc){ repeat_penalty = std::stod(argv[++i]); continue; }
|
| if (a=="--learn"){
|
| ++i;
|
| for (; i<argc; ++i){
|
| if (!argv[i]) break;
|
| std::string s = argv[i];
|
| if (!s.empty() && s[0]=='-'){ --i; break; }
|
| learn_files.push_back(s);
|
| }
|
| continue;
|
| }
|
| learn_files.push_back(a);
|
| }
|
|
|
| KnowledgeBase kb;
|
|
|
|
|
| global_dictionary_entries = parse_dictionary_json();
|
| build_def_tokens_cache();
|
|
|
| kb.set_def_depth(dict_depth);
|
|
|
|
|
| if (!load_kb.empty()){
|
| try { std::cerr << "Loading KB: " << load_kb << "\n";
|
| load_kb_binary(kb, load_kb, dict_depth); std::cerr << "Loaded KB: " << load_kb << "\n"; }
|
| catch (const std::exception &e){ std::cerr << "Load KB error: " << e.what() << "\n"; }
|
| }
|
|
|
| if (!learn_files.empty()){
|
| std::cerr << "Learning from file/s (" << learn_files.size() << ") using threads=" << omp_get_max_threads() << "\n";
|
| learn_files_parallel(kb, learn_files);
|
| }
|
|
|
| std::string line;
|
| std::cout << "Ready. Enter prompts.\n";
|
| while (std::cout << "> " , std::getline(std::cin, line)){
|
| if (line.empty()){ std::cout << "\n"; continue; }
|
| auto prompt_toks = tokenize_whitespace(line);
|
| for (size_t i=1;i<prompt_toks.size();++i) kb.add_pair(prompt_toks[i-1], prompt_toks[i]);
|
| auto resp = construct_response(kb, prompt_toks, maxlen, repeat_penalty);
|
| std::cout << "\n";
|
| if (!resp.empty()){
|
| std::vector<std::string> combined = prompt_toks;
|
| combined.insert(combined.end(), resp.begin(), resp.end());
|
| for (size_t i=1;i<combined.size();++i) kb.add_pair(combined[i-1], combined[i]);
|
| }
|
| if (!savefile.empty()){
|
| try { std::cerr << "Saving KB: " << savefile << "\n";
|
| save_kb_binary(kb, savefile); std::cerr << "Saved KB: " << savefile << "\n"; }
|
| catch (const std::exception &e){ std::cerr << "Save KB error: " << e.what() << "\n"; }
|
| }
|
| }
|
|
|
| return 0;
|
| }
|
|
|