Vedant Jigarbhai Mehta commited on
Commit
c95b5c2
·
1 Parent(s): 9b0eb35

Fix UNet++ decoder call for newer SMP API (pass list not unpacked args)

Browse files
Files changed (1) hide show
  1. models/unet_pp.py +2 -2
models/unet_pp.py CHANGED
@@ -62,8 +62,8 @@ class UNetPPChangeDetection(nn.Module):
62
  # Compute absolute difference at each scale
63
  diff_features = [torch.abs(f1 - f2) for f1, f2 in zip(features_1, features_2)]
64
 
65
- # Decode
66
- decoder_output = self.decoder(*diff_features)
67
  out = self.segmentation_head(decoder_output)
68
  return out
69
 
 
62
  # Compute absolute difference at each scale
63
  diff_features = [torch.abs(f1 - f2) for f1, f2 in zip(features_1, features_2)]
64
 
65
+ # Decode (SMP decoder expects a list of features, not unpacked args)
66
+ decoder_output = self.decoder(diff_features)
67
  out = self.segmentation_head(decoder_output)
68
  return out
69