| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """ |
| Created in September 2022 |
| @author: fabrizio.guillaro |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| from torch.nn import functional as F |
|
|
|
|
|
|
| class CrossEntropy(nn.Module): |
| def __init__(self, ignore_label=-1, weight=None): |
| super(CrossEntropy, self).__init__() |
| self.ignore_label = ignore_label |
| self.criterion = nn.CrossEntropyLoss(weight=weight, |
| ignore_index=ignore_label) |
|
|
| def forward(self, score, target): |
| ph, pw = score.size(2), score.size(3) |
| h, w = target.size(1), target.size(2) |
| if ph != h or pw != w: |
| score = F.upsample( |
| input=score, size=(h, w), mode='bilinear') |
|
|
| loss = self.criterion(score, target) |
| return loss |
|
|
| |
| |
| class DiceLoss(nn.Module): |
| def __init__(self, ignore_label=-1, smooth=1, exponent=2): |
| super(DiceLoss, self).__init__() |
| self.ignore_index = ignore_label |
| self.smooth = smooth |
| self.exponent = exponent |
| |
| def dice_loss(self, pred, target, valid_mask): |
| assert pred.shape[0] == target.shape[0] |
| total_loss = 0 |
| num_classes = pred.shape[1] |
| for i in range(num_classes): |
| if i != self.ignore_index: |
| dice_loss = self.binary_dice_loss( |
| pred[:, i], |
| target[..., i], |
| valid_mask=valid_mask,) |
| total_loss += dice_loss |
| return total_loss / num_classes |
|
|
| def binary_dice_loss(self, pred, target, valid_mask): |
| assert pred.shape[0] == target.shape[0] |
| pred = pred.reshape(pred.shape[0], -1) |
| target = target.reshape(target.shape[0], -1) |
| valid_mask = valid_mask.reshape(valid_mask.shape[0], -1) |
|
|
| num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + self.smooth |
| den = torch.sum(pred.pow(self.exponent)*valid_mask + target.pow(self.exponent)*valid_mask, dim=1) + max(self.smooth, 1e-5) |
| |
| dice = num / den |
| dice = torch.mean(dice) |
| return 1 - dice |
| |
| def forward(self, score, target): |
| ph, pw = score.size(2), score.size(3) |
| h, w = target.size(1), target.size(2) |
| if ph != h or pw != w: |
| score = F.upsample( |
| input=score, size=(h, w), mode='bilinear') |
| |
| score = F.softmax(score,dim=1) |
| num_classes = score.shape[1] |
| |
| one_hot_target = F.one_hot( |
| torch.clamp(target.long(), 0, num_classes - 1), |
| num_classes=num_classes) |
| valid_mask = (target != self.ignore_index).long() |
| |
| loss = self.dice_loss(score, one_hot_target, valid_mask) |
| return loss |
| |
| |
|
|
| class BinaryDiceLoss(nn.Module): |
| def __init__(self, smooth=1, exponent=2, ignore_label=-1): |
| super(BinaryDiceLoss, self).__init__() |
| self.ignore_index = ignore_label |
| self.smooth = smooth |
| self.exponent = exponent |
|
|
| def binary_dice_loss(self, pred, target, valid_mask): |
| assert pred.shape[0] == target.shape[0] |
| print(pred.shape, target.shape) |
| pred = pred.reshape(pred.shape[0], -1) |
| target = target.reshape(target.shape[0], -1) |
| valid_mask = valid_mask.reshape(valid_mask.shape[0], -1) |
| print(pred.shape, target.shape) |
| num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + self.smooth |
| den = torch.sum(pred.pow(self.exponent)*valid_mask + target.pow(self.exponent)*valid_mask, dim=1) + max(self.smooth, 1e-5) |
| |
| dice = num / den |
| dice = torch.mean(dice) |
| return 1 - dice |
| |
| def forward(self, score, target): |
| |
| ph, pw = score.size(2), score.size(3) |
| h, w = target.size(2), target.size(3) |
| if ph != h or pw != w: |
| score = F.upsample( |
| input=score, size=(h, w), mode='bilinear') |
| |
| score = F.softmax(score,dim=1) |
| num_classes = score.shape[1] |
| |
| one_hot_target = F.one_hot( |
| torch.clamp(target.long(), 0, num_classes - 1), |
| num_classes=num_classes) |
| valid_mask = (target != self.ignore_index).long() |
| |
| loss = self.binary_dice_loss( |
| score[:, 1], |
| one_hot_target[..., 1], |
| valid_mask) |
| return loss |
| |
| def create_target_from_mask_and_label(mask, data_label): |
| """ |
| Convert binary mask to class-labeled target based on data_label. |
| |
| Args: |
| mask: B H W with values 0 (black/background) or 1 (white/foreground) |
| data_label: B×1 tensor or B tensor with values [0, 1, 2, 3] |
| - 0: background (no edit) |
| - 1: physical edit (Photoshop) |
| - 2: synthetic AI edit |
| - 3: other edit type |
| |
| Returns: |
| target: B H W with values [0, 1, 2, 3] |
| - 0: unedited pixels (mask == 0) |
| - 1, 2, 3: edited pixels with their respective class labels |
| """ |
| |
| |
| if mask.dim() == 4: |
| mask = mask.squeeze(1) |
| |
| |
| if data_label.dim() > 1: |
| data_label = data_label.squeeze() |
| |
| B, H, W = mask.shape |
| |
| |
| target = torch.zeros(B, H, W, dtype=torch.long, device=mask.device) |
| |
| |
| for b in range(B): |
| |
| class_label = data_label[b].item() if data_label.dim() > 0 else data_label.item() |
| |
| |
| |
| target[b][mask[b] == 1] = class_label |
| |
| return target |
|
|
|
|
| def debug_target_creation(target, data_label, batch_size=4): |
| """ |
| Debug function to print data_label and target mapping before and after conversion. |
| |
| Args: |
| target: Binary mask B×H×W or B×1×H×W with values 0 or 1 |
| data_label: B tensor with class labels [0, 1, 2, 3] |
| """ |
| |
| print("="*80) |
| print("DEBUGGING TARGET CREATION") |
| print("="*80) |
| |
| |
| print("\n--- BEFORE CONVERSION ---") |
| print(f"Data Label shape: {data_label.shape}") |
| print(f"Data Label values: {data_label}") |
| print(f"Data Label dtype: {data_label.dtype}") |
| |
| print(f"\nTarget (mask) shape: {target.shape}") |
| print(f"Target (mask) unique values: {torch.unique(target)}") |
| print(f"Target (mask) dtype: {target.dtype}") |
| |
| |
| print("\n--- PER-SAMPLE BREAKDOWN (BEFORE) ---") |
| if target.dim() == 4: |
| target_2d = target.squeeze(1) |
| else: |
| target_2d = target |
| |
| B = target_2d.shape[0] |
| for b in range(min(B, batch_size)): |
| edited_pixels = (target_2d[b] == 1).sum().item() |
| total_pixels = target_2d[b].numel() |
| label = data_label[b].item() if data_label.dim() > 0 else data_label.item() |
| print(f" Sample {b}: Label={label}, Edited pixels={edited_pixels}/{total_pixels}") |
| |
| |
| target_converted = create_target_from_mask_and_label(target, data_label) |
| |
| |
| print("\n--- AFTER CONVERSION ---") |
| print(f"Target (converted) shape: {target_converted.shape}") |
| print(f"Target (converted) unique values: {torch.unique(target_converted)}") |
| print(f"Target (converted) dtype: {target_converted.dtype}") |
| |
| |
| print("\n--- PER-SAMPLE BREAKDOWN (AFTER) ---") |
| for b in range(min(B, batch_size)): |
| label = data_label[b].item() if data_label.dim() > 0 else data_label.item() |
| |
| |
| class_counts = {} |
| for class_id in range(4): |
| count = (target_converted[b] == class_id).sum().item() |
| class_counts[class_id] = count |
| |
| print(f" Sample {b}:") |
| print(f" Label (expected): {label}") |
| print(f" Class distribution: {class_counts}") |
| |
| |
| if label == 0: |
| |
| if class_counts[0] == target_converted[b].numel(): |
| print(f" ✓ CORRECT: All pixels are class 0 (background)") |
| else: |
| print(f" ✗ ERROR: Expected all pixels to be 0, but got {class_counts}") |
| else: |
| |
| if class_counts[label] > 0: |
| print(f" ✓ CORRECT: Found {class_counts[label]} pixels with class {label}") |
| else: |
| print(f" ✗ ERROR: Expected class {label} pixels but found none") |
| |
| print("\n" + "="*80) |
| |
| return target_converted |
|
|
| class MultiClassDiceEntropyLoss(nn.Module): |
| """ |
| Multi-class segmentation loss combining Dice and CrossEntropy. |
| Supports classes: 0 (background), 1, 2, 3 |
| """ |
| def __init__(self, num_classes=4, smooth=1e-5, dice_weight=0.5, ce_weight=0.5, |
| ignore_index=-1): |
| super(MultiClassDiceEntropyLoss, self).__init__() |
| self.num_classes = num_classes |
| self.smooth = smooth |
| self.dice_weight = dice_weight |
| self.ce_weight = ce_weight |
| self.ignore_index = ignore_index |
| |
| |
| self.ce_loss = nn.CrossEntropyLoss(ignore_index=ignore_index) |
| |
| def dice_loss(self, pred, target, valid_mask=None): |
| """ |
| Compute Dice loss per class and average |
| |
| pred: B C H W (softmax probabilities) |
| target: B H W (class indices 0-3) |
| valid_mask: B H W (1 for valid, 0 for ignore) |
| """ |
| dice_losses = [] |
| |
| for class_id in range(self.num_classes): |
| |
| pred_class = pred[:, class_id, :, :] |
| target_class = (target == class_id).float() |
| |
| |
| pred_flat = pred_class.reshape(-1) |
| target_flat = target_class.reshape(-1) |
| |
| |
| if valid_mask is not None: |
| valid_flat = valid_mask.reshape(-1) |
| pred_flat = pred_flat * valid_flat |
| target_flat = target_flat * valid_flat |
| |
| |
| intersection = torch.sum(pred_flat * target_flat) |
| union = torch.sum(pred_flat) + torch.sum(target_flat) |
| |
| dice = (2 * intersection + self.smooth) / (union + self.smooth) |
| dice_losses.append(1 - dice) |
| |
| return torch.mean(torch.stack(dice_losses)) |
| |
| def forward(self, score, target, data_label): |
| """ |
| pred: B 1 H W (U-Net output, raw logits) |
| target: B H W (class labels: 0, 1, 2, or 3) |
| """ |
| |
| if target.dim() == 4: |
| target = target.squeeze(1) |
| |
| |
| |
| |
| |
| |
| target = target.long() |
| |
| |
| if score.shape[2:] != target.shape[1:]: |
| score = F.interpolate(score, size=target.shape[1:], mode='bilinear', align_corners=False) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| score_probs = F.softmax(score, dim=1) |
| |
| |
| ce_loss = self.ce_loss(score, target) |
| |
| |
| valid_mask = (target != self.ignore_index).float() |
| |
| |
| dice_loss = self.dice_loss(score_probs, target, valid_mask) |
| |
| |
| total_loss = self.dice_weight * dice_loss + self.ce_weight * ce_loss |
| |
| return total_loss |
|
|
|
|
| class DiceEntropyLoss(nn.Module): |
| def __init__(self, smooth=1, exponent=2, ignore_label=-1, weight=None): |
| super(DiceEntropyLoss, self).__init__() |
| self.ignore_label = ignore_label |
| self.smooth = smooth |
| self.exponent = exponent |
| self.cross_entropy = nn.CrossEntropyLoss(weight=weight, |
| ignore_index=ignore_label) |
| |
| def binary_dice_loss(self, pred, target, valid_mask): |
| assert pred.shape[0] == target.shape[0] |
| pred = pred.reshape(pred.shape[0], -1) |
| target = target.reshape(target.shape[0], -1) |
| valid_mask = valid_mask.reshape(valid_mask.shape[0], -1) |
|
|
| num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + self.smooth |
| den = torch.sum(pred.pow(self.exponent)*valid_mask + target.pow(self.exponent)*valid_mask, dim=1) + max(self.smooth, 1e-5) |
| |
| dice = num / den |
| dice = torch.mean(dice) |
| return 1 - dice |
| |
| def forward(self, score, target): |
| target = target.squeeze(1).long() |
|
|
| target = torch.clamp(target, min=0, max=1) |
| ph, pw = score.size(2), score.size(3) |
| h, w = target.size(1), target.size(2) |
| if ph != h or pw != w: |
| score = F.upsample( |
| input=score, size=(h, w), mode='bilinear') |
|
|
| CE_loss = self.cross_entropy(score, target) |
| |
| |
| score = F.softmax(score,dim=1) |
| num_classes = score.shape[1] |
| |
| one_hot_target = F.one_hot( |
| torch.clamp(target.long(), 0, num_classes - 1), |
| num_classes=num_classes) |
| valid_mask = (target != self.ignore_label).long() |
| |
| |
| |
| |
| |
|
|
| number_of_present_classes = 4 |
| dice_loss = 0 |
| for class_id in [1,2,3]: |
| if (target == class_id).sum() > 0: |
| dice_loss += dice(pred[:, class_id], target_onehot[:, class_id]) |
| dice_loss /= number_of_present_classes |
|
|
| return 0.3*CE_loss + 0.7*dice_loss |
|
|
|
|
|
|
| |
| class FocalLoss(nn.Module): |
|
|
| def __init__(self, alpha=0.25, gamma=2., ignore_label=-1): |
| super(FocalLoss, self).__init__() |
| self.alpha=alpha |
| self.gamma= gamma |
| self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_label, reduction="none") |
| |
| def forward(self, score, target): |
| ph, pw = score.size(2), score.size(3) |
| h, w = target.size(1), target.size(2) |
| if ph != h or pw != w: |
| score = F.upsample( |
| input=score, size=(h, w), mode='bilinear') |
| |
| ce_loss = self.criterion(score, target) |
| pt = torch.exp(-ce_loss) |
| f_loss = self.alpha * (1-pt)**self.gamma * ce_loss |
| return f_loss.mean() |
| |
| |