| from collections.abc import Sequence |
| from typing import Optional, Union |
|
|
| import regex as re |
| from transformers import PreTrainedTokenizerBase |
|
|
| from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, |
| DeltaMessage) |
| from vllm.logger import init_logger |
| from vllm.reasoning import ReasoningParser, ReasoningParserManager |
|
|
| logger = init_logger(__name__) |
|
|
|
|
| @ReasoningParserManager.register_module("greenmind_14b_r1") |
| class GreenMind14bR1ReasoningParser(ReasoningParser): |
| """ |
| Reasoning parser for GreenMind-14B-R1 model. |
| |
| The GreenMind-14B-R1 model uses </think> token to denote the end of reasoning |
| text. This parser extracts all content before </think> as reasoning content. |
| |
| think start: "<think>\n": [13708, 766, 397] |
| think ends: "\n</think>\n<answer>\n": [198, 522, 26865, 397, 27, 9217, 397] |
| response ends: "</answer>": [198, 522, 9217, 29] |
| """ |
|
|
| def __init__(self, tokenizer: PreTrainedTokenizerBase): |
| super().__init__(tokenizer) |
| self.think_start_expr = r"<think>\n" |
| self.think_end_expr = r"\n</think>\n" |
|
|
| self.response_start_expr = r"\n</think>\n<answer>" |
| self.response_end_expr = r"</answer>" |
|
|
| self.full_match_reasoning_regex = re.compile( |
| rf"(?:{self.think_start_expr}(.*?){self.response_start_expr})?(.*?){self.response_end_expr}", |
| re.DOTALL) |
|
|
| self.half_match_reasoning_regex = re.compile( |
| rf"{self.think_start_expr}(.*?){self.response_start_expr}(.*)", |
| re.DOTALL) |
|
|
| self.think_start_ids = [13708, 766, 397] |
| self.think_start_ids_fast = [27, 26865, 397] |
| self.response_start_ids = [522, 26865, 397, 27, 9217, 397] |
| self.response_start_ids_fast = [522, 26865, 397, 27, 9217, 29] |
| self.response_end_ids = [522, 9217, 29] |
| self.fast_think_ids = [ |
| 13708, 766, 1339, 522, 26865, 397, 27, 9217, 397 |
| ] |
|
|
| |
| self.buffered_text = [] |
| self.buffered_ids = [] |
|
|
| self.current_state = "reasoning" |
| self.all_states = ["reasoning", "response"] |
|
|
| self.current_state = "idle" |
| self.expected_sequence = self.think_start_ids |
| |
| self.expected_sequence_side = self.think_start_ids_fast |
| self.sequence_index = 0 |
| self.token_buffer = [] |
| self.text_buffer = "" |
|
|
| def is_reasoning_end(self, input_ids: list[int]) -> bool: |
| return self.current_state == "response" |
|
|
| def extract_content_ids(self, input_ids: list[int]) -> list[int]: |
| |
| |
| |
| |
| return [] |
|
|
| def extract_reasoning_content( |
| self, model_output: str, request: ChatCompletionRequest |
| ) -> tuple[Optional[str], Optional[str]]: |
| """Extract the reasoning content & content sections, respectively. |
| If the sequence doesn't match what we expect, i.e., the model generates |
| something else, all content is considered non-reasoning content. |
| |
| Args: |
| model_output (str): Output of the model to be parsed. |
| request (ChatCompletionRequest): Request being processed. |
| |
| Returns: |
| tuple[Optional[str], Optional[str]]: Tuple pair containing the |
| reasoning content and non-reasoning content. |
| """ |
|
|
| re_match = self.full_match_reasoning_regex.findall(model_output) |
| if re_match: |
| reasoning_content, response_content = re_match[0] |
| if len(reasoning_content) == 0: |
| reasoning_content = None |
| if len(response_content) == 0: |
| response_content = None |
| return reasoning_content, response_content |
|
|
| fallback_regex = self.half_match_reasoning_regex |
| fallback_match = fallback_regex.findall(model_output) |
| if fallback_match: |
| reasoning_content, response_content = fallback_match[0] |
|
|
| if response_content.endswith(self.response_end_expr): |
| response_content = response_content[:-len(self. |
| response_end_expr)] |
|
|
| if len(reasoning_content) == 0: |
| reasoning_content = None |
| if len(response_content) == 0: |
| response_content = None |
|
|
| return reasoning_content, response_content |
|
|
| return None, model_output |
|
|
| def _is_strict_increasing_subsequence(self, subsequence: Sequence[int], |
| sequence: Sequence[int]) -> bool: |
| if not subsequence: |
| return False |
|
|
| sub_idx = 0 |
| for num in sequence: |
| if sub_idx < len(subsequence) and num == subsequence[sub_idx]: |
| sub_idx += 1 |
| return sub_idx == len(subsequence) |
|
|
| def extract_reasoning_content_streaming( |
| self, |
| previous_text: str, |
| current_text: str, |
| delta_text: str, |
| previous_token_ids: Sequence[int], |
| current_token_ids: Sequence[int], |
| delta_token_ids: Sequence[int], |
| ) -> Union[DeltaMessage, None]: |
| """Extract content using token ID sequence state machine""" |
| |
| think_start_sequence = self.think_start_ids |
| response_start_sequence = self.response_start_ids |
| response_end_sequence = self.response_end_ids |
|
|
| assert (len(delta_token_ids) == 1) |
| |
| token = delta_token_ids[0] |
|
|
| def check_token_with_sequence(token): |
| if self.current_state == "idle" or self.current_state == "think": |
| return (token == self.expected_sequence[self.sequence_index] |
| or token == \ |
| self.expected_sequence_side[self.sequence_index]) |
| else: |
| return token == self.expected_sequence[self.sequence_index] |
|
|
| def check_last_token(token): |
| if self.current_state == "idle" or self.current_state == "think": |
| |
| if (self.sequence_index - 1 < len(self.expected_sequence_side) |
| and token |
| == self.expected_sequence_side[self.sequence_index - |
| 1]): |
| return self.sequence_index == len( |
| self.expected_sequence_side) |
| else: |
| return self.sequence_index == len(self.expected_sequence) |
| else: |
| return self.sequence_index == len(self.expected_sequence) |
|
|
| |
| token_in_state_seq = check_token_with_sequence(token) |
|
|
| if token_in_state_seq: |
| |
| self.token_buffer.append(token) |
| self.text_buffer += delta_text |
| self.sequence_index += 1 |
| |
|
|
| |
| if check_last_token(token): |
| |
| if self.current_state == "idle": |
| self.current_state = "think" |
| self.expected_sequence = response_start_sequence |
| self.expected_sequence_side = self.response_start_ids_fast |
| elif self.current_state == "think": |
| self.current_state = "response" |
| self.expected_sequence = response_end_sequence |
| elif self.current_state == "response": |
| self.current_state = "idle" |
| self.expected_sequence = think_start_sequence |
| self.expected_sequence_side = self.think_start_ids_fast |
|
|
| |
| self.sequence_index = 0 |
| self.token_buffer = [] |
| self.text_buffer = "" |
| |
| else: |
| |
| if self.token_buffer and len(self.token_buffer) > 0: |
| |
| buffered_content = self.text_buffer + delta_text |
| |
| self.sequence_index = 0 |
| self.token_buffer = [] |
| self.text_buffer = "" |
|
|
| |
| if self.current_state == "think": |
| return DeltaMessage(reasoning_content=buffered_content, |
| content=None) |
| else: |
| return DeltaMessage(reasoning_content=None, |
| content=buffered_content) |
| else: |
| |
| if self.current_state == "think": |
| return DeltaMessage(reasoning_content=delta_text, |
| content=None) |
| else: |
| return DeltaMessage(reasoning_content=None, |
| content=delta_text) |
|
|
| |
| return None |
|
|
|
|