Nunzio commited on
Commit
9d4bb4b
·
1 Parent(s): 05e5639

fixed weight loading

Browse files
Files changed (1) hide show
  1. model/modelLoading.py +7 -6
model/modelLoading.py CHANGED
@@ -5,13 +5,13 @@ from model.BiSeNetV2.model import BiSeNetV2
5
 
6
 
7
  # BiSeNet model loading function
8
- def loadBiSeNet(device: str = 'cpu', weights:str='weightADV.pth') -> BiSeNet:
9
  """
10
  Load the BiSeNet model and move it to the specified device.
11
 
12
  Args:
13
- device (str): Device to load the model onto ('cpu' or 'cuda').
14
- weights (str): weights file to be loaded
15
 
16
  Returns:
17
  model (BiSeNet): The loaded BiSeNet model.
@@ -23,18 +23,19 @@ def loadBiSeNet(device: str = 'cpu', weights:str='weightADV.pth') -> BiSeNet:
23
  return model
24
 
25
 
26
- def loadBiSeNetV2(device: str = 'cpu') -> BiSeNetV2:
27
  """
28
  Load the BiSeNetV2 model and move it to the specified device.
29
 
30
  Args:
31
- device (str): Device to load the model onto ('cpu' or 'cuda').
 
32
 
33
  Returns:
34
  model (BiSeNetV2): The loaded BiSeNetV2 model.
35
  """
36
  model = BiSeNetV2(n_classes=19).to(device)
37
- model.load_state_dict(torch.load('./weights/BiSeNetV2/weightADV.pth', map_location=device)['model_state_dict'])
38
  model.eval()
39
 
40
  return model
 
5
 
6
 
7
  # BiSeNet model loading function
8
+ def loadBiSeNet(device: str = 'cpu', weights:str='weight_Base.pth') -> BiSeNet:
9
  """
10
  Load the BiSeNet model and move it to the specified device.
11
 
12
  Args:
13
+ device (str): Device to load the model onto ('cpu' or 'cuda'). Default is 'cpu'.
14
+ weights (str): weights file to be loaded. Default is 'weight_Base.pth'.
15
 
16
  Returns:
17
  model (BiSeNet): The loaded BiSeNet model.
 
23
  return model
24
 
25
 
26
+ def loadBiSeNetV2(device: str = 'cpu', weights:str='weight_Base.pth') -> BiSeNetV2:
27
  """
28
  Load the BiSeNetV2 model and move it to the specified device.
29
 
30
  Args:
31
+ device (str): Device to load the model onto ('cpu' or 'cuda'). Default is 'cpu'.
32
+ weights (str): weights file to be loaded. Default is 'weight_Base.pth'.
33
 
34
  Returns:
35
  model (BiSeNetV2): The loaded BiSeNetV2 model.
36
  """
37
  model = BiSeNetV2(n_classes=19).to(device)
38
+ model.load_state_dict(torch.load(f'./weights/BiSeNetV2/{weights}', map_location=device)['model_state_dict'])
39
  model.eval()
40
 
41
  return model