Toadied commited on
Commit
edf63e7
·
verified ·
1 Parent(s): 5d751eb

Upload 16 files

Browse files
.env ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # OPENAI API 访问密钥配置
2
+ OPENAI_API_KEY = ""
3
+ # 文心 API 访问密钥配置
4
+ # 方式1. 使用应用 AK/SK 鉴权
5
+ # 创建的应用的 API Key
6
+ QIANFAN_AK = ""
7
+ # 创建的应用的 Secret Key
8
+ QIANFAN_SK = ""
9
+ # 方式2. 使用安全认证 AK/SK 鉴权
10
+ # 安全认证方式获取的 Access Key
11
+ QIANFAN_ACCESS_KEY = ""
12
+ # 安全认证方式获取的 Secret Key
13
+ QIANFAN_SECRET_KEY = ""
14
+
15
+ # Ernie SDK 文心 API 访问密钥配置
16
+ EB_ACCESS_TOKEN = ""
17
+
18
+ # 控制台中获取的 APPID 信息
19
+ IFLYTEK_SPARK_APP_ID = ""
20
+ # 控制台中获取的 APIKey 信息
21
+ IFLYTEK_SPARK_API_KEY = ""
22
+ # 控制台中获取的 APISecret 信息
23
+ IFLYTEK_SPARK_API_SECRET = ""
24
+
25
+ # 智谱 API 访问密钥配置
26
+ ZHIPUAI_API_KEY = "18d8cadb02594fa9b6876ea298ddc13c.vpdBu5BbfpTzsS7t"
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data_base/vector_db/chroma/chroma.sqlite3 filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ */.idea/
161
+ */.DS_Store
162
+ */*/.DS_Store
163
+ .idea/
164
+ .DS_Store
.vscode/settings.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "editor.autoIndentOnPaste": true
3
+ }
__pycache__/zhipuEmbedding.cpython-310.pyc ADDED
Binary file (1.87 kB). View file
 
__pycache__/zhipuLLM.cpython-310.pyc ADDED
Binary file (4.92 kB). View file
 
app.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["CHROMA_TELEMETRY_DISABLED"] = "true"
3
+ from dotenv import load_dotenv, find_dotenv
4
+ from zhipuLLM import ZhipuaiLLM
5
+ from zhipuEmbedding import ZhipuAiEmbeddings
6
+ from langchain_community.vectorstores import Chroma
7
+ from langchain_core.runnables import RunnablePassthrough
8
+ from langchain_core.output_parsers import StrOutputParser
9
+ from langchain_core.prompts import ChatPromptTemplate
10
+ from langchain_core.runnables import RunnableBranch
11
+
12
+ #ui
13
+ import gradio as gr
14
+
15
+
16
+ _ = load_dotenv(find_dotenv())
17
+ api_key=os.environ["ZHIPUAI_API_KEY"]
18
+
19
+ def combine_docs(docs):
20
+ return "\n\n".join(doc.page_content for doc in docs["context"])
21
+
22
+
23
+ def show_switch_status(switch_state):
24
+ return switch_state
25
+
26
+ # 在文件顶部定义转换函数
27
+ def format_chat_history(chatbot):
28
+ """将 Gradio Chatbot 格式转为 LangChain 支持的 chat_history 格式"""
29
+ formatted_history = []
30
+ for human_msg, ai_msg in chatbot:
31
+ formatted_history.append(("human", human_msg))
32
+ formatted_history.append(("ai", ai_msg))
33
+ return formatted_history
34
+
35
+ def chatbot_response(input, chatbot, isUseRAG):
36
+ """根据开关状态返回提示信息"""
37
+
38
+ llm = ZhipuaiLLM(model_name="glm-4-plus", temperature=0.1, api_key=api_key)
39
+ if isUseRAG:
40
+
41
+ # 问答链的系统prompt
42
+ system_prompt = (
43
+ "你是一个问答任务的助手。 "
44
+ "请使用检索到的上下文片段回答这个问题。 "
45
+ "如果你不知道答案就说不知道。 "
46
+ "请使用简洁的话语回答用户。"
47
+ "\n\n"
48
+ "{context}"
49
+ )
50
+ # 制定prompt template
51
+ qa_prompt = ChatPromptTemplate(
52
+ [
53
+ ("system", system_prompt),
54
+ ("placeholder", "{chat_history}"),
55
+ ("human", "{input}"),
56
+ ]
57
+ )
58
+ # 定义问答链
59
+ qa_chain = (
60
+ RunnablePassthrough.assign(context=combine_docs) # 使用 combine_docs 函数整合 qa_prompt 中的 context
61
+ | qa_prompt # 问答模板
62
+ | llm
63
+ | StrOutputParser() # 规定输出的格式为 str
64
+ )
65
+
66
+ #
67
+ #获取得到向量库
68
+ vectordb = Chroma(
69
+ persist_directory='data_base/vector_db/chroma', # 允许我们将persist_directory目录保存到磁盘上
70
+ embedding_function=ZhipuAiEmbeddings()
71
+ )
72
+ #取数据
73
+ retriever = vectordb.as_retriever(search_kwargs={"k": 1})
74
+
75
+
76
+ # 压缩问题的系统 prompt
77
+ condense_question_system_template = (
78
+ "请根据聊天记录完善用户最新的问题,"
79
+ "如果用户最新的问题不需要完善则返回用户的问题。"
80
+ )
81
+ # 构造 压缩问题的 prompt template
82
+ condense_question_prompt = ChatPromptTemplate([
83
+ ("system", condense_question_system_template),
84
+ ("placeholder", "{chat_history}"),
85
+ ("human", "{input}"),
86
+ ])
87
+
88
+ retrieve_docs = RunnableBranch(
89
+ # 分支 1: 若聊天记录中没有 chat_history 则直接使用用户问题查询向量数据库
90
+ (lambda x: not x.get("chat_history", False), (lambda x: x["input"]) | retriever, ),
91
+ # 分支 2 : 若聊天记录中有 chat_history 则先让 llm 根据聊天记录完善问题再查询向量数据库
92
+ condense_question_prompt | llm | StrOutputParser() | retriever,
93
+ )
94
+
95
+ # 定义带有历史记录的问答链
96
+ qa_history_chain = RunnablePassthrough.assign(
97
+ context = (lambda x: x) | retrieve_docs # 将查询结果存为 content
98
+ ).assign(answer=qa_chain)
99
+
100
+ result = qa_history_chain.invoke({
101
+ "input": input,
102
+ "chat_history": format_chat_history(chatbot)
103
+ })
104
+
105
+ print(result)
106
+ chatbot.append((input,result["answer"]))
107
+ return [chatbot,input]
108
+ else:
109
+ result = llm.invoke(input)
110
+ print(chatbot)
111
+ chatbot.append((input,result.content))
112
+ return [chatbot,input]
113
+
114
+ # 创建界面
115
+ with gr.Blocks() as demo:
116
+ gr.Markdown("""
117
+ # 🤖 RAG 智能聊天机器人
118
+ 支持直接调用大模型或结合本地知识库(RAG)回答问题
119
+ """)
120
+
121
+ chatbot = gr.Chatbot(
122
+ label="对话历史",
123
+ height=500, # 对话框高度
124
+ avatar_images=(None, "https://gradio.s3-us-west-2.amazonaws.com/guides/robot.png") # (可选)用户/机器人头像
125
+ )
126
+
127
+ with gr.Row():
128
+ chebox = gr.Checkbox(
129
+ label="RAG",
130
+ value=False
131
+ )
132
+
133
+ with gr.Row():
134
+ input = gr.Textbox(
135
+ label="输入你的问题",
136
+ placeholder="例如:",
137
+ lines=2,
138
+ container=False
139
+ )
140
+ submit_btn = gr.Button("发送", variant="primary", icon="📤")
141
+
142
+ submit_btn.click(
143
+ fn=chatbot_response,
144
+ inputs=[input, chatbot, chebox], # 输入:用户消息 + 历史对话 + 开关状态
145
+ outputs=[chatbot, input] # 输出:更新后的对话 + 清空输入框
146
+ )
147
+
148
+
149
+ if __name__ == "__main__":
150
+ demo.launch(
151
+ share=False,
152
+ show_error=True, # 显示错误信息(调试用)
153
+ )
154
+
155
+
156
+
157
+
158
+
data_base/data/rag.md ADDED
@@ -0,0 +1 @@
 
 
1
+ 2025年乒乓球项目恭喜樊振东获得冠军,马龙为亚军
data_base/vector_db/chroma/81183b78-851d-4e82-8886-717c11558d9c/data_level0.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f5707b4304f81e825ab1c96b0955b9fcbce912c03d1b0e55d9f3b70d0d68046b
3
+ size 8332000
data_base/vector_db/chroma/81183b78-851d-4e82-8886-717c11558d9c/header.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:de65dd7dc719eee86a1e11054bd45ee9d541ad62e7e654ea3a1c5b7d61da6baa
3
+ size 100
data_base/vector_db/chroma/81183b78-851d-4e82-8886-717c11558d9c/length.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a24c0f603727245b4a01a14a3ee703614fed0d5fe14e19b71f01ac4099b3a433
3
+ size 4000
data_base/vector_db/chroma/81183b78-851d-4e82-8886-717c11558d9c/link_lists.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855
3
+ size 0
data_base/vector_db/chroma/chroma.sqlite3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e47310452bc5312a566c477b2af85270b441763a7a127b499fd59462e4b92b89
3
+ size 167936
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ langchain==0.3.0
2
+ langchain-community==0.3.0
3
+ langchain-text-splitters==0.3.0
4
+ langchain-core==0.3.0
5
+ langchain-openai==0.2.0
6
+ langchain-chroma==0.1.4
7
+ python-dotenv==1.0.1
8
+ zhipuai==2.1.5.20250106
9
+ qianfan==0.4.12.3
10
+ unstructured==0.16.23
11
+ pymupdf==1.25.3
12
+ markdown==3.7
13
+ streamlit==1.43.0
14
+ jieba==0.42.1
15
+ pydantic==2.10.6
16
+ gradio==4.44.1
zhipuEmbedding.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import os
3
+ from langchain_core.embeddings import Embeddings
4
+ from zhipuai import ZhipuAI
5
+
6
+ class ZhipuAiEmbeddings(Embeddings):
7
+ def __init__(self):
8
+ self.client = ZhipuAI()
9
+ self.batch_size = 64
10
+
11
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
12
+ '''
13
+ all_embeddings = []
14
+ for i in range(0,len(texts),self.batch_size):
15
+ input_embeddings = texts[i : i + self.batch_size]
16
+ input_embeddings = [text.strip() for text in input_embeddings if text.strip()]
17
+ print(len(texts))
18
+ print(input_embeddings)
19
+ response = self.client.embeddings.create(
20
+ model="embedding-3",
21
+ input=input_embeddings
22
+ )
23
+ batch_embeddings = [embeddings.embedding for embeddings in response.data]
24
+ return all_embeddings.extend(batch_embeddings)
25
+ '''
26
+ response = self.client.embeddings.create(
27
+ model="embedding-3",
28
+ input=texts
29
+ )
30
+ return [embeddings.embedding for embeddings in response.data]
31
+
32
+ def embed_query(self, text: str) -> List[float]:
33
+ return self.embed_documents([text])[0]
zhipuEmbeddingsData.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["USER_AGENT"] = "MyRAGApp/1.0 (https://myapp.example.com; myemail@example.com)"
3
+ os.environ["CHROMA_TELEMETRY_DISABLED"] = "true"
4
+ from zhipuai import ZhipuAI
5
+ from dotenv import load_dotenv, find_dotenv
6
+ from langchain_community.document_loaders import (
7
+ TextLoader, PythonLoader, CSVLoader, JSONLoader,
8
+ Docx2txtLoader, UnstructuredPowerPointLoader,
9
+ PyMuPDFLoader, UnstructuredMarkdownLoader,
10
+ UnstructuredImageLoader, WebBaseLoader
11
+ )
12
+ _ = load_dotenv(find_dotenv())
13
+ client = ZhipuAI(api_key=os.environ["ZHIPUAI_API_KEY"])
14
+
15
+ #数据存入向量库
16
+ from zhipuEmbedding import ZhipuAiEmbeddings
17
+ from langchain_community.vectorstores import Chroma
18
+
19
+ def dataLoadToVectordb(texts):
20
+ embedding = ZhipuAiEmbeddings()
21
+ persist_directory = 'data_base/vector_db/chroma'
22
+ vectordb = Chroma.from_documents(
23
+ documents=texts,
24
+ embedding=embedding,
25
+ persist_directory=persist_directory
26
+ )
27
+ print(f"向量库中存储的数量:{vectordb._collection.count()}")
28
+ return
29
+
30
+ def get_file_paths(folder_path):
31
+ current_dir = os.getcwd()
32
+ abs_folder_path = os.path.abspath(folder_path)
33
+ print(f"当前工作目录:{current_dir}")
34
+ print(f"目标文件夹绝对路径:{abs_folder_path}")
35
+ print(f"目标路径是否存在:{os.path.exists(abs_folder_path)}")
36
+ print(f"目标路径是否是文件夹:{os.path.isdir(abs_folder_path)}")
37
+ # 1.获取所有文件
38
+ file_paths = []
39
+ for root, dirs, files in os.walk(folder_path):
40
+ for file in files:
41
+ file_path = os.path.join(root, file)
42
+ file_paths.append(file_path)
43
+ print(file_paths[:3])
44
+
45
+ # 下载所有文件并存储到text
46
+ texts = []
47
+ for file_path in file_paths:
48
+ splitDocuments(file_path, texts)
49
+
50
+ #2。清洗数据
51
+ #去除多余换行,符号,空格等
52
+
53
+ #3.文档数据分割
54
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
55
+
56
+ # 知识库中单段文本长度
57
+ CHUNK_SIZE = 500
58
+
59
+ # 知识库中相邻文本重合长度
60
+ OVERLAP_SIZE = 0
61
+
62
+ text_splitter = RecursiveCharacterTextSplitter(
63
+ chunk_size=CHUNK_SIZE,
64
+ chunk_overlap=OVERLAP_SIZE
65
+ )
66
+ docs = text_splitter.split_documents(texts)
67
+ print(f"切分后的文件数量:{docs}")
68
+ #print(f"切分后的字符数(可以用来大致评估 token 数):{sum([len(doc.page_content) for doc in docs])}")
69
+
70
+ #dataLoadToVectordb(docs)
71
+ for i in range(0,len(docs),64):
72
+ input_embeddings = docs[i : i + 64]
73
+ #input_embeddings = [text.strip() for text in input_embeddings if text.strip()]
74
+ dataLoadToVectordb(input_embeddings)
75
+
76
+ def splitDocuments(file_path, texts):
77
+ file_type = file_path.split('.')[-1].lower()
78
+ loader = None
79
+ if file_type == 'pdf':
80
+ loader = PyMuPDFLoader(file_path) # PDF首选(高效稳定)
81
+ elif file_type == 'md':
82
+ loader = UnstructuredMarkdownLoader(file_path) # Markdown
83
+ elif file_type == 'txt':
84
+ loader = TextLoader(file_path, encoding="utf-8") # 纯文本
85
+ elif file_type == 'py':
86
+ loader = PythonLoader(file_path) # Python代码
87
+ elif file_type == 'csv':
88
+ loader = CSVLoader(file_path, encoding="utf-8") # 表格
89
+ elif file_type == 'json':
90
+ loader = JSONLoader(file_path, jq_schema=".content", text_content=False) # JSON
91
+ elif file_type == 'docx':
92
+ loader = Docx2txtLoader(file_path) # Word(docx)
93
+ elif file_type in ['xlsx', 'xls']:
94
+ #loader = ExcelLoader(file_path) # Excel(新旧格式)
95
+ print(f"不支持的文件格式:{file_type} | 文件路径:{file_path}")
96
+ return
97
+ elif file_type in ['pptx', 'ppt']:
98
+ loader = UnstructuredPowerPointLoader(file_path) # PPT(新旧格式)
99
+ elif file_type in ['png', 'jpg', 'jpeg']:
100
+ loader = UnstructuredImageLoader(file_path) # 图片(OCR提取)
101
+ elif file_type == 'url':
102
+ loader = WebBaseLoader(file_path) # 普通网页
103
+ elif file_type == 'epub':
104
+ #loader = EpubLoader(file_path) # 电子书
105
+ print(f"不支持的文件格式:{file_type} | 文件路径:{file_path}")
106
+ return
107
+ else:
108
+ print(f"不支持的文件格式:{file_type} | 文件路径:{file_path}")
109
+ return
110
+
111
+ if loader is not None:
112
+
113
+ texts.extend(loader.load())
114
+
115
+
116
+
117
+ if __name__ == "__main__":
118
+ get_file_paths("data_base/data")
zhipuLLM.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Iterator, List, Optional
2
+ from zhipuai import ZhipuAI
3
+ from langchain_core.callbacks import (
4
+ CallbackManagerForLLMRun,
5
+ )
6
+ from langchain_core.language_models import BaseChatModel
7
+ from langchain_core.messages import (
8
+ AIMessage,
9
+ AIMessageChunk,
10
+ BaseMessage,
11
+ SystemMessage,
12
+ ChatMessage,
13
+ HumanMessage
14
+ )
15
+ from langchain_core.messages.ai import UsageMetadata
16
+ from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
17
+ import time
18
+
19
+ def _convert_message_to_dict(message: BaseMessage) -> dict:
20
+ """ 把LangChain的消息格式转为智谱支持的格式
21
+ Args:
22
+ message: The LangChain message.
23
+ Returns:
24
+ The dictionary.
25
+ """
26
+ message_dict: Dict[str, Any] = {"content": message.content}
27
+ if (name := message.name or message.additional_kwargs.get("name")) is not None:
28
+ message_dict["name"] = name
29
+
30
+ # populate role and additional message data
31
+ if isinstance(message, ChatMessage):
32
+ message_dict["role"] = message.role
33
+ elif isinstance(message, HumanMessage):
34
+ message_dict["role"] = "user"
35
+ elif isinstance(message, AIMessage):
36
+ message_dict["role"] = "assistant"
37
+ elif isinstance(message, SystemMessage):
38
+ message_dict["role"] = "system"
39
+ else:
40
+ raise TypeError(f"Got unknown type {message}")
41
+ return message_dict
42
+
43
+ class ZhipuaiLLM(BaseChatModel):
44
+ """自定义Zhipuai聊天模型。
45
+ """
46
+ model_name: str = None
47
+ temperature: Optional[float] = None
48
+ max_tokens: Optional[int] = None
49
+ timeout: Optional[int] = None
50
+ stop: Optional[List[str]] = None
51
+ max_retries: int = 3
52
+ api_key: str | None = None
53
+
54
+ def _generate(
55
+ self,
56
+ messages: List[BaseMessage],
57
+ stop: Optional[List[str]] = None,
58
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
59
+ **kwargs: Any,
60
+ ) -> ChatResult:
61
+ """通过调用智谱API从而响应输入。
62
+
63
+ Args:
64
+ messages: 由messages列表组成的prompt
65
+ stop: 在模型生成的回答中有该字符串列表中的元素则停止响应
66
+ run_manager: 一个为LLM提供回调的运行管理器
67
+ """
68
+ # 列表推导式 将 messages 的元素逐个转为智谱的格式
69
+ messages = [_convert_message_to_dict(message) for message in messages]
70
+ # 定义推理的开始时间
71
+ start_time = time.time()
72
+ # 调用 ZhipuAI 对处理消息
73
+ response = ZhipuAI(api_key=self.api_key).chat.completions.create(
74
+ model=self.model_name,
75
+ temperature=self.temperature,
76
+ max_tokens=self.max_tokens,
77
+ timeout=self.timeout,
78
+ stop=stop,
79
+ messages=messages
80
+ )
81
+ # 计算运行时间 由现在时间 time.time() 减去 开始时间start_time得到
82
+ time_in_seconds = time.time() - start_time
83
+ # 将返回的消息封装并返回
84
+ message = AIMessage(
85
+ content=response.choices[0].message.content, # 响应的结果
86
+ additional_kwargs={}, # 额外信息
87
+ response_metadata={
88
+ "time_in_seconds": round(time_in_seconds, 3), # 响应源数据 这里是运行时间 也可以添加其他信息
89
+ },
90
+ # 本次推理消耗的token
91
+ usage_metadata={
92
+ "input_tokens": response.usage.prompt_tokens, # 输入token
93
+ "output_tokens": response.usage.completion_tokens, # 输出token
94
+ "total_tokens": response.usage.total_tokens, # 全部token
95
+ },
96
+ )
97
+ generation = ChatGeneration(message=message)
98
+ return ChatResult(generations=[generation])
99
+
100
+
101
+ def _stream(
102
+ self,
103
+ messages: List[BaseMessage],
104
+ stop: Optional[List[str]] = None,
105
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
106
+ **kwargs: Any,
107
+ ) -> Iterator[ChatGenerationChunk]:
108
+ """通过调用智谱API返回流式输出。
109
+
110
+ Args:
111
+ messages: 由messages列表组成的prompt
112
+ stop: 在模型生成的回答中有该字符串列表中的元素则停止响应
113
+ run_manager: 一个为LLM提供回调的运行管理器
114
+ """
115
+ messages = [_convert_message_to_dict(message) for message in messages]
116
+ response = ZhipuAI().chat.completions.create(
117
+ model=self.model_name,
118
+ stream=True, # 将stream 设置为 True 返回的是迭代器,可以通过for循环取值
119
+ temperature=self.temperature,
120
+ max_tokens=self.max_tokens,
121
+ timeout=self.timeout,
122
+ stop=stop,
123
+ messages=messages
124
+ )
125
+ start_time = time.time()
126
+ # 使用for循环存���结果
127
+ for res in response:
128
+ if res.usage: # 如果 res.usage 存在则存储token使用情况
129
+ usage_metadata = UsageMetadata(
130
+ {
131
+ "input_tokens": res.usage.prompt_tokens,
132
+ "output_tokens": res.usage.completion_tokens,
133
+ "total_tokens": res.usage.total_tokens,
134
+ }
135
+ )
136
+ # 封装每次返回的chunk
137
+ chunk = ChatGenerationChunk(
138
+ message=AIMessageChunk(content=res.choices[0].delta.content)
139
+ )
140
+
141
+ if run_manager:
142
+ # This is optional in newer versions of LangChain
143
+ # The on_llm_new_token will be called automatically
144
+ run_manager.on_llm_new_token(res.choices[0].delta.content, chunk=chunk)
145
+ # 使用yield返回 结果是一个生成器 同样可以使用for循环调用
146
+ yield chunk
147
+ time_in_sec = time.time() - start_time
148
+ # Let's add some other information (e.g., response metadata)
149
+ # 最终返回运行时间
150
+ chunk = ChatGenerationChunk(
151
+ message=AIMessageChunk(content="", response_metadata={"time_in_sec": round(time_in_sec, 3)}, usage_metadata=usage_metadata)
152
+ )
153
+ if run_manager:
154
+ # This is optional in newer versions of LangChain
155
+ # The on_llm_new_token will be called automatically
156
+ run_manager.on_llm_new_token("", chunk=chunk)
157
+ yield chunk
158
+
159
+ @property
160
+ def _llm_type(self) -> str:
161
+ """获取此聊天模型使用的语言模型类型。"""
162
+ return self.model_name
163
+
164
+ @property
165
+ def _identifying_params(self) -> Dict[str, Any]:
166
+ """返回一个标识参数的字典。
167
+
168
+ 该信息由LangChain回调系统使用,用于跟踪目的,使监视llm成为可能。
169
+ """
170
+ return {
171
+ "model_name": self.model_name,
172
+ }