| import httpx |
| import os |
| import time |
| import subprocess |
| import uuid |
| from loguru import logger |
| from typing import Any, List, Tuple, Type, Literal, Optional, Union, Dict |
| import httpx |
| import os |
| import time |
| import subprocess |
| import uuid |
| import streamlit as st |
| from openai import OpenAI |
| import base64 |
| from tqdm import tqdm |
|
|
| from app.config import app_settings |
|
|
| from app.qdrant_db import MyQdrantClient |
|
|
| from app.vdr_utils import ( |
| get_text_embedding, |
| get_image_embedding, |
| pdf_folder_to_images, |
| scale_image, |
| pil_image_to_base64, |
| load_images, |
| ) |
|
|
| class VDRSession: |
| def __init__(self): |
| self.client = None |
| self.api_key = None |
| self.base_url = app_settings.GLOBAL_API_BASE |
| self.SAVE_DIR = None |
| self.db_collection = None |
| self.session_id = str(uuid.uuid4())[:5] |
| self.indexed_images = [] |
| self.model_name_list = [] |
| self.vector_db_client = None |
|
|
| def set_api_key(self, api_key: str): |
| if api_key is not None and len(api_key)>10: |
| try: |
| api_key = api_key.strip() |
| client = OpenAI(api_key=api_key, |
| base_url=self.base_url) |
| models = client.models.list() |
| if models: |
| self.api_key = api_key |
| self.client = client |
| return True |
| except Exception as e: |
| logger.debug(f'Incorrect API Key: {e}') |
|
|
| self.client = None |
| return False |
|
|
| def set_context(self, embed_model: str): |
| self.embed_model = embed_model |
|
|
| if not self.SAVE_DIR: |
| self.SAVE_DIR=os.path.join('./temp_data', self.session_id) |
| os.makedirs(self.SAVE_DIR, exist_ok=True) |
| self.SAVE_IMAGE_DIR=os.path.join(self.SAVE_DIR, 'images') |
| logger.debug(f'Created folder: {self.SAVE_DIR} and {self.SAVE_IMAGE_DIR}') |
|
|
| if not self.vector_db_client: |
| self.vector_db_client = MyQdrantClient(path=self.SAVE_DIR) |
|
|
| if not self.db_collection: |
| self.db_collection = f"qd-{embed_model}-{self.session_id}" |
| try: |
| if self.embed_model == "tsi-embedding-colqwen2-2b-v1": |
| self.vector_db_client.create_collection(self.db_collection, vector_dim=128, vector_type="colbert") |
| elif self.embed_model == "jina-embedding-clip-v1": |
| self.vector_db_client.create_collection(self.db_collection, vector_dim=768, vector_type="dense") |
| else: |
| raise ValueError(f"Embedding model {self.embed_model} not supported") |
| except Exception as e: |
| logger.error(f"Error while creating collection: {e}") |
| |
| return True |
|
|
| def get_available_vlms(self) -> List[str]: |
| assert self.client != None |
| |
| if self.model_name_list: |
| return self.model_name_list |
| try: |
| models = self.client.models.list() |
| for model in models.data: |
| model_name = model.id |
| substrings = ['gemini-2.0','claude','Qwen2.5-VL-72B-Instruct'] |
| if any(substring in model_name for substring in substrings): |
| self.model_name_list.append(model.id) |
| |
| except Exception as e: |
| logger.error(f"Error while query all models: {e}") |
| raise e |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| return self.model_name_list |
|
|
| def get_available_image_embeds(self) -> List[str]: |
| assert self.client != None |
| model_name_list = [] |
| try: |
| models = self.client.models.list() |
| for model in models.data: |
| model_name = model.id |
| substrings = ['tsi-embedding','clip'] |
| if any(substring in model_name for substring in substrings): |
| model_name_list.append(model.id) |
|
|
| except Exception as e: |
| logger.error(f"Error while query all models: {e}") |
| raise e |
| |
| return model_name_list |
|
|
| def search_images(self, text: str, top_k: int = 5) -> list[str]: |
| assert self.client != None |
| assert self.vector_db_client != None |
| try: |
| if not self.indexed_images: |
| raise Exception("No indexed images found. You need to click on 'Add selected context' button to index images.") |
| text = text.strip() |
| if len(text) < 2: |
| return False |
|
|
| embeddings = get_text_embedding( |
| texts=text, |
| openai_client=self.client, |
| model=self.embed_model |
| )[0] |
|
|
| index_results = self.vector_db_client.query_multivector( |
| multivector_input=embeddings, |
| collection_name=self.db_collection, |
| top_k=top_k |
| ) |
| image_list=[self.indexed_images[i] for i in index_results] |
| images = [] |
| for img in image_list: |
| |
| |
| encoded = pil_image_to_base64(img) |
| images.append(f"data:image/png;base64,{encoded}") |
| return images |
| except Exception as e: |
| logger.error(f"Error while generating image: {e}") |
| raise e |
|
|
| def ask(self, query: str, model: str, prompt_template: str, retrieved_context: Any, modality: str = "image", stream: bool = False) -> str: |
| assert self.client != None |
| assert query != None |
| assert prompt_template != None |
| assert retrieved_context != None |
|
|
| try: |
| prompt = prompt_template.format(user_question=query) |
| if modality == "image": |
| context = [ |
| { |
| "type": "image_url", |
| "image_url": { |
| "url": base64_image |
| } |
| } for base64_image in retrieved_context |
| ] |
| |
| content = [ |
| { |
| "type": "text", |
| "text": prompt |
| } |
| ] |
| content=content+context |
| |
| messages=[ |
| { |
| "role": "user", |
| "content": content, |
| } |
| ] |
|
|
| chat_response = self.client.chat.completions.create( |
| model=model, |
| messages=messages, |
| temperature=0.1, |
| max_tokens=2048, |
| stream=stream, |
| ) |
| if not stream: |
| return chat_response.choices[0].message.content |
| else: |
| for chunk in chat_response: |
| if chunk.choices: |
| if chunk.choices[0].delta.content is not None: |
| yield chunk.choices[0].delta.content |
| |
|
|
| except Exception as e: |
| logger.error(f"Error while asking: {e}") |
| raise e |
|
|
| def indexing(self, uploaded_files: list[str], embed_model: str, indexing_bar: Optional[st.progress] = None) -> bool: |
| self.set_context(embed_model) |
|
|
| assert self.client != None |
| assert self.db_collection != None |
| assert self.SAVE_DIR != None |
| assert self.embed_model != None |
| assert len(uploaded_files) > 0 |
|
|
| |
| for file in uploaded_files : |
| path = os.path.join(self.SAVE_DIR, file.name) |
| if os.path.exists(path): |
| print("File existed, skip") |
| continue |
| with open(path, "wb") as f: |
| f.write(file.getvalue()) |
| |
| image_path_list = pdf_folder_to_images(pdf_folder=self.SAVE_DIR, output_folder=self.SAVE_IMAGE_DIR) |
| logger.debug(f"Extracted {len(image_path_list)} images from {len(uploaded_files)} files.") |
|
|
| indexed_images = self.index_from_images(image_path_list, indexing_bar=indexing_bar) |
| logger.debug(f"Indexed {len(indexed_images)} images.") |
|
|
| self.indexed_images.extend(indexed_images) |
| return True |
|
|
| def clear_context(self): |
| self.indexed_images = [] |
| self.vector_db_client.delete_collection(self.db_collection) |
| self.db_collection = None |
| self.vector_db_client = None |
|
|
| if self.SAVE_DIR: |
| if os.path.exists(self.SAVE_DIR): |
| subprocess.run(['rm', '-rf', self.SAVE_DIR]) |
| logger.debug(f'Removed folder: {self.SAVE_DIR}') |
| self.SAVE_DIR = None |
| return True |
|
|
| def __del__(self): |
| self.clear_context() |
| logger.debug('VDR session is cleaned up.') |
|
|
| def index_from_images(self, |
| images_path_list: list, |
| batch_size: int =5, |
| indexing_bar: Optional[st.progress] = None |
| ): |
| try: |
| indexed_images = [] |
| total_len = len(images_path_list) |
| with tqdm(total=total_len, desc="Indexing Progress") as pbar: |
| for i in range(0, total_len, batch_size): |
| try: |
| batch = images_path_list[i:min(i+batch_size,total_len)] |
| |
| batch = [scale_image(x, 768) for x in batch] |
|
|
| embeddings = get_image_embedding( |
| image_list=batch, |
| openai_client=self.client, |
| model=self.embed_model |
| ) |
| self.vector_db_client.upsert_multivector( |
| index=i, |
| multivector_input_list=embeddings, |
| collection_name=self.db_collection |
| ) |
|
|
| indexed_images.extend(batch) |
| |
| pbar.update(batch_size) |
| indexing_bar.progress(i/total_len, text=f"Indexing {i}/{total_len}") |
| except Exception as e: |
| logger.exception(f"Error during indexing: {e}") |
| continue |
| |
| return indexed_images |
|
|
| logger.debug("Indexing complete!") |
| except Exception as e: |
| raise Exception(f"Error during indexing: {e}") |
|
|
| |