| from PIL import Image, ImageDraw |
|
|
| |
| from unet.unet_model import UNet |
|
|
| import streamlit as st |
| import plotly.express as px |
| import pandas as pd |
| import numpy as np |
| import torchvision.transforms as T |
|
|
| import torch |
| import pathlib |
| import io |
| import cv2 |
| import tempfile |
|
|
| |
| pathlib.WindowsPath = pathlib.PosixPath |
|
|
| st.title("Smart city rubbish detection Web Application") |
|
|
| def yolo(): |
| st.markdown( |
| "<h1 style='text-align: center; font-size: 36px;'>Yolo object detection</h1>", |
| unsafe_allow_html=True |
| ) |
| st.markdown( |
| "<h2 style='text-align: center; font-size: 30px;'>Using Yolov5</h2>", |
| unsafe_allow_html=True |
| ) |
|
|
| |
| default_sub_classes = [ |
| "container", |
| "waste-paper", |
| "plant", |
| "transportation", |
| "kitchenware", |
| "rubbish bag", |
| "chair", |
| "wood", |
| "electronics good", |
| "sofa", |
| "scrap metal", |
| "carton", |
| "bag", |
| "tarpaulin", |
| "accessory", |
| "rubble", |
| "table", |
| "board", |
| "mattress", |
| "beverage", |
| "tyre", |
| "nylon", |
| "rack", |
| "styrofoam", |
| "clothes", |
| "toy", |
| "furniture", |
| "trolley", |
| "carpet", |
| "plastic cup" |
| ] |
|
|
| |
| if 'video_processed' not in st.session_state: |
| st.session_state.video_processed = False |
| st.session_state.output_video_path = None |
| st.session_state.detections_summary = None |
|
|
| |
| @st.cache_resource |
| def load_model(): |
| model = torch.hub.load('./yolov5', 'custom', path='./model/yolo/best.pt', source='local', force_reload=False) |
| return model |
|
|
| model = load_model() |
|
|
| |
| model_class_names = model.names |
|
|
| |
| def get_class_indices(class_list): |
| indices = [] |
| not_found = [] |
| for cls in class_list: |
| found = False |
| for index, name in model_class_names.items(): |
| if name.lower() == cls.lower(): |
| indices.append(index) |
| found = True |
| break |
| if not found: |
| not_found.append(cls) |
| return indices, not_found |
|
|
| |
| def annotate_image(frame, results): |
| results.render() |
| annotated_frame = results.ims[0] |
| return annotated_frame |
|
|
| |
| st.markdown("### Available Classes:") |
| st.markdown("**" + ", ".join(default_sub_classes + ["rubbish"]) + "**") |
|
|
| |
| st.info("By default, the application will detect **rubbish** only.") |
|
|
| |
| custom_classes_input = st.text_input( |
| "Enter classes (comma-separated) or type 'all' to detect everything:", |
| "" |
| ) |
|
|
| |
| all_model_classes = list(model_class_names.values()) |
|
|
| |
| if custom_classes_input.strip() == "": |
| |
| selected_classes = ['rubbish'] |
| st.info("No classes entered. Using default class: **rubbish**.") |
| elif custom_classes_input.strip().lower() == "all": |
| |
| selected_classes = all_model_classes |
| st.info("Detecting **all** available classes.") |
| else: |
| |
| |
| input_classes = [cls.strip() for cls in custom_classes_input.split(",") if cls.strip()] |
| |
| if 'rubbish' not in [cls.lower() for cls in input_classes]: |
| selected_classes = input_classes + ['rubbish'] |
| st.info(f"Detecting the following classes: **{', '.join(selected_classes)}** (Including **rubbish**)") |
| else: |
| selected_classes = input_classes |
| st.info(f"Detecting the following classes: **{', '.join(selected_classes)}**") |
|
|
| |
| selected_class_indices, not_found_classes = get_class_indices(selected_classes) |
|
|
| if not_found_classes: |
| st.warning(f"The following classes were not found in the model and will be ignored: **{', '.join(not_found_classes)}**") |
|
|
| |
| if selected_class_indices: |
| |
| model.classes = selected_class_indices |
|
|
| |
| st.header("Image Object Detection") |
|
|
| uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"], key="image_upload") |
|
|
| if uploaded_file is not None: |
| try: |
| |
| image = Image.open(uploaded_file).convert('RGB') |
| st.image(image, caption="Uploaded Image", use_column_width=True) |
| st.write("Processing...") |
|
|
| |
| results = model(image) |
|
|
| |
| results_df = results.pandas().xyxy[0] |
|
|
| |
| filtered_results = results_df[results_df['name'].str.lower().isin([cls.lower() for cls in selected_classes])] |
|
|
| if filtered_results.empty: |
| st.warning("No objects detected for the selected classes.") |
| else: |
| |
| st.write("### Detection Results") |
| st.dataframe(filtered_results) |
|
|
| |
| annotated_image = annotate_image(np.array(image), results) |
|
|
| |
| annotated_pil = Image.fromarray(annotated_image) |
|
|
| |
| st.image(annotated_pil, caption="Annotated Image", use_column_width=True) |
|
|
| |
| img_byte_arr = io.BytesIO() |
| annotated_pil.save(img_byte_arr, format='PNG') |
| img_byte_arr = img_byte_arr.getvalue() |
|
|
| |
| st.download_button( |
| label="Download Annotated Image", |
| data=img_byte_arr, |
| file_name='annotated_image.png', |
| mime='image/png' |
| ) |
| except Exception as e: |
| st.error(f"An error occurred during image processing: {e}") |
|
|
| |
| st.header("Video Object Detection") |
|
|
| uploaded_video = st.file_uploader("Choose a video...", type=["mp4", "avi", "mov"], key="video_upload") |
|
|
| if uploaded_video is not None: |
| |
| |
| if st.session_state.get("uploaded_video_name") is None: |
| st.session_state.uploaded_video_name = uploaded_video.name |
| print("First time uploaded video" +st.session_state.uploaded_video_name) |
| elif st.session_state.uploaded_video_name != uploaded_video.name: |
| st.session_state.uploaded_video_name = uploaded_video.name |
| print("Another time uploaded video" +st.session_state.uploaded_video_name) |
| st.session_state.video_processed = False |
| st.session_state.output_video_path = None |
| st.session_state.detections_summary = None |
| print("New uploaded video") |
| |
| |
| if uploaded_video is None and st.session_state.video_processed: |
| st.session_state.video_processed = False |
| st.session_state.output_video_path = None |
| st.session_state.detections_summary = None |
| st.warning("Video upload has been cleared. You can upload a new video for processing.") |
|
|
| if uploaded_video: |
| if not st.session_state.video_processed: |
| try: |
| with st.spinner("Processing video..."): |
| |
| tfile = tempfile.NamedTemporaryFile(delete=False) |
| tfile.write(uploaded_video.read()) |
| tfile.close() |
|
|
| |
| video_cap = cv2.VideoCapture(tfile.name) |
| stframe = st.empty() |
|
|
| |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
| fps = video_cap.get(cv2.CAP_PROP_FPS) |
| width = int(video_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
| height = int(video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
| output_video_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name |
| out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height)) |
|
|
| frame_count = int(video_cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
| progress_bar = st.progress(0) |
|
|
| |
| all_detections = [] |
|
|
| for frame_num in range(frame_count): |
| ret, frame = video_cap.read() |
| if not ret: |
| break |
|
|
| |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
| |
| results = model(frame_rgb) |
|
|
| |
| results_df = results.pandas().xyxy[0] |
| results_df['frame_num'] = frame_num |
|
|
| |
| if not results_df.empty: |
| all_detections.append(results_df) |
|
|
| |
| annotated_frame = annotate_image(frame_rgb, results) |
|
|
| |
| annotated_bgr = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR) |
|
|
| |
| out.write(annotated_bgr) |
|
|
| |
| stframe.image(annotated_frame, channels="RGB", use_column_width=True) |
|
|
| |
| progress_percent = (frame_num + 1) / frame_count |
| progress_bar.progress(progress_percent) |
|
|
| video_cap.release() |
| out.release() |
|
|
| |
| st.session_state.output_video_path = output_video_path |
|
|
| if all_detections: |
| |
| detections_df = pd.concat(all_detections, ignore_index=True) |
|
|
| |
| detections_summary = detections_df.groupby('name').size().reset_index(name='counts') |
| st.session_state.detections_summary = detections_summary |
| else: |
| st.session_state.detections_summary = None |
|
|
| |
| st.session_state.video_processed = True |
|
|
| |
|
|
| st.success("Video processing complete!") |
|
|
| except Exception as e: |
| st.error(f"An error occurred during video processing: {e}") |
|
|
| |
| if st.session_state.video_processed: |
| try: |
| |
| with open(st.session_state.output_video_path, "rb") as video_file: |
| st.download_button( |
| label="Download Annotated Video", |
| data=video_file, |
| file_name="annotated_video.mp4", |
| mime="video/mp4" |
| ) |
|
|
| |
| if st.session_state.detections_summary is not None: |
| detections_summary = st.session_state.detections_summary |
|
|
| st.write("### Detection Summary") |
| st.dataframe(detections_summary) |
| else: |
| st.warning("No objects detected in the video for the selected classes.") |
| except Exception as e: |
| st.error(f"An error occurred while preparing the download: {e}") |
|
|
| |
| if custom_classes_input.strip().lower() == "all": |
| st.info(f"The model is set to detect **all** available classes: {', '.join(all_model_classes)}") |
|
|
| |
|
|
| |
| IMG_SIZE = 128 |
|
|
| |
| @st.cache_resource |
| def load_model(): |
| model = UNet(n_channels=3, n_classes=32) |
| model.load_state_dict(torch.load("./model/unet/checkpoint_epoch5.pth", map_location="cpu", weights_only=True), strict=False) |
| model.eval() |
| return model |
|
|
| |
| def preprocess_image(image): |
| transform = T.Compose([ |
| T.Resize((IMG_SIZE, IMG_SIZE)), |
| T.ToTensor(), |
| ]) |
| image_tensor = transform(image).unsqueeze(0) |
| return image_tensor |
|
|
| |
| def postprocess_mask(mask): |
| |
| mask_np = mask.squeeze().cpu().numpy() |
| mask_np = (mask_np > 0.5).astype(np.uint8) * 255 |
| return mask_np |
|
|
| def unet(): |
| try: |
| |
| model = load_model() |
|
|
| st.markdown( |
| "<h1 style='text-align: center; font-size: 36px;'>Unet object detection</h1>", |
| unsafe_allow_html=True |
| ) |
| st.markdown( |
| "<h2 style='text-align: center; font-size: 30px;'>Using Unet - Pytorch</h2>", |
| unsafe_allow_html=True |
| ) |
| |
| |
| uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) |
| if uploaded_file is not None: |
| st.write("Processing...") |
| |
| image = Image.open(uploaded_file).convert("RGB") |
| st.image(image, caption="Uploaded Image", use_column_width=True) |
|
|
| |
| input_tensor = preprocess_image(image) |
|
|
| |
| with torch.no_grad(): |
| output = model(input_tensor) |
| prediction = torch.sigmoid(output) |
|
|
| |
| mask = postprocess_mask(prediction[0, 0]) |
|
|
| |
| st.image(mask, caption="Segmentation Mask", use_column_width=True) |
| except Exception as e: |
| st.error(f"An error occurred in Unet: {e}") |
|
|
| |
| if 'model_selected' not in st.session_state: |
| st.session_state.model_selected = None |
|
|
| def main(): |
| |
| option = st.radio("Select Model:", ("Unet", "YOLO")) |
|
|
| |
| if st.button("Choose"): |
| st.session_state.model_selected = option |
| st.success(f"Selected Model: {st.session_state.model_selected}") |
|
|
| |
| if st.session_state.model_selected == "Unet": |
| unet() |
| elif st.session_state.model_selected == "YOLO": |
| yolo() |
|
|
| if __name__ == "__main__": |
| main() |