| import os |
| import streamlit as st |
| from PIL import Image as PILImage |
| from PIL import Image as pilImage |
| import base64 |
| import io |
| import chromadb |
| from initate import process_pdf |
| from utils.llm_ag import intiate_convo |
| from utils.doi import process_image_and_get_description |
|
|
| path = "mm_vdb2" |
| client = chromadb.PersistentClient(path=path) |
| import streamlit as st |
| from PIL import Image as PILImage |
|
|
| def add_background_image(image_path): |
| with open(image_path, "rb") as image_file: |
| base64_image = base64.b64encode(image_file.read()).decode() |
| css = f""" |
| <style> |
| .stApp {{ |
| background-image: url("data:image/png;base64,{base64_image}"); |
| background-size: cover; |
| background-repeat: no-repeat; |
| background-attachment: fixed; |
| }} |
| </style> |
| """ |
| st.markdown(css, unsafe_allow_html=True) |
|
|
| |
|
|
|
|
| def display_images(image_collection, query_text, max_distance=None, debug=False): |
| """ |
| Display images in a Streamlit app based on a query. |
| Args: |
| image_collection: The image collection object for querying. |
| query_text (str): The text query for images. |
| max_distance (float, optional): Maximum allowable distance for filtering. |
| debug (bool, optional): Whether to print debug information. |
| """ |
| results = image_collection.query( |
| query_texts=[query_text], |
| n_results=10, |
| include=['uris', 'distances'] |
| ) |
|
|
| uris = results['uris'][0] |
| distances = results['distances'][0] |
|
|
| |
| sorted_results = sorted(zip(uris, distances), key=lambda x: x[0]) |
|
|
| |
| cols = st.columns(3) |
|
|
| for i, (uri, distance) in enumerate(sorted_results): |
| if max_distance is None or distance <= max_distance: |
| try: |
| img = PILImage.open(uri) |
| with cols[i % 3]: |
| st.image(img, use_container_width = True) |
| except Exception as e: |
| st.error(f"Error loading image: {e}") |
|
|
| def display_videos_streamlit(video_collection, query_text, max_distance=None, max_results=5, debug=False): |
| """ |
| Display videos in a Streamlit app based on a query. |
| |
| Args: |
| video_collection: The video collection object for querying. |
| query_text (str): The text query for videos. |
| max_distance (float, optional): Maximum allowable distance for filtering. |
| max_results (int, optional): Maximum number of results to display. |
| debug (bool, optional): Whether to print debug information. |
| """ |
| |
| displayed_videos = set() |
|
|
| |
| results = video_collection.query( |
| query_texts=[query_text], |
| n_results=max_results, |
| include=['uris', 'distances', 'metadatas'] |
| ) |
|
|
| |
| uris = results['uris'][0] |
| distances = results['distances'][0] |
| metadatas = results['metadatas'][0] |
|
|
| |
| for uri, distance, metadata in zip(uris, distances, metadatas): |
| video_uri = metadata['video_uri'] |
|
|
| |
| if (max_distance is None or distance <= max_distance) and video_uri not in displayed_videos: |
| if debug: |
| st.write(f"URI: {uri} - Video URI: {video_uri} - Distance: {distance}") |
| st.video(video_uri) |
| displayed_videos.add(video_uri) |
| else: |
| if debug: |
| st.write(f"URI: {uri} - Video URI: {video_uri} - Distance: {distance} (Filtered out)") |
| |
| |
| def image_uris(image_collection,query_text, max_distance=None, max_results=5): |
| results = image_collection.query( |
| query_texts=[query_text], |
| n_results=max_results, |
| include=['uris', 'distances'] |
| ) |
|
|
| filtered_uris = [] |
| for uri, distance in zip(results['uris'][0], results['distances'][0]): |
| if max_distance is None or distance <= max_distance: |
| filtered_uris.append(uri) |
|
|
| return filtered_uris |
|
|
| def text_uris(text_collection,query_text, max_distance=None, max_results=5): |
| results = text_collection.query( |
| query_texts=[query_text], |
| n_results=max_results, |
| include=['documents', 'distances'] |
| ) |
|
|
| filtered_texts = [] |
| for doc, distance in zip(results['documents'][0], results['distances'][0]): |
| if max_distance is None or distance <= max_distance: |
| filtered_texts.append(doc) |
|
|
| return filtered_texts |
|
|
| def frame_uris(video_collection,query_text, max_distance=None, max_results=5): |
| results = video_collection.query( |
| query_texts=[query_text], |
| n_results=max_results, |
| include=['uris', 'distances'] |
| ) |
|
|
| filtered_uris = [] |
| seen_folders = set() |
|
|
| for uri, distance in zip(results['uris'][0], results['distances'][0]): |
| if max_distance is None or distance <= max_distance: |
| folder = os.path.dirname(uri) |
| if folder not in seen_folders: |
| filtered_uris.append(uri) |
| seen_folders.add(folder) |
|
|
| if len(filtered_uris) == max_results: |
| break |
|
|
| return filtered_uris |
|
|
| def image_uris2(image_collection2,query_text, max_distance=None, max_results=5): |
| results = image_collection2.query( |
| query_texts=[query_text], |
| n_results=max_results, |
| include=['uris', 'distances'] |
| ) |
|
|
| filtered_uris = [] |
| for uri, distance in zip(results['uris'][0], results['distances'][0]): |
| if max_distance is None or distance <= max_distance: |
| filtered_uris.append(uri) |
|
|
| return filtered_uris |
|
|
|
|
| def format_prompt_inputs(image_collection, text_collection, video_collection, user_query): |
| frame_candidates = frame_uris(video_collection, user_query, max_distance=1.55) |
| image_candidates = image_uris(image_collection, user_query, max_distance=1.5) |
| texts = text_uris(text_collection, user_query, max_distance=1.3) |
|
|
| inputs = {"query": user_query, "texts": texts} |
| frame = frame_candidates[0] if frame_candidates else "" |
| inputs["frame"] = frame |
|
|
| if image_candidates: |
| image = image_candidates[0] |
| with PILImage.open(image) as img: |
| img = img.resize((img.width // 6, img.height // 6)) |
| img = img.convert("L") |
| with io.BytesIO() as output: |
| img.save(output, format="JPEG", quality=60) |
| compressed_image_data = output.getvalue() |
|
|
| inputs["image_data_1"] = base64.b64encode(compressed_image_data).decode('utf-8') |
| else: |
| inputs["image_data_1"] = "" |
|
|
| return inputs |
|
|
| import time |
|
|
| def page_1(): |
| add_background_image("bg3.jpg") |
| |
| st.markdown(""" |
| <svg width="600" height="100"> |
| <text x="50%" y="50%" font-family="San serif" font-size="42px" fill="Black" text-anchor="middle" stroke="white" |
| stroke-width="0.3" stroke-linejoin="round">ADMIN - UPLOAD |
| </text> |
| </svg> |
| """, unsafe_allow_html=True) |
| uploaded_file = st.file_uploader("Upload a PDF file", type=["pdf"]) |
| if uploaded_file: |
| pdf_path = f"/tmp/{uploaded_file.name}" |
| with open(pdf_path, "wb") as f: |
| f.write(uploaded_file.getbuffer()) |
|
|
| |
| with st.spinner("Processing PDF... Please wait."): |
| try: |
| |
| time.sleep(1) |
|
|
| |
| st.text("Extracting content from PDF...") |
| image_collection, text_collection, video_collection = process_pdf(pdf_path) |
| st.session_state.image_collection = image_collection |
| st.session_state.text_collection = text_collection |
| st.session_state.video_collection = video_collection |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
| time.sleep(1) |
|
|
| st.success("PDF processed successfully! Collections saved to session state.") |
| except Exception as e: |
| st.error(f"Error processing PDF: {e}") |
|
|
| def page_2(): |
| add_background_image("bg3.jpg") |
| st.markdown(""" |
| <div style="text-align: left;"> |
| <svg width="600" height="100"> |
| <text x="0" y="50%" font-family="San serif" font-size="42px" fill="Black" stroke="white" |
| stroke-width="0.1" stroke-linejoin="round">Poss Assistant |
| </text> |
| </svg> |
| </div> |
| """, unsafe_allow_html=True) |
|
|
| if "image_collection" in st.session_state and "text_collection" in st.session_state and "video_collection" in st.session_state: |
| image_collection = st.session_state.image_collection |
| text_collection = st.session_state.text_collection |
| video_collection = st.session_state.video_collection |
| st.success("Collections loaded successfully.") |
|
|
| query = st.text_input("Enter your query", value="Example Query") |
| if query: |
| inputs = format_prompt_inputs(image_collection, text_collection, video_collection, query) |
| texts = inputs["texts"] |
| image_data_1 = inputs["image_data_1"] |
|
|
| if image_data_1: |
| image_data_1 = process_image_and_get_description(image_data_1) |
|
|
| response = intiate_convo(query, image_data_1, texts) |
| |
| st.markdown("### Assistant's Response") |
| st.markdown(response) |
|
|
| st.markdown("### Images") |
| display_images(image_collection, query, max_distance=1.55, debug=False) |
|
|
| st.markdown("### Videos") |
| frame = inputs["frame"] |
| if frame: |
| directory_name = frame.split('/')[1] |
| video_path = f"videos_flattened/{directory_name}.mp4" |
| if os.path.exists(video_path): |
| st.video(video_path) |
| else: |
| st.write("No related videos found.") |
| else: |
| st.error("Collections not found in session state. Please process the PDF on Page 1.") |
|
|
| |
|
|
| PAGES = { |
| "Upload and Process PDF": page_1, |
| "Query and Use Processed Collections": page_2 |
| } |
|
|
| |
| selected_page = st.sidebar.selectbox("Choose a page", options=list(PAGES.keys())) |
|
|
| |
| PAGES[selected_page]() |