Nunzio commited on
Commit
049f834
Β·
1 Parent(s): c06c582
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -17,22 +17,21 @@ def predict(inputImage: torch.Tensor, model) -> torch.Tensor:
17
  prediction (torch.Tensor): The predicted segmentation mask.
18
  """
19
  with torch.no_grad():
20
- output = model(preprocessing(inputImage.clone()).to(model.device))
21
  output = output[0] if isinstance(output, (tuple, list)) else output
22
  return output[0].argmax(dim=0, keepdim=True).cpu()
23
 
24
 
25
  # %% Gradio interface
26
  def run_prediction(image: gr.Image, selected_model: str)-> tuple[torch.Tensor]:
27
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
28
  image = hfImageToTensor(image, width=1024, height=512)
29
- return image, predict(image, loadModel(selected_model, device))
30
 
31
  # Gradio UI
32
  with gr.Blocks(title="πŸ”€ BiSeNet | BiSeNetV2 Predictor") as demo:
33
- gr.Markdown("## 🧠 Image Segmentation with BiSeNet and BiSeNetV2")
34
- gr.Markdown("Upload an image and choose your preferred model for segmentation.")
35
  gr.Markdown('A small user interface created to run semantic segmentation on images using city scapes like predictions and real time segmentation networks.')
 
36
 
37
  with gr.Row():
38
  with gr.Column():
@@ -42,7 +41,7 @@ with gr.Blocks(title="πŸ”€ BiSeNet | BiSeNetV2 Predictor") as demo:
42
  label="Select model"
43
  )
44
  image_input = gr.Image(type="pil", label="Upload image")
45
- submit_btn = gr.Button("πŸ§ͺ Run prediction")
46
  with gr.Column():
47
  result_display = gr.Image(label="Model prediction")
48
 
@@ -54,4 +53,5 @@ with gr.Blocks(title="πŸ”€ BiSeNet | BiSeNetV2 Predictor") as demo:
54
 
55
  gr.Markdown("Made by group 21 semantic segmentation project. ")
56
 
 
57
  demo.launch()
 
17
  prediction (torch.Tensor): The predicted segmentation mask.
18
  """
19
  with torch.no_grad():
20
+ output = model(preprocessing(inputImage.clone()).to(device))
21
  output = output[0] if isinstance(output, (tuple, list)) else output
22
  return output[0].argmax(dim=0, keepdim=True).cpu()
23
 
24
 
25
  # %% Gradio interface
26
  def run_prediction(image: gr.Image, selected_model: str)-> tuple[torch.Tensor]:
 
27
  image = hfImageToTensor(image, width=1024, height=512)
28
+ return predict(image, loadModel(selected_model, device))
29
 
30
  # Gradio UI
31
  with gr.Blocks(title="πŸ”€ BiSeNet | BiSeNetV2 Predictor") as demo:
32
+ gr.Markdown("## Semantic Segmentation with Real-Time Networks")
 
33
  gr.Markdown('A small user interface created to run semantic segmentation on images using city scapes like predictions and real time segmentation networks.')
34
+ gr.Markdown("Upload an image and choose your preferred model for segmentation.")
35
 
36
  with gr.Row():
37
  with gr.Column():
 
41
  label="Select model"
42
  )
43
  image_input = gr.Image(type="pil", label="Upload image")
44
+ submit_btn = gr.Button("Run prediction")
45
  with gr.Column():
46
  result_display = gr.Image(label="Model prediction")
47
 
 
53
 
54
  gr.Markdown("Made by group 21 semantic segmentation project. ")
55
 
56
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
57
  demo.launch()