| import os |
| import gradio as gr |
| import torch |
| from langchain_google_genai import ChatGoogleGenerativeAI |
| from langchain.prompts import ChatPromptTemplate |
| from langchain.schema.runnable import RunnablePassthrough |
| from langchain.schema.output_parser import StrOutputParser |
| from langchain_community.vectorstores import Chroma |
| from langchain_community.embeddings import HuggingFaceEmbeddings |
| from langchain_core.documents import Document |
| from langchain.text_splitter import RecursiveCharacterTextSplitter |
| from markitdown import MarkItDown |
| import pandas as pd |
| import numpy as np |
| import re |
|
|
| |
| os.environ["GOOGLE_API_KEY"] = "AIzaSyAouHZNUHVWoMHPTrTZKCES-OqiosAfEJY" |
|
|
| |
| VECTOR_STORE_PATH = "./vector_store" |
|
|
| |
| embedding_model = HuggingFaceEmbeddings(model_name="BAAI/bge-small-zh-v1.5") |
|
|
| |
| text_splitter = RecursiveCharacterTextSplitter( |
| chunk_size=500, |
| chunk_overlap=50, |
| separators=["\n\n", "\n", "。", "!", "?", ".", "!", "?", " ", ""] |
| ) |
|
|
| |
| def get_vectorstore(): |
| """获取向量数据库,如果不存在则创建一个空的""" |
| if os.path.exists(VECTOR_STORE_PATH): |
| return Chroma(persist_directory=VECTOR_STORE_PATH, embedding_function=embedding_model) |
| else: |
| vectorstore = Chroma.from_documents( |
| documents=[Document(page_content="初始化文档", metadata={"source": "初始化"})], |
| embedding=embedding_model, |
| persist_directory=VECTOR_STORE_PATH |
| ) |
| vectorstore.persist() |
| return vectorstore |
|
|
| |
| def get_llm(): |
| """初始化Gemini Flash 2.0模型""" |
| return ChatGoogleGenerativeAI( |
| model="gemini-2.0-flash-exp", |
| temperature=0.7, |
| convert_system_message_to_human=True, |
| max_output_tokens=2048 |
| ) |
|
|
| |
| def create_rag_chain(): |
| """创建RAG检索链""" |
| vectorstore = get_vectorstore() |
| retriever = vectorstore.as_retriever(search_kwargs={"k": 5}) |
| |
| |
| template = """ |
| 你是一个专业的数据分析助手。请基于以下检索到的Excel表格数据回答用户的问题。 |
| 如果检索内容中没有相关信息,请诚实地告知用户你不知道,不要编造答案。 |
| |
| 检索到的Excel表格数据: |
| {context} |
| |
| 用户问题: {question} |
| |
| 请提供详细、准确的回答,并在适当的情况下引用数据来源。 |
| """ |
| |
| prompt = ChatPromptTemplate.from_template(template) |
| |
| |
| llm = get_llm() |
| rag_chain = ( |
| {"context": retriever, "question": RunnablePassthrough()} |
| | prompt |
| | llm |
| | StrOutputParser() |
| ) |
| |
| return rag_chain |
|
|
| |
| def process_excel_with_markitdown(file_path): |
| """使用MarkItDown处理Excel文件并添加到向量数据库""" |
| try: |
| |
| md = MarkItDown(enable_plugins=False) |
| result = md.convert(file_path) |
| excel_content = result.text_content |
| |
| |
| docs = text_splitter.split_text(excel_content) |
| documents = [Document(page_content=doc, metadata={"source": file_path}) for doc in docs] |
| |
| |
| vectorstore = get_vectorstore() |
| vectorstore.add_documents(documents) |
| vectorstore.persist() |
| |
| return f"成功处理Excel文件: {file_path},添加了{len(documents)}个文档块到向量数据库" |
| except Exception as e: |
| return f"处理Excel文件时出错: {str(e)}" |
|
|
| |
| def process_excel_with_pandas(file_path): |
| """使用pandas处理Excel文件并添加到向量数据库,将每列作为单独的元数据字段""" |
| try: |
| |
| df = pd.read_excel(file_path) |
| |
| |
| documents = [] |
| for idx, row in df.iterrows(): |
| |
| row_text = "\n".join([f"{col}: {val}" for col, val in row.items() if not pd.isna(val)]) |
| |
| |
| metadata = { |
| "source": file_path, |
| "row": idx, |
| "sheet": "Sheet1" |
| } |
| |
| |
| for col, val in row.items(): |
| |
| if isinstance(val, (int, float)) and not pd.isna(val): |
| metadata[f"col_{col}"] = val |
| elif isinstance(val, str) and val.strip(): |
| metadata[f"col_{col}"] = val.strip() |
| elif pd.isna(val): |
| |
| continue |
| else: |
| |
| metadata[f"col_{col}"] = str(val) |
| |
| |
| doc = Document( |
| page_content=row_text, |
| metadata=metadata |
| ) |
| documents.append(doc) |
| |
| |
| vectorstore = get_vectorstore() |
| vectorstore.add_documents(documents) |
| vectorstore.persist() |
| |
| return f"成功处理Excel文件: {file_path},添加了{len(documents)}个行记录到向量数据库,每行包含{len(df.columns)}个字段" |
| except Exception as e: |
| return f"处理Excel文件时出错: {str(e)}" |
|
|
| |
| def query_vectorstore(query, k=5): |
| """直接查询向量数据库""" |
| vectorstore = get_vectorstore() |
| results = vectorstore.similarity_search(query, k=k) |
| return results |
|
|
| |
| def answer_question(question): |
| """使用RAG链回答问题""" |
| rag_chain = create_rag_chain() |
| response = rag_chain.invoke(question) |
| return response |
|
|
| |
| port=7860 |
| use_frpc=True |
| frpconfigfile="7680.ini" |
| import subprocess |
|
|
| def install_Frpc(port, frpconfigfile, use_frpc): |
| if use_frpc: |
| subprocess.run(['chmod', '+x', './frpc'], check=True) |
| print(f'正在启动frp ,端口{port}') |
| subprocess.Popen(['./frpc', '-c', frpconfigfile]) |
|
|
| install_Frpc('7860',frpconfigfile,use_frpc) |
|
|
| def create_interface(): |
| with gr.Blocks(title="Gemini Flash 2.0 Excel RAG") as demo: |
| gr.HTML("<h1 style='text-align: center'>Gemini Flash 2.0 Excel RAG 系统</h1>") |
| |
| with gr.Tab("导入Excel数据"): |
| with gr.Row(): |
| excel_file = gr.File(label="上传Excel文件") |
| process_method = gr.Radio(["使用MarkItDown处理", "使用Pandas处理"], label="处理方法", value="使用Pandas处理") |
| process_btn = gr.Button("处理并导入到向量数据库") |
| output_msg = gr.Textbox(label="处理结果") |
| |
| def process_excel(file_path, method): |
| if method == "使用MarkItDown处理": |
| return process_excel_with_markitdown(file_path) |
| else: |
| return process_excel_with_pandas(file_path) |
| |
| process_btn.click( |
| process_excel, |
| inputs=[excel_file, process_method], |
| outputs=[output_msg] |
| ) |
| |
| with gr.Tab("查询问答"): |
| with gr.Row(): |
| question_input = gr.Textbox(label="输入问题", placeholder="请输入您的问题...") |
| submit_btn = gr.Button("提交") |
| answer_output = gr.Textbox(label="回答", lines=10) |
| |
| submit_btn.click( |
| answer_question, |
| inputs=[question_input], |
| outputs=[answer_output] |
| ) |
| |
| with gr.Tab("查看向量库内容"): |
| with gr.Row(): |
| search_input = gr.Textbox(label="搜索关键词") |
| search_btn = gr.Button("搜索") |
| k_slider = gr.Slider(minimum=1, maximum=20, value=5, step=1, label="返回结果数量") |
| search_output = gr.JSON(label="搜索结果") |
| |
| def format_search_results(query, k): |
| results = query_vectorstore(query, k=int(k)) |
| formatted_results = [] |
| for doc in results: |
| formatted_results.append({ |
| "content": doc.page_content, |
| "metadata": doc.metadata, |
| "score": doc.metadata.get("score", "N/A") |
| }) |
| return formatted_results |
| |
| search_btn.click( |
| format_search_results, |
| inputs=[search_input, k_slider], |
| outputs=[search_output] |
| ) |
| |
| return demo |
|
|
| def main(): |
| demo = create_interface() |
| demo.launch(share=True) |
|
|
| if __name__ == "__main__": |
| main() |