Lance / data /system_prompt_render.py
Nayefleb's picture
Upload folder using huggingface_hub
8b306b3 verified
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Any, Dict, List, Tuple
from jinja2 import Environment, BaseLoader
# —— 模板:用显式 \n 控制换行,并用 -%} / {%- 去掉多余空白 ——
JINJA_PROMPT_TMPL = (
"<|im_start|>system\n"
"{{ system_prompt }}<|im_end|>\n"
"{% for m in msgs -%}"
"<|im_start|>{{ m.role }}\n"
"{% if not (m.role == 'assistant' and not include_assistant_content) -%}"
"{{ m.content | render_mm_list }}"
"{% endif -%}"
"{% if (not (loop.last and m.role == 'assistant')) or include_assistant_content -%}"
"<|im_end|>\n"
"{% endif -%}"
"{% endfor -%}"
)
VS, VE = "<|vision_start|>", "<|vision_end|>"
VP, IP = "<|video_pad|>", "<|image_pad|>"
def expand_and_index_by_token_ids_new(
rendered_text: str,
tokens: List[int], # 遇到 VP/IP 的顺序逐个取 K
tokenizer, # HF tokenizer(需含 VP/IP/VE/VS 等special tokens)
target_text: str = "", # 如 "assistant\n"
search_text: str = "", # 如 ""
) -> Tuple[str, List[int], List[List[int]], List[int]]:
"""
返回:
new_rendered_text: 扩展后的文本
all_token_id : new_rendered_text 的 token ids
spans_index : 每个pad块在 all_token_id 中的索引列表(按出现顺序),如 [[100..199], [350..549], ...]
tgt_index : target_text 在 all_token_id 中的索引列表(找不到返回 [])
"""
vs_ids = tokenizer(VS, add_special_tokens=False)["input_ids"]
ve_ids = tokenizer(VE, add_special_tokens=False)["input_ids"]
vp_ids = tokenizer(VP, add_special_tokens=False)["input_ids"]
ip_ids = tokenizer(IP, add_special_tokens=False)["input_ids"]
enc = tokenizer(rendered_text, add_special_tokens=False)
base_ids = enc["input_ids"]
# ---------- 1) 扫描并按出现顺序扩展 VP/IP 为 K 次,占位信息入 pad_blocks ----------
# find all VS positions and pair them with nearest VE after each VS
all_ids: List[int] = []
spans_index: List[List[int]] = []
i = 0 # base_ids 扫描指针
tk_ptr = 0 # tokens(K) 指针
while True:
try:
vs_positions_ = base_ids[i:].index(vs_ids[0]) + i
except:
all_ids.extend(base_ids[i:])
break
all_ids.extend(base_ids[i: vs_positions_])
i = vs_positions_ + 3
# 进行序列扩展,插入占位信息入 pad_ids
pad_ids = base_ids[vs_positions_ + 1:vs_positions_ + 2]
K = int(tokens[tk_ptr])
start, end = len(all_ids) + 1, len(all_ids) + 1 + K
all_ids.extend(vs_ids + pad_ids * K + ve_ids)
tk_ptr += 1
# 获取 每个pad token 在 all_token_id 中的索引列表(按出现顺序),如 [[100..199], [350..549], ...]
#start, end = vs_positions_ + 1, vs_positions_ + 1 + K
spans_index.append(list(range(start, end)))
tgt_index: List[int] = []
if target_text:
tgt_ids_identify = tokenizer(target_text, add_special_tokens=False)["input_ids"]
i = 0 # base_ids 扫描指针
while i < len(all_ids):
tgt_positions_ = all_ids[i:].index(tgt_ids_identify[0]) + i
if all_ids[tgt_positions_+len(tgt_ids_identify)-1] == tgt_ids_identify[-1]:
tgt_index = list(range(tgt_positions_+len(tgt_ids_identify), len(all_ids)))
break
else:
i = tgt_positions_ + 1
search_index: List[int] = []
if search_text:
search_ids_identify = tokenizer(search_text, add_special_tokens=False)["input_ids"]
i = 0 # base_ids 扫描指针
while i < len(all_ids):
search_positions_ = all_ids[i:].index(search_ids_identify[0]) + i
if all_ids[search_positions_:search_positions_+len(search_ids_identify)] == search_ids_identify:
search_index = list(range(search_positions_, search_positions_+len(search_ids_identify)))
break
else:
i = search_positions_ + 1
return all_ids, spans_index, tgt_index, search_index
def _extract_system_prompt(messages: List[Dict[str, Any]], default_system: str) -> str:
for m in messages:
if m.get("role") == "system":
c = m.get("content", "")
if isinstance(c, str):
return c
if isinstance(c, list):
texts = [it.get("text", "") for it in c if isinstance(it, dict) and it.get("type") == "text"]
if texts:
return "".join(texts)
return default_system
def _normalize_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
norm: List[Dict[str, Any]] = []
for m in messages:
role = m.get("role")
if role == "system":
continue
c = m.get("content", "")
if isinstance(c, str):
items = [{"type": "text", "text": c}]
elif isinstance(c, list):
items = c
else:
items = []
norm.append({"role": role, "content": items})
return norm
def render_qwenvl_prompt(
messages: List[Dict[str, Any]],
default_system: str = "You are a helpful assistant.",
include_assistant_content: bool = False, # 关键参数:是否渲染 assistant 文本
force_video_pad: bool = False,
) -> str:
system_prompt = _extract_system_prompt(messages, default_system)
msgs = _normalize_messages(messages)
def _render_mm_list(items: Any) -> str:
if isinstance(items, str):
return items
if not isinstance(items, list):
return ""
parts: List[str] = []
for it in items:
if not isinstance(it, dict):
continue
t = it.get("type")
if t == "text":
parts.append(it.get("text", ""))
elif t == "image":
if force_video_pad:
parts.append("<|vision_start|><|image_pad|><|vision_end|>")
else:
parts.append("<|vision_start|><|video_pad|><|vision_end|>")
elif t == "video":
parts.append("<|vision_start|><|video_pad|><|vision_end|>")
# 其他模态可在这里扩展
return "".join(parts)
env = Environment(
loader=BaseLoader(),
autoescape=False,
trim_blocks=True, # 去掉块结束后的换行
lstrip_blocks=True, # 去掉块起始前的空白
newline_sequence="\n",
keep_trailing_newline=False,
)
env.filters["render_mm_list"] = _render_mm_list
template = env.from_string(JINJA_PROMPT_TMPL)
return template.render(
system_prompt=system_prompt,
msgs=msgs,
include_assistant_content=include_assistant_content,
)