| import argparse |
| import json |
| import sys |
| from pathlib import Path |
| import re |
|
|
| import torch |
| import torch.nn.functional as F |
| from transformers import AutoTokenizer |
|
|
| try: |
| from tokenizers import Tokenizer as HFTokenizer |
| except ImportError: |
| HFTokenizer = None |
|
|
| DEFAULT_CHECKPOINT = r"D:\Downloads\hssm_fineweb_edu_final.pt" |
| DEFAULT_TOKENIZER = r"D:\Downloads\simple_tokenizer_20k.json" |
| RUBINET_HSSM_PATH = r"C:\Users\ASUS\.anaconda" |
|
|
| sys.path.append(RUBINET_HSSM_PATH) |
| from RubiNet_HSSM import HierarchicalSSM |
| from hssm_v2_gpu_pretrain import HSSMV2Config, HSSMV2LM |
|
|
| if hasattr(sys.stdout, "reconfigure"): |
| sys.stdout.reconfigure(encoding="utf-8", errors="replace") |
| if hasattr(sys.stderr, "reconfigure"): |
| sys.stderr.reconfigure(encoding="utf-8", errors="replace") |
|
|
|
|
| class CompatibleTokenizer: |
| def __init__(self, tokenizer_path: str): |
| path = Path(tokenizer_path) |
| with path.open("r", encoding="utf-8") as f: |
| data = json.load(f) |
|
|
| self.backend = None |
| if HFTokenizer is not None: |
| try: |
| self.backend = HFTokenizer.from_file(str(path)) |
| except Exception: |
| self.backend = None |
|
|
| if "model" in data and isinstance(data["model"], dict) and "vocab" in data["model"]: |
| vocab = data["model"]["vocab"] |
| elif "vocab" in data: |
| vocab = data["vocab"] |
| else: |
| raise ValueError(f"Unsupported tokenizer format: {path}") |
|
|
| self.vocab = {str(token): int(idx) for token, idx in vocab.items()} |
| self.id_to_token = {idx: token for token, idx in self.vocab.items()} |
| self.vocab_size = len(self.vocab) |
| self.pad_token_id = self._resolve_token_id(["<PAD>", "[PAD]"], fallback=0) |
| self.unk_token_id = self._resolve_token_id(["<UNK>", "[UNK]"], fallback=3) |
|
|
| print(f"[TOKENIZER] Loaded vocab tokenizer - Vocab: {self.vocab_size:,}") |
|
|
| def _resolve_token_id(self, candidates, fallback: int): |
| for token in candidates: |
| token_id = self.vocab.get(token) |
| if token_id is not None: |
| return token_id |
| return fallback |
|
|
| def encode(self, text, max_length=128): |
| if self.backend is not None: |
| ids = self.backend.encode(text).ids[:max_length] |
| else: |
| words = text.split() |
| ids = [self.vocab.get(word, self.unk_token_id) for word in words][:max_length] |
| if len(ids) < max_length: |
| ids += [self.pad_token_id] * (max_length - len(ids)) |
| return ids |
|
|
| def decode(self, ids): |
| filtered = [int(i) for i in ids if int(i) != self.pad_token_id] |
| if self.backend is not None: |
| return self.backend.decode(filtered, skip_special_tokens=False) |
| return " ".join(self.id_to_token.get(i, "<UNK>") for i in filtered) |
|
|
|
|
| class HFTokenizerAdapter: |
| def __init__(self, tokenizer_name: str): |
| self.backend = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=True) |
| if self.backend.pad_token is None: |
| self.backend.pad_token = self.backend.eos_token or self.backend.unk_token |
| self.vocab = self.backend.get_vocab() |
| self.id_to_token = {idx: token for token, idx in self.vocab.items()} |
| self.vocab_size = int(self.backend.vocab_size) |
| self.pad_token_id = int(self.backend.pad_token_id) |
| self.unk_token_id = int(self.backend.unk_token_id if self.backend.unk_token_id is not None else self.pad_token_id) |
|
|
| def encode(self, text, max_length=128): |
| ids = self.backend.encode(text, add_special_tokens=False, truncation=True, max_length=max_length) |
| if len(ids) < max_length: |
| ids += [self.pad_token_id] * (max_length - len(ids)) |
| return ids |
|
|
| def decode(self, ids): |
| filtered = [int(i) for i in ids if int(i) != self.pad_token_id] |
| return self.backend.decode(filtered, skip_special_tokens=True) |
|
|
|
|
| def build_model(tokenizer): |
| return HierarchicalSSM( |
| vocab_size=tokenizer.vocab_size, |
| d_model=512, |
| d_state=32, |
| num_blocks=6, |
| num_experts=8, |
| top_k=2, |
| chunk_size=4, |
| expert_dim=1024, |
| ) |
|
|
|
|
| def build_hssm_v2_model(tokenizer, checkpoint_config: dict): |
| config = HSSMV2Config( |
| vocab_size=int(checkpoint_config.get("vocab_size", tokenizer.vocab_size)), |
| d_model=int(checkpoint_config.get("d_model", 288)), |
| n_layers=int(checkpoint_config.get("n_layers", 10)), |
| d_ff=int(checkpoint_config.get("d_ff", 512)), |
| state_rank=int(checkpoint_config.get("state_rank", 128)), |
| chunk_size=int(checkpoint_config.get("chunk_size", 8)), |
| dropout=float(checkpoint_config.get("dropout", 0.0)), |
| max_seq_len=int(checkpoint_config.get("max_seq_len", 1024)), |
| tie_embeddings=bool(checkpoint_config.get("tie_embeddings", True)), |
| num_experts=int(checkpoint_config.get("num_experts", 64)), |
| experts_per_token=int(checkpoint_config.get("experts_per_token", 1)), |
| expert_dim=int(checkpoint_config.get("expert_dim", 2048)), |
| moe_every=int(checkpoint_config.get("moe_every", 4)), |
| aux_loss_coef=float(checkpoint_config.get("aux_loss_coef", 1e-2)), |
| ) |
| return HSSMV2LM(config) |
|
|
|
|
| def _looks_like_hf_tokenizer_reference(tokenizer_path: str) -> bool: |
| path = Path(tokenizer_path) |
| return not path.exists() |
|
|
|
|
| def _load_tokenizer(tokenizer_path: str): |
| if _looks_like_hf_tokenizer_reference(tokenizer_path): |
| return HFTokenizerAdapter(tokenizer_path) |
| return CompatibleTokenizer(tokenizer_path) |
|
|
|
|
| def _is_hssm_v2_checkpoint(checkpoint: dict) -> bool: |
| config = checkpoint.get("config") if isinstance(checkpoint, dict) else None |
| if not isinstance(config, dict): |
| return False |
| required_keys = {"d_model", "n_layers", "state_rank", "chunk_size"} |
| return required_keys.issubset(config.keys()) |
|
|
|
|
| def load_pretrained(checkpoint_path: str, tokenizer_path: str, device: str): |
| checkpoint_file = Path(checkpoint_path) |
|
|
| if not checkpoint_file.exists(): |
| raise FileNotFoundError(f"Checkpoint not found: {checkpoint_file}") |
|
|
| if not _looks_like_hf_tokenizer_reference(tokenizer_path): |
| tokenizer_file = Path(tokenizer_path) |
| if not tokenizer_file.exists(): |
| raise FileNotFoundError(f"Tokenizer not found: {tokenizer_file}") |
|
|
| tokenizer = _load_tokenizer(tokenizer_path) |
|
|
| checkpoint = torch.load(str(checkpoint_file), map_location=device, weights_only=False) |
| state_dict = checkpoint["model_state_dict"] if "model_state_dict" in checkpoint else checkpoint |
| if _is_hssm_v2_checkpoint(checkpoint): |
| model = build_hssm_v2_model(tokenizer, checkpoint.get("config", {})) |
| else: |
| model = build_model(tokenizer) |
| missing, unexpected = model.load_state_dict(state_dict, strict=False) |
|
|
| model = model.to(device) |
| model.eval() |
|
|
| print("Loaded HSSM checkpoint") |
| print(f" Path: {checkpoint_file}") |
| print(f" Missing keys: {len(missing)}") |
| print(f" Unexpected keys: {len(unexpected)}") |
| if "epoch" in checkpoint: |
| print(f" Epoch: {checkpoint['epoch']}") |
| if "loss" in checkpoint: |
| print(f" Loss: {checkpoint['loss']}") |
| print(f" Model type: {'HSSM v2' if _is_hssm_v2_checkpoint(checkpoint) else 'RubiNet HSSM'}") |
|
|
| return tokenizer, model |
|
|
|
|
| def _model_chunk_size(model) -> int: |
| if hasattr(model, "chunk_size"): |
| return int(model.chunk_size) |
| if hasattr(model, "config") and hasattr(model.config, "chunk_size"): |
| return int(model.config.chunk_size) |
| return 1 |
|
|
|
|
| def _next_token_logits(model, input_tensor: torch.Tensor, current_len: int) -> torch.Tensor: |
| outputs = model(input_tensor) |
| if isinstance(outputs, dict): |
| logits = outputs.get("logits") |
| if logits is None: |
| raise ValueError("Model returned a dict without logits") |
| return logits[0, current_len - 1, :].clone() |
| chunk_size = _model_chunk_size(model) |
| chunk_idx = max((current_len - 1) // chunk_size, 0) |
| return outputs[0, chunk_idx, :].clone() |
|
|
|
|
| def build_prompt(user_text: str, cot_mode: bool = False) -> str: |
| cleaned_user_text = user_text.strip() |
| if cot_mode: |
| return ( |
| "system: Reply only in correct English. " |
| "Follow English grammar, spelling, punctuation, and sentence structure strictly. " |
| "Do not output fragments, corrupted tokens, mixed-language text, or placeholder symbols. " |
| "Think step by step briefly and keep the output clean. " |
| "Output exactly two lines in this format: " |
| "Reasoning: <very short reasoning>. " |
| "Answer: <final answer>. " |
| "Keep both lines grammatical and concise.\n" |
| f"user: {cleaned_user_text}\n" |
| "assistant:" |
| ) |
| return ( |
| "system: Reply only in correct English. " |
| "Follow English grammar, spelling, punctuation, and sentence structure strictly. " |
| "Use short complete sentences. " |
| "Do not output broken words, malformed tokens, mixed-language text, or placeholder symbols.\n" |
| f"user: {cleaned_user_text}\n" |
| "assistant:" |
| ) |
|
|
|
|
| def safe_print(text: str): |
| try: |
| print(text) |
| except UnicodeEncodeError: |
| sanitized = text.encode("utf-8", errors="replace").decode("utf-8", errors="replace") |
| print(sanitized) |
|
|
|
|
| def _apply_top_p_filter(logits: torch.Tensor, top_p: float) -> torch.Tensor: |
| if top_p >= 1.0: |
| return logits |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
| sorted_probs = F.softmax(sorted_logits, dim=-1) |
| cumulative_probs = torch.cumsum(sorted_probs, dim=-1) |
| sorted_indices_to_remove = cumulative_probs > top_p |
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
| sorted_indices_to_remove[..., 0] = False |
| indices_to_remove = sorted_indices[sorted_indices_to_remove] |
| logits[indices_to_remove] = float("-inf") |
| return logits |
|
|
|
|
| def _has_repeat_ngram(token_ids, next_token_id: int, ngram_size: int) -> bool: |
| if ngram_size <= 1 or len(token_ids) < ngram_size - 1: |
| return False |
| candidate = token_ids[-(ngram_size - 1):] + [next_token_id] |
| for i in range(len(token_ids) - ngram_size + 1): |
| if token_ids[i:i + ngram_size] == candidate: |
| return True |
| return False |
|
|
|
|
| def _normalize_word(word: str) -> str: |
| return re.sub(r"[^a-z0-9]+", "", word.lower()) |
|
|
|
|
| def _recent_word_counts(text: str, window: int = 12): |
| words = [_normalize_word(part) for part in text.split()] |
| words = [word for word in words if word] |
| recent = words[-window:] |
| counts = {} |
| for word in recent: |
| counts[word] = counts.get(word, 0) + 1 |
| return counts |
|
|
|
|
| def _violates_word_repeat(decoded_text: str, candidate_piece: str) -> bool: |
| candidate_word = _normalize_word(candidate_piece) |
| if not candidate_word: |
| return False |
| counts = _recent_word_counts(decoded_text, window=12) |
| return counts.get(candidate_word, 0) >= 2 |
|
|
|
|
| def _resolve_special_token_ids(tokenizer): |
| special_ids = set() |
| for token in ["<BOS>", "[BOS]", "<PAD>", "[PAD]", "<SEP>", "[SEP]", "<EOS>", "[EOS]", "<UNK>", "[UNK]", "<CLS>", "[CLS]", "<MASK>", "[MASK]", "<MASK>"]: |
| token_id = tokenizer.vocab.get(token) |
| if token_id is not None: |
| special_ids.add(int(token_id)) |
| if getattr(tokenizer, "pad_token_id", None) is not None: |
| special_ids.add(int(tokenizer.pad_token_id)) |
| if getattr(tokenizer, "unk_token_id", None) is not None: |
| special_ids.add(int(tokenizer.unk_token_id)) |
| return special_ids |
|
|
|
|
| def _contains_special_marker(text: str) -> bool: |
| upper_text = text.upper() |
| markers = ["<BOS>", "[BOS]", "<PAD>", "[PAD]", "<SEP>", "[SEP]", "<EOS>", "[EOS]", "<UNK>", "[UNK]", "<CLS>", "[CLS]", "<MASK>", "[MASK]"] |
| return any(marker in upper_text for marker in markers) |
|
|
|
|
| def _looks_like_artifact(text: str) -> bool: |
| stripped = text.strip() |
| if not stripped: |
| return False |
| if "##" in stripped: |
| return True |
| if stripped.startswith("##") or stripped.endswith("##"): |
| return True |
| if stripped.count("#") >= 1 and len(stripped) <= 4: |
| return True |
| if "�" in stripped: |
| return True |
| if any(ch in stripped for ch in ["�", ""]): |
| return True |
| if re.search(r"(.)\1{3,}", stripped.lower()): |
| return True |
| if re.fullmatch(r"[A-Za-z]{1,4}\d{2,}", stripped): |
| return True |
| if re.fullmatch(r"[#\-_=~`|.]+", stripped): |
| return True |
| return False |
|
|
|
|
| def _strip_special_markers(text: str) -> str: |
| cleaned = text |
| for pattern in [r"<\s*BOS\s*>", r"\[\s*BOS\s*\]", r"<\s*PAD\s*>", r"\[\s*PAD\s*\]", r"<\s*SEP\s*>", r"\[\s*SEP\s*\]", r"<\s*EOS\s*>", r"\[\s*EOS\s*\]", r"<\s*UNK\s*>", r"\[\s*UNK\s*\]", r"<\s*CLS\s*>", r"\[\s*CLS\s*\]", r"<\s*MASK\s*>", r"\[\s*MASK\s*\]"]: |
| cleaned = re.sub(pattern, " ", cleaned, flags=re.IGNORECASE) |
| cleaned = re.sub(r"#{2,}", " ", cleaned) |
| cleaned = re.sub(r"(?<!\w)#(?!\w)", " ", cleaned) |
| cleaned = cleaned.replace("�", " ") |
| cleaned = re.sub(r"\b([A-Za-z]+)(\s+\1\b){2,}", r"\1", cleaned, flags=re.IGNORECASE) |
| cleaned = re.sub(r"\b(\w{1,20})(\w{1,20})\1\b", r"\1", cleaned) |
| cleaned = re.sub(r"\s*([,;:.!?])\s*", r"\1 ", cleaned) |
| cleaned = re.sub(r"\s+", " ", cleaned) |
| return cleaned.strip() |
|
|
|
|
| def _cleanup_english_grammar(text: str) -> str: |
| cleaned = text.strip() |
| if not cleaned: |
| return cleaned |
|
|
| replacements = { |
| " im ": " I'm ", |
| " ive ": " I've ", |
| " ill ": " I'll ", |
| " id ": " I'd ", |
| " dont ": " don't ", |
| " cant ": " can't ", |
| " wont ": " won't ", |
| " didnt ": " didn't ", |
| " doesnt ": " doesn't ", |
| " isnt ": " isn't ", |
| " arent ": " aren't ", |
| " wasnt ": " wasn't ", |
| " werent ": " weren't ", |
| " thats ": " that's ", |
| " whats ": " what's ", |
| " theres ": " there's ", |
| " ive ": " I've ", |
| } |
|
|
| padded = f" {cleaned} " |
| for source, target in replacements.items(): |
| padded = re.sub(re.escape(source), target, padded, flags=re.IGNORECASE) |
| cleaned = padded.strip() |
|
|
| cleaned = re.sub(r"\bi\b", "I", cleaned) |
| cleaned = re.sub(r"\b([A-Za-z]+)(\s+\1\b){1,}", r"\1", cleaned, flags=re.IGNORECASE) |
| cleaned = re.sub(r"\s+([,;:.!?])", r"\1", cleaned) |
| cleaned = re.sub(r"([,;:.!?])(?!\s|$)", r"\1 ", cleaned) |
| cleaned = re.sub(r"\s+", " ", cleaned).strip() |
|
|
| if cleaned: |
| cleaned = cleaned[0].upper() + cleaned[1:] |
|
|
| sentences = re.split(r"(?<=[.!?])\s+", cleaned) |
| normalized_sentences = [] |
| for sentence in sentences: |
| sentence = sentence.strip() |
| if not sentence: |
| continue |
| if len(sentence) == 1: |
| normalized_sentences.append(sentence.upper()) |
| else: |
| normalized_sentences.append(sentence[0].upper() + sentence[1:]) |
| cleaned = " ".join(normalized_sentences).strip() |
|
|
| if cleaned and cleaned[-1] not in ".!?": |
| cleaned += "." |
|
|
| return cleaned |
|
|
|
|
| def _is_strict_english_output(text: str, cot_mode: bool = False) -> bool: |
| cleaned = text.strip() |
| if not cleaned: |
| return False |
| if any(token in cleaned for token in ["[", "]", "{", "}", "|", "<UNK>", "[UNK]", "<PAD>", "[PAD]"]): |
| return False |
| if re.search(r"[^A-Za-z0-9\s,.;:!?\-\'\"()\n]", cleaned): |
| return False |
| words = re.findall(r"[A-Za-z']+", cleaned) |
| if len(words) < 2: |
| return False |
| long_weird_words = [word for word in words if len(word) > 18] |
| if long_weird_words: |
| return False |
| if re.search(r"([A-Za-z]{2,})([A-Z][a-z]+)", cleaned): |
| return False |
| common_markers = { |
| "the", "a", "an", "is", "are", "am", "i", "you", "we", "they", "it", "to", "of", "and", |
| "that", "this", "can", "will", "do", "not", "yes", "no", "my", "your", "in", "on", "for" |
| } |
| lowered_words = [word.lower() for word in words] |
| if not any(word in common_markers for word in lowered_words): |
| return False |
| if cot_mode: |
| lines = [line.strip() for line in cleaned.splitlines() if line.strip()] |
| if len(lines) != 2: |
| return False |
| if not lines[0].startswith("Reasoning:"): |
| return False |
| if not lines[1].startswith("Answer:"): |
| return False |
| sentences = [segment.strip() for segment in re.split(r"(?<=[.!?])\s+", cleaned) if segment.strip()] |
| if not sentences: |
| return False |
| for sentence in sentences: |
| if not sentence[0].isupper(): |
| return False |
| if sentence[-1] not in ".!?": |
| return False |
| return True |
|
|
|
|
| def _force_cot_shape(text: str) -> str: |
| cleaned = text.strip() |
| if not cleaned: |
| return cleaned |
| lines = [line.strip() for line in cleaned.splitlines() if line.strip()] |
| if len(lines) >= 2 and lines[0].startswith("Reasoning:") and lines[1].startswith("Answer:"): |
| return f"{lines[0]}\n{lines[1]}" |
| parts = re.split(r"(?<=[.!?])\s+", cleaned, maxsplit=1) |
| if len(parts) == 2: |
| reasoning, answer = parts |
| else: |
| reasoning, answer = "Reasoning: Briefly considered the request.", f"Answer: {cleaned}" |
| return f"{reasoning}\n{answer}" |
| reasoning = reasoning if reasoning.startswith("Reasoning:") else f"Reasoning: {reasoning.strip()}" |
| answer = answer if answer.startswith("Answer:") else f"Answer: {answer.strip()}" |
| return f"{reasoning}\n{answer}" |
|
|
|
|
| def _ban_low_quality_candidates(tokenizer, logits: torch.Tensor): |
| for token_id in range(logits.size(0)): |
| piece = tokenizer.decode([token_id]).strip() |
| if not piece: |
| continue |
| if _contains_special_marker(piece): |
| logits[token_id] = float("-inf") |
|
|
|
|
| def _select_candidate_id(tokenizer, probs: torch.Tensor, generated, prompt_token_count: int, no_repeat_ngram_size: int): |
| candidate_count = min(24, probs.size(0)) |
| top_probs, top_ids = torch.topk(probs, candidate_count) |
| decoded_so_far = tokenizer.decode(generated[prompt_token_count:]).strip() |
|
|
| fallback_clean_id = None |
| fallback_clean_prob = -1.0 |
| fallback_any_id = None |
| fallback_any_prob = -1.0 |
| for prob_value, candidate_id_tensor in zip(top_probs.tolist(), top_ids.tolist()): |
| candidate_id = int(candidate_id_tensor) |
| candidate_piece = tokenizer.decode([candidate_id]).strip() |
| if not candidate_piece: |
| continue |
| if _contains_special_marker(candidate_piece): |
| continue |
| if fallback_any_id is None or prob_value > fallback_any_prob: |
| fallback_any_id = candidate_id |
| fallback_any_prob = prob_value |
| if _looks_like_artifact(candidate_piece): |
| continue |
| if _violates_word_repeat(decoded_so_far, candidate_piece): |
| continue |
| if _has_repeat_ngram(generated, candidate_id, max(no_repeat_ngram_size, 4)): |
| continue |
| normalized_piece = _normalize_word(candidate_piece) |
| if normalized_piece and decoded_so_far: |
| recent_words = [_normalize_word(part) for part in decoded_so_far.split()[-8:]] |
| recent_words = [word for word in recent_words if word] |
| if recent_words.count(normalized_piece) >= 1: |
| continue |
| if fallback_clean_id is None or prob_value > fallback_clean_prob: |
| fallback_clean_id = candidate_id |
| fallback_clean_prob = prob_value |
|
|
| if fallback_clean_id is not None: |
| return fallback_clean_id |
| return fallback_any_id |
|
|
|
|
| def _generate_fallback_reply(model, tokenizer, prompt_tokens, blocked_special_ids, max_length: int): |
| device = next(model.parameters()).device |
| generated = list(prompt_tokens) |
|
|
| with torch.no_grad(): |
| for _ in range(min(max_length, 16)): |
| current_len = len(generated) |
| chunk_size = _model_chunk_size(model) |
| pad_len = (chunk_size - current_len % chunk_size) % chunk_size |
| padded_input = generated + [tokenizer.pad_token_id] * pad_len |
| input_tensor = torch.tensor([padded_input], device=device) |
| next_token_logits = _next_token_logits(model, input_tensor, current_len) |
|
|
| for special_id in blocked_special_ids: |
| if 0 <= special_id < next_token_logits.size(0): |
| next_token_logits[special_id] = float("-inf") |
|
|
| next_token_id = int(torch.argmax(next_token_logits).item()) |
| if next_token_id == tokenizer.pad_token_id: |
| break |
|
|
| next_piece = tokenizer.decode([next_token_id]).strip() |
| if not next_piece or _contains_special_marker(next_piece): |
| break |
|
|
| generated.append(next_token_id) |
|
|
| return generated |
|
|
|
|
| def generate_reply( |
| model, |
| tokenizer, |
| prompt: str, |
| max_length: int, |
| temperature: float, |
| top_k: int, |
| top_p: float, |
| repetition_penalty: float, |
| no_repeat_ngram_size: int, |
| cot_mode: bool = False, |
| ): |
| model.eval() |
| device = next(model.parameters()).device |
|
|
| formatted_prompt = build_prompt(prompt, cot_mode=cot_mode) |
| prompt_ids = tokenizer.encode(formatted_prompt, max_length=128) |
| generated = [tok for tok in prompt_ids if tok != tokenizer.pad_token_id] |
| prompt_token_count = len(generated) |
| blocked_special_ids = _resolve_special_token_ids(tokenizer) |
| bos_token_id = tokenizer.vocab.get("<BOS>") |
| if not generated: |
| generated = [tokenizer.unk_token_id] |
| prompt_token_count = len(generated) |
|
|
| with torch.no_grad(): |
| for _ in range(max_length): |
| current_len = len(generated) |
| chunk_size = _model_chunk_size(model) |
| pad_len = (chunk_size - current_len % chunk_size) % chunk_size |
| padded_input = generated + [tokenizer.pad_token_id] * pad_len |
| input_tensor = torch.tensor([padded_input], device=device) |
|
|
| next_token_logits = _next_token_logits(model, input_tensor, current_len) |
|
|
| if temperature > 0: |
| next_token_logits = next_token_logits / temperature |
|
|
| for special_id in blocked_special_ids: |
| if 0 <= special_id < next_token_logits.size(0): |
| next_token_logits[special_id] = float("-inf") |
|
|
| if bos_token_id is not None and 0 <= int(bos_token_id) < next_token_logits.size(0): |
| next_token_logits[int(bos_token_id)] = float("-inf") |
|
|
| _ban_low_quality_candidates(tokenizer, next_token_logits) |
|
|
| recent_tokens = generated[-48:] |
| recent_weights = {} |
| for idx, token_id in enumerate(recent_tokens): |
| distance_weight = 1.0 + (idx / max(len(recent_tokens), 1)) |
| recent_weights[token_id] = max(recent_weights.get(token_id, 1.0), distance_weight) |
|
|
| for token_id, distance_weight in recent_weights.items(): |
| if 0 <= token_id < next_token_logits.size(0): |
| penalty = repetition_penalty * distance_weight |
| if next_token_logits[token_id] > 0: |
| next_token_logits[token_id] /= penalty |
| else: |
| next_token_logits[token_id] *= penalty |
|
|
| for token_id in range(next_token_logits.size(0)): |
| if _has_repeat_ngram(generated, token_id, no_repeat_ngram_size): |
| next_token_logits[token_id] = float("-inf") |
|
|
| if top_k > 0 and top_k < next_token_logits.size(0): |
| threshold = torch.topk(next_token_logits, top_k)[0][..., -1] |
| next_token_logits[next_token_logits < threshold] = float("-inf") |
|
|
| next_token_logits = _apply_top_p_filter(next_token_logits, top_p) |
| probs = F.softmax(next_token_logits, dim=-1) |
|
|
| if torch.isnan(probs).any() or torch.isinf(probs).any() or probs.sum() <= 0: |
| break |
|
|
| next_token = _select_candidate_id( |
| tokenizer, |
| probs, |
| generated, |
| prompt_token_count, |
| no_repeat_ngram_size, |
| ) |
|
|
| if next_token is None: |
| break |
|
|
| if next_token == tokenizer.pad_token_id: |
| break |
| generated.append(next_token) |
|
|
| decoded_output = tokenizer.decode(generated[prompt_token_count:]).strip() |
| if len(decoded_output.split()) >= 6: |
| tail_words = [_normalize_word(part) for part in decoded_output.split()[-4:]] |
| tail_words = [word for word in tail_words if word] |
| if len(tail_words) >= 4 and len(set(tail_words)) == 1: |
| break |
|
|
| output_ids = generated[prompt_token_count:] |
| cleaned_output = _strip_special_markers(tokenizer.decode(output_ids).strip()) |
| if cleaned_output: |
| normalized_output = _cleanup_english_grammar(cleaned_output) |
| if cot_mode: |
| normalized_output = _force_cot_shape(normalized_output) |
| return normalized_output |
|
|
| fallback_generated = _generate_fallback_reply( |
| model, |
| tokenizer, |
| generated[:prompt_token_count], |
| blocked_special_ids, |
| max_length, |
| ) |
| fallback_output_ids = fallback_generated[prompt_token_count:] |
| fallback_output = _strip_special_markers(tokenizer.decode(fallback_output_ids).strip()) |
| normalized_fallback_output = _cleanup_english_grammar(fallback_output) |
| if cot_mode: |
| normalized_fallback_output = _force_cot_shape(normalized_fallback_output) |
| return normalized_fallback_output |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Chat/test with pretrained RubiNet HSSM checkpoint") |
| parser.add_argument("--checkpoint", default=DEFAULT_CHECKPOINT) |
| parser.add_argument("--tokenizer", default=DEFAULT_TOKENIZER) |
| parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") |
| parser.add_argument("--max-length", type=int, default=40) |
| parser.add_argument("--temperature", type=float, default=0.0) |
| parser.add_argument("--top-k", type=int, default=4) |
| parser.add_argument("--top-p", type=float, default=0.65) |
| parser.add_argument("--repetition-penalty", type=float, default=1.9) |
| parser.add_argument("--no-repeat-ngram-size", type=int, default=6) |
| parser.add_argument("--cot-mode", action="store_true") |
| parser.add_argument("--no-cot-mode", action="store_false", dest="cot_mode") |
| parser.set_defaults(cot_mode=False) |
| parser.add_argument("--message", default="") |
| args = parser.parse_args() |
|
|
| tokenizer, model = load_pretrained(args.checkpoint, args.tokenizer, args.device) |
|
|
| if args.message: |
| output = generate_reply( |
| model, |
| tokenizer, |
| args.message, |
| args.max_length, |
| args.temperature, |
| args.top_k, |
| args.top_p, |
| args.repetition_penalty, |
| args.no_repeat_ngram_size, |
| args.cot_mode, |
| ) |
| safe_print(output) |
| return |
|
|
| print("Interactive HSSM chat/test. Type 'exit' to quit.") |
| while True: |
| user_text = input("You: ").strip() |
| if not user_text: |
| continue |
| if user_text.lower() in {"exit", "quit"}: |
| break |
| output = generate_reply( |
| model, |
| tokenizer, |
| user_text, |
| args.max_length, |
| args.temperature, |
| args.top_k, |
| args.top_p, |
| args.repetition_penalty, |
| args.no_repeat_ngram_size, |
| args.cot_mode, |
| ) |
| safe_print(f"HSSM: {output}\n") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|