shenli commited on
Commit
8a17806
·
1 Parent(s): e45316c

Add GraphDatabase module with Neo4j + Redis caching

Browse files
.gitignore CHANGED
@@ -1 +1,8 @@
 
 
1
  .env
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
  .env
4
+ *.db
5
+ *.db.lock
6
+ .DS_Store
7
+ pdf_output/
8
+ data/
GraphDatabase/main.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from fastapi import FastAPI, HTTPException, Depends
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ import uvicorn
5
+ from contextlib import asynccontextmanager
6
+ from openai import OpenAI
7
+ from dotenv import load_dotenv
8
+ from models import NL2CypherRequest, CypherResponse, ValidationRequest, ValidationResponse
9
+ from schemas import EXAMPLE_SCHEMA
10
+ from prompts import create_system_prompt, create_validation_prompt
11
+ from validators import CypherValidator, RuleBasedValidator
12
+
13
+ # 加载环境变量
14
+ load_dotenv()
15
+
16
+ # 获取 OpenAI 的 api key
17
+ openai_api_key = os.getenv("OPENAI_API_KEY")
18
+
19
+
20
+ # 生命周期管理
21
+ @asynccontextmanager
22
+ async def lifespan(app: FastAPI):
23
+ # 启动时初始化
24
+ neo4j_uri = os.getenv("NEO4J_URI")
25
+ neo4j_user = os.getenv("NEO4J_USER")
26
+ neo4j_password = os.getenv("NEO4J_PASSWORD")
27
+
28
+ if all([neo4j_uri, neo4j_user, neo4j_password]):
29
+ app.state.validator = CypherValidator(neo4j_uri, neo4j_user, neo4j_password)
30
+ else:
31
+ app.state.validator = RuleBasedValidator()
32
+
33
+ yield
34
+
35
+ # 关闭时清理
36
+ if hasattr(app.state.validator, 'close'):
37
+ app.state.validator.close()
38
+
39
+
40
+ # 创建FastAPI应用
41
+ app = FastAPI(title="NL2Cypher API", lifespan=lifespan)
42
+
43
+ # 初始化 OpenAI 模型
44
+ client = OpenAI(
45
+ api_key=openai_api_key, # 你的 OpenAI API 密钥
46
+ base_url="https://api.openai.com/v1", # OpenAI 的 API 端点
47
+ )
48
+
49
+ # 添加CORS中间件
50
+ app.add_middleware(
51
+ CORSMiddleware,
52
+ allow_origins=["*"],
53
+ allow_credentials=True,
54
+ allow_methods=["*"],
55
+ allow_headers=["*"],
56
+ )
57
+
58
+
59
+ def clean_cypher_output(raw_output: str) -> str:
60
+ """清洗 LLM 返回的 Cypher 查询, 去掉多余的包装文本"""
61
+ import re
62
+ text = raw_output.strip()
63
+
64
+ # 去掉 markdown 代码块: ```cypher ... ``` 或 ``` ... ```
65
+ text = re.sub(r'```(?:cypher)?\s*', '', text)
66
+ text = text.strip('`')
67
+
68
+ # 去掉 Cypher: "..." 包装
69
+ match = re.match(r'^[Cc]ypher:\s*["\']?(.*?)["\']?\s*$', text, re.DOTALL)
70
+ if match:
71
+ text = match.group(1).strip()
72
+
73
+ # 去掉首尾引号
74
+ if (text.startswith('"') and text.endswith('"')) or \
75
+ (text.startswith("'") and text.endswith("'")):
76
+ text = text[1:-1].strip()
77
+
78
+ return text
79
+
80
+
81
+ def generate_cypher_query(natural_language: str, query_type: str = None) -> str:
82
+ """使用 OpenAI 生成 Cypher 查询"""
83
+ system_prompt = create_system_prompt(str(EXAMPLE_SCHEMA.model_dump()))
84
+
85
+ user_prompt = natural_language
86
+ if query_type:
87
+ user_prompt = f"{query_type}查询: {natural_language}"
88
+
89
+ try:
90
+ response = client.chat.completions.create(
91
+ model="gpt-4o",
92
+ messages=[
93
+ {"role": "system", "content": system_prompt},
94
+ {"role": "user", "content": user_prompt}
95
+ ],
96
+ temperature=0.1,
97
+ max_tokens=2048,
98
+ stream=False
99
+ )
100
+ raw_output = response.choices[0].message.content.strip()
101
+ return clean_cypher_output(raw_output)
102
+ except Exception as e:
103
+ raise HTTPException(status_code=500, detail=f"OpenAI API错误: {str(e)}")
104
+
105
+
106
+ def explain_cypher_query(cypher_query: str) -> str:
107
+ """解释Cypher查询"""
108
+ try:
109
+ response = client.chat.completions.create(
110
+ model="gpt-4o",
111
+ messages=[
112
+ {"role": "system", "content": "你是一个Neo4j专家, 请用简单明了的语言解释Cypher查询."},
113
+ {"role": "user", "content": f"请解释以下Cypher查询: {cypher_query}"}
114
+ ],
115
+ temperature=0.1,
116
+ max_tokens=1024,
117
+ stream=False
118
+ )
119
+ return response.choices[0].message.content.strip()
120
+ except Exception as e:
121
+ return f"无法生成解释: {str(e)}"
122
+
123
+
124
+ @app.post("/generate", response_model=CypherResponse)
125
+ async def generate_cypher(request: NL2CypherRequest):
126
+ """生成Cypher查询端点"""
127
+ # 利用 OpenAI 生成 Cypher 查询
128
+ cypher_query = generate_cypher_query(
129
+ request.natural_language_query,
130
+ request.query_type.value if request.query_type else None
131
+ )
132
+
133
+ # 利用 OpenAI 生成解释
134
+ explanation = explain_cypher_query(cypher_query)
135
+
136
+ # 验证查询
137
+ is_valid, errors = app.state.validator.validate_against_schema(cypher_query, EXAMPLE_SCHEMA)
138
+
139
+ # 计算置信度, 将基础置信度设置为0.9
140
+ confidence = 0.9
141
+
142
+ # 如果有潜在错误, 重新计算置信度 confidence
143
+ if errors:
144
+ confidence = max(0.3, confidence - len(errors) * 0.1)
145
+
146
+ return CypherResponse(
147
+ cypher_query=cypher_query,
148
+ explanation=explanation,
149
+ confidence=confidence,
150
+ validated=is_valid,
151
+ validation_errors=errors
152
+ )
153
+
154
+
155
+ @app.post("/validate", response_model=ValidationResponse)
156
+ async def validate_cypher(request: ValidationRequest):
157
+ """验证Cypher查询端点"""
158
+ is_valid, errors = app.state.validator.validate_against_schema(request.cypher_query, EXAMPLE_SCHEMA)
159
+
160
+ # 生成改进建议
161
+ suggestions = []
162
+ if errors:
163
+ try:
164
+ response = client.chat.completions.create(
165
+ model="gpt-4o",
166
+ messages=[
167
+ {"role": "system", "content": "你是一个Neo4j专家, 请提供Cypher查询的改进建议."},
168
+ {"role": "user", "content": create_validation_prompt(request.cypher_query)}
169
+ ],
170
+ temperature=0.1,
171
+ max_tokens=1024,
172
+ stream=False
173
+ )
174
+ suggestions = [response.choices[0].message.content.strip()]
175
+ except:
176
+ suggestions = ["无法生成建议"]
177
+
178
+ return ValidationResponse(
179
+ is_valid=is_valid,
180
+ errors=errors,
181
+ suggestions=suggestions
182
+ )
183
+
184
+
185
+ @app.get("/schema")
186
+ async def get_schema():
187
+ """获取图模式端点"""
188
+ return EXAMPLE_SCHEMA.model_dump()
189
+
190
+
191
+ if __name__ == "__main__":
192
+ # 因为项目中的主服务Agent启动在8103端口, 所以这个neo4j的服务端口另选一个8101即可
193
+ uvicorn.run(app, host="0.0.0.0", port=8101)
GraphDatabase/models.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pydantic import BaseModel, Field
3
+ from typing import Optional, List, Dict, Any
4
+ from enum import Enum
5
+
6
+
7
+ class QueryType(str, Enum):
8
+ MATCH = "MATCH"
9
+ CREATE = "CREATE"
10
+ MERGE = "MERGE"
11
+ DELETE = "DELETE"
12
+ SET = "SET"
13
+ REMOVE = "REMOVE"
14
+
15
+
16
+ class NL2CypherRequest(BaseModel):
17
+ natural_language_query: str = Field(
18
+ description="自然语言描述的需求",
19
+ examples=["查找'心血管和血栓栓塞综合征'建议服用什么药物?"]
20
+ )
21
+ query_type: Optional[QueryType] = Field(
22
+ default=None,
23
+ description="指定查询类型,如果不指定则由模型推断"
24
+ )
25
+ limit: Optional[int] = Field(
26
+ default=10,
27
+ description="结果限制数量",
28
+ ge=1,
29
+ le=1000
30
+ )
31
+
32
+
33
+ class CypherResponse(BaseModel):
34
+ cypher_query: str = Field(
35
+ ...,
36
+ description="生成的Cypher查询语句"
37
+ )
38
+ explanation: str = Field(
39
+ ...,
40
+ description="对生成的Cypher查询的解释"
41
+ )
42
+ confidence: float = Field(
43
+ ...,
44
+ description="模型对生成查询的信心度(0-1)",
45
+ ge=0,
46
+ le=1
47
+ )
48
+ validated: bool = Field(
49
+ default=False,
50
+ description="查询是否通过验证"
51
+ )
52
+ validation_errors: List[str] = Field(
53
+ default_factory=list,
54
+ description="验证过程中发现的错误"
55
+ )
56
+
57
+
58
+ class ValidationRequest(BaseModel):
59
+ cypher_query: str = Field(
60
+ ...,
61
+ description="需要验证的Cypher查询"
62
+ )
63
+
64
+
65
+ class ValidationResponse(BaseModel):
66
+ is_valid: bool = Field(
67
+ ...,
68
+ description="查询是否有效"
69
+ )
70
+ errors: List[str] = Field(
71
+ default_factory=list,
72
+ description="发现的错误列表"
73
+ )
74
+ suggestions: List[str] = Field(
75
+ default_factory=list,
76
+ description="改进建议"
77
+ )
GraphDatabase/prompts.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from schemas import EXAMPLE_SCHEMA
3
+
4
+
5
+ def create_system_prompt(schema: str) -> str:
6
+ return f"""
7
+ 你是一个专业的Neo4j Cypher查询生成器, 你的任务是将自然语言描述转换为准确, 高效的Cypher查询.
8
+
9
+ # 图数据库模式
10
+ {schema}
11
+
12
+ # 重要规则
13
+ 1. 始终使用参数化查询风格, 对字符串值使用单引号
14
+ 2. 确保节点标签和关系类型使用正确的大小写
15
+ 3. 对于模糊查询, 使用 CONTAINS 或 STARTS WITH 而不是 "="
16
+ 4. 对于可选模式, 使用 OPTIONAL MATCH
17
+ 5. 始终考虑查询性能, 使用适当的索引和约束
18
+ 6. 对于需要返回多个实体的查询, 使用 RETURN 子句明确指定要返回的内容
19
+ 7. 避免使用可能导致性能问题的查询模式
20
+
21
+ # 示例如下
22
+ 自然语言: "查找心血管和血栓栓塞综合征建议服用什么药物?"
23
+ Cypher: "match (p:Disease)-[r:recommand_drug]-(d:Drug) where p.name='心血管和血栓栓塞综合征' return d.name"
24
+
25
+ 自然语言: "查找嗜铬细胞瘤这种疾病有哪些临床症状?"
26
+ Cypher: "match (p:Disease)-[r:has_symptom]-(s:Symptom) where p.name='嗜铬细胞瘤' return s.name"
27
+
28
+ 自然语言: "查找小儿先天性巨结肠推荐哪些饮食有利康复?"
29
+ Cypher: "match (p:Disease)-[r:recommand_eat]-(f:Food) where p.name='小儿先天性巨结肠' return f.name"
30
+
31
+ 自然语言: "查找糖尿病需要做哪些检查项目?"
32
+ Cypher: "match (p:Disease)-[r:need_check]-(c:Check) where p.name='糖尿病' return c.name"
33
+
34
+ 自然语言: "查找高血压属于哪个科室?"
35
+ Cypher: "match (p:Disease)-[r:belongs_to]-(d:Department) where p.name='高血压' return d.name"
36
+
37
+ 自然语言: "查找感冒的常用药物有哪些?"
38
+ Cypher: "match (p:Disease)-[r:common_drug]-(d:Drug) where p.name='感冒' return d.name"
39
+
40
+ 自然语言: "查找肺炎患者不能吃什么食物?"
41
+ Cypher: "match (p:Disease)-[r:no_eat]-(f:Food) where p.name='肺炎' return f.name"
42
+
43
+ 自然语言: "查找胃炎患者适合吃什么食物?"
44
+ Cypher: "match (p:Disease)-[r:do_eat]-(f:Food) where p.name='胃炎' return f.name"
45
+
46
+ 自然语言: "查找冠心病容易并发哪些疾病?"
47
+ Cypher: "match (p:Disease)-[r:acompany_with]-(d:Disease) where p.name='冠心病' return d.name"
48
+
49
+ 自然语言: "查找阿莫西林是哪个厂家生产的?"
50
+ Cypher: "match (p:Producer)-[r:drugs_of]-(d:Drug) where d.name='阿莫西林' return p.name"
51
+
52
+ 现在请根据以下自然语言描述生成Cypher查询:
53
+ """
54
+
55
+
56
+ def create_validation_prompt(cypher_query: str) -> str:
57
+ return f"""
58
+ 请分析以下Cypher查询, 指出其中的任何错误或潜在问题, 并提供改进建议:
59
+
60
+ {cypher_query}
61
+
62
+ 请按以下格式回答:
63
+ 错误: [列出所有错误]
64
+ 建议: [提供改进建议]
65
+ """
GraphDatabase/schemas.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pydantic import BaseModel
3
+ from typing import Dict, List, Optional
4
+
5
+
6
+ class NodeSchema(BaseModel):
7
+ label: str
8
+ properties: Dict[str, str] # 属性名: 类型
9
+
10
+
11
+ class RelationshipSchema(BaseModel):
12
+ type: str
13
+ from_node: str # 起始节点标签
14
+ to_node: str # 目标节点标签
15
+ properties: Dict[str, str] # 属性名: 类型
16
+
17
+
18
+ class GraphSchema(BaseModel):
19
+ nodes: List[NodeSchema]
20
+ relationships: List[RelationshipSchema]
21
+
22
+
23
+ # 示例图模式 (按照neo4j数据库中的定义schema来填充)
24
+ EXAMPLE_SCHEMA = GraphSchema(
25
+ # 节点的名称一定要严格保持跟neo4j一致
26
+ nodes=[
27
+ # --- PDF 原有的 4 个节点 ---
28
+ NodeSchema(label="Disease", properties={"name": "string", "desc": "string", "cause": "string", "prevent": "string", "cure_lasttime": "string", "cure_department": "string", "cure_way": "string", "cure_prob": "string", "easy_get": "string"}),
29
+ NodeSchema(label="Drug", properties={"name": "string"}),
30
+ NodeSchema(label="Food", properties={"name": "string"}),
31
+ NodeSchema(label="Symptom", properties={"name": "string"}),
32
+ # --- 基于 Neo4j 截图新增的 3 个节点 ---
33
+ NodeSchema(label="Check", properties={"name": "string"}),
34
+ NodeSchema(label="Department", properties={"name": "string"}),
35
+ NodeSchema(label="Producer", properties={"name": "string"}),
36
+ ],
37
+ # 关系的相关字段一定要严格保持跟neo4j一致, 大小写都不能错
38
+ relationships=[
39
+ # --- PDF 原有的 3 个关系 ---
40
+ RelationshipSchema(
41
+ type="has_symptom",
42
+ from_node="Disease",
43
+ to_node="Symptom",
44
+ properties={}
45
+ ),
46
+ RelationshipSchema(
47
+ type="recommand_drug",
48
+ from_node="Disease",
49
+ to_node="Drug",
50
+ properties={}
51
+ ),
52
+ RelationshipSchema(
53
+ type="recommand_eat",
54
+ from_node="Disease",
55
+ to_node="Food",
56
+ properties={}
57
+ ),
58
+ # --- 基于 Neo4j 截图新增的关系 ---
59
+ # Disease 需要做的检查项目
60
+ RelationshipSchema(
61
+ type="need_check",
62
+ from_node="Disease",
63
+ to_node="Check",
64
+ properties={}
65
+ ),
66
+ # Disease 所属的科室
67
+ RelationshipSchema(
68
+ type="belongs_to",
69
+ from_node="Disease",
70
+ to_node="Department",
71
+ properties={}
72
+ ),
73
+ # Disease 的常用药物
74
+ RelationshipSchema(
75
+ type="common_drug",
76
+ from_node="Disease",
77
+ to_node="Drug",
78
+ properties={}
79
+ ),
80
+ # Disease 宜吃的食物
81
+ RelationshipSchema(
82
+ type="do_eat",
83
+ from_node="Disease",
84
+ to_node="Food",
85
+ properties={}
86
+ ),
87
+ # Disease 忌吃的食物
88
+ RelationshipSchema(
89
+ type="no_eat",
90
+ from_node="Disease",
91
+ to_node="Food",
92
+ properties={}
93
+ ),
94
+ # Disease 的并发症
95
+ RelationshipSchema(
96
+ type="acompany_with",
97
+ from_node="Disease",
98
+ to_node="Disease",
99
+ properties={}
100
+ ),
101
+ # Drug 的生产商
102
+ RelationshipSchema(
103
+ type="drugs_of",
104
+ from_node="Producer",
105
+ to_node="Drug",
106
+ properties={}
107
+ ),
108
+ ]
109
+ )
110
+
111
+
112
+ if __name__ == '__main__':
113
+ res = str(EXAMPLE_SCHEMA.model_dump())
114
+ print(res)
GraphDatabase/validators.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import List, Tuple
3
+ from neo4j import GraphDatabase
4
+ import os
5
+
6
+
7
+ class CypherValidator:
8
+ def __init__(self, neo4j_uri: str, neo4j_user: str, neo4j_password: str):
9
+ self.driver = GraphDatabase.driver(neo4j_uri, auth=(neo4j_user, neo4j_password))
10
+
11
+ def validate_syntax(self, cypher_query: str) -> Tuple[bool, List[str]]:
12
+ """验证Cypher查询的语法"""
13
+ errors = []
14
+
15
+ # 基本语法检查
16
+ if not cypher_query.strip().upper().startswith(('MATCH', 'CREATE', 'MERGE', 'CALL')):
17
+ errors.append("查询必须以MATCH, CREATE, MERGE 或 CALL开头!!!")
18
+
19
+ # 检查是否有潜在的注入风险
20
+ if any(keyword in cypher_query.upper() for keyword in ['DROP', 'DELETE', 'DETACH', 'REMOVE']):
21
+ if not any(keyword in cypher_query.upper() for keyword in ['DELETE', 'DETACH']):
22
+ errors.append("查询包含可能危险的操作符")
23
+
24
+ # 检查RETURN语句是否存在 (对于MATCH查询)
25
+ if cypher_query.upper().startswith('MATCH') and 'RETURN' not in cypher_query.upper():
26
+ errors.append("MATCH查询必须包含RETURN语句!!!")
27
+
28
+ # 使用Neo4j解释计划验证查询
29
+ try:
30
+ with self.driver.session() as session:
31
+ result = session.run(f"EXPLAIN {cypher_query}")
32
+ # 如果解释成功, 语法基本正确
33
+ return True, errors
34
+ except Exception as e:
35
+ errors.append(f"语法错误: {str(e)}")
36
+ return False, errors
37
+
38
+ def validate_against_schema(self, cypher_query: str, schema) -> Tuple[bool, List[str]]:
39
+ """根据模式验证查询"""
40
+ errors = []
41
+
42
+ # 提取所有节点标签
43
+ node_labels = [node.label for node in schema.nodes]
44
+ node_pattern = r'\(([a-zA-Z0-9_]+)?:?([a-zA-Z0-9_]+)\)'
45
+ matches = re.findall(node_pattern, cypher_query)
46
+
47
+ for match in matches:
48
+ if match[1] and match[1] not in node_labels:
49
+ errors.append(f"使用了不存在的节点标签: {match[1]}")
50
+
51
+ # 提取所有关系类型
52
+ rel_types = [rel.type for rel in schema.relationships]
53
+ rel_pattern = r'\[([a-zA-Z0-9_]+)?:?([a-zA-Z0-9_]+)\]'
54
+ rel_matches = re.findall(rel_pattern, cypher_query)
55
+
56
+ for match in rel_matches:
57
+ if match[1] and match[1] not in rel_types:
58
+ errors.append(f"使用了不存在的关系类型: {match[1]}")
59
+
60
+ return len(errors) == 0, errors
61
+
62
+ def close(self):
63
+ self.driver.close()
64
+
65
+
66
+
67
+ # 简单的基于规则的验证器 (当无法连接Neo4j时使用)
68
+ class RuleBasedValidator:
69
+ def validate(self, cypher_query: str, schema) -> Tuple[bool, List[str]]:
70
+ errors = []
71
+
72
+ # 检查基本结构
73
+ if not cypher_query.strip():
74
+ errors.append("查询不能为空!!!")
75
+ return False, errors
76
+
77
+ # 检查是否包含潜在危险操作
78
+ dangerous_patterns = [
79
+ (r'(?i)drop\s+', "DROP操作可能危险"),
80
+ (r'(?i)delete\s+', "DELETE操作需要谨慎"),
81
+ (r'(?i)detach\s+delete', "DETACH DELETE操作非常危险!!"),
82
+ (r'(?i)remove\s+', "REMOVE操作需要谨慎"),
83
+ ]
84
+
85
+ for pattern, message in dangerous_patterns:
86
+ if re.search(pattern, cypher_query):
87
+ errors.append(message)
88
+
89
+ # 检查MATCH查询是否包含RETURN
90
+ if re.match(r'(?i)match', cypher_query) and not re.search(r'(?i)return', cypher_query):
91
+ errors.append("MATCH查询必须包含RETURN子句")
92
+
93
+ # 检查CREATE查询是否合理
94
+ if re.match(r'(?i)create', cypher_query) and not re.search(r'(?i)(node|relationship|label|index)', cypher_query):
95
+ errors.append("CREATE查询应该明确创建节点或关系")
96
+
97
+ return len(errors) == 0, errors
98
+
99
+ def validate_against_schema(self, cypher_query: str, schema) -> Tuple[bool, List[str]]:
100
+ """兼容CypherValidator的接口, 先做规则验证再做schema验证"""
101
+ is_valid, errors = self.validate(cypher_query, schema)
102
+
103
+ # 额外进行schema验证
104
+ node_labels = [node.label for node in schema.nodes]
105
+ node_pattern = r'\(([a-zA-Z0-9_]+)?:?([a-zA-Z0-9_]+)\)'
106
+ matches = re.findall(node_pattern, cypher_query)
107
+
108
+ for match in matches:
109
+ if match[1] and match[1] not in node_labels:
110
+ errors.append(f"使用了不存在的节点标签: {match[1]}")
111
+
112
+ rel_types = [rel.type for rel in schema.relationships]
113
+ rel_pattern = r'\[([a-zA-Z0-9_]+)?:?([a-zA-Z0-9_]+)\]'
114
+ rel_matches = re.findall(rel_pattern, cypher_query)
115
+
116
+ for match in rel_matches:
117
+ if match[1] and match[1] not in rel_types:
118
+ errors.append(f"使用了不存在的关系类型: {match[1]}")
119
+
120
+ return len(errors) == 0, errors
agent3.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uvicorn
3
+ from fastapi import FastAPI, Request
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ import json
6
+ import requests
7
+ import datetime
8
+ from openai import OpenAI
9
+ from neo4j import GraphDatabase
10
+ from langchain_milvus import Milvus, BM25BuiltInFunction
11
+ from vector import OpenAIEmbeddings, get_redis_client, cache_set, cache_get
12
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
13
+ from langchain_core.stores import InMemoryStore
14
+ from langchain_classic.retrievers.parent_document_retriever import ParentDocumentRetriever
15
+ from dotenv import load_dotenv
16
+
17
+ # 加载 .env 文件中的环境变量, 隐藏 API Keys
18
+ load_dotenv()
19
+
20
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
21
+ app = FastAPI()
22
+
23
+
24
+ # ============================================================
25
+ # OpenAI LLM 客户端封装 (替代讲义中的 DeepSeek)
26
+ # ============================================================
27
+
28
+ def create_openai_client():
29
+ """创建 OpenAI 客户端"""
30
+ client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
31
+ return client
32
+
33
+
34
+ def generate_openai_answer(client, prompt):
35
+ """使用 OpenAI 生成回复"""
36
+ response = client.chat.completions.create(
37
+ model="gpt-4o-mini",
38
+ messages=[
39
+ {"role": "user", "content": prompt}
40
+ ],
41
+ temperature=0.7,
42
+ )
43
+ return response.choices[0].message.content
44
+
45
+
46
+ # 允许所有域的请求
47
+ app.add_middleware(
48
+ CORSMiddleware,
49
+ allow_origins=["*"],
50
+ allow_credentials=True,
51
+ allow_methods=["*"],
52
+ allow_headers=["*"],
53
+ )
54
+
55
+ # 创建 Embedding 模型
56
+ embedding_model = OpenAIEmbeddings()
57
+ print("创建 Embedding 模型成功......")
58
+
59
+ # 设置默认的 Milvus 数据库文件路径
60
+ URI = "./milvus_agent.db"
61
+ URI1 = "./pdf_agent.db"
62
+
63
+ # 创建 Milvus 连接
64
+ milvus_vectorstore = Milvus(
65
+ embedding_function=embedding_model,
66
+ builtin_function=BM25BuiltInFunction(),
67
+ vector_field=["dense", "sparse"],
68
+ index_params=[
69
+ {
70
+ "metric_type": "IP",
71
+ "index_type": "IVF_FLAT",
72
+ },
73
+ {
74
+ "metric_type": "BM25",
75
+ "index_type": "SPARSE_INVERTED_INDEX"
76
+ }
77
+ ],
78
+ connection_args={"uri": URI},
79
+ )
80
+
81
+ retriever = milvus_vectorstore.as_retriever()
82
+ print("创建 Milvus 连接成功......")
83
+
84
+
85
+ docstore = InMemoryStore()
86
+
87
+ # 文本分割器
88
+ child_splitter = RecursiveCharacterTextSplitter(
89
+ chunk_size=200,
90
+ chunk_overlap=50,
91
+ length_function=len,
92
+ separators=["\n\n", "\n", "。", "!", "?", ";", ",", " ", ""]
93
+ )
94
+
95
+ parent_splitter = RecursiveCharacterTextSplitter(
96
+ chunk_size=1000,
97
+ chunk_overlap=200
98
+ )
99
+
100
+ pdf_vectorstore = Milvus(
101
+ embedding_function=embedding_model,
102
+ builtin_function=BM25BuiltInFunction(),
103
+ vector_field=["dense", "sparse"],
104
+ index_params=[
105
+ {
106
+ "metric_type": "IP",
107
+ "index_type": "IVF_FLAT",
108
+ },
109
+ {
110
+ "metric_type": "BM25",
111
+ "index_type": "SPARSE_INVERTED_INDEX"
112
+ }
113
+ ],
114
+ connection_args={"uri": URI1},
115
+ consistency_level="Bounded",
116
+ drop_old=False,
117
+ )
118
+
119
+ # 设置父子文档检索器
120
+ parent_retriever = ParentDocumentRetriever(
121
+ vectorstore=pdf_vectorstore,
122
+ docstore=docstore,
123
+ child_splitter=child_splitter,
124
+ parent_splitter=parent_splitter,
125
+ )
126
+
127
+ print("创建 Parent Milvus 连接成功......")
128
+
129
+ # 获取 neo4j 图数据库的连接
130
+ neo4j_uri = os.getenv("NEO4J_URI", "bolt://localhost:7687")
131
+ neo4j_user = os.getenv("NEO4J_USER", "neo4j")
132
+ neo4j_password = os.getenv("NEO4J_PASSWORD", "neo4j")
133
+ driver = GraphDatabase.driver(uri=neo4j_uri, auth=(neo4j_user, neo4j_password), max_connection_lifetime=1000)
134
+ print("创建 Neo4j 连接成功......")
135
+
136
+ # 创建大语言模型, 采用 OpenAI
137
+ client_llm = create_openai_client()
138
+ print("创建 OpenAI LLM 成功......")
139
+
140
+ # 获取 Redis 连接
141
+ client_redis = get_redis_client()
142
+ print("创建 Redis 连接成功......")
143
+
144
+
145
+ def format_docs(docs):
146
+ return "\n\n".join(doc.page_content for doc in docs)
147
+
148
+
149
+ @app.post("/")
150
+ async def chatbot(request: Request):
151
+ global milvus_vectorstore, retriever
152
+
153
+ json_post_raw = await request.json()
154
+ json_post = json.dumps(json_post_raw)
155
+ json_post_list = json.loads(json_post)
156
+
157
+ query = json_post_list.get('question')
158
+
159
+ # ============================================================
160
+ # 1: 先查 Redis 缓存, 如果缓存命中, 直接返回结果
161
+ # ============================================================
162
+ response_redis = cache_get(client_redis, query)
163
+
164
+ if response_redis is not None:
165
+ # redis 返回的字符串是以十六进制显示的, 需要按 utf-8 解码
166
+ response = response_redis.decode('utf-8')
167
+
168
+ now = datetime.datetime.now()
169
+ time = now.strftime("%Y-%m-%d %H:%M:%S")
170
+ answer = {
171
+ "response": response,
172
+ "status": 200,
173
+ "time": time
174
+ }
175
+ print('REDIS HIT !!!')
176
+ return answer
177
+
178
+ # ============================================================
179
+ # 2: 向量数据库 Milvus 模糊召回 & 重排序
180
+ # ============================================================
181
+ # 在集合中搜索问题并检索语义 top-10 匹配项, 而且已经配置了 reranker 的处理, 采用RRF算法
182
+ recall_rerank_milvus = milvus_vectorstore.similarity_search(
183
+ query,
184
+ k=10,
185
+ ranker_type="rrf",
186
+ ranker_params={"k": 100}
187
+ )
188
+
189
+ if recall_rerank_milvus:
190
+ # 检索结果存放在列表中
191
+ context = format_docs(recall_rerank_milvus)
192
+ else:
193
+ context = ""
194
+
195
+ # ============================================================
196
+ # 2.5: PDF 文档的 Milvus 召回 (父子文档检索器)
197
+ # ============================================================
198
+ pdf_res = ""
199
+ retrieved_docs = parent_retriever.invoke(query)
200
+
201
+ if retrieved_docs is not None and len(retrieved_docs) >= 1:
202
+ pdf_res = retrieved_docs[0].page_content
203
+ print("PDF res: ", pdf_res)
204
+
205
+ context = context + "\n" + pdf_res
206
+
207
+ # ============================================================
208
+ # 3: 图数据库 neo4j 精准召回
209
+ # ============================================================
210
+ # 访问 neo4j API 服务, 生成 Cypher 命令
211
+ neo4j_res = ""
212
+ data = {"natural_language_query": query}
213
+ data_json = json.dumps(data)
214
+
215
+ try:
216
+ cypher_response = requests.post("http://0.0.0.0:8101/generate", data_json)
217
+
218
+ if cypher_response.status_code == 200:
219
+ cypher_response_data = cypher_response.json()
220
+
221
+ cypher_query = cypher_response_data["cypher_query"]
222
+ confidence = cypher_response_data["confidence"]
223
+ is_valid = cypher_response_data["validated"]
224
+
225
+ if cypher_query is not None and float(confidence) >= 0.9 and is_valid == True:
226
+ print("neo4j Cypher 初步生成成功 !!!")
227
+
228
+ # 验证 neo4j 生成的 Cypher 命令完全正确
229
+ data = {"cypher_query": cypher_query}
230
+ data_json = json.dumps(data)
231
+ cypher_valid = requests.post("http://0.0.0.0:8101/validate", data_json)
232
+
233
+ if cypher_valid.status_code == 200:
234
+ cypher_valid_data = cypher_valid.json()
235
+ if cypher_valid_data["is_valid"] == True:
236
+ with driver.session() as session:
237
+ try:
238
+ record = session.run(cypher_query)
239
+ result = list(map(lambda x: x[0], record))
240
+ neo4j_res = ','.join(result)
241
+ except Exception as e:
242
+ print(e)
243
+ print("neo4j查询失败 !!")
244
+ neo4j_res = ""
245
+ else:
246
+ print("生成Cypher查询失败 !!")
247
+ except Exception as e:
248
+ print(f"neo4j API 服务不可用: {e}")
249
+
250
+ # 合并 Milvus、PDF 和 neo4j 的召回结果, 共同作为 LLM 的输入 prompt
251
+ context = context + "\n" + neo4j_res
252
+
253
+ # ============================================================
254
+ # 4: 为LLM定义系统和用户提示
255
+ # ============================================================
256
+ SYSTEM_PROMPT = """
257
+ System: 你是一个非常得力的医学助手, 你可以通过从数据库中检索出的信息找到问题的答案.
258
+ """
259
+
260
+ USER_PROMPT = f"""
261
+ User: 利用介于<context>和</context>之间的从数据库中检索出的信息来回答问题, 具体的问题介于<question>和</question>之间. 如果提供的信息为空, 则按照你的经验知识来给出尽可能严谨准确的回答, 不知道的时候坦诚的承认不了解, 不要编造不真实的信息.
262
+ <context>
263
+ {context}
264
+ </context>
265
+
266
+ <question>
267
+ {query}
268
+ </question>
269
+ """
270
+
271
+ # ============================================================
272
+ # 5: 使用 OpenAI 最新版本模型, 根据提示生成回复
273
+ # ============================================================
274
+ response = generate_openai_answer(client_llm, SYSTEM_PROMPT + USER_PROMPT.format(context, query))
275
+
276
+ # ============================================================
277
+ # 6: 写入缓存
278
+ # ============================================================
279
+ cache_set(client_redis, query, response)
280
+
281
+ # ============================================================
282
+ # 7: 组装服务返回数据
283
+ # ============================================================
284
+ now = datetime.datetime.now()
285
+ time = now.strftime("%Y-%m-%d %H:%M:%S")
286
+
287
+ answer = {
288
+ "response": response,
289
+ "status": 200,
290
+ "time": time
291
+ }
292
+
293
+ return answer
294
+
295
+
296
+ if __name__ == '__main__':
297
+ # 主函数中直接启动fastapi服务
298
+ uvicorn.run(app, host='0.0.0.0', port=8103, workers=1)
milvus_agent.db DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:3fb6f3a55a098a6eac5d6b916bb55ac65827941084f862007ded0669e2671f8e
3
- size 28672
 
 
 
 
test.py CHANGED
@@ -3,8 +3,9 @@ import time
3
  import json
