RAG_AIEXP / index_retriever.py
Pimnk's picture
Upload 9 files
13d0427 verified
import numpy as np
from llama_index.core import VectorStoreIndex, Settings
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.retrievers import VectorIndexRetriever, BaseRetriever
from llama_index.core.response_synthesizers import get_response_synthesizer, ResponseMode
from llama_index.core.prompts import PromptTemplate
from llama_index.retrievers.bm25 import BM25Retriever
from llama_index.core.retrievers import QueryFusionRetriever
from llama_index.core.schema import NodeWithScore, QueryBundle
from typing import List, Optional, Dict, Tuple
from logger.my_logging import log_message
from config import CUSTOM_PROMPT, DEFAULT_RETRIEVAL_PARAMS
# --- НОВЫЙ КЛАСС ДЛЯ ЛОГИРОВАНИЯ ---
class LogWrapperRetriever(BaseRetriever):
"""
Обертка для ретривера, которая логирует найденные чанки и их скоры
перед тем, как вернуть их.
"""
def __init__(self, retriever: BaseRetriever, name: str):
self._retriever = retriever
self._name = name
super().__init__()
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
# Выполняем реальный поиск
nodes = self._retriever.retrieve(query_bundle)
# Логируем результаты
log_message(f"\n--- 🔎 {self._name} RETRIEVAL (Top {len(nodes)}) ---")
for i, node in enumerate(nodes):
score = node.score if node.score is not None else 0.0
doc_id = node.metadata.get('document_id', 'N/A')
text_preview = node.text.replace('\n', ' ')
log_message(f"[{i+1}] Score: {score:.4f} | Doc: {doc_id} | Text: {text_preview}...")
return nodes
# -----------------------------------
def create_vector_index(documents: List) -> VectorStoreIndex:
"""
Создает векторный индекс из списка документов.
Args:
documents: Список документов для индексации
Returns:
VectorStoreIndex: Созданный векторный индекс
"""
log_message("Инициализация построения векторного индекса")
connection_type_sources: Dict[str, List[str]] = {}
table_count = 0
for doc in documents:
doc_type = doc.metadata.get('type', 'text')
if doc_type == 'table':
table_count += 1
conn_type = doc.metadata.get('connection_type', '')
if conn_type:
table_id = (f"{doc.metadata.get('document_id', 'unknown')} "
f"Table {doc.metadata.get('table_number', 'N/A')}")
if conn_type not in connection_type_sources:
connection_type_sources[conn_type] = []
connection_type_sources[conn_type].append(table_id)
log_message(f"📊 Статистика: Всего документов {len(documents)}, из них таблиц {table_count}")
return VectorStoreIndex.from_documents(documents)
def rerank_nodes(
query: str,
nodes: List,
reranker: Optional[object],
top_k: int = DEFAULT_RETRIEVAL_PARAMS['rerank_top_k'],
rerank_threshold: float = DEFAULT_RETRIEVAL_PARAMS['rerank_threshold']
) -> List:
"""
Переранжирует узлы с использованием модели reranker для улучшения релевантности.
Args:
query: Поисковый запрос
nodes: Список узлов для переранжировки
reranker: Модель для переранжировки (может быть None)
top_k: Количество топовых узлов для возврата
rerank_threshold: Минимальный порог оценки релевантности
Returns:
List: Отсортированный список наиболее релевантных узлов
"""
# Если нет узлов или reranker не предоставлен, возвращаем топ-k узлов как есть
if not nodes or not reranker:
log_message(f"Переранжировка пропущена. Возвращаю первые {top_k} узлов")
return nodes[:top_k]
try:
log_message(f"Начинаю переранжировку {len(nodes)} узлов с порогом {rerank_threshold}")
# Формируем пары [запрос, текст узла] для переранжировки
pairs = [[query, node.text] for node in nodes]
# Получаем оценки релевантности от модели
raw_scores = reranker.predict(pairs)
# Формула: 1 / (1 + e^-x) превращает любое число (5.1, -2.0) в диапазон 0..1
scores = 1 / (1 + np.exp(-raw_scores))
if isinstance(scores, np.ndarray):
scores = scores.tolist()
# Связываем узлы с их оценками
scored_nodes: List[Tuple] = list(zip(nodes, scores))
# Сортируем по убыванию оценки релевантности
scored_nodes.sort(key=lambda x: x[1], reverse=True)
# Фильтруем по минимальному порогу
filtered_nodes = [
(node, score) for node, score in scored_nodes
if score >= rerank_threshold
]
# Если после фильтрации не осталось узлов, берем топ-k без фильтрации
if not filtered_nodes:
log_message(f"Ни один узел не прошел порог {rerank_threshold}. "
f"Возвращаю топ-{top_k} без фильтрации")
filtered_nodes = scored_nodes[:top_k]
result_count = min(len(filtered_nodes), top_k)
log_message(f"Переранжировка завершена. Выбрано узлов: {result_count}")
final_nodes = []
for node, score in filtered_nodes[:top_k]:
node.score = float(score)
final_nodes.append(node)
return final_nodes
except Exception as e:
log_message(f"Ошибка при переранжировке: {str(e)}. Возвращаю исходные узлы")
return nodes[:top_k]
def create_query_engine(
vector_index: VectorStoreIndex,
vector_top_k: int = DEFAULT_RETRIEVAL_PARAMS['vector_top_k'],
bm25_top_k: int = DEFAULT_RETRIEVAL_PARAMS['bm25_top_k'],
similarity_cutoff: float = DEFAULT_RETRIEVAL_PARAMS['similarity_cutoff'],
hybrid_top_k: int = DEFAULT_RETRIEVAL_PARAMS['hybrid_top_k']
) -> RetrieverQueryEngine:
"""
Создает гибридный query engine с комбинацией векторного и BM25 поиска.
Args:
vector_index: Векторный индекс для поиска
vector_top_k: Количество топовых результатов для векторного поиска
bm25_top_k: Количество топовых результатов для BM25 поиска
similarity_cutoff: Порог схожести для векторного поиска (0-1)
hybrid_top_k: Итоговое количество результатов после слияния
Returns:
RetrieverQueryEngine: Настроенный query engine
Raises:
Exception: При ошибке создания query engine
"""
try:
log_message("Инициализация создания query engine")
# Создаем BM25 retriever для лексического поиска
bm25_retriever = BM25Retriever.from_defaults(
docstore=vector_index.docstore,
similarity_top_k=bm25_top_k
)
# Создаем векторный retriever для семантического поиска
vector_retriever = VectorIndexRetriever(
index=vector_index,
similarity_top_k=vector_top_k,
similarity_cutoff=similarity_cutoff
)
# Создаем гибридный retriever, объединяющий оба подхода
bm25_logged = LogWrapperRetriever(bm25_retriever, "BM25 (Keywords)")
vector_logged = LogWrapperRetriever(vector_retriever, "VECTOR (Semantic)")
# 3. Создаем гибридный retriever, используя уже обернутые ретриверы
hybrid_retriever = QueryFusionRetriever(
retrievers=[vector_logged, bm25_logged],
similarity_top_k=hybrid_top_k,
num_queries=1
)
# Настраиваем кастомный промпт для генерации ответа
custom_prompt_template = PromptTemplate(CUSTOM_PROMPT)
# Создаем синтезатор ответов с режимом древовидного суммирования
response_synthesizer = get_response_synthesizer(
response_mode=ResponseMode.TREE_SUMMARIZE,
text_qa_template=custom_prompt_template
)
# Собираем финальный query engine
query_engine = RetrieverQueryEngine(
retriever=hybrid_retriever,
response_synthesizer=response_synthesizer
)
log_message(
f"Query engine успешно создан с параметрами: "
f"vector_top_k={vector_top_k}, bm25_top_k={bm25_top_k}, "
f"similarity_cutoff={similarity_cutoff}, hybrid_top_k={hybrid_top_k}"
)
return query_engine
except Exception as e:
log_message(f"Критическая ошибка при создании query engine: {str(e)}")
raise