Nunzio commited on
Commit
dcb04c4
·
1 Parent(s): c00c6a4

added post processing

Browse files
Files changed (2) hide show
  1. app.py +18 -3
  2. utils/imageHandling.py +13 -0
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import os, torch
2
  import gradio as gr
3
- from utils.imageHandling import hfImageToTensor, preprocessing
4
  from model.modelLoading import loadModel
5
 
6
  # %% prediction on an image
@@ -24,8 +24,20 @@ def predict(inputImage: torch.Tensor, model) -> torch.Tensor:
24
 
25
  # %% Gradio interface
26
  def run_prediction(image: gr.Image, selected_model: str)-> tuple[torch.Tensor]:
 
 
 
 
 
 
 
 
 
 
 
27
  image = hfImageToTensor(image, width=1024, height=512)
28
- return predict(image, loadModel(selected_model, device))
 
29
 
30
  # Gradio UI
31
  with gr.Blocks(title="🔀 BiSeNet | BiSeNetV2 Predictor") as demo:
@@ -44,7 +56,10 @@ with gr.Blocks(title="🔀 BiSeNet | BiSeNetV2 Predictor") as demo:
44
  submit_btn = gr.Button("Run prediction")
45
  with gr.Column():
46
  result_display = gr.Image(label="Model prediction")
47
-
 
 
 
48
  submit_btn.click(
49
  fn=run_prediction,
50
  inputs=[image_input, model_selector],
 
1
  import os, torch
2
  import gradio as gr
3
+ from utils.imageHandling import hfImageToTensor, preprocessing, postprocessing
4
  from model.modelLoading import loadModel
5
 
6
  # %% prediction on an image
 
24
 
25
  # %% Gradio interface
26
  def run_prediction(image: gr.Image, selected_model: str)-> tuple[torch.Tensor]:
27
+ if image is None:
28
+ raise ValueError("No image provided for prediction.")
29
+
30
+ if selected_model is None:
31
+ raise ValueError("No model selected for prediction.")
32
+
33
+ try:
34
+ model = loadModel(selected_model, device)
35
+ except NotImplementedError as e:
36
+ raise ValueError(f"Model loading failed: {e}")
37
+
38
  image = hfImageToTensor(image, width=1024, height=512)
39
+ prediction = predict(image, model)
40
+ return postprocessing(prediction)
41
 
42
  # Gradio UI
43
  with gr.Blocks(title="🔀 BiSeNet | BiSeNetV2 Predictor") as demo:
 
56
  submit_btn = gr.Button("Run prediction")
57
  with gr.Column():
58
  result_display = gr.Image(label="Model prediction")
59
+
60
+ result_display = gr.Image(label="Output")
61
+ error_text = gr.Textbox(label="Messaggio", interactive=False)
62
+
63
  submit_btn.click(
64
  fn=run_prediction,
65
  inputs=[image_input, model_selector],
utils/imageHandling.py CHANGED
@@ -31,3 +31,16 @@ def preprocessing(image_tensor: torch.Tensor) -> torch.Tensor:
31
  return torchvision.transforms.functional.normalize(
32
  image_tensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
33
  ).unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  return torchvision.transforms.functional.normalize(
32
  image_tensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
33
  ).unsqueeze(0)
34
+
35
+ # %% postprocessing
36
+ def postprocessing(pred: torch.Tensor) -> torch.Tensor:
37
+ """
38
+ Convert the model's output tensor to a format suitable for visualization.
39
+
40
+ Args:
41
+ pred (torch.Tensor): Model output tensor of shape (1, H, W).
42
+
43
+ Returns:
44
+ torch.Tensor: Processed tensor of shape (3, H, W) for visualization.
45
+ """
46
+ return torchvision.transforms.functional.to_pil_image(pred.squeeze(0).cpu().clamp(0, 1))