VREyeSAM / app.py
Dev Nagaich
Restructure: Clean repository - remove duplicates, consolidate at root
f74cf62
import streamlit as st
import cv2
import numpy as np
from PIL import Image
import io
from model_server import get_predictor
# Page config
st.set_page_config(
page_title="VREyeSAM - Non-frontal Iris Segmentation",
page_icon="πŸ‘οΈ",
layout="wide",
initial_sidebar_state="expanded"
)
# Custom CSS
st.markdown("""
<style>
.main {
padding: 2rem;
}
.stButton>button {
width: 100%;
background-color: #4CAF50;
color: white;
padding: 0.5rem;
font-size: 16px;
}
.result-box {
border: 2px solid #ddd;
border-radius: 10px;
padding: 1rem;
margin: 1rem 0;
}
</style>
""", unsafe_allow_html=True)
@st.cache_resource
def load_model():
"""Load model securely through protected server"""
try:
predictor = get_predictor()
return predictor
except Exception as e:
st.error(f"Error loading model")
return None
def read_and_resize_image(image):
"""Read and resize image for processing"""
img = np.array(image)
if len(img.shape) == 2: # Grayscale
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
elif img.shape[2] == 4: # RGBA
img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
# Resize if needed
r = np.min([1024 / img.shape[1], 1024 / img.shape[0]])
if r < 1:
img = cv2.resize(img, (int(img.shape[1] * r), int(img.shape[0] * r)))
return img
def segment_iris(predictor, image):
"""Perform iris segmentation using secure model server"""
return predictor.predict(image, num_samples=30)
def overlay_mask_on_image(image, binary_mask, color=(0, 255, 0), alpha=0.5):
"""Overlay binary mask on original image"""
overlay = image.copy()
mask_colored = np.zeros_like(image)
mask_colored[binary_mask > 0] = color
# Blend
result = cv2.addWeighted(overlay, 1-alpha, mask_colored, alpha, 0)
return result
# Main App
def main():
st.title("πŸ‘οΈ VREyeSAM: Non-Frontal Iris Segmentation")
st.markdown("""
Upload a non-frontal iris image captured in VR/AR environments, and VREyeSAM will segment the iris region
using a fine-tuned SAM2 model with uncertainty-weighted loss.
""")
# Sidebar
with st.sidebar:
st.header("About VREyeSAM")
st.markdown("""
**VREyeSAM** is a robust non-frontal iris segmentation framework designed for images captured under:
- Varying gaze directions
- Partial occlusions
- Inconsistent lighting conditions
**Model Performance:**
- Recall: 0.870
- F1-Score: 0.806
""")
st.header("Settings")
show_overlay = st.checkbox("Show Mask Overlay", value=True)
show_probabilistic = st.checkbox("Show Probabilistic Mask", value=False)
# Load model
with st.spinner("Loading VREyeSAM model..."):
predictor = load_model()
if predictor is None:
st.error("Failed to load model. Please check the setup.")
return
st.success("βœ… Model loaded successfully!")
# File uploader with increased size limit
uploaded_file = st.file_uploader(
"Upload an iris image (JPG, PNG, JPEG)",
type=["jpg", "png", "jpeg"],
help="Upload a non-frontal iris image for segmentation"
)
if uploaded_file is not None:
try:
# Display original image
image = Image.open(uploaded_file)
col1, col2 = st.columns(2)
with col1:
st.subheader("πŸ“· Original Image")
st.image(image, use_container_width=True)
# Process button
if st.button("πŸ” Segment Iris", type="primary"):
with st.spinner("Segmenting iris..."):
try:
# Prepare image
img_array = read_and_resize_image(image)
# Perform segmentation
binary_mask, prob_mask = segment_iris(predictor, img_array)
with col2:
st.subheader("🎯 Binary Mask")
binary_mask_img = (binary_mask * 255).astype(np.uint8)
st.image(binary_mask_img, use_container_width=True)
# Additional results
st.markdown("---")
st.subheader("πŸ“Š Segmentation Results")
result_cols = st.columns(2)
with result_cols[0]:
if show_overlay:
st.markdown("**Overlay View**")
overlay = overlay_mask_on_image(img_array, binary_mask)
st.image(overlay, use_container_width=True)
with result_cols[1]:
if show_probabilistic:
st.markdown("**Probabilistic Mask**")
prob_mask_img = (prob_mask * 255).astype(np.uint8)
st.image(prob_mask_img, use_container_width=True)
# Download options
st.markdown("---")
st.subheader("πŸ’Ύ Download Results")
download_cols = st.columns(2)
with download_cols[0]:
# Binary mask download
binary_pil = Image.fromarray(binary_mask_img)
buf = io.BytesIO()
binary_pil.save(buf, format="PNG")
st.download_button(
label="Download Binary Mask",
data=buf.getvalue(),
file_name="binary_mask.png",
mime="image/png"
)
with download_cols[1]:
if show_overlay:
# Overlay download
overlay_pil = Image.fromarray(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB))
buf = io.BytesIO()
overlay_pil.save(buf, format="PNG")
st.download_button(
label="Download Overlay",
data=buf.getvalue(),
file_name="overlay.png",
mime="image/png"
)
# Statistics
st.markdown("---")
st.subheader("πŸ“ˆ Segmentation Statistics")
stats_cols = st.columns(3)
mask_area = np.sum(binary_mask > 0)
total_area = binary_mask.shape[0] * binary_mask.shape[1]
coverage = (mask_area / total_area) * 100
with stats_cols[0]:
st.metric("Mask Coverage", f"{coverage:.2f}%")
with stats_cols[1]:
st.metric("Image Size", f"{img_array.shape[1]}x{img_array.shape[0]}")
with stats_cols[2]:
st.metric("Mask Area (pixels)", f"{mask_area:,}")
except Exception as e:
st.error(f"❌ Error during segmentation: {str(e)}")
except Exception as e:
st.error(f"❌ Error loading image: {str(e)}")
st.info("Please try uploading a different image or reducing the file size.")
# Footer
st.markdown("---")
st.markdown("""
<div style='text-align: center'>
<p><strong>VREyeSAM</strong> - Virtual Reality Non-Frontal Iris Segmentation</p>
<p>πŸ”— <a href='https://github.com/GeetanjaliGTZ/VREyeSAM'>GitHub</a> |
πŸ“§ <a href='mailto:geetanjalisharma546@gmail.com'>Contact</a></p>
</div>
""", unsafe_allow_html=True)
if __name__ == "__main__":
main()