| """检索管理器 |
| |
| 协调稠密检索、稀疏检索和 Rerank,提供统一的检索接口 |
| """ |
|
|
| import time |
| from dataclasses import dataclass |
|
|
| from astrbot import logger |
| from astrbot.core.db.vec_db.base import Result |
| from astrbot.core.db.vec_db.faiss_impl import FaissVecDB |
| from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase |
| from astrbot.core.knowledge_base.retrieval.rank_fusion import RankFusion |
| from astrbot.core.knowledge_base.retrieval.sparse_retriever import SparseRetriever |
| from astrbot.core.provider.provider import RerankProvider |
|
|
| from ..kb_helper import KBHelper |
|
|
|
|
| @dataclass |
| class RetrievalResult: |
| """检索结果""" |
|
|
| chunk_id: str |
| doc_id: str |
| doc_name: str |
| kb_id: str |
| kb_name: str |
| content: str |
| score: float |
| metadata: dict |
|
|
|
|
| class RetrievalManager: |
| """检索管理器 |
| |
| 职责: |
| - 协调稠密检索、稀疏检索和 Rerank |
| - 结果融合和排序 |
| """ |
|
|
| def __init__( |
| self, |
| sparse_retriever: SparseRetriever, |
| rank_fusion: RankFusion, |
| kb_db: KBSQLiteDatabase, |
| ) -> None: |
| """初始化检索管理器 |
| |
| Args: |
| vec_db_factory: 向量数据库工厂 |
| sparse_retriever: 稀疏检索器 |
| rank_fusion: 结果融合器 |
| kb_db: 知识库数据库实例 |
| |
| """ |
| self.sparse_retriever = sparse_retriever |
| self.rank_fusion = rank_fusion |
| self.kb_db = kb_db |
|
|
| async def retrieve( |
| self, |
| query: str, |
| kb_ids: list[str], |
| kb_id_helper_map: dict[str, KBHelper], |
| top_k_fusion: int = 20, |
| top_m_final: int = 5, |
| ) -> list[RetrievalResult]: |
| """混合检索 |
| |
| 流程: |
| 1. 稠密检索 (向量相似度) |
| 2. 稀疏检索 (BM25) |
| 3. 结果融合 (RRF) |
| 4. Rerank 重排序 |
| |
| Args: |
| query: 查询文本 |
| kb_ids: 知识库 ID 列表 |
| top_m_final: 最终返回数量 |
| enable_rerank: 是否启用 Rerank |
| |
| Returns: |
| List[RetrievalResult]: 检索结果列表 |
| |
| """ |
| if not kb_ids: |
| return [] |
|
|
| kb_options: dict = {} |
| new_kb_ids = [] |
| for kb_id in kb_ids: |
| kb_helper = kb_id_helper_map.get(kb_id) |
| if kb_helper: |
| kb = kb_helper.kb |
| kb_options[kb_id] = { |
| "top_k_dense": kb.top_k_dense or 50, |
| "top_k_sparse": kb.top_k_sparse or 50, |
| "top_m_final": kb.top_m_final or 5, |
| "vec_db": kb_helper.vec_db, |
| "rerank_provider_id": kb.rerank_provider_id, |
| } |
| new_kb_ids.append(kb_id) |
| else: |
| logger.warning(f"知识库 ID {kb_id} 实例未找到, 已跳过该知识库的检索") |
|
|
| kb_ids = new_kb_ids |
|
|
| |
| time_start = time.time() |
| dense_results = await self._dense_retrieve( |
| query=query, |
| kb_ids=kb_ids, |
| kb_options=kb_options, |
| ) |
| time_end = time.time() |
| logger.debug( |
| f"Dense retrieval across {len(kb_ids)} bases took {time_end - time_start:.2f}s and returned {len(dense_results)} results.", |
| ) |
|
|
| |
| time_start = time.time() |
| sparse_results = await self.sparse_retriever.retrieve( |
| query=query, |
| kb_ids=kb_ids, |
| kb_options=kb_options, |
| ) |
| time_end = time.time() |
| logger.debug( |
| f"Sparse retrieval across {len(kb_ids)} bases took {time_end - time_start:.2f}s and returned {len(sparse_results)} results.", |
| ) |
|
|
| |
| time_start = time.time() |
| fused_results = await self.rank_fusion.fuse( |
| dense_results=dense_results, |
| sparse_results=sparse_results, |
| top_k=top_k_fusion, |
| ) |
| time_end = time.time() |
| logger.debug( |
| f"Rank fusion took {time_end - time_start:.2f}s and returned {len(fused_results)} results.", |
| ) |
|
|
| |
| doc_ids = {fr.doc_id for fr in fused_results} |
| metadata_map = await self.kb_db.get_documents_with_metadata_batch(doc_ids) |
|
|
| retrieval_results = [] |
| for fr in fused_results: |
| metadata_dict = metadata_map.get(fr.doc_id) |
| if metadata_dict: |
| retrieval_results.append( |
| RetrievalResult( |
| chunk_id=fr.chunk_id, |
| doc_id=fr.doc_id, |
| doc_name=metadata_dict["document"].doc_name, |
| kb_id=fr.kb_id, |
| kb_name=metadata_dict["knowledge_base"].kb_name, |
| content=fr.content, |
| score=fr.score, |
| metadata={ |
| "chunk_index": fr.chunk_index, |
| "char_count": len(fr.content), |
| }, |
| ), |
| ) |
|
|
| |
| first_rerank = None |
| for kb_id in kb_ids: |
| vec_db = kb_options[kb_id]["vec_db"] |
| if not isinstance(vec_db, FaissVecDB): |
| logger.warning(f"vec_db for kb_id {kb_id} is not FaissVecDB") |
| continue |
|
|
| rerank_pi = kb_options[kb_id]["rerank_provider_id"] |
| if ( |
| vec_db |
| and vec_db.rerank_provider |
| and rerank_pi |
| and rerank_pi == vec_db.rerank_provider.meta().id |
| ): |
| first_rerank = vec_db.rerank_provider |
| break |
| if first_rerank and retrieval_results: |
| retrieval_results = await self._rerank( |
| query=query, |
| results=retrieval_results, |
| top_k=top_m_final, |
| rerank_provider=first_rerank, |
| ) |
|
|
| return retrieval_results[:top_m_final] |
|
|
| async def _dense_retrieve( |
| self, |
| query: str, |
| kb_ids: list[str], |
| kb_options: dict, |
| ): |
| """稠密检索 (向量相似度) |
| |
| 为每个知识库使用独立的向量数据库进行检索,然后合并结果。 |
| |
| Args: |
| query: 查询文本 |
| kb_ids: 知识库 ID 列表 |
| top_k: 返回结果数量 |
| |
| Returns: |
| List[Result]: 检索结果列表 |
| |
| """ |
| all_results: list[Result] = [] |
| for kb_id in kb_ids: |
| if kb_id not in kb_options: |
| continue |
| try: |
| vec_db: FaissVecDB = kb_options[kb_id]["vec_db"] |
| dense_k = int(kb_options[kb_id]["top_k_dense"]) |
| vec_results = await vec_db.retrieve( |
| query=query, |
| k=dense_k, |
| fetch_k=dense_k * 2, |
| rerank=False, |
| metadata_filters={"kb_id": kb_id}, |
| ) |
|
|
| all_results.extend(vec_results) |
| except Exception as e: |
| from astrbot.core import logger |
|
|
| logger.warning(f"知识库 {kb_id} 稠密检索失败: {e}") |
| continue |
|
|
| |
| all_results.sort(key=lambda x: x.similarity, reverse=True) |
| |
| return all_results |
|
|
| async def _rerank( |
| self, |
| query: str, |
| results: list[RetrievalResult], |
| top_k: int, |
| rerank_provider: RerankProvider, |
| ) -> list[RetrievalResult]: |
| """Rerank 重排序 |
| |
| Args: |
| query: 查询文本 |
| results: 检索结果列表 |
| top_k: 返回结果数量 |
| |
| Returns: |
| List[RetrievalResult]: 重排序后的结果列表 |
| |
| """ |
| if not results: |
| return [] |
|
|
| |
| docs = [r.content for r in results] |
|
|
| |
| rerank_results = await rerank_provider.rerank( |
| query=query, |
| documents=docs, |
| ) |
|
|
| |
| reranked_list = [] |
| for rerank_result in rerank_results: |
| idx = rerank_result.index |
| if idx < len(results): |
| result = results[idx] |
| result.score = rerank_result.relevance_score |
| reranked_list.append(result) |
|
|
| reranked_list.sort(key=lambda x: x.score, reverse=True) |
|
|
| return reranked_list[:top_k] |
|
|