Nunzio commited on
Commit
5a17bb3
·
1 Parent(s): 94c4671
Files changed (2) hide show
  1. app.py +5 -41
  2. model/modelLoading.py +41 -0
app.py CHANGED
@@ -1,11 +1,11 @@
1
  import os, torch
2
- from model.BiSeNet.build_bisenet import BiSeNet
3
  import gradio as gr
4
  from utils.imageHandling import hfImageToTensor, preprocessing
 
5
 
6
  # %% prediction on an image
7
 
8
- def predict(inputImage: torch.Tensor, model: BiSeNet) -> torch.Tensor:
9
  """
10
  Predict the segmentation mask for the input image using the provided model.
11
 
@@ -22,45 +22,6 @@ def predict(inputImage: torch.Tensor, model: BiSeNet) -> torch.Tensor:
22
  return output[0].argmax(dim=0, keepdim=True).cpu()
23
 
24
 
25
-
26
- # %% load model
27
-
28
- def loadModel(model:str = 'bisenet', device: str = 'cpu')->BiSeNet:
29
- """
30
- Load the specified model and move it to the given device.
31
-
32
- Args:
33
- model (str): model to be loaded.
34
- device (str): Device to load the model onto ('cpu' or 'cuda').
35
-
36
- Returns:
37
- model (BiSeNet): The loaded BiSeNet model.
38
- """
39
- match model.lower() if isinstance(model, str) else model:
40
- case 'bisenet': model = loadBiSeNet(device)
41
- case _: raise NotImplementedError(f"Model {model} is not implemented. Please choose 'bisenet' .")
42
-
43
- return model
44
-
45
-
46
- # BiSeNet model loading function
47
- def loadBiSeNet(device: str = 'cpu') -> BiSeNet:
48
- """
49
- Load the BiSeNet model and move it to the specified device.
50
-
51
- Args:
52
- device (str): Device to load the model onto ('cpu' or 'cuda').
53
-
54
- Returns:
55
- model (BiSeNet): The loaded BiSeNet model.
56
- """
57
- model = BiSeNet(n_classes=19, context_path='resnet18').to(device)
58
- model.load_state_dict(torch.load('./weights/BiSeNet/weightADV.pth', map_location=device))
59
- model.eval()
60
-
61
- return model
62
-
63
-
64
  # %% Gradio interface
65
  def run_prediction(image: gr.Image, selected_model: str)-> tuple[torch.Tensor]:
66
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
@@ -71,6 +32,7 @@ def run_prediction(image: gr.Image, selected_model: str)-> tuple[torch.Tensor]:
71
  with gr.Blocks(title="🔀 BiSeNet | BiSeNetV2 Predictor") as demo:
72
  gr.Markdown("## 🧠 Image Segmentation with BiSeNet and BiSeNetV2")
73
  gr.Markdown("Upload an image and choose your preferred model for segmentation.")
 
74
 
75
  with gr.Row():
76
  with gr.Column():
@@ -89,5 +51,7 @@ with gr.Blocks(title="🔀 BiSeNet | BiSeNetV2 Predictor") as demo:
89
  inputs=[image_input, model_selector],
90
  outputs=[result_display]
91
  )
 
 
92
 
93
  demo.launch()
 
1
  import os, torch
 
2
  import gradio as gr
3
  from utils.imageHandling import hfImageToTensor, preprocessing
4
+ from model.modelLoading import loadModel
5
 
6
  # %% prediction on an image
7
 
8
+ def predict(inputImage: torch.Tensor, model) -> torch.Tensor:
9
  """
10
  Predict the segmentation mask for the input image using the provided model.
11
 
 
22
  return output[0].argmax(dim=0, keepdim=True).cpu()
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  # %% Gradio interface
26
  def run_prediction(image: gr.Image, selected_model: str)-> tuple[torch.Tensor]:
27
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
32
  with gr.Blocks(title="🔀 BiSeNet | BiSeNetV2 Predictor") as demo:
33
  gr.Markdown("## 🧠 Image Segmentation with BiSeNet and BiSeNetV2")
34
  gr.Markdown("Upload an image and choose your preferred model for segmentation.")
35
+ gr.Markdown('A small user interface created to run semantic segmentation on images using city scapes like predictions and real time segmentation networks.')
36
 
37
  with gr.Row():
38
  with gr.Column():
 
51
  inputs=[image_input, model_selector],
52
  outputs=[result_display]
53
  )
54
+
55
+ gr.Markdown("Made by group 21 semantic segmentation project. ")
56
 
57
  demo.launch()
model/modelLoading.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from model.BiSeNet.build_bisenet import BiSeNet
4
+
5
+
6
+ # %% load model
7
+
8
+ def loadModel(model:str = 'bisenet', device: str = 'cpu')->BiSeNet:
9
+ """
10
+ Load the specified model and move it to the given device.
11
+
12
+ Args:
13
+ model (str): model to be loaded.
14
+ device (str): Device to load the model onto ('cpu' or 'cuda').
15
+
16
+ Returns:
17
+ model (BiSeNet): The loaded BiSeNet model.
18
+ """
19
+ match model.lower() if isinstance(model, str) else model:
20
+ case 'bisenet': model = loadBiSeNet(device)
21
+ case _: raise NotImplementedError(f"Model {model} is not implemented. Please choose 'bisenet' .")
22
+
23
+ return model
24
+
25
+
26
+ # BiSeNet model loading function
27
+ def loadBiSeNet(device: str = 'cpu') -> BiSeNet:
28
+ """
29
+ Load the BiSeNet model and move it to the specified device.
30
+
31
+ Args:
32
+ device (str): Device to load the model onto ('cpu' or 'cuda').
33
+
34
+ Returns:
35
+ model (BiSeNet): The loaded BiSeNet model.
36
+ """
37
+ model = BiSeNet(num_classes=19, context_path='resnet18').to(device)
38
+ model.load_state_dict(torch.load('./weights/BiSeNet/weightADV.pth', map_location=device))
39
+ model.eval()
40
+
41
+ return model