Spaces:
Sleeping
Sleeping
Nunzio commited on
Commit ·
9d4bb4b
1
Parent(s): 05e5639
fixed weight loading
Browse files- model/modelLoading.py +7 -6
model/modelLoading.py
CHANGED
|
@@ -5,13 +5,13 @@ from model.BiSeNetV2.model import BiSeNetV2
|
|
| 5 |
|
| 6 |
|
| 7 |
# BiSeNet model loading function
|
| 8 |
-
def loadBiSeNet(device: str = 'cpu', weights:str='
|
| 9 |
"""
|
| 10 |
Load the BiSeNet model and move it to the specified device.
|
| 11 |
|
| 12 |
Args:
|
| 13 |
-
device (str): Device to load the model onto ('cpu' or 'cuda').
|
| 14 |
-
weights (str): weights file to be loaded
|
| 15 |
|
| 16 |
Returns:
|
| 17 |
model (BiSeNet): The loaded BiSeNet model.
|
|
@@ -23,18 +23,19 @@ def loadBiSeNet(device: str = 'cpu', weights:str='weightADV.pth') -> BiSeNet:
|
|
| 23 |
return model
|
| 24 |
|
| 25 |
|
| 26 |
-
def loadBiSeNetV2(device: str = 'cpu') -> BiSeNetV2:
|
| 27 |
"""
|
| 28 |
Load the BiSeNetV2 model and move it to the specified device.
|
| 29 |
|
| 30 |
Args:
|
| 31 |
-
device (str): Device to load the model onto ('cpu' or 'cuda').
|
|
|
|
| 32 |
|
| 33 |
Returns:
|
| 34 |
model (BiSeNetV2): The loaded BiSeNetV2 model.
|
| 35 |
"""
|
| 36 |
model = BiSeNetV2(n_classes=19).to(device)
|
| 37 |
-
model.load_state_dict(torch.load('./weights/BiSeNetV2/
|
| 38 |
model.eval()
|
| 39 |
|
| 40 |
return model
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
# BiSeNet model loading function
|
| 8 |
+
def loadBiSeNet(device: str = 'cpu', weights:str='weight_Base.pth') -> BiSeNet:
|
| 9 |
"""
|
| 10 |
Load the BiSeNet model and move it to the specified device.
|
| 11 |
|
| 12 |
Args:
|
| 13 |
+
device (str): Device to load the model onto ('cpu' or 'cuda'). Default is 'cpu'.
|
| 14 |
+
weights (str): weights file to be loaded. Default is 'weight_Base.pth'.
|
| 15 |
|
| 16 |
Returns:
|
| 17 |
model (BiSeNet): The loaded BiSeNet model.
|
|
|
|
| 23 |
return model
|
| 24 |
|
| 25 |
|
| 26 |
+
def loadBiSeNetV2(device: str = 'cpu', weights:str='weight_Base.pth') -> BiSeNetV2:
|
| 27 |
"""
|
| 28 |
Load the BiSeNetV2 model and move it to the specified device.
|
| 29 |
|
| 30 |
Args:
|
| 31 |
+
device (str): Device to load the model onto ('cpu' or 'cuda'). Default is 'cpu'.
|
| 32 |
+
weights (str): weights file to be loaded. Default is 'weight_Base.pth'.
|
| 33 |
|
| 34 |
Returns:
|
| 35 |
model (BiSeNetV2): The loaded BiSeNetV2 model.
|
| 36 |
"""
|
| 37 |
model = BiSeNetV2(n_classes=19).to(device)
|
| 38 |
+
model.load_state_dict(torch.load(f'./weights/BiSeNetV2/{weights}', map_location=device)['model_state_dict'])
|
| 39 |
model.eval()
|
| 40 |
|
| 41 |
return model
|