Fix Gemma4 `response_schema` regex to capture content after tool calls

#19
by qgallouedec HF Staff - opened

The Gemma4 chat template renders tool calls before content text:

from transformers import AutoTokenizer

tok = AutoTokenizer.from_pretrained("google/gemma-4-31B-it")

messages = [
    {"role": "user", "content": "What is 3*4?"},
    {"role": "assistant", "content": "Let's call the tool.", "tool_calls": [
        {"type": "function", "function": {"name": "multiply", "arguments": {"a": 3, "b": 4}}}
    ]},
]
print(tok.apply_chat_template(messages, tokenize=False))

Output:

<bos><|turn>user
What is 3*4?<turn|>
<|turn>model
<|tool_call>call:multiply{a:3,b:4}<tool_call|>Let's call the tool.<turn|>

But the response_schema regex expects content before tool calls:

(?P<content>...)?(?P<tool_calls>...)?

So when parse_response runs, the content group matches nothing (it stops at <|tool_call>), the tool_calls group captures the tool call, and the trailing text Let's call the tool. is lost:

tok.parse_response(response_ids)
# {'role': 'assistant', 'tool_calls': [...]}
# content is missing!

This PR fixes it by swapping the order in the regex to match tool_calls first, then content:

# Before (content first, then tool_calls)
(?P<content>(?:(?!\<\|tool_call\>)(?!\<turn\|\>).)+)?(?P<tool_calls>\<\|tool_call\>.*\<tool_call\|\>)?

# After (tool_calls first, then content)
(?P<tool_calls>(?:\<\|tool_call\>.*?\<tool_call\|\>)+)?(?P<content>(?:(?!\<turn\|\>).)*)?
import re

fixed = r'(\<\|channel\>thought\n(?P<thinking>.*?)\<channel\|\>)?(?P<tool_calls>(?:\<\|tool_call\>.*?\<tool_call\|\>)+)?(?P<content>(?:(?!\<turn\|\>).)*)?(?:\<turn\|\>)?'

# Tool calls + content
text = "<|tool_call>call:multiply{a:3,b:4}<tool_call|>Let's call the tool.<turn|>"
m = re.match(fixed, text, re.DOTALL)
print(m.group("tool_calls"))  # '<|tool_call>call:multiply{a:3,b:4}<tool_call|>'
print(m.group("content"))     # "Let's call the tool."

# Tool calls only
text = "<|tool_call>call:multiply{a:3,b:4}<tool_call|><turn|>"
m = re.match(fixed, text, re.DOTALL)
print(m.group("tool_calls"))  # '<|tool_call>call:multiply{a:3,b:4}<tool_call|>'
print(m.group("content"))     # ''

# Content only
text = "The answer is 12.<turn|>"
m = re.match(fixed, text, re.DOTALL)
print(m.group("tool_calls"))  # None
print(m.group("content"))     # 'The answer is 12.'
Publish this branch
This branch is in draft mode, publish it to be able to merge.

Sign up or log in to comment