| from typing import List, Dict
|
| import requests
|
| import numpy as np
|
| from elasticsearch import Elasticsearch
|
| import urllib3
|
| from dotenv import load_dotenv
|
| import os
|
|
|
| load_dotenv()
|
|
|
| urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
|
|
| class VectorStore:
|
| def __init__(self):
|
|
|
| self.es = Elasticsearch(
|
| os.getenv("ELASTICSEARCH_URL", "http://localhost:9200"),
|
| request_timeout=30,
|
| max_retries=3,
|
| retry_on_timeout=True,
|
| )
|
| self.api_key = os.getenv("API_KEY")
|
| self.api_base = os.getenv("BASE_URL")
|
|
|
| def test_connection(self):
|
| """测试 Elasticsearch 连接"""
|
| try:
|
| self.es.info()
|
| return True
|
| except Exception as e:
|
| print(f"Elasticsearch 连接测试失败: {str(e)}")
|
| return False
|
|
|
| def get_embedding(self, text: str) -> List[float]:
|
| """调用SiliconFlow的embedding API获取向量"""
|
| headers = {
|
| "Authorization": f"Bearer {self.api_key}",
|
| "Content-Type": "application/json"
|
| }
|
|
|
| response = requests.post(
|
| f"{self.api_base}/embeddings",
|
| headers=headers,
|
| json={
|
| "model": "BAAI/bge-m3",
|
| "input": text
|
| }
|
| )
|
|
|
| if response.status_code == 200:
|
| return response.json()["data"][0]["embedding"]
|
| else:
|
| raise Exception(f"Error getting embedding: {response.text}")
|
|
|
| def store(self, documents: List[Dict], index_name: str) -> None:
|
| """将文档存储到 Elasticsearch"""
|
|
|
| if not self.es.indices.exists(index=index_name):
|
| self.create_index(index_name)
|
|
|
|
|
| try:
|
| response = self.es.count(index=index_name)
|
| last_id = response['count'] - 1
|
| if last_id < 0:
|
| last_id = -1
|
| except Exception as e:
|
| print(f"获取文档数量时出错,假设为-1: {str(e)}")
|
| last_id = -1
|
|
|
|
|
| bulk_data = []
|
| for i, doc in enumerate(documents, start=last_id + 1):
|
|
|
| vector = self.get_embedding(doc['content'])
|
|
|
|
|
| bulk_data.append({
|
| "index": {
|
| "_index": index_name,
|
| "_id": f"doc_{i}"
|
| }
|
| })
|
|
|
|
|
| doc_data = {
|
| "content": doc['content'],
|
| "vector": vector,
|
| "metadata": {
|
| "file_name": doc['metadata'].get('file_name', '未知文件'),
|
| "source": doc['metadata'].get('source', ''),
|
| "page": doc['metadata'].get('page', ''),
|
| "img_url": doc['metadata'].get('img_url', '')
|
| }
|
| }
|
| bulk_data.append(doc_data)
|
|
|
|
|
| if bulk_data:
|
| response = self.es.bulk(operations=bulk_data, refresh=True)
|
| if response.get('errors'):
|
| print("批量写入时出现错误:", response)
|
|
|
| def get_files_in_index(self, index_name: str) -> List[str]:
|
| """获取索引中的所有文件名"""
|
| try:
|
| response = self.es.search(
|
| index=index_name,
|
| body={
|
| "size": 0,
|
| "aggs": {
|
| "unique_files": {
|
| "terms": {
|
| "field": "metadata.file_name",
|
| "size": 1000
|
| }
|
| }
|
| }
|
| }
|
| )
|
|
|
| files = [bucket['key'] for bucket in response['aggregations']['unique_files']['buckets']]
|
| return sorted(files)
|
| except Exception as e:
|
| print(f"获取文件列表时出错: {str(e)}")
|
| return []
|
|
|
| def create_index(self, index_name: str):
|
| """创建 Elasticsearch 索引"""
|
| settings = {
|
| "mappings": {
|
| "properties": {
|
| "content": {"type": "text"},
|
| "vector": {
|
| "type": "dense_vector",
|
| "dims": 1024
|
| },
|
| "metadata": {
|
| "properties": {
|
| "file_name": {
|
| "type": "keyword",
|
| "ignore_above": 256
|
| },
|
| "source": {
|
| "type": "keyword"
|
| },
|
| "page": {
|
| "type": "keyword"
|
| },
|
| "img_url": {
|
| "type": "keyword",
|
| "ignore_above": 2048
|
| }
|
| }
|
| }
|
| }
|
| }
|
| }
|
|
|
|
|
| if self.es.indices.exists(index=index_name):
|
| self.es.indices.delete(index=index_name)
|
|
|
| self.es.indices.create(index=index_name, body=settings)
|
|
|
| def delete_index(self, index_id: str) -> bool:
|
| """删除一个索引"""
|
| try:
|
| if self.es.indices.exists(index=index_id):
|
| self.es.indices.delete(index=index_id)
|
| return True
|
| return False
|
| except Exception as e:
|
| print(f"删除索引时出错: {str(e)}")
|
| return False
|
|
|
| def delete_document(self, index_id: str, file_name: str) -> bool:
|
| """根据文件名删除文档"""
|
| try:
|
| response = self.es.delete_by_query(
|
| index=index_id,
|
| body={
|
| "query": {
|
| "term": {
|
| "metadata.file_name": file_name
|
| }
|
| }
|
| },
|
| refresh=True
|
| )
|
| return True
|
| except Exception as e:
|
| print(f"删除文档时出错: {str(e)}")
|
| return False |