GeeeekExplorer commited on
Commit
24d73e8
·
verified ·
1 Parent(s): 2b88d47

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. inference/encoding_dsv4.py +744 -0
  2. inference/generate.py +4 -0
inference/encoding_dsv4.py ADDED
@@ -0,0 +1,744 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DeepSeek-V4 Encoding
3
+
4
+ A self-contained implementation for encoding/decoding DeepSeek-V4 chat messages
5
+ with tool calling, thinking mode, and quick instruction task support.
6
+ """
7
+
8
+ from typing import Any, Dict, List, Union, Optional, Tuple
9
+ import copy
10
+ import json
11
+ import re
12
+
13
+ # ============================================================
14
+ # Special Tokens
15
+ # ============================================================
16
+
17
+ bos_token: str = "<|begin▁of▁sentence|>"
18
+ eos_token: str = "<|end▁of▁sentence|>"
19
+ thinking_start_token: str = "<think>"
20
+ thinking_end_token: str = "</think>"
21
+ dsml_token: str = "|DSML|"
22
+
23
+ USER_SP_TOKEN = "<|User|>"
24
+ ASSISTANT_SP_TOKEN = "<|Assistant|>"
25
+ LATEST_REMINDER_SP_TOKEN = "<|latest_reminder|>"
26
+
27
+ # Task special tokens for internal classification tasks
28
+ DS_TASK_SP_TOKENS = {
29
+ "action": "<|action|>",
30
+ "query": "<|query|>",
31
+ "authority": "<|authority|>",
32
+ "domain": "<|domain|>",
33
+ "title": "<|title|>",
34
+ "read_url": "<|read_url|>",
35
+ }
36
+ VALID_TASKS = set(DS_TASK_SP_TOKENS.keys())
37
+
38
+ # ============================================================
39
+ # Templates
40
+ # ============================================================
41
+
42
+ system_msg_template: str = "{content}"
43
+ user_msg_template: str = "{content}"
44
+ latest_reminder_msg_template: str = "{content}"
45
+ assistant_msg_template: str = "{reasoning}{content}{tool_calls}" + eos_token
46
+ assistant_msg_wo_eos_template: str = "{reasoning}{content}{tool_calls}"
47
+ thinking_template: str = "{reasoning_content}"
48
+
49
+ response_format_template: str = (
50
+ "## Response Format:\n\nYou MUST strictly adhere to the following schema to reply:\n{schema}"
51
+ )
52
+ tool_call_template: str = (
53
+ "<{dsml_token}invoke name=\"{name}\">\n{arguments}\n</{dsml_token}invoke>"
54
+ )
55
+ tool_calls_template = (
56
+ "<{dsml_token}{tc_block_name}>\n{tool_calls}\n</{dsml_token}{tc_block_name}>"
57
+ )
58
+ tool_calls_block_name: str = "tool_calls"
59
+
60
+ tool_output_template: str = (
61
+ "<tool_result>{content}</tool_result>"
62
+ )
63
+
64
+ REASONING_EFFORT_MAX = (
65
+ "Reasoning Effort: Absolute maximum with no shortcuts permitted.\n"
66
+ "You MUST be very thorough in your thinking and comprehensively decompose the problem to resolve the root cause, rigorously stress-testing your logic against all potential paths, edge cases, and adversarial scenarios.\n"
67
+ "Explicitly write out your entire deliberation process, documenting every intermediate step, considered alternative, and rejected hypothesis to ensure absolutely no assumption is left unchecked.\n\n"
68
+ )
69
+
70
+ TOOLS_TEMPLATE = """## Tools
71
+
72
+ You have access to a set of tools to help answer the user's question. You can invoke tools by writing a "<{dsml_token}tool_calls>" block like the following:
73
+
74
+ <{dsml_token}tool_calls>
75
+ <{dsml_token}invoke name="$TOOL_NAME">
76
+ <{dsml_token}parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE</{dsml_token}parameter>
77
+ ...
78
+ </{dsml_token}invoke>
79
+ <{dsml_token}invoke name="$TOOL_NAME2">
80
+ ...
81
+ </{dsml_token}invoke>
82
+ </{dsml_token}tool_calls>
83
+
84
+ String parameters should be specified as is and set `string="true"`. For all other types (numbers, booleans, arrays, objects), pass the value in JSON format and set `string="false"`.
85
+
86
+ If thinking_mode is enabled (triggered by {thinking_start_token}), you MUST output your complete reasoning inside {thinking_start_token}...{thinking_end_token} BEFORE any tool calls or final response.
87
+
88
+ Otherwise, output directly after {thinking_end_token} with tool calls or final response.
89
+
90
+ ### Available Tool Schemas
91
+
92
+ {tool_schemas}
93
+
94
+ You MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls.
95
+ """
96
+
97
+ # ============================================================
98
+ # Utility Functions
99
+ # ============================================================
100
+
101
+ def to_json(value: Any) -> str:
102
+ """Serialize a value to JSON string."""
103
+ try:
104
+ return json.dumps(value, ensure_ascii=False)
105
+ except:
106
+ return json.dumps(value, ensure_ascii=True)
107
+
108
+
109
+ def tools_from_openai_format(tools):
110
+ """Extract function definitions from OpenAI-format tool list."""
111
+ return [tool["function"] for tool in tools]
112
+
113
+
114
+ def tool_calls_from_openai_format(tool_calls):
115
+ """Convert OpenAI-format tool calls to internal format."""
116
+ return [
117
+ {
118
+ "name": tool_call["function"]["name"],
119
+ "arguments": tool_call["function"]["arguments"],
120
+ }
121
+ for tool_call in tool_calls
122
+ ]
123
+
124
+
125
+ def tool_calls_to_openai_format(tool_calls):
126
+ """Convert internal tool calls to OpenAI format."""
127
+ return [
128
+ {
129
+ "type": "function",
130
+ "function": {
131
+ "name": tool_call["name"],
132
+ "arguments": tool_call["arguments"],
133
+ }
134
+ }
135
+ for tool_call in tool_calls
136
+ ]
137
+
138
+
139
+ def encode_arguments_to_dsml(tool_call: Dict[str, str]) -> str:
140
+ """
141
+ Encode tool call arguments into DSML parameter format.
142
+
143
+ Args:
144
+ tool_call: Dict with "name" and "arguments" (JSON string) keys.
145
+
146
+ Returns:
147
+ DSML-formatted parameter string.
148
+ """
149
+ p_dsml_template = '<{dsml_token}parameter name="{key}" string="{is_str}">{value}</{dsml_token}parameter>'
150
+ P_dsml_strs = []
151
+
152
+ try:
153
+ arguments = json.loads(tool_call["arguments"])
154
+ except Exception as err:
155
+ arguments = {"arguments": tool_call["arguments"]}
156
+
157
+ for k, v in arguments.items():
158
+ p_dsml_str = p_dsml_template.format(
159
+ dsml_token=dsml_token,
160
+ key=k,
161
+ is_str="true" if isinstance(v, str) else "false",
162
+ value=v if isinstance(v, str) else to_json(v),
163
+ )
164
+ P_dsml_strs.append(p_dsml_str)
165
+
166
+ return "\n".join(P_dsml_strs)
167
+
168
+
169
+ def decode_dsml_to_arguments(tool_name: str, tool_args: Dict[str, Tuple[str, str]]) -> Dict[str, str]:
170
+ """
171
+ Decode DSML parameters back to a tool call dict.
172
+
173
+ Args:
174
+ tool_name: Name of the tool.
175
+ tool_args: Dict mapping param_name -> (value, is_string_flag).
176
+
177
+ Returns:
178
+ Dict with "name" and "arguments" (JSON string) keys.
179
+ """
180
+ def _decode_value(key: str, value: str, string: str):
181
+ if string == "true":
182
+ value = to_json(value)
183
+ return f"{to_json(key)}: {value}"
184
+
185
+ tool_args_json = "{" + ", ".join([_decode_value(k, v, string=is_str) for k, (v, is_str) in tool_args.items()]) + "}"
186
+ return dict(name=tool_name, arguments=tool_args_json)
187
+
188
+
189
+ def render_tools(tools: List[Dict[str, Union[str, Dict[str, Any]]]]) -> str:
190
+ """
191
+ Render tool schemas into the system prompt format.
192
+
193
+ Args:
194
+ tools: List of tool schema dicts (each with name, description, parameters).
195
+
196
+ Returns:
197
+ Formatted tools section string.
198
+ """
199
+ tools_json = [to_json(t) for t in tools]
200
+
201
+ return TOOLS_TEMPLATE.format(
202
+ tool_schemas="\n".join(tools_json),
203
+ dsml_token=dsml_token,
204
+ thinking_start_token=thinking_start_token,
205
+ thinking_end_token=thinking_end_token,
206
+ )
207
+
208
+
209
+ def find_last_user_index(messages: List[Dict[str, Any]]) -> int:
210
+ """Find the index of the last user/developer message."""
211
+ last_user_index = -1
212
+ for idx in range(len(messages) - 1, -1, -1):
213
+ if messages[idx].get("role") in ["user", "developer"]:
214
+ last_user_index = idx
215
+ break
216
+ return last_user_index
217
+
218
+
219
+ # ============================================================
220
+ # Message Rendering
221
+ # ============================================================
222
+
223
+ def render_message(index: int, messages: List[Dict[str, Any]], thinking_mode: str, drop_thinking: bool = True, reasoning_effort: Optional[str] = None) -> str:
224
+ """
225
+ Render a single message at the given index into its encoded string form.
226
+
227
+ This is the core function that converts each message in the conversation
228
+ into the DeepSeek-V4 format.
229
+
230
+ Args:
231
+ index: Index of the message to render.
232
+ messages: Full list of messages in the conversation.
233
+ thinking_mode: Either "chat" or "thinking".
234
+ drop_thinking: Whether to drop reasoning content from earlier turns.
235
+ reasoning_effort: Optional reasoning effort level ("max", "high", or None).
236
+
237
+ Returns:
238
+ Encoded string for this message.
239
+ """
240
+ assert 0 <= index < len(messages)
241
+ assert thinking_mode in ["chat", "thinking"], f"Invalid thinking_mode `{thinking_mode}`"
242
+
243
+ prompt = ""
244
+ msg = messages[index]
245
+ last_user_idx = find_last_user_index(messages)
246
+
247
+ role = msg.get("role")
248
+ content = msg.get("content")
249
+ tools = msg.get("tools")
250
+ response_format = msg.get("response_format")
251
+ tool_calls = msg.get("tool_calls")
252
+ reasoning_content = msg.get("reasoning_content")
253
+ wo_eos = msg.get("wo_eos", False)
254
+
255
+ if tools:
256
+ tools = tools_from_openai_format(tools)
257
+ if tool_calls:
258
+ tool_calls = tool_calls_from_openai_format(tool_calls)
259
+
260
+ # Reasoning effort prefix (only at index 0 in thinking mode with max effort)
261
+ assert reasoning_effort in ['max', None, 'high'], f"Invalid reasoning effort: {reasoning_effort}"
262
+ if index == 0 and thinking_mode == "thinking" and reasoning_effort == 'max':
263
+ prompt += REASONING_EFFORT_MAX
264
+
265
+ if role == "system":
266
+ prompt += system_msg_template.format(content=content or "")
267
+ if tools:
268
+ prompt += "\n\n" + render_tools(tools)
269
+ if response_format:
270
+ prompt += "\n\n" + response_format_template.format(schema=to_json(response_format))
271
+
272
+ elif role == "developer":
273
+ assert content, f"Invalid message for role `{role}`: {msg}"
274
+
275
+ content_developer = USER_SP_TOKEN
276
+ content_developer += content
277
+
278
+ if tools:
279
+ content_developer += "\n\n" + render_tools(tools)
280
+ if response_format:
281
+ content_developer += "\n\n" + response_format_template.format(schema=to_json(response_format))
282
+
283
+ prompt += user_msg_template.format(content=content_developer)
284
+
285
+ elif role == "user":
286
+ prompt += USER_SP_TOKEN
287
+
288
+ # Handle content blocks (tool results mixed with text)
289
+ content_blocks = msg.get("content_blocks")
290
+ if content_blocks:
291
+ parts = []
292
+ for block in content_blocks:
293
+ block_type = block.get("type")
294
+ if block_type == "text":
295
+ parts.append(block.get("text", ""))
296
+ elif block_type == "tool_result":
297
+ tool_content = block.get("content", "")
298
+ if isinstance(tool_content, list):
299
+ text_parts = []
300
+ for b in tool_content:
301
+ if b.get("type") == "text":
302
+ text_parts.append(b.get("text", ""))
303
+ else:
304
+ text_parts.append(f"[Unsupported {b.get('type')}]")
305
+ tool_content = "\n\n".join(text_parts)
306
+ parts.append(tool_output_template.format(content=tool_content))
307
+ else:
308
+ parts.append(f"[Unsupported {block_type}]")
309
+ prompt += "\n\n".join(parts)
310
+ else:
311
+ prompt += content or ""
312
+
313
+ elif role == "latest_reminder":
314
+ prompt += LATEST_REMINDER_SP_TOKEN + latest_reminder_msg_template.format(content=content)
315
+
316
+ elif role == "tool":
317
+ raise NotImplementedError("deepseek_v4 merges tool messages into user; please preprocess with merge_tool_messages()")
318
+
319
+ elif role == "assistant":
320
+ thinking_part = ""
321
+ tc_content = ""
322
+
323
+ if tool_calls:
324
+ tc_list = [
325
+ tool_call_template.format(
326
+ dsml_token=dsml_token,
327
+ name=tc.get("name"),
328
+ arguments=encode_arguments_to_dsml(tc)
329
+ )
330
+ for tc in tool_calls
331
+ ]
332
+ tc_content += '\n\n' + tool_calls_template.format(
333
+ dsml_token=dsml_token,
334
+ tool_calls="\n".join(tc_list),
335
+ tc_block_name=tool_calls_block_name,
336
+ )
337
+
338
+ summary_content = content or ""
339
+ rc = reasoning_content or ""
340
+
341
+ # Check if previous message has a task - if so, this is a task output (no thinking)
342
+ prev_has_task = index - 1 >= 0 and messages[index - 1].get("task") is not None
343
+
344
+ if thinking_mode == "thinking" and not prev_has_task:
345
+ if not drop_thinking or index > last_user_idx:
346
+ thinking_part = thinking_template.format(reasoning_content=rc) + thinking_end_token
347
+ else:
348
+ thinking_part = ""
349
+
350
+ if wo_eos:
351
+ prompt += assistant_msg_wo_eos_template.format(
352
+ reasoning=thinking_part,
353
+ content=summary_content,
354
+ tool_calls=tc_content,
355
+ )
356
+ else:
357
+ prompt += assistant_msg_template.format(
358
+ reasoning=thinking_part,
359
+ content=summary_content,
360
+ tool_calls=tc_content,
361
+ )
362
+ else:
363
+ raise NotImplementedError(f"Unknown role: {role}")
364
+
365
+ # Append transition tokens based on what follows
366
+ if index + 1 < len(messages) and messages[index + 1].get("role") not in ["assistant", "latest_reminder"]:
367
+ return prompt
368
+
369
+ task = messages[index].get("task")
370
+ if task is not None:
371
+ # Task special token for internal classification tasks
372
+ assert task in VALID_TASKS, f"Invalid task: '{task}'. Valid tasks are: {list(VALID_TASKS)}"
373
+ task_sp_token = DS_TASK_SP_TOKENS[task]
374
+
375
+ if task != "action":
376
+ # Non-action tasks: append task sp token directly after the message
377
+ prompt += task_sp_token
378
+ else:
379
+ # Action task: append Assistant + thinking token + action sp token
380
+ prompt += ASSISTANT_SP_TOKEN
381
+ prompt += thinking_end_token if thinking_mode != "thinking" else thinking_start_token
382
+ prompt += task_sp_token
383
+
384
+ elif messages[index].get("role") in ["user", "developer"]:
385
+ # Normal generation: append Assistant + thinking token
386
+ prompt += ASSISTANT_SP_TOKEN
387
+ if not drop_thinking and thinking_mode == "thinking":
388
+ prompt += thinking_start_token
389
+ elif drop_thinking and thinking_mode == "thinking" and index >= last_user_idx:
390
+ prompt += thinking_start_token
391
+ else:
392
+ prompt += thinking_end_token
393
+
394
+ return prompt
395
+
396
+
397
+ # ============================================================
398
+ # Preprocessing
399
+ # ============================================================
400
+
401
+ def merge_tool_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
402
+ """
403
+ Merge tool messages into the preceding user message using content_blocks format.
404
+
405
+ DeepSeek-V4 does not have a standalone "tool" role; instead, tool results
406
+ are encoded as <tool_result> blocks within user messages.
407
+
408
+ This function converts a standard OpenAI-format conversation (with separate
409
+ "tool" role messages) into V4 format where tool results are merged into
410
+ user messages.
411
+
412
+ Args:
413
+ messages: List of message dicts in OpenAI format.
414
+
415
+ Returns:
416
+ Processed message list with tool messages merged into user messages.
417
+ """
418
+ merged: List[Dict[str, Any]] = []
419
+
420
+ for msg in messages:
421
+ msg = copy.deepcopy(msg)
422
+ role = msg.get("role")
423
+
424
+ if role == "tool":
425
+ # Convert tool message to a user message with tool_result block
426
+ tool_block = {
427
+ "type": "tool_result",
428
+ "tool_use_id": msg.get("tool_call_id", ""),
429
+ "content": msg.get("content", ""),
430
+ }
431
+ # Merge into previous message if it's already a user (merged tool)
432
+ if merged and merged[-1].get("role") == "user" and "content_blocks" in merged[-1]:
433
+ merged[-1]["content_blocks"].append(tool_block)
434
+ else:
435
+ merged.append({
436
+ "role": "user",
437
+ "content_blocks": [tool_block],
438
+ })
439
+ elif role == "user":
440
+ text_block = {"type": "text", "text": msg.get("content", "")}
441
+ if merged and merged[-1].get("role") == "user" and "content_blocks" in merged[-1] and merged[-1].get("task") is None:
442
+ merged[-1]["content_blocks"].append(text_block)
443
+ else:
444
+ new_msg = {
445
+ "role": "user",
446
+ "content": msg.get("content", ""),
447
+ "content_blocks": [text_block],
448
+ }
449
+ # Preserve extra fields (task, wo_eos, mask, etc.)
450
+ for key in ("task", "wo_eos", "mask"):
451
+ if key in msg:
452
+ new_msg[key] = msg[key]
453
+ merged.append(new_msg)
454
+ else:
455
+ merged.append(msg)
456
+
457
+ return merged
458
+
459
+
460
+ def sort_tool_results_by_call_order(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
461
+ """
462
+ Sort tool_result blocks within user messages by the order of tool_calls
463
+ in the preceding assistant message.
464
+
465
+ Args:
466
+ messages: Preprocessed message list (after merge_tool_messages).
467
+
468
+ Returns:
469
+ Message list with sorted tool result blocks.
470
+ """
471
+ last_tool_call_order: Dict[str, int] = {}
472
+
473
+ for msg in messages:
474
+ role = msg.get("role")
475
+ if role == "assistant" and msg.get("tool_calls"):
476
+ last_tool_call_order = {}
477
+ for idx, tc in enumerate(msg["tool_calls"]):
478
+ tc_id = tc.get("id") or tc.get("function", {}).get("id", "")
479
+ if tc_id:
480
+ last_tool_call_order[tc_id] = idx
481
+
482
+ elif role == "user" and msg.get("content_blocks"):
483
+ tool_blocks = [b for b in msg["content_blocks"] if b.get("type") == "tool_result"]
484
+ if len(tool_blocks) > 1 and last_tool_call_order:
485
+ sorted_blocks = sorted(
486
+ tool_blocks,
487
+ key=lambda b: last_tool_call_order.get(b.get("tool_use_id", ""), 0)
488
+ )
489
+ sorted_idx = 0
490
+ new_blocks = []
491
+ for block in msg["content_blocks"]:
492
+ if block.get("type") == "tool_result":
493
+ new_blocks.append(sorted_blocks[sorted_idx])
494
+ sorted_idx += 1
495
+ else:
496
+ new_blocks.append(block)
497
+ msg["content_blocks"] = new_blocks
498
+
499
+ return messages
500
+
501
+
502
+ # ============================================================
503
+ # Main Encoding Function
504
+ # ============================================================
505
+
506
+ def encode_messages(
507
+ messages: List[Dict[str, Any]],
508
+ thinking_mode: str,
509
+ context: Optional[List[Dict[str, Any]]] = None,
510
+ drop_thinking: bool = True,
511
+ add_default_bos_token: bool = True,
512
+ reasoning_effort: Optional[str] = None,
513
+ ) -> str:
514
+ """
515
+ Encode a list of messages into the DeepSeek-V4 prompt format.
516
+
517
+ This is the main entry point for encoding conversations. It handles:
518
+ - BOS token insertion
519
+ - Thinking mode with optional reasoning content dropping
520
+ - Tool message merging into user messages
521
+ - Multi-turn conversation context
522
+
523
+ Args:
524
+ messages: List of message dicts to encode.
525
+ thinking_mode: Either "chat" or "thinking".
526
+ context: Optional preceding context messages (already encoded prefix).
527
+ drop_thinking: If True, drop reasoning_content from earlier assistant turns
528
+ (only keep reasoning for messages after the last user message).
529
+ add_default_bos_token: Whether to prepend BOS token at conversation start.
530
+ reasoning_effort: Optional reasoning effort level ("max", "high", or None).
531
+
532
+ Returns:
533
+ The encoded prompt string.
534
+ """
535
+ context = context if context else []
536
+
537
+ # Preprocess: merge tool messages and sort tool results
538
+ messages = merge_tool_messages(messages)
539
+ messages = sort_tool_results_by_call_order(context + messages)[len(context):]
540
+ if context:
541
+ context = merge_tool_messages(context)
542
+ context = sort_tool_results_by_call_order(context)
543
+
544
+ full_messages = context + messages
545
+
546
+ prompt = bos_token if add_default_bos_token and len(context) == 0 else ""
547
+
548
+ # Resolve drop_thinking: if any message has tools defined, don't drop thinking
549
+ effective_drop_thinking = drop_thinking
550
+ if any(m.get("tools") for m in full_messages):
551
+ effective_drop_thinking = False
552
+
553
+ if thinking_mode == "thinking" and effective_drop_thinking:
554
+ full_messages = _drop_thinking_messages(full_messages)
555
+ # After dropping, recalculate how many messages to render
556
+ # (context may have shrunk too)
557
+ num_to_render = len(full_messages) - len(_drop_thinking_messages(context))
558
+ context_len = len(full_messages) - num_to_render
559
+ else:
560
+ num_to_render = len(messages)
561
+ context_len = len(context)
562
+
563
+ for idx in range(num_to_render):
564
+ prompt += render_message(
565
+ idx + context_len,
566
+ full_messages,
567
+ thinking_mode=thinking_mode,
568
+ drop_thinking=effective_drop_thinking,
569
+ reasoning_effort=reasoning_effort,
570
+ )
571
+
572
+ return prompt
573
+
574
+
575
+ def _drop_thinking_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
576
+ """
577
+ Drop reasoning_content and non-essential messages before the last user message.
578
+
579
+ Behavior:
580
+ - Messages with role in ["user", "system", "tool", "latest_reminder"] are always kept.
581
+ - Messages at or after the last user index are always kept.
582
+ - Assistant messages before the last user get reasoning_content removed.
583
+ - Developer messages before the last user are dropped entirely.
584
+ """
585
+ last_user_idx = find_last_user_index(messages)
586
+ result = []
587
+ keep_roles = {"user", "system", "tool", "latest_reminder", "direct_search_results"}
588
+
589
+ for idx, msg in enumerate(messages):
590
+ role = msg.get("role")
591
+ if role in keep_roles or idx >= last_user_idx:
592
+ result.append(msg)
593
+ elif role == "assistant":
594
+ msg = copy.copy(msg)
595
+ msg.pop("reasoning_content", None)
596
+ result.append(msg)
597
+ # developer and other roles before last_user_idx are dropped
598
+
599
+ return result
600
+
601
+
602
+ # ============================================================
603
+ # Parsing (Decoding model output)
604
+ # ============================================================
605
+
606
+ def _read_until_stop(index: int, text: str, stop: List[str]) -> Tuple[int, str, Optional[str]]:
607
+ """
608
+ Read text from index until one of the stop strings is found.
609
+
610
+ Returns:
611
+ Tuple of (new_index, content_before_stop, matched_stop_string_or_None).
612
+ """
613
+ min_pos = len(text)
614
+ matched_stop = None
615
+
616
+ for s in stop:
617
+ pos = text.find(s, index)
618
+ if pos != -1 and pos < min_pos:
619
+ min_pos = pos
620
+ matched_stop = s
621
+
622
+ if matched_stop:
623
+ content = text[index:min_pos]
624
+ return min_pos + len(matched_stop), content, matched_stop
625
+ else:
626
+ content = text[index:]
627
+ return len(text), content, None
628
+
629
+
630
+ def parse_tool_calls(index: int, text: str) -> Tuple[int, Optional[str], List[Dict[str, str]]]:
631
+ """
632
+ Parse DSML tool calls from text starting at the given index.
633
+
634
+ Args:
635
+ index: Starting position in text.
636
+ text: The full text to parse.
637
+
638
+ Returns:
639
+ Tuple of (new_index, last_stop_token, list_of_tool_call_dicts).
640
+ Each tool call dict has "name" and "arguments" keys.
641
+ """
642
+ tool_calls: List[Dict[str, Any]] = []
643
+ stop_token = None
644
+ tool_calls_end_token = f"</{dsml_token}{tool_calls_block_name}>"
645
+
646
+ while index < len(text):
647
+ index, _, stop_token = _read_until_stop(index, text, [f"<{dsml_token}invoke", tool_calls_end_token])
648
+ if _ != ">\n":
649
+ raise ValueError(f"Tool call format error: expected '>\\n' but got '{_}'")
650
+
651
+ if stop_token == tool_calls_end_token:
652
+ break
653
+
654
+ if stop_token is None:
655
+ raise ValueError("Missing special token in tool calls")
656
+
657
+ index, tool_name_content, stop_token = _read_until_stop(index, text, [f"<{dsml_token}parameter", f"</{dsml_token}invoke"])
658
+
659
+ p_tool_name = re.findall(r'^\s*name="(.*?)">\n$', tool_name_content, flags=re.DOTALL)
660
+ if len(p_tool_name) != 1:
661
+ raise ValueError(f"Tool name format error: '{tool_name_content}'")
662
+ tool_name = p_tool_name[0]
663
+
664
+ tool_args: Dict[str, Tuple[str, str]] = {}
665
+ while stop_token == f"<{dsml_token}parameter":
666
+ index, param_content, stop_token = _read_until_stop(index, text, [f"/{dsml_token}parameter"])
667
+
668
+ param_kv = re.findall(r'^ name="(.*?)" string="(true|false)">(.*?)<$', param_content, flags=re.DOTALL)
669
+ if len(param_kv) != 1:
670
+ raise ValueError(f"Parameter format error: '{param_content}'")
671
+ param_name, string, param_value = param_kv[0]
672
+
673
+ if param_name in tool_args:
674
+ raise ValueError(f"Duplicate parameter name: '{param_name}'")
675
+ tool_args[param_name] = (param_value, string)
676
+
677
+ index, content, stop_token = _read_until_stop(index, text, [f"<{dsml_token}parameter", f"</{dsml_token}invoke"])
678
+ if content != ">\n":
679
+ raise ValueError(f"Parameter format error: expected '>\\n' but got '{content}'")
680
+
681
+ tool_call = decode_dsml_to_arguments(tool_name=tool_name, tool_args=tool_args)
682
+ tool_calls.append(tool_call)
683
+
684
+ return index, stop_token, tool_calls
685
+
686
+
687
+ def parse_message_from_completion_text(text: str, thinking_mode: str) -> Dict[str, Any]:
688
+ """
689
+ Parse a model completion text into a structured assistant message.
690
+
691
+ This function takes the raw text output from the model (a single assistant turn)
692
+ and extracts:
693
+ - reasoning_content (thinking block)
694
+ - content (summary/response)
695
+ - tool_calls (if any)
696
+
697
+ NOTE: This function is designed to parse only correctly formatted strings and
698
+ will raise ValueError for malformed output.
699
+
700
+ Args:
701
+ text: The raw completion text (including EOS token).
702
+ thinking_mode: Either "chat" or "thinking".
703
+
704
+ Returns:
705
+ Dict with keys: "role", "content", "reasoning_content", "tool_calls".
706
+ tool_calls are in OpenAI format.
707
+ """
708
+ summary_content, reasoning_content, tool_calls = "", "", []
709
+ index, stop_token = 0, None
710
+ tool_calls_start_token = f"\n\n<{dsml_token}{tool_calls_block_name}"
711
+
712
+ is_thinking = thinking_mode == "thinking"
713
+ is_tool_calling = False
714
+
715
+ if is_thinking:
716
+ index, content_delta, stop_token = _read_until_stop(index, text, [thinking_end_token, tool_calls_start_token])
717
+ reasoning_content = content_delta
718
+ assert stop_token == thinking_end_token, "Invalid thinking format: missing </think>"
719
+
720
+ index, content_delta, stop_token = _read_until_stop(index, text, [eos_token, tool_calls_start_token])
721
+ summary_content = content_delta
722
+ if stop_token == tool_calls_start_token:
723
+ is_tool_calling = True
724
+ else:
725
+ assert stop_token == eos_token, "Invalid format: missing EOS token"
726
+
727
+ if is_tool_calling:
728
+ index, stop_token, tool_calls = parse_tool_calls(index, text)
729
+
730
+ index, tool_ends_text, stop_token = _read_until_stop(index, text, [eos_token])
731
+ assert not tool_ends_text, "Unexpected content after tool calls"
732
+
733
+ assert len(text) == index and stop_token in [eos_token, None], "Unexpected content at end"
734
+
735
+ for sp_token in [bos_token, eos_token, thinking_start_token, thinking_end_token, dsml_token]:
736
+ assert sp_token not in summary_content and sp_token not in reasoning_content, \
737
+ f"Unexpected special token '{sp_token}' in content"
738
+
739
+ return {
740
+ "role": "assistant",
741
+ "content": summary_content,
742
+ "reasoning_content": reasoning_content,
743
+ "tool_calls": tool_calls_to_openai_format(tool_calls)
744
+ }
inference/generate.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import json
 
3
  from argparse import ArgumentParser
4
  from typing import List
5
 
@@ -9,6 +10,9 @@ from transformers import AutoTokenizer
9
  from safetensors.torch import load_model
10
 
11
  from model import Transformer, ModelArgs
 
 
 
12
  from encoding_dsv4 import encode_messages, parse_message_from_completion_text
13
 
14
 
 
1
  import os
2
  import json
3
+ import sys
4
  from argparse import ArgumentParser
5
  from typing import List
6
 
 
10
  from safetensors.torch import load_model
11
 
12
  from model import Transformer, ModelArgs
13
+ current_dir = os.path.dirname(os.path.abspath(__file__))
14
+ encoding_dir = os.path.join(current_dir, '../encoding_dir')
15
+ sys.path.insert(0, os.path.abspath(encoding_dir))
16
  from encoding_dsv4 import encode_messages, parse_message_from_completion_text
17
 
18