shenli commited on
Commit ·
8a17806
1
Parent(s): e45316c
Add GraphDatabase module with Neo4j + Redis caching
Browse files- .gitignore +7 -0
- GraphDatabase/main.py +193 -0
- GraphDatabase/models.py +77 -0
- GraphDatabase/prompts.py +65 -0
- GraphDatabase/schemas.py +114 -0
- GraphDatabase/validators.py +120 -0
- agent3.py +298 -0
- milvus_agent.db +0 -3
- test.py +3 -2
- vector.py +37 -1
.gitignore
CHANGED
|
@@ -1 +1,8 @@
|
|
|
|
|
|
|
|
| 1 |
.env
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
.env
|
| 4 |
+
*.db
|
| 5 |
+
*.db.lock
|
| 6 |
+
.DS_Store
|
| 7 |
+
pdf_output/
|
| 8 |
+
data/
|
GraphDatabase/main.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from fastapi import FastAPI, HTTPException, Depends
|
| 3 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 4 |
+
import uvicorn
|
| 5 |
+
from contextlib import asynccontextmanager
|
| 6 |
+
from openai import OpenAI
|
| 7 |
+
from dotenv import load_dotenv
|
| 8 |
+
from models import NL2CypherRequest, CypherResponse, ValidationRequest, ValidationResponse
|
| 9 |
+
from schemas import EXAMPLE_SCHEMA
|
| 10 |
+
from prompts import create_system_prompt, create_validation_prompt
|
| 11 |
+
from validators import CypherValidator, RuleBasedValidator
|
| 12 |
+
|
| 13 |
+
# 加载环境变量
|
| 14 |
+
load_dotenv()
|
| 15 |
+
|
| 16 |
+
# 获取 OpenAI 的 api key
|
| 17 |
+
openai_api_key = os.getenv("OPENAI_API_KEY")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# 生命周期管理
|
| 21 |
+
@asynccontextmanager
|
| 22 |
+
async def lifespan(app: FastAPI):
|
| 23 |
+
# 启动时初始化
|
| 24 |
+
neo4j_uri = os.getenv("NEO4J_URI")
|
| 25 |
+
neo4j_user = os.getenv("NEO4J_USER")
|
| 26 |
+
neo4j_password = os.getenv("NEO4J_PASSWORD")
|
| 27 |
+
|
| 28 |
+
if all([neo4j_uri, neo4j_user, neo4j_password]):
|
| 29 |
+
app.state.validator = CypherValidator(neo4j_uri, neo4j_user, neo4j_password)
|
| 30 |
+
else:
|
| 31 |
+
app.state.validator = RuleBasedValidator()
|
| 32 |
+
|
| 33 |
+
yield
|
| 34 |
+
|
| 35 |
+
# 关闭时清理
|
| 36 |
+
if hasattr(app.state.validator, 'close'):
|
| 37 |
+
app.state.validator.close()
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# 创建FastAPI应用
|
| 41 |
+
app = FastAPI(title="NL2Cypher API", lifespan=lifespan)
|
| 42 |
+
|
| 43 |
+
# 初始化 OpenAI 模型
|
| 44 |
+
client = OpenAI(
|
| 45 |
+
api_key=openai_api_key, # 你的 OpenAI API 密钥
|
| 46 |
+
base_url="https://api.openai.com/v1", # OpenAI 的 API 端点
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
# 添加CORS中间件
|
| 50 |
+
app.add_middleware(
|
| 51 |
+
CORSMiddleware,
|
| 52 |
+
allow_origins=["*"],
|
| 53 |
+
allow_credentials=True,
|
| 54 |
+
allow_methods=["*"],
|
| 55 |
+
allow_headers=["*"],
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def clean_cypher_output(raw_output: str) -> str:
|
| 60 |
+
"""清洗 LLM 返回的 Cypher 查询, 去掉多余的包装文本"""
|
| 61 |
+
import re
|
| 62 |
+
text = raw_output.strip()
|
| 63 |
+
|
| 64 |
+
# 去掉 markdown 代码块: ```cypher ... ``` 或 ``` ... ```
|
| 65 |
+
text = re.sub(r'```(?:cypher)?\s*', '', text)
|
| 66 |
+
text = text.strip('`')
|
| 67 |
+
|
| 68 |
+
# 去掉 Cypher: "..." 包装
|
| 69 |
+
match = re.match(r'^[Cc]ypher:\s*["\']?(.*?)["\']?\s*$', text, re.DOTALL)
|
| 70 |
+
if match:
|
| 71 |
+
text = match.group(1).strip()
|
| 72 |
+
|
| 73 |
+
# 去掉首尾引号
|
| 74 |
+
if (text.startswith('"') and text.endswith('"')) or \
|
| 75 |
+
(text.startswith("'") and text.endswith("'")):
|
| 76 |
+
text = text[1:-1].strip()
|
| 77 |
+
|
| 78 |
+
return text
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def generate_cypher_query(natural_language: str, query_type: str = None) -> str:
|
| 82 |
+
"""使用 OpenAI 生成 Cypher 查询"""
|
| 83 |
+
system_prompt = create_system_prompt(str(EXAMPLE_SCHEMA.model_dump()))
|
| 84 |
+
|
| 85 |
+
user_prompt = natural_language
|
| 86 |
+
if query_type:
|
| 87 |
+
user_prompt = f"{query_type}查询: {natural_language}"
|
| 88 |
+
|
| 89 |
+
try:
|
| 90 |
+
response = client.chat.completions.create(
|
| 91 |
+
model="gpt-4o",
|
| 92 |
+
messages=[
|
| 93 |
+
{"role": "system", "content": system_prompt},
|
| 94 |
+
{"role": "user", "content": user_prompt}
|
| 95 |
+
],
|
| 96 |
+
temperature=0.1,
|
| 97 |
+
max_tokens=2048,
|
| 98 |
+
stream=False
|
| 99 |
+
)
|
| 100 |
+
raw_output = response.choices[0].message.content.strip()
|
| 101 |
+
return clean_cypher_output(raw_output)
|
| 102 |
+
except Exception as e:
|
| 103 |
+
raise HTTPException(status_code=500, detail=f"OpenAI API错误: {str(e)}")
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def explain_cypher_query(cypher_query: str) -> str:
|
| 107 |
+
"""解释Cypher查询"""
|
| 108 |
+
try:
|
| 109 |
+
response = client.chat.completions.create(
|
| 110 |
+
model="gpt-4o",
|
| 111 |
+
messages=[
|
| 112 |
+
{"role": "system", "content": "你是一个Neo4j专家, 请用简单明了的语言解释Cypher查询."},
|
| 113 |
+
{"role": "user", "content": f"请解释以下Cypher查询: {cypher_query}"}
|
| 114 |
+
],
|
| 115 |
+
temperature=0.1,
|
| 116 |
+
max_tokens=1024,
|
| 117 |
+
stream=False
|
| 118 |
+
)
|
| 119 |
+
return response.choices[0].message.content.strip()
|
| 120 |
+
except Exception as e:
|
| 121 |
+
return f"无法生成解释: {str(e)}"
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
@app.post("/generate", response_model=CypherResponse)
|
| 125 |
+
async def generate_cypher(request: NL2CypherRequest):
|
| 126 |
+
"""生成Cypher查询端点"""
|
| 127 |
+
# 利用 OpenAI 生成 Cypher 查询
|
| 128 |
+
cypher_query = generate_cypher_query(
|
| 129 |
+
request.natural_language_query,
|
| 130 |
+
request.query_type.value if request.query_type else None
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
# 利用 OpenAI 生成解释
|
| 134 |
+
explanation = explain_cypher_query(cypher_query)
|
| 135 |
+
|
| 136 |
+
# 验证查询
|
| 137 |
+
is_valid, errors = app.state.validator.validate_against_schema(cypher_query, EXAMPLE_SCHEMA)
|
| 138 |
+
|
| 139 |
+
# 计算置信度, 将基础置信度设置为0.9
|
| 140 |
+
confidence = 0.9
|
| 141 |
+
|
| 142 |
+
# 如果有潜在错误, 重新计算置信度 confidence
|
| 143 |
+
if errors:
|
| 144 |
+
confidence = max(0.3, confidence - len(errors) * 0.1)
|
| 145 |
+
|
| 146 |
+
return CypherResponse(
|
| 147 |
+
cypher_query=cypher_query,
|
| 148 |
+
explanation=explanation,
|
| 149 |
+
confidence=confidence,
|
| 150 |
+
validated=is_valid,
|
| 151 |
+
validation_errors=errors
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
@app.post("/validate", response_model=ValidationResponse)
|
| 156 |
+
async def validate_cypher(request: ValidationRequest):
|
| 157 |
+
"""验证Cypher查询端点"""
|
| 158 |
+
is_valid, errors = app.state.validator.validate_against_schema(request.cypher_query, EXAMPLE_SCHEMA)
|
| 159 |
+
|
| 160 |
+
# 生成改进建议
|
| 161 |
+
suggestions = []
|
| 162 |
+
if errors:
|
| 163 |
+
try:
|
| 164 |
+
response = client.chat.completions.create(
|
| 165 |
+
model="gpt-4o",
|
| 166 |
+
messages=[
|
| 167 |
+
{"role": "system", "content": "你是一个Neo4j专家, 请提供Cypher查询的改进建议."},
|
| 168 |
+
{"role": "user", "content": create_validation_prompt(request.cypher_query)}
|
| 169 |
+
],
|
| 170 |
+
temperature=0.1,
|
| 171 |
+
max_tokens=1024,
|
| 172 |
+
stream=False
|
| 173 |
+
)
|
| 174 |
+
suggestions = [response.choices[0].message.content.strip()]
|
| 175 |
+
except:
|
| 176 |
+
suggestions = ["无法生成建议"]
|
| 177 |
+
|
| 178 |
+
return ValidationResponse(
|
| 179 |
+
is_valid=is_valid,
|
| 180 |
+
errors=errors,
|
| 181 |
+
suggestions=suggestions
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
@app.get("/schema")
|
| 186 |
+
async def get_schema():
|
| 187 |
+
"""获取图模式端点"""
|
| 188 |
+
return EXAMPLE_SCHEMA.model_dump()
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
if __name__ == "__main__":
|
| 192 |
+
# 因为项目中的主服务Agent启动在8103端口, 所以这个neo4j的服务端口另选一个8101即可
|
| 193 |
+
uvicorn.run(app, host="0.0.0.0", port=8101)
|
GraphDatabase/models.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pydantic import BaseModel, Field
|
| 3 |
+
from typing import Optional, List, Dict, Any
|
| 4 |
+
from enum import Enum
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class QueryType(str, Enum):
|
| 8 |
+
MATCH = "MATCH"
|
| 9 |
+
CREATE = "CREATE"
|
| 10 |
+
MERGE = "MERGE"
|
| 11 |
+
DELETE = "DELETE"
|
| 12 |
+
SET = "SET"
|
| 13 |
+
REMOVE = "REMOVE"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class NL2CypherRequest(BaseModel):
|
| 17 |
+
natural_language_query: str = Field(
|
| 18 |
+
description="自然语言描述的需求",
|
| 19 |
+
examples=["查找'心血管和血栓栓塞综合征'建议服用什么药物?"]
|
| 20 |
+
)
|
| 21 |
+
query_type: Optional[QueryType] = Field(
|
| 22 |
+
default=None,
|
| 23 |
+
description="指定查询类型,如果不指定则由模型推断"
|
| 24 |
+
)
|
| 25 |
+
limit: Optional[int] = Field(
|
| 26 |
+
default=10,
|
| 27 |
+
description="结果限制数量",
|
| 28 |
+
ge=1,
|
| 29 |
+
le=1000
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class CypherResponse(BaseModel):
|
| 34 |
+
cypher_query: str = Field(
|
| 35 |
+
...,
|
| 36 |
+
description="生成的Cypher查询语句"
|
| 37 |
+
)
|
| 38 |
+
explanation: str = Field(
|
| 39 |
+
...,
|
| 40 |
+
description="对生成的Cypher查询的解释"
|
| 41 |
+
)
|
| 42 |
+
confidence: float = Field(
|
| 43 |
+
...,
|
| 44 |
+
description="模型对生成查询的信心度(0-1)",
|
| 45 |
+
ge=0,
|
| 46 |
+
le=1
|
| 47 |
+
)
|
| 48 |
+
validated: bool = Field(
|
| 49 |
+
default=False,
|
| 50 |
+
description="查询是否通过验证"
|
| 51 |
+
)
|
| 52 |
+
validation_errors: List[str] = Field(
|
| 53 |
+
default_factory=list,
|
| 54 |
+
description="验证过程中发现的错误"
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class ValidationRequest(BaseModel):
|
| 59 |
+
cypher_query: str = Field(
|
| 60 |
+
...,
|
| 61 |
+
description="需要验证的Cypher查询"
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class ValidationResponse(BaseModel):
|
| 66 |
+
is_valid: bool = Field(
|
| 67 |
+
...,
|
| 68 |
+
description="查询是否有效"
|
| 69 |
+
)
|
| 70 |
+
errors: List[str] = Field(
|
| 71 |
+
default_factory=list,
|
| 72 |
+
description="发现的错误列表"
|
| 73 |
+
)
|
| 74 |
+
suggestions: List[str] = Field(
|
| 75 |
+
default_factory=list,
|
| 76 |
+
description="改进建议"
|
| 77 |
+
)
|
GraphDatabase/prompts.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from schemas import EXAMPLE_SCHEMA
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def create_system_prompt(schema: str) -> str:
|
| 6 |
+
return f"""
|
| 7 |
+
你是一个专业的Neo4j Cypher查询生成器, 你的任务是将自然语言描述转换为准确, 高效的Cypher查询.
|
| 8 |
+
|
| 9 |
+
# 图数据库模式
|
| 10 |
+
{schema}
|
| 11 |
+
|
| 12 |
+
# 重要规则
|
| 13 |
+
1. 始终使用参数化查询风格, 对字符串值使用单引号
|
| 14 |
+
2. 确保节点标签和关系类型使用正确的大小写
|
| 15 |
+
3. 对于模糊查询, 使用 CONTAINS 或 STARTS WITH 而不是 "="
|
| 16 |
+
4. 对于可选模式, 使用 OPTIONAL MATCH
|
| 17 |
+
5. 始终考虑查询性能, 使用适当的索引和约束
|
| 18 |
+
6. 对于需要返回多个实体的查询, 使用 RETURN 子句明确指定要返回的内容
|
| 19 |
+
7. 避免使用可能导致性能问题的查询模式
|
| 20 |
+
|
| 21 |
+
# 示例如下
|
| 22 |
+
自然语言: "查找心血管和血栓栓塞综合征建议服用什么药物?"
|
| 23 |
+
Cypher: "match (p:Disease)-[r:recommand_drug]-(d:Drug) where p.name='心血管和血栓栓塞综合征' return d.name"
|
| 24 |
+
|
| 25 |
+
自然语言: "查找嗜铬细胞瘤这种疾病有哪些临床症状?"
|
| 26 |
+
Cypher: "match (p:Disease)-[r:has_symptom]-(s:Symptom) where p.name='嗜铬细胞瘤' return s.name"
|
| 27 |
+
|
| 28 |
+
自然语言: "查找小儿先天性巨结肠推荐哪些饮食有利康复?"
|
| 29 |
+
Cypher: "match (p:Disease)-[r:recommand_eat]-(f:Food) where p.name='小儿先天性巨结肠' return f.name"
|
| 30 |
+
|
| 31 |
+
自然语言: "查找糖尿病需要做哪些检查项目?"
|
| 32 |
+
Cypher: "match (p:Disease)-[r:need_check]-(c:Check) where p.name='糖尿病' return c.name"
|
| 33 |
+
|
| 34 |
+
自然语言: "查找高血压属于哪个科室?"
|
| 35 |
+
Cypher: "match (p:Disease)-[r:belongs_to]-(d:Department) where p.name='高血压' return d.name"
|
| 36 |
+
|
| 37 |
+
自然语言: "查找感冒的常用药物有哪些?"
|
| 38 |
+
Cypher: "match (p:Disease)-[r:common_drug]-(d:Drug) where p.name='感冒' return d.name"
|
| 39 |
+
|
| 40 |
+
自然语言: "查找肺炎患者不能吃什么食物?"
|
| 41 |
+
Cypher: "match (p:Disease)-[r:no_eat]-(f:Food) where p.name='肺炎' return f.name"
|
| 42 |
+
|
| 43 |
+
自然语言: "查找胃炎患者适合吃什么食物?"
|
| 44 |
+
Cypher: "match (p:Disease)-[r:do_eat]-(f:Food) where p.name='胃炎' return f.name"
|
| 45 |
+
|
| 46 |
+
自然语言: "查找冠心病容易并发哪些疾病?"
|
| 47 |
+
Cypher: "match (p:Disease)-[r:acompany_with]-(d:Disease) where p.name='冠心病' return d.name"
|
| 48 |
+
|
| 49 |
+
自然语言: "查找阿莫西林是哪个厂家生产的?"
|
| 50 |
+
Cypher: "match (p:Producer)-[r:drugs_of]-(d:Drug) where d.name='阿莫西林' return p.name"
|
| 51 |
+
|
| 52 |
+
现在请根据以下自然语言描述生成Cypher查询:
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def create_validation_prompt(cypher_query: str) -> str:
|
| 57 |
+
return f"""
|
| 58 |
+
请分析以下Cypher查询, 指出其中的任何错误或潜在问题, 并提供改进建议:
|
| 59 |
+
|
| 60 |
+
{cypher_query}
|
| 61 |
+
|
| 62 |
+
请按以下格式回答:
|
| 63 |
+
错误: [列出所有错误]
|
| 64 |
+
建议: [提供改进建议]
|
| 65 |
+
"""
|
GraphDatabase/schemas.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pydantic import BaseModel
|
| 3 |
+
from typing import Dict, List, Optional
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class NodeSchema(BaseModel):
|
| 7 |
+
label: str
|
| 8 |
+
properties: Dict[str, str] # 属性名: 类型
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class RelationshipSchema(BaseModel):
|
| 12 |
+
type: str
|
| 13 |
+
from_node: str # 起始节点标签
|
| 14 |
+
to_node: str # 目标节点标签
|
| 15 |
+
properties: Dict[str, str] # 属性名: 类型
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class GraphSchema(BaseModel):
|
| 19 |
+
nodes: List[NodeSchema]
|
| 20 |
+
relationships: List[RelationshipSchema]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# 示例图模式 (按照neo4j数据库中的定义schema来填充)
|
| 24 |
+
EXAMPLE_SCHEMA = GraphSchema(
|
| 25 |
+
# 节点的名称一定要严格保持跟neo4j一致
|
| 26 |
+
nodes=[
|
| 27 |
+
# --- PDF 原有的 4 个节点 ---
|
| 28 |
+
NodeSchema(label="Disease", properties={"name": "string", "desc": "string", "cause": "string", "prevent": "string", "cure_lasttime": "string", "cure_department": "string", "cure_way": "string", "cure_prob": "string", "easy_get": "string"}),
|
| 29 |
+
NodeSchema(label="Drug", properties={"name": "string"}),
|
| 30 |
+
NodeSchema(label="Food", properties={"name": "string"}),
|
| 31 |
+
NodeSchema(label="Symptom", properties={"name": "string"}),
|
| 32 |
+
# --- 基于 Neo4j 截图新增的 3 个节点 ---
|
| 33 |
+
NodeSchema(label="Check", properties={"name": "string"}),
|
| 34 |
+
NodeSchema(label="Department", properties={"name": "string"}),
|
| 35 |
+
NodeSchema(label="Producer", properties={"name": "string"}),
|
| 36 |
+
],
|
| 37 |
+
# 关系的相关字段一定要严格保持跟neo4j一致, 大小写都不能错
|
| 38 |
+
relationships=[
|
| 39 |
+
# --- PDF 原有的 3 个关系 ---
|
| 40 |
+
RelationshipSchema(
|
| 41 |
+
type="has_symptom",
|
| 42 |
+
from_node="Disease",
|
| 43 |
+
to_node="Symptom",
|
| 44 |
+
properties={}
|
| 45 |
+
),
|
| 46 |
+
RelationshipSchema(
|
| 47 |
+
type="recommand_drug",
|
| 48 |
+
from_node="Disease",
|
| 49 |
+
to_node="Drug",
|
| 50 |
+
properties={}
|
| 51 |
+
),
|
| 52 |
+
RelationshipSchema(
|
| 53 |
+
type="recommand_eat",
|
| 54 |
+
from_node="Disease",
|
| 55 |
+
to_node="Food",
|
| 56 |
+
properties={}
|
| 57 |
+
),
|
| 58 |
+
# --- 基于 Neo4j 截图新增的关系 ---
|
| 59 |
+
# Disease 需要做的检查项目
|
| 60 |
+
RelationshipSchema(
|
| 61 |
+
type="need_check",
|
| 62 |
+
from_node="Disease",
|
| 63 |
+
to_node="Check",
|
| 64 |
+
properties={}
|
| 65 |
+
),
|
| 66 |
+
# Disease 所属的科室
|
| 67 |
+
RelationshipSchema(
|
| 68 |
+
type="belongs_to",
|
| 69 |
+
from_node="Disease",
|
| 70 |
+
to_node="Department",
|
| 71 |
+
properties={}
|
| 72 |
+
),
|
| 73 |
+
# Disease 的常用药物
|
| 74 |
+
RelationshipSchema(
|
| 75 |
+
type="common_drug",
|
| 76 |
+
from_node="Disease",
|
| 77 |
+
to_node="Drug",
|
| 78 |
+
properties={}
|
| 79 |
+
),
|
| 80 |
+
# Disease 宜吃的食物
|
| 81 |
+
RelationshipSchema(
|
| 82 |
+
type="do_eat",
|
| 83 |
+
from_node="Disease",
|
| 84 |
+
to_node="Food",
|
| 85 |
+
properties={}
|
| 86 |
+
),
|
| 87 |
+
# Disease 忌吃的食物
|
| 88 |
+
RelationshipSchema(
|
| 89 |
+
type="no_eat",
|
| 90 |
+
from_node="Disease",
|
| 91 |
+
to_node="Food",
|
| 92 |
+
properties={}
|
| 93 |
+
),
|
| 94 |
+
# Disease 的并发症
|
| 95 |
+
RelationshipSchema(
|
| 96 |
+
type="acompany_with",
|
| 97 |
+
from_node="Disease",
|
| 98 |
+
to_node="Disease",
|
| 99 |
+
properties={}
|
| 100 |
+
),
|
| 101 |
+
# Drug 的生产商
|
| 102 |
+
RelationshipSchema(
|
| 103 |
+
type="drugs_of",
|
| 104 |
+
from_node="Producer",
|
| 105 |
+
to_node="Drug",
|
| 106 |
+
properties={}
|
| 107 |
+
),
|
| 108 |
+
]
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
if __name__ == '__main__':
|
| 113 |
+
res = str(EXAMPLE_SCHEMA.model_dump())
|
| 114 |
+
print(res)
|
GraphDatabase/validators.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from typing import List, Tuple
|
| 3 |
+
from neo4j import GraphDatabase
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class CypherValidator:
|
| 8 |
+
def __init__(self, neo4j_uri: str, neo4j_user: str, neo4j_password: str):
|
| 9 |
+
self.driver = GraphDatabase.driver(neo4j_uri, auth=(neo4j_user, neo4j_password))
|
| 10 |
+
|
| 11 |
+
def validate_syntax(self, cypher_query: str) -> Tuple[bool, List[str]]:
|
| 12 |
+
"""验证Cypher查询的语法"""
|
| 13 |
+
errors = []
|
| 14 |
+
|
| 15 |
+
# 基本语法检查
|
| 16 |
+
if not cypher_query.strip().upper().startswith(('MATCH', 'CREATE', 'MERGE', 'CALL')):
|
| 17 |
+
errors.append("查询必须以MATCH, CREATE, MERGE 或 CALL开头!!!")
|
| 18 |
+
|
| 19 |
+
# 检查是否有潜在的注入风险
|
| 20 |
+
if any(keyword in cypher_query.upper() for keyword in ['DROP', 'DELETE', 'DETACH', 'REMOVE']):
|
| 21 |
+
if not any(keyword in cypher_query.upper() for keyword in ['DELETE', 'DETACH']):
|
| 22 |
+
errors.append("查询包含可能危险的操作符")
|
| 23 |
+
|
| 24 |
+
# 检查RETURN语句是否存在 (对于MATCH查询)
|
| 25 |
+
if cypher_query.upper().startswith('MATCH') and 'RETURN' not in cypher_query.upper():
|
| 26 |
+
errors.append("MATCH查询必须包含RETURN语句!!!")
|
| 27 |
+
|
| 28 |
+
# 使用Neo4j解释计划验证查询
|
| 29 |
+
try:
|
| 30 |
+
with self.driver.session() as session:
|
| 31 |
+
result = session.run(f"EXPLAIN {cypher_query}")
|
| 32 |
+
# 如果解释成功, 语法基本正确
|
| 33 |
+
return True, errors
|
| 34 |
+
except Exception as e:
|
| 35 |
+
errors.append(f"语法错误: {str(e)}")
|
| 36 |
+
return False, errors
|
| 37 |
+
|
| 38 |
+
def validate_against_schema(self, cypher_query: str, schema) -> Tuple[bool, List[str]]:
|
| 39 |
+
"""根据模式验证查询"""
|
| 40 |
+
errors = []
|
| 41 |
+
|
| 42 |
+
# 提取所有节点标签
|
| 43 |
+
node_labels = [node.label for node in schema.nodes]
|
| 44 |
+
node_pattern = r'\(([a-zA-Z0-9_]+)?:?([a-zA-Z0-9_]+)\)'
|
| 45 |
+
matches = re.findall(node_pattern, cypher_query)
|
| 46 |
+
|
| 47 |
+
for match in matches:
|
| 48 |
+
if match[1] and match[1] not in node_labels:
|
| 49 |
+
errors.append(f"使用了不存在的节点标签: {match[1]}")
|
| 50 |
+
|
| 51 |
+
# 提取所有关系类型
|
| 52 |
+
rel_types = [rel.type for rel in schema.relationships]
|
| 53 |
+
rel_pattern = r'\[([a-zA-Z0-9_]+)?:?([a-zA-Z0-9_]+)\]'
|
| 54 |
+
rel_matches = re.findall(rel_pattern, cypher_query)
|
| 55 |
+
|
| 56 |
+
for match in rel_matches:
|
| 57 |
+
if match[1] and match[1] not in rel_types:
|
| 58 |
+
errors.append(f"使用了不存在的关系类型: {match[1]}")
|
| 59 |
+
|
| 60 |
+
return len(errors) == 0, errors
|
| 61 |
+
|
| 62 |
+
def close(self):
|
| 63 |
+
self.driver.close()
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# 简单的基于规则的验证器 (当无法连接Neo4j时使用)
|
| 68 |
+
class RuleBasedValidator:
|
| 69 |
+
def validate(self, cypher_query: str, schema) -> Tuple[bool, List[str]]:
|
| 70 |
+
errors = []
|
| 71 |
+
|
| 72 |
+
# 检查基本结构
|
| 73 |
+
if not cypher_query.strip():
|
| 74 |
+
errors.append("查询不能为空!!!")
|
| 75 |
+
return False, errors
|
| 76 |
+
|
| 77 |
+
# 检查是否包含潜在危险操作
|
| 78 |
+
dangerous_patterns = [
|
| 79 |
+
(r'(?i)drop\s+', "DROP操作可能危险"),
|
| 80 |
+
(r'(?i)delete\s+', "DELETE操作需要谨慎"),
|
| 81 |
+
(r'(?i)detach\s+delete', "DETACH DELETE操作非常危险!!"),
|
| 82 |
+
(r'(?i)remove\s+', "REMOVE操作需要谨慎"),
|
| 83 |
+
]
|
| 84 |
+
|
| 85 |
+
for pattern, message in dangerous_patterns:
|
| 86 |
+
if re.search(pattern, cypher_query):
|
| 87 |
+
errors.append(message)
|
| 88 |
+
|
| 89 |
+
# 检查MATCH查询是否包含RETURN
|
| 90 |
+
if re.match(r'(?i)match', cypher_query) and not re.search(r'(?i)return', cypher_query):
|
| 91 |
+
errors.append("MATCH查询必须包含RETURN子句")
|
| 92 |
+
|
| 93 |
+
# 检查CREATE查询是否合理
|
| 94 |
+
if re.match(r'(?i)create', cypher_query) and not re.search(r'(?i)(node|relationship|label|index)', cypher_query):
|
| 95 |
+
errors.append("CREATE查询应该明确创建节点或关系")
|
| 96 |
+
|
| 97 |
+
return len(errors) == 0, errors
|
| 98 |
+
|
| 99 |
+
def validate_against_schema(self, cypher_query: str, schema) -> Tuple[bool, List[str]]:
|
| 100 |
+
"""兼容CypherValidator的接口, 先做规则验证再做schema验证"""
|
| 101 |
+
is_valid, errors = self.validate(cypher_query, schema)
|
| 102 |
+
|
| 103 |
+
# 额外进行schema验证
|
| 104 |
+
node_labels = [node.label for node in schema.nodes]
|
| 105 |
+
node_pattern = r'\(([a-zA-Z0-9_]+)?:?([a-zA-Z0-9_]+)\)'
|
| 106 |
+
matches = re.findall(node_pattern, cypher_query)
|
| 107 |
+
|
| 108 |
+
for match in matches:
|
| 109 |
+
if match[1] and match[1] not in node_labels:
|
| 110 |
+
errors.append(f"使用了不存在的节点标签: {match[1]}")
|
| 111 |
+
|
| 112 |
+
rel_types = [rel.type for rel in schema.relationships]
|
| 113 |
+
rel_pattern = r'\[([a-zA-Z0-9_]+)?:?([a-zA-Z0-9_]+)\]'
|
| 114 |
+
rel_matches = re.findall(rel_pattern, cypher_query)
|
| 115 |
+
|
| 116 |
+
for match in rel_matches:
|
| 117 |
+
if match[1] and match[1] not in rel_types:
|
| 118 |
+
errors.append(f"使用了不存在的关系类型: {match[1]}")
|
| 119 |
+
|
| 120 |
+
return len(errors) == 0, errors
|
agent3.py
ADDED
|
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import uvicorn
|
| 3 |
+
from fastapi import FastAPI, Request
|
| 4 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 5 |
+
import json
|
| 6 |
+
import requests
|
| 7 |
+
import datetime
|
| 8 |
+
from openai import OpenAI
|
| 9 |
+
from neo4j import GraphDatabase
|
| 10 |
+
from langchain_milvus import Milvus, BM25BuiltInFunction
|
| 11 |
+
from vector import OpenAIEmbeddings, get_redis_client, cache_set, cache_get
|
| 12 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 13 |
+
from langchain_core.stores import InMemoryStore
|
| 14 |
+
from langchain_classic.retrievers.parent_document_retriever import ParentDocumentRetriever
|
| 15 |
+
from dotenv import load_dotenv
|
| 16 |
+
|
| 17 |
+
# 加载 .env 文件中的环境变量, 隐藏 API Keys
|
| 18 |
+
load_dotenv()
|
| 19 |
+
|
| 20 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 21 |
+
app = FastAPI()
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# ============================================================
|
| 25 |
+
# OpenAI LLM 客户端封装 (替代讲义中的 DeepSeek)
|
| 26 |
+
# ============================================================
|
| 27 |
+
|
| 28 |
+
def create_openai_client():
|
| 29 |
+
"""创建 OpenAI 客户端"""
|
| 30 |
+
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
| 31 |
+
return client
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def generate_openai_answer(client, prompt):
|
| 35 |
+
"""使用 OpenAI 生成回复"""
|
| 36 |
+
response = client.chat.completions.create(
|
| 37 |
+
model="gpt-4o-mini",
|
| 38 |
+
messages=[
|
| 39 |
+
{"role": "user", "content": prompt}
|
| 40 |
+
],
|
| 41 |
+
temperature=0.7,
|
| 42 |
+
)
|
| 43 |
+
return response.choices[0].message.content
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# 允许所有域的请求
|
| 47 |
+
app.add_middleware(
|
| 48 |
+
CORSMiddleware,
|
| 49 |
+
allow_origins=["*"],
|
| 50 |
+
allow_credentials=True,
|
| 51 |
+
allow_methods=["*"],
|
| 52 |
+
allow_headers=["*"],
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
# 创建 Embedding 模型
|
| 56 |
+
embedding_model = OpenAIEmbeddings()
|
| 57 |
+
print("创建 Embedding 模型成功......")
|
| 58 |
+
|
| 59 |
+
# 设置默认的 Milvus 数据库文件路径
|
| 60 |
+
URI = "./milvus_agent.db"
|
| 61 |
+
URI1 = "./pdf_agent.db"
|
| 62 |
+
|
| 63 |
+
# 创建 Milvus 连接
|
| 64 |
+
milvus_vectorstore = Milvus(
|
| 65 |
+
embedding_function=embedding_model,
|
| 66 |
+
builtin_function=BM25BuiltInFunction(),
|
| 67 |
+
vector_field=["dense", "sparse"],
|
| 68 |
+
index_params=[
|
| 69 |
+
{
|
| 70 |
+
"metric_type": "IP",
|
| 71 |
+
"index_type": "IVF_FLAT",
|
| 72 |
+
},
|
| 73 |
+
{
|
| 74 |
+
"metric_type": "BM25",
|
| 75 |
+
"index_type": "SPARSE_INVERTED_INDEX"
|
| 76 |
+
}
|
| 77 |
+
],
|
| 78 |
+
connection_args={"uri": URI},
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
retriever = milvus_vectorstore.as_retriever()
|
| 82 |
+
print("创建 Milvus 连接成功......")
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
docstore = InMemoryStore()
|
| 86 |
+
|
| 87 |
+
# 文本分割器
|
| 88 |
+
child_splitter = RecursiveCharacterTextSplitter(
|
| 89 |
+
chunk_size=200,
|
| 90 |
+
chunk_overlap=50,
|
| 91 |
+
length_function=len,
|
| 92 |
+
separators=["\n\n", "\n", "。", "!", "?", ";", ",", " ", ""]
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
parent_splitter = RecursiveCharacterTextSplitter(
|
| 96 |
+
chunk_size=1000,
|
| 97 |
+
chunk_overlap=200
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
pdf_vectorstore = Milvus(
|
| 101 |
+
embedding_function=embedding_model,
|
| 102 |
+
builtin_function=BM25BuiltInFunction(),
|
| 103 |
+
vector_field=["dense", "sparse"],
|
| 104 |
+
index_params=[
|
| 105 |
+
{
|
| 106 |
+
"metric_type": "IP",
|
| 107 |
+
"index_type": "IVF_FLAT",
|
| 108 |
+
},
|
| 109 |
+
{
|
| 110 |
+
"metric_type": "BM25",
|
| 111 |
+
"index_type": "SPARSE_INVERTED_INDEX"
|
| 112 |
+
}
|
| 113 |
+
],
|
| 114 |
+
connection_args={"uri": URI1},
|
| 115 |
+
consistency_level="Bounded",
|
| 116 |
+
drop_old=False,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# 设置父子文档检索器
|
| 120 |
+
parent_retriever = ParentDocumentRetriever(
|
| 121 |
+
vectorstore=pdf_vectorstore,
|
| 122 |
+
docstore=docstore,
|
| 123 |
+
child_splitter=child_splitter,
|
| 124 |
+
parent_splitter=parent_splitter,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
print("创建 Parent Milvus 连接成功......")
|
| 128 |
+
|
| 129 |
+
# 获取 neo4j 图数据库的连接
|
| 130 |
+
neo4j_uri = os.getenv("NEO4J_URI", "bolt://localhost:7687")
|
| 131 |
+
neo4j_user = os.getenv("NEO4J_USER", "neo4j")
|
| 132 |
+
neo4j_password = os.getenv("NEO4J_PASSWORD", "neo4j")
|
| 133 |
+
driver = GraphDatabase.driver(uri=neo4j_uri, auth=(neo4j_user, neo4j_password), max_connection_lifetime=1000)
|
| 134 |
+
print("创建 Neo4j 连接成功......")
|
| 135 |
+
|
| 136 |
+
# 创建大语言模型, 采用 OpenAI
|
| 137 |
+
client_llm = create_openai_client()
|
| 138 |
+
print("创建 OpenAI LLM 成功......")
|
| 139 |
+
|
| 140 |
+
# 获取 Redis 连接
|
| 141 |
+
client_redis = get_redis_client()
|
| 142 |
+
print("创建 Redis 连接成功......")
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def format_docs(docs):
|
| 146 |
+
return "\n\n".join(doc.page_content for doc in docs)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
@app.post("/")
|
| 150 |
+
async def chatbot(request: Request):
|
| 151 |
+
global milvus_vectorstore, retriever
|
| 152 |
+
|
| 153 |
+
json_post_raw = await request.json()
|
| 154 |
+
json_post = json.dumps(json_post_raw)
|
| 155 |
+
json_post_list = json.loads(json_post)
|
| 156 |
+
|
| 157 |
+
query = json_post_list.get('question')
|
| 158 |
+
|
| 159 |
+
# ============================================================
|
| 160 |
+
# 1: 先查 Redis 缓存, 如果缓存命中, 直接返回结果
|
| 161 |
+
# ============================================================
|
| 162 |
+
response_redis = cache_get(client_redis, query)
|
| 163 |
+
|
| 164 |
+
if response_redis is not None:
|
| 165 |
+
# redis 返回的字符串是以十六进制显示的, 需要按 utf-8 解码
|
| 166 |
+
response = response_redis.decode('utf-8')
|
| 167 |
+
|
| 168 |
+
now = datetime.datetime.now()
|
| 169 |
+
time = now.strftime("%Y-%m-%d %H:%M:%S")
|
| 170 |
+
answer = {
|
| 171 |
+
"response": response,
|
| 172 |
+
"status": 200,
|
| 173 |
+
"time": time
|
| 174 |
+
}
|
| 175 |
+
print('REDIS HIT !!!')
|
| 176 |
+
return answer
|
| 177 |
+
|
| 178 |
+
# ============================================================
|
| 179 |
+
# 2: 向量数据库 Milvus 模糊召回 & 重排序
|
| 180 |
+
# ============================================================
|
| 181 |
+
# 在集合中搜索问题并检索语义 top-10 匹配项, 而且已经配置了 reranker 的处理, 采用RRF算法
|
| 182 |
+
recall_rerank_milvus = milvus_vectorstore.similarity_search(
|
| 183 |
+
query,
|
| 184 |
+
k=10,
|
| 185 |
+
ranker_type="rrf",
|
| 186 |
+
ranker_params={"k": 100}
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
if recall_rerank_milvus:
|
| 190 |
+
# 检索结果存放在列表中
|
| 191 |
+
context = format_docs(recall_rerank_milvus)
|
| 192 |
+
else:
|
| 193 |
+
context = ""
|
| 194 |
+
|
| 195 |
+
# ============================================================
|
| 196 |
+
# 2.5: PDF 文档的 Milvus 召回 (父子文档检索器)
|
| 197 |
+
# ============================================================
|
| 198 |
+
pdf_res = ""
|
| 199 |
+
retrieved_docs = parent_retriever.invoke(query)
|
| 200 |
+
|
| 201 |
+
if retrieved_docs is not None and len(retrieved_docs) >= 1:
|
| 202 |
+
pdf_res = retrieved_docs[0].page_content
|
| 203 |
+
print("PDF res: ", pdf_res)
|
| 204 |
+
|
| 205 |
+
context = context + "\n" + pdf_res
|
| 206 |
+
|
| 207 |
+
# ============================================================
|
| 208 |
+
# 3: 图数据库 neo4j 精准召回
|
| 209 |
+
# ============================================================
|
| 210 |
+
# 访问 neo4j API 服务, 生成 Cypher 命令
|
| 211 |
+
neo4j_res = ""
|
| 212 |
+
data = {"natural_language_query": query}
|
| 213 |
+
data_json = json.dumps(data)
|
| 214 |
+
|
| 215 |
+
try:
|
| 216 |
+
cypher_response = requests.post("http://0.0.0.0:8101/generate", data_json)
|
| 217 |
+
|
| 218 |
+
if cypher_response.status_code == 200:
|
| 219 |
+
cypher_response_data = cypher_response.json()
|
| 220 |
+
|
| 221 |
+
cypher_query = cypher_response_data["cypher_query"]
|
| 222 |
+
confidence = cypher_response_data["confidence"]
|
| 223 |
+
is_valid = cypher_response_data["validated"]
|
| 224 |
+
|
| 225 |
+
if cypher_query is not None and float(confidence) >= 0.9 and is_valid == True:
|
| 226 |
+
print("neo4j Cypher 初步生成成功 !!!")
|
| 227 |
+
|
| 228 |
+
# 验证 neo4j 生成的 Cypher 命令完全正确
|
| 229 |
+
data = {"cypher_query": cypher_query}
|
| 230 |
+
data_json = json.dumps(data)
|
| 231 |
+
cypher_valid = requests.post("http://0.0.0.0:8101/validate", data_json)
|
| 232 |
+
|
| 233 |
+
if cypher_valid.status_code == 200:
|
| 234 |
+
cypher_valid_data = cypher_valid.json()
|
| 235 |
+
if cypher_valid_data["is_valid"] == True:
|
| 236 |
+
with driver.session() as session:
|
| 237 |
+
try:
|
| 238 |
+
record = session.run(cypher_query)
|
| 239 |
+
result = list(map(lambda x: x[0], record))
|
| 240 |
+
neo4j_res = ','.join(result)
|
| 241 |
+
except Exception as e:
|
| 242 |
+
print(e)
|
| 243 |
+
print("neo4j查询失败 !!")
|
| 244 |
+
neo4j_res = ""
|
| 245 |
+
else:
|
| 246 |
+
print("生成Cypher查询失败 !!")
|
| 247 |
+
except Exception as e:
|
| 248 |
+
print(f"neo4j API 服务不可用: {e}")
|
| 249 |
+
|
| 250 |
+
# 合并 Milvus、PDF 和 neo4j 的召回结果, 共同作为 LLM 的输入 prompt
|
| 251 |
+
context = context + "\n" + neo4j_res
|
| 252 |
+
|
| 253 |
+
# ============================================================
|
| 254 |
+
# 4: 为LLM定义系统和用户提示
|
| 255 |
+
# ============================================================
|
| 256 |
+
SYSTEM_PROMPT = """
|
| 257 |
+
System: 你是一个非常得力的医学助手, 你可以通过从数据库中检索出的信息找到问题的答案.
|
| 258 |
+
"""
|
| 259 |
+
|
| 260 |
+
USER_PROMPT = f"""
|
| 261 |
+
User: 利用介于<context>和</context>之间的从数据库中检索出的信息来回答问题, 具体的问题介于<question>和</question>之间. 如果提供的信息为空, 则按照你的经验知识来给出尽可能严谨准确的回答, 不知道的时候坦诚的承认不了解, 不要编造不真实的信息.
|
| 262 |
+
<context>
|
| 263 |
+
{context}
|
| 264 |
+
</context>
|
| 265 |
+
|
| 266 |
+
<question>
|
| 267 |
+
{query}
|
| 268 |
+
</question>
|
| 269 |
+
"""
|
| 270 |
+
|
| 271 |
+
# ============================================================
|
| 272 |
+
# 5: 使用 OpenAI 最新版本模型, 根据提示生成回复
|
| 273 |
+
# ============================================================
|
| 274 |
+
response = generate_openai_answer(client_llm, SYSTEM_PROMPT + USER_PROMPT.format(context, query))
|
| 275 |
+
|
| 276 |
+
# ============================================================
|
| 277 |
+
# 6: 写入缓存
|
| 278 |
+
# ============================================================
|
| 279 |
+
cache_set(client_redis, query, response)
|
| 280 |
+
|
| 281 |
+
# ============================================================
|
| 282 |
+
# 7: 组装服务返回数据
|
| 283 |
+
# ============================================================
|
| 284 |
+
now = datetime.datetime.now()
|
| 285 |
+
time = now.strftime("%Y-%m-%d %H:%M:%S")
|
| 286 |
+
|
| 287 |
+
answer = {
|
| 288 |
+
"response": response,
|
| 289 |
+
"status": 200,
|
| 290 |
+
"time": time
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
return answer
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
if __name__ == '__main__':
|
| 297 |
+
# 主函数中直接启动fastapi服务
|
| 298 |
+
uvicorn.run(app, host='0.0.0.0', port=8103, workers=1)
|
milvus_agent.db
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:3fb6f3a55a098a6eac5d6b916bb55ac65827941084f862007ded0669e2671f8e
|
| 3 |
-
size 28672
|
|
|
|
|
|
|
|
|
|
|
|
test.py
CHANGED
|
@@ -3,8 +3,9 @@ import time
|
|
| 3 |
import json
|
| 4 |
|
| 5 |
url = "http://0.0.0.0:8103/"
|
| 6 |
-
data = {"question": "平日里蜂蜜加白醋一起喝有什么疗效?"}
|
| 7 |
-
|
|
|
|
| 8 |
|
| 9 |
start_time = time.time()
|
| 10 |
|
|
|
|
| 3 |
import json
|
| 4 |
|
| 5 |
url = "http://0.0.0.0:8103/"
|
| 6 |
+
#data = {"question": "平日里蜂蜜加白醋一起喝有什么疗效?"}
|
| 7 |
+
data = {"question": "听说用酸枣仁泡水喝能养生,是真的吗?"}
|
| 8 |
+
#data = {"question": "糖尿病有什么症状?"}
|
| 9 |
|
| 10 |
start_time = time.time()
|
| 11 |
|
vector.py
CHANGED
|
@@ -4,6 +4,7 @@ from tqdm import tqdm
|
|
| 4 |
import json
|
| 5 |
import uuid
|
| 6 |
import time
|
|
|
|
| 7 |
import pandas as pd
|
| 8 |
from openai import OpenAI
|
| 9 |
from langchain.embeddings.base import Embeddings
|
|
@@ -18,6 +19,36 @@ from dotenv import load_dotenv
|
|
| 18 |
load_dotenv()
|
| 19 |
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
# ============================================================
|
| 22 |
# 嵌入模型, 采用 OpenAI text-embedding-3-small
|
| 23 |
# ============================================================
|
|
@@ -247,7 +278,7 @@ if __name__ == "__main__":
|
|
| 247 |
vectorstore = milvus_vectorstore.create_vector_store(docs)
|
| 248 |
print("全部初始化完成, 可以开始问答了......")
|
| 249 |
'''
|
| 250 |
-
|
| 251 |
# 将 PDF 后处理文档中的数据, 封装成Document
|
| 252 |
docs = prepare_pdf_document()
|
| 253 |
print("预处理 PDF 文档数据成功......")
|
|
@@ -259,4 +290,9 @@ if __name__ == "__main__":
|
|
| 259 |
retriever = pdf_vectorstore.create_pdf_vector_store(docs)
|
| 260 |
print("创建基于 Milvus 数据库的父子文档检索器成功......")
|
| 261 |
print(retriever)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
print("全部初始化完成, 可以开始问答了......")
|
|
|
|
| 4 |
import json
|
| 5 |
import uuid
|
| 6 |
import time
|
| 7 |
+
import redis
|
| 8 |
import pandas as pd
|
| 9 |
from openai import OpenAI
|
| 10 |
from langchain.embeddings.base import Embeddings
|
|
|
|
| 19 |
load_dotenv()
|
| 20 |
|
| 21 |
|
| 22 |
+
# ============================================================
|
| 23 |
+
# Redis 缓存处理模块
|
| 24 |
+
# ============================================================
|
| 25 |
+
|
| 26 |
+
def get_redis_client():
|
| 27 |
+
# 创建Redis连接, 使用连接池 (推荐用于生产环境)
|
| 28 |
+
pool = redis.ConnectionPool(host='0.0.0.0', port=6379, db=0, password=None, max_connections=10)
|
| 29 |
+
r = redis.StrictRedis(connection_pool=pool)
|
| 30 |
+
|
| 31 |
+
# 测试连接
|
| 32 |
+
try:
|
| 33 |
+
r.ping()
|
| 34 |
+
print("成功连接到 Redis !")
|
| 35 |
+
except redis.ConnectionError:
|
| 36 |
+
print("无法连接到 Redis !")
|
| 37 |
+
|
| 38 |
+
return r
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# 将 (question, answer) 问答对, 存入 redis
|
| 42 |
+
def cache_set(r, question: str, answer: str):
|
| 43 |
+
r.hset("qa", question, answer)
|
| 44 |
+
r.expire("qa", 3600)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# 通过 question, 读取存在 redis 中的 answer
|
| 48 |
+
def cache_get(r, question: str):
|
| 49 |
+
return r.hget("qa", question)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
# ============================================================
|
| 53 |
# 嵌入模型, 采用 OpenAI text-embedding-3-small
|
| 54 |
# ============================================================
|
|
|
|
| 278 |
vectorstore = milvus_vectorstore.create_vector_store(docs)
|
| 279 |
print("全部初始化完成, 可以开始问答了......")
|
| 280 |
'''
|
| 281 |
+
''''
|
| 282 |
# 将 PDF 后处理文档中的数据, 封装成Document
|
| 283 |
docs = prepare_pdf_document()
|
| 284 |
print("预处理 PDF 文档数据成功......")
|
|
|
|
| 290 |
retriever = pdf_vectorstore.create_pdf_vector_store(docs)
|
| 291 |
print("创建基于 Milvus 数据库的父子文档检索器成功......")
|
| 292 |
print(retriever)
|
| 293 |
+
'''
|
| 294 |
+
r = get_redis_client()
|
| 295 |
+
print("创建Redis连接成功......")
|
| 296 |
+
print(r)
|
| 297 |
+
|
| 298 |
print("全部初始化完成, 可以开始问答了......")
|