| from __future__ import annotations |
| from functools import partial |
|
|
| import asyncio |
| import json |
| import json_repair |
| from typing import Any, AsyncIterator, overload, Literal |
| from collections import Counter, defaultdict |
|
|
| from .utils import ( |
| logger, |
| compute_mdhash_id, |
| Tokenizer, |
| is_float_regex, |
| sanitize_and_normalize_extracted_text, |
| pack_user_ass_to_openai_messages, |
| split_string_by_multi_markers, |
| truncate_list_by_token_size, |
| compute_args_hash, |
| handle_cache, |
| save_to_cache, |
| CacheData, |
| use_llm_func_with_cache, |
| update_chunk_cache_list, |
| remove_think_tags, |
| pick_by_weighted_polling, |
| pick_by_vector_similarity, |
| process_chunks_unified, |
| build_file_path, |
| safe_vdb_operation_with_exception, |
| create_prefixed_exception, |
| fix_tuple_delimiter_corruption, |
| convert_to_user_format, |
| generate_reference_list_from_chunks, |
| ) |
| from .base import ( |
| BaseGraphStorage, |
| BaseKVStorage, |
| BaseVectorStorage, |
| TextChunkSchema, |
| QueryParam, |
| QueryResult, |
| QueryContextResult, |
| ) |
| from .prompt import PROMPTS |
| from .constants import ( |
| GRAPH_FIELD_SEP, |
| DEFAULT_MAX_ENTITY_TOKENS, |
| DEFAULT_MAX_RELATION_TOKENS, |
| DEFAULT_MAX_TOTAL_TOKENS, |
| DEFAULT_RELATED_CHUNK_NUMBER, |
| DEFAULT_KG_CHUNK_PICK_METHOD, |
| DEFAULT_ENTITY_TYPES, |
| DEFAULT_SUMMARY_LANGUAGE, |
| ) |
| from .kg.shared_storage import get_storage_keyed_lock |
| import time |
| from dotenv import load_dotenv |
|
|
| |
| |
| |
| load_dotenv(dotenv_path=".env", override=False) |
|
|
|
|
| def chunking_by_token_size( |
| tokenizer: Tokenizer, |
| content: str, |
| split_by_character: str | None = None, |
| split_by_character_only: bool = False, |
| overlap_token_size: int = 128, |
| max_token_size: int = 1024, |
| ) -> list[dict[str, Any]]: |
| tokens = tokenizer.encode(content) |
| results: list[dict[str, Any]] = [] |
| if split_by_character: |
| raw_chunks = content.split(split_by_character) |
| new_chunks = [] |
| if split_by_character_only: |
| for chunk in raw_chunks: |
| _tokens = tokenizer.encode(chunk) |
| new_chunks.append((len(_tokens), chunk)) |
| else: |
| for chunk in raw_chunks: |
| _tokens = tokenizer.encode(chunk) |
| if len(_tokens) > max_token_size: |
| for start in range( |
| 0, len(_tokens), max_token_size - overlap_token_size |
| ): |
| chunk_content = tokenizer.decode( |
| _tokens[start : start + max_token_size] |
| ) |
| new_chunks.append( |
| (min(max_token_size, len(_tokens) - start), chunk_content) |
| ) |
| else: |
| new_chunks.append((len(_tokens), chunk)) |
| for index, (_len, chunk) in enumerate(new_chunks): |
| results.append( |
| { |
| "tokens": _len, |
| "content": chunk.strip(), |
| "chunk_order_index": index, |
| } |
| ) |
| else: |
| for index, start in enumerate( |
| range(0, len(tokens), max_token_size - overlap_token_size) |
| ): |
| chunk_content = tokenizer.decode(tokens[start : start + max_token_size]) |
| results.append( |
| { |
| "tokens": min(max_token_size, len(tokens) - start), |
| "content": chunk_content.strip(), |
| "chunk_order_index": index, |
| } |
| ) |
| return results |
|
|
|
|
| async def _handle_entity_relation_summary( |
| description_type: str, |
| entity_or_relation_name: str, |
| description_list: list[str], |
| seperator: str, |
| global_config: dict, |
| llm_response_cache: BaseKVStorage | None = None, |
| ) -> tuple[str, bool]: |
| """Handle entity relation description summary using map-reduce approach. |
| |
| This function summarizes a list of descriptions using a map-reduce strategy: |
| 1. If total tokens < summary_context_size and len(description_list) < force_llm_summary_on_merge, no need to summarize |
| 2. If total tokens < summary_max_tokens, summarize with LLM directly |
| 3. Otherwise, split descriptions into chunks that fit within token limits |
| 4. Summarize each chunk, then recursively process the summaries |
| 5. Continue until we get a final summary within token limits or num of descriptions is less than force_llm_summary_on_merge |
| |
| Args: |
| entity_or_relation_name: Name of the entity or relation being summarized |
| description_list: List of description strings to summarize |
| global_config: Global configuration containing tokenizer and limits |
| llm_response_cache: Optional cache for LLM responses |
| |
| Returns: |
| Tuple of (final_summarized_description_string, llm_was_used_boolean) |
| """ |
| |
| if not description_list: |
| return "", False |
|
|
| |
| if len(description_list) == 1: |
| return description_list[0], False |
|
|
| |
| tokenizer: Tokenizer = global_config["tokenizer"] |
| summary_context_size = global_config["summary_context_size"] |
| summary_max_tokens = global_config["summary_max_tokens"] |
| force_llm_summary_on_merge = global_config["force_llm_summary_on_merge"] |
|
|
| current_list = description_list[:] |
| llm_was_used = False |
|
|
| |
| while True: |
| |
| total_tokens = sum(len(tokenizer.encode(desc)) for desc in current_list) |
|
|
| |
| if total_tokens <= summary_context_size or len(current_list) <= 2: |
| if ( |
| len(current_list) < force_llm_summary_on_merge |
| and total_tokens < summary_max_tokens |
| ): |
| |
| final_description = seperator.join(current_list) |
| return final_description if final_description else "", llm_was_used |
| else: |
| if total_tokens > summary_context_size and len(current_list) <= 2: |
| logger.warning( |
| f"Summarizing {entity_or_relation_name}: Oversize descpriton found" |
| ) |
| |
| final_summary = await _summarize_descriptions( |
| description_type, |
| entity_or_relation_name, |
| current_list, |
| global_config, |
| llm_response_cache, |
| ) |
| return final_summary, True |
|
|
| |
| |
| chunks = [] |
| current_chunk = [] |
| current_tokens = 0 |
|
|
| |
| for i, desc in enumerate(current_list): |
| desc_tokens = len(tokenizer.encode(desc)) |
|
|
| |
| if current_tokens + desc_tokens > summary_context_size and current_chunk: |
| |
| if len(current_chunk) == 1: |
| |
| current_chunk.append(desc) |
| chunks.append(current_chunk) |
| logger.warning( |
| f"Summarizing {entity_or_relation_name}: Oversize descpriton found" |
| ) |
| current_chunk = [] |
| current_tokens = 0 |
| else: |
| chunks.append(current_chunk) |
| current_chunk = [desc] |
| current_tokens = desc_tokens |
| else: |
| current_chunk.append(desc) |
| current_tokens += desc_tokens |
|
|
| |
| if current_chunk: |
| chunks.append(current_chunk) |
|
|
| logger.info( |
| f" Summarizing {entity_or_relation_name}: Map {len(current_list)} descriptions into {len(chunks)} groups" |
| ) |
|
|
| |
| new_summaries = [] |
| for chunk in chunks: |
| if len(chunk) == 1: |
| |
| new_summaries.append(chunk[0]) |
| else: |
| |
| summary = await _summarize_descriptions( |
| description_type, |
| entity_or_relation_name, |
| chunk, |
| global_config, |
| llm_response_cache, |
| ) |
| new_summaries.append(summary) |
| llm_was_used = True |
|
|
| |
| current_list = new_summaries |
|
|
|
|
| async def _summarize_descriptions( |
| description_type: str, |
| description_name: str, |
| description_list: list[str], |
| global_config: dict, |
| llm_response_cache: BaseKVStorage | None = None, |
| ) -> str: |
| """Helper function to summarize a list of descriptions using LLM. |
| |
| Args: |
| entity_or_relation_name: Name of the entity or relation being summarized |
| descriptions: List of description strings to summarize |
| global_config: Global configuration containing LLM function and settings |
| llm_response_cache: Optional cache for LLM responses |
| |
| Returns: |
| Summarized description string |
| """ |
| use_llm_func = global_config["llm_model_func"] |
| if not isinstance(use_llm_func, partial): |
| |
| use_llm_func = partial(use_llm_func, _priority=8) |
|
|
| language = global_config["addon_params"].get("language", DEFAULT_SUMMARY_LANGUAGE) |
|
|
| summary_length_recommended = global_config["summary_length_recommended"] |
|
|
| prompt_template = PROMPTS["summarize_entity_descriptions"] |
|
|
| |
| tokenizer = global_config["tokenizer"] |
| summary_context_size = global_config["summary_context_size"] |
|
|
| |
| json_descriptions = [{"Description": desc} for desc in description_list] |
|
|
| |
| truncated_json_descriptions = truncate_list_by_token_size( |
| json_descriptions, |
| key=lambda x: json.dumps(x, ensure_ascii=False), |
| max_token_size=summary_context_size, |
| tokenizer=tokenizer, |
| ) |
|
|
| |
| joined_descriptions = "\n".join( |
| json.dumps(desc, ensure_ascii=False) for desc in truncated_json_descriptions |
| ) |
|
|
| |
| context_base = dict( |
| description_type=description_type, |
| description_name=description_name, |
| description_list=joined_descriptions, |
| summary_length=summary_length_recommended, |
| language=language, |
| ) |
| use_prompt = prompt_template.format(**context_base) |
|
|
| |
| summary, _ = await use_llm_func_with_cache( |
| use_prompt, |
| use_llm_func, |
| llm_response_cache=llm_response_cache, |
| cache_type="summary", |
| ) |
| return summary |
|
|
|
|
| async def _handle_single_entity_extraction( |
| record_attributes: list[str], |
| chunk_key: str, |
| timestamp: int, |
| file_path: str = "unknown_source", |
| ): |
| if len(record_attributes) != 4 or "entity" not in record_attributes[0]: |
| if len(record_attributes) > 1 and "entity" in record_attributes[0]: |
| logger.warning( |
| f"{chunk_key}: LLM output format error; found {len(record_attributes)}/4 feilds on ENTITY `{record_attributes[1]}` @ `{record_attributes[2] if len(record_attributes) > 2 else 'N/A'}`" |
| ) |
| logger.debug(record_attributes) |
| return None |
|
|
| try: |
| entity_name = sanitize_and_normalize_extracted_text( |
| record_attributes[1], remove_inner_quotes=True |
| ) |
|
|
| |
| if not entity_name or not entity_name.strip(): |
| logger.warning( |
| f"Entity extraction error: entity name became empty after cleaning. Original: '{record_attributes[1]}'" |
| ) |
| return None |
|
|
| |
| entity_type = sanitize_and_normalize_extracted_text( |
| record_attributes[2], remove_inner_quotes=True |
| ) |
|
|
| if not entity_type.strip() or any( |
| char in entity_type for char in ["'", "(", ")", "<", ">", "|", "/", "\\"] |
| ): |
| logger.warning( |
| f"Entity extraction error: invalid entity type in: {record_attributes}" |
| ) |
| return None |
|
|
| |
| entity_type = entity_type.replace(" ", "").lower() |
|
|
| |
| entity_description = sanitize_and_normalize_extracted_text(record_attributes[3]) |
|
|
| if not entity_description.strip(): |
| logger.warning( |
| f"Entity extraction error: empty description for entity '{entity_name}' of type '{entity_type}'" |
| ) |
| return None |
|
|
| return dict( |
| entity_name=entity_name, |
| entity_type=entity_type, |
| description=entity_description, |
| source_id=chunk_key, |
| file_path=file_path, |
| timestamp=timestamp, |
| ) |
|
|
| except ValueError as e: |
| logger.error( |
| f"Entity extraction failed due to encoding issues in chunk {chunk_key}: {e}" |
| ) |
| return None |
| except Exception as e: |
| logger.error( |
| f"Entity extraction failed with unexpected error in chunk {chunk_key}: {e}" |
| ) |
| return None |
|
|
|
|
| async def _handle_single_relationship_extraction( |
| record_attributes: list[str], |
| chunk_key: str, |
| timestamp: int, |
| file_path: str = "unknown_source", |
| ): |
| if ( |
| len(record_attributes) != 5 or "relation" not in record_attributes[0] |
| ): |
| if len(record_attributes) > 1 and "relation" in record_attributes[0]: |
| logger.warning( |
| f"{chunk_key}: LLM output format error; found {len(record_attributes)}/5 fields on REALTION `{record_attributes[1]}`~`{record_attributes[2] if len(record_attributes) >2 else 'N/A'}`" |
| ) |
| logger.debug(record_attributes) |
| return None |
|
|
| try: |
| source = sanitize_and_normalize_extracted_text( |
| record_attributes[1], remove_inner_quotes=True |
| ) |
| target = sanitize_and_normalize_extracted_text( |
| record_attributes[2], remove_inner_quotes=True |
| ) |
|
|
| |
| if not source: |
| logger.warning( |
| f"Relationship extraction error: source entity became empty after cleaning. Original: '{record_attributes[1]}'" |
| ) |
| return None |
|
|
| if not target: |
| logger.warning( |
| f"Relationship extraction error: target entity became empty after cleaning. Original: '{record_attributes[2]}'" |
| ) |
| return None |
|
|
| if source == target: |
| logger.debug( |
| f"Relationship source and target are the same in: {record_attributes}" |
| ) |
| return None |
|
|
| |
| edge_keywords = sanitize_and_normalize_extracted_text( |
| record_attributes[3], remove_inner_quotes=True |
| ) |
| edge_keywords = edge_keywords.replace(",", ",") |
|
|
| |
| edge_description = sanitize_and_normalize_extracted_text(record_attributes[4]) |
|
|
| edge_source_id = chunk_key |
| weight = ( |
| float(record_attributes[-1].strip('"').strip("'")) |
| if is_float_regex(record_attributes[-1].strip('"').strip("'")) |
| else 1.0 |
| ) |
|
|
| return dict( |
| src_id=source, |
| tgt_id=target, |
| weight=weight, |
| description=edge_description, |
| keywords=edge_keywords, |
| source_id=edge_source_id, |
| file_path=file_path, |
| timestamp=timestamp, |
| ) |
|
|
| except ValueError as e: |
| logger.warning( |
| f"Relationship extraction failed due to encoding issues in chunk {chunk_key}: {e}" |
| ) |
| return None |
| except Exception as e: |
| logger.warning( |
| f"Relationship extraction failed with unexpected error in chunk {chunk_key}: {e}" |
| ) |
| return None |
|
|
|
|
| async def _rebuild_knowledge_from_chunks( |
| entities_to_rebuild: dict[str, set[str]], |
| relationships_to_rebuild: dict[tuple[str, str], set[str]], |
| knowledge_graph_inst: BaseGraphStorage, |
| entities_vdb: BaseVectorStorage, |
| relationships_vdb: BaseVectorStorage, |
| text_chunks_storage: BaseKVStorage, |
| llm_response_cache: BaseKVStorage, |
| global_config: dict[str, str], |
| pipeline_status: dict | None = None, |
| pipeline_status_lock=None, |
| ) -> None: |
| """Rebuild entity and relationship descriptions from cached extraction results with parallel processing |
| |
| This method uses cached LLM extraction results instead of calling LLM again, |
| following the same approach as the insert process. Now with parallel processing |
| controlled by llm_model_max_async and using get_storage_keyed_lock for data consistency. |
| |
| Args: |
| entities_to_rebuild: Dict mapping entity_name -> set of remaining chunk_ids |
| relationships_to_rebuild: Dict mapping (src, tgt) -> set of remaining chunk_ids |
| knowledge_graph_inst: Knowledge graph storage |
| entities_vdb: Entity vector database |
| relationships_vdb: Relationship vector database |
| text_chunks_storage: Text chunks storage |
| llm_response_cache: LLM response cache |
| global_config: Global configuration containing llm_model_max_async |
| pipeline_status: Pipeline status dictionary |
| pipeline_status_lock: Lock for pipeline status |
| """ |
| if not entities_to_rebuild and not relationships_to_rebuild: |
| return |
|
|
| |
| all_referenced_chunk_ids = set() |
| for chunk_ids in entities_to_rebuild.values(): |
| all_referenced_chunk_ids.update(chunk_ids) |
| for chunk_ids in relationships_to_rebuild.values(): |
| all_referenced_chunk_ids.update(chunk_ids) |
|
|
| status_message = f"Rebuilding knowledge from {len(all_referenced_chunk_ids)} cached chunk extractions (parallel processing)" |
| logger.info(status_message) |
| if pipeline_status is not None and pipeline_status_lock is not None: |
| async with pipeline_status_lock: |
| pipeline_status["latest_message"] = status_message |
| pipeline_status["history_messages"].append(status_message) |
|
|
| |
| |
| cached_results = await _get_cached_extraction_results( |
| llm_response_cache, |
| all_referenced_chunk_ids, |
| text_chunks_storage=text_chunks_storage, |
| ) |
|
|
| if not cached_results: |
| status_message = "No cached extraction results found, cannot rebuild" |
| logger.warning(status_message) |
| if pipeline_status is not None and pipeline_status_lock is not None: |
| async with pipeline_status_lock: |
| pipeline_status["latest_message"] = status_message |
| pipeline_status["history_messages"].append(status_message) |
| return |
|
|
| |
| chunk_entities = {} |
| chunk_relationships = {} |
|
|
| for chunk_id, results in cached_results.items(): |
| try: |
| |
| chunk_entities[chunk_id] = defaultdict(list) |
| chunk_relationships[chunk_id] = defaultdict(list) |
|
|
| |
| for result in results: |
| entities, relationships = await _rebuild_from_extraction_result( |
| text_chunks_storage=text_chunks_storage, |
| chunk_id=chunk_id, |
| extraction_result=result[0], |
| timestamp=result[1], |
| ) |
|
|
| |
| |
| for entity_name, entity_list in entities.items(): |
| if entity_name not in chunk_entities[chunk_id]: |
| |
| chunk_entities[chunk_id][entity_name].extend(entity_list) |
| elif len(chunk_entities[chunk_id][entity_name]) == 0: |
| |
| chunk_entities[chunk_id][entity_name].extend(entity_list) |
| else: |
| |
| existing_desc_len = len( |
| chunk_entities[chunk_id][entity_name][0].get( |
| "description", "" |
| ) |
| or "" |
| ) |
| new_desc_len = len(entity_list[0].get("description", "") or "") |
|
|
| if new_desc_len > existing_desc_len: |
| |
| chunk_entities[chunk_id][entity_name] = list(entity_list) |
| |
|
|
| |
| for rel_key, rel_list in relationships.items(): |
| if rel_key not in chunk_relationships[chunk_id]: |
| |
| chunk_relationships[chunk_id][rel_key].extend(rel_list) |
| elif len(chunk_relationships[chunk_id][rel_key]) == 0: |
| |
| chunk_relationships[chunk_id][rel_key].extend(rel_list) |
| else: |
| |
| existing_desc_len = len( |
| chunk_relationships[chunk_id][rel_key][0].get( |
| "description", "" |
| ) |
| or "" |
| ) |
| new_desc_len = len(rel_list[0].get("description", "") or "") |
|
|
| if new_desc_len > existing_desc_len: |
| |
| chunk_relationships[chunk_id][rel_key] = list(rel_list) |
| |
|
|
| except Exception as e: |
| status_message = ( |
| f"Failed to parse cached extraction result for chunk {chunk_id}: {e}" |
| ) |
| logger.info(status_message) |
| if pipeline_status is not None and pipeline_status_lock is not None: |
| async with pipeline_status_lock: |
| pipeline_status["latest_message"] = status_message |
| pipeline_status["history_messages"].append(status_message) |
| continue |
|
|
| |
| graph_max_async = global_config.get("llm_model_max_async", 4) * 2 |
| semaphore = asyncio.Semaphore(graph_max_async) |
|
|
| |
| rebuilt_entities_count = 0 |
| rebuilt_relationships_count = 0 |
| failed_entities_count = 0 |
| failed_relationships_count = 0 |
|
|
| async def _locked_rebuild_entity(entity_name, chunk_ids): |
| nonlocal rebuilt_entities_count, failed_entities_count |
| async with semaphore: |
| workspace = global_config.get("workspace", "") |
| namespace = f"{workspace}:GraphDB" if workspace else "GraphDB" |
| async with get_storage_keyed_lock( |
| [entity_name], namespace=namespace, enable_logging=False |
| ): |
| try: |
| await _rebuild_single_entity( |
| knowledge_graph_inst=knowledge_graph_inst, |
| entities_vdb=entities_vdb, |
| entity_name=entity_name, |
| chunk_ids=chunk_ids, |
| chunk_entities=chunk_entities, |
| llm_response_cache=llm_response_cache, |
| global_config=global_config, |
| ) |
| rebuilt_entities_count += 1 |
| status_message = ( |
| f"Rebuilt `{entity_name}` from {len(chunk_ids)} chunks" |
| ) |
| logger.info(status_message) |
| if pipeline_status is not None and pipeline_status_lock is not None: |
| async with pipeline_status_lock: |
| pipeline_status["latest_message"] = status_message |
| pipeline_status["history_messages"].append(status_message) |
| except Exception as e: |
| failed_entities_count += 1 |
| status_message = f"Failed to rebuild `{entity_name}`: {e}" |
| logger.info(status_message) |
| if pipeline_status is not None and pipeline_status_lock is not None: |
| async with pipeline_status_lock: |
| pipeline_status["latest_message"] = status_message |
| pipeline_status["history_messages"].append(status_message) |
|
|
| async def _locked_rebuild_relationship(src, tgt, chunk_ids): |
| nonlocal rebuilt_relationships_count, failed_relationships_count |
| async with semaphore: |
| workspace = global_config.get("workspace", "") |
| namespace = f"{workspace}:GraphDB" if workspace else "GraphDB" |
| |
| sorted_key_parts = sorted([src, tgt]) |
| async with get_storage_keyed_lock( |
| sorted_key_parts, |
| namespace=namespace, |
| enable_logging=False, |
| ): |
| try: |
| await _rebuild_single_relationship( |
| knowledge_graph_inst=knowledge_graph_inst, |
| relationships_vdb=relationships_vdb, |
| src=src, |
| tgt=tgt, |
| chunk_ids=chunk_ids, |
| chunk_relationships=chunk_relationships, |
| llm_response_cache=llm_response_cache, |
| global_config=global_config, |
| ) |
| rebuilt_relationships_count += 1 |
| status_message = ( |
| f"Rebuilt `{src} - {tgt}` from {len(chunk_ids)} chunks" |
| ) |
| logger.info(status_message) |
| if pipeline_status is not None and pipeline_status_lock is not None: |
| async with pipeline_status_lock: |
| pipeline_status["latest_message"] = status_message |
| pipeline_status["history_messages"].append(status_message) |
| except Exception as e: |
| failed_relationships_count += 1 |
| status_message = f"Failed to rebuild `{src} - {tgt}`: {e}" |
| logger.info(status_message) |
| if pipeline_status is not None and pipeline_status_lock is not None: |
| async with pipeline_status_lock: |
| pipeline_status["latest_message"] = status_message |
| pipeline_status["history_messages"].append(status_message) |
|
|
| |
| tasks = [] |
|
|
| |
| for entity_name, chunk_ids in entities_to_rebuild.items(): |
| task = asyncio.create_task(_locked_rebuild_entity(entity_name, chunk_ids)) |
| tasks.append(task) |
|
|
| |
| for (src, tgt), chunk_ids in relationships_to_rebuild.items(): |
| task = asyncio.create_task(_locked_rebuild_relationship(src, tgt, chunk_ids)) |
| tasks.append(task) |
|
|
| |
| status_message = f"Starting parallel rebuild of {len(entities_to_rebuild)} entities and {len(relationships_to_rebuild)} relationships (async: {graph_max_async})" |
| logger.info(status_message) |
| if pipeline_status is not None and pipeline_status_lock is not None: |
| async with pipeline_status_lock: |
| pipeline_status["latest_message"] = status_message |
| pipeline_status["history_messages"].append(status_message) |
|
|
| |
| done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION) |
|
|
| |
| first_exception = None |
|
|
| for task in done: |
| try: |
| exception = task.exception() |
| if exception is not None: |
| if first_exception is None: |
| first_exception = exception |
| else: |
| |
| task.result() |
| except Exception as e: |
| if first_exception is None: |
| first_exception = e |
|
|
| |
| if first_exception is not None: |
| |
| for pending_task in pending: |
| pending_task.cancel() |
|
|
| |
| if pending: |
| await asyncio.wait(pending) |
|
|
| |
| raise first_exception |
|
|
| |
| status_message = f"KG rebuild completed: {rebuilt_entities_count} entities and {rebuilt_relationships_count} relationships rebuilt successfully." |
| if failed_entities_count > 0 or failed_relationships_count > 0: |
| status_message += f" Failed: {failed_entities_count} entities, {failed_relationships_count} relationships." |
|
|
| logger.info(status_message) |
| if pipeline_status is not None and pipeline_status_lock is not None: |
| async with pipeline_status_lock: |
| pipeline_status["latest_message"] = status_message |
| pipeline_status["history_messages"].append(status_message) |
|
|
|
|
| async def _get_cached_extraction_results( |
| llm_response_cache: BaseKVStorage, |
| chunk_ids: set[str], |
| text_chunks_storage: BaseKVStorage, |
| ) -> dict[str, list[str]]: |
| """Get cached extraction results for specific chunk IDs |
| |
| This function retrieves cached LLM extraction results for the given chunk IDs and returns |
| them sorted by creation time. The results are sorted at two levels: |
| 1. Individual extraction results within each chunk are sorted by create_time (earliest first) |
| 2. Chunks themselves are sorted by the create_time of their earliest extraction result |
| |
| Args: |
| llm_response_cache: LLM response cache storage |
| chunk_ids: Set of chunk IDs to get cached results for |
| text_chunks_storage: Text chunks storage for retrieving chunk data and LLM cache references |
| |
| Returns: |
| Dict mapping chunk_id -> list of extraction_result_text, where: |
| - Keys (chunk_ids) are ordered by the create_time of their first extraction result |
| - Values (extraction results) are ordered by create_time within each chunk |
| """ |
| cached_results = {} |
|
|
| |
| all_cache_ids = set() |
|
|
| |
| chunk_data_list = await text_chunks_storage.get_by_ids(list(chunk_ids)) |
| for chunk_data in chunk_data_list: |
| if chunk_data and isinstance(chunk_data, dict): |
| llm_cache_list = chunk_data.get("llm_cache_list", []) |
| if llm_cache_list: |
| all_cache_ids.update(llm_cache_list) |
| else: |
| logger.warning(f"Chunk data is invalid or None: {chunk_data}") |
|
|
| if not all_cache_ids: |
| logger.warning(f"No LLM cache IDs found for {len(chunk_ids)} chunk IDs") |
| return cached_results |
|
|
| |
| cache_data_list = await llm_response_cache.get_by_ids(list(all_cache_ids)) |
|
|
| |
| valid_entries = 0 |
| for cache_entry in cache_data_list: |
| if ( |
| cache_entry is not None |
| and isinstance(cache_entry, dict) |
| and cache_entry.get("cache_type") == "extract" |
| and cache_entry.get("chunk_id") in chunk_ids |
| ): |
| chunk_id = cache_entry["chunk_id"] |
| extraction_result = cache_entry["return"] |
| create_time = cache_entry.get( |
| "create_time", 0 |
| ) |
| valid_entries += 1 |
|
|
| |
| if chunk_id not in cached_results: |
| cached_results[chunk_id] = [] |
| |
| cached_results[chunk_id].append((extraction_result, create_time)) |
|
|
| |
| chunk_earliest_times = {} |
| for chunk_id in cached_results: |
| |
| cached_results[chunk_id].sort(key=lambda x: x[1]) |
| |
| chunk_earliest_times[chunk_id] = cached_results[chunk_id][0][1] |
|
|
| |
| sorted_chunk_ids = sorted( |
| chunk_earliest_times.keys(), key=lambda chunk_id: chunk_earliest_times[chunk_id] |
| ) |
|
|
| |
| sorted_cached_results = {} |
| for chunk_id in sorted_chunk_ids: |
| sorted_cached_results[chunk_id] = cached_results[chunk_id] |
|
|
| logger.info( |
| f"Found {valid_entries} valid cache entries, {len(sorted_cached_results)} chunks with results" |
| ) |
| return sorted_cached_results |
|
|
|
|
| async def _process_extraction_result( |
| result: str, |
| chunk_key: str, |
| timestamp: int, |
| file_path: str = "unknown_source", |
| tuple_delimiter: str = "<|#|>", |
| completion_delimiter: str = "<|COMPLETE|>", |
| ) -> tuple[dict, dict]: |
| """Process a single extraction result (either initial or gleaning) |
| Args: |
| result (str): The extraction result to process |
| chunk_key (str): The chunk key for source tracking |
| file_path (str): The file path for citation |
| tuple_delimiter (str): Delimiter for tuple fields |
| record_delimiter (str): Delimiter for records |
| completion_delimiter (str): Delimiter for completion |
| Returns: |
| tuple: (nodes_dict, edges_dict) containing the extracted entities and relationships |
| """ |
| maybe_nodes = defaultdict(list) |
| maybe_edges = defaultdict(list) |
|
|
| if completion_delimiter not in result: |
| logger.warning( |
| f"{chunk_key}: Complete delimiter can not be found in extraction result" |
| ) |
|
|
| |
| records = split_string_by_multi_markers( |
| result, |
| ["\n", completion_delimiter, completion_delimiter.lower()], |
| ) |
|
|
| |
| fixed_records = [] |
| for record in records: |
| record = record.strip() |
| if record is None: |
| continue |
| entity_records = split_string_by_multi_markers( |
| record, [f"{tuple_delimiter}entity{tuple_delimiter}"] |
| ) |
| for entity_record in entity_records: |
| if not entity_record.startswith("entity") and not entity_record.startswith( |
| "relation" |
| ): |
| entity_record = f"entity<|{entity_record}" |
| entity_relation_records = split_string_by_multi_markers( |
| |
| entity_record, |
| [ |
| f"{tuple_delimiter}relationship{tuple_delimiter}", |
| f"{tuple_delimiter}relation{tuple_delimiter}", |
| ], |
| ) |
| for entity_relation_record in entity_relation_records: |
| if not entity_relation_record.startswith( |
| "entity" |
| ) and not entity_relation_record.startswith("relation"): |
| entity_relation_record = ( |
| f"relation{tuple_delimiter}{entity_relation_record}" |
| ) |
| fixed_records = fixed_records + [entity_relation_record] |
|
|
| if len(fixed_records) != len(records): |
| logger.warning( |
| f"{chunk_key}: LLM output format error; find LLM use {tuple_delimiter} as record seperators instead new-line" |
| ) |
|
|
| for record in fixed_records: |
| record = record.strip() |
| if record is None: |
| continue |
|
|
| |
| delimiter_core = tuple_delimiter[2:-2] |
| record = fix_tuple_delimiter_corruption(record, delimiter_core, tuple_delimiter) |
| if delimiter_core != delimiter_core.lower(): |
| |
| delimiter_core = delimiter_core.lower() |
| record = fix_tuple_delimiter_corruption( |
| record, delimiter_core, tuple_delimiter |
| ) |
|
|
| record_attributes = split_string_by_multi_markers(record, [tuple_delimiter]) |
|
|
| |
| entity_data = await _handle_single_entity_extraction( |
| record_attributes, chunk_key, timestamp, file_path |
| ) |
| if entity_data is not None: |
| maybe_nodes[entity_data["entity_name"]].append(entity_data) |
| continue |
|
|
| |
| relationship_data = await _handle_single_relationship_extraction( |
| record_attributes, chunk_key, timestamp, file_path |
| ) |
| if relationship_data is not None: |
| maybe_edges[ |
| (relationship_data["src_id"], relationship_data["tgt_id"]) |
| ].append(relationship_data) |
|
|
| return dict(maybe_nodes), dict(maybe_edges) |
|
|
|
|
| async def _rebuild_from_extraction_result( |
| text_chunks_storage: BaseKVStorage, |
| extraction_result: str, |
| chunk_id: str, |
| timestamp: int, |
| ) -> tuple[dict, dict]: |
| """Parse cached extraction result using the same logic as extract_entities |
| |
| Args: |
| text_chunks_storage: Text chunks storage to get chunk data |
| extraction_result: The cached LLM extraction result |
| chunk_id: The chunk ID for source tracking |
| |
| Returns: |
| Tuple of (entities_dict, relationships_dict) |
| """ |
|
|
| |
| chunk_data = await text_chunks_storage.get_by_id(chunk_id) |
| file_path = ( |
| chunk_data.get("file_path", "unknown_source") |
| if chunk_data |
| else "unknown_source" |
| ) |
|
|
| |
| return await _process_extraction_result( |
| extraction_result, |
| chunk_id, |
| timestamp, |
| file_path, |
| tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"], |
| completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"], |
| ) |
|
|
|
|
| async def _rebuild_single_entity( |
| knowledge_graph_inst: BaseGraphStorage, |
| entities_vdb: BaseVectorStorage, |
| entity_name: str, |
| chunk_ids: set[str], |
| chunk_entities: dict, |
| llm_response_cache: BaseKVStorage, |
| global_config: dict[str, str], |
| ) -> None: |
| """Rebuild a single entity from cached extraction results""" |
|
|
| |
| current_entity = await knowledge_graph_inst.get_node(entity_name) |
| if not current_entity: |
| return |
|
|
| |
| async def _update_entity_storage( |
| final_description: str, entity_type: str, file_paths: set[str] |
| ): |
| try: |
| |
| updated_entity_data = { |
| **current_entity, |
| "description": final_description, |
| "entity_type": entity_type, |
| "source_id": GRAPH_FIELD_SEP.join(chunk_ids), |
| "file_path": GRAPH_FIELD_SEP.join(file_paths) |
| if file_paths |
| else current_entity.get("file_path", "unknown_source"), |
| } |
| await knowledge_graph_inst.upsert_node(entity_name, updated_entity_data) |
|
|
| |
| entity_vdb_id = compute_mdhash_id(entity_name, prefix="ent-") |
| entity_content = f"{entity_name}\n{final_description}" |
|
|
| vdb_data = { |
| entity_vdb_id: { |
| "content": entity_content, |
| "entity_name": entity_name, |
| "source_id": updated_entity_data["source_id"], |
| "description": final_description, |
| "entity_type": entity_type, |
| "file_path": updated_entity_data["file_path"], |
| } |
| } |
|
|
| |
| await safe_vdb_operation_with_exception( |
| operation=lambda: entities_vdb.upsert(vdb_data), |
| operation_name="rebuild_entity_upsert", |
| entity_name=entity_name, |
| max_retries=3, |
| retry_delay=0.1, |
| ) |
|
|
| except Exception as e: |
| error_msg = f"Failed to update entity storage for `{entity_name}`: {e}" |
| logger.error(error_msg) |
| raise |
|
|
| |
| all_entity_data = [] |
| for chunk_id in chunk_ids: |
| if chunk_id in chunk_entities and entity_name in chunk_entities[chunk_id]: |
| all_entity_data.extend(chunk_entities[chunk_id][entity_name]) |
|
|
| if not all_entity_data: |
| logger.warning( |
| f"No entity data found for `{entity_name}`, trying to rebuild from relationships" |
| ) |
|
|
| |
| edges = await knowledge_graph_inst.get_node_edges(entity_name) |
| if not edges: |
| logger.warning(f"No relations attached to entity `{entity_name}`") |
| return |
|
|
| |
| relationship_descriptions = [] |
| file_paths = set() |
|
|
| |
| for src_id, tgt_id in edges: |
| edge_data = await knowledge_graph_inst.get_edge(src_id, tgt_id) |
| if edge_data: |
| if edge_data.get("description"): |
| relationship_descriptions.append(edge_data["description"]) |
|
|
| if edge_data.get("file_path"): |
| edge_file_paths = edge_data["file_path"].split(GRAPH_FIELD_SEP) |
| file_paths.update(edge_file_paths) |
|
|
| |
| description_list = list(dict.fromkeys(relationship_descriptions)) |
|
|
| |
| if description_list: |
| final_description, _ = await _handle_entity_relation_summary( |
| "Entity", |
| entity_name, |
| description_list, |
| GRAPH_FIELD_SEP, |
| global_config, |
| llm_response_cache=llm_response_cache, |
| ) |
| else: |
| final_description = current_entity.get("description", "") |
|
|
| entity_type = current_entity.get("entity_type", "UNKNOWN") |
| await _update_entity_storage(final_description, entity_type, file_paths) |
| return |
|
|
| |
| descriptions = [] |
| entity_types = [] |
| file_paths = set() |
|
|
| for entity_data in all_entity_data: |
| if entity_data.get("description"): |
| descriptions.append(entity_data["description"]) |
| if entity_data.get("entity_type"): |
| entity_types.append(entity_data["entity_type"]) |
| if entity_data.get("file_path"): |
| file_paths.add(entity_data["file_path"]) |
|
|
| |
| description_list = list(dict.fromkeys(descriptions)) |
| entity_types = list(dict.fromkeys(entity_types)) |
|
|
| |
| entity_type = ( |
| max(set(entity_types), key=entity_types.count) |
| if entity_types |
| else current_entity.get("entity_type", "UNKNOWN") |
| ) |
|
|
| |
| if description_list: |
| final_description, _ = await _handle_entity_relation_summary( |
| "Entity", |
| entity_name, |
| description_list, |
| GRAPH_FIELD_SEP, |
| global_config, |
| llm_response_cache=llm_response_cache, |
| ) |
| else: |
| final_description = current_entity.get("description", "") |
|
|
| await _update_entity_storage(final_description, entity_type, file_paths) |
|
|
|
|
| async def _rebuild_single_relationship( |
| knowledge_graph_inst: BaseGraphStorage, |
| relationships_vdb: BaseVectorStorage, |
| src: str, |
| tgt: str, |
| chunk_ids: set[str], |
| chunk_relationships: dict, |
| llm_response_cache: BaseKVStorage, |
| global_config: dict[str, str], |
| ) -> None: |
| """Rebuild a single relationship from cached extraction results |
| |
| Note: This function assumes the caller has already acquired the appropriate |
| keyed lock for the relationship pair to ensure thread safety. |
| """ |
|
|
| |
| current_relationship = await knowledge_graph_inst.get_edge(src, tgt) |
| if not current_relationship: |
| return |
|
|
| |
| all_relationship_data = [] |
| for chunk_id in chunk_ids: |
| if chunk_id in chunk_relationships: |
| |
| for edge_key in [(src, tgt), (tgt, src)]: |
| if edge_key in chunk_relationships[chunk_id]: |
| all_relationship_data.extend( |
| chunk_relationships[chunk_id][edge_key] |
| ) |
|
|
| if not all_relationship_data: |
| logger.warning(f"No relation data found for `{src}-{tgt}`") |
| return |
|
|
| |
| descriptions = [] |
| keywords = [] |
| weights = [] |
| file_paths = set() |
|
|
| for rel_data in all_relationship_data: |
| if rel_data.get("description"): |
| descriptions.append(rel_data["description"]) |
| if rel_data.get("keywords"): |
| keywords.append(rel_data["keywords"]) |
| if rel_data.get("weight"): |
| weights.append(rel_data["weight"]) |
| if rel_data.get("file_path"): |
| file_paths.add(rel_data["file_path"]) |
|
|
| |
| description_list = list(dict.fromkeys(descriptions)) |
| keywords = list(dict.fromkeys(keywords)) |
|
|
| combined_keywords = ( |
| ", ".join(set(keywords)) |
| if keywords |
| else current_relationship.get("keywords", "") |
| ) |
|
|
| weight = sum(weights) if weights else current_relationship.get("weight", 1.0) |
|
|
| |
| if description_list: |
| final_description, _ = await _handle_entity_relation_summary( |
| "Relation", |
| f"{src}-{tgt}", |
| description_list, |
| GRAPH_FIELD_SEP, |
| global_config, |
| llm_response_cache=llm_response_cache, |
| ) |
| else: |
| |
| final_description = current_relationship.get("description", "") |
|
|
| |
| updated_relationship_data = { |
| **current_relationship, |
| "description": final_description |
| if final_description |
| else current_relationship.get("description", ""), |
| "keywords": combined_keywords, |
| "weight": weight, |
| "source_id": GRAPH_FIELD_SEP.join(chunk_ids), |
| "file_path": GRAPH_FIELD_SEP.join([fp for fp in file_paths if fp]) |
| if file_paths |
| else current_relationship.get("file_path", "unknown_source"), |
| } |
| await knowledge_graph_inst.upsert_edge(src, tgt, updated_relationship_data) |
|
|
| |
| try: |
| rel_vdb_id = compute_mdhash_id(src + tgt, prefix="rel-") |
| rel_vdb_id_reverse = compute_mdhash_id(tgt + src, prefix="rel-") |
|
|
| |
| try: |
| await relationships_vdb.delete([rel_vdb_id, rel_vdb_id_reverse]) |
| except Exception as e: |
| logger.debug( |
| f"Could not delete old relationship vector records {rel_vdb_id}, {rel_vdb_id_reverse}: {e}" |
| ) |
|
|
| |
| rel_content = f"{combined_keywords}\t{src}\n{tgt}\n{final_description}" |
| vdb_data = { |
| rel_vdb_id: { |
| "src_id": src, |
| "tgt_id": tgt, |
| "source_id": updated_relationship_data["source_id"], |
| "content": rel_content, |
| "keywords": combined_keywords, |
| "description": final_description, |
| "weight": weight, |
| "file_path": updated_relationship_data["file_path"], |
| } |
| } |
|
|
| |
| await safe_vdb_operation_with_exception( |
| operation=lambda: relationships_vdb.upsert(vdb_data), |
| operation_name="rebuild_relationship_upsert", |
| entity_name=f"{src}-{tgt}", |
| max_retries=3, |
| retry_delay=0.2, |
| ) |
|
|
| except Exception as e: |
| error_msg = f"Failed to rebuild relationship storage for `{src}-{tgt}`: {e}" |
| logger.error(error_msg) |
| raise |
|
|
|
|
| async def _merge_nodes_then_upsert( |
| entity_name: str, |
| nodes_data: list[dict], |
| knowledge_graph_inst: BaseGraphStorage, |
| global_config: dict, |
| pipeline_status: dict = None, |
| pipeline_status_lock=None, |
| llm_response_cache: BaseKVStorage | None = None, |
| ): |
| """Get existing nodes from knowledge graph use name,if exists, merge data, else create, then upsert.""" |
| already_entity_types = [] |
| already_source_ids = [] |
| already_description = [] |
| already_file_paths = [] |
|
|
| already_node = await knowledge_graph_inst.get_node(entity_name) |
| if already_node: |
| already_entity_types.append(already_node["entity_type"]) |
| already_source_ids.extend(already_node["source_id"].split(GRAPH_FIELD_SEP)) |
| already_file_paths.extend(already_node["file_path"].split(GRAPH_FIELD_SEP)) |
| already_description.extend(already_node["description"].split(GRAPH_FIELD_SEP)) |
|
|
| entity_type = sorted( |
| Counter( |
| [dp["entity_type"] for dp in nodes_data] + already_entity_types |
| ).items(), |
| key=lambda x: x[1], |
| reverse=True, |
| )[0][0] |
|
|
| |
| unique_nodes = {} |
| for dp in nodes_data: |
| desc = dp["description"] |
| if desc not in unique_nodes: |
| unique_nodes[desc] = dp |
|
|
| |
| sorted_nodes = sorted( |
| unique_nodes.values(), |
| key=lambda x: (x.get("timestamp", 0), -len(x.get("description", ""))), |
| ) |
| sorted_descriptions = [dp["description"] for dp in sorted_nodes] |
|
|
| |
| description_list = already_description + sorted_descriptions |
|
|
| num_fragment = len(description_list) |
| already_fragment = len(already_description) |
| deduplicated_num = already_fragment + len(nodes_data) - num_fragment |
| if deduplicated_num > 0: |
| dd_message = f"(dd:{deduplicated_num})" |
| else: |
| dd_message = "" |
| if num_fragment > 0: |
| |
| description, llm_was_used = await _handle_entity_relation_summary( |
| "Entity", |
| entity_name, |
| description_list, |
| GRAPH_FIELD_SEP, |
| global_config, |
| llm_response_cache, |
| ) |
|
|
| |
| if llm_was_used: |
| status_message = f"LLMmrg: `{entity_name}` | {already_fragment}+{num_fragment - already_fragment}{dd_message}" |
| else: |
| status_message = f"Merged: `{entity_name}` | {already_fragment}+{num_fragment - already_fragment}{dd_message}" |
|
|
| if already_fragment > 0 or llm_was_used: |
| logger.info(status_message) |
| if pipeline_status is not None and pipeline_status_lock is not None: |
| async with pipeline_status_lock: |
| pipeline_status["latest_message"] = status_message |
| pipeline_status["history_messages"].append(status_message) |
| else: |
| logger.debug(status_message) |
|
|
| else: |
| logger.error(f"Entity {entity_name} has no description") |
| description = "(no description)" |
|
|
| source_id = GRAPH_FIELD_SEP.join( |
| set([dp["source_id"] for dp in nodes_data] + already_source_ids) |
| ) |
| file_path = build_file_path(already_file_paths, nodes_data, entity_name) |
|
|
| node_data = dict( |
| entity_id=entity_name, |
| entity_type=entity_type, |
| description=description, |
| source_id=source_id, |
| file_path=file_path, |
| created_at=int(time.time()), |
| ) |
| await knowledge_graph_inst.upsert_node( |
| entity_name, |
| node_data=node_data, |
| ) |
| node_data["entity_name"] = entity_name |
| return node_data |
|
|
|
|
| async def _merge_edges_then_upsert( |
| src_id: str, |
| tgt_id: str, |
| edges_data: list[dict], |
| knowledge_graph_inst: BaseGraphStorage, |
| global_config: dict, |
| pipeline_status: dict = None, |
| pipeline_status_lock=None, |
| llm_response_cache: BaseKVStorage | None = None, |
| added_entities: list = None, |
| ): |
| if src_id == tgt_id: |
| return None |
|
|
| already_weights = [] |
| already_source_ids = [] |
| already_description = [] |
| already_keywords = [] |
| already_file_paths = [] |
|
|
| if await knowledge_graph_inst.has_edge(src_id, tgt_id): |
| already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id) |
| |
| if already_edge: |
| |
| already_weights.append(already_edge.get("weight", 1.0)) |
|
|
| |
| if already_edge.get("source_id") is not None: |
| already_source_ids.extend( |
| already_edge["source_id"].split(GRAPH_FIELD_SEP) |
| ) |
|
|
| |
| if already_edge.get("file_path") is not None: |
| already_file_paths.extend( |
| already_edge["file_path"].split(GRAPH_FIELD_SEP) |
| ) |
|
|
| |
| if already_edge.get("description") is not None: |
| already_description.extend( |
| already_edge["description"].split(GRAPH_FIELD_SEP) |
| ) |
|
|
| |
| if already_edge.get("keywords") is not None: |
| already_keywords.extend( |
| split_string_by_multi_markers( |
| already_edge["keywords"], [GRAPH_FIELD_SEP] |
| ) |
| ) |
|
|
| |
| weight = sum([dp["weight"] for dp in edges_data] + already_weights) |
|
|
| |
| unique_edges = {} |
| for dp in edges_data: |
| if dp.get("description"): |
| desc = dp["description"] |
| if desc not in unique_edges: |
| unique_edges[desc] = dp |
|
|
| |
| sorted_edges = sorted( |
| unique_edges.values(), |
| key=lambda x: (x.get("timestamp", 0), -len(x.get("description", ""))), |
| ) |
| sorted_descriptions = [dp["description"] for dp in sorted_edges] |
|
|
| |
| description_list = already_description + sorted_descriptions |
|
|
| num_fragment = len(description_list) |
| already_fragment = len(already_description) |
| deduplicated_num = already_fragment + len(edges_data) - num_fragment |
| if deduplicated_num > 0: |
| dd_message = f"(dd:{deduplicated_num})" |
| else: |
| dd_message = "" |
| if num_fragment > 0: |
| |
| description, llm_was_used = await _handle_entity_relation_summary( |
| "Relation", |
| f"({src_id}, {tgt_id})", |
| description_list, |
| GRAPH_FIELD_SEP, |
| global_config, |
| llm_response_cache, |
| ) |
|
|
| |
| if llm_was_used: |
| status_message = f"LLMmrg: `{src_id}`~`{tgt_id}` | {already_fragment}+{num_fragment - already_fragment}{dd_message}" |
| else: |
| status_message = f"Merged: `{src_id}`~`{tgt_id}` | {already_fragment}+{num_fragment - already_fragment}{dd_message}" |
|
|
| if already_fragment > 0 or llm_was_used: |
| logger.info(status_message) |
| if pipeline_status is not None and pipeline_status_lock is not None: |
| async with pipeline_status_lock: |
| pipeline_status["latest_message"] = status_message |
| pipeline_status["history_messages"].append(status_message) |
| else: |
| logger.debug(status_message) |
|
|
| else: |
| logger.error(f"Edge {src_id} - {tgt_id} has no description") |
| description = "(no description)" |
|
|
| |
| all_keywords = set() |
| |
| for keyword_str in already_keywords: |
| if keyword_str: |
| all_keywords.update(k.strip() for k in keyword_str.split(",") if k.strip()) |
| |
| for edge in edges_data: |
| if edge.get("keywords"): |
| all_keywords.update( |
| k.strip() for k in edge["keywords"].split(",") if k.strip() |
| ) |
| |
| keywords = ",".join(sorted(all_keywords)) |
|
|
| source_id = GRAPH_FIELD_SEP.join( |
| set( |
| [dp["source_id"] for dp in edges_data if dp.get("source_id")] |
| + already_source_ids |
| ) |
| ) |
| file_path = build_file_path(already_file_paths, edges_data, f"{src_id}-{tgt_id}") |
|
|
| for need_insert_id in [src_id, tgt_id]: |
| if not (await knowledge_graph_inst.has_node(need_insert_id)): |
| node_data = { |
| "entity_id": need_insert_id, |
| "source_id": source_id, |
| "description": description, |
| "entity_type": "UNKNOWN", |
| "file_path": file_path, |
| "created_at": int(time.time()), |
| } |
| await knowledge_graph_inst.upsert_node(need_insert_id, node_data=node_data) |
|
|
| |
| if added_entities is not None: |
| entity_data = { |
| "entity_name": need_insert_id, |
| "entity_type": "UNKNOWN", |
| "description": description, |
| "source_id": source_id, |
| "file_path": file_path, |
| "created_at": int(time.time()), |
| } |
| added_entities.append(entity_data) |
|
|
| await knowledge_graph_inst.upsert_edge( |
| src_id, |
| tgt_id, |
| edge_data=dict( |
| weight=weight, |
| description=description, |
| keywords=keywords, |
| source_id=source_id, |
| file_path=file_path, |
| created_at=int(time.time()), |
| ), |
| ) |
|
|
| edge_data = dict( |
| src_id=src_id, |
| tgt_id=tgt_id, |
| description=description, |
| keywords=keywords, |
| source_id=source_id, |
| file_path=file_path, |
| created_at=int(time.time()), |
| ) |
|
|
| return edge_data |
|
|
|
|
| async def merge_nodes_and_edges( |
| chunk_results: list, |
| knowledge_graph_inst: BaseGraphStorage, |
| entity_vdb: BaseVectorStorage, |
| relationships_vdb: BaseVectorStorage, |
| global_config: dict[str, str], |
| full_entities_storage: BaseKVStorage = None, |
| full_relations_storage: BaseKVStorage = None, |
| doc_id: str = None, |
| pipeline_status: dict = None, |
| pipeline_status_lock=None, |
| llm_response_cache: BaseKVStorage | None = None, |
| current_file_number: int = 0, |
| total_files: int = 0, |
| file_path: str = "unknown_source", |
| ) -> None: |
| """Two-phase merge: process all entities first, then all relationships |
| |
| This approach ensures data consistency by: |
| 1. Phase 1: Process all entities concurrently |
| 2. Phase 2: Process all relationships concurrently (may add missing entities) |
| 3. Phase 3: Update full_entities and full_relations storage with final results |
| |
| Args: |
| chunk_results: List of tuples (maybe_nodes, maybe_edges) containing extracted entities and relationships |
| knowledge_graph_inst: Knowledge graph storage |
| entity_vdb: Entity vector database |
| relationships_vdb: Relationship vector database |
| global_config: Global configuration |
| full_entities_storage: Storage for document entity lists |
| full_relations_storage: Storage for document relation lists |
| doc_id: Document ID for storage indexing |
| pipeline_status: Pipeline status dictionary |
| pipeline_status_lock: Lock for pipeline status |
| llm_response_cache: LLM response cache |
| current_file_number: Current file number for logging |
| total_files: Total files for logging |
| file_path: File path for logging |
| """ |
|
|
| |
| all_nodes = defaultdict(list) |
| all_edges = defaultdict(list) |
|
|
| for maybe_nodes, maybe_edges in chunk_results: |
| |
| for entity_name, entities in maybe_nodes.items(): |
| all_nodes[entity_name].extend(entities) |
|
|
| |
| for edge_key, edges in maybe_edges.items(): |
| sorted_edge_key = tuple(sorted(edge_key)) |
| all_edges[sorted_edge_key].extend(edges) |
|
|
| total_entities_count = len(all_nodes) |
| total_relations_count = len(all_edges) |
|
|
| log_message = f"Merging stage {current_file_number}/{total_files}: {file_path}" |
| logger.info(log_message) |
| async with pipeline_status_lock: |
| pipeline_status["latest_message"] = log_message |
| pipeline_status["history_messages"].append(log_message) |
|
|
| |
| graph_max_async = global_config.get("llm_model_max_async", 4) * 2 |
| semaphore = asyncio.Semaphore(graph_max_async) |
|
|
| |
| log_message = f"Phase 1: Processing {total_entities_count} entities from {doc_id} (async: {graph_max_async})" |
| logger.info(log_message) |
| async with pipeline_status_lock: |
| pipeline_status["latest_message"] = log_message |
| pipeline_status["history_messages"].append(log_message) |
|
|
| async def _locked_process_entity_name(entity_name, entities): |
| async with semaphore: |
| workspace = global_config.get("workspace", "") |
| namespace = f"{workspace}:GraphDB" if workspace else "GraphDB" |
| async with get_storage_keyed_lock( |
| [entity_name], namespace=namespace, enable_logging=False |
| ): |
| try: |
| |
| entity_data = await _merge_nodes_then_upsert( |
| entity_name, |
| entities, |
| knowledge_graph_inst, |
| global_config, |
| pipeline_status, |
| pipeline_status_lock, |
| llm_response_cache, |
| ) |
|
|
| |
| if entity_vdb is not None and entity_data: |
| data_for_vdb = { |
| compute_mdhash_id( |
| entity_data["entity_name"], prefix="ent-" |
| ): { |
| "entity_name": entity_data["entity_name"], |
| "entity_type": entity_data["entity_type"], |
| "content": f"{entity_data['entity_name']}\n{entity_data['description']}", |
| "source_id": entity_data["source_id"], |
| "file_path": entity_data.get( |
| "file_path", "unknown_source" |
| ), |
| } |
| } |
|
|
| |
| await safe_vdb_operation_with_exception( |
| operation=lambda: entity_vdb.upsert(data_for_vdb), |
| operation_name="entity_upsert", |
| entity_name=entity_name, |
| max_retries=3, |
| retry_delay=0.1, |
| ) |
|
|
| return entity_data |
|
|
| except Exception as e: |
| |
| error_msg = ( |
| f"Critical error in entity processing for `{entity_name}`: {e}" |
| ) |
| logger.error(error_msg) |
|
|
| |
| try: |
| if ( |
| pipeline_status is not None |
| and pipeline_status_lock is not None |
| ): |
| async with pipeline_status_lock: |
| pipeline_status["latest_message"] = error_msg |
| pipeline_status["history_messages"].append(error_msg) |
| except Exception as status_error: |
| logger.error( |
| f"Failed to update pipeline status: {status_error}" |
| ) |
|
|
| |
| prefixed_exception = create_prefixed_exception( |
| e, f"`{entity_name}`" |
| ) |
| raise prefixed_exception from e |
|
|
| |
| entity_tasks = [] |
| for entity_name, entities in all_nodes.items(): |
| task = asyncio.create_task(_locked_process_entity_name(entity_name, entities)) |
| entity_tasks.append(task) |
|
|
| |
| processed_entities = [] |
| if entity_tasks: |
| done, pending = await asyncio.wait( |
| entity_tasks, return_when=asyncio.FIRST_EXCEPTION |
| ) |
|
|
| |
| first_exception = None |
| successful_results = [] |
|
|
| for task in done: |
| try: |
| exception = task.exception() |
| if exception is not None: |
| if first_exception is None: |
| first_exception = exception |
| else: |
| successful_results.append(task.result()) |
| except Exception as e: |
| if first_exception is None: |
| first_exception = e |
|
|
| |
| if first_exception is not None: |
| |
| for pending_task in pending: |
| pending_task.cancel() |
| |
| if pending: |
| await asyncio.wait(pending) |
| |
| raise first_exception |
|
|
| |
| processed_entities = [task.result() for task in entity_tasks] |
|
|
| |
| log_message = f"Phase 2: Processing {total_relations_count} relations from {doc_id} (async: {graph_max_async})" |
| logger.info(log_message) |
| async with pipeline_status_lock: |
| pipeline_status["latest_message"] = log_message |
| pipeline_status["history_messages"].append(log_message) |
|
|
| async def _locked_process_edges(edge_key, edges): |
| async with semaphore: |
| workspace = global_config.get("workspace", "") |
| namespace = f"{workspace}:GraphDB" if workspace else "GraphDB" |
| sorted_edge_key = sorted([edge_key[0], edge_key[1]]) |
|
|
| async with get_storage_keyed_lock( |
| sorted_edge_key, |
| namespace=namespace, |
| enable_logging=False, |
| ): |
| try: |
| added_entities = [] |
|
|
| |
| edge_data = await _merge_edges_then_upsert( |
| edge_key[0], |
| edge_key[1], |
| edges, |
| knowledge_graph_inst, |
| global_config, |
| pipeline_status, |
| pipeline_status_lock, |
| llm_response_cache, |
| added_entities, |
| ) |
|
|
| if edge_data is None: |
| return None, [] |
|
|
| |
| if relationships_vdb is not None: |
| data_for_vdb = { |
| compute_mdhash_id( |
| edge_data["src_id"] + edge_data["tgt_id"], prefix="rel-" |
| ): { |
| "src_id": edge_data["src_id"], |
| "tgt_id": edge_data["tgt_id"], |
| "keywords": edge_data["keywords"], |
| "content": f"{edge_data['src_id']}\t{edge_data['tgt_id']}\n{edge_data['keywords']}\n{edge_data['description']}", |
| "source_id": edge_data["source_id"], |
| "file_path": edge_data.get( |
| "file_path", "unknown_source" |
| ), |
| "weight": edge_data.get("weight", 1.0), |
| } |
| } |
|
|
| |
| await safe_vdb_operation_with_exception( |
| operation=lambda: relationships_vdb.upsert(data_for_vdb), |
| operation_name="relationship_upsert", |
| entity_name=f"{edge_data['src_id']}-{edge_data['tgt_id']}", |
| max_retries=3, |
| retry_delay=0.1, |
| ) |
|
|
| |
| if added_entities and entity_vdb is not None: |
| for entity_data in added_entities: |
| entity_vdb_id = compute_mdhash_id( |
| entity_data["entity_name"], prefix="ent-" |
| ) |
| entity_content = f"{entity_data['entity_name']}\n{entity_data['description']}" |
|
|
| vdb_data = { |
| entity_vdb_id: { |
| "content": entity_content, |
| "entity_name": entity_data["entity_name"], |
| "source_id": entity_data["source_id"], |
| "entity_type": entity_data["entity_type"], |
| "file_path": entity_data.get( |
| "file_path", "unknown_source" |
| ), |
| } |
| } |
|
|
| |
| await safe_vdb_operation_with_exception( |
| operation=lambda data=vdb_data: entity_vdb.upsert(data), |
| operation_name="added_entity_upsert", |
| entity_name=entity_data["entity_name"], |
| max_retries=3, |
| retry_delay=0.1, |
| ) |
|
|
| return edge_data, added_entities |
|
|
| except Exception as e: |
| |
| error_msg = f"Critical error in relationship processing for `{sorted_edge_key}`: {e}" |
| logger.error(error_msg) |
|
|
| |
| try: |
| if ( |
| pipeline_status is not None |
| and pipeline_status_lock is not None |
| ): |
| async with pipeline_status_lock: |
| pipeline_status["latest_message"] = error_msg |
| pipeline_status["history_messages"].append(error_msg) |
| except Exception as status_error: |
| logger.error( |
| f"Failed to update pipeline status: {status_error}" |
| ) |
|
|
| |
| prefixed_exception = create_prefixed_exception( |
| e, f"{sorted_edge_key}" |
| ) |
| raise prefixed_exception from e |
|
|
| |
| edge_tasks = [] |
| for edge_key, edges in all_edges.items(): |
| task = asyncio.create_task(_locked_process_edges(edge_key, edges)) |
| edge_tasks.append(task) |
|
|
| |
| processed_edges = [] |
| all_added_entities = [] |
|
|
| if edge_tasks: |
| done, pending = await asyncio.wait( |
| edge_tasks, return_when=asyncio.FIRST_EXCEPTION |
| ) |
|
|
| |
| first_exception = None |
| successful_results = [] |
|
|
| for task in done: |
| try: |
| exception = task.exception() |
| if exception is not None: |
| if first_exception is None: |
| first_exception = exception |
| else: |
| successful_results.append(task.result()) |
| except Exception as e: |
| if first_exception is None: |
| first_exception = e |
|
|
| |
| if first_exception is not None: |
| |
| for pending_task in pending: |
| pending_task.cancel() |
| |
| if pending: |
| await asyncio.wait(pending) |
| |
| raise first_exception |
|
|
| |
| for task in edge_tasks: |
| edge_data, added_entities = task.result() |
| if edge_data is not None: |
| processed_edges.append(edge_data) |
| all_added_entities.extend(added_entities) |
|
|
| |
| if full_entities_storage and full_relations_storage and doc_id: |
| try: |
| |
| final_entity_names = set() |
|
|
| |
| for entity_data in processed_entities: |
| if entity_data and entity_data.get("entity_name"): |
| final_entity_names.add(entity_data["entity_name"]) |
|
|
| |
| for added_entity in all_added_entities: |
| if added_entity and added_entity.get("entity_name"): |
| final_entity_names.add(added_entity["entity_name"]) |
|
|
| |
| final_relation_pairs = set() |
| for edge_data in processed_edges: |
| if edge_data: |
| src_id = edge_data.get("src_id") |
| tgt_id = edge_data.get("tgt_id") |
| if src_id and tgt_id: |
| relation_pair = tuple(sorted([src_id, tgt_id])) |
| final_relation_pairs.add(relation_pair) |
|
|
| log_message = f"Phase 3: Updating final {len(final_entity_names)}({len(processed_entities)}+{len(all_added_entities)}) entities and {len(final_relation_pairs)} relations from {doc_id}" |
| logger.info(log_message) |
| async with pipeline_status_lock: |
| pipeline_status["latest_message"] = log_message |
| pipeline_status["history_messages"].append(log_message) |
|
|
| |
| if final_entity_names: |
| await full_entities_storage.upsert( |
| { |
| doc_id: { |
| "entity_names": list(final_entity_names), |
| "count": len(final_entity_names), |
| } |
| } |
| ) |
|
|
| if final_relation_pairs: |
| await full_relations_storage.upsert( |
| { |
| doc_id: { |
| "relation_pairs": [ |
| list(pair) for pair in final_relation_pairs |
| ], |
| "count": len(final_relation_pairs), |
| } |
| } |
| ) |
|
|
| logger.debug( |
| f"Updated entity-relation index for document {doc_id}: {len(final_entity_names)} entities (original: {len(processed_entities)}, added: {len(all_added_entities)}), {len(final_relation_pairs)} relations" |
| ) |
|
|
| except Exception as e: |
| logger.error( |
| f"Failed to update entity-relation index for document {doc_id}: {e}" |
| ) |
| |
|
|
| log_message = f"Completed merging: {len(processed_entities)} entities, {len(all_added_entities)} extra entities, {len(processed_edges)} relations" |
| logger.info(log_message) |
| async with pipeline_status_lock: |
| pipeline_status["latest_message"] = log_message |
| pipeline_status["history_messages"].append(log_message) |
|
|
|
|
| async def extract_entities( |
| chunks: dict[str, TextChunkSchema], |
| global_config: dict[str, str], |
| pipeline_status: dict = None, |
| pipeline_status_lock=None, |
| llm_response_cache: BaseKVStorage | None = None, |
| text_chunks_storage: BaseKVStorage | None = None, |
| ) -> list: |
| use_llm_func: callable = global_config["llm_model_func"] |
| entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] |
|
|
| ordered_chunks = list(chunks.items()) |
| |
| language = global_config["addon_params"].get("language", DEFAULT_SUMMARY_LANGUAGE) |
| entity_types = global_config["addon_params"].get( |
| "entity_types", DEFAULT_ENTITY_TYPES |
| ) |
|
|
| examples = "\n".join(PROMPTS["entity_extraction_examples"]) |
|
|
| example_context_base = dict( |
| tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"], |
| completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"], |
| entity_types=", ".join(entity_types), |
| language=language, |
| ) |
| |
| examples = examples.format(**example_context_base) |
|
|
| context_base = dict( |
| tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"], |
| completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"], |
| entity_types=",".join(entity_types), |
| examples=examples, |
| language=language, |
| ) |
|
|
| processed_chunks = 0 |
| total_chunks = len(ordered_chunks) |
|
|
| async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]): |
| """Process a single chunk |
| Args: |
| chunk_key_dp (tuple[str, TextChunkSchema]): |
| ("chunk-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int}) |
| Returns: |
| tuple: (maybe_nodes, maybe_edges) containing extracted entities and relationships |
| """ |
| nonlocal processed_chunks |
| chunk_key = chunk_key_dp[0] |
| chunk_dp = chunk_key_dp[1] |
| content = chunk_dp["content"] |
| |
| file_path = chunk_dp.get("file_path", "unknown_source") |
|
|
| |
| cache_keys_collector = [] |
|
|
| |
| entity_extraction_system_prompt = PROMPTS[ |
| "entity_extraction_system_prompt" |
| ].format(**{**context_base, "input_text": content}) |
| entity_extraction_user_prompt = PROMPTS["entity_extraction_user_prompt"].format( |
| **{**context_base, "input_text": content} |
| ) |
| entity_continue_extraction_user_prompt = PROMPTS[ |
| "entity_continue_extraction_user_prompt" |
| ].format(**{**context_base, "input_text": content}) |
|
|
| final_result, timestamp = await use_llm_func_with_cache( |
| entity_extraction_user_prompt, |
| use_llm_func, |
| system_prompt=entity_extraction_system_prompt, |
| llm_response_cache=llm_response_cache, |
| cache_type="extract", |
| chunk_id=chunk_key, |
| cache_keys_collector=cache_keys_collector, |
| ) |
|
|
| history = pack_user_ass_to_openai_messages( |
| entity_extraction_user_prompt, final_result |
| ) |
|
|
| |
| maybe_nodes, maybe_edges = await _process_extraction_result( |
| final_result, |
| chunk_key, |
| timestamp, |
| file_path, |
| tuple_delimiter=context_base["tuple_delimiter"], |
| completion_delimiter=context_base["completion_delimiter"], |
| ) |
|
|
| |
| if entity_extract_max_gleaning > 0: |
| glean_result, timestamp = await use_llm_func_with_cache( |
| entity_continue_extraction_user_prompt, |
| use_llm_func, |
| system_prompt=entity_extraction_system_prompt, |
| llm_response_cache=llm_response_cache, |
| history_messages=history, |
| cache_type="extract", |
| chunk_id=chunk_key, |
| cache_keys_collector=cache_keys_collector, |
| ) |
|
|
| |
| glean_nodes, glean_edges = await _process_extraction_result( |
| glean_result, |
| chunk_key, |
| timestamp, |
| file_path, |
| tuple_delimiter=context_base["tuple_delimiter"], |
| completion_delimiter=context_base["completion_delimiter"], |
| ) |
|
|
| |
| for entity_name, glean_entities in glean_nodes.items(): |
| if entity_name in maybe_nodes: |
| |
| original_desc_len = len( |
| maybe_nodes[entity_name][0].get("description", "") or "" |
| ) |
| glean_desc_len = len(glean_entities[0].get("description", "") or "") |
|
|
| if glean_desc_len > original_desc_len: |
| maybe_nodes[entity_name] = list(glean_entities) |
| |
| else: |
| |
| maybe_nodes[entity_name] = list(glean_entities) |
|
|
| for edge_key, glean_edges in glean_edges.items(): |
| if edge_key in maybe_edges: |
| |
| original_desc_len = len( |
| maybe_edges[edge_key][0].get("description", "") or "" |
| ) |
| glean_desc_len = len(glean_edges[0].get("description", "") or "") |
|
|
| if glean_desc_len > original_desc_len: |
| maybe_edges[edge_key] = list(glean_edges) |
| |
| else: |
| |
| maybe_edges[edge_key] = list(glean_edges) |
|
|
| |
| if cache_keys_collector and text_chunks_storage: |
| await update_chunk_cache_list( |
| chunk_key, |
| text_chunks_storage, |
| cache_keys_collector, |
| "entity_extraction", |
| ) |
|
|
| processed_chunks += 1 |
| entities_count = len(maybe_nodes) |
| relations_count = len(maybe_edges) |
| log_message = f"Chunk {processed_chunks} of {total_chunks} extracted {entities_count} Ent + {relations_count} Rel {chunk_key}" |
| logger.info(log_message) |
| if pipeline_status is not None: |
| async with pipeline_status_lock: |
| pipeline_status["latest_message"] = log_message |
| pipeline_status["history_messages"].append(log_message) |
|
|
| |
| return maybe_nodes, maybe_edges |
|
|
| |
| chunk_max_async = global_config.get("llm_model_max_async", 4) |
| semaphore = asyncio.Semaphore(chunk_max_async) |
|
|
| async def _process_with_semaphore(chunk): |
| async with semaphore: |
| try: |
| return await _process_single_content(chunk) |
| except Exception as e: |
| chunk_id = chunk[0] |
| prefixed_exception = create_prefixed_exception(e, chunk_id) |
| raise prefixed_exception from e |
|
|
| tasks = [] |
| for c in ordered_chunks: |
| task = asyncio.create_task(_process_with_semaphore(c)) |
| tasks.append(task) |
|
|
| |
| |
| done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION) |
|
|
| |
| first_exception = None |
| chunk_results = [] |
|
|
| for task in done: |
| try: |
| exception = task.exception() |
| if exception is not None: |
| if first_exception is None: |
| first_exception = exception |
| else: |
| chunk_results.append(task.result()) |
| except Exception as e: |
| if first_exception is None: |
| first_exception = e |
|
|
| |
| if first_exception is not None: |
| |
| for pending_task in pending: |
| pending_task.cancel() |
|
|
| |
| if pending: |
| await asyncio.wait(pending) |
|
|
| |
| progress_prefix = f"C[{processed_chunks+1}/{total_chunks}]" |
|
|
| |
| prefixed_exception = create_prefixed_exception(first_exception, progress_prefix) |
| raise prefixed_exception from first_exception |
|
|
| |
| |
| return chunk_results |
|
|
|
|
| async def kg_query( |
| query: str, |
| knowledge_graph_inst: BaseGraphStorage, |
| entities_vdb: BaseVectorStorage, |
| relationships_vdb: BaseVectorStorage, |
| text_chunks_db: BaseKVStorage, |
| query_param: QueryParam, |
| global_config: dict[str, str], |
| hashing_kv: BaseKVStorage | None = None, |
| system_prompt: str | None = None, |
| chunks_vdb: BaseVectorStorage = None, |
| ) -> QueryResult: |
| logger.info("in kg_query") |
| """ |
| Execute knowledge graph query and return unified QueryResult object. |
| |
| Args: |
| query: Query string |
| knowledge_graph_inst: Knowledge graph storage instance |
| entities_vdb: Entity vector database |
| relationships_vdb: Relationship vector database |
| text_chunks_db: Text chunks storage |
| query_param: Query parameters |
| global_config: Global configuration |
| hashing_kv: Cache storage |
| system_prompt: System prompt |
| chunks_vdb: Document chunks vector database |
| |
| Returns: |
| QueryResult: Unified query result object containing: |
| - content: Non-streaming response text content |
| - response_iterator: Streaming response iterator |
| - raw_data: Complete structured data (including references and metadata) |
| - is_streaming: Whether this is a streaming result |
| |
| Based on different query_param settings, different fields will be populated: |
| - only_need_context=True: content contains context string |
| - only_need_prompt=True: content contains complete prompt |
| - stream=True: response_iterator contains streaming response, raw_data contains complete data |
| - default: content contains LLM response text, raw_data contains complete data |
| """ |
| if not query: |
| return QueryResult(content=PROMPTS["fail_response"]) |
|
|
| if query_param.model_func: |
| use_model_func = query_param.model_func |
| else: |
| use_model_func = global_config["llm_model_func"] |
| logger.info(f"use_model_func: {use_model_func}") |
| logger.info(f"dir(use_model_func): {dir(use_model_func)}") |
| if not isinstance(use_model_func, partial): |
| |
| use_model_func = partial(use_model_func, _priority=5) |
|
|
| hl_keywords, ll_keywords = await get_keywords_from_query( |
| query, query_param, global_config, hashing_kv |
| ) |
|
|
| logger.debug(f"High-level keywords: {hl_keywords}") |
| logger.debug(f"Low-level keywords: {ll_keywords}") |
|
|
| |
| if ll_keywords == [] and query_param.mode in ["local", "hybrid", "mix"]: |
| logger.warning("low_level_keywords is empty") |
| if hl_keywords == [] and query_param.mode in ["global", "hybrid", "mix"]: |
| logger.warning("high_level_keywords is empty") |
| if hl_keywords == [] and ll_keywords == []: |
| if len(query) < 50: |
| logger.warning(f"Forced low_level_keywords to origin query: {query}") |
| ll_keywords = [query] |
| else: |
| return QueryResult(content=PROMPTS["fail_response"]) |
|
|
| ll_keywords_str = ", ".join(ll_keywords) if ll_keywords else "" |
| hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else "" |
|
|
| |
| context_result = await _build_query_context( |
| query, |
| ll_keywords_str, |
| hl_keywords_str, |
| knowledge_graph_inst, |
| entities_vdb, |
| relationships_vdb, |
| text_chunks_db, |
| query_param, |
| chunks_vdb, |
| ) |
|
|
| if context_result is None: |
| return QueryResult(content=PROMPTS["fail_response"]) |
|
|
| |
| if query_param.only_need_context and not query_param.only_need_prompt: |
| return QueryResult( |
| content=context_result.context, raw_data=context_result.raw_data |
| ) |
|
|
| user_prompt = f"\n\n{query_param.user_prompt}" if query_param.user_prompt else "n/a" |
| response_type = ( |
| query_param.response_type |
| if query_param.response_type |
| else "Multiple Paragraphs" |
| ) |
|
|
| |
| sys_prompt_temp = system_prompt if system_prompt else PROMPTS["rag_response"] |
| sys_prompt = sys_prompt_temp.format( |
| response_type=response_type, |
| user_prompt=user_prompt, |
| context_data=context_result.context, |
| ) |
|
|
| user_query = query |
|
|
| if query_param.only_need_prompt: |
| prompt_content = "\n\n".join([sys_prompt, "---User Query---", user_query]) |
| return QueryResult(content=prompt_content, raw_data=context_result.raw_data) |
|
|
| |
| tokenizer: Tokenizer = global_config["tokenizer"] |
| len_of_prompts = len(tokenizer.encode(query + sys_prompt)) |
| logger.debug( |
| f"[kg_query] Sending to LLM: {len_of_prompts:,} tokens (Query: {len(tokenizer.encode(query))}, System: {len(tokenizer.encode(sys_prompt))})" |
| ) |
|
|
| |
| args_hash = compute_args_hash( |
| query_param.mode, |
| query, |
| query_param.response_type, |
| query_param.top_k, |
| query_param.chunk_top_k, |
| query_param.max_entity_tokens, |
| query_param.max_relation_tokens, |
| query_param.max_total_tokens, |
| hl_keywords_str, |
| ll_keywords_str, |
| query_param.user_prompt or "", |
| query_param.enable_rerank, |
| ) |
|
|
| cached_result = await handle_cache( |
| hashing_kv, args_hash, user_query, query_param.mode, cache_type="query" |
| ) |
|
|
| if cached_result is not None: |
| cached_response, _ = cached_result |
| logger.info( |
| " == LLM cache == Query cache hit, using cached response as query result" |
| ) |
| response = cached_response |
| else: |
| response = await use_model_func( |
| user_query, |
| system_prompt=sys_prompt, |
| history_messages=query_param.conversation_history, |
| enable_cot=True, |
| stream=query_param.stream, |
| ) |
|
|
| if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"): |
| queryparam_dict = { |
| "mode": query_param.mode, |
| "response_type": query_param.response_type, |
| "top_k": query_param.top_k, |
| "chunk_top_k": query_param.chunk_top_k, |
| "max_entity_tokens": query_param.max_entity_tokens, |
| "max_relation_tokens": query_param.max_relation_tokens, |
| "max_total_tokens": query_param.max_total_tokens, |
| "hl_keywords": hl_keywords_str, |
| "ll_keywords": ll_keywords_str, |
| "user_prompt": query_param.user_prompt or "", |
| "enable_rerank": query_param.enable_rerank, |
| } |
| await save_to_cache( |
| hashing_kv, |
| CacheData( |
| args_hash=args_hash, |
| content=response, |
| prompt=query, |
| mode=query_param.mode, |
| cache_type="query", |
| queryparam=queryparam_dict, |
| ), |
| ) |
|
|
| |
| if isinstance(response, str): |
| |
| if len(response) > len(sys_prompt): |
| response = ( |
| response.replace(sys_prompt, "") |
| .replace("user", "") |
| .replace("model", "") |
| .replace(query, "") |
| .replace("<system>", "") |
| .replace("</system>", "") |
| .strip() |
| ) |
|
|
| return QueryResult(content=response, raw_data=context_result.raw_data) |
| else: |
| |
| return QueryResult( |
| response_iterator=response, |
| raw_data=context_result.raw_data, |
| is_streaming=True, |
| ) |
|
|
|
|
| async def get_keywords_from_query( |
| query: str, |
| query_param: QueryParam, |
| global_config: dict[str, str], |
| hashing_kv: BaseKVStorage | None = None, |
| ) -> tuple[list[str], list[str]]: |
| """ |
| Retrieves high-level and low-level keywords for RAG operations. |
| |
| This function checks if keywords are already provided in query parameters, |
| and if not, extracts them from the query text using LLM. |
| |
| Args: |
| query: The user's query text |
| query_param: Query parameters that may contain pre-defined keywords |
| global_config: Global configuration dictionary |
| hashing_kv: Optional key-value storage for caching results |
| |
| Returns: |
| A tuple containing (high_level_keywords, low_level_keywords) |
| """ |
| |
| if query_param.hl_keywords or query_param.ll_keywords: |
| return query_param.hl_keywords, query_param.ll_keywords |
|
|
| |
| hl_keywords, ll_keywords = await extract_keywords_only( |
| query, query_param, global_config, hashing_kv |
| ) |
| return hl_keywords, ll_keywords |
|
|
|
|
| async def extract_keywords_only( |
| text: str, |
| param: QueryParam, |
| global_config: dict[str, str], |
| hashing_kv: BaseKVStorage | None = None, |
| ) -> tuple[list[str], list[str]]: |
| """ |
| Extract high-level and low-level keywords from the given 'text' using the LLM. |
| This method does NOT build the final RAG context or provide a final answer. |
| It ONLY extracts keywords (hl_keywords, ll_keywords). |
| """ |
|
|
| |
| args_hash = compute_args_hash( |
| param.mode, |
| text, |
| ) |
| cached_result = await handle_cache( |
| hashing_kv, args_hash, text, param.mode, cache_type="keywords" |
| ) |
| if cached_result is not None: |
| cached_response, _ = cached_result |
| try: |
| keywords_data = json_repair.loads(cached_response) |
| return keywords_data.get("high_level_keywords", []), keywords_data.get( |
| "low_level_keywords", [] |
| ) |
| except (json.JSONDecodeError, KeyError): |
| logger.warning( |
| "Invalid cache format for keywords, proceeding with extraction" |
| ) |
|
|
| |
| examples = "\n".join(PROMPTS["keywords_extraction_examples"]) |
|
|
| language = global_config["addon_params"].get("language", DEFAULT_SUMMARY_LANGUAGE) |
|
|
| |
| kw_prompt = PROMPTS["keywords_extraction"].format( |
| query=text, |
| examples=examples, |
| language=language, |
| ) |
|
|
| tokenizer: Tokenizer = global_config["tokenizer"] |
| len_of_prompts = len(tokenizer.encode(kw_prompt)) |
| logger.debug( |
| f"[extract_keywords] Sending to LLM: {len_of_prompts:,} tokens (Prompt: {len_of_prompts})" |
| ) |
|
|
| |
| if param.model_func: |
| use_model_func = param.model_func |
| else: |
| use_model_func = global_config["llm_model_func"] |
| |
| use_model_func = partial(use_model_func, _priority=5) |
|
|
| result = await use_model_func(kw_prompt, keyword_extraction=True) |
|
|
| |
| result = remove_think_tags(result) |
| try: |
| keywords_data = json_repair.loads(result) |
| if not keywords_data: |
| logger.error("No JSON-like structure found in the LLM respond.") |
| return [], [] |
| except json.JSONDecodeError as e: |
| logger.error(f"JSON parsing error: {e}") |
| logger.error(f"LLM respond: {result}") |
| return [], [] |
|
|
| hl_keywords = keywords_data.get("high_level_keywords", []) |
| ll_keywords = keywords_data.get("low_level_keywords", []) |
|
|
| |
| if hl_keywords or ll_keywords: |
| cache_data = { |
| "high_level_keywords": hl_keywords, |
| "low_level_keywords": ll_keywords, |
| } |
| if hashing_kv.global_config.get("enable_llm_cache"): |
| |
| queryparam_dict = { |
| "mode": param.mode, |
| "response_type": param.response_type, |
| "top_k": param.top_k, |
| "chunk_top_k": param.chunk_top_k, |
| "max_entity_tokens": param.max_entity_tokens, |
| "max_relation_tokens": param.max_relation_tokens, |
| "max_total_tokens": param.max_total_tokens, |
| "user_prompt": param.user_prompt or "", |
| "enable_rerank": param.enable_rerank, |
| } |
| await save_to_cache( |
| hashing_kv, |
| CacheData( |
| args_hash=args_hash, |
| content=json.dumps(cache_data), |
| prompt=text, |
| mode=param.mode, |
| cache_type="keywords", |
| queryparam=queryparam_dict, |
| ), |
| ) |
|
|
| return hl_keywords, ll_keywords |
|
|
|
|
| async def _get_vector_context( |
| query: str, |
| chunks_vdb: BaseVectorStorage, |
| query_param: QueryParam, |
| query_embedding: list[float] = None, |
| ) -> list[dict]: |
| """ |
| Retrieve text chunks from the vector database without reranking or truncation. |
| |
| This function performs vector search to find relevant text chunks for a query. |
| Reranking and truncation will be handled later in the unified processing. |
| |
| Args: |
| query: The query string to search for |
| chunks_vdb: Vector database containing document chunks |
| query_param: Query parameters including chunk_top_k and ids |
| query_embedding: Optional pre-computed query embedding to avoid redundant embedding calls |
| |
| Returns: |
| List of text chunks with metadata |
| """ |
| try: |
| |
| search_top_k = query_param.chunk_top_k or query_param.top_k |
| cosine_threshold = chunks_vdb.cosine_better_than_threshold |
|
|
| results = await chunks_vdb.query( |
| query, top_k=search_top_k, query_embedding=query_embedding |
| ) |
| if not results: |
| logger.info( |
| f"Naive query: 0 chunks (chunk_top_k:{search_top_k} cosine:{cosine_threshold})" |
| ) |
| return [] |
|
|
| valid_chunks = [] |
| for result in results: |
| if "content" in result: |
| chunk_with_metadata = { |
| "content": result["content"], |
| "created_at": result.get("created_at", None), |
| "file_path": result.get("file_path", "unknown_source"), |
| "source_type": "vector", |
| "chunk_id": result.get("id"), |
| } |
| valid_chunks.append(chunk_with_metadata) |
|
|
| logger.info( |
| f"Naive query: {len(valid_chunks)} chunks (chunk_top_k:{search_top_k} cosine:{cosine_threshold})" |
| ) |
| return valid_chunks |
|
|
| except Exception as e: |
| logger.error(f"Error in _get_vector_context: {e}") |
| return [] |
|
|
|
|
| async def _perform_kg_search( |
| query: str, |
| ll_keywords: str, |
| hl_keywords: str, |
| knowledge_graph_inst: BaseGraphStorage, |
| entities_vdb: BaseVectorStorage, |
| relationships_vdb: BaseVectorStorage, |
| text_chunks_db: BaseKVStorage, |
| query_param: QueryParam, |
| chunks_vdb: BaseVectorStorage = None, |
| ) -> dict[str, Any]: |
| """ |
| Pure search logic that retrieves raw entities, relations, and vector chunks. |
| No token truncation or formatting - just raw search results. |
| """ |
|
|
| |
| local_entities = [] |
| local_relations = [] |
| global_entities = [] |
| global_relations = [] |
| vector_chunks = [] |
| chunk_tracking = {} |
|
|
| |
|
|
| |
| chunk_tracking = {} |
|
|
| |
| kg_chunk_pick_method = text_chunks_db.global_config.get( |
| "kg_chunk_pick_method", DEFAULT_KG_CHUNK_PICK_METHOD |
| ) |
| query_embedding = None |
| if query and (kg_chunk_pick_method == "VECTOR" or chunks_vdb): |
| embedding_func_config = text_chunks_db.embedding_func |
| if embedding_func_config: |
| try: |
| query_embedding = await embedding_func_config([query]) |
| query_embedding = query_embedding[ |
| 0 |
| ] |
| logger.debug("Pre-computed query embedding for all vector operations") |
| except Exception as e: |
| logger.warning(f"Failed to pre-compute query embedding: {e}") |
| query_embedding = None |
|
|
| |
| if query_param.mode == "local" and len(ll_keywords) > 0: |
| local_entities, local_relations = await _get_node_data( |
| ll_keywords, |
| knowledge_graph_inst, |
| entities_vdb, |
| query_param, |
| ) |
|
|
| elif query_param.mode == "global" and len(hl_keywords) > 0: |
| global_relations, global_entities = await _get_edge_data( |
| hl_keywords, |
| knowledge_graph_inst, |
| relationships_vdb, |
| query_param, |
| ) |
|
|
| else: |
| if len(ll_keywords) > 0: |
| local_entities, local_relations = await _get_node_data( |
| ll_keywords, |
| knowledge_graph_inst, |
| entities_vdb, |
| query_param, |
| ) |
| if len(hl_keywords) > 0: |
| global_relations, global_entities = await _get_edge_data( |
| hl_keywords, |
| knowledge_graph_inst, |
| relationships_vdb, |
| query_param, |
| ) |
|
|
| |
| if query_param.mode == "mix" and chunks_vdb: |
| vector_chunks = await _get_vector_context( |
| query, |
| chunks_vdb, |
| query_param, |
| query_embedding, |
| ) |
| |
| for i, chunk in enumerate(vector_chunks): |
| chunk_id = chunk.get("chunk_id") or chunk.get("id") |
| if chunk_id: |
| chunk_tracking[chunk_id] = { |
| "source": "C", |
| "frequency": 1, |
| "order": i + 1, |
| } |
| else: |
| logger.warning(f"Vector chunk missing chunk_id: {chunk}") |
|
|
| |
| final_entities = [] |
| seen_entities = set() |
| max_len = max(len(local_entities), len(global_entities)) |
| for i in range(max_len): |
| |
| if i < len(local_entities): |
| entity = local_entities[i] |
| entity_name = entity.get("entity_name") |
| if entity_name and entity_name not in seen_entities: |
| final_entities.append(entity) |
| seen_entities.add(entity_name) |
|
|
| |
| if i < len(global_entities): |
| entity = global_entities[i] |
| entity_name = entity.get("entity_name") |
| if entity_name and entity_name not in seen_entities: |
| final_entities.append(entity) |
| seen_entities.add(entity_name) |
|
|
| |
| final_relations = [] |
| seen_relations = set() |
| max_len = max(len(local_relations), len(global_relations)) |
| for i in range(max_len): |
| |
| if i < len(local_relations): |
| relation = local_relations[i] |
| |
| if "src_tgt" in relation: |
| rel_key = tuple(sorted(relation["src_tgt"])) |
| else: |
| rel_key = tuple( |
| sorted([relation.get("src_id"), relation.get("tgt_id")]) |
| ) |
|
|
| if rel_key not in seen_relations: |
| final_relations.append(relation) |
| seen_relations.add(rel_key) |
|
|
| |
| if i < len(global_relations): |
| relation = global_relations[i] |
| |
| if "src_tgt" in relation: |
| rel_key = tuple(sorted(relation["src_tgt"])) |
| else: |
| rel_key = tuple( |
| sorted([relation.get("src_id"), relation.get("tgt_id")]) |
| ) |
|
|
| if rel_key not in seen_relations: |
| final_relations.append(relation) |
| seen_relations.add(rel_key) |
|
|
| logger.info( |
| f"Raw search results: {len(final_entities)} entities, {len(final_relations)} relations, {len(vector_chunks)} vector chunks" |
| ) |
|
|
| return { |
| "final_entities": final_entities, |
| "final_relations": final_relations, |
| "vector_chunks": vector_chunks, |
| "chunk_tracking": chunk_tracking, |
| "query_embedding": query_embedding, |
| } |
|
|
|
|
| async def _apply_token_truncation( |
| search_result: dict[str, Any], |
| query_param: QueryParam, |
| global_config: dict[str, str], |
| ) -> dict[str, Any]: |
| """ |
| Apply token-based truncation to entities and relations for LLM efficiency. |
| """ |
| tokenizer = global_config.get("tokenizer") |
| if not tokenizer: |
| logger.warning("No tokenizer found, skipping truncation") |
| return { |
| "entities_context": [], |
| "relations_context": [], |
| "filtered_entities": search_result["final_entities"], |
| "filtered_relations": search_result["final_relations"], |
| "entity_id_to_original": {}, |
| "relation_id_to_original": {}, |
| } |
|
|
| |
| max_entity_tokens = getattr( |
| query_param, |
| "max_entity_tokens", |
| global_config.get("max_entity_tokens", DEFAULT_MAX_ENTITY_TOKENS), |
| ) |
| max_relation_tokens = getattr( |
| query_param, |
| "max_relation_tokens", |
| global_config.get("max_relation_tokens", DEFAULT_MAX_RELATION_TOKENS), |
| ) |
|
|
| final_entities = search_result["final_entities"] |
| final_relations = search_result["final_relations"] |
|
|
| |
| entity_id_to_original = {} |
| relation_id_to_original = {} |
|
|
| |
| entities_context = [] |
| for i, entity in enumerate(final_entities): |
| entity_name = entity["entity_name"] |
| created_at = entity.get("created_at", "UNKNOWN") |
| if isinstance(created_at, (int, float)): |
| created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at)) |
|
|
| |
| entity_id_to_original[entity_name] = entity |
|
|
| entities_context.append( |
| { |
| "entity": entity_name, |
| "type": entity.get("entity_type", "UNKNOWN"), |
| "description": entity.get("description", "UNKNOWN"), |
| "created_at": created_at, |
| "file_path": entity.get("file_path", "unknown_source"), |
| } |
| ) |
|
|
| |
| relations_context = [] |
| for i, relation in enumerate(final_relations): |
| created_at = relation.get("created_at", "UNKNOWN") |
| if isinstance(created_at, (int, float)): |
| created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at)) |
|
|
| |
| if "src_tgt" in relation: |
| entity1, entity2 = relation["src_tgt"] |
| else: |
| entity1, entity2 = relation.get("src_id"), relation.get("tgt_id") |
|
|
| |
| relation_key = (entity1, entity2) |
| relation_id_to_original[relation_key] = relation |
|
|
| relations_context.append( |
| { |
| "entity1": entity1, |
| "entity2": entity2, |
| "description": relation.get("description", "UNKNOWN"), |
| "created_at": created_at, |
| "file_path": relation.get("file_path", "unknown_source"), |
| } |
| ) |
|
|
| logger.debug( |
| f"Before truncation: {len(entities_context)} entities, {len(relations_context)} relations" |
| ) |
|
|
| |
| if entities_context: |
| |
| entities_context_for_truncation = [] |
| for entity in entities_context: |
| entity_copy = entity.copy() |
| entity_copy.pop("file_path", None) |
| entity_copy.pop("created_at", None) |
| entities_context_for_truncation.append(entity_copy) |
|
|
| entities_context = truncate_list_by_token_size( |
| entities_context_for_truncation, |
| key=lambda x: "\n".join( |
| json.dumps(item, ensure_ascii=False) for item in [x] |
| ), |
| max_token_size=max_entity_tokens, |
| tokenizer=tokenizer, |
| ) |
|
|
| if relations_context: |
| |
| relations_context_for_truncation = [] |
| for relation in relations_context: |
| relation_copy = relation.copy() |
| relation_copy.pop("file_path", None) |
| relation_copy.pop("created_at", None) |
| relations_context_for_truncation.append(relation_copy) |
|
|
| relations_context = truncate_list_by_token_size( |
| relations_context_for_truncation, |
| key=lambda x: "\n".join( |
| json.dumps(item, ensure_ascii=False) for item in [x] |
| ), |
| max_token_size=max_relation_tokens, |
| tokenizer=tokenizer, |
| ) |
|
|
| logger.info( |
| f"After truncation: {len(entities_context)} entities, {len(relations_context)} relations" |
| ) |
|
|
| |
| filtered_entities = [] |
| filtered_entity_id_to_original = {} |
| if entities_context: |
| final_entity_names = {e["entity"] for e in entities_context} |
| seen_nodes = set() |
| for entity in final_entities: |
| name = entity.get("entity_name") |
| if name in final_entity_names and name not in seen_nodes: |
| filtered_entities.append(entity) |
| filtered_entity_id_to_original[name] = entity |
| seen_nodes.add(name) |
|
|
| filtered_relations = [] |
| filtered_relation_id_to_original = {} |
| if relations_context: |
| final_relation_pairs = {(r["entity1"], r["entity2"]) for r in relations_context} |
| seen_edges = set() |
| for relation in final_relations: |
| src, tgt = relation.get("src_id"), relation.get("tgt_id") |
| if src is None or tgt is None: |
| src, tgt = relation.get("src_tgt", (None, None)) |
|
|
| pair = (src, tgt) |
| if pair in final_relation_pairs and pair not in seen_edges: |
| filtered_relations.append(relation) |
| filtered_relation_id_to_original[pair] = relation |
| seen_edges.add(pair) |
|
|
| return { |
| "entities_context": entities_context, |
| "relations_context": relations_context, |
| "filtered_entities": filtered_entities, |
| "filtered_relations": filtered_relations, |
| "entity_id_to_original": filtered_entity_id_to_original, |
| "relation_id_to_original": filtered_relation_id_to_original, |
| } |
|
|
|
|
| async def _merge_all_chunks( |
| filtered_entities: list[dict], |
| filtered_relations: list[dict], |
| vector_chunks: list[dict], |
| query: str = "", |
| knowledge_graph_inst: BaseGraphStorage = None, |
| text_chunks_db: BaseKVStorage = None, |
| query_param: QueryParam = None, |
| chunks_vdb: BaseVectorStorage = None, |
| chunk_tracking: dict = None, |
| query_embedding: list[float] = None, |
| ) -> list[dict]: |
| """ |
| Merge chunks from different sources: vector_chunks + entity_chunks + relation_chunks. |
| """ |
| if chunk_tracking is None: |
| chunk_tracking = {} |
|
|
| |
| entity_chunks = [] |
| if filtered_entities and text_chunks_db: |
| entity_chunks = await _find_related_text_unit_from_entities( |
| filtered_entities, |
| query_param, |
| text_chunks_db, |
| knowledge_graph_inst, |
| query, |
| chunks_vdb, |
| chunk_tracking=chunk_tracking, |
| query_embedding=query_embedding, |
| ) |
|
|
| |
| relation_chunks = [] |
| if filtered_relations and text_chunks_db: |
| relation_chunks = await _find_related_text_unit_from_relations( |
| filtered_relations, |
| query_param, |
| text_chunks_db, |
| entity_chunks, |
| query, |
| chunks_vdb, |
| chunk_tracking=chunk_tracking, |
| query_embedding=query_embedding, |
| ) |
|
|
| |
| merged_chunks = [] |
| seen_chunk_ids = set() |
| max_len = max(len(vector_chunks), len(entity_chunks), len(relation_chunks)) |
| origin_len = len(vector_chunks) + len(entity_chunks) + len(relation_chunks) |
|
|
| for i in range(max_len): |
| |
| if i < len(vector_chunks): |
| chunk = vector_chunks[i] |
| chunk_id = chunk.get("chunk_id") or chunk.get("id") |
| if chunk_id and chunk_id not in seen_chunk_ids: |
| seen_chunk_ids.add(chunk_id) |
| merged_chunks.append( |
| { |
| "content": chunk["content"], |
| "file_path": chunk.get("file_path", "unknown_source"), |
| "chunk_id": chunk_id, |
| } |
| ) |
|
|
| |
| if i < len(entity_chunks): |
| chunk = entity_chunks[i] |
| chunk_id = chunk.get("chunk_id") or chunk.get("id") |
| if chunk_id and chunk_id not in seen_chunk_ids: |
| seen_chunk_ids.add(chunk_id) |
| merged_chunks.append( |
| { |
| "content": chunk["content"], |
| "file_path": chunk.get("file_path", "unknown_source"), |
| "chunk_id": chunk_id, |
| } |
| ) |
|
|
| |
| if i < len(relation_chunks): |
| chunk = relation_chunks[i] |
| chunk_id = chunk.get("chunk_id") or chunk.get("id") |
| if chunk_id and chunk_id not in seen_chunk_ids: |
| seen_chunk_ids.add(chunk_id) |
| merged_chunks.append( |
| { |
| "content": chunk["content"], |
| "file_path": chunk.get("file_path", "unknown_source"), |
| "chunk_id": chunk_id, |
| } |
| ) |
|
|
| logger.info( |
| f"Round-robin merged chunks: {origin_len} -> {len(merged_chunks)} (deduplicated {origin_len - len(merged_chunks)})" |
| ) |
|
|
| return merged_chunks |
|
|
|
|
| async def _build_llm_context( |
| entities_context: list[dict], |
| relations_context: list[dict], |
| merged_chunks: list[dict], |
| query: str, |
| query_param: QueryParam, |
| global_config: dict[str, str], |
| chunk_tracking: dict = None, |
| entity_id_to_original: dict = None, |
| relation_id_to_original: dict = None, |
| ) -> tuple[str, dict[str, Any]]: |
| """ |
| Build the final LLM context string with token processing. |
| This includes dynamic token calculation and final chunk truncation. |
| """ |
| tokenizer = global_config.get("tokenizer") |
| if not tokenizer: |
| logger.error("Missing tokenizer, cannot build LLM context") |
| |
| empty_raw_data = convert_to_user_format( |
| [], |
| [], |
| [], |
| [], |
| query_param.mode, |
| ) |
| empty_raw_data["status"] = "failure" |
| empty_raw_data["message"] = "Missing tokenizer, cannot build LLM context." |
| return "", empty_raw_data |
|
|
| |
| max_total_tokens = getattr( |
| query_param, |
| "max_total_tokens", |
| global_config.get("max_total_tokens", DEFAULT_MAX_TOTAL_TOKENS), |
| ) |
|
|
| |
| sys_prompt_template = global_config.get( |
| "system_prompt_template", PROMPTS["rag_response"] |
| ) |
|
|
| kg_context_template = PROMPTS["kg_query_context"] |
| user_prompt = query_param.user_prompt if query_param.user_prompt else "" |
| response_type = ( |
| query_param.response_type |
| if query_param.response_type |
| else "Multiple Paragraphs" |
| ) |
|
|
| entities_str = "\n".join( |
| json.dumps(entity, ensure_ascii=False) for entity in entities_context |
| ) |
| relations_str = "\n".join( |
| json.dumps(relation, ensure_ascii=False) for relation in relations_context |
| ) |
|
|
| |
| pre_kg_context = kg_context_template.format( |
| entities_str=entities_str, |
| relations_str=relations_str, |
| text_chunks_str="", |
| reference_list_str="", |
| ) |
| kg_context_tokens = len(tokenizer.encode(pre_kg_context)) |
|
|
| |
| pre_sys_prompt = sys_prompt_template.format( |
| context_data="", |
| response_type=response_type, |
| user_prompt=user_prompt, |
| ) |
| sys_prompt_tokens = len(tokenizer.encode(pre_sys_prompt)) |
|
|
| |
| query_tokens = len(tokenizer.encode(query)) |
| buffer_tokens = 200 |
| available_chunk_tokens = max_total_tokens - ( |
| sys_prompt_tokens + kg_context_tokens + query_tokens + buffer_tokens |
| ) |
|
|
| logger.debug( |
| f"Token allocation - Total: {max_total_tokens}, SysPrompt: {sys_prompt_tokens}, Query: {query_tokens}, KG: {kg_context_tokens}, Buffer: {buffer_tokens}, Available for chunks: {available_chunk_tokens}" |
| ) |
|
|
| |
| truncated_chunks = await process_chunks_unified( |
| query=query, |
| unique_chunks=merged_chunks, |
| query_param=query_param, |
| global_config=global_config, |
| source_type=query_param.mode, |
| chunk_token_limit=available_chunk_tokens, |
| ) |
|
|
| |
| reference_list, truncated_chunks = generate_reference_list_from_chunks( |
| truncated_chunks |
| ) |
|
|
| |
| |
| text_units_context = [] |
| for i, chunk in enumerate(truncated_chunks): |
| text_units_context.append( |
| { |
| "reference_id": chunk["reference_id"], |
| "content": chunk["content"], |
| } |
| ) |
|
|
| logger.info( |
| f"Final context: {len(entities_context)} entities, {len(relations_context)} relations, {len(text_units_context)} chunks" |
| ) |
|
|
| |
| if not entities_context and not relations_context: |
| |
| empty_raw_data = convert_to_user_format( |
| [], |
| [], |
| [], |
| [], |
| query_param.mode, |
| ) |
| empty_raw_data["status"] = "failure" |
| empty_raw_data["message"] = "Query returned empty dataset." |
| return "", empty_raw_data |
|
|
| |
| |
| if truncated_chunks and chunk_tracking: |
| chunk_tracking_log = [] |
| for chunk in truncated_chunks: |
| chunk_id = chunk.get("chunk_id") |
| if chunk_id and chunk_id in chunk_tracking: |
| tracking_info = chunk_tracking[chunk_id] |
| source = tracking_info["source"] |
| frequency = tracking_info["frequency"] |
| order = tracking_info["order"] |
| chunk_tracking_log.append(f"{source}{frequency}/{order}") |
| else: |
| chunk_tracking_log.append("?0/0") |
|
|
| if chunk_tracking_log: |
| logger.info(f"chunks S+F/O: {' '.join(chunk_tracking_log)}") |
|
|
| text_units_str = "\n".join( |
| json.dumps(text_unit, ensure_ascii=False) for text_unit in text_units_context |
| ) |
| reference_list_str = "\n".join( |
| f"[{ref['reference_id']}] {ref['file_path']}" |
| for ref in reference_list |
| if ref["reference_id"] |
| ) |
|
|
| result = kg_context_template.format( |
| entities_str=entities_str, |
| relations_str=relations_str, |
| text_chunks_str=text_units_str, |
| reference_list_str=reference_list_str, |
| ) |
|
|
| |
| logger.debug( |
| f"[_build_llm_context] Converting to user format: {len(entities_context)} entities, {len(relations_context)} relations, {len(truncated_chunks)} chunks" |
| ) |
| final_data = convert_to_user_format( |
| entities_context, |
| relations_context, |
| truncated_chunks, |
| reference_list, |
| query_param.mode, |
| entity_id_to_original, |
| relation_id_to_original, |
| ) |
| logger.debug( |
| f"[_build_llm_context] Final data after conversion: {len(final_data.get('entities', []))} entities, {len(final_data.get('relationships', []))} relationships, {len(final_data.get('chunks', []))} chunks" |
| ) |
| return result, final_data |
|
|
|
|
| |
| async def _build_query_context( |
| query: str, |
| ll_keywords: str, |
| hl_keywords: str, |
| knowledge_graph_inst: BaseGraphStorage, |
| entities_vdb: BaseVectorStorage, |
| relationships_vdb: BaseVectorStorage, |
| text_chunks_db: BaseKVStorage, |
| query_param: QueryParam, |
| chunks_vdb: BaseVectorStorage = None, |
| ) -> QueryContextResult | None: |
| """ |
| Main query context building function using the new 4-stage architecture: |
| 1. Search -> 2. Truncate -> 3. Merge chunks -> 4. Build LLM context |
| |
| Returns unified QueryContextResult containing both context and raw_data. |
| """ |
|
|
| if not query: |
| logger.warning("Query is empty, skipping context building") |
| return None |
|
|
| |
| search_result = await _perform_kg_search( |
| query, |
| ll_keywords, |
| hl_keywords, |
| knowledge_graph_inst, |
| entities_vdb, |
| relationships_vdb, |
| text_chunks_db, |
| query_param, |
| chunks_vdb, |
| ) |
|
|
| if not search_result["final_entities"] and not search_result["final_relations"]: |
| if query_param.mode != "mix": |
| return None |
| else: |
| if not search_result["chunk_tracking"]: |
| return None |
|
|
| |
| truncation_result = await _apply_token_truncation( |
| search_result, |
| query_param, |
| text_chunks_db.global_config, |
| ) |
|
|
| |
| merged_chunks = await _merge_all_chunks( |
| filtered_entities=truncation_result["filtered_entities"], |
| filtered_relations=truncation_result["filtered_relations"], |
| vector_chunks=search_result["vector_chunks"], |
| query=query, |
| knowledge_graph_inst=knowledge_graph_inst, |
| text_chunks_db=text_chunks_db, |
| query_param=query_param, |
| chunks_vdb=chunks_vdb, |
| chunk_tracking=search_result["chunk_tracking"], |
| query_embedding=search_result["query_embedding"], |
| ) |
|
|
| if not merged_chunks: |
| return None |
|
|
| |
| |
| context, raw_data = await _build_llm_context( |
| entities_context=truncation_result["entities_context"], |
| relations_context=truncation_result["relations_context"], |
| merged_chunks=merged_chunks, |
| query=query, |
| query_param=query_param, |
| global_config=text_chunks_db.global_config, |
| chunk_tracking=search_result["chunk_tracking"], |
| entity_id_to_original=truncation_result["entity_id_to_original"], |
| relation_id_to_original=truncation_result["relation_id_to_original"], |
| ) |
|
|
| |
| hl_keywords_list = hl_keywords.split(", ") if hl_keywords else [] |
| ll_keywords_list = ll_keywords.split(", ") if ll_keywords else [] |
|
|
| |
| if "metadata" not in raw_data: |
| raw_data["metadata"] = {} |
|
|
| |
| raw_data["metadata"]["keywords"] = { |
| "high_level": hl_keywords_list, |
| "low_level": ll_keywords_list, |
| } |
| raw_data["metadata"]["processing_info"] = { |
| "total_entities_found": len(search_result.get("final_entities", [])), |
| "total_relations_found": len(search_result.get("final_relations", [])), |
| "entities_after_truncation": len( |
| truncation_result.get("filtered_entities", []) |
| ), |
| "relations_after_truncation": len( |
| truncation_result.get("filtered_relations", []) |
| ), |
| "merged_chunks_count": len(merged_chunks), |
| "final_chunks_count": len(raw_data.get("data", {}).get("chunks", [])), |
| } |
|
|
| logger.debug( |
| f"[_build_query_context] Context length: {len(context) if context else 0}" |
| ) |
| logger.debug( |
| f"[_build_query_context] Raw data entities: {len(raw_data.get('data', {}).get('entities', []))}, relationships: {len(raw_data.get('data', {}).get('relationships', []))}, chunks: {len(raw_data.get('data', {}).get('chunks', []))}" |
| ) |
|
|
| return QueryContextResult(context=context, raw_data=raw_data) |
|
|
|
|
| async def _get_node_data( |
| query: str, |
| knowledge_graph_inst: BaseGraphStorage, |
| entities_vdb: BaseVectorStorage, |
| query_param: QueryParam, |
| ): |
| |
| logger.info( |
| f"Query nodes: {query} (top_k:{query_param.top_k}, cosine:{entities_vdb.cosine_better_than_threshold})" |
| ) |
|
|
| results = await entities_vdb.query(query, top_k=query_param.top_k) |
|
|
| if not len(results): |
| return [], [] |
|
|
| |
| node_ids = [r["entity_name"] for r in results] |
|
|
| |
| nodes_dict, degrees_dict = await asyncio.gather( |
| knowledge_graph_inst.get_nodes_batch(node_ids), |
| knowledge_graph_inst.node_degrees_batch(node_ids), |
| ) |
|
|
| |
| node_datas = [nodes_dict.get(nid) for nid in node_ids] |
| node_degrees = [degrees_dict.get(nid, 0) for nid in node_ids] |
|
|
| if not all([n is not None for n in node_datas]): |
| logger.warning("Some nodes are missing, maybe the storage is damaged") |
|
|
| node_datas = [ |
| { |
| **n, |
| "entity_name": k["entity_name"], |
| "rank": d, |
| "created_at": k.get("created_at"), |
| } |
| for k, n, d in zip(results, node_datas, node_degrees) |
| if n is not None |
| ] |
|
|
| use_relations = await _find_most_related_edges_from_entities( |
| node_datas, |
| query_param, |
| knowledge_graph_inst, |
| ) |
|
|
| logger.info( |
| f"Local query: {len(node_datas)} entites, {len(use_relations)} relations" |
| ) |
|
|
| |
| |
| return node_datas, use_relations |
|
|
|
|
| async def _find_most_related_edges_from_entities( |
| node_datas: list[dict], |
| query_param: QueryParam, |
| knowledge_graph_inst: BaseGraphStorage, |
| ): |
| node_names = [dp["entity_name"] for dp in node_datas] |
| batch_edges_dict = await knowledge_graph_inst.get_nodes_edges_batch(node_names) |
|
|
| all_edges = [] |
| seen = set() |
|
|
| for node_name in node_names: |
| this_edges = batch_edges_dict.get(node_name, []) |
| for e in this_edges: |
| sorted_edge = tuple(sorted(e)) |
| if sorted_edge not in seen: |
| seen.add(sorted_edge) |
| all_edges.append(sorted_edge) |
|
|
| |
| |
| edge_pairs_dicts = [{"src": e[0], "tgt": e[1]} for e in all_edges] |
| |
| edge_pairs_tuples = list(all_edges) |
|
|
| |
| edge_data_dict, edge_degrees_dict = await asyncio.gather( |
| knowledge_graph_inst.get_edges_batch(edge_pairs_dicts), |
| knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples), |
| ) |
|
|
| |
| all_edges_data = [] |
| for pair in all_edges: |
| edge_props = edge_data_dict.get(pair) |
| if edge_props is not None: |
| if "weight" not in edge_props: |
| logger.warning( |
| f"Edge {pair} missing 'weight' attribute, using default value 1.0" |
| ) |
| edge_props["weight"] = 1.0 |
|
|
| combined = { |
| "src_tgt": pair, |
| "rank": edge_degrees_dict.get(pair, 0), |
| **edge_props, |
| } |
| all_edges_data.append(combined) |
|
|
| all_edges_data = sorted( |
| all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True |
| ) |
|
|
| return all_edges_data |
|
|
|
|
| async def _find_related_text_unit_from_entities( |
| node_datas: list[dict], |
| query_param: QueryParam, |
| text_chunks_db: BaseKVStorage, |
| knowledge_graph_inst: BaseGraphStorage, |
| query: str = None, |
| chunks_vdb: BaseVectorStorage = None, |
| chunk_tracking: dict = None, |
| query_embedding=None, |
| ): |
| """ |
| Find text chunks related to entities using configurable chunk selection method. |
| |
| This function supports two chunk selection strategies: |
| 1. WEIGHT: Linear gradient weighted polling based on chunk occurrence count |
| 2. VECTOR: Vector similarity-based selection using embedding cosine similarity |
| """ |
| logger.debug(f"Finding text chunks from {len(node_datas)} entities") |
|
|
| if not node_datas: |
| return [] |
|
|
| |
| entities_with_chunks = [] |
| for entity in node_datas: |
| if entity.get("source_id"): |
| chunks = split_string_by_multi_markers( |
| entity["source_id"], [GRAPH_FIELD_SEP] |
| ) |
| if chunks: |
| entities_with_chunks.append( |
| { |
| "entity_name": entity["entity_name"], |
| "chunks": chunks, |
| "entity_data": entity, |
| } |
| ) |
|
|
| if not entities_with_chunks: |
| logger.warning("No entities with text chunks found") |
| return [] |
|
|
| kg_chunk_pick_method = text_chunks_db.global_config.get( |
| "kg_chunk_pick_method", DEFAULT_KG_CHUNK_PICK_METHOD |
| ) |
| max_related_chunks = text_chunks_db.global_config.get( |
| "related_chunk_number", DEFAULT_RELATED_CHUNK_NUMBER |
| ) |
|
|
| |
| chunk_occurrence_count = {} |
| for entity_info in entities_with_chunks: |
| deduplicated_chunks = [] |
| for chunk_id in entity_info["chunks"]: |
| chunk_occurrence_count[chunk_id] = ( |
| chunk_occurrence_count.get(chunk_id, 0) + 1 |
| ) |
|
|
| |
| if chunk_occurrence_count[chunk_id] == 1: |
| deduplicated_chunks.append(chunk_id) |
| |
|
|
| |
| entity_info["chunks"] = deduplicated_chunks |
|
|
| |
| total_entity_chunks = 0 |
| for entity_info in entities_with_chunks: |
| sorted_chunks = sorted( |
| entity_info["chunks"], |
| key=lambda chunk_id: chunk_occurrence_count.get(chunk_id, 0), |
| reverse=True, |
| ) |
| entity_info["sorted_chunks"] = sorted_chunks |
| total_entity_chunks += len(sorted_chunks) |
|
|
| selected_chunk_ids = [] |
|
|
| |
| |
| |
| |
| if kg_chunk_pick_method == "VECTOR" and query and chunks_vdb: |
| num_of_chunks = int(max_related_chunks * len(entities_with_chunks) / 2) |
|
|
| |
| embedding_func_config = text_chunks_db.embedding_func |
| if not embedding_func_config: |
| logger.warning("No embedding function found, falling back to WEIGHT method") |
| kg_chunk_pick_method = "WEIGHT" |
| else: |
| try: |
| actual_embedding_func = embedding_func_config |
|
|
| selected_chunk_ids = None |
| if actual_embedding_func: |
| selected_chunk_ids = await pick_by_vector_similarity( |
| query=query, |
| text_chunks_storage=text_chunks_db, |
| chunks_vdb=chunks_vdb, |
| num_of_chunks=num_of_chunks, |
| entity_info=entities_with_chunks, |
| embedding_func=actual_embedding_func, |
| query_embedding=query_embedding, |
| ) |
|
|
| if selected_chunk_ids == []: |
| kg_chunk_pick_method = "WEIGHT" |
| logger.warning( |
| "No entity-related chunks selected by vector similarity, falling back to WEIGHT method" |
| ) |
| else: |
| logger.info( |
| f"Selecting {len(selected_chunk_ids)} from {total_entity_chunks} entity-related chunks by vector similarity" |
| ) |
|
|
| except Exception as e: |
| logger.error( |
| f"Error in vector similarity sorting: {e}, falling back to WEIGHT method" |
| ) |
| kg_chunk_pick_method = "WEIGHT" |
|
|
| if kg_chunk_pick_method == "WEIGHT": |
| |
| |
| selected_chunk_ids = pick_by_weighted_polling( |
| entities_with_chunks, max_related_chunks, min_related_chunks=1 |
| ) |
|
|
| logger.info( |
| f"Selecting {len(selected_chunk_ids)} from {total_entity_chunks} entity-related chunks by weighted polling" |
| ) |
|
|
| if not selected_chunk_ids: |
| return [] |
|
|
| |
| unique_chunk_ids = list( |
| dict.fromkeys(selected_chunk_ids) |
| ) |
| chunk_data_list = await text_chunks_db.get_by_ids(unique_chunk_ids) |
|
|
| |
| result_chunks = [] |
| for i, (chunk_id, chunk_data) in enumerate(zip(unique_chunk_ids, chunk_data_list)): |
| if chunk_data is not None and "content" in chunk_data: |
| chunk_data_copy = chunk_data.copy() |
| chunk_data_copy["source_type"] = "entity" |
| chunk_data_copy["chunk_id"] = chunk_id |
| result_chunks.append(chunk_data_copy) |
|
|
| |
| if chunk_tracking is not None: |
| chunk_tracking[chunk_id] = { |
| "source": "E", |
| "frequency": chunk_occurrence_count.get(chunk_id, 1), |
| "order": i + 1, |
| } |
|
|
| return result_chunks |
|
|
|
|
| async def _get_edge_data( |
| keywords, |
| knowledge_graph_inst: BaseGraphStorage, |
| relationships_vdb: BaseVectorStorage, |
| query_param: QueryParam, |
| ): |
| logger.info( |
| f"Query edges: {keywords} (top_k:{query_param.top_k}, cosine:{relationships_vdb.cosine_better_than_threshold})" |
| ) |
|
|
| results = await relationships_vdb.query(keywords, top_k=query_param.top_k) |
|
|
| if not len(results): |
| return [], [] |
|
|
| |
| |
| edge_pairs_dicts = [{"src": r["src_id"], "tgt": r["tgt_id"]} for r in results] |
| edge_data_dict = await knowledge_graph_inst.get_edges_batch(edge_pairs_dicts) |
|
|
| |
| edge_datas = [] |
| for k in results: |
| pair = (k["src_id"], k["tgt_id"]) |
| edge_props = edge_data_dict.get(pair) |
| if edge_props is not None: |
| if "weight" not in edge_props: |
| logger.warning( |
| f"Edge {pair} missing 'weight' attribute, using default value 1.0" |
| ) |
| edge_props["weight"] = 1.0 |
|
|
| |
| combined = { |
| "src_id": k["src_id"], |
| "tgt_id": k["tgt_id"], |
| "created_at": k.get("created_at", None), |
| **edge_props, |
| } |
| edge_datas.append(combined) |
|
|
| |
|
|
| use_entities = await _find_most_related_entities_from_relationships( |
| edge_datas, |
| query_param, |
| knowledge_graph_inst, |
| ) |
|
|
| logger.info( |
| f"Global query: {len(use_entities)} entites, {len(edge_datas)} relations" |
| ) |
|
|
| return edge_datas, use_entities |
|
|
|
|
| async def _find_most_related_entities_from_relationships( |
| edge_datas: list[dict], |
| query_param: QueryParam, |
| knowledge_graph_inst: BaseGraphStorage, |
| ): |
| entity_names = [] |
| seen = set() |
|
|
| for e in edge_datas: |
| if e["src_id"] not in seen: |
| entity_names.append(e["src_id"]) |
| seen.add(e["src_id"]) |
| if e["tgt_id"] not in seen: |
| entity_names.append(e["tgt_id"]) |
| seen.add(e["tgt_id"]) |
|
|
| |
| nodes_dict = await knowledge_graph_inst.get_nodes_batch(entity_names) |
|
|
| |
| node_datas = [] |
| for entity_name in entity_names: |
| node = nodes_dict.get(entity_name) |
| if node is None: |
| logger.warning(f"Node '{entity_name}' not found in batch retrieval.") |
| continue |
| |
| combined = {**node, "entity_name": entity_name} |
| node_datas.append(combined) |
|
|
| return node_datas |
|
|
|
|
| async def _find_related_text_unit_from_relations( |
| edge_datas: list[dict], |
| query_param: QueryParam, |
| text_chunks_db: BaseKVStorage, |
| entity_chunks: list[dict] = None, |
| query: str = None, |
| chunks_vdb: BaseVectorStorage = None, |
| chunk_tracking: dict = None, |
| query_embedding=None, |
| ): |
| """ |
| Find text chunks related to relationships using configurable chunk selection method. |
| |
| This function supports two chunk selection strategies: |
| 1. WEIGHT: Linear gradient weighted polling based on chunk occurrence count |
| 2. VECTOR: Vector similarity-based selection using embedding cosine similarity |
| """ |
| logger.debug(f"Finding text chunks from {len(edge_datas)} relations") |
|
|
| if not edge_datas: |
| return [] |
|
|
| |
| relations_with_chunks = [] |
| for relation in edge_datas: |
| if relation.get("source_id"): |
| chunks = split_string_by_multi_markers( |
| relation["source_id"], [GRAPH_FIELD_SEP] |
| ) |
| if chunks: |
| |
| if "src_tgt" in relation: |
| rel_key = tuple(sorted(relation["src_tgt"])) |
| else: |
| rel_key = tuple( |
| sorted([relation.get("src_id"), relation.get("tgt_id")]) |
| ) |
|
|
| relations_with_chunks.append( |
| { |
| "relation_key": rel_key, |
| "chunks": chunks, |
| "relation_data": relation, |
| } |
| ) |
|
|
| if not relations_with_chunks: |
| logger.warning("No relation-related chunks found") |
| return [] |
|
|
| kg_chunk_pick_method = text_chunks_db.global_config.get( |
| "kg_chunk_pick_method", DEFAULT_KG_CHUNK_PICK_METHOD |
| ) |
| max_related_chunks = text_chunks_db.global_config.get( |
| "related_chunk_number", DEFAULT_RELATED_CHUNK_NUMBER |
| ) |
|
|
| |
| |
|
|
| |
| entity_chunk_ids = set() |
| if entity_chunks: |
| for chunk in entity_chunks: |
| chunk_id = chunk.get("chunk_id") |
| if chunk_id: |
| entity_chunk_ids.add(chunk_id) |
|
|
| chunk_occurrence_count = {} |
| |
| removed_entity_chunk_ids = set() |
|
|
| for relation_info in relations_with_chunks: |
| deduplicated_chunks = [] |
| for chunk_id in relation_info["chunks"]: |
| |
| if chunk_id in entity_chunk_ids: |
| |
| removed_entity_chunk_ids.add(chunk_id) |
| continue |
|
|
| chunk_occurrence_count[chunk_id] = ( |
| chunk_occurrence_count.get(chunk_id, 0) + 1 |
| ) |
|
|
| |
| if chunk_occurrence_count[chunk_id] == 1: |
| deduplicated_chunks.append(chunk_id) |
| |
|
|
| |
| relation_info["chunks"] = deduplicated_chunks |
|
|
| |
| relations_with_chunks = [ |
| relation_info |
| for relation_info in relations_with_chunks |
| if relation_info["chunks"] |
| ] |
|
|
| if not relations_with_chunks: |
| logger.info( |
| f"Find no additional relations-related chunks from {len(edge_datas)} relations" |
| ) |
| return [] |
|
|
| |
| total_relation_chunks = 0 |
| for relation_info in relations_with_chunks: |
| sorted_chunks = sorted( |
| relation_info["chunks"], |
| key=lambda chunk_id: chunk_occurrence_count.get(chunk_id, 0), |
| reverse=True, |
| ) |
| relation_info["sorted_chunks"] = sorted_chunks |
| total_relation_chunks += len(sorted_chunks) |
|
|
| logger.info( |
| f"Find {total_relation_chunks} additional chunks in {len(relations_with_chunks)} relations (deduplicated {len(removed_entity_chunk_ids)})" |
| ) |
|
|
| |
| selected_chunk_ids = [] |
|
|
| if kg_chunk_pick_method == "VECTOR" and query and chunks_vdb: |
| num_of_chunks = int(max_related_chunks * len(relations_with_chunks) / 2) |
|
|
| |
| embedding_func_config = text_chunks_db.embedding_func |
| if not embedding_func_config: |
| logger.warning("No embedding function found, falling back to WEIGHT method") |
| kg_chunk_pick_method = "WEIGHT" |
| else: |
| try: |
| actual_embedding_func = embedding_func_config |
|
|
| if actual_embedding_func: |
| selected_chunk_ids = await pick_by_vector_similarity( |
| query=query, |
| text_chunks_storage=text_chunks_db, |
| chunks_vdb=chunks_vdb, |
| num_of_chunks=num_of_chunks, |
| entity_info=relations_with_chunks, |
| embedding_func=actual_embedding_func, |
| query_embedding=query_embedding, |
| ) |
|
|
| if selected_chunk_ids == []: |
| kg_chunk_pick_method = "WEIGHT" |
| logger.warning( |
| "No relation-related chunks selected by vector similarity, falling back to WEIGHT method" |
| ) |
| else: |
| logger.info( |
| f"Selecting {len(selected_chunk_ids)} from {total_relation_chunks} relation-related chunks by vector similarity" |
| ) |
|
|
| except Exception as e: |
| logger.error( |
| f"Error in vector similarity sorting: {e}, falling back to WEIGHT method" |
| ) |
| kg_chunk_pick_method = "WEIGHT" |
|
|
| if kg_chunk_pick_method == "WEIGHT": |
| |
| selected_chunk_ids = pick_by_weighted_polling( |
| relations_with_chunks, max_related_chunks, min_related_chunks=1 |
| ) |
|
|
| logger.info( |
| f"Selecting {len(selected_chunk_ids)} from {total_relation_chunks} relation-related chunks by weighted polling" |
| ) |
|
|
| logger.debug( |
| f"KG related chunks: {len(entity_chunks)} from entitys, {len(selected_chunk_ids)} from relations" |
| ) |
|
|
| if not selected_chunk_ids: |
| return [] |
|
|
| |
| unique_chunk_ids = list( |
| dict.fromkeys(selected_chunk_ids) |
| ) |
| chunk_data_list = await text_chunks_db.get_by_ids(unique_chunk_ids) |
|
|
| |
| result_chunks = [] |
| for i, (chunk_id, chunk_data) in enumerate(zip(unique_chunk_ids, chunk_data_list)): |
| if chunk_data is not None and "content" in chunk_data: |
| chunk_data_copy = chunk_data.copy() |
| chunk_data_copy["source_type"] = "relationship" |
| chunk_data_copy["chunk_id"] = chunk_id |
| result_chunks.append(chunk_data_copy) |
|
|
| |
| if chunk_tracking is not None: |
| chunk_tracking[chunk_id] = { |
| "source": "R", |
| "frequency": chunk_occurrence_count.get(chunk_id, 1), |
| "order": i + 1, |
| } |
|
|
| return result_chunks |
|
|
|
|
| @overload |
| async def naive_query( |
| query: str, |
| chunks_vdb: BaseVectorStorage, |
| query_param: QueryParam, |
| global_config: dict[str, str], |
| hashing_kv: BaseKVStorage | None = None, |
| system_prompt: str | None = None, |
| return_raw_data: Literal[True] = True, |
| ) -> dict[str, Any]: ... |
|
|
|
|
| @overload |
| async def naive_query( |
| query: str, |
| chunks_vdb: BaseVectorStorage, |
| query_param: QueryParam, |
| global_config: dict[str, str], |
| hashing_kv: BaseKVStorage | None = None, |
| system_prompt: str | None = None, |
| return_raw_data: Literal[False] = False, |
| ) -> str | AsyncIterator[str]: ... |
|
|
|
|
| async def naive_query( |
| query: str, |
| chunks_vdb: BaseVectorStorage, |
| query_param: QueryParam, |
| global_config: dict[str, str], |
| hashing_kv: BaseKVStorage | None = None, |
| system_prompt: str | None = None, |
| ) -> QueryResult: |
| logger.info("in naive_query") |
| """ |
| Execute naive query and return unified QueryResult object. |
| |
| Args: |
| query: Query string |
| chunks_vdb: Document chunks vector database |
| query_param: Query parameters |
| global_config: Global configuration |
| hashing_kv: Cache storage |
| system_prompt: System prompt |
| |
| Returns: |
| QueryResult: Unified query result object containing: |
| - content: Non-streaming response text content |
| - response_iterator: Streaming response iterator |
| - raw_data: Complete structured data (including references and metadata) |
| - is_streaming: Whether this is a streaming result |
| """ |
|
|
| if not query: |
| return QueryResult(content=PROMPTS["fail_response"]) |
|
|
| if query_param.model_func: |
| use_model_func = query_param.model_func |
| else: |
| use_model_func = global_config["llm_model_func"] |
| logger.info(f"use_model_func: {use_model_func}") |
| logger.info(f"dir(use_model_func): {dir(use_model_func)}") |
| if not isinstance(use_model_func, partial): |
| |
| use_model_func = partial(use_model_func, _priority=5) |
| logger.info(f"type of use_model_func: {type(use_model_func)}") |
|
|
| tokenizer: Tokenizer = global_config["tokenizer"] |
| if not tokenizer: |
| logger.error("Tokenizer not found in global configuration.") |
| return QueryResult(content=PROMPTS["fail_response"]) |
|
|
| chunks = await _get_vector_context(query, chunks_vdb, query_param, None) |
|
|
| if chunks is None or len(chunks) == 0: |
| |
| empty_raw_data = convert_to_user_format( |
| [], |
| [], |
| [], |
| [], |
| "naive", |
| ) |
| empty_raw_data["message"] = "No relevant document chunks found." |
| return QueryResult(content=PROMPTS["fail_response"], raw_data=empty_raw_data) |
|
|
| |
| max_total_tokens = getattr( |
| query_param, |
| "max_total_tokens", |
| global_config.get("max_total_tokens", DEFAULT_MAX_TOTAL_TOKENS), |
| ) |
|
|
| |
| user_prompt = f"\n\n{query_param.user_prompt}" if query_param.user_prompt else "n/a" |
| response_type = ( |
| query_param.response_type |
| if query_param.response_type |
| else "Multiple Paragraphs" |
| ) |
|
|
| |
| sys_prompt_template = ( |
| system_prompt if system_prompt else PROMPTS["naive_rag_response"] |
| ) |
|
|
| |
| pre_sys_prompt = sys_prompt_template.format( |
| response_type=response_type, |
| user_prompt=user_prompt, |
| content_data="", |
| ) |
|
|
| |
| sys_prompt_tokens = len(tokenizer.encode(pre_sys_prompt)) |
| query_tokens = len(tokenizer.encode(query)) |
| buffer_tokens = 200 |
| available_chunk_tokens = max_total_tokens - ( |
| sys_prompt_tokens + query_tokens + buffer_tokens |
| ) |
|
|
| logger.debug( |
| f"Naive query token allocation - Total: {max_total_tokens}, SysPrompt: {sys_prompt_tokens}, Query: {query_tokens}, Buffer: {buffer_tokens}, Available for chunks: {available_chunk_tokens}" |
| ) |
|
|
| |
| processed_chunks = await process_chunks_unified( |
| query=query, |
| unique_chunks=chunks, |
| query_param=query_param, |
| global_config=global_config, |
| source_type="vector", |
| chunk_token_limit=available_chunk_tokens, |
| ) |
|
|
| |
| reference_list, processed_chunks_with_ref_ids = generate_reference_list_from_chunks( |
| processed_chunks |
| ) |
|
|
| logger.info(f"Final context: {len(processed_chunks_with_ref_ids)} chunks") |
|
|
| |
| raw_data = convert_to_user_format( |
| [], |
| [], |
| processed_chunks_with_ref_ids, |
| reference_list, |
| "naive", |
| ) |
|
|
| |
| if "metadata" not in raw_data: |
| raw_data["metadata"] = {} |
| raw_data["metadata"]["keywords"] = { |
| "high_level": [], |
| "low_level": [], |
| } |
| raw_data["metadata"]["processing_info"] = { |
| "total_chunks_found": len(chunks), |
| "final_chunks_count": len(processed_chunks_with_ref_ids), |
| } |
|
|
| |
| text_units_context = [] |
| for i, chunk in enumerate(processed_chunks_with_ref_ids): |
| text_units_context.append( |
| { |
| "reference_id": chunk["reference_id"], |
| "content": chunk["content"], |
| } |
| ) |
|
|
| text_units_str = "\n".join( |
| json.dumps(text_unit, ensure_ascii=False) for text_unit in text_units_context |
| ) |
| reference_list_str = "\n".join( |
| f"[{ref['reference_id']}] {ref['file_path']}" |
| for ref in reference_list |
| if ref["reference_id"] |
| ) |
|
|
| naive_context_template = PROMPTS["naive_query_context"] |
| context_content = naive_context_template.format( |
| text_chunks_str=text_units_str, |
| reference_list_str=reference_list_str, |
| ) |
|
|
| if query_param.only_need_context and not query_param.only_need_prompt: |
| return QueryResult(content=context_content, raw_data=raw_data) |
|
|
| sys_prompt = sys_prompt_template.format( |
| response_type=query_param.response_type, |
| user_prompt=user_prompt, |
| content_data=context_content, |
| ) |
|
|
| user_query = query |
|
|
| if query_param.only_need_prompt: |
| prompt_content = "\n\n".join([sys_prompt, "---User Query---", user_query]) |
| return QueryResult(content=prompt_content, raw_data=raw_data) |
|
|
| |
| args_hash = compute_args_hash( |
| query_param.mode, |
| query, |
| query_param.response_type, |
| query_param.top_k, |
| query_param.chunk_top_k, |
| query_param.max_entity_tokens, |
| query_param.max_relation_tokens, |
| query_param.max_total_tokens, |
| query_param.user_prompt or "", |
| query_param.enable_rerank, |
| ) |
| cached_result = await handle_cache( |
| hashing_kv, args_hash, user_query, query_param.mode, cache_type="query" |
| ) |
| if cached_result is not None: |
| cached_response, _ = cached_result |
| logger.info( |
| " == LLM cache == Query cache hit, using cached response as query result" |
| ) |
| response = cached_response |
| else: |
| response = await use_model_func( |
| user_query, |
| system_prompt=sys_prompt, |
| history_messages=query_param.conversation_history, |
| enable_cot=True, |
| stream=query_param.stream, |
| ) |
|
|
| if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"): |
| queryparam_dict = { |
| "mode": query_param.mode, |
| "response_type": query_param.response_type, |
| "top_k": query_param.top_k, |
| "chunk_top_k": query_param.chunk_top_k, |
| "max_entity_tokens": query_param.max_entity_tokens, |
| "max_relation_tokens": query_param.max_relation_tokens, |
| "max_total_tokens": query_param.max_total_tokens, |
| "user_prompt": query_param.user_prompt or "", |
| "enable_rerank": query_param.enable_rerank, |
| } |
| await save_to_cache( |
| hashing_kv, |
| CacheData( |
| args_hash=args_hash, |
| content=response, |
| prompt=query, |
| mode=query_param.mode, |
| cache_type="query", |
| queryparam=queryparam_dict, |
| ), |
| ) |
|
|
| |
| if isinstance(response, str): |
| |
| if len(response) > len(sys_prompt): |
| response = ( |
| response[len(sys_prompt) :] |
| .replace(sys_prompt, "") |
| .replace("user", "") |
| .replace("model", "") |
| .replace(query, "") |
| .replace("<system>", "") |
| .replace("</system>", "") |
| .strip() |
| ) |
|
|
| return QueryResult(content=response, raw_data=raw_data) |
| else: |
| |
| return QueryResult( |
| response_iterator=response, raw_data=raw_data, is_streaming=True |
| ) |
|
|