| import json |
| import tiktoken |
| from tqdm import tqdm |
| from multiprocessing import Pool |
| import pandas as pd |
|
|
| |
| def init_process(): |
| global encoder |
| encoder = tiktoken.get_encoding("cl100k_base") |
|
|
| def calculate_tokens(obj): |
| """计算单个对象的token数量(子进程内部调用)""" |
| global encoder |
| total_text = [] |
| |
| try: |
| messages = obj.get("body", {}).get("messages", []) |
| for msg in messages: |
| |
| if msg.get("role") == "system": |
| content = msg.get("content", "") |
| if content: |
| total_text.append(content) |
| |
| |
| elif msg.get("role") == "user": |
| content = msg.get("content", []) |
| if isinstance(content, list): |
| for item in content: |
| if isinstance(item, dict) and item.get("type") == "text": |
| text = item.get("text", "") |
| if text: |
| total_text.append(text) |
| elif isinstance(content, dict) and content.get("type") == "text": |
| text = content.get("text", "") |
| if text: |
| total_text.append(text) |
| |
| |
| return len(encoder.encode("\n".join(total_text))) |
| |
| except Exception as e: |
| print(f"处理错误: {e} | 数据: {obj.get('custom_id')}") |
| return 0 |
|
|
| def process_line(line): |
| """处理单行数据""" |
| try: |
| data = json.loads(line) |
| return { |
| "custom_id": data.get("custom_id"), |
| "tokens": calculate_tokens(data) |
| } |
| except json.JSONDecodeError: |
| print(f"无效JSON: {line[:100]}...") |
| return None |
| except Exception as e: |
| print(f"全局错误: {e}") |
| return None |
|
|
| if __name__ == "__main__": |
| |
| with open("/mnt/data/users/zys/proj/vlm_reasoning/request/vqa_batch_requests.jsonl", "r") as f: |
| lines = f.readlines() |
|
|
| |
| with Pool(processes=8, initializer=init_process) as pool: |
| results = [] |
| with tqdm(total=len(lines), desc="处理进度") as pbar: |
| for result in pool.imap(process_line, lines): |
| if result is not None: |
| results.append(result) |
| pbar.update() |
| |
| |
| df = pd.DataFrame(results) |
| df.to_csv("token_results.csv", index=False) |
| |
| |
| total_tokens = df["tokens"].sum() |
| avg_tokens = df["tokens"].mean() |
| print(f"统计报告:\n" |
| f"- 总Token数: {total_tokens:,}\n" |
| f"- 平均每条: {avg_tokens:.1f}\n" |
| f"- 最大单条: {df['tokens'].max()}\n" |
| f"- 有效数据: {len(df)}/{len(lines)}") |