| from langchain.text_splitter import RecursiveCharacterTextSplitter |
| from langchain_text_splitters import MarkdownHeaderTextSplitter |
| from langchain_core.messages.base import BaseMessage |
| from langchain_core.messages import ToolMessage |
| from langchain_community.embeddings import OpenAIEmbeddings |
| from langchain_community.vectorstores import FAISS |
|
|
| from config.settings import config |
| import json |
| import tiktoken |
|
|
|
|
| def parse_mark_down(data: str) -> list: |
| headers_to_split_on = [ |
| ("#", "Header 1"), |
| ("##", "Header 2"), |
| ] |
|
|
| markdown_splitter = MarkdownHeaderTextSplitter(headers_to_split_on=headers_to_split_on) |
| md_header_splits = markdown_splitter.split_text(data) |
| return md_header_splits |
|
|
|
|
| class OversizedContentHandler: |
| """Main handler for content that exceeds context limits""" |
|
|
| def __init__(self, |
| model_name: str = "gpt-4.1", |
| max_context_tokens: int = 8000, |
| reserved_tokens: int = 2000): |
| self.encoding = tiktoken.encoding_for_model(model_name) |
| self.max_context_tokens = max_context_tokens |
| self.reserved_tokens = reserved_tokens |
| self.max_chunk_tokens = max_context_tokens - reserved_tokens |
|
|
| def count_tokens(self, text: str) -> int: |
| return len(self.encoding.encode(text)) |
|
|
| def extract_relevant_chunks(self, content: str, query: str): |
| |
| md_chunks = parse_mark_down(content) |
| |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=15000, chunk_overlap=500) |
| final_chunks = text_splitter.split_documents(md_chunks) |
|
|
| embeddings = OpenAIEmbeddings() |
| vector_db = FAISS.from_documents(final_chunks, embeddings) |
|
|
| relevant_chunks = vector_db.similarity_search(query, k=3) |
| |
| context_with_metadata = [ |
| {"text": doc.page_content, "source": doc.metadata.get("source")} |
| for doc in relevant_chunks |
| ] |
| return context_with_metadata |
|
|
| def process_oversized_message(self, message: BaseMessage, query: str) -> bool: |
| chunked = False |
| |
| if isinstance(message, ToolMessage) and message.name == "tavily_extract": |
| json_content = json.loads(message.content) |
| result = json_content['results'][0] |
| raw_content = result['raw_content'] |
|
|
| content_size = self.count_tokens(raw_content) |
| if content_size > config.MAX_CONTEXT_TOKENS: |
| print(f"Proceed with chunking, evaluated no of tokens {content_size} for message {message.id}") |
| chunked = True |
| result['raw_content'] = self.extract_relevant_chunks(raw_content, query=query) |
| message.content = json.dumps(json_content) |
| return chunked |
|
|