import streamlit as st import cv2 import numpy as np from PIL import Image import io from model_server import get_predictor # Page config st.set_page_config( page_title="VREyeSAM - Non-frontal Iris Segmentation", page_icon="👁️", layout="wide", initial_sidebar_state="expanded" ) # Custom CSS st.markdown(""" """, unsafe_allow_html=True) @st.cache_resource def load_model(): """Load model securely through protected server""" try: predictor = get_predictor() return predictor except Exception as e: st.error(f"Error loading model") return None def read_and_resize_image(image): """Read and resize image for processing""" img = np.array(image) if len(img.shape) == 2: # Grayscale img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) elif img.shape[2] == 4: # RGBA img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB) # Resize if needed r = np.min([1024 / img.shape[1], 1024 / img.shape[0]]) if r < 1: img = cv2.resize(img, (int(img.shape[1] * r), int(img.shape[0] * r))) return img def segment_iris(predictor, image): """Perform iris segmentation using secure model server""" return predictor.predict(image, num_samples=30) def overlay_mask_on_image(image, binary_mask, color=(0, 255, 0), alpha=0.5): """Overlay binary mask on original image""" overlay = image.copy() mask_colored = np.zeros_like(image) mask_colored[binary_mask > 0] = color # Blend result = cv2.addWeighted(overlay, 1-alpha, mask_colored, alpha, 0) return result # Main App def main(): st.title("👁️ VREyeSAM: Non-Frontal Iris Segmentation") st.markdown(""" Upload a non-frontal iris image captured in VR/AR environments, and VREyeSAM will segment the iris region using a fine-tuned SAM2 model with uncertainty-weighted loss. """) # Sidebar with st.sidebar: st.header("About VREyeSAM") st.markdown(""" **VREyeSAM** is a robust non-frontal iris segmentation framework designed for images captured under: - Varying gaze directions - Partial occlusions - Inconsistent lighting conditions **Model Performance:** - Recall: 0.870 - F1-Score: 0.806 """) st.header("Settings") show_overlay = st.checkbox("Show Mask Overlay", value=True) show_probabilistic = st.checkbox("Show Probabilistic Mask", value=False) # Load model with st.spinner("Loading VREyeSAM model..."): predictor = load_model() if predictor is None: st.error("Failed to load model. Please check the setup.") return st.success("✅ Model loaded successfully!") # File uploader with increased size limit uploaded_file = st.file_uploader( "Upload an iris image (JPG, PNG, JPEG)", type=["jpg", "png", "jpeg"], help="Upload a non-frontal iris image for segmentation" ) if uploaded_file is not None: try: # Display original image image = Image.open(uploaded_file) col1, col2 = st.columns(2) with col1: st.subheader("📷 Original Image") st.image(image, use_container_width=True) # Process button if st.button("🔍 Segment Iris", type="primary"): with st.spinner("Segmenting iris..."): try: # Prepare image img_array = read_and_resize_image(image) # Perform segmentation binary_mask, prob_mask = segment_iris(predictor, img_array) with col2: st.subheader("🎯 Binary Mask") binary_mask_img = (binary_mask * 255).astype(np.uint8) st.image(binary_mask_img, use_container_width=True) # Additional results st.markdown("---") st.subheader("📊 Segmentation Results") result_cols = st.columns(2) with result_cols[0]: if show_overlay: st.markdown("**Overlay View**") overlay = overlay_mask_on_image(img_array, binary_mask) st.image(overlay, use_container_width=True) with result_cols[1]: if show_probabilistic: st.markdown("**Probabilistic Mask**") prob_mask_img = (prob_mask * 255).astype(np.uint8) st.image(prob_mask_img, use_container_width=True) # Download options st.markdown("---") st.subheader("💾 Download Results") download_cols = st.columns(2) with download_cols[0]: # Binary mask download binary_pil = Image.fromarray(binary_mask_img) buf = io.BytesIO() binary_pil.save(buf, format="PNG") st.download_button( label="Download Binary Mask", data=buf.getvalue(), file_name="binary_mask.png", mime="image/png" ) with download_cols[1]: if show_overlay: # Overlay download overlay_pil = Image.fromarray(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)) buf = io.BytesIO() overlay_pil.save(buf, format="PNG") st.download_button( label="Download Overlay", data=buf.getvalue(), file_name="overlay.png", mime="image/png" ) # Statistics st.markdown("---") st.subheader("📈 Segmentation Statistics") stats_cols = st.columns(3) mask_area = np.sum(binary_mask > 0) total_area = binary_mask.shape[0] * binary_mask.shape[1] coverage = (mask_area / total_area) * 100 with stats_cols[0]: st.metric("Mask Coverage", f"{coverage:.2f}%") with stats_cols[1]: st.metric("Image Size", f"{img_array.shape[1]}x{img_array.shape[0]}") with stats_cols[2]: st.metric("Mask Area (pixels)", f"{mask_area:,}") except Exception as e: st.error(f"❌ Error during segmentation: {str(e)}") except Exception as e: st.error(f"❌ Error loading image: {str(e)}") st.info("Please try uploading a different image or reducing the file size.") # Footer st.markdown("---") st.markdown("""

VREyeSAM - Virtual Reality Non-Frontal Iris Segmentation

🔗 GitHub | 📧 Contact

""", unsafe_allow_html=True) if __name__ == "__main__": main()