sd / gemini_excel_rag.py
decula
Added 7b_dual.py
3bb71dd
import os
import gradio as gr
import torch
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.prompts import ChatPromptTemplate
from langchain.schema.runnable import RunnablePassthrough
from langchain.schema.output_parser import StrOutputParser
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_core.documents import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from markitdown import MarkItDown
import pandas as pd
import numpy as np
import re
# 设置Google API密钥
os.environ["GOOGLE_API_KEY"] = "AIzaSyAouHZNUHVWoMHPTrTZKCES-OqiosAfEJY" # 请替换为您的API密钥
# 设置向量数据库存储路径
VECTOR_STORE_PATH = "./vector_store"
# 初始化嵌入模型
embedding_model = HuggingFaceEmbeddings(model_name="BAAI/bge-small-zh-v1.5")
# 初始化文本分割器
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=50,
separators=["\n\n", "\n", "。", "!", "?", ".", "!", "?", " ", ""]
)
# 初始化向量数据库
def get_vectorstore():
"""获取向量数据库,如果不存在则创建一个空的"""
if os.path.exists(VECTOR_STORE_PATH):
return Chroma(persist_directory=VECTOR_STORE_PATH, embedding_function=embedding_model)
else:
vectorstore = Chroma.from_documents(
documents=[Document(page_content="初始化文档", metadata={"source": "初始化"})],
embedding=embedding_model,
persist_directory=VECTOR_STORE_PATH
)
vectorstore.persist()
return vectorstore
# 初始化LLM模型
def get_llm():
"""初始化Gemini Flash 2.0模型"""
return ChatGoogleGenerativeAI(
model="gemini-2.0-flash-exp",
temperature=0.7,
convert_system_message_to_human=True,
max_output_tokens=2048
)
# 创建RAG链
def create_rag_chain():
"""创建RAG检索链"""
vectorstore = get_vectorstore()
retriever = vectorstore.as_retriever(search_kwargs={"k": 5})
# 创建提示模板
template = """
你是一个专业的数据分析助手。请基于以下检索到的Excel表格数据回答用户的问题。
如果检索内容中没有相关信息,请诚实地告知用户你不知道,不要编造答案。
检索到的Excel表格数据:
{context}
用户问题: {question}
请提供详细、准确的回答,并在适当的情况下引用数据来源。
"""
prompt = ChatPromptTemplate.from_template(template)
# 创建RAG链
llm = get_llm()
rag_chain = (
{"context": retriever, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
return rag_chain
# 使用MarkItDown处理Excel文件并添加到向量数据库
def process_excel_with_markitdown(file_path):
"""使用MarkItDown处理Excel文件并添加到向量数据库"""
try:
# 使用MarkItDown提取Excel内容
md = MarkItDown(enable_plugins=False)
result = md.convert(file_path)
excel_content = result.text_content
# 将Excel内容分割成块
docs = text_splitter.split_text(excel_content)
documents = [Document(page_content=doc, metadata={"source": file_path}) for doc in docs]
# 添加到向量数据库
vectorstore = get_vectorstore()
vectorstore.add_documents(documents)
vectorstore.persist()
return f"成功处理Excel文件: {file_path},添加了{len(documents)}个文档块到向量数据库"
except Exception as e:
return f"处理Excel文件时出错: {str(e)}"
# 使用pandas直接处理Excel文件并添加到向量数据库
def process_excel_with_pandas(file_path):
"""使用pandas处理Excel文件并添加到向量数据库,将每列作为单独的元数据字段"""
try:
# 读取Excel文件
df = pd.read_excel(file_path)
# 将每个表格行转换为文本和元数据
documents = []
for idx, row in df.iterrows():
# 创建行文本内容(用于向量化和检索)
row_text = "\n".join([f"{col}: {val}" for col, val in row.items() if not pd.isna(val)])
# 创建元数据字典,包含每列的值
metadata = {
"source": file_path,
"row": idx,
"sheet": "Sheet1" # 如果需要处理多个sheet,可以在这里修改
}
# 将每列的值添加到元数据中
for col, val in row.items():
# 处理不同类型的数据
if isinstance(val, (int, float)) and not pd.isna(val):
metadata[f"col_{col}"] = val
elif isinstance(val, str) and val.strip():
metadata[f"col_{col}"] = val.strip()
elif pd.isna(val):
# 跳过空值
continue
else:
# 其他类型转为字符串
metadata[f"col_{col}"] = str(val)
# 创建文档
doc = Document(
page_content=row_text,
metadata=metadata
)
documents.append(doc)
# 添加到向量数据库
vectorstore = get_vectorstore()
vectorstore.add_documents(documents)
vectorstore.persist()
return f"成功处理Excel文件: {file_path},添加了{len(documents)}个行记录到向量数据库,每行包含{len(df.columns)}个字段"
except Exception as e:
return f"处理Excel文件时出错: {str(e)}"
# 查询向量数据库
def query_vectorstore(query, k=5):
"""直接查询向量数据库"""
vectorstore = get_vectorstore()
results = vectorstore.similarity_search(query, k=k)
return results
# 使用RAG链回答问题
def answer_question(question):
"""使用RAG链回答问题"""
rag_chain = create_rag_chain()
response = rag_chain.invoke(question)
return response
# 创建Gradio界面
port=7860
use_frpc=True
frpconfigfile="7680.ini"
import subprocess
def install_Frpc(port, frpconfigfile, use_frpc):
if use_frpc:
subprocess.run(['chmod', '+x', './frpc'], check=True)
print(f'正在启动frp ,端口{port}')
subprocess.Popen(['./frpc', '-c', frpconfigfile])
install_Frpc('7860',frpconfigfile,use_frpc)
def create_interface():
with gr.Blocks(title="Gemini Flash 2.0 Excel RAG") as demo:
gr.HTML("<h1 style='text-align: center'>Gemini Flash 2.0 Excel RAG 系统</h1>")
with gr.Tab("导入Excel数据"):
with gr.Row():
excel_file = gr.File(label="上传Excel文件")
process_method = gr.Radio(["使用MarkItDown处理", "使用Pandas处理"], label="处理方法", value="使用Pandas处理")
process_btn = gr.Button("处理并导入到向量数据库")
output_msg = gr.Textbox(label="处理结果")
def process_excel(file_path, method):
if method == "使用MarkItDown处理":
return process_excel_with_markitdown(file_path)
else:
return process_excel_with_pandas(file_path)
process_btn.click(
process_excel,
inputs=[excel_file, process_method],
outputs=[output_msg]
)
with gr.Tab("查询问答"):
with gr.Row():
question_input = gr.Textbox(label="输入问题", placeholder="请输入您的问题...")
submit_btn = gr.Button("提交")
answer_output = gr.Textbox(label="回答", lines=10)
submit_btn.click(
answer_question,
inputs=[question_input],
outputs=[answer_output]
)
with gr.Tab("查看向量库内容"):
with gr.Row():
search_input = gr.Textbox(label="搜索关键词")
search_btn = gr.Button("搜索")
k_slider = gr.Slider(minimum=1, maximum=20, value=5, step=1, label="返回结果数量")
search_output = gr.JSON(label="搜索结果")
def format_search_results(query, k):
results = query_vectorstore(query, k=int(k))
formatted_results = []
for doc in results:
formatted_results.append({
"content": doc.page_content,
"metadata": doc.metadata,
"score": doc.metadata.get("score", "N/A")
})
return formatted_results
search_btn.click(
format_search_results,
inputs=[search_input, k_slider],
outputs=[search_output]
)
return demo
def main():
demo = create_interface()
demo.launch(share=True)
if __name__ == "__main__":
main()