| import streamlit as st |
| import pandas as pd |
| import plotly.express as px |
| import sahi.utils.file |
| from PIL import Image |
| from sahi import AutoDetectionModel |
| from utils import sahi_yolov8m_inference |
| from ultralyticsplus.hf_utils import download_from_hub |
|
|
| IMAGE_TO_URL = { |
| 'factory_pid.png' : 'https://d1afc1j4569hs1.cloudfront.net/factory-pid.png', |
| 'plant_pid.png' : 'https://d1afc1j4569hs1.cloudfront.net/plant-pid.png', |
| 'processing_pid.png' : 'https://d1afc1j4569hs1.cloudfront.net/processing-pid.png', |
| 'prediction_visual.png' : 'https://d1afc1j4569hs1.cloudfront.net/prediction_visual.png' |
| } |
|
|
| st.set_page_config( |
| page_title="P&ID Object Detection", |
| layout="wide", |
| initial_sidebar_state="expanded" |
| ) |
|
|
| st.title('P&ID Object Detection') |
| st.subheader(' Identify valves and pumps with deep learning model ', divider='rainbow') |
| st.markdown( |
| """ |
| <a href='https://cl.linkedin.com/in/daniel-cerda-escobar' target='_blank'><img src="https://img.icons8.com/fluency/48/000000/linkedin.png" height="30"></a> |
| </p> |
| """, |
| unsafe_allow_html=True, |
| ) |
|
|
| @st.cache_resource(show_spinner=False) |
| def get_model(postprocess_match_threshold): |
| yolov8_model_path = download_from_hub('DanielCerda/pid_yolov8') |
| detection_model = AutoDetectionModel.from_pretrained( |
| model_type='yolov8', |
| model_path=yolov8_model_path, |
| confidence_threshold=postprocess_match_threshold, |
| device="cpu", |
| ) |
| return detection_model |
| |
| @st.cache_data(show_spinner=False) |
| def download_comparison_images(): |
| sahi.utils.file.download_from_url( |
| 'https://d1afc1j4569hs1.cloudfront.net/plant-pid.png', |
| 'plant_pid.png', |
| ) |
| sahi.utils.file.download_from_url( |
| 'https://d1afc1j4569hs1.cloudfront.net/prediction_visual.png', |
| 'prediction_visual.png', |
| ) |
|
|
| download_comparison_images() |
|
|
| |
| coco_df = pd.DataFrame({ |
| 'category' : ['centrifugal-pump','centrifugal-pump','gate-valve','gate-valve','gate-valve','gate-valve','gate-valve','gate-valve','gate-valve','gate-valve','gate-valve'], |
| 'score' : [0.88, 0.85, 0.87, 0.87, 0.86, 0.86, 0.85, 0.84, 0.81, 0.81, 0.76] |
| }) |
| output_df = pd.DataFrame({ |
| 'category':['ball-valve', 'butterfly-valve', 'centrifugal-pump', 'check-valve', 'gate-valve'], |
| 'count':[0, 0, 2, 0, 9], |
| 'percentage':[0, 0, 18.2, 0, 81.8] |
| }) |
|
|
| |
| if "output_1" not in st.session_state: |
| img_1 = Image.open('plant_pid.png') |
| st.session_state["output_1"] = img_1.resize((4960,3508)) |
|
|
| if "output_2" not in st.session_state: |
| img_2 = Image.open('prediction_visual.png') |
| st.session_state["output_2"] = img_2.resize((4960,3508)) |
|
|
| if "output_3" not in st.session_state: |
| st.session_state["output_3"] = coco_df |
| |
| if "output_4" not in st.session_state: |
| st.session_state["output_4"] = output_df |
| |
|
|
| col1, col2, col3 = st.columns(3, gap='medium') |
| with col1: |
| with st.expander('How to use it'): |
| st.markdown( |
| ''' |
| 1) Upload or select any example diagram ππ» |
| 2) Set model parameters π |
| 3) Press to perform inference π |
| 4) Visualize model predictions π |
| ''' |
| ) |
|
|
| st.write('##') |
| |
| col1, col2, col3 = st.columns(3, gap='large') |
| with col1: |
| st.markdown('##### Set Input Image') |
| |
| image_file = st.file_uploader( |
| 'Upload your P&ID', type = ['jpg','jpeg','png'] |
| ) |
| |
| def radio_func(option): |
| option_to_id = { |
| 'factory_pid.png' : 'A', |
| 'plant_pid.png' : 'B', |
| 'processing_pid.png' : 'C', |
| } |
| return option_to_id[option] |
| radio = st.radio( |
| 'Select from the following examples', |
| options = ['factory_pid.png', 'plant_pid.png', 'processing_pid.png'], |
| format_func = radio_func, |
| ) |
| with col2: |
| |
| if image_file is not None: |
| image = Image.open(image_file) |
| else: |
| image = sahi.utils.cv.read_image_as_pil(IMAGE_TO_URL[radio]) |
| st.markdown('##### Preview') |
| with st.container(border = True): |
| st.image(image, use_column_width = True) |
| |
| with col3: |
| |
| st.markdown('##### Set model parameters') |
| slice_number = st.select_slider( |
| 'Slices per Image', |
| options = [ |
| '1', |
| '4', |
| '16', |
| '64', |
| ], |
| value = '4' |
| ) |
| overlap_ratio = st.slider( |
| label = 'Slicing Overlap Ratio', |
| min_value=0.0, |
| max_value=0.5, |
| value=0.1, |
| step=0.1 |
| ) |
| postprocess_match_threshold = st.slider( |
| label = 'Confidence Threshold', |
| min_value = 0.0, |
| max_value = 1.0, |
| value = 0.85, |
| step = 0.05 |
| ) |
|
|
| st.write('##') |
|
|
| col1, col2, col3 = st.columns([4, 1, 4]) |
| with col2: |
| submit = st.button("π Perform Prediction") |
| |
| if submit: |
| |
| with st.spinner(text="Downloading model weights ... "): |
| detection_model = get_model(postprocess_match_threshold) |
| |
| slice_size = int(4960/(float(slice_number)**0.5)) |
| image_size = 4960 |
|
|
| with st.spinner(text="Performing prediction ... "): |
| output_visual,coco_df,output_df = sahi_yolov8m_inference( |
| image, |
| detection_model, |
| image_size=image_size, |
| slice_height=slice_size, |
| slice_width=slice_size, |
| overlap_height_ratio=overlap_ratio, |
| overlap_width_ratio=overlap_ratio, |
| ) |
|
|
| st.session_state["output_1"] = image |
| st.session_state["output_2"] = output_visual |
| st.session_state["output_3"] = coco_df |
| st.session_state["output_4"] = output_df |
|
|
| st.write('##') |
|
|
| col1, col2, col3 = st.columns([1, 5, 1], gap='small') |
| with col2: |
| st.markdown(f"#### Object Detection Result") |
| with st.container(border = True): |
| tab1, tab2, tab3, tab4 = st.tabs(['Original Image','Inference Prediction','Data','Insights']) |
| with tab1: |
| st.image(st.session_state["output_1"]) |
| with tab2: |
| st.image(st.session_state["output_2"]) |
| with tab3: |
| col1,col2,col3 = st.columns([1,2,1]) |
| with col2: |
| st.dataframe( |
| st.session_state["output_3"], |
| column_config = { |
| 'category' : 'Predicted Category', |
| 'score' : 'Confidence', |
| }, |
| use_container_width = True, |
| hide_index = True, |
| ) |
| with tab4: |
| col1,col2,col3 = st.columns([1,5,1]) |
| with col2: |
| chart_data = st.session_state["output_4"] |
| fig = px.bar(chart_data, x='category', y='count', color='category') |
| fig.update_layout(title='Objects Detected',xaxis_title=None, yaxis_title=None, showlegend=False,yaxis=dict(tick0=0,dtick=1),bargap=0.5) |
| st.plotly_chart(fig,use_container_width=True, theme='streamlit' ) |
| |
| |
|
|