| import torch |
| import torch.nn.functional as F |
|
|
|
|
| def fixed_cross_entropy( |
| source, |
| target, |
| num_items_in_batch: int | None = None, |
| ignore_index: int = -100, |
| weight=None, |
| **kwargs, |
| ): |
| reduction = "sum" if num_items_in_batch is not None else "mean" |
| loss = F.cross_entropy( |
| source, |
| target, |
| ignore_index=ignore_index, |
| reduction=reduction, |
| weight=weight, |
| ) |
| if reduction == "sum": |
| loss = loss / num_items_in_batch |
| return loss |
|
|
|
|
| def WeightedCausalLMLoss( |
| logits, |
| labels, |
| image_vocab_size: int, |
| image_loss_weight: float = 1.0, |
| image_token_ratio: float = 2.4, |
| num_items_in_batch: int | None = None, |
| ignore_index: int = -100, |
| **kwargs, |
| ): |
| |
| logits = logits.float() |
| labels = labels.to(logits.device) |
| |
| labels = F.pad(labels, (0, 1), value=ignore_index) |
| shift_labels = labels[..., 1:].contiguous() |
|
|
| |
| if image_loss_weight != 1.0: |
| weight = torch.ones(logits.size(-1), device=logits.device) |
| weight[-image_vocab_size:] = image_loss_weight |
| else: |
| weight = None |
|
|
| |
| logits = logits.view(-1, logits.size(-1)) |
| shift_labels = shift_labels.view(-1) |
| |
| shift_labels = shift_labels.to(logits.device) |
| loss = fixed_cross_entropy( |
| logits, |
| shift_labels, |
| num_items_in_batch, |
| ignore_index, |
| weight=weight, |
| **kwargs, |
| ) |
|
|
| |
| if image_loss_weight != 1.0: |
| denom = 1.0 + (image_token_ratio * image_loss_weight) |
| scale = (1.0 + image_token_ratio) / denom |
| loss = scale * loss |
|
|
| return loss |
|
|