Nunzio commited on
Commit
335b126
·
1 Parent(s): 9f7e1a8
Files changed (2) hide show
  1. app.py +16 -5
  2. model/modelLoading.py +0 -19
app.py CHANGED
@@ -3,7 +3,7 @@ import gradio as gr
3
  from PIL import Image
4
 
5
  from utils.imageHandling import hfImageToTensor, preprocessing, postprocessing, loadPreloadedImages
6
- from model.modelLoading import loadModel
7
 
8
 
9
  ## %% CONSTANTS
@@ -12,7 +12,16 @@ city_image_dir = "./preloadedImages/cityScapes"
12
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
13
 
14
 
 
 
 
 
 
 
15
  def load_example(index:int=0, useGta:bool=True)-> gr.Image:
 
 
 
16
  example_img = loadPreloadedImages(gta_image_dir if useGta else city_image_dir)
17
  return gr.update(value=example_img[index if 0 <= index < len(example_img) else 0], visible=True)
18
 
@@ -44,10 +53,12 @@ def run_prediction(image: gr.Image, selected_model: str)-> tuple[torch.Tensor]:
44
  if selected_model is None:
45
  return (gr.update(value=None, visible=False), gr.update(value=f"❌ No model selected for prediction.", visible=True))
46
 
 
 
 
47
  try:
48
- model = loadModel(selected_model, device)
49
  image = hfImageToTensor(image, width=1024, height=512)
50
- prediction = predict(image, model)
51
  prediction = postprocessing(prediction)
52
  except Exception as e:
53
  return (gr.update(value=None, visible=False), gr.update(value=f"❌ {str(e)}.", visible=True))
@@ -88,8 +99,8 @@ with gr.Blocks(title="Semantic Segmentation Predictors") as demo:
88
  rows=1, height=256, allow_preview=False
89
  )
90
 
91
- gta_gallery.select(fn=load_example, inputs=[gta_gallery, True], outputs=[image_input])
92
- city_gallery.select(fn=load_example, inputs=[city_gallery, False], outputs=[image_input])
93
 
94
  submit_btn.click(
95
  fn=run_prediction,
 
3
  from PIL import Image
4
 
5
  from utils.imageHandling import hfImageToTensor, preprocessing, postprocessing, loadPreloadedImages
6
+ from model.modelLoading import loadBiSeNet, loadBiSeNetV2
7
 
8
 
9
  ## %% CONSTANTS
 
12
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
13
 
14
 
15
+ MODELS = {
16
+ "BiSeNet": loadBiSeNet(device),
17
+ "BiSeNetV2": loadBiSeNetV2(device)
18
+ }
19
+
20
+
21
  def load_example(index:int=0, useGta:bool=True)-> gr.Image:
22
+ print(type(index))
23
+ print(index)
24
+ print(useGta)
25
  example_img = loadPreloadedImages(gta_image_dir if useGta else city_image_dir)
26
  return gr.update(value=example_img[index if 0 <= index < len(example_img) else 0], visible=True)
27
 
 
53
  if selected_model is None:
54
  return (gr.update(value=None, visible=False), gr.update(value=f"❌ No model selected for prediction.", visible=True))
55
 
56
+ if not isinstance(selected_model, str) or selected_model.strip().upper() not in MODELS:
57
+ return (gr.update(value=None, visible=False), gr.update(value=f"❌ Invalid model selected.", visible=True))
58
+
59
  try:
 
60
  image = hfImageToTensor(image, width=1024, height=512)
61
+ prediction = predict(image, MODELS[selected_model.strip().upper()])
62
  prediction = postprocessing(prediction)
63
  except Exception as e:
64
  return (gr.update(value=None, visible=False), gr.update(value=f"❌ {str(e)}.", visible=True))
 
99
  rows=1, height=256, allow_preview=False
100
  )
101
 
102
+ gta_gallery.select(fn=lambda x: load_example(x, True), inputs=[gta_gallery], outputs=[image_input])
103
+ city_gallery.select(fn=lambda x: load_example(x, False), inputs=[city_gallery], outputs=[image_input])
104
 
105
  submit_btn.click(
106
  fn=run_prediction,
model/modelLoading.py CHANGED
@@ -3,25 +3,6 @@ import torch
3
  from model.BiSeNet.build_bisenet import BiSeNet
4
  from model.BiSeNetV2.model import BiSeNetV2
5
 
6
- # general loading function
7
- def loadModel(model:str = 'bisenet', device: str = 'cpu')->BiSeNet:
8
- """
9
- Load the specified model and move it to the given device.
10
-
11
- Args:
12
- model (str): model to be loaded.
13
- device (str): Device to load the model onto ('cpu' or 'cuda').
14
-
15
- Returns:
16
- model (BiSeNet): The loaded BiSeNet model.
17
- """
18
- match model.lower() if isinstance(model, str) else model:
19
- case 'bisenet': model = loadBiSeNet(device)
20
- case 'bisenetv2': model = loadBiSeNetV2(device)
21
- case _: raise NotImplementedError(f"Model {model} is not implemented. Please choose 'bisenet' .")
22
-
23
- return model
24
-
25
 
26
  # BiSeNet model loading function
27
  def loadBiSeNet(device: str = 'cpu') -> BiSeNet:
 
3
  from model.BiSeNet.build_bisenet import BiSeNet
4
  from model.BiSeNetV2.model import BiSeNetV2
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  # BiSeNet model loading function
8
  def loadBiSeNet(device: str = 'cpu') -> BiSeNet: