File size: 6,172 Bytes
66bf19a
6a0b93e
79d4472
e749740
025a944
6a0b93e
79d4472
 
 
 
e5d95e7
79d4472
 
 
0930734
22974e8
0930734
22974e8
0930734
 
335b126
e5d95e7
4912b0e
79d4472
6a0b93e
 
5a17bb3
6a0b93e
 
 
 
 
 
 
 
 
 
 
049f834
6a0b93e
c00c6a4
6a0b93e
 
94c4671
 
dcb04c4
37570fa
dcb04c4
 
37570fa
60fd570
335b126
 
 
60fd570
 
335b126
60fd570
 
 
37570fa
94c4671
 
2eadc64
79d4472
ef2b3d9
ff83735
9fe5dbe
94c4671
 
831e26e
94c4671
22974e8
 
 
 
9213c5c
79d4472
831e26e
049f834
94c4671
ff83735
2eadc64
87e9613
e749740
c924f71
 
c952334
 
05e5639
43be5ef
0930734
c952334
 
05e5639
43be5ef
0930734
c952334
7ba72fb
c952334
ef2b3d9
79d4472
e5d95e7
53e85ab
4990abe
c1b49d2
 
d7e37f8
 
66bf19a
 
71771bf
66bf19a
c1b49d2
94c4671
 
 
37570fa
94c4671
2eadc64
e661224
94c4671
 
79d4472
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import torch
import gradio as gr

from utils.imageHandling import hfImageToTensor, preprocessing, postprocessing, loadPreloadedImages, legendHandling
from model.modelLoading import loadModel


## %% CONSTANTS
gta_image_dir = "./preloadedImages/GTAV"
city_image_dir = "./preloadedImages/cityScapes"
turin_image_dir = "./preloadedImages/turin"
device = 'cuda' if torch.cuda.is_available() else 'cpu'


MODELS = {
    # "BISENET-BASE": loadModel('bisenet', device, 'weight_Base.pth'),
    "BISENET-BEST": loadModel('bisenet', device, 'weight_Best.pth'),
    # "BISENETV2-BASE": loadModel('bisenetv2', device, 'weight_Base.pth'),
    "BISENETV2-BEST": loadModel('bisenetv2', device, 'weight_Best.pth')
}

image_list = loadPreloadedImages(gta_image_dir, city_image_dir, turin_image_dir)


# %% prediction on an image

def predict(inputImage: torch.Tensor, model) -> torch.Tensor:
    """
    Predict the segmentation mask for the input image using the provided model.
    
    Args:
        inputImage (torch.Tensor): The input image tensor.
        model (BiSeNet): The BiSeNet model for segmentation.
        
    Returns:
        prediction (torch.Tensor): The predicted segmentation mask.
    """
    with torch.no_grad():
        output = model(preprocessing(inputImage.clone()).to(device))
    output = output[0] if isinstance(output, (tuple, list)) else output
    return output[0].argmax(dim=0, keepdim=True).to(device)  


# %% Gradio interface
def run_prediction(image: gr.Image, selected_model: str)-> tuple[torch.Tensor]:
    if image is None:
        return (gr.update(value=None, visible=False), gr.update(value=f"❌ No image provided for prediction.", visible=True))
    
    if selected_model is None:
        return (gr.update(value=None, visible=False),  gr.update(value=f"❌ No model selected for prediction.", visible=True))
    
    if not isinstance(selected_model, str) or selected_model.strip().upper() not in MODELS:
        return (gr.update(value=None, visible=False),  gr.update(value=f"❌ Invalid model selected.", visible=True))

    try:
        image = hfImageToTensor(image, width=1024, height=512)
        prediction = predict(image, MODELS[selected_model.strip().upper()])
        prediction = postprocessing(prediction)
    except Exception as e:
        return (gr.update(value=None, visible=False),  gr.update(value=f"❌ {str(e)}.", visible=True))
    return (gr.update(value=prediction, visible=True), gr.update(value="", visible=False))

# Gradio UI
with gr.Blocks(title="Semantic Segmentation Predictors") as demo:
    gr.Markdown("# Semantic Segmentation with Real-Time Networks")
    gr.Markdown('A small user interface created to run semantic segmentation on images using Cityscapes-like predictions and real-time segmentation networks.')
    gr.Markdown("Upload an image and choose your preferred model for segmentation, or otherwise use one of the preloaded images.")
    gr.Markdown("The full code for the project is available on [GitHub](https://github.com/Nuzz23/MLDL_SemanticSegmentation).")
    with gr.Row():
        with gr.Column():
            image_input = gr.Image(type="pil", label="Upload image")
            model_selector = gr.Radio(
                choices=[   #"BiSeNet-Base", 
                            "BiSeNet-Best", 
                            # "BiSeNetV2-Base", 
                            "BiSeNetV2-Best"],
                value="BiSeNet-Best",
                label="Select the real time segmentation model"
                )
            submit_btn = gr.Button("Run prediction")
        with gr.Column():
            result_display = gr.Image(label="Model prediction", visible=True)
            error_text = gr.Markdown("", visible=False)
            gr.Markdown("The legend of the classes is the following (format **name** -> **color**)")
            with gr.Row():
                legend = legendHandling()
                for i in range(0, len(legend), 2):
                    with gr.Row():
                        with gr.Column(scale=1):
                            color_box0 = f"""<span style='display:inline-block; width:15px; height:15px; 
                                background-color:rgb({legend[i][3][0]},{legend[i][3][1]},{legend[i][3][2]}); margin-left:6px; border:1px solid #000;'></span>"""
                            gr.HTML(f"<div style='display:flex; align-items:center; margin-bottom:-10px; margin-top:-5px;'><b>{legend[i][1]}</b> → {color_box0}</div>")
                        with gr.Column(scale=1):
                            if i + 1 < len(legend):
                                color_box1 = f"""<span style='display:inline-block; width:15px; height:15px; 
                                    background-color:rgb({legend[i+1][3][0]},{legend[i+1][3][1]},{legend[i+1][3][2]}); margin-left:6px; border:1px solid #000;'></span>"""
                                gr.HTML(f"<div style='display:flex; align-items:center; margin-bottom:-10px; margin-top:-5px;'><b>{legend[i+1][1]}</b> → {color_box1}</div>")
                            else:
                                gr.Markdown("") 


    with gr.Row():
        gr.Markdown("## Preloaded images to be used for testing the model")
        gr.Markdown("""You can use images taken from the Grand Theft Auto V video game, the Cityscapes dataset or 
                    even the city of Turin to be used as input for the model without need for manual upload.""")

    # Mostriamo 4 righe da 5 immagini
    for i in range(0, len(image_list), 5):
        with gr.Row():
            for img in image_list[i:i+5]:
                img_comp = gr.Image(value=img, interactive=False, show_label=False, show_download_button=False, height=180, width=256,
                                    show_fullscreen_button=False, show_share_button=False, mirror_webcam=False)
                img_comp.select(fn=lambda x:x, inputs=img_comp, outputs=image_input)
    
    submit_btn.click(
        fn=run_prediction,
        inputs=[image_input, model_selector],
        outputs=[result_display, error_text],
    )

    gr.Markdown("Made by group 21 semantic segmentation project at Politecnico di Torino 2024/2025")

demo.launch()