| |
| |
| |
|
|
| |
| |
|
|
| import math |
|
|
| import einops |
| import numpy as np |
| import torch |
|
|
| import torch.nn as nn |
|
|
|
|
| class Normalize(nn.Module): |
| def __init__(self, dim: int) -> None: |
| super().__init__() |
| self.dim = dim |
|
|
| def forward(self, x): |
| return torch.nn.functional.normalize(x, dim=self.dim, p=2) |
|
|
|
|
| class LearnableLogitScaling(nn.Module): |
| def __init__( |
| self, |
| logit_scale_init: float = 1 / 0.07, |
| learnable: bool = True, |
| max_logit_scale: float = 100, |
| ) -> None: |
| super().__init__() |
| self.max_logit_scale = max_logit_scale |
| self.logit_scale_init = logit_scale_init |
| self.learnable = learnable |
| log_logit_scale = torch.ones([]) * np.log(self.logit_scale_init) |
| if learnable: |
| self.log_logit_scale = nn.Parameter(log_logit_scale) |
| else: |
| self.register_buffer("log_logit_scale", log_logit_scale) |
|
|
| def forward(self, x): |
| return torch.clip(self.log_logit_scale.exp(), max=self.max_logit_scale) * x |
|
|
| def extra_repr(self): |
| st = f"logit_scale_init={self.logit_scale_init},learnable={self.learnable}, max_logit_scale={self.max_logit_scale}" |
| return st |
|
|
|
|
| class EinOpsRearrange(nn.Module): |
| def __init__(self, rearrange_expr: str, **kwargs) -> None: |
| super().__init__() |
| self.rearrange_expr = rearrange_expr |
| self.kwargs = kwargs |
|
|
| def forward(self, x): |
| assert isinstance(x, torch.Tensor) |
| return einops.rearrange(x, self.rearrange_expr, **self.kwargs) |
|
|
|
|
| class VerboseNNModule(nn.Module): |
| """ |
| Wrapper around nn.Module that prints registered buffers and parameter names. |
| """ |
|
|
| @staticmethod |
| def get_readable_tensor_repr(name: str, tensor: torch.Tensor) -> str: |
| st = ( |
| "(" |
| + name |
| + "): " |
| + "tensor(" |
| + str(tuple(tensor[1].shape)) |
| + ", requires_grad=" |
| + str(tensor[1].requires_grad) |
| + ")\n" |
| ) |
| return st |
|
|
| def extra_repr(self) -> str: |
| named_modules = set() |
| for p in self.named_modules(): |
| named_modules.update([p[0]]) |
| named_modules = list(named_modules) |
|
|
| string_repr = "" |
| for p in self.named_parameters(): |
| name = p[0].split(".")[0] |
| if name not in named_modules: |
| string_repr += self.get_readable_tensor_repr(name, p) |
|
|
| for p in self.named_buffers(): |
| name = p[0].split(".")[0] |
| string_repr += self.get_readable_tensor_repr(name, p) |
|
|
| return string_repr |
|
|
|
|
| def cast_if_src_dtype( |
| tensor: torch.Tensor, src_dtype: torch.dtype, tgt_dtype: torch.dtype |
| ): |
| updated = False |
| if tensor.dtype == src_dtype: |
| tensor = tensor.to(dtype=tgt_dtype) |
| updated = True |
| return tensor, updated |
|
|
|
|
| class QuickGELU(nn.Module): |
| |
| def forward(self, x: torch.Tensor): |
| return x * torch.sigmoid(1.702 * x) |
|
|
|
|
| class SelectElement(nn.Module): |
| def __init__(self, index) -> None: |
| super().__init__() |
| self.index = index |
|
|
| def forward(self, x): |
| assert x.ndim >= 3 |
| return x[:, self.index, ...] |
|
|
|
|
| class SelectEOSAndProject(nn.Module): |
| """ |
| Text Pooling used in OpenCLIP |
| """ |
|
|
| def __init__(self, proj: nn.Module) -> None: |
| super().__init__() |
| self.proj = proj |
|
|
| def forward(self, x, seq_len): |
| assert x.ndim == 3 |
| |
| |
| x = x[torch.arange(x.shape[0]), seq_len] |
| x = self.proj(x) |
| return x |
|
|