| from __future__ import annotations |
| import weakref |
|
|
| import asyncio |
| import html |
| import csv |
| import json |
| import logging |
| import logging.handlers |
| import os |
| import re |
| import time |
| import uuid |
| from dataclasses import dataclass |
| from datetime import datetime |
| from functools import wraps |
| from hashlib import md5 |
| from typing import Any, Protocol, Callable, TYPE_CHECKING, List, Optional |
| import numpy as np |
| from dotenv import load_dotenv |
|
|
| from lightrag.constants import ( |
| DEFAULT_LOG_MAX_BYTES, |
| DEFAULT_LOG_BACKUP_COUNT, |
| DEFAULT_LOG_FILENAME, |
| GRAPH_FIELD_SEP, |
| DEFAULT_MAX_TOTAL_TOKENS, |
| DEFAULT_MAX_FILE_PATH_LENGTH, |
| ) |
|
|
| |
| logger = logging.getLogger("lightrag") |
| logger.propagate = False |
| logger.setLevel(logging.INFO) |
|
|
| |
| if not logger.handlers: |
| console_handler = logging.StreamHandler() |
| console_handler.setLevel(logging.INFO) |
| formatter = logging.Formatter("%(levelname)s: %(message)s") |
| console_handler.setFormatter(formatter) |
| logger.addHandler(console_handler) |
|
|
| |
| logging.getLogger("httpx").setLevel(logging.WARNING) |
|
|
| |
| try: |
| import pypinyin |
|
|
| _PYPINYIN_AVAILABLE = True |
| |
| except ImportError: |
| pypinyin = None |
| _PYPINYIN_AVAILABLE = False |
| logger.warning( |
| "pypinyin is not installed. Chinese pinyin sorting will use simple string sorting." |
| ) |
|
|
|
|
| async def safe_vdb_operation_with_exception( |
| operation: Callable, |
| operation_name: str, |
| entity_name: str = "", |
| max_retries: int = 3, |
| retry_delay: float = 0.2, |
| logger_func: Optional[Callable] = None, |
| ) -> None: |
| """ |
| Safely execute vector database operations with retry mechanism and exception handling. |
| |
| This function ensures that VDB operations are executed with proper error handling |
| and retry logic. If all retries fail, it raises an exception to maintain data consistency. |
| |
| Args: |
| operation: The async operation to execute |
| operation_name: Operation name for logging purposes |
| entity_name: Entity name for logging purposes |
| max_retries: Maximum number of retry attempts |
| retry_delay: Delay between retries in seconds |
| logger_func: Logger function to use for error messages |
| |
| Raises: |
| Exception: When operation fails after all retry attempts |
| """ |
| log_func = logger_func or logger.warning |
|
|
| for attempt in range(max_retries): |
| try: |
| await operation() |
| return |
| except Exception as e: |
| if attempt >= max_retries - 1: |
| error_msg = f"VDB {operation_name} failed for {entity_name} after {max_retries} attempts: {e}" |
| log_func(error_msg) |
| raise Exception(error_msg) from e |
| else: |
| log_func( |
| f"VDB {operation_name} attempt {attempt + 1} failed for {entity_name}: {e}, retrying..." |
| ) |
| if retry_delay > 0: |
| await asyncio.sleep(retry_delay) |
|
|
|
|
| def get_env_value( |
| env_key: str, default: any, value_type: type = str, special_none: bool = False |
| ) -> any: |
| """ |
| Get value from environment variable with type conversion |
| |
| Args: |
| env_key (str): Environment variable key |
| default (any): Default value if env variable is not set |
| value_type (type): Type to convert the value to |
| special_none (bool): If True, return None when value is "None" |
| |
| Returns: |
| any: Converted value from environment or default |
| """ |
| value = os.getenv(env_key) |
| if value is None: |
| return default |
|
|
| |
| if special_none and value == "None": |
| return None |
|
|
| if value_type is bool: |
| return value.lower() in ("true", "1", "yes", "t", "on") |
|
|
| |
| if value_type is list: |
| try: |
| import json |
|
|
| parsed_value = json.loads(value) |
| |
| if isinstance(parsed_value, list): |
| return parsed_value |
| else: |
| logger.warning( |
| f"Environment variable {env_key} is not a valid JSON list, using default" |
| ) |
| return default |
| except (json.JSONDecodeError, ValueError) as e: |
| logger.warning( |
| f"Failed to parse {env_key} as JSON list: {e}, using default" |
| ) |
| return default |
|
|
| try: |
| return value_type(value) |
| except (ValueError, TypeError): |
| return default |
|
|
|
|
| |
| if TYPE_CHECKING: |
| from lightrag.base import BaseKVStorage, BaseVectorStorage, QueryParam |
|
|
| |
| |
| |
| load_dotenv(dotenv_path=".env", override=False) |
|
|
| VERBOSE_DEBUG = os.getenv("VERBOSE", "false").lower() == "true" |
|
|
|
|
| def verbose_debug(msg: str, *args, **kwargs): |
| """Function for outputting detailed debug information. |
| When VERBOSE_DEBUG=True, outputs the complete message. |
| When VERBOSE_DEBUG=False, outputs only the first 50 characters. |
| |
| Args: |
| msg: The message format string |
| *args: Arguments to be formatted into the message |
| **kwargs: Keyword arguments passed to logger.debug() |
| """ |
| if VERBOSE_DEBUG: |
| logger.debug(msg, *args, **kwargs) |
| else: |
| |
| if args: |
| formatted_msg = msg % args |
| else: |
| formatted_msg = msg |
| |
| truncated_msg = ( |
| formatted_msg[:150] + "..." if len(formatted_msg) > 150 else formatted_msg |
| ) |
| |
| truncated_msg = re.sub(r"\n+", "\n", truncated_msg) |
| logger.debug(truncated_msg, **kwargs) |
|
|
|
|
| def set_verbose_debug(enabled: bool): |
| """Enable or disable verbose debug output""" |
| global VERBOSE_DEBUG |
| VERBOSE_DEBUG = enabled |
|
|
|
|
| statistic_data = {"llm_call": 0, "llm_cache": 0, "embed_call": 0} |
|
|
|
|
| class LightragPathFilter(logging.Filter): |
| """Filter for lightrag logger to filter out frequent path access logs""" |
|
|
| def __init__(self): |
| super().__init__() |
| |
| self.filtered_paths = [ |
| "/documents", |
| "/documents/paginated", |
| "/health", |
| "/webui/", |
| "/documents/pipeline_status", |
| ] |
| |
|
|
| def filter(self, record): |
| try: |
| |
| if not hasattr(record, "args") or not isinstance(record.args, tuple): |
| return True |
| if len(record.args) < 5: |
| return True |
|
|
| |
| method = record.args[1] |
| path = record.args[2] |
| status = record.args[4] |
|
|
| |
| if ( |
| (method == "GET" or method == "POST") |
| and (status == 200 or status == 304) |
| and path in self.filtered_paths |
| ): |
| return False |
|
|
| return True |
| except Exception: |
| |
| return True |
|
|
|
|
| def setup_logger( |
| logger_name: str, |
| level: str = "INFO", |
| add_filter: bool = False, |
| log_file_path: str | None = None, |
| enable_file_logging: bool = True, |
| ): |
| """Set up a logger with console and optionally file handlers |
| |
| Args: |
| logger_name: Name of the logger to set up |
| level: Log level (DEBUG, INFO, WARNING, ERROR, CRITICAL) |
| add_filter: Whether to add LightragPathFilter to the logger |
| log_file_path: Path to the log file. If None and file logging is enabled, defaults to lightrag.log in LOG_DIR or cwd |
| enable_file_logging: Whether to enable logging to a file (defaults to True) |
| """ |
| |
| detailed_formatter = logging.Formatter( |
| "%(asctime)s - %(name)s - %(levelname)s - %(message)s" |
| ) |
| simple_formatter = logging.Formatter("%(levelname)s: %(message)s") |
|
|
| logger_instance = logging.getLogger(logger_name) |
| logger_instance.setLevel(level) |
| logger_instance.handlers = [] |
| logger_instance.propagate = False |
|
|
| |
| console_handler = logging.StreamHandler() |
| console_handler.setFormatter(simple_formatter) |
| console_handler.setLevel(level) |
| logger_instance.addHandler(console_handler) |
|
|
| |
| if enable_file_logging: |
| |
| if log_file_path is None: |
| log_dir = os.getenv("LOG_DIR", os.getcwd()) |
| log_file_path = os.path.abspath(os.path.join(log_dir, DEFAULT_LOG_FILENAME)) |
|
|
| |
| os.makedirs(os.path.dirname(log_file_path), exist_ok=True) |
|
|
| |
| log_max_bytes = get_env_value("LOG_MAX_BYTES", DEFAULT_LOG_MAX_BYTES, int) |
| log_backup_count = get_env_value( |
| "LOG_BACKUP_COUNT", DEFAULT_LOG_BACKUP_COUNT, int |
| ) |
|
|
| try: |
| |
| file_handler = logging.handlers.RotatingFileHandler( |
| filename=log_file_path, |
| maxBytes=log_max_bytes, |
| backupCount=log_backup_count, |
| encoding="utf-8", |
| ) |
| file_handler.setFormatter(detailed_formatter) |
| file_handler.setLevel(level) |
| logger_instance.addHandler(file_handler) |
| except PermissionError as e: |
| logger.warning(f"Could not create log file at {log_file_path}: {str(e)}") |
| logger.warning("Continuing with console logging only") |
|
|
| |
| if add_filter: |
| path_filter = LightragPathFilter() |
| logger_instance.addFilter(path_filter) |
|
|
|
|
| class UnlimitedSemaphore: |
| """A context manager that allows unlimited access.""" |
|
|
| async def __aenter__(self): |
| pass |
|
|
| async def __aexit__(self, exc_type, exc, tb): |
| pass |
|
|
|
|
| @dataclass |
| class TaskState: |
| """Task state tracking for priority queue management""" |
|
|
| future: asyncio.Future |
| start_time: float |
| execution_start_time: float = None |
| worker_started: bool = False |
| cancellation_requested: bool = False |
| cleanup_done: bool = False |
|
|
|
|
| @dataclass |
| class EmbeddingFunc: |
| embedding_dim: int |
| func: callable |
| max_token_size: int | None = None |
|
|
| async def __call__(self, *args, **kwargs) -> np.ndarray: |
| return await self.func(*args, **kwargs) |
|
|
|
|
| def compute_args_hash(*args: Any) -> str: |
| """Compute a hash for the given arguments with safe Unicode handling. |
| |
| Args: |
| *args: Arguments to hash |
| Returns: |
| str: Hash string |
| """ |
| |
| args_str = "".join([str(arg) for arg in args]) |
|
|
| |
| |
| try: |
| return md5(args_str.encode("utf-8")).hexdigest() |
| except UnicodeEncodeError: |
| |
| safe_bytes = args_str.encode("utf-8", errors="replace") |
| return md5(safe_bytes).hexdigest() |
|
|
|
|
| def compute_mdhash_id(content: str, prefix: str = "") -> str: |
| """ |
| Compute a unique ID for a given content string. |
| |
| The ID is a combination of the given prefix and the MD5 hash of the content string. |
| """ |
| return prefix + compute_args_hash(content) |
|
|
|
|
| def generate_cache_key(mode: str, cache_type: str, hash_value: str) -> str: |
| """Generate a flattened cache key in the format {mode}:{cache_type}:{hash} |
| |
| Args: |
| mode: Cache mode (e.g., 'default', 'local', 'global') |
| cache_type: Type of cache (e.g., 'extract', 'query', 'keywords') |
| hash_value: Hash value from compute_args_hash |
| |
| Returns: |
| str: Flattened cache key |
| """ |
| return f"{mode}:{cache_type}:{hash_value}" |
|
|
|
|
| def parse_cache_key(cache_key: str) -> tuple[str, str, str] | None: |
| """Parse a flattened cache key back into its components |
| |
| Args: |
| cache_key: Flattened cache key in format {mode}:{cache_type}:{hash} |
| |
| Returns: |
| tuple[str, str, str] | None: (mode, cache_type, hash) or None if invalid format |
| """ |
| parts = cache_key.split(":", 2) |
| if len(parts) == 3: |
| return parts[0], parts[1], parts[2] |
| return None |
|
|
|
|
| |
| class QueueFullError(Exception): |
| """Raised when the queue is full and the wait times out""" |
|
|
| pass |
|
|
|
|
| class WorkerTimeoutError(Exception): |
| """Worker-level timeout exception with specific timeout information""" |
|
|
| def __init__(self, timeout_value: float, timeout_type: str = "execution"): |
| self.timeout_value = timeout_value |
| self.timeout_type = timeout_type |
| super().__init__(f"Worker {timeout_type} timeout after {timeout_value}s") |
|
|
|
|
| class HealthCheckTimeoutError(Exception): |
| """Health Check-level timeout exception""" |
|
|
| def __init__(self, timeout_value: float, execution_duration: float): |
| self.timeout_value = timeout_value |
| self.execution_duration = execution_duration |
| super().__init__( |
| f"Task forcefully terminated due to execution timeout (>{timeout_value}s, actual: {execution_duration:.1f}s)" |
| ) |
|
|
|
|
| def priority_limit_async_func_call( |
| max_size: int, |
| llm_timeout: float = None, |
| max_execution_timeout: float = None, |
| max_task_duration: float = None, |
| max_queue_size: int = 1000, |
| cleanup_timeout: float = 2.0, |
| queue_name: str = "limit_async", |
| ): |
| """ |
| Enhanced priority-limited asynchronous function call decorator with robust timeout handling |
| |
| This decorator provides a comprehensive solution for managing concurrent LLM requests with: |
| - Multi-layer timeout protection (LLM -> Worker -> Health Check -> User) |
| - Task state tracking to prevent race conditions |
| - Enhanced health check system with stuck task detection |
| - Proper resource cleanup and error recovery |
| |
| Args: |
| max_size: Maximum number of concurrent calls |
| max_queue_size: Maximum queue capacity to prevent memory overflow |
| llm_timeout: LLM provider timeout (from global config), used to calculate other timeouts |
| max_execution_timeout: Maximum time for worker to execute function (defaults to llm_timeout + 30s) |
| max_task_duration: Maximum time before health check intervenes (defaults to llm_timeout + 60s) |
| cleanup_timeout: Maximum time to wait for cleanup operations (defaults to 2.0s) |
| queue_name: Optional queue name for logging identification (defaults to "limit_async") |
| |
| Returns: |
| Decorator function |
| """ |
|
|
| def final_decro(func): |
| |
| if not callable(func): |
| raise TypeError(f"Expected a callable object, got {type(func)}") |
|
|
| |
| if llm_timeout is not None: |
| nonlocal max_execution_timeout, max_task_duration |
| if max_execution_timeout is None: |
| max_execution_timeout = ( |
| llm_timeout * 2 |
| ) |
| if max_task_duration is None: |
| max_task_duration = ( |
| llm_timeout * 2 + 15 |
| ) |
|
|
| queue = asyncio.PriorityQueue(maxsize=max_queue_size) |
| tasks = set() |
| initialization_lock = asyncio.Lock() |
| counter = 0 |
| shutdown_event = asyncio.Event() |
| initialized = False |
| worker_health_check_task = None |
|
|
| |
| task_states = {} |
| task_states_lock = asyncio.Lock() |
| active_futures = weakref.WeakSet() |
| reinit_count = 0 |
|
|
| async def worker(): |
| """Enhanced worker that processes tasks with proper timeout and state management""" |
| try: |
| while not shutdown_event.is_set(): |
| try: |
| |
| try: |
| ( |
| priority, |
| count, |
| task_id, |
| args, |
| kwargs, |
| ) = await asyncio.wait_for(queue.get(), timeout=1.0) |
| except asyncio.TimeoutError: |
| continue |
|
|
| |
| async with task_states_lock: |
| if task_id not in task_states: |
| queue.task_done() |
| continue |
| task_state = task_states[task_id] |
| task_state.worker_started = True |
| |
| task_state.execution_start_time = ( |
| asyncio.get_event_loop().time() |
| ) |
|
|
| |
| if ( |
| task_state.cancellation_requested |
| or task_state.future.cancelled() |
| ): |
| async with task_states_lock: |
| task_states.pop(task_id, None) |
| queue.task_done() |
| continue |
|
|
| try: |
| |
| if max_execution_timeout is not None: |
| result = await asyncio.wait_for( |
| func(*args, **kwargs), timeout=max_execution_timeout |
| ) |
| else: |
| result = await func(*args, **kwargs) |
|
|
| |
| if not task_state.future.done(): |
| task_state.future.set_result(result) |
|
|
| except asyncio.TimeoutError: |
| |
| logger.warning( |
| f"{queue_name}: Worker timeout for task {task_id} after {max_execution_timeout}s" |
| ) |
| if not task_state.future.done(): |
| task_state.future.set_exception( |
| WorkerTimeoutError( |
| max_execution_timeout, "execution" |
| ) |
| ) |
| except asyncio.CancelledError: |
| |
| if not task_state.future.done(): |
| task_state.future.cancel() |
| logger.debug( |
| f"{queue_name}: Task {task_id} cancelled during execution" |
| ) |
| except Exception as e: |
| |
| logger.error( |
| f"{queue_name}: Error in decorated function for task {task_id}: {str(e)}" |
| ) |
| if not task_state.future.done(): |
| task_state.future.set_exception(e) |
| finally: |
| |
| async with task_states_lock: |
| task_states.pop(task_id, None) |
| queue.task_done() |
|
|
| except Exception as e: |
| |
| logger.error( |
| f"{queue_name}: Critical error in worker: {str(e)}" |
| ) |
| await asyncio.sleep(0.1) |
| finally: |
| logger.debug(f"{queue_name}: Worker exiting") |
|
|
| async def enhanced_health_check(): |
| """Enhanced health check with stuck task detection and recovery""" |
| nonlocal initialized |
| try: |
| while not shutdown_event.is_set(): |
| await asyncio.sleep(5) |
|
|
| current_time = asyncio.get_event_loop().time() |
|
|
| |
| if max_task_duration is not None: |
| stuck_tasks = [] |
| async with task_states_lock: |
| for task_id, task_state in list(task_states.items()): |
| |
| if ( |
| task_state.worker_started |
| and task_state.execution_start_time is not None |
| and current_time - task_state.execution_start_time |
| > max_task_duration |
| ): |
| stuck_tasks.append( |
| ( |
| task_id, |
| current_time |
| - task_state.execution_start_time, |
| ) |
| ) |
|
|
| |
| for task_id, execution_duration in stuck_tasks: |
| logger.warning( |
| f"{queue_name}: Detected stuck task {task_id} (execution time: {execution_duration:.1f}s), forcing cleanup" |
| ) |
| async with task_states_lock: |
| if task_id in task_states: |
| task_state = task_states[task_id] |
| if not task_state.future.done(): |
| task_state.future.set_exception( |
| HealthCheckTimeoutError( |
| max_task_duration, execution_duration |
| ) |
| ) |
| task_states.pop(task_id, None) |
|
|
| |
| current_tasks = set(tasks) |
| done_tasks = {t for t in current_tasks if t.done()} |
| tasks.difference_update(done_tasks) |
|
|
| active_tasks_count = len(tasks) |
| workers_needed = max_size - active_tasks_count |
|
|
| if workers_needed > 0: |
| logger.info( |
| f"{queue_name}: Creating {workers_needed} new workers" |
| ) |
| new_tasks = set() |
| for _ in range(workers_needed): |
| task = asyncio.create_task(worker()) |
| new_tasks.add(task) |
| task.add_done_callback(tasks.discard) |
| tasks.update(new_tasks) |
|
|
| except Exception as e: |
| logger.error(f"{queue_name}: Error in enhanced health check: {str(e)}") |
| finally: |
| logger.debug(f"{queue_name}: Enhanced health check task exiting") |
| initialized = False |
|
|
| async def ensure_workers(): |
| """Ensure worker system is initialized with enhanced error handling""" |
| nonlocal initialized, worker_health_check_task, tasks, reinit_count |
|
|
| if initialized: |
| return |
|
|
| async with initialization_lock: |
| if initialized: |
| return |
|
|
| if reinit_count > 0: |
| reinit_count += 1 |
| logger.warning( |
| f"{queue_name}: Reinitializing system (count: {reinit_count})" |
| ) |
| else: |
| reinit_count = 1 |
|
|
| |
| current_tasks = set(tasks) |
| done_tasks = {t for t in current_tasks if t.done()} |
| tasks.difference_update(done_tasks) |
|
|
| active_tasks_count = len(tasks) |
| if active_tasks_count > 0 and reinit_count > 1: |
| logger.warning( |
| f"{queue_name}: {active_tasks_count} tasks still running during reinitialization" |
| ) |
|
|
| |
| workers_needed = max_size - active_tasks_count |
| for _ in range(workers_needed): |
| task = asyncio.create_task(worker()) |
| tasks.add(task) |
| task.add_done_callback(tasks.discard) |
|
|
| |
| worker_health_check_task = asyncio.create_task(enhanced_health_check()) |
|
|
| initialized = True |
| |
| timeout_info = [] |
| if llm_timeout is not None: |
| timeout_info.append(f"Func: {llm_timeout}s") |
| if max_execution_timeout is not None: |
| timeout_info.append(f"Worker: {max_execution_timeout}s") |
| if max_task_duration is not None: |
| timeout_info.append(f"Health Check: {max_task_duration}s") |
|
|
| timeout_str = ( |
| f"(Timeouts: {', '.join(timeout_info)})" if timeout_info else "" |
| ) |
| logger.info( |
| f"{queue_name}: {workers_needed} new workers initialized {timeout_str}" |
| ) |
|
|
| async def shutdown(): |
| """Gracefully shut down all workers and cleanup resources""" |
| logger.info(f"{queue_name}: Shutting down priority queue workers") |
|
|
| shutdown_event.set() |
|
|
| |
| for future in list(active_futures): |
| if not future.done(): |
| future.cancel() |
|
|
| |
| async with task_states_lock: |
| for task_id, task_state in list(task_states.items()): |
| if not task_state.future.done(): |
| task_state.future.cancel() |
| task_states.clear() |
|
|
| |
| try: |
| await asyncio.wait_for(queue.join(), timeout=5.0) |
| except asyncio.TimeoutError: |
| logger.warning( |
| f"{queue_name}: Timeout waiting for queue to empty during shutdown" |
| ) |
|
|
| |
| for task in list(tasks): |
| if not task.done(): |
| task.cancel() |
|
|
| |
| if tasks: |
| await asyncio.gather(*tasks, return_exceptions=True) |
|
|
| |
| if worker_health_check_task and not worker_health_check_task.done(): |
| worker_health_check_task.cancel() |
| try: |
| await worker_health_check_task |
| except asyncio.CancelledError: |
| pass |
|
|
| logger.info(f"{queue_name}: Priority queue workers shutdown complete") |
|
|
| @wraps(func) |
| async def wait_func( |
| *args, _priority=10, _timeout=None, _queue_timeout=None, **kwargs |
| ): |
| """ |
| Execute function with enhanced priority-based concurrency control and timeout handling |
| |
| Args: |
| *args: Positional arguments passed to the function |
| _priority: Call priority (lower values have higher priority) |
| _timeout: Maximum time to wait for completion (in seconds, none means determinded by max_execution_timeout of the queue) |
| _queue_timeout: Maximum time to wait for entering the queue (in seconds) |
| **kwargs: Keyword arguments passed to the function |
| |
| Returns: |
| The result of the function call |
| |
| Raises: |
| TimeoutError: If the function call times out at any level |
| QueueFullError: If the queue is full and waiting times out |
| Any exception raised by the decorated function |
| """ |
| await ensure_workers() |
|
|
| |
| task_id = f"{id(asyncio.current_task())}_{asyncio.get_event_loop().time()}" |
| future = asyncio.Future() |
|
|
| |
| task_state = TaskState( |
| future=future, start_time=asyncio.get_event_loop().time() |
| ) |
|
|
| try: |
| |
| async with task_states_lock: |
| task_states[task_id] = task_state |
|
|
| active_futures.add(future) |
|
|
| |
| nonlocal counter |
| async with initialization_lock: |
| current_count = counter |
| counter += 1 |
|
|
| |
| try: |
| if _queue_timeout is not None: |
| await asyncio.wait_for( |
| queue.put( |
| (_priority, current_count, task_id, args, kwargs) |
| ), |
| timeout=_queue_timeout, |
| ) |
| else: |
| await queue.put( |
| (_priority, current_count, task_id, args, kwargs) |
| ) |
| except asyncio.TimeoutError: |
| raise QueueFullError( |
| f"{queue_name}: Queue full, timeout after {_queue_timeout} seconds" |
| ) |
| except Exception as e: |
| |
| if not future.done(): |
| future.set_exception(e) |
| raise |
|
|
| |
| try: |
| if _timeout is not None: |
| return await asyncio.wait_for(future, _timeout) |
| else: |
| return await future |
| except asyncio.TimeoutError: |
| |
| |
| async with task_states_lock: |
| if task_id in task_states: |
| task_states[task_id].cancellation_requested = True |
|
|
| |
| if not future.done(): |
| future.cancel() |
|
|
| |
| cleanup_start = asyncio.get_event_loop().time() |
| while ( |
| task_id in task_states |
| and asyncio.get_event_loop().time() - cleanup_start |
| < cleanup_timeout |
| ): |
| await asyncio.sleep(0.1) |
|
|
| raise TimeoutError( |
| f"{queue_name}: User timeout after {_timeout} seconds" |
| ) |
| except WorkerTimeoutError as e: |
| |
| raise TimeoutError(f"{queue_name}: {str(e)}") |
| except HealthCheckTimeoutError as e: |
| |
| raise TimeoutError(f"{queue_name}: {str(e)}") |
|
|
| finally: |
| |
| active_futures.discard(future) |
| async with task_states_lock: |
| task_states.pop(task_id, None) |
|
|
| |
| wait_func.shutdown = shutdown |
|
|
| return wait_func |
|
|
| return final_decro |
|
|
|
|
| def wrap_embedding_func_with_attrs(**kwargs): |
| """Wrap a function with attributes""" |
|
|
| def final_decro(func) -> EmbeddingFunc: |
| new_func = EmbeddingFunc(**kwargs, func=func) |
| return new_func |
|
|
| return final_decro |
|
|
|
|
| def load_json(file_name): |
| if not os.path.exists(file_name): |
| return None |
| with open(file_name, encoding="utf-8-sig") as f: |
| return json.load(f) |
|
|
|
|
| def write_json(json_obj, file_name): |
| with open(file_name, "w", encoding="utf-8") as f: |
| json.dump(json_obj, f, indent=2, ensure_ascii=False) |
|
|
|
|
| class TokenizerInterface(Protocol): |
| """ |
| Defines the interface for a tokenizer, requiring encode and decode methods. |
| """ |
|
|
| def encode(self, content: str) -> List[int]: |
| """Encodes a string into a list of tokens.""" |
| ... |
|
|
| def decode(self, tokens: List[int]) -> str: |
| """Decodes a list of tokens into a string.""" |
| ... |
|
|
|
|
| class Tokenizer: |
| """ |
| A wrapper around a tokenizer to provide a consistent interface for encoding and decoding. |
| """ |
|
|
| def __init__(self, model_name: str, tokenizer: TokenizerInterface): |
| """ |
| Initializes the Tokenizer with a tokenizer model name and a tokenizer instance. |
| |
| Args: |
| model_name: The associated model name for the tokenizer. |
| tokenizer: An instance of a class implementing the TokenizerInterface. |
| """ |
| self.model_name: str = model_name |
| self.tokenizer: TokenizerInterface = tokenizer |
|
|
| def encode(self, content: str) -> List[int]: |
| """ |
| Encodes a string into a list of tokens using the underlying tokenizer. |
| |
| Args: |
| content: The string to encode. |
| |
| Returns: |
| A list of integer tokens. |
| """ |
| return self.tokenizer.encode(content) |
|
|
| def decode(self, tokens: List[int]) -> str: |
| """ |
| Decodes a list of tokens into a string using the underlying tokenizer. |
| |
| Args: |
| tokens: A list of integer tokens to decode. |
| |
| Returns: |
| The decoded string. |
| """ |
| return self.tokenizer.decode(tokens) |
|
|
|
|
| class TiktokenTokenizer(Tokenizer): |
| """ |
| A Tokenizer implementation using the tiktoken library. |
| """ |
|
|
| def __init__(self, model_name: str = "gpt-4o-mini"): |
| """ |
| Initializes the TiktokenTokenizer with a specified model name. |
| |
| Args: |
| model_name: The model name for the tiktoken tokenizer to use. Defaults to "gpt-4o-mini". |
| |
| Raises: |
| ImportError: If tiktoken is not installed. |
| ValueError: If the model_name is invalid. |
| """ |
| try: |
| import tiktoken |
| except ImportError: |
| raise ImportError( |
| "tiktoken is not installed. Please install it with `pip install tiktoken` or define custom `tokenizer_func`." |
| ) |
|
|
| try: |
| tokenizer = tiktoken.encoding_for_model(model_name) |
| super().__init__(model_name=model_name, tokenizer=tokenizer) |
| except KeyError: |
| raise ValueError(f"Invalid model_name: {model_name}.") |
|
|
|
|
| def pack_user_ass_to_openai_messages(*args: str): |
| roles = ["user", "assistant"] |
| return [ |
| {"role": roles[i % 2], "content": content} for i, content in enumerate(args) |
| ] |
|
|
|
|
| def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]: |
| """Split a string by multiple markers""" |
| if not markers: |
| return [content] |
| content = content if content is not None else "" |
| results = re.split("|".join(re.escape(marker) for marker in markers), content) |
| return [r.strip() for r in results if r.strip()] |
|
|
|
|
| def is_float_regex(value: str) -> bool: |
| return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value)) |
|
|
|
|
| def truncate_list_by_token_size( |
| list_data: list[Any], |
| key: Callable[[Any], str], |
| max_token_size: int, |
| tokenizer: Tokenizer, |
| ) -> list[int]: |
| """Truncate a list of data by token size""" |
| if max_token_size <= 0: |
| return [] |
| tokens = 0 |
| for i, data in enumerate(list_data): |
| tokens += len(tokenizer.encode(key(data))) |
| if tokens > max_token_size: |
| return list_data[:i] |
| return list_data |
|
|
|
|
| def cosine_similarity(v1, v2): |
| """Calculate cosine similarity between two vectors""" |
| dot_product = np.dot(v1, v2) |
| norm1 = np.linalg.norm(v1) |
| norm2 = np.linalg.norm(v2) |
| return dot_product / (norm1 * norm2) |
|
|
|
|
| async def handle_cache( |
| hashing_kv, |
| args_hash, |
| prompt, |
| mode="default", |
| cache_type="unknown", |
| ) -> tuple[str, int] | None: |
| """Generic cache handling function with flattened cache keys |
| |
| Returns: |
| tuple[str, int] | None: (content, create_time) if cache hit, None if cache miss |
| """ |
| if hashing_kv is None: |
| return None |
|
|
| if mode != "default": |
| if not hashing_kv.global_config.get("enable_llm_cache"): |
| return None |
| else: |
| if not hashing_kv.global_config.get("enable_llm_cache_for_entity_extract"): |
| return None |
|
|
| |
| flattened_key = generate_cache_key(mode, cache_type, args_hash) |
| cache_entry = await hashing_kv.get_by_id(flattened_key) |
| if cache_entry: |
| logger.debug(f"Flattened cache hit(key:{flattened_key})") |
| content = cache_entry["return"] |
| timestamp = cache_entry.get("create_time", 0) |
| return content, timestamp |
|
|
| logger.debug(f"Cache missed(mode:{mode} type:{cache_type})") |
| return None |
|
|
|
|
| @dataclass |
| class CacheData: |
| args_hash: str |
| content: str |
| prompt: str |
| mode: str = "default" |
| cache_type: str = "query" |
| chunk_id: str | None = None |
| queryparam: dict | None = None |
|
|
|
|
| async def save_to_cache(hashing_kv, cache_data: CacheData): |
| """Save data to cache using flattened key structure. |
| |
| Args: |
| hashing_kv: The key-value storage for caching |
| cache_data: The cache data to save |
| """ |
| |
| if hashing_kv is None or not cache_data.content: |
| return |
|
|
| |
| if hasattr(cache_data.content, "__aiter__"): |
| logger.debug("Streaming response detected, skipping cache") |
| return |
|
|
| |
| flattened_key = generate_cache_key( |
| cache_data.mode, cache_data.cache_type, cache_data.args_hash |
| ) |
|
|
| |
| existing_cache = await hashing_kv.get_by_id(flattened_key) |
| if existing_cache: |
| existing_content = existing_cache.get("return") |
| if existing_content == cache_data.content: |
| logger.warning( |
| f"Cache duplication detected for {flattened_key}, skipping update" |
| ) |
| return |
|
|
| |
| cache_entry = { |
| "return": cache_data.content, |
| "cache_type": cache_data.cache_type, |
| "chunk_id": cache_data.chunk_id if cache_data.chunk_id is not None else None, |
| "original_prompt": cache_data.prompt, |
| "queryparam": cache_data.queryparam |
| if cache_data.queryparam is not None |
| else None, |
| } |
|
|
| logger.info(f" == LLM cache == saving: {flattened_key}") |
|
|
| |
| await hashing_kv.upsert({flattened_key: cache_entry}) |
|
|
|
|
| def safe_unicode_decode(content): |
| |
| unicode_escape_pattern = re.compile(r"\\u([0-9a-fA-F]{4})") |
|
|
| |
| def replace_unicode_escape(match): |
| |
| return chr(int(match.group(1), 16)) |
|
|
| |
| decoded_content = unicode_escape_pattern.sub( |
| replace_unicode_escape, content.decode("utf-8") |
| ) |
|
|
| return decoded_content |
|
|
|
|
| def exists_func(obj, func_name: str) -> bool: |
| """Check if a function exists in an object or not. |
| :param obj: |
| :param func_name: |
| :return: True / False |
| """ |
| if callable(getattr(obj, func_name, None)): |
| return True |
| else: |
| return False |
|
|
|
|
| def always_get_an_event_loop() -> asyncio.AbstractEventLoop: |
| """ |
| Ensure that there is always an event loop available. |
| |
| This function tries to get the current event loop. If the current event loop is closed or does not exist, |
| it creates a new event loop and sets it as the current event loop. |
| |
| Returns: |
| asyncio.AbstractEventLoop: The current or newly created event loop. |
| """ |
| try: |
| |
| current_loop = asyncio.get_event_loop() |
| if current_loop.is_closed(): |
| raise RuntimeError("Event loop is closed.") |
| return current_loop |
|
|
| except RuntimeError: |
| |
| logger.info("Creating a new event loop in main thread.") |
| new_loop = asyncio.new_event_loop() |
| asyncio.set_event_loop(new_loop) |
| return new_loop |
|
|
|
|
| async def aexport_data( |
| chunk_entity_relation_graph, |
| entities_vdb, |
| relationships_vdb, |
| output_path: str, |
| file_format: str = "csv", |
| include_vector_data: bool = False, |
| ) -> None: |
| """ |
| Asynchronously exports all entities, relations, and relationships to various formats. |
| |
| Args: |
| chunk_entity_relation_graph: Graph storage instance for entities and relations |
| entities_vdb: Vector database storage for entities |
| relationships_vdb: Vector database storage for relationships |
| output_path: The path to the output file (including extension). |
| file_format: Output format - "csv", "excel", "md", "txt". |
| - csv: Comma-separated values file |
| - excel: Microsoft Excel file with multiple sheets |
| - md: Markdown tables |
| - txt: Plain text formatted output |
| include_vector_data: Whether to include data from the vector database. |
| """ |
| |
| entities_data = [] |
| relations_data = [] |
| relationships_data = [] |
|
|
| |
| all_entities = await chunk_entity_relation_graph.get_all_labels() |
| for entity_name in all_entities: |
| |
| node_data = await chunk_entity_relation_graph.get_node(entity_name) |
| source_id = node_data.get("source_id") if node_data else None |
|
|
| entity_info = { |
| "graph_data": node_data, |
| "source_id": source_id, |
| } |
|
|
| |
| if include_vector_data: |
| entity_id = compute_mdhash_id(entity_name, prefix="ent-") |
| vector_data = await entities_vdb.get_by_id(entity_id) |
| entity_info["vector_data"] = vector_data |
|
|
| entity_row = { |
| "entity_name": entity_name, |
| "source_id": source_id, |
| "graph_data": str( |
| entity_info["graph_data"] |
| ), |
| } |
| if include_vector_data and "vector_data" in entity_info: |
| entity_row["vector_data"] = str(entity_info["vector_data"]) |
| entities_data.append(entity_row) |
|
|
| |
| for src_entity in all_entities: |
| for tgt_entity in all_entities: |
| if src_entity == tgt_entity: |
| continue |
|
|
| edge_exists = await chunk_entity_relation_graph.has_edge( |
| src_entity, tgt_entity |
| ) |
| if edge_exists: |
| |
| edge_data = await chunk_entity_relation_graph.get_edge( |
| src_entity, tgt_entity |
| ) |
| source_id = edge_data.get("source_id") if edge_data else None |
|
|
| relation_info = { |
| "graph_data": edge_data, |
| "source_id": source_id, |
| } |
|
|
| |
| if include_vector_data: |
| rel_id = compute_mdhash_id(src_entity + tgt_entity, prefix="rel-") |
| vector_data = await relationships_vdb.get_by_id(rel_id) |
| relation_info["vector_data"] = vector_data |
|
|
| relation_row = { |
| "src_entity": src_entity, |
| "tgt_entity": tgt_entity, |
| "source_id": relation_info["source_id"], |
| "graph_data": str(relation_info["graph_data"]), |
| } |
| if include_vector_data and "vector_data" in relation_info: |
| relation_row["vector_data"] = str(relation_info["vector_data"]) |
| relations_data.append(relation_row) |
|
|
| |
| all_relationships = await relationships_vdb.client_storage |
| for rel in all_relationships["data"]: |
| relationships_data.append( |
| { |
| "relationship_id": rel["__id__"], |
| "data": str(rel), |
| } |
| ) |
|
|
| |
| if file_format == "csv": |
| |
| with open(output_path, "w", newline="", encoding="utf-8") as csvfile: |
| |
| if entities_data: |
| csvfile.write("# ENTITIES\n") |
| writer = csv.DictWriter(csvfile, fieldnames=entities_data[0].keys()) |
| writer.writeheader() |
| writer.writerows(entities_data) |
| csvfile.write("\n\n") |
|
|
| |
| if relations_data: |
| csvfile.write("# RELATIONS\n") |
| writer = csv.DictWriter(csvfile, fieldnames=relations_data[0].keys()) |
| writer.writeheader() |
| writer.writerows(relations_data) |
| csvfile.write("\n\n") |
|
|
| |
| if relationships_data: |
| csvfile.write("# RELATIONSHIPS\n") |
| writer = csv.DictWriter( |
| csvfile, fieldnames=relationships_data[0].keys() |
| ) |
| writer.writeheader() |
| writer.writerows(relationships_data) |
|
|
| elif file_format == "excel": |
| |
| import pandas as pd |
|
|
| entities_df = pd.DataFrame(entities_data) if entities_data else pd.DataFrame() |
| relations_df = ( |
| pd.DataFrame(relations_data) if relations_data else pd.DataFrame() |
| ) |
| relationships_df = ( |
| pd.DataFrame(relationships_data) if relationships_data else pd.DataFrame() |
| ) |
|
|
| with pd.ExcelWriter(output_path, engine="xlsxwriter") as writer: |
| if not entities_df.empty: |
| entities_df.to_excel(writer, sheet_name="Entities", index=False) |
| if not relations_df.empty: |
| relations_df.to_excel(writer, sheet_name="Relations", index=False) |
| if not relationships_df.empty: |
| relationships_df.to_excel( |
| writer, sheet_name="Relationships", index=False |
| ) |
|
|
| elif file_format == "md": |
| |
| with open(output_path, "w", encoding="utf-8") as mdfile: |
| mdfile.write("# LightRAG Data Export\n\n") |
|
|
| |
| mdfile.write("## Entities\n\n") |
| if entities_data: |
| |
| mdfile.write("| " + " | ".join(entities_data[0].keys()) + " |\n") |
| mdfile.write( |
| "| " + " | ".join(["---"] * len(entities_data[0].keys())) + " |\n" |
| ) |
|
|
| |
| for entity in entities_data: |
| mdfile.write( |
| "| " + " | ".join(str(v) for v in entity.values()) + " |\n" |
| ) |
| mdfile.write("\n\n") |
| else: |
| mdfile.write("*No entity data available*\n\n") |
|
|
| |
| mdfile.write("## Relations\n\n") |
| if relations_data: |
| |
| mdfile.write("| " + " | ".join(relations_data[0].keys()) + " |\n") |
| mdfile.write( |
| "| " + " | ".join(["---"] * len(relations_data[0].keys())) + " |\n" |
| ) |
|
|
| |
| for relation in relations_data: |
| mdfile.write( |
| "| " + " | ".join(str(v) for v in relation.values()) + " |\n" |
| ) |
| mdfile.write("\n\n") |
| else: |
| mdfile.write("*No relation data available*\n\n") |
|
|
| |
| mdfile.write("## Relationships\n\n") |
| if relationships_data: |
| |
| mdfile.write("| " + " | ".join(relationships_data[0].keys()) + " |\n") |
| mdfile.write( |
| "| " |
| + " | ".join(["---"] * len(relationships_data[0].keys())) |
| + " |\n" |
| ) |
|
|
| |
| for relationship in relationships_data: |
| mdfile.write( |
| "| " |
| + " | ".join(str(v) for v in relationship.values()) |
| + " |\n" |
| ) |
| else: |
| mdfile.write("*No relationship data available*\n\n") |
|
|
| elif file_format == "txt": |
| |
| with open(output_path, "w", encoding="utf-8") as txtfile: |
| txtfile.write("LIGHTRAG DATA EXPORT\n") |
| txtfile.write("=" * 80 + "\n\n") |
|
|
| |
| txtfile.write("ENTITIES\n") |
| txtfile.write("-" * 80 + "\n") |
| if entities_data: |
| |
| col_widths = { |
| k: max(len(k), max(len(str(e[k])) for e in entities_data)) |
| for k in entities_data[0] |
| } |
| header = " ".join(k.ljust(col_widths[k]) for k in entities_data[0]) |
| txtfile.write(header + "\n") |
| txtfile.write("-" * len(header) + "\n") |
|
|
| |
| for entity in entities_data: |
| row = " ".join( |
| str(v).ljust(col_widths[k]) for k, v in entity.items() |
| ) |
| txtfile.write(row + "\n") |
| txtfile.write("\n\n") |
| else: |
| txtfile.write("No entity data available\n\n") |
|
|
| |
| txtfile.write("RELATIONS\n") |
| txtfile.write("-" * 80 + "\n") |
| if relations_data: |
| |
| col_widths = { |
| k: max(len(k), max(len(str(r[k])) for r in relations_data)) |
| for k in relations_data[0] |
| } |
| header = " ".join(k.ljust(col_widths[k]) for k in relations_data[0]) |
| txtfile.write(header + "\n") |
| txtfile.write("-" * len(header) + "\n") |
|
|
| |
| for relation in relations_data: |
| row = " ".join( |
| str(v).ljust(col_widths[k]) for k, v in relation.items() |
| ) |
| txtfile.write(row + "\n") |
| txtfile.write("\n\n") |
| else: |
| txtfile.write("No relation data available\n\n") |
|
|
| |
| txtfile.write("RELATIONSHIPS\n") |
| txtfile.write("-" * 80 + "\n") |
| if relationships_data: |
| |
| col_widths = { |
| k: max(len(k), max(len(str(r[k])) for r in relationships_data)) |
| for k in relationships_data[0] |
| } |
| header = " ".join( |
| k.ljust(col_widths[k]) for k in relationships_data[0] |
| ) |
| txtfile.write(header + "\n") |
| txtfile.write("-" * len(header) + "\n") |
|
|
| |
| for relationship in relationships_data: |
| row = " ".join( |
| str(v).ljust(col_widths[k]) for k, v in relationship.items() |
| ) |
| txtfile.write(row + "\n") |
| else: |
| txtfile.write("No relationship data available\n\n") |
|
|
| else: |
| raise ValueError( |
| f"Unsupported file format: {file_format}. " |
| f"Choose from: csv, excel, md, txt" |
| ) |
| if file_format is not None: |
| print(f"Data exported to: {output_path} with format: {file_format}") |
| else: |
| print("Data displayed as table format") |
|
|
|
|
| def export_data( |
| chunk_entity_relation_graph, |
| entities_vdb, |
| relationships_vdb, |
| output_path: str, |
| file_format: str = "csv", |
| include_vector_data: bool = False, |
| ) -> None: |
| """ |
| Synchronously exports all entities, relations, and relationships to various formats. |
| |
| Args: |
| chunk_entity_relation_graph: Graph storage instance for entities and relations |
| entities_vdb: Vector database storage for entities |
| relationships_vdb: Vector database storage for relationships |
| output_path: The path to the output file (including extension). |
| file_format: Output format - "csv", "excel", "md", "txt". |
| - csv: Comma-separated values file |
| - excel: Microsoft Excel file with multiple sheets |
| - md: Markdown tables |
| - txt: Plain text formatted output |
| include_vector_data: Whether to include data from the vector database. |
| """ |
| try: |
| loop = asyncio.get_event_loop() |
| except RuntimeError: |
| loop = asyncio.new_event_loop() |
| asyncio.set_event_loop(loop) |
|
|
| loop.run_until_complete( |
| aexport_data( |
| chunk_entity_relation_graph, |
| entities_vdb, |
| relationships_vdb, |
| output_path, |
| file_format, |
| include_vector_data, |
| ) |
| ) |
|
|
|
|
| def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any]: |
| """Lazily import a class from an external module based on the package of the caller.""" |
| |
| import inspect |
|
|
| caller_frame = inspect.currentframe().f_back |
| module = inspect.getmodule(caller_frame) |
| package = module.__package__ if module else None |
|
|
| def import_class(*args: Any, **kwargs: Any): |
| import importlib |
|
|
| module = importlib.import_module(module_name, package=package) |
| cls = getattr(module, class_name) |
| return cls(*args, **kwargs) |
|
|
| return import_class |
|
|
|
|
| async def update_chunk_cache_list( |
| chunk_id: str, |
| text_chunks_storage: "BaseKVStorage", |
| cache_keys: list[str], |
| cache_scenario: str = "batch_update", |
| ) -> None: |
| """Update chunk's llm_cache_list with the given cache keys |
| |
| Args: |
| chunk_id: Chunk identifier |
| text_chunks_storage: Text chunks storage instance |
| cache_keys: List of cache keys to add to the list |
| cache_scenario: Description of the cache scenario for logging |
| """ |
| if not cache_keys: |
| return |
|
|
| try: |
| chunk_data = await text_chunks_storage.get_by_id(chunk_id) |
| if chunk_data: |
| |
| if "llm_cache_list" not in chunk_data: |
| chunk_data["llm_cache_list"] = [] |
|
|
| |
| existing_keys = set(chunk_data["llm_cache_list"]) |
| new_keys = [key for key in cache_keys if key not in existing_keys] |
|
|
| if new_keys: |
| chunk_data["llm_cache_list"].extend(new_keys) |
|
|
| |
| await text_chunks_storage.upsert({chunk_id: chunk_data}) |
| logger.debug( |
| f"Updated chunk {chunk_id} with {len(new_keys)} cache keys ({cache_scenario})" |
| ) |
| except Exception as e: |
| logger.warning( |
| f"Failed to update chunk {chunk_id} with cache references on {cache_scenario}: {e}" |
| ) |
|
|
|
|
| def remove_think_tags(text: str) -> str: |
| """Remove <think>...</think> tags from the text |
| Remove orphon ...</think> tags from the text also""" |
| return re.sub( |
| r"^(<think>.*?</think>|.*</think>)", "", text, flags=re.DOTALL |
| ).strip() |
|
|
|
|
| async def use_llm_func_with_cache( |
| user_prompt: str, |
| use_llm_func: callable, |
| llm_response_cache: "BaseKVStorage | None" = None, |
| system_prompt: str | None = None, |
| max_tokens: int = None, |
| history_messages: list[dict[str, str]] = None, |
| cache_type: str = "extract", |
| chunk_id: str | None = None, |
| cache_keys_collector: list = None, |
| ) -> tuple[str, int]: |
| """Call LLM function with cache support and text sanitization |
| |
| If cache is available and enabled (determined by handle_cache based on mode), |
| retrieve result from cache; otherwise call LLM function and save result to cache. |
| |
| This function applies text sanitization to prevent UTF-8 encoding errors for all LLM providers. |
| |
| Args: |
| input_text: Input text to send to LLM |
| use_llm_func: LLM function with higher priority |
| llm_response_cache: Cache storage instance |
| max_tokens: Maximum tokens for generation |
| history_messages: History messages list |
| cache_type: Type of cache |
| chunk_id: Chunk identifier to store in cache |
| text_chunks_storage: Text chunks storage to update llm_cache_list |
| cache_keys_collector: Optional list to collect cache keys for batch processing |
| |
| Returns: |
| tuple[str, int]: (LLM response text, timestamp) |
| - For cache hits: (content, cache_create_time) |
| - For cache misses: (content, current_timestamp) |
| """ |
| |
| safe_user_prompt = sanitize_text_for_encoding(user_prompt) |
| safe_system_prompt = ( |
| sanitize_text_for_encoding(system_prompt) if system_prompt else None |
| ) |
|
|
| |
| safe_history_messages = None |
| if history_messages: |
| safe_history_messages = [] |
| for i, msg in enumerate(history_messages): |
| safe_msg = msg.copy() |
| if "content" in safe_msg: |
| safe_msg["content"] = sanitize_text_for_encoding(safe_msg["content"]) |
| safe_history_messages.append(safe_msg) |
| history = json.dumps(safe_history_messages, ensure_ascii=False) |
| else: |
| history = None |
|
|
| if llm_response_cache: |
| prompt_parts = [] |
| if safe_user_prompt: |
| prompt_parts.append(safe_user_prompt) |
| if safe_system_prompt: |
| prompt_parts.append(safe_system_prompt) |
| if history: |
| prompt_parts.append(history) |
| _prompt = "\n".join(prompt_parts) |
|
|
| arg_hash = compute_args_hash(_prompt) |
| |
| cache_key = generate_cache_key("default", cache_type, arg_hash) |
|
|
| cached_result = await handle_cache( |
| llm_response_cache, |
| arg_hash, |
| _prompt, |
| "default", |
| cache_type=cache_type, |
| ) |
| if cached_result: |
| content, timestamp = cached_result |
| logger.debug(f"Found cache for {arg_hash}") |
| statistic_data["llm_cache"] += 1 |
|
|
| |
| if cache_keys_collector is not None: |
| cache_keys_collector.append(cache_key) |
|
|
| return content, timestamp |
| statistic_data["llm_call"] += 1 |
|
|
| |
| kwargs = {} |
| if safe_history_messages: |
| kwargs["history_messages"] = safe_history_messages |
| if max_tokens is not None: |
| kwargs["max_tokens"] = max_tokens |
|
|
| res: str = await use_llm_func( |
| safe_user_prompt, system_prompt=safe_system_prompt, **kwargs |
| ) |
|
|
| res = remove_think_tags(res) |
|
|
| |
| current_timestamp = int(time.time()) |
|
|
| if llm_response_cache.global_config.get("enable_llm_cache_for_entity_extract"): |
| await save_to_cache( |
| llm_response_cache, |
| CacheData( |
| args_hash=arg_hash, |
| content=res, |
| prompt=_prompt, |
| cache_type=cache_type, |
| chunk_id=chunk_id, |
| ), |
| ) |
|
|
| |
| if cache_keys_collector is not None: |
| cache_keys_collector.append(cache_key) |
|
|
| return res, current_timestamp |
|
|
| |
| kwargs = {} |
| if safe_history_messages: |
| kwargs["history_messages"] = safe_history_messages |
| if max_tokens is not None: |
| kwargs["max_tokens"] = max_tokens |
|
|
| try: |
| res = await use_llm_func( |
| safe_user_prompt, system_prompt=safe_system_prompt, **kwargs |
| ) |
| except Exception as e: |
| |
| error_msg = f"[LLM func] {str(e)}" |
| |
| raise type(e)(error_msg) from e |
|
|
| |
| current_timestamp = int(time.time()) |
| return remove_think_tags(res), current_timestamp |
|
|
|
|
| def get_content_summary(content: str, max_length: int = 250) -> str: |
| """Get summary of document content |
| |
| Args: |
| content: Original document content |
| max_length: Maximum length of summary |
| |
| Returns: |
| Truncated content with ellipsis if needed |
| """ |
| content = content.strip() |
| if len(content) <= max_length: |
| return content |
| return content[:max_length] + "..." |
|
|
|
|
| def sanitize_and_normalize_extracted_text( |
| input_text: str, remove_inner_quotes=False |
| ) -> str: |
| """Santitize and normalize extracted text |
| Args: |
| input_text: text string to be processed |
| is_name: whether the input text is a entity or relation name |
| |
| Returns: |
| Santitized and normalized text string |
| """ |
| safe_input_text = sanitize_text_for_encoding(input_text) |
| if safe_input_text: |
| normalized_text = normalize_extracted_info( |
| safe_input_text, remove_inner_quotes=remove_inner_quotes |
| ) |
| return normalized_text |
| return "" |
|
|
|
|
| def normalize_extracted_info(name: str, remove_inner_quotes=False) -> str: |
| """Normalize entity/relation names and description with the following rules: |
| - Clean HTML tags (paragraph and line break tags) |
| - Convert Chinese symbols to English symbols |
| - Remove spaces between Chinese characters |
| - Remove spaces between Chinese characters and English letters/numbers |
| - Preserve spaces within English text and numbers |
| - Replace Chinese parentheses with English parentheses |
| - Replace Chinese dash with English dash |
| - Remove English quotation marks from the beginning and end of the text |
| - Remove English quotation marks in and around chinese |
| - Remove Chinese quotation marks |
| - Filter out short numeric-only text (length < 3 and only digits/dots) |
| - remove_inner_quotes = True |
| remove Chinese quotes |
| remove English queotes in and around chinese |
| Convert non-breaking spaces to regular spaces |
| Convert narrow non-breaking spaces after non-digits to regular spaces |
| |
| Args: |
| name: Entity name to normalize |
| is_entity: Whether this is an entity name (affects quote handling) |
| |
| Returns: |
| Normalized entity name |
| """ |
| |
| name = re.sub(r"</p\s*>|<p\s*>|<p/>", "", name, flags=re.IGNORECASE) |
| name = re.sub(r"</br\s*>|<br\s*>|<br/>", "", name, flags=re.IGNORECASE) |
|
|
| |
| name = name.translate( |
| str.maketrans( |
| "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz", |
| "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz", |
| ) |
| ) |
|
|
| |
| name = name.translate(str.maketrans("0123456789", "0123456789")) |
|
|
| |
| name = name.replace("-", "-") |
| name = name.replace("+", "+") |
| name = name.replace("/", "/") |
| name = name.replace("*", "*") |
|
|
| |
| name = name.replace("(", "(").replace(")", ")") |
|
|
| |
| name = name.replace("—", "-").replace("-", "-") |
|
|
| |
| name = name.replace(" ", " ") |
|
|
| |
| |
| |
| |
| |
| name = re.sub(r"(?<=[\u4e00-\u9fa5])\s+(?=[\u4e00-\u9fa5])", "", name) |
|
|
| |
| name = re.sub( |
| r"(?<=[\u4e00-\u9fa5])\s+(?=[a-zA-Z0-9\(\)\[\]@#$%!&\*\-=+_])", "", name |
| ) |
| name = re.sub( |
| r"(?<=[a-zA-Z0-9\(\)\[\]@#$%!&\*\-=+_])\s+(?=[\u4e00-\u9fa5])", "", name |
| ) |
|
|
| |
| if len(name) >= 2: |
| |
| if name.startswith('"') and name.endswith('"'): |
| inner_content = name[1:-1] |
| if '"' not in inner_content: |
| name = inner_content |
|
|
| |
| if name.startswith("'") and name.endswith("'"): |
| inner_content = name[1:-1] |
| if "'" not in inner_content: |
| name = inner_content |
|
|
| |
| if name.startswith("“") and name.endswith("”"): |
| inner_content = name[1:-1] |
| if "“" not in inner_content and "”" not in inner_content: |
| name = inner_content |
| if name.startswith("‘") and name.endswith("’"): |
| inner_content = name[1:-1] |
| if "‘" not in inner_content and "’" not in inner_content: |
| name = inner_content |
|
|
| |
| if name.startswith("《") and name.endswith("》"): |
| inner_content = name[1:-1] |
| if "《" not in inner_content and "》" not in inner_content: |
| name = inner_content |
|
|
| if remove_inner_quotes: |
| |
| name = name.replace("“", "").replace("”", "").replace("‘", "").replace("’", "") |
| |
| name = re.sub(r"['\"]+(?=[\u4e00-\u9fa5])", "", name) |
| name = re.sub(r"(?<=[\u4e00-\u9fa5])['\"]+", "", name) |
| |
| name = name.replace("\u00a0", " ") |
| |
| name = re.sub(r"(?<=[^\d])\u202F", " ", name) |
|
|
| |
| name = name.strip() |
|
|
| |
| if len(name) < 3 and re.match(r"^[0-9]+$", name): |
| return "" |
|
|
| def should_filter_by_dots(text): |
| """ |
| Check if the string consists only of dots and digits, with at least one dot |
| Filter cases include: 1.2.3, 12.3, .123, 123., 12.3., .1.23 etc. |
| """ |
| return all(c.isdigit() or c == "." for c in text) and "." in text |
|
|
| if len(name) < 6 and should_filter_by_dots(name): |
| |
| return "" |
| |
| return "" |
|
|
| return name |
|
|
|
|
| def sanitize_text_for_encoding(text: str, replacement_char: str = "") -> str: |
| """Sanitize text to ensure safe UTF-8 encoding by removing or replacing problematic characters. |
| |
| This function handles: |
| - Surrogate characters (the main cause of encoding errors) |
| - Other invalid Unicode sequences |
| - Control characters that might cause issues |
| - Unescape HTML escapes |
| - Remove control characters |
| - Whitespace trimming |
| |
| Args: |
| text: Input text to sanitize |
| replacement_char: Character to use for replacing invalid sequences |
| |
| Returns: |
| Sanitized text that can be safely encoded as UTF-8 |
| |
| Raises: |
| ValueError: When text contains uncleanable encoding issues that cannot be safely processed |
| """ |
| if not text: |
| return text |
|
|
| try: |
| |
| text = text.strip() |
|
|
| |
| if not text: |
| return text |
|
|
| |
| text.encode("utf-8") |
|
|
| |
| |
| sanitized = "" |
| for char in text: |
| code_point = ord(char) |
| |
| if 0xD800 <= code_point <= 0xDFFF: |
| |
| sanitized += replacement_char |
| continue |
| |
| elif code_point == 0xFFFE or code_point == 0xFFFF: |
| |
| sanitized += replacement_char |
| continue |
| else: |
| sanitized += char |
|
|
| |
| |
| sanitized = re.sub( |
| r"[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]", replacement_char, sanitized |
| ) |
|
|
| |
| sanitized.encode("utf-8") |
|
|
| |
| sanitized = html.unescape(sanitized) |
|
|
| |
| sanitized = re.sub(r"[\x00-\x08\x0B\x0C\x0E-\x1F\x7F-\x9F]", "", sanitized) |
|
|
| return sanitized.strip() |
|
|
| except UnicodeEncodeError as e: |
| |
| error_msg = f"Text contains uncleanable UTF-8 encoding issues: {str(e)[:100]}" |
| logger.error(f"Text sanitization failed: {error_msg}") |
| raise ValueError(error_msg) from e |
|
|
| except Exception as e: |
| logger.error(f"Text sanitization: Unexpected error: {str(e)}") |
| |
| try: |
| text.encode("utf-8") |
| return text |
| except UnicodeEncodeError: |
| raise ValueError( |
| f"Text sanitization failed with unexpected error: {str(e)}" |
| ) from e |
|
|
|
|
| def check_storage_env_vars(storage_name: str) -> None: |
| """Check if all required environment variables for storage implementation exist |
| |
| Args: |
| storage_name: Storage implementation name |
| |
| Raises: |
| ValueError: If required environment variables are missing |
| """ |
| from lightrag.kg import STORAGE_ENV_REQUIREMENTS |
|
|
| required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, []) |
| missing_vars = [var for var in required_vars if var not in os.environ] |
|
|
| if missing_vars: |
| raise ValueError( |
| f"Storage implementation '{storage_name}' requires the following " |
| f"environment variables: {', '.join(missing_vars)}" |
| ) |
|
|
|
|
| def pick_by_weighted_polling( |
| entities_or_relations: list[dict], |
| max_related_chunks: int, |
| min_related_chunks: int = 1, |
| ) -> list[str]: |
| """ |
| Linear gradient weighted polling algorithm for text chunk selection. |
| |
| This algorithm ensures that entities/relations with higher importance get more text chunks, |
| forming a linear decreasing allocation pattern. |
| |
| Args: |
| entities_or_relations: List of entities or relations sorted by importance (high to low) |
| max_related_chunks: Expected number of text chunks for the highest importance entity/relation |
| min_related_chunks: Expected number of text chunks for the lowest importance entity/relation |
| |
| Returns: |
| List of selected text chunk IDs |
| """ |
| if not entities_or_relations: |
| return [] |
|
|
| n = len(entities_or_relations) |
| if n == 1: |
| |
| entity_chunks = entities_or_relations[0].get("sorted_chunks", []) |
| return entity_chunks[:max_related_chunks] |
|
|
| |
| expected_counts = [] |
| for i in range(n): |
| |
| ratio = i / (n - 1) if n > 1 else 0 |
| expected = max_related_chunks - ratio * ( |
| max_related_chunks - min_related_chunks |
| ) |
| expected_counts.append(int(round(expected))) |
|
|
| |
| selected_chunks = [] |
| used_counts = [] |
| total_remaining = 0 |
|
|
| for i, entity_rel in enumerate(entities_or_relations): |
| entity_chunks = entity_rel.get("sorted_chunks", []) |
| expected = expected_counts[i] |
|
|
| |
| actual = min(expected, len(entity_chunks)) |
| selected_chunks.extend(entity_chunks[:actual]) |
| used_counts.append(actual) |
|
|
| |
| remaining = expected - actual |
| if remaining > 0: |
| total_remaining += remaining |
|
|
| |
| for _ in range(total_remaining): |
| allocated = False |
|
|
| |
| for i, entity_rel in enumerate(entities_or_relations): |
| entity_chunks = entity_rel.get("sorted_chunks", []) |
|
|
| |
| if used_counts[i] < len(entity_chunks): |
| |
| selected_chunks.append(entity_chunks[used_counts[i]]) |
| used_counts[i] += 1 |
| allocated = True |
| break |
|
|
| |
| if not allocated: |
| break |
|
|
| return selected_chunks |
|
|
|
|
| async def pick_by_vector_similarity( |
| query: str, |
| text_chunks_storage: "BaseKVStorage", |
| chunks_vdb: "BaseVectorStorage", |
| num_of_chunks: int, |
| entity_info: list[dict[str, Any]], |
| embedding_func: callable, |
| query_embedding=None, |
| ) -> list[str]: |
| """ |
| Vector similarity-based text chunk selection algorithm. |
| |
| This algorithm selects text chunks based on cosine similarity between |
| the query embedding and text chunk embeddings. |
| |
| Args: |
| query: User's original query string |
| text_chunks_storage: Text chunks storage instance |
| chunks_vdb: Vector database storage for chunks |
| num_of_chunks: Number of chunks to select |
| entity_info: List of entity information containing chunk IDs |
| embedding_func: Embedding function to compute query embedding |
| |
| Returns: |
| List of selected text chunk IDs sorted by similarity (highest first) |
| """ |
| logger.debug( |
| f"Vector similarity chunk selection: num_of_chunks={num_of_chunks}, entity_info_count={len(entity_info) if entity_info else 0}" |
| ) |
|
|
| if not entity_info or num_of_chunks <= 0: |
| return [] |
|
|
| |
| all_chunk_ids = set() |
| for i, entity in enumerate(entity_info): |
| chunk_ids = entity.get("sorted_chunks", []) |
| all_chunk_ids.update(chunk_ids) |
|
|
| if not all_chunk_ids: |
| logger.warning( |
| "Vector similarity chunk selection: no chunk IDs found in entity_info" |
| ) |
| return [] |
|
|
| logger.debug( |
| f"Vector similarity chunk selection: {len(all_chunk_ids)} unique chunk IDs collected" |
| ) |
|
|
| all_chunk_ids = list(all_chunk_ids) |
|
|
| try: |
| |
| if query_embedding is None: |
| query_embedding = await embedding_func([query]) |
| query_embedding = query_embedding[ |
| 0 |
| ] |
| logger.debug( |
| "Computed query embedding for vector similarity chunk selection" |
| ) |
| else: |
| logger.debug( |
| "Using pre-computed query embedding for vector similarity chunk selection" |
| ) |
|
|
| |
| chunk_vectors = await chunks_vdb.get_vectors_by_ids(all_chunk_ids) |
| logger.debug( |
| f"Vector similarity chunk selection: {len(chunk_vectors)} chunk vectors Retrieved" |
| ) |
|
|
| if not chunk_vectors or len(chunk_vectors) != len(all_chunk_ids): |
| if not chunk_vectors: |
| logger.warning( |
| "Vector similarity chunk selection: no vectors retrieved from chunks_vdb" |
| ) |
| else: |
| logger.warning( |
| f"Vector similarity chunk selection: found {len(chunk_vectors)} but expecting {len(all_chunk_ids)}" |
| ) |
| return [] |
|
|
| |
| similarities = [] |
| valid_vectors = 0 |
| for chunk_id in all_chunk_ids: |
| if chunk_id in chunk_vectors: |
| chunk_embedding = chunk_vectors[chunk_id] |
| try: |
| |
| similarity = cosine_similarity(query_embedding, chunk_embedding) |
| similarities.append((chunk_id, similarity)) |
| valid_vectors += 1 |
| except Exception as e: |
| logger.warning( |
| f"Vector similarity chunk selection: failed to calculate similarity for chunk {chunk_id}: {e}" |
| ) |
| else: |
| logger.warning( |
| f"Vector similarity chunk selection: no vector found for chunk {chunk_id}" |
| ) |
|
|
| |
| similarities.sort(key=lambda x: x[1], reverse=True) |
| selected_chunks = [chunk_id for chunk_id, _ in similarities[:num_of_chunks]] |
|
|
| logger.debug( |
| f"Vector similarity chunk selection: {len(selected_chunks)} chunks from {len(all_chunk_ids)} candidates" |
| ) |
|
|
| return selected_chunks |
|
|
| except Exception as e: |
| logger.error(f"[VECTOR_SIMILARITY] Error in vector similarity sorting: {e}") |
| import traceback |
|
|
| logger.error(f"[VECTOR_SIMILARITY] Traceback: {traceback.format_exc()}") |
| |
| logger.debug("[VECTOR_SIMILARITY] Falling back to simple truncation") |
| return all_chunk_ids[:num_of_chunks] |
|
|
|
|
| class TokenTracker: |
| """Track token usage for LLM calls.""" |
|
|
| def __init__(self): |
| self.reset() |
|
|
| def __enter__(self): |
| self.reset() |
| return self |
|
|
| def __exit__(self, exc_type, exc_val, exc_tb): |
| print(self) |
|
|
| def reset(self): |
| self.prompt_tokens = 0 |
| self.completion_tokens = 0 |
| self.total_tokens = 0 |
| self.call_count = 0 |
|
|
| def add_usage(self, token_counts): |
| """Add token usage from one LLM call. |
| |
| Args: |
| token_counts: A dictionary containing prompt_tokens, completion_tokens, total_tokens |
| """ |
| self.prompt_tokens += token_counts.get("prompt_tokens", 0) |
| self.completion_tokens += token_counts.get("completion_tokens", 0) |
|
|
| |
| if "total_tokens" in token_counts: |
| self.total_tokens += token_counts["total_tokens"] |
| else: |
| self.total_tokens += token_counts.get( |
| "prompt_tokens", 0 |
| ) + token_counts.get("completion_tokens", 0) |
|
|
| self.call_count += 1 |
|
|
| def get_usage(self): |
| """Get current usage statistics.""" |
| return { |
| "prompt_tokens": self.prompt_tokens, |
| "completion_tokens": self.completion_tokens, |
| "total_tokens": self.total_tokens, |
| "call_count": self.call_count, |
| } |
|
|
| def __str__(self): |
| usage = self.get_usage() |
| return ( |
| f"LLM call count: {usage['call_count']}, " |
| f"Prompt tokens: {usage['prompt_tokens']}, " |
| f"Completion tokens: {usage['completion_tokens']}, " |
| f"Total tokens: {usage['total_tokens']}" |
| ) |
|
|
|
|
| async def apply_rerank_if_enabled( |
| query: str, |
| retrieved_docs: list[dict], |
| global_config: dict, |
| enable_rerank: bool = True, |
| top_n: int = None, |
| ) -> list[dict]: |
| """ |
| Apply reranking to retrieved documents if rerank is enabled. |
| |
| Args: |
| query: The search query |
| retrieved_docs: List of retrieved documents |
| global_config: Global configuration containing rerank settings |
| enable_rerank: Whether to enable reranking from query parameter |
| top_n: Number of top documents to return after reranking |
| |
| Returns: |
| Reranked documents if rerank is enabled, otherwise original documents |
| """ |
| if not enable_rerank or not retrieved_docs: |
| return retrieved_docs |
|
|
| rerank_func = global_config.get("rerank_model_func") |
| if not rerank_func: |
| logger.warning( |
| "Rerank is enabled but no rerank model is configured. Please set up a rerank model or set enable_rerank=False in query parameters." |
| ) |
| return retrieved_docs |
|
|
| try: |
| |
| document_texts = [] |
| for doc in retrieved_docs: |
| |
| content = ( |
| doc.get("content") |
| or doc.get("text") |
| or doc.get("chunk_content") |
| or doc.get("document") |
| or str(doc) |
| ) |
| document_texts.append(content) |
|
|
| |
| rerank_results = await rerank_func( |
| query=query, |
| documents=document_texts, |
| top_n=top_n, |
| ) |
|
|
| |
| if rerank_results and len(rerank_results) > 0: |
| |
| if isinstance(rerank_results[0], dict) and "index" in rerank_results[0]: |
| |
| reranked_docs = [] |
| for result in rerank_results: |
| index = result["index"] |
| relevance_score = result["relevance_score"] |
|
|
| |
| if 0 <= index < len(retrieved_docs): |
| doc = retrieved_docs[index].copy() |
| doc["rerank_score"] = relevance_score |
| reranked_docs.append(doc) |
|
|
| logger.info( |
| f"Successfully reranked: {len(reranked_docs)} chunks from {len(retrieved_docs)} original chunks" |
| ) |
| return reranked_docs |
| else: |
| |
| logger.info(f"Using legacy rerank format: {len(rerank_results)} chunks") |
| return rerank_results[:top_n] if top_n else rerank_results |
| else: |
| logger.warning("Rerank returned empty results, using original chunks") |
| return retrieved_docs |
|
|
| except Exception as e: |
| logger.error(f"Error during reranking: {e}, using original chunks") |
| return retrieved_docs |
|
|
|
|
| async def process_chunks_unified( |
| query: str, |
| unique_chunks: list[dict], |
| query_param: "QueryParam", |
| global_config: dict, |
| source_type: str = "mixed", |
| chunk_token_limit: int = None, |
| ) -> list[dict]: |
| """ |
| Unified processing for text chunks: deduplication, chunk_top_k limiting, reranking, and token truncation. |
| |
| Args: |
| query: Search query for reranking |
| chunks: List of text chunks to process |
| query_param: Query parameters containing configuration |
| global_config: Global configuration dictionary |
| source_type: Source type for logging ("vector", "entity", "relationship", "mixed") |
| chunk_token_limit: Dynamic token limit for chunks (if None, uses default) |
| |
| Returns: |
| Processed and filtered list of text chunks |
| """ |
| if not unique_chunks: |
| return [] |
|
|
| origin_count = len(unique_chunks) |
|
|
| |
| if query_param.enable_rerank and query and unique_chunks: |
| rerank_top_k = query_param.chunk_top_k or len(unique_chunks) |
| unique_chunks = await apply_rerank_if_enabled( |
| query=query, |
| retrieved_docs=unique_chunks, |
| global_config=global_config, |
| enable_rerank=query_param.enable_rerank, |
| top_n=rerank_top_k, |
| ) |
|
|
| |
| if query_param.enable_rerank and unique_chunks: |
| min_rerank_score = global_config.get("min_rerank_score", 0.5) |
| if min_rerank_score > 0.0: |
| original_count = len(unique_chunks) |
|
|
| |
| filtered_chunks = [] |
| for chunk in unique_chunks: |
| rerank_score = chunk.get( |
| "rerank_score", 1.0 |
| ) |
| if rerank_score >= min_rerank_score: |
| filtered_chunks.append(chunk) |
|
|
| unique_chunks = filtered_chunks |
| filtered_count = original_count - len(unique_chunks) |
|
|
| if filtered_count > 0: |
| logger.info( |
| f"Rerank filtering: {len(unique_chunks)} chunks remained (min rerank score: {min_rerank_score})" |
| ) |
| if not unique_chunks: |
| return [] |
|
|
| |
| if query_param.chunk_top_k is not None and query_param.chunk_top_k > 0: |
| if len(unique_chunks) > query_param.chunk_top_k: |
| unique_chunks = unique_chunks[: query_param.chunk_top_k] |
| logger.debug( |
| f"Kept chunk_top-k: {len(unique_chunks)} chunks (deduplicated original: {origin_count})" |
| ) |
|
|
| |
| tokenizer = global_config.get("tokenizer") |
| if tokenizer and unique_chunks: |
| |
| if chunk_token_limit is None: |
| |
| chunk_token_limit = getattr( |
| query_param, |
| "max_total_tokens", |
| global_config.get("MAX_TOTAL_TOKENS", DEFAULT_MAX_TOTAL_TOKENS), |
| ) |
|
|
| original_count = len(unique_chunks) |
|
|
| unique_chunks = truncate_list_by_token_size( |
| unique_chunks, |
| key=lambda x: "\n".join( |
| json.dumps(item, ensure_ascii=False) for item in [x] |
| ), |
| max_token_size=chunk_token_limit, |
| tokenizer=tokenizer, |
| ) |
|
|
| logger.debug( |
| f"Token truncation: {len(unique_chunks)} chunks from {original_count} " |
| f"(chunk available tokens: {chunk_token_limit}, source: {source_type})" |
| ) |
|
|
| |
| final_chunks = [] |
| for i, chunk in enumerate(unique_chunks): |
| chunk_with_id = chunk.copy() |
| chunk_with_id["id"] = f"DC{i + 1}" |
| final_chunks.append(chunk_with_id) |
|
|
| return final_chunks |
|
|
|
|
| def build_file_path(already_file_paths, data_list, target): |
| """Build file path string with UTF-8 byte length limit and deduplication |
| |
| Args: |
| already_file_paths: List of existing file paths |
| data_list: List of data items containing file_path |
| target: Target name for logging warnings |
| |
| Returns: |
| str: Combined file paths separated by GRAPH_FIELD_SEP |
| """ |
| |
| file_paths_set = {fp for fp in already_file_paths if fp} |
|
|
| |
| file_paths = GRAPH_FIELD_SEP.join(fp for fp in already_file_paths if fp) |
|
|
| |
| if len(file_paths.encode("utf-8")) >= DEFAULT_MAX_FILE_PATH_LENGTH: |
| logger.warning( |
| f"Initial file_paths already exceeds {DEFAULT_MAX_FILE_PATH_LENGTH} bytes for {target}, " |
| f"current size: {len(file_paths.encode('utf-8'))} bytes" |
| ) |
|
|
| |
| file_paths_ignore = "" |
| |
| for dp in data_list: |
| cur_file_path = dp.get("file_path") |
| |
| if not cur_file_path: |
| continue |
|
|
| |
| if cur_file_path in file_paths_set: |
| continue |
| |
| file_paths_set.add(cur_file_path) |
|
|
| |
| new_addition = GRAPH_FIELD_SEP + cur_file_path if file_paths else cur_file_path |
| if ( |
| len(file_paths.encode("utf-8")) + len(new_addition.encode("utf-8")) |
| < DEFAULT_MAX_FILE_PATH_LENGTH - 5 |
| ): |
| |
| file_paths += new_addition |
| else: |
| |
| file_paths_ignore += GRAPH_FIELD_SEP + cur_file_path |
|
|
| if file_paths_ignore: |
| logger.warning( |
| f"File paths exceed {DEFAULT_MAX_FILE_PATH_LENGTH} bytes for {target}, " |
| f"ignoring file path: {file_paths_ignore}" |
| ) |
| return file_paths |
|
|
|
|
| def generate_track_id(prefix: str = "upload") -> str: |
| """Generate a unique tracking ID with timestamp and UUID |
| |
| Args: |
| prefix: Prefix for the track ID (e.g., 'upload', 'insert') |
| |
| Returns: |
| str: Unique tracking ID in format: {prefix}_{timestamp}_{uuid} |
| """ |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| unique_id = str(uuid.uuid4())[:8] |
| return f"{prefix}_{timestamp}_{unique_id}" |
|
|
|
|
| def get_pinyin_sort_key(text: str) -> str: |
| """Generate sort key for Chinese pinyin sorting |
| |
| This function uses pypinyin for true Chinese pinyin sorting. |
| If pypinyin is not available, it falls back to simple lowercase string sorting. |
| |
| Args: |
| text: Text to generate sort key for |
| |
| Returns: |
| str: Sort key that can be used for comparison and sorting |
| """ |
| if not text: |
| return "" |
|
|
| if _PYPINYIN_AVAILABLE: |
| try: |
| |
| pinyin_list = pypinyin.lazy_pinyin(text, style=pypinyin.Style.NORMAL) |
| return "".join(pinyin_list).lower() |
| except Exception: |
| |
| return text.lower() |
| else: |
| |
| return text.lower() |
|
|
|
|
| def fix_tuple_delimiter_corruption( |
| record: str, delimiter_core: str, tuple_delimiter: str |
| ) -> str: |
| """ |
| Fix various forms of tuple_delimiter corruption from LLM output. |
| |
| This function handles missing or replaced characters around the core delimiter. |
| It fixes common corruption patterns where the LLM output doesn't match the expected |
| tuple_delimiter format. |
| |
| Args: |
| record: The text record to fix |
| delimiter_core: The core delimiter (e.g., "S" from "<|#|>") |
| tuple_delimiter: The complete tuple delimiter (e.g., "<|#|>") |
| |
| Returns: |
| The corrected record with proper tuple_delimiter format |
| """ |
| if not record or not delimiter_core or not tuple_delimiter: |
| return record |
|
|
| |
| escaped_delimiter_core = re.escape(delimiter_core) |
|
|
| |
| record = re.sub( |
| rf"<\|{escaped_delimiter_core}\|*?{escaped_delimiter_core}\|>", |
| tuple_delimiter, |
| record, |
| ) |
|
|
| |
| record = re.sub( |
| rf"<\|\\{escaped_delimiter_core}\|>", |
| tuple_delimiter, |
| record, |
| ) |
|
|
| |
| record = re.sub( |
| r"<\|+>", |
| tuple_delimiter, |
| record, |
| ) |
|
|
| |
| record = re.sub( |
| rf"<.?\|{escaped_delimiter_core}\|*?>", |
| tuple_delimiter, |
| record, |
| ) |
|
|
| |
| record = re.sub( |
| rf"<\|?{escaped_delimiter_core}\|?>", |
| tuple_delimiter, |
| record, |
| ) |
|
|
| |
| record = re.sub( |
| rf"<[^|]{escaped_delimiter_core}\|>|<\|{escaped_delimiter_core}[^|]>", |
| tuple_delimiter, |
| record, |
| ) |
|
|
| |
|
|
| record = re.sub( |
| rf"<\|{escaped_delimiter_core}\|+(?!>)", |
| tuple_delimiter, |
| record, |
| ) |
|
|
| |
| record = re.sub( |
| rf"<\|{escaped_delimiter_core}:(?!>)", |
| tuple_delimiter, |
| record, |
| ) |
|
|
| |
| record = re.sub( |
| r"<\|\|(?!>)", |
| tuple_delimiter, |
| record, |
| ) |
|
|
| |
| record = re.sub( |
| rf"(?<!<)\|{escaped_delimiter_core}\|>", |
| tuple_delimiter, |
| record, |
| ) |
|
|
| |
| record = re.sub( |
| rf"<\|{escaped_delimiter_core}\|>\|", |
| tuple_delimiter, |
| record, |
| ) |
|
|
| |
| record = re.sub( |
| rf"\|\|{escaped_delimiter_core}\|\|", |
| tuple_delimiter, |
| record, |
| ) |
|
|
| return record |
|
|
|
|
| def create_prefixed_exception(original_exception: Exception, prefix: str) -> Exception: |
| """ |
| Safely create a prefixed exception that adapts to all error types. |
| |
| Args: |
| original_exception: The original exception. |
| prefix: The prefix to add. |
| |
| Returns: |
| A new exception with the prefix, maintaining the original exception type if possible. |
| """ |
| try: |
| |
| if hasattr(original_exception, "args") and original_exception.args: |
| args = list(original_exception.args) |
| |
| |
| found_str = False |
| for i, arg in enumerate(args): |
| if isinstance(arg, str): |
| args[i] = f"{prefix}: {arg}" |
| found_str = True |
| break |
|
|
| |
| if not found_str: |
| args[0] = f"{prefix}: {args[0]}" |
|
|
| return type(original_exception)(*args) |
| else: |
| |
| return type(original_exception)(f"{prefix}: {str(original_exception)}") |
| except (TypeError, ValueError, AttributeError) as construct_error: |
| |
| |
| |
| return RuntimeError( |
| f"{prefix}: {type(original_exception).__name__}: {str(original_exception)} " |
| f"(Original exception could not be reconstructed: {construct_error})" |
| ) |
|
|
|
|
| def convert_to_user_format( |
| entities_context: list[dict], |
| relations_context: list[dict], |
| chunks: list[dict], |
| references: list[dict], |
| query_mode: str, |
| entity_id_to_original: dict = None, |
| relation_id_to_original: dict = None, |
| ) -> dict[str, Any]: |
| """Convert internal data format to user-friendly format using original database data""" |
|
|
| |
| formatted_entities = [] |
| for entity in entities_context: |
| entity_name = entity.get("entity", "") |
|
|
| |
| original_entity = None |
| if entity_id_to_original and entity_name in entity_id_to_original: |
| original_entity = entity_id_to_original[entity_name] |
|
|
| if original_entity: |
| |
| formatted_entities.append( |
| { |
| "entity_name": original_entity.get("entity_name", entity_name), |
| "entity_type": original_entity.get("entity_type", "UNKNOWN"), |
| "description": original_entity.get("description", ""), |
| "source_id": original_entity.get("source_id", ""), |
| "file_path": original_entity.get("file_path", "unknown_source"), |
| "created_at": original_entity.get("created_at", ""), |
| } |
| ) |
| else: |
| |
| formatted_entities.append( |
| { |
| "entity_name": entity_name, |
| "entity_type": entity.get("type", "UNKNOWN"), |
| "description": entity.get("description", ""), |
| "source_id": entity.get("source_id", ""), |
| "file_path": entity.get("file_path", "unknown_source"), |
| "created_at": entity.get("created_at", ""), |
| } |
| ) |
|
|
| |
| formatted_relationships = [] |
| for relation in relations_context: |
| entity1 = relation.get("entity1", "") |
| entity2 = relation.get("entity2", "") |
| relation_key = (entity1, entity2) |
|
|
| |
| original_relation = None |
| if relation_id_to_original and relation_key in relation_id_to_original: |
| original_relation = relation_id_to_original[relation_key] |
|
|
| if original_relation: |
| |
| formatted_relationships.append( |
| { |
| "src_id": original_relation.get("src_id", entity1), |
| "tgt_id": original_relation.get("tgt_id", entity2), |
| "description": original_relation.get("description", ""), |
| "keywords": original_relation.get("keywords", ""), |
| "weight": original_relation.get("weight", 1.0), |
| "source_id": original_relation.get("source_id", ""), |
| "file_path": original_relation.get("file_path", "unknown_source"), |
| "created_at": original_relation.get("created_at", ""), |
| } |
| ) |
| else: |
| |
| formatted_relationships.append( |
| { |
| "src_id": entity1, |
| "tgt_id": entity2, |
| "description": relation.get("description", ""), |
| "keywords": relation.get("keywords", ""), |
| "weight": relation.get("weight", 1.0), |
| "source_id": relation.get("source_id", ""), |
| "file_path": relation.get("file_path", "unknown_source"), |
| "created_at": relation.get("created_at", ""), |
| } |
| ) |
|
|
| |
| formatted_chunks = [] |
| for i, chunk in enumerate(chunks): |
| chunk_data = { |
| "reference_id": chunk.get("reference_id", ""), |
| "content": chunk.get("content", ""), |
| "file_path": chunk.get("file_path", "unknown_source"), |
| "chunk_id": chunk.get("chunk_id", ""), |
| } |
| formatted_chunks.append(chunk_data) |
|
|
| logger.debug( |
| f"[convert_to_user_format] Formatted {len(formatted_chunks)}/{len(chunks)} chunks" |
| ) |
|
|
| |
| metadata = { |
| "query_mode": query_mode, |
| "keywords": { |
| "high_level": [], |
| "low_level": [], |
| }, |
| } |
|
|
| return { |
| "status": "success", |
| "message": "Query processed successfully", |
| "data": { |
| "entities": formatted_entities, |
| "relationships": formatted_relationships, |
| "chunks": formatted_chunks, |
| "references": references, |
| }, |
| "metadata": metadata, |
| } |
|
|
|
|
| def generate_reference_list_from_chunks( |
| chunks: list[dict], |
| ) -> tuple[list[dict], list[dict]]: |
| """ |
| Generate reference list from chunks, prioritizing by occurrence frequency. |
| |
| This function extracts file_paths from chunks, counts their occurrences, |
| sorts by frequency and first appearance order, creates reference_id mappings, |
| and builds a reference_list structure. |
| |
| Args: |
| chunks: List of chunk dictionaries with file_path information |
| |
| Returns: |
| tuple: (reference_list, updated_chunks_with_reference_ids) |
| - reference_list: List of dicts with reference_id and file_path |
| - updated_chunks_with_reference_ids: Original chunks with reference_id field added |
| """ |
| if not chunks: |
| return [], [] |
|
|
| |
| file_path_counts = {} |
| for chunk in chunks: |
| file_path = chunk.get("file_path", "") |
| if file_path and file_path != "unknown_source": |
| file_path_counts[file_path] = file_path_counts.get(file_path, 0) + 1 |
|
|
| |
| |
| file_path_with_indices = [] |
| seen_paths = set() |
| for i, chunk in enumerate(chunks): |
| file_path = chunk.get("file_path", "") |
| if file_path and file_path != "unknown_source" and file_path not in seen_paths: |
| file_path_with_indices.append((file_path, file_path_counts[file_path], i)) |
| seen_paths.add(file_path) |
|
|
| |
| sorted_file_paths = sorted(file_path_with_indices, key=lambda x: (-x[1], x[2])) |
| unique_file_paths = [item[0] for item in sorted_file_paths] |
|
|
| |
| file_path_to_ref_id = {} |
| for i, file_path in enumerate(unique_file_paths): |
| file_path_to_ref_id[file_path] = str(i + 1) |
|
|
| |
| updated_chunks = [] |
| for chunk in chunks: |
| chunk_copy = chunk.copy() |
| file_path = chunk_copy.get("file_path", "") |
| if file_path and file_path != "unknown_source": |
| chunk_copy["reference_id"] = file_path_to_ref_id[file_path] |
| else: |
| chunk_copy["reference_id"] = "" |
| updated_chunks.append(chunk_copy) |
|
|
| |
| reference_list = [] |
| for i, file_path in enumerate(unique_file_paths): |
| reference_list.append({"reference_id": str(i + 1), "file_path": file_path}) |
|
|
| return reference_list, updated_chunks |
|
|