| import torch |
| import torch.nn as nn |
| import utils |
|
|
| from utils import trunc_normal_ |
|
|
| class CSyncBatchNorm(nn.SyncBatchNorm): |
| def __init__(self, |
| *args, |
| with_var=False, |
| **kwargs): |
| super(CSyncBatchNorm, self).__init__(*args, **kwargs) |
| self.with_var = with_var |
|
|
| def forward(self, x): |
| |
| self.training = False |
| if not self.with_var: |
| self.running_var = torch.ones_like(self.running_var) |
| normed_x = super(CSyncBatchNorm, self).forward(x) |
| |
| self.training = True |
| _ = super(CSyncBatchNorm, self).forward(x) |
| return normed_x |
|
|
| class PSyncBatchNorm(nn.SyncBatchNorm): |
| def __init__(self, |
| *args, |
| bunch_size, |
| **kwargs): |
| procs_per_bunch = min(bunch_size, utils.get_world_size()) |
| assert utils.get_world_size() % procs_per_bunch == 0 |
| n_bunch = utils.get_world_size() // procs_per_bunch |
| |
| ranks = list(range(utils.get_world_size())) |
| print('---ALL RANKS----\n{}'.format(ranks)) |
| rank_groups = [ranks[i*procs_per_bunch: (i+1)*procs_per_bunch] for i in range(n_bunch)] |
| print('---RANK GROUPS----\n{}'.format(rank_groups)) |
| process_groups = [torch.distributed.new_group(pids) for pids in rank_groups] |
| bunch_id = utils.get_rank() // procs_per_bunch |
| process_group = process_groups[bunch_id] |
| print('---CURRENT GROUP----\n{}'.format(process_group)) |
| super(PSyncBatchNorm, self).__init__(*args, process_group=process_group, **kwargs) |
|
|
| class CustomSequential(nn.Sequential): |
| bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) |
|
|
| def forward(self, input): |
| for module in self: |
| dim = len(input.shape) |
| if isinstance(module, self.bn_types) and dim > 2: |
| perm = list(range(dim - 1)); perm.insert(1, dim - 1) |
| inv_perm = list(range(dim)) + [1]; inv_perm.pop(1) |
| input = module(input.permute(*perm)).permute(*inv_perm) |
| else: |
| input = module(input) |
| return input |
|
|
| class DINOHead(nn.Module): |
| def __init__(self, in_dim, out_dim, norm=None, act='gelu', last_norm=None, |
| nlayers=3, hidden_dim=2048, bottleneck_dim=256, norm_last_layer=True, **kwargs): |
| super().__init__() |
| norm = self._build_norm(norm, hidden_dim) |
| last_norm = self._build_norm(last_norm, out_dim, affine=False, **kwargs) |
| act = self._build_act(act) |
|
|
| nlayers = max(nlayers, 1) |
| if nlayers == 1: |
| if bottleneck_dim > 0: |
| self.mlp = nn.Linear(in_dim, bottleneck_dim) |
| else: |
| self.mlp = nn.Linear(in_dim, out_dim) |
| else: |
| layers = [nn.Linear(in_dim, hidden_dim)] |
| if norm is not None: |
| layers.append(norm) |
| layers.append(act) |
| for _ in range(nlayers - 2): |
| layers.append(nn.Linear(hidden_dim, hidden_dim)) |
| if norm is not None: |
| layers.append(norm) |
| layers.append(act) |
| if bottleneck_dim > 0: |
| layers.append(nn.Linear(hidden_dim, bottleneck_dim)) |
| else: |
| layers.append(nn.Linear(hidden_dim, out_dim)) |
| self.mlp = CustomSequential(*layers) |
| self.apply(self._init_weights) |
| |
| if bottleneck_dim > 0: |
| self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) |
| self.last_layer.weight_g.data.fill_(1) |
| if norm_last_layer: |
| self.last_layer.weight_g.requires_grad = False |
| else: |
| self.last_layer = None |
|
|
| self.last_norm = last_norm |
|
|
| def _init_weights(self, m): |
| if isinstance(m, nn.Linear): |
| trunc_normal_(m.weight, std=.02) |
| if isinstance(m, nn.Linear) and m.bias is not None: |
| nn.init.constant_(m.bias, 0) |
|
|
| def forward(self, x): |
| x = self.mlp(x) |
| if self.last_layer is not None: |
| x = nn.functional.normalize(x, dim=-1, p=2) |
| x = self.last_layer(x) |
| if self.last_norm is not None: |
| x = self.last_norm(x) |
| return x |
|
|
| def _build_norm(self, norm, hidden_dim, **kwargs): |
| if norm == 'bn': |
| norm = nn.BatchNorm1d(hidden_dim, **kwargs) |
| elif norm == 'syncbn': |
| norm = nn.SyncBatchNorm(hidden_dim, **kwargs) |
| elif norm == 'csyncbn': |
| norm = CSyncBatchNorm(hidden_dim, **kwargs) |
| elif norm == 'psyncbn': |
| norm = PSyncBatchNorm(hidden_dim, **kwargs) |
| elif norm == 'ln': |
| norm = nn.LayerNorm(hidden_dim, **kwargs) |
| else: |
| assert norm is None, "unknown norm type {}".format(norm) |
| return norm |
|
|
| def _build_act(self, act): |
| if act == 'relu': |
| act = nn.ReLU() |
| elif act == 'gelu': |
| act = nn.GELU() |
| else: |
| assert False, "unknown act type {}".format(act) |
| return act |
|
|
| class iBOTHead(DINOHead): |
|
|
| def __init__(self, *args, patch_out_dim=8192, norm=None, act='gelu', last_norm=None, |
| nlayers=3, hidden_dim=2048, bottleneck_dim=256, norm_last_layer=True, |
| shared_head=False, **kwargs): |
| |
| super(iBOTHead, self).__init__(*args, |
| norm=norm, |
| act=act, |
| last_norm=last_norm, |
| nlayers=nlayers, |
| hidden_dim=hidden_dim, |
| bottleneck_dim=bottleneck_dim, |
| norm_last_layer=norm_last_layer, |
| **kwargs) |
|
|
| if not shared_head: |
| if bottleneck_dim > 0: |
| self.last_layer2 = nn.utils.weight_norm(nn.Linear(bottleneck_dim, patch_out_dim, bias=False)) |
| self.last_layer2.weight_g.data.fill_(1) |
| if norm_last_layer: |
| self.last_layer2.weight_g.requires_grad = False |
| else: |
| self.mlp2 = nn.Linear(hidden_dim, patch_out_dim) |
| self.last_layer2 = None |
|
|
| self.last_norm2 = self._build_norm(last_norm, patch_out_dim, affine=False, **kwargs) |
| else: |
| if bottleneck_dim > 0: |
| self.last_layer2 = self.last_layer |
| else: |
| self.mlp2 = self.mlp[-1] |
| self.last_layer2 = None |
|
|
| self.last_norm2 = self.last_norm |
|
|
| def forward(self, x): |
| if len(x.shape) == 2: |
| return super(iBOTHead, self).forward(x) |
|
|
| if self.last_layer is not None: |
| x = self.mlp(x) |
| x = nn.functional.normalize(x, dim=-1, p=2) |
| x1 = self.last_layer(x[:, 0]) |
| x2 = self.last_layer2(x[:, 1:]) |
| else: |
| x = self.mlp[:-1](x) |
| x1 = self.mlp[-1](x[:, 0]) |
| x2 = self.mlp2(x[:, 1:]) |
| |
| if self.last_norm is not None: |
| x1 = self.last_norm(x1) |
| x2 = self.last_norm2(x2) |
| |
| return x1, x2 |
|
|
|
|
|
|
| class TemporalSideContext(nn.Module): |
| def __init__(self, D, max_len=64, n_layers=6, n_head=8, dropout=0.1): |
| super().__init__() |
| |
| layer = nn.TransformerEncoderLayer(D, n_head, 4*D, |
| dropout=dropout, batch_first=True) |
| self.enc = nn.TransformerEncoder(layer, n_layers) |
|
|
| def forward(self, x): |
| B,T,D = x.shape |
| device = x.device |
| |
| |
| |
| |
| return self.enc(x) |
|
|
|
|
|
|
| class TemporalHead(nn.Module): |
| """ |
| Converts backbone features [B,T,D] → logits [B,T,1] for Plackett–Luce. |
| """ |
| def __init__(self, backbone_dim: int, hidden_mul: float = 0.5, max_len: int = 64): |
| super().__init__() |
| hidden_dim = int(backbone_dim * hidden_mul) |
|
|
| self.reduce = nn.Sequential( |
| nn.Linear(backbone_dim, hidden_dim), |
| nn.GELU() |
| ) |
| self.temporal = TemporalSideContext(hidden_dim, max_len=max_len) |
| self.scorer = nn.Sequential( |
| nn.Linear(hidden_dim, hidden_dim // 2), |
| nn.GELU(), |
| nn.Linear(hidden_dim // 2, 1) |
| ) |
|
|
| def forward(self, x: torch.Tensor): |
| x = self.reduce(x) |
| x = self.temporal(x) |
| return self.scorer(x) |
|
|
|
|
|
|
|
|