| from collections.abc import Callable |
|
|
| from .base import BaseChunker |
|
|
|
|
| class RecursiveCharacterChunker(BaseChunker): |
| def __init__( |
| self, |
| chunk_size: int = 500, |
| chunk_overlap: int = 100, |
| length_function: Callable[[str], int] = len, |
| is_separator_regex: bool = False, |
| separators: list[str] | None = None, |
| ) -> None: |
| """初始化递归字符文本分割器 |
| |
| Args: |
| chunk_size: 每个文本块的最大大小 |
| chunk_overlap: 每个文本块之间的重叠部分大小 |
| length_function: 计算文本长度的函数 |
| is_separator_regex: 分隔符是否为正则表达式 |
| separators: 用于分割文本的分隔符列表,按优先级排序 |
| |
| """ |
| self.chunk_size = chunk_size |
| self.chunk_overlap = chunk_overlap |
| self.length_function = length_function |
| self.is_separator_regex = is_separator_regex |
|
|
| |
| self.separators = separators or [ |
| "\n\n", |
| "\n", |
| "。", |
| ",", |
| ". ", |
| ", ", |
| " ", |
| "", |
| ] |
|
|
| async def chunk(self, text: str, **kwargs) -> list[str]: |
| """递归地将文本分割成块 |
| |
| Args: |
| text: 要分割的文本 |
| chunk_size: 每个文本块的最大大小 |
| chunk_overlap: 每个文本块之间的重叠部分大小 |
| |
| Returns: |
| 分割后的文本块列表 |
| |
| """ |
| if not text: |
| return [] |
|
|
| overlap = kwargs.get("chunk_overlap", self.chunk_overlap) |
| chunk_size = kwargs.get("chunk_size", self.chunk_size) |
|
|
| text_length = self.length_function(text) |
| if text_length <= chunk_size: |
| return [text] |
|
|
| for separator in self.separators: |
| if separator == "": |
| return self._split_by_character(text, chunk_size, overlap) |
|
|
| if separator in text: |
| splits = text.split(separator) |
| |
| splits = [s + separator for s in splits[:-1]] + [splits[-1]] |
| splits = [s for s in splits if s] |
| if len(splits) == 1: |
| continue |
|
|
| |
| final_chunks = [] |
| current_chunk = [] |
| current_chunk_length = 0 |
|
|
| for split in splits: |
| split_length = self.length_function(split) |
|
|
| |
| if split_length > chunk_size: |
| |
| if current_chunk: |
| combined_text = "".join(current_chunk) |
| final_chunks.extend( |
| await self.chunk( |
| combined_text, |
| chunk_size=chunk_size, |
| chunk_overlap=overlap, |
| ), |
| ) |
| current_chunk = [] |
| current_chunk_length = 0 |
|
|
| |
| final_chunks.extend( |
| await self.chunk( |
| split, |
| chunk_size=chunk_size, |
| chunk_overlap=overlap, |
| ), |
| ) |
| |
| elif current_chunk_length + split_length > chunk_size: |
| |
| combined_text = "".join(current_chunk) |
| final_chunks.append(combined_text) |
|
|
| |
| overlap_start = max(0, len(combined_text) - overlap) |
| if overlap_start > 0: |
| overlap_text = combined_text[overlap_start:] |
| current_chunk = [overlap_text, split] |
| current_chunk_length = ( |
| self.length_function(overlap_text) + split_length |
| ) |
| else: |
| current_chunk = [split] |
| current_chunk_length = split_length |
| else: |
| |
| current_chunk.append(split) |
| current_chunk_length += split_length |
|
|
| |
| if current_chunk: |
| final_chunks.append("".join(current_chunk)) |
|
|
| return final_chunks |
|
|
| return [text] |
|
|
| def _split_by_character( |
| self, |
| text: str, |
| chunk_size: int | None = None, |
| overlap: int | None = None, |
| ) -> list[str]: |
| """按字符级别分割文本 |
| |
| Args: |
| text: 要分割的文本 |
| |
| Returns: |
| 分割后的文本块列表 |
| |
| """ |
| if chunk_size is None: |
| chunk_size = self.chunk_size |
| if overlap is None: |
| overlap = self.chunk_overlap |
| if chunk_size <= 0: |
| raise ValueError("chunk_size must be greater than 0") |
| if overlap < 0: |
| raise ValueError("chunk_overlap must be non-negative") |
| if overlap >= chunk_size: |
| raise ValueError("chunk_overlap must be less than chunk_size") |
| result = [] |
| for i in range(0, len(text), chunk_size - overlap): |
| end = min(i + chunk_size, len(text)) |
| result.append(text[i:end]) |
| if end == len(text): |
| break |
|
|
| return result |
|
|