upload model
Browse files
encoding/test_encoding_dsv32.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import copy
|
| 3 |
+
|
| 4 |
+
from encoding_dsv32 import encode_messages, parse_message_from_completion_text
|
| 5 |
+
|
| 6 |
+
with open("test_input.json", "r") as f:
|
| 7 |
+
test_dict = json.load(f)
|
| 8 |
+
messages = test_dict["messages"]
|
| 9 |
+
messages[0]["tools"] = test_dict["tools"]
|
| 10 |
+
|
| 11 |
+
with open("test_output.txt", "r") as f:
|
| 12 |
+
gold_prompt = f.read().strip()
|
| 13 |
+
|
| 14 |
+
print(messages)
|
| 15 |
+
print("=" * 60)
|
| 16 |
+
|
| 17 |
+
encode_config = dict(thinking_mode="thinking", drop_thinking=True, add_default_bos_token=True)
|
| 18 |
+
prompt = encode_messages(messages, **encode_config)
|
| 19 |
+
print(prompt)
|
| 20 |
+
assert prompt == gold_prompt
|
| 21 |
+
print("=" * 60)
|
| 22 |
+
|
| 23 |
+
tool_call_message = messages[4]
|
| 24 |
+
tool_call_prompt = encode_messages([tool_call_message], context=messages[:4], **encode_config)
|
| 25 |
+
tool_call_message_wo_id = copy.deepcopy(tool_call_message)
|
| 26 |
+
for tool_call in tool_call_message_wo_id["tool_calls"]:
|
| 27 |
+
tool_call.pop("id")
|
| 28 |
+
parsed_tool_call_message = parse_message_from_completion_text(tool_call_prompt, thinking_mode="thinking")
|
| 29 |
+
parsed_tool_call_message.pop("content")
|
| 30 |
+
assert tool_call_message_wo_id == parsed_tool_call_message
|
| 31 |
+
|
| 32 |
+
thinking_message = messages[-6]
|
| 33 |
+
thinking_prompt = encode_messages([thinking_message], context=messages[:-6], **encode_config)
|
| 34 |
+
parsed_thinking_message = parse_message_from_completion_text(thinking_prompt, thinking_mode="thinking")
|
| 35 |
+
parsed_thinking_message.pop("tool_calls")
|
| 36 |
+
assert thinking_message == parsed_thinking_message
|
| 37 |
+
|
| 38 |
+
with open("test_input_search_wo_date.json", "r") as f:
|
| 39 |
+
search_messages = json.load(f)["messages"]
|
| 40 |
+
|
| 41 |
+
with open("test_output_search_wo_date.txt", "r") as f:
|
| 42 |
+
search_gold_prompt = f.read().strip()
|
| 43 |
+
|
| 44 |
+
search_prompt = encode_messages(search_messages, **encode_config)
|
| 45 |
+
assert search_prompt == search_gold_prompt
|
| 46 |
+
|
| 47 |
+
with open("test_input_search_w_date.json", "r") as f:
|
| 48 |
+
search_messages_w_date = json.load(f)["messages"]
|
| 49 |
+
|
| 50 |
+
with open("test_output_search_w_date.txt", "r") as f:
|
| 51 |
+
search_gold_prompt_w_date = f.read().strip()
|
| 52 |
+
|
| 53 |
+
search_prompt_w_date = encode_messages(search_messages_w_date, **encode_config)
|
| 54 |
+
with open("test_output_search_w_date_2.txt", "w") as f:
|
| 55 |
+
f.write(search_prompt_w_date)
|
| 56 |
+
assert search_prompt_w_date == search_gold_prompt_w_date
|