Cheeeeeeeeky commited on
Commit
4be43b8
·
verified ·
1 Parent(s): 49d8fcd

upload model

Browse files
Files changed (1) hide show
  1. encoding/test_encoding_dsv32.py +56 -0
encoding/test_encoding_dsv32.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import copy
3
+
4
+ from encoding_dsv32 import encode_messages, parse_message_from_completion_text
5
+
6
+ with open("test_input.json", "r") as f:
7
+ test_dict = json.load(f)
8
+ messages = test_dict["messages"]
9
+ messages[0]["tools"] = test_dict["tools"]
10
+
11
+ with open("test_output.txt", "r") as f:
12
+ gold_prompt = f.read().strip()
13
+
14
+ print(messages)
15
+ print("=" * 60)
16
+
17
+ encode_config = dict(thinking_mode="thinking", drop_thinking=True, add_default_bos_token=True)
18
+ prompt = encode_messages(messages, **encode_config)
19
+ print(prompt)
20
+ assert prompt == gold_prompt
21
+ print("=" * 60)
22
+
23
+ tool_call_message = messages[4]
24
+ tool_call_prompt = encode_messages([tool_call_message], context=messages[:4], **encode_config)
25
+ tool_call_message_wo_id = copy.deepcopy(tool_call_message)
26
+ for tool_call in tool_call_message_wo_id["tool_calls"]:
27
+ tool_call.pop("id")
28
+ parsed_tool_call_message = parse_message_from_completion_text(tool_call_prompt, thinking_mode="thinking")
29
+ parsed_tool_call_message.pop("content")
30
+ assert tool_call_message_wo_id == parsed_tool_call_message
31
+
32
+ thinking_message = messages[-6]
33
+ thinking_prompt = encode_messages([thinking_message], context=messages[:-6], **encode_config)
34
+ parsed_thinking_message = parse_message_from_completion_text(thinking_prompt, thinking_mode="thinking")
35
+ parsed_thinking_message.pop("tool_calls")
36
+ assert thinking_message == parsed_thinking_message
37
+
38
+ with open("test_input_search_wo_date.json", "r") as f:
39
+ search_messages = json.load(f)["messages"]
40
+
41
+ with open("test_output_search_wo_date.txt", "r") as f:
42
+ search_gold_prompt = f.read().strip()
43
+
44
+ search_prompt = encode_messages(search_messages, **encode_config)
45
+ assert search_prompt == search_gold_prompt
46
+
47
+ with open("test_input_search_w_date.json", "r") as f:
48
+ search_messages_w_date = json.load(f)["messages"]
49
+
50
+ with open("test_output_search_w_date.txt", "r") as f:
51
+ search_gold_prompt_w_date = f.read().strip()
52
+
53
+ search_prompt_w_date = encode_messages(search_messages_w_date, **encode_config)
54
+ with open("test_output_search_w_date_2.txt", "w") as f:
55
+ f.write(search_prompt_w_date)
56
+ assert search_prompt_w_date == search_gold_prompt_w_date