| from typing import Literal |
|
|
| from transformers import AutoTokenizer |
| from langchain_text_splitters import RecursiveCharacterTextSplitter, NLTKTextSplitter |
| from langchain_experimental.text_splitter import SemanticChunker |
| from langchain_huggingface.embeddings import HuggingFaceEmbeddings |
|
|
|
|
| class Splitter: |
| """ |
| Класс описывает функционал разделения текста на чанки тремя способами на выбор: |
| - рекурсивно разбивая чанки различными разделителями |
| в порядке возрастания "жесткости" их эффекта; |
| |
| - объединяя выделенные с помощью библиотеки NLTK предложения |
| в чанки определенного размера и с наложением; |
| |
| - разбивая текст на семантически связанные блоки |
| с помощью векторных представлений текстов; |
| """ |
|
|
| def __init__( |
| self, |
| mode: Literal["recursive", "nltk", "semantic"], |
| model_name: str = "deepvk/USER-bge-m3", |
| chunk_size: int = 256, |
| chunk_overlap: int = 64, |
| **splitter_kwargs, |
| ): |
| self.chunk_size = chunk_size |
| self.chunk_overlap = chunk_overlap |
|
|
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
| match mode: |
|
|
| case "recursive": |
| self.splitter = RecursiveCharacterTextSplitter( |
| separators=[ |
| "\n### ", "\n## ", "\n# ", |
| "\n\n", "\n", |
| "!", "?", ". ", ";", ",", ")", " ", "", |
| ], |
| keep_separator="end", |
| chunk_size=chunk_size, |
| chunk_overlap=chunk_overlap, |
| length_function=lambda x: len(self.tokenizer.encode(x, add_special_tokens=False)), |
| **splitter_kwargs, |
| ) |
| self.split_fn = self._recursive_split |
| |
| case "nltk": |
| self.splitter = NLTKTextSplitter( |
| language="russian", |
| **splitter_kwargs, |
| ) |
| self.split_fn = self._nltk_split |
| |
| case "semantic": |
| self.splitter = SemanticChunker( |
| HuggingFaceEmbeddings( |
| model_name=model_name, |
| encode_kwargs={"normalize_embeddings": True}, |
| ), |
| **splitter_kwargs, |
| ) |
| self.split_fn = self._semantic_split |
|
|
|
|
| def split_text(self, text: str) -> list[str]: |
| """ |
| Доступная пользователю функция разделения текста на чанки |
| """ |
| return self.split_fn(text) |
| |
|
|
| def _recursive_split(self, text: str) -> list[str]: |
| """ |
| Функция разделения текста на чанки при self.splitter == RecursiveCharacterTextSplitter |
| """ |
| return [ |
| chunk |
| for chunk in self.splitter.split_text(text) |
| if any(ch.isalpha() for ch in set(chunk)) |
| ] |
| |
| |
| def _nltk_split(self, text: str) -> list[str]: |
| """ |
| Функция разделения текста на чанки при self.splitter == NLTKTextSplitter |
| """ |
| sentences = self.splitter.split_text(text)[0].split("\n\n") |
| sent_sizes = [ |
| len(self.tokenizer.encode(sent, add_special_tokens=False)) |
| for sent in sentences |
| ] |
|
|
| chunks = [] |
| i, n = 0, len(sentences) |
| while i < n: |
| cur_len, cur_texts = 0, [] |
|
|
| |
| j = i |
| while (j < n) and (cur_len + sent_sizes[j] <= self.chunk_size): |
| cur_texts.append(sentences[j]) |
| cur_len += sent_sizes[j] |
| j += 1 |
|
|
| chunks.append(cur_texts) |
|
|
| |
| if j >= n: |
| break |
|
|
| |
| overlap_len, k = 0, j - 1 |
| while (k >= i) and (overlap_len + sent_sizes[k] <= self.chunk_overlap): |
| overlap_len += sent_sizes[k] |
| k -= 1 |
|
|
| |
| i = k + 1 |
|
|
| return chunks |
| |
|
|
| def _semantic_split(self, text: str) -> list[str]: |
| """ |
| Функция разделения текста на чанки при self.splitter == SemanticChunker |
| """ |
| return self.splitter.split_text(text) |
|
|