Spaces:
Sleeping
Sleeping
File size: 6,172 Bytes
66bf19a 6a0b93e 79d4472 e749740 025a944 6a0b93e 79d4472 e5d95e7 79d4472 0930734 22974e8 0930734 22974e8 0930734 335b126 e5d95e7 4912b0e 79d4472 6a0b93e 5a17bb3 6a0b93e 049f834 6a0b93e c00c6a4 6a0b93e 94c4671 dcb04c4 37570fa dcb04c4 37570fa 60fd570 335b126 60fd570 335b126 60fd570 37570fa 94c4671 2eadc64 79d4472 ef2b3d9 ff83735 9fe5dbe 94c4671 831e26e 94c4671 22974e8 9213c5c 79d4472 831e26e 049f834 94c4671 ff83735 2eadc64 87e9613 e749740 c924f71 c952334 05e5639 43be5ef 0930734 c952334 05e5639 43be5ef 0930734 c952334 7ba72fb c952334 ef2b3d9 79d4472 e5d95e7 53e85ab 4990abe c1b49d2 d7e37f8 66bf19a 71771bf 66bf19a c1b49d2 94c4671 37570fa 94c4671 2eadc64 e661224 94c4671 79d4472 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 | import torch
import gradio as gr
from utils.imageHandling import hfImageToTensor, preprocessing, postprocessing, loadPreloadedImages, legendHandling
from model.modelLoading import loadModel
## %% CONSTANTS
gta_image_dir = "./preloadedImages/GTAV"
city_image_dir = "./preloadedImages/cityScapes"
turin_image_dir = "./preloadedImages/turin"
device = 'cuda' if torch.cuda.is_available() else 'cpu'
MODELS = {
# "BISENET-BASE": loadModel('bisenet', device, 'weight_Base.pth'),
"BISENET-BEST": loadModel('bisenet', device, 'weight_Best.pth'),
# "BISENETV2-BASE": loadModel('bisenetv2', device, 'weight_Base.pth'),
"BISENETV2-BEST": loadModel('bisenetv2', device, 'weight_Best.pth')
}
image_list = loadPreloadedImages(gta_image_dir, city_image_dir, turin_image_dir)
# %% prediction on an image
def predict(inputImage: torch.Tensor, model) -> torch.Tensor:
"""
Predict the segmentation mask for the input image using the provided model.
Args:
inputImage (torch.Tensor): The input image tensor.
model (BiSeNet): The BiSeNet model for segmentation.
Returns:
prediction (torch.Tensor): The predicted segmentation mask.
"""
with torch.no_grad():
output = model(preprocessing(inputImage.clone()).to(device))
output = output[0] if isinstance(output, (tuple, list)) else output
return output[0].argmax(dim=0, keepdim=True).to(device)
# %% Gradio interface
def run_prediction(image: gr.Image, selected_model: str)-> tuple[torch.Tensor]:
if image is None:
return (gr.update(value=None, visible=False), gr.update(value=f"❌ No image provided for prediction.", visible=True))
if selected_model is None:
return (gr.update(value=None, visible=False), gr.update(value=f"❌ No model selected for prediction.", visible=True))
if not isinstance(selected_model, str) or selected_model.strip().upper() not in MODELS:
return (gr.update(value=None, visible=False), gr.update(value=f"❌ Invalid model selected.", visible=True))
try:
image = hfImageToTensor(image, width=1024, height=512)
prediction = predict(image, MODELS[selected_model.strip().upper()])
prediction = postprocessing(prediction)
except Exception as e:
return (gr.update(value=None, visible=False), gr.update(value=f"❌ {str(e)}.", visible=True))
return (gr.update(value=prediction, visible=True), gr.update(value="", visible=False))
# Gradio UI
with gr.Blocks(title="Semantic Segmentation Predictors") as demo:
gr.Markdown("# Semantic Segmentation with Real-Time Networks")
gr.Markdown('A small user interface created to run semantic segmentation on images using Cityscapes-like predictions and real-time segmentation networks.')
gr.Markdown("Upload an image and choose your preferred model for segmentation, or otherwise use one of the preloaded images.")
gr.Markdown("The full code for the project is available on [GitHub](https://github.com/Nuzz23/MLDL_SemanticSegmentation).")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload image")
model_selector = gr.Radio(
choices=[ #"BiSeNet-Base",
"BiSeNet-Best",
# "BiSeNetV2-Base",
"BiSeNetV2-Best"],
value="BiSeNet-Best",
label="Select the real time segmentation model"
)
submit_btn = gr.Button("Run prediction")
with gr.Column():
result_display = gr.Image(label="Model prediction", visible=True)
error_text = gr.Markdown("", visible=False)
gr.Markdown("The legend of the classes is the following (format **name** -> **color**)")
with gr.Row():
legend = legendHandling()
for i in range(0, len(legend), 2):
with gr.Row():
with gr.Column(scale=1):
color_box0 = f"""<span style='display:inline-block; width:15px; height:15px;
background-color:rgb({legend[i][3][0]},{legend[i][3][1]},{legend[i][3][2]}); margin-left:6px; border:1px solid #000;'></span>"""
gr.HTML(f"<div style='display:flex; align-items:center; margin-bottom:-10px; margin-top:-5px;'><b>{legend[i][1]}</b> → {color_box0}</div>")
with gr.Column(scale=1):
if i + 1 < len(legend):
color_box1 = f"""<span style='display:inline-block; width:15px; height:15px;
background-color:rgb({legend[i+1][3][0]},{legend[i+1][3][1]},{legend[i+1][3][2]}); margin-left:6px; border:1px solid #000;'></span>"""
gr.HTML(f"<div style='display:flex; align-items:center; margin-bottom:-10px; margin-top:-5px;'><b>{legend[i+1][1]}</b> → {color_box1}</div>")
else:
gr.Markdown("")
with gr.Row():
gr.Markdown("## Preloaded images to be used for testing the model")
gr.Markdown("""You can use images taken from the Grand Theft Auto V video game, the Cityscapes dataset or
even the city of Turin to be used as input for the model without need for manual upload.""")
# Mostriamo 4 righe da 5 immagini
for i in range(0, len(image_list), 5):
with gr.Row():
for img in image_list[i:i+5]:
img_comp = gr.Image(value=img, interactive=False, show_label=False, show_download_button=False, height=180, width=256,
show_fullscreen_button=False, show_share_button=False, mirror_webcam=False)
img_comp.select(fn=lambda x:x, inputs=img_comp, outputs=image_input)
submit_btn.click(
fn=run_prediction,
inputs=[image_input, model_selector],
outputs=[result_display, error_text],
)
gr.Markdown("Made by group 21 semantic segmentation project at Politecnico di Torino 2024/2025")
demo.launch()
|