from collections.abc import Callable from .base import BaseChunker class RecursiveCharacterChunker(BaseChunker): def __init__( self, chunk_size: int = 500, chunk_overlap: int = 100, length_function: Callable[[str], int] = len, is_separator_regex: bool = False, separators: list[str] | None = None, ) -> None: """初始化递归字符文本分割器 Args: chunk_size: 每个文本块的最大大小 chunk_overlap: 每个文本块之间的重叠部分大小 length_function: 计算文本长度的函数 is_separator_regex: 分隔符是否为正则表达式 separators: 用于分割文本的分隔符列表,按优先级排序 """ self.chunk_size = chunk_size self.chunk_overlap = chunk_overlap self.length_function = length_function self.is_separator_regex = is_separator_regex # 默认分隔符列表,按优先级从高到低 self.separators = separators or [ "\n\n", # 段落 "\n", # 换行 "。", # 中文句子 ",", # 中文逗号 ". ", # 句子 ", ", # 逗号分隔 " ", # 单词 "", # 字符 ] async def chunk(self, text: str, **kwargs) -> list[str]: """递归地将文本分割成块 Args: text: 要分割的文本 chunk_size: 每个文本块的最大大小 chunk_overlap: 每个文本块之间的重叠部分大小 Returns: 分割后的文本块列表 """ if not text: return [] overlap = kwargs.get("chunk_overlap", self.chunk_overlap) chunk_size = kwargs.get("chunk_size", self.chunk_size) text_length = self.length_function(text) if text_length <= chunk_size: return [text] for separator in self.separators: if separator == "": return self._split_by_character(text, chunk_size, overlap) if separator in text: splits = text.split(separator) # 重新添加分隔符(除了最后一个片段) splits = [s + separator for s in splits[:-1]] + [splits[-1]] splits = [s for s in splits if s] if len(splits) == 1: continue # 递归合并分割后的文本块 final_chunks = [] current_chunk = [] current_chunk_length = 0 for split in splits: split_length = self.length_function(split) # 如果单个分割部分已经超过了chunk_size,需要递归分割 if split_length > chunk_size: # 先处理当前积累的块 if current_chunk: combined_text = "".join(current_chunk) final_chunks.extend( await self.chunk( combined_text, chunk_size=chunk_size, chunk_overlap=overlap, ), ) current_chunk = [] current_chunk_length = 0 # 递归分割过大的部分 final_chunks.extend( await self.chunk( split, chunk_size=chunk_size, chunk_overlap=overlap, ), ) # 如果添加这部分会使当前块超过chunk_size elif current_chunk_length + split_length > chunk_size: # 合并当前块并添加到结果中 combined_text = "".join(current_chunk) final_chunks.append(combined_text) # 处理重叠部分 overlap_start = max(0, len(combined_text) - overlap) if overlap_start > 0: overlap_text = combined_text[overlap_start:] current_chunk = [overlap_text, split] current_chunk_length = ( self.length_function(overlap_text) + split_length ) else: current_chunk = [split] current_chunk_length = split_length else: # 添加到当前块 current_chunk.append(split) current_chunk_length += split_length # 处理剩余的块 if current_chunk: final_chunks.append("".join(current_chunk)) return final_chunks return [text] def _split_by_character( self, text: str, chunk_size: int | None = None, overlap: int | None = None, ) -> list[str]: """按字符级别分割文本 Args: text: 要分割的文本 Returns: 分割后的文本块列表 """ if chunk_size is None: chunk_size = self.chunk_size if overlap is None: overlap = self.chunk_overlap if chunk_size <= 0: raise ValueError("chunk_size must be greater than 0") if overlap < 0: raise ValueError("chunk_overlap must be non-negative") if overlap >= chunk_size: raise ValueError("chunk_overlap must be less than chunk_size") result = [] for i in range(0, len(text), chunk_size - overlap): end = min(i + chunk_size, len(text)) result.append(text[i:end]) if end == len(text): break return result