jeongseokoh's picture
Add files using upload-large-folder tool
9751720 verified
#!/usr/bin/env python3
from __future__ import annotations
import os
import argparse
import contextlib
import inspect
import json
from pathlib import Path
from typing import Any, List, Tuple, Optional
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
try:
from .llopa_utils.saving_utils import load_embedding_layer, load_llopa_specials, read_backbone_ref
except Exception:
from llopa_utils.saving_utils import load_embedding_layer, load_llopa_specials, read_backbone_ref
from transformers.cache_utils import DynamicCache
# Be safe with tokenizers threads when forking
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
os.environ.setdefault("DS_BUILD_AIO", "0")
os.environ.setdefault("DS_BUILD_OPS", "0")
def _dtype_kwargs(from_pretrained_fn, dtype: torch.dtype) -> dict:
try:
params = inspect.signature(from_pretrained_fn).parameters
except Exception:
return {"dtype": dtype}
if "dtype" in params:
return {"dtype": dtype}
if "torch_dtype" in params:
return {"torch_dtype": dtype}
# Some transformers versions expose from_pretrained(*args, **kwargs)
# while still accepting `dtype` at runtime.
if any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()):
return {"dtype": dtype}
return {"dtype": dtype}
def _tokenizer_kwargs(from_pretrained_fn) -> dict:
kwargs: dict = {}
try:
params = inspect.signature(from_pretrained_fn).parameters
if "fix_mistral_regex" in params:
kwargs["fix_mistral_regex"] = True
elif any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()):
# Older/newer tokenizer loaders often hide optional args behind **kwargs.
kwargs["fix_mistral_regex"] = True
except Exception:
pass
return kwargs
def _normalize_dtype_arg(value) -> Optional[str]:
if value is None:
return None
if isinstance(value, str):
s = value.strip().lower()
elif value is torch.bfloat16:
s = "bfloat16"
elif value is torch.float16:
s = "float16"
elif value is torch.float32:
s = "float32"
else:
return None
mapping = {
"auto": "auto",
"bf16": "bf16",
"bfloat16": "bf16",
"fp16": "fp16",
"float16": "fp16",
"half": "fp16",
"fp32": "fp32",
"float32": "fp32",
}
return mapping.get(s)
# Special assistant-start token for Mistral-style templates
MISTRAL_ASSIST_START = "<Mistral_start>"
# Instruction appended after the question in prompts (requested).
BOXED_ANSWER_INSTRUCTION = (
"Provide the answer and finish with the short answer inside \\\\boxed{}."
)
MATH_STEP_BY_STEP_BOXED_INSTRUCTION = (
"Solve the problem step by step, and put the final answer inside \\\\boxed{}."
)
# -----------------------------
# Custom modeling loader (TRI)
# -----------------------------
def load_custom_modeling(modeling_path: str, model_family: str):
"""Load local tri_*_modeling.py and register it as the HF module."""
import importlib.util, sys
import transformers # noqa: F401
if model_family == "llama":
import transformers.models.llama # ensure package exists
target_name = "transformers.models.llama.modeling_llama"
expected = ("LlamaModel", "LlamaForCausalLM")
elif model_family == "qwen3":
import transformers.models.qwen3 # ensure package exists
target_name = "transformers.models.qwen3.modeling_qwen3"
expected = ("Qwen3Model", "Qwen3ForCausalLM")
elif model_family == "mistral":
import transformers.models.mistral # ensure package exists
target_name = "transformers.models.mistral.modeling_mistral"
expected = ("MistralModel", "MistralForCausalLM")
else:
raise ValueError(f"Unknown model_family: {model_family}")
if target_name in sys.modules:
del sys.modules[target_name]
spec = importlib.util.spec_from_file_location(target_name, str(modeling_path))
if spec is None or spec.loader is None:
raise RuntimeError(f"Failed to load spec for {modeling_path}")
module = importlib.util.module_from_spec(spec)
sys.modules[target_name] = module
spec.loader.exec_module(module)
for klass in expected:
if not hasattr(module, klass):
raise RuntimeError(f"{modeling_path} does not define {klass}")
return module
def infer_model_family(model_name: str, model_family_arg: str) -> str:
if model_family_arg and model_family_arg != "auto":
return model_family_arg
name = (model_name or "").lower()
if "qwen" in name:
return "qwen3"
if "llama" in name:
return "llama"
if "mistral" in name:
return "mistral"
return "llama"
# -----------------------------
# Template helpers
# -----------------------------
def _is_mistral_template(tokenizer) -> bool:
tmpl = getattr(tokenizer, "chat_template", "") or ""
name = getattr(getattr(tokenizer, "init_kwargs", {}), "get", lambda k, d=None: d)("name_or_path", "")
return ("[INST]" in tmpl) or ("mistral" in str(name).lower()) or ("mistral" in tmpl.lower())
def _is_qwen3_tokenizer(tokenizer) -> bool:
name = getattr(tokenizer, "name_or_path", "") or ""
cls = tokenizer.__class__.__name__.lower()
return ("qwen3" in name.lower()) or ("qwen3" in cls)
def ensure_mistral_special_token(tokenizer, model=None):
"""Ensure the custom assistant-start token exists in tokenizer (and resize model embeddings if provided)."""
if not _is_mistral_template(tokenizer):
return False
add_tok = []
cur = set(tokenizer.get_vocab().keys())
if MISTRAL_ASSIST_START not in cur:
add_tok.append(MISTRAL_ASSIST_START)
if add_tok:
tokenizer.add_special_tokens({
"additional_special_tokens": tokenizer.special_tokens_map_extended.get("additional_special_tokens", []) + add_tok
})
if model is not None:
try:
model.resize_token_embeddings(len(tokenizer))
except Exception:
pass
return True
return False
def build_prompt_parts(task: str, question: str, document: str) -> Tuple[str, str]:
t = (task or "qa_doc").strip()
q = (question or "").strip()
d = (document or "").strip()
if t == "summary":
instruction = q if q else "Summarize the passage."
doc_text = f"Passage:\n{d}" if d else ""
return doc_text, instruction
if t == "code":
if not q:
q = "Solve the following problem."
return "", f"Write code to solve the following problem:\n{q}"
if t == "math":
mq = (q if q else d).strip()
if mq:
mq = f"{mq}\n{MATH_STEP_BY_STEP_BOXED_INSTRUCTION}"
else:
mq = MATH_STEP_BY_STEP_BOXED_INSTRUCTION
return "", mq
doc_text = f"Document:\n{d}" if d else ""
query_text = f"Question: {q}\n{BOXED_ANSWER_INSTRUCTION}" if q else ""
return doc_text, query_text
def build_user_content(task: str, question: str, document: str) -> str:
doc_text, query_text = build_prompt_parts(task, question, document)
if doc_text and query_text:
return f"{doc_text}\n\n{query_text}"
return doc_text or query_text
def build_messages(system: str, document: str, question: str,
include_query: bool = True, task: str = "qa_doc"):
q = question if include_query else ""
user = build_user_content(task, q, document)
if not user:
user = (q or document or "")
return [{"role": "system", "content": system}, {"role": "user", "content": user}]
def apply_chat_template(tokenizer, messages, add_generation_prompt: bool):
"""Render chat with robust fallback across templates.
- Prefer tokenizer.apply_chat_template(..., add_generation_prompt=...)
- If that signature is unsupported, detect template style:
* Llama-3 style → append assistant header tokens
* Mistral/INST style → no explicit assistant header to append
* Unknown → do not append anything
"""
force_thinking = getattr(tokenizer, "_force_enable_thinking", None)
if force_thinking is None and _is_qwen3_tokenizer(tokenizer):
force_thinking = False
try:
if force_thinking is None:
return tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=add_generation_prompt
)
return tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=add_generation_prompt,
enable_thinking=bool(force_thinking),
)
except TypeError:
tmpl = getattr(tokenizer, "chat_template", "") or ""
try:
if force_thinking is None:
s = tokenizer.apply_chat_template(messages, tokenize=False)
else:
s = tokenizer.apply_chat_template(messages, tokenize=False, enable_thinking=bool(force_thinking))
except TypeError:
s = tokenizer.apply_chat_template(messages, tokenize=False)
if add_generation_prompt:
if "<|start_header_id|>" in tmpl: # Llama 3 style
s += "<|start_header_id|>assistant<|end_header_id|>\n\n"
elif "[INST]" in tmpl or "</s>" in tmpl: # Mistral style: no explicit header
s += ""
else:
# Unknown template → safest is to append nothing
s += ""
return s
def tokens_from_messages(tokenizer, messages, device, add_generation_prompt=False):
s = apply_chat_template(tokenizer, messages, add_generation_prompt)
ids = tokenizer(s, add_special_tokens=False, return_tensors="pt").input_ids.to(device)
# If Mistral template and generation prompt requested, append our assistant-start header token
# if add_generation_prompt and _is_mistral_template(tokenizer):
# try:
# tok_id = tokenizer.convert_tokens_to_ids(MISTRAL_ASSIST_START)
# if tok_id is not None and tok_id != tokenizer.unk_token_id:
# extra = torch.tensor([[int(tok_id)]], device=ids.device, dtype=ids.dtype)
# ids = torch.cat([ids, extra], dim=1)
# except Exception:
# pass
return ids
def _build_llopa_inputs(tokenizer,
system: str,
document: str,
question: str,
*,
task: str = "qa_doc",
device: str):
doc_text, query_text = build_prompt_parts(task, question, document)
msgs = build_messages(system, document, question, include_query=True, task=task)
ids_sys = tokens_from_messages(
tokenizer, [{"role": "system", "content": system}], device, add_generation_prompt=False
)
ids_hdr = tokens_from_messages(tokenizer, msgs, device, add_generation_prompt=True)
hdr_tail = _assistant_header_ids(tokenizer, device)
ids_sys_user = None
if hdr_tail is not None and hdr_tail.numel() > 0:
tail_len = hdr_tail.size(1)
if ids_hdr.size(1) >= tail_len and torch.equal(ids_hdr[:, -tail_len:], hdr_tail):
ids_sys_user = ids_hdr[:, :-tail_len]
if ids_sys_user is None:
ids_sys_user = tokens_from_messages(tokenizer, msgs, device, add_generation_prompt=False)
hdr_tail = ids_hdr[:, ids_sys_user.size(1):]
if ids_sys_user.size(1) < ids_sys.size(1):
raise ValueError("System-only tokens longer than system+user tokens.")
user_ids_full = ids_sys_user[:, ids_sys.size(1):]
user_content_full = msgs[-1]["content"]
if query_text:
doc_prefix_text = f"{doc_text}\n\n" if doc_text else ""
else:
doc_prefix_text = user_content_full
user_doc_ids = user_ids_full[:, 0:0]
user_q_ids = user_ids_full[:, 0:0]
if user_ids_full.size(1) > 0:
msgs_empty = [{"role": "system", "content": system}, {"role": "user", "content": ""}]
ids_empty = tokens_from_messages(tokenizer, msgs_empty, device, add_generation_prompt=False)
header_prefix_len = lcp_len(ids_empty, ids_sys_user)
user_header_len = max(0, header_prefix_len - ids_sys.size(1))
doc_prefix_len = 0
if doc_prefix_text:
doc_prefix_len = len(tokenizer(doc_prefix_text, add_special_tokens=False).input_ids)
doc_end = min(user_ids_full.size(1), user_header_len + doc_prefix_len)
user_doc_ids = user_ids_full[:, :doc_end]
user_q_ids = user_ids_full[:, doc_end:]
return {
"prompt_ids": ids_hdr,
"system_ids": ids_sys,
"system_user_ids": ids_sys_user,
"user_ids_full": user_ids_full,
"user_doc_ids": user_doc_ids,
"user_q_ids": user_q_ids,
"hdr_tail": hdr_tail,
}
def build_messages_for_llopa(tokenizer,
system: str,
document: str,
question: str,
*,
task: str = "qa_doc",
device: str):
return _build_llopa_inputs(
tokenizer,
system=system,
document=document,
question=question,
task=task,
device=device,
)
def _normalize_prompt_messages(messages):
out = []
if not isinstance(messages, list):
return out
for msg in messages:
if not isinstance(msg, dict):
continue
role = str(msg.get("role") or "user").strip().lower()
if role not in {"system", "user", "assistant"}:
role = "user"
content = str(msg.get("content") or "")
if role != "assistant":
content = content.strip()
if not content:
continue
out.append({"role": role, "content": content})
return out
def _assistant_header_starts_from_messages(
tokenizer,
prompt_messages,
*,
prompt_add_generation_prompt: bool,
device,
):
msgs = _normalize_prompt_messages(prompt_messages)
if not msgs:
return None, None
starts: list[int] = []
for idx, msg in enumerate(msgs):
if msg["role"] != "assistant":
continue
prefix_ids = tokens_from_messages(tokenizer, msgs[:idx], device, add_generation_prompt=False)
starts.append(int(prefix_ids.size(1)))
if bool(prompt_add_generation_prompt) and msgs[-1]["role"] != "assistant":
prefix_ids = tokens_from_messages(tokenizer, msgs, device, add_generation_prompt=False)
starts.append(int(prefix_ids.size(1)))
if not starts:
return None, None
starts_tensor = torch.tensor([starts], device=device, dtype=torch.long)
start_mask = torch.ones((1, len(starts)), device=device, dtype=torch.bool)
return starts_tensor, start_mask
def _assistant_turn_boundaries_from_messages(
tokenizer,
prompt_messages,
*,
prompt_add_generation_prompt: bool,
device,
):
msgs = _normalize_prompt_messages(prompt_messages)
if not msgs:
return None, None, None
starts: list[int] = []
ends: list[int] = []
for idx, msg in enumerate(msgs):
if msg["role"] != "assistant":
continue
prefix_ids = tokens_from_messages(tokenizer, msgs[:idx], device, add_generation_prompt=False)
turn_ids = tokens_from_messages(tokenizer, msgs[: idx + 1], device, add_generation_prompt=False)
starts.append(int(prefix_ids.size(1)))
ends.append(int(turn_ids.size(1)))
if bool(prompt_add_generation_prompt) and msgs[-1]["role"] != "assistant":
prefix_ids = tokens_from_messages(tokenizer, msgs, device, add_generation_prompt=False)
prompt_ids = tokens_from_messages(tokenizer, msgs, device, add_generation_prompt=True)
starts.append(int(prefix_ids.size(1)))
ends.append(int(prompt_ids.size(1)))
if not starts:
return None, None, None
starts_tensor = torch.tensor([starts], device=device, dtype=torch.long)
ends_tensor = torch.tensor([ends], device=device, dtype=torch.long)
start_mask = torch.ones((1, len(starts)), device=device, dtype=torch.bool)
return starts_tensor, ends_tensor, start_mask
def _assistant_content_delta_from_messages(tokenizer, prefix_messages, assistant_text: str, su_gen, device):
msgs_ass = list(prefix_messages) + [{"role": "assistant", "content": assistant_text}]
full_ids = tokens_from_messages(tokenizer, msgs_ass, device, add_generation_prompt=False)
if full_ids.size(1) <= su_gen.size(1):
return full_ids[:, :0]
return full_ids[:, su_gen.size(1):]
def _strip_trailing_assistant_stop_tokens(tokenizer, token_ids: torch.Tensor) -> torch.Tensor:
if not isinstance(token_ids, torch.Tensor) or token_ids.numel() == 0:
return token_ids
stop_ids = set()
eos_id = getattr(tokenizer, "eos_token_id", None)
if eos_id is not None:
stop_ids.add(int(eos_id))
with contextlib.suppress(Exception):
eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>")
if eot_id is not None and eot_id != tokenizer.unk_token_id:
stop_ids.add(int(eot_id))
if not stop_ids:
return token_ids
trimmed = token_ids
while trimmed.size(1) > 0 and int(trimmed[0, -1].item()) in stop_ids:
trimmed = trimmed[:, :-1]
return trimmed
def _resolve_user_replay_layout_from_messages(
tokenizer,
prefix_messages,
*,
system_len: int,
user_len: int,
device,
):
system_len = max(int(system_len), 0)
user_len = max(int(user_len), 0)
if user_len <= 0:
return 0, 0, 0
user_indices = [idx for idx, msg in enumerate(prefix_messages) if msg["role"] == "user"]
if not user_indices:
return 0, 0, user_len
def _token_len(msgs) -> int:
try:
return int(tokens_from_messages(tokenizer, msgs, device, add_generation_prompt=False).size(1))
except Exception:
return 0
first_user_idx = int(user_indices[0])
latest_user_idx = int(user_indices[-1])
prefix_before_first_user_len = _token_len(prefix_messages[:first_user_idx])
prefix_before_latest_user_len = _token_len(prefix_messages[:latest_user_idx])
prefix_through_latest_user_len = _token_len(prefix_messages[: latest_user_idx + 1])
user_prefix_keep_len = max(prefix_before_first_user_len - system_len, 0)
user_prefix_keep_len = min(user_prefix_keep_len, user_len)
latest_user_start = max(prefix_before_latest_user_len - system_len, user_prefix_keep_len)
latest_user_start = min(latest_user_start, user_len)
latest_user_end = max(prefix_through_latest_user_len - system_len, latest_user_start)
latest_user_end = min(latest_user_end, user_len)
latest_user_len = max(latest_user_end - latest_user_start, 0)
if latest_user_len <= 0 and user_len > user_prefix_keep_len:
latest_user_start = int(user_prefix_keep_len)
latest_user_len = int(user_len - user_prefix_keep_len)
return int(user_prefix_keep_len), int(latest_user_start), int(latest_user_len)
def _build_structured_prompt_segments(tokenizer, prompt_messages, *, prompt_add_generation_prompt: bool, device):
msgs = _normalize_prompt_messages(prompt_messages)
if not msgs:
raise ValueError("prompt_messages must contain at least one non-empty message.")
if bool(prompt_add_generation_prompt):
if msgs[-1]["role"] == "assistant":
raise ValueError("prompt_add_generation_prompt=True requires prompt_messages to end with a non-assistant role.")
prefix_messages = msgs
prefix_ids = tokens_from_messages(tokenizer, prefix_messages, device, add_generation_prompt=False)
prompt_ids = tokens_from_messages(tokenizer, prefix_messages, device, add_generation_prompt=True)
assistant_prefill_ids = prompt_ids[:, prefix_ids.size(1):]
else:
if msgs[-1]["role"] != "assistant":
raise ValueError("prompt_add_generation_prompt=False requires prompt_messages to end with an assistant prefix.")
if len(msgs) < 2:
raise ValueError("assistant-prefix prompts require at least one preceding non-assistant message.")
prefix_messages = msgs[:-1]
assistant_text = str(msgs[-1]["content"] or "")
prefix_ids = tokens_from_messages(tokenizer, prefix_messages, device, add_generation_prompt=False)
prompt_prefix = tokens_from_messages(tokenizer, prefix_messages, device, add_generation_prompt=True)
assistant_content_ids = _assistant_content_delta_from_messages(
tokenizer,
prefix_messages,
assistant_text,
prompt_prefix,
device,
)
assistant_content_ids = _strip_trailing_assistant_stop_tokens(tokenizer, assistant_content_ids)
assistant_header_ids = prompt_prefix[:, prefix_ids.size(1):]
assistant_prefill_ids = torch.cat([assistant_header_ids, assistant_content_ids], dim=1)
prompt_ids = torch.cat([prefix_ids, assistant_prefill_ids], dim=1)
if assistant_prefill_ids.size(1) <= 0:
raise ValueError("Structured direct LLoPA prompt produced an empty assistant prefill segment.")
if prefix_messages and prefix_messages[0]["role"] == "system":
system_ids = tokens_from_messages(tokenizer, [prefix_messages[0]], device, add_generation_prompt=False)
user_ids = prefix_ids[:, system_ids.size(1):]
else:
system_ids = prefix_ids[:, :0]
user_ids = prefix_ids
replay_user_prefix_keep_len, replay_user_start, replay_user_len = _resolve_user_replay_layout_from_messages(
tokenizer,
prefix_messages,
system_len=int(system_ids.size(1)),
user_len=int(user_ids.size(1)),
device=device,
)
assistant_header_starts, assistant_turn_ends, assistant_header_start_mask = _assistant_turn_boundaries_from_messages(
tokenizer,
msgs,
prompt_add_generation_prompt=bool(prompt_add_generation_prompt),
device=device,
)
return {
"prefix_ids": prefix_ids,
"prompt_ids": prompt_ids,
"system_ids": system_ids,
"user_ids": user_ids,
"assistant_prefill_ids": assistant_prefill_ids,
"replay_user_prefix_keep_len": int(replay_user_prefix_keep_len),
"replay_user_start": int(replay_user_start),
"replay_user_len": int(replay_user_len),
"assistant_header_starts": assistant_header_starts,
"assistant_turn_ends": assistant_turn_ends,
"assistant_header_start_mask": assistant_header_start_mask,
}
def _build_unified_prefill_lower_prompt_bundle(
tokenizer,
*,
prompt_messages,
prompt_add_generation_prompt: bool,
structured_prompt_segments=None,
device,
):
segments = structured_prompt_segments if isinstance(structured_prompt_segments, dict) else None
if segments is None:
segments = _build_structured_prompt_segments(
tokenizer,
prompt_messages,
prompt_add_generation_prompt=bool(prompt_add_generation_prompt),
device=device,
)
prompt_ids = segments["prompt_ids"]
system_ids = segments["system_ids"]
prefix_ids = segments.get("prefix_ids")
header_starts = segments.get("assistant_header_starts")
assistant_turn_ends = segments.get("assistant_turn_ends")
header_start_mask = segments.get("assistant_header_start_mask")
replay_user_prefix_keep_len = int(segments.get("replay_user_prefix_keep_len", 0) or 0)
replay_user_start = int(segments.get("replay_user_start", 0) or 0)
replay_user_len = int(segments.get("replay_user_len", 0) or 0)
assistant_header_start: Optional[int] = None
if (
isinstance(header_starts, torch.Tensor)
and header_starts.ndim == 2
and header_starts.size(0) == 1
and header_starts.numel() > 0
):
valid_mask = header_start_mask
if not isinstance(valid_mask, torch.Tensor) or valid_mask.shape != header_starts.shape:
valid_mask = header_starts >= 0
valid_starts = header_starts[0][valid_mask[0]]
if valid_starts.numel() > 0:
assistant_header_start = int(valid_starts[-1].item())
if isinstance(prefix_ids, torch.Tensor) and prefix_ids.ndim == 2 and prefix_ids.size(0) == 1:
assistant_header_start = int(prefix_ids.size(1))
if assistant_header_start is None:
header_ids = _assistant_header_ids(tokenizer, device=device)
if isinstance(header_ids, torch.Tensor) and header_ids.ndim == 2 and header_ids.size(0) == 1:
assistant_header_start = _find_last_subsequence_start(prompt_ids, header_ids)
if assistant_header_start is None:
assistant_header_start = max(int(prompt_ids.size(1) - 1), 0)
return {
"segments": segments,
"prompt_ids": prompt_ids,
"attention_mask": torch.ones_like(prompt_ids, device=prompt_ids.device),
"assistant_header_start": int(assistant_header_start),
"prefill_lower_split_start": torch.tensor(
[int(assistant_header_start)],
device=prompt_ids.device,
dtype=torch.long,
),
"assistant_header_starts": (
header_starts.to(device=prompt_ids.device, dtype=torch.long)
if isinstance(header_starts, torch.Tensor) and header_starts.numel() > 0
else torch.tensor(
[[int(assistant_header_start)]],
device=prompt_ids.device,
dtype=torch.long,
)
),
"assistant_turn_ends": (
assistant_turn_ends.to(device=prompt_ids.device, dtype=torch.long)
if isinstance(assistant_turn_ends, torch.Tensor) and assistant_turn_ends.numel() > 0
else torch.tensor(
[[int(prompt_ids.size(1))]],
device=prompt_ids.device,
dtype=torch.long,
)
),
"assistant_header_start_mask": (
header_start_mask.to(device=prompt_ids.device, dtype=torch.bool)
if isinstance(header_start_mask, torch.Tensor) and header_start_mask.numel() > 0
else torch.ones((1, 1), device=prompt_ids.device, dtype=torch.bool)
),
"prefill_lower_system_len": torch.tensor(
[int(system_ids.size(1))],
device=prompt_ids.device,
dtype=torch.long,
),
"prefill_lower_replay_user_prefix_keep_len": torch.tensor(
[int(replay_user_prefix_keep_len)],
device=prompt_ids.device,
dtype=torch.long,
),
"prefill_lower_replay_user_start": torch.tensor(
[int(replay_user_start)],
device=prompt_ids.device,
dtype=torch.long,
),
"prefill_lower_replay_user_len": torch.tensor(
[int(replay_user_len)],
device=prompt_ids.device,
dtype=torch.long,
),
}
def _prompt_bundle_has_past_assistant_history(prompt_bundle) -> bool:
if not isinstance(prompt_bundle, dict):
return False
header_starts = prompt_bundle.get("assistant_header_starts")
turn_ends = prompt_bundle.get("assistant_turn_ends")
if not isinstance(header_starts, torch.Tensor) or not isinstance(turn_ends, torch.Tensor):
return False
if header_starts.numel() == 0 or turn_ends.numel() == 0:
return False
if header_starts.dim() == 1:
header_starts = header_starts.view(1, -1)
if turn_ends.dim() == 1:
turn_ends = turn_ends.view(1, -1)
if header_starts.dim() != 2 or turn_ends.dim() != 2:
return False
header_mask = prompt_bundle.get("assistant_header_start_mask")
if isinstance(header_mask, torch.Tensor) and header_mask.numel() > 0:
if header_mask.dim() == 1:
header_mask = header_mask.view(1, -1)
if header_mask.dim() != 2 or header_mask.shape != header_starts.shape:
header_mask = header_starts >= 0
else:
header_mask = header_mask.to(device=header_starts.device, dtype=torch.bool)
else:
header_mask = header_starts >= 0
split_starts = prompt_bundle.get("effective_prefill_lower_split_start")
if not isinstance(split_starts, torch.Tensor) or split_starts.numel() == 0:
split_starts = prompt_bundle.get("prefill_lower_split_start")
if isinstance(split_starts, torch.Tensor) and split_starts.numel() > 0:
split_starts = split_starts.flatten().to(device=header_starts.device, dtype=torch.long)
else:
split_starts = torch.tensor(
[int(prompt_bundle.get("assistant_header_start", 0) or 0)],
device=header_starts.device,
dtype=torch.long,
)
rows = min(int(header_starts.size(0)), int(turn_ends.size(0)))
cols = min(int(header_starts.size(1)), int(turn_ends.size(1)))
if rows <= 0 or cols <= 0:
return False
for row in range(rows):
split_idx = min(row, int(split_starts.numel()) - 1)
split_start = int(split_starts[split_idx].item())
for col in range(cols):
if not bool(header_mask[row, col].item()):
continue
turn_start = int(header_starts[row, col].item())
turn_end = int(turn_ends[row, col].item())
if turn_end <= turn_start or turn_start >= split_start:
continue
if min(turn_end, split_start) > turn_start:
return True
return False
def _direct_prefill_lower_cache_and_logits(
model,
*,
prompt_bundle,
lower_k: int,
prefill_attn: str,
system_prefill: str,
no_upper_attn: bool,
see_past_assistant: bool = False,
replay_module: str = "none",
replay_per_layers: int = -1,
last_layer_module: Optional[str] = None,
seed_mode: str = "auto",
):
if last_layer_module is not None and _normalize_replay_module_value(replay_module) == "none":
replay_module = last_layer_module
replay_module = _normalize_replay_module_value(replay_module)
replay_per_layers = _normalize_replay_per_layers_value(replay_per_layers)
seed_mode = _normalize_structured_llopa_seed_mode(seed_mode)
if isinstance(prompt_bundle, dict):
prompt_ids = prompt_bundle.get("prompt_ids")
if isinstance(prompt_ids, torch.Tensor):
attention_mask = prompt_bundle.get("attention_mask")
if not isinstance(attention_mask, torch.Tensor) or attention_mask.shape != prompt_ids.shape:
attention_mask = torch.ones_like(prompt_ids, device=prompt_ids.device, dtype=torch.long)
else:
attention_mask = attention_mask.to(device=prompt_ids.device, dtype=torch.long)
effective_prompt_ids = prompt_ids
effective_prompt_attention_mask = attention_mask
split_starts = prompt_bundle.get("prefill_lower_split_start")
if isinstance(split_starts, torch.Tensor):
split_starts = split_starts.to(device=prompt_ids.device, dtype=torch.long)
else:
split_starts = torch.tensor(
[int(prompt_bundle.get("assistant_header_start", max(int(prompt_ids.size(1) - 1), 0)))],
device=prompt_ids.device,
dtype=torch.long,
)
valid_len = (
int(effective_prompt_attention_mask[0].sum().item())
if effective_prompt_attention_mask.ndim == 2 and effective_prompt_attention_mask.size(0) == 1
else int(effective_prompt_ids.size(1))
)
if valid_len > 0:
effective_prompt_ids = effective_prompt_ids[:, :valid_len]
effective_prompt_attention_mask = effective_prompt_attention_mask[:, :valid_len]
prompt_bundle["effective_prompt_ids"] = effective_prompt_ids
prompt_bundle["effective_prompt_attention_mask"] = effective_prompt_attention_mask
prompt_bundle["effective_prefill_lower_split_start"] = split_starts
segments = prompt_bundle.get("segments") if isinstance(prompt_bundle, dict) else None
if not isinstance(segments, dict):
return None
if not bool(see_past_assistant):
matched_inband_seed = _matched_inband_prefill_cache_and_logits(
model,
prompt_bundle=prompt_bundle,
lower_k=lower_k,
prefill_attn=prefill_attn,
system_prefill=system_prefill,
no_upper_attn=no_upper_attn,
)
if matched_inband_seed is not None:
return matched_inband_seed
llopa_full_prompt_seed_fn = _get_llopa_full_prompt_seed(model)
effective_prompt_ids = prompt_bundle.get("effective_prompt_ids")
effective_prompt_attention_mask = prompt_bundle.get("effective_prompt_attention_mask")
split_starts = prompt_bundle.get("effective_prefill_lower_split_start")
system_lens = prompt_bundle.get("prefill_lower_system_len")
replay_user_prefix_keep_lens = prompt_bundle.get("prefill_lower_replay_user_prefix_keep_len")
replay_user_starts = prompt_bundle.get("prefill_lower_replay_user_start")
replay_user_lens = prompt_bundle.get("prefill_lower_replay_user_len")
assistant_header_starts = prompt_bundle.get("assistant_header_starts")
assistant_turn_ends = prompt_bundle.get("assistant_turn_ends")
assistant_header_start_mask = prompt_bundle.get("assistant_header_start_mask")
needs_past_assistant_seed = bool(see_past_assistant) and _prompt_bundle_has_past_assistant_history(prompt_bundle)
should_try_full_prompt_seed = seed_mode == "auto" or bool(needs_past_assistant_seed)
full_prompt_seed_error = None
if should_try_full_prompt_seed and callable(llopa_full_prompt_seed_fn) and isinstance(effective_prompt_ids, torch.Tensor):
if not isinstance(effective_prompt_attention_mask, torch.Tensor):
effective_prompt_attention_mask = torch.ones_like(
effective_prompt_ids,
device=effective_prompt_ids.device,
dtype=torch.long,
)
if not isinstance(split_starts, torch.Tensor):
split_starts = prompt_bundle.get("prefill_lower_split_start")
if not isinstance(system_lens, torch.Tensor):
system_lens = torch.zeros(
(effective_prompt_ids.size(0),),
device=effective_prompt_ids.device,
dtype=torch.long,
)
try:
full_prompt_seed = llopa_full_prompt_seed_fn(
input_ids=effective_prompt_ids,
attention_mask=effective_prompt_attention_mask,
use_cache=True,
logits_to_keep=1,
lower_k=int(lower_k),
prefill_attn=str(prefill_attn),
system_prefill=str(system_prefill),
no_upper_attn=bool(no_upper_attn),
prefill_lower_split_start=split_starts,
prefill_lower_system_len=system_lens,
prefill_lower_replay_user_prefix_keep_len=replay_user_prefix_keep_lens,
prefill_lower_replay_user_start=replay_user_starts,
prefill_lower_replay_user_len=replay_user_lens,
assistant_header_starts=assistant_header_starts,
assistant_turn_ends=assistant_turn_ends,
assistant_header_start_mask=assistant_header_start_mask,
prefill_lower_see_past_assistant=bool(see_past_assistant),
replay_module=str(replay_module),
replay_per_layers=int(replay_per_layers),
)
if full_prompt_seed is not None:
return full_prompt_seed
except Exception as exc:
full_prompt_seed_error = exc
if bool(needs_past_assistant_seed):
if not callable(llopa_full_prompt_seed_fn):
raise RuntimeError(
"LLOPA_SEE_PAST_ASSISTANT=1 requires llopa_full_prompt_prefill_seed "
"for prefill_header prompts with previous assistant turns."
)
if full_prompt_seed_error is not None:
raise RuntimeError(
"LLOPA_SEE_PAST_ASSISTANT=1 failed in llopa_full_prompt_prefill_seed "
"for a prefill_header prompt with previous assistant turns."
) from full_prompt_seed_error
raise RuntimeError(
"LLOPA_SEE_PAST_ASSISTANT=1 could not build a prefill_header seed "
"that includes previous assistant turns."
)
seed_fn = getattr(model, "llopa_reference_prefill_seed", None)
if not callable(seed_fn):
return None
try:
return seed_fn(
system_ids=segments["system_ids"],
user_ids=segments["user_ids"],
assistant_ids=segments["assistant_prefill_ids"],
lower_k=int(lower_k),
prefill_attn=str(prefill_attn),
system_prefill=str(system_prefill),
no_upper_attn=bool(no_upper_attn),
replay_module=str(replay_module),
replay_per_layers=int(replay_per_layers),
replay_user_prefix_keep_len=int(segments.get("replay_user_prefix_keep_len", 0) or 0),
replay_user_start=int(segments.get("replay_user_start", 0) or 0),
replay_user_len=int(segments.get("replay_user_len", 0) or 0),
)
except Exception:
return None
def _optimized_prefill_lower_cache_and_logits(
model,
*,
prompt_bundle,
lower_k: int,
prefill_attn: str,
system_prefill: str,
no_upper_attn: bool,
see_past_assistant: bool = False,
replay_module: str = "none",
replay_per_layers: int = -1,
last_layer_module: Optional[str] = None,
seed_mode: str = "auto",
):
if last_layer_module is not None and _normalize_replay_module_value(replay_module) == "none":
replay_module = last_layer_module
replay_module = _normalize_replay_module_value(replay_module)
replay_per_layers = _normalize_replay_per_layers_value(replay_per_layers)
seed_mode = _normalize_optimized_llopa_seed_mode(seed_mode)
if seed_mode in {"matched", "auto"}:
matched_seed = _matched_inband_prefill_cache_and_logits(
model,
prompt_bundle=prompt_bundle,
lower_k=lower_k,
prefill_attn=prefill_attn,
system_prefill=system_prefill,
no_upper_attn=no_upper_attn,
replay_module=replay_module,
replay_per_layers=replay_per_layers,
)
if matched_seed is not None:
return matched_seed
llopa_full_prompt_seed_fn = _get_llopa_full_prompt_seed(model)
if seed_mode in {"tri", "auto"} and callable(llopa_full_prompt_seed_fn) and isinstance(prompt_bundle, dict):
effective_prompt_ids = prompt_bundle.get("effective_prompt_ids")
effective_prompt_attention_mask = prompt_bundle.get("effective_prompt_attention_mask")
split_starts = prompt_bundle.get("effective_prefill_lower_split_start")
system_lens = prompt_bundle.get("prefill_lower_system_len")
replay_user_prefix_keep_lens = prompt_bundle.get("prefill_lower_replay_user_prefix_keep_len")
replay_user_starts = prompt_bundle.get("prefill_lower_replay_user_start")
replay_user_lens = prompt_bundle.get("prefill_lower_replay_user_len")
assistant_header_starts = prompt_bundle.get("assistant_header_starts")
assistant_header_start_mask = prompt_bundle.get("assistant_header_start_mask")
if isinstance(effective_prompt_ids, torch.Tensor):
if not isinstance(effective_prompt_attention_mask, torch.Tensor):
effective_prompt_attention_mask = torch.ones_like(
effective_prompt_ids,
device=effective_prompt_ids.device,
dtype=torch.long,
)
if not isinstance(split_starts, torch.Tensor):
split_starts = prompt_bundle.get("prefill_lower_split_start")
if not isinstance(split_starts, torch.Tensor):
split_starts = torch.tensor(
[int(prompt_bundle.get("assistant_header_start", max(int(effective_prompt_ids.size(1) - 1), 0)))],
device=effective_prompt_ids.device,
dtype=torch.long,
)
if not isinstance(system_lens, torch.Tensor):
system_lens = torch.zeros(
(effective_prompt_ids.size(0),),
device=effective_prompt_ids.device,
dtype=torch.long,
)
if not isinstance(replay_user_prefix_keep_lens, torch.Tensor):
replay_user_prefix_keep_lens = torch.zeros(
(effective_prompt_ids.size(0),),
device=effective_prompt_ids.device,
dtype=torch.long,
)
if not isinstance(replay_user_starts, torch.Tensor):
replay_user_starts = replay_user_prefix_keep_lens.clone()
if not isinstance(replay_user_lens, torch.Tensor):
replay_user_lens = torch.zeros(
(effective_prompt_ids.size(0),),
device=effective_prompt_ids.device,
dtype=torch.long,
)
try:
full_prompt_seed = llopa_full_prompt_seed_fn(
input_ids=effective_prompt_ids,
attention_mask=effective_prompt_attention_mask,
use_cache=True,
logits_to_keep=1,
lower_k=int(lower_k),
prefill_attn=str(prefill_attn),
system_prefill=str(system_prefill),
no_upper_attn=bool(no_upper_attn),
replay_module=str(replay_module),
replay_per_layers=int(replay_per_layers),
prefill_lower_split_start=split_starts,
prefill_lower_system_len=system_lens,
prefill_lower_replay_user_prefix_keep_len=replay_user_prefix_keep_lens,
prefill_lower_replay_user_start=replay_user_starts,
prefill_lower_replay_user_len=replay_user_lens,
assistant_header_starts=assistant_header_starts,
assistant_turn_ends=prompt_bundle.get("assistant_turn_ends"),
assistant_header_start_mask=assistant_header_start_mask,
prefill_lower_see_past_assistant=bool(see_past_assistant),
)
if full_prompt_seed is not None:
return full_prompt_seed
except Exception:
pass
if seed_mode in {"stable", "auto", "matched"}:
return _direct_prefill_lower_cache_and_logits(
model,
prompt_bundle=prompt_bundle,
lower_k=lower_k,
prefill_attn=prefill_attn,
system_prefill=system_prefill,
no_upper_attn=no_upper_attn,
see_past_assistant=bool(see_past_assistant),
replay_module=replay_module,
replay_per_layers=replay_per_layers,
)
return None
def _normalize_structured_llopa_seed_mode(seed_mode: Optional[str]) -> str:
normalized = str(seed_mode or "auto").strip().lower()
aliases = {
"": "auto",
"default": "auto",
"tri": "auto",
"tri_auto": "auto",
"reference": "prefill_header",
"reference_only": "prefill_header",
"prefill-header": "prefill_header",
"prefill_header_seed": "prefill_header",
"prefill_header_only": "prefill_header",
}
normalized = aliases.get(normalized, normalized)
if normalized not in {"auto", "prefill_header"}:
normalized = "auto"
return normalized
def _llopa_modeling_module(model) -> Optional[Any]:
llopa_core = _get_llopa_core(model)
if llopa_core is None:
return None
with contextlib.suppress(Exception):
return inspect.getmodule(llopa_core.__class__)
return None
def _matched_inband_prefill_cache_and_logits(
model,
*,
prompt_bundle,
lower_k: int,
prefill_attn: str,
system_prefill: str,
no_upper_attn: bool,
):
if bool(no_upper_attn):
return None
llopa_core = _get_llopa_core(model)
output_head = _get_output_head(model)
llopa_mod = _llopa_modeling_module(model)
if llopa_core is None or output_head is None or llopa_mod is None:
return None
fusion_mode = str(
getattr(getattr(llopa_core, "config", None), "capsule_fusion_mode", "upper_only") or "upper_only"
).strip().lower()
if fusion_mode != "inband":
return None
attn = (prefill_attn or "causal").strip().lower()
if attn == "prefix_full":
attn = "full"
if attn != "causal":
return None
required_names = (
"_safe_dynamic_cache",
"_tri_arange",
"_llopa_position_ids_from_mask",
"_resolve_attn_impl",
"_can_use_implicit_causal_mask",
"_llopa_mask_is_all_ones",
"_build_tri_mask_local",
"_tri_insert_suffix_specials_inband",
"_tri_effective_suffix_special_token_ids",
"_tri_build_prefill_lower_upper_index_batch",
"_tri_pack_indexed_tensor",
"create_causal_mask",
)
if any(not hasattr(llopa_mod, name) for name in required_names):
return None
prompt_ids = prompt_bundle["prompt_ids"]
attention_mask = prompt_bundle["attention_mask"].to(device=prompt_ids.device, dtype=torch.long)
system_len = int(prompt_bundle["prefill_lower_system_len"][0].item()) if prompt_bundle["prefill_lower_system_len"].numel() > 0 else 0
full_ids = prompt_ids
full_attention_mask = attention_mask
split_starts = prompt_bundle.get("prefill_lower_split_start")
if isinstance(split_starts, torch.Tensor):
split_starts = split_starts.to(device=prompt_ids.device, dtype=torch.long)
else:
split_starts = torch.tensor(
[int(prompt_bundle.get("assistant_header_start", max(int(prompt_ids.size(1) - 1), 0)))],
device=prompt_ids.device,
dtype=torch.long,
)
token_ids = list(getattr(llopa_mod, "_tri_effective_suffix_special_token_ids")(model) or [])
if token_ids:
header_starts = prompt_bundle.get("assistant_header_starts")
header_start_mask = prompt_bundle.get("assistant_header_start_mask")
(
full_ids,
full_attention_mask,
_,
remapped_split_starts,
_,
_,
) = getattr(llopa_mod, "_tri_insert_suffix_specials_inband")(
token_ids=token_ids,
input_ids=full_ids,
attention_mask=full_attention_mask,
labels=None,
split_starts=split_starts,
assistant_header_starts=header_starts,
assistant_header_start_mask=header_start_mask,
)
if isinstance(remapped_split_starts, torch.Tensor) and remapped_split_starts.numel() > 0:
split_starts = remapped_split_starts.to(device=prompt_ids.device, dtype=torch.long)
valid_len = int(full_attention_mask[0].sum().item())
if valid_len <= 0:
return None
full_ids = full_ids[:, :valid_len]
full_attention_mask = full_attention_mask[:, :valid_len]
split_start = max(0, min(int(split_starts[0].item()) if split_starts.numel() > 0 else int(valid_len - 1), valid_len))
try:
lower_k = int(lower_k)
except Exception:
return None
n_layers = len(getattr(llopa_core, "layers", []))
if lower_k <= 0 or n_layers <= 0:
return None
lower_k = max(0, min(lower_k, n_layers))
device = full_ids.device
pkv = getattr(llopa_mod, "_safe_dynamic_cache")(llopa_core.config)
inputs_embeds = llopa_core.embed_tokens(full_ids)
cache_position = getattr(llopa_mod, "_tri_arange")(0, inputs_embeds.shape[1], device)
position_ids = getattr(llopa_mod, "_llopa_position_ids_from_mask")(full_attention_mask)
attn_impl = getattr(llopa_mod, "_resolve_attn_impl")(llopa_core.config)
if attn_impl == "flash_attention_2":
lower_mask = None if getattr(llopa_mod, "_llopa_mask_is_all_ones")(full_attention_mask) else full_attention_mask
elif getattr(llopa_mod, "_can_use_implicit_causal_mask")(llopa_core.config) and getattr(llopa_mod, "_llopa_mask_is_all_ones")(full_attention_mask):
lower_mask = None
else:
lower_mask = getattr(llopa_mod, "create_causal_mask")(
config=llopa_core.config,
input_embeds=inputs_embeds,
attention_mask=full_attention_mask,
cache_position=cache_position,
past_key_values=None,
position_ids=position_ids,
)
hidden_states = inputs_embeds
position_embeddings = llopa_core.rotary_emb(hidden_states, position_ids)
for li in range(lower_k):
layer = llopa_core.layers[li]
hidden_states = layer(
hidden_states,
attention_mask=lower_mask,
position_ids=position_ids,
past_key_values=pkv,
use_cache=True,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
split_starts = torch.tensor([split_start], device=hidden_states.device, dtype=torch.long)
valid_lens = torch.tensor([valid_len], device=hidden_states.device, dtype=torch.long)
system_lens = torch.tensor([system_len], device=hidden_states.device, dtype=torch.long)
upper_gather_idx, upper_valid_mask, upper_lens = getattr(llopa_mod, "_tri_build_prefill_lower_upper_index_batch")(
split_starts=split_starts,
valid_lens=valid_lens,
system_lens=system_lens,
system_prefill=str(system_prefill),
device=hidden_states.device,
)
upper_hidden, _ = getattr(llopa_mod, "_tri_pack_indexed_tensor")(
hidden_states,
gather_idx=upper_gather_idx,
valid_mask=upper_valid_mask,
pad_value=0.0,
)
upper_position_ids_src = position_ids.to(device=hidden_states.device, dtype=torch.long)
upper_position_ids, _ = getattr(llopa_mod, "_tri_pack_indexed_tensor")(
upper_position_ids_src,
gather_idx=upper_gather_idx,
valid_mask=upper_valid_mask,
pad_value=0,
)
upper_len = int(upper_lens[0].item()) if upper_lens.numel() > 0 else 0
if upper_len <= 0:
return None
upper_hidden = upper_hidden[:, :upper_len, :]
upper_position_ids = upper_position_ids[:, :upper_len]
upper_cache_position = upper_position_ids[0]
if lower_k < n_layers:
if attn_impl == "flash_attention_2" or getattr(llopa_mod, "_can_use_implicit_causal_mask")(llopa_core.config):
upper_mask = None
else:
upper_mask = getattr(llopa_mod, "_build_tri_mask_local")(
1,
upper_len,
0,
upper_hidden.device,
upper_hidden.dtype,
)
upper_pos_emb = llopa_core.rotary_emb(upper_hidden, upper_position_ids)
for li in range(lower_k, n_layers):
layer = llopa_core.layers[li]
upper_hidden = layer(
upper_hidden,
attention_mask=upper_mask,
position_ids=upper_position_ids,
past_key_values=pkv,
use_cache=True,
cache_position=upper_cache_position,
position_embeddings=upper_pos_emb,
)
upper_hidden = llopa_core.norm(upper_hidden)
initial_logits = output_head(upper_hidden[:, -1:, :])[:, -1, :].to(torch.float32)
sys_mode = str(system_prefill or "full").strip().lower()
if sys_mode == "full":
visible_prefix_len = system_len
elif sys_mode == "no_system":
visible_prefix_len = min(system_len, 1)
else:
visible_prefix_len = 0
return pkv, int(visible_prefix_len), max(int(split_start) - int(visible_prefix_len), 0), initial_logits
def _coerce_llopa_inputs(tokenizer,
system: str,
document: str,
question: str,
*,
task: str,
device: str,
input_ids):
if not isinstance(input_ids, dict):
return _build_llopa_inputs(
tokenizer,
system=system,
document=document,
question=question,
task=task,
device=device,
)
def _get_first(mapping, *keys):
for k in keys:
if k in mapping and mapping[k] is not None:
return mapping[k]
return None
ids_sys = _get_first(input_ids, "system_ids", "ids_sys")
ids_sys_user = _get_first(input_ids, "system_user_ids", "ids_sys_user")
user_doc_ids = _get_first(input_ids, "user_doc_ids")
user_q_ids = _get_first(input_ids, "user_q_ids")
hdr_tail = _get_first(input_ids, "hdr_tail")
prompt_ids = _get_first(input_ids, "prompt_ids", "ids_hdr")
if (ids_sys is None or ids_sys_user is None or user_doc_ids is None or
user_q_ids is None or hdr_tail is None):
built = _build_llopa_inputs(
tokenizer,
system=system,
document=document,
question=question,
task=task,
device=device,
)
if ids_sys is None:
ids_sys = built["system_ids"]
if ids_sys_user is None:
ids_sys_user = built["system_user_ids"]
if user_doc_ids is None:
user_doc_ids = built["user_doc_ids"]
if user_q_ids is None:
user_q_ids = built["user_q_ids"]
if hdr_tail is None:
hdr_tail = built["hdr_tail"]
if prompt_ids is None:
prompt_ids = built.get("prompt_ids")
if prompt_ids is None and ids_sys_user is not None and hdr_tail is not None:
try:
prompt_ids = torch.cat([ids_sys_user, hdr_tail], dim=1)
except Exception:
prompt_ids = None
return {
"prompt_ids": prompt_ids,
"system_ids": ids_sys,
"system_user_ids": ids_sys_user,
"user_doc_ids": user_doc_ids,
"user_q_ids": user_q_ids,
"hdr_tail": hdr_tail,
}
def lcp_len(a: torch.Tensor, b: torch.Tensor) -> int:
L = min(a.size(1), b.size(1))
eq = (a[0, :L] == b[0, :L])
nz = (~eq).nonzero(as_tuple=False)
return int(nz[0, 0]) if nz.numel() else L
def _assistant_header_ids_from_chat_template(tokenizer, device):
"""Infer assistant header ids by diffing chat-template renders."""
probe_messages_list = [
[{"role": "system", "content": "system"}, {"role": "user", "content": "user"}],
[{"role": "user", "content": "user"}],
]
for messages in probe_messages_list:
try:
rendered_no_prompt = apply_chat_template(tokenizer, messages, add_generation_prompt=False)
rendered_with_prompt = apply_chat_template(tokenizer, messages, add_generation_prompt=True)
except Exception:
continue
if not isinstance(rendered_no_prompt, str) or not isinstance(rendered_with_prompt, str):
continue
if not rendered_with_prompt or rendered_with_prompt == rendered_no_prompt:
continue
if rendered_with_prompt.startswith(rendered_no_prompt):
suffix_text = rendered_with_prompt[len(rendered_no_prompt):]
if suffix_text:
try:
ids = tokenizer(
suffix_text,
add_special_tokens=False,
return_tensors="pt",
).input_ids.to(device)
if ids.numel() > 0:
return ids
except Exception:
pass
try:
ids_no_prompt = tokens_from_messages(tokenizer, messages, device, add_generation_prompt=False)
ids_with_prompt = tokens_from_messages(tokenizer, messages, device, add_generation_prompt=True)
except Exception:
continue
if ids_with_prompt.numel() == 0 or ids_with_prompt.size(1) <= ids_no_prompt.size(1):
continue
prefix_len = lcp_len(ids_no_prompt, ids_with_prompt)
if prefix_len < ids_with_prompt.size(1):
delta_ids = ids_with_prompt[:, prefix_len:]
if delta_ids.numel() > 0:
return delta_ids
return None
def _assistant_header_ids(tokenizer, device):
"""Best-effort header ids appended by add_generation_prompt for common templates."""
if _is_mistral_template(tokenizer):
try:
tok_id = tokenizer.convert_tokens_to_ids(MISTRAL_ASSIST_START)
if tok_id is None or tok_id == tokenizer.unk_token_id:
return None
return torch.tensor([[int(tok_id)]], device=device, dtype=torch.long)
except Exception:
return None
tmpl = getattr(tokenizer, "chat_template", "") or ""
if "<|start_header_id|>" in tmpl:
header = "<|start_header_id|>assistant<|end_header_id|>\n\n"
try:
return tokenizer(header, add_special_tokens=False, return_tensors="pt").input_ids.to(device)
except Exception:
return None
return _assistant_header_ids_from_chat_template(tokenizer, device)
def split_system_user_ids(tokenizer, system_text: str, user_text: str, device):
msgs_sys = [{"role": "system", "content": system_text}]
msgs_sys_user = [{"role": "system", "content": system_text},
{"role": "user", "content": user_text}]
ids_sys = tokens_from_messages(tokenizer, msgs_sys, device, add_generation_prompt=False)
ids_sys_user = tokens_from_messages(tokenizer, msgs_sys_user, device, add_generation_prompt=False)
if ids_sys_user.size(1) < ids_sys.size(1):
raise ValueError("System-only tokens longer than system+user tokens.")
user_ids = ids_sys_user[:, ids_sys.size(1):]
return ids_sys, user_ids, ids_sys_user
# -----------------------------
# DynamicCache helpers
# -----------------------------
def pkv_len(pkv) -> int:
if hasattr(pkv, "layers"): return len(pkv.layers)
if hasattr(pkv, "key_cache"): return len(pkv.key_cache)
return len(pkv)
def pkv_get(pkv, idx: int):
if hasattr(pkv, "layers"):
layer = pkv.layers[idx]
return layer.keys, layer.values
if hasattr(pkv, "key_cache"):
return pkv.key_cache[idx], pkv.value_cache[idx]
return pkv[idx]
def dc_from_subset(pkv_src, idxs: List[int]) -> DynamicCache:
dc = DynamicCache()
for li in idxs:
k, v = pkv_get(pkv_src, li)
dc.update(k, v, li)
return dc
def _safe_dynamic_cache(config=None) -> DynamicCache:
try:
return DynamicCache(config=config)
except TypeError as exc:
if "max_cache_len" in str(exc):
return DynamicCache()
raise
def _get_inner_model(m):
"""Return the decoder backbone that owns `.layers` (robust across wrappers)."""
# unwrap DDP/Accelerate
if hasattr(m, "module"):
m = m.module
# unwrap PEFT
try:
from peft import PeftModel
if isinstance(m, PeftModel):
try:
m = m.get_base_model()
except Exception:
m = getattr(m, "base_model", m)
except Exception:
pass
for attr in ("model", "transformer", "backbone", "base_model", "language_model"):
if hasattr(m, attr):
cand = getattr(m, attr)
if hasattr(cand, "layers") and isinstance(getattr(cand, "layers", None), nn.ModuleList):
return cand
if hasattr(cand, "decoder") and hasattr(cand.decoder, "layers") and isinstance(cand.decoder.layers, nn.ModuleList):
return cand.decoder
if hasattr(m, "layers") and isinstance(getattr(m, "layers", None), nn.ModuleList):
return m
for child in m.modules():
if child is m:
continue
if hasattr(child, "layers") and isinstance(getattr(child, "layers", None), nn.ModuleList):
return child
raise AttributeError("Could not locate inner base model with a .layers attribute")
def _get_llopa_core(model):
"""Return the decoder object that owns the LLoPA prefill/cache hooks."""
inner = _get_inner_model(model)
if hasattr(inner, "llopa_prefill_cache"):
return inner
if hasattr(inner, "model") and hasattr(inner.model, "llopa_prefill_cache"):
return inner.model
if hasattr(inner, "tri_build_caches"):
return inner
if hasattr(inner, "model") and hasattr(inner.model, "tri_build_caches"):
return inner.model
return None
def _get_llopa_decode_step(model):
"""Return the cached LLoPA decode-step callable if present."""
for name in ("llopa_decode_step_logits", "tri_step_logits"):
if hasattr(model, name):
return getattr(model, name)
try:
from peft import PeftModel
if isinstance(model, PeftModel):
try:
base = model.get_base_model()
except Exception:
base = getattr(model, "base_model", None)
if base is not None:
for name in ("llopa_decode_step_logits", "tri_step_logits"):
if hasattr(base, name):
return getattr(base, name)
except Exception:
pass
return None
def _get_llopa_full_prompt_seed(model):
"""Return the full-prompt seed callable used for past-assistant contexts."""
for name in ("llopa_full_prompt_prefill_seed", "tri_reference_prefill_seed"):
if hasattr(model, name):
return getattr(model, name)
try:
from peft import PeftModel
if isinstance(model, PeftModel):
try:
base = model.get_base_model()
except Exception:
base = getattr(model, "base_model", None)
if base is not None:
for name in ("llopa_full_prompt_prefill_seed", "tri_reference_prefill_seed"):
if hasattr(base, name):
return getattr(base, name)
except Exception:
pass
return None
def _has_active_llopa_runtime(model) -> bool:
"""Return True when LLoPA hooks are present on the loaded model."""
return _get_llopa_core(model) is not None and _get_llopa_decode_step(model) is not None
def _round_up_to_multiple(value: int, multiple: int) -> int:
value = int(value)
multiple = int(multiple)
if multiple <= 0:
return value
return ((value + multiple - 1) // multiple) * multiple
_OPTIMIZED_LLOPA_VARIANT_PRESETS = {
"baseline": {
"seed_mode": "auto",
"upper_prepare_mode": "exact",
"upper_bucket_multiple": 0,
"seq_bucket_multiple": 256,
},
"upper_ws_auto": {
"seed_mode": "auto",
"upper_prepare_mode": "bucketed_workspace",
"upper_bucket_multiple": 256,
"seq_bucket_multiple": 256,
},
"upper_ws_auto_128": {
"seed_mode": "auto",
"upper_prepare_mode": "bucketed_workspace",
"upper_bucket_multiple": 128,
"seq_bucket_multiple": 128,
},
"upper_ws_tri": {
"seed_mode": "tri",
"upper_prepare_mode": "bucketed_workspace",
"upper_bucket_multiple": 256,
"seq_bucket_multiple": 256,
},
"upper_ws_stable": {
"seed_mode": "stable",
"upper_prepare_mode": "bucketed_workspace",
"upper_bucket_multiple": 256,
"seq_bucket_multiple": 256,
},
"upper_ws_matched": {
"seed_mode": "matched",
"upper_prepare_mode": "bucketed_workspace",
"upper_bucket_multiple": 256,
"seq_bucket_multiple": 256,
},
}
def _normalize_optimized_llopa_seed_mode(seed_mode: Optional[str]) -> str:
raw = str(seed_mode or "auto").strip().lower()
if raw in {"", "default"}:
raw = "auto"
if raw not in {"auto", "tri", "stable", "matched"}:
raw = "auto"
return raw
def _normalize_optimized_llopa_upper_prepare_mode(mode: Optional[str]) -> str:
raw = str(mode or "exact").strip().lower()
if raw in {"", "default"}:
raw = "exact"
if raw not in {"exact", "bucketed_workspace"}:
raw = "exact"
return raw
def _resolve_optimized_llopa_settings(
*,
variant: Optional[str],
seed_mode: Optional[str],
upper_prepare_mode: Optional[str],
upper_bucket_multiple: Optional[int],
seq_bucket_multiple: Optional[int],
):
preset_name = str(variant or "upper_ws_auto").strip().lower()
if preset_name in {"", "default", "auto"}:
preset_name = "upper_ws_auto"
preset = _OPTIMIZED_LLOPA_VARIANT_PRESETS.get(
preset_name,
_OPTIMIZED_LLOPA_VARIANT_PRESETS["upper_ws_auto"],
)
resolved_seed_mode = _normalize_optimized_llopa_seed_mode(
seed_mode if seed_mode is not None else preset.get("seed_mode")
)
resolved_upper_prepare_mode = _normalize_optimized_llopa_upper_prepare_mode(
upper_prepare_mode if upper_prepare_mode is not None else preset.get("upper_prepare_mode")
)
resolved_upper_bucket_multiple = (
int(upper_bucket_multiple)
if upper_bucket_multiple is not None
else int(preset.get("upper_bucket_multiple", 0) or 0)
)
resolved_seq_bucket_multiple = (
int(seq_bucket_multiple)
if seq_bucket_multiple is not None
else int(preset.get("seq_bucket_multiple", 256) or 256)
)
if resolved_upper_prepare_mode != "bucketed_workspace":
resolved_upper_bucket_multiple = 0
if resolved_upper_bucket_multiple < 0:
resolved_upper_bucket_multiple = 0
if resolved_seq_bucket_multiple <= 0:
resolved_seq_bucket_multiple = 256
return {
"variant": preset_name,
"seed_mode": resolved_seed_mode,
"upper_prepare_mode": resolved_upper_prepare_mode,
"upper_bucket_multiple": int(resolved_upper_bucket_multiple),
"seq_bucket_multiple": int(resolved_seq_bucket_multiple),
}
@contextlib.contextmanager
def _temporary_model_attrs(model, **updates):
sentinel = object()
prior = {}
try:
for key, value in updates.items():
prior[key] = getattr(model, key, sentinel)
setattr(model, key, value)
yield
finally:
for key, old_value in prior.items():
if old_value is sentinel:
with contextlib.suppress(Exception):
delattr(model, key)
else:
with contextlib.suppress(Exception):
setattr(model, key, old_value)
def _acquire_bucketed_sequence_workspace(
model,
*,
reference_ids: torch.Tensor,
batch_size: int,
total_len: int,
bucket_multiple: int,
):
bucket_total_len = _round_up_to_multiple(total_len, bucket_multiple)
if bucket_total_len <= 0:
bucket_total_len = int(total_len)
dtype = reference_ids.dtype
device = reference_ids.device
key = (str(device), str(dtype), int(batch_size), int(bucket_total_len))
cache = getattr(model, "_optimized_llopa_sequence_workspace_cache", None)
if not isinstance(cache, dict):
cache = {}
workspace = cache.get(key)
if (
not isinstance(workspace, torch.Tensor)
or workspace.device != device
or workspace.dtype != dtype
or workspace.shape != (int(batch_size), int(bucket_total_len))
):
workspace = torch.empty((int(batch_size), int(bucket_total_len)), dtype=dtype, device=device)
cache[key] = workspace
try:
setattr(model, "_optimized_llopa_sequence_workspace_cache", cache)
except Exception:
pass
return workspace, int(bucket_total_len)
def _kv_meta_from_model(model_like):
"""Return (num_kv_heads, head_dim, dtype)."""
try:
cfg = getattr(model_like, "config", None) or getattr(_get_inner_model(model_like), "config", None)
except Exception:
cfg = getattr(_get_inner_model(model_like), "config", None)
num_heads = getattr(cfg, "num_attention_heads", None)
num_kv = getattr(cfg, "num_key_value_heads", None) or num_heads
hidden = getattr(cfg, "hidden_size", None)
head_dim = (hidden // num_heads) if (hidden and num_heads) else None
try:
dtype = next(_get_inner_model(model_like).parameters()).dtype
except Exception:
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
return int(num_kv), int(head_dim), dtype
def _make_empty_kv(batch: int, num_kv: int, head_dim: int, device, dtype):
shape = (batch, num_kv, 0, head_dim)
k = torch.empty(shape, device=device, dtype=dtype)
v = torch.empty(shape, device=device, dtype=dtype)
return k.contiguous(), v.contiguous()
# -----------------------------
# LLOPA helpers (encapsulated)
# -----------------------------
def _llopa_split_system(system_ids: torch.Tensor, system_prefill: str):
"""Return (system_upper, system_lower_extra) based on system_prefill."""
sys_prefill = (system_prefill or "full").strip().lower()
if sys_prefill not in {"full", "no_system", "no_bos_system"}:
sys_prefill = "full"
if sys_prefill == "full":
return system_ids, system_ids[:, :0]
if sys_prefill == "no_system":
if system_ids.size(1) < 1:
return system_ids[:, :0], system_ids[:, :0]
return system_ids[:, :1], system_ids[:, 1:]
# no_bos_system
return system_ids[:, :0], system_ids
def _llopa_merge_replay_user_span(
system_ids: torch.Tensor,
*,
system_prefill: str,
replay_user_prefix_keep_len: int,
replay_user_start: Optional[int],
replay_user_len: Optional[int],
):
_, sys_lower_extra = _llopa_split_system(system_ids, system_prefill)
merged_prefix_keep_len = int(sys_lower_extra.size(1)) + max(int(replay_user_prefix_keep_len or 0), 0)
merged_user_start = None
if replay_user_start is not None:
merged_user_start = int(sys_lower_extra.size(1)) + max(int(replay_user_start), 0)
merged_user_len = None if replay_user_len is None else max(int(replay_user_len), 0)
return int(merged_prefix_keep_len), merged_user_start, merged_user_len
def _llopa_merge_user(system_ids: torch.Tensor, user_ids: torch.Tensor, system_prefill: str):
sys_upper, sys_lower_extra = _llopa_split_system(system_ids, system_prefill)
if sys_lower_extra.numel() == 0:
return sys_upper, user_ids
if user_ids.numel() == 0:
return sys_upper, sys_lower_extra
return sys_upper, torch.cat([sys_lower_extra, user_ids], dim=1)
def _llopa_prefill_cache(llopa_core,
system_ids: torch.Tensor,
user_ids: torch.Tensor,
assistant_ids: torch.Tensor,
*,
lower_k: int,
prefill_mode: str,
prefill_attn: str,
system_prefill: str,
return_last_assistant_hidden: bool = False,
replay_user_prefix_keep_len: int = 0,
replay_user_start: Optional[int] = None,
replay_user_len: Optional[int] = None):
prefill_mode = (prefill_mode or "lower").strip().lower()
prefill_attn = (prefill_attn or "causal").strip().lower()
if prefill_attn == "prefix_full":
prefill_attn = "full"
if prefill_mode != "lower":
raise ValueError("llopa_prefill requires prefill_mode='lower'.")
if prefill_attn not in {"causal", "full"}:
raise ValueError("llopa_prefill requires prefill_attn in {'causal','full'}.")
llopa_fn = getattr(llopa_core, "llopa_prefill_cache", None)
if llopa_fn is None:
raise RuntimeError("llopa_prefill_cache not found. Check LLoPA modeling patch.")
sys_upper, user_llopa = _llopa_merge_user(system_ids, user_ids, system_prefill)
merged_replay_prefix_keep_len, merged_replay_user_start, merged_replay_user_len = _llopa_merge_replay_user_span(
system_ids,
system_prefill=system_prefill,
replay_user_prefix_keep_len=int(replay_user_prefix_keep_len or 0),
replay_user_start=replay_user_start,
replay_user_len=replay_user_len,
)
prefill_out = llopa_fn(
system_ids=sys_upper,
user_ids=user_llopa,
assistant_ids=assistant_ids,
lower_k=lower_k,
prefill_mode=prefill_mode,
prefill_attn=prefill_attn,
return_last_assistant_hidden=bool(return_last_assistant_hidden),
replay_user_prefix_keep_len=merged_replay_prefix_keep_len,
replay_user_start=merged_replay_user_start,
replay_user_len=merged_replay_user_len,
)
if bool(return_last_assistant_hidden):
if not isinstance(prefill_out, tuple):
raise RuntimeError("llopa_prefill_cache did not return the requested last assistant hidden state.")
pkv, last_hidden = prefill_out
return pkv, sys_upper.size(1), user_llopa.size(1), last_hidden
pkv = prefill_out
return pkv, sys_upper.size(1), user_llopa.size(1)
# ---------------------------------------------------------------------------
# LoPA per-layer cache_position/position_ids adjustment
# Aligns per-layer positions when lower layers have prefill past and upper
# layers start from zero. Mirrors the trainer's runtime patch.
# ---------------------------------------------------------------------------
import contextlib
@contextlib.contextmanager
def lopa_cache_position_patch(model, past_key_values):
"""
Match trainer's dynamic position alignment:
For each decoder layer, compute its current past length from the provided
past_key_values, and during forward adjust cache_position/position_ids by
off = start_val - past_len so that lower-K layers (with past=L_sys+L_doc)
and upper layers (with past=0) align logically for the current token.
"""
inner = _get_inner_model(model)
# Per-layer past length from the provided cache snapshot
def _pkv_past_len(li: int) -> int:
if hasattr(past_key_values, "key_cache") and hasattr(past_key_values, "value_cache"):
return int(past_key_values.key_cache[li].shape[2])
if hasattr(past_key_values, "layers"):
return int(past_key_values.layers[li].keys.shape[2])
return int(past_key_values[li][0].shape[2])
n_layers = len(inner.layers)
past_lens = [_pkv_past_len(li) for li in range(n_layers)]
handles = []
for li, layer in enumerate(inner.layers):
layer._lopa_past = int(past_lens[li])
def _pre_hook(module, args, kwargs):
past_len = getattr(module, "_lopa_past", 0)
cp = kwargs.get("cache_position", None)
pi = kwargs.get("position_ids", None)
start_val = None
if isinstance(cp, torch.Tensor) and cp.numel() > 0:
start_val = int(cp.view(-1)[0].item())
elif isinstance(pi, torch.Tensor) and pi.numel() > 0:
start_val = int(pi.view(-1)[0].item())
if start_val is not None:
off = start_val - past_len
if off != 0:
if isinstance(cp, torch.Tensor):
kwargs["cache_position"] = cp - off
if isinstance(pi, torch.Tensor):
kwargs["position_ids"] = pi - off
return args, kwargs
h = layer.register_forward_pre_hook(_pre_hook, with_kwargs=True)
handles.append(h)
try:
yield
finally:
for h in handles:
h.remove()
for layer in inner.layers:
if hasattr(layer, "_lopa_past"):
delattr(layer, "_lopa_past")
# -----------------------------
# TRI inference core
# -----------------------------
@torch.inference_mode()
def _get_peft_wrapper(m):
try:
from peft import PeftModel
except Exception:
return None
if hasattr(m, "module"):
m = m.module
return m if isinstance(m, PeftModel) else None
def _set_prefill_adapter(model, enabled: bool) -> None:
if not bool(getattr(model, "_prefill_adapter_only", False)):
return
peft_model = _get_peft_wrapper(model)
if peft_model is None:
return
base = getattr(peft_model, "base_model", None)
if base is None:
return
try:
if enabled:
base.enable_adapter_layers()
else:
base.disable_adapter_layers()
except Exception:
return
def _get_output_head(model):
getter = getattr(model, "get_output_embeddings", None)
if callable(getter):
head = getter()
if head is not None:
return head
return getattr(model, "lm_head", None)
def _env_flag_enabled(name: str, default: str = "1") -> bool:
raw = os.environ.get(name, default).strip().lower()
return raw not in {"0", "false", "no", "off"}
@torch.inference_mode()
def lopa_generate(model,
tokenizer,
system: str,
document: str,
question: str,
*,
task: str = "qa_doc",
K: int,
prefill_mode: str = "lower",
prefill_attn: str = "causal",
system_prefill: str = "full",
user_prefill: str = "full",
device: str,
input_ids: Optional[dict] = None,
max_new_tokens: int = 256,
min_length: int = 16,
temperature: float = 0.7,
top_p: float = 0.9,
top_k: Optional[int] = None,
do_sample: bool = True,
math_force_final_hash_rule: bool = False,
log_cuda_mem: bool = False,
log_cuda_tag: Optional[str] = None,
debug: bool = False,
debug_dir: Optional[Path] = None,
llopa_prefill: bool = False,
no_upper_attn: Optional[bool] = None,
return_tokens: bool = False) -> str | Tuple[str, int]:
# Build ids
if input_ids is None and task == "math" and math_force_final_hash_rule and "####" not in (system or ""):
system = (
system.rstrip()
+ " Conclude your explanation with the answer in a '#### {numeric answer}' format, "
+ "where the answer is solely a number."
)
user_prefill = (user_prefill or "full").strip().lower()
if user_prefill not in {"full", "no_question"}:
user_prefill = "full"
prefill_attn = (prefill_attn or "causal").strip().lower()
if prefill_attn == "prefix_full":
prefill_attn = "full"
if prefill_attn not in {"causal", "full"}:
raise ValueError("prefill_attn must be one of: causal | full")
llopa_inputs = _coerce_llopa_inputs(
tokenizer,
system=system,
document=document,
question=question,
task=task,
device=device,
input_ids=input_ids,
)
ids_sys = llopa_inputs["system_ids"]
ids_sys_user = llopa_inputs["system_user_ids"]
user_doc_ids = llopa_inputs["user_doc_ids"]
user_q_ids = llopa_inputs["user_q_ids"]
hdr_tail = llopa_inputs["hdr_tail"]
ids_hdr = llopa_inputs.get("prompt_ids", None)
if debug:
try:
msgs = build_messages(system, document, question, include_query=True, task=task)
s_no_hdr = apply_chat_template(tokenizer, msgs, add_generation_prompt=False)
s_with_hdr = apply_chat_template(tokenizer, msgs, add_generation_prompt=True)
print(f"[debug] render lengths (chars): no_hdr={len(s_no_hdr)}, with_hdr={len(s_with_hdr)}")
if debug_dir is not None:
debug_dir.mkdir(parents=True, exist_ok=True)
(debug_dir / "infer_render_no_header.txt").write_text(s_no_hdr, encoding="utf-8")
(debug_dir / "infer_render_with_header.txt").write_text(s_with_hdr, encoding="utf-8")
except Exception as e:
print(f"[debug] render dump failed: {e}")
# Assistant header tokens
if ids_hdr is None:
ids_hdr = torch.cat([ids_sys_user, hdr_tail], dim=1) if hdr_tail is not None else ids_sys_user
hdr_tail = ids_hdr[:, ids_sys_user.size(1):]
if user_prefill == "no_question":
prefix_full = torch.cat([user_q_ids, hdr_tail], dim=1)
else:
prefix_full = hdr_tail
# Require LLoPA runtime API
llopa_core = _get_llopa_core(model)
llopa_step = _get_llopa_decode_step(model)
if llopa_core is None or llopa_step is None:
raise RuntimeError("Custom LLoPA modeling not active. Check --lopa_modeling_path/--modeling_family.")
if no_upper_attn is None:
no_upper_attn = bool(getattr(model, "_no_upper_attn", False))
no_upper_attn = bool(no_upper_attn)
if no_upper_attn and (not bool(llopa_prefill)):
print("[infer][warn] no_upper_attn is ignored unless llopa_prefill=True.")
effective_no_upper_attn = bool(no_upper_attn and bool(llopa_prefill))
llopa_step_accepts_no_upper_attn = False
try:
llopa_step_params = inspect.signature(llopa_step).parameters
llopa_step_accepts_no_upper_attn = ("no_upper_attn" in llopa_step_params)
except Exception:
llopa_step_accepts_no_upper_attn = False
if effective_no_upper_attn and (not llopa_step_accepts_no_upper_attn):
print("[infer][warn] no_upper_attn requested but llopa_decode_step_logits does not support it; ignoring.")
effective_no_upper_attn = False
def _log_mem(tag: str) -> None:
if not log_cuda_mem or not torch.cuda.is_available():
return
try:
torch.cuda.synchronize()
except Exception:
pass
alloc = torch.cuda.memory_allocated() / (1024 ** 3)
reserved = torch.cuda.memory_reserved() / (1024 ** 3)
max_alloc = torch.cuda.max_memory_allocated() / (1024 ** 3)
max_reserved = torch.cuda.max_memory_reserved() / (1024 ** 3)
prefix = "[mem]"
if log_cuda_tag:
prefix = f"{prefix}[{log_cuda_tag}]"
print(f"{prefix} {tag} | alloc={alloc:.2f}GiB reserved={reserved:.2f}GiB "
f"max_alloc={max_alloc:.2f}GiB max_reserved={max_reserved:.2f}GiB")
if log_cuda_mem and torch.cuda.is_available():
try:
torch.cuda.reset_peak_memory_stats()
except Exception:
pass
_log_mem("start")
# 1) TRI prefill: system all + user lower-K
_set_prefill_adapter(model, True)
lower_k = int(K)
sys_prefill = (system_prefill or "full").strip().lower()
if sys_prefill not in {"full", "no_system", "no_bos_system"}:
sys_prefill = "full"
if user_prefill == "no_question":
user_prefill_ids = user_doc_ids
else:
user_prefill_ids = torch.cat([user_doc_ids, user_q_ids], dim=1)
if llopa_prefill:
use_fused_first_token = bool(
prefix_full.numel() > 0
and not effective_no_upper_attn
and not bool(getattr(model, "_prefill_adapter_only", False))
)
if use_fused_first_token:
pkv, S, U, llopa_last_hidden = _llopa_prefill_cache(
llopa_core,
ids_sys,
user_prefill_ids,
prefix_full,
lower_k=lower_k,
prefill_mode=prefill_mode,
prefill_attn=prefill_attn,
system_prefill=sys_prefill,
return_last_assistant_hidden=True,
)
else:
pkv, S, U = _llopa_prefill_cache(
llopa_core,
ids_sys,
user_prefill_ids,
prefix_full,
lower_k=lower_k,
prefill_mode=prefill_mode,
prefill_attn=prefill_attn,
system_prefill=sys_prefill,
)
llopa_last_hidden = None
else:
use_fused_first_token = False
llopa_last_hidden = None
if sys_prefill == "full":
pkv, S, U = llopa_core.tri_build_caches(
system_ids=ids_sys,
user_ids=user_prefill_ids,
lower_k=lower_k,
prefill_mode=prefill_mode,
prefill_attn=prefill_attn,
)
elif sys_prefill == "no_system":
if ids_sys.size(1) < 1:
pkv = _safe_dynamic_cache(getattr(llopa_core, "config", None))
else:
bos_ids = ids_sys[:, :1]
rest_ids = ids_sys[:, 1:]
out = llopa_core.tri_prefill_system_all(
bos_ids,
past_key_values=None,
prefill_attn=prefill_attn,
)
pkv = out.past_key_values
if rest_ids.size(1) > 0:
_ = llopa_core.tri_prefill_user_lower(
rest_ids,
lower_k=lower_k,
past_key_values=pkv,
prefill_mode=prefill_mode,
prefill_attn=prefill_attn,
)
_ = llopa_core.tri_prefill_user_lower(
user_prefill_ids,
lower_k=lower_k,
past_key_values=pkv,
prefill_mode=prefill_mode,
prefill_attn=prefill_attn,
)
S, U = ids_sys.size(1), user_prefill_ids.size(1)
else:
pkv = _safe_dynamic_cache(getattr(llopa_core, "config", None))
if ids_sys.size(1) > 0:
_ = llopa_core.tri_prefill_user_lower(
ids_sys,
lower_k=lower_k,
past_key_values=pkv,
prefill_mode=prefill_mode,
prefill_attn=prefill_attn,
)
_ = llopa_core.tri_prefill_user_lower(
user_prefill_ids,
lower_k=lower_k,
past_key_values=pkv,
prefill_mode=prefill_mode,
prefill_attn=prefill_attn,
)
S, U = ids_sys.size(1), user_prefill_ids.size(1)
_set_prefill_adapter(model, False)
_log_mem("prefill_end")
initial_logits = None
if use_fused_first_token:
output_head = _get_output_head(model)
if output_head is not None and isinstance(llopa_last_hidden, torch.Tensor) and llopa_last_hidden.numel() > 0:
initial_logits = output_head(llopa_last_hidden)[:, -1, :].to(torch.float32)
# 2) Push assistant header if present (or fallback to last user token)
if llopa_prefill:
if prefix_full.numel() > 0:
last_pushed = prefix_full[:, -1:]
elif ids_sys_user.numel() > 0:
last_pushed = ids_sys_user[:, -1:]
else:
raise ValueError("Empty prompt after LLOPA prefill; cannot start decoding.")
else:
if prefix_full.numel() > 0:
seed_kwargs = dict(
assistant_ids=prefix_full,
lower_k=lower_k,
pkv=pkv,
S=S,
U=U,
logits_to_keep=0,
labels=None,
prefill_mode=prefill_mode,
)
if effective_no_upper_attn and llopa_step_accepts_no_upper_attn:
seed_kwargs["no_upper_attn"] = True
out_seed = llopa_step(**seed_kwargs)
pkv = out_seed.past_key_values or pkv
last_pushed = prefix_full[:, -1:]
else:
step_tok = ids_sys_user[:, -1:]
seed_kwargs = dict(
assistant_ids=step_tok,
lower_k=lower_k,
pkv=pkv,
S=S,
U=U,
logits_to_keep=0,
labels=None,
prefill_mode=prefill_mode,
)
if effective_no_upper_attn and llopa_step_accepts_no_upper_attn:
seed_kwargs["no_upper_attn"] = True
out_seed = llopa_step(**seed_kwargs)
pkv = out_seed.past_key_values or pkv
last_pushed = step_tok
if log_cuda_mem and torch.cuda.is_available():
try:
torch.cuda.reset_peak_memory_stats()
except Exception:
pass
_log_mem("decode_start")
# 5) decoding
from transformers.generation import LogitsProcessorList
from transformers.generation.logits_process import TemperatureLogitsWarper, TopPLogitsWarper, TopKLogitsWarper
eos_id = tokenizer.eos_token_id
try:
eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>")
except Exception:
eot_id = None
stop_ids = set()
if eos_id is not None: stop_ids.add(int(eos_id))
if eot_id is not None and eot_id != tokenizer.unk_token_id: stop_ids.add(int(eot_id))
procs = None
if do_sample:
procs = LogitsProcessorList()
if temperature and temperature != 1.0:
procs.append(TemperatureLogitsWarper(temperature=float(temperature)))
if top_p and top_p < 1.0:
procs.append(TopPLogitsWarper(top_p=float(top_p), min_tokens_to_keep=1))
if top_k is not None and top_k > 0:
procs.append(TopKLogitsWarper(top_k=int(top_k), filter_value=-float("inf")))
device_t = last_pushed.device
last = last_pushed
generated = torch.empty((1, max_new_tokens), dtype=torch.long, device=device_t)
cur = 0
stop_reason = None
pending_logits = initial_logits
while cur < max_new_tokens:
if pending_logits is None:
step_kwargs = dict(
assistant_ids=last,
lower_k=lower_k,
pkv=pkv,
S=S,
U=U,
logits_to_keep=1,
labels=None,
prefill_mode=prefill_mode,
)
if effective_no_upper_attn and llopa_step_accepts_no_upper_attn:
step_kwargs["no_upper_attn"] = True
out = llopa_step(**step_kwargs)
pkv = out.past_key_values or pkv
logits = out.logits[:, -1, :]
else:
logits = pending_logits
pending_logits = None
# force min_length
if stop_ids and cur < min_length:
for sid in stop_ids:
logits[:, sid] = -float("inf")
if procs is not None:
inp_for_proc = generated[:, :cur]
if logits.dtype != torch.float32:
logits = logits.float()
logits = procs(inp_for_proc, logits)
if do_sample:
probs = torch.softmax(logits, dim=-1)
next_tok = torch.multinomial(probs, num_samples=1)
else:
next_tok = torch.argmax(logits, dim=-1, keepdim=True)
generated[:, cur:cur + 1] = next_tok
last = next_tok
if stop_ids and cur >= min_length:
tok_id = int(next_tok.item())
if tok_id in stop_ids:
stop_reason = f"stop_token:{tok_id}"
cur += 1
break
cur += 1
if stop_reason is None and cur >= max_new_tokens:
stop_reason = "max_new_tokens"
gen_ids = generated[:, :cur]
text = tokenizer.decode(gen_ids[0].tolist(), skip_special_tokens=True)
_log_mem("decode_end")
if debug:
print(f"[debug] finished | tokens={cur} | reason={stop_reason}")
if debug_dir is not None:
debug_dir.mkdir(parents=True, exist_ok=True)
(debug_dir / "infer_generated.txt").write_text(text, encoding="utf-8")
if return_tokens:
return text, int(cur)
return text
# -----------------------------
# LLOPA helpers (Capsule interface + HF packaging)
# -----------------------------
def llopa_generate(*args, **kwargs):
"""Capsule interface: identical behavior to lopa_generate."""
if "device" not in kwargs or kwargs.get("device") is None:
if len(args) >= 1:
model = args[0]
try:
dev = model.get_input_embeddings().weight.device
except Exception:
try:
dev = next(model.parameters()).device
except Exception:
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
kwargs["device"] = str(dev)
return lopa_generate(*args, **kwargs)
def _read_kv_file(path: Path) -> dict[str, str]:
info: dict[str, str] = {}
try:
lines = path.read_text(encoding="utf-8").splitlines()
except Exception:
return info
for raw in lines:
line = raw.strip()
if not line or line.startswith("#"):
continue
if "=" not in line:
continue
k, v = line.split("=", 1)
k = k.strip()
v = v.strip()
if k:
info[k] = v
return info
def _read_adapter_backbone_ref(repo_path: Path) -> str:
adapter_cfg = repo_path / "adapter_config.json"
if not adapter_cfg.is_file():
return ""
try:
data = json.loads(adapter_cfg.read_text(encoding="utf-8"))
except Exception:
return ""
val = data.get("base_model_name_or_path")
if isinstance(val, str):
return val.strip()
return ""
def _resolve_repo_path(model_repo: str, cache_dir: Optional[str] = None,
revision: Optional[str] = None, token: Optional[str] = None,
local_files_only: bool = False) -> Path:
repo = Path(model_repo)
if repo.exists():
return repo
try:
from huggingface_hub import snapshot_download
except Exception as exc:
raise RuntimeError("huggingface_hub is required to load remote repos") from exc
path = snapshot_download(
repo_id=model_repo,
cache_dir=cache_dir,
revision=revision,
token=token,
local_files_only=local_files_only,
)
return Path(path)
def _repo_has_pretrained_weights(repo_path: Path) -> bool:
weight_files = (
"pytorch_model.bin",
"pytorch_model.bin.index.json",
"model.safetensors",
"model.safetensors.index.json",
)
if any((repo_path / name).is_file() for name in weight_files):
return True
return any(repo_path.glob("pytorch_model-*-of-*.bin")) or any(repo_path.glob("model-*-of-*.safetensors"))
def _pick_vocab_weight_key(keys) -> Optional[str]:
preferred = (
"model.embed_tokens.weight",
"embed_tokens.weight",
"model.decoder.embed_tokens.weight",
"transformer.wte.weight",
"lm_head.weight",
"model.lm_head.weight",
)
key_list = list(keys)
for cand in preferred:
if cand in key_list:
return cand
suffixes = (
"embed_tokens.weight",
"decoder.embed_tokens.weight",
"wte.weight",
"lm_head.weight",
)
for suffix in suffixes:
for key in key_list:
if str(key).endswith(suffix):
return str(key)
return None
def _load_state_dict_meta(path: Path):
try:
return torch.load(path, map_location="meta", weights_only=True)
except TypeError:
return torch.load(path, map_location="meta")
def _weight_rows_from_file(path: Path, tensor_key: Optional[str] = None) -> Optional[int]:
try:
suffixes = path.suffixes
except Exception:
suffixes = []
try:
if suffixes[-1:] == [".safetensors"] or suffixes[-2:] == [".model", ".safetensors"]:
from safetensors import safe_open
with safe_open(str(path), framework="pt", device="cpu") as handle:
chosen = tensor_key or _pick_vocab_weight_key(handle.keys())
if not chosen:
return None
shape = handle.get_slice(chosen).get_shape()
if len(shape) >= 2:
return int(shape[0])
return None
state_dict = _load_state_dict_meta(path)
if not isinstance(state_dict, dict):
return None
chosen = tensor_key if tensor_key in state_dict else _pick_vocab_weight_key(state_dict.keys())
if not chosen:
return None
tensor = state_dict.get(chosen)
if isinstance(tensor, torch.Tensor) and tensor.ndim >= 2:
return int(tensor.shape[0])
except Exception:
return None
return None
def _infer_checkpoint_vocab_size(repo_path: Path) -> Optional[int]:
index_files = (
"pytorch_model.bin.index.json",
"model.safetensors.index.json",
)
single_weight_files = (
"pytorch_model.bin",
"model.safetensors",
)
for name in index_files:
index_path = repo_path / name
if not index_path.is_file():
continue
try:
data = json.loads(index_path.read_text(encoding="utf-8"))
except Exception:
continue
weight_map = data.get("weight_map")
if not isinstance(weight_map, dict):
continue
tensor_key = _pick_vocab_weight_key(weight_map.keys())
if not tensor_key:
continue
shard_rel = weight_map.get(tensor_key)
if not shard_rel:
continue
rows = _weight_rows_from_file(repo_path / shard_rel, tensor_key=tensor_key)
if rows is not None and rows > 0:
return rows
for name in single_weight_files:
weight_path = repo_path / name
if not weight_path.is_file():
continue
rows = _weight_rows_from_file(weight_path)
if rows is not None and rows > 0:
return rows
for weight_path in sorted(repo_path.glob("pytorch_model-*-of-*.bin")):
rows = _weight_rows_from_file(weight_path)
if rows is not None and rows > 0:
return rows
for weight_path in sorted(repo_path.glob("model-*-of-*.safetensors")):
rows = _weight_rows_from_file(weight_path)
if rows is not None and rows > 0:
return rows
return None
def _expand_tokenizer_placeholders(tokenizer, target_vocab_size: int) -> int:
try:
current_vocab_size = int(len(tokenizer))
except Exception:
return 0
if target_vocab_size <= current_vocab_size:
return 0
missing = int(target_vocab_size - current_vocab_size)
placeholder_tokens = [f"<|capsule_missing_token_{idx}|>" for idx in range(missing)]
try:
return int(tokenizer.add_tokens(placeholder_tokens, special_tokens=True) or 0)
except Exception:
return 0
def _normalize_special_token_values(raw_value: Any) -> list[str]:
if not isinstance(raw_value, list):
return []
normalized: list[str] = []
for item in raw_value:
token = ""
if isinstance(item, str):
token = item
elif isinstance(item, dict):
for key in ("content", "token", "text"):
value = item.get(key)
if isinstance(value, str) and value:
token = value
break
if token:
normalized.append(token)
return normalized
def _checkpoint_special_token_candidates(checkpoint_repo_path: Path) -> list[str]:
candidates: list[str] = []
for filename in ("config.json", "tokenizer_config.json", "special_tokens_map.json"):
path = checkpoint_repo_path / filename
if not path.is_file():
continue
try:
payload = json.loads(path.read_text(encoding="utf-8"))
except Exception:
continue
if not isinstance(payload, dict):
continue
candidates.extend(_normalize_special_token_values(payload.get("capsule_suffix_special_tokens")))
candidates.extend(_normalize_special_token_values(payload.get("additional_special_tokens")))
raw_decoder = payload.get("added_tokens_decoder")
if isinstance(raw_decoder, dict):
for _, value in sorted(
raw_decoder.items(),
key=lambda item: int(item[0]) if str(item[0]).isdigit() else str(item[0]),
):
candidates.extend(_normalize_special_token_values([value]))
deduped: list[str] = []
seen: set[str] = set()
for token in candidates:
if not token or token in seen:
continue
seen.add(token)
deduped.append(token)
return deduped
def _align_tokenizer_with_checkpoint_vocab(
tokenizer,
checkpoint_repo_path: Path,
target_vocab_size: int,
) -> dict[str, int]:
report = {
"checkpoint_vocab_size": int(target_vocab_size or 0),
"tokenizer_vocab_size_before": 0,
"checkpoint_special_candidate_count": 0,
"checkpoint_specials_missing_before": 0,
"added_checkpoint_specials": 0,
"tokenizer_vocab_size_after_specials": 0,
"padding_gap_before_placeholders": 0,
"added_placeholders": 0,
"tokenizer_vocab_size_after": 0,
}
try:
current_vocab_size = int(len(tokenizer))
except Exception:
return report
report["tokenizer_vocab_size_before"] = current_vocab_size
report["tokenizer_vocab_size_after_specials"] = current_vocab_size
report["tokenizer_vocab_size_after"] = current_vocab_size
missing = int(target_vocab_size - current_vocab_size)
if missing <= 0:
return report
checkpoint_specials = _checkpoint_special_token_candidates(checkpoint_repo_path)
report["checkpoint_special_candidate_count"] = len(checkpoint_specials)
try:
existing_vocab = set(tokenizer.get_vocab().keys())
except Exception:
existing_vocab = set()
missing_checkpoint_specials = [token for token in checkpoint_specials if token not in existing_vocab]
report["checkpoint_specials_missing_before"] = len(missing_checkpoint_specials)
tokens_to_add: list[str] = []
for token in missing_checkpoint_specials:
tokens_to_add.append(token)
existing_vocab.add(token)
if len(tokens_to_add) >= missing:
break
added_checkpoint_specials = 0
if tokens_to_add:
try:
current_specials = tokenizer.special_tokens_map_extended.get("additional_special_tokens", []) or []
added_checkpoint_specials = int(
tokenizer.add_special_tokens(
{"additional_special_tokens": list(current_specials) + tokens_to_add}
)
or 0
)
except Exception:
try:
added_checkpoint_specials = int(tokenizer.add_tokens(tokens_to_add, special_tokens=True) or 0)
except Exception:
added_checkpoint_specials = 0
report["added_checkpoint_specials"] = added_checkpoint_specials
try:
post_special_vocab_size = int(len(tokenizer))
except Exception:
post_special_vocab_size = current_vocab_size + added_checkpoint_specials
report["tokenizer_vocab_size_after_specials"] = post_special_vocab_size
report["padding_gap_before_placeholders"] = max(0, int(target_vocab_size - post_special_vocab_size))
added_placeholders = 0
if target_vocab_size > post_special_vocab_size:
added_placeholders = _expand_tokenizer_placeholders(tokenizer, target_vocab_size)
report["added_placeholders"] = added_placeholders
try:
report["tokenizer_vocab_size_after"] = int(len(tokenizer))
except Exception:
report["tokenizer_vocab_size_after"] = post_special_vocab_size + added_placeholders
return report
def _log_tokenizer_checkpoint_alignment(
log_prefix: str,
report: dict[str, int],
*,
print_fn=print,
) -> None:
added_specials = int(report.get("added_checkpoint_specials", 0) or 0)
added_placeholders = int(report.get("added_placeholders", 0) or 0)
if added_specials <= 0 and added_placeholders <= 0:
return
vocab_before = int(report.get("tokenizer_vocab_size_before", 0) or 0)
vocab_after_specials = int(report.get("tokenizer_vocab_size_after_specials", vocab_before) or vocab_before)
vocab_after = int(report.get("tokenizer_vocab_size_after", vocab_after_specials) or vocab_after_specials)
missing_specials_before = int(report.get("checkpoint_specials_missing_before", 0) or 0)
if added_specials > 0:
print_fn(
f"{log_prefix}[info] tokenizer vocab is smaller than checkpoint embeddings; "
f"recovered {added_specials} checkpoint special tokens "
f"({vocab_before} -> {vocab_after_specials})."
)
if added_placeholders > 0:
if missing_specials_before > 0:
print_fn(
f"{log_prefix}[warn] checkpoint embeddings still exceed tokenizer vocab after recovering "
f"checkpoint special tokens; added {added_placeholders} placeholder special tokens "
f"to preserve id alignment ({vocab_after_specials} -> {vocab_after})."
)
else:
print_fn(
f"{log_prefix}[info] checkpoint embeddings include {added_placeholders} padded rows "
f"beyond tokenizer vocab; added {added_placeholders} placeholder special tokens "
f"to preserve id alignment ({vocab_after_specials} -> {vocab_after})."
)
def _log_config_checkpoint_vocab_alignment(
log_prefix: str,
*,
config_vocab_size: int,
checkpoint_vocab_size: int,
tokenizer_vocab_size: int,
alignment_report: Optional[dict[str, int]] = None,
print_fn=print,
) -> None:
if checkpoint_vocab_size <= 0 or checkpoint_vocab_size == config_vocab_size:
return
final_tokenizer_vocab = int(tokenizer_vocab_size or 0)
if alignment_report is not None:
final_tokenizer_vocab = int(
alignment_report.get("tokenizer_vocab_size_after", final_tokenizer_vocab) or final_tokenizer_vocab
)
missing_specials_before = 0
if alignment_report is not None:
missing_specials_before = int(alignment_report.get("checkpoint_specials_missing_before", 0) or 0)
if final_tokenizer_vocab >= checkpoint_vocab_size and missing_specials_before <= 0:
print_fn(
f"{log_prefix}[info] config vocab_size ({config_vocab_size}) lags tokenizer/checkpoint embeddings "
f"({checkpoint_vocab_size}); using checkpoint size."
)
return
padding_rows = max(0, checkpoint_vocab_size - final_tokenizer_vocab)
if padding_rows > 0 and missing_specials_before <= 0:
print_fn(
f"{log_prefix}[info] config vocab_size ({config_vocab_size}) lags checkpoint embeddings "
f"({checkpoint_vocab_size}); using checkpoint size. tokenizer vocab is {final_tokenizer_vocab}, "
f"so {padding_rows} rows are embedding padding/alignment."
)
return
print_fn(
f"{log_prefix}[warn] config vocab_size ({config_vocab_size}) does not match checkpoint embeddings "
f"({checkpoint_vocab_size}); using checkpoint size."
)
def _expand_tokenizer_with_checkpoint_specials(
tokenizer,
checkpoint_repo_path: Path,
target_vocab_size: int,
) -> tuple[int, int]:
report = _align_tokenizer_with_checkpoint_vocab(
tokenizer,
checkpoint_repo_path,
target_vocab_size,
)
return (
int(report.get("added_checkpoint_specials", 0) or 0),
int(report.get("added_placeholders", 0) or 0),
)
def _resolve_modeling_path(repo_path: Path, user_path: Optional[str], model_family: str) -> Optional[str]:
if user_path:
cand = Path(user_path)
if not cand.is_file():
cand = repo_path / user_path
return str(cand) if cand.is_file() else None
tri_info = repo_path / "tri_info.txt"
if tri_info.is_file():
info = _read_kv_file(tri_info)
tri_file = info.get("lopa_modeling_path") or ""
if tri_file:
cand = repo_path / tri_file
if cand.is_file():
return str(cand)
default_name = {
"llama": "tri_llama3_modeling.py",
"qwen3": "tri_qwen3_modeling.py",
"mistral": "tri_mistral_modeling.py",
}.get(model_family, "tri_llama3_modeling.py")
cand = repo_path / default_name
if cand.is_file():
return str(cand)
return None
def _attach_llopa_generate(model):
try:
import types
def _llopa_generate(self, tokenizer, system: Optional[str] = None, document: Optional[str] = None,
question: Optional[str] = None, **kwargs):
if system is None:
system = kwargs.pop("system", "")
if document is None:
document = kwargs.pop("document", "")
if question is None:
question = kwargs.pop("question", "")
if "device" not in kwargs or kwargs.get("device") is None:
try:
dev = self.get_input_embeddings().weight.device
except Exception:
try:
dev = next(self.parameters()).device
except Exception:
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
kwargs["device"] = str(dev)
return lopa_generate(self, tokenizer, system=system, document=document, question=question, **kwargs)
model.llopa_generate = types.MethodType(_llopa_generate, model)
except Exception:
pass
def _maybe_attach_llopa_generate(model):
if not hasattr(model, "llopa_generate"):
_attach_llopa_generate(model)
def _attach_prefill_lower_generate(
model,
*,
lower_k: int,
prefill_attn: str = "causal",
system_prefill: str = "no_bos_system",
no_upper_attn: bool = False,
) -> None:
try:
import types
except Exception:
return
try:
lower_k = int(lower_k)
except Exception:
lower_k = 0
attn = (prefill_attn or "causal").strip().lower()
if attn == "prefix_full":
attn = "full"
sys_prefill = (system_prefill or "no_bos_system").strip().lower()
if sys_prefill not in {"full", "no_system", "no_bos_system"}:
sys_prefill = "no_bos_system"
if lower_k <= 0 or attn not in {"causal", "full"}:
return
try:
setattr(model, "_runtime_prefill_lower_layers", int(lower_k))
setattr(model, "_runtime_prefill_lower_attn", attn)
setattr(model, "_runtime_prefill_lower_system_prefill", sys_prefill)
setattr(model, "_runtime_prefill_lower_no_upper_attn", bool(no_upper_attn))
except Exception:
return
if getattr(model, "_runtime_prefill_generate_attached", False):
return
orig_prepare = getattr(model, "prepare_inputs_for_generation", None)
if orig_prepare is None:
return
def _runtime_prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
prefill_lower_layers=None,
prefill_lower_attn=None,
prefill_lower_system_prefill=None,
prefill_lower_no_upper_attn=None,
prefill_lower_split_start=None,
prefill_lower_system_len=None,
prefill_lower_replay_user_prefix_keep_len=None,
prefill_lower_replay_user_start=None,
prefill_lower_replay_user_len=None,
assistant_header_start=None,
assistant_header_starts=None,
assistant_header_start_mask=None,
**kwargs,
):
model_inputs = orig_prepare(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
**kwargs,
)
if prefill_lower_layers is None:
prefill_lower_layers = int(getattr(self, "_runtime_prefill_lower_layers", 0) or 0)
if prefill_lower_attn is None:
prefill_lower_attn = str(getattr(self, "_runtime_prefill_lower_attn", "causal") or "causal")
if prefill_lower_system_prefill is None:
prefill_lower_system_prefill = str(
getattr(self, "_runtime_prefill_lower_system_prefill", "no_bos_system") or "no_bos_system"
)
if prefill_lower_no_upper_attn is None:
prefill_lower_no_upper_attn = bool(
getattr(self, "_runtime_prefill_lower_no_upper_attn", False)
)
if int(prefill_lower_layers or 0) > 0:
model_inputs["prefill_lower_layers"] = int(prefill_lower_layers)
model_inputs["prefill_lower_attn"] = str(prefill_lower_attn)
model_inputs["prefill_lower_system_prefill"] = str(prefill_lower_system_prefill)
if bool(prefill_lower_no_upper_attn):
model_inputs["prefill_lower_no_upper_attn"] = True
if prefill_lower_split_start is not None:
model_inputs["prefill_lower_split_start"] = prefill_lower_split_start
if prefill_lower_system_len is not None:
model_inputs["prefill_lower_system_len"] = prefill_lower_system_len
if prefill_lower_replay_user_prefix_keep_len is not None:
model_inputs["prefill_lower_replay_user_prefix_keep_len"] = prefill_lower_replay_user_prefix_keep_len
if prefill_lower_replay_user_start is not None:
model_inputs["prefill_lower_replay_user_start"] = prefill_lower_replay_user_start
if prefill_lower_replay_user_len is not None:
model_inputs["prefill_lower_replay_user_len"] = prefill_lower_replay_user_len
if assistant_header_start is not None:
model_inputs["assistant_header_start"] = assistant_header_start
if assistant_header_starts is not None:
model_inputs["assistant_header_starts"] = assistant_header_starts
if assistant_header_start_mask is not None:
model_inputs["assistant_header_start_mask"] = assistant_header_start_mask
return model_inputs
try:
model.prepare_inputs_for_generation = types.MethodType(
_runtime_prepare_inputs_for_generation,
model,
)
setattr(model, "_runtime_prefill_generate_attached", True)
except Exception:
pass
def _attach_prefill_lower_freeze_generate(
model,
*,
tokenizer=None,
lower_k: int,
prefill_attn: str = "causal",
system_prefill: str = "no_bos_system",
) -> None:
try:
import types
except Exception:
return
try:
lower_k = int(lower_k)
except Exception:
lower_k = 0
attn = (prefill_attn or "causal").strip().lower()
if attn == "prefix_full":
attn = "full"
sys_prefill = (system_prefill or "no_bos_system").strip().lower()
if sys_prefill not in {"full", "no_system", "no_bos_system"}:
sys_prefill = "no_bos_system"
if lower_k <= 0 or attn not in {"causal", "full"}:
return
try:
setattr(model, "_runtime_prefill_freeze_layers", int(lower_k))
setattr(model, "_runtime_prefill_freeze_attn", attn)
setattr(model, "_runtime_prefill_freeze_system_prefill", sys_prefill)
setattr(model, "_runtime_structured_freeze_generate_default", True)
except Exception:
return
if tokenizer is not None:
_attach_structured_llopa_generate(model, tokenizer)
def _attach_prefill_lower_solo_generate(
model,
*,
tokenizer=None,
lower_k: int,
prefill_attn: str = "causal",
system_prefill: str = "no_bos_system",
) -> None:
try:
import types
except Exception:
return
try:
lower_k = int(lower_k)
except Exception:
lower_k = 0
attn = (prefill_attn or "causal").strip().lower()
if attn == "prefix_full":
attn = "full"
sys_prefill = (system_prefill or "no_bos_system").strip().lower()
if sys_prefill not in {"full", "no_system", "no_bos_system"}:
sys_prefill = "no_bos_system"
if lower_k <= 0 or attn not in {"causal", "full"}:
return
try:
setattr(model, "_runtime_prefill_solo_layers", int(lower_k))
setattr(model, "_runtime_prefill_solo_attn", attn)
setattr(model, "_runtime_prefill_solo_system_prefill", sys_prefill)
setattr(model, "_runtime_structured_solo_generate_default", True)
except Exception:
return
if tokenizer is not None:
_attach_structured_llopa_generate(model, tokenizer)
def _attach_prefill_lower_solo_v2_generate(
model,
*,
tokenizer=None,
lower_k: int,
prefill_attn: str = "causal",
system_prefill: str = "no_bos_system",
with_bos: bool = False,
) -> None:
try:
lower_k = int(lower_k)
except Exception:
lower_k = 0
attn = (prefill_attn or "causal").strip().lower()
if attn == "prefix_full":
attn = "full"
sys_prefill = (system_prefill or "no_bos_system").strip().lower()
if sys_prefill not in {"full", "no_system", "no_bos_system"}:
sys_prefill = "no_bos_system"
if lower_k <= 0 or attn not in {"causal", "full"}:
return
try:
setattr(model, "_runtime_prefill_solo_v2_layers", int(lower_k))
setattr(model, "_runtime_prefill_solo_v2_attn", attn)
setattr(model, "_runtime_prefill_solo_v2_system_prefill", sys_prefill)
setattr(model, "_runtime_prefill_solo_v2_with_bos", bool(with_bos))
setattr(model, "_runtime_structured_solo_v2_generate_default", True)
except Exception:
return
if tokenizer is not None:
_attach_structured_llopa_generate(model, tokenizer)
@torch.inference_mode()
def _direct_freeze_generate_impl(
model,
tokenizer,
*,
prompt_messages,
prompt_add_generation_prompt: bool,
input_ids: Optional[torch.LongTensor],
attention_mask: Optional[torch.Tensor],
lower_k: int,
prefill_attn: str,
system_prefill: str,
max_length=None,
max_new_tokens=None,
min_length=None,
min_new_tokens=None,
do_sample=None,
temperature=None,
top_p=None,
top_k=None,
stopping_criteria=None,
pad_token_id=None,
eos_token_id=None,
output_scores: bool = False,
return_dict_in_generate: bool = False,
use_cache: Optional[bool] = None,
):
try:
lower_k = int(lower_k)
except Exception:
lower_k = 0
if lower_k <= 0:
return None
attn = (prefill_attn or "causal").strip().lower()
if attn == "prefix_full":
attn = "full"
if attn not in {"causal", "full"}:
attn = "causal"
sys_prefill = (system_prefill or "no_bos_system").strip().lower()
if sys_prefill not in {"full", "no_system", "no_bos_system"}:
sys_prefill = "no_bos_system"
device = None
if isinstance(input_ids, torch.Tensor):
if input_ids.dim() != 2 or input_ids.size(0) != 1:
return None
device = input_ids.device
if device is None:
try:
device = next(model.parameters()).device
except Exception:
device = "cpu"
segments = _build_structured_prompt_segments(
tokenizer,
prompt_messages,
prompt_add_generation_prompt=bool(prompt_add_generation_prompt),
device=device,
)
prompt_ids = segments["prompt_ids"]
system_ids = segments["system_ids"]
user_ids = segments["user_ids"]
canonical_input_ids = prompt_ids
split_start = int(system_ids.size(1) + user_ids.size(1))
system_len = int(system_ids.size(1))
total_prompt_len = int(canonical_input_ids.size(1))
if max_new_tokens is None:
if max_length is None:
max_new_tokens = 256
else:
max_new_tokens = max(0, int(max_length) - total_prompt_len)
else:
max_new_tokens = int(max_new_tokens)
if min_new_tokens is None:
if min_length is None:
min_new_tokens = 0
else:
min_new_tokens = max(0, int(min_length) - total_prompt_len)
else:
min_new_tokens = int(min_new_tokens)
raw_temp = 0.0 if temperature is None else float(temperature)
if do_sample is None:
do_sample = bool(raw_temp != 0.0)
do_sample = bool(do_sample)
sample_temp = 1.0 if (not do_sample or raw_temp == 0.0) else float(raw_temp)
top_p = 1.0 if top_p is None else float(top_p)
top_k = None if top_k is None else int(top_k)
stop_ids = set(_normalize_eos_token_ids(eos_token_id))
if not stop_ids:
tok_eos = getattr(tokenizer, "eos_token_id", None)
if tok_eos is not None:
stop_ids.add(int(tok_eos))
with contextlib.suppress(Exception):
eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>")
if eot_id is not None and eot_id != tokenizer.unk_token_id:
stop_ids.add(int(eot_id))
logits_warpers = _build_sampling_warpers(do_sample, sample_temp, top_p, top_k)
generated = torch.empty((1, int(max_new_tokens)), dtype=torch.long, device=canonical_input_ids.device)
score_list: list[torch.Tensor] = []
cur = 0
while cur < int(max_new_tokens):
current_ids = torch.cat([canonical_input_ids, generated[:, :cur]], dim=1)
current_attn = torch.ones_like(current_ids, device=current_ids.device)
out = model(
input_ids=current_ids,
attention_mask=current_attn,
use_cache=False,
logits_to_keep=1,
prefill_lower_layers=int(lower_k),
prefill_lower_attn=str(attn),
prefill_lower_freeze_runtime=True,
prefill_lower_split_start=int(split_start),
prefill_lower_system_len=int(system_len),
prefill_lower_system_prefill=str(sys_prefill),
)
if out is None or not isinstance(getattr(out, "logits", None), torch.Tensor):
return None
logits = out.logits[:, -1, :].to(torch.float32)
if stop_ids and cur < int(min_new_tokens):
for sid in stop_ids:
logits[:, sid] = -float("inf")
if logits_warpers is not None:
logits = logits_warpers(generated[:, :cur], logits)
if bool(output_scores):
score_list.append(logits.detach().clone())
if do_sample:
probs = torch.softmax(logits, dim=-1)
next_tok = torch.multinomial(probs, num_samples=1)
else:
next_tok = torch.argmax(logits, dim=-1, keepdim=True)
generated[:, cur : cur + 1] = next_tok
cur += 1
sequences_now = torch.cat([canonical_input_ids, generated[:, :cur]], dim=1)
should_stop = False
tok_id = int(next_tok.item())
if tok_id in stop_ids and cur >= int(min_new_tokens):
should_stop = True
if (not should_stop) and stopping_criteria is not None:
try:
should_stop = bool(stopping_criteria(sequences_now, logits))
except TypeError:
should_stop = bool(stopping_criteria(sequences_now, None))
if should_stop:
break
sequences = torch.cat([canonical_input_ids, generated[:, :cur]], dim=1)
if not bool(return_dict_in_generate):
return sequences
return {
"sequences": sequences,
"scores": tuple(score_list) if bool(output_scores) else tuple(),
}
@torch.inference_mode()
def _direct_solo_generate_impl(
model,
tokenizer,
*,
prompt_messages,
prompt_add_generation_prompt: bool,
input_ids: Optional[torch.LongTensor],
attention_mask: Optional[torch.Tensor],
lower_k: int,
prefill_attn: str,
system_prefill: str,
max_length=None,
max_new_tokens=None,
min_length=None,
min_new_tokens=None,
do_sample=None,
temperature=None,
top_p=None,
top_k=None,
stopping_criteria=None,
pad_token_id=None,
eos_token_id=None,
output_scores: bool = False,
return_dict_in_generate: bool = False,
use_cache: Optional[bool] = None,
solo_v2: bool = False,
with_bos: bool = False,
):
try:
lower_k = int(lower_k)
except Exception:
lower_k = 0
if lower_k <= 0:
return None
attn = (prefill_attn or "causal").strip().lower()
if attn == "prefix_full":
attn = "full"
if attn not in {"causal", "full"}:
attn = "causal"
sys_prefill = (system_prefill or "no_bos_system").strip().lower()
if sys_prefill not in {"full", "no_system", "no_bos_system"}:
sys_prefill = "no_bos_system"
device = None
if isinstance(input_ids, torch.Tensor):
if input_ids.dim() != 2 or input_ids.size(0) != 1:
return None
device = input_ids.device
if device is None:
try:
device = next(model.parameters()).device
except Exception:
device = "cpu"
segments = _build_structured_prompt_segments(
tokenizer,
prompt_messages,
prompt_add_generation_prompt=bool(prompt_add_generation_prompt),
device=device,
)
prompt_ids = segments["prompt_ids"]
system_ids = segments["system_ids"]
user_ids = segments["user_ids"]
canonical_input_ids = prompt_ids
split_start = int(system_ids.size(1) + user_ids.size(1))
system_len = int(system_ids.size(1))
total_prompt_len = int(canonical_input_ids.size(1))
if max_new_tokens is None:
if max_length is None:
max_new_tokens = 256
else:
max_new_tokens = max(0, int(max_length) - total_prompt_len)
else:
max_new_tokens = int(max_new_tokens)
if min_new_tokens is None:
if min_length is None:
min_new_tokens = 0
else:
min_new_tokens = max(0, int(min_length) - total_prompt_len)
else:
min_new_tokens = int(min_new_tokens)
raw_temp = 0.0 if temperature is None else float(temperature)
if do_sample is None:
do_sample = bool(raw_temp != 0.0)
do_sample = bool(do_sample)
sample_temp = 1.0 if (not do_sample or raw_temp == 0.0) else float(raw_temp)
top_p = 1.0 if top_p is None else float(top_p)
top_k = None if top_k is None else int(top_k)
stop_ids = set(_normalize_eos_token_ids(eos_token_id))
if not stop_ids:
tok_eos = getattr(tokenizer, "eos_token_id", None)
if tok_eos is not None:
stop_ids.add(int(tok_eos))
with contextlib.suppress(Exception):
eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>")
if eot_id is not None and eot_id != tokenizer.unk_token_id:
stop_ids.add(int(eot_id))
logits_warpers = _build_sampling_warpers(do_sample, sample_temp, top_p, top_k)
generated = torch.empty((1, int(max_new_tokens)), dtype=torch.long, device=canonical_input_ids.device)
score_list: list[torch.Tensor] = []
cur = 0
while cur < int(max_new_tokens):
current_ids = torch.cat([canonical_input_ids, generated[:, :cur]], dim=1)
current_attn = torch.ones_like(current_ids, device=current_ids.device)
out = model(
input_ids=current_ids,
attention_mask=current_attn,
use_cache=False,
logits_to_keep=1,
prefill_lower_layers=int(lower_k),
prefill_lower_attn=str(attn),
prefill_lower_split_start=int(split_start),
prefill_lower_system_len=int(system_len),
prefill_lower_system_prefill=str(sys_prefill),
prefill_lower_solo_attention=bool(not solo_v2),
prefill_lower_solo_attention_v2=bool(solo_v2),
prefill_lower_solo_attention_v2_with_bos=bool(solo_v2 and with_bos),
)
if out is None or not isinstance(getattr(out, "logits", None), torch.Tensor):
return None
logits = out.logits[:, -1, :].to(torch.float32)
if stop_ids and cur < int(min_new_tokens):
for sid in stop_ids:
logits[:, sid] = -float("inf")
if logits_warpers is not None:
logits = logits_warpers(generated[:, :cur], logits)
if bool(output_scores):
score_list.append(logits.detach().clone())
if do_sample:
probs = torch.softmax(logits, dim=-1)
next_tok = torch.multinomial(probs, num_samples=1)
else:
next_tok = torch.argmax(logits, dim=-1, keepdim=True)
generated[:, cur : cur + 1] = next_tok
cur += 1
sequences_now = torch.cat([canonical_input_ids, generated[:, :cur]], dim=1)
should_stop = False
tok_id = int(next_tok.item())
if tok_id in stop_ids and cur >= int(min_new_tokens):
should_stop = True
if (not should_stop) and stopping_criteria is not None:
try:
should_stop = bool(stopping_criteria(sequences_now, logits))
except TypeError:
should_stop = bool(stopping_criteria(sequences_now, None))
if should_stop:
break
sequences = torch.cat([canonical_input_ids, generated[:, :cur]], dim=1)
if not bool(return_dict_in_generate):
return sequences
return {
"sequences": sequences,
"scores": tuple(score_list) if bool(output_scores) else tuple(),
}
def _attach_runtime_llopa_generate(
model,
*,
header_ids: torch.Tensor,
lower_k: int,
prefill_attn: str = "causal",
no_upper_attn: bool = False,
) -> None:
try:
import types
except Exception:
return
try:
lower_k = int(lower_k)
except Exception:
lower_k = 0
attn = (prefill_attn or "causal").strip().lower()
if attn == "prefix_full":
attn = "full"
if (
lower_k <= 0
or attn not in {"causal", "full"}
or not isinstance(header_ids, torch.Tensor)
or header_ids.numel() == 0
):
return
try:
setattr(model, "_runtime_llopa_header_ids", header_ids.detach().to(device="cpu", dtype=torch.long))
setattr(model, "_runtime_llopa_layers", int(lower_k))
setattr(model, "_runtime_llopa_attn", attn)
setattr(model, "_runtime_llopa_no_upper_attn", bool(no_upper_attn))
except Exception:
return
if getattr(model, "_runtime_llopa_generate_attached", False):
return
orig_prepare = getattr(model, "prepare_inputs_for_generation", None)
if orig_prepare is None:
return
def _runtime_prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
runtime_llopa_prefill=None,
runtime_llopa_layers=None,
runtime_llopa_attn=None,
runtime_llopa_no_upper_attn=None,
**kwargs,
):
model_inputs = orig_prepare(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
**kwargs,
)
if runtime_llopa_prefill is None:
runtime_llopa_prefill = True
if runtime_llopa_layers is None:
runtime_llopa_layers = int(getattr(self, "_runtime_llopa_layers", 0) or 0)
if runtime_llopa_attn is None:
runtime_llopa_attn = str(getattr(self, "_runtime_llopa_attn", "causal") or "causal")
if runtime_llopa_no_upper_attn is None:
runtime_llopa_no_upper_attn = bool(getattr(self, "_runtime_llopa_no_upper_attn", False))
if bool(runtime_llopa_prefill) and int(runtime_llopa_layers or 0) > 0:
model_inputs["runtime_llopa_prefill"] = True
model_inputs["runtime_llopa_layers"] = int(runtime_llopa_layers)
model_inputs["runtime_llopa_attn"] = str(runtime_llopa_attn)
if bool(runtime_llopa_no_upper_attn):
model_inputs["runtime_llopa_no_upper_attn"] = True
return model_inputs
try:
model.prepare_inputs_for_generation = types.MethodType(
_runtime_prepare_inputs_for_generation,
model,
)
setattr(model, "_runtime_llopa_generate_attached", True)
except Exception:
pass
@torch.inference_mode()
def _runtime_llopa_fast_generate_mode(
model,
input_ids: torch.LongTensor,
logits_processor,
stopping_criteria,
generation_config,
synced_gpus: bool = False,
streamer=None,
**model_kwargs,
):
from transformers.generation.utils import GenerateDecoderOnlyOutput
if model.config.is_encoder_decoder:
return None
pad_token_id = generation_config._pad_token_tensor
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
output_scores = generation_config.output_scores
output_logits = generation_config.output_logits
return_dict_in_generate = generation_config.return_dict_in_generate
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
do_sample = generation_config.do_sample
scores = () if (return_dict_in_generate and output_scores) else None
raw_logits = () if (return_dict_in_generate and output_logits) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
batch_size, cur_len = input_ids.shape[:2]
this_peer_finished = False
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs = model._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
while model._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
outputs = model(**model_inputs, return_dict=True)
model_kwargs = model._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=False,
)
if synced_gpus and this_peer_finished:
continue
next_token_logits = outputs.logits[:, -1, :].to(dtype=torch.float32)
raw_next_token_logits = next_token_logits.clone() if (return_dict_in_generate and output_logits) else None
next_token_scores = logits_processor(input_ids, next_token_logits)
if return_dict_in_generate:
if output_scores:
scores += (next_token_scores,)
if output_logits:
raw_logits += (raw_next_token_logits,)
if output_attentions:
decoder_attentions += (outputs.attentions,)
if output_hidden_states:
decoder_hidden_states += (outputs.hidden_states,)
if do_sample:
probs = nn.functional.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
next_tokens = torch.argmax(next_token_scores, dim=-1)
if has_eos_stopping_criteria:
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if streamer is not None:
streamer.put(next_tokens.cpu())
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
this_peer_finished = unfinished_sequences.max() == 0
cur_len += 1
del outputs
if streamer is not None:
streamer.end()
if return_dict_in_generate:
return GenerateDecoderOnlyOutput(
sequences=input_ids,
scores=scores,
logits=raw_logits,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
return input_ids
@torch.inference_mode()
def _llopa_v2_generation_mixin_decode_loop(
model,
*,
canonical_input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor],
past_key_values,
initial_logits: torch.Tensor,
lower_k: int,
no_upper_attn: bool,
replay_module: str,
replay_per_layers: int,
max_new_tokens: int,
min_new_tokens: int,
do_sample: bool,
logits_warpers,
stop_ids: set[int],
stop_token_ids: Optional[torch.Tensor],
stopping_criteria,
output_scores: bool,
compact_scores: bool,
return_dict_in_generate: bool,
):
if not isinstance(canonical_input_ids, torch.Tensor) or canonical_input_ids.dim() != 2:
return None
if canonical_input_ids.size(0) != 1:
return None
if not isinstance(initial_logits, torch.Tensor):
return None
if int(max_new_tokens) <= 0:
return canonical_input_ids if not bool(return_dict_in_generate) else None
from transformers.generation.utils import GenerateDecoderOnlyOutput
device = canonical_input_ids.device
total_prompt_len = int(canonical_input_ids.size(1))
max_new_tokens = int(max_new_tokens)
min_new_tokens = int(min_new_tokens)
sequences = torch.empty(
(canonical_input_ids.size(0), total_prompt_len + max_new_tokens),
dtype=canonical_input_ids.dtype,
device=device,
)
sequences[:, :total_prompt_len] = canonical_input_ids
generated = sequences[:, total_prompt_len:]
if isinstance(attention_mask, torch.Tensor) and attention_mask.dim() == 2:
prompt_attention_mask = attention_mask.to(device=device)
if prompt_attention_mask.size(0) != canonical_input_ids.size(0):
prompt_attention_mask = torch.ones_like(canonical_input_ids, dtype=torch.long, device=device)
elif prompt_attention_mask.size(1) != total_prompt_len:
prompt_attention_mask = prompt_attention_mask[:, -total_prompt_len:]
else:
prompt_attention_mask = torch.ones_like(canonical_input_ids, dtype=torch.long, device=device)
model_kwargs = {
"attention_mask": prompt_attention_mask,
"past_key_values": past_key_values,
"use_cache": True,
"llopa_v2_decode": True,
"llopa_v2_decode_layers": int(lower_k),
"llopa_v2_decode_no_upper_attn": bool(no_upper_attn),
"llopa_v2_decode_replay_module": str(replay_module),
"llopa_v2_decode_replay_per_layers": int(replay_per_layers),
}
record_scores = bool(output_scores)
record_compact_scores = record_scores and bool(compact_scores)
score_list: list[torch.Tensor] = []
score_list_append = score_list.append
compact_logprob_list: list[torch.Tensor] = []
compact_logprob_append = compact_logprob_list.append
cur = 0
pending_logits = initial_logits
should_apply_stopping_criteria = stopping_criteria is not None
while cur < max_new_tokens:
outputs = None
used_prefill_logits = pending_logits is not None
if used_prefill_logits:
logits = pending_logits.to(dtype=torch.float32, device=device, copy=True)
pending_logits = None
else:
current_input_ids = sequences[:, : total_prompt_len + cur]
if not isinstance(model_kwargs.get("cache_position"), torch.Tensor):
model_kwargs["cache_position"] = torch.arange(
total_prompt_len + cur - 1,
total_prompt_len + cur,
device=device,
dtype=torch.long,
)
model_inputs = model.prepare_inputs_for_generation(current_input_ids, **model_kwargs)
outputs = model(**model_inputs, return_dict=True)
model_kwargs = model._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=False,
)
logits = outputs.logits[:, -1, :].to(dtype=torch.float32, device=device, copy=True)
if stop_token_ids is not None and cur < min_new_tokens:
logits.index_fill_(1, stop_token_ids, -float("inf"))
if logits_warpers is not None:
logits = logits_warpers(generated[:, :cur], logits)
if record_scores and not record_compact_scores:
score_list_append(logits.detach().clone())
if do_sample:
probs = torch.softmax(logits, dim=-1)
next_tok = torch.multinomial(probs, num_samples=1)
else:
next_tok = torch.argmax(logits, dim=-1, keepdim=True)
if record_compact_scores:
next_logit = torch.gather(logits, 1, next_tok)
next_logprob = next_logit - torch.logsumexp(logits, dim=-1, keepdim=True)
compact_logprob_append(next_logprob.squeeze(-1).detach())
generated[:, cur : cur + 1] = next_tok
cur += 1
if used_prefill_logits:
one_mask = torch.ones(
(prompt_attention_mask.size(0), 1),
dtype=prompt_attention_mask.dtype,
device=device,
)
model_kwargs["attention_mask"] = torch.cat([model_kwargs["attention_mask"], one_mask], dim=-1)
model_kwargs["cache_position"] = torch.arange(
total_prompt_len,
total_prompt_len + 1,
device=device,
dtype=torch.long,
)
should_stop = False
tok_id = int(next_tok.item())
if tok_id in stop_ids and cur >= min_new_tokens:
should_stop = True
if (not should_stop) and should_apply_stopping_criteria:
sequences_now = sequences[:, : total_prompt_len + cur]
try:
should_stop = bool(stopping_criteria(sequences_now, logits))
except TypeError:
should_stop = bool(stopping_criteria(sequences_now, None))
if outputs is not None:
del outputs
if should_stop:
break
sequences = sequences[:, : total_prompt_len + cur]
if not bool(return_dict_in_generate):
return sequences
output = GenerateDecoderOnlyOutput(
sequences=sequences,
scores=tuple(score_list) if record_scores and not record_compact_scores else None,
past_key_values=model_kwargs.get("past_key_values", past_key_values),
)
if record_compact_scores:
if compact_logprob_list:
compact_logprobs = torch.stack(compact_logprob_list, dim=1)
else:
compact_logprobs = torch.empty(
(sequences.size(0), 0),
dtype=torch.float32,
device=sequences.device,
)
setattr(output, "generated_token_logprobs", compact_logprobs)
return output
def _attach_runtime_llopa_fast_generate(
model,
*,
lower_k: int,
prefill_attn: str = "causal",
no_upper_attn: bool = False,
) -> None:
try:
import types
except Exception:
return
try:
lower_k = int(lower_k)
except Exception:
lower_k = 0
attn = (prefill_attn or "causal").strip().lower()
if attn == "prefix_full":
attn = "full"
if lower_k <= 0 or attn not in {"causal", "full"}:
return
try:
setattr(model, "_runtime_llopa_fast_layers", int(lower_k))
setattr(model, "_runtime_llopa_fast_attn", attn)
setattr(model, "_runtime_llopa_fast_no_upper_attn", bool(no_upper_attn))
except Exception:
return
if getattr(model, "_runtime_llopa_fast_generate_attached", False):
return
orig_generate = getattr(model, "generate", None)
if not callable(orig_generate):
return
def _runtime_llopa_fast_generate(self, *args, **kwargs):
if bool(getattr(self.config, "is_encoder_decoder", False)):
return orig_generate(*args, **kwargs)
if kwargs.get("custom_generate") is not None or kwargs.get("assistant_model") is not None:
return orig_generate(*args, **kwargs)
if kwargs.get("inputs_embeds") is not None:
return orig_generate(*args, **kwargs)
num_beams = kwargs.get("num_beams", None)
if num_beams is None:
gen_cfg = kwargs.get("generation_config")
num_beams = getattr(gen_cfg, "num_beams", 1) if gen_cfg is not None else 1
if int(num_beams or 1) != 1:
_warn_once(
self,
"_warned_runtime_llopa_fast_num_beams",
"[load_llopa_model][warn] runtime_llopa_fast_generate currently supports only num_beams=1; falling back to model.generate().",
)
return orig_generate(*args, **kwargs)
fast_enabled = kwargs.pop("runtime_llopa_fast_generate", None)
if fast_enabled is None:
fast_enabled = True
if not bool(fast_enabled):
return orig_generate(*args, **kwargs)
runtime_enabled = kwargs.get("runtime_llopa_prefill", None)
if runtime_enabled is None:
runtime_enabled = True
if not bool(runtime_enabled):
return orig_generate(*args, **kwargs)
if kwargs.get("runtime_llopa_layers") is None:
kwargs["runtime_llopa_layers"] = int(getattr(self, "_runtime_llopa_fast_layers", 0) or 0)
if kwargs.get("runtime_llopa_attn") is None:
kwargs["runtime_llopa_attn"] = str(getattr(self, "_runtime_llopa_fast_attn", "causal") or "causal")
if kwargs.get("runtime_llopa_no_upper_attn") is None:
kwargs["runtime_llopa_no_upper_attn"] = bool(
getattr(self, "_runtime_llopa_fast_no_upper_attn", False)
)
kwargs["runtime_llopa_prefill"] = True
return orig_generate(
*args,
custom_generate=_runtime_llopa_fast_generate_mode,
**kwargs,
)
try:
model.generate = types.MethodType(_runtime_llopa_fast_generate, model)
setattr(model, "_runtime_llopa_fast_generate_attached", True)
except Exception:
pass
def _supports_prefill_lower_runtime(model) -> bool:
try:
inner = _get_inner_model(model)
except Exception:
inner = None
return bool(
hasattr(model, "tri_vanilla_prefill_decode_forward")
and inner is not None
and hasattr(inner, "tri_prefill_lower_cache")
)
def _supports_runtime_llopa_prompt_prefill(model) -> bool:
try:
inner = _get_inner_model(model)
except Exception:
inner = None
return bool(
hasattr(model, "tri_runtime_llopa_prompt_prefill_forward")
and _get_llopa_decode_step(model) is not None
and inner is not None
and hasattr(inner, "llopa_prefill_cache")
)
def _supports_direct_llopa_generate(model) -> bool:
try:
inner = _get_inner_model(model)
except Exception:
inner = None
return bool(
_get_llopa_decode_step(model) is not None
and inner is not None
and hasattr(inner, "llopa_prefill_cache")
)
def _warn_once(model, flag: str, msg: str) -> None:
if getattr(model, flag, False):
return
try:
setattr(model, flag, True)
except Exception:
pass
print(msg)
def _normalize_eos_token_ids(eos_token_id) -> list[int]:
if eos_token_id is None:
return []
if isinstance(eos_token_id, torch.Tensor):
eos_token_id = eos_token_id.flatten().tolist()
if isinstance(eos_token_id, (list, tuple, set)):
out: list[int] = []
for item in eos_token_id:
try:
out.append(int(item))
except Exception:
continue
return out
try:
return [int(eos_token_id)]
except Exception:
return []
def _prepare_stop_token_tensor(stop_ids, device) -> Optional[torch.LongTensor]:
if not stop_ids:
return None
try:
ordered = [int(tok_id) for tok_id in stop_ids]
except Exception:
return None
if not ordered:
return None
return torch.tensor(ordered, device=device, dtype=torch.long)
def _find_last_subsequence_start(input_ids: torch.Tensor, pattern: torch.Tensor) -> Optional[int]:
if not isinstance(input_ids, torch.Tensor) or not isinstance(pattern, torch.Tensor):
return None
if input_ids.dim() == 2:
if input_ids.size(0) != 1:
return None
seq = input_ids[0]
elif input_ids.dim() == 1:
seq = input_ids
else:
return None
if pattern.dim() == 2:
if pattern.size(0) != 1:
return None
pat = pattern[0]
elif pattern.dim() == 1:
pat = pattern
else:
return None
pat_len = int(pat.numel())
seq_len = int(seq.numel())
if pat_len <= 0 or seq_len < pat_len:
return None
windows = seq.unfold(0, pat_len, 1)
matches = (windows == pat).all(dim=-1).nonzero(as_tuple=False)
if matches.numel() == 0:
return None
return int(matches[-1].item())
def _build_sampling_warpers(do_sample: bool, temperature, top_p, top_k):
if not do_sample:
return None
from transformers.generation import LogitsProcessorList
from transformers.generation.logits_process import (
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
)
procs = LogitsProcessorList()
temp = float(temperature)
if temp != 1.0:
procs.append(TemperatureLogitsWarper(temp))
if top_p is not None and float(top_p) < 1.0:
procs.append(TopPLogitsWarper(float(top_p), min_tokens_to_keep=1))
if top_k is not None and int(top_k) > 0:
procs.append(TopKLogitsWarper(int(top_k), filter_value=-float("inf")))
return procs
@torch.inference_mode()
def _legacy_direct_llopa_generate_impl(
model,
tokenizer,
*,
input_ids: Optional[torch.LongTensor],
attention_mask: Optional[torch.Tensor],
lower_k: int,
prefill_attn: str,
max_length=None,
max_new_tokens=None,
min_length=None,
min_new_tokens=None,
do_sample=None,
temperature=None,
top_p=None,
top_k=None,
stopping_criteria=None,
pad_token_id=None,
eos_token_id=None,
output_scores: bool = False,
return_dict_in_generate: bool = False,
use_cache: Optional[bool] = None,
):
if input_ids is None or input_ids.dim() != 2 or input_ids.size(0) != 1:
return None
if use_cache is False:
return None
valid_len = int(input_ids.size(1))
if attention_mask is not None:
if attention_mask.dim() != 2 or attention_mask.size(0) != input_ids.size(0):
return None
valid_len = int(attention_mask[0].sum().item())
if valid_len <= 0:
return None
trimmed_input_ids = input_ids[:, -valid_len:]
header_ids = getattr(model, "_direct_llopa_header_ids", None)
if not isinstance(header_ids, torch.Tensor) or header_ids.numel() == 0:
return None
hdr = header_ids.to(device=trimmed_input_ids.device, dtype=trimmed_input_ids.dtype)
assistant_start = _find_last_subsequence_start(trimmed_input_ids, hdr)
if assistant_start is None:
return None
prefix_ids = trimmed_input_ids[:, :assistant_start]
assistant_ids = trimmed_input_ids[:, assistant_start:]
if assistant_ids.numel() == 0:
return None
try:
lower_k = int(lower_k)
except Exception:
lower_k = 0
if lower_k <= 0:
return None
attn = (prefill_attn or "causal").strip().lower()
if attn == "prefix_full":
attn = "full"
if attn not in {"causal", "full"}:
attn = "causal"
total_prompt_len = int(input_ids.size(1))
if max_new_tokens is None:
if max_length is None:
max_new_tokens = 256
else:
max_new_tokens = max(0, int(max_length) - total_prompt_len)
else:
max_new_tokens = int(max_new_tokens)
if min_new_tokens is None:
if min_length is None:
min_new_tokens = 0
else:
min_new_tokens = max(0, int(min_length) - total_prompt_len)
else:
min_new_tokens = int(min_new_tokens)
raw_temp = 0.0 if temperature is None else float(temperature)
if do_sample is None:
do_sample = bool(raw_temp != 0.0)
do_sample = bool(do_sample)
sample_temp = 1.0 if (not do_sample or raw_temp == 0.0) else float(raw_temp)
top_p = 1.0 if top_p is None else float(top_p)
top_k = None if top_k is None else int(top_k)
llopa_core = _get_llopa_core(model)
llopa_step = _get_llopa_decode_step(model)
if llopa_core is None or llopa_step is None:
return None
llopa_fn = getattr(llopa_core, "llopa_prefill_cache", None)
if llopa_fn is None:
return None
output_head = _get_output_head(model)
prefill_out = llopa_fn(
system_ids=prefix_ids[:, :0],
user_ids=prefix_ids,
assistant_ids=assistant_ids,
lower_k=lower_k,
prefill_mode="lower",
prefill_attn=attn,
return_last_assistant_hidden=bool(output_head is not None),
)
initial_logits = None
if isinstance(prefill_out, tuple):
pkv, last_hidden = prefill_out
if output_head is not None and isinstance(last_hidden, torch.Tensor) and last_hidden.numel() > 0:
initial_logits = output_head(last_hidden)[:, -1, :].to(torch.float32)
else:
pkv = prefill_out
S = 0
U = int(prefix_ids.size(1))
last = assistant_ids[:, -1:]
stop_ids = set(_normalize_eos_token_ids(eos_token_id))
if not stop_ids:
tok_eos = getattr(tokenizer, "eos_token_id", None)
if tok_eos is not None:
stop_ids.add(int(tok_eos))
stop_token_ids = _prepare_stop_token_tensor(stop_ids, last.device)
logits_warpers = _build_sampling_warpers(do_sample, sample_temp, top_p, top_k)
max_new_tokens = int(max_new_tokens)
min_new_tokens = int(min_new_tokens)
total_prompt_len = int(input_ids.size(1))
record_scores = bool(output_scores)
should_apply_stopping_criteria = stopping_criteria is not None
sequences = torch.empty(
(input_ids.size(0), total_prompt_len + max_new_tokens),
dtype=input_ids.dtype,
device=input_ids.device,
)
sequences[:, :total_prompt_len] = input_ids
generated = sequences[:, total_prompt_len:]
score_list: list[torch.Tensor] = []
score_list_append = score_list.append
cur = 0
pending_logits = initial_logits
while cur < max_new_tokens:
out = None
if pending_logits is None:
out = llopa_step(
assistant_ids=last,
lower_k=lower_k,
pkv=pkv,
S=S,
U=U,
logits_to_keep=1,
labels=None,
prefill_mode="lower",
)
pkv = out.past_key_values or pkv
logits = out.logits[:, -1, :].to(dtype=torch.float32, device=last.device, copy=True)
else:
logits = pending_logits
pending_logits = None
if stop_token_ids is not None and cur < min_new_tokens:
logits.index_fill_(1, stop_token_ids, -float("inf"))
if logits_warpers is not None:
logits = logits_warpers(generated[:, :cur], logits)
if record_scores:
score_list_append(logits.detach().clone())
if do_sample:
probs = torch.softmax(logits, dim=-1)
next_tok = torch.multinomial(probs, num_samples=1)
else:
next_tok = torch.argmax(logits, dim=-1, keepdim=True)
generated[:, cur : cur + 1] = next_tok
cur += 1
should_stop = False
tok_id = int(next_tok.item())
if tok_id in stop_ids and cur >= min_new_tokens:
should_stop = True
if (not should_stop) and should_apply_stopping_criteria:
sequences_now = sequences[:, : total_prompt_len + cur]
try:
should_stop = bool(stopping_criteria(sequences_now, logits))
except TypeError:
should_stop = bool(stopping_criteria(sequences_now, None))
if out is not None:
del out
if should_stop:
break
last = next_tok
sequences = sequences[:, : total_prompt_len + cur]
if not bool(return_dict_in_generate):
return sequences
from transformers.generation.utils import GenerateDecoderOnlyOutput
return GenerateDecoderOnlyOutput(
sequences=sequences,
scores=tuple(score_list) if bool(output_scores) else None,
past_key_values=pkv,
)
@torch.inference_mode()
def _direct_llopa_generate_impl(
model,
tokenizer,
*,
prompt_messages,
prompt_add_generation_prompt: bool,
structured_prompt_segments=None,
input_ids: Optional[torch.LongTensor],
attention_mask: Optional[torch.Tensor],
lower_k: int,
prefill_attn: str,
system_prefill: str,
user_prefill: str,
no_upper_attn: bool,
see_past_assistant: bool = False,
replay_module: str = "none",
replay_per_layers: int = -1,
last_layer_module: Optional[str] = None,
seed_mode: str = "auto",
max_length=None,
max_new_tokens=None,
min_length=None,
min_new_tokens=None,
do_sample=None,
temperature=None,
top_p=None,
top_k=None,
stopping_criteria=None,
pad_token_id=None,
eos_token_id=None,
output_scores: bool = False,
compact_scores: bool = False,
return_dict_in_generate: bool = False,
use_cache: Optional[bool] = None,
):
if last_layer_module is not None and _normalize_replay_module_value(replay_module) == "none":
replay_module = last_layer_module
replay_module = _normalize_replay_module_value(replay_module)
replay_per_layers = _normalize_replay_per_layers_value(replay_per_layers)
if use_cache is False:
return None
try:
lower_k = int(lower_k)
except Exception:
lower_k = 0
if lower_k <= 0:
return None
attn = (prefill_attn or "causal").strip().lower()
if attn == "prefix_full":
attn = "full"
if attn not in {"causal", "full"}:
attn = "causal"
sys_prefill = (system_prefill or "full").strip().lower()
if sys_prefill not in {"full", "no_system", "no_bos_system"}:
sys_prefill = "full"
user_prefill_norm = (user_prefill or "full").strip().lower()
if user_prefill_norm != "full":
raise ValueError("Unified direct LLoPA currently supports only user_prefill='full'.")
llopa_core = _get_llopa_core(model)
llopa_step = _get_llopa_decode_step(model)
if llopa_core is None or llopa_step is None:
return None
llopa_forward_assistant = getattr(llopa_core, "tri_forward_assistant", None)
decode_output_head = _get_output_head(model)
use_direct_decode_step = (
_env_flag_enabled("CAPSULE_LLOPA_DIRECT_DECODE_STEP", "1")
and callable(llopa_forward_assistant)
and decode_output_head is not None
)
device = None
if isinstance(input_ids, torch.Tensor):
if input_ids.dim() != 2 or input_ids.size(0) != 1:
return None
device = input_ids.device
if device is None:
try:
device = next(model.parameters()).device
except Exception:
device = "cpu"
segments = structured_prompt_segments if isinstance(structured_prompt_segments, dict) else None
if segments is None:
segments = _build_structured_prompt_segments(
tokenizer,
prompt_messages,
prompt_add_generation_prompt=bool(prompt_add_generation_prompt),
device=device,
)
prompt_ids = segments["prompt_ids"]
system_ids = segments["system_ids"]
user_ids = segments["user_ids"]
assistant_prefill_ids = segments["assistant_prefill_ids"]
replay_user_prefix_keep_len = int(segments.get("replay_user_prefix_keep_len", 0) or 0)
replay_user_start = int(segments.get("replay_user_start", 0) or 0)
replay_user_len = int(segments.get("replay_user_len", 0) or 0)
if assistant_prefill_ids.numel() == 0:
return None
raw_temp = 0.0 if temperature is None else float(temperature)
if do_sample is None:
do_sample = bool(raw_temp != 0.0)
do_sample = bool(do_sample)
sample_temp = 1.0 if (not do_sample or raw_temp == 0.0) else float(raw_temp)
top_p = 1.0 if top_p is None else float(top_p)
top_k = None if top_k is None else int(top_k)
initial_logits = None
prompt_bundle = _build_unified_prefill_lower_prompt_bundle(
tokenizer,
prompt_messages=prompt_messages,
prompt_add_generation_prompt=bool(prompt_add_generation_prompt),
structured_prompt_segments=segments,
device=device,
)
reference_seed = _direct_prefill_lower_cache_and_logits(
model,
prompt_bundle=prompt_bundle,
lower_k=lower_k,
prefill_attn=attn,
system_prefill=sys_prefill,
no_upper_attn=bool(no_upper_attn),
see_past_assistant=bool(see_past_assistant),
replay_module=str(replay_module),
replay_per_layers=int(replay_per_layers),
seed_mode=str(seed_mode or "auto"),
)
canonical_input_ids = None
if isinstance(input_ids, torch.Tensor) and input_ids.dim() == 2 and input_ids.size(0) == 1:
valid_len = int(input_ids.size(1))
if isinstance(attention_mask, torch.Tensor) and attention_mask.dim() == 2 and attention_mask.size(0) == 1:
valid_len = int(attention_mask[0].sum().item())
if valid_len > 0:
canonical_input_ids = input_ids[:, -valid_len:]
if not isinstance(canonical_input_ids, torch.Tensor):
canonical_input_ids = prompt_bundle.get("effective_prompt_ids")
if not isinstance(canonical_input_ids, torch.Tensor):
canonical_input_ids = prompt_ids
total_prompt_len = int(canonical_input_ids.size(1))
if max_new_tokens is None:
if max_length is None:
max_new_tokens = 256
else:
max_new_tokens = max(0, int(max_length) - total_prompt_len)
else:
max_new_tokens = int(max_new_tokens)
if min_new_tokens is None:
if min_length is None:
min_new_tokens = 0
else:
min_new_tokens = max(0, int(min_length) - total_prompt_len)
else:
min_new_tokens = int(min_new_tokens)
if reference_seed is not None:
pkv, S, U, initial_logits = reference_seed
else:
output_head = _get_output_head(model)
if bool(no_upper_attn):
pkv, S, U = _llopa_prefill_cache(
llopa_core,
system_ids,
user_ids,
assistant_prefill_ids,
lower_k=lower_k,
prefill_mode="lower",
prefill_attn=attn,
system_prefill=sys_prefill,
replay_user_prefix_keep_len=replay_user_prefix_keep_len,
replay_user_start=replay_user_start,
replay_user_len=replay_user_len,
)
else:
pkv, S, U, last_hidden = _llopa_prefill_cache(
llopa_core,
system_ids,
user_ids,
assistant_prefill_ids,
lower_k=lower_k,
prefill_mode="lower",
prefill_attn=attn,
system_prefill=sys_prefill,
return_last_assistant_hidden=bool(output_head is not None),
replay_user_prefix_keep_len=replay_user_prefix_keep_len,
replay_user_start=replay_user_start,
replay_user_len=replay_user_len,
)
if output_head is not None and isinstance(last_hidden, torch.Tensor) and last_hidden.numel() > 0:
initial_logits = output_head(last_hidden)[:, -1, :].to(torch.float32)
stop_ids = set(_normalize_eos_token_ids(eos_token_id))
if not stop_ids:
tok_eos = getattr(tokenizer, "eos_token_id", None)
if tok_eos is not None:
stop_ids.add(int(tok_eos))
with contextlib.suppress(Exception):
eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>")
if eot_id is not None and eot_id != tokenizer.unk_token_id:
stop_ids.add(int(eot_id))
last = assistant_prefill_ids[:, -1:]
stop_token_ids = _prepare_stop_token_tensor(stop_ids, last.device)
logits_warpers = _build_sampling_warpers(do_sample, sample_temp, top_p, top_k)
max_new_tokens = int(max_new_tokens)
min_new_tokens = int(min_new_tokens)
lower_k = int(lower_k)
no_upper_attn_bool = bool(no_upper_attn)
replay_module_str = str(replay_module)
replay_per_layers_int = int(replay_per_layers)
record_scores = bool(output_scores)
record_compact_scores = bool(output_scores) and bool(compact_scores)
should_apply_stopping_criteria = stopping_criteria is not None
mixin_decode_default = getattr(model, "_llopa_v2_generation_mixin_decode", None)
use_mixin_decode = (
bool(mixin_decode_default)
if mixin_decode_default is not None
else _env_flag_enabled("CAPSULE_LLOPA_GENERATION_MIXIN_DECODE", "0")
)
if use_mixin_decode and isinstance(initial_logits, torch.Tensor):
mixin_output = _llopa_v2_generation_mixin_decode_loop(
model,
canonical_input_ids=canonical_input_ids,
attention_mask=attention_mask,
past_key_values=pkv,
initial_logits=initial_logits,
lower_k=lower_k,
no_upper_attn=no_upper_attn_bool,
replay_module=replay_module_str,
replay_per_layers=replay_per_layers_int,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
do_sample=do_sample,
logits_warpers=logits_warpers,
stop_ids=stop_ids,
stop_token_ids=stop_token_ids,
stopping_criteria=stopping_criteria,
output_scores=output_scores,
compact_scores=compact_scores,
return_dict_in_generate=return_dict_in_generate,
)
if mixin_output is not None:
return mixin_output
sequences = torch.empty(
(canonical_input_ids.size(0), total_prompt_len + max_new_tokens),
dtype=canonical_input_ids.dtype,
device=canonical_input_ids.device,
)
sequences[:, :total_prompt_len] = canonical_input_ids
generated = sequences[:, total_prompt_len:]
score_list: list[torch.Tensor] = []
score_list_append = score_list.append
compact_logprob_list: list[torch.Tensor] = []
compact_logprob_append = compact_logprob_list.append
cur = 0
pending_logits = initial_logits
while cur < max_new_tokens:
out = None
if pending_logits is None:
if use_direct_decode_step:
out = llopa_forward_assistant(
assistant_ids=last,
lower_k=lower_k,
pkv=pkv,
S=S,
U=U,
write_cache=True,
prefill_mode="lower",
no_upper_attn=no_upper_attn_bool,
align_cache_position_to_layer_past=False,
replay_module=replay_module_str,
replay_per_layers=replay_per_layers_int,
)
pkv = out.past_key_values or pkv
logits = decode_output_head(out.last_hidden_state[:, -1, :])
if logits.dim() == 3:
logits = logits[:, -1, :]
else:
out = llopa_step(
assistant_ids=last,
lower_k=lower_k,
pkv=pkv,
S=S,
U=U,
logits_to_keep=1,
labels=None,
prefill_mode="lower",
no_upper_attn=no_upper_attn_bool,
align_cache_position_to_layer_past=False,
replay_module=replay_module_str,
replay_per_layers=replay_per_layers_int,
)
pkv = out.past_key_values or pkv
logits = out.logits[:, -1, :]
logits = logits.to(dtype=torch.float32, device=last.device, copy=True)
else:
logits = pending_logits
pending_logits = None
if stop_token_ids is not None and cur < min_new_tokens:
logits.index_fill_(1, stop_token_ids, -float("inf"))
if logits_warpers is not None:
logits = logits_warpers(generated[:, :cur], logits)
if record_scores and not record_compact_scores:
score_list_append(logits.detach().clone())
if do_sample:
probs = torch.softmax(logits, dim=-1)
next_tok = torch.multinomial(probs, num_samples=1)
else:
next_tok = torch.argmax(logits, dim=-1, keepdim=True)
if record_compact_scores:
next_logit = torch.gather(logits, 1, next_tok)
next_logprob = next_logit - torch.logsumexp(logits, dim=-1, keepdim=True)
compact_logprob_append(next_logprob.squeeze(-1).detach())
generated[:, cur : cur + 1] = next_tok
cur += 1
should_stop = False
tok_id = int(next_tok.item())
if tok_id in stop_ids and cur >= min_new_tokens:
should_stop = True
if (not should_stop) and should_apply_stopping_criteria:
sequences_now = sequences[:, : total_prompt_len + cur]
try:
should_stop = bool(stopping_criteria(sequences_now, logits))
except TypeError:
should_stop = bool(stopping_criteria(sequences_now, None))
if out is not None:
del out
if should_stop:
break
last = next_tok
sequences = sequences[:, : total_prompt_len + cur]
if not bool(return_dict_in_generate):
return sequences
from transformers.generation.utils import GenerateDecoderOnlyOutput
output = GenerateDecoderOnlyOutput(
sequences=sequences,
scores=tuple(score_list) if bool(output_scores) and not record_compact_scores else None,
past_key_values=pkv,
)
if record_compact_scores:
if compact_logprob_list:
compact_logprobs = torch.stack(compact_logprob_list, dim=1)
else:
compact_logprobs = torch.empty(
(sequences.size(0), 0),
dtype=torch.float32,
device=sequences.device,
)
setattr(output, "generated_token_logprobs", compact_logprobs)
return output
def _is_prompt_messages_batch(prompt_messages) -> bool:
if not isinstance(prompt_messages, (list, tuple)) or not prompt_messages:
return False
if all(isinstance(item, dict) for item in prompt_messages):
return False
return all(isinstance(item, (list, tuple)) for item in prompt_messages)
def _is_structured_segments_batch(structured_prompt_segments) -> bool:
return (
isinstance(structured_prompt_segments, (list, tuple))
and not isinstance(structured_prompt_segments, dict)
and all(isinstance(item, dict) for item in structured_prompt_segments)
)
def _batch_prompt_add_generation_flags(value, batch_size: int) -> list[bool]:
if isinstance(value, (list, tuple)):
if len(value) != int(batch_size):
raise ValueError(
f"prompt_add_generation_prompt batch size mismatch: {len(value)} != {int(batch_size)}"
)
return [bool(item) for item in value]
return [bool(value) for _ in range(int(batch_size))]
def _as_1d_long_tensor(value: torch.Tensor, *, device) -> torch.LongTensor:
value = value.to(device=device, dtype=torch.long)
if value.dim() == 2:
if value.size(0) != 1:
raise ValueError("Expected a single-row prompt segment tensor.")
value = value[0]
elif value.dim() != 1:
value = value.reshape(-1)
return value
def _pad_1d_rows(
rows: list[torch.Tensor],
*,
pad_value: int,
device,
dtype: torch.dtype = torch.long,
) -> torch.Tensor:
batch_size = len(rows)
max_len = max((int(row.numel()) for row in rows), default=0)
out = torch.full((batch_size, max_len), int(pad_value), device=device, dtype=dtype)
for row_idx, row in enumerate(rows):
row = row.to(device=device, dtype=dtype).reshape(-1)
width = int(row.numel())
if width > 0:
out[row_idx, :width] = row
return out
def _pad_prompt_metadata_rows(
rows: list[Optional[torch.Tensor]],
*,
fill_value: int,
device,
dtype: torch.dtype,
) -> torch.Tensor:
batch_size = len(rows)
max_len = 0
flat_rows: list[torch.Tensor] = []
for row in rows:
if isinstance(row, torch.Tensor) and row.numel() > 0:
flat = row.to(device=device, dtype=dtype)
if flat.dim() == 2:
flat = flat[0]
else:
flat = flat.reshape(-1)
else:
flat = torch.empty((0,), device=device, dtype=dtype)
flat_rows.append(flat)
max_len = max(max_len, int(flat.numel()))
if max_len <= 0:
return torch.full((batch_size, 1), int(fill_value), device=device, dtype=dtype)
out = torch.full((batch_size, max_len), int(fill_value), device=device, dtype=dtype)
for row_idx, flat in enumerate(flat_rows):
width = int(flat.numel())
if width > 0:
out[row_idx, :width] = flat
return out
def _build_batched_structured_prompt_segments(
tokenizer,
*,
prompt_messages,
prompt_add_generation_prompt,
structured_prompt_segments,
device,
) -> list[dict]:
if _is_structured_segments_batch(structured_prompt_segments):
return list(structured_prompt_segments)
if isinstance(structured_prompt_segments, dict):
return [structured_prompt_segments]
if _is_prompt_messages_batch(prompt_messages):
batch_size = len(prompt_messages)
flags = _batch_prompt_add_generation_flags(prompt_add_generation_prompt, batch_size)
out = []
for row_idx, messages in enumerate(prompt_messages):
row_messages = list(messages)
add_generation_prompt = bool(flags[row_idx])
normalized = _normalize_prompt_messages(row_messages)
if add_generation_prompt and normalized and normalized[-1]["role"] == "assistant":
add_generation_prompt = False
out.append(
_build_structured_prompt_segments(
tokenizer,
row_messages,
prompt_add_generation_prompt=add_generation_prompt,
device=device,
)
)
return out
if prompt_messages is not None:
if prompt_add_generation_prompt is None:
raise ValueError("llopa_v2_batch_generate requires prompt_add_generation_prompt for prompt_messages.")
return [
_build_structured_prompt_segments(
tokenizer,
prompt_messages,
prompt_add_generation_prompt=bool(prompt_add_generation_prompt),
device=device,
)
]
return []
def _batched_segments_to_tensors(segments: list[dict], *, device, pad_token_id: int) -> dict:
prompt_rows: list[torch.Tensor] = []
prompt_lens: list[int] = []
split_starts: list[int] = []
system_lens: list[int] = []
replay_prefix_keep_lens: list[int] = []
replay_user_starts: list[int] = []
replay_user_lens: list[int] = []
header_start_rows: list[Optional[torch.Tensor]] = []
turn_end_rows: list[Optional[torch.Tensor]] = []
header_mask_rows: list[Optional[torch.Tensor]] = []
for seg in segments:
prompt_ids = seg.get("prompt_ids")
assistant_prefill_ids = seg.get("assistant_prefill_ids")
if not isinstance(prompt_ids, torch.Tensor) or prompt_ids.numel() == 0:
raise ValueError("llopa_v2_batch_generate requires non-empty prompt_ids in each segment.")
if not isinstance(assistant_prefill_ids, torch.Tensor) or assistant_prefill_ids.numel() == 0:
raise ValueError("llopa_v2_batch_generate requires non-empty assistant_prefill_ids in each segment.")
prompt_row = _as_1d_long_tensor(prompt_ids, device=device)
assistant_row = _as_1d_long_tensor(assistant_prefill_ids, device=device)
prompt_len = int(prompt_row.numel())
assistant_len = int(assistant_row.numel())
split_start = int(prompt_len - assistant_len)
prefix_ids = seg.get("prefix_ids")
if isinstance(prefix_ids, torch.Tensor) and prefix_ids.numel() > 0:
split_start = int(_as_1d_long_tensor(prefix_ids, device=device).numel())
split_start = max(0, min(split_start, prompt_len - 1))
system_ids = seg.get("system_ids")
system_len = int(_as_1d_long_tensor(system_ids, device=device).numel()) if isinstance(system_ids, torch.Tensor) else 0
system_len = max(0, min(system_len, split_start))
prompt_rows.append(prompt_row)
prompt_lens.append(prompt_len)
split_starts.append(split_start)
system_lens.append(system_len)
replay_prefix_keep_lens.append(int(seg.get("replay_user_prefix_keep_len", 0) or 0))
replay_user_starts.append(int(seg.get("replay_user_start", 0) or 0))
replay_user_lens.append(int(seg.get("replay_user_len", 0) or 0))
header_starts = seg.get("assistant_header_starts")
turn_ends = seg.get("assistant_turn_ends")
header_mask = seg.get("assistant_header_start_mask")
if not isinstance(header_starts, torch.Tensor) or header_starts.numel() == 0:
header_starts = torch.tensor([[split_start]], device=device, dtype=torch.long)
if not isinstance(turn_ends, torch.Tensor) or turn_ends.numel() == 0:
turn_ends = torch.tensor([[prompt_len]], device=device, dtype=torch.long)
if not isinstance(header_mask, torch.Tensor) or header_mask.numel() == 0:
header_mask = torch.ones_like(header_starts, device=device, dtype=torch.bool)
header_start_rows.append(header_starts)
turn_end_rows.append(turn_ends)
header_mask_rows.append(header_mask)
prompt_ids = _pad_1d_rows(prompt_rows, pad_value=int(pad_token_id), device=device, dtype=torch.long)
prompt_lens_tensor = torch.tensor(prompt_lens, device=device, dtype=torch.long)
prompt_attention_mask = (
torch.arange(prompt_ids.size(1), device=device, dtype=torch.long).unsqueeze(0)
< prompt_lens_tensor.unsqueeze(1)
).to(dtype=torch.long)
return {
"prompt_rows": prompt_rows,
"prompt_ids": prompt_ids,
"prompt_attention_mask": prompt_attention_mask,
"prompt_lens": prompt_lens_tensor,
"split_starts": torch.tensor(split_starts, device=device, dtype=torch.long),
"system_lens": torch.tensor(system_lens, device=device, dtype=torch.long),
"replay_user_prefix_keep_lens": torch.tensor(replay_prefix_keep_lens, device=device, dtype=torch.long),
"replay_user_starts": torch.tensor(replay_user_starts, device=device, dtype=torch.long),
"replay_user_lens": torch.tensor(replay_user_lens, device=device, dtype=torch.long),
"assistant_header_starts": _pad_prompt_metadata_rows(
header_start_rows,
fill_value=-1,
device=device,
dtype=torch.long,
),
"assistant_turn_ends": _pad_prompt_metadata_rows(
turn_end_rows,
fill_value=-1,
device=device,
dtype=torch.long,
),
"assistant_header_start_mask": _pad_prompt_metadata_rows(
header_mask_rows,
fill_value=0,
device=device,
dtype=torch.bool,
).to(dtype=torch.bool),
}
def _cache_layer_seq_len(pkv, layer_idx: int) -> int:
try:
k, _ = pkv_get(pkv, int(layer_idx))
except Exception:
return 0
if not isinstance(k, torch.Tensor) or k.dim() < 3:
return 0
return int(k.shape[-2])
def _merge_optional_batch_sequence_attr(row_pkvs: list, attr_name: str, *, device, dtype=None):
rows = []
trailing_shape = None
for pkv in row_pkvs:
value = getattr(pkv, attr_name, None)
if isinstance(value, torch.Tensor) and value.numel() > 0:
value = value.to(device=device)
if dtype is not None:
value = value.to(dtype=dtype)
if value.dim() == 1:
value = value.view(1, -1)
elif value.dim() >= 2 and value.size(0) != 1:
value = value[:1]
trailing_shape = tuple(value.shape[2:])
else:
value = None
rows.append(value)
if trailing_shape is None:
return None
max_len = max((int(row.shape[1]) for row in rows if isinstance(row, torch.Tensor)), default=0)
if max_len <= 0:
return None
ref = next(row for row in rows if isinstance(row, torch.Tensor))
out_shape = (len(rows), max_len) + trailing_shape
out = torch.zeros(out_shape, device=device, dtype=ref.dtype)
for row_idx, row in enumerate(rows):
if not isinstance(row, torch.Tensor):
continue
width = int(row.shape[1])
if width > 0:
out[row_idx : row_idx + 1, :width, ...] = row[:, :width, ...]
return out
def _merge_llopa_batch_row_caches(row_pkvs: list, *, device) -> Optional[DynamicCache]:
if not row_pkvs:
return None
try:
n_layers = max(int(pkv_len(pkv)) for pkv in row_pkvs)
except Exception:
return None
if n_layers <= 0:
return None
merged_pairs = []
layer_valid_masks: list[torch.Tensor] = []
for layer_idx in range(n_layers):
row_kvs = []
max_len = 0
ref_k = None
ref_v = None
for pkv in row_pkvs:
try:
k, v = pkv_get(pkv, layer_idx)
except Exception:
k = None
v = None
if isinstance(k, torch.Tensor) and isinstance(v, torch.Tensor) and k.dim() == 4 and v.dim() == 4:
k = k.to(device=device)
v = v.to(device=device)
if k.size(0) != 1:
k = k[:1]
v = v[:1]
ref_k = k if ref_k is None else ref_k
ref_v = v if ref_v is None else ref_v
max_len = max(max_len, int(k.shape[-2]))
row_kvs.append((k, v))
else:
row_kvs.append((None, None))
if ref_k is None or ref_v is None:
return None
B = len(row_pkvs)
merged_k = torch.zeros(
(B, int(ref_k.shape[1]), max_len, int(ref_k.shape[-1])),
device=device,
dtype=ref_k.dtype,
)
merged_v = torch.zeros(
(B, int(ref_v.shape[1]), max_len, int(ref_v.shape[-1])),
device=device,
dtype=ref_v.dtype,
)
valid_mask = torch.zeros((B, max_len), device=device, dtype=torch.bool)
for row_idx, (k, v) in enumerate(row_kvs):
if not isinstance(k, torch.Tensor) or not isinstance(v, torch.Tensor):
continue
width = int(k.shape[-2])
if width > 0:
merged_k[row_idx : row_idx + 1, :, :width, :] = k[:, :, :width, :]
merged_v[row_idx : row_idx + 1, :, :width, :] = v[:, :, :width, :]
valid_mask[row_idx, :width] = True
merged_pairs.append((merged_k, merged_v))
layer_valid_masks.append(valid_mask)
try:
merged = DynamicCache(ddp_cache_data=merged_pairs)
except Exception:
merged = DynamicCache()
for layer_idx, (k, v) in enumerate(merged_pairs):
merged.update(k, v, layer_idx)
try:
setattr(merged, "_llopa_batch_layer_valid_masks", layer_valid_masks)
except Exception:
pass
replay_hidden = _merge_optional_batch_sequence_attr(
row_pkvs,
"_tri_last_layer_memory_hidden",
device=device,
)
replay_pos = _merge_optional_batch_sequence_attr(
row_pkvs,
"_tri_last_layer_memory_position_ids",
device=device,
dtype=torch.long,
)
replay_valid = _merge_optional_batch_sequence_attr(
row_pkvs,
"_tri_last_layer_memory_valid_mask",
device=device,
dtype=torch.bool,
)
if isinstance(replay_hidden, torch.Tensor) and isinstance(replay_pos, torch.Tensor):
with contextlib.suppress(Exception):
setattr(merged, "_tri_last_layer_memory_hidden", replay_hidden)
setattr(merged, "_tri_last_layer_memory_position_ids", replay_pos)
if isinstance(replay_valid, torch.Tensor):
setattr(merged, "_tri_last_layer_memory_valid_mask", replay_valid.to(dtype=torch.bool))
module_type = getattr(row_pkvs[0], "_tri_replay_module", getattr(row_pkvs[0], "_tri_last_layer_module", "none"))
setattr(merged, "_tri_replay_module", str(module_type or "none"))
setattr(merged, "_tri_last_layer_module", str(module_type or "none"))
setattr(merged, "_tri_replay_per_layers", int(getattr(row_pkvs[0], "_tri_replay_per_layers", -1) or -1))
return merged
def _append_llopa_batch_cache_valid_masks(pkv, input_valid: torch.Tensor) -> None:
masks = getattr(pkv, "_llopa_batch_layer_valid_masks", None)
if not isinstance(masks, list):
return
input_valid = input_valid.to(dtype=torch.bool)
new_masks = []
for layer_idx, mask in enumerate(masks):
if not isinstance(mask, torch.Tensor) or mask.dim() != 2:
new_masks.append(mask)
continue
layer_len = _cache_layer_seq_len(pkv, layer_idx)
cur_len = int(mask.size(1))
if layer_len > cur_len:
add_width = int(layer_len - cur_len)
add = input_valid.to(device=mask.device).view(-1, 1).expand(mask.size(0), add_width)
mask = torch.cat([mask, add], dim=1)
elif layer_len < cur_len:
mask = mask[:, :layer_len]
new_masks.append(mask)
with contextlib.suppress(Exception):
setattr(pkv, "_llopa_batch_layer_valid_masks", new_masks)
setattr(pkv, "_tri_past_len_cache", None)
def _direct_llopa_batch_prefill_cache_and_logits(
model,
*,
segments: list[dict],
pad_id: int,
lower_k: int,
prefill_attn: str,
system_prefill: str,
no_upper_attn: bool,
see_past_assistant: bool,
replay_module: str,
replay_per_layers: int,
device,
):
if not segments:
return None
if _normalize_replay_module_value(replay_module) != "none":
return None
seed_fn = _get_llopa_full_prompt_seed(model)
if not callable(seed_fn):
return None
try:
batch = _batched_segments_to_tensors(segments, device=device, pad_token_id=int(pad_id))
except Exception:
return None
try:
seed = seed_fn(
input_ids=batch["prompt_ids"],
attention_mask=batch["prompt_attention_mask"],
use_cache=True,
logits_to_keep=1,
lower_k=int(lower_k),
prefill_attn=str(prefill_attn),
system_prefill=str(system_prefill),
no_upper_attn=bool(no_upper_attn),
prefill_lower_split_start=batch["split_starts"],
prefill_lower_system_len=batch["system_lens"],
prefill_lower_replay_user_prefix_keep_len=batch["replay_user_prefix_keep_lens"],
prefill_lower_replay_user_start=batch["replay_user_starts"],
prefill_lower_replay_user_len=batch["replay_user_lens"],
assistant_header_starts=batch["assistant_header_starts"],
assistant_turn_ends=batch["assistant_turn_ends"],
assistant_header_start_mask=batch["assistant_header_start_mask"],
prefill_lower_see_past_assistant=bool(see_past_assistant),
replay_module=str(replay_module),
replay_per_layers=int(replay_per_layers),
)
except Exception:
return None
if not isinstance(seed, tuple) or len(seed) != 4:
return None
pkv, S, U, logits = seed
if pkv is None or not isinstance(logits, torch.Tensor) or logits.numel() == 0:
return None
if logits.dim() == 3:
logits = logits[:, -1, :]
if logits.dim() != 2 or int(logits.size(0)) != len(segments):
return None
return pkv, S, U, logits.to(device=device, dtype=torch.float32)
def _direct_llopa_batch_generate_cached_impl(
model,
tokenizer,
*,
segments: list[dict],
canonical_input_ids: torch.Tensor,
context_width: int,
batch_size: int,
lower_k: int,
prefill_attn: str,
system_prefill: str,
no_upper_attn: bool,
see_past_assistant: bool,
replay_module: str,
replay_per_layers: int,
last_layer_module: Optional[str],
max_new_tokens: int,
min_new_tokens: int,
do_sample: bool,
logits_warpers,
stopping_criteria,
stop_ids: set[int],
stop_token_ids: Optional[torch.Tensor],
pad_id: int,
output_scores: bool,
compact_scores: bool,
return_dict_in_generate: bool,
device,
):
llopa_core = _get_llopa_core(model)
llopa_step = _get_llopa_decode_step(model)
if llopa_core is None or llopa_step is None:
return None
llopa_forward_assistant = getattr(llopa_core, "tri_forward_assistant", None)
decode_output_head = _get_output_head(model)
use_direct_decode_step = (
_env_flag_enabled("CAPSULE_LLOPA_DIRECT_DECODE_STEP", "1")
and callable(llopa_forward_assistant)
and decode_output_head is not None
)
row_pkvs = []
initial_logits_rows = []
prompt_bundles = []
for seg in segments:
prompt_bundle = _build_unified_prefill_lower_prompt_bundle(
tokenizer,
prompt_messages=None,
prompt_add_generation_prompt=True,
structured_prompt_segments=seg,
device=device,
)
prompt_bundles.append(prompt_bundle)
has_past_assistant_history = bool(see_past_assistant) and any(
_prompt_bundle_has_past_assistant_history(bundle) for bundle in prompt_bundles
)
batch_seed = None
if int(batch_size) > 1:
batch_seed = _direct_llopa_batch_prefill_cache_and_logits(
model,
segments=segments,
pad_id=int(pad_id),
lower_k=int(lower_k),
prefill_attn=str(prefill_attn),
system_prefill=str(system_prefill),
no_upper_attn=bool(no_upper_attn),
see_past_assistant=bool(see_past_assistant),
replay_module=str(replay_module),
replay_per_layers=int(replay_per_layers),
device=device,
)
if batch_seed is not None:
pkv, _S, _U, pending_logits = batch_seed
pending_logits = pending_logits.to(device=device, dtype=torch.float32)
else:
if int(batch_size) > 1 and _normalize_replay_module_value(replay_module) != "none":
return None
if int(batch_size) > 1 and bool(has_past_assistant_history):
return None
for prompt_bundle in prompt_bundles:
seed = _direct_prefill_lower_cache_and_logits(
model,
prompt_bundle=prompt_bundle,
lower_k=int(lower_k),
prefill_attn=str(prefill_attn),
system_prefill=str(system_prefill),
no_upper_attn=bool(no_upper_attn),
see_past_assistant=bool(see_past_assistant),
replay_module=str(replay_module),
replay_per_layers=int(replay_per_layers),
last_layer_module=last_layer_module,
seed_mode="prefill_header",
)
if seed is None:
return None
pkv_row, _S, _U, initial_logits = seed
if not isinstance(initial_logits, torch.Tensor) or initial_logits.numel() == 0:
return None
row_pkvs.append(pkv_row)
initial_logits_rows.append(initial_logits.to(device=device, dtype=torch.float32))
pkv = _merge_llopa_batch_row_caches(row_pkvs, device=device)
if pkv is None:
return None
pending_logits = torch.cat(initial_logits_rows, dim=0)
del row_pkvs, initial_logits_rows
del prompt_bundles
sequences = torch.full(
(batch_size, int(context_width) + int(max_new_tokens)),
int(pad_id),
dtype=canonical_input_ids.dtype,
device=device,
)
sequences[:, : int(context_width)] = canonical_input_ids.to(device=device)
generated = sequences[:, int(context_width) :]
score_list: list[torch.Tensor] = []
score_list_append = score_list.append
record_scores = bool(output_scores)
record_compact_scores = record_scores and bool(compact_scores)
compact_logprob_list: list[torch.Tensor] = []
compact_logprob_append = compact_logprob_list.append
unfinished = torch.ones((batch_size,), device=device, dtype=torch.bool)
finish_steps = torch.full((batch_size,), int(max_new_tokens), device=device, dtype=torch.long)
last = None
last_valid = torch.ones((batch_size,), device=device, dtype=torch.bool)
cur = 0
while cur < int(max_new_tokens):
if pending_logits is None:
if not isinstance(last, torch.Tensor):
return None
if use_direct_decode_step:
out = llopa_forward_assistant(
assistant_ids=last,
lower_k=int(lower_k),
pkv=pkv,
S=0,
U=0,
write_cache=True,
prefill_mode="lower",
no_upper_attn=bool(no_upper_attn),
align_cache_position_to_layer_past=True,
replay_module=str(replay_module),
replay_per_layers=int(replay_per_layers),
)
pkv = out.past_key_values or pkv
_append_llopa_batch_cache_valid_masks(pkv, last_valid)
logits = decode_output_head(out.last_hidden_state[:, -1, :])
if logits.dim() == 3:
logits = logits[:, -1, :]
else:
out = llopa_step(
assistant_ids=last,
lower_k=int(lower_k),
pkv=pkv,
S=0,
U=0,
logits_to_keep=1,
labels=None,
prefill_mode="lower",
no_upper_attn=bool(no_upper_attn),
align_cache_position_to_layer_past=True,
replay_module=str(replay_module),
replay_per_layers=int(replay_per_layers),
)
pkv = out.past_key_values or pkv
_append_llopa_batch_cache_valid_masks(pkv, last_valid)
logits = out.logits[:, -1, :]
logits = logits.to(dtype=torch.float32, device=device, copy=True)
else:
logits = pending_logits
pending_logits = None
if stop_token_ids is not None and cur < int(min_new_tokens):
logits.index_fill_(1, stop_token_ids.to(device=logits.device), -float("inf"))
if logits_warpers is not None:
logits = logits_warpers(generated[:, :cur], logits)
if record_scores and not record_compact_scores:
score_list_append(logits.detach().clone())
if bool(do_sample):
probs = torch.softmax(logits, dim=-1)
next_tok = torch.multinomial(probs, num_samples=1)
else:
next_tok = torch.argmax(logits, dim=-1, keepdim=True)
if record_compact_scores:
next_logit = torch.gather(logits, 1, next_tok)
next_logprob = next_logit - torch.logsumexp(logits, dim=-1, keepdim=True)
compact_logprob_append(next_logprob.squeeze(-1).detach())
token_valid = unfinished.clone()
if not bool(unfinished.all().item()):
next_tok = torch.where(
unfinished.view(-1, 1),
next_tok,
torch.full_like(next_tok, int(pad_id)),
)
token_valid = unfinished.clone()
generated[:, cur : cur + 1] = next_tok.to(dtype=generated.dtype)
cur += 1
if stop_ids and cur >= int(min_new_tokens):
stop_mask = torch.zeros_like(unfinished)
for stop_id in stop_ids:
stop_mask |= next_tok.squeeze(1).eq(int(stop_id))
newly_finished = unfinished & stop_mask
if bool(newly_finished.any().item()):
finish_steps = torch.where(
newly_finished,
torch.full_like(finish_steps, int(cur)),
finish_steps,
)
unfinished = unfinished & ~stop_mask
if stopping_criteria is not None:
sequences_now = sequences[:, : int(context_width) + cur]
try:
stop_result = stopping_criteria(sequences_now, logits)
except TypeError:
stop_result = stopping_criteria(sequences_now, None)
if isinstance(stop_result, torch.Tensor):
stop_result = stop_result.to(device=device, dtype=torch.bool).reshape(-1)
if stop_result.numel() == 1:
if bool(stop_result.item()):
unfinished.zero_()
elif stop_result.numel() == batch_size:
newly_finished = unfinished & stop_result
if bool(newly_finished.any().item()):
finish_steps = torch.where(
newly_finished,
torch.full_like(finish_steps, int(cur)),
finish_steps,
)
unfinished = unfinished & ~stop_result
elif bool(stop_result.all().item()):
unfinished.zero_()
elif bool(stop_result):
unfinished.zero_()
if not bool(unfinished.any().item()):
break
last = next_tok
last_valid = token_valid
sequences = sequences[:, : int(context_width) + cur]
if not bool(return_dict_in_generate):
return sequences
from transformers.generation.utils import GenerateDecoderOnlyOutput
output = GenerateDecoderOnlyOutput(
sequences=sequences,
scores=tuple(score_list) if record_scores and not record_compact_scores else None,
past_key_values=pkv,
)
if record_compact_scores:
if compact_logprob_list:
compact_logprobs = torch.stack(compact_logprob_list, dim=1)
else:
compact_logprobs = torch.empty(
(sequences.size(0), 0),
dtype=torch.float32,
device=sequences.device,
)
setattr(output, "generated_token_logprobs", compact_logprobs)
return output
def _direct_llopa_batch_generate_serial_cached_fallback_impl(
model,
tokenizer,
*,
segments: list[dict],
canonical_input_ids: torch.Tensor,
context_width: int,
batch_size: int,
lower_k: int,
prefill_attn: str,
system_prefill: str,
user_prefill: str,
no_upper_attn: bool,
see_past_assistant: bool,
replay_module: str,
replay_per_layers: int,
last_layer_module: Optional[str],
seed_mode: str,
max_new_tokens: int,
min_new_tokens: int,
do_sample: bool,
temperature,
top_p,
top_k,
stopping_criteria,
pad_id: int,
eos_token_id,
output_scores: bool,
compact_scores: bool,
return_dict_in_generate: bool,
device,
):
if int(batch_size) <= 1:
return None
row_sequences: list[torch.Tensor] = []
row_score_tuples: list[tuple] = []
row_compact_tensors: list[torch.Tensor] = []
score_template = None
max_gen_len = 0
record_scores = bool(output_scores)
record_compact_scores = record_scores and bool(compact_scores)
for row_idx, seg in enumerate(segments):
prompt_ids = seg.get("prompt_ids") if isinstance(seg, dict) else None
if not isinstance(prompt_ids, torch.Tensor) or prompt_ids.numel() == 0:
return None
row_prompt = prompt_ids.to(device=device, dtype=torch.long)
if row_prompt.dim() == 1:
row_prompt = row_prompt.unsqueeze(0)
elif row_prompt.dim() == 2:
row_prompt = row_prompt[:1]
else:
row_prompt = row_prompt.reshape(1, -1)
row_prompt_len = int(row_prompt.size(1))
row_out = _direct_llopa_generate_impl(
model,
tokenizer,
prompt_messages=None,
prompt_add_generation_prompt=True,
structured_prompt_segments=seg,
input_ids=row_prompt,
attention_mask=torch.ones_like(row_prompt, device=device, dtype=torch.long),
lower_k=int(lower_k),
prefill_attn=str(prefill_attn),
system_prefill=str(system_prefill),
user_prefill=str(user_prefill),
no_upper_attn=bool(no_upper_attn),
see_past_assistant=bool(see_past_assistant),
replay_module=str(replay_module),
replay_per_layers=int(replay_per_layers),
last_layer_module=last_layer_module,
seed_mode=str(seed_mode or "prefill_header"),
max_new_tokens=int(max_new_tokens),
min_new_tokens=int(min_new_tokens),
do_sample=bool(do_sample),
temperature=temperature,
top_p=top_p,
top_k=top_k,
stopping_criteria=stopping_criteria,
pad_token_id=int(pad_id),
eos_token_id=eos_token_id,
output_scores=record_scores,
compact_scores=record_compact_scores,
return_dict_in_generate=True,
use_cache=True,
)
if row_out is None:
return None
row_seq = getattr(row_out, "sequences", None)
if not isinstance(row_seq, torch.Tensor) or row_seq.dim() != 2 or row_seq.size(0) < 1:
return None
row_gen = row_seq[:1, row_prompt_len:].to(device=device, dtype=canonical_input_ids.dtype)
row_sequences.append(row_gen)
max_gen_len = max(max_gen_len, int(row_gen.size(1)))
if record_compact_scores:
compact = getattr(row_out, "generated_token_logprobs", None)
if isinstance(compact, torch.Tensor):
row_compact_tensors.append(compact[:1].to(device=device, dtype=torch.float32))
else:
row_compact_tensors.append(torch.empty((1, 0), device=device, dtype=torch.float32))
elif record_scores:
scores = getattr(row_out, "scores", None)
if isinstance(scores, tuple):
row_score_tuples.append(scores)
if score_template is None and len(scores) > 0 and isinstance(scores[0], torch.Tensor):
score_template = scores[0][:1].to(device=device)
else:
row_score_tuples.append(())
sequences = torch.full(
(int(batch_size), int(context_width) + int(max_gen_len)),
int(pad_id),
dtype=canonical_input_ids.dtype,
device=device,
)
sequences[:, : int(context_width)] = canonical_input_ids.to(device=device)
for row_idx, row_gen in enumerate(row_sequences):
width = int(row_gen.size(1))
if width > 0:
sequences[row_idx : row_idx + 1, int(context_width) : int(context_width) + width] = row_gen[:, :width]
compact_logprobs = None
score_list = None
if record_compact_scores:
compact_logprobs = torch.zeros(
(int(batch_size), int(max_gen_len)),
device=device,
dtype=torch.float32,
)
for row_idx, compact in enumerate(row_compact_tensors):
width = min(int(compact.size(1)), int(max_gen_len))
if width > 0:
compact_logprobs[row_idx : row_idx + 1, :width] = compact[:, :width]
elif record_scores:
if score_template is None:
score_list = []
else:
score_list = []
for step_idx in range(int(max_gen_len)):
step_scores = []
for scores in row_score_tuples:
if step_idx < len(scores) and isinstance(scores[step_idx], torch.Tensor):
step_scores.append(scores[step_idx][:1].to(device=device))
else:
step_scores.append(torch.zeros_like(score_template, device=device))
score_list.append(torch.cat(step_scores, dim=0))
if not bool(return_dict_in_generate):
return sequences
from transformers.generation.utils import GenerateDecoderOnlyOutput
output = GenerateDecoderOnlyOutput(
sequences=sequences,
scores=tuple(score_list) if record_scores and not record_compact_scores and score_list is not None else None,
past_key_values=None,
)
if compact_logprobs is not None:
setattr(output, "generated_token_logprobs", compact_logprobs)
return output
@torch.inference_mode()
def _direct_llopa_batch_generate_impl(
model,
tokenizer,
*,
prompt_messages,
prompt_add_generation_prompt: bool,
structured_prompt_segments=None,
input_ids: Optional[torch.LongTensor],
attention_mask: Optional[torch.Tensor],
lower_k: int,
prefill_attn: str,
system_prefill: str,
user_prefill: str,
no_upper_attn: bool,
see_past_assistant: bool = False,
replay_module: str = "none",
replay_per_layers: int = -1,
last_layer_module: Optional[str] = None,
seed_mode: str = "prefill_header",
max_length=None,
max_new_tokens=None,
min_length=None,
min_new_tokens=None,
do_sample=None,
temperature=None,
top_p=None,
top_k=None,
stopping_criteria=None,
pad_token_id=None,
eos_token_id=None,
output_scores: bool = False,
compact_scores: bool = False,
return_dict_in_generate: bool = False,
use_cache: Optional[bool] = None,
):
allow_cached_paths = use_cache is not False
batch_size = int(input_ids.size(0)) if isinstance(input_ids, torch.Tensor) and input_ids.dim() == 2 else 0
if _is_structured_segments_batch(structured_prompt_segments):
batch_size = len(structured_prompt_segments)
elif _is_prompt_messages_batch(prompt_messages):
batch_size = len(prompt_messages)
elif batch_size <= 0:
batch_size = 1
if last_layer_module is not None and _normalize_replay_module_value(replay_module) == "none":
replay_module = last_layer_module
replay_module = _normalize_replay_module_value(replay_module)
replay_per_layers = _normalize_replay_per_layers_value(replay_per_layers)
try:
lower_k = int(lower_k)
except Exception:
lower_k = 0
if lower_k <= 0:
return None
attn = (prefill_attn or "causal").strip().lower()
if attn == "prefix_full":
attn = "full"
if attn not in {"causal", "full"}:
attn = "causal"
sys_prefill = (system_prefill or "full").strip().lower()
if sys_prefill not in {"full", "no_system", "no_bos_system"}:
sys_prefill = "full"
user_prefill_norm = (user_prefill or "full").strip().lower()
if user_prefill_norm != "full":
raise ValueError("llopa_v2_batch_generate currently supports only user_prefill='full'.")
if isinstance(input_ids, torch.Tensor):
if input_ids.dim() != 2 or input_ids.size(0) != batch_size:
return None
device = input_ids.device
else:
try:
device = next(model.parameters()).device
except Exception:
device = "cpu"
pad_id = pad_token_id
if pad_id is None:
pad_id = getattr(tokenizer, "pad_token_id", None)
if pad_id is None:
pad_id = getattr(tokenizer, "eos_token_id", None)
if pad_id is None:
pad_id = 0
pad_id = int(pad_id)
segments = _build_batched_structured_prompt_segments(
tokenizer,
prompt_messages=prompt_messages,
prompt_add_generation_prompt=prompt_add_generation_prompt,
structured_prompt_segments=structured_prompt_segments,
device=device,
)
if len(segments) != batch_size:
raise ValueError(
f"llopa_v2_batch_generate prompt metadata batch size mismatch: {len(segments)} != {batch_size}"
)
batch = _batched_segments_to_tensors(segments, device=device, pad_token_id=pad_id)
if isinstance(input_ids, torch.Tensor):
canonical_input_ids = input_ids.to(device=device)
context_width = int(canonical_input_ids.size(1))
else:
canonical_input_ids = batch["prompt_ids"]
context_width = int(canonical_input_ids.size(1))
raw_temp = 0.0 if temperature is None else float(temperature)
if do_sample is None:
do_sample = bool(raw_temp != 0.0)
do_sample = bool(do_sample)
sample_temp = 1.0 if (not do_sample or raw_temp == 0.0) else float(raw_temp)
top_p = 1.0 if top_p is None else float(top_p)
top_k = None if top_k is None else int(top_k)
if max_new_tokens is None:
if max_length is None:
max_new_tokens = 256
else:
max_new_tokens = max(0, int(max_length) - int(context_width))
else:
max_new_tokens = int(max_new_tokens)
if min_new_tokens is None:
if min_length is None:
min_new_tokens = 0
else:
min_new_tokens = max(0, int(min_length) - int(context_width))
else:
min_new_tokens = int(min_new_tokens)
max_new_tokens = int(max_new_tokens)
min_new_tokens = int(min_new_tokens)
stop_ids = set(_normalize_eos_token_ids(eos_token_id))
if not stop_ids:
tok_eos = getattr(tokenizer, "eos_token_id", None)
if tok_eos is not None:
stop_ids.add(int(tok_eos))
with contextlib.suppress(Exception):
eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>")
if eot_id is not None and eot_id != tokenizer.unk_token_id:
stop_ids.add(int(eot_id))
stop_token_ids = _prepare_stop_token_tensor(stop_ids, device)
logits_warpers = _build_sampling_warpers(do_sample, sample_temp, top_p, top_k)
record_scores = bool(output_scores)
record_compact_scores = record_scores and bool(compact_scores)
should_apply_stopping_criteria = stopping_criteria is not None
cached_out = None
if bool(allow_cached_paths):
cached_out = _direct_llopa_batch_generate_cached_impl(
model,
tokenizer,
segments=segments,
canonical_input_ids=canonical_input_ids,
context_width=int(context_width),
batch_size=int(batch_size),
lower_k=int(lower_k),
prefill_attn=str(attn),
system_prefill=str(sys_prefill),
no_upper_attn=bool(no_upper_attn),
see_past_assistant=bool(see_past_assistant),
replay_module=str(replay_module),
replay_per_layers=int(replay_per_layers),
last_layer_module=last_layer_module,
max_new_tokens=int(max_new_tokens),
min_new_tokens=int(min_new_tokens),
do_sample=bool(do_sample),
logits_warpers=logits_warpers,
stopping_criteria=stopping_criteria,
stop_ids=stop_ids,
stop_token_ids=stop_token_ids,
pad_id=int(pad_id),
output_scores=record_scores,
compact_scores=record_compact_scores,
return_dict_in_generate=bool(return_dict_in_generate),
device=device,
)
if cached_out is not None:
return cached_out
serial_cached_out = None
if bool(allow_cached_paths) and stopping_criteria is None:
serial_cached_out = _direct_llopa_batch_generate_serial_cached_fallback_impl(
model,
tokenizer,
segments=segments,
canonical_input_ids=canonical_input_ids,
context_width=int(context_width),
batch_size=int(batch_size),
lower_k=int(lower_k),
prefill_attn=str(attn),
system_prefill=str(sys_prefill),
user_prefill=str(user_prefill_norm),
no_upper_attn=bool(no_upper_attn),
see_past_assistant=bool(see_past_assistant),
replay_module=str(replay_module),
replay_per_layers=int(replay_per_layers),
last_layer_module=last_layer_module,
seed_mode=str(seed_mode or "prefill_header"),
max_new_tokens=int(max_new_tokens),
min_new_tokens=int(min_new_tokens),
do_sample=bool(do_sample),
temperature=temperature,
top_p=top_p,
top_k=top_k,
stopping_criteria=stopping_criteria,
pad_id=int(pad_id),
eos_token_id=eos_token_id,
output_scores=record_scores,
compact_scores=record_compact_scores,
return_dict_in_generate=bool(return_dict_in_generate),
device=device,
)
if serial_cached_out is not None:
return serial_cached_out
sequences = torch.full(
(batch_size, context_width + max_new_tokens),
pad_id,
dtype=canonical_input_ids.dtype,
device=device,
)
sequences[:, :context_width] = canonical_input_ids
generated = sequences[:, context_width:]
score_list: list[torch.Tensor] = []
score_list_append = score_list.append
compact_logprob_list: list[torch.Tensor] = []
compact_logprob_append = compact_logprob_list.append
prompt_rows: list[torch.Tensor] = batch["prompt_rows"]
prompt_lens = batch["prompt_lens"]
split_starts = batch["split_starts"]
system_lens = batch["system_lens"]
unfinished = torch.ones((batch_size,), device=device, dtype=torch.bool)
finish_steps = torch.full((batch_size,), max_new_tokens, device=device, dtype=torch.long)
def _build_step_tensors(cur_tokens: int) -> tuple[torch.LongTensor, torch.Tensor, torch.LongTensor]:
active_lens = []
rows = []
for row_idx, prompt_row in enumerate(prompt_rows):
if bool(unfinished[row_idx].item()):
gen_len = int(cur_tokens)
else:
gen_len = min(int(cur_tokens), int(finish_steps[row_idx].item()))
gen_prefix = generated[row_idx, :gen_len].to(device=device, dtype=torch.long)
row_ids = torch.cat([prompt_row.to(device=device, dtype=torch.long), gen_prefix], dim=0)
rows.append(row_ids)
active_lens.append(int(row_ids.numel()))
step_ids = _pad_1d_rows(rows, pad_value=pad_id, device=device, dtype=torch.long)
active_lens_t = torch.tensor(active_lens, device=device, dtype=torch.long)
step_mask = (
torch.arange(step_ids.size(1), device=device, dtype=torch.long).unsqueeze(0)
< active_lens_t.unsqueeze(1)
).to(dtype=torch.long)
labels = torch.full_like(step_ids, -100)
for row_idx, active_len in enumerate(active_lens):
start = int(split_starts[row_idx].item())
if active_len > start:
labels[row_idx, start:active_len] = step_ids[row_idx, start:active_len]
return step_ids, step_mask, labels
cur = 0
while cur < max_new_tokens:
step_ids, step_mask, labels = _build_step_tensors(cur)
out = model(
input_ids=step_ids,
attention_mask=step_mask,
labels=labels,
use_cache=False,
logits_to_keep=1,
prefill_lower_layers=int(lower_k),
prefill_lower_attn=str(attn),
prefill_lower_system_prefill=str(sys_prefill),
prefill_lower_no_upper_attn=bool(no_upper_attn),
prefill_lower_split_start=split_starts,
prefill_lower_system_len=system_lens,
prefill_lower_replay_user_prefix_keep_len=batch["replay_user_prefix_keep_lens"],
prefill_lower_replay_user_start=batch["replay_user_starts"],
prefill_lower_replay_user_len=batch["replay_user_lens"],
assistant_header_starts=batch["assistant_header_starts"],
assistant_turn_ends=batch["assistant_turn_ends"],
assistant_header_start_mask=batch["assistant_header_start_mask"],
prefill_lower_see_past_assistant=bool(see_past_assistant),
prefill_lower_replay_module=str(replay_module),
prefill_lower_replay_per_layers=int(replay_per_layers),
)
if out is None or not isinstance(getattr(out, "logits", None), torch.Tensor):
return None
logits = out.logits[:, -1, :].to(dtype=torch.float32, device=device, copy=True)
if stop_token_ids is not None and cur < min_new_tokens:
logits.index_fill_(1, stop_token_ids, -float("inf"))
if logits_warpers is not None:
logits = logits_warpers(generated[:, :cur], logits)
if record_scores and not record_compact_scores:
score_list_append(logits.detach().clone())
if do_sample:
probs = torch.softmax(logits, dim=-1)
next_tok = torch.multinomial(probs, num_samples=1)
else:
next_tok = torch.argmax(logits, dim=-1, keepdim=True)
if record_compact_scores:
next_logit = torch.gather(logits, 1, next_tok)
next_logprob = next_logit - torch.logsumexp(logits, dim=-1, keepdim=True)
compact_logprob_append(next_logprob.squeeze(-1).detach())
if not bool(unfinished.all().item()):
next_tok = torch.where(
unfinished.view(-1, 1),
next_tok,
torch.full_like(next_tok, pad_id),
)
generated[:, cur : cur + 1] = next_tok.to(dtype=generated.dtype)
cur += 1
if stop_ids and cur >= min_new_tokens:
stop_mask = torch.zeros_like(unfinished)
for stop_id in stop_ids:
stop_mask |= next_tok.squeeze(1).eq(int(stop_id))
newly_finished = unfinished & stop_mask
if bool(newly_finished.any().item()):
finish_steps = torch.where(
newly_finished,
torch.full_like(finish_steps, int(cur)),
finish_steps,
)
unfinished = unfinished & ~stop_mask
if should_apply_stopping_criteria:
sequences_now = sequences[:, : context_width + cur]
try:
stop_result = stopping_criteria(sequences_now, logits)
except TypeError:
stop_result = stopping_criteria(sequences_now, None)
if isinstance(stop_result, torch.Tensor):
stop_result = stop_result.to(device=device, dtype=torch.bool).reshape(-1)
if stop_result.numel() == 1:
if bool(stop_result.item()):
unfinished.zero_()
elif stop_result.numel() == batch_size:
newly_finished = unfinished & stop_result
if bool(newly_finished.any().item()):
finish_steps = torch.where(
newly_finished,
torch.full_like(finish_steps, int(cur)),
finish_steps,
)
unfinished = unfinished & ~stop_result
elif bool(stop_result.all().item()):
unfinished.zero_()
elif bool(stop_result):
unfinished.zero_()
if not bool(unfinished.any().item()):
break
sequences = sequences[:, : context_width + cur]
if not bool(return_dict_in_generate):
return sequences
from transformers.generation.utils import GenerateDecoderOnlyOutput
output = GenerateDecoderOnlyOutput(
sequences=sequences,
scores=tuple(score_list) if record_scores and not record_compact_scores else None,
past_key_values=None,
)
if record_compact_scores:
if compact_logprob_list:
compact_logprobs = torch.stack(compact_logprob_list, dim=1)
else:
compact_logprobs = torch.empty(
(sequences.size(0), 0),
dtype=torch.float32,
device=sequences.device,
)
setattr(output, "generated_token_logprobs", compact_logprobs)
return output
def _direct_optimized_llopa_generate_impl(
model,
tokenizer,
*,
prompt_messages,
prompt_add_generation_prompt: bool,
structured_prompt_segments=None,
input_ids: Optional[torch.LongTensor],
attention_mask: Optional[torch.Tensor],
lower_k: int,
prefill_attn: str,
system_prefill: str,
user_prefill: str,
no_upper_attn: bool,
see_past_assistant: bool = False,
replay_module: str = "none",
replay_per_layers: int = -1,
last_layer_module: Optional[str] = None,
optimized_variant: Optional[str] = None,
optimized_seed_mode: Optional[str] = None,
optimized_upper_prepare_mode: Optional[str] = None,
optimized_upper_bucket_multiple: Optional[int] = None,
optimized_seq_bucket_multiple: Optional[int] = None,
max_length=None,
max_new_tokens=None,
min_length=None,
min_new_tokens=None,
do_sample=None,
temperature=None,
top_p=None,
top_k=None,
stopping_criteria=None,
pad_token_id=None,
eos_token_id=None,
output_scores: bool = False,
compact_scores: bool = False,
return_dict_in_generate: bool = False,
use_cache: Optional[bool] = None,
):
if last_layer_module is not None and _normalize_replay_module_value(replay_module) == "none":
replay_module = last_layer_module
replay_module = _normalize_replay_module_value(replay_module)
replay_per_layers = _normalize_replay_per_layers_value(replay_per_layers)
if use_cache is False:
return None
try:
lower_k = int(lower_k)
except Exception:
lower_k = 0
if lower_k <= 0:
return None
attn = (prefill_attn or "causal").strip().lower()
if attn == "prefix_full":
attn = "full"
if attn not in {"causal", "full"}:
attn = "causal"
sys_prefill = (system_prefill or "full").strip().lower()
if sys_prefill not in {"full", "no_system", "no_bos_system"}:
sys_prefill = "full"
user_prefill_norm = (user_prefill or "full").strip().lower()
if user_prefill_norm != "full":
raise ValueError("Optimized LLoPA currently supports only user_prefill='full'.")
optimized_settings = _resolve_optimized_llopa_settings(
variant=optimized_variant,
seed_mode=optimized_seed_mode,
upper_prepare_mode=optimized_upper_prepare_mode,
upper_bucket_multiple=optimized_upper_bucket_multiple,
seq_bucket_multiple=optimized_seq_bucket_multiple,
)
llopa_core = _get_llopa_core(model)
llopa_step = _get_llopa_decode_step(model)
if llopa_core is None or llopa_step is None:
return None
llopa_forward_assistant = getattr(llopa_core, "tri_forward_assistant", None)
decode_output_head = _get_output_head(model)
use_direct_decode_step = (
_env_flag_enabled("CAPSULE_LLOPA_DIRECT_DECODE_STEP", "1")
and callable(llopa_forward_assistant)
and decode_output_head is not None
)
device = None
if isinstance(input_ids, torch.Tensor):
if input_ids.dim() != 2 or input_ids.size(0) != 1:
return None
device = input_ids.device
if device is None:
try:
device = next(model.parameters()).device
except Exception:
device = "cpu"
segments = structured_prompt_segments if isinstance(structured_prompt_segments, dict) else None
if segments is None:
segments = _build_structured_prompt_segments(
tokenizer,
prompt_messages,
prompt_add_generation_prompt=bool(prompt_add_generation_prompt),
device=device,
)
prompt_ids = segments["prompt_ids"]
system_ids = segments["system_ids"]
user_ids = segments["user_ids"]
assistant_prefill_ids = segments["assistant_prefill_ids"]
replay_user_prefix_keep_len = int(segments.get("replay_user_prefix_keep_len", 0) or 0)
replay_user_start = int(segments.get("replay_user_start", 0) or 0)
replay_user_len = int(segments.get("replay_user_len", 0) or 0)
if assistant_prefill_ids.numel() == 0:
return None
raw_temp = 0.0 if temperature is None else float(temperature)
if do_sample is None:
do_sample = bool(raw_temp != 0.0)
do_sample = bool(do_sample)
sample_temp = 1.0 if (not do_sample or raw_temp == 0.0) else float(raw_temp)
top_p = 1.0 if top_p is None else float(top_p)
top_k = None if top_k is None else int(top_k)
initial_logits = None
prompt_bundle = _build_unified_prefill_lower_prompt_bundle(
tokenizer,
prompt_messages=prompt_messages,
prompt_add_generation_prompt=bool(prompt_add_generation_prompt),
structured_prompt_segments=segments,
device=device,
)
with _temporary_model_attrs(
model,
_optimized_llopa_variant=optimized_settings["variant"],
_optimized_llopa_seed_mode=optimized_settings["seed_mode"],
_optimized_llopa_upper_prepare_mode=optimized_settings["upper_prepare_mode"],
_optimized_llopa_upper_bucket_multiple=int(optimized_settings["upper_bucket_multiple"]),
_optimized_llopa_seq_bucket_multiple=int(optimized_settings["seq_bucket_multiple"]),
):
reference_seed = _optimized_prefill_lower_cache_and_logits(
model,
prompt_bundle=prompt_bundle,
lower_k=lower_k,
prefill_attn=attn,
system_prefill=sys_prefill,
no_upper_attn=bool(no_upper_attn),
see_past_assistant=bool(see_past_assistant),
replay_module=str(replay_module),
replay_per_layers=int(replay_per_layers),
seed_mode=optimized_settings["seed_mode"],
)
canonical_input_ids = None
if isinstance(input_ids, torch.Tensor) and input_ids.dim() == 2 and input_ids.size(0) == 1:
valid_len = int(input_ids.size(1))
if isinstance(attention_mask, torch.Tensor) and attention_mask.dim() == 2 and attention_mask.size(0) == 1:
valid_len = int(attention_mask[0].sum().item())
if valid_len > 0:
canonical_input_ids = input_ids[:, -valid_len:]
if not isinstance(canonical_input_ids, torch.Tensor):
canonical_input_ids = prompt_bundle.get("effective_prompt_ids")
if not isinstance(canonical_input_ids, torch.Tensor):
canonical_input_ids = prompt_ids
total_prompt_len = int(canonical_input_ids.size(1))
if max_new_tokens is None:
if max_length is None:
max_new_tokens = 256
else:
max_new_tokens = max(0, int(max_length) - total_prompt_len)
else:
max_new_tokens = int(max_new_tokens)
if min_new_tokens is None:
if min_length is None:
min_new_tokens = 0
else:
min_new_tokens = max(0, int(min_length) - total_prompt_len)
else:
min_new_tokens = int(min_new_tokens)
if reference_seed is not None:
pkv, S, U, initial_logits = reference_seed
else:
output_head = _get_output_head(model)
if bool(no_upper_attn):
pkv, S, U = _llopa_prefill_cache(
llopa_core,
system_ids,
user_ids,
assistant_prefill_ids,
lower_k=lower_k,
prefill_mode="lower",
prefill_attn=attn,
system_prefill=sys_prefill,
replay_user_prefix_keep_len=replay_user_prefix_keep_len,
replay_user_start=replay_user_start,
replay_user_len=replay_user_len,
)
else:
pkv, S, U, last_hidden = _llopa_prefill_cache(
llopa_core,
system_ids,
user_ids,
assistant_prefill_ids,
lower_k=lower_k,
prefill_mode="lower",
prefill_attn=attn,
system_prefill=sys_prefill,
return_last_assistant_hidden=bool(output_head is not None),
replay_user_prefix_keep_len=replay_user_prefix_keep_len,
replay_user_start=replay_user_start,
replay_user_len=replay_user_len,
)
if output_head is not None and isinstance(last_hidden, torch.Tensor) and last_hidden.numel() > 0:
initial_logits = output_head(last_hidden)[:, -1, :].to(torch.float32)
stop_ids = set(_normalize_eos_token_ids(eos_token_id))
if not stop_ids:
tok_eos = getattr(tokenizer, "eos_token_id", None)
if tok_eos is not None:
stop_ids.add(int(tok_eos))
with contextlib.suppress(Exception):
eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>")
if eot_id is not None and eot_id != tokenizer.unk_token_id:
stop_ids.add(int(eot_id))
last = assistant_prefill_ids[:, -1:]
stop_token_ids = _prepare_stop_token_tensor(stop_ids, last.device)
logits_warpers = _build_sampling_warpers(do_sample, sample_temp, top_p, top_k)
max_new_tokens = int(max_new_tokens)
min_new_tokens = int(min_new_tokens)
lower_k = int(lower_k)
no_upper_attn_bool = bool(no_upper_attn)
replay_module_str = str(replay_module)
replay_per_layers_int = int(replay_per_layers)
record_scores = bool(output_scores)
record_compact_scores = bool(output_scores) and bool(compact_scores)
should_apply_stopping_criteria = stopping_criteria is not None
bucket_multiple = int(optimized_settings["seq_bucket_multiple"] or 256)
sequences_full, _ = _acquire_bucketed_sequence_workspace(
model,
reference_ids=canonical_input_ids,
batch_size=int(canonical_input_ids.size(0)),
total_len=int(total_prompt_len + max_new_tokens),
bucket_multiple=bucket_multiple,
)
sequences_full[:, :total_prompt_len] = canonical_input_ids
sequences = sequences_full[:, : total_prompt_len + max_new_tokens]
generated = sequences[:, total_prompt_len:]
score_list: list[torch.Tensor] = []
score_list_append = score_list.append
compact_logprob_list: list[torch.Tensor] = []
compact_logprob_append = compact_logprob_list.append
cur = 0
pending_logits = initial_logits
while cur < max_new_tokens:
out = None
if pending_logits is None:
if use_direct_decode_step:
out = llopa_forward_assistant(
assistant_ids=last,
lower_k=lower_k,
pkv=pkv,
S=S,
U=U,
write_cache=True,
prefill_mode="lower",
no_upper_attn=no_upper_attn_bool,
align_cache_position_to_layer_past=False,
replay_module=replay_module_str,
replay_per_layers=replay_per_layers_int,
)
pkv = out.past_key_values or pkv
logits = decode_output_head(out.last_hidden_state[:, -1, :])
if logits.dim() == 3:
logits = logits[:, -1, :]
else:
out = llopa_step(
assistant_ids=last,
lower_k=lower_k,
pkv=pkv,
S=S,
U=U,
logits_to_keep=1,
labels=None,
prefill_mode="lower",
no_upper_attn=no_upper_attn_bool,
align_cache_position_to_layer_past=False,
replay_module=replay_module_str,
replay_per_layers=replay_per_layers_int,
)
pkv = out.past_key_values or pkv
logits = out.logits[:, -1, :]
logits = logits.to(dtype=torch.float32, device=last.device, copy=True)
else:
logits = pending_logits
pending_logits = None
if stop_token_ids is not None and cur < min_new_tokens:
logits.index_fill_(1, stop_token_ids, -float("inf"))
if logits_warpers is not None:
logits = logits_warpers(generated[:, :cur], logits)
if record_scores and not record_compact_scores:
score_list_append(logits.detach().clone())
if do_sample:
probs = torch.softmax(logits, dim=-1)
next_tok = torch.multinomial(probs, num_samples=1)
else:
next_tok = torch.argmax(logits, dim=-1, keepdim=True)
if record_compact_scores:
next_logit = torch.gather(logits, 1, next_tok)
next_logprob = next_logit - torch.logsumexp(logits, dim=-1, keepdim=True)
compact_logprob_append(next_logprob.squeeze(-1).detach())
generated[:, cur : cur + 1] = next_tok
cur += 1
should_stop = False
tok_id = int(next_tok.item())
if tok_id in stop_ids and cur >= min_new_tokens:
should_stop = True
if (not should_stop) and should_apply_stopping_criteria:
sequences_now = sequences[:, : total_prompt_len + cur]
try:
should_stop = bool(stopping_criteria(sequences_now, logits))
except TypeError:
should_stop = bool(stopping_criteria(sequences_now, None))
if out is not None:
del out
if should_stop:
break
last = next_tok
sequences = sequences[:, : total_prompt_len + cur]
if not bool(return_dict_in_generate):
return sequences
from transformers.generation.utils import GenerateDecoderOnlyOutput
output = GenerateDecoderOnlyOutput(
sequences=sequences,
scores=tuple(score_list) if bool(output_scores) and not record_compact_scores else None,
past_key_values=pkv,
)
if record_compact_scores:
if compact_logprob_list:
compact_logprobs = torch.stack(compact_logprob_list, dim=1)
else:
compact_logprobs = torch.empty(
(sequences.size(0), 0),
dtype=torch.float32,
device=sequences.device,
)
setattr(output, "generated_token_logprobs", compact_logprobs)
return output
def _normalize_replay_module_value(value: Optional[str]) -> str:
raw = str(value or "none").strip().lower()
aliases = {
"": "none",
"off": "none",
"disabled": "none",
"disable": "none",
"self-attention": "self",
"self_attention": "self",
"selfattn": "self",
"self_attn": "self",
"cross-attention": "cross",
"cross_attention": "cross",
"crossattn": "cross",
"cross_attn": "cross",
}
raw = aliases.get(raw, raw)
if raw in {"none", "self", "cross"}:
return raw
return "none"
def _normalize_replay_per_layers_value(value) -> int:
try:
normalized = int(value)
except Exception:
return -1
if normalized == -1 or normalized >= 1:
return normalized
return -1
def _normalize_structured_llopa_runtime(
lower_k: int,
*,
prefill_attn: str = "causal",
system_prefill: str = "full",
user_prefill: str = "full",
no_upper_attn: bool = False,
see_past_assistant: bool = False,
replay_module: str = "none",
replay_per_layers: int = -1,
last_layer_module: Optional[str] = None,
):
try:
lower_k = int(lower_k)
except Exception:
lower_k = 0
attn = (prefill_attn or "causal").strip().lower()
if attn == "prefix_full":
attn = "full"
sys_prefill = (system_prefill or "full").strip().lower()
if sys_prefill not in {"full", "no_system", "no_bos_system"}:
sys_prefill = "full"
user_prefill_norm = (user_prefill or "full").strip().lower()
if last_layer_module is not None and _normalize_replay_module_value(replay_module) == "none":
replay_module = last_layer_module
replay_module_norm = _normalize_replay_module_value(replay_module)
replay_per_layers_norm = _normalize_replay_per_layers_value(replay_per_layers)
if lower_k <= 0 or attn not in {"causal", "full"}:
return None
return (
int(lower_k),
attn,
sys_prefill,
user_prefill_norm,
bool(no_upper_attn),
replay_module_norm,
replay_per_layers_norm,
)
def _attach_structured_llopa_generate(model, tokenizer) -> None:
try:
import types
except Exception:
return
if getattr(model, "_structured_llopa_generate_attached", False):
return
orig_generate = getattr(model, "generate", None)
if not callable(orig_generate):
return
def _structured_llopa_generate(self, *args, **kwargs):
if args:
if "input_ids" not in kwargs:
kwargs["input_ids"] = args[0]
args = args[1:]
if args:
return orig_generate(*args, **kwargs)
optimized_enabled = kwargs.pop("optimized_llopa_generate", None)
llopa_v2_batch_enabled = kwargs.pop("llopa_v2_batch_generate", None)
llopa_v2_enabled = kwargs.pop("llopa_v2_generate", None)
llopa_v3_enabled = kwargs.pop("llopa_v3_generate", None)
runtime_solo_enabled = kwargs.pop("runtime_solo_generate", None)
runtime_solo_v2_enabled = kwargs.pop("runtime_solo_v2_generate", None)
unified_enabled = kwargs.pop("unified_llopa_generate", None)
direct_enabled = kwargs.pop("direct_llopa_generate", None)
legacy_search = bool(kwargs.pop("direct_llopa_legacy_search", False))
prompt_messages = kwargs.pop("prompt_messages", None)
prompt_add_generation_prompt = kwargs.pop("prompt_add_generation_prompt", None)
structured_prompt_segments = kwargs.pop("structured_prompt_segments", None)
compact_scores = bool(kwargs.pop("capsule_compact_scores", False))
mode = None
if optimized_enabled is not None:
if bool(optimized_enabled):
mode = "optimized"
elif llopa_v2_batch_enabled is None and llopa_v2_enabled is None and llopa_v3_enabled is None and runtime_solo_enabled is None and runtime_solo_v2_enabled is None and unified_enabled is None and direct_enabled is None:
return orig_generate(**kwargs)
if llopa_v2_batch_enabled is not None:
if bool(llopa_v2_batch_enabled):
mode = "llopa_v2_batch"
elif llopa_v2_enabled is None and llopa_v3_enabled is None and runtime_solo_enabled is None and runtime_solo_v2_enabled is None and unified_enabled is None and direct_enabled is None:
return orig_generate(**kwargs)
if llopa_v2_enabled is not None:
if bool(llopa_v2_enabled):
mode = "llopa_v2"
elif llopa_v3_enabled is None and runtime_solo_enabled is None and runtime_solo_v2_enabled is None and unified_enabled is None and direct_enabled is None:
return orig_generate(**kwargs)
if llopa_v3_enabled is not None:
if bool(llopa_v3_enabled):
mode = "llopa_v3"
elif runtime_solo_enabled is None and runtime_solo_v2_enabled is None and unified_enabled is None and direct_enabled is None:
return orig_generate(**kwargs)
if runtime_solo_v2_enabled is not None:
if bool(runtime_solo_v2_enabled):
mode = "solo_v2"
elif runtime_solo_enabled is None and unified_enabled is None and direct_enabled is None:
return orig_generate(**kwargs)
if runtime_solo_enabled is not None:
if bool(runtime_solo_enabled):
mode = "solo"
elif unified_enabled is None and direct_enabled is None:
return orig_generate(**kwargs)
if unified_enabled is not None:
if bool(unified_enabled):
mode = "unified"
elif direct_enabled is None:
return orig_generate(**kwargs)
if mode is None and direct_enabled is not None:
if bool(direct_enabled):
mode = "direct"
else:
return orig_generate(**kwargs)
if mode is None:
if bool(getattr(self, "_optimized_llopa_generate_default", False)):
mode = "optimized"
elif bool(getattr(self, "_llopa_v2_batch_generate_default", False)):
mode = "llopa_v2_batch"
elif bool(getattr(self, "_llopa_v2_generate_default", False)):
mode = "llopa_v3" if str(getattr(self, "_capsule_inference_path", "") or "") == "llopa_v3" else "llopa_v2"
elif bool(getattr(self, "_unified_llopa_generate_default", False)):
mode = "unified"
elif bool(getattr(self, "_runtime_structured_freeze_generate_default", False)):
mode = "freeze"
elif bool(getattr(self, "_runtime_structured_solo_v2_generate_default", False)):
mode = "solo_v2"
elif bool(getattr(self, "_runtime_structured_solo_generate_default", False)):
mode = "solo"
elif bool(getattr(self, "_direct_llopa_generate_default", False)):
mode = "direct"
else:
return orig_generate(**kwargs)
if kwargs.get("inputs_embeds") is not None:
return orig_generate(**kwargs)
if mode == "optimized":
mode_label = "optimized_llopa_generate"
elif mode == "llopa_v2_batch":
mode_label = "llopa_v2_batch_generate"
elif mode == "llopa_v2":
mode_label = "llopa_v2_generate"
elif mode == "llopa_v3":
mode_label = "llopa_v3_generate"
elif mode == "unified":
mode_label = "unified_llopa_generate"
elif mode == "freeze":
mode_label = "runtime_freeze_generate"
elif mode == "solo_v2":
mode_label = "runtime_solo_v2_generate"
elif mode == "solo":
mode_label = "runtime_solo_generate"
else:
mode_label = "direct_llopa_generate"
if int(kwargs.get("num_beams", 1) or 1) != 1:
_warn_once(
self,
f"_warned_{mode_label}_num_beams",
f"[load_llopa_model][warn] {mode_label} currently supports only num_beams=1; falling back to model.generate().",
)
return orig_generate(**kwargs)
if int(kwargs.get("num_return_sequences", 1) or 1) != 1:
_warn_once(
self,
f"_warned_{mode_label}_num_return_sequences",
f"[load_llopa_model][warn] {mode_label} currently supports only num_return_sequences=1; falling back to model.generate().",
)
return orig_generate(**kwargs)
if mode in {"unified", "optimized", "llopa_v2", "llopa_v3", "llopa_v2_batch"}:
attr_prefix = "_optimized_llopa" if mode == "optimized" else "_llopa_v2" if mode in {"llopa_v2", "llopa_v3", "llopa_v2_batch"} else "_unified_llopa"
kw_prefix = "optimized_llopa" if mode == "optimized" else "llopa_v2" if mode in {"llopa_v2", "llopa_v3", "llopa_v2_batch"} else "unified_llopa"
lower_k_attr = f"{attr_prefix}_layers"
attn_attr = f"{attr_prefix}_attn"
system_attr = f"{attr_prefix}_system_prefill"
user_attr = f"{attr_prefix}_user_prefill"
no_upper_attr = f"{attr_prefix}_no_upper_attn"
see_past_assistant_attr = f"{attr_prefix}_see_past_assistant"
replay_module_attr = f"{attr_prefix}_replay_module"
replay_per_layers_attr = f"{attr_prefix}_replay_per_layers"
last_layer_attr = f"{attr_prefix}_last_layer_module"
variant_attr = "_optimized_llopa_variant"
seed_attr = "_optimized_llopa_seed_mode"
upper_prepare_attr = "_optimized_llopa_upper_prepare_mode"
upper_bucket_attr = "_optimized_llopa_upper_bucket_multiple"
seq_bucket_attr = "_optimized_llopa_seq_bucket_multiple"
lower_k_local = kwargs.pop(f"{kw_prefix}_layers", None)
if lower_k_local is None:
lower_k_local = int(getattr(self, lower_k_attr, 0) or 0)
attn_local = kwargs.pop(f"{kw_prefix}_attn", None)
if attn_local is None:
attn_local = str(getattr(self, attn_attr, "causal") or "causal")
system_prefill_local = kwargs.pop(f"{kw_prefix}_system_prefill", None)
if system_prefill_local is None:
system_prefill_local = str(getattr(self, system_attr, "full") or "full")
user_prefill_local = kwargs.pop(f"{kw_prefix}_user_prefill", None)
if user_prefill_local is None:
user_prefill_local = str(getattr(self, user_attr, "full") or "full")
no_upper_attn_local = kwargs.pop(f"{kw_prefix}_no_upper_attn", None)
if no_upper_attn_local is None:
no_upper_attn_local = bool(getattr(self, no_upper_attr, False))
see_past_assistant_local = kwargs.pop(f"{kw_prefix}_see_past_assistant", None)
if see_past_assistant_local is None:
see_past_assistant_local = bool(getattr(self, see_past_assistant_attr, False))
replay_module_local = kwargs.pop(f"{kw_prefix}_replay_module", None)
if replay_module_local is None:
replay_module_local = kwargs.pop(f"{kw_prefix}_last_layer_module", None)
if replay_module_local is None:
replay_module_local = getattr(self, replay_module_attr, None)
if replay_module_local is None:
replay_module_local = str(getattr(self, last_layer_attr, "none") or "none")
replay_per_layers_local = kwargs.pop(f"{kw_prefix}_replay_per_layers", None)
if replay_per_layers_local is None:
replay_per_layers_local = getattr(self, replay_per_layers_attr, -1)
structured_seed_mode_local = "auto"
if mode in {"llopa_v2", "llopa_v3", "llopa_v2_batch"}:
kwargs.pop("llopa_v2_seed_mode", None)
structured_seed_mode_local = "prefill_header"
optimized_variant_local = None
optimized_seed_mode_local = None
optimized_upper_prepare_mode_local = None
optimized_upper_bucket_multiple_local = None
optimized_seq_bucket_multiple_local = None
if mode == "optimized":
optimized_variant_local = kwargs.pop("optimized_llopa_variant", None)
if optimized_variant_local is None:
optimized_variant_local = getattr(self, variant_attr, "upper_ws_auto")
optimized_seed_mode_local = kwargs.pop("optimized_llopa_seed_mode", None)
if optimized_seed_mode_local is None:
optimized_seed_mode_local = getattr(self, seed_attr, "auto")
optimized_upper_prepare_mode_local = kwargs.pop("optimized_llopa_upper_prepare_mode", None)
if optimized_upper_prepare_mode_local is None:
optimized_upper_prepare_mode_local = getattr(self, upper_prepare_attr, "bucketed_workspace")
optimized_upper_bucket_multiple_local = kwargs.pop("optimized_llopa_upper_bucket_multiple", None)
if optimized_upper_bucket_multiple_local is None:
optimized_upper_bucket_multiple_local = getattr(self, upper_bucket_attr, 256)
optimized_seq_bucket_multiple_local = kwargs.pop("optimized_llopa_seq_bucket_multiple", None)
if optimized_seq_bucket_multiple_local is None:
optimized_seq_bucket_multiple_local = getattr(self, seq_bucket_attr, 256)
if prompt_messages is None and structured_prompt_segments is None:
_warn_once(
self,
f"_warned_{mode_label}_missing_prompt_metadata",
f"[load_llopa_model][warn] {mode_label} requested without structured prompt metadata; falling back to model.generate().",
)
return orig_generate(**kwargs)
if prompt_add_generation_prompt is None and structured_prompt_segments is None:
raise ValueError(f"{mode_label} requires prompt_add_generation_prompt when prompt_messages are provided.")
generate_impl = (
_direct_optimized_llopa_generate_impl
if mode == "optimized"
else _direct_llopa_batch_generate_impl
if mode == "llopa_v2_batch"
else _direct_llopa_generate_impl
)
generate_kwargs = dict(
prompt_messages=prompt_messages,
prompt_add_generation_prompt=bool(prompt_add_generation_prompt),
structured_prompt_segments=structured_prompt_segments,
input_ids=kwargs.get("input_ids"),
attention_mask=kwargs.get("attention_mask"),
lower_k=int(lower_k_local),
prefill_attn=str(attn_local),
system_prefill=str(system_prefill_local),
user_prefill=str(user_prefill_local),
no_upper_attn=bool(no_upper_attn_local),
see_past_assistant=bool(see_past_assistant_local),
replay_module=str(replay_module_local),
replay_per_layers=int(replay_per_layers_local or -1),
max_length=kwargs.get("max_length"),
max_new_tokens=kwargs.get("max_new_tokens"),
min_length=kwargs.get("min_length"),
min_new_tokens=kwargs.get("min_new_tokens"),
do_sample=kwargs.get("do_sample"),
temperature=kwargs.get("temperature"),
top_p=kwargs.get("top_p"),
top_k=kwargs.get("top_k"),
stopping_criteria=kwargs.get("stopping_criteria"),
pad_token_id=kwargs.get("pad_token_id"),
eos_token_id=kwargs.get("eos_token_id"),
output_scores=bool(kwargs.get("output_scores", False)),
return_dict_in_generate=bool(kwargs.get("return_dict_in_generate", False)),
use_cache=kwargs.get("use_cache"),
)
generate_kwargs["compact_scores"] = bool(compact_scores)
if mode in {"llopa_v2", "llopa_v3", "llopa_v2_batch"}:
generate_kwargs["seed_mode"] = str(structured_seed_mode_local)
if mode == "optimized":
generate_kwargs.update(
optimized_variant=optimized_variant_local,
optimized_seed_mode=optimized_seed_mode_local,
optimized_upper_prepare_mode=optimized_upper_prepare_mode_local,
optimized_upper_bucket_multiple=optimized_upper_bucket_multiple_local,
optimized_seq_bucket_multiple=optimized_seq_bucket_multiple_local,
)
with torch.inference_mode():
previous_mixin_decode = getattr(self, "_llopa_v2_generation_mixin_decode", None)
try:
if mode == "llopa_v2":
setattr(self, "_llopa_v2_generation_mixin_decode", False)
elif mode == "llopa_v3":
setattr(self, "_llopa_v2_generation_mixin_decode", True)
unified_out = generate_impl(
self,
tokenizer,
**generate_kwargs,
)
finally:
if previous_mixin_decode is None:
with contextlib.suppress(Exception):
delattr(self, "_llopa_v2_generation_mixin_decode")
else:
setattr(self, "_llopa_v2_generation_mixin_decode", previous_mixin_decode)
if unified_out is not None:
return unified_out
raise RuntimeError(f"Structured {mode} LLoPA failed unexpectedly for the current prompt.")
if mode == "freeze":
lower_k_local = kwargs.pop("runtime_prefill_freeze_layers", None)
if lower_k_local is None:
lower_k_local = int(getattr(self, "_runtime_prefill_freeze_layers", 0) or 0)
attn_local = kwargs.pop("runtime_prefill_freeze_attn", None)
if attn_local is None:
attn_local = str(getattr(self, "_runtime_prefill_freeze_attn", "causal") or "causal")
system_prefill_local = kwargs.pop("runtime_prefill_freeze_system_prefill", None)
if system_prefill_local is None:
system_prefill_local = str(getattr(self, "_runtime_prefill_freeze_system_prefill", "no_bos_system") or "no_bos_system")
if prompt_messages is None:
_warn_once(
self,
"_warned_runtime_freeze_missing_prompt_metadata",
"[load_llopa_model][warn] runtime_freeze_generate requested without structured prompt metadata; falling back to model.generate().",
)
return orig_generate(**kwargs)
if prompt_add_generation_prompt is None:
raise ValueError("runtime_freeze_generate requires prompt_add_generation_prompt when prompt_messages are provided.")
freeze_out = _direct_freeze_generate_impl(
self,
tokenizer,
prompt_messages=prompt_messages,
prompt_add_generation_prompt=bool(prompt_add_generation_prompt),
input_ids=kwargs.get("input_ids"),
attention_mask=kwargs.get("attention_mask"),
lower_k=int(lower_k_local),
prefill_attn=str(attn_local),
system_prefill=str(system_prefill_local),
max_length=kwargs.get("max_length"),
max_new_tokens=kwargs.get("max_new_tokens"),
min_length=kwargs.get("min_length"),
min_new_tokens=kwargs.get("min_new_tokens"),
do_sample=kwargs.get("do_sample"),
temperature=kwargs.get("temperature"),
top_p=kwargs.get("top_p"),
top_k=kwargs.get("top_k"),
stopping_criteria=kwargs.get("stopping_criteria"),
pad_token_id=kwargs.get("pad_token_id"),
eos_token_id=kwargs.get("eos_token_id"),
output_scores=bool(kwargs.get("output_scores", False)),
return_dict_in_generate=bool(kwargs.get("return_dict_in_generate", False)),
use_cache=kwargs.get("use_cache"),
)
if freeze_out is not None:
return freeze_out
raise RuntimeError("Structured runtime freeze failed unexpectedly for the current prompt.")
if mode == "solo":
lower_k_local = kwargs.pop("runtime_prefill_solo_layers", None)
if lower_k_local is None:
lower_k_local = int(getattr(self, "_runtime_prefill_solo_layers", 0) or 0)
attn_local = kwargs.pop("runtime_prefill_solo_attn", None)
if attn_local is None:
attn_local = str(getattr(self, "_runtime_prefill_solo_attn", "causal") or "causal")
system_prefill_local = kwargs.pop("runtime_prefill_solo_system_prefill", None)
if system_prefill_local is None:
system_prefill_local = str(getattr(self, "_runtime_prefill_solo_system_prefill", "no_bos_system") or "no_bos_system")
if prompt_messages is None:
_warn_once(
self,
"_warned_runtime_solo_missing_prompt_metadata",
"[load_llopa_model][warn] runtime_solo_generate requested without structured prompt metadata; falling back to model.generate().",
)
return orig_generate(**kwargs)
if prompt_add_generation_prompt is None:
raise ValueError("runtime_solo_generate requires prompt_add_generation_prompt when prompt_messages are provided.")
solo_out = _direct_solo_generate_impl(
self,
tokenizer,
prompt_messages=prompt_messages,
prompt_add_generation_prompt=bool(prompt_add_generation_prompt),
input_ids=kwargs.get("input_ids"),
attention_mask=kwargs.get("attention_mask"),
lower_k=int(lower_k_local),
prefill_attn=str(attn_local),
system_prefill=str(system_prefill_local),
max_length=kwargs.get("max_length"),
max_new_tokens=kwargs.get("max_new_tokens"),
min_length=kwargs.get("min_length"),
min_new_tokens=kwargs.get("min_new_tokens"),
do_sample=kwargs.get("do_sample"),
temperature=kwargs.get("temperature"),
top_p=kwargs.get("top_p"),
top_k=kwargs.get("top_k"),
stopping_criteria=kwargs.get("stopping_criteria"),
pad_token_id=kwargs.get("pad_token_id"),
eos_token_id=kwargs.get("eos_token_id"),
output_scores=bool(kwargs.get("output_scores", False)),
return_dict_in_generate=bool(kwargs.get("return_dict_in_generate", False)),
use_cache=kwargs.get("use_cache"),
)
if solo_out is not None:
return solo_out
raise RuntimeError("Structured runtime solo-attn failed unexpectedly for the current prompt.")
if mode == "solo_v2":
lower_k_local = kwargs.pop("runtime_prefill_solo_v2_layers", None)
if lower_k_local is None:
lower_k_local = int(getattr(self, "_runtime_prefill_solo_v2_layers", 0) or 0)
attn_local = kwargs.pop("runtime_prefill_solo_v2_attn", None)
if attn_local is None:
attn_local = str(getattr(self, "_runtime_prefill_solo_v2_attn", "causal") or "causal")
system_prefill_local = kwargs.pop("runtime_prefill_solo_v2_system_prefill", None)
if system_prefill_local is None:
system_prefill_local = str(getattr(self, "_runtime_prefill_solo_v2_system_prefill", "no_bos_system") or "no_bos_system")
with_bos_local = kwargs.pop("runtime_prefill_solo_v2_with_bos", None)
if with_bos_local is None:
with_bos_local = bool(getattr(self, "_runtime_prefill_solo_v2_with_bos", False))
if prompt_messages is None:
_warn_once(
self,
"_warned_runtime_solo_v2_missing_prompt_metadata",
"[load_llopa_model][warn] runtime_solo_v2_generate requested without structured prompt metadata; falling back to model.generate().",
)
return orig_generate(**kwargs)
if prompt_add_generation_prompt is None:
raise ValueError("runtime_solo_v2_generate requires prompt_add_generation_prompt when prompt_messages are provided.")
solo_out = _direct_solo_generate_impl(
self,
tokenizer,
prompt_messages=prompt_messages,
prompt_add_generation_prompt=bool(prompt_add_generation_prompt),
input_ids=kwargs.get("input_ids"),
attention_mask=kwargs.get("attention_mask"),
lower_k=int(lower_k_local),
prefill_attn=str(attn_local),
system_prefill=str(system_prefill_local),
max_length=kwargs.get("max_length"),
max_new_tokens=kwargs.get("max_new_tokens"),
min_length=kwargs.get("min_length"),
min_new_tokens=kwargs.get("min_new_tokens"),
do_sample=kwargs.get("do_sample"),
temperature=kwargs.get("temperature"),
top_p=kwargs.get("top_p"),
top_k=kwargs.get("top_k"),
stopping_criteria=kwargs.get("stopping_criteria"),
pad_token_id=kwargs.get("pad_token_id"),
eos_token_id=kwargs.get("eos_token_id"),
output_scores=bool(kwargs.get("output_scores", False)),
return_dict_in_generate=bool(kwargs.get("return_dict_in_generate", False)),
use_cache=kwargs.get("use_cache"),
solo_v2=True,
with_bos=bool(with_bos_local),
)
if solo_out is not None:
return solo_out
raise RuntimeError("Structured runtime solo-attn-v2 failed unexpectedly for the current prompt.")
_warn_once(
self,
"_warned_deprecated_direct_llopa",
"[load_llopa_model][warn] direct_llopa_* is deprecated; use unified_llopa_* or INFERENCE_PATH=unified_llopa. Legacy users can keep existing envs unchanged.",
)
lower_k_local = kwargs.pop("direct_llopa_layers", None)
if lower_k_local is None:
lower_k_local = int(getattr(self, "_direct_llopa_layers", 0) or 0)
attn_local = kwargs.pop("direct_llopa_attn", None)
if attn_local is None:
attn_local = str(getattr(self, "_direct_llopa_attn", "causal") or "causal")
system_prefill_local = kwargs.pop("direct_llopa_system_prefill", None)
if system_prefill_local is None:
system_prefill_local = str(getattr(self, "_direct_llopa_system_prefill", "full") or "full")
user_prefill_local = kwargs.pop("direct_llopa_user_prefill", None)
if user_prefill_local is None:
user_prefill_local = str(getattr(self, "_direct_llopa_user_prefill", "full") or "full")
no_upper_attn_local = kwargs.pop("direct_llopa_no_upper_attn", None)
if no_upper_attn_local is None:
no_upper_attn_local = bool(getattr(self, "_direct_llopa_no_upper_attn", False))
if not legacy_search and prompt_messages is None and structured_prompt_segments is None:
raise ValueError(
"direct_llopa_generate now requires prompt_messages and prompt_add_generation_prompt. "
"Use direct_llopa_legacy_search=True only for legacy prompt scanning."
)
if not legacy_search and prompt_add_generation_prompt is None and structured_prompt_segments is None:
raise ValueError("direct_llopa_generate requires prompt_add_generation_prompt when prompt_messages are provided.")
if legacy_search:
with torch.inference_mode():
direct_out = _legacy_direct_llopa_generate_impl(
self,
tokenizer,
input_ids=kwargs.get("input_ids"),
attention_mask=kwargs.get("attention_mask"),
lower_k=int(lower_k_local),
prefill_attn=str(attn_local),
max_length=kwargs.get("max_length"),
max_new_tokens=kwargs.get("max_new_tokens"),
min_length=kwargs.get("min_length"),
min_new_tokens=kwargs.get("min_new_tokens"),
do_sample=kwargs.get("do_sample"),
temperature=kwargs.get("temperature"),
top_p=kwargs.get("top_p"),
top_k=kwargs.get("top_k"),
stopping_criteria=kwargs.get("stopping_criteria"),
pad_token_id=kwargs.get("pad_token_id"),
eos_token_id=kwargs.get("eos_token_id"),
output_scores=bool(kwargs.get("output_scores", False)),
return_dict_in_generate=bool(kwargs.get("return_dict_in_generate", False)),
use_cache=kwargs.get("use_cache"),
)
else:
with torch.inference_mode():
direct_out = _direct_llopa_generate_impl(
self,
tokenizer,
prompt_messages=prompt_messages,
prompt_add_generation_prompt=bool(prompt_add_generation_prompt),
structured_prompt_segments=structured_prompt_segments,
input_ids=kwargs.get("input_ids"),
attention_mask=kwargs.get("attention_mask"),
lower_k=int(lower_k_local),
prefill_attn=str(attn_local),
system_prefill=str(system_prefill_local),
user_prefill=str(user_prefill_local),
no_upper_attn=bool(no_upper_attn_local),
max_length=kwargs.get("max_length"),
max_new_tokens=kwargs.get("max_new_tokens"),
min_length=kwargs.get("min_length"),
min_new_tokens=kwargs.get("min_new_tokens"),
do_sample=kwargs.get("do_sample"),
temperature=kwargs.get("temperature"),
top_p=kwargs.get("top_p"),
top_k=kwargs.get("top_k"),
stopping_criteria=kwargs.get("stopping_criteria"),
pad_token_id=kwargs.get("pad_token_id"),
eos_token_id=kwargs.get("eos_token_id"),
output_scores=bool(kwargs.get("output_scores", False)),
return_dict_in_generate=bool(kwargs.get("return_dict_in_generate", False)),
use_cache=kwargs.get("use_cache"),
)
if direct_out is not None:
return direct_out
if not legacy_search:
raise RuntimeError("Structured direct LLoPA failed unexpectedly for the current prompt.")
_warn_once(
self,
"_warned_direct_llopa_fallback",
"[load_llopa_model][warn] direct_llopa_generate could not use the current prompt/generation settings; falling back to model.generate().",
)
return orig_generate(**kwargs)
try:
model.generate = types.MethodType(_structured_llopa_generate, model)
setattr(model, "_structured_llopa_generate_attached", True)
except Exception:
pass
def _attach_unified_llopa_generate(
model,
tokenizer,
*,
lower_k: int,
prefill_attn: str = "causal",
system_prefill: str = "full",
user_prefill: str = "full",
no_upper_attn: bool = False,
see_past_assistant: bool = False,
replay_module: str = "none",
replay_per_layers: int = -1,
last_layer_module: Optional[str] = None,
) -> None:
normalized = _normalize_structured_llopa_runtime(
lower_k,
prefill_attn=prefill_attn,
system_prefill=system_prefill,
user_prefill=user_prefill,
no_upper_attn=no_upper_attn,
replay_module=replay_module,
replay_per_layers=replay_per_layers,
last_layer_module=last_layer_module,
)
if normalized is None:
return
lower_k, attn, sys_prefill, user_prefill_norm, no_upper_attn, replay_module, replay_per_layers = normalized
try:
setattr(model, "_unified_llopa_layers", int(lower_k))
setattr(model, "_unified_llopa_attn", attn)
setattr(model, "_unified_llopa_system_prefill", sys_prefill)
setattr(model, "_unified_llopa_user_prefill", user_prefill_norm)
setattr(model, "_unified_llopa_no_upper_attn", bool(no_upper_attn))
setattr(model, "_unified_llopa_see_past_assistant", bool(see_past_assistant))
setattr(model, "_unified_llopa_replay_module", str(replay_module))
setattr(model, "_unified_llopa_last_layer_module", str(replay_module))
setattr(model, "_unified_llopa_replay_per_layers", int(replay_per_layers))
setattr(model, "_unified_llopa_generate_default", True)
setattr(model, "_capsule_inference_path", "unified_llopa")
except Exception:
return
_attach_structured_llopa_generate(model, tokenizer)
try:
setattr(model, "_unified_llopa_generate_attached", True)
except Exception:
pass
def _attach_llopa_v2_generate(
model,
tokenizer,
*,
lower_k: int,
prefill_attn: str = "causal",
system_prefill: str = "full",
user_prefill: str = "full",
no_upper_attn: bool = False,
see_past_assistant: bool = False,
replay_module: str = "none",
replay_per_layers: int = -1,
last_layer_module: Optional[str] = None,
seed_mode: str = "prefill_header",
generation_mixin_decode: bool = False,
capsule_inference_path: str = "llopa_v2",
) -> None:
normalized = _normalize_structured_llopa_runtime(
lower_k,
prefill_attn=prefill_attn,
system_prefill=system_prefill,
user_prefill=user_prefill,
no_upper_attn=no_upper_attn,
replay_module=replay_module,
replay_per_layers=replay_per_layers,
last_layer_module=last_layer_module,
)
if normalized is None:
return
lower_k, attn, sys_prefill, user_prefill_norm, no_upper_attn, replay_module, replay_per_layers = normalized
normalized_seed_mode = _normalize_structured_llopa_seed_mode(seed_mode)
if normalized_seed_mode != "prefill_header":
normalized_seed_mode = "prefill_header"
try:
setattr(model, "_llopa_v2_layers", int(lower_k))
setattr(model, "_llopa_v2_attn", attn)
setattr(model, "_llopa_v2_system_prefill", sys_prefill)
setattr(model, "_llopa_v2_user_prefill", user_prefill_norm)
setattr(model, "_llopa_v2_no_upper_attn", bool(no_upper_attn))
setattr(model, "_llopa_v2_see_past_assistant", bool(see_past_assistant))
setattr(model, "_llopa_v2_replay_module", str(replay_module))
setattr(model, "_llopa_v2_last_layer_module", str(replay_module))
setattr(model, "_llopa_v2_replay_per_layers", int(replay_per_layers))
setattr(model, "_llopa_v2_seed_mode", str(normalized_seed_mode))
setattr(model, "_llopa_v2_generation_mixin_decode", bool(generation_mixin_decode))
setattr(model, "_llopa_v2_generate_default", True)
setattr(model, "_capsule_inference_path", str(capsule_inference_path or "llopa_v2"))
except Exception:
return
_attach_structured_llopa_generate(model, tokenizer)
try:
setattr(model, "_llopa_v2_generate_attached", True)
except Exception:
pass
def _attach_llopa_v2_batch_generate(
model,
tokenizer,
*,
lower_k: int,
prefill_attn: str = "causal",
system_prefill: str = "full",
user_prefill: str = "full",
no_upper_attn: bool = False,
see_past_assistant: bool = False,
replay_module: str = "none",
replay_per_layers: int = -1,
last_layer_module: Optional[str] = None,
seed_mode: str = "prefill_header",
) -> None:
normalized = _normalize_structured_llopa_runtime(
lower_k,
prefill_attn=prefill_attn,
system_prefill=system_prefill,
user_prefill=user_prefill,
no_upper_attn=no_upper_attn,
replay_module=replay_module,
replay_per_layers=replay_per_layers,
last_layer_module=last_layer_module,
)
if normalized is None:
return
lower_k, attn, sys_prefill, user_prefill_norm, no_upper_attn, replay_module, replay_per_layers = normalized
normalized_seed_mode = _normalize_structured_llopa_seed_mode(seed_mode)
if normalized_seed_mode != "prefill_header":
normalized_seed_mode = "prefill_header"
try:
setattr(model, "_llopa_v2_layers", int(lower_k))
setattr(model, "_llopa_v2_attn", attn)
setattr(model, "_llopa_v2_system_prefill", sys_prefill)
setattr(model, "_llopa_v2_user_prefill", user_prefill_norm)
setattr(model, "_llopa_v2_no_upper_attn", bool(no_upper_attn))
setattr(model, "_llopa_v2_see_past_assistant", bool(see_past_assistant))
setattr(model, "_llopa_v2_replay_module", str(replay_module))
setattr(model, "_llopa_v2_last_layer_module", str(replay_module))
setattr(model, "_llopa_v2_replay_per_layers", int(replay_per_layers))
setattr(model, "_llopa_v2_seed_mode", str(normalized_seed_mode))
setattr(model, "_llopa_v2_batch_generate_default", True)
setattr(model, "_capsule_inference_path", "llopa_v2_batch")
except Exception:
return
_attach_structured_llopa_generate(model, tokenizer)
try:
setattr(model, "_llopa_v2_batch_generate_attached", True)
except Exception:
pass
def _attach_optimized_llopa_generate(
model,
tokenizer,
*,
lower_k: int,
prefill_attn: str = "causal",
system_prefill: str = "full",
user_prefill: str = "full",
no_upper_attn: bool = False,
see_past_assistant: bool = False,
replay_module: str = "none",
replay_per_layers: int = -1,
last_layer_module: Optional[str] = None,
optimized_variant: Optional[str] = None,
optimized_seed_mode: Optional[str] = None,
optimized_upper_prepare_mode: Optional[str] = None,
optimized_upper_bucket_multiple: Optional[int] = None,
optimized_seq_bucket_multiple: Optional[int] = None,
) -> None:
normalized = _normalize_structured_llopa_runtime(
lower_k,
prefill_attn=prefill_attn,
system_prefill=system_prefill,
user_prefill=user_prefill,
no_upper_attn=no_upper_attn,
replay_module=replay_module,
replay_per_layers=replay_per_layers,
last_layer_module=last_layer_module,
)
if normalized is None:
return
lower_k, attn, sys_prefill, user_prefill_norm, no_upper_attn, replay_module, replay_per_layers = normalized
optimized_settings = _resolve_optimized_llopa_settings(
variant=optimized_variant,
seed_mode=optimized_seed_mode,
upper_prepare_mode=optimized_upper_prepare_mode,
upper_bucket_multiple=optimized_upper_bucket_multiple,
seq_bucket_multiple=optimized_seq_bucket_multiple,
)
try:
setattr(model, "_optimized_llopa_layers", int(lower_k))
setattr(model, "_optimized_llopa_attn", attn)
setattr(model, "_optimized_llopa_system_prefill", sys_prefill)
setattr(model, "_optimized_llopa_user_prefill", user_prefill_norm)
setattr(model, "_optimized_llopa_no_upper_attn", bool(no_upper_attn))
setattr(model, "_optimized_llopa_see_past_assistant", bool(see_past_assistant))
setattr(model, "_optimized_llopa_replay_module", str(replay_module))
setattr(model, "_optimized_llopa_last_layer_module", str(replay_module))
setattr(model, "_optimized_llopa_replay_per_layers", int(replay_per_layers))
setattr(model, "_optimized_llopa_variant", str(optimized_settings["variant"]))
setattr(model, "_optimized_llopa_seed_mode", str(optimized_settings["seed_mode"]))
setattr(model, "_optimized_llopa_upper_prepare_mode", str(optimized_settings["upper_prepare_mode"]))
setattr(model, "_optimized_llopa_upper_bucket_multiple", int(optimized_settings["upper_bucket_multiple"]))
setattr(model, "_optimized_llopa_seq_bucket_multiple", int(optimized_settings["seq_bucket_multiple"]))
setattr(model, "_optimized_llopa_generate_default", True)
setattr(model, "_capsule_inference_path", "optimized_llopa")
except Exception:
return
_attach_structured_llopa_generate(model, tokenizer)
try:
setattr(model, "_optimized_llopa_generate_attached", True)
except Exception:
pass
def _attach_direct_llopa_generate(
model,
tokenizer,
*,
lower_k: int,
prefill_attn: str = "causal",
system_prefill: str = "full",
user_prefill: str = "full",
no_upper_attn: bool = False,
) -> None:
normalized = _normalize_structured_llopa_runtime(
lower_k,
prefill_attn=prefill_attn,
system_prefill=system_prefill,
user_prefill=user_prefill,
no_upper_attn=no_upper_attn,
)
if normalized is None:
return
lower_k, attn, sys_prefill, user_prefill_norm, no_upper_attn, _, _ = normalized
header_ids = _assistant_header_ids(tokenizer, "cpu")
try:
setattr(model, "_direct_llopa_layers", int(lower_k))
setattr(model, "_direct_llopa_attn", attn)
setattr(model, "_direct_llopa_system_prefill", sys_prefill)
setattr(model, "_direct_llopa_user_prefill", user_prefill_norm)
setattr(model, "_direct_llopa_no_upper_attn", bool(no_upper_attn))
setattr(model, "_direct_llopa_generate_default", True)
setattr(model, "_capsule_inference_path", "legacy_llopa")
if isinstance(header_ids, torch.Tensor) and header_ids.numel() > 0:
setattr(model, "_direct_llopa_header_ids", header_ids.detach().to(device="cpu", dtype=torch.long))
except Exception:
return
_attach_structured_llopa_generate(model, tokenizer)
try:
setattr(model, "_direct_llopa_generate_attached", True)
except Exception:
pass
def load_llopa_model(model_repo: str,
*,
model_name: str = "",
tokenizer_name: str = "",
num_specials: Optional[int] = None,
backbone_dir: str = "",
lopa_modeling_path: str = "",
modeling_family: str = "auto",
dtype: str = "auto",
torch_dtype=None,
device: str = "",
device_map: Optional[str] = None,
attn_impl: str = "auto",
attn_implementation: Optional[str] = None,
_attn_implementation: Optional[str] = None,
trust_remote_code: bool = False,
cache_dir: Optional[str] = None,
revision: Optional[str] = None,
token: Optional[str] = None,
local_files_only: bool = False,
force_custom_modeling: bool = False,
number_of_lora: int = 1,
use_lora: bool = True,
merge_on_cpu: bool = True,
enable_thinking: Optional[bool] = None,
no_upper_attn: Optional[bool] = None,
runtime_prefill_lower: Optional[bool] = None,
runtime_prefill_layers: Optional[int] = None,
runtime_prefill_attn: Optional[str] = None,
runtime_prefill_system_prefill: Optional[str] = None,
runtime_prefill_freeze: Optional[bool] = None,
runtime_prefill_freeze_layers: Optional[int] = None,
runtime_prefill_freeze_attn: Optional[str] = None,
runtime_prefill_freeze_system_prefill: Optional[str] = None,
runtime_prefill_solo: Optional[bool] = None,
runtime_prefill_solo_layers: Optional[int] = None,
runtime_prefill_solo_attn: Optional[str] = None,
runtime_prefill_solo_system_prefill: Optional[str] = None,
runtime_prefill_solo_v2: Optional[bool] = None,
runtime_prefill_solo_v2_layers: Optional[int] = None,
runtime_prefill_solo_v2_attn: Optional[str] = None,
runtime_prefill_solo_v2_system_prefill: Optional[str] = None,
runtime_prefill_solo_v2_with_bos: Optional[bool] = None,
runtime_llopa_prefill: Optional[bool] = None,
runtime_llopa_layers: Optional[int] = None,
runtime_llopa_attn: Optional[str] = None,
runtime_llopa_no_upper_attn: Optional[bool] = None,
unified_llopa_generate: Optional[bool] = None,
unified_llopa_layers: Optional[int] = None,
unified_llopa_attn: Optional[str] = None,
unified_llopa_system_prefill: Optional[str] = None,
unified_llopa_user_prefill: Optional[str] = None,
unified_llopa_no_upper_attn: Optional[bool] = None,
unified_llopa_see_past_assistant: Optional[bool] = None,
unified_llopa_replay_module: Optional[str] = None,
unified_llopa_replay_per_layers: Optional[int] = None,
unified_llopa_last_layer_module: Optional[str] = None,
llopa_v2_batch_generate: Optional[bool] = None,
llopa_v2_generate: Optional[bool] = None,
llopa_v3_generate: Optional[bool] = None,
llopa_v2_layers: Optional[int] = None,
llopa_v2_attn: Optional[str] = None,
llopa_v2_system_prefill: Optional[str] = None,
llopa_v2_user_prefill: Optional[str] = None,
llopa_v2_no_upper_attn: Optional[bool] = None,
llopa_v2_see_past_assistant: Optional[bool] = None,
llopa_v2_replay_module: Optional[str] = None,
llopa_v2_replay_per_layers: Optional[int] = None,
llopa_v2_last_layer_module: Optional[str] = None,
llopa_v2_seed_mode: Optional[str] = None,
optimized_llopa_generate: Optional[bool] = None,
optimized_llopa_layers: Optional[int] = None,
optimized_llopa_attn: Optional[str] = None,
optimized_llopa_system_prefill: Optional[str] = None,
optimized_llopa_user_prefill: Optional[str] = None,
optimized_llopa_no_upper_attn: Optional[bool] = None,
optimized_llopa_see_past_assistant: Optional[bool] = None,
optimized_llopa_replay_module: Optional[str] = None,
optimized_llopa_replay_per_layers: Optional[int] = None,
optimized_llopa_last_layer_module: Optional[str] = None,
optimized_llopa_variant: Optional[str] = None,
optimized_llopa_seed_mode: Optional[str] = None,
optimized_llopa_upper_prepare_mode: Optional[str] = None,
optimized_llopa_upper_bucket_multiple: Optional[int] = None,
optimized_llopa_seq_bucket_multiple: Optional[int] = None,
runtime_llopa_fast_generate: Optional[bool] = None,
direct_llopa_generate: Optional[bool] = None,
direct_llopa_layers: Optional[int] = None,
direct_llopa_attn: Optional[str] = None,
direct_llopa_system_prefill: Optional[str] = None,
direct_llopa_user_prefill: Optional[str] = None,
direct_llopa_no_upper_attn: Optional[bool] = None):
repo_path = _resolve_repo_path(model_repo, cache_dir=cache_dir, revision=revision,
token=token, local_files_only=local_files_only)
info: dict[str, str] = {}
tri_info = repo_path / "tri_info.txt"
if tri_info.is_file():
info = _read_kv_file(tri_info)
if not model_name:
model_name = info.get("model_name", "")
if num_specials is None:
try:
num_specials = int(info.get("num_specials", "") or 0)
except Exception:
num_specials = 0
if runtime_prefill_lower is None:
runtime_prefill_lower = False
if runtime_prefill_freeze is None:
runtime_prefill_freeze = False
if runtime_prefill_solo is None:
runtime_prefill_solo = False
if runtime_prefill_solo_v2 is None:
runtime_prefill_solo_v2 = False
if runtime_llopa_prefill is None:
runtime_llopa_prefill = False
if unified_llopa_generate is None:
unified_llopa_generate = False
if llopa_v2_batch_generate is None:
llopa_v2_batch_generate = False
if llopa_v2_generate is None:
llopa_v2_generate = False
if llopa_v3_generate is None:
llopa_v3_generate = False
if bool(llopa_v3_generate):
llopa_v2_generate = True
if optimized_llopa_generate is None:
optimized_llopa_generate = False
if runtime_llopa_fast_generate is None:
runtime_llopa_fast_generate = False
if direct_llopa_generate is None:
direct_llopa_generate = False
backbone_ref = (
backbone_dir
or read_backbone_ref(repo_path)
or _read_adapter_backbone_ref(repo_path)
or model_name
or model_repo
)
num_specials_arg = num_specials
config = None
config_source = str(repo_path) if (repo_path / "config.json").is_file() else backbone_ref
try:
config = AutoConfig.from_pretrained(
config_source,
cache_dir=cache_dir,
revision=revision,
token=token,
local_files_only=local_files_only,
)
if num_specials_arg is not None:
config.llopa_num_specials = int(num_specials_arg)
except Exception:
config = None
if num_specials_arg is None:
num_specials = int(getattr(config, "llopa_num_specials", 0) or 0) if config is not None else 0
else:
num_specials = int(num_specials_arg)
def _config_str(name: str) -> str:
if config is None:
return ""
raw = getattr(config, name, "")
return str(raw or "").strip()
def _config_bool(name: str):
if config is None:
return None
value = getattr(config, name, None)
if value is None:
return None
return bool(value)
def _normalize_attention_gate_mode(mode: str) -> str:
normalized = str(mode or "off").strip().lower()
aliases = {
"": "off",
"none": "off",
"disabled": "off",
"disable": "off",
"false": "off",
"0": "off",
"paper": "sdpa_sigmoid",
"sdpa_gate": "sdpa_sigmoid",
"sdpa-gate": "sdpa_sigmoid",
"sigmoid_after_sdpa": "sdpa_sigmoid",
"sdpa_elementwise_sigmoid": "sdpa_sigmoid",
}
normalized = aliases.get(normalized, normalized)
if normalized not in {"off", "sdpa_sigmoid"}:
normalized = "off"
return normalized
attention_gate_mode = _normalize_attention_gate_mode(
info.get("attention_gate_mode") or _config_str("capsule_attention_gate_mode") or "off"
)
if config is not None:
with contextlib.suppress(Exception):
setattr(config, "capsule_attention_gate_mode", attention_gate_mode)
if no_upper_attn is None:
raw = (info.get("no_upper_attn") or _config_str("capsule_no_upper_attn")).strip().lower()
if raw in {"1", "true", "yes", "on"}:
no_upper_attn = True
elif raw in {"0", "false", "no", "off"}:
no_upper_attn = False
else:
cfg_bool = _config_bool("capsule_no_upper_attn")
if cfg_bool is not None:
no_upper_attn = cfg_bool
if runtime_prefill_layers is None:
raw = (info.get("lower_k") or _config_str("capsule_lower_layers")).strip()
if raw:
with contextlib.suppress(Exception):
runtime_prefill_layers = int(raw)
if runtime_prefill_attn is None:
runtime_prefill_attn = (
(info.get("prefill_attn") or _config_str("capsule_prefill_attn") or "causal").strip().lower() or "causal"
)
if runtime_prefill_system_prefill is None:
raw_system_prefill = (info.get("system_prefill") or _config_str("capsule_system_prefill")).strip().lower()
if raw_system_prefill in {"full", "no_system", "no_bos_system"}:
runtime_prefill_system_prefill = raw_system_prefill
if runtime_llopa_layers is None:
runtime_llopa_layers = runtime_prefill_layers
if runtime_llopa_attn is None:
runtime_llopa_attn = runtime_prefill_attn
if runtime_llopa_no_upper_attn is None:
runtime_llopa_no_upper_attn = bool(no_upper_attn) if no_upper_attn is not None else False
if unified_llopa_layers is None:
unified_llopa_layers = runtime_prefill_layers
if unified_llopa_attn is None:
unified_llopa_attn = runtime_prefill_attn
if unified_llopa_system_prefill is None:
unified_llopa_system_prefill = runtime_prefill_system_prefill
if unified_llopa_user_prefill is None:
raw_user_prefill = (info.get("user_prefill") or _config_str("capsule_user_prefill") or "full").strip().lower()
if raw_user_prefill:
unified_llopa_user_prefill = raw_user_prefill
if unified_llopa_no_upper_attn is None:
unified_llopa_no_upper_attn = bool(no_upper_attn) if no_upper_attn is not None else False
if unified_llopa_see_past_assistant is None:
unified_llopa_see_past_assistant = False
if unified_llopa_replay_module is None:
unified_llopa_replay_module = unified_llopa_last_layer_module
if unified_llopa_replay_module is None:
unified_llopa_replay_module = (
info.get("replay_module")
or _config_str("capsule_replay_module")
or info.get("last_layer_module")
or _config_str("capsule_last_layer_module")
)
unified_llopa_replay_module = _normalize_replay_module_value(unified_llopa_replay_module)
unified_llopa_last_layer_module = str(unified_llopa_replay_module)
if unified_llopa_replay_per_layers is None:
unified_llopa_replay_per_layers = info.get("replay_per_layers") or _config_str("capsule_replay_per_layers") or -1
unified_llopa_replay_per_layers = _normalize_replay_per_layers_value(unified_llopa_replay_per_layers)
if llopa_v2_layers is None:
llopa_v2_layers = runtime_prefill_layers
if llopa_v2_attn is None:
llopa_v2_attn = runtime_prefill_attn
if llopa_v2_system_prefill is None:
llopa_v2_system_prefill = runtime_prefill_system_prefill
if llopa_v2_user_prefill is None:
raw_user_prefill = (info.get("user_prefill") or _config_str("capsule_user_prefill") or "full").strip().lower()
if raw_user_prefill:
llopa_v2_user_prefill = raw_user_prefill
if llopa_v2_no_upper_attn is None:
llopa_v2_no_upper_attn = bool(no_upper_attn) if no_upper_attn is not None else False
if llopa_v2_see_past_assistant is None:
llopa_v2_see_past_assistant = False
if llopa_v2_replay_module is None:
llopa_v2_replay_module = llopa_v2_last_layer_module
if llopa_v2_replay_module is None:
llopa_v2_replay_module = (
info.get("replay_module")
or _config_str("capsule_replay_module")
or info.get("last_layer_module")
or _config_str("capsule_last_layer_module")
)
llopa_v2_replay_module = _normalize_replay_module_value(llopa_v2_replay_module)
llopa_v2_last_layer_module = str(llopa_v2_replay_module)
if llopa_v2_replay_per_layers is None:
llopa_v2_replay_per_layers = info.get("replay_per_layers") or _config_str("capsule_replay_per_layers") or -1
llopa_v2_replay_per_layers = _normalize_replay_per_layers_value(llopa_v2_replay_per_layers)
llopa_v2_seed_mode = "prefill_header"
if optimized_llopa_layers is None:
optimized_llopa_layers = runtime_prefill_layers
if optimized_llopa_attn is None:
optimized_llopa_attn = runtime_prefill_attn
if optimized_llopa_system_prefill is None:
optimized_llopa_system_prefill = runtime_prefill_system_prefill
if optimized_llopa_user_prefill is None:
raw_user_prefill = (info.get("user_prefill") or _config_str("capsule_user_prefill") or "full").strip().lower()
if raw_user_prefill:
optimized_llopa_user_prefill = raw_user_prefill
if optimized_llopa_no_upper_attn is None:
optimized_llopa_no_upper_attn = bool(no_upper_attn) if no_upper_attn is not None else False
if optimized_llopa_see_past_assistant is None:
optimized_llopa_see_past_assistant = False
if optimized_llopa_replay_module is None:
optimized_llopa_replay_module = optimized_llopa_last_layer_module
if optimized_llopa_replay_module is None:
optimized_llopa_replay_module = (
info.get("replay_module")
or _config_str("capsule_replay_module")
or info.get("last_layer_module")
or _config_str("capsule_last_layer_module")
)
optimized_llopa_replay_module = _normalize_replay_module_value(optimized_llopa_replay_module)
optimized_llopa_last_layer_module = str(optimized_llopa_replay_module)
if optimized_llopa_replay_per_layers is None:
optimized_llopa_replay_per_layers = info.get("replay_per_layers") or _config_str("capsule_replay_per_layers") or -1
optimized_llopa_replay_per_layers = _normalize_replay_per_layers_value(optimized_llopa_replay_per_layers)
optimized_settings = _resolve_optimized_llopa_settings(
variant=optimized_llopa_variant,
seed_mode=optimized_llopa_seed_mode,
upper_prepare_mode=optimized_llopa_upper_prepare_mode,
upper_bucket_multiple=optimized_llopa_upper_bucket_multiple,
seq_bucket_multiple=optimized_llopa_seq_bucket_multiple,
)
if direct_llopa_layers is None:
direct_llopa_layers = runtime_prefill_layers
if direct_llopa_attn is None:
direct_llopa_attn = runtime_prefill_attn
if direct_llopa_system_prefill is None:
direct_llopa_system_prefill = runtime_prefill_system_prefill
if direct_llopa_user_prefill is None:
raw_user_prefill = (info.get("user_prefill") or _config_str("capsule_user_prefill") or "full").strip().lower()
if raw_user_prefill:
direct_llopa_user_prefill = raw_user_prefill
if direct_llopa_no_upper_attn is None:
direct_llopa_no_upper_attn = bool(no_upper_attn) if no_upper_attn is not None else False
if runtime_prefill_solo_layers is None:
runtime_prefill_solo_layers = runtime_prefill_layers
if runtime_prefill_solo_attn is None:
runtime_prefill_solo_attn = runtime_prefill_attn
if runtime_prefill_solo_system_prefill is None:
runtime_prefill_solo_system_prefill = runtime_prefill_system_prefill
if runtime_prefill_solo_v2_layers is None:
runtime_prefill_solo_v2_layers = runtime_prefill_layers
if runtime_prefill_solo_v2_attn is None:
runtime_prefill_solo_v2_attn = runtime_prefill_attn
if runtime_prefill_solo_v2_system_prefill is None:
runtime_prefill_solo_v2_system_prefill = runtime_prefill_system_prefill
if runtime_prefill_solo_v2_with_bos is None:
runtime_prefill_solo_v2_with_bos = False
config_kwargs = {"config": config} if config is not None else {}
dtype_norm = _normalize_dtype_arg(dtype) or "auto"
torch_dtype_norm = _normalize_dtype_arg(torch_dtype)
if dtype_norm == "auto" and torch_dtype_norm is not None:
dtype_norm = torch_dtype_norm
if dtype_norm == "fp32":
torch_dtype = torch.float32
elif dtype_norm == "bf16":
torch_dtype = torch.bfloat16
elif dtype_norm == "fp16":
torch_dtype = torch.float16
else:
if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
torch_dtype = torch.bfloat16
elif torch.cuda.is_available():
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
# Accept HF-style aliases so AutoModel.from_pretrained(...) kwargs work as-is.
for cand in (attn_implementation, _attn_implementation):
if cand not in (None, "", "auto"):
attn_impl = str(cand)
break
if attn_impl != "auto" and config is not None:
for k in ("attn_implementation", "_attn_implementation"):
with contextlib.suppress(Exception):
setattr(config, k, attn_impl)
# `device_map="cuda:0"` is a single-device placement request, not a sharding map.
if isinstance(device_map, str):
dm = device_map.strip()
if dm in ("", "none", "None", "null", "NULL"):
device_map = None
elif dm in ("cuda", "cpu", "mps") or dm.startswith(("cuda:", "xpu:", "npu:")):
if not device:
device = dm
device_map = None
if device_map is None and not device:
device = "cuda" if torch.cuda.is_available() else "cpu"
if tokenizer_name:
tok_src = tokenizer_name
else:
tok_src = str(repo_path) if (repo_path / "tokenizer.json").is_file() else backbone_ref
tokenizer = AutoTokenizer.from_pretrained(
tok_src,
use_fast=True,
cache_dir=cache_dir,
revision=revision,
token=token,
local_files_only=local_files_only,
**_tokenizer_kwargs(AutoTokenizer.from_pretrained),
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
checkpoint_repo_path = repo_path
if not _repo_has_pretrained_weights(checkpoint_repo_path):
with contextlib.suppress(Exception):
resolved_backbone = _resolve_repo_path(
backbone_ref,
cache_dir=cache_dir,
revision=revision,
token=token,
local_files_only=local_files_only,
)
if _repo_has_pretrained_weights(resolved_backbone):
checkpoint_repo_path = resolved_backbone
checkpoint_vocab_size = 0
alignment_report = None
with contextlib.suppress(Exception):
checkpoint_vocab_size = int(_infer_checkpoint_vocab_size(checkpoint_repo_path) or 0)
if checkpoint_vocab_size > 0:
try:
tokenizer_vocab_size = int(len(tokenizer))
except Exception:
tokenizer_vocab_size = 0
if tokenizer_vocab_size > 0 and checkpoint_vocab_size > tokenizer_vocab_size:
alignment_report = _align_tokenizer_with_checkpoint_vocab(
tokenizer,
checkpoint_repo_path,
checkpoint_vocab_size,
)
_log_tokenizer_checkpoint_alignment("[load_llopa_model]", alignment_report)
if config is not None:
try:
tokenizer_vocab_size = int(len(tokenizer))
except Exception:
tokenizer_vocab_size = 0
config_vocab_size = 0
try:
config_vocab_size = int(getattr(config, "vocab_size", 0) or 0)
except Exception:
config_vocab_size = 0
target_vocab_size = max(config_vocab_size, tokenizer_vocab_size, int(checkpoint_vocab_size or 0))
if checkpoint_vocab_size > 0 and checkpoint_vocab_size != config_vocab_size:
_log_config_checkpoint_vocab_alignment(
"[load_llopa_model]",
config_vocab_size=config_vocab_size,
checkpoint_vocab_size=checkpoint_vocab_size,
tokenizer_vocab_size=tokenizer_vocab_size,
alignment_report=alignment_report,
)
if target_vocab_size > 0 and target_vocab_size != config_vocab_size:
with contextlib.suppress(Exception):
config.vocab_size = int(target_vocab_size)
if enable_thinking is not None:
if hasattr(tokenizer, "enable_thinking"):
try:
tokenizer.enable_thinking(enable_thinking)
except Exception:
try:
tokenizer.enable_thinking = enable_thinking
except Exception:
pass
elif hasattr(tokenizer, "set_enable_thinking"):
try:
tokenizer.set_enable_thinking(enable_thinking)
except Exception:
pass
try:
setattr(tokenizer, "_force_enable_thinking", enable_thinking)
except Exception:
pass
model_family = infer_model_family(backbone_ref, modeling_family)
modeling_path = _resolve_modeling_path(repo_path, lopa_modeling_path, model_family)
lora_path = repo_path / "lora"
has_lora = bool(use_lora and lora_path.exists() and any(lora_path.iterdir()))
if not has_lora and bool(use_lora):
top_level_adapter = (
(repo_path / "adapter_config.json").is_file()
and (
(repo_path / "adapter_model.safetensors").is_file()
or (repo_path / "adapter_model.bin").is_file()
)
)
if top_level_adapter:
lora_path = repo_path
has_lora = True
load_device_map = device_map
merge_on_cpu_active = bool(merge_on_cpu and has_lora and device_map is None)
if merge_on_cpu and has_lora and device_map is not None:
print("[LoRA] merge_on_cpu ignored because sharded device_map is requested.")
if merge_on_cpu_active:
print("[LoRA] Merging adapters on CPU to reduce CUDA peak memory.")
load_device_map = None
custom_mod = None
if modeling_path:
try:
custom_mod = load_custom_modeling(modeling_path, model_family=model_family)
except Exception:
custom_mod = None
base = None
custom_load_exc = None
if custom_mod is not None:
try:
if model_family == "qwen3":
base = custom_mod.Qwen3ForCausalLM.from_pretrained(
backbone_ref,
**_dtype_kwargs(custom_mod.Qwen3ForCausalLM.from_pretrained, torch_dtype),
trust_remote_code=trust_remote_code,
cache_dir=cache_dir,
revision=revision,
token=token,
local_files_only=local_files_only,
device_map=load_device_map,
**config_kwargs,
)
elif model_family == "mistral":
base = custom_mod.MistralForCausalLM.from_pretrained(
backbone_ref,
**_dtype_kwargs(custom_mod.MistralForCausalLM.from_pretrained, torch_dtype),
trust_remote_code=trust_remote_code,
cache_dir=cache_dir,
revision=revision,
token=token,
local_files_only=local_files_only,
device_map=load_device_map,
**config_kwargs,
)
else:
base = custom_mod.LlamaForCausalLM.from_pretrained(
backbone_ref,
**_dtype_kwargs(custom_mod.LlamaForCausalLM.from_pretrained, torch_dtype),
trust_remote_code=trust_remote_code,
cache_dir=cache_dir,
revision=revision,
token=token,
local_files_only=local_files_only,
device_map=load_device_map,
**config_kwargs,
)
except Exception as exc:
custom_load_exc = exc
base = None
if base is None and force_custom_modeling and custom_mod is not None and custom_load_exc is not None:
raise RuntimeError(
f"Failed to load base model with custom LLoPA modeling from {backbone_ref}"
) from custom_load_exc
if base is None:
base = AutoModelForCausalLM.from_pretrained(
backbone_ref,
trust_remote_code=trust_remote_code,
**_dtype_kwargs(AutoModelForCausalLM.from_pretrained, torch_dtype),
cache_dir=cache_dir,
revision=revision,
token=token,
local_files_only=local_files_only,
device_map=load_device_map,
**config_kwargs,
)
ensure_mistral_special_token(tokenizer, base)
load_embedding_layer(base, repo_path)
if merge_on_cpu_active:
try:
base = base.to("cpu")
except Exception:
pass
model = None
if has_lora:
num_lora = int(number_of_lora or 1)
if num_lora == 2:
gen_dir = lora_path / "gen"
prefill_dir = lora_path / "prefill"
if gen_dir.is_dir() and prefill_dir.is_dir():
try:
from peft import PeftModel
peft_gen = PeftModel.from_pretrained(base, str(gen_dir), adapter_name="gen")
try:
peft_gen.set_adapter("gen")
except Exception:
pass
merged_base = peft_gen.merge_and_unload()
peft_prefill = PeftModel.from_pretrained(merged_base, str(prefill_dir), adapter_name="prefill")
model = peft_prefill
setattr(model, "_prefill_adapter_only", True)
except Exception:
model = base
else:
model = base
if model is None or model is base:
try:
from peft import PeftModel
peft = PeftModel.from_pretrained(base, str(lora_path))
try:
model = peft.merge_and_unload()
except Exception:
model = peft
except Exception:
model = base
else:
model = base
if num_specials > 0:
if not load_llopa_specials(model, repo_path):
print("[Warn] Failed to load LLOPA specials.")
try:
p0 = next(model.parameters())
print(f"[load_llopa_model] model dtype={p0.dtype}, on={p0.device}")
except Exception:
pass
if device_map is None:
model = model.to(device).eval()
else:
model = model.eval()
if no_upper_attn is not None:
try:
setattr(model, "_no_upper_attn", bool(no_upper_attn))
except Exception:
pass
if bool(runtime_prefill_lower):
if _supports_prefill_lower_runtime(model):
_attach_prefill_lower_generate(
model,
lower_k=int(runtime_prefill_layers or 0),
prefill_attn=str(runtime_prefill_attn or "causal"),
system_prefill=str(runtime_prefill_system_prefill or "no_bos_system"),
)
try:
print(
f"[load_llopa_model] standard generate runtime enabled "
f"(prefill_lower_layers={int(runtime_prefill_layers or 0)}, "
f"prefill_attn={str(runtime_prefill_attn or 'causal')})"
)
except Exception:
pass
try:
setattr(model, "_capsule_inference_path", "runtime_lower")
except Exception:
pass
else:
print("[load_llopa_model][warn] runtime_prefill_lower requested but TRI prefill-lower runtime is unavailable.")
if bool(runtime_prefill_freeze):
if _supports_prefill_lower_runtime(model):
_attach_prefill_lower_freeze_generate(
model,
tokenizer=tokenizer,
lower_k=int(runtime_prefill_freeze_layers or 0),
prefill_attn=str(runtime_prefill_freeze_attn or "causal"),
system_prefill=str(runtime_prefill_freeze_system_prefill or "no_bos_system"),
)
try:
print(
f"[load_llopa_model] freeze-faithful generate runtime enabled "
f"(prefill_lower_layers={int(runtime_prefill_freeze_layers or 0)}, "
f"prefill_attn={str(runtime_prefill_freeze_attn or 'causal')}, "
f"system_prefill={str(runtime_prefill_freeze_system_prefill or 'no_bos_system')})"
)
except Exception:
pass
try:
setattr(model, "_capsule_inference_path", "runtime_freeze")
except Exception:
pass
else:
print("[load_llopa_model][warn] runtime_prefill_freeze requested but TRI prefill-freeze runtime is unavailable.")
if bool(runtime_prefill_solo):
if _supports_prefill_lower_runtime(model):
_attach_prefill_lower_solo_generate(
model,
tokenizer=tokenizer,
lower_k=int(runtime_prefill_solo_layers or 0),
prefill_attn=str(runtime_prefill_solo_attn or "causal"),
system_prefill=str(runtime_prefill_solo_system_prefill or "no_bos_system"),
)
try:
print(
f"[load_llopa_model] solo-attn generate runtime enabled "
f"(prefill_lower_layers={int(runtime_prefill_solo_layers or 0)}, "
f"prefill_attn={str(runtime_prefill_solo_attn or 'causal')}, "
f"system_prefill={str(runtime_prefill_solo_system_prefill or 'no_bos_system')})"
)
except Exception:
pass
try:
setattr(model, "_capsule_inference_path", "runtime_solo")
except Exception:
pass
else:
print("[load_llopa_model][warn] runtime_prefill_solo requested but TRI prefill-solo runtime is unavailable.")
if bool(runtime_prefill_solo_v2):
if _supports_prefill_lower_runtime(model):
_attach_prefill_lower_solo_v2_generate(
model,
tokenizer=tokenizer,
lower_k=int(runtime_prefill_solo_v2_layers or 0),
prefill_attn=str(runtime_prefill_solo_v2_attn or "causal"),
system_prefill=str(runtime_prefill_solo_v2_system_prefill or "no_bos_system"),
with_bos=bool(runtime_prefill_solo_v2_with_bos),
)
try:
print(
f"[load_llopa_model] solo-attn-v2 generate runtime enabled "
f"(prefill_lower_layers={int(runtime_prefill_solo_v2_layers or 0)}, "
f"prefill_attn={str(runtime_prefill_solo_v2_attn or 'causal')}, "
f"system_prefill={str(runtime_prefill_solo_v2_system_prefill or 'no_bos_system')}, "
f"with_bos={int(bool(runtime_prefill_solo_v2_with_bos))})"
)
except Exception:
pass
try:
setattr(model, "_capsule_inference_path", "runtime_solo_v2")
except Exception:
pass
else:
print("[load_llopa_model][warn] runtime_prefill_solo_v2 requested but TRI prefill-solo-v2 runtime is unavailable.")
if bool(runtime_llopa_prefill):
if _supports_runtime_llopa_prompt_prefill(model):
header_ids = _assistant_header_ids(tokenizer, "cpu")
if isinstance(header_ids, torch.Tensor) and header_ids.numel() > 0:
_attach_runtime_llopa_generate(
model,
header_ids=header_ids,
lower_k=int(runtime_llopa_layers or 0),
prefill_attn=str(runtime_llopa_attn or "causal"),
no_upper_attn=bool(runtime_llopa_no_upper_attn),
)
if bool(runtime_llopa_fast_generate):
_attach_runtime_llopa_fast_generate(
model,
lower_k=int(runtime_llopa_layers or 0),
prefill_attn=str(runtime_llopa_attn or "causal"),
no_upper_attn=bool(runtime_llopa_no_upper_attn),
)
try:
print(
f"[load_llopa_model] standard generate runtime enabled "
f"(llopa_prefill_layers={int(runtime_llopa_layers or 0)}, "
f"prefill_attn={str(runtime_llopa_attn or 'causal')})"
)
except Exception:
pass
try:
setattr(model, "_capsule_inference_path", "legacy_llopa")
except Exception:
pass
if bool(runtime_llopa_fast_generate):
try:
print("[load_llopa_model] runtime_llopa_fast_generate enabled")
except Exception:
pass
try:
setattr(model, "_capsule_inference_path", "legacy_llopa")
except Exception:
pass
else:
print("[load_llopa_model][warn] runtime_llopa_prefill requested but assistant header ids are unavailable.")
else:
print("[load_llopa_model][warn] runtime_llopa_prefill requested but TRI LLoPA runtime is unavailable.")
if bool(unified_llopa_generate):
if _supports_direct_llopa_generate(model):
_attach_unified_llopa_generate(
model,
tokenizer,
lower_k=int(unified_llopa_layers or 0),
prefill_attn=str(unified_llopa_attn or "causal"),
system_prefill=str(unified_llopa_system_prefill or "full"),
user_prefill=str(unified_llopa_user_prefill or "full"),
no_upper_attn=bool(unified_llopa_no_upper_attn),
see_past_assistant=bool(unified_llopa_see_past_assistant),
replay_module=str(unified_llopa_replay_module or "none"),
replay_per_layers=int(unified_llopa_replay_per_layers or -1),
)
try:
print(
f"[load_llopa_model] unified_llopa generate enabled "
f"(llopa_prefill_layers={int(unified_llopa_layers or 0)}, "
f"prefill_attn={str(unified_llopa_attn or 'causal')}, "
f"replay_module={str(unified_llopa_replay_module or 'none')}, "
f"replay_per_layers={int(unified_llopa_replay_per_layers or -1)})"
)
except Exception:
pass
else:
print("[load_llopa_model][warn] unified_llopa_generate requested but LLoPA direct prompt prefill is unavailable.")
if bool(llopa_v2_batch_generate):
if _supports_direct_llopa_generate(model):
_attach_llopa_v2_batch_generate(
model,
tokenizer,
lower_k=int(llopa_v2_layers or 0),
prefill_attn=str(llopa_v2_attn or "causal"),
system_prefill=str(llopa_v2_system_prefill or "full"),
user_prefill=str(llopa_v2_user_prefill or "full"),
no_upper_attn=bool(llopa_v2_no_upper_attn),
see_past_assistant=bool(llopa_v2_see_past_assistant),
replay_module=str(llopa_v2_replay_module or "none"),
replay_per_layers=int(llopa_v2_replay_per_layers or -1),
seed_mode=str(llopa_v2_seed_mode or "prefill_header"),
)
try:
print(
f"[load_llopa_model] llopa_v2_batch generate enabled "
f"(llopa_prefill_layers={int(llopa_v2_layers or 0)}, "
f"prefill_attn={str(llopa_v2_attn or 'causal')}, "
f"replay_module={str(llopa_v2_replay_module or 'none')}, "
f"replay_per_layers={int(llopa_v2_replay_per_layers or -1)}, "
f"seed_mode={str(llopa_v2_seed_mode or 'prefill_header')})"
)
except Exception:
pass
else:
print("[load_llopa_model][warn] llopa_v2_batch_generate requested but LLoPA direct prompt prefill is unavailable.")
if bool(llopa_v2_generate):
if _supports_direct_llopa_generate(model):
llopa_v3_active = bool(llopa_v3_generate)
_attach_llopa_v2_generate(
model,
tokenizer,
lower_k=int(llopa_v2_layers or 0),
prefill_attn=str(llopa_v2_attn or "causal"),
system_prefill=str(llopa_v2_system_prefill or "full"),
user_prefill=str(llopa_v2_user_prefill or "full"),
no_upper_attn=bool(llopa_v2_no_upper_attn),
see_past_assistant=bool(llopa_v2_see_past_assistant),
replay_module=str(llopa_v2_replay_module or "none"),
replay_per_layers=int(llopa_v2_replay_per_layers or -1),
seed_mode=str(llopa_v2_seed_mode or "prefill_header"),
generation_mixin_decode=llopa_v3_active,
capsule_inference_path="llopa_v3" if llopa_v3_active else "llopa_v2",
)
try:
label = "llopa_v3" if llopa_v3_active else "llopa_v2"
print(
f"[load_llopa_model] {label} generate enabled "
f"(llopa_prefill_layers={int(llopa_v2_layers or 0)}, "
f"prefill_attn={str(llopa_v2_attn or 'causal')}, "
f"replay_module={str(llopa_v2_replay_module or 'none')}, "
f"replay_per_layers={int(llopa_v2_replay_per_layers or -1)}, "
f"seed_mode={str(llopa_v2_seed_mode or 'prefill_header')}, "
f"generation_mixin_decode={int(llopa_v3_active)})"
)
except Exception:
pass
else:
label = "llopa_v3_generate" if bool(llopa_v3_generate) else "llopa_v2_generate"
print(f"[load_llopa_model][warn] {label} requested but LLoPA direct prompt prefill is unavailable.")
if bool(optimized_llopa_generate):
if _supports_direct_llopa_generate(model):
_attach_optimized_llopa_generate(
model,
tokenizer,
lower_k=int(optimized_llopa_layers or 0),
prefill_attn=str(optimized_llopa_attn or "causal"),
system_prefill=str(optimized_llopa_system_prefill or "full"),
user_prefill=str(optimized_llopa_user_prefill or "full"),
no_upper_attn=bool(optimized_llopa_no_upper_attn),
see_past_assistant=bool(optimized_llopa_see_past_assistant),
replay_module=str(optimized_llopa_replay_module or "none"),
replay_per_layers=int(optimized_llopa_replay_per_layers or -1),
optimized_variant=str(optimized_settings["variant"]),
optimized_seed_mode=str(optimized_settings["seed_mode"]),
optimized_upper_prepare_mode=str(optimized_settings["upper_prepare_mode"]),
optimized_upper_bucket_multiple=int(optimized_settings["upper_bucket_multiple"]),
optimized_seq_bucket_multiple=int(optimized_settings["seq_bucket_multiple"]),
)
try:
print(
f"[load_llopa_model] optimized_llopa generate enabled "
f"(llopa_prefill_layers={int(optimized_llopa_layers or 0)}, "
f"prefill_attn={str(optimized_llopa_attn or 'causal')}, "
f"replay_module={str(optimized_llopa_replay_module or 'none')}, "
f"replay_per_layers={int(optimized_llopa_replay_per_layers or -1)}, "
f"variant={str(optimized_settings['variant'])})"
)
except Exception:
pass
else:
print("[load_llopa_model][warn] optimized_llopa_generate requested but LLoPA direct prompt prefill is unavailable.")
if bool(direct_llopa_generate):
print(
"[load_llopa_model][warn] direct_llopa_* is deprecated; "
"use unified_llopa_* or INFERENCE_PATH=unified_llopa. Legacy users can keep existing envs unchanged."
)
if _supports_direct_llopa_generate(model):
_attach_direct_llopa_generate(
model,
tokenizer,
lower_k=int(direct_llopa_layers or 0),
prefill_attn=str(direct_llopa_attn or "causal"),
system_prefill=str(direct_llopa_system_prefill or "full"),
user_prefill=str(direct_llopa_user_prefill or "full"),
no_upper_attn=bool(direct_llopa_no_upper_attn),
)
try:
print(
f"[load_llopa_model] direct_llopa generate enabled "
f"(llopa_prefill_layers={int(direct_llopa_layers or 0)}, "
f"prefill_attn={str(direct_llopa_attn or 'causal')})"
)
except Exception:
pass
else:
print("[load_llopa_model][warn] direct_llopa_generate requested but LLoPA direct prompt prefill is unavailable.")
if force_custom_modeling:
try:
has_llopa = _has_active_llopa_runtime(model)
except Exception:
has_llopa = False
if not has_llopa:
raise RuntimeError("Custom LLoPA modeling not active at inference. Check --lopa_modeling_path.")
if attn_impl != "auto":
impl = attn_impl
for k in ("attn_implementation", "_attn_implementation"):
try:
setattr(model.config, k, impl)
inner = getattr(model, "model", None) or getattr(model, "transformer", None)
if inner is not None and hasattr(inner, "config"):
setattr(inner.config, k, impl)
except Exception:
pass
_attach_llopa_generate(model)
return model, tokenizer
# -----------------------------
# CLI
# -----------------------------
def main():
ap = argparse.ArgumentParser("TRI inference helper")
ap.add_argument("--best_dir", type=str, required=True, help="Path to best/ folder produced by training")
ap.add_argument("--backbone_dir", type=str, default=None,
help="Optional backbone path/ID (overrides best_dir/backbone.json and --model_name).")
ap.add_argument("--model_name", type=str, default="meta-llama/Llama-3.1-8B-Instruct")
ap.add_argument("--tokenizer_name", type=str, default="", help="Optional tokenizer name or path")
ap.add_argument("--prefill_layers", type=int, default=4)
ap.add_argument("--prefill_mode", type=str, choices=["lower", "periodic"], default="lower",
help="Prefill mode for user tokens: 'lower' uses first K layers, 'periodic' uses every K-th layer.")
ap.add_argument("--prefill_attn", type=str, choices=["causal", "full"], default="causal",
help="Prefill attention for system/user tokens (training must match).")
ap.add_argument("--system_prefill", type=str, choices=["full", "no_system", "no_bos_system"], default="full",
help="System prefill mode (must match training).")
ap.add_argument("--user_prefill", type=str, choices=["full", "no_question"], default="full",
help="User prefill ablation: full=doc+question, no_question=doc-only (question runs in full layers).")
ap.add_argument("--llopa_prefill", action="store_true",
help="Use single-forward LLOPA prefill (causal + lower only).")
ap.add_argument("--no_upper_attn", action="store_true",
help="Skip upper-layer attention during decode (effective only with --llopa_prefill).")
ap.add_argument("--lopa_modeling_path", type=str, default="tri_llama3_modeling.py",
help="Path to custom LLoPA modeling file used in training")
ap.add_argument("--modeling_family", type=str, choices=["auto", "llama", "qwen3", "mistral"], default="auto",
help="Model family for custom modeling injection (auto detects from --model_name)")
ap.add_argument("--force_custom_modeling", action="store_true",
help="Require custom LLoPA modeling to be active; error if not.")
ap.add_argument("--system", type=str, default="You are a helpful assistant that answers questions based on the given document. ")
ap.add_argument("--task", type=str, default="qa_doc",
help="Prompt template task: qa_doc | math | summary | code")
ap.add_argument("--math_force_final_hash_rule", action="store_true",
help="If set and task=math, append the #### answer rule to the system prompt.")
ap.add_argument("--document", type=str, required=True)
ap.add_argument("--question", type=str, required=True)
ap.add_argument("--max_new_tokens", type=int, default=256)
ap.add_argument("--min_length", type=int, default=16)
ap.add_argument("--temperature", type=float, default=0.7)
ap.add_argument("--top_p", type=float, default=0.9)
ap.add_argument("--top_k", type=int, default=None)
ap.add_argument("--do_sample", action="store_true")
# numeric controls for reproducibility
ap.add_argument("--dtype", type=str, choices=["auto","bf16","fp16","fp32"], default="auto")
ap.add_argument("--no_tf32", action="store_true")
ap.add_argument("--sdpa_math_only", action="store_true")
ap.add_argument("--debug", action="store_true")
ap.add_argument("--attn_impl", type=str, choices=["sdpa", "eager", "auto"], default="sdpa",
help="Attention implementation override (auto keeps model default).")
args = ap.parse_args()
device = "cuda" if torch.cuda.is_available() else "cpu"
if args.dtype == "fp32":
dtype = torch.float32
elif args.dtype == "bf16":
dtype = torch.bfloat16
elif args.dtype == "fp16":
dtype = torch.float16
else:
dtype = torch.bfloat16 if (device == "cuda" and torch.cuda.is_bf16_supported()) else (torch.float16 if device == "cuda" else torch.float32)
# global numeric toggles
if args.no_tf32 and torch.cuda.is_available():
try:
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
except Exception:
pass
if args.sdpa_math_only and torch.cuda.is_available():
try:
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
except Exception:
pass
best_dir = Path(args.best_dir)
if getattr(args, "tokenizer_name", ""):
tok_src = args.tokenizer_name
else:
tok_src = str(best_dir) if (best_dir / "tokenizer.json").is_file() else args.model_name
tokenizer = AutoTokenizer.from_pretrained(
tok_src,
use_fast=True,
**_tokenizer_kwargs(AutoTokenizer.from_pretrained),
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
# Prefer a saved base backbone under best_dir/base if present (captures resized embeddings)
backbone_ref = None
if args.backbone_dir:
backbone_ref = args.backbone_dir
else:
backbone_ref = read_backbone_ref(best_dir)
base_path = best_dir / "base"
if backbone_ref:
base_load_src = backbone_ref
elif base_path.exists() and any(base_path.iterdir()):
base_load_src = str(base_path)
else:
base_load_src = args.model_name
num_specials = None
tri_no_upper_attn = None
tri_info = best_dir / "tri_info.txt"
if tri_info.is_file():
info = _read_kv_file(tri_info)
try:
num_specials = int(info.get("num_specials", "") or 0)
except Exception:
num_specials = None
raw_no_upper = (info.get("no_upper_attn") or "").strip().lower()
if raw_no_upper in {"1", "true", "yes", "on"}:
tri_no_upper_attn = True
elif raw_no_upper in {"0", "false", "no", "off"}:
tri_no_upper_attn = False
if (not bool(getattr(args, "no_upper_attn", False))) and (tri_no_upper_attn is not None):
args.no_upper_attn = bool(tri_no_upper_attn)
config = None
try:
config = AutoConfig.from_pretrained(base_load_src)
if num_specials is not None:
config.llopa_num_specials = int(num_specials)
except Exception:
config = None
if num_specials is None:
num_specials = int(getattr(config, "llopa_num_specials", 0) or 0) if config is not None else 0
config_kwargs = {"config": config} if config is not None else {}
# Try loading custom LLoPA modeling before model load
custom_mod = None
model_family = infer_model_family(args.model_name, args.modeling_family)
try:
custom_mod = load_custom_modeling(args.lopa_modeling_path, model_family=model_family)
except Exception:
custom_mod = None
base = None
if custom_mod is not None:
try:
# Prefer explicit class if available (ensures we really use custom class)
if model_family == "qwen3":
base = custom_mod.Qwen3ForCausalLM.from_pretrained(
base_load_src,
**_dtype_kwargs(custom_mod.Qwen3ForCausalLM.from_pretrained, dtype),
**config_kwargs,
)
elif model_family == "mistral":
base = custom_mod.MistralForCausalLM.from_pretrained(
base_load_src,
**_dtype_kwargs(custom_mod.MistralForCausalLM.from_pretrained, dtype),
**config_kwargs,
)
else:
base = custom_mod.LlamaForCausalLM.from_pretrained(
base_load_src,
**_dtype_kwargs(custom_mod.LlamaForCausalLM.from_pretrained, dtype),
**config_kwargs,
)
except Exception:
base = None
if base is None:
base = AutoModelForCausalLM.from_pretrained(
base_load_src,
trust_remote_code=False,
**_dtype_kwargs(AutoModelForCausalLM.from_pretrained, dtype),
**config_kwargs,
)
# Ensure special token availability for Mistral
ensure_mistral_special_token(tokenizer, base)
# Apply saved embedding layer (special tokens) if present.
loaded_emb = load_embedding_layer(base, best_dir)
# attach LoRA if exists
lora_path = best_dir / "lora"
model = None
merged_lora = False
if lora_path.exists() and any(lora_path.iterdir()):
try:
from peft import PeftModel
peft = PeftModel.from_pretrained(base, str(lora_path))
try:
model = peft.merge_and_unload()
merged_lora = True
except Exception:
# fallback: keep PEFT wrapper without merge
model = peft
except Exception:
model = base
else:
model = base
if num_specials > 0:
if not load_llopa_specials(model, best_dir):
print("[Warn] Failed to load LLOPA specials.")
# device & eval
model = model.to(device).eval()
try:
setattr(model, "_no_upper_attn", bool(getattr(args, "no_upper_attn", False)))
except Exception:
pass
# Validate custom modeling presence if requested
if args.force_custom_modeling:
has_llopa = _has_active_llopa_runtime(model)
if not has_llopa:
raise RuntimeError("Custom LLoPA modeling not active at inference. Check --lopa_modeling_path.")
# Allow sdpa/eager; avoid flash_attention_2 for LoPA masks.
if args.attn_impl != "auto":
impl = args.attn_impl
for k in ("attn_implementation", "_attn_implementation"):
try:
setattr(model.config, k, impl)
inner = getattr(model, "model", None) or getattr(model, "transformer", None)
if inner is not None and hasattr(inner, "config"):
setattr(inner.config, k, impl)
except Exception:
pass
print(f"[infer] Forcing attn_implementation='{impl}' for all models.")
else:
print("[infer] Using model default attn_implementation (auto).")
if args.debug:
print(f"[debug] load base from: {base_load_src}")
if loaded_emb:
print("[debug] loaded embedding layer from best_dir")
print(f"[debug] lora path: {lora_path} | merged={merged_lora}")
tmpl = getattr(tokenizer, "chat_template", "") or ""
print(f"[debug] template contains Llama3 header? {('<|start_header_id|>' in tmpl)} | Mistral? {('[INST]' in tmpl)}")
# debug dir under best_dir
dbg_dir = (best_dir / "debug_infer") if args.debug else None
text = lopa_generate(
model, tokenizer,
system=args.system, document=args.document, question=args.question,
task=str(getattr(args, "task", "qa_doc")),
K=int(args.prefill_layers),
prefill_mode=str(args.prefill_mode),
prefill_attn=str(getattr(args, "prefill_attn", "causal")),
system_prefill=str(getattr(args, "system_prefill", "full")),
user_prefill=str(getattr(args, "user_prefill", "full")),
device=device,
max_new_tokens=args.max_new_tokens, min_length=args.min_length,
temperature=args.temperature, top_p=args.top_p, top_k=args.top_k,
do_sample=bool(args.do_sample),
math_force_final_hash_rule=bool(getattr(args, "math_force_final_hash_rule", False)),
llopa_prefill=bool(getattr(args, "llopa_prefill", False)),
no_upper_attn=bool(getattr(args, "no_upper_attn", False)),
debug=bool(args.debug), debug_dir=dbg_dir,
)
print(text)
if __name__ == "__main__":
main()