| """ |
| Specialized processors for different modalities |
| |
| Includes: |
| - ContextExtractor: Universal context extraction for multimodal content |
| - ImageModalProcessor: Specialized processor for image content |
| - TableModalProcessor: Specialized processor for table content |
| - EquationModalProcessor: Specialized processor for equation content |
| - GenericModalProcessor: Processor for other modal content |
| """ |
|
|
| import re |
| import json |
| import time |
| import base64 |
| from typing import Dict, Any, Tuple, List |
| from pathlib import Path |
| from dataclasses import dataclass |
|
|
| from lightrag.utils import ( |
| logger, |
| compute_mdhash_id, |
| ) |
| from lightrag.lightrag import LightRAG |
| from dataclasses import asdict |
| from lightrag.kg.shared_storage import get_namespace_data, get_pipeline_status_lock |
| from lightrag.operate import extract_entities, merge_nodes_and_edges |
|
|
| |
| from raganything.prompt import PROMPTS |
|
|
|
|
| @dataclass |
| class ContextConfig: |
| """Configuration for context extraction""" |
|
|
| context_window: int = 1 |
| context_mode: str = "page" |
| max_context_tokens: int = 2000 |
| include_headers: bool = True |
| include_captions: bool = True |
| filter_content_types: List[str] = None |
|
|
| def __post_init__(self): |
| if self.filter_content_types is None: |
| self.filter_content_types = ["text"] |
|
|
|
|
| class ContextExtractor: |
| """Universal context extractor supporting multiple content source formats""" |
|
|
| def __init__(self, config: ContextConfig = None, tokenizer=None): |
| """Initialize context extractor |
| |
| Args: |
| config: Context extraction configuration |
| tokenizer: Tokenizer for accurate token counting |
| """ |
| self.config = config or ContextConfig() |
| self.tokenizer = tokenizer |
|
|
| def extract_context( |
| self, |
| content_source: Any, |
| current_item_info: Dict[str, Any], |
| content_format: str = "auto", |
| ) -> str: |
| """Extract context for current item from content source |
| |
| Args: |
| content_source: Source content (list, dict, or other format) |
| current_item_info: Information about current item (page_idx, index, etc.) |
| content_format: Format hint for content source ("minerU", "text_chunks", "auto", etc.) |
| |
| Returns: |
| Extracted context text |
| """ |
| if not content_source and not self.config.context_window: |
| return "" |
|
|
| try: |
| |
| if content_format == "minerU" and isinstance(content_source, list): |
| return self._extract_from_content_list( |
| content_source, current_item_info |
| ) |
| elif content_format == "text_chunks" and isinstance(content_source, list): |
| return self._extract_from_text_chunks(content_source, current_item_info) |
| elif content_format == "text" and isinstance(content_source, str): |
| return self._extract_from_text_source(content_source, current_item_info) |
| else: |
| |
| if isinstance(content_source, list): |
| return self._extract_from_content_list( |
| content_source, current_item_info |
| ) |
| elif isinstance(content_source, dict): |
| return self._extract_from_dict_source( |
| content_source, current_item_info |
| ) |
| elif isinstance(content_source, str): |
| return self._extract_from_text_source( |
| content_source, current_item_info |
| ) |
| else: |
| logger.warning( |
| f"Unsupported content source type: {type(content_source)}" |
| ) |
| return "" |
| except Exception as e: |
| logger.error(f"Error extracting context: {e}") |
| return "" |
|
|
| def _extract_from_content_list( |
| self, content_list: List[Dict], current_item_info: Dict |
| ) -> str: |
| """Extract context from MinerU-style content list |
| |
| Args: |
| content_list: List of content items with page_idx and type info |
| current_item_info: Current item information |
| |
| Returns: |
| Context text from surrounding pages/chunks |
| """ |
| if self.config.context_mode == "page": |
| return self._extract_page_context(content_list, current_item_info) |
| elif self.config.context_mode == "chunk": |
| return self._extract_chunk_context(content_list, current_item_info) |
| else: |
| return self._extract_page_context(content_list, current_item_info) |
|
|
| def _extract_page_context( |
| self, content_list: List[Dict], current_item_info: Dict |
| ) -> str: |
| """Extract context based on page boundaries |
| |
| Args: |
| content_list: List of content items |
| current_item_info: Current item with page_idx |
| |
| Returns: |
| Context text from surrounding pages |
| """ |
| current_page = current_item_info.get("page_idx", 0) |
| window_size = self.config.context_window |
|
|
| start_page = max(0, current_page - window_size) |
| end_page = current_page + window_size + 1 |
|
|
| context_texts = [] |
|
|
| for item in content_list: |
| item_page = item.get("page_idx", 0) |
| item_type = item.get("type", "") |
|
|
| |
| if ( |
| start_page <= item_page < end_page |
| and item_type in self.config.filter_content_types |
| ): |
| text_content = self._extract_text_from_item(item) |
| if text_content and text_content.strip(): |
| |
| if item_page != current_page: |
| context_texts.append(f"[Page {item_page}] {text_content}") |
| else: |
| context_texts.append(text_content) |
|
|
| context = "\n".join(context_texts) |
| return self._truncate_context(context) |
|
|
| def _extract_chunk_context( |
| self, content_list: List[Dict], current_item_info: Dict |
| ) -> str: |
| """Extract context based on content chunks |
| |
| Args: |
| content_list: List of content items |
| current_item_info: Current item with index info |
| |
| Returns: |
| Context text from surrounding chunks |
| """ |
| current_index = current_item_info.get("index", 0) |
| window_size = self.config.context_window |
|
|
| start_idx = max(0, current_index - window_size) |
| end_idx = min(len(content_list), current_index + window_size + 1) |
|
|
| context_texts = [] |
|
|
| for i in range(start_idx, end_idx): |
| if i != current_index: |
| item = content_list[i] |
| item_type = item.get("type", "") |
|
|
| if item_type in self.config.filter_content_types: |
| text_content = self._extract_text_from_item(item) |
| if text_content and text_content.strip(): |
| context_texts.append(text_content) |
|
|
| context = "\n".join(context_texts) |
| return self._truncate_context(context) |
|
|
| def _extract_text_from_item(self, item: Dict) -> str: |
| """Extract text content from a content item |
| |
| Args: |
| item: Content item dictionary |
| |
| Returns: |
| Extracted text content |
| """ |
| item_type = item.get("type", "") |
|
|
| if item_type == "text": |
| text = item.get("text", "") |
| text_level = item.get("text_level", 0) |
|
|
| |
| if self.config.include_headers and text_level > 0: |
| return f"{'#' * text_level} {text}" |
| return text |
|
|
| elif item_type == "image" and self.config.include_captions: |
| captions = item.get("image_caption", item.get("img_caption", [])) |
| if captions: |
| return f"[Image: {', '.join(captions)}]" |
|
|
| elif item_type == "table" and self.config.include_captions: |
| captions = item.get("table_caption", []) |
| if captions: |
| return f"[Table: {', '.join(captions)}]" |
|
|
| return "" |
|
|
| def _extract_from_dict_source( |
| self, dict_source: Dict, current_item_info: Dict |
| ) -> str: |
| """Extract context from dictionary-based content source |
| |
| Args: |
| dict_source: Dictionary containing content |
| current_item_info: Current item information |
| |
| Returns: |
| Extracted context text |
| """ |
| |
| if "content" in dict_source: |
| context = str(dict_source["content"]) |
| elif "text" in dict_source: |
| context = str(dict_source["text"]) |
| else: |
| |
| text_parts = [] |
| for value in dict_source.values(): |
| if isinstance(value, str): |
| text_parts.append(value) |
| context = "\n".join(text_parts) |
|
|
| return self._truncate_context(context) |
|
|
| def _extract_from_text_source( |
| self, text_source: str, current_item_info: Dict |
| ) -> str: |
| """Extract context from plain text source |
| |
| Args: |
| text_source: Plain text content |
| current_item_info: Current item information |
| |
| Returns: |
| Truncated text context |
| """ |
| return self._truncate_context(text_source) |
|
|
| def _extract_from_text_chunks( |
| self, text_chunks: List[str], current_item_info: Dict |
| ) -> str: |
| """Extract context from simple text chunks list |
| |
| Args: |
| text_chunks: List of text strings |
| current_item_info: Current item information with index |
| |
| Returns: |
| Context text from surrounding chunks |
| """ |
| current_index = current_item_info.get("index", 0) |
| window_size = self.config.context_window |
|
|
| start_idx = max(0, current_index - window_size) |
| end_idx = min(len(text_chunks), current_index + window_size + 1) |
|
|
| context_texts = [] |
| for i in range(start_idx, end_idx): |
| if i != current_index: |
| if i < len(text_chunks): |
| chunk_text = str(text_chunks[i]).strip() |
| if chunk_text: |
| context_texts.append(chunk_text) |
|
|
| context = "\n".join(context_texts) |
| return self._truncate_context(context) |
|
|
| def _truncate_context(self, context: str) -> str: |
| """Truncate context to maximum token limit |
| |
| Args: |
| context: Context text to truncate |
| |
| Returns: |
| Truncated context text |
| """ |
| if not context: |
| return "" |
|
|
| |
| if self.tokenizer: |
| tokens = self.tokenizer.encode(context) |
| if len(tokens) <= self.config.max_context_tokens: |
| return context |
|
|
| |
| truncated_tokens = tokens[: self.config.max_context_tokens] |
| truncated_text = self.tokenizer.decode(truncated_tokens) |
|
|
| |
| last_period = truncated_text.rfind(".") |
| last_newline = truncated_text.rfind("\n") |
|
|
| if last_period > len(truncated_text) * 0.8: |
| return truncated_text[: last_period + 1] |
| elif last_newline > len(truncated_text) * 0.8: |
| return truncated_text[:last_newline] |
| else: |
| return truncated_text + "..." |
| else: |
| |
| if len(context) <= self.config.max_context_tokens: |
| return context |
|
|
| |
| truncated = context[: self.config.max_context_tokens] |
|
|
| |
| last_period = truncated.rfind(".") |
| last_newline = truncated.rfind("\n") |
|
|
| if last_period > len(truncated) * 0.8: |
| return truncated[: last_period + 1] |
| elif last_newline > len(truncated) * 0.8: |
| return truncated[:last_newline] |
| else: |
| return truncated + "..." |
|
|
|
|
| class BaseModalProcessor: |
| """Base class for modal processors""" |
|
|
| def __init__( |
| self, |
| lightrag: LightRAG, |
| modal_caption_func, |
| context_extractor: ContextExtractor = None, |
| ): |
| """Initialize base processor |
| |
| Args: |
| lightrag: LightRAG instance |
| modal_caption_func: Function for generating descriptions |
| context_extractor: Context extractor instance |
| """ |
| self.lightrag = lightrag |
| self.modal_caption_func = modal_caption_func |
|
|
| |
| self.text_chunks_db = lightrag.text_chunks |
| self.chunks_vdb = lightrag.chunks_vdb |
| self.entities_vdb = lightrag.entities_vdb |
| self.relationships_vdb = lightrag.relationships_vdb |
| self.knowledge_graph_inst = lightrag.chunk_entity_relation_graph |
|
|
| |
| self.embedding_func = lightrag.embedding_func |
| self.llm_model_func = lightrag.llm_model_func |
| self.global_config = asdict(lightrag) |
| self.hashing_kv = lightrag.llm_response_cache |
| self.tokenizer = lightrag.tokenizer |
|
|
| |
| if context_extractor is None: |
| self.context_extractor = ContextExtractor(tokenizer=self.tokenizer) |
| else: |
| self.context_extractor = context_extractor |
| |
| if self.context_extractor.tokenizer is None: |
| self.context_extractor.tokenizer = self.tokenizer |
|
|
| |
| self.content_source = None |
| self.content_format = "auto" |
|
|
| def set_content_source(self, content_source: Any, content_format: str = "auto"): |
| """Set content source for context extraction |
| |
| Args: |
| content_source: Source content for context extraction |
| content_format: Format of content source ("minerU", "text_chunks", "auto") |
| """ |
| self.content_source = content_source |
| self.content_format = content_format |
| logger.info(f"Content source set with format: {content_format}") |
|
|
| def _get_context_for_item(self, item_info: Dict[str, Any]) -> str: |
| """Get context for current processing item |
| |
| Args: |
| item_info: Information about current item (page_idx, index, etc.) |
| |
| Returns: |
| Context text for the item |
| """ |
| if not self.content_source: |
| return "" |
|
|
| try: |
| context = self.context_extractor.extract_context( |
| self.content_source, item_info, self.content_format |
| ) |
| if context: |
| logger.debug( |
| f"Extracted context of length {len(context)} for item: {item_info}" |
| ) |
| return context |
| except Exception as e: |
| logger.error(f"Error getting context for item {item_info}: {e}") |
| return "" |
|
|
| async def generate_description_only( |
| self, |
| modal_content, |
| content_type: str, |
| item_info: Dict[str, Any] = None, |
| entity_name: str = None, |
| ) -> Tuple[str, Dict[str, Any]]: |
| """ |
| Generate text description and entity info only, without entity relation extraction. |
| Used for batch processing stage 1. |
| |
| Args: |
| modal_content: Modal content to process |
| content_type: Type of modal content |
| item_info: Item information for context extraction |
| entity_name: Optional predefined entity name |
| |
| Returns: |
| Tuple of (description, entity_info) |
| """ |
| |
| raise NotImplementedError("Subclasses must implement this method") |
|
|
| async def _create_entity_and_chunk( |
| self, |
| modal_chunk: str, |
| entity_info: Dict[str, Any], |
| file_path: str, |
| batch_mode: bool = False, |
| doc_id: str = None, |
| chunk_order_index: int = 0, |
| ) -> Tuple[str, Dict[str, Any]]: |
| """Create entity and text chunk""" |
| |
| chunk_id = compute_mdhash_id(str(modal_chunk), prefix="chunk-") |
| tokens = len(self.tokenizer.encode(modal_chunk)) |
|
|
| |
| actual_doc_id = doc_id if doc_id else chunk_id |
|
|
| chunk_data = { |
| "tokens": tokens, |
| "content": modal_chunk, |
| "chunk_order_index": chunk_order_index, |
| "full_doc_id": actual_doc_id, |
| "file_path": file_path, |
| } |
|
|
| |
| await self.text_chunks_db.upsert({chunk_id: chunk_data}) |
|
|
| |
| chunk_vdb_data = { |
| chunk_id: { |
| "content": modal_chunk, |
| "full_doc_id": actual_doc_id, |
| "tokens": tokens, |
| "chunk_order_index": chunk_order_index, |
| "file_path": file_path, |
| } |
| } |
| await self.chunks_vdb.upsert(chunk_vdb_data) |
|
|
| |
| node_data = { |
| "entity_id": entity_info["entity_name"], |
| "entity_type": entity_info["entity_type"], |
| "description": entity_info["summary"], |
| "source_id": chunk_id, |
| "file_path": file_path, |
| "created_at": int(time.time()), |
| } |
|
|
| await self.knowledge_graph_inst.upsert_node( |
| entity_info["entity_name"], node_data |
| ) |
|
|
| |
| entity_vdb_data = { |
| compute_mdhash_id(entity_info["entity_name"], prefix="ent-"): { |
| "entity_name": entity_info["entity_name"], |
| "entity_type": entity_info["entity_type"], |
| "content": f"{entity_info['entity_name']}\n{entity_info['summary']}", |
| "source_id": chunk_id, |
| "file_path": file_path, |
| } |
| } |
| await self.entities_vdb.upsert(entity_vdb_data) |
|
|
| |
| chunk_results = await self._process_chunk_for_extraction( |
| chunk_id, entity_info["entity_name"], batch_mode |
| ) |
|
|
| return ( |
| entity_info["summary"], |
| { |
| "entity_name": entity_info["entity_name"], |
| "entity_type": entity_info["entity_type"], |
| "description": entity_info["summary"], |
| "chunk_id": chunk_id, |
| }, |
| chunk_results, |
| ) |
|
|
| def _robust_json_parse(self, response: str) -> dict: |
| """Robust JSON parsing with multiple fallback strategies""" |
|
|
| |
| for json_candidate in self._extract_all_json_candidates(response): |
| result = self._try_parse_json(json_candidate) |
| if result: |
| return result |
|
|
| |
| for json_candidate in self._extract_all_json_candidates(response): |
| cleaned = self._basic_json_cleanup(json_candidate) |
| result = self._try_parse_json(cleaned) |
| if result: |
| return result |
|
|
| |
| for json_candidate in self._extract_all_json_candidates(response): |
| fixed = self._progressive_quote_fix(json_candidate) |
| result = self._try_parse_json(fixed) |
| if result: |
| return result |
|
|
| |
| return self._extract_fields_with_regex(response) |
|
|
| def _extract_all_json_candidates(self, response: str) -> list: |
| """Extract all possible JSON candidates from response""" |
| candidates = [] |
|
|
| |
| import re |
|
|
| json_blocks = re.findall(r"```(?:json)?\s*(\{.*?\})\s*```", response, re.DOTALL) |
| candidates.extend(json_blocks) |
|
|
| |
| brace_count = 0 |
| start_pos = -1 |
|
|
| for i, char in enumerate(response): |
| if char == "{": |
| if brace_count == 0: |
| start_pos = i |
| brace_count += 1 |
| elif char == "}": |
| brace_count -= 1 |
| if brace_count == 0 and start_pos != -1: |
| candidates.append(response[start_pos : i + 1]) |
|
|
| |
| simple_match = re.search(r"\{.*\}", response, re.DOTALL) |
| if simple_match: |
| candidates.append(simple_match.group(0)) |
|
|
| return candidates |
|
|
| def _try_parse_json(self, json_str: str) -> dict: |
| """Try to parse JSON string, return None if failed""" |
| if not json_str or not json_str.strip(): |
| return None |
|
|
| try: |
| return json.loads(json_str) |
| except (json.JSONDecodeError, ValueError): |
| return None |
|
|
| def _basic_json_cleanup(self, json_str: str) -> str: |
| """Basic cleanup for common JSON issues""" |
| |
| json_str = json_str.strip() |
|
|
| |
| json_str = json_str.replace('"', '"').replace('"', '"') |
| json_str = json_str.replace(""", "'").replace(""", "'") |
|
|
| |
| json_str = re.sub(r",(\s*[}\]])", r"\1", json_str) |
|
|
| return json_str |
|
|
| def _progressive_quote_fix(self, json_str: str) -> str: |
| """Progressive fixing of quote and escape issues""" |
| |
| json_str = re.sub(r'(?<!\\)\\(?=")', r"\\\\", json_str) |
|
|
| |
| def fix_string_content(match): |
| content = match.group(1) |
| |
| content = re.sub(r"\\(?=[a-zA-Z])", r"\\\\", content) |
| return f'"{content}"' |
|
|
| json_str = re.sub(r'"([^"]*(?:\\.[^"]*)*)"', fix_string_content, json_str) |
| return json_str |
|
|
| def _extract_fields_with_regex(self, response: str) -> dict: |
| """Extract required fields using regex as last resort""" |
| logger.warning("Using regex fallback for JSON parsing") |
|
|
| |
| desc_match = re.search( |
| r'"detailed_description":\s*"([^"]*(?:\\.[^"]*)*)"', response, re.DOTALL |
| ) |
| description = desc_match.group(1) if desc_match else "" |
|
|
| |
| name_match = re.search(r'"entity_name":\s*"([^"]*(?:\\.[^"]*)*)"', response) |
| entity_name = name_match.group(1) if name_match else "unknown_entity" |
|
|
| |
| type_match = re.search(r'"entity_type":\s*"([^"]*(?:\\.[^"]*)*)"', response) |
| entity_type = type_match.group(1) if type_match else "unknown" |
|
|
| |
| summary_match = re.search( |
| r'"summary":\s*"([^"]*(?:\\.[^"]*)*)"', response, re.DOTALL |
| ) |
| summary = summary_match.group(1) if summary_match else description[:100] |
|
|
| return { |
| "detailed_description": description, |
| "entity_info": { |
| "entity_name": entity_name, |
| "entity_type": entity_type, |
| "summary": summary, |
| }, |
| } |
|
|
| def _extract_json_from_response(self, response: str) -> str: |
| """Legacy method - now handled by _extract_all_json_candidates""" |
| candidates = self._extract_all_json_candidates(response) |
| return candidates[0] if candidates else None |
|
|
| def _fix_json_escapes(self, json_str: str) -> str: |
| """Legacy method - now handled by progressive strategies""" |
| return self._progressive_quote_fix(json_str) |
|
|
| async def _process_chunk_for_extraction( |
| self, chunk_id: str, modal_entity_name: str, batch_mode: bool = False |
| ): |
| """Process chunk for entity and relationship extraction""" |
| chunk_data = await self.text_chunks_db.get_by_id(chunk_id) |
| if not chunk_data: |
| logger.error(f"Chunk {chunk_id} not found") |
| return |
|
|
| |
| chunk_vdb_data = { |
| chunk_id: { |
| "content": chunk_data["content"], |
| "full_doc_id": chunk_id, |
| "tokens": chunk_data["tokens"], |
| "chunk_order_index": chunk_data["chunk_order_index"], |
| "file_path": chunk_data["file_path"], |
| } |
| } |
|
|
| await self.chunks_vdb.upsert(chunk_vdb_data) |
|
|
| pipeline_status = await get_namespace_data("pipeline_status") |
| pipeline_status_lock = get_pipeline_status_lock() |
|
|
| |
| chunks = {chunk_id: chunk_data} |
|
|
| |
| chunk_results = await extract_entities( |
| chunks=chunks, |
| global_config=self.global_config, |
| pipeline_status=pipeline_status, |
| pipeline_status_lock=pipeline_status_lock, |
| llm_response_cache=self.hashing_kv, |
| ) |
|
|
| |
| processed_chunk_results = [] |
| for maybe_nodes, maybe_edges in chunk_results: |
| for entity_name in maybe_nodes.keys(): |
| if entity_name != modal_entity_name: |
| |
| relation_data = { |
| "description": f"Entity {entity_name} belongs to {modal_entity_name}", |
| "keywords": "belongs_to,part_of,contained_in", |
| "source_id": chunk_id, |
| "weight": 10.0, |
| "file_path": chunk_data.get("file_path", "manual_creation"), |
| } |
| await self.knowledge_graph_inst.upsert_edge( |
| entity_name, modal_entity_name, relation_data |
| ) |
|
|
| relation_id = compute_mdhash_id( |
| entity_name + modal_entity_name, prefix="rel-" |
| ) |
| relation_vdb_data = { |
| relation_id: { |
| "src_id": entity_name, |
| "tgt_id": modal_entity_name, |
| "keywords": relation_data["keywords"], |
| "content": f"{relation_data['keywords']}\t{entity_name}\n{modal_entity_name}\n{relation_data['description']}", |
| "source_id": chunk_id, |
| "file_path": chunk_data.get("file_path", "manual_creation"), |
| } |
| } |
| await self.relationships_vdb.upsert(relation_vdb_data) |
|
|
| |
| maybe_edges[(entity_name, modal_entity_name)] = [relation_data] |
|
|
| processed_chunk_results.append((maybe_nodes, maybe_edges)) |
|
|
| if not batch_mode: |
| |
| file_path = chunk_data.get("file_path", "manual_creation") |
| await merge_nodes_and_edges( |
| chunk_results=chunk_results, |
| knowledge_graph_inst=self.knowledge_graph_inst, |
| entity_vdb=self.entities_vdb, |
| relationships_vdb=self.relationships_vdb, |
| global_config=self.global_config, |
| pipeline_status=pipeline_status, |
| pipeline_status_lock=pipeline_status_lock, |
| llm_response_cache=self.hashing_kv, |
| current_file_number=1, |
| total_files=1, |
| file_path=file_path, |
| ) |
|
|
| |
| await self.lightrag._insert_done() |
|
|
| return processed_chunk_results |
|
|
|
|
| class ImageModalProcessor(BaseModalProcessor): |
| """Processor specialized for image content""" |
|
|
| def __init__( |
| self, |
| lightrag: LightRAG, |
| modal_caption_func, |
| context_extractor: ContextExtractor = None, |
| ): |
| """Initialize image processor |
| |
| Args: |
| lightrag: LightRAG instance |
| modal_caption_func: Function for generating descriptions (supporting image understanding) |
| context_extractor: Context extractor instance |
| """ |
| super().__init__(lightrag, modal_caption_func, context_extractor) |
|
|
| def _encode_image_to_base64(self, image_path: str) -> str: |
| """Encode image to base64""" |
| try: |
| with open(image_path, "rb") as image_file: |
| encoded_string = base64.b64encode(image_file.read()).decode("utf-8") |
| return encoded_string |
| except Exception as e: |
| logger.error(f"Failed to encode image {image_path}: {e}") |
| return "" |
|
|
| async def generate_description_only( |
| self, |
| modal_content, |
| content_type: str, |
| item_info: Dict[str, Any] = None, |
| entity_name: str = None, |
| ) -> Tuple[str, Dict[str, Any]]: |
| """ |
| Generate image description and entity info only, without entity relation extraction. |
| Used for batch processing stage 1. |
| |
| Args: |
| modal_content: Image content to process |
| content_type: Type of modal content ("image") |
| item_info: Item information for context extraction |
| entity_name: Optional predefined entity name |
| |
| Returns: |
| Tuple of (enhanced_caption, entity_info) |
| """ |
| try: |
| |
| if isinstance(modal_content, str): |
| try: |
| content_data = json.loads(modal_content) |
| except json.JSONDecodeError: |
| content_data = {"description": modal_content} |
| else: |
| content_data = modal_content |
|
|
| image_path = content_data.get("img_path") |
| captions = content_data.get( |
| "image_caption", content_data.get("img_caption", []) |
| ) |
| footnotes = content_data.get( |
| "image_footnote", content_data.get("img_footnote", []) |
| ) |
|
|
| |
| if not image_path: |
| raise ValueError( |
| f"No image path provided in modal_content: {modal_content}" |
| ) |
|
|
| |
| image_path_obj = Path(image_path) |
| if not image_path_obj.exists(): |
| raise FileNotFoundError(f"Image file not found: {image_path}") |
|
|
| |
| context = "" |
| if item_info: |
| context = self._get_context_for_item(item_info) |
|
|
| |
| if context: |
| vision_prompt = PROMPTS.get( |
| "vision_prompt_with_context", PROMPTS["vision_prompt"] |
| ).format( |
| context=context, |
| entity_name=entity_name |
| if entity_name |
| else "unique descriptive name for this image", |
| image_path=image_path, |
| captions=captions if captions else "None", |
| footnotes=footnotes if footnotes else "None", |
| ) |
| else: |
| vision_prompt = PROMPTS["vision_prompt"].format( |
| entity_name=entity_name |
| if entity_name |
| else "unique descriptive name for this image", |
| image_path=image_path, |
| captions=captions if captions else "None", |
| footnotes=footnotes if footnotes else "None", |
| ) |
|
|
| |
| image_base64 = self._encode_image_to_base64(image_path) |
| if not image_base64: |
| raise RuntimeError(f"Failed to encode image to base64: {image_path}") |
|
|
| |
| response = await self.modal_caption_func( |
| vision_prompt, |
| image_data=image_base64, |
| system_prompt=PROMPTS["IMAGE_ANALYSIS_SYSTEM"], |
| ) |
|
|
| |
| enhanced_caption, entity_info = self._parse_response(response, entity_name) |
|
|
| return enhanced_caption, entity_info |
|
|
| except Exception as e: |
| logger.error(f"Error generating image description: {e}") |
| |
| fallback_entity = { |
| "entity_name": entity_name |
| if entity_name |
| else f"image_{compute_mdhash_id(str(modal_content))}", |
| "entity_type": "image", |
| "summary": f"Image content: {str(modal_content)[:100]}", |
| } |
| return str(modal_content), fallback_entity |
|
|
| async def process_multimodal_content( |
| self, |
| modal_content, |
| content_type: str, |
| file_path: str = "manual_creation", |
| entity_name: str = None, |
| item_info: Dict[str, Any] = None, |
| batch_mode: bool = False, |
| doc_id: str = None, |
| chunk_order_index: int = 0, |
| ) -> Tuple[str, Dict[str, Any]]: |
| """Process image content with context support""" |
| try: |
| |
| enhanced_caption, entity_info = await self.generate_description_only( |
| modal_content, content_type, item_info, entity_name |
| ) |
|
|
| |
| if isinstance(modal_content, str): |
| try: |
| content_data = json.loads(modal_content) |
| except json.JSONDecodeError: |
| content_data = {"description": modal_content} |
| else: |
| content_data = modal_content |
|
|
| image_path = content_data.get("img_path", "") |
| captions = content_data.get( |
| "image_caption", content_data.get("img_caption", []) |
| ) |
| footnotes = content_data.get( |
| "image_footnote", content_data.get("img_footnote", []) |
| ) |
|
|
| modal_chunk = PROMPTS["image_chunk"].format( |
| image_path=image_path, |
| captions=", ".join(captions) if captions else "None", |
| footnotes=", ".join(footnotes) if footnotes else "None", |
| enhanced_caption=enhanced_caption, |
| ) |
|
|
| return await self._create_entity_and_chunk( |
| modal_chunk, |
| entity_info, |
| file_path, |
| batch_mode, |
| doc_id, |
| chunk_order_index, |
| ) |
|
|
| except Exception as e: |
| logger.error(f"Error processing image content: {e}") |
| |
| fallback_entity = { |
| "entity_name": entity_name |
| if entity_name |
| else f"image_{compute_mdhash_id(str(modal_content))}", |
| "entity_type": "image", |
| "summary": f"Image content: {str(modal_content)[:100]}", |
| } |
| return str(modal_content), fallback_entity |
|
|
| def _parse_response( |
| self, response: str, entity_name: str = None |
| ) -> Tuple[str, Dict[str, Any]]: |
| """Parse model response""" |
| try: |
| response_data = self._robust_json_parse(response) |
|
|
| description = response_data.get("detailed_description", "") |
| entity_data = response_data.get("entity_info", {}) |
|
|
| if not description or not entity_data: |
| raise ValueError("Missing required fields in response") |
|
|
| if not all( |
| key in entity_data for key in ["entity_name", "entity_type", "summary"] |
| ): |
| raise ValueError("Missing required fields in entity_info") |
|
|
| entity_data["entity_name"] = ( |
| entity_data["entity_name"] + f" ({entity_data['entity_type']})" |
| ) |
| if entity_name: |
| entity_data["entity_name"] = entity_name |
|
|
| return description, entity_data |
|
|
| except (json.JSONDecodeError, AttributeError, ValueError) as e: |
| logger.error(f"Error parsing image analysis response: {e}") |
| logger.debug(f"Raw response: {response}") |
| fallback_entity = { |
| "entity_name": entity_name |
| if entity_name |
| else f"image_{compute_mdhash_id(response)}", |
| "entity_type": "image", |
| "summary": response[:100] + "..." if len(response) > 100 else response, |
| } |
| return response, fallback_entity |
|
|
|
|
| class TableModalProcessor(BaseModalProcessor): |
| """Processor specialized for table content""" |
|
|
| async def generate_description_only( |
| self, |
| modal_content, |
| content_type: str, |
| item_info: Dict[str, Any] = None, |
| entity_name: str = None, |
| ) -> Tuple[str, Dict[str, Any]]: |
| """ |
| Generate table description and entity info only, without entity relation extraction. |
| Used for batch processing stage 1. |
| |
| Args: |
| modal_content: Table content to process |
| content_type: Type of modal content ("table") |
| item_info: Item information for context extraction |
| entity_name: Optional predefined entity name |
| |
| Returns: |
| Tuple of (enhanced_caption, entity_info) |
| """ |
| try: |
| |
| if isinstance(modal_content, str): |
| try: |
| content_data = json.loads(modal_content) |
| except json.JSONDecodeError: |
| content_data = {"table_body": modal_content} |
| else: |
| content_data = modal_content |
|
|
| table_img_path = content_data.get("img_path") |
| table_caption = content_data.get("table_caption", []) |
| table_body = content_data.get("table_body", "") |
| table_footnote = content_data.get("table_footnote", []) |
|
|
| |
| context = "" |
| if item_info: |
| context = self._get_context_for_item(item_info) |
|
|
| |
| if context: |
| table_prompt = PROMPTS.get( |
| "table_prompt_with_context", PROMPTS["table_prompt"] |
| ).format( |
| context=context, |
| entity_name=entity_name |
| if entity_name |
| else "descriptive name for this table", |
| table_img_path=table_img_path, |
| table_caption=table_caption if table_caption else "None", |
| table_body=table_body, |
| table_footnote=table_footnote if table_footnote else "None", |
| ) |
| else: |
| table_prompt = PROMPTS["table_prompt"].format( |
| entity_name=entity_name |
| if entity_name |
| else "descriptive name for this table", |
| table_img_path=table_img_path, |
| table_caption=table_caption if table_caption else "None", |
| table_body=table_body, |
| table_footnote=table_footnote if table_footnote else "None", |
| ) |
|
|
| |
| response = await self.modal_caption_func( |
| table_prompt, |
| system_prompt=PROMPTS["TABLE_ANALYSIS_SYSTEM"], |
| ) |
|
|
| |
| enhanced_caption, entity_info = self._parse_table_response( |
| response, entity_name |
| ) |
|
|
| return enhanced_caption, entity_info |
|
|
| except Exception as e: |
| logger.error(f"Error generating table description: {e}") |
| |
| fallback_entity = { |
| "entity_name": entity_name |
| if entity_name |
| else f"table_{compute_mdhash_id(str(modal_content))}", |
| "entity_type": "table", |
| "summary": f"Table content: {str(modal_content)[:100]}", |
| } |
| return str(modal_content), fallback_entity |
|
|
| async def process_multimodal_content( |
| self, |
| modal_content, |
| content_type: str, |
| file_path: str = "manual_creation", |
| entity_name: str = None, |
| item_info: Dict[str, Any] = None, |
| batch_mode: bool = False, |
| doc_id: str = None, |
| chunk_order_index: int = 0, |
| ) -> Tuple[str, Dict[str, Any]]: |
| """Process table content with context support""" |
| try: |
| |
| enhanced_caption, entity_info = await self.generate_description_only( |
| modal_content, content_type, item_info, entity_name |
| ) |
|
|
| |
| if isinstance(modal_content, str): |
| try: |
| content_data = json.loads(modal_content) |
| except json.JSONDecodeError: |
| content_data = {"table_body": modal_content} |
| else: |
| content_data = modal_content |
|
|
| table_img_path = content_data.get("img_path") |
| table_caption = content_data.get("table_caption", []) |
| table_body = content_data.get("table_body", "") |
| table_footnote = content_data.get("table_footnote", []) |
|
|
| |
| modal_chunk = PROMPTS["table_chunk"].format( |
| table_img_path=table_img_path, |
| table_caption=", ".join(table_caption) if table_caption else "None", |
| table_body=table_body, |
| table_footnote=", ".join(table_footnote) if table_footnote else "None", |
| enhanced_caption=enhanced_caption, |
| ) |
|
|
| return await self._create_entity_and_chunk( |
| modal_chunk, |
| entity_info, |
| file_path, |
| batch_mode, |
| doc_id, |
| chunk_order_index, |
| ) |
|
|
| except Exception as e: |
| logger.error(f"Error processing table content: {e}") |
| |
| fallback_entity = { |
| "entity_name": entity_name |
| if entity_name |
| else f"table_{compute_mdhash_id(str(modal_content))}", |
| "entity_type": "table", |
| "summary": f"Table content: {str(modal_content)[:100]}", |
| } |
| return str(modal_content), fallback_entity |
|
|
| def _parse_table_response( |
| self, response: str, entity_name: str = None |
| ) -> Tuple[str, Dict[str, Any]]: |
| """Parse table analysis response""" |
| try: |
| response_data = self._robust_json_parse(response) |
|
|
| description = response_data.get("detailed_description", "") |
| entity_data = response_data.get("entity_info", {}) |
|
|
| if not description or not entity_data: |
| raise ValueError("Missing required fields in response") |
|
|
| if not all( |
| key in entity_data for key in ["entity_name", "entity_type", "summary"] |
| ): |
| raise ValueError("Missing required fields in entity_info") |
|
|
| entity_data["entity_name"] = ( |
| entity_data["entity_name"] + f" ({entity_data['entity_type']})" |
| ) |
| if entity_name: |
| entity_data["entity_name"] = entity_name |
|
|
| return description, entity_data |
|
|
| except (json.JSONDecodeError, AttributeError, ValueError) as e: |
| logger.error(f"Error parsing table analysis response: {e}") |
| logger.debug(f"Raw response: {response}") |
| fallback_entity = { |
| "entity_name": entity_name |
| if entity_name |
| else f"table_{compute_mdhash_id(response)}", |
| "entity_type": "table", |
| "summary": response[:100] + "..." if len(response) > 100 else response, |
| } |
| return response, fallback_entity |
|
|
|
|
| class EquationModalProcessor(BaseModalProcessor): |
| """Processor specialized for equation content""" |
|
|
| async def generate_description_only( |
| self, |
| modal_content, |
| content_type: str, |
| item_info: Dict[str, Any] = None, |
| entity_name: str = None, |
| ) -> Tuple[str, Dict[str, Any]]: |
| """ |
| Generate equation description and entity info only, without entity relation extraction. |
| Used for batch processing stage 1. |
| |
| Args: |
| modal_content: Equation content to process |
| content_type: Type of modal content ("equation") |
| item_info: Item information for context extraction |
| entity_name: Optional predefined entity name |
| |
| Returns: |
| Tuple of (enhanced_caption, entity_info) |
| """ |
| try: |
| |
| if isinstance(modal_content, str): |
| try: |
| content_data = json.loads(modal_content) |
| except json.JSONDecodeError: |
| content_data = {"equation": modal_content} |
| else: |
| content_data = modal_content |
|
|
| equation_text = content_data.get("text") |
| equation_format = content_data.get("text_format", "") |
|
|
| |
| context = "" |
| if item_info: |
| context = self._get_context_for_item(item_info) |
|
|
| |
| if context: |
| equation_prompt = PROMPTS.get( |
| "equation_prompt_with_context", PROMPTS["equation_prompt"] |
| ).format( |
| context=context, |
| equation_text=equation_text, |
| equation_format=equation_format, |
| entity_name=entity_name |
| if entity_name |
| else "descriptive name for this equation", |
| ) |
| else: |
| equation_prompt = PROMPTS["equation_prompt"].format( |
| equation_text=equation_text, |
| equation_format=equation_format, |
| entity_name=entity_name |
| if entity_name |
| else "descriptive name for this equation", |
| ) |
|
|
| |
| response = await self.modal_caption_func( |
| equation_prompt, |
| system_prompt=PROMPTS["EQUATION_ANALYSIS_SYSTEM"], |
| ) |
|
|
| |
| enhanced_caption, entity_info = self._parse_equation_response( |
| response, entity_name |
| ) |
|
|
| return enhanced_caption, entity_info |
|
|
| except Exception as e: |
| logger.error(f"Error generating equation description: {e}") |
| |
| fallback_entity = { |
| "entity_name": entity_name |
| if entity_name |
| else f"equation_{compute_mdhash_id(str(modal_content))}", |
| "entity_type": "equation", |
| "summary": f"Equation content: {str(modal_content)[:100]}", |
| } |
| return str(modal_content), fallback_entity |
|
|
| async def process_multimodal_content( |
| self, |
| modal_content, |
| content_type: str, |
| file_path: str = "manual_creation", |
| entity_name: str = None, |
| item_info: Dict[str, Any] = None, |
| batch_mode: bool = False, |
| doc_id: str = None, |
| chunk_order_index: int = 0, |
| ) -> Tuple[str, Dict[str, Any]]: |
| """Process equation content with context support""" |
| try: |
| |
| enhanced_caption, entity_info = await self.generate_description_only( |
| modal_content, content_type, item_info, entity_name |
| ) |
|
|
| |
| if isinstance(modal_content, str): |
| try: |
| content_data = json.loads(modal_content) |
| except json.JSONDecodeError: |
| content_data = {"equation": modal_content} |
| else: |
| content_data = modal_content |
|
|
| equation_text = content_data.get("text") |
| equation_format = content_data.get("text_format", "") |
|
|
| |
| modal_chunk = PROMPTS["equation_chunk"].format( |
| equation_text=equation_text, |
| equation_format=equation_format, |
| enhanced_caption=enhanced_caption, |
| ) |
|
|
| return await self._create_entity_and_chunk( |
| modal_chunk, |
| entity_info, |
| file_path, |
| batch_mode, |
| doc_id, |
| chunk_order_index, |
| ) |
|
|
| except Exception as e: |
| logger.error(f"Error processing equation content: {e}") |
| |
| fallback_entity = { |
| "entity_name": entity_name |
| if entity_name |
| else f"equation_{compute_mdhash_id(str(modal_content))}", |
| "entity_type": "equation", |
| "summary": f"Equation content: {str(modal_content)[:100]}", |
| } |
| return str(modal_content), fallback_entity |
|
|
| def _parse_equation_response( |
| self, response: str, entity_name: str = None |
| ) -> Tuple[str, Dict[str, Any]]: |
| """Parse equation analysis response with robust JSON handling""" |
| try: |
| response_data = self._robust_json_parse(response) |
|
|
| description = response_data.get("detailed_description", "") |
| entity_data = response_data.get("entity_info", {}) |
|
|
| if not description or not entity_data: |
| raise ValueError("Missing required fields in response") |
|
|
| if not all( |
| key in entity_data for key in ["entity_name", "entity_type", "summary"] |
| ): |
| raise ValueError("Missing required fields in entity_info") |
|
|
| entity_data["entity_name"] = ( |
| entity_data["entity_name"] + f" ({entity_data['entity_type']})" |
| ) |
| if entity_name: |
| entity_data["entity_name"] = entity_name |
|
|
| return description, entity_data |
|
|
| except (json.JSONDecodeError, AttributeError, ValueError) as e: |
| logger.error(f"Error parsing equation analysis response: {e}") |
| logger.debug(f"Raw response: {response}") |
| fallback_entity = { |
| "entity_name": entity_name |
| if entity_name |
| else f"equation_{compute_mdhash_id(response)}", |
| "entity_type": "equation", |
| "summary": response[:100] + "..." if len(response) > 100 else response, |
| } |
| return response, fallback_entity |
|
|
|
|
| class GenericModalProcessor(BaseModalProcessor): |
| """Generic processor for other types of modal content""" |
|
|
| async def generate_description_only( |
| self, |
| modal_content, |
| content_type: str, |
| item_info: Dict[str, Any] = None, |
| entity_name: str = None, |
| ) -> Tuple[str, Dict[str, Any]]: |
| """ |
| Generate generic modal description and entity info only, without entity relation extraction. |
| Used for batch processing stage 1. |
| |
| Args: |
| modal_content: Generic modal content to process |
| content_type: Type of modal content |
| item_info: Item information for context extraction |
| entity_name: Optional predefined entity name |
| |
| Returns: |
| Tuple of (enhanced_caption, entity_info) |
| """ |
| try: |
| |
| context = "" |
| if item_info: |
| context = self._get_context_for_item(item_info) |
|
|
| |
| if context: |
| generic_prompt = PROMPTS.get( |
| "generic_prompt_with_context", PROMPTS["generic_prompt"] |
| ).format( |
| context=context, |
| content_type=content_type, |
| entity_name=entity_name |
| if entity_name |
| else f"descriptive name for this {content_type}", |
| content=str(modal_content), |
| ) |
| else: |
| generic_prompt = PROMPTS["generic_prompt"].format( |
| content_type=content_type, |
| entity_name=entity_name |
| if entity_name |
| else f"descriptive name for this {content_type}", |
| content=str(modal_content), |
| ) |
|
|
| |
| response = await self.modal_caption_func( |
| generic_prompt, |
| system_prompt=PROMPTS["GENERIC_ANALYSIS_SYSTEM"].format( |
| content_type=content_type |
| ), |
| ) |
|
|
| |
| enhanced_caption, entity_info = self._parse_generic_response( |
| response, entity_name, content_type |
| ) |
|
|
| return enhanced_caption, entity_info |
|
|
| except Exception as e: |
| logger.error(f"Error generating {content_type} description: {e}") |
| |
| fallback_entity = { |
| "entity_name": entity_name |
| if entity_name |
| else f"{content_type}_{compute_mdhash_id(str(modal_content))}", |
| "entity_type": content_type, |
| "summary": f"{content_type} content: {str(modal_content)[:100]}", |
| } |
| return str(modal_content), fallback_entity |
|
|
| async def process_multimodal_content( |
| self, |
| modal_content, |
| content_type: str, |
| file_path: str = "manual_creation", |
| entity_name: str = None, |
| item_info: Dict[str, Any] = None, |
| batch_mode: bool = False, |
| doc_id: str = None, |
| chunk_order_index: int = 0, |
| ) -> Tuple[str, Dict[str, Any]]: |
| """Process generic modal content with context support""" |
| try: |
| |
| enhanced_caption, entity_info = await self.generate_description_only( |
| modal_content, content_type, item_info, entity_name |
| ) |
|
|
| |
| modal_chunk = PROMPTS["generic_chunk"].format( |
| content_type=content_type.title(), |
| content=str(modal_content), |
| enhanced_caption=enhanced_caption, |
| ) |
|
|
| return await self._create_entity_and_chunk( |
| modal_chunk, |
| entity_info, |
| file_path, |
| batch_mode, |
| doc_id, |
| chunk_order_index, |
| ) |
|
|
| except Exception as e: |
| logger.error(f"Error processing {content_type} content: {e}") |
| |
| fallback_entity = { |
| "entity_name": entity_name |
| if entity_name |
| else f"{content_type}_{compute_mdhash_id(str(modal_content))}", |
| "entity_type": content_type, |
| "summary": f"{content_type} content: {str(modal_content)[:100]}", |
| } |
| return str(modal_content), fallback_entity |
|
|
| def _parse_generic_response( |
| self, response: str, entity_name: str = None, content_type: str = "content" |
| ) -> Tuple[str, Dict[str, Any]]: |
| """Parse generic analysis response""" |
| try: |
| response_data = self._robust_json_parse(response) |
|
|
| description = response_data.get("detailed_description", "") |
| entity_data = response_data.get("entity_info", {}) |
|
|
| if not description or not entity_data: |
| raise ValueError("Missing required fields in response") |
|
|
| if not all( |
| key in entity_data for key in ["entity_name", "entity_type", "summary"] |
| ): |
| raise ValueError("Missing required fields in entity_info") |
|
|
| entity_data["entity_name"] = ( |
| entity_data["entity_name"] + f" ({entity_data['entity_type']})" |
| ) |
| if entity_name: |
| entity_data["entity_name"] = entity_name |
|
|
| return description, entity_data |
|
|
| except (json.JSONDecodeError, AttributeError, ValueError) as e: |
| logger.error(f"Error parsing {content_type} analysis response: {e}") |
| logger.debug(f"Raw response: {response}") |
| fallback_entity = { |
| "entity_name": entity_name |
| if entity_name |
| else f"{content_type}_{compute_mdhash_id(response)}", |
| "entity_type": content_type, |
| "summary": response[:100] + "..." if len(response) > 100 else response, |
| } |
| return response, fallback_entity |
|
|