| |
| import argparse |
| import re |
| import pandas as pd |
| from transformers import AutoTokenizer |
|
|
| |
| CLOSED_PAT = re.compile( |
| r"<\|im_start\|>(system|user|assistant)[ \t]*\n" |
| r"(.*?)" |
| r"<\|im_end\|>", |
| flags=re.DOTALL |
| ) |
|
|
| |
| OPEN_ASSIST_TAIL = re.compile( |
| r"<\|im_start\|>assistant[ \t]*\n([\s\S]*)\Z", |
| flags=re.DOTALL |
| ) |
|
|
| def chatml_to_messages_and_tail(text: str): |
| """解析为 messages(仅闭合块)和可能存在的未闭合 assistant 尾块。""" |
| if text is None: |
| return [], None |
| t = str(text) |
| msgs, last_end = [], 0 |
| for m in CLOSED_PAT.finditer(t): |
| role = m.group(1) |
| content = m.group(2).strip("\n") |
| msgs.append({"role": role, "content": content}) |
| last_end = m.end() |
| tail = t[last_end:] |
| m_tail = OPEN_ASSIST_TAIL.search(tail) if tail else None |
| tail_assistant = m_tail.group(1) if m_tail else None |
| return msgs, tail_assistant |
|
|
| def transform_one(raw_chatml: str, tok: AutoTokenizer) -> str: |
| """ |
| 完全基于你的逻辑: |
| - 闭合块 -> apply_chat_template(add_generation_prompt=False) |
| - 若有未闭合 assistant -> 直接拼 "<|im_start|>assistant\n<think>\n\n</think>\n\n{tail}" |
| """ |
| messages, tail_assistant = chatml_to_messages_and_tail(raw_chatml) |
|
|
| |
| rendered_closed = tok.apply_chat_template( |
| messages, |
| add_generation_prompt=False, |
| tokenize=False |
| ) |
| |
| if tail_assistant is not None: |
| tail_assistant = tail_assistant.rstrip() |
| |
| |
| final = rendered_closed + f"<|im_start|>assistant\n{tail_assistant}" |
| else: |
| final = rendered_closed |
| return final |
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--input", required=True, help="输入 parquet 路径") |
| ap.add_argument("--output", required=True, help="输出 parquet 路径") |
| ap.add_argument("--model", default="deeppin/Qwen3-Reranker-8B-SequenceClassification", |
| help="用于 apply_chat_template 的 tokenizer 模型名/路径") |
| ap.add_argument("--column", default="chosen_prompt", help="需要转换的列名") |
| ap.add_argument("--out_column", default=None, |
| help="输出列名(不填则覆盖原列)") |
| args = ap.parse_args() |
|
|
| df = pd.read_parquet(args.input) |
| if args.column not in df.columns: |
| raise ValueError(f"找不到列:{args.column}") |
|
|
| tok = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True, use_fast=False) |
|
|
| out_col = args.out_column or args.column |
| df[out_col] = df[args.column].apply(lambda s: transform_one(s, tok)) |
|
|
| df.to_parquet(args.output, index=False) |
| print(f"Done. Wrote: {args.output} (transformed column: {out_col})") |
|
|
| if __name__ == "__main__": |
| main() |
|
|