| from typing import List, Dict, Tuple
|
| import requests
|
| from elasticsearch import Elasticsearch
|
| import os
|
| import time
|
| from dotenv import load_dotenv
|
|
|
| load_dotenv()
|
|
|
| class Retriever:
|
| def __init__(self):
|
|
|
| self.es = Elasticsearch(
|
| os.getenv("ELASTICSEARCH_URL", "http://localhost:9200"),
|
| )
|
| self.api_key = os.getenv("API_KEY")
|
| self.api_base = os.getenv("BASE_URL")
|
|
|
| def get_embedding(self, text: str) -> List[float]:
|
| """调用SiliconFlow的embedding API获取向量"""
|
| headers = {
|
| "Authorization": f"Bearer {self.api_key}",
|
| "Content-Type": "application/json"
|
| }
|
|
|
| response = requests.post(
|
| f"{self.api_base}/embeddings",
|
| headers=headers,
|
| json={
|
| "model": "BAAI/bge-m3",
|
| "input": text
|
| }
|
| )
|
|
|
| if response.status_code == 200:
|
| return response.json()["data"][0]["embedding"]
|
| else:
|
| raise Exception(f"Error getting embedding: {response.text}")
|
|
|
| def get_all_indices(self) -> List[str]:
|
| """获取所有 RAG 相关的索引"""
|
| indices = self.es.indices.get_alias().keys()
|
| return [idx for idx in indices if idx.startswith('rag_')]
|
|
|
| def retrieve(self, query: str, top_k: int = 10, specific_index: str = None) -> Tuple[List[Dict], str]:
|
| """混合检索:结合 BM25 和向量检索,支持指定特定索引"""
|
|
|
| if specific_index:
|
| indices = [specific_index] if self.es.indices.exists(index=specific_index) else []
|
| else:
|
| indices = self.get_all_indices()
|
|
|
| if not indices:
|
| raise Exception("没有找到可用的文档索引!")
|
|
|
|
|
| query_vector = self.get_embedding(query)
|
|
|
|
|
| all_results = []
|
| for index in indices:
|
|
|
| script_query = {
|
| "script_score": {
|
| "query": {
|
| "match": {
|
| "content": query
|
| }
|
| },
|
| "script": {
|
| "source": "cosineSimilarity(params.query_vector, 'vector') + 1.0",
|
| "params": {"query_vector": query_vector}
|
| }
|
| }
|
| }
|
|
|
|
|
| response = self.es.search(
|
| index=index,
|
| body={
|
| "query": script_query,
|
| "size": top_k
|
| }
|
| )
|
|
|
|
|
| for hit in response['hits']['hits']:
|
| result = {
|
| 'id': hit['_id'],
|
| 'content': hit['_source']['content'],
|
| 'score': hit['_score'],
|
| 'metadata': hit['_source']['metadata'],
|
| 'index': index
|
| }
|
| all_results.append(result)
|
|
|
|
|
| all_results.sort(key=lambda x: x['score'], reverse=True)
|
| top_results = all_results[:top_k]
|
|
|
|
|
| if top_results:
|
| most_relevant_index = top_results[0]['index']
|
| else:
|
| most_relevant_index = indices[0] if indices else ""
|
|
|
| return top_results, most_relevant_index |