technician1's picture
Upload ChatIPC.cpp
ec7d802 verified
// ChatIPC := Chat Incremental Pattern Constructor
#include <algorithm>
#include <atomic>
#include <cctype>
#include <cinttypes>
#include <cstring>
#include <fstream>
#include <iostream>
#include <iterator>
#include <map>
#include <mutex>
#include <optional>
#include <sstream>
#include <stdexcept>
#include <string>
#include <thread>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#ifdef _OPENMP
#include <omp.h>
#else
inline int omp_get_max_threads(){ return 1; }
inline int omp_get_thread_num(){ return 0; }
#endif
extern unsigned char dictionary_json[]; // provide dictionary.cpp to embed dictionary JSON bytes
extern unsigned int dictionary_json_len;
// --------------------------- Short utility functions ----------------------
static inline bool is_space(char c){ return std::isspace(static_cast<unsigned char>(c)) != 0; }
static inline char to_low(char c){ return static_cast<char>(std::tolower(static_cast<unsigned char>(c))); }
static inline void safe_flush(std::ostream &os){ os.flush(); }
// NEW: dictionary model, normalization, and English-rule helpers.
struct DictionaryEntry {
std::string pos;
std::string word;
std::vector<std::string> definitions;
};
static std::vector<DictionaryEntry> global_dictionary_entries;
static std::unordered_map<std::string, std::vector<std::string>> global_def_tokens_cache;
static std::unordered_map<std::string, std::vector<std::string>> global_pos_cache;
static inline bool is_word_char_for_key(char c){
unsigned char uc = static_cast<unsigned char>(c);
return std::isalnum(uc) != 0 || c == '\'' || c == '-';
}
static std::string normalize_dictionary_key(const std::string &s){
size_t b = 0, e = s.size();
while (b < e && !is_word_char_for_key(s[b])) ++b;
while (e > b && !is_word_char_for_key(s[e - 1])) --e;
std::string out;
out.reserve(e - b);
for (size_t i = b; i < e; ++i) out.push_back(to_low(s[i]));
return out;
}
static std::string normalize_pos_tag(const std::string &s){
std::string out;
out.reserve(s.size());
for (char c : s){
unsigned char uc = static_cast<unsigned char>(c);
if (std::isalpha(uc) != 0) out.push_back(to_low(c));
}
return out;
}
enum class PosClass {
Unknown,
Noun,
Verb,
Adj,
Adv,
Pron,
Prep,
Conj,
Det,
Num,
Interj
};
static PosClass pos_class_from_tag(const std::string &tag){
if (tag == "n" || tag == "noun") return PosClass::Noun;
if (tag == "v" || tag == "verb" || tag == "part" || tag == "participle" || tag == "p") return PosClass::Verb;
if (tag == "a" || tag == "adj" || tag == "adjective") return PosClass::Adj;
if (tag == "adv" || tag == "adverb") return PosClass::Adv;
if (tag == "pron" || tag == "pronoun") return PosClass::Pron;
if (tag == "prep" || tag == "preposition") return PosClass::Prep;
if (tag == "conj" || tag == "conjunction") return PosClass::Conj;
if (tag == "art" || tag == "article" || tag == "det" || tag == "determiner") return PosClass::Det;
if (tag == "num" || tag == "number") return PosClass::Num;
if (tag == "interj" || tag == "interjection") return PosClass::Interj;
return PosClass::Unknown;
}
static bool has_pos_class(const std::vector<std::string> &tags, PosClass cls){
for (const auto &t : tags){
if (pos_class_from_tag(t) == cls) return true;
}
return false;
}
static const std::vector<std::string> &dictionary_pos_for_token(const std::string &surface){
static const std::vector<std::string> empty;
auto key = normalize_dictionary_key(surface);
if (key.empty()) return empty;
auto it = global_pos_cache.find(key);
return (it == global_pos_cache.end()) ? empty : it->second;
}
static bool first_alpha_is_upper(const std::string &s){
for (char c : s){
unsigned char uc = static_cast<unsigned char>(c);
if (std::isalpha(uc) != 0) return std::isupper(uc) != 0;
}
return false;
}
static bool first_alpha_is_lower(const std::string &s){
for (char c : s){
unsigned char uc = static_cast<unsigned char>(c);
if (std::isalpha(uc) != 0) return std::islower(uc) != 0;
}
return false;
}
static bool is_sentence_boundary_token(const std::string &s){
if (s.empty()) return false;
char c = s.back();
return c == '.' || c == '!' || c == '?';
}
static bool is_open_punct_token(const std::string &s){
return s == "(" || s == "[" || s == "{" || s == "\"" || s == "'";
}
static bool is_punctuation_only_token(const std::string &s){
if (s.empty()) return false;
for (char c : s){
unsigned char uc = static_cast<unsigned char>(c);
if (std::isalnum(uc) != 0) return false;
}
return true;
}
static bool is_common_determiner(const std::string &s){
return s == "a" || s == "an" || s == "the" || s == "this" || s == "that" ||
s == "these" || s == "those" || s == "my" || s == "your" || s == "his" ||
s == "her" || s == "its" || s == "our" || s == "their";
}
static bool is_common_preposition(const std::string &s){
return s == "of" || s == "in" || s == "on" || s == "at" || s == "by" ||
s == "for" || s == "from" || s == "with" || s == "into" || s == "onto" ||
s == "about" || s == "over" || s == "under" || s == "after" || s == "before" ||
s == "between" || s == "through" || s == "during" || s == "without" || s == "within" ||
s == "under" || s == "across" || s == "against" || s == "among" || s == "around";
}
static bool is_common_aux_or_modal(const std::string &s){
return s == "to" || s == "be" || s == "am" || s == "is" || s == "are" || s == "was" ||
s == "were" || s == "been" || s == "being" || s == "have" || s == "has" ||
s == "had" || s == "do" || s == "does" || s == "did" || s == "can" ||
s == "could" || s == "may" || s == "might" || s == "must" || s == "shall" ||
s == "should" || s == "will" || s == "would";
}
static bool begins_with_vowel_sound(const std::string &s){
if (s.empty()) return false;
if (s.rfind("hour", 0) == 0 || s.rfind("honest", 0) == 0 || s.rfind("honor", 0) == 0 ||
s.rfind("heir", 0) == 0 || s.rfind("herb", 0) == 0) {
return true;
}
if (s.rfind("uni", 0) == 0 || s.rfind("use", 0) == 0 || s.rfind("user", 0) == 0 ||
s.rfind("one", 0) == 0 || s.rfind("once", 0) == 0 || s.rfind("euro", 0) == 0) {
return false;
}
char c = s[0];
return c == 'a' || c == 'e' || c == 'i' || c == 'o' || c == 'u';
}
static double english_rule_bonus(const std::string &context_tok, const std::string &cand){
const std::string ctx_key = normalize_dictionary_key(context_tok);
const std::string cand_key = normalize_dictionary_key(cand);
const auto &ctx_tags = dictionary_pos_for_token(context_tok);
const auto &cand_tags = dictionary_pos_for_token(cand);
const bool sentence_start = context_tok.empty() || is_sentence_boundary_token(context_tok) || is_open_punct_token(context_tok);
const bool cand_nounish = has_pos_class(cand_tags, PosClass::Noun) ||
has_pos_class(cand_tags, PosClass::Adj) ||
has_pos_class(cand_tags, PosClass::Pron) ||
has_pos_class(cand_tags, PosClass::Num);
const bool cand_verbish = has_pos_class(cand_tags, PosClass::Verb);
const bool cand_advish = has_pos_class(cand_tags, PosClass::Adv);
const bool cand_prepish = has_pos_class(cand_tags, PosClass::Prep);
const bool cand_detish = has_pos_class(cand_tags, PosClass::Det);
double bonus = 0.0;
if (!cand_key.empty()){
if (sentence_start){
bonus += first_alpha_is_upper(cand) ? 0.22 : -0.08;
} else if (first_alpha_is_upper(cand)){
bonus -= 0.03;
}
}
if (ctx_key == "a" || ctx_key == "an"){
const bool vowel = begins_with_vowel_sound(cand_key.empty() ? cand : cand_key);
bonus += ((ctx_key == "an") == vowel) ? 0.28 : -0.18;
}
const bool ctx_det = has_pos_class(ctx_tags, PosClass::Det) || is_common_determiner(ctx_key);
const bool ctx_prep = has_pos_class(ctx_tags, PosClass::Prep) || is_common_preposition(ctx_key);
const bool ctx_aux = is_common_aux_or_modal(ctx_key);
if (ctx_det){
if (cand_nounish) bonus += 0.20;
if (cand_verbish || cand_advish || cand_prepish) bonus -= 0.08;
}
if (ctx_prep){
if (cand_nounish) bonus += 0.16;
if (cand_verbish) bonus -= 0.06;
}
if (ctx_aux){
if (cand_verbish) bonus += 0.18;
if (cand_detish) bonus -= 0.04;
}
if (has_pos_class(ctx_tags, PosClass::Pron) || has_pos_class(ctx_tags, PosClass::Noun)){
if (cand_verbish) bonus += 0.05;
}
if (!context_tok.empty() && (context_tok.back() == ',' || context_tok.back() == ';' || context_tok.back() == ':')){
if (!cand.empty() && first_alpha_is_lower(cand)) bonus += 0.04;
}
if (is_punctuation_only_token(cand)){
if (sentence_start) bonus -= 0.05;
else if (!context_tok.empty() && std::isalnum(static_cast<unsigned char>(context_tok.back())) != 0) bonus += 0.03;
}
if (is_sentence_boundary_token(cand)) bonus += 0.06;
return bonus;
}
// Tokenize by whitespace
static std::vector<std::string> tokenize_whitespace(const std::string &s){
std::istringstream iss(s);
std::vector<std::string> out;
std::string t;
while (iss >> t) out.push_back(t);
return out;
}
// Tokenize by non-alphanumeric characters (for definitions)
static std::vector<std::string> tokenize_non_alnum(const std::string &s){
std::vector<std::string> out; std::string cur;
for (char ch : s){
if (std::isalnum(static_cast<unsigned char>(ch)) || ch=='-' || ch=='\''){
cur.push_back(to_low(ch));
} else {
if (!cur.empty()){ out.push_back(cur); cur.clear(); }
}
}
if (!cur.empty()) out.push_back(cur);
return out;
}
// --------------------------- String interning (short methods) --------------
using StrPtr = const std::string*;
using TokenId = std::uint32_t;
static constexpr TokenId TOKEN_ID_INVALID = 0xFFFFFFFFu;
static inline std::size_t popcount_u64(std::uint64_t x){
#if defined(_MSC_VER)
return static_cast<std::size_t>(__popcnt64(x));
#else
return static_cast<std::size_t>(__builtin_popcountll(static_cast<unsigned long long>(x)));
#endif
}
struct StringInterner {
std::unordered_set<std::string> pool;
std::unordered_map<std::string, TokenId> id_by_string;
std::vector<const std::string*> string_by_id;
mutable std::mutex m;
const std::string* intern(const std::string &s){
std::lock_guard<std::mutex> lk(m);
auto [it, inserted] = pool.emplace(s);
if (inserted){
const TokenId id = static_cast<TokenId>(string_by_id.size());
id_by_string.emplace(*it, id);
string_by_id.push_back(&*it);
}
return &*it;
}
TokenId id_of(const std::string &s) const {
std::lock_guard<std::mutex> lk(m);
auto it = id_by_string.find(s);
return (it == id_by_string.end()) ? TOKEN_ID_INVALID : it->second;
}
TokenId id_of(const std::string *p) const {
return p ? id_of(*p) : TOKEN_ID_INVALID;
}
const std::string* ptr_from_id(TokenId id) const {
std::lock_guard<std::mutex> lk(m);
return (id < string_by_id.size()) ? string_by_id[(size_t)id] : nullptr;
}
std::size_t size() const {
std::lock_guard<std::mutex> lk(m);
return string_by_id.size();
}
};
static inline void bitset_set(std::uint64_t *bits, std::size_t words, TokenId id){
if (id == TOKEN_ID_INVALID) return;
const std::size_t idx = static_cast<std::size_t>(id >> 6);
if (idx >= words) return;
bits[idx] |= (1ULL << (id & 63u));
}
static inline std::size_t bitset_count(const std::uint64_t *bits, std::size_t words){
std::size_t total = 0;
for (std::size_t i = 0; i < words; ++i) total += popcount_u64(bits[i]);
return total;
}
static inline std::size_t bitset_intersection_count(const std::uint64_t *a, const std::uint64_t *b, std::size_t words){
std::size_t total = 0;
for (std::size_t i = 0; i < words; ++i) total += popcount_u64(a[i] & b[i]);
return total;
}
// ---------- Global parsed dictionary (populated once in main) ----------
static void build_def_tokens_cache(){
global_def_tokens_cache.clear();
global_pos_cache.clear();
global_def_tokens_cache.reserve(global_dictionary_entries.size());
global_pos_cache.reserve(global_dictionary_entries.size());
for (const auto &entry : global_dictionary_entries){
const std::string key = normalize_dictionary_key(entry.word);
if (key.empty()) continue;
std::string pos = normalize_pos_tag(entry.pos);
if (!pos.empty()) global_pos_cache[key].push_back(std::move(pos));
auto &defs = global_def_tokens_cache[key];
for (const auto &def : entry.definitions){
auto toks = tokenize_non_alnum(def);
defs.insert(defs.end(), toks.begin(), toks.end());
}
}
for (auto &pr : global_def_tokens_cache){
auto &v = pr.second;
std::sort(v.begin(), v.end());
v.erase(std::unique(v.begin(), v.end()), v.end());
}
for (auto &pr : global_pos_cache){
auto &v = pr.second;
std::sort(v.begin(), v.end());
v.erase(std::unique(v.begin(), v.end()), v.end());
}
}
// --------------------------- Knowledge base (short methods) --------------
struct PtrHash { size_t operator()(StrPtr p) const noexcept { return std::hash<StrPtr>()(p); } };
struct PtrEq { bool operator()(StrPtr a, StrPtr b) const noexcept { return a == b; } };
using NextSet = std::vector<StrPtr>;
struct KnowledgeBase {
StringInterner interner;
std::unordered_map<StrPtr, NextSet, PtrHash, PtrEq> next;
std::unordered_map<std::string, StrPtr> next_key_index;
mutable std::mutex m;
// def-index: for each interned word pointer -> list of interned tokens (definition expansion)
std::unordered_map<StrPtr, std::vector<StrPtr>, PtrHash, PtrEq> def_index;
mutable std::mutex def_m;
int def_depth = 0;
void add_pair_interned(StrPtr k, StrPtr v){
std::lock_guard<std::mutex> lk(m);
next_key_index.emplace(*k, k);
auto &vec = next[k];
for (auto p : vec) if (p == v) return;
vec.push_back(v);
}
// set def depth; if changed, drop previously computed def expansions
void set_def_depth(int D){
std::lock_guard<std::mutex> lk(def_m);
if (D != def_depth){
def_index.clear();
def_depth = D;
}
}
void ensure_def_for_interned(StrPtr wp){
if (wp == nullptr) return;
if (def_depth <= 0) return;
{
std::lock_guard<std::mutex> lk(def_m);
if (def_index.find(wp) != def_index.end()) return;
}
std::unordered_set<StrPtr, PtrHash, PtrEq> acc;
std::vector<StrPtr> frontier;
const std::string start_key = normalize_dictionary_key(*wp);
if (!start_key.empty()){
auto it_def = global_def_tokens_cache.find(start_key);
if (it_def != global_def_tokens_cache.end()){
for (const auto &tok : it_def->second){
StrPtr tp = interner.intern(tok);
if (acc.insert(tp).second) frontier.push_back(tp);
}
}
}
for (int depth = 1; depth < def_depth && !frontier.empty(); ++depth){
std::vector<StrPtr> next_frontier;
for (StrPtr w : frontier){
const std::string key = normalize_dictionary_key(*w);
if (key.empty()) continue;
auto it2 = global_def_tokens_cache.find(key);
if (it2 == global_def_tokens_cache.end()) continue;
for (const auto &tok : it2->second){
StrPtr tp = interner.intern(tok);
if (acc.insert(tp).second) next_frontier.push_back(tp);
}
}
frontier.swap(next_frontier);
}
std::vector<StrPtr> out;
out.reserve(acc.size());
for (StrPtr p : acc) out.push_back(p);
{
std::lock_guard<std::mutex> lk(def_m);
def_index.emplace(wp, std::move(out));
}
}
// existing public add_pair but now ensure def-expansion is built immediately
void add_pair(const std::string &k, const std::string &v){
StrPtr kp = interner.intern(k);
StrPtr vp = interner.intern(v);
// ensure definition expansion for both words as soon as they are seen
ensure_def_for_interned(kp);
ensure_def_for_interned(vp);
add_pair_interned(kp, vp);
}
std::optional<NextSet> lookup_by_string(const std::string &k) const {
std::lock_guard<std::mutex> lk(m);
auto kit = next_key_index.find(k);
if (kit == next_key_index.end()) return std::nullopt;
auto it = next.find(kit->second);
if (it == next.end()) return std::nullopt;
return it->second;
}
std::optional<NextSet> lookup_by_ptr(StrPtr k) const {
std::lock_guard<std::mutex> lk(m);
auto it = next.find(k);
if (it == next.end()) return std::nullopt;
return it->second;
}
};
static std::vector<StrPtr>
intern_tokens(KnowledgeBase &kb, const std::vector<std::string> &tokens)
{
std::vector<StrPtr> out;
out.reserve(tokens.size());
for (const auto &t : tokens) out.push_back(kb.interner.intern(t));
return out;
}
// --------------------------- Small JSON parse helpers ----------------------
static inline bool json_valid_index(size_t i, size_t n){ return i < n; }
static std::string parse_quoted_string(const std::string &text, size_t &i){
std::string out;
if (!json_valid_index(i, text.size()) || text[i] != '"') throw std::runtime_error("expected '\"'");
++i;
while (json_valid_index(i, text.size())){
char c = text[i++];
if (c == '"') break;
if (c == '\\'){
if (!json_valid_index(i, text.size())) break;
char e = text[i++];
if (e=='n') out.push_back('\n');
else if (e=='t') out.push_back('\t');
else out.push_back(e);
} else out.push_back(c);
}
return out;
}
static void skip_spaces(const std::string &s, size_t &i){
while (json_valid_index(i, s.size()) && is_space(s[i])) ++i;
}
// Very small JSON-like parser tailored to dictionary_json structure
static void skip_json_value(const std::string &s, size_t &i);
static std::vector<std::string> parse_json_string_array(const std::string &text, size_t &i){
std::vector<std::string> out;
if (!json_valid_index(i, text.size()) || text[i] != '[') return out;
++i;
while (true){
skip_spaces(text, i);
if (!json_valid_index(i, text.size())) break;
if (text[i] == ']'){ ++i; break; }
if (text[i] == '"') out.push_back(parse_quoted_string(text, i));
else skip_json_value(text, i);
skip_spaces(text, i);
if (json_valid_index(i, text.size()) && text[i] == ','){ ++i; continue; }
if (json_valid_index(i, text.size()) && text[i] == ']'){ ++i; break; }
}
return out;
}
static void skip_json_value(const std::string &s, size_t &i){
skip_spaces(s, i);
if (!json_valid_index(i, s.size())) return;
if (s[i] == '"'){
(void)parse_quoted_string(s, i);
return;
}
if (s[i] == '['){
++i;
while (json_valid_index(i, s.size())){
skip_spaces(s, i);
if (!json_valid_index(i, s.size())) break;
if (s[i] == ']'){ ++i; break; }
skip_json_value(s, i);
skip_spaces(s, i);
if (json_valid_index(i, s.size()) && s[i] == ','){ ++i; continue; }
if (json_valid_index(i, s.size()) && s[i] == ']'){ ++i; break; }
}
return;
}
if (s[i] == '{'){
++i;
while (json_valid_index(i, s.size())){
skip_spaces(s, i);
if (!json_valid_index(i, s.size())) break;
if (s[i] == '}'){ ++i; break; }
if (s[i] == '"'){
(void)parse_quoted_string(s, i);
skip_spaces(s, i);
if (json_valid_index(i, s.size()) && s[i] == ':') ++i;
skip_json_value(s, i);
skip_spaces(s, i);
if (json_valid_index(i, s.size()) && s[i] == ','){ ++i; continue; }
if (json_valid_index(i, s.size()) && s[i] == '}'){ ++i; break; }
} else {
++i;
}
}
return;
}
while (json_valid_index(i, s.size())){
char c = s[i];
if (c == ',' || c == ']' || c == '}' || is_space(c)) break;
++i;
}
}
static std::vector<DictionaryEntry> parse_dictionary_json(){
std::vector<DictionaryEntry> dict;
if (dictionary_json_len == 0) return dict;
std::string text;
text.reserve(dictionary_json_len);
for (unsigned int b = 0; b < dictionary_json_len; ++b){
text.push_back(static_cast<char>(dictionary_json[b]));
}
size_t i = 0;
skip_spaces(text, i);
if (!json_valid_index(i, text.size()) || text[i] != '[') return dict;
++i;
while (true){
skip_spaces(text, i);
if (!json_valid_index(i, text.size())) break;
if (text[i] == ']'){ ++i; break; }
if (text[i] != '{'){
skip_json_value(text, i);
skip_spaces(text, i);
if (json_valid_index(i, text.size()) && text[i] == ','){ ++i; continue; }
if (json_valid_index(i, text.size()) && text[i] == ']'){ ++i; break; }
continue;
}
++i;
DictionaryEntry entry;
while (true){
skip_spaces(text, i);
if (!json_valid_index(i, text.size())) break;
if (text[i] == '}'){ ++i; break; }
std::string field = parse_quoted_string(text, i);
skip_spaces(text, i);
if (!json_valid_index(i, text.size()) || text[i] != ':') break;
++i;
skip_spaces(text, i);
if (field == "word"){
entry.word = parse_quoted_string(text, i);
} else if (field == "pos"){
entry.pos = parse_quoted_string(text, i);
} else if (field == "definitions"){
entry.definitions = parse_json_string_array(text, i);
} else {
skip_json_value(text, i);
}
skip_spaces(text, i);
if (json_valid_index(i, text.size()) && text[i] == ','){ ++i; continue; }
if (json_valid_index(i, text.size()) && text[i] == '}'){ ++i; break; }
}
if (!entry.word.empty()) dict.push_back(std::move(entry));
skip_spaces(text, i);
if (json_valid_index(i, text.size()) && text[i] == ','){ ++i; continue; }
if (json_valid_index(i, text.size()) && text[i] == ']'){ ++i; break; }
}
return dict;
}
static std::string best_candidate_by_similarity(
const NextSet &cands,
const std::vector<StrPtr> &prompt_ptrs,
const std::vector<StrPtr> &resp_ptrs,
const std::unordered_map<StrPtr, std::vector<StrPtr>, PtrHash, PtrEq> &def_index,
const StringInterner &interner,
const std::unordered_map<std::string,int> &recent_counts,
double repeat_penalty,
const std::string &context_tok)
{
if (cands.empty()) return std::string();
if (cands.size() == 1) return *cands[0];
const std::size_t words = std::max<std::size_t>(1, (interner.size() + 63u) / 64u);
std::vector<std::uint64_t> agg_words(words, 0ULL);
auto add_token_and_defs = [&](StrPtr t){
TokenId tid = interner.id_of(t);
if (tid != TOKEN_ID_INVALID) bitset_set(agg_words.data(), words, tid);
auto it = def_index.find(t);
if (it != def_index.end()){
for (StrPtr d : it->second){
TokenId did = interner.id_of(d);
if (did != TOKEN_ID_INVALID) bitset_set(agg_words.data(), words, did);
}
}
};
for (StrPtr t : prompt_ptrs) add_token_and_defs(t);
for (StrPtr t : resp_ptrs) add_token_and_defs(t);
const std::size_t agg_count = bitset_count(agg_words.data(), words);
const std::size_t M = cands.size();
std::vector<std::uint64_t> cand_words(M * words, 0ULL);
std::vector<std::size_t> cand_counts(M, 0);
for (std::size_t i = 0; i < M; ++i){
std::uint64_t *row = cand_words.data() + i * words;
const StrPtr cand = cands[i];
TokenId cid = interner.id_of(cand);
if (cid != TOKEN_ID_INVALID) bitset_set(row, words, cid);
auto it = def_index.find(cand);
if (it != def_index.end()){
for (StrPtr d : it->second){
TokenId did = interner.id_of(d);
if (did != TOKEN_ID_INVALID) bitset_set(row, words, did);
}
}
cand_counts[i] = bitset_count(row, words);
}
std::vector<double> scores(M, 0.0);
#if defined(_OPENMP) && defined(CHATIPC_ENABLE_OMP_TARGET)
const bool use_target = (omp_get_num_devices() > 0) && (M >= 256);
#else
const bool use_target = false;
#endif
if (use_target){
std::uint64_t *agg_ptr = agg_words.data();
std::uint64_t *cand_ptr = cand_words.data();
std::size_t *count_ptr = cand_counts.data();
double *score_ptr = scores.data();
const std::size_t cand_words_total = cand_words.size();
#pragma omp target data map(to: agg_ptr[0:words], cand_ptr[0:cand_words_total], count_ptr[0:M]) map(from: score_ptr[0:M])
{
#pragma omp target teams distribute parallel for
for (ptrdiff_t i = 0; i < static_cast<ptrdiff_t>(M); ++i){
const std::uint64_t *row = cand_ptr + static_cast<std::size_t>(i) * words;
std::size_t inter = 0;
for (std::size_t w = 0; w < words; ++w){
inter += popcount_u64(agg_ptr[w] & row[w]);
}
const std::size_t uni = agg_count + count_ptr[(size_t)i] - inter;
score_ptr[(size_t)i] = uni ? static_cast<double>(inter) / static_cast<double>(uni) : 0.0;
}
}
} else {
#ifdef _OPENMP
#pragma omp parallel for schedule(static)
#endif
for (ptrdiff_t i = 0; i < static_cast<ptrdiff_t>(M); ++i){
const std::uint64_t *row = cand_words.data() + static_cast<std::size_t>(i) * words;
const std::size_t inter = bitset_intersection_count(agg_words.data(), row, words);
const std::size_t uni = agg_count + cand_counts[(size_t)i] - inter;
scores[(size_t)i] = uni ? static_cast<double>(inter) / static_cast<double>(uni) : 0.0;
}
}
double best = -1e9;
std::string best_tok;
for (std::size_t i = 0; i < M; ++i){
const std::string &tok = *cands[i];
const std::string tok_key = normalize_dictionary_key(tok);
const std::string count_key = tok_key.empty() ? tok : tok_key;
auto rc_it = recent_counts.find(count_key);
const int cnt = (rc_it == recent_counts.end()) ? 0 : rc_it->second;
const double adjusted =
scores[i] +
english_rule_bonus(context_tok, tok) -
repeat_penalty * static_cast<double>(cnt);
if (adjusted > best || (adjusted == best && tok < best_tok)){
best = adjusted;
best_tok = tok;
}
}
return best_tok;
}
static std::vector<std::string> construct_response(KnowledgeBase &kb,
const std::vector<std::string> &prompt_toks,
size_t maxlen,
double repeat_penalty)
{
std::vector<std::string> resp;
if (prompt_toks.empty() || maxlen == 0) return resp;
auto prompt_ptrs = intern_tokens(kb, prompt_toks);
std::vector<StrPtr> resp_ptrs;
std::unordered_map<std::string,int> recent_counts;
auto would_create_2_cycle = [&](const std::string &cand) -> bool {
if (resp.size() < 3) return false;
return normalize_dictionary_key(cand) == normalize_dictionary_key(resp[resp.size() - 2]) &&
normalize_dictionary_key(resp.back()) == normalize_dictionary_key(resp[resp.size() - 3]);
};
std::string last_printed;
for (size_t step = 0; step < maxlen; ++step){
NextSet candidates;
bool found = false;
std::string context_tok;
if (step == 0){
for (ssize_t p = static_cast<ssize_t>(prompt_toks.size()) - 1; p >= 0; --p){
auto opt = kb.lookup_by_string(prompt_toks[(size_t)p]);
if (opt){
candidates = *opt;
found = true;
context_tok = prompt_toks[(size_t)p];
break;
}
}
} else {
auto opt = kb.lookup_by_string(last_printed);
if (opt){
candidates = *opt;
found = true;
context_tok = last_printed;
} else {
for (ssize_t p = static_cast<ssize_t>(prompt_toks.size()) - 1; p >= 0; --p){
auto opt2 = kb.lookup_by_string(prompt_toks[(size_t)p]);
if (opt2){
candidates = *opt2;
found = true;
context_tok = prompt_toks[(size_t)p];
break;
}
}
}
}
if (!found || candidates.empty()) break;
if (candidates.size() == 1){
std::string only = *candidates[0];
std::string only_key = normalize_dictionary_key(only);
if (recent_counts[only_key.empty() ? only : only_key] > 0) break;
resp.push_back(only);
resp_ptrs.push_back(kb.interner.intern(only));
recent_counts[only_key.empty() ? only : only_key] += 1;
last_printed = only;
std::cout << only << ' ' << std::flush;
continue;
}
std::string chosen = best_candidate_by_similarity(
candidates, prompt_ptrs, resp_ptrs, kb.def_index, kb.interner,
recent_counts, repeat_penalty, context_tok
);
if (chosen.empty()) break;
if (would_create_2_cycle(chosen)) break;
resp.push_back(chosen);
resp_ptrs.push_back(kb.interner.intern(chosen));
std::string chosen_key = normalize_dictionary_key(chosen);
recent_counts[chosen_key.empty() ? chosen : chosen_key] += 1;
last_printed = chosen;
std::cout << chosen << ' ' << std::flush;
}
return resp;
}
// --------------------------- Learning from files (short) -------------------
static void learn_from_file(KnowledgeBase &kb, const std::string &fname){
std::ifstream ifs(fname);
if (!ifs) return;
std::string tok;
std::string prev;
bool have_prev = false;
while (ifs >> tok){
if (have_prev) kb.add_pair(prev, tok);
prev = tok; have_prev = true;
}
}
static void learn_files_parallel(KnowledgeBase &kb, const std::vector<std::string> &files){
#pragma omp parallel for schedule(dynamic)
for (ptrdiff_t i=0;i<static_cast<ptrdiff_t>(files.size());++i) learn_from_file(kb, files[(size_t)i]);
}
// --------------------------- Serialization (binary, versioned) --------------
static constexpr std::uint64_t KB_MAGIC = 0x434850434B535641ULL; // "CHPCKSVA"
static constexpr std::uint64_t KB_VERSION = 1ULL;
static void write_u64(std::ostream &os, std::uint64_t v){
os.write(reinterpret_cast<const char*>(&v), sizeof(v));
if(!os) throw std::runtime_error("write_u64 failed");
}
static std::uint64_t read_u64(std::istream &is){
std::uint64_t v = 0;
is.read(reinterpret_cast<char*>(&v), sizeof(v));
if(!is) throw std::runtime_error("read_u64 failed");
return v;
}
static void write_string(std::ostream &os, const std::string &s){
write_u64(os, static_cast<std::uint64_t>(s.size()));
if (!s.empty()){
os.write(s.data(), static_cast<std::streamsize>(s.size()));
if(!os) throw std::runtime_error("write_string failed");
}
}
static std::string read_string(std::istream &is){
std::uint64_t n = read_u64(is);
if (n > (1ULL << 30)) throw std::runtime_error("corrupt save file: string too large");
std::string s;
s.resize(static_cast<size_t>(n));
if (n != 0){
is.read(&s[0], static_cast<std::streamsize>(n));
if(!is) throw std::runtime_error("read_string failed");
}
return s;
}
static void save_kb_binary(const KnowledgeBase &kb, const std::string &fname){
const std::string temp = fname + ".tmp";
{
std::ofstream ofs(temp.c_str(), std::ios::binary | std::ios::trunc);
if (!ofs) throw std::runtime_error("cannot open temp save file");
std::vector<std::string> pool;
pool.reserve(kb.interner.pool.size());
for (const auto &s : kb.interner.pool) pool.push_back(s);
std::sort(pool.begin(), pool.end());
std::unordered_map<std::string, std::uint64_t> id;
id.reserve(pool.size());
for (std::uint64_t i = 0; i < static_cast<std::uint64_t>(pool.size()); ++i){
id.emplace(pool[(size_t)i], i);
}
write_u64(ofs, KB_MAGIC);
write_u64(ofs, KB_VERSION);
write_u64(ofs, static_cast<std::uint64_t>(kb.def_depth));
write_u64(ofs, static_cast<std::uint64_t>(pool.size()));
for (const auto &s : pool) write_string(ofs, s);
write_u64(ofs, static_cast<std::uint64_t>(kb.next.size()));
for (const auto &pr : kb.next){
write_u64(ofs, id.at(*pr.first));
write_u64(ofs, static_cast<std::uint64_t>(pr.second.size()));
for (StrPtr nxt : pr.second){
write_u64(ofs, id.at(*nxt));
}
}
write_u64(ofs, static_cast<std::uint64_t>(kb.def_index.size()));
for (const auto &pr : kb.def_index){
write_u64(ofs, id.at(*pr.first));
write_u64(ofs, static_cast<std::uint64_t>(pr.second.size()));
for (StrPtr tok : pr.second){
write_u64(ofs, id.at(*tok));
}
}
ofs.flush();
if (!ofs) throw std::runtime_error("failed while writing temp save file");
}
std::remove(fname.c_str());
if (std::rename(temp.c_str(), fname.c_str()) != 0){
std::remove(temp.c_str());
throw std::runtime_error("failed to commit save file");
}
}
static void load_kb_binary(KnowledgeBase &kb, const std::string &fname, int cli_dict_depth){
std::ifstream ifs(fname, std::ios::binary);
if (!ifs) throw std::runtime_error("cannot open load file");
const std::uint64_t magic = read_u64(ifs);
if (magic != KB_MAGIC) throw std::runtime_error("bad save file magic");
const std::uint64_t version = read_u64(ifs);
if (version != KB_VERSION) throw std::runtime_error("unsupported save file version");
const std::uint64_t file_def_depth = read_u64(ifs);
const std::uint64_t N = read_u64(ifs);
if (N > (1ULL << 26)) throw std::runtime_error("corrupt save file: pool too large");
std::vector<std::string> strings;
strings.reserve(static_cast<size_t>(N));
for (std::uint64_t i = 0; i < N; ++i){
strings.push_back(read_string(ifs));
}
kb.interner.pool.clear();
kb.interner.pool.reserve(static_cast<size_t>(N));
std::vector<StrPtr> ptrs;
ptrs.reserve(static_cast<size_t>(N));
for (const auto &s : strings){
ptrs.push_back(kb.interner.intern(s));
}
// Rebuild next
const std::uint64_t E = read_u64(ifs);
if (E > (1ULL << 26)) throw std::runtime_error("corrupt save file: graph too large");
{
std::lock_guard<std::mutex> lk(kb.m);
kb.next.clear();
kb.next_key_index.clear();
kb.next.reserve(static_cast<size_t>(E));
kb.next_key_index.reserve(static_cast<size_t>(E));
}
for (std::uint64_t i = 0; i < E; ++i){
const std::uint64_t key_idx = read_u64(ifs);
const std::uint64_t M = read_u64(ifs);
if (key_idx >= ptrs.size()) throw std::runtime_error("corrupt save file: bad graph key");
if (M > (1ULL << 26)) throw std::runtime_error("corrupt save file: graph degree too large");
StrPtr key_ptr = ptrs[(size_t)key_idx];
NextSet vec;
vec.reserve(static_cast<size_t>(M));
for (std::uint64_t j = 0; j < M; ++j){
const std::uint64_t v_idx = read_u64(ifs);
if (v_idx >= ptrs.size()) throw std::runtime_error("corrupt save file: bad graph value");
vec.push_back(ptrs[(size_t)v_idx]);
}
{
std::lock_guard<std::mutex> lk(kb.m);
kb.next.emplace(key_ptr, std::move(vec));
kb.next_key_index.emplace(*key_ptr, key_ptr);
}
}
// Rebuild def_index from file
const std::uint64_t K = read_u64(ifs);
if (K > (1ULL << 26)) throw std::runtime_error("corrupt save file: def_index too large");
{
std::lock_guard<std::mutex> lk(kb.def_m);
kb.def_index.clear();
kb.def_index.reserve(static_cast<size_t>(K));
kb.def_depth = static_cast<int>(file_def_depth);
}
for (std::uint64_t i = 0; i < K; ++i){
const std::uint64_t key_idx = read_u64(ifs);
const std::uint64_t M = read_u64(ifs);
if (key_idx >= ptrs.size()) throw std::runtime_error("corrupt save file: bad def key");
if (M > (1ULL << 26)) throw std::runtime_error("corrupt save file: def list too large");
std::vector<StrPtr> toks;
toks.reserve(static_cast<size_t>(M));
for (std::uint64_t j = 0; j < M; ++j){
const std::uint64_t v_idx = read_u64(ifs);
if (v_idx >= ptrs.size()) throw std::runtime_error("corrupt save file: bad def value");
toks.push_back(ptrs[(size_t)v_idx]);
}
{
std::lock_guard<std::mutex> lk(kb.def_m);
kb.def_index.emplace(ptrs[(size_t)key_idx], std::move(toks));
}
}
// If the caller asks for a different dict depth, recompute with the current embedded dictionary.
if (cli_dict_depth != static_cast<int>(file_def_depth)){
kb.set_def_depth(cli_dict_depth);
std::vector<StrPtr> targets;
targets.reserve(ptrs.size() + kb.next.size() * 2);
std::unordered_set<StrPtr, PtrHash, PtrEq> seen;
seen.reserve(ptrs.size() + kb.next.size() * 2);
for (StrPtr p : ptrs){
if (seen.insert(p).second) targets.push_back(p);
}
{
std::lock_guard<std::mutex> lk(kb.m);
for (const auto &pr : kb.next){
if (seen.insert(pr.first).second) targets.push_back(pr.first);
for (StrPtr v : pr.second){
if (seen.insert(v).second) targets.push_back(v);
}
}
}
#pragma omp parallel for schedule(dynamic)
for (ptrdiff_t i = 0; i < static_cast<ptrdiff_t>(targets.size()); ++i){
kb.ensure_def_for_interned(targets[(size_t)i]);
}
}
}
// --------------------------- CLI + Interactive loop (shorters) -----------
static void print_usage(const char *p){
std::cout << "Usage: " << p << " [--maxlen N] [--save FILE] [--load-kb FILE] [--dict-depth D] [--learn f1 f2 ...] [--repeat-penalty P] [--help]\n";
std::cout << " --maxlen N Maximum number of tokens constructed in a response.\n";
std::cout << " --save FILE Save the knowledge-base and dictionary expansions to a binary file.\n";
std::cout << " --load-kb FILE Load a previously saved knowledge-base (and dictionary expansions) from a binary file.\n";
std::cout << " --dict-depth D Depth of dictionary-definition expansion used during learning.\n";
std::cout << " --learn f1 f2 ... Learn from one or more text files to update the knowledge base.\n";
std::cout << " --repeat-penalty P Penalize repeated tokens during response generation (higher values discourage repetition).\n";
std::cout << " --help Show command-line interface options for ChatIPC usage.\n";
}
int main(int argc, char **argv){
size_t maxlen = 100;
std::string savefile;
std::string load_txt;
std::string load_kb;
int dict_depth = 2;
double repeat_penalty = 0.7; // default λ
std::vector<std::string> learn_files;
for (int i=1;i<argc;++i){
std::string a = argv[i];
if (a=="--help"){ print_usage(argv[0]); return 0; }
if (a=="--maxlen" && i+1<argc){ maxlen = std::stoul(argv[++i]); continue; }
if (a=="--save" && i+1<argc){ savefile = argv[++i]; continue; }
if (a=="--load-kb" && i+1<argc){ load_kb = argv[++i]; continue; }
if (a=="--dict-depth" && i+1<argc){ dict_depth = std::stoi(argv[++i]); continue; }
if (a=="--repeat-penalty" && i+1<argc){ repeat_penalty = std::stod(argv[++i]); continue; }
if (a=="--learn"){
++i;
for (; i<argc; ++i){
if (!argv[i]) break;
std::string s = argv[i];
if (!s.empty() && s[0]=='-'){ --i; break; }
learn_files.push_back(s);
}
continue;
}
learn_files.push_back(a);
}
KnowledgeBase kb;
// parse the embedded dictionary once for use by per-word expansion
global_dictionary_entries = parse_dictionary_json();
build_def_tokens_cache();
// set KB def depth (clears any previous expansion)
kb.set_def_depth(dict_depth);
if (!load_kb.empty()){
try { std::cerr << "Loading KB: " << load_kb << "\n";
load_kb_binary(kb, load_kb, dict_depth); std::cerr << "Loaded KB: " << load_kb << "\n"; }
catch (const std::exception &e){ std::cerr << "Load KB error: " << e.what() << "\n"; }
}
if (!learn_files.empty()){
std::cerr << "Learning from file/s (" << learn_files.size() << ") using threads=" << omp_get_max_threads() << "\n";
learn_files_parallel(kb, learn_files);
}
std::string line;
std::cout << "Ready. Enter prompts.\n";
while (std::cout << "> " , std::getline(std::cin, line)){
if (line.empty()){ std::cout << "\n"; continue; }
auto prompt_toks = tokenize_whitespace(line);
for (size_t i=1;i<prompt_toks.size();++i) kb.add_pair(prompt_toks[i-1], prompt_toks[i]);
auto resp = construct_response(kb, prompt_toks, maxlen, repeat_penalty);
std::cout << "\n";
if (!resp.empty()){
std::vector<std::string> combined = prompt_toks;
combined.insert(combined.end(), resp.begin(), resp.end());
for (size_t i=1;i<combined.size();++i) kb.add_pair(combined[i-1], combined[i]);
}
if (!savefile.empty()){
try { std::cerr << "Saving KB: " << savefile << "\n";
save_kb_binary(kb, savefile); std::cerr << "Saved KB: " << savefile << "\n"; }
catch (const std::exception &e){ std::cerr << "Save KB error: " << e.what() << "\n"; }
}
}
return 0;
}