code_SAS_VLM2Vec / src /loss_loc.py
MgGladys's picture
Add files using upload-large-folder tool
0a937d7 verified
import torch
import torch.nn.functional as F
from torch import nn
from typing import List
class DiceLoss(nn.Module):
def __init__(self, epsilon: float = 1e-6):
super().__init__()
self.epsilon = epsilon
def forward(self, logits: List[torch.Tensor], gts: List[torch.Tensor]):
if len(logits) == 0:
dev = gts[0].device if len(gts) else "cpu"
return torch.tensor(0.0, device=dev)
total = 0.0
for pred, gt in zip(logits, gts):
p = pred.flatten().sigmoid()
g = gt.flatten().to(p.device, dtype=torch.float)
inter = (p * g).sum()
denom = p.sum() + g.sum()
dice = (2 * inter + self.epsilon) / (denom + self.epsilon)
total += (1 - dice)
return total / len(logits)
class BCELoss(nn.Module):
def __init__(self):
super().__init__()
def forward(self, logits: List[torch.Tensor], gts: List[torch.Tensor]):
if len(logits) == 0:
dev = gts[0].device if len(gts) else "cpu"
return torch.tensor(0.0, device=dev)
total = 0.0
for pred, gt in zip(logits, gts):
total += F.binary_cross_entropy_with_logits(
pred.flatten().float(),
gt.flatten().to(pred.device).float(),
)
return total / len(logits)
class MaskLoss(nn.Module):
def __init__(self, dice_weight=1.0, bce_weight=0.1, epsilon=1e-6):
super().__init__()
self.dice = DiceLoss(epsilon)
self.bce = BCELoss()
self.dw = dice_weight
self.bw = bce_weight
def forward(self, logits, gts):
return self.dw * self.dice(logits, gts) + self.bw * self.bce(logits, gts)