| # """ | |
| # Query functionality for RAGAnything | |
| # Contains all query-related methods for both text and multimodal queries | |
| # """ | |
| # import json | |
| # import hashlib | |
| # import re | |
| # from typing import Dict, List, Any | |
| # from pathlib import Path | |
| # from lightrag import QueryParam | |
| # from lightrag.utils import always_get_an_event_loop | |
| # from raganything.prompt import PROMPTS | |
| # from raganything.utils import ( | |
| # get_processor_for_type, | |
| # encode_image_to_base64, | |
| # validate_image_file, | |
| # ) | |
| # # Add these imports | |
| # from raganything.query_improvement import QueryImprovementMixin | |
| # from raganything.verification import DualLLMVerificationMixin | |
| # class QueryMixin(QueryImprovementMixin, DualLLMVerificationMixin): | |
| # """QueryMixin class containing query functionality for RAGAnything""" | |
| # def _generate_multimodal_cache_key( | |
| # self, query: str, multimodal_content: List[Dict[str, Any]], mode: str, **kwargs | |
| # ) -> str: | |
| # """ | |
| # Generate cache key for multimodal query | |
| # Args: | |
| # query: Base query text | |
| # multimodal_content: List of multimodal content | |
| # mode: Query mode | |
| # **kwargs: Additional parameters | |
| # Returns: | |
| # str: Cache key hash | |
| # """ | |
| # # Create a normalized representation of the query parameters | |
| # cache_data = { | |
| # "query": query.strip(), | |
| # "mode": mode, | |
| # } | |
| # # Normalize multimodal content for stable caching | |
| # normalized_content = [] | |
| # if multimodal_content: | |
| # for item in multimodal_content: | |
| # if isinstance(item, dict): | |
| # normalized_item = {} | |
| # for key, value in item.items(): | |
| # # For file paths, use basename to make cache more portable | |
| # if key in [ | |
| # "img_path", | |
| # "image_path", | |
| # "file_path", | |
| # ] and isinstance(value, str): | |
| # normalized_item[key] = Path(value).name | |
| # # For large content, create a hash instead of storing directly | |
| # elif ( | |
| # key in ["table_data", "table_body"] | |
| # and isinstance(value, str) | |
| # and len(value) > 200 | |
| # ): | |
| # normalized_item[f"{key}_hash"] = hashlib.md5( | |
| # value.encode() | |
| # ).hexdigest() | |
| # else: | |
| # normalized_item[key] = value | |
| # normalized_content.append(normalized_item) | |
| # else: | |
| # normalized_content.append(item) | |
| # cache_data["multimodal_content"] = normalized_content | |
| # # Add relevant kwargs to cache data | |
| # relevant_kwargs = { | |
| # k: v | |
| # for k, v in kwargs.items() | |
| # if k | |
| # in [ | |
| # "stream", | |
| # "response_type", | |
| # "top_k", | |
| # "max_tokens", | |
| # "temperature", | |
| # # "only_need_context", | |
| # # "only_need_prompt", | |
| # ] | |
| # } | |
| # cache_data.update(relevant_kwargs) | |
| # # Generate hash from the cache data | |
| # cache_str = json.dumps(cache_data, sort_keys=True, ensure_ascii=False) | |
| # cache_hash = hashlib.md5(cache_str.encode()).hexdigest() | |
| # return f"multimodal_query:{cache_hash}" | |
| # # async def aquery(self, query: str, mode: str = "mix", **kwargs) -> str: | |
| # # """ | |
| # # Pure text query - directly calls LightRAG's query functionality | |
| # # Args: | |
| # # query: Query text | |
| # # mode: Query mode ("local", "global", "hybrid", "naive", "mix", "bypass") | |
| # # **kwargs: Other query parameters, will be passed to QueryParam | |
| # # - vlm_enhanced: bool, default True when vision_model_func is available. | |
| # # If True, will parse image paths in retrieved context and replace them | |
| # # with base64 encoded images for VLM processing. | |
| # # Returns: | |
| # # str: Query result | |
| # # """ | |
| # # if self.lightrag is None: | |
| # # raise ValueError( | |
| # # "No LightRAG instance available. Please process documents first or provide a pre-initialized LightRAG instance." | |
| # # ) | |
| # # # Check if VLM enhanced query should be used | |
| # # vlm_enhanced = kwargs.pop("vlm_enhanced", None) | |
| # # # Auto-determine VLM enhanced based on availability | |
| # # if vlm_enhanced is None: | |
| # # vlm_enhanced = ( | |
| # # hasattr(self, "vision_model_func") | |
| # # and self.vision_model_func is not None | |
| # # ) | |
| # # # Use VLM enhanced query if enabled and available | |
| # # if ( | |
| # # vlm_enhanced | |
| # # and hasattr(self, "vision_model_func") | |
| # # and self.vision_model_func | |
| # # ): | |
| # # return await self.aquery_vlm_enhanced(query, mode=mode, **kwargs) | |
| # # elif vlm_enhanced and ( | |
| # # not hasattr(self, "vision_model_func") or not self.vision_model_func | |
| # # ): | |
| # # self.logger.warning( | |
| # # "VLM enhanced query requested but vision_model_func is not available, falling back to normal query" | |
| # # ) | |
| # # # Create query parameters | |
| # # query_param = QueryParam(mode=mode, **kwargs) | |
| # # self.logger.info(f"Executing text query: {query[:100]}...") | |
| # # self.logger.info(f"Query mode: {mode}") | |
| # # # Call LightRAG's query method | |
| # # result = await self.lightrag.aquery(query, param=query_param) | |
| # # self.logger.info("Text query completed") | |
| # # return result | |
| # # async def aquery_with_multimodal( | |
| # # self, | |
| # # query: str, | |
| # # multimodal_content: List[Dict[str, Any]] = None, | |
| # # mode: str = "mix", | |
| # # **kwargs, | |
| # # ) -> str: | |
| # # """ | |
| # # Multimodal query - combines text and multimodal content for querying | |
| # # Args: | |
| # # query: Base query text | |
| # # multimodal_content: List of multimodal content, each element contains: | |
| # # - type: Content type ("image", "table", "equation", etc.) | |
| # # - Other fields depend on type (e.g., img_path, table_data, latex, etc.) | |
| # # mode: Query mode ("local", "global", "hybrid", "naive", "mix", "bypass") | |
| # # **kwargs: Other query parameters, will be passed to QueryParam | |
| # # Returns: | |
| # # str: Query result | |
| # # Examples: | |
| # # # Pure text query | |
| # # result = await rag.query_with_multimodal("What is machine learning?") | |
| # # # Image query | |
| # # result = await rag.query_with_multimodal( | |
| # # "Analyze the content in this image", | |
| # # multimodal_content=[{ | |
| # # "type": "image", | |
| # # "img_path": "./image.jpg" | |
| # # }] | |
| # # ) | |
| # # # Table query | |
| # # result = await rag.query_with_multimodal( | |
| # # "Analyze the data trends in this table", | |
| # # multimodal_content=[{ | |
| # # "type": "table", | |
| # # "table_data": "Name,Age\nAlice,25\nBob,30" | |
| # # }] | |
| # # ) | |
| # # """ | |
| # # # Ensure LightRAG is initialized | |
| # # await self._ensure_lightrag_initialized() | |
| # # self.logger.info(f"Executing multimodal query: {query[:100]}...") | |
| # # self.logger.info(f"Query mode: {mode}") | |
| # # # If no multimodal content, fallback to pure text query | |
| # # if not multimodal_content: | |
| # # self.logger.info("No multimodal content provided, executing text query") | |
| # # return await self.aquery(query, mode=mode, **kwargs) | |
| # # # Generate cache key for multimodal query | |
| # # cache_key = self._generate_multimodal_cache_key( | |
| # # query, multimodal_content, mode, **kwargs | |
| # # ) | |
| # # # Check cache if available and enabled | |
| # # cached_result = None | |
| # # if ( | |
| # # hasattr(self, "lightrag") | |
| # # and self.lightrag | |
| # # and hasattr(self.lightrag, "llm_response_cache") | |
| # # and self.lightrag.llm_response_cache | |
| # # ): | |
| # # if self.lightrag.llm_response_cache.global_config.get( | |
| # # "enable_llm_cache", True | |
| # # ): | |
| # # try: | |
| # # cached_result = await self.lightrag.llm_response_cache.get_by_id( | |
| # # cache_key | |
| # # ) | |
| # # if cached_result and isinstance(cached_result, dict): | |
| # # result_content = cached_result.get("return") | |
| # # if result_content: | |
| # # self.logger.info( | |
| # # f"Multimodal query cache hit: {cache_key[:16]}..." | |
| # # ) | |
| # # return result_content | |
| # # except Exception as e: | |
| # # self.logger.debug(f"Error accessing multimodal query cache: {e}") | |
| # # # Process multimodal content to generate enhanced query text | |
| # # enhanced_query = await self._process_multimodal_query_content( | |
| # # query, multimodal_content | |
| # # ) | |
| # # self.logger.info( | |
| # # f"Generated enhanced query length: {len(enhanced_query)} characters" | |
| # # ) | |
| # # # Execute enhanced query | |
| # # result = await self.aquery(enhanced_query, mode=mode, **kwargs) | |
| # # # Save to cache if available and enabled | |
| # # if ( | |
| # # hasattr(self, "lightrag") | |
| # # and self.lightrag | |
| # # and hasattr(self.lightrag, "llm_response_cache") | |
| # # and self.lightrag.llm_response_cache | |
| # # ): | |
| # # if self.lightrag.llm_response_cache.global_config.get( | |
| # # "enable_llm_cache", True | |
| # # ): | |
| # # try: | |
| # # # Create cache entry for multimodal query | |
| # # cache_entry = { | |
| # # "return": result, | |
| # # "cache_type": "multimodal_query", | |
| # # "original_query": query, | |
| # # "multimodal_content_count": len(multimodal_content), | |
| # # "mode": mode, | |
| # # } | |
| # # await self.lightrag.llm_response_cache.upsert( | |
| # # {cache_key: cache_entry} | |
| # # ) | |
| # # self.logger.info( | |
| # # f"Saved multimodal query result to cache: {cache_key[:16]}..." | |
| # # ) | |
| # # except Exception as e: | |
| # # self.logger.debug(f"Error saving multimodal query to cache: {e}") | |
| # # # Ensure cache is persisted to disk | |
| # # if ( | |
| # # hasattr(self, "lightrag") | |
| # # and self.lightrag | |
| # # and hasattr(self.lightrag, "llm_response_cache") | |
| # # and self.lightrag.llm_response_cache | |
| # # ): | |
| # # try: | |
| # # await self.lightrag.llm_response_cache.index_done_callback() | |
| # # except Exception as e: | |
| # # self.logger.debug(f"Error persisting multimodal query cache: {e}") | |
| # # self.logger.info("Multimodal query completed") | |
| # # return result | |
| # async def aquery(self, query: str, mode: str = "mix", **kwargs) -> str: | |
| # """ | |
| # Pure text query with optional query improvement and verification | |
| # Args: | |
| # query: Query text | |
| # mode: Query mode ("local", "global", "hybrid", "naive", "mix", "bypass") | |
| # **kwargs: Other query parameters | |
| # - enable_query_improvement: bool, override config setting | |
| # - enable_verification: bool, override config setting | |
| # - return_verification_info: bool, return detailed verification info | |
| # Returns: | |
| # str: Query result (or dict if return_verification_info=True) | |
| # """ | |
| # if self.lightrag is None: | |
| # raise ValueError( | |
| # "No LightRAG instance available. Please process documents first or provide a pre-initialized LightRAG instance." | |
| # ) | |
| # # Check override flags | |
| # use_query_improvement = kwargs.pop('enable_query_improvement', | |
| # getattr(self.config, 'enable_query_improvement', False)) | |
| # use_verification = kwargs.pop('enable_verification', | |
| # getattr(self.config, 'enable_dual_llm_verification', False)) | |
| # return_verification_info = kwargs.pop('return_verification_info', False) | |
| # original_query = query | |
| # query_improvement_result = None | |
| # # Step 1: Apply query improvement if enabled | |
| # if use_query_improvement and hasattr(self, 'query_improver') and self.query_improver: | |
| # self.logger.info("Applying query improvement...") | |
| # query_improvement_result = await self._apply_query_improvement(query) | |
| # query = query_improvement_result["improved_query"] | |
| # self.logger.info(f"Query improved: '{original_query[:50]}...' -> '{query[:50]}...'") | |
| # # Step 2: Check VLM enhanced query | |
| # vlm_enhanced = kwargs.pop("vlm_enhanced", None) | |
| # if vlm_enhanced is None: | |
| # vlm_enhanced = ( | |
| # hasattr(self, "vision_model_func") and self.vision_model_func is not None | |
| # ) | |
| # # If using VLM enhanced or verification is disabled, use existing flow | |
| # if vlm_enhanced or not use_verification: | |
| # if vlm_enhanced and hasattr(self, "vision_model_func") and self.vision_model_func: | |
| # result = await self.aquery_vlm_enhanced(query, mode=mode, **kwargs) | |
| # else: | |
| # from lightrag import QueryParam | |
| # query_param = QueryParam(mode=mode, **kwargs) | |
| # result = await self.lightrag.aquery(query, param=query_param) | |
| # if return_verification_info: | |
| # return { | |
| # "answer": result, | |
| # "original_query": original_query, | |
| # "improved_query": query if query_improvement_result else original_query, | |
| # "query_improvement": query_improvement_result, | |
| # "verification_passed": True, | |
| # "verification_score": 10.0 | |
| # } | |
| # return result | |
| # # Step 3: Generate with verification | |
| # if use_verification and hasattr(self, 'answer_verifier') and self.answer_verifier: | |
| # self.logger.info("Using dual-LLM verification...") | |
| # # Get context without final answer | |
| # from lightrag import QueryParam | |
| # query_param = QueryParam(mode=mode, only_need_context=True, **kwargs) | |
| # context = await self.lightrag.aquery(query, param=query_param) | |
| # # Generate with verification | |
| # verification_result = await self._generate_with_verification( | |
| # query=query, | |
| # context=context, | |
| # original_query=original_query | |
| # ) | |
| # if return_verification_info: | |
| # return { | |
| # "answer": verification_result["answer"], | |
| # "original_query": original_query, | |
| # "improved_query": query if query_improvement_result else original_query, | |
| # "query_improvement": query_improvement_result, | |
| # "verification_passed": verification_result["verification_passed"], | |
| # "verification_score": verification_result["verification_score"], | |
| # "modification_attempts": verification_result["modification_attempts"], | |
| # "verification_history": verification_result.get("verification_history", []) | |
| # } | |
| # return verification_result["answer"] | |
| # # Fallback to normal query | |
| # from lightrag import QueryParam | |
| # query_param = QueryParam(mode=mode, **kwargs) | |
| # result = await self.lightrag.aquery(query, param=query_param) | |
| # if return_verification_info: | |
| # return { | |
| # "answer": result, | |
| # "original_query": original_query, | |
| # "improved_query": query if query_improvement_result else original_query, | |
| # "query_improvement": query_improvement_result | |
| # } | |
| # return result | |
| # async def aquery_vlm_enhanced(self, query: str, mode: str = "mix", **kwargs) -> str: | |
| # """ | |
| # VLM enhanced query - replaces image paths in retrieved context with base64 encoded images for VLM processing | |
| # Args: | |
| # query: User query | |
| # mode: Underlying LightRAG query mode | |
| # **kwargs: Other query parameters | |
| # Returns: | |
| # str: VLM query result | |
| # """ | |
| # # Ensure VLM is available | |
| # if not hasattr(self, "vision_model_func") or not self.vision_model_func: | |
| # raise ValueError( | |
| # "VLM enhanced query requires vision_model_func. " | |
| # "Please provide a vision model function when initializing RAGAnything." | |
| # ) | |
| # # Ensure LightRAG is initialized | |
| # await self._ensure_lightrag_initialized() | |
| # self.logger.info(f"Executing VLM enhanced query: {query[:100]}...") | |
| # # Clear previous image cache | |
| # if hasattr(self, "_current_images_base64"): | |
| # delattr(self, "_current_images_base64") | |
| # # 1. Get original retrieval prompt (without generating final answer) | |
| # query_param = QueryParam(mode=mode, only_need_prompt=True, **kwargs) | |
| # raw_prompt = await self.lightrag.aquery(query, param=query_param) | |
| # self.logger.debug("Retrieved raw prompt from LightRAG") | |
| # # 2. Extract and process image paths | |
| # enhanced_prompt, images_found = await self._process_image_paths_for_vlm( | |
| # raw_prompt | |
| # ) | |
| # if not images_found: | |
| # self.logger.info("No valid images found, falling back to normal query") | |
| # # Fallback to normal query | |
| # query_param = QueryParam(mode=mode, **kwargs) | |
| # return await self.lightrag.aquery(query, param=query_param) | |
| # self.logger.info(f"Processed {images_found} images for VLM") | |
| # # 3. Build VLM message format | |
| # messages = self._build_vlm_messages_with_images(enhanced_prompt, query) | |
| # # 4. Call VLM for question answering | |
| # result = await self._call_vlm_with_multimodal_content(messages) | |
| # self.logger.info("VLM enhanced query completed") | |
| # return result | |
| # async def _process_multimodal_query_content( | |
| # self, base_query: str, multimodal_content: List[Dict[str, Any]] | |
| # ) -> str: | |
| # """ | |
| # Process multimodal query content to generate enhanced query text | |
| # Args: | |
| # base_query: Base query text | |
| # multimodal_content: List of multimodal content | |
| # Returns: | |
| # str: Enhanced query text | |
| # """ | |
| # self.logger.info("Starting multimodal query content processing...") | |
| # enhanced_parts = [f"User query: {base_query}"] | |
| # for i, content in enumerate(multimodal_content): | |
| # content_type = content.get("type", "unknown") | |
| # self.logger.info( | |
| # f"Processing {i+1}/{len(multimodal_content)} multimodal content: {content_type}" | |
| # ) | |
| # try: | |
| # # Get appropriate processor | |
| # processor = get_processor_for_type(self.modal_processors, content_type) | |
| # if processor: | |
| # # Generate content description | |
| # description = await self._generate_query_content_description( | |
| # processor, content, content_type | |
| # ) | |
| # enhanced_parts.append( | |
| # f"\nRelated {content_type} content: {description}" | |
| # ) | |
| # else: | |
| # # If no appropriate processor, use basic description | |
| # basic_desc = str(content)[:200] | |
| # enhanced_parts.append( | |
| # f"\nRelated {content_type} content: {basic_desc}" | |
| # ) | |
| # except Exception as e: | |
| # self.logger.error(f"Error processing multimodal content: {str(e)}") | |
| # # Continue processing other content | |
| # continue | |
| # enhanced_query = "\n".join(enhanced_parts) | |
| # enhanced_query += PROMPTS["QUERY_ENHANCEMENT_SUFFIX"] | |
| # self.logger.info("Multimodal query content processing completed") | |
| # return enhanced_query | |
| # async def _generate_query_content_description( | |
| # self, processor, content: Dict[str, Any], content_type: str | |
| # ) -> str: | |
| # """ | |
| # Generate content description for query | |
| # Args: | |
| # processor: Multimodal processor | |
| # content: Content data | |
| # content_type: Content type | |
| # Returns: | |
| # str: Content description | |
| # """ | |
| # try: | |
| # if content_type == "image": | |
| # return await self._describe_image_for_query(processor, content) | |
| # elif content_type == "table": | |
| # return await self._describe_table_for_query(processor, content) | |
| # elif content_type == "equation": | |
| # return await self._describe_equation_for_query(processor, content) | |
| # else: | |
| # return await self._describe_generic_for_query( | |
| # processor, content, content_type | |
| # ) | |
| # except Exception as e: | |
| # self.logger.error(f"Error generating {content_type} description: {str(e)}") | |
| # return f"{content_type} content: {str(content)[:100]}" | |
| # async def _describe_image_for_query( | |
| # self, processor, content: Dict[str, Any] | |
| # ) -> str: | |
| # """Generate image description for query""" | |
| # image_path = content.get("img_path") | |
| # captions = content.get("image_caption", content.get("img_caption", [])) | |
| # footnotes = content.get("image_footnote", content.get("img_footnote", [])) | |
| # if image_path and Path(image_path).exists(): | |
| # # If image exists, use vision model to generate description | |
| # image_base64 = processor._encode_image_to_base64(image_path) | |
| # if image_base64: | |
| # prompt = PROMPTS["QUERY_IMAGE_DESCRIPTION"] | |
| # description = await processor.modal_caption_func( | |
| # prompt, | |
| # image_data=image_base64, | |
| # system_prompt=PROMPTS["QUERY_IMAGE_ANALYST_SYSTEM"], | |
| # ) | |
| # return description | |
| # # If image doesn't exist or processing failed, use existing information | |
| # parts = [] | |
| # if image_path: | |
| # parts.append(f"Image path: {image_path}") | |
| # if captions: | |
| # parts.append(f"Image captions: {', '.join(captions)}") | |
| # if footnotes: | |
| # parts.append(f"Image footnotes: {', '.join(footnotes)}") | |
| # return "; ".join(parts) if parts else "Image content information incomplete" | |
| # async def _describe_table_for_query( | |
| # self, processor, content: Dict[str, Any] | |
| # ) -> str: | |
| # """Generate table description for query""" | |
| # table_data = content.get("table_data", "") | |
| # table_caption = content.get("table_caption", "") | |
| # prompt = PROMPTS["QUERY_TABLE_ANALYSIS"].format( | |
| # table_data=table_data, table_caption=table_caption | |
| # ) | |
| # description = await processor.modal_caption_func( | |
| # prompt, system_prompt=PROMPTS["QUERY_TABLE_ANALYST_SYSTEM"] | |
| # ) | |
| # return description | |
| # async def _describe_equation_for_query( | |
| # self, processor, content: Dict[str, Any] | |
| # ) -> str: | |
| # """Generate equation description for query""" | |
| # latex = content.get("latex", "") | |
| # equation_caption = content.get("equation_caption", "") | |
| # prompt = PROMPTS["QUERY_EQUATION_ANALYSIS"].format( | |
| # latex=latex, equation_caption=equation_caption | |
| # ) | |
| # description = await processor.modal_caption_func( | |
| # prompt, system_prompt=PROMPTS["QUERY_EQUATION_ANALYST_SYSTEM"] | |
| # ) | |
| # return description | |
| # async def _describe_generic_for_query( | |
| # self, processor, content: Dict[str, Any], content_type: str | |
| # ) -> str: | |
| # """Generate generic content description for query""" | |
| # content_str = str(content) | |
| # prompt = PROMPTS["QUERY_GENERIC_ANALYSIS"].format( | |
| # content_type=content_type, content_str=content_str | |
| # ) | |
| # description = await processor.modal_caption_func( | |
| # prompt, | |
| # system_prompt=PROMPTS["QUERY_GENERIC_ANALYST_SYSTEM"].format( | |
| # content_type=content_type | |
| # ), | |
| # ) | |
| # return description | |
| # async def _process_image_paths_for_vlm(self, prompt: str) -> tuple[str, int]: | |
| # """ | |
| # Process image paths in prompt, keeping original paths and adding VLM markers | |
| # Args: | |
| # prompt: Original prompt | |
| # Returns: | |
| # tuple: (processed prompt, image count) | |
| # """ | |
| # enhanced_prompt = prompt | |
| # images_processed = 0 | |
| # # Initialize image cache | |
| # self._current_images_base64 = [] | |
| # # Enhanced regex pattern for matching image paths | |
| # # Matches only the path ending with image file extensions | |
| # image_path_pattern = ( | |
| # r"Image Path:\s*([^\r\n]*?\.(?:jpg|jpeg|png|gif|bmp|webp|tiff|tif))" | |
| # ) | |
| # # First, let's see what matches we find | |
| # matches = re.findall(image_path_pattern, prompt) | |
| # self.logger.info(f"Found {len(matches)} image path matches in prompt") | |
| # def replace_image_path(match): | |
| # nonlocal images_processed | |
| # image_path = match.group(1).strip() | |
| # self.logger.debug(f"Processing image path: '{image_path}'") | |
| # # Validate path format (basic check) | |
| # if not image_path or len(image_path) < 3: | |
| # self.logger.warning(f"Invalid image path format: {image_path}") | |
| # return match.group(0) # Keep original | |
| # # Use utility function to validate image file | |
| # self.logger.debug(f"Calling validate_image_file for: {image_path}") | |
| # is_valid = validate_image_file(image_path) | |
| # self.logger.debug(f"Validation result for {image_path}: {is_valid}") | |
| # if not is_valid: | |
| # self.logger.warning(f"Image validation failed for: {image_path}") | |
| # return match.group(0) # Keep original if validation fails | |
| # try: | |
| # # Encode image to base64 using utility function | |
| # self.logger.debug(f"Attempting to encode image: {image_path}") | |
| # image_base64 = encode_image_to_base64(image_path) | |
| # if image_base64: | |
| # images_processed += 1 | |
| # # Save base64 to instance variable for later use | |
| # self._current_images_base64.append(image_base64) | |
| # # Keep original path info and add VLM marker | |
| # result = f"Image Path: {image_path}\n[VLM_IMAGE_{images_processed}]" | |
| # self.logger.debug( | |
| # f"Successfully processed image {images_processed}: {image_path}" | |
| # ) | |
| # return result | |
| # else: | |
| # self.logger.error(f"Failed to encode image: {image_path}") | |
| # return match.group(0) # Keep original if encoding failed | |
| # except Exception as e: | |
| # self.logger.error(f"Failed to process image {image_path}: {e}") | |
| # return match.group(0) # Keep original | |
| # # Execute replacement | |
| # enhanced_prompt = re.sub( | |
| # image_path_pattern, replace_image_path, enhanced_prompt | |
| # ) | |
| # return enhanced_prompt, images_processed | |
| # def _build_vlm_messages_with_images( | |
| # self, enhanced_prompt: str, user_query: str | |
| # ) -> List[Dict]: | |
| # """ | |
| # Build VLM message format, using markers to correspond images with text positions | |
| # Args: | |
| # enhanced_prompt: Enhanced prompt with image markers | |
| # user_query: User query | |
| # Returns: | |
| # List[Dict]: VLM message format | |
| # """ | |
| # images_base64 = getattr(self, "_current_images_base64", []) | |
| # if not images_base64: | |
| # # Pure text mode | |
| # return [ | |
| # { | |
| # "role": "user", | |
| # "content": f"Context:\n{enhanced_prompt}\n\nUser Question: {user_query}", | |
| # } | |
| # ] | |
| # # Build multimodal content | |
| # content_parts = [] | |
| # # Split text at image markers and insert images | |
| # text_parts = enhanced_prompt.split("[VLM_IMAGE_") | |
| # for i, text_part in enumerate(text_parts): | |
| # if i == 0: | |
| # # First text part | |
| # if text_part.strip(): | |
| # content_parts.append({"type": "text", "text": text_part}) | |
| # else: | |
| # # Find marker number and insert corresponding image | |
| # marker_match = re.match(r"(\d+)\](.*)", text_part, re.DOTALL) | |
| # if marker_match: | |
| # image_num = ( | |
| # int(marker_match.group(1)) - 1 | |
| # ) # Convert to 0-based index | |
| # remaining_text = marker_match.group(2) | |
| # # Insert corresponding image | |
| # if 0 <= image_num < len(images_base64): | |
| # content_parts.append( | |
| # { | |
| # "type": "image_url", | |
| # "image_url": { | |
| # "url": f"data:image/jpeg;base64,{images_base64[image_num]}" | |
| # }, | |
| # } | |
| # ) | |
| # # Insert remaining text | |
| # if remaining_text.strip(): | |
| # content_parts.append({"type": "text", "text": remaining_text}) | |
| # # Add user question | |
| # content_parts.append( | |
| # { | |
| # "type": "text", | |
| # "text": f"\n\nUser Question: {user_query}\n\nPlease answer based on the context and images provided.", | |
| # } | |
| # ) | |
| # return [ | |
| # { | |
| # "role": "system", | |
| # "content": "You are a helpful assistant that can analyze both text and image content to provide comprehensive answers.", | |
| # }, | |
| # {"role": "user", "content": content_parts}, | |
| # ] | |
| # async def _call_vlm_with_multimodal_content(self, messages: List[Dict]) -> str: | |
| # """ | |
| # Call VLM to process multimodal content | |
| # Args: | |
| # messages: VLM message format | |
| # Returns: | |
| # str: VLM response result | |
| # """ | |
| # try: | |
| # user_message = messages[1] | |
| # content = user_message["content"] | |
| # system_prompt = messages[0]["content"] | |
| # if isinstance(content, str): | |
| # # Pure text mode | |
| # result = await self.vision_model_func( | |
| # content, system_prompt=system_prompt | |
| # ) | |
| # else: | |
| # # Multimodal mode - pass complete messages directly to VLM | |
| # result = await self.vision_model_func( | |
| # "", # Empty prompt since we're using messages format | |
| # messages=messages, | |
| # ) | |
| # return result | |
| # except Exception as e: | |
| # self.logger.error(f"VLM call failed: {e}") | |
| # raise | |
| # # Synchronous versions of query methods | |
| # def query(self, query: str, mode: str = "mix", **kwargs) -> str: | |
| # """ | |
| # Synchronous version of pure text query | |
| # Args: | |
| # query: Query text | |
| # mode: Query mode ("local", "global", "hybrid", "naive", "mix", "bypass") | |
| # **kwargs: Other query parameters, will be passed to QueryParam | |
| # - vlm_enhanced: bool, default True when vision_model_func is available. | |
| # If True, will parse image paths in retrieved context and replace them | |
| # with base64 encoded images for VLM processing. | |
| # Returns: | |
| # str: Query result | |
| # """ | |
| # loop = always_get_an_event_loop() | |
| # return loop.run_until_complete(self.aquery(query, mode=mode, **kwargs)) | |
| # def query_with_multimodal( | |
| # self, | |
| # query: str, | |
| # multimodal_content: List[Dict[str, Any]] = None, | |
| # mode: str = "mix", | |
| # **kwargs, | |
| # ) -> str: | |
| # """ | |
| # Synchronous version of multimodal query | |
| # Args: | |
| # query: Base query text | |
| # multimodal_content: List of multimodal content, each element contains: | |
| # - type: Content type ("image", "table", "equation", etc.) | |
| # - Other fields depend on type (e.g., img_path, table_data, latex, etc.) | |
| # mode: Query mode ("local", "global", "hybrid", "naive", "mix", "bypass") | |
| # **kwargs: Other query parameters, will be passed to QueryParam | |
| # Returns: | |
| # str: Query result | |
| # """ | |
| # loop = always_get_an_event_loop() | |
| # return loop.run_until_complete( | |
| # self.aquery_with_multimodal(query, multimodal_content, mode=mode, **kwargs) | |
| # ) | |
| """ | |
| Query functionality for RAGAnything - ENHANCED VERSION | |
| Contains all query-related methods for text and multimodal queries, | |
| plus query improvement and dual-LLM verification capabilities. | |
| """ | |
| import json | |
| import hashlib | |
| import re | |
| import asyncio | |
| from typing import Dict, List, Any | |
| from pathlib import Path | |
| from lightrag import QueryParam | |
| from lightrag.utils import always_get_an_event_loop | |
| from raganything.prompt import PROMPTS | |
| from raganything.utils import ( | |
| get_processor_for_type, | |
| encode_image_to_base64, | |
| validate_image_file, | |
| ) | |
| # Import new enhancement modules | |
| from raganything.query_improvement import QueryImprovementMixin | |
| from raganything.verification import DualLLMVerificationMixin | |
| from raganything.streaming import StreamingQueryMixin | |
| class QueryMixin(QueryImprovementMixin, DualLLMVerificationMixin, StreamingQueryMixin): | |
| """ | |
| QueryMixin class containing query functionality for RAGAnything | |
| Enhanced with: | |
| - Query improvement (rewriting, expansion, decomposition) | |
| - Dual-LLM verification system | |
| - Answer modification based on feedback | |
| - Real-time streaming with verification support | |
| """ | |
| def _generate_multimodal_cache_key( | |
| self, query: str, multimodal_content: List[Dict[str, Any]], mode: str, **kwargs | |
| ) -> str: | |
| """ | |
| Generate cache key for multimodal query | |
| Args: | |
| query: Base query text | |
| multimodal_content: List of multimodal content | |
| mode: Query mode | |
| **kwargs: Additional parameters | |
| Returns: | |
| str: Cache key hash | |
| """ | |
| # Create a normalized representation of the query parameters | |
| cache_data = { | |
| "query": query.strip(), | |
| "mode": mode, | |
| } | |
| # Normalize multimodal content for stable caching | |
| normalized_content = [] | |
| if multimodal_content: | |
| for item in multimodal_content: | |
| if isinstance(item, dict): | |
| normalized_item = {} | |
| for key, value in item.items(): | |
| # For file paths, use basename to make cache more portable | |
| if key in [ | |
| "img_path", | |
| "image_path", | |
| "file_path", | |
| ] and isinstance(value, str): | |
| normalized_item[key] = Path(value).name | |
| # For large content, create a hash instead of storing directly | |
| elif ( | |
| key in ["table_data", "table_body"] | |
| and isinstance(value, str) | |
| and len(value) > 200 | |
| ): | |
| normalized_item[f"{key}_hash"] = hashlib.md5( | |
| value.encode() | |
| ).hexdigest() | |
| else: | |
| normalized_item[key] = value | |
| normalized_content.append(normalized_item) | |
| else: | |
| normalized_content.append(item) | |
| cache_data["multimodal_content"] = normalized_content | |
| # Add relevant kwargs to cache data | |
| relevant_kwargs = { | |
| k: v | |
| for k, v in kwargs.items() | |
| if k | |
| in [ | |
| "stream", | |
| "response_type", | |
| "top_k", | |
| "max_tokens", | |
| "temperature", | |
| ] | |
| } | |
| cache_data.update(relevant_kwargs) | |
| # Generate hash from the cache data | |
| cache_str = json.dumps(cache_data, sort_keys=True, ensure_ascii=False) | |
| cache_hash = hashlib.md5(cache_str.encode()).hexdigest() | |
| return f"multimodal_query:{cache_hash}" | |
| async def aquery(self, query: str, mode: str = "mix", **kwargs) -> str: | |
| """ | |
| Pure text query with optional query improvement and verification | |
| Args: | |
| query: Query text | |
| mode: Query mode ("local", "global", "hybrid", "naive", "mix", "bypass") | |
| **kwargs: Other query parameters | |
| - vlm_enhanced: bool, default True when vision_model_func is available | |
| - enable_query_improvement: bool, override config setting | |
| - enable_verification: bool, override config setting | |
| - return_verification_info: bool, return detailed verification info | |
| Returns: | |
| str: Query result (or dict if return_verification_info=True) | |
| """ | |
| if self.lightrag is None: | |
| raise ValueError( | |
| "No LightRAG instance available. Please process documents first or provide a pre-initialized LightRAG instance." | |
| ) | |
| # Check override flags | |
| use_query_improvement = kwargs.pop('enable_query_improvement', | |
| getattr(self.config, 'enable_query_improvement', False)) | |
| use_verification = kwargs.pop('enable_verification', | |
| getattr(self.config, 'enable_dual_llm_verification', False)) | |
| return_verification_info = kwargs.pop('return_verification_info', False) | |
| original_query = query | |
| query_improvement_result = None | |
| # Step 1: Apply query improvement if enabled | |
| if use_query_improvement and hasattr(self, 'query_improver') and self.query_improver: | |
| self.logger.info("Applying query improvement...") | |
| query_improvement_result = await self._apply_query_improvement(query) | |
| if not query_improvement_result["improved_query"]: | |
| self.logger.warning("Query improvement resulted in an empty query, using original query.") | |
| query = original_query | |
| else: | |
| query = query_improvement_result["improved_query"] | |
| self.logger.info(f"Query improved: '{original_query[:50]}...' -> '{query[:50]}...'") | |
| # Check if VLM enhanced query should be used | |
| vlm_enhanced = kwargs.pop("vlm_enhanced", None) | |
| # Auto-determine VLM enhanced based on availability | |
| if vlm_enhanced is None: | |
| vlm_enhanced = ( | |
| hasattr(self, "vision_model_func") | |
| and self.vision_model_func is not None | |
| ) | |
| # If using VLM enhanced or verification is disabled, use existing flow | |
| if vlm_enhanced or not use_verification: | |
| # Use VLM enhanced query if enabled and available | |
| if ( | |
| vlm_enhanced | |
| and hasattr(self, "vision_model_func") | |
| and self.vision_model_func | |
| ): | |
| result = await self.aquery_vlm_enhanced(query, mode=mode, **kwargs) | |
| elif vlm_enhanced and ( | |
| not hasattr(self, "vision_model_func") or not self.vision_model_func | |
| ): | |
| self.logger.warning( | |
| "VLM enhanced query requested but vision_model_func is not available, falling back to normal query" | |
| ) | |
| # Create query parameters | |
| query_param = QueryParam(mode=mode, **kwargs) | |
| # Call LightRAG's query method | |
| result = await self.lightrag.aquery(query, param=query_param) | |
| else: | |
| # Create query parameters | |
| query_param = QueryParam(mode=mode, **kwargs) | |
| # Call LightRAG's query method | |
| result = await self.lightrag.aquery(query, param=query_param) | |
| # Handle None result from LightRAG | |
| if result is None: | |
| result = "I couldn't find any relevant information in the knowledge base to answer your question." | |
| # Return with verification info if requested | |
| if return_verification_info: | |
| return { | |
| "answer": result, | |
| "original_query": original_query, | |
| "improved_query": query if query_improvement_result else original_query, | |
| "query_improvement": query_improvement_result, | |
| "verification_passed": True, | |
| "verification_score": 10.0, | |
| "modification_attempts": 0 | |
| } | |
| self.logger.info("Query completed") | |
| return result | |
| # Step 2: Generate with verification if enabled | |
| if use_verification and hasattr(self, 'answer_verifier') and self.answer_verifier: | |
| self.logger.info("Using dual-LLM verification...") | |
| # Get context without final answer | |
| query_param = QueryParam(mode=mode, only_need_context=True, **kwargs) | |
| context = await self.lightrag.aquery(query, param=query_param) | |
| # Check if context is None or empty | |
| if context is None or (isinstance(context, str) and not context.strip()): | |
| self.logger.warning("No context retrieved from knowledge base") | |
| no_context_answer = "I couldn't find any relevant information in the knowledge base to answer your question." | |
| if return_verification_info: | |
| return { | |
| "answer": no_context_answer, | |
| "original_query": original_query, | |
| "improved_query": query if query_improvement_result else original_query, | |
| "query_improvement": query_improvement_result, | |
| "verification_passed": False, | |
| "verification_score": 0.0, | |
| "modification_attempts": 0, | |
| "verification_history": [] | |
| } | |
| return no_context_answer | |
| # Generate with verification | |
| verification_result = await self._generate_with_verification( | |
| query=query, | |
| context=context, | |
| original_query=original_query | |
| ) | |
| if return_verification_info: | |
| return { | |
| "answer": verification_result["answer"], | |
| "original_query": original_query, | |
| "improved_query": query if query_improvement_result else original_query, | |
| "query_improvement": query_improvement_result, | |
| "verification_passed": verification_result["verification_passed"], | |
| "verification_score": verification_result["verification_score"], | |
| "modification_attempts": verification_result["modification_attempts"], | |
| "verification_history": verification_result.get("verification_history", []) | |
| } | |
| self.logger.info("Verified query completed") | |
| return verification_result["answer"] | |
| # Fallback to normal query | |
| query_param = QueryParam(mode=mode, **kwargs) | |
| result = await self.lightrag.aquery(query, param=query_param) | |
| # Handle None result from LightRAG | |
| if result is None: | |
| result = "I couldn't find any relevant information in the knowledge base to answer your question." | |
| if return_verification_info: | |
| return { | |
| "answer": result, | |
| "original_query": original_query, | |
| "improved_query": query if query_improvement_result else original_query, | |
| "query_improvement": query_improvement_result, | |
| "verification_passed": True, | |
| "verification_score": 10.0, | |
| "modification_attempts": 0 | |
| } | |
| self.logger.info("Query completed") | |
| return result | |
| async def aquery_with_multimodal( | |
| self, | |
| query: str, | |
| multimodal_content: List[Dict[str, Any]] = None, | |
| mode: str = "mix", | |
| **kwargs, | |
| ) -> str: | |
| """ | |
| Multimodal query - combines text and multimodal content for querying | |
| Args: | |
| query: Base query text | |
| multimodal_content: List of multimodal content | |
| mode: Query mode | |
| **kwargs: Other query parameters | |
| Returns: | |
| str: Query result | |
| """ | |
| # Ensure LightRAG is initialized | |
| await self._ensure_lightrag_initialized() | |
| self.logger.info(f"Executing multimodal query: {query[:100]}...") | |
| self.logger.info(f"Query mode: {mode}") | |
| # If no multimodal content, fallback to pure text query | |
| if not multimodal_content: | |
| self.logger.info("No multimodal content provided, executing text query") | |
| return await self.aquery(query, mode=mode, **kwargs) | |
| # Generate cache key for multimodal query | |
| cache_key = self._generate_multimodal_cache_key( | |
| query, multimodal_content, mode, **kwargs | |
| ) | |
| # Check cache if available and enabled | |
| cached_result = None | |
| if ( | |
| hasattr(self, "lightrag") | |
| and self.lightrag | |
| and hasattr(self.lightrag, "llm_response_cache") | |
| and self.lightrag.llm_response_cache | |
| ): | |
| if self.lightrag.llm_response_cache.global_config.get( | |
| "enable_llm_cache", True | |
| ): | |
| try: | |
| cached_result = await self.lightrag.llm_response_cache.get_by_id( | |
| cache_key | |
| ) | |
| if cached_result and isinstance(cached_result, dict): | |
| result_content = cached_result.get("return") | |
| if result_content: | |
| self.logger.info( | |
| f"Multimodal query cache hit: {cache_key[:16]}..." | |
| ) | |
| return result_content | |
| except Exception as e: | |
| self.logger.debug(f"Error accessing multimodal query cache: {e}") | |
| # Process multimodal content to generate enhanced query text | |
| enhanced_query = await self._process_multimodal_query_content( | |
| query, multimodal_content | |
| ) | |
| self.logger.info( | |
| f"Generated enhanced query length: {len(enhanced_query)} characters" | |
| ) | |
| # Execute enhanced query | |
| result = await self.aquery(enhanced_query, mode=mode, **kwargs) | |
| # Save to cache if available and enabled | |
| if ( | |
| hasattr(self, "lightrag") | |
| and self.lightrag | |
| and hasattr(self.lightrag, "llm_response_cache") | |
| and self.lightrag.llm_response_cache | |
| ): | |
| if self.lightrag.llm_response_cache.global_config.get( | |
| "enable_llm_cache", True | |
| ): | |
| try: | |
| # Create cache entry for multimodal query | |
| cache_entry = { | |
| "return": result, | |
| "cache_type": "multimodal_query", | |
| "original_query": query, | |
| "multimodal_content_count": len(multimodal_content), | |
| "mode": mode, | |
| } | |
| await self.lightrag.llm_response_cache.upsert( | |
| {cache_key: cache_entry} | |
| ) | |
| self.logger.info( | |
| f"Saved multimodal query result to cache: {cache_key[:16]}..." | |
| ) | |
| except Exception as e: | |
| self.logger.debug(f"Error saving multimodal query to cache: {e}") | |
| # Ensure cache is persisted to disk | |
| if ( | |
| hasattr(self, "lightrag") | |
| and self.lightrag | |
| and hasattr(self.lightrag, "llm_response_cache") | |
| and self.lightrag.llm_response_cache | |
| ): | |
| try: | |
| await self.lightrag.llm_response_cache.index_done_callback() | |
| except Exception as e: | |
| self.logger.debug(f"Error persisting multimodal query cache: {e}") | |
| self.logger.info("Multimodal query completed") | |
| return result | |
| async def aquery_vlm_enhanced(self, query: str, mode: str = "mix", **kwargs) -> str: | |
| """ | |
| VLM enhanced query - replaces image paths in retrieved context with base64 encoded images | |
| Args: | |
| query: User query | |
| mode: Underlying LightRAG query mode | |
| **kwargs: Other query parameters | |
| Returns: | |
| str: VLM query result | |
| """ | |
| # Ensure VLM is available | |
| if not hasattr(self, "vision_model_func") or not self.vision_model_func: | |
| raise ValueError( | |
| "VLM enhanced query requires vision_model_func. " | |
| "Please provide a vision model function when initializing RAGAnything." | |
| ) | |
| # Ensure LightRAG is initialized | |
| await self._ensure_lightrag_initialized() | |
| self.logger.info(f"Executing VLM enhanced query: {query[:100]}...") | |
| # Clear previous image cache | |
| if hasattr(self, "_current_images_base64"): | |
| delattr(self, "_current_images_base64") | |
| # 1. Get original retrieval prompt (without generating final answer) | |
| self.logger.info(f"Getting raw prompt for query: {query[:100]}...") | |
| query_param = QueryParam(mode=mode, only_need_prompt=True, **kwargs) | |
| try: | |
| raw_prompt = await self.lightrag.aquery(query, param=query_param) | |
| except Exception as e: | |
| self.logger.error(f"Error in self.lightrag.aquery: {e}", exc_info=True) | |
| raw_prompt = None | |
| self.logger.info(f"Retrieved raw prompt: {str(raw_prompt)[:200]}...") | |
| if raw_prompt is None: | |
| self.logger.warning("raw_prompt is None, falling back to normal query (single pass)") | |
| query_param = QueryParam(mode=mode, **kwargs) | |
| return await self.lightrag.aquery(query, param=query_param) | |
| self.logger.debug("Retrieved raw prompt from LightRAG") | |
| # 2. Extract and process image paths | |
| enhanced_prompt, images_found = await self._process_image_paths_for_vlm( | |
| raw_prompt | |
| ) | |
| if not images_found: | |
| self.logger.info("No valid images found, falling back to normal query WITHOUT re-retrieval") | |
| # OPTIMIZATION: Reuse the already-retrieved context instead of querying again | |
| # The raw_prompt already contains the full RAG context, so we can use it directly | |
| # Try to use the existing model function if available | |
| if hasattr(self.lightrag, 'llm_model_func') and self.lightrag.llm_model_func: | |
| try: | |
| # Generate answer using the already-retrieved context | |
| self.logger.info("Generating answer from cached context (avoiding re-query)") | |
| # Call the LLM with the raw prompt directly | |
| if asyncio.iscoroutinefunction(self.lightrag.llm_model_func): | |
| result = await self.lightrag.llm_model_func(raw_prompt) | |
| else: | |
| result = self.lightrag.llm_model_func(raw_prompt) | |
| self.logger.info("Successfully generated answer from cached context (no re-query)") | |
| return result | |
| except Exception as e: | |
| self.logger.warning(f"Failed to use cached context, falling back to re-query: {e}") | |
| # Fall back to re-query if direct generation fails | |
| query_param = QueryParam(mode=mode, **kwargs) | |
| return await self.lightrag.aquery(query, param=query_param) | |
| else: | |
| # No model_func available, must re-query (original behavior) | |
| # This maintains backward compatibility | |
| self.logger.debug("llm_model_func not available, using standard re-query") | |
| query_param = QueryParam(mode=mode, **kwargs) | |
| return await self.lightrag.aquery(query, param=query_param) | |
| self.logger.info(f"Processed {images_found} images for VLM") | |
| # 3. Build VLM message format | |
| messages = self._build_vlm_messages_with_images(enhanced_prompt, query) | |
| # 4. Call VLM for question answering | |
| result = await self._call_vlm_with_multimodal_content(messages) | |
| self.logger.info("VLM enhanced query completed") | |
| return result | |
| # ... (rest of the existing methods remain the same) ... | |
| async def _process_multimodal_query_content( | |
| self, base_query: str, multimodal_content: List[Dict[str, Any]] | |
| ) -> str: | |
| """Process multimodal query content to generate enhanced query text""" | |
| self.logger.info("Starting multimodal query content processing...") | |
| enhanced_parts = [f"User query: {base_query}"] | |
| for i, content in enumerate(multimodal_content): | |
| content_type = content.get("type", "unknown") | |
| self.logger.info( | |
| f"Processing {i+1}/{len(multimodal_content)} multimodal content: {content_type}" | |
| ) | |
| try: | |
| # Get appropriate processor | |
| processor = get_processor_for_type(self.modal_processors, content_type) | |
| if processor: | |
| # Generate content description | |
| description = await self._generate_query_content_description( | |
| processor, content, content_type | |
| ) | |
| enhanced_parts.append( | |
| f"\nRelated {content_type} content: {description}" | |
| ) | |
| else: | |
| # If no appropriate processor, use basic description | |
| basic_desc = str(content)[:200] | |
| enhanced_parts.append( | |
| f"\nRelated {content_type} content: {basic_desc}" | |
| ) | |
| except Exception as e: | |
| self.logger.error(f"Error processing multimodal content: {str(e)}") | |
| continue | |
| enhanced_query = "\n".join(enhanced_parts) | |
| enhanced_query += PROMPTS["QUERY_ENHANCEMENT_SUFFIX"] | |
| self.logger.info("Multimodal query content processing completed") | |
| return enhanced_query | |
| async def _generate_query_content_description( | |
| self, processor, content: Dict[str, Any], content_type: str | |
| ) -> str: | |
| """Generate content description for query""" | |
| try: | |
| if content_type == "image": | |
| return await self._describe_image_for_query(processor, content) | |
| elif content_type == "table": | |
| return await self._describe_table_for_query(processor, content) | |
| elif content_type == "equation": | |
| return await self._describe_equation_for_query(processor, content) | |
| else: | |
| return await self._describe_generic_for_query( | |
| processor, content, content_type | |
| ) | |
| except Exception as e: | |
| self.logger.error(f"Error generating {content_type} description: {str(e)}") | |
| return f"{content_type} content: {str(content)[:100]}" | |
| async def _describe_image_for_query( | |
| self, processor, content: Dict[str, Any] | |
| ) -> str: | |
| """Generate image description for query""" | |
| image_path = content.get("img_path") | |
| captions = content.get("image_caption", content.get("img_caption", [])) | |
| footnotes = content.get("image_footnote", content.get("img_footnote", [])) | |
| if image_path and Path(image_path).exists(): | |
| image_base64 = processor._encode_image_to_base64(image_path) | |
| if image_base64: | |
| prompt = PROMPTS["QUERY_IMAGE_DESCRIPTION"] | |
| description = await processor.modal_caption_func( | |
| prompt, | |
| image_data=image_base64, | |
| system_prompt=PROMPTS["QUERY_IMAGE_ANALYST_SYSTEM"], | |
| ) | |
| return description | |
| parts = [] | |
| if image_path: | |
| parts.append(f"Image path: {image_path}") | |
| if captions: | |
| parts.append(f"Image captions: {', '.join(captions)}") | |
| if footnotes: | |
| parts.append(f"Image footnotes: {', '.join(footnotes)}") | |
| return "; ".join(parts) if parts else "Image content information incomplete" | |
| async def _describe_table_for_query( | |
| self, processor, content: Dict[str, Any] | |
| ) -> str: | |
| """Generate table description for query""" | |
| table_data = content.get("table_data", "") | |
| table_caption = content.get("table_caption", "") | |
| prompt = PROMPTS["QUERY_TABLE_ANALYSIS"].format( | |
| table_data=table_data, table_caption=table_caption | |
| ) | |
| description = await processor.modal_caption_func( | |
| prompt, system_prompt=PROMPTS["QUERY_TABLE_ANALYST_SYSTEM"] | |
| ) | |
| return description | |
| async def _describe_equation_for_query( | |
| self, processor, content: Dict[str, Any] | |
| ) -> str: | |
| """Generate equation description for query""" | |
| latex = content.get("latex", "") | |
| equation_caption = content.get("equation_caption", "") | |
| prompt = PROMPTS["QUERY_EQUATION_ANALYSIS"].format( | |
| latex=latex, equation_caption=equation_caption | |
| ) | |
| description = await processor.modal_caption_func( | |
| prompt, system_prompt=PROMPTS["QUERY_EQUATION_ANALYST_SYSTEM"] | |
| ) | |
| return description | |
| async def _describe_generic_for_query( | |
| self, processor, content: Dict[str, Any], content_type: str | |
| ) -> str: | |
| """Generate generic content description for query""" | |
| content_str = str(content) | |
| prompt = PROMPTS["QUERY_GENERIC_ANALYSIS"].format( | |
| content_type=content_type, content_str=content_str | |
| ) | |
| description = await processor.modal_caption_func( | |
| prompt, | |
| system_prompt=PROMPTS["QUERY_GENERIC_ANALYST_SYSTEM"].format( | |
| content_type=content_type | |
| ), | |
| ) | |
| return description | |
| async def _process_image_paths_for_vlm(self, prompt: str) -> tuple[str, int]: | |
| """Process image paths in prompt, keeping original paths and adding VLM markers""" | |
| if prompt is None: | |
| self.logger.warning("prompt is None in _process_image_paths_for_vlm, returning as is") | |
| return prompt, 0 | |
| enhanced_prompt = prompt | |
| images_processed = 0 | |
| self._current_images_base64 = [] | |
| image_path_pattern = ( | |
| r"Image Path:\s*([^\r\n]*?\.(?:jpg|jpeg|png|gif|bmp|webp|tiff|tif))" | |
| ) | |
| matches = re.findall(image_path_pattern, prompt) | |
| self.logger.info(f"Found {len(matches)} image path matches in prompt") | |
| def replace_image_path(match): | |
| nonlocal images_processed | |
| image_path = match.group(1).strip() | |
| self.logger.debug(f"Processing image path: '{image_path}'") | |
| if not image_path or len(image_path) < 3: | |
| self.logger.warning(f"Invalid image path format: {image_path}") | |
| return match.group(0) | |
| self.logger.debug(f"Calling validate_image_file for: {image_path}") | |
| is_valid = validate_image_file(image_path) | |
| self.logger.debug(f"Validation result for {image_path}: {is_valid}") | |
| if not is_valid: | |
| self.logger.warning(f"Image validation failed for: {image_path}") | |
| return match.group(0) | |
| try: | |
| self.logger.debug(f"Attempting to encode image: {image_path}") | |
| image_base64 = encode_image_to_base64(image_path) | |
| if image_base64: | |
| images_processed += 1 | |
| self._current_images_base64.append(image_base64) | |
| result = f"Image Path: {image_path}\n[VLM_IMAGE_{images_processed}]" | |
| self.logger.debug( | |
| f"Successfully processed image {images_processed}: {image_path}" | |
| ) | |
| return result | |
| else: | |
| self.logger.error(f"Failed to encode image: {image_path}") | |
| return match.group(0) | |
| except Exception as e: | |
| self.logger.error(f"Failed to process image {image_path}: {e}") | |
| return match.group(0) | |
| enhanced_prompt = re.sub( | |
| image_path_pattern, replace_image_path, enhanced_prompt | |
| ) | |
| return enhanced_prompt, images_processed | |
| def _build_vlm_messages_with_images( | |
| self, enhanced_prompt: str, user_query: str | |
| ) -> List[Dict]: | |
| """Build VLM message format, using markers to correspond images with text positions""" | |
| images_base64 = getattr(self, "_current_images_base64", []) | |
| if not images_base64: | |
| return [ | |
| { | |
| "role": "user", | |
| "content": f"Context:\n{enhanced_prompt}\n\nUser Question: {user_query}", | |
| } | |
| ] | |
| content_parts = [] | |
| text_parts = enhanced_prompt.split("[VLM_IMAGE_") | |
| for i, text_part in enumerate(text_parts): | |
| if i == 0: | |
| if text_part.strip(): | |
| content_parts.append({"type": "text", "text": text_part}) | |
| else: | |
| marker_match = re.match(r"(\d+)\](.*)", text_part, re.DOTALL) | |
| if marker_match: | |
| image_num = int(marker_match.group(1)) - 1 | |
| remaining_text = marker_match.group(2) | |
| if 0 <= image_num < len(images_base64): | |
| content_parts.append( | |
| { | |
| "type": "image_url", | |
| "image_url": { | |
| "url": f"data:image/jpeg;base64,{images_base64[image_num]}" | |
| }, | |
| } | |
| ) | |
| if remaining_text.strip(): | |
| content_parts.append({"type": "text", "text": remaining_text}) | |
| content_parts.append( | |
| { | |
| "type": "text", | |
| "text": f"\n\nUser Question: {user_query}\n\nPlease answer based on the context and images provided.", | |
| } | |
| ) | |
| return [ | |
| { | |
| "role": "system", | |
| "content": "You are a helpful assistant that can analyze both text and image content to provide comprehensive answers.", | |
| }, | |
| {"role": "user", "content": content_parts}, | |
| ] | |
| async def _call_vlm_with_multimodal_content(self, messages: List[Dict]) -> str: | |
| """Call VLM to process multimodal content""" | |
| try: | |
| user_message = messages[1] | |
| content = user_message["content"] | |
| system_prompt = messages[0]["content"] | |
| if isinstance(content, str): | |
| result = await self.vision_model_func( | |
| content, system_prompt=system_prompt | |
| ) | |
| else: | |
| result = await self.vision_model_func( | |
| "", | |
| messages=messages, | |
| ) | |
| return result | |
| except Exception as e: | |
| self.logger.error(f"VLM call failed: {e}") | |
| raise | |
| # Synchronous versions of query methods | |
| def query(self, query: str, mode: str = "mix", **kwargs) -> str: | |
| """Synchronous version of pure text query""" | |
| loop = always_get_an_event_loop() | |
| return loop.run_until_complete(self.aquery(query, mode=mode, **kwargs)) | |
| def query_with_multimodal( | |
| self, | |
| query: str, | |
| multimodal_content: List[Dict[str, Any]] = None, | |
| mode: str = "mix", | |
| **kwargs, | |
| ) -> str: | |
| """Synchronous version of multimodal query""" | |
| loop = always_get_an_event_loop() | |
| return loop.run_until_complete( | |
| self.aquery_with_multimodal(query, multimodal_content, mode=mode, **kwargs) | |
| ) |