| from typing import List, Dict, Callable, Optional
|
| from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| from langchain_community.document_loaders import (
|
| DirectoryLoader,
|
| UnstructuredMarkdownLoader,
|
| PyPDFLoader,
|
| TextLoader
|
| )
|
| import os
|
| import requests
|
| import base64
|
| from PIL import Image
|
| import io
|
|
|
| class DocumentLoader:
|
| """通用文档加载器"""
|
| def __init__(self, file_path: str, original_filename: str = None):
|
| self.file_path = file_path
|
|
|
| self.original_filename = original_filename or os.path.basename(file_path)
|
|
|
| self.extension = os.path.splitext(self.original_filename)[1].lower()
|
| self.api_key = os.getenv("API_KEY")
|
| self.api_base = os.getenv("BASE_URL")
|
|
|
| def process_image(self, image_path: str) -> str:
|
| """使用 SiliconFlow VLM 模型处理图片"""
|
| try:
|
|
|
| with open(image_path, 'rb') as image_file:
|
| image_data = image_file.read()
|
| base64_image = base64.b64encode(image_data).decode('utf-8')
|
|
|
|
|
| headers = {
|
| "Authorization": f"Bearer {self.api_key}",
|
| "Content-Type": "application/json"
|
| }
|
|
|
| response = requests.post(
|
| f"{self.api_base}/chat/completions",
|
| headers=headers,
|
| json={
|
| "model": "Qwen/Qwen2.5-VL-72B-Instruct",
|
| "messages": [
|
| {
|
| "role": "user",
|
| "content": [
|
| {
|
| "type": "image_url",
|
| "image_url": {
|
| "url": f"data:image/jpeg;base64,{base64_image}",
|
| "detail": "high"
|
| }
|
| },
|
| {
|
| "type": "text",
|
| "text": "请详细描述这张图片的内容,包括主要对象、场景、活动、颜色、布局等关键信息。"
|
| }
|
| ]
|
| }
|
| ],
|
| "temperature": 0.7,
|
| "max_tokens": 500
|
| }
|
| )
|
|
|
| if response.status_code != 200:
|
| raise Exception(f"图片处理API调用失败: {response.text}")
|
|
|
| description = response.json()["choices"][0]["message"]["content"]
|
| return description
|
|
|
| except Exception as e:
|
| print(f"处理图片时出错: {str(e)}")
|
| return "图片处理失败"
|
|
|
| def load(self):
|
| try:
|
| print(f"正在加载文件: {self.file_path}, 原始文件名: {self.original_filename}, 扩展名: {self.extension}")
|
|
|
| if self.extension == '.md':
|
| try:
|
| loader = UnstructuredMarkdownLoader(self.file_path, encoding='utf-8')
|
| return loader.load()
|
| except UnicodeDecodeError:
|
|
|
| loader = UnstructuredMarkdownLoader(self.file_path, encoding='gbk')
|
| return loader.load()
|
| elif self.extension == '.pdf':
|
| loader = PyPDFLoader(self.file_path)
|
| return loader.load()
|
| elif self.extension == '.txt':
|
| try:
|
| loader = TextLoader(self.file_path, encoding='utf-8')
|
| return loader.load()
|
| except UnicodeDecodeError:
|
|
|
| loader = TextLoader(self.file_path, encoding='gbk')
|
| return loader.load()
|
| elif self.extension in ['.png', '.jpg', '.jpeg', '.gif', '.bmp']:
|
|
|
| description = self.process_image(self.file_path)
|
|
|
| from langchain.schema import Document
|
| doc = Document(
|
| page_content=description,
|
| metadata={
|
| 'source': self.file_path,
|
| 'file_name': self.original_filename,
|
| 'img_url': os.path.abspath(self.file_path)
|
| }
|
| )
|
| return [doc]
|
| else:
|
| print(f"不支持的文件扩展名: {self.extension}")
|
| raise ValueError(f"不支持的文件格式: {self.extension}")
|
|
|
| except UnicodeDecodeError:
|
|
|
| print(f"文件编码错误,尝试其他编码: {self.file_path}")
|
| if self.extension in ['.md', '.txt']:
|
| try:
|
| loader = TextLoader(self.file_path, encoding='gbk')
|
| return loader.load()
|
| except Exception as e:
|
| print(f"尝试GBK编码也失败: {str(e)}")
|
| raise
|
| except Exception as e:
|
| print(f"加载文件 {self.file_path} 时出错: {str(e)}")
|
| import traceback
|
| traceback.print_exc()
|
| raise
|
|
|
| class DocumentProcessor:
|
| def __init__(self):
|
| self.text_splitter = RecursiveCharacterTextSplitter(
|
| chunk_size=1000,
|
| chunk_overlap=200,
|
| length_function=len,
|
| )
|
|
|
| def get_index_name(self, path: str) -> str:
|
| """根据文件路径生成索引名称"""
|
| if os.path.isdir(path):
|
|
|
| return f"rag_{os.path.basename(path).lower()}"
|
| else:
|
|
|
| return f"rag_{os.path.splitext(os.path.basename(path))[0].lower()}"
|
|
|
| def process(self, path: str, progress_callback: Optional[Callable] = None, original_filename: str = None) -> List[Dict]:
|
| """
|
| 加载并处理文档,支持目录或单个文件
|
| 参数:
|
| path: 文档路径
|
| progress_callback: 进度回调函数,用于报告处理进度
|
| original_filename: 原始文件名(包括中文)
|
| 返回:处理后的文档列表
|
| """
|
| if os.path.isdir(path):
|
| documents = []
|
| total_files = sum([len(files) for _, _, files in os.walk(path)])
|
| processed_files = 0
|
| processed_size = 0
|
|
|
| for root, _, files in os.walk(path):
|
| for file in files:
|
| file_path = os.path.join(root, file)
|
| try:
|
|
|
| if progress_callback:
|
| file_size = os.path.getsize(file_path)
|
| processed_size += file_size
|
| processed_files += 1
|
| progress_callback(processed_size, f"处理文件 {processed_files}/{total_files}: {file}")
|
|
|
|
|
| loader = DocumentLoader(file_path, original_filename=file)
|
| docs = loader.load()
|
|
|
| for doc in docs:
|
| doc.metadata['file_name'] = file
|
| documents.extend(docs)
|
| except Exception as e:
|
| print(f"警告:加载文件 {file_path} 时出错: {str(e)}")
|
| continue
|
| else:
|
| try:
|
| if progress_callback:
|
| file_size = os.path.getsize(path)
|
| progress_callback(file_size * 0.3, f"加载文件: {original_filename or os.path.basename(path)}")
|
|
|
|
|
| loader = DocumentLoader(path, original_filename=original_filename)
|
| documents = loader.load()
|
|
|
|
|
| if progress_callback:
|
| progress_callback(file_size * 0.6, f"处理文件内容...")
|
|
|
|
|
| file_name = original_filename or os.path.basename(path)
|
| for doc in documents:
|
| doc.metadata['file_name'] = file_name
|
| except Exception as e:
|
| print(f"加载文件时出错: {str(e)}")
|
| raise
|
|
|
|
|
| chunks = self.text_splitter.split_documents(documents)
|
|
|
|
|
| if progress_callback:
|
| if os.path.isdir(path):
|
| progress_callback(processed_size, f"文档分块完成,共{len(chunks)}个文档片段")
|
| else:
|
| file_size = os.path.getsize(path)
|
| progress_callback(file_size * 0.9, f"文档分块完成,共{len(chunks)}个文档片段")
|
|
|
|
|
| processed_docs = []
|
| for i, chunk in enumerate(chunks):
|
| processed_docs.append({
|
| 'id': f'doc_{i}',
|
| 'content': chunk.page_content,
|
| 'metadata': chunk.metadata
|
| })
|
|
|
| return processed_docs |