| |
| from __future__ import annotations |
|
|
| import os |
| import argparse |
| import contextlib |
| import inspect |
| import json |
| from pathlib import Path |
| from typing import Any, List, Tuple, Optional |
|
|
| import torch |
| import torch.nn as nn |
| from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer |
|
|
| try: |
| from .llopa_utils.saving_utils import load_embedding_layer, load_llopa_specials, read_backbone_ref |
| except Exception: |
| from llopa_utils.saving_utils import load_embedding_layer, load_llopa_specials, read_backbone_ref |
| from transformers.cache_utils import DynamicCache |
|
|
| |
| os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") |
| os.environ.setdefault("DS_BUILD_AIO", "0") |
| os.environ.setdefault("DS_BUILD_OPS", "0") |
|
|
|
|
| def _dtype_kwargs(from_pretrained_fn, dtype: torch.dtype) -> dict: |
| try: |
| params = inspect.signature(from_pretrained_fn).parameters |
| except Exception: |
| return {"torch_dtype": dtype} |
| if "torch_dtype" in params: |
| return {"torch_dtype": dtype} |
| if "dtype" in params: |
| return {"dtype": dtype} |
| |
| |
| |
| if any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()): |
| return {"torch_dtype": dtype} |
| return {"torch_dtype": dtype} |
|
|
|
|
| def _tokenizer_kwargs(from_pretrained_fn) -> dict: |
| kwargs: dict = {} |
| try: |
| params = inspect.signature(from_pretrained_fn).parameters |
| if "fix_mistral_regex" in params: |
| kwargs["fix_mistral_regex"] = True |
| elif any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()): |
| |
| kwargs["fix_mistral_regex"] = True |
| except Exception: |
| pass |
| return kwargs |
|
|
|
|
| def _normalize_dtype_arg(value) -> Optional[str]: |
| if value is None: |
| return None |
| if isinstance(value, str): |
| s = value.strip().lower() |
| elif value is torch.bfloat16: |
| s = "bfloat16" |
| elif value is torch.float16: |
| s = "float16" |
| elif value is torch.float32: |
| s = "float32" |
| else: |
| return None |
| mapping = { |
| "auto": "auto", |
| "bf16": "bf16", |
| "bfloat16": "bf16", |
| "fp16": "fp16", |
| "float16": "fp16", |
| "half": "fp16", |
| "fp32": "fp32", |
| "float32": "fp32", |
| } |
| return mapping.get(s) |
|
|
| |
| MISTRAL_ASSIST_START = "<Mistral_start>" |
|
|
| |
| BOXED_ANSWER_INSTRUCTION = ( |
| "Provide the answer and finish with the short answer inside \\\\boxed{}." |
| ) |
| MATH_STEP_BY_STEP_BOXED_INSTRUCTION = ( |
| "Solve the problem step by step, and put the final answer inside \\\\boxed{}." |
| ) |
|
|
| |
| |
| |
| def load_custom_modeling(modeling_path: str, model_family: str): |
| """Load local tri_*_modeling.py and register it as the HF module.""" |
| import importlib.util, sys |
| import transformers |
| if model_family == "llama": |
| import transformers.models.llama |
| target_name = "transformers.models.llama.modeling_llama" |
| expected = ("LlamaModel", "LlamaForCausalLM") |
| elif model_family == "qwen3": |
| import transformers.models.qwen3 |
| target_name = "transformers.models.qwen3.modeling_qwen3" |
| expected = ("Qwen3Model", "Qwen3ForCausalLM") |
| elif model_family == "mistral": |
| import transformers.models.mistral |
| target_name = "transformers.models.mistral.modeling_mistral" |
| expected = ("MistralModel", "MistralForCausalLM") |
| else: |
| raise ValueError(f"Unknown model_family: {model_family}") |
| if target_name in sys.modules: |
| del sys.modules[target_name] |
| spec = importlib.util.spec_from_file_location(target_name, str(modeling_path)) |
| if spec is None or spec.loader is None: |
| raise RuntimeError(f"Failed to load spec for {modeling_path}") |
| module = importlib.util.module_from_spec(spec) |
| sys.modules[target_name] = module |
| spec.loader.exec_module(module) |
| for klass in expected: |
| if not hasattr(module, klass): |
| raise RuntimeError(f"{modeling_path} does not define {klass}") |
| return module |
|
|
|
|
| def infer_model_family(model_name: str, model_family_arg: str) -> str: |
| if model_family_arg and model_family_arg != "auto": |
| return model_family_arg |
| name = (model_name or "").lower() |
| if "qwen" in name: |
| return "qwen3" |
| if "llama" in name: |
| return "llama" |
| if "mistral" in name: |
| return "mistral" |
| return "llama" |
|
|
| |
| |
| |
| def _is_mistral_template(tokenizer) -> bool: |
| tmpl = getattr(tokenizer, "chat_template", "") or "" |
| name = getattr(getattr(tokenizer, "init_kwargs", {}), "get", lambda k, d=None: d)("name_or_path", "") |
| return ("[INST]" in tmpl) or ("mistral" in str(name).lower()) or ("mistral" in tmpl.lower()) |
|
|
| def _is_qwen3_tokenizer(tokenizer) -> bool: |
| name = getattr(tokenizer, "name_or_path", "") or "" |
| cls = tokenizer.__class__.__name__.lower() |
| return ("qwen3" in name.lower()) or ("qwen3" in cls) |
|
|
| def ensure_mistral_special_token(tokenizer, model=None): |
| """Ensure the custom assistant-start token exists in tokenizer (and resize model embeddings if provided).""" |
| if not _is_mistral_template(tokenizer): |
| return False |
| add_tok = [] |
| cur = set(tokenizer.get_vocab().keys()) |
| if MISTRAL_ASSIST_START not in cur: |
| add_tok.append(MISTRAL_ASSIST_START) |
| if add_tok: |
| tokenizer.add_special_tokens({ |
| "additional_special_tokens": tokenizer.special_tokens_map_extended.get("additional_special_tokens", []) + add_tok |
| }) |
| if model is not None: |
| try: |
| model.resize_token_embeddings(len(tokenizer)) |
| except Exception: |
| pass |
| return True |
| return False |
|
|
| def build_prompt_parts(task: str, question: str, document: str) -> Tuple[str, str]: |
| t = (task or "qa_doc").strip() |
| q = (question or "").strip() |
| d = (document or "").strip() |
|
|
| if t == "summary": |
| instruction = q if q else "Summarize the passage." |
| doc_text = f"Passage:\n{d}" if d else "" |
| return doc_text, instruction |
| if t == "code": |
| if not q: |
| q = "Solve the following problem." |
| return "", f"Write code to solve the following problem:\n{q}" |
| if t == "math": |
| mq = (q if q else d).strip() |
| if mq: |
| mq = f"{mq}\n{MATH_STEP_BY_STEP_BOXED_INSTRUCTION}" |
| else: |
| mq = MATH_STEP_BY_STEP_BOXED_INSTRUCTION |
| return "", mq |
|
|
| doc_text = f"Document:\n{d}" if d else "" |
| query_text = f"Question: {q}\n{BOXED_ANSWER_INSTRUCTION}" if q else "" |
| return doc_text, query_text |
|
|
| def build_user_content(task: str, question: str, document: str) -> str: |
| doc_text, query_text = build_prompt_parts(task, question, document) |
| if doc_text and query_text: |
| return f"{doc_text}\n\n{query_text}" |
| return doc_text or query_text |
|
|
| def build_messages(system: str, document: str, question: str, |
| include_query: bool = True, task: str = "qa_doc"): |
| q = question if include_query else "" |
| user = build_user_content(task, q, document) |
| if not user: |
| user = (q or document or "") |
| return [{"role": "system", "content": system}, {"role": "user", "content": user}] |
|
|
| def apply_chat_template(tokenizer, messages, add_generation_prompt: bool): |
| """Render chat with robust fallback across templates. |
| |
| - Prefer tokenizer.apply_chat_template(..., add_generation_prompt=...) |
| - If that signature is unsupported, detect template style: |
| * Llama-3 style → append assistant header tokens |
| * Mistral/INST style → no explicit assistant header to append |
| * Unknown → do not append anything |
| """ |
| force_thinking = getattr(tokenizer, "_force_enable_thinking", None) |
| if force_thinking is None and _is_qwen3_tokenizer(tokenizer): |
| force_thinking = False |
| try: |
| if force_thinking is None: |
| return tokenizer.apply_chat_template( |
| messages, tokenize=False, add_generation_prompt=add_generation_prompt |
| ) |
| return tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=add_generation_prompt, |
| enable_thinking=bool(force_thinking), |
| ) |
| except TypeError: |
| tmpl = getattr(tokenizer, "chat_template", "") or "" |
| try: |
| if force_thinking is None: |
| s = tokenizer.apply_chat_template(messages, tokenize=False) |
| else: |
| s = tokenizer.apply_chat_template(messages, tokenize=False, enable_thinking=bool(force_thinking)) |
| except TypeError: |
| s = tokenizer.apply_chat_template(messages, tokenize=False) |
| if add_generation_prompt: |
| if "<|start_header_id|>" in tmpl: |
| s += "<|start_header_id|>assistant<|end_header_id|>\n\n" |
| elif "[INST]" in tmpl or "</s>" in tmpl: |
| s += "" |
| else: |
| |
| s += "" |
| return s |
|
|
| def tokens_from_messages(tokenizer, messages, device, add_generation_prompt=False): |
| s = apply_chat_template(tokenizer, messages, add_generation_prompt) |
| ids = tokenizer(s, add_special_tokens=False, return_tensors="pt").input_ids.to(device) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| return ids |
|
|
|
|
| def _build_llopa_inputs(tokenizer, |
| system: str, |
| document: str, |
| question: str, |
| *, |
| task: str = "qa_doc", |
| device: str): |
| doc_text, query_text = build_prompt_parts(task, question, document) |
| msgs = build_messages(system, document, question, include_query=True, task=task) |
| ids_sys = tokens_from_messages( |
| tokenizer, [{"role": "system", "content": system}], device, add_generation_prompt=False |
| ) |
| ids_hdr = tokens_from_messages(tokenizer, msgs, device, add_generation_prompt=True) |
| hdr_tail = _assistant_header_ids(tokenizer, device) |
| ids_sys_user = None |
| if hdr_tail is not None and hdr_tail.numel() > 0: |
| tail_len = hdr_tail.size(1) |
| if ids_hdr.size(1) >= tail_len and torch.equal(ids_hdr[:, -tail_len:], hdr_tail): |
| ids_sys_user = ids_hdr[:, :-tail_len] |
| if ids_sys_user is None: |
| ids_sys_user = tokens_from_messages(tokenizer, msgs, device, add_generation_prompt=False) |
| hdr_tail = ids_hdr[:, ids_sys_user.size(1):] |
| if ids_sys_user.size(1) < ids_sys.size(1): |
| raise ValueError("System-only tokens longer than system+user tokens.") |
| user_ids_full = ids_sys_user[:, ids_sys.size(1):] |
|
|
| user_content_full = msgs[-1]["content"] |
| if query_text: |
| doc_prefix_text = f"{doc_text}\n\n" if doc_text else "" |
| else: |
| doc_prefix_text = user_content_full |
| user_doc_ids = user_ids_full[:, 0:0] |
| user_q_ids = user_ids_full[:, 0:0] |
| if user_ids_full.size(1) > 0: |
| msgs_empty = [{"role": "system", "content": system}, {"role": "user", "content": ""}] |
| ids_empty = tokens_from_messages(tokenizer, msgs_empty, device, add_generation_prompt=False) |
| header_prefix_len = lcp_len(ids_empty, ids_sys_user) |
| user_header_len = max(0, header_prefix_len - ids_sys.size(1)) |
| doc_prefix_len = 0 |
| if doc_prefix_text: |
| doc_prefix_len = len(tokenizer(doc_prefix_text, add_special_tokens=False).input_ids) |
| doc_end = min(user_ids_full.size(1), user_header_len + doc_prefix_len) |
| user_doc_ids = user_ids_full[:, :doc_end] |
| user_q_ids = user_ids_full[:, doc_end:] |
|
|
| return { |
| "prompt_ids": ids_hdr, |
| "system_ids": ids_sys, |
| "system_user_ids": ids_sys_user, |
| "user_ids_full": user_ids_full, |
| "user_doc_ids": user_doc_ids, |
| "user_q_ids": user_q_ids, |
| "hdr_tail": hdr_tail, |
| } |
|
|
|
|
| def build_messages_for_llopa(tokenizer, |
| system: str, |
| document: str, |
| question: str, |
| *, |
| task: str = "qa_doc", |
| device: str): |
| return _build_llopa_inputs( |
| tokenizer, |
| system=system, |
| document=document, |
| question=question, |
| task=task, |
| device=device, |
| ) |
|
|
|
|
| def _normalize_prompt_messages(messages): |
| out = [] |
| if not isinstance(messages, list): |
| return out |
| for msg in messages: |
| if not isinstance(msg, dict): |
| continue |
| role = str(msg.get("role") or "user").strip().lower() |
| if role not in {"system", "user", "assistant"}: |
| role = "user" |
| content = str(msg.get("content") or "") |
| if role != "assistant": |
| content = content.strip() |
| if not content: |
| continue |
| out.append({"role": role, "content": content}) |
| return out |
|
|
|
|
| def _assistant_header_starts_from_messages( |
| tokenizer, |
| prompt_messages, |
| *, |
| prompt_add_generation_prompt: bool, |
| device, |
| ): |
| msgs = _normalize_prompt_messages(prompt_messages) |
| if not msgs: |
| return None, None |
|
|
| starts: list[int] = [] |
| for idx, msg in enumerate(msgs): |
| if msg["role"] != "assistant": |
| continue |
| prefix_ids = tokens_from_messages(tokenizer, msgs[:idx], device, add_generation_prompt=False) |
| starts.append(int(prefix_ids.size(1))) |
|
|
| if bool(prompt_add_generation_prompt) and msgs[-1]["role"] != "assistant": |
| prefix_ids = tokens_from_messages(tokenizer, msgs, device, add_generation_prompt=False) |
| starts.append(int(prefix_ids.size(1))) |
|
|
| if not starts: |
| return None, None |
|
|
| starts_tensor = torch.tensor([starts], device=device, dtype=torch.long) |
| start_mask = torch.ones((1, len(starts)), device=device, dtype=torch.bool) |
| return starts_tensor, start_mask |
|
|
|
|
| def _assistant_turn_boundaries_from_messages( |
| tokenizer, |
| prompt_messages, |
| *, |
| prompt_add_generation_prompt: bool, |
| device, |
| ): |
| msgs = _normalize_prompt_messages(prompt_messages) |
| if not msgs: |
| return None, None, None |
|
|
| starts: list[int] = [] |
| ends: list[int] = [] |
| for idx, msg in enumerate(msgs): |
| if msg["role"] != "assistant": |
| continue |
| prefix_ids = tokens_from_messages(tokenizer, msgs[:idx], device, add_generation_prompt=False) |
| turn_ids = tokens_from_messages(tokenizer, msgs[: idx + 1], device, add_generation_prompt=False) |
| starts.append(int(prefix_ids.size(1))) |
| ends.append(int(turn_ids.size(1))) |
|
|
| if bool(prompt_add_generation_prompt) and msgs[-1]["role"] != "assistant": |
| prefix_ids = tokens_from_messages(tokenizer, msgs, device, add_generation_prompt=False) |
| prompt_ids = tokens_from_messages(tokenizer, msgs, device, add_generation_prompt=True) |
| starts.append(int(prefix_ids.size(1))) |
| ends.append(int(prompt_ids.size(1))) |
|
|
| if not starts: |
| return None, None, None |
|
|
| starts_tensor = torch.tensor([starts], device=device, dtype=torch.long) |
| ends_tensor = torch.tensor([ends], device=device, dtype=torch.long) |
| start_mask = torch.ones((1, len(starts)), device=device, dtype=torch.bool) |
| return starts_tensor, ends_tensor, start_mask |
|
|
|
|
| def _assistant_content_delta_from_messages(tokenizer, prefix_messages, assistant_text: str, su_gen, device): |
| msgs_ass = list(prefix_messages) + [{"role": "assistant", "content": assistant_text}] |
| full_ids = tokens_from_messages(tokenizer, msgs_ass, device, add_generation_prompt=False) |
| if full_ids.size(1) <= su_gen.size(1): |
| return full_ids[:, :0] |
| return full_ids[:, su_gen.size(1):] |
|
|
|
|
| def _strip_trailing_assistant_stop_tokens(tokenizer, token_ids: torch.Tensor) -> torch.Tensor: |
| if not isinstance(token_ids, torch.Tensor) or token_ids.numel() == 0: |
| return token_ids |
| stop_ids = set() |
| eos_id = getattr(tokenizer, "eos_token_id", None) |
| if eos_id is not None: |
| stop_ids.add(int(eos_id)) |
| with contextlib.suppress(Exception): |
| eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>") |
| if eot_id is not None and eot_id != tokenizer.unk_token_id: |
| stop_ids.add(int(eot_id)) |
| if not stop_ids: |
| return token_ids |
| trimmed = token_ids |
| while trimmed.size(1) > 0 and int(trimmed[0, -1].item()) in stop_ids: |
| trimmed = trimmed[:, :-1] |
| return trimmed |
|
|
|
|
| def _resolve_user_replay_layout_from_messages( |
| tokenizer, |
| prefix_messages, |
| *, |
| system_len: int, |
| user_len: int, |
| device, |
| ): |
| system_len = max(int(system_len), 0) |
| user_len = max(int(user_len), 0) |
| if user_len <= 0: |
| return 0, 0, 0 |
|
|
| user_indices = [idx for idx, msg in enumerate(prefix_messages) if msg["role"] == "user"] |
| if not user_indices: |
| return 0, 0, user_len |
|
|
| def _token_len(msgs) -> int: |
| try: |
| return int(tokens_from_messages(tokenizer, msgs, device, add_generation_prompt=False).size(1)) |
| except Exception: |
| return 0 |
|
|
| first_user_idx = int(user_indices[0]) |
| latest_user_idx = int(user_indices[-1]) |
|
|
| prefix_before_first_user_len = _token_len(prefix_messages[:first_user_idx]) |
| prefix_before_latest_user_len = _token_len(prefix_messages[:latest_user_idx]) |
| prefix_through_latest_user_len = _token_len(prefix_messages[: latest_user_idx + 1]) |
|
|
| user_prefix_keep_len = max(prefix_before_first_user_len - system_len, 0) |
| user_prefix_keep_len = min(user_prefix_keep_len, user_len) |
|
|
| latest_user_start = max(prefix_before_latest_user_len - system_len, user_prefix_keep_len) |
| latest_user_start = min(latest_user_start, user_len) |
|
|
| latest_user_end = max(prefix_through_latest_user_len - system_len, latest_user_start) |
| latest_user_end = min(latest_user_end, user_len) |
|
|
| latest_user_len = max(latest_user_end - latest_user_start, 0) |
| if latest_user_len <= 0 and user_len > user_prefix_keep_len: |
| latest_user_start = int(user_prefix_keep_len) |
| latest_user_len = int(user_len - user_prefix_keep_len) |
|
|
| return int(user_prefix_keep_len), int(latest_user_start), int(latest_user_len) |
|
|
|
|
| def _build_structured_prompt_segments(tokenizer, prompt_messages, *, prompt_add_generation_prompt: bool, device): |
| msgs = _normalize_prompt_messages(prompt_messages) |
| if not msgs: |
| raise ValueError("prompt_messages must contain at least one non-empty message.") |
|
|
| if bool(prompt_add_generation_prompt): |
| if msgs[-1]["role"] == "assistant": |
| raise ValueError("prompt_add_generation_prompt=True requires prompt_messages to end with a non-assistant role.") |
| prefix_messages = msgs |
| prefix_ids = tokens_from_messages(tokenizer, prefix_messages, device, add_generation_prompt=False) |
| prompt_ids = tokens_from_messages(tokenizer, prefix_messages, device, add_generation_prompt=True) |
| assistant_prefill_ids = prompt_ids[:, prefix_ids.size(1):] |
| else: |
| if msgs[-1]["role"] != "assistant": |
| raise ValueError("prompt_add_generation_prompt=False requires prompt_messages to end with an assistant prefix.") |
| if len(msgs) < 2: |
| raise ValueError("assistant-prefix prompts require at least one preceding non-assistant message.") |
| prefix_messages = msgs[:-1] |
| assistant_text = str(msgs[-1]["content"] or "") |
| prefix_ids = tokens_from_messages(tokenizer, prefix_messages, device, add_generation_prompt=False) |
| prompt_prefix = tokens_from_messages(tokenizer, prefix_messages, device, add_generation_prompt=True) |
| assistant_content_ids = _assistant_content_delta_from_messages( |
| tokenizer, |
| prefix_messages, |
| assistant_text, |
| prompt_prefix, |
| device, |
| ) |
| assistant_content_ids = _strip_trailing_assistant_stop_tokens(tokenizer, assistant_content_ids) |
| assistant_header_ids = prompt_prefix[:, prefix_ids.size(1):] |
| assistant_prefill_ids = torch.cat([assistant_header_ids, assistant_content_ids], dim=1) |
| prompt_ids = torch.cat([prefix_ids, assistant_prefill_ids], dim=1) |
|
|
| if assistant_prefill_ids.size(1) <= 0: |
| raise ValueError("Structured direct LLoPA prompt produced an empty assistant prefill segment.") |
|
|
| if prefix_messages and prefix_messages[0]["role"] == "system": |
| system_ids = tokens_from_messages(tokenizer, [prefix_messages[0]], device, add_generation_prompt=False) |
| user_ids = prefix_ids[:, system_ids.size(1):] |
| else: |
| system_ids = prefix_ids[:, :0] |
| user_ids = prefix_ids |
|
|
| replay_user_prefix_keep_len, replay_user_start, replay_user_len = _resolve_user_replay_layout_from_messages( |
| tokenizer, |
| prefix_messages, |
| system_len=int(system_ids.size(1)), |
| user_len=int(user_ids.size(1)), |
| device=device, |
| ) |
|
|
| assistant_header_starts, assistant_turn_ends, assistant_header_start_mask = _assistant_turn_boundaries_from_messages( |
| tokenizer, |
| msgs, |
| prompt_add_generation_prompt=bool(prompt_add_generation_prompt), |
| device=device, |
| ) |
|
|
| return { |
| "prefix_ids": prefix_ids, |
| "prompt_ids": prompt_ids, |
| "system_ids": system_ids, |
| "user_ids": user_ids, |
| "assistant_prefill_ids": assistant_prefill_ids, |
| "replay_user_prefix_keep_len": int(replay_user_prefix_keep_len), |
| "replay_user_start": int(replay_user_start), |
| "replay_user_len": int(replay_user_len), |
| "assistant_header_starts": assistant_header_starts, |
| "assistant_turn_ends": assistant_turn_ends, |
| "assistant_header_start_mask": assistant_header_start_mask, |
| } |
|
|
|
|
| def _build_unified_prefill_lower_prompt_bundle( |
| tokenizer, |
| *, |
| prompt_messages, |
| prompt_add_generation_prompt: bool, |
| structured_prompt_segments=None, |
| device, |
| ): |
| segments = structured_prompt_segments if isinstance(structured_prompt_segments, dict) else None |
| if segments is None: |
| segments = _build_structured_prompt_segments( |
| tokenizer, |
| prompt_messages, |
| prompt_add_generation_prompt=bool(prompt_add_generation_prompt), |
| device=device, |
| ) |
|
|
| prompt_ids = segments["prompt_ids"] |
| system_ids = segments["system_ids"] |
| prefix_ids = segments.get("prefix_ids") |
| header_starts = segments.get("assistant_header_starts") |
| assistant_turn_ends = segments.get("assistant_turn_ends") |
| header_start_mask = segments.get("assistant_header_start_mask") |
| replay_user_prefix_keep_len = int(segments.get("replay_user_prefix_keep_len", 0) or 0) |
| replay_user_start = int(segments.get("replay_user_start", 0) or 0) |
| replay_user_len = int(segments.get("replay_user_len", 0) or 0) |
|
|
| assistant_header_start: Optional[int] = None |
| if ( |
| isinstance(header_starts, torch.Tensor) |
| and header_starts.ndim == 2 |
| and header_starts.size(0) == 1 |
| and header_starts.numel() > 0 |
| ): |
| valid_mask = header_start_mask |
| if not isinstance(valid_mask, torch.Tensor) or valid_mask.shape != header_starts.shape: |
| valid_mask = header_starts >= 0 |
| valid_starts = header_starts[0][valid_mask[0]] |
| if valid_starts.numel() > 0: |
| assistant_header_start = int(valid_starts[-1].item()) |
| if isinstance(prefix_ids, torch.Tensor) and prefix_ids.ndim == 2 and prefix_ids.size(0) == 1: |
| assistant_header_start = int(prefix_ids.size(1)) |
| if assistant_header_start is None: |
| header_ids = _assistant_header_ids(tokenizer, device=device) |
| if isinstance(header_ids, torch.Tensor) and header_ids.ndim == 2 and header_ids.size(0) == 1: |
| assistant_header_start = _find_last_subsequence_start(prompt_ids, header_ids) |
| if assistant_header_start is None: |
| assistant_header_start = max(int(prompt_ids.size(1) - 1), 0) |
|
|
| return { |
| "segments": segments, |
| "prompt_ids": prompt_ids, |
| "attention_mask": torch.ones_like(prompt_ids, device=prompt_ids.device), |
| "assistant_header_start": int(assistant_header_start), |
| "prefill_lower_split_start": torch.tensor( |
| [int(assistant_header_start)], |
| device=prompt_ids.device, |
| dtype=torch.long, |
| ), |
| "assistant_header_starts": ( |
| header_starts.to(device=prompt_ids.device, dtype=torch.long) |
| if isinstance(header_starts, torch.Tensor) and header_starts.numel() > 0 |
| else torch.tensor( |
| [[int(assistant_header_start)]], |
| device=prompt_ids.device, |
| dtype=torch.long, |
| ) |
| ), |
| "assistant_turn_ends": ( |
| assistant_turn_ends.to(device=prompt_ids.device, dtype=torch.long) |
| if isinstance(assistant_turn_ends, torch.Tensor) and assistant_turn_ends.numel() > 0 |
| else torch.tensor( |
| [[int(prompt_ids.size(1))]], |
| device=prompt_ids.device, |
| dtype=torch.long, |
| ) |
| ), |
| "assistant_header_start_mask": ( |
| header_start_mask.to(device=prompt_ids.device, dtype=torch.bool) |
| if isinstance(header_start_mask, torch.Tensor) and header_start_mask.numel() > 0 |
| else torch.ones((1, 1), device=prompt_ids.device, dtype=torch.bool) |
| ), |
| "prefill_lower_system_len": torch.tensor( |
| [int(system_ids.size(1))], |
| device=prompt_ids.device, |
| dtype=torch.long, |
| ), |
| "prefill_lower_replay_user_prefix_keep_len": torch.tensor( |
| [int(replay_user_prefix_keep_len)], |
| device=prompt_ids.device, |
| dtype=torch.long, |
| ), |
| "prefill_lower_replay_user_start": torch.tensor( |
| [int(replay_user_start)], |
| device=prompt_ids.device, |
| dtype=torch.long, |
| ), |
| "prefill_lower_replay_user_len": torch.tensor( |
| [int(replay_user_len)], |
| device=prompt_ids.device, |
| dtype=torch.long, |
| ), |
| } |
|
|
|
|
| def _prompt_bundle_has_past_assistant_history(prompt_bundle) -> bool: |
| if not isinstance(prompt_bundle, dict): |
| return False |
|
|
| header_starts = prompt_bundle.get("assistant_header_starts") |
| turn_ends = prompt_bundle.get("assistant_turn_ends") |
| if not isinstance(header_starts, torch.Tensor) or not isinstance(turn_ends, torch.Tensor): |
| return False |
| if header_starts.numel() == 0 or turn_ends.numel() == 0: |
| return False |
| if header_starts.dim() == 1: |
| header_starts = header_starts.view(1, -1) |
| if turn_ends.dim() == 1: |
| turn_ends = turn_ends.view(1, -1) |
| if header_starts.dim() != 2 or turn_ends.dim() != 2: |
| return False |
|
|
| header_mask = prompt_bundle.get("assistant_header_start_mask") |
| if isinstance(header_mask, torch.Tensor) and header_mask.numel() > 0: |
| if header_mask.dim() == 1: |
| header_mask = header_mask.view(1, -1) |
| if header_mask.dim() != 2 or header_mask.shape != header_starts.shape: |
| header_mask = header_starts >= 0 |
| else: |
| header_mask = header_mask.to(device=header_starts.device, dtype=torch.bool) |
| else: |
| header_mask = header_starts >= 0 |
|
|
| split_starts = prompt_bundle.get("effective_prefill_lower_split_start") |
| if not isinstance(split_starts, torch.Tensor) or split_starts.numel() == 0: |
| split_starts = prompt_bundle.get("prefill_lower_split_start") |
| if isinstance(split_starts, torch.Tensor) and split_starts.numel() > 0: |
| split_starts = split_starts.flatten().to(device=header_starts.device, dtype=torch.long) |
| else: |
| split_starts = torch.tensor( |
| [int(prompt_bundle.get("assistant_header_start", 0) or 0)], |
| device=header_starts.device, |
| dtype=torch.long, |
| ) |
|
|
| rows = min(int(header_starts.size(0)), int(turn_ends.size(0))) |
| cols = min(int(header_starts.size(1)), int(turn_ends.size(1))) |
| if rows <= 0 or cols <= 0: |
| return False |
| for row in range(rows): |
| split_idx = min(row, int(split_starts.numel()) - 1) |
| split_start = int(split_starts[split_idx].item()) |
| for col in range(cols): |
| if not bool(header_mask[row, col].item()): |
| continue |
| turn_start = int(header_starts[row, col].item()) |
| turn_end = int(turn_ends[row, col].item()) |
| if turn_end <= turn_start or turn_start >= split_start: |
| continue |
| if min(turn_end, split_start) > turn_start: |
| return True |
| return False |
|
|
|
|
| def _direct_prefill_lower_cache_and_logits( |
| model, |
| *, |
| prompt_bundle, |
| lower_k: int, |
| prefill_attn: str, |
| system_prefill: str, |
| no_upper_attn: bool, |
| see_past_assistant: bool = False, |
| replay_module: str = "none", |
| replay_per_layers: int = -1, |
| last_layer_module: Optional[str] = None, |
| seed_mode: str = "auto", |
| ): |
| if last_layer_module is not None and _normalize_replay_module_value(replay_module) == "none": |
| replay_module = last_layer_module |
| replay_module = _normalize_replay_module_value(replay_module) |
| replay_per_layers = _normalize_replay_per_layers_value(replay_per_layers) |
| seed_mode = _normalize_structured_llopa_seed_mode(seed_mode) |
| if isinstance(prompt_bundle, dict): |
| prompt_ids = prompt_bundle.get("prompt_ids") |
| if isinstance(prompt_ids, torch.Tensor): |
| attention_mask = prompt_bundle.get("attention_mask") |
| if not isinstance(attention_mask, torch.Tensor) or attention_mask.shape != prompt_ids.shape: |
| attention_mask = torch.ones_like(prompt_ids, device=prompt_ids.device, dtype=torch.long) |
| else: |
| attention_mask = attention_mask.to(device=prompt_ids.device, dtype=torch.long) |
|
|
| effective_prompt_ids = prompt_ids |
| effective_prompt_attention_mask = attention_mask |
| split_starts = prompt_bundle.get("prefill_lower_split_start") |
| if isinstance(split_starts, torch.Tensor): |
| split_starts = split_starts.to(device=prompt_ids.device, dtype=torch.long) |
| else: |
| split_starts = torch.tensor( |
| [int(prompt_bundle.get("assistant_header_start", max(int(prompt_ids.size(1) - 1), 0)))], |
| device=prompt_ids.device, |
| dtype=torch.long, |
| ) |
|
|
| valid_len = ( |
| int(effective_prompt_attention_mask[0].sum().item()) |
| if effective_prompt_attention_mask.ndim == 2 and effective_prompt_attention_mask.size(0) == 1 |
| else int(effective_prompt_ids.size(1)) |
| ) |
| if valid_len > 0: |
| effective_prompt_ids = effective_prompt_ids[:, :valid_len] |
| effective_prompt_attention_mask = effective_prompt_attention_mask[:, :valid_len] |
|
|
| prompt_bundle["effective_prompt_ids"] = effective_prompt_ids |
| prompt_bundle["effective_prompt_attention_mask"] = effective_prompt_attention_mask |
| prompt_bundle["effective_prefill_lower_split_start"] = split_starts |
|
|
| segments = prompt_bundle.get("segments") if isinstance(prompt_bundle, dict) else None |
| if not isinstance(segments, dict): |
| return None |
|
|
| if not bool(see_past_assistant): |
| matched_inband_seed = _matched_inband_prefill_cache_and_logits( |
| model, |
| prompt_bundle=prompt_bundle, |
| lower_k=lower_k, |
| prefill_attn=prefill_attn, |
| system_prefill=system_prefill, |
| no_upper_attn=no_upper_attn, |
| ) |
| if matched_inband_seed is not None: |
| return matched_inband_seed |
|
|
| llopa_full_prompt_seed_fn = _get_llopa_full_prompt_seed(model) |
| effective_prompt_ids = prompt_bundle.get("effective_prompt_ids") |
| effective_prompt_attention_mask = prompt_bundle.get("effective_prompt_attention_mask") |
| split_starts = prompt_bundle.get("effective_prefill_lower_split_start") |
| system_lens = prompt_bundle.get("prefill_lower_system_len") |
| replay_user_prefix_keep_lens = prompt_bundle.get("prefill_lower_replay_user_prefix_keep_len") |
| replay_user_starts = prompt_bundle.get("prefill_lower_replay_user_start") |
| replay_user_lens = prompt_bundle.get("prefill_lower_replay_user_len") |
| assistant_header_starts = prompt_bundle.get("assistant_header_starts") |
| assistant_turn_ends = prompt_bundle.get("assistant_turn_ends") |
| assistant_header_start_mask = prompt_bundle.get("assistant_header_start_mask") |
| needs_past_assistant_seed = bool(see_past_assistant) and _prompt_bundle_has_past_assistant_history(prompt_bundle) |
| should_try_full_prompt_seed = seed_mode == "auto" or bool(needs_past_assistant_seed) |
| full_prompt_seed_error = None |
| if should_try_full_prompt_seed and callable(llopa_full_prompt_seed_fn) and isinstance(effective_prompt_ids, torch.Tensor): |
| if not isinstance(effective_prompt_attention_mask, torch.Tensor): |
| effective_prompt_attention_mask = torch.ones_like( |
| effective_prompt_ids, |
| device=effective_prompt_ids.device, |
| dtype=torch.long, |
| ) |
| if not isinstance(split_starts, torch.Tensor): |
| split_starts = prompt_bundle.get("prefill_lower_split_start") |
| if not isinstance(system_lens, torch.Tensor): |
| system_lens = torch.zeros( |
| (effective_prompt_ids.size(0),), |
| device=effective_prompt_ids.device, |
| dtype=torch.long, |
| ) |
| try: |
| full_prompt_seed = llopa_full_prompt_seed_fn( |
| input_ids=effective_prompt_ids, |
| attention_mask=effective_prompt_attention_mask, |
| use_cache=True, |
| logits_to_keep=1, |
| lower_k=int(lower_k), |
| prefill_attn=str(prefill_attn), |
| system_prefill=str(system_prefill), |
| no_upper_attn=bool(no_upper_attn), |
| prefill_lower_split_start=split_starts, |
| prefill_lower_system_len=system_lens, |
| prefill_lower_replay_user_prefix_keep_len=replay_user_prefix_keep_lens, |
| prefill_lower_replay_user_start=replay_user_starts, |
| prefill_lower_replay_user_len=replay_user_lens, |
| assistant_header_starts=assistant_header_starts, |
| assistant_turn_ends=assistant_turn_ends, |
| assistant_header_start_mask=assistant_header_start_mask, |
| prefill_lower_see_past_assistant=bool(see_past_assistant), |
| replay_module=str(replay_module), |
| replay_per_layers=int(replay_per_layers), |
| ) |
| if full_prompt_seed is not None: |
| return full_prompt_seed |
| except Exception as exc: |
| full_prompt_seed_error = exc |
|
|
| if bool(needs_past_assistant_seed): |
| if not callable(llopa_full_prompt_seed_fn): |
| raise RuntimeError( |
| "LLOPA_SEE_PAST_ASSISTANT=1 requires llopa_full_prompt_prefill_seed " |
| "for prefill_header prompts with previous assistant turns." |
| ) |
| if full_prompt_seed_error is not None: |
| raise RuntimeError( |
| "LLOPA_SEE_PAST_ASSISTANT=1 failed in llopa_full_prompt_prefill_seed " |
| "for a prefill_header prompt with previous assistant turns." |
| ) from full_prompt_seed_error |
| raise RuntimeError( |
| "LLOPA_SEE_PAST_ASSISTANT=1 could not build a prefill_header seed " |
| "that includes previous assistant turns." |
| ) |
|
|
| seed_fn = getattr(model, "llopa_reference_prefill_seed", None) |
| if not callable(seed_fn): |
| return None |
| try: |
| return seed_fn( |
| system_ids=segments["system_ids"], |
| user_ids=segments["user_ids"], |
| assistant_ids=segments["assistant_prefill_ids"], |
| lower_k=int(lower_k), |
| prefill_attn=str(prefill_attn), |
| system_prefill=str(system_prefill), |
| no_upper_attn=bool(no_upper_attn), |
| replay_module=str(replay_module), |
| replay_per_layers=int(replay_per_layers), |
| replay_user_prefix_keep_len=int(segments.get("replay_user_prefix_keep_len", 0) or 0), |
| replay_user_start=int(segments.get("replay_user_start", 0) or 0), |
| replay_user_len=int(segments.get("replay_user_len", 0) or 0), |
| ) |
| except Exception: |
| return None |
|
|
|
|
| def _optimized_prefill_lower_cache_and_logits( |
| model, |
| *, |
| prompt_bundle, |
| lower_k: int, |
| prefill_attn: str, |
| system_prefill: str, |
| no_upper_attn: bool, |
| see_past_assistant: bool = False, |
| replay_module: str = "none", |
| replay_per_layers: int = -1, |
| last_layer_module: Optional[str] = None, |
| seed_mode: str = "auto", |
| ): |
| if last_layer_module is not None and _normalize_replay_module_value(replay_module) == "none": |
| replay_module = last_layer_module |
| replay_module = _normalize_replay_module_value(replay_module) |
| replay_per_layers = _normalize_replay_per_layers_value(replay_per_layers) |
| seed_mode = _normalize_optimized_llopa_seed_mode(seed_mode) |
| if seed_mode in {"matched", "auto"}: |
| matched_seed = _matched_inband_prefill_cache_and_logits( |
| model, |
| prompt_bundle=prompt_bundle, |
| lower_k=lower_k, |
| prefill_attn=prefill_attn, |
| system_prefill=system_prefill, |
| no_upper_attn=no_upper_attn, |
| replay_module=replay_module, |
| replay_per_layers=replay_per_layers, |
| ) |
| if matched_seed is not None: |
| return matched_seed |
|
|
| llopa_full_prompt_seed_fn = _get_llopa_full_prompt_seed(model) |
| if seed_mode in {"tri", "auto"} and callable(llopa_full_prompt_seed_fn) and isinstance(prompt_bundle, dict): |
| effective_prompt_ids = prompt_bundle.get("effective_prompt_ids") |
| effective_prompt_attention_mask = prompt_bundle.get("effective_prompt_attention_mask") |
| split_starts = prompt_bundle.get("effective_prefill_lower_split_start") |
| system_lens = prompt_bundle.get("prefill_lower_system_len") |
| replay_user_prefix_keep_lens = prompt_bundle.get("prefill_lower_replay_user_prefix_keep_len") |
| replay_user_starts = prompt_bundle.get("prefill_lower_replay_user_start") |
| replay_user_lens = prompt_bundle.get("prefill_lower_replay_user_len") |
| assistant_header_starts = prompt_bundle.get("assistant_header_starts") |
| assistant_header_start_mask = prompt_bundle.get("assistant_header_start_mask") |
| if isinstance(effective_prompt_ids, torch.Tensor): |
| if not isinstance(effective_prompt_attention_mask, torch.Tensor): |
| effective_prompt_attention_mask = torch.ones_like( |
| effective_prompt_ids, |
| device=effective_prompt_ids.device, |
| dtype=torch.long, |
| ) |
| if not isinstance(split_starts, torch.Tensor): |
| split_starts = prompt_bundle.get("prefill_lower_split_start") |
| if not isinstance(split_starts, torch.Tensor): |
| split_starts = torch.tensor( |
| [int(prompt_bundle.get("assistant_header_start", max(int(effective_prompt_ids.size(1) - 1), 0)))], |
| device=effective_prompt_ids.device, |
| dtype=torch.long, |
| ) |
| if not isinstance(system_lens, torch.Tensor): |
| system_lens = torch.zeros( |
| (effective_prompt_ids.size(0),), |
| device=effective_prompt_ids.device, |
| dtype=torch.long, |
| ) |
| if not isinstance(replay_user_prefix_keep_lens, torch.Tensor): |
| replay_user_prefix_keep_lens = torch.zeros( |
| (effective_prompt_ids.size(0),), |
| device=effective_prompt_ids.device, |
| dtype=torch.long, |
| ) |
| if not isinstance(replay_user_starts, torch.Tensor): |
| replay_user_starts = replay_user_prefix_keep_lens.clone() |
| if not isinstance(replay_user_lens, torch.Tensor): |
| replay_user_lens = torch.zeros( |
| (effective_prompt_ids.size(0),), |
| device=effective_prompt_ids.device, |
| dtype=torch.long, |
| ) |
|
|
| try: |
| full_prompt_seed = llopa_full_prompt_seed_fn( |
| input_ids=effective_prompt_ids, |
| attention_mask=effective_prompt_attention_mask, |
| use_cache=True, |
| logits_to_keep=1, |
| lower_k=int(lower_k), |
| prefill_attn=str(prefill_attn), |
| system_prefill=str(system_prefill), |
| no_upper_attn=bool(no_upper_attn), |
| replay_module=str(replay_module), |
| replay_per_layers=int(replay_per_layers), |
| prefill_lower_split_start=split_starts, |
| prefill_lower_system_len=system_lens, |
| prefill_lower_replay_user_prefix_keep_len=replay_user_prefix_keep_lens, |
| prefill_lower_replay_user_start=replay_user_starts, |
| prefill_lower_replay_user_len=replay_user_lens, |
| assistant_header_starts=assistant_header_starts, |
| assistant_turn_ends=prompt_bundle.get("assistant_turn_ends"), |
| assistant_header_start_mask=assistant_header_start_mask, |
| prefill_lower_see_past_assistant=bool(see_past_assistant), |
| ) |
| if full_prompt_seed is not None: |
| return full_prompt_seed |
| except Exception: |
| pass |
|
|
| if seed_mode in {"stable", "auto", "matched"}: |
| return _direct_prefill_lower_cache_and_logits( |
| model, |
| prompt_bundle=prompt_bundle, |
| lower_k=lower_k, |
| prefill_attn=prefill_attn, |
| system_prefill=system_prefill, |
| no_upper_attn=no_upper_attn, |
| see_past_assistant=bool(see_past_assistant), |
| replay_module=replay_module, |
| replay_per_layers=replay_per_layers, |
| ) |
|
|
| return None |
|
|
|
|
| def _normalize_structured_llopa_seed_mode(seed_mode: Optional[str]) -> str: |
| normalized = str(seed_mode or "auto").strip().lower() |
| aliases = { |
| "": "auto", |
| "default": "auto", |
| "tri": "auto", |
| "tri_auto": "auto", |
| "reference": "prefill_header", |
| "reference_only": "prefill_header", |
| "prefill-header": "prefill_header", |
| "prefill_header_seed": "prefill_header", |
| "prefill_header_only": "prefill_header", |
| } |
| normalized = aliases.get(normalized, normalized) |
| if normalized not in {"auto", "prefill_header"}: |
| normalized = "auto" |
| return normalized |
|
|
|
|
| def _llopa_modeling_module(model) -> Optional[Any]: |
| llopa_core = _get_llopa_core(model) |
| if llopa_core is None: |
| return None |
| with contextlib.suppress(Exception): |
| return inspect.getmodule(llopa_core.__class__) |
| return None |
|
|
|
|
| def _matched_inband_prefill_cache_and_logits( |
| model, |
| *, |
| prompt_bundle, |
| lower_k: int, |
| prefill_attn: str, |
| system_prefill: str, |
| no_upper_attn: bool, |
| ): |
| if bool(no_upper_attn): |
| return None |
|
|
| llopa_core = _get_llopa_core(model) |
| output_head = _get_output_head(model) |
| llopa_mod = _llopa_modeling_module(model) |
| if llopa_core is None or output_head is None or llopa_mod is None: |
| return None |
|
|
| fusion_mode = str( |
| getattr(getattr(llopa_core, "config", None), "capsule_fusion_mode", "upper_only") or "upper_only" |
| ).strip().lower() |
| if fusion_mode != "inband": |
| return None |
|
|
| attn = (prefill_attn or "causal").strip().lower() |
| if attn == "prefix_full": |
| attn = "full" |
| if attn != "causal": |
| return None |
|
|
| required_names = ( |
| "_safe_dynamic_cache", |
| "_tri_arange", |
| "_llopa_position_ids_from_mask", |
| "_resolve_attn_impl", |
| "_can_use_implicit_causal_mask", |
| "_llopa_mask_is_all_ones", |
| "_build_tri_mask_local", |
| "_tri_insert_suffix_specials_inband", |
| "_tri_effective_suffix_special_token_ids", |
| "_tri_build_prefill_lower_upper_index_batch", |
| "_tri_pack_indexed_tensor", |
| "create_causal_mask", |
| ) |
| if any(not hasattr(llopa_mod, name) for name in required_names): |
| return None |
|
|
| prompt_ids = prompt_bundle["prompt_ids"] |
| attention_mask = prompt_bundle["attention_mask"].to(device=prompt_ids.device, dtype=torch.long) |
| system_len = int(prompt_bundle["prefill_lower_system_len"][0].item()) if prompt_bundle["prefill_lower_system_len"].numel() > 0 else 0 |
| full_ids = prompt_ids |
| full_attention_mask = attention_mask |
| split_starts = prompt_bundle.get("prefill_lower_split_start") |
| if isinstance(split_starts, torch.Tensor): |
| split_starts = split_starts.to(device=prompt_ids.device, dtype=torch.long) |
| else: |
| split_starts = torch.tensor( |
| [int(prompt_bundle.get("assistant_header_start", max(int(prompt_ids.size(1) - 1), 0)))], |
| device=prompt_ids.device, |
| dtype=torch.long, |
| ) |
|
|
| token_ids = list(getattr(llopa_mod, "_tri_effective_suffix_special_token_ids")(model) or []) |
| if token_ids: |
| header_starts = prompt_bundle.get("assistant_header_starts") |
| header_start_mask = prompt_bundle.get("assistant_header_start_mask") |
| ( |
| full_ids, |
| full_attention_mask, |
| _, |
| remapped_split_starts, |
| _, |
| _, |
| ) = getattr(llopa_mod, "_tri_insert_suffix_specials_inband")( |
| token_ids=token_ids, |
| input_ids=full_ids, |
| attention_mask=full_attention_mask, |
| labels=None, |
| split_starts=split_starts, |
| assistant_header_starts=header_starts, |
| assistant_header_start_mask=header_start_mask, |
| ) |
| if isinstance(remapped_split_starts, torch.Tensor) and remapped_split_starts.numel() > 0: |
| split_starts = remapped_split_starts.to(device=prompt_ids.device, dtype=torch.long) |
|
|
| valid_len = int(full_attention_mask[0].sum().item()) |
| if valid_len <= 0: |
| return None |
| full_ids = full_ids[:, :valid_len] |
| full_attention_mask = full_attention_mask[:, :valid_len] |
| split_start = max(0, min(int(split_starts[0].item()) if split_starts.numel() > 0 else int(valid_len - 1), valid_len)) |
|
|
| try: |
| lower_k = int(lower_k) |
| except Exception: |
| return None |
| n_layers = len(getattr(llopa_core, "layers", [])) |
| if lower_k <= 0 or n_layers <= 0: |
| return None |
| lower_k = max(0, min(lower_k, n_layers)) |
|
|
| device = full_ids.device |
| pkv = getattr(llopa_mod, "_safe_dynamic_cache")(llopa_core.config) |
| inputs_embeds = llopa_core.embed_tokens(full_ids) |
| cache_position = getattr(llopa_mod, "_tri_arange")(0, inputs_embeds.shape[1], device) |
| position_ids = getattr(llopa_mod, "_llopa_position_ids_from_mask")(full_attention_mask) |
| attn_impl = getattr(llopa_mod, "_resolve_attn_impl")(llopa_core.config) |
| if attn_impl == "flash_attention_2": |
| lower_mask = None if getattr(llopa_mod, "_llopa_mask_is_all_ones")(full_attention_mask) else full_attention_mask |
| elif getattr(llopa_mod, "_can_use_implicit_causal_mask")(llopa_core.config) and getattr(llopa_mod, "_llopa_mask_is_all_ones")(full_attention_mask): |
| lower_mask = None |
| else: |
| lower_mask = getattr(llopa_mod, "create_causal_mask")( |
| config=llopa_core.config, |
| input_embeds=inputs_embeds, |
| attention_mask=full_attention_mask, |
| cache_position=cache_position, |
| past_key_values=None, |
| position_ids=position_ids, |
| ) |
|
|
| hidden_states = inputs_embeds |
| position_embeddings = llopa_core.rotary_emb(hidden_states, position_ids) |
| for li in range(lower_k): |
| layer = llopa_core.layers[li] |
| hidden_states = layer( |
| hidden_states, |
| attention_mask=lower_mask, |
| position_ids=position_ids, |
| past_key_values=pkv, |
| use_cache=True, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| ) |
|
|
| split_starts = torch.tensor([split_start], device=hidden_states.device, dtype=torch.long) |
| valid_lens = torch.tensor([valid_len], device=hidden_states.device, dtype=torch.long) |
| system_lens = torch.tensor([system_len], device=hidden_states.device, dtype=torch.long) |
| upper_gather_idx, upper_valid_mask, upper_lens = getattr(llopa_mod, "_tri_build_prefill_lower_upper_index_batch")( |
| split_starts=split_starts, |
| valid_lens=valid_lens, |
| system_lens=system_lens, |
| system_prefill=str(system_prefill), |
| device=hidden_states.device, |
| ) |
| upper_hidden, _ = getattr(llopa_mod, "_tri_pack_indexed_tensor")( |
| hidden_states, |
| gather_idx=upper_gather_idx, |
| valid_mask=upper_valid_mask, |
| pad_value=0.0, |
| ) |
| upper_position_ids_src = position_ids.to(device=hidden_states.device, dtype=torch.long) |
| upper_position_ids, _ = getattr(llopa_mod, "_tri_pack_indexed_tensor")( |
| upper_position_ids_src, |
| gather_idx=upper_gather_idx, |
| valid_mask=upper_valid_mask, |
| pad_value=0, |
| ) |
|
|
| upper_len = int(upper_lens[0].item()) if upper_lens.numel() > 0 else 0 |
| if upper_len <= 0: |
| return None |
| upper_hidden = upper_hidden[:, :upper_len, :] |
| upper_position_ids = upper_position_ids[:, :upper_len] |
| upper_cache_position = upper_position_ids[0] |
|
|
| if lower_k < n_layers: |
| if attn_impl == "flash_attention_2" or getattr(llopa_mod, "_can_use_implicit_causal_mask")(llopa_core.config): |
| upper_mask = None |
| else: |
| upper_mask = getattr(llopa_mod, "_build_tri_mask_local")( |
| 1, |
| upper_len, |
| 0, |
| upper_hidden.device, |
| upper_hidden.dtype, |
| ) |
| upper_pos_emb = llopa_core.rotary_emb(upper_hidden, upper_position_ids) |
| for li in range(lower_k, n_layers): |
| layer = llopa_core.layers[li] |
| upper_hidden = layer( |
| upper_hidden, |
| attention_mask=upper_mask, |
| position_ids=upper_position_ids, |
| past_key_values=pkv, |
| use_cache=True, |
| cache_position=upper_cache_position, |
| position_embeddings=upper_pos_emb, |
| ) |
|
|
| upper_hidden = llopa_core.norm(upper_hidden) |
| initial_logits = output_head(upper_hidden[:, -1:, :])[:, -1, :].to(torch.float32) |
|
|
| sys_mode = str(system_prefill or "full").strip().lower() |
| if sys_mode == "full": |
| visible_prefix_len = system_len |
| elif sys_mode == "no_system": |
| visible_prefix_len = min(system_len, 1) |
| else: |
| visible_prefix_len = 0 |
| return pkv, int(visible_prefix_len), max(int(split_start) - int(visible_prefix_len), 0), initial_logits |
|
|
|
|
| def _coerce_llopa_inputs(tokenizer, |
| system: str, |
| document: str, |
| question: str, |
| *, |
| task: str, |
| device: str, |
| input_ids): |
| if not isinstance(input_ids, dict): |
| return _build_llopa_inputs( |
| tokenizer, |
| system=system, |
| document=document, |
| question=question, |
| task=task, |
| device=device, |
| ) |
|
|
| def _get_first(mapping, *keys): |
| for k in keys: |
| if k in mapping and mapping[k] is not None: |
| return mapping[k] |
| return None |
|
|
| ids_sys = _get_first(input_ids, "system_ids", "ids_sys") |
| ids_sys_user = _get_first(input_ids, "system_user_ids", "ids_sys_user") |
| user_doc_ids = _get_first(input_ids, "user_doc_ids") |
| user_q_ids = _get_first(input_ids, "user_q_ids") |
| hdr_tail = _get_first(input_ids, "hdr_tail") |
| prompt_ids = _get_first(input_ids, "prompt_ids", "ids_hdr") |
|
|
| if (ids_sys is None or ids_sys_user is None or user_doc_ids is None or |
| user_q_ids is None or hdr_tail is None): |
| built = _build_llopa_inputs( |
| tokenizer, |
| system=system, |
| document=document, |
| question=question, |
| task=task, |
| device=device, |
| ) |
| if ids_sys is None: |
| ids_sys = built["system_ids"] |
| if ids_sys_user is None: |
| ids_sys_user = built["system_user_ids"] |
| if user_doc_ids is None: |
| user_doc_ids = built["user_doc_ids"] |
| if user_q_ids is None: |
| user_q_ids = built["user_q_ids"] |
| if hdr_tail is None: |
| hdr_tail = built["hdr_tail"] |
| if prompt_ids is None: |
| prompt_ids = built.get("prompt_ids") |
|
|
| if prompt_ids is None and ids_sys_user is not None and hdr_tail is not None: |
| try: |
| prompt_ids = torch.cat([ids_sys_user, hdr_tail], dim=1) |
| except Exception: |
| prompt_ids = None |
|
|
| return { |
| "prompt_ids": prompt_ids, |
| "system_ids": ids_sys, |
| "system_user_ids": ids_sys_user, |
| "user_doc_ids": user_doc_ids, |
| "user_q_ids": user_q_ids, |
| "hdr_tail": hdr_tail, |
| } |
|
|
| def lcp_len(a: torch.Tensor, b: torch.Tensor) -> int: |
| L = min(a.size(1), b.size(1)) |
| eq = (a[0, :L] == b[0, :L]) |
| nz = (~eq).nonzero(as_tuple=False) |
| return int(nz[0, 0]) if nz.numel() else L |
|
|
|
|
| def _assistant_header_ids_from_chat_template(tokenizer, device): |
| """Infer assistant header ids by diffing chat-template renders.""" |
| probe_messages_list = [ |
| [{"role": "system", "content": "system"}, {"role": "user", "content": "user"}], |
| [{"role": "user", "content": "user"}], |
| ] |
|
|
| for messages in probe_messages_list: |
| try: |
| rendered_no_prompt = apply_chat_template(tokenizer, messages, add_generation_prompt=False) |
| rendered_with_prompt = apply_chat_template(tokenizer, messages, add_generation_prompt=True) |
| except Exception: |
| continue |
|
|
| if not isinstance(rendered_no_prompt, str) or not isinstance(rendered_with_prompt, str): |
| continue |
| if not rendered_with_prompt or rendered_with_prompt == rendered_no_prompt: |
| continue |
|
|
| if rendered_with_prompt.startswith(rendered_no_prompt): |
| suffix_text = rendered_with_prompt[len(rendered_no_prompt):] |
| if suffix_text: |
| try: |
| ids = tokenizer( |
| suffix_text, |
| add_special_tokens=False, |
| return_tensors="pt", |
| ).input_ids.to(device) |
| if ids.numel() > 0: |
| return ids |
| except Exception: |
| pass |
|
|
| try: |
| ids_no_prompt = tokens_from_messages(tokenizer, messages, device, add_generation_prompt=False) |
| ids_with_prompt = tokens_from_messages(tokenizer, messages, device, add_generation_prompt=True) |
| except Exception: |
| continue |
|
|
| if ids_with_prompt.numel() == 0 or ids_with_prompt.size(1) <= ids_no_prompt.size(1): |
| continue |
|
|
| prefix_len = lcp_len(ids_no_prompt, ids_with_prompt) |
| if prefix_len < ids_with_prompt.size(1): |
| delta_ids = ids_with_prompt[:, prefix_len:] |
| if delta_ids.numel() > 0: |
| return delta_ids |
|
|
| return None |
|
|
| def _assistant_header_ids(tokenizer, device): |
| """Best-effort header ids appended by add_generation_prompt for common templates.""" |
| if _is_mistral_template(tokenizer): |
| try: |
| tok_id = tokenizer.convert_tokens_to_ids(MISTRAL_ASSIST_START) |
| if tok_id is None or tok_id == tokenizer.unk_token_id: |
| return None |
| return torch.tensor([[int(tok_id)]], device=device, dtype=torch.long) |
| except Exception: |
| return None |
| tmpl = getattr(tokenizer, "chat_template", "") or "" |
| if "<|start_header_id|>" in tmpl: |
| header = "<|start_header_id|>assistant<|end_header_id|>\n\n" |
| try: |
| return tokenizer(header, add_special_tokens=False, return_tensors="pt").input_ids.to(device) |
| except Exception: |
| return None |
| return _assistant_header_ids_from_chat_template(tokenizer, device) |
|
|
|
|
| def split_system_user_ids(tokenizer, system_text: str, user_text: str, device): |
| msgs_sys = [{"role": "system", "content": system_text}] |
| msgs_sys_user = [{"role": "system", "content": system_text}, |
| {"role": "user", "content": user_text}] |
| ids_sys = tokens_from_messages(tokenizer, msgs_sys, device, add_generation_prompt=False) |
| ids_sys_user = tokens_from_messages(tokenizer, msgs_sys_user, device, add_generation_prompt=False) |
| if ids_sys_user.size(1) < ids_sys.size(1): |
| raise ValueError("System-only tokens longer than system+user tokens.") |
| user_ids = ids_sys_user[:, ids_sys.size(1):] |
| return ids_sys, user_ids, ids_sys_user |
|
|
| |
| |
| |
| def pkv_len(pkv) -> int: |
| if hasattr(pkv, "layers"): return len(pkv.layers) |
| if hasattr(pkv, "key_cache"): return len(pkv.key_cache) |
| return len(pkv) |
|
|
| def pkv_get(pkv, idx: int): |
| if hasattr(pkv, "layers"): |
| layer = pkv.layers[idx] |
| return layer.keys, layer.values |
| if hasattr(pkv, "key_cache"): |
| return pkv.key_cache[idx], pkv.value_cache[idx] |
| return pkv[idx] |
|
|
| def dc_from_subset(pkv_src, idxs: List[int]) -> DynamicCache: |
| dc = DynamicCache() |
| for li in idxs: |
| k, v = pkv_get(pkv_src, li) |
| dc.update(k, v, li) |
| return dc |
|
|
|
|
| def _safe_dynamic_cache(config=None) -> DynamicCache: |
| try: |
| return DynamicCache(config=config) |
| except TypeError as exc: |
| if "max_cache_len" in str(exc): |
| return DynamicCache() |
| raise |
|
|
| def _get_inner_model(m): |
| """Return the decoder backbone that owns `.layers` (robust across wrappers).""" |
| |
| if hasattr(m, "module"): |
| m = m.module |
| |
| try: |
| from peft import PeftModel |
| if isinstance(m, PeftModel): |
| try: |
| m = m.get_base_model() |
| except Exception: |
| m = getattr(m, "base_model", m) |
| except Exception: |
| pass |
|
|
| for attr in ("model", "transformer", "backbone", "base_model", "language_model"): |
| if hasattr(m, attr): |
| cand = getattr(m, attr) |
| if hasattr(cand, "layers") and isinstance(getattr(cand, "layers", None), nn.ModuleList): |
| return cand |
| if hasattr(cand, "decoder") and hasattr(cand.decoder, "layers") and isinstance(cand.decoder.layers, nn.ModuleList): |
| return cand.decoder |
| if hasattr(m, "layers") and isinstance(getattr(m, "layers", None), nn.ModuleList): |
| return m |
| for child in m.modules(): |
| if child is m: |
| continue |
| if hasattr(child, "layers") and isinstance(getattr(child, "layers", None), nn.ModuleList): |
| return child |
| raise AttributeError("Could not locate inner base model with a .layers attribute") |
|
|
| def _get_llopa_core(model): |
| """Return the decoder object that owns the LLoPA prefill/cache hooks.""" |
| inner = _get_inner_model(model) |
| if hasattr(inner, "llopa_prefill_cache"): |
| return inner |
| if hasattr(inner, "model") and hasattr(inner.model, "llopa_prefill_cache"): |
| return inner.model |
| if hasattr(inner, "tri_build_caches"): |
| return inner |
| if hasattr(inner, "model") and hasattr(inner.model, "tri_build_caches"): |
| return inner.model |
| return None |
|
|
| def _get_llopa_decode_step(model): |
| """Return the cached LLoPA decode-step callable if present.""" |
| for name in ("llopa_decode_step_logits", "tri_step_logits"): |
| if hasattr(model, name): |
| return getattr(model, name) |
| try: |
| from peft import PeftModel |
| if isinstance(model, PeftModel): |
| try: |
| base = model.get_base_model() |
| except Exception: |
| base = getattr(model, "base_model", None) |
| if base is not None: |
| for name in ("llopa_decode_step_logits", "tri_step_logits"): |
| if hasattr(base, name): |
| return getattr(base, name) |
| except Exception: |
| pass |
| return None |
|
|
|
|
| def _get_llopa_full_prompt_seed(model): |
| """Return the full-prompt seed callable used for past-assistant contexts.""" |
| for name in ("llopa_full_prompt_prefill_seed", "tri_reference_prefill_seed"): |
| if hasattr(model, name): |
| return getattr(model, name) |
| try: |
| from peft import PeftModel |
| if isinstance(model, PeftModel): |
| try: |
| base = model.get_base_model() |
| except Exception: |
| base = getattr(model, "base_model", None) |
| if base is not None: |
| for name in ("llopa_full_prompt_prefill_seed", "tri_reference_prefill_seed"): |
| if hasattr(base, name): |
| return getattr(base, name) |
| except Exception: |
| pass |
| return None |
|
|
|
|
| def _has_active_llopa_runtime(model) -> bool: |
| """Return True when LLoPA hooks are present on the loaded model.""" |
| return _get_llopa_core(model) is not None and _get_llopa_decode_step(model) is not None |
|
|
|
|
| def _round_up_to_multiple(value: int, multiple: int) -> int: |
| value = int(value) |
| multiple = int(multiple) |
| if multiple <= 0: |
| return value |
| return ((value + multiple - 1) // multiple) * multiple |
|
|
|
|
| _OPTIMIZED_LLOPA_VARIANT_PRESETS = { |
| "baseline": { |
| "seed_mode": "auto", |
| "upper_prepare_mode": "exact", |
| "upper_bucket_multiple": 0, |
| "seq_bucket_multiple": 256, |
| }, |
| "upper_ws_auto": { |
| "seed_mode": "auto", |
| "upper_prepare_mode": "bucketed_workspace", |
| "upper_bucket_multiple": 256, |
| "seq_bucket_multiple": 256, |
| }, |
| "upper_ws_auto_128": { |
| "seed_mode": "auto", |
| "upper_prepare_mode": "bucketed_workspace", |
| "upper_bucket_multiple": 128, |
| "seq_bucket_multiple": 128, |
| }, |
| "upper_ws_tri": { |
| "seed_mode": "tri", |
| "upper_prepare_mode": "bucketed_workspace", |
| "upper_bucket_multiple": 256, |
| "seq_bucket_multiple": 256, |
| }, |
| "upper_ws_stable": { |
| "seed_mode": "stable", |
| "upper_prepare_mode": "bucketed_workspace", |
| "upper_bucket_multiple": 256, |
| "seq_bucket_multiple": 256, |
| }, |
| "upper_ws_matched": { |
| "seed_mode": "matched", |
| "upper_prepare_mode": "bucketed_workspace", |
| "upper_bucket_multiple": 256, |
| "seq_bucket_multiple": 256, |
| }, |
| } |
|
|
|
|
| def _normalize_optimized_llopa_seed_mode(seed_mode: Optional[str]) -> str: |
| raw = str(seed_mode or "auto").strip().lower() |
| if raw in {"", "default"}: |
| raw = "auto" |
| if raw not in {"auto", "tri", "stable", "matched"}: |
| raw = "auto" |
| return raw |
|
|
|
|
| def _normalize_optimized_llopa_upper_prepare_mode(mode: Optional[str]) -> str: |
| raw = str(mode or "exact").strip().lower() |
| if raw in {"", "default"}: |
| raw = "exact" |
| if raw not in {"exact", "bucketed_workspace"}: |
| raw = "exact" |
| return raw |
|
|
|
|
| def _resolve_optimized_llopa_settings( |
| *, |
| variant: Optional[str], |
| seed_mode: Optional[str], |
| upper_prepare_mode: Optional[str], |
| upper_bucket_multiple: Optional[int], |
| seq_bucket_multiple: Optional[int], |
| ): |
| preset_name = str(variant or "upper_ws_auto").strip().lower() |
| if preset_name in {"", "default", "auto"}: |
| preset_name = "upper_ws_auto" |
| preset = _OPTIMIZED_LLOPA_VARIANT_PRESETS.get( |
| preset_name, |
| _OPTIMIZED_LLOPA_VARIANT_PRESETS["upper_ws_auto"], |
| ) |
| resolved_seed_mode = _normalize_optimized_llopa_seed_mode( |
| seed_mode if seed_mode is not None else preset.get("seed_mode") |
| ) |
| resolved_upper_prepare_mode = _normalize_optimized_llopa_upper_prepare_mode( |
| upper_prepare_mode if upper_prepare_mode is not None else preset.get("upper_prepare_mode") |
| ) |
| resolved_upper_bucket_multiple = ( |
| int(upper_bucket_multiple) |
| if upper_bucket_multiple is not None |
| else int(preset.get("upper_bucket_multiple", 0) or 0) |
| ) |
| resolved_seq_bucket_multiple = ( |
| int(seq_bucket_multiple) |
| if seq_bucket_multiple is not None |
| else int(preset.get("seq_bucket_multiple", 256) or 256) |
| ) |
| if resolved_upper_prepare_mode != "bucketed_workspace": |
| resolved_upper_bucket_multiple = 0 |
| if resolved_upper_bucket_multiple < 0: |
| resolved_upper_bucket_multiple = 0 |
| if resolved_seq_bucket_multiple <= 0: |
| resolved_seq_bucket_multiple = 256 |
| return { |
| "variant": preset_name, |
| "seed_mode": resolved_seed_mode, |
| "upper_prepare_mode": resolved_upper_prepare_mode, |
| "upper_bucket_multiple": int(resolved_upper_bucket_multiple), |
| "seq_bucket_multiple": int(resolved_seq_bucket_multiple), |
| } |
|
|
|
|
| @contextlib.contextmanager |
| def _temporary_model_attrs(model, **updates): |
| sentinel = object() |
| prior = {} |
| try: |
| for key, value in updates.items(): |
| prior[key] = getattr(model, key, sentinel) |
| setattr(model, key, value) |
| yield |
| finally: |
| for key, old_value in prior.items(): |
| if old_value is sentinel: |
| with contextlib.suppress(Exception): |
| delattr(model, key) |
| else: |
| with contextlib.suppress(Exception): |
| setattr(model, key, old_value) |
|
|
|
|
| def _acquire_bucketed_sequence_workspace( |
| model, |
| *, |
| reference_ids: torch.Tensor, |
| batch_size: int, |
| total_len: int, |
| bucket_multiple: int, |
| ): |
| bucket_total_len = _round_up_to_multiple(total_len, bucket_multiple) |
| if bucket_total_len <= 0: |
| bucket_total_len = int(total_len) |
| dtype = reference_ids.dtype |
| device = reference_ids.device |
| key = (str(device), str(dtype), int(batch_size), int(bucket_total_len)) |
| cache = getattr(model, "_optimized_llopa_sequence_workspace_cache", None) |
| if not isinstance(cache, dict): |
| cache = {} |
| workspace = cache.get(key) |
| if ( |
| not isinstance(workspace, torch.Tensor) |
| or workspace.device != device |
| or workspace.dtype != dtype |
| or workspace.shape != (int(batch_size), int(bucket_total_len)) |
| ): |
| workspace = torch.empty((int(batch_size), int(bucket_total_len)), dtype=dtype, device=device) |
| cache[key] = workspace |
| try: |
| setattr(model, "_optimized_llopa_sequence_workspace_cache", cache) |
| except Exception: |
| pass |
| return workspace, int(bucket_total_len) |
|
|
| def _kv_meta_from_model(model_like): |
| """Return (num_kv_heads, head_dim, dtype).""" |
| try: |
| cfg = getattr(model_like, "config", None) or getattr(_get_inner_model(model_like), "config", None) |
| except Exception: |
| cfg = getattr(_get_inner_model(model_like), "config", None) |
| num_heads = getattr(cfg, "num_attention_heads", None) |
| num_kv = getattr(cfg, "num_key_value_heads", None) or num_heads |
| hidden = getattr(cfg, "hidden_size", None) |
| head_dim = (hidden // num_heads) if (hidden and num_heads) else None |
| try: |
| dtype = next(_get_inner_model(model_like).parameters()).dtype |
| except Exception: |
| dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
| return int(num_kv), int(head_dim), dtype |
|
|
| def _make_empty_kv(batch: int, num_kv: int, head_dim: int, device, dtype): |
| shape = (batch, num_kv, 0, head_dim) |
| k = torch.empty(shape, device=device, dtype=dtype) |
| v = torch.empty(shape, device=device, dtype=dtype) |
| return k.contiguous(), v.contiguous() |
|
|
| |
| |
| |
| def _llopa_split_system(system_ids: torch.Tensor, system_prefill: str): |
| """Return (system_upper, system_lower_extra) based on system_prefill.""" |
| sys_prefill = (system_prefill or "full").strip().lower() |
| if sys_prefill not in {"full", "no_system", "no_bos_system"}: |
| sys_prefill = "full" |
| if sys_prefill == "full": |
| return system_ids, system_ids[:, :0] |
| if sys_prefill == "no_system": |
| if system_ids.size(1) < 1: |
| return system_ids[:, :0], system_ids[:, :0] |
| return system_ids[:, :1], system_ids[:, 1:] |
| |
| return system_ids[:, :0], system_ids |
|
|
|
|
| def _llopa_merge_replay_user_span( |
| system_ids: torch.Tensor, |
| *, |
| system_prefill: str, |
| replay_user_prefix_keep_len: int, |
| replay_user_start: Optional[int], |
| replay_user_len: Optional[int], |
| ): |
| _, sys_lower_extra = _llopa_split_system(system_ids, system_prefill) |
| merged_prefix_keep_len = int(sys_lower_extra.size(1)) + max(int(replay_user_prefix_keep_len or 0), 0) |
| merged_user_start = None |
| if replay_user_start is not None: |
| merged_user_start = int(sys_lower_extra.size(1)) + max(int(replay_user_start), 0) |
| merged_user_len = None if replay_user_len is None else max(int(replay_user_len), 0) |
| return int(merged_prefix_keep_len), merged_user_start, merged_user_len |
|
|
|
|
| def _llopa_merge_user(system_ids: torch.Tensor, user_ids: torch.Tensor, system_prefill: str): |
| sys_upper, sys_lower_extra = _llopa_split_system(system_ids, system_prefill) |
| if sys_lower_extra.numel() == 0: |
| return sys_upper, user_ids |
| if user_ids.numel() == 0: |
| return sys_upper, sys_lower_extra |
| return sys_upper, torch.cat([sys_lower_extra, user_ids], dim=1) |
|
|
|
|
| def _llopa_prefill_cache(llopa_core, |
| system_ids: torch.Tensor, |
| user_ids: torch.Tensor, |
| assistant_ids: torch.Tensor, |
| *, |
| lower_k: int, |
| prefill_mode: str, |
| prefill_attn: str, |
| system_prefill: str, |
| return_last_assistant_hidden: bool = False, |
| replay_user_prefix_keep_len: int = 0, |
| replay_user_start: Optional[int] = None, |
| replay_user_len: Optional[int] = None): |
| prefill_mode = (prefill_mode or "lower").strip().lower() |
| prefill_attn = (prefill_attn or "causal").strip().lower() |
| if prefill_attn == "prefix_full": |
| prefill_attn = "full" |
| if prefill_mode != "lower": |
| raise ValueError("llopa_prefill requires prefill_mode='lower'.") |
| if prefill_attn not in {"causal", "full"}: |
| raise ValueError("llopa_prefill requires prefill_attn in {'causal','full'}.") |
| llopa_fn = getattr(llopa_core, "llopa_prefill_cache", None) |
| if llopa_fn is None: |
| raise RuntimeError("llopa_prefill_cache not found. Check LLoPA modeling patch.") |
| sys_upper, user_llopa = _llopa_merge_user(system_ids, user_ids, system_prefill) |
| merged_replay_prefix_keep_len, merged_replay_user_start, merged_replay_user_len = _llopa_merge_replay_user_span( |
| system_ids, |
| system_prefill=system_prefill, |
| replay_user_prefix_keep_len=int(replay_user_prefix_keep_len or 0), |
| replay_user_start=replay_user_start, |
| replay_user_len=replay_user_len, |
| ) |
| prefill_out = llopa_fn( |
| system_ids=sys_upper, |
| user_ids=user_llopa, |
| assistant_ids=assistant_ids, |
| lower_k=lower_k, |
| prefill_mode=prefill_mode, |
| prefill_attn=prefill_attn, |
| return_last_assistant_hidden=bool(return_last_assistant_hidden), |
| replay_user_prefix_keep_len=merged_replay_prefix_keep_len, |
| replay_user_start=merged_replay_user_start, |
| replay_user_len=merged_replay_user_len, |
| ) |
| if bool(return_last_assistant_hidden): |
| if not isinstance(prefill_out, tuple): |
| raise RuntimeError("llopa_prefill_cache did not return the requested last assistant hidden state.") |
| pkv, last_hidden = prefill_out |
| return pkv, sys_upper.size(1), user_llopa.size(1), last_hidden |
| pkv = prefill_out |
| return pkv, sys_upper.size(1), user_llopa.size(1) |
|
|
| |
| |
| |
| |
| |
| import contextlib |
|
|
| @contextlib.contextmanager |
| def lopa_cache_position_patch(model, past_key_values): |
| """ |
| Match trainer's dynamic position alignment: |
| For each decoder layer, compute its current past length from the provided |
| past_key_values, and during forward adjust cache_position/position_ids by |
| off = start_val - past_len so that lower-K layers (with past=L_sys+L_doc) |
| and upper layers (with past=0) align logically for the current token. |
| """ |
| inner = _get_inner_model(model) |
|
|
| |
| def _pkv_past_len(li: int) -> int: |
| if hasattr(past_key_values, "key_cache") and hasattr(past_key_values, "value_cache"): |
| return int(past_key_values.key_cache[li].shape[2]) |
| if hasattr(past_key_values, "layers"): |
| return int(past_key_values.layers[li].keys.shape[2]) |
| return int(past_key_values[li][0].shape[2]) |
|
|
| n_layers = len(inner.layers) |
| past_lens = [_pkv_past_len(li) for li in range(n_layers)] |
|
|
| handles = [] |
| for li, layer in enumerate(inner.layers): |
| layer._lopa_past = int(past_lens[li]) |
|
|
| def _pre_hook(module, args, kwargs): |
| past_len = getattr(module, "_lopa_past", 0) |
| cp = kwargs.get("cache_position", None) |
| pi = kwargs.get("position_ids", None) |
| start_val = None |
| if isinstance(cp, torch.Tensor) and cp.numel() > 0: |
| start_val = int(cp.view(-1)[0].item()) |
| elif isinstance(pi, torch.Tensor) and pi.numel() > 0: |
| start_val = int(pi.view(-1)[0].item()) |
| if start_val is not None: |
| off = start_val - past_len |
| if off != 0: |
| if isinstance(cp, torch.Tensor): |
| kwargs["cache_position"] = cp - off |
| if isinstance(pi, torch.Tensor): |
| kwargs["position_ids"] = pi - off |
| return args, kwargs |
|
|
| h = layer.register_forward_pre_hook(_pre_hook, with_kwargs=True) |
| handles.append(h) |
|
|
| try: |
| yield |
| finally: |
| for h in handles: |
| h.remove() |
| for layer in inner.layers: |
| if hasattr(layer, "_lopa_past"): |
| delattr(layer, "_lopa_past") |
|
|
| |
| |
| |
| @torch.inference_mode() |
|
|
| def _get_peft_wrapper(m): |
| try: |
| from peft import PeftModel |
| except Exception: |
| return None |
| if hasattr(m, "module"): |
| m = m.module |
| return m if isinstance(m, PeftModel) else None |
|
|
|
|
| def _set_prefill_adapter(model, enabled: bool) -> None: |
| if not bool(getattr(model, "_prefill_adapter_only", False)): |
| return |
| peft_model = _get_peft_wrapper(model) |
| if peft_model is None: |
| return |
| base = getattr(peft_model, "base_model", None) |
| if base is None: |
| return |
| try: |
| if enabled: |
| base.enable_adapter_layers() |
| else: |
| base.disable_adapter_layers() |
| except Exception: |
| return |
|
|
|
|
| def _get_output_head(model): |
| getter = getattr(model, "get_output_embeddings", None) |
| if callable(getter): |
| head = getter() |
| if head is not None: |
| return head |
| return getattr(model, "lm_head", None) |
|
|
|
|
| def _env_flag_enabled(name: str, default: str = "1") -> bool: |
| raw = os.environ.get(name, default).strip().lower() |
| return raw not in {"0", "false", "no", "off"} |
|
|
| @torch.inference_mode() |
| def lopa_generate(model, |
| tokenizer, |
| system: str, |
| document: str, |
| question: str, |
| *, |
| task: str = "qa_doc", |
| K: int, |
| prefill_mode: str = "lower", |
| prefill_attn: str = "causal", |
| system_prefill: str = "full", |
| user_prefill: str = "full", |
| device: str, |
| input_ids: Optional[dict] = None, |
| max_new_tokens: int = 256, |
| min_length: int = 16, |
| temperature: float = 0.7, |
| top_p: float = 0.9, |
| top_k: Optional[int] = None, |
| do_sample: bool = True, |
| math_force_final_hash_rule: bool = False, |
| log_cuda_mem: bool = False, |
| log_cuda_tag: Optional[str] = None, |
| debug: bool = False, |
| debug_dir: Optional[Path] = None, |
| llopa_prefill: bool = False, |
| no_upper_attn: Optional[bool] = None, |
| return_tokens: bool = False) -> str | Tuple[str, int]: |
| |
| if input_ids is None and task == "math" and math_force_final_hash_rule and "####" not in (system or ""): |
| system = ( |
| system.rstrip() |
| + " Conclude your explanation with the answer in a '#### {numeric answer}' format, " |
| + "where the answer is solely a number." |
| ) |
| user_prefill = (user_prefill or "full").strip().lower() |
| if user_prefill not in {"full", "no_question"}: |
| user_prefill = "full" |
| prefill_attn = (prefill_attn or "causal").strip().lower() |
| if prefill_attn == "prefix_full": |
| prefill_attn = "full" |
| if prefill_attn not in {"causal", "full"}: |
| raise ValueError("prefill_attn must be one of: causal | full") |
|
|
| llopa_inputs = _coerce_llopa_inputs( |
| tokenizer, |
| system=system, |
| document=document, |
| question=question, |
| task=task, |
| device=device, |
| input_ids=input_ids, |
| ) |
| ids_sys = llopa_inputs["system_ids"] |
| ids_sys_user = llopa_inputs["system_user_ids"] |
| user_doc_ids = llopa_inputs["user_doc_ids"] |
| user_q_ids = llopa_inputs["user_q_ids"] |
| hdr_tail = llopa_inputs["hdr_tail"] |
| ids_hdr = llopa_inputs.get("prompt_ids", None) |
|
|
| if debug: |
| try: |
| msgs = build_messages(system, document, question, include_query=True, task=task) |
| s_no_hdr = apply_chat_template(tokenizer, msgs, add_generation_prompt=False) |
| s_with_hdr = apply_chat_template(tokenizer, msgs, add_generation_prompt=True) |
| print(f"[debug] render lengths (chars): no_hdr={len(s_no_hdr)}, with_hdr={len(s_with_hdr)}") |
| if debug_dir is not None: |
| debug_dir.mkdir(parents=True, exist_ok=True) |
| (debug_dir / "infer_render_no_header.txt").write_text(s_no_hdr, encoding="utf-8") |
| (debug_dir / "infer_render_with_header.txt").write_text(s_with_hdr, encoding="utf-8") |
| except Exception as e: |
| print(f"[debug] render dump failed: {e}") |
|
|
| |
| if ids_hdr is None: |
| ids_hdr = torch.cat([ids_sys_user, hdr_tail], dim=1) if hdr_tail is not None else ids_sys_user |
| hdr_tail = ids_hdr[:, ids_sys_user.size(1):] |
| if user_prefill == "no_question": |
| prefix_full = torch.cat([user_q_ids, hdr_tail], dim=1) |
| else: |
| prefix_full = hdr_tail |
|
|
| |
| llopa_core = _get_llopa_core(model) |
| llopa_step = _get_llopa_decode_step(model) |
| if llopa_core is None or llopa_step is None: |
| raise RuntimeError("Custom LLoPA modeling not active. Check --lopa_modeling_path/--modeling_family.") |
| if no_upper_attn is None: |
| no_upper_attn = bool(getattr(model, "_no_upper_attn", False)) |
| no_upper_attn = bool(no_upper_attn) |
| if no_upper_attn and (not bool(llopa_prefill)): |
| print("[infer][warn] no_upper_attn is ignored unless llopa_prefill=True.") |
| effective_no_upper_attn = bool(no_upper_attn and bool(llopa_prefill)) |
| llopa_step_accepts_no_upper_attn = False |
| try: |
| llopa_step_params = inspect.signature(llopa_step).parameters |
| llopa_step_accepts_no_upper_attn = ("no_upper_attn" in llopa_step_params) |
| except Exception: |
| llopa_step_accepts_no_upper_attn = False |
| if effective_no_upper_attn and (not llopa_step_accepts_no_upper_attn): |
| print("[infer][warn] no_upper_attn requested but llopa_decode_step_logits does not support it; ignoring.") |
| effective_no_upper_attn = False |
|
|
| def _log_mem(tag: str) -> None: |
| if not log_cuda_mem or not torch.cuda.is_available(): |
| return |
| try: |
| torch.cuda.synchronize() |
| except Exception: |
| pass |
| alloc = torch.cuda.memory_allocated() / (1024 ** 3) |
| reserved = torch.cuda.memory_reserved() / (1024 ** 3) |
| max_alloc = torch.cuda.max_memory_allocated() / (1024 ** 3) |
| max_reserved = torch.cuda.max_memory_reserved() / (1024 ** 3) |
| prefix = "[mem]" |
| if log_cuda_tag: |
| prefix = f"{prefix}[{log_cuda_tag}]" |
| print(f"{prefix} {tag} | alloc={alloc:.2f}GiB reserved={reserved:.2f}GiB " |
| f"max_alloc={max_alloc:.2f}GiB max_reserved={max_reserved:.2f}GiB") |
|
|
| if log_cuda_mem and torch.cuda.is_available(): |
| try: |
| torch.cuda.reset_peak_memory_stats() |
| except Exception: |
| pass |
| _log_mem("start") |
|
|
| |
| _set_prefill_adapter(model, True) |
| lower_k = int(K) |
| sys_prefill = (system_prefill or "full").strip().lower() |
| if sys_prefill not in {"full", "no_system", "no_bos_system"}: |
| sys_prefill = "full" |
| if user_prefill == "no_question": |
| user_prefill_ids = user_doc_ids |
| else: |
| user_prefill_ids = torch.cat([user_doc_ids, user_q_ids], dim=1) |
| if llopa_prefill: |
| use_fused_first_token = bool( |
| prefix_full.numel() > 0 |
| and not effective_no_upper_attn |
| and not bool(getattr(model, "_prefill_adapter_only", False)) |
| ) |
| if use_fused_first_token: |
| pkv, S, U, llopa_last_hidden = _llopa_prefill_cache( |
| llopa_core, |
| ids_sys, |
| user_prefill_ids, |
| prefix_full, |
| lower_k=lower_k, |
| prefill_mode=prefill_mode, |
| prefill_attn=prefill_attn, |
| system_prefill=sys_prefill, |
| return_last_assistant_hidden=True, |
| ) |
| else: |
| pkv, S, U = _llopa_prefill_cache( |
| llopa_core, |
| ids_sys, |
| user_prefill_ids, |
| prefix_full, |
| lower_k=lower_k, |
| prefill_mode=prefill_mode, |
| prefill_attn=prefill_attn, |
| system_prefill=sys_prefill, |
| ) |
| llopa_last_hidden = None |
| else: |
| use_fused_first_token = False |
| llopa_last_hidden = None |
| if sys_prefill == "full": |
| pkv, S, U = llopa_core.tri_build_caches( |
| system_ids=ids_sys, |
| user_ids=user_prefill_ids, |
| lower_k=lower_k, |
| prefill_mode=prefill_mode, |
| prefill_attn=prefill_attn, |
| ) |
| elif sys_prefill == "no_system": |
| if ids_sys.size(1) < 1: |
| pkv = _safe_dynamic_cache(getattr(llopa_core, "config", None)) |
| else: |
| bos_ids = ids_sys[:, :1] |
| rest_ids = ids_sys[:, 1:] |
| out = llopa_core.tri_prefill_system_all( |
| bos_ids, |
| past_key_values=None, |
| prefill_attn=prefill_attn, |
| ) |
| pkv = out.past_key_values |
| if rest_ids.size(1) > 0: |
| _ = llopa_core.tri_prefill_user_lower( |
| rest_ids, |
| lower_k=lower_k, |
| past_key_values=pkv, |
| prefill_mode=prefill_mode, |
| prefill_attn=prefill_attn, |
| ) |
| _ = llopa_core.tri_prefill_user_lower( |
| user_prefill_ids, |
| lower_k=lower_k, |
| past_key_values=pkv, |
| prefill_mode=prefill_mode, |
| prefill_attn=prefill_attn, |
| ) |
| S, U = ids_sys.size(1), user_prefill_ids.size(1) |
| else: |
| pkv = _safe_dynamic_cache(getattr(llopa_core, "config", None)) |
| if ids_sys.size(1) > 0: |
| _ = llopa_core.tri_prefill_user_lower( |
| ids_sys, |
| lower_k=lower_k, |
| past_key_values=pkv, |
| prefill_mode=prefill_mode, |
| prefill_attn=prefill_attn, |
| ) |
| _ = llopa_core.tri_prefill_user_lower( |
| user_prefill_ids, |
| lower_k=lower_k, |
| past_key_values=pkv, |
| prefill_mode=prefill_mode, |
| prefill_attn=prefill_attn, |
| ) |
| S, U = ids_sys.size(1), user_prefill_ids.size(1) |
|
|
| _set_prefill_adapter(model, False) |
| _log_mem("prefill_end") |
| initial_logits = None |
| if use_fused_first_token: |
| output_head = _get_output_head(model) |
| if output_head is not None and isinstance(llopa_last_hidden, torch.Tensor) and llopa_last_hidden.numel() > 0: |
| initial_logits = output_head(llopa_last_hidden)[:, -1, :].to(torch.float32) |
|
|
| |
| if llopa_prefill: |
| if prefix_full.numel() > 0: |
| last_pushed = prefix_full[:, -1:] |
| elif ids_sys_user.numel() > 0: |
| last_pushed = ids_sys_user[:, -1:] |
| else: |
| raise ValueError("Empty prompt after LLOPA prefill; cannot start decoding.") |
| else: |
| if prefix_full.numel() > 0: |
| seed_kwargs = dict( |
| assistant_ids=prefix_full, |
| lower_k=lower_k, |
| pkv=pkv, |
| S=S, |
| U=U, |
| logits_to_keep=0, |
| labels=None, |
| prefill_mode=prefill_mode, |
| ) |
| if effective_no_upper_attn and llopa_step_accepts_no_upper_attn: |
| seed_kwargs["no_upper_attn"] = True |
| out_seed = llopa_step(**seed_kwargs) |
| pkv = out_seed.past_key_values or pkv |
| last_pushed = prefix_full[:, -1:] |
| else: |
| step_tok = ids_sys_user[:, -1:] |
| seed_kwargs = dict( |
| assistant_ids=step_tok, |
| lower_k=lower_k, |
| pkv=pkv, |
| S=S, |
| U=U, |
| logits_to_keep=0, |
| labels=None, |
| prefill_mode=prefill_mode, |
| ) |
| if effective_no_upper_attn and llopa_step_accepts_no_upper_attn: |
| seed_kwargs["no_upper_attn"] = True |
| out_seed = llopa_step(**seed_kwargs) |
| pkv = out_seed.past_key_values or pkv |
| last_pushed = step_tok |
|
|
| if log_cuda_mem and torch.cuda.is_available(): |
| try: |
| torch.cuda.reset_peak_memory_stats() |
| except Exception: |
| pass |
| _log_mem("decode_start") |
|
|
| |
| from transformers.generation import LogitsProcessorList |
| from transformers.generation.logits_process import TemperatureLogitsWarper, TopPLogitsWarper, TopKLogitsWarper |
|
|
| eos_id = tokenizer.eos_token_id |
| try: |
| eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>") |
| except Exception: |
| eot_id = None |
| stop_ids = set() |
| if eos_id is not None: stop_ids.add(int(eos_id)) |
| if eot_id is not None and eot_id != tokenizer.unk_token_id: stop_ids.add(int(eot_id)) |
|
|
| procs = None |
| if do_sample: |
| procs = LogitsProcessorList() |
| if temperature and temperature != 1.0: |
| procs.append(TemperatureLogitsWarper(temperature=float(temperature))) |
| if top_p and top_p < 1.0: |
| procs.append(TopPLogitsWarper(top_p=float(top_p), min_tokens_to_keep=1)) |
| if top_k is not None and top_k > 0: |
| procs.append(TopKLogitsWarper(top_k=int(top_k), filter_value=-float("inf"))) |
|
|
| device_t = last_pushed.device |
| last = last_pushed |
| generated = torch.empty((1, max_new_tokens), dtype=torch.long, device=device_t) |
| cur = 0 |
| stop_reason = None |
| pending_logits = initial_logits |
|
|
| while cur < max_new_tokens: |
| if pending_logits is None: |
| step_kwargs = dict( |
| assistant_ids=last, |
| lower_k=lower_k, |
| pkv=pkv, |
| S=S, |
| U=U, |
| logits_to_keep=1, |
| labels=None, |
| prefill_mode=prefill_mode, |
| ) |
| if effective_no_upper_attn and llopa_step_accepts_no_upper_attn: |
| step_kwargs["no_upper_attn"] = True |
| out = llopa_step(**step_kwargs) |
| pkv = out.past_key_values or pkv |
| logits = out.logits[:, -1, :] |
| else: |
| logits = pending_logits |
| pending_logits = None |
|
|
| |
| if stop_ids and cur < min_length: |
| for sid in stop_ids: |
| logits[:, sid] = -float("inf") |
|
|
| if procs is not None: |
| inp_for_proc = generated[:, :cur] |
| if logits.dtype != torch.float32: |
| logits = logits.float() |
| logits = procs(inp_for_proc, logits) |
|
|
| if do_sample: |
| probs = torch.softmax(logits, dim=-1) |
| next_tok = torch.multinomial(probs, num_samples=1) |
| else: |
| next_tok = torch.argmax(logits, dim=-1, keepdim=True) |
|
|
| generated[:, cur:cur + 1] = next_tok |
| last = next_tok |
|
|
| if stop_ids and cur >= min_length: |
| tok_id = int(next_tok.item()) |
| if tok_id in stop_ids: |
| stop_reason = f"stop_token:{tok_id}" |
| cur += 1 |
| break |
|
|
| cur += 1 |
|
|
| if stop_reason is None and cur >= max_new_tokens: |
| stop_reason = "max_new_tokens" |
|
|
| gen_ids = generated[:, :cur] |
| text = tokenizer.decode(gen_ids[0].tolist(), skip_special_tokens=True) |
| _log_mem("decode_end") |
| if debug: |
| print(f"[debug] finished | tokens={cur} | reason={stop_reason}") |
| if debug_dir is not None: |
| debug_dir.mkdir(parents=True, exist_ok=True) |
| (debug_dir / "infer_generated.txt").write_text(text, encoding="utf-8") |
| if return_tokens: |
| return text, int(cur) |
| return text |
|
|
| |
| |
| |
|
|
| def llopa_generate(*args, **kwargs): |
| """Capsule interface: identical behavior to lopa_generate.""" |
| if "device" not in kwargs or kwargs.get("device") is None: |
| if len(args) >= 1: |
| model = args[0] |
| try: |
| dev = model.get_input_embeddings().weight.device |
| except Exception: |
| try: |
| dev = next(model.parameters()).device |
| except Exception: |
| dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| kwargs["device"] = str(dev) |
| return lopa_generate(*args, **kwargs) |
|
|
|
|
| def _read_kv_file(path: Path) -> dict[str, str]: |
| info: dict[str, str] = {} |
| try: |
| lines = path.read_text(encoding="utf-8").splitlines() |
| except Exception: |
| return info |
| for raw in lines: |
| line = raw.strip() |
| if not line or line.startswith("#"): |
| continue |
| if "=" not in line: |
| continue |
| k, v = line.split("=", 1) |
| k = k.strip() |
| v = v.strip() |
| if k: |
| info[k] = v |
| return info |
|
|
|
|
| def _read_adapter_backbone_ref(repo_path: Path) -> str: |
| adapter_cfg = repo_path / "adapter_config.json" |
| if not adapter_cfg.is_file(): |
| return "" |
| try: |
| data = json.loads(adapter_cfg.read_text(encoding="utf-8")) |
| except Exception: |
| return "" |
| val = data.get("base_model_name_or_path") |
| if isinstance(val, str): |
| return val.strip() |
| return "" |
|
|
|
|
| def _resolve_repo_path(model_repo: str, cache_dir: Optional[str] = None, |
| revision: Optional[str] = None, token: Optional[str] = None, |
| local_files_only: bool = False) -> Path: |
| repo = Path(model_repo) |
| if repo.exists(): |
| return repo |
| try: |
| from huggingface_hub import snapshot_download |
| except Exception as exc: |
| raise RuntimeError("huggingface_hub is required to load remote repos") from exc |
| path = snapshot_download( |
| repo_id=model_repo, |
| cache_dir=cache_dir, |
| revision=revision, |
| token=token, |
| local_files_only=local_files_only, |
| ) |
| return Path(path) |
|
|
|
|
| def _repo_has_pretrained_weights(repo_path: Path) -> bool: |
| weight_files = ( |
| "pytorch_model.bin", |
| "pytorch_model.bin.index.json", |
| "model.safetensors", |
| "model.safetensors.index.json", |
| ) |
| if any((repo_path / name).is_file() for name in weight_files): |
| return True |
| return any(repo_path.glob("pytorch_model-*-of-*.bin")) or any(repo_path.glob("model-*-of-*.safetensors")) |
|
|
|
|
| def _pick_vocab_weight_key(keys) -> Optional[str]: |
| preferred = ( |
| "model.embed_tokens.weight", |
| "embed_tokens.weight", |
| "model.decoder.embed_tokens.weight", |
| "transformer.wte.weight", |
| "lm_head.weight", |
| "model.lm_head.weight", |
| ) |
| key_list = list(keys) |
| for cand in preferred: |
| if cand in key_list: |
| return cand |
| suffixes = ( |
| "embed_tokens.weight", |
| "decoder.embed_tokens.weight", |
| "wte.weight", |
| "lm_head.weight", |
| ) |
| for suffix in suffixes: |
| for key in key_list: |
| if str(key).endswith(suffix): |
| return str(key) |
| return None |
|
|
|
|
| def _load_state_dict_meta(path: Path): |
| try: |
| return torch.load(path, map_location="meta", weights_only=True) |
| except TypeError: |
| return torch.load(path, map_location="meta") |
|
|
|
|
| def _weight_rows_from_file(path: Path, tensor_key: Optional[str] = None) -> Optional[int]: |
| try: |
| suffixes = path.suffixes |
| except Exception: |
| suffixes = [] |
| try: |
| if suffixes[-1:] == [".safetensors"] or suffixes[-2:] == [".model", ".safetensors"]: |
| from safetensors import safe_open |
|
|
| with safe_open(str(path), framework="pt", device="cpu") as handle: |
| chosen = tensor_key or _pick_vocab_weight_key(handle.keys()) |
| if not chosen: |
| return None |
| shape = handle.get_slice(chosen).get_shape() |
| if len(shape) >= 2: |
| return int(shape[0]) |
| return None |
| state_dict = _load_state_dict_meta(path) |
| if not isinstance(state_dict, dict): |
| return None |
| chosen = tensor_key if tensor_key in state_dict else _pick_vocab_weight_key(state_dict.keys()) |
| if not chosen: |
| return None |
| tensor = state_dict.get(chosen) |
| if isinstance(tensor, torch.Tensor) and tensor.ndim >= 2: |
| return int(tensor.shape[0]) |
| except Exception: |
| return None |
| return None |
|
|
|
|
| def _infer_checkpoint_vocab_size(repo_path: Path) -> Optional[int]: |
| index_files = ( |
| "pytorch_model.bin.index.json", |
| "model.safetensors.index.json", |
| ) |
| single_weight_files = ( |
| "pytorch_model.bin", |
| "model.safetensors", |
| ) |
|
|
| for name in index_files: |
| index_path = repo_path / name |
| if not index_path.is_file(): |
| continue |
| try: |
| data = json.loads(index_path.read_text(encoding="utf-8")) |
| except Exception: |
| continue |
| weight_map = data.get("weight_map") |
| if not isinstance(weight_map, dict): |
| continue |
| tensor_key = _pick_vocab_weight_key(weight_map.keys()) |
| if not tensor_key: |
| continue |
| shard_rel = weight_map.get(tensor_key) |
| if not shard_rel: |
| continue |
| rows = _weight_rows_from_file(repo_path / shard_rel, tensor_key=tensor_key) |
| if rows is not None and rows > 0: |
| return rows |
|
|
| for name in single_weight_files: |
| weight_path = repo_path / name |
| if not weight_path.is_file(): |
| continue |
| rows = _weight_rows_from_file(weight_path) |
| if rows is not None and rows > 0: |
| return rows |
|
|
| for weight_path in sorted(repo_path.glob("pytorch_model-*-of-*.bin")): |
| rows = _weight_rows_from_file(weight_path) |
| if rows is not None and rows > 0: |
| return rows |
| for weight_path in sorted(repo_path.glob("model-*-of-*.safetensors")): |
| rows = _weight_rows_from_file(weight_path) |
| if rows is not None and rows > 0: |
| return rows |
| return None |
|
|
|
|
| def _expand_tokenizer_placeholders(tokenizer, target_vocab_size: int) -> int: |
| try: |
| current_vocab_size = int(len(tokenizer)) |
| except Exception: |
| return 0 |
| if target_vocab_size <= current_vocab_size: |
| return 0 |
| missing = int(target_vocab_size - current_vocab_size) |
| placeholder_tokens = [f"<|capsule_missing_token_{idx}|>" for idx in range(missing)] |
| try: |
| return int(tokenizer.add_tokens(placeholder_tokens, special_tokens=True) or 0) |
| except Exception: |
| return 0 |
|
|
|
|
| def _normalize_special_token_values(raw_value: Any) -> list[str]: |
| if not isinstance(raw_value, list): |
| return [] |
| normalized: list[str] = [] |
| for item in raw_value: |
| token = "" |
| if isinstance(item, str): |
| token = item |
| elif isinstance(item, dict): |
| for key in ("content", "token", "text"): |
| value = item.get(key) |
| if isinstance(value, str) and value: |
| token = value |
| break |
| if token: |
| normalized.append(token) |
| return normalized |
|
|
|
|
| def _checkpoint_special_token_candidates(checkpoint_repo_path: Path) -> list[str]: |
| candidates: list[str] = [] |
| for filename in ("config.json", "tokenizer_config.json", "special_tokens_map.json"): |
| path = checkpoint_repo_path / filename |
| if not path.is_file(): |
| continue |
| try: |
| payload = json.loads(path.read_text(encoding="utf-8")) |
| except Exception: |
| continue |
| if not isinstance(payload, dict): |
| continue |
| candidates.extend(_normalize_special_token_values(payload.get("capsule_suffix_special_tokens"))) |
| candidates.extend(_normalize_special_token_values(payload.get("additional_special_tokens"))) |
| raw_decoder = payload.get("added_tokens_decoder") |
| if isinstance(raw_decoder, dict): |
| for _, value in sorted( |
| raw_decoder.items(), |
| key=lambda item: int(item[0]) if str(item[0]).isdigit() else str(item[0]), |
| ): |
| candidates.extend(_normalize_special_token_values([value])) |
| deduped: list[str] = [] |
| seen: set[str] = set() |
| for token in candidates: |
| if not token or token in seen: |
| continue |
| seen.add(token) |
| deduped.append(token) |
| return deduped |
|
|
|
|
| def _align_tokenizer_with_checkpoint_vocab( |
| tokenizer, |
| checkpoint_repo_path: Path, |
| target_vocab_size: int, |
| ) -> dict[str, int]: |
| report = { |
| "checkpoint_vocab_size": int(target_vocab_size or 0), |
| "tokenizer_vocab_size_before": 0, |
| "checkpoint_special_candidate_count": 0, |
| "checkpoint_specials_missing_before": 0, |
| "added_checkpoint_specials": 0, |
| "tokenizer_vocab_size_after_specials": 0, |
| "padding_gap_before_placeholders": 0, |
| "added_placeholders": 0, |
| "tokenizer_vocab_size_after": 0, |
| } |
| try: |
| current_vocab_size = int(len(tokenizer)) |
| except Exception: |
| return report |
| report["tokenizer_vocab_size_before"] = current_vocab_size |
| report["tokenizer_vocab_size_after_specials"] = current_vocab_size |
| report["tokenizer_vocab_size_after"] = current_vocab_size |
|
|
| missing = int(target_vocab_size - current_vocab_size) |
| if missing <= 0: |
| return report |
|
|
| checkpoint_specials = _checkpoint_special_token_candidates(checkpoint_repo_path) |
| report["checkpoint_special_candidate_count"] = len(checkpoint_specials) |
| try: |
| existing_vocab = set(tokenizer.get_vocab().keys()) |
| except Exception: |
| existing_vocab = set() |
|
|
| missing_checkpoint_specials = [token for token in checkpoint_specials if token not in existing_vocab] |
| report["checkpoint_specials_missing_before"] = len(missing_checkpoint_specials) |
|
|
| tokens_to_add: list[str] = [] |
| for token in missing_checkpoint_specials: |
| tokens_to_add.append(token) |
| existing_vocab.add(token) |
| if len(tokens_to_add) >= missing: |
| break |
|
|
| added_checkpoint_specials = 0 |
| if tokens_to_add: |
| try: |
| current_specials = tokenizer.special_tokens_map_extended.get("additional_special_tokens", []) or [] |
| added_checkpoint_specials = int( |
| tokenizer.add_special_tokens( |
| {"additional_special_tokens": list(current_specials) + tokens_to_add} |
| ) |
| or 0 |
| ) |
| except Exception: |
| try: |
| added_checkpoint_specials = int(tokenizer.add_tokens(tokens_to_add, special_tokens=True) or 0) |
| except Exception: |
| added_checkpoint_specials = 0 |
| report["added_checkpoint_specials"] = added_checkpoint_specials |
|
|
| try: |
| post_special_vocab_size = int(len(tokenizer)) |
| except Exception: |
| post_special_vocab_size = current_vocab_size + added_checkpoint_specials |
| report["tokenizer_vocab_size_after_specials"] = post_special_vocab_size |
| report["padding_gap_before_placeholders"] = max(0, int(target_vocab_size - post_special_vocab_size)) |
|
|
| added_placeholders = 0 |
| if target_vocab_size > post_special_vocab_size: |
| added_placeholders = _expand_tokenizer_placeholders(tokenizer, target_vocab_size) |
| report["added_placeholders"] = added_placeholders |
| try: |
| report["tokenizer_vocab_size_after"] = int(len(tokenizer)) |
| except Exception: |
| report["tokenizer_vocab_size_after"] = post_special_vocab_size + added_placeholders |
| return report |
|
|
|
|
| def _log_tokenizer_checkpoint_alignment( |
| log_prefix: str, |
| report: dict[str, int], |
| *, |
| print_fn=print, |
| ) -> None: |
| added_specials = int(report.get("added_checkpoint_specials", 0) or 0) |
| added_placeholders = int(report.get("added_placeholders", 0) or 0) |
| if added_specials <= 0 and added_placeholders <= 0: |
| return |
|
|
| vocab_before = int(report.get("tokenizer_vocab_size_before", 0) or 0) |
| vocab_after_specials = int(report.get("tokenizer_vocab_size_after_specials", vocab_before) or vocab_before) |
| vocab_after = int(report.get("tokenizer_vocab_size_after", vocab_after_specials) or vocab_after_specials) |
| missing_specials_before = int(report.get("checkpoint_specials_missing_before", 0) or 0) |
|
|
| if added_specials > 0: |
| print_fn( |
| f"{log_prefix}[info] tokenizer vocab is smaller than checkpoint embeddings; " |
| f"recovered {added_specials} checkpoint special tokens " |
| f"({vocab_before} -> {vocab_after_specials})." |
| ) |
| if added_placeholders > 0: |
| if missing_specials_before > 0: |
| print_fn( |
| f"{log_prefix}[warn] checkpoint embeddings still exceed tokenizer vocab after recovering " |
| f"checkpoint special tokens; added {added_placeholders} placeholder special tokens " |
| f"to preserve id alignment ({vocab_after_specials} -> {vocab_after})." |
| ) |
| else: |
| print_fn( |
| f"{log_prefix}[info] checkpoint embeddings include {added_placeholders} padded rows " |
| f"beyond tokenizer vocab; added {added_placeholders} placeholder special tokens " |
| f"to preserve id alignment ({vocab_after_specials} -> {vocab_after})." |
| ) |
|
|
|
|
| def _log_config_checkpoint_vocab_alignment( |
| log_prefix: str, |
| *, |
| config_vocab_size: int, |
| checkpoint_vocab_size: int, |
| tokenizer_vocab_size: int, |
| alignment_report: Optional[dict[str, int]] = None, |
| print_fn=print, |
| ) -> None: |
| if checkpoint_vocab_size <= 0 or checkpoint_vocab_size == config_vocab_size: |
| return |
|
|
| final_tokenizer_vocab = int(tokenizer_vocab_size or 0) |
| if alignment_report is not None: |
| final_tokenizer_vocab = int( |
| alignment_report.get("tokenizer_vocab_size_after", final_tokenizer_vocab) or final_tokenizer_vocab |
| ) |
| missing_specials_before = 0 |
| if alignment_report is not None: |
| missing_specials_before = int(alignment_report.get("checkpoint_specials_missing_before", 0) or 0) |
|
|
| if final_tokenizer_vocab >= checkpoint_vocab_size and missing_specials_before <= 0: |
| print_fn( |
| f"{log_prefix}[info] config vocab_size ({config_vocab_size}) lags tokenizer/checkpoint embeddings " |
| f"({checkpoint_vocab_size}); using checkpoint size." |
| ) |
| return |
|
|
| padding_rows = max(0, checkpoint_vocab_size - final_tokenizer_vocab) |
| if padding_rows > 0 and missing_specials_before <= 0: |
| print_fn( |
| f"{log_prefix}[info] config vocab_size ({config_vocab_size}) lags checkpoint embeddings " |
| f"({checkpoint_vocab_size}); using checkpoint size. tokenizer vocab is {final_tokenizer_vocab}, " |
| f"so {padding_rows} rows are embedding padding/alignment." |
| ) |
| return |
|
|
| print_fn( |
| f"{log_prefix}[warn] config vocab_size ({config_vocab_size}) does not match checkpoint embeddings " |
| f"({checkpoint_vocab_size}); using checkpoint size." |
| ) |
|
|
|
|
| def _expand_tokenizer_with_checkpoint_specials( |
| tokenizer, |
| checkpoint_repo_path: Path, |
| target_vocab_size: int, |
| ) -> tuple[int, int]: |
| report = _align_tokenizer_with_checkpoint_vocab( |
| tokenizer, |
| checkpoint_repo_path, |
| target_vocab_size, |
| ) |
| return ( |
| int(report.get("added_checkpoint_specials", 0) or 0), |
| int(report.get("added_placeholders", 0) or 0), |
| ) |
|
|
|
|
| def _resolve_modeling_path(repo_path: Path, user_path: Optional[str], model_family: str) -> Optional[str]: |
| if user_path: |
| cand = Path(user_path) |
| if not cand.is_file(): |
| cand = repo_path / user_path |
| return str(cand) if cand.is_file() else None |
| tri_info = repo_path / "tri_info.txt" |
| if tri_info.is_file(): |
| info = _read_kv_file(tri_info) |
| tri_file = info.get("lopa_modeling_path") or "" |
| if tri_file: |
| cand = repo_path / tri_file |
| if cand.is_file(): |
| return str(cand) |
| default_name = { |
| "llama": "tri_llama3_modeling.py", |
| "qwen3": "tri_qwen3_modeling.py", |
| "mistral": "tri_mistral_modeling.py", |
| }.get(model_family, "tri_llama3_modeling.py") |
| cand = repo_path / default_name |
| if cand.is_file(): |
| return str(cand) |
| return None |
|
|
|
|
| def _attach_llopa_generate(model): |
| try: |
| import types |
| def _llopa_generate(self, tokenizer, system: Optional[str] = None, document: Optional[str] = None, |
| question: Optional[str] = None, **kwargs): |
| if system is None: |
| system = kwargs.pop("system", "") |
| if document is None: |
| document = kwargs.pop("document", "") |
| if question is None: |
| question = kwargs.pop("question", "") |
| if "device" not in kwargs or kwargs.get("device") is None: |
| try: |
| dev = self.get_input_embeddings().weight.device |
| except Exception: |
| try: |
| dev = next(self.parameters()).device |
| except Exception: |
| dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| kwargs["device"] = str(dev) |
| return lopa_generate(self, tokenizer, system=system, document=document, question=question, **kwargs) |
| model.llopa_generate = types.MethodType(_llopa_generate, model) |
| except Exception: |
| pass |
|
|
|
|
| def _maybe_attach_llopa_generate(model): |
| if not hasattr(model, "llopa_generate"): |
| _attach_llopa_generate(model) |
|
|
|
|
| def _attach_prefill_lower_generate( |
| model, |
| *, |
| lower_k: int, |
| prefill_attn: str = "causal", |
| system_prefill: str = "no_bos_system", |
| no_upper_attn: bool = False, |
| ) -> None: |
| try: |
| import types |
| except Exception: |
| return |
|
|
| try: |
| lower_k = int(lower_k) |
| except Exception: |
| lower_k = 0 |
| attn = (prefill_attn or "causal").strip().lower() |
| if attn == "prefix_full": |
| attn = "full" |
| sys_prefill = (system_prefill or "no_bos_system").strip().lower() |
| if sys_prefill not in {"full", "no_system", "no_bos_system"}: |
| sys_prefill = "no_bos_system" |
| if lower_k <= 0 or attn not in {"causal", "full"}: |
| return |
|
|
| try: |
| setattr(model, "_runtime_prefill_lower_layers", int(lower_k)) |
| setattr(model, "_runtime_prefill_lower_attn", attn) |
| setattr(model, "_runtime_prefill_lower_system_prefill", sys_prefill) |
| setattr(model, "_runtime_prefill_lower_no_upper_attn", bool(no_upper_attn)) |
| except Exception: |
| return |
|
|
| if getattr(model, "_runtime_prefill_generate_attached", False): |
| return |
|
|
| orig_prepare = getattr(model, "prepare_inputs_for_generation", None) |
| if orig_prepare is None: |
| return |
|
|
| def _runtime_prepare_inputs_for_generation( |
| self, |
| input_ids, |
| past_key_values=None, |
| attention_mask=None, |
| inputs_embeds=None, |
| cache_position=None, |
| prefill_lower_layers=None, |
| prefill_lower_attn=None, |
| prefill_lower_system_prefill=None, |
| prefill_lower_no_upper_attn=None, |
| prefill_lower_split_start=None, |
| prefill_lower_system_len=None, |
| prefill_lower_replay_user_prefix_keep_len=None, |
| prefill_lower_replay_user_start=None, |
| prefill_lower_replay_user_len=None, |
| assistant_header_start=None, |
| assistant_header_starts=None, |
| assistant_header_start_mask=None, |
| **kwargs, |
| ): |
| model_inputs = orig_prepare( |
| input_ids, |
| past_key_values=past_key_values, |
| attention_mask=attention_mask, |
| inputs_embeds=inputs_embeds, |
| cache_position=cache_position, |
| **kwargs, |
| ) |
| if prefill_lower_layers is None: |
| prefill_lower_layers = int(getattr(self, "_runtime_prefill_lower_layers", 0) or 0) |
| if prefill_lower_attn is None: |
| prefill_lower_attn = str(getattr(self, "_runtime_prefill_lower_attn", "causal") or "causal") |
| if prefill_lower_system_prefill is None: |
| prefill_lower_system_prefill = str( |
| getattr(self, "_runtime_prefill_lower_system_prefill", "no_bos_system") or "no_bos_system" |
| ) |
| if prefill_lower_no_upper_attn is None: |
| prefill_lower_no_upper_attn = bool( |
| getattr(self, "_runtime_prefill_lower_no_upper_attn", False) |
| ) |
| if int(prefill_lower_layers or 0) > 0: |
| model_inputs["prefill_lower_layers"] = int(prefill_lower_layers) |
| model_inputs["prefill_lower_attn"] = str(prefill_lower_attn) |
| model_inputs["prefill_lower_system_prefill"] = str(prefill_lower_system_prefill) |
| if bool(prefill_lower_no_upper_attn): |
| model_inputs["prefill_lower_no_upper_attn"] = True |
| if prefill_lower_split_start is not None: |
| model_inputs["prefill_lower_split_start"] = prefill_lower_split_start |
| if prefill_lower_system_len is not None: |
| model_inputs["prefill_lower_system_len"] = prefill_lower_system_len |
| if prefill_lower_replay_user_prefix_keep_len is not None: |
| model_inputs["prefill_lower_replay_user_prefix_keep_len"] = prefill_lower_replay_user_prefix_keep_len |
| if prefill_lower_replay_user_start is not None: |
| model_inputs["prefill_lower_replay_user_start"] = prefill_lower_replay_user_start |
| if prefill_lower_replay_user_len is not None: |
| model_inputs["prefill_lower_replay_user_len"] = prefill_lower_replay_user_len |
| if assistant_header_start is not None: |
| model_inputs["assistant_header_start"] = assistant_header_start |
| if assistant_header_starts is not None: |
| model_inputs["assistant_header_starts"] = assistant_header_starts |
| if assistant_header_start_mask is not None: |
| model_inputs["assistant_header_start_mask"] = assistant_header_start_mask |
| return model_inputs |
|
|
| try: |
| model.prepare_inputs_for_generation = types.MethodType( |
| _runtime_prepare_inputs_for_generation, |
| model, |
| ) |
| setattr(model, "_runtime_prefill_generate_attached", True) |
| except Exception: |
| pass |
|
|
|
|
| def _attach_prefill_lower_freeze_generate( |
| model, |
| *, |
| tokenizer=None, |
| lower_k: int, |
| prefill_attn: str = "causal", |
| system_prefill: str = "no_bos_system", |
| ) -> None: |
| try: |
| import types |
| except Exception: |
| return |
|
|
| try: |
| lower_k = int(lower_k) |
| except Exception: |
| lower_k = 0 |
| attn = (prefill_attn or "causal").strip().lower() |
| if attn == "prefix_full": |
| attn = "full" |
| sys_prefill = (system_prefill or "no_bos_system").strip().lower() |
| if sys_prefill not in {"full", "no_system", "no_bos_system"}: |
| sys_prefill = "no_bos_system" |
| if lower_k <= 0 or attn not in {"causal", "full"}: |
| return |
|
|
| try: |
| setattr(model, "_runtime_prefill_freeze_layers", int(lower_k)) |
| setattr(model, "_runtime_prefill_freeze_attn", attn) |
| setattr(model, "_runtime_prefill_freeze_system_prefill", sys_prefill) |
| setattr(model, "_runtime_structured_freeze_generate_default", True) |
| except Exception: |
| return |
|
|
| if tokenizer is not None: |
| _attach_structured_llopa_generate(model, tokenizer) |
|
|
|
|
| def _attach_prefill_lower_solo_generate( |
| model, |
| *, |
| tokenizer=None, |
| lower_k: int, |
| prefill_attn: str = "causal", |
| system_prefill: str = "no_bos_system", |
| ) -> None: |
| try: |
| import types |
| except Exception: |
| return |
|
|
| try: |
| lower_k = int(lower_k) |
| except Exception: |
| lower_k = 0 |
| attn = (prefill_attn or "causal").strip().lower() |
| if attn == "prefix_full": |
| attn = "full" |
| sys_prefill = (system_prefill or "no_bos_system").strip().lower() |
| if sys_prefill not in {"full", "no_system", "no_bos_system"}: |
| sys_prefill = "no_bos_system" |
| if lower_k <= 0 or attn not in {"causal", "full"}: |
| return |
|
|
| try: |
| setattr(model, "_runtime_prefill_solo_layers", int(lower_k)) |
| setattr(model, "_runtime_prefill_solo_attn", attn) |
| setattr(model, "_runtime_prefill_solo_system_prefill", sys_prefill) |
| setattr(model, "_runtime_structured_solo_generate_default", True) |
| except Exception: |
| return |
|
|
| if tokenizer is not None: |
| _attach_structured_llopa_generate(model, tokenizer) |
|
|
|
|
| def _attach_prefill_lower_solo_v2_generate( |
| model, |
| *, |
| tokenizer=None, |
| lower_k: int, |
| prefill_attn: str = "causal", |
| system_prefill: str = "no_bos_system", |
| with_bos: bool = False, |
| ) -> None: |
| try: |
| lower_k = int(lower_k) |
| except Exception: |
| lower_k = 0 |
| attn = (prefill_attn or "causal").strip().lower() |
| if attn == "prefix_full": |
| attn = "full" |
| sys_prefill = (system_prefill or "no_bos_system").strip().lower() |
| if sys_prefill not in {"full", "no_system", "no_bos_system"}: |
| sys_prefill = "no_bos_system" |
| if lower_k <= 0 or attn not in {"causal", "full"}: |
| return |
|
|
| try: |
| setattr(model, "_runtime_prefill_solo_v2_layers", int(lower_k)) |
| setattr(model, "_runtime_prefill_solo_v2_attn", attn) |
| setattr(model, "_runtime_prefill_solo_v2_system_prefill", sys_prefill) |
| setattr(model, "_runtime_prefill_solo_v2_with_bos", bool(with_bos)) |
| setattr(model, "_runtime_structured_solo_v2_generate_default", True) |
| except Exception: |
| return |
|
|
| if tokenizer is not None: |
| _attach_structured_llopa_generate(model, tokenizer) |
|
|
|
|
| @torch.inference_mode() |
| def _direct_freeze_generate_impl( |
| model, |
| tokenizer, |
| *, |
| prompt_messages, |
| prompt_add_generation_prompt: bool, |
| input_ids: Optional[torch.LongTensor], |
| attention_mask: Optional[torch.Tensor], |
| lower_k: int, |
| prefill_attn: str, |
| system_prefill: str, |
| max_length=None, |
| max_new_tokens=None, |
| min_length=None, |
| min_new_tokens=None, |
| do_sample=None, |
| temperature=None, |
| top_p=None, |
| top_k=None, |
| stopping_criteria=None, |
| pad_token_id=None, |
| eos_token_id=None, |
| output_scores: bool = False, |
| return_dict_in_generate: bool = False, |
| use_cache: Optional[bool] = None, |
| ): |
| try: |
| lower_k = int(lower_k) |
| except Exception: |
| lower_k = 0 |
| if lower_k <= 0: |
| return None |
|
|
| attn = (prefill_attn or "causal").strip().lower() |
| if attn == "prefix_full": |
| attn = "full" |
| if attn not in {"causal", "full"}: |
| attn = "causal" |
|
|
| sys_prefill = (system_prefill or "no_bos_system").strip().lower() |
| if sys_prefill not in {"full", "no_system", "no_bos_system"}: |
| sys_prefill = "no_bos_system" |
|
|
| device = None |
| if isinstance(input_ids, torch.Tensor): |
| if input_ids.dim() != 2 or input_ids.size(0) != 1: |
| return None |
| device = input_ids.device |
| if device is None: |
| try: |
| device = next(model.parameters()).device |
| except Exception: |
| device = "cpu" |
|
|
| segments = _build_structured_prompt_segments( |
| tokenizer, |
| prompt_messages, |
| prompt_add_generation_prompt=bool(prompt_add_generation_prompt), |
| device=device, |
| ) |
| prompt_ids = segments["prompt_ids"] |
| system_ids = segments["system_ids"] |
| user_ids = segments["user_ids"] |
| canonical_input_ids = prompt_ids |
| split_start = int(system_ids.size(1) + user_ids.size(1)) |
| system_len = int(system_ids.size(1)) |
| total_prompt_len = int(canonical_input_ids.size(1)) |
|
|
| if max_new_tokens is None: |
| if max_length is None: |
| max_new_tokens = 256 |
| else: |
| max_new_tokens = max(0, int(max_length) - total_prompt_len) |
| else: |
| max_new_tokens = int(max_new_tokens) |
| if min_new_tokens is None: |
| if min_length is None: |
| min_new_tokens = 0 |
| else: |
| min_new_tokens = max(0, int(min_length) - total_prompt_len) |
| else: |
| min_new_tokens = int(min_new_tokens) |
|
|
| raw_temp = 0.0 if temperature is None else float(temperature) |
| if do_sample is None: |
| do_sample = bool(raw_temp != 0.0) |
| do_sample = bool(do_sample) |
| sample_temp = 1.0 if (not do_sample or raw_temp == 0.0) else float(raw_temp) |
| top_p = 1.0 if top_p is None else float(top_p) |
| top_k = None if top_k is None else int(top_k) |
|
|
| stop_ids = set(_normalize_eos_token_ids(eos_token_id)) |
| if not stop_ids: |
| tok_eos = getattr(tokenizer, "eos_token_id", None) |
| if tok_eos is not None: |
| stop_ids.add(int(tok_eos)) |
| with contextlib.suppress(Exception): |
| eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>") |
| if eot_id is not None and eot_id != tokenizer.unk_token_id: |
| stop_ids.add(int(eot_id)) |
| logits_warpers = _build_sampling_warpers(do_sample, sample_temp, top_p, top_k) |
|
|
| generated = torch.empty((1, int(max_new_tokens)), dtype=torch.long, device=canonical_input_ids.device) |
| score_list: list[torch.Tensor] = [] |
| cur = 0 |
|
|
| while cur < int(max_new_tokens): |
| current_ids = torch.cat([canonical_input_ids, generated[:, :cur]], dim=1) |
| current_attn = torch.ones_like(current_ids, device=current_ids.device) |
| out = model( |
| input_ids=current_ids, |
| attention_mask=current_attn, |
| use_cache=False, |
| logits_to_keep=1, |
| prefill_lower_layers=int(lower_k), |
| prefill_lower_attn=str(attn), |
| prefill_lower_freeze_runtime=True, |
| prefill_lower_split_start=int(split_start), |
| prefill_lower_system_len=int(system_len), |
| prefill_lower_system_prefill=str(sys_prefill), |
| ) |
| if out is None or not isinstance(getattr(out, "logits", None), torch.Tensor): |
| return None |
| logits = out.logits[:, -1, :].to(torch.float32) |
|
|
| if stop_ids and cur < int(min_new_tokens): |
| for sid in stop_ids: |
| logits[:, sid] = -float("inf") |
| if logits_warpers is not None: |
| logits = logits_warpers(generated[:, :cur], logits) |
| if bool(output_scores): |
| score_list.append(logits.detach().clone()) |
|
|
| if do_sample: |
| probs = torch.softmax(logits, dim=-1) |
| next_tok = torch.multinomial(probs, num_samples=1) |
| else: |
| next_tok = torch.argmax(logits, dim=-1, keepdim=True) |
|
|
| generated[:, cur : cur + 1] = next_tok |
| cur += 1 |
|
|
| sequences_now = torch.cat([canonical_input_ids, generated[:, :cur]], dim=1) |
| should_stop = False |
| tok_id = int(next_tok.item()) |
| if tok_id in stop_ids and cur >= int(min_new_tokens): |
| should_stop = True |
| if (not should_stop) and stopping_criteria is not None: |
| try: |
| should_stop = bool(stopping_criteria(sequences_now, logits)) |
| except TypeError: |
| should_stop = bool(stopping_criteria(sequences_now, None)) |
| if should_stop: |
| break |
|
|
| sequences = torch.cat([canonical_input_ids, generated[:, :cur]], dim=1) |
| if not bool(return_dict_in_generate): |
| return sequences |
| return { |
| "sequences": sequences, |
| "scores": tuple(score_list) if bool(output_scores) else tuple(), |
| } |
|
|
|
|
| @torch.inference_mode() |
| def _direct_solo_generate_impl( |
| model, |
| tokenizer, |
| *, |
| prompt_messages, |
| prompt_add_generation_prompt: bool, |
| input_ids: Optional[torch.LongTensor], |
| attention_mask: Optional[torch.Tensor], |
| lower_k: int, |
| prefill_attn: str, |
| system_prefill: str, |
| max_length=None, |
| max_new_tokens=None, |
| min_length=None, |
| min_new_tokens=None, |
| do_sample=None, |
| temperature=None, |
| top_p=None, |
| top_k=None, |
| stopping_criteria=None, |
| pad_token_id=None, |
| eos_token_id=None, |
| output_scores: bool = False, |
| return_dict_in_generate: bool = False, |
| use_cache: Optional[bool] = None, |
| solo_v2: bool = False, |
| with_bos: bool = False, |
| ): |
| try: |
| lower_k = int(lower_k) |
| except Exception: |
| lower_k = 0 |
| if lower_k <= 0: |
| return None |
|
|
| attn = (prefill_attn or "causal").strip().lower() |
| if attn == "prefix_full": |
| attn = "full" |
| if attn not in {"causal", "full"}: |
| attn = "causal" |
|
|
| sys_prefill = (system_prefill or "no_bos_system").strip().lower() |
| if sys_prefill not in {"full", "no_system", "no_bos_system"}: |
| sys_prefill = "no_bos_system" |
|
|
| device = None |
| if isinstance(input_ids, torch.Tensor): |
| if input_ids.dim() != 2 or input_ids.size(0) != 1: |
| return None |
| device = input_ids.device |
| if device is None: |
| try: |
| device = next(model.parameters()).device |
| except Exception: |
| device = "cpu" |
|
|
| segments = _build_structured_prompt_segments( |
| tokenizer, |
| prompt_messages, |
| prompt_add_generation_prompt=bool(prompt_add_generation_prompt), |
| device=device, |
| ) |
| prompt_ids = segments["prompt_ids"] |
| system_ids = segments["system_ids"] |
| user_ids = segments["user_ids"] |
| canonical_input_ids = prompt_ids |
| split_start = int(system_ids.size(1) + user_ids.size(1)) |
| system_len = int(system_ids.size(1)) |
| total_prompt_len = int(canonical_input_ids.size(1)) |
|
|
| if max_new_tokens is None: |
| if max_length is None: |
| max_new_tokens = 256 |
| else: |
| max_new_tokens = max(0, int(max_length) - total_prompt_len) |
| else: |
| max_new_tokens = int(max_new_tokens) |
| if min_new_tokens is None: |
| if min_length is None: |
| min_new_tokens = 0 |
| else: |
| min_new_tokens = max(0, int(min_length) - total_prompt_len) |
| else: |
| min_new_tokens = int(min_new_tokens) |
|
|
| raw_temp = 0.0 if temperature is None else float(temperature) |
| if do_sample is None: |
| do_sample = bool(raw_temp != 0.0) |
| do_sample = bool(do_sample) |
| sample_temp = 1.0 if (not do_sample or raw_temp == 0.0) else float(raw_temp) |
| top_p = 1.0 if top_p is None else float(top_p) |
| top_k = None if top_k is None else int(top_k) |
|
|
| stop_ids = set(_normalize_eos_token_ids(eos_token_id)) |
| if not stop_ids: |
| tok_eos = getattr(tokenizer, "eos_token_id", None) |
| if tok_eos is not None: |
| stop_ids.add(int(tok_eos)) |
| with contextlib.suppress(Exception): |
| eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>") |
| if eot_id is not None and eot_id != tokenizer.unk_token_id: |
| stop_ids.add(int(eot_id)) |
| logits_warpers = _build_sampling_warpers(do_sample, sample_temp, top_p, top_k) |
|
|
| generated = torch.empty((1, int(max_new_tokens)), dtype=torch.long, device=canonical_input_ids.device) |
| score_list: list[torch.Tensor] = [] |
| cur = 0 |
|
|
| while cur < int(max_new_tokens): |
| current_ids = torch.cat([canonical_input_ids, generated[:, :cur]], dim=1) |
| current_attn = torch.ones_like(current_ids, device=current_ids.device) |
| out = model( |
| input_ids=current_ids, |
| attention_mask=current_attn, |
| use_cache=False, |
| logits_to_keep=1, |
| prefill_lower_layers=int(lower_k), |
| prefill_lower_attn=str(attn), |
| prefill_lower_split_start=int(split_start), |
| prefill_lower_system_len=int(system_len), |
| prefill_lower_system_prefill=str(sys_prefill), |
| prefill_lower_solo_attention=bool(not solo_v2), |
| prefill_lower_solo_attention_v2=bool(solo_v2), |
| prefill_lower_solo_attention_v2_with_bos=bool(solo_v2 and with_bos), |
| ) |
| if out is None or not isinstance(getattr(out, "logits", None), torch.Tensor): |
| return None |
| logits = out.logits[:, -1, :].to(torch.float32) |
|
|
| if stop_ids and cur < int(min_new_tokens): |
| for sid in stop_ids: |
| logits[:, sid] = -float("inf") |
| if logits_warpers is not None: |
| logits = logits_warpers(generated[:, :cur], logits) |
| if bool(output_scores): |
| score_list.append(logits.detach().clone()) |
|
|
| if do_sample: |
| probs = torch.softmax(logits, dim=-1) |
| next_tok = torch.multinomial(probs, num_samples=1) |
| else: |
| next_tok = torch.argmax(logits, dim=-1, keepdim=True) |
|
|
| generated[:, cur : cur + 1] = next_tok |
| cur += 1 |
|
|
| sequences_now = torch.cat([canonical_input_ids, generated[:, :cur]], dim=1) |
| should_stop = False |
| tok_id = int(next_tok.item()) |
| if tok_id in stop_ids and cur >= int(min_new_tokens): |
| should_stop = True |
| if (not should_stop) and stopping_criteria is not None: |
| try: |
| should_stop = bool(stopping_criteria(sequences_now, logits)) |
| except TypeError: |
| should_stop = bool(stopping_criteria(sequences_now, None)) |
| if should_stop: |
| break |
|
|
| sequences = torch.cat([canonical_input_ids, generated[:, :cur]], dim=1) |
| if not bool(return_dict_in_generate): |
| return sequences |
| return { |
| "sequences": sequences, |
| "scores": tuple(score_list) if bool(output_scores) else tuple(), |
| } |
|
|
|
|
| def _attach_runtime_llopa_generate( |
| model, |
| *, |
| header_ids: torch.Tensor, |
| lower_k: int, |
| prefill_attn: str = "causal", |
| no_upper_attn: bool = False, |
| ) -> None: |
| try: |
| import types |
| except Exception: |
| return |
|
|
| try: |
| lower_k = int(lower_k) |
| except Exception: |
| lower_k = 0 |
| attn = (prefill_attn or "causal").strip().lower() |
| if attn == "prefix_full": |
| attn = "full" |
| if ( |
| lower_k <= 0 |
| or attn not in {"causal", "full"} |
| or not isinstance(header_ids, torch.Tensor) |
| or header_ids.numel() == 0 |
| ): |
| return |
|
|
| try: |
| setattr(model, "_runtime_llopa_header_ids", header_ids.detach().to(device="cpu", dtype=torch.long)) |
| setattr(model, "_runtime_llopa_layers", int(lower_k)) |
| setattr(model, "_runtime_llopa_attn", attn) |
| setattr(model, "_runtime_llopa_no_upper_attn", bool(no_upper_attn)) |
| except Exception: |
| return |
|
|
| if getattr(model, "_runtime_llopa_generate_attached", False): |
| return |
|
|
| orig_prepare = getattr(model, "prepare_inputs_for_generation", None) |
| if orig_prepare is None: |
| return |
|
|
| def _runtime_prepare_inputs_for_generation( |
| self, |
| input_ids, |
| past_key_values=None, |
| attention_mask=None, |
| inputs_embeds=None, |
| cache_position=None, |
| runtime_llopa_prefill=None, |
| runtime_llopa_layers=None, |
| runtime_llopa_attn=None, |
| runtime_llopa_no_upper_attn=None, |
| **kwargs, |
| ): |
| model_inputs = orig_prepare( |
| input_ids, |
| past_key_values=past_key_values, |
| attention_mask=attention_mask, |
| inputs_embeds=inputs_embeds, |
| cache_position=cache_position, |
| **kwargs, |
| ) |
| if runtime_llopa_prefill is None: |
| runtime_llopa_prefill = True |
| if runtime_llopa_layers is None: |
| runtime_llopa_layers = int(getattr(self, "_runtime_llopa_layers", 0) or 0) |
| if runtime_llopa_attn is None: |
| runtime_llopa_attn = str(getattr(self, "_runtime_llopa_attn", "causal") or "causal") |
| if runtime_llopa_no_upper_attn is None: |
| runtime_llopa_no_upper_attn = bool(getattr(self, "_runtime_llopa_no_upper_attn", False)) |
| if bool(runtime_llopa_prefill) and int(runtime_llopa_layers or 0) > 0: |
| model_inputs["runtime_llopa_prefill"] = True |
| model_inputs["runtime_llopa_layers"] = int(runtime_llopa_layers) |
| model_inputs["runtime_llopa_attn"] = str(runtime_llopa_attn) |
| if bool(runtime_llopa_no_upper_attn): |
| model_inputs["runtime_llopa_no_upper_attn"] = True |
| return model_inputs |
|
|
| try: |
| model.prepare_inputs_for_generation = types.MethodType( |
| _runtime_prepare_inputs_for_generation, |
| model, |
| ) |
| setattr(model, "_runtime_llopa_generate_attached", True) |
| except Exception: |
| pass |
|
|
|
|
| @torch.inference_mode() |
| def _runtime_llopa_fast_generate_mode( |
| model, |
| input_ids: torch.LongTensor, |
| logits_processor, |
| stopping_criteria, |
| generation_config, |
| synced_gpus: bool = False, |
| streamer=None, |
| **model_kwargs, |
| ): |
| from transformers.generation.utils import GenerateDecoderOnlyOutput |
|
|
| if model.config.is_encoder_decoder: |
| return None |
|
|
| pad_token_id = generation_config._pad_token_tensor |
| output_attentions = generation_config.output_attentions |
| output_hidden_states = generation_config.output_hidden_states |
| output_scores = generation_config.output_scores |
| output_logits = generation_config.output_logits |
| return_dict_in_generate = generation_config.return_dict_in_generate |
| has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) |
| do_sample = generation_config.do_sample |
|
|
| scores = () if (return_dict_in_generate and output_scores) else None |
| raw_logits = () if (return_dict_in_generate and output_logits) else None |
| decoder_attentions = () if (return_dict_in_generate and output_attentions) else None |
| decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None |
|
|
| batch_size, cur_len = input_ids.shape[:2] |
| this_peer_finished = False |
| unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) |
| model_kwargs = model._get_initial_cache_position(cur_len, input_ids.device, model_kwargs) |
|
|
| while model._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): |
| model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs) |
| outputs = model(**model_inputs, return_dict=True) |
|
|
| model_kwargs = model._update_model_kwargs_for_generation( |
| outputs, |
| model_kwargs, |
| is_encoder_decoder=False, |
| ) |
| if synced_gpus and this_peer_finished: |
| continue |
|
|
| next_token_logits = outputs.logits[:, -1, :].to(dtype=torch.float32) |
| raw_next_token_logits = next_token_logits.clone() if (return_dict_in_generate and output_logits) else None |
| next_token_scores = logits_processor(input_ids, next_token_logits) |
|
|
| if return_dict_in_generate: |
| if output_scores: |
| scores += (next_token_scores,) |
| if output_logits: |
| raw_logits += (raw_next_token_logits,) |
| if output_attentions: |
| decoder_attentions += (outputs.attentions,) |
| if output_hidden_states: |
| decoder_hidden_states += (outputs.hidden_states,) |
|
|
| if do_sample: |
| probs = nn.functional.softmax(next_token_scores, dim=-1) |
| next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) |
| else: |
| next_tokens = torch.argmax(next_token_scores, dim=-1) |
|
|
| if has_eos_stopping_criteria: |
| next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) |
|
|
| input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) |
| if streamer is not None: |
| streamer.put(next_tokens.cpu()) |
|
|
| unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) |
| this_peer_finished = unfinished_sequences.max() == 0 |
| cur_len += 1 |
|
|
| del outputs |
|
|
| if streamer is not None: |
| streamer.end() |
|
|
| if return_dict_in_generate: |
| return GenerateDecoderOnlyOutput( |
| sequences=input_ids, |
| scores=scores, |
| logits=raw_logits, |
| attentions=decoder_attentions, |
| hidden_states=decoder_hidden_states, |
| past_key_values=model_kwargs.get("past_key_values"), |
| ) |
| return input_ids |
|
|
|
|
| @torch.inference_mode() |
| def _llopa_v2_generation_mixin_decode_loop( |
| model, |
| *, |
| canonical_input_ids: torch.LongTensor, |
| attention_mask: Optional[torch.Tensor], |
| past_key_values, |
| initial_logits: torch.Tensor, |
| lower_k: int, |
| no_upper_attn: bool, |
| replay_module: str, |
| replay_per_layers: int, |
| max_new_tokens: int, |
| min_new_tokens: int, |
| do_sample: bool, |
| logits_warpers, |
| stop_ids: set[int], |
| stop_token_ids: Optional[torch.Tensor], |
| stopping_criteria, |
| output_scores: bool, |
| compact_scores: bool, |
| return_dict_in_generate: bool, |
| ): |
| if not isinstance(canonical_input_ids, torch.Tensor) or canonical_input_ids.dim() != 2: |
| return None |
| if canonical_input_ids.size(0) != 1: |
| return None |
| if not isinstance(initial_logits, torch.Tensor): |
| return None |
| if int(max_new_tokens) <= 0: |
| return canonical_input_ids if not bool(return_dict_in_generate) else None |
|
|
| from transformers.generation.utils import GenerateDecoderOnlyOutput |
|
|
| device = canonical_input_ids.device |
| total_prompt_len = int(canonical_input_ids.size(1)) |
| max_new_tokens = int(max_new_tokens) |
| min_new_tokens = int(min_new_tokens) |
| sequences = torch.empty( |
| (canonical_input_ids.size(0), total_prompt_len + max_new_tokens), |
| dtype=canonical_input_ids.dtype, |
| device=device, |
| ) |
| sequences[:, :total_prompt_len] = canonical_input_ids |
| generated = sequences[:, total_prompt_len:] |
|
|
| if isinstance(attention_mask, torch.Tensor) and attention_mask.dim() == 2: |
| prompt_attention_mask = attention_mask.to(device=device) |
| if prompt_attention_mask.size(0) != canonical_input_ids.size(0): |
| prompt_attention_mask = torch.ones_like(canonical_input_ids, dtype=torch.long, device=device) |
| elif prompt_attention_mask.size(1) != total_prompt_len: |
| prompt_attention_mask = prompt_attention_mask[:, -total_prompt_len:] |
| else: |
| prompt_attention_mask = torch.ones_like(canonical_input_ids, dtype=torch.long, device=device) |
|
|
| model_kwargs = { |
| "attention_mask": prompt_attention_mask, |
| "past_key_values": past_key_values, |
| "use_cache": True, |
| "llopa_v2_decode": True, |
| "llopa_v2_decode_layers": int(lower_k), |
| "llopa_v2_decode_no_upper_attn": bool(no_upper_attn), |
| "llopa_v2_decode_replay_module": str(replay_module), |
| "llopa_v2_decode_replay_per_layers": int(replay_per_layers), |
| } |
|
|
| record_scores = bool(output_scores) |
| record_compact_scores = record_scores and bool(compact_scores) |
| score_list: list[torch.Tensor] = [] |
| score_list_append = score_list.append |
| compact_logprob_list: list[torch.Tensor] = [] |
| compact_logprob_append = compact_logprob_list.append |
| cur = 0 |
| pending_logits = initial_logits |
| should_apply_stopping_criteria = stopping_criteria is not None |
|
|
| while cur < max_new_tokens: |
| outputs = None |
| used_prefill_logits = pending_logits is not None |
| if used_prefill_logits: |
| logits = pending_logits.to(dtype=torch.float32, device=device, copy=True) |
| pending_logits = None |
| else: |
| current_input_ids = sequences[:, : total_prompt_len + cur] |
| if not isinstance(model_kwargs.get("cache_position"), torch.Tensor): |
| model_kwargs["cache_position"] = torch.arange( |
| total_prompt_len + cur - 1, |
| total_prompt_len + cur, |
| device=device, |
| dtype=torch.long, |
| ) |
| model_inputs = model.prepare_inputs_for_generation(current_input_ids, **model_kwargs) |
| outputs = model(**model_inputs, return_dict=True) |
| model_kwargs = model._update_model_kwargs_for_generation( |
| outputs, |
| model_kwargs, |
| is_encoder_decoder=False, |
| ) |
| logits = outputs.logits[:, -1, :].to(dtype=torch.float32, device=device, copy=True) |
|
|
| if stop_token_ids is not None and cur < min_new_tokens: |
| logits.index_fill_(1, stop_token_ids, -float("inf")) |
| if logits_warpers is not None: |
| logits = logits_warpers(generated[:, :cur], logits) |
| if record_scores and not record_compact_scores: |
| score_list_append(logits.detach().clone()) |
|
|
| if do_sample: |
| probs = torch.softmax(logits, dim=-1) |
| next_tok = torch.multinomial(probs, num_samples=1) |
| else: |
| next_tok = torch.argmax(logits, dim=-1, keepdim=True) |
| if record_compact_scores: |
| next_logit = torch.gather(logits, 1, next_tok) |
| next_logprob = next_logit - torch.logsumexp(logits, dim=-1, keepdim=True) |
| compact_logprob_append(next_logprob.squeeze(-1).detach()) |
|
|
| generated[:, cur : cur + 1] = next_tok |
| cur += 1 |
|
|
| if used_prefill_logits: |
| one_mask = torch.ones( |
| (prompt_attention_mask.size(0), 1), |
| dtype=prompt_attention_mask.dtype, |
| device=device, |
| ) |
| model_kwargs["attention_mask"] = torch.cat([model_kwargs["attention_mask"], one_mask], dim=-1) |
| model_kwargs["cache_position"] = torch.arange( |
| total_prompt_len, |
| total_prompt_len + 1, |
| device=device, |
| dtype=torch.long, |
| ) |
|
|
| should_stop = False |
| tok_id = int(next_tok.item()) |
| if tok_id in stop_ids and cur >= min_new_tokens: |
| should_stop = True |
| if (not should_stop) and should_apply_stopping_criteria: |
| sequences_now = sequences[:, : total_prompt_len + cur] |
| try: |
| should_stop = bool(stopping_criteria(sequences_now, logits)) |
| except TypeError: |
| should_stop = bool(stopping_criteria(sequences_now, None)) |
| if outputs is not None: |
| del outputs |
| if should_stop: |
| break |
|
|
| sequences = sequences[:, : total_prompt_len + cur] |
| if not bool(return_dict_in_generate): |
| return sequences |
|
|
| output = GenerateDecoderOnlyOutput( |
| sequences=sequences, |
| scores=tuple(score_list) if record_scores and not record_compact_scores else None, |
| past_key_values=model_kwargs.get("past_key_values", past_key_values), |
| ) |
| if record_compact_scores: |
| if compact_logprob_list: |
| compact_logprobs = torch.stack(compact_logprob_list, dim=1) |
| else: |
| compact_logprobs = torch.empty( |
| (sequences.size(0), 0), |
| dtype=torch.float32, |
| device=sequences.device, |
| ) |
| setattr(output, "generated_token_logprobs", compact_logprobs) |
| return output |
|
|
|
|
| def _attach_runtime_llopa_fast_generate( |
| model, |
| *, |
| lower_k: int, |
| prefill_attn: str = "causal", |
| no_upper_attn: bool = False, |
| ) -> None: |
| try: |
| import types |
| except Exception: |
| return |
|
|
| try: |
| lower_k = int(lower_k) |
| except Exception: |
| lower_k = 0 |
| attn = (prefill_attn or "causal").strip().lower() |
| if attn == "prefix_full": |
| attn = "full" |
| if lower_k <= 0 or attn not in {"causal", "full"}: |
| return |
|
|
| try: |
| setattr(model, "_runtime_llopa_fast_layers", int(lower_k)) |
| setattr(model, "_runtime_llopa_fast_attn", attn) |
| setattr(model, "_runtime_llopa_fast_no_upper_attn", bool(no_upper_attn)) |
| except Exception: |
| return |
|
|
| if getattr(model, "_runtime_llopa_fast_generate_attached", False): |
| return |
|
|
| orig_generate = getattr(model, "generate", None) |
| if not callable(orig_generate): |
| return |
|
|
| def _runtime_llopa_fast_generate(self, *args, **kwargs): |
| if bool(getattr(self.config, "is_encoder_decoder", False)): |
| return orig_generate(*args, **kwargs) |
|
|
| if kwargs.get("custom_generate") is not None or kwargs.get("assistant_model") is not None: |
| return orig_generate(*args, **kwargs) |
|
|
| if kwargs.get("inputs_embeds") is not None: |
| return orig_generate(*args, **kwargs) |
|
|
| num_beams = kwargs.get("num_beams", None) |
| if num_beams is None: |
| gen_cfg = kwargs.get("generation_config") |
| num_beams = getattr(gen_cfg, "num_beams", 1) if gen_cfg is not None else 1 |
| if int(num_beams or 1) != 1: |
| _warn_once( |
| self, |
| "_warned_runtime_llopa_fast_num_beams", |
| "[load_llopa_model][warn] runtime_llopa_fast_generate currently supports only num_beams=1; falling back to model.generate().", |
| ) |
| return orig_generate(*args, **kwargs) |
|
|
| fast_enabled = kwargs.pop("runtime_llopa_fast_generate", None) |
| if fast_enabled is None: |
| fast_enabled = True |
| if not bool(fast_enabled): |
| return orig_generate(*args, **kwargs) |
|
|
| runtime_enabled = kwargs.get("runtime_llopa_prefill", None) |
| if runtime_enabled is None: |
| runtime_enabled = True |
| if not bool(runtime_enabled): |
| return orig_generate(*args, **kwargs) |
|
|
| if kwargs.get("runtime_llopa_layers") is None: |
| kwargs["runtime_llopa_layers"] = int(getattr(self, "_runtime_llopa_fast_layers", 0) or 0) |
| if kwargs.get("runtime_llopa_attn") is None: |
| kwargs["runtime_llopa_attn"] = str(getattr(self, "_runtime_llopa_fast_attn", "causal") or "causal") |
| if kwargs.get("runtime_llopa_no_upper_attn") is None: |
| kwargs["runtime_llopa_no_upper_attn"] = bool( |
| getattr(self, "_runtime_llopa_fast_no_upper_attn", False) |
| ) |
| kwargs["runtime_llopa_prefill"] = True |
|
|
| return orig_generate( |
| *args, |
| custom_generate=_runtime_llopa_fast_generate_mode, |
| **kwargs, |
| ) |
|
|
| try: |
| model.generate = types.MethodType(_runtime_llopa_fast_generate, model) |
| setattr(model, "_runtime_llopa_fast_generate_attached", True) |
| except Exception: |
| pass |
|
|
|
|
| def _supports_prefill_lower_runtime(model) -> bool: |
| try: |
| inner = _get_inner_model(model) |
| except Exception: |
| inner = None |
| return bool( |
| hasattr(model, "tri_vanilla_prefill_decode_forward") |
| and inner is not None |
| and hasattr(inner, "tri_prefill_lower_cache") |
| ) |
|
|
|
|
| def _supports_runtime_llopa_prompt_prefill(model) -> bool: |
| try: |
| inner = _get_inner_model(model) |
| except Exception: |
| inner = None |
| return bool( |
| hasattr(model, "tri_runtime_llopa_prompt_prefill_forward") |
| and _get_llopa_decode_step(model) is not None |
| and inner is not None |
| and hasattr(inner, "llopa_prefill_cache") |
| ) |
|
|
|
|
| def _supports_direct_llopa_generate(model) -> bool: |
| try: |
| inner = _get_inner_model(model) |
| except Exception: |
| inner = None |
| return bool( |
| _get_llopa_decode_step(model) is not None |
| and inner is not None |
| and hasattr(inner, "llopa_prefill_cache") |
| ) |
|
|
|
|
| def _warn_once(model, flag: str, msg: str) -> None: |
| if getattr(model, flag, False): |
| return |
| try: |
| setattr(model, flag, True) |
| except Exception: |
| pass |
| print(msg) |
|
|
|
|
| def _normalize_eos_token_ids(eos_token_id) -> list[int]: |
| if eos_token_id is None: |
| return [] |
| if isinstance(eos_token_id, torch.Tensor): |
| eos_token_id = eos_token_id.flatten().tolist() |
| if isinstance(eos_token_id, (list, tuple, set)): |
| out: list[int] = [] |
| for item in eos_token_id: |
| try: |
| out.append(int(item)) |
| except Exception: |
| continue |
| return out |
| try: |
| return [int(eos_token_id)] |
| except Exception: |
| return [] |
|
|
|
|
| def _prepare_stop_token_tensor(stop_ids, device) -> Optional[torch.LongTensor]: |
| if not stop_ids: |
| return None |
| try: |
| ordered = [int(tok_id) for tok_id in stop_ids] |
| except Exception: |
| return None |
| if not ordered: |
| return None |
| return torch.tensor(ordered, device=device, dtype=torch.long) |
|
|
|
|
| def _find_last_subsequence_start(input_ids: torch.Tensor, pattern: torch.Tensor) -> Optional[int]: |
| if not isinstance(input_ids, torch.Tensor) or not isinstance(pattern, torch.Tensor): |
| return None |
| if input_ids.dim() == 2: |
| if input_ids.size(0) != 1: |
| return None |
| seq = input_ids[0] |
| elif input_ids.dim() == 1: |
| seq = input_ids |
| else: |
| return None |
| if pattern.dim() == 2: |
| if pattern.size(0) != 1: |
| return None |
| pat = pattern[0] |
| elif pattern.dim() == 1: |
| pat = pattern |
| else: |
| return None |
|
|
| pat_len = int(pat.numel()) |
| seq_len = int(seq.numel()) |
| if pat_len <= 0 or seq_len < pat_len: |
| return None |
|
|
| windows = seq.unfold(0, pat_len, 1) |
| matches = (windows == pat).all(dim=-1).nonzero(as_tuple=False) |
| if matches.numel() == 0: |
| return None |
| return int(matches[-1].item()) |
|
|
|
|
| def _build_sampling_warpers(do_sample: bool, temperature, top_p, top_k): |
| if not do_sample: |
| return None |
| from transformers.generation import LogitsProcessorList |
| from transformers.generation.logits_process import ( |
| TemperatureLogitsWarper, |
| TopKLogitsWarper, |
| TopPLogitsWarper, |
| ) |
|
|
| procs = LogitsProcessorList() |
| temp = float(temperature) |
| if temp != 1.0: |
| procs.append(TemperatureLogitsWarper(temp)) |
| if top_p is not None and float(top_p) < 1.0: |
| procs.append(TopPLogitsWarper(float(top_p), min_tokens_to_keep=1)) |
| if top_k is not None and int(top_k) > 0: |
| procs.append(TopKLogitsWarper(int(top_k), filter_value=-float("inf"))) |
| return procs |
|
|
|
|
| @torch.inference_mode() |
| def _legacy_direct_llopa_generate_impl( |
| model, |
| tokenizer, |
| *, |
| input_ids: Optional[torch.LongTensor], |
| attention_mask: Optional[torch.Tensor], |
| lower_k: int, |
| prefill_attn: str, |
| max_length=None, |
| max_new_tokens=None, |
| min_length=None, |
| min_new_tokens=None, |
| do_sample=None, |
| temperature=None, |
| top_p=None, |
| top_k=None, |
| stopping_criteria=None, |
| pad_token_id=None, |
| eos_token_id=None, |
| output_scores: bool = False, |
| return_dict_in_generate: bool = False, |
| use_cache: Optional[bool] = None, |
| ): |
| if input_ids is None or input_ids.dim() != 2 or input_ids.size(0) != 1: |
| return None |
| if use_cache is False: |
| return None |
|
|
| valid_len = int(input_ids.size(1)) |
| if attention_mask is not None: |
| if attention_mask.dim() != 2 or attention_mask.size(0) != input_ids.size(0): |
| return None |
| valid_len = int(attention_mask[0].sum().item()) |
| if valid_len <= 0: |
| return None |
|
|
| trimmed_input_ids = input_ids[:, -valid_len:] |
| header_ids = getattr(model, "_direct_llopa_header_ids", None) |
| if not isinstance(header_ids, torch.Tensor) or header_ids.numel() == 0: |
| return None |
| hdr = header_ids.to(device=trimmed_input_ids.device, dtype=trimmed_input_ids.dtype) |
| assistant_start = _find_last_subsequence_start(trimmed_input_ids, hdr) |
| if assistant_start is None: |
| return None |
| prefix_ids = trimmed_input_ids[:, :assistant_start] |
| assistant_ids = trimmed_input_ids[:, assistant_start:] |
| if assistant_ids.numel() == 0: |
| return None |
| try: |
| lower_k = int(lower_k) |
| except Exception: |
| lower_k = 0 |
| if lower_k <= 0: |
| return None |
|
|
| attn = (prefill_attn or "causal").strip().lower() |
| if attn == "prefix_full": |
| attn = "full" |
| if attn not in {"causal", "full"}: |
| attn = "causal" |
|
|
| total_prompt_len = int(input_ids.size(1)) |
| if max_new_tokens is None: |
| if max_length is None: |
| max_new_tokens = 256 |
| else: |
| max_new_tokens = max(0, int(max_length) - total_prompt_len) |
| else: |
| max_new_tokens = int(max_new_tokens) |
| if min_new_tokens is None: |
| if min_length is None: |
| min_new_tokens = 0 |
| else: |
| min_new_tokens = max(0, int(min_length) - total_prompt_len) |
| else: |
| min_new_tokens = int(min_new_tokens) |
|
|
| raw_temp = 0.0 if temperature is None else float(temperature) |
| if do_sample is None: |
| do_sample = bool(raw_temp != 0.0) |
| do_sample = bool(do_sample) |
| sample_temp = 1.0 if (not do_sample or raw_temp == 0.0) else float(raw_temp) |
| top_p = 1.0 if top_p is None else float(top_p) |
| top_k = None if top_k is None else int(top_k) |
|
|
| llopa_core = _get_llopa_core(model) |
| llopa_step = _get_llopa_decode_step(model) |
| if llopa_core is None or llopa_step is None: |
| return None |
| llopa_fn = getattr(llopa_core, "llopa_prefill_cache", None) |
| if llopa_fn is None: |
| return None |
|
|
| output_head = _get_output_head(model) |
| prefill_out = llopa_fn( |
| system_ids=prefix_ids[:, :0], |
| user_ids=prefix_ids, |
| assistant_ids=assistant_ids, |
| lower_k=lower_k, |
| prefill_mode="lower", |
| prefill_attn=attn, |
| return_last_assistant_hidden=bool(output_head is not None), |
| ) |
| initial_logits = None |
| if isinstance(prefill_out, tuple): |
| pkv, last_hidden = prefill_out |
| if output_head is not None and isinstance(last_hidden, torch.Tensor) and last_hidden.numel() > 0: |
| initial_logits = output_head(last_hidden)[:, -1, :].to(torch.float32) |
| else: |
| pkv = prefill_out |
|
|
| S = 0 |
| U = int(prefix_ids.size(1)) |
| last = assistant_ids[:, -1:] |
| stop_ids = set(_normalize_eos_token_ids(eos_token_id)) |
| if not stop_ids: |
| tok_eos = getattr(tokenizer, "eos_token_id", None) |
| if tok_eos is not None: |
| stop_ids.add(int(tok_eos)) |
| stop_token_ids = _prepare_stop_token_tensor(stop_ids, last.device) |
| logits_warpers = _build_sampling_warpers(do_sample, sample_temp, top_p, top_k) |
| max_new_tokens = int(max_new_tokens) |
| min_new_tokens = int(min_new_tokens) |
| total_prompt_len = int(input_ids.size(1)) |
| record_scores = bool(output_scores) |
| should_apply_stopping_criteria = stopping_criteria is not None |
|
|
| sequences = torch.empty( |
| (input_ids.size(0), total_prompt_len + max_new_tokens), |
| dtype=input_ids.dtype, |
| device=input_ids.device, |
| ) |
| sequences[:, :total_prompt_len] = input_ids |
| generated = sequences[:, total_prompt_len:] |
| score_list: list[torch.Tensor] = [] |
| score_list_append = score_list.append |
| cur = 0 |
| pending_logits = initial_logits |
|
|
| while cur < max_new_tokens: |
| out = None |
| if pending_logits is None: |
| out = llopa_step( |
| assistant_ids=last, |
| lower_k=lower_k, |
| pkv=pkv, |
| S=S, |
| U=U, |
| logits_to_keep=1, |
| labels=None, |
| prefill_mode="lower", |
| ) |
| pkv = out.past_key_values or pkv |
| logits = out.logits[:, -1, :].to(dtype=torch.float32, device=last.device, copy=True) |
| else: |
| logits = pending_logits |
| pending_logits = None |
|
|
| if stop_token_ids is not None and cur < min_new_tokens: |
| logits.index_fill_(1, stop_token_ids, -float("inf")) |
| if logits_warpers is not None: |
| logits = logits_warpers(generated[:, :cur], logits) |
| if record_scores: |
| score_list_append(logits.detach().clone()) |
|
|
| if do_sample: |
| probs = torch.softmax(logits, dim=-1) |
| next_tok = torch.multinomial(probs, num_samples=1) |
| else: |
| next_tok = torch.argmax(logits, dim=-1, keepdim=True) |
|
|
| generated[:, cur : cur + 1] = next_tok |
| cur += 1 |
|
|
| should_stop = False |
| tok_id = int(next_tok.item()) |
| if tok_id in stop_ids and cur >= min_new_tokens: |
| should_stop = True |
| if (not should_stop) and should_apply_stopping_criteria: |
| sequences_now = sequences[:, : total_prompt_len + cur] |
| try: |
| should_stop = bool(stopping_criteria(sequences_now, logits)) |
| except TypeError: |
| should_stop = bool(stopping_criteria(sequences_now, None)) |
| if out is not None: |
| del out |
| if should_stop: |
| break |
| last = next_tok |
|
|
| sequences = sequences[:, : total_prompt_len + cur] |
| if not bool(return_dict_in_generate): |
| return sequences |
| from transformers.generation.utils import GenerateDecoderOnlyOutput |
|
|
| return GenerateDecoderOnlyOutput( |
| sequences=sequences, |
| scores=tuple(score_list) if bool(output_scores) else None, |
| past_key_values=pkv, |
| ) |
|
|
|
|
| @torch.inference_mode() |
| def _direct_llopa_generate_impl( |
| model, |
| tokenizer, |
| *, |
| prompt_messages, |
| prompt_add_generation_prompt: bool, |
| structured_prompt_segments=None, |
| input_ids: Optional[torch.LongTensor], |
| attention_mask: Optional[torch.Tensor], |
| lower_k: int, |
| prefill_attn: str, |
| system_prefill: str, |
| user_prefill: str, |
| no_upper_attn: bool, |
| see_past_assistant: bool = False, |
| replay_module: str = "none", |
| replay_per_layers: int = -1, |
| last_layer_module: Optional[str] = None, |
| seed_mode: str = "auto", |
| max_length=None, |
| max_new_tokens=None, |
| min_length=None, |
| min_new_tokens=None, |
| do_sample=None, |
| temperature=None, |
| top_p=None, |
| top_k=None, |
| stopping_criteria=None, |
| pad_token_id=None, |
| eos_token_id=None, |
| output_scores: bool = False, |
| compact_scores: bool = False, |
| return_dict_in_generate: bool = False, |
| use_cache: Optional[bool] = None, |
| ): |
| if last_layer_module is not None and _normalize_replay_module_value(replay_module) == "none": |
| replay_module = last_layer_module |
| replay_module = _normalize_replay_module_value(replay_module) |
| replay_per_layers = _normalize_replay_per_layers_value(replay_per_layers) |
| if use_cache is False: |
| return None |
|
|
| try: |
| lower_k = int(lower_k) |
| except Exception: |
| lower_k = 0 |
| if lower_k <= 0: |
| return None |
|
|
| attn = (prefill_attn or "causal").strip().lower() |
| if attn == "prefix_full": |
| attn = "full" |
| if attn not in {"causal", "full"}: |
| attn = "causal" |
|
|
| sys_prefill = (system_prefill or "full").strip().lower() |
| if sys_prefill not in {"full", "no_system", "no_bos_system"}: |
| sys_prefill = "full" |
|
|
| user_prefill_norm = (user_prefill or "full").strip().lower() |
| if user_prefill_norm != "full": |
| raise ValueError("Unified direct LLoPA currently supports only user_prefill='full'.") |
|
|
| llopa_core = _get_llopa_core(model) |
| llopa_step = _get_llopa_decode_step(model) |
| if llopa_core is None or llopa_step is None: |
| return None |
| llopa_forward_assistant = getattr(llopa_core, "tri_forward_assistant", None) |
| decode_output_head = _get_output_head(model) |
| use_direct_decode_step = ( |
| _env_flag_enabled("CAPSULE_LLOPA_DIRECT_DECODE_STEP", "1") |
| and callable(llopa_forward_assistant) |
| and decode_output_head is not None |
| ) |
|
|
| device = None |
| if isinstance(input_ids, torch.Tensor): |
| if input_ids.dim() != 2 or input_ids.size(0) != 1: |
| return None |
| device = input_ids.device |
| if device is None: |
| try: |
| device = next(model.parameters()).device |
| except Exception: |
| device = "cpu" |
|
|
| segments = structured_prompt_segments if isinstance(structured_prompt_segments, dict) else None |
| if segments is None: |
| segments = _build_structured_prompt_segments( |
| tokenizer, |
| prompt_messages, |
| prompt_add_generation_prompt=bool(prompt_add_generation_prompt), |
| device=device, |
| ) |
| prompt_ids = segments["prompt_ids"] |
| system_ids = segments["system_ids"] |
| user_ids = segments["user_ids"] |
| assistant_prefill_ids = segments["assistant_prefill_ids"] |
| replay_user_prefix_keep_len = int(segments.get("replay_user_prefix_keep_len", 0) or 0) |
| replay_user_start = int(segments.get("replay_user_start", 0) or 0) |
| replay_user_len = int(segments.get("replay_user_len", 0) or 0) |
| if assistant_prefill_ids.numel() == 0: |
| return None |
|
|
| raw_temp = 0.0 if temperature is None else float(temperature) |
| if do_sample is None: |
| do_sample = bool(raw_temp != 0.0) |
| do_sample = bool(do_sample) |
| sample_temp = 1.0 if (not do_sample or raw_temp == 0.0) else float(raw_temp) |
| top_p = 1.0 if top_p is None else float(top_p) |
| top_k = None if top_k is None else int(top_k) |
|
|
| initial_logits = None |
| prompt_bundle = _build_unified_prefill_lower_prompt_bundle( |
| tokenizer, |
| prompt_messages=prompt_messages, |
| prompt_add_generation_prompt=bool(prompt_add_generation_prompt), |
| structured_prompt_segments=segments, |
| device=device, |
| ) |
| reference_seed = _direct_prefill_lower_cache_and_logits( |
| model, |
| prompt_bundle=prompt_bundle, |
| lower_k=lower_k, |
| prefill_attn=attn, |
| system_prefill=sys_prefill, |
| no_upper_attn=bool(no_upper_attn), |
| see_past_assistant=bool(see_past_assistant), |
| replay_module=str(replay_module), |
| replay_per_layers=int(replay_per_layers), |
| seed_mode=str(seed_mode or "auto"), |
| ) |
| canonical_input_ids = None |
| if isinstance(input_ids, torch.Tensor) and input_ids.dim() == 2 and input_ids.size(0) == 1: |
| valid_len = int(input_ids.size(1)) |
| if isinstance(attention_mask, torch.Tensor) and attention_mask.dim() == 2 and attention_mask.size(0) == 1: |
| valid_len = int(attention_mask[0].sum().item()) |
| if valid_len > 0: |
| canonical_input_ids = input_ids[:, -valid_len:] |
| if not isinstance(canonical_input_ids, torch.Tensor): |
| canonical_input_ids = prompt_bundle.get("effective_prompt_ids") |
| if not isinstance(canonical_input_ids, torch.Tensor): |
| canonical_input_ids = prompt_ids |
| total_prompt_len = int(canonical_input_ids.size(1)) |
| if max_new_tokens is None: |
| if max_length is None: |
| max_new_tokens = 256 |
| else: |
| max_new_tokens = max(0, int(max_length) - total_prompt_len) |
| else: |
| max_new_tokens = int(max_new_tokens) |
| if min_new_tokens is None: |
| if min_length is None: |
| min_new_tokens = 0 |
| else: |
| min_new_tokens = max(0, int(min_length) - total_prompt_len) |
| else: |
| min_new_tokens = int(min_new_tokens) |
| if reference_seed is not None: |
| pkv, S, U, initial_logits = reference_seed |
| else: |
| output_head = _get_output_head(model) |
| if bool(no_upper_attn): |
| pkv, S, U = _llopa_prefill_cache( |
| llopa_core, |
| system_ids, |
| user_ids, |
| assistant_prefill_ids, |
| lower_k=lower_k, |
| prefill_mode="lower", |
| prefill_attn=attn, |
| system_prefill=sys_prefill, |
| replay_user_prefix_keep_len=replay_user_prefix_keep_len, |
| replay_user_start=replay_user_start, |
| replay_user_len=replay_user_len, |
| ) |
| else: |
| pkv, S, U, last_hidden = _llopa_prefill_cache( |
| llopa_core, |
| system_ids, |
| user_ids, |
| assistant_prefill_ids, |
| lower_k=lower_k, |
| prefill_mode="lower", |
| prefill_attn=attn, |
| system_prefill=sys_prefill, |
| return_last_assistant_hidden=bool(output_head is not None), |
| replay_user_prefix_keep_len=replay_user_prefix_keep_len, |
| replay_user_start=replay_user_start, |
| replay_user_len=replay_user_len, |
| ) |
| if output_head is not None and isinstance(last_hidden, torch.Tensor) and last_hidden.numel() > 0: |
| initial_logits = output_head(last_hidden)[:, -1, :].to(torch.float32) |
|
|
| stop_ids = set(_normalize_eos_token_ids(eos_token_id)) |
| if not stop_ids: |
| tok_eos = getattr(tokenizer, "eos_token_id", None) |
| if tok_eos is not None: |
| stop_ids.add(int(tok_eos)) |
| with contextlib.suppress(Exception): |
| eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>") |
| if eot_id is not None and eot_id != tokenizer.unk_token_id: |
| stop_ids.add(int(eot_id)) |
|
|
| last = assistant_prefill_ids[:, -1:] |
| stop_token_ids = _prepare_stop_token_tensor(stop_ids, last.device) |
| logits_warpers = _build_sampling_warpers(do_sample, sample_temp, top_p, top_k) |
| max_new_tokens = int(max_new_tokens) |
| min_new_tokens = int(min_new_tokens) |
| lower_k = int(lower_k) |
| no_upper_attn_bool = bool(no_upper_attn) |
| replay_module_str = str(replay_module) |
| replay_per_layers_int = int(replay_per_layers) |
| record_scores = bool(output_scores) |
| record_compact_scores = bool(output_scores) and bool(compact_scores) |
| should_apply_stopping_criteria = stopping_criteria is not None |
| mixin_decode_default = getattr(model, "_llopa_v2_generation_mixin_decode", None) |
| use_mixin_decode = ( |
| bool(mixin_decode_default) |
| if mixin_decode_default is not None |
| else _env_flag_enabled("CAPSULE_LLOPA_GENERATION_MIXIN_DECODE", "0") |
| ) |
| if use_mixin_decode and isinstance(initial_logits, torch.Tensor): |
| mixin_output = _llopa_v2_generation_mixin_decode_loop( |
| model, |
| canonical_input_ids=canonical_input_ids, |
| attention_mask=attention_mask, |
| past_key_values=pkv, |
| initial_logits=initial_logits, |
| lower_k=lower_k, |
| no_upper_attn=no_upper_attn_bool, |
| replay_module=replay_module_str, |
| replay_per_layers=replay_per_layers_int, |
| max_new_tokens=max_new_tokens, |
| min_new_tokens=min_new_tokens, |
| do_sample=do_sample, |
| logits_warpers=logits_warpers, |
| stop_ids=stop_ids, |
| stop_token_ids=stop_token_ids, |
| stopping_criteria=stopping_criteria, |
| output_scores=output_scores, |
| compact_scores=compact_scores, |
| return_dict_in_generate=return_dict_in_generate, |
| ) |
| if mixin_output is not None: |
| return mixin_output |
| sequences = torch.empty( |
| (canonical_input_ids.size(0), total_prompt_len + max_new_tokens), |
| dtype=canonical_input_ids.dtype, |
| device=canonical_input_ids.device, |
| ) |
| sequences[:, :total_prompt_len] = canonical_input_ids |
| generated = sequences[:, total_prompt_len:] |
| score_list: list[torch.Tensor] = [] |
| score_list_append = score_list.append |
| compact_logprob_list: list[torch.Tensor] = [] |
| compact_logprob_append = compact_logprob_list.append |
| cur = 0 |
| pending_logits = initial_logits |
|
|
| while cur < max_new_tokens: |
| out = None |
| if pending_logits is None: |
| if use_direct_decode_step: |
| out = llopa_forward_assistant( |
| assistant_ids=last, |
| lower_k=lower_k, |
| pkv=pkv, |
| S=S, |
| U=U, |
| write_cache=True, |
| prefill_mode="lower", |
| no_upper_attn=no_upper_attn_bool, |
| align_cache_position_to_layer_past=False, |
| replay_module=replay_module_str, |
| replay_per_layers=replay_per_layers_int, |
| ) |
| pkv = out.past_key_values or pkv |
| logits = decode_output_head(out.last_hidden_state[:, -1, :]) |
| if logits.dim() == 3: |
| logits = logits[:, -1, :] |
| else: |
| out = llopa_step( |
| assistant_ids=last, |
| lower_k=lower_k, |
| pkv=pkv, |
| S=S, |
| U=U, |
| logits_to_keep=1, |
| labels=None, |
| prefill_mode="lower", |
| no_upper_attn=no_upper_attn_bool, |
| align_cache_position_to_layer_past=False, |
| replay_module=replay_module_str, |
| replay_per_layers=replay_per_layers_int, |
| ) |
| pkv = out.past_key_values or pkv |
| logits = out.logits[:, -1, :] |
| logits = logits.to(dtype=torch.float32, device=last.device, copy=True) |
| else: |
| logits = pending_logits |
| pending_logits = None |
|
|
| if stop_token_ids is not None and cur < min_new_tokens: |
| logits.index_fill_(1, stop_token_ids, -float("inf")) |
| if logits_warpers is not None: |
| logits = logits_warpers(generated[:, :cur], logits) |
| if record_scores and not record_compact_scores: |
| score_list_append(logits.detach().clone()) |
|
|
| if do_sample: |
| probs = torch.softmax(logits, dim=-1) |
| next_tok = torch.multinomial(probs, num_samples=1) |
| else: |
| next_tok = torch.argmax(logits, dim=-1, keepdim=True) |
| if record_compact_scores: |
| next_logit = torch.gather(logits, 1, next_tok) |
| next_logprob = next_logit - torch.logsumexp(logits, dim=-1, keepdim=True) |
| compact_logprob_append(next_logprob.squeeze(-1).detach()) |
|
|
| generated[:, cur : cur + 1] = next_tok |
| cur += 1 |
|
|
| should_stop = False |
| tok_id = int(next_tok.item()) |
| if tok_id in stop_ids and cur >= min_new_tokens: |
| should_stop = True |
| if (not should_stop) and should_apply_stopping_criteria: |
| sequences_now = sequences[:, : total_prompt_len + cur] |
| try: |
| should_stop = bool(stopping_criteria(sequences_now, logits)) |
| except TypeError: |
| should_stop = bool(stopping_criteria(sequences_now, None)) |
| if out is not None: |
| del out |
| if should_stop: |
| break |
| last = next_tok |
|
|
| sequences = sequences[:, : total_prompt_len + cur] |
| if not bool(return_dict_in_generate): |
| return sequences |
| from transformers.generation.utils import GenerateDecoderOnlyOutput |
|
|
| output = GenerateDecoderOnlyOutput( |
| sequences=sequences, |
| scores=tuple(score_list) if bool(output_scores) and not record_compact_scores else None, |
| past_key_values=pkv, |
| ) |
| if record_compact_scores: |
| if compact_logprob_list: |
| compact_logprobs = torch.stack(compact_logprob_list, dim=1) |
| else: |
| compact_logprobs = torch.empty( |
| (sequences.size(0), 0), |
| dtype=torch.float32, |
| device=sequences.device, |
| ) |
| setattr(output, "generated_token_logprobs", compact_logprobs) |
| return output |
|
|
|
|
| def _is_prompt_messages_batch(prompt_messages) -> bool: |
| if not isinstance(prompt_messages, (list, tuple)) or not prompt_messages: |
| return False |
| if all(isinstance(item, dict) for item in prompt_messages): |
| return False |
| return all(isinstance(item, (list, tuple)) for item in prompt_messages) |
|
|
|
|
| def _is_structured_segments_batch(structured_prompt_segments) -> bool: |
| return ( |
| isinstance(structured_prompt_segments, (list, tuple)) |
| and not isinstance(structured_prompt_segments, dict) |
| and all(isinstance(item, dict) for item in structured_prompt_segments) |
| ) |
|
|
|
|
| def _batch_prompt_add_generation_flags(value, batch_size: int) -> list[bool]: |
| if isinstance(value, (list, tuple)): |
| if len(value) != int(batch_size): |
| raise ValueError( |
| f"prompt_add_generation_prompt batch size mismatch: {len(value)} != {int(batch_size)}" |
| ) |
| return [bool(item) for item in value] |
| return [bool(value) for _ in range(int(batch_size))] |
|
|
|
|
| def _as_1d_long_tensor(value: torch.Tensor, *, device) -> torch.LongTensor: |
| value = value.to(device=device, dtype=torch.long) |
| if value.dim() == 2: |
| if value.size(0) != 1: |
| raise ValueError("Expected a single-row prompt segment tensor.") |
| value = value[0] |
| elif value.dim() != 1: |
| value = value.reshape(-1) |
| return value |
|
|
|
|
| def _pad_1d_rows( |
| rows: list[torch.Tensor], |
| *, |
| pad_value: int, |
| device, |
| dtype: torch.dtype = torch.long, |
| ) -> torch.Tensor: |
| batch_size = len(rows) |
| max_len = max((int(row.numel()) for row in rows), default=0) |
| out = torch.full((batch_size, max_len), int(pad_value), device=device, dtype=dtype) |
| for row_idx, row in enumerate(rows): |
| row = row.to(device=device, dtype=dtype).reshape(-1) |
| width = int(row.numel()) |
| if width > 0: |
| out[row_idx, :width] = row |
| return out |
|
|
|
|
| def _pad_prompt_metadata_rows( |
| rows: list[Optional[torch.Tensor]], |
| *, |
| fill_value: int, |
| device, |
| dtype: torch.dtype, |
| ) -> torch.Tensor: |
| batch_size = len(rows) |
| max_len = 0 |
| flat_rows: list[torch.Tensor] = [] |
| for row in rows: |
| if isinstance(row, torch.Tensor) and row.numel() > 0: |
| flat = row.to(device=device, dtype=dtype) |
| if flat.dim() == 2: |
| flat = flat[0] |
| else: |
| flat = flat.reshape(-1) |
| else: |
| flat = torch.empty((0,), device=device, dtype=dtype) |
| flat_rows.append(flat) |
| max_len = max(max_len, int(flat.numel())) |
| if max_len <= 0: |
| return torch.full((batch_size, 1), int(fill_value), device=device, dtype=dtype) |
| out = torch.full((batch_size, max_len), int(fill_value), device=device, dtype=dtype) |
| for row_idx, flat in enumerate(flat_rows): |
| width = int(flat.numel()) |
| if width > 0: |
| out[row_idx, :width] = flat |
| return out |
|
|
|
|
| def _build_batched_structured_prompt_segments( |
| tokenizer, |
| *, |
| prompt_messages, |
| prompt_add_generation_prompt, |
| structured_prompt_segments, |
| device, |
| ) -> list[dict]: |
| if _is_structured_segments_batch(structured_prompt_segments): |
| return list(structured_prompt_segments) |
| if isinstance(structured_prompt_segments, dict): |
| return [structured_prompt_segments] |
| if _is_prompt_messages_batch(prompt_messages): |
| batch_size = len(prompt_messages) |
| flags = _batch_prompt_add_generation_flags(prompt_add_generation_prompt, batch_size) |
| out = [] |
| for row_idx, messages in enumerate(prompt_messages): |
| row_messages = list(messages) |
| add_generation_prompt = bool(flags[row_idx]) |
| normalized = _normalize_prompt_messages(row_messages) |
| if add_generation_prompt and normalized and normalized[-1]["role"] == "assistant": |
| add_generation_prompt = False |
| out.append( |
| _build_structured_prompt_segments( |
| tokenizer, |
| row_messages, |
| prompt_add_generation_prompt=add_generation_prompt, |
| device=device, |
| ) |
| ) |
| return out |
| if prompt_messages is not None: |
| if prompt_add_generation_prompt is None: |
| raise ValueError("llopa_v2_batch_generate requires prompt_add_generation_prompt for prompt_messages.") |
| return [ |
| _build_structured_prompt_segments( |
| tokenizer, |
| prompt_messages, |
| prompt_add_generation_prompt=bool(prompt_add_generation_prompt), |
| device=device, |
| ) |
| ] |
| return [] |
|
|
|
|
| def _batched_segments_to_tensors(segments: list[dict], *, device, pad_token_id: int) -> dict: |
| prompt_rows: list[torch.Tensor] = [] |
| prompt_lens: list[int] = [] |
| split_starts: list[int] = [] |
| system_lens: list[int] = [] |
| replay_prefix_keep_lens: list[int] = [] |
| replay_user_starts: list[int] = [] |
| replay_user_lens: list[int] = [] |
| header_start_rows: list[Optional[torch.Tensor]] = [] |
| turn_end_rows: list[Optional[torch.Tensor]] = [] |
| header_mask_rows: list[Optional[torch.Tensor]] = [] |
|
|
| for seg in segments: |
| prompt_ids = seg.get("prompt_ids") |
| assistant_prefill_ids = seg.get("assistant_prefill_ids") |
| if not isinstance(prompt_ids, torch.Tensor) or prompt_ids.numel() == 0: |
| raise ValueError("llopa_v2_batch_generate requires non-empty prompt_ids in each segment.") |
| if not isinstance(assistant_prefill_ids, torch.Tensor) or assistant_prefill_ids.numel() == 0: |
| raise ValueError("llopa_v2_batch_generate requires non-empty assistant_prefill_ids in each segment.") |
| prompt_row = _as_1d_long_tensor(prompt_ids, device=device) |
| assistant_row = _as_1d_long_tensor(assistant_prefill_ids, device=device) |
| prompt_len = int(prompt_row.numel()) |
| assistant_len = int(assistant_row.numel()) |
| split_start = int(prompt_len - assistant_len) |
| prefix_ids = seg.get("prefix_ids") |
| if isinstance(prefix_ids, torch.Tensor) and prefix_ids.numel() > 0: |
| split_start = int(_as_1d_long_tensor(prefix_ids, device=device).numel()) |
| split_start = max(0, min(split_start, prompt_len - 1)) |
|
|
| system_ids = seg.get("system_ids") |
| system_len = int(_as_1d_long_tensor(system_ids, device=device).numel()) if isinstance(system_ids, torch.Tensor) else 0 |
| system_len = max(0, min(system_len, split_start)) |
|
|
| prompt_rows.append(prompt_row) |
| prompt_lens.append(prompt_len) |
| split_starts.append(split_start) |
| system_lens.append(system_len) |
| replay_prefix_keep_lens.append(int(seg.get("replay_user_prefix_keep_len", 0) or 0)) |
| replay_user_starts.append(int(seg.get("replay_user_start", 0) or 0)) |
| replay_user_lens.append(int(seg.get("replay_user_len", 0) or 0)) |
|
|
| header_starts = seg.get("assistant_header_starts") |
| turn_ends = seg.get("assistant_turn_ends") |
| header_mask = seg.get("assistant_header_start_mask") |
| if not isinstance(header_starts, torch.Tensor) or header_starts.numel() == 0: |
| header_starts = torch.tensor([[split_start]], device=device, dtype=torch.long) |
| if not isinstance(turn_ends, torch.Tensor) or turn_ends.numel() == 0: |
| turn_ends = torch.tensor([[prompt_len]], device=device, dtype=torch.long) |
| if not isinstance(header_mask, torch.Tensor) or header_mask.numel() == 0: |
| header_mask = torch.ones_like(header_starts, device=device, dtype=torch.bool) |
| header_start_rows.append(header_starts) |
| turn_end_rows.append(turn_ends) |
| header_mask_rows.append(header_mask) |
|
|
| prompt_ids = _pad_1d_rows(prompt_rows, pad_value=int(pad_token_id), device=device, dtype=torch.long) |
| prompt_lens_tensor = torch.tensor(prompt_lens, device=device, dtype=torch.long) |
| prompt_attention_mask = ( |
| torch.arange(prompt_ids.size(1), device=device, dtype=torch.long).unsqueeze(0) |
| < prompt_lens_tensor.unsqueeze(1) |
| ).to(dtype=torch.long) |
| return { |
| "prompt_rows": prompt_rows, |
| "prompt_ids": prompt_ids, |
| "prompt_attention_mask": prompt_attention_mask, |
| "prompt_lens": prompt_lens_tensor, |
| "split_starts": torch.tensor(split_starts, device=device, dtype=torch.long), |
| "system_lens": torch.tensor(system_lens, device=device, dtype=torch.long), |
| "replay_user_prefix_keep_lens": torch.tensor(replay_prefix_keep_lens, device=device, dtype=torch.long), |
| "replay_user_starts": torch.tensor(replay_user_starts, device=device, dtype=torch.long), |
| "replay_user_lens": torch.tensor(replay_user_lens, device=device, dtype=torch.long), |
| "assistant_header_starts": _pad_prompt_metadata_rows( |
| header_start_rows, |
| fill_value=-1, |
| device=device, |
| dtype=torch.long, |
| ), |
| "assistant_turn_ends": _pad_prompt_metadata_rows( |
| turn_end_rows, |
| fill_value=-1, |
| device=device, |
| dtype=torch.long, |
| ), |
| "assistant_header_start_mask": _pad_prompt_metadata_rows( |
| header_mask_rows, |
| fill_value=0, |
| device=device, |
| dtype=torch.bool, |
| ).to(dtype=torch.bool), |
| } |
|
|
|
|
| def _cache_layer_seq_len(pkv, layer_idx: int) -> int: |
| try: |
| k, _ = pkv_get(pkv, int(layer_idx)) |
| except Exception: |
| return 0 |
| if not isinstance(k, torch.Tensor) or k.dim() < 3: |
| return 0 |
| return int(k.shape[-2]) |
|
|
|
|
| def _merge_optional_batch_sequence_attr(row_pkvs: list, attr_name: str, *, device, dtype=None): |
| rows = [] |
| trailing_shape = None |
| for pkv in row_pkvs: |
| value = getattr(pkv, attr_name, None) |
| if isinstance(value, torch.Tensor) and value.numel() > 0: |
| value = value.to(device=device) |
| if dtype is not None: |
| value = value.to(dtype=dtype) |
| if value.dim() == 1: |
| value = value.view(1, -1) |
| elif value.dim() >= 2 and value.size(0) != 1: |
| value = value[:1] |
| trailing_shape = tuple(value.shape[2:]) |
| else: |
| value = None |
| rows.append(value) |
| if trailing_shape is None: |
| return None |
| max_len = max((int(row.shape[1]) for row in rows if isinstance(row, torch.Tensor)), default=0) |
| if max_len <= 0: |
| return None |
| ref = next(row for row in rows if isinstance(row, torch.Tensor)) |
| out_shape = (len(rows), max_len) + trailing_shape |
| out = torch.zeros(out_shape, device=device, dtype=ref.dtype) |
| for row_idx, row in enumerate(rows): |
| if not isinstance(row, torch.Tensor): |
| continue |
| width = int(row.shape[1]) |
| if width > 0: |
| out[row_idx : row_idx + 1, :width, ...] = row[:, :width, ...] |
| return out |
|
|
|
|
| def _merge_llopa_batch_row_caches(row_pkvs: list, *, device) -> Optional[DynamicCache]: |
| if not row_pkvs: |
| return None |
| try: |
| n_layers = max(int(pkv_len(pkv)) for pkv in row_pkvs) |
| except Exception: |
| return None |
| if n_layers <= 0: |
| return None |
|
|
| merged_pairs = [] |
| layer_valid_masks: list[torch.Tensor] = [] |
| for layer_idx in range(n_layers): |
| row_kvs = [] |
| max_len = 0 |
| ref_k = None |
| ref_v = None |
| for pkv in row_pkvs: |
| try: |
| k, v = pkv_get(pkv, layer_idx) |
| except Exception: |
| k = None |
| v = None |
| if isinstance(k, torch.Tensor) and isinstance(v, torch.Tensor) and k.dim() == 4 and v.dim() == 4: |
| k = k.to(device=device) |
| v = v.to(device=device) |
| if k.size(0) != 1: |
| k = k[:1] |
| v = v[:1] |
| ref_k = k if ref_k is None else ref_k |
| ref_v = v if ref_v is None else ref_v |
| max_len = max(max_len, int(k.shape[-2])) |
| row_kvs.append((k, v)) |
| else: |
| row_kvs.append((None, None)) |
| if ref_k is None or ref_v is None: |
| return None |
| B = len(row_pkvs) |
| merged_k = torch.zeros( |
| (B, int(ref_k.shape[1]), max_len, int(ref_k.shape[-1])), |
| device=device, |
| dtype=ref_k.dtype, |
| ) |
| merged_v = torch.zeros( |
| (B, int(ref_v.shape[1]), max_len, int(ref_v.shape[-1])), |
| device=device, |
| dtype=ref_v.dtype, |
| ) |
| valid_mask = torch.zeros((B, max_len), device=device, dtype=torch.bool) |
| for row_idx, (k, v) in enumerate(row_kvs): |
| if not isinstance(k, torch.Tensor) or not isinstance(v, torch.Tensor): |
| continue |
| width = int(k.shape[-2]) |
| if width > 0: |
| merged_k[row_idx : row_idx + 1, :, :width, :] = k[:, :, :width, :] |
| merged_v[row_idx : row_idx + 1, :, :width, :] = v[:, :, :width, :] |
| valid_mask[row_idx, :width] = True |
| merged_pairs.append((merged_k, merged_v)) |
| layer_valid_masks.append(valid_mask) |
|
|
| try: |
| merged = DynamicCache(ddp_cache_data=merged_pairs) |
| except Exception: |
| merged = DynamicCache() |
| for layer_idx, (k, v) in enumerate(merged_pairs): |
| merged.update(k, v, layer_idx) |
|
|
| try: |
| setattr(merged, "_llopa_batch_layer_valid_masks", layer_valid_masks) |
| except Exception: |
| pass |
|
|
| replay_hidden = _merge_optional_batch_sequence_attr( |
| row_pkvs, |
| "_tri_last_layer_memory_hidden", |
| device=device, |
| ) |
| replay_pos = _merge_optional_batch_sequence_attr( |
| row_pkvs, |
| "_tri_last_layer_memory_position_ids", |
| device=device, |
| dtype=torch.long, |
| ) |
| replay_valid = _merge_optional_batch_sequence_attr( |
| row_pkvs, |
| "_tri_last_layer_memory_valid_mask", |
| device=device, |
| dtype=torch.bool, |
| ) |
| if isinstance(replay_hidden, torch.Tensor) and isinstance(replay_pos, torch.Tensor): |
| with contextlib.suppress(Exception): |
| setattr(merged, "_tri_last_layer_memory_hidden", replay_hidden) |
| setattr(merged, "_tri_last_layer_memory_position_ids", replay_pos) |
| if isinstance(replay_valid, torch.Tensor): |
| setattr(merged, "_tri_last_layer_memory_valid_mask", replay_valid.to(dtype=torch.bool)) |
| module_type = getattr(row_pkvs[0], "_tri_replay_module", getattr(row_pkvs[0], "_tri_last_layer_module", "none")) |
| setattr(merged, "_tri_replay_module", str(module_type or "none")) |
| setattr(merged, "_tri_last_layer_module", str(module_type or "none")) |
| setattr(merged, "_tri_replay_per_layers", int(getattr(row_pkvs[0], "_tri_replay_per_layers", -1) or -1)) |
| return merged |
|
|
|
|
| def _append_llopa_batch_cache_valid_masks(pkv, input_valid: torch.Tensor) -> None: |
| masks = getattr(pkv, "_llopa_batch_layer_valid_masks", None) |
| if not isinstance(masks, list): |
| return |
| input_valid = input_valid.to(dtype=torch.bool) |
| new_masks = [] |
| for layer_idx, mask in enumerate(masks): |
| if not isinstance(mask, torch.Tensor) or mask.dim() != 2: |
| new_masks.append(mask) |
| continue |
| layer_len = _cache_layer_seq_len(pkv, layer_idx) |
| cur_len = int(mask.size(1)) |
| if layer_len > cur_len: |
| add_width = int(layer_len - cur_len) |
| add = input_valid.to(device=mask.device).view(-1, 1).expand(mask.size(0), add_width) |
| mask = torch.cat([mask, add], dim=1) |
| elif layer_len < cur_len: |
| mask = mask[:, :layer_len] |
| new_masks.append(mask) |
| with contextlib.suppress(Exception): |
| setattr(pkv, "_llopa_batch_layer_valid_masks", new_masks) |
| setattr(pkv, "_tri_past_len_cache", None) |
|
|
|
|
| def _direct_llopa_batch_prefill_cache_and_logits( |
| model, |
| *, |
| segments: list[dict], |
| pad_id: int, |
| lower_k: int, |
| prefill_attn: str, |
| system_prefill: str, |
| no_upper_attn: bool, |
| see_past_assistant: bool, |
| replay_module: str, |
| replay_per_layers: int, |
| device, |
| ): |
| if not segments: |
| return None |
| if _normalize_replay_module_value(replay_module) != "none": |
| return None |
| seed_fn = _get_llopa_full_prompt_seed(model) |
| if not callable(seed_fn): |
| return None |
| try: |
| batch = _batched_segments_to_tensors(segments, device=device, pad_token_id=int(pad_id)) |
| except Exception: |
| return None |
| try: |
| seed = seed_fn( |
| input_ids=batch["prompt_ids"], |
| attention_mask=batch["prompt_attention_mask"], |
| use_cache=True, |
| logits_to_keep=1, |
| lower_k=int(lower_k), |
| prefill_attn=str(prefill_attn), |
| system_prefill=str(system_prefill), |
| no_upper_attn=bool(no_upper_attn), |
| prefill_lower_split_start=batch["split_starts"], |
| prefill_lower_system_len=batch["system_lens"], |
| prefill_lower_replay_user_prefix_keep_len=batch["replay_user_prefix_keep_lens"], |
| prefill_lower_replay_user_start=batch["replay_user_starts"], |
| prefill_lower_replay_user_len=batch["replay_user_lens"], |
| assistant_header_starts=batch["assistant_header_starts"], |
| assistant_turn_ends=batch["assistant_turn_ends"], |
| assistant_header_start_mask=batch["assistant_header_start_mask"], |
| prefill_lower_see_past_assistant=bool(see_past_assistant), |
| replay_module=str(replay_module), |
| replay_per_layers=int(replay_per_layers), |
| ) |
| except Exception: |
| return None |
| if not isinstance(seed, tuple) or len(seed) != 4: |
| return None |
| pkv, S, U, logits = seed |
| if pkv is None or not isinstance(logits, torch.Tensor) or logits.numel() == 0: |
| return None |
| if logits.dim() == 3: |
| logits = logits[:, -1, :] |
| if logits.dim() != 2 or int(logits.size(0)) != len(segments): |
| return None |
| return pkv, S, U, logits.to(device=device, dtype=torch.float32) |
|
|
|
|
| def _direct_llopa_batch_generate_cached_impl( |
| model, |
| tokenizer, |
| *, |
| segments: list[dict], |
| canonical_input_ids: torch.Tensor, |
| context_width: int, |
| batch_size: int, |
| lower_k: int, |
| prefill_attn: str, |
| system_prefill: str, |
| no_upper_attn: bool, |
| see_past_assistant: bool, |
| replay_module: str, |
| replay_per_layers: int, |
| last_layer_module: Optional[str], |
| max_new_tokens: int, |
| min_new_tokens: int, |
| do_sample: bool, |
| logits_warpers, |
| stopping_criteria, |
| stop_ids: set[int], |
| stop_token_ids: Optional[torch.Tensor], |
| pad_id: int, |
| output_scores: bool, |
| compact_scores: bool, |
| return_dict_in_generate: bool, |
| device, |
| ): |
| llopa_core = _get_llopa_core(model) |
| llopa_step = _get_llopa_decode_step(model) |
| if llopa_core is None or llopa_step is None: |
| return None |
| llopa_forward_assistant = getattr(llopa_core, "tri_forward_assistant", None) |
| decode_output_head = _get_output_head(model) |
| use_direct_decode_step = ( |
| _env_flag_enabled("CAPSULE_LLOPA_DIRECT_DECODE_STEP", "1") |
| and callable(llopa_forward_assistant) |
| and decode_output_head is not None |
| ) |
|
|
| row_pkvs = [] |
| initial_logits_rows = [] |
| prompt_bundles = [] |
| for seg in segments: |
| prompt_bundle = _build_unified_prefill_lower_prompt_bundle( |
| tokenizer, |
| prompt_messages=None, |
| prompt_add_generation_prompt=True, |
| structured_prompt_segments=seg, |
| device=device, |
| ) |
| prompt_bundles.append(prompt_bundle) |
|
|
| has_past_assistant_history = bool(see_past_assistant) and any( |
| _prompt_bundle_has_past_assistant_history(bundle) for bundle in prompt_bundles |
| ) |
| batch_seed = None |
| if int(batch_size) > 1: |
| batch_seed = _direct_llopa_batch_prefill_cache_and_logits( |
| model, |
| segments=segments, |
| pad_id=int(pad_id), |
| lower_k=int(lower_k), |
| prefill_attn=str(prefill_attn), |
| system_prefill=str(system_prefill), |
| no_upper_attn=bool(no_upper_attn), |
| see_past_assistant=bool(see_past_assistant), |
| replay_module=str(replay_module), |
| replay_per_layers=int(replay_per_layers), |
| device=device, |
| ) |
|
|
| if batch_seed is not None: |
| pkv, _S, _U, pending_logits = batch_seed |
| pending_logits = pending_logits.to(device=device, dtype=torch.float32) |
| else: |
| if int(batch_size) > 1 and _normalize_replay_module_value(replay_module) != "none": |
| return None |
| if int(batch_size) > 1 and bool(has_past_assistant_history): |
| return None |
|
|
| for prompt_bundle in prompt_bundles: |
| seed = _direct_prefill_lower_cache_and_logits( |
| model, |
| prompt_bundle=prompt_bundle, |
| lower_k=int(lower_k), |
| prefill_attn=str(prefill_attn), |
| system_prefill=str(system_prefill), |
| no_upper_attn=bool(no_upper_attn), |
| see_past_assistant=bool(see_past_assistant), |
| replay_module=str(replay_module), |
| replay_per_layers=int(replay_per_layers), |
| last_layer_module=last_layer_module, |
| seed_mode="prefill_header", |
| ) |
| if seed is None: |
| return None |
| pkv_row, _S, _U, initial_logits = seed |
| if not isinstance(initial_logits, torch.Tensor) or initial_logits.numel() == 0: |
| return None |
| row_pkvs.append(pkv_row) |
| initial_logits_rows.append(initial_logits.to(device=device, dtype=torch.float32)) |
|
|
| pkv = _merge_llopa_batch_row_caches(row_pkvs, device=device) |
| if pkv is None: |
| return None |
| pending_logits = torch.cat(initial_logits_rows, dim=0) |
| del row_pkvs, initial_logits_rows |
| del prompt_bundles |
|
|
| sequences = torch.full( |
| (batch_size, int(context_width) + int(max_new_tokens)), |
| int(pad_id), |
| dtype=canonical_input_ids.dtype, |
| device=device, |
| ) |
| sequences[:, : int(context_width)] = canonical_input_ids.to(device=device) |
| generated = sequences[:, int(context_width) :] |
| score_list: list[torch.Tensor] = [] |
| score_list_append = score_list.append |
| record_scores = bool(output_scores) |
| record_compact_scores = record_scores and bool(compact_scores) |
| compact_logprob_list: list[torch.Tensor] = [] |
| compact_logprob_append = compact_logprob_list.append |
| unfinished = torch.ones((batch_size,), device=device, dtype=torch.bool) |
| finish_steps = torch.full((batch_size,), int(max_new_tokens), device=device, dtype=torch.long) |
| last = None |
| last_valid = torch.ones((batch_size,), device=device, dtype=torch.bool) |
| cur = 0 |
|
|
| while cur < int(max_new_tokens): |
| if pending_logits is None: |
| if not isinstance(last, torch.Tensor): |
| return None |
| if use_direct_decode_step: |
| out = llopa_forward_assistant( |
| assistant_ids=last, |
| lower_k=int(lower_k), |
| pkv=pkv, |
| S=0, |
| U=0, |
| write_cache=True, |
| prefill_mode="lower", |
| no_upper_attn=bool(no_upper_attn), |
| align_cache_position_to_layer_past=True, |
| replay_module=str(replay_module), |
| replay_per_layers=int(replay_per_layers), |
| ) |
| pkv = out.past_key_values or pkv |
| _append_llopa_batch_cache_valid_masks(pkv, last_valid) |
| logits = decode_output_head(out.last_hidden_state[:, -1, :]) |
| if logits.dim() == 3: |
| logits = logits[:, -1, :] |
| else: |
| out = llopa_step( |
| assistant_ids=last, |
| lower_k=int(lower_k), |
| pkv=pkv, |
| S=0, |
| U=0, |
| logits_to_keep=1, |
| labels=None, |
| prefill_mode="lower", |
| no_upper_attn=bool(no_upper_attn), |
| align_cache_position_to_layer_past=True, |
| replay_module=str(replay_module), |
| replay_per_layers=int(replay_per_layers), |
| ) |
| pkv = out.past_key_values or pkv |
| _append_llopa_batch_cache_valid_masks(pkv, last_valid) |
| logits = out.logits[:, -1, :] |
| logits = logits.to(dtype=torch.float32, device=device, copy=True) |
| else: |
| logits = pending_logits |
| pending_logits = None |
|
|
| if stop_token_ids is not None and cur < int(min_new_tokens): |
| logits.index_fill_(1, stop_token_ids.to(device=logits.device), -float("inf")) |
| if logits_warpers is not None: |
| logits = logits_warpers(generated[:, :cur], logits) |
| if record_scores and not record_compact_scores: |
| score_list_append(logits.detach().clone()) |
|
|
| if bool(do_sample): |
| probs = torch.softmax(logits, dim=-1) |
| next_tok = torch.multinomial(probs, num_samples=1) |
| else: |
| next_tok = torch.argmax(logits, dim=-1, keepdim=True) |
| if record_compact_scores: |
| next_logit = torch.gather(logits, 1, next_tok) |
| next_logprob = next_logit - torch.logsumexp(logits, dim=-1, keepdim=True) |
| compact_logprob_append(next_logprob.squeeze(-1).detach()) |
| token_valid = unfinished.clone() |
| if not bool(unfinished.all().item()): |
| next_tok = torch.where( |
| unfinished.view(-1, 1), |
| next_tok, |
| torch.full_like(next_tok, int(pad_id)), |
| ) |
| token_valid = unfinished.clone() |
|
|
| generated[:, cur : cur + 1] = next_tok.to(dtype=generated.dtype) |
| cur += 1 |
|
|
| if stop_ids and cur >= int(min_new_tokens): |
| stop_mask = torch.zeros_like(unfinished) |
| for stop_id in stop_ids: |
| stop_mask |= next_tok.squeeze(1).eq(int(stop_id)) |
| newly_finished = unfinished & stop_mask |
| if bool(newly_finished.any().item()): |
| finish_steps = torch.where( |
| newly_finished, |
| torch.full_like(finish_steps, int(cur)), |
| finish_steps, |
| ) |
| unfinished = unfinished & ~stop_mask |
|
|
| if stopping_criteria is not None: |
| sequences_now = sequences[:, : int(context_width) + cur] |
| try: |
| stop_result = stopping_criteria(sequences_now, logits) |
| except TypeError: |
| stop_result = stopping_criteria(sequences_now, None) |
| if isinstance(stop_result, torch.Tensor): |
| stop_result = stop_result.to(device=device, dtype=torch.bool).reshape(-1) |
| if stop_result.numel() == 1: |
| if bool(stop_result.item()): |
| unfinished.zero_() |
| elif stop_result.numel() == batch_size: |
| newly_finished = unfinished & stop_result |
| if bool(newly_finished.any().item()): |
| finish_steps = torch.where( |
| newly_finished, |
| torch.full_like(finish_steps, int(cur)), |
| finish_steps, |
| ) |
| unfinished = unfinished & ~stop_result |
| elif bool(stop_result.all().item()): |
| unfinished.zero_() |
| elif bool(stop_result): |
| unfinished.zero_() |
|
|
| if not bool(unfinished.any().item()): |
| break |
| last = next_tok |
| last_valid = token_valid |
|
|
| sequences = sequences[:, : int(context_width) + cur] |
| if not bool(return_dict_in_generate): |
| return sequences |
| from transformers.generation.utils import GenerateDecoderOnlyOutput |
|
|
| output = GenerateDecoderOnlyOutput( |
| sequences=sequences, |
| scores=tuple(score_list) if record_scores and not record_compact_scores else None, |
| past_key_values=pkv, |
| ) |
| if record_compact_scores: |
| if compact_logprob_list: |
| compact_logprobs = torch.stack(compact_logprob_list, dim=1) |
| else: |
| compact_logprobs = torch.empty( |
| (sequences.size(0), 0), |
| dtype=torch.float32, |
| device=sequences.device, |
| ) |
| setattr(output, "generated_token_logprobs", compact_logprobs) |
| return output |
|
|
|
|
| def _direct_llopa_batch_generate_serial_cached_fallback_impl( |
| model, |
| tokenizer, |
| *, |
| segments: list[dict], |
| canonical_input_ids: torch.Tensor, |
| context_width: int, |
| batch_size: int, |
| lower_k: int, |
| prefill_attn: str, |
| system_prefill: str, |
| user_prefill: str, |
| no_upper_attn: bool, |
| see_past_assistant: bool, |
| replay_module: str, |
| replay_per_layers: int, |
| last_layer_module: Optional[str], |
| seed_mode: str, |
| max_new_tokens: int, |
| min_new_tokens: int, |
| do_sample: bool, |
| temperature, |
| top_p, |
| top_k, |
| stopping_criteria, |
| pad_id: int, |
| eos_token_id, |
| output_scores: bool, |
| compact_scores: bool, |
| return_dict_in_generate: bool, |
| device, |
| ): |
| if int(batch_size) <= 1: |
| return None |
| row_sequences: list[torch.Tensor] = [] |
| row_score_tuples: list[tuple] = [] |
| row_compact_tensors: list[torch.Tensor] = [] |
| score_template = None |
| max_gen_len = 0 |
| record_scores = bool(output_scores) |
| record_compact_scores = record_scores and bool(compact_scores) |
| for row_idx, seg in enumerate(segments): |
| prompt_ids = seg.get("prompt_ids") if isinstance(seg, dict) else None |
| if not isinstance(prompt_ids, torch.Tensor) or prompt_ids.numel() == 0: |
| return None |
| row_prompt = prompt_ids.to(device=device, dtype=torch.long) |
| if row_prompt.dim() == 1: |
| row_prompt = row_prompt.unsqueeze(0) |
| elif row_prompt.dim() == 2: |
| row_prompt = row_prompt[:1] |
| else: |
| row_prompt = row_prompt.reshape(1, -1) |
| row_prompt_len = int(row_prompt.size(1)) |
| row_out = _direct_llopa_generate_impl( |
| model, |
| tokenizer, |
| prompt_messages=None, |
| prompt_add_generation_prompt=True, |
| structured_prompt_segments=seg, |
| input_ids=row_prompt, |
| attention_mask=torch.ones_like(row_prompt, device=device, dtype=torch.long), |
| lower_k=int(lower_k), |
| prefill_attn=str(prefill_attn), |
| system_prefill=str(system_prefill), |
| user_prefill=str(user_prefill), |
| no_upper_attn=bool(no_upper_attn), |
| see_past_assistant=bool(see_past_assistant), |
| replay_module=str(replay_module), |
| replay_per_layers=int(replay_per_layers), |
| last_layer_module=last_layer_module, |
| seed_mode=str(seed_mode or "prefill_header"), |
| max_new_tokens=int(max_new_tokens), |
| min_new_tokens=int(min_new_tokens), |
| do_sample=bool(do_sample), |
| temperature=temperature, |
| top_p=top_p, |
| top_k=top_k, |
| stopping_criteria=stopping_criteria, |
| pad_token_id=int(pad_id), |
| eos_token_id=eos_token_id, |
| output_scores=record_scores, |
| compact_scores=record_compact_scores, |
| return_dict_in_generate=True, |
| use_cache=True, |
| ) |
| if row_out is None: |
| return None |
| row_seq = getattr(row_out, "sequences", None) |
| if not isinstance(row_seq, torch.Tensor) or row_seq.dim() != 2 or row_seq.size(0) < 1: |
| return None |
| row_gen = row_seq[:1, row_prompt_len:].to(device=device, dtype=canonical_input_ids.dtype) |
| row_sequences.append(row_gen) |
| max_gen_len = max(max_gen_len, int(row_gen.size(1))) |
| if record_compact_scores: |
| compact = getattr(row_out, "generated_token_logprobs", None) |
| if isinstance(compact, torch.Tensor): |
| row_compact_tensors.append(compact[:1].to(device=device, dtype=torch.float32)) |
| else: |
| row_compact_tensors.append(torch.empty((1, 0), device=device, dtype=torch.float32)) |
| elif record_scores: |
| scores = getattr(row_out, "scores", None) |
| if isinstance(scores, tuple): |
| row_score_tuples.append(scores) |
| if score_template is None and len(scores) > 0 and isinstance(scores[0], torch.Tensor): |
| score_template = scores[0][:1].to(device=device) |
| else: |
| row_score_tuples.append(()) |
|
|
| sequences = torch.full( |
| (int(batch_size), int(context_width) + int(max_gen_len)), |
| int(pad_id), |
| dtype=canonical_input_ids.dtype, |
| device=device, |
| ) |
| sequences[:, : int(context_width)] = canonical_input_ids.to(device=device) |
| for row_idx, row_gen in enumerate(row_sequences): |
| width = int(row_gen.size(1)) |
| if width > 0: |
| sequences[row_idx : row_idx + 1, int(context_width) : int(context_width) + width] = row_gen[:, :width] |
|
|
| compact_logprobs = None |
| score_list = None |
| if record_compact_scores: |
| compact_logprobs = torch.zeros( |
| (int(batch_size), int(max_gen_len)), |
| device=device, |
| dtype=torch.float32, |
| ) |
| for row_idx, compact in enumerate(row_compact_tensors): |
| width = min(int(compact.size(1)), int(max_gen_len)) |
| if width > 0: |
| compact_logprobs[row_idx : row_idx + 1, :width] = compact[:, :width] |
| elif record_scores: |
| if score_template is None: |
| score_list = [] |
| else: |
| score_list = [] |
| for step_idx in range(int(max_gen_len)): |
| step_scores = [] |
| for scores in row_score_tuples: |
| if step_idx < len(scores) and isinstance(scores[step_idx], torch.Tensor): |
| step_scores.append(scores[step_idx][:1].to(device=device)) |
| else: |
| step_scores.append(torch.zeros_like(score_template, device=device)) |
| score_list.append(torch.cat(step_scores, dim=0)) |
|
|
| if not bool(return_dict_in_generate): |
| return sequences |
| from transformers.generation.utils import GenerateDecoderOnlyOutput |
|
|
| output = GenerateDecoderOnlyOutput( |
| sequences=sequences, |
| scores=tuple(score_list) if record_scores and not record_compact_scores and score_list is not None else None, |
| past_key_values=None, |
| ) |
| if compact_logprobs is not None: |
| setattr(output, "generated_token_logprobs", compact_logprobs) |
| return output |
|
|
|
|
| @torch.inference_mode() |
| def _direct_llopa_batch_generate_impl( |
| model, |
| tokenizer, |
| *, |
| prompt_messages, |
| prompt_add_generation_prompt: bool, |
| structured_prompt_segments=None, |
| input_ids: Optional[torch.LongTensor], |
| attention_mask: Optional[torch.Tensor], |
| lower_k: int, |
| prefill_attn: str, |
| system_prefill: str, |
| user_prefill: str, |
| no_upper_attn: bool, |
| see_past_assistant: bool = False, |
| replay_module: str = "none", |
| replay_per_layers: int = -1, |
| last_layer_module: Optional[str] = None, |
| seed_mode: str = "prefill_header", |
| max_length=None, |
| max_new_tokens=None, |
| min_length=None, |
| min_new_tokens=None, |
| do_sample=None, |
| temperature=None, |
| top_p=None, |
| top_k=None, |
| stopping_criteria=None, |
| pad_token_id=None, |
| eos_token_id=None, |
| output_scores: bool = False, |
| compact_scores: bool = False, |
| return_dict_in_generate: bool = False, |
| use_cache: Optional[bool] = None, |
| ): |
| allow_cached_paths = use_cache is not False |
| batch_size = int(input_ids.size(0)) if isinstance(input_ids, torch.Tensor) and input_ids.dim() == 2 else 0 |
| if _is_structured_segments_batch(structured_prompt_segments): |
| batch_size = len(structured_prompt_segments) |
| elif _is_prompt_messages_batch(prompt_messages): |
| batch_size = len(prompt_messages) |
| elif batch_size <= 0: |
| batch_size = 1 |
|
|
| if last_layer_module is not None and _normalize_replay_module_value(replay_module) == "none": |
| replay_module = last_layer_module |
| replay_module = _normalize_replay_module_value(replay_module) |
| replay_per_layers = _normalize_replay_per_layers_value(replay_per_layers) |
|
|
| try: |
| lower_k = int(lower_k) |
| except Exception: |
| lower_k = 0 |
| if lower_k <= 0: |
| return None |
|
|
| attn = (prefill_attn or "causal").strip().lower() |
| if attn == "prefix_full": |
| attn = "full" |
| if attn not in {"causal", "full"}: |
| attn = "causal" |
|
|
| sys_prefill = (system_prefill or "full").strip().lower() |
| if sys_prefill not in {"full", "no_system", "no_bos_system"}: |
| sys_prefill = "full" |
|
|
| user_prefill_norm = (user_prefill or "full").strip().lower() |
| if user_prefill_norm != "full": |
| raise ValueError("llopa_v2_batch_generate currently supports only user_prefill='full'.") |
|
|
| if isinstance(input_ids, torch.Tensor): |
| if input_ids.dim() != 2 or input_ids.size(0) != batch_size: |
| return None |
| device = input_ids.device |
| else: |
| try: |
| device = next(model.parameters()).device |
| except Exception: |
| device = "cpu" |
|
|
| pad_id = pad_token_id |
| if pad_id is None: |
| pad_id = getattr(tokenizer, "pad_token_id", None) |
| if pad_id is None: |
| pad_id = getattr(tokenizer, "eos_token_id", None) |
| if pad_id is None: |
| pad_id = 0 |
| pad_id = int(pad_id) |
|
|
| segments = _build_batched_structured_prompt_segments( |
| tokenizer, |
| prompt_messages=prompt_messages, |
| prompt_add_generation_prompt=prompt_add_generation_prompt, |
| structured_prompt_segments=structured_prompt_segments, |
| device=device, |
| ) |
| if len(segments) != batch_size: |
| raise ValueError( |
| f"llopa_v2_batch_generate prompt metadata batch size mismatch: {len(segments)} != {batch_size}" |
| ) |
| batch = _batched_segments_to_tensors(segments, device=device, pad_token_id=pad_id) |
|
|
| if isinstance(input_ids, torch.Tensor): |
| canonical_input_ids = input_ids.to(device=device) |
| context_width = int(canonical_input_ids.size(1)) |
| else: |
| canonical_input_ids = batch["prompt_ids"] |
| context_width = int(canonical_input_ids.size(1)) |
|
|
| raw_temp = 0.0 if temperature is None else float(temperature) |
| if do_sample is None: |
| do_sample = bool(raw_temp != 0.0) |
| do_sample = bool(do_sample) |
| sample_temp = 1.0 if (not do_sample or raw_temp == 0.0) else float(raw_temp) |
| top_p = 1.0 if top_p is None else float(top_p) |
| top_k = None if top_k is None else int(top_k) |
|
|
| if max_new_tokens is None: |
| if max_length is None: |
| max_new_tokens = 256 |
| else: |
| max_new_tokens = max(0, int(max_length) - int(context_width)) |
| else: |
| max_new_tokens = int(max_new_tokens) |
| if min_new_tokens is None: |
| if min_length is None: |
| min_new_tokens = 0 |
| else: |
| min_new_tokens = max(0, int(min_length) - int(context_width)) |
| else: |
| min_new_tokens = int(min_new_tokens) |
| max_new_tokens = int(max_new_tokens) |
| min_new_tokens = int(min_new_tokens) |
|
|
| stop_ids = set(_normalize_eos_token_ids(eos_token_id)) |
| if not stop_ids: |
| tok_eos = getattr(tokenizer, "eos_token_id", None) |
| if tok_eos is not None: |
| stop_ids.add(int(tok_eos)) |
| with contextlib.suppress(Exception): |
| eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>") |
| if eot_id is not None and eot_id != tokenizer.unk_token_id: |
| stop_ids.add(int(eot_id)) |
|
|
| stop_token_ids = _prepare_stop_token_tensor(stop_ids, device) |
| logits_warpers = _build_sampling_warpers(do_sample, sample_temp, top_p, top_k) |
| record_scores = bool(output_scores) |
| record_compact_scores = record_scores and bool(compact_scores) |
| should_apply_stopping_criteria = stopping_criteria is not None |
|
|
| cached_out = None |
| if bool(allow_cached_paths): |
| cached_out = _direct_llopa_batch_generate_cached_impl( |
| model, |
| tokenizer, |
| segments=segments, |
| canonical_input_ids=canonical_input_ids, |
| context_width=int(context_width), |
| batch_size=int(batch_size), |
| lower_k=int(lower_k), |
| prefill_attn=str(attn), |
| system_prefill=str(sys_prefill), |
| no_upper_attn=bool(no_upper_attn), |
| see_past_assistant=bool(see_past_assistant), |
| replay_module=str(replay_module), |
| replay_per_layers=int(replay_per_layers), |
| last_layer_module=last_layer_module, |
| max_new_tokens=int(max_new_tokens), |
| min_new_tokens=int(min_new_tokens), |
| do_sample=bool(do_sample), |
| logits_warpers=logits_warpers, |
| stopping_criteria=stopping_criteria, |
| stop_ids=stop_ids, |
| stop_token_ids=stop_token_ids, |
| pad_id=int(pad_id), |
| output_scores=record_scores, |
| compact_scores=record_compact_scores, |
| return_dict_in_generate=bool(return_dict_in_generate), |
| device=device, |
| ) |
| if cached_out is not None: |
| return cached_out |
|
|
| serial_cached_out = None |
| if bool(allow_cached_paths) and stopping_criteria is None: |
| serial_cached_out = _direct_llopa_batch_generate_serial_cached_fallback_impl( |
| model, |
| tokenizer, |
| segments=segments, |
| canonical_input_ids=canonical_input_ids, |
| context_width=int(context_width), |
| batch_size=int(batch_size), |
| lower_k=int(lower_k), |
| prefill_attn=str(attn), |
| system_prefill=str(sys_prefill), |
| user_prefill=str(user_prefill_norm), |
| no_upper_attn=bool(no_upper_attn), |
| see_past_assistant=bool(see_past_assistant), |
| replay_module=str(replay_module), |
| replay_per_layers=int(replay_per_layers), |
| last_layer_module=last_layer_module, |
| seed_mode=str(seed_mode or "prefill_header"), |
| max_new_tokens=int(max_new_tokens), |
| min_new_tokens=int(min_new_tokens), |
| do_sample=bool(do_sample), |
| temperature=temperature, |
| top_p=top_p, |
| top_k=top_k, |
| stopping_criteria=stopping_criteria, |
| pad_id=int(pad_id), |
| eos_token_id=eos_token_id, |
| output_scores=record_scores, |
| compact_scores=record_compact_scores, |
| return_dict_in_generate=bool(return_dict_in_generate), |
| device=device, |
| ) |
| if serial_cached_out is not None: |
| return serial_cached_out |
|
|
| sequences = torch.full( |
| (batch_size, context_width + max_new_tokens), |
| pad_id, |
| dtype=canonical_input_ids.dtype, |
| device=device, |
| ) |
| sequences[:, :context_width] = canonical_input_ids |
| generated = sequences[:, context_width:] |
| score_list: list[torch.Tensor] = [] |
| score_list_append = score_list.append |
| compact_logprob_list: list[torch.Tensor] = [] |
| compact_logprob_append = compact_logprob_list.append |
|
|
| prompt_rows: list[torch.Tensor] = batch["prompt_rows"] |
| prompt_lens = batch["prompt_lens"] |
| split_starts = batch["split_starts"] |
| system_lens = batch["system_lens"] |
| unfinished = torch.ones((batch_size,), device=device, dtype=torch.bool) |
| finish_steps = torch.full((batch_size,), max_new_tokens, device=device, dtype=torch.long) |
|
|
| def _build_step_tensors(cur_tokens: int) -> tuple[torch.LongTensor, torch.Tensor, torch.LongTensor]: |
| active_lens = [] |
| rows = [] |
| for row_idx, prompt_row in enumerate(prompt_rows): |
| if bool(unfinished[row_idx].item()): |
| gen_len = int(cur_tokens) |
| else: |
| gen_len = min(int(cur_tokens), int(finish_steps[row_idx].item())) |
| gen_prefix = generated[row_idx, :gen_len].to(device=device, dtype=torch.long) |
| row_ids = torch.cat([prompt_row.to(device=device, dtype=torch.long), gen_prefix], dim=0) |
| rows.append(row_ids) |
| active_lens.append(int(row_ids.numel())) |
| step_ids = _pad_1d_rows(rows, pad_value=pad_id, device=device, dtype=torch.long) |
| active_lens_t = torch.tensor(active_lens, device=device, dtype=torch.long) |
| step_mask = ( |
| torch.arange(step_ids.size(1), device=device, dtype=torch.long).unsqueeze(0) |
| < active_lens_t.unsqueeze(1) |
| ).to(dtype=torch.long) |
| labels = torch.full_like(step_ids, -100) |
| for row_idx, active_len in enumerate(active_lens): |
| start = int(split_starts[row_idx].item()) |
| if active_len > start: |
| labels[row_idx, start:active_len] = step_ids[row_idx, start:active_len] |
| return step_ids, step_mask, labels |
|
|
| cur = 0 |
| while cur < max_new_tokens: |
| step_ids, step_mask, labels = _build_step_tensors(cur) |
| out = model( |
| input_ids=step_ids, |
| attention_mask=step_mask, |
| labels=labels, |
| use_cache=False, |
| logits_to_keep=1, |
| prefill_lower_layers=int(lower_k), |
| prefill_lower_attn=str(attn), |
| prefill_lower_system_prefill=str(sys_prefill), |
| prefill_lower_no_upper_attn=bool(no_upper_attn), |
| prefill_lower_split_start=split_starts, |
| prefill_lower_system_len=system_lens, |
| prefill_lower_replay_user_prefix_keep_len=batch["replay_user_prefix_keep_lens"], |
| prefill_lower_replay_user_start=batch["replay_user_starts"], |
| prefill_lower_replay_user_len=batch["replay_user_lens"], |
| assistant_header_starts=batch["assistant_header_starts"], |
| assistant_turn_ends=batch["assistant_turn_ends"], |
| assistant_header_start_mask=batch["assistant_header_start_mask"], |
| prefill_lower_see_past_assistant=bool(see_past_assistant), |
| prefill_lower_replay_module=str(replay_module), |
| prefill_lower_replay_per_layers=int(replay_per_layers), |
| ) |
| if out is None or not isinstance(getattr(out, "logits", None), torch.Tensor): |
| return None |
| logits = out.logits[:, -1, :].to(dtype=torch.float32, device=device, copy=True) |
|
|
| if stop_token_ids is not None and cur < min_new_tokens: |
| logits.index_fill_(1, stop_token_ids, -float("inf")) |
| if logits_warpers is not None: |
| logits = logits_warpers(generated[:, :cur], logits) |
| if record_scores and not record_compact_scores: |
| score_list_append(logits.detach().clone()) |
|
|
| if do_sample: |
| probs = torch.softmax(logits, dim=-1) |
| next_tok = torch.multinomial(probs, num_samples=1) |
| else: |
| next_tok = torch.argmax(logits, dim=-1, keepdim=True) |
| if record_compact_scores: |
| next_logit = torch.gather(logits, 1, next_tok) |
| next_logprob = next_logit - torch.logsumexp(logits, dim=-1, keepdim=True) |
| compact_logprob_append(next_logprob.squeeze(-1).detach()) |
| if not bool(unfinished.all().item()): |
| next_tok = torch.where( |
| unfinished.view(-1, 1), |
| next_tok, |
| torch.full_like(next_tok, pad_id), |
| ) |
|
|
| generated[:, cur : cur + 1] = next_tok.to(dtype=generated.dtype) |
| cur += 1 |
|
|
| if stop_ids and cur >= min_new_tokens: |
| stop_mask = torch.zeros_like(unfinished) |
| for stop_id in stop_ids: |
| stop_mask |= next_tok.squeeze(1).eq(int(stop_id)) |
| newly_finished = unfinished & stop_mask |
| if bool(newly_finished.any().item()): |
| finish_steps = torch.where( |
| newly_finished, |
| torch.full_like(finish_steps, int(cur)), |
| finish_steps, |
| ) |
| unfinished = unfinished & ~stop_mask |
|
|
| if should_apply_stopping_criteria: |
| sequences_now = sequences[:, : context_width + cur] |
| try: |
| stop_result = stopping_criteria(sequences_now, logits) |
| except TypeError: |
| stop_result = stopping_criteria(sequences_now, None) |
| if isinstance(stop_result, torch.Tensor): |
| stop_result = stop_result.to(device=device, dtype=torch.bool).reshape(-1) |
| if stop_result.numel() == 1: |
| if bool(stop_result.item()): |
| unfinished.zero_() |
| elif stop_result.numel() == batch_size: |
| newly_finished = unfinished & stop_result |
| if bool(newly_finished.any().item()): |
| finish_steps = torch.where( |
| newly_finished, |
| torch.full_like(finish_steps, int(cur)), |
| finish_steps, |
| ) |
| unfinished = unfinished & ~stop_result |
| elif bool(stop_result.all().item()): |
| unfinished.zero_() |
| elif bool(stop_result): |
| unfinished.zero_() |
|
|
| if not bool(unfinished.any().item()): |
| break |
|
|
| sequences = sequences[:, : context_width + cur] |
| if not bool(return_dict_in_generate): |
| return sequences |
| from transformers.generation.utils import GenerateDecoderOnlyOutput |
|
|
| output = GenerateDecoderOnlyOutput( |
| sequences=sequences, |
| scores=tuple(score_list) if record_scores and not record_compact_scores else None, |
| past_key_values=None, |
| ) |
| if record_compact_scores: |
| if compact_logprob_list: |
| compact_logprobs = torch.stack(compact_logprob_list, dim=1) |
| else: |
| compact_logprobs = torch.empty( |
| (sequences.size(0), 0), |
| dtype=torch.float32, |
| device=sequences.device, |
| ) |
| setattr(output, "generated_token_logprobs", compact_logprobs) |
| return output |
|
|
|
|
| def _direct_optimized_llopa_generate_impl( |
| model, |
| tokenizer, |
| *, |
| prompt_messages, |
| prompt_add_generation_prompt: bool, |
| structured_prompt_segments=None, |
| input_ids: Optional[torch.LongTensor], |
| attention_mask: Optional[torch.Tensor], |
| lower_k: int, |
| prefill_attn: str, |
| system_prefill: str, |
| user_prefill: str, |
| no_upper_attn: bool, |
| see_past_assistant: bool = False, |
| replay_module: str = "none", |
| replay_per_layers: int = -1, |
| last_layer_module: Optional[str] = None, |
| optimized_variant: Optional[str] = None, |
| optimized_seed_mode: Optional[str] = None, |
| optimized_upper_prepare_mode: Optional[str] = None, |
| optimized_upper_bucket_multiple: Optional[int] = None, |
| optimized_seq_bucket_multiple: Optional[int] = None, |
| max_length=None, |
| max_new_tokens=None, |
| min_length=None, |
| min_new_tokens=None, |
| do_sample=None, |
| temperature=None, |
| top_p=None, |
| top_k=None, |
| stopping_criteria=None, |
| pad_token_id=None, |
| eos_token_id=None, |
| output_scores: bool = False, |
| compact_scores: bool = False, |
| return_dict_in_generate: bool = False, |
| use_cache: Optional[bool] = None, |
| ): |
| if last_layer_module is not None and _normalize_replay_module_value(replay_module) == "none": |
| replay_module = last_layer_module |
| replay_module = _normalize_replay_module_value(replay_module) |
| replay_per_layers = _normalize_replay_per_layers_value(replay_per_layers) |
| if use_cache is False: |
| return None |
|
|
| try: |
| lower_k = int(lower_k) |
| except Exception: |
| lower_k = 0 |
| if lower_k <= 0: |
| return None |
|
|
| attn = (prefill_attn or "causal").strip().lower() |
| if attn == "prefix_full": |
| attn = "full" |
| if attn not in {"causal", "full"}: |
| attn = "causal" |
|
|
| sys_prefill = (system_prefill or "full").strip().lower() |
| if sys_prefill not in {"full", "no_system", "no_bos_system"}: |
| sys_prefill = "full" |
|
|
| user_prefill_norm = (user_prefill or "full").strip().lower() |
| if user_prefill_norm != "full": |
| raise ValueError("Optimized LLoPA currently supports only user_prefill='full'.") |
|
|
| optimized_settings = _resolve_optimized_llopa_settings( |
| variant=optimized_variant, |
| seed_mode=optimized_seed_mode, |
| upper_prepare_mode=optimized_upper_prepare_mode, |
| upper_bucket_multiple=optimized_upper_bucket_multiple, |
| seq_bucket_multiple=optimized_seq_bucket_multiple, |
| ) |
|
|
| llopa_core = _get_llopa_core(model) |
| llopa_step = _get_llopa_decode_step(model) |
| if llopa_core is None or llopa_step is None: |
| return None |
| llopa_forward_assistant = getattr(llopa_core, "tri_forward_assistant", None) |
| decode_output_head = _get_output_head(model) |
| use_direct_decode_step = ( |
| _env_flag_enabled("CAPSULE_LLOPA_DIRECT_DECODE_STEP", "1") |
| and callable(llopa_forward_assistant) |
| and decode_output_head is not None |
| ) |
|
|
| device = None |
| if isinstance(input_ids, torch.Tensor): |
| if input_ids.dim() != 2 or input_ids.size(0) != 1: |
| return None |
| device = input_ids.device |
| if device is None: |
| try: |
| device = next(model.parameters()).device |
| except Exception: |
| device = "cpu" |
|
|
| segments = structured_prompt_segments if isinstance(structured_prompt_segments, dict) else None |
| if segments is None: |
| segments = _build_structured_prompt_segments( |
| tokenizer, |
| prompt_messages, |
| prompt_add_generation_prompt=bool(prompt_add_generation_prompt), |
| device=device, |
| ) |
| prompt_ids = segments["prompt_ids"] |
| system_ids = segments["system_ids"] |
| user_ids = segments["user_ids"] |
| assistant_prefill_ids = segments["assistant_prefill_ids"] |
| replay_user_prefix_keep_len = int(segments.get("replay_user_prefix_keep_len", 0) or 0) |
| replay_user_start = int(segments.get("replay_user_start", 0) or 0) |
| replay_user_len = int(segments.get("replay_user_len", 0) or 0) |
| if assistant_prefill_ids.numel() == 0: |
| return None |
|
|
| raw_temp = 0.0 if temperature is None else float(temperature) |
| if do_sample is None: |
| do_sample = bool(raw_temp != 0.0) |
| do_sample = bool(do_sample) |
| sample_temp = 1.0 if (not do_sample or raw_temp == 0.0) else float(raw_temp) |
| top_p = 1.0 if top_p is None else float(top_p) |
| top_k = None if top_k is None else int(top_k) |
|
|
| initial_logits = None |
| prompt_bundle = _build_unified_prefill_lower_prompt_bundle( |
| tokenizer, |
| prompt_messages=prompt_messages, |
| prompt_add_generation_prompt=bool(prompt_add_generation_prompt), |
| structured_prompt_segments=segments, |
| device=device, |
| ) |
| with _temporary_model_attrs( |
| model, |
| _optimized_llopa_variant=optimized_settings["variant"], |
| _optimized_llopa_seed_mode=optimized_settings["seed_mode"], |
| _optimized_llopa_upper_prepare_mode=optimized_settings["upper_prepare_mode"], |
| _optimized_llopa_upper_bucket_multiple=int(optimized_settings["upper_bucket_multiple"]), |
| _optimized_llopa_seq_bucket_multiple=int(optimized_settings["seq_bucket_multiple"]), |
| ): |
| reference_seed = _optimized_prefill_lower_cache_and_logits( |
| model, |
| prompt_bundle=prompt_bundle, |
| lower_k=lower_k, |
| prefill_attn=attn, |
| system_prefill=sys_prefill, |
| no_upper_attn=bool(no_upper_attn), |
| see_past_assistant=bool(see_past_assistant), |
| replay_module=str(replay_module), |
| replay_per_layers=int(replay_per_layers), |
| seed_mode=optimized_settings["seed_mode"], |
| ) |
| canonical_input_ids = None |
| if isinstance(input_ids, torch.Tensor) and input_ids.dim() == 2 and input_ids.size(0) == 1: |
| valid_len = int(input_ids.size(1)) |
| if isinstance(attention_mask, torch.Tensor) and attention_mask.dim() == 2 and attention_mask.size(0) == 1: |
| valid_len = int(attention_mask[0].sum().item()) |
| if valid_len > 0: |
| canonical_input_ids = input_ids[:, -valid_len:] |
| if not isinstance(canonical_input_ids, torch.Tensor): |
| canonical_input_ids = prompt_bundle.get("effective_prompt_ids") |
| if not isinstance(canonical_input_ids, torch.Tensor): |
| canonical_input_ids = prompt_ids |
| total_prompt_len = int(canonical_input_ids.size(1)) |
| if max_new_tokens is None: |
| if max_length is None: |
| max_new_tokens = 256 |
| else: |
| max_new_tokens = max(0, int(max_length) - total_prompt_len) |
| else: |
| max_new_tokens = int(max_new_tokens) |
| if min_new_tokens is None: |
| if min_length is None: |
| min_new_tokens = 0 |
| else: |
| min_new_tokens = max(0, int(min_length) - total_prompt_len) |
| else: |
| min_new_tokens = int(min_new_tokens) |
| if reference_seed is not None: |
| pkv, S, U, initial_logits = reference_seed |
| else: |
| output_head = _get_output_head(model) |
| if bool(no_upper_attn): |
| pkv, S, U = _llopa_prefill_cache( |
| llopa_core, |
| system_ids, |
| user_ids, |
| assistant_prefill_ids, |
| lower_k=lower_k, |
| prefill_mode="lower", |
| prefill_attn=attn, |
| system_prefill=sys_prefill, |
| replay_user_prefix_keep_len=replay_user_prefix_keep_len, |
| replay_user_start=replay_user_start, |
| replay_user_len=replay_user_len, |
| ) |
| else: |
| pkv, S, U, last_hidden = _llopa_prefill_cache( |
| llopa_core, |
| system_ids, |
| user_ids, |
| assistant_prefill_ids, |
| lower_k=lower_k, |
| prefill_mode="lower", |
| prefill_attn=attn, |
| system_prefill=sys_prefill, |
| return_last_assistant_hidden=bool(output_head is not None), |
| replay_user_prefix_keep_len=replay_user_prefix_keep_len, |
| replay_user_start=replay_user_start, |
| replay_user_len=replay_user_len, |
| ) |
| if output_head is not None and isinstance(last_hidden, torch.Tensor) and last_hidden.numel() > 0: |
| initial_logits = output_head(last_hidden)[:, -1, :].to(torch.float32) |
|
|
| stop_ids = set(_normalize_eos_token_ids(eos_token_id)) |
| if not stop_ids: |
| tok_eos = getattr(tokenizer, "eos_token_id", None) |
| if tok_eos is not None: |
| stop_ids.add(int(tok_eos)) |
| with contextlib.suppress(Exception): |
| eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>") |
| if eot_id is not None and eot_id != tokenizer.unk_token_id: |
| stop_ids.add(int(eot_id)) |
|
|
| last = assistant_prefill_ids[:, -1:] |
| stop_token_ids = _prepare_stop_token_tensor(stop_ids, last.device) |
| logits_warpers = _build_sampling_warpers(do_sample, sample_temp, top_p, top_k) |
| max_new_tokens = int(max_new_tokens) |
| min_new_tokens = int(min_new_tokens) |
| lower_k = int(lower_k) |
| no_upper_attn_bool = bool(no_upper_attn) |
| replay_module_str = str(replay_module) |
| replay_per_layers_int = int(replay_per_layers) |
| record_scores = bool(output_scores) |
| record_compact_scores = bool(output_scores) and bool(compact_scores) |
| should_apply_stopping_criteria = stopping_criteria is not None |
| bucket_multiple = int(optimized_settings["seq_bucket_multiple"] or 256) |
| sequences_full, _ = _acquire_bucketed_sequence_workspace( |
| model, |
| reference_ids=canonical_input_ids, |
| batch_size=int(canonical_input_ids.size(0)), |
| total_len=int(total_prompt_len + max_new_tokens), |
| bucket_multiple=bucket_multiple, |
| ) |
| sequences_full[:, :total_prompt_len] = canonical_input_ids |
| sequences = sequences_full[:, : total_prompt_len + max_new_tokens] |
| generated = sequences[:, total_prompt_len:] |
| score_list: list[torch.Tensor] = [] |
| score_list_append = score_list.append |
| compact_logprob_list: list[torch.Tensor] = [] |
| compact_logprob_append = compact_logprob_list.append |
| cur = 0 |
| pending_logits = initial_logits |
|
|
| while cur < max_new_tokens: |
| out = None |
| if pending_logits is None: |
| if use_direct_decode_step: |
| out = llopa_forward_assistant( |
| assistant_ids=last, |
| lower_k=lower_k, |
| pkv=pkv, |
| S=S, |
| U=U, |
| write_cache=True, |
| prefill_mode="lower", |
| no_upper_attn=no_upper_attn_bool, |
| align_cache_position_to_layer_past=False, |
| replay_module=replay_module_str, |
| replay_per_layers=replay_per_layers_int, |
| ) |
| pkv = out.past_key_values or pkv |
| logits = decode_output_head(out.last_hidden_state[:, -1, :]) |
| if logits.dim() == 3: |
| logits = logits[:, -1, :] |
| else: |
| out = llopa_step( |
| assistant_ids=last, |
| lower_k=lower_k, |
| pkv=pkv, |
| S=S, |
| U=U, |
| logits_to_keep=1, |
| labels=None, |
| prefill_mode="lower", |
| no_upper_attn=no_upper_attn_bool, |
| align_cache_position_to_layer_past=False, |
| replay_module=replay_module_str, |
| replay_per_layers=replay_per_layers_int, |
| ) |
| pkv = out.past_key_values or pkv |
| logits = out.logits[:, -1, :] |
| logits = logits.to(dtype=torch.float32, device=last.device, copy=True) |
| else: |
| logits = pending_logits |
| pending_logits = None |
|
|
| if stop_token_ids is not None and cur < min_new_tokens: |
| logits.index_fill_(1, stop_token_ids, -float("inf")) |
| if logits_warpers is not None: |
| logits = logits_warpers(generated[:, :cur], logits) |
| if record_scores and not record_compact_scores: |
| score_list_append(logits.detach().clone()) |
|
|
| if do_sample: |
| probs = torch.softmax(logits, dim=-1) |
| next_tok = torch.multinomial(probs, num_samples=1) |
| else: |
| next_tok = torch.argmax(logits, dim=-1, keepdim=True) |
| if record_compact_scores: |
| next_logit = torch.gather(logits, 1, next_tok) |
| next_logprob = next_logit - torch.logsumexp(logits, dim=-1, keepdim=True) |
| compact_logprob_append(next_logprob.squeeze(-1).detach()) |
|
|
| generated[:, cur : cur + 1] = next_tok |
| cur += 1 |
|
|
| should_stop = False |
| tok_id = int(next_tok.item()) |
| if tok_id in stop_ids and cur >= min_new_tokens: |
| should_stop = True |
| if (not should_stop) and should_apply_stopping_criteria: |
| sequences_now = sequences[:, : total_prompt_len + cur] |
| try: |
| should_stop = bool(stopping_criteria(sequences_now, logits)) |
| except TypeError: |
| should_stop = bool(stopping_criteria(sequences_now, None)) |
| if out is not None: |
| del out |
| if should_stop: |
| break |
| last = next_tok |
|
|
| sequences = sequences[:, : total_prompt_len + cur] |
| if not bool(return_dict_in_generate): |
| return sequences |
| from transformers.generation.utils import GenerateDecoderOnlyOutput |
|
|
| output = GenerateDecoderOnlyOutput( |
| sequences=sequences, |
| scores=tuple(score_list) if bool(output_scores) and not record_compact_scores else None, |
| past_key_values=pkv, |
| ) |
| if record_compact_scores: |
| if compact_logprob_list: |
| compact_logprobs = torch.stack(compact_logprob_list, dim=1) |
| else: |
| compact_logprobs = torch.empty( |
| (sequences.size(0), 0), |
| dtype=torch.float32, |
| device=sequences.device, |
| ) |
| setattr(output, "generated_token_logprobs", compact_logprobs) |
| return output |
|
|
|
|
| def _normalize_replay_module_value(value: Optional[str]) -> str: |
| raw = str(value or "none").strip().lower() |
| aliases = { |
| "": "none", |
| "off": "none", |
| "disabled": "none", |
| "disable": "none", |
| "self-attention": "self", |
| "self_attention": "self", |
| "selfattn": "self", |
| "self_attn": "self", |
| "cross-attention": "cross", |
| "cross_attention": "cross", |
| "crossattn": "cross", |
| "cross_attn": "cross", |
| } |
| raw = aliases.get(raw, raw) |
| if raw in {"none", "self", "cross"}: |
| return raw |
| return "none" |
|
|
|
|
| def _normalize_replay_per_layers_value(value) -> int: |
| try: |
| normalized = int(value) |
| except Exception: |
| return -1 |
| if normalized == -1 or normalized >= 1: |
| return normalized |
| return -1 |
|
|
|
|
| def _normalize_structured_llopa_runtime( |
| lower_k: int, |
| *, |
| prefill_attn: str = "causal", |
| system_prefill: str = "full", |
| user_prefill: str = "full", |
| no_upper_attn: bool = False, |
| see_past_assistant: bool = False, |
| replay_module: str = "none", |
| replay_per_layers: int = -1, |
| last_layer_module: Optional[str] = None, |
| ): |
| try: |
| lower_k = int(lower_k) |
| except Exception: |
| lower_k = 0 |
| attn = (prefill_attn or "causal").strip().lower() |
| if attn == "prefix_full": |
| attn = "full" |
| sys_prefill = (system_prefill or "full").strip().lower() |
| if sys_prefill not in {"full", "no_system", "no_bos_system"}: |
| sys_prefill = "full" |
| user_prefill_norm = (user_prefill or "full").strip().lower() |
| if last_layer_module is not None and _normalize_replay_module_value(replay_module) == "none": |
| replay_module = last_layer_module |
| replay_module_norm = _normalize_replay_module_value(replay_module) |
| replay_per_layers_norm = _normalize_replay_per_layers_value(replay_per_layers) |
| if lower_k <= 0 or attn not in {"causal", "full"}: |
| return None |
| return ( |
| int(lower_k), |
| attn, |
| sys_prefill, |
| user_prefill_norm, |
| bool(no_upper_attn), |
| replay_module_norm, |
| replay_per_layers_norm, |
| ) |
|
|
|
|
| def _attach_structured_llopa_generate(model, tokenizer) -> None: |
| try: |
| import types |
| except Exception: |
| return |
|
|
| if getattr(model, "_structured_llopa_generate_attached", False): |
| return |
|
|
| orig_generate = getattr(model, "generate", None) |
| if not callable(orig_generate): |
| return |
|
|
| def _structured_llopa_generate(self, *args, **kwargs): |
| if args: |
| if "input_ids" not in kwargs: |
| kwargs["input_ids"] = args[0] |
| args = args[1:] |
| if args: |
| return orig_generate(*args, **kwargs) |
|
|
| optimized_enabled = kwargs.pop("optimized_llopa_generate", None) |
| llopa_v2_batch_enabled = kwargs.pop("llopa_v2_batch_generate", None) |
| llopa_v2_enabled = kwargs.pop("llopa_v2_generate", None) |
| llopa_v3_enabled = kwargs.pop("llopa_v3_generate", None) |
| runtime_solo_enabled = kwargs.pop("runtime_solo_generate", None) |
| runtime_solo_v2_enabled = kwargs.pop("runtime_solo_v2_generate", None) |
| unified_enabled = kwargs.pop("unified_llopa_generate", None) |
| direct_enabled = kwargs.pop("direct_llopa_generate", None) |
| legacy_search = bool(kwargs.pop("direct_llopa_legacy_search", False)) |
| prompt_messages = kwargs.pop("prompt_messages", None) |
| prompt_add_generation_prompt = kwargs.pop("prompt_add_generation_prompt", None) |
| structured_prompt_segments = kwargs.pop("structured_prompt_segments", None) |
| compact_scores = bool(kwargs.pop("capsule_compact_scores", False)) |
|
|
| mode = None |
| if optimized_enabled is not None: |
| if bool(optimized_enabled): |
| mode = "optimized" |
| elif llopa_v2_batch_enabled is None and llopa_v2_enabled is None and llopa_v3_enabled is None and runtime_solo_enabled is None and runtime_solo_v2_enabled is None and unified_enabled is None and direct_enabled is None: |
| return orig_generate(**kwargs) |
| if llopa_v2_batch_enabled is not None: |
| if bool(llopa_v2_batch_enabled): |
| mode = "llopa_v2_batch" |
| elif llopa_v2_enabled is None and llopa_v3_enabled is None and runtime_solo_enabled is None and runtime_solo_v2_enabled is None and unified_enabled is None and direct_enabled is None: |
| return orig_generate(**kwargs) |
| if llopa_v2_enabled is not None: |
| if bool(llopa_v2_enabled): |
| mode = "llopa_v2" |
| elif llopa_v3_enabled is None and runtime_solo_enabled is None and runtime_solo_v2_enabled is None and unified_enabled is None and direct_enabled is None: |
| return orig_generate(**kwargs) |
| if llopa_v3_enabled is not None: |
| if bool(llopa_v3_enabled): |
| mode = "llopa_v3" |
| elif runtime_solo_enabled is None and runtime_solo_v2_enabled is None and unified_enabled is None and direct_enabled is None: |
| return orig_generate(**kwargs) |
| if runtime_solo_v2_enabled is not None: |
| if bool(runtime_solo_v2_enabled): |
| mode = "solo_v2" |
| elif runtime_solo_enabled is None and unified_enabled is None and direct_enabled is None: |
| return orig_generate(**kwargs) |
| if runtime_solo_enabled is not None: |
| if bool(runtime_solo_enabled): |
| mode = "solo" |
| elif unified_enabled is None and direct_enabled is None: |
| return orig_generate(**kwargs) |
| if unified_enabled is not None: |
| if bool(unified_enabled): |
| mode = "unified" |
| elif direct_enabled is None: |
| return orig_generate(**kwargs) |
| if mode is None and direct_enabled is not None: |
| if bool(direct_enabled): |
| mode = "direct" |
| else: |
| return orig_generate(**kwargs) |
| if mode is None: |
| if bool(getattr(self, "_optimized_llopa_generate_default", False)): |
| mode = "optimized" |
| elif bool(getattr(self, "_llopa_v2_batch_generate_default", False)): |
| mode = "llopa_v2_batch" |
| elif bool(getattr(self, "_llopa_v2_generate_default", False)): |
| mode = "llopa_v3" if str(getattr(self, "_capsule_inference_path", "") or "") == "llopa_v3" else "llopa_v2" |
| elif bool(getattr(self, "_unified_llopa_generate_default", False)): |
| mode = "unified" |
| elif bool(getattr(self, "_runtime_structured_freeze_generate_default", False)): |
| mode = "freeze" |
| elif bool(getattr(self, "_runtime_structured_solo_v2_generate_default", False)): |
| mode = "solo_v2" |
| elif bool(getattr(self, "_runtime_structured_solo_generate_default", False)): |
| mode = "solo" |
| elif bool(getattr(self, "_direct_llopa_generate_default", False)): |
| mode = "direct" |
| else: |
| return orig_generate(**kwargs) |
|
|
| if kwargs.get("inputs_embeds") is not None: |
| return orig_generate(**kwargs) |
|
|
| if mode == "optimized": |
| mode_label = "optimized_llopa_generate" |
| elif mode == "llopa_v2_batch": |
| mode_label = "llopa_v2_batch_generate" |
| elif mode == "llopa_v2": |
| mode_label = "llopa_v2_generate" |
| elif mode == "llopa_v3": |
| mode_label = "llopa_v3_generate" |
| elif mode == "unified": |
| mode_label = "unified_llopa_generate" |
| elif mode == "freeze": |
| mode_label = "runtime_freeze_generate" |
| elif mode == "solo_v2": |
| mode_label = "runtime_solo_v2_generate" |
| elif mode == "solo": |
| mode_label = "runtime_solo_generate" |
| else: |
| mode_label = "direct_llopa_generate" |
| if int(kwargs.get("num_beams", 1) or 1) != 1: |
| _warn_once( |
| self, |
| f"_warned_{mode_label}_num_beams", |
| f"[load_llopa_model][warn] {mode_label} currently supports only num_beams=1; falling back to model.generate().", |
| ) |
| return orig_generate(**kwargs) |
| if int(kwargs.get("num_return_sequences", 1) or 1) != 1: |
| _warn_once( |
| self, |
| f"_warned_{mode_label}_num_return_sequences", |
| f"[load_llopa_model][warn] {mode_label} currently supports only num_return_sequences=1; falling back to model.generate().", |
| ) |
| return orig_generate(**kwargs) |
|
|
| if mode in {"unified", "optimized", "llopa_v2", "llopa_v3", "llopa_v2_batch"}: |
| attr_prefix = "_optimized_llopa" if mode == "optimized" else "_llopa_v2" if mode in {"llopa_v2", "llopa_v3", "llopa_v2_batch"} else "_unified_llopa" |
| kw_prefix = "optimized_llopa" if mode == "optimized" else "llopa_v2" if mode in {"llopa_v2", "llopa_v3", "llopa_v2_batch"} else "unified_llopa" |
| lower_k_attr = f"{attr_prefix}_layers" |
| attn_attr = f"{attr_prefix}_attn" |
| system_attr = f"{attr_prefix}_system_prefill" |
| user_attr = f"{attr_prefix}_user_prefill" |
| no_upper_attr = f"{attr_prefix}_no_upper_attn" |
| see_past_assistant_attr = f"{attr_prefix}_see_past_assistant" |
| replay_module_attr = f"{attr_prefix}_replay_module" |
| replay_per_layers_attr = f"{attr_prefix}_replay_per_layers" |
| last_layer_attr = f"{attr_prefix}_last_layer_module" |
| variant_attr = "_optimized_llopa_variant" |
| seed_attr = "_optimized_llopa_seed_mode" |
| upper_prepare_attr = "_optimized_llopa_upper_prepare_mode" |
| upper_bucket_attr = "_optimized_llopa_upper_bucket_multiple" |
| seq_bucket_attr = "_optimized_llopa_seq_bucket_multiple" |
| lower_k_local = kwargs.pop(f"{kw_prefix}_layers", None) |
| if lower_k_local is None: |
| lower_k_local = int(getattr(self, lower_k_attr, 0) or 0) |
| attn_local = kwargs.pop(f"{kw_prefix}_attn", None) |
| if attn_local is None: |
| attn_local = str(getattr(self, attn_attr, "causal") or "causal") |
| system_prefill_local = kwargs.pop(f"{kw_prefix}_system_prefill", None) |
| if system_prefill_local is None: |
| system_prefill_local = str(getattr(self, system_attr, "full") or "full") |
| user_prefill_local = kwargs.pop(f"{kw_prefix}_user_prefill", None) |
| if user_prefill_local is None: |
| user_prefill_local = str(getattr(self, user_attr, "full") or "full") |
| no_upper_attn_local = kwargs.pop(f"{kw_prefix}_no_upper_attn", None) |
| if no_upper_attn_local is None: |
| no_upper_attn_local = bool(getattr(self, no_upper_attr, False)) |
| see_past_assistant_local = kwargs.pop(f"{kw_prefix}_see_past_assistant", None) |
| if see_past_assistant_local is None: |
| see_past_assistant_local = bool(getattr(self, see_past_assistant_attr, False)) |
| replay_module_local = kwargs.pop(f"{kw_prefix}_replay_module", None) |
| if replay_module_local is None: |
| replay_module_local = kwargs.pop(f"{kw_prefix}_last_layer_module", None) |
| if replay_module_local is None: |
| replay_module_local = getattr(self, replay_module_attr, None) |
| if replay_module_local is None: |
| replay_module_local = str(getattr(self, last_layer_attr, "none") or "none") |
| replay_per_layers_local = kwargs.pop(f"{kw_prefix}_replay_per_layers", None) |
| if replay_per_layers_local is None: |
| replay_per_layers_local = getattr(self, replay_per_layers_attr, -1) |
| structured_seed_mode_local = "auto" |
| if mode in {"llopa_v2", "llopa_v3", "llopa_v2_batch"}: |
| kwargs.pop("llopa_v2_seed_mode", None) |
| structured_seed_mode_local = "prefill_header" |
| optimized_variant_local = None |
| optimized_seed_mode_local = None |
| optimized_upper_prepare_mode_local = None |
| optimized_upper_bucket_multiple_local = None |
| optimized_seq_bucket_multiple_local = None |
| if mode == "optimized": |
| optimized_variant_local = kwargs.pop("optimized_llopa_variant", None) |
| if optimized_variant_local is None: |
| optimized_variant_local = getattr(self, variant_attr, "upper_ws_auto") |
| optimized_seed_mode_local = kwargs.pop("optimized_llopa_seed_mode", None) |
| if optimized_seed_mode_local is None: |
| optimized_seed_mode_local = getattr(self, seed_attr, "auto") |
| optimized_upper_prepare_mode_local = kwargs.pop("optimized_llopa_upper_prepare_mode", None) |
| if optimized_upper_prepare_mode_local is None: |
| optimized_upper_prepare_mode_local = getattr(self, upper_prepare_attr, "bucketed_workspace") |
| optimized_upper_bucket_multiple_local = kwargs.pop("optimized_llopa_upper_bucket_multiple", None) |
| if optimized_upper_bucket_multiple_local is None: |
| optimized_upper_bucket_multiple_local = getattr(self, upper_bucket_attr, 256) |
| optimized_seq_bucket_multiple_local = kwargs.pop("optimized_llopa_seq_bucket_multiple", None) |
| if optimized_seq_bucket_multiple_local is None: |
| optimized_seq_bucket_multiple_local = getattr(self, seq_bucket_attr, 256) |
| if prompt_messages is None and structured_prompt_segments is None: |
| _warn_once( |
| self, |
| f"_warned_{mode_label}_missing_prompt_metadata", |
| f"[load_llopa_model][warn] {mode_label} requested without structured prompt metadata; falling back to model.generate().", |
| ) |
| return orig_generate(**kwargs) |
| if prompt_add_generation_prompt is None and structured_prompt_segments is None: |
| raise ValueError(f"{mode_label} requires prompt_add_generation_prompt when prompt_messages are provided.") |
| generate_impl = ( |
| _direct_optimized_llopa_generate_impl |
| if mode == "optimized" |
| else _direct_llopa_batch_generate_impl |
| if mode == "llopa_v2_batch" |
| else _direct_llopa_generate_impl |
| ) |
| generate_kwargs = dict( |
| prompt_messages=prompt_messages, |
| prompt_add_generation_prompt=bool(prompt_add_generation_prompt), |
| structured_prompt_segments=structured_prompt_segments, |
| input_ids=kwargs.get("input_ids"), |
| attention_mask=kwargs.get("attention_mask"), |
| lower_k=int(lower_k_local), |
| prefill_attn=str(attn_local), |
| system_prefill=str(system_prefill_local), |
| user_prefill=str(user_prefill_local), |
| no_upper_attn=bool(no_upper_attn_local), |
| see_past_assistant=bool(see_past_assistant_local), |
| replay_module=str(replay_module_local), |
| replay_per_layers=int(replay_per_layers_local or -1), |
| max_length=kwargs.get("max_length"), |
| max_new_tokens=kwargs.get("max_new_tokens"), |
| min_length=kwargs.get("min_length"), |
| min_new_tokens=kwargs.get("min_new_tokens"), |
| do_sample=kwargs.get("do_sample"), |
| temperature=kwargs.get("temperature"), |
| top_p=kwargs.get("top_p"), |
| top_k=kwargs.get("top_k"), |
| stopping_criteria=kwargs.get("stopping_criteria"), |
| pad_token_id=kwargs.get("pad_token_id"), |
| eos_token_id=kwargs.get("eos_token_id"), |
| output_scores=bool(kwargs.get("output_scores", False)), |
| return_dict_in_generate=bool(kwargs.get("return_dict_in_generate", False)), |
| use_cache=kwargs.get("use_cache"), |
| ) |
| generate_kwargs["compact_scores"] = bool(compact_scores) |
| if mode in {"llopa_v2", "llopa_v3", "llopa_v2_batch"}: |
| generate_kwargs["seed_mode"] = str(structured_seed_mode_local) |
| if mode == "optimized": |
| generate_kwargs.update( |
| optimized_variant=optimized_variant_local, |
| optimized_seed_mode=optimized_seed_mode_local, |
| optimized_upper_prepare_mode=optimized_upper_prepare_mode_local, |
| optimized_upper_bucket_multiple=optimized_upper_bucket_multiple_local, |
| optimized_seq_bucket_multiple=optimized_seq_bucket_multiple_local, |
| ) |
| with torch.inference_mode(): |
| previous_mixin_decode = getattr(self, "_llopa_v2_generation_mixin_decode", None) |
| try: |
| if mode == "llopa_v2": |
| setattr(self, "_llopa_v2_generation_mixin_decode", False) |
| elif mode == "llopa_v3": |
| setattr(self, "_llopa_v2_generation_mixin_decode", True) |
| unified_out = generate_impl( |
| self, |
| tokenizer, |
| **generate_kwargs, |
| ) |
| finally: |
| if previous_mixin_decode is None: |
| with contextlib.suppress(Exception): |
| delattr(self, "_llopa_v2_generation_mixin_decode") |
| else: |
| setattr(self, "_llopa_v2_generation_mixin_decode", previous_mixin_decode) |
| if unified_out is not None: |
| return unified_out |
| raise RuntimeError(f"Structured {mode} LLoPA failed unexpectedly for the current prompt.") |
|
|
| if mode == "freeze": |
| lower_k_local = kwargs.pop("runtime_prefill_freeze_layers", None) |
| if lower_k_local is None: |
| lower_k_local = int(getattr(self, "_runtime_prefill_freeze_layers", 0) or 0) |
| attn_local = kwargs.pop("runtime_prefill_freeze_attn", None) |
| if attn_local is None: |
| attn_local = str(getattr(self, "_runtime_prefill_freeze_attn", "causal") or "causal") |
| system_prefill_local = kwargs.pop("runtime_prefill_freeze_system_prefill", None) |
| if system_prefill_local is None: |
| system_prefill_local = str(getattr(self, "_runtime_prefill_freeze_system_prefill", "no_bos_system") or "no_bos_system") |
| if prompt_messages is None: |
| _warn_once( |
| self, |
| "_warned_runtime_freeze_missing_prompt_metadata", |
| "[load_llopa_model][warn] runtime_freeze_generate requested without structured prompt metadata; falling back to model.generate().", |
| ) |
| return orig_generate(**kwargs) |
| if prompt_add_generation_prompt is None: |
| raise ValueError("runtime_freeze_generate requires prompt_add_generation_prompt when prompt_messages are provided.") |
|
|
| freeze_out = _direct_freeze_generate_impl( |
| self, |
| tokenizer, |
| prompt_messages=prompt_messages, |
| prompt_add_generation_prompt=bool(prompt_add_generation_prompt), |
| input_ids=kwargs.get("input_ids"), |
| attention_mask=kwargs.get("attention_mask"), |
| lower_k=int(lower_k_local), |
| prefill_attn=str(attn_local), |
| system_prefill=str(system_prefill_local), |
| max_length=kwargs.get("max_length"), |
| max_new_tokens=kwargs.get("max_new_tokens"), |
| min_length=kwargs.get("min_length"), |
| min_new_tokens=kwargs.get("min_new_tokens"), |
| do_sample=kwargs.get("do_sample"), |
| temperature=kwargs.get("temperature"), |
| top_p=kwargs.get("top_p"), |
| top_k=kwargs.get("top_k"), |
| stopping_criteria=kwargs.get("stopping_criteria"), |
| pad_token_id=kwargs.get("pad_token_id"), |
| eos_token_id=kwargs.get("eos_token_id"), |
| output_scores=bool(kwargs.get("output_scores", False)), |
| return_dict_in_generate=bool(kwargs.get("return_dict_in_generate", False)), |
| use_cache=kwargs.get("use_cache"), |
| ) |
| if freeze_out is not None: |
| return freeze_out |
| raise RuntimeError("Structured runtime freeze failed unexpectedly for the current prompt.") |
|
|
| if mode == "solo": |
| lower_k_local = kwargs.pop("runtime_prefill_solo_layers", None) |
| if lower_k_local is None: |
| lower_k_local = int(getattr(self, "_runtime_prefill_solo_layers", 0) or 0) |
| attn_local = kwargs.pop("runtime_prefill_solo_attn", None) |
| if attn_local is None: |
| attn_local = str(getattr(self, "_runtime_prefill_solo_attn", "causal") or "causal") |
| system_prefill_local = kwargs.pop("runtime_prefill_solo_system_prefill", None) |
| if system_prefill_local is None: |
| system_prefill_local = str(getattr(self, "_runtime_prefill_solo_system_prefill", "no_bos_system") or "no_bos_system") |
| if prompt_messages is None: |
| _warn_once( |
| self, |
| "_warned_runtime_solo_missing_prompt_metadata", |
| "[load_llopa_model][warn] runtime_solo_generate requested without structured prompt metadata; falling back to model.generate().", |
| ) |
| return orig_generate(**kwargs) |
| if prompt_add_generation_prompt is None: |
| raise ValueError("runtime_solo_generate requires prompt_add_generation_prompt when prompt_messages are provided.") |
|
|
| solo_out = _direct_solo_generate_impl( |
| self, |
| tokenizer, |
| prompt_messages=prompt_messages, |
| prompt_add_generation_prompt=bool(prompt_add_generation_prompt), |
| input_ids=kwargs.get("input_ids"), |
| attention_mask=kwargs.get("attention_mask"), |
| lower_k=int(lower_k_local), |
| prefill_attn=str(attn_local), |
| system_prefill=str(system_prefill_local), |
| max_length=kwargs.get("max_length"), |
| max_new_tokens=kwargs.get("max_new_tokens"), |
| min_length=kwargs.get("min_length"), |
| min_new_tokens=kwargs.get("min_new_tokens"), |
| do_sample=kwargs.get("do_sample"), |
| temperature=kwargs.get("temperature"), |
| top_p=kwargs.get("top_p"), |
| top_k=kwargs.get("top_k"), |
| stopping_criteria=kwargs.get("stopping_criteria"), |
| pad_token_id=kwargs.get("pad_token_id"), |
| eos_token_id=kwargs.get("eos_token_id"), |
| output_scores=bool(kwargs.get("output_scores", False)), |
| return_dict_in_generate=bool(kwargs.get("return_dict_in_generate", False)), |
| use_cache=kwargs.get("use_cache"), |
| ) |
| if solo_out is not None: |
| return solo_out |
| raise RuntimeError("Structured runtime solo-attn failed unexpectedly for the current prompt.") |
|
|
| if mode == "solo_v2": |
| lower_k_local = kwargs.pop("runtime_prefill_solo_v2_layers", None) |
| if lower_k_local is None: |
| lower_k_local = int(getattr(self, "_runtime_prefill_solo_v2_layers", 0) or 0) |
| attn_local = kwargs.pop("runtime_prefill_solo_v2_attn", None) |
| if attn_local is None: |
| attn_local = str(getattr(self, "_runtime_prefill_solo_v2_attn", "causal") or "causal") |
| system_prefill_local = kwargs.pop("runtime_prefill_solo_v2_system_prefill", None) |
| if system_prefill_local is None: |
| system_prefill_local = str(getattr(self, "_runtime_prefill_solo_v2_system_prefill", "no_bos_system") or "no_bos_system") |
| with_bos_local = kwargs.pop("runtime_prefill_solo_v2_with_bos", None) |
| if with_bos_local is None: |
| with_bos_local = bool(getattr(self, "_runtime_prefill_solo_v2_with_bos", False)) |
| if prompt_messages is None: |
| _warn_once( |
| self, |
| "_warned_runtime_solo_v2_missing_prompt_metadata", |
| "[load_llopa_model][warn] runtime_solo_v2_generate requested without structured prompt metadata; falling back to model.generate().", |
| ) |
| return orig_generate(**kwargs) |
| if prompt_add_generation_prompt is None: |
| raise ValueError("runtime_solo_v2_generate requires prompt_add_generation_prompt when prompt_messages are provided.") |
|
|
| solo_out = _direct_solo_generate_impl( |
| self, |
| tokenizer, |
| prompt_messages=prompt_messages, |
| prompt_add_generation_prompt=bool(prompt_add_generation_prompt), |
| input_ids=kwargs.get("input_ids"), |
| attention_mask=kwargs.get("attention_mask"), |
| lower_k=int(lower_k_local), |
| prefill_attn=str(attn_local), |
| system_prefill=str(system_prefill_local), |
| max_length=kwargs.get("max_length"), |
| max_new_tokens=kwargs.get("max_new_tokens"), |
| min_length=kwargs.get("min_length"), |
| min_new_tokens=kwargs.get("min_new_tokens"), |
| do_sample=kwargs.get("do_sample"), |
| temperature=kwargs.get("temperature"), |
| top_p=kwargs.get("top_p"), |
| top_k=kwargs.get("top_k"), |
| stopping_criteria=kwargs.get("stopping_criteria"), |
| pad_token_id=kwargs.get("pad_token_id"), |
| eos_token_id=kwargs.get("eos_token_id"), |
| output_scores=bool(kwargs.get("output_scores", False)), |
| return_dict_in_generate=bool(kwargs.get("return_dict_in_generate", False)), |
| use_cache=kwargs.get("use_cache"), |
| solo_v2=True, |
| with_bos=bool(with_bos_local), |
| ) |
| if solo_out is not None: |
| return solo_out |
| raise RuntimeError("Structured runtime solo-attn-v2 failed unexpectedly for the current prompt.") |
|
|
| _warn_once( |
| self, |
| "_warned_deprecated_direct_llopa", |
| "[load_llopa_model][warn] direct_llopa_* is deprecated; use unified_llopa_* or INFERENCE_PATH=unified_llopa. Legacy users can keep existing envs unchanged.", |
| ) |
| lower_k_local = kwargs.pop("direct_llopa_layers", None) |
| if lower_k_local is None: |
| lower_k_local = int(getattr(self, "_direct_llopa_layers", 0) or 0) |
| attn_local = kwargs.pop("direct_llopa_attn", None) |
| if attn_local is None: |
| attn_local = str(getattr(self, "_direct_llopa_attn", "causal") or "causal") |
| system_prefill_local = kwargs.pop("direct_llopa_system_prefill", None) |
| if system_prefill_local is None: |
| system_prefill_local = str(getattr(self, "_direct_llopa_system_prefill", "full") or "full") |
| user_prefill_local = kwargs.pop("direct_llopa_user_prefill", None) |
| if user_prefill_local is None: |
| user_prefill_local = str(getattr(self, "_direct_llopa_user_prefill", "full") or "full") |
| no_upper_attn_local = kwargs.pop("direct_llopa_no_upper_attn", None) |
| if no_upper_attn_local is None: |
| no_upper_attn_local = bool(getattr(self, "_direct_llopa_no_upper_attn", False)) |
|
|
| if not legacy_search and prompt_messages is None and structured_prompt_segments is None: |
| raise ValueError( |
| "direct_llopa_generate now requires prompt_messages and prompt_add_generation_prompt. " |
| "Use direct_llopa_legacy_search=True only for legacy prompt scanning." |
| ) |
| if not legacy_search and prompt_add_generation_prompt is None and structured_prompt_segments is None: |
| raise ValueError("direct_llopa_generate requires prompt_add_generation_prompt when prompt_messages are provided.") |
|
|
| if legacy_search: |
| with torch.inference_mode(): |
| direct_out = _legacy_direct_llopa_generate_impl( |
| self, |
| tokenizer, |
| input_ids=kwargs.get("input_ids"), |
| attention_mask=kwargs.get("attention_mask"), |
| lower_k=int(lower_k_local), |
| prefill_attn=str(attn_local), |
| max_length=kwargs.get("max_length"), |
| max_new_tokens=kwargs.get("max_new_tokens"), |
| min_length=kwargs.get("min_length"), |
| min_new_tokens=kwargs.get("min_new_tokens"), |
| do_sample=kwargs.get("do_sample"), |
| temperature=kwargs.get("temperature"), |
| top_p=kwargs.get("top_p"), |
| top_k=kwargs.get("top_k"), |
| stopping_criteria=kwargs.get("stopping_criteria"), |
| pad_token_id=kwargs.get("pad_token_id"), |
| eos_token_id=kwargs.get("eos_token_id"), |
| output_scores=bool(kwargs.get("output_scores", False)), |
| return_dict_in_generate=bool(kwargs.get("return_dict_in_generate", False)), |
| use_cache=kwargs.get("use_cache"), |
| ) |
| else: |
| with torch.inference_mode(): |
| direct_out = _direct_llopa_generate_impl( |
| self, |
| tokenizer, |
| prompt_messages=prompt_messages, |
| prompt_add_generation_prompt=bool(prompt_add_generation_prompt), |
| structured_prompt_segments=structured_prompt_segments, |
| input_ids=kwargs.get("input_ids"), |
| attention_mask=kwargs.get("attention_mask"), |
| lower_k=int(lower_k_local), |
| prefill_attn=str(attn_local), |
| system_prefill=str(system_prefill_local), |
| user_prefill=str(user_prefill_local), |
| no_upper_attn=bool(no_upper_attn_local), |
| max_length=kwargs.get("max_length"), |
| max_new_tokens=kwargs.get("max_new_tokens"), |
| min_length=kwargs.get("min_length"), |
| min_new_tokens=kwargs.get("min_new_tokens"), |
| do_sample=kwargs.get("do_sample"), |
| temperature=kwargs.get("temperature"), |
| top_p=kwargs.get("top_p"), |
| top_k=kwargs.get("top_k"), |
| stopping_criteria=kwargs.get("stopping_criteria"), |
| pad_token_id=kwargs.get("pad_token_id"), |
| eos_token_id=kwargs.get("eos_token_id"), |
| output_scores=bool(kwargs.get("output_scores", False)), |
| return_dict_in_generate=bool(kwargs.get("return_dict_in_generate", False)), |
| use_cache=kwargs.get("use_cache"), |
| ) |
| if direct_out is not None: |
| return direct_out |
| if not legacy_search: |
| raise RuntimeError("Structured direct LLoPA failed unexpectedly for the current prompt.") |
| _warn_once( |
| self, |
| "_warned_direct_llopa_fallback", |
| "[load_llopa_model][warn] direct_llopa_generate could not use the current prompt/generation settings; falling back to model.generate().", |
| ) |
| return orig_generate(**kwargs) |
|
|
| try: |
| model.generate = types.MethodType(_structured_llopa_generate, model) |
| setattr(model, "_structured_llopa_generate_attached", True) |
| except Exception: |
| pass |
|
|
|
|
| def _attach_unified_llopa_generate( |
| model, |
| tokenizer, |
| *, |
| lower_k: int, |
| prefill_attn: str = "causal", |
| system_prefill: str = "full", |
| user_prefill: str = "full", |
| no_upper_attn: bool = False, |
| see_past_assistant: bool = False, |
| replay_module: str = "none", |
| replay_per_layers: int = -1, |
| last_layer_module: Optional[str] = None, |
| ) -> None: |
| normalized = _normalize_structured_llopa_runtime( |
| lower_k, |
| prefill_attn=prefill_attn, |
| system_prefill=system_prefill, |
| user_prefill=user_prefill, |
| no_upper_attn=no_upper_attn, |
| replay_module=replay_module, |
| replay_per_layers=replay_per_layers, |
| last_layer_module=last_layer_module, |
| ) |
| if normalized is None: |
| return |
| lower_k, attn, sys_prefill, user_prefill_norm, no_upper_attn, replay_module, replay_per_layers = normalized |
| try: |
| setattr(model, "_unified_llopa_layers", int(lower_k)) |
| setattr(model, "_unified_llopa_attn", attn) |
| setattr(model, "_unified_llopa_system_prefill", sys_prefill) |
| setattr(model, "_unified_llopa_user_prefill", user_prefill_norm) |
| setattr(model, "_unified_llopa_no_upper_attn", bool(no_upper_attn)) |
| setattr(model, "_unified_llopa_see_past_assistant", bool(see_past_assistant)) |
| setattr(model, "_unified_llopa_replay_module", str(replay_module)) |
| setattr(model, "_unified_llopa_last_layer_module", str(replay_module)) |
| setattr(model, "_unified_llopa_replay_per_layers", int(replay_per_layers)) |
| setattr(model, "_unified_llopa_generate_default", True) |
| setattr(model, "_capsule_inference_path", "unified_llopa") |
| except Exception: |
| return |
| _attach_structured_llopa_generate(model, tokenizer) |
| try: |
| setattr(model, "_unified_llopa_generate_attached", True) |
| except Exception: |
| pass |
|
|
|
|
| def _attach_llopa_v2_generate( |
| model, |
| tokenizer, |
| *, |
| lower_k: int, |
| prefill_attn: str = "causal", |
| system_prefill: str = "full", |
| user_prefill: str = "full", |
| no_upper_attn: bool = False, |
| see_past_assistant: bool = False, |
| replay_module: str = "none", |
| replay_per_layers: int = -1, |
| last_layer_module: Optional[str] = None, |
| seed_mode: str = "prefill_header", |
| generation_mixin_decode: bool = False, |
| capsule_inference_path: str = "llopa_v2", |
| ) -> None: |
| normalized = _normalize_structured_llopa_runtime( |
| lower_k, |
| prefill_attn=prefill_attn, |
| system_prefill=system_prefill, |
| user_prefill=user_prefill, |
| no_upper_attn=no_upper_attn, |
| replay_module=replay_module, |
| replay_per_layers=replay_per_layers, |
| last_layer_module=last_layer_module, |
| ) |
| if normalized is None: |
| return |
| lower_k, attn, sys_prefill, user_prefill_norm, no_upper_attn, replay_module, replay_per_layers = normalized |
| normalized_seed_mode = _normalize_structured_llopa_seed_mode(seed_mode) |
| if normalized_seed_mode != "prefill_header": |
| normalized_seed_mode = "prefill_header" |
| try: |
| setattr(model, "_llopa_v2_layers", int(lower_k)) |
| setattr(model, "_llopa_v2_attn", attn) |
| setattr(model, "_llopa_v2_system_prefill", sys_prefill) |
| setattr(model, "_llopa_v2_user_prefill", user_prefill_norm) |
| setattr(model, "_llopa_v2_no_upper_attn", bool(no_upper_attn)) |
| setattr(model, "_llopa_v2_see_past_assistant", bool(see_past_assistant)) |
| setattr(model, "_llopa_v2_replay_module", str(replay_module)) |
| setattr(model, "_llopa_v2_last_layer_module", str(replay_module)) |
| setattr(model, "_llopa_v2_replay_per_layers", int(replay_per_layers)) |
| setattr(model, "_llopa_v2_seed_mode", str(normalized_seed_mode)) |
| setattr(model, "_llopa_v2_generation_mixin_decode", bool(generation_mixin_decode)) |
| setattr(model, "_llopa_v2_generate_default", True) |
| setattr(model, "_capsule_inference_path", str(capsule_inference_path or "llopa_v2")) |
| except Exception: |
| return |
| _attach_structured_llopa_generate(model, tokenizer) |
| try: |
| setattr(model, "_llopa_v2_generate_attached", True) |
| except Exception: |
| pass |
|
|
|
|
| def _attach_llopa_v2_batch_generate( |
| model, |
| tokenizer, |
| *, |
| lower_k: int, |
| prefill_attn: str = "causal", |
| system_prefill: str = "full", |
| user_prefill: str = "full", |
| no_upper_attn: bool = False, |
| see_past_assistant: bool = False, |
| replay_module: str = "none", |
| replay_per_layers: int = -1, |
| last_layer_module: Optional[str] = None, |
| seed_mode: str = "prefill_header", |
| ) -> None: |
| normalized = _normalize_structured_llopa_runtime( |
| lower_k, |
| prefill_attn=prefill_attn, |
| system_prefill=system_prefill, |
| user_prefill=user_prefill, |
| no_upper_attn=no_upper_attn, |
| replay_module=replay_module, |
| replay_per_layers=replay_per_layers, |
| last_layer_module=last_layer_module, |
| ) |
| if normalized is None: |
| return |
| lower_k, attn, sys_prefill, user_prefill_norm, no_upper_attn, replay_module, replay_per_layers = normalized |
| normalized_seed_mode = _normalize_structured_llopa_seed_mode(seed_mode) |
| if normalized_seed_mode != "prefill_header": |
| normalized_seed_mode = "prefill_header" |
| try: |
| setattr(model, "_llopa_v2_layers", int(lower_k)) |
| setattr(model, "_llopa_v2_attn", attn) |
| setattr(model, "_llopa_v2_system_prefill", sys_prefill) |
| setattr(model, "_llopa_v2_user_prefill", user_prefill_norm) |
| setattr(model, "_llopa_v2_no_upper_attn", bool(no_upper_attn)) |
| setattr(model, "_llopa_v2_see_past_assistant", bool(see_past_assistant)) |
| setattr(model, "_llopa_v2_replay_module", str(replay_module)) |
| setattr(model, "_llopa_v2_last_layer_module", str(replay_module)) |
| setattr(model, "_llopa_v2_replay_per_layers", int(replay_per_layers)) |
| setattr(model, "_llopa_v2_seed_mode", str(normalized_seed_mode)) |
| setattr(model, "_llopa_v2_batch_generate_default", True) |
| setattr(model, "_capsule_inference_path", "llopa_v2_batch") |
| except Exception: |
| return |
| _attach_structured_llopa_generate(model, tokenizer) |
| try: |
| setattr(model, "_llopa_v2_batch_generate_attached", True) |
| except Exception: |
| pass |
|
|
|
|
| def _attach_optimized_llopa_generate( |
| model, |
| tokenizer, |
| *, |
| lower_k: int, |
| prefill_attn: str = "causal", |
| system_prefill: str = "full", |
| user_prefill: str = "full", |
| no_upper_attn: bool = False, |
| see_past_assistant: bool = False, |
| replay_module: str = "none", |
| replay_per_layers: int = -1, |
| last_layer_module: Optional[str] = None, |
| optimized_variant: Optional[str] = None, |
| optimized_seed_mode: Optional[str] = None, |
| optimized_upper_prepare_mode: Optional[str] = None, |
| optimized_upper_bucket_multiple: Optional[int] = None, |
| optimized_seq_bucket_multiple: Optional[int] = None, |
| ) -> None: |
| normalized = _normalize_structured_llopa_runtime( |
| lower_k, |
| prefill_attn=prefill_attn, |
| system_prefill=system_prefill, |
| user_prefill=user_prefill, |
| no_upper_attn=no_upper_attn, |
| replay_module=replay_module, |
| replay_per_layers=replay_per_layers, |
| last_layer_module=last_layer_module, |
| ) |
| if normalized is None: |
| return |
| lower_k, attn, sys_prefill, user_prefill_norm, no_upper_attn, replay_module, replay_per_layers = normalized |
| optimized_settings = _resolve_optimized_llopa_settings( |
| variant=optimized_variant, |
| seed_mode=optimized_seed_mode, |
| upper_prepare_mode=optimized_upper_prepare_mode, |
| upper_bucket_multiple=optimized_upper_bucket_multiple, |
| seq_bucket_multiple=optimized_seq_bucket_multiple, |
| ) |
| try: |
| setattr(model, "_optimized_llopa_layers", int(lower_k)) |
| setattr(model, "_optimized_llopa_attn", attn) |
| setattr(model, "_optimized_llopa_system_prefill", sys_prefill) |
| setattr(model, "_optimized_llopa_user_prefill", user_prefill_norm) |
| setattr(model, "_optimized_llopa_no_upper_attn", bool(no_upper_attn)) |
| setattr(model, "_optimized_llopa_see_past_assistant", bool(see_past_assistant)) |
| setattr(model, "_optimized_llopa_replay_module", str(replay_module)) |
| setattr(model, "_optimized_llopa_last_layer_module", str(replay_module)) |
| setattr(model, "_optimized_llopa_replay_per_layers", int(replay_per_layers)) |
| setattr(model, "_optimized_llopa_variant", str(optimized_settings["variant"])) |
| setattr(model, "_optimized_llopa_seed_mode", str(optimized_settings["seed_mode"])) |
| setattr(model, "_optimized_llopa_upper_prepare_mode", str(optimized_settings["upper_prepare_mode"])) |
| setattr(model, "_optimized_llopa_upper_bucket_multiple", int(optimized_settings["upper_bucket_multiple"])) |
| setattr(model, "_optimized_llopa_seq_bucket_multiple", int(optimized_settings["seq_bucket_multiple"])) |
| setattr(model, "_optimized_llopa_generate_default", True) |
| setattr(model, "_capsule_inference_path", "optimized_llopa") |
| except Exception: |
| return |
| _attach_structured_llopa_generate(model, tokenizer) |
| try: |
| setattr(model, "_optimized_llopa_generate_attached", True) |
| except Exception: |
| pass |
|
|
|
|
| def _attach_direct_llopa_generate( |
| model, |
| tokenizer, |
| *, |
| lower_k: int, |
| prefill_attn: str = "causal", |
| system_prefill: str = "full", |
| user_prefill: str = "full", |
| no_upper_attn: bool = False, |
| ) -> None: |
| normalized = _normalize_structured_llopa_runtime( |
| lower_k, |
| prefill_attn=prefill_attn, |
| system_prefill=system_prefill, |
| user_prefill=user_prefill, |
| no_upper_attn=no_upper_attn, |
| ) |
| if normalized is None: |
| return |
| lower_k, attn, sys_prefill, user_prefill_norm, no_upper_attn, _, _ = normalized |
| header_ids = _assistant_header_ids(tokenizer, "cpu") |
| try: |
| setattr(model, "_direct_llopa_layers", int(lower_k)) |
| setattr(model, "_direct_llopa_attn", attn) |
| setattr(model, "_direct_llopa_system_prefill", sys_prefill) |
| setattr(model, "_direct_llopa_user_prefill", user_prefill_norm) |
| setattr(model, "_direct_llopa_no_upper_attn", bool(no_upper_attn)) |
| setattr(model, "_direct_llopa_generate_default", True) |
| setattr(model, "_capsule_inference_path", "legacy_llopa") |
| if isinstance(header_ids, torch.Tensor) and header_ids.numel() > 0: |
| setattr(model, "_direct_llopa_header_ids", header_ids.detach().to(device="cpu", dtype=torch.long)) |
| except Exception: |
| return |
| _attach_structured_llopa_generate(model, tokenizer) |
| try: |
| setattr(model, "_direct_llopa_generate_attached", True) |
| except Exception: |
| pass |
|
|
|
|
| def load_llopa_model(model_repo: str, |
| *, |
| model_name: str = "", |
| tokenizer_name: str = "", |
| num_specials: Optional[int] = None, |
| backbone_dir: str = "", |
| lopa_modeling_path: str = "", |
| modeling_family: str = "auto", |
| dtype: str = "auto", |
| torch_dtype=None, |
| device: str = "", |
| device_map: Optional[str] = None, |
| attn_impl: str = "auto", |
| attn_implementation: Optional[str] = None, |
| _attn_implementation: Optional[str] = None, |
| trust_remote_code: bool = False, |
| cache_dir: Optional[str] = None, |
| revision: Optional[str] = None, |
| token: Optional[str] = None, |
| local_files_only: bool = False, |
| force_custom_modeling: bool = False, |
| number_of_lora: int = 1, |
| use_lora: bool = True, |
| merge_on_cpu: bool = True, |
| enable_thinking: Optional[bool] = None, |
| no_upper_attn: Optional[bool] = None, |
| runtime_prefill_lower: Optional[bool] = None, |
| runtime_prefill_layers: Optional[int] = None, |
| runtime_prefill_attn: Optional[str] = None, |
| runtime_prefill_system_prefill: Optional[str] = None, |
| runtime_prefill_freeze: Optional[bool] = None, |
| runtime_prefill_freeze_layers: Optional[int] = None, |
| runtime_prefill_freeze_attn: Optional[str] = None, |
| runtime_prefill_freeze_system_prefill: Optional[str] = None, |
| runtime_prefill_solo: Optional[bool] = None, |
| runtime_prefill_solo_layers: Optional[int] = None, |
| runtime_prefill_solo_attn: Optional[str] = None, |
| runtime_prefill_solo_system_prefill: Optional[str] = None, |
| runtime_prefill_solo_v2: Optional[bool] = None, |
| runtime_prefill_solo_v2_layers: Optional[int] = None, |
| runtime_prefill_solo_v2_attn: Optional[str] = None, |
| runtime_prefill_solo_v2_system_prefill: Optional[str] = None, |
| runtime_prefill_solo_v2_with_bos: Optional[bool] = None, |
| runtime_llopa_prefill: Optional[bool] = None, |
| runtime_llopa_layers: Optional[int] = None, |
| runtime_llopa_attn: Optional[str] = None, |
| runtime_llopa_no_upper_attn: Optional[bool] = None, |
| unified_llopa_generate: Optional[bool] = None, |
| unified_llopa_layers: Optional[int] = None, |
| unified_llopa_attn: Optional[str] = None, |
| unified_llopa_system_prefill: Optional[str] = None, |
| unified_llopa_user_prefill: Optional[str] = None, |
| unified_llopa_no_upper_attn: Optional[bool] = None, |
| unified_llopa_see_past_assistant: Optional[bool] = None, |
| unified_llopa_replay_module: Optional[str] = None, |
| unified_llopa_replay_per_layers: Optional[int] = None, |
| unified_llopa_last_layer_module: Optional[str] = None, |
| llopa_v2_batch_generate: Optional[bool] = None, |
| llopa_v2_generate: Optional[bool] = None, |
| llopa_v3_generate: Optional[bool] = None, |
| llopa_v2_layers: Optional[int] = None, |
| llopa_v2_attn: Optional[str] = None, |
| llopa_v2_system_prefill: Optional[str] = None, |
| llopa_v2_user_prefill: Optional[str] = None, |
| llopa_v2_no_upper_attn: Optional[bool] = None, |
| llopa_v2_see_past_assistant: Optional[bool] = None, |
| llopa_v2_replay_module: Optional[str] = None, |
| llopa_v2_replay_per_layers: Optional[int] = None, |
| llopa_v2_last_layer_module: Optional[str] = None, |
| llopa_v2_seed_mode: Optional[str] = None, |
| optimized_llopa_generate: Optional[bool] = None, |
| optimized_llopa_layers: Optional[int] = None, |
| optimized_llopa_attn: Optional[str] = None, |
| optimized_llopa_system_prefill: Optional[str] = None, |
| optimized_llopa_user_prefill: Optional[str] = None, |
| optimized_llopa_no_upper_attn: Optional[bool] = None, |
| optimized_llopa_see_past_assistant: Optional[bool] = None, |
| optimized_llopa_replay_module: Optional[str] = None, |
| optimized_llopa_replay_per_layers: Optional[int] = None, |
| optimized_llopa_last_layer_module: Optional[str] = None, |
| optimized_llopa_variant: Optional[str] = None, |
| optimized_llopa_seed_mode: Optional[str] = None, |
| optimized_llopa_upper_prepare_mode: Optional[str] = None, |
| optimized_llopa_upper_bucket_multiple: Optional[int] = None, |
| optimized_llopa_seq_bucket_multiple: Optional[int] = None, |
| runtime_llopa_fast_generate: Optional[bool] = None, |
| direct_llopa_generate: Optional[bool] = None, |
| direct_llopa_layers: Optional[int] = None, |
| direct_llopa_attn: Optional[str] = None, |
| direct_llopa_system_prefill: Optional[str] = None, |
| direct_llopa_user_prefill: Optional[str] = None, |
| direct_llopa_no_upper_attn: Optional[bool] = None): |
| repo_path = _resolve_repo_path(model_repo, cache_dir=cache_dir, revision=revision, |
| token=token, local_files_only=local_files_only) |
|
|
| info: dict[str, str] = {} |
| tri_info = repo_path / "tri_info.txt" |
| if tri_info.is_file(): |
| info = _read_kv_file(tri_info) |
| if not model_name: |
| model_name = info.get("model_name", "") |
| if num_specials is None: |
| try: |
| num_specials = int(info.get("num_specials", "") or 0) |
| except Exception: |
| num_specials = 0 |
| if runtime_prefill_lower is None: |
| runtime_prefill_lower = False |
| if runtime_prefill_freeze is None: |
| runtime_prefill_freeze = False |
| if runtime_prefill_solo is None: |
| runtime_prefill_solo = False |
| if runtime_prefill_solo_v2 is None: |
| runtime_prefill_solo_v2 = False |
| if runtime_llopa_prefill is None: |
| runtime_llopa_prefill = False |
| if unified_llopa_generate is None: |
| unified_llopa_generate = False |
| if llopa_v2_batch_generate is None: |
| llopa_v2_batch_generate = False |
| if llopa_v2_generate is None: |
| llopa_v2_generate = False |
| if llopa_v3_generate is None: |
| llopa_v3_generate = False |
| if bool(llopa_v3_generate): |
| llopa_v2_generate = True |
| if optimized_llopa_generate is None: |
| optimized_llopa_generate = False |
| if runtime_llopa_fast_generate is None: |
| runtime_llopa_fast_generate = False |
| if direct_llopa_generate is None: |
| direct_llopa_generate = False |
|
|
| backbone_ref = ( |
| backbone_dir |
| or read_backbone_ref(repo_path) |
| or _read_adapter_backbone_ref(repo_path) |
| or model_name |
| or model_repo |
| ) |
| num_specials_arg = num_specials |
|
|
| config = None |
| config_source = str(repo_path) if (repo_path / "config.json").is_file() else backbone_ref |
| try: |
| config = AutoConfig.from_pretrained( |
| config_source, |
| cache_dir=cache_dir, |
| revision=revision, |
| token=token, |
| local_files_only=local_files_only, |
| ) |
| if num_specials_arg is not None: |
| config.llopa_num_specials = int(num_specials_arg) |
| except Exception: |
| config = None |
| if num_specials_arg is None: |
| num_specials = int(getattr(config, "llopa_num_specials", 0) or 0) if config is not None else 0 |
| else: |
| num_specials = int(num_specials_arg) |
|
|
| def _config_str(name: str) -> str: |
| if config is None: |
| return "" |
| raw = getattr(config, name, "") |
| return str(raw or "").strip() |
|
|
| def _config_bool(name: str): |
| if config is None: |
| return None |
| value = getattr(config, name, None) |
| if value is None: |
| return None |
| return bool(value) |
|
|
| def _normalize_attention_gate_mode(mode: str) -> str: |
| normalized = str(mode or "off").strip().lower() |
| aliases = { |
| "": "off", |
| "none": "off", |
| "disabled": "off", |
| "disable": "off", |
| "false": "off", |
| "0": "off", |
| "paper": "sdpa_sigmoid", |
| "sdpa_gate": "sdpa_sigmoid", |
| "sdpa-gate": "sdpa_sigmoid", |
| "sigmoid_after_sdpa": "sdpa_sigmoid", |
| "sdpa_elementwise_sigmoid": "sdpa_sigmoid", |
| } |
| normalized = aliases.get(normalized, normalized) |
| if normalized not in {"off", "sdpa_sigmoid"}: |
| normalized = "off" |
| return normalized |
|
|
| attention_gate_mode = _normalize_attention_gate_mode( |
| info.get("attention_gate_mode") or _config_str("capsule_attention_gate_mode") or "off" |
| ) |
| if config is not None: |
| with contextlib.suppress(Exception): |
| setattr(config, "capsule_attention_gate_mode", attention_gate_mode) |
|
|
| if no_upper_attn is None: |
| raw = (info.get("no_upper_attn") or _config_str("capsule_no_upper_attn")).strip().lower() |
| if raw in {"1", "true", "yes", "on"}: |
| no_upper_attn = True |
| elif raw in {"0", "false", "no", "off"}: |
| no_upper_attn = False |
| else: |
| cfg_bool = _config_bool("capsule_no_upper_attn") |
| if cfg_bool is not None: |
| no_upper_attn = cfg_bool |
| if runtime_prefill_layers is None: |
| raw = (info.get("lower_k") or _config_str("capsule_lower_layers")).strip() |
| if raw: |
| with contextlib.suppress(Exception): |
| runtime_prefill_layers = int(raw) |
| if runtime_prefill_attn is None: |
| runtime_prefill_attn = ( |
| (info.get("prefill_attn") or _config_str("capsule_prefill_attn") or "causal").strip().lower() or "causal" |
| ) |
| if runtime_prefill_system_prefill is None: |
| raw_system_prefill = (info.get("system_prefill") or _config_str("capsule_system_prefill")).strip().lower() |
| if raw_system_prefill in {"full", "no_system", "no_bos_system"}: |
| runtime_prefill_system_prefill = raw_system_prefill |
| if runtime_llopa_layers is None: |
| runtime_llopa_layers = runtime_prefill_layers |
| if runtime_llopa_attn is None: |
| runtime_llopa_attn = runtime_prefill_attn |
| if runtime_llopa_no_upper_attn is None: |
| runtime_llopa_no_upper_attn = bool(no_upper_attn) if no_upper_attn is not None else False |
| if unified_llopa_layers is None: |
| unified_llopa_layers = runtime_prefill_layers |
| if unified_llopa_attn is None: |
| unified_llopa_attn = runtime_prefill_attn |
| if unified_llopa_system_prefill is None: |
| unified_llopa_system_prefill = runtime_prefill_system_prefill |
| if unified_llopa_user_prefill is None: |
| raw_user_prefill = (info.get("user_prefill") or _config_str("capsule_user_prefill") or "full").strip().lower() |
| if raw_user_prefill: |
| unified_llopa_user_prefill = raw_user_prefill |
| if unified_llopa_no_upper_attn is None: |
| unified_llopa_no_upper_attn = bool(no_upper_attn) if no_upper_attn is not None else False |
| if unified_llopa_see_past_assistant is None: |
| unified_llopa_see_past_assistant = False |
| if unified_llopa_replay_module is None: |
| unified_llopa_replay_module = unified_llopa_last_layer_module |
| if unified_llopa_replay_module is None: |
| unified_llopa_replay_module = ( |
| info.get("replay_module") |
| or _config_str("capsule_replay_module") |
| or info.get("last_layer_module") |
| or _config_str("capsule_last_layer_module") |
| ) |
| unified_llopa_replay_module = _normalize_replay_module_value(unified_llopa_replay_module) |
| unified_llopa_last_layer_module = str(unified_llopa_replay_module) |
| if unified_llopa_replay_per_layers is None: |
| unified_llopa_replay_per_layers = info.get("replay_per_layers") or _config_str("capsule_replay_per_layers") or -1 |
| unified_llopa_replay_per_layers = _normalize_replay_per_layers_value(unified_llopa_replay_per_layers) |
| if llopa_v2_layers is None: |
| llopa_v2_layers = runtime_prefill_layers |
| if llopa_v2_attn is None: |
| llopa_v2_attn = runtime_prefill_attn |
| if llopa_v2_system_prefill is None: |
| llopa_v2_system_prefill = runtime_prefill_system_prefill |
| if llopa_v2_user_prefill is None: |
| raw_user_prefill = (info.get("user_prefill") or _config_str("capsule_user_prefill") or "full").strip().lower() |
| if raw_user_prefill: |
| llopa_v2_user_prefill = raw_user_prefill |
| if llopa_v2_no_upper_attn is None: |
| llopa_v2_no_upper_attn = bool(no_upper_attn) if no_upper_attn is not None else False |
| if llopa_v2_see_past_assistant is None: |
| llopa_v2_see_past_assistant = False |
| if llopa_v2_replay_module is None: |
| llopa_v2_replay_module = llopa_v2_last_layer_module |
| if llopa_v2_replay_module is None: |
| llopa_v2_replay_module = ( |
| info.get("replay_module") |
| or _config_str("capsule_replay_module") |
| or info.get("last_layer_module") |
| or _config_str("capsule_last_layer_module") |
| ) |
| llopa_v2_replay_module = _normalize_replay_module_value(llopa_v2_replay_module) |
| llopa_v2_last_layer_module = str(llopa_v2_replay_module) |
| if llopa_v2_replay_per_layers is None: |
| llopa_v2_replay_per_layers = info.get("replay_per_layers") or _config_str("capsule_replay_per_layers") or -1 |
| llopa_v2_replay_per_layers = _normalize_replay_per_layers_value(llopa_v2_replay_per_layers) |
| llopa_v2_seed_mode = "prefill_header" |
| if optimized_llopa_layers is None: |
| optimized_llopa_layers = runtime_prefill_layers |
| if optimized_llopa_attn is None: |
| optimized_llopa_attn = runtime_prefill_attn |
| if optimized_llopa_system_prefill is None: |
| optimized_llopa_system_prefill = runtime_prefill_system_prefill |
| if optimized_llopa_user_prefill is None: |
| raw_user_prefill = (info.get("user_prefill") or _config_str("capsule_user_prefill") or "full").strip().lower() |
| if raw_user_prefill: |
| optimized_llopa_user_prefill = raw_user_prefill |
| if optimized_llopa_no_upper_attn is None: |
| optimized_llopa_no_upper_attn = bool(no_upper_attn) if no_upper_attn is not None else False |
| if optimized_llopa_see_past_assistant is None: |
| optimized_llopa_see_past_assistant = False |
| if optimized_llopa_replay_module is None: |
| optimized_llopa_replay_module = optimized_llopa_last_layer_module |
| if optimized_llopa_replay_module is None: |
| optimized_llopa_replay_module = ( |
| info.get("replay_module") |
| or _config_str("capsule_replay_module") |
| or info.get("last_layer_module") |
| or _config_str("capsule_last_layer_module") |
| ) |
| optimized_llopa_replay_module = _normalize_replay_module_value(optimized_llopa_replay_module) |
| optimized_llopa_last_layer_module = str(optimized_llopa_replay_module) |
| if optimized_llopa_replay_per_layers is None: |
| optimized_llopa_replay_per_layers = info.get("replay_per_layers") or _config_str("capsule_replay_per_layers") or -1 |
| optimized_llopa_replay_per_layers = _normalize_replay_per_layers_value(optimized_llopa_replay_per_layers) |
| optimized_settings = _resolve_optimized_llopa_settings( |
| variant=optimized_llopa_variant, |
| seed_mode=optimized_llopa_seed_mode, |
| upper_prepare_mode=optimized_llopa_upper_prepare_mode, |
| upper_bucket_multiple=optimized_llopa_upper_bucket_multiple, |
| seq_bucket_multiple=optimized_llopa_seq_bucket_multiple, |
| ) |
| if direct_llopa_layers is None: |
| direct_llopa_layers = runtime_prefill_layers |
| if direct_llopa_attn is None: |
| direct_llopa_attn = runtime_prefill_attn |
| if direct_llopa_system_prefill is None: |
| direct_llopa_system_prefill = runtime_prefill_system_prefill |
| if direct_llopa_user_prefill is None: |
| raw_user_prefill = (info.get("user_prefill") or _config_str("capsule_user_prefill") or "full").strip().lower() |
| if raw_user_prefill: |
| direct_llopa_user_prefill = raw_user_prefill |
| if direct_llopa_no_upper_attn is None: |
| direct_llopa_no_upper_attn = bool(no_upper_attn) if no_upper_attn is not None else False |
| if runtime_prefill_solo_layers is None: |
| runtime_prefill_solo_layers = runtime_prefill_layers |
| if runtime_prefill_solo_attn is None: |
| runtime_prefill_solo_attn = runtime_prefill_attn |
| if runtime_prefill_solo_system_prefill is None: |
| runtime_prefill_solo_system_prefill = runtime_prefill_system_prefill |
| if runtime_prefill_solo_v2_layers is None: |
| runtime_prefill_solo_v2_layers = runtime_prefill_layers |
| if runtime_prefill_solo_v2_attn is None: |
| runtime_prefill_solo_v2_attn = runtime_prefill_attn |
| if runtime_prefill_solo_v2_system_prefill is None: |
| runtime_prefill_solo_v2_system_prefill = runtime_prefill_system_prefill |
| if runtime_prefill_solo_v2_with_bos is None: |
| runtime_prefill_solo_v2_with_bos = False |
|
|
| config_kwargs = {"config": config} if config is not None else {} |
|
|
| dtype_norm = _normalize_dtype_arg(dtype) or "auto" |
| torch_dtype_norm = _normalize_dtype_arg(torch_dtype) |
| if dtype_norm == "auto" and torch_dtype_norm is not None: |
| dtype_norm = torch_dtype_norm |
|
|
| if dtype_norm == "fp32": |
| torch_dtype = torch.float32 |
| elif dtype_norm == "bf16": |
| torch_dtype = torch.bfloat16 |
| elif dtype_norm == "fp16": |
| torch_dtype = torch.float16 |
| else: |
| if torch.cuda.is_available() and torch.cuda.is_bf16_supported(): |
| torch_dtype = torch.bfloat16 |
| elif torch.cuda.is_available(): |
| torch_dtype = torch.float16 |
| else: |
| torch_dtype = torch.float32 |
|
|
| |
| for cand in (attn_implementation, _attn_implementation): |
| if cand not in (None, "", "auto"): |
| attn_impl = str(cand) |
| break |
| if attn_impl != "auto" and config is not None: |
| for k in ("attn_implementation", "_attn_implementation"): |
| with contextlib.suppress(Exception): |
| setattr(config, k, attn_impl) |
|
|
| |
| if isinstance(device_map, str): |
| dm = device_map.strip() |
| if dm in ("", "none", "None", "null", "NULL"): |
| device_map = None |
| elif dm in ("cuda", "cpu", "mps") or dm.startswith(("cuda:", "xpu:", "npu:")): |
| if not device: |
| device = dm |
| device_map = None |
|
|
| if device_map is None and not device: |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| if tokenizer_name: |
| tok_src = tokenizer_name |
| else: |
| tok_src = str(repo_path) if (repo_path / "tokenizer.json").is_file() else backbone_ref |
| tokenizer = AutoTokenizer.from_pretrained( |
| tok_src, |
| use_fast=True, |
| cache_dir=cache_dir, |
| revision=revision, |
| token=token, |
| local_files_only=local_files_only, |
| **_tokenizer_kwargs(AutoTokenizer.from_pretrained), |
| ) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| tokenizer.padding_side = "right" |
| checkpoint_repo_path = repo_path |
| if not _repo_has_pretrained_weights(checkpoint_repo_path): |
| with contextlib.suppress(Exception): |
| resolved_backbone = _resolve_repo_path( |
| backbone_ref, |
| cache_dir=cache_dir, |
| revision=revision, |
| token=token, |
| local_files_only=local_files_only, |
| ) |
| if _repo_has_pretrained_weights(resolved_backbone): |
| checkpoint_repo_path = resolved_backbone |
| checkpoint_vocab_size = 0 |
| alignment_report = None |
| with contextlib.suppress(Exception): |
| checkpoint_vocab_size = int(_infer_checkpoint_vocab_size(checkpoint_repo_path) or 0) |
| if checkpoint_vocab_size > 0: |
| try: |
| tokenizer_vocab_size = int(len(tokenizer)) |
| except Exception: |
| tokenizer_vocab_size = 0 |
| if tokenizer_vocab_size > 0 and checkpoint_vocab_size > tokenizer_vocab_size: |
| alignment_report = _align_tokenizer_with_checkpoint_vocab( |
| tokenizer, |
| checkpoint_repo_path, |
| checkpoint_vocab_size, |
| ) |
| _log_tokenizer_checkpoint_alignment("[load_llopa_model]", alignment_report) |
| if config is not None: |
| try: |
| tokenizer_vocab_size = int(len(tokenizer)) |
| except Exception: |
| tokenizer_vocab_size = 0 |
| config_vocab_size = 0 |
| try: |
| config_vocab_size = int(getattr(config, "vocab_size", 0) or 0) |
| except Exception: |
| config_vocab_size = 0 |
| target_vocab_size = max(config_vocab_size, tokenizer_vocab_size, int(checkpoint_vocab_size or 0)) |
| if checkpoint_vocab_size > 0 and checkpoint_vocab_size != config_vocab_size: |
| _log_config_checkpoint_vocab_alignment( |
| "[load_llopa_model]", |
| config_vocab_size=config_vocab_size, |
| checkpoint_vocab_size=checkpoint_vocab_size, |
| tokenizer_vocab_size=tokenizer_vocab_size, |
| alignment_report=alignment_report, |
| ) |
| if target_vocab_size > 0 and target_vocab_size != config_vocab_size: |
| with contextlib.suppress(Exception): |
| config.vocab_size = int(target_vocab_size) |
| if enable_thinking is not None: |
| if hasattr(tokenizer, "enable_thinking"): |
| try: |
| tokenizer.enable_thinking(enable_thinking) |
| except Exception: |
| try: |
| tokenizer.enable_thinking = enable_thinking |
| except Exception: |
| pass |
| elif hasattr(tokenizer, "set_enable_thinking"): |
| try: |
| tokenizer.set_enable_thinking(enable_thinking) |
| except Exception: |
| pass |
| try: |
| setattr(tokenizer, "_force_enable_thinking", enable_thinking) |
| except Exception: |
| pass |
|
|
| model_family = infer_model_family(backbone_ref, modeling_family) |
| modeling_path = _resolve_modeling_path(repo_path, lopa_modeling_path, model_family) |
| lora_path = repo_path / "lora" |
| has_lora = bool(use_lora and lora_path.exists() and any(lora_path.iterdir())) |
| if not has_lora and bool(use_lora): |
| top_level_adapter = ( |
| (repo_path / "adapter_config.json").is_file() |
| and ( |
| (repo_path / "adapter_model.safetensors").is_file() |
| or (repo_path / "adapter_model.bin").is_file() |
| ) |
| ) |
| if top_level_adapter: |
| lora_path = repo_path |
| has_lora = True |
| load_device_map = device_map |
| merge_on_cpu_active = bool(merge_on_cpu and has_lora and device_map is None) |
| if merge_on_cpu and has_lora and device_map is not None: |
| print("[LoRA] merge_on_cpu ignored because sharded device_map is requested.") |
| if merge_on_cpu_active: |
| print("[LoRA] Merging adapters on CPU to reduce CUDA peak memory.") |
| load_device_map = None |
|
|
| custom_mod = None |
| if modeling_path: |
| try: |
| custom_mod = load_custom_modeling(modeling_path, model_family=model_family) |
| except Exception: |
| custom_mod = None |
|
|
| base = None |
| custom_load_exc = None |
| if custom_mod is not None: |
| try: |
| if model_family == "qwen3": |
| base = custom_mod.Qwen3ForCausalLM.from_pretrained( |
| backbone_ref, |
| **_dtype_kwargs(custom_mod.Qwen3ForCausalLM.from_pretrained, torch_dtype), |
| trust_remote_code=trust_remote_code, |
| cache_dir=cache_dir, |
| revision=revision, |
| token=token, |
| local_files_only=local_files_only, |
| device_map=load_device_map, |
| **config_kwargs, |
| ) |
| elif model_family == "mistral": |
| base = custom_mod.MistralForCausalLM.from_pretrained( |
| backbone_ref, |
| **_dtype_kwargs(custom_mod.MistralForCausalLM.from_pretrained, torch_dtype), |
| trust_remote_code=trust_remote_code, |
| cache_dir=cache_dir, |
| revision=revision, |
| token=token, |
| local_files_only=local_files_only, |
| device_map=load_device_map, |
| **config_kwargs, |
| ) |
| else: |
| base = custom_mod.LlamaForCausalLM.from_pretrained( |
| backbone_ref, |
| **_dtype_kwargs(custom_mod.LlamaForCausalLM.from_pretrained, torch_dtype), |
| trust_remote_code=trust_remote_code, |
| cache_dir=cache_dir, |
| revision=revision, |
| token=token, |
| local_files_only=local_files_only, |
| device_map=load_device_map, |
| **config_kwargs, |
| ) |
| except Exception as exc: |
| custom_load_exc = exc |
| base = None |
|
|
| if base is None and force_custom_modeling and custom_mod is not None and custom_load_exc is not None: |
| raise RuntimeError( |
| f"Failed to load base model with custom LLoPA modeling from {backbone_ref}" |
| ) from custom_load_exc |
|
|
| if base is None: |
| base = AutoModelForCausalLM.from_pretrained( |
| backbone_ref, |
| trust_remote_code=trust_remote_code, |
| **_dtype_kwargs(AutoModelForCausalLM.from_pretrained, torch_dtype), |
| cache_dir=cache_dir, |
| revision=revision, |
| token=token, |
| local_files_only=local_files_only, |
| device_map=load_device_map, |
| **config_kwargs, |
| ) |
|
|
| ensure_mistral_special_token(tokenizer, base) |
| load_embedding_layer(base, repo_path) |
| if merge_on_cpu_active: |
| try: |
| base = base.to("cpu") |
| except Exception: |
| pass |
|
|
| model = None |
| if has_lora: |
| num_lora = int(number_of_lora or 1) |
| if num_lora == 2: |
| gen_dir = lora_path / "gen" |
| prefill_dir = lora_path / "prefill" |
| if gen_dir.is_dir() and prefill_dir.is_dir(): |
| try: |
| from peft import PeftModel |
| peft_gen = PeftModel.from_pretrained(base, str(gen_dir), adapter_name="gen") |
| try: |
| peft_gen.set_adapter("gen") |
| except Exception: |
| pass |
| merged_base = peft_gen.merge_and_unload() |
| peft_prefill = PeftModel.from_pretrained(merged_base, str(prefill_dir), adapter_name="prefill") |
| model = peft_prefill |
| setattr(model, "_prefill_adapter_only", True) |
| except Exception: |
| model = base |
| else: |
| model = base |
| if model is None or model is base: |
| try: |
| from peft import PeftModel |
| peft = PeftModel.from_pretrained(base, str(lora_path)) |
| try: |
| model = peft.merge_and_unload() |
| except Exception: |
| model = peft |
| except Exception: |
| model = base |
| else: |
| model = base |
|
|
| if num_specials > 0: |
| if not load_llopa_specials(model, repo_path): |
| print("[Warn] Failed to load LLOPA specials.") |
|
|
| try: |
| p0 = next(model.parameters()) |
| print(f"[load_llopa_model] model dtype={p0.dtype}, on={p0.device}") |
| except Exception: |
| pass |
|
|
| if device_map is None: |
| model = model.to(device).eval() |
| else: |
| model = model.eval() |
| if no_upper_attn is not None: |
| try: |
| setattr(model, "_no_upper_attn", bool(no_upper_attn)) |
| except Exception: |
| pass |
| if bool(runtime_prefill_lower): |
| if _supports_prefill_lower_runtime(model): |
| _attach_prefill_lower_generate( |
| model, |
| lower_k=int(runtime_prefill_layers or 0), |
| prefill_attn=str(runtime_prefill_attn or "causal"), |
| system_prefill=str(runtime_prefill_system_prefill or "no_bos_system"), |
| ) |
| try: |
| print( |
| f"[load_llopa_model] standard generate runtime enabled " |
| f"(prefill_lower_layers={int(runtime_prefill_layers or 0)}, " |
| f"prefill_attn={str(runtime_prefill_attn or 'causal')})" |
| ) |
| except Exception: |
| pass |
| try: |
| setattr(model, "_capsule_inference_path", "runtime_lower") |
| except Exception: |
| pass |
| else: |
| print("[load_llopa_model][warn] runtime_prefill_lower requested but TRI prefill-lower runtime is unavailable.") |
| if bool(runtime_prefill_freeze): |
| if _supports_prefill_lower_runtime(model): |
| _attach_prefill_lower_freeze_generate( |
| model, |
| tokenizer=tokenizer, |
| lower_k=int(runtime_prefill_freeze_layers or 0), |
| prefill_attn=str(runtime_prefill_freeze_attn or "causal"), |
| system_prefill=str(runtime_prefill_freeze_system_prefill or "no_bos_system"), |
| ) |
| try: |
| print( |
| f"[load_llopa_model] freeze-faithful generate runtime enabled " |
| f"(prefill_lower_layers={int(runtime_prefill_freeze_layers or 0)}, " |
| f"prefill_attn={str(runtime_prefill_freeze_attn or 'causal')}, " |
| f"system_prefill={str(runtime_prefill_freeze_system_prefill or 'no_bos_system')})" |
| ) |
| except Exception: |
| pass |
| try: |
| setattr(model, "_capsule_inference_path", "runtime_freeze") |
| except Exception: |
| pass |
| else: |
| print("[load_llopa_model][warn] runtime_prefill_freeze requested but TRI prefill-freeze runtime is unavailable.") |
| if bool(runtime_prefill_solo): |
| if _supports_prefill_lower_runtime(model): |
| _attach_prefill_lower_solo_generate( |
| model, |
| tokenizer=tokenizer, |
| lower_k=int(runtime_prefill_solo_layers or 0), |
| prefill_attn=str(runtime_prefill_solo_attn or "causal"), |
| system_prefill=str(runtime_prefill_solo_system_prefill or "no_bos_system"), |
| ) |
| try: |
| print( |
| f"[load_llopa_model] solo-attn generate runtime enabled " |
| f"(prefill_lower_layers={int(runtime_prefill_solo_layers or 0)}, " |
| f"prefill_attn={str(runtime_prefill_solo_attn or 'causal')}, " |
| f"system_prefill={str(runtime_prefill_solo_system_prefill or 'no_bos_system')})" |
| ) |
| except Exception: |
| pass |
| try: |
| setattr(model, "_capsule_inference_path", "runtime_solo") |
| except Exception: |
| pass |
| else: |
| print("[load_llopa_model][warn] runtime_prefill_solo requested but TRI prefill-solo runtime is unavailable.") |
| if bool(runtime_prefill_solo_v2): |
| if _supports_prefill_lower_runtime(model): |
| _attach_prefill_lower_solo_v2_generate( |
| model, |
| tokenizer=tokenizer, |
| lower_k=int(runtime_prefill_solo_v2_layers or 0), |
| prefill_attn=str(runtime_prefill_solo_v2_attn or "causal"), |
| system_prefill=str(runtime_prefill_solo_v2_system_prefill or "no_bos_system"), |
| with_bos=bool(runtime_prefill_solo_v2_with_bos), |
| ) |
| try: |
| print( |
| f"[load_llopa_model] solo-attn-v2 generate runtime enabled " |
| f"(prefill_lower_layers={int(runtime_prefill_solo_v2_layers or 0)}, " |
| f"prefill_attn={str(runtime_prefill_solo_v2_attn or 'causal')}, " |
| f"system_prefill={str(runtime_prefill_solo_v2_system_prefill or 'no_bos_system')}, " |
| f"with_bos={int(bool(runtime_prefill_solo_v2_with_bos))})" |
| ) |
| except Exception: |
| pass |
| try: |
| setattr(model, "_capsule_inference_path", "runtime_solo_v2") |
| except Exception: |
| pass |
| else: |
| print("[load_llopa_model][warn] runtime_prefill_solo_v2 requested but TRI prefill-solo-v2 runtime is unavailable.") |
| if bool(runtime_llopa_prefill): |
| if _supports_runtime_llopa_prompt_prefill(model): |
| header_ids = _assistant_header_ids(tokenizer, "cpu") |
| if isinstance(header_ids, torch.Tensor) and header_ids.numel() > 0: |
| _attach_runtime_llopa_generate( |
| model, |
| header_ids=header_ids, |
| lower_k=int(runtime_llopa_layers or 0), |
| prefill_attn=str(runtime_llopa_attn or "causal"), |
| no_upper_attn=bool(runtime_llopa_no_upper_attn), |
| ) |
| if bool(runtime_llopa_fast_generate): |
| _attach_runtime_llopa_fast_generate( |
| model, |
| lower_k=int(runtime_llopa_layers or 0), |
| prefill_attn=str(runtime_llopa_attn or "causal"), |
| no_upper_attn=bool(runtime_llopa_no_upper_attn), |
| ) |
| try: |
| print( |
| f"[load_llopa_model] standard generate runtime enabled " |
| f"(llopa_prefill_layers={int(runtime_llopa_layers or 0)}, " |
| f"prefill_attn={str(runtime_llopa_attn or 'causal')})" |
| ) |
| except Exception: |
| pass |
| try: |
| setattr(model, "_capsule_inference_path", "legacy_llopa") |
| except Exception: |
| pass |
| if bool(runtime_llopa_fast_generate): |
| try: |
| print("[load_llopa_model] runtime_llopa_fast_generate enabled") |
| except Exception: |
| pass |
| try: |
| setattr(model, "_capsule_inference_path", "legacy_llopa") |
| except Exception: |
| pass |
| else: |
| print("[load_llopa_model][warn] runtime_llopa_prefill requested but assistant header ids are unavailable.") |
| else: |
| print("[load_llopa_model][warn] runtime_llopa_prefill requested but TRI LLoPA runtime is unavailable.") |
| if bool(unified_llopa_generate): |
| if _supports_direct_llopa_generate(model): |
| _attach_unified_llopa_generate( |
| model, |
| tokenizer, |
| lower_k=int(unified_llopa_layers or 0), |
| prefill_attn=str(unified_llopa_attn or "causal"), |
| system_prefill=str(unified_llopa_system_prefill or "full"), |
| user_prefill=str(unified_llopa_user_prefill or "full"), |
| no_upper_attn=bool(unified_llopa_no_upper_attn), |
| see_past_assistant=bool(unified_llopa_see_past_assistant), |
| replay_module=str(unified_llopa_replay_module or "none"), |
| replay_per_layers=int(unified_llopa_replay_per_layers or -1), |
| ) |
| try: |
| print( |
| f"[load_llopa_model] unified_llopa generate enabled " |
| f"(llopa_prefill_layers={int(unified_llopa_layers or 0)}, " |
| f"prefill_attn={str(unified_llopa_attn or 'causal')}, " |
| f"replay_module={str(unified_llopa_replay_module or 'none')}, " |
| f"replay_per_layers={int(unified_llopa_replay_per_layers or -1)})" |
| ) |
| except Exception: |
| pass |
| else: |
| print("[load_llopa_model][warn] unified_llopa_generate requested but LLoPA direct prompt prefill is unavailable.") |
| if bool(llopa_v2_batch_generate): |
| if _supports_direct_llopa_generate(model): |
| _attach_llopa_v2_batch_generate( |
| model, |
| tokenizer, |
| lower_k=int(llopa_v2_layers or 0), |
| prefill_attn=str(llopa_v2_attn or "causal"), |
| system_prefill=str(llopa_v2_system_prefill or "full"), |
| user_prefill=str(llopa_v2_user_prefill or "full"), |
| no_upper_attn=bool(llopa_v2_no_upper_attn), |
| see_past_assistant=bool(llopa_v2_see_past_assistant), |
| replay_module=str(llopa_v2_replay_module or "none"), |
| replay_per_layers=int(llopa_v2_replay_per_layers or -1), |
| seed_mode=str(llopa_v2_seed_mode or "prefill_header"), |
| ) |
| try: |
| print( |
| f"[load_llopa_model] llopa_v2_batch generate enabled " |
| f"(llopa_prefill_layers={int(llopa_v2_layers or 0)}, " |
| f"prefill_attn={str(llopa_v2_attn or 'causal')}, " |
| f"replay_module={str(llopa_v2_replay_module or 'none')}, " |
| f"replay_per_layers={int(llopa_v2_replay_per_layers or -1)}, " |
| f"seed_mode={str(llopa_v2_seed_mode or 'prefill_header')})" |
| ) |
| except Exception: |
| pass |
| else: |
| print("[load_llopa_model][warn] llopa_v2_batch_generate requested but LLoPA direct prompt prefill is unavailable.") |
| if bool(llopa_v2_generate): |
| if _supports_direct_llopa_generate(model): |
| llopa_v3_active = bool(llopa_v3_generate) |
| _attach_llopa_v2_generate( |
| model, |
| tokenizer, |
| lower_k=int(llopa_v2_layers or 0), |
| prefill_attn=str(llopa_v2_attn or "causal"), |
| system_prefill=str(llopa_v2_system_prefill or "full"), |
| user_prefill=str(llopa_v2_user_prefill or "full"), |
| no_upper_attn=bool(llopa_v2_no_upper_attn), |
| see_past_assistant=bool(llopa_v2_see_past_assistant), |
| replay_module=str(llopa_v2_replay_module or "none"), |
| replay_per_layers=int(llopa_v2_replay_per_layers or -1), |
| seed_mode=str(llopa_v2_seed_mode or "prefill_header"), |
| generation_mixin_decode=llopa_v3_active, |
| capsule_inference_path="llopa_v3" if llopa_v3_active else "llopa_v2", |
| ) |
| try: |
| label = "llopa_v3" if llopa_v3_active else "llopa_v2" |
| print( |
| f"[load_llopa_model] {label} generate enabled " |
| f"(llopa_prefill_layers={int(llopa_v2_layers or 0)}, " |
| f"prefill_attn={str(llopa_v2_attn or 'causal')}, " |
| f"replay_module={str(llopa_v2_replay_module or 'none')}, " |
| f"replay_per_layers={int(llopa_v2_replay_per_layers or -1)}, " |
| f"seed_mode={str(llopa_v2_seed_mode or 'prefill_header')}, " |
| f"generation_mixin_decode={int(llopa_v3_active)})" |
| ) |
| except Exception: |
| pass |
| else: |
| label = "llopa_v3_generate" if bool(llopa_v3_generate) else "llopa_v2_generate" |
| print(f"[load_llopa_model][warn] {label} requested but LLoPA direct prompt prefill is unavailable.") |
| if bool(optimized_llopa_generate): |
| if _supports_direct_llopa_generate(model): |
| _attach_optimized_llopa_generate( |
| model, |
| tokenizer, |
| lower_k=int(optimized_llopa_layers or 0), |
| prefill_attn=str(optimized_llopa_attn or "causal"), |
| system_prefill=str(optimized_llopa_system_prefill or "full"), |
| user_prefill=str(optimized_llopa_user_prefill or "full"), |
| no_upper_attn=bool(optimized_llopa_no_upper_attn), |
| see_past_assistant=bool(optimized_llopa_see_past_assistant), |
| replay_module=str(optimized_llopa_replay_module or "none"), |
| replay_per_layers=int(optimized_llopa_replay_per_layers or -1), |
| optimized_variant=str(optimized_settings["variant"]), |
| optimized_seed_mode=str(optimized_settings["seed_mode"]), |
| optimized_upper_prepare_mode=str(optimized_settings["upper_prepare_mode"]), |
| optimized_upper_bucket_multiple=int(optimized_settings["upper_bucket_multiple"]), |
| optimized_seq_bucket_multiple=int(optimized_settings["seq_bucket_multiple"]), |
| ) |
| try: |
| print( |
| f"[load_llopa_model] optimized_llopa generate enabled " |
| f"(llopa_prefill_layers={int(optimized_llopa_layers or 0)}, " |
| f"prefill_attn={str(optimized_llopa_attn or 'causal')}, " |
| f"replay_module={str(optimized_llopa_replay_module or 'none')}, " |
| f"replay_per_layers={int(optimized_llopa_replay_per_layers or -1)}, " |
| f"variant={str(optimized_settings['variant'])})" |
| ) |
| except Exception: |
| pass |
| else: |
| print("[load_llopa_model][warn] optimized_llopa_generate requested but LLoPA direct prompt prefill is unavailable.") |
| if bool(direct_llopa_generate): |
| print( |
| "[load_llopa_model][warn] direct_llopa_* is deprecated; " |
| "use unified_llopa_* or INFERENCE_PATH=unified_llopa. Legacy users can keep existing envs unchanged." |
| ) |
| if _supports_direct_llopa_generate(model): |
| _attach_direct_llopa_generate( |
| model, |
| tokenizer, |
| lower_k=int(direct_llopa_layers or 0), |
| prefill_attn=str(direct_llopa_attn or "causal"), |
| system_prefill=str(direct_llopa_system_prefill or "full"), |
| user_prefill=str(direct_llopa_user_prefill or "full"), |
| no_upper_attn=bool(direct_llopa_no_upper_attn), |
| ) |
| try: |
| print( |
| f"[load_llopa_model] direct_llopa generate enabled " |
| f"(llopa_prefill_layers={int(direct_llopa_layers or 0)}, " |
| f"prefill_attn={str(direct_llopa_attn or 'causal')})" |
| ) |
| except Exception: |
| pass |
| else: |
| print("[load_llopa_model][warn] direct_llopa_generate requested but LLoPA direct prompt prefill is unavailable.") |
|
|
| if force_custom_modeling: |
| try: |
| has_llopa = _has_active_llopa_runtime(model) |
| except Exception: |
| has_llopa = False |
| if not has_llopa: |
| raise RuntimeError("Custom LLoPA modeling not active at inference. Check --lopa_modeling_path.") |
|
|
| if attn_impl != "auto": |
| impl = attn_impl |
| for k in ("attn_implementation", "_attn_implementation"): |
| try: |
| setattr(model.config, k, impl) |
| inner = getattr(model, "model", None) or getattr(model, "transformer", None) |
| if inner is not None and hasattr(inner, "config"): |
| setattr(inner.config, k, impl) |
| except Exception: |
| pass |
|
|
| _attach_llopa_generate(model) |
|
|
| return model, tokenizer |
|
|
|
|
| |
| |
| |
| def main(): |
| ap = argparse.ArgumentParser("TRI inference helper") |
| ap.add_argument("--best_dir", type=str, required=True, help="Path to best/ folder produced by training") |
| ap.add_argument("--backbone_dir", type=str, default=None, |
| help="Optional backbone path/ID (overrides best_dir/backbone.json and --model_name).") |
| ap.add_argument("--model_name", type=str, default="meta-llama/Llama-3.1-8B-Instruct") |
| ap.add_argument("--tokenizer_name", type=str, default="", help="Optional tokenizer name or path") |
| ap.add_argument("--prefill_layers", type=int, default=4) |
| ap.add_argument("--prefill_mode", type=str, choices=["lower", "periodic"], default="lower", |
| help="Prefill mode for user tokens: 'lower' uses first K layers, 'periodic' uses every K-th layer.") |
| ap.add_argument("--prefill_attn", type=str, choices=["causal", "full"], default="causal", |
| help="Prefill attention for system/user tokens (training must match).") |
| ap.add_argument("--system_prefill", type=str, choices=["full", "no_system", "no_bos_system"], default="full", |
| help="System prefill mode (must match training).") |
| ap.add_argument("--user_prefill", type=str, choices=["full", "no_question"], default="full", |
| help="User prefill ablation: full=doc+question, no_question=doc-only (question runs in full layers).") |
| ap.add_argument("--llopa_prefill", action="store_true", |
| help="Use single-forward LLOPA prefill (causal + lower only).") |
| ap.add_argument("--no_upper_attn", action="store_true", |
| help="Skip upper-layer attention during decode (effective only with --llopa_prefill).") |
| ap.add_argument("--lopa_modeling_path", type=str, default="tri_llama3_modeling.py", |
| help="Path to custom LLoPA modeling file used in training") |
| ap.add_argument("--modeling_family", type=str, choices=["auto", "llama", "qwen3", "mistral"], default="auto", |
| help="Model family for custom modeling injection (auto detects from --model_name)") |
| ap.add_argument("--force_custom_modeling", action="store_true", |
| help="Require custom LLoPA modeling to be active; error if not.") |
| ap.add_argument("--system", type=str, default="You are a helpful assistant that answers questions based on the given document. ") |
| ap.add_argument("--task", type=str, default="qa_doc", |
| help="Prompt template task: qa_doc | math | summary | code") |
| ap.add_argument("--math_force_final_hash_rule", action="store_true", |
| help="If set and task=math, append the #### answer rule to the system prompt.") |
| ap.add_argument("--document", type=str, required=True) |
| ap.add_argument("--question", type=str, required=True) |
| ap.add_argument("--max_new_tokens", type=int, default=256) |
| ap.add_argument("--min_length", type=int, default=16) |
| ap.add_argument("--temperature", type=float, default=0.7) |
| ap.add_argument("--top_p", type=float, default=0.9) |
| ap.add_argument("--top_k", type=int, default=None) |
| ap.add_argument("--do_sample", action="store_true") |
| |
| ap.add_argument("--dtype", type=str, choices=["auto","bf16","fp16","fp32"], default="auto") |
| ap.add_argument("--no_tf32", action="store_true") |
| ap.add_argument("--sdpa_math_only", action="store_true") |
| ap.add_argument("--debug", action="store_true") |
| ap.add_argument("--attn_impl", type=str, choices=["sdpa", "eager", "auto"], default="sdpa", |
| help="Attention implementation override (auto keeps model default).") |
| args = ap.parse_args() |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| if args.dtype == "fp32": |
| dtype = torch.float32 |
| elif args.dtype == "bf16": |
| dtype = torch.bfloat16 |
| elif args.dtype == "fp16": |
| dtype = torch.float16 |
| else: |
| dtype = torch.bfloat16 if (device == "cuda" and torch.cuda.is_bf16_supported()) else (torch.float16 if device == "cuda" else torch.float32) |
|
|
| |
| if args.no_tf32 and torch.cuda.is_available(): |
| try: |
| torch.backends.cuda.matmul.allow_tf32 = False |
| torch.backends.cudnn.allow_tf32 = False |
| except Exception: |
| pass |
| if args.sdpa_math_only and torch.cuda.is_available(): |
| try: |
| torch.backends.cuda.enable_flash_sdp(False) |
| torch.backends.cuda.enable_mem_efficient_sdp(False) |
| torch.backends.cuda.enable_math_sdp(True) |
| except Exception: |
| pass |
|
|
| best_dir = Path(args.best_dir) |
| if getattr(args, "tokenizer_name", ""): |
| tok_src = args.tokenizer_name |
| else: |
| tok_src = str(best_dir) if (best_dir / "tokenizer.json").is_file() else args.model_name |
| tokenizer = AutoTokenizer.from_pretrained( |
| tok_src, |
| use_fast=True, |
| **_tokenizer_kwargs(AutoTokenizer.from_pretrained), |
| ) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| tokenizer.padding_side = "right" |
|
|
| |
| backbone_ref = None |
| if args.backbone_dir: |
| backbone_ref = args.backbone_dir |
| else: |
| backbone_ref = read_backbone_ref(best_dir) |
| base_path = best_dir / "base" |
| if backbone_ref: |
| base_load_src = backbone_ref |
| elif base_path.exists() and any(base_path.iterdir()): |
| base_load_src = str(base_path) |
| else: |
| base_load_src = args.model_name |
|
|
| num_specials = None |
| tri_no_upper_attn = None |
| tri_info = best_dir / "tri_info.txt" |
| if tri_info.is_file(): |
| info = _read_kv_file(tri_info) |
| try: |
| num_specials = int(info.get("num_specials", "") or 0) |
| except Exception: |
| num_specials = None |
| raw_no_upper = (info.get("no_upper_attn") or "").strip().lower() |
| if raw_no_upper in {"1", "true", "yes", "on"}: |
| tri_no_upper_attn = True |
| elif raw_no_upper in {"0", "false", "no", "off"}: |
| tri_no_upper_attn = False |
| if (not bool(getattr(args, "no_upper_attn", False))) and (tri_no_upper_attn is not None): |
| args.no_upper_attn = bool(tri_no_upper_attn) |
|
|
| config = None |
| try: |
| config = AutoConfig.from_pretrained(base_load_src) |
| if num_specials is not None: |
| config.llopa_num_specials = int(num_specials) |
| except Exception: |
| config = None |
| if num_specials is None: |
| num_specials = int(getattr(config, "llopa_num_specials", 0) or 0) if config is not None else 0 |
| config_kwargs = {"config": config} if config is not None else {} |
|
|
| |
| custom_mod = None |
| model_family = infer_model_family(args.model_name, args.modeling_family) |
| try: |
| custom_mod = load_custom_modeling(args.lopa_modeling_path, model_family=model_family) |
| except Exception: |
| custom_mod = None |
|
|
| base = None |
| if custom_mod is not None: |
| try: |
| |
| if model_family == "qwen3": |
| base = custom_mod.Qwen3ForCausalLM.from_pretrained( |
| base_load_src, |
| **_dtype_kwargs(custom_mod.Qwen3ForCausalLM.from_pretrained, dtype), |
| **config_kwargs, |
| ) |
| elif model_family == "mistral": |
| base = custom_mod.MistralForCausalLM.from_pretrained( |
| base_load_src, |
| **_dtype_kwargs(custom_mod.MistralForCausalLM.from_pretrained, dtype), |
| **config_kwargs, |
| ) |
| else: |
| base = custom_mod.LlamaForCausalLM.from_pretrained( |
| base_load_src, |
| **_dtype_kwargs(custom_mod.LlamaForCausalLM.from_pretrained, dtype), |
| **config_kwargs, |
| ) |
| except Exception: |
| base = None |
| if base is None: |
| base = AutoModelForCausalLM.from_pretrained( |
| base_load_src, |
| trust_remote_code=False, |
| **_dtype_kwargs(AutoModelForCausalLM.from_pretrained, dtype), |
| **config_kwargs, |
| ) |
| |
| ensure_mistral_special_token(tokenizer, base) |
| |
| loaded_emb = load_embedding_layer(base, best_dir) |
|
|
| |
| lora_path = best_dir / "lora" |
| model = None |
| merged_lora = False |
| if lora_path.exists() and any(lora_path.iterdir()): |
| try: |
| from peft import PeftModel |
| peft = PeftModel.from_pretrained(base, str(lora_path)) |
| try: |
| model = peft.merge_and_unload() |
| merged_lora = True |
| except Exception: |
| |
| model = peft |
| except Exception: |
| model = base |
| else: |
| model = base |
|
|
| if num_specials > 0: |
| if not load_llopa_specials(model, best_dir): |
| print("[Warn] Failed to load LLOPA specials.") |
|
|
| |
| model = model.to(device).eval() |
| try: |
| setattr(model, "_no_upper_attn", bool(getattr(args, "no_upper_attn", False))) |
| except Exception: |
| pass |
|
|
| |
| if args.force_custom_modeling: |
| has_llopa = _has_active_llopa_runtime(model) |
| if not has_llopa: |
| raise RuntimeError("Custom LLoPA modeling not active at inference. Check --lopa_modeling_path.") |
|
|
| |
| if args.attn_impl != "auto": |
| impl = args.attn_impl |
| for k in ("attn_implementation", "_attn_implementation"): |
| try: |
| setattr(model.config, k, impl) |
| inner = getattr(model, "model", None) or getattr(model, "transformer", None) |
| if inner is not None and hasattr(inner, "config"): |
| setattr(inner.config, k, impl) |
| except Exception: |
| pass |
| print(f"[infer] Forcing attn_implementation='{impl}' for all models.") |
| else: |
| print("[infer] Using model default attn_implementation (auto).") |
| if args.debug: |
| print(f"[debug] load base from: {base_load_src}") |
| if loaded_emb: |
| print("[debug] loaded embedding layer from best_dir") |
| print(f"[debug] lora path: {lora_path} | merged={merged_lora}") |
| tmpl = getattr(tokenizer, "chat_template", "") or "" |
| print(f"[debug] template contains Llama3 header? {('<|start_header_id|>' in tmpl)} | Mistral? {('[INST]' in tmpl)}") |
|
|
| |
| dbg_dir = (best_dir / "debug_infer") if args.debug else None |
| text = lopa_generate( |
| model, tokenizer, |
| system=args.system, document=args.document, question=args.question, |
| task=str(getattr(args, "task", "qa_doc")), |
| K=int(args.prefill_layers), |
| prefill_mode=str(args.prefill_mode), |
| prefill_attn=str(getattr(args, "prefill_attn", "causal")), |
| system_prefill=str(getattr(args, "system_prefill", "full")), |
| user_prefill=str(getattr(args, "user_prefill", "full")), |
| device=device, |
| max_new_tokens=args.max_new_tokens, min_length=args.min_length, |
| temperature=args.temperature, top_p=args.top_p, top_k=args.top_k, |
| do_sample=bool(args.do_sample), |
| math_force_final_hash_rule=bool(getattr(args, "math_force_final_hash_rule", False)), |
| llopa_prefill=bool(getattr(args, "llopa_prefill", False)), |
| no_upper_attn=bool(getattr(args, "no_upper_attn", False)), |
| debug=bool(args.debug), debug_dir=dbg_dir, |
| ) |
| print(text) |
|
|
| if __name__ == "__main__": |
| main() |
|
|