Nunzio commited on
Commit
3f820ff
·
1 Parent(s): 9d4bb4b

simpler model loading

Browse files
Files changed (2) hide show
  1. app.py +4 -4
  2. model/modelLoading.py +19 -20
app.py CHANGED
@@ -13,10 +13,10 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu'
13
 
14
 
15
  MODELS = {
16
- "BISENET-BASE": loadBiSeNet(device, 'weight_Base.pth'),
17
- "BISENET-BEST": loadBiSeNet(device, 'weight_Best.pth'),
18
- "BISENETV2-BASE": loadBiSeNetV2(device, 'weight_Base.pth'),
19
- "BISENETV2-BEST": loadBiSeNetV2(device, 'weight_Best.pth')
20
  }
21
 
22
  image_list = loadPreloadedImages(gta_image_dir, city_image_dir, turin_image_dir)
 
13
 
14
 
15
  MODELS = {
16
+ "BISENET-BASE": loadBiSeNet('bisenet', device, 'weight_Base.pth'),
17
+ "BISENET-BEST": loadBiSeNet('bisenet', device, 'weight_Best.pth'),
18
+ "BISENETV2-BASE": loadBiSeNetV2('bisenetv2', device, 'weight_Base.pth'),
19
+ "BISENETV2-BEST": loadBiSeNetV2('bisenetv2', device, 'weight_Best.pth')
20
  }
21
 
22
  image_list = loadPreloadedImages(gta_image_dir, city_image_dir, turin_image_dir)
model/modelLoading.py CHANGED
@@ -4,38 +4,37 @@ from model.BiSeNet.build_bisenet import BiSeNet
4
  from model.BiSeNetV2.model import BiSeNetV2
5
 
6
 
7
- # BiSeNet model loading function
8
- def loadBiSeNet(device: str = 'cpu', weights:str='weight_Base.pth') -> BiSeNet:
9
  """
10
- Load the BiSeNet model and move it to the specified device.
 
 
 
11
 
12
  Args:
 
 
13
  device (str): Device to load the model onto ('cpu' or 'cuda'). Default is 'cpu'.
14
  weights (str): weights file to be loaded. Default is 'weight_Base.pth'.
15
 
16
  Returns:
17
- model (BiSeNet): The loaded BiSeNet model.
18
  """
19
- model = BiSeNet(num_classes=19, context_path='resnet18').to(device)
20
- model.load_state_dict(torch.load(f'./weights/BiSeNet/{weights}', map_location=device)['model_state_dict'])
21
- model.eval()
22
 
23
- return model
 
 
 
 
24
 
 
 
 
 
25
 
26
- def loadBiSeNetV2(device: str = 'cpu', weights:str='weight_Base.pth') -> BiSeNetV2:
27
- """
28
- Load the BiSeNetV2 model and move it to the specified device.
29
-
30
- Args:
31
- device (str): Device to load the model onto ('cpu' or 'cuda'). Default is 'cpu'.
32
- weights (str): weights file to be loaded. Default is 'weight_Base.pth'.
33
 
34
- Returns:
35
- model (BiSeNetV2): The loaded BiSeNetV2 model.
36
- """
37
- model = BiSeNetV2(n_classes=19).to(device)
38
- model.load_state_dict(torch.load(f'./weights/BiSeNetV2/{weights}', map_location=device)['model_state_dict'])
39
  model.eval()
40
 
41
  return model
 
4
  from model.BiSeNetV2.model import BiSeNetV2
5
 
6
 
7
+ # Model loading function
8
+ def loadModel(model:str = 'bisenet', device: str = 'cpu', weights:str='weight_Base.pth') -> BiSeNet | BiSeNetV2:
9
  """
10
+ Load the BiSeNet or BiSeNetV2 model and move it to the specified device.
11
+ This function supports loading different versions of the model based on the provided `model` argument.
12
+ The model weights are loaded from the specified `weights` file.
13
+ The model is set to evaluation mode after loading.
14
 
15
  Args:
16
+ model (str): The type of model to load. Options are 'bisenet', 'bisenet_base', 'bisenet_best', 'bisenetv2', 'bisenetv2_base', 'bisenetv2_best'.
17
+ Default is 'bisenet'.
18
  device (str): Device to load the model onto ('cpu' or 'cuda'). Default is 'cpu'.
19
  weights (str): weights file to be loaded. Default is 'weight_Base.pth'.
20
 
21
  Returns:
22
+ model (BiSeNet | BiSeNetV2): The loaded BiSeNet or BiSeNetV2 model.
23
  """
 
 
 
24
 
25
+ match model.lower() if isinstance(model, str) else model:
26
+ case 'bisenet' | 'bisenet_base' | 'bisenet_best':
27
+ model = BiSeNet(num_classes=19, context_path='resnet18').to(device)
28
+ modelStateDict = torch.load(f'./weights/BiSeNet/{weights}', map_location=device)
29
+ model.load_state_dict(modelStateDict['model_state_dict'] if 'model_state_dict' in modelStateDict else modelStateDict)
30
 
31
+ case 'bisenetv2' | 'bisenetv2_base' | 'bisenetv2_best':
32
+ model = BiSeNetV2(n_classes=19).to(device)
33
+ modelStateDict = torch.load(f'./weights/BiSeNetV2/{weights}', map_location=device)
34
+ model.load_state_dict(modelStateDict['model_state_dict'] if 'model_state_dict' in modelStateDict else modelStateDict)
35
 
36
+ case _: raise NotImplementedError(f"Model {model} is not implemented. Please choose 'bisenet' or 'bisenetv2'.")
 
 
 
 
 
 
37
 
 
 
 
 
 
38
  model.eval()
39
 
40
  return model