Fix Gemma4 `response_schema` regex to capture content after tool calls
#19
by qgallouedec HF Staff - opened
The Gemma4 chat template renders tool calls before content text:
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained("google/gemma-4-31B-it")
messages = [
{"role": "user", "content": "What is 3*4?"},
{"role": "assistant", "content": "Let's call the tool.", "tool_calls": [
{"type": "function", "function": {"name": "multiply", "arguments": {"a": 3, "b": 4}}}
]},
]
print(tok.apply_chat_template(messages, tokenize=False))
Output:
<bos><|turn>user
What is 3*4?<turn|>
<|turn>model
<|tool_call>call:multiply{a:3,b:4}<tool_call|>Let's call the tool.<turn|>
But the response_schema regex expects content before tool calls:
(?P<content>...)?(?P<tool_calls>...)?
So when parse_response runs, the content group matches nothing (it stops at <|tool_call>), the tool_calls group captures the tool call, and the trailing text Let's call the tool. is lost:
tok.parse_response(response_ids)
# {'role': 'assistant', 'tool_calls': [...]}
# content is missing!
This PR fixes it by swapping the order in the regex to match tool_calls first, then content:
# Before (content first, then tool_calls)
(?P<content>(?:(?!\<\|tool_call\>)(?!\<turn\|\>).)+)?(?P<tool_calls>\<\|tool_call\>.*\<tool_call\|\>)?
# After (tool_calls first, then content)
(?P<tool_calls>(?:\<\|tool_call\>.*?\<tool_call\|\>)+)?(?P<content>(?:(?!\<turn\|\>).)*)?
import re
fixed = r'(\<\|channel\>thought\n(?P<thinking>.*?)\<channel\|\>)?(?P<tool_calls>(?:\<\|tool_call\>.*?\<tool_call\|\>)+)?(?P<content>(?:(?!\<turn\|\>).)*)?(?:\<turn\|\>)?'
# Tool calls + content
text = "<|tool_call>call:multiply{a:3,b:4}<tool_call|>Let's call the tool.<turn|>"
m = re.match(fixed, text, re.DOTALL)
print(m.group("tool_calls")) # '<|tool_call>call:multiply{a:3,b:4}<tool_call|>'
print(m.group("content")) # "Let's call the tool."
# Tool calls only
text = "<|tool_call>call:multiply{a:3,b:4}<tool_call|><turn|>"
m = re.match(fixed, text, re.DOTALL)
print(m.group("tool_calls")) # '<|tool_call>call:multiply{a:3,b:4}<tool_call|>'
print(m.group("content")) # ''
# Content only
text = "The answer is 12.<turn|>"
m = re.match(fixed, text, re.DOTALL)
print(m.group("tool_calls")) # None
print(m.group("content")) # 'The answer is 12.'