| import torch |
| from torch import nn |
| import torch.nn.functional as F |
|
|
|
|
| def get_similarity_matrix( |
| image_features: torch.Tensor, text_features: torch.Tensor |
| ) -> torch.Tensor: |
| return image_features @ text_features.T |
|
|
|
|
| def contrastive_loss(logits, dim): |
| neg_ce = torch.diag(F.log_softmax(logits, dim=dim)) |
| return -neg_ce.mean() |
|
|
|
|
| def contrastive_sigmoid_loss(logits): |
| return F.binary_cross_entropy_with_logits(logits, torch.eye(len(logits)), reduction="mean") |
|
|
|
|
| class CLIPLoss(nn.Module): |
| def __init__(self, logit_temperature: float = -1.0): |
| super().__init__() |
| self.logit_temperature = nn.Parameter(torch.tensor(logit_temperature)) |
|
|
| def forward(self, similarity_matrix: torch.Tensor, *args): |
| temperature = self.logit_temperature.sigmoid() |
|
|
| caption_loss = contrastive_loss(similarity_matrix / temperature, dim=0) |
| image_loss = contrastive_loss(similarity_matrix / temperature, dim=1) |
|
|
| return 0.5 * (caption_loss + image_loss) |
|
|
|
|
| class CyCLIPLoss(nn.Module): |
| def __init__(self, logit_temperature: float = -1.0): |
| super().__init__() |
| self.logit_temperature = nn.Parameter(torch.tensor(logit_temperature)) |
| self.lambda_1: float = 1.0 |
| self.lambda_2: float = 1.0 |
|
|
| def forward( |
| self, |
| similarity_matrix: torch.Tensor, |
| image_features: torch.Tensor, |
| text_features: torch.Tensor, |
| ): |
| temperature = self.logit_temperature.sigmoid() |
| caption_loss = contrastive_loss(similarity_matrix / temperature, dim=0) |
| image_loss = contrastive_loss(similarity_matrix / temperature, dim=1) |
|
|
| symmetry_loss = F.mse_loss(similarity_matrix, similarity_matrix.T) |
| modality_difference_loss = F.mse_loss( |
| image_features @ image_features.T, text_features @ text_features.T |
| ) |
|
|
| return ( |
| 0.5 * (caption_loss + image_loss) |
| + self.lambda_1 * symmetry_loss |
| + self.lambda_2 * modality_difference_loss |
| ) |
|
|
|
|
| class SigLIPLoss(nn.Module): |
| def __init__(self, logit_temperature: float = -1.0): |
| super().__init__() |
| self.logit_temperature = nn.Parameter(torch.tensor(logit_temperature)) |
|
|
| def forward(self, similarity_matrix: torch.Tensor, *args): |
| temperature = self.logit_temperature.sigmoid() |
| return contrastive_sigmoid_loss(similarity_matrix / temperature) |
|
|
|
|
| class CySigLIPLoss(nn.Module): |
| def __init__(self, logit_temperature: float = -1.0): |
| super().__init__() |
| self.logit_temperature = nn.Parameter(torch.tensor(logit_temperature)) |
| self.lambda_1: float = 1.0 |
| self.lambda_2: float = 1.0 |
|
|
| def forward( |
| self, |
| similarity_matrix: torch.Tensor, |
| image_features: torch.Tensor, |
| text_features: torch.Tensor, |
| ): |
| temperature = self.logit_temperature.sigmoid() |
| loss = contrastive_sigmoid_loss(similarity_matrix / temperature) |
|
|
| symmetry_loss = F.mse_loss(similarity_matrix, similarity_matrix.T) |
| modality_difference_loss = F.mse_loss( |
| image_features @ image_features.T, text_features @ text_features.T |
| ) |
|
|
| return loss + self.lambda_1 * symmetry_loss + self.lambda_2 * modality_difference_loss |
|
|
|
|
| def get_loss(loss_type: str): |
| loss_functions = { |
| "clip": CLIPLoss(), |
| "cyclip": CyCLIPLoss(), |
| "sigmoid": SigLIPLoss(), |
| "cyclic_sigmoid": CySigLIPLoss(), |
| } |
| if loss_type in loss_functions: |
| return loss_functions[loss_type] |
| else: |
| raise ValueError("Invalid loss type") |
|
|