"""TwentyQ: The world's smallest chat model. 2-bit quantized neural network (1988), 156 attention heads, 1200 output classes. Trained on ~75 million conversations. Context window: 20 questions. """ import hashlib import random import torch import torch.nn as nn from transformers import PreTrainedModel, GenerationMixin from .configuration_twentyq import TwentyQConfig # Answer codes: 1=No(pol0), 2=Yes(pol1), 3=Probably(pol0), 4=Doubtful(pol1), 5=Maybe(pol0), 6=Unknown POLARITY = [0, 0, 1, 0, 1, 0, 0] MATCH_BONUS = [0, 4, 4, 3, 3, 1, 0] MISS_PENALTY = [0, 4, 4, 1, 1, 0, 0] ANSWER_WORDS = { "yes": 2, "y": 2, "yeah": 2, "yep": 2, "usually": 2, "no": 1, "n": 1, "nope": 1, "nah": 1, "probably": 3, "prob": 3, "likely": 3, "doubtful": 4, "doubt": 4, "rarely": 4, "maybe": 5, "sometimes": 5, "perhaps": 5, "partly": 5, "unknown": 6, "dunno": 6, "idk": 6, "irrelevant": 6, "skip": 6, "close": -1, } AVM_WORDS = {"animal": 1, "vegetable": 2, "mineral": 3, "other": 4} class TwentyQForCausalLM(PreTrainedModel, GenerationMixin): config_class = TwentyQConfig _tied_weights_keys = [] def __init__(self, config): super().__init__(config) self.all_tied_weights_keys = {} self._dummy = nn.Parameter(torch.zeros(1), requires_grad=False) self.register_buffer("weight_matrix", torch.zeros(config.num_questions, config.num_targets, dtype=torch.uint8)) self._vocab_loaded = False def set_vocab(self, questions, targets): """Set question and target strings (called by tokenizer or manually).""" self.questions_str = list(questions) self.targets_str = list(targets) self._q_lookup = {q.lower(): i for i, q in enumerate(self.questions_str)} self._t_lookup = {t.lower(): i for i, t in enumerate(self.targets_str)} self._vocab_loaded = True def _ensure_strings(self): if self._vocab_loaded: return raise RuntimeError( "Model vocabulary not loaded. Call model.set_vocab(questions, targets) " "or load a tokenizer with vocab.json alongside the model." ) def forward(self, input_ids=None, **kwargs): # Dummy forward — the real work happens in generate() batch = input_ids.shape[0] if input_ids is not None else 1 return {"logits": torch.zeros(batch, 1, self.config.vocab_size)} def generate(self, input_ids=None, attention_mask=None, **kwargs): self._ensure_strings() # Decode input_ids to text (byte-level tokenizer, filter specials > 255) ids = input_ids[0].tolist() raw_bytes = bytes(b for b in ids if b < 256) text = raw_bytes.decode("utf-8", errors="replace") # Parse conversation and get next response answers, qnum, last_was_guess, game_over_msg, unrecognized = self._parse_conversation(text) if unrecognized: response = f"I didn't understand that. Please answer: {unrecognized}" elif game_over_msg: response = game_over_msg else: # Seed RNG from conversation for deterministic play seed = int(hashlib.md5(text.encode()).hexdigest()[:8], 16) self._rng = random.Random(seed) response = self._next_move(answers, qnum, last_was_guess) response_ids = list(response.encode("utf-8")) response_tensor = torch.tensor([response_ids], dtype=input_ids.dtype, device=input_ids.device) return torch.cat([input_ids, response_tensor], dim=1) def _parse_conversation(self, text): """Parse chat-templated text into game state.""" answers = [] # [(q_idx, ans_code, is_guess)] qnum = 0 last_was_guess = False game_over_msg = None unrecognized = None # set to hint string if last answer wasn't understood # Split into turns by [A] and [U] markers parts = text.replace("\r", "").split("\n") turns = [] for line in parts: line = line.strip() if line.startswith("[A] "): turns.append(("a", line[4:].strip())) elif line.startswith("[U] "): turns.append(("u", line[4:].strip())) # Pair up assistant/user turns i = 0 while i < len(turns): if turns[i][0] == "a": a_msg = turns[i][1] u_msg = turns[i + 1][1] if i + 1 < len(turns) and turns[i + 1][0] == "u" else None if u_msg is None: # This is the generation prompt — no user response yet break u_lower = u_msg.lower().strip().rstrip(".") if "animal, vegetable, mineral" in a_msg.lower(): # AVM question avm_code = AVM_WORDS.get(u_lower, 0) if avm_code: answers.append((0, avm_code, False)) qnum += 1 unrecognized = None else: unrecognized = "Animal, Vegetable, Mineral, or Other" i += 2 elif a_msg.lower().startswith("i'm guessing"): # Guess target_name = a_msg.split("...")[-1].strip().rstrip("?").strip() t_idx = self._t_lookup.get(target_name.lower(), -1) ans_code = ANSWER_WORDS.get(u_lower, 0) if ans_code == 2: # Yes — correct guess game_over_msg = f"I win! Got it in {qnum + 1} questions." unrecognized = None elif ans_code == 1 or ans_code == -1: # No or Close if t_idx >= 0: answers.append((t_idx, 0, True)) qnum += 1 unrecognized = None else: unrecognized = "Yes, No, or Close" i += 2 elif a_msg.lower().startswith("i win") or a_msg.lower().startswith("i'm stumped"): # Game already over game_over_msg = a_msg i += 2 else: # Regular question q_text = a_msg.rstrip("?").strip() q_idx = self._q_lookup.get(q_text.lower(), -1) ans_code = ANSWER_WORDS.get(u_lower, 0) if ans_code == -1 or ans_code == 0: unrecognized = "Yes, No, Probably, Doubtful, Maybe, or Unknown" else: unrecognized = None if q_idx >= 0: answers.append((q_idx, ans_code, False)) qnum += 1 i += 2 else: i += 1 return answers, qnum, last_was_guess, game_over_msg, unrecognized def _next_move(self, answers, qnum, last_was_guess): if qnum == 0: return "Is it Animal, Vegetable, Mineral, or Other?" if qnum >= 30: return "I'm stumped! I can't figure out what you're thinking of." nc, best_t, best_s, cidx, cscores = self._rank_targets(answers) if nc == 0: return "I'm stumped! I can't figure out what you're thinking of." should_guess = ( nc == 1 or qnum == 20 or qnum == 24 or qnum == 30 or (qnum >= 18 and nc <= 2) ) if should_guess: return f"I'm guessing... {self.targets_str[best_t]}?" q = self._select_question(answers, nc, cidx) if q < 0: return f"I'm guessing... {self.targets_str[best_t]}?" return f"{self.questions_str[q]}?" def _score(self, answer_code, target, question): w = int(self.weight_matrix[question, target]) if (POLARITY[answer_code] ^ w) & 1: s = -MISS_PENALTY[answer_code] else: s = MATCH_BONUS[answer_code] if w & 2: s *= 2 return s def _rank_targets(self, answers): max_c = 16 if len(answers) <= 10 else (8 if len(answers) <= 12 else 5) c_scores = [0] * max_c c_indices = [0] * max_c nc = 0 best_t, best_s = 0, 0 for t in range(self.config.num_targets): guessed = any(qi == t and ig for qi, _, ig in answers) if guessed: continue score = 0 skip = False for qi, ac, ig in answers: if ig or ac == 0: continue if qi != 0: score += self._score(ac, t, qi) else: for k in range(4): score += self._score(4 if k + 1 == ac else 3, t, k) if len(answers) > 7 and score < 0: skip = True break if skip or score < 0: continue score += self._rng.randint(0, 7) if nc < max_c: slot = nc nc += 1 else: min_s, slot = min((c_scores[j], j) for j in range(max_c)) if min_s >= score: continue c_scores[slot] = score c_indices[slot] = t if score > best_s: best_t, best_s = t, score thresh = best_s // 4 thresh = max(5, min(20, thresh)) cutoff = best_s - thresh pi = [(c_indices[j], c_scores[j]) for j in range(nc) if c_scores[j] > cutoff] if not pi: return 0, best_t, best_s, [], [] idx, sc = zip(*pi) return len(pi), best_t, best_s, list(idx), list(sc) def _select_question(self, answers, nc, cidx): best_s, best_q = -1000, -1 asked = {qi for qi, _, ig in answers if not ig} for q in range(4, self.config.num_questions): if q in asked: continue pos, neg = 0, 0 for t in cidx: w = int(self.weight_matrix[q, t]) wt = 3 if (w & 2) else 1 if w & 1: neg += wt else: pos += wt s = (pos * 2 - neg) if pos <= neg else (neg * 2 - pos) s += self._rng.randint(0, 7) if s > best_s: best_s, best_q = s, q return best_q def play(self, tokenizer=None): """Interactive CLI mode. Pass the tokenizer for proper chat template formatting.""" self._ensure_strings() if tokenizer is None: # Minimal fallback — construct chat text directly from .tokenization_twentyq import TwentyQTokenizer tokenizer = TwentyQTokenizer() tokenizer.chat_template = ( "{% if messages[0]['role'] == 'system' %}{{ messages[0]['content'] }}\n" "{% set loop_messages = messages[1:] %}{% else %}" "{% set loop_messages = messages %}{% endif %}" "{% for message in loop_messages %}" "{% if message['role'] == 'assistant' %}[A] {{ message['content'] }}\n" "{% elif message['role'] == 'user' %}[U] {{ message['content'] }}\n" "{% endif %}{% endfor %}" "{% if add_generation_prompt %}[A] {% endif %}" ) messages = [ {"role": "system", "content": "Think of something and I'll try to guess it in 20 questions."}, ] print("\n Think of something...\n") input(" Press Enter when ready... ") while True: text = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) ids = tokenizer.encode(text, return_tensors="pt") out = self.generate(ids) response = tokenizer.decode(out[0, ids.shape[1]:].tolist()) messages.append({"role": "assistant", "content": response}) print(f"\n > {response}") if "I win" in response or "stumped" in response: return if "Animal, Vegetable, Mineral" in response: hint = "(Animal/Vegetable/Mineral/Other)" elif "guessing" in response.lower(): hint = "(Yes/No/Close)" else: hint = "(Yes/No/Probably/Doubtful/Maybe/Unknown)" reply = input(f" {hint}: ").strip() if not reply: return messages.append({"role": "user", "content": reply})