| import asyncio |
| import json |
| import re |
| import time |
| import uuid |
| from pathlib import Path |
|
|
| import aiofiles |
|
|
| from astrbot.core import logger |
| from astrbot.core.db.vec_db.base import BaseVecDB |
| from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB |
| from astrbot.core.provider.manager import ProviderManager |
| from astrbot.core.provider.provider import ( |
| EmbeddingProvider, |
| RerankProvider, |
| ) |
| from astrbot.core.provider.provider import ( |
| Provider as LLMProvider, |
| ) |
|
|
| from .chunking.base import BaseChunker |
| from .chunking.recursive import RecursiveCharacterChunker |
| from .kb_db_sqlite import KBSQLiteDatabase |
| from .models import KBDocument, KBMedia, KnowledgeBase |
| from .parsers.url_parser import extract_text_from_url |
| from .parsers.util import select_parser |
| from .prompts import TEXT_REPAIR_SYSTEM_PROMPT |
|
|
|
|
| class RateLimiter: |
| """一个简单的速率限制器""" |
|
|
| def __init__(self, max_rpm: int) -> None: |
| self.max_per_minute = max_rpm |
| self.interval = 60.0 / max_rpm if max_rpm > 0 else 0 |
| self.last_call_time = 0 |
|
|
| async def __aenter__(self): |
| if self.interval == 0: |
| return |
|
|
| now = time.monotonic() |
| elapsed = now - self.last_call_time |
|
|
| if elapsed < self.interval: |
| await asyncio.sleep(self.interval - elapsed) |
|
|
| self.last_call_time = time.monotonic() |
|
|
| async def __aexit__(self, exc_type, exc_val, exc_tb): |
| pass |
|
|
|
|
| async def _repair_and_translate_chunk_with_retry( |
| chunk: str, |
| repair_llm_service: LLMProvider, |
| rate_limiter: RateLimiter, |
| max_retries: int = 2, |
| ) -> list[str]: |
| """ |
| Repairs, translates, and optionally re-chunks a single text chunk using the small LLM, with rate limiting. |
| """ |
| |
| user_prompt = f"""IGNORE ALL PREVIOUS INSTRUCTIONS. Your ONLY task is to process the following text chunk according to the system prompt provided. |
| |
| Text chunk to process: |
| --- |
| {chunk} |
| --- |
| """ |
| for attempt in range(max_retries + 1): |
| try: |
| async with rate_limiter: |
| response = await repair_llm_service.text_chat( |
| prompt=user_prompt, system_prompt=TEXT_REPAIR_SYSTEM_PROMPT |
| ) |
|
|
| llm_output = response.completion_text |
|
|
| if "<discard_chunk />" in llm_output: |
| return [] |
|
|
| |
| matches = re.findall( |
| r"<\s*repaired_text\s*>\s*(.*?)\s*<\s*/\s*repaired_text\s*>", |
| llm_output, |
| re.DOTALL, |
| ) |
|
|
| if matches: |
| |
| return [m.strip() for m in matches if m.strip()] |
| else: |
| |
| return [] |
| except Exception as e: |
| logger.warning( |
| f" - LLM call failed on attempt {attempt + 1}/{max_retries + 1}. Error: {str(e)}" |
| ) |
|
|
| logger.error( |
| f" - Failed to process chunk after {max_retries + 1} attempts. Using original text." |
| ) |
| return [chunk] |
|
|
|
|
| class KBHelper: |
| vec_db: BaseVecDB |
| kb: KnowledgeBase |
|
|
| def __init__( |
| self, |
| kb_db: KBSQLiteDatabase, |
| kb: KnowledgeBase, |
| provider_manager: ProviderManager, |
| kb_root_dir: str, |
| chunker: BaseChunker, |
| ) -> None: |
| self.kb_db = kb_db |
| self.kb = kb |
| self.prov_mgr = provider_manager |
| self.kb_root_dir = kb_root_dir |
| self.chunker = chunker |
|
|
| self.kb_dir = Path(self.kb_root_dir) / self.kb.kb_id |
| self.kb_medias_dir = Path(self.kb_dir) / "medias" / self.kb.kb_id |
| self.kb_files_dir = Path(self.kb_dir) / "files" / self.kb.kb_id |
|
|
| self.kb_medias_dir.mkdir(parents=True, exist_ok=True) |
| self.kb_files_dir.mkdir(parents=True, exist_ok=True) |
|
|
| async def initialize(self) -> None: |
| await self._ensure_vec_db() |
|
|
| async def get_ep(self) -> EmbeddingProvider: |
| if not self.kb.embedding_provider_id: |
| raise ValueError(f"知识库 {self.kb.kb_name} 未配置 Embedding Provider") |
| ep: EmbeddingProvider = await self.prov_mgr.get_provider_by_id( |
| self.kb.embedding_provider_id, |
| ) |
| if not ep: |
| raise ValueError( |
| f"无法找到 ID 为 {self.kb.embedding_provider_id} 的 Embedding Provider", |
| ) |
| return ep |
|
|
| async def get_rp(self) -> RerankProvider | None: |
| if not self.kb.rerank_provider_id: |
| return None |
| rp: RerankProvider = await self.prov_mgr.get_provider_by_id( |
| self.kb.rerank_provider_id, |
| ) |
| if not rp: |
| raise ValueError( |
| f"无法找到 ID 为 {self.kb.rerank_provider_id} 的 Rerank Provider", |
| ) |
| return rp |
|
|
| async def _ensure_vec_db(self) -> FaissVecDB: |
| if not self.kb.embedding_provider_id: |
| raise ValueError(f"知识库 {self.kb.kb_name} 未配置 Embedding Provider") |
|
|
| ep = await self.get_ep() |
| rp = await self.get_rp() |
|
|
| vec_db = FaissVecDB( |
| doc_store_path=str(self.kb_dir / "doc.db"), |
| index_store_path=str(self.kb_dir / "index.faiss"), |
| embedding_provider=ep, |
| rerank_provider=rp, |
| ) |
| await vec_db.initialize() |
| self.vec_db = vec_db |
| return vec_db |
|
|
| async def delete_vec_db(self) -> None: |
| """删除知识库的向量数据库和所有相关文件""" |
| import shutil |
|
|
| await self.terminate() |
| if self.kb_dir.exists(): |
| shutil.rmtree(self.kb_dir) |
|
|
| async def terminate(self) -> None: |
| if self.vec_db: |
| await self.vec_db.close() |
|
|
| async def upload_document( |
| self, |
| file_name: str, |
| file_content: bytes | None, |
| file_type: str, |
| chunk_size: int = 512, |
| chunk_overlap: int = 50, |
| batch_size: int = 32, |
| tasks_limit: int = 3, |
| max_retries: int = 3, |
| progress_callback=None, |
| pre_chunked_text: list[str] | None = None, |
| ) -> KBDocument: |
| """上传并处理文档(带原子性保证和失败清理) |
| |
| 流程: |
| 1. 保存原始文件 |
| 2. 解析文档内容 |
| 3. 提取多媒体资源 |
| 4. 分块处理 |
| 5. 生成向量并存储 |
| 6. 保存元数据(事务) |
| 7. 更新统计 |
| |
| Args: |
| progress_callback: 进度回调函数,接收参数 (stage, current, total) |
| - stage: 当前阶段 ('parsing', 'chunking', 'embedding') |
| - current: 当前进度 |
| - total: 总数 |
| |
| """ |
| await self._ensure_vec_db() |
| doc_id = str(uuid.uuid4()) |
| media_paths: list[Path] = [] |
| file_size = 0 |
|
|
| |
| |
| |
|
|
| try: |
| chunks_text = [] |
| saved_media = [] |
|
|
| if pre_chunked_text is not None: |
| |
| chunks_text = pre_chunked_text |
| file_size = sum(len(chunk) for chunk in chunks_text) |
| logger.info(f"使用预分块文本进行上传,共 {len(chunks_text)} 个块。") |
| else: |
| |
| if file_content is None: |
| raise ValueError( |
| "当未提供 pre_chunked_text 时,file_content 不能为空。" |
| ) |
|
|
| file_size = len(file_content) |
|
|
| |
| if progress_callback: |
| await progress_callback("parsing", 0, 100) |
|
|
| parser = await select_parser(f".{file_type}") |
| parse_result = await parser.parse(file_content, file_name) |
| text_content = parse_result.text |
| media_items = parse_result.media |
|
|
| if progress_callback: |
| await progress_callback("parsing", 100, 100) |
|
|
| |
| for media_item in media_items: |
| media = await self._save_media( |
| doc_id=doc_id, |
| media_type=media_item.media_type, |
| file_name=media_item.file_name, |
| content=media_item.content, |
| mime_type=media_item.mime_type, |
| ) |
| saved_media.append(media) |
| media_paths.append(Path(media.file_path)) |
|
|
| |
| if progress_callback: |
| await progress_callback("chunking", 0, 100) |
|
|
| chunks_text = await self.chunker.chunk( |
| text_content, |
| chunk_size=chunk_size, |
| chunk_overlap=chunk_overlap, |
| ) |
| contents = [] |
| metadatas = [] |
| for idx, chunk_text in enumerate(chunks_text): |
| contents.append(chunk_text) |
| metadatas.append( |
| { |
| "kb_id": self.kb.kb_id, |
| "kb_doc_id": doc_id, |
| "chunk_index": idx, |
| }, |
| ) |
|
|
| if progress_callback: |
| await progress_callback("chunking", 100, 100) |
|
|
| |
| async def embedding_progress_callback(current, total) -> None: |
| if progress_callback: |
| await progress_callback("embedding", current, total) |
|
|
| await self.vec_db.insert_batch( |
| contents=contents, |
| metadatas=metadatas, |
| batch_size=batch_size, |
| tasks_limit=tasks_limit, |
| max_retries=max_retries, |
| progress_callback=embedding_progress_callback, |
| ) |
|
|
| |
| doc = KBDocument( |
| doc_id=doc_id, |
| kb_id=self.kb.kb_id, |
| doc_name=file_name, |
| file_type=file_type, |
| file_size=file_size, |
| |
| file_path="", |
| chunk_count=len(chunks_text), |
| media_count=0, |
| ) |
| async with self.kb_db.get_db() as session: |
| async with session.begin(): |
| session.add(doc) |
| for media in saved_media: |
| session.add(media) |
| await session.commit() |
|
|
| await session.refresh(doc) |
|
|
| vec_db: FaissVecDB = self.vec_db |
| await self.kb_db.update_kb_stats(kb_id=self.kb.kb_id, vec_db=vec_db) |
| await self.refresh_kb() |
| await self.refresh_document(doc_id) |
| return doc |
| except Exception as e: |
| logger.error(f"上传文档失败: {e}") |
| |
| |
|
|
| for media_path in media_paths: |
| try: |
| if media_path.exists(): |
| media_path.unlink() |
| except Exception as me: |
| logger.warning(f"清理多媒体文件失败 {media_path}: {me}") |
|
|
| raise e |
|
|
| async def list_documents( |
| self, |
| offset: int = 0, |
| limit: int = 100, |
| ) -> list[KBDocument]: |
| """列出知识库的所有文档""" |
| docs = await self.kb_db.list_documents_by_kb(self.kb.kb_id, offset, limit) |
| return docs |
|
|
| async def get_document(self, doc_id: str) -> KBDocument | None: |
| """获取单个文档""" |
| doc = await self.kb_db.get_document_by_id(doc_id) |
| return doc |
|
|
| async def delete_document(self, doc_id: str) -> None: |
| """删除单个文档及其相关数据""" |
| await self.kb_db.delete_document_by_id( |
| doc_id=doc_id, |
| vec_db=self.vec_db, |
| ) |
| await self.kb_db.update_kb_stats( |
| kb_id=self.kb.kb_id, |
| vec_db=self.vec_db, |
| ) |
| await self.refresh_kb() |
|
|
| async def delete_chunk(self, chunk_id: str, doc_id: str) -> None: |
| """删除单个文本块及其相关数据""" |
| vec_db: FaissVecDB = self.vec_db |
| await vec_db.delete(chunk_id) |
| await self.kb_db.update_kb_stats( |
| kb_id=self.kb.kb_id, |
| vec_db=self.vec_db, |
| ) |
| await self.refresh_kb() |
| await self.refresh_document(doc_id) |
|
|
| async def refresh_kb(self) -> None: |
| if self.kb: |
| kb = await self.kb_db.get_kb_by_id(self.kb.kb_id) |
| if kb: |
| self.kb = kb |
|
|
| async def refresh_document(self, doc_id: str) -> None: |
| """更新文档的元数据""" |
| doc = await self.get_document(doc_id) |
| if not doc: |
| raise ValueError(f"无法找到 ID 为 {doc_id} 的文档") |
| chunk_count = await self.get_chunk_count_by_doc_id(doc_id) |
| doc.chunk_count = chunk_count |
| async with self.kb_db.get_db() as session: |
| async with session.begin(): |
| session.add(doc) |
| await session.commit() |
| await session.refresh(doc) |
|
|
| async def get_chunks_by_doc_id( |
| self, |
| doc_id: str, |
| offset: int = 0, |
| limit: int = 100, |
| ) -> list[dict]: |
| """获取文档的所有块及其元数据""" |
| vec_db: FaissVecDB = self.vec_db |
| chunks = await vec_db.document_storage.get_documents( |
| metadata_filters={"kb_doc_id": doc_id}, |
| offset=offset, |
| limit=limit, |
| ) |
| result = [] |
| for chunk in chunks: |
| chunk_md = json.loads(chunk["metadata"]) |
| result.append( |
| { |
| "chunk_id": chunk["doc_id"], |
| "doc_id": chunk_md["kb_doc_id"], |
| "kb_id": chunk_md["kb_id"], |
| "chunk_index": chunk_md["chunk_index"], |
| "content": chunk["text"], |
| "char_count": len(chunk["text"]), |
| }, |
| ) |
| return result |
|
|
| async def get_chunk_count_by_doc_id(self, doc_id: str) -> int: |
| """获取文档的块数量""" |
| vec_db: FaissVecDB = self.vec_db |
| count = await vec_db.count_documents(metadata_filter={"kb_doc_id": doc_id}) |
| return count |
|
|
| async def _save_media( |
| self, |
| doc_id: str, |
| media_type: str, |
| file_name: str, |
| content: bytes, |
| mime_type: str, |
| ) -> KBMedia: |
| """保存多媒体资源""" |
| media_id = str(uuid.uuid4()) |
| ext = Path(file_name).suffix |
|
|
| |
| file_path = self.kb_medias_dir / doc_id / f"{media_id}{ext}" |
| file_path.parent.mkdir(parents=True, exist_ok=True) |
| async with aiofiles.open(file_path, "wb") as f: |
| await f.write(content) |
|
|
| media = KBMedia( |
| media_id=media_id, |
| doc_id=doc_id, |
| kb_id=self.kb.kb_id, |
| media_type=media_type, |
| file_name=file_name, |
| file_path=str(file_path), |
| file_size=len(content), |
| mime_type=mime_type, |
| ) |
|
|
| return media |
|
|
| async def upload_from_url( |
| self, |
| url: str, |
| chunk_size: int = 512, |
| chunk_overlap: int = 50, |
| batch_size: int = 32, |
| tasks_limit: int = 3, |
| max_retries: int = 3, |
| progress_callback=None, |
| enable_cleaning: bool = False, |
| cleaning_provider_id: str | None = None, |
| ) -> KBDocument: |
| """从 URL 上传并处理文档(带原子性保证和失败清理) |
| Args: |
| url: 要提取内容的网页 URL |
| chunk_size: 文本块大小 |
| chunk_overlap: 文本块重叠大小 |
| batch_size: 批处理大小 |
| tasks_limit: 并发任务限制 |
| max_retries: 最大重试次数 |
| progress_callback: 进度回调函数,接收参数 (stage, current, total) |
| - stage: 当前阶段 ('extracting', 'cleaning', 'parsing', 'chunking', 'embedding') |
| - current: 当前进度 |
| - total: 总数 |
| Returns: |
| KBDocument: 上传的文档对象 |
| Raises: |
| ValueError: 如果 URL 为空或无法提取内容 |
| IOError: 如果网络请求失败 |
| """ |
| |
| config = self.prov_mgr.acm.default_conf |
| tavily_keys = config.get("provider_settings", {}).get( |
| "websearch_tavily_key", [] |
| ) |
| if not tavily_keys: |
| raise ValueError( |
| "Error: Tavily API key is not configured in provider_settings." |
| ) |
|
|
| |
| if progress_callback: |
| await progress_callback("extracting", 0, 100) |
|
|
| try: |
| text_content = await extract_text_from_url(url, tavily_keys) |
| except Exception as e: |
| logger.error(f"Failed to extract content from URL {url}: {e}") |
| raise OSError(f"Failed to extract content from URL {url}: {e}") from e |
|
|
| if not text_content: |
| raise ValueError(f"No content extracted from URL: {url}") |
|
|
| if progress_callback: |
| await progress_callback("extracting", 100, 100) |
|
|
| |
| final_chunks = await self._clean_and_rechunk_content( |
| content=text_content, |
| url=url, |
| progress_callback=progress_callback, |
| enable_cleaning=enable_cleaning, |
| cleaning_provider_id=cleaning_provider_id, |
| chunk_size=chunk_size, |
| chunk_overlap=chunk_overlap, |
| ) |
|
|
| if enable_cleaning and not final_chunks: |
| raise ValueError( |
| "内容清洗后未提取到有效文本。请尝试关闭内容清洗功能,或更换更高性能的LLM模型后重试。" |
| ) |
|
|
| |
| file_name = url.split("/")[-1] or f"document_from_{url}" |
| if not Path(file_name).suffix: |
| file_name += ".url" |
|
|
| |
| return await self.upload_document( |
| file_name=file_name, |
| file_content=None, |
| file_type="url", |
| chunk_size=chunk_size, |
| chunk_overlap=chunk_overlap, |
| batch_size=batch_size, |
| tasks_limit=tasks_limit, |
| max_retries=max_retries, |
| progress_callback=progress_callback, |
| pre_chunked_text=final_chunks, |
| ) |
|
|
| async def _clean_and_rechunk_content( |
| self, |
| content: str, |
| url: str, |
| progress_callback=None, |
| enable_cleaning: bool = False, |
| cleaning_provider_id: str | None = None, |
| repair_max_rpm: int = 60, |
| chunk_size: int = 512, |
| chunk_overlap: int = 50, |
| ) -> list[str]: |
| """ |
| 对从 URL 获取的内容进行清洗、修复、翻译和重新分块。 |
| """ |
| if not enable_cleaning: |
| |
| logger.info( |
| f"内容清洗未启用,使用指定参数进行分块: chunk_size={chunk_size}, chunk_overlap={chunk_overlap}" |
| ) |
| return await self.chunker.chunk( |
| content, chunk_size=chunk_size, chunk_overlap=chunk_overlap |
| ) |
|
|
| if not cleaning_provider_id: |
| logger.warning( |
| "启用了内容清洗,但未提供 cleaning_provider_id,跳过清洗并使用默认分块。" |
| ) |
| return await self.chunker.chunk(content) |
|
|
| if progress_callback: |
| await progress_callback("cleaning", 0, 100) |
|
|
| try: |
| |
| llm_provider = await self.prov_mgr.get_provider_by_id(cleaning_provider_id) |
| if not llm_provider or not isinstance(llm_provider, LLMProvider): |
| raise ValueError( |
| f"无法找到 ID 为 {cleaning_provider_id} 的 LLM Provider 或类型不正确" |
| ) |
|
|
| |
| |
| text_splitter = RecursiveCharacterChunker( |
| chunk_size=chunk_size, |
| chunk_overlap=chunk_overlap, |
| separators=["\n\n", "\n", " "], |
| ) |
| initial_chunks = await text_splitter.chunk(content) |
| logger.info(f"初步分块完成,生成 {len(initial_chunks)} 个块用于修复。") |
|
|
| |
| rate_limiter = RateLimiter(repair_max_rpm) |
| tasks = [ |
| _repair_and_translate_chunk_with_retry( |
| chunk, llm_provider, rate_limiter |
| ) |
| for chunk in initial_chunks |
| ] |
|
|
| repaired_results = await asyncio.gather(*tasks, return_exceptions=True) |
|
|
| final_chunks = [] |
| for i, result in enumerate(repaired_results): |
| if isinstance(result, Exception): |
| logger.warning(f"块 {i} 处理异常: {str(result)}. 回退到原始块。") |
| final_chunks.append(initial_chunks[i]) |
| elif isinstance(result, list): |
| final_chunks.extend(result) |
|
|
| logger.info( |
| f"文本修复完成: {len(initial_chunks)} 个原始块 -> {len(final_chunks)} 个最终块。" |
| ) |
|
|
| if progress_callback: |
| await progress_callback("cleaning", 100, 100) |
|
|
| return final_chunks |
|
|
| except Exception as e: |
| logger.error(f"使用 Provider '{cleaning_provider_id}' 清洗内容失败: {e}") |
| |
| return await self.chunker.chunk(content) |
|
|