| import numpy as np |
| from llama_index.core import VectorStoreIndex, Settings |
| from llama_index.core.query_engine import RetrieverQueryEngine |
| from llama_index.core.retrievers import VectorIndexRetriever, BaseRetriever |
| from llama_index.core.response_synthesizers import get_response_synthesizer, ResponseMode |
| from llama_index.core.prompts import PromptTemplate |
| from llama_index.retrievers.bm25 import BM25Retriever |
| from llama_index.core.retrievers import QueryFusionRetriever |
| from llama_index.core.schema import NodeWithScore, QueryBundle |
| from typing import List, Optional, Dict, Tuple |
| from logger.my_logging import log_message |
| from config import CUSTOM_PROMPT, DEFAULT_RETRIEVAL_PARAMS |
|
|
| |
| class LogWrapperRetriever(BaseRetriever): |
| """ |
| Обертка для ретривера, которая логирует найденные чанки и их скоры |
| перед тем, как вернуть их. |
| """ |
| def __init__(self, retriever: BaseRetriever, name: str): |
| self._retriever = retriever |
| self._name = name |
| super().__init__() |
|
|
| def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: |
| |
| nodes = self._retriever.retrieve(query_bundle) |
| |
| |
| log_message(f"\n--- 🔎 {self._name} RETRIEVAL (Top {len(nodes)}) ---") |
| for i, node in enumerate(nodes): |
| score = node.score if node.score is not None else 0.0 |
| doc_id = node.metadata.get('document_id', 'N/A') |
| text_preview = node.text.replace('\n', ' ') |
| log_message(f"[{i+1}] Score: {score:.4f} | Doc: {doc_id} | Text: {text_preview}...") |
| |
| return nodes |
|
|
| |
|
|
| def create_vector_index(documents: List) -> VectorStoreIndex: |
| """ |
| Создает векторный индекс из списка документов. |
| |
| Args: |
| documents: Список документов для индексации |
| |
| Returns: |
| VectorStoreIndex: Созданный векторный индекс |
| """ |
| log_message("Инициализация построения векторного индекса") |
| |
| connection_type_sources: Dict[str, List[str]] = {} |
| table_count = 0 |
| |
| for doc in documents: |
| doc_type = doc.metadata.get('type', 'text') |
| |
| if doc_type == 'table': |
| table_count += 1 |
|
|
| conn_type = doc.metadata.get('connection_type', '') |
| if conn_type: |
| table_id = (f"{doc.metadata.get('document_id', 'unknown')} " |
| f"Table {doc.metadata.get('table_number', 'N/A')}") |
| |
| if conn_type not in connection_type_sources: |
| connection_type_sources[conn_type] = [] |
| connection_type_sources[conn_type].append(table_id) |
| |
| log_message(f"📊 Статистика: Всего документов {len(documents)}, из них таблиц {table_count}") |
| |
| return VectorStoreIndex.from_documents(documents) |
|
|
|
|
| def rerank_nodes( |
| query: str, |
| nodes: List, |
| reranker: Optional[object], |
| top_k: int = DEFAULT_RETRIEVAL_PARAMS['rerank_top_k'], |
| rerank_threshold: float = DEFAULT_RETRIEVAL_PARAMS['rerank_threshold'] |
| ) -> List: |
| """ |
| Переранжирует узлы с использованием модели reranker для улучшения релевантности. |
| |
| Args: |
| query: Поисковый запрос |
| nodes: Список узлов для переранжировки |
| reranker: Модель для переранжировки (может быть None) |
| top_k: Количество топовых узлов для возврата |
| rerank_threshold: Минимальный порог оценки релевантности |
| |
| Returns: |
| List: Отсортированный список наиболее релевантных узлов |
| """ |
| |
| if not nodes or not reranker: |
| log_message(f"Переранжировка пропущена. Возвращаю первые {top_k} узлов") |
| return nodes[:top_k] |
| |
| try: |
| log_message(f"Начинаю переранжировку {len(nodes)} узлов с порогом {rerank_threshold}") |
| |
| |
| pairs = [[query, node.text] for node in nodes] |
| |
| |
| raw_scores = reranker.predict(pairs) |
|
|
| |
| scores = 1 / (1 + np.exp(-raw_scores)) |
|
|
| if isinstance(scores, np.ndarray): |
| scores = scores.tolist() |
| |
| |
| scored_nodes: List[Tuple] = list(zip(nodes, scores)) |
| |
| |
| scored_nodes.sort(key=lambda x: x[1], reverse=True) |
| |
| |
| filtered_nodes = [ |
| (node, score) for node, score in scored_nodes |
| if score >= rerank_threshold |
| ] |
| |
| |
| if not filtered_nodes: |
| log_message(f"Ни один узел не прошел порог {rerank_threshold}. " |
| f"Возвращаю топ-{top_k} без фильтрации") |
| filtered_nodes = scored_nodes[:top_k] |
| |
| result_count = min(len(filtered_nodes), top_k) |
| log_message(f"Переранжировка завершена. Выбрано узлов: {result_count}") |
| |
| final_nodes = [] |
| for node, score in filtered_nodes[:top_k]: |
| node.score = float(score) |
| final_nodes.append(node) |
| |
| return final_nodes |
| |
| except Exception as e: |
| log_message(f"Ошибка при переранжировке: {str(e)}. Возвращаю исходные узлы") |
| return nodes[:top_k] |
|
|
|
|
| def create_query_engine( |
| vector_index: VectorStoreIndex, |
| vector_top_k: int = DEFAULT_RETRIEVAL_PARAMS['vector_top_k'], |
| bm25_top_k: int = DEFAULT_RETRIEVAL_PARAMS['bm25_top_k'], |
| similarity_cutoff: float = DEFAULT_RETRIEVAL_PARAMS['similarity_cutoff'], |
| hybrid_top_k: int = DEFAULT_RETRIEVAL_PARAMS['hybrid_top_k'] |
| ) -> RetrieverQueryEngine: |
| """ |
| Создает гибридный query engine с комбинацией векторного и BM25 поиска. |
| |
| Args: |
| vector_index: Векторный индекс для поиска |
| vector_top_k: Количество топовых результатов для векторного поиска |
| bm25_top_k: Количество топовых результатов для BM25 поиска |
| similarity_cutoff: Порог схожести для векторного поиска (0-1) |
| hybrid_top_k: Итоговое количество результатов после слияния |
| |
| Returns: |
| RetrieverQueryEngine: Настроенный query engine |
| |
| Raises: |
| Exception: При ошибке создания query engine |
| """ |
| try: |
| log_message("Инициализация создания query engine") |
| |
| |
| bm25_retriever = BM25Retriever.from_defaults( |
| docstore=vector_index.docstore, |
| similarity_top_k=bm25_top_k |
| ) |
| |
| |
| vector_retriever = VectorIndexRetriever( |
| index=vector_index, |
| similarity_top_k=vector_top_k, |
| similarity_cutoff=similarity_cutoff |
| ) |
| |
| |
| bm25_logged = LogWrapperRetriever(bm25_retriever, "BM25 (Keywords)") |
| vector_logged = LogWrapperRetriever(vector_retriever, "VECTOR (Semantic)") |
| |
| |
| hybrid_retriever = QueryFusionRetriever( |
| retrievers=[vector_logged, bm25_logged], |
| similarity_top_k=hybrid_top_k, |
| num_queries=1 |
| ) |
| |
| |
| custom_prompt_template = PromptTemplate(CUSTOM_PROMPT) |
| |
| |
| response_synthesizer = get_response_synthesizer( |
| response_mode=ResponseMode.TREE_SUMMARIZE, |
| text_qa_template=custom_prompt_template |
| ) |
| |
| |
| query_engine = RetrieverQueryEngine( |
| retriever=hybrid_retriever, |
| response_synthesizer=response_synthesizer |
| ) |
| |
| log_message( |
| f"Query engine успешно создан с параметрами: " |
| f"vector_top_k={vector_top_k}, bm25_top_k={bm25_top_k}, " |
| f"similarity_cutoff={similarity_cutoff}, hybrid_top_k={hybrid_top_k}" |
| ) |
| |
| return query_engine |
| |
| except Exception as e: |
| log_message(f"Критическая ошибка при создании query engine: {str(e)}") |
| raise |