File size: 8,685 Bytes
8ede856 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 | """检索管理器
协调稠密检索、稀疏检索和 Rerank,提供统一的检索接口
"""
import time
from dataclasses import dataclass
from astrbot import logger
from astrbot.core.db.vec_db.base import Result
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
from astrbot.core.knowledge_base.retrieval.rank_fusion import RankFusion
from astrbot.core.knowledge_base.retrieval.sparse_retriever import SparseRetriever
from astrbot.core.provider.provider import RerankProvider
from ..kb_helper import KBHelper
@dataclass
class RetrievalResult:
"""检索结果"""
chunk_id: str
doc_id: str
doc_name: str
kb_id: str
kb_name: str
content: str
score: float
metadata: dict
class RetrievalManager:
"""检索管理器
职责:
- 协调稠密检索、稀疏检索和 Rerank
- 结果融合和排序
"""
def __init__(
self,
sparse_retriever: SparseRetriever,
rank_fusion: RankFusion,
kb_db: KBSQLiteDatabase,
) -> None:
"""初始化检索管理器
Args:
vec_db_factory: 向量数据库工厂
sparse_retriever: 稀疏检索器
rank_fusion: 结果融合器
kb_db: 知识库数据库实例
"""
self.sparse_retriever = sparse_retriever
self.rank_fusion = rank_fusion
self.kb_db = kb_db
async def retrieve(
self,
query: str,
kb_ids: list[str],
kb_id_helper_map: dict[str, KBHelper],
top_k_fusion: int = 20,
top_m_final: int = 5,
) -> list[RetrievalResult]:
"""混合检索
流程:
1. 稠密检索 (向量相似度)
2. 稀疏检索 (BM25)
3. 结果融合 (RRF)
4. Rerank 重排序
Args:
query: 查询文本
kb_ids: 知识库 ID 列表
top_m_final: 最终返回数量
enable_rerank: 是否启用 Rerank
Returns:
List[RetrievalResult]: 检索结果列表
"""
if not kb_ids:
return []
kb_options: dict = {}
new_kb_ids = []
for kb_id in kb_ids:
kb_helper = kb_id_helper_map.get(kb_id)
if kb_helper:
kb = kb_helper.kb
kb_options[kb_id] = {
"top_k_dense": kb.top_k_dense or 50,
"top_k_sparse": kb.top_k_sparse or 50,
"top_m_final": kb.top_m_final or 5,
"vec_db": kb_helper.vec_db,
"rerank_provider_id": kb.rerank_provider_id,
}
new_kb_ids.append(kb_id)
else:
logger.warning(f"知识库 ID {kb_id} 实例未找到, 已跳过该知识库的检索")
kb_ids = new_kb_ids
# 1. 稠密检索
time_start = time.time()
dense_results = await self._dense_retrieve(
query=query,
kb_ids=kb_ids,
kb_options=kb_options,
)
time_end = time.time()
logger.debug(
f"Dense retrieval across {len(kb_ids)} bases took {time_end - time_start:.2f}s and returned {len(dense_results)} results.",
)
# 2. 稀疏检索
time_start = time.time()
sparse_results = await self.sparse_retriever.retrieve(
query=query,
kb_ids=kb_ids,
kb_options=kb_options,
)
time_end = time.time()
logger.debug(
f"Sparse retrieval across {len(kb_ids)} bases took {time_end - time_start:.2f}s and returned {len(sparse_results)} results.",
)
# 3. 结果融合
time_start = time.time()
fused_results = await self.rank_fusion.fuse(
dense_results=dense_results,
sparse_results=sparse_results,
top_k=top_k_fusion,
)
time_end = time.time()
logger.debug(
f"Rank fusion took {time_end - time_start:.2f}s and returned {len(fused_results)} results.",
)
# 4. 转换为 RetrievalResult (批量获取元数据)
doc_ids = {fr.doc_id for fr in fused_results}
metadata_map = await self.kb_db.get_documents_with_metadata_batch(doc_ids)
retrieval_results = []
for fr in fused_results:
metadata_dict = metadata_map.get(fr.doc_id)
if metadata_dict:
retrieval_results.append(
RetrievalResult(
chunk_id=fr.chunk_id,
doc_id=fr.doc_id,
doc_name=metadata_dict["document"].doc_name,
kb_id=fr.kb_id,
kb_name=metadata_dict["knowledge_base"].kb_name,
content=fr.content,
score=fr.score,
metadata={
"chunk_index": fr.chunk_index,
"char_count": len(fr.content),
},
),
)
# 5. Rerank
first_rerank = None
for kb_id in kb_ids:
vec_db = kb_options[kb_id]["vec_db"]
if not isinstance(vec_db, FaissVecDB):
logger.warning(f"vec_db for kb_id {kb_id} is not FaissVecDB")
continue
rerank_pi = kb_options[kb_id]["rerank_provider_id"]
if (
vec_db
and vec_db.rerank_provider
and rerank_pi
and rerank_pi == vec_db.rerank_provider.meta().id
):
first_rerank = vec_db.rerank_provider
break
if first_rerank and retrieval_results:
retrieval_results = await self._rerank(
query=query,
results=retrieval_results,
top_k=top_m_final,
rerank_provider=first_rerank,
)
return retrieval_results[:top_m_final]
async def _dense_retrieve(
self,
query: str,
kb_ids: list[str],
kb_options: dict,
):
"""稠密检索 (向量相似度)
为每个知识库使用独立的向量数据库进行检索,然后合并结果。
Args:
query: 查询文本
kb_ids: 知识库 ID 列表
top_k: 返回结果数量
Returns:
List[Result]: 检索结果列表
"""
all_results: list[Result] = []
for kb_id in kb_ids:
if kb_id not in kb_options:
continue
try:
vec_db: FaissVecDB = kb_options[kb_id]["vec_db"]
dense_k = int(kb_options[kb_id]["top_k_dense"])
vec_results = await vec_db.retrieve(
query=query,
k=dense_k,
fetch_k=dense_k * 2,
rerank=False, # 稠密检索阶段不进行 rerank
metadata_filters={"kb_id": kb_id},
)
all_results.extend(vec_results)
except Exception as e:
from astrbot.core import logger
logger.warning(f"知识库 {kb_id} 稠密检索失败: {e}")
continue
# 按相似度排序并返回 top_k
all_results.sort(key=lambda x: x.similarity, reverse=True)
# return all_results[: len(all_results) // len(kb_ids)]
return all_results
async def _rerank(
self,
query: str,
results: list[RetrievalResult],
top_k: int,
rerank_provider: RerankProvider,
) -> list[RetrievalResult]:
"""Rerank 重排序
Args:
query: 查询文本
results: 检索结果列表
top_k: 返回结果数量
Returns:
List[RetrievalResult]: 重排序后的结果列表
"""
if not results:
return []
# 准备文档列表
docs = [r.content for r in results]
# 调用 Rerank Provider
rerank_results = await rerank_provider.rerank(
query=query,
documents=docs,
)
# 更新分数并重新排序
reranked_list = []
for rerank_result in rerank_results:
idx = rerank_result.index
if idx < len(results):
result = results[idx]
result.score = rerank_result.relevance_score
reranked_list.append(result)
reranked_list.sort(key=lambda x: x.score, reverse=True)
return reranked_list[:top_k]
|