Nunzio commited on
Commit
79d4472
·
1 Parent(s): 60fd570

added prefixed images

Browse files
Files changed (2) hide show
  1. app.py +42 -3
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,8 +1,18 @@
1
  import os, torch
2
  import gradio as gr
 
 
3
  from utils.imageHandling import hfImageToTensor, preprocessing, postprocessing
4
  from model.modelLoading import loadModel
5
 
 
 
 
 
 
 
 
 
6
  # %% prediction on an image
7
 
8
  def predict(inputImage: torch.Tensor, model) -> torch.Tensor:
@@ -41,7 +51,7 @@ def run_prediction(image: gr.Image, selected_model: str)-> tuple[torch.Tensor]:
41
 
42
  # Gradio UI
43
  with gr.Blocks(title="Semantic Segmentation Predictors") as demo:
44
- gr.Markdown("## Semantic Segmentation with Real-Time Networks")
45
  gr.Markdown('A small user interface created to run semantic segmentation on images using city scapes like predictions and real time segmentation networks.')
46
  gr.Markdown("Upload an image and choose your preferred model for segmentation, or otherwise use one of the preloaded images.")
47
 
@@ -50,14 +60,35 @@ with gr.Blocks(title="Semantic Segmentation Predictors") as demo:
50
  model_selector = gr.Radio(
51
  choices=["BiSeNet", "BiSeNetV2"],
52
  value="BiSeNet",
53
- label="Select model"
54
  )
55
  image_input = gr.Image(type="pil", label="Upload image")
56
  submit_btn = gr.Button("Run prediction")
57
  with gr.Column():
58
  result_display = gr.Image(label="Model prediction", visible=True)
59
  error_text = gr.Markdown("", visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
 
 
 
 
 
 
 
 
61
  submit_btn.click(
62
  fn=run_prediction,
63
  inputs=[image_input, model_selector],
@@ -66,5 +97,13 @@ with gr.Blocks(title="Semantic Segmentation Predictors") as demo:
66
 
67
  gr.Markdown("Made by group 21 semantic segmentation project. ")
68
 
69
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
 
 
 
 
 
70
  demo.launch()
 
 
 
1
  import os, torch
2
  import gradio as gr
3
+ from PIL import Image
4
+
5
  from utils.imageHandling import hfImageToTensor, preprocessing, postprocessing
6
  from model.modelLoading import loadModel
7
 
8
+
9
+ ## %% CONSTANTS
10
+ gta_image_dir = "./preloadedImages/GTAV"
11
+ city_image_dir = "./preloadedImages/cityScapes"
12
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
13
+
14
+
15
+
16
  # %% prediction on an image
17
 
18
  def predict(inputImage: torch.Tensor, model) -> torch.Tensor:
 
51
 
52
  # Gradio UI
53
  with gr.Blocks(title="Semantic Segmentation Predictors") as demo:
54
+ gr.Markdown("# Semantic Segmentation with Real-Time Networks")
55
  gr.Markdown('A small user interface created to run semantic segmentation on images using city scapes like predictions and real time segmentation networks.')
56
  gr.Markdown("Upload an image and choose your preferred model for segmentation, or otherwise use one of the preloaded images.")
57
 
 
60
  model_selector = gr.Radio(
61
  choices=["BiSeNet", "BiSeNetV2"],
62
  value="BiSeNet",
63
+ label="Select the real time segmentation model"
64
  )
65
  image_input = gr.Image(type="pil", label="Upload image")
66
  submit_btn = gr.Button("Run prediction")
67
  with gr.Column():
68
  result_display = gr.Image(label="Model prediction", visible=True)
69
  error_text = gr.Markdown("", visible=False)
70
+
71
+ with gr.Row():
72
+ gr.Markdown("## Preloaded GTA V images to be used for testing the model")
73
+ with gr.Row():
74
+ gta_gallery = gr.Gallery(
75
+ value=sorted([Image.open(os.path.join(gta_image_dir, f)).convert("RGB") for f in os.listdir(gta_image_dir) if f.endswith(".png")]),
76
+ label="GTA V Examples",
77
+ show_label=False,
78
+ columns=5,
79
+ rows=1,
80
+ height=200,
81
+ type="pil"
82
+ )
83
 
84
+ with gr.Row():
85
+ gr.Markdown("## Preloaded Cityscapes images to be used for testing the model")
86
+ with gr.Row():
87
+ city_gallery = gr.Gallery(value=sorted([Image.open(os.path.join(city_image_dir, f)).convert("RGB") for f in os.listdir(city_image_dir) if f.endswith(".png")]),
88
+ label="Cityscapes Examples", show_label=False, columns=5, rows=1,
89
+ height=256, type="pil"
90
+ )
91
+
92
  submit_btn.click(
93
  fn=run_prediction,
94
  inputs=[image_input, model_selector],
 
97
 
98
  gr.Markdown("Made by group 21 semantic segmentation project. ")
99
 
100
+ def load_example(example_img):
101
+ return gr.update(value=example_img)
102
+
103
+ # On click: update image_input with selected example
104
+ gta_gallery.select(fn=load_example, inputs=[gta_gallery], outputs=[image_input])
105
+ city_gallery.select(fn=load_example, inputs=[city_gallery], outputs=[image_input])
106
+
107
  demo.launch()
108
+
109
+
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  torch
2
  torchvision
3
- gradio
 
 
1
  torch
2
  torchvision
3
+ gradio
4
+ pillow