Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| import io | |
| from model_server import get_predictor | |
| # Page config | |
| st.set_page_config( | |
| page_title="VREyeSAM - Non-frontal Iris Segmentation", | |
| page_icon="ποΈ", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # Custom CSS | |
| st.markdown(""" | |
| <style> | |
| .main { | |
| padding: 2rem; | |
| } | |
| .stButton>button { | |
| width: 100%; | |
| background-color: #4CAF50; | |
| color: white; | |
| padding: 0.5rem; | |
| font-size: 16px; | |
| } | |
| .result-box { | |
| border: 2px solid #ddd; | |
| border-radius: 10px; | |
| padding: 1rem; | |
| margin: 1rem 0; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| def load_model(): | |
| """Load model securely through protected server""" | |
| try: | |
| predictor = get_predictor() | |
| return predictor | |
| except Exception as e: | |
| st.error(f"Error loading model") | |
| return None | |
| def read_and_resize_image(image): | |
| """Read and resize image for processing""" | |
| img = np.array(image) | |
| if len(img.shape) == 2: # Grayscale | |
| img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) | |
| elif img.shape[2] == 4: # RGBA | |
| img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB) | |
| # Resize if needed | |
| r = np.min([1024 / img.shape[1], 1024 / img.shape[0]]) | |
| if r < 1: | |
| img = cv2.resize(img, (int(img.shape[1] * r), int(img.shape[0] * r))) | |
| return img | |
| def segment_iris(predictor, image): | |
| """Perform iris segmentation using secure model server""" | |
| return predictor.predict(image, num_samples=30) | |
| def overlay_mask_on_image(image, binary_mask, color=(0, 255, 0), alpha=0.5): | |
| """Overlay binary mask on original image""" | |
| overlay = image.copy() | |
| mask_colored = np.zeros_like(image) | |
| mask_colored[binary_mask > 0] = color | |
| # Blend | |
| result = cv2.addWeighted(overlay, 1-alpha, mask_colored, alpha, 0) | |
| return result | |
| # Main App | |
| def main(): | |
| st.title("ποΈ VREyeSAM: Non-Frontal Iris Segmentation") | |
| st.markdown(""" | |
| Upload a non-frontal iris image captured in VR/AR environments, and VREyeSAM will segment the iris region | |
| using a fine-tuned SAM2 model with uncertainty-weighted loss. | |
| """) | |
| # Sidebar | |
| with st.sidebar: | |
| st.header("About VREyeSAM") | |
| st.markdown(""" | |
| **VREyeSAM** is a robust non-frontal iris segmentation framework designed for images captured under: | |
| - Varying gaze directions | |
| - Partial occlusions | |
| - Inconsistent lighting conditions | |
| **Model Performance:** | |
| - Recall: 0.870 | |
| - F1-Score: 0.806 | |
| """) | |
| st.header("Settings") | |
| show_overlay = st.checkbox("Show Mask Overlay", value=True) | |
| show_probabilistic = st.checkbox("Show Probabilistic Mask", value=False) | |
| # Load model | |
| with st.spinner("Loading VREyeSAM model..."): | |
| predictor = load_model() | |
| if predictor is None: | |
| st.error("Failed to load model. Please check the setup.") | |
| return | |
| st.success("β Model loaded successfully!") | |
| # File uploader with increased size limit | |
| uploaded_file = st.file_uploader( | |
| "Upload an iris image (JPG, PNG, JPEG)", | |
| type=["jpg", "png", "jpeg"], | |
| help="Upload a non-frontal iris image for segmentation" | |
| ) | |
| if uploaded_file is not None: | |
| try: | |
| # Display original image | |
| image = Image.open(uploaded_file) | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.subheader("π· Original Image") | |
| st.image(image, use_container_width=True) | |
| # Process button | |
| if st.button("π Segment Iris", type="primary"): | |
| with st.spinner("Segmenting iris..."): | |
| try: | |
| # Prepare image | |
| img_array = read_and_resize_image(image) | |
| # Perform segmentation | |
| binary_mask, prob_mask = segment_iris(predictor, img_array) | |
| with col2: | |
| st.subheader("π― Binary Mask") | |
| binary_mask_img = (binary_mask * 255).astype(np.uint8) | |
| st.image(binary_mask_img, use_container_width=True) | |
| # Additional results | |
| st.markdown("---") | |
| st.subheader("π Segmentation Results") | |
| result_cols = st.columns(2) | |
| with result_cols[0]: | |
| if show_overlay: | |
| st.markdown("**Overlay View**") | |
| overlay = overlay_mask_on_image(img_array, binary_mask) | |
| st.image(overlay, use_container_width=True) | |
| with result_cols[1]: | |
| if show_probabilistic: | |
| st.markdown("**Probabilistic Mask**") | |
| prob_mask_img = (prob_mask * 255).astype(np.uint8) | |
| st.image(prob_mask_img, use_container_width=True) | |
| # Download options | |
| st.markdown("---") | |
| st.subheader("πΎ Download Results") | |
| download_cols = st.columns(2) | |
| with download_cols[0]: | |
| # Binary mask download | |
| binary_pil = Image.fromarray(binary_mask_img) | |
| buf = io.BytesIO() | |
| binary_pil.save(buf, format="PNG") | |
| st.download_button( | |
| label="Download Binary Mask", | |
| data=buf.getvalue(), | |
| file_name="binary_mask.png", | |
| mime="image/png" | |
| ) | |
| with download_cols[1]: | |
| if show_overlay: | |
| # Overlay download | |
| overlay_pil = Image.fromarray(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)) | |
| buf = io.BytesIO() | |
| overlay_pil.save(buf, format="PNG") | |
| st.download_button( | |
| label="Download Overlay", | |
| data=buf.getvalue(), | |
| file_name="overlay.png", | |
| mime="image/png" | |
| ) | |
| # Statistics | |
| st.markdown("---") | |
| st.subheader("π Segmentation Statistics") | |
| stats_cols = st.columns(3) | |
| mask_area = np.sum(binary_mask > 0) | |
| total_area = binary_mask.shape[0] * binary_mask.shape[1] | |
| coverage = (mask_area / total_area) * 100 | |
| with stats_cols[0]: | |
| st.metric("Mask Coverage", f"{coverage:.2f}%") | |
| with stats_cols[1]: | |
| st.metric("Image Size", f"{img_array.shape[1]}x{img_array.shape[0]}") | |
| with stats_cols[2]: | |
| st.metric("Mask Area (pixels)", f"{mask_area:,}") | |
| except Exception as e: | |
| st.error(f"β Error during segmentation: {str(e)}") | |
| except Exception as e: | |
| st.error(f"β Error loading image: {str(e)}") | |
| st.info("Please try uploading a different image or reducing the file size.") | |
| # Footer | |
| st.markdown("---") | |
| st.markdown(""" | |
| <div style='text-align: center'> | |
| <p><strong>VREyeSAM</strong> - Virtual Reality Non-Frontal Iris Segmentation</p> | |
| <p>π <a href='https://github.com/GeetanjaliGTZ/VREyeSAM'>GitHub</a> | | |
| π§ <a href='mailto:geetanjalisharma546@gmail.com'>Contact</a></p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| if __name__ == "__main__": | |
| main() |