File size: 6,134 Bytes
8ede856
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
from collections.abc import Callable

from .base import BaseChunker


class RecursiveCharacterChunker(BaseChunker):
    def __init__(
        self,
        chunk_size: int = 500,
        chunk_overlap: int = 100,
        length_function: Callable[[str], int] = len,
        is_separator_regex: bool = False,
        separators: list[str] | None = None,
    ) -> None:
        """初始化递归字符文本分割器

        Args:
            chunk_size: 每个文本块的最大大小
            chunk_overlap: 每个文本块之间的重叠部分大小
            length_function: 计算文本长度的函数
            is_separator_regex: 分隔符是否为正则表达式
            separators: 用于分割文本的分隔符列表,按优先级排序

        """
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
        self.length_function = length_function
        self.is_separator_regex = is_separator_regex

        # 默认分隔符列表,按优先级从高到低
        self.separators = separators or [
            "\n\n",  # 段落
            "\n",  # 换行
            "。",  # 中文句子
            ",",  # 中文逗号
            ". ",  # 句子
            ", ",  # 逗号分隔
            " ",  # 单词
            "",  # 字符
        ]

    async def chunk(self, text: str, **kwargs) -> list[str]:
        """递归地将文本分割成块

        Args:
            text: 要分割的文本
            chunk_size: 每个文本块的最大大小
            chunk_overlap: 每个文本块之间的重叠部分大小

        Returns:
            分割后的文本块列表

        """
        if not text:
            return []

        overlap = kwargs.get("chunk_overlap", self.chunk_overlap)
        chunk_size = kwargs.get("chunk_size", self.chunk_size)

        text_length = self.length_function(text)
        if text_length <= chunk_size:
            return [text]

        for separator in self.separators:
            if separator == "":
                return self._split_by_character(text, chunk_size, overlap)

            if separator in text:
                splits = text.split(separator)
                # 重新添加分隔符(除了最后一个片段)
                splits = [s + separator for s in splits[:-1]] + [splits[-1]]
                splits = [s for s in splits if s]
                if len(splits) == 1:
                    continue

                # 递归合并分割后的文本块
                final_chunks = []
                current_chunk = []
                current_chunk_length = 0

                for split in splits:
                    split_length = self.length_function(split)

                    # 如果单个分割部分已经超过了chunk_size,需要递归分割
                    if split_length > chunk_size:
                        # 先处理当前积累的块
                        if current_chunk:
                            combined_text = "".join(current_chunk)
                            final_chunks.extend(
                                await self.chunk(
                                    combined_text,
                                    chunk_size=chunk_size,
                                    chunk_overlap=overlap,
                                ),
                            )
                            current_chunk = []
                            current_chunk_length = 0

                        # 递归分割过大的部分
                        final_chunks.extend(
                            await self.chunk(
                                split,
                                chunk_size=chunk_size,
                                chunk_overlap=overlap,
                            ),
                        )
                    # 如果添加这部分会使当前块超过chunk_size
                    elif current_chunk_length + split_length > chunk_size:
                        # 合并当前块并添加到结果中
                        combined_text = "".join(current_chunk)
                        final_chunks.append(combined_text)

                        # 处理重叠部分
                        overlap_start = max(0, len(combined_text) - overlap)
                        if overlap_start > 0:
                            overlap_text = combined_text[overlap_start:]
                            current_chunk = [overlap_text, split]
                            current_chunk_length = (
                                self.length_function(overlap_text) + split_length
                            )
                        else:
                            current_chunk = [split]
                            current_chunk_length = split_length
                    else:
                        # 添加到当前块
                        current_chunk.append(split)
                        current_chunk_length += split_length

                # 处理剩余的块
                if current_chunk:
                    final_chunks.append("".join(current_chunk))

                return final_chunks

        return [text]

    def _split_by_character(
        self,
        text: str,
        chunk_size: int | None = None,
        overlap: int | None = None,
    ) -> list[str]:
        """按字符级别分割文本

        Args:
            text: 要分割的文本

        Returns:
            分割后的文本块列表

        """
        if chunk_size is None:
            chunk_size = self.chunk_size
        if overlap is None:
            overlap = self.chunk_overlap
        if chunk_size <= 0:
            raise ValueError("chunk_size must be greater than 0")
        if overlap < 0:
            raise ValueError("chunk_overlap must be non-negative")
        if overlap >= chunk_size:
            raise ValueError("chunk_overlap must be less than chunk_size")
        result = []
        for i in range(0, len(text), chunk_size - overlap):
            end = min(i + chunk_size, len(text))
            result.append(text[i:end])
            if end == len(text):
                break

        return result