diff --git a/dinov2/__init__.py b/dinov2/__init__.py deleted file mode 100644 index ae847e46898077fe3d8701b8a181d7b4e3d41cd9..0000000000000000000000000000000000000000 --- a/dinov2/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -__version__ = "0.0.1" diff --git a/dinov2/configs/__init__.py b/dinov2/configs/__init__.py deleted file mode 100644 index 68e0830c62ea19649b6cd2361995f6df309d7640..0000000000000000000000000000000000000000 --- a/dinov2/configs/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import pathlib - -from omegaconf import OmegaConf - - -def load_config(config_name: str): - config_filename = config_name + ".yaml" - return OmegaConf.load(pathlib.Path(__file__).parent.resolve() / config_filename) - - -dinov2_default_config = load_config("ssl_default_config") - - -def load_and_merge_config(config_name: str): - default_config = OmegaConf.create(dinov2_default_config) - loaded_config = load_config(config_name) - return OmegaConf.merge(default_config, loaded_config) diff --git a/dinov2/configs/eval/cell_dino/vitl16_channel_adaptive_pretrain.yaml b/dinov2/configs/eval/cell_dino/vitl16_channel_adaptive_pretrain.yaml deleted file mode 100644 index e32eb1772ccb5d59c8566987a06498cadf126b63..0000000000000000000000000000000000000000 --- a/dinov2/configs/eval/cell_dino/vitl16_channel_adaptive_pretrain.yaml +++ /dev/null @@ -1,35 +0,0 @@ -train: - batch_size_per_gpu: 32 - OFFICIAL_EPOCH_LENGTH: 450 - cell_augmentation: true - channel_adaptive: true -student: - arch: vit_large - patch_size: 16 - num_register_tokens: 0 - interpolate_antialias: false - interpolate_offset: 0.1 - drop_path_rate: 0.1 - in_chans: 1 - block_chunks: 4 - channel_adaptive: true -teacher: - momentum_teacher: 0.996 - warmup_teacher_temp_epochs: 20 - in_chans: 1 - channel_adaptive: true -crops: - global_crops_scale: - - 0.4 - - 1.0 - local_crops_number: 8 - local_crops_scale: - - 0.005 - - 0.4 - global_crops_size: 224 - local_crops_size: 96 -optim: - weight_decay_end: 0.2 - base_lr: 5.0e-4 - warmup_epochs: 20 - epochs: 400 \ No newline at end of file diff --git a/dinov2/configs/eval/cell_dino/vitl16_pretrain.yaml b/dinov2/configs/eval/cell_dino/vitl16_pretrain.yaml deleted file mode 100644 index 0c31ee81c4c69a83a8753fdd06583843e0485fa3..0000000000000000000000000000000000000000 --- a/dinov2/configs/eval/cell_dino/vitl16_pretrain.yaml +++ /dev/null @@ -1,14 +0,0 @@ -student: - arch: vit_large - patch_size: 16 - num_register_tokens: 0 - interpolate_antialias: false - interpolate_offset: 0.1 - drop_path_rate: 0.1 - in_chans: 4 - block_chunks: 4 -teacher: - in_chans: 4 -crops: - global_crops_size: 224 - local_crops_size: 96 diff --git a/dinov2/configs/eval/vitb14_pretrain.yaml b/dinov2/configs/eval/vitb14_pretrain.yaml deleted file mode 100644 index 117d0f027ca26cd8ce6c010bb78d5a8fac42c70e..0000000000000000000000000000000000000000 --- a/dinov2/configs/eval/vitb14_pretrain.yaml +++ /dev/null @@ -1,6 +0,0 @@ -student: - arch: vit_base - patch_size: 14 -crops: - global_crops_size: 518 # this is to set up the position embeddings properly - local_crops_size: 98 \ No newline at end of file diff --git a/dinov2/configs/eval/vitb14_reg4_pretrain.yaml b/dinov2/configs/eval/vitb14_reg4_pretrain.yaml deleted file mode 100644 index d53edc04a0761b4b35c147d63e04d55c90092c8f..0000000000000000000000000000000000000000 --- a/dinov2/configs/eval/vitb14_reg4_pretrain.yaml +++ /dev/null @@ -1,9 +0,0 @@ -student: - arch: vit_base - patch_size: 14 - num_register_tokens: 4 - interpolate_antialias: true - interpolate_offset: 0.0 -crops: - global_crops_size: 518 # this is to set up the position embeddings properly - local_crops_size: 98 \ No newline at end of file diff --git a/dinov2/configs/eval/vitg14_pretrain.yaml b/dinov2/configs/eval/vitg14_pretrain.yaml deleted file mode 100644 index a96dd5b117b4d59ee210b65037821f1b3e3f16e3..0000000000000000000000000000000000000000 --- a/dinov2/configs/eval/vitg14_pretrain.yaml +++ /dev/null @@ -1,7 +0,0 @@ -student: - arch: vit_giant2 - patch_size: 14 - ffn_layer: swiglufused -crops: - global_crops_size: 518 # this is to set up the position embeddings properly - local_crops_size: 98 \ No newline at end of file diff --git a/dinov2/configs/eval/vitg14_reg4_pretrain.yaml b/dinov2/configs/eval/vitg14_reg4_pretrain.yaml deleted file mode 100644 index 15948f8589ea0a6e04717453eb88c18388e7f1b2..0000000000000000000000000000000000000000 --- a/dinov2/configs/eval/vitg14_reg4_pretrain.yaml +++ /dev/null @@ -1,10 +0,0 @@ -student: - arch: vit_giant2 - patch_size: 14 - ffn_layer: swiglufused - num_register_tokens: 4 - interpolate_antialias: true - interpolate_offset: 0.0 -crops: - global_crops_size: 518 # this is to set up the position embeddings properly - local_crops_size: 98 \ No newline at end of file diff --git a/dinov2/configs/eval/vitl14_pretrain.yaml b/dinov2/configs/eval/vitl14_pretrain.yaml deleted file mode 100644 index 7a984548bd034f762d455419d7193917fa462dd8..0000000000000000000000000000000000000000 --- a/dinov2/configs/eval/vitl14_pretrain.yaml +++ /dev/null @@ -1,6 +0,0 @@ -student: - arch: vit_large - patch_size: 14 -crops: - global_crops_size: 518 # this is to set up the position embeddings properly - local_crops_size: 98 \ No newline at end of file diff --git a/dinov2/configs/eval/vitl14_reg4_pretrain.yaml b/dinov2/configs/eval/vitl14_reg4_pretrain.yaml deleted file mode 100644 index 0e2bc4e7b24b1a64d0369a24927996d0f184e283..0000000000000000000000000000000000000000 --- a/dinov2/configs/eval/vitl14_reg4_pretrain.yaml +++ /dev/null @@ -1,9 +0,0 @@ -student: - arch: vit_large - patch_size: 14 - num_register_tokens: 4 - interpolate_antialias: true - interpolate_offset: 0.0 -crops: - global_crops_size: 518 # this is to set up the position embeddings properly - local_crops_size: 98 \ No newline at end of file diff --git a/dinov2/configs/eval/vits14_pretrain.yaml b/dinov2/configs/eval/vits14_pretrain.yaml deleted file mode 100644 index afbdb4ba14f1c97130a25b579360f4d817cda495..0000000000000000000000000000000000000000 --- a/dinov2/configs/eval/vits14_pretrain.yaml +++ /dev/null @@ -1,6 +0,0 @@ -student: - arch: vit_small - patch_size: 14 -crops: - global_crops_size: 518 # this is to set up the position embeddings properly - local_crops_size: 98 \ No newline at end of file diff --git a/dinov2/configs/eval/vits14_reg4_pretrain.yaml b/dinov2/configs/eval/vits14_reg4_pretrain.yaml deleted file mode 100644 index d25fd638389bfba9220792302dc9dbf5d9a2406a..0000000000000000000000000000000000000000 --- a/dinov2/configs/eval/vits14_reg4_pretrain.yaml +++ /dev/null @@ -1,9 +0,0 @@ -student: - arch: vit_small - patch_size: 14 - num_register_tokens: 4 - interpolate_antialias: true - interpolate_offset: 0.0 -crops: - global_crops_size: 518 # this is to set up the position embeddings properly - local_crops_size: 98 \ No newline at end of file diff --git a/dinov2/configs/ssl_default_config.yaml b/dinov2/configs/ssl_default_config.yaml deleted file mode 100644 index cdbea43d8e497baec8ab5172b1f83d037f29a4d9..0000000000000000000000000000000000000000 --- a/dinov2/configs/ssl_default_config.yaml +++ /dev/null @@ -1,123 +0,0 @@ -MODEL: - WEIGHTS: '' -compute_precision: - grad_scaler: true - teacher: - backbone: - sharding_strategy: SHARD_GRAD_OP - mixed_precision: - param_dtype: fp16 - reduce_dtype: fp16 - buffer_dtype: fp32 - dino_head: - sharding_strategy: SHARD_GRAD_OP - mixed_precision: - param_dtype: fp16 - reduce_dtype: fp16 - buffer_dtype: fp32 - ibot_head: - sharding_strategy: SHARD_GRAD_OP - mixed_precision: - param_dtype: fp16 - reduce_dtype: fp16 - buffer_dtype: fp32 - student: - backbone: - sharding_strategy: SHARD_GRAD_OP - mixed_precision: - param_dtype: fp16 - reduce_dtype: fp16 - buffer_dtype: fp32 - dino_head: - sharding_strategy: SHARD_GRAD_OP - mixed_precision: - param_dtype: fp16 - reduce_dtype: fp32 - buffer_dtype: fp32 - ibot_head: - sharding_strategy: SHARD_GRAD_OP - mixed_precision: - param_dtype: fp16 - reduce_dtype: fp32 - buffer_dtype: fp32 -dino: - loss_weight: 1.0 - head_n_prototypes: 65536 - head_bottleneck_dim: 256 - head_nlayers: 3 - head_hidden_dim: 2048 - koleo_loss_weight: 0.1 -ibot: - loss_weight: 1.0 - mask_sample_probability: 0.5 - mask_ratio_min_max: - - 0.1 - - 0.5 - separate_head: false - head_n_prototypes: 65536 - head_bottleneck_dim: 256 - head_nlayers: 3 - head_hidden_dim: 2048 -train: - batch_size_per_gpu: 64 - dataset_path: ImageNet:split=TRAIN - output_dir: . - saveckp_freq: 20 - seed: 0 - num_workers: 10 - OFFICIAL_EPOCH_LENGTH: 1250 - cache_dataset: true - centering: "centering" # or "sinkhorn_knopp" - cell_augmentation: false -student: - arch: vit_large - patch_size: 16 - drop_path_rate: 0.3 - layerscale: 1.0e-05 - drop_path_uniform: true - pretrained_weights: '' - ffn_layer: "mlp" - block_chunks: 0 - qkv_bias: true - proj_bias: true - ffn_bias: true - num_register_tokens: 0 - interpolate_antialias: false - interpolate_offset: 0.1 - in_chans: 3 - channel_adaptive: false -teacher: - momentum_teacher: 0.992 - final_momentum_teacher: 1 - warmup_teacher_temp: 0.04 - teacher_temp: 0.07 - warmup_teacher_temp_epochs: 30 - in_chans: 3 - channel_adaptive: false -optim: - epochs: 100 - weight_decay: 0.04 - weight_decay_end: 0.4 - base_lr: 0.004 # learning rate for a batch size of 1024 - lr: 0. # will be set after applying scaling rule - warmup_epochs: 10 - min_lr: 1.0e-06 - clip_grad: 3.0 - freeze_last_layer_epochs: 1 - scaling_rule: sqrt_wrt_1024 - patch_embed_lr_mult: 0.2 - layerwise_decay: 0.9 - adamw_beta1: 0.9 - adamw_beta2: 0.999 -crops: - global_crops_scale: - - 0.32 - - 1.0 - local_crops_number: 8 - local_crops_scale: - - 0.05 - - 0.32 - global_crops_size: 224 - local_crops_size: 96 -evaluation: - eval_period_iterations: 12500 diff --git a/dinov2/configs/train/cell_dino/vitl16_boc_hpafov.yaml b/dinov2/configs/train/cell_dino/vitl16_boc_hpafov.yaml deleted file mode 100644 index 4520df31547be6078237811b3e2b3a29c40db899..0000000000000000000000000000000000000000 --- a/dinov2/configs/train/cell_dino/vitl16_boc_hpafov.yaml +++ /dev/null @@ -1,31 +0,0 @@ -train: - batch_size_per_gpu: 16 - OFFICIAL_EPOCH_LENGTH: 450 - cell_augmentation: true - channel_adaptive: true -student: - arch: vit_large - patch_size: 16 - in_chans: 1 - drop_path_rate: 0.1 - block_chunks: 4 -teacher: - momentum_teacher: 0.996 - warmup_teacher_temp_epochs: 20 - in_chans: 1 -crops: - global_crops_scale: - - 0.4 - - 1.0 - local_crops_number: 8 - local_crops_scale: - - 0.005 - - 0.4 - global_crops_size: 224 - local_crops_size: 96 -optim: - weight_decay_end: 0.2 - base_lr: 5.0e-4 - warmup_epochs: 20 - epochs: 400 - \ No newline at end of file diff --git a/dinov2/configs/train/cell_dino/vitl16_hpafov.yaml b/dinov2/configs/train/cell_dino/vitl16_hpafov.yaml deleted file mode 100644 index 59496f93dfdf1cb022976e99f0a1d9a52df54a0a..0000000000000000000000000000000000000000 --- a/dinov2/configs/train/cell_dino/vitl16_hpafov.yaml +++ /dev/null @@ -1,32 +0,0 @@ -train: - batch_size_per_gpu: 16 - OFFICIAL_EPOCH_LENGTH: 450 - cell_augmentation: true -student: - arch: vit_large - patch_size: 16 - in_chans: 4 - drop_path_rate: 0.1 - block_chunks: 4 -teacher: - momentum_teacher: 0.996 - warmup_teacher_temp_epochs: 20 - in_chans: 4 -optim: - weight_decay_end: 0.2 - base_lr: 5.0e-4 - warmup_epochs: 20 -crops: - global_crops_scale: - - 0.4 - - 1.0 - local_crops_number: 8 - local_crops_scale: - - 0.005 - - 0.4 - global_crops_size: 224 - local_crops_size: 96 -evaluation: - eval_period_iterations: 9000 - - \ No newline at end of file diff --git a/dinov2/configs/train/cell_dino/vitl16_hpaone.yaml b/dinov2/configs/train/cell_dino/vitl16_hpaone.yaml deleted file mode 100644 index c6f76b1c2c9d0f02b1e377b03aba7043d03e5cee..0000000000000000000000000000000000000000 --- a/dinov2/configs/train/cell_dino/vitl16_hpaone.yaml +++ /dev/null @@ -1,30 +0,0 @@ -train: - batch_size_per_gpu: 16 - OFFICIAL_EPOCH_LENGTH: 1756 - cell_augmentation: true -student: - arch: vit_large - patch_size: 16 - in_chans: 4 - drop_path_rate: 0.1 - block_chunks: 4 -teacher: - momentum_teacher: 0.996 - warmup_teacher_temp_epochs: 20 - in_chans: 4 -optim: - weight_decay_end: 0.2 - base_lr: 5.0e-4 - warmup_epochs: 20 -crops: - global_crops_scale: - - 0.4 - - 1.0 - local_crops_number: 8 - local_crops_scale: - - 0.005 - - 0.4 - global_crops_size: 224 - local_crops_size: 96 -evaluation: - eval_period_iterations: 9000 \ No newline at end of file diff --git a/dinov2/configs/train/vitg14.yaml b/dinov2/configs/train/vitg14.yaml deleted file mode 100644 index d05cf0d59e07ac6e4a2b0f9bdcb6131d7c508962..0000000000000000000000000000000000000000 --- a/dinov2/configs/train/vitg14.yaml +++ /dev/null @@ -1,26 +0,0 @@ -dino: - head_n_prototypes: 131072 - head_bottleneck_dim: 384 -ibot: - separate_head: true - head_n_prototypes: 131072 -train: - batch_size_per_gpu: 12 - dataset_path: ImageNet22k - centering: sinkhorn_knopp -student: - arch: vit_giant2 - patch_size: 14 - drop_path_rate: 0.4 - ffn_layer: swiglufused - block_chunks: 4 -teacher: - momentum_teacher: 0.994 -optim: - epochs: 500 - weight_decay_end: 0.2 - base_lr: 2.0e-04 # learning rate for a batch size of 1024 - warmup_epochs: 80 - layerwise_decay: 1.0 -crops: - local_crops_size: 98 \ No newline at end of file diff --git a/dinov2/configs/train/vitl14.yaml b/dinov2/configs/train/vitl14.yaml deleted file mode 100644 index d9b491dcc6a522c71328fc2933dd0501123c8f6b..0000000000000000000000000000000000000000 --- a/dinov2/configs/train/vitl14.yaml +++ /dev/null @@ -1,26 +0,0 @@ -dino: - head_n_prototypes: 131072 - head_bottleneck_dim: 384 -ibot: - separate_head: true - head_n_prototypes: 131072 -train: - batch_size_per_gpu: 32 - dataset_path: ImageNet22k - centering: sinkhorn_knopp -student: - arch: vit_large - patch_size: 14 - drop_path_rate: 0.4 - ffn_layer: swiglufused - block_chunks: 4 -teacher: - momentum_teacher: 0.994 -optim: - epochs: 500 - weight_decay_end: 0.2 - base_lr: 2.0e-04 # learning rate for a batch size of 1024 - warmup_epochs: 80 - layerwise_decay: 1.0 -crops: - local_crops_size: 98 \ No newline at end of file diff --git a/dinov2/configs/train/vitl16_short.yaml b/dinov2/configs/train/vitl16_short.yaml deleted file mode 100644 index 3e7e72864c92175a1354142ac1d64da8070d1e5e..0000000000000000000000000000000000000000 --- a/dinov2/configs/train/vitl16_short.yaml +++ /dev/null @@ -1,6 +0,0 @@ -# this corresponds to the default config -train: - dataset_path: ImageNet:split=TRAIN - batch_size_per_gpu: 64 -student: - block_chunks: 4 diff --git a/dinov2/data/__init__.py b/dinov2/data/__init__.py deleted file mode 100644 index ac440218caed7e1d398ae518361246090b91614b..0000000000000000000000000000000000000000 --- a/dinov2/data/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from .adapters import DatasetWithEnumeratedTargets -from .loaders import make_data_loader, make_dataset, SamplerType -from .collate import collate_data_and_cast -from .masking import MaskingGenerator -from .augmentations import DataAugmentationDINO -from .cell_dino.augmentations import CellAugmentationDINO -from .accumulators import NoOpAccumulator, ResultsAccumulator diff --git a/dinov2/data/accumulators.py b/dinov2/data/accumulators.py deleted file mode 100644 index a63bcf0a44605587423fa34a6ab51153040bbb40..0000000000000000000000000000000000000000 --- a/dinov2/data/accumulators.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from collections import defaultdict -from typing import Dict, List, Optional, Any - -import torch -from torch import Tensor -from torch.nn import functional as F - -import torch.distributed as dist -from dinov2.distributed import get_global_size - - -def _simple_gather_all_tensors(result: torch.Tensor, group: Any, world_size: int) -> List[torch.Tensor]: - gathered_result = [torch.zeros_like(result) for _ in range(world_size)] - dist.all_gather(gathered_result, result, group) - return gathered_result - - -def gather_all_tensors(result: torch.Tensor, group: Optional[Any] = None) -> List[torch.Tensor]: - """ - Copied from https://github.com/Lightning-AI/torchmetrics/blob/master/src/torchmetrics/utilities/distributed.py - Gather all tensors from several ddp processes onto a list that is broadcasted to all processes. - - Works on tensors that have the same number of dimensions, but where each dimension may differ. In this case - tensors are padded, gathered and then trimmed to secure equal workload for all processes. - - Args: - result: the value to sync - group: the process group to gather results from. Defaults to all processes (world) - - Return: - list with size equal to the process group where element i corresponds to result tensor from process i - """ - # convert tensors to contiguous format - result = result.contiguous() - - world_size = get_global_size() - dist.barrier(group=group) - - # if the tensor is scalar, things are easy - if result.ndim == 0: - return _simple_gather_all_tensors(result, group, world_size) - - # 1. Gather sizes of all tensors - local_size = torch.tensor(result.shape, device=result.device) - local_sizes = [torch.zeros_like(local_size) for _ in range(world_size)] - dist.all_gather(local_sizes, local_size, group=group) - max_size = torch.stack(local_sizes).max(dim=0).values - all_sizes_equal = all(all(ls == max_size) for ls in local_sizes) - - # 2. If shapes are all the same, then do a simple gather: - if all_sizes_equal: - return _simple_gather_all_tensors(result, group, world_size) - - # 3. If not, we need to pad each local tensor to maximum size, gather and then truncate - pad_dims = [] - pad_by = (max_size - local_size).detach().cpu() - for val in reversed(pad_by): - pad_dims.append(0) - pad_dims.append(val.item()) - result_padded = F.pad(result, pad_dims) - gathered_result = [torch.zeros_like(result_padded) for _ in range(world_size)] - dist.all_gather(gathered_result, result_padded, group) - for idx, item_size in enumerate(local_sizes): - slice_param = [slice(dim_size) for dim_size in item_size] - gathered_result[idx] = gathered_result[idx][slice_param] - return gathered_result - - -def _cat_and_gather_tensor_list(tensor_list: List[Tensor]) -> Tensor: - local_cat = torch.cat(tensor_list) - return torch.cat(gather_all_tensors(local_cat)) - - -class Accumulator: - def __init__(self) -> None: - pass - - def update(self, preds: Tensor, target: Tensor, index: Tensor) -> None: - raise NotImplementedError - - def accumulate(self) -> Optional[Dict[str, Tensor]]: - raise NotImplementedError - - -class NoOpAccumulator(Accumulator): - def __init__(self) -> None: - pass - - def update(self, preds: Tensor, target: Tensor, index: Tensor) -> None: - pass - - def accumulate(self) -> None: - return None - - -class ResultsAccumulator(Accumulator): - """ - Accumulate predictions and targets across processes - """ - - def __init__(self) -> None: - self._local_values: Dict[str, List[Tensor]] = defaultdict(list) - self._gathered_values: Dict[str, Tensor] = {} - self._gathered = False - - def update(self, preds: Tensor, target: Tensor, index: Tensor) -> None: - assert len(preds) == len(target) == len(index) - assert not self._gathered, "Tensors have already been gathered in this helper" - self._local_values["preds"].append(preds) - self._local_values["target"].append(target) - self._local_values["index"].append(index) - self._gathered = False - - def _gather_tensors(self): - for k, tensor_list in self._local_values.items(): - self._gathered_values[k] = _cat_and_gather_tensor_list(tensor_list) - self._gathered = True - - def accumulate(self) -> Dict[str, Tensor]: - if not self._gathered: - self._gather_tensors() - preds, target, index = [self._gathered_values[k] for k in ["preds", "target", "index"]] - assert len(preds) == len(target) == len(index) and index.min() == 0 - preds_ordered = torch.zeros((index.max() + 1, *preds.shape[1:]), dtype=preds.dtype, device=preds.device) - preds_ordered[index] = preds - target_ordered = torch.zeros((index.max() + 1, *target.shape[1:]), dtype=target.dtype, device=target.device) - target_ordered[index] = target - return {"preds": preds_ordered, "target": target_ordered} diff --git a/dinov2/data/adapters.py b/dinov2/data/adapters.py deleted file mode 100644 index a5efe965f11c8c536ee31f90dc9b1f96be7c556f..0000000000000000000000000000000000000000 --- a/dinov2/data/adapters.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from typing import Any, Tuple, Optional - -from torch.utils.data import Dataset - - -class DatasetWithEnumeratedTargets(Dataset): - """ - If pad_dataset is set, pads based on torch's DistributedSampler implementation, which - with drop_last=False pads the last batch to be a multiple of the world size. - https://github.com/pytorch/pytorch/blob/main/torch/utils/data/distributed.py#L91 - """ - - def __init__(self, dataset: Dataset, pad_dataset: bool = False, num_replicas: Optional[int] = None): - self._dataset = dataset - self._size = len(self._dataset) - self._padded_size = self._size - self._pad_dataset = pad_dataset - if self._pad_dataset: - assert num_replicas is not None, "num_replicas should be set if pad_dataset is True" - self._padded_size = num_replicas * ((len(dataset) + num_replicas - 1) // num_replicas) - - def get_image_relpath(self, index: int) -> str: - assert self._pad_dataset or index < self._size - return self._dataset.get_image_relpath(index % self._size) - - def get_image_data(self, index: int) -> bytes: - assert self._pad_dataset or index < self._size - return self._dataset.get_image_data(index % self._size) - - def get_target(self, index: int) -> Tuple[Any, int]: - target = self._dataset.get_target(index % self._size) - if index >= self._size: - assert self._pad_dataset - return (-1, target) - return (index, target) - - def __getitem__(self, index: int) -> Tuple[Any, Tuple[Any, int]]: - image, target = self._dataset[index % self._size] - if index >= self._size: - assert self._pad_dataset - return image, (-1, target) - target = index if target is None else target - return image, (index, target) - - def __len__(self) -> int: - return self._padded_size diff --git a/dinov2/data/augmentations.py b/dinov2/data/augmentations.py deleted file mode 100644 index 05b1eaa942c14f75b88d9e14732e141e8909b0a1..0000000000000000000000000000000000000000 --- a/dinov2/data/augmentations.py +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import logging - -from torchvision import transforms - -from .transforms import ( - GaussianBlur, - make_normalize_transform, -) - - -logger = logging.getLogger("dinov2") - - -class DataAugmentationDINO(object): - def __init__( - self, - global_crops_scale, - local_crops_scale, - local_crops_number, - global_crops_size=224, - local_crops_size=96, - ): - self.global_crops_scale = global_crops_scale - self.local_crops_scale = local_crops_scale - self.local_crops_number = local_crops_number - self.global_crops_size = global_crops_size - self.local_crops_size = local_crops_size - - logger.info("###################################") - logger.info("Using data augmentation parameters:") - logger.info(f"global_crops_scale: {global_crops_scale}") - logger.info(f"local_crops_scale: {local_crops_scale}") - logger.info(f"local_crops_number: {local_crops_number}") - logger.info(f"global_crops_size: {global_crops_size}") - logger.info(f"local_crops_size: {local_crops_size}") - logger.info("###################################") - - # random resized crop and flip - self.geometric_augmentation_global = transforms.Compose( - [ - transforms.RandomResizedCrop( - global_crops_size, scale=global_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC - ), - transforms.RandomHorizontalFlip(p=0.5), - ] - ) - - self.geometric_augmentation_local = transforms.Compose( - [ - transforms.RandomResizedCrop( - local_crops_size, scale=local_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC - ), - transforms.RandomHorizontalFlip(p=0.5), - ] - ) - - # color distorsions / blurring - color_jittering = transforms.Compose( - [ - transforms.RandomApply( - [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)], - p=0.8, - ), - transforms.RandomGrayscale(p=0.2), - ] - ) - - global_transfo1_extra = GaussianBlur(p=1.0) - - global_transfo2_extra = transforms.Compose( - [ - GaussianBlur(p=0.1), - transforms.RandomSolarize(threshold=128, p=0.2), - ] - ) - - local_transfo_extra = GaussianBlur(p=0.5) - - # normalization - self.normalize = transforms.Compose( - [ - transforms.ToTensor(), - make_normalize_transform(), - ] - ) - - self.global_transfo1 = transforms.Compose([color_jittering, global_transfo1_extra, self.normalize]) - self.global_transfo2 = transforms.Compose([color_jittering, global_transfo2_extra, self.normalize]) - self.local_transfo = transforms.Compose([color_jittering, local_transfo_extra, self.normalize]) - - def __call__(self, image): - output = {} - - # global crops: - im1_base = self.geometric_augmentation_global(image) - global_crop_1 = self.global_transfo1(im1_base) - - im2_base = self.geometric_augmentation_global(image) - global_crop_2 = self.global_transfo2(im2_base) - - output["global_crops"] = [global_crop_1, global_crop_2] - - # global crops for teacher: - output["global_crops_teacher"] = [global_crop_1, global_crop_2] - - # local crops: - local_crops = [ - self.local_transfo(self.geometric_augmentation_local(image)) for _ in range(self.local_crops_number) - ] - output["local_crops"] = local_crops - output["offsets"] = () - - return output diff --git a/dinov2/data/cell_dino/augmentations.py b/dinov2/data/cell_dino/augmentations.py deleted file mode 100644 index 2d324e7f2b342fa5f1529f90596281011d86803a..0000000000000000000000000000000000000000 --- a/dinov2/data/cell_dino/augmentations.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the CC-by-NC licence, -# found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree. - -import logging -import torchvision -from torchvision import transforms - -from .transforms import ( - RandomContrastProteinChannel, - RandomRemoveChannelExceptProtein, - RandomBrightness, - RandomContrast, - Div255, - SelfNormalizeNoDiv, -) - -logger = logging.getLogger("dinov2") - - -class CellAugmentationDINO(object): - def __init__( - self, - global_crops_scale, - local_crops_scale, - local_crops_number, - global_crops_size=224, - local_crops_size=96, - ): - self.global_crops_scale = global_crops_scale - self.local_crops_scale = local_crops_scale - self.local_crops_number = local_crops_number - self.global_crops_size = global_crops_size - self.local_crops_size = local_crops_size - - logger.info("###################################") - logger.info("Using data augmentation parameters:") - logger.info(f"global_crops_scale: {global_crops_scale}") - logger.info(f"local_crops_scale: {local_crops_scale}") - logger.info(f"local_crops_number: {local_crops_number}") - logger.info(f"global_crops_size: {global_crops_size}") - logger.info(f"local_crops_size: {local_crops_size}") - logger.info("###################################") - - additional_transforms_list = [ - torchvision.transforms.RandomHorizontalFlip(), - torchvision.transforms.RandomVerticalFlip(), - RandomBrightness(), - RandomContrast(), - SelfNormalizeNoDiv(), - ] - - first_transforms_list = [ - Div255(), - RandomRemoveChannelExceptProtein(), - RandomContrastProteinChannel(), - ] - - global_transforms_list = first_transforms_list.copy() - global_transforms_list.append( - torchvision.transforms.RandomResizedCrop(size=global_crops_size, scale=global_crops_scale) - ) - global_transforms_list = global_transforms_list + additional_transforms_list - - local_transforms_list = first_transforms_list - local_transforms_list.append( - torchvision.transforms.RandomResizedCrop(size=local_crops_size, scale=local_crops_scale) - ) - local_transforms_list = local_transforms_list + additional_transforms_list - - self.global_transform = transforms.Compose(global_transforms_list) - self.local_transform = transforms.Compose(local_transforms_list) - - def __call__(self, image): - output = {} - - global_crop1 = self.global_transform(image) - global_crop2 = self.global_transform(image) - - output["global_crops"] = [global_crop1, global_crop2] - - local_crops = [] - for _ in range(self.local_crops_number): - local_crops.append(self.local_transform(image)) - - output["local_crops"] = local_crops - output["global_crops_teacher"] = [global_crop1, global_crop2] - output["offsets"] = () - - return output diff --git a/dinov2/data/cell_dino/transforms.py b/dinov2/data/cell_dino/transforms.py deleted file mode 100644 index 0f4752ce59606818b8fd1776d6207738499a98f5..0000000000000000000000000000000000000000 --- a/dinov2/data/cell_dino/transforms.py +++ /dev/null @@ -1,169 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the CC-by-NC licence, -# found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree. - -import torch -from torchvision import transforms -import numpy as np -from enum import Enum - - -class NormalizationType(Enum): - SELF_NORM_AUG_DECODER = "self_norm_aug_decoder" - SELF_NORM_CENTER_CROP = "self_norm_center_crop" - - -class Div255(torch.nn.Module): - def forward(self, x): - x = x / 255 - return x - - -class SelfNormalizeNoDiv(torch.nn.Module): - def forward(self, x): - m = x.mean((-2, -1), keepdim=True) - s = x.std((-2, -1), unbiased=False, keepdim=True) - x -= m - x /= s + 1e-7 - return x - - -class SelfNormalize(torch.nn.Module): - def forward(self, x): - x = x / 255 - m = x.mean((-2, -1), keepdim=True) - s = x.std((-2, -1), unbiased=False, keepdim=True) - x -= m - x /= s + 1e-7 - return x - - -class RandomContrastProteinChannel(torch.nn.Module): - """ - Random constrast rescaling of the protein channel only. - RescaleProtein function in Dino4cell codebase. - """ - - def __init__(self, p=0.2): - super().__init__() - self.p = p - - def forward(self, img): - if img.max() == 0: - return img - if len(img) == 1: - return img - if np.random.rand() <= self.p: - random_factor = (np.random.rand() * 2) / img.max() # scaling - img[1] = img[1] * random_factor - return img - else: - return img - - -class RandomRemoveChannelExceptProtein(torch.nn.Module): - """ - dropping a channel at random except the channel 1, corresponding to proteins in HPA datasets. - """ - - def __init__(self, p=0.2): - super().__init__() - self.p = p - - def forward(self, img): - img_size = np.array(img).shape - if img_size[0] < 4: - return img - if np.random.rand() <= self.p: - channel_to_blacken = np.random.choice(np.array([0, 2, 3])) - img[channel_to_blacken] = torch.zeros(1, *img.shape[1:]) - return img - else: - return img - - -class RandomRemoveChannel(torch.nn.Module): - """ - dropping a channel at random - """ - - def __init__(self, p=0.2): - super().__init__() - self.p = p - - def forward(self, img): - img_size = np.array(img).shape - num_channels = img_size[0] - if num_channels < 4: - return img - if np.random.rand() <= self.p: - channel_to_blacken = np.random.choice(np.array(list(range(num_channels)))) - img[channel_to_blacken] = torch.zeros(1, *img.shape[1:]) - return img - else: - return img - - -class RandomContrast(torch.nn.Module): - def __init__(self, p=0.2): - super().__init__() - self.p = p - - def forward(self, img): - if img.max() == 0: - return img - n_channels = img.shape[0] - for ind in range(n_channels): - factor = max(np.random.normal(1, self.p), 0.5) - img[ind] = transforms.functional.adjust_contrast(img[ind][None, ...], factor) - return img - - -class RandomBrightness(torch.nn.Module): - def __init__(self, p=0.2): - super().__init__() - self.p = p - - def forward(self, img): - if img.max() == 0: - return img - n_channels = img.shape[0] - for ind in range(n_channels): - factor = max(np.random.normal(1, self.p), 0.5) - img[ind] = transforms.functional.adjust_brightness(img[ind], factor) - return img - - -def make_classification_eval_cell_transform( - *, - resize_size: int = 0, - interpolation=transforms.InterpolationMode.BICUBIC, - crop_size: int = 384, - normalization_type: Enum = NormalizationType.SELF_NORM_CENTER_CROP, -) -> transforms.Compose: - - from .transforms import ( - Div255, - SelfNormalizeNoDiv, - ) - - transforms_list = [Div255()] - if resize_size > 0: - transforms_list.append(transforms.Resize(resize_size, interpolation=interpolation)) - - if normalization_type == NormalizationType.SELF_NORM_AUG_DECODER: - transforms_list.extend( - [ - transforms.RandomCrop(size=crop_size, pad_if_needed=True), - transforms.RandomHorizontalFlip(), - transforms.RandomVerticalFlip(), - ] - ) - elif normalization_type == NormalizationType.SELF_NORM_CENTER_CROP: - transforms_list.append(transforms.CenterCrop(size=crop_size)) - else: - raise ValueError("f{normalization_type}: unknown NormalizationType") - transforms_list.append(SelfNormalizeNoDiv()) - - return transforms.Compose(transforms_list) diff --git a/dinov2/data/collate.py b/dinov2/data/collate.py deleted file mode 100644 index b3e32f357a76e6f32162cee14cb6ae1665a4827a..0000000000000000000000000000000000000000 --- a/dinov2/data/collate.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import torch -import random - - -def collate_data_and_cast(samples_list, mask_ratio_tuple, mask_probability, dtype, n_tokens=None, mask_generator=None): - # dtype = torch.half # TODO: Remove - - n_global_crops = len(samples_list[0][0]["global_crops"]) - n_local_crops = len(samples_list[0][0]["local_crops"]) - - collated_global_crops = torch.stack([s[0]["global_crops"][i] for i in range(n_global_crops) for s in samples_list]) - - collated_local_crops = torch.stack([s[0]["local_crops"][i] for i in range(n_local_crops) for s in samples_list]) - - B = len(collated_global_crops) - N = n_tokens - n_samples_masked = int(B * mask_probability) - probs = torch.linspace(*mask_ratio_tuple, n_samples_masked + 1) - upperbound = 0 - masks_list = [] - for i in range(0, n_samples_masked): - prob_min = probs[i] - prob_max = probs[i + 1] - masks_list.append(torch.BoolTensor(mask_generator(int(N * random.uniform(prob_min, prob_max))))) - upperbound += int(N * prob_max) - for i in range(n_samples_masked, B): - masks_list.append(torch.BoolTensor(mask_generator(0))) - - random.shuffle(masks_list) - - collated_masks = torch.stack(masks_list).flatten(1) - mask_indices_list = collated_masks.flatten().nonzero().flatten() - - masks_weight = (1 / collated_masks.sum(-1).clamp(min=1.0)).unsqueeze(-1).expand_as(collated_masks)[collated_masks] - - return { - "collated_global_crops": collated_global_crops.to(dtype), - "collated_local_crops": collated_local_crops.to(dtype), - "collated_masks": collated_masks, - "mask_indices_list": mask_indices_list, - "masks_weight": masks_weight, - "upperbound": upperbound, - "n_masked_patches": torch.full((1,), fill_value=mask_indices_list.shape[0], dtype=torch.long), - } diff --git a/dinov2/data/datasets/__init__.py b/dinov2/data/datasets/__init__.py deleted file mode 100644 index e7cc56930fe15024b9c6c64749a76c4fdb529241..0000000000000000000000000000000000000000 --- a/dinov2/data/datasets/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from .image_net import ImageNet -from .image_net_22k import ImageNet22k -from .cell_dino.hpaone import HPAone -from .cell_dino.hpafov import HPAFoV -from .cell_dino.chammi_cp import CHAMMI_CP -from .cell_dino.chammi_hpa import CHAMMI_HPA -from .cell_dino.chammi_wtc import CHAMMI_WTC diff --git a/dinov2/data/datasets/cell_dino/chammi_cp.py b/dinov2/data/datasets/cell_dino/chammi_cp.py deleted file mode 100644 index 911d6b4ccf579ada69a111c3d5c5057652b3f05c..0000000000000000000000000000000000000000 --- a/dinov2/data/datasets/cell_dino/chammi_cp.py +++ /dev/null @@ -1,112 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the CC-by-NC licence, -# found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree. - -import csv -from enum import Enum -import logging -import os -from typing import Any, Callable, Optional, Union - -import numpy as np - -from ..extended import ExtendedVisionDataset -from ..decoders import DecoderType - -logger = logging.getLogger("dinov2") - - -METADATA_FILE = "morphem70k_v2.csv" - -CLASS_LABELS = { - "BRD-A29260609": 0, - "BRD-K04185004": 1, - "BRD-K21680192": 2, - "DMSO": 3, - "BRD-K11129031": 4, # labels only seen in TASK_FOUR - "BRD-K62310379": 5, - "BRD-K77947974": 6, -} - - -class _Split(Enum): - TRAIN = "Train" - TASK_ONE = "Task_one" - TASK_TWO = "Task_two" - TASK_THREE = "Task_three" - TASK_FOUR = "Task_four" - - -def _load_file_names_and_targets( - root: str, - split: _Split, -): - image_paths = [] - labels = [] - with open(os.path.join(root, METADATA_FILE)) as metadata: - metadata_reader = csv.DictReader(metadata) - for row in metadata_reader: - row_dataset = row["file_path"].split("/")[0] - - if row["train_test_split"].upper() == split and row_dataset == "CP": - image_paths.append(row["file_path"]) - labels.append(CLASS_LABELS[row["label"]]) - - return image_paths, labels # to debug - - -class CHAMMI_CP(ExtendedVisionDataset): - """ - Implementation of the CP (Cell-Painting) subset of the CHAMMI benchmark dataset, - following the CHAMMI paper: https://arxiv.org/pdf/2310.19224 - Github code: https://github.com/chaudatascience/channel_adaptive_models - """ - - Split = Union[_Split] - - def __init__( - self, - *, - split: "CHAMMI_CP.Split", - root: str, - transforms: Optional[Callable] = None, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - image_decoder_type: DecoderType = DecoderType.XChannelsDecoder, - **kwargs: Any, - ) -> None: - super().__init__( - root, - transforms, - transform, - target_transform, - image_decoder_type=image_decoder_type, - **kwargs, - ) - self.split = split - self.root = root - self.num_additional_labels_loo_eval = 3 - self._image_paths, self._targets = _load_file_names_and_targets( - root, - split, - ) - - def get_image_relpath(self, index: int) -> str: - return self._image_paths[index] - - def get_image_data(self, index: int) -> bytes: - image_relpath = self.get_image_relpath(index) - image_full_path = os.path.join(self.root, image_relpath) - with open(image_full_path, mode="rb") as f: - image_data = f.read() - return image_data - - def get_target(self, index: int) -> Any: - return self._targets[index] - - def get_targets(self) -> np.ndarray: - return np.array(self._targets) - - def __len__(self) -> int: - return len(self._image_paths) diff --git a/dinov2/data/datasets/cell_dino/chammi_hpa.py b/dinov2/data/datasets/cell_dino/chammi_hpa.py deleted file mode 100644 index e7b31ef9a075c07d23c99e6981f4f3dc240639a3..0000000000000000000000000000000000000000 --- a/dinov2/data/datasets/cell_dino/chammi_hpa.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the CC-by-NC licence, -# found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree. - -import csv -from enum import Enum -import logging -import os -from typing import Any, Callable, Optional, Union - -import numpy as np - -from ..extended import ExtendedVisionDataset -from ..decoders import DecoderType - -logger = logging.getLogger("dinov2") - - -METADATA_FILE = "morphem70k_v2.csv" - -CLASS_LABELS = { - "golgi apparatus": 0, - "microtubules": 1, - "mitochondria": 2, - "nuclear speckles": 3, - "cytosol": 4, # labels only seen in TASK_THREE - "endoplasmic reticulum": 5, - "nucleoplasm": 6, -} - - -class _Split(Enum): - TRAIN = "Train" - TASK_ONE = "Task_one" - TASK_TWO = "Task_two" - TASK_THREE = "Task_three" - - -def _load_file_names_and_targets( - root: str, - split: _Split, -): - image_paths = [] - labels = [] - with open(os.path.join(root, METADATA_FILE)) as metadata: - metadata_reader = csv.DictReader(metadata) - for row in metadata_reader: - row_dataset = row["file_path"].split("/")[0] - if row["train_test_split"].upper() == split and row_dataset == "HPA": - image_paths.append(row["file_path"]) - labels.append(CLASS_LABELS[row["label"]]) - - return image_paths, labels - - -class CHAMMI_HPA(ExtendedVisionDataset): - """ - Implementation of the CP (Cell-Painting) subset of the CHAMMI benchmark dataset, - following the CHAMMI paper: https://arxiv.org/pdf/2310.19224 - Github code: https://github.com/chaudatascience/channel_adaptive_models - """ - - Split = Union[_Split] - - def __init__( - self, - *, - split: "CHAMMI_HPA.Split", - root: str, - transforms: Optional[Callable] = None, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - image_decoder_type: DecoderType = DecoderType.XChannelsDecoder, - **kwargs: Any, - ) -> None: - super().__init__( - root, - transforms, - transform, - target_transform, - image_decoder_type=image_decoder_type, - **kwargs, - ) - self.split = split - self.root = root - self.num_additional_labels_loo_eval = 3 - - self._image_paths, self._targets = _load_file_names_and_targets( - root, - split, - ) - - def get_image_relpath(self, index: int) -> str: - return self._image_paths[index] - - def get_image_data(self, index: int) -> bytes: - image_relpath = self.get_image_relpath(index) - image_full_path = os.path.join(self.root, image_relpath) - with open(image_full_path, mode="rb") as f: - image_data = f.read() - return image_data - - def get_target(self, index: int) -> Any: - return self._targets[index] - - def get_targets(self) -> np.ndarray: - return np.array(self._targets) - - def __len__(self) -> int: - return len(self._image_paths) diff --git a/dinov2/data/datasets/cell_dino/chammi_wtc.py b/dinov2/data/datasets/cell_dino/chammi_wtc.py deleted file mode 100644 index 5a4005210cbaff60eccef7a5d1cba869dac067c0..0000000000000000000000000000000000000000 --- a/dinov2/data/datasets/cell_dino/chammi_wtc.py +++ /dev/null @@ -1,108 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the CC-by-NC licence, -# found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree. - -import csv -from enum import Enum -import logging -import os -from typing import Any, Callable, Optional, Union - -import numpy as np - -from ..extended import ExtendedVisionDataset -from ..decoders import DecoderType - -logger = logging.getLogger("dinov2") - - -METADATA_FILE = "morphem70k_v2.csv" - -CLASS_LABELS = { - "M0": 0, - "M1M2": 1, - "M3": 2, - "M4M5": 3, - "M6M7_complete": 4, - "M6M7_single": 5, -} - - -class _Split(Enum): - TRAIN = "Train" - TASK_ONE = "Task_one" - TASK_TWO = "Task_two" - - -def _load_file_names_and_targets( - root: str, - split: _Split, -): - image_paths = [] - labels = [] - with open(os.path.join(root, METADATA_FILE)) as metadata: - metadata_reader = csv.DictReader(metadata) - for row in metadata_reader: - row_dataset = row["file_path"].split("/")[0] - if row["train_test_split"].upper() == split and row_dataset == "Allen": - image_paths.append(row["file_path"]) - labels.append(CLASS_LABELS[row["label"]]) - - return image_paths, labels - - -class CHAMMI_WTC(ExtendedVisionDataset): - """ - Implementation of the CP (Cell-Painting) subset of the CHAMMI benchmark dataset, - following the CHAMMI paper: https://arxiv.org/pdf/2310.19224 - Github code: https://github.com/chaudatascience/channel_adaptive_models - """ - - Split = Union[_Split] - - def __init__( - self, - *, - split: "CHAMMI_WTC.Split", - root: str, - transforms: Optional[Callable] = None, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - image_decoder_type: DecoderType = DecoderType.XChannelsTIFFDecoder, - **kwargs: Any, - ) -> None: - super().__init__( - root, - transforms, - transform, - target_transform, - image_decoder_type=image_decoder_type, - **kwargs, - ) - self.split = split - self.root = root - - self._image_paths, self._targets = _load_file_names_and_targets( - root, - split, - ) - - def get_image_relpath(self, index: int) -> str: - return self._image_paths[index] - - def get_image_data(self, index: int) -> bytes: - image_relpath = self.get_image_relpath(index) - image_full_path = os.path.join(self.root, image_relpath) - with open(image_full_path, mode="rb") as f: - image_data = f.read() - return image_data - - def get_target(self, index: int) -> Any: - return self._targets[index] - - def get_targets(self) -> np.ndarray: - return np.array(self._targets) - - def __len__(self) -> int: - return len(self._image_paths) diff --git a/dinov2/data/datasets/cell_dino/hpafov.py b/dinov2/data/datasets/cell_dino/hpafov.py deleted file mode 100644 index 263d2f332f651e7d7bf8746bb53a9956957a686e..0000000000000000000000000000000000000000 --- a/dinov2/data/datasets/cell_dino/hpafov.py +++ /dev/null @@ -1,283 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the CC-by-NC licence, -# found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree. - -import csv -from enum import Enum -import logging -import os -from typing import Any, Callable, List, Optional, Tuple, Union, Dict - -import numpy as np - -from ..extended import ExtendedVisionDataset -from ..decoders import DecoderType - -logger = logging.getLogger("dinov2") - -CELL_TYPE = [ - "BJ", # 1 - "LHCN-M2", - "RH-30", - "SH-SY5Y", - "U-2 OS", # 5 - "ASC TERT1", - "HaCaT", - "A-431", - "U-251 MG", - "HEK 293", # 10 - "A549", - "RT4", - "HeLa", - "MCF7", - "PC-3", # 15 - "hTERT-RPE1", - "SK-MEL-30", - "EFO-21", - "AF22", - "HEL", # 20 - "Hep G2", - "HUVEC TERT2", - "THP-1", - "CACO-2", - "JURKAT", # 25 - "RPTEC TERT1", - "SuSa", - "REH", - "HDLM-2", - "K-562", # 30 - "hTCEpi", - "NB-4", - "HAP1", - "OE19", - "SiHa", # 35 -] - -PROTEIN_LOCALIZATION = [ # matches https://www.kaggle.com/c/human-protein-atlas-image-classification/data - "nucleoplasm", - "nuclear membrane", - "nucleoli", - "nucleoli fibrillar center", - "nuclear speckles", # 5 - "nuclear bodies", - "endoplasmic reticulum", - "golgi apparatus", - "peroxisomes", - "endosomes", # 10 - "lysosomes", - "intermediate filaments", - "actin filaments", - "focal adhesion sites", - "microtubules", # 15 - "microtubule ends", - "cytokinetic bridge", - "mitotic spindle", - "microtubule organizing center", - "centrosome", # 20 - "lipid droplets", - "plasma membrane", - "cell junctions", - "mitochondria", - "aggresome", # 25 - "cytosol", - "cytoplasmic bodies", - "rods & rings", -] - - -class _Split(Enum): - TRAIN = "train" - VAL = "val" - SSL = "ssl" - - -def get_csv_fpath(split): - """ - Path to data relative to root - """ - if split == _Split.TRAIN.value.upper() or split == _Split.TRAIN or split == "TRAIN": - return "whole_images_512_train.csv" - elif split == _Split.VAL.value.upper() or split == _Split.VAL or split == "VAL": - return "whole_images_512_test.csv" - - -class _WildCard(Enum): - NONE = "none" - SEPARATECHANNELS = "separate_channels" # each channel from each image is treated as an independent sample, overrides chosen channel configuration - - -class _Mode(Enum): - """ - Targets: - - ALL: tuple, (one hot encoding of multilabel protein localization, categorical encoding of cell type) - - PROTEIN_LOCALIZATION: one hot encoding of multilabel protein localization - - CELL_TYPE: categorical encoding of cell type - """ - - ALL = "all" - PROTEIN_LOCALIZATION = "protein_localization" - CELL_TYPE = "cell_type" - - @property - def nb_labels(self): - if self == _Mode.CELL_TYPE: - return len(CELL_TYPE) - elif self == _Mode.PROTEIN_LOCALIZATION: - return len(PROTEIN_LOCALIZATION) - else: - return None - - -def _list_images_from_csv(img_path, csv_path): - L = [] - with open(csv_path) as filename: - reader = csv.DictReader(filename) - for row in reader: - L.append(os.path.join(img_path, row["ID"] + ".png")) - return L - - -def _load_file_names_and_labels_ssl( - root: str, -) -> Tuple[List[str], List[Any]]: - - curr_img_path = os.path.join(root, "normalized_data") - csv_train_ssl = os.path.join(root, "whole_images_names.csv") - image_paths = _list_images_from_csv(curr_img_path, csv_train_ssl) - labels = [i for i in range(len(image_paths))] - return image_paths, labels - - -def _load_file_names_and_labels( - root: str, - split: _Split, - mode: _Mode, -) -> Tuple[List[str], List[Any], np.ndarray]: - - data_path = os.path.join(root, "512_whole_images") - csv_fpath = os.path.join(root, get_csv_fpath(split)) - - image_paths = [] - labels = [] - - with open(csv_fpath) as filename: - reader = csv.DictReader(filename) - for row in reader: - - add_sample = True - if mode != _Mode.PROTEIN_LOCALIZATION.value.upper(): - # categorical - if row["cell_type"] in CELL_TYPE: - cell_type = CELL_TYPE.index(row["cell_type"]) - else: - cell_type = np.nan - - if mode != _Mode.CELL_TYPE.value.upper(): - # one hot encoding - prot_loc = np.zeros(len(PROTEIN_LOCALIZATION), dtype=np.int_) - for k in range(len(PROTEIN_LOCALIZATION)): - if row[PROTEIN_LOCALIZATION[k]] == "True": - prot_loc[k] = 1 - if prot_loc.max() < 0.5: - add_sample = False - - if add_sample: - if mode == _Mode.PROTEIN_LOCALIZATION.value.upper(): - labels.append(prot_loc) - elif mode == _Mode.CELL_TYPE.value.upper(): - labels.append(cell_type) - else: - labels.append({"prot_loc": prot_loc, "cell_type": cell_type}) - - candidate_path = os.path.join(data_path, row["file"].split("/")[-1]) - if os.path.exists(candidate_path): - image_paths.append(candidate_path) - else: - candidate_path = os.path.join( - data_path, row["file"].split("/")[-1].split(".")[0] + ".tiff" - ) # _blue.png") # some images on the normalized_data folder have a _blue suffix on their names - if os.path.exists(candidate_path): - image_paths.append(candidate_path) - else: - raise FileNotFoundError(f"File {candidate_path} not found.") - - return image_paths, labels - - -class HPAFoV(ExtendedVisionDataset): - Split = Union[_Split] - Mode = Union[_Mode] - WildCard = Union[_WildCard] - - def __init__( - self, - *, - split: "HPAFoV.Split" = _Split.TRAIN, - mode: "HPAFoV.Mode" = _Mode.ALL, - wildcard: "HPAFoV.WildCard" = _WildCard.NONE, - root: str, - transforms: Optional[Callable] = None, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - image_decoder_type: DecoderType = DecoderType.ChannelSelectDecoder, - image_decoder_params: Dict[str, Any] = {}, - **kwargs: Any, - ) -> None: - super().__init__( - root, - transforms, - transform, - target_transform, - image_decoder_type=image_decoder_type, - image_decoder_params={ - "select_channel": True - if wildcard == _WildCard.SEPARATECHANNELS or wildcard == "SEPARATE_CHANNELS" - else False - }, - **kwargs, - ) - self.mode = mode - self.split = split - self.root = root - self.wildcard = wildcard - self.channel_adaptive = True - if split == _Split.SSL.value.upper() or split == _Split.SSL or split == "SSL": - self._image_paths, self._labels = _load_file_names_and_labels_ssl(root) - else: - self._image_paths, self._labels = _load_file_names_and_labels(root, self.split, self.mode) - - self._channels = np.repeat(np.array([[0, 1, 2, 3]]), len(self._image_paths), axis=0).tolist() - - if self.wildcard == _WildCard.SEPARATECHANNELS.value.upper(): - image_paths, labels, channels = self._image_paths, self._labels, self._channels - channels = np.array(channels) - # separate and stack the columns of the channels array - C = channels.shape[1] - channels = np.concatenate([channels[:, i] for i in range(C)]) - self._channels = np.expand_dims(channels, 1).tolist() - self.image_paths = image_paths * C - self.labels = labels * C - - def get_image_relpath(self, index: int) -> str: - return self._image_paths[index] - - def get_image_data(self, index: int) -> bytes: - image_relpath = self.get_image_relpath(index) - image_full_path = os.path.join(self.root, image_relpath) - with open(image_full_path, mode="rb") as f: - image_data = f.read() - if self.channel_adaptive: - channels = self._channels[index] - return image_data + bytes(channels) + (len(channels)).to_bytes(1, byteorder="big") - else: - return image_data - - def get_target(self, index: int) -> Any: - return self._labels[index] - - def get_targets(self) -> np.ndarray: - return np.array(self._labels) - - def __len__(self) -> int: - return len(self._image_paths) diff --git a/dinov2/data/datasets/cell_dino/hpaone.py b/dinov2/data/datasets/cell_dino/hpaone.py deleted file mode 100644 index 0058455f1ab8af6d13610591dbfb7c7c919f10a5..0000000000000000000000000000000000000000 --- a/dinov2/data/datasets/cell_dino/hpaone.py +++ /dev/null @@ -1,223 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the CC-by-NC licence, -# found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree. - -import csv -from enum import Enum -import logging -import os -from typing import Any, Callable, List, Optional, Tuple, Union - -import numpy as np - -from ..extended import ExtendedVisionDataset -from ..decoders import DecoderType - -logger = logging.getLogger("dinov2") - -PROTEIN_LOCALIZATION = [ - "actin filaments,focal adhesion sites", - "aggresome", - "centrosome,centriolar satellite", - "cytosol", - "endoplasmic reticulum", - "golgi apparatus", - "intermediate filaments", - "microtubules", - "mitochondria", - "mitotic spindle", - "no staining", - "nuclear bodies", - "nuclear membrane", - "nuclear speckles", - "nucleoli", - "nucleoli fibrillar center", - "nucleoplasm", - "plasma membrane,cell junctions", - "vesicles,peroxisomes,endosomes,lysosomes,lipid droplets,cytoplasmic bodies", -] # 19 - - -CELL_TYPE = [ - "A-431", # 0 - "A549", - "AF22", - "ASC TERT1", - "BJ", - "CACO-2", - "EFO-21", - "HAP1", - "HDLM-2", - "HEK 293", # 9 - "HEL", - "HUVEC TERT2", - "HaCaT", - "HeLa", - "Hep G2", - "JURKAT", - "K-562", - "MCF7", - "PC-3", - "REH", - "RH-30", # 20 - "RPTEC TERT1", - "RT4", - "SH-SY5Y", - "SK-MEL-30", - "SiHa", - "U-2 OS", - "U-251 MG", - "hTCEpi", # 28 -] # 29 cell types - - -class _Split(Enum): - VAL = "val" - TRAIN = "train" - ALL = "all" # images without labels, for encoder training - - -class _Mode(Enum): - PROTEIN_LOCALIZATION = "protein_localization" - CELL_TYPE = "cell_type" - - @property - def num_labels(self): - if self == _Mode.CELL_TYPE.value.upper(): - return len(CELL_TYPE) - return len(PROTEIN_LOCALIZATION) - - -def _simple_parse_csv(img_rootdir, csv_filepath: str): - samples = [] - with open(csv_filepath) as filename: - template = csv.DictReader(filename) - samples = [(os.path.join(img_rootdir, row["img_path"]), 0) for row in template] - return samples - - -def _parse_csv(img_rootdir, csv_labels_path: str): - nb_protein_location = len(PROTEIN_LOCALIZATION) - nb_cell_type = len(CELL_TYPE) - samples = [] - with open(csv_labels_path) as filename: - reader = csv.DictReader(filename) - for row in reader: - protein_location = np.zeros(nb_protein_location, dtype=np.int_) - for k in range(nb_protein_location): - if row[PROTEIN_LOCALIZATION[k]] == "True": - protein_location[k] = 1 - - cell_type = 0 - for k in range(nb_cell_type): - if row[CELL_TYPE[k]] == "True": - cell_type = k - - samples.append( - ( - img_rootdir + "/" + row["file"].rsplit("/", 1)[1], - protein_location, - cell_type, - ) - ) - return samples - - -def _load_file_names_and_labels_ssl( - root: str, -) -> Tuple[List[str], List[Any]]: - curr_dir_train = os.path.join(root, "varied_size_masked_single_cells_HPA") - csv_all_path = os.path.join(root, "varied_size_masked_single_cells_pretrain_20240507.csv") - samples = _simple_parse_csv(curr_dir_train, csv_all_path) - image_paths, fake_labels = zip(*samples) - lab = list(fake_labels) - return image_paths, lab - - -def _load_file_names_and_labels_train_or_test( - root: str, - split: _Split, - mode: _Mode, -) -> Tuple[List[str], List[Any]]: - - if split == _Split.TRAIN.value.upper() or split == _Split.TRAIN: - csv_labels_path = os.path.join(root, "fixed_size_masked_single_cells_pretrain_20240507.csv") - elif split == _Split.VAL.value.upper() or split == _Split.VAL: - csv_labels_path = os.path.join(root, "fixed_size_masked_single_cells_evaluation_20240507.csv") - else: - print("wrong split name") - curr_dir_val = os.path.join(root, "fixed_size_masked_single_cells_HPA") - - samples = _parse_csv(curr_dir_val, csv_labels_path) - image_paths, protein_location, cell_type = zip(*samples) - if mode == _Mode.PROTEIN_LOCALIZATION.value.upper(): - lab = protein_location - elif mode == _Mode.CELL_TYPE.value.upper(): - lab = cell_type - else: - lab = protein_location, cell_type - image_paths = list(image_paths) - return image_paths, lab - - -class HPAone(ExtendedVisionDataset): - Split = Union[_Split] - Mode = Union[_Mode] - - def __init__( - self, - *, - split: "HPAone.Split" = _Split.ALL, - mode: "HPAone.Mode" = None, - root: str, - transforms: Optional[Callable] = None, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - image_decoder_type: DecoderType = DecoderType.XChannelsDecoder, - **kwargs: Any, - ) -> None: - super().__init__( - root, - transforms, - transform, - target_transform, - image_decoder_type=image_decoder_type, - **kwargs, - ) - self.mode = mode - self.split = split - self.root = root - - if ( - split in {_Split.TRAIN.value.upper(), _Split.VAL.value.upper()} - or split == _Split.TRAIN - or split == _Split.VAL - ): - ( - self._image_paths, - self._labels, - ) = _load_file_names_and_labels_train_or_test(root, split, mode) - elif split == _Split.ALL.value.upper() or split == _Split.ALL: - self._image_paths, self._labels = _load_file_names_and_labels_ssl(root) - else: - logger.info(f"unknown split: {split}, {_Split.ALL.value.upper()}") - - def get_image_relpath(self, index: int) -> str: - return self._image_paths[index] - - def get_image_data(self, index: int) -> bytes: - image_relpath = self.get_image_relpath(index) - image_full_path = os.path.join(self.root, image_relpath) - with open(image_full_path, mode="rb") as f: - image_data = f.read() - return image_data - - def get_target(self, index: int) -> Any: - return self._labels[index] - - def get_targets(self) -> np.ndarray: - return np.array(self._labels) - - def __len__(self) -> int: - return len(self._image_paths) diff --git a/dinov2/data/datasets/decoders.py b/dinov2/data/datasets/decoders.py deleted file mode 100644 index feb746885b5285b6009bfe07d32c44b2593f32af..0000000000000000000000000000000000000000 --- a/dinov2/data/datasets/decoders.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from io import BytesIO -from typing import Any, Type - -from PIL import Image -import numpy as np -import torch -from enum import Enum - -try: - import tifffile -except ImportError: - print("Could not import `tifffile`, TIFFImageDataDecoder will be disabled") - - -class Decoder: - def decode(self) -> Any: - raise NotImplementedError - - -class DecoderType(Enum): - ImageDataDecoder = "ImageDataDecoder" - XChannelsDecoder = "XChannelsDecoder" - XChannelsTIFFDecoder = "XChannelsTIFFDecoder" - ChannelSelectDecoder = "ChannelSelectDecoder" - - def get_class(self) -> Type[Decoder]: # noqa: C901 - if self == DecoderType.ImageDataDecoder: - return ImageDataDecoder - if self == DecoderType.XChannelsDecoder: - return XChannelsDecoder - if self == DecoderType.XChannelsTIFFDecoder: - return XChannelsTIFFDecoder - if self == DecoderType.ChannelSelectDecoder: - return ChannelSelectDecoder - - -class ImageDataDecoder(Decoder): - def __init__(self, image_data: bytes) -> None: - self._image_data = image_data - - def decode(self) -> Image: - f = BytesIO(self._image_data) - return Image.open(f).convert(mode="RGB") - - -class TargetDecoder(Decoder): - def __init__(self, target: Any): - self._target = target - - def decode(self) -> Any: - return self._target - - -class XChannelsDecoder(Decoder): - def __init__(self, image_data: bytes) -> None: - self._image_data = image_data - - def decode(self): - im = np.asarray(Image.open(BytesIO(self._image_data))) - if len(im.shape) == 2: - im = np.reshape(im, (im.shape[0], im.shape[0], -1), order="F") - return torch.Tensor(im).permute(2, 0, 1) - - -class XChannelsTIFFDecoder(Decoder): - def __init__(self, image_data: bytes, num_channels: int = 3) -> None: - self._image_data = image_data - self._num_channels = num_channels - - def decode(self): - numpy_array = tifffile.imread(BytesIO(self._image_data)) - numpy_array = np.reshape(numpy_array, (numpy_array.shape[0], -1, self._num_channels), order="F") - return torch.Tensor(numpy_array).permute(2, 0, 1) - - -class ChannelSelectDecoder(Decoder): - def __init__(self, image_data: bytes, select_channel: bool = False) -> None: - self.select_channel = select_channel - if select_channel: - self._image_data = image_data[:-1] - self._channel = image_data[-1] - else: - self._image_data = image_data - - def decode(self): - im = np.asarray(Image.open(BytesIO(self._image_data))) - if self.select_channel: - return torch.Tensor(im).permute(2, 0, 1)[[self._channel]] - return torch.Tensor(im).permute(2, 0, 1) diff --git a/dinov2/data/datasets/extended.py b/dinov2/data/datasets/extended.py deleted file mode 100644 index 32555b57cab97eddf47e20ba98212a5c628f331d..0000000000000000000000000000000000000000 --- a/dinov2/data/datasets/extended.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from typing import Any, Tuple - -from torchvision.datasets import VisionDataset - -from .decoders import DecoderType, TargetDecoder - - -class ExtendedVisionDataset(VisionDataset): - def __init__(self, *args, **kwargs) -> None: - image_decoder_type = kwargs.pop("image_decoder_type", DecoderType.ImageDataDecoder) - self._decoder_params = {} - self._image_decoder_class = image_decoder_type.get_class() - if "image_decoder_params" in kwargs: - self._decoder_params = kwargs.pop("image_decoder_params") - - super().__init__(*args, **kwargs) # type: ignore - - def get_image_data(self, index: int) -> bytes: - raise NotImplementedError - - def get_target(self, index: int) -> Any: - raise NotImplementedError - - def __getitem__(self, index: int) -> Tuple[Any, Any]: - try: - image_data = self.get_image_data(index) - image = self._image_decoder_class(image_data, **self._decoder_params).decode() - except Exception as e: - raise RuntimeError(f"can not read image for sample {index}") from e - target = self.get_target(index) - target = TargetDecoder(target).decode() - - if self.transforms is not None: - image, target = self.transforms(image, target) - - return image, target - - def __len__(self) -> int: - raise NotImplementedError diff --git a/dinov2/data/datasets/image_net.py b/dinov2/data/datasets/image_net.py deleted file mode 100644 index 8d08446147986c58360163e468896e994197c657..0000000000000000000000000000000000000000 --- a/dinov2/data/datasets/image_net.py +++ /dev/null @@ -1,290 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import csv -from enum import Enum -import logging -import os -from typing import Callable, List, Optional, Tuple, Union - -import numpy as np - -from .extended import ExtendedVisionDataset - - -logger = logging.getLogger("dinov2") -_Target = int - - -class _Split(Enum): - TRAIN = "train" - VAL = "val" - TEST = "test" # NOTE: torchvision does not support the test split - - @property - def length(self) -> int: - split_lengths = { - _Split.TRAIN: 1_281_167, - _Split.VAL: 50_000, - _Split.TEST: 100_000, - } - return split_lengths[self] - - def get_dirname(self, class_id: Optional[str] = None) -> str: - return self.value if class_id is None else os.path.join(self.value, class_id) - - def get_image_relpath(self, actual_index: int, class_id: Optional[str] = None) -> str: - dirname = self.get_dirname(class_id) - if self == _Split.TRAIN: - basename = f"{class_id}_{actual_index}" - else: # self in (_Split.VAL, _Split.TEST): - basename = f"ILSVRC2012_{self.value}_{actual_index:08d}" - return os.path.join(dirname, basename + ".JPEG") - - def parse_image_relpath(self, image_relpath: str) -> Tuple[str, int]: - assert self != _Split.TEST - dirname, filename = os.path.split(image_relpath) - class_id = os.path.split(dirname)[-1] - basename, _ = os.path.splitext(filename) - actual_index = int(basename.split("_")[-1]) - return class_id, actual_index - - -class ImageNet(ExtendedVisionDataset): - Target = Union[_Target] - Split = Union[_Split] - - def __init__( - self, - *, - split: "ImageNet.Split", - root: str, - extra: str, - transforms: Optional[Callable] = None, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - ) -> None: - super().__init__(root, transforms, transform, target_transform) - self._extra_root = extra - self._split = split - - self._entries = None - self._class_ids = None - self._class_names = None - - @property - def split(self) -> "ImageNet.Split": - return self._split - - def _get_extra_full_path(self, extra_path: str) -> str: - return os.path.join(self._extra_root, extra_path) - - def _load_extra(self, extra_path: str) -> np.ndarray: - extra_full_path = self._get_extra_full_path(extra_path) - return np.load(extra_full_path, mmap_mode="r") - - def _save_extra(self, extra_array: np.ndarray, extra_path: str) -> None: - extra_full_path = self._get_extra_full_path(extra_path) - os.makedirs(self._extra_root, exist_ok=True) - np.save(extra_full_path, extra_array) - - @property - def _entries_path(self) -> str: - return f"entries-{self._split.value.upper()}.npy" - - @property - def _class_ids_path(self) -> str: - return f"class-ids-{self._split.value.upper()}.npy" - - @property - def _class_names_path(self) -> str: - return f"class-names-{self._split.value.upper()}.npy" - - def _get_entries(self) -> np.ndarray: - if self._entries is None: - self._entries = self._load_extra(self._entries_path) - assert self._entries is not None - return self._entries - - def _get_class_ids(self) -> np.ndarray: - if self._split == _Split.TEST: - assert False, "Class IDs are not available in TEST split" - if self._class_ids is None: - self._class_ids = self._load_extra(self._class_ids_path) - assert self._class_ids is not None - return self._class_ids - - def _get_class_names(self) -> np.ndarray: - if self._split == _Split.TEST: - assert False, "Class names are not available in TEST split" - if self._class_names is None: - self._class_names = self._load_extra(self._class_names_path) - assert self._class_names is not None - return self._class_names - - def find_class_id(self, class_index: int) -> str: - class_ids = self._get_class_ids() - return str(class_ids[class_index]) - - def find_class_name(self, class_index: int) -> str: - class_names = self._get_class_names() - return str(class_names[class_index]) - - def get_image_data(self, index: int) -> bytes: - entries = self._get_entries() - actual_index = entries[index]["actual_index"] - - class_id = self.get_class_id(index) - - image_relpath = self.split.get_image_relpath(actual_index, class_id) - image_full_path = os.path.join(self.root, image_relpath) - with open(image_full_path, mode="rb") as f: - image_data = f.read() - return image_data - - def get_target(self, index: int) -> Optional[Target]: - entries = self._get_entries() - class_index = entries[index]["class_index"] - return None if self.split == _Split.TEST else int(class_index) - - def get_targets(self) -> Optional[np.ndarray]: - entries = self._get_entries() - return None if self.split == _Split.TEST else entries["class_index"] - - def get_class_id(self, index: int) -> Optional[str]: - entries = self._get_entries() - class_id = entries[index]["class_id"] - return None if self.split == _Split.TEST else str(class_id) - - def get_class_name(self, index: int) -> Optional[str]: - entries = self._get_entries() - class_name = entries[index]["class_name"] - return None if self.split == _Split.TEST else str(class_name) - - def __len__(self) -> int: - entries = self._get_entries() - assert len(entries) == self.split.length - return len(entries) - - def _load_labels(self, labels_path: str) -> List[Tuple[str, str]]: - labels_full_path = os.path.join(self.root, labels_path) - labels = [] - - try: - with open(labels_full_path, "r") as f: - reader = csv.reader(f) - for row in reader: - class_id, class_name = row - labels.append((class_id, class_name)) - except OSError as e: - raise RuntimeError(f'can not read labels file "{labels_full_path}"') from e - - return labels - - def _dump_entries(self) -> None: - split = self.split - if split == ImageNet.Split.TEST: - dataset = None - sample_count = split.length - max_class_id_length, max_class_name_length = 0, 0 - else: - labels_path = "labels.txt" - logger.info(f'loading labels from "{labels_path}"') - labels = self._load_labels(labels_path) - - # NOTE: Using torchvision ImageFolder for consistency - from torchvision.datasets import ImageFolder - - dataset_root = os.path.join(self.root, split.get_dirname()) - dataset = ImageFolder(dataset_root) - sample_count = len(dataset) - max_class_id_length, max_class_name_length = -1, -1 - for sample in dataset.samples: - _, class_index = sample - class_id, class_name = labels[class_index] - max_class_id_length = max(len(class_id), max_class_id_length) - max_class_name_length = max(len(class_name), max_class_name_length) - - dtype = np.dtype( - [ - ("actual_index", " old_percent: - logger.info(f"creating entries: {percent}%") - old_percent = percent - - actual_index = index + 1 - class_index = np.uint32(-1) - class_id, class_name = "", "" - entries_array[index] = (actual_index, class_index, class_id, class_name) - else: - class_names = {class_id: class_name for class_id, class_name in labels} - - assert dataset - old_percent = -1 - for index in range(sample_count): - percent = 100 * (index + 1) // sample_count - if percent > old_percent: - logger.info(f"creating entries: {percent}%") - old_percent = percent - - image_full_path, class_index = dataset.samples[index] - image_relpath = os.path.relpath(image_full_path, self.root) - class_id, actual_index = split.parse_image_relpath(image_relpath) - class_name = class_names[class_id] - entries_array[index] = (actual_index, class_index, class_id, class_name) - - logger.info(f'saving entries to "{self._entries_path}"') - self._save_extra(entries_array, self._entries_path) - - def _dump_class_ids_and_names(self) -> None: - split = self.split - if split == ImageNet.Split.TEST: - return - - entries_array = self._load_extra(self._entries_path) - - max_class_id_length, max_class_name_length, max_class_index = -1, -1, -1 - for entry in entries_array: - class_index, class_id, class_name = ( - entry["class_index"], - entry["class_id"], - entry["class_name"], - ) - max_class_index = max(int(class_index), max_class_index) - max_class_id_length = max(len(str(class_id)), max_class_id_length) - max_class_name_length = max(len(str(class_name)), max_class_name_length) - - class_count = max_class_index + 1 - class_ids_array = np.empty(class_count, dtype=f"U{max_class_id_length}") - class_names_array = np.empty(class_count, dtype=f"U{max_class_name_length}") - for entry in entries_array: - class_index, class_id, class_name = ( - entry["class_index"], - entry["class_id"], - entry["class_name"], - ) - class_ids_array[class_index] = class_id - class_names_array[class_index] = class_name - - logger.info(f'saving class IDs to "{self._class_ids_path}"') - self._save_extra(class_ids_array, self._class_ids_path) - - logger.info(f'saving class names to "{self._class_names_path}"') - self._save_extra(class_names_array, self._class_names_path) - - def dump_extra(self) -> None: - self._dump_entries() - self._dump_class_ids_and_names() diff --git a/dinov2/data/datasets/image_net_22k.py b/dinov2/data/datasets/image_net_22k.py deleted file mode 100644 index 52b36a2c664a7b72e30173b03b4e2aef1cd2fcd9..0000000000000000000000000000000000000000 --- a/dinov2/data/datasets/image_net_22k.py +++ /dev/null @@ -1,302 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from dataclasses import dataclass -from enum import Enum -from functools import lru_cache -from gzip import GzipFile -from io import BytesIO -from mmap import ACCESS_READ, mmap -import os -from typing import Any, Callable, List, Optional, Set, Tuple -import warnings - -import numpy as np - -from .extended import ExtendedVisionDataset - - -_Labels = int - -_DEFAULT_MMAP_CACHE_SIZE = 16 # Warning: This can exhaust file descriptors - - -@dataclass -class _ClassEntry: - block_offset: int - maybe_filename: Optional[str] = None - - -@dataclass -class _Entry: - class_index: int # noqa: E701 - start_offset: int - end_offset: int - filename: str - - -class _Split(Enum): - TRAIN = "train" - VAL = "val" - - @property - def length(self) -> int: - return { - _Split.TRAIN: 11_797_647, - _Split.VAL: 561_050, - }[self] - - def entries_path(self): - return f"imagenet21kp_{self.value}.txt" - - -def _get_tarball_path(class_id: str) -> str: - return f"{class_id}.tar" - - -def _make_mmap_tarball(tarballs_root: str, mmap_cache_size: int): - @lru_cache(maxsize=mmap_cache_size) - def _mmap_tarball(class_id: str) -> mmap: - tarball_path = _get_tarball_path(class_id) - tarball_full_path = os.path.join(tarballs_root, tarball_path) - with open(tarball_full_path) as f: - return mmap(fileno=f.fileno(), length=0, access=ACCESS_READ) - - return _mmap_tarball - - -class ImageNet22k(ExtendedVisionDataset): - _GZIPPED_INDICES: Set[int] = { - 841_545, - 1_304_131, - 2_437_921, - 2_672_079, - 2_795_676, - 2_969_786, - 6_902_965, - 6_903_550, - 6_903_628, - 7_432_557, - 7_432_589, - 7_813_809, - 8_329_633, - 10_296_990, - 10_417_652, - 10_492_265, - 10_598_078, - 10_782_398, - 10_902_612, - 11_203_736, - 11_342_890, - 11_397_596, - 11_589_762, - 11_705_103, - 12_936_875, - 13_289_782, - } - Labels = _Labels - - def __init__( - self, - *, - root: str, - extra: str, - transforms: Optional[Callable] = None, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - mmap_cache_size: int = _DEFAULT_MMAP_CACHE_SIZE, - ) -> None: - super().__init__(root, transforms, transform, target_transform) - self._extra_root = extra - - entries_path = self._get_entries_path(root) - self._entries = self._load_extra(entries_path) - - class_ids_path = self._get_class_ids_path(root) - self._class_ids = self._load_extra(class_ids_path) - - self._gzipped_indices = ImageNet22k._GZIPPED_INDICES - self._mmap_tarball = _make_mmap_tarball(self._tarballs_root, mmap_cache_size) - - def _get_entries_path(self, root: Optional[str] = None) -> str: - return "entries.npy" - - def _get_class_ids_path(self, root: Optional[str] = None) -> str: - return "class-ids.npy" - - def _find_class_ids(self, path: str) -> List[str]: - class_ids = [] - - with os.scandir(path) as entries: - for entry in entries: - root, ext = os.path.splitext(entry.name) - if ext != ".tar": - continue - class_ids.append(root) - - return sorted(class_ids) - - def _load_entries_class_ids(self, root: Optional[str] = None) -> Tuple[List[_Entry], List[str]]: - root = self.get_root(root) - entries: List[_Entry] = [] - class_ids = self._find_class_ids(root) - - for class_index, class_id in enumerate(class_ids): - path = os.path.join(root, "blocks", f"{class_id}.log") - class_entries = [] - - try: - with open(path) as f: - for line in f: - line = line.rstrip() - block, filename = line.split(":") - block_offset = int(block[6:]) - filename = filename[1:] - - maybe_filename = None - if filename != "** Block of NULs **": - maybe_filename = filename - _, ext = os.path.splitext(filename) - # assert ext == ".JPEG" - - class_entry = _ClassEntry(block_offset, maybe_filename) - class_entries.append(class_entry) - except OSError as e: - raise RuntimeError(f'can not read blocks file "{path}"') from e - - assert class_entries[-1].maybe_filename is None - - for class_entry1, class_entry2 in zip(class_entries, class_entries[1:]): - assert class_entry1.block_offset <= class_entry2.block_offset - start_offset = 512 * class_entry1.block_offset - end_offset = 512 * class_entry2.block_offset - assert class_entry1.maybe_filename is not None - filename = class_entry1.maybe_filename - entry = _Entry(class_index, start_offset, end_offset, filename) - # Skip invalid image files (PIL throws UnidentifiedImageError) - if filename == "n06470073_47249.JPEG": - continue - entries.append(entry) - - return entries, class_ids - - def _load_extra(self, extra_path: str) -> np.ndarray: - extra_root = self._extra_root - extra_full_path = os.path.join(extra_root, extra_path) - return np.load(extra_full_path, mmap_mode="r") - - def _save_extra(self, extra_array: np.ndarray, extra_path: str) -> None: - extra_root = self._extra_root - extra_full_path = os.path.join(extra_root, extra_path) - os.makedirs(extra_root, exist_ok=True) - np.save(extra_full_path, extra_array) - - @property - def _tarballs_root(self) -> str: - return self.root - - def find_class_id(self, class_index: int) -> str: - return str(self._class_ids[class_index]) - - def get_image_data(self, index: int) -> bytes: - entry = self._entries[index] - class_id = entry["class_id"] - class_mmap = self._mmap_tarball(class_id) - - start_offset, end_offset = entry["start_offset"], entry["end_offset"] - try: - mapped_data = class_mmap[start_offset:end_offset] - data = mapped_data[512:] # Skip entry header block - - if len(data) >= 2 and tuple(data[:2]) == (0x1F, 0x8B): - assert index in self._gzipped_indices, f"unexpected gzip header for sample {index}" - with GzipFile(fileobj=BytesIO(data)) as g: - data = g.read() - except Exception as e: - raise RuntimeError(f"can not retrieve image data for sample {index} " f'from "{class_id}" tarball') from e - - return data - - def get_target(self, index: int) -> Any: - return int(self._entries[index]["class_index"]) - - def get_targets(self) -> np.ndarray: - return self._entries["class_index"] - - def get_class_id(self, index: int) -> str: - return str(self._entries[index]["class_id"]) - - def get_class_ids(self) -> np.ndarray: - return self._entries["class_id"] - - def __getitem__(self, index: int) -> Tuple[Any, Any]: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - return super().__getitem__(index) - - def __len__(self) -> int: - return len(self._entries) - - def _dump_entries(self, *args, **kwargs) -> None: - entries, class_ids = self._load_entries_class_ids(*args, **kwargs) - - max_class_id_length, max_filename_length, max_class_index = -1, -1, -1 - for entry in entries: - class_id = class_ids[entry.class_index] - max_class_index = max(entry.class_index, max_class_index) - max_class_id_length = max(len(class_id), max_class_id_length) - max_filename_length = max(len(entry.filename), max_filename_length) - - dtype = np.dtype( - [ - ("class_index", " None: - entries_path = self._get_entries_path(*args, **kwargs) - entries_array = self._load_extra(entries_path) - - max_class_id_length, max_class_index = -1, -1 - for entry in entries_array: - class_index, class_id = entry["class_index"], entry["class_id"] - max_class_index = max(int(class_index), max_class_index) - max_class_id_length = max(len(str(class_id)), max_class_id_length) - - class_ids_array = np.empty(max_class_index + 1, dtype=f"U{max_class_id_length}") - for entry in entries_array: - class_index, class_id = entry["class_index"], entry["class_id"] - class_ids_array[class_index] = class_id - class_ids_path = self._get_class_ids_path(*args, **kwargs) - self._save_extra(class_ids_array, class_ids_path) - - def _dump_extra(self, *args, **kwargs) -> None: - self._dump_entries(*args, *kwargs) - self._dump_class_ids(*args, *kwargs) - - def dump_extra(self, root: Optional[str] = None) -> None: - return self._dump_extra(root) diff --git a/dinov2/data/loaders.py b/dinov2/data/loaders.py deleted file mode 100644 index fdf6709b87c439b10a0b4c34197755d0a1a4bc18..0000000000000000000000000000000000000000 --- a/dinov2/data/loaders.py +++ /dev/null @@ -1,232 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import logging -from enum import Enum -from typing import Any, Callable, List, Optional, TypeVar - -import torch -from torch.utils.data import Sampler - -from .datasets import ImageNet, ImageNet22k, HPAone, HPAFoV, CHAMMI_CP, CHAMMI_HPA, CHAMMI_WTC -from .samplers import EpochSampler, InfiniteSampler, ShardedInfiniteSampler - - -logger = logging.getLogger("dinov2") - - -class SamplerType(Enum): - DISTRIBUTED = 0 - EPOCH = 1 - INFINITE = 2 - SHARDED_INFINITE = 3 - SHARDED_INFINITE_NEW = 4 - - -def _make_bool_str(b: bool) -> str: - return "yes" if b else "no" - - -def _make_sample_transform(image_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None): - def transform(sample): - image, target = sample - if image_transform is not None: - image = image_transform(image) - if target_transform is not None: - target = target_transform(target) - return image, target - - return transform - - -def _parse_dataset_str(dataset_str: str): - tokens = dataset_str.split(":") - - name = tokens[0] - kwargs = {} - - for token in tokens[1:]: - key, value = token.split("=") - assert key in ("root", "extra", "split", "mode", "wildcard") - kwargs[key] = value - - if name == "ImageNet": - class_ = ImageNet - if "split" in kwargs: - kwargs["split"] = ImageNet.Split[kwargs["split"]] - elif name == "ImageNet22k": - class_ = ImageNet22k - elif name == "HPAone": - class_ = HPAone - elif name == "HPAFoV": - class_ = HPAFoV - elif name == "CHAMMI_CP": - class_ = CHAMMI_CP - elif name == "CHAMMI_WTC": - class_ = CHAMMI_WTC - elif name == "CHAMMI_HPA": - class_ = CHAMMI_HPA - else: - raise ValueError(f'Unsupported dataset "{name}"') - - return class_, kwargs - - -def make_dataset( - *, - dataset_str: str, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, -): - """ - Creates a dataset with the specified parameters. - - Args: - dataset_str: A dataset string description (e.g. ImageNet:split=TRAIN). - transform: A transform to apply to images. - target_transform: A transform to apply to targets. - - Returns: - The created dataset. - """ - logger.info(f'using dataset: "{dataset_str}"') - - class_, kwargs = _parse_dataset_str(dataset_str) - dataset = class_(transform=transform, target_transform=target_transform, **kwargs) - - logger.info(f"# of dataset samples: {len(dataset):,d}") - - # Aggregated datasets do not expose (yet) these attributes, so add them. - if not hasattr(dataset, "transform"): - setattr(dataset, "transform", transform) - if not hasattr(dataset, "target_transform"): - setattr(dataset, "target_transform", target_transform) - - return dataset - - -def _make_sampler( - *, - dataset, - type: Optional[SamplerType] = None, - shuffle: bool = False, - seed: int = 0, - size: int = -1, - advance: int = 0, -) -> Optional[Sampler]: - sample_count = len(dataset) - - if type == SamplerType.INFINITE: - logger.info("sampler: infinite") - if size > 0: - raise ValueError("sampler size > 0 is invalid") - return InfiniteSampler( - sample_count=sample_count, - shuffle=shuffle, - seed=seed, - advance=advance, - ) - elif type in (SamplerType.SHARDED_INFINITE, SamplerType.SHARDED_INFINITE_NEW): - logger.info("sampler: sharded infinite") - if size > 0: - raise ValueError("sampler size > 0 is invalid") - # TODO: Remove support for old shuffling - use_new_shuffle_tensor_slice = type == SamplerType.SHARDED_INFINITE_NEW - return ShardedInfiniteSampler( - sample_count=sample_count, - shuffle=shuffle, - seed=seed, - advance=advance, - use_new_shuffle_tensor_slice=use_new_shuffle_tensor_slice, - ) - elif type == SamplerType.EPOCH: - logger.info("sampler: epoch") - if advance > 0: - raise NotImplementedError("sampler advance > 0 is not supported") - size = size if size > 0 else sample_count - logger.info(f"# of samples / epoch: {size:,d}") - return EpochSampler( - size=size, - sample_count=sample_count, - shuffle=shuffle, - seed=seed, - ) - elif type == SamplerType.DISTRIBUTED: - logger.info("sampler: distributed") - if size > 0: - raise ValueError("sampler size > 0 is invalid") - if advance > 0: - raise ValueError("sampler advance > 0 is invalid") - return torch.utils.data.DistributedSampler( - dataset=dataset, - shuffle=shuffle, - seed=seed, - drop_last=False, - ) - - logger.info("sampler: none") - return None - - -T = TypeVar("T") - - -def make_data_loader( - *, - dataset, - batch_size: int, - num_workers: int, - shuffle: bool = True, - seed: int = 0, - sampler_type: Optional[SamplerType] = SamplerType.INFINITE, - sampler_size: int = -1, - sampler_advance: int = 0, - drop_last: bool = True, - persistent_workers: bool = False, - collate_fn: Optional[Callable[[List[T]], Any]] = None, -): - """ - Creates a data loader with the specified parameters. - - Args: - dataset: A dataset (third party, LaViDa or WebDataset). - batch_size: The size of batches to generate. - num_workers: The number of workers to use. - shuffle: Whether to shuffle samples. - seed: The random seed to use. - sampler_type: Which sampler to use: EPOCH, INFINITE, SHARDED_INFINITE, SHARDED_INFINITE_NEW, DISTRIBUTED or None. - sampler_size: The number of images per epoch (when applicable) or -1 for the entire dataset. - sampler_advance: How many samples to skip (when applicable). - drop_last: Whether the last non-full batch of data should be dropped. - persistent_workers: maintain the workers Dataset instances alive after a dataset has been consumed once. - collate_fn: Function that performs batch collation - """ - - sampler = _make_sampler( - dataset=dataset, - type=sampler_type, - shuffle=shuffle, - seed=seed, - size=sampler_size, - advance=sampler_advance, - ) - - logger.info("using PyTorch data loader") - data_loader = torch.utils.data.DataLoader( - dataset, - sampler=sampler, - batch_size=batch_size, - num_workers=num_workers, - pin_memory=True, - drop_last=drop_last, - persistent_workers=persistent_workers, - collate_fn=collate_fn, - ) - - try: - logger.info(f"# of batches: {len(data_loader):,d}") - except TypeError: # data loader has no length - logger.info("infinite data loader") - return data_loader diff --git a/dinov2/data/masking.py b/dinov2/data/masking.py deleted file mode 100644 index ab12aa7bf138b916b16a9a2ed1a628a2759dbec6..0000000000000000000000000000000000000000 --- a/dinov2/data/masking.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import random -import math -import numpy as np - - -class MaskingGenerator: - def __init__( - self, - input_size, - num_masking_patches=None, - min_num_patches=4, - max_num_patches=None, - min_aspect=0.3, - max_aspect=None, - ): - if not isinstance(input_size, tuple): - input_size = (input_size,) * 2 - self.height, self.width = input_size - - self.num_patches = self.height * self.width - self.num_masking_patches = num_masking_patches - - self.min_num_patches = min_num_patches - self.max_num_patches = num_masking_patches if max_num_patches is None else max_num_patches - - max_aspect = max_aspect or 1 / min_aspect - self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) - - def __repr__(self): - repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % ( - self.height, - self.width, - self.min_num_patches, - self.max_num_patches, - self.num_masking_patches, - self.log_aspect_ratio[0], - self.log_aspect_ratio[1], - ) - return repr_str - - def get_shape(self): - return self.height, self.width - - def _mask(self, mask, max_mask_patches): - delta = 0 - for _ in range(10): - target_area = random.uniform(self.min_num_patches, max_mask_patches) - aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) - h = int(round(math.sqrt(target_area * aspect_ratio))) - w = int(round(math.sqrt(target_area / aspect_ratio))) - if w < self.width and h < self.height: - top = random.randint(0, self.height - h) - left = random.randint(0, self.width - w) - - num_masked = mask[top : top + h, left : left + w].sum() - # Overlap - if 0 < h * w - num_masked <= max_mask_patches: - for i in range(top, top + h): - for j in range(left, left + w): - if mask[i, j] == 0: - mask[i, j] = 1 - delta += 1 - - if delta > 0: - break - return delta - - def __call__(self, num_masking_patches=0): - mask = np.zeros(shape=self.get_shape(), dtype=bool) - mask_count = 0 - while mask_count < num_masking_patches: - max_mask_patches = num_masking_patches - mask_count - max_mask_patches = min(max_mask_patches, self.max_num_patches) - - delta = self._mask(mask, max_mask_patches) - if delta == 0: - break - else: - mask_count += delta - - return mask diff --git a/dinov2/data/samplers.py b/dinov2/data/samplers.py deleted file mode 100644 index 6562197d94652bb9a75a5fc722fcb2c65ca161be..0000000000000000000000000000000000000000 --- a/dinov2/data/samplers.py +++ /dev/null @@ -1,229 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import itertools -from typing import Any, Optional -import warnings - -import numpy as np -import torch -from torch.utils.data.sampler import Sampler - -import dinov2.distributed as distributed - - -class EpochSampler(Sampler): - def __init__( - self, - *, - size: int, - sample_count: int, - shuffle: bool = False, - seed: int = 0, - start: Optional[int] = None, - step: Optional[int] = None, - ): - self._size = size - self._sample_count = sample_count - self._shuffle = shuffle - self._seed = seed - self._start = distributed.get_global_rank() if start is None else start - self._step = distributed.get_global_size() if step is None else step - self._epoch = 0 - - def __iter__(self): - count = (self._size + self._sample_count - 1) // self._sample_count - tiled_indices = np.tile(np.arange(self._sample_count), count) - if self._shuffle: - seed = self._seed * self._epoch if self._seed != 0 else self._epoch - rng = np.random.default_rng(seed) - iterable = rng.choice(tiled_indices, self._size, replace=False) - else: - iterable = tiled_indices[: self._size] - - yield from itertools.islice(iterable, self._start, None, self._step) - - def __len__(self): - return (self._size - self._start + self._step - 1) // self._step - - def set_epoch(self, epoch): - self._epoch = epoch - - -def _get_numpy_dtype(size: int) -> Any: - return np.int32 if size <= 2**31 else np.int64 - - -def _get_torch_dtype(size: int) -> Any: - return torch.int32 if size <= 2**31 else torch.int64 - - -def _generate_randperm_indices(*, size: int, generator: torch.Generator): - """Generate the indices of a random permutation.""" - dtype = _get_torch_dtype(size) - # This is actually matching PyTorch's CPU implementation, see: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorFactories.cpp#L900-L921 - perm = torch.arange(size, dtype=dtype) - for i in range(size): - j = torch.randint(i, size, size=(1,), generator=generator).item() - - # Always swap even if no-op - value = perm[j].item() - perm[j] = perm[i].item() - perm[i] = value - yield value - - -class InfiniteSampler(Sampler): - def __init__( - self, - *, - sample_count: int, - shuffle: bool = False, - seed: int = 0, - start: Optional[int] = None, - step: Optional[int] = None, - advance: int = 0, - ): - self._sample_count = sample_count - self._seed = seed - self._shuffle = shuffle - self._start = distributed.get_global_rank() if start is None else start - self._step = distributed.get_global_size() if step is None else step - self._advance = advance - - def __iter__(self): - if self._shuffle: - iterator = self._shuffled_iterator() - else: - iterator = self._iterator() - - yield from itertools.islice(iterator, self._advance, None) - - def _iterator(self): - assert not self._shuffle - - while True: - iterable = range(self._sample_count) - yield from itertools.islice(iterable, self._start, None, self._step) - - def _shuffled_iterator(self): - assert self._shuffle - - # Instantiate a generator here (rather than in the ctor) to keep the class - # picklable (requirement of mp.spawn) - generator = torch.Generator().manual_seed(self._seed) - - while True: - iterable = _generate_randperm_indices(size=self._sample_count, generator=generator) - yield from itertools.islice(iterable, self._start, None, self._step) - - -# The following function is somewhat equivalent to _new_shuffle_tensor_slice below, -# but avoids a full in-place random permutation generation. -def _shuffle_tensor_slice( - *, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator -) -> np.ndarray: - stop = len(tensor) - count = stop // step - drop_count = stop - step * count - if drop_count: - warnings.warn(f"# of dropped samples: {drop_count}") - - dtype = _get_numpy_dtype(stop) - result = np.empty(count, dtype=dtype) - - for i in range(count): - j = torch.randint(0, i + 1, size=(1,), generator=generator).item() if i > 0 else 0 - - result[i] = result[j] - result[j] = tensor[start + i * step].item() - - return result - - -def _new_shuffle_tensor_slice( - *, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator -) -> np.ndarray: - stop = len(tensor) - count = stop // step - dtype = torch.int64 # Needed for using randperm result as indices - count = stop // step - drop_count = stop - step * count - if drop_count: - warnings.warn(f"# of dropped samples: {drop_count}") - indices = torch.randperm(count, dtype=dtype, generator=generator) - return tensor[start::step][indices].numpy() - - -def _make_seed(seed: int, start: int, iter_count: int) -> int: - # NOTE: Tried a few variants (including iter_count << 32), this one worked best. - return seed + start + (iter_count << 24) - - -class ShardedInfiniteSampler(Sampler): - def __init__( - self, - *, - sample_count: int, - shuffle: bool = False, - seed: int = 0, - start: Optional[int] = None, - step: Optional[int] = None, - advance: int = 0, - use_new_shuffle_tensor_slice: bool = False, - ): - self._sample_count = sample_count - self._seed = seed - self._shuffle = shuffle - self._start = distributed.get_global_rank() if start is None else start - self._step = distributed.get_global_size() if step is None else step - self._advance = advance - self._iter_count = 0 - self._shuffle_tensor_slice_fn = ( - _new_shuffle_tensor_slice if use_new_shuffle_tensor_slice else _shuffle_tensor_slice - ) - - def __iter__(self): - iter_count = self._advance // self._sample_count - if iter_count > 0: - self._advance -= iter_count * self._sample_count - self._iter_count += iter_count - - if self._shuffle: - iterator = self._shuffled_iterator() - else: - iterator = self._iterator() - - yield from itertools.islice(iterator, self._advance, None) - - def _iterator(self): - assert not self._shuffle - - while True: - iterable = range(self._sample_count) - yield from itertools.islice(iterable, self._start, None, self._step) - - def _shuffled_iterator(self): - assert self._shuffle - - # Instantiate a generator here (rather than in the ctor) to be keep the class - # picklable (requirement of mp.spawn) - generator = torch.Generator() - - # Always shuffle everything first - generator.manual_seed(self._seed) - dtype = _get_torch_dtype(self._sample_count) - perm = torch.randperm(self._sample_count, dtype=dtype, generator=generator) - - while True: - # Re-seed on each iteration to allow skipping whole permutations - seed = _make_seed(self._seed, self._start, self._iter_count) - generator.manual_seed(seed) - - iterable = self._shuffle_tensor_slice_fn( - tensor=perm, start=self._start, step=self._step, generator=generator - ) - yield from iterable - self._iter_count += 1 diff --git a/dinov2/data/transforms.py b/dinov2/data/transforms.py deleted file mode 100644 index eb5f252b50c54d58f160528c9f2b00fad47103c7..0000000000000000000000000000000000000000 --- a/dinov2/data/transforms.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from typing import Sequence - -import torch -from torchvision import transforms - - -class GaussianBlur(transforms.RandomApply): - """ - Apply Gaussian Blur to the PIL image. - """ - - def __init__(self, *, p: float = 0.5, radius_min: float = 0.1, radius_max: float = 2.0): - # NOTE: torchvision is applying 1 - probability to return the original image - keep_p = 1 - p - transform = transforms.GaussianBlur(kernel_size=9, sigma=(radius_min, radius_max)) - super().__init__(transforms=[transform], p=keep_p) - - -class MaybeToTensor(transforms.ToTensor): - """ - Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor, or keep as is if already a tensor. - """ - - def __call__(self, pic): - """ - Args: - pic (PIL Image, numpy.ndarray or torch.tensor): Image to be converted to tensor. - Returns: - Tensor: Converted image. - """ - if isinstance(pic, torch.Tensor): - return pic - return super().__call__(pic) - - -# Use timm's names -IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) -IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) - - -def make_normalize_transform( - mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, - std: Sequence[float] = IMAGENET_DEFAULT_STD, -) -> transforms.Normalize: - return transforms.Normalize(mean=mean, std=std) - - -# This roughly matches torchvision's preset for classification training: -# https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L6-L44 -def make_classification_train_transform( - *, - crop_size: int = 224, - interpolation=transforms.InterpolationMode.BICUBIC, - hflip_prob: float = 0.5, - mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, - std: Sequence[float] = IMAGENET_DEFAULT_STD, -): - transforms_list = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)] - if hflip_prob > 0.0: - transforms_list.append(transforms.RandomHorizontalFlip(hflip_prob)) - transforms_list.extend( - [ - MaybeToTensor(), - make_normalize_transform(mean=mean, std=std), - ] - ) - return transforms.Compose(transforms_list) - - -# This matches (roughly) torchvision's preset for classification evaluation: -# https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L47-L69 -def make_classification_eval_transform( - *, - resize_size: int = 256, - interpolation=transforms.InterpolationMode.BICUBIC, - crop_size: int = 224, - mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, - std: Sequence[float] = IMAGENET_DEFAULT_STD, -) -> transforms.Compose: - transforms_list = [ - transforms.Resize(resize_size, interpolation=interpolation), - transforms.CenterCrop(crop_size), - MaybeToTensor(), - make_normalize_transform(mean=mean, std=std), - ] - return transforms.Compose(transforms_list) diff --git a/dinov2/distributed/__init__.py b/dinov2/distributed/__init__.py deleted file mode 100644 index 23226f4536bf5acf4ffac242e9903d92863b246d..0000000000000000000000000000000000000000 --- a/dinov2/distributed/__init__.py +++ /dev/null @@ -1,270 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import os -import random -import re -import socket -from typing import Dict, List - -import torch -import torch.distributed as dist - -_LOCAL_RANK = -1 -_LOCAL_WORLD_SIZE = -1 - - -def is_enabled() -> bool: - """ - Returns: - True if distributed training is enabled - """ - return dist.is_available() and dist.is_initialized() - - -def get_global_size() -> int: - """ - Returns: - The number of processes in the process group - """ - return dist.get_world_size() if is_enabled() else 1 - - -def get_global_rank() -> int: - """ - Returns: - The rank of the current process within the global process group. - """ - return dist.get_rank() if is_enabled() else 0 - - -def get_local_rank() -> int: - """ - Returns: - The rank of the current process within the local (per-machine) process group. - """ - if not is_enabled(): - return 0 - assert 0 <= _LOCAL_RANK < _LOCAL_WORLD_SIZE - return _LOCAL_RANK - - -def get_local_size() -> int: - """ - Returns: - The size of the per-machine process group, - i.e. the number of processes per machine. - """ - if not is_enabled(): - return 1 - assert 0 <= _LOCAL_RANK < _LOCAL_WORLD_SIZE - return _LOCAL_WORLD_SIZE - - -def is_main_process() -> bool: - """ - Returns: - True if the current process is the main one. - """ - return get_global_rank() == 0 - - -def _restrict_print_to_main_process() -> None: - """ - This function disables printing when not in the main process - """ - import builtins as __builtin__ - - builtin_print = __builtin__.print - - def print(*args, **kwargs): - force = kwargs.pop("force", False) - if is_main_process() or force: - builtin_print(*args, **kwargs) - - __builtin__.print = print - - -def _get_master_port(seed: int = 0) -> int: - MIN_MASTER_PORT, MAX_MASTER_PORT = (20_000, 60_000) - - master_port_str = os.environ.get("MASTER_PORT") - if master_port_str is None: - rng = random.Random(seed) - return rng.randint(MIN_MASTER_PORT, MAX_MASTER_PORT) - - return int(master_port_str) - - -def _get_available_port() -> int: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - # A "" host address means INADDR_ANY i.e. binding to all interfaces. - # Note this is not compatible with IPv6. - s.bind(("", 0)) - port = s.getsockname()[1] - return port - - -_TORCH_DISTRIBUTED_ENV_VARS = ( - "MASTER_ADDR", - "MASTER_PORT", - "RANK", - "WORLD_SIZE", - "LOCAL_RANK", - "LOCAL_WORLD_SIZE", -) - - -def _collect_env_vars() -> Dict[str, str]: - return {env_var: os.environ[env_var] for env_var in _TORCH_DISTRIBUTED_ENV_VARS if env_var in os.environ} - - -def _is_slurm_job_process() -> bool: - return "SLURM_JOB_ID" in os.environ - - -def _parse_slurm_node_list(s: str) -> List[str]: - nodes = [] - # Extract "hostname", "hostname[1-2,3,4-5]," substrings - p = re.compile(r"(([^\[]+)(?:\[([^\]]+)\])?),?") - for m in p.finditer(s): - prefix, suffixes = s[m.start(2) : m.end(2)], s[m.start(3) : m.end(3)] - for suffix in suffixes.split(","): - span = suffix.split("-") - if len(span) == 1: - nodes.append(prefix + suffix) - else: - width = len(span[0]) - start, end = int(span[0]), int(span[1]) + 1 - nodes.extend([prefix + f"{i:0{width}}" for i in range(start, end)]) - return nodes - - -def _check_env_variable(key: str, new_value: str): - # Only check for difference with preset environment variables - if key in os.environ and os.environ[key] != new_value: - raise RuntimeError(f"Cannot export environment variables as {key} is already set") - - -class _TorchDistributedEnvironment: - def __init__(self): - self.master_addr = "127.0.0.1" - self.master_port = 0 - self.rank = -1 - self.world_size = -1 - self.local_rank = -1 - self.local_world_size = -1 - - if _is_slurm_job_process(): - return self._set_from_slurm_env() - - env_vars = _collect_env_vars() - if not env_vars: - # Environment is not set - pass - elif len(env_vars) == len(_TORCH_DISTRIBUTED_ENV_VARS): - # Environment is fully set - return self._set_from_preset_env() - else: - # Environment is partially set - collected_env_vars = ", ".join(env_vars.keys()) - raise RuntimeError(f"Partially set environment: {collected_env_vars}") - - if torch.cuda.device_count() > 0: - return self._set_from_local() - - raise RuntimeError("Can't initialize PyTorch distributed environment") - - # Slurm job created with sbatch, submitit, etc... - def _set_from_slurm_env(self): - # logger.info("Initialization from Slurm environment") - job_id = int(os.environ["SLURM_JOB_ID"]) - node_count = int(os.environ["SLURM_JOB_NUM_NODES"]) - nodes = _parse_slurm_node_list(os.environ["SLURM_JOB_NODELIST"]) - assert len(nodes) == node_count - - self.master_addr = nodes[0] - self.master_port = _get_master_port(seed=job_id) - self.rank = int(os.environ["SLURM_PROCID"]) - self.world_size = int(os.environ["SLURM_NTASKS"]) - assert self.rank < self.world_size - self.local_rank = int(os.environ["SLURM_LOCALID"]) - self.local_world_size = self.world_size // node_count - assert self.local_rank < self.local_world_size - - # Single node job with preset environment (i.e. torchrun) - def _set_from_preset_env(self): - # logger.info("Initialization from preset environment") - self.master_addr = os.environ["MASTER_ADDR"] - self.master_port = os.environ["MASTER_PORT"] - self.rank = int(os.environ["RANK"]) - self.world_size = int(os.environ["WORLD_SIZE"]) - assert self.rank < self.world_size - self.local_rank = int(os.environ["LOCAL_RANK"]) - self.local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) - assert self.local_rank < self.local_world_size - - # Single node and GPU job (i.e. local script run) - def _set_from_local(self): - # logger.info("Initialization from local") - self.master_addr = "127.0.0.1" - self.master_port = _get_available_port() - self.rank = 0 - self.world_size = 1 - self.local_rank = 0 - self.local_world_size = 1 - - def export(self, *, overwrite: bool) -> "_TorchDistributedEnvironment": - # See the "Environment variable initialization" section from - # https://pytorch.org/docs/stable/distributed.html for the complete list of - # environment variables required for the env:// initialization method. - env_vars = { - "MASTER_ADDR": self.master_addr, - "MASTER_PORT": str(self.master_port), - "RANK": str(self.rank), - "WORLD_SIZE": str(self.world_size), - "LOCAL_RANK": str(self.local_rank), - "LOCAL_WORLD_SIZE": str(self.local_world_size), - } - if not overwrite: - for k, v in env_vars.items(): - _check_env_variable(k, v) - - os.environ.update(env_vars) - return self - - -def enable(*, set_cuda_current_device: bool = True, overwrite: bool = False, allow_nccl_timeout: bool = False): - """Enable distributed mode - - Args: - set_cuda_current_device: If True, call torch.cuda.set_device() to set the - current PyTorch CUDA device to the one matching the local rank. - overwrite: If True, overwrites already set variables. Else fails. - """ - - global _LOCAL_RANK, _LOCAL_WORLD_SIZE - if _LOCAL_RANK >= 0 or _LOCAL_WORLD_SIZE >= 0: - raise RuntimeError("Distributed mode has already been enabled") - torch_env = _TorchDistributedEnvironment() - torch_env.export(overwrite=overwrite) - - if set_cuda_current_device: - torch.cuda.set_device(torch_env.local_rank) - - if allow_nccl_timeout: - # This allows to use torch distributed timeout in a NCCL backend - key, value = "NCCL_ASYNC_ERROR_HANDLING", "1" - if not overwrite: - _check_env_variable(key, value) - os.environ[key] = value - - dist.init_process_group(backend="nccl") - dist.barrier() - - # Finalize setup - _LOCAL_RANK = torch_env.local_rank - _LOCAL_WORLD_SIZE = torch_env.local_world_size - _restrict_print_to_main_process() diff --git a/dinov2/eval/__init__.py b/dinov2/eval/__init__.py deleted file mode 100644 index b88da6bf80be92af00b72dfdb0a806fa64a7a2d9..0000000000000000000000000000000000000000 --- a/dinov2/eval/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. diff --git a/dinov2/eval/cell_dino/knn.py b/dinov2/eval/cell_dino/knn.py deleted file mode 100644 index d20d85700b7fc454f7e10e8d88fa660e954d5124..0000000000000000000000000000000000000000 --- a/dinov2/eval/cell_dino/knn.py +++ /dev/null @@ -1,479 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the CC-by-NC licence, -# found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree. - -import argparse -from functools import partial -import json -import logging -import os -import sys -from typing import List, Optional, Any -import numpy as np - -import torch -import torch.backends.cudnn as cudnn -import pandas as pd -from sklearn.metrics import f1_score - -import dinov2.distributed as distributed -from dinov2.data import make_dataset, DatasetWithEnumeratedTargets, SamplerType, make_data_loader -from dinov2.data.cell_dino.transforms import NormalizationType, make_classification_eval_cell_transform -from dinov2.eval.metrics import build_metric, MetricType -from dinov2.eval.setup import get_args_parser as get_setup_args_parser -from dinov2.eval.setup import setup_and_build_model - -from dinov2.data import ResultsAccumulator -from dinov2.eval.utils import ModelWithNormalize -from dinov2.eval.cell_dino.utils import ( - BagOfChannelsModelWithNormalize, - extract_features_cell_dino, - average_metrics, - create_train_dataset_dict, - get_num_classes, - extract_features_for_dataset_dict, - evaluate_with_accumulate, - KnnModule, -) -from dinov2.eval.knn import DictKeysModule -from torch.utils.data import Subset as SubsetEx -from torch.utils.data import ConcatDataset as ConcatDatasetEx - - -logger = logging.getLogger("dinov2") - - -def get_args_parser( - description: Optional[str] = None, - parents: Optional[List[argparse.ArgumentParser]] = None, - add_help: bool = True, -): - parents = parents or [] - setup_args_parser = get_setup_args_parser(parents=parents, add_help=False) - parents = [setup_args_parser] - parser = argparse.ArgumentParser( - description=description, - parents=parents, - add_help=add_help, - ) - parser.add_argument( - "--train-dataset", - dest="train_dataset_str", - type=str, - help="Training dataset", - ) - parser.add_argument( - "--val-dataset", - dest="val_dataset_str", - type=str, - help="Validation dataset", - ) - parser.add_argument( - "--nb_knn", - nargs="+", - type=int, - help="Number of NN to use. 20 is usually working the best.", - ) - parser.add_argument( - "--temperature", - type=float, - help="Temperature used in the voting coefficient", - ) - parser.add_argument( - "--gather-on-cpu", - action="store_true", - help="Whether to gather the train features on cpu, slower" - "but useful to avoid OOM for large datasets (e.g. ImageNet22k).", - ) - parser.add_argument( - "--batch-size", - type=int, - help="Batch size.", - ) - parser.add_argument( - "--n-per-class-list", - nargs="+", - type=int, - help="Number to take per class", - ) - parser.add_argument( - "--n-tries", - type=int, - help="Number of tries", - ) - parser.add_argument( - "--leave-one-out-dataset", - type=str, - help="Path with indexes to use the leave one out strategy for CHAMMI_CP task 3 and CHAMMI_HPA task 4", - ) - parser.add_argument( - "--bag-of-channels", - action="store_true", - help='Whether to use the "bag of channels" channel adaptive strategy', - ) - parser.add_argument( - "--crop-size", - type=int, - help="crop size for train and eval", - ) - parser.add_argument( - "--resize-size", - type=int, - help="resize size for image just before crop. 0: no resize", - ) - parser.add_argument( - "--metric-type", - type=MetricType, - choices=list(MetricType), - help="Validation metric", - ) - parser.add_argument( - "--avgpool", - action="store_true", - help="Whether to use average pooling of path tokens in addition to CLS tokens", - ) - - parser.set_defaults( - train_dataset_str="ImageNet:split=TRAIN", - val_dataset_str="ImageNet:split=VAL", - nb_knn=[1], - temperature=0.07, - batch_size=256, - resize_size=0, - ) - return parser - - -class SequentialWithKwargs(torch.nn.Sequential): - def __init__(self, *args): - super().__init__(*args) - - def forward(self, input, **kwargs): - - input = self[0](input, **kwargs) - for module in self[1:]: - input = module(input) - return input - - -def create_train_test_dataset_dict_leave_one_out( - train_dataset, - test_dataset, -) -> dict[int, dict[int, Any]]: - """ - This function implements a train dataset dictionary with the leave-one-out (LOO) method. - Specifically, given a train dataset and test dataset, it creates a train dataset for each - test dataset point, which is a combination of train+test dataset except for this specific data point. - At the end, it contains len(test_dataset) key and value pairs. - - Format is {"nth-test-sample": dataset_without_test_sample} - """ - train_dataset_dict: dict[int, Any] = {} - test_size = len(test_dataset) - - for test_sample_index in range(test_size): - test_indices_bool = torch.ones(test_size, dtype=bool) - test_indices_bool[test_sample_index] = False - train_dataset_dict[test_sample_index] = ConcatDatasetEx( - [train_dataset, SubsetEx(test_dataset, test_indices_bool.nonzero().flatten())] - ) - - return train_dataset_dict - - -def eval_knn_with_leave_one_out( - model, leave_one_out_dataset, train_dataset, test_dataset, metric_type, nb_knn, temperature, batch_size, num_workers -): - num_classes = get_num_classes(test_dataset) - train_dataset_dict = create_train_dataset_dict(train_dataset) - test_dataset_dict = create_train_dataset_dict(test_dataset) - - logger.info("Extracting features for train set...") - train_data_dict = extract_features_for_dataset_dict( - model, train_dataset_dict, batch_size, num_workers, gather_on_cpu=True - ) - test_data_dict = extract_features_for_dataset_dict( - model, test_dataset_dict, batch_size, num_workers, gather_on_cpu=True - ) - - train_features = train_data_dict[0]["train_features"] - train_labels = train_data_dict[0]["train_labels"] - test_features = test_data_dict[0]["train_features"] - test_labels = test_data_dict[0]["train_labels"] - - metric_collection = build_metric(metric_type, num_classes=3) - - device = torch.cuda.current_device() - partial_knn_module = partial(KnnModule, T=temperature, device=device, num_classes=num_classes) - - logger.info("Reading the leave-one-out label metadata.") - - leave_one_out_indices = {} - metadata = pd.read_csv(leave_one_out_dataset) - if "HPA" in leave_one_out_dataset: - metadata = metadata[metadata["Task_three"]].reset_index() - leave_one_out_label_type = "cell_type" - else: - metadata = metadata[metadata["Task_four"]].reset_index() - leave_one_out_label_type = "Plate" - leave_one_out_labels = metadata[leave_one_out_label_type].unique() - - for leave_one_out_label in leave_one_out_labels: - leave_one_out_indices[leave_one_out_label] = torch.tensor( - metadata[metadata[leave_one_out_label_type] == leave_one_out_label].index.values - ) - - # ============ evaluation ... ============ - logger.info("Start the k-NN classification.") - - eval_metrics_dict = {} - postprocessors, metrics = {k: DictKeysModule([k]) for k in nb_knn}, { - k: metric_collection.clone().to(device) for k in nb_knn - } - for metric_key in metrics.keys(): - metrics[metric_key] = metrics[metric_key].to(device) - - accumulator_class = ResultsAccumulator - accumulators = {k: accumulator_class() for k in postprocessors.keys()} - all_preds = [] - all_target = [] - - for loo_label, loo_indices in leave_one_out_indices.items(): - logger.info(f"Evaluating on test sample {loo_label}") - loo_for_training_indices = torch.ones(test_features.shape[0], dtype=bool) - loo_for_training_indices[loo_indices] = False - train_features_sample = torch.cat([train_features, test_features[loo_for_training_indices]]) - train_labels_sample = torch.cat([train_labels, test_labels[loo_for_training_indices]]) - logger.info(f"Train shape {train_features_sample.shape}, Test shape {test_features[loo_indices].shape}") - logger.info( - f"Train values {train_labels_sample.unique(return_counts=True)}, Test shape {test_labels[loo_indices].unique(return_counts=True)}" - ) - knn_module = partial_knn_module( - train_features=train_features_sample, train_labels=train_labels_sample, nb_knn=nb_knn - ) - - output = knn_module(test_features[loo_indices].to(device)) - all_preds.append(output[1]) - all_target.append(test_labels[loo_indices]) - output[1] = output[1][:, 4:] - transformed_test_labels = test_labels[loo_indices] - 4 - for k, metric in metrics.items(): - metric_inputs = postprocessors[k](output, transformed_test_labels.to(device)) - metric.update(**metric_inputs) - accumulators[k].update( - preds=metric_inputs["preds"], target=metric_inputs["target"], index=loo_indices.to(device) - ) - - all_preds = torch.cat(all_preds).cpu().detach().numpy() - - all_preds = np.argmax(all_preds, axis=1) - all_target = torch.cat(all_target).cpu().detach().numpy() - - f1 = f1_score(all_target, all_preds, average="macro", labels=[4, 5, 6]) - logger.info(f"Real f1 score: {f1}") - eval_metrics = { - k: metric.compute() for k, metric in metrics.items() - } # next erased by the real f1 score computed above - - for k in nb_knn: - if k not in eval_metrics_dict: - eval_metrics_dict[k] = {} - eval_metrics_dict[k] = {metric: f1 * 100.0 for metric, v in eval_metrics[k].items()} - - if len(train_data_dict) > 1: - return {k: average_metrics(eval_metrics_dict[k]) for k in eval_metrics_dict.keys()} - - return {k: eval_metrics_dict[k] for k in eval_metrics_dict.keys()} - - -def eval_knn_with_model( - model, - output_dir, - train_dataset_str, - val_dataset_str, - nb_knn=(10, 20, 100, 200), - temperature=0.07, - autocast_dtype=torch.float, - metric_type=MetricType.MEAN_ACCURACY, - transform=None, - resize_size=256, - crop_size=224, - batch_size=256, - num_workers=5, - leave_one_out_dataset="", - bag_of_channels=False, - avgpool=False, -): - autocast_ctx = partial(torch.cuda.amp.autocast, enabled=True, dtype=autocast_dtype) - if bag_of_channels: - model = BagOfChannelsModelWithNormalize(model, autocast_ctx, avgpool) - else: - model = ModelWithNormalize(model) - if leave_one_out_dataset == "" or leave_one_out_dataset is None: - leave_one_out = False - else: - leave_one_out = True - - cudnn.benchmark = True - transform = make_classification_eval_cell_transform( - normalization_type=NormalizationType.SELF_NORM_CENTER_CROP, resize_size=resize_size, crop_size=crop_size - ) - - train_dataset = make_dataset(dataset_str=train_dataset_str, transform=transform) - results_dict = {} - test_dataset = make_dataset(dataset_str=val_dataset_str, transform=transform) - - with torch.cuda.amp.autocast(dtype=autocast_dtype): - if leave_one_out: - results_dict_knn = eval_knn_with_leave_one_out( - model=model, - leave_one_out_dataset=leave_one_out_dataset, - train_dataset=train_dataset, - test_dataset=test_dataset, - metric_type=metric_type, - nb_knn=nb_knn, - temperature=temperature, - batch_size=batch_size, - num_workers=num_workers, - ) - else: - results_dict_knn = eval_knn( - model=model, - train_dataset=train_dataset, - test_dataset=test_dataset, - metric_type=metric_type, - nb_knn=nb_knn, - temperature=temperature, - batch_size=batch_size, - num_workers=num_workers, - ) - - for knn_ in results_dict_knn.keys(): - top1 = results_dict_knn[knn_]["top-1"] - results_dict[f"{val_dataset_str}_{knn_} Top 1"] = top1 - results_string = f"{val_dataset_str} {knn_} NN classifier result: Top1: {top1:.2f}" - if "top-5" in results_dict_knn[knn_]: - top5 = results_dict_knn[knn_]["top-5"] - results_dict[f"{val_dataset_str}_{knn_} Top 5"] = top5 - results_string += f"Top5: {top5:.2f}" - logger.info(results_string) - - metrics_file_path = os.path.join(output_dir, "results_eval_knn.json") - with open(metrics_file_path, "a") as f: - for k, v in results_dict.items(): - f.write(json.dumps({k: v}) + "\n") - - if distributed.is_enabled(): - torch.distributed.barrier() - return results_dict - - -def eval_knn( - model, - train_dataset, - test_dataset, - metric_type, - nb_knn, - temperature, - batch_size, - num_workers, - few_shot_eval=False, - few_shot_k_or_percent=None, - few_shot_n_tries=1, -): - num_classes = get_num_classes(train_dataset) - train_dataset_dict = create_train_dataset_dict( - train_dataset, - few_shot_eval=few_shot_eval, - few_shot_k_or_percent=few_shot_k_or_percent, - few_shot_n_tries=few_shot_n_tries, - ) - - logger.info("Extracting features for train set...") - - train_data_dict: dict[int, dict[str, torch.Tensor]] = {} - for try_n, dataset in train_dataset_dict.items(): - features, labels = extract_features_cell_dino(model, dataset, batch_size, num_workers, gather_on_cpu=True) - train_data_dict[try_n] = {"train_features": features, "train_labels": labels} - - test_data_loader = make_data_loader( - dataset=DatasetWithEnumeratedTargets( - test_dataset, pad_dataset=True, num_replicas=distributed.get_global_size() - ), - batch_size=batch_size, - num_workers=num_workers, - sampler_type=SamplerType.DISTRIBUTED, - drop_last=False, - shuffle=False, - persistent_workers=True, - collate_fn=None, - ) - metric_collection = build_metric(metric_type, num_classes=num_classes) - - device = torch.cuda.current_device() - partial_knn_module = partial( - KnnModule, - T=temperature, - device=device, - num_classes=num_classes, - ) - - # ============ evaluation ... ============ - logger.info("Start the k-NN classification.") - eval_metrics_dict = {} - - for try_ in train_data_dict.keys(): - train_features, train_labels = train_data_dict[try_]["train_features"], train_data_dict[try_]["train_labels"] - k_list = sorted(set([el if el < len(train_features) else len(train_features) for el in nb_knn])) - knn_module = partial_knn_module(train_features=train_features, train_labels=train_labels, nb_knn=k_list) - postprocessors, metrics = {k: DictKeysModule([k]) for k in k_list}, { - k: metric_collection.clone() for k in k_list - } - _, eval_metrics, _ = evaluate_with_accumulate( - SequentialWithKwargs(model, knn_module), - test_data_loader, - postprocessors, - metrics, - device, - accumulate_results=False, - ) - for k in k_list: - if k not in eval_metrics_dict: - eval_metrics_dict[k] = {} - eval_metrics_dict[k][try_] = {metric: v.item() * 100.0 for metric, v in eval_metrics[k].items()} - - if len(train_data_dict) > 1: - return {k: average_metrics(eval_metrics_dict[k]) for k in eval_metrics_dict.keys()} - - return {k: eval_metrics_dict[k][0] for k in eval_metrics_dict.keys()} - - -def main(args): - model, autocast_dtype = setup_and_build_model(args) - eval_knn_with_model( - model=model, - output_dir=args.output_dir, - train_dataset_str=args.train_dataset_str, - val_dataset_str=args.val_dataset_str, - nb_knn=args.nb_knn, - temperature=args.temperature, - autocast_dtype=autocast_dtype, - transform=None, - metric_type=args.metric_type, - batch_size=args.batch_size, - num_workers=5, - leave_one_out_dataset=args.leave_one_out_dataset, - resize_size=args.resize_size, - crop_size=args.crop_size, - avgpool=args.avgpool, - bag_of_channels=args.bag_of_channels, - ) - return 0 - - -if __name__ == "__main__": - description = "k-NN evaluation on models trained with bag of channel strategy or cell dino" - args_parser = get_args_parser(description=description) - args = args_parser.parse_args() - sys.exit(main(args)) diff --git a/dinov2/eval/cell_dino/linear.py b/dinov2/eval/cell_dino/linear.py deleted file mode 100644 index efbf838a87af6af2cc6a749219f53c2e0a65d080..0000000000000000000000000000000000000000 --- a/dinov2/eval/cell_dino/linear.py +++ /dev/null @@ -1,1048 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the CC-by-NC licence, -# found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree. - -import argparse -from functools import partial -import json -import logging -import os -import sys -from typing import Any, Callable, Dict, Optional, Tuple, List -from enum import Enum -from dataclasses import dataclass - -from sklearn.metrics import f1_score -import numpy as np -import pandas as pd -import torch -import torch.nn as nn -from torch.utils.data import TensorDataset -from torch.nn.parallel import DistributedDataParallel - - -from dinov2.data import SamplerType, make_data_loader, make_dataset, DatasetWithEnumeratedTargets -from dinov2.data.cell_dino.transforms import NormalizationType, make_classification_eval_cell_transform -import dinov2.distributed as distributed -from dinov2.eval.metrics import MetricType, build_metric -from dinov2.eval.setup import get_args_parser as get_setup_args_parser -from dinov2.eval.setup import setup_and_build_model -from dinov2.eval.cell_dino.utils import ( - evaluate_with_accumulate, - LossType, - average_metrics, - create_train_dataset_dict, - get_num_classes, - extract_features_for_dataset_dict, -) -from dinov2.eval.utils import ModelWithIntermediateLayers -from dinov2.logging import MetricLogger -from dinov2.utils.checkpoint import build_periodic_checkpointer, resume_or_load - -logger = logging.getLogger("dinov2") - -""" -List of changes with respect to the standard linear evaluation script: - -bag of channel option : SCALE ADAPTIVE STRATEGY - -Adam optimizer instead of SGD -Scheduler : two options : onecycleLR or CosineAnnealingLR -the transforms/normalization are different, now calling make_classification_eval_cell_transform -add binary cross entropy loss option for protein localization -change the definition of the num_classes using get_num_classes -change of some default parameters (batch_size, epoch_length, epochs, lrs) -defined n_last_blocks option -avgpool option -leave one out strategy for CHAMMI evaluation -grid search for optimal weight decay -""" - - -def get_args_parser( - description: Optional[str] = None, - parents: Optional[List[argparse.ArgumentParser]] = None, - add_help: bool = True, -): - parents = parents or [] - setup_args_parser = get_setup_args_parser(parents=parents, add_help=False) - parents = [setup_args_parser] - parser = argparse.ArgumentParser( - description=description, - parents=parents, - add_help=add_help, - ) - parser.add_argument( - "--train-dataset", - dest="train_dataset_str", - type=str, - help="Training dataset", - ) - parser.add_argument( - "--val-dataset", - dest="val_dataset_str", - type=str, - help="Validation dataset", - ) - parser.add_argument( - "--test-datasets", - dest="test_dataset_strs", - type=str, - nargs="+", - help="Test datasets, none to reuse the validation dataset", - ) - parser.add_argument( - "--epochs", - type=int, - help="Number of training epochs", - ) - parser.add_argument( - "--batch-size", - type=int, - help="Batch Size (per GPU)", - ) - parser.add_argument( - "--num-workers", - type=int, - help="Number de Workers", - ) - parser.add_argument( - "--epoch-length", - type=int, - help="Length of an epoch in number of iterations", - ) - parser.add_argument( - "--save-checkpoint-frequency", - type=int, - help="Number of epochs between two named checkpoint saves.", - ) - parser.add_argument( - "--eval-period-iterations", - type=int, - help="Number of iterations between two evaluations.", - ) - parser.add_argument( - "--learning-rates", - nargs="+", - type=float, - help="Learning rates to grid search.", - ) - parser.add_argument( - "--weight_decays", - nargs="+", - type=float, - help="Weight decays to grid search.", - ) - parser.add_argument( - "--n-last-blocks", - type=int, - help="number of backbone last blocks used for the linear classifier", - ) - parser.add_argument( - "--no-resume", - action="store_true", - help="Whether to not resume from existing checkpoints", - ) - parser.add_argument( - "--val-metric-type", - type=MetricType, - choices=list(MetricType), - help="Validation metric", - ) - parser.add_argument( - "--test-metric-types", - type=MetricType, - choices=list(MetricType), - nargs="+", - help="Evaluation metric", - ) - parser.add_argument( - "--classifier-fpath", - type=str, - help="Path to a file containing pretrained linear classifiers", - ) - parser.add_argument( - "--val-class-mapping-fpath", - type=str, - help="Path to a file containing a mapping to adjust classifier outputs", - ) - parser.add_argument( - "--test-class-mapping-fpaths", - nargs="+", - type=str, - help="Path to a file containing a mapping to adjust classifier outputs", - ) - parser.add_argument( - "--loss-type", - type=LossType, - help="Cross Entropy or Binary Cross Entropy, default cross entropy loss", - ) - parser.add_argument( - "--bag-of-channels", - action="store_true", - help='Whether to use the "bag of channels" channel adaptive strategy', - ) - parser.add_argument( - "--leave-one-out-dataset", - type=str, - help="Path with indexes to use the leave one out strategy for CHAMMI_CP task 3 and CHAMMI_HPA task 4", - ) - parser.add_argument( - "--crop-size", - type=int, - help="crop size for train and eval", - ) - parser.add_argument( - "--resize-size", - type=int, - help="resize size for image just before crop. 0: no resize", - ) - parser.add_argument( - "--avgpool", - action="store_true", - help="Whether to use average pooling of path tokens in addition to CLS tokens", - ) - parser.add_argument( - "--scheduler", - type=SchedulerType, - help="Scheduler type", - ) - - parser.set_defaults( - train_dataset_str="ImageNet:split=TRAIN", - val_dataset_str="ImageNet:split=VAL", - test_dataset_strs=None, - epochs=30, - batch_size=64, - num_workers=8, - epoch_length=145, - save_checkpoint_frequency=1250, - eval_period_iterations=1250, - learning_rates=[1e-5, 2e-5, 5e-5, 1e-4, 2e-4, 5e-4, 1e-3, 2e-3, 5e-3, 1e-2, 2e-2, 5e-2, 1e-1, 2e-1, 5e-1, 1.0], - weight_decays=[0.0, 0.0001, 1.0e-05], - val_metric_type=MetricType.MEAN_ACCURACY, - test_metric_types=None, - classifier_fpath=None, - val_class_mapping_fpath=None, - test_class_mapping_fpaths=[None], - loss_type=LossType.CROSS_ENTROPY, - crop_size=384, - resize_size=0, - n_last_blocks=4, - avgpool=False, - scheduler=SchedulerType.COSINE_ANNEALING, - ) - return parser - - -def has_ddp_wrapper(m: nn.Module) -> bool: - return isinstance(m, DistributedDataParallel) - - -def remove_ddp_wrapper(m: nn.Module) -> nn.Module: - return m.module if has_ddp_wrapper(m) else m - - -def create_linear_input(x_tokens_list, use_n_blocks, use_avgpool, bag_of_channels): - intermediate_output = x_tokens_list[-use_n_blocks:] - output = torch.cat([class_token for _, class_token in intermediate_output], dim=-1) - if bag_of_channels: - if use_avgpool: - output = torch.cat( - ( - output, - torch.mean(intermediate_output[-1][0], dim=-2).reshape(intermediate_output[-1][0].shape[0], -1), - # average pooling of patch tokens: average over N, then concatenate channels if single-channel patch model - ), - dim=-1, - ) # concatenate average pooling of patch tokens to concatenated patch tokens - else: - if use_avgpool: - output = torch.cat( - ( - output, - torch.mean(intermediate_output[-1][0], dim=1), # patch tokens - ), - dim=-1, - ) - output = output.reshape(output.shape[0], -1) - return output.float() - - -class LinearClassifier(nn.Module): - """Linear layer to train on top of frozen features""" - - def __init__( - self, out_dim, use_n_blocks, use_avgpool, num_classes=1000, bag_of_channels=False, leave_one_out=False - ): - super().__init__() - self.out_dim = out_dim - self.use_n_blocks = use_n_blocks - self.use_avgpool = use_avgpool - self.num_classes = num_classes - self.bag_of_channels = bag_of_channels - self.leave_one_out = leave_one_out - self.linear = nn.Linear(out_dim, num_classes) - self.linear.weight.data.normal_(mean=0.0, std=0.01) - self.linear.bias.data.zero_() - - def forward(self, x_tokens_list): - if self.leave_one_out: - return self.linear(x_tokens_list) - output = create_linear_input(x_tokens_list, self.use_n_blocks, self.use_avgpool, self.bag_of_channels) - return self.linear(output) - - -class AllClassifiers(nn.Module): - def __init__(self, classifiers_dict): - super().__init__() - self.classifiers_dict = nn.ModuleDict() - self.classifiers_dict.update(classifiers_dict) - - def forward(self, inputs): - return {k: v.forward(inputs) for k, v in self.classifiers_dict.items()} - - def __len__(self): - return len(self.classifiers_dict) - - -class LinearPostprocessor(nn.Module): - def __init__(self, linear_classifier, class_mapping=None): - super().__init__() - self.linear_classifier = linear_classifier - self.register_buffer("class_mapping", None if class_mapping is None else torch.LongTensor(class_mapping)) - - def forward(self, samples, targets): - preds = self.linear_classifier(samples) - return { - "preds": preds[:, self.class_mapping] if self.class_mapping is not None else preds, - "target": targets, - } - - -def scale_lr(learning_rates, batch_size): - return learning_rates * (batch_size * distributed.get_global_size()) / 256.0 - - -def setup_linear_classifiers( - sample_output, - n_last_blocks_list, - learning_rates, - weight_decays, - batch_size, - num_classes=1000, - bag_of_channels=False, - leave_one_out=False, - avgpool=False, -): - linear_classifiers_dict = nn.ModuleDict() - avgpool_value = avgpool - optim_param_groups = [] - for n in n_last_blocks_list: - for avgpool in [avgpool_value]: - for _lr in learning_rates: - for wd in weight_decays: - lr = scale_lr(_lr, batch_size) - out_dim = create_linear_input( - sample_output, use_n_blocks=n, use_avgpool=avgpool, bag_of_channels=bag_of_channels - ).shape[1] - linear_classifier = LinearClassifier( - out_dim, - use_n_blocks=n, - use_avgpool=avgpool, - num_classes=num_classes, - bag_of_channels=bag_of_channels, - leave_one_out=leave_one_out, - ) - linear_classifier = linear_classifier.cuda() - linear_classifiers_dict[ - f"classifier_{n}_blocks_avgpool_{avgpool}_lr_{lr:.5f}_wd_{wd:.2E}".replace(".", "_") - ] = linear_classifier - optim_param_groups.append({"params": linear_classifier.parameters(), "lr": lr, "weight_decay": wd}) - - linear_classifiers = AllClassifiers(linear_classifiers_dict) - if distributed.is_enabled(): - linear_classifiers = nn.parallel.DistributedDataParallel(linear_classifiers) - - return linear_classifiers, optim_param_groups - - -def make_eval_data_loader( - *, - test_dataset_str_or_path_or_loo_dataset, - config, - batch_size, - num_workers, -): - if isinstance(test_dataset_str_or_path_or_loo_dataset, str): - logger.info(f"Loading dataset {test_dataset_str_or_path_or_loo_dataset}") - transform = make_classification_eval_cell_transform( - normalization_type=NormalizationType.SELF_NORM_CENTER_CROP, - resize_size=config["resize_size"], - crop_size=config["crop_size"], - ) - test_dataset = make_dataset(dataset_str=test_dataset_str_or_path_or_loo_dataset, transform=transform) - collate_fn = None - else: - logger.info("Making data loader for feature dataset (typical in leave one out evaluation)") - test_dataset = test_dataset_str_or_path_or_loo_dataset - collate_fn = None - class_mapping = None - if hasattr(test_dataset, "get_imagenet_class_mapping"): - class_mapping = test_dataset.get_imagenet_class_mapping() - - test_data_loader = make_data_loader( - dataset=DatasetWithEnumeratedTargets( - test_dataset, pad_dataset=True, num_replicas=distributed.get_global_size() - ), - batch_size=batch_size, - num_workers=num_workers, - sampler_type=SamplerType.DISTRIBUTED, - drop_last=False, - shuffle=False, - persistent_workers=False, - collate_fn=collate_fn, - ) - return test_data_loader, class_mapping - - -@dataclass -class Evaluator: - batch_size: int - num_workers: int - dataset_str_or_path: str - config: Dict - metric_type: MetricType - metrics_file_path: str - training_num_classes: int - save_results_func: Optional[Callable] - val_dataset_loo: Optional[TensorDataset] = None - - def __post_init__(self): - self.main_metric_name = f"{self.dataset_str_or_path}_accuracy" - - if self.val_dataset_loo is not None: - self.dataset_str_or_path = self.val_dataset_loo - - self.data_loader, self.class_mapping = make_eval_data_loader( - test_dataset_str_or_path_or_loo_dataset=self.dataset_str_or_path, - batch_size=self.batch_size, - num_workers=self.num_workers, - config=self.config, - ) - - @torch.no_grad() - def _evaluate_linear_classifiers( - self, - *, - feature_model, - linear_classifiers, - iteration, - prefixstring="", - best_classifier_on_val=None, - accumulate_results=False, - test_mode=False, - ) -> Tuple[Dict[str, Any], Optional[Dict[str, torch.Tensor]]]: - logger.info("running validation !") - - num_classes = len(self.class_mapping) if self.class_mapping is not None else self.training_num_classes - metric = build_metric(self.metric_type, num_classes=num_classes) - postprocessors = { - k: LinearPostprocessor(v, self.class_mapping) for k, v in linear_classifiers.classifiers_dict.items() - } - metrics = {k: metric.clone() for k in linear_classifiers.classifiers_dict} - - _, results_dict_temp, accumulated_results = evaluate_with_accumulate( - feature_model, - self.data_loader, - postprocessors, - metrics, - torch.cuda.current_device(), - accumulate_results=accumulate_results, - leave_one_out=self.config["leave_one_out"], - test_mode=test_mode, - ) - - logger.info("") - results_dict = {} - max_accuracy = 0 - best_classifier = "" - for _, (classifier_string, metric) in enumerate(results_dict_temp.items()): - logger.info(f"{prefixstring} -- Classifier: {classifier_string} * {metric}") - if ( - best_classifier_on_val is None and metric["top-1"].item() > max_accuracy - ) or classifier_string == best_classifier_on_val: - max_accuracy = metric["top-1"].item() - best_classifier = classifier_string - - results_dict["best_classifier"] = {"name": best_classifier, "accuracy": max_accuracy} - - logger.info(f"best classifier: {results_dict['best_classifier']}") - - accumulated_best_results = None - if test_mode: - accumulated_best_results = accumulated_results - elif accumulated_results is not None: - accumulated_best_results = accumulated_results[best_classifier] - - if distributed.is_main_process(): - with open(self.metrics_file_path, "a") as f: - f.write(f"iter: {iteration}\n") - for k, v in results_dict.items(): - f.write(json.dumps({k: v}) + "\n") - f.write("\n") - - return results_dict, accumulated_best_results - - def evaluate_and_maybe_save( - self, - feature_model, - linear_classifiers, - iteration: int, - best_classifier_on_val: Optional[Any] = None, - save_filename_suffix: str = "", - prefixstring: str = "", - test_mode: bool = False, - ): - logger.info(f"Testing on {self.dataset_str_or_path}") - save_results = self.save_results_func is not None - full_results_dict, accumulated_best_results = self._evaluate_linear_classifiers( - feature_model=feature_model, - linear_classifiers=remove_ddp_wrapper(linear_classifiers), - iteration=iteration, - prefixstring=prefixstring, - best_classifier_on_val=best_classifier_on_val, - accumulate_results=save_results, - test_mode=test_mode, - ) - if self.save_results_func is not None: - self.save_results_func( - filename_suffix=f"{self.dataset_str_or_path}{save_filename_suffix}", **accumulated_best_results - ) - - results_dict = { - self.main_metric_name: 100.0 * full_results_dict["best_classifier"]["accuracy"], - "best_classifier": full_results_dict["best_classifier"]["name"], - } - return results_dict, accumulated_best_results - - -def make_evaluators( - config: Dict, - val_metric_type: MetricType, - val_dataset: str, - metric_type: MetricType, - metrics_file_path: str, - training_num_classes: int, - save_results_func: Optional[Callable], - val_dataset_loo: Optional[TensorDataset] = None, -): - test_metric_types = config["test_metric_types"] - test_dataset_strs = config["test_datasets"] - if test_dataset_strs is None: - test_dataset_strs = (config["val_dataset"],) - if test_metric_types is None: - test_metric_types = (val_metric_type,) - else: - assert len(test_metric_types) == len(config["test_datasets"]) - - val_evaluator, *test_evaluators = [ - Evaluator( - dataset_str_or_path=dataset_str_or_path, - batch_size=config["batch_size"], - num_workers=config["num_workers"], - config=config, - metric_type=metric_type, - metrics_file_path=metrics_file_path, - training_num_classes=training_num_classes, - save_results_func=save_results_func, - val_dataset_loo=val_dataset_loo, - ) - for dataset_str_or_path, metric_type in zip( - (val_dataset,) + tuple(test_dataset_strs), - (val_metric_type,) + tuple(test_metric_types), - ) - ] - return val_evaluator, test_evaluators - - -class SchedulerType(Enum): - COSINE_ANNEALING = "cosine_annealing" - ONE_CYCLE = "one_cycle" - - def get_scheduler(self, optimizer, optim_param_groups, epoch_length, epochs, max_iter): - if self == SchedulerType.ONE_CYCLE: - lr_list = [optim_param_groups[i]["lr"] for i in range(len(optim_param_groups))] - scheduler = torch.optim.lr_scheduler.OneCycleLR( - optimizer, max_lr=lr_list, steps_per_epoch=epoch_length, epochs=epochs - ) - else: - scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, max_iter, eta_min=0) - print("CosineAnnealingLR scheduler") - return scheduler - - -def setup_linear_training( - *, - config: Dict, - sample_output: torch.Tensor, - training_num_classes: int, - checkpoint_output_dir: str, -): - linear_classifiers, optim_param_groups = setup_linear_classifiers( - sample_output, - config["n_last_blocks_list"], - config["learning_rates"], - config["weight_decays"], - config["batch_size"], - training_num_classes, - config["bag_of_channels"], - config["leave_one_out"], - config["avgpool"], - ) - max_iter = config["epochs"] * config["epoch_length"] - optimizer = torch.optim.AdamW(optim_param_groups, weight_decay=0) - - scheduler = config["scheduler"].get_scheduler( - optimizer=optimizer, - optim_param_groups=optim_param_groups, - epoch_length=config["epoch_length"], - epochs=config["epochs"], - max_iter=max_iter, - ) - checkpoint_period = config["save_checkpoint_iterations"] or config["epoch_length"] - periodic_checkpointer = build_periodic_checkpointer( - linear_classifiers, - checkpoint_output_dir, - optimizer=optimizer, - scheduler=scheduler, - period=checkpoint_period, - max_iter=max_iter, - max_to_keep=None, - ) - checkpoint = resume_or_load(periodic_checkpointer, config["classifier_fpath"] or "", resume=config["resume"]) - - start_iter = checkpoint.get("iteration", -1) + 1 - best_accuracy = checkpoint.get("best_accuracy", -1) - - if config["loss_type"] == LossType.BINARY_CROSS_ENTROPY: - criterion = nn.BCEWithLogitsLoss() - else: - criterion = nn.CrossEntropyLoss() - - return ( - linear_classifiers, - start_iter, - max_iter, - criterion, - optimizer, - scheduler, - periodic_checkpointer, - best_accuracy, - ) - - -def train_linear_classifiers( - *, - feature_model, - train_dataset, - train_config: Dict, - training_num_classes: int, - val_evaluator: Evaluator, - checkpoint_output_dir: str, - sample_output: Optional[torch.Tensor] = None, -): - - if train_config["leave_one_out"]: - assert sample_output is not None, "sample_output should be passed as argument when using leave_one_out." - else: - sample_output = feature_model(train_dataset[0][0].unsqueeze(0).cuda()) - - ( - linear_classifiers, - start_iter, - max_iter, - criterion, - optimizer, - scheduler, - periodic_checkpointer, - best_accuracy, - ) = setup_linear_training( - config=train_config, - sample_output=sample_output, - training_num_classes=training_num_classes, - checkpoint_output_dir=checkpoint_output_dir, - ) - - sampler_type = SamplerType.INFINITE - train_data_loader = make_data_loader( - dataset=train_dataset, - batch_size=train_config["batch_size"], - num_workers=train_config["num_workers"], - shuffle=True, - seed=0, - sampler_type=sampler_type, - sampler_advance=start_iter, - drop_last=True, - persistent_workers=True, - ) - eval_period = train_config["eval_period_iterations"] or train_config["epoch_length"] - iteration = start_iter - logger.info("Starting training from iteration {}".format(start_iter)) - metric_logger = MetricLogger(delimiter=" ") - header = "Training" - - for data, labels in metric_logger.log_every( - train_data_loader, - 10, - header, - max_iter, - start_iter, - ): - data = data.cuda(non_blocking=True) - labels = labels.cuda(non_blocking=True) - - if not train_config["leave_one_out"]: - in_classifier = feature_model(data) - else: - in_classifier = data - - outputs = linear_classifiers(in_classifier) - - if len(labels.shape) > 1: - labels = labels.float() - losses = {f"loss_{k}": criterion(v, labels) for k, v in outputs.items()} - loss = sum(losses.values()) - - optimizer.zero_grad() - loss.backward() - - optimizer.step() - scheduler.step() - - if iteration % 10 == 0: - torch.cuda.synchronize() - metric_logger.update(loss=loss.item()) - metric_logger.update(lr=optimizer.param_groups[0]["lr"]) - - periodic_checkpointer.step(iteration=iteration, best_accuracy=best_accuracy) - - if eval_period > 0 and (iteration + 1) % eval_period == 0 and iteration != max_iter - 1: - val_results_dict, _ = val_evaluator.evaluate_and_maybe_save( - feature_model=feature_model, - linear_classifiers=linear_classifiers, - prefixstring=f"ITER: {iteration}", - iteration=iteration, - ) - val_accuracy = val_results_dict[val_evaluator.main_metric_name] - if val_accuracy >= best_accuracy: - best_accuracy = val_accuracy - periodic_checkpointer.save_best(iteration=iteration, best_accuracy=best_accuracy) - torch.distributed.barrier() - - iteration = iteration + 1 - - return feature_model, linear_classifiers, iteration, periodic_checkpointer - - -def eval_linear_with_model( - model, - output_dir, - train_dataset_str, - val_dataset_str, - batch_size, - epochs, - epoch_length, - num_workers, - save_checkpoint_frequency, - eval_period_iterations, - learning_rates, - weight_decays, - autocast_dtype, - test_dataset_strs=None, - resume=True, - classifier_fpath=None, - val_metric_type=MetricType.MEAN_ACCURACY, - test_metric_types=None, - loss_type=LossType.CROSS_ENTROPY, - bag_of_channels=False, - leave_one_out_dataset="", - resize_size=0, - crop_size=384, - n_last_blocks=4, - avgpool=False, - scheduler=SchedulerType.COSINE_ANNEALING, -): - - if leave_one_out_dataset == "" or leave_one_out_dataset is None: - leave_one_out = False - else: - logger.info("Reading the leave-one-out label metadata.") - - leave_one_out_indices = {} - metadata = pd.read_csv(leave_one_out_dataset) - if "HPA" in leave_one_out_dataset: - metadata = metadata[metadata["Task_three"]].reset_index() - leave_one_out_label_type = "cell_type" - else: - metadata = metadata[metadata["Task_four"]].reset_index() - leave_one_out_label_type = "Plate" - leave_one_out_labels = metadata[leave_one_out_label_type].unique() - - for leave_one_out_label in leave_one_out_labels: - leave_one_out_indices[leave_one_out_label] = np.array( - metadata[metadata[leave_one_out_label_type] == leave_one_out_label].index.values - ) - - leave_one_out = True - - train_transform = make_classification_eval_cell_transform( - normalization_type=NormalizationType.SELF_NORM_AUG_DECODER, crop_size=crop_size, resize_size=resize_size - ) - print("train_transform", train_transform) - train_dataset = make_dataset( - dataset_str=train_dataset_str, - transform=train_transform, - ) - - training_num_classes = get_num_classes(train_dataset) - if leave_one_out: - training_num_classes += train_dataset.num_additional_labels_loo_eval - train_dataset_dict = create_train_dataset_dict(train_dataset) - n_last_blocks_list = [n_last_blocks] - n_last_blocks = max(n_last_blocks_list) - dataset_use_cache = True - autocast_ctx = partial(torch.cuda.amp.autocast, enabled=True, dtype=autocast_dtype) - feature_model = ModelWithIntermediateLayers(model, n_last_blocks, autocast_ctx) - - if bag_of_channels: - sample = train_dataset[0][0].unsqueeze(0) - sample_output = feature_model(sample.cuda()) - - if leave_one_out: - loo_dict = {} - train_data_dict = extract_features_for_dataset_dict( - feature_model, - train_dataset_dict, - batch_size, - num_workers, - gather_on_cpu=True, - avgpool=avgpool, - ) - val_dataset = make_dataset( - dataset_str=val_dataset_str, - transform=make_classification_eval_cell_transform( - normalization_type=NormalizationType.SELF_NORM_CENTER_CROP, crop_size=crop_size, resize_size=resize_size - ), - ) - val_dataset_dict = create_train_dataset_dict(val_dataset) - val_data_dict = extract_features_for_dataset_dict( - feature_model, - val_dataset_dict, - batch_size, - num_workers, - gather_on_cpu=True, - avgpool=avgpool, - ) - - train_features = train_data_dict[0]["train_features"] - train_labels = train_data_dict[0]["train_labels"] - val_features = val_data_dict[0]["train_features"] - val_labels = val_data_dict[0]["train_labels"] - - for loo_label, loo_indices in leave_one_out_indices.items(): - loo_for_training_indices = torch.ones(val_features.shape[0], dtype=bool) - loo_for_training_indices[loo_indices] = False - loo_for_val_indices = torch.zeros(val_features.shape[0], dtype=bool) - loo_for_val_indices[loo_indices] = True - - loo_dict[loo_label] = { - "train_features": torch.cat([train_features, val_features[loo_for_training_indices]]), - "train_labels": torch.cat([train_labels, val_labels[loo_for_training_indices]]), - "val_features": val_features[loo_indices], - "val_labels": val_labels[loo_indices], - } - save_results_func = None - # if config.save_results: - # save_results_func = partial(default_save_results_func, output_dir=output_dir) - - metrics_file_path = os.path.join(output_dir, "results_eval_linear.json") - periodic_checkpointers: list = [] - - train_config = { - "learning_rates": learning_rates, - "weight_decays": weight_decays, - "batch_size": batch_size, - "num_workers": num_workers, - "dataset_use_cache": dataset_use_cache, - "eval_period_iterations": eval_period_iterations, - "epoch_length": epoch_length, - "leave_one_out": leave_one_out, - "bag_of_channels": bag_of_channels, - "n_last_blocks_list": n_last_blocks_list, - "epochs": epochs, - "loss_type": loss_type, - "resume": resume, - "save_checkpoint_iterations": save_checkpoint_frequency, - "classifier_fpath": classifier_fpath, - "avgpool": avgpool, - "scheduler": scheduler, - } - config = { - "test_metric_types": test_metric_types, - "test_datasets": test_dataset_strs, - "val_metric_types": val_metric_type, - "val_dataset": val_dataset_str, - "batch_size": batch_size, - "num_workers": num_workers, - "leave_one_out": leave_one_out, - "crop_size": crop_size, - "resize_size": resize_size, - } - if not leave_one_out: - val_evaluator, test_evaluators = make_evaluators( - config=config, - val_metric_type=val_metric_type, - val_dataset=val_dataset_str, - metric_type=test_metric_types, - metrics_file_path=metrics_file_path, - training_num_classes=training_num_classes, - save_results_func=save_results_func, - ) - results_dict = {} - - for _try in train_dataset_dict.keys(): - if len(train_dataset_dict) > 1: - checkpoint_output_dir = os.path.join(output_dir, f"checkpoints_{_try}") - save_filename_suffix = f"_{_try}" - else: - checkpoint_output_dir, save_filename_suffix = output_dir, "" - os.makedirs(checkpoint_output_dir, exist_ok=True) - - feature_model, linear_classifiers, iteration, periodic_checkpointer = train_linear_classifiers( - train_config=train_config, - feature_model=feature_model, - train_dataset=train_dataset_dict[_try], - training_num_classes=training_num_classes, - val_evaluator=val_evaluator, - checkpoint_output_dir=checkpoint_output_dir, - ) - periodic_checkpointers.append(periodic_checkpointer) - results_dict[_try], _ = val_evaluator.evaluate_and_maybe_save( - feature_model=feature_model, - linear_classifiers=linear_classifiers, - iteration=iteration, - save_filename_suffix=save_filename_suffix, - ) - for test_evaluator in test_evaluators: - eval_results_dict, _ = test_evaluator.evaluate_and_maybe_save( - feature_model=feature_model, - linear_classifiers=linear_classifiers, - iteration=iteration, - best_classifier_on_val=results_dict[_try]["best_classifier"], - save_filename_suffix=save_filename_suffix, - ) - results_dict[_try] = {**eval_results_dict, **results_dict[_try]} - if len(train_dataset_dict) > 1: - results_dict = average_metrics(results_dict, ignore_keys=["best_classifier"]) - else: - results_dict = {**results_dict[_try]} - else: # if leave one out is True - test_results_dict = {} - for loo_label in loo_dict.keys(): - - checkpoint_output_dir, save_filename_suffix = os.path.join(output_dir, f"checkpoints_{loo_label}"), "" - os.makedirs(checkpoint_output_dir, exist_ok=True) - - train_dataset_loo = TensorDataset( - loo_dict[loo_label]["train_features"], loo_dict[loo_label]["train_labels"] - ) - - logger.info(f"Creating leave_one_out evaluators. loo_label: {loo_label}") - val_dataset_loo = TensorDataset(loo_dict[loo_label]["val_features"], loo_dict[loo_label]["val_labels"]) - val_evaluators_loo, _ = make_evaluators( - config=config, - val_metric_type=val_metric_type, - val_dataset="loo", - metric_type=test_metric_types, - metrics_file_path=metrics_file_path, - training_num_classes=training_num_classes, - save_results_func=save_results_func, - val_dataset_loo=val_dataset_loo, - ) - feature_model, linear_classifiers, iteration, periodic_checkpointer = train_linear_classifiers( - feature_model=feature_model, - train_dataset=train_dataset_loo, - train_config=train_config, - training_num_classes=training_num_classes, - val_evaluator=val_evaluators_loo, - checkpoint_output_dir=checkpoint_output_dir, - sample_output=sample_output, - ) - periodic_checkpointers.append(periodic_checkpointer) - _, test_results_dict[loo_label] = val_evaluators_loo.evaluate_and_maybe_save( - feature_model=feature_model, - linear_classifiers=linear_classifiers, - iteration=iteration, - save_filename_suffix=save_filename_suffix, - test_mode=True, - ) - classifier_names = test_results_dict[loo_label].keys() - results_dict = {k: [[], []] for k in classifier_names} - for ll in test_results_dict.keys(): - for k in classifier_names: - results_dict[k][0].append(test_results_dict[ll][k][0]) - results_dict[k][1].append(test_results_dict[ll][k][1]) - for k in classifier_names: - results_dict[k] = [ - np.argmax(torch.cat(results_dict[k][0]).cpu().detach().numpy(), axis=1), - torch.cat(results_dict[k][1]).cpu().detach().numpy(), - ] - results_dict[k] = f1_score(results_dict[k][1], results_dict[k][0], average="macro", labels=[4, 5, 6]) - logger.info( - f"Best performance is for {max(results_dict, key=results_dict.get)}, with F1-Score of {results_dict[max(results_dict, key=results_dict.get)]}" - ) - - logger.info("Test Results Dict " + str(results_dict)) - return results_dict - - -def main(args): - model, autocast_dtype = setup_and_build_model(args) - eval_linear_with_model( - model=model, - output_dir=args.output_dir, - train_dataset_str=args.train_dataset_str, - val_dataset_str=args.val_dataset_str, - test_dataset_strs=args.test_dataset_strs, - batch_size=args.batch_size, - epochs=args.epochs, - epoch_length=args.epoch_length, - num_workers=args.num_workers, - save_checkpoint_frequency=args.save_checkpoint_frequency, - eval_period_iterations=args.eval_period_iterations, - learning_rates=args.learning_rates, - weight_decays=args.weight_decays, - autocast_dtype=autocast_dtype, - resume=not args.no_resume, - classifier_fpath=args.classifier_fpath, - val_metric_type=args.val_metric_type, - test_metric_types=args.test_metric_types, - loss_type=args.loss_type, - bag_of_channels=args.bag_of_channels, - leave_one_out_dataset=args.leave_one_out_dataset, - crop_size=args.crop_size, - resize_size=args.resize_size, - n_last_blocks=args.n_last_blocks, - avgpool=args.avgpool, - scheduler=args.scheduler, - ) - return 0 - - -if __name__ == "__main__": - description = "DINOv2 linear_cell_dino evaluation" - args_parser = get_args_parser(description=description) - args = args_parser.parse_args() - sys.exit(main(args)) diff --git a/dinov2/eval/cell_dino/utils.py b/dinov2/eval/cell_dino/utils.py deleted file mode 100644 index f7ec81351814738164ad2bba5d26329545e863e1..0000000000000000000000000000000000000000 --- a/dinov2/eval/cell_dino/utils.py +++ /dev/null @@ -1,542 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the CC-by-NC licence, -# found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree. - -import logging -from typing import Callable, Dict, Optional, Any, List - -import torch -from torch import nn -from torchmetrics import MetricCollection - -from dinov2.data import DatasetWithEnumeratedTargets, SamplerType, make_data_loader -from dinov2.data import NoOpAccumulator, ResultsAccumulator -import dinov2.distributed as distributed -from dinov2.logging import MetricLogger -from enum import Enum -from torch.utils.data import Subset -from torchvision.datasets.vision import StandardTransform -import numpy as np -from torch.nn.functional import one_hot, softmax - -logger = logging.getLogger("dinov2") - - -class LossType(Enum): - CROSS_ENTROPY = "cross_entropy" - BINARY_CROSS_ENTROPY = "binary_cross_entropy" - - -class BagOfChannelsModelWithNormalize(nn.Module): - def __init__(self, model, autocast_ctx, avgpool, n_last_blocks=1): - super().__init__() - self.model = model - self.autocast_ctx = autocast_ctx - self.n_last_blocks = n_last_blocks - self.avgpool = avgpool - - def forward(self, samples): - with self.autocast_ctx(): - features = self.model.get_intermediate_layers(samples, self.n_last_blocks, return_class_token=True) - output = create_linear_input(features, self.avgpool, use_n_blocks=self.n_last_blocks) - return nn.functional.normalize(output, dim=1, p=2) - - -@torch.inference_mode() -def evaluate_with_accumulate( - model: nn.Module, - data_loader, - postprocessors: Dict[str, nn.Module], - metrics: Dict[str, MetricCollection], - device: torch.device, - criterion: Optional[nn.Module] = None, - test_mode: bool = False, - accumulate_results: bool = False, - leave_one_out: bool = False, -): - model.eval() - - if test_mode: - output_tensor = {k: [] for k in postprocessors.keys()} - target_tensor = {k: [] for k in postprocessors.keys()} - - if criterion is not None: - criterion.eval() - - accumulator_class = ResultsAccumulator if accumulate_results else NoOpAccumulator - accumulators = {k: accumulator_class() for k in postprocessors.keys()} - - for metric in metrics.values(): - metric = metric.to(device) - - metric_logger = MetricLogger(delimiter=" ") - header = "Test:" - - for samples, targets, *_ in metric_logger.log_every(data_loader, 10, header): - if isinstance(targets, list): - index = targets[0] - targets = targets[1] - samples, targets, index = samples[index >= 0], targets[index >= 0], index[index >= 0] - if len(index) == 0: - continue - - outputs = samples.to(device) if leave_one_out else model(samples.to(device)) - targets = targets.to(device) - - if criterion is not None: - loss = criterion(outputs, targets) - metric_logger.update(loss=loss.item()) - - for k, metric in metrics.items(): - metric_inputs = postprocessors[k](outputs, targets) - metric.update(**metric_inputs) - if test_mode: - output_tensor[k].append(metric_inputs["preds"]) - target_tensor[k].append(metric_inputs["target"]) - accumulators[k].update(preds=metric_inputs["preds"], target=metric_inputs["target"], index=index) - - metric_logger.synchronize_between_processes() - logger.info(f"Averaged stats: {metric_logger}") - - stats = {k: metric.compute() for k, metric in metrics.items()} - metric_logger_stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} - - # accumulator.accumulate() returns None for the NoOpAccumulator - accumulated_results = {k: accumulator.accumulate() for k, accumulator in accumulators.items()} - if test_mode: - for k in postprocessors.keys(): - output_tensor[k] = torch.cat(output_tensor[k]) - target_tensor[k] = torch.cat(target_tensor[k]) - accumulated_results = {k: [output_tensor[k], target_tensor[k]] for k in postprocessors.keys()} - - if accumulate_results: - return metric_logger_stats, stats - return metric_logger_stats, stats, accumulated_results - - -def all_gather_and_flatten(tensor_rank): - tensor_all_ranks = torch.empty( - distributed.get_global_size(), - *tensor_rank.shape, - dtype=tensor_rank.dtype, - device=tensor_rank.device, - ) - tensor_list = list(tensor_all_ranks.unbind(0)) - torch.distributed.all_gather(tensor_list, tensor_rank.contiguous()) - return tensor_all_ranks.flatten(end_dim=1) - - -def extract_features_cell_dino( - model, dataset, batch_size, num_workers, gather_on_cpu=False, shuffle=False, avgpool=False -): - dataset_with_enumerated_targets = DatasetWithEnumeratedTargets(dataset) - sample_count = len(dataset_with_enumerated_targets) - data_loader = make_data_loader( - dataset=dataset_with_enumerated_targets, - batch_size=batch_size, - num_workers=num_workers, - sampler_type=SamplerType.DISTRIBUTED, - drop_last=False, - shuffle=shuffle, - ) - return extract_features_with_dataloader_cell_dino(model, data_loader, sample_count, gather_on_cpu, avgpool=avgpool) - - -@torch.inference_mode() -def extract_features_with_dataloader_cell_dino(model, data_loader, sample_count, gather_on_cpu=False, avgpool=False): - gather_device = torch.device("cpu") if gather_on_cpu else torch.device("cuda") - metric_logger = MetricLogger(delimiter=" ") - features, all_labels = None, None - for samples, (index, labels_rank) in metric_logger.log_every(data_loader, 10): - samples = samples.cuda(non_blocking=True) - labels_rank = labels_rank.cuda(non_blocking=True) - index = index.cuda(non_blocking=True) - feat = model(samples) - if isinstance(samples, list) or isinstance(feat, tuple): - features_rank = create_linear_input(feat, avgpool=avgpool) - else: - features_rank = feat - - # init storage feature matrix - if features is None: - features = torch.zeros(sample_count, features_rank.shape[-1], device=gather_device) - labels_shape = list(labels_rank.shape) - labels_shape[0] = sample_count - all_labels = torch.full(labels_shape, fill_value=-1, device=gather_device) - logger.info(f"Storing features into tensor of shape {features.shape}") - - # share indexes, features and labels between processes - index_all = all_gather_and_flatten(index).to(gather_device) - features_all_ranks = all_gather_and_flatten(features_rank).to(gather_device) - labels_all_ranks = all_gather_and_flatten(labels_rank).to(gather_device) - - # update storage feature matrix - if len(index_all) > 0: - features.index_copy_(0, index_all, features_all_ranks) - all_labels.index_copy_(0, index_all, labels_all_ranks) - - logger.info(f"Features shape: {tuple(features.shape)}") - logger.info(f"Labels shape: {tuple(all_labels.shape)}") - - assert torch.all(all_labels > -1) - - return features, all_labels - - -def create_linear_input(x_tokens_list, avgpool=False, use_n_blocks=1): - intermediate_output = x_tokens_list[-use_n_blocks:] - output = torch.cat( - [class_token for _, class_token in intermediate_output], dim=-1 - ) # concatenate class tokens of the last n blocks - if avgpool: - output = torch.cat( - ( - output, - torch.mean(intermediate_output[-1][0], dim=-2).reshape( - intermediate_output[-1][0].shape[0], -1 - ), # average pooling of patch tokens: average over N, then concatenate channels if single-channel patch model - ), - dim=-1, - ) # concatenate average pooling of patch tokens to concatenated patch tokens - output = output.reshape(output.shape[0], -1) - - return output.float() - - -def get_target_transform(dataset) -> Optional[Callable]: - if hasattr(dataset, "transforms"): - if isinstance(dataset.transforms, StandardTransform): - return dataset.transforms.target_transform - raise ValueError("Dataset has a non-standard .transforms property") - if hasattr(dataset, "target_transform"): - return dataset.target_transform - return None - - -def get_labels(dataset) -> torch.Tensor: - """ - Get the labels of a classification dataset, as a Tensor, using the `get_targets` method - if it is present or loading the labels one by one with `get_target`, if it exists. - If the dataset has a target transform, iterate over the whole dataset to get the - transformed labels for each element, then stack them as a torch tensor. - """ - logger.info("Getting dataset labels ...") - if hasattr(dataset, "get_targets") or hasattr(dataset, "get_target"): - if hasattr(dataset, "get_targets"): # Returns a np.array - labels = dataset.get_targets() - elif hasattr(dataset, "get_target"): - labels = [dataset.get_target(i) for i in range(len(dataset))] - target_transform = get_target_transform(dataset) - if target_transform is not None: - labels = [target_transform(label) for label in labels] - else: - # Target transform is applied in this case - labels = [dataset[i][1] for i in range(len(dataset))] - return torch.stack([torch.tensor(label, dtype=int) for label in labels]) - - -def get_num_classes(dataset) -> int: - """ - Get the labels of a dataset and compute the number of classes - """ - labels = get_labels(dataset) - if len(labels.shape) > 1: - return int(labels.shape[1]) - return int(labels.max() + 1) - - -def average_metrics(eval_metrics_dict: dict[Any, dict[str, torch.Tensor]], ignore_keys: List[str] = []): - """ - Function that computes the average and the std on a metrics dict. - A linear evaluation dictionary contains "best_classifier", - so this specific key is removed for computing aggregated metrics. - """ - output_metrics_dict = {} - metrics = [metric for metric in eval_metrics_dict[0].keys() if metric not in ignore_keys] - for metric in metrics: - stats_tensor = torch.tensor([stat[metric] for stat in eval_metrics_dict.values()]) - output_metrics_dict[metric + "_mean"] = stats_tensor.mean().item() - output_metrics_dict[metric + "_std"] = torch.std(stats_tensor).item() - - return output_metrics_dict - - -def create_class_indices_mapping(labels: torch.Tensor) -> dict[int, torch.Tensor]: - """ - Efficiently creates a mapping between the labels and tensors containing - the indices of all the dataset elements that share this label. - In the case of multiple labels, it is not guaranteed that there - will be exactly the specified percentage of labels. - """ - if len(labels.shape) > 1: # labels are a one-hot encoding - assert len(labels.shape) == 2 - sorted_labels, indices = torch.nonzero(labels.T, as_tuple=True) - else: - sorted_labels, indices = torch.sort(labels, stable=True) - unique_labels, counts = torch.unique_consecutive(sorted_labels, return_counts=True) - mapping = dict(zip(unique_labels.tolist(), torch.split(indices, counts.tolist()))) - return mapping - - -def _shuffle_dataset(dataset: torch.Tensor, seed: int = 0): - """ - Shuffling a dataset by subsetting it with a random permutation of its indices - """ - random_generator = torch.Generator() - random_generator.manual_seed(seed) - random_indices = torch.randperm(len(dataset), generator=random_generator) - return Subset(dataset, random_indices) - - -def _subset_dataset_per_class( - class_indices_mapping: dict[int, torch.Tensor], - n_or_percent_per_class: float, - dataset_size: int, - seed: int = 0, - is_percent: bool = False, -) -> torch.Tensor: - """ - Helper function to select a percentage of a dataset, equally distributed across classes, - or to take the same number of elements from each class of the dataset. - Returns a boolean mask tensor being True at indices of selected elements - """ - - random_generator = torch.Generator() - random_generator.manual_seed(seed) - - final_indices_bool = torch.zeros(dataset_size, dtype=bool) - for class_indices in class_indices_mapping.values(): - # Select at least one element - n_for_class = max(int(len(class_indices) * n_or_percent_per_class), 1) if is_percent else n_or_percent_per_class - assert isinstance(n_for_class, int) - filtered_index = torch.randperm(len(class_indices), generator=random_generator)[:n_for_class] - final_indices_bool[class_indices[filtered_index]] = True - return final_indices_bool - - -def _multilabel_rebalance_subset( - class_indices_mapping: dict[int, torch.Tensor], - n_or_percent_per_class: float, - labels: torch.Tensor, - indices_bool: torch.Tensor, - dataset_size: int, - seed: int = 0, -) -> torch.Tensor: - """ - Helper function to refine a subset of a multi-label dataset (indices_bool) - to better match a target percentage of labels. - Returns a boolean mask tensor being True at indices of selected elements. - """ - - # Compute the number of selected labels in indices_bool - num_total_labels = labels.sum() - num_wanted_labels = int(num_total_labels * n_or_percent_per_class) - num_selected_labels = (labels[indices_bool] > 0).sum() - logger.info(f" {num_selected_labels} labels instead of {num_wanted_labels}") - - # Compute a new percentage and new set selecting less images, therefore less labels, to match approximatelly the exact percentage of labels selected - n_or_percent_per_class = n_or_percent_per_class / (num_selected_labels / num_wanted_labels) - final_indices_bool = _subset_dataset_per_class( - class_indices_mapping, n_or_percent_per_class, dataset_size, seed, True - ) - - # Compute the number of labels finally used - num_selected_labels = (labels[final_indices_bool] > 0).sum() - logger.info(f" {num_selected_labels} labels instead of {num_wanted_labels}") - - return final_indices_bool - - -def split_train_val_datasets(train_dataset, split_percentage: float = 0.1, shuffle_train: bool = True): - """ - Splitting a percent of the train dataset to choose hyperparameters, taking the same percentage for each class. - If `shuffle` is False, taking the first elements of each class as the validaton set. - """ - assert 0 < split_percentage < 1 - logger.info(f"Selecting {int(split_percentage * 100)}% of the train dataset as the validation set") - if shuffle_train: - logger.info("Shuffling train dataset before splitting in train and validation sets") - train_dataset = _shuffle_dataset(train_dataset) - train_labels = get_labels(train_dataset) - class_indices_mapping = create_class_indices_mapping(train_labels) - val_mask = torch.zeros(len(train_labels), dtype=bool) - for class_indices in class_indices_mapping.values(): - # If there is only one element, it goes in the train set - n_for_val = max(1, int(split_percentage * len(class_indices))) if len(class_indices) > 1 else 0 - val_mask[class_indices[:n_for_val]] = True - - val_dataset = Subset(train_dataset, val_mask.nonzero().flatten()) - train_dataset = Subset(train_dataset, (~val_mask).nonzero().flatten()) - return train_dataset, val_dataset - - -def create_train_dataset_dict( - train_dataset, - few_shot_eval: bool = False, - few_shot_k_or_percent=None, - few_shot_n_tries: int = 1, -) -> dict[int, dict[int, Any]]: - """ - Randomly split a dataset for few-shot evaluation, with `few_shot_k_or_percent` being - n elements or x% of a class. Produces a dict, which keys are number of random "tries" - and values are the dataset subset for this "try". - - Format is {"nth-try": dataset} - """ - if few_shot_eval is False: - assert few_shot_k_or_percent is None - assert few_shot_n_tries == 1 - return {0: train_dataset} - - assert few_shot_k_or_percent is not None - train_labels = get_labels(train_dataset) - class_indices_mapping = create_class_indices_mapping(train_labels) - train_dataset_dict: dict[int, Any] = {} - is_percent = few_shot_k_or_percent < 1 - if not is_percent: - few_shot_k_or_percent = int(few_shot_k_or_percent) - - for t in range(few_shot_n_tries): - t_subset_bool = _subset_dataset_per_class( - class_indices_mapping=class_indices_mapping, - n_or_percent_per_class=few_shot_k_or_percent, - dataset_size=len(train_labels), - is_percent=is_percent, - seed=t, - ) - if len(train_labels.shape) > 1 and is_percent: - t_subset_bool = _multilabel_rebalance_subset( - class_indices_mapping=class_indices_mapping, - n_or_percent_per_class=few_shot_k_or_percent, - dataset_size=len(train_labels), - labels=train_labels, - indices_bool=t_subset_bool, - seed=t, - ) - train_dataset_dict[t] = Subset(train_dataset, t_subset_bool.nonzero().flatten()) - return train_dataset_dict - - -def extract_features_for_dataset_dict( - model, - dataset_dict: dict[int, dict[int, Any]], - batch_size: int, - num_workers: int, - gather_on_cpu=False, - avgpool=False, -) -> dict[int, dict[str, torch.Tensor]]: - """ - Extract features for each subset of dataset in the context of few-shot evaluations - """ - few_shot_data_dict: dict[int, dict[str, torch.Tensor]] = {} - for try_n, dataset in dataset_dict.items(): - features, labels = extract_features_cell_dino( - model, dataset, batch_size, num_workers, gather_on_cpu=gather_on_cpu, avgpool=avgpool - ) - few_shot_data_dict[try_n] = {"train_features": features, "train_labels": labels} - return few_shot_data_dict - - -def pad_multilabel_and_collate(batch, pad_value=-1): - """ - This method pads and collates a batch of (image, (index, target)) tuples, coming from - DatasetWithEnumeratedTargets, with targets that are list of potentially varying sizes. - The targets are padded to the length of the longest target list in the batch. - """ - maxlen = max(len(targets) for _, (_, targets) in batch) - padded_batch = [ - (image, (index, np.pad(targets, (0, maxlen - len(targets)), constant_values=pad_value))) - for image, (index, targets) in batch - ] - return torch.utils.data.default_collate(padded_batch) - - -class KnnModule(torch.nn.Module): - """ - Gets knn of test features from all processes on a chunk of the train features - - Each rank gets a chunk of the train features as well as a chunk of the test features. - In `compute_neighbors`, for each rank one after the other, its chunk of test features - is sent to all devices, partial knns are computed with each chunk of train features - then collated back on the original device. - """ - - def __init__(self, train_features, train_labels, nb_knn, T, device, num_classes=1000): - super().__init__() - - self.global_rank = distributed.get_global_rank() - self.global_size = distributed.get_global_size() - - self.device = device - self.train_features_rank_T = train_features.chunk(self.global_size)[self.global_rank].T.to(self.device) - # Labels can either be integers, or in a one-hot format - self.candidates = train_labels.chunk(self.global_size)[self.global_rank].unsqueeze(0).to(self.device) - self.nb_knn = nb_knn - self.max_k = max(self.nb_knn) - self.T = T - self.num_classes = num_classes - - def _get_knn_sims_and_labels(self, similarity, train_labels): - topk_sims, indices = similarity.topk(self.max_k, largest=True, sorted=True) - if len(train_labels.shape) == 3: # If the labels are in one_hot format - indices = indices.unsqueeze(2).expand(-1, -1, self.num_classes) # Orignally [bs, max_k] - neighbors_labels = torch.gather(train_labels, 1, indices) - return topk_sims, neighbors_labels - - def _similarity_for_rank(self, features_rank, source_rank): - # Send the features from `source_rank` to all ranks - broadcast_shape = torch.tensor(features_rank.shape).to(self.device) - torch.distributed.broadcast(broadcast_shape, source_rank) - - broadcasted = features_rank - if self.global_rank != source_rank: - broadcasted = torch.zeros(*broadcast_shape, dtype=features_rank.dtype, device=self.device) - torch.distributed.broadcast(broadcasted, source_rank) - - # Compute the neighbors for `source_rank` among `train_features_rank_T` - similarity_rank = torch.mm(broadcasted, self.train_features_rank_T) - candidate_labels = self.candidates.expand(len(similarity_rank), *self.candidates.shape[1:]) - return self._get_knn_sims_and_labels(similarity_rank, candidate_labels) - - def _gather_all_knn_for_rank(self, topk_sims, neighbors_labels, target_rank): - # Gather all neighbors for `target_rank` - topk_sims_rank = retrieved_rank = None - if self.global_rank == target_rank: - topk_sims_rank = [torch.zeros_like(topk_sims) for _ in range(self.global_size)] - retrieved_rank = [torch.zeros_like(neighbors_labels) for _ in range(self.global_size)] - - torch.distributed.gather(topk_sims, topk_sims_rank, dst=target_rank) - torch.distributed.gather(neighbors_labels, retrieved_rank, dst=target_rank) - - if self.global_rank == target_rank: - # Perform a second top-k on the k * global_size retrieved neighbors - topk_sims_rank = torch.cat(topk_sims_rank, dim=1) - retrieved_rank = torch.cat(retrieved_rank, dim=1) - results = self._get_knn_sims_and_labels(topk_sims_rank, retrieved_rank) - return results - return None - - def compute_neighbors(self, features_rank): - for rank in range(self.global_size): - topk_sims, neighbors_labels = self._similarity_for_rank(features_rank, rank) - results = self._gather_all_knn_for_rank(topk_sims, neighbors_labels, rank) - if results is not None: - topk_sims_rank, neighbors_labels_rank = results - return topk_sims_rank, neighbors_labels_rank - - def forward(self, features_rank): - """ - Compute the results on all values of `self.nb_knn` neighbors from the full `self.max_k` - """ - assert all(k <= self.max_k for k in self.nb_knn) - - topk_sims, neighbors_labels = self.compute_neighbors(features_rank) - batch_size = neighbors_labels.shape[0] - topk_sims_transform = softmax(topk_sims / self.T, 1) - voting_coefficient = topk_sims_transform.view(batch_size, -1, 1) - if len(neighbors_labels.shape) == 2: # If the labels are not yet one hot - neighbors_labels = one_hot(neighbors_labels, num_classes=self.num_classes) - matmul = torch.mul(neighbors_labels, voting_coefficient) - probas_for_k = {k: torch.sum(matmul[:, :k, :], 1) for k in self.nb_knn} - return probas_for_k diff --git a/dinov2/eval/depth/__init__.py b/dinov2/eval/depth/__init__.py deleted file mode 100644 index b88da6bf80be92af00b72dfdb0a806fa64a7a2d9..0000000000000000000000000000000000000000 --- a/dinov2/eval/depth/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. diff --git a/dinov2/eval/depth/models/__init__.py b/dinov2/eval/depth/models/__init__.py deleted file mode 100644 index 9a5825181dc2189424b5c58d245b36919cbc5b2e..0000000000000000000000000000000000000000 --- a/dinov2/eval/depth/models/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from .backbones import * # noqa: F403 -from .builder import BACKBONES, DEPTHER, HEADS, LOSSES, build_backbone, build_depther, build_head, build_loss -from .decode_heads import * # noqa: F403 -from .depther import * # noqa: F403 -from .losses import * # noqa: F403 diff --git a/dinov2/eval/depth/models/backbones/__init__.py b/dinov2/eval/depth/models/backbones/__init__.py deleted file mode 100644 index 520d75bc6e064b9d64487293604ac1bda6e2b6f7..0000000000000000000000000000000000000000 --- a/dinov2/eval/depth/models/backbones/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from .vision_transformer import DinoVisionTransformer diff --git a/dinov2/eval/depth/models/backbones/vision_transformer.py b/dinov2/eval/depth/models/backbones/vision_transformer.py deleted file mode 100644 index 69bda46fd69eb7dabb8f5b60e6fa459fdc21aeab..0000000000000000000000000000000000000000 --- a/dinov2/eval/depth/models/backbones/vision_transformer.py +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from mmcv.runner import BaseModule - -from ..builder import BACKBONES - - -@BACKBONES.register_module() -class DinoVisionTransformer(BaseModule): - """Vision Transformer.""" - - def __init__(self, *args, **kwargs): - super().__init__() diff --git a/dinov2/eval/depth/models/builder.py b/dinov2/eval/depth/models/builder.py deleted file mode 100644 index c152643435308afcff60b07cd68ea979fe1d90cb..0000000000000000000000000000000000000000 --- a/dinov2/eval/depth/models/builder.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import warnings - -from mmcv.cnn import MODELS as MMCV_MODELS -from mmcv.cnn.bricks.registry import ATTENTION as MMCV_ATTENTION -from mmcv.utils import Registry - -MODELS = Registry("models", parent=MMCV_MODELS) -ATTENTION = Registry("attention", parent=MMCV_ATTENTION) - - -BACKBONES = MODELS -NECKS = MODELS -HEADS = MODELS -LOSSES = MODELS -DEPTHER = MODELS - - -def build_backbone(cfg): - """Build backbone.""" - return BACKBONES.build(cfg) - - -def build_neck(cfg): - """Build neck.""" - return NECKS.build(cfg) - - -def build_head(cfg): - """Build head.""" - return HEADS.build(cfg) - - -def build_loss(cfg): - """Build loss.""" - return LOSSES.build(cfg) - - -def build_depther(cfg, train_cfg=None, test_cfg=None): - """Build depther.""" - if train_cfg is not None or test_cfg is not None: - warnings.warn("train_cfg and test_cfg is deprecated, " "please specify them in model", UserWarning) - assert cfg.get("train_cfg") is None or train_cfg is None, "train_cfg specified in both outer field and model field " - assert cfg.get("test_cfg") is None or test_cfg is None, "test_cfg specified in both outer field and model field " - return DEPTHER.build(cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg)) diff --git a/dinov2/eval/depth/models/decode_heads/__init__.py b/dinov2/eval/depth/models/decode_heads/__init__.py deleted file mode 100644 index bd0f0754a5b01d7622c1f26bf3f60daea19da4e8..0000000000000000000000000000000000000000 --- a/dinov2/eval/depth/models/decode_heads/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from .dpt_head import DPTHead -from .linear_head import BNHead diff --git a/dinov2/eval/depth/models/decode_heads/decode_head.py b/dinov2/eval/depth/models/decode_heads/decode_head.py deleted file mode 100644 index f8c867a3ec687090b280d90bb86aee435320acda..0000000000000000000000000000000000000000 --- a/dinov2/eval/depth/models/decode_heads/decode_head.py +++ /dev/null @@ -1,225 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import copy -from abc import ABCMeta, abstractmethod - -import mmcv -import numpy as np -import torch -import torch.nn as nn -from mmcv.runner import BaseModule, auto_fp16, force_fp32 - -from ...ops import resize -from ..builder import build_loss - - -class DepthBaseDecodeHead(BaseModule, metaclass=ABCMeta): - """Base class for BaseDecodeHead. - - Args: - in_channels (List): Input channels. - channels (int): Channels after modules, before conv_depth. - conv_cfg (dict|None): Config of conv layers. Default: None. - act_cfg (dict): Config of activation layers. - Default: dict(type='ReLU') - loss_decode (dict): Config of decode loss. - Default: dict(type='SigLoss'). - sampler (dict|None): The config of depth map sampler. - Default: None. - align_corners (bool): align_corners argument of F.interpolate. - Default: False. - min_depth (int): Min depth in dataset setting. - Default: 1e-3. - max_depth (int): Max depth in dataset setting. - Default: None. - norm_cfg (dict|None): Config of norm layers. - Default: None. - classify (bool): Whether predict depth in a cls.-reg. manner. - Default: False. - n_bins (int): The number of bins used in cls. step. - Default: 256. - bins_strategy (str): The discrete strategy used in cls. step. - Default: 'UD'. - norm_strategy (str): The norm strategy on cls. probability - distribution. Default: 'linear' - scale_up (str): Whether predict depth in a scale-up manner. - Default: False. - """ - - def __init__( - self, - in_channels, - channels=96, - conv_cfg=None, - act_cfg=dict(type="ReLU"), - loss_decode=dict(type="SigLoss", valid_mask=True, loss_weight=10), - sampler=None, - align_corners=False, - min_depth=1e-3, - max_depth=None, - norm_cfg=None, - classify=False, - n_bins=256, - bins_strategy="UD", - norm_strategy="linear", - scale_up=False, - ): - super(DepthBaseDecodeHead, self).__init__() - - self.in_channels = in_channels - self.channels = channels - self.conv_cfg = conv_cfg - self.act_cfg = act_cfg - if isinstance(loss_decode, dict): - self.loss_decode = build_loss(loss_decode) - elif isinstance(loss_decode, (list, tuple)): - self.loss_decode = nn.ModuleList() - for loss in loss_decode: - self.loss_decode.append(build_loss(loss)) - self.align_corners = align_corners - self.min_depth = min_depth - self.max_depth = max_depth - self.norm_cfg = norm_cfg - self.classify = classify - self.n_bins = n_bins - self.scale_up = scale_up - - if self.classify: - assert bins_strategy in ["UD", "SID"], "Support bins_strategy: UD, SID" - assert norm_strategy in ["linear", "softmax", "sigmoid"], "Support norm_strategy: linear, softmax, sigmoid" - - self.bins_strategy = bins_strategy - self.norm_strategy = norm_strategy - self.softmax = nn.Softmax(dim=1) - self.conv_depth = nn.Conv2d(channels, n_bins, kernel_size=3, padding=1, stride=1) - else: - self.conv_depth = nn.Conv2d(channels, 1, kernel_size=3, padding=1, stride=1) - - self.fp16_enabled = False - self.relu = nn.ReLU() - self.sigmoid = nn.Sigmoid() - - def extra_repr(self): - """Extra repr.""" - s = f"align_corners={self.align_corners}" - return s - - @auto_fp16() - @abstractmethod - def forward(self, inputs, img_metas): - """Placeholder of forward function.""" - pass - - def forward_train(self, img, inputs, img_metas, depth_gt, train_cfg): - """Forward function for training. - Args: - inputs (list[Tensor]): List of multi-level img features. - img_metas (list[dict]): List of image info dict where each dict - has: 'img_shape', 'scale_factor', 'flip', and may also contain - 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. - For details on the values of these keys see - `depth/datasets/pipelines/formatting.py:Collect`. - depth_gt (Tensor): GT depth - train_cfg (dict): The training config. - - Returns: - dict[str, Tensor]: a dictionary of loss components - """ - depth_pred = self.forward(inputs, img_metas) - losses = self.losses(depth_pred, depth_gt) - - log_imgs = self.log_images(img[0], depth_pred[0], depth_gt[0], img_metas[0]) - losses.update(**log_imgs) - - return losses - - def forward_test(self, inputs, img_metas, test_cfg): - """Forward function for testing. - Args: - inputs (list[Tensor]): List of multi-level img features. - img_metas (list[dict]): List of image info dict where each dict - has: 'img_shape', 'scale_factor', 'flip', and may also contain - 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. - For details on the values of these keys see - `depth/datasets/pipelines/formatting.py:Collect`. - test_cfg (dict): The testing config. - - Returns: - Tensor: Output depth map. - """ - return self.forward(inputs, img_metas) - - def depth_pred(self, feat): - """Prediction each pixel.""" - if self.classify: - logit = self.conv_depth(feat) - - if self.bins_strategy == "UD": - bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device) - elif self.bins_strategy == "SID": - bins = torch.logspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device) - - # following Adabins, default linear - if self.norm_strategy == "linear": - logit = torch.relu(logit) - eps = 0.1 - logit = logit + eps - logit = logit / logit.sum(dim=1, keepdim=True) - elif self.norm_strategy == "softmax": - logit = torch.softmax(logit, dim=1) - elif self.norm_strategy == "sigmoid": - logit = torch.sigmoid(logit) - logit = logit / logit.sum(dim=1, keepdim=True) - - output = torch.einsum("ikmn,k->imn", [logit, bins]).unsqueeze(dim=1) - - else: - if self.scale_up: - output = self.sigmoid(self.conv_depth(feat)) * self.max_depth - else: - output = self.relu(self.conv_depth(feat)) + self.min_depth - return output - - @force_fp32(apply_to=("depth_pred",)) - def losses(self, depth_pred, depth_gt): - """Compute depth loss.""" - loss = dict() - depth_pred = resize( - input=depth_pred, size=depth_gt.shape[2:], mode="bilinear", align_corners=self.align_corners, warning=False - ) - if not isinstance(self.loss_decode, nn.ModuleList): - losses_decode = [self.loss_decode] - else: - losses_decode = self.loss_decode - for loss_decode in losses_decode: - if loss_decode.loss_name not in loss: - loss[loss_decode.loss_name] = loss_decode(depth_pred, depth_gt) - else: - loss[loss_decode.loss_name] += loss_decode(depth_pred, depth_gt) - return loss - - def log_images(self, img_path, depth_pred, depth_gt, img_meta): - show_img = copy.deepcopy(img_path.detach().cpu().permute(1, 2, 0)) - show_img = show_img.numpy().astype(np.float32) - show_img = mmcv.imdenormalize( - show_img, - img_meta["img_norm_cfg"]["mean"], - img_meta["img_norm_cfg"]["std"], - img_meta["img_norm_cfg"]["to_rgb"], - ) - show_img = np.clip(show_img, 0, 255) - show_img = show_img.astype(np.uint8) - show_img = show_img[:, :, ::-1] - show_img = show_img.transpose(0, 2, 1) - show_img = show_img.transpose(1, 0, 2) - - depth_pred = depth_pred / torch.max(depth_pred) - depth_gt = depth_gt / torch.max(depth_gt) - - depth_pred_color = copy.deepcopy(depth_pred.detach().cpu()) - depth_gt_color = copy.deepcopy(depth_gt.detach().cpu()) - - return {"img_rgb": show_img, "img_depth_pred": depth_pred_color, "img_depth_gt": depth_gt_color} diff --git a/dinov2/eval/depth/models/decode_heads/dpt_head.py b/dinov2/eval/depth/models/decode_heads/dpt_head.py deleted file mode 100644 index c6c6d9470d78e1d944cc505f97865f026a9458d3..0000000000000000000000000000000000000000 --- a/dinov2/eval/depth/models/decode_heads/dpt_head.py +++ /dev/null @@ -1,270 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import math - -import torch -import torch.nn as nn -from mmcv.cnn import ConvModule, Linear, build_activation_layer -from mmcv.runner import BaseModule - -from ...ops import resize -from ..builder import HEADS -from .decode_head import DepthBaseDecodeHead - - -class Interpolate(nn.Module): - def __init__(self, scale_factor, mode, align_corners=False): - super(Interpolate, self).__init__() - self.interp = nn.functional.interpolate - self.scale_factor = scale_factor - self.mode = mode - self.align_corners = align_corners - - def forward(self, x): - x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners) - return x - - -class HeadDepth(nn.Module): - def __init__(self, features): - super(HeadDepth, self).__init__() - self.head = nn.Sequential( - nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), - Interpolate(scale_factor=2, mode="bilinear", align_corners=True), - nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), - nn.ReLU(), - nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), - ) - - def forward(self, x): - x = self.head(x) - return x - - -class ReassembleBlocks(BaseModule): - """ViTPostProcessBlock, process cls_token in ViT backbone output and - rearrange the feature vector to feature map. - Args: - in_channels (int): ViT feature channels. Default: 768. - out_channels (List): output channels of each stage. - Default: [96, 192, 384, 768]. - readout_type (str): Type of readout operation. Default: 'ignore'. - patch_size (int): The patch size. Default: 16. - init_cfg (dict, optional): Initialization config dict. Default: None. - """ - - def __init__( - self, in_channels=768, out_channels=[96, 192, 384, 768], readout_type="ignore", patch_size=16, init_cfg=None - ): - super(ReassembleBlocks, self).__init__(init_cfg) - - assert readout_type in ["ignore", "add", "project"] - self.readout_type = readout_type - self.patch_size = patch_size - - self.projects = nn.ModuleList( - [ - ConvModule( - in_channels=in_channels, - out_channels=out_channel, - kernel_size=1, - act_cfg=None, - ) - for out_channel in out_channels - ] - ) - - self.resize_layers = nn.ModuleList( - [ - nn.ConvTranspose2d( - in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0 - ), - nn.ConvTranspose2d( - in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0 - ), - nn.Identity(), - nn.Conv2d( - in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1 - ), - ] - ) - if self.readout_type == "project": - self.readout_projects = nn.ModuleList() - for _ in range(len(self.projects)): - self.readout_projects.append( - nn.Sequential(Linear(2 * in_channels, in_channels), build_activation_layer(dict(type="GELU"))) - ) - - def forward(self, inputs): - assert isinstance(inputs, list) - out = [] - for i, x in enumerate(inputs): - assert len(x) == 2 - x, cls_token = x[0], x[1] - feature_shape = x.shape - if self.readout_type == "project": - x = x.flatten(2).permute((0, 2, 1)) - readout = cls_token.unsqueeze(1).expand_as(x) - x = self.readout_projects[i](torch.cat((x, readout), -1)) - x = x.permute(0, 2, 1).reshape(feature_shape) - elif self.readout_type == "add": - x = x.flatten(2) + cls_token.unsqueeze(-1) - x = x.reshape(feature_shape) - else: - pass - x = self.projects[i](x) - x = self.resize_layers[i](x) - out.append(x) - return out - - -class PreActResidualConvUnit(BaseModule): - """ResidualConvUnit, pre-activate residual unit. - Args: - in_channels (int): number of channels in the input feature map. - act_cfg (dict): dictionary to construct and config activation layer. - norm_cfg (dict): dictionary to construct and config norm layer. - stride (int): stride of the first block. Default: 1 - dilation (int): dilation rate for convs layers. Default: 1. - init_cfg (dict, optional): Initialization config dict. Default: None. - """ - - def __init__(self, in_channels, act_cfg, norm_cfg, stride=1, dilation=1, init_cfg=None): - super(PreActResidualConvUnit, self).__init__(init_cfg) - - self.conv1 = ConvModule( - in_channels, - in_channels, - 3, - stride=stride, - padding=dilation, - dilation=dilation, - norm_cfg=norm_cfg, - act_cfg=act_cfg, - bias=False, - order=("act", "conv", "norm"), - ) - - self.conv2 = ConvModule( - in_channels, - in_channels, - 3, - padding=1, - norm_cfg=norm_cfg, - act_cfg=act_cfg, - bias=False, - order=("act", "conv", "norm"), - ) - - def forward(self, inputs): - inputs_ = inputs.clone() - x = self.conv1(inputs) - x = self.conv2(x) - return x + inputs_ - - -class FeatureFusionBlock(BaseModule): - """FeatureFusionBlock, merge feature map from different stages. - Args: - in_channels (int): Input channels. - act_cfg (dict): The activation config for ResidualConvUnit. - norm_cfg (dict): Config dict for normalization layer. - expand (bool): Whether expand the channels in post process block. - Default: False. - align_corners (bool): align_corner setting for bilinear upsample. - Default: True. - init_cfg (dict, optional): Initialization config dict. Default: None. - """ - - def __init__(self, in_channels, act_cfg, norm_cfg, expand=False, align_corners=True, init_cfg=None): - super(FeatureFusionBlock, self).__init__(init_cfg) - - self.in_channels = in_channels - self.expand = expand - self.align_corners = align_corners - - self.out_channels = in_channels - if self.expand: - self.out_channels = in_channels // 2 - - self.project = ConvModule(self.in_channels, self.out_channels, kernel_size=1, act_cfg=None, bias=True) - - self.res_conv_unit1 = PreActResidualConvUnit(in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg) - self.res_conv_unit2 = PreActResidualConvUnit(in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg) - - def forward(self, *inputs): - x = inputs[0] - if len(inputs) == 2: - if x.shape != inputs[1].shape: - res = resize(inputs[1], size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False) - else: - res = inputs[1] - x = x + self.res_conv_unit1(res) - x = self.res_conv_unit2(x) - x = resize(x, scale_factor=2, mode="bilinear", align_corners=self.align_corners) - x = self.project(x) - return x - - -@HEADS.register_module() -class DPTHead(DepthBaseDecodeHead): - """Vision Transformers for Dense Prediction. - This head is implemented of `DPT `_. - Args: - embed_dims (int): The embed dimension of the ViT backbone. - Default: 768. - post_process_channels (List): Out channels of post process conv - layers. Default: [96, 192, 384, 768]. - readout_type (str): Type of readout operation. Default: 'ignore'. - patch_size (int): The patch size. Default: 16. - expand_channels (bool): Whether expand the channels in post process - block. Default: False. - """ - - def __init__( - self, - embed_dims=768, - post_process_channels=[96, 192, 384, 768], - readout_type="ignore", - patch_size=16, - expand_channels=False, - **kwargs - ): - super(DPTHead, self).__init__(**kwargs) - - self.in_channels = self.in_channels - self.expand_channels = expand_channels - self.reassemble_blocks = ReassembleBlocks(embed_dims, post_process_channels, readout_type, patch_size) - - self.post_process_channels = [ - channel * math.pow(2, i) if expand_channels else channel for i, channel in enumerate(post_process_channels) - ] - self.convs = nn.ModuleList() - for channel in self.post_process_channels: - self.convs.append(ConvModule(channel, self.channels, kernel_size=3, padding=1, act_cfg=None, bias=False)) - self.fusion_blocks = nn.ModuleList() - for _ in range(len(self.convs)): - self.fusion_blocks.append(FeatureFusionBlock(self.channels, self.act_cfg, self.norm_cfg)) - self.fusion_blocks[0].res_conv_unit1 = None - self.project = ConvModule(self.channels, self.channels, kernel_size=3, padding=1, norm_cfg=self.norm_cfg) - self.num_fusion_blocks = len(self.fusion_blocks) - self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers) - self.num_post_process_channels = len(self.post_process_channels) - assert self.num_fusion_blocks == self.num_reassemble_blocks - assert self.num_reassemble_blocks == self.num_post_process_channels - self.conv_depth = HeadDepth(self.channels) - - def forward(self, inputs, img_metas): - assert len(inputs) == self.num_reassemble_blocks - x = [inp for inp in inputs] - x = self.reassemble_blocks(x) - x = [self.convs[i](feature) for i, feature in enumerate(x)] - out = self.fusion_blocks[0](x[-1]) - for i in range(1, len(self.fusion_blocks)): - out = self.fusion_blocks[i](out, x[-(i + 1)]) - out = self.project(out) - out = self.depth_pred(out) - return out diff --git a/dinov2/eval/depth/models/decode_heads/linear_head.py b/dinov2/eval/depth/models/decode_heads/linear_head.py deleted file mode 100644 index 3da1436f6a3f0bcc389d74ed86d44d455d2f7a87..0000000000000000000000000000000000000000 --- a/dinov2/eval/depth/models/decode_heads/linear_head.py +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import torch -import torch.nn as nn - -from ...ops import resize -from ..builder import HEADS -from .decode_head import DepthBaseDecodeHead - - -@HEADS.register_module() -class BNHead(DepthBaseDecodeHead): - """Just a batchnorm.""" - - def __init__(self, input_transform="resize_concat", in_index=(0, 1, 2, 3), upsample=1, **kwargs): - super().__init__(**kwargs) - self.input_transform = input_transform - self.in_index = in_index - self.upsample = upsample - # self.bn = nn.SyncBatchNorm(self.in_channels) - if self.classify: - self.conv_depth = nn.Conv2d(self.channels, self.n_bins, kernel_size=1, padding=0, stride=1) - else: - self.conv_depth = nn.Conv2d(self.channels, 1, kernel_size=1, padding=0, stride=1) - - def _transform_inputs(self, inputs): - """Transform inputs for decoder. - Args: - inputs (list[Tensor]): List of multi-level img features. - Returns: - Tensor: The transformed inputs - """ - - if "concat" in self.input_transform: - inputs = [inputs[i] for i in self.in_index] - if "resize" in self.input_transform: - inputs = [ - resize( - input=x, - size=[s * self.upsample for s in inputs[0].shape[2:]], - mode="bilinear", - align_corners=self.align_corners, - ) - for x in inputs - ] - inputs = torch.cat(inputs, dim=1) - elif self.input_transform == "multiple_select": - inputs = [inputs[i] for i in self.in_index] - else: - inputs = inputs[self.in_index] - - return inputs - - def _forward_feature(self, inputs, img_metas=None, **kwargs): - """Forward function for feature maps before classifying each pixel with - ``self.cls_seg`` fc. - Args: - inputs (list[Tensor]): List of multi-level img features. - Returns: - feats (Tensor): A tensor of shape (batch_size, self.channels, - H, W) which is feature map for last layer of decoder head. - """ - # accept lists (for cls token) - inputs = list(inputs) - for i, x in enumerate(inputs): - if len(x) == 2: - x, cls_token = x[0], x[1] - if len(x.shape) == 2: - x = x[:, :, None, None] - cls_token = cls_token[:, :, None, None].expand_as(x) - inputs[i] = torch.cat((x, cls_token), 1) - else: - x = x[0] - if len(x.shape) == 2: - x = x[:, :, None, None] - inputs[i] = x - x = self._transform_inputs(inputs) - # feats = self.bn(x) - return x - - def forward(self, inputs, img_metas=None, **kwargs): - """Forward function.""" - output = self._forward_feature(inputs, img_metas=img_metas, **kwargs) - output = self.depth_pred(output) - - return output diff --git a/dinov2/eval/depth/models/depther/__init__.py b/dinov2/eval/depth/models/depther/__init__.py deleted file mode 100644 index be99743bf6c773d05f2b74524116e368c0cfcba0..0000000000000000000000000000000000000000 --- a/dinov2/eval/depth/models/depther/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from .base import BaseDepther -from .encoder_decoder import DepthEncoderDecoder diff --git a/dinov2/eval/depth/models/depther/base.py b/dinov2/eval/depth/models/depther/base.py deleted file mode 100644 index e133a825a888167f90d95d67803609d6cac7ff55..0000000000000000000000000000000000000000 --- a/dinov2/eval/depth/models/depther/base.py +++ /dev/null @@ -1,194 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from abc import ABCMeta, abstractmethod -from collections import OrderedDict - -import torch -import torch.distributed as dist -from mmcv.runner import BaseModule, auto_fp16 - - -class BaseDepther(BaseModule, metaclass=ABCMeta): - """Base class for depther.""" - - def __init__(self, init_cfg=None): - super(BaseDepther, self).__init__(init_cfg) - self.fp16_enabled = False - - @property - def with_neck(self): - """bool: whether the depther has neck""" - return hasattr(self, "neck") and self.neck is not None - - @property - def with_auxiliary_head(self): - """bool: whether the depther has auxiliary head""" - return hasattr(self, "auxiliary_head") and self.auxiliary_head is not None - - @property - def with_decode_head(self): - """bool: whether the depther has decode head""" - return hasattr(self, "decode_head") and self.decode_head is not None - - @abstractmethod - def extract_feat(self, imgs): - """Placeholder for extract features from images.""" - pass - - @abstractmethod - def encode_decode(self, img, img_metas): - """Placeholder for encode images with backbone and decode into a - semantic depth map of the same size as input.""" - pass - - @abstractmethod - def forward_train(self, imgs, img_metas, **kwargs): - """Placeholder for Forward function for training.""" - pass - - @abstractmethod - def simple_test(self, img, img_meta, **kwargs): - """Placeholder for single image test.""" - pass - - @abstractmethod - def aug_test(self, imgs, img_metas, **kwargs): - """Placeholder for augmentation test.""" - pass - - def forward_test(self, imgs, img_metas, **kwargs): - """ - Args: - imgs (List[Tensor]): the outer list indicates test-time - augmentations and inner Tensor should have a shape NxCxHxW, - which contains all images in the batch. - img_metas (List[List[dict]]): the outer list indicates test-time - augs (multiscale, flip, etc.) and the inner list indicates - images in a batch. - """ - for var, name in [(imgs, "imgs"), (img_metas, "img_metas")]: - if not isinstance(var, list): - raise TypeError(f"{name} must be a list, but got " f"{type(var)}") - num_augs = len(imgs) - if num_augs != len(img_metas): - raise ValueError(f"num of augmentations ({len(imgs)}) != " f"num of image meta ({len(img_metas)})") - # all images in the same aug batch all of the same ori_shape and pad - # shape - for img_meta in img_metas: - ori_shapes = [_["ori_shape"] for _ in img_meta] - assert all(shape == ori_shapes[0] for shape in ori_shapes) - img_shapes = [_["img_shape"] for _ in img_meta] - assert all(shape == img_shapes[0] for shape in img_shapes) - pad_shapes = [_["pad_shape"] for _ in img_meta] - assert all(shape == pad_shapes[0] for shape in pad_shapes) - - if num_augs == 1: - return self.simple_test(imgs[0], img_metas[0], **kwargs) - else: - return self.aug_test(imgs, img_metas, **kwargs) - - @auto_fp16(apply_to=("img",)) - def forward(self, img, img_metas, return_loss=True, **kwargs): - """Calls either :func:`forward_train` or :func:`forward_test` depending - on whether ``return_loss`` is ``True``. - - Note this setting will change the expected inputs. When - ``return_loss=True``, img and img_meta are single-nested (i.e. Tensor - and List[dict]), and when ``resturn_loss=False``, img and img_meta - should be double nested (i.e. List[Tensor], List[List[dict]]), with - the outer list indicating test time augmentations. - """ - if return_loss: - return self.forward_train(img, img_metas, **kwargs) - else: - return self.forward_test(img, img_metas, **kwargs) - - def train_step(self, data_batch, optimizer, **kwargs): - """The iteration step during training. - - This method defines an iteration step during training, except for the - back propagation and optimizer updating, which are done in an optimizer - hook. Note that in some complicated cases or models, the whole process - including back propagation and optimizer updating is also defined in - this method, such as GAN. - - Args: - data (dict): The output of dataloader. - optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of - runner is passed to ``train_step()``. This argument is unused - and reserved. - - Returns: - dict: It should contain at least 3 keys: ``loss``, ``log_vars``, - ``num_samples``. - ``loss`` is a tensor for back propagation, which can be a - weighted sum of multiple losses. - ``log_vars`` contains all the variables to be sent to the - logger. - ``num_samples`` indicates the batch size (when the model is - DDP, it means the batch size on each GPU), which is used for - averaging the logs. - """ - losses = self(**data_batch) - - # split losses and images - real_losses = {} - log_imgs = {} - for k, v in losses.items(): - if "img" in k: - log_imgs[k] = v - else: - real_losses[k] = v - - loss, log_vars = self._parse_losses(real_losses) - - outputs = dict(loss=loss, log_vars=log_vars, num_samples=len(data_batch["img_metas"]), log_imgs=log_imgs) - - return outputs - - def val_step(self, data_batch, **kwargs): - """The iteration step during validation. - - This method shares the same signature as :func:`train_step`, but used - during val epochs. Note that the evaluation after training epochs is - not implemented with this method, but an evaluation hook. - """ - output = self(**data_batch, **kwargs) - return output - - @staticmethod - def _parse_losses(losses): - """Parse the raw outputs (losses) of the network. - - Args: - losses (dict): Raw output of the network, which usually contain - losses and other necessary information. - - Returns: - tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor - which may be a weighted sum of all losses, log_vars contains - all the variables to be sent to the logger. - """ - log_vars = OrderedDict() - for loss_name, loss_value in losses.items(): - if isinstance(loss_value, torch.Tensor): - log_vars[loss_name] = loss_value.mean() - elif isinstance(loss_value, list): - log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) - else: - raise TypeError(f"{loss_name} is not a tensor or list of tensors") - - loss = sum(_value for _key, _value in log_vars.items() if "loss" in _key) - - log_vars["loss"] = loss - for loss_name, loss_value in log_vars.items(): - # reduce loss when distributed training - if dist.is_available() and dist.is_initialized(): - loss_value = loss_value.data.clone() - dist.all_reduce(loss_value.div_(dist.get_world_size())) - log_vars[loss_name] = loss_value.item() - - return loss, log_vars diff --git a/dinov2/eval/depth/models/depther/encoder_decoder.py b/dinov2/eval/depth/models/depther/encoder_decoder.py deleted file mode 100644 index 6b0ec2dd314fdf8ccf4414d81afb95326b7dc0c9..0000000000000000000000000000000000000000 --- a/dinov2/eval/depth/models/depther/encoder_decoder.py +++ /dev/null @@ -1,236 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import torch -import torch.nn.functional as F - -from ...models import builder -from ...models.builder import DEPTHER -from ...ops import resize -from .base import BaseDepther - - -def add_prefix(inputs, prefix): - """Add prefix for dict. - - Args: - inputs (dict): The input dict with str keys. - prefix (str): The prefix to add. - - Returns: - - dict: The dict with keys updated with ``prefix``. - """ - - outputs = dict() - for name, value in inputs.items(): - outputs[f"{prefix}.{name}"] = value - - return outputs - - -@DEPTHER.register_module() -class DepthEncoderDecoder(BaseDepther): - """Encoder Decoder depther. - - EncoderDecoder typically consists of backbone, (neck) and decode_head. - """ - - def __init__(self, backbone, decode_head, neck=None, train_cfg=None, test_cfg=None, pretrained=None, init_cfg=None): - super(DepthEncoderDecoder, self).__init__(init_cfg) - if pretrained is not None: - assert backbone.get("pretrained") is None, "both backbone and depther set pretrained weight" - backbone.pretrained = pretrained - self.backbone = builder.build_backbone(backbone) - self._init_decode_head(decode_head) - - if neck is not None: - self.neck = builder.build_neck(neck) - - self.train_cfg = train_cfg - self.test_cfg = test_cfg - - assert self.with_decode_head - - def _init_decode_head(self, decode_head): - """Initialize ``decode_head``""" - self.decode_head = builder.build_head(decode_head) - self.align_corners = self.decode_head.align_corners - - def extract_feat(self, img): - """Extract features from images.""" - x = self.backbone(img) - if self.with_neck: - x = self.neck(x) - return x - - def encode_decode(self, img, img_metas, rescale=True, size=None): - """Encode images with backbone and decode into a depth estimation - map of the same size as input.""" - x = self.extract_feat(img) - out = self._decode_head_forward_test(x, img_metas) - # crop the pred depth to the certain range. - out = torch.clamp(out, min=self.decode_head.min_depth, max=self.decode_head.max_depth) - if rescale: - if size is None: - if img_metas is not None: - size = img_metas[0]["ori_shape"][:2] - else: - size = img.shape[2:] - out = resize(input=out, size=size, mode="bilinear", align_corners=self.align_corners) - return out - - def _decode_head_forward_train(self, img, x, img_metas, depth_gt, **kwargs): - """Run forward function and calculate loss for decode head in - training.""" - losses = dict() - loss_decode = self.decode_head.forward_train(img, x, img_metas, depth_gt, self.train_cfg, **kwargs) - losses.update(add_prefix(loss_decode, "decode")) - return losses - - def _decode_head_forward_test(self, x, img_metas): - """Run forward function and calculate loss for decode head in - inference.""" - depth_pred = self.decode_head.forward_test(x, img_metas, self.test_cfg) - return depth_pred - - def forward_dummy(self, img): - """Dummy forward function.""" - depth = self.encode_decode(img, None) - - return depth - - def forward_train(self, img, img_metas, depth_gt, **kwargs): - """Forward function for training. - - Args: - img (Tensor): Input images. - img_metas (list[dict]): List of image info dict where each dict - has: 'img_shape', 'scale_factor', 'flip', and may also contain - 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. - For details on the values of these keys see - `depth/datasets/pipelines/formatting.py:Collect`. - depth_gt (Tensor): Depth gt - used if the architecture supports depth estimation task. - - Returns: - dict[str, Tensor]: a dictionary of loss components - """ - - x = self.extract_feat(img) - - losses = dict() - - # the last of x saves the info from neck - loss_decode = self._decode_head_forward_train(img, x, img_metas, depth_gt, **kwargs) - - losses.update(loss_decode) - - return losses - - def whole_inference(self, img, img_meta, rescale, size=None): - """Inference with full image.""" - depth_pred = self.encode_decode(img, img_meta, rescale, size=size) - - return depth_pred - - def slide_inference(self, img, img_meta, rescale): - """Inference by sliding-window with overlap. - - If h_crop > h_img or w_crop > w_img, the small patch will be used to - decode without padding. - """ - - h_stride, w_stride = self.test_cfg.stride - h_crop, w_crop = self.test_cfg.crop_size - batch_size, _, h_img, w_img = img.size() - h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 - w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 - preds = img.new_zeros((batch_size, 1, h_img, w_img)) - count_mat = img.new_zeros((batch_size, 1, h_img, w_img)) - for h_idx in range(h_grids): - for w_idx in range(w_grids): - y1 = h_idx * h_stride - x1 = w_idx * w_stride - y2 = min(y1 + h_crop, h_img) - x2 = min(x1 + w_crop, w_img) - y1 = max(y2 - h_crop, 0) - x1 = max(x2 - w_crop, 0) - crop_img = img[:, :, y1:y2, x1:x2] - depth_pred = self.encode_decode(crop_img, img_meta, rescale) - preds += F.pad(depth_pred, (int(x1), int(preds.shape[3] - x2), int(y1), int(preds.shape[2] - y2))) - - count_mat[:, :, y1:y2, x1:x2] += 1 - assert (count_mat == 0).sum() == 0 - if torch.onnx.is_in_onnx_export(): - # cast count_mat to constant while exporting to ONNX - count_mat = torch.from_numpy(count_mat.cpu().detach().numpy()).to(device=img.device) - preds = preds / count_mat - return preds - - def inference(self, img, img_meta, rescale, size=None): - """Inference with slide/whole style. - - Args: - img (Tensor): The input image of shape (N, 3, H, W). - img_meta (dict): Image info dict where each dict has: 'img_shape', - 'scale_factor', 'flip', and may also contain - 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. - For details on the values of these keys see - `depth/datasets/pipelines/formatting.py:Collect`. - rescale (bool): Whether rescale back to original shape. - - Returns: - Tensor: The output depth map. - """ - - assert self.test_cfg.mode in ["slide", "whole"] - ori_shape = img_meta[0]["ori_shape"] - assert all(_["ori_shape"] == ori_shape for _ in img_meta) - if self.test_cfg.mode == "slide": - depth_pred = self.slide_inference(img, img_meta, rescale) - else: - depth_pred = self.whole_inference(img, img_meta, rescale, size=size) - output = depth_pred - flip = img_meta[0]["flip"] - if flip: - flip_direction = img_meta[0]["flip_direction"] - assert flip_direction in ["horizontal", "vertical"] - if flip_direction == "horizontal": - output = output.flip(dims=(3,)) - elif flip_direction == "vertical": - output = output.flip(dims=(2,)) - - return output - - def simple_test(self, img, img_meta, rescale=True): - """Simple test with single image.""" - depth_pred = self.inference(img, img_meta, rescale) - if torch.onnx.is_in_onnx_export(): - # our inference backend only support 4D output - depth_pred = depth_pred.unsqueeze(0) - return depth_pred - depth_pred = depth_pred.cpu().numpy() - # unravel batch dim - depth_pred = list(depth_pred) - return depth_pred - - def aug_test(self, imgs, img_metas, rescale=True): - """Test with augmentations. - - Only rescale=True is supported. - """ - # aug_test rescale all imgs back to ori_shape for now - assert rescale - # to save memory, we get augmented depth logit inplace - depth_pred = self.inference(imgs[0], img_metas[0], rescale) - for i in range(1, len(imgs)): - cur_depth_pred = self.inference(imgs[i], img_metas[i], rescale, size=depth_pred.shape[-2:]) - depth_pred += cur_depth_pred - depth_pred /= len(imgs) - depth_pred = depth_pred.cpu().numpy() - # unravel batch dim - depth_pred = list(depth_pred) - return depth_pred diff --git a/dinov2/eval/depth/models/losses/__init__.py b/dinov2/eval/depth/models/losses/__init__.py deleted file mode 100644 index 2f86242e342776da2e0acc61150d15a8d58ff1e0..0000000000000000000000000000000000000000 --- a/dinov2/eval/depth/models/losses/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from .gradientloss import GradientLoss -from .sigloss import SigLoss diff --git a/dinov2/eval/depth/models/losses/gradientloss.py b/dinov2/eval/depth/models/losses/gradientloss.py deleted file mode 100644 index 1599878a6b70cdff4f8467e1e875f0d13ea89eca..0000000000000000000000000000000000000000 --- a/dinov2/eval/depth/models/losses/gradientloss.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import torch -import torch.nn as nn - -from ...models.builder import LOSSES - - -@LOSSES.register_module() -class GradientLoss(nn.Module): - """GradientLoss. - - Adapted from https://www.cs.cornell.edu/projects/megadepth/ - - Args: - valid_mask (bool): Whether filter invalid gt (gt > 0). Default: True. - loss_weight (float): Weight of the loss. Default: 1.0. - max_depth (int): When filtering invalid gt, set a max threshold. Default: None. - """ - - def __init__(self, valid_mask=True, loss_weight=1.0, max_depth=None, loss_name="loss_grad"): - super(GradientLoss, self).__init__() - self.valid_mask = valid_mask - self.loss_weight = loss_weight - self.max_depth = max_depth - self.loss_name = loss_name - - self.eps = 0.001 # avoid grad explode - - def gradientloss(self, input, target): - input_downscaled = [input] + [input[:: 2 * i, :: 2 * i] for i in range(1, 4)] - target_downscaled = [target] + [target[:: 2 * i, :: 2 * i] for i in range(1, 4)] - - gradient_loss = 0 - for input, target in zip(input_downscaled, target_downscaled): - if self.valid_mask: - mask = target > 0 - if self.max_depth is not None: - mask = torch.logical_and(target > 0, target <= self.max_depth) - N = torch.sum(mask) - else: - mask = torch.ones_like(target) - N = input.numel() - input_log = torch.log(input + self.eps) - target_log = torch.log(target + self.eps) - log_d_diff = input_log - target_log - - log_d_diff = torch.mul(log_d_diff, mask) - - v_gradient = torch.abs(log_d_diff[0:-2, :] - log_d_diff[2:, :]) - v_mask = torch.mul(mask[0:-2, :], mask[2:, :]) - v_gradient = torch.mul(v_gradient, v_mask) - - h_gradient = torch.abs(log_d_diff[:, 0:-2] - log_d_diff[:, 2:]) - h_mask = torch.mul(mask[:, 0:-2], mask[:, 2:]) - h_gradient = torch.mul(h_gradient, h_mask) - - gradient_loss += (torch.sum(h_gradient) + torch.sum(v_gradient)) / N - - return gradient_loss - - def forward(self, depth_pred, depth_gt): - """Forward function.""" - - gradient_loss = self.loss_weight * self.gradientloss(depth_pred, depth_gt) - return gradient_loss diff --git a/dinov2/eval/depth/models/losses/sigloss.py b/dinov2/eval/depth/models/losses/sigloss.py deleted file mode 100644 index e12fad3e6151e4b975dd055193fdaec0206d4a14..0000000000000000000000000000000000000000 --- a/dinov2/eval/depth/models/losses/sigloss.py +++ /dev/null @@ -1,65 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import torch -import torch.nn as nn - -from ...models.builder import LOSSES - - -@LOSSES.register_module() -class SigLoss(nn.Module): - """SigLoss. - - This follows `AdaBins `_. - - Args: - valid_mask (bool): Whether filter invalid gt (gt > 0). Default: True. - loss_weight (float): Weight of the loss. Default: 1.0. - max_depth (int): When filtering invalid gt, set a max threshold. Default: None. - warm_up (bool): A simple warm up stage to help convergence. Default: False. - warm_iter (int): The number of warm up stage. Default: 100. - """ - - def __init__( - self, valid_mask=True, loss_weight=1.0, max_depth=None, warm_up=False, warm_iter=100, loss_name="sigloss" - ): - super(SigLoss, self).__init__() - self.valid_mask = valid_mask - self.loss_weight = loss_weight - self.max_depth = max_depth - self.loss_name = loss_name - - self.eps = 0.001 # avoid grad explode - - # HACK: a hack implementation for warmup sigloss - self.warm_up = warm_up - self.warm_iter = warm_iter - self.warm_up_counter = 0 - - def sigloss(self, input, target): - if self.valid_mask: - valid_mask = target > 0 - if self.max_depth is not None: - valid_mask = torch.logical_and(target > 0, target <= self.max_depth) - input = input[valid_mask] - target = target[valid_mask] - - if self.warm_up: - if self.warm_up_counter < self.warm_iter: - g = torch.log(input + self.eps) - torch.log(target + self.eps) - g = 0.15 * torch.pow(torch.mean(g), 2) - self.warm_up_counter += 1 - return torch.sqrt(g) - - g = torch.log(input + self.eps) - torch.log(target + self.eps) - Dg = torch.var(g) + 0.15 * torch.pow(torch.mean(g), 2) - return torch.sqrt(Dg) - - def forward(self, depth_pred, depth_gt): - """Forward function.""" - - loss_depth = self.loss_weight * self.sigloss(depth_pred, depth_gt) - return loss_depth diff --git a/dinov2/eval/depth/ops/__init__.py b/dinov2/eval/depth/ops/__init__.py deleted file mode 100644 index 78181c29581a281b5f42cf12078636aaeb43b5a5..0000000000000000000000000000000000000000 --- a/dinov2/eval/depth/ops/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from .wrappers import resize diff --git a/dinov2/eval/depth/ops/wrappers.py b/dinov2/eval/depth/ops/wrappers.py deleted file mode 100644 index 15880ee0cb7652d4b41c489b927bf6a156b40e5e..0000000000000000000000000000000000000000 --- a/dinov2/eval/depth/ops/wrappers.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import warnings - -import torch.nn.functional as F - - -def resize(input, size=None, scale_factor=None, mode="nearest", align_corners=None, warning=False): - if warning: - if size is not None and align_corners: - input_h, input_w = tuple(int(x) for x in input.shape[2:]) - output_h, output_w = tuple(int(x) for x in size) - if output_h > input_h or output_w > output_h: - if ( - (output_h > 1 and output_w > 1 and input_h > 1 and input_w > 1) - and (output_h - 1) % (input_h - 1) - and (output_w - 1) % (input_w - 1) - ): - warnings.warn( - f"When align_corners={align_corners}, " - "the output would more aligned if " - f"input size {(input_h, input_w)} is `x+1` and " - f"out size {(output_h, output_w)} is `nx+1`" - ) - return F.interpolate(input, size, scale_factor, mode, align_corners) diff --git a/dinov2/eval/knn.py b/dinov2/eval/knn.py deleted file mode 100644 index f3a4845da1313a6db6b8345bb9a98230fcd24acf..0000000000000000000000000000000000000000 --- a/dinov2/eval/knn.py +++ /dev/null @@ -1,404 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import argparse -from functools import partial -import json -import logging -import os -import sys -from typing import List, Optional - -import torch -from torch.nn.functional import one_hot, softmax - -import dinov2.distributed as distributed -from dinov2.data import SamplerType, make_data_loader, make_dataset -from dinov2.data.transforms import make_classification_eval_transform -from dinov2.eval.metrics import AccuracyAveraging, build_topk_accuracy_metric -from dinov2.eval.setup import get_args_parser as get_setup_args_parser -from dinov2.eval.setup import setup_and_build_model -from dinov2.eval.utils import ModelWithNormalize, evaluate, extract_features - - -logger = logging.getLogger("dinov2") - - -def get_args_parser( - description: Optional[str] = None, - parents: Optional[List[argparse.ArgumentParser]] = None, - add_help: bool = True, -): - parents = parents or [] - setup_args_parser = get_setup_args_parser(parents=parents, add_help=False) - parents = [setup_args_parser] - parser = argparse.ArgumentParser( - description=description, - parents=parents, - add_help=add_help, - ) - parser.add_argument( - "--train-dataset", - dest="train_dataset_str", - type=str, - help="Training dataset", - ) - parser.add_argument( - "--val-dataset", - dest="val_dataset_str", - type=str, - help="Validation dataset", - ) - parser.add_argument( - "--nb_knn", - nargs="+", - type=int, - help="Number of NN to use. 20 is usually working the best.", - ) - parser.add_argument( - "--temperature", - type=float, - help="Temperature used in the voting coefficient", - ) - parser.add_argument( - "--gather-on-cpu", - action="store_true", - help="Whether to gather the train features on cpu, slower" - "but useful to avoid OOM for large datasets (e.g. ImageNet22k).", - ) - parser.add_argument( - "--batch-size", - type=int, - help="Batch size.", - ) - parser.add_argument( - "--n-per-class-list", - nargs="+", - type=int, - help="Number to take per class", - ) - parser.add_argument( - "--n-tries", - type=int, - help="Number of tries", - ) - parser.set_defaults( - train_dataset_str="ImageNet:split=TRAIN", - val_dataset_str="ImageNet:split=VAL", - nb_knn=[10, 20, 100, 200], - temperature=0.07, - batch_size=256, - n_per_class_list=[-1], - n_tries=1, - ) - return parser - - -class KnnModule(torch.nn.Module): - """ - Gets knn of test features from all processes on a chunk of the train features - - Each rank gets a chunk of the train features as well as a chunk of the test features. - In `compute_neighbors`, for each rank one after the other, its chunk of test features - is sent to all devices, partial knns are computed with each chunk of train features - then collated back on the original device. - """ - - def __init__(self, train_features, train_labels, nb_knn, T, device, num_classes=1000): - super().__init__() - - self.global_rank = distributed.get_global_rank() - self.global_size = distributed.get_global_size() - - self.device = device - self.train_features_rank_T = train_features.chunk(self.global_size)[self.global_rank].T.to(self.device) - self.candidates = train_labels.chunk(self.global_size)[self.global_rank].view(1, -1).to(self.device) - - self.nb_knn = nb_knn - self.max_k = max(self.nb_knn) - self.T = T - self.num_classes = num_classes - - def _get_knn_sims_and_labels(self, similarity, train_labels): - topk_sims, indices = similarity.topk(self.max_k, largest=True, sorted=True) - neighbors_labels = torch.gather(train_labels, 1, indices) - return topk_sims, neighbors_labels - - def _similarity_for_rank(self, features_rank, source_rank): - # Send the features from `source_rank` to all ranks - broadcast_shape = torch.tensor(features_rank.shape).to(self.device) - torch.distributed.broadcast(broadcast_shape, source_rank) - - broadcasted = features_rank - if self.global_rank != source_rank: - broadcasted = torch.zeros(*broadcast_shape, dtype=features_rank.dtype, device=self.device) - torch.distributed.broadcast(broadcasted, source_rank) - - # Compute the neighbors for `source_rank` among `train_features_rank_T` - similarity_rank = torch.mm(broadcasted, self.train_features_rank_T) - candidate_labels = self.candidates.expand(len(similarity_rank), -1) - return self._get_knn_sims_and_labels(similarity_rank, candidate_labels) - - def _gather_all_knn_for_rank(self, topk_sims, neighbors_labels, target_rank): - # Gather all neighbors for `target_rank` - topk_sims_rank = retrieved_rank = None - if self.global_rank == target_rank: - topk_sims_rank = [torch.zeros_like(topk_sims) for _ in range(self.global_size)] - retrieved_rank = [torch.zeros_like(neighbors_labels) for _ in range(self.global_size)] - - torch.distributed.gather(topk_sims, topk_sims_rank, dst=target_rank) - torch.distributed.gather(neighbors_labels, retrieved_rank, dst=target_rank) - - if self.global_rank == target_rank: - # Perform a second top-k on the k * global_size retrieved neighbors - topk_sims_rank = torch.cat(topk_sims_rank, dim=1) - retrieved_rank = torch.cat(retrieved_rank, dim=1) - results = self._get_knn_sims_and_labels(topk_sims_rank, retrieved_rank) - return results - return None - - def compute_neighbors(self, features_rank): - for rank in range(self.global_size): - topk_sims, neighbors_labels = self._similarity_for_rank(features_rank, rank) - results = self._gather_all_knn_for_rank(topk_sims, neighbors_labels, rank) - if results is not None: - topk_sims_rank, neighbors_labels_rank = results - return topk_sims_rank, neighbors_labels_rank - - def forward(self, features_rank): - """ - Compute the results on all values of `self.nb_knn` neighbors from the full `self.max_k` - """ - assert all(k <= self.max_k for k in self.nb_knn) - - topk_sims, neighbors_labels = self.compute_neighbors(features_rank) - batch_size = neighbors_labels.shape[0] - topk_sims_transform = softmax(topk_sims / self.T, 1) - matmul = torch.mul( - one_hot(neighbors_labels, num_classes=self.num_classes), - topk_sims_transform.view(batch_size, -1, 1), - ) - probas_for_k = {k: torch.sum(matmul[:, :k, :], 1) for k in self.nb_knn} - return probas_for_k - - -class DictKeysModule(torch.nn.Module): - def __init__(self, keys): - super().__init__() - self.keys = keys - - def forward(self, features_dict, targets): - for k in self.keys: - features_dict = features_dict[k] - return {"preds": features_dict, "target": targets} - - -def create_module_dict(*, module, n_per_class_list, n_tries, nb_knn, train_features, train_labels): - modules = {} - mapping = create_class_indices_mapping(train_labels) - for npc in n_per_class_list: - if npc < 0: # Only one try needed when using the full data - full_module = module( - train_features=train_features, - train_labels=train_labels, - nb_knn=nb_knn, - ) - modules["full"] = ModuleDictWithForward({"1": full_module}) - continue - all_tries = {} - for t in range(n_tries): - final_indices = filter_train(mapping, npc, seed=t) - k_list = list(set(nb_knn + [npc])) - k_list = sorted([el for el in k_list if el <= npc]) - all_tries[str(t)] = module( - train_features=train_features[final_indices], - train_labels=train_labels[final_indices], - nb_knn=k_list, - ) - modules[f"{npc} per class"] = ModuleDictWithForward(all_tries) - - return ModuleDictWithForward(modules) - - -def filter_train(mapping, n_per_class, seed): - torch.manual_seed(seed) - final_indices = [] - for k in mapping.keys(): - index = torch.randperm(len(mapping[k]))[:n_per_class] - final_indices.append(mapping[k][index]) - return torch.cat(final_indices).squeeze() - - -def create_class_indices_mapping(labels): - unique_labels, inverse = torch.unique(labels, return_inverse=True) - mapping = {unique_labels[i]: (inverse == i).nonzero() for i in range(len(unique_labels))} - return mapping - - -class ModuleDictWithForward(torch.nn.ModuleDict): - def forward(self, *args, **kwargs): - return {k: module(*args, **kwargs) for k, module in self._modules.items()} - - -def eval_knn( - model, - train_dataset, - val_dataset, - accuracy_averaging, - nb_knn, - temperature, - batch_size, - num_workers, - gather_on_cpu, - n_per_class_list=[-1], - n_tries=1, -): - model = ModelWithNormalize(model) - - logger.info("Extracting features for train set...") - train_features, train_labels = extract_features( - model, train_dataset, batch_size, num_workers, gather_on_cpu=gather_on_cpu - ) - logger.info(f"Train features created, shape {train_features.shape}.") - - val_dataloader = make_data_loader( - dataset=val_dataset, - batch_size=batch_size, - num_workers=num_workers, - sampler_type=SamplerType.DISTRIBUTED, - drop_last=False, - shuffle=False, - persistent_workers=True, - ) - num_classes = train_labels.max() + 1 - metric_collection = build_topk_accuracy_metric(accuracy_averaging, num_classes=num_classes) - - device = torch.cuda.current_device() - partial_module = partial(KnnModule, T=temperature, device=device, num_classes=num_classes) - knn_module_dict = create_module_dict( - module=partial_module, - n_per_class_list=n_per_class_list, - n_tries=n_tries, - nb_knn=nb_knn, - train_features=train_features, - train_labels=train_labels, - ) - postprocessors, metrics = {}, {} - for n_per_class, knn_module in knn_module_dict.items(): - for t, knn_try in knn_module.items(): - postprocessors = { - **postprocessors, - **{(n_per_class, t, k): DictKeysModule([n_per_class, t, k]) for k in knn_try.nb_knn}, - } - metrics = {**metrics, **{(n_per_class, t, k): metric_collection.clone() for k in knn_try.nb_knn}} - model_with_knn = torch.nn.Sequential(model, knn_module_dict) - - # ============ evaluation ... ============ - logger.info("Start the k-NN classification.") - _, results_dict = evaluate(model_with_knn, val_dataloader, postprocessors, metrics, device) - - # Averaging the results over the n tries for each value of n_per_class - for n_per_class, knn_module in knn_module_dict.items(): - first_try = list(knn_module.keys())[0] - k_list = knn_module[first_try].nb_knn - for k in k_list: - keys = results_dict[(n_per_class, first_try, k)].keys() # keys are e.g. `top-1` and `top-5` - results_dict[(n_per_class, k)] = { - key: torch.mean(torch.stack([results_dict[(n_per_class, t, k)][key] for t in knn_module.keys()])) - for key in keys - } - for t in knn_module.keys(): - del results_dict[(n_per_class, t, k)] - - return results_dict - - -def eval_knn_with_model( - model, - output_dir, - train_dataset_str="ImageNet:split=TRAIN", - val_dataset_str="ImageNet:split=VAL", - nb_knn=(10, 20, 100, 200), - temperature=0.07, - autocast_dtype=torch.float, - accuracy_averaging=AccuracyAveraging.MEAN_ACCURACY, - transform=None, - gather_on_cpu=False, - batch_size=256, - num_workers=5, - n_per_class_list=[-1], - n_tries=1, -): - transform = transform or make_classification_eval_transform() - - train_dataset = make_dataset( - dataset_str=train_dataset_str, - transform=transform, - ) - val_dataset = make_dataset( - dataset_str=val_dataset_str, - transform=transform, - ) - - with torch.cuda.amp.autocast(dtype=autocast_dtype): - results_dict_knn = eval_knn( - model=model, - train_dataset=train_dataset, - val_dataset=val_dataset, - accuracy_averaging=accuracy_averaging, - nb_knn=nb_knn, - temperature=temperature, - batch_size=batch_size, - num_workers=num_workers, - gather_on_cpu=gather_on_cpu, - n_per_class_list=n_per_class_list, - n_tries=n_tries, - ) - - results_dict = {} - if distributed.is_main_process(): - for knn_ in results_dict_knn.keys(): - top1 = results_dict_knn[knn_]["top-1"].item() * 100.0 - top5 = results_dict_knn[knn_]["top-5"].item() * 100.0 - results_dict[f"{knn_} Top 1"] = top1 - results_dict[f"{knn_} Top 5"] = top5 - logger.info(f"{knn_} classifier result: Top1: {top1:.2f} Top5: {top5:.2f}") - - metrics_file_path = os.path.join(output_dir, "results_eval_knn.json") - with open(metrics_file_path, "a") as f: - for k, v in results_dict.items(): - f.write(json.dumps({k: v}) + "\n") - - if distributed.is_enabled(): - torch.distributed.barrier() - return results_dict - - -def main(args): - model, autocast_dtype = setup_and_build_model(args) - eval_knn_with_model( - model=model, - output_dir=args.output_dir, - train_dataset_str=args.train_dataset_str, - val_dataset_str=args.val_dataset_str, - nb_knn=args.nb_knn, - temperature=args.temperature, - autocast_dtype=autocast_dtype, - accuracy_averaging=AccuracyAveraging.MEAN_ACCURACY, - transform=None, - gather_on_cpu=args.gather_on_cpu, - batch_size=args.batch_size, - num_workers=5, - n_per_class_list=args.n_per_class_list, - n_tries=args.n_tries, - ) - return 0 - - -if __name__ == "__main__": - description = "DINOv2 k-NN evaluation" - args_parser = get_args_parser(description=description) - args = args_parser.parse_args() - sys.exit(main(args)) diff --git a/dinov2/eval/linear.py b/dinov2/eval/linear.py deleted file mode 100644 index 1bd4c5de5a041be8a188f007257d1e91b6d6921e..0000000000000000000000000000000000000000 --- a/dinov2/eval/linear.py +++ /dev/null @@ -1,625 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import argparse -from functools import partial -import json -import logging -import os -import sys -from typing import List, Optional - -import numpy as np -import torch -import torch.nn as nn -from torch.nn.parallel import DistributedDataParallel -from fvcore.common.checkpoint import Checkpointer, PeriodicCheckpointer - -from dinov2.data import SamplerType, make_data_loader, make_dataset -from dinov2.data.transforms import make_classification_eval_transform, make_classification_train_transform -import dinov2.distributed as distributed -from dinov2.eval.metrics import MetricType, build_metric -from dinov2.eval.setup import get_args_parser as get_setup_args_parser -from dinov2.eval.setup import setup_and_build_model -from dinov2.eval.utils import ModelWithIntermediateLayers, evaluate -from dinov2.logging import MetricLogger - - -logger = logging.getLogger("dinov2") - - -def get_args_parser( - description: Optional[str] = None, - parents: Optional[List[argparse.ArgumentParser]] = None, - add_help: bool = True, -): - parents = parents or [] - setup_args_parser = get_setup_args_parser(parents=parents, add_help=False) - parents = [setup_args_parser] - parser = argparse.ArgumentParser( - description=description, - parents=parents, - add_help=add_help, - ) - parser.add_argument( - "--train-dataset", - dest="train_dataset_str", - type=str, - help="Training dataset", - ) - parser.add_argument( - "--val-dataset", - dest="val_dataset_str", - type=str, - help="Validation dataset", - ) - parser.add_argument( - "--test-datasets", - dest="test_dataset_strs", - type=str, - nargs="+", - help="Test datasets, none to reuse the validation dataset", - ) - parser.add_argument( - "--epochs", - type=int, - help="Number of training epochs", - ) - parser.add_argument( - "--batch-size", - type=int, - help="Batch Size (per GPU)", - ) - parser.add_argument( - "--num-workers", - type=int, - help="Number de Workers", - ) - parser.add_argument( - "--epoch-length", - type=int, - help="Length of an epoch in number of iterations", - ) - parser.add_argument( - "--save-checkpoint-frequency", - type=int, - help="Number of epochs between two named checkpoint saves.", - ) - parser.add_argument( - "--eval-period-iterations", - type=int, - help="Number of iterations between two evaluations.", - ) - parser.add_argument( - "--learning-rates", - nargs="+", - type=float, - help="Learning rates to grid search.", - ) - parser.add_argument( - "--no-resume", - action="store_true", - help="Whether to not resume from existing checkpoints", - ) - parser.add_argument( - "--val-metric-type", - type=MetricType, - choices=list(MetricType), - help="Validation metric", - ) - parser.add_argument( - "--test-metric-types", - type=MetricType, - choices=list(MetricType), - nargs="+", - help="Evaluation metric", - ) - parser.add_argument( - "--classifier-fpath", - type=str, - help="Path to a file containing pretrained linear classifiers", - ) - parser.add_argument( - "--val-class-mapping-fpath", - type=str, - help="Path to a file containing a mapping to adjust classifier outputs", - ) - parser.add_argument( - "--test-class-mapping-fpaths", - nargs="+", - type=str, - help="Path to a file containing a mapping to adjust classifier outputs", - ) - parser.set_defaults( - train_dataset_str="ImageNet:split=TRAIN", - val_dataset_str="ImageNet:split=VAL", - test_dataset_strs=None, - epochs=10, - batch_size=128, - num_workers=8, - epoch_length=1250, - save_checkpoint_frequency=20, - eval_period_iterations=1250, - learning_rates=[1e-5, 2e-5, 5e-5, 1e-4, 2e-4, 5e-4, 1e-3, 2e-3, 5e-3, 1e-2, 2e-2, 5e-2, 0.1], - val_metric_type=MetricType.MEAN_ACCURACY, - test_metric_types=None, - classifier_fpath=None, - val_class_mapping_fpath=None, - test_class_mapping_fpaths=[None], - ) - return parser - - -def has_ddp_wrapper(m: nn.Module) -> bool: - return isinstance(m, DistributedDataParallel) - - -def remove_ddp_wrapper(m: nn.Module) -> nn.Module: - return m.module if has_ddp_wrapper(m) else m - - -def _pad_and_collate(batch): - maxlen = max(len(targets) for image, targets in batch) - padded_batch = [ - (image, np.pad(targets, (0, maxlen - len(targets)), constant_values=-1)) for image, targets in batch - ] - return torch.utils.data.default_collate(padded_batch) - - -def create_linear_input(x_tokens_list, use_n_blocks, use_avgpool): - intermediate_output = x_tokens_list[-use_n_blocks:] - output = torch.cat([class_token for _, class_token in intermediate_output], dim=-1) - if use_avgpool: - output = torch.cat( - ( - output, - torch.mean(intermediate_output[-1][0], dim=1), # patch tokens - ), - dim=-1, - ) - output = output.reshape(output.shape[0], -1) - return output.float() - - -class LinearClassifier(nn.Module): - """Linear layer to train on top of frozen features""" - - def __init__(self, out_dim, use_n_blocks, use_avgpool, num_classes=1000): - super().__init__() - self.out_dim = out_dim - self.use_n_blocks = use_n_blocks - self.use_avgpool = use_avgpool - self.num_classes = num_classes - self.linear = nn.Linear(out_dim, num_classes) - self.linear.weight.data.normal_(mean=0.0, std=0.01) - self.linear.bias.data.zero_() - - def forward(self, x_tokens_list): - output = create_linear_input(x_tokens_list, self.use_n_blocks, self.use_avgpool) - return self.linear(output) - - -class AllClassifiers(nn.Module): - def __init__(self, classifiers_dict): - super().__init__() - self.classifiers_dict = nn.ModuleDict() - self.classifiers_dict.update(classifiers_dict) - - def forward(self, inputs): - return {k: v.forward(inputs) for k, v in self.classifiers_dict.items()} - - def __len__(self): - return len(self.classifiers_dict) - - -class LinearPostprocessor(nn.Module): - def __init__(self, linear_classifier, class_mapping=None): - super().__init__() - self.linear_classifier = linear_classifier - self.register_buffer("class_mapping", None if class_mapping is None else torch.LongTensor(class_mapping)) - - def forward(self, samples, targets): - preds = self.linear_classifier(samples) - return { - "preds": preds[:, self.class_mapping] if self.class_mapping is not None else preds, - "target": targets, - } - - -def scale_lr(learning_rates, batch_size): - return learning_rates * (batch_size * distributed.get_global_size()) / 256.0 - - -def setup_linear_classifiers(sample_output, n_last_blocks_list, learning_rates, batch_size, num_classes=1000): - linear_classifiers_dict = nn.ModuleDict() - optim_param_groups = [] - for n in n_last_blocks_list: - for avgpool in [False, True]: - for _lr in learning_rates: - lr = scale_lr(_lr, batch_size) - out_dim = create_linear_input(sample_output, use_n_blocks=n, use_avgpool=avgpool).shape[1] - linear_classifier = LinearClassifier( - out_dim, use_n_blocks=n, use_avgpool=avgpool, num_classes=num_classes - ) - linear_classifier = linear_classifier.cuda() - linear_classifiers_dict[ - f"classifier_{n}_blocks_avgpool_{avgpool}_lr_{lr:.5f}".replace(".", "_") - ] = linear_classifier - optim_param_groups.append({"params": linear_classifier.parameters(), "lr": lr}) - - linear_classifiers = AllClassifiers(linear_classifiers_dict) - if distributed.is_enabled(): - linear_classifiers = nn.parallel.DistributedDataParallel(linear_classifiers) - - return linear_classifiers, optim_param_groups - - -@torch.no_grad() -def evaluate_linear_classifiers( - feature_model, - linear_classifiers, - data_loader, - metric_type, - metrics_file_path, - training_num_classes, - iteration, - prefixstring="", - class_mapping=None, - best_classifier_on_val=None, -): - logger.info("running validation !") - - num_classes = len(class_mapping) if class_mapping is not None else training_num_classes - metric = build_metric(metric_type, num_classes=num_classes) - postprocessors = {k: LinearPostprocessor(v, class_mapping) for k, v in linear_classifiers.classifiers_dict.items()} - metrics = {k: metric.clone() for k in linear_classifiers.classifiers_dict} - - _, results_dict_temp = evaluate( - feature_model, - data_loader, - postprocessors, - metrics, - torch.cuda.current_device(), - ) - - logger.info("") - results_dict = {} - max_accuracy = 0 - best_classifier = "" - for i, (classifier_string, metric) in enumerate(results_dict_temp.items()): - logger.info(f"{prefixstring} -- Classifier: {classifier_string} * {metric}") - if ( - best_classifier_on_val is None and metric["top-1"].item() > max_accuracy - ) or classifier_string == best_classifier_on_val: - max_accuracy = metric["top-1"].item() - best_classifier = classifier_string - - results_dict["best_classifier"] = {"name": best_classifier, "accuracy": max_accuracy} - - logger.info(f"best classifier: {results_dict['best_classifier']}") - - if distributed.is_main_process(): - with open(metrics_file_path, "a") as f: - f.write(f"iter: {iteration}\n") - for k, v in results_dict.items(): - f.write(json.dumps({k: v}) + "\n") - f.write("\n") - - return results_dict - - -def eval_linear( - *, - feature_model, - linear_classifiers, - train_data_loader, - val_data_loader, - metrics_file_path, - optimizer, - scheduler, - output_dir, - max_iter, - checkpoint_period, # In number of iter, creates a new file every period - running_checkpoint_period, # Period to update main checkpoint file - eval_period, - metric_type, - training_num_classes, - resume=True, - classifier_fpath=None, - val_class_mapping=None, -): - checkpointer = Checkpointer(linear_classifiers, output_dir, optimizer=optimizer, scheduler=scheduler) - start_iter = checkpointer.resume_or_load(classifier_fpath or "", resume=resume).get("iteration", -1) + 1 - - periodic_checkpointer = PeriodicCheckpointer(checkpointer, checkpoint_period, max_iter=max_iter) - iteration = start_iter - logger.info("Starting training from iteration {}".format(start_iter)) - metric_logger = MetricLogger(delimiter=" ") - header = "Training" - - for data, labels in metric_logger.log_every( - train_data_loader, - 10, - header, - max_iter, - start_iter, - ): - data = data.cuda(non_blocking=True) - labels = labels.cuda(non_blocking=True) - - features = feature_model(data) - outputs = linear_classifiers(features) - - losses = {f"loss_{k}": nn.CrossEntropyLoss()(v, labels) for k, v in outputs.items()} - loss = sum(losses.values()) - - # compute the gradients - optimizer.zero_grad() - loss.backward() - - # step - optimizer.step() - scheduler.step() - - # log - if iteration % 10 == 0: - torch.cuda.synchronize() - metric_logger.update(loss=loss.item()) - metric_logger.update(lr=optimizer.param_groups[0]["lr"]) - print("lr", optimizer.param_groups[0]["lr"]) - - if iteration - start_iter > 5: - if iteration % running_checkpoint_period == 0: - torch.cuda.synchronize() - if distributed.is_main_process(): - logger.info("Checkpointing running_checkpoint") - periodic_checkpointer.save("running_checkpoint_linear_eval", iteration=iteration) - torch.cuda.synchronize() - periodic_checkpointer.step(iteration) - - if eval_period > 0 and (iteration + 1) % eval_period == 0 and iteration != max_iter - 1: - _ = evaluate_linear_classifiers( - feature_model=feature_model, - linear_classifiers=remove_ddp_wrapper(linear_classifiers), - data_loader=val_data_loader, - metrics_file_path=metrics_file_path, - prefixstring=f"ITER: {iteration}", - metric_type=metric_type, - training_num_classes=training_num_classes, - iteration=iteration, - class_mapping=val_class_mapping, - ) - torch.cuda.synchronize() - - iteration = iteration + 1 - - val_results_dict = evaluate_linear_classifiers( - feature_model=feature_model, - linear_classifiers=remove_ddp_wrapper(linear_classifiers), - data_loader=val_data_loader, - metrics_file_path=metrics_file_path, - metric_type=metric_type, - training_num_classes=training_num_classes, - iteration=iteration, - class_mapping=val_class_mapping, - ) - return val_results_dict, feature_model, linear_classifiers, iteration - - -def make_eval_data_loader(test_dataset_str, batch_size, num_workers, metric_type): - test_dataset = make_dataset( - dataset_str=test_dataset_str, - transform=make_classification_eval_transform(), - ) - test_data_loader = make_data_loader( - dataset=test_dataset, - batch_size=batch_size, - num_workers=num_workers, - sampler_type=SamplerType.DISTRIBUTED, - drop_last=False, - shuffle=False, - persistent_workers=False, - collate_fn=_pad_and_collate if metric_type == MetricType.IMAGENET_REAL_ACCURACY else None, - ) - return test_data_loader - - -def test_on_datasets( - feature_model, - linear_classifiers, - test_dataset_strs, - batch_size, - num_workers, - test_metric_types, - metrics_file_path, - training_num_classes, - iteration, - best_classifier_on_val, - prefixstring="", - test_class_mappings=[None], -): - results_dict = {} - for test_dataset_str, class_mapping, metric_type in zip(test_dataset_strs, test_class_mappings, test_metric_types): - logger.info(f"Testing on {test_dataset_str}") - test_data_loader = make_eval_data_loader(test_dataset_str, batch_size, num_workers, metric_type) - dataset_results_dict = evaluate_linear_classifiers( - feature_model, - remove_ddp_wrapper(linear_classifiers), - test_data_loader, - metric_type, - metrics_file_path, - training_num_classes, - iteration, - prefixstring="", - class_mapping=class_mapping, - best_classifier_on_val=best_classifier_on_val, - ) - results_dict[f"{test_dataset_str}_accuracy"] = 100.0 * dataset_results_dict["best_classifier"]["accuracy"] - return results_dict - - -def run_eval_linear( - model, - output_dir, - train_dataset_str, - val_dataset_str, - batch_size, - epochs, - epoch_length, - num_workers, - save_checkpoint_frequency, - eval_period_iterations, - learning_rates, - autocast_dtype, - test_dataset_strs=None, - resume=True, - classifier_fpath=None, - val_class_mapping_fpath=None, - test_class_mapping_fpaths=[None], - val_metric_type=MetricType.MEAN_ACCURACY, - test_metric_types=None, -): - seed = 0 - - if test_dataset_strs is None: - test_dataset_strs = [val_dataset_str] - if test_metric_types is None: - test_metric_types = [val_metric_type] * len(test_dataset_strs) - else: - assert len(test_metric_types) == len(test_dataset_strs) - assert len(test_dataset_strs) == len(test_class_mapping_fpaths) - - train_transform = make_classification_train_transform() - train_dataset = make_dataset( - dataset_str=train_dataset_str, - transform=train_transform, - ) - training_num_classes = len(torch.unique(torch.Tensor(train_dataset.get_targets().astype(int)))) - sampler_type = SamplerType.SHARDED_INFINITE - # sampler_type = SamplerType.INFINITE - - n_last_blocks_list = [1, 4] - n_last_blocks = max(n_last_blocks_list) - autocast_ctx = partial(torch.cuda.amp.autocast, enabled=True, dtype=autocast_dtype) - feature_model = ModelWithIntermediateLayers(model, n_last_blocks, autocast_ctx) - sample_output = feature_model(train_dataset[0][0].unsqueeze(0).cuda()) - - linear_classifiers, optim_param_groups = setup_linear_classifiers( - sample_output, - n_last_blocks_list, - learning_rates, - batch_size, - training_num_classes, - ) - - optimizer = torch.optim.SGD(optim_param_groups, momentum=0.9, weight_decay=0) - max_iter = epochs * epoch_length - scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, max_iter, eta_min=0) - checkpointer = Checkpointer(linear_classifiers, output_dir, optimizer=optimizer, scheduler=scheduler) - start_iter = checkpointer.resume_or_load(classifier_fpath or "", resume=resume).get("iteration", -1) + 1 - train_data_loader = make_data_loader( - dataset=train_dataset, - batch_size=batch_size, - num_workers=num_workers, - shuffle=True, - seed=seed, - sampler_type=sampler_type, - sampler_advance=start_iter, - drop_last=True, - persistent_workers=True, - ) - val_data_loader = make_eval_data_loader(val_dataset_str, batch_size, num_workers, val_metric_type) - - checkpoint_period = save_checkpoint_frequency * epoch_length - - if val_class_mapping_fpath is not None: - logger.info(f"Using class mapping from {val_class_mapping_fpath}") - val_class_mapping = np.load(val_class_mapping_fpath) - else: - val_class_mapping = None - - test_class_mappings = [] - for class_mapping_fpath in test_class_mapping_fpaths: - if class_mapping_fpath is not None and class_mapping_fpath != "None": - logger.info(f"Using class mapping from {class_mapping_fpath}") - class_mapping = np.load(class_mapping_fpath) - else: - class_mapping = None - test_class_mappings.append(class_mapping) - - metrics_file_path = os.path.join(output_dir, "results_eval_linear.json") - val_results_dict, feature_model, linear_classifiers, iteration = eval_linear( - feature_model=feature_model, - linear_classifiers=linear_classifiers, - train_data_loader=train_data_loader, - val_data_loader=val_data_loader, - metrics_file_path=metrics_file_path, - optimizer=optimizer, - scheduler=scheduler, - output_dir=output_dir, - max_iter=max_iter, - checkpoint_period=checkpoint_period, - running_checkpoint_period=epoch_length, - eval_period=eval_period_iterations, - metric_type=val_metric_type, - training_num_classes=training_num_classes, - resume=resume, - val_class_mapping=val_class_mapping, - classifier_fpath=classifier_fpath, - ) - results_dict = {} - if len(test_dataset_strs) > 1 or test_dataset_strs[0] != val_dataset_str: - results_dict = test_on_datasets( - feature_model, - linear_classifiers, - test_dataset_strs, - batch_size, - 0, # num_workers, - test_metric_types, - metrics_file_path, - training_num_classes, - iteration, - val_results_dict["best_classifier"]["name"], - prefixstring="", - test_class_mappings=test_class_mappings, - ) - results_dict["best_classifier"] = val_results_dict["best_classifier"]["name"] - results_dict[f"{val_dataset_str}_accuracy"] = 100.0 * val_results_dict["best_classifier"]["accuracy"] - logger.info("Test Results Dict " + str(results_dict)) - - return results_dict - - -def main(args): - model, autocast_dtype = setup_and_build_model(args) - run_eval_linear( - model=model, - output_dir=args.output_dir, - train_dataset_str=args.train_dataset_str, - val_dataset_str=args.val_dataset_str, - test_dataset_strs=args.test_dataset_strs, - batch_size=args.batch_size, - epochs=args.epochs, - epoch_length=args.epoch_length, - num_workers=args.num_workers, - save_checkpoint_frequency=args.save_checkpoint_frequency, - eval_period_iterations=args.eval_period_iterations, - learning_rates=args.learning_rates, - autocast_dtype=autocast_dtype, - resume=not args.no_resume, - classifier_fpath=args.classifier_fpath, - val_metric_type=args.val_metric_type, - test_metric_types=args.test_metric_types, - val_class_mapping_fpath=args.val_class_mapping_fpath, - test_class_mapping_fpaths=args.test_class_mapping_fpaths, - ) - return 0 - - -if __name__ == "__main__": - description = "DINOv2 linear evaluation" - args_parser = get_args_parser(description=description) - args = args_parser.parse_args() - sys.exit(main(args)) diff --git a/dinov2/eval/log_regression.py b/dinov2/eval/log_regression.py deleted file mode 100644 index 5f36ec134e0ce25697428a0b3f21cdc2f0145645..0000000000000000000000000000000000000000 --- a/dinov2/eval/log_regression.py +++ /dev/null @@ -1,444 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import argparse -import gc -import logging -import sys -import time -from typing import List, Optional - -from cuml.linear_model import LogisticRegression -import torch -import torch.backends.cudnn as cudnn -import torch.distributed -from torch import nn -from torch.utils.data import TensorDataset -from torchmetrics import MetricTracker - -from dinov2.data import make_dataset -from dinov2.data.transforms import make_classification_eval_transform -from dinov2.distributed import get_global_rank, get_global_size -from dinov2.eval.metrics import MetricType, build_metric -from dinov2.eval.setup import get_args_parser as get_setup_args_parser -from dinov2.eval.setup import setup_and_build_model -from dinov2.eval.utils import evaluate, extract_features -from dinov2.utils.dtype import as_torch_dtype - - -logger = logging.getLogger("dinov2") - -DEFAULT_MAX_ITER = 1_000 -C_POWER_RANGE = torch.linspace(-6, 5, 45) -_CPU_DEVICE = torch.device("cpu") - - -def get_args_parser( - description: Optional[str] = None, - parents: Optional[List[argparse.ArgumentParser]] = None, - add_help: bool = True, -): - parents = parents or [] - setup_args_parser = get_setup_args_parser(parents=parents, add_help=False) - parents = [setup_args_parser] - parser = argparse.ArgumentParser( - description=description, - parents=parents, - add_help=add_help, - ) - parser.add_argument( - "--train-dataset", - dest="train_dataset_str", - type=str, - help="Training dataset", - ) - parser.add_argument( - "--val-dataset", - dest="val_dataset_str", - type=str, - help="Validation dataset", - ) - parser.add_argument( - "--finetune-dataset-str", - dest="finetune_dataset_str", - type=str, - help="Fine-tuning dataset", - ) - parser.add_argument( - "--finetune-on-val", - action="store_true", - help="If there is no finetune dataset, whether to choose the " - "hyperparameters on the val set instead of 10%% of the train dataset", - ) - parser.add_argument( - "--metric-type", - type=MetricType, - choices=list(MetricType), - help="Metric type", - ) - parser.add_argument( - "--train-features-device", - type=str, - help="Device to gather train features (cpu, cuda, cuda:0, etc.), default: %(default)s", - ) - parser.add_argument( - "--train-dtype", - type=str, - help="Data type to convert the train features to (default: %(default)s)", - ) - parser.add_argument( - "--max-train-iters", - type=int, - help="Maximum number of train iterations (default: %(default)s)", - ) - parser.set_defaults( - train_dataset_str="ImageNet:split=TRAIN", - val_dataset_str="ImageNet:split=VAL", - finetune_dataset_str=None, - metric_type=MetricType.MEAN_ACCURACY, - train_features_device="cpu", - train_dtype="float64", - max_train_iters=DEFAULT_MAX_ITER, - finetune_on_val=False, - ) - return parser - - -class LogRegModule(nn.Module): - def __init__( - self, - C, - max_iter=DEFAULT_MAX_ITER, - dtype=torch.float64, - device=_CPU_DEVICE, - ): - super().__init__() - self.dtype = dtype - self.device = device - self.estimator = LogisticRegression( - penalty="l2", - C=C, - max_iter=max_iter, - output_type="numpy", - tol=1e-12, - linesearch_max_iter=50, - ) - - def forward(self, samples, targets): - samples_device = samples.device - samples = samples.to(dtype=self.dtype, device=self.device) - if self.device == _CPU_DEVICE: - samples = samples.numpy() - probas = self.estimator.predict_proba(samples) - return {"preds": torch.from_numpy(probas).to(samples_device), "target": targets} - - def fit(self, train_features, train_labels): - train_features = train_features.to(dtype=self.dtype, device=self.device) - train_labels = train_labels.to(dtype=self.dtype, device=self.device) - if self.device == _CPU_DEVICE: - # both cuML and sklearn only work with numpy arrays on CPU - train_features = train_features.numpy() - train_labels = train_labels.numpy() - self.estimator.fit(train_features, train_labels) - - -def evaluate_model(*, logreg_model, logreg_metric, test_data_loader, device): - postprocessors = {"metrics": logreg_model} - metrics = {"metrics": logreg_metric} - return evaluate(nn.Identity(), test_data_loader, postprocessors, metrics, device) - - -def train_for_C(*, C, max_iter, train_features, train_labels, dtype=torch.float64, device=_CPU_DEVICE): - logreg_model = LogRegModule(C, max_iter=max_iter, dtype=dtype, device=device) - logreg_model.fit(train_features, train_labels) - return logreg_model - - -def train_and_evaluate( - *, - C, - max_iter, - train_features, - train_labels, - logreg_metric, - test_data_loader, - train_dtype=torch.float64, - train_features_device, - eval_device, -): - logreg_model = train_for_C( - C=C, - max_iter=max_iter, - train_features=train_features, - train_labels=train_labels, - dtype=train_dtype, - device=train_features_device, - ) - return evaluate_model( - logreg_model=logreg_model, - logreg_metric=logreg_metric, - test_data_loader=test_data_loader, - device=eval_device, - ) - - -def sweep_C_values( - *, - train_features, - train_labels, - test_data_loader, - metric_type, - num_classes, - train_dtype=torch.float64, - train_features_device=_CPU_DEVICE, - max_train_iters=DEFAULT_MAX_ITER, -): - if metric_type == MetricType.PER_CLASS_ACCURACY: - # If we want to output per-class accuracy, we select the hyperparameters with mean per class - metric_type = MetricType.MEAN_PER_CLASS_ACCURACY - logreg_metric = build_metric(metric_type, num_classes=num_classes) - metric_tracker = MetricTracker(logreg_metric, maximize=True) - ALL_C = 10**C_POWER_RANGE - logreg_models = {} - - train_features = train_features.to(dtype=train_dtype, device=train_features_device) - train_labels = train_labels.to(device=train_features_device) - - for i in range(get_global_rank(), len(ALL_C), get_global_size()): - C = ALL_C[i].item() - logger.info( - f"Training for C = {C:.5f}, dtype={train_dtype}, " - f"features: {train_features.shape}, {train_features.dtype}, " - f"labels: {train_labels.shape}, {train_labels.dtype}" - ) - logreg_models[C] = train_for_C( - C=C, - max_iter=max_train_iters, - train_features=train_features, - train_labels=train_labels, - dtype=train_dtype, - device=train_features_device, - ) - - gather_list = [None for _ in range(get_global_size())] - torch.distributed.all_gather_object(gather_list, logreg_models) - - logreg_models_gathered = {} - for logreg_dict in gather_list: - logreg_models_gathered.update(logreg_dict) - - for i in range(len(ALL_C)): - metric_tracker.increment() - C = ALL_C[i].item() - evals = evaluate_model( - logreg_model=logreg_models_gathered[C], - logreg_metric=metric_tracker, - test_data_loader=test_data_loader, - device=torch.cuda.current_device(), - ) - logger.info(f"Trained for C = {C:.5f}, accuracies = {evals}") - - best_stats, which_epoch = metric_tracker.best_metric(return_step=True) - best_stats_100 = {k: 100.0 * v for k, v in best_stats.items()} - if which_epoch["top-1"] == i: - best_C = C - logger.info(f"Sweep best {best_stats_100}, best C = {best_C:.6f}") - - return best_stats, best_C - - -def eval_log_regression( - *, - model, - train_dataset, - val_dataset, - finetune_dataset, - metric_type, - batch_size, - num_workers, - finetune_on_val=False, - train_dtype=torch.float64, - train_features_device=_CPU_DEVICE, - max_train_iters=DEFAULT_MAX_ITER, -): - """ - Implements the "standard" process for log regression evaluation: - The value of C is chosen by training on train_dataset and evaluating on - finetune_dataset. Then, the final model is trained on a concatenation of - train_dataset and finetune_dataset, and is evaluated on val_dataset. - If there is no finetune_dataset, the value of C is the one that yields - the best results on a random 10% subset of the train dataset - """ - - start = time.time() - - train_features, train_labels = extract_features( - model, train_dataset, batch_size, num_workers, gather_on_cpu=(train_features_device == _CPU_DEVICE) - ) - val_features, val_labels = extract_features( - model, val_dataset, batch_size, num_workers, gather_on_cpu=(train_features_device == _CPU_DEVICE) - ) - val_data_loader = torch.utils.data.DataLoader( - TensorDataset(val_features, val_labels), - batch_size=batch_size, - drop_last=False, - num_workers=0, - persistent_workers=False, - ) - - if finetune_dataset is None and finetune_on_val: - logger.info("Choosing hyperparameters on the val dataset") - finetune_features, finetune_labels = val_features, val_labels - elif finetune_dataset is None and not finetune_on_val: - logger.info("Choosing hyperparameters on 10% of the train dataset") - torch.manual_seed(0) - indices = torch.randperm(len(train_features), device=train_features.device) - finetune_index = indices[: len(train_features) // 10] - train_index = indices[len(train_features) // 10 :] - finetune_features, finetune_labels = train_features[finetune_index], train_labels[finetune_index] - train_features, train_labels = train_features[train_index], train_labels[train_index] - else: - logger.info("Choosing hyperparameters on the finetune dataset") - finetune_features, finetune_labels = extract_features( - model, finetune_dataset, batch_size, num_workers, gather_on_cpu=(train_features_device == _CPU_DEVICE) - ) - # release the model - free GPU memory - del model - gc.collect() - torch.cuda.empty_cache() - finetune_data_loader = torch.utils.data.DataLoader( - TensorDataset(finetune_features, finetune_labels), - batch_size=batch_size, - drop_last=False, - ) - - if len(train_labels.shape) > 1: - num_classes = train_labels.shape[1] - else: - num_classes = train_labels.max() + 1 - - logger.info("Using cuML for logistic regression") - - best_stats, best_C = sweep_C_values( - train_features=train_features, - train_labels=train_labels, - test_data_loader=finetune_data_loader, - metric_type=metric_type, - num_classes=num_classes, - train_dtype=train_dtype, - train_features_device=train_features_device, - max_train_iters=max_train_iters, - ) - - if not finetune_on_val: - logger.info("Best parameter found, concatenating features") - train_features = torch.cat((train_features, finetune_features)) - train_labels = torch.cat((train_labels, finetune_labels)) - - logger.info("Training final model") - logreg_metric = build_metric(metric_type, num_classes=num_classes) - evals = train_and_evaluate( - C=best_C, - max_iter=max_train_iters, - train_features=train_features, - train_labels=train_labels, - logreg_metric=logreg_metric.clone(), - test_data_loader=val_data_loader, - eval_device=torch.cuda.current_device(), - train_dtype=train_dtype, - train_features_device=train_features_device, - ) - - best_stats = evals[1]["metrics"] - - best_stats["best_C"] = best_C - - logger.info(f"Log regression evaluation done in {int(time.time() - start)}s") - return best_stats - - -def eval_log_regression_with_model( - model, - train_dataset_str="ImageNet:split=TRAIN", - val_dataset_str="ImageNet:split=VAL", - finetune_dataset_str=None, - autocast_dtype=torch.float, - finetune_on_val=False, - metric_type=MetricType.MEAN_ACCURACY, - train_dtype=torch.float64, - train_features_device=_CPU_DEVICE, - max_train_iters=DEFAULT_MAX_ITER, -): - cudnn.benchmark = True - - transform = make_classification_eval_transform(resize_size=224) - target_transform = None - - train_dataset = make_dataset(dataset_str=train_dataset_str, transform=transform, target_transform=target_transform) - val_dataset = make_dataset(dataset_str=val_dataset_str, transform=transform, target_transform=target_transform) - if finetune_dataset_str is not None: - finetune_dataset = make_dataset( - dataset_str=finetune_dataset_str, transform=transform, target_transform=target_transform - ) - else: - finetune_dataset = None - - with torch.cuda.amp.autocast(dtype=autocast_dtype): - results_dict_logreg = eval_log_regression( - model=model, - train_dataset=train_dataset, - val_dataset=val_dataset, - finetune_dataset=finetune_dataset, - metric_type=metric_type, - batch_size=256, - num_workers=0, # 5, - finetune_on_val=finetune_on_val, - train_dtype=train_dtype, - train_features_device=train_features_device, - max_train_iters=max_train_iters, - ) - - results_dict = { - "top-1": results_dict_logreg["top-1"].cpu().numpy() * 100.0, - "top-5": results_dict_logreg.get("top-5", torch.tensor(0.0)).cpu().numpy() * 100.0, - "best_C": results_dict_logreg["best_C"], - } - logger.info( - "\n".join( - [ - "Training of the supervised logistic regression on frozen features completed.\n" - "Top-1 test accuracy: {acc:.1f}".format(acc=results_dict["top-1"]), - "Top-5 test accuracy: {acc:.1f}".format(acc=results_dict["top-5"]), - "obtained for C = {c:.6f}".format(c=results_dict["best_C"]), - ] - ) - ) - - torch.distributed.barrier() - return results_dict - - -def main(args): - model, autocast_dtype = setup_and_build_model(args) - eval_log_regression_with_model( - model=model, - train_dataset_str=args.train_dataset_str, - val_dataset_str=args.val_dataset_str, - finetune_dataset_str=args.finetune_dataset_str, - autocast_dtype=autocast_dtype, - finetune_on_val=args.finetune_on_val, - metric_type=args.metric_type, - train_dtype=as_torch_dtype(args.train_dtype), - train_features_device=torch.device(args.train_features_device), - max_train_iters=args.max_train_iters, - ) - return 0 - - -if __name__ == "__main__": - description = "DINOv2 logistic regression evaluation" - args_parser = get_args_parser(description=description) - args = args_parser.parse_args() - sys.exit(main(args)) diff --git a/dinov2/eval/metrics.py b/dinov2/eval/metrics.py deleted file mode 100644 index c26db7b46a2e8162fb8d54813a78e272cfb8fded..0000000000000000000000000000000000000000 --- a/dinov2/eval/metrics.py +++ /dev/null @@ -1,119 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from enum import Enum -import logging -from typing import Any, Dict, Optional - -import torch -from torch import Tensor -from torchmetrics import Metric, MetricCollection -from torchmetrics.classification import MulticlassAccuracy, MulticlassF1Score, MultilabelF1Score -from torchmetrics.utilities.data import dim_zero_cat, select_topk - - -logger = logging.getLogger("dinov2") - - -class MetricType(Enum): - MEAN_ACCURACY = "mean_accuracy" - MEAN_PER_CLASS_ACCURACY = "mean_per_class_accuracy" - PER_CLASS_ACCURACY = "per_class_accuracy" - IMAGENET_REAL_ACCURACY = "imagenet_real_accuracy" - MEAN_PER_CLASS_MULTICLASS_F1 = "mean_per_class_multiclass_f1" - MEAN_PER_CLASS_MULTILABEL_F1 = "mean_per_class_multilabel_f1" - - @property - def accuracy_averaging(self): - return getattr(AccuracyAveraging, self.name, None) - - def __str__(self): - return self.value - - -class AccuracyAveraging(Enum): - MEAN_ACCURACY = "micro" - MEAN_PER_CLASS_ACCURACY = "macro" - PER_CLASS_ACCURACY = "none" - - def __str__(self): - return self.value - - -def build_metric(metric_type: MetricType, *, num_classes: int, ks: Optional[tuple] = None): - if metric_type.accuracy_averaging is not None: - return build_topk_accuracy_metric( - average_type=metric_type.accuracy_averaging, - num_classes=num_classes, - ks=(1, 5) if ks is None else ks, - ) - elif metric_type == MetricType.IMAGENET_REAL_ACCURACY: - return build_topk_imagenet_real_accuracy_metric( - num_classes=num_classes, - ks=(1, 5) if ks is None else ks, - ) - elif metric_type == MetricType.MEAN_PER_CLASS_MULTILABEL_F1: - return MetricCollection({"top-1": MultilabelF1Score(num_labels=int(num_classes), average="macro")}) - elif metric_type == MetricType.MEAN_PER_CLASS_MULTICLASS_F1: - return MetricCollection({"top-1": MulticlassF1Score(num_classes=int(num_classes), average="macro")}) - - raise ValueError(f"Unknown metric type {metric_type}") - - -def build_topk_accuracy_metric(average_type: AccuracyAveraging, num_classes: int, ks: tuple = (1, 5)): - metrics: Dict[str, Metric] = { - f"top-{k}": MulticlassAccuracy(top_k=k, num_classes=int(num_classes), average=average_type.value) for k in ks - } - return MetricCollection(metrics) - - -def build_topk_imagenet_real_accuracy_metric(num_classes: int, ks: tuple = (1, 5)): - metrics: Dict[str, Metric] = {f"top-{k}": ImageNetReaLAccuracy(top_k=k, num_classes=int(num_classes)) for k in ks} - return MetricCollection(metrics) - - -class ImageNetReaLAccuracy(Metric): - is_differentiable: bool = False - higher_is_better: Optional[bool] = None - full_state_update: bool = False - - def __init__( - self, - num_classes: int, - top_k: int = 1, - **kwargs: Any, - ) -> None: - super().__init__(**kwargs) - self.num_classes = num_classes - self.top_k = top_k - self.add_state("tp", [], dist_reduce_fx="cat") - - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore - # preds [B, D] - # target [B, A] - # preds_oh [B, D] with 0 and 1 - # select top K highest probabilities, use one hot representation - preds_oh = select_topk(preds, self.top_k) - # target_oh [B, D + 1] with 0 and 1 - target_oh = torch.zeros((preds_oh.shape[0], preds_oh.shape[1] + 1), device=target.device, dtype=torch.int32) - target = target.long() - # for undefined targets (-1) use a fake value `num_classes` - target[target == -1] = self.num_classes - # fill targets, use one hot representation - target_oh.scatter_(1, target, 1) - # target_oh [B, D] (remove the fake target at index `num_classes`) - target_oh = target_oh[:, :-1] - # tp [B] with 0 and 1 - tp = (preds_oh * target_oh == 1).sum(dim=1) - # at least one match between prediction and target - tp.clip_(max=1) - # ignore instances where no targets are defined - mask = target_oh.sum(dim=1) > 0 - tp = tp[mask] - self.tp.append(tp) # type: ignore - - def compute(self) -> Tensor: - tp = dim_zero_cat(self.tp) # type: ignore - return tp.float().mean() diff --git a/dinov2/eval/segmentation/__init__.py b/dinov2/eval/segmentation/__init__.py deleted file mode 100644 index b88da6bf80be92af00b72dfdb0a806fa64a7a2d9..0000000000000000000000000000000000000000 --- a/dinov2/eval/segmentation/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. diff --git a/dinov2/eval/segmentation/hooks/__init__.py b/dinov2/eval/segmentation/hooks/__init__.py deleted file mode 100644 index 738cc2d2069521ea0353acd0cb0a03e3ddf1fa51..0000000000000000000000000000000000000000 --- a/dinov2/eval/segmentation/hooks/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from .optimizer import DistOptimizerHook diff --git a/dinov2/eval/segmentation/hooks/optimizer.py b/dinov2/eval/segmentation/hooks/optimizer.py deleted file mode 100644 index f593f26a84475bbf7ebda9607a4d10914b13a443..0000000000000000000000000000000000000000 --- a/dinov2/eval/segmentation/hooks/optimizer.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -try: - import apex -except ImportError: - print("apex is not installed") - -from mmcv.runner import OptimizerHook, HOOKS - - -@HOOKS.register_module() -class DistOptimizerHook(OptimizerHook): - """Optimizer hook for distributed training.""" - - def __init__(self, update_interval=1, grad_clip=None, coalesce=True, bucket_size_mb=-1, use_fp16=False): - self.grad_clip = grad_clip - self.coalesce = coalesce - self.bucket_size_mb = bucket_size_mb - self.update_interval = update_interval - self.use_fp16 = use_fp16 - - def before_run(self, runner): - runner.optimizer.zero_grad() - - def after_train_iter(self, runner): - runner.outputs["loss"] /= self.update_interval - if self.use_fp16: - # runner.outputs['loss'].backward() - with apex.amp.scale_loss(runner.outputs["loss"], runner.optimizer) as scaled_loss: - scaled_loss.backward() - else: - runner.outputs["loss"].backward() - if self.every_n_iters(runner, self.update_interval): - if self.grad_clip is not None: - self.clip_grads(runner.model.parameters()) - runner.optimizer.step() - runner.optimizer.zero_grad() diff --git a/dinov2/eval/segmentation/models/__init__.py b/dinov2/eval/segmentation/models/__init__.py deleted file mode 100644 index 88e4563d4c162d67e7900955a06bd9248d4c9a48..0000000000000000000000000000000000000000 --- a/dinov2/eval/segmentation/models/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from .backbones import * # noqa: F403 -from .decode_heads import * # noqa: F403 diff --git a/dinov2/eval/segmentation/models/backbones/__init__.py b/dinov2/eval/segmentation/models/backbones/__init__.py deleted file mode 100644 index 520d75bc6e064b9d64487293604ac1bda6e2b6f7..0000000000000000000000000000000000000000 --- a/dinov2/eval/segmentation/models/backbones/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from .vision_transformer import DinoVisionTransformer diff --git a/dinov2/eval/segmentation/models/backbones/vision_transformer.py b/dinov2/eval/segmentation/models/backbones/vision_transformer.py deleted file mode 100644 index c3e9753ae92a36be52f100e3004cbeeff777d14a..0000000000000000000000000000000000000000 --- a/dinov2/eval/segmentation/models/backbones/vision_transformer.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from mmcv.runner import BaseModule -from mmseg.models.builder import BACKBONES - - -@BACKBONES.register_module() -class DinoVisionTransformer(BaseModule): - """Vision Transformer.""" - - def __init__( - self, - *args, - **kwargs, - ): - super().__init__() diff --git a/dinov2/eval/segmentation/models/decode_heads/__init__.py b/dinov2/eval/segmentation/models/decode_heads/__init__.py deleted file mode 100644 index c55317875262dadf8970c2b3882f016b8d4731ac..0000000000000000000000000000000000000000 --- a/dinov2/eval/segmentation/models/decode_heads/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from .linear_head import BNHead diff --git a/dinov2/eval/segmentation/models/decode_heads/linear_head.py b/dinov2/eval/segmentation/models/decode_heads/linear_head.py deleted file mode 100644 index d1f39c68fb136f84d1aa5284da5b69581bb177cc..0000000000000000000000000000000000000000 --- a/dinov2/eval/segmentation/models/decode_heads/linear_head.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import torch -import torch.nn as nn - -from mmseg.models.builder import HEADS -from mmseg.models.decode_heads.decode_head import BaseDecodeHead -from mmseg.ops import resize - - -@HEADS.register_module() -class BNHead(BaseDecodeHead): - """Just a batchnorm.""" - - def __init__(self, resize_factors=None, **kwargs): - super().__init__(**kwargs) - assert self.in_channels == self.channels - self.bn = nn.SyncBatchNorm(self.in_channels) - self.resize_factors = resize_factors - - def _forward_feature(self, inputs): - """Forward function for feature maps before classifying each pixel with - ``self.cls_seg`` fc. - - Args: - inputs (list[Tensor]): List of multi-level img features. - - Returns: - feats (Tensor): A tensor of shape (batch_size, self.channels, - H, W) which is feature map for last layer of decoder head. - """ - # print("inputs", [i.shape for i in inputs]) - x = self._transform_inputs(inputs) - # print("x", x.shape) - feats = self.bn(x) - # print("feats", feats.shape) - return feats - - def _transform_inputs(self, inputs): - """Transform inputs for decoder. - Args: - inputs (list[Tensor]): List of multi-level img features. - Returns: - Tensor: The transformed inputs - """ - - if self.input_transform == "resize_concat": - # accept lists (for cls token) - input_list = [] - for x in inputs: - if isinstance(x, list): - input_list.extend(x) - else: - input_list.append(x) - inputs = input_list - # an image descriptor can be a local descriptor with resolution 1x1 - for i, x in enumerate(inputs): - if len(x.shape) == 2: - inputs[i] = x[:, :, None, None] - # select indices - inputs = [inputs[i] for i in self.in_index] - # Resizing shenanigans - # print("before", *(x.shape for x in inputs)) - if self.resize_factors is not None: - assert len(self.resize_factors) == len(inputs), (len(self.resize_factors), len(inputs)) - inputs = [ - resize(input=x, scale_factor=f, mode="bilinear" if f >= 1 else "area") - for x, f in zip(inputs, self.resize_factors) - ] - # print("after", *(x.shape for x in inputs)) - upsampled_inputs = [ - resize(input=x, size=inputs[0].shape[2:], mode="bilinear", align_corners=self.align_corners) - for x in inputs - ] - inputs = torch.cat(upsampled_inputs, dim=1) - elif self.input_transform == "multiple_select": - inputs = [inputs[i] for i in self.in_index] - else: - inputs = inputs[self.in_index] - - return inputs - - def forward(self, inputs): - """Forward function.""" - output = self._forward_feature(inputs) - output = self.cls_seg(output) - return output diff --git a/dinov2/eval/segmentation/utils/__init__.py b/dinov2/eval/segmentation/utils/__init__.py deleted file mode 100644 index b88da6bf80be92af00b72dfdb0a806fa64a7a2d9..0000000000000000000000000000000000000000 --- a/dinov2/eval/segmentation/utils/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. diff --git a/dinov2/eval/segmentation/utils/colormaps.py b/dinov2/eval/segmentation/utils/colormaps.py deleted file mode 100644 index e6ef604b2c75792e95e438abfd51ab03d40de340..0000000000000000000000000000000000000000 --- a/dinov2/eval/segmentation/utils/colormaps.py +++ /dev/null @@ -1,362 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -ADE20K_COLORMAP = [ - (0, 0, 0), - (120, 120, 120), - (180, 120, 120), - (6, 230, 230), - (80, 50, 50), - (4, 200, 3), - (120, 120, 80), - (140, 140, 140), - (204, 5, 255), - (230, 230, 230), - (4, 250, 7), - (224, 5, 255), - (235, 255, 7), - (150, 5, 61), - (120, 120, 70), - (8, 255, 51), - (255, 6, 82), - (143, 255, 140), - (204, 255, 4), - (255, 51, 7), - (204, 70, 3), - (0, 102, 200), - (61, 230, 250), - (255, 6, 51), - (11, 102, 255), - (255, 7, 71), - (255, 9, 224), - (9, 7, 230), - (220, 220, 220), - (255, 9, 92), - (112, 9, 255), - (8, 255, 214), - (7, 255, 224), - (255, 184, 6), - (10, 255, 71), - (255, 41, 10), - (7, 255, 255), - (224, 255, 8), - (102, 8, 255), - (255, 61, 6), - (255, 194, 7), - (255, 122, 8), - (0, 255, 20), - (255, 8, 41), - (255, 5, 153), - (6, 51, 255), - (235, 12, 255), - (160, 150, 20), - (0, 163, 255), - (140, 140, 140), - (250, 10, 15), - (20, 255, 0), - (31, 255, 0), - (255, 31, 0), - (255, 224, 0), - (153, 255, 0), - (0, 0, 255), - (255, 71, 0), - (0, 235, 255), - (0, 173, 255), - (31, 0, 255), - (11, 200, 200), - (255, 82, 0), - (0, 255, 245), - (0, 61, 255), - (0, 255, 112), - (0, 255, 133), - (255, 0, 0), - (255, 163, 0), - (255, 102, 0), - (194, 255, 0), - (0, 143, 255), - (51, 255, 0), - (0, 82, 255), - (0, 255, 41), - (0, 255, 173), - (10, 0, 255), - (173, 255, 0), - (0, 255, 153), - (255, 92, 0), - (255, 0, 255), - (255, 0, 245), - (255, 0, 102), - (255, 173, 0), - (255, 0, 20), - (255, 184, 184), - (0, 31, 255), - (0, 255, 61), - (0, 71, 255), - (255, 0, 204), - (0, 255, 194), - (0, 255, 82), - (0, 10, 255), - (0, 112, 255), - (51, 0, 255), - (0, 194, 255), - (0, 122, 255), - (0, 255, 163), - (255, 153, 0), - (0, 255, 10), - (255, 112, 0), - (143, 255, 0), - (82, 0, 255), - (163, 255, 0), - (255, 235, 0), - (8, 184, 170), - (133, 0, 255), - (0, 255, 92), - (184, 0, 255), - (255, 0, 31), - (0, 184, 255), - (0, 214, 255), - (255, 0, 112), - (92, 255, 0), - (0, 224, 255), - (112, 224, 255), - (70, 184, 160), - (163, 0, 255), - (153, 0, 255), - (71, 255, 0), - (255, 0, 163), - (255, 204, 0), - (255, 0, 143), - (0, 255, 235), - (133, 255, 0), - (255, 0, 235), - (245, 0, 255), - (255, 0, 122), - (255, 245, 0), - (10, 190, 212), - (214, 255, 0), - (0, 204, 255), - (20, 0, 255), - (255, 255, 0), - (0, 153, 255), - (0, 41, 255), - (0, 255, 204), - (41, 0, 255), - (41, 255, 0), - (173, 0, 255), - (0, 245, 255), - (71, 0, 255), - (122, 0, 255), - (0, 255, 184), - (0, 92, 255), - (184, 255, 0), - (0, 133, 255), - (255, 214, 0), - (25, 194, 194), - (102, 255, 0), - (92, 0, 255), -] - -ADE20K_CLASS_NAMES = [ - "", - "wall", - "building;edifice", - "sky", - "floor;flooring", - "tree", - "ceiling", - "road;route", - "bed", - "windowpane;window", - "grass", - "cabinet", - "sidewalk;pavement", - "person;individual;someone;somebody;mortal;soul", - "earth;ground", - "door;double;door", - "table", - "mountain;mount", - "plant;flora;plant;life", - "curtain;drape;drapery;mantle;pall", - "chair", - "car;auto;automobile;machine;motorcar", - "water", - "painting;picture", - "sofa;couch;lounge", - "shelf", - "house", - "sea", - "mirror", - "rug;carpet;carpeting", - "field", - "armchair", - "seat", - "fence;fencing", - "desk", - "rock;stone", - "wardrobe;closet;press", - "lamp", - "bathtub;bathing;tub;bath;tub", - "railing;rail", - "cushion", - "base;pedestal;stand", - "box", - "column;pillar", - "signboard;sign", - "chest;of;drawers;chest;bureau;dresser", - "counter", - "sand", - "sink", - "skyscraper", - "fireplace;hearth;open;fireplace", - "refrigerator;icebox", - "grandstand;covered;stand", - "path", - "stairs;steps", - "runway", - "case;display;case;showcase;vitrine", - "pool;table;billiard;table;snooker;table", - "pillow", - "screen;door;screen", - "stairway;staircase", - "river", - "bridge;span", - "bookcase", - "blind;screen", - "coffee;table;cocktail;table", - "toilet;can;commode;crapper;pot;potty;stool;throne", - "flower", - "book", - "hill", - "bench", - "countertop", - "stove;kitchen;stove;range;kitchen;range;cooking;stove", - "palm;palm;tree", - "kitchen;island", - "computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system", - "swivel;chair", - "boat", - "bar", - "arcade;machine", - "hovel;hut;hutch;shack;shanty", - "bus;autobus;coach;charabanc;double-decker;jitney;motorbus;motorcoach;omnibus;passenger;vehicle", - "towel", - "light;light;source", - "truck;motortruck", - "tower", - "chandelier;pendant;pendent", - "awning;sunshade;sunblind", - "streetlight;street;lamp", - "booth;cubicle;stall;kiosk", - "television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box", - "airplane;aeroplane;plane", - "dirt;track", - "apparel;wearing;apparel;dress;clothes", - "pole", - "land;ground;soil", - "bannister;banister;balustrade;balusters;handrail", - "escalator;moving;staircase;moving;stairway", - "ottoman;pouf;pouffe;puff;hassock", - "bottle", - "buffet;counter;sideboard", - "poster;posting;placard;notice;bill;card", - "stage", - "van", - "ship", - "fountain", - "conveyer;belt;conveyor;belt;conveyer;conveyor;transporter", - "canopy", - "washer;automatic;washer;washing;machine", - "plaything;toy", - "swimming;pool;swimming;bath;natatorium", - "stool", - "barrel;cask", - "basket;handbasket", - "waterfall;falls", - "tent;collapsible;shelter", - "bag", - "minibike;motorbike", - "cradle", - "oven", - "ball", - "food;solid;food", - "step;stair", - "tank;storage;tank", - "trade;name;brand;name;brand;marque", - "microwave;microwave;oven", - "pot;flowerpot", - "animal;animate;being;beast;brute;creature;fauna", - "bicycle;bike;wheel;cycle", - "lake", - "dishwasher;dish;washer;dishwashing;machine", - "screen;silver;screen;projection;screen", - "blanket;cover", - "sculpture", - "hood;exhaust;hood", - "sconce", - "vase", - "traffic;light;traffic;signal;stoplight", - "tray", - "ashcan;trash;can;garbage;can;wastebin;ash;bin;ash-bin;ashbin;dustbin;trash;barrel;trash;bin", - "fan", - "pier;wharf;wharfage;dock", - "crt;screen", - "plate", - "monitor;monitoring;device", - "bulletin;board;notice;board", - "shower", - "radiator", - "glass;drinking;glass", - "clock", - "flag", -] - - -VOC2012_COLORMAP = [ - (0, 0, 0), - (128, 0, 0), - (0, 128, 0), - (128, 128, 0), - (0, 0, 128), - (128, 0, 128), - (0, 128, 128), - (128, 128, 128), - (64, 0, 0), - (192, 0, 0), - (64, 128, 0), - (192, 128, 0), - (64, 0, 128), - (192, 0, 128), - (64, 128, 128), - (192, 128, 128), - (0, 64, 0), - (128, 64, 0), - (0, 192, 0), - (128, 192, 0), - (0, 64, 128), -] - - -VOC2012_CLASS_NAMES = [ - "", - "aeroplane", - "bicycle", - "bird", - "boat", - "bottle", - "bus", - "car", - "cat", - "chair", - "cow", - "diningtable", - "dog", - "horse", - "motorbike", - "person", - "pottedplant", - "sheep", - "sofa", - "train", - "tvmonitor", -] diff --git a/dinov2/eval/segmentation_m2f/__init__.py b/dinov2/eval/segmentation_m2f/__init__.py deleted file mode 100644 index 6c678fdf8f1dee14d7cf9be70af14e6f9a1441c3..0000000000000000000000000000000000000000 --- a/dinov2/eval/segmentation_m2f/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from .core import * # noqa: F403 -from .models import * # noqa: F403 -from .ops import * # noqa: F403 diff --git a/dinov2/eval/segmentation_m2f/core/__init__.py b/dinov2/eval/segmentation_m2f/core/__init__.py deleted file mode 100644 index 92599806fbd221c1418d179892a0f46dc0b7d4db..0000000000000000000000000000000000000000 --- a/dinov2/eval/segmentation_m2f/core/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from mmseg.core.evaluation import * # noqa: F403 -from mmseg.core.seg import * # noqa: F403 - -from .anchor import * # noqa: F403 -from .box import * # noqa: F403 -from .utils import * # noqa: F403 diff --git a/dinov2/eval/segmentation_m2f/core/anchor/__init__.py b/dinov2/eval/segmentation_m2f/core/anchor/__init__.py deleted file mode 100644 index e71ac4d6e01462221ae01aa16d0e1231cda7e2e7..0000000000000000000000000000000000000000 --- a/dinov2/eval/segmentation_m2f/core/anchor/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from .point_generator import MlvlPointGenerator # noqa: F403 diff --git a/dinov2/eval/segmentation_m2f/core/anchor/builder.py b/dinov2/eval/segmentation_m2f/core/anchor/builder.py deleted file mode 100644 index 6dba90e22de76d2f23a86d3c057f196d55a99690..0000000000000000000000000000000000000000 --- a/dinov2/eval/segmentation_m2f/core/anchor/builder.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import warnings - -from mmcv.utils import Registry, build_from_cfg - -PRIOR_GENERATORS = Registry("Generator for anchors and points") - -ANCHOR_GENERATORS = PRIOR_GENERATORS - - -def build_prior_generator(cfg, default_args=None): - return build_from_cfg(cfg, PRIOR_GENERATORS, default_args) - - -def build_anchor_generator(cfg, default_args=None): - warnings.warn("``build_anchor_generator`` would be deprecated soon, please use " "``build_prior_generator`` ") - return build_prior_generator(cfg, default_args=default_args) diff --git a/dinov2/eval/segmentation_m2f/core/anchor/point_generator.py b/dinov2/eval/segmentation_m2f/core/anchor/point_generator.py deleted file mode 100644 index 574d71939080e22284fe99087fb2e7336657bd97..0000000000000000000000000000000000000000 --- a/dinov2/eval/segmentation_m2f/core/anchor/point_generator.py +++ /dev/null @@ -1,205 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import numpy as np -import torch -from torch.nn.modules.utils import _pair - -from .builder import PRIOR_GENERATORS - - -@PRIOR_GENERATORS.register_module() -class MlvlPointGenerator: - """Standard points generator for multi-level (Mlvl) feature maps in 2D - points-based detectors. - - Args: - strides (list[int] | list[tuple[int, int]]): Strides of anchors - in multiple feature levels in order (w, h). - offset (float): The offset of points, the value is normalized with - corresponding stride. Defaults to 0.5. - """ - - def __init__(self, strides, offset=0.5): - self.strides = [_pair(stride) for stride in strides] - self.offset = offset - - @property - def num_levels(self): - """int: number of feature levels that the generator will be applied""" - return len(self.strides) - - @property - def num_base_priors(self): - """list[int]: The number of priors (points) at a point - on the feature grid""" - return [1 for _ in range(len(self.strides))] - - def _meshgrid(self, x, y, row_major=True): - yy, xx = torch.meshgrid(y, x) - if row_major: - # warning .flatten() would cause error in ONNX exporting - # have to use reshape here - return xx.reshape(-1), yy.reshape(-1) - - else: - return yy.reshape(-1), xx.reshape(-1) - - def grid_priors(self, featmap_sizes, dtype=torch.float32, device="cuda", with_stride=False): - """Generate grid points of multiple feature levels. - - Args: - featmap_sizes (list[tuple]): List of feature map sizes in - multiple feature levels, each size arrange as - as (h, w). - dtype (:obj:`dtype`): Dtype of priors. Default: torch.float32. - device (str): The device where the anchors will be put on. - with_stride (bool): Whether to concatenate the stride to - the last dimension of points. - - Return: - list[torch.Tensor]: Points of multiple feature levels. - The sizes of each tensor should be (N, 2) when with stride is - ``False``, where N = width * height, width and height - are the sizes of the corresponding feature level, - and the last dimension 2 represent (coord_x, coord_y), - otherwise the shape should be (N, 4), - and the last dimension 4 represent - (coord_x, coord_y, stride_w, stride_h). - """ - - assert self.num_levels == len(featmap_sizes) - multi_level_priors = [] - for i in range(self.num_levels): - priors = self.single_level_grid_priors( - featmap_sizes[i], level_idx=i, dtype=dtype, device=device, with_stride=with_stride - ) - multi_level_priors.append(priors) - return multi_level_priors - - def single_level_grid_priors(self, featmap_size, level_idx, dtype=torch.float32, device="cuda", with_stride=False): - """Generate grid Points of a single level. - - Note: - This function is usually called by method ``self.grid_priors``. - - Args: - featmap_size (tuple[int]): Size of the feature maps, arrange as - (h, w). - level_idx (int): The index of corresponding feature map level. - dtype (:obj:`dtype`): Dtype of priors. Default: torch.float32. - device (str, optional): The device the tensor will be put on. - Defaults to 'cuda'. - with_stride (bool): Concatenate the stride to the last dimension - of points. - - Return: - Tensor: Points of single feature levels. - The shape of tensor should be (N, 2) when with stride is - ``False``, where N = width * height, width and height - are the sizes of the corresponding feature level, - and the last dimension 2 represent (coord_x, coord_y), - otherwise the shape should be (N, 4), - and the last dimension 4 represent - (coord_x, coord_y, stride_w, stride_h). - """ - feat_h, feat_w = featmap_size - stride_w, stride_h = self.strides[level_idx] - shift_x = (torch.arange(0, feat_w, device=device) + self.offset) * stride_w - # keep featmap_size as Tensor instead of int, so that we - # can convert to ONNX correctly - shift_x = shift_x.to(dtype) - - shift_y = (torch.arange(0, feat_h, device=device) + self.offset) * stride_h - # keep featmap_size as Tensor instead of int, so that we - # can convert to ONNX correctly - shift_y = shift_y.to(dtype) - shift_xx, shift_yy = self._meshgrid(shift_x, shift_y) - if not with_stride: - shifts = torch.stack([shift_xx, shift_yy], dim=-1) - else: - # use `shape[0]` instead of `len(shift_xx)` for ONNX export - stride_w = shift_xx.new_full((shift_xx.shape[0],), stride_w).to(dtype) - stride_h = shift_xx.new_full((shift_yy.shape[0],), stride_h).to(dtype) - shifts = torch.stack([shift_xx, shift_yy, stride_w, stride_h], dim=-1) - all_points = shifts.to(device) - return all_points - - def valid_flags(self, featmap_sizes, pad_shape, device="cuda"): - """Generate valid flags of points of multiple feature levels. - - Args: - featmap_sizes (list(tuple)): List of feature map sizes in - multiple feature levels, each size arrange as - as (h, w). - pad_shape (tuple(int)): The padded shape of the image, - arrange as (h, w). - device (str): The device where the anchors will be put on. - - Return: - list(torch.Tensor): Valid flags of points of multiple levels. - """ - assert self.num_levels == len(featmap_sizes) - multi_level_flags = [] - for i in range(self.num_levels): - point_stride = self.strides[i] - feat_h, feat_w = featmap_sizes[i] - h, w = pad_shape[:2] - valid_feat_h = min(int(np.ceil(h / point_stride[1])), feat_h) - valid_feat_w = min(int(np.ceil(w / point_stride[0])), feat_w) - flags = self.single_level_valid_flags((feat_h, feat_w), (valid_feat_h, valid_feat_w), device=device) - multi_level_flags.append(flags) - return multi_level_flags - - def single_level_valid_flags(self, featmap_size, valid_size, device="cuda"): - """Generate the valid flags of points of a single feature map. - - Args: - featmap_size (tuple[int]): The size of feature maps, arrange as - as (h, w). - valid_size (tuple[int]): The valid size of the feature maps. - The size arrange as as (h, w). - device (str, optional): The device where the flags will be put on. - Defaults to 'cuda'. - - Returns: - torch.Tensor: The valid flags of each points in a single level \ - feature map. - """ - feat_h, feat_w = featmap_size - valid_h, valid_w = valid_size - assert valid_h <= feat_h and valid_w <= feat_w - valid_x = torch.zeros(feat_w, dtype=torch.bool, device=device) - valid_y = torch.zeros(feat_h, dtype=torch.bool, device=device) - valid_x[:valid_w] = 1 - valid_y[:valid_h] = 1 - valid_xx, valid_yy = self._meshgrid(valid_x, valid_y) - valid = valid_xx & valid_yy - return valid - - def sparse_priors(self, prior_idxs, featmap_size, level_idx, dtype=torch.float32, device="cuda"): - """Generate sparse points according to the ``prior_idxs``. - - Args: - prior_idxs (Tensor): The index of corresponding anchors - in the feature map. - featmap_size (tuple[int]): feature map size arrange as (w, h). - level_idx (int): The level index of corresponding feature - map. - dtype (obj:`torch.dtype`): Date type of points. Defaults to - ``torch.float32``. - device (obj:`torch.device`): The device where the points is - located. - Returns: - Tensor: Anchor with shape (N, 2), N should be equal to - the length of ``prior_idxs``. And last dimension - 2 represent (coord_x, coord_y). - """ - height, width = featmap_size - x = (prior_idxs % width + self.offset) * self.strides[level_idx][0] - y = ((prior_idxs // width) % height + self.offset) * self.strides[level_idx][1] - prioris = torch.stack([x, y], 1).to(dtype) - prioris = prioris.to(device) - return prioris diff --git a/dinov2/eval/segmentation_m2f/core/box/__init__.py b/dinov2/eval/segmentation_m2f/core/box/__init__.py deleted file mode 100644 index bf35a613f81acd77ecab2dfb75a722fa8e5c0787..0000000000000000000000000000000000000000 --- a/dinov2/eval/segmentation_m2f/core/box/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from .builder import * # noqa: F403 -from .samplers import MaskPseudoSampler # noqa: F403 diff --git a/dinov2/eval/segmentation_m2f/core/box/builder.py b/dinov2/eval/segmentation_m2f/core/box/builder.py deleted file mode 100644 index 9538c0de3db682c2b111b085a8a1ce321c76a9ff..0000000000000000000000000000000000000000 --- a/dinov2/eval/segmentation_m2f/core/box/builder.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from mmcv.utils import Registry, build_from_cfg - -BBOX_SAMPLERS = Registry("bbox_sampler") -BBOX_CODERS = Registry("bbox_coder") - - -def build_sampler(cfg, **default_args): - """Builder of box sampler.""" - return build_from_cfg(cfg, BBOX_SAMPLERS, default_args) - - -def build_bbox_coder(cfg, **default_args): - """Builder of box coder.""" - return build_from_cfg(cfg, BBOX_CODERS, default_args) diff --git a/dinov2/eval/segmentation_m2f/core/box/samplers/__init__.py b/dinov2/eval/segmentation_m2f/core/box/samplers/__init__.py deleted file mode 100644 index 19c363e3fabc365d92aeaf1e78189d710db279e9..0000000000000000000000000000000000000000 --- a/dinov2/eval/segmentation_m2f/core/box/samplers/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from .mask_pseudo_sampler import MaskPseudoSampler # noqa: F403 diff --git a/dinov2/eval/segmentation_m2f/core/box/samplers/base_sampler.py b/dinov2/eval/segmentation_m2f/core/box/samplers/base_sampler.py deleted file mode 100644 index c45cec3ed7af5b49bb54b92d6e6bcf59b06b4c99..0000000000000000000000000000000000000000 --- a/dinov2/eval/segmentation_m2f/core/box/samplers/base_sampler.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from abc import ABCMeta, abstractmethod - -import torch - -from .sampling_result import SamplingResult - - -class BaseSampler(metaclass=ABCMeta): - """Base class of samplers.""" - - def __init__(self, num, pos_fraction, neg_pos_ub=-1, add_gt_as_proposals=True, **kwargs): - self.num = num - self.pos_fraction = pos_fraction - self.neg_pos_ub = neg_pos_ub - self.add_gt_as_proposals = add_gt_as_proposals - self.pos_sampler = self - self.neg_sampler = self - - @abstractmethod - def _sample_pos(self, assign_result, num_expected, **kwargs): - """Sample positive samples.""" - pass - - @abstractmethod - def _sample_neg(self, assign_result, num_expected, **kwargs): - """Sample negative samples.""" - pass - - def sample(self, assign_result, bboxes, gt_bboxes, gt_labels=None, **kwargs): - """Sample positive and negative bboxes. - - This is a simple implementation of bbox sampling given candidates, - assigning results and ground truth bboxes. - - Args: - assign_result (:obj:`AssignResult`): Bbox assigning results. - bboxes (Tensor): Boxes to be sampled from. - gt_bboxes (Tensor): Ground truth bboxes. - gt_labels (Tensor, optional): Class labels of ground truth bboxes. - - Returns: - :obj:`SamplingResult`: Sampling result. - - Example: - >>> from mmdet.core.bbox import RandomSampler - >>> from mmdet.core.bbox import AssignResult - >>> from mmdet.core.bbox.demodata import ensure_rng, random_boxes - >>> rng = ensure_rng(None) - >>> assign_result = AssignResult.random(rng=rng) - >>> bboxes = random_boxes(assign_result.num_preds, rng=rng) - >>> gt_bboxes = random_boxes(assign_result.num_gts, rng=rng) - >>> gt_labels = None - >>> self = RandomSampler(num=32, pos_fraction=0.5, neg_pos_ub=-1, - >>> add_gt_as_proposals=False) - >>> self = self.sample(assign_result, bboxes, gt_bboxes, gt_labels) - """ - if len(bboxes.shape) < 2: - bboxes = bboxes[None, :] - - bboxes = bboxes[:, :4] - - gt_flags = bboxes.new_zeros((bboxes.shape[0],), dtype=torch.uint8) - if self.add_gt_as_proposals and len(gt_bboxes) > 0: - if gt_labels is None: - raise ValueError("gt_labels must be given when add_gt_as_proposals is True") - bboxes = torch.cat([gt_bboxes, bboxes], dim=0) - assign_result.add_gt_(gt_labels) - gt_ones = bboxes.new_ones(gt_bboxes.shape[0], dtype=torch.uint8) - gt_flags = torch.cat([gt_ones, gt_flags]) - - num_expected_pos = int(self.num * self.pos_fraction) - pos_inds = self.pos_sampler._sample_pos(assign_result, num_expected_pos, bboxes=bboxes, **kwargs) - # We found that sampled indices have duplicated items occasionally. - # (may be a bug of PyTorch) - pos_inds = pos_inds.unique() - num_sampled_pos = pos_inds.numel() - num_expected_neg = self.num - num_sampled_pos - if self.neg_pos_ub >= 0: - _pos = max(1, num_sampled_pos) - neg_upper_bound = int(self.neg_pos_ub * _pos) - if num_expected_neg > neg_upper_bound: - num_expected_neg = neg_upper_bound - neg_inds = self.neg_sampler._sample_neg(assign_result, num_expected_neg, bboxes=bboxes, **kwargs) - neg_inds = neg_inds.unique() - - sampling_result = SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes, assign_result, gt_flags) - return sampling_result diff --git a/dinov2/eval/segmentation_m2f/core/box/samplers/mask_pseudo_sampler.py b/dinov2/eval/segmentation_m2f/core/box/samplers/mask_pseudo_sampler.py deleted file mode 100644 index 3e67ea61ed0fd65cca0addde1893a3c1e176bf15..0000000000000000000000000000000000000000 --- a/dinov2/eval/segmentation_m2f/core/box/samplers/mask_pseudo_sampler.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -# References: -# https://github.com/ZwwWayne/K-Net/blob/main/knet/det/mask_pseudo_sampler.py - -import torch - -from ..builder import BBOX_SAMPLERS -from .base_sampler import BaseSampler -from .mask_sampling_result import MaskSamplingResult - - -@BBOX_SAMPLERS.register_module() -class MaskPseudoSampler(BaseSampler): - """A pseudo sampler that does not do sampling actually.""" - - def __init__(self, **kwargs): - pass - - def _sample_pos(self, **kwargs): - """Sample positive samples.""" - raise NotImplementedError - - def _sample_neg(self, **kwargs): - """Sample negative samples.""" - raise NotImplementedError - - def sample(self, assign_result, masks, gt_masks, **kwargs): - """Directly returns the positive and negative indices of samples. - - Args: - assign_result (:obj:`AssignResult`): Assigned results - masks (torch.Tensor): Bounding boxes - gt_masks (torch.Tensor): Ground truth boxes - Returns: - :obj:`SamplingResult`: sampler results - """ - pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False).squeeze(-1).unique() - neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False).squeeze(-1).unique() - gt_flags = masks.new_zeros(masks.shape[0], dtype=torch.uint8) - sampling_result = MaskSamplingResult(pos_inds, neg_inds, masks, gt_masks, assign_result, gt_flags) - return sampling_result diff --git a/dinov2/eval/segmentation_m2f/core/box/samplers/mask_sampling_result.py b/dinov2/eval/segmentation_m2f/core/box/samplers/mask_sampling_result.py deleted file mode 100644 index 270ffd35a5f120dd0560a7fea7fe83ef0bab66bb..0000000000000000000000000000000000000000 --- a/dinov2/eval/segmentation_m2f/core/box/samplers/mask_sampling_result.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -# References: -# https://github.com/ZwwWayne/K-Net/blob/main/knet/det/mask_pseudo_sampler.py - -import torch - -from .sampling_result import SamplingResult - - -class MaskSamplingResult(SamplingResult): - """Mask sampling result.""" - - def __init__(self, pos_inds, neg_inds, masks, gt_masks, assign_result, gt_flags): - self.pos_inds = pos_inds - self.neg_inds = neg_inds - self.pos_masks = masks[pos_inds] - self.neg_masks = masks[neg_inds] - self.pos_is_gt = gt_flags[pos_inds] - - self.num_gts = gt_masks.shape[0] - self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1 - - if gt_masks.numel() == 0: - # hack for index error case - assert self.pos_assigned_gt_inds.numel() == 0 - self.pos_gt_masks = torch.empty_like(gt_masks) - else: - self.pos_gt_masks = gt_masks[self.pos_assigned_gt_inds, :] - - if assign_result.labels is not None: - self.pos_gt_labels = assign_result.labels[pos_inds] - else: - self.pos_gt_labels = None - - @property - def masks(self): - """torch.Tensor: concatenated positive and negative boxes""" - return torch.cat([self.pos_masks, self.neg_masks]) - - def __nice__(self): - data = self.info.copy() - data["pos_masks"] = data.pop("pos_masks").shape - data["neg_masks"] = data.pop("neg_masks").shape - parts = [f"'{k}': {v!r}" for k, v in sorted(data.items())] - body = " " + ",\n ".join(parts) - return "{\n" + body + "\n}" - - @property - def info(self): - """Returns a dictionary of info about the object.""" - return { - "pos_inds": self.pos_inds, - "neg_inds": self.neg_inds, - "pos_masks": self.pos_masks, - "neg_masks": self.neg_masks, - "pos_is_gt": self.pos_is_gt, - "num_gts": self.num_gts, - "pos_assigned_gt_inds": self.pos_assigned_gt_inds, - } diff --git a/dinov2/eval/segmentation_m2f/core/box/samplers/sampling_result.py b/dinov2/eval/segmentation_m2f/core/box/samplers/sampling_result.py deleted file mode 100644 index aaee3fe55aeb8c6da7edefbbd382d94b67b6a6b4..0000000000000000000000000000000000000000 --- a/dinov2/eval/segmentation_m2f/core/box/samplers/sampling_result.py +++ /dev/null @@ -1,152 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import torch - - -class SamplingResult: - """Bbox sampling result. - - Example: - >>> # xdoctest: +IGNORE_WANT - >>> from mmdet.core.bbox.samplers.sampling_result import * # NOQA - >>> self = SamplingResult.random(rng=10) - >>> print(f'self = {self}') - self = - """ - - def __init__(self, pos_inds, neg_inds, bboxes, gt_bboxes, assign_result, gt_flags): - self.pos_inds = pos_inds - self.neg_inds = neg_inds - self.pos_bboxes = bboxes[pos_inds] - self.neg_bboxes = bboxes[neg_inds] - self.pos_is_gt = gt_flags[pos_inds] - - self.num_gts = gt_bboxes.shape[0] - self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1 - - if gt_bboxes.numel() == 0: - # hack for index error case - assert self.pos_assigned_gt_inds.numel() == 0 - self.pos_gt_bboxes = torch.empty_like(gt_bboxes).view(-1, 4) - else: - if len(gt_bboxes.shape) < 2: - gt_bboxes = gt_bboxes.view(-1, 4) - - self.pos_gt_bboxes = gt_bboxes[self.pos_assigned_gt_inds.long(), :] - - if assign_result.labels is not None: - self.pos_gt_labels = assign_result.labels[pos_inds] - else: - self.pos_gt_labels = None - - @property - def bboxes(self): - """torch.Tensor: concatenated positive and negative boxes""" - return torch.cat([self.pos_bboxes, self.neg_bboxes]) - - def to(self, device): - """Change the device of the data inplace. - - Example: - >>> self = SamplingResult.random() - >>> print(f'self = {self.to(None)}') - >>> # xdoctest: +REQUIRES(--gpu) - >>> print(f'self = {self.to(0)}') - """ - _dict = self.__dict__ - for key, value in _dict.items(): - if isinstance(value, torch.Tensor): - _dict[key] = value.to(device) - return self - - def __nice__(self): - data = self.info.copy() - data["pos_bboxes"] = data.pop("pos_bboxes").shape - data["neg_bboxes"] = data.pop("neg_bboxes").shape - parts = [f"'{k}': {v!r}" for k, v in sorted(data.items())] - body = " " + ",\n ".join(parts) - return "{\n" + body + "\n}" - - @property - def info(self): - """Returns a dictionary of info about the object.""" - return { - "pos_inds": self.pos_inds, - "neg_inds": self.neg_inds, - "pos_bboxes": self.pos_bboxes, - "neg_bboxes": self.neg_bboxes, - "pos_is_gt": self.pos_is_gt, - "num_gts": self.num_gts, - "pos_assigned_gt_inds": self.pos_assigned_gt_inds, - } - - @classmethod - def random(cls, rng=None, **kwargs): - """ - Args: - rng (None | int | numpy.random.RandomState): seed or state. - kwargs (keyword arguments): - - num_preds: number of predicted boxes - - num_gts: number of true boxes - - p_ignore (float): probability of a predicted box assigned to \ - an ignored truth. - - p_assigned (float): probability of a predicted box not being \ - assigned. - - p_use_label (float | bool): with labels or not. - - Returns: - :obj:`SamplingResult`: Randomly generated sampling result. - - Example: - >>> from mmdet.core.bbox.samplers.sampling_result import * # NOQA - >>> self = SamplingResult.random() - >>> print(self.__dict__) - """ - from mmdet.core.bbox import demodata - from mmdet.core.bbox.assigners.assign_result import AssignResult - from mmdet.core.bbox.samplers.random_sampler import RandomSampler - - rng = demodata.ensure_rng(rng) - - # make probabalistic? - num = 32 - pos_fraction = 0.5 - neg_pos_ub = -1 - - assign_result = AssignResult.random(rng=rng, **kwargs) - - # Note we could just compute an assignment - bboxes = demodata.random_boxes(assign_result.num_preds, rng=rng) - gt_bboxes = demodata.random_boxes(assign_result.num_gts, rng=rng) - - if rng.rand() > 0.2: - # sometimes algorithms squeeze their data, be robust to that - gt_bboxes = gt_bboxes.squeeze() - bboxes = bboxes.squeeze() - - if assign_result.labels is None: - gt_labels = None - else: - gt_labels = None - - if gt_labels is None: - add_gt_as_proposals = False - else: - add_gt_as_proposals = True # make probabalistic? - - sampler = RandomSampler( - num, pos_fraction, neg_pos_ub=neg_pos_ub, add_gt_as_proposals=add_gt_as_proposals, rng=rng - ) - self = sampler.sample(assign_result, bboxes, gt_bboxes, gt_labels) - return self diff --git a/dinov2/eval/segmentation_m2f/core/utils/__init__.py b/dinov2/eval/segmentation_m2f/core/utils/__init__.py deleted file mode 100644 index 6cdc9e19352f50bc2d5433c412ff71186c5df019..0000000000000000000000000000000000000000 --- a/dinov2/eval/segmentation_m2f/core/utils/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from .dist_utils import reduce_mean -from .misc import add_prefix, multi_apply diff --git a/dinov2/eval/segmentation_m2f/core/utils/dist_utils.py b/dinov2/eval/segmentation_m2f/core/utils/dist_utils.py deleted file mode 100644 index 7dfed42da821cd94e31b663d86b20b8f09799b30..0000000000000000000000000000000000000000 --- a/dinov2/eval/segmentation_m2f/core/utils/dist_utils.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import torch.distributed as dist - - -def reduce_mean(tensor): - """ "Obtain the mean of tensor on different GPUs.""" - if not (dist.is_available() and dist.is_initialized()): - return tensor - tensor = tensor.clone() - dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM) - return tensor diff --git a/dinov2/eval/segmentation_m2f/core/utils/misc.py b/dinov2/eval/segmentation_m2f/core/utils/misc.py deleted file mode 100644 index e07579e7b182b62153e81fe637ffd0f3081ef2a3..0000000000000000000000000000000000000000 --- a/dinov2/eval/segmentation_m2f/core/utils/misc.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from functools import partial - - -def multi_apply(func, *args, **kwargs): - """Apply function to a list of arguments. - - Note: - This function applies the ``func`` to multiple inputs and - map the multiple outputs of the ``func`` into different - list. Each list contains the same type of outputs corresponding - to different inputs. - - Args: - func (Function): A function that will be applied to a list of - arguments - - Returns: - tuple(list): A tuple containing multiple list, each list contains \ - a kind of returned results by the function - """ - pfunc = partial(func, **kwargs) if kwargs else func - map_results = map(pfunc, *args) - return tuple(map(list, zip(*map_results))) - - -def add_prefix(inputs, prefix): - """Add prefix for dict. - - Args: - inputs (dict): The input dict with str keys. - prefix (str): The prefix to add. - - Returns: - - dict: The dict with keys updated with ``prefix``. - """ - - outputs = dict() - for name, value in inputs.items(): - outputs[f"{prefix}.{name}"] = value - - return outputs diff --git a/dinov2/eval/segmentation_m2f/models/__init__.py b/dinov2/eval/segmentation_m2f/models/__init__.py deleted file mode 100644 index ed89bb0064d82b4360af020798eab3d2f5a47937..0000000000000000000000000000000000000000 --- a/dinov2/eval/segmentation_m2f/models/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from .backbones import * # noqa: F403 -from .builder import MASK_ASSIGNERS, MATCH_COST, TRANSFORMER, build_assigner, build_match_cost -from .decode_heads import * # noqa: F403 -from .losses import * # noqa: F403 -from .plugins import * # noqa: F403 -from .segmentors import * # noqa: F403 diff --git a/dinov2/eval/segmentation_m2f/models/backbones/__init__.py b/dinov2/eval/segmentation_m2f/models/backbones/__init__.py deleted file mode 100644 index c4bf73bcbcee710676f81cb6517ae787f4d61cc6..0000000000000000000000000000000000000000 --- a/dinov2/eval/segmentation_m2f/models/backbones/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from .vit_adapter import ViTAdapter diff --git a/dinov2/eval/segmentation_m2f/models/backbones/adapter_modules.py b/dinov2/eval/segmentation_m2f/models/backbones/adapter_modules.py deleted file mode 100644 index 26bfdf8f6ae6c107d22d61985cce34d4b5ce275f..0000000000000000000000000000000000000000 --- a/dinov2/eval/segmentation_m2f/models/backbones/adapter_modules.py +++ /dev/null @@ -1,442 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from functools import partial - -import torch -import torch.nn as nn -import torch.utils.checkpoint as cp - -from ...ops.modules import MSDeformAttn -from .drop_path import DropPath - - -def get_reference_points(spatial_shapes, device): - reference_points_list = [] - for lvl, (H_, W_) in enumerate(spatial_shapes): - ref_y, ref_x = torch.meshgrid( - torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), - torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device), - ) - ref_y = ref_y.reshape(-1)[None] / H_ - ref_x = ref_x.reshape(-1)[None] / W_ - ref = torch.stack((ref_x, ref_y), -1) - reference_points_list.append(ref) - reference_points = torch.cat(reference_points_list, 1) - reference_points = reference_points[:, :, None] - return reference_points - - -def deform_inputs(x, patch_size): - bs, c, h, w = x.shape - spatial_shapes = torch.as_tensor( - [(h // 8, w // 8), (h // 16, w // 16), (h // 32, w // 32)], dtype=torch.long, device=x.device - ) - level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) - reference_points = get_reference_points([(h // patch_size, w // patch_size)], x.device) - deform_inputs1 = [reference_points, spatial_shapes, level_start_index] - - spatial_shapes = torch.as_tensor([(h // patch_size, w // patch_size)], dtype=torch.long, device=x.device) - level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) - reference_points = get_reference_points([(h // 8, w // 8), (h // 16, w // 16), (h // 32, w // 32)], x.device) - deform_inputs2 = [reference_points, spatial_shapes, level_start_index] - - return deform_inputs1, deform_inputs2 - - -class ConvFFN(nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.dwconv = DWConv(hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x, H, W): - x = self.fc1(x) - x = self.dwconv(x, H, W) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -class DWConv(nn.Module): - def __init__(self, dim=768): - super().__init__() - self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) - - def forward(self, x, H, W): - B, N, C = x.shape - n = N // 21 - x1 = x[:, 0 : 16 * n, :].transpose(1, 2).view(B, C, H * 2, W * 2).contiguous() - x2 = x[:, 16 * n : 20 * n, :].transpose(1, 2).view(B, C, H, W).contiguous() - x3 = x[:, 20 * n :, :].transpose(1, 2).view(B, C, H // 2, W // 2).contiguous() - x1 = self.dwconv(x1).flatten(2).transpose(1, 2) - x2 = self.dwconv(x2).flatten(2).transpose(1, 2) - x3 = self.dwconv(x3).flatten(2).transpose(1, 2) - x = torch.cat([x1, x2, x3], dim=1) - return x - - -class Extractor(nn.Module): - def __init__( - self, - dim, - num_heads=6, - n_points=4, - n_levels=1, - deform_ratio=1.0, - with_cffn=True, - cffn_ratio=0.25, - drop=0.0, - drop_path=0.0, - norm_layer=partial(nn.LayerNorm, eps=1e-6), - with_cp=False, - ): - super().__init__() - self.query_norm = norm_layer(dim) - self.feat_norm = norm_layer(dim) - self.attn = MSDeformAttn( - d_model=dim, n_levels=n_levels, n_heads=num_heads, n_points=n_points, ratio=deform_ratio - ) - self.with_cffn = with_cffn - self.with_cp = with_cp - if with_cffn: - self.ffn = ConvFFN(in_features=dim, hidden_features=int(dim * cffn_ratio), drop=drop) - self.ffn_norm = norm_layer(dim) - self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() - - def forward(self, query, reference_points, feat, spatial_shapes, level_start_index, H, W): - def _inner_forward(query, feat): - - attn = self.attn( - self.query_norm(query), reference_points, self.feat_norm(feat), spatial_shapes, level_start_index, None - ) - query = query + attn - - if self.with_cffn: - query = query + self.drop_path(self.ffn(self.ffn_norm(query), H, W)) - return query - - if self.with_cp and query.requires_grad: - query = cp.checkpoint(_inner_forward, query, feat) - else: - query = _inner_forward(query, feat) - - return query - - -class Injector(nn.Module): - def __init__( - self, - dim, - num_heads=6, - n_points=4, - n_levels=1, - deform_ratio=1.0, - norm_layer=partial(nn.LayerNorm, eps=1e-6), - init_values=0.0, - with_cp=False, - ): - super().__init__() - self.with_cp = with_cp - self.query_norm = norm_layer(dim) - self.feat_norm = norm_layer(dim) - self.attn = MSDeformAttn( - d_model=dim, n_levels=n_levels, n_heads=num_heads, n_points=n_points, ratio=deform_ratio - ) - self.gamma = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) - - def forward(self, query, reference_points, feat, spatial_shapes, level_start_index): - def _inner_forward(query, feat): - - attn = self.attn( - self.query_norm(query), reference_points, self.feat_norm(feat), spatial_shapes, level_start_index, None - ) - return query + self.gamma * attn - - if self.with_cp and query.requires_grad: - query = cp.checkpoint(_inner_forward, query, feat) - else: - query = _inner_forward(query, feat) - - return query - - -class InteractionBlock(nn.Module): - def __init__( - self, - dim, - num_heads=6, - n_points=4, - norm_layer=partial(nn.LayerNorm, eps=1e-6), - drop=0.0, - drop_path=0.0, - with_cffn=True, - cffn_ratio=0.25, - init_values=0.0, - deform_ratio=1.0, - extra_extractor=False, - with_cp=False, - ): - super().__init__() - - self.injector = Injector( - dim=dim, - n_levels=3, - num_heads=num_heads, - init_values=init_values, - n_points=n_points, - norm_layer=norm_layer, - deform_ratio=deform_ratio, - with_cp=with_cp, - ) - self.extractor = Extractor( - dim=dim, - n_levels=1, - num_heads=num_heads, - n_points=n_points, - norm_layer=norm_layer, - deform_ratio=deform_ratio, - with_cffn=with_cffn, - cffn_ratio=cffn_ratio, - drop=drop, - drop_path=drop_path, - with_cp=with_cp, - ) - if extra_extractor: - self.extra_extractors = nn.Sequential( - *[ - Extractor( - dim=dim, - num_heads=num_heads, - n_points=n_points, - norm_layer=norm_layer, - with_cffn=with_cffn, - cffn_ratio=cffn_ratio, - deform_ratio=deform_ratio, - drop=drop, - drop_path=drop_path, - with_cp=with_cp, - ) - for _ in range(2) - ] - ) - else: - self.extra_extractors = None - - def forward(self, x, c, blocks, deform_inputs1, deform_inputs2, H_c, W_c, H_toks, W_toks): - x = self.injector( - query=x, - reference_points=deform_inputs1[0], - feat=c, - spatial_shapes=deform_inputs1[1], - level_start_index=deform_inputs1[2], - ) - for idx, blk in enumerate(blocks): - x = blk(x, H_toks, W_toks) - c = self.extractor( - query=c, - reference_points=deform_inputs2[0], - feat=x, - spatial_shapes=deform_inputs2[1], - level_start_index=deform_inputs2[2], - H=H_c, - W=W_c, - ) - if self.extra_extractors is not None: - for extractor in self.extra_extractors: - c = extractor( - query=c, - reference_points=deform_inputs2[0], - feat=x, - spatial_shapes=deform_inputs2[1], - level_start_index=deform_inputs2[2], - H=H_c, - W=W_c, - ) - return x, c - - -class InteractionBlockWithCls(nn.Module): - def __init__( - self, - dim, - num_heads=6, - n_points=4, - norm_layer=partial(nn.LayerNorm, eps=1e-6), - drop=0.0, - drop_path=0.0, - with_cffn=True, - cffn_ratio=0.25, - init_values=0.0, - deform_ratio=1.0, - extra_extractor=False, - with_cp=False, - ): - super().__init__() - - self.injector = Injector( - dim=dim, - n_levels=3, - num_heads=num_heads, - init_values=init_values, - n_points=n_points, - norm_layer=norm_layer, - deform_ratio=deform_ratio, - with_cp=with_cp, - ) - self.extractor = Extractor( - dim=dim, - n_levels=1, - num_heads=num_heads, - n_points=n_points, - norm_layer=norm_layer, - deform_ratio=deform_ratio, - with_cffn=with_cffn, - cffn_ratio=cffn_ratio, - drop=drop, - drop_path=drop_path, - with_cp=with_cp, - ) - if extra_extractor: - self.extra_extractors = nn.Sequential( - *[ - Extractor( - dim=dim, - num_heads=num_heads, - n_points=n_points, - norm_layer=norm_layer, - with_cffn=with_cffn, - cffn_ratio=cffn_ratio, - deform_ratio=deform_ratio, - drop=drop, - drop_path=drop_path, - with_cp=with_cp, - ) - for _ in range(2) - ] - ) - else: - self.extra_extractors = None - - def forward(self, x, c, cls, blocks, deform_inputs1, deform_inputs2, H_c, W_c, H_toks, W_toks): - x = self.injector( - query=x, - reference_points=deform_inputs1[0], - feat=c, - spatial_shapes=deform_inputs1[1], - level_start_index=deform_inputs1[2], - ) - x = torch.cat((cls, x), dim=1) - for idx, blk in enumerate(blocks): - x = blk(x, H_toks, W_toks) - cls, x = ( - x[ - :, - :1, - ], - x[ - :, - 1:, - ], - ) - c = self.extractor( - query=c, - reference_points=deform_inputs2[0], - feat=x, - spatial_shapes=deform_inputs2[1], - level_start_index=deform_inputs2[2], - H=H_c, - W=W_c, - ) - if self.extra_extractors is not None: - for extractor in self.extra_extractors: - c = extractor( - query=c, - reference_points=deform_inputs2[0], - feat=x, - spatial_shapes=deform_inputs2[1], - level_start_index=deform_inputs2[2], - H=H_c, - W=W_c, - ) - return x, c, cls - - -class SpatialPriorModule(nn.Module): - def __init__(self, inplanes=64, embed_dim=384, with_cp=False): - super().__init__() - self.with_cp = with_cp - - self.stem = nn.Sequential( - *[ - nn.Conv2d(3, inplanes, kernel_size=3, stride=2, padding=1, bias=False), - nn.SyncBatchNorm(inplanes), - nn.ReLU(inplace=True), - nn.Conv2d(inplanes, inplanes, kernel_size=3, stride=1, padding=1, bias=False), - nn.SyncBatchNorm(inplanes), - nn.ReLU(inplace=True), - nn.Conv2d(inplanes, inplanes, kernel_size=3, stride=1, padding=1, bias=False), - nn.SyncBatchNorm(inplanes), - nn.ReLU(inplace=True), - nn.MaxPool2d(kernel_size=3, stride=2, padding=1), - ] - ) - self.conv2 = nn.Sequential( - *[ - nn.Conv2d(inplanes, 2 * inplanes, kernel_size=3, stride=2, padding=1, bias=False), - nn.SyncBatchNorm(2 * inplanes), - nn.ReLU(inplace=True), - ] - ) - self.conv3 = nn.Sequential( - *[ - nn.Conv2d(2 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False), - nn.SyncBatchNorm(4 * inplanes), - nn.ReLU(inplace=True), - ] - ) - self.conv4 = nn.Sequential( - *[ - nn.Conv2d(4 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False), - nn.SyncBatchNorm(4 * inplanes), - nn.ReLU(inplace=True), - ] - ) - self.fc1 = nn.Conv2d(inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True) - self.fc2 = nn.Conv2d(2 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True) - self.fc3 = nn.Conv2d(4 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True) - self.fc4 = nn.Conv2d(4 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True) - - def forward(self, x): - def _inner_forward(x): - c1 = self.stem(x) - c2 = self.conv2(c1) - c3 = self.conv3(c2) - c4 = self.conv4(c3) - c1 = self.fc1(c1) - c2 = self.fc2(c2) - c3 = self.fc3(c3) - c4 = self.fc4(c4) - - bs, dim, _, _ = c1.shape - # c1 = c1.view(bs, dim, -1).transpose(1, 2) # 4s - c2 = c2.view(bs, dim, -1).transpose(1, 2) # 8s - c3 = c3.view(bs, dim, -1).transpose(1, 2) # 16s - c4 = c4.view(bs, dim, -1).transpose(1, 2) # 32s - - return c1, c2, c3, c4 - - if self.with_cp and x.requires_grad: - outs = cp.checkpoint(_inner_forward, x) - else: - outs = _inner_forward(x) - return outs diff --git a/dinov2/eval/segmentation_m2f/models/backbones/drop_path.py b/dinov2/eval/segmentation_m2f/models/backbones/drop_path.py deleted file mode 100644 index 864eb8738c44652d12b979fc811503f21cbb00dd..0000000000000000000000000000000000000000 --- a/dinov2/eval/segmentation_m2f/models/backbones/drop_path.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -# References: -# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py -# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py - -from torch import nn - - -def drop_path(x, drop_prob: float = 0.0, training: bool = False): - if drop_prob == 0.0 or not training: - return x - keep_prob = 1 - drop_prob - shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets - random_tensor = x.new_empty(shape).bernoulli_(keep_prob) - if keep_prob > 0.0: - random_tensor.div_(keep_prob) - return x * random_tensor - - -class DropPath(nn.Module): - """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" - - def __init__(self, drop_prob: float = 0.0): - super(DropPath, self).__init__() - self.drop_prob = drop_prob - - def forward(self, x): - return drop_path(x, self.drop_prob, self.training) diff --git a/dinov2/eval/segmentation_m2f/models/backbones/vit.py b/dinov2/eval/segmentation_m2f/models/backbones/vit.py deleted file mode 100644 index 8a147570451bd2fbd016ddfafbbfa33035cbd4f8..0000000000000000000000000000000000000000 --- a/dinov2/eval/segmentation_m2f/models/backbones/vit.py +++ /dev/null @@ -1,552 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -"""Vision Transformer (ViT) in PyTorch. - -A PyTorch implement of Vision Transformers as described in: - -'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - - https://arxiv.org/abs/2010.11929 - -`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers` - - https://arxiv.org/abs/2106.10270 - -The official jax code is released and available at https://github.com/google-research/vision_transformer - -DeiT model defs and weights from https://github.com/facebookresearch/deit, -paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877 - -Acknowledgments: -* The paper authors for releasing code and weights, thanks! -* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out -for some einops/einsum fun -* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT -* Bert reference code checks against Huggingface Transformers and Tensorflow Bert - -Hacked together by / Copyright 2021 Ross Wightman -""" -import logging -import math -from functools import partial -from itertools import repeat -from typing import Callable, Optional - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.utils.checkpoint as cp -from mmcv.runner import BaseModule, load_checkpoint -from mmseg.ops import resize -from mmseg.utils import get_root_logger -from torch import Tensor - -from .drop_path import DropPath - - -def to_2tuple(x): - return tuple(repeat(x, 2)) - - -class Mlp(nn.Module): - def __init__( - self, - in_features: int, - hidden_features: Optional[int] = None, - out_features: Optional[int] = None, - act_layer: Callable[..., nn.Module] = nn.GELU, - drop: float = 0.0, - bias: bool = True, - ) -> None: - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) - self.drop = nn.Dropout(drop) - - def forward(self, x: Tensor) -> Tensor: - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -class SwiGLUFFN(nn.Module): - def __init__( - self, - in_features: int, - hidden_features: Optional[int] = None, - out_features: Optional[int] = None, - act_layer: Callable[..., nn.Module] = None, - drop: float = 0.0, - ) -> None: - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - swiglu_hidden_features = int(2 * hidden_features / 3) - align_as = 8 - swiglu_hidden_features = (swiglu_hidden_features + align_as - 1) // align_as * align_as - self.w1 = nn.Linear(in_features, swiglu_hidden_features) - self.w2 = nn.Linear(in_features, swiglu_hidden_features) - self.w3 = nn.Linear(swiglu_hidden_features, out_features) - - def forward(self, x: Tensor) -> Tensor: - x1 = self.w1(x) - x2 = self.w2(x) - hidden = F.silu(x1) * x2 - return self.w3(hidden) - - -class PatchEmbed(nn.Module): - """2D Image to Patch Embedding.""" - - def __init__( - self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, bias=True - ): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - self.img_size = img_size - self.patch_size = patch_size - self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) - self.num_patches = self.grid_size[0] * self.grid_size[1] - self.flatten = flatten - - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) - self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() - - def forward(self, x): - x = self.proj(x) - _, _, H, W = x.shape - if self.flatten: - x = x.flatten(2).transpose(1, 2) # BCHW -> BNC - x = self.norm(x) - return x, H, W - - -class Attention(nn.Module): - def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0): - super().__init__() - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = head_dim**-0.5 - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - - def forward(self, x, H, W): - B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) - - attn = (q @ k.transpose(-2, -1)) * self.scale - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - -class MemEffAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int = 8, - qkv_bias: bool = False, - attn_drop: float = 0.0, - proj_drop: float = 0.0, - ) -> None: - super().__init__() - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = head_dim**-0.5 - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - - def forward(self, x: Tensor, H, W) -> Tensor: - from xformers.ops import memory_efficient_attention, unbind - - B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) - - q, k, v = unbind(qkv, 2) - - x = memory_efficient_attention(q, k, v) - x = x.reshape([B, N, C]) - - x = self.proj(x) - x = self.proj_drop(x) - return x - - -def window_partition(x, window_size): - """ - Args: - x: (B, H, W, C) - window_size (int): window size - Returns: - windows: (num_windows*B, window_size, window_size, C) - """ - B, H, W, C = x.shape - x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) - return windows - - -def window_reverse(windows, window_size, H, W): - """ - Args: - windows: (num_windows*B, window_size, window_size, C) - window_size (int): Window size - H (int): Height of image - W (int): Width of image - Returns: - x: (B, H, W, C) - """ - B = int(windows.shape[0] / (H * W / window_size / window_size)) - x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) - return x - - -class WindowedAttention(nn.Module): - def __init__( - self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0, window_size=14, pad_mode="constant" - ): - super().__init__() - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = head_dim**-0.5 - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - self.window_size = window_size - self.pad_mode = pad_mode - - def forward(self, x, H, W): - B, N, C = x.shape - N_ = self.window_size * self.window_size - H_ = math.ceil(H / self.window_size) * self.window_size - W_ = math.ceil(W / self.window_size) * self.window_size - - qkv = self.qkv(x) # [B, N, C] - qkv = qkv.transpose(1, 2).reshape(B, C * 3, H, W) # [B, C, H, W] - qkv = F.pad(qkv, [0, W_ - W, 0, H_ - H], mode=self.pad_mode) - - qkv = F.unfold( - qkv, kernel_size=(self.window_size, self.window_size), stride=(self.window_size, self.window_size) - ) - B, C_kw_kw, L = qkv.shape # L - the num of windows - qkv = qkv.reshape(B, C * 3, N_, L).permute(0, 3, 2, 1) # [B, L, N_, C] - qkv = qkv.reshape(B, L, N_, 3, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5) - q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) - - # q,k,v [B, L, num_head, N_, C/num_head] - attn = (q @ k.transpose(-2, -1)) * self.scale # [B, L, num_head, N_, N_] - # if self.mask: - # attn = attn * mask - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) # [B, L, num_head, N_, N_] - # attn @ v = [B, L, num_head, N_, C/num_head] - x = (attn @ v).permute(0, 2, 4, 3, 1).reshape(B, C_kw_kw // 3, L) - - x = F.fold( - x, - output_size=(H_, W_), - kernel_size=(self.window_size, self.window_size), - stride=(self.window_size, self.window_size), - ) # [B, C, H_, W_] - x = x[:, :, :H, :W].reshape(B, C, N).transpose(-1, -2) - x = self.proj(x) - x = self.proj_drop(x) - return x - - -# class WindowedAttention(nn.Module): -# def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., window_size=14, pad_mode="constant"): -# super().__init__() -# self.num_heads = num_heads -# head_dim = dim // num_heads -# self.scale = head_dim ** -0.5 -# -# self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) -# self.attn_drop = nn.Dropout(attn_drop) -# self.proj = nn.Linear(dim, dim) -# self.proj_drop = nn.Dropout(proj_drop) -# self.window_size = window_size -# self.pad_mode = pad_mode -# -# def forward(self, x, H, W): -# B, N, C = x.shape -# -# N_ = self.window_size * self.window_size -# H_ = math.ceil(H / self.window_size) * self.window_size -# W_ = math.ceil(W / self.window_size) * self.window_size -# x = x.view(B, H, W, C) -# x = F.pad(x, [0, 0, 0, W_ - W, 0, H_- H], mode=self.pad_mode) -# -# x = window_partition(x, window_size=self.window_size)# nW*B, window_size, window_size, C -# x = x.view(-1, N_, C) -# -# qkv = self.qkv(x).view(-1, N_, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) -# q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) -# attn = (q @ k.transpose(-2, -1)) * self.scale # [B, L, num_head, N_, N_] -# attn = attn.softmax(dim=-1) -# attn = self.attn_drop(attn) # [B, L, num_head, N_, N_] -# x = (attn @ v).transpose(1, 2).reshape(-1, self.window_size, self.window_size, C) -# -# x = window_reverse(x, self.window_size, H_, W_) -# x = x[:, :H, :W, :].reshape(B, N, C).contiguous() -# x = self.proj(x) -# x = self.proj_drop(x) -# return x - - -class Block(nn.Module): - def __init__( - self, - dim, - num_heads, - mlp_ratio=4.0, - qkv_bias=False, - drop=0.0, - attn_drop=0.0, - drop_path=0.0, - act_layer=nn.GELU, - norm_layer=nn.LayerNorm, - windowed=False, - window_size=14, - pad_mode="constant", - layer_scale=False, - with_cp=False, - ffn_layer=Mlp, - memeff=False, - ): - super().__init__() - self.with_cp = with_cp - self.norm1 = norm_layer(dim) - if windowed: - self.attn = WindowedAttention( - dim, - num_heads=num_heads, - qkv_bias=qkv_bias, - attn_drop=attn_drop, - proj_drop=drop, - window_size=window_size, - pad_mode=pad_mode, - ) - elif memeff: - self.attn = MemEffAttention( - dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop - ) - else: - self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) - # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here - self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = ffn_layer(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - self.layer_scale = layer_scale - if layer_scale: - self.gamma1 = nn.Parameter(torch.ones((dim)), requires_grad=True) - self.gamma2 = nn.Parameter(torch.ones((dim)), requires_grad=True) - - def forward(self, x, H, W): - def _inner_forward(x): - if self.layer_scale: - x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x), H, W)) - x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x))) - else: - x = x + self.drop_path(self.attn(self.norm1(x), H, W)) - x = x + self.drop_path(self.mlp(self.norm2(x))) - return x - - if self.with_cp and x.requires_grad: - x = cp.checkpoint(_inner_forward, x) - else: - x = _inner_forward(x) - - return x - - -class TIMMVisionTransformer(BaseModule): - """Vision Transformer. - - A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - - https://arxiv.org/abs/2010.11929 - - Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` - - https://arxiv.org/abs/2012.12877 - """ - - def __init__( - self, - img_size=224, - patch_size=16, - in_chans=3, - num_classes=1000, - embed_dim=768, - depth=12, - num_heads=12, - mlp_ratio=4.0, - qkv_bias=True, - drop_rate=0.0, - attn_drop_rate=0.0, - drop_path_rate=0.0, - layer_scale=True, - embed_layer=PatchEmbed, - norm_layer=partial(nn.LayerNorm, eps=1e-6), - act_layer=nn.GELU, - window_attn=False, - window_size=14, - pretrained=None, - with_cp=False, - pre_norm=False, - ffn_type="mlp", - memeff=False, - ): - """ - Args: - img_size (int, tuple): input image size - patch_size (int, tuple): patch size - in_chans (int): number of input channels - num_classes (int): number of classes for classification head - embed_dim (int): embedding dimension - depth (int): depth of transformer - num_heads (int): number of attention heads - mlp_ratio (int): ratio of mlp hidden dim to embedding dim - qkv_bias (bool): enable bias for qkv if True - drop_rate (float): dropout rate - attn_drop_rate (float): attention dropout rate - drop_path_rate (float): stochastic depth rate - embed_layer (nn.Module): patch embedding layer - norm_layer: (nn.Module): normalization layer - pretrained: (str): pretrained path - """ - super().__init__() - self.num_classes = num_classes - self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models - self.num_tokens = 1 - norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) - act_layer = act_layer or nn.GELU - self.norm_layer = norm_layer - self.act_layer = act_layer - self.pretrain_size = img_size - self.drop_path_rate = drop_path_rate - self.drop_rate = drop_rate - self.patch_size = patch_size - - window_attn = [window_attn] * depth if not isinstance(window_attn, list) else window_attn - window_size = [window_size] * depth if not isinstance(window_size, list) else window_size - logging.info("window attention:", window_attn) - logging.info("window size:", window_size) - logging.info("layer scale:", layer_scale) - - self.patch_embed = embed_layer( - img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, bias=not pre_norm - ) - num_patches = self.patch_embed.num_patches - - self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) - self.pos_drop = nn.Dropout(p=drop_rate) - - ffn_types = {"mlp": Mlp, "swiglu": SwiGLUFFN} - - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule - self.blocks = nn.Sequential( - *[ - Block( - dim=embed_dim, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - drop=drop_rate, - attn_drop=attn_drop_rate, - drop_path=dpr[i], - norm_layer=norm_layer, - act_layer=act_layer, - windowed=window_attn[i], - window_size=window_size[i], - layer_scale=layer_scale, - with_cp=with_cp, - ffn_layer=ffn_types[ffn_type], - memeff=memeff, - ) - for i in range(depth) - ] - ) - - # self.norm = norm_layer(embed_dim) - self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) - # For CLIP - if pre_norm: - norm_pre = norm_layer(embed_dim) - self.norm_pre = norm_pre - else: - self.norm_pre = nn.Identity() - self.init_weights(pretrained) - - def init_weights(self, pretrained=None): - if isinstance(pretrained, str): - logger = get_root_logger() - load_checkpoint(self, pretrained, map_location="cpu", strict=False, logger=logger) - - def forward_features(self, x): - x, H, W = self.patch_embed(x) - cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks - x = torch.cat((cls_token, x), dim=1) - x = self.pos_drop(x + self.pos_embed) - - # For CLIP - x = self.norm_pre(x) - - for blk in self.blocks: - x = blk(x, H, W) - x = self.norm(x) - return x - - def forward(self, x): - x = self.forward_features(x) - return x - - @staticmethod - def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode): - """Resize pos_embed weights. - - Resize pos_embed using bicubic interpolate method. - Args: - pos_embed (torch.Tensor): Position embedding weights. - input_shpae (tuple): Tuple for (downsampled input image height, - downsampled input image width). - pos_shape (tuple): The resolution of downsampled origin training - image. - mode (str): Algorithm used for upsampling: - ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | - ``'trilinear'``. Default: ``'nearest'`` - Return: - torch.Tensor: The resized pos_embed of shape [B, L_new, C] - """ - assert pos_embed.ndim == 3, "shape of pos_embed must be [B, L, C]" - pos_h, pos_w = pos_shape - # keep dim for easy deployment - cls_token_weight = pos_embed[:, 0:1] - pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w) :] - pos_embed_weight = pos_embed_weight.reshape(1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2) - pos_embed_weight = resize(pos_embed_weight, size=input_shpae, align_corners=False, mode=mode) - pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2) - pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1) - return pos_embed diff --git a/dinov2/eval/segmentation_m2f/models/backbones/vit_adapter.py b/dinov2/eval/segmentation_m2f/models/backbones/vit_adapter.py deleted file mode 100644 index ebc4f0f65e04ed764464d141607b3b2073220f6b..0000000000000000000000000000000000000000 --- a/dinov2/eval/segmentation_m2f/models/backbones/vit_adapter.py +++ /dev/null @@ -1,217 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import math - -import torch -import torch.nn as nn -import torch.nn.functional as F -from mmseg.models.builder import BACKBONES -from torch.nn.init import normal_ - -from ...ops.modules import MSDeformAttn -from .adapter_modules import InteractionBlock, InteractionBlockWithCls, SpatialPriorModule, deform_inputs -from .vit import TIMMVisionTransformer - - -@BACKBONES.register_module() -class ViTAdapter(TIMMVisionTransformer): - def __init__( - self, - pretrain_size=224, - num_heads=12, - conv_inplane=64, - n_points=4, - deform_num_heads=6, - init_values=0.0, - interaction_indexes=None, - with_cffn=True, - cffn_ratio=0.25, - deform_ratio=1.0, - add_vit_feature=True, - pretrained=None, - use_extra_extractor=True, - freeze_vit=False, - use_cls=True, - with_cp=False, - *args, - **kwargs - ): - - super().__init__(num_heads=num_heads, pretrained=pretrained, with_cp=with_cp, *args, **kwargs) - if freeze_vit: - for param in self.parameters(): - param.requires_grad = False - - # self.num_classes = 80 - self.use_cls = use_cls - if not self.use_cls: - self.cls_token = None - self.num_block = len(self.blocks) - self.pretrain_size = (pretrain_size, pretrain_size) - self.interaction_indexes = interaction_indexes - self.add_vit_feature = add_vit_feature - embed_dim = self.embed_dim - - block_fn = InteractionBlockWithCls if use_cls else InteractionBlock - - self.level_embed = nn.Parameter(torch.zeros(3, embed_dim)) - self.spm = SpatialPriorModule(inplanes=conv_inplane, embed_dim=embed_dim, with_cp=False) - self.interactions = nn.Sequential( - *[ - block_fn( - dim=embed_dim, - num_heads=deform_num_heads, - n_points=n_points, - init_values=init_values, - drop_path=self.drop_path_rate, - norm_layer=self.norm_layer, - with_cffn=with_cffn, - cffn_ratio=cffn_ratio, - deform_ratio=deform_ratio, - extra_extractor=((True if i == len(interaction_indexes) - 1 else False) and use_extra_extractor), - with_cp=with_cp, - ) - for i in range(len(interaction_indexes)) - ] - ) - self.up = nn.ConvTranspose2d(embed_dim, embed_dim, 2, 2) - self.norm1 = nn.SyncBatchNorm(embed_dim) - self.norm2 = nn.SyncBatchNorm(embed_dim) - self.norm3 = nn.SyncBatchNorm(embed_dim) - self.norm4 = nn.SyncBatchNorm(embed_dim) - - self.up.apply(self._init_weights) - self.spm.apply(self._init_weights) - self.interactions.apply(self._init_weights) - self.apply(self._init_deform_weights) - normal_(self.level_embed) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - torch.nn.init.trunc_normal_(m.weight, std=0.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm) or isinstance(m, nn.BatchNorm2d): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): - fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - fan_out //= m.groups - m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) - if m.bias is not None: - m.bias.data.zero_() - - def _get_pos_embed(self, pos_embed, H, W): - pos_embed = pos_embed.reshape( - 1, self.pretrain_size[0] // self.patch_size, self.pretrain_size[1] // self.patch_size, -1 - ).permute(0, 3, 1, 2) - pos_embed = ( - F.interpolate(pos_embed, size=(H, W), mode="bicubic", align_corners=False) - .reshape(1, -1, H * W) - .permute(0, 2, 1) - ) - return pos_embed - - def _init_deform_weights(self, m): - if isinstance(m, MSDeformAttn): - m._reset_parameters() - - def _add_level_embed(self, c2, c3, c4): - c2 = c2 + self.level_embed[0] - c3 = c3 + self.level_embed[1] - c4 = c4 + self.level_embed[2] - return c2, c3, c4 - - def forward(self, x): - deform_inputs1, deform_inputs2 = deform_inputs(x, self.patch_size) - - # SPM forward - c1, c2, c3, c4 = self.spm(x) - c2, c3, c4 = self._add_level_embed(c2, c3, c4) - c = torch.cat([c2, c3, c4], dim=1) - - # Patch Embedding forward - H_c, W_c = x.shape[2] // 16, x.shape[3] // 16 - x, H_toks, W_toks = self.patch_embed(x) - # print("H_toks, W_toks =", H_toks, W_toks) - bs, n, dim = x.shape - pos_embed = self._get_pos_embed(self.pos_embed[:, 1:], H_toks, W_toks) - if self.use_cls: - cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks - x = torch.cat((cls_token, x), dim=1) - pos_embed = torch.cat((self.pos_embed[:, :1], pos_embed), dim=1) - x = self.pos_drop(x + pos_embed) - # For CLIP - x = self.norm_pre(x) - - # Interaction - if self.use_cls: - cls, x = ( - x[ - :, - :1, - ], - x[ - :, - 1:, - ], - ) - outs = list() - for i, layer in enumerate(self.interactions): - indexes = self.interaction_indexes[i] - if self.use_cls: - x, c, cls = layer( - x, - c, - cls, - self.blocks[indexes[0] : indexes[-1] + 1], - deform_inputs1, - deform_inputs2, - H_c, - W_c, - H_toks, - W_toks, - ) - else: - x, c = layer( - x, - c, - self.blocks[indexes[0] : indexes[-1] + 1], - deform_inputs1, - deform_inputs2, - H_c, - W_c, - H_toks, - W_toks, - ) - outs.append(x.transpose(1, 2).view(bs, dim, H_toks, W_toks).contiguous()) - - # Split & Reshape - c2 = c[:, 0 : c2.size(1), :] - c3 = c[:, c2.size(1) : c2.size(1) + c3.size(1), :] - c4 = c[:, c2.size(1) + c3.size(1) :, :] - - c2 = c2.transpose(1, 2).view(bs, dim, H_c * 2, W_c * 2).contiguous() - c3 = c3.transpose(1, 2).view(bs, dim, H_c, W_c).contiguous() - c4 = c4.transpose(1, 2).view(bs, dim, H_c // 2, W_c // 2).contiguous() - c1 = self.up(c2) + c1 - - if self.add_vit_feature: - x1, x2, x3, x4 = outs - - x1 = F.interpolate(x1, size=(4 * H_c, 4 * W_c), mode="bilinear", align_corners=False) - x2 = F.interpolate(x2, size=(2 * H_c, 2 * W_c), mode="bilinear", align_corners=False) - x3 = F.interpolate(x3, size=(1 * H_c, 1 * W_c), mode="bilinear", align_corners=False) - x4 = F.interpolate(x4, size=(H_c // 2, W_c // 2), mode="bilinear", align_corners=False) - # print(c1.shape, c2.shape, c3.shape, c4.shape, x1.shape, x2.shape, x3.shape, x4.shape, H_c, H_toks) - c1, c2, c3, c4 = c1 + x1, c2 + x2, c3 + x3, c4 + x4 - - # Final Norm - f1 = self.norm1(c1) - f2 = self.norm2(c2) - f3 = self.norm3(c3) - f4 = self.norm4(c4) - return [f1, f2, f3, f4] diff --git a/dinov2/eval/segmentation_m2f/models/builder.py b/dinov2/eval/segmentation_m2f/models/builder.py deleted file mode 100644 index d7cf7b919f6b0e8e00bde45bc244d9c29a36fed6..0000000000000000000000000000000000000000 --- a/dinov2/eval/segmentation_m2f/models/builder.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from mmcv.utils import Registry - -TRANSFORMER = Registry("Transformer") -MASK_ASSIGNERS = Registry("mask_assigner") -MATCH_COST = Registry("match_cost") - - -def build_match_cost(cfg): - """Build Match Cost.""" - return MATCH_COST.build(cfg) - - -def build_assigner(cfg): - """Build Assigner.""" - return MASK_ASSIGNERS.build(cfg) - - -def build_transformer(cfg): - """Build Transformer.""" - return TRANSFORMER.build(cfg) diff --git a/dinov2/eval/segmentation_m2f/models/decode_heads/__init__.py b/dinov2/eval/segmentation_m2f/models/decode_heads/__init__.py deleted file mode 100644 index 01f08b88950750337781fc671adfea2a935ea8fe..0000000000000000000000000000000000000000 --- a/dinov2/eval/segmentation_m2f/models/decode_heads/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from .mask2former_head import Mask2FormerHead diff --git a/dinov2/eval/segmentation_m2f/models/decode_heads/mask2former_head.py b/dinov2/eval/segmentation_m2f/models/decode_heads/mask2former_head.py deleted file mode 100644 index d1705fc444fa8d1583d88fca36d7fe1e060db9e7..0000000000000000000000000000000000000000 --- a/dinov2/eval/segmentation_m2f/models/decode_heads/mask2former_head.py +++ /dev/null @@ -1,544 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import copy - -import torch -import torch.nn as nn -import torch.nn.functional as F -from mmcv.cnn import Conv2d, build_plugin_layer, caffe2_xavier_init -from mmcv.cnn.bricks.transformer import build_positional_encoding, build_transformer_layer_sequence -from mmcv.ops import point_sample -from mmcv.runner import ModuleList, force_fp32 -from mmseg.models.builder import HEADS, build_loss -from mmseg.models.decode_heads.decode_head import BaseDecodeHead - -from ...core import build_sampler, multi_apply, reduce_mean -from ..builder import build_assigner -from ..utils import get_uncertain_point_coords_with_randomness - - -@HEADS.register_module() -class Mask2FormerHead(BaseDecodeHead): - """Implements the Mask2Former head. - - See `Masked-attention Mask Transformer for Universal Image - Segmentation `_ for details. - - Args: - in_channels (list[int]): Number of channels in the input feature map. - feat_channels (int): Number of channels for features. - out_channels (int): Number of channels for output. - num_things_classes (int): Number of things. - num_stuff_classes (int): Number of stuff. - num_queries (int): Number of query in Transformer decoder. - pixel_decoder (:obj:`mmcv.ConfigDict` | dict): Config for pixel - decoder. Defaults to None. - enforce_decoder_input_project (bool, optional): Whether to add - a layer to change the embed_dim of tranformer encoder in - pixel decoder to the embed_dim of transformer decoder. - Defaults to False. - transformer_decoder (:obj:`mmcv.ConfigDict` | dict): Config for - transformer decoder. Defaults to None. - positional_encoding (:obj:`mmcv.ConfigDict` | dict): Config for - transformer decoder position encoding. Defaults to None. - loss_cls (:obj:`mmcv.ConfigDict` | dict): Config of the classification - loss. Defaults to None. - loss_mask (:obj:`mmcv.ConfigDict` | dict): Config of the mask loss. - Defaults to None. - loss_dice (:obj:`mmcv.ConfigDict` | dict): Config of the dice loss. - Defaults to None. - train_cfg (:obj:`mmcv.ConfigDict` | dict): Training config of - Mask2Former head. - test_cfg (:obj:`mmcv.ConfigDict` | dict): Testing config of - Mask2Former head. - init_cfg (dict or list[dict], optional): Initialization config dict. - Defaults to None. - """ - - def __init__( - self, - in_channels, - feat_channels, - out_channels, - num_things_classes=80, - num_stuff_classes=53, - num_queries=100, - num_transformer_feat_level=3, - pixel_decoder=None, - enforce_decoder_input_project=False, - transformer_decoder=None, - positional_encoding=None, - loss_cls=None, - loss_mask=None, - loss_dice=None, - train_cfg=None, - test_cfg=None, - init_cfg=None, - **kwargs, - ): - super(Mask2FormerHead, self).__init__( - in_channels=in_channels, - channels=feat_channels, - num_classes=(num_things_classes + num_stuff_classes), - init_cfg=init_cfg, - input_transform="multiple_select", - **kwargs, - ) - self.num_things_classes = num_things_classes - self.num_stuff_classes = num_stuff_classes - self.num_classes = self.num_things_classes + self.num_stuff_classes - self.num_queries = num_queries - self.num_transformer_feat_level = num_transformer_feat_level - self.num_heads = transformer_decoder.transformerlayers.attn_cfgs.num_heads - self.num_transformer_decoder_layers = transformer_decoder.num_layers - assert pixel_decoder.encoder.transformerlayers.attn_cfgs.num_levels == num_transformer_feat_level - pixel_decoder_ = copy.deepcopy(pixel_decoder) - pixel_decoder_.update(in_channels=in_channels, feat_channels=feat_channels, out_channels=out_channels) - self.pixel_decoder = build_plugin_layer(pixel_decoder_)[1] - self.transformer_decoder = build_transformer_layer_sequence(transformer_decoder) - self.decoder_embed_dims = self.transformer_decoder.embed_dims - - self.decoder_input_projs = ModuleList() - # from low resolution to high resolution - for _ in range(num_transformer_feat_level): - if self.decoder_embed_dims != feat_channels or enforce_decoder_input_project: - self.decoder_input_projs.append(Conv2d(feat_channels, self.decoder_embed_dims, kernel_size=1)) - else: - self.decoder_input_projs.append(nn.Identity()) - self.decoder_positional_encoding = build_positional_encoding(positional_encoding) - self.query_embed = nn.Embedding(self.num_queries, feat_channels) - self.query_feat = nn.Embedding(self.num_queries, feat_channels) - # from low resolution to high resolution - self.level_embed = nn.Embedding(self.num_transformer_feat_level, feat_channels) - - self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1) - self.mask_embed = nn.Sequential( - nn.Linear(feat_channels, feat_channels), - nn.ReLU(inplace=True), - nn.Linear(feat_channels, feat_channels), - nn.ReLU(inplace=True), - nn.Linear(feat_channels, out_channels), - ) - self.conv_seg = None # fix a bug here (conv_seg is not used) - - self.test_cfg = test_cfg - self.train_cfg = train_cfg - if train_cfg: - self.assigner = build_assigner(self.train_cfg.assigner) - self.sampler = build_sampler(self.train_cfg.sampler, context=self) - self.num_points = self.train_cfg.get("num_points", 12544) - self.oversample_ratio = self.train_cfg.get("oversample_ratio", 3.0) - self.importance_sample_ratio = self.train_cfg.get("importance_sample_ratio", 0.75) - - self.class_weight = loss_cls.class_weight - self.loss_cls = build_loss(loss_cls) - self.loss_mask = build_loss(loss_mask) - self.loss_dice = build_loss(loss_dice) - - def init_weights(self): - for m in self.decoder_input_projs: - if isinstance(m, Conv2d): - caffe2_xavier_init(m, bias=0) - - self.pixel_decoder.init_weights() - - for p in self.transformer_decoder.parameters(): - if p.dim() > 1: - nn.init.xavier_normal_(p) - - def get_targets(self, cls_scores_list, mask_preds_list, gt_labels_list, gt_masks_list, img_metas): - """Compute classification and mask targets for all images for a decoder - layer. - - Args: - cls_scores_list (list[Tensor]): Mask score logits from a single - decoder layer for all images. Each with shape [num_queries, - cls_out_channels]. - mask_preds_list (list[Tensor]): Mask logits from a single decoder - layer for all images. Each with shape [num_queries, h, w]. - gt_labels_list (list[Tensor]): Ground truth class indices for all - images. Each with shape (n, ), n is the sum of number of stuff - type and number of instance in a image. - gt_masks_list (list[Tensor]): Ground truth mask for each image, - each with shape (n, h, w). - img_metas (list[dict]): List of image meta information. - - Returns: - tuple[list[Tensor]]: a tuple containing the following targets. - - - labels_list (list[Tensor]): Labels of all images. - Each with shape [num_queries, ]. - - label_weights_list (list[Tensor]): Label weights of all - images.Each with shape [num_queries, ]. - - mask_targets_list (list[Tensor]): Mask targets of all images. - Each with shape [num_queries, h, w]. - - mask_weights_list (list[Tensor]): Mask weights of all images. - Each with shape [num_queries, ]. - - num_total_pos (int): Number of positive samples in all - images. - - num_total_neg (int): Number of negative samples in all - images. - """ - ( - labels_list, - label_weights_list, - mask_targets_list, - mask_weights_list, - pos_inds_list, - neg_inds_list, - ) = multi_apply( - self._get_target_single, cls_scores_list, mask_preds_list, gt_labels_list, gt_masks_list, img_metas - ) - - num_total_pos = sum((inds.numel() for inds in pos_inds_list)) - num_total_neg = sum((inds.numel() for inds in neg_inds_list)) - return (labels_list, label_weights_list, mask_targets_list, mask_weights_list, num_total_pos, num_total_neg) - - def _get_target_single(self, cls_score, mask_pred, gt_labels, gt_masks, img_metas): - """Compute classification and mask targets for one image. - - Args: - cls_score (Tensor): Mask score logits from a single decoder layer - for one image. Shape (num_queries, cls_out_channels). - mask_pred (Tensor): Mask logits for a single decoder layer for one - image. Shape (num_queries, h, w). - gt_labels (Tensor): Ground truth class indices for one image with - shape (num_gts, ). - gt_masks (Tensor): Ground truth mask for each image, each with - shape (num_gts, h, w). - img_metas (dict): Image informtation. - - Returns: - tuple[Tensor]: A tuple containing the following for one image. - - - labels (Tensor): Labels of each image. \ - shape (num_queries, ). - - label_weights (Tensor): Label weights of each image. \ - shape (num_queries, ). - - mask_targets (Tensor): Mask targets of each image. \ - shape (num_queries, h, w). - - mask_weights (Tensor): Mask weights of each image. \ - shape (num_queries, ). - - pos_inds (Tensor): Sampled positive indices for each \ - image. - - neg_inds (Tensor): Sampled negative indices for each \ - image. - """ - # sample points - num_queries = cls_score.shape[0] - num_gts = gt_labels.shape[0] - - point_coords = torch.rand((1, self.num_points, 2), device=cls_score.device) - # shape (num_queries, num_points) - mask_points_pred = point_sample(mask_pred.unsqueeze(1), point_coords.repeat(num_queries, 1, 1)).squeeze(1) - # shape (num_gts, num_points) - gt_points_masks = point_sample(gt_masks.unsqueeze(1).float(), point_coords.repeat(num_gts, 1, 1)).squeeze(1) - - # assign and sample - assign_result = self.assigner.assign(cls_score, mask_points_pred, gt_labels, gt_points_masks, img_metas) - sampling_result = self.sampler.sample(assign_result, mask_pred, gt_masks) - pos_inds = sampling_result.pos_inds - neg_inds = sampling_result.neg_inds - - # label target - labels = gt_labels.new_full((self.num_queries,), self.num_classes, dtype=torch.long) - labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds] - label_weights = gt_labels.new_ones((self.num_queries,)) - - # mask target - mask_targets = gt_masks[sampling_result.pos_assigned_gt_inds] - mask_weights = mask_pred.new_zeros((self.num_queries,)) - mask_weights[pos_inds] = 1.0 - - return (labels, label_weights, mask_targets, mask_weights, pos_inds, neg_inds) - - def loss_single(self, cls_scores, mask_preds, gt_labels_list, gt_masks_list, img_metas): - """Loss function for outputs from a single decoder layer. - - Args: - cls_scores (Tensor): Mask score logits from a single decoder layer - for all images. Shape (batch_size, num_queries, - cls_out_channels). Note `cls_out_channels` should includes - background. - mask_preds (Tensor): Mask logits for a pixel decoder for all - images. Shape (batch_size, num_queries, h, w). - gt_labels_list (list[Tensor]): Ground truth class indices for each - image, each with shape (num_gts, ). - gt_masks_list (list[Tensor]): Ground truth mask for each image, - each with shape (num_gts, h, w). - img_metas (list[dict]): List of image meta information. - - Returns: - tuple[Tensor]: Loss components for outputs from a single \ - decoder layer. - """ - num_imgs = cls_scores.size(0) - cls_scores_list = [cls_scores[i] for i in range(num_imgs)] - mask_preds_list = [mask_preds[i] for i in range(num_imgs)] - ( - labels_list, - label_weights_list, - mask_targets_list, - mask_weights_list, - num_total_pos, - num_total_neg, - ) = self.get_targets(cls_scores_list, mask_preds_list, gt_labels_list, gt_masks_list, img_metas) - # shape (batch_size, num_queries) - labels = torch.stack(labels_list, dim=0) - # shape (batch_size, num_queries) - label_weights = torch.stack(label_weights_list, dim=0) - # shape (num_total_gts, h, w) - mask_targets = torch.cat(mask_targets_list, dim=0) - # shape (batch_size, num_queries) - mask_weights = torch.stack(mask_weights_list, dim=0) - - # classfication loss - # shape (batch_size * num_queries, ) - cls_scores = cls_scores.flatten(0, 1) - labels = labels.flatten(0, 1) - label_weights = label_weights.flatten(0, 1) - - class_weight = cls_scores.new_tensor(self.class_weight) - loss_cls = self.loss_cls(cls_scores, labels, label_weights, avg_factor=class_weight[labels].sum()) - - num_total_masks = reduce_mean(cls_scores.new_tensor([num_total_pos])) - num_total_masks = max(num_total_masks, 1) - - # extract positive ones - # shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w) - mask_preds = mask_preds[mask_weights > 0] - - if mask_targets.shape[0] == 0: - # zero match - loss_dice = mask_preds.sum() - loss_mask = mask_preds.sum() - return loss_cls, loss_mask, loss_dice - - with torch.no_grad(): - points_coords = get_uncertain_point_coords_with_randomness( - mask_preds.unsqueeze(1), None, self.num_points, self.oversample_ratio, self.importance_sample_ratio - ) - # shape (num_total_gts, h, w) -> (num_total_gts, num_points) - mask_point_targets = point_sample(mask_targets.unsqueeze(1).float(), points_coords).squeeze(1) - # shape (num_queries, h, w) -> (num_queries, num_points) - mask_point_preds = point_sample(mask_preds.unsqueeze(1), points_coords).squeeze(1) - - # dice loss - loss_dice = self.loss_dice(mask_point_preds, mask_point_targets, avg_factor=num_total_masks) - - # mask loss - # shape (num_queries, num_points) -> (num_queries * num_points, ) - mask_point_preds = mask_point_preds.reshape(-1, 1) - # shape (num_total_gts, num_points) -> (num_total_gts * num_points, ) - mask_point_targets = mask_point_targets.reshape(-1) - loss_mask = self.loss_mask(mask_point_preds, mask_point_targets, avg_factor=num_total_masks * self.num_points) - - return loss_cls, loss_mask, loss_dice - - @force_fp32(apply_to=("all_cls_scores", "all_mask_preds")) - def loss(self, all_cls_scores, all_mask_preds, gt_labels_list, gt_masks_list, img_metas): - """Loss function. - - Args: - all_cls_scores (Tensor): Classification scores for all decoder - layers with shape [num_decoder, batch_size, num_queries, - cls_out_channels]. - all_mask_preds (Tensor): Mask scores for all decoder layers with - shape [num_decoder, batch_size, num_queries, h, w]. - gt_labels_list (list[Tensor]): Ground truth class indices for each - image with shape (n, ). n is the sum of number of stuff type - and number of instance in a image. - gt_masks_list (list[Tensor]): Ground truth mask for each image with - shape (n, h, w). - img_metas (list[dict]): List of image meta information. - - Returns: - dict[str, Tensor]: A dictionary of loss components. - """ - num_dec_layers = len(all_cls_scores) - all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)] - all_gt_masks_list = [gt_masks_list for _ in range(num_dec_layers)] - img_metas_list = [img_metas for _ in range(num_dec_layers)] - losses_cls, losses_mask, losses_dice = multi_apply( - self.loss_single, all_cls_scores, all_mask_preds, all_gt_labels_list, all_gt_masks_list, img_metas_list - ) - - loss_dict = dict() - # loss from the last decoder layer - loss_dict["loss_cls"] = losses_cls[-1] - loss_dict["loss_mask"] = losses_mask[-1] - loss_dict["loss_dice"] = losses_dice[-1] - # loss from other decoder layers - num_dec_layer = 0 - for loss_cls_i, loss_mask_i, loss_dice_i in zip(losses_cls[:-1], losses_mask[:-1], losses_dice[:-1]): - loss_dict[f"d{num_dec_layer}.loss_cls"] = loss_cls_i - loss_dict[f"d{num_dec_layer}.loss_mask"] = loss_mask_i - loss_dict[f"d{num_dec_layer}.loss_dice"] = loss_dice_i - num_dec_layer += 1 - return loss_dict - - def forward_head(self, decoder_out, mask_feature, attn_mask_target_size): - """Forward for head part which is called after every decoder layer. - - Args: - decoder_out (Tensor): in shape (num_queries, batch_size, c). - mask_feature (Tensor): in shape (batch_size, c, h, w). - attn_mask_target_size (tuple[int, int]): target attention - mask size. - - Returns: - tuple: A tuple contain three elements. - - - cls_pred (Tensor): Classification scores in shape \ - (batch_size, num_queries, cls_out_channels). \ - Note `cls_out_channels` should includes background. - - mask_pred (Tensor): Mask scores in shape \ - (batch_size, num_queries,h, w). - - attn_mask (Tensor): Attention mask in shape \ - (batch_size * num_heads, num_queries, h, w). - """ - decoder_out = self.transformer_decoder.post_norm(decoder_out) - decoder_out = decoder_out.transpose(0, 1) - # shape (num_queries, batch_size, c) - cls_pred = self.cls_embed(decoder_out) - # shape (num_queries, batch_size, c) - mask_embed = self.mask_embed(decoder_out) - # shape (num_queries, batch_size, h, w) - mask_pred = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_feature) - attn_mask = F.interpolate(mask_pred, attn_mask_target_size, mode="bilinear", align_corners=False) - # shape (num_queries, batch_size, h, w) -> - # (batch_size * num_head, num_queries, h, w) - attn_mask = attn_mask.flatten(2).unsqueeze(1).repeat((1, self.num_heads, 1, 1)).flatten(0, 1) - attn_mask = attn_mask.sigmoid() < 0.5 - attn_mask = attn_mask.detach() - - return cls_pred, mask_pred, attn_mask - - def forward(self, feats, img_metas): - """Forward function. - - Args: - feats (list[Tensor]): Multi scale Features from the - upstream network, each is a 4D-tensor. - img_metas (list[dict]): List of image information. - - Returns: - tuple: A tuple contains two elements. - - - cls_pred_list (list[Tensor)]: Classification logits \ - for each decoder layer. Each is a 3D-tensor with shape \ - (batch_size, num_queries, cls_out_channels). \ - Note `cls_out_channels` should includes background. - - mask_pred_list (list[Tensor]): Mask logits for each \ - decoder layer. Each with shape (batch_size, num_queries, \ - h, w). - """ - batch_size = len(img_metas) - mask_features, multi_scale_memorys = self.pixel_decoder(feats) - # multi_scale_memorys (from low resolution to high resolution) - decoder_inputs = [] - decoder_positional_encodings = [] - for i in range(self.num_transformer_feat_level): - decoder_input = self.decoder_input_projs[i](multi_scale_memorys[i]) - # shape (batch_size, c, h, w) -> (h*w, batch_size, c) - decoder_input = decoder_input.flatten(2).permute(2, 0, 1) - level_embed = self.level_embed.weight[i].view(1, 1, -1) - decoder_input = decoder_input + level_embed - # shape (batch_size, c, h, w) -> (h*w, batch_size, c) - mask = decoder_input.new_zeros((batch_size,) + multi_scale_memorys[i].shape[-2:], dtype=torch.bool) - decoder_positional_encoding = self.decoder_positional_encoding(mask) - decoder_positional_encoding = decoder_positional_encoding.flatten(2).permute(2, 0, 1) - decoder_inputs.append(decoder_input) - decoder_positional_encodings.append(decoder_positional_encoding) - # shape (num_queries, c) -> (num_queries, batch_size, c) - query_feat = self.query_feat.weight.unsqueeze(1).repeat((1, batch_size, 1)) - query_embed = self.query_embed.weight.unsqueeze(1).repeat((1, batch_size, 1)) - - cls_pred_list = [] - mask_pred_list = [] - cls_pred, mask_pred, attn_mask = self.forward_head(query_feat, mask_features, multi_scale_memorys[0].shape[-2:]) - cls_pred_list.append(cls_pred) - mask_pred_list.append(mask_pred) - - for i in range(self.num_transformer_decoder_layers): - level_idx = i % self.num_transformer_feat_level - # if a mask is all True(all background), then set it all False. - attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False - - # cross_attn + self_attn - layer = self.transformer_decoder.layers[i] - attn_masks = [attn_mask, None] - query_feat = layer( - query=query_feat, - key=decoder_inputs[level_idx], - value=decoder_inputs[level_idx], - query_pos=query_embed, - key_pos=decoder_positional_encodings[level_idx], - attn_masks=attn_masks, - query_key_padding_mask=None, - # here we do not apply masking on padded region - key_padding_mask=None, - ) - cls_pred, mask_pred, attn_mask = self.forward_head( - query_feat, mask_features, multi_scale_memorys[(i + 1) % self.num_transformer_feat_level].shape[-2:] - ) - - cls_pred_list.append(cls_pred) - mask_pred_list.append(mask_pred) - - return cls_pred_list, mask_pred_list - - def forward_train(self, x, img_metas, gt_semantic_seg, gt_labels, gt_masks): - """Forward function for training mode. - - Args: - x (list[Tensor]): Multi-level features from the upstream network, - each is a 4D-tensor. - img_metas (list[Dict]): List of image information. - gt_semantic_seg (list[tensor]):Each element is the ground truth - of semantic segmentation with the shape (N, H, W). - train_cfg (dict): The training config, which not been used in - maskformer. - gt_labels (list[Tensor]): Each element is ground truth labels of - each box, shape (num_gts,). - gt_masks (list[BitmapMasks]): Each element is masks of instances - of a image, shape (num_gts, h, w). - - Returns: - losses (dict[str, Tensor]): a dictionary of loss components - """ - - # forward - all_cls_scores, all_mask_preds = self(x, img_metas) - - # loss - losses = self.loss(all_cls_scores, all_mask_preds, gt_labels, gt_masks, img_metas) - - return losses - - def forward_test(self, inputs, img_metas, test_cfg): - """Test segment without test-time aumengtation. - - Only the output of last decoder layers was used. - - Args: - inputs (list[Tensor]): Multi-level features from the - upstream network, each is a 4D-tensor. - img_metas (list[dict]): List of image information. - test_cfg (dict): Testing config. - - Returns: - seg_mask (Tensor): Predicted semantic segmentation logits. - """ - all_cls_scores, all_mask_preds = self(inputs, img_metas) - cls_score, mask_pred = all_cls_scores[-1], all_mask_preds[-1] - ori_h, ori_w, _ = img_metas[0]["ori_shape"] - - # semantic inference - cls_score = F.softmax(cls_score, dim=-1)[..., :-1] - mask_pred = mask_pred.sigmoid() - seg_mask = torch.einsum("bqc,bqhw->bchw", cls_score, mask_pred) - return seg_mask diff --git a/dinov2/eval/segmentation_m2f/models/losses/__init__.py b/dinov2/eval/segmentation_m2f/models/losses/__init__.py deleted file mode 100644 index 229a887817372f4991b32354180592cfb236d728..0000000000000000000000000000000000000000 --- a/dinov2/eval/segmentation_m2f/models/losses/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from .cross_entropy_loss import CrossEntropyLoss, binary_cross_entropy, cross_entropy, mask_cross_entropy -from .dice_loss import DiceLoss -from .match_costs import ClassificationCost, CrossEntropyLossCost, DiceCost diff --git a/dinov2/eval/segmentation_m2f/models/losses/cross_entropy_loss.py b/dinov2/eval/segmentation_m2f/models/losses/cross_entropy_loss.py deleted file mode 100644 index 0a1f9dd4aa52ebe94cc527db36b1c7fa2f53813e..0000000000000000000000000000000000000000 --- a/dinov2/eval/segmentation_m2f/models/losses/cross_entropy_loss.py +++ /dev/null @@ -1,279 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import warnings - -import torch -import torch.nn as nn -import torch.nn.functional as F -from mmseg.models.builder import LOSSES -from mmseg.models.losses.utils import get_class_weight, weight_reduce_loss - - -def cross_entropy( - pred, - label, - weight=None, - class_weight=None, - reduction="mean", - avg_factor=None, - ignore_index=-100, - avg_non_ignore=False, -): - """cross_entropy. The wrapper function for :func:`F.cross_entropy` - - Args: - pred (torch.Tensor): The prediction with shape (N, 1). - label (torch.Tensor): The learning label of the prediction. - weight (torch.Tensor, optional): Sample-wise loss weight. - Default: None. - class_weight (list[float], optional): The weight for each class. - Default: None. - reduction (str, optional): The method used to reduce the loss. - Options are 'none', 'mean' and 'sum'. Default: 'mean'. - avg_factor (int, optional): Average factor that is used to average - the loss. Default: None. - ignore_index (int): Specifies a target value that is ignored and - does not contribute to the input gradients. When - ``avg_non_ignore `` is ``True``, and the ``reduction`` is - ``''mean''``, the loss is averaged over non-ignored targets. - Defaults: -100. - avg_non_ignore (bool): The flag decides to whether the loss is - only averaged over non-ignored targets. Default: False. - `New in version 0.23.0.` - """ - - # class_weight is a manual rescaling weight given to each class. - # If given, has to be a Tensor of size C element-wise losses - loss = F.cross_entropy(pred, label, weight=class_weight, reduction="none", ignore_index=ignore_index) - - # apply weights and do the reduction - # average loss over non-ignored elements - # pytorch's official cross_entropy average loss over non-ignored elements - # refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa - if (avg_factor is None) and avg_non_ignore and reduction == "mean": - avg_factor = label.numel() - (label == ignore_index).sum().item() - if weight is not None: - weight = weight.float() - loss = weight_reduce_loss(loss, weight=weight, reduction=reduction, avg_factor=avg_factor) - - return loss - - -def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index): - """Expand onehot labels to match the size of prediction.""" - bin_labels = labels.new_zeros(target_shape) - valid_mask = (labels >= 0) & (labels != ignore_index) - inds = torch.nonzero(valid_mask, as_tuple=True) - - if inds[0].numel() > 0: - if labels.dim() == 3: - bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1 - else: - bin_labels[inds[0], labels[valid_mask]] = 1 - - valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float() - - if label_weights is None: - bin_label_weights = valid_mask - else: - bin_label_weights = label_weights.unsqueeze(1).expand(target_shape) - bin_label_weights = bin_label_weights * valid_mask - - return bin_labels, bin_label_weights, valid_mask - - -def binary_cross_entropy( - pred, - label, - weight=None, - reduction="mean", - avg_factor=None, - class_weight=None, - ignore_index=-100, - avg_non_ignore=False, - **kwargs, -): - """Calculate the binary CrossEntropy loss. - - Args: - pred (torch.Tensor): The prediction with shape (N, 1). - label (torch.Tensor): The learning label of the prediction. - Note: In bce loss, label < 0 is invalid. - weight (torch.Tensor, optional): Sample-wise loss weight. - reduction (str, optional): The method used to reduce the loss. - Options are "none", "mean" and "sum". - avg_factor (int, optional): Average factor that is used to average - the loss. Defaults to None. - class_weight (list[float], optional): The weight for each class. - ignore_index (int): The label index to be ignored. Default: -100. - avg_non_ignore (bool): The flag decides to whether the loss is - only averaged over non-ignored targets. Default: False. - `New in version 0.23.0.` - - Returns: - torch.Tensor: The calculated loss - """ - if pred.size(1) == 1: - # For binary class segmentation, the shape of pred is - # [N, 1, H, W] and that of label is [N, H, W]. - assert label.max() <= 1, "For pred with shape [N, 1, H, W], its label must have at " "most 2 classes" - pred = pred.squeeze() - if pred.dim() != label.dim(): - assert (pred.dim() == 2 and label.dim() == 1) or (pred.dim() == 4 and label.dim() == 3), ( - "Only pred shape [N, C], label shape [N] or pred shape [N, C, " "H, W], label shape [N, H, W] are supported" - ) - # `weight` returned from `_expand_onehot_labels` - # has been treated for valid (non-ignore) pixels - label, weight, valid_mask = _expand_onehot_labels(label, weight, pred.shape, ignore_index) - else: - # should mask out the ignored elements - valid_mask = ((label >= 0) & (label != ignore_index)).float() - if weight is not None: - weight = weight * valid_mask - else: - weight = valid_mask - # average loss over non-ignored and valid elements - if reduction == "mean" and avg_factor is None and avg_non_ignore: - avg_factor = valid_mask.sum().item() - - loss = F.binary_cross_entropy_with_logits(pred, label.float(), pos_weight=class_weight, reduction="none") - # do the reduction for the weighted loss - loss = weight_reduce_loss(loss, weight, reduction=reduction, avg_factor=avg_factor) - - return loss - - -def mask_cross_entropy( - pred, target, label, reduction="mean", avg_factor=None, class_weight=None, ignore_index=None, **kwargs -): - """Calculate the CrossEntropy loss for masks. - - Args: - pred (torch.Tensor): The prediction with shape (N, C), C is the number - of classes. - target (torch.Tensor): The learning label of the prediction. - label (torch.Tensor): ``label`` indicates the class label of the mask' - corresponding object. This will be used to select the mask in the - of the class which the object belongs to when the mask prediction - if not class-agnostic. - reduction (str, optional): The method used to reduce the loss. - Options are "none", "mean" and "sum". - avg_factor (int, optional): Average factor that is used to average - the loss. Defaults to None. - class_weight (list[float], optional): The weight for each class. - ignore_index (None): Placeholder, to be consistent with other loss. - Default: None. - - Returns: - torch.Tensor: The calculated loss - """ - assert ignore_index is None, "BCE loss does not support ignore_index" - assert reduction == "mean" and avg_factor is None - num_rois = pred.size()[0] - inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device) - pred_slice = pred[inds, label].squeeze(1) - return F.binary_cross_entropy_with_logits(pred_slice, target, weight=class_weight, reduction="mean")[None] - - -@LOSSES.register_module(force=True) -class CrossEntropyLoss(nn.Module): - """CrossEntropyLoss. - - Args: - use_sigmoid (bool, optional): Whether the prediction uses sigmoid - of softmax. Defaults to False. - use_mask (bool, optional): Whether to use mask cross entropy loss. - Defaults to False. - reduction (str, optional): . Defaults to 'mean'. - Options are "none", "mean" and "sum". - class_weight (list[float] | str, optional): Weight of each class. If in - str format, read them from a file. Defaults to None. - loss_weight (float, optional): Weight of the loss. Defaults to 1.0. - loss_name (str, optional): Name of the loss item. If you want this loss - item to be included into the backward graph, `loss_` must be the - prefix of the name. Defaults to 'loss_ce'. - avg_non_ignore (bool): The flag decides to whether the loss is - only averaged over non-ignored targets. Default: False. - `New in version 0.23.0.` - """ - - def __init__( - self, - use_sigmoid=False, - use_mask=False, - reduction="mean", - class_weight=None, - loss_weight=1.0, - loss_name="loss_ce", - avg_non_ignore=False, - ): - super(CrossEntropyLoss, self).__init__() - assert (use_sigmoid is False) or (use_mask is False) - self.use_sigmoid = use_sigmoid - self.use_mask = use_mask - self.reduction = reduction - self.loss_weight = loss_weight - self.class_weight = get_class_weight(class_weight) - self.avg_non_ignore = avg_non_ignore - if not self.avg_non_ignore and self.reduction == "mean": - warnings.warn( - "Default ``avg_non_ignore`` is False, if you would like to " - "ignore the certain label and average loss over non-ignore " - "labels, which is the same with PyTorch official " - "cross_entropy, set ``avg_non_ignore=True``." - ) - - if self.use_sigmoid: - self.cls_criterion = binary_cross_entropy - elif self.use_mask: - self.cls_criterion = mask_cross_entropy - else: - self.cls_criterion = cross_entropy - self._loss_name = loss_name - - def extra_repr(self): - """Extra repr.""" - s = f"avg_non_ignore={self.avg_non_ignore}" - return s - - def forward( - self, cls_score, label, weight=None, avg_factor=None, reduction_override=None, ignore_index=-100, **kwargs - ): - """Forward function.""" - assert reduction_override in (None, "none", "mean", "sum") - reduction = reduction_override if reduction_override else self.reduction - if self.class_weight is not None: - class_weight = cls_score.new_tensor(self.class_weight) - else: - class_weight = None - # Note: for BCE loss, label < 0 is invalid. - loss_cls = self.loss_weight * self.cls_criterion( - cls_score, - label, - weight, - class_weight=class_weight, - reduction=reduction, - avg_factor=avg_factor, - avg_non_ignore=self.avg_non_ignore, - ignore_index=ignore_index, - **kwargs, - ) - return loss_cls - - @property - def loss_name(self): - """Loss Name. - - This function must be implemented and will return the name of this - loss function. This name will be used to combine different loss items - by simple sum operation. In addition, if you want this loss item to be - included into the backward graph, `loss_` must be the prefix of the - name. - - Returns: - str: The name of this loss item. - """ - return self._loss_name diff --git a/dinov2/eval/segmentation_m2f/models/losses/dice_loss.py b/dinov2/eval/segmentation_m2f/models/losses/dice_loss.py deleted file mode 100644 index 1bc5ba893c502861032ed531283f225e183eb693..0000000000000000000000000000000000000000 --- a/dinov2/eval/segmentation_m2f/models/losses/dice_loss.py +++ /dev/null @@ -1,153 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import torch -import torch.nn as nn -from mmseg.models.builder import LOSSES -from mmseg.models.losses.utils import weight_reduce_loss - - -def dice_loss(pred, target, weight=None, eps=1e-3, reduction="mean", avg_factor=None): - """Calculate dice loss, which is proposed in - `V-Net: Fully Convolutional Neural Networks for Volumetric - Medical Image Segmentation `_. - - Args: - pred (torch.Tensor): The prediction, has a shape (n, *) - target (torch.Tensor): The learning label of the prediction, - shape (n, *), same shape of pred. - weight (torch.Tensor, optional): The weight of loss for each - prediction, has a shape (n,). Defaults to None. - eps (float): Avoid dividing by zero. Default: 1e-3. - reduction (str, optional): The method used to reduce the loss into - a scalar. Defaults to 'mean'. - Options are "none", "mean" and "sum". - avg_factor (int, optional): Average factor that is used to average - the loss. Defaults to None. - """ - - input = pred.flatten(1) - target = target.flatten(1).float() - - a = torch.sum(input * target, 1) - b = torch.sum(input * input, 1) + eps - c = torch.sum(target * target, 1) + eps - d = (2 * a) / (b + c) - loss = 1 - d - if weight is not None: - assert weight.ndim == loss.ndim - assert len(weight) == len(pred) - loss = weight_reduce_loss(loss, weight, reduction, avg_factor) - return loss - - -def naive_dice_loss(pred, target, weight=None, eps=1e-3, reduction="mean", avg_factor=None): - """Calculate naive dice loss, the coefficient in the denominator is the - first power instead of the second power. - - Args: - pred (torch.Tensor): The prediction, has a shape (n, *) - target (torch.Tensor): The learning label of the prediction, - shape (n, *), same shape of pred. - weight (torch.Tensor, optional): The weight of loss for each - prediction, has a shape (n,). Defaults to None. - eps (float): Avoid dividing by zero. Default: 1e-3. - reduction (str, optional): The method used to reduce the loss into - a scalar. Defaults to 'mean'. - Options are "none", "mean" and "sum". - avg_factor (int, optional): Average factor that is used to average - the loss. Defaults to None. - """ - input = pred.flatten(1) - target = target.flatten(1).float() - - a = torch.sum(input * target, 1) - b = torch.sum(input, 1) - c = torch.sum(target, 1) - d = (2 * a + eps) / (b + c + eps) - loss = 1 - d - if weight is not None: - assert weight.ndim == loss.ndim - assert len(weight) == len(pred) - loss = weight_reduce_loss(loss, weight, reduction, avg_factor) - return loss - - -@LOSSES.register_module(force=True) -class DiceLoss(nn.Module): - def __init__(self, use_sigmoid=True, activate=True, reduction="mean", naive_dice=False, loss_weight=1.0, eps=1e-3): - """Dice Loss, there are two forms of dice loss is supported: - - - the one proposed in `V-Net: Fully Convolutional Neural - Networks for Volumetric Medical Image Segmentation - `_. - - the dice loss in which the power of the number in the - denominator is the first power instead of the second - power. - - Args: - use_sigmoid (bool, optional): Whether to the prediction is - used for sigmoid or softmax. Defaults to True. - activate (bool): Whether to activate the predictions inside, - this will disable the inside sigmoid operation. - Defaults to True. - reduction (str, optional): The method used - to reduce the loss. Options are "none", - "mean" and "sum". Defaults to 'mean'. - naive_dice (bool, optional): If false, use the dice - loss defined in the V-Net paper, otherwise, use the - naive dice loss in which the power of the number in the - denominator is the first power instead of the second - power.Defaults to False. - loss_weight (float, optional): Weight of loss. Defaults to 1.0. - eps (float): Avoid dividing by zero. Defaults to 1e-3. - """ - - super(DiceLoss, self).__init__() - self.use_sigmoid = use_sigmoid - self.reduction = reduction - self.naive_dice = naive_dice - self.loss_weight = loss_weight - self.eps = eps - self.activate = activate - - def forward(self, pred, target, weight=None, reduction_override=None, avg_factor=None): - """Forward function. - - Args: - pred (torch.Tensor): The prediction, has a shape (n, *). - target (torch.Tensor): The label of the prediction, - shape (n, *), same shape of pred. - weight (torch.Tensor, optional): The weight of loss for each - prediction, has a shape (n,). Defaults to None. - avg_factor (int, optional): Average factor that is used to average - the loss. Defaults to None. - reduction_override (str, optional): The reduction method used to - override the original reduction method of the loss. - Options are "none", "mean" and "sum". - - Returns: - torch.Tensor: The calculated loss - """ - - assert reduction_override in (None, "none", "mean", "sum") - reduction = reduction_override if reduction_override else self.reduction - - if self.activate: - if self.use_sigmoid: - pred = pred.sigmoid() - else: - raise NotImplementedError - - if self.naive_dice: - loss = self.loss_weight * naive_dice_loss( - pred, target, weight, eps=self.eps, reduction=reduction, avg_factor=avg_factor - ) - else: - loss = self.loss_weight * dice_loss( - pred, target, weight, eps=self.eps, reduction=reduction, avg_factor=avg_factor - ) - - return loss diff --git a/dinov2/eval/segmentation_m2f/models/losses/match_costs.py b/dinov2/eval/segmentation_m2f/models/losses/match_costs.py deleted file mode 100644 index 4917d2a939c01398dd49c0d90b06f4c37d283ce0..0000000000000000000000000000000000000000 --- a/dinov2/eval/segmentation_m2f/models/losses/match_costs.py +++ /dev/null @@ -1,153 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import torch -import torch.nn.functional as F - -from ..builder import MATCH_COST - - -@MATCH_COST.register_module() -class ClassificationCost: - """ClsSoftmaxCost.Borrow from - mmdet.core.bbox.match_costs.match_cost.ClassificationCost. - - Args: - weight (int | float, optional): loss_weight - - Examples: - >>> import torch - >>> self = ClassificationCost() - >>> cls_pred = torch.rand(4, 3) - >>> gt_labels = torch.tensor([0, 1, 2]) - >>> factor = torch.tensor([10, 8, 10, 8]) - >>> self(cls_pred, gt_labels) - tensor([[-0.3430, -0.3525, -0.3045], - [-0.3077, -0.2931, -0.3992], - [-0.3664, -0.3455, -0.2881], - [-0.3343, -0.2701, -0.3956]]) - """ - - def __init__(self, weight=1.0): - self.weight = weight - - def __call__(self, cls_pred, gt_labels): - """ - Args: - cls_pred (Tensor): Predicted classification logits, shape - [num_query, num_class]. - gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,). - - Returns: - torch.Tensor: cls_cost value with weight - """ - # Following the official DETR repo, contrary to the loss that - # NLL is used, we approximate it in 1 - cls_score[gt_label]. - # The 1 is a constant that doesn't change the matching, - # so it can be omitted. - cls_score = cls_pred.softmax(-1) - cls_cost = -cls_score[:, gt_labels] - return cls_cost * self.weight - - -@MATCH_COST.register_module() -class DiceCost: - """Cost of mask assignments based on dice losses. - - Args: - weight (int | float, optional): loss_weight. Defaults to 1. - pred_act (bool, optional): Whether to apply sigmoid to mask_pred. - Defaults to False. - eps (float, optional): default 1e-12. - """ - - def __init__(self, weight=1.0, pred_act=False, eps=1e-3): - self.weight = weight - self.pred_act = pred_act - self.eps = eps - - def binary_mask_dice_loss(self, mask_preds, gt_masks): - """ - Args: - mask_preds (Tensor): Mask prediction in shape (N1, H, W). - gt_masks (Tensor): Ground truth in shape (N2, H, W) - store 0 or 1, 0 for negative class and 1 for - positive class. - - Returns: - Tensor: Dice cost matrix in shape (N1, N2). - """ - mask_preds = mask_preds.reshape((mask_preds.shape[0], -1)) - gt_masks = gt_masks.reshape((gt_masks.shape[0], -1)).float() - numerator = 2 * torch.einsum("nc,mc->nm", mask_preds, gt_masks) - denominator = mask_preds.sum(-1)[:, None] + gt_masks.sum(-1)[None, :] - loss = 1 - (numerator + self.eps) / (denominator + self.eps) - return loss - - def __call__(self, mask_preds, gt_masks): - """ - Args: - mask_preds (Tensor): Mask prediction logits in shape (N1, H, W). - gt_masks (Tensor): Ground truth in shape (N2, H, W). - - Returns: - Tensor: Dice cost matrix in shape (N1, N2). - """ - if self.pred_act: - mask_preds = mask_preds.sigmoid() - dice_cost = self.binary_mask_dice_loss(mask_preds, gt_masks) - return dice_cost * self.weight - - -@MATCH_COST.register_module() -class CrossEntropyLossCost: - """CrossEntropyLossCost. - - Args: - weight (int | float, optional): loss weight. Defaults to 1. - use_sigmoid (bool, optional): Whether the prediction uses sigmoid - of softmax. Defaults to True. - """ - - def __init__(self, weight=1.0, use_sigmoid=True): - assert use_sigmoid, "use_sigmoid = False is not supported yet." - self.weight = weight - self.use_sigmoid = use_sigmoid - - def _binary_cross_entropy(self, cls_pred, gt_labels): - """ - Args: - cls_pred (Tensor): The prediction with shape (num_query, 1, *) or - (num_query, *). - gt_labels (Tensor): The learning label of prediction with - shape (num_gt, *). - Returns: - Tensor: Cross entropy cost matrix in shape (num_query, num_gt). - """ - cls_pred = cls_pred.flatten(1).float() - gt_labels = gt_labels.flatten(1).float() - n = cls_pred.shape[1] - pos = F.binary_cross_entropy_with_logits(cls_pred, torch.ones_like(cls_pred), reduction="none") - neg = F.binary_cross_entropy_with_logits(cls_pred, torch.zeros_like(cls_pred), reduction="none") - cls_cost = torch.einsum("nc,mc->nm", pos, gt_labels) + torch.einsum("nc,mc->nm", neg, 1 - gt_labels) - cls_cost = cls_cost / n - - return cls_cost - - def __call__(self, cls_pred, gt_labels): - """ - Args: - cls_pred (Tensor): Predicted classification logits. - gt_labels (Tensor): Labels. - Returns: - Tensor: Cross entropy cost matrix with weight in - shape (num_query, num_gt). - """ - if self.use_sigmoid: - cls_cost = self._binary_cross_entropy(cls_pred, gt_labels) - else: - raise NotImplementedError - - return cls_cost * self.weight diff --git a/dinov2/eval/segmentation_m2f/models/plugins/__init__.py b/dinov2/eval/segmentation_m2f/models/plugins/__init__.py deleted file mode 100644 index 81a60db4de31238cb38e078683e5ca265839fe60..0000000000000000000000000000000000000000 --- a/dinov2/eval/segmentation_m2f/models/plugins/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from .msdeformattn_pixel_decoder import MSDeformAttnPixelDecoder diff --git a/dinov2/eval/segmentation_m2f/models/plugins/msdeformattn_pixel_decoder.py b/dinov2/eval/segmentation_m2f/models/plugins/msdeformattn_pixel_decoder.py deleted file mode 100644 index db1947175917f73f3f24184cb09c78e092d46ef8..0000000000000000000000000000000000000000 --- a/dinov2/eval/segmentation_m2f/models/plugins/msdeformattn_pixel_decoder.py +++ /dev/null @@ -1,242 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import torch -import torch.nn as nn -import torch.nn.functional as F -from mmcv.cnn import PLUGIN_LAYERS, Conv2d, ConvModule, caffe2_xavier_init, normal_init, xavier_init -from mmcv.cnn.bricks.transformer import build_positional_encoding, build_transformer_layer_sequence -from mmcv.runner import BaseModule, ModuleList - -from ...core.anchor import MlvlPointGenerator -from ..utils.transformer import MultiScaleDeformableAttention - - -@PLUGIN_LAYERS.register_module() -class MSDeformAttnPixelDecoder(BaseModule): - """Pixel decoder with multi-scale deformable attention. - - Args: - in_channels (list[int] | tuple[int]): Number of channels in the - input feature maps. - strides (list[int] | tuple[int]): Output strides of feature from - backbone. - feat_channels (int): Number of channels for feature. - out_channels (int): Number of channels for output. - num_outs (int): Number of output scales. - norm_cfg (:obj:`mmcv.ConfigDict` | dict): Config for normalization. - Defaults to dict(type='GN', num_groups=32). - act_cfg (:obj:`mmcv.ConfigDict` | dict): Config for activation. - Defaults to dict(type='ReLU'). - encoder (:obj:`mmcv.ConfigDict` | dict): Config for transformer - encoder. Defaults to `DetrTransformerEncoder`. - positional_encoding (:obj:`mmcv.ConfigDict` | dict): Config for - transformer encoder position encoding. Defaults to - dict(type='SinePositionalEncoding', num_feats=128, - normalize=True). - init_cfg (:obj:`mmcv.ConfigDict` | dict): Initialization config dict. - """ - - def __init__( - self, - in_channels=[256, 512, 1024, 2048], - strides=[4, 8, 16, 32], - feat_channels=256, - out_channels=256, - num_outs=3, - norm_cfg=dict(type="GN", num_groups=32), - act_cfg=dict(type="ReLU"), - encoder=dict( - type="DetrTransformerEncoder", - num_layers=6, - transformerlayers=dict( - type="BaseTransformerLayer", - attn_cfgs=dict( - type="MultiScaleDeformableAttention", - embed_dims=256, - num_heads=8, - num_levels=3, - num_points=4, - im2col_step=64, - dropout=0.0, - batch_first=False, - norm_cfg=None, - init_cfg=None, - ), - feedforward_channels=1024, - ffn_dropout=0.0, - operation_order=("self_attn", "norm", "ffn", "norm"), - ), - init_cfg=None, - ), - positional_encoding=dict(type="SinePositionalEncoding", num_feats=128, normalize=True), - init_cfg=None, - ): - super().__init__(init_cfg=init_cfg) - self.strides = strides - self.num_input_levels = len(in_channels) - self.num_encoder_levels = encoder.transformerlayers.attn_cfgs.num_levels - assert self.num_encoder_levels >= 1, "num_levels in attn_cfgs must be at least one" - input_conv_list = [] - # from top to down (low to high resolution) - for i in range(self.num_input_levels - 1, self.num_input_levels - self.num_encoder_levels - 1, -1): - input_conv = ConvModule( - in_channels[i], feat_channels, kernel_size=1, norm_cfg=norm_cfg, act_cfg=None, bias=True - ) - input_conv_list.append(input_conv) - self.input_convs = ModuleList(input_conv_list) - - self.encoder = build_transformer_layer_sequence(encoder) - self.postional_encoding = build_positional_encoding(positional_encoding) - # high resolution to low resolution - self.level_encoding = nn.Embedding(self.num_encoder_levels, feat_channels) - - # fpn-like structure - self.lateral_convs = ModuleList() - self.output_convs = ModuleList() - self.use_bias = norm_cfg is None - # from top to down (low to high resolution) - # fpn for the rest features that didn't pass in encoder - for i in range(self.num_input_levels - self.num_encoder_levels - 1, -1, -1): - lateral_conv = ConvModule( - in_channels[i], feat_channels, kernel_size=1, bias=self.use_bias, norm_cfg=norm_cfg, act_cfg=None - ) - output_conv = ConvModule( - feat_channels, - feat_channels, - kernel_size=3, - stride=1, - padding=1, - bias=self.use_bias, - norm_cfg=norm_cfg, - act_cfg=act_cfg, - ) - self.lateral_convs.append(lateral_conv) - self.output_convs.append(output_conv) - - self.mask_feature = Conv2d(feat_channels, out_channels, kernel_size=1, stride=1, padding=0) - - self.num_outs = num_outs - self.point_generator = MlvlPointGenerator(strides) - - def init_weights(self): - """Initialize weights.""" - for i in range(0, self.num_encoder_levels): - xavier_init(self.input_convs[i].conv, gain=1, bias=0, distribution="uniform") - - for i in range(0, self.num_input_levels - self.num_encoder_levels): - caffe2_xavier_init(self.lateral_convs[i].conv, bias=0) - caffe2_xavier_init(self.output_convs[i].conv, bias=0) - - caffe2_xavier_init(self.mask_feature, bias=0) - - normal_init(self.level_encoding, mean=0, std=1) - for p in self.encoder.parameters(): - if p.dim() > 1: - nn.init.xavier_normal_(p) - - # init_weights defined in MultiScaleDeformableAttention - for layer in self.encoder.layers: - for attn in layer.attentions: - if isinstance(attn, MultiScaleDeformableAttention): - attn.init_weights() - - def forward(self, feats): - """ - Args: - feats (list[Tensor]): Feature maps of each level. Each has - shape of (batch_size, c, h, w). - - Returns: - tuple: A tuple containing the following: - - - mask_feature (Tensor): shape (batch_size, c, h, w). - - multi_scale_features (list[Tensor]): Multi scale \ - features, each in shape (batch_size, c, h, w). - """ - # generate padding mask for each level, for each image - batch_size = feats[0].shape[0] - encoder_input_list = [] - padding_mask_list = [] - level_positional_encoding_list = [] - spatial_shapes = [] - reference_points_list = [] - for i in range(self.num_encoder_levels): - level_idx = self.num_input_levels - i - 1 - feat = feats[level_idx] - feat_projected = self.input_convs[i](feat) - h, w = feat.shape[-2:] - - # no padding - padding_mask_resized = feat.new_zeros((batch_size,) + feat.shape[-2:], dtype=torch.bool) - pos_embed = self.postional_encoding(padding_mask_resized) - level_embed = self.level_encoding.weight[i] - level_pos_embed = level_embed.view(1, -1, 1, 1) + pos_embed - # (h_i * w_i, 2) - reference_points = self.point_generator.single_level_grid_priors( - feat.shape[-2:], level_idx, device=feat.device - ) - # normalize - factor = feat.new_tensor([[w, h]]) * self.strides[level_idx] - reference_points = reference_points / factor - - # shape (batch_size, c, h_i, w_i) -> (h_i * w_i, batch_size, c) - feat_projected = feat_projected.flatten(2).permute(2, 0, 1) - level_pos_embed = level_pos_embed.flatten(2).permute(2, 0, 1) - padding_mask_resized = padding_mask_resized.flatten(1) - - encoder_input_list.append(feat_projected) - padding_mask_list.append(padding_mask_resized) - level_positional_encoding_list.append(level_pos_embed) - spatial_shapes.append(feat.shape[-2:]) - reference_points_list.append(reference_points) - # shape (batch_size, total_num_query), - # total_num_query=sum([., h_i * w_i,.]) - padding_masks = torch.cat(padding_mask_list, dim=1) - # shape (total_num_query, batch_size, c) - encoder_inputs = torch.cat(encoder_input_list, dim=0) - level_positional_encodings = torch.cat(level_positional_encoding_list, dim=0) - device = encoder_inputs.device - # shape (num_encoder_levels, 2), from low - # resolution to high resolution - spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=device) - # shape (0, h_0*w_0, h_0*w_0+h_1*w_1, ...) - level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) - reference_points = torch.cat(reference_points_list, dim=0) - reference_points = reference_points[None, :, None].repeat(batch_size, 1, self.num_encoder_levels, 1) - valid_radios = reference_points.new_ones((batch_size, self.num_encoder_levels, 2)) - # shape (num_total_query, batch_size, c) - memory = self.encoder( - query=encoder_inputs, - key=None, - value=None, - query_pos=level_positional_encodings, - key_pos=None, - attn_masks=None, - key_padding_mask=None, - query_key_padding_mask=padding_masks, - spatial_shapes=spatial_shapes, - reference_points=reference_points, - level_start_index=level_start_index, - valid_radios=valid_radios, - ) - # (num_total_query, batch_size, c) -> (batch_size, c, num_total_query) - memory = memory.permute(1, 2, 0) - - # from low resolution to high resolution - num_query_per_level = [e[0] * e[1] for e in spatial_shapes] - outs = torch.split(memory, num_query_per_level, dim=-1) - outs = [x.reshape(batch_size, -1, spatial_shapes[i][0], spatial_shapes[i][1]) for i, x in enumerate(outs)] - - for i in range(self.num_input_levels - self.num_encoder_levels - 1, -1, -1): - x = feats[i] - cur_feat = self.lateral_convs[i](x) - y = cur_feat + F.interpolate(outs[-1], size=cur_feat.shape[-2:], mode="bilinear", align_corners=False) - y = self.output_convs[i](y) - outs.append(y) - multi_scale_features = outs[: self.num_outs] - - mask_feature = self.mask_feature(outs[-1]) - return mask_feature, multi_scale_features diff --git a/dinov2/eval/segmentation_m2f/models/segmentors/__init__.py b/dinov2/eval/segmentation_m2f/models/segmentors/__init__.py deleted file mode 100644 index adf0062691e4889612e118f28ced853cd0bc33db..0000000000000000000000000000000000000000 --- a/dinov2/eval/segmentation_m2f/models/segmentors/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from .encoder_decoder_mask2former import EncoderDecoderMask2Former diff --git a/dinov2/eval/segmentation_m2f/models/segmentors/encoder_decoder_mask2former.py b/dinov2/eval/segmentation_m2f/models/segmentors/encoder_decoder_mask2former.py deleted file mode 100644 index cfe572c9d317303bff8d51b85217d144906ebfe7..0000000000000000000000000000000000000000 --- a/dinov2/eval/segmentation_m2f/models/segmentors/encoder_decoder_mask2former.py +++ /dev/null @@ -1,271 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import torch -import torch.nn as nn -import torch.nn.functional as F -from mmseg.core import add_prefix -from mmseg.models import builder -from mmseg.models.builder import SEGMENTORS -from mmseg.models.segmentors.base import BaseSegmentor -from mmseg.ops import resize - - -@SEGMENTORS.register_module() -class EncoderDecoderMask2Former(BaseSegmentor): - """Encoder Decoder segmentors. - - EncoderDecoder typically consists of backbone, decode_head, auxiliary_head. - Note that auxiliary_head is only used for deep supervision during training, - which could be dumped during inference. - """ - - def __init__( - self, - backbone, - decode_head, - neck=None, - auxiliary_head=None, - train_cfg=None, - test_cfg=None, - pretrained=None, - init_cfg=None, - ): - super(EncoderDecoderMask2Former, self).__init__(init_cfg) - if pretrained is not None: - assert backbone.get("pretrained") is None, "both backbone and segmentor set pretrained weight" - backbone.pretrained = pretrained - self.backbone = builder.build_backbone(backbone) - if neck is not None: - self.neck = builder.build_neck(neck) - decode_head.update(train_cfg=train_cfg) - decode_head.update(test_cfg=test_cfg) - self._init_decode_head(decode_head) - self._init_auxiliary_head(auxiliary_head) - - self.train_cfg = train_cfg - self.test_cfg = test_cfg - - assert self.with_decode_head - - def _init_decode_head(self, decode_head): - """Initialize ``decode_head``""" - self.decode_head = builder.build_head(decode_head) - self.align_corners = self.decode_head.align_corners - self.num_classes = self.decode_head.num_classes - - def _init_auxiliary_head(self, auxiliary_head): - """Initialize ``auxiliary_head``""" - if auxiliary_head is not None: - if isinstance(auxiliary_head, list): - self.auxiliary_head = nn.ModuleList() - for head_cfg in auxiliary_head: - self.auxiliary_head.append(builder.build_head(head_cfg)) - else: - self.auxiliary_head = builder.build_head(auxiliary_head) - - def extract_feat(self, img): - """Extract features from images.""" - x = self.backbone(img) - if self.with_neck: - x = self.neck(x) - return x - - def encode_decode(self, img, img_metas): - """Encode images with backbone and decode into a semantic segmentation - map of the same size as input.""" - x = self.extract_feat(img) - out = self._decode_head_forward_test(x, img_metas) - out = resize(input=out, size=img.shape[2:], mode="bilinear", align_corners=self.align_corners) - return out - - def _decode_head_forward_train(self, x, img_metas, gt_semantic_seg, **kwargs): - """Run forward function and calculate loss for decode head in - training.""" - losses = dict() - loss_decode = self.decode_head.forward_train(x, img_metas, gt_semantic_seg, **kwargs) - - losses.update(add_prefix(loss_decode, "decode")) - return losses - - def _decode_head_forward_test(self, x, img_metas): - """Run forward function and calculate loss for decode head in - inference.""" - seg_logits = self.decode_head.forward_test(x, img_metas, self.test_cfg) - return seg_logits - - def _auxiliary_head_forward_train(self, x, img_metas, gt_semantic_seg): - """Run forward function and calculate loss for auxiliary head in - training.""" - losses = dict() - if isinstance(self.auxiliary_head, nn.ModuleList): - for idx, aux_head in enumerate(self.auxiliary_head): - loss_aux = aux_head.forward_train(x, img_metas, gt_semantic_seg, self.train_cfg) - losses.update(add_prefix(loss_aux, f"aux_{idx}")) - else: - loss_aux = self.auxiliary_head.forward_train(x, img_metas, gt_semantic_seg, self.train_cfg) - losses.update(add_prefix(loss_aux, "aux")) - - return losses - - def forward_dummy(self, img): - """Dummy forward function.""" - seg_logit = self.encode_decode(img, None) - - return seg_logit - - def forward_train(self, img, img_metas, gt_semantic_seg, **kwargs): - """Forward function for training. - - Args: - img (Tensor): Input images. - img_metas (list[dict]): List of image info dict where each dict - has: 'img_shape', 'scale_factor', 'flip', and may also contain - 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. - For details on the values of these keys see - `mmseg/datasets/pipelines/formatting.py:Collect`. - gt_semantic_seg (Tensor): Semantic segmentation masks - used if the architecture supports semantic segmentation task. - - Returns: - dict[str, Tensor]: a dictionary of loss components - """ - - x = self.extract_feat(img) - - losses = dict() - - loss_decode = self._decode_head_forward_train(x, img_metas, gt_semantic_seg, **kwargs) - losses.update(loss_decode) - - if self.with_auxiliary_head: - loss_aux = self._auxiliary_head_forward_train(x, img_metas, gt_semantic_seg) - losses.update(loss_aux) - - return losses - - def slide_inference(self, img, img_meta, rescale): - """Inference by sliding-window with overlap. - - If h_crop > h_img or w_crop > w_img, the small patch will be used to - decode without padding. - """ - - h_stride, w_stride = self.test_cfg.stride - h_crop, w_crop = self.test_cfg.crop_size - batch_size, _, h_img, w_img = img.size() - num_classes = self.num_classes - h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 - w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 - preds = img.new_zeros((batch_size, num_classes, h_img, w_img)) - count_mat = img.new_zeros((batch_size, 1, h_img, w_img)) - for h_idx in range(h_grids): - for w_idx in range(w_grids): - y1 = h_idx * h_stride - x1 = w_idx * w_stride - y2 = min(y1 + h_crop, h_img) - x2 = min(x1 + w_crop, w_img) - y1 = max(y2 - h_crop, 0) - x1 = max(x2 - w_crop, 0) - crop_img = img[:, :, y1:y2, x1:x2] - crop_seg_logit = self.encode_decode(crop_img, img_meta) - preds += F.pad(crop_seg_logit, (int(x1), int(preds.shape[3] - x2), int(y1), int(preds.shape[2] - y2))) - - count_mat[:, :, y1:y2, x1:x2] += 1 - assert (count_mat == 0).sum() == 0 - if torch.onnx.is_in_onnx_export(): - # cast count_mat to constant while exporting to ONNX - count_mat = torch.from_numpy(count_mat.cpu().detach().numpy()).to(device=img.device) - preds = preds / count_mat - if rescale: - preds = resize( - preds, - size=img_meta[0]["ori_shape"][:2], - mode="bilinear", - align_corners=self.align_corners, - warning=False, - ) - return preds - - def whole_inference(self, img, img_meta, rescale): - """Inference with full image.""" - - seg_logit = self.encode_decode(img, img_meta) - if rescale: - # support dynamic shape for onnx - if torch.onnx.is_in_onnx_export(): - size = img.shape[2:] - else: - size = img_meta[0]["ori_shape"][:2] - seg_logit = resize(seg_logit, size=size, mode="bilinear", align_corners=self.align_corners, warning=False) - - return seg_logit - - def inference(self, img, img_meta, rescale): - """Inference with slide/whole style. - - Args: - img (Tensor): The input image of shape (N, 3, H, W). - img_meta (dict): Image info dict where each dict has: 'img_shape', - 'scale_factor', 'flip', and may also contain - 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. - For details on the values of these keys see - `mmseg/datasets/pipelines/formatting.py:Collect`. - rescale (bool): Whether rescale back to original shape. - - Returns: - Tensor: The output segmentation map. - """ - - assert self.test_cfg.mode in ["slide", "whole"] - ori_shape = img_meta[0]["ori_shape"] - assert all(_["ori_shape"] == ori_shape for _ in img_meta) - if self.test_cfg.mode == "slide": - seg_logit = self.slide_inference(img, img_meta, rescale) - else: - seg_logit = self.whole_inference(img, img_meta, rescale) - output = F.softmax(seg_logit, dim=1) - flip = img_meta[0]["flip"] - if flip: - flip_direction = img_meta[0]["flip_direction"] - assert flip_direction in ["horizontal", "vertical"] - if flip_direction == "horizontal": - output = output.flip(dims=(3,)) - elif flip_direction == "vertical": - output = output.flip(dims=(2,)) - - return output - - def simple_test(self, img, img_meta, rescale=True): - """Simple test with single image.""" - seg_logit = self.inference(img, img_meta, rescale) - seg_pred = seg_logit.argmax(dim=1) - if torch.onnx.is_in_onnx_export(): - # our inference backend only support 4D output - seg_pred = seg_pred.unsqueeze(0) - return seg_pred - seg_pred = seg_pred.cpu().numpy() - # unravel batch dim - seg_pred = list(seg_pred) - return seg_pred - - def aug_test(self, imgs, img_metas, rescale=True): - """Test with augmentations. - - Only rescale=True is supported. - """ - # aug_test rescale all imgs back to ori_shape for now - assert rescale - # to save memory, we get augmented seg logit inplace - seg_logit = self.inference(imgs[0], img_metas[0], rescale) - for i in range(1, len(imgs)): - cur_seg_logit = self.inference(imgs[i], img_metas[i], rescale) - seg_logit += cur_seg_logit - seg_logit /= len(imgs) - seg_pred = seg_logit.argmax(dim=1) - seg_pred = seg_pred.cpu().numpy() - # unravel batch dim - seg_pred = list(seg_pred) - return seg_pred diff --git a/dinov2/eval/segmentation_m2f/models/utils/__init__.py b/dinov2/eval/segmentation_m2f/models/utils/__init__.py deleted file mode 100644 index e7fdc1668b1015c8feea8fa1a4691bc0ebdbd936..0000000000000000000000000000000000000000 --- a/dinov2/eval/segmentation_m2f/models/utils/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from .assigner import MaskHungarianAssigner -from .point_sample import get_uncertain_point_coords_with_randomness -from .positional_encoding import LearnedPositionalEncoding, SinePositionalEncoding -from .transformer import DetrTransformerDecoder, DetrTransformerDecoderLayer, DynamicConv, Transformer diff --git a/dinov2/eval/segmentation_m2f/models/utils/assigner.py b/dinov2/eval/segmentation_m2f/models/utils/assigner.py deleted file mode 100644 index 3cb08fc1bb2e36336989b45a1d3850f260c05963..0000000000000000000000000000000000000000 --- a/dinov2/eval/segmentation_m2f/models/utils/assigner.py +++ /dev/null @@ -1,157 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from abc import ABCMeta, abstractmethod - -import torch - -from ..builder import MASK_ASSIGNERS, build_match_cost - -try: - from scipy.optimize import linear_sum_assignment -except ImportError: - linear_sum_assignment = None - - -class AssignResult(metaclass=ABCMeta): - """Collection of assign results.""" - - def __init__(self, num_gts, gt_inds, labels): - self.num_gts = num_gts - self.gt_inds = gt_inds - self.labels = labels - - @property - def info(self): - info = { - "num_gts": self.num_gts, - "gt_inds": self.gt_inds, - "labels": self.labels, - } - return info - - -class BaseAssigner(metaclass=ABCMeta): - """Base assigner that assigns boxes to ground truth boxes.""" - - @abstractmethod - def assign(self, masks, gt_masks, gt_masks_ignore=None, gt_labels=None): - """Assign boxes to either a ground truth boxes or a negative boxes.""" - pass - - -@MASK_ASSIGNERS.register_module() -class MaskHungarianAssigner(BaseAssigner): - """Computes one-to-one matching between predictions and ground truth for - mask. - - This class computes an assignment between the targets and the predictions - based on the costs. The costs are weighted sum of three components: - classification cost, regression L1 cost and regression iou cost. The - targets don't include the no_object, so generally there are more - predictions than targets. After the one-to-one matching, the un-matched - are treated as backgrounds. Thus each query prediction will be assigned - with `0` or a positive integer indicating the ground truth index: - - - 0: negative sample, no assigned gt - - positive integer: positive sample, index (1-based) of assigned gt - - Args: - cls_cost (obj:`mmcv.ConfigDict`|dict): Classification cost config. - mask_cost (obj:`mmcv.ConfigDict`|dict): Mask cost config. - dice_cost (obj:`mmcv.ConfigDict`|dict): Dice cost config. - """ - - def __init__( - self, - cls_cost=dict(type="ClassificationCost", weight=1.0), - dice_cost=dict(type="DiceCost", weight=1.0), - mask_cost=dict(type="MaskFocalCost", weight=1.0), - ): - self.cls_cost = build_match_cost(cls_cost) - self.dice_cost = build_match_cost(dice_cost) - self.mask_cost = build_match_cost(mask_cost) - - def assign(self, cls_pred, mask_pred, gt_labels, gt_masks, img_meta, gt_masks_ignore=None, eps=1e-7): - """Computes one-to-one matching based on the weighted costs. - - This method assign each query prediction to a ground truth or - background. The `assigned_gt_inds` with -1 means don't care, - 0 means negative sample, and positive number is the index (1-based) - of assigned gt. - The assignment is done in the following steps, the order matters. - - 1. assign every prediction to -1 - 2. compute the weighted costs - 3. do Hungarian matching on CPU based on the costs - 4. assign all to 0 (background) first, then for each matched pair - between predictions and gts, treat this prediction as foreground - and assign the corresponding gt index (plus 1) to it. - - Args: - mask_pred (Tensor): Predicted mask, shape [num_query, h, w] - cls_pred (Tensor): Predicted classification logits, shape - [num_query, num_class]. - gt_masks (Tensor): Ground truth mask, shape [num_gt, h, w]. - gt_labels (Tensor): Label of `gt_masks`, shape (num_gt,). - img_meta (dict): Meta information for current image. - gt_masks_ignore (Tensor, optional): Ground truth masks that are - labelled as `ignored`. Default None. - eps (int | float, optional): A value added to the denominator for - numerical stability. Default 1e-7. - - Returns: - :obj:`AssignResult`: The assigned result. - """ - assert gt_masks_ignore is None, "Only case when gt_masks_ignore is None is supported." - num_gts, num_queries = gt_labels.shape[0], cls_pred.shape[0] - - # 1. assign -1 by default - assigned_gt_inds = cls_pred.new_full((num_queries,), -1, dtype=torch.long) - assigned_labels = cls_pred.new_full((num_queries,), -1, dtype=torch.long) - if num_gts == 0 or num_queries == 0: - # No ground truth or boxes, return empty assignment - if num_gts == 0: - # No ground truth, assign all to background - assigned_gt_inds[:] = 0 - return AssignResult(num_gts, assigned_gt_inds, labels=assigned_labels) - - # 2. compute the weighted costs - # classification and maskcost. - if self.cls_cost.weight != 0 and cls_pred is not None: - cls_cost = self.cls_cost(cls_pred, gt_labels) - else: - cls_cost = 0 - - if self.mask_cost.weight != 0: - # mask_pred shape = [nq, h, w] - # gt_mask shape = [ng, h, w] - # mask_cost shape = [nq, ng] - mask_cost = self.mask_cost(mask_pred, gt_masks) - else: - mask_cost = 0 - - if self.dice_cost.weight != 0: - dice_cost = self.dice_cost(mask_pred, gt_masks) - else: - dice_cost = 0 - cost = cls_cost + mask_cost + dice_cost - - # 3. do Hungarian matching on CPU using linear_sum_assignment - cost = cost.detach().cpu() - if linear_sum_assignment is None: - raise ImportError('Please run "pip install scipy" ' "to install scipy first.") - - matched_row_inds, matched_col_inds = linear_sum_assignment(cost) - matched_row_inds = torch.from_numpy(matched_row_inds).to(cls_pred.device) - matched_col_inds = torch.from_numpy(matched_col_inds).to(cls_pred.device) - - # 4. assign backgrounds and foregrounds - # assign all indices to backgrounds first - assigned_gt_inds[:] = 0 - # assign foregrounds based on matching results - assigned_gt_inds[matched_row_inds] = matched_col_inds + 1 - assigned_labels[matched_row_inds] = gt_labels[matched_col_inds] - return AssignResult(num_gts, assigned_gt_inds, labels=assigned_labels) diff --git a/dinov2/eval/segmentation_m2f/models/utils/point_sample.py b/dinov2/eval/segmentation_m2f/models/utils/point_sample.py deleted file mode 100644 index 9f1134082bafb51432618a9632592db070f87284..0000000000000000000000000000000000000000 --- a/dinov2/eval/segmentation_m2f/models/utils/point_sample.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import torch -from mmcv.ops import point_sample - - -def get_uncertainty(mask_pred, labels): - """Estimate uncertainty based on pred logits. - - We estimate uncertainty as L1 distance between 0.0 and the logits - prediction in 'mask_pred' for the foreground class in `classes`. - - Args: - mask_pred (Tensor): mask predication logits, shape (num_rois, - num_classes, mask_height, mask_width). - - labels (list[Tensor]): Either predicted or ground truth label for - each predicted mask, of length num_rois. - - Returns: - scores (Tensor): Uncertainty scores with the most uncertain - locations having the highest uncertainty score, - shape (num_rois, 1, mask_height, mask_width) - """ - if mask_pred.shape[1] == 1: - gt_class_logits = mask_pred.clone() - else: - inds = torch.arange(mask_pred.shape[0], device=mask_pred.device) - gt_class_logits = mask_pred[inds, labels].unsqueeze(1) - return -torch.abs(gt_class_logits) - - -def get_uncertain_point_coords_with_randomness( - mask_pred, labels, num_points, oversample_ratio, importance_sample_ratio -): - """Get ``num_points`` most uncertain points with random points during - train. - - Sample points in [0, 1] x [0, 1] coordinate space based on their - uncertainty. The uncertainties are calculated for each point using - 'get_uncertainty()' function that takes point's logit prediction as - input. - - Args: - mask_pred (Tensor): A tensor of shape (num_rois, num_classes, - mask_height, mask_width) for class-specific or class-agnostic - prediction. - labels (list): The ground truth class for each instance. - num_points (int): The number of points to sample. - oversample_ratio (int): Oversampling parameter. - importance_sample_ratio (float): Ratio of points that are sampled - via importnace sampling. - - Returns: - point_coords (Tensor): A tensor of shape (num_rois, num_points, 2) - that contains the coordinates sampled points. - """ - assert oversample_ratio >= 1 - assert 0 <= importance_sample_ratio <= 1 - batch_size = mask_pred.shape[0] - num_sampled = int(num_points * oversample_ratio) - point_coords = torch.rand(batch_size, num_sampled, 2, device=mask_pred.device) - point_logits = point_sample(mask_pred, point_coords) - # It is crucial to calculate uncertainty based on the sampled - # prediction value for the points. Calculating uncertainties of the - # coarse predictions first and sampling them for points leads to - # incorrect results. To illustrate this: assume uncertainty func( - # logits)=-abs(logits), a sampled point between two coarse - # predictions with -1 and 1 logits has 0 logits, and therefore 0 - # uncertainty value. However, if we calculate uncertainties for the - # coarse predictions first, both will have -1 uncertainty, - # and sampled point will get -1 uncertainty. - point_uncertainties = get_uncertainty(point_logits, labels) - num_uncertain_points = int(importance_sample_ratio * num_points) - num_random_points = num_points - num_uncertain_points - idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] - shift = num_sampled * torch.arange(batch_size, dtype=torch.long, device=mask_pred.device) - idx += shift[:, None] - point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(batch_size, num_uncertain_points, 2) - if num_random_points > 0: - rand_roi_coords = torch.rand(batch_size, num_random_points, 2, device=mask_pred.device) - point_coords = torch.cat((point_coords, rand_roi_coords), dim=1) - return point_coords diff --git a/dinov2/eval/segmentation_m2f/models/utils/positional_encoding.py b/dinov2/eval/segmentation_m2f/models/utils/positional_encoding.py deleted file mode 100644 index bf5d6fabe946d06fe97cc799da47bae93758b34e..0000000000000000000000000000000000000000 --- a/dinov2/eval/segmentation_m2f/models/utils/positional_encoding.py +++ /dev/null @@ -1,152 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import math - -import torch -import torch.nn as nn -from mmcv.cnn.bricks.transformer import POSITIONAL_ENCODING -from mmcv.runner import BaseModule - - -@POSITIONAL_ENCODING.register_module() -class SinePositionalEncoding(BaseModule): - """Position encoding with sine and cosine functions. - - See `End-to-End Object Detection with Transformers - `_ for details. - - Args: - num_feats (int): The feature dimension for each position - along x-axis or y-axis. Note the final returned dimension - for each position is 2 times of this value. - temperature (int, optional): The temperature used for scaling - the position embedding. Defaults to 10000. - normalize (bool, optional): Whether to normalize the position - embedding. Defaults to False. - scale (float, optional): A scale factor that scales the position - embedding. The scale will be used only when `normalize` is True. - Defaults to 2*pi. - eps (float, optional): A value added to the denominator for - numerical stability. Defaults to 1e-6. - offset (float): offset add to embed when do the normalization. - Defaults to 0. - init_cfg (dict or list[dict], optional): Initialization config dict. - Default: None - """ - - def __init__( - self, num_feats, temperature=10000, normalize=False, scale=2 * math.pi, eps=1e-6, offset=0.0, init_cfg=None - ): - super(SinePositionalEncoding, self).__init__(init_cfg) - if normalize: - assert isinstance(scale, (float, int)), ( - "when normalize is set," "scale should be provided and in float or int type, " f"found {type(scale)}" - ) - self.num_feats = num_feats - self.temperature = temperature - self.normalize = normalize - self.scale = scale - self.eps = eps - self.offset = offset - - def forward(self, mask): - """Forward function for `SinePositionalEncoding`. - - Args: - mask (Tensor): ByteTensor mask. Non-zero values representing - ignored positions, while zero values means valid positions - for this image. Shape [bs, h, w]. - - Returns: - pos (Tensor): Returned position embedding with shape - [bs, num_feats*2, h, w]. - """ - # For convenience of exporting to ONNX, it's required to convert - # `masks` from bool to int. - mask = mask.to(torch.int) - not_mask = 1 - mask # logical_not - y_embed = not_mask.cumsum(1, dtype=torch.float32) - x_embed = not_mask.cumsum(2, dtype=torch.float32) - if self.normalize: - y_embed = (y_embed + self.offset) / (y_embed[:, -1:, :] + self.eps) * self.scale - x_embed = (x_embed + self.offset) / (x_embed[:, :, -1:] + self.eps) * self.scale - dim_t = torch.arange(self.num_feats, dtype=torch.float32, device=mask.device) - dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_feats) - pos_x = x_embed[:, :, :, None] / dim_t - pos_y = y_embed[:, :, :, None] / dim_t - # use `view` instead of `flatten` for dynamically exporting to ONNX - B, H, W = mask.size() - pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).view(B, H, W, -1) - pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).view(B, H, W, -1) - pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) - return pos - - def __repr__(self): - """str: a string that describes the module""" - repr_str = self.__class__.__name__ - repr_str += f"(num_feats={self.num_feats}, " - repr_str += f"temperature={self.temperature}, " - repr_str += f"normalize={self.normalize}, " - repr_str += f"scale={self.scale}, " - repr_str += f"eps={self.eps})" - return repr_str - - -@POSITIONAL_ENCODING.register_module() -class LearnedPositionalEncoding(BaseModule): - """Position embedding with learnable embedding weights. - - Args: - num_feats (int): The feature dimension for each position - along x-axis or y-axis. The final returned dimension for - each position is 2 times of this value. - row_num_embed (int, optional): The dictionary size of row embeddings. - Default 50. - col_num_embed (int, optional): The dictionary size of col embeddings. - Default 50. - init_cfg (dict or list[dict], optional): Initialization config dict. - """ - - def __init__(self, num_feats, row_num_embed=50, col_num_embed=50, init_cfg=dict(type="Uniform", layer="Embedding")): - super(LearnedPositionalEncoding, self).__init__(init_cfg) - self.row_embed = nn.Embedding(row_num_embed, num_feats) - self.col_embed = nn.Embedding(col_num_embed, num_feats) - self.num_feats = num_feats - self.row_num_embed = row_num_embed - self.col_num_embed = col_num_embed - - def forward(self, mask): - """Forward function for `LearnedPositionalEncoding`. - - Args: - mask (Tensor): ByteTensor mask. Non-zero values representing - ignored positions, while zero values means valid positions - for this image. Shape [bs, h, w]. - - Returns: - pos (Tensor): Returned position embedding with shape - [bs, num_feats*2, h, w]. - """ - h, w = mask.shape[-2:] - x = torch.arange(w, device=mask.device) - y = torch.arange(h, device=mask.device) - x_embed = self.col_embed(x) - y_embed = self.row_embed(y) - pos = ( - torch.cat((x_embed.unsqueeze(0).repeat(h, 1, 1), y_embed.unsqueeze(1).repeat(1, w, 1)), dim=-1) - .permute(2, 0, 1) - .unsqueeze(0) - .repeat(mask.shape[0], 1, 1, 1) - ) - return pos - - def __repr__(self): - """str: a string that describes the module""" - repr_str = self.__class__.__name__ - repr_str += f"(num_feats={self.num_feats}, " - repr_str += f"row_num_embed={self.row_num_embed}, " - repr_str += f"col_num_embed={self.col_num_embed})" - return repr_str diff --git a/dinov2/eval/segmentation_m2f/models/utils/transformer.py b/dinov2/eval/segmentation_m2f/models/utils/transformer.py deleted file mode 100644 index 8befe6011a34d5ccecb82c8b17b61e19f732f96b..0000000000000000000000000000000000000000 --- a/dinov2/eval/segmentation_m2f/models/utils/transformer.py +++ /dev/null @@ -1,989 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import math -import warnings -from typing import Sequence - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.utils.checkpoint as cp -from mmcv.cnn import Linear, build_activation_layer, build_norm_layer, xavier_init -from mmcv.cnn.bricks.drop import build_dropout -from mmcv.cnn.bricks.registry import FEEDFORWARD_NETWORK, TRANSFORMER_LAYER, TRANSFORMER_LAYER_SEQUENCE -from mmcv.cnn.bricks.transformer import BaseTransformerLayer, TransformerLayerSequence, build_transformer_layer_sequence -from mmcv.runner.base_module import BaseModule, Sequential -from mmcv.utils import deprecated_api_warning, to_2tuple -from torch.nn.init import normal_ - -from ..builder import TRANSFORMER - -try: - from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention - -except ImportError: - warnings.warn( - "`MultiScaleDeformableAttention` in MMCV has been moved to " - "`mmcv.ops.multi_scale_deform_attn`, please update your MMCV" - ) - from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention - - -class AdaptivePadding(nn.Module): - """Applies padding to input (if needed) so that input can get fully covered - by filter you specified. It support two modes "same" and "corner". The - "same" mode is same with "SAME" padding mode in TensorFlow, pad zero around - input. The "corner" mode would pad zero to bottom right. - - Args: - kernel_size (int | tuple): Size of the kernel: - stride (int | tuple): Stride of the filter. Default: 1: - dilation (int | tuple): Spacing between kernel elements. - Default: 1 - padding (str): Support "same" and "corner", "corner" mode - would pad zero to bottom right, and "same" mode would - pad zero around input. Default: "corner". - Example: - >>> kernel_size = 16 - >>> stride = 16 - >>> dilation = 1 - >>> input = torch.rand(1, 1, 15, 17) - >>> adap_pad = AdaptivePadding( - >>> kernel_size=kernel_size, - >>> stride=stride, - >>> dilation=dilation, - >>> padding="corner") - >>> out = adap_pad(input) - >>> assert (out.shape[2], out.shape[3]) == (16, 32) - >>> input = torch.rand(1, 1, 16, 17) - >>> out = adap_pad(input) - >>> assert (out.shape[2], out.shape[3]) == (16, 32) - """ - - def __init__(self, kernel_size=1, stride=1, dilation=1, padding="corner"): - - super(AdaptivePadding, self).__init__() - - assert padding in ("same", "corner") - - kernel_size = to_2tuple(kernel_size) - stride = to_2tuple(stride) - padding = to_2tuple(padding) - dilation = to_2tuple(dilation) - - self.padding = padding - self.kernel_size = kernel_size - self.stride = stride - self.dilation = dilation - - def get_pad_shape(self, input_shape): - input_h, input_w = input_shape - kernel_h, kernel_w = self.kernel_size - stride_h, stride_w = self.stride - output_h = math.ceil(input_h / stride_h) - output_w = math.ceil(input_w / stride_w) - pad_h = max((output_h - 1) * stride_h + (kernel_h - 1) * self.dilation[0] + 1 - input_h, 0) - pad_w = max((output_w - 1) * stride_w + (kernel_w - 1) * self.dilation[1] + 1 - input_w, 0) - return pad_h, pad_w - - def forward(self, x): - pad_h, pad_w = self.get_pad_shape(x.size()[-2:]) - if pad_h > 0 or pad_w > 0: - if self.padding == "corner": - x = F.pad(x, [0, pad_w, 0, pad_h]) - elif self.padding == "same": - x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) - return x - - -class PatchMerging(BaseModule): - """Merge patch feature map. - - This layer groups feature map by kernel_size, and applies norm and linear - layers to the grouped feature map. Our implementation uses `nn.Unfold` to - merge patch, which is about 25% faster than original implementation. - Instead, we need to modify pretrained models for compatibility. - - Args: - in_channels (int): The num of input channels. - to gets fully covered by filter and stride you specified.. - Default: True. - out_channels (int): The num of output channels. - kernel_size (int | tuple, optional): the kernel size in the unfold - layer. Defaults to 2. - stride (int | tuple, optional): the stride of the sliding blocks in the - unfold layer. Default: None. (Would be set as `kernel_size`) - padding (int | tuple | string ): The padding length of - embedding conv. When it is a string, it means the mode - of adaptive padding, support "same" and "corner" now. - Default: "corner". - dilation (int | tuple, optional): dilation parameter in the unfold - layer. Default: 1. - bias (bool, optional): Whether to add bias in linear layer or not. - Defaults: False. - norm_cfg (dict, optional): Config dict for normalization layer. - Default: dict(type='LN'). - init_cfg (dict, optional): The extra config for initialization. - Default: None. - """ - - def __init__( - self, - in_channels, - out_channels, - kernel_size=2, - stride=None, - padding="corner", - dilation=1, - bias=False, - norm_cfg=dict(type="LN"), - init_cfg=None, - ): - super().__init__(init_cfg=init_cfg) - self.in_channels = in_channels - self.out_channels = out_channels - if stride: - stride = stride - else: - stride = kernel_size - - kernel_size = to_2tuple(kernel_size) - stride = to_2tuple(stride) - dilation = to_2tuple(dilation) - - if isinstance(padding, str): - self.adap_padding = AdaptivePadding( - kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding - ) - # disable the padding of unfold - padding = 0 - else: - self.adap_padding = None - - padding = to_2tuple(padding) - self.sampler = nn.Unfold(kernel_size=kernel_size, dilation=dilation, padding=padding, stride=stride) - - sample_dim = kernel_size[0] * kernel_size[1] * in_channels - - if norm_cfg is not None: - self.norm = build_norm_layer(norm_cfg, sample_dim)[1] - else: - self.norm = None - - self.reduction = nn.Linear(sample_dim, out_channels, bias=bias) - - def forward(self, x, input_size): - """ - Args: - x (Tensor): Has shape (B, H*W, C_in). - input_size (tuple[int]): The spatial shape of x, arrange as (H, W). - Default: None. - - Returns: - tuple: Contains merged results and its spatial shape. - - - x (Tensor): Has shape (B, Merged_H * Merged_W, C_out) - - out_size (tuple[int]): Spatial shape of x, arrange as - (Merged_H, Merged_W). - """ - B, L, C = x.shape - assert isinstance(input_size, Sequence), f"Expect " f"input_size is " f"`Sequence` " f"but get {input_size}" - - H, W = input_size - assert L == H * W, "input feature has wrong size" - - x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W - # Use nn.Unfold to merge patch. About 25% faster than original method, - # but need to modify pretrained model for compatibility - - if self.adap_padding: - x = self.adap_padding(x) - H, W = x.shape[-2:] - - x = self.sampler(x) - # if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2) - - out_h = ( - H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] * (self.sampler.kernel_size[0] - 1) - 1 - ) // self.sampler.stride[0] + 1 - out_w = ( - W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] * (self.sampler.kernel_size[1] - 1) - 1 - ) // self.sampler.stride[1] + 1 - - output_size = (out_h, out_w) - x = x.transpose(1, 2) # B, H/2*W/2, 4*C - x = self.norm(x) if self.norm else x - x = self.reduction(x) - return x, output_size - - -def inverse_sigmoid(x, eps=1e-5): - """Inverse function of sigmoid. - - Args: - x (Tensor): The tensor to do the - inverse. - eps (float): EPS avoid numerical - overflow. Defaults 1e-5. - Returns: - Tensor: The x has passed the inverse - function of sigmoid, has same - shape with input. - """ - x = x.clamp(min=0, max=1) - x1 = x.clamp(min=eps) - x2 = (1 - x).clamp(min=eps) - return torch.log(x1 / x2) - - -@FEEDFORWARD_NETWORK.register_module(force=True) -class FFN(BaseModule): - """Implements feed-forward networks (FFNs) with identity connection. - Args: - embed_dims (int): The feature dimension. Same as - `MultiheadAttention`. Defaults: 256. - feedforward_channels (int): The hidden dimension of FFNs. - Defaults: 1024. - num_fcs (int, optional): The number of fully-connected layers in - FFNs. Default: 2. - act_cfg (dict, optional): The activation config for FFNs. - Default: dict(type='ReLU') - ffn_drop (float, optional): Probability of an element to be - zeroed in FFN. Default 0.0. - add_identity (bool, optional): Whether to add the - identity connection. Default: `True`. - dropout_layer (obj:`ConfigDict`): The dropout_layer used - when adding the shortcut. - init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. - Default: None. - """ - - @deprecated_api_warning({"dropout": "ffn_drop", "add_residual": "add_identity"}, cls_name="FFN") - def __init__( - self, - embed_dims=256, - feedforward_channels=1024, - num_fcs=2, - act_cfg=dict(type="ReLU", inplace=True), - ffn_drop=0.0, - dropout_layer=None, - add_identity=True, - init_cfg=None, - with_cp=False, - **kwargs, - ): - super().__init__(init_cfg) - assert num_fcs >= 2, "num_fcs should be no less " f"than 2. got {num_fcs}." - self.embed_dims = embed_dims - self.feedforward_channels = feedforward_channels - self.num_fcs = num_fcs - self.act_cfg = act_cfg - self.activate = build_activation_layer(act_cfg) - self.with_cp = with_cp - layers = [] - in_channels = embed_dims - for _ in range(num_fcs - 1): - layers.append(Sequential(Linear(in_channels, feedforward_channels), self.activate, nn.Dropout(ffn_drop))) - in_channels = feedforward_channels - layers.append(Linear(feedforward_channels, embed_dims)) - layers.append(nn.Dropout(ffn_drop)) - self.layers = Sequential(*layers) - self.dropout_layer = build_dropout(dropout_layer) if dropout_layer else torch.nn.Identity() - self.add_identity = add_identity - - @deprecated_api_warning({"residual": "identity"}, cls_name="FFN") - def forward(self, x, identity=None): - """Forward function for `FFN`. - The function would add x to the output tensor if residue is None. - """ - - if self.with_cp and x.requires_grad: - out = cp.checkpoint(self.layers, x) - else: - out = self.layers(x) - - if not self.add_identity: - return self.dropout_layer(out) - if identity is None: - identity = x - return identity + self.dropout_layer(out) - - -@TRANSFORMER_LAYER.register_module() -class DetrTransformerDecoderLayer(BaseTransformerLayer): - """Implements decoder layer in DETR transformer. - - Args: - attn_cfgs (list[`mmcv.ConfigDict`] | list[dict] | dict )): - Configs for self_attention or cross_attention, the order - should be consistent with it in `operation_order`. If it is - a dict, it would be expand to the number of attention in - `operation_order`. - feedforward_channels (int): The hidden dimension for FFNs. - ffn_dropout (float): Probability of an element to be zeroed - in ffn. Default 0.0. - operation_order (tuple[str]): The execution order of operation - in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm'). - Default:None - act_cfg (dict): The activation config for FFNs. Default: `LN` - norm_cfg (dict): Config dict for normalization layer. - Default: `LN`. - ffn_num_fcs (int): The number of fully-connected layers in FFNs. - Default:2. - """ - - def __init__( - self, - attn_cfgs, - feedforward_channels, - ffn_dropout=0.0, - operation_order=None, - act_cfg=dict(type="ReLU", inplace=True), - norm_cfg=dict(type="LN"), - ffn_num_fcs=2, - **kwargs, - ): - super(DetrTransformerDecoderLayer, self).__init__( - attn_cfgs=attn_cfgs, - feedforward_channels=feedforward_channels, - ffn_dropout=ffn_dropout, - operation_order=operation_order, - act_cfg=act_cfg, - norm_cfg=norm_cfg, - ffn_num_fcs=ffn_num_fcs, - **kwargs, - ) - assert len(operation_order) == 6 - assert set(operation_order) == set(["self_attn", "norm", "cross_attn", "ffn"]) - - -@TRANSFORMER_LAYER_SEQUENCE.register_module() -class DetrTransformerEncoder(TransformerLayerSequence): - """TransformerEncoder of DETR. - - Args: - post_norm_cfg (dict): Config of last normalization layer. Default: - `LN`. Only used when `self.pre_norm` is `True` - """ - - def __init__(self, *args, post_norm_cfg=dict(type="LN"), **kwargs): - super(DetrTransformerEncoder, self).__init__(*args, **kwargs) - if post_norm_cfg is not None: - self.post_norm = build_norm_layer(post_norm_cfg, self.embed_dims)[1] if self.pre_norm else None - else: - assert not self.pre_norm, f"Use prenorm in " f"{self.__class__.__name__}," f"Please specify post_norm_cfg" - self.post_norm = None - - def forward(self, *args, **kwargs): - """Forward function for `TransformerCoder`. - - Returns: - Tensor: forwarded results with shape [num_query, bs, embed_dims]. - """ - x = super(DetrTransformerEncoder, self).forward(*args, **kwargs) - if self.post_norm is not None: - x = self.post_norm(x) - return x - - -@TRANSFORMER_LAYER_SEQUENCE.register_module() -class DetrTransformerDecoder(TransformerLayerSequence): - """Implements the decoder in DETR transformer. - - Args: - return_intermediate (bool): Whether to return intermediate outputs. - post_norm_cfg (dict): Config of last normalization layer. Default: - `LN`. - """ - - def __init__(self, *args, post_norm_cfg=dict(type="LN"), return_intermediate=False, **kwargs): - - super(DetrTransformerDecoder, self).__init__(*args, **kwargs) - self.return_intermediate = return_intermediate - if post_norm_cfg is not None: - self.post_norm = build_norm_layer(post_norm_cfg, self.embed_dims)[1] - else: - self.post_norm = None - - def forward(self, query, *args, **kwargs): - """Forward function for `TransformerDecoder`. - - Args: - query (Tensor): Input query with shape - `(num_query, bs, embed_dims)`. - - Returns: - Tensor: Results with shape [1, num_query, bs, embed_dims] when - return_intermediate is `False`, otherwise it has shape - [num_layers, num_query, bs, embed_dims]. - """ - if not self.return_intermediate: - x = super().forward(query, *args, **kwargs) - if self.post_norm: - x = self.post_norm(x)[None] - return x - - intermediate = [] - for layer in self.layers: - query = layer(query, *args, **kwargs) - if self.return_intermediate: - if self.post_norm is not None: - intermediate.append(self.post_norm(query)) - else: - intermediate.append(query) - return torch.stack(intermediate) - - -@TRANSFORMER.register_module() -class Transformer(BaseModule): - """Implements the DETR transformer. - - Following the official DETR implementation, this module copy-paste - from torch.nn.Transformer with modifications: - - * positional encodings are passed in MultiheadAttention - * extra LN at the end of encoder is removed - * decoder returns a stack of activations from all decoding layers - - See `paper: End-to-End Object Detection with Transformers - `_ for details. - - Args: - encoder (`mmcv.ConfigDict` | Dict): Config of - TransformerEncoder. Defaults to None. - decoder ((`mmcv.ConfigDict` | Dict)): Config of - TransformerDecoder. Defaults to None - init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. - Defaults to None. - """ - - def __init__(self, encoder=None, decoder=None, init_cfg=None): - super(Transformer, self).__init__(init_cfg=init_cfg) - self.encoder = build_transformer_layer_sequence(encoder) - self.decoder = build_transformer_layer_sequence(decoder) - self.embed_dims = self.encoder.embed_dims - - def init_weights(self): - # follow the official DETR to init parameters - for m in self.modules(): - if hasattr(m, "weight") and m.weight.dim() > 1: - xavier_init(m, distribution="uniform") - self._is_init = True - - def forward(self, x, mask, query_embed, pos_embed): - """Forward function for `Transformer`. - - Args: - x (Tensor): Input query with shape [bs, c, h, w] where - c = embed_dims. - mask (Tensor): The key_padding_mask used for encoder and decoder, - with shape [bs, h, w]. - query_embed (Tensor): The query embedding for decoder, with shape - [num_query, c]. - pos_embed (Tensor): The positional encoding for encoder and - decoder, with the same shape as `x`. - - Returns: - tuple[Tensor]: results of decoder containing the following tensor. - - - out_dec: Output from decoder. If return_intermediate_dec \ - is True output has shape [num_dec_layers, bs, - num_query, embed_dims], else has shape [1, bs, \ - num_query, embed_dims]. - - memory: Output results from encoder, with shape \ - [bs, embed_dims, h, w]. - """ - bs, c, h, w = x.shape - # use `view` instead of `flatten` for dynamically exporting to ONNX - x = x.view(bs, c, -1).permute(2, 0, 1) # [bs, c, h, w] -> [h*w, bs, c] - pos_embed = pos_embed.view(bs, c, -1).permute(2, 0, 1) - query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) # [num_query, dim] -> [num_query, bs, dim] - mask = mask.view(bs, -1) # [bs, h, w] -> [bs, h*w] - memory = self.encoder(query=x, key=None, value=None, query_pos=pos_embed, query_key_padding_mask=mask) - target = torch.zeros_like(query_embed) - # out_dec: [num_layers, num_query, bs, dim] - out_dec = self.decoder( - query=target, key=memory, value=memory, key_pos=pos_embed, query_pos=query_embed, key_padding_mask=mask - ) - out_dec = out_dec.transpose(1, 2) - memory = memory.permute(1, 2, 0).reshape(bs, c, h, w) - return out_dec, memory - - -@TRANSFORMER_LAYER_SEQUENCE.register_module() -class DeformableDetrTransformerDecoder(TransformerLayerSequence): - """Implements the decoder in DETR transformer. - - Args: - return_intermediate (bool): Whether to return intermediate outputs. - coder_norm_cfg (dict): Config of last normalization layer. Default: - `LN`. - """ - - def __init__(self, *args, return_intermediate=False, **kwargs): - - super(DeformableDetrTransformerDecoder, self).__init__(*args, **kwargs) - self.return_intermediate = return_intermediate - - def forward(self, query, *args, reference_points=None, valid_ratios=None, reg_branches=None, **kwargs): - """Forward function for `TransformerDecoder`. - - Args: - query (Tensor): Input query with shape - `(num_query, bs, embed_dims)`. - reference_points (Tensor): The reference - points of offset. has shape - (bs, num_query, 4) when as_two_stage, - otherwise has shape ((bs, num_query, 2). - valid_ratios (Tensor): The radios of valid - points on the feature map, has shape - (bs, num_levels, 2) - reg_branch: (obj:`nn.ModuleList`): Used for - refining the regression results. Only would - be passed when with_box_refine is True, - otherwise would be passed a `None`. - - Returns: - Tensor: Results with shape [1, num_query, bs, embed_dims] when - return_intermediate is `False`, otherwise it has shape - [num_layers, num_query, bs, embed_dims]. - """ - output = query - intermediate = [] - intermediate_reference_points = [] - for lid, layer in enumerate(self.layers): - if reference_points.shape[-1] == 4: - reference_points_input = ( - reference_points[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[:, None] - ) - else: - assert reference_points.shape[-1] == 2 - reference_points_input = reference_points[:, :, None] * valid_ratios[:, None] - output = layer(output, *args, reference_points=reference_points_input, **kwargs) - output = output.permute(1, 0, 2) - - if reg_branches is not None: - tmp = reg_branches[lid](output) - if reference_points.shape[-1] == 4: - new_reference_points = tmp + inverse_sigmoid(reference_points) - new_reference_points = new_reference_points.sigmoid() - else: - assert reference_points.shape[-1] == 2 - new_reference_points = tmp - new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid(reference_points) - new_reference_points = new_reference_points.sigmoid() - reference_points = new_reference_points.detach() - - output = output.permute(1, 0, 2) - if self.return_intermediate: - intermediate.append(output) - intermediate_reference_points.append(reference_points) - - if self.return_intermediate: - return torch.stack(intermediate), torch.stack(intermediate_reference_points) - - return output, reference_points - - -@TRANSFORMER.register_module() -class DeformableDetrTransformer(Transformer): - """Implements the DeformableDETR transformer. - - Args: - as_two_stage (bool): Generate query from encoder features. - Default: False. - num_feature_levels (int): Number of feature maps from FPN: - Default: 4. - two_stage_num_proposals (int): Number of proposals when set - `as_two_stage` as True. Default: 300. - """ - - def __init__(self, as_two_stage=False, num_feature_levels=4, two_stage_num_proposals=300, **kwargs): - super(DeformableDetrTransformer, self).__init__(**kwargs) - self.as_two_stage = as_two_stage - self.num_feature_levels = num_feature_levels - self.two_stage_num_proposals = two_stage_num_proposals - self.embed_dims = self.encoder.embed_dims - self.init_layers() - - def init_layers(self): - """Initialize layers of the DeformableDetrTransformer.""" - self.level_embeds = nn.Parameter(torch.Tensor(self.num_feature_levels, self.embed_dims)) - - if self.as_two_stage: - self.enc_output = nn.Linear(self.embed_dims, self.embed_dims) - self.enc_output_norm = nn.LayerNorm(self.embed_dims) - self.pos_trans = nn.Linear(self.embed_dims * 2, self.embed_dims * 2) - self.pos_trans_norm = nn.LayerNorm(self.embed_dims * 2) - else: - self.reference_points = nn.Linear(self.embed_dims, 2) - - def init_weights(self): - """Initialize the transformer weights.""" - for p in self.parameters(): - if p.dim() > 1: - nn.init.xavier_uniform_(p) - for m in self.modules(): - if isinstance(m, MultiScaleDeformableAttention): - m.init_weights() - if not self.as_two_stage: - xavier_init(self.reference_points, distribution="uniform", bias=0.0) - normal_(self.level_embeds) - - def gen_encoder_output_proposals(self, memory, memory_padding_mask, spatial_shapes): - """Generate proposals from encoded memory. - - Args: - memory (Tensor) : The output of encoder, - has shape (bs, num_key, embed_dim). num_key is - equal the number of points on feature map from - all level. - memory_padding_mask (Tensor): Padding mask for memory. - has shape (bs, num_key). - spatial_shapes (Tensor): The shape of all feature maps. - has shape (num_level, 2). - - Returns: - tuple: A tuple of feature map and bbox prediction. - - - output_memory (Tensor): The input of decoder, \ - has shape (bs, num_key, embed_dim). num_key is \ - equal the number of points on feature map from \ - all levels. - - output_proposals (Tensor): The normalized proposal \ - after a inverse sigmoid, has shape \ - (bs, num_keys, 4). - """ - - N, S, C = memory.shape - proposals = [] - _cur = 0 - for lvl, (H, W) in enumerate(spatial_shapes): - mask_flatten_ = memory_padding_mask[:, _cur : (_cur + H * W)].view(N, H, W, 1) - valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1) - valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1) - - grid_y, grid_x = torch.meshgrid( - torch.linspace(0, H - 1, H, dtype=torch.float32, device=memory.device), - torch.linspace(0, W - 1, W, dtype=torch.float32, device=memory.device), - ) - grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) - - scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N, 1, 1, 2) - grid = (grid.unsqueeze(0).expand(N, -1, -1, -1) + 0.5) / scale - wh = torch.ones_like(grid) * 0.05 * (2.0**lvl) - proposal = torch.cat((grid, wh), -1).view(N, -1, 4) - proposals.append(proposal) - _cur += H * W - output_proposals = torch.cat(proposals, 1) - output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True) - output_proposals = torch.log(output_proposals / (1 - output_proposals)) - output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float("inf")) - output_proposals = output_proposals.masked_fill(~output_proposals_valid, float("inf")) - - output_memory = memory - output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0)) - output_memory = output_memory.masked_fill(~output_proposals_valid, float(0)) - output_memory = self.enc_output_norm(self.enc_output(output_memory)) - return output_memory, output_proposals - - @staticmethod - def get_reference_points(spatial_shapes, valid_ratios, device): - """Get the reference points used in decoder. - - Args: - spatial_shapes (Tensor): The shape of all - feature maps, has shape (num_level, 2). - valid_ratios (Tensor): The radios of valid - points on the feature map, has shape - (bs, num_levels, 2) - device (obj:`device`): The device where - reference_points should be. - - Returns: - Tensor: reference points used in decoder, has \ - shape (bs, num_keys, num_levels, 2). - """ - reference_points_list = [] - for lvl, (H, W) in enumerate(spatial_shapes): - ref_y, ref_x = torch.meshgrid( - torch.linspace(0.5, H - 0.5, H, dtype=torch.float32, device=device), - torch.linspace(0.5, W - 0.5, W, dtype=torch.float32, device=device), - ) - ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H) - ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W) - ref = torch.stack((ref_x, ref_y), -1) - reference_points_list.append(ref) - reference_points = torch.cat(reference_points_list, 1) - reference_points = reference_points[:, :, None] * valid_ratios[:, None] - return reference_points - - def get_valid_ratio(self, mask): - """Get the valid radios of feature maps of all level.""" - _, H, W = mask.shape - valid_H = torch.sum(~mask[:, :, 0], 1) - valid_W = torch.sum(~mask[:, 0, :], 1) - valid_ratio_h = valid_H.float() / H - valid_ratio_w = valid_W.float() / W - valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) - return valid_ratio - - def get_proposal_pos_embed(self, proposals, num_pos_feats=128, temperature=10000): - """Get the position embedding of proposal.""" - scale = 2 * math.pi - dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=proposals.device) - dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats) - # N, L, 4 - proposals = proposals.sigmoid() * scale - # N, L, 4, 128 - pos = proposals[:, :, :, None] / dim_t - # N, L, 4, 64, 2 - pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4).flatten(2) - return pos - - def forward( - self, mlvl_feats, mlvl_masks, query_embed, mlvl_pos_embeds, reg_branches=None, cls_branches=None, **kwargs - ): - """Forward function for `Transformer`. - - Args: - mlvl_feats (list(Tensor)): Input queries from - different level. Each element has shape - [bs, embed_dims, h, w]. - mlvl_masks (list(Tensor)): The key_padding_mask from - different level used for encoder and decoder, - each element has shape [bs, h, w]. - query_embed (Tensor): The query embedding for decoder, - with shape [num_query, c]. - mlvl_pos_embeds (list(Tensor)): The positional encoding - of feats from different level, has the shape - [bs, embed_dims, h, w]. - reg_branches (obj:`nn.ModuleList`): Regression heads for - feature maps from each decoder layer. Only would - be passed when - `with_box_refine` is True. Default to None. - cls_branches (obj:`nn.ModuleList`): Classification heads - for feature maps from each decoder layer. Only would - be passed when `as_two_stage` - is True. Default to None. - - - Returns: - tuple[Tensor]: results of decoder containing the following tensor. - - - inter_states: Outputs from decoder. If - return_intermediate_dec is True output has shape \ - (num_dec_layers, bs, num_query, embed_dims), else has \ - shape (1, bs, num_query, embed_dims). - - init_reference_out: The initial value of reference \ - points, has shape (bs, num_queries, 4). - - inter_references_out: The internal value of reference \ - points in decoder, has shape \ - (num_dec_layers, bs,num_query, embed_dims) - - enc_outputs_class: The classification score of \ - proposals generated from \ - encoder's feature maps, has shape \ - (batch, h*w, num_classes). \ - Only would be returned when `as_two_stage` is True, \ - otherwise None. - - enc_outputs_coord_unact: The regression results \ - generated from encoder's feature maps., has shape \ - (batch, h*w, 4). Only would \ - be returned when `as_two_stage` is True, \ - otherwise None. - """ - assert self.as_two_stage or query_embed is not None - - feat_flatten = [] - mask_flatten = [] - lvl_pos_embed_flatten = [] - spatial_shapes = [] - for lvl, (feat, mask, pos_embed) in enumerate(zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)): - bs, c, h, w = feat.shape - spatial_shape = (h, w) - spatial_shapes.append(spatial_shape) - feat = feat.flatten(2).transpose(1, 2) - mask = mask.flatten(1) - pos_embed = pos_embed.flatten(2).transpose(1, 2) - lvl_pos_embed = pos_embed + self.level_embeds[lvl].view(1, 1, -1) - lvl_pos_embed_flatten.append(lvl_pos_embed) - feat_flatten.append(feat) - mask_flatten.append(mask) - feat_flatten = torch.cat(feat_flatten, 1) - mask_flatten = torch.cat(mask_flatten, 1) - lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) - spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=feat_flatten.device) - level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) - valid_ratios = torch.stack([self.get_valid_ratio(m) for m in mlvl_masks], 1) - - reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=feat.device) - - feat_flatten = feat_flatten.permute(1, 0, 2) # (H*W, bs, embed_dims) - lvl_pos_embed_flatten = lvl_pos_embed_flatten.permute(1, 0, 2) # (H*W, bs, embed_dims) - memory = self.encoder( - query=feat_flatten, - key=None, - value=None, - query_pos=lvl_pos_embed_flatten, - query_key_padding_mask=mask_flatten, - spatial_shapes=spatial_shapes, - reference_points=reference_points, - level_start_index=level_start_index, - valid_ratios=valid_ratios, - **kwargs, - ) - - memory = memory.permute(1, 0, 2) - bs, _, c = memory.shape - if self.as_two_stage: - output_memory, output_proposals = self.gen_encoder_output_proposals(memory, mask_flatten, spatial_shapes) - enc_outputs_class = cls_branches[self.decoder.num_layers](output_memory) - enc_outputs_coord_unact = reg_branches[self.decoder.num_layers](output_memory) + output_proposals - - topk = self.two_stage_num_proposals - topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1] - topk_coords_unact = torch.gather(enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) - topk_coords_unact = topk_coords_unact.detach() - reference_points = topk_coords_unact.sigmoid() - init_reference_out = reference_points - pos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact))) - query_pos, query = torch.split(pos_trans_out, c, dim=2) - else: - query_pos, query = torch.split(query_embed, c, dim=1) - query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1) - query = query.unsqueeze(0).expand(bs, -1, -1) - reference_points = self.reference_points(query_pos).sigmoid() - init_reference_out = reference_points - - # decoder - query = query.permute(1, 0, 2) - memory = memory.permute(1, 0, 2) - query_pos = query_pos.permute(1, 0, 2) - inter_states, inter_references = self.decoder( - query=query, - key=None, - value=memory, - query_pos=query_pos, - key_padding_mask=mask_flatten, - reference_points=reference_points, - spatial_shapes=spatial_shapes, - level_start_index=level_start_index, - valid_ratios=valid_ratios, - reg_branches=reg_branches, - **kwargs, - ) - - inter_references_out = inter_references - if self.as_two_stage: - return inter_states, init_reference_out, inter_references_out, enc_outputs_class, enc_outputs_coord_unact - return inter_states, init_reference_out, inter_references_out, None, None - - -@TRANSFORMER.register_module() -class DynamicConv(BaseModule): - """Implements Dynamic Convolution. - - This module generate parameters for each sample and - use bmm to implement 1*1 convolution. Code is modified - from the `official github repo `_ . - - Args: - in_channels (int): The input feature channel. - Defaults to 256. - feat_channels (int): The inner feature channel. - Defaults to 64. - out_channels (int, optional): The output feature channel. - When not specified, it will be set to `in_channels` - by default - input_feat_shape (int): The shape of input feature. - Defaults to 7. - with_proj (bool): Project two-dimentional feature to - one-dimentional feature. Default to True. - act_cfg (dict): The activation config for DynamicConv. - norm_cfg (dict): Config dict for normalization layer. Default - layer normalization. - init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. - Default: None. - """ - - def __init__( - self, - in_channels=256, - feat_channels=64, - out_channels=None, - input_feat_shape=7, - with_proj=True, - act_cfg=dict(type="ReLU", inplace=True), - norm_cfg=dict(type="LN"), - init_cfg=None, - ): - super(DynamicConv, self).__init__(init_cfg) - self.in_channels = in_channels - self.feat_channels = feat_channels - self.out_channels_raw = out_channels - self.input_feat_shape = input_feat_shape - self.with_proj = with_proj - self.act_cfg = act_cfg - self.norm_cfg = norm_cfg - self.out_channels = out_channels if out_channels else in_channels - - self.num_params_in = self.in_channels * self.feat_channels - self.num_params_out = self.out_channels * self.feat_channels - self.dynamic_layer = nn.Linear(self.in_channels, self.num_params_in + self.num_params_out) - - self.norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1] - self.norm_out = build_norm_layer(norm_cfg, self.out_channels)[1] - - self.activation = build_activation_layer(act_cfg) - - num_output = self.out_channels * input_feat_shape**2 - if self.with_proj: - self.fc_layer = nn.Linear(num_output, self.out_channels) - self.fc_norm = build_norm_layer(norm_cfg, self.out_channels)[1] - - def forward(self, param_feature, input_feature): - """Forward function for `DynamicConv`. - - Args: - param_feature (Tensor): The feature can be used - to generate the parameter, has shape - (num_all_proposals, in_channels). - input_feature (Tensor): Feature that - interact with parameters, has shape - (num_all_proposals, in_channels, H, W). - - Returns: - Tensor: The output feature has shape - (num_all_proposals, out_channels). - """ - input_feature = input_feature.flatten(2).permute(2, 0, 1) - - input_feature = input_feature.permute(1, 0, 2) - parameters = self.dynamic_layer(param_feature) - - param_in = parameters[:, : self.num_params_in].view(-1, self.in_channels, self.feat_channels) - param_out = parameters[:, -self.num_params_out :].view(-1, self.feat_channels, self.out_channels) - - # input_feature has shape (num_all_proposals, H*W, in_channels) - # param_in has shape (num_all_proposals, in_channels, feat_channels) - # feature has shape (num_all_proposals, H*W, feat_channels) - features = torch.bmm(input_feature, param_in) - features = self.norm_in(features) - features = self.activation(features) - - # param_out has shape (batch_size, feat_channels, out_channels) - features = torch.bmm(features, param_out) - features = self.norm_out(features) - features = self.activation(features) - - if self.with_proj: - features = features.flatten(1) - features = self.fc_layer(features) - features = self.fc_norm(features) - features = self.activation(features) - - return features diff --git a/dinov2/eval/segmentation_m2f/ops/modules/__init__.py b/dinov2/eval/segmentation_m2f/ops/modules/__init__.py deleted file mode 100644 index 49aa8fe612fd4c088e294707c5ee16bd1cb5b5e7..0000000000000000000000000000000000000000 --- a/dinov2/eval/segmentation_m2f/ops/modules/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -# References: -# https://github.com/fundamentalvision/Deformable-DETR/tree/main/models/ops/modules -# https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 - -from .ms_deform_attn import MSDeformAttn diff --git a/dinov2/eval/segmentation_m2f/ops/modules/ms_deform_attn.py b/dinov2/eval/segmentation_m2f/ops/modules/ms_deform_attn.py deleted file mode 100644 index d8b4fa23712e87d1a2682b57e71ee37fe8524cff..0000000000000000000000000000000000000000 --- a/dinov2/eval/segmentation_m2f/ops/modules/ms_deform_attn.py +++ /dev/null @@ -1,185 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import math -import warnings - -import torch -import torch.nn.functional as F -from torch import nn -from torch.autograd import Function -from torch.cuda.amp import custom_fwd -from torch.nn.init import constant_, xavier_uniform_ - - -class MSDeformAttnFunction(Function): - @staticmethod - @custom_fwd(cast_inputs=torch.float32) - def forward( - ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step - ): - output = ms_deform_attn_core_pytorch( - value, - value_spatial_shapes, - # value_level_start_index, - sampling_locations, - attention_weights, - ) - return output - - -def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): - # for debug and test only, - # need to use cuda version instead - N_, S_, M_, D_ = value.shape - _, Lq_, M_, L_, P_, _ = sampling_locations.shape - value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) - sampling_grids = 2 * sampling_locations - 1 - sampling_value_list = [] - for lid_, (H_, W_) in enumerate(value_spatial_shapes): - # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ - value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_ * M_, D_, H_, W_) - # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 - sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) - # N_*M_, D_, Lq_, P_ - sampling_value_l_ = F.grid_sample( - value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False - ) - sampling_value_list.append(sampling_value_l_) - # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) - attention_weights = attention_weights.transpose(1, 2).reshape(N_ * M_, 1, Lq_, L_ * P_) - output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_ * D_, Lq_) - return output.transpose(1, 2).contiguous() - - -def _is_power_of_2(n): - if (not isinstance(n, int)) or (n < 0): - raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) - return (n & (n - 1) == 0) and n != 0 - - -class MSDeformAttn(nn.Module): - def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4, ratio=1.0): - """Multi-Scale Deformable Attention Module. - - :param d_model hidden dimension - :param n_levels number of feature levels - :param n_heads number of attention heads - :param n_points number of sampling points per attention head per feature level - """ - super().__init__() - if d_model % n_heads != 0: - raise ValueError("d_model must be divisible by n_heads, " "but got {} and {}".format(d_model, n_heads)) - _d_per_head = d_model // n_heads - # you'd better set _d_per_head to a power of 2 - # which is more efficient in our CUDA implementation - if not _is_power_of_2(_d_per_head): - warnings.warn( - "You'd better set d_model in MSDeformAttn to make " - "the dimension of each attention head a power of 2 " - "which is more efficient in our CUDA implementation." - ) - - self.im2col_step = 64 - - self.d_model = d_model - self.n_levels = n_levels - self.n_heads = n_heads - self.n_points = n_points - self.ratio = ratio - self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) - self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) - self.value_proj = nn.Linear(d_model, int(d_model * ratio)) - self.output_proj = nn.Linear(int(d_model * ratio), d_model) - - self._reset_parameters() - - def _reset_parameters(self): - constant_(self.sampling_offsets.weight.data, 0.0) - thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) - grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) - grid_init = ( - (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) - .view(self.n_heads, 1, 1, 2) - .repeat(1, self.n_levels, self.n_points, 1) - ) - for i in range(self.n_points): - grid_init[:, :, i, :] *= i + 1 - - with torch.no_grad(): - self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) - constant_(self.attention_weights.weight.data, 0.0) - constant_(self.attention_weights.bias.data, 0.0) - xavier_uniform_(self.value_proj.weight.data) - constant_(self.value_proj.bias.data, 0.0) - xavier_uniform_(self.output_proj.weight.data) - constant_(self.output_proj.bias.data, 0.0) - - def forward( - self, - query, - reference_points, - input_flatten, - input_spatial_shapes, - input_level_start_index, - input_padding_mask=None, - ): - """ - :param query (N, Length_{query}, C) - :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area - or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes - :param input_flatten (N, \\sum_{l=0}^{L-1} H_l \\cdot W_l, C) - :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] - :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] - :param input_padding_mask (N, \\sum_{l=0}^{L-1} H_l \\cdot W_l), True for padding elements, False for non-padding elements - - :return output (N, Length_{query}, C) - """ - # print(query.shape) - # print(reference_points.shape) - # print(input_flatten.shape) - # print(input_spatial_shapes.shape) - # print(input_level_start_index.shape) - # print(input_spatial_shapes) - # print(input_level_start_index) - - N, Len_q, _ = query.shape - N, Len_in, _ = input_flatten.shape - assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in - - value = self.value_proj(input_flatten) - if input_padding_mask is not None: - value = value.masked_fill(input_padding_mask[..., None], float(0)) - - value = value.view(N, Len_in, self.n_heads, int(self.ratio * self.d_model) // self.n_heads) - sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) - attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) - attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) - - if reference_points.shape[-1] == 2: - offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) - sampling_locations = ( - reference_points[:, :, None, :, None, :] - + sampling_offsets / offset_normalizer[None, None, None, :, None, :] - ) - elif reference_points.shape[-1] == 4: - sampling_locations = ( - reference_points[:, :, None, :, None, :2] - + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 - ) - else: - raise ValueError( - "Last dim of reference_points must be 2 or 4, but get {} instead.".format(reference_points.shape[-1]) - ) - output = MSDeformAttnFunction.apply( - value, - input_spatial_shapes, - input_level_start_index, - sampling_locations, - attention_weights, - self.im2col_step, - ) - output = self.output_proj(output) - return output diff --git a/dinov2/eval/setup.py b/dinov2/eval/setup.py deleted file mode 100644 index 959128c0673cc51036dbf17dcc4ee68a037988fb..0000000000000000000000000000000000000000 --- a/dinov2/eval/setup.py +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import argparse -from typing import Any, List, Optional, Tuple - -import torch -import torch.backends.cudnn as cudnn - -from dinov2.models import build_model_from_cfg -from dinov2.utils.config import setup -import dinov2.utils.utils as dinov2_utils - - -def get_args_parser( - description: Optional[str] = None, - parents: Optional[List[argparse.ArgumentParser]] = None, - add_help: bool = True, -): - parser = argparse.ArgumentParser( - description=description, - parents=parents or [], - add_help=add_help, - ) - parser.add_argument( - "--config-file", - type=str, - help="Model configuration file", - ) - parser.add_argument( - "--pretrained-weights", - type=str, - help="Pretrained model weights", - ) - parser.add_argument( - "--output-dir", - default="", - type=str, - help="Output directory to write results and logs", - ) - parser.add_argument( - "--opts", - help="Extra configuration options", - default=[], - nargs="+", - ) - return parser - - -def get_autocast_dtype(config): - teacher_dtype_str = config.compute_precision.teacher.backbone.mixed_precision.param_dtype - if teacher_dtype_str == "fp16": - return torch.half - elif teacher_dtype_str == "bf16": - return torch.bfloat16 - else: - return torch.float - - -def build_model_for_eval(config, pretrained_weights): - model, _ = build_model_from_cfg(config, only_teacher=True) - dinov2_utils.load_pretrained_weights(model, pretrained_weights, "teacher") - model.eval() - model.cuda() - return model - - -def setup_and_build_model(args) -> Tuple[Any, torch.dtype]: - cudnn.benchmark = True - config = setup(args) - model = build_model_for_eval(config, args.pretrained_weights) - autocast_dtype = get_autocast_dtype(config) - return model, autocast_dtype diff --git a/dinov2/eval/utils.py b/dinov2/eval/utils.py deleted file mode 100644 index c50576b1940587ee64b7a422e2e96b475d60fd39..0000000000000000000000000000000000000000 --- a/dinov2/eval/utils.py +++ /dev/null @@ -1,146 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import logging -from typing import Dict, Optional - -import torch -from torch import nn -from torchmetrics import MetricCollection - -from dinov2.data import DatasetWithEnumeratedTargets, SamplerType, make_data_loader -import dinov2.distributed as distributed -from dinov2.logging import MetricLogger - - -logger = logging.getLogger("dinov2") - - -class ModelWithNormalize(torch.nn.Module): - def __init__(self, model): - super().__init__() - self.model = model - - def forward(self, samples): - return nn.functional.normalize(self.model(samples), dim=1, p=2) - - -class ModelWithIntermediateLayers(nn.Module): - def __init__(self, feature_model, n_last_blocks, autocast_ctx): - super().__init__() - self.feature_model = feature_model - self.feature_model.eval() - self.n_last_blocks = n_last_blocks - self.autocast_ctx = autocast_ctx - - def forward(self, images): - with torch.inference_mode(): - with self.autocast_ctx(): - features = self.feature_model.get_intermediate_layers( - images, self.n_last_blocks, return_class_token=True - ) - return features - - -@torch.inference_mode() -def evaluate( - model: nn.Module, - data_loader, - postprocessors: Dict[str, nn.Module], - metrics: Dict[str, MetricCollection], - device: torch.device, - criterion: Optional[nn.Module] = None, -): - model.eval() - if criterion is not None: - criterion.eval() - - for metric in metrics.values(): - metric = metric.to(device) - - metric_logger = MetricLogger(delimiter=" ") - header = "Test:" - - for samples, targets, *_ in metric_logger.log_every(data_loader, 10, header): - outputs = model(samples.to(device)) - targets = targets.to(device) - - if criterion is not None: - loss = criterion(outputs, targets) - metric_logger.update(loss=loss.item()) - - for k, metric in metrics.items(): - metric_inputs = postprocessors[k](outputs, targets) - metric.update(**metric_inputs) - - metric_logger.synchronize_between_processes() - logger.info(f"Averaged stats: {metric_logger}") - - stats = {k: metric.compute() for k, metric in metrics.items()} - metric_logger_stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} - return metric_logger_stats, stats - - -def all_gather_and_flatten(tensor_rank): - tensor_all_ranks = torch.empty( - distributed.get_global_size(), - *tensor_rank.shape, - dtype=tensor_rank.dtype, - device=tensor_rank.device, - ) - tensor_list = list(tensor_all_ranks.unbind(0)) - torch.distributed.all_gather(tensor_list, tensor_rank.contiguous()) - return tensor_all_ranks.flatten(end_dim=1) - - -def extract_features(model, dataset, batch_size, num_workers, gather_on_cpu=False): - dataset_with_enumerated_targets = DatasetWithEnumeratedTargets(dataset) - sample_count = len(dataset_with_enumerated_targets) - data_loader = make_data_loader( - dataset=dataset_with_enumerated_targets, - batch_size=batch_size, - num_workers=num_workers, - sampler_type=SamplerType.DISTRIBUTED, - drop_last=False, - shuffle=False, - ) - return extract_features_with_dataloader(model, data_loader, sample_count, gather_on_cpu) - - -@torch.inference_mode() -def extract_features_with_dataloader(model, data_loader, sample_count, gather_on_cpu=False): - gather_device = torch.device("cpu") if gather_on_cpu else torch.device("cuda") - metric_logger = MetricLogger(delimiter=" ") - features, all_labels = None, None - for samples, (index, labels_rank) in metric_logger.log_every(data_loader, 10): - samples = samples.cuda(non_blocking=True) - labels_rank = labels_rank.cuda(non_blocking=True) - index = index.cuda(non_blocking=True) - features_rank = model(samples).float() - - # init storage feature matrix - if features is None: - features = torch.zeros(sample_count, features_rank.shape[-1], device=gather_device) - labels_shape = list(labels_rank.shape) - labels_shape[0] = sample_count - all_labels = torch.full(labels_shape, fill_value=-1, device=gather_device) - logger.info(f"Storing features into tensor of shape {features.shape}") - - # share indexes, features and labels between processes - index_all = all_gather_and_flatten(index).to(gather_device) - features_all_ranks = all_gather_and_flatten(features_rank).to(gather_device) - labels_all_ranks = all_gather_and_flatten(labels_rank).to(gather_device) - - # update storage feature matrix - if len(index_all) > 0: - features.index_copy_(0, index_all, features_all_ranks) - all_labels.index_copy_(0, index_all, labels_all_ranks) - - logger.info(f"Features shape: {tuple(features.shape)}") - logger.info(f"Labels shape: {tuple(all_labels.shape)}") - - assert torch.all(all_labels > -1) - - return features, all_labels diff --git a/dinov2/fsdp/__init__.py b/dinov2/fsdp/__init__.py deleted file mode 100644 index ed454480e0b76e761d657cc40fd097bd339d15a2..0000000000000000000000000000000000000000 --- a/dinov2/fsdp/__init__.py +++ /dev/null @@ -1,157 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import os -from typing import Any - -import torch -import dinov2.distributed as distributed -from functools import partial -from fvcore.common.checkpoint import Checkpointer -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp import ShardingStrategy -from torch.distributed.fsdp import MixedPrecision -from torch.distributed.fsdp import StateDictType -from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler -from torch.distributed.fsdp.wrap import ModuleWrapPolicy -from torch.distributed.fsdp._runtime_utils import _reshard - - -def get_fsdp_wrapper(model_cfg, modules_to_wrap=set()): - sharding_strategy_dict = { - "NO_SHARD": ShardingStrategy.NO_SHARD, - "SHARD_GRAD_OP": ShardingStrategy.SHARD_GRAD_OP, - "FULL_SHARD": ShardingStrategy.FULL_SHARD, - } - - dtype_dict = { - "fp32": torch.float32, - "fp16": torch.float16, - "bf16": torch.bfloat16, - } - - mixed_precision_config = MixedPrecision( - param_dtype=dtype_dict[model_cfg.mixed_precision.param_dtype], - reduce_dtype=dtype_dict[model_cfg.mixed_precision.reduce_dtype], - buffer_dtype=dtype_dict[model_cfg.mixed_precision.buffer_dtype], - ) - - sharding_strategy_config = sharding_strategy_dict[model_cfg.sharding_strategy] - - local_rank = distributed.get_local_rank() - - fsdp_wrapper = partial( - FSDP, - sharding_strategy=sharding_strategy_config, - mixed_precision=mixed_precision_config, - device_id=local_rank, - sync_module_states=True, - use_orig_params=True, - auto_wrap_policy=ModuleWrapPolicy(modules_to_wrap), - ) - return fsdp_wrapper - - -def is_fsdp(x): - return isinstance(x, FSDP) - - -def is_sharded_fsdp(x): - return is_fsdp(x) and x.sharding_strategy is not ShardingStrategy.NO_SHARD - - -def free_if_fsdp(x): - if is_sharded_fsdp(x): - handles = x._handles - true_list = [True for h in handles] - _reshard(x, handles, true_list) - - -def get_fsdp_modules(x): - return FSDP.fsdp_modules(x) - - -def reshard_fsdp_model(x): - for m in get_fsdp_modules(x): - free_if_fsdp(m) - - -def rankstr(): - return f"rank_{distributed.get_global_rank()}" - - -class FSDPCheckpointer(Checkpointer): - def save(self, name: str, **kwargs: Any) -> None: - """ - Dump model and checkpointables to a file. - - Args: - name (str): name of the file. - kwargs (dict): extra arbitrary data to save. - """ - if not self.save_dir or not self.save_to_disk: - return - - data = {} - with FSDP.state_dict_type(self.model, StateDictType.LOCAL_STATE_DICT): - data["model"] = self.model.state_dict() - - # data["model"] = self.model.state_dict() - for key, obj in self.checkpointables.items(): - data[key] = obj.state_dict() - data.update(kwargs) - - basename = f"{name}.{rankstr()}.pth" - save_file = os.path.join(self.save_dir, basename) - assert os.path.basename(save_file) == basename, basename - self.logger.info("Saving checkpoint to {}".format(save_file)) - with self.path_manager.open(save_file, "wb") as f: - torch.save(data, f) - self.tag_last_checkpoint(basename) - - def load(self, *args, **kwargs): - with FSDP.state_dict_type(self.model, StateDictType.LOCAL_STATE_DICT): - return super().load(*args, **kwargs) - - def has_checkpoint(self) -> bool: - """ - Returns: - bool: whether a checkpoint exists in the target directory. - """ - save_file = os.path.join(self.save_dir, f"last_checkpoint.{rankstr()}") - return self.path_manager.exists(save_file) - - def get_checkpoint_file(self) -> str: - """ - Returns: - str: The latest checkpoint file in target directory. - """ - save_file = os.path.join(self.save_dir, f"last_checkpoint.{rankstr()}") - try: - with self.path_manager.open(save_file, "r") as f: - last_saved = f.read().strip() - except IOError: - # if file doesn't exist, maybe because it has just been - # deleted by a separate process - return "" - # pyre-fixme[6]: For 2nd param expected `Union[PathLike[str], str]` but got - # `Union[bytes, str]`. - return os.path.join(self.save_dir, last_saved) - - def tag_last_checkpoint(self, last_filename_basename: str) -> None: - """ - Tag the last checkpoint. - - Args: - last_filename_basename (str): the basename of the last filename. - """ - if distributed.is_enabled(): - torch.distributed.barrier() - save_file = os.path.join(self.save_dir, f"last_checkpoint.{rankstr()}") - with self.path_manager.open(save_file, "w") as f: - f.write(last_filename_basename) # pyre-ignore - - -ShardedGradScaler = ShardedGradScaler diff --git a/dinov2/hub/__init__.py b/dinov2/hub/__init__.py deleted file mode 100644 index b88da6bf80be92af00b72dfdb0a806fa64a7a2d9..0000000000000000000000000000000000000000 --- a/dinov2/hub/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. diff --git a/dinov2/hub/backbones.py b/dinov2/hub/backbones.py deleted file mode 100644 index 9204dc6296973046d2e42e6b33ac21ede31eb31b..0000000000000000000000000000000000000000 --- a/dinov2/hub/backbones.py +++ /dev/null @@ -1,174 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from enum import Enum -from pathlib import Path -from typing import Optional, Union -from urllib.parse import urlparse - -import torch - -from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name - - -class Weights(Enum): - LVD142M = "LVD142M" - XRAY_DINO = "XRay-DINO" - - -def is_url(path: str) -> bool: - parsed = urlparse(path) - return parsed.scheme in ("https", "file") - - -def convert_path_or_url_to_url(path: str) -> str: - if is_url(path): - return path - return Path(path).expanduser().resolve().as_uri() - - -def _make_dinov2_model( - *, - arch_name: str = "vit_large", - img_size: int = 518, - patch_size: int = 14, - init_values: float = 1.0, - ffn_layer: str = "mlp", - block_chunks: int = 0, - num_register_tokens: int = 0, - interpolate_antialias: bool = False, - interpolate_offset: float = 0.1, - pretrained: bool = True, - weights: Union[Weights, str] = Weights.LVD142M, - hash: Optional[str] = None, - check_hash: bool = False, - **kwargs, -): - from ..models import vision_transformer as vits - - model_base_name = _make_dinov2_model_name(arch_name, patch_size) - vit_kwargs = dict( - img_size=img_size, - patch_size=patch_size, - init_values=init_values, - ffn_layer=ffn_layer, - block_chunks=block_chunks, - num_register_tokens=num_register_tokens, - interpolate_antialias=interpolate_antialias, - interpolate_offset=interpolate_offset, - ) - vit_kwargs.update(**kwargs) - model = vits.__dict__[arch_name](**vit_kwargs) - - if pretrained: - if type(weights) is Weights and weights not in { - Weights.LVD142M, - Weights.XRAY_DINO, - }: - raise ValueError(f"Unsupported weights for the backbone: {weights}") - elif type(weights) is Weights: - model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens) - url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth" - else: - url = convert_path_or_url_to_url(weights) - state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu", check_hash=check_hash) - model.load_state_dict(state_dict, strict=True) - - return model - - -def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): - """ - DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset. - """ - return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs) - - -def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): - """ - DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset. - """ - return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs) - - -def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): - """ - DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset. - """ - return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs) - - -def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): - """ - DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset. - """ - return _make_dinov2_model( - arch_name="vit_giant2", - ffn_layer="swiglufused", - weights=weights, - pretrained=pretrained, - **kwargs, - ) - - -def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): - """ - DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset. - """ - return _make_dinov2_model( - arch_name="vit_small", - pretrained=pretrained, - weights=weights, - num_register_tokens=4, - interpolate_antialias=True, - interpolate_offset=0.0, - **kwargs, - ) - - -def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): - """ - DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset. - """ - return _make_dinov2_model( - arch_name="vit_base", - pretrained=pretrained, - weights=weights, - num_register_tokens=4, - interpolate_antialias=True, - interpolate_offset=0.0, - **kwargs, - ) - - -def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): - """ - DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset. - """ - return _make_dinov2_model( - arch_name="vit_large", - pretrained=pretrained, - weights=weights, - num_register_tokens=4, - interpolate_antialias=True, - interpolate_offset=0.0, - **kwargs, - ) - - -def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): - """ - DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset. - """ - return _make_dinov2_model( - arch_name="vit_giant2", - ffn_layer="swiglufused", - weights=weights, - pretrained=pretrained, - num_register_tokens=4, - interpolate_antialias=True, - interpolate_offset=0.0, - **kwargs, - ) diff --git a/dinov2/hub/cell_dino/backbones.py b/dinov2/hub/cell_dino/backbones.py deleted file mode 100644 index 0e6b90b0204cadcd20de1363d332d71d2a68e544..0000000000000000000000000000000000000000 --- a/dinov2/hub/cell_dino/backbones.py +++ /dev/null @@ -1,182 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the CC-by-NC licence, -# found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree. - -from enum import Enum -from typing import Optional, Union - -import torch - - -class Weights(Enum): - CELL_DINO = "CELL-DINO" - - -def _make_cell_dino_model( - *, - arch_name: str = "vit_large", - img_size: int = 518, - patch_size: int = 14, - init_values: float = 1.0, - ffn_layer: str = "mlp", - block_chunks: int = 0, - num_register_tokens: int = 0, - interpolate_antialias: bool = False, - interpolate_offset: float = 0.1, - pretrained: bool = True, - channel_adaptive: bool = False, - weights: Union[Weights, str] = Weights.CELL_DINO, - pretrained_url: Optional[str] = None, - pretrained_path: Optional[str] = None, - **kwargs, -): - from ...models import vision_transformer as vits - - if isinstance(weights, str): - try: - weights = Weights[weights] - except KeyError: - raise AssertionError(f"Unsupported weights: {weights}") - - vit_kwargs = dict( - img_size=img_size, - patch_size=patch_size, - init_values=init_values, - ffn_layer=ffn_layer, - block_chunks=block_chunks, - num_register_tokens=num_register_tokens, - interpolate_antialias=interpolate_antialias, - interpolate_offset=interpolate_offset, - channel_adaptive=channel_adaptive, - ) - vit_kwargs.update(**kwargs) - model = vits.__dict__[arch_name](**vit_kwargs) - - if pretrained: - if pretrained_path is not None: - state_dict = torch.load(pretrained_path, map_location="cpu") - else: - pretrained_url is not None - state_dict = torch.hub.load_state_dict_from_url(pretrained_url, map_location="cpu") - model.load_state_dict(state_dict, strict=True) - - return model - - -def cell_dino_hpa_vitl16( - *, - pretrained_url: Optional[str] = None, - pretrained_path: Optional[str] = None, - pretrained: bool = True, - weights: Union[Weights, str] = Weights.CELL_DINO, - in_channels: int = 4, - **kwargs, -): - """ - Cell-DINO ViT-L/16 model dataset pretrained on HPA dataset. - """ - return _make_cell_dino_model( - arch_name="vit_large", - patch_size=16, - img_size=224, - num_register_tokens=0, - interpolate_antialias=False, - interpolate_offset=0.1, - block_chunks=4, - pretrained_url=pretrained_url, - pretrained_path=pretrained_path, - pretrained=pretrained, - weights=weights, - in_chans=in_channels, - **kwargs, - ) - - -def cell_dino_hpa_vitl14( - *, - pretrained_url: Optional[str] = None, - pretrained_path: Optional[str] = None, - pretrained: bool = True, - weights: Union[Weights, str] = Weights.CELL_DINO, - in_channels: int = 4, - **kwargs, -): - """ - Cell-DINO ViT-L/14 model dataset pretrained on LVD, then on HPA dataset. - """ - return _make_cell_dino_model( - arch_name="vit_large", - patch_size=14, - img_size=518, - num_register_tokens=0, - interpolate_antialias=False, - interpolate_offset=0.1, - block_chunks=4, - pretrained_url=pretrained_url, - pretrained_path=pretrained_path, - pretrained=pretrained, - weights=weights, - in_chans=in_channels, - **kwargs, - ) - - -def cell_dino_cp_vits8( - *, - pretrained_url: Optional[str] = None, - pretrained_path: Optional[str] = None, - pretrained: bool = True, - weights: Union[Weights, str] = Weights.CELL_DINO, - in_channels: int = 5, - **kwargs, -): - """ - Cell-DINO ViT-S/8 model dataset pretrained on the combined cell painting dataset. - """ - return _make_cell_dino_model( - arch_name="vit_small", - patch_size=8, - img_size=128, - num_register_tokens=0, - interpolate_antialias=False, - interpolate_offset=0.1, - block_chunks=4, - pretrained_url=pretrained_url, - pretrained_path=pretrained_path, - pretrained=pretrained, - weights=weights, - in_chans=in_channels, - **kwargs, - ) - - -def channel_adaptive_dino_vitl16( - *, - pretrained_url: Optional[str] = None, - pretrained_path: Optional[str] = None, - pretrained: bool = True, - weights: Union[Weights, str] = Weights.CELL_DINO, - in_channels: int = 1, - channel_adaptive: bool = True, - **kwargs, -): - """ - Cell-DINO ViT-L/16 model dataset pretrained on HPA dataset. - """ - return _make_cell_dino_model( - arch_name="vit_large", - patch_size=16, - img_size=224, - num_register_tokens=0, - interpolate_antialias=False, - interpolate_offset=0.1, - block_chunks=4, - pretrained_url=pretrained_url, - pretrained_path=pretrained_path, - pretrained=pretrained, - weights=weights, - in_chans=in_channels, - channel_adaptive=channel_adaptive, - **kwargs, - ) diff --git a/dinov2/hub/classifiers.py b/dinov2/hub/classifiers.py deleted file mode 100644 index 3f0841efa80ab3d564cd320d61da254af182606b..0000000000000000000000000000000000000000 --- a/dinov2/hub/classifiers.py +++ /dev/null @@ -1,268 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from enum import Enum -from typing import Union - -import torch -import torch.nn as nn - -from .backbones import _make_dinov2_model -from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name - - -class Weights(Enum): - IMAGENET1K = "IMAGENET1K" - - -def _make_dinov2_linear_classification_head( - *, - arch_name: str = "vit_large", - patch_size: int = 14, - embed_dim: int = 1024, - layers: int = 4, - pretrained: bool = True, - weights: Union[Weights, str] = Weights.IMAGENET1K, - num_register_tokens: int = 0, - **kwargs, -): - if layers not in (1, 4): - raise AssertionError(f"Unsupported number of layers: {layers}") - if isinstance(weights, str): - try: - weights = Weights[weights] - except KeyError: - raise AssertionError(f"Unsupported weights: {weights}") - - linear_head = nn.Linear((1 + layers) * embed_dim, 1_000) - - if pretrained: - model_base_name = _make_dinov2_model_name(arch_name, patch_size) - model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens) - layers_str = str(layers) if layers == 4 else "" - url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_linear{layers_str}_head.pth" - state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") - linear_head.load_state_dict(state_dict, strict=True) - - return linear_head - - -class _LinearClassifierWrapper(nn.Module): - def __init__(self, *, backbone: nn.Module, linear_head: nn.Module, layers: int = 4): - super().__init__() - self.backbone = backbone - self.linear_head = linear_head - self.layers = layers - - def forward(self, x): - if self.layers == 1: - x = self.backbone.forward_features(x) - cls_token = x["x_norm_clstoken"] - patch_tokens = x["x_norm_patchtokens"] - # fmt: off - linear_input = torch.cat([ - cls_token, - patch_tokens.mean(dim=1), - ], dim=1) - # fmt: on - elif self.layers == 4: - x = self.backbone.get_intermediate_layers(x, n=4, return_class_token=True) - # fmt: off - linear_input = torch.cat([ - x[0][1], - x[1][1], - x[2][1], - x[3][1], - x[3][0].mean(dim=1), - ], dim=1) - # fmt: on - else: - assert False, f"Unsupported number of layers: {self.layers}" - return self.linear_head(linear_input) - - -def _make_dinov2_linear_classifier( - *, - arch_name: str = "vit_large", - layers: int = 4, - pretrained: bool = True, - weights: Union[Weights, str] = Weights.IMAGENET1K, - num_register_tokens: int = 0, - interpolate_antialias: bool = False, - interpolate_offset: float = 0.1, - **kwargs, -): - backbone = _make_dinov2_model( - arch_name=arch_name, - pretrained=pretrained, - num_register_tokens=num_register_tokens, - interpolate_antialias=interpolate_antialias, - interpolate_offset=interpolate_offset, - **kwargs, - ) - - embed_dim = backbone.embed_dim - patch_size = backbone.patch_size - linear_head = _make_dinov2_linear_classification_head( - arch_name=arch_name, - patch_size=patch_size, - embed_dim=embed_dim, - layers=layers, - pretrained=pretrained, - weights=weights, - num_register_tokens=num_register_tokens, - ) - - return _LinearClassifierWrapper(backbone=backbone, linear_head=linear_head, layers=layers) - - -def dinov2_vits14_lc( - *, - layers: int = 4, - pretrained: bool = True, - weights: Union[Weights, str] = Weights.IMAGENET1K, - **kwargs, -): - """ - Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-S/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. - """ - return _make_dinov2_linear_classifier( - arch_name="vit_small", - layers=layers, - pretrained=pretrained, - weights=weights, - **kwargs, - ) - - -def dinov2_vitb14_lc( - *, - layers: int = 4, - pretrained: bool = True, - weights: Union[Weights, str] = Weights.IMAGENET1K, - **kwargs, -): - """ - Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-B/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. - """ - return _make_dinov2_linear_classifier( - arch_name="vit_base", - layers=layers, - pretrained=pretrained, - weights=weights, - **kwargs, - ) - - -def dinov2_vitl14_lc( - *, - layers: int = 4, - pretrained: bool = True, - weights: Union[Weights, str] = Weights.IMAGENET1K, - **kwargs, -): - """ - Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-L/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. - """ - return _make_dinov2_linear_classifier( - arch_name="vit_large", - layers=layers, - pretrained=pretrained, - weights=weights, - **kwargs, - ) - - -def dinov2_vitg14_lc( - *, - layers: int = 4, - pretrained: bool = True, - weights: Union[Weights, str] = Weights.IMAGENET1K, - **kwargs, -): - """ - Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-g/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. - """ - return _make_dinov2_linear_classifier( - arch_name="vit_giant2", - layers=layers, - ffn_layer="swiglufused", - pretrained=pretrained, - weights=weights, - **kwargs, - ) - - -def dinov2_vits14_reg_lc( - *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs -): - """ - Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-S/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. - """ - return _make_dinov2_linear_classifier( - arch_name="vit_small", - layers=layers, - pretrained=pretrained, - weights=weights, - num_register_tokens=4, - interpolate_antialias=True, - interpolate_offset=0.0, - **kwargs, - ) - - -def dinov2_vitb14_reg_lc( - *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs -): - """ - Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-B/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. - """ - return _make_dinov2_linear_classifier( - arch_name="vit_base", - layers=layers, - pretrained=pretrained, - weights=weights, - num_register_tokens=4, - interpolate_antialias=True, - interpolate_offset=0.0, - **kwargs, - ) - - -def dinov2_vitl14_reg_lc( - *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs -): - """ - Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-L/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. - """ - return _make_dinov2_linear_classifier( - arch_name="vit_large", - layers=layers, - pretrained=pretrained, - weights=weights, - num_register_tokens=4, - interpolate_antialias=True, - interpolate_offset=0.0, - **kwargs, - ) - - -def dinov2_vitg14_reg_lc( - *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs -): - """ - Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-g/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. - """ - return _make_dinov2_linear_classifier( - arch_name="vit_giant2", - layers=layers, - ffn_layer="swiglufused", - pretrained=pretrained, - weights=weights, - num_register_tokens=4, - interpolate_antialias=True, - interpolate_offset=0.0, - **kwargs, - ) diff --git a/dinov2/hub/depth/__init__.py b/dinov2/hub/depth/__init__.py deleted file mode 100644 index 91716e58ab6158d814df8c653644d9af4c7be65c..0000000000000000000000000000000000000000 --- a/dinov2/hub/depth/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from .decode_heads import BNHead, DPTHead -from .encoder_decoder import DepthEncoderDecoder diff --git a/dinov2/hub/depth/decode_heads.py b/dinov2/hub/depth/decode_heads.py deleted file mode 100644 index f455accad38fec6ecdd53460233a564c34f434da..0000000000000000000000000000000000000000 --- a/dinov2/hub/depth/decode_heads.py +++ /dev/null @@ -1,747 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import copy -from functools import partial -import math -import warnings - -import torch -import torch.nn as nn - -from .ops import resize - - -# XXX: (Untested) replacement for mmcv.imdenormalize() -def _imdenormalize(img, mean, std, to_bgr=True): - import numpy as np - - mean = mean.reshape(1, -1).astype(np.float64) - std = std.reshape(1, -1).astype(np.float64) - img = (img * std) + mean - if to_bgr: - img = img[::-1] - return img - - -class DepthBaseDecodeHead(nn.Module): - """Base class for BaseDecodeHead. - - Args: - in_channels (List): Input channels. - channels (int): Channels after modules, before conv_depth. - conv_layer (nn.Module): Conv layers. Default: None. - act_layer (nn.Module): Activation layers. Default: nn.ReLU. - loss_decode (dict): Config of decode loss. - Default: (). - sampler (dict|None): The config of depth map sampler. - Default: None. - align_corners (bool): align_corners argument of F.interpolate. - Default: False. - min_depth (int): Min depth in dataset setting. - Default: 1e-3. - max_depth (int): Max depth in dataset setting. - Default: None. - norm_layer (dict|None): Norm layers. - Default: None. - classify (bool): Whether predict depth in a cls.-reg. manner. - Default: False. - n_bins (int): The number of bins used in cls. step. - Default: 256. - bins_strategy (str): The discrete strategy used in cls. step. - Default: 'UD'. - norm_strategy (str): The norm strategy on cls. probability - distribution. Default: 'linear' - scale_up (str): Whether predict depth in a scale-up manner. - Default: False. - """ - - def __init__( - self, - in_channels, - conv_layer=None, - act_layer=nn.ReLU, - channels=96, - loss_decode=(), - sampler=None, - align_corners=False, - min_depth=1e-3, - max_depth=None, - norm_layer=None, - classify=False, - n_bins=256, - bins_strategy="UD", - norm_strategy="linear", - scale_up=False, - ): - super(DepthBaseDecodeHead, self).__init__() - - self.in_channels = in_channels - self.channels = channels - self.conf_layer = conv_layer - self.act_layer = act_layer - self.loss_decode = loss_decode - self.align_corners = align_corners - self.min_depth = min_depth - self.max_depth = max_depth - self.norm_layer = norm_layer - self.classify = classify - self.n_bins = n_bins - self.scale_up = scale_up - - if self.classify: - assert bins_strategy in ["UD", "SID"], "Support bins_strategy: UD, SID" - assert norm_strategy in ["linear", "softmax", "sigmoid"], "Support norm_strategy: linear, softmax, sigmoid" - - self.bins_strategy = bins_strategy - self.norm_strategy = norm_strategy - self.softmax = nn.Softmax(dim=1) - self.conv_depth = nn.Conv2d(channels, n_bins, kernel_size=3, padding=1, stride=1) - else: - self.conv_depth = nn.Conv2d(channels, 1, kernel_size=3, padding=1, stride=1) - - self.relu = nn.ReLU() - self.sigmoid = nn.Sigmoid() - - def forward(self, inputs, img_metas): - """Placeholder of forward function.""" - pass - - def forward_train(self, img, inputs, img_metas, depth_gt): - """Forward function for training. - Args: - inputs (list[Tensor]): List of multi-level img features. - img_metas (list[dict]): List of image info dict where each dict - has: 'img_shape', 'scale_factor', 'flip', and may also contain - 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. - For details on the values of these keys see - `depth/datasets/pipelines/formatting.py:Collect`. - depth_gt (Tensor): GT depth - - Returns: - dict[str, Tensor]: a dictionary of loss components - """ - depth_pred = self.forward(inputs, img_metas) - losses = self.losses(depth_pred, depth_gt) - - log_imgs = self.log_images(img[0], depth_pred[0], depth_gt[0], img_metas[0]) - losses.update(**log_imgs) - - return losses - - def forward_test(self, inputs, img_metas): - """Forward function for testing. - Args: - inputs (list[Tensor]): List of multi-level img features. - img_metas (list[dict]): List of image info dict where each dict - has: 'img_shape', 'scale_factor', 'flip', and may also contain - 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. - For details on the values of these keys see - `depth/datasets/pipelines/formatting.py:Collect`. - - Returns: - Tensor: Output depth map. - """ - return self.forward(inputs, img_metas) - - def depth_pred(self, feat): - """Prediction each pixel.""" - if self.classify: - logit = self.conv_depth(feat) - - if self.bins_strategy == "UD": - bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device) - elif self.bins_strategy == "SID": - bins = torch.logspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device) - - # following Adabins, default linear - if self.norm_strategy == "linear": - logit = torch.relu(logit) - eps = 0.1 - logit = logit + eps - logit = logit / logit.sum(dim=1, keepdim=True) - elif self.norm_strategy == "softmax": - logit = torch.softmax(logit, dim=1) - elif self.norm_strategy == "sigmoid": - logit = torch.sigmoid(logit) - logit = logit / logit.sum(dim=1, keepdim=True) - - output = torch.einsum("ikmn,k->imn", [logit, bins]).unsqueeze(dim=1) - - else: - if self.scale_up: - output = self.sigmoid(self.conv_depth(feat)) * self.max_depth - else: - output = self.relu(self.conv_depth(feat)) + self.min_depth - return output - - def losses(self, depth_pred, depth_gt): - """Compute depth loss.""" - loss = dict() - depth_pred = resize( - input=depth_pred, size=depth_gt.shape[2:], mode="bilinear", align_corners=self.align_corners, warning=False - ) - if not isinstance(self.loss_decode, nn.ModuleList): - losses_decode = [self.loss_decode] - else: - losses_decode = self.loss_decode - for loss_decode in losses_decode: - if loss_decode.loss_name not in loss: - loss[loss_decode.loss_name] = loss_decode(depth_pred, depth_gt) - else: - loss[loss_decode.loss_name] += loss_decode(depth_pred, depth_gt) - return loss - - def log_images(self, img_path, depth_pred, depth_gt, img_meta): - import numpy as np - - show_img = copy.deepcopy(img_path.detach().cpu().permute(1, 2, 0)) - show_img = show_img.numpy().astype(np.float32) - show_img = _imdenormalize( - show_img, - img_meta["img_norm_cfg"]["mean"], - img_meta["img_norm_cfg"]["std"], - img_meta["img_norm_cfg"]["to_rgb"], - ) - show_img = np.clip(show_img, 0, 255) - show_img = show_img.astype(np.uint8) - show_img = show_img[:, :, ::-1] - show_img = show_img.transpose(0, 2, 1) - show_img = show_img.transpose(1, 0, 2) - - depth_pred = depth_pred / torch.max(depth_pred) - depth_gt = depth_gt / torch.max(depth_gt) - - depth_pred_color = copy.deepcopy(depth_pred.detach().cpu()) - depth_gt_color = copy.deepcopy(depth_gt.detach().cpu()) - - return {"img_rgb": show_img, "img_depth_pred": depth_pred_color, "img_depth_gt": depth_gt_color} - - -class BNHead(DepthBaseDecodeHead): - """Just a batchnorm.""" - - def __init__(self, input_transform="resize_concat", in_index=(0, 1, 2, 3), upsample=1, **kwargs): - super().__init__(**kwargs) - self.input_transform = input_transform - self.in_index = in_index - self.upsample = upsample - # self.bn = nn.SyncBatchNorm(self.in_channels) - if self.classify: - self.conv_depth = nn.Conv2d(self.channels, self.n_bins, kernel_size=1, padding=0, stride=1) - else: - self.conv_depth = nn.Conv2d(self.channels, 1, kernel_size=1, padding=0, stride=1) - - def _transform_inputs(self, inputs): - """Transform inputs for decoder. - Args: - inputs (list[Tensor]): List of multi-level img features. - Returns: - Tensor: The transformed inputs - """ - - if "concat" in self.input_transform: - inputs = [inputs[i] for i in self.in_index] - if "resize" in self.input_transform: - inputs = [ - resize( - input=x, - size=[s * self.upsample for s in inputs[0].shape[2:]], - mode="bilinear", - align_corners=self.align_corners, - ) - for x in inputs - ] - inputs = torch.cat(inputs, dim=1) - elif self.input_transform == "multiple_select": - inputs = [inputs[i] for i in self.in_index] - else: - inputs = inputs[self.in_index] - - return inputs - - def _forward_feature(self, inputs, img_metas=None, **kwargs): - """Forward function for feature maps before classifying each pixel with - ``self.cls_seg`` fc. - Args: - inputs (list[Tensor]): List of multi-level img features. - Returns: - feats (Tensor): A tensor of shape (batch_size, self.channels, - H, W) which is feature map for last layer of decoder head. - """ - # accept lists (for cls token) - inputs = list(inputs) - for i, x in enumerate(inputs): - if len(x) == 2: - x, cls_token = x[0], x[1] - if len(x.shape) == 2: - x = x[:, :, None, None] - cls_token = cls_token[:, :, None, None].expand_as(x) - inputs[i] = torch.cat((x, cls_token), 1) - else: - x = x[0] - if len(x.shape) == 2: - x = x[:, :, None, None] - inputs[i] = x - x = self._transform_inputs(inputs) - # feats = self.bn(x) - return x - - def forward(self, inputs, img_metas=None, **kwargs): - """Forward function.""" - output = self._forward_feature(inputs, img_metas=img_metas, **kwargs) - output = self.depth_pred(output) - return output - - -class ConvModule(nn.Module): - """A conv block that bundles conv/norm/activation layers. - - This block simplifies the usage of convolution layers, which are commonly - used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU). - It is based upon three build methods: `build_conv_layer()`, - `build_norm_layer()` and `build_activation_layer()`. - - Besides, we add some additional features in this module. - 1. Automatically set `bias` of the conv layer. - 2. Spectral norm is supported. - 3. More padding modes are supported. Before PyTorch 1.5, nn.Conv2d only - supports zero and circular padding, and we add "reflect" padding mode. - - Args: - in_channels (int): Number of channels in the input feature map. - Same as that in ``nn._ConvNd``. - out_channels (int): Number of channels produced by the convolution. - Same as that in ``nn._ConvNd``. - kernel_size (int | tuple[int]): Size of the convolving kernel. - Same as that in ``nn._ConvNd``. - stride (int | tuple[int]): Stride of the convolution. - Same as that in ``nn._ConvNd``. - padding (int | tuple[int]): Zero-padding added to both sides of - the input. Same as that in ``nn._ConvNd``. - dilation (int | tuple[int]): Spacing between kernel elements. - Same as that in ``nn._ConvNd``. - groups (int): Number of blocked connections from input channels to - output channels. Same as that in ``nn._ConvNd``. - bias (bool | str): If specified as `auto`, it will be decided by the - norm_layer. Bias will be set as True if `norm_layer` is None, otherwise - False. Default: "auto". - conv_layer (nn.Module): Convolution layer. Default: None, - which means using conv2d. - norm_layer (nn.Module): Normalization layer. Default: None. - act_layer (nn.Module): Activation layer. Default: nn.ReLU. - inplace (bool): Whether to use inplace mode for activation. - Default: True. - with_spectral_norm (bool): Whether use spectral norm in conv module. - Default: False. - padding_mode (str): If the `padding_mode` has not been supported by - current `Conv2d` in PyTorch, we will use our own padding layer - instead. Currently, we support ['zeros', 'circular'] with official - implementation and ['reflect'] with our own implementation. - Default: 'zeros'. - order (tuple[str]): The order of conv/norm/activation layers. It is a - sequence of "conv", "norm" and "act". Common examples are - ("conv", "norm", "act") and ("act", "conv", "norm"). - Default: ('conv', 'norm', 'act'). - """ - - _abbr_ = "conv_block" - - def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - bias="auto", - conv_layer=nn.Conv2d, - norm_layer=None, - act_layer=nn.ReLU, - inplace=True, - with_spectral_norm=False, - padding_mode="zeros", - order=("conv", "norm", "act"), - ): - super(ConvModule, self).__init__() - official_padding_mode = ["zeros", "circular"] - self.conv_layer = conv_layer - self.norm_layer = norm_layer - self.act_layer = act_layer - self.inplace = inplace - self.with_spectral_norm = with_spectral_norm - self.with_explicit_padding = padding_mode not in official_padding_mode - self.order = order - assert isinstance(self.order, tuple) and len(self.order) == 3 - assert set(order) == set(["conv", "norm", "act"]) - - self.with_norm = norm_layer is not None - self.with_activation = act_layer is not None - # if the conv layer is before a norm layer, bias is unnecessary. - if bias == "auto": - bias = not self.with_norm - self.with_bias = bias - - if self.with_explicit_padding: - if padding_mode == "zeros": - padding_layer = nn.ZeroPad2d - else: - raise AssertionError(f"Unsupported padding mode: {padding_mode}") - self.pad = padding_layer(padding) - - # reset padding to 0 for conv module - conv_padding = 0 if self.with_explicit_padding else padding - # build convolution layer - self.conv = self.conv_layer( - in_channels, - out_channels, - kernel_size, - stride=stride, - padding=conv_padding, - dilation=dilation, - groups=groups, - bias=bias, - ) - # export the attributes of self.conv to a higher level for convenience - self.in_channels = self.conv.in_channels - self.out_channels = self.conv.out_channels - self.kernel_size = self.conv.kernel_size - self.stride = self.conv.stride - self.padding = padding - self.dilation = self.conv.dilation - self.transposed = self.conv.transposed - self.output_padding = self.conv.output_padding - self.groups = self.conv.groups - - if self.with_spectral_norm: - self.conv = nn.utils.spectral_norm(self.conv) - - # build normalization layers - if self.with_norm: - # norm layer is after conv layer - if order.index("norm") > order.index("conv"): - norm_channels = out_channels - else: - norm_channels = in_channels - norm = partial(norm_layer, num_features=norm_channels) - self.add_module("norm", norm) - if self.with_bias: - from torch.nnModules.batchnorm import _BatchNorm - from torch.nnModules.instancenorm import _InstanceNorm - - if isinstance(norm, (_BatchNorm, _InstanceNorm)): - warnings.warn("Unnecessary conv bias before batch/instance norm") - else: - self.norm_name = None - - # build activation layer - if self.with_activation: - # nn.Tanh has no 'inplace' argument - # (nn.Tanh, nn.PReLU, nn.Sigmoid, nn.HSigmoid, nn.Swish, nn.GELU) - if not isinstance(act_layer, (nn.Tanh, nn.PReLU, nn.Sigmoid, nn.GELU)): - act_layer = partial(act_layer, inplace=inplace) - self.activate = act_layer() - - # Use msra init by default - self.init_weights() - - @property - def norm(self): - if self.norm_name: - return getattr(self, self.norm_name) - else: - return None - - def init_weights(self): - # 1. It is mainly for customized conv layers with their own - # initialization manners by calling their own ``init_weights()``, - # and we do not want ConvModule to override the initialization. - # 2. For customized conv layers without their own initialization - # manners (that is, they don't have their own ``init_weights()``) - # and PyTorch's conv layers, they will be initialized by - # this method with default ``kaiming_init``. - # Note: For PyTorch's conv layers, they will be overwritten by our - # initialization implementation using default ``kaiming_init``. - if not hasattr(self.conv, "init_weights"): - if self.with_activation and isinstance(self.act_layer, nn.LeakyReLU): - nonlinearity = "leaky_relu" - a = 0.01 # XXX: default negative_slope - else: - nonlinearity = "relu" - a = 0 - if hasattr(self.conv, "weight") and self.conv.weight is not None: - nn.init.kaiming_normal_(self.conv.weight, a=a, mode="fan_out", nonlinearity=nonlinearity) - if hasattr(self.conv, "bias") and self.conv.bias is not None: - nn.init.constant_(self.conv.bias, 0) - if self.with_norm: - if hasattr(self.norm, "weight") and self.norm.weight is not None: - nn.init.constant_(self.norm.weight, 1) - if hasattr(self.norm, "bias") and self.norm.bias is not None: - nn.init.constant_(self.norm.bias, 0) - - def forward(self, x, activate=True, norm=True): - for layer in self.order: - if layer == "conv": - if self.with_explicit_padding: - x = self.pad(x) - x = self.conv(x) - elif layer == "norm" and norm and self.with_norm: - x = self.norm(x) - elif layer == "act" and activate and self.with_activation: - x = self.activate(x) - return x - - -class Interpolate(nn.Module): - def __init__(self, scale_factor, mode, align_corners=False): - super(Interpolate, self).__init__() - self.interp = nn.functional.interpolate - self.scale_factor = scale_factor - self.mode = mode - self.align_corners = align_corners - - def forward(self, x): - x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners) - return x - - -class HeadDepth(nn.Module): - def __init__(self, features): - super(HeadDepth, self).__init__() - self.head = nn.Sequential( - nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), - Interpolate(scale_factor=2, mode="bilinear", align_corners=True), - nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), - nn.ReLU(), - nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), - ) - - def forward(self, x): - x = self.head(x) - return x - - -class ReassembleBlocks(nn.Module): - """ViTPostProcessBlock, process cls_token in ViT backbone output and - rearrange the feature vector to feature map. - Args: - in_channels (int): ViT feature channels. Default: 768. - out_channels (List): output channels of each stage. - Default: [96, 192, 384, 768]. - readout_type (str): Type of readout operation. Default: 'ignore'. - patch_size (int): The patch size. Default: 16. - """ - - def __init__(self, in_channels=768, out_channels=[96, 192, 384, 768], readout_type="ignore", patch_size=16): - super(ReassembleBlocks, self).__init__() - - assert readout_type in ["ignore", "add", "project"] - self.readout_type = readout_type - self.patch_size = patch_size - - self.projects = nn.ModuleList( - [ - ConvModule( - in_channels=in_channels, - out_channels=out_channel, - kernel_size=1, - act_layer=None, - ) - for out_channel in out_channels - ] - ) - - self.resize_layers = nn.ModuleList( - [ - nn.ConvTranspose2d( - in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0 - ), - nn.ConvTranspose2d( - in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0 - ), - nn.Identity(), - nn.Conv2d( - in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1 - ), - ] - ) - if self.readout_type == "project": - self.readout_projects = nn.ModuleList() - for _ in range(len(self.projects)): - self.readout_projects.append(nn.Sequential(nn.Linear(2 * in_channels, in_channels), nn.GELU())) - - def forward(self, inputs): - assert isinstance(inputs, list) - out = [] - for i, x in enumerate(inputs): - assert len(x) == 2 - x, cls_token = x[0], x[1] - feature_shape = x.shape - if self.readout_type == "project": - x = x.flatten(2).permute((0, 2, 1)) - readout = cls_token.unsqueeze(1).expand_as(x) - x = self.readout_projects[i](torch.cat((x, readout), -1)) - x = x.permute(0, 2, 1).reshape(feature_shape) - elif self.readout_type == "add": - x = x.flatten(2) + cls_token.unsqueeze(-1) - x = x.reshape(feature_shape) - else: - pass - x = self.projects[i](x) - x = self.resize_layers[i](x) - out.append(x) - return out - - -class PreActResidualConvUnit(nn.Module): - """ResidualConvUnit, pre-activate residual unit. - Args: - in_channels (int): number of channels in the input feature map. - act_layer (nn.Module): activation layer. - norm_layer (nn.Module): norm layer. - stride (int): stride of the first block. Default: 1 - dilation (int): dilation rate for convs layers. Default: 1. - """ - - def __init__(self, in_channels, act_layer, norm_layer, stride=1, dilation=1): - super(PreActResidualConvUnit, self).__init__() - - self.conv1 = ConvModule( - in_channels, - in_channels, - 3, - stride=stride, - padding=dilation, - dilation=dilation, - norm_layer=norm_layer, - act_layer=act_layer, - bias=False, - order=("act", "conv", "norm"), - ) - - self.conv2 = ConvModule( - in_channels, - in_channels, - 3, - padding=1, - norm_layer=norm_layer, - act_layer=act_layer, - bias=False, - order=("act", "conv", "norm"), - ) - - def forward(self, inputs): - inputs_ = inputs.clone() - x = self.conv1(inputs) - x = self.conv2(x) - return x + inputs_ - - -class FeatureFusionBlock(nn.Module): - """FeatureFusionBlock, merge feature map from different stages. - Args: - in_channels (int): Input channels. - act_layer (nn.Module): activation layer for ResidualConvUnit. - norm_layer (nn.Module): normalization layer. - expand (bool): Whether expand the channels in post process block. - Default: False. - align_corners (bool): align_corner setting for bilinear upsample. - Default: True. - """ - - def __init__(self, in_channels, act_layer, norm_layer, expand=False, align_corners=True): - super(FeatureFusionBlock, self).__init__() - - self.in_channels = in_channels - self.expand = expand - self.align_corners = align_corners - - self.out_channels = in_channels - if self.expand: - self.out_channels = in_channels // 2 - - self.project = ConvModule(self.in_channels, self.out_channels, kernel_size=1, act_layer=None, bias=True) - - self.res_conv_unit1 = PreActResidualConvUnit( - in_channels=self.in_channels, act_layer=act_layer, norm_layer=norm_layer - ) - self.res_conv_unit2 = PreActResidualConvUnit( - in_channels=self.in_channels, act_layer=act_layer, norm_layer=norm_layer - ) - - def forward(self, *inputs): - x = inputs[0] - if len(inputs) == 2: - if x.shape != inputs[1].shape: - res = resize(inputs[1], size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False) - else: - res = inputs[1] - x = x + self.res_conv_unit1(res) - x = self.res_conv_unit2(x) - x = resize(x, scale_factor=2, mode="bilinear", align_corners=self.align_corners) - x = self.project(x) - return x - - -class DPTHead(DepthBaseDecodeHead): - """Vision Transformers for Dense Prediction. - This head is implemented of `DPT `_. - Args: - embed_dims (int): The embed dimension of the ViT backbone. - Default: 768. - post_process_channels (List): Out channels of post process conv - layers. Default: [96, 192, 384, 768]. - readout_type (str): Type of readout operation. Default: 'ignore'. - patch_size (int): The patch size. Default: 16. - expand_channels (bool): Whether expand the channels in post process - block. Default: False. - """ - - def __init__( - self, - embed_dims=768, - post_process_channels=[96, 192, 384, 768], - readout_type="ignore", - patch_size=16, - expand_channels=False, - **kwargs, - ): - super(DPTHead, self).__init__(**kwargs) - - self.in_channels = self.in_channels - self.expand_channels = expand_channels - self.reassemble_blocks = ReassembleBlocks(embed_dims, post_process_channels, readout_type, patch_size) - - self.post_process_channels = [ - channel * math.pow(2, i) if expand_channels else channel for i, channel in enumerate(post_process_channels) - ] - self.convs = nn.ModuleList() - for channel in self.post_process_channels: - self.convs.append(ConvModule(channel, self.channels, kernel_size=3, padding=1, act_layer=None, bias=False)) - self.fusion_blocks = nn.ModuleList() - for _ in range(len(self.convs)): - self.fusion_blocks.append(FeatureFusionBlock(self.channels, self.act_layer, self.norm_layer)) - self.fusion_blocks[0].res_conv_unit1 = None - self.project = ConvModule(self.channels, self.channels, kernel_size=3, padding=1, norm_layer=self.norm_layer) - self.num_fusion_blocks = len(self.fusion_blocks) - self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers) - self.num_post_process_channels = len(self.post_process_channels) - assert self.num_fusion_blocks == self.num_reassemble_blocks - assert self.num_reassemble_blocks == self.num_post_process_channels - self.conv_depth = HeadDepth(self.channels) - - def forward(self, inputs, img_metas): - assert len(inputs) == self.num_reassemble_blocks - x = [inp for inp in inputs] - x = self.reassemble_blocks(x) - x = [self.convs[i](feature) for i, feature in enumerate(x)] - out = self.fusion_blocks[0](x[-1]) - for i in range(1, len(self.fusion_blocks)): - out = self.fusion_blocks[i](out, x[-(i + 1)]) - out = self.project(out) - out = self.depth_pred(out) - return out diff --git a/dinov2/hub/depth/encoder_decoder.py b/dinov2/hub/depth/encoder_decoder.py deleted file mode 100644 index eb29ced67957a336e763b0e7c90c0eeaea36fea8..0000000000000000000000000000000000000000 --- a/dinov2/hub/depth/encoder_decoder.py +++ /dev/null @@ -1,351 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from collections import OrderedDict - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from .ops import resize - - -def add_prefix(inputs, prefix): - """Add prefix for dict. - - Args: - inputs (dict): The input dict with str keys. - prefix (str): The prefix to add. - - Returns: - - dict: The dict with keys updated with ``prefix``. - """ - - outputs = dict() - for name, value in inputs.items(): - outputs[f"{prefix}.{name}"] = value - - return outputs - - -class DepthEncoderDecoder(nn.Module): - """Encoder Decoder depther. - - EncoderDecoder typically consists of backbone and decode_head. - """ - - def __init__(self, backbone, decode_head): - super(DepthEncoderDecoder, self).__init__() - - self.backbone = backbone - self.decode_head = decode_head - self.align_corners = self.decode_head.align_corners - - def extract_feat(self, img): - """Extract features from images.""" - return self.backbone(img) - - def encode_decode(self, img, img_metas, rescale=True, size=None): - """Encode images with backbone and decode into a depth estimation - map of the same size as input.""" - x = self.extract_feat(img) - out = self._decode_head_forward_test(x, img_metas) - # crop the pred depth to the certain range. - out = torch.clamp(out, min=self.decode_head.min_depth, max=self.decode_head.max_depth) - if rescale: - if size is None: - if img_metas is not None: - size = img_metas[0]["ori_shape"][:2] - else: - size = img.shape[2:] - out = resize(input=out, size=size, mode="bilinear", align_corners=self.align_corners) - return out - - def _decode_head_forward_train(self, img, x, img_metas, depth_gt, **kwargs): - """Run forward function and calculate loss for decode head in - training.""" - losses = dict() - loss_decode = self.decode_head.forward_train(img, x, img_metas, depth_gt, **kwargs) - losses.update(add_prefix(loss_decode, "decode")) - return losses - - def _decode_head_forward_test(self, x, img_metas): - """Run forward function and calculate loss for decode head in - inference.""" - depth_pred = self.decode_head.forward_test(x, img_metas) - return depth_pred - - def forward_dummy(self, img): - """Dummy forward function.""" - depth = self.encode_decode(img, None) - - return depth - - def forward_train(self, img, img_metas, depth_gt, **kwargs): - """Forward function for training. - - Args: - img (Tensor): Input images. - img_metas (list[dict]): List of image info dict where each dict - has: 'img_shape', 'scale_factor', 'flip', and may also contain - 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. - For details on the values of these keys see - `depth/datasets/pipelines/formatting.py:Collect`. - depth_gt (Tensor): Depth gt - used if the architecture supports depth estimation task. - - Returns: - dict[str, Tensor]: a dictionary of loss components - """ - - x = self.extract_feat(img) - - losses = dict() - - # the last of x saves the info from neck - loss_decode = self._decode_head_forward_train(img, x, img_metas, depth_gt, **kwargs) - - losses.update(loss_decode) - - return losses - - def whole_inference(self, img, img_meta, rescale, size=None): - """Inference with full image.""" - return self.encode_decode(img, img_meta, rescale, size=size) - - def slide_inference(self, img, img_meta, rescale, stride, crop_size): - """Inference by sliding-window with overlap. - - If h_crop > h_img or w_crop > w_img, the small patch will be used to - decode without padding. - """ - - h_stride, w_stride = stride - h_crop, w_crop = crop_size - batch_size, _, h_img, w_img = img.size() - h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 - w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 - preds = img.new_zeros((batch_size, 1, h_img, w_img)) - count_mat = img.new_zeros((batch_size, 1, h_img, w_img)) - for h_idx in range(h_grids): - for w_idx in range(w_grids): - y1 = h_idx * h_stride - x1 = w_idx * w_stride - y2 = min(y1 + h_crop, h_img) - x2 = min(x1 + w_crop, w_img) - y1 = max(y2 - h_crop, 0) - x1 = max(x2 - w_crop, 0) - crop_img = img[:, :, y1:y2, x1:x2] - depth_pred = self.encode_decode(crop_img, img_meta, rescale) - preds += F.pad(depth_pred, (int(x1), int(preds.shape[3] - x2), int(y1), int(preds.shape[2] - y2))) - - count_mat[:, :, y1:y2, x1:x2] += 1 - assert (count_mat == 0).sum() == 0 - if torch.onnx.is_in_onnx_export(): - # cast count_mat to constant while exporting to ONNX - count_mat = torch.from_numpy(count_mat.cpu().detach().numpy()).to(device=img.device) - preds = preds / count_mat - return preds - - def inference(self, img, img_meta, rescale, size=None, mode="whole"): - """Inference with slide/whole style. - - Args: - img (Tensor): The input image of shape (N, 3, H, W). - img_meta (dict): Image info dict where each dict has: 'img_shape', - 'scale_factor', 'flip', and may also contain - 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. - For details on the values of these keys see - `depth/datasets/pipelines/formatting.py:Collect`. - rescale (bool): Whether rescale back to original shape. - - Returns: - Tensor: The output depth map. - """ - - assert mode in ["slide", "whole"] - ori_shape = img_meta[0]["ori_shape"] - assert all(_["ori_shape"] == ori_shape for _ in img_meta) - if mode == "slide": - depth_pred = self.slide_inference(img, img_meta, rescale) - else: - depth_pred = self.whole_inference(img, img_meta, rescale, size=size) - output = depth_pred - flip = img_meta[0]["flip"] - if flip: - flip_direction = img_meta[0]["flip_direction"] - assert flip_direction in ["horizontal", "vertical"] - if flip_direction == "horizontal": - output = output.flip(dims=(3,)) - elif flip_direction == "vertical": - output = output.flip(dims=(2,)) - - return output - - def simple_test(self, img, img_meta, rescale=True): - """Simple test with single image.""" - depth_pred = self.inference(img, img_meta, rescale) - if torch.onnx.is_in_onnx_export(): - # our inference backend only support 4D output - depth_pred = depth_pred.unsqueeze(0) - return depth_pred - depth_pred = depth_pred.cpu().numpy() - # unravel batch dim - depth_pred = list(depth_pred) - return depth_pred - - def aug_test(self, imgs, img_metas, rescale=True): - """Test with augmentations. - - Only rescale=True is supported. - """ - # aug_test rescale all imgs back to ori_shape for now - assert rescale - # to save memory, we get augmented depth logit inplace - depth_pred = self.inference(imgs[0], img_metas[0], rescale) - for i in range(1, len(imgs)): - cur_depth_pred = self.inference(imgs[i], img_metas[i], rescale, size=depth_pred.shape[-2:]) - depth_pred += cur_depth_pred - depth_pred /= len(imgs) - depth_pred = depth_pred.cpu().numpy() - # unravel batch dim - depth_pred = list(depth_pred) - return depth_pred - - def forward_test(self, imgs, img_metas, **kwargs): - """ - Args: - imgs (List[Tensor]): the outer list indicates test-time - augmentations and inner Tensor should have a shape NxCxHxW, - which contains all images in the batch. - img_metas (List[List[dict]]): the outer list indicates test-time - augs (multiscale, flip, etc.) and the inner list indicates - images in a batch. - """ - for var, name in [(imgs, "imgs"), (img_metas, "img_metas")]: - if not isinstance(var, list): - raise TypeError(f"{name} must be a list, but got " f"{type(var)}") - num_augs = len(imgs) - if num_augs != len(img_metas): - raise ValueError(f"num of augmentations ({len(imgs)}) != " f"num of image meta ({len(img_metas)})") - # all images in the same aug batch all of the same ori_shape and pad - # shape - for img_meta in img_metas: - ori_shapes = [_["ori_shape"] for _ in img_meta] - assert all(shape == ori_shapes[0] for shape in ori_shapes) - img_shapes = [_["img_shape"] for _ in img_meta] - assert all(shape == img_shapes[0] for shape in img_shapes) - pad_shapes = [_["pad_shape"] for _ in img_meta] - assert all(shape == pad_shapes[0] for shape in pad_shapes) - - if num_augs == 1: - return self.simple_test(imgs[0], img_metas[0], **kwargs) - else: - return self.aug_test(imgs, img_metas, **kwargs) - - def forward(self, img, img_metas, return_loss=True, **kwargs): - """Calls either :func:`forward_train` or :func:`forward_test` depending - on whether ``return_loss`` is ``True``. - - Note this setting will change the expected inputs. When - ``return_loss=True``, img and img_meta are single-nested (i.e. Tensor - and List[dict]), and when ``resturn_loss=False``, img and img_meta - should be double nested (i.e. List[Tensor], List[List[dict]]), with - the outer list indicating test time augmentations. - """ - if return_loss: - return self.forward_train(img, img_metas, **kwargs) - else: - return self.forward_test(img, img_metas, **kwargs) - - def train_step(self, data_batch, optimizer, **kwargs): - """The iteration step during training. - - This method defines an iteration step during training, except for the - back propagation and optimizer updating, which are done in an optimizer - hook. Note that in some complicated cases or models, the whole process - including back propagation and optimizer updating is also defined in - this method, such as GAN. - - Args: - data (dict): The output of dataloader. - optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of - runner is passed to ``train_step()``. This argument is unused - and reserved. - - Returns: - dict: It should contain at least 3 keys: ``loss``, ``log_vars``, - ``num_samples``. - ``loss`` is a tensor for back propagation, which can be a - weighted sum of multiple losses. - ``log_vars`` contains all the variables to be sent to the - logger. - ``num_samples`` indicates the batch size (when the model is - DDP, it means the batch size on each GPU), which is used for - averaging the logs. - """ - losses = self(**data_batch) - - # split losses and images - real_losses = {} - log_imgs = {} - for k, v in losses.items(): - if "img" in k: - log_imgs[k] = v - else: - real_losses[k] = v - - loss, log_vars = self._parse_losses(real_losses) - - outputs = dict(loss=loss, log_vars=log_vars, num_samples=len(data_batch["img_metas"]), log_imgs=log_imgs) - - return outputs - - def val_step(self, data_batch, **kwargs): - """The iteration step during validation. - - This method shares the same signature as :func:`train_step`, but used - during val epochs. Note that the evaluation after training epochs is - not implemented with this method, but an evaluation hook. - """ - output = self(**data_batch, **kwargs) - return output - - @staticmethod - def _parse_losses(losses): - import torch.distributed as dist - - """Parse the raw outputs (losses) of the network. - - Args: - losses (dict): Raw output of the network, which usually contain - losses and other necessary information. - - Returns: - tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor - which may be a weighted sum of all losses, log_vars contains - all the variables to be sent to the logger. - """ - log_vars = OrderedDict() - for loss_name, loss_value in losses.items(): - if isinstance(loss_value, torch.Tensor): - log_vars[loss_name] = loss_value.mean() - elif isinstance(loss_value, list): - log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) - else: - raise TypeError(f"{loss_name} is not a tensor or list of tensors") - - loss = sum(_value for _key, _value in log_vars.items() if "loss" in _key) - - log_vars["loss"] = loss - for loss_name, loss_value in log_vars.items(): - # reduce loss when distributed training - if dist.is_available() and dist.is_initialized(): - loss_value = loss_value.data.clone() - dist.all_reduce(loss_value.div_(dist.get_world_size())) - log_vars[loss_name] = loss_value.item() - - return loss, log_vars diff --git a/dinov2/hub/depth/ops.py b/dinov2/hub/depth/ops.py deleted file mode 100644 index 15880ee0cb7652d4b41c489b927bf6a156b40e5e..0000000000000000000000000000000000000000 --- a/dinov2/hub/depth/ops.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import warnings - -import torch.nn.functional as F - - -def resize(input, size=None, scale_factor=None, mode="nearest", align_corners=None, warning=False): - if warning: - if size is not None and align_corners: - input_h, input_w = tuple(int(x) for x in input.shape[2:]) - output_h, output_w = tuple(int(x) for x in size) - if output_h > input_h or output_w > output_h: - if ( - (output_h > 1 and output_w > 1 and input_h > 1 and input_w > 1) - and (output_h - 1) % (input_h - 1) - and (output_w - 1) % (input_w - 1) - ): - warnings.warn( - f"When align_corners={align_corners}, " - "the output would more aligned if " - f"input size {(input_h, input_w)} is `x+1` and " - f"out size {(output_h, output_w)} is `nx+1`" - ) - return F.interpolate(input, size, scale_factor, mode, align_corners) diff --git a/dinov2/hub/depthers.py b/dinov2/hub/depthers.py deleted file mode 100644 index f88b7e9a41056594e3b3e66107feee98bffab820..0000000000000000000000000000000000000000 --- a/dinov2/hub/depthers.py +++ /dev/null @@ -1,246 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from enum import Enum -from functools import partial -from typing import Optional, Tuple, Union - -import torch - -from .backbones import _make_dinov2_model -from .depth import BNHead, DepthEncoderDecoder, DPTHead -from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name, CenterPadding - - -class Weights(Enum): - NYU = "NYU" - KITTI = "KITTI" - - -def _get_depth_range(pretrained: bool, weights: Weights = Weights.NYU) -> Tuple[float, float]: - if not pretrained: # Default - return (0.001, 10.0) - - # Pretrained, set according to the training dataset for the provided weights - if weights == Weights.KITTI: - return (0.001, 80.0) - - if weights == Weights.NYU: - return (0.001, 10.0) - - return (0.001, 10.0) - - -def _make_dinov2_linear_depth_head( - *, - embed_dim: int, - layers: int, - min_depth: float, - max_depth: float, - **kwargs, -): - if layers not in (1, 4): - raise AssertionError(f"Unsupported number of layers: {layers}") - - if layers == 1: - in_index = [0] - else: - assert layers == 4 - in_index = [0, 1, 2, 3] - - return BNHead( - classify=True, - n_bins=256, - bins_strategy="UD", - norm_strategy="linear", - upsample=4, - in_channels=[embed_dim] * len(in_index), - in_index=in_index, - input_transform="resize_concat", - channels=embed_dim * len(in_index) * 2, - align_corners=False, - min_depth=0.001, - max_depth=80, - loss_decode=(), - ) - - -def _make_dinov2_linear_depther( - *, - arch_name: str = "vit_large", - layers: int = 4, - pretrained: bool = True, - weights: Union[Weights, str] = Weights.NYU, - depth_range: Optional[Tuple[float, float]] = None, - **kwargs, -): - if layers not in (1, 4): - raise AssertionError(f"Unsupported number of layers: {layers}") - if isinstance(weights, str): - try: - weights = Weights[weights] - except KeyError: - raise AssertionError(f"Unsupported weights: {weights}") - - if depth_range is None: - depth_range = _get_depth_range(pretrained, weights) - min_depth, max_depth = depth_range - - backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs) - - embed_dim = backbone.embed_dim - patch_size = backbone.patch_size - model_name = _make_dinov2_model_name(arch_name, patch_size) - linear_depth_head = _make_dinov2_linear_depth_head( - embed_dim=embed_dim, - layers=layers, - min_depth=min_depth, - max_depth=max_depth, - ) - - layer_count = { - "vit_small": 12, - "vit_base": 12, - "vit_large": 24, - "vit_giant2": 40, - }[arch_name] - - if layers == 4: - out_index = { - "vit_small": [2, 5, 8, 11], - "vit_base": [2, 5, 8, 11], - "vit_large": [4, 11, 17, 23], - "vit_giant2": [9, 19, 29, 39], - }[arch_name] - else: - assert layers == 1 - out_index = [layer_count - 1] - - model = DepthEncoderDecoder(backbone=backbone, decode_head=linear_depth_head) - model.backbone.forward = partial( - backbone.get_intermediate_layers, - n=out_index, - reshape=True, - return_class_token=True, - norm=False, - ) - model.backbone.register_forward_pre_hook(lambda _, x: CenterPadding(patch_size)(x[0])) - - if pretrained: - layers_str = str(layers) if layers == 4 else "" - weights_str = weights.value.lower() - url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_{weights_str}_linear{layers_str}_head.pth" - checkpoint = torch.hub.load_state_dict_from_url(url, map_location="cpu") - if "state_dict" in checkpoint: - state_dict = checkpoint["state_dict"] - model.load_state_dict(state_dict, strict=False) - - return model - - -def dinov2_vits14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): - return _make_dinov2_linear_depther( - arch_name="vit_small", layers=layers, pretrained=pretrained, weights=weights, **kwargs - ) - - -def dinov2_vitb14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): - return _make_dinov2_linear_depther( - arch_name="vit_base", layers=layers, pretrained=pretrained, weights=weights, **kwargs - ) - - -def dinov2_vitl14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): - return _make_dinov2_linear_depther( - arch_name="vit_large", layers=layers, pretrained=pretrained, weights=weights, **kwargs - ) - - -def dinov2_vitg14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): - return _make_dinov2_linear_depther( - arch_name="vit_giant2", layers=layers, ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs - ) - - -def _make_dinov2_dpt_depth_head(*, embed_dim: int, min_depth: float, max_depth: float): - return DPTHead( - in_channels=[embed_dim] * 4, - channels=256, - embed_dims=embed_dim, - post_process_channels=[embed_dim // 2 ** (3 - i) for i in range(4)], - readout_type="project", - min_depth=min_depth, - max_depth=max_depth, - loss_decode=(), - ) - - -def _make_dinov2_dpt_depther( - *, - arch_name: str = "vit_large", - pretrained: bool = True, - weights: Union[Weights, str] = Weights.NYU, - depth_range: Optional[Tuple[float, float]] = None, - **kwargs, -): - if isinstance(weights, str): - try: - weights = Weights[weights] - except KeyError: - raise AssertionError(f"Unsupported weights: {weights}") - - if depth_range is None: - depth_range = _get_depth_range(pretrained, weights) - min_depth, max_depth = depth_range - - backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs) - - model_name = _make_dinov2_model_name(arch_name, backbone.patch_size) - dpt_depth_head = _make_dinov2_dpt_depth_head(embed_dim=backbone.embed_dim, min_depth=min_depth, max_depth=max_depth) - - out_index = { - "vit_small": [2, 5, 8, 11], - "vit_base": [2, 5, 8, 11], - "vit_large": [4, 11, 17, 23], - "vit_giant2": [9, 19, 29, 39], - }[arch_name] - - model = DepthEncoderDecoder(backbone=backbone, decode_head=dpt_depth_head) - model.backbone.forward = partial( - backbone.get_intermediate_layers, - n=out_index, - reshape=True, - return_class_token=True, - norm=False, - ) - model.backbone.register_forward_pre_hook(lambda _, x: CenterPadding(backbone.patch_size)(x[0])) - - if pretrained: - weights_str = weights.value.lower() - url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_{weights_str}_dpt_head.pth" - checkpoint = torch.hub.load_state_dict_from_url(url, map_location="cpu") - if "state_dict" in checkpoint: - state_dict = checkpoint["state_dict"] - model.load_state_dict(state_dict, strict=False) - - return model - - -def dinov2_vits14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): - return _make_dinov2_dpt_depther(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs) - - -def dinov2_vitb14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): - return _make_dinov2_dpt_depther(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs) - - -def dinov2_vitl14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): - return _make_dinov2_dpt_depther(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs) - - -def dinov2_vitg14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): - return _make_dinov2_dpt_depther( - arch_name="vit_giant2", ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs - ) diff --git a/dinov2/hub/dinotxt.py b/dinov2/hub/dinotxt.py deleted file mode 100644 index 3578538ceacf08dcc86f4426da8368aba4223894..0000000000000000000000000000000000000000 --- a/dinov2/hub/dinotxt.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import torch -import math - -from .backbones import dinov2_vitl14_reg -from .utils import _DINOV2_BASE_URL - - -def dinov2_vitl14_reg4_dinotxt_tet1280d20h24l(): - from .text.dinotxt_model import DinoTxtConfig, DinoTxt - from .text.dinov2_wrapper import DINOv2Wrapper - from .text.text_transformer import TextTransformer - - dinotxt_config = DinoTxtConfig( - embed_dim=2048, - vision_model_freeze_backbone=True, - vision_model_train_img_size=224, - vision_model_use_class_token=True, - vision_model_use_patch_tokens=True, - vision_model_num_head_blocks=2, - vision_model_head_blocks_drop_path=0.3, - vision_model_use_linear_projection=False, - vision_model_patch_tokens_pooler_type="mean", - vision_model_patch_token_layer=1, # which layer to take patch tokens from - # 1 - last layer, 2 - second last layer, etc. - text_model_freeze_backbone=False, - text_model_num_head_blocks=0, - text_model_head_blocks_is_causal=False, - text_model_head_blocks_drop_prob=0.0, - text_model_tokens_pooler_type="argmax", - text_model_use_linear_projection=True, - init_logit_scale=math.log(1 / 0.07), - init_logit_bias=None, - freeze_logit_scale=False, - ) - vision_backbone = DINOv2Wrapper(dinov2_vitl14_reg()) - text_backbone = TextTransformer( - context_length=77, - vocab_size=49408, - dim=1280, - num_heads=20, - num_layers=24, - ffn_ratio=4, - is_causal=True, - ls_init_value=None, - dropout_prob=0.0, - ) - model = DinoTxt(dinotxt_config, vision_backbone, text_backbone) - model.init_weights() - model.visual_model.backbone = vision_backbone - model.eval() - - visual_model_head_state_dict = torch.hub.load_state_dict_from_url( - _DINOV2_BASE_URL + "/dinov2_vitl14/dinov2_vitl14_reg4_dinotxt_tet1280d20h24l_vision_head.pth", - map_location="cpu", - ) - text_model_state_dict = torch.hub.load_state_dict_from_url( - _DINOV2_BASE_URL + "/dinov2_vitl14/dinov2_vitl14_reg4_dinotxt_tet1280d20h24l_text_encoder.pth", - map_location="cpu", - ) - model.visual_model.head.load_state_dict(visual_model_head_state_dict, strict=True) - model.text_model.load_state_dict(text_model_state_dict, strict=True) - return model - - -def get_tokenizer(): - from .text.tokenizer import Tokenizer - import requests - from io import BytesIO - - url = _DINOV2_BASE_URL + "/thirdparty/bpe_simple_vocab_16e6.txt.gz" - try: - response = requests.get(url) - response.raise_for_status() - file_buf = BytesIO(response.content) - return Tokenizer(vocab_path=file_buf) - except Exception as e: - raise FileNotFoundError(f"Failed to download file from url {url} with error last: {e}") diff --git a/dinov2/hub/text/dinotxt_model.py b/dinov2/hub/text/dinotxt_model.py deleted file mode 100644 index e561e1915485d30aee3cd04fb65da0e3db55dd36..0000000000000000000000000000000000000000 --- a/dinov2/hub/text/dinotxt_model.py +++ /dev/null @@ -1,135 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import math -from dataclasses import dataclass -from typing import Optional, Tuple - -import torch -import torch.nn.functional as F -from torch import nn, Tensor - -from .vision_tower import VisionTower -from .text_tower import TextTower - - -@dataclass -class DinoTxtConfig: - embed_dim: int - vision_model_freeze_backbone: bool = True - vision_model_train_img_size: int = 224 - vision_model_use_class_token: bool = True - vision_model_use_patch_tokens: bool = False - vision_model_num_head_blocks: int = 0 - vision_model_head_blocks_drop_path: float = 0.3 - vision_model_use_linear_projection: bool = False - vision_model_patch_tokens_pooler_type: str = "mean" - vision_model_patch_token_layer: int = 1 # which layer to take patch tokens from - # 1 - last layer, 2 - second last layer, etc. - text_model_freeze_backbone: bool = False - text_model_num_head_blocks: int = 0 - text_model_head_blocks_is_causal: bool = False - text_model_head_blocks_drop_prob: float = 0.0 - text_model_tokens_pooler_type: str = "first" - text_model_use_linear_projection: bool = False - init_logit_scale: float = math.log(1 / 0.07) - init_logit_bias: Optional[float] = None - freeze_logit_scale: bool = False - - -class DinoTxt(nn.Module): - def __init__( - self, - model_config: DinoTxtConfig, - vision_backbone: nn.Module, - text_backbone: nn.Module, - ): - super().__init__() - self.model_config = model_config - self.visual_model = VisionTower( - vision_backbone, - model_config.vision_model_freeze_backbone, - model_config.embed_dim, - model_config.vision_model_num_head_blocks, - model_config.vision_model_head_blocks_drop_path, - model_config.vision_model_use_class_token, - model_config.vision_model_use_patch_tokens, - model_config.vision_model_patch_token_layer, - model_config.vision_model_patch_tokens_pooler_type, - model_config.vision_model_use_linear_projection, - ) - self.text_model = TextTower( - text_backbone, - model_config.text_model_freeze_backbone, - model_config.embed_dim, - model_config.text_model_num_head_blocks, - model_config.text_model_head_blocks_is_causal, - model_config.text_model_head_blocks_drop_prob, - model_config.text_model_tokens_pooler_type, - model_config.text_model_use_linear_projection, - ) - self.logit_scale = nn.Parameter(torch.ones(1) * model_config.init_logit_scale) - if model_config.freeze_logit_scale: - self.logit_scale.requires_grad = False - - def init_weights(self): - self.visual_model.init_weights() - self.text_model.init_weights() - - def get_visual_class_and_patch_tokens(self, image: Tensor) -> Tuple[Tensor, Tensor]: - return self.visual_model.get_class_and_patch_tokens(image) - - def encode_image( - self, - image: Tensor, - normalize: bool = False, - ) -> Tensor: - """ - Encode an image into a vector descriptor containing both global and local features. - - Args: - image (Tensor): Tensor of shape `(batch_size, rgb, height, width)`, normalized using ImageNet mean and std. - normalize (bool, optional): Whether to normalize the output vectors. Default is False. - Image features should always be normalized before comparing them with text features: - Returns: - Tensor: Tensor of shape `(batch_size, embed_dim)` containing the image features. - The first half of the vector corresponds to the global features (class token), - and the second half corresponds to the pooled patch features. - """ - features = self.visual_model(image) - return F.normalize(features, dim=-1) if normalize else features - - def encode_text(self, text: Tensor, normalize: bool = False) -> Tensor: - """ - Encode a text input into a vector descriptor. - - Args: - text (Tensor): Tensor of shape `(batch_size, seq_len)` containing token indices. - normalize (bool, optional): Whether to normalize the output vectors. Default is False. - Text features should be normalized before comparing them with image features: - Returns: - Tensor: Tensor of shape `(batch_size, embed_dim)` containing the text features. - As a consequence of the training procedure, assume that the first half of the tensor corresponds - to global image features and the second half to pooled patch features. - """ - features = self.text_model(text) - return F.normalize(features, dim=-1) if normalize else features - - def get_logits(self, image: Tensor, text: Tensor) -> Tuple[Tensor, Tensor]: - text_features = self.encode_text(text, normalize=True) - image_features = self.encode_image(image, normalize=True) - image_logits = self.logit_scale.exp() * image_features @ text_features.T - text_logits = image_logits.T - return image_logits, text_logits - - def forward( - self, - image: Tensor, - text: Tensor, - ) -> Tuple[Tensor, Tensor, Tensor]: - - text_features = self.encode_text(text, normalize=True) - image_features = self.encode_image(image, normalize=True) - return image_features, text_features, self.logit_scale.exp() diff --git a/dinov2/hub/text/dinov2_wrapper.py b/dinov2/hub/text/dinov2_wrapper.py deleted file mode 100644 index 2689dc2fdd73c5be0b32d7f60abb6c11b760e3c2..0000000000000000000000000000000000000000 --- a/dinov2/hub/text/dinov2_wrapper.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from typing import Sequence - -import torch - - -class DINOv2Wrapper(torch.nn.Module): - def __init__(self, model): - super().__init__() - self.model = model - self.embed_dim = model.embed_dim - self.num_heads = model.num_heads - self.num_register_tokens = model.num_register_tokens - - # Same as the original forward, but assert is_training and rename x_norm_regtokens -> x_storage_tokens - def forward(self, img, is_training: bool): - assert is_training - H, W = img.shape[-2:] - P = self.model.patch_size - x_dict = self.model(img, is_training=True) - x_dict["h"] = h = H // P - x_dict["w"] = w = W // P - assert x_dict["x_norm_patchtokens"].shape[-2] == h * w - return x_dict - - # Same as the original get_intermediate_layers, but allow returining extra tokens (registers) - def get_intermediate_layers( - self, - x: torch.Tensor, - n: int | Sequence[int] = 1, # Layers or n last layers to take - reshape: bool = False, - return_class_token: bool = False, - return_register_tokens: bool = False, - norm=True, - ) -> tuple[torch.Tensor] | tuple[tuple[torch.Tensor, ...], ...]: - if self.model.chunked_blocks: - outputs = self.model._get_intermediate_layers_chunked(x, n) - else: - outputs = self.model._get_intermediate_layers_not_chunked(x, n) - if norm: - outputs = [self.model.norm(out) for out in outputs] - class_tokens = [out[:, 0] for out in outputs] - register_tokens = [out[:, 1 : 1 + self.model.num_register_tokens] for out in outputs] - outputs = [out[:, 1 + self.model.num_register_tokens :] for out in outputs] - if reshape: - B, _, h, w = x.shape - outputs = [ - out.reshape(B, h // self.model.patch_size, w // self.model.patch_size, -1) - .permute(0, 3, 1, 2) - .contiguous() - for out in outputs - ] - - if not return_class_token and not return_register_tokens: - return tuple(outputs) - if return_class_token and not return_register_tokens: - return tuple(zip(outputs, class_tokens)) - if not return_class_token and return_register_tokens: - return tuple(zip(outputs, register_tokens)) - return tuple(zip(outputs, class_tokens, register_tokens)) diff --git a/dinov2/hub/text/text_tower.py b/dinov2/hub/text/text_tower.py deleted file mode 100644 index a74c15f0eb64092728e245bbce08958e75ddc82c..0000000000000000000000000000000000000000 --- a/dinov2/hub/text/text_tower.py +++ /dev/null @@ -1,104 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import torch -from torch import nn, Tensor - -from dinov2.layers import ( - CausalAttentionBlock, -) - - -class TextHead(nn.Module): - def __init__( - self, - input_dim: int, - embed_dim: int, - num_heads: int, - num_blocks: int, - block_drop_prob: float, - is_causal: bool, - use_linear_projection: bool, - ): - super().__init__() - block_list = [nn.Identity()] - self.ln_final = nn.Identity() - if num_blocks > 0: - block_list = [ - CausalAttentionBlock( - dim=input_dim, - num_heads=num_heads, - is_causal=is_causal, - dropout_prob=block_drop_prob, - ) - for _ in range(num_blocks) - ] - self.ln_final = nn.LayerNorm(input_dim) - self.block_list = nn.ModuleList(block_list) - self.num_blocks = num_blocks - self.linear_projection = nn.Identity() - if input_dim != embed_dim or use_linear_projection: - self.linear_projection = nn.Linear(input_dim, embed_dim, bias=False) - - def init_weights(self): - if self.num_blocks > 0: - for i in range(self.num_blocks): - self.block_list[i].init_weights() - self.ln_final.reset_parameters() - if isinstance(self.linear_projection, nn.Linear): - nn.init.normal_(self.linear_projection.weight, std=self.linear_projection.in_features**-0.5) - - def forward(self, text_tokens: Tensor) -> Tensor: - for block in self.block_list: - text_tokens = block(text_tokens) - text_tokens = self.ln_final(text_tokens) - return self.linear_projection(text_tokens) - - -class TextTower(nn.Module): - def __init__( - self, - backbone: nn.Module, - freeze_backbone: bool, - embed_dim: int, - num_head_blocks: int, - head_blocks_is_causal: bool, - head_blocks_block_drop_prob: float, - tokens_pooler_type: str, - use_linear_projection: bool, - ): - super().__init__() - self.backbone = backbone - self.freeze_backbone = freeze_backbone - backbone_out_dim = backbone.embed_dim - self.backbone = backbone - self.head = TextHead( - backbone_out_dim, - embed_dim, - self.backbone.num_heads, - num_head_blocks, - head_blocks_block_drop_prob, - head_blocks_is_causal, - use_linear_projection, - ) - self.tokens_pooler_type = tokens_pooler_type - - def init_weights(self): - self.backbone.init_weights() - self.head.init_weights() - - def forward(self, token_indices: Tensor) -> Tensor: - text_tokens = self.backbone(token_indices) - text_tokens = self.head(text_tokens) - if self.tokens_pooler_type == "first": - features = text_tokens[:, 0] - elif self.tokens_pooler_type == "last": - features = text_tokens[:, -1] - elif self.tokens_pooler_type == "argmax": - assert token_indices is not None - features = text_tokens[torch.arange(text_tokens.shape[0]), token_indices.argmax(dim=-1)] - else: - raise ValueError(f"Unknown text tokens pooler type: {self.pooler_type}") - return features diff --git a/dinov2/hub/text/text_transformer.py b/dinov2/hub/text/text_transformer.py deleted file mode 100644 index 7f836481b1a9fe98271d894423c4fd79644dd487..0000000000000000000000000000000000000000 --- a/dinov2/hub/text/text_transformer.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from typing import Callable, Optional, Tuple - -import torch -from torch import nn, Tensor - - -from dinov2.layers import CausalAttentionBlock - - -class TextTransformer(nn.Module): - def __init__( - self, - context_length: int, - vocab_size: int, - dim: int, - num_heads: int, - num_layers: int, - ffn_ratio: float, - is_causal: bool, - ls_init_value: Optional[float] = None, - act_layer: Callable = nn.GELU, - norm_layer: Callable = nn.LayerNorm, - dropout_prob: float = 0.0, - ): - super().__init__() - self.vocab_size = vocab_size - self.embed_dim = dim - self.num_heads = num_heads - - self.token_embedding = nn.Embedding(vocab_size, dim) - self.positional_embedding = nn.Parameter(torch.empty(context_length, dim)) - self.dropout = nn.Dropout(dropout_prob) - self.num_layers = num_layers - block_list = [ - CausalAttentionBlock( - dim=dim, - num_heads=num_heads, - ffn_ratio=ffn_ratio, - ls_init_value=ls_init_value, - is_causal=is_causal, - act_layer=act_layer, - norm_layer=norm_layer, - dropout_prob=dropout_prob, - ) - for _ in range(num_layers) - ] - self.blocks = nn.ModuleList(block_list) - self.ln_final = norm_layer(dim) - - def init_weights(self): - nn.init.normal_(self.token_embedding.weight, std=0.02) - nn.init.normal_(self.positional_embedding, std=0.01) - init_attn_std = self.embed_dim**-0.5 - init_proj_std = (self.embed_dim**-0.5) * ((2 * self.num_layers) ** -0.5) - init_fc_std = (2 * self.embed_dim) ** -0.5 - for block in self.blocks: - block.init_weights(init_attn_std, init_proj_std, init_fc_std) - self.ln_final.reset_parameters() - - def forward(self, token_indices: Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - _, N = token_indices.size() - x = self.token_embedding(token_indices) + self.positional_embedding[:N] - x = self.dropout(x) - for block in self.blocks: - x = block(x) - x = self.ln_final(x) - return x diff --git a/dinov2/hub/text/tokenizer.py b/dinov2/hub/text/tokenizer.py deleted file mode 100644 index 97a6b32eae76e5a1fcf63964d6ff9bf482dfeb0e..0000000000000000000000000000000000000000 --- a/dinov2/hub/text/tokenizer.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import torch -from typing import List, Union - - -from dinov2.thirdparty.CLIP.clip.simple_tokenizer import SimpleTokenizer - - -class Tokenizer(SimpleTokenizer): - def __init__(self, vocab_path: str): - SimpleTokenizer.__init__(self, bpe_path=vocab_path) - - def tokenize(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: - """ - Returns the tokenized representation of given input string(s) - - Parameters - ---------- - texts : Union[str, List[str]] - An input string or a list of input strings to tokenize - context_length : int - The context length to use; all CLIP models use 77 as the context length - - Returns - ------- - A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] - """ - if isinstance(texts, str): - texts = [texts] - sot_token = self.encoder["<|startoftext|>"] - eot_token = self.encoder["<|endoftext|>"] - all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts] - result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) - - for i, tokens in enumerate(all_tokens): - if len(tokens) > context_length: - tokens = tokens[:context_length] # Truncate - tokens[-1] = eot_token - result[i, : len(tokens)] = torch.tensor(tokens) - - return result diff --git a/dinov2/hub/text/vision_tower.py b/dinov2/hub/text/vision_tower.py deleted file mode 100644 index 8777aa9d27546da5c31cf2c995c6c1ed6a24ac37..0000000000000000000000000000000000000000 --- a/dinov2/hub/text/vision_tower.py +++ /dev/null @@ -1,179 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from functools import partial -from typing import Callable, Tuple - -import torch -from torch import nn, Tensor - -from dinov2.layers import ( - LayerScale, - NestedTensorBlock as AttentionBlock, - SwiGLUFFNAligned as SwiGLUFFN, -) - - -def init_weights_vit_timm(module: nn.Module, name: str = ""): - """ViT weight initialization, original timm impl (for reproducibility)""" - if isinstance(module, nn.Linear): - nn.init.trunc_normal_(module.weight, std=0.02) - if module.bias is not None: - nn.init.zeros_(module.bias) - if isinstance(module, nn.LayerNorm): - module.reset_parameters() - if isinstance(module, LayerScale): - module.reset_parameters() - if isinstance(module, nn.Conv2d): - module.reset_parameters() - - -def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: - if not depth_first and include_root: - fn(module=module, name=name) - for child_name, child_module in module.named_children(): - child_name = ".".join((name, child_name)) if name else child_name - named_apply( - fn=fn, - module=child_module, - name=child_name, - depth_first=depth_first, - include_root=True, - ) - if depth_first and include_root: - fn(module=module, name=name) - return module - - -class VisionHead(nn.Module): - def __init__( - self, - input_dim: int, - embed_dim: int, - num_heads: int, - num_blocks: int, - blocks_drop_path: float, - use_class_token: bool, - use_patch_tokens: bool, - use_linear_projection: bool, - ): - super().__init__() - block_list = [nn.Identity()] - self.ln_final = nn.Identity() - if num_blocks > 0: - block_list = [ - AttentionBlock( - input_dim, - num_heads, - ffn_layer=partial(SwiGLUFFN, align_to=64), - init_values=1e-5, - drop_path=blocks_drop_path, - ) - for _ in range(num_blocks) - ] - self.ln_final = nn.LayerNorm(input_dim) - self.block_list = nn.ModuleList(block_list) - self.num_blocks = num_blocks - multiplier = 2 if use_class_token and use_patch_tokens else 1 - self.linear_projection = nn.Identity() - if multiplier * input_dim != embed_dim or use_linear_projection: - assert embed_dim % multiplier == 0, f"Expects {embed_dim} to be divisible by {multiplier}" - self.linear_projection = nn.Linear(input_dim, embed_dim // multiplier, bias=False) - - def init_weights(self): - if self.num_blocks > 0: - for i in range(self.num_blocks): - block = self.block_list[i] - named_apply(init_weights_vit_timm, block) - self.ln_final.reset_parameters() - if isinstance(self.linear_projection, nn.Linear): - nn.init.normal_(self.linear_projection.weight, std=self.linear_projection.in_features**-0.5) - - def forward(self, image_tokens: Tensor) -> Tensor: - for block in self.block_list: - image_tokens = block(image_tokens) - image_tokens = self.ln_final(image_tokens) - return self.linear_projection(image_tokens) - - -class VisionTower(nn.Module): - def __init__( - self, - backbone: nn.Module, - freeze_backbone: bool, - embed_dim: int, - num_head_blocks: int, - head_blocks_block_drop_path: float, - use_class_token: bool, - use_patch_tokens: bool, - patch_token_layer: int, - patch_tokens_pooler_type: str, - use_linear_projection: bool, - ): - super().__init__() - self.backbone = backbone - self.freeze_backbone = freeze_backbone - self.use_class_token = use_class_token - self.use_patch_tokens = use_patch_tokens - self.patch_token_layer = patch_token_layer - self.patch_tokens_pooler_type = patch_tokens_pooler_type - self.num_register_tokens = 0 - if hasattr(self.backbone, "num_register_tokens"): - self.num_register_tokens = self.backbone.num_register_tokens - elif hasattr(self.backbone, "n_storage_tokens"): - self.num_register_tokens = self.backbone.n_storage_tokens - backbone_out_dim = self.backbone.embed_dim - self.head = VisionHead( - backbone_out_dim, - embed_dim, - self.backbone.num_heads, - num_head_blocks, - head_blocks_block_drop_path, - use_class_token, - use_patch_tokens, - use_linear_projection, - ) - - def init_weights(self): - if not self.freeze_backbone: - self.backbone.init_weights() - self.head.init_weights() - - def get_backbone_features(self, images: Tensor) -> Tuple[Tensor, Tensor, Tensor]: - tokens = self.backbone.get_intermediate_layers( - images, - n=self.patch_token_layer, - return_class_token=True, - return_register_tokens=True, - ) - class_token = tokens[-1][1] - patch_tokens = tokens[0][0] - register_tokens = tokens[0][2] - return class_token, patch_tokens, register_tokens - - def get_class_and_patch_tokens(self, images: Tensor) -> Tuple[Tensor, Tensor]: - class_token, patch_tokens, register_tokens = self.get_backbone_features(images) - image_tokens = self.head(torch.cat([class_token.unsqueeze(1), register_tokens, patch_tokens], dim=1)) - class_token, patch_tokens = image_tokens[:, 0], image_tokens[:, self.num_register_tokens + 1 :] - return class_token, patch_tokens - - def forward(self, images: Tensor) -> Tensor: - class_token, patch_tokens = self.get_class_and_patch_tokens(images) - features = [] - if self.use_class_token: - features.append(class_token) - if self.use_patch_tokens: - if self.patch_tokens_pooler_type == "mean": - features.append(torch.mean(patch_tokens, dim=1)) - elif self.patch_tokens_pooler_type == "max": - features.append(torch.max(patch_tokens, dim=1).values) - elif self.patch_tokens_pooler_type == "gem": - power = 3 - eps = 1e-6 - patch_tokens_power = patch_tokens.clamp(min=eps).pow(power) - features.append(torch.mean(patch_tokens_power, dim=1).pow(1 / power)) - else: - raise ValueError(f"Unknown patch tokens pooler type: {self.patch_tokens_pooler_type}") - return torch.cat(features, dim=-1) diff --git a/dinov2/hub/utils.py b/dinov2/hub/utils.py deleted file mode 100644 index 9c6641404093652d5a2f19b4cf283d976ec39e64..0000000000000000000000000000000000000000 --- a/dinov2/hub/utils.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import itertools -import math - -import torch -import torch.nn as nn -import torch.nn.functional as F - - -_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2" - - -def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str: - compact_arch_name = arch_name.replace("_", "")[:4] - registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else "" - return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}" - - -class CenterPadding(nn.Module): - def __init__(self, multiple): - super().__init__() - self.multiple = multiple - - def _get_pad(self, size): - new_size = math.ceil(size / self.multiple) * self.multiple - pad_size = new_size - size - pad_size_left = pad_size // 2 - pad_size_right = pad_size - pad_size_left - return pad_size_left, pad_size_right - - @torch.inference_mode() - def forward(self, x): - pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1])) - output = F.pad(x, pads) - return output diff --git a/dinov2/hub/xray_dino/backbones.py b/dinov2/hub/xray_dino/backbones.py deleted file mode 100644 index f49b54be1c3aa5a0fbe17427b86f9277598aa925..0000000000000000000000000000000000000000 --- a/dinov2/hub/xray_dino/backbones.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the licence -# found in the LICENSE_XRAY_DINO_MODEL file in the root directory of this source tree. - -from typing import Union - -from ..backbones import Weights, _make_dinov2_model - - -def xray_dino_vitl16(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.XRAY_DINO, **kwargs): - """ - XRay-DINO ViT-L/16 model (optionally) pretrained on the XRay-DINO dataset. - """ - return _make_dinov2_model( - arch_name="vit_large", - patch_size=16, - img_size=512, - num_register_tokens=0, - interpolate_antialias=False, - interpolate_offset=0.1, - block_chunks=4, - pretrained=pretrained, - weights=weights, - hash="ad31c2b0", - check_hash=True, - **kwargs, - ) diff --git a/dinov2/layers/__init__.py b/dinov2/layers/__init__.py deleted file mode 100644 index d640d145fbfb8993d6e7ceec4eb22b5b0a3e62fa..0000000000000000000000000000000000000000 --- a/dinov2/layers/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from .dino_head import DINOHead -from .layer_scale import LayerScale -from .mlp import Mlp -from .patch_embed import PatchEmbed -from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused, SwiGLUFFNAligned -from .block import NestedTensorBlock, CausalAttentionBlock -from .attention import Attention, MemEffAttention diff --git a/dinov2/layers/attention.py b/dinov2/layers/attention.py deleted file mode 100644 index f1d3dabf14ffbd5eb68c3c16edc56d0da19b1265..0000000000000000000000000000000000000000 --- a/dinov2/layers/attention.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -# References: -# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py -# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py - -import logging -import os -import warnings - -import torch -from torch import nn, Tensor - - -logger = logging.getLogger("dinov2") - - -XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None -try: - if XFORMERS_ENABLED: - from xformers.ops import memory_efficient_attention, unbind - - XFORMERS_AVAILABLE = True - warnings.warn("xFormers is available (Attention)") - else: - warnings.warn("xFormers is disabled (Attention)") - raise ImportError -except ImportError: - XFORMERS_AVAILABLE = False - warnings.warn("xFormers is not available (Attention)") - - -class Attention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int = 8, - qkv_bias: bool = False, - proj_bias: bool = True, - attn_drop: float = 0.0, - proj_drop: float = 0.0, - ) -> None: - super().__init__() - self.dim = dim - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = head_dim**-0.5 - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.attn_drop = attn_drop - self.proj = nn.Linear(dim, dim, bias=proj_bias) - self.proj_drop = nn.Dropout(proj_drop) - - def init_weights( - self, init_attn_std: float | None = None, init_proj_std: float | None = None, factor: float = 1.0 - ) -> None: - init_attn_std = init_attn_std or (self.dim**-0.5) - init_proj_std = init_proj_std or init_attn_std * factor - nn.init.normal_(self.qkv.weight, std=init_attn_std) - nn.init.normal_(self.proj.weight, std=init_proj_std) - if self.qkv.bias is not None: - nn.init.zeros_(self.qkv.bias) - if self.proj.bias is not None: - nn.init.zeros_(self.proj.bias) - - def forward(self, x: Tensor, is_causal: bool = False) -> Tensor: - B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) - q, k, v = torch.unbind(qkv, 2) - q, k, v = [t.transpose(1, 2) for t in [q, k, v]] - x = nn.functional.scaled_dot_product_attention( - q, k, v, attn_mask=None, dropout_p=self.attn_drop if self.training else 0, is_causal=is_causal - ) - x = x.transpose(1, 2).contiguous().view(B, N, C) - x = self.proj_drop(self.proj(x)) - return x - - -class MemEffAttention(Attention): - def forward(self, x: Tensor, attn_bias=None) -> Tensor: - if not XFORMERS_AVAILABLE: - if attn_bias is not None: - raise AssertionError("xFormers is required for using nested tensors") - return super().forward(x) - - B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) - - q, k, v = unbind(qkv, 2) - - x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) - x = x.reshape([B, N, C]) - - x = self.proj(x) - x = self.proj_drop(x) - return x diff --git a/dinov2/layers/block.py b/dinov2/layers/block.py deleted file mode 100644 index 7e83b71ccb428ca099d2d1d49933dc837faeecfa..0000000000000000000000000000000000000000 --- a/dinov2/layers/block.py +++ /dev/null @@ -1,316 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -# References: -# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py -# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py - -import logging -import os -from typing import Callable, List, Any, Tuple, Dict, Optional -import warnings - -import torch -from torch import nn, Tensor - -from .attention import Attention, MemEffAttention -from .drop_path import DropPath -from .layer_scale import LayerScale -from .mlp import Mlp - - -logger = logging.getLogger("dinov2") - - -XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None -try: - if XFORMERS_ENABLED: - from xformers.ops import fmha, scaled_index_add, index_select_cat - - XFORMERS_AVAILABLE = True - warnings.warn("xFormers is available (Block)") - else: - warnings.warn("xFormers is disabled (Block)") - raise ImportError -except ImportError: - XFORMERS_AVAILABLE = False - - warnings.warn("xFormers is not available (Block)") - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - mlp_ratio: float = 4.0, - qkv_bias: bool = False, - proj_bias: bool = True, - ffn_bias: bool = True, - drop: float = 0.0, - attn_drop: float = 0.0, - init_values=None, - drop_path: float = 0.0, - act_layer: Callable[..., nn.Module] = nn.GELU, - norm_layer: Callable[..., nn.Module] = nn.LayerNorm, - attn_class: Callable[..., nn.Module] = Attention, - ffn_layer: Callable[..., nn.Module] = Mlp, - ) -> None: - super().__init__() - # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") - self.norm1 = norm_layer(dim) - self.attn = attn_class( - dim, - num_heads=num_heads, - qkv_bias=qkv_bias, - proj_bias=proj_bias, - attn_drop=attn_drop, - proj_drop=drop, - ) - self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() - self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() - - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = ffn_layer( - in_features=dim, - hidden_features=mlp_hidden_dim, - act_layer=act_layer, - drop=drop, - bias=ffn_bias, - ) - self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() - self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() - - self.sample_drop_ratio = drop_path - - def forward(self, x: Tensor) -> Tensor: - def attn_residual_func(x: Tensor) -> Tensor: - return self.ls1(self.attn(self.norm1(x))) - - def ffn_residual_func(x: Tensor) -> Tensor: - return self.ls2(self.mlp(self.norm2(x))) - - if self.training and self.sample_drop_ratio > 0.1: - # the overhead is compensated only for a drop path rate larger than 0.1 - x = drop_add_residual_stochastic_depth( - x, - residual_func=attn_residual_func, - sample_drop_ratio=self.sample_drop_ratio, - ) - x = drop_add_residual_stochastic_depth( - x, - residual_func=ffn_residual_func, - sample_drop_ratio=self.sample_drop_ratio, - ) - elif self.training and self.sample_drop_ratio > 0.0: - x = x + self.drop_path1(attn_residual_func(x)) - x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 - else: - x = x + attn_residual_func(x) - x = x + ffn_residual_func(x) - return x - - -class CausalAttentionBlock(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - ffn_ratio: float = 4.0, - ls_init_value: Optional[float] = None, - is_causal: bool = True, - act_layer: Callable = nn.GELU, - norm_layer: Callable = nn.LayerNorm, - dropout_prob: float = 0.0, - ): - super().__init__() - - self.dim = dim - self.is_causal = is_causal - self.ls1 = LayerScale(dim, init_values=ls_init_value) if ls_init_value else nn.Identity() - self.attention_norm = norm_layer(dim) - self.attention = Attention(dim, num_heads, attn_drop=dropout_prob, proj_drop=dropout_prob) - - self.ffn_norm = norm_layer(dim) - ffn_hidden_dim = int(dim * ffn_ratio) - self.feed_forward = Mlp( - in_features=dim, - hidden_features=ffn_hidden_dim, - drop=dropout_prob, - act_layer=act_layer, - ) - - self.ls2 = LayerScale(dim, init_values=ls_init_value) if ls_init_value else nn.Identity() - - def init_weights( - self, - init_attn_std: float | None = None, - init_proj_std: float | None = None, - init_fc_std: float | None = None, - factor: float = 1.0, - ) -> None: - init_attn_std = init_attn_std or (self.dim**-0.5) - init_proj_std = init_proj_std or init_attn_std * factor - init_fc_std = init_fc_std or (2 * self.dim) ** -0.5 - self.attention.init_weights(init_attn_std, init_proj_std) - self.attention_norm.reset_parameters() - nn.init.normal_(self.feed_forward.fc1.weight, std=init_fc_std) - nn.init.normal_(self.feed_forward.fc2.weight, std=init_proj_std) - self.ffn_norm.reset_parameters() - - def forward( - self, - x: torch.Tensor, - ): - x_attn = x + self.ls1(self.attention(self.attention_norm(x), self.is_causal)) - x_ffn = x_attn + self.ls2(self.feed_forward(self.ffn_norm(x_attn))) - return x_ffn - - -def drop_add_residual_stochastic_depth( - x: Tensor, - residual_func: Callable[[Tensor], Tensor], - sample_drop_ratio: float = 0.0, -) -> Tensor: - # 1) extract subset using permutation - b, n, d = x.shape - sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) - brange = (torch.randperm(b, device=x.device))[:sample_subset_size] - x_subset = x[brange] - - # 2) apply residual_func to get residual - residual = residual_func(x_subset) - - x_flat = x.flatten(1) - residual = residual.flatten(1) - - residual_scale_factor = b / sample_subset_size - - # 3) add the residual - x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) - return x_plus_residual.view_as(x) - - -def get_branges_scales(x, sample_drop_ratio=0.0): - b, n, d = x.shape - sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) - brange = (torch.randperm(b, device=x.device))[:sample_subset_size] - residual_scale_factor = b / sample_subset_size - return brange, residual_scale_factor - - -def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): - if scaling_vector is None: - x_flat = x.flatten(1) - residual = residual.flatten(1) - x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) - else: - x_plus_residual = scaled_index_add( - x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor - ) - return x_plus_residual - - -attn_bias_cache: Dict[Tuple, Any] = {} - - -def get_attn_bias_and_cat(x_list, branges=None): - """ - this will perform the index select, cat the tensors, and provide the attn_bias from cache - """ - batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] - all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) - if all_shapes not in attn_bias_cache.keys(): - seqlens = [] - for b, x in zip(batch_sizes, x_list): - for _ in range(b): - seqlens.append(x.shape[1]) - attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) - attn_bias._batch_sizes = batch_sizes - attn_bias_cache[all_shapes] = attn_bias - - if branges is not None: - cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) - else: - tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) - cat_tensors = torch.cat(tensors_bs1, dim=1) - - return attn_bias_cache[all_shapes], cat_tensors - - -def drop_add_residual_stochastic_depth_list( - x_list: List[Tensor], - residual_func: Callable[[Tensor, Any], Tensor], - sample_drop_ratio: float = 0.0, - scaling_vector=None, -) -> Tensor: - # 1) generate random set of indices for dropping samples in the batch - branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] - branges = [s[0] for s in branges_scales] - residual_scale_factors = [s[1] for s in branges_scales] - - # 2) get attention bias and index+concat the tensors - attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) - - # 3) apply residual_func to get residual, and split the result - residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore - - outputs = [] - for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): - outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) - return outputs - - -class NestedTensorBlock(Block): - def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: - """ - x_list contains a list of tensors to nest together and run - """ - assert isinstance(self.attn, MemEffAttention) - - if self.training and self.sample_drop_ratio > 0.0: - - def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: - return self.attn(self.norm1(x), attn_bias=attn_bias) - - def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: - return self.mlp(self.norm2(x)) - - x_list = drop_add_residual_stochastic_depth_list( - x_list, - residual_func=attn_residual_func, - sample_drop_ratio=self.sample_drop_ratio, - scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, - ) - x_list = drop_add_residual_stochastic_depth_list( - x_list, - residual_func=ffn_residual_func, - sample_drop_ratio=self.sample_drop_ratio, - scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, - ) - return x_list - else: - - def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: - return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) - - def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: - return self.ls2(self.mlp(self.norm2(x))) - - attn_bias, x = get_attn_bias_and_cat(x_list) - x = x + attn_residual_func(x, attn_bias=attn_bias) - x = x + ffn_residual_func(x) - return attn_bias.split(x) - - def forward(self, x_or_x_list): - if isinstance(x_or_x_list, Tensor): - return super().forward(x_or_x_list) - elif isinstance(x_or_x_list, list): - if not XFORMERS_AVAILABLE: - raise AssertionError("xFormers is required for using nested tensors") - return self.forward_nested(x_or_x_list) - else: - raise AssertionError diff --git a/dinov2/layers/dino_head.py b/dinov2/layers/dino_head.py deleted file mode 100644 index 0ace8ffd6297a1dd480b19db407b662a6ea0f565..0000000000000000000000000000000000000000 --- a/dinov2/layers/dino_head.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import torch -import torch.nn as nn -from torch.nn.init import trunc_normal_ -from torch.nn.utils import weight_norm - - -class DINOHead(nn.Module): - def __init__( - self, - in_dim, - out_dim, - use_bn=False, - nlayers=3, - hidden_dim=2048, - bottleneck_dim=256, - mlp_bias=True, - ): - super().__init__() - nlayers = max(nlayers, 1) - self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias) - self.apply(self._init_weights) - self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) - self.last_layer.weight_g.data.fill_(1) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=0.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - - def forward(self, x): - x = self.mlp(x) - eps = 1e-6 if x.dtype == torch.float16 else 1e-12 - x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) - x = self.last_layer(x) - return x - - -def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True): - if nlayers == 1: - return nn.Linear(in_dim, bottleneck_dim, bias=bias) - else: - layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] - if use_bn: - layers.append(nn.BatchNorm1d(hidden_dim)) - layers.append(nn.GELU()) - for _ in range(nlayers - 2): - layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) - if use_bn: - layers.append(nn.BatchNorm1d(hidden_dim)) - layers.append(nn.GELU()) - layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) - return nn.Sequential(*layers) diff --git a/dinov2/layers/drop_path.py b/dinov2/layers/drop_path.py deleted file mode 100644 index 1d640e0b969b8dcba96260243473700b4e5b24b5..0000000000000000000000000000000000000000 --- a/dinov2/layers/drop_path.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -# References: -# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py -# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py - - -from torch import nn - - -def drop_path(x, drop_prob: float = 0.0, training: bool = False): - if drop_prob == 0.0 or not training: - return x - keep_prob = 1 - drop_prob - shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets - random_tensor = x.new_empty(shape).bernoulli_(keep_prob) - if keep_prob > 0.0: - random_tensor.div_(keep_prob) - output = x * random_tensor - return output - - -class DropPath(nn.Module): - """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" - - def __init__(self, drop_prob=None): - super(DropPath, self).__init__() - self.drop_prob = drop_prob - - def forward(self, x): - return drop_path(x, self.drop_prob, self.training) diff --git a/dinov2/layers/layer_scale.py b/dinov2/layers/layer_scale.py deleted file mode 100644 index 0b38971302b3c8fb3d4c05a5f0912fafe0e80816..0000000000000000000000000000000000000000 --- a/dinov2/layers/layer_scale.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 - -from typing import Optional, Union - -import torch -from torch import Tensor -from torch import nn - - -class LayerScale(nn.Module): - def __init__( - self, - dim: int, - init_values: Union[float, Tensor] = 1e-5, - inplace: bool = False, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ) -> None: - super().__init__() - self.inplace = inplace - self.init_values = init_values - self.gamma = nn.Parameter(torch.empty(dim, device=device, dtype=dtype)) - self.reset_parameters() - - def reset_parameters(self): - nn.init.constant_(self.gamma, self.init_values) - - def forward(self, x: Tensor) -> Tensor: - return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/dinov2/layers/mlp.py b/dinov2/layers/mlp.py deleted file mode 100644 index bbf9432aae9258612caeae910a7bde17999e328e..0000000000000000000000000000000000000000 --- a/dinov2/layers/mlp.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -# References: -# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py -# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py - - -from typing import Callable, Optional - -from torch import Tensor, nn - - -class Mlp(nn.Module): - def __init__( - self, - in_features: int, - hidden_features: Optional[int] = None, - out_features: Optional[int] = None, - act_layer: Callable[..., nn.Module] = nn.GELU, - drop: float = 0.0, - bias: bool = True, - ) -> None: - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) - self.drop = nn.Dropout(drop) - - def forward(self, x: Tensor) -> Tensor: - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x diff --git a/dinov2/layers/patch_embed.py b/dinov2/layers/patch_embed.py deleted file mode 100644 index 8b7c0804784a42cf80c0297d110dcc68cc85b339..0000000000000000000000000000000000000000 --- a/dinov2/layers/patch_embed.py +++ /dev/null @@ -1,88 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -# References: -# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py -# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py - -from typing import Callable, Optional, Tuple, Union - -from torch import Tensor -import torch.nn as nn - - -def make_2tuple(x): - if isinstance(x, tuple): - assert len(x) == 2 - return x - - assert isinstance(x, int) - return (x, x) - - -class PatchEmbed(nn.Module): - """ - 2D image to patch embedding: (B,C,H,W) -> (B,N,D) - - Args: - img_size: Image size. - patch_size: Patch token size. - in_chans: Number of input image channels. - embed_dim: Number of linear projection output channels. - norm_layer: Normalization layer. - """ - - def __init__( - self, - img_size: Union[int, Tuple[int, int]] = 224, - patch_size: Union[int, Tuple[int, int]] = 16, - in_chans: int = 3, - embed_dim: int = 768, - norm_layer: Optional[Callable] = None, - flatten_embedding: bool = True, - ) -> None: - super().__init__() - - image_HW = make_2tuple(img_size) - patch_HW = make_2tuple(patch_size) - patch_grid_size = ( - image_HW[0] // patch_HW[0], - image_HW[1] // patch_HW[1], - ) - - self.img_size = image_HW - self.patch_size = patch_HW - self.patches_resolution = patch_grid_size - self.num_patches = patch_grid_size[0] * patch_grid_size[1] - - self.in_chans = in_chans - self.embed_dim = embed_dim - - self.flatten_embedding = flatten_embedding - - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) - self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() - - def forward(self, x: Tensor) -> Tensor: - _, _, H, W = x.shape - patch_H, patch_W = self.patch_size - - assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" - assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" - - x = self.proj(x) # B C H W - H, W = x.size(2), x.size(3) - x = x.flatten(2).transpose(1, 2) # B HW C - x = self.norm(x) - if not self.flatten_embedding: - x = x.reshape(-1, H, W, self.embed_dim) # B H W C - return x - - def flops(self) -> float: - Ho, Wo = self.patches_resolution - flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) - if self.norm is not None: - flops += Ho * Wo * self.embed_dim - return flops diff --git a/dinov2/layers/swiglu_ffn.py b/dinov2/layers/swiglu_ffn.py deleted file mode 100644 index 340cee356cb4ad7cb3c8bbefa121f39f7c4e5c6f..0000000000000000000000000000000000000000 --- a/dinov2/layers/swiglu_ffn.py +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import os -from typing import Callable, Optional -import warnings - -from torch import Tensor, nn -import torch.nn.functional as F - - -class SwiGLUFFN(nn.Module): - def __init__( - self, - in_features: int, - hidden_features: Optional[int] = None, - out_features: Optional[int] = None, - act_layer: Callable[..., nn.Module] = None, - drop: float = 0.0, - bias: bool = True, - ) -> None: - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) - self.w3 = nn.Linear(hidden_features, out_features, bias=bias) - - def forward(self, x: Tensor) -> Tensor: - x12 = self.w12(x) - x1, x2 = x12.chunk(2, dim=-1) - hidden = F.silu(x1) * x2 - return self.w3(hidden) - - -XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None -try: - if XFORMERS_ENABLED: - from xformers.ops import SwiGLU - - XFORMERS_AVAILABLE = True - warnings.warn("xFormers is available (SwiGLU)") - else: - warnings.warn("xFormers is disabled (SwiGLU)") - raise ImportError -except ImportError: - SwiGLU = SwiGLUFFN - XFORMERS_AVAILABLE = False - - warnings.warn("xFormers is not available (SwiGLU)") - - -class SwiGLUFFNFused(SwiGLU): - def __init__( - self, - in_features: int, - hidden_features: Optional[int] = None, - out_features: Optional[int] = None, - act_layer: Callable[..., nn.Module] = None, - drop: float = 0.0, - bias: bool = True, - ) -> None: - out_features = out_features or in_features - hidden_features = hidden_features or in_features - hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 - super().__init__( - in_features=in_features, - hidden_features=hidden_features, - out_features=out_features, - bias=bias, - ) - - -class SwiGLUFFNAligned(nn.Module): - def __init__( - self, - in_features: int, - hidden_features: Optional[int] = None, - out_features: Optional[int] = None, - act_layer: Callable[..., nn.Module] = nn.GELU, - drop: float = 0.0, - bias: bool = True, - align_to: int = 8, - device=None, - ) -> None: - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - d = int(hidden_features * 2 / 3) - swiglu_hidden_features = d + (-d % align_to) - self.w1 = nn.Linear(in_features, swiglu_hidden_features, bias=bias, device=device) - self.w2 = nn.Linear(in_features, swiglu_hidden_features, bias=bias, device=device) - self.w3 = nn.Linear(swiglu_hidden_features, out_features, bias=bias, device=device) - - def forward(self, x: Tensor) -> Tensor: - x1 = self.w1(x) - x2 = self.w2(x) - hidden = F.silu(x1) * x2 - return self.w3(hidden) diff --git a/dinov2/logging/__init__.py b/dinov2/logging/__init__.py deleted file mode 100644 index 04a7f02204316d4d1ef38bf6080dae3d66241c25..0000000000000000000000000000000000000000 --- a/dinov2/logging/__init__.py +++ /dev/null @@ -1,102 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import functools -import logging -import os -import sys -from typing import Optional - -import dinov2.distributed as distributed -from .helpers import MetricLogger, SmoothedValue - - -# So that calling _configure_logger multiple times won't add many handlers -@functools.lru_cache() -def _configure_logger( - name: Optional[str] = None, - *, - level: int = logging.DEBUG, - output: Optional[str] = None, -): - """ - Configure a logger. - - Adapted from Detectron2. - - Args: - name: The name of the logger to configure. - level: The logging level to use. - output: A file name or a directory to save log. If None, will not save log file. - If ends with ".txt" or ".log", assumed to be a file name. - Otherwise, logs will be saved to `output/log.txt`. - - Returns: - The configured logger. - """ - - logger = logging.getLogger(name) - logger.setLevel(level) - logger.propagate = False - - # Loosely match Google glog format: - # [IWEF]yyyymmdd hh:mm:ss.uuuuuu threadid file:line] msg - # but use a shorter timestamp and include the logger name: - # [IWEF]yyyymmdd hh:mm:ss logger threadid file:line] msg - fmt_prefix = "%(levelname).1s%(asctime)s %(process)s %(name)s %(filename)s:%(lineno)s] " - fmt_message = "%(message)s" - fmt = fmt_prefix + fmt_message - datefmt = "%Y%m%d %H:%M:%S" - formatter = logging.Formatter(fmt=fmt, datefmt=datefmt) - - # stdout logging for main worker only - if distributed.is_main_process(): - handler = logging.StreamHandler(stream=sys.stdout) - handler.setLevel(logging.DEBUG) - handler.setFormatter(formatter) - logger.addHandler(handler) - - # file logging for all workers - if output: - if os.path.splitext(output)[-1] in (".txt", ".log"): - filename = output - else: - filename = os.path.join(output, "logs", "log.txt") - - if not distributed.is_main_process(): - global_rank = distributed.get_global_rank() - filename = filename + ".rank{}".format(global_rank) - - os.makedirs(os.path.dirname(filename), exist_ok=True) - - handler = logging.StreamHandler(open(filename, "a")) - handler.setLevel(logging.DEBUG) - handler.setFormatter(formatter) - logger.addHandler(handler) - - return logger - - -def setup_logging( - output: Optional[str] = None, - *, - name: Optional[str] = None, - level: int = logging.DEBUG, - capture_warnings: bool = True, -) -> None: - """ - Setup logging. - - Args: - output: A file name or a directory to save log files. If None, log - files will not be saved. If output ends with ".txt" or ".log", it - is assumed to be a file name. - Otherwise, logs will be saved to `output/log.txt`. - name: The name of the logger to configure, by default the root logger. - level: The logging level to use. - capture_warnings: Whether warnings should be captured as logs. - """ - logging.captureWarnings(capture_warnings) - _configure_logger(name, level=level, output=output) diff --git a/dinov2/logging/helpers.py b/dinov2/logging/helpers.py deleted file mode 100644 index c6e70bb15505cbbc4c4732b069ee919bf921a74f..0000000000000000000000000000000000000000 --- a/dinov2/logging/helpers.py +++ /dev/null @@ -1,194 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from collections import defaultdict, deque -import datetime -import json -import logging -import time - -import torch - -import dinov2.distributed as distributed - - -logger = logging.getLogger("dinov2") - - -class MetricLogger(object): - def __init__(self, delimiter="\t", output_file=None): - self.meters = defaultdict(SmoothedValue) - self.delimiter = delimiter - self.output_file = output_file - - def update(self, **kwargs): - for k, v in kwargs.items(): - if isinstance(v, torch.Tensor): - v = v.item() - assert isinstance(v, (float, int)) - self.meters[k].update(v) - - def __getattr__(self, attr): - if attr in self.meters: - return self.meters[attr] - if attr in self.__dict__: - return self.__dict__[attr] - raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr)) - - def __str__(self): - loss_str = [] - for name, meter in self.meters.items(): - loss_str.append("{}: {}".format(name, str(meter))) - return self.delimiter.join(loss_str) - - def synchronize_between_processes(self): - for meter in self.meters.values(): - meter.synchronize_between_processes() - - def add_meter(self, name, meter): - self.meters[name] = meter - - def dump_in_output_file(self, iteration, iter_time, data_time): - if self.output_file is None or not distributed.is_main_process(): - return - dict_to_dump = dict( - iteration=iteration, - iter_time=iter_time, - data_time=data_time, - ) - dict_to_dump.update({k: v.median for k, v in self.meters.items()}) - with open(self.output_file, "a") as f: - f.write(json.dumps(dict_to_dump) + "\n") - pass - - def log_every(self, iterable, print_freq, header=None, n_iterations=None, start_iteration=0): - i = start_iteration - if not header: - header = "" - start_time = time.time() - end = time.time() - iter_time = SmoothedValue(fmt="{avg:.6f}") - data_time = SmoothedValue(fmt="{avg:.6f}") - - if n_iterations is None: - n_iterations = len(iterable) - - space_fmt = ":" + str(len(str(n_iterations))) + "d" - - log_list = [ - header, - "[{0" + space_fmt + "}/{1}]", - "eta: {eta}", - "{meters}", - "time: {time}", - "data: {data}", - ] - if torch.cuda.is_available(): - log_list += ["max mem: {memory:.0f}"] - - log_msg = self.delimiter.join(log_list) - MB = 1024.0 * 1024.0 - for obj in iterable: - data_time.update(time.time() - end) - yield obj - iter_time.update(time.time() - end) - if i % print_freq == 0 or i == n_iterations - 1: - self.dump_in_output_file(iteration=i, iter_time=iter_time.avg, data_time=data_time.avg) - eta_seconds = iter_time.global_avg * (n_iterations - i) - eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) - if torch.cuda.is_available(): - logger.info( - log_msg.format( - i, - n_iterations, - eta=eta_string, - meters=str(self), - time=str(iter_time), - data=str(data_time), - memory=torch.cuda.max_memory_allocated() / MB, - ) - ) - else: - logger.info( - log_msg.format( - i, - n_iterations, - eta=eta_string, - meters=str(self), - time=str(iter_time), - data=str(data_time), - ) - ) - i += 1 - end = time.time() - if i >= n_iterations: - break - total_time = time.time() - start_time - total_time_str = str(datetime.timedelta(seconds=int(total_time))) - logger.info("{} Total time: {} ({:.6f} s / it)".format(header, total_time_str, total_time / n_iterations)) - - -class SmoothedValue: - """Track a series of values and provide access to smoothed values over a - window or the global series average. - """ - - def __init__(self, window_size=20, fmt=None): - if fmt is None: - fmt = "{median:.4f} ({global_avg:.4f})" - self.deque = deque(maxlen=window_size) - self.total = 0.0 - self.count = 0 - self.fmt = fmt - - def update(self, value, num=1): - self.deque.append(value) - self.count += num - self.total += value * num - - def synchronize_between_processes(self): - """ - Distributed synchronization of the metric - Warning: does not synchronize the deque! - """ - if not distributed.is_enabled(): - return - t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") - torch.distributed.barrier() - torch.distributed.all_reduce(t) - t = t.tolist() - self.count = int(t[0]) - self.total = t[1] - - @property - def median(self): - d = torch.tensor(list(self.deque)) - return d.median().item() - - @property - def avg(self): - d = torch.tensor(list(self.deque), dtype=torch.float32) - return d.mean().item() - - @property - def global_avg(self): - return self.total / self.count - - @property - def max(self): - return max(self.deque) - - @property - def value(self): - return self.deque[-1] - - def __str__(self): - return self.fmt.format( - median=self.median, - avg=self.avg, - global_avg=self.global_avg, - max=self.max, - value=self.value, - ) diff --git a/dinov2/loss/__init__.py b/dinov2/loss/__init__.py deleted file mode 100644 index d6b0115b74edbd74b324c9056a57fade363c58fd..0000000000000000000000000000000000000000 --- a/dinov2/loss/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from .dino_clstoken_loss import DINOLoss -from .ibot_patch_loss import iBOTPatchLoss -from .koleo_loss import KoLeoLoss diff --git a/dinov2/loss/dino_clstoken_loss.py b/dinov2/loss/dino_clstoken_loss.py deleted file mode 100644 index c31808e36e6c38ee6dae13ba0443bf1946242117..0000000000000000000000000000000000000000 --- a/dinov2/loss/dino_clstoken_loss.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import nn - - -class DINOLoss(nn.Module): - def __init__( - self, - out_dim, - student_temp=0.1, - center_momentum=0.9, - ): - super().__init__() - self.student_temp = student_temp - self.center_momentum = center_momentum - self.register_buffer("center", torch.zeros(1, out_dim)) - self.updated = True - self.reduce_handle = None - self.len_teacher_output = None - self.async_batch_center = None - - @torch.no_grad() - def softmax_center_teacher(self, teacher_output, teacher_temp): - self.apply_center_update() - # teacher centering and sharpening - return F.softmax((teacher_output - self.center) / teacher_temp, dim=-1) - - @torch.no_grad() - def sinkhorn_knopp_teacher(self, teacher_output, teacher_temp, n_iterations=3): - teacher_output = teacher_output.float() - world_size = dist.get_world_size() if dist.is_initialized() else 1 - Q = torch.exp(teacher_output / teacher_temp).t() # Q is K-by-B for consistency with notations from our paper - B = Q.shape[1] * world_size # number of samples to assign - K = Q.shape[0] # how many prototypes - - # make the matrix sums to 1 - sum_Q = torch.sum(Q) - if dist.is_initialized(): - dist.all_reduce(sum_Q) - Q /= sum_Q - - for it in range(n_iterations): - # normalize each row: total weight per prototype must be 1/K - sum_of_rows = torch.sum(Q, dim=1, keepdim=True) - if dist.is_initialized(): - dist.all_reduce(sum_of_rows) - Q /= sum_of_rows - Q /= K - - # normalize each column: total weight per sample must be 1/B - Q /= torch.sum(Q, dim=0, keepdim=True) - Q /= B - - Q *= B # the columns must sum to 1 so that Q is an assignment - return Q.t() - - def forward(self, student_output_list, teacher_out_softmaxed_centered_list): - """ - Cross-entropy between softmax outputs of the teacher and student networks. - """ - # TODO: Use cross_entropy_distribution here - total_loss = 0 - for s in student_output_list: - lsm = F.log_softmax(s / self.student_temp, dim=-1) - for t in teacher_out_softmaxed_centered_list: - loss = torch.sum(t * lsm, dim=-1) - total_loss -= loss.mean() - return total_loss - - @torch.no_grad() - def update_center(self, teacher_output): - self.reduce_center_update(teacher_output) - - @torch.no_grad() - def reduce_center_update(self, teacher_output): - self.updated = False - self.len_teacher_output = len(teacher_output) - self.async_batch_center = torch.sum(teacher_output, dim=0, keepdim=True) - if dist.is_initialized(): - self.reduce_handle = dist.all_reduce(self.async_batch_center, async_op=True) - - @torch.no_grad() - def apply_center_update(self): - if self.updated is False: - world_size = dist.get_world_size() if dist.is_initialized() else 1 - - if self.reduce_handle is not None: - self.reduce_handle.wait() - _t = self.async_batch_center / (self.len_teacher_output * world_size) - - self.center = self.center * self.center_momentum + _t * (1 - self.center_momentum) - - self.updated = True diff --git a/dinov2/loss/ibot_patch_loss.py b/dinov2/loss/ibot_patch_loss.py deleted file mode 100644 index 6732cda0c311c69f193669ebc950fc8665871442..0000000000000000000000000000000000000000 --- a/dinov2/loss/ibot_patch_loss.py +++ /dev/null @@ -1,151 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import nn - -import logging - - -logger = logging.getLogger("dinov2") - - -try: - from xformers.ops import cross_entropy - - def lossfunc(t, s, temp): - s = s.float() - t = t.float() - if s.ndim == 2: - return -cross_entropy(s.unsqueeze(0), t.unsqueeze(0), temp, bw_inplace=True).squeeze(0) - elif s.ndim == 3: - return -cross_entropy(s, t, temp, bw_inplace=True) - -except ImportError: - - def lossfunc(t, s, temp): - return torch.sum(t * F.log_softmax(s / temp, dim=-1), dim=-1) - - -class iBOTPatchLoss(nn.Module): - def __init__(self, patch_out_dim, student_temp=0.1, center_momentum=0.9): - super().__init__() - self.student_temp = student_temp - self.center_momentum = center_momentum - self.register_buffer("center", torch.zeros(1, 1, patch_out_dim)) - self.updated = True - self.reduce_handle = None - self.len_teacher_patch_tokens = None - self.async_batch_center = None - - @torch.no_grad() - def softmax_center_teacher(self, teacher_patch_tokens, teacher_temp): - self.apply_center_update() - # teacher centering and sharpening - # - # WARNING: - # as self.center is a float32, everything gets casted to float32 afterwards - # - # teacher_patch_tokens = teacher_patch_tokens.float() - # return F.softmax((teacher_patch_tokens.sub_(self.center.to(teacher_patch_tokens.dtype))).mul_(1 / teacher_temp), dim=-1) - - return F.softmax((teacher_patch_tokens - self.center) / teacher_temp, dim=-1) - - # this is experimental, keep everything in float16 and let's see what happens: - # return F.softmax((teacher_patch_tokens.sub_(self.center)) / teacher_temp, dim=-1) - - @torch.no_grad() - def sinkhorn_knopp_teacher(self, teacher_output, teacher_temp, n_masked_patches_tensor, n_iterations=3): - teacher_output = teacher_output.float() - # world_size = dist.get_world_size() if dist.is_initialized() else 1 - Q = torch.exp(teacher_output / teacher_temp).t() # Q is K-by-B for consistency with notations from our paper - # B = Q.shape[1] * world_size # number of samples to assign - B = n_masked_patches_tensor - dist.all_reduce(B) - K = Q.shape[0] # how many prototypes - - # make the matrix sums to 1 - sum_Q = torch.sum(Q) - if dist.is_initialized(): - dist.all_reduce(sum_Q) - Q /= sum_Q - - for it in range(n_iterations): - # normalize each row: total weight per prototype must be 1/K - sum_of_rows = torch.sum(Q, dim=1, keepdim=True) - if dist.is_initialized(): - dist.all_reduce(sum_of_rows) - Q /= sum_of_rows - Q /= K - - # normalize each column: total weight per sample must be 1/B - Q /= torch.sum(Q, dim=0, keepdim=True) - Q /= B - - Q *= B # the columns must sum to 1 so that Q is an assignment - return Q.t() - - def forward(self, student_patch_tokens, teacher_patch_tokens, student_masks_flat): - """ - Cross-entropy between softmax outputs of the teacher and student networks. - student_patch_tokens: (B, N, D) tensor - teacher_patch_tokens: (B, N, D) tensor - student_masks_flat: (B, N) tensor - """ - t = teacher_patch_tokens - s = student_patch_tokens - loss = torch.sum(t * F.log_softmax(s / self.student_temp, dim=-1), dim=-1) - loss = torch.sum(loss * student_masks_flat.float(), dim=-1) / student_masks_flat.sum(dim=-1).clamp(min=1.0) - return -loss.mean() - - def forward_masked( - self, - student_patch_tokens_masked, - teacher_patch_tokens_masked, - student_masks_flat, - n_masked_patches=None, - masks_weight=None, - ): - t = teacher_patch_tokens_masked - s = student_patch_tokens_masked - # loss = torch.sum(t * F.log_softmax(s / self.student_temp, dim=-1), dim=-1) - loss = lossfunc(t, s, self.student_temp) - if masks_weight is None: - masks_weight = ( - (1 / student_masks_flat.sum(-1).clamp(min=1.0)) - .unsqueeze(-1) - .expand_as(student_masks_flat)[student_masks_flat] - ) - if n_masked_patches is not None: - loss = loss[:n_masked_patches] - loss = loss * masks_weight - return -loss.sum() / student_masks_flat.shape[0] - - @torch.no_grad() - def update_center(self, teacher_patch_tokens): - self.reduce_center_update(teacher_patch_tokens) - - @torch.no_grad() - def reduce_center_update(self, teacher_patch_tokens): - self.updated = False - self.len_teacher_patch_tokens = len(teacher_patch_tokens) - self.async_batch_center = torch.sum(teacher_patch_tokens.mean(1), dim=0, keepdim=True) - if dist.is_initialized(): - self.reduce_handle = dist.all_reduce(self.async_batch_center, async_op=True) - - @torch.no_grad() - def apply_center_update(self): - if self.updated is False: - world_size = dist.get_world_size() if dist.is_initialized() else 1 - - if self.reduce_handle is not None: - self.reduce_handle.wait() - _t = self.async_batch_center / (self.len_teacher_patch_tokens * world_size) - - self.center = self.center * self.center_momentum + _t * (1 - self.center_momentum) - - self.updated = True diff --git a/dinov2/loss/koleo_loss.py b/dinov2/loss/koleo_loss.py deleted file mode 100644 index b5cbcd91e0fc0b857f477b0910f957f02a6c4335..0000000000000000000000000000000000000000 --- a/dinov2/loss/koleo_loss.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import logging - -import torch -import torch.nn as nn -import torch.nn.functional as F - -# import torch.distributed as dist - - -logger = logging.getLogger("dinov2") - - -class KoLeoLoss(nn.Module): - """Kozachenko-Leonenko entropic loss regularizer from Sablayrolles et al. - 2018 - Spreading vectors for similarity search""" - - def __init__(self): - super().__init__() - self.pdist = nn.PairwiseDistance(2, eps=1e-8) - - def pairwise_NNs_inner(self, x): - """ - Pairwise nearest neighbors for L2-normalized vectors. - Uses Torch rather than Faiss to remain on GPU. - """ - # parwise dot products (= inverse distance) - dots = torch.mm(x, x.t()) - n = x.shape[0] - dots.view(-1)[:: (n + 1)].fill_(-1) # Trick to fill diagonal with -1 - # max inner prod -> min distance - _, I = torch.max(dots, dim=1) # noqa: E741 - return I - - def forward(self, student_output, eps=1e-8): - """ - Args: - student_output (BxD): backbone output of student - """ - with torch.cuda.amp.autocast(enabled=False): - student_output = F.normalize(student_output, eps=eps, p=2, dim=-1) - I = self.pairwise_NNs_inner(student_output) # noqa: E741 - distances = self.pdist(student_output, student_output[I]) # BxD, BxD -> B - loss = -torch.log(distances + eps).mean() - return loss diff --git a/dinov2/models/__init__.py b/dinov2/models/__init__.py deleted file mode 100644 index 817a63aeb07802ca02f75c9cef674b97355deb3f..0000000000000000000000000000000000000000 --- a/dinov2/models/__init__.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import logging - -from . import vision_transformer as vits - - -logger = logging.getLogger("dinov2") - - -def build_model(args, only_teacher=False, img_size=224): - args.arch = args.arch.removesuffix("_memeff") - if "vit" in args.arch: - vit_kwargs = dict( - img_size=img_size, - patch_size=args.patch_size, - init_values=args.layerscale, - ffn_layer=args.ffn_layer, - block_chunks=args.block_chunks, - qkv_bias=args.qkv_bias, - proj_bias=args.proj_bias, - ffn_bias=args.ffn_bias, - num_register_tokens=args.num_register_tokens, - interpolate_offset=args.interpolate_offset, - interpolate_antialias=args.interpolate_antialias, - in_chans=args.in_chans, - channel_adaptive=args.channel_adaptive, - ) - teacher = vits.__dict__[args.arch](**vit_kwargs) - if only_teacher: - return teacher, teacher.embed_dim - student = vits.__dict__[args.arch]( - **vit_kwargs, - drop_path_rate=args.drop_path_rate, - drop_path_uniform=args.drop_path_uniform, - ) - embed_dim = student.embed_dim - return student, teacher, embed_dim - - -def build_model_from_cfg(cfg, only_teacher=False): - return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size) diff --git a/dinov2/models/vision_transformer.py b/dinov2/models/vision_transformer.py deleted file mode 100644 index 34694244ae4e6467a4aa315180a89b323336bf0b..0000000000000000000000000000000000000000 --- a/dinov2/models/vision_transformer.py +++ /dev/null @@ -1,428 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -# References: -# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py -# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py - -from functools import partial -import math -import logging -from typing import Sequence, Tuple, Union, Callable - -import numpy as np -import torch -import torch.nn as nn -import torch.utils.checkpoint -from torch.nn.init import trunc_normal_ - -from dinov2.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block - - -logger = logging.getLogger("dinov2") - - -def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: - if not depth_first and include_root: - fn(module=module, name=name) - for child_name, child_module in module.named_children(): - child_name = ".".join((name, child_name)) if name else child_name - named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) - if depth_first and include_root: - fn(module=module, name=name) - return module - - -class BlockChunk(nn.ModuleList): - def forward(self, x): - for b in self: - x = b(x) - return x - - -class DinoVisionTransformer(nn.Module): - def __init__( - self, - img_size=224, - patch_size=16, - in_chans=3, - embed_dim=768, - depth=12, - num_heads=12, - mlp_ratio=4.0, - qkv_bias=True, - ffn_bias=True, - proj_bias=True, - drop_path_rate=0.0, - drop_path_uniform=False, - init_values=None, # for layerscale: None or 0 => no layerscale - embed_layer=PatchEmbed, - act_layer=nn.GELU, - block_fn=Block, - ffn_layer="mlp", - block_chunks=1, - num_register_tokens=0, - interpolate_antialias=False, - interpolate_offset=0.1, - channel_adaptive=False, - ): - """ - Args: - img_size (int, tuple): input image size - patch_size (int, tuple): patch size - in_chans (int): number of input channels - embed_dim (int): embedding dimension - depth (int): depth of transformer - num_heads (int): number of attention heads - mlp_ratio (int): ratio of mlp hidden dim to embedding dim - qkv_bias (bool): enable bias for qkv if True - proj_bias (bool): enable bias for proj in attn if True - ffn_bias (bool): enable bias for ffn if True - drop_path_rate (float): stochastic depth rate - drop_path_uniform (bool): apply uniform drop rate across blocks - weight_init (str): weight init scheme - init_values (float): layer-scale init values - embed_layer (nn.Module): patch embedding layer - act_layer (nn.Module): MLP activation layer - block_fn (nn.Module): transformer block class - ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" - block_chunks: (int) split block sequence into block_chunks units for FSDP wrap - num_register_tokens: (int) number of extra cls tokens (so-called "registers") - interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings - interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings - """ - super().__init__() - norm_layer = partial(nn.LayerNorm, eps=1e-6) - - self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models - self.num_tokens = 1 - self.n_blocks = depth - self.num_heads = num_heads - self.patch_size = patch_size - self.num_register_tokens = num_register_tokens - self.interpolate_antialias = interpolate_antialias - self.interpolate_offset = interpolate_offset - self.bag_of_channels = channel_adaptive - - self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) - num_patches = self.patch_embed.num_patches - - self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) - self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) - assert num_register_tokens >= 0 - self.register_tokens = ( - nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None - ) - - if drop_path_uniform is True: - dpr = [drop_path_rate] * depth - else: - dpr = np.linspace(0, drop_path_rate, depth).tolist() # stochastic depth decay rule - - if ffn_layer == "mlp": - logger.info("using MLP layer as FFN") - ffn_layer = Mlp - elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": - logger.info("using SwiGLU layer as FFN") - ffn_layer = SwiGLUFFNFused - elif ffn_layer == "identity": - logger.info("using Identity layer as FFN") - - def f(*args, **kwargs): - return nn.Identity() - - ffn_layer = f - else: - raise NotImplementedError - - blocks_list = [ - block_fn( - dim=embed_dim, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - proj_bias=proj_bias, - ffn_bias=ffn_bias, - drop_path=dpr[i], - norm_layer=norm_layer, - act_layer=act_layer, - ffn_layer=ffn_layer, - init_values=init_values, - ) - for i in range(depth) - ] - if block_chunks > 0: - self.chunked_blocks = True - chunked_blocks = [] - chunksize = depth // block_chunks - for i in range(0, depth, chunksize): - # this is to keep the block index consistent if we chunk the block list - chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) - self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) - else: - self.chunked_blocks = False - self.blocks = nn.ModuleList(blocks_list) - - self.norm = norm_layer(embed_dim) - self.head = nn.Identity() - - self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) - - self.init_weights() - - def init_weights(self): - trunc_normal_(self.pos_embed, std=0.02) - nn.init.normal_(self.cls_token, std=1e-6) - if self.register_tokens is not None: - nn.init.normal_(self.register_tokens, std=1e-6) - named_apply(init_weights_vit_timm, self) - - def interpolate_pos_encoding(self, x, w, h): - previous_dtype = x.dtype - npatch = x.shape[1] - 1 - N = self.pos_embed.shape[1] - 1 - if npatch == N and w == h: - return self.pos_embed - pos_embed = self.pos_embed.float() - class_pos_embed = pos_embed[:, 0] - patch_pos_embed = pos_embed[:, 1:] - dim = x.shape[-1] - w0 = w // self.patch_size - h0 = h // self.patch_size - M = int(math.sqrt(N)) # Recover the number of patches in each dimension - assert N == M * M - kwargs = {} - if self.interpolate_offset: - # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8 - # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors - sx = float(w0 + self.interpolate_offset) / M - sy = float(h0 + self.interpolate_offset) / M - kwargs["scale_factor"] = (sx, sy) - else: - # Simply specify an output size instead of a scale factor - kwargs["size"] = (w0, h0) - patch_pos_embed = nn.functional.interpolate( - patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), - mode="bicubic", - antialias=self.interpolate_antialias, - **kwargs, - ) - assert (w0, h0) == patch_pos_embed.shape[-2:] - patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) - return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) - - def prepare_tokens_with_masks(self, x, masks=None): - B, nc, w, h = x.shape - x = self.patch_embed(x) - if masks is not None: - x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) - - x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) - x = x + self.interpolate_pos_encoding(x, w, h) - - if self.register_tokens is not None: - x = torch.cat( - ( - x[:, :1], - self.register_tokens.expand(x.shape[0], -1, -1), - x[:, 1:], - ), - dim=1, - ) - - return x - - def forward_features_list(self, x_list, masks_list): - x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] - for blk in self.blocks: - x = blk(x) - - all_x = x - output = [] - for x, masks in zip(all_x, masks_list): - x_norm = self.norm(x) - output.append( - { - "x_norm_clstoken": x_norm[:, 0], - "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], - "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], - "x_prenorm": x, - "masks": masks, - } - ) - return output - - def forward_features(self, x, masks=None): - if isinstance(x, list): - return self.forward_features_list(x, masks) - - x = self.prepare_tokens_with_masks(x, masks) - - for blk in self.blocks: - x = blk(x) - - x_norm = self.norm(x) - return { - "x_norm_clstoken": x_norm[:, 0], - "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], - "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], - "x_prenorm": x, - "masks": masks, - } - - def _get_intermediate_layers_not_chunked(self, x, n=1): - x = self.prepare_tokens_with_masks(x) - # If n is an int, take the n last blocks. If it's a list, take them - output, total_block_len = [], len(self.blocks) - blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n - for i, blk in enumerate(self.blocks): - x = blk(x) - if i in blocks_to_take: - output.append(x) - assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" - return output - - def _get_intermediate_layers_chunked(self, x, n=1): - x = self.prepare_tokens_with_masks(x) - output, i, total_block_len = [], 0, len(self.blocks[-1]) - # If n is an int, take the n last blocks. If it's a list, take them - blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n - for block_chunk in self.blocks: - for blk in block_chunk[i:]: # Passing the nn.Identity() - x = blk(x) - if i in blocks_to_take: - output.append(x) - i += 1 - assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" - return output - - def get_intermediate_layers( - self, - x: torch.Tensor, - n: Union[int, Sequence] = 1, # Layers or n last layers to take - reshape: bool = False, - return_class_token: bool = False, - norm=True, - ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: - - if self.bag_of_channels: - B, C, H, W = x.shape - x = x.reshape(B * C, 1, H, W) # passing channels to batch dimension to get encodings for each channel - - if self.chunked_blocks: - outputs = self._get_intermediate_layers_chunked(x, n) - else: - outputs = self._get_intermediate_layers_not_chunked(x, n) - if norm: - outputs = [self.norm(out) for out in outputs] - class_tokens = [out[:, 0] for out in outputs] - outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs] - if reshape: - B, _, w, h = x.shape - outputs = [ - out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() - for out in outputs - ] - - if self.bag_of_channels: - output = tuple(zip(outputs, class_tokens)) - output = list( - zip(*output) - ) # unzip the tuple: (list of patch_tokens per block, list of class tokens per block) - patch_tokens_per_block = output[0] # [BLOCK1, BLOCK2, ...] where BLOCK1.shape: B*C, N, D - cls_tokens_per_block = output[1] # [BLOCK1, BLOCK2, ...] where BLOCK1.shape: B*C, D - patch_tokens_per_block = [ - patch_tokens.reshape(B, C, patch_tokens.shape[-2], patch_tokens.shape[-1]) - for patch_tokens in patch_tokens_per_block - ] # [BLOCK1, BLOCK2, ...] where BLOCK1.shape: B, C, N, D - cls_tokens_per_block = [cls_tokens.reshape(B, -1) for cls_tokens in cls_tokens_per_block] - output = tuple(zip(patch_tokens_per_block, cls_tokens_per_block)) - return output - - if return_class_token: - return tuple(zip(outputs, class_tokens)) - return tuple(outputs) - - def forward(self, *args, is_training=False, **kwargs): - ret = self.forward_features(*args, **kwargs) - if is_training: - return ret - else: - return self.head(ret["x_norm_clstoken"]) - - -def init_weights_vit_timm(module: nn.Module, name: str = ""): - """ViT weight initialization, original timm impl (for reproducibility)""" - if isinstance(module, nn.Linear): - trunc_normal_(module.weight, std=0.02) - if module.bias is not None: - nn.init.zeros_(module.bias) - - -def vit_small(patch_size=16, num_register_tokens=0, in_chans=3, channel_adaptive=False, **kwargs): - model = DinoVisionTransformer( - patch_size=patch_size, - embed_dim=384, - depth=12, - num_heads=6, - mlp_ratio=4, - block_fn=partial(Block, attn_class=MemEffAttention), - num_register_tokens=num_register_tokens, - in_chans=in_chans, - channel_adaptive=channel_adaptive, - **kwargs, - ) - return model - - -def vit_base(patch_size=16, num_register_tokens=0, in_chans=3, channel_adaptive=False, **kwargs): - model = DinoVisionTransformer( - patch_size=patch_size, - embed_dim=768, - depth=12, - num_heads=12, - mlp_ratio=4, - block_fn=partial(Block, attn_class=MemEffAttention), - num_register_tokens=num_register_tokens, - in_chans=in_chans, - channel_adaptive=channel_adaptive, - **kwargs, - ) - return model - - -def vit_large(patch_size=16, num_register_tokens=0, in_chans=3, channel_adaptive=False, **kwargs): - model = DinoVisionTransformer( - patch_size=patch_size, - embed_dim=1024, - depth=24, - num_heads=16, - mlp_ratio=4, - block_fn=partial(Block, attn_class=MemEffAttention), - num_register_tokens=num_register_tokens, - in_chans=in_chans, - channel_adaptive=channel_adaptive, - **kwargs, - ) - return model - - -def vit_giant2(patch_size=16, num_register_tokens=0, in_chans=3, channel_adaptive=False, **kwargs): - """ - Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 - """ - model = DinoVisionTransformer( - patch_size=patch_size, - embed_dim=1536, - depth=40, - num_heads=24, - mlp_ratio=4, - block_fn=partial(Block, attn_class=MemEffAttention), - num_register_tokens=num_register_tokens, - in_chans=in_chans, - channel_adaptive=channel_adaptive, - **kwargs, - ) - return model diff --git a/dinov2/run/__init__.py b/dinov2/run/__init__.py deleted file mode 100644 index b88da6bf80be92af00b72dfdb0a806fa64a7a2d9..0000000000000000000000000000000000000000 --- a/dinov2/run/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. diff --git a/dinov2/run/eval/cell_dino/knn.py b/dinov2/run/eval/cell_dino/knn.py deleted file mode 100644 index aa560bbe1aa9de4e3bbd51fd6aba8a741d57ef89..0000000000000000000000000000000000000000 --- a/dinov2/run/eval/cell_dino/knn.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the CC-by-NC licence, -# found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree. - -import logging -import os -import sys - -from dinov2.eval.cell_dino.knn import get_args_parser as get_knn_args_parser -from dinov2.logging import setup_logging -from dinov2.run.submit import get_args_parser, submit_jobs - - -logger = logging.getLogger("dinov2") - - -class Evaluator: - def __init__(self, args): - self.args = args - - def __call__(self): - from dinov2.eval.cell_dino.knn import main as knn_main - - self._setup_args() - knn_main(self.args) - - def checkpoint(self): - import submitit - - logger.info(f"Requeuing {self.args}") - empty = type(self)(self.args) - return submitit.helpers.DelayedSubmission(empty) - - def _setup_args(self): - import submitit - - job_env = submitit.JobEnvironment() - self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id)) - logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") - logger.info(f"Args: {self.args}") - - -def main(): - description = "Submitit launcher for k-NN Cell-DINO and Channel-Adaptive DINO evaluation" - knn_args_parser = get_knn_args_parser(add_help=False) - parents = [knn_args_parser] - args_parser = get_args_parser(description=description, parents=parents) - args = args_parser.parse_args() - - setup_logging() - - assert os.path.exists(args.config_file), "Configuration file does not exist!" - submit_jobs(Evaluator, args, name="dinov2:knn Cell-DINO") - return 0 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/dinov2/run/eval/cell_dino/linear.py b/dinov2/run/eval/cell_dino/linear.py deleted file mode 100644 index 89319bc4c412adc6587855715d7e7237f4aeaaa7..0000000000000000000000000000000000000000 --- a/dinov2/run/eval/cell_dino/linear.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the CC-by-NC licence, -# found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree. - -import logging -import os -import sys - -from dinov2.eval.cell_dino.linear import get_args_parser as get_linear_args_parser -from dinov2.logging import setup_logging -from dinov2.run.submit import get_args_parser, submit_jobs - - -logger = logging.getLogger("dinov2") - - -class Evaluator: - def __init__(self, args): - self.args = args - - def __call__(self): - from dinov2.eval.cell_dino.linear import main as linear_main - - self._setup_args() - linear_main(self.args) - - def checkpoint(self): - import submitit - - logger.info(f"Requeuing {self.args}") - empty = type(self)(self.args) - return submitit.helpers.DelayedSubmission(empty) - - def _setup_args(self): - import submitit - - job_env = submitit.JobEnvironment() - self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id)) - logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") - logger.info(f"Args: {self.args}") - - -def main(): - description = "Submitit launcher for DINOv2 linear Cell-DINO and Channel-Adaptive DINO evaluation" - linear_args_parser = get_linear_args_parser(add_help=False) - parents = [linear_args_parser] - args_parser = get_args_parser(description=description, parents=parents) - args = args_parser.parse_args() - - setup_logging() - - assert os.path.exists(args.config_file), "Configuration file does not exist!" - submit_jobs(Evaluator, args, name="dinov2:linear Cell-DINO") - return 0 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/dinov2/run/eval/knn.py b/dinov2/run/eval/knn.py deleted file mode 100644 index d11918445cdfe415fe58ac8b3ad0bf29702e3457..0000000000000000000000000000000000000000 --- a/dinov2/run/eval/knn.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import logging -import os -import sys - -from dinov2.eval.knn import get_args_parser as get_knn_args_parser -from dinov2.logging import setup_logging -from dinov2.run.submit import get_args_parser, submit_jobs - - -logger = logging.getLogger("dinov2") - - -class Evaluator: - def __init__(self, args): - self.args = args - - def __call__(self): - from dinov2.eval.knn import main as knn_main - - self._setup_args() - knn_main(self.args) - - def checkpoint(self): - import submitit - - logger.info(f"Requeuing {self.args}") - empty = type(self)(self.args) - return submitit.helpers.DelayedSubmission(empty) - - def _setup_args(self): - import submitit - - job_env = submitit.JobEnvironment() - self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id)) - logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") - logger.info(f"Args: {self.args}") - - -def main(): - description = "Submitit launcher for DINOv2 k-NN evaluation" - knn_args_parser = get_knn_args_parser(add_help=False) - parents = [knn_args_parser] - args_parser = get_args_parser(description=description, parents=parents) - args = args_parser.parse_args() - - setup_logging() - - assert os.path.exists(args.config_file), "Configuration file does not exist!" - submit_jobs(Evaluator, args, name="dinov2:knn") - return 0 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/dinov2/run/eval/linear.py b/dinov2/run/eval/linear.py deleted file mode 100644 index e1dc3293e88512a5cf885ab775dc08e01aed6724..0000000000000000000000000000000000000000 --- a/dinov2/run/eval/linear.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import logging -import os -import sys - -from dinov2.eval.linear import get_args_parser as get_linear_args_parser -from dinov2.logging import setup_logging -from dinov2.run.submit import get_args_parser, submit_jobs - - -logger = logging.getLogger("dinov2") - - -class Evaluator: - def __init__(self, args): - self.args = args - - def __call__(self): - from dinov2.eval.linear import main as linear_main - - self._setup_args() - linear_main(self.args) - - def checkpoint(self): - import submitit - - logger.info(f"Requeuing {self.args}") - empty = type(self)(self.args) - return submitit.helpers.DelayedSubmission(empty) - - def _setup_args(self): - import submitit - - job_env = submitit.JobEnvironment() - self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id)) - logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") - logger.info(f"Args: {self.args}") - - -def main(): - description = "Submitit launcher for DINOv2 linear evaluation" - linear_args_parser = get_linear_args_parser(add_help=False) - parents = [linear_args_parser] - args_parser = get_args_parser(description=description, parents=parents) - args = args_parser.parse_args() - - setup_logging() - - assert os.path.exists(args.config_file), "Configuration file does not exist!" - submit_jobs(Evaluator, args, name="dinov2:linear") - return 0 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/dinov2/run/eval/log_regression.py b/dinov2/run/eval/log_regression.py deleted file mode 100644 index cdf02181122de72cfa463ef38494967219df9cf3..0000000000000000000000000000000000000000 --- a/dinov2/run/eval/log_regression.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import logging -import os -import sys - -from dinov2.eval.log_regression import get_args_parser as get_log_regression_args_parser -from dinov2.logging import setup_logging -from dinov2.run.submit import get_args_parser, submit_jobs - - -logger = logging.getLogger("dinov2") - - -class Evaluator: - def __init__(self, args): - self.args = args - - def __call__(self): - from dinov2.eval.log_regression import main as log_regression_main - - self._setup_args() - log_regression_main(self.args) - - def checkpoint(self): - import submitit - - logger.info(f"Requeuing {self.args}") - empty = type(self)(self.args) - return submitit.helpers.DelayedSubmission(empty) - - def _setup_args(self): - import submitit - - job_env = submitit.JobEnvironment() - self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id)) - logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") - logger.info(f"Args: {self.args}") - - -def main(): - description = "Submitit launcher for DINOv2 logistic evaluation" - log_regression_args_parser = get_log_regression_args_parser(add_help=False) - parents = [log_regression_args_parser] - args_parser = get_args_parser(description=description, parents=parents) - args = args_parser.parse_args() - - setup_logging() - - assert os.path.exists(args.config_file), "Configuration file does not exist!" - submit_jobs(Evaluator, args, name="dinov2:logreg") - return 0 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/dinov2/run/submit.py b/dinov2/run/submit.py deleted file mode 100644 index 4d1f718e704cf9a48913422404c25a7fcc50e738..0000000000000000000000000000000000000000 --- a/dinov2/run/submit.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import argparse -import logging -import os -from pathlib import Path -from typing import List, Optional - -import submitit - -from dinov2.utils.cluster import ( - get_slurm_executor_parameters, - get_slurm_partition, - get_user_checkpoint_path, -) - - -logger = logging.getLogger("dinov2") - - -def get_args_parser( - description: Optional[str] = None, - parents: Optional[List[argparse.ArgumentParser]] = None, - add_help: bool = True, -) -> argparse.ArgumentParser: - parents = parents or [] - slurm_partition = get_slurm_partition() - parser = argparse.ArgumentParser( - description=description, - parents=parents, - add_help=add_help, - ) - parser.add_argument( - "--ngpus", - "--gpus", - "--gpus-per-node", - default=8, - type=int, - help="Number of GPUs to request on each node", - ) - parser.add_argument( - "--nodes", - "--nnodes", - default=1, - type=int, - help="Number of nodes to request", - ) - parser.add_argument( - "--timeout", - default=2800, - type=int, - help="Duration of the job", - ) - parser.add_argument( - "--partition", - default=slurm_partition, - type=str, - help="Partition where to submit", - ) - parser.add_argument( - "--use-volta32", - action="store_true", - help="Request V100-32GB GPUs", - ) - parser.add_argument( - "--comment", - default="", - type=str, - help="Comment to pass to scheduler, e.g. priority message", - ) - parser.add_argument( - "--exclude", - default="", - type=str, - help="Nodes to exclude", - ) - return parser - - -def get_shared_folder() -> Path: - user_checkpoint_path = get_user_checkpoint_path() - if user_checkpoint_path is None: - raise RuntimeError("Path to user checkpoint cannot be determined") - path = user_checkpoint_path / "experiments" - path.mkdir(exist_ok=True) - return path - - -def submit_jobs(task_class, args, name: str): - if not args.output_dir: - args.output_dir = str(get_shared_folder() / "%j") - - Path(args.output_dir).mkdir(parents=True, exist_ok=True) - executor = submitit.AutoExecutor(folder=args.output_dir, slurm_max_num_timeout=30) - - kwargs = {} - if args.use_volta32: - kwargs["slurm_constraint"] = "volta32gb" - if args.comment: - kwargs["slurm_comment"] = args.comment - if args.exclude: - kwargs["slurm_exclude"] = args.exclude - - executor_params = get_slurm_executor_parameters( - nodes=args.nodes, - num_gpus_per_node=args.ngpus, - timeout_min=args.timeout, # max is 60 * 72 - slurm_signal_delay_s=120, - slurm_partition=args.partition, - **kwargs, - ) - executor.update_parameters(name=name, **executor_params) - - task = task_class(args) - job = executor.submit(task) - - logger.info(f"Submitted job_id: {job.job_id}") - str_output_dir = os.path.abspath(args.output_dir).replace("%j", str(job.job_id)) - logger.info(f"Logs and checkpoints will be saved at: {str_output_dir}") diff --git a/dinov2/run/train/train.py b/dinov2/run/train/train.py deleted file mode 100644 index c2366e9bf79765e6abcd70dda6b43f31cb7093eb..0000000000000000000000000000000000000000 --- a/dinov2/run/train/train.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import logging -import os -import sys - -from dinov2.logging import setup_logging -from dinov2.train import get_args_parser as get_train_args_parser -from dinov2.run.submit import get_args_parser, submit_jobs - - -logger = logging.getLogger("dinov2") - - -class Trainer(object): - def __init__(self, args): - self.args = args - - def __call__(self): - from dinov2.train import main as train_main - - self._setup_args() - train_main(self.args) - - def checkpoint(self): - import submitit - - logger.info(f"Requeuing {self.args}") - empty = type(self)(self.args) - return submitit.helpers.DelayedSubmission(empty) - - def _setup_args(self): - import submitit - - job_env = submitit.JobEnvironment() - self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id)) - logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") - logger.info(f"Args: {self.args}") - - -def main(): - description = "Submitit launcher for DINOv2 training" - train_args_parser = get_train_args_parser(add_help=False) - parents = [train_args_parser] - args_parser = get_args_parser(description=description, parents=parents) - args = args_parser.parse_args() - - setup_logging() - - assert os.path.exists(args.config_file), "Configuration file does not exist!" - submit_jobs(Trainer, args, name="dinov2:train") - return 0 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/dinov2/thirdparty/CLIP/LICENSE b/dinov2/thirdparty/CLIP/LICENSE deleted file mode 100644 index c123b69334717d178daa674c2d08e3383fe36134..0000000000000000000000000000000000000000 --- a/dinov2/thirdparty/CLIP/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2021 OpenAI - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/dinov2/thirdparty/CLIP/clip/simple_tokenizer.py b/dinov2/thirdparty/CLIP/clip/simple_tokenizer.py deleted file mode 100644 index 22b171880f6f1ece11dc84642b19bf6421cef92b..0000000000000000000000000000000000000000 --- a/dinov2/thirdparty/CLIP/clip/simple_tokenizer.py +++ /dev/null @@ -1,140 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import gzip -import html -import os -from functools import lru_cache - -import ftfy -import regex as re - - -@lru_cache() -def default_bpe(): - return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") - - -@lru_cache() -def bytes_to_unicode(): - """ - Returns list of utf-8 byte and a corresponding list of unicode strings. - The reversible bpe codes work on unicode strings. - This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. - When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. - This is a signficant percentage of your normal, say, 32K bpe vocab. - To avoid that, we want lookup tables between utf-8 bytes and unicode strings. - And avoids mapping to whitespace/control characters the bpe code barfs on. - """ - bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) - cs = bs[:] - n = 0 - for b in range(2**8): - if b not in bs: - bs.append(b) - cs.append(2**8 + n) - n += 1 - cs = [chr(n) for n in cs] - return dict(zip(bs, cs)) - - -def get_pairs(word): - """Return set of symbol pairs in a word. - Word is represented as tuple of symbols (symbols being variable-length strings). - """ - pairs = set() - prev_char = word[0] - for char in word[1:]: - pairs.add((prev_char, char)) - prev_char = char - return pairs - - -def basic_clean(text): - text = ftfy.fix_text(text) - text = html.unescape(html.unescape(text)) - return text.strip() - - -def whitespace_clean(text): - text = re.sub(r"\s+", " ", text) - text = text.strip() - return text - - -class SimpleTokenizer(object): - def __init__(self, bpe_path: str = default_bpe()): - self.byte_encoder = bytes_to_unicode() - self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} - merges = gzip.open(bpe_path).read().decode("utf-8").split("\n") - merges = merges[1 : 49152 - 256 - 2 + 1] - merges = [tuple(merge.split()) for merge in merges] - vocab = list(bytes_to_unicode().values()) - vocab = vocab + [v + "" for v in vocab] - for merge in merges: - vocab.append("".join(merge)) - vocab.extend(["<|startoftext|>", "<|endoftext|>"]) - self.encoder = dict(zip(vocab, range(len(vocab)))) - self.decoder = {v: k for k, v in self.encoder.items()} - self.bpe_ranks = dict(zip(merges, range(len(merges)))) - self.cache = {"<|startoftext|>": "<|startoftext|>", "<|endoftext|>": "<|endoftext|>"} - self.pat = re.compile( - r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", - re.IGNORECASE, - ) - - def bpe(self, token): - if token in self.cache: - return self.cache[token] - word = tuple(token[:-1]) + (token[-1] + "",) - pairs = get_pairs(word) - - if not pairs: - return token + "" - - while True: - bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) - if bigram not in self.bpe_ranks: - break - first, second = bigram - new_word = [] - i = 0 - while i < len(word): - try: - j = word.index(first, i) - new_word.extend(word[i:j]) - i = j - except Exception: - new_word.extend(word[i:]) - break - - if word[i] == first and i < len(word) - 1 and word[i + 1] == second: - new_word.append(first + second) - i += 2 - else: - new_word.append(word[i]) - i += 1 - new_word = tuple(new_word) - word = new_word - if len(word) == 1: - break - else: - pairs = get_pairs(word) - word = " ".join(word) - self.cache[token] = word - return word - - def encode(self, text): - bpe_tokens = [] - text = whitespace_clean(basic_clean(text)).lower() - for token in re.findall(self.pat, text): - token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) - bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")) - return bpe_tokens - - def decode(self, tokens): - text = "".join([self.decoder[token] for token in tokens]) - text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors="replace").replace("", " ") - return text diff --git a/dinov2/train/__init__.py b/dinov2/train/__init__.py deleted file mode 100644 index 5f1752922d04fff0112eb7796be28ff6b68c6073..0000000000000000000000000000000000000000 --- a/dinov2/train/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from .train import get_args_parser, main -from .ssl_meta_arch import SSLMetaArch diff --git a/dinov2/train/ssl_meta_arch.py b/dinov2/train/ssl_meta_arch.py deleted file mode 100644 index 3ccf15e904ebeb6134dfb4f5c99da4fc8d41b8e4..0000000000000000000000000000000000000000 --- a/dinov2/train/ssl_meta_arch.py +++ /dev/null @@ -1,400 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from functools import partial -import logging - -import torch -from torch import nn - -from dinov2.loss import DINOLoss, iBOTPatchLoss, KoLeoLoss -from dinov2.models import build_model_from_cfg -from dinov2.layers import DINOHead -from dinov2.utils.utils import has_batchnorms -from dinov2.utils.param_groups import get_params_groups_with_decay, fuse_params_groups -from dinov2.fsdp import get_fsdp_wrapper, ShardedGradScaler, get_fsdp_modules, reshard_fsdp_model - -from dinov2.models.vision_transformer import BlockChunk - - -try: - from xformers.ops import fmha -except ImportError: - raise AssertionError("xFormers is required for training") - - -logger = logging.getLogger("dinov2") - - -class SSLMetaArch(nn.Module): - def __init__(self, cfg): - super().__init__() - self.cfg = cfg - self.fp16_scaler = ShardedGradScaler() if cfg.compute_precision.grad_scaler else None - - student_model_dict = dict() - teacher_model_dict = dict() - - student_backbone, teacher_backbone, embed_dim = build_model_from_cfg(cfg) - student_model_dict["backbone"] = student_backbone - teacher_model_dict["backbone"] = teacher_backbone - logger.info(f"OPTIONS -- architecture : embed_dim: {embed_dim}") - - if cfg.student.pretrained_weights: - chkpt = torch.load(cfg.student.pretrained_weights) - logger.info(f"OPTIONS -- pretrained weights: loading from {cfg.student.pretrained_weights}") - student_backbone.load_state_dict(chkpt["model"], strict=False) - - self.embed_dim = embed_dim - self.dino_out_dim = cfg.dino.head_n_prototypes - - self.do_dino = cfg.dino.loss_weight > 0 - self.do_koleo = cfg.dino.koleo_loss_weight > 0 - self.do_ibot = cfg.ibot.loss_weight > 0 - self.ibot_separate_head = cfg.ibot.separate_head - - logger.info("OPTIONS -- DINO") - if self.do_dino: - logger.info(f"OPTIONS -- DINO -- loss_weight: {cfg.dino.loss_weight}") - logger.info(f"OPTIONS -- DINO -- head_n_prototypes: {cfg.dino.head_n_prototypes}") - logger.info(f"OPTIONS -- DINO -- head_bottleneck_dim: {cfg.dino.head_bottleneck_dim}") - logger.info(f"OPTIONS -- DINO -- head_hidden_dim: {cfg.dino.head_hidden_dim}") - self.dino_loss_weight = cfg.dino.loss_weight - dino_head = partial( - DINOHead, - in_dim=embed_dim, - out_dim=cfg.dino.head_n_prototypes, - hidden_dim=cfg.dino.head_hidden_dim, - bottleneck_dim=cfg.dino.head_bottleneck_dim, - nlayers=cfg.dino.head_nlayers, - ) - self.dino_loss = DINOLoss(self.dino_out_dim) - if self.do_koleo: - logger.info("OPTIONS -- DINO -- applying KOLEO regularization") - self.koleo_loss = KoLeoLoss() - - else: - logger.info("OPTIONS -- DINO -- not using DINO") - - if self.do_dino or self.do_ibot: - student_model_dict["dino_head"] = dino_head() - teacher_model_dict["dino_head"] = dino_head() - - logger.info("OPTIONS -- IBOT") - logger.info(f"OPTIONS -- IBOT -- loss_weight: {cfg.ibot.loss_weight}") - logger.info(f"OPTIONS -- IBOT masking -- ibot_mask_ratio_tuple: {cfg.ibot.mask_ratio_min_max}") - logger.info(f"OPTIONS -- IBOT masking -- ibot_mask_sample_probability: {cfg.ibot.mask_sample_probability}") - if self.do_ibot: - self.ibot_loss_weight = cfg.ibot.loss_weight - assert max(cfg.ibot.mask_ratio_min_max) > 0, "please provide a positive mask ratio tuple for ibot" - assert cfg.ibot.mask_sample_probability > 0, "please provide a positive mask probability for ibot" - self.ibot_out_dim = cfg.ibot.head_n_prototypes if self.ibot_separate_head else cfg.dino.head_n_prototypes - self.ibot_patch_loss = iBOTPatchLoss(self.ibot_out_dim) - if self.ibot_separate_head: - logger.info(f"OPTIONS -- IBOT -- loss_weight: {cfg.ibot.loss_weight}") - logger.info(f"OPTIONS -- IBOT -- head_n_prototypes: {cfg.ibot.head_n_prototypes}") - logger.info(f"OPTIONS -- IBOT -- head_bottleneck_dim: {cfg.ibot.head_bottleneck_dim}") - logger.info(f"OPTIONS -- IBOT -- head_hidden_dim: {cfg.ibot.head_hidden_dim}") - ibot_head = partial( - DINOHead, - in_dim=embed_dim, - out_dim=cfg.ibot.head_n_prototypes, - hidden_dim=cfg.ibot.head_hidden_dim, - bottleneck_dim=cfg.ibot.head_bottleneck_dim, - nlayers=cfg.ibot.head_nlayers, - ) - student_model_dict["ibot_head"] = ibot_head() - teacher_model_dict["ibot_head"] = ibot_head() - else: - logger.info("OPTIONS -- IBOT -- head shared with DINO") - - self.need_to_synchronize_fsdp_streams = True - - self.student = nn.ModuleDict(student_model_dict) - self.teacher = nn.ModuleDict(teacher_model_dict) - - # there is no backpropagation through the teacher, so no need for gradients - for p in self.teacher.parameters(): - p.requires_grad = False - logger.info(f"Student and Teacher are built: they are both {cfg.student.arch} network.") - - def forward(self, inputs): - raise NotImplementedError - - def backprop_loss(self, loss): - if self.fp16_scaler is not None: - self.fp16_scaler.scale(loss).backward() - else: - loss.backward() - - def forward_backward(self, images, teacher_temp): - n_global_crops = 2 - assert n_global_crops == 2 - n_local_crops = self.cfg.crops.local_crops_number - - global_crops = images["collated_global_crops"].cuda(non_blocking=True) - local_crops = images["collated_local_crops"].cuda(non_blocking=True) - - masks = images["collated_masks"].cuda(non_blocking=True) - mask_indices_list = images["mask_indices_list"].cuda(non_blocking=True) - n_masked_patches_tensor = images["n_masked_patches"].cuda(non_blocking=True) - n_masked_patches = mask_indices_list.shape[0] - upperbound = images["upperbound"] - masks_weight = images["masks_weight"].cuda(non_blocking=True) - - n_local_crops_loss_terms = max(n_local_crops * n_global_crops, 1) - n_global_crops_loss_terms = (n_global_crops - 1) * n_global_crops - - do_dino = self.do_dino - do_ibot = self.do_ibot - - # loss scales - ibot_loss_scale = 1.0 / n_global_crops - - # teacher output - @torch.no_grad() - def get_teacher_output(): - x, n_global_crops_teacher = global_crops, n_global_crops - teacher_backbone_output_dict = self.teacher.backbone(x, is_training=True) - teacher_cls_tokens = teacher_backbone_output_dict["x_norm_clstoken"] - teacher_cls_tokens = teacher_cls_tokens.chunk(n_global_crops_teacher) - # watch out: these are chunked and cat'd in reverse so A is matched to B in the global crops dino loss - teacher_cls_tokens = torch.cat((teacher_cls_tokens[1], teacher_cls_tokens[0])) - ibot_teacher_patch_tokens = teacher_backbone_output_dict["x_norm_patchtokens"] - _dim = ibot_teacher_patch_tokens.shape[-1] - n_cls_tokens = teacher_cls_tokens.shape[0] - - if do_ibot and not self.ibot_separate_head: - buffer_tensor_teacher = ibot_teacher_patch_tokens.new_zeros(upperbound + n_cls_tokens, _dim) - buffer_tensor_teacher[:n_cls_tokens].copy_(teacher_cls_tokens) - torch.index_select( - ibot_teacher_patch_tokens.flatten(0, 1), - dim=0, - index=mask_indices_list, - out=buffer_tensor_teacher[n_cls_tokens : n_cls_tokens + n_masked_patches], - ) - tokens_after_head = self.teacher.dino_head(buffer_tensor_teacher) - teacher_cls_tokens_after_head = tokens_after_head[:n_cls_tokens] - masked_teacher_patch_tokens_after_head = tokens_after_head[ - n_cls_tokens : n_cls_tokens + n_masked_patches - ] - elif do_ibot and self.ibot_separate_head: - buffer_tensor_teacher = ibot_teacher_patch_tokens.new_zeros(upperbound, _dim) - torch.index_select( - ibot_teacher_patch_tokens.flatten(0, 1), - dim=0, - index=mask_indices_list, - out=buffer_tensor_teacher[:n_masked_patches], - ) - teacher_cls_tokens_after_head = self.teacher.dino_head(teacher_cls_tokens) - masked_teacher_patch_tokens_after_head = self.teacher.ibot_head(buffer_tensor_teacher)[ - :n_masked_patches - ] - else: - teacher_cls_tokens_after_head = self.teacher.dino_head(teacher_cls_tokens) - masked_teacher_ibot_softmaxed_centered = None - - if self.cfg.train.centering == "centering": - teacher_dino_softmaxed_centered_list = self.dino_loss.softmax_center_teacher( - teacher_cls_tokens_after_head, teacher_temp=teacher_temp - ).view(n_global_crops_teacher, -1, *teacher_cls_tokens_after_head.shape[1:]) - self.dino_loss.update_center(teacher_cls_tokens_after_head) - if do_ibot: - masked_teacher_patch_tokens_after_head = masked_teacher_patch_tokens_after_head.unsqueeze(0) - masked_teacher_ibot_softmaxed_centered = self.ibot_patch_loss.softmax_center_teacher( - masked_teacher_patch_tokens_after_head[:, :n_masked_patches], teacher_temp=teacher_temp - ) - masked_teacher_ibot_softmaxed_centered = masked_teacher_ibot_softmaxed_centered.squeeze(0) - self.ibot_patch_loss.update_center(masked_teacher_patch_tokens_after_head[:n_masked_patches]) - - elif self.cfg.train.centering == "sinkhorn_knopp": - teacher_dino_softmaxed_centered_list = self.dino_loss.sinkhorn_knopp_teacher( - teacher_cls_tokens_after_head, teacher_temp=teacher_temp - ).view(n_global_crops_teacher, -1, *teacher_cls_tokens_after_head.shape[1:]) - - if do_ibot: - masked_teacher_ibot_softmaxed_centered = self.ibot_patch_loss.sinkhorn_knopp_teacher( - masked_teacher_patch_tokens_after_head, - teacher_temp=teacher_temp, - n_masked_patches_tensor=n_masked_patches_tensor, - ) - - else: - raise NotImplementedError - - return teacher_dino_softmaxed_centered_list, masked_teacher_ibot_softmaxed_centered - - teacher_dino_softmaxed_centered_list, masked_teacher_ibot_softmaxed_centered = get_teacher_output() - reshard_fsdp_model(self.teacher) - - loss_dict = {} - - loss_accumulator = 0 # for backprop - student_global_backbone_output_dict, student_local_backbone_output_dict = self.student.backbone( - [global_crops, local_crops], masks=[masks, None], is_training=True - ) - - inputs_for_student_head_list = [] - - # 1a: local crops cls tokens - student_local_cls_tokens = student_local_backbone_output_dict["x_norm_clstoken"] - inputs_for_student_head_list.append(student_local_cls_tokens.unsqueeze(0)) - - # 1b: global crops cls tokens - student_global_cls_tokens = student_global_backbone_output_dict["x_norm_clstoken"] - inputs_for_student_head_list.append(student_global_cls_tokens.unsqueeze(0)) - - # 1c: global crops patch tokens - if do_ibot: - _dim = student_global_backbone_output_dict["x_norm_clstoken"].shape[-1] - ibot_student_patch_tokens = student_global_backbone_output_dict["x_norm_patchtokens"] - buffer_tensor_patch_tokens = ibot_student_patch_tokens.new_zeros(upperbound, _dim) - buffer_tensor_patch_tokens[:n_masked_patches].copy_( - torch.index_select(ibot_student_patch_tokens.flatten(0, 1), dim=0, index=mask_indices_list) - ) - if not self.ibot_separate_head: - inputs_for_student_head_list.append(buffer_tensor_patch_tokens.unsqueeze(0)) - else: - student_global_masked_patch_tokens_after_head = self.student.ibot_head(buffer_tensor_patch_tokens)[ - :n_masked_patches - ] - - # 2: run - _attn_bias, cat_inputs = fmha.BlockDiagonalMask.from_tensor_list(inputs_for_student_head_list) - outputs_list = _attn_bias.split(self.student.dino_head(cat_inputs)) - - # 3a: local crops cls tokens - student_local_cls_tokens_after_head = outputs_list.pop(0).squeeze(0) - - # 3b: global crops cls tokens - student_global_cls_tokens_after_head = outputs_list.pop(0).squeeze(0) - - # 3c: global crops patch tokens - if do_ibot and not self.ibot_separate_head: - student_global_masked_patch_tokens_after_head = outputs_list.pop(0).squeeze(0)[:n_masked_patches] - - if n_local_crops > 0: - dino_local_crops_loss = self.dino_loss( - student_output_list=student_local_cls_tokens_after_head.chunk(n_local_crops), - teacher_out_softmaxed_centered_list=teacher_dino_softmaxed_centered_list, - ) / (n_global_crops_loss_terms + n_local_crops_loss_terms) - - # store for display - loss_dict["dino_local_crops_loss"] = dino_local_crops_loss - - # accumulate loss - loss_accumulator += self.dino_loss_weight * dino_local_crops_loss - - # process global crops - loss_scales = 2 # this is here since we process global crops together - - if do_dino: - # compute loss - dino_global_crops_loss = ( - self.dino_loss( - student_output_list=[student_global_cls_tokens_after_head], - teacher_out_softmaxed_centered_list=[ - teacher_dino_softmaxed_centered_list.flatten(0, 1) - ], # these were chunked and stacked in reverse so A is matched to B - ) - * loss_scales - / (n_global_crops_loss_terms + n_local_crops_loss_terms) - ) - - loss_dict["dino_global_crops_loss"] = dino_global_crops_loss - - # accumulate loss - loss_accumulator += self.dino_loss_weight * dino_global_crops_loss - - student_cls_tokens = student_global_cls_tokens - - if self.do_koleo: - koleo_loss = self.cfg.dino.koleo_loss_weight * sum( - self.koleo_loss(p) for p in student_cls_tokens.chunk(2) - ) # we don't apply koleo loss between cls tokens of a same image - loss_accumulator += koleo_loss - loss_dict["koleo_loss"] = ( - koleo_loss / loss_scales - ) # this is to display the same losses as before but we can remove eventually - - if do_ibot: - # compute loss - ibot_patch_loss = ( - self.ibot_patch_loss.forward_masked( - student_global_masked_patch_tokens_after_head, - masked_teacher_ibot_softmaxed_centered, - student_masks_flat=masks, - n_masked_patches=n_masked_patches, - masks_weight=masks_weight, - ) - * loss_scales - * ibot_loss_scale - ) - - # store for display - loss_dict["ibot_loss"] = ibot_patch_loss / 2 - - # accumulate loss - loss_accumulator += self.ibot_loss_weight * ibot_patch_loss - - self.backprop_loss(loss_accumulator) - - self.fsdp_synchronize_streams() - - return loss_dict - - def fsdp_synchronize_streams(self): - if self.need_to_synchronize_fsdp_streams: - torch.cuda.synchronize() - self.student.dino_head._streams = ( - self.teacher.dino_head._streams - ) = self.student.backbone._streams = self.teacher.backbone._streams - self.need_to_synchronize_fsdp_streams = False - - def update_teacher(self, m): - student_param_list = [] - teacher_param_list = [] - with torch.no_grad(): - for k in self.student.keys(): - for ms, mt in zip(get_fsdp_modules(self.student[k]), get_fsdp_modules(self.teacher[k])): - student_param_list += ms.params - teacher_param_list += mt.params - torch._foreach_mul_(teacher_param_list, m) - torch._foreach_add_(teacher_param_list, student_param_list, alpha=1 - m) - - def train(self): - super().train() - self.teacher.eval() - - def get_maybe_fused_params_for_submodel(self, m): - params_groups = get_params_groups_with_decay( - model=m, - lr_decay_rate=self.cfg.optim.layerwise_decay, - patch_embed_lr_mult=self.cfg.optim.patch_embed_lr_mult, - ) - fused_params_groups = fuse_params_groups(params_groups) - logger.info("fusing param groups") - - for g in fused_params_groups: - g["foreach"] = True - return fused_params_groups - - def get_params_groups(self): - all_params_groups = [] - for m in self.student.values(): - all_params_groups += self.get_maybe_fused_params_for_submodel(m) - return all_params_groups - - def prepare_for_distributed_training(self): - logger.info("DISTRIBUTED FSDP -- preparing model for distributed training") - if has_batchnorms(self.student): - raise NotImplementedError - # below will synchronize all student subnetworks across gpus: - for k, v in self.student.items(): - self.teacher[k].load_state_dict(self.student[k].state_dict()) - student_model_cfg = self.cfg.compute_precision.student[k] - self.student[k] = get_fsdp_wrapper(student_model_cfg, modules_to_wrap={BlockChunk})(self.student[k]) - teacher_model_cfg = self.cfg.compute_precision.teacher[k] - self.teacher[k] = get_fsdp_wrapper(teacher_model_cfg, modules_to_wrap={BlockChunk})(self.teacher[k]) diff --git a/dinov2/train/train.py b/dinov2/train/train.py deleted file mode 100644 index 4e86b8daaff01e32acd1f30edcfd65aed03df4db..0000000000000000000000000000000000000000 --- a/dinov2/train/train.py +++ /dev/null @@ -1,327 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import argparse -import logging -import math -import os -from functools import partial - -from fvcore.common.checkpoint import PeriodicCheckpointer -import torch - -from dinov2.data import SamplerType, make_data_loader, make_dataset -from dinov2.data import collate_data_and_cast, DataAugmentationDINO, CellAugmentationDINO, MaskingGenerator -import dinov2.distributed as distributed -from dinov2.fsdp import FSDPCheckpointer -from dinov2.logging import MetricLogger -from dinov2.utils.config import setup -from dinov2.utils.utils import CosineScheduler - -from dinov2.train.ssl_meta_arch import SSLMetaArch - - -torch.backends.cuda.matmul.allow_tf32 = True # PyTorch 1.12 sets this to False by default -logger = logging.getLogger("dinov2") - - -def get_args_parser(add_help: bool = True): - parser = argparse.ArgumentParser("DINOv2 training", add_help=add_help) - parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file") - parser.add_argument( - "--no-resume", - action="store_true", - help="Whether to not attempt to resume from the checkpoint directory. ", - ) - parser.add_argument("--eval-only", action="store_true", help="perform evaluation only") - parser.add_argument("--eval", type=str, default="", help="Eval type to perform") - parser.add_argument( - "opts", - help=""" -Modify config options at the end of the command. For Yacs configs, use -space-separated "PATH.KEY VALUE" pairs. -For python-based LazyConfig, use "path.key=value". - """.strip(), - default=None, - nargs=argparse.REMAINDER, - ) - parser.add_argument( - "--output-dir", - "--output_dir", - default="", - type=str, - help="Output directory to save logs and checkpoints", - ) - - return parser - - -def build_optimizer(cfg, params_groups): - return torch.optim.AdamW(params_groups, betas=(cfg.optim.adamw_beta1, cfg.optim.adamw_beta2)) - - -def build_schedulers(cfg): - OFFICIAL_EPOCH_LENGTH = cfg.train.OFFICIAL_EPOCH_LENGTH - lr = dict( - base_value=cfg.optim["lr"], - final_value=cfg.optim["min_lr"], - total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH, - warmup_iters=cfg.optim["warmup_epochs"] * OFFICIAL_EPOCH_LENGTH, - start_warmup_value=0, - ) - wd = dict( - base_value=cfg.optim["weight_decay"], - final_value=cfg.optim["weight_decay_end"], - total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH, - ) - momentum = dict( - base_value=cfg.teacher["momentum_teacher"], - final_value=cfg.teacher["final_momentum_teacher"], - total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH, - ) - teacher_temp = dict( - base_value=cfg.teacher["teacher_temp"], - final_value=cfg.teacher["teacher_temp"], - total_iters=cfg.teacher["warmup_teacher_temp_epochs"] * OFFICIAL_EPOCH_LENGTH, - warmup_iters=cfg.teacher["warmup_teacher_temp_epochs"] * OFFICIAL_EPOCH_LENGTH, - start_warmup_value=cfg.teacher["warmup_teacher_temp"], - ) - - lr_schedule = CosineScheduler(**lr) - wd_schedule = CosineScheduler(**wd) - momentum_schedule = CosineScheduler(**momentum) - teacher_temp_schedule = CosineScheduler(**teacher_temp) - last_layer_lr_schedule = CosineScheduler(**lr) - - last_layer_lr_schedule.schedule[ - : cfg.optim["freeze_last_layer_epochs"] * OFFICIAL_EPOCH_LENGTH - ] = 0 # mimicking the original schedules - - logger.info("Schedulers ready.") - - return ( - lr_schedule, - wd_schedule, - momentum_schedule, - teacher_temp_schedule, - last_layer_lr_schedule, - ) - - -def apply_optim_scheduler(optimizer, lr, wd, last_layer_lr): - for param_group in optimizer.param_groups: - is_last_layer = param_group["is_last_layer"] - lr_multiplier = param_group["lr_multiplier"] - wd_multiplier = param_group["wd_multiplier"] - param_group["weight_decay"] = wd * wd_multiplier - param_group["lr"] = (last_layer_lr if is_last_layer else lr) * lr_multiplier - - -def do_test(cfg, model, iteration): - new_state_dict = model.teacher.state_dict() - - if distributed.is_main_process(): - iterstring = str(iteration) - eval_dir = os.path.join(cfg.train.output_dir, "eval", iterstring) - os.makedirs(eval_dir, exist_ok=True) - # save teacher checkpoint - teacher_ckp_path = os.path.join(eval_dir, "teacher_checkpoint.pth") - torch.save({"teacher": new_state_dict}, teacher_ckp_path) - - -def do_train(cfg, model, resume=False): - model.train() - inputs_dtype = torch.half - fp16_scaler = model.fp16_scaler # for mixed precision training - - # setup optimizer - - optimizer = build_optimizer(cfg, model.get_params_groups()) - ( - lr_schedule, - wd_schedule, - momentum_schedule, - teacher_temp_schedule, - last_layer_lr_schedule, - ) = build_schedulers(cfg) - - # checkpointer - checkpointer = FSDPCheckpointer(model, cfg.train.output_dir, optimizer=optimizer, save_to_disk=True) - - start_iter = checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1 - - OFFICIAL_EPOCH_LENGTH = cfg.train.OFFICIAL_EPOCH_LENGTH - max_iter = cfg.optim.epochs * OFFICIAL_EPOCH_LENGTH - - periodic_checkpointer = PeriodicCheckpointer( - checkpointer, - period=3 * OFFICIAL_EPOCH_LENGTH, - max_iter=max_iter, - max_to_keep=3, - ) - - # setup data preprocessing - - img_size = cfg.crops.global_crops_size - patch_size = cfg.student.patch_size - n_tokens = (img_size // patch_size) ** 2 - mask_generator = MaskingGenerator( - input_size=(img_size // patch_size, img_size // patch_size), - max_num_patches=0.5 * img_size // patch_size * img_size // patch_size, - ) - - if cfg.train.cell_augmentation: - data_transform = CellAugmentationDINO( - cfg.crops.global_crops_scale, - cfg.crops.local_crops_scale, - cfg.crops.local_crops_number, - global_crops_size=cfg.crops.global_crops_size, - local_crops_size=cfg.crops.local_crops_size, - ) - else: - data_transform = DataAugmentationDINO( - cfg.crops.global_crops_scale, - cfg.crops.local_crops_scale, - cfg.crops.local_crops_number, - global_crops_size=cfg.crops.global_crops_size, - local_crops_size=cfg.crops.local_crops_size, - ) - - collate_fn = partial( - collate_data_and_cast, - mask_ratio_tuple=cfg.ibot.mask_ratio_min_max, - mask_probability=cfg.ibot.mask_sample_probability, - n_tokens=n_tokens, - mask_generator=mask_generator, - dtype=inputs_dtype, - ) - - # setup data loader - - dataset = make_dataset( - dataset_str=cfg.train.dataset_path, - transform=data_transform, - target_transform=lambda _: (), - ) - # sampler_type = SamplerType.INFINITE - sampler_type = SamplerType.SHARDED_INFINITE - data_loader = make_data_loader( - dataset=dataset, - batch_size=cfg.train.batch_size_per_gpu, - num_workers=cfg.train.num_workers, - shuffle=True, - seed=start_iter, # TODO: Fix this -- cfg.train.seed - sampler_type=sampler_type, - sampler_advance=0, # TODO(qas): fix this -- start_iter * cfg.train.batch_size_per_gpu, - drop_last=True, - collate_fn=collate_fn, - ) - - # training loop - - iteration = start_iter - - logger.info("Starting training from iteration {}".format(start_iter)) - metrics_file = os.path.join(cfg.train.output_dir, "training_metrics.json") - metric_logger = MetricLogger(delimiter=" ", output_file=metrics_file) - header = "Training" - - for data in metric_logger.log_every( - data_loader, - 10, - header, - max_iter, - start_iter, - ): - current_batch_size = data["collated_global_crops"].shape[0] / 2 - if iteration > max_iter: - return - - # apply schedules - - lr = lr_schedule[iteration] - wd = wd_schedule[iteration] - mom = momentum_schedule[iteration] - teacher_temp = teacher_temp_schedule[iteration] - last_layer_lr = last_layer_lr_schedule[iteration] - apply_optim_scheduler(optimizer, lr, wd, last_layer_lr) - - # compute losses - - optimizer.zero_grad(set_to_none=True) - loss_dict = model.forward_backward(data, teacher_temp=teacher_temp) - - # clip gradients - - if fp16_scaler is not None: - if cfg.optim.clip_grad: - fp16_scaler.unscale_(optimizer) - for v in model.student.values(): - v.clip_grad_norm_(cfg.optim.clip_grad) - fp16_scaler.step(optimizer) - fp16_scaler.update() - else: - if cfg.optim.clip_grad: - for v in model.student.values(): - v.clip_grad_norm_(cfg.optim.clip_grad) - optimizer.step() - - # perform teacher EMA update - - model.update_teacher(mom) - - # logging - - if distributed.get_global_size() > 1: - for v in loss_dict.values(): - torch.distributed.all_reduce(v) - loss_dict_reduced = {k: v.item() / distributed.get_global_size() for k, v in loss_dict.items()} - - if math.isnan(sum(loss_dict_reduced.values())): - logger.info("NaN detected") - raise AssertionError - losses_reduced = sum(loss for loss in loss_dict_reduced.values()) - - metric_logger.update(lr=lr) - metric_logger.update(wd=wd) - metric_logger.update(mom=mom) - metric_logger.update(last_layer_lr=last_layer_lr) - metric_logger.update(current_batch_size=current_batch_size) - metric_logger.update(total_loss=losses_reduced, **loss_dict_reduced) - - # checkpointing and testing - - if cfg.evaluation.eval_period_iterations > 0 and (iteration + 1) % cfg.evaluation.eval_period_iterations == 0: - do_test(cfg, model, f"training_{iteration}") - torch.cuda.synchronize() - periodic_checkpointer.step(iteration) - - iteration = iteration + 1 - metric_logger.synchronize_between_processes() - return {k: meter.global_avg for k, meter in metric_logger.meters.items()} - - -def main(args): - cfg = setup(args) - - model = SSLMetaArch(cfg).to(torch.device("cuda")) - model.prepare_for_distributed_training() - - logger.info("Model:\n{}".format(model)) - if args.eval_only: - iteration = ( - FSDPCheckpointer(model, save_dir=cfg.train.output_dir) - .resume_or_load(cfg.MODEL.WEIGHTS, resume=not args.no_resume) - .get("iteration", -1) - + 1 - ) - return do_test(cfg, model, f"manual_{iteration}") - - do_train(cfg, model, resume=not args.no_resume) - - -if __name__ == "__main__": - args = get_args_parser(add_help=True).parse_args() - main(args) diff --git a/dinov2/utils/__init__.py b/dinov2/utils/__init__.py deleted file mode 100644 index b88da6bf80be92af00b72dfdb0a806fa64a7a2d9..0000000000000000000000000000000000000000 --- a/dinov2/utils/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. diff --git a/dinov2/utils/checkpoint.py b/dinov2/utils/checkpoint.py deleted file mode 100644 index 91fbcbaf051a1181a30e2e7258b2df1e66ccff7c..0000000000000000000000000000000000000000 --- a/dinov2/utils/checkpoint.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the CC-by-NC licence, -# found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree. - -from typing import Any - -from fvcore.common.checkpoint import Checkpointer, PeriodicCheckpointer -from torch import nn - -import dinov2.distributed as dist - - -class PeriodicCheckpointerWithCleanup(PeriodicCheckpointer): - @property - def does_write(self) -> bool: - """See https://github.com/facebookresearch/fvcore/blob/main/fvcore/common/checkpoint.py#L114""" - return self.checkpointer.save_dir and self.checkpointer.save_to_disk - - def save_best(self, **kwargs: Any) -> None: - """Same argument as `Checkpointer.save`, to save a model named like `model_best.pth`""" - self.checkpointer.save(f"{self.file_prefix}_best", **kwargs) - - def has_checkpoint(self) -> bool: - return self.checkpointer.has_checkpoint() - - def get_checkpoint_file(self) -> str: # returns "" if the file does not exist - return self.checkpointer.get_checkpoint_file() - - def load(self, path: str, checkpointables=None) -> dict[str, Any]: - return self.checkpointer.load(path=path, checkpointables=checkpointables) - - def step(self, iteration: int, **kwargs: Any) -> None: - if not self.does_write: # step also removes files, so should be deactivated when object does not write - return - super().step(iteration=iteration, **kwargs) - - -def resume_or_load(checkpointer: Checkpointer, path: str, *, resume: bool = True) -> dict[str, Any]: - """ - If `resume` is True, this method attempts to resume from the last - checkpoint, if exists. Otherwise, load checkpoint from the given path. - Similar to Checkpointer.resume_or_load in fvcore - https://github.com/facebookresearch/fvcore/blob/main/fvcore/common/checkpoint.py#L208 - but always reload checkpointables, in case we want to resume the training in a new job. - """ - if resume and checkpointer.has_checkpoint(): - path = checkpointer.get_checkpoint_file() - return checkpointer.load(path) - - -def build_periodic_checkpointer( - model: nn.Module, - save_dir="", - *, - period: int, - max_iter=None, - max_to_keep=None, - **checkpointables: Any, -) -> PeriodicCheckpointerWithCleanup: - """Util to build a `PeriodicCheckpointerWithCleanup`.""" - checkpointer = Checkpointer(model, save_dir, **checkpointables, save_to_disk=dist.is_main_process()) - return PeriodicCheckpointerWithCleanup(checkpointer, period, max_iter=max_iter, max_to_keep=max_to_keep) diff --git a/dinov2/utils/cluster.py b/dinov2/utils/cluster.py deleted file mode 100644 index 855a5268b226422b49ae6edfed80dca1ee7d9cb1..0000000000000000000000000000000000000000 --- a/dinov2/utils/cluster.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from enum import Enum -import os -from pathlib import Path -from typing import Any, Dict, Optional - - -class ClusterType(Enum): - AWS = "aws" - FAIR = "fair" - RSC = "rsc" - - -def _guess_cluster_type() -> ClusterType: - uname = os.uname() - if uname.sysname == "Linux": - if uname.release.endswith("-aws"): - # Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws" - return ClusterType.AWS - elif uname.nodename.startswith("rsc"): - # Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc" - return ClusterType.RSC - - return ClusterType.FAIR - - -def get_cluster_type(cluster_type: Optional[ClusterType] = None) -> Optional[ClusterType]: - if cluster_type is None: - return _guess_cluster_type() - - return cluster_type - - -def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]: - cluster_type = get_cluster_type(cluster_type) - if cluster_type is None: - return None - - CHECKPOINT_DIRNAMES = { - ClusterType.AWS: "checkpoints", - ClusterType.FAIR: "checkpoint", - ClusterType.RSC: "checkpoint/dino", - } - return Path("/") / CHECKPOINT_DIRNAMES[cluster_type] - - -def get_user_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]: - checkpoint_path = get_checkpoint_path(cluster_type) - if checkpoint_path is None: - return None - - username = os.environ.get("USER") - assert username is not None - return checkpoint_path / username - - -def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]: - cluster_type = get_cluster_type(cluster_type) - if cluster_type is None: - return None - - SLURM_PARTITIONS = { - ClusterType.AWS: "learnaccel", - ClusterType.FAIR: "learnaccel", - ClusterType.RSC: "learn", - } - return SLURM_PARTITIONS[cluster_type] - - -def get_slurm_executor_parameters( - nodes: int, num_gpus_per_node: int, cluster_type: Optional[ClusterType] = None, **kwargs -) -> Dict[str, Any]: - # create default parameters - params = { - "mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html - "gpus_per_node": num_gpus_per_node, - "tasks_per_node": num_gpus_per_node, # one task per GPU - "cpus_per_task": 10, - "nodes": nodes, - "slurm_partition": get_slurm_partition(cluster_type), - } - # apply cluster-specific adjustments - cluster_type = get_cluster_type(cluster_type) - if cluster_type == ClusterType.AWS: - params["cpus_per_task"] = 12 - del params["mem_gb"] - elif cluster_type == ClusterType.RSC: - params["cpus_per_task"] = 12 - # set additional parameters / apply overrides - params.update(kwargs) - return params diff --git a/dinov2/utils/config.py b/dinov2/utils/config.py deleted file mode 100644 index c9de578787bbcb376f8bd5a782206d0eb7ec1f52..0000000000000000000000000000000000000000 --- a/dinov2/utils/config.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import math -import logging -import os - -from omegaconf import OmegaConf - -import dinov2.distributed as distributed -from dinov2.logging import setup_logging -from dinov2.utils import utils -from dinov2.configs import dinov2_default_config - - -logger = logging.getLogger("dinov2") - - -def apply_scaling_rules_to_cfg(cfg): # to fix - if cfg.optim.scaling_rule == "sqrt_wrt_1024": - base_lr = cfg.optim.base_lr - cfg.optim.lr = base_lr - cfg.optim.lr *= math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0) - logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}") - else: - raise NotImplementedError - return cfg - - -def write_config(cfg, output_dir, name="config.yaml"): - logger.info(OmegaConf.to_yaml(cfg)) - saved_cfg_path = os.path.join(output_dir, name) - with open(saved_cfg_path, "w") as f: - OmegaConf.save(config=cfg, f=f) - return saved_cfg_path - - -def get_cfg_from_args(args): - args.output_dir = os.path.abspath(args.output_dir) - args.opts += [f"train.output_dir={args.output_dir}"] - default_cfg = OmegaConf.create(dinov2_default_config) - cfg = OmegaConf.load(args.config_file) - cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts)) - return cfg - - -def default_setup(args): - distributed.enable(overwrite=True) - seed = getattr(args, "seed", 0) - rank = distributed.get_global_rank() - - global logger - setup_logging(output=args.output_dir, level=logging.INFO) - logger = logging.getLogger("dinov2") - - utils.fix_random_seeds(seed + rank) - logger.info("git:\n {}\n".format(utils.get_sha())) - logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) - - -def setup(args): - """ - Create configs and perform basic setups. - """ - cfg = get_cfg_from_args(args) - os.makedirs(args.output_dir, exist_ok=True) - default_setup(args) - apply_scaling_rules_to_cfg(cfg) - write_config(cfg, args.output_dir) - return cfg diff --git a/dinov2/utils/dtype.py b/dinov2/utils/dtype.py deleted file mode 100644 index 80f4cd74d99faa2731dbe9f8d3a13d71b3f8e3a8..0000000000000000000000000000000000000000 --- a/dinov2/utils/dtype.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - - -from typing import Dict, Union - -import numpy as np -import torch - - -TypeSpec = Union[str, np.dtype, torch.dtype] - - -_NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = { - np.dtype("bool"): torch.bool, - np.dtype("uint8"): torch.uint8, - np.dtype("int8"): torch.int8, - np.dtype("int16"): torch.int16, - np.dtype("int32"): torch.int32, - np.dtype("int64"): torch.int64, - np.dtype("float16"): torch.float16, - np.dtype("float32"): torch.float32, - np.dtype("float64"): torch.float64, - np.dtype("complex64"): torch.complex64, - np.dtype("complex128"): torch.complex128, -} - - -def as_torch_dtype(dtype: TypeSpec) -> torch.dtype: - if isinstance(dtype, torch.dtype): - return dtype - if isinstance(dtype, str): - dtype = np.dtype(dtype) - assert isinstance(dtype, np.dtype), f"Expected an instance of nunpy dtype, got {type(dtype)}" - return _NUMPY_TO_TORCH_DTYPE[dtype] diff --git a/dinov2/utils/param_groups.py b/dinov2/utils/param_groups.py deleted file mode 100644 index 9a5d2ff627cddadc222e5f836864ee39c865208f..0000000000000000000000000000000000000000 --- a/dinov2/utils/param_groups.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -from collections import defaultdict -import logging - - -logger = logging.getLogger("dinov2") - - -def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12, force_is_backbone=False, chunked_blocks=False): - """ - Calculate lr decay rate for different ViT blocks. - Args: - name (string): parameter name. - lr_decay_rate (float): base lr decay rate. - num_layers (int): number of ViT blocks. - Returns: - lr decay rate for the given parameter. - """ - layer_id = num_layers + 1 - if name.startswith("backbone") or force_is_backbone: - if ( - ".pos_embed" in name - or ".patch_embed" in name - or ".mask_token" in name - or ".cls_token" in name - or ".register_tokens" in name - ): - layer_id = 0 - elif force_is_backbone and ( - "pos_embed" in name - or "patch_embed" in name - or "mask_token" in name - or "cls_token" in name - or "register_tokens" in name - ): - layer_id = 0 - elif ".blocks." in name and ".residual." not in name: - layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1 - elif chunked_blocks and "blocks." in name and "residual." not in name: - layer_id = int(name[name.find("blocks.") :].split(".")[2]) + 1 - elif "blocks." in name and "residual." not in name: - layer_id = int(name[name.find("blocks.") :].split(".")[1]) + 1 - - return lr_decay_rate ** (num_layers + 1 - layer_id) - - -def get_params_groups_with_decay(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0): - chunked_blocks = False - if hasattr(model, "n_blocks"): - logger.info("chunked fsdp") - n_blocks = model.n_blocks - chunked_blocks = model.chunked_blocks - elif hasattr(model, "blocks"): - logger.info("first code branch") - n_blocks = len(model.blocks) - elif hasattr(model, "backbone"): - logger.info("second code branch") - n_blocks = len(model.backbone.blocks) - else: - logger.info("else code branch") - n_blocks = 0 - all_param_groups = [] - - for name, param in model.named_parameters(): - name = name.replace("_fsdp_wrapped_module.", "") - if not param.requires_grad: - continue - decay_rate = get_vit_lr_decay_rate( - name, lr_decay_rate, num_layers=n_blocks, force_is_backbone=n_blocks > 0, chunked_blocks=chunked_blocks - ) - d = {"params": param, "is_last_layer": False, "lr_multiplier": decay_rate, "wd_multiplier": 1.0, "name": name} - - if "last_layer" in name: - d.update({"is_last_layer": True}) - - if name.endswith(".bias") or "norm" in name or "gamma" in name: - d.update({"wd_multiplier": 0.0}) - - if "patch_embed" in name: - d.update({"lr_multiplier": d["lr_multiplier"] * patch_embed_lr_mult}) - - all_param_groups.append(d) - logger.info(f"""{name}: lr_multiplier: {d["lr_multiplier"]}, wd_multiplier: {d["wd_multiplier"]}""") - - return all_param_groups - - -def fuse_params_groups(all_params_groups, keys=("lr_multiplier", "wd_multiplier", "is_last_layer")): - fused_params_groups = defaultdict(lambda: {"params": []}) - for d in all_params_groups: - identifier = "" - for k in keys: - identifier += k + str(d[k]) + "_" - - for k in keys: - fused_params_groups[identifier][k] = d[k] - fused_params_groups[identifier]["params"].append(d["params"]) - - return fused_params_groups.values() diff --git a/dinov2/utils/utils.py b/dinov2/utils/utils.py deleted file mode 100644 index 68f8e2c3be5f780bbb7e00359b5ac4fd0ba0785f..0000000000000000000000000000000000000000 --- a/dinov2/utils/utils.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import logging -import os -import random -import subprocess -from urllib.parse import urlparse - -import numpy as np -import torch -from torch import nn - - -logger = logging.getLogger("dinov2") - - -def load_pretrained_weights(model, pretrained_weights, checkpoint_key): - if urlparse(pretrained_weights).scheme: # If it looks like an URL - state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu") - else: - state_dict = torch.load(pretrained_weights, map_location="cpu") - if checkpoint_key is not None and checkpoint_key in state_dict: - logger.info(f"Take key {checkpoint_key} in provided checkpoint dict") - state_dict = state_dict[checkpoint_key] - # remove `module.` prefix - state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} - # remove `backbone.` prefix induced by multicrop wrapper - state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} - msg = model.load_state_dict(state_dict, strict=False) - logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg)) - - -def fix_random_seeds(seed=31): - """ - Fix random seeds. - """ - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - np.random.seed(seed) - random.seed(seed) - - -def get_sha(): - cwd = os.path.dirname(os.path.abspath(__file__)) - - def _run(command): - return subprocess.check_output(command, cwd=cwd).decode("ascii").strip() - - sha = "N/A" - diff = "clean" - branch = "N/A" - try: - sha = _run(["git", "rev-parse", "HEAD"]) - subprocess.check_output(["git", "diff"], cwd=cwd) - diff = _run(["git", "diff-index", "HEAD"]) - diff = "has uncommitted changes" if diff else "clean" - branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"]) - except Exception: - pass - message = f"sha: {sha}, status: {diff}, branch: {branch}" - return message - - -class CosineScheduler(object): - def __init__(self, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0): - super().__init__() - self.final_value = final_value - self.total_iters = total_iters - - freeze_schedule = np.zeros((freeze_iters)) - - warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) - - iters = np.arange(total_iters - warmup_iters - freeze_iters) - schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) - self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule)) - - assert len(self.schedule) == self.total_iters - - def __getitem__(self, it): - if it >= self.total_iters: - return self.final_value - else: - return self.schedule[it] - - -def has_batchnorms(model): - bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) - for name, module in model.named_modules(): - if isinstance(module, bn_types): - return True - return False