#!/usr/bin/env python3 from __future__ import annotations import os import argparse import contextlib import inspect import json from pathlib import Path from typing import Any, List, Tuple, Optional import torch import torch.nn as nn from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer try: from .llopa_utils.saving_utils import load_embedding_layer, load_llopa_specials, read_backbone_ref except Exception: from llopa_utils.saving_utils import load_embedding_layer, load_llopa_specials, read_backbone_ref from transformers.cache_utils import DynamicCache # Be safe with tokenizers threads when forking os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") os.environ.setdefault("DS_BUILD_AIO", "0") os.environ.setdefault("DS_BUILD_OPS", "0") def _dtype_kwargs(from_pretrained_fn, dtype: torch.dtype) -> dict: try: params = inspect.signature(from_pretrained_fn).parameters except Exception: return {"torch_dtype": dtype} if "torch_dtype" in params: return {"torch_dtype": dtype} if "dtype" in params: return {"dtype": dtype} # With **kwargs-only signatures, prefer torch_dtype. Passing raw `dtype` # can leak into GenerationConfig kwargs on some transformers versions and # trigger JSON serialization errors for torch.dtype. if any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()): return {"torch_dtype": dtype} return {"torch_dtype": dtype} def _tokenizer_kwargs(from_pretrained_fn) -> dict: kwargs: dict = {} try: params = inspect.signature(from_pretrained_fn).parameters if "fix_mistral_regex" in params: kwargs["fix_mistral_regex"] = True elif any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()): # Older/newer tokenizer loaders often hide optional args behind **kwargs. kwargs["fix_mistral_regex"] = True except Exception: pass return kwargs def _normalize_dtype_arg(value) -> Optional[str]: if value is None: return None if isinstance(value, str): s = value.strip().lower() elif value is torch.bfloat16: s = "bfloat16" elif value is torch.float16: s = "float16" elif value is torch.float32: s = "float32" else: return None mapping = { "auto": "auto", "bf16": "bf16", "bfloat16": "bf16", "fp16": "fp16", "float16": "fp16", "half": "fp16", "fp32": "fp32", "float32": "fp32", } return mapping.get(s) # Special assistant-start token for Mistral-style templates MISTRAL_ASSIST_START = "" # Instruction appended after the question in prompts (requested). BOXED_ANSWER_INSTRUCTION = ( "Provide the answer and finish with the short answer inside \\\\boxed{}." ) MATH_STEP_BY_STEP_BOXED_INSTRUCTION = ( "Solve the problem step by step, and put the final answer inside \\\\boxed{}." ) # ----------------------------- # Custom modeling loader (TRI) # ----------------------------- def load_custom_modeling(modeling_path: str, model_family: str): """Load local tri_*_modeling.py and register it as the HF module.""" import importlib.util, sys import transformers # noqa: F401 if model_family == "llama": import transformers.models.llama # ensure package exists target_name = "transformers.models.llama.modeling_llama" expected = ("LlamaModel", "LlamaForCausalLM") elif model_family == "qwen3": import transformers.models.qwen3 # ensure package exists target_name = "transformers.models.qwen3.modeling_qwen3" expected = ("Qwen3Model", "Qwen3ForCausalLM") elif model_family == "mistral": import transformers.models.mistral # ensure package exists target_name = "transformers.models.mistral.modeling_mistral" expected = ("MistralModel", "MistralForCausalLM") else: raise ValueError(f"Unknown model_family: {model_family}") if target_name in sys.modules: del sys.modules[target_name] spec = importlib.util.spec_from_file_location(target_name, str(modeling_path)) if spec is None or spec.loader is None: raise RuntimeError(f"Failed to load spec for {modeling_path}") module = importlib.util.module_from_spec(spec) sys.modules[target_name] = module spec.loader.exec_module(module) for klass in expected: if not hasattr(module, klass): raise RuntimeError(f"{modeling_path} does not define {klass}") return module def infer_model_family(model_name: str, model_family_arg: str) -> str: if model_family_arg and model_family_arg != "auto": return model_family_arg name = (model_name or "").lower() if "qwen" in name: return "qwen3" if "llama" in name: return "llama" if "mistral" in name: return "mistral" return "llama" # ----------------------------- # Template helpers # ----------------------------- def _is_mistral_template(tokenizer) -> bool: tmpl = getattr(tokenizer, "chat_template", "") or "" name = getattr(getattr(tokenizer, "init_kwargs", {}), "get", lambda k, d=None: d)("name_or_path", "") return ("[INST]" in tmpl) or ("mistral" in str(name).lower()) or ("mistral" in tmpl.lower()) def _is_qwen3_tokenizer(tokenizer) -> bool: name = getattr(tokenizer, "name_or_path", "") or "" cls = tokenizer.__class__.__name__.lower() return ("qwen3" in name.lower()) or ("qwen3" in cls) def ensure_mistral_special_token(tokenizer, model=None): """Ensure the custom assistant-start token exists in tokenizer (and resize model embeddings if provided).""" if not _is_mistral_template(tokenizer): return False add_tok = [] cur = set(tokenizer.get_vocab().keys()) if MISTRAL_ASSIST_START not in cur: add_tok.append(MISTRAL_ASSIST_START) if add_tok: tokenizer.add_special_tokens({ "additional_special_tokens": tokenizer.special_tokens_map_extended.get("additional_special_tokens", []) + add_tok }) if model is not None: try: model.resize_token_embeddings(len(tokenizer)) except Exception: pass return True return False def build_prompt_parts(task: str, question: str, document: str) -> Tuple[str, str]: t = (task or "qa_doc").strip() q = (question or "").strip() d = (document or "").strip() if t == "summary": instruction = q if q else "Summarize the passage." doc_text = f"Passage:\n{d}" if d else "" return doc_text, instruction if t == "code": if not q: q = "Solve the following problem." return "", f"Write code to solve the following problem:\n{q}" if t == "math": mq = (q if q else d).strip() if mq: mq = f"{mq}\n{MATH_STEP_BY_STEP_BOXED_INSTRUCTION}" else: mq = MATH_STEP_BY_STEP_BOXED_INSTRUCTION return "", mq doc_text = f"Document:\n{d}" if d else "" query_text = f"Question: {q}\n{BOXED_ANSWER_INSTRUCTION}" if q else "" return doc_text, query_text def build_user_content(task: str, question: str, document: str) -> str: doc_text, query_text = build_prompt_parts(task, question, document) if doc_text and query_text: return f"{doc_text}\n\n{query_text}" return doc_text or query_text def build_messages(system: str, document: str, question: str, include_query: bool = True, task: str = "qa_doc"): q = question if include_query else "" user = build_user_content(task, q, document) if not user: user = (q or document or "") return [{"role": "system", "content": system}, {"role": "user", "content": user}] def apply_chat_template(tokenizer, messages, add_generation_prompt: bool): """Render chat with robust fallback across templates. - Prefer tokenizer.apply_chat_template(..., add_generation_prompt=...) - If that signature is unsupported, detect template style: * Llama-3 style → append assistant header tokens * Mistral/INST style → no explicit assistant header to append * Unknown → do not append anything """ force_thinking = getattr(tokenizer, "_force_enable_thinking", None) if force_thinking is None and _is_qwen3_tokenizer(tokenizer): force_thinking = False try: if force_thinking is None: return tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=add_generation_prompt ) return tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=add_generation_prompt, enable_thinking=bool(force_thinking), ) except TypeError: tmpl = getattr(tokenizer, "chat_template", "") or "" try: if force_thinking is None: s = tokenizer.apply_chat_template(messages, tokenize=False) else: s = tokenizer.apply_chat_template(messages, tokenize=False, enable_thinking=bool(force_thinking)) except TypeError: s = tokenizer.apply_chat_template(messages, tokenize=False) if add_generation_prompt: if "<|start_header_id|>" in tmpl: # Llama 3 style s += "<|start_header_id|>assistant<|end_header_id|>\n\n" elif "[INST]" in tmpl or "" in tmpl: # Mistral style: no explicit header s += "" else: # Unknown template → safest is to append nothing s += "" return s def tokens_from_messages(tokenizer, messages, device, add_generation_prompt=False): s = apply_chat_template(tokenizer, messages, add_generation_prompt) ids = tokenizer(s, add_special_tokens=False, return_tensors="pt").input_ids.to(device) # If Mistral template and generation prompt requested, append our assistant-start header token # if add_generation_prompt and _is_mistral_template(tokenizer): # try: # tok_id = tokenizer.convert_tokens_to_ids(MISTRAL_ASSIST_START) # if tok_id is not None and tok_id != tokenizer.unk_token_id: # extra = torch.tensor([[int(tok_id)]], device=ids.device, dtype=ids.dtype) # ids = torch.cat([ids, extra], dim=1) # except Exception: # pass return ids def _build_llopa_inputs(tokenizer, system: str, document: str, question: str, *, task: str = "qa_doc", device: str): doc_text, query_text = build_prompt_parts(task, question, document) msgs = build_messages(system, document, question, include_query=True, task=task) ids_sys = tokens_from_messages( tokenizer, [{"role": "system", "content": system}], device, add_generation_prompt=False ) ids_hdr = tokens_from_messages(tokenizer, msgs, device, add_generation_prompt=True) hdr_tail = _assistant_header_ids(tokenizer, device) ids_sys_user = None if hdr_tail is not None and hdr_tail.numel() > 0: tail_len = hdr_tail.size(1) if ids_hdr.size(1) >= tail_len and torch.equal(ids_hdr[:, -tail_len:], hdr_tail): ids_sys_user = ids_hdr[:, :-tail_len] if ids_sys_user is None: ids_sys_user = tokens_from_messages(tokenizer, msgs, device, add_generation_prompt=False) hdr_tail = ids_hdr[:, ids_sys_user.size(1):] if ids_sys_user.size(1) < ids_sys.size(1): raise ValueError("System-only tokens longer than system+user tokens.") user_ids_full = ids_sys_user[:, ids_sys.size(1):] user_content_full = msgs[-1]["content"] if query_text: doc_prefix_text = f"{doc_text}\n\n" if doc_text else "" else: doc_prefix_text = user_content_full user_doc_ids = user_ids_full[:, 0:0] user_q_ids = user_ids_full[:, 0:0] if user_ids_full.size(1) > 0: msgs_empty = [{"role": "system", "content": system}, {"role": "user", "content": ""}] ids_empty = tokens_from_messages(tokenizer, msgs_empty, device, add_generation_prompt=False) header_prefix_len = lcp_len(ids_empty, ids_sys_user) user_header_len = max(0, header_prefix_len - ids_sys.size(1)) doc_prefix_len = 0 if doc_prefix_text: doc_prefix_len = len(tokenizer(doc_prefix_text, add_special_tokens=False).input_ids) doc_end = min(user_ids_full.size(1), user_header_len + doc_prefix_len) user_doc_ids = user_ids_full[:, :doc_end] user_q_ids = user_ids_full[:, doc_end:] return { "prompt_ids": ids_hdr, "system_ids": ids_sys, "system_user_ids": ids_sys_user, "user_ids_full": user_ids_full, "user_doc_ids": user_doc_ids, "user_q_ids": user_q_ids, "hdr_tail": hdr_tail, } def build_messages_for_llopa(tokenizer, system: str, document: str, question: str, *, task: str = "qa_doc", device: str): return _build_llopa_inputs( tokenizer, system=system, document=document, question=question, task=task, device=device, ) def _normalize_prompt_messages(messages): out = [] if not isinstance(messages, list): return out for msg in messages: if not isinstance(msg, dict): continue role = str(msg.get("role") or "user").strip().lower() if role not in {"system", "user", "assistant"}: role = "user" content = str(msg.get("content") or "") if role != "assistant": content = content.strip() if not content: continue out.append({"role": role, "content": content}) return out def _assistant_header_starts_from_messages( tokenizer, prompt_messages, *, prompt_add_generation_prompt: bool, device, ): msgs = _normalize_prompt_messages(prompt_messages) if not msgs: return None, None starts: list[int] = [] for idx, msg in enumerate(msgs): if msg["role"] != "assistant": continue prefix_ids = tokens_from_messages(tokenizer, msgs[:idx], device, add_generation_prompt=False) starts.append(int(prefix_ids.size(1))) if bool(prompt_add_generation_prompt) and msgs[-1]["role"] != "assistant": prefix_ids = tokens_from_messages(tokenizer, msgs, device, add_generation_prompt=False) starts.append(int(prefix_ids.size(1))) if not starts: return None, None starts_tensor = torch.tensor([starts], device=device, dtype=torch.long) start_mask = torch.ones((1, len(starts)), device=device, dtype=torch.bool) return starts_tensor, start_mask def _assistant_turn_boundaries_from_messages( tokenizer, prompt_messages, *, prompt_add_generation_prompt: bool, device, ): msgs = _normalize_prompt_messages(prompt_messages) if not msgs: return None, None, None starts: list[int] = [] ends: list[int] = [] for idx, msg in enumerate(msgs): if msg["role"] != "assistant": continue prefix_ids = tokens_from_messages(tokenizer, msgs[:idx], device, add_generation_prompt=False) turn_ids = tokens_from_messages(tokenizer, msgs[: idx + 1], device, add_generation_prompt=False) starts.append(int(prefix_ids.size(1))) ends.append(int(turn_ids.size(1))) if bool(prompt_add_generation_prompt) and msgs[-1]["role"] != "assistant": prefix_ids = tokens_from_messages(tokenizer, msgs, device, add_generation_prompt=False) prompt_ids = tokens_from_messages(tokenizer, msgs, device, add_generation_prompt=True) starts.append(int(prefix_ids.size(1))) ends.append(int(prompt_ids.size(1))) if not starts: return None, None, None starts_tensor = torch.tensor([starts], device=device, dtype=torch.long) ends_tensor = torch.tensor([ends], device=device, dtype=torch.long) start_mask = torch.ones((1, len(starts)), device=device, dtype=torch.bool) return starts_tensor, ends_tensor, start_mask def _assistant_content_delta_from_messages(tokenizer, prefix_messages, assistant_text: str, su_gen, device): msgs_ass = list(prefix_messages) + [{"role": "assistant", "content": assistant_text}] full_ids = tokens_from_messages(tokenizer, msgs_ass, device, add_generation_prompt=False) if full_ids.size(1) <= su_gen.size(1): return full_ids[:, :0] return full_ids[:, su_gen.size(1):] def _strip_trailing_assistant_stop_tokens(tokenizer, token_ids: torch.Tensor) -> torch.Tensor: if not isinstance(token_ids, torch.Tensor) or token_ids.numel() == 0: return token_ids stop_ids = set() eos_id = getattr(tokenizer, "eos_token_id", None) if eos_id is not None: stop_ids.add(int(eos_id)) with contextlib.suppress(Exception): eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>") if eot_id is not None and eot_id != tokenizer.unk_token_id: stop_ids.add(int(eot_id)) if not stop_ids: return token_ids trimmed = token_ids while trimmed.size(1) > 0 and int(trimmed[0, -1].item()) in stop_ids: trimmed = trimmed[:, :-1] return trimmed def _resolve_user_replay_layout_from_messages( tokenizer, prefix_messages, *, system_len: int, user_len: int, device, ): system_len = max(int(system_len), 0) user_len = max(int(user_len), 0) if user_len <= 0: return 0, 0, 0 user_indices = [idx for idx, msg in enumerate(prefix_messages) if msg["role"] == "user"] if not user_indices: return 0, 0, user_len def _token_len(msgs) -> int: try: return int(tokens_from_messages(tokenizer, msgs, device, add_generation_prompt=False).size(1)) except Exception: return 0 first_user_idx = int(user_indices[0]) latest_user_idx = int(user_indices[-1]) prefix_before_first_user_len = _token_len(prefix_messages[:first_user_idx]) prefix_before_latest_user_len = _token_len(prefix_messages[:latest_user_idx]) prefix_through_latest_user_len = _token_len(prefix_messages[: latest_user_idx + 1]) user_prefix_keep_len = max(prefix_before_first_user_len - system_len, 0) user_prefix_keep_len = min(user_prefix_keep_len, user_len) latest_user_start = max(prefix_before_latest_user_len - system_len, user_prefix_keep_len) latest_user_start = min(latest_user_start, user_len) latest_user_end = max(prefix_through_latest_user_len - system_len, latest_user_start) latest_user_end = min(latest_user_end, user_len) latest_user_len = max(latest_user_end - latest_user_start, 0) if latest_user_len <= 0 and user_len > user_prefix_keep_len: latest_user_start = int(user_prefix_keep_len) latest_user_len = int(user_len - user_prefix_keep_len) return int(user_prefix_keep_len), int(latest_user_start), int(latest_user_len) def _build_structured_prompt_segments(tokenizer, prompt_messages, *, prompt_add_generation_prompt: bool, device): msgs = _normalize_prompt_messages(prompt_messages) if not msgs: raise ValueError("prompt_messages must contain at least one non-empty message.") if bool(prompt_add_generation_prompt): if msgs[-1]["role"] == "assistant": raise ValueError("prompt_add_generation_prompt=True requires prompt_messages to end with a non-assistant role.") prefix_messages = msgs prefix_ids = tokens_from_messages(tokenizer, prefix_messages, device, add_generation_prompt=False) prompt_ids = tokens_from_messages(tokenizer, prefix_messages, device, add_generation_prompt=True) assistant_prefill_ids = prompt_ids[:, prefix_ids.size(1):] else: if msgs[-1]["role"] != "assistant": raise ValueError("prompt_add_generation_prompt=False requires prompt_messages to end with an assistant prefix.") if len(msgs) < 2: raise ValueError("assistant-prefix prompts require at least one preceding non-assistant message.") prefix_messages = msgs[:-1] assistant_text = str(msgs[-1]["content"] or "") prefix_ids = tokens_from_messages(tokenizer, prefix_messages, device, add_generation_prompt=False) prompt_prefix = tokens_from_messages(tokenizer, prefix_messages, device, add_generation_prompt=True) assistant_content_ids = _assistant_content_delta_from_messages( tokenizer, prefix_messages, assistant_text, prompt_prefix, device, ) assistant_content_ids = _strip_trailing_assistant_stop_tokens(tokenizer, assistant_content_ids) assistant_header_ids = prompt_prefix[:, prefix_ids.size(1):] assistant_prefill_ids = torch.cat([assistant_header_ids, assistant_content_ids], dim=1) prompt_ids = torch.cat([prefix_ids, assistant_prefill_ids], dim=1) if assistant_prefill_ids.size(1) <= 0: raise ValueError("Structured direct LLoPA prompt produced an empty assistant prefill segment.") if prefix_messages and prefix_messages[0]["role"] == "system": system_ids = tokens_from_messages(tokenizer, [prefix_messages[0]], device, add_generation_prompt=False) user_ids = prefix_ids[:, system_ids.size(1):] else: system_ids = prefix_ids[:, :0] user_ids = prefix_ids replay_user_prefix_keep_len, replay_user_start, replay_user_len = _resolve_user_replay_layout_from_messages( tokenizer, prefix_messages, system_len=int(system_ids.size(1)), user_len=int(user_ids.size(1)), device=device, ) assistant_header_starts, assistant_turn_ends, assistant_header_start_mask = _assistant_turn_boundaries_from_messages( tokenizer, msgs, prompt_add_generation_prompt=bool(prompt_add_generation_prompt), device=device, ) return { "prefix_ids": prefix_ids, "prompt_ids": prompt_ids, "system_ids": system_ids, "user_ids": user_ids, "assistant_prefill_ids": assistant_prefill_ids, "replay_user_prefix_keep_len": int(replay_user_prefix_keep_len), "replay_user_start": int(replay_user_start), "replay_user_len": int(replay_user_len), "assistant_header_starts": assistant_header_starts, "assistant_turn_ends": assistant_turn_ends, "assistant_header_start_mask": assistant_header_start_mask, } def _build_unified_prefill_lower_prompt_bundle( tokenizer, *, prompt_messages, prompt_add_generation_prompt: bool, structured_prompt_segments=None, device, ): segments = structured_prompt_segments if isinstance(structured_prompt_segments, dict) else None if segments is None: segments = _build_structured_prompt_segments( tokenizer, prompt_messages, prompt_add_generation_prompt=bool(prompt_add_generation_prompt), device=device, ) prompt_ids = segments["prompt_ids"] system_ids = segments["system_ids"] prefix_ids = segments.get("prefix_ids") header_starts = segments.get("assistant_header_starts") assistant_turn_ends = segments.get("assistant_turn_ends") header_start_mask = segments.get("assistant_header_start_mask") replay_user_prefix_keep_len = int(segments.get("replay_user_prefix_keep_len", 0) or 0) replay_user_start = int(segments.get("replay_user_start", 0) or 0) replay_user_len = int(segments.get("replay_user_len", 0) or 0) assistant_header_start: Optional[int] = None if ( isinstance(header_starts, torch.Tensor) and header_starts.ndim == 2 and header_starts.size(0) == 1 and header_starts.numel() > 0 ): valid_mask = header_start_mask if not isinstance(valid_mask, torch.Tensor) or valid_mask.shape != header_starts.shape: valid_mask = header_starts >= 0 valid_starts = header_starts[0][valid_mask[0]] if valid_starts.numel() > 0: assistant_header_start = int(valid_starts[-1].item()) if isinstance(prefix_ids, torch.Tensor) and prefix_ids.ndim == 2 and prefix_ids.size(0) == 1: assistant_header_start = int(prefix_ids.size(1)) if assistant_header_start is None: header_ids = _assistant_header_ids(tokenizer, device=device) if isinstance(header_ids, torch.Tensor) and header_ids.ndim == 2 and header_ids.size(0) == 1: assistant_header_start = _find_last_subsequence_start(prompt_ids, header_ids) if assistant_header_start is None: assistant_header_start = max(int(prompt_ids.size(1) - 1), 0) return { "segments": segments, "prompt_ids": prompt_ids, "attention_mask": torch.ones_like(prompt_ids, device=prompt_ids.device), "assistant_header_start": int(assistant_header_start), "prefill_lower_split_start": torch.tensor( [int(assistant_header_start)], device=prompt_ids.device, dtype=torch.long, ), "assistant_header_starts": ( header_starts.to(device=prompt_ids.device, dtype=torch.long) if isinstance(header_starts, torch.Tensor) and header_starts.numel() > 0 else torch.tensor( [[int(assistant_header_start)]], device=prompt_ids.device, dtype=torch.long, ) ), "assistant_turn_ends": ( assistant_turn_ends.to(device=prompt_ids.device, dtype=torch.long) if isinstance(assistant_turn_ends, torch.Tensor) and assistant_turn_ends.numel() > 0 else torch.tensor( [[int(prompt_ids.size(1))]], device=prompt_ids.device, dtype=torch.long, ) ), "assistant_header_start_mask": ( header_start_mask.to(device=prompt_ids.device, dtype=torch.bool) if isinstance(header_start_mask, torch.Tensor) and header_start_mask.numel() > 0 else torch.ones((1, 1), device=prompt_ids.device, dtype=torch.bool) ), "prefill_lower_system_len": torch.tensor( [int(system_ids.size(1))], device=prompt_ids.device, dtype=torch.long, ), "prefill_lower_replay_user_prefix_keep_len": torch.tensor( [int(replay_user_prefix_keep_len)], device=prompt_ids.device, dtype=torch.long, ), "prefill_lower_replay_user_start": torch.tensor( [int(replay_user_start)], device=prompt_ids.device, dtype=torch.long, ), "prefill_lower_replay_user_len": torch.tensor( [int(replay_user_len)], device=prompt_ids.device, dtype=torch.long, ), } def _prompt_bundle_has_past_assistant_history(prompt_bundle) -> bool: if not isinstance(prompt_bundle, dict): return False header_starts = prompt_bundle.get("assistant_header_starts") turn_ends = prompt_bundle.get("assistant_turn_ends") if not isinstance(header_starts, torch.Tensor) or not isinstance(turn_ends, torch.Tensor): return False if header_starts.numel() == 0 or turn_ends.numel() == 0: return False if header_starts.dim() == 1: header_starts = header_starts.view(1, -1) if turn_ends.dim() == 1: turn_ends = turn_ends.view(1, -1) if header_starts.dim() != 2 or turn_ends.dim() != 2: return False header_mask = prompt_bundle.get("assistant_header_start_mask") if isinstance(header_mask, torch.Tensor) and header_mask.numel() > 0: if header_mask.dim() == 1: header_mask = header_mask.view(1, -1) if header_mask.dim() != 2 or header_mask.shape != header_starts.shape: header_mask = header_starts >= 0 else: header_mask = header_mask.to(device=header_starts.device, dtype=torch.bool) else: header_mask = header_starts >= 0 split_starts = prompt_bundle.get("effective_prefill_lower_split_start") if not isinstance(split_starts, torch.Tensor) or split_starts.numel() == 0: split_starts = prompt_bundle.get("prefill_lower_split_start") if isinstance(split_starts, torch.Tensor) and split_starts.numel() > 0: split_starts = split_starts.flatten().to(device=header_starts.device, dtype=torch.long) else: split_starts = torch.tensor( [int(prompt_bundle.get("assistant_header_start", 0) or 0)], device=header_starts.device, dtype=torch.long, ) rows = min(int(header_starts.size(0)), int(turn_ends.size(0))) cols = min(int(header_starts.size(1)), int(turn_ends.size(1))) if rows <= 0 or cols <= 0: return False for row in range(rows): split_idx = min(row, int(split_starts.numel()) - 1) split_start = int(split_starts[split_idx].item()) for col in range(cols): if not bool(header_mask[row, col].item()): continue turn_start = int(header_starts[row, col].item()) turn_end = int(turn_ends[row, col].item()) if turn_end <= turn_start or turn_start >= split_start: continue if min(turn_end, split_start) > turn_start: return True return False def _direct_prefill_lower_cache_and_logits( model, *, prompt_bundle, lower_k: int, prefill_attn: str, system_prefill: str, no_upper_attn: bool, see_past_assistant: bool = False, replay_module: str = "none", replay_per_layers: int = -1, last_layer_module: Optional[str] = None, seed_mode: str = "auto", ): if last_layer_module is not None and _normalize_replay_module_value(replay_module) == "none": replay_module = last_layer_module replay_module = _normalize_replay_module_value(replay_module) replay_per_layers = _normalize_replay_per_layers_value(replay_per_layers) seed_mode = _normalize_structured_llopa_seed_mode(seed_mode) if isinstance(prompt_bundle, dict): prompt_ids = prompt_bundle.get("prompt_ids") if isinstance(prompt_ids, torch.Tensor): attention_mask = prompt_bundle.get("attention_mask") if not isinstance(attention_mask, torch.Tensor) or attention_mask.shape != prompt_ids.shape: attention_mask = torch.ones_like(prompt_ids, device=prompt_ids.device, dtype=torch.long) else: attention_mask = attention_mask.to(device=prompt_ids.device, dtype=torch.long) effective_prompt_ids = prompt_ids effective_prompt_attention_mask = attention_mask split_starts = prompt_bundle.get("prefill_lower_split_start") if isinstance(split_starts, torch.Tensor): split_starts = split_starts.to(device=prompt_ids.device, dtype=torch.long) else: split_starts = torch.tensor( [int(prompt_bundle.get("assistant_header_start", max(int(prompt_ids.size(1) - 1), 0)))], device=prompt_ids.device, dtype=torch.long, ) valid_len = ( int(effective_prompt_attention_mask[0].sum().item()) if effective_prompt_attention_mask.ndim == 2 and effective_prompt_attention_mask.size(0) == 1 else int(effective_prompt_ids.size(1)) ) if valid_len > 0: effective_prompt_ids = effective_prompt_ids[:, :valid_len] effective_prompt_attention_mask = effective_prompt_attention_mask[:, :valid_len] prompt_bundle["effective_prompt_ids"] = effective_prompt_ids prompt_bundle["effective_prompt_attention_mask"] = effective_prompt_attention_mask prompt_bundle["effective_prefill_lower_split_start"] = split_starts segments = prompt_bundle.get("segments") if isinstance(prompt_bundle, dict) else None if not isinstance(segments, dict): return None if not bool(see_past_assistant): matched_inband_seed = _matched_inband_prefill_cache_and_logits( model, prompt_bundle=prompt_bundle, lower_k=lower_k, prefill_attn=prefill_attn, system_prefill=system_prefill, no_upper_attn=no_upper_attn, ) if matched_inband_seed is not None: return matched_inband_seed llopa_full_prompt_seed_fn = _get_llopa_full_prompt_seed(model) effective_prompt_ids = prompt_bundle.get("effective_prompt_ids") effective_prompt_attention_mask = prompt_bundle.get("effective_prompt_attention_mask") split_starts = prompt_bundle.get("effective_prefill_lower_split_start") system_lens = prompt_bundle.get("prefill_lower_system_len") replay_user_prefix_keep_lens = prompt_bundle.get("prefill_lower_replay_user_prefix_keep_len") replay_user_starts = prompt_bundle.get("prefill_lower_replay_user_start") replay_user_lens = prompt_bundle.get("prefill_lower_replay_user_len") assistant_header_starts = prompt_bundle.get("assistant_header_starts") assistant_turn_ends = prompt_bundle.get("assistant_turn_ends") assistant_header_start_mask = prompt_bundle.get("assistant_header_start_mask") needs_past_assistant_seed = bool(see_past_assistant) and _prompt_bundle_has_past_assistant_history(prompt_bundle) should_try_full_prompt_seed = seed_mode == "auto" or bool(needs_past_assistant_seed) full_prompt_seed_error = None if should_try_full_prompt_seed and callable(llopa_full_prompt_seed_fn) and isinstance(effective_prompt_ids, torch.Tensor): if not isinstance(effective_prompt_attention_mask, torch.Tensor): effective_prompt_attention_mask = torch.ones_like( effective_prompt_ids, device=effective_prompt_ids.device, dtype=torch.long, ) if not isinstance(split_starts, torch.Tensor): split_starts = prompt_bundle.get("prefill_lower_split_start") if not isinstance(system_lens, torch.Tensor): system_lens = torch.zeros( (effective_prompt_ids.size(0),), device=effective_prompt_ids.device, dtype=torch.long, ) try: full_prompt_seed = llopa_full_prompt_seed_fn( input_ids=effective_prompt_ids, attention_mask=effective_prompt_attention_mask, use_cache=True, logits_to_keep=1, lower_k=int(lower_k), prefill_attn=str(prefill_attn), system_prefill=str(system_prefill), no_upper_attn=bool(no_upper_attn), prefill_lower_split_start=split_starts, prefill_lower_system_len=system_lens, prefill_lower_replay_user_prefix_keep_len=replay_user_prefix_keep_lens, prefill_lower_replay_user_start=replay_user_starts, prefill_lower_replay_user_len=replay_user_lens, assistant_header_starts=assistant_header_starts, assistant_turn_ends=assistant_turn_ends, assistant_header_start_mask=assistant_header_start_mask, prefill_lower_see_past_assistant=bool(see_past_assistant), replay_module=str(replay_module), replay_per_layers=int(replay_per_layers), ) if full_prompt_seed is not None: return full_prompt_seed except Exception as exc: full_prompt_seed_error = exc if bool(needs_past_assistant_seed): if not callable(llopa_full_prompt_seed_fn): raise RuntimeError( "LLOPA_SEE_PAST_ASSISTANT=1 requires llopa_full_prompt_prefill_seed " "for prefill_header prompts with previous assistant turns." ) if full_prompt_seed_error is not None: raise RuntimeError( "LLOPA_SEE_PAST_ASSISTANT=1 failed in llopa_full_prompt_prefill_seed " "for a prefill_header prompt with previous assistant turns." ) from full_prompt_seed_error raise RuntimeError( "LLOPA_SEE_PAST_ASSISTANT=1 could not build a prefill_header seed " "that includes previous assistant turns." ) seed_fn = getattr(model, "llopa_reference_prefill_seed", None) if not callable(seed_fn): return None try: return seed_fn( system_ids=segments["system_ids"], user_ids=segments["user_ids"], assistant_ids=segments["assistant_prefill_ids"], lower_k=int(lower_k), prefill_attn=str(prefill_attn), system_prefill=str(system_prefill), no_upper_attn=bool(no_upper_attn), replay_module=str(replay_module), replay_per_layers=int(replay_per_layers), replay_user_prefix_keep_len=int(segments.get("replay_user_prefix_keep_len", 0) or 0), replay_user_start=int(segments.get("replay_user_start", 0) or 0), replay_user_len=int(segments.get("replay_user_len", 0) or 0), ) except Exception: return None def _optimized_prefill_lower_cache_and_logits( model, *, prompt_bundle, lower_k: int, prefill_attn: str, system_prefill: str, no_upper_attn: bool, see_past_assistant: bool = False, replay_module: str = "none", replay_per_layers: int = -1, last_layer_module: Optional[str] = None, seed_mode: str = "auto", ): if last_layer_module is not None and _normalize_replay_module_value(replay_module) == "none": replay_module = last_layer_module replay_module = _normalize_replay_module_value(replay_module) replay_per_layers = _normalize_replay_per_layers_value(replay_per_layers) seed_mode = _normalize_optimized_llopa_seed_mode(seed_mode) if seed_mode in {"matched", "auto"}: matched_seed = _matched_inband_prefill_cache_and_logits( model, prompt_bundle=prompt_bundle, lower_k=lower_k, prefill_attn=prefill_attn, system_prefill=system_prefill, no_upper_attn=no_upper_attn, replay_module=replay_module, replay_per_layers=replay_per_layers, ) if matched_seed is not None: return matched_seed llopa_full_prompt_seed_fn = _get_llopa_full_prompt_seed(model) if seed_mode in {"tri", "auto"} and callable(llopa_full_prompt_seed_fn) and isinstance(prompt_bundle, dict): effective_prompt_ids = prompt_bundle.get("effective_prompt_ids") effective_prompt_attention_mask = prompt_bundle.get("effective_prompt_attention_mask") split_starts = prompt_bundle.get("effective_prefill_lower_split_start") system_lens = prompt_bundle.get("prefill_lower_system_len") replay_user_prefix_keep_lens = prompt_bundle.get("prefill_lower_replay_user_prefix_keep_len") replay_user_starts = prompt_bundle.get("prefill_lower_replay_user_start") replay_user_lens = prompt_bundle.get("prefill_lower_replay_user_len") assistant_header_starts = prompt_bundle.get("assistant_header_starts") assistant_header_start_mask = prompt_bundle.get("assistant_header_start_mask") if isinstance(effective_prompt_ids, torch.Tensor): if not isinstance(effective_prompt_attention_mask, torch.Tensor): effective_prompt_attention_mask = torch.ones_like( effective_prompt_ids, device=effective_prompt_ids.device, dtype=torch.long, ) if not isinstance(split_starts, torch.Tensor): split_starts = prompt_bundle.get("prefill_lower_split_start") if not isinstance(split_starts, torch.Tensor): split_starts = torch.tensor( [int(prompt_bundle.get("assistant_header_start", max(int(effective_prompt_ids.size(1) - 1), 0)))], device=effective_prompt_ids.device, dtype=torch.long, ) if not isinstance(system_lens, torch.Tensor): system_lens = torch.zeros( (effective_prompt_ids.size(0),), device=effective_prompt_ids.device, dtype=torch.long, ) if not isinstance(replay_user_prefix_keep_lens, torch.Tensor): replay_user_prefix_keep_lens = torch.zeros( (effective_prompt_ids.size(0),), device=effective_prompt_ids.device, dtype=torch.long, ) if not isinstance(replay_user_starts, torch.Tensor): replay_user_starts = replay_user_prefix_keep_lens.clone() if not isinstance(replay_user_lens, torch.Tensor): replay_user_lens = torch.zeros( (effective_prompt_ids.size(0),), device=effective_prompt_ids.device, dtype=torch.long, ) try: full_prompt_seed = llopa_full_prompt_seed_fn( input_ids=effective_prompt_ids, attention_mask=effective_prompt_attention_mask, use_cache=True, logits_to_keep=1, lower_k=int(lower_k), prefill_attn=str(prefill_attn), system_prefill=str(system_prefill), no_upper_attn=bool(no_upper_attn), replay_module=str(replay_module), replay_per_layers=int(replay_per_layers), prefill_lower_split_start=split_starts, prefill_lower_system_len=system_lens, prefill_lower_replay_user_prefix_keep_len=replay_user_prefix_keep_lens, prefill_lower_replay_user_start=replay_user_starts, prefill_lower_replay_user_len=replay_user_lens, assistant_header_starts=assistant_header_starts, assistant_turn_ends=prompt_bundle.get("assistant_turn_ends"), assistant_header_start_mask=assistant_header_start_mask, prefill_lower_see_past_assistant=bool(see_past_assistant), ) if full_prompt_seed is not None: return full_prompt_seed except Exception: pass if seed_mode in {"stable", "auto", "matched"}: return _direct_prefill_lower_cache_and_logits( model, prompt_bundle=prompt_bundle, lower_k=lower_k, prefill_attn=prefill_attn, system_prefill=system_prefill, no_upper_attn=no_upper_attn, see_past_assistant=bool(see_past_assistant), replay_module=replay_module, replay_per_layers=replay_per_layers, ) return None def _normalize_structured_llopa_seed_mode(seed_mode: Optional[str]) -> str: normalized = str(seed_mode or "auto").strip().lower() aliases = { "": "auto", "default": "auto", "tri": "auto", "tri_auto": "auto", "reference": "prefill_header", "reference_only": "prefill_header", "prefill-header": "prefill_header", "prefill_header_seed": "prefill_header", "prefill_header_only": "prefill_header", } normalized = aliases.get(normalized, normalized) if normalized not in {"auto", "prefill_header"}: normalized = "auto" return normalized def _llopa_modeling_module(model) -> Optional[Any]: llopa_core = _get_llopa_core(model) if llopa_core is None: return None with contextlib.suppress(Exception): return inspect.getmodule(llopa_core.__class__) return None def _matched_inband_prefill_cache_and_logits( model, *, prompt_bundle, lower_k: int, prefill_attn: str, system_prefill: str, no_upper_attn: bool, ): if bool(no_upper_attn): return None llopa_core = _get_llopa_core(model) output_head = _get_output_head(model) llopa_mod = _llopa_modeling_module(model) if llopa_core is None or output_head is None or llopa_mod is None: return None fusion_mode = str( getattr(getattr(llopa_core, "config", None), "capsule_fusion_mode", "upper_only") or "upper_only" ).strip().lower() if fusion_mode != "inband": return None attn = (prefill_attn or "causal").strip().lower() if attn == "prefix_full": attn = "full" if attn != "causal": return None required_names = ( "_safe_dynamic_cache", "_tri_arange", "_llopa_position_ids_from_mask", "_resolve_attn_impl", "_can_use_implicit_causal_mask", "_llopa_mask_is_all_ones", "_build_tri_mask_local", "_tri_insert_suffix_specials_inband", "_tri_effective_suffix_special_token_ids", "_tri_build_prefill_lower_upper_index_batch", "_tri_pack_indexed_tensor", "create_causal_mask", ) if any(not hasattr(llopa_mod, name) for name in required_names): return None prompt_ids = prompt_bundle["prompt_ids"] attention_mask = prompt_bundle["attention_mask"].to(device=prompt_ids.device, dtype=torch.long) system_len = int(prompt_bundle["prefill_lower_system_len"][0].item()) if prompt_bundle["prefill_lower_system_len"].numel() > 0 else 0 full_ids = prompt_ids full_attention_mask = attention_mask split_starts = prompt_bundle.get("prefill_lower_split_start") if isinstance(split_starts, torch.Tensor): split_starts = split_starts.to(device=prompt_ids.device, dtype=torch.long) else: split_starts = torch.tensor( [int(prompt_bundle.get("assistant_header_start", max(int(prompt_ids.size(1) - 1), 0)))], device=prompt_ids.device, dtype=torch.long, ) token_ids = list(getattr(llopa_mod, "_tri_effective_suffix_special_token_ids")(model) or []) if token_ids: header_starts = prompt_bundle.get("assistant_header_starts") header_start_mask = prompt_bundle.get("assistant_header_start_mask") ( full_ids, full_attention_mask, _, remapped_split_starts, _, _, ) = getattr(llopa_mod, "_tri_insert_suffix_specials_inband")( token_ids=token_ids, input_ids=full_ids, attention_mask=full_attention_mask, labels=None, split_starts=split_starts, assistant_header_starts=header_starts, assistant_header_start_mask=header_start_mask, ) if isinstance(remapped_split_starts, torch.Tensor) and remapped_split_starts.numel() > 0: split_starts = remapped_split_starts.to(device=prompt_ids.device, dtype=torch.long) valid_len = int(full_attention_mask[0].sum().item()) if valid_len <= 0: return None full_ids = full_ids[:, :valid_len] full_attention_mask = full_attention_mask[:, :valid_len] split_start = max(0, min(int(split_starts[0].item()) if split_starts.numel() > 0 else int(valid_len - 1), valid_len)) try: lower_k = int(lower_k) except Exception: return None n_layers = len(getattr(llopa_core, "layers", [])) if lower_k <= 0 or n_layers <= 0: return None lower_k = max(0, min(lower_k, n_layers)) device = full_ids.device pkv = getattr(llopa_mod, "_safe_dynamic_cache")(llopa_core.config) inputs_embeds = llopa_core.embed_tokens(full_ids) cache_position = getattr(llopa_mod, "_tri_arange")(0, inputs_embeds.shape[1], device) position_ids = getattr(llopa_mod, "_llopa_position_ids_from_mask")(full_attention_mask) attn_impl = getattr(llopa_mod, "_resolve_attn_impl")(llopa_core.config) if attn_impl == "flash_attention_2": lower_mask = None if getattr(llopa_mod, "_llopa_mask_is_all_ones")(full_attention_mask) else full_attention_mask elif getattr(llopa_mod, "_can_use_implicit_causal_mask")(llopa_core.config) and getattr(llopa_mod, "_llopa_mask_is_all_ones")(full_attention_mask): lower_mask = None else: lower_mask = getattr(llopa_mod, "create_causal_mask")( config=llopa_core.config, input_embeds=inputs_embeds, attention_mask=full_attention_mask, cache_position=cache_position, past_key_values=None, position_ids=position_ids, ) hidden_states = inputs_embeds position_embeddings = llopa_core.rotary_emb(hidden_states, position_ids) for li in range(lower_k): layer = llopa_core.layers[li] hidden_states = layer( hidden_states, attention_mask=lower_mask, position_ids=position_ids, past_key_values=pkv, use_cache=True, cache_position=cache_position, position_embeddings=position_embeddings, ) split_starts = torch.tensor([split_start], device=hidden_states.device, dtype=torch.long) valid_lens = torch.tensor([valid_len], device=hidden_states.device, dtype=torch.long) system_lens = torch.tensor([system_len], device=hidden_states.device, dtype=torch.long) upper_gather_idx, upper_valid_mask, upper_lens = getattr(llopa_mod, "_tri_build_prefill_lower_upper_index_batch")( split_starts=split_starts, valid_lens=valid_lens, system_lens=system_lens, system_prefill=str(system_prefill), device=hidden_states.device, ) upper_hidden, _ = getattr(llopa_mod, "_tri_pack_indexed_tensor")( hidden_states, gather_idx=upper_gather_idx, valid_mask=upper_valid_mask, pad_value=0.0, ) upper_position_ids_src = position_ids.to(device=hidden_states.device, dtype=torch.long) upper_position_ids, _ = getattr(llopa_mod, "_tri_pack_indexed_tensor")( upper_position_ids_src, gather_idx=upper_gather_idx, valid_mask=upper_valid_mask, pad_value=0, ) upper_len = int(upper_lens[0].item()) if upper_lens.numel() > 0 else 0 if upper_len <= 0: return None upper_hidden = upper_hidden[:, :upper_len, :] upper_position_ids = upper_position_ids[:, :upper_len] upper_cache_position = upper_position_ids[0] if lower_k < n_layers: if attn_impl == "flash_attention_2" or getattr(llopa_mod, "_can_use_implicit_causal_mask")(llopa_core.config): upper_mask = None else: upper_mask = getattr(llopa_mod, "_build_tri_mask_local")( 1, upper_len, 0, upper_hidden.device, upper_hidden.dtype, ) upper_pos_emb = llopa_core.rotary_emb(upper_hidden, upper_position_ids) for li in range(lower_k, n_layers): layer = llopa_core.layers[li] upper_hidden = layer( upper_hidden, attention_mask=upper_mask, position_ids=upper_position_ids, past_key_values=pkv, use_cache=True, cache_position=upper_cache_position, position_embeddings=upper_pos_emb, ) upper_hidden = llopa_core.norm(upper_hidden) initial_logits = output_head(upper_hidden[:, -1:, :])[:, -1, :].to(torch.float32) sys_mode = str(system_prefill or "full").strip().lower() if sys_mode == "full": visible_prefix_len = system_len elif sys_mode == "no_system": visible_prefix_len = min(system_len, 1) else: visible_prefix_len = 0 return pkv, int(visible_prefix_len), max(int(split_start) - int(visible_prefix_len), 0), initial_logits def _coerce_llopa_inputs(tokenizer, system: str, document: str, question: str, *, task: str, device: str, input_ids): if not isinstance(input_ids, dict): return _build_llopa_inputs( tokenizer, system=system, document=document, question=question, task=task, device=device, ) def _get_first(mapping, *keys): for k in keys: if k in mapping and mapping[k] is not None: return mapping[k] return None ids_sys = _get_first(input_ids, "system_ids", "ids_sys") ids_sys_user = _get_first(input_ids, "system_user_ids", "ids_sys_user") user_doc_ids = _get_first(input_ids, "user_doc_ids") user_q_ids = _get_first(input_ids, "user_q_ids") hdr_tail = _get_first(input_ids, "hdr_tail") prompt_ids = _get_first(input_ids, "prompt_ids", "ids_hdr") if (ids_sys is None or ids_sys_user is None or user_doc_ids is None or user_q_ids is None or hdr_tail is None): built = _build_llopa_inputs( tokenizer, system=system, document=document, question=question, task=task, device=device, ) if ids_sys is None: ids_sys = built["system_ids"] if ids_sys_user is None: ids_sys_user = built["system_user_ids"] if user_doc_ids is None: user_doc_ids = built["user_doc_ids"] if user_q_ids is None: user_q_ids = built["user_q_ids"] if hdr_tail is None: hdr_tail = built["hdr_tail"] if prompt_ids is None: prompt_ids = built.get("prompt_ids") if prompt_ids is None and ids_sys_user is not None and hdr_tail is not None: try: prompt_ids = torch.cat([ids_sys_user, hdr_tail], dim=1) except Exception: prompt_ids = None return { "prompt_ids": prompt_ids, "system_ids": ids_sys, "system_user_ids": ids_sys_user, "user_doc_ids": user_doc_ids, "user_q_ids": user_q_ids, "hdr_tail": hdr_tail, } def lcp_len(a: torch.Tensor, b: torch.Tensor) -> int: L = min(a.size(1), b.size(1)) eq = (a[0, :L] == b[0, :L]) nz = (~eq).nonzero(as_tuple=False) return int(nz[0, 0]) if nz.numel() else L def _assistant_header_ids_from_chat_template(tokenizer, device): """Infer assistant header ids by diffing chat-template renders.""" probe_messages_list = [ [{"role": "system", "content": "system"}, {"role": "user", "content": "user"}], [{"role": "user", "content": "user"}], ] for messages in probe_messages_list: try: rendered_no_prompt = apply_chat_template(tokenizer, messages, add_generation_prompt=False) rendered_with_prompt = apply_chat_template(tokenizer, messages, add_generation_prompt=True) except Exception: continue if not isinstance(rendered_no_prompt, str) or not isinstance(rendered_with_prompt, str): continue if not rendered_with_prompt or rendered_with_prompt == rendered_no_prompt: continue if rendered_with_prompt.startswith(rendered_no_prompt): suffix_text = rendered_with_prompt[len(rendered_no_prompt):] if suffix_text: try: ids = tokenizer( suffix_text, add_special_tokens=False, return_tensors="pt", ).input_ids.to(device) if ids.numel() > 0: return ids except Exception: pass try: ids_no_prompt = tokens_from_messages(tokenizer, messages, device, add_generation_prompt=False) ids_with_prompt = tokens_from_messages(tokenizer, messages, device, add_generation_prompt=True) except Exception: continue if ids_with_prompt.numel() == 0 or ids_with_prompt.size(1) <= ids_no_prompt.size(1): continue prefix_len = lcp_len(ids_no_prompt, ids_with_prompt) if prefix_len < ids_with_prompt.size(1): delta_ids = ids_with_prompt[:, prefix_len:] if delta_ids.numel() > 0: return delta_ids return None def _assistant_header_ids(tokenizer, device): """Best-effort header ids appended by add_generation_prompt for common templates.""" if _is_mistral_template(tokenizer): try: tok_id = tokenizer.convert_tokens_to_ids(MISTRAL_ASSIST_START) if tok_id is None or tok_id == tokenizer.unk_token_id: return None return torch.tensor([[int(tok_id)]], device=device, dtype=torch.long) except Exception: return None tmpl = getattr(tokenizer, "chat_template", "") or "" if "<|start_header_id|>" in tmpl: header = "<|start_header_id|>assistant<|end_header_id|>\n\n" try: return tokenizer(header, add_special_tokens=False, return_tensors="pt").input_ids.to(device) except Exception: return None return _assistant_header_ids_from_chat_template(tokenizer, device) def split_system_user_ids(tokenizer, system_text: str, user_text: str, device): msgs_sys = [{"role": "system", "content": system_text}] msgs_sys_user = [{"role": "system", "content": system_text}, {"role": "user", "content": user_text}] ids_sys = tokens_from_messages(tokenizer, msgs_sys, device, add_generation_prompt=False) ids_sys_user = tokens_from_messages(tokenizer, msgs_sys_user, device, add_generation_prompt=False) if ids_sys_user.size(1) < ids_sys.size(1): raise ValueError("System-only tokens longer than system+user tokens.") user_ids = ids_sys_user[:, ids_sys.size(1):] return ids_sys, user_ids, ids_sys_user # ----------------------------- # DynamicCache helpers # ----------------------------- def pkv_len(pkv) -> int: if hasattr(pkv, "layers"): return len(pkv.layers) if hasattr(pkv, "key_cache"): return len(pkv.key_cache) return len(pkv) def pkv_get(pkv, idx: int): if hasattr(pkv, "layers"): layer = pkv.layers[idx] return layer.keys, layer.values if hasattr(pkv, "key_cache"): return pkv.key_cache[idx], pkv.value_cache[idx] return pkv[idx] def dc_from_subset(pkv_src, idxs: List[int]) -> DynamicCache: dc = DynamicCache() for li in idxs: k, v = pkv_get(pkv_src, li) dc.update(k, v, li) return dc def _safe_dynamic_cache(config=None) -> DynamicCache: try: return DynamicCache(config=config) except TypeError as exc: if "max_cache_len" in str(exc): return DynamicCache() raise def _get_inner_model(m): """Return the decoder backbone that owns `.layers` (robust across wrappers).""" # unwrap DDP/Accelerate if hasattr(m, "module"): m = m.module # unwrap PEFT try: from peft import PeftModel if isinstance(m, PeftModel): try: m = m.get_base_model() except Exception: m = getattr(m, "base_model", m) except Exception: pass for attr in ("model", "transformer", "backbone", "base_model", "language_model"): if hasattr(m, attr): cand = getattr(m, attr) if hasattr(cand, "layers") and isinstance(getattr(cand, "layers", None), nn.ModuleList): return cand if hasattr(cand, "decoder") and hasattr(cand.decoder, "layers") and isinstance(cand.decoder.layers, nn.ModuleList): return cand.decoder if hasattr(m, "layers") and isinstance(getattr(m, "layers", None), nn.ModuleList): return m for child in m.modules(): if child is m: continue if hasattr(child, "layers") and isinstance(getattr(child, "layers", None), nn.ModuleList): return child raise AttributeError("Could not locate inner base model with a .layers attribute") def _get_llopa_core(model): """Return the decoder object that owns the LLoPA prefill/cache hooks.""" inner = _get_inner_model(model) if hasattr(inner, "llopa_prefill_cache"): return inner if hasattr(inner, "model") and hasattr(inner.model, "llopa_prefill_cache"): return inner.model if hasattr(inner, "tri_build_caches"): return inner if hasattr(inner, "model") and hasattr(inner.model, "tri_build_caches"): return inner.model return None def _get_llopa_decode_step(model): """Return the cached LLoPA decode-step callable if present.""" for name in ("llopa_decode_step_logits", "tri_step_logits"): if hasattr(model, name): return getattr(model, name) try: from peft import PeftModel if isinstance(model, PeftModel): try: base = model.get_base_model() except Exception: base = getattr(model, "base_model", None) if base is not None: for name in ("llopa_decode_step_logits", "tri_step_logits"): if hasattr(base, name): return getattr(base, name) except Exception: pass return None def _get_llopa_full_prompt_seed(model): """Return the full-prompt seed callable used for past-assistant contexts.""" for name in ("llopa_full_prompt_prefill_seed", "tri_reference_prefill_seed"): if hasattr(model, name): return getattr(model, name) try: from peft import PeftModel if isinstance(model, PeftModel): try: base = model.get_base_model() except Exception: base = getattr(model, "base_model", None) if base is not None: for name in ("llopa_full_prompt_prefill_seed", "tri_reference_prefill_seed"): if hasattr(base, name): return getattr(base, name) except Exception: pass return None def _has_active_llopa_runtime(model) -> bool: """Return True when LLoPA hooks are present on the loaded model.""" return _get_llopa_core(model) is not None and _get_llopa_decode_step(model) is not None def _round_up_to_multiple(value: int, multiple: int) -> int: value = int(value) multiple = int(multiple) if multiple <= 0: return value return ((value + multiple - 1) // multiple) * multiple _OPTIMIZED_LLOPA_VARIANT_PRESETS = { "baseline": { "seed_mode": "auto", "upper_prepare_mode": "exact", "upper_bucket_multiple": 0, "seq_bucket_multiple": 256, }, "upper_ws_auto": { "seed_mode": "auto", "upper_prepare_mode": "bucketed_workspace", "upper_bucket_multiple": 256, "seq_bucket_multiple": 256, }, "upper_ws_auto_128": { "seed_mode": "auto", "upper_prepare_mode": "bucketed_workspace", "upper_bucket_multiple": 128, "seq_bucket_multiple": 128, }, "upper_ws_tri": { "seed_mode": "tri", "upper_prepare_mode": "bucketed_workspace", "upper_bucket_multiple": 256, "seq_bucket_multiple": 256, }, "upper_ws_stable": { "seed_mode": "stable", "upper_prepare_mode": "bucketed_workspace", "upper_bucket_multiple": 256, "seq_bucket_multiple": 256, }, "upper_ws_matched": { "seed_mode": "matched", "upper_prepare_mode": "bucketed_workspace", "upper_bucket_multiple": 256, "seq_bucket_multiple": 256, }, } def _normalize_optimized_llopa_seed_mode(seed_mode: Optional[str]) -> str: raw = str(seed_mode or "auto").strip().lower() if raw in {"", "default"}: raw = "auto" if raw not in {"auto", "tri", "stable", "matched"}: raw = "auto" return raw def _normalize_optimized_llopa_upper_prepare_mode(mode: Optional[str]) -> str: raw = str(mode or "exact").strip().lower() if raw in {"", "default"}: raw = "exact" if raw not in {"exact", "bucketed_workspace"}: raw = "exact" return raw def _resolve_optimized_llopa_settings( *, variant: Optional[str], seed_mode: Optional[str], upper_prepare_mode: Optional[str], upper_bucket_multiple: Optional[int], seq_bucket_multiple: Optional[int], ): preset_name = str(variant or "upper_ws_auto").strip().lower() if preset_name in {"", "default", "auto"}: preset_name = "upper_ws_auto" preset = _OPTIMIZED_LLOPA_VARIANT_PRESETS.get( preset_name, _OPTIMIZED_LLOPA_VARIANT_PRESETS["upper_ws_auto"], ) resolved_seed_mode = _normalize_optimized_llopa_seed_mode( seed_mode if seed_mode is not None else preset.get("seed_mode") ) resolved_upper_prepare_mode = _normalize_optimized_llopa_upper_prepare_mode( upper_prepare_mode if upper_prepare_mode is not None else preset.get("upper_prepare_mode") ) resolved_upper_bucket_multiple = ( int(upper_bucket_multiple) if upper_bucket_multiple is not None else int(preset.get("upper_bucket_multiple", 0) or 0) ) resolved_seq_bucket_multiple = ( int(seq_bucket_multiple) if seq_bucket_multiple is not None else int(preset.get("seq_bucket_multiple", 256) or 256) ) if resolved_upper_prepare_mode != "bucketed_workspace": resolved_upper_bucket_multiple = 0 if resolved_upper_bucket_multiple < 0: resolved_upper_bucket_multiple = 0 if resolved_seq_bucket_multiple <= 0: resolved_seq_bucket_multiple = 256 return { "variant": preset_name, "seed_mode": resolved_seed_mode, "upper_prepare_mode": resolved_upper_prepare_mode, "upper_bucket_multiple": int(resolved_upper_bucket_multiple), "seq_bucket_multiple": int(resolved_seq_bucket_multiple), } @contextlib.contextmanager def _temporary_model_attrs(model, **updates): sentinel = object() prior = {} try: for key, value in updates.items(): prior[key] = getattr(model, key, sentinel) setattr(model, key, value) yield finally: for key, old_value in prior.items(): if old_value is sentinel: with contextlib.suppress(Exception): delattr(model, key) else: with contextlib.suppress(Exception): setattr(model, key, old_value) def _acquire_bucketed_sequence_workspace( model, *, reference_ids: torch.Tensor, batch_size: int, total_len: int, bucket_multiple: int, ): bucket_total_len = _round_up_to_multiple(total_len, bucket_multiple) if bucket_total_len <= 0: bucket_total_len = int(total_len) dtype = reference_ids.dtype device = reference_ids.device key = (str(device), str(dtype), int(batch_size), int(bucket_total_len)) cache = getattr(model, "_optimized_llopa_sequence_workspace_cache", None) if not isinstance(cache, dict): cache = {} workspace = cache.get(key) if ( not isinstance(workspace, torch.Tensor) or workspace.device != device or workspace.dtype != dtype or workspace.shape != (int(batch_size), int(bucket_total_len)) ): workspace = torch.empty((int(batch_size), int(bucket_total_len)), dtype=dtype, device=device) cache[key] = workspace try: setattr(model, "_optimized_llopa_sequence_workspace_cache", cache) except Exception: pass return workspace, int(bucket_total_len) def _kv_meta_from_model(model_like): """Return (num_kv_heads, head_dim, dtype).""" try: cfg = getattr(model_like, "config", None) or getattr(_get_inner_model(model_like), "config", None) except Exception: cfg = getattr(_get_inner_model(model_like), "config", None) num_heads = getattr(cfg, "num_attention_heads", None) num_kv = getattr(cfg, "num_key_value_heads", None) or num_heads hidden = getattr(cfg, "hidden_size", None) head_dim = (hidden // num_heads) if (hidden and num_heads) else None try: dtype = next(_get_inner_model(model_like).parameters()).dtype except Exception: dtype = torch.float16 if torch.cuda.is_available() else torch.float32 return int(num_kv), int(head_dim), dtype def _make_empty_kv(batch: int, num_kv: int, head_dim: int, device, dtype): shape = (batch, num_kv, 0, head_dim) k = torch.empty(shape, device=device, dtype=dtype) v = torch.empty(shape, device=device, dtype=dtype) return k.contiguous(), v.contiguous() # ----------------------------- # LLOPA helpers (encapsulated) # ----------------------------- def _llopa_split_system(system_ids: torch.Tensor, system_prefill: str): """Return (system_upper, system_lower_extra) based on system_prefill.""" sys_prefill = (system_prefill or "full").strip().lower() if sys_prefill not in {"full", "no_system", "no_bos_system"}: sys_prefill = "full" if sys_prefill == "full": return system_ids, system_ids[:, :0] if sys_prefill == "no_system": if system_ids.size(1) < 1: return system_ids[:, :0], system_ids[:, :0] return system_ids[:, :1], system_ids[:, 1:] # no_bos_system return system_ids[:, :0], system_ids def _llopa_merge_replay_user_span( system_ids: torch.Tensor, *, system_prefill: str, replay_user_prefix_keep_len: int, replay_user_start: Optional[int], replay_user_len: Optional[int], ): _, sys_lower_extra = _llopa_split_system(system_ids, system_prefill) merged_prefix_keep_len = int(sys_lower_extra.size(1)) + max(int(replay_user_prefix_keep_len or 0), 0) merged_user_start = None if replay_user_start is not None: merged_user_start = int(sys_lower_extra.size(1)) + max(int(replay_user_start), 0) merged_user_len = None if replay_user_len is None else max(int(replay_user_len), 0) return int(merged_prefix_keep_len), merged_user_start, merged_user_len def _llopa_merge_user(system_ids: torch.Tensor, user_ids: torch.Tensor, system_prefill: str): sys_upper, sys_lower_extra = _llopa_split_system(system_ids, system_prefill) if sys_lower_extra.numel() == 0: return sys_upper, user_ids if user_ids.numel() == 0: return sys_upper, sys_lower_extra return sys_upper, torch.cat([sys_lower_extra, user_ids], dim=1) def _llopa_prefill_cache(llopa_core, system_ids: torch.Tensor, user_ids: torch.Tensor, assistant_ids: torch.Tensor, *, lower_k: int, prefill_mode: str, prefill_attn: str, system_prefill: str, return_last_assistant_hidden: bool = False, replay_user_prefix_keep_len: int = 0, replay_user_start: Optional[int] = None, replay_user_len: Optional[int] = None): prefill_mode = (prefill_mode or "lower").strip().lower() prefill_attn = (prefill_attn or "causal").strip().lower() if prefill_attn == "prefix_full": prefill_attn = "full" if prefill_mode != "lower": raise ValueError("llopa_prefill requires prefill_mode='lower'.") if prefill_attn not in {"causal", "full"}: raise ValueError("llopa_prefill requires prefill_attn in {'causal','full'}.") llopa_fn = getattr(llopa_core, "llopa_prefill_cache", None) if llopa_fn is None: raise RuntimeError("llopa_prefill_cache not found. Check LLoPA modeling patch.") sys_upper, user_llopa = _llopa_merge_user(system_ids, user_ids, system_prefill) merged_replay_prefix_keep_len, merged_replay_user_start, merged_replay_user_len = _llopa_merge_replay_user_span( system_ids, system_prefill=system_prefill, replay_user_prefix_keep_len=int(replay_user_prefix_keep_len or 0), replay_user_start=replay_user_start, replay_user_len=replay_user_len, ) prefill_out = llopa_fn( system_ids=sys_upper, user_ids=user_llopa, assistant_ids=assistant_ids, lower_k=lower_k, prefill_mode=prefill_mode, prefill_attn=prefill_attn, return_last_assistant_hidden=bool(return_last_assistant_hidden), replay_user_prefix_keep_len=merged_replay_prefix_keep_len, replay_user_start=merged_replay_user_start, replay_user_len=merged_replay_user_len, ) if bool(return_last_assistant_hidden): if not isinstance(prefill_out, tuple): raise RuntimeError("llopa_prefill_cache did not return the requested last assistant hidden state.") pkv, last_hidden = prefill_out return pkv, sys_upper.size(1), user_llopa.size(1), last_hidden pkv = prefill_out return pkv, sys_upper.size(1), user_llopa.size(1) # --------------------------------------------------------------------------- # LoPA per-layer cache_position/position_ids adjustment # Aligns per-layer positions when lower layers have prefill past and upper # layers start from zero. Mirrors the trainer's runtime patch. # --------------------------------------------------------------------------- import contextlib @contextlib.contextmanager def lopa_cache_position_patch(model, past_key_values): """ Match trainer's dynamic position alignment: For each decoder layer, compute its current past length from the provided past_key_values, and during forward adjust cache_position/position_ids by off = start_val - past_len so that lower-K layers (with past=L_sys+L_doc) and upper layers (with past=0) align logically for the current token. """ inner = _get_inner_model(model) # Per-layer past length from the provided cache snapshot def _pkv_past_len(li: int) -> int: if hasattr(past_key_values, "key_cache") and hasattr(past_key_values, "value_cache"): return int(past_key_values.key_cache[li].shape[2]) if hasattr(past_key_values, "layers"): return int(past_key_values.layers[li].keys.shape[2]) return int(past_key_values[li][0].shape[2]) n_layers = len(inner.layers) past_lens = [_pkv_past_len(li) for li in range(n_layers)] handles = [] for li, layer in enumerate(inner.layers): layer._lopa_past = int(past_lens[li]) def _pre_hook(module, args, kwargs): past_len = getattr(module, "_lopa_past", 0) cp = kwargs.get("cache_position", None) pi = kwargs.get("position_ids", None) start_val = None if isinstance(cp, torch.Tensor) and cp.numel() > 0: start_val = int(cp.view(-1)[0].item()) elif isinstance(pi, torch.Tensor) and pi.numel() > 0: start_val = int(pi.view(-1)[0].item()) if start_val is not None: off = start_val - past_len if off != 0: if isinstance(cp, torch.Tensor): kwargs["cache_position"] = cp - off if isinstance(pi, torch.Tensor): kwargs["position_ids"] = pi - off return args, kwargs h = layer.register_forward_pre_hook(_pre_hook, with_kwargs=True) handles.append(h) try: yield finally: for h in handles: h.remove() for layer in inner.layers: if hasattr(layer, "_lopa_past"): delattr(layer, "_lopa_past") # ----------------------------- # TRI inference core # ----------------------------- @torch.inference_mode() def _get_peft_wrapper(m): try: from peft import PeftModel except Exception: return None if hasattr(m, "module"): m = m.module return m if isinstance(m, PeftModel) else None def _set_prefill_adapter(model, enabled: bool) -> None: if not bool(getattr(model, "_prefill_adapter_only", False)): return peft_model = _get_peft_wrapper(model) if peft_model is None: return base = getattr(peft_model, "base_model", None) if base is None: return try: if enabled: base.enable_adapter_layers() else: base.disable_adapter_layers() except Exception: return def _get_output_head(model): getter = getattr(model, "get_output_embeddings", None) if callable(getter): head = getter() if head is not None: return head return getattr(model, "lm_head", None) def _env_flag_enabled(name: str, default: str = "1") -> bool: raw = os.environ.get(name, default).strip().lower() return raw not in {"0", "false", "no", "off"} @torch.inference_mode() def lopa_generate(model, tokenizer, system: str, document: str, question: str, *, task: str = "qa_doc", K: int, prefill_mode: str = "lower", prefill_attn: str = "causal", system_prefill: str = "full", user_prefill: str = "full", device: str, input_ids: Optional[dict] = None, max_new_tokens: int = 256, min_length: int = 16, temperature: float = 0.7, top_p: float = 0.9, top_k: Optional[int] = None, do_sample: bool = True, math_force_final_hash_rule: bool = False, log_cuda_mem: bool = False, log_cuda_tag: Optional[str] = None, debug: bool = False, debug_dir: Optional[Path] = None, llopa_prefill: bool = False, no_upper_attn: Optional[bool] = None, return_tokens: bool = False) -> str | Tuple[str, int]: # Build ids if input_ids is None and task == "math" and math_force_final_hash_rule and "####" not in (system or ""): system = ( system.rstrip() + " Conclude your explanation with the answer in a '#### {numeric answer}' format, " + "where the answer is solely a number." ) user_prefill = (user_prefill or "full").strip().lower() if user_prefill not in {"full", "no_question"}: user_prefill = "full" prefill_attn = (prefill_attn or "causal").strip().lower() if prefill_attn == "prefix_full": prefill_attn = "full" if prefill_attn not in {"causal", "full"}: raise ValueError("prefill_attn must be one of: causal | full") llopa_inputs = _coerce_llopa_inputs( tokenizer, system=system, document=document, question=question, task=task, device=device, input_ids=input_ids, ) ids_sys = llopa_inputs["system_ids"] ids_sys_user = llopa_inputs["system_user_ids"] user_doc_ids = llopa_inputs["user_doc_ids"] user_q_ids = llopa_inputs["user_q_ids"] hdr_tail = llopa_inputs["hdr_tail"] ids_hdr = llopa_inputs.get("prompt_ids", None) if debug: try: msgs = build_messages(system, document, question, include_query=True, task=task) s_no_hdr = apply_chat_template(tokenizer, msgs, add_generation_prompt=False) s_with_hdr = apply_chat_template(tokenizer, msgs, add_generation_prompt=True) print(f"[debug] render lengths (chars): no_hdr={len(s_no_hdr)}, with_hdr={len(s_with_hdr)}") if debug_dir is not None: debug_dir.mkdir(parents=True, exist_ok=True) (debug_dir / "infer_render_no_header.txt").write_text(s_no_hdr, encoding="utf-8") (debug_dir / "infer_render_with_header.txt").write_text(s_with_hdr, encoding="utf-8") except Exception as e: print(f"[debug] render dump failed: {e}") # Assistant header tokens if ids_hdr is None: ids_hdr = torch.cat([ids_sys_user, hdr_tail], dim=1) if hdr_tail is not None else ids_sys_user hdr_tail = ids_hdr[:, ids_sys_user.size(1):] if user_prefill == "no_question": prefix_full = torch.cat([user_q_ids, hdr_tail], dim=1) else: prefix_full = hdr_tail # Require LLoPA runtime API llopa_core = _get_llopa_core(model) llopa_step = _get_llopa_decode_step(model) if llopa_core is None or llopa_step is None: raise RuntimeError("Custom LLoPA modeling not active. Check --lopa_modeling_path/--modeling_family.") if no_upper_attn is None: no_upper_attn = bool(getattr(model, "_no_upper_attn", False)) no_upper_attn = bool(no_upper_attn) if no_upper_attn and (not bool(llopa_prefill)): print("[infer][warn] no_upper_attn is ignored unless llopa_prefill=True.") effective_no_upper_attn = bool(no_upper_attn and bool(llopa_prefill)) llopa_step_accepts_no_upper_attn = False try: llopa_step_params = inspect.signature(llopa_step).parameters llopa_step_accepts_no_upper_attn = ("no_upper_attn" in llopa_step_params) except Exception: llopa_step_accepts_no_upper_attn = False if effective_no_upper_attn and (not llopa_step_accepts_no_upper_attn): print("[infer][warn] no_upper_attn requested but llopa_decode_step_logits does not support it; ignoring.") effective_no_upper_attn = False def _log_mem(tag: str) -> None: if not log_cuda_mem or not torch.cuda.is_available(): return try: torch.cuda.synchronize() except Exception: pass alloc = torch.cuda.memory_allocated() / (1024 ** 3) reserved = torch.cuda.memory_reserved() / (1024 ** 3) max_alloc = torch.cuda.max_memory_allocated() / (1024 ** 3) max_reserved = torch.cuda.max_memory_reserved() / (1024 ** 3) prefix = "[mem]" if log_cuda_tag: prefix = f"{prefix}[{log_cuda_tag}]" print(f"{prefix} {tag} | alloc={alloc:.2f}GiB reserved={reserved:.2f}GiB " f"max_alloc={max_alloc:.2f}GiB max_reserved={max_reserved:.2f}GiB") if log_cuda_mem and torch.cuda.is_available(): try: torch.cuda.reset_peak_memory_stats() except Exception: pass _log_mem("start") # 1) TRI prefill: system all + user lower-K _set_prefill_adapter(model, True) lower_k = int(K) sys_prefill = (system_prefill or "full").strip().lower() if sys_prefill not in {"full", "no_system", "no_bos_system"}: sys_prefill = "full" if user_prefill == "no_question": user_prefill_ids = user_doc_ids else: user_prefill_ids = torch.cat([user_doc_ids, user_q_ids], dim=1) if llopa_prefill: use_fused_first_token = bool( prefix_full.numel() > 0 and not effective_no_upper_attn and not bool(getattr(model, "_prefill_adapter_only", False)) ) if use_fused_first_token: pkv, S, U, llopa_last_hidden = _llopa_prefill_cache( llopa_core, ids_sys, user_prefill_ids, prefix_full, lower_k=lower_k, prefill_mode=prefill_mode, prefill_attn=prefill_attn, system_prefill=sys_prefill, return_last_assistant_hidden=True, ) else: pkv, S, U = _llopa_prefill_cache( llopa_core, ids_sys, user_prefill_ids, prefix_full, lower_k=lower_k, prefill_mode=prefill_mode, prefill_attn=prefill_attn, system_prefill=sys_prefill, ) llopa_last_hidden = None else: use_fused_first_token = False llopa_last_hidden = None if sys_prefill == "full": pkv, S, U = llopa_core.tri_build_caches( system_ids=ids_sys, user_ids=user_prefill_ids, lower_k=lower_k, prefill_mode=prefill_mode, prefill_attn=prefill_attn, ) elif sys_prefill == "no_system": if ids_sys.size(1) < 1: pkv = _safe_dynamic_cache(getattr(llopa_core, "config", None)) else: bos_ids = ids_sys[:, :1] rest_ids = ids_sys[:, 1:] out = llopa_core.tri_prefill_system_all( bos_ids, past_key_values=None, prefill_attn=prefill_attn, ) pkv = out.past_key_values if rest_ids.size(1) > 0: _ = llopa_core.tri_prefill_user_lower( rest_ids, lower_k=lower_k, past_key_values=pkv, prefill_mode=prefill_mode, prefill_attn=prefill_attn, ) _ = llopa_core.tri_prefill_user_lower( user_prefill_ids, lower_k=lower_k, past_key_values=pkv, prefill_mode=prefill_mode, prefill_attn=prefill_attn, ) S, U = ids_sys.size(1), user_prefill_ids.size(1) else: pkv = _safe_dynamic_cache(getattr(llopa_core, "config", None)) if ids_sys.size(1) > 0: _ = llopa_core.tri_prefill_user_lower( ids_sys, lower_k=lower_k, past_key_values=pkv, prefill_mode=prefill_mode, prefill_attn=prefill_attn, ) _ = llopa_core.tri_prefill_user_lower( user_prefill_ids, lower_k=lower_k, past_key_values=pkv, prefill_mode=prefill_mode, prefill_attn=prefill_attn, ) S, U = ids_sys.size(1), user_prefill_ids.size(1) _set_prefill_adapter(model, False) _log_mem("prefill_end") initial_logits = None if use_fused_first_token: output_head = _get_output_head(model) if output_head is not None and isinstance(llopa_last_hidden, torch.Tensor) and llopa_last_hidden.numel() > 0: initial_logits = output_head(llopa_last_hidden)[:, -1, :].to(torch.float32) # 2) Push assistant header if present (or fallback to last user token) if llopa_prefill: if prefix_full.numel() > 0: last_pushed = prefix_full[:, -1:] elif ids_sys_user.numel() > 0: last_pushed = ids_sys_user[:, -1:] else: raise ValueError("Empty prompt after LLOPA prefill; cannot start decoding.") else: if prefix_full.numel() > 0: seed_kwargs = dict( assistant_ids=prefix_full, lower_k=lower_k, pkv=pkv, S=S, U=U, logits_to_keep=0, labels=None, prefill_mode=prefill_mode, ) if effective_no_upper_attn and llopa_step_accepts_no_upper_attn: seed_kwargs["no_upper_attn"] = True out_seed = llopa_step(**seed_kwargs) pkv = out_seed.past_key_values or pkv last_pushed = prefix_full[:, -1:] else: step_tok = ids_sys_user[:, -1:] seed_kwargs = dict( assistant_ids=step_tok, lower_k=lower_k, pkv=pkv, S=S, U=U, logits_to_keep=0, labels=None, prefill_mode=prefill_mode, ) if effective_no_upper_attn and llopa_step_accepts_no_upper_attn: seed_kwargs["no_upper_attn"] = True out_seed = llopa_step(**seed_kwargs) pkv = out_seed.past_key_values or pkv last_pushed = step_tok if log_cuda_mem and torch.cuda.is_available(): try: torch.cuda.reset_peak_memory_stats() except Exception: pass _log_mem("decode_start") # 5) decoding from transformers.generation import LogitsProcessorList from transformers.generation.logits_process import TemperatureLogitsWarper, TopPLogitsWarper, TopKLogitsWarper eos_id = tokenizer.eos_token_id try: eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>") except Exception: eot_id = None stop_ids = set() if eos_id is not None: stop_ids.add(int(eos_id)) if eot_id is not None and eot_id != tokenizer.unk_token_id: stop_ids.add(int(eot_id)) procs = None if do_sample: procs = LogitsProcessorList() if temperature and temperature != 1.0: procs.append(TemperatureLogitsWarper(temperature=float(temperature))) if top_p and top_p < 1.0: procs.append(TopPLogitsWarper(top_p=float(top_p), min_tokens_to_keep=1)) if top_k is not None and top_k > 0: procs.append(TopKLogitsWarper(top_k=int(top_k), filter_value=-float("inf"))) device_t = last_pushed.device last = last_pushed generated = torch.empty((1, max_new_tokens), dtype=torch.long, device=device_t) cur = 0 stop_reason = None pending_logits = initial_logits while cur < max_new_tokens: if pending_logits is None: step_kwargs = dict( assistant_ids=last, lower_k=lower_k, pkv=pkv, S=S, U=U, logits_to_keep=1, labels=None, prefill_mode=prefill_mode, ) if effective_no_upper_attn and llopa_step_accepts_no_upper_attn: step_kwargs["no_upper_attn"] = True out = llopa_step(**step_kwargs) pkv = out.past_key_values or pkv logits = out.logits[:, -1, :] else: logits = pending_logits pending_logits = None # force min_length if stop_ids and cur < min_length: for sid in stop_ids: logits[:, sid] = -float("inf") if procs is not None: inp_for_proc = generated[:, :cur] if logits.dtype != torch.float32: logits = logits.float() logits = procs(inp_for_proc, logits) if do_sample: probs = torch.softmax(logits, dim=-1) next_tok = torch.multinomial(probs, num_samples=1) else: next_tok = torch.argmax(logits, dim=-1, keepdim=True) generated[:, cur:cur + 1] = next_tok last = next_tok if stop_ids and cur >= min_length: tok_id = int(next_tok.item()) if tok_id in stop_ids: stop_reason = f"stop_token:{tok_id}" cur += 1 break cur += 1 if stop_reason is None and cur >= max_new_tokens: stop_reason = "max_new_tokens" gen_ids = generated[:, :cur] text = tokenizer.decode(gen_ids[0].tolist(), skip_special_tokens=True) _log_mem("decode_end") if debug: print(f"[debug] finished | tokens={cur} | reason={stop_reason}") if debug_dir is not None: debug_dir.mkdir(parents=True, exist_ok=True) (debug_dir / "infer_generated.txt").write_text(text, encoding="utf-8") if return_tokens: return text, int(cur) return text # ----------------------------- # LLOPA helpers (Capsule interface + HF packaging) # ----------------------------- def llopa_generate(*args, **kwargs): """Capsule interface: identical behavior to lopa_generate.""" if "device" not in kwargs or kwargs.get("device") is None: if len(args) >= 1: model = args[0] try: dev = model.get_input_embeddings().weight.device except Exception: try: dev = next(model.parameters()).device except Exception: dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") kwargs["device"] = str(dev) return lopa_generate(*args, **kwargs) def _read_kv_file(path: Path) -> dict[str, str]: info: dict[str, str] = {} try: lines = path.read_text(encoding="utf-8").splitlines() except Exception: return info for raw in lines: line = raw.strip() if not line or line.startswith("#"): continue if "=" not in line: continue k, v = line.split("=", 1) k = k.strip() v = v.strip() if k: info[k] = v return info def _read_adapter_backbone_ref(repo_path: Path) -> str: adapter_cfg = repo_path / "adapter_config.json" if not adapter_cfg.is_file(): return "" try: data = json.loads(adapter_cfg.read_text(encoding="utf-8")) except Exception: return "" val = data.get("base_model_name_or_path") if isinstance(val, str): return val.strip() return "" def _resolve_repo_path(model_repo: str, cache_dir: Optional[str] = None, revision: Optional[str] = None, token: Optional[str] = None, local_files_only: bool = False) -> Path: repo = Path(model_repo) if repo.exists(): return repo try: from huggingface_hub import snapshot_download except Exception as exc: raise RuntimeError("huggingface_hub is required to load remote repos") from exc path = snapshot_download( repo_id=model_repo, cache_dir=cache_dir, revision=revision, token=token, local_files_only=local_files_only, ) return Path(path) def _repo_has_pretrained_weights(repo_path: Path) -> bool: weight_files = ( "pytorch_model.bin", "pytorch_model.bin.index.json", "model.safetensors", "model.safetensors.index.json", ) if any((repo_path / name).is_file() for name in weight_files): return True return any(repo_path.glob("pytorch_model-*-of-*.bin")) or any(repo_path.glob("model-*-of-*.safetensors")) def _pick_vocab_weight_key(keys) -> Optional[str]: preferred = ( "model.embed_tokens.weight", "embed_tokens.weight", "model.decoder.embed_tokens.weight", "transformer.wte.weight", "lm_head.weight", "model.lm_head.weight", ) key_list = list(keys) for cand in preferred: if cand in key_list: return cand suffixes = ( "embed_tokens.weight", "decoder.embed_tokens.weight", "wte.weight", "lm_head.weight", ) for suffix in suffixes: for key in key_list: if str(key).endswith(suffix): return str(key) return None def _load_state_dict_meta(path: Path): try: return torch.load(path, map_location="meta", weights_only=True) except TypeError: return torch.load(path, map_location="meta") def _weight_rows_from_file(path: Path, tensor_key: Optional[str] = None) -> Optional[int]: try: suffixes = path.suffixes except Exception: suffixes = [] try: if suffixes[-1:] == [".safetensors"] or suffixes[-2:] == [".model", ".safetensors"]: from safetensors import safe_open with safe_open(str(path), framework="pt", device="cpu") as handle: chosen = tensor_key or _pick_vocab_weight_key(handle.keys()) if not chosen: return None shape = handle.get_slice(chosen).get_shape() if len(shape) >= 2: return int(shape[0]) return None state_dict = _load_state_dict_meta(path) if not isinstance(state_dict, dict): return None chosen = tensor_key if tensor_key in state_dict else _pick_vocab_weight_key(state_dict.keys()) if not chosen: return None tensor = state_dict.get(chosen) if isinstance(tensor, torch.Tensor) and tensor.ndim >= 2: return int(tensor.shape[0]) except Exception: return None return None def _infer_checkpoint_vocab_size(repo_path: Path) -> Optional[int]: index_files = ( "pytorch_model.bin.index.json", "model.safetensors.index.json", ) single_weight_files = ( "pytorch_model.bin", "model.safetensors", ) for name in index_files: index_path = repo_path / name if not index_path.is_file(): continue try: data = json.loads(index_path.read_text(encoding="utf-8")) except Exception: continue weight_map = data.get("weight_map") if not isinstance(weight_map, dict): continue tensor_key = _pick_vocab_weight_key(weight_map.keys()) if not tensor_key: continue shard_rel = weight_map.get(tensor_key) if not shard_rel: continue rows = _weight_rows_from_file(repo_path / shard_rel, tensor_key=tensor_key) if rows is not None and rows > 0: return rows for name in single_weight_files: weight_path = repo_path / name if not weight_path.is_file(): continue rows = _weight_rows_from_file(weight_path) if rows is not None and rows > 0: return rows for weight_path in sorted(repo_path.glob("pytorch_model-*-of-*.bin")): rows = _weight_rows_from_file(weight_path) if rows is not None and rows > 0: return rows for weight_path in sorted(repo_path.glob("model-*-of-*.safetensors")): rows = _weight_rows_from_file(weight_path) if rows is not None and rows > 0: return rows return None def _expand_tokenizer_placeholders(tokenizer, target_vocab_size: int) -> int: try: current_vocab_size = int(len(tokenizer)) except Exception: return 0 if target_vocab_size <= current_vocab_size: return 0 missing = int(target_vocab_size - current_vocab_size) placeholder_tokens = [f"<|capsule_missing_token_{idx}|>" for idx in range(missing)] try: return int(tokenizer.add_tokens(placeholder_tokens, special_tokens=True) or 0) except Exception: return 0 def _normalize_special_token_values(raw_value: Any) -> list[str]: if not isinstance(raw_value, list): return [] normalized: list[str] = [] for item in raw_value: token = "" if isinstance(item, str): token = item elif isinstance(item, dict): for key in ("content", "token", "text"): value = item.get(key) if isinstance(value, str) and value: token = value break if token: normalized.append(token) return normalized def _checkpoint_special_token_candidates(checkpoint_repo_path: Path) -> list[str]: candidates: list[str] = [] for filename in ("config.json", "tokenizer_config.json", "special_tokens_map.json"): path = checkpoint_repo_path / filename if not path.is_file(): continue try: payload = json.loads(path.read_text(encoding="utf-8")) except Exception: continue if not isinstance(payload, dict): continue candidates.extend(_normalize_special_token_values(payload.get("capsule_suffix_special_tokens"))) candidates.extend(_normalize_special_token_values(payload.get("additional_special_tokens"))) raw_decoder = payload.get("added_tokens_decoder") if isinstance(raw_decoder, dict): for _, value in sorted( raw_decoder.items(), key=lambda item: int(item[0]) if str(item[0]).isdigit() else str(item[0]), ): candidates.extend(_normalize_special_token_values([value])) deduped: list[str] = [] seen: set[str] = set() for token in candidates: if not token or token in seen: continue seen.add(token) deduped.append(token) return deduped def _align_tokenizer_with_checkpoint_vocab( tokenizer, checkpoint_repo_path: Path, target_vocab_size: int, ) -> dict[str, int]: report = { "checkpoint_vocab_size": int(target_vocab_size or 0), "tokenizer_vocab_size_before": 0, "checkpoint_special_candidate_count": 0, "checkpoint_specials_missing_before": 0, "added_checkpoint_specials": 0, "tokenizer_vocab_size_after_specials": 0, "padding_gap_before_placeholders": 0, "added_placeholders": 0, "tokenizer_vocab_size_after": 0, } try: current_vocab_size = int(len(tokenizer)) except Exception: return report report["tokenizer_vocab_size_before"] = current_vocab_size report["tokenizer_vocab_size_after_specials"] = current_vocab_size report["tokenizer_vocab_size_after"] = current_vocab_size missing = int(target_vocab_size - current_vocab_size) if missing <= 0: return report checkpoint_specials = _checkpoint_special_token_candidates(checkpoint_repo_path) report["checkpoint_special_candidate_count"] = len(checkpoint_specials) try: existing_vocab = set(tokenizer.get_vocab().keys()) except Exception: existing_vocab = set() missing_checkpoint_specials = [token for token in checkpoint_specials if token not in existing_vocab] report["checkpoint_specials_missing_before"] = len(missing_checkpoint_specials) tokens_to_add: list[str] = [] for token in missing_checkpoint_specials: tokens_to_add.append(token) existing_vocab.add(token) if len(tokens_to_add) >= missing: break added_checkpoint_specials = 0 if tokens_to_add: try: current_specials = tokenizer.special_tokens_map_extended.get("additional_special_tokens", []) or [] added_checkpoint_specials = int( tokenizer.add_special_tokens( {"additional_special_tokens": list(current_specials) + tokens_to_add} ) or 0 ) except Exception: try: added_checkpoint_specials = int(tokenizer.add_tokens(tokens_to_add, special_tokens=True) or 0) except Exception: added_checkpoint_specials = 0 report["added_checkpoint_specials"] = added_checkpoint_specials try: post_special_vocab_size = int(len(tokenizer)) except Exception: post_special_vocab_size = current_vocab_size + added_checkpoint_specials report["tokenizer_vocab_size_after_specials"] = post_special_vocab_size report["padding_gap_before_placeholders"] = max(0, int(target_vocab_size - post_special_vocab_size)) added_placeholders = 0 if target_vocab_size > post_special_vocab_size: added_placeholders = _expand_tokenizer_placeholders(tokenizer, target_vocab_size) report["added_placeholders"] = added_placeholders try: report["tokenizer_vocab_size_after"] = int(len(tokenizer)) except Exception: report["tokenizer_vocab_size_after"] = post_special_vocab_size + added_placeholders return report def _log_tokenizer_checkpoint_alignment( log_prefix: str, report: dict[str, int], *, print_fn=print, ) -> None: added_specials = int(report.get("added_checkpoint_specials", 0) or 0) added_placeholders = int(report.get("added_placeholders", 0) or 0) if added_specials <= 0 and added_placeholders <= 0: return vocab_before = int(report.get("tokenizer_vocab_size_before", 0) or 0) vocab_after_specials = int(report.get("tokenizer_vocab_size_after_specials", vocab_before) or vocab_before) vocab_after = int(report.get("tokenizer_vocab_size_after", vocab_after_specials) or vocab_after_specials) missing_specials_before = int(report.get("checkpoint_specials_missing_before", 0) or 0) if added_specials > 0: print_fn( f"{log_prefix}[info] tokenizer vocab is smaller than checkpoint embeddings; " f"recovered {added_specials} checkpoint special tokens " f"({vocab_before} -> {vocab_after_specials})." ) if added_placeholders > 0: if missing_specials_before > 0: print_fn( f"{log_prefix}[warn] checkpoint embeddings still exceed tokenizer vocab after recovering " f"checkpoint special tokens; added {added_placeholders} placeholder special tokens " f"to preserve id alignment ({vocab_after_specials} -> {vocab_after})." ) else: print_fn( f"{log_prefix}[info] checkpoint embeddings include {added_placeholders} padded rows " f"beyond tokenizer vocab; added {added_placeholders} placeholder special tokens " f"to preserve id alignment ({vocab_after_specials} -> {vocab_after})." ) def _log_config_checkpoint_vocab_alignment( log_prefix: str, *, config_vocab_size: int, checkpoint_vocab_size: int, tokenizer_vocab_size: int, alignment_report: Optional[dict[str, int]] = None, print_fn=print, ) -> None: if checkpoint_vocab_size <= 0 or checkpoint_vocab_size == config_vocab_size: return final_tokenizer_vocab = int(tokenizer_vocab_size or 0) if alignment_report is not None: final_tokenizer_vocab = int( alignment_report.get("tokenizer_vocab_size_after", final_tokenizer_vocab) or final_tokenizer_vocab ) missing_specials_before = 0 if alignment_report is not None: missing_specials_before = int(alignment_report.get("checkpoint_specials_missing_before", 0) or 0) if final_tokenizer_vocab >= checkpoint_vocab_size and missing_specials_before <= 0: print_fn( f"{log_prefix}[info] config vocab_size ({config_vocab_size}) lags tokenizer/checkpoint embeddings " f"({checkpoint_vocab_size}); using checkpoint size." ) return padding_rows = max(0, checkpoint_vocab_size - final_tokenizer_vocab) if padding_rows > 0 and missing_specials_before <= 0: print_fn( f"{log_prefix}[info] config vocab_size ({config_vocab_size}) lags checkpoint embeddings " f"({checkpoint_vocab_size}); using checkpoint size. tokenizer vocab is {final_tokenizer_vocab}, " f"so {padding_rows} rows are embedding padding/alignment." ) return print_fn( f"{log_prefix}[warn] config vocab_size ({config_vocab_size}) does not match checkpoint embeddings " f"({checkpoint_vocab_size}); using checkpoint size." ) def _expand_tokenizer_with_checkpoint_specials( tokenizer, checkpoint_repo_path: Path, target_vocab_size: int, ) -> tuple[int, int]: report = _align_tokenizer_with_checkpoint_vocab( tokenizer, checkpoint_repo_path, target_vocab_size, ) return ( int(report.get("added_checkpoint_specials", 0) or 0), int(report.get("added_placeholders", 0) or 0), ) def _resolve_modeling_path(repo_path: Path, user_path: Optional[str], model_family: str) -> Optional[str]: if user_path: cand = Path(user_path) if not cand.is_file(): cand = repo_path / user_path return str(cand) if cand.is_file() else None tri_info = repo_path / "tri_info.txt" if tri_info.is_file(): info = _read_kv_file(tri_info) tri_file = info.get("lopa_modeling_path") or "" if tri_file: cand = repo_path / tri_file if cand.is_file(): return str(cand) default_name = { "llama": "tri_llama3_modeling.py", "qwen3": "tri_qwen3_modeling.py", "mistral": "tri_mistral_modeling.py", }.get(model_family, "tri_llama3_modeling.py") cand = repo_path / default_name if cand.is_file(): return str(cand) return None def _attach_llopa_generate(model): try: import types def _llopa_generate(self, tokenizer, system: Optional[str] = None, document: Optional[str] = None, question: Optional[str] = None, **kwargs): if system is None: system = kwargs.pop("system", "") if document is None: document = kwargs.pop("document", "") if question is None: question = kwargs.pop("question", "") if "device" not in kwargs or kwargs.get("device") is None: try: dev = self.get_input_embeddings().weight.device except Exception: try: dev = next(self.parameters()).device except Exception: dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") kwargs["device"] = str(dev) return lopa_generate(self, tokenizer, system=system, document=document, question=question, **kwargs) model.llopa_generate = types.MethodType(_llopa_generate, model) except Exception: pass def _maybe_attach_llopa_generate(model): if not hasattr(model, "llopa_generate"): _attach_llopa_generate(model) def _attach_prefill_lower_generate( model, *, lower_k: int, prefill_attn: str = "causal", system_prefill: str = "no_bos_system", no_upper_attn: bool = False, ) -> None: try: import types except Exception: return try: lower_k = int(lower_k) except Exception: lower_k = 0 attn = (prefill_attn or "causal").strip().lower() if attn == "prefix_full": attn = "full" sys_prefill = (system_prefill or "no_bos_system").strip().lower() if sys_prefill not in {"full", "no_system", "no_bos_system"}: sys_prefill = "no_bos_system" if lower_k <= 0 or attn not in {"causal", "full"}: return try: setattr(model, "_runtime_prefill_lower_layers", int(lower_k)) setattr(model, "_runtime_prefill_lower_attn", attn) setattr(model, "_runtime_prefill_lower_system_prefill", sys_prefill) setattr(model, "_runtime_prefill_lower_no_upper_attn", bool(no_upper_attn)) except Exception: return if getattr(model, "_runtime_prefill_generate_attached", False): return orig_prepare = getattr(model, "prepare_inputs_for_generation", None) if orig_prepare is None: return def _runtime_prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, prefill_lower_layers=None, prefill_lower_attn=None, prefill_lower_system_prefill=None, prefill_lower_no_upper_attn=None, prefill_lower_split_start=None, prefill_lower_system_len=None, prefill_lower_replay_user_prefix_keep_len=None, prefill_lower_replay_user_start=None, prefill_lower_replay_user_len=None, assistant_header_start=None, assistant_header_starts=None, assistant_header_start_mask=None, **kwargs, ): model_inputs = orig_prepare( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, inputs_embeds=inputs_embeds, cache_position=cache_position, **kwargs, ) if prefill_lower_layers is None: prefill_lower_layers = int(getattr(self, "_runtime_prefill_lower_layers", 0) or 0) if prefill_lower_attn is None: prefill_lower_attn = str(getattr(self, "_runtime_prefill_lower_attn", "causal") or "causal") if prefill_lower_system_prefill is None: prefill_lower_system_prefill = str( getattr(self, "_runtime_prefill_lower_system_prefill", "no_bos_system") or "no_bos_system" ) if prefill_lower_no_upper_attn is None: prefill_lower_no_upper_attn = bool( getattr(self, "_runtime_prefill_lower_no_upper_attn", False) ) if int(prefill_lower_layers or 0) > 0: model_inputs["prefill_lower_layers"] = int(prefill_lower_layers) model_inputs["prefill_lower_attn"] = str(prefill_lower_attn) model_inputs["prefill_lower_system_prefill"] = str(prefill_lower_system_prefill) if bool(prefill_lower_no_upper_attn): model_inputs["prefill_lower_no_upper_attn"] = True if prefill_lower_split_start is not None: model_inputs["prefill_lower_split_start"] = prefill_lower_split_start if prefill_lower_system_len is not None: model_inputs["prefill_lower_system_len"] = prefill_lower_system_len if prefill_lower_replay_user_prefix_keep_len is not None: model_inputs["prefill_lower_replay_user_prefix_keep_len"] = prefill_lower_replay_user_prefix_keep_len if prefill_lower_replay_user_start is not None: model_inputs["prefill_lower_replay_user_start"] = prefill_lower_replay_user_start if prefill_lower_replay_user_len is not None: model_inputs["prefill_lower_replay_user_len"] = prefill_lower_replay_user_len if assistant_header_start is not None: model_inputs["assistant_header_start"] = assistant_header_start if assistant_header_starts is not None: model_inputs["assistant_header_starts"] = assistant_header_starts if assistant_header_start_mask is not None: model_inputs["assistant_header_start_mask"] = assistant_header_start_mask return model_inputs try: model.prepare_inputs_for_generation = types.MethodType( _runtime_prepare_inputs_for_generation, model, ) setattr(model, "_runtime_prefill_generate_attached", True) except Exception: pass def _attach_prefill_lower_freeze_generate( model, *, tokenizer=None, lower_k: int, prefill_attn: str = "causal", system_prefill: str = "no_bos_system", ) -> None: try: import types except Exception: return try: lower_k = int(lower_k) except Exception: lower_k = 0 attn = (prefill_attn or "causal").strip().lower() if attn == "prefix_full": attn = "full" sys_prefill = (system_prefill or "no_bos_system").strip().lower() if sys_prefill not in {"full", "no_system", "no_bos_system"}: sys_prefill = "no_bos_system" if lower_k <= 0 or attn not in {"causal", "full"}: return try: setattr(model, "_runtime_prefill_freeze_layers", int(lower_k)) setattr(model, "_runtime_prefill_freeze_attn", attn) setattr(model, "_runtime_prefill_freeze_system_prefill", sys_prefill) setattr(model, "_runtime_structured_freeze_generate_default", True) except Exception: return if tokenizer is not None: _attach_structured_llopa_generate(model, tokenizer) def _attach_prefill_lower_solo_generate( model, *, tokenizer=None, lower_k: int, prefill_attn: str = "causal", system_prefill: str = "no_bos_system", ) -> None: try: import types except Exception: return try: lower_k = int(lower_k) except Exception: lower_k = 0 attn = (prefill_attn or "causal").strip().lower() if attn == "prefix_full": attn = "full" sys_prefill = (system_prefill or "no_bos_system").strip().lower() if sys_prefill not in {"full", "no_system", "no_bos_system"}: sys_prefill = "no_bos_system" if lower_k <= 0 or attn not in {"causal", "full"}: return try: setattr(model, "_runtime_prefill_solo_layers", int(lower_k)) setattr(model, "_runtime_prefill_solo_attn", attn) setattr(model, "_runtime_prefill_solo_system_prefill", sys_prefill) setattr(model, "_runtime_structured_solo_generate_default", True) except Exception: return if tokenizer is not None: _attach_structured_llopa_generate(model, tokenizer) def _attach_prefill_lower_solo_v2_generate( model, *, tokenizer=None, lower_k: int, prefill_attn: str = "causal", system_prefill: str = "no_bos_system", with_bos: bool = False, ) -> None: try: lower_k = int(lower_k) except Exception: lower_k = 0 attn = (prefill_attn or "causal").strip().lower() if attn == "prefix_full": attn = "full" sys_prefill = (system_prefill or "no_bos_system").strip().lower() if sys_prefill not in {"full", "no_system", "no_bos_system"}: sys_prefill = "no_bos_system" if lower_k <= 0 or attn not in {"causal", "full"}: return try: setattr(model, "_runtime_prefill_solo_v2_layers", int(lower_k)) setattr(model, "_runtime_prefill_solo_v2_attn", attn) setattr(model, "_runtime_prefill_solo_v2_system_prefill", sys_prefill) setattr(model, "_runtime_prefill_solo_v2_with_bos", bool(with_bos)) setattr(model, "_runtime_structured_solo_v2_generate_default", True) except Exception: return if tokenizer is not None: _attach_structured_llopa_generate(model, tokenizer) @torch.inference_mode() def _direct_freeze_generate_impl( model, tokenizer, *, prompt_messages, prompt_add_generation_prompt: bool, input_ids: Optional[torch.LongTensor], attention_mask: Optional[torch.Tensor], lower_k: int, prefill_attn: str, system_prefill: str, max_length=None, max_new_tokens=None, min_length=None, min_new_tokens=None, do_sample=None, temperature=None, top_p=None, top_k=None, stopping_criteria=None, pad_token_id=None, eos_token_id=None, output_scores: bool = False, return_dict_in_generate: bool = False, use_cache: Optional[bool] = None, ): try: lower_k = int(lower_k) except Exception: lower_k = 0 if lower_k <= 0: return None attn = (prefill_attn or "causal").strip().lower() if attn == "prefix_full": attn = "full" if attn not in {"causal", "full"}: attn = "causal" sys_prefill = (system_prefill or "no_bos_system").strip().lower() if sys_prefill not in {"full", "no_system", "no_bos_system"}: sys_prefill = "no_bos_system" device = None if isinstance(input_ids, torch.Tensor): if input_ids.dim() != 2 or input_ids.size(0) != 1: return None device = input_ids.device if device is None: try: device = next(model.parameters()).device except Exception: device = "cpu" segments = _build_structured_prompt_segments( tokenizer, prompt_messages, prompt_add_generation_prompt=bool(prompt_add_generation_prompt), device=device, ) prompt_ids = segments["prompt_ids"] system_ids = segments["system_ids"] user_ids = segments["user_ids"] canonical_input_ids = prompt_ids split_start = int(system_ids.size(1) + user_ids.size(1)) system_len = int(system_ids.size(1)) total_prompt_len = int(canonical_input_ids.size(1)) if max_new_tokens is None: if max_length is None: max_new_tokens = 256 else: max_new_tokens = max(0, int(max_length) - total_prompt_len) else: max_new_tokens = int(max_new_tokens) if min_new_tokens is None: if min_length is None: min_new_tokens = 0 else: min_new_tokens = max(0, int(min_length) - total_prompt_len) else: min_new_tokens = int(min_new_tokens) raw_temp = 0.0 if temperature is None else float(temperature) if do_sample is None: do_sample = bool(raw_temp != 0.0) do_sample = bool(do_sample) sample_temp = 1.0 if (not do_sample or raw_temp == 0.0) else float(raw_temp) top_p = 1.0 if top_p is None else float(top_p) top_k = None if top_k is None else int(top_k) stop_ids = set(_normalize_eos_token_ids(eos_token_id)) if not stop_ids: tok_eos = getattr(tokenizer, "eos_token_id", None) if tok_eos is not None: stop_ids.add(int(tok_eos)) with contextlib.suppress(Exception): eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>") if eot_id is not None and eot_id != tokenizer.unk_token_id: stop_ids.add(int(eot_id)) logits_warpers = _build_sampling_warpers(do_sample, sample_temp, top_p, top_k) generated = torch.empty((1, int(max_new_tokens)), dtype=torch.long, device=canonical_input_ids.device) score_list: list[torch.Tensor] = [] cur = 0 while cur < int(max_new_tokens): current_ids = torch.cat([canonical_input_ids, generated[:, :cur]], dim=1) current_attn = torch.ones_like(current_ids, device=current_ids.device) out = model( input_ids=current_ids, attention_mask=current_attn, use_cache=False, logits_to_keep=1, prefill_lower_layers=int(lower_k), prefill_lower_attn=str(attn), prefill_lower_freeze_runtime=True, prefill_lower_split_start=int(split_start), prefill_lower_system_len=int(system_len), prefill_lower_system_prefill=str(sys_prefill), ) if out is None or not isinstance(getattr(out, "logits", None), torch.Tensor): return None logits = out.logits[:, -1, :].to(torch.float32) if stop_ids and cur < int(min_new_tokens): for sid in stop_ids: logits[:, sid] = -float("inf") if logits_warpers is not None: logits = logits_warpers(generated[:, :cur], logits) if bool(output_scores): score_list.append(logits.detach().clone()) if do_sample: probs = torch.softmax(logits, dim=-1) next_tok = torch.multinomial(probs, num_samples=1) else: next_tok = torch.argmax(logits, dim=-1, keepdim=True) generated[:, cur : cur + 1] = next_tok cur += 1 sequences_now = torch.cat([canonical_input_ids, generated[:, :cur]], dim=1) should_stop = False tok_id = int(next_tok.item()) if tok_id in stop_ids and cur >= int(min_new_tokens): should_stop = True if (not should_stop) and stopping_criteria is not None: try: should_stop = bool(stopping_criteria(sequences_now, logits)) except TypeError: should_stop = bool(stopping_criteria(sequences_now, None)) if should_stop: break sequences = torch.cat([canonical_input_ids, generated[:, :cur]], dim=1) if not bool(return_dict_in_generate): return sequences return { "sequences": sequences, "scores": tuple(score_list) if bool(output_scores) else tuple(), } @torch.inference_mode() def _direct_solo_generate_impl( model, tokenizer, *, prompt_messages, prompt_add_generation_prompt: bool, input_ids: Optional[torch.LongTensor], attention_mask: Optional[torch.Tensor], lower_k: int, prefill_attn: str, system_prefill: str, max_length=None, max_new_tokens=None, min_length=None, min_new_tokens=None, do_sample=None, temperature=None, top_p=None, top_k=None, stopping_criteria=None, pad_token_id=None, eos_token_id=None, output_scores: bool = False, return_dict_in_generate: bool = False, use_cache: Optional[bool] = None, solo_v2: bool = False, with_bos: bool = False, ): try: lower_k = int(lower_k) except Exception: lower_k = 0 if lower_k <= 0: return None attn = (prefill_attn or "causal").strip().lower() if attn == "prefix_full": attn = "full" if attn not in {"causal", "full"}: attn = "causal" sys_prefill = (system_prefill or "no_bos_system").strip().lower() if sys_prefill not in {"full", "no_system", "no_bos_system"}: sys_prefill = "no_bos_system" device = None if isinstance(input_ids, torch.Tensor): if input_ids.dim() != 2 or input_ids.size(0) != 1: return None device = input_ids.device if device is None: try: device = next(model.parameters()).device except Exception: device = "cpu" segments = _build_structured_prompt_segments( tokenizer, prompt_messages, prompt_add_generation_prompt=bool(prompt_add_generation_prompt), device=device, ) prompt_ids = segments["prompt_ids"] system_ids = segments["system_ids"] user_ids = segments["user_ids"] canonical_input_ids = prompt_ids split_start = int(system_ids.size(1) + user_ids.size(1)) system_len = int(system_ids.size(1)) total_prompt_len = int(canonical_input_ids.size(1)) if max_new_tokens is None: if max_length is None: max_new_tokens = 256 else: max_new_tokens = max(0, int(max_length) - total_prompt_len) else: max_new_tokens = int(max_new_tokens) if min_new_tokens is None: if min_length is None: min_new_tokens = 0 else: min_new_tokens = max(0, int(min_length) - total_prompt_len) else: min_new_tokens = int(min_new_tokens) raw_temp = 0.0 if temperature is None else float(temperature) if do_sample is None: do_sample = bool(raw_temp != 0.0) do_sample = bool(do_sample) sample_temp = 1.0 if (not do_sample or raw_temp == 0.0) else float(raw_temp) top_p = 1.0 if top_p is None else float(top_p) top_k = None if top_k is None else int(top_k) stop_ids = set(_normalize_eos_token_ids(eos_token_id)) if not stop_ids: tok_eos = getattr(tokenizer, "eos_token_id", None) if tok_eos is not None: stop_ids.add(int(tok_eos)) with contextlib.suppress(Exception): eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>") if eot_id is not None and eot_id != tokenizer.unk_token_id: stop_ids.add(int(eot_id)) logits_warpers = _build_sampling_warpers(do_sample, sample_temp, top_p, top_k) generated = torch.empty((1, int(max_new_tokens)), dtype=torch.long, device=canonical_input_ids.device) score_list: list[torch.Tensor] = [] cur = 0 while cur < int(max_new_tokens): current_ids = torch.cat([canonical_input_ids, generated[:, :cur]], dim=1) current_attn = torch.ones_like(current_ids, device=current_ids.device) out = model( input_ids=current_ids, attention_mask=current_attn, use_cache=False, logits_to_keep=1, prefill_lower_layers=int(lower_k), prefill_lower_attn=str(attn), prefill_lower_split_start=int(split_start), prefill_lower_system_len=int(system_len), prefill_lower_system_prefill=str(sys_prefill), prefill_lower_solo_attention=bool(not solo_v2), prefill_lower_solo_attention_v2=bool(solo_v2), prefill_lower_solo_attention_v2_with_bos=bool(solo_v2 and with_bos), ) if out is None or not isinstance(getattr(out, "logits", None), torch.Tensor): return None logits = out.logits[:, -1, :].to(torch.float32) if stop_ids and cur < int(min_new_tokens): for sid in stop_ids: logits[:, sid] = -float("inf") if logits_warpers is not None: logits = logits_warpers(generated[:, :cur], logits) if bool(output_scores): score_list.append(logits.detach().clone()) if do_sample: probs = torch.softmax(logits, dim=-1) next_tok = torch.multinomial(probs, num_samples=1) else: next_tok = torch.argmax(logits, dim=-1, keepdim=True) generated[:, cur : cur + 1] = next_tok cur += 1 sequences_now = torch.cat([canonical_input_ids, generated[:, :cur]], dim=1) should_stop = False tok_id = int(next_tok.item()) if tok_id in stop_ids and cur >= int(min_new_tokens): should_stop = True if (not should_stop) and stopping_criteria is not None: try: should_stop = bool(stopping_criteria(sequences_now, logits)) except TypeError: should_stop = bool(stopping_criteria(sequences_now, None)) if should_stop: break sequences = torch.cat([canonical_input_ids, generated[:, :cur]], dim=1) if not bool(return_dict_in_generate): return sequences return { "sequences": sequences, "scores": tuple(score_list) if bool(output_scores) else tuple(), } def _attach_runtime_llopa_generate( model, *, header_ids: torch.Tensor, lower_k: int, prefill_attn: str = "causal", no_upper_attn: bool = False, ) -> None: try: import types except Exception: return try: lower_k = int(lower_k) except Exception: lower_k = 0 attn = (prefill_attn or "causal").strip().lower() if attn == "prefix_full": attn = "full" if ( lower_k <= 0 or attn not in {"causal", "full"} or not isinstance(header_ids, torch.Tensor) or header_ids.numel() == 0 ): return try: setattr(model, "_runtime_llopa_header_ids", header_ids.detach().to(device="cpu", dtype=torch.long)) setattr(model, "_runtime_llopa_layers", int(lower_k)) setattr(model, "_runtime_llopa_attn", attn) setattr(model, "_runtime_llopa_no_upper_attn", bool(no_upper_attn)) except Exception: return if getattr(model, "_runtime_llopa_generate_attached", False): return orig_prepare = getattr(model, "prepare_inputs_for_generation", None) if orig_prepare is None: return def _runtime_prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, runtime_llopa_prefill=None, runtime_llopa_layers=None, runtime_llopa_attn=None, runtime_llopa_no_upper_attn=None, **kwargs, ): model_inputs = orig_prepare( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, inputs_embeds=inputs_embeds, cache_position=cache_position, **kwargs, ) if runtime_llopa_prefill is None: runtime_llopa_prefill = True if runtime_llopa_layers is None: runtime_llopa_layers = int(getattr(self, "_runtime_llopa_layers", 0) or 0) if runtime_llopa_attn is None: runtime_llopa_attn = str(getattr(self, "_runtime_llopa_attn", "causal") or "causal") if runtime_llopa_no_upper_attn is None: runtime_llopa_no_upper_attn = bool(getattr(self, "_runtime_llopa_no_upper_attn", False)) if bool(runtime_llopa_prefill) and int(runtime_llopa_layers or 0) > 0: model_inputs["runtime_llopa_prefill"] = True model_inputs["runtime_llopa_layers"] = int(runtime_llopa_layers) model_inputs["runtime_llopa_attn"] = str(runtime_llopa_attn) if bool(runtime_llopa_no_upper_attn): model_inputs["runtime_llopa_no_upper_attn"] = True return model_inputs try: model.prepare_inputs_for_generation = types.MethodType( _runtime_prepare_inputs_for_generation, model, ) setattr(model, "_runtime_llopa_generate_attached", True) except Exception: pass @torch.inference_mode() def _runtime_llopa_fast_generate_mode( model, input_ids: torch.LongTensor, logits_processor, stopping_criteria, generation_config, synced_gpus: bool = False, streamer=None, **model_kwargs, ): from transformers.generation.utils import GenerateDecoderOnlyOutput if model.config.is_encoder_decoder: return None pad_token_id = generation_config._pad_token_tensor output_attentions = generation_config.output_attentions output_hidden_states = generation_config.output_hidden_states output_scores = generation_config.output_scores output_logits = generation_config.output_logits return_dict_in_generate = generation_config.return_dict_in_generate has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) do_sample = generation_config.do_sample scores = () if (return_dict_in_generate and output_scores) else None raw_logits = () if (return_dict_in_generate and output_logits) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None batch_size, cur_len = input_ids.shape[:2] this_peer_finished = False unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) model_kwargs = model._get_initial_cache_position(cur_len, input_ids.device, model_kwargs) while model._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs) outputs = model(**model_inputs, return_dict=True) model_kwargs = model._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=False, ) if synced_gpus and this_peer_finished: continue next_token_logits = outputs.logits[:, -1, :].to(dtype=torch.float32) raw_next_token_logits = next_token_logits.clone() if (return_dict_in_generate and output_logits) else None next_token_scores = logits_processor(input_ids, next_token_logits) if return_dict_in_generate: if output_scores: scores += (next_token_scores,) if output_logits: raw_logits += (raw_next_token_logits,) if output_attentions: decoder_attentions += (outputs.attentions,) if output_hidden_states: decoder_hidden_states += (outputs.hidden_states,) if do_sample: probs = nn.functional.softmax(next_token_scores, dim=-1) next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) else: next_tokens = torch.argmax(next_token_scores, dim=-1) if has_eos_stopping_criteria: next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) if streamer is not None: streamer.put(next_tokens.cpu()) unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) this_peer_finished = unfinished_sequences.max() == 0 cur_len += 1 del outputs if streamer is not None: streamer.end() if return_dict_in_generate: return GenerateDecoderOnlyOutput( sequences=input_ids, scores=scores, logits=raw_logits, attentions=decoder_attentions, hidden_states=decoder_hidden_states, past_key_values=model_kwargs.get("past_key_values"), ) return input_ids @torch.inference_mode() def _llopa_v2_generation_mixin_decode_loop( model, *, canonical_input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor], past_key_values, initial_logits: torch.Tensor, lower_k: int, no_upper_attn: bool, replay_module: str, replay_per_layers: int, max_new_tokens: int, min_new_tokens: int, do_sample: bool, logits_warpers, stop_ids: set[int], stop_token_ids: Optional[torch.Tensor], stopping_criteria, output_scores: bool, compact_scores: bool, return_dict_in_generate: bool, ): if not isinstance(canonical_input_ids, torch.Tensor) or canonical_input_ids.dim() != 2: return None if canonical_input_ids.size(0) != 1: return None if not isinstance(initial_logits, torch.Tensor): return None if int(max_new_tokens) <= 0: return canonical_input_ids if not bool(return_dict_in_generate) else None from transformers.generation.utils import GenerateDecoderOnlyOutput device = canonical_input_ids.device total_prompt_len = int(canonical_input_ids.size(1)) max_new_tokens = int(max_new_tokens) min_new_tokens = int(min_new_tokens) sequences = torch.empty( (canonical_input_ids.size(0), total_prompt_len + max_new_tokens), dtype=canonical_input_ids.dtype, device=device, ) sequences[:, :total_prompt_len] = canonical_input_ids generated = sequences[:, total_prompt_len:] if isinstance(attention_mask, torch.Tensor) and attention_mask.dim() == 2: prompt_attention_mask = attention_mask.to(device=device) if prompt_attention_mask.size(0) != canonical_input_ids.size(0): prompt_attention_mask = torch.ones_like(canonical_input_ids, dtype=torch.long, device=device) elif prompt_attention_mask.size(1) != total_prompt_len: prompt_attention_mask = prompt_attention_mask[:, -total_prompt_len:] else: prompt_attention_mask = torch.ones_like(canonical_input_ids, dtype=torch.long, device=device) model_kwargs = { "attention_mask": prompt_attention_mask, "past_key_values": past_key_values, "use_cache": True, "llopa_v2_decode": True, "llopa_v2_decode_layers": int(lower_k), "llopa_v2_decode_no_upper_attn": bool(no_upper_attn), "llopa_v2_decode_replay_module": str(replay_module), "llopa_v2_decode_replay_per_layers": int(replay_per_layers), } record_scores = bool(output_scores) record_compact_scores = record_scores and bool(compact_scores) score_list: list[torch.Tensor] = [] score_list_append = score_list.append compact_logprob_list: list[torch.Tensor] = [] compact_logprob_append = compact_logprob_list.append cur = 0 pending_logits = initial_logits should_apply_stopping_criteria = stopping_criteria is not None while cur < max_new_tokens: outputs = None used_prefill_logits = pending_logits is not None if used_prefill_logits: logits = pending_logits.to(dtype=torch.float32, device=device, copy=True) pending_logits = None else: current_input_ids = sequences[:, : total_prompt_len + cur] if not isinstance(model_kwargs.get("cache_position"), torch.Tensor): model_kwargs["cache_position"] = torch.arange( total_prompt_len + cur - 1, total_prompt_len + cur, device=device, dtype=torch.long, ) model_inputs = model.prepare_inputs_for_generation(current_input_ids, **model_kwargs) outputs = model(**model_inputs, return_dict=True) model_kwargs = model._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=False, ) logits = outputs.logits[:, -1, :].to(dtype=torch.float32, device=device, copy=True) if stop_token_ids is not None and cur < min_new_tokens: logits.index_fill_(1, stop_token_ids, -float("inf")) if logits_warpers is not None: logits = logits_warpers(generated[:, :cur], logits) if record_scores and not record_compact_scores: score_list_append(logits.detach().clone()) if do_sample: probs = torch.softmax(logits, dim=-1) next_tok = torch.multinomial(probs, num_samples=1) else: next_tok = torch.argmax(logits, dim=-1, keepdim=True) if record_compact_scores: next_logit = torch.gather(logits, 1, next_tok) next_logprob = next_logit - torch.logsumexp(logits, dim=-1, keepdim=True) compact_logprob_append(next_logprob.squeeze(-1).detach()) generated[:, cur : cur + 1] = next_tok cur += 1 if used_prefill_logits: one_mask = torch.ones( (prompt_attention_mask.size(0), 1), dtype=prompt_attention_mask.dtype, device=device, ) model_kwargs["attention_mask"] = torch.cat([model_kwargs["attention_mask"], one_mask], dim=-1) model_kwargs["cache_position"] = torch.arange( total_prompt_len, total_prompt_len + 1, device=device, dtype=torch.long, ) should_stop = False tok_id = int(next_tok.item()) if tok_id in stop_ids and cur >= min_new_tokens: should_stop = True if (not should_stop) and should_apply_stopping_criteria: sequences_now = sequences[:, : total_prompt_len + cur] try: should_stop = bool(stopping_criteria(sequences_now, logits)) except TypeError: should_stop = bool(stopping_criteria(sequences_now, None)) if outputs is not None: del outputs if should_stop: break sequences = sequences[:, : total_prompt_len + cur] if not bool(return_dict_in_generate): return sequences output = GenerateDecoderOnlyOutput( sequences=sequences, scores=tuple(score_list) if record_scores and not record_compact_scores else None, past_key_values=model_kwargs.get("past_key_values", past_key_values), ) if record_compact_scores: if compact_logprob_list: compact_logprobs = torch.stack(compact_logprob_list, dim=1) else: compact_logprobs = torch.empty( (sequences.size(0), 0), dtype=torch.float32, device=sequences.device, ) setattr(output, "generated_token_logprobs", compact_logprobs) return output def _attach_runtime_llopa_fast_generate( model, *, lower_k: int, prefill_attn: str = "causal", no_upper_attn: bool = False, ) -> None: try: import types except Exception: return try: lower_k = int(lower_k) except Exception: lower_k = 0 attn = (prefill_attn or "causal").strip().lower() if attn == "prefix_full": attn = "full" if lower_k <= 0 or attn not in {"causal", "full"}: return try: setattr(model, "_runtime_llopa_fast_layers", int(lower_k)) setattr(model, "_runtime_llopa_fast_attn", attn) setattr(model, "_runtime_llopa_fast_no_upper_attn", bool(no_upper_attn)) except Exception: return if getattr(model, "_runtime_llopa_fast_generate_attached", False): return orig_generate = getattr(model, "generate", None) if not callable(orig_generate): return def _runtime_llopa_fast_generate(self, *args, **kwargs): if bool(getattr(self.config, "is_encoder_decoder", False)): return orig_generate(*args, **kwargs) if kwargs.get("custom_generate") is not None or kwargs.get("assistant_model") is not None: return orig_generate(*args, **kwargs) if kwargs.get("inputs_embeds") is not None: return orig_generate(*args, **kwargs) num_beams = kwargs.get("num_beams", None) if num_beams is None: gen_cfg = kwargs.get("generation_config") num_beams = getattr(gen_cfg, "num_beams", 1) if gen_cfg is not None else 1 if int(num_beams or 1) != 1: _warn_once( self, "_warned_runtime_llopa_fast_num_beams", "[load_llopa_model][warn] runtime_llopa_fast_generate currently supports only num_beams=1; falling back to model.generate().", ) return orig_generate(*args, **kwargs) fast_enabled = kwargs.pop("runtime_llopa_fast_generate", None) if fast_enabled is None: fast_enabled = True if not bool(fast_enabled): return orig_generate(*args, **kwargs) runtime_enabled = kwargs.get("runtime_llopa_prefill", None) if runtime_enabled is None: runtime_enabled = True if not bool(runtime_enabled): return orig_generate(*args, **kwargs) if kwargs.get("runtime_llopa_layers") is None: kwargs["runtime_llopa_layers"] = int(getattr(self, "_runtime_llopa_fast_layers", 0) or 0) if kwargs.get("runtime_llopa_attn") is None: kwargs["runtime_llopa_attn"] = str(getattr(self, "_runtime_llopa_fast_attn", "causal") or "causal") if kwargs.get("runtime_llopa_no_upper_attn") is None: kwargs["runtime_llopa_no_upper_attn"] = bool( getattr(self, "_runtime_llopa_fast_no_upper_attn", False) ) kwargs["runtime_llopa_prefill"] = True return orig_generate( *args, custom_generate=_runtime_llopa_fast_generate_mode, **kwargs, ) try: model.generate = types.MethodType(_runtime_llopa_fast_generate, model) setattr(model, "_runtime_llopa_fast_generate_attached", True) except Exception: pass def _supports_prefill_lower_runtime(model) -> bool: try: inner = _get_inner_model(model) except Exception: inner = None return bool( hasattr(model, "tri_vanilla_prefill_decode_forward") and inner is not None and hasattr(inner, "tri_prefill_lower_cache") ) def _supports_runtime_llopa_prompt_prefill(model) -> bool: try: inner = _get_inner_model(model) except Exception: inner = None return bool( hasattr(model, "tri_runtime_llopa_prompt_prefill_forward") and _get_llopa_decode_step(model) is not None and inner is not None and hasattr(inner, "llopa_prefill_cache") ) def _supports_direct_llopa_generate(model) -> bool: try: inner = _get_inner_model(model) except Exception: inner = None return bool( _get_llopa_decode_step(model) is not None and inner is not None and hasattr(inner, "llopa_prefill_cache") ) def _warn_once(model, flag: str, msg: str) -> None: if getattr(model, flag, False): return try: setattr(model, flag, True) except Exception: pass print(msg) def _normalize_eos_token_ids(eos_token_id) -> list[int]: if eos_token_id is None: return [] if isinstance(eos_token_id, torch.Tensor): eos_token_id = eos_token_id.flatten().tolist() if isinstance(eos_token_id, (list, tuple, set)): out: list[int] = [] for item in eos_token_id: try: out.append(int(item)) except Exception: continue return out try: return [int(eos_token_id)] except Exception: return [] def _prepare_stop_token_tensor(stop_ids, device) -> Optional[torch.LongTensor]: if not stop_ids: return None try: ordered = [int(tok_id) for tok_id in stop_ids] except Exception: return None if not ordered: return None return torch.tensor(ordered, device=device, dtype=torch.long) def _find_last_subsequence_start(input_ids: torch.Tensor, pattern: torch.Tensor) -> Optional[int]: if not isinstance(input_ids, torch.Tensor) or not isinstance(pattern, torch.Tensor): return None if input_ids.dim() == 2: if input_ids.size(0) != 1: return None seq = input_ids[0] elif input_ids.dim() == 1: seq = input_ids else: return None if pattern.dim() == 2: if pattern.size(0) != 1: return None pat = pattern[0] elif pattern.dim() == 1: pat = pattern else: return None pat_len = int(pat.numel()) seq_len = int(seq.numel()) if pat_len <= 0 or seq_len < pat_len: return None windows = seq.unfold(0, pat_len, 1) matches = (windows == pat).all(dim=-1).nonzero(as_tuple=False) if matches.numel() == 0: return None return int(matches[-1].item()) def _build_sampling_warpers(do_sample: bool, temperature, top_p, top_k): if not do_sample: return None from transformers.generation import LogitsProcessorList from transformers.generation.logits_process import ( TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper, ) procs = LogitsProcessorList() temp = float(temperature) if temp != 1.0: procs.append(TemperatureLogitsWarper(temp)) if top_p is not None and float(top_p) < 1.0: procs.append(TopPLogitsWarper(float(top_p), min_tokens_to_keep=1)) if top_k is not None and int(top_k) > 0: procs.append(TopKLogitsWarper(int(top_k), filter_value=-float("inf"))) return procs @torch.inference_mode() def _legacy_direct_llopa_generate_impl( model, tokenizer, *, input_ids: Optional[torch.LongTensor], attention_mask: Optional[torch.Tensor], lower_k: int, prefill_attn: str, max_length=None, max_new_tokens=None, min_length=None, min_new_tokens=None, do_sample=None, temperature=None, top_p=None, top_k=None, stopping_criteria=None, pad_token_id=None, eos_token_id=None, output_scores: bool = False, return_dict_in_generate: bool = False, use_cache: Optional[bool] = None, ): if input_ids is None or input_ids.dim() != 2 or input_ids.size(0) != 1: return None if use_cache is False: return None valid_len = int(input_ids.size(1)) if attention_mask is not None: if attention_mask.dim() != 2 or attention_mask.size(0) != input_ids.size(0): return None valid_len = int(attention_mask[0].sum().item()) if valid_len <= 0: return None trimmed_input_ids = input_ids[:, -valid_len:] header_ids = getattr(model, "_direct_llopa_header_ids", None) if not isinstance(header_ids, torch.Tensor) or header_ids.numel() == 0: return None hdr = header_ids.to(device=trimmed_input_ids.device, dtype=trimmed_input_ids.dtype) assistant_start = _find_last_subsequence_start(trimmed_input_ids, hdr) if assistant_start is None: return None prefix_ids = trimmed_input_ids[:, :assistant_start] assistant_ids = trimmed_input_ids[:, assistant_start:] if assistant_ids.numel() == 0: return None try: lower_k = int(lower_k) except Exception: lower_k = 0 if lower_k <= 0: return None attn = (prefill_attn or "causal").strip().lower() if attn == "prefix_full": attn = "full" if attn not in {"causal", "full"}: attn = "causal" total_prompt_len = int(input_ids.size(1)) if max_new_tokens is None: if max_length is None: max_new_tokens = 256 else: max_new_tokens = max(0, int(max_length) - total_prompt_len) else: max_new_tokens = int(max_new_tokens) if min_new_tokens is None: if min_length is None: min_new_tokens = 0 else: min_new_tokens = max(0, int(min_length) - total_prompt_len) else: min_new_tokens = int(min_new_tokens) raw_temp = 0.0 if temperature is None else float(temperature) if do_sample is None: do_sample = bool(raw_temp != 0.0) do_sample = bool(do_sample) sample_temp = 1.0 if (not do_sample or raw_temp == 0.0) else float(raw_temp) top_p = 1.0 if top_p is None else float(top_p) top_k = None if top_k is None else int(top_k) llopa_core = _get_llopa_core(model) llopa_step = _get_llopa_decode_step(model) if llopa_core is None or llopa_step is None: return None llopa_fn = getattr(llopa_core, "llopa_prefill_cache", None) if llopa_fn is None: return None output_head = _get_output_head(model) prefill_out = llopa_fn( system_ids=prefix_ids[:, :0], user_ids=prefix_ids, assistant_ids=assistant_ids, lower_k=lower_k, prefill_mode="lower", prefill_attn=attn, return_last_assistant_hidden=bool(output_head is not None), ) initial_logits = None if isinstance(prefill_out, tuple): pkv, last_hidden = prefill_out if output_head is not None and isinstance(last_hidden, torch.Tensor) and last_hidden.numel() > 0: initial_logits = output_head(last_hidden)[:, -1, :].to(torch.float32) else: pkv = prefill_out S = 0 U = int(prefix_ids.size(1)) last = assistant_ids[:, -1:] stop_ids = set(_normalize_eos_token_ids(eos_token_id)) if not stop_ids: tok_eos = getattr(tokenizer, "eos_token_id", None) if tok_eos is not None: stop_ids.add(int(tok_eos)) stop_token_ids = _prepare_stop_token_tensor(stop_ids, last.device) logits_warpers = _build_sampling_warpers(do_sample, sample_temp, top_p, top_k) max_new_tokens = int(max_new_tokens) min_new_tokens = int(min_new_tokens) total_prompt_len = int(input_ids.size(1)) record_scores = bool(output_scores) should_apply_stopping_criteria = stopping_criteria is not None sequences = torch.empty( (input_ids.size(0), total_prompt_len + max_new_tokens), dtype=input_ids.dtype, device=input_ids.device, ) sequences[:, :total_prompt_len] = input_ids generated = sequences[:, total_prompt_len:] score_list: list[torch.Tensor] = [] score_list_append = score_list.append cur = 0 pending_logits = initial_logits while cur < max_new_tokens: out = None if pending_logits is None: out = llopa_step( assistant_ids=last, lower_k=lower_k, pkv=pkv, S=S, U=U, logits_to_keep=1, labels=None, prefill_mode="lower", ) pkv = out.past_key_values or pkv logits = out.logits[:, -1, :].to(dtype=torch.float32, device=last.device, copy=True) else: logits = pending_logits pending_logits = None if stop_token_ids is not None and cur < min_new_tokens: logits.index_fill_(1, stop_token_ids, -float("inf")) if logits_warpers is not None: logits = logits_warpers(generated[:, :cur], logits) if record_scores: score_list_append(logits.detach().clone()) if do_sample: probs = torch.softmax(logits, dim=-1) next_tok = torch.multinomial(probs, num_samples=1) else: next_tok = torch.argmax(logits, dim=-1, keepdim=True) generated[:, cur : cur + 1] = next_tok cur += 1 should_stop = False tok_id = int(next_tok.item()) if tok_id in stop_ids and cur >= min_new_tokens: should_stop = True if (not should_stop) and should_apply_stopping_criteria: sequences_now = sequences[:, : total_prompt_len + cur] try: should_stop = bool(stopping_criteria(sequences_now, logits)) except TypeError: should_stop = bool(stopping_criteria(sequences_now, None)) if out is not None: del out if should_stop: break last = next_tok sequences = sequences[:, : total_prompt_len + cur] if not bool(return_dict_in_generate): return sequences from transformers.generation.utils import GenerateDecoderOnlyOutput return GenerateDecoderOnlyOutput( sequences=sequences, scores=tuple(score_list) if bool(output_scores) else None, past_key_values=pkv, ) @torch.inference_mode() def _direct_llopa_generate_impl( model, tokenizer, *, prompt_messages, prompt_add_generation_prompt: bool, structured_prompt_segments=None, input_ids: Optional[torch.LongTensor], attention_mask: Optional[torch.Tensor], lower_k: int, prefill_attn: str, system_prefill: str, user_prefill: str, no_upper_attn: bool, see_past_assistant: bool = False, replay_module: str = "none", replay_per_layers: int = -1, last_layer_module: Optional[str] = None, seed_mode: str = "auto", max_length=None, max_new_tokens=None, min_length=None, min_new_tokens=None, do_sample=None, temperature=None, top_p=None, top_k=None, stopping_criteria=None, pad_token_id=None, eos_token_id=None, output_scores: bool = False, compact_scores: bool = False, return_dict_in_generate: bool = False, use_cache: Optional[bool] = None, ): if last_layer_module is not None and _normalize_replay_module_value(replay_module) == "none": replay_module = last_layer_module replay_module = _normalize_replay_module_value(replay_module) replay_per_layers = _normalize_replay_per_layers_value(replay_per_layers) if use_cache is False: return None try: lower_k = int(lower_k) except Exception: lower_k = 0 if lower_k <= 0: return None attn = (prefill_attn or "causal").strip().lower() if attn == "prefix_full": attn = "full" if attn not in {"causal", "full"}: attn = "causal" sys_prefill = (system_prefill or "full").strip().lower() if sys_prefill not in {"full", "no_system", "no_bos_system"}: sys_prefill = "full" user_prefill_norm = (user_prefill or "full").strip().lower() if user_prefill_norm != "full": raise ValueError("Unified direct LLoPA currently supports only user_prefill='full'.") llopa_core = _get_llopa_core(model) llopa_step = _get_llopa_decode_step(model) if llopa_core is None or llopa_step is None: return None llopa_forward_assistant = getattr(llopa_core, "tri_forward_assistant", None) decode_output_head = _get_output_head(model) use_direct_decode_step = ( _env_flag_enabled("CAPSULE_LLOPA_DIRECT_DECODE_STEP", "1") and callable(llopa_forward_assistant) and decode_output_head is not None ) device = None if isinstance(input_ids, torch.Tensor): if input_ids.dim() != 2 or input_ids.size(0) != 1: return None device = input_ids.device if device is None: try: device = next(model.parameters()).device except Exception: device = "cpu" segments = structured_prompt_segments if isinstance(structured_prompt_segments, dict) else None if segments is None: segments = _build_structured_prompt_segments( tokenizer, prompt_messages, prompt_add_generation_prompt=bool(prompt_add_generation_prompt), device=device, ) prompt_ids = segments["prompt_ids"] system_ids = segments["system_ids"] user_ids = segments["user_ids"] assistant_prefill_ids = segments["assistant_prefill_ids"] replay_user_prefix_keep_len = int(segments.get("replay_user_prefix_keep_len", 0) or 0) replay_user_start = int(segments.get("replay_user_start", 0) or 0) replay_user_len = int(segments.get("replay_user_len", 0) or 0) if assistant_prefill_ids.numel() == 0: return None raw_temp = 0.0 if temperature is None else float(temperature) if do_sample is None: do_sample = bool(raw_temp != 0.0) do_sample = bool(do_sample) sample_temp = 1.0 if (not do_sample or raw_temp == 0.0) else float(raw_temp) top_p = 1.0 if top_p is None else float(top_p) top_k = None if top_k is None else int(top_k) initial_logits = None prompt_bundle = _build_unified_prefill_lower_prompt_bundle( tokenizer, prompt_messages=prompt_messages, prompt_add_generation_prompt=bool(prompt_add_generation_prompt), structured_prompt_segments=segments, device=device, ) reference_seed = _direct_prefill_lower_cache_and_logits( model, prompt_bundle=prompt_bundle, lower_k=lower_k, prefill_attn=attn, system_prefill=sys_prefill, no_upper_attn=bool(no_upper_attn), see_past_assistant=bool(see_past_assistant), replay_module=str(replay_module), replay_per_layers=int(replay_per_layers), seed_mode=str(seed_mode or "auto"), ) canonical_input_ids = None if isinstance(input_ids, torch.Tensor) and input_ids.dim() == 2 and input_ids.size(0) == 1: valid_len = int(input_ids.size(1)) if isinstance(attention_mask, torch.Tensor) and attention_mask.dim() == 2 and attention_mask.size(0) == 1: valid_len = int(attention_mask[0].sum().item()) if valid_len > 0: canonical_input_ids = input_ids[:, -valid_len:] if not isinstance(canonical_input_ids, torch.Tensor): canonical_input_ids = prompt_bundle.get("effective_prompt_ids") if not isinstance(canonical_input_ids, torch.Tensor): canonical_input_ids = prompt_ids total_prompt_len = int(canonical_input_ids.size(1)) if max_new_tokens is None: if max_length is None: max_new_tokens = 256 else: max_new_tokens = max(0, int(max_length) - total_prompt_len) else: max_new_tokens = int(max_new_tokens) if min_new_tokens is None: if min_length is None: min_new_tokens = 0 else: min_new_tokens = max(0, int(min_length) - total_prompt_len) else: min_new_tokens = int(min_new_tokens) if reference_seed is not None: pkv, S, U, initial_logits = reference_seed else: output_head = _get_output_head(model) if bool(no_upper_attn): pkv, S, U = _llopa_prefill_cache( llopa_core, system_ids, user_ids, assistant_prefill_ids, lower_k=lower_k, prefill_mode="lower", prefill_attn=attn, system_prefill=sys_prefill, replay_user_prefix_keep_len=replay_user_prefix_keep_len, replay_user_start=replay_user_start, replay_user_len=replay_user_len, ) else: pkv, S, U, last_hidden = _llopa_prefill_cache( llopa_core, system_ids, user_ids, assistant_prefill_ids, lower_k=lower_k, prefill_mode="lower", prefill_attn=attn, system_prefill=sys_prefill, return_last_assistant_hidden=bool(output_head is not None), replay_user_prefix_keep_len=replay_user_prefix_keep_len, replay_user_start=replay_user_start, replay_user_len=replay_user_len, ) if output_head is not None and isinstance(last_hidden, torch.Tensor) and last_hidden.numel() > 0: initial_logits = output_head(last_hidden)[:, -1, :].to(torch.float32) stop_ids = set(_normalize_eos_token_ids(eos_token_id)) if not stop_ids: tok_eos = getattr(tokenizer, "eos_token_id", None) if tok_eos is not None: stop_ids.add(int(tok_eos)) with contextlib.suppress(Exception): eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>") if eot_id is not None and eot_id != tokenizer.unk_token_id: stop_ids.add(int(eot_id)) last = assistant_prefill_ids[:, -1:] stop_token_ids = _prepare_stop_token_tensor(stop_ids, last.device) logits_warpers = _build_sampling_warpers(do_sample, sample_temp, top_p, top_k) max_new_tokens = int(max_new_tokens) min_new_tokens = int(min_new_tokens) lower_k = int(lower_k) no_upper_attn_bool = bool(no_upper_attn) replay_module_str = str(replay_module) replay_per_layers_int = int(replay_per_layers) record_scores = bool(output_scores) record_compact_scores = bool(output_scores) and bool(compact_scores) should_apply_stopping_criteria = stopping_criteria is not None mixin_decode_default = getattr(model, "_llopa_v2_generation_mixin_decode", None) use_mixin_decode = ( bool(mixin_decode_default) if mixin_decode_default is not None else _env_flag_enabled("CAPSULE_LLOPA_GENERATION_MIXIN_DECODE", "0") ) if use_mixin_decode and isinstance(initial_logits, torch.Tensor): mixin_output = _llopa_v2_generation_mixin_decode_loop( model, canonical_input_ids=canonical_input_ids, attention_mask=attention_mask, past_key_values=pkv, initial_logits=initial_logits, lower_k=lower_k, no_upper_attn=no_upper_attn_bool, replay_module=replay_module_str, replay_per_layers=replay_per_layers_int, max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, do_sample=do_sample, logits_warpers=logits_warpers, stop_ids=stop_ids, stop_token_ids=stop_token_ids, stopping_criteria=stopping_criteria, output_scores=output_scores, compact_scores=compact_scores, return_dict_in_generate=return_dict_in_generate, ) if mixin_output is not None: return mixin_output sequences = torch.empty( (canonical_input_ids.size(0), total_prompt_len + max_new_tokens), dtype=canonical_input_ids.dtype, device=canonical_input_ids.device, ) sequences[:, :total_prompt_len] = canonical_input_ids generated = sequences[:, total_prompt_len:] score_list: list[torch.Tensor] = [] score_list_append = score_list.append compact_logprob_list: list[torch.Tensor] = [] compact_logprob_append = compact_logprob_list.append cur = 0 pending_logits = initial_logits while cur < max_new_tokens: out = None if pending_logits is None: if use_direct_decode_step: out = llopa_forward_assistant( assistant_ids=last, lower_k=lower_k, pkv=pkv, S=S, U=U, write_cache=True, prefill_mode="lower", no_upper_attn=no_upper_attn_bool, align_cache_position_to_layer_past=False, replay_module=replay_module_str, replay_per_layers=replay_per_layers_int, ) pkv = out.past_key_values or pkv logits = decode_output_head(out.last_hidden_state[:, -1, :]) if logits.dim() == 3: logits = logits[:, -1, :] else: out = llopa_step( assistant_ids=last, lower_k=lower_k, pkv=pkv, S=S, U=U, logits_to_keep=1, labels=None, prefill_mode="lower", no_upper_attn=no_upper_attn_bool, align_cache_position_to_layer_past=False, replay_module=replay_module_str, replay_per_layers=replay_per_layers_int, ) pkv = out.past_key_values or pkv logits = out.logits[:, -1, :] logits = logits.to(dtype=torch.float32, device=last.device, copy=True) else: logits = pending_logits pending_logits = None if stop_token_ids is not None and cur < min_new_tokens: logits.index_fill_(1, stop_token_ids, -float("inf")) if logits_warpers is not None: logits = logits_warpers(generated[:, :cur], logits) if record_scores and not record_compact_scores: score_list_append(logits.detach().clone()) if do_sample: probs = torch.softmax(logits, dim=-1) next_tok = torch.multinomial(probs, num_samples=1) else: next_tok = torch.argmax(logits, dim=-1, keepdim=True) if record_compact_scores: next_logit = torch.gather(logits, 1, next_tok) next_logprob = next_logit - torch.logsumexp(logits, dim=-1, keepdim=True) compact_logprob_append(next_logprob.squeeze(-1).detach()) generated[:, cur : cur + 1] = next_tok cur += 1 should_stop = False tok_id = int(next_tok.item()) if tok_id in stop_ids and cur >= min_new_tokens: should_stop = True if (not should_stop) and should_apply_stopping_criteria: sequences_now = sequences[:, : total_prompt_len + cur] try: should_stop = bool(stopping_criteria(sequences_now, logits)) except TypeError: should_stop = bool(stopping_criteria(sequences_now, None)) if out is not None: del out if should_stop: break last = next_tok sequences = sequences[:, : total_prompt_len + cur] if not bool(return_dict_in_generate): return sequences from transformers.generation.utils import GenerateDecoderOnlyOutput output = GenerateDecoderOnlyOutput( sequences=sequences, scores=tuple(score_list) if bool(output_scores) and not record_compact_scores else None, past_key_values=pkv, ) if record_compact_scores: if compact_logprob_list: compact_logprobs = torch.stack(compact_logprob_list, dim=1) else: compact_logprobs = torch.empty( (sequences.size(0), 0), dtype=torch.float32, device=sequences.device, ) setattr(output, "generated_token_logprobs", compact_logprobs) return output def _is_prompt_messages_batch(prompt_messages) -> bool: if not isinstance(prompt_messages, (list, tuple)) or not prompt_messages: return False if all(isinstance(item, dict) for item in prompt_messages): return False return all(isinstance(item, (list, tuple)) for item in prompt_messages) def _is_structured_segments_batch(structured_prompt_segments) -> bool: return ( isinstance(structured_prompt_segments, (list, tuple)) and not isinstance(structured_prompt_segments, dict) and all(isinstance(item, dict) for item in structured_prompt_segments) ) def _batch_prompt_add_generation_flags(value, batch_size: int) -> list[bool]: if isinstance(value, (list, tuple)): if len(value) != int(batch_size): raise ValueError( f"prompt_add_generation_prompt batch size mismatch: {len(value)} != {int(batch_size)}" ) return [bool(item) for item in value] return [bool(value) for _ in range(int(batch_size))] def _as_1d_long_tensor(value: torch.Tensor, *, device) -> torch.LongTensor: value = value.to(device=device, dtype=torch.long) if value.dim() == 2: if value.size(0) != 1: raise ValueError("Expected a single-row prompt segment tensor.") value = value[0] elif value.dim() != 1: value = value.reshape(-1) return value def _pad_1d_rows( rows: list[torch.Tensor], *, pad_value: int, device, dtype: torch.dtype = torch.long, ) -> torch.Tensor: batch_size = len(rows) max_len = max((int(row.numel()) for row in rows), default=0) out = torch.full((batch_size, max_len), int(pad_value), device=device, dtype=dtype) for row_idx, row in enumerate(rows): row = row.to(device=device, dtype=dtype).reshape(-1) width = int(row.numel()) if width > 0: out[row_idx, :width] = row return out def _pad_prompt_metadata_rows( rows: list[Optional[torch.Tensor]], *, fill_value: int, device, dtype: torch.dtype, ) -> torch.Tensor: batch_size = len(rows) max_len = 0 flat_rows: list[torch.Tensor] = [] for row in rows: if isinstance(row, torch.Tensor) and row.numel() > 0: flat = row.to(device=device, dtype=dtype) if flat.dim() == 2: flat = flat[0] else: flat = flat.reshape(-1) else: flat = torch.empty((0,), device=device, dtype=dtype) flat_rows.append(flat) max_len = max(max_len, int(flat.numel())) if max_len <= 0: return torch.full((batch_size, 1), int(fill_value), device=device, dtype=dtype) out = torch.full((batch_size, max_len), int(fill_value), device=device, dtype=dtype) for row_idx, flat in enumerate(flat_rows): width = int(flat.numel()) if width > 0: out[row_idx, :width] = flat return out def _build_batched_structured_prompt_segments( tokenizer, *, prompt_messages, prompt_add_generation_prompt, structured_prompt_segments, device, ) -> list[dict]: if _is_structured_segments_batch(structured_prompt_segments): return list(structured_prompt_segments) if isinstance(structured_prompt_segments, dict): return [structured_prompt_segments] if _is_prompt_messages_batch(prompt_messages): batch_size = len(prompt_messages) flags = _batch_prompt_add_generation_flags(prompt_add_generation_prompt, batch_size) out = [] for row_idx, messages in enumerate(prompt_messages): row_messages = list(messages) add_generation_prompt = bool(flags[row_idx]) normalized = _normalize_prompt_messages(row_messages) if add_generation_prompt and normalized and normalized[-1]["role"] == "assistant": add_generation_prompt = False out.append( _build_structured_prompt_segments( tokenizer, row_messages, prompt_add_generation_prompt=add_generation_prompt, device=device, ) ) return out if prompt_messages is not None: if prompt_add_generation_prompt is None: raise ValueError("llopa_v2_batch_generate requires prompt_add_generation_prompt for prompt_messages.") return [ _build_structured_prompt_segments( tokenizer, prompt_messages, prompt_add_generation_prompt=bool(prompt_add_generation_prompt), device=device, ) ] return [] def _batched_segments_to_tensors(segments: list[dict], *, device, pad_token_id: int) -> dict: prompt_rows: list[torch.Tensor] = [] prompt_lens: list[int] = [] split_starts: list[int] = [] system_lens: list[int] = [] replay_prefix_keep_lens: list[int] = [] replay_user_starts: list[int] = [] replay_user_lens: list[int] = [] header_start_rows: list[Optional[torch.Tensor]] = [] turn_end_rows: list[Optional[torch.Tensor]] = [] header_mask_rows: list[Optional[torch.Tensor]] = [] for seg in segments: prompt_ids = seg.get("prompt_ids") assistant_prefill_ids = seg.get("assistant_prefill_ids") if not isinstance(prompt_ids, torch.Tensor) or prompt_ids.numel() == 0: raise ValueError("llopa_v2_batch_generate requires non-empty prompt_ids in each segment.") if not isinstance(assistant_prefill_ids, torch.Tensor) or assistant_prefill_ids.numel() == 0: raise ValueError("llopa_v2_batch_generate requires non-empty assistant_prefill_ids in each segment.") prompt_row = _as_1d_long_tensor(prompt_ids, device=device) assistant_row = _as_1d_long_tensor(assistant_prefill_ids, device=device) prompt_len = int(prompt_row.numel()) assistant_len = int(assistant_row.numel()) split_start = int(prompt_len - assistant_len) prefix_ids = seg.get("prefix_ids") if isinstance(prefix_ids, torch.Tensor) and prefix_ids.numel() > 0: split_start = int(_as_1d_long_tensor(prefix_ids, device=device).numel()) split_start = max(0, min(split_start, prompt_len - 1)) system_ids = seg.get("system_ids") system_len = int(_as_1d_long_tensor(system_ids, device=device).numel()) if isinstance(system_ids, torch.Tensor) else 0 system_len = max(0, min(system_len, split_start)) prompt_rows.append(prompt_row) prompt_lens.append(prompt_len) split_starts.append(split_start) system_lens.append(system_len) replay_prefix_keep_lens.append(int(seg.get("replay_user_prefix_keep_len", 0) or 0)) replay_user_starts.append(int(seg.get("replay_user_start", 0) or 0)) replay_user_lens.append(int(seg.get("replay_user_len", 0) or 0)) header_starts = seg.get("assistant_header_starts") turn_ends = seg.get("assistant_turn_ends") header_mask = seg.get("assistant_header_start_mask") if not isinstance(header_starts, torch.Tensor) or header_starts.numel() == 0: header_starts = torch.tensor([[split_start]], device=device, dtype=torch.long) if not isinstance(turn_ends, torch.Tensor) or turn_ends.numel() == 0: turn_ends = torch.tensor([[prompt_len]], device=device, dtype=torch.long) if not isinstance(header_mask, torch.Tensor) or header_mask.numel() == 0: header_mask = torch.ones_like(header_starts, device=device, dtype=torch.bool) header_start_rows.append(header_starts) turn_end_rows.append(turn_ends) header_mask_rows.append(header_mask) prompt_ids = _pad_1d_rows(prompt_rows, pad_value=int(pad_token_id), device=device, dtype=torch.long) prompt_lens_tensor = torch.tensor(prompt_lens, device=device, dtype=torch.long) prompt_attention_mask = ( torch.arange(prompt_ids.size(1), device=device, dtype=torch.long).unsqueeze(0) < prompt_lens_tensor.unsqueeze(1) ).to(dtype=torch.long) return { "prompt_rows": prompt_rows, "prompt_ids": prompt_ids, "prompt_attention_mask": prompt_attention_mask, "prompt_lens": prompt_lens_tensor, "split_starts": torch.tensor(split_starts, device=device, dtype=torch.long), "system_lens": torch.tensor(system_lens, device=device, dtype=torch.long), "replay_user_prefix_keep_lens": torch.tensor(replay_prefix_keep_lens, device=device, dtype=torch.long), "replay_user_starts": torch.tensor(replay_user_starts, device=device, dtype=torch.long), "replay_user_lens": torch.tensor(replay_user_lens, device=device, dtype=torch.long), "assistant_header_starts": _pad_prompt_metadata_rows( header_start_rows, fill_value=-1, device=device, dtype=torch.long, ), "assistant_turn_ends": _pad_prompt_metadata_rows( turn_end_rows, fill_value=-1, device=device, dtype=torch.long, ), "assistant_header_start_mask": _pad_prompt_metadata_rows( header_mask_rows, fill_value=0, device=device, dtype=torch.bool, ).to(dtype=torch.bool), } def _cache_layer_seq_len(pkv, layer_idx: int) -> int: try: k, _ = pkv_get(pkv, int(layer_idx)) except Exception: return 0 if not isinstance(k, torch.Tensor) or k.dim() < 3: return 0 return int(k.shape[-2]) def _merge_optional_batch_sequence_attr(row_pkvs: list, attr_name: str, *, device, dtype=None): rows = [] trailing_shape = None for pkv in row_pkvs: value = getattr(pkv, attr_name, None) if isinstance(value, torch.Tensor) and value.numel() > 0: value = value.to(device=device) if dtype is not None: value = value.to(dtype=dtype) if value.dim() == 1: value = value.view(1, -1) elif value.dim() >= 2 and value.size(0) != 1: value = value[:1] trailing_shape = tuple(value.shape[2:]) else: value = None rows.append(value) if trailing_shape is None: return None max_len = max((int(row.shape[1]) for row in rows if isinstance(row, torch.Tensor)), default=0) if max_len <= 0: return None ref = next(row for row in rows if isinstance(row, torch.Tensor)) out_shape = (len(rows), max_len) + trailing_shape out = torch.zeros(out_shape, device=device, dtype=ref.dtype) for row_idx, row in enumerate(rows): if not isinstance(row, torch.Tensor): continue width = int(row.shape[1]) if width > 0: out[row_idx : row_idx + 1, :width, ...] = row[:, :width, ...] return out def _merge_llopa_batch_row_caches(row_pkvs: list, *, device) -> Optional[DynamicCache]: if not row_pkvs: return None try: n_layers = max(int(pkv_len(pkv)) for pkv in row_pkvs) except Exception: return None if n_layers <= 0: return None merged_pairs = [] layer_valid_masks: list[torch.Tensor] = [] for layer_idx in range(n_layers): row_kvs = [] max_len = 0 ref_k = None ref_v = None for pkv in row_pkvs: try: k, v = pkv_get(pkv, layer_idx) except Exception: k = None v = None if isinstance(k, torch.Tensor) and isinstance(v, torch.Tensor) and k.dim() == 4 and v.dim() == 4: k = k.to(device=device) v = v.to(device=device) if k.size(0) != 1: k = k[:1] v = v[:1] ref_k = k if ref_k is None else ref_k ref_v = v if ref_v is None else ref_v max_len = max(max_len, int(k.shape[-2])) row_kvs.append((k, v)) else: row_kvs.append((None, None)) if ref_k is None or ref_v is None: return None B = len(row_pkvs) merged_k = torch.zeros( (B, int(ref_k.shape[1]), max_len, int(ref_k.shape[-1])), device=device, dtype=ref_k.dtype, ) merged_v = torch.zeros( (B, int(ref_v.shape[1]), max_len, int(ref_v.shape[-1])), device=device, dtype=ref_v.dtype, ) valid_mask = torch.zeros((B, max_len), device=device, dtype=torch.bool) for row_idx, (k, v) in enumerate(row_kvs): if not isinstance(k, torch.Tensor) or not isinstance(v, torch.Tensor): continue width = int(k.shape[-2]) if width > 0: merged_k[row_idx : row_idx + 1, :, :width, :] = k[:, :, :width, :] merged_v[row_idx : row_idx + 1, :, :width, :] = v[:, :, :width, :] valid_mask[row_idx, :width] = True merged_pairs.append((merged_k, merged_v)) layer_valid_masks.append(valid_mask) try: merged = DynamicCache(ddp_cache_data=merged_pairs) except Exception: merged = DynamicCache() for layer_idx, (k, v) in enumerate(merged_pairs): merged.update(k, v, layer_idx) try: setattr(merged, "_llopa_batch_layer_valid_masks", layer_valid_masks) except Exception: pass replay_hidden = _merge_optional_batch_sequence_attr( row_pkvs, "_tri_last_layer_memory_hidden", device=device, ) replay_pos = _merge_optional_batch_sequence_attr( row_pkvs, "_tri_last_layer_memory_position_ids", device=device, dtype=torch.long, ) replay_valid = _merge_optional_batch_sequence_attr( row_pkvs, "_tri_last_layer_memory_valid_mask", device=device, dtype=torch.bool, ) if isinstance(replay_hidden, torch.Tensor) and isinstance(replay_pos, torch.Tensor): with contextlib.suppress(Exception): setattr(merged, "_tri_last_layer_memory_hidden", replay_hidden) setattr(merged, "_tri_last_layer_memory_position_ids", replay_pos) if isinstance(replay_valid, torch.Tensor): setattr(merged, "_tri_last_layer_memory_valid_mask", replay_valid.to(dtype=torch.bool)) module_type = getattr(row_pkvs[0], "_tri_replay_module", getattr(row_pkvs[0], "_tri_last_layer_module", "none")) setattr(merged, "_tri_replay_module", str(module_type or "none")) setattr(merged, "_tri_last_layer_module", str(module_type or "none")) setattr(merged, "_tri_replay_per_layers", int(getattr(row_pkvs[0], "_tri_replay_per_layers", -1) or -1)) return merged def _append_llopa_batch_cache_valid_masks(pkv, input_valid: torch.Tensor) -> None: masks = getattr(pkv, "_llopa_batch_layer_valid_masks", None) if not isinstance(masks, list): return input_valid = input_valid.to(dtype=torch.bool) new_masks = [] for layer_idx, mask in enumerate(masks): if not isinstance(mask, torch.Tensor) or mask.dim() != 2: new_masks.append(mask) continue layer_len = _cache_layer_seq_len(pkv, layer_idx) cur_len = int(mask.size(1)) if layer_len > cur_len: add_width = int(layer_len - cur_len) add = input_valid.to(device=mask.device).view(-1, 1).expand(mask.size(0), add_width) mask = torch.cat([mask, add], dim=1) elif layer_len < cur_len: mask = mask[:, :layer_len] new_masks.append(mask) with contextlib.suppress(Exception): setattr(pkv, "_llopa_batch_layer_valid_masks", new_masks) setattr(pkv, "_tri_past_len_cache", None) def _direct_llopa_batch_prefill_cache_and_logits( model, *, segments: list[dict], pad_id: int, lower_k: int, prefill_attn: str, system_prefill: str, no_upper_attn: bool, see_past_assistant: bool, replay_module: str, replay_per_layers: int, device, ): if not segments: return None if _normalize_replay_module_value(replay_module) != "none": return None seed_fn = _get_llopa_full_prompt_seed(model) if not callable(seed_fn): return None try: batch = _batched_segments_to_tensors(segments, device=device, pad_token_id=int(pad_id)) except Exception: return None try: seed = seed_fn( input_ids=batch["prompt_ids"], attention_mask=batch["prompt_attention_mask"], use_cache=True, logits_to_keep=1, lower_k=int(lower_k), prefill_attn=str(prefill_attn), system_prefill=str(system_prefill), no_upper_attn=bool(no_upper_attn), prefill_lower_split_start=batch["split_starts"], prefill_lower_system_len=batch["system_lens"], prefill_lower_replay_user_prefix_keep_len=batch["replay_user_prefix_keep_lens"], prefill_lower_replay_user_start=batch["replay_user_starts"], prefill_lower_replay_user_len=batch["replay_user_lens"], assistant_header_starts=batch["assistant_header_starts"], assistant_turn_ends=batch["assistant_turn_ends"], assistant_header_start_mask=batch["assistant_header_start_mask"], prefill_lower_see_past_assistant=bool(see_past_assistant), replay_module=str(replay_module), replay_per_layers=int(replay_per_layers), ) except Exception: return None if not isinstance(seed, tuple) or len(seed) != 4: return None pkv, S, U, logits = seed if pkv is None or not isinstance(logits, torch.Tensor) or logits.numel() == 0: return None if logits.dim() == 3: logits = logits[:, -1, :] if logits.dim() != 2 or int(logits.size(0)) != len(segments): return None return pkv, S, U, logits.to(device=device, dtype=torch.float32) def _direct_llopa_batch_generate_cached_impl( model, tokenizer, *, segments: list[dict], canonical_input_ids: torch.Tensor, context_width: int, batch_size: int, lower_k: int, prefill_attn: str, system_prefill: str, no_upper_attn: bool, see_past_assistant: bool, replay_module: str, replay_per_layers: int, last_layer_module: Optional[str], max_new_tokens: int, min_new_tokens: int, do_sample: bool, logits_warpers, stopping_criteria, stop_ids: set[int], stop_token_ids: Optional[torch.Tensor], pad_id: int, output_scores: bool, compact_scores: bool, return_dict_in_generate: bool, device, ): llopa_core = _get_llopa_core(model) llopa_step = _get_llopa_decode_step(model) if llopa_core is None or llopa_step is None: return None llopa_forward_assistant = getattr(llopa_core, "tri_forward_assistant", None) decode_output_head = _get_output_head(model) use_direct_decode_step = ( _env_flag_enabled("CAPSULE_LLOPA_DIRECT_DECODE_STEP", "1") and callable(llopa_forward_assistant) and decode_output_head is not None ) row_pkvs = [] initial_logits_rows = [] prompt_bundles = [] for seg in segments: prompt_bundle = _build_unified_prefill_lower_prompt_bundle( tokenizer, prompt_messages=None, prompt_add_generation_prompt=True, structured_prompt_segments=seg, device=device, ) prompt_bundles.append(prompt_bundle) has_past_assistant_history = bool(see_past_assistant) and any( _prompt_bundle_has_past_assistant_history(bundle) for bundle in prompt_bundles ) batch_seed = None if int(batch_size) > 1: batch_seed = _direct_llopa_batch_prefill_cache_and_logits( model, segments=segments, pad_id=int(pad_id), lower_k=int(lower_k), prefill_attn=str(prefill_attn), system_prefill=str(system_prefill), no_upper_attn=bool(no_upper_attn), see_past_assistant=bool(see_past_assistant), replay_module=str(replay_module), replay_per_layers=int(replay_per_layers), device=device, ) if batch_seed is not None: pkv, _S, _U, pending_logits = batch_seed pending_logits = pending_logits.to(device=device, dtype=torch.float32) else: if int(batch_size) > 1 and _normalize_replay_module_value(replay_module) != "none": return None if int(batch_size) > 1 and bool(has_past_assistant_history): return None for prompt_bundle in prompt_bundles: seed = _direct_prefill_lower_cache_and_logits( model, prompt_bundle=prompt_bundle, lower_k=int(lower_k), prefill_attn=str(prefill_attn), system_prefill=str(system_prefill), no_upper_attn=bool(no_upper_attn), see_past_assistant=bool(see_past_assistant), replay_module=str(replay_module), replay_per_layers=int(replay_per_layers), last_layer_module=last_layer_module, seed_mode="prefill_header", ) if seed is None: return None pkv_row, _S, _U, initial_logits = seed if not isinstance(initial_logits, torch.Tensor) or initial_logits.numel() == 0: return None row_pkvs.append(pkv_row) initial_logits_rows.append(initial_logits.to(device=device, dtype=torch.float32)) pkv = _merge_llopa_batch_row_caches(row_pkvs, device=device) if pkv is None: return None pending_logits = torch.cat(initial_logits_rows, dim=0) del row_pkvs, initial_logits_rows del prompt_bundles sequences = torch.full( (batch_size, int(context_width) + int(max_new_tokens)), int(pad_id), dtype=canonical_input_ids.dtype, device=device, ) sequences[:, : int(context_width)] = canonical_input_ids.to(device=device) generated = sequences[:, int(context_width) :] score_list: list[torch.Tensor] = [] score_list_append = score_list.append record_scores = bool(output_scores) record_compact_scores = record_scores and bool(compact_scores) compact_logprob_list: list[torch.Tensor] = [] compact_logprob_append = compact_logprob_list.append unfinished = torch.ones((batch_size,), device=device, dtype=torch.bool) finish_steps = torch.full((batch_size,), int(max_new_tokens), device=device, dtype=torch.long) last = None last_valid = torch.ones((batch_size,), device=device, dtype=torch.bool) cur = 0 while cur < int(max_new_tokens): if pending_logits is None: if not isinstance(last, torch.Tensor): return None if use_direct_decode_step: out = llopa_forward_assistant( assistant_ids=last, lower_k=int(lower_k), pkv=pkv, S=0, U=0, write_cache=True, prefill_mode="lower", no_upper_attn=bool(no_upper_attn), align_cache_position_to_layer_past=True, replay_module=str(replay_module), replay_per_layers=int(replay_per_layers), ) pkv = out.past_key_values or pkv _append_llopa_batch_cache_valid_masks(pkv, last_valid) logits = decode_output_head(out.last_hidden_state[:, -1, :]) if logits.dim() == 3: logits = logits[:, -1, :] else: out = llopa_step( assistant_ids=last, lower_k=int(lower_k), pkv=pkv, S=0, U=0, logits_to_keep=1, labels=None, prefill_mode="lower", no_upper_attn=bool(no_upper_attn), align_cache_position_to_layer_past=True, replay_module=str(replay_module), replay_per_layers=int(replay_per_layers), ) pkv = out.past_key_values or pkv _append_llopa_batch_cache_valid_masks(pkv, last_valid) logits = out.logits[:, -1, :] logits = logits.to(dtype=torch.float32, device=device, copy=True) else: logits = pending_logits pending_logits = None if stop_token_ids is not None and cur < int(min_new_tokens): logits.index_fill_(1, stop_token_ids.to(device=logits.device), -float("inf")) if logits_warpers is not None: logits = logits_warpers(generated[:, :cur], logits) if record_scores and not record_compact_scores: score_list_append(logits.detach().clone()) if bool(do_sample): probs = torch.softmax(logits, dim=-1) next_tok = torch.multinomial(probs, num_samples=1) else: next_tok = torch.argmax(logits, dim=-1, keepdim=True) if record_compact_scores: next_logit = torch.gather(logits, 1, next_tok) next_logprob = next_logit - torch.logsumexp(logits, dim=-1, keepdim=True) compact_logprob_append(next_logprob.squeeze(-1).detach()) token_valid = unfinished.clone() if not bool(unfinished.all().item()): next_tok = torch.where( unfinished.view(-1, 1), next_tok, torch.full_like(next_tok, int(pad_id)), ) token_valid = unfinished.clone() generated[:, cur : cur + 1] = next_tok.to(dtype=generated.dtype) cur += 1 if stop_ids and cur >= int(min_new_tokens): stop_mask = torch.zeros_like(unfinished) for stop_id in stop_ids: stop_mask |= next_tok.squeeze(1).eq(int(stop_id)) newly_finished = unfinished & stop_mask if bool(newly_finished.any().item()): finish_steps = torch.where( newly_finished, torch.full_like(finish_steps, int(cur)), finish_steps, ) unfinished = unfinished & ~stop_mask if stopping_criteria is not None: sequences_now = sequences[:, : int(context_width) + cur] try: stop_result = stopping_criteria(sequences_now, logits) except TypeError: stop_result = stopping_criteria(sequences_now, None) if isinstance(stop_result, torch.Tensor): stop_result = stop_result.to(device=device, dtype=torch.bool).reshape(-1) if stop_result.numel() == 1: if bool(stop_result.item()): unfinished.zero_() elif stop_result.numel() == batch_size: newly_finished = unfinished & stop_result if bool(newly_finished.any().item()): finish_steps = torch.where( newly_finished, torch.full_like(finish_steps, int(cur)), finish_steps, ) unfinished = unfinished & ~stop_result elif bool(stop_result.all().item()): unfinished.zero_() elif bool(stop_result): unfinished.zero_() if not bool(unfinished.any().item()): break last = next_tok last_valid = token_valid sequences = sequences[:, : int(context_width) + cur] if not bool(return_dict_in_generate): return sequences from transformers.generation.utils import GenerateDecoderOnlyOutput output = GenerateDecoderOnlyOutput( sequences=sequences, scores=tuple(score_list) if record_scores and not record_compact_scores else None, past_key_values=pkv, ) if record_compact_scores: if compact_logprob_list: compact_logprobs = torch.stack(compact_logprob_list, dim=1) else: compact_logprobs = torch.empty( (sequences.size(0), 0), dtype=torch.float32, device=sequences.device, ) setattr(output, "generated_token_logprobs", compact_logprobs) return output def _direct_llopa_batch_generate_serial_cached_fallback_impl( model, tokenizer, *, segments: list[dict], canonical_input_ids: torch.Tensor, context_width: int, batch_size: int, lower_k: int, prefill_attn: str, system_prefill: str, user_prefill: str, no_upper_attn: bool, see_past_assistant: bool, replay_module: str, replay_per_layers: int, last_layer_module: Optional[str], seed_mode: str, max_new_tokens: int, min_new_tokens: int, do_sample: bool, temperature, top_p, top_k, stopping_criteria, pad_id: int, eos_token_id, output_scores: bool, compact_scores: bool, return_dict_in_generate: bool, device, ): if int(batch_size) <= 1: return None row_sequences: list[torch.Tensor] = [] row_score_tuples: list[tuple] = [] row_compact_tensors: list[torch.Tensor] = [] score_template = None max_gen_len = 0 record_scores = bool(output_scores) record_compact_scores = record_scores and bool(compact_scores) for row_idx, seg in enumerate(segments): prompt_ids = seg.get("prompt_ids") if isinstance(seg, dict) else None if not isinstance(prompt_ids, torch.Tensor) or prompt_ids.numel() == 0: return None row_prompt = prompt_ids.to(device=device, dtype=torch.long) if row_prompt.dim() == 1: row_prompt = row_prompt.unsqueeze(0) elif row_prompt.dim() == 2: row_prompt = row_prompt[:1] else: row_prompt = row_prompt.reshape(1, -1) row_prompt_len = int(row_prompt.size(1)) row_out = _direct_llopa_generate_impl( model, tokenizer, prompt_messages=None, prompt_add_generation_prompt=True, structured_prompt_segments=seg, input_ids=row_prompt, attention_mask=torch.ones_like(row_prompt, device=device, dtype=torch.long), lower_k=int(lower_k), prefill_attn=str(prefill_attn), system_prefill=str(system_prefill), user_prefill=str(user_prefill), no_upper_attn=bool(no_upper_attn), see_past_assistant=bool(see_past_assistant), replay_module=str(replay_module), replay_per_layers=int(replay_per_layers), last_layer_module=last_layer_module, seed_mode=str(seed_mode or "prefill_header"), max_new_tokens=int(max_new_tokens), min_new_tokens=int(min_new_tokens), do_sample=bool(do_sample), temperature=temperature, top_p=top_p, top_k=top_k, stopping_criteria=stopping_criteria, pad_token_id=int(pad_id), eos_token_id=eos_token_id, output_scores=record_scores, compact_scores=record_compact_scores, return_dict_in_generate=True, use_cache=True, ) if row_out is None: return None row_seq = getattr(row_out, "sequences", None) if not isinstance(row_seq, torch.Tensor) or row_seq.dim() != 2 or row_seq.size(0) < 1: return None row_gen = row_seq[:1, row_prompt_len:].to(device=device, dtype=canonical_input_ids.dtype) row_sequences.append(row_gen) max_gen_len = max(max_gen_len, int(row_gen.size(1))) if record_compact_scores: compact = getattr(row_out, "generated_token_logprobs", None) if isinstance(compact, torch.Tensor): row_compact_tensors.append(compact[:1].to(device=device, dtype=torch.float32)) else: row_compact_tensors.append(torch.empty((1, 0), device=device, dtype=torch.float32)) elif record_scores: scores = getattr(row_out, "scores", None) if isinstance(scores, tuple): row_score_tuples.append(scores) if score_template is None and len(scores) > 0 and isinstance(scores[0], torch.Tensor): score_template = scores[0][:1].to(device=device) else: row_score_tuples.append(()) sequences = torch.full( (int(batch_size), int(context_width) + int(max_gen_len)), int(pad_id), dtype=canonical_input_ids.dtype, device=device, ) sequences[:, : int(context_width)] = canonical_input_ids.to(device=device) for row_idx, row_gen in enumerate(row_sequences): width = int(row_gen.size(1)) if width > 0: sequences[row_idx : row_idx + 1, int(context_width) : int(context_width) + width] = row_gen[:, :width] compact_logprobs = None score_list = None if record_compact_scores: compact_logprobs = torch.zeros( (int(batch_size), int(max_gen_len)), device=device, dtype=torch.float32, ) for row_idx, compact in enumerate(row_compact_tensors): width = min(int(compact.size(1)), int(max_gen_len)) if width > 0: compact_logprobs[row_idx : row_idx + 1, :width] = compact[:, :width] elif record_scores: if score_template is None: score_list = [] else: score_list = [] for step_idx in range(int(max_gen_len)): step_scores = [] for scores in row_score_tuples: if step_idx < len(scores) and isinstance(scores[step_idx], torch.Tensor): step_scores.append(scores[step_idx][:1].to(device=device)) else: step_scores.append(torch.zeros_like(score_template, device=device)) score_list.append(torch.cat(step_scores, dim=0)) if not bool(return_dict_in_generate): return sequences from transformers.generation.utils import GenerateDecoderOnlyOutput output = GenerateDecoderOnlyOutput( sequences=sequences, scores=tuple(score_list) if record_scores and not record_compact_scores and score_list is not None else None, past_key_values=None, ) if compact_logprobs is not None: setattr(output, "generated_token_logprobs", compact_logprobs) return output @torch.inference_mode() def _direct_llopa_batch_generate_impl( model, tokenizer, *, prompt_messages, prompt_add_generation_prompt: bool, structured_prompt_segments=None, input_ids: Optional[torch.LongTensor], attention_mask: Optional[torch.Tensor], lower_k: int, prefill_attn: str, system_prefill: str, user_prefill: str, no_upper_attn: bool, see_past_assistant: bool = False, replay_module: str = "none", replay_per_layers: int = -1, last_layer_module: Optional[str] = None, seed_mode: str = "prefill_header", max_length=None, max_new_tokens=None, min_length=None, min_new_tokens=None, do_sample=None, temperature=None, top_p=None, top_k=None, stopping_criteria=None, pad_token_id=None, eos_token_id=None, output_scores: bool = False, compact_scores: bool = False, return_dict_in_generate: bool = False, use_cache: Optional[bool] = None, ): allow_cached_paths = use_cache is not False batch_size = int(input_ids.size(0)) if isinstance(input_ids, torch.Tensor) and input_ids.dim() == 2 else 0 if _is_structured_segments_batch(structured_prompt_segments): batch_size = len(structured_prompt_segments) elif _is_prompt_messages_batch(prompt_messages): batch_size = len(prompt_messages) elif batch_size <= 0: batch_size = 1 if last_layer_module is not None and _normalize_replay_module_value(replay_module) == "none": replay_module = last_layer_module replay_module = _normalize_replay_module_value(replay_module) replay_per_layers = _normalize_replay_per_layers_value(replay_per_layers) try: lower_k = int(lower_k) except Exception: lower_k = 0 if lower_k <= 0: return None attn = (prefill_attn or "causal").strip().lower() if attn == "prefix_full": attn = "full" if attn not in {"causal", "full"}: attn = "causal" sys_prefill = (system_prefill or "full").strip().lower() if sys_prefill not in {"full", "no_system", "no_bos_system"}: sys_prefill = "full" user_prefill_norm = (user_prefill or "full").strip().lower() if user_prefill_norm != "full": raise ValueError("llopa_v2_batch_generate currently supports only user_prefill='full'.") if isinstance(input_ids, torch.Tensor): if input_ids.dim() != 2 or input_ids.size(0) != batch_size: return None device = input_ids.device else: try: device = next(model.parameters()).device except Exception: device = "cpu" pad_id = pad_token_id if pad_id is None: pad_id = getattr(tokenizer, "pad_token_id", None) if pad_id is None: pad_id = getattr(tokenizer, "eos_token_id", None) if pad_id is None: pad_id = 0 pad_id = int(pad_id) segments = _build_batched_structured_prompt_segments( tokenizer, prompt_messages=prompt_messages, prompt_add_generation_prompt=prompt_add_generation_prompt, structured_prompt_segments=structured_prompt_segments, device=device, ) if len(segments) != batch_size: raise ValueError( f"llopa_v2_batch_generate prompt metadata batch size mismatch: {len(segments)} != {batch_size}" ) batch = _batched_segments_to_tensors(segments, device=device, pad_token_id=pad_id) if isinstance(input_ids, torch.Tensor): canonical_input_ids = input_ids.to(device=device) context_width = int(canonical_input_ids.size(1)) else: canonical_input_ids = batch["prompt_ids"] context_width = int(canonical_input_ids.size(1)) raw_temp = 0.0 if temperature is None else float(temperature) if do_sample is None: do_sample = bool(raw_temp != 0.0) do_sample = bool(do_sample) sample_temp = 1.0 if (not do_sample or raw_temp == 0.0) else float(raw_temp) top_p = 1.0 if top_p is None else float(top_p) top_k = None if top_k is None else int(top_k) if max_new_tokens is None: if max_length is None: max_new_tokens = 256 else: max_new_tokens = max(0, int(max_length) - int(context_width)) else: max_new_tokens = int(max_new_tokens) if min_new_tokens is None: if min_length is None: min_new_tokens = 0 else: min_new_tokens = max(0, int(min_length) - int(context_width)) else: min_new_tokens = int(min_new_tokens) max_new_tokens = int(max_new_tokens) min_new_tokens = int(min_new_tokens) stop_ids = set(_normalize_eos_token_ids(eos_token_id)) if not stop_ids: tok_eos = getattr(tokenizer, "eos_token_id", None) if tok_eos is not None: stop_ids.add(int(tok_eos)) with contextlib.suppress(Exception): eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>") if eot_id is not None and eot_id != tokenizer.unk_token_id: stop_ids.add(int(eot_id)) stop_token_ids = _prepare_stop_token_tensor(stop_ids, device) logits_warpers = _build_sampling_warpers(do_sample, sample_temp, top_p, top_k) record_scores = bool(output_scores) record_compact_scores = record_scores and bool(compact_scores) should_apply_stopping_criteria = stopping_criteria is not None cached_out = None if bool(allow_cached_paths): cached_out = _direct_llopa_batch_generate_cached_impl( model, tokenizer, segments=segments, canonical_input_ids=canonical_input_ids, context_width=int(context_width), batch_size=int(batch_size), lower_k=int(lower_k), prefill_attn=str(attn), system_prefill=str(sys_prefill), no_upper_attn=bool(no_upper_attn), see_past_assistant=bool(see_past_assistant), replay_module=str(replay_module), replay_per_layers=int(replay_per_layers), last_layer_module=last_layer_module, max_new_tokens=int(max_new_tokens), min_new_tokens=int(min_new_tokens), do_sample=bool(do_sample), logits_warpers=logits_warpers, stopping_criteria=stopping_criteria, stop_ids=stop_ids, stop_token_ids=stop_token_ids, pad_id=int(pad_id), output_scores=record_scores, compact_scores=record_compact_scores, return_dict_in_generate=bool(return_dict_in_generate), device=device, ) if cached_out is not None: return cached_out serial_cached_out = None if bool(allow_cached_paths) and stopping_criteria is None: serial_cached_out = _direct_llopa_batch_generate_serial_cached_fallback_impl( model, tokenizer, segments=segments, canonical_input_ids=canonical_input_ids, context_width=int(context_width), batch_size=int(batch_size), lower_k=int(lower_k), prefill_attn=str(attn), system_prefill=str(sys_prefill), user_prefill=str(user_prefill_norm), no_upper_attn=bool(no_upper_attn), see_past_assistant=bool(see_past_assistant), replay_module=str(replay_module), replay_per_layers=int(replay_per_layers), last_layer_module=last_layer_module, seed_mode=str(seed_mode or "prefill_header"), max_new_tokens=int(max_new_tokens), min_new_tokens=int(min_new_tokens), do_sample=bool(do_sample), temperature=temperature, top_p=top_p, top_k=top_k, stopping_criteria=stopping_criteria, pad_id=int(pad_id), eos_token_id=eos_token_id, output_scores=record_scores, compact_scores=record_compact_scores, return_dict_in_generate=bool(return_dict_in_generate), device=device, ) if serial_cached_out is not None: return serial_cached_out sequences = torch.full( (batch_size, context_width + max_new_tokens), pad_id, dtype=canonical_input_ids.dtype, device=device, ) sequences[:, :context_width] = canonical_input_ids generated = sequences[:, context_width:] score_list: list[torch.Tensor] = [] score_list_append = score_list.append compact_logprob_list: list[torch.Tensor] = [] compact_logprob_append = compact_logprob_list.append prompt_rows: list[torch.Tensor] = batch["prompt_rows"] prompt_lens = batch["prompt_lens"] split_starts = batch["split_starts"] system_lens = batch["system_lens"] unfinished = torch.ones((batch_size,), device=device, dtype=torch.bool) finish_steps = torch.full((batch_size,), max_new_tokens, device=device, dtype=torch.long) def _build_step_tensors(cur_tokens: int) -> tuple[torch.LongTensor, torch.Tensor, torch.LongTensor]: active_lens = [] rows = [] for row_idx, prompt_row in enumerate(prompt_rows): if bool(unfinished[row_idx].item()): gen_len = int(cur_tokens) else: gen_len = min(int(cur_tokens), int(finish_steps[row_idx].item())) gen_prefix = generated[row_idx, :gen_len].to(device=device, dtype=torch.long) row_ids = torch.cat([prompt_row.to(device=device, dtype=torch.long), gen_prefix], dim=0) rows.append(row_ids) active_lens.append(int(row_ids.numel())) step_ids = _pad_1d_rows(rows, pad_value=pad_id, device=device, dtype=torch.long) active_lens_t = torch.tensor(active_lens, device=device, dtype=torch.long) step_mask = ( torch.arange(step_ids.size(1), device=device, dtype=torch.long).unsqueeze(0) < active_lens_t.unsqueeze(1) ).to(dtype=torch.long) labels = torch.full_like(step_ids, -100) for row_idx, active_len in enumerate(active_lens): start = int(split_starts[row_idx].item()) if active_len > start: labels[row_idx, start:active_len] = step_ids[row_idx, start:active_len] return step_ids, step_mask, labels cur = 0 while cur < max_new_tokens: step_ids, step_mask, labels = _build_step_tensors(cur) out = model( input_ids=step_ids, attention_mask=step_mask, labels=labels, use_cache=False, logits_to_keep=1, prefill_lower_layers=int(lower_k), prefill_lower_attn=str(attn), prefill_lower_system_prefill=str(sys_prefill), prefill_lower_no_upper_attn=bool(no_upper_attn), prefill_lower_split_start=split_starts, prefill_lower_system_len=system_lens, prefill_lower_replay_user_prefix_keep_len=batch["replay_user_prefix_keep_lens"], prefill_lower_replay_user_start=batch["replay_user_starts"], prefill_lower_replay_user_len=batch["replay_user_lens"], assistant_header_starts=batch["assistant_header_starts"], assistant_turn_ends=batch["assistant_turn_ends"], assistant_header_start_mask=batch["assistant_header_start_mask"], prefill_lower_see_past_assistant=bool(see_past_assistant), prefill_lower_replay_module=str(replay_module), prefill_lower_replay_per_layers=int(replay_per_layers), ) if out is None or not isinstance(getattr(out, "logits", None), torch.Tensor): return None logits = out.logits[:, -1, :].to(dtype=torch.float32, device=device, copy=True) if stop_token_ids is not None and cur < min_new_tokens: logits.index_fill_(1, stop_token_ids, -float("inf")) if logits_warpers is not None: logits = logits_warpers(generated[:, :cur], logits) if record_scores and not record_compact_scores: score_list_append(logits.detach().clone()) if do_sample: probs = torch.softmax(logits, dim=-1) next_tok = torch.multinomial(probs, num_samples=1) else: next_tok = torch.argmax(logits, dim=-1, keepdim=True) if record_compact_scores: next_logit = torch.gather(logits, 1, next_tok) next_logprob = next_logit - torch.logsumexp(logits, dim=-1, keepdim=True) compact_logprob_append(next_logprob.squeeze(-1).detach()) if not bool(unfinished.all().item()): next_tok = torch.where( unfinished.view(-1, 1), next_tok, torch.full_like(next_tok, pad_id), ) generated[:, cur : cur + 1] = next_tok.to(dtype=generated.dtype) cur += 1 if stop_ids and cur >= min_new_tokens: stop_mask = torch.zeros_like(unfinished) for stop_id in stop_ids: stop_mask |= next_tok.squeeze(1).eq(int(stop_id)) newly_finished = unfinished & stop_mask if bool(newly_finished.any().item()): finish_steps = torch.where( newly_finished, torch.full_like(finish_steps, int(cur)), finish_steps, ) unfinished = unfinished & ~stop_mask if should_apply_stopping_criteria: sequences_now = sequences[:, : context_width + cur] try: stop_result = stopping_criteria(sequences_now, logits) except TypeError: stop_result = stopping_criteria(sequences_now, None) if isinstance(stop_result, torch.Tensor): stop_result = stop_result.to(device=device, dtype=torch.bool).reshape(-1) if stop_result.numel() == 1: if bool(stop_result.item()): unfinished.zero_() elif stop_result.numel() == batch_size: newly_finished = unfinished & stop_result if bool(newly_finished.any().item()): finish_steps = torch.where( newly_finished, torch.full_like(finish_steps, int(cur)), finish_steps, ) unfinished = unfinished & ~stop_result elif bool(stop_result.all().item()): unfinished.zero_() elif bool(stop_result): unfinished.zero_() if not bool(unfinished.any().item()): break sequences = sequences[:, : context_width + cur] if not bool(return_dict_in_generate): return sequences from transformers.generation.utils import GenerateDecoderOnlyOutput output = GenerateDecoderOnlyOutput( sequences=sequences, scores=tuple(score_list) if record_scores and not record_compact_scores else None, past_key_values=None, ) if record_compact_scores: if compact_logprob_list: compact_logprobs = torch.stack(compact_logprob_list, dim=1) else: compact_logprobs = torch.empty( (sequences.size(0), 0), dtype=torch.float32, device=sequences.device, ) setattr(output, "generated_token_logprobs", compact_logprobs) return output def _direct_optimized_llopa_generate_impl( model, tokenizer, *, prompt_messages, prompt_add_generation_prompt: bool, structured_prompt_segments=None, input_ids: Optional[torch.LongTensor], attention_mask: Optional[torch.Tensor], lower_k: int, prefill_attn: str, system_prefill: str, user_prefill: str, no_upper_attn: bool, see_past_assistant: bool = False, replay_module: str = "none", replay_per_layers: int = -1, last_layer_module: Optional[str] = None, optimized_variant: Optional[str] = None, optimized_seed_mode: Optional[str] = None, optimized_upper_prepare_mode: Optional[str] = None, optimized_upper_bucket_multiple: Optional[int] = None, optimized_seq_bucket_multiple: Optional[int] = None, max_length=None, max_new_tokens=None, min_length=None, min_new_tokens=None, do_sample=None, temperature=None, top_p=None, top_k=None, stopping_criteria=None, pad_token_id=None, eos_token_id=None, output_scores: bool = False, compact_scores: bool = False, return_dict_in_generate: bool = False, use_cache: Optional[bool] = None, ): if last_layer_module is not None and _normalize_replay_module_value(replay_module) == "none": replay_module = last_layer_module replay_module = _normalize_replay_module_value(replay_module) replay_per_layers = _normalize_replay_per_layers_value(replay_per_layers) if use_cache is False: return None try: lower_k = int(lower_k) except Exception: lower_k = 0 if lower_k <= 0: return None attn = (prefill_attn or "causal").strip().lower() if attn == "prefix_full": attn = "full" if attn not in {"causal", "full"}: attn = "causal" sys_prefill = (system_prefill or "full").strip().lower() if sys_prefill not in {"full", "no_system", "no_bos_system"}: sys_prefill = "full" user_prefill_norm = (user_prefill or "full").strip().lower() if user_prefill_norm != "full": raise ValueError("Optimized LLoPA currently supports only user_prefill='full'.") optimized_settings = _resolve_optimized_llopa_settings( variant=optimized_variant, seed_mode=optimized_seed_mode, upper_prepare_mode=optimized_upper_prepare_mode, upper_bucket_multiple=optimized_upper_bucket_multiple, seq_bucket_multiple=optimized_seq_bucket_multiple, ) llopa_core = _get_llopa_core(model) llopa_step = _get_llopa_decode_step(model) if llopa_core is None or llopa_step is None: return None llopa_forward_assistant = getattr(llopa_core, "tri_forward_assistant", None) decode_output_head = _get_output_head(model) use_direct_decode_step = ( _env_flag_enabled("CAPSULE_LLOPA_DIRECT_DECODE_STEP", "1") and callable(llopa_forward_assistant) and decode_output_head is not None ) device = None if isinstance(input_ids, torch.Tensor): if input_ids.dim() != 2 or input_ids.size(0) != 1: return None device = input_ids.device if device is None: try: device = next(model.parameters()).device except Exception: device = "cpu" segments = structured_prompt_segments if isinstance(structured_prompt_segments, dict) else None if segments is None: segments = _build_structured_prompt_segments( tokenizer, prompt_messages, prompt_add_generation_prompt=bool(prompt_add_generation_prompt), device=device, ) prompt_ids = segments["prompt_ids"] system_ids = segments["system_ids"] user_ids = segments["user_ids"] assistant_prefill_ids = segments["assistant_prefill_ids"] replay_user_prefix_keep_len = int(segments.get("replay_user_prefix_keep_len", 0) or 0) replay_user_start = int(segments.get("replay_user_start", 0) or 0) replay_user_len = int(segments.get("replay_user_len", 0) or 0) if assistant_prefill_ids.numel() == 0: return None raw_temp = 0.0 if temperature is None else float(temperature) if do_sample is None: do_sample = bool(raw_temp != 0.0) do_sample = bool(do_sample) sample_temp = 1.0 if (not do_sample or raw_temp == 0.0) else float(raw_temp) top_p = 1.0 if top_p is None else float(top_p) top_k = None if top_k is None else int(top_k) initial_logits = None prompt_bundle = _build_unified_prefill_lower_prompt_bundle( tokenizer, prompt_messages=prompt_messages, prompt_add_generation_prompt=bool(prompt_add_generation_prompt), structured_prompt_segments=segments, device=device, ) with _temporary_model_attrs( model, _optimized_llopa_variant=optimized_settings["variant"], _optimized_llopa_seed_mode=optimized_settings["seed_mode"], _optimized_llopa_upper_prepare_mode=optimized_settings["upper_prepare_mode"], _optimized_llopa_upper_bucket_multiple=int(optimized_settings["upper_bucket_multiple"]), _optimized_llopa_seq_bucket_multiple=int(optimized_settings["seq_bucket_multiple"]), ): reference_seed = _optimized_prefill_lower_cache_and_logits( model, prompt_bundle=prompt_bundle, lower_k=lower_k, prefill_attn=attn, system_prefill=sys_prefill, no_upper_attn=bool(no_upper_attn), see_past_assistant=bool(see_past_assistant), replay_module=str(replay_module), replay_per_layers=int(replay_per_layers), seed_mode=optimized_settings["seed_mode"], ) canonical_input_ids = None if isinstance(input_ids, torch.Tensor) and input_ids.dim() == 2 and input_ids.size(0) == 1: valid_len = int(input_ids.size(1)) if isinstance(attention_mask, torch.Tensor) and attention_mask.dim() == 2 and attention_mask.size(0) == 1: valid_len = int(attention_mask[0].sum().item()) if valid_len > 0: canonical_input_ids = input_ids[:, -valid_len:] if not isinstance(canonical_input_ids, torch.Tensor): canonical_input_ids = prompt_bundle.get("effective_prompt_ids") if not isinstance(canonical_input_ids, torch.Tensor): canonical_input_ids = prompt_ids total_prompt_len = int(canonical_input_ids.size(1)) if max_new_tokens is None: if max_length is None: max_new_tokens = 256 else: max_new_tokens = max(0, int(max_length) - total_prompt_len) else: max_new_tokens = int(max_new_tokens) if min_new_tokens is None: if min_length is None: min_new_tokens = 0 else: min_new_tokens = max(0, int(min_length) - total_prompt_len) else: min_new_tokens = int(min_new_tokens) if reference_seed is not None: pkv, S, U, initial_logits = reference_seed else: output_head = _get_output_head(model) if bool(no_upper_attn): pkv, S, U = _llopa_prefill_cache( llopa_core, system_ids, user_ids, assistant_prefill_ids, lower_k=lower_k, prefill_mode="lower", prefill_attn=attn, system_prefill=sys_prefill, replay_user_prefix_keep_len=replay_user_prefix_keep_len, replay_user_start=replay_user_start, replay_user_len=replay_user_len, ) else: pkv, S, U, last_hidden = _llopa_prefill_cache( llopa_core, system_ids, user_ids, assistant_prefill_ids, lower_k=lower_k, prefill_mode="lower", prefill_attn=attn, system_prefill=sys_prefill, return_last_assistant_hidden=bool(output_head is not None), replay_user_prefix_keep_len=replay_user_prefix_keep_len, replay_user_start=replay_user_start, replay_user_len=replay_user_len, ) if output_head is not None and isinstance(last_hidden, torch.Tensor) and last_hidden.numel() > 0: initial_logits = output_head(last_hidden)[:, -1, :].to(torch.float32) stop_ids = set(_normalize_eos_token_ids(eos_token_id)) if not stop_ids: tok_eos = getattr(tokenizer, "eos_token_id", None) if tok_eos is not None: stop_ids.add(int(tok_eos)) with contextlib.suppress(Exception): eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>") if eot_id is not None and eot_id != tokenizer.unk_token_id: stop_ids.add(int(eot_id)) last = assistant_prefill_ids[:, -1:] stop_token_ids = _prepare_stop_token_tensor(stop_ids, last.device) logits_warpers = _build_sampling_warpers(do_sample, sample_temp, top_p, top_k) max_new_tokens = int(max_new_tokens) min_new_tokens = int(min_new_tokens) lower_k = int(lower_k) no_upper_attn_bool = bool(no_upper_attn) replay_module_str = str(replay_module) replay_per_layers_int = int(replay_per_layers) record_scores = bool(output_scores) record_compact_scores = bool(output_scores) and bool(compact_scores) should_apply_stopping_criteria = stopping_criteria is not None bucket_multiple = int(optimized_settings["seq_bucket_multiple"] or 256) sequences_full, _ = _acquire_bucketed_sequence_workspace( model, reference_ids=canonical_input_ids, batch_size=int(canonical_input_ids.size(0)), total_len=int(total_prompt_len + max_new_tokens), bucket_multiple=bucket_multiple, ) sequences_full[:, :total_prompt_len] = canonical_input_ids sequences = sequences_full[:, : total_prompt_len + max_new_tokens] generated = sequences[:, total_prompt_len:] score_list: list[torch.Tensor] = [] score_list_append = score_list.append compact_logprob_list: list[torch.Tensor] = [] compact_logprob_append = compact_logprob_list.append cur = 0 pending_logits = initial_logits while cur < max_new_tokens: out = None if pending_logits is None: if use_direct_decode_step: out = llopa_forward_assistant( assistant_ids=last, lower_k=lower_k, pkv=pkv, S=S, U=U, write_cache=True, prefill_mode="lower", no_upper_attn=no_upper_attn_bool, align_cache_position_to_layer_past=False, replay_module=replay_module_str, replay_per_layers=replay_per_layers_int, ) pkv = out.past_key_values or pkv logits = decode_output_head(out.last_hidden_state[:, -1, :]) if logits.dim() == 3: logits = logits[:, -1, :] else: out = llopa_step( assistant_ids=last, lower_k=lower_k, pkv=pkv, S=S, U=U, logits_to_keep=1, labels=None, prefill_mode="lower", no_upper_attn=no_upper_attn_bool, align_cache_position_to_layer_past=False, replay_module=replay_module_str, replay_per_layers=replay_per_layers_int, ) pkv = out.past_key_values or pkv logits = out.logits[:, -1, :] logits = logits.to(dtype=torch.float32, device=last.device, copy=True) else: logits = pending_logits pending_logits = None if stop_token_ids is not None and cur < min_new_tokens: logits.index_fill_(1, stop_token_ids, -float("inf")) if logits_warpers is not None: logits = logits_warpers(generated[:, :cur], logits) if record_scores and not record_compact_scores: score_list_append(logits.detach().clone()) if do_sample: probs = torch.softmax(logits, dim=-1) next_tok = torch.multinomial(probs, num_samples=1) else: next_tok = torch.argmax(logits, dim=-1, keepdim=True) if record_compact_scores: next_logit = torch.gather(logits, 1, next_tok) next_logprob = next_logit - torch.logsumexp(logits, dim=-1, keepdim=True) compact_logprob_append(next_logprob.squeeze(-1).detach()) generated[:, cur : cur + 1] = next_tok cur += 1 should_stop = False tok_id = int(next_tok.item()) if tok_id in stop_ids and cur >= min_new_tokens: should_stop = True if (not should_stop) and should_apply_stopping_criteria: sequences_now = sequences[:, : total_prompt_len + cur] try: should_stop = bool(stopping_criteria(sequences_now, logits)) except TypeError: should_stop = bool(stopping_criteria(sequences_now, None)) if out is not None: del out if should_stop: break last = next_tok sequences = sequences[:, : total_prompt_len + cur] if not bool(return_dict_in_generate): return sequences from transformers.generation.utils import GenerateDecoderOnlyOutput output = GenerateDecoderOnlyOutput( sequences=sequences, scores=tuple(score_list) if bool(output_scores) and not record_compact_scores else None, past_key_values=pkv, ) if record_compact_scores: if compact_logprob_list: compact_logprobs = torch.stack(compact_logprob_list, dim=1) else: compact_logprobs = torch.empty( (sequences.size(0), 0), dtype=torch.float32, device=sequences.device, ) setattr(output, "generated_token_logprobs", compact_logprobs) return output def _normalize_replay_module_value(value: Optional[str]) -> str: raw = str(value or "none").strip().lower() aliases = { "": "none", "off": "none", "disabled": "none", "disable": "none", "self-attention": "self", "self_attention": "self", "selfattn": "self", "self_attn": "self", "cross-attention": "cross", "cross_attention": "cross", "crossattn": "cross", "cross_attn": "cross", } raw = aliases.get(raw, raw) if raw in {"none", "self", "cross"}: return raw return "none" def _normalize_replay_per_layers_value(value) -> int: try: normalized = int(value) except Exception: return -1 if normalized == -1 or normalized >= 1: return normalized return -1 def _normalize_structured_llopa_runtime( lower_k: int, *, prefill_attn: str = "causal", system_prefill: str = "full", user_prefill: str = "full", no_upper_attn: bool = False, see_past_assistant: bool = False, replay_module: str = "none", replay_per_layers: int = -1, last_layer_module: Optional[str] = None, ): try: lower_k = int(lower_k) except Exception: lower_k = 0 attn = (prefill_attn or "causal").strip().lower() if attn == "prefix_full": attn = "full" sys_prefill = (system_prefill or "full").strip().lower() if sys_prefill not in {"full", "no_system", "no_bos_system"}: sys_prefill = "full" user_prefill_norm = (user_prefill or "full").strip().lower() if last_layer_module is not None and _normalize_replay_module_value(replay_module) == "none": replay_module = last_layer_module replay_module_norm = _normalize_replay_module_value(replay_module) replay_per_layers_norm = _normalize_replay_per_layers_value(replay_per_layers) if lower_k <= 0 or attn not in {"causal", "full"}: return None return ( int(lower_k), attn, sys_prefill, user_prefill_norm, bool(no_upper_attn), replay_module_norm, replay_per_layers_norm, ) def _attach_structured_llopa_generate(model, tokenizer) -> None: try: import types except Exception: return if getattr(model, "_structured_llopa_generate_attached", False): return orig_generate = getattr(model, "generate", None) if not callable(orig_generate): return def _structured_llopa_generate(self, *args, **kwargs): if args: if "input_ids" not in kwargs: kwargs["input_ids"] = args[0] args = args[1:] if args: return orig_generate(*args, **kwargs) optimized_enabled = kwargs.pop("optimized_llopa_generate", None) llopa_v2_batch_enabled = kwargs.pop("llopa_v2_batch_generate", None) llopa_v2_enabled = kwargs.pop("llopa_v2_generate", None) llopa_v3_enabled = kwargs.pop("llopa_v3_generate", None) runtime_solo_enabled = kwargs.pop("runtime_solo_generate", None) runtime_solo_v2_enabled = kwargs.pop("runtime_solo_v2_generate", None) unified_enabled = kwargs.pop("unified_llopa_generate", None) direct_enabled = kwargs.pop("direct_llopa_generate", None) legacy_search = bool(kwargs.pop("direct_llopa_legacy_search", False)) prompt_messages = kwargs.pop("prompt_messages", None) prompt_add_generation_prompt = kwargs.pop("prompt_add_generation_prompt", None) structured_prompt_segments = kwargs.pop("structured_prompt_segments", None) compact_scores = bool(kwargs.pop("capsule_compact_scores", False)) mode = None if optimized_enabled is not None: if bool(optimized_enabled): mode = "optimized" elif llopa_v2_batch_enabled is None and llopa_v2_enabled is None and llopa_v3_enabled is None and runtime_solo_enabled is None and runtime_solo_v2_enabled is None and unified_enabled is None and direct_enabled is None: return orig_generate(**kwargs) if llopa_v2_batch_enabled is not None: if bool(llopa_v2_batch_enabled): mode = "llopa_v2_batch" elif llopa_v2_enabled is None and llopa_v3_enabled is None and runtime_solo_enabled is None and runtime_solo_v2_enabled is None and unified_enabled is None and direct_enabled is None: return orig_generate(**kwargs) if llopa_v2_enabled is not None: if bool(llopa_v2_enabled): mode = "llopa_v2" elif llopa_v3_enabled is None and runtime_solo_enabled is None and runtime_solo_v2_enabled is None and unified_enabled is None and direct_enabled is None: return orig_generate(**kwargs) if llopa_v3_enabled is not None: if bool(llopa_v3_enabled): mode = "llopa_v3" elif runtime_solo_enabled is None and runtime_solo_v2_enabled is None and unified_enabled is None and direct_enabled is None: return orig_generate(**kwargs) if runtime_solo_v2_enabled is not None: if bool(runtime_solo_v2_enabled): mode = "solo_v2" elif runtime_solo_enabled is None and unified_enabled is None and direct_enabled is None: return orig_generate(**kwargs) if runtime_solo_enabled is not None: if bool(runtime_solo_enabled): mode = "solo" elif unified_enabled is None and direct_enabled is None: return orig_generate(**kwargs) if unified_enabled is not None: if bool(unified_enabled): mode = "unified" elif direct_enabled is None: return orig_generate(**kwargs) if mode is None and direct_enabled is not None: if bool(direct_enabled): mode = "direct" else: return orig_generate(**kwargs) if mode is None: if bool(getattr(self, "_optimized_llopa_generate_default", False)): mode = "optimized" elif bool(getattr(self, "_llopa_v2_batch_generate_default", False)): mode = "llopa_v2_batch" elif bool(getattr(self, "_llopa_v2_generate_default", False)): mode = "llopa_v3" if str(getattr(self, "_capsule_inference_path", "") or "") == "llopa_v3" else "llopa_v2" elif bool(getattr(self, "_unified_llopa_generate_default", False)): mode = "unified" elif bool(getattr(self, "_runtime_structured_freeze_generate_default", False)): mode = "freeze" elif bool(getattr(self, "_runtime_structured_solo_v2_generate_default", False)): mode = "solo_v2" elif bool(getattr(self, "_runtime_structured_solo_generate_default", False)): mode = "solo" elif bool(getattr(self, "_direct_llopa_generate_default", False)): mode = "direct" else: return orig_generate(**kwargs) if kwargs.get("inputs_embeds") is not None: return orig_generate(**kwargs) if mode == "optimized": mode_label = "optimized_llopa_generate" elif mode == "llopa_v2_batch": mode_label = "llopa_v2_batch_generate" elif mode == "llopa_v2": mode_label = "llopa_v2_generate" elif mode == "llopa_v3": mode_label = "llopa_v3_generate" elif mode == "unified": mode_label = "unified_llopa_generate" elif mode == "freeze": mode_label = "runtime_freeze_generate" elif mode == "solo_v2": mode_label = "runtime_solo_v2_generate" elif mode == "solo": mode_label = "runtime_solo_generate" else: mode_label = "direct_llopa_generate" if int(kwargs.get("num_beams", 1) or 1) != 1: _warn_once( self, f"_warned_{mode_label}_num_beams", f"[load_llopa_model][warn] {mode_label} currently supports only num_beams=1; falling back to model.generate().", ) return orig_generate(**kwargs) if int(kwargs.get("num_return_sequences", 1) or 1) != 1: _warn_once( self, f"_warned_{mode_label}_num_return_sequences", f"[load_llopa_model][warn] {mode_label} currently supports only num_return_sequences=1; falling back to model.generate().", ) return orig_generate(**kwargs) if mode in {"unified", "optimized", "llopa_v2", "llopa_v3", "llopa_v2_batch"}: attr_prefix = "_optimized_llopa" if mode == "optimized" else "_llopa_v2" if mode in {"llopa_v2", "llopa_v3", "llopa_v2_batch"} else "_unified_llopa" kw_prefix = "optimized_llopa" if mode == "optimized" else "llopa_v2" if mode in {"llopa_v2", "llopa_v3", "llopa_v2_batch"} else "unified_llopa" lower_k_attr = f"{attr_prefix}_layers" attn_attr = f"{attr_prefix}_attn" system_attr = f"{attr_prefix}_system_prefill" user_attr = f"{attr_prefix}_user_prefill" no_upper_attr = f"{attr_prefix}_no_upper_attn" see_past_assistant_attr = f"{attr_prefix}_see_past_assistant" replay_module_attr = f"{attr_prefix}_replay_module" replay_per_layers_attr = f"{attr_prefix}_replay_per_layers" last_layer_attr = f"{attr_prefix}_last_layer_module" variant_attr = "_optimized_llopa_variant" seed_attr = "_optimized_llopa_seed_mode" upper_prepare_attr = "_optimized_llopa_upper_prepare_mode" upper_bucket_attr = "_optimized_llopa_upper_bucket_multiple" seq_bucket_attr = "_optimized_llopa_seq_bucket_multiple" lower_k_local = kwargs.pop(f"{kw_prefix}_layers", None) if lower_k_local is None: lower_k_local = int(getattr(self, lower_k_attr, 0) or 0) attn_local = kwargs.pop(f"{kw_prefix}_attn", None) if attn_local is None: attn_local = str(getattr(self, attn_attr, "causal") or "causal") system_prefill_local = kwargs.pop(f"{kw_prefix}_system_prefill", None) if system_prefill_local is None: system_prefill_local = str(getattr(self, system_attr, "full") or "full") user_prefill_local = kwargs.pop(f"{kw_prefix}_user_prefill", None) if user_prefill_local is None: user_prefill_local = str(getattr(self, user_attr, "full") or "full") no_upper_attn_local = kwargs.pop(f"{kw_prefix}_no_upper_attn", None) if no_upper_attn_local is None: no_upper_attn_local = bool(getattr(self, no_upper_attr, False)) see_past_assistant_local = kwargs.pop(f"{kw_prefix}_see_past_assistant", None) if see_past_assistant_local is None: see_past_assistant_local = bool(getattr(self, see_past_assistant_attr, False)) replay_module_local = kwargs.pop(f"{kw_prefix}_replay_module", None) if replay_module_local is None: replay_module_local = kwargs.pop(f"{kw_prefix}_last_layer_module", None) if replay_module_local is None: replay_module_local = getattr(self, replay_module_attr, None) if replay_module_local is None: replay_module_local = str(getattr(self, last_layer_attr, "none") or "none") replay_per_layers_local = kwargs.pop(f"{kw_prefix}_replay_per_layers", None) if replay_per_layers_local is None: replay_per_layers_local = getattr(self, replay_per_layers_attr, -1) structured_seed_mode_local = "auto" if mode in {"llopa_v2", "llopa_v3", "llopa_v2_batch"}: kwargs.pop("llopa_v2_seed_mode", None) structured_seed_mode_local = "prefill_header" optimized_variant_local = None optimized_seed_mode_local = None optimized_upper_prepare_mode_local = None optimized_upper_bucket_multiple_local = None optimized_seq_bucket_multiple_local = None if mode == "optimized": optimized_variant_local = kwargs.pop("optimized_llopa_variant", None) if optimized_variant_local is None: optimized_variant_local = getattr(self, variant_attr, "upper_ws_auto") optimized_seed_mode_local = kwargs.pop("optimized_llopa_seed_mode", None) if optimized_seed_mode_local is None: optimized_seed_mode_local = getattr(self, seed_attr, "auto") optimized_upper_prepare_mode_local = kwargs.pop("optimized_llopa_upper_prepare_mode", None) if optimized_upper_prepare_mode_local is None: optimized_upper_prepare_mode_local = getattr(self, upper_prepare_attr, "bucketed_workspace") optimized_upper_bucket_multiple_local = kwargs.pop("optimized_llopa_upper_bucket_multiple", None) if optimized_upper_bucket_multiple_local is None: optimized_upper_bucket_multiple_local = getattr(self, upper_bucket_attr, 256) optimized_seq_bucket_multiple_local = kwargs.pop("optimized_llopa_seq_bucket_multiple", None) if optimized_seq_bucket_multiple_local is None: optimized_seq_bucket_multiple_local = getattr(self, seq_bucket_attr, 256) if prompt_messages is None and structured_prompt_segments is None: _warn_once( self, f"_warned_{mode_label}_missing_prompt_metadata", f"[load_llopa_model][warn] {mode_label} requested without structured prompt metadata; falling back to model.generate().", ) return orig_generate(**kwargs) if prompt_add_generation_prompt is None and structured_prompt_segments is None: raise ValueError(f"{mode_label} requires prompt_add_generation_prompt when prompt_messages are provided.") generate_impl = ( _direct_optimized_llopa_generate_impl if mode == "optimized" else _direct_llopa_batch_generate_impl if mode == "llopa_v2_batch" else _direct_llopa_generate_impl ) generate_kwargs = dict( prompt_messages=prompt_messages, prompt_add_generation_prompt=bool(prompt_add_generation_prompt), structured_prompt_segments=structured_prompt_segments, input_ids=kwargs.get("input_ids"), attention_mask=kwargs.get("attention_mask"), lower_k=int(lower_k_local), prefill_attn=str(attn_local), system_prefill=str(system_prefill_local), user_prefill=str(user_prefill_local), no_upper_attn=bool(no_upper_attn_local), see_past_assistant=bool(see_past_assistant_local), replay_module=str(replay_module_local), replay_per_layers=int(replay_per_layers_local or -1), max_length=kwargs.get("max_length"), max_new_tokens=kwargs.get("max_new_tokens"), min_length=kwargs.get("min_length"), min_new_tokens=kwargs.get("min_new_tokens"), do_sample=kwargs.get("do_sample"), temperature=kwargs.get("temperature"), top_p=kwargs.get("top_p"), top_k=kwargs.get("top_k"), stopping_criteria=kwargs.get("stopping_criteria"), pad_token_id=kwargs.get("pad_token_id"), eos_token_id=kwargs.get("eos_token_id"), output_scores=bool(kwargs.get("output_scores", False)), return_dict_in_generate=bool(kwargs.get("return_dict_in_generate", False)), use_cache=kwargs.get("use_cache"), ) generate_kwargs["compact_scores"] = bool(compact_scores) if mode in {"llopa_v2", "llopa_v3", "llopa_v2_batch"}: generate_kwargs["seed_mode"] = str(structured_seed_mode_local) if mode == "optimized": generate_kwargs.update( optimized_variant=optimized_variant_local, optimized_seed_mode=optimized_seed_mode_local, optimized_upper_prepare_mode=optimized_upper_prepare_mode_local, optimized_upper_bucket_multiple=optimized_upper_bucket_multiple_local, optimized_seq_bucket_multiple=optimized_seq_bucket_multiple_local, ) with torch.inference_mode(): previous_mixin_decode = getattr(self, "_llopa_v2_generation_mixin_decode", None) try: if mode == "llopa_v2": setattr(self, "_llopa_v2_generation_mixin_decode", False) elif mode == "llopa_v3": setattr(self, "_llopa_v2_generation_mixin_decode", True) unified_out = generate_impl( self, tokenizer, **generate_kwargs, ) finally: if previous_mixin_decode is None: with contextlib.suppress(Exception): delattr(self, "_llopa_v2_generation_mixin_decode") else: setattr(self, "_llopa_v2_generation_mixin_decode", previous_mixin_decode) if unified_out is not None: return unified_out raise RuntimeError(f"Structured {mode} LLoPA failed unexpectedly for the current prompt.") if mode == "freeze": lower_k_local = kwargs.pop("runtime_prefill_freeze_layers", None) if lower_k_local is None: lower_k_local = int(getattr(self, "_runtime_prefill_freeze_layers", 0) or 0) attn_local = kwargs.pop("runtime_prefill_freeze_attn", None) if attn_local is None: attn_local = str(getattr(self, "_runtime_prefill_freeze_attn", "causal") or "causal") system_prefill_local = kwargs.pop("runtime_prefill_freeze_system_prefill", None) if system_prefill_local is None: system_prefill_local = str(getattr(self, "_runtime_prefill_freeze_system_prefill", "no_bos_system") or "no_bos_system") if prompt_messages is None: _warn_once( self, "_warned_runtime_freeze_missing_prompt_metadata", "[load_llopa_model][warn] runtime_freeze_generate requested without structured prompt metadata; falling back to model.generate().", ) return orig_generate(**kwargs) if prompt_add_generation_prompt is None: raise ValueError("runtime_freeze_generate requires prompt_add_generation_prompt when prompt_messages are provided.") freeze_out = _direct_freeze_generate_impl( self, tokenizer, prompt_messages=prompt_messages, prompt_add_generation_prompt=bool(prompt_add_generation_prompt), input_ids=kwargs.get("input_ids"), attention_mask=kwargs.get("attention_mask"), lower_k=int(lower_k_local), prefill_attn=str(attn_local), system_prefill=str(system_prefill_local), max_length=kwargs.get("max_length"), max_new_tokens=kwargs.get("max_new_tokens"), min_length=kwargs.get("min_length"), min_new_tokens=kwargs.get("min_new_tokens"), do_sample=kwargs.get("do_sample"), temperature=kwargs.get("temperature"), top_p=kwargs.get("top_p"), top_k=kwargs.get("top_k"), stopping_criteria=kwargs.get("stopping_criteria"), pad_token_id=kwargs.get("pad_token_id"), eos_token_id=kwargs.get("eos_token_id"), output_scores=bool(kwargs.get("output_scores", False)), return_dict_in_generate=bool(kwargs.get("return_dict_in_generate", False)), use_cache=kwargs.get("use_cache"), ) if freeze_out is not None: return freeze_out raise RuntimeError("Structured runtime freeze failed unexpectedly for the current prompt.") if mode == "solo": lower_k_local = kwargs.pop("runtime_prefill_solo_layers", None) if lower_k_local is None: lower_k_local = int(getattr(self, "_runtime_prefill_solo_layers", 0) or 0) attn_local = kwargs.pop("runtime_prefill_solo_attn", None) if attn_local is None: attn_local = str(getattr(self, "_runtime_prefill_solo_attn", "causal") or "causal") system_prefill_local = kwargs.pop("runtime_prefill_solo_system_prefill", None) if system_prefill_local is None: system_prefill_local = str(getattr(self, "_runtime_prefill_solo_system_prefill", "no_bos_system") or "no_bos_system") if prompt_messages is None: _warn_once( self, "_warned_runtime_solo_missing_prompt_metadata", "[load_llopa_model][warn] runtime_solo_generate requested without structured prompt metadata; falling back to model.generate().", ) return orig_generate(**kwargs) if prompt_add_generation_prompt is None: raise ValueError("runtime_solo_generate requires prompt_add_generation_prompt when prompt_messages are provided.") solo_out = _direct_solo_generate_impl( self, tokenizer, prompt_messages=prompt_messages, prompt_add_generation_prompt=bool(prompt_add_generation_prompt), input_ids=kwargs.get("input_ids"), attention_mask=kwargs.get("attention_mask"), lower_k=int(lower_k_local), prefill_attn=str(attn_local), system_prefill=str(system_prefill_local), max_length=kwargs.get("max_length"), max_new_tokens=kwargs.get("max_new_tokens"), min_length=kwargs.get("min_length"), min_new_tokens=kwargs.get("min_new_tokens"), do_sample=kwargs.get("do_sample"), temperature=kwargs.get("temperature"), top_p=kwargs.get("top_p"), top_k=kwargs.get("top_k"), stopping_criteria=kwargs.get("stopping_criteria"), pad_token_id=kwargs.get("pad_token_id"), eos_token_id=kwargs.get("eos_token_id"), output_scores=bool(kwargs.get("output_scores", False)), return_dict_in_generate=bool(kwargs.get("return_dict_in_generate", False)), use_cache=kwargs.get("use_cache"), ) if solo_out is not None: return solo_out raise RuntimeError("Structured runtime solo-attn failed unexpectedly for the current prompt.") if mode == "solo_v2": lower_k_local = kwargs.pop("runtime_prefill_solo_v2_layers", None) if lower_k_local is None: lower_k_local = int(getattr(self, "_runtime_prefill_solo_v2_layers", 0) or 0) attn_local = kwargs.pop("runtime_prefill_solo_v2_attn", None) if attn_local is None: attn_local = str(getattr(self, "_runtime_prefill_solo_v2_attn", "causal") or "causal") system_prefill_local = kwargs.pop("runtime_prefill_solo_v2_system_prefill", None) if system_prefill_local is None: system_prefill_local = str(getattr(self, "_runtime_prefill_solo_v2_system_prefill", "no_bos_system") or "no_bos_system") with_bos_local = kwargs.pop("runtime_prefill_solo_v2_with_bos", None) if with_bos_local is None: with_bos_local = bool(getattr(self, "_runtime_prefill_solo_v2_with_bos", False)) if prompt_messages is None: _warn_once( self, "_warned_runtime_solo_v2_missing_prompt_metadata", "[load_llopa_model][warn] runtime_solo_v2_generate requested without structured prompt metadata; falling back to model.generate().", ) return orig_generate(**kwargs) if prompt_add_generation_prompt is None: raise ValueError("runtime_solo_v2_generate requires prompt_add_generation_prompt when prompt_messages are provided.") solo_out = _direct_solo_generate_impl( self, tokenizer, prompt_messages=prompt_messages, prompt_add_generation_prompt=bool(prompt_add_generation_prompt), input_ids=kwargs.get("input_ids"), attention_mask=kwargs.get("attention_mask"), lower_k=int(lower_k_local), prefill_attn=str(attn_local), system_prefill=str(system_prefill_local), max_length=kwargs.get("max_length"), max_new_tokens=kwargs.get("max_new_tokens"), min_length=kwargs.get("min_length"), min_new_tokens=kwargs.get("min_new_tokens"), do_sample=kwargs.get("do_sample"), temperature=kwargs.get("temperature"), top_p=kwargs.get("top_p"), top_k=kwargs.get("top_k"), stopping_criteria=kwargs.get("stopping_criteria"), pad_token_id=kwargs.get("pad_token_id"), eos_token_id=kwargs.get("eos_token_id"), output_scores=bool(kwargs.get("output_scores", False)), return_dict_in_generate=bool(kwargs.get("return_dict_in_generate", False)), use_cache=kwargs.get("use_cache"), solo_v2=True, with_bos=bool(with_bos_local), ) if solo_out is not None: return solo_out raise RuntimeError("Structured runtime solo-attn-v2 failed unexpectedly for the current prompt.") _warn_once( self, "_warned_deprecated_direct_llopa", "[load_llopa_model][warn] direct_llopa_* is deprecated; use unified_llopa_* or INFERENCE_PATH=unified_llopa. Legacy users can keep existing envs unchanged.", ) lower_k_local = kwargs.pop("direct_llopa_layers", None) if lower_k_local is None: lower_k_local = int(getattr(self, "_direct_llopa_layers", 0) or 0) attn_local = kwargs.pop("direct_llopa_attn", None) if attn_local is None: attn_local = str(getattr(self, "_direct_llopa_attn", "causal") or "causal") system_prefill_local = kwargs.pop("direct_llopa_system_prefill", None) if system_prefill_local is None: system_prefill_local = str(getattr(self, "_direct_llopa_system_prefill", "full") or "full") user_prefill_local = kwargs.pop("direct_llopa_user_prefill", None) if user_prefill_local is None: user_prefill_local = str(getattr(self, "_direct_llopa_user_prefill", "full") or "full") no_upper_attn_local = kwargs.pop("direct_llopa_no_upper_attn", None) if no_upper_attn_local is None: no_upper_attn_local = bool(getattr(self, "_direct_llopa_no_upper_attn", False)) if not legacy_search and prompt_messages is None and structured_prompt_segments is None: raise ValueError( "direct_llopa_generate now requires prompt_messages and prompt_add_generation_prompt. " "Use direct_llopa_legacy_search=True only for legacy prompt scanning." ) if not legacy_search and prompt_add_generation_prompt is None and structured_prompt_segments is None: raise ValueError("direct_llopa_generate requires prompt_add_generation_prompt when prompt_messages are provided.") if legacy_search: with torch.inference_mode(): direct_out = _legacy_direct_llopa_generate_impl( self, tokenizer, input_ids=kwargs.get("input_ids"), attention_mask=kwargs.get("attention_mask"), lower_k=int(lower_k_local), prefill_attn=str(attn_local), max_length=kwargs.get("max_length"), max_new_tokens=kwargs.get("max_new_tokens"), min_length=kwargs.get("min_length"), min_new_tokens=kwargs.get("min_new_tokens"), do_sample=kwargs.get("do_sample"), temperature=kwargs.get("temperature"), top_p=kwargs.get("top_p"), top_k=kwargs.get("top_k"), stopping_criteria=kwargs.get("stopping_criteria"), pad_token_id=kwargs.get("pad_token_id"), eos_token_id=kwargs.get("eos_token_id"), output_scores=bool(kwargs.get("output_scores", False)), return_dict_in_generate=bool(kwargs.get("return_dict_in_generate", False)), use_cache=kwargs.get("use_cache"), ) else: with torch.inference_mode(): direct_out = _direct_llopa_generate_impl( self, tokenizer, prompt_messages=prompt_messages, prompt_add_generation_prompt=bool(prompt_add_generation_prompt), structured_prompt_segments=structured_prompt_segments, input_ids=kwargs.get("input_ids"), attention_mask=kwargs.get("attention_mask"), lower_k=int(lower_k_local), prefill_attn=str(attn_local), system_prefill=str(system_prefill_local), user_prefill=str(user_prefill_local), no_upper_attn=bool(no_upper_attn_local), max_length=kwargs.get("max_length"), max_new_tokens=kwargs.get("max_new_tokens"), min_length=kwargs.get("min_length"), min_new_tokens=kwargs.get("min_new_tokens"), do_sample=kwargs.get("do_sample"), temperature=kwargs.get("temperature"), top_p=kwargs.get("top_p"), top_k=kwargs.get("top_k"), stopping_criteria=kwargs.get("stopping_criteria"), pad_token_id=kwargs.get("pad_token_id"), eos_token_id=kwargs.get("eos_token_id"), output_scores=bool(kwargs.get("output_scores", False)), return_dict_in_generate=bool(kwargs.get("return_dict_in_generate", False)), use_cache=kwargs.get("use_cache"), ) if direct_out is not None: return direct_out if not legacy_search: raise RuntimeError("Structured direct LLoPA failed unexpectedly for the current prompt.") _warn_once( self, "_warned_direct_llopa_fallback", "[load_llopa_model][warn] direct_llopa_generate could not use the current prompt/generation settings; falling back to model.generate().", ) return orig_generate(**kwargs) try: model.generate = types.MethodType(_structured_llopa_generate, model) setattr(model, "_structured_llopa_generate_attached", True) except Exception: pass def _attach_unified_llopa_generate( model, tokenizer, *, lower_k: int, prefill_attn: str = "causal", system_prefill: str = "full", user_prefill: str = "full", no_upper_attn: bool = False, see_past_assistant: bool = False, replay_module: str = "none", replay_per_layers: int = -1, last_layer_module: Optional[str] = None, ) -> None: normalized = _normalize_structured_llopa_runtime( lower_k, prefill_attn=prefill_attn, system_prefill=system_prefill, user_prefill=user_prefill, no_upper_attn=no_upper_attn, replay_module=replay_module, replay_per_layers=replay_per_layers, last_layer_module=last_layer_module, ) if normalized is None: return lower_k, attn, sys_prefill, user_prefill_norm, no_upper_attn, replay_module, replay_per_layers = normalized try: setattr(model, "_unified_llopa_layers", int(lower_k)) setattr(model, "_unified_llopa_attn", attn) setattr(model, "_unified_llopa_system_prefill", sys_prefill) setattr(model, "_unified_llopa_user_prefill", user_prefill_norm) setattr(model, "_unified_llopa_no_upper_attn", bool(no_upper_attn)) setattr(model, "_unified_llopa_see_past_assistant", bool(see_past_assistant)) setattr(model, "_unified_llopa_replay_module", str(replay_module)) setattr(model, "_unified_llopa_last_layer_module", str(replay_module)) setattr(model, "_unified_llopa_replay_per_layers", int(replay_per_layers)) setattr(model, "_unified_llopa_generate_default", True) setattr(model, "_capsule_inference_path", "unified_llopa") except Exception: return _attach_structured_llopa_generate(model, tokenizer) try: setattr(model, "_unified_llopa_generate_attached", True) except Exception: pass def _attach_llopa_v2_generate( model, tokenizer, *, lower_k: int, prefill_attn: str = "causal", system_prefill: str = "full", user_prefill: str = "full", no_upper_attn: bool = False, see_past_assistant: bool = False, replay_module: str = "none", replay_per_layers: int = -1, last_layer_module: Optional[str] = None, seed_mode: str = "prefill_header", generation_mixin_decode: bool = False, capsule_inference_path: str = "llopa_v2", ) -> None: normalized = _normalize_structured_llopa_runtime( lower_k, prefill_attn=prefill_attn, system_prefill=system_prefill, user_prefill=user_prefill, no_upper_attn=no_upper_attn, replay_module=replay_module, replay_per_layers=replay_per_layers, last_layer_module=last_layer_module, ) if normalized is None: return lower_k, attn, sys_prefill, user_prefill_norm, no_upper_attn, replay_module, replay_per_layers = normalized normalized_seed_mode = _normalize_structured_llopa_seed_mode(seed_mode) if normalized_seed_mode != "prefill_header": normalized_seed_mode = "prefill_header" try: setattr(model, "_llopa_v2_layers", int(lower_k)) setattr(model, "_llopa_v2_attn", attn) setattr(model, "_llopa_v2_system_prefill", sys_prefill) setattr(model, "_llopa_v2_user_prefill", user_prefill_norm) setattr(model, "_llopa_v2_no_upper_attn", bool(no_upper_attn)) setattr(model, "_llopa_v2_see_past_assistant", bool(see_past_assistant)) setattr(model, "_llopa_v2_replay_module", str(replay_module)) setattr(model, "_llopa_v2_last_layer_module", str(replay_module)) setattr(model, "_llopa_v2_replay_per_layers", int(replay_per_layers)) setattr(model, "_llopa_v2_seed_mode", str(normalized_seed_mode)) setattr(model, "_llopa_v2_generation_mixin_decode", bool(generation_mixin_decode)) setattr(model, "_llopa_v2_generate_default", True) setattr(model, "_capsule_inference_path", str(capsule_inference_path or "llopa_v2")) except Exception: return _attach_structured_llopa_generate(model, tokenizer) try: setattr(model, "_llopa_v2_generate_attached", True) except Exception: pass def _attach_llopa_v2_batch_generate( model, tokenizer, *, lower_k: int, prefill_attn: str = "causal", system_prefill: str = "full", user_prefill: str = "full", no_upper_attn: bool = False, see_past_assistant: bool = False, replay_module: str = "none", replay_per_layers: int = -1, last_layer_module: Optional[str] = None, seed_mode: str = "prefill_header", ) -> None: normalized = _normalize_structured_llopa_runtime( lower_k, prefill_attn=prefill_attn, system_prefill=system_prefill, user_prefill=user_prefill, no_upper_attn=no_upper_attn, replay_module=replay_module, replay_per_layers=replay_per_layers, last_layer_module=last_layer_module, ) if normalized is None: return lower_k, attn, sys_prefill, user_prefill_norm, no_upper_attn, replay_module, replay_per_layers = normalized normalized_seed_mode = _normalize_structured_llopa_seed_mode(seed_mode) if normalized_seed_mode != "prefill_header": normalized_seed_mode = "prefill_header" try: setattr(model, "_llopa_v2_layers", int(lower_k)) setattr(model, "_llopa_v2_attn", attn) setattr(model, "_llopa_v2_system_prefill", sys_prefill) setattr(model, "_llopa_v2_user_prefill", user_prefill_norm) setattr(model, "_llopa_v2_no_upper_attn", bool(no_upper_attn)) setattr(model, "_llopa_v2_see_past_assistant", bool(see_past_assistant)) setattr(model, "_llopa_v2_replay_module", str(replay_module)) setattr(model, "_llopa_v2_last_layer_module", str(replay_module)) setattr(model, "_llopa_v2_replay_per_layers", int(replay_per_layers)) setattr(model, "_llopa_v2_seed_mode", str(normalized_seed_mode)) setattr(model, "_llopa_v2_batch_generate_default", True) setattr(model, "_capsule_inference_path", "llopa_v2_batch") except Exception: return _attach_structured_llopa_generate(model, tokenizer) try: setattr(model, "_llopa_v2_batch_generate_attached", True) except Exception: pass def _attach_optimized_llopa_generate( model, tokenizer, *, lower_k: int, prefill_attn: str = "causal", system_prefill: str = "full", user_prefill: str = "full", no_upper_attn: bool = False, see_past_assistant: bool = False, replay_module: str = "none", replay_per_layers: int = -1, last_layer_module: Optional[str] = None, optimized_variant: Optional[str] = None, optimized_seed_mode: Optional[str] = None, optimized_upper_prepare_mode: Optional[str] = None, optimized_upper_bucket_multiple: Optional[int] = None, optimized_seq_bucket_multiple: Optional[int] = None, ) -> None: normalized = _normalize_structured_llopa_runtime( lower_k, prefill_attn=prefill_attn, system_prefill=system_prefill, user_prefill=user_prefill, no_upper_attn=no_upper_attn, replay_module=replay_module, replay_per_layers=replay_per_layers, last_layer_module=last_layer_module, ) if normalized is None: return lower_k, attn, sys_prefill, user_prefill_norm, no_upper_attn, replay_module, replay_per_layers = normalized optimized_settings = _resolve_optimized_llopa_settings( variant=optimized_variant, seed_mode=optimized_seed_mode, upper_prepare_mode=optimized_upper_prepare_mode, upper_bucket_multiple=optimized_upper_bucket_multiple, seq_bucket_multiple=optimized_seq_bucket_multiple, ) try: setattr(model, "_optimized_llopa_layers", int(lower_k)) setattr(model, "_optimized_llopa_attn", attn) setattr(model, "_optimized_llopa_system_prefill", sys_prefill) setattr(model, "_optimized_llopa_user_prefill", user_prefill_norm) setattr(model, "_optimized_llopa_no_upper_attn", bool(no_upper_attn)) setattr(model, "_optimized_llopa_see_past_assistant", bool(see_past_assistant)) setattr(model, "_optimized_llopa_replay_module", str(replay_module)) setattr(model, "_optimized_llopa_last_layer_module", str(replay_module)) setattr(model, "_optimized_llopa_replay_per_layers", int(replay_per_layers)) setattr(model, "_optimized_llopa_variant", str(optimized_settings["variant"])) setattr(model, "_optimized_llopa_seed_mode", str(optimized_settings["seed_mode"])) setattr(model, "_optimized_llopa_upper_prepare_mode", str(optimized_settings["upper_prepare_mode"])) setattr(model, "_optimized_llopa_upper_bucket_multiple", int(optimized_settings["upper_bucket_multiple"])) setattr(model, "_optimized_llopa_seq_bucket_multiple", int(optimized_settings["seq_bucket_multiple"])) setattr(model, "_optimized_llopa_generate_default", True) setattr(model, "_capsule_inference_path", "optimized_llopa") except Exception: return _attach_structured_llopa_generate(model, tokenizer) try: setattr(model, "_optimized_llopa_generate_attached", True) except Exception: pass def _attach_direct_llopa_generate( model, tokenizer, *, lower_k: int, prefill_attn: str = "causal", system_prefill: str = "full", user_prefill: str = "full", no_upper_attn: bool = False, ) -> None: normalized = _normalize_structured_llopa_runtime( lower_k, prefill_attn=prefill_attn, system_prefill=system_prefill, user_prefill=user_prefill, no_upper_attn=no_upper_attn, ) if normalized is None: return lower_k, attn, sys_prefill, user_prefill_norm, no_upper_attn, _, _ = normalized header_ids = _assistant_header_ids(tokenizer, "cpu") try: setattr(model, "_direct_llopa_layers", int(lower_k)) setattr(model, "_direct_llopa_attn", attn) setattr(model, "_direct_llopa_system_prefill", sys_prefill) setattr(model, "_direct_llopa_user_prefill", user_prefill_norm) setattr(model, "_direct_llopa_no_upper_attn", bool(no_upper_attn)) setattr(model, "_direct_llopa_generate_default", True) setattr(model, "_capsule_inference_path", "legacy_llopa") if isinstance(header_ids, torch.Tensor) and header_ids.numel() > 0: setattr(model, "_direct_llopa_header_ids", header_ids.detach().to(device="cpu", dtype=torch.long)) except Exception: return _attach_structured_llopa_generate(model, tokenizer) try: setattr(model, "_direct_llopa_generate_attached", True) except Exception: pass def load_llopa_model(model_repo: str, *, model_name: str = "", tokenizer_name: str = "", num_specials: Optional[int] = None, backbone_dir: str = "", lopa_modeling_path: str = "", modeling_family: str = "auto", dtype: str = "auto", torch_dtype=None, device: str = "", device_map: Optional[str] = None, attn_impl: str = "auto", attn_implementation: Optional[str] = None, _attn_implementation: Optional[str] = None, trust_remote_code: bool = False, cache_dir: Optional[str] = None, revision: Optional[str] = None, token: Optional[str] = None, local_files_only: bool = False, force_custom_modeling: bool = False, number_of_lora: int = 1, use_lora: bool = True, merge_on_cpu: bool = True, enable_thinking: Optional[bool] = None, no_upper_attn: Optional[bool] = None, runtime_prefill_lower: Optional[bool] = None, runtime_prefill_layers: Optional[int] = None, runtime_prefill_attn: Optional[str] = None, runtime_prefill_system_prefill: Optional[str] = None, runtime_prefill_freeze: Optional[bool] = None, runtime_prefill_freeze_layers: Optional[int] = None, runtime_prefill_freeze_attn: Optional[str] = None, runtime_prefill_freeze_system_prefill: Optional[str] = None, runtime_prefill_solo: Optional[bool] = None, runtime_prefill_solo_layers: Optional[int] = None, runtime_prefill_solo_attn: Optional[str] = None, runtime_prefill_solo_system_prefill: Optional[str] = None, runtime_prefill_solo_v2: Optional[bool] = None, runtime_prefill_solo_v2_layers: Optional[int] = None, runtime_prefill_solo_v2_attn: Optional[str] = None, runtime_prefill_solo_v2_system_prefill: Optional[str] = None, runtime_prefill_solo_v2_with_bos: Optional[bool] = None, runtime_llopa_prefill: Optional[bool] = None, runtime_llopa_layers: Optional[int] = None, runtime_llopa_attn: Optional[str] = None, runtime_llopa_no_upper_attn: Optional[bool] = None, unified_llopa_generate: Optional[bool] = None, unified_llopa_layers: Optional[int] = None, unified_llopa_attn: Optional[str] = None, unified_llopa_system_prefill: Optional[str] = None, unified_llopa_user_prefill: Optional[str] = None, unified_llopa_no_upper_attn: Optional[bool] = None, unified_llopa_see_past_assistant: Optional[bool] = None, unified_llopa_replay_module: Optional[str] = None, unified_llopa_replay_per_layers: Optional[int] = None, unified_llopa_last_layer_module: Optional[str] = None, llopa_v2_batch_generate: Optional[bool] = None, llopa_v2_generate: Optional[bool] = None, llopa_v3_generate: Optional[bool] = None, llopa_v2_layers: Optional[int] = None, llopa_v2_attn: Optional[str] = None, llopa_v2_system_prefill: Optional[str] = None, llopa_v2_user_prefill: Optional[str] = None, llopa_v2_no_upper_attn: Optional[bool] = None, llopa_v2_see_past_assistant: Optional[bool] = None, llopa_v2_replay_module: Optional[str] = None, llopa_v2_replay_per_layers: Optional[int] = None, llopa_v2_last_layer_module: Optional[str] = None, llopa_v2_seed_mode: Optional[str] = None, optimized_llopa_generate: Optional[bool] = None, optimized_llopa_layers: Optional[int] = None, optimized_llopa_attn: Optional[str] = None, optimized_llopa_system_prefill: Optional[str] = None, optimized_llopa_user_prefill: Optional[str] = None, optimized_llopa_no_upper_attn: Optional[bool] = None, optimized_llopa_see_past_assistant: Optional[bool] = None, optimized_llopa_replay_module: Optional[str] = None, optimized_llopa_replay_per_layers: Optional[int] = None, optimized_llopa_last_layer_module: Optional[str] = None, optimized_llopa_variant: Optional[str] = None, optimized_llopa_seed_mode: Optional[str] = None, optimized_llopa_upper_prepare_mode: Optional[str] = None, optimized_llopa_upper_bucket_multiple: Optional[int] = None, optimized_llopa_seq_bucket_multiple: Optional[int] = None, runtime_llopa_fast_generate: Optional[bool] = None, direct_llopa_generate: Optional[bool] = None, direct_llopa_layers: Optional[int] = None, direct_llopa_attn: Optional[str] = None, direct_llopa_system_prefill: Optional[str] = None, direct_llopa_user_prefill: Optional[str] = None, direct_llopa_no_upper_attn: Optional[bool] = None): repo_path = _resolve_repo_path(model_repo, cache_dir=cache_dir, revision=revision, token=token, local_files_only=local_files_only) info: dict[str, str] = {} tri_info = repo_path / "tri_info.txt" if tri_info.is_file(): info = _read_kv_file(tri_info) if not model_name: model_name = info.get("model_name", "") if num_specials is None: try: num_specials = int(info.get("num_specials", "") or 0) except Exception: num_specials = 0 if runtime_prefill_lower is None: runtime_prefill_lower = False if runtime_prefill_freeze is None: runtime_prefill_freeze = False if runtime_prefill_solo is None: runtime_prefill_solo = False if runtime_prefill_solo_v2 is None: runtime_prefill_solo_v2 = False if runtime_llopa_prefill is None: runtime_llopa_prefill = False if unified_llopa_generate is None: unified_llopa_generate = False if llopa_v2_batch_generate is None: llopa_v2_batch_generate = False if llopa_v2_generate is None: llopa_v2_generate = False if llopa_v3_generate is None: llopa_v3_generate = False if bool(llopa_v3_generate): llopa_v2_generate = True if optimized_llopa_generate is None: optimized_llopa_generate = False if runtime_llopa_fast_generate is None: runtime_llopa_fast_generate = False if direct_llopa_generate is None: direct_llopa_generate = False backbone_ref = ( backbone_dir or read_backbone_ref(repo_path) or _read_adapter_backbone_ref(repo_path) or model_name or model_repo ) num_specials_arg = num_specials config = None config_source = str(repo_path) if (repo_path / "config.json").is_file() else backbone_ref try: config = AutoConfig.from_pretrained( config_source, cache_dir=cache_dir, revision=revision, token=token, local_files_only=local_files_only, ) if num_specials_arg is not None: config.llopa_num_specials = int(num_specials_arg) except Exception: config = None if num_specials_arg is None: num_specials = int(getattr(config, "llopa_num_specials", 0) or 0) if config is not None else 0 else: num_specials = int(num_specials_arg) def _config_str(name: str) -> str: if config is None: return "" raw = getattr(config, name, "") return str(raw or "").strip() def _config_bool(name: str): if config is None: return None value = getattr(config, name, None) if value is None: return None return bool(value) def _normalize_attention_gate_mode(mode: str) -> str: normalized = str(mode or "off").strip().lower() aliases = { "": "off", "none": "off", "disabled": "off", "disable": "off", "false": "off", "0": "off", "paper": "sdpa_sigmoid", "sdpa_gate": "sdpa_sigmoid", "sdpa-gate": "sdpa_sigmoid", "sigmoid_after_sdpa": "sdpa_sigmoid", "sdpa_elementwise_sigmoid": "sdpa_sigmoid", } normalized = aliases.get(normalized, normalized) if normalized not in {"off", "sdpa_sigmoid"}: normalized = "off" return normalized attention_gate_mode = _normalize_attention_gate_mode( info.get("attention_gate_mode") or _config_str("capsule_attention_gate_mode") or "off" ) if config is not None: with contextlib.suppress(Exception): setattr(config, "capsule_attention_gate_mode", attention_gate_mode) if no_upper_attn is None: raw = (info.get("no_upper_attn") or _config_str("capsule_no_upper_attn")).strip().lower() if raw in {"1", "true", "yes", "on"}: no_upper_attn = True elif raw in {"0", "false", "no", "off"}: no_upper_attn = False else: cfg_bool = _config_bool("capsule_no_upper_attn") if cfg_bool is not None: no_upper_attn = cfg_bool if runtime_prefill_layers is None: raw = (info.get("lower_k") or _config_str("capsule_lower_layers")).strip() if raw: with contextlib.suppress(Exception): runtime_prefill_layers = int(raw) if runtime_prefill_attn is None: runtime_prefill_attn = ( (info.get("prefill_attn") or _config_str("capsule_prefill_attn") or "causal").strip().lower() or "causal" ) if runtime_prefill_system_prefill is None: raw_system_prefill = (info.get("system_prefill") or _config_str("capsule_system_prefill")).strip().lower() if raw_system_prefill in {"full", "no_system", "no_bos_system"}: runtime_prefill_system_prefill = raw_system_prefill if runtime_llopa_layers is None: runtime_llopa_layers = runtime_prefill_layers if runtime_llopa_attn is None: runtime_llopa_attn = runtime_prefill_attn if runtime_llopa_no_upper_attn is None: runtime_llopa_no_upper_attn = bool(no_upper_attn) if no_upper_attn is not None else False if unified_llopa_layers is None: unified_llopa_layers = runtime_prefill_layers if unified_llopa_attn is None: unified_llopa_attn = runtime_prefill_attn if unified_llopa_system_prefill is None: unified_llopa_system_prefill = runtime_prefill_system_prefill if unified_llopa_user_prefill is None: raw_user_prefill = (info.get("user_prefill") or _config_str("capsule_user_prefill") or "full").strip().lower() if raw_user_prefill: unified_llopa_user_prefill = raw_user_prefill if unified_llopa_no_upper_attn is None: unified_llopa_no_upper_attn = bool(no_upper_attn) if no_upper_attn is not None else False if unified_llopa_see_past_assistant is None: unified_llopa_see_past_assistant = False if unified_llopa_replay_module is None: unified_llopa_replay_module = unified_llopa_last_layer_module if unified_llopa_replay_module is None: unified_llopa_replay_module = ( info.get("replay_module") or _config_str("capsule_replay_module") or info.get("last_layer_module") or _config_str("capsule_last_layer_module") ) unified_llopa_replay_module = _normalize_replay_module_value(unified_llopa_replay_module) unified_llopa_last_layer_module = str(unified_llopa_replay_module) if unified_llopa_replay_per_layers is None: unified_llopa_replay_per_layers = info.get("replay_per_layers") or _config_str("capsule_replay_per_layers") or -1 unified_llopa_replay_per_layers = _normalize_replay_per_layers_value(unified_llopa_replay_per_layers) if llopa_v2_layers is None: llopa_v2_layers = runtime_prefill_layers if llopa_v2_attn is None: llopa_v2_attn = runtime_prefill_attn if llopa_v2_system_prefill is None: llopa_v2_system_prefill = runtime_prefill_system_prefill if llopa_v2_user_prefill is None: raw_user_prefill = (info.get("user_prefill") or _config_str("capsule_user_prefill") or "full").strip().lower() if raw_user_prefill: llopa_v2_user_prefill = raw_user_prefill if llopa_v2_no_upper_attn is None: llopa_v2_no_upper_attn = bool(no_upper_attn) if no_upper_attn is not None else False if llopa_v2_see_past_assistant is None: llopa_v2_see_past_assistant = False if llopa_v2_replay_module is None: llopa_v2_replay_module = llopa_v2_last_layer_module if llopa_v2_replay_module is None: llopa_v2_replay_module = ( info.get("replay_module") or _config_str("capsule_replay_module") or info.get("last_layer_module") or _config_str("capsule_last_layer_module") ) llopa_v2_replay_module = _normalize_replay_module_value(llopa_v2_replay_module) llopa_v2_last_layer_module = str(llopa_v2_replay_module) if llopa_v2_replay_per_layers is None: llopa_v2_replay_per_layers = info.get("replay_per_layers") or _config_str("capsule_replay_per_layers") or -1 llopa_v2_replay_per_layers = _normalize_replay_per_layers_value(llopa_v2_replay_per_layers) llopa_v2_seed_mode = "prefill_header" if optimized_llopa_layers is None: optimized_llopa_layers = runtime_prefill_layers if optimized_llopa_attn is None: optimized_llopa_attn = runtime_prefill_attn if optimized_llopa_system_prefill is None: optimized_llopa_system_prefill = runtime_prefill_system_prefill if optimized_llopa_user_prefill is None: raw_user_prefill = (info.get("user_prefill") or _config_str("capsule_user_prefill") or "full").strip().lower() if raw_user_prefill: optimized_llopa_user_prefill = raw_user_prefill if optimized_llopa_no_upper_attn is None: optimized_llopa_no_upper_attn = bool(no_upper_attn) if no_upper_attn is not None else False if optimized_llopa_see_past_assistant is None: optimized_llopa_see_past_assistant = False if optimized_llopa_replay_module is None: optimized_llopa_replay_module = optimized_llopa_last_layer_module if optimized_llopa_replay_module is None: optimized_llopa_replay_module = ( info.get("replay_module") or _config_str("capsule_replay_module") or info.get("last_layer_module") or _config_str("capsule_last_layer_module") ) optimized_llopa_replay_module = _normalize_replay_module_value(optimized_llopa_replay_module) optimized_llopa_last_layer_module = str(optimized_llopa_replay_module) if optimized_llopa_replay_per_layers is None: optimized_llopa_replay_per_layers = info.get("replay_per_layers") or _config_str("capsule_replay_per_layers") or -1 optimized_llopa_replay_per_layers = _normalize_replay_per_layers_value(optimized_llopa_replay_per_layers) optimized_settings = _resolve_optimized_llopa_settings( variant=optimized_llopa_variant, seed_mode=optimized_llopa_seed_mode, upper_prepare_mode=optimized_llopa_upper_prepare_mode, upper_bucket_multiple=optimized_llopa_upper_bucket_multiple, seq_bucket_multiple=optimized_llopa_seq_bucket_multiple, ) if direct_llopa_layers is None: direct_llopa_layers = runtime_prefill_layers if direct_llopa_attn is None: direct_llopa_attn = runtime_prefill_attn if direct_llopa_system_prefill is None: direct_llopa_system_prefill = runtime_prefill_system_prefill if direct_llopa_user_prefill is None: raw_user_prefill = (info.get("user_prefill") or _config_str("capsule_user_prefill") or "full").strip().lower() if raw_user_prefill: direct_llopa_user_prefill = raw_user_prefill if direct_llopa_no_upper_attn is None: direct_llopa_no_upper_attn = bool(no_upper_attn) if no_upper_attn is not None else False if runtime_prefill_solo_layers is None: runtime_prefill_solo_layers = runtime_prefill_layers if runtime_prefill_solo_attn is None: runtime_prefill_solo_attn = runtime_prefill_attn if runtime_prefill_solo_system_prefill is None: runtime_prefill_solo_system_prefill = runtime_prefill_system_prefill if runtime_prefill_solo_v2_layers is None: runtime_prefill_solo_v2_layers = runtime_prefill_layers if runtime_prefill_solo_v2_attn is None: runtime_prefill_solo_v2_attn = runtime_prefill_attn if runtime_prefill_solo_v2_system_prefill is None: runtime_prefill_solo_v2_system_prefill = runtime_prefill_system_prefill if runtime_prefill_solo_v2_with_bos is None: runtime_prefill_solo_v2_with_bos = False config_kwargs = {"config": config} if config is not None else {} dtype_norm = _normalize_dtype_arg(dtype) or "auto" torch_dtype_norm = _normalize_dtype_arg(torch_dtype) if dtype_norm == "auto" and torch_dtype_norm is not None: dtype_norm = torch_dtype_norm if dtype_norm == "fp32": torch_dtype = torch.float32 elif dtype_norm == "bf16": torch_dtype = torch.bfloat16 elif dtype_norm == "fp16": torch_dtype = torch.float16 else: if torch.cuda.is_available() and torch.cuda.is_bf16_supported(): torch_dtype = torch.bfloat16 elif torch.cuda.is_available(): torch_dtype = torch.float16 else: torch_dtype = torch.float32 # Accept HF-style aliases so AutoModel.from_pretrained(...) kwargs work as-is. for cand in (attn_implementation, _attn_implementation): if cand not in (None, "", "auto"): attn_impl = str(cand) break if attn_impl != "auto" and config is not None: for k in ("attn_implementation", "_attn_implementation"): with contextlib.suppress(Exception): setattr(config, k, attn_impl) # `device_map="cuda:0"` is a single-device placement request, not a sharding map. if isinstance(device_map, str): dm = device_map.strip() if dm in ("", "none", "None", "null", "NULL"): device_map = None elif dm in ("cuda", "cpu", "mps") or dm.startswith(("cuda:", "xpu:", "npu:")): if not device: device = dm device_map = None if device_map is None and not device: device = "cuda" if torch.cuda.is_available() else "cpu" if tokenizer_name: tok_src = tokenizer_name else: tok_src = str(repo_path) if (repo_path / "tokenizer.json").is_file() else backbone_ref tokenizer = AutoTokenizer.from_pretrained( tok_src, use_fast=True, cache_dir=cache_dir, revision=revision, token=token, local_files_only=local_files_only, **_tokenizer_kwargs(AutoTokenizer.from_pretrained), ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "right" checkpoint_repo_path = repo_path if not _repo_has_pretrained_weights(checkpoint_repo_path): with contextlib.suppress(Exception): resolved_backbone = _resolve_repo_path( backbone_ref, cache_dir=cache_dir, revision=revision, token=token, local_files_only=local_files_only, ) if _repo_has_pretrained_weights(resolved_backbone): checkpoint_repo_path = resolved_backbone checkpoint_vocab_size = 0 alignment_report = None with contextlib.suppress(Exception): checkpoint_vocab_size = int(_infer_checkpoint_vocab_size(checkpoint_repo_path) or 0) if checkpoint_vocab_size > 0: try: tokenizer_vocab_size = int(len(tokenizer)) except Exception: tokenizer_vocab_size = 0 if tokenizer_vocab_size > 0 and checkpoint_vocab_size > tokenizer_vocab_size: alignment_report = _align_tokenizer_with_checkpoint_vocab( tokenizer, checkpoint_repo_path, checkpoint_vocab_size, ) _log_tokenizer_checkpoint_alignment("[load_llopa_model]", alignment_report) if config is not None: try: tokenizer_vocab_size = int(len(tokenizer)) except Exception: tokenizer_vocab_size = 0 config_vocab_size = 0 try: config_vocab_size = int(getattr(config, "vocab_size", 0) or 0) except Exception: config_vocab_size = 0 target_vocab_size = max(config_vocab_size, tokenizer_vocab_size, int(checkpoint_vocab_size or 0)) if checkpoint_vocab_size > 0 and checkpoint_vocab_size != config_vocab_size: _log_config_checkpoint_vocab_alignment( "[load_llopa_model]", config_vocab_size=config_vocab_size, checkpoint_vocab_size=checkpoint_vocab_size, tokenizer_vocab_size=tokenizer_vocab_size, alignment_report=alignment_report, ) if target_vocab_size > 0 and target_vocab_size != config_vocab_size: with contextlib.suppress(Exception): config.vocab_size = int(target_vocab_size) if enable_thinking is not None: if hasattr(tokenizer, "enable_thinking"): try: tokenizer.enable_thinking(enable_thinking) except Exception: try: tokenizer.enable_thinking = enable_thinking except Exception: pass elif hasattr(tokenizer, "set_enable_thinking"): try: tokenizer.set_enable_thinking(enable_thinking) except Exception: pass try: setattr(tokenizer, "_force_enable_thinking", enable_thinking) except Exception: pass model_family = infer_model_family(backbone_ref, modeling_family) modeling_path = _resolve_modeling_path(repo_path, lopa_modeling_path, model_family) lora_path = repo_path / "lora" has_lora = bool(use_lora and lora_path.exists() and any(lora_path.iterdir())) if not has_lora and bool(use_lora): top_level_adapter = ( (repo_path / "adapter_config.json").is_file() and ( (repo_path / "adapter_model.safetensors").is_file() or (repo_path / "adapter_model.bin").is_file() ) ) if top_level_adapter: lora_path = repo_path has_lora = True load_device_map = device_map merge_on_cpu_active = bool(merge_on_cpu and has_lora and device_map is None) if merge_on_cpu and has_lora and device_map is not None: print("[LoRA] merge_on_cpu ignored because sharded device_map is requested.") if merge_on_cpu_active: print("[LoRA] Merging adapters on CPU to reduce CUDA peak memory.") load_device_map = None custom_mod = None if modeling_path: try: custom_mod = load_custom_modeling(modeling_path, model_family=model_family) except Exception: custom_mod = None base = None custom_load_exc = None if custom_mod is not None: try: if model_family == "qwen3": base = custom_mod.Qwen3ForCausalLM.from_pretrained( backbone_ref, **_dtype_kwargs(custom_mod.Qwen3ForCausalLM.from_pretrained, torch_dtype), trust_remote_code=trust_remote_code, cache_dir=cache_dir, revision=revision, token=token, local_files_only=local_files_only, device_map=load_device_map, **config_kwargs, ) elif model_family == "mistral": base = custom_mod.MistralForCausalLM.from_pretrained( backbone_ref, **_dtype_kwargs(custom_mod.MistralForCausalLM.from_pretrained, torch_dtype), trust_remote_code=trust_remote_code, cache_dir=cache_dir, revision=revision, token=token, local_files_only=local_files_only, device_map=load_device_map, **config_kwargs, ) else: base = custom_mod.LlamaForCausalLM.from_pretrained( backbone_ref, **_dtype_kwargs(custom_mod.LlamaForCausalLM.from_pretrained, torch_dtype), trust_remote_code=trust_remote_code, cache_dir=cache_dir, revision=revision, token=token, local_files_only=local_files_only, device_map=load_device_map, **config_kwargs, ) except Exception as exc: custom_load_exc = exc base = None if base is None and force_custom_modeling and custom_mod is not None and custom_load_exc is not None: raise RuntimeError( f"Failed to load base model with custom LLoPA modeling from {backbone_ref}" ) from custom_load_exc if base is None: base = AutoModelForCausalLM.from_pretrained( backbone_ref, trust_remote_code=trust_remote_code, **_dtype_kwargs(AutoModelForCausalLM.from_pretrained, torch_dtype), cache_dir=cache_dir, revision=revision, token=token, local_files_only=local_files_only, device_map=load_device_map, **config_kwargs, ) ensure_mistral_special_token(tokenizer, base) load_embedding_layer(base, repo_path) if merge_on_cpu_active: try: base = base.to("cpu") except Exception: pass model = None if has_lora: num_lora = int(number_of_lora or 1) if num_lora == 2: gen_dir = lora_path / "gen" prefill_dir = lora_path / "prefill" if gen_dir.is_dir() and prefill_dir.is_dir(): try: from peft import PeftModel peft_gen = PeftModel.from_pretrained(base, str(gen_dir), adapter_name="gen") try: peft_gen.set_adapter("gen") except Exception: pass merged_base = peft_gen.merge_and_unload() peft_prefill = PeftModel.from_pretrained(merged_base, str(prefill_dir), adapter_name="prefill") model = peft_prefill setattr(model, "_prefill_adapter_only", True) except Exception: model = base else: model = base if model is None or model is base: try: from peft import PeftModel peft = PeftModel.from_pretrained(base, str(lora_path)) try: model = peft.merge_and_unload() except Exception: model = peft except Exception: model = base else: model = base if num_specials > 0: if not load_llopa_specials(model, repo_path): print("[Warn] Failed to load LLOPA specials.") try: p0 = next(model.parameters()) print(f"[load_llopa_model] model dtype={p0.dtype}, on={p0.device}") except Exception: pass if device_map is None: model = model.to(device).eval() else: model = model.eval() if no_upper_attn is not None: try: setattr(model, "_no_upper_attn", bool(no_upper_attn)) except Exception: pass if bool(runtime_prefill_lower): if _supports_prefill_lower_runtime(model): _attach_prefill_lower_generate( model, lower_k=int(runtime_prefill_layers or 0), prefill_attn=str(runtime_prefill_attn or "causal"), system_prefill=str(runtime_prefill_system_prefill or "no_bos_system"), ) try: print( f"[load_llopa_model] standard generate runtime enabled " f"(prefill_lower_layers={int(runtime_prefill_layers or 0)}, " f"prefill_attn={str(runtime_prefill_attn or 'causal')})" ) except Exception: pass try: setattr(model, "_capsule_inference_path", "runtime_lower") except Exception: pass else: print("[load_llopa_model][warn] runtime_prefill_lower requested but TRI prefill-lower runtime is unavailable.") if bool(runtime_prefill_freeze): if _supports_prefill_lower_runtime(model): _attach_prefill_lower_freeze_generate( model, tokenizer=tokenizer, lower_k=int(runtime_prefill_freeze_layers or 0), prefill_attn=str(runtime_prefill_freeze_attn or "causal"), system_prefill=str(runtime_prefill_freeze_system_prefill or "no_bos_system"), ) try: print( f"[load_llopa_model] freeze-faithful generate runtime enabled " f"(prefill_lower_layers={int(runtime_prefill_freeze_layers or 0)}, " f"prefill_attn={str(runtime_prefill_freeze_attn or 'causal')}, " f"system_prefill={str(runtime_prefill_freeze_system_prefill or 'no_bos_system')})" ) except Exception: pass try: setattr(model, "_capsule_inference_path", "runtime_freeze") except Exception: pass else: print("[load_llopa_model][warn] runtime_prefill_freeze requested but TRI prefill-freeze runtime is unavailable.") if bool(runtime_prefill_solo): if _supports_prefill_lower_runtime(model): _attach_prefill_lower_solo_generate( model, tokenizer=tokenizer, lower_k=int(runtime_prefill_solo_layers or 0), prefill_attn=str(runtime_prefill_solo_attn or "causal"), system_prefill=str(runtime_prefill_solo_system_prefill or "no_bos_system"), ) try: print( f"[load_llopa_model] solo-attn generate runtime enabled " f"(prefill_lower_layers={int(runtime_prefill_solo_layers or 0)}, " f"prefill_attn={str(runtime_prefill_solo_attn or 'causal')}, " f"system_prefill={str(runtime_prefill_solo_system_prefill or 'no_bos_system')})" ) except Exception: pass try: setattr(model, "_capsule_inference_path", "runtime_solo") except Exception: pass else: print("[load_llopa_model][warn] runtime_prefill_solo requested but TRI prefill-solo runtime is unavailable.") if bool(runtime_prefill_solo_v2): if _supports_prefill_lower_runtime(model): _attach_prefill_lower_solo_v2_generate( model, tokenizer=tokenizer, lower_k=int(runtime_prefill_solo_v2_layers or 0), prefill_attn=str(runtime_prefill_solo_v2_attn or "causal"), system_prefill=str(runtime_prefill_solo_v2_system_prefill or "no_bos_system"), with_bos=bool(runtime_prefill_solo_v2_with_bos), ) try: print( f"[load_llopa_model] solo-attn-v2 generate runtime enabled " f"(prefill_lower_layers={int(runtime_prefill_solo_v2_layers or 0)}, " f"prefill_attn={str(runtime_prefill_solo_v2_attn or 'causal')}, " f"system_prefill={str(runtime_prefill_solo_v2_system_prefill or 'no_bos_system')}, " f"with_bos={int(bool(runtime_prefill_solo_v2_with_bos))})" ) except Exception: pass try: setattr(model, "_capsule_inference_path", "runtime_solo_v2") except Exception: pass else: print("[load_llopa_model][warn] runtime_prefill_solo_v2 requested but TRI prefill-solo-v2 runtime is unavailable.") if bool(runtime_llopa_prefill): if _supports_runtime_llopa_prompt_prefill(model): header_ids = _assistant_header_ids(tokenizer, "cpu") if isinstance(header_ids, torch.Tensor) and header_ids.numel() > 0: _attach_runtime_llopa_generate( model, header_ids=header_ids, lower_k=int(runtime_llopa_layers or 0), prefill_attn=str(runtime_llopa_attn or "causal"), no_upper_attn=bool(runtime_llopa_no_upper_attn), ) if bool(runtime_llopa_fast_generate): _attach_runtime_llopa_fast_generate( model, lower_k=int(runtime_llopa_layers or 0), prefill_attn=str(runtime_llopa_attn or "causal"), no_upper_attn=bool(runtime_llopa_no_upper_attn), ) try: print( f"[load_llopa_model] standard generate runtime enabled " f"(llopa_prefill_layers={int(runtime_llopa_layers or 0)}, " f"prefill_attn={str(runtime_llopa_attn or 'causal')})" ) except Exception: pass try: setattr(model, "_capsule_inference_path", "legacy_llopa") except Exception: pass if bool(runtime_llopa_fast_generate): try: print("[load_llopa_model] runtime_llopa_fast_generate enabled") except Exception: pass try: setattr(model, "_capsule_inference_path", "legacy_llopa") except Exception: pass else: print("[load_llopa_model][warn] runtime_llopa_prefill requested but assistant header ids are unavailable.") else: print("[load_llopa_model][warn] runtime_llopa_prefill requested but TRI LLoPA runtime is unavailable.") if bool(unified_llopa_generate): if _supports_direct_llopa_generate(model): _attach_unified_llopa_generate( model, tokenizer, lower_k=int(unified_llopa_layers or 0), prefill_attn=str(unified_llopa_attn or "causal"), system_prefill=str(unified_llopa_system_prefill or "full"), user_prefill=str(unified_llopa_user_prefill or "full"), no_upper_attn=bool(unified_llopa_no_upper_attn), see_past_assistant=bool(unified_llopa_see_past_assistant), replay_module=str(unified_llopa_replay_module or "none"), replay_per_layers=int(unified_llopa_replay_per_layers or -1), ) try: print( f"[load_llopa_model] unified_llopa generate enabled " f"(llopa_prefill_layers={int(unified_llopa_layers or 0)}, " f"prefill_attn={str(unified_llopa_attn or 'causal')}, " f"replay_module={str(unified_llopa_replay_module or 'none')}, " f"replay_per_layers={int(unified_llopa_replay_per_layers or -1)})" ) except Exception: pass else: print("[load_llopa_model][warn] unified_llopa_generate requested but LLoPA direct prompt prefill is unavailable.") if bool(llopa_v2_batch_generate): if _supports_direct_llopa_generate(model): _attach_llopa_v2_batch_generate( model, tokenizer, lower_k=int(llopa_v2_layers or 0), prefill_attn=str(llopa_v2_attn or "causal"), system_prefill=str(llopa_v2_system_prefill or "full"), user_prefill=str(llopa_v2_user_prefill or "full"), no_upper_attn=bool(llopa_v2_no_upper_attn), see_past_assistant=bool(llopa_v2_see_past_assistant), replay_module=str(llopa_v2_replay_module or "none"), replay_per_layers=int(llopa_v2_replay_per_layers or -1), seed_mode=str(llopa_v2_seed_mode or "prefill_header"), ) try: print( f"[load_llopa_model] llopa_v2_batch generate enabled " f"(llopa_prefill_layers={int(llopa_v2_layers or 0)}, " f"prefill_attn={str(llopa_v2_attn or 'causal')}, " f"replay_module={str(llopa_v2_replay_module or 'none')}, " f"replay_per_layers={int(llopa_v2_replay_per_layers or -1)}, " f"seed_mode={str(llopa_v2_seed_mode or 'prefill_header')})" ) except Exception: pass else: print("[load_llopa_model][warn] llopa_v2_batch_generate requested but LLoPA direct prompt prefill is unavailable.") if bool(llopa_v2_generate): if _supports_direct_llopa_generate(model): llopa_v3_active = bool(llopa_v3_generate) _attach_llopa_v2_generate( model, tokenizer, lower_k=int(llopa_v2_layers or 0), prefill_attn=str(llopa_v2_attn or "causal"), system_prefill=str(llopa_v2_system_prefill or "full"), user_prefill=str(llopa_v2_user_prefill or "full"), no_upper_attn=bool(llopa_v2_no_upper_attn), see_past_assistant=bool(llopa_v2_see_past_assistant), replay_module=str(llopa_v2_replay_module or "none"), replay_per_layers=int(llopa_v2_replay_per_layers or -1), seed_mode=str(llopa_v2_seed_mode or "prefill_header"), generation_mixin_decode=llopa_v3_active, capsule_inference_path="llopa_v3" if llopa_v3_active else "llopa_v2", ) try: label = "llopa_v3" if llopa_v3_active else "llopa_v2" print( f"[load_llopa_model] {label} generate enabled " f"(llopa_prefill_layers={int(llopa_v2_layers or 0)}, " f"prefill_attn={str(llopa_v2_attn or 'causal')}, " f"replay_module={str(llopa_v2_replay_module or 'none')}, " f"replay_per_layers={int(llopa_v2_replay_per_layers or -1)}, " f"seed_mode={str(llopa_v2_seed_mode or 'prefill_header')}, " f"generation_mixin_decode={int(llopa_v3_active)})" ) except Exception: pass else: label = "llopa_v3_generate" if bool(llopa_v3_generate) else "llopa_v2_generate" print(f"[load_llopa_model][warn] {label} requested but LLoPA direct prompt prefill is unavailable.") if bool(optimized_llopa_generate): if _supports_direct_llopa_generate(model): _attach_optimized_llopa_generate( model, tokenizer, lower_k=int(optimized_llopa_layers or 0), prefill_attn=str(optimized_llopa_attn or "causal"), system_prefill=str(optimized_llopa_system_prefill or "full"), user_prefill=str(optimized_llopa_user_prefill or "full"), no_upper_attn=bool(optimized_llopa_no_upper_attn), see_past_assistant=bool(optimized_llopa_see_past_assistant), replay_module=str(optimized_llopa_replay_module or "none"), replay_per_layers=int(optimized_llopa_replay_per_layers or -1), optimized_variant=str(optimized_settings["variant"]), optimized_seed_mode=str(optimized_settings["seed_mode"]), optimized_upper_prepare_mode=str(optimized_settings["upper_prepare_mode"]), optimized_upper_bucket_multiple=int(optimized_settings["upper_bucket_multiple"]), optimized_seq_bucket_multiple=int(optimized_settings["seq_bucket_multiple"]), ) try: print( f"[load_llopa_model] optimized_llopa generate enabled " f"(llopa_prefill_layers={int(optimized_llopa_layers or 0)}, " f"prefill_attn={str(optimized_llopa_attn or 'causal')}, " f"replay_module={str(optimized_llopa_replay_module or 'none')}, " f"replay_per_layers={int(optimized_llopa_replay_per_layers or -1)}, " f"variant={str(optimized_settings['variant'])})" ) except Exception: pass else: print("[load_llopa_model][warn] optimized_llopa_generate requested but LLoPA direct prompt prefill is unavailable.") if bool(direct_llopa_generate): print( "[load_llopa_model][warn] direct_llopa_* is deprecated; " "use unified_llopa_* or INFERENCE_PATH=unified_llopa. Legacy users can keep existing envs unchanged." ) if _supports_direct_llopa_generate(model): _attach_direct_llopa_generate( model, tokenizer, lower_k=int(direct_llopa_layers or 0), prefill_attn=str(direct_llopa_attn or "causal"), system_prefill=str(direct_llopa_system_prefill or "full"), user_prefill=str(direct_llopa_user_prefill or "full"), no_upper_attn=bool(direct_llopa_no_upper_attn), ) try: print( f"[load_llopa_model] direct_llopa generate enabled " f"(llopa_prefill_layers={int(direct_llopa_layers or 0)}, " f"prefill_attn={str(direct_llopa_attn or 'causal')})" ) except Exception: pass else: print("[load_llopa_model][warn] direct_llopa_generate requested but LLoPA direct prompt prefill is unavailable.") if force_custom_modeling: try: has_llopa = _has_active_llopa_runtime(model) except Exception: has_llopa = False if not has_llopa: raise RuntimeError("Custom LLoPA modeling not active at inference. Check --lopa_modeling_path.") if attn_impl != "auto": impl = attn_impl for k in ("attn_implementation", "_attn_implementation"): try: setattr(model.config, k, impl) inner = getattr(model, "model", None) or getattr(model, "transformer", None) if inner is not None and hasattr(inner, "config"): setattr(inner.config, k, impl) except Exception: pass _attach_llopa_generate(model) return model, tokenizer # ----------------------------- # CLI # ----------------------------- def main(): ap = argparse.ArgumentParser("TRI inference helper") ap.add_argument("--best_dir", type=str, required=True, help="Path to best/ folder produced by training") ap.add_argument("--backbone_dir", type=str, default=None, help="Optional backbone path/ID (overrides best_dir/backbone.json and --model_name).") ap.add_argument("--model_name", type=str, default="meta-llama/Llama-3.1-8B-Instruct") ap.add_argument("--tokenizer_name", type=str, default="", help="Optional tokenizer name or path") ap.add_argument("--prefill_layers", type=int, default=4) ap.add_argument("--prefill_mode", type=str, choices=["lower", "periodic"], default="lower", help="Prefill mode for user tokens: 'lower' uses first K layers, 'periodic' uses every K-th layer.") ap.add_argument("--prefill_attn", type=str, choices=["causal", "full"], default="causal", help="Prefill attention for system/user tokens (training must match).") ap.add_argument("--system_prefill", type=str, choices=["full", "no_system", "no_bos_system"], default="full", help="System prefill mode (must match training).") ap.add_argument("--user_prefill", type=str, choices=["full", "no_question"], default="full", help="User prefill ablation: full=doc+question, no_question=doc-only (question runs in full layers).") ap.add_argument("--llopa_prefill", action="store_true", help="Use single-forward LLOPA prefill (causal + lower only).") ap.add_argument("--no_upper_attn", action="store_true", help="Skip upper-layer attention during decode (effective only with --llopa_prefill).") ap.add_argument("--lopa_modeling_path", type=str, default="tri_llama3_modeling.py", help="Path to custom LLoPA modeling file used in training") ap.add_argument("--modeling_family", type=str, choices=["auto", "llama", "qwen3", "mistral"], default="auto", help="Model family for custom modeling injection (auto detects from --model_name)") ap.add_argument("--force_custom_modeling", action="store_true", help="Require custom LLoPA modeling to be active; error if not.") ap.add_argument("--system", type=str, default="You are a helpful assistant that answers questions based on the given document. ") ap.add_argument("--task", type=str, default="qa_doc", help="Prompt template task: qa_doc | math | summary | code") ap.add_argument("--math_force_final_hash_rule", action="store_true", help="If set and task=math, append the #### answer rule to the system prompt.") ap.add_argument("--document", type=str, required=True) ap.add_argument("--question", type=str, required=True) ap.add_argument("--max_new_tokens", type=int, default=256) ap.add_argument("--min_length", type=int, default=16) ap.add_argument("--temperature", type=float, default=0.7) ap.add_argument("--top_p", type=float, default=0.9) ap.add_argument("--top_k", type=int, default=None) ap.add_argument("--do_sample", action="store_true") # numeric controls for reproducibility ap.add_argument("--dtype", type=str, choices=["auto","bf16","fp16","fp32"], default="auto") ap.add_argument("--no_tf32", action="store_true") ap.add_argument("--sdpa_math_only", action="store_true") ap.add_argument("--debug", action="store_true") ap.add_argument("--attn_impl", type=str, choices=["sdpa", "eager", "auto"], default="sdpa", help="Attention implementation override (auto keeps model default).") args = ap.parse_args() device = "cuda" if torch.cuda.is_available() else "cpu" if args.dtype == "fp32": dtype = torch.float32 elif args.dtype == "bf16": dtype = torch.bfloat16 elif args.dtype == "fp16": dtype = torch.float16 else: dtype = torch.bfloat16 if (device == "cuda" and torch.cuda.is_bf16_supported()) else (torch.float16 if device == "cuda" else torch.float32) # global numeric toggles if args.no_tf32 and torch.cuda.is_available(): try: torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False except Exception: pass if args.sdpa_math_only and torch.cuda.is_available(): try: torch.backends.cuda.enable_flash_sdp(False) torch.backends.cuda.enable_mem_efficient_sdp(False) torch.backends.cuda.enable_math_sdp(True) except Exception: pass best_dir = Path(args.best_dir) if getattr(args, "tokenizer_name", ""): tok_src = args.tokenizer_name else: tok_src = str(best_dir) if (best_dir / "tokenizer.json").is_file() else args.model_name tokenizer = AutoTokenizer.from_pretrained( tok_src, use_fast=True, **_tokenizer_kwargs(AutoTokenizer.from_pretrained), ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "right" # Prefer a saved base backbone under best_dir/base if present (captures resized embeddings) backbone_ref = None if args.backbone_dir: backbone_ref = args.backbone_dir else: backbone_ref = read_backbone_ref(best_dir) base_path = best_dir / "base" if backbone_ref: base_load_src = backbone_ref elif base_path.exists() and any(base_path.iterdir()): base_load_src = str(base_path) else: base_load_src = args.model_name num_specials = None tri_no_upper_attn = None tri_info = best_dir / "tri_info.txt" if tri_info.is_file(): info = _read_kv_file(tri_info) try: num_specials = int(info.get("num_specials", "") or 0) except Exception: num_specials = None raw_no_upper = (info.get("no_upper_attn") or "").strip().lower() if raw_no_upper in {"1", "true", "yes", "on"}: tri_no_upper_attn = True elif raw_no_upper in {"0", "false", "no", "off"}: tri_no_upper_attn = False if (not bool(getattr(args, "no_upper_attn", False))) and (tri_no_upper_attn is not None): args.no_upper_attn = bool(tri_no_upper_attn) config = None try: config = AutoConfig.from_pretrained(base_load_src) if num_specials is not None: config.llopa_num_specials = int(num_specials) except Exception: config = None if num_specials is None: num_specials = int(getattr(config, "llopa_num_specials", 0) or 0) if config is not None else 0 config_kwargs = {"config": config} if config is not None else {} # Try loading custom LLoPA modeling before model load custom_mod = None model_family = infer_model_family(args.model_name, args.modeling_family) try: custom_mod = load_custom_modeling(args.lopa_modeling_path, model_family=model_family) except Exception: custom_mod = None base = None if custom_mod is not None: try: # Prefer explicit class if available (ensures we really use custom class) if model_family == "qwen3": base = custom_mod.Qwen3ForCausalLM.from_pretrained( base_load_src, **_dtype_kwargs(custom_mod.Qwen3ForCausalLM.from_pretrained, dtype), **config_kwargs, ) elif model_family == "mistral": base = custom_mod.MistralForCausalLM.from_pretrained( base_load_src, **_dtype_kwargs(custom_mod.MistralForCausalLM.from_pretrained, dtype), **config_kwargs, ) else: base = custom_mod.LlamaForCausalLM.from_pretrained( base_load_src, **_dtype_kwargs(custom_mod.LlamaForCausalLM.from_pretrained, dtype), **config_kwargs, ) except Exception: base = None if base is None: base = AutoModelForCausalLM.from_pretrained( base_load_src, trust_remote_code=False, **_dtype_kwargs(AutoModelForCausalLM.from_pretrained, dtype), **config_kwargs, ) # Ensure special token availability for Mistral ensure_mistral_special_token(tokenizer, base) # Apply saved embedding layer (special tokens) if present. loaded_emb = load_embedding_layer(base, best_dir) # attach LoRA if exists lora_path = best_dir / "lora" model = None merged_lora = False if lora_path.exists() and any(lora_path.iterdir()): try: from peft import PeftModel peft = PeftModel.from_pretrained(base, str(lora_path)) try: model = peft.merge_and_unload() merged_lora = True except Exception: # fallback: keep PEFT wrapper without merge model = peft except Exception: model = base else: model = base if num_specials > 0: if not load_llopa_specials(model, best_dir): print("[Warn] Failed to load LLOPA specials.") # device & eval model = model.to(device).eval() try: setattr(model, "_no_upper_attn", bool(getattr(args, "no_upper_attn", False))) except Exception: pass # Validate custom modeling presence if requested if args.force_custom_modeling: has_llopa = _has_active_llopa_runtime(model) if not has_llopa: raise RuntimeError("Custom LLoPA modeling not active at inference. Check --lopa_modeling_path.") # Allow sdpa/eager; avoid flash_attention_2 for LoPA masks. if args.attn_impl != "auto": impl = args.attn_impl for k in ("attn_implementation", "_attn_implementation"): try: setattr(model.config, k, impl) inner = getattr(model, "model", None) or getattr(model, "transformer", None) if inner is not None and hasattr(inner, "config"): setattr(inner.config, k, impl) except Exception: pass print(f"[infer] Forcing attn_implementation='{impl}' for all models.") else: print("[infer] Using model default attn_implementation (auto).") if args.debug: print(f"[debug] load base from: {base_load_src}") if loaded_emb: print("[debug] loaded embedding layer from best_dir") print(f"[debug] lora path: {lora_path} | merged={merged_lora}") tmpl = getattr(tokenizer, "chat_template", "") or "" print(f"[debug] template contains Llama3 header? {('<|start_header_id|>' in tmpl)} | Mistral? {('[INST]' in tmpl)}") # debug dir under best_dir dbg_dir = (best_dir / "debug_infer") if args.debug else None text = lopa_generate( model, tokenizer, system=args.system, document=args.document, question=args.question, task=str(getattr(args, "task", "qa_doc")), K=int(args.prefill_layers), prefill_mode=str(args.prefill_mode), prefill_attn=str(getattr(args, "prefill_attn", "causal")), system_prefill=str(getattr(args, "system_prefill", "full")), user_prefill=str(getattr(args, "user_prefill", "full")), device=device, max_new_tokens=args.max_new_tokens, min_length=args.min_length, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k, do_sample=bool(args.do_sample), math_force_final_hash_rule=bool(getattr(args, "math_force_final_hash_rule", False)), llopa_prefill=bool(getattr(args, "llopa_prefill", False)), no_upper_attn=bool(getattr(args, "no_upper_attn", False)), debug=bool(args.debug), debug_dir=dbg_dir, ) print(text) if __name__ == "__main__": main()