mmarquezsa commited on
Commit
ffa8f7b
·
verified ·
1 Parent(s): 9a9c395

revert: remove post-processing, use raw model output — works best for majority of images

Browse files
Files changed (1) hide show
  1. src/segmentation.py +0 -1
src/segmentation.py CHANGED
@@ -252,7 +252,6 @@ def segment(model: nn.Module, img_bgr: np.ndarray, device: torch.device, use_tta
252
  probs_np = probs[0].cpu().numpy()
253
  probs_resized = np.stack([cv2.resize(probs_np[c], (w, h), interpolation=cv2.INTER_LINEAR) for c in range(4)])
254
  classmap = probs_resized.argmax(axis=0).astype(np.uint8)
255
- classmap = postprocess_segmentation(classmap, img_bgr)
256
  masks = {name: (classmap == cid) for cid, name in CLASS_NAMES.items() if cid > 0}
257
  return {"classmap": classmap, "masks": masks, "probs": probs_resized}
258
 
 
252
  probs_np = probs[0].cpu().numpy()
253
  probs_resized = np.stack([cv2.resize(probs_np[c], (w, h), interpolation=cv2.INTER_LINEAR) for c in range(4)])
254
  classmap = probs_resized.argmax(axis=0).astype(np.uint8)
 
255
  masks = {name: (classmap == cid) for cid, name in CLASS_NAMES.items() if cid > 0}
256
  return {"classmap": classmap, "masks": masks, "probs": probs_resized}
257