revert: remove post-processing, use raw model output — works best for majority of images
Browse files- src/segmentation.py +0 -1
src/segmentation.py
CHANGED
|
@@ -252,7 +252,6 @@ def segment(model: nn.Module, img_bgr: np.ndarray, device: torch.device, use_tta
|
|
| 252 |
probs_np = probs[0].cpu().numpy()
|
| 253 |
probs_resized = np.stack([cv2.resize(probs_np[c], (w, h), interpolation=cv2.INTER_LINEAR) for c in range(4)])
|
| 254 |
classmap = probs_resized.argmax(axis=0).astype(np.uint8)
|
| 255 |
-
classmap = postprocess_segmentation(classmap, img_bgr)
|
| 256 |
masks = {name: (classmap == cid) for cid, name in CLASS_NAMES.items() if cid > 0}
|
| 257 |
return {"classmap": classmap, "masks": masks, "probs": probs_resized}
|
| 258 |
|
|
|
|
| 252 |
probs_np = probs[0].cpu().numpy()
|
| 253 |
probs_resized = np.stack([cv2.resize(probs_np[c], (w, h), interpolation=cv2.INTER_LINEAR) for c in range(4)])
|
| 254 |
classmap = probs_resized.argmax(axis=0).astype(np.uint8)
|
|
|
|
| 255 |
masks = {name: (classmap == cid) for cid, name in CLASS_NAMES.items() if cid > 0}
|
| 256 |
return {"classmap": classmap, "masks": masks, "probs": probs_resized}
|
| 257 |
|