# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from typing import Any, Dict, List, Tuple from jinja2 import Environment, BaseLoader # —— 模板:用显式 \n 控制换行,并用 -%} / {%- 去掉多余空白 —— JINJA_PROMPT_TMPL = ( "<|im_start|>system\n" "{{ system_prompt }}<|im_end|>\n" "{% for m in msgs -%}" "<|im_start|>{{ m.role }}\n" "{% if not (m.role == 'assistant' and not include_assistant_content) -%}" "{{ m.content | render_mm_list }}" "{% endif -%}" "{% if (not (loop.last and m.role == 'assistant')) or include_assistant_content -%}" "<|im_end|>\n" "{% endif -%}" "{% endfor -%}" ) VS, VE = "<|vision_start|>", "<|vision_end|>" VP, IP = "<|video_pad|>", "<|image_pad|>" def expand_and_index_by_token_ids_new( rendered_text: str, tokens: List[int], # 遇到 VP/IP 的顺序逐个取 K tokenizer, # HF tokenizer(需含 VP/IP/VE/VS 等special tokens) target_text: str = "", # 如 "assistant\n" search_text: str = "", # 如 "" ) -> Tuple[str, List[int], List[List[int]], List[int]]: """ 返回: new_rendered_text: 扩展后的文本 all_token_id : new_rendered_text 的 token ids spans_index : 每个pad块在 all_token_id 中的索引列表(按出现顺序),如 [[100..199], [350..549], ...] tgt_index : target_text 在 all_token_id 中的索引列表(找不到返回 []) """ vs_ids = tokenizer(VS, add_special_tokens=False)["input_ids"] ve_ids = tokenizer(VE, add_special_tokens=False)["input_ids"] vp_ids = tokenizer(VP, add_special_tokens=False)["input_ids"] ip_ids = tokenizer(IP, add_special_tokens=False)["input_ids"] enc = tokenizer(rendered_text, add_special_tokens=False) base_ids = enc["input_ids"] # ---------- 1) 扫描并按出现顺序扩展 VP/IP 为 K 次,占位信息入 pad_blocks ---------- # find all VS positions and pair them with nearest VE after each VS all_ids: List[int] = [] spans_index: List[List[int]] = [] i = 0 # base_ids 扫描指针 tk_ptr = 0 # tokens(K) 指针 while True: try: vs_positions_ = base_ids[i:].index(vs_ids[0]) + i except: all_ids.extend(base_ids[i:]) break all_ids.extend(base_ids[i: vs_positions_]) i = vs_positions_ + 3 # 进行序列扩展,插入占位信息入 pad_ids pad_ids = base_ids[vs_positions_ + 1:vs_positions_ + 2] K = int(tokens[tk_ptr]) start, end = len(all_ids) + 1, len(all_ids) + 1 + K all_ids.extend(vs_ids + pad_ids * K + ve_ids) tk_ptr += 1 # 获取 每个pad token 在 all_token_id 中的索引列表(按出现顺序),如 [[100..199], [350..549], ...] #start, end = vs_positions_ + 1, vs_positions_ + 1 + K spans_index.append(list(range(start, end))) tgt_index: List[int] = [] if target_text: tgt_ids_identify = tokenizer(target_text, add_special_tokens=False)["input_ids"] i = 0 # base_ids 扫描指针 while i < len(all_ids): tgt_positions_ = all_ids[i:].index(tgt_ids_identify[0]) + i if all_ids[tgt_positions_+len(tgt_ids_identify)-1] == tgt_ids_identify[-1]: tgt_index = list(range(tgt_positions_+len(tgt_ids_identify), len(all_ids))) break else: i = tgt_positions_ + 1 search_index: List[int] = [] if search_text: search_ids_identify = tokenizer(search_text, add_special_tokens=False)["input_ids"] i = 0 # base_ids 扫描指针 while i < len(all_ids): search_positions_ = all_ids[i:].index(search_ids_identify[0]) + i if all_ids[search_positions_:search_positions_+len(search_ids_identify)] == search_ids_identify: search_index = list(range(search_positions_, search_positions_+len(search_ids_identify))) break else: i = search_positions_ + 1 return all_ids, spans_index, tgt_index, search_index def _extract_system_prompt(messages: List[Dict[str, Any]], default_system: str) -> str: for m in messages: if m.get("role") == "system": c = m.get("content", "") if isinstance(c, str): return c if isinstance(c, list): texts = [it.get("text", "") for it in c if isinstance(it, dict) and it.get("type") == "text"] if texts: return "".join(texts) return default_system def _normalize_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: norm: List[Dict[str, Any]] = [] for m in messages: role = m.get("role") if role == "system": continue c = m.get("content", "") if isinstance(c, str): items = [{"type": "text", "text": c}] elif isinstance(c, list): items = c else: items = [] norm.append({"role": role, "content": items}) return norm def render_qwenvl_prompt( messages: List[Dict[str, Any]], default_system: str = "You are a helpful assistant.", include_assistant_content: bool = False, # 关键参数:是否渲染 assistant 文本 force_video_pad: bool = False, ) -> str: system_prompt = _extract_system_prompt(messages, default_system) msgs = _normalize_messages(messages) def _render_mm_list(items: Any) -> str: if isinstance(items, str): return items if not isinstance(items, list): return "" parts: List[str] = [] for it in items: if not isinstance(it, dict): continue t = it.get("type") if t == "text": parts.append(it.get("text", "")) elif t == "image": if force_video_pad: parts.append("<|vision_start|><|image_pad|><|vision_end|>") else: parts.append("<|vision_start|><|video_pad|><|vision_end|>") elif t == "video": parts.append("<|vision_start|><|video_pad|><|vision_end|>") # 其他模态可在这里扩展 return "".join(parts) env = Environment( loader=BaseLoader(), autoescape=False, trim_blocks=True, # 去掉块结束后的换行 lstrip_blocks=True, # 去掉块起始前的空白 newline_sequence="\n", keep_trailing_newline=False, ) env.filters["render_mm_list"] = _render_mm_list template = env.from_string(JINJA_PROMPT_TMPL) return template.render( system_prompt=system_prompt, msgs=msgs, include_assistant_content=include_assistant_content, )