Spaces:
Sleeping
Sleeping
Nunzio commited on
Commit ·
335b126
1
Parent(s): 9f7e1a8
lambdas
Browse files- app.py +16 -5
- model/modelLoading.py +0 -19
app.py
CHANGED
|
@@ -3,7 +3,7 @@ import gradio as gr
|
|
| 3 |
from PIL import Image
|
| 4 |
|
| 5 |
from utils.imageHandling import hfImageToTensor, preprocessing, postprocessing, loadPreloadedImages
|
| 6 |
-
from model.modelLoading import
|
| 7 |
|
| 8 |
|
| 9 |
## %% CONSTANTS
|
|
@@ -12,7 +12,16 @@ city_image_dir = "./preloadedImages/cityScapes"
|
|
| 12 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 13 |
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
def load_example(index:int=0, useGta:bool=True)-> gr.Image:
|
|
|
|
|
|
|
|
|
|
| 16 |
example_img = loadPreloadedImages(gta_image_dir if useGta else city_image_dir)
|
| 17 |
return gr.update(value=example_img[index if 0 <= index < len(example_img) else 0], visible=True)
|
| 18 |
|
|
@@ -44,10 +53,12 @@ def run_prediction(image: gr.Image, selected_model: str)-> tuple[torch.Tensor]:
|
|
| 44 |
if selected_model is None:
|
| 45 |
return (gr.update(value=None, visible=False), gr.update(value=f"❌ No model selected for prediction.", visible=True))
|
| 46 |
|
|
|
|
|
|
|
|
|
|
| 47 |
try:
|
| 48 |
-
model = loadModel(selected_model, device)
|
| 49 |
image = hfImageToTensor(image, width=1024, height=512)
|
| 50 |
-
prediction = predict(image,
|
| 51 |
prediction = postprocessing(prediction)
|
| 52 |
except Exception as e:
|
| 53 |
return (gr.update(value=None, visible=False), gr.update(value=f"❌ {str(e)}.", visible=True))
|
|
@@ -88,8 +99,8 @@ with gr.Blocks(title="Semantic Segmentation Predictors") as demo:
|
|
| 88 |
rows=1, height=256, allow_preview=False
|
| 89 |
)
|
| 90 |
|
| 91 |
-
gta_gallery.select(fn=load_example, inputs=[gta_gallery
|
| 92 |
-
city_gallery.select(fn=load_example, inputs=[city_gallery
|
| 93 |
|
| 94 |
submit_btn.click(
|
| 95 |
fn=run_prediction,
|
|
|
|
| 3 |
from PIL import Image
|
| 4 |
|
| 5 |
from utils.imageHandling import hfImageToTensor, preprocessing, postprocessing, loadPreloadedImages
|
| 6 |
+
from model.modelLoading import loadBiSeNet, loadBiSeNetV2
|
| 7 |
|
| 8 |
|
| 9 |
## %% CONSTANTS
|
|
|
|
| 12 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 13 |
|
| 14 |
|
| 15 |
+
MODELS = {
|
| 16 |
+
"BiSeNet": loadBiSeNet(device),
|
| 17 |
+
"BiSeNetV2": loadBiSeNetV2(device)
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
|
| 21 |
def load_example(index:int=0, useGta:bool=True)-> gr.Image:
|
| 22 |
+
print(type(index))
|
| 23 |
+
print(index)
|
| 24 |
+
print(useGta)
|
| 25 |
example_img = loadPreloadedImages(gta_image_dir if useGta else city_image_dir)
|
| 26 |
return gr.update(value=example_img[index if 0 <= index < len(example_img) else 0], visible=True)
|
| 27 |
|
|
|
|
| 53 |
if selected_model is None:
|
| 54 |
return (gr.update(value=None, visible=False), gr.update(value=f"❌ No model selected for prediction.", visible=True))
|
| 55 |
|
| 56 |
+
if not isinstance(selected_model, str) or selected_model.strip().upper() not in MODELS:
|
| 57 |
+
return (gr.update(value=None, visible=False), gr.update(value=f"❌ Invalid model selected.", visible=True))
|
| 58 |
+
|
| 59 |
try:
|
|
|
|
| 60 |
image = hfImageToTensor(image, width=1024, height=512)
|
| 61 |
+
prediction = predict(image, MODELS[selected_model.strip().upper()])
|
| 62 |
prediction = postprocessing(prediction)
|
| 63 |
except Exception as e:
|
| 64 |
return (gr.update(value=None, visible=False), gr.update(value=f"❌ {str(e)}.", visible=True))
|
|
|
|
| 99 |
rows=1, height=256, allow_preview=False
|
| 100 |
)
|
| 101 |
|
| 102 |
+
gta_gallery.select(fn=lambda x: load_example(x, True), inputs=[gta_gallery], outputs=[image_input])
|
| 103 |
+
city_gallery.select(fn=lambda x: load_example(x, False), inputs=[city_gallery], outputs=[image_input])
|
| 104 |
|
| 105 |
submit_btn.click(
|
| 106 |
fn=run_prediction,
|
model/modelLoading.py
CHANGED
|
@@ -3,25 +3,6 @@ import torch
|
|
| 3 |
from model.BiSeNet.build_bisenet import BiSeNet
|
| 4 |
from model.BiSeNetV2.model import BiSeNetV2
|
| 5 |
|
| 6 |
-
# general loading function
|
| 7 |
-
def loadModel(model:str = 'bisenet', device: str = 'cpu')->BiSeNet:
|
| 8 |
-
"""
|
| 9 |
-
Load the specified model and move it to the given device.
|
| 10 |
-
|
| 11 |
-
Args:
|
| 12 |
-
model (str): model to be loaded.
|
| 13 |
-
device (str): Device to load the model onto ('cpu' or 'cuda').
|
| 14 |
-
|
| 15 |
-
Returns:
|
| 16 |
-
model (BiSeNet): The loaded BiSeNet model.
|
| 17 |
-
"""
|
| 18 |
-
match model.lower() if isinstance(model, str) else model:
|
| 19 |
-
case 'bisenet': model = loadBiSeNet(device)
|
| 20 |
-
case 'bisenetv2': model = loadBiSeNetV2(device)
|
| 21 |
-
case _: raise NotImplementedError(f"Model {model} is not implemented. Please choose 'bisenet' .")
|
| 22 |
-
|
| 23 |
-
return model
|
| 24 |
-
|
| 25 |
|
| 26 |
# BiSeNet model loading function
|
| 27 |
def loadBiSeNet(device: str = 'cpu') -> BiSeNet:
|
|
|
|
| 3 |
from model.BiSeNet.build_bisenet import BiSeNet
|
| 4 |
from model.BiSeNetV2.model import BiSeNetV2
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
# BiSeNet model loading function
|
| 8 |
def loadBiSeNet(device: str = 'cpu') -> BiSeNet:
|