4
 
5
  url = "http://0.0.0.0:8103/"
6
- data = {"question": "平日里蜂蜜加白醋一起喝有什么疗效?"}
7
- #data = {"question": "听说用酸枣仁泡水喝能养生,是真的吗?"}
 
8
 
9
  start_time = time.time()
10
 
 
3
  import json
4
 
5
  url = "http://0.0.0.0:8103/"
6
+ #data = {"question": "平日里蜂蜜加白醋一起喝有什么疗效?"}
7
+ data = {"question": "听说用酸枣仁泡水喝能养生,是真的吗?"}
8
+ #data = {"question": "糖尿病有什么症状?"}
9
 
10
  start_time = time.time()
11
 
vector.py CHANGED
@@ -4,6 +4,7 @@ from tqdm import tqdm
4
  import json
5
  import uuid
6
  import time
 
7
  import pandas as pd
8
  from openai import OpenAI
9
  from langchain.embeddings.base import Embeddings
@@ -18,6 +19,36 @@ from dotenv import load_dotenv
18
  load_dotenv()
19
 
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  # ============================================================
22
  # 嵌入模型, 采用 OpenAI text-embedding-3-small
23
  # ============================================================
@@ -247,7 +278,7 @@ if __name__ == "__main__":
247
  vectorstore = milvus_vectorstore.create_vector_store(docs)
248
  print("全部初始化完成, 可以开始问答了......")
249
  '''
250
-
251
  # 将 PDF 后处理文档中的数据, 封装成Document
252
  docs = prepare_pdf_document()
253
  print("预处理 PDF 文档数据成功......")
@@ -259,4 +290,9 @@ if __name__ == "__main__":
259
  retriever = pdf_vectorstore.create_pdf_vector_store(docs)
260
  print("创建基于 Milvus 数据库的父子文档检索器成功......")
261
  print(retriever)
 
 
 
 
 
262
  print("全部初始化完成, 可以开始问答了......")
 
4
  import json
5
  import uuid
6
  import time
7
+ import redis
8
  import pandas as pd
9
  from openai import OpenAI
10
  from langchain.embeddings.base import Embeddings
 
19
  load_dotenv()
20
 
21
 
22
+ # ============================================================
23
+ # Redis 缓存处理模块
24
+ # ============================================================
25
+
26
+ def get_redis_client():
27
+ # 创建Redis连接, 使用连接池 (推荐用于生产环境)
28
+ pool = redis.ConnectionPool(host='0.0.0.0', port=6379, db=0, password=None, max_connections=10)
29
+ r = redis.StrictRedis(connection_pool=pool)
30
+
31
+ # 测试连接
32
+ try:
33
+ r.ping()
34
+ print("成功连接到 Redis !")
35
+ except redis.ConnectionError:
36
+ print("无法连接到 Redis !")
37
+
38
+ return r
39
+
40
+
41
+ # 将 (question, answer) 问答对, 存入 redis
42
+ def cache_set(r, question: str, answer: str):
43
+ r.hset("qa", question, answer)
44
+ r.expire("qa", 3600)
45
+
46
+
47
+ # 通过 question, 读取存在 redis 中的 answer
48
+ def cache_get(r, question: str):
49
+ return r.hget("qa", question)
50
+
51
+
52
  # ============================================================
53
  # 嵌入模型, 采用 OpenAI text-embedding-3-small
54
  # ============================================================
 
278
  vectorstore = milvus_vectorstore.create_vector_store(docs)
279
  print("全部初始化完成, 可以开始问答了......")
280
  '''
281
+ ''''
282
  # 将 PDF 后处理文档中的数据, 封装成Document
283
  docs = prepare_pdf_document()
284
  print("预处理 PDF 文档数据成功......")
 
290
  retriever = pdf_vectorstore.create_pdf_vector_store(docs)
291
  print("创建基于 Milvus 数据库的父子文档检索器成功......")
292
  print(retriever)
293
+ '''
294
+ r = get_redis_client()
295
+ print("创建Redis连接成功......")
296
+ print(r)
297
+
298
  print("全部初始化完成, 可以开始问答了......")