import argparse import json import sys from pathlib import Path import re import torch import torch.nn.functional as F from transformers import AutoTokenizer try: from tokenizers import Tokenizer as HFTokenizer except ImportError: HFTokenizer = None DEFAULT_CHECKPOINT = r"D:\Downloads\hssm_fineweb_edu_final.pt" DEFAULT_TOKENIZER = r"D:\Downloads\simple_tokenizer_20k.json" RUBINET_HSSM_PATH = r"C:\Users\ASUS\.anaconda" sys.path.append(RUBINET_HSSM_PATH) from RubiNet_HSSM import HierarchicalSSM from hssm_v2_gpu_pretrain import HSSMV2Config, HSSMV2LM if hasattr(sys.stdout, "reconfigure"): sys.stdout.reconfigure(encoding="utf-8", errors="replace") if hasattr(sys.stderr, "reconfigure"): sys.stderr.reconfigure(encoding="utf-8", errors="replace") class CompatibleTokenizer: def __init__(self, tokenizer_path: str): path = Path(tokenizer_path) with path.open("r", encoding="utf-8") as f: data = json.load(f) self.backend = None if HFTokenizer is not None: try: self.backend = HFTokenizer.from_file(str(path)) except Exception: self.backend = None if "model" in data and isinstance(data["model"], dict) and "vocab" in data["model"]: vocab = data["model"]["vocab"] elif "vocab" in data: vocab = data["vocab"] else: raise ValueError(f"Unsupported tokenizer format: {path}") self.vocab = {str(token): int(idx) for token, idx in vocab.items()} self.id_to_token = {idx: token for token, idx in self.vocab.items()} self.vocab_size = len(self.vocab) self.pad_token_id = self._resolve_token_id(["", "[PAD]"], fallback=0) self.unk_token_id = self._resolve_token_id(["", "[UNK]"], fallback=3) print(f"[TOKENIZER] Loaded vocab tokenizer - Vocab: {self.vocab_size:,}") def _resolve_token_id(self, candidates, fallback: int): for token in candidates: token_id = self.vocab.get(token) if token_id is not None: return token_id return fallback def encode(self, text, max_length=128): if self.backend is not None: ids = self.backend.encode(text).ids[:max_length] else: words = text.split() ids = [self.vocab.get(word, self.unk_token_id) for word in words][:max_length] if len(ids) < max_length: ids += [self.pad_token_id] * (max_length - len(ids)) return ids def decode(self, ids): filtered = [int(i) for i in ids if int(i) != self.pad_token_id] if self.backend is not None: return self.backend.decode(filtered, skip_special_tokens=False) return " ".join(self.id_to_token.get(i, "") for i in filtered) class HFTokenizerAdapter: def __init__(self, tokenizer_name: str): self.backend = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=True) if self.backend.pad_token is None: self.backend.pad_token = self.backend.eos_token or self.backend.unk_token self.vocab = self.backend.get_vocab() self.id_to_token = {idx: token for token, idx in self.vocab.items()} self.vocab_size = int(self.backend.vocab_size) self.pad_token_id = int(self.backend.pad_token_id) self.unk_token_id = int(self.backend.unk_token_id if self.backend.unk_token_id is not None else self.pad_token_id) def encode(self, text, max_length=128): ids = self.backend.encode(text, add_special_tokens=False, truncation=True, max_length=max_length) if len(ids) < max_length: ids += [self.pad_token_id] * (max_length - len(ids)) return ids def decode(self, ids): filtered = [int(i) for i in ids if int(i) != self.pad_token_id] return self.backend.decode(filtered, skip_special_tokens=True) def build_model(tokenizer): return HierarchicalSSM( vocab_size=tokenizer.vocab_size, d_model=512, d_state=32, num_blocks=6, num_experts=8, top_k=2, chunk_size=4, expert_dim=1024, ) def build_hssm_v2_model(tokenizer, checkpoint_config: dict): config = HSSMV2Config( vocab_size=int(checkpoint_config.get("vocab_size", tokenizer.vocab_size)), d_model=int(checkpoint_config.get("d_model", 288)), n_layers=int(checkpoint_config.get("n_layers", 10)), d_ff=int(checkpoint_config.get("d_ff", 512)), state_rank=int(checkpoint_config.get("state_rank", 128)), chunk_size=int(checkpoint_config.get("chunk_size", 8)), dropout=float(checkpoint_config.get("dropout", 0.0)), max_seq_len=int(checkpoint_config.get("max_seq_len", 1024)), tie_embeddings=bool(checkpoint_config.get("tie_embeddings", True)), num_experts=int(checkpoint_config.get("num_experts", 64)), experts_per_token=int(checkpoint_config.get("experts_per_token", 1)), expert_dim=int(checkpoint_config.get("expert_dim", 2048)), moe_every=int(checkpoint_config.get("moe_every", 4)), aux_loss_coef=float(checkpoint_config.get("aux_loss_coef", 1e-2)), ) return HSSMV2LM(config) def _looks_like_hf_tokenizer_reference(tokenizer_path: str) -> bool: path = Path(tokenizer_path) return not path.exists() def _load_tokenizer(tokenizer_path: str): if _looks_like_hf_tokenizer_reference(tokenizer_path): return HFTokenizerAdapter(tokenizer_path) return CompatibleTokenizer(tokenizer_path) def _is_hssm_v2_checkpoint(checkpoint: dict) -> bool: config = checkpoint.get("config") if isinstance(checkpoint, dict) else None if not isinstance(config, dict): return False required_keys = {"d_model", "n_layers", "state_rank", "chunk_size"} return required_keys.issubset(config.keys()) def load_pretrained(checkpoint_path: str, tokenizer_path: str, device: str): checkpoint_file = Path(checkpoint_path) if not checkpoint_file.exists(): raise FileNotFoundError(f"Checkpoint not found: {checkpoint_file}") if not _looks_like_hf_tokenizer_reference(tokenizer_path): tokenizer_file = Path(tokenizer_path) if not tokenizer_file.exists(): raise FileNotFoundError(f"Tokenizer not found: {tokenizer_file}") tokenizer = _load_tokenizer(tokenizer_path) checkpoint = torch.load(str(checkpoint_file), map_location=device, weights_only=False) state_dict = checkpoint["model_state_dict"] if "model_state_dict" in checkpoint else checkpoint if _is_hssm_v2_checkpoint(checkpoint): model = build_hssm_v2_model(tokenizer, checkpoint.get("config", {})) else: model = build_model(tokenizer) missing, unexpected = model.load_state_dict(state_dict, strict=False) model = model.to(device) model.eval() print("Loaded HSSM checkpoint") print(f" Path: {checkpoint_file}") print(f" Missing keys: {len(missing)}") print(f" Unexpected keys: {len(unexpected)}") if "epoch" in checkpoint: print(f" Epoch: {checkpoint['epoch']}") if "loss" in checkpoint: print(f" Loss: {checkpoint['loss']}") print(f" Model type: {'HSSM v2' if _is_hssm_v2_checkpoint(checkpoint) else 'RubiNet HSSM'}") return tokenizer, model def _model_chunk_size(model) -> int: if hasattr(model, "chunk_size"): return int(model.chunk_size) if hasattr(model, "config") and hasattr(model.config, "chunk_size"): return int(model.config.chunk_size) return 1 def _next_token_logits(model, input_tensor: torch.Tensor, current_len: int) -> torch.Tensor: outputs = model(input_tensor) if isinstance(outputs, dict): logits = outputs.get("logits") if logits is None: raise ValueError("Model returned a dict without logits") return logits[0, current_len - 1, :].clone() chunk_size = _model_chunk_size(model) chunk_idx = max((current_len - 1) // chunk_size, 0) return outputs[0, chunk_idx, :].clone() def build_prompt(user_text: str, cot_mode: bool = False) -> str: cleaned_user_text = user_text.strip() if cot_mode: return ( "system: Reply only in correct English. " "Follow English grammar, spelling, punctuation, and sentence structure strictly. " "Do not output fragments, corrupted tokens, mixed-language text, or placeholder symbols. " "Think step by step briefly and keep the output clean. " "Output exactly two lines in this format: " "Reasoning: . " "Answer: . " "Keep both lines grammatical and concise.\n" f"user: {cleaned_user_text}\n" "assistant:" ) return ( "system: Reply only in correct English. " "Follow English grammar, spelling, punctuation, and sentence structure strictly. " "Use short complete sentences. " "Do not output broken words, malformed tokens, mixed-language text, or placeholder symbols.\n" f"user: {cleaned_user_text}\n" "assistant:" ) def safe_print(text: str): try: print(text) except UnicodeEncodeError: sanitized = text.encode("utf-8", errors="replace").decode("utf-8", errors="replace") print(sanitized) def _apply_top_p_filter(logits: torch.Tensor, top_p: float) -> torch.Tensor: if top_p >= 1.0: return logits sorted_logits, sorted_indices = torch.sort(logits, descending=True) sorted_probs = F.softmax(sorted_logits, dim=-1) cumulative_probs = torch.cumsum(sorted_probs, dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = False indices_to_remove = sorted_indices[sorted_indices_to_remove] logits[indices_to_remove] = float("-inf") return logits def _has_repeat_ngram(token_ids, next_token_id: int, ngram_size: int) -> bool: if ngram_size <= 1 or len(token_ids) < ngram_size - 1: return False candidate = token_ids[-(ngram_size - 1):] + [next_token_id] for i in range(len(token_ids) - ngram_size + 1): if token_ids[i:i + ngram_size] == candidate: return True return False def _normalize_word(word: str) -> str: return re.sub(r"[^a-z0-9]+", "", word.lower()) def _recent_word_counts(text: str, window: int = 12): words = [_normalize_word(part) for part in text.split()] words = [word for word in words if word] recent = words[-window:] counts = {} for word in recent: counts[word] = counts.get(word, 0) + 1 return counts def _violates_word_repeat(decoded_text: str, candidate_piece: str) -> bool: candidate_word = _normalize_word(candidate_piece) if not candidate_word: return False counts = _recent_word_counts(decoded_text, window=12) return counts.get(candidate_word, 0) >= 2 def _resolve_special_token_ids(tokenizer): special_ids = set() for token in ["", "[BOS]", "", "[PAD]", "", "[SEP]", "", "[EOS]", "", "[UNK]", "", "[CLS]", "", "[MASK]", ""]: token_id = tokenizer.vocab.get(token) if token_id is not None: special_ids.add(int(token_id)) if getattr(tokenizer, "pad_token_id", None) is not None: special_ids.add(int(tokenizer.pad_token_id)) if getattr(tokenizer, "unk_token_id", None) is not None: special_ids.add(int(tokenizer.unk_token_id)) return special_ids def _contains_special_marker(text: str) -> bool: upper_text = text.upper() markers = ["", "[BOS]", "", "[PAD]", "", "[SEP]", "", "[EOS]", "", "[UNK]", "", "[CLS]", "", "[MASK]"] return any(marker in upper_text for marker in markers) def _looks_like_artifact(text: str) -> bool: stripped = text.strip() if not stripped: return False if "##" in stripped: return True if stripped.startswith("##") or stripped.endswith("##"): return True if stripped.count("#") >= 1 and len(stripped) <= 4: return True if "�" in stripped: return True if any(ch in stripped for ch in ["�", ""]): return True if re.search(r"(.)\1{3,}", stripped.lower()): return True if re.fullmatch(r"[A-Za-z]{1,4}\d{2,}", stripped): return True if re.fullmatch(r"[#\-_=~`|.]+", stripped): return True return False def _strip_special_markers(text: str) -> str: cleaned = text for pattern in [r"<\s*BOS\s*>", r"\[\s*BOS\s*\]", r"<\s*PAD\s*>", r"\[\s*PAD\s*\]", r"<\s*SEP\s*>", r"\[\s*SEP\s*\]", r"<\s*EOS\s*>", r"\[\s*EOS\s*\]", r"<\s*UNK\s*>", r"\[\s*UNK\s*\]", r"<\s*CLS\s*>", r"\[\s*CLS\s*\]", r"<\s*MASK\s*>", r"\[\s*MASK\s*\]"]: cleaned = re.sub(pattern, " ", cleaned, flags=re.IGNORECASE) cleaned = re.sub(r"#{2,}", " ", cleaned) cleaned = re.sub(r"(? str: cleaned = text.strip() if not cleaned: return cleaned replacements = { " im ": " I'm ", " ive ": " I've ", " ill ": " I'll ", " id ": " I'd ", " dont ": " don't ", " cant ": " can't ", " wont ": " won't ", " didnt ": " didn't ", " doesnt ": " doesn't ", " isnt ": " isn't ", " arent ": " aren't ", " wasnt ": " wasn't ", " werent ": " weren't ", " thats ": " that's ", " whats ": " what's ", " theres ": " there's ", " ive ": " I've ", } padded = f" {cleaned} " for source, target in replacements.items(): padded = re.sub(re.escape(source), target, padded, flags=re.IGNORECASE) cleaned = padded.strip() cleaned = re.sub(r"\bi\b", "I", cleaned) cleaned = re.sub(r"\b([A-Za-z]+)(\s+\1\b){1,}", r"\1", cleaned, flags=re.IGNORECASE) cleaned = re.sub(r"\s+([,;:.!?])", r"\1", cleaned) cleaned = re.sub(r"([,;:.!?])(?!\s|$)", r"\1 ", cleaned) cleaned = re.sub(r"\s+", " ", cleaned).strip() if cleaned: cleaned = cleaned[0].upper() + cleaned[1:] sentences = re.split(r"(?<=[.!?])\s+", cleaned) normalized_sentences = [] for sentence in sentences: sentence = sentence.strip() if not sentence: continue if len(sentence) == 1: normalized_sentences.append(sentence.upper()) else: normalized_sentences.append(sentence[0].upper() + sentence[1:]) cleaned = " ".join(normalized_sentences).strip() if cleaned and cleaned[-1] not in ".!?": cleaned += "." return cleaned def _is_strict_english_output(text: str, cot_mode: bool = False) -> bool: cleaned = text.strip() if not cleaned: return False if any(token in cleaned for token in ["[", "]", "{", "}", "|", "", "[UNK]", "", "[PAD]"]): return False if re.search(r"[^A-Za-z0-9\s,.;:!?\-\'\"()\n]", cleaned): return False words = re.findall(r"[A-Za-z']+", cleaned) if len(words) < 2: return False long_weird_words = [word for word in words if len(word) > 18] if long_weird_words: return False if re.search(r"([A-Za-z]{2,})([A-Z][a-z]+)", cleaned): return False common_markers = { "the", "a", "an", "is", "are", "am", "i", "you", "we", "they", "it", "to", "of", "and", "that", "this", "can", "will", "do", "not", "yes", "no", "my", "your", "in", "on", "for" } lowered_words = [word.lower() for word in words] if not any(word in common_markers for word in lowered_words): return False if cot_mode: lines = [line.strip() for line in cleaned.splitlines() if line.strip()] if len(lines) != 2: return False if not lines[0].startswith("Reasoning:"): return False if not lines[1].startswith("Answer:"): return False sentences = [segment.strip() for segment in re.split(r"(?<=[.!?])\s+", cleaned) if segment.strip()] if not sentences: return False for sentence in sentences: if not sentence[0].isupper(): return False if sentence[-1] not in ".!?": return False return True def _force_cot_shape(text: str) -> str: cleaned = text.strip() if not cleaned: return cleaned lines = [line.strip() for line in cleaned.splitlines() if line.strip()] if len(lines) >= 2 and lines[0].startswith("Reasoning:") and lines[1].startswith("Answer:"): return f"{lines[0]}\n{lines[1]}" parts = re.split(r"(?<=[.!?])\s+", cleaned, maxsplit=1) if len(parts) == 2: reasoning, answer = parts else: reasoning, answer = "Reasoning: Briefly considered the request.", f"Answer: {cleaned}" return f"{reasoning}\n{answer}" reasoning = reasoning if reasoning.startswith("Reasoning:") else f"Reasoning: {reasoning.strip()}" answer = answer if answer.startswith("Answer:") else f"Answer: {answer.strip()}" return f"{reasoning}\n{answer}" def _ban_low_quality_candidates(tokenizer, logits: torch.Tensor): for token_id in range(logits.size(0)): piece = tokenizer.decode([token_id]).strip() if not piece: continue if _contains_special_marker(piece): logits[token_id] = float("-inf") def _select_candidate_id(tokenizer, probs: torch.Tensor, generated, prompt_token_count: int, no_repeat_ngram_size: int): candidate_count = min(24, probs.size(0)) top_probs, top_ids = torch.topk(probs, candidate_count) decoded_so_far = tokenizer.decode(generated[prompt_token_count:]).strip() fallback_clean_id = None fallback_clean_prob = -1.0 fallback_any_id = None fallback_any_prob = -1.0 for prob_value, candidate_id_tensor in zip(top_probs.tolist(), top_ids.tolist()): candidate_id = int(candidate_id_tensor) candidate_piece = tokenizer.decode([candidate_id]).strip() if not candidate_piece: continue if _contains_special_marker(candidate_piece): continue if fallback_any_id is None or prob_value > fallback_any_prob: fallback_any_id = candidate_id fallback_any_prob = prob_value if _looks_like_artifact(candidate_piece): continue if _violates_word_repeat(decoded_so_far, candidate_piece): continue if _has_repeat_ngram(generated, candidate_id, max(no_repeat_ngram_size, 4)): continue normalized_piece = _normalize_word(candidate_piece) if normalized_piece and decoded_so_far: recent_words = [_normalize_word(part) for part in decoded_so_far.split()[-8:]] recent_words = [word for word in recent_words if word] if recent_words.count(normalized_piece) >= 1: continue if fallback_clean_id is None or prob_value > fallback_clean_prob: fallback_clean_id = candidate_id fallback_clean_prob = prob_value if fallback_clean_id is not None: return fallback_clean_id return fallback_any_id def _generate_fallback_reply(model, tokenizer, prompt_tokens, blocked_special_ids, max_length: int): device = next(model.parameters()).device generated = list(prompt_tokens) with torch.no_grad(): for _ in range(min(max_length, 16)): current_len = len(generated) chunk_size = _model_chunk_size(model) pad_len = (chunk_size - current_len % chunk_size) % chunk_size padded_input = generated + [tokenizer.pad_token_id] * pad_len input_tensor = torch.tensor([padded_input], device=device) next_token_logits = _next_token_logits(model, input_tensor, current_len) for special_id in blocked_special_ids: if 0 <= special_id < next_token_logits.size(0): next_token_logits[special_id] = float("-inf") next_token_id = int(torch.argmax(next_token_logits).item()) if next_token_id == tokenizer.pad_token_id: break next_piece = tokenizer.decode([next_token_id]).strip() if not next_piece or _contains_special_marker(next_piece): break generated.append(next_token_id) return generated def generate_reply( model, tokenizer, prompt: str, max_length: int, temperature: float, top_k: int, top_p: float, repetition_penalty: float, no_repeat_ngram_size: int, cot_mode: bool = False, ): model.eval() device = next(model.parameters()).device formatted_prompt = build_prompt(prompt, cot_mode=cot_mode) prompt_ids = tokenizer.encode(formatted_prompt, max_length=128) generated = [tok for tok in prompt_ids if tok != tokenizer.pad_token_id] prompt_token_count = len(generated) blocked_special_ids = _resolve_special_token_ids(tokenizer) bos_token_id = tokenizer.vocab.get("") if not generated: generated = [tokenizer.unk_token_id] prompt_token_count = len(generated) with torch.no_grad(): for _ in range(max_length): current_len = len(generated) chunk_size = _model_chunk_size(model) pad_len = (chunk_size - current_len % chunk_size) % chunk_size padded_input = generated + [tokenizer.pad_token_id] * pad_len input_tensor = torch.tensor([padded_input], device=device) next_token_logits = _next_token_logits(model, input_tensor, current_len) if temperature > 0: next_token_logits = next_token_logits / temperature for special_id in blocked_special_ids: if 0 <= special_id < next_token_logits.size(0): next_token_logits[special_id] = float("-inf") if bos_token_id is not None and 0 <= int(bos_token_id) < next_token_logits.size(0): next_token_logits[int(bos_token_id)] = float("-inf") _ban_low_quality_candidates(tokenizer, next_token_logits) recent_tokens = generated[-48:] recent_weights = {} for idx, token_id in enumerate(recent_tokens): distance_weight = 1.0 + (idx / max(len(recent_tokens), 1)) recent_weights[token_id] = max(recent_weights.get(token_id, 1.0), distance_weight) for token_id, distance_weight in recent_weights.items(): if 0 <= token_id < next_token_logits.size(0): penalty = repetition_penalty * distance_weight if next_token_logits[token_id] > 0: next_token_logits[token_id] /= penalty else: next_token_logits[token_id] *= penalty for token_id in range(next_token_logits.size(0)): if _has_repeat_ngram(generated, token_id, no_repeat_ngram_size): next_token_logits[token_id] = float("-inf") if top_k > 0 and top_k < next_token_logits.size(0): threshold = torch.topk(next_token_logits, top_k)[0][..., -1] next_token_logits[next_token_logits < threshold] = float("-inf") next_token_logits = _apply_top_p_filter(next_token_logits, top_p) probs = F.softmax(next_token_logits, dim=-1) if torch.isnan(probs).any() or torch.isinf(probs).any() or probs.sum() <= 0: break next_token = _select_candidate_id( tokenizer, probs, generated, prompt_token_count, no_repeat_ngram_size, ) if next_token is None: break if next_token == tokenizer.pad_token_id: break generated.append(next_token) decoded_output = tokenizer.decode(generated[prompt_token_count:]).strip() if len(decoded_output.split()) >= 6: tail_words = [_normalize_word(part) for part in decoded_output.split()[-4:]] tail_words = [word for word in tail_words if word] if len(tail_words) >= 4 and len(set(tail_words)) == 1: break output_ids = generated[prompt_token_count:] cleaned_output = _strip_special_markers(tokenizer.decode(output_ids).strip()) if cleaned_output: normalized_output = _cleanup_english_grammar(cleaned_output) if cot_mode: normalized_output = _force_cot_shape(normalized_output) return normalized_output fallback_generated = _generate_fallback_reply( model, tokenizer, generated[:prompt_token_count], blocked_special_ids, max_length, ) fallback_output_ids = fallback_generated[prompt_token_count:] fallback_output = _strip_special_markers(tokenizer.decode(fallback_output_ids).strip()) normalized_fallback_output = _cleanup_english_grammar(fallback_output) if cot_mode: normalized_fallback_output = _force_cot_shape(normalized_fallback_output) return normalized_fallback_output def main(): parser = argparse.ArgumentParser(description="Chat/test with pretrained RubiNet HSSM checkpoint") parser.add_argument("--checkpoint", default=DEFAULT_CHECKPOINT) parser.add_argument("--tokenizer", default=DEFAULT_TOKENIZER) parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") parser.add_argument("--max-length", type=int, default=40) parser.add_argument("--temperature", type=float, default=0.0) parser.add_argument("--top-k", type=int, default=4) parser.add_argument("--top-p", type=float, default=0.65) parser.add_argument("--repetition-penalty", type=float, default=1.9) parser.add_argument("--no-repeat-ngram-size", type=int, default=6) parser.add_argument("--cot-mode", action="store_true") parser.add_argument("--no-cot-mode", action="store_false", dest="cot_mode") parser.set_defaults(cot_mode=False) parser.add_argument("--message", default="") args = parser.parse_args() tokenizer, model = load_pretrained(args.checkpoint, args.tokenizer, args.device) if args.message: output = generate_reply( model, tokenizer, args.message, args.max_length, args.temperature, args.top_k, args.top_p, args.repetition_penalty, args.no_repeat_ngram_size, args.cot_mode, ) safe_print(output) return print("Interactive HSSM chat/test. Type 'exit' to quit.") while True: user_text = input("You: ").strip() if not user_text: continue if user_text.lower() in {"exit", "quit"}: break output = generate_reply( model, tokenizer, user_text, args.max_length, args.temperature, args.top_k, args.top_p, args.repetition_penalty, args.no_repeat_ngram_size, args.cot_mode, ) safe_print(f"HSSM: {output}\n") if __name__ == "__main__": main()