Delete dinov2
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- dinov2/__init__.py +0 -6
- dinov2/configs/__init__.py +0 -22
- dinov2/configs/eval/cell_dino/vitl16_channel_adaptive_pretrain.yaml +0 -35
- dinov2/configs/eval/cell_dino/vitl16_pretrain.yaml +0 -14
- dinov2/configs/eval/vitb14_pretrain.yaml +0 -6
- dinov2/configs/eval/vitb14_reg4_pretrain.yaml +0 -9
- dinov2/configs/eval/vitg14_pretrain.yaml +0 -7
- dinov2/configs/eval/vitg14_reg4_pretrain.yaml +0 -10
- dinov2/configs/eval/vitl14_pretrain.yaml +0 -6
- dinov2/configs/eval/vitl14_reg4_pretrain.yaml +0 -9
- dinov2/configs/eval/vits14_pretrain.yaml +0 -6
- dinov2/configs/eval/vits14_reg4_pretrain.yaml +0 -9
- dinov2/configs/ssl_default_config.yaml +0 -123
- dinov2/configs/train/cell_dino/vitl16_boc_hpafov.yaml +0 -31
- dinov2/configs/train/cell_dino/vitl16_hpafov.yaml +0 -32
- dinov2/configs/train/cell_dino/vitl16_hpaone.yaml +0 -30
- dinov2/configs/train/vitg14.yaml +0 -26
- dinov2/configs/train/vitl14.yaml +0 -26
- dinov2/configs/train/vitl16_short.yaml +0 -6
- dinov2/data/__init__.py +0 -12
- dinov2/data/accumulators.py +0 -133
- dinov2/data/adapters.py +0 -51
- dinov2/data/augmentations.py +0 -118
- dinov2/data/cell_dino/augmentations.py +0 -91
- dinov2/data/cell_dino/transforms.py +0 -169
- dinov2/data/collate.py +0 -49
- dinov2/data/datasets/__init__.py +0 -12
- dinov2/data/datasets/cell_dino/chammi_cp.py +0 -112
- dinov2/data/datasets/cell_dino/chammi_hpa.py +0 -111
- dinov2/data/datasets/cell_dino/chammi_wtc.py +0 -108
- dinov2/data/datasets/cell_dino/hpafov.py +0 -283
- dinov2/data/datasets/cell_dino/hpaone.py +0 -223
- dinov2/data/datasets/decoders.py +0 -94
- dinov2/data/datasets/extended.py +0 -44
- dinov2/data/datasets/image_net.py +0 -290
- dinov2/data/datasets/image_net_22k.py +0 -302
- dinov2/data/loaders.py +0 -232
- dinov2/data/masking.py +0 -86
- dinov2/data/samplers.py +0 -229
- dinov2/data/transforms.py +0 -91
- dinov2/distributed/__init__.py +0 -270
- dinov2/eval/__init__.py +0 -4
- dinov2/eval/cell_dino/knn.py +0 -479
- dinov2/eval/cell_dino/linear.py +0 -1048
- dinov2/eval/cell_dino/utils.py +0 -542
- dinov2/eval/depth/__init__.py +0 -4
- dinov2/eval/depth/models/__init__.py +0 -10
- dinov2/eval/depth/models/backbones/__init__.py +0 -6
- dinov2/eval/depth/models/backbones/vision_transformer.py +0 -16
- dinov2/eval/depth/models/builder.py +0 -49
dinov2/__init__.py
DELETED
|
@@ -1,6 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
__version__ = "0.0.1"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dinov2/configs/__init__.py
DELETED
|
@@ -1,22 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
import pathlib
|
| 7 |
-
|
| 8 |
-
from omegaconf import OmegaConf
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
def load_config(config_name: str):
|
| 12 |
-
config_filename = config_name + ".yaml"
|
| 13 |
-
return OmegaConf.load(pathlib.Path(__file__).parent.resolve() / config_filename)
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
dinov2_default_config = load_config("ssl_default_config")
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
def load_and_merge_config(config_name: str):
|
| 20 |
-
default_config = OmegaConf.create(dinov2_default_config)
|
| 21 |
-
loaded_config = load_config(config_name)
|
| 22 |
-
return OmegaConf.merge(default_config, loaded_config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dinov2/configs/eval/cell_dino/vitl16_channel_adaptive_pretrain.yaml
DELETED
|
@@ -1,35 +0,0 @@
|
|
| 1 |
-
train:
|
| 2 |
-
batch_size_per_gpu: 32
|
| 3 |
-
OFFICIAL_EPOCH_LENGTH: 450
|
| 4 |
-
cell_augmentation: true
|
| 5 |
-
channel_adaptive: true
|
| 6 |
-
student:
|
| 7 |
-
arch: vit_large
|
| 8 |
-
patch_size: 16
|
| 9 |
-
num_register_tokens: 0
|
| 10 |
-
interpolate_antialias: false
|
| 11 |
-
interpolate_offset: 0.1
|
| 12 |
-
drop_path_rate: 0.1
|
| 13 |
-
in_chans: 1
|
| 14 |
-
block_chunks: 4
|
| 15 |
-
channel_adaptive: true
|
| 16 |
-
teacher:
|
| 17 |
-
momentum_teacher: 0.996
|
| 18 |
-
warmup_teacher_temp_epochs: 20
|
| 19 |
-
in_chans: 1
|
| 20 |
-
channel_adaptive: true
|
| 21 |
-
crops:
|
| 22 |
-
global_crops_scale:
|
| 23 |
-
- 0.4
|
| 24 |
-
- 1.0
|
| 25 |
-
local_crops_number: 8
|
| 26 |
-
local_crops_scale:
|
| 27 |
-
- 0.005
|
| 28 |
-
- 0.4
|
| 29 |
-
global_crops_size: 224
|
| 30 |
-
local_crops_size: 96
|
| 31 |
-
optim:
|
| 32 |
-
weight_decay_end: 0.2
|
| 33 |
-
base_lr: 5.0e-4
|
| 34 |
-
warmup_epochs: 20
|
| 35 |
-
epochs: 400
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dinov2/configs/eval/cell_dino/vitl16_pretrain.yaml
DELETED
|
@@ -1,14 +0,0 @@
|
|
| 1 |
-
student:
|
| 2 |
-
arch: vit_large
|
| 3 |
-
patch_size: 16
|
| 4 |
-
num_register_tokens: 0
|
| 5 |
-
interpolate_antialias: false
|
| 6 |
-
interpolate_offset: 0.1
|
| 7 |
-
drop_path_rate: 0.1
|
| 8 |
-
in_chans: 4
|
| 9 |
-
block_chunks: 4
|
| 10 |
-
teacher:
|
| 11 |
-
in_chans: 4
|
| 12 |
-
crops:
|
| 13 |
-
global_crops_size: 224
|
| 14 |
-
local_crops_size: 96
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dinov2/configs/eval/vitb14_pretrain.yaml
DELETED
|
@@ -1,6 +0,0 @@
|
|
| 1 |
-
student:
|
| 2 |
-
arch: vit_base
|
| 3 |
-
patch_size: 14
|
| 4 |
-
crops:
|
| 5 |
-
global_crops_size: 518 # this is to set up the position embeddings properly
|
| 6 |
-
local_crops_size: 98
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dinov2/configs/eval/vitb14_reg4_pretrain.yaml
DELETED
|
@@ -1,9 +0,0 @@
|
|
| 1 |
-
student:
|
| 2 |
-
arch: vit_base
|
| 3 |
-
patch_size: 14
|
| 4 |
-
num_register_tokens: 4
|
| 5 |
-
interpolate_antialias: true
|
| 6 |
-
interpolate_offset: 0.0
|
| 7 |
-
crops:
|
| 8 |
-
global_crops_size: 518 # this is to set up the position embeddings properly
|
| 9 |
-
local_crops_size: 98
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dinov2/configs/eval/vitg14_pretrain.yaml
DELETED
|
@@ -1,7 +0,0 @@
|
|
| 1 |
-
student:
|
| 2 |
-
arch: vit_giant2
|
| 3 |
-
patch_size: 14
|
| 4 |
-
ffn_layer: swiglufused
|
| 5 |
-
crops:
|
| 6 |
-
global_crops_size: 518 # this is to set up the position embeddings properly
|
| 7 |
-
local_crops_size: 98
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dinov2/configs/eval/vitg14_reg4_pretrain.yaml
DELETED
|
@@ -1,10 +0,0 @@
|
|
| 1 |
-
student:
|
| 2 |
-
arch: vit_giant2
|
| 3 |
-
patch_size: 14
|
| 4 |
-
ffn_layer: swiglufused
|
| 5 |
-
num_register_tokens: 4
|
| 6 |
-
interpolate_antialias: true
|
| 7 |
-
interpolate_offset: 0.0
|
| 8 |
-
crops:
|
| 9 |
-
global_crops_size: 518 # this is to set up the position embeddings properly
|
| 10 |
-
local_crops_size: 98
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dinov2/configs/eval/vitl14_pretrain.yaml
DELETED
|
@@ -1,6 +0,0 @@
|
|
| 1 |
-
student:
|
| 2 |
-
arch: vit_large
|
| 3 |
-
patch_size: 14
|
| 4 |
-
crops:
|
| 5 |
-
global_crops_size: 518 # this is to set up the position embeddings properly
|
| 6 |
-
local_crops_size: 98
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dinov2/configs/eval/vitl14_reg4_pretrain.yaml
DELETED
|
@@ -1,9 +0,0 @@
|
|
| 1 |
-
student:
|
| 2 |
-
arch: vit_large
|
| 3 |
-
patch_size: 14
|
| 4 |
-
num_register_tokens: 4
|
| 5 |
-
interpolate_antialias: true
|
| 6 |
-
interpolate_offset: 0.0
|
| 7 |
-
crops:
|
| 8 |
-
global_crops_size: 518 # this is to set up the position embeddings properly
|
| 9 |
-
local_crops_size: 98
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dinov2/configs/eval/vits14_pretrain.yaml
DELETED
|
@@ -1,6 +0,0 @@
|
|
| 1 |
-
student:
|
| 2 |
-
arch: vit_small
|
| 3 |
-
patch_size: 14
|
| 4 |
-
crops:
|
| 5 |
-
global_crops_size: 518 # this is to set up the position embeddings properly
|
| 6 |
-
local_crops_size: 98
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dinov2/configs/eval/vits14_reg4_pretrain.yaml
DELETED
|
@@ -1,9 +0,0 @@
|
|
| 1 |
-
student:
|
| 2 |
-
arch: vit_small
|
| 3 |
-
patch_size: 14
|
| 4 |
-
num_register_tokens: 4
|
| 5 |
-
interpolate_antialias: true
|
| 6 |
-
interpolate_offset: 0.0
|
| 7 |
-
crops:
|
| 8 |
-
global_crops_size: 518 # this is to set up the position embeddings properly
|
| 9 |
-
local_crops_size: 98
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dinov2/configs/ssl_default_config.yaml
DELETED
|
@@ -1,123 +0,0 @@
|
|
| 1 |
-
MODEL:
|
| 2 |
-
WEIGHTS: ''
|
| 3 |
-
compute_precision:
|
| 4 |
-
grad_scaler: true
|
| 5 |
-
teacher:
|
| 6 |
-
backbone:
|
| 7 |
-
sharding_strategy: SHARD_GRAD_OP
|
| 8 |
-
mixed_precision:
|
| 9 |
-
param_dtype: fp16
|
| 10 |
-
reduce_dtype: fp16
|
| 11 |
-
buffer_dtype: fp32
|
| 12 |
-
dino_head:
|
| 13 |
-
sharding_strategy: SHARD_GRAD_OP
|
| 14 |
-
mixed_precision:
|
| 15 |
-
param_dtype: fp16
|
| 16 |
-
reduce_dtype: fp16
|
| 17 |
-
buffer_dtype: fp32
|
| 18 |
-
ibot_head:
|
| 19 |
-
sharding_strategy: SHARD_GRAD_OP
|
| 20 |
-
mixed_precision:
|
| 21 |
-
param_dtype: fp16
|
| 22 |
-
reduce_dtype: fp16
|
| 23 |
-
buffer_dtype: fp32
|
| 24 |
-
student:
|
| 25 |
-
backbone:
|
| 26 |
-
sharding_strategy: SHARD_GRAD_OP
|
| 27 |
-
mixed_precision:
|
| 28 |
-
param_dtype: fp16
|
| 29 |
-
reduce_dtype: fp16
|
| 30 |
-
buffer_dtype: fp32
|
| 31 |
-
dino_head:
|
| 32 |
-
sharding_strategy: SHARD_GRAD_OP
|
| 33 |
-
mixed_precision:
|
| 34 |
-
param_dtype: fp16
|
| 35 |
-
reduce_dtype: fp32
|
| 36 |
-
buffer_dtype: fp32
|
| 37 |
-
ibot_head:
|
| 38 |
-
sharding_strategy: SHARD_GRAD_OP
|
| 39 |
-
mixed_precision:
|
| 40 |
-
param_dtype: fp16
|
| 41 |
-
reduce_dtype: fp32
|
| 42 |
-
buffer_dtype: fp32
|
| 43 |
-
dino:
|
| 44 |
-
loss_weight: 1.0
|
| 45 |
-
head_n_prototypes: 65536
|
| 46 |
-
head_bottleneck_dim: 256
|
| 47 |
-
head_nlayers: 3
|
| 48 |
-
head_hidden_dim: 2048
|
| 49 |
-
koleo_loss_weight: 0.1
|
| 50 |
-
ibot:
|
| 51 |
-
loss_weight: 1.0
|
| 52 |
-
mask_sample_probability: 0.5
|
| 53 |
-
mask_ratio_min_max:
|
| 54 |
-
- 0.1
|
| 55 |
-
- 0.5
|
| 56 |
-
separate_head: false
|
| 57 |
-
head_n_prototypes: 65536
|
| 58 |
-
head_bottleneck_dim: 256
|
| 59 |
-
head_nlayers: 3
|
| 60 |
-
head_hidden_dim: 2048
|
| 61 |
-
train:
|
| 62 |
-
batch_size_per_gpu: 64
|
| 63 |
-
dataset_path: ImageNet:split=TRAIN
|
| 64 |
-
output_dir: .
|
| 65 |
-
saveckp_freq: 20
|
| 66 |
-
seed: 0
|
| 67 |
-
num_workers: 10
|
| 68 |
-
OFFICIAL_EPOCH_LENGTH: 1250
|
| 69 |
-
cache_dataset: true
|
| 70 |
-
centering: "centering" # or "sinkhorn_knopp"
|
| 71 |
-
cell_augmentation: false
|
| 72 |
-
student:
|
| 73 |
-
arch: vit_large
|
| 74 |
-
patch_size: 16
|
| 75 |
-
drop_path_rate: 0.3
|
| 76 |
-
layerscale: 1.0e-05
|
| 77 |
-
drop_path_uniform: true
|
| 78 |
-
pretrained_weights: ''
|
| 79 |
-
ffn_layer: "mlp"
|
| 80 |
-
block_chunks: 0
|
| 81 |
-
qkv_bias: true
|
| 82 |
-
proj_bias: true
|
| 83 |
-
ffn_bias: true
|
| 84 |
-
num_register_tokens: 0
|
| 85 |
-
interpolate_antialias: false
|
| 86 |
-
interpolate_offset: 0.1
|
| 87 |
-
in_chans: 3
|
| 88 |
-
channel_adaptive: false
|
| 89 |
-
teacher:
|
| 90 |
-
momentum_teacher: 0.992
|
| 91 |
-
final_momentum_teacher: 1
|
| 92 |
-
warmup_teacher_temp: 0.04
|
| 93 |
-
teacher_temp: 0.07
|
| 94 |
-
warmup_teacher_temp_epochs: 30
|
| 95 |
-
in_chans: 3
|
| 96 |
-
channel_adaptive: false
|
| 97 |
-
optim:
|
| 98 |
-
epochs: 100
|
| 99 |
-
weight_decay: 0.04
|
| 100 |
-
weight_decay_end: 0.4
|
| 101 |
-
base_lr: 0.004 # learning rate for a batch size of 1024
|
| 102 |
-
lr: 0. # will be set after applying scaling rule
|
| 103 |
-
warmup_epochs: 10
|
| 104 |
-
min_lr: 1.0e-06
|
| 105 |
-
clip_grad: 3.0
|
| 106 |
-
freeze_last_layer_epochs: 1
|
| 107 |
-
scaling_rule: sqrt_wrt_1024
|
| 108 |
-
patch_embed_lr_mult: 0.2
|
| 109 |
-
layerwise_decay: 0.9
|
| 110 |
-
adamw_beta1: 0.9
|
| 111 |
-
adamw_beta2: 0.999
|
| 112 |
-
crops:
|
| 113 |
-
global_crops_scale:
|
| 114 |
-
- 0.32
|
| 115 |
-
- 1.0
|
| 116 |
-
local_crops_number: 8
|
| 117 |
-
local_crops_scale:
|
| 118 |
-
- 0.05
|
| 119 |
-
- 0.32
|
| 120 |
-
global_crops_size: 224
|
| 121 |
-
local_crops_size: 96
|
| 122 |
-
evaluation:
|
| 123 |
-
eval_period_iterations: 12500
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dinov2/configs/train/cell_dino/vitl16_boc_hpafov.yaml
DELETED
|
@@ -1,31 +0,0 @@
|
|
| 1 |
-
train:
|
| 2 |
-
batch_size_per_gpu: 16
|
| 3 |
-
OFFICIAL_EPOCH_LENGTH: 450
|
| 4 |
-
cell_augmentation: true
|
| 5 |
-
channel_adaptive: true
|
| 6 |
-
student:
|
| 7 |
-
arch: vit_large
|
| 8 |
-
patch_size: 16
|
| 9 |
-
in_chans: 1
|
| 10 |
-
drop_path_rate: 0.1
|
| 11 |
-
block_chunks: 4
|
| 12 |
-
teacher:
|
| 13 |
-
momentum_teacher: 0.996
|
| 14 |
-
warmup_teacher_temp_epochs: 20
|
| 15 |
-
in_chans: 1
|
| 16 |
-
crops:
|
| 17 |
-
global_crops_scale:
|
| 18 |
-
- 0.4
|
| 19 |
-
- 1.0
|
| 20 |
-
local_crops_number: 8
|
| 21 |
-
local_crops_scale:
|
| 22 |
-
- 0.005
|
| 23 |
-
- 0.4
|
| 24 |
-
global_crops_size: 224
|
| 25 |
-
local_crops_size: 96
|
| 26 |
-
optim:
|
| 27 |
-
weight_decay_end: 0.2
|
| 28 |
-
base_lr: 5.0e-4
|
| 29 |
-
warmup_epochs: 20
|
| 30 |
-
epochs: 400
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dinov2/configs/train/cell_dino/vitl16_hpafov.yaml
DELETED
|
@@ -1,32 +0,0 @@
|
|
| 1 |
-
train:
|
| 2 |
-
batch_size_per_gpu: 16
|
| 3 |
-
OFFICIAL_EPOCH_LENGTH: 450
|
| 4 |
-
cell_augmentation: true
|
| 5 |
-
student:
|
| 6 |
-
arch: vit_large
|
| 7 |
-
patch_size: 16
|
| 8 |
-
in_chans: 4
|
| 9 |
-
drop_path_rate: 0.1
|
| 10 |
-
block_chunks: 4
|
| 11 |
-
teacher:
|
| 12 |
-
momentum_teacher: 0.996
|
| 13 |
-
warmup_teacher_temp_epochs: 20
|
| 14 |
-
in_chans: 4
|
| 15 |
-
optim:
|
| 16 |
-
weight_decay_end: 0.2
|
| 17 |
-
base_lr: 5.0e-4
|
| 18 |
-
warmup_epochs: 20
|
| 19 |
-
crops:
|
| 20 |
-
global_crops_scale:
|
| 21 |
-
- 0.4
|
| 22 |
-
- 1.0
|
| 23 |
-
local_crops_number: 8
|
| 24 |
-
local_crops_scale:
|
| 25 |
-
- 0.005
|
| 26 |
-
- 0.4
|
| 27 |
-
global_crops_size: 224
|
| 28 |
-
local_crops_size: 96
|
| 29 |
-
evaluation:
|
| 30 |
-
eval_period_iterations: 9000
|
| 31 |
-
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dinov2/configs/train/cell_dino/vitl16_hpaone.yaml
DELETED
|
@@ -1,30 +0,0 @@
|
|
| 1 |
-
train:
|
| 2 |
-
batch_size_per_gpu: 16
|
| 3 |
-
OFFICIAL_EPOCH_LENGTH: 1756
|
| 4 |
-
cell_augmentation: true
|
| 5 |
-
student:
|
| 6 |
-
arch: vit_large
|
| 7 |
-
patch_size: 16
|
| 8 |
-
in_chans: 4
|
| 9 |
-
drop_path_rate: 0.1
|
| 10 |
-
block_chunks: 4
|
| 11 |
-
teacher:
|
| 12 |
-
momentum_teacher: 0.996
|
| 13 |
-
warmup_teacher_temp_epochs: 20
|
| 14 |
-
in_chans: 4
|
| 15 |
-
optim:
|
| 16 |
-
weight_decay_end: 0.2
|
| 17 |
-
base_lr: 5.0e-4
|
| 18 |
-
warmup_epochs: 20
|
| 19 |
-
crops:
|
| 20 |
-
global_crops_scale:
|
| 21 |
-
- 0.4
|
| 22 |
-
- 1.0
|
| 23 |
-
local_crops_number: 8
|
| 24 |
-
local_crops_scale:
|
| 25 |
-
- 0.005
|
| 26 |
-
- 0.4
|
| 27 |
-
global_crops_size: 224
|
| 28 |
-
local_crops_size: 96
|
| 29 |
-
evaluation:
|
| 30 |
-
eval_period_iterations: 9000
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dinov2/configs/train/vitg14.yaml
DELETED
|
@@ -1,26 +0,0 @@
|
|
| 1 |
-
dino:
|
| 2 |
-
head_n_prototypes: 131072
|
| 3 |
-
head_bottleneck_dim: 384
|
| 4 |
-
ibot:
|
| 5 |
-
separate_head: true
|
| 6 |
-
head_n_prototypes: 131072
|
| 7 |
-
train:
|
| 8 |
-
batch_size_per_gpu: 12
|
| 9 |
-
dataset_path: ImageNet22k
|
| 10 |
-
centering: sinkhorn_knopp
|
| 11 |
-
student:
|
| 12 |
-
arch: vit_giant2
|
| 13 |
-
patch_size: 14
|
| 14 |
-
drop_path_rate: 0.4
|
| 15 |
-
ffn_layer: swiglufused
|
| 16 |
-
block_chunks: 4
|
| 17 |
-
teacher:
|
| 18 |
-
momentum_teacher: 0.994
|
| 19 |
-
optim:
|
| 20 |
-
epochs: 500
|
| 21 |
-
weight_decay_end: 0.2
|
| 22 |
-
base_lr: 2.0e-04 # learning rate for a batch size of 1024
|
| 23 |
-
warmup_epochs: 80
|
| 24 |
-
layerwise_decay: 1.0
|
| 25 |
-
crops:
|
| 26 |
-
local_crops_size: 98
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dinov2/configs/train/vitl14.yaml
DELETED
|
@@ -1,26 +0,0 @@
|
|
| 1 |
-
dino:
|
| 2 |
-
head_n_prototypes: 131072
|
| 3 |
-
head_bottleneck_dim: 384
|
| 4 |
-
ibot:
|
| 5 |
-
separate_head: true
|
| 6 |
-
head_n_prototypes: 131072
|
| 7 |
-
train:
|
| 8 |
-
batch_size_per_gpu: 32
|
| 9 |
-
dataset_path: ImageNet22k
|
| 10 |
-
centering: sinkhorn_knopp
|
| 11 |
-
student:
|
| 12 |
-
arch: vit_large
|
| 13 |
-
patch_size: 14
|
| 14 |
-
drop_path_rate: 0.4
|
| 15 |
-
ffn_layer: swiglufused
|
| 16 |
-
block_chunks: 4
|
| 17 |
-
teacher:
|
| 18 |
-
momentum_teacher: 0.994
|
| 19 |
-
optim:
|
| 20 |
-
epochs: 500
|
| 21 |
-
weight_decay_end: 0.2
|
| 22 |
-
base_lr: 2.0e-04 # learning rate for a batch size of 1024
|
| 23 |
-
warmup_epochs: 80
|
| 24 |
-
layerwise_decay: 1.0
|
| 25 |
-
crops:
|
| 26 |
-
local_crops_size: 98
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dinov2/configs/train/vitl16_short.yaml
DELETED
|
@@ -1,6 +0,0 @@
|
|
| 1 |
-
# this corresponds to the default config
|
| 2 |
-
train:
|
| 3 |
-
dataset_path: ImageNet:split=TRAIN
|
| 4 |
-
batch_size_per_gpu: 64
|
| 5 |
-
student:
|
| 6 |
-
block_chunks: 4
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dinov2/data/__init__.py
DELETED
|
@@ -1,12 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
from .adapters import DatasetWithEnumeratedTargets
|
| 7 |
-
from .loaders import make_data_loader, make_dataset, SamplerType
|
| 8 |
-
from .collate import collate_data_and_cast
|
| 9 |
-
from .masking import MaskingGenerator
|
| 10 |
-
from .augmentations import DataAugmentationDINO
|
| 11 |
-
from .cell_dino.augmentations import CellAugmentationDINO
|
| 12 |
-
from .accumulators import NoOpAccumulator, ResultsAccumulator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dinov2/data/accumulators.py
DELETED
|
@@ -1,133 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
from collections import defaultdict
|
| 7 |
-
from typing import Dict, List, Optional, Any
|
| 8 |
-
|
| 9 |
-
import torch
|
| 10 |
-
from torch import Tensor
|
| 11 |
-
from torch.nn import functional as F
|
| 12 |
-
|
| 13 |
-
import torch.distributed as dist
|
| 14 |
-
from dinov2.distributed import get_global_size
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
def _simple_gather_all_tensors(result: torch.Tensor, group: Any, world_size: int) -> List[torch.Tensor]:
|
| 18 |
-
gathered_result = [torch.zeros_like(result) for _ in range(world_size)]
|
| 19 |
-
dist.all_gather(gathered_result, result, group)
|
| 20 |
-
return gathered_result
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
def gather_all_tensors(result: torch.Tensor, group: Optional[Any] = None) -> List[torch.Tensor]:
|
| 24 |
-
"""
|
| 25 |
-
Copied from https://github.com/Lightning-AI/torchmetrics/blob/master/src/torchmetrics/utilities/distributed.py
|
| 26 |
-
Gather all tensors from several ddp processes onto a list that is broadcasted to all processes.
|
| 27 |
-
|
| 28 |
-
Works on tensors that have the same number of dimensions, but where each dimension may differ. In this case
|
| 29 |
-
tensors are padded, gathered and then trimmed to secure equal workload for all processes.
|
| 30 |
-
|
| 31 |
-
Args:
|
| 32 |
-
result: the value to sync
|
| 33 |
-
group: the process group to gather results from. Defaults to all processes (world)
|
| 34 |
-
|
| 35 |
-
Return:
|
| 36 |
-
list with size equal to the process group where element i corresponds to result tensor from process i
|
| 37 |
-
"""
|
| 38 |
-
# convert tensors to contiguous format
|
| 39 |
-
result = result.contiguous()
|
| 40 |
-
|
| 41 |
-
world_size = get_global_size()
|
| 42 |
-
dist.barrier(group=group)
|
| 43 |
-
|
| 44 |
-
# if the tensor is scalar, things are easy
|
| 45 |
-
if result.ndim == 0:
|
| 46 |
-
return _simple_gather_all_tensors(result, group, world_size)
|
| 47 |
-
|
| 48 |
-
# 1. Gather sizes of all tensors
|
| 49 |
-
local_size = torch.tensor(result.shape, device=result.device)
|
| 50 |
-
local_sizes = [torch.zeros_like(local_size) for _ in range(world_size)]
|
| 51 |
-
dist.all_gather(local_sizes, local_size, group=group)
|
| 52 |
-
max_size = torch.stack(local_sizes).max(dim=0).values
|
| 53 |
-
all_sizes_equal = all(all(ls == max_size) for ls in local_sizes)
|
| 54 |
-
|
| 55 |
-
# 2. If shapes are all the same, then do a simple gather:
|
| 56 |
-
if all_sizes_equal:
|
| 57 |
-
return _simple_gather_all_tensors(result, group, world_size)
|
| 58 |
-
|
| 59 |
-
# 3. If not, we need to pad each local tensor to maximum size, gather and then truncate
|
| 60 |
-
pad_dims = []
|
| 61 |
-
pad_by = (max_size - local_size).detach().cpu()
|
| 62 |
-
for val in reversed(pad_by):
|
| 63 |
-
pad_dims.append(0)
|
| 64 |
-
pad_dims.append(val.item())
|
| 65 |
-
result_padded = F.pad(result, pad_dims)
|
| 66 |
-
gathered_result = [torch.zeros_like(result_padded) for _ in range(world_size)]
|
| 67 |
-
dist.all_gather(gathered_result, result_padded, group)
|
| 68 |
-
for idx, item_size in enumerate(local_sizes):
|
| 69 |
-
slice_param = [slice(dim_size) for dim_size in item_size]
|
| 70 |
-
gathered_result[idx] = gathered_result[idx][slice_param]
|
| 71 |
-
return gathered_result
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
def _cat_and_gather_tensor_list(tensor_list: List[Tensor]) -> Tensor:
|
| 75 |
-
local_cat = torch.cat(tensor_list)
|
| 76 |
-
return torch.cat(gather_all_tensors(local_cat))
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
class Accumulator:
|
| 80 |
-
def __init__(self) -> None:
|
| 81 |
-
pass
|
| 82 |
-
|
| 83 |
-
def update(self, preds: Tensor, target: Tensor, index: Tensor) -> None:
|
| 84 |
-
raise NotImplementedError
|
| 85 |
-
|
| 86 |
-
def accumulate(self) -> Optional[Dict[str, Tensor]]:
|
| 87 |
-
raise NotImplementedError
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
class NoOpAccumulator(Accumulator):
|
| 91 |
-
def __init__(self) -> None:
|
| 92 |
-
pass
|
| 93 |
-
|
| 94 |
-
def update(self, preds: Tensor, target: Tensor, index: Tensor) -> None:
|
| 95 |
-
pass
|
| 96 |
-
|
| 97 |
-
def accumulate(self) -> None:
|
| 98 |
-
return None
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
class ResultsAccumulator(Accumulator):
|
| 102 |
-
"""
|
| 103 |
-
Accumulate predictions and targets across processes
|
| 104 |
-
"""
|
| 105 |
-
|
| 106 |
-
def __init__(self) -> None:
|
| 107 |
-
self._local_values: Dict[str, List[Tensor]] = defaultdict(list)
|
| 108 |
-
self._gathered_values: Dict[str, Tensor] = {}
|
| 109 |
-
self._gathered = False
|
| 110 |
-
|
| 111 |
-
def update(self, preds: Tensor, target: Tensor, index: Tensor) -> None:
|
| 112 |
-
assert len(preds) == len(target) == len(index)
|
| 113 |
-
assert not self._gathered, "Tensors have already been gathered in this helper"
|
| 114 |
-
self._local_values["preds"].append(preds)
|
| 115 |
-
self._local_values["target"].append(target)
|
| 116 |
-
self._local_values["index"].append(index)
|
| 117 |
-
self._gathered = False
|
| 118 |
-
|
| 119 |
-
def _gather_tensors(self):
|
| 120 |
-
for k, tensor_list in self._local_values.items():
|
| 121 |
-
self._gathered_values[k] = _cat_and_gather_tensor_list(tensor_list)
|
| 122 |
-
self._gathered = True
|
| 123 |
-
|
| 124 |
-
def accumulate(self) -> Dict[str, Tensor]:
|
| 125 |
-
if not self._gathered:
|
| 126 |
-
self._gather_tensors()
|
| 127 |
-
preds, target, index = [self._gathered_values[k] for k in ["preds", "target", "index"]]
|
| 128 |
-
assert len(preds) == len(target) == len(index) and index.min() == 0
|
| 129 |
-
preds_ordered = torch.zeros((index.max() + 1, *preds.shape[1:]), dtype=preds.dtype, device=preds.device)
|
| 130 |
-
preds_ordered[index] = preds
|
| 131 |
-
target_ordered = torch.zeros((index.max() + 1, *target.shape[1:]), dtype=target.dtype, device=target.device)
|
| 132 |
-
target_ordered[index] = target
|
| 133 |
-
return {"preds": preds_ordered, "target": target_ordered}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dinov2/data/adapters.py
DELETED
|
@@ -1,51 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
from typing import Any, Tuple, Optional
|
| 7 |
-
|
| 8 |
-
from torch.utils.data import Dataset
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
class DatasetWithEnumeratedTargets(Dataset):
|
| 12 |
-
"""
|
| 13 |
-
If pad_dataset is set, pads based on torch's DistributedSampler implementation, which
|
| 14 |
-
with drop_last=False pads the last batch to be a multiple of the world size.
|
| 15 |
-
https://github.com/pytorch/pytorch/blob/main/torch/utils/data/distributed.py#L91
|
| 16 |
-
"""
|
| 17 |
-
|
| 18 |
-
def __init__(self, dataset: Dataset, pad_dataset: bool = False, num_replicas: Optional[int] = None):
|
| 19 |
-
self._dataset = dataset
|
| 20 |
-
self._size = len(self._dataset)
|
| 21 |
-
self._padded_size = self._size
|
| 22 |
-
self._pad_dataset = pad_dataset
|
| 23 |
-
if self._pad_dataset:
|
| 24 |
-
assert num_replicas is not None, "num_replicas should be set if pad_dataset is True"
|
| 25 |
-
self._padded_size = num_replicas * ((len(dataset) + num_replicas - 1) // num_replicas)
|
| 26 |
-
|
| 27 |
-
def get_image_relpath(self, index: int) -> str:
|
| 28 |
-
assert self._pad_dataset or index < self._size
|
| 29 |
-
return self._dataset.get_image_relpath(index % self._size)
|
| 30 |
-
|
| 31 |
-
def get_image_data(self, index: int) -> bytes:
|
| 32 |
-
assert self._pad_dataset or index < self._size
|
| 33 |
-
return self._dataset.get_image_data(index % self._size)
|
| 34 |
-
|
| 35 |
-
def get_target(self, index: int) -> Tuple[Any, int]:
|
| 36 |
-
target = self._dataset.get_target(index % self._size)
|
| 37 |
-
if index >= self._size:
|
| 38 |
-
assert self._pad_dataset
|
| 39 |
-
return (-1, target)
|
| 40 |
-
return (index, target)
|
| 41 |
-
|
| 42 |
-
def __getitem__(self, index: int) -> Tuple[Any, Tuple[Any, int]]:
|
| 43 |
-
image, target = self._dataset[index % self._size]
|
| 44 |
-
if index >= self._size:
|
| 45 |
-
assert self._pad_dataset
|
| 46 |
-
return image, (-1, target)
|
| 47 |
-
target = index if target is None else target
|
| 48 |
-
return image, (index, target)
|
| 49 |
-
|
| 50 |
-
def __len__(self) -> int:
|
| 51 |
-
return self._padded_size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dinov2/data/augmentations.py
DELETED
|
@@ -1,118 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
import logging
|
| 7 |
-
|
| 8 |
-
from torchvision import transforms
|
| 9 |
-
|
| 10 |
-
from .transforms import (
|
| 11 |
-
GaussianBlur,
|
| 12 |
-
make_normalize_transform,
|
| 13 |
-
)
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
logger = logging.getLogger("dinov2")
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
class DataAugmentationDINO(object):
|
| 20 |
-
def __init__(
|
| 21 |
-
self,
|
| 22 |
-
global_crops_scale,
|
| 23 |
-
local_crops_scale,
|
| 24 |
-
local_crops_number,
|
| 25 |
-
global_crops_size=224,
|
| 26 |
-
local_crops_size=96,
|
| 27 |
-
):
|
| 28 |
-
self.global_crops_scale = global_crops_scale
|
| 29 |
-
self.local_crops_scale = local_crops_scale
|
| 30 |
-
self.local_crops_number = local_crops_number
|
| 31 |
-
self.global_crops_size = global_crops_size
|
| 32 |
-
self.local_crops_size = local_crops_size
|
| 33 |
-
|
| 34 |
-
logger.info("###################################")
|
| 35 |
-
logger.info("Using data augmentation parameters:")
|
| 36 |
-
logger.info(f"global_crops_scale: {global_crops_scale}")
|
| 37 |
-
logger.info(f"local_crops_scale: {local_crops_scale}")
|
| 38 |
-
logger.info(f"local_crops_number: {local_crops_number}")
|
| 39 |
-
logger.info(f"global_crops_size: {global_crops_size}")
|
| 40 |
-
logger.info(f"local_crops_size: {local_crops_size}")
|
| 41 |
-
logger.info("###################################")
|
| 42 |
-
|
| 43 |
-
# random resized crop and flip
|
| 44 |
-
self.geometric_augmentation_global = transforms.Compose(
|
| 45 |
-
[
|
| 46 |
-
transforms.RandomResizedCrop(
|
| 47 |
-
global_crops_size, scale=global_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC
|
| 48 |
-
),
|
| 49 |
-
transforms.RandomHorizontalFlip(p=0.5),
|
| 50 |
-
]
|
| 51 |
-
)
|
| 52 |
-
|
| 53 |
-
self.geometric_augmentation_local = transforms.Compose(
|
| 54 |
-
[
|
| 55 |
-
transforms.RandomResizedCrop(
|
| 56 |
-
local_crops_size, scale=local_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC
|
| 57 |
-
),
|
| 58 |
-
transforms.RandomHorizontalFlip(p=0.5),
|
| 59 |
-
]
|
| 60 |
-
)
|
| 61 |
-
|
| 62 |
-
# color distorsions / blurring
|
| 63 |
-
color_jittering = transforms.Compose(
|
| 64 |
-
[
|
| 65 |
-
transforms.RandomApply(
|
| 66 |
-
[transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)],
|
| 67 |
-
p=0.8,
|
| 68 |
-
),
|
| 69 |
-
transforms.RandomGrayscale(p=0.2),
|
| 70 |
-
]
|
| 71 |
-
)
|
| 72 |
-
|
| 73 |
-
global_transfo1_extra = GaussianBlur(p=1.0)
|
| 74 |
-
|
| 75 |
-
global_transfo2_extra = transforms.Compose(
|
| 76 |
-
[
|
| 77 |
-
GaussianBlur(p=0.1),
|
| 78 |
-
transforms.RandomSolarize(threshold=128, p=0.2),
|
| 79 |
-
]
|
| 80 |
-
)
|
| 81 |
-
|
| 82 |
-
local_transfo_extra = GaussianBlur(p=0.5)
|
| 83 |
-
|
| 84 |
-
# normalization
|
| 85 |
-
self.normalize = transforms.Compose(
|
| 86 |
-
[
|
| 87 |
-
transforms.ToTensor(),
|
| 88 |
-
make_normalize_transform(),
|
| 89 |
-
]
|
| 90 |
-
)
|
| 91 |
-
|
| 92 |
-
self.global_transfo1 = transforms.Compose([color_jittering, global_transfo1_extra, self.normalize])
|
| 93 |
-
self.global_transfo2 = transforms.Compose([color_jittering, global_transfo2_extra, self.normalize])
|
| 94 |
-
self.local_transfo = transforms.Compose([color_jittering, local_transfo_extra, self.normalize])
|
| 95 |
-
|
| 96 |
-
def __call__(self, image):
|
| 97 |
-
output = {}
|
| 98 |
-
|
| 99 |
-
# global crops:
|
| 100 |
-
im1_base = self.geometric_augmentation_global(image)
|
| 101 |
-
global_crop_1 = self.global_transfo1(im1_base)
|
| 102 |
-
|
| 103 |
-
im2_base = self.geometric_augmentation_global(image)
|
| 104 |
-
global_crop_2 = self.global_transfo2(im2_base)
|
| 105 |
-
|
| 106 |
-
output["global_crops"] = [global_crop_1, global_crop_2]
|
| 107 |
-
|
| 108 |
-
# global crops for teacher:
|
| 109 |
-
output["global_crops_teacher"] = [global_crop_1, global_crop_2]
|
| 110 |
-
|
| 111 |
-
# local crops:
|
| 112 |
-
local_crops = [
|
| 113 |
-
self.local_transfo(self.geometric_augmentation_local(image)) for _ in range(self.local_crops_number)
|
| 114 |
-
]
|
| 115 |
-
output["local_crops"] = local_crops
|
| 116 |
-
output["offsets"] = ()
|
| 117 |
-
|
| 118 |
-
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dinov2/data/cell_dino/augmentations.py
DELETED
|
@@ -1,91 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the CC-by-NC licence,
|
| 4 |
-
# found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
import logging
|
| 7 |
-
import torchvision
|
| 8 |
-
from torchvision import transforms
|
| 9 |
-
|
| 10 |
-
from .transforms import (
|
| 11 |
-
RandomContrastProteinChannel,
|
| 12 |
-
RandomRemoveChannelExceptProtein,
|
| 13 |
-
RandomBrightness,
|
| 14 |
-
RandomContrast,
|
| 15 |
-
Div255,
|
| 16 |
-
SelfNormalizeNoDiv,
|
| 17 |
-
)
|
| 18 |
-
|
| 19 |
-
logger = logging.getLogger("dinov2")
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
class CellAugmentationDINO(object):
|
| 23 |
-
def __init__(
|
| 24 |
-
self,
|
| 25 |
-
global_crops_scale,
|
| 26 |
-
local_crops_scale,
|
| 27 |
-
local_crops_number,
|
| 28 |
-
global_crops_size=224,
|
| 29 |
-
local_crops_size=96,
|
| 30 |
-
):
|
| 31 |
-
self.global_crops_scale = global_crops_scale
|
| 32 |
-
self.local_crops_scale = local_crops_scale
|
| 33 |
-
self.local_crops_number = local_crops_number
|
| 34 |
-
self.global_crops_size = global_crops_size
|
| 35 |
-
self.local_crops_size = local_crops_size
|
| 36 |
-
|
| 37 |
-
logger.info("###################################")
|
| 38 |
-
logger.info("Using data augmentation parameters:")
|
| 39 |
-
logger.info(f"global_crops_scale: {global_crops_scale}")
|
| 40 |
-
logger.info(f"local_crops_scale: {local_crops_scale}")
|
| 41 |
-
logger.info(f"local_crops_number: {local_crops_number}")
|
| 42 |
-
logger.info(f"global_crops_size: {global_crops_size}")
|
| 43 |
-
logger.info(f"local_crops_size: {local_crops_size}")
|
| 44 |
-
logger.info("###################################")
|
| 45 |
-
|
| 46 |
-
additional_transforms_list = [
|
| 47 |
-
torchvision.transforms.RandomHorizontalFlip(),
|
| 48 |
-
torchvision.transforms.RandomVerticalFlip(),
|
| 49 |
-
RandomBrightness(),
|
| 50 |
-
RandomContrast(),
|
| 51 |
-
SelfNormalizeNoDiv(),
|
| 52 |
-
]
|
| 53 |
-
|
| 54 |
-
first_transforms_list = [
|
| 55 |
-
Div255(),
|
| 56 |
-
RandomRemoveChannelExceptProtein(),
|
| 57 |
-
RandomContrastProteinChannel(),
|
| 58 |
-
]
|
| 59 |
-
|
| 60 |
-
global_transforms_list = first_transforms_list.copy()
|
| 61 |
-
global_transforms_list.append(
|
| 62 |
-
torchvision.transforms.RandomResizedCrop(size=global_crops_size, scale=global_crops_scale)
|
| 63 |
-
)
|
| 64 |
-
global_transforms_list = global_transforms_list + additional_transforms_list
|
| 65 |
-
|
| 66 |
-
local_transforms_list = first_transforms_list
|
| 67 |
-
local_transforms_list.append(
|
| 68 |
-
torchvision.transforms.RandomResizedCrop(size=local_crops_size, scale=local_crops_scale)
|
| 69 |
-
)
|
| 70 |
-
local_transforms_list = local_transforms_list + additional_transforms_list
|
| 71 |
-
|
| 72 |
-
self.global_transform = transforms.Compose(global_transforms_list)
|
| 73 |
-
self.local_transform = transforms.Compose(local_transforms_list)
|
| 74 |
-
|
| 75 |
-
def __call__(self, image):
|
| 76 |
-
output = {}
|
| 77 |
-
|
| 78 |
-
global_crop1 = self.global_transform(image)
|
| 79 |
-
global_crop2 = self.global_transform(image)
|
| 80 |
-
|
| 81 |
-
output["global_crops"] = [global_crop1, global_crop2]
|
| 82 |
-
|
| 83 |
-
local_crops = []
|
| 84 |
-
for _ in range(self.local_crops_number):
|
| 85 |
-
local_crops.append(self.local_transform(image))
|
| 86 |
-
|
| 87 |
-
output["local_crops"] = local_crops
|
| 88 |
-
output["global_crops_teacher"] = [global_crop1, global_crop2]
|
| 89 |
-
output["offsets"] = ()
|
| 90 |
-
|
| 91 |
-
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dinov2/data/cell_dino/transforms.py
DELETED
|
@@ -1,169 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the CC-by-NC licence,
|
| 4 |
-
# found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
import torch
|
| 7 |
-
from torchvision import transforms
|
| 8 |
-
import numpy as np
|
| 9 |
-
from enum import Enum
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
class NormalizationType(Enum):
|
| 13 |
-
SELF_NORM_AUG_DECODER = "self_norm_aug_decoder"
|
| 14 |
-
SELF_NORM_CENTER_CROP = "self_norm_center_crop"
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
class Div255(torch.nn.Module):
|
| 18 |
-
def forward(self, x):
|
| 19 |
-
x = x / 255
|
| 20 |
-
return x
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
class SelfNormalizeNoDiv(torch.nn.Module):
|
| 24 |
-
def forward(self, x):
|
| 25 |
-
m = x.mean((-2, -1), keepdim=True)
|
| 26 |
-
s = x.std((-2, -1), unbiased=False, keepdim=True)
|
| 27 |
-
x -= m
|
| 28 |
-
x /= s + 1e-7
|
| 29 |
-
return x
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
class SelfNormalize(torch.nn.Module):
|
| 33 |
-
def forward(self, x):
|
| 34 |
-
x = x / 255
|
| 35 |
-
m = x.mean((-2, -1), keepdim=True)
|
| 36 |
-
s = x.std((-2, -1), unbiased=False, keepdim=True)
|
| 37 |
-
x -= m
|
| 38 |
-
x /= s + 1e-7
|
| 39 |
-
return x
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
class RandomContrastProteinChannel(torch.nn.Module):
|
| 43 |
-
"""
|
| 44 |
-
Random constrast rescaling of the protein channel only.
|
| 45 |
-
RescaleProtein function in Dino4cell codebase.
|
| 46 |
-
"""
|
| 47 |
-
|
| 48 |
-
def __init__(self, p=0.2):
|
| 49 |
-
super().__init__()
|
| 50 |
-
self.p = p
|
| 51 |
-
|
| 52 |
-
def forward(self, img):
|
| 53 |
-
if img.max() == 0:
|
| 54 |
-
return img
|
| 55 |
-
if len(img) == 1:
|
| 56 |
-
return img
|
| 57 |
-
if np.random.rand() <= self.p:
|
| 58 |
-
random_factor = (np.random.rand() * 2) / img.max() # scaling
|
| 59 |
-
img[1] = img[1] * random_factor
|
| 60 |
-
return img
|
| 61 |
-
else:
|
| 62 |
-
return img
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
class RandomRemoveChannelExceptProtein(torch.nn.Module):
|
| 66 |
-
"""
|
| 67 |
-
dropping a channel at random except the channel 1, corresponding to proteins in HPA datasets.
|
| 68 |
-
"""
|
| 69 |
-
|
| 70 |
-
def __init__(self, p=0.2):
|
| 71 |
-
super().__init__()
|
| 72 |
-
self.p = p
|
| 73 |
-
|
| 74 |
-
def forward(self, img):
|
| 75 |
-
img_size = np.array(img).shape
|
| 76 |
-
if img_size[0] < 4:
|
| 77 |
-
return img
|
| 78 |
-
if np.random.rand() <= self.p:
|
| 79 |
-
channel_to_blacken = np.random.choice(np.array([0, 2, 3]))
|
| 80 |
-
img[channel_to_blacken] = torch.zeros(1, *img.shape[1:])
|
| 81 |
-
return img
|
| 82 |
-
else:
|
| 83 |
-
return img
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
class RandomRemoveChannel(torch.nn.Module):
|
| 87 |
-
"""
|
| 88 |
-
dropping a channel at random
|
| 89 |
-
"""
|
| 90 |
-
|
| 91 |
-
def __init__(self, p=0.2):
|
| 92 |
-
super().__init__()
|
| 93 |
-
self.p = p
|
| 94 |
-
|
| 95 |
-
def forward(self, img):
|
| 96 |
-
img_size = np.array(img).shape
|
| 97 |
-
num_channels = img_size[0]
|
| 98 |
-
if num_channels < 4:
|
| 99 |
-
return img
|
| 100 |
-
if np.random.rand() <= self.p:
|
| 101 |
-
channel_to_blacken = np.random.choice(np.array(list(range(num_channels))))
|
| 102 |
-
img[channel_to_blacken] = torch.zeros(1, *img.shape[1:])
|
| 103 |
-
return img
|
| 104 |
-
else:
|
| 105 |
-
return img
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
class RandomContrast(torch.nn.Module):
|
| 109 |
-
def __init__(self, p=0.2):
|
| 110 |
-
super().__init__()
|
| 111 |
-
self.p = p
|
| 112 |
-
|
| 113 |
-
def forward(self, img):
|
| 114 |
-
if img.max() == 0:
|
| 115 |
-
return img
|
| 116 |
-
n_channels = img.shape[0]
|
| 117 |
-
for ind in range(n_channels):
|
| 118 |
-
factor = max(np.random.normal(1, self.p), 0.5)
|
| 119 |
-
img[ind] = transforms.functional.adjust_contrast(img[ind][None, ...], factor)
|
| 120 |
-
return img
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
class RandomBrightness(torch.nn.Module):
|
| 124 |
-
def __init__(self, p=0.2):
|
| 125 |
-
super().__init__()
|
| 126 |
-
self.p = p
|
| 127 |
-
|
| 128 |
-
def forward(self, img):
|
| 129 |
-
if img.max() == 0:
|
| 130 |
-
return img
|
| 131 |
-
n_channels = img.shape[0]
|
| 132 |
-
for ind in range(n_channels):
|
| 133 |
-
factor = max(np.random.normal(1, self.p), 0.5)
|
| 134 |
-
img[ind] = transforms.functional.adjust_brightness(img[ind], factor)
|
| 135 |
-
return img
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
def make_classification_eval_cell_transform(
|
| 139 |
-
*,
|
| 140 |
-
resize_size: int = 0,
|
| 141 |
-
interpolation=transforms.InterpolationMode.BICUBIC,
|
| 142 |
-
crop_size: int = 384,
|
| 143 |
-
normalization_type: Enum = NormalizationType.SELF_NORM_CENTER_CROP,
|
| 144 |
-
) -> transforms.Compose:
|
| 145 |
-
|
| 146 |
-
from .transforms import (
|
| 147 |
-
Div255,
|
| 148 |
-
SelfNormalizeNoDiv,
|
| 149 |
-
)
|
| 150 |
-
|
| 151 |
-
transforms_list = [Div255()]
|
| 152 |
-
if resize_size > 0:
|
| 153 |
-
transforms_list.append(transforms.Resize(resize_size, interpolation=interpolation))
|
| 154 |
-
|
| 155 |
-
if normalization_type == NormalizationType.SELF_NORM_AUG_DECODER:
|
| 156 |
-
transforms_list.extend(
|
| 157 |
-
[
|
| 158 |
-
transforms.RandomCrop(size=crop_size, pad_if_needed=True),
|
| 159 |
-
transforms.RandomHorizontalFlip(),
|
| 160 |
-
transforms.RandomVerticalFlip(),
|
| 161 |
-
]
|
| 162 |
-
)
|
| 163 |
-
elif normalization_type == NormalizationType.SELF_NORM_CENTER_CROP:
|
| 164 |
-
transforms_list.append(transforms.CenterCrop(size=crop_size))
|
| 165 |
-
else:
|
| 166 |
-
raise ValueError("f{normalization_type}: unknown NormalizationType")
|
| 167 |
-
transforms_list.append(SelfNormalizeNoDiv())
|
| 168 |
-
|
| 169 |
-
return transforms.Compose(transforms_list)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dinov2/data/collate.py
DELETED
|
@@ -1,49 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
import torch
|
| 7 |
-
import random
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
def collate_data_and_cast(samples_list, mask_ratio_tuple, mask_probability, dtype, n_tokens=None, mask_generator=None):
|
| 11 |
-
# dtype = torch.half # TODO: Remove
|
| 12 |
-
|
| 13 |
-
n_global_crops = len(samples_list[0][0]["global_crops"])
|
| 14 |
-
n_local_crops = len(samples_list[0][0]["local_crops"])
|
| 15 |
-
|
| 16 |
-
collated_global_crops = torch.stack([s[0]["global_crops"][i] for i in range(n_global_crops) for s in samples_list])
|
| 17 |
-
|
| 18 |
-
collated_local_crops = torch.stack([s[0]["local_crops"][i] for i in range(n_local_crops) for s in samples_list])
|
| 19 |
-
|
| 20 |
-
B = len(collated_global_crops)
|
| 21 |
-
N = n_tokens
|
| 22 |
-
n_samples_masked = int(B * mask_probability)
|
| 23 |
-
probs = torch.linspace(*mask_ratio_tuple, n_samples_masked + 1)
|
| 24 |
-
upperbound = 0
|
| 25 |
-
masks_list = []
|
| 26 |
-
for i in range(0, n_samples_masked):
|
| 27 |
-
prob_min = probs[i]
|
| 28 |
-
prob_max = probs[i + 1]
|
| 29 |
-
masks_list.append(torch.BoolTensor(mask_generator(int(N * random.uniform(prob_min, prob_max)))))
|
| 30 |
-
upperbound += int(N * prob_max)
|
| 31 |
-
for i in range(n_samples_masked, B):
|
| 32 |
-
masks_list.append(torch.BoolTensor(mask_generator(0)))
|
| 33 |
-
|
| 34 |
-
random.shuffle(masks_list)
|
| 35 |
-
|
| 36 |
-
collated_masks = torch.stack(masks_list).flatten(1)
|
| 37 |
-
mask_indices_list = collated_masks.flatten().nonzero().flatten()
|
| 38 |
-
|
| 39 |
-
masks_weight = (1 / collated_masks.sum(-1).clamp(min=1.0)).unsqueeze(-1).expand_as(collated_masks)[collated_masks]
|
| 40 |
-
|
| 41 |
-
return {
|
| 42 |
-
"collated_global_crops": collated_global_crops.to(dtype),
|
| 43 |
-
"collated_local_crops": collated_local_crops.to(dtype),
|
| 44 |
-
"collated_masks": collated_masks,
|
| 45 |
-
"mask_indices_list": mask_indices_list,
|
| 46 |
-
"masks_weight": masks_weight,
|
| 47 |
-
"upperbound": upperbound,
|
| 48 |
-
"n_masked_patches": torch.full((1,), fill_value=mask_indices_list.shape[0], dtype=torch.long),
|
| 49 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dinov2/data/datasets/__init__.py
DELETED
|
@@ -1,12 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
from .image_net import ImageNet
|
| 7 |
-
from .image_net_22k import ImageNet22k
|
| 8 |
-
from .cell_dino.hpaone import HPAone
|
| 9 |
-
from .cell_dino.hpafov import HPAFoV
|
| 10 |
-
from .cell_dino.chammi_cp import CHAMMI_CP
|
| 11 |
-
from .cell_dino.chammi_hpa import CHAMMI_HPA
|
| 12 |
-
from .cell_dino.chammi_wtc import CHAMMI_WTC
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dinov2/data/datasets/cell_dino/chammi_cp.py
DELETED
|
@@ -1,112 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the CC-by-NC licence,
|
| 4 |
-
# found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
import csv
|
| 7 |
-
from enum import Enum
|
| 8 |
-
import logging
|
| 9 |
-
import os
|
| 10 |
-
from typing import Any, Callable, Optional, Union
|
| 11 |
-
|
| 12 |
-
import numpy as np
|
| 13 |
-
|
| 14 |
-
from ..extended import ExtendedVisionDataset
|
| 15 |
-
from ..decoders import DecoderType
|
| 16 |
-
|
| 17 |
-
logger = logging.getLogger("dinov2")
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
METADATA_FILE = "morphem70k_v2.csv"
|
| 21 |
-
|
| 22 |
-
CLASS_LABELS = {
|
| 23 |
-
"BRD-A29260609": 0,
|
| 24 |
-
"BRD-K04185004": 1,
|
| 25 |
-
"BRD-K21680192": 2,
|
| 26 |
-
"DMSO": 3,
|
| 27 |
-
"BRD-K11129031": 4, # labels only seen in TASK_FOUR
|
| 28 |
-
"BRD-K62310379": 5,
|
| 29 |
-
"BRD-K77947974": 6,
|
| 30 |
-
}
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
class _Split(Enum):
|
| 34 |
-
TRAIN = "Train"
|
| 35 |
-
TASK_ONE = "Task_one"
|
| 36 |
-
TASK_TWO = "Task_two"
|
| 37 |
-
TASK_THREE = "Task_three"
|
| 38 |
-
TASK_FOUR = "Task_four"
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
def _load_file_names_and_targets(
|
| 42 |
-
root: str,
|
| 43 |
-
split: _Split,
|
| 44 |
-
):
|
| 45 |
-
image_paths = []
|
| 46 |
-
labels = []
|
| 47 |
-
with open(os.path.join(root, METADATA_FILE)) as metadata:
|
| 48 |
-
metadata_reader = csv.DictReader(metadata)
|
| 49 |
-
for row in metadata_reader:
|
| 50 |
-
row_dataset = row["file_path"].split("/")[0]
|
| 51 |
-
|
| 52 |
-
if row["train_test_split"].upper() == split and row_dataset == "CP":
|
| 53 |
-
image_paths.append(row["file_path"])
|
| 54 |
-
labels.append(CLASS_LABELS[row["label"]])
|
| 55 |
-
|
| 56 |
-
return image_paths, labels # to debug
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
class CHAMMI_CP(ExtendedVisionDataset):
|
| 60 |
-
"""
|
| 61 |
-
Implementation of the CP (Cell-Painting) subset of the CHAMMI benchmark dataset,
|
| 62 |
-
following the CHAMMI paper: https://arxiv.org/pdf/2310.19224
|
| 63 |
-
Github code: https://github.com/chaudatascience/channel_adaptive_models
|
| 64 |
-
"""
|
| 65 |
-
|
| 66 |
-
Split = Union[_Split]
|
| 67 |
-
|
| 68 |
-
def __init__(
|
| 69 |
-
self,
|
| 70 |
-
*,
|
| 71 |
-
split: "CHAMMI_CP.Split",
|
| 72 |
-
root: str,
|
| 73 |
-
transforms: Optional[Callable] = None,
|
| 74 |
-
transform: Optional[Callable] = None,
|
| 75 |
-
target_transform: Optional[Callable] = None,
|
| 76 |
-
image_decoder_type: DecoderType = DecoderType.XChannelsDecoder,
|
| 77 |
-
**kwargs: Any,
|
| 78 |
-
) -> None:
|
| 79 |
-
super().__init__(
|
| 80 |
-
root,
|
| 81 |
-
transforms,
|
| 82 |
-
transform,
|
| 83 |
-
target_transform,
|
| 84 |
-
image_decoder_type=image_decoder_type,
|
| 85 |
-
**kwargs,
|
| 86 |
-
)
|
| 87 |
-
self.split = split
|
| 88 |
-
self.root = root
|
| 89 |
-
self.num_additional_labels_loo_eval = 3
|
| 90 |
-
self._image_paths, self._targets = _load_file_names_and_targets(
|
| 91 |
-
root,
|
| 92 |
-
split,
|
| 93 |
-
)
|
| 94 |
-
|
| 95 |
-
def get_image_relpath(self, index: int) -> str:
|
| 96 |
-
return self._image_paths[index]
|
| 97 |
-
|
| 98 |
-
def get_image_data(self, index: int) -> bytes:
|
| 99 |
-
image_relpath = self.get_image_relpath(index)
|
| 100 |
-
image_full_path = os.path.join(self.root, image_relpath)
|
| 101 |
-
with open(image_full_path, mode="rb") as f:
|
| 102 |
-
image_data = f.read()
|
| 103 |
-
return image_data
|
| 104 |
-
|
| 105 |
-
def get_target(self, index: int) -> Any:
|
| 106 |
-
return self._targets[index]
|
| 107 |
-
|
| 108 |
-
def get_targets(self) -> np.ndarray:
|
| 109 |
-
return np.array(self._targets)
|
| 110 |
-
|
| 111 |
-
def __len__(self) -> int:
|
| 112 |
-
return len(self._image_paths)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dinov2/data/datasets/cell_dino/chammi_hpa.py
DELETED
|
@@ -1,111 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the CC-by-NC licence,
|
| 4 |
-
# found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
import csv
|
| 7 |
-
from enum import Enum
|
| 8 |
-
import logging
|
| 9 |
-
import os
|
| 10 |
-
from typing import Any, Callable, Optional, Union
|
| 11 |
-
|
| 12 |
-
import numpy as np
|
| 13 |
-
|
| 14 |
-
from ..extended import ExtendedVisionDataset
|
| 15 |
-
from ..decoders import DecoderType
|
| 16 |
-
|
| 17 |
-
logger = logging.getLogger("dinov2")
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
METADATA_FILE = "morphem70k_v2.csv"
|
| 21 |
-
|
| 22 |
-
CLASS_LABELS = {
|
| 23 |
-
"golgi apparatus": 0,
|
| 24 |
-
"microtubules": 1,
|
| 25 |
-
"mitochondria": 2,
|
| 26 |
-
"nuclear speckles": 3,
|
| 27 |
-
"cytosol": 4, # labels only seen in TASK_THREE
|
| 28 |
-
"endoplasmic reticulum": 5,
|
| 29 |
-
"nucleoplasm": 6,
|
| 30 |
-
}
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
class _Split(Enum):
|
| 34 |
-
TRAIN = "Train"
|
| 35 |
-
TASK_ONE = "Task_one"
|
| 36 |
-
TASK_TWO = "Task_two"
|
| 37 |
-
TASK_THREE = "Task_three"
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
def _load_file_names_and_targets(
|
| 41 |
-
root: str,
|
| 42 |
-
split: _Split,
|
| 43 |
-
):
|
| 44 |
-
image_paths = []
|
| 45 |
-
labels = []
|
| 46 |
-
with open(os.path.join(root, METADATA_FILE)) as metadata:
|
| 47 |
-
metadata_reader = csv.DictReader(metadata)
|
| 48 |
-
for row in metadata_reader:
|
| 49 |
-
row_dataset = row["file_path"].split("/")[0]
|
| 50 |
-
if row["train_test_split"].upper() == split and row_dataset == "HPA":
|
| 51 |
-
image_paths.append(row["file_path"])
|
| 52 |
-
labels.append(CLASS_LABELS[row["label"]])
|
| 53 |
-
|
| 54 |
-
return image_paths, labels
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
class CHAMMI_HPA(ExtendedVisionDataset):
|
| 58 |
-
"""
|
| 59 |
-
Implementation of the CP (Cell-Painting) subset of the CHAMMI benchmark dataset,
|
| 60 |
-
following the CHAMMI paper: https://arxiv.org/pdf/2310.19224
|
| 61 |
-
Github code: https://github.com/chaudatascience/channel_adaptive_models
|
| 62 |
-
"""
|
| 63 |
-
|
| 64 |
-
Split = Union[_Split]
|
| 65 |
-
|
| 66 |
-
def __init__(
|
| 67 |
-
self,
|
| 68 |
-
*,
|
| 69 |
-
split: "CHAMMI_HPA.Split",
|
| 70 |
-
root: str,
|
| 71 |
-
transforms: Optional[Callable] = None,
|
| 72 |
-
transform: Optional[Callable] = None,
|
| 73 |
-
target_transform: Optional[Callable] = None,
|
| 74 |
-
image_decoder_type: DecoderType = DecoderType.XChannelsDecoder,
|
| 75 |
-
**kwargs: Any,
|
| 76 |
-
) -> None:
|
| 77 |
-
super().__init__(
|
| 78 |
-
root,
|
| 79 |
-
transforms,
|
| 80 |
-
transform,
|
| 81 |
-
target_transform,
|
| 82 |
-
image_decoder_type=image_decoder_type,
|
| 83 |
-
**kwargs,
|
| 84 |
-
)
|
| 85 |
-
self.split = split
|
| 86 |
-
self.root = root
|
| 87 |
-
self.num_additional_labels_loo_eval = 3
|
| 88 |
-
|
| 89 |
-
self._image_paths, self._targets = _load_file_names_and_targets(
|
| 90 |
-
root,
|
| 91 |
-
split,
|
| 92 |
-
)
|
| 93 |
-
|
| 94 |
-
def get_image_relpath(self, index: int) -> str:
|
| 95 |
-
return self._image_paths[index]
|
| 96 |
-
|
| 97 |
-
def get_image_data(self, index: int) -> bytes:
|
| 98 |
-
image_relpath = self.get_image_relpath(index)
|
| 99 |
-
image_full_path = os.path.join(self.root, image_relpath)
|
| 100 |
-
with open(image_full_path, mode="rb") as f:
|
| 101 |
-
image_data = f.read()
|
| 102 |
-
return image_data
|
| 103 |
-
|
| 104 |
-
def get_target(self, index: int) -> Any:
|
| 105 |
-
return self._targets[index]
|
| 106 |
-
|
| 107 |
-
def get_targets(self) -> np.ndarray:
|
| 108 |
-
return np.array(self._targets)
|
| 109 |
-
|
| 110 |
-
def __len__(self) -> int:
|
| 111 |
-
return len(self._image_paths)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dinov2/data/datasets/cell_dino/chammi_wtc.py
DELETED
|
@@ -1,108 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the CC-by-NC licence,
|
| 4 |
-
# found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
import csv
|
| 7 |
-
from enum import Enum
|
| 8 |
-
import logging
|
| 9 |
-
import os
|
| 10 |
-
from typing import Any, Callable, Optional, Union
|
| 11 |
-
|
| 12 |
-
import numpy as np
|
| 13 |
-
|
| 14 |
-
from ..extended import ExtendedVisionDataset
|
| 15 |
-
from ..decoders import DecoderType
|
| 16 |
-
|
| 17 |
-
logger = logging.getLogger("dinov2")
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
METADATA_FILE = "morphem70k_v2.csv"
|
| 21 |
-
|
| 22 |
-
CLASS_LABELS = {
|
| 23 |
-
"M0": 0,
|
| 24 |
-
"M1M2": 1,
|
| 25 |
-
"M3": 2,
|
| 26 |
-
"M4M5": 3,
|
| 27 |
-
"M6M7_complete": 4,
|
| 28 |
-
"M6M7_single": 5,
|
| 29 |
-
}
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
class _Split(Enum):
|
| 33 |
-
TRAIN = "Train"
|
| 34 |
-
TASK_ONE = "Task_one"
|
| 35 |
-
TASK_TWO = "Task_two"
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
def _load_file_names_and_targets(
|
| 39 |
-
root: str,
|
| 40 |
-
split: _Split,
|
| 41 |
-
):
|
| 42 |
-
image_paths = []
|
| 43 |
-
labels = []
|
| 44 |
-
with open(os.path.join(root, METADATA_FILE)) as metadata:
|
| 45 |
-
metadata_reader = csv.DictReader(metadata)
|
| 46 |
-
for row in metadata_reader:
|
| 47 |
-
row_dataset = row["file_path"].split("/")[0]
|
| 48 |
-
if row["train_test_split"].upper() == split and row_dataset == "Allen":
|
| 49 |
-
image_paths.append(row["file_path"])
|
| 50 |
-
labels.append(CLASS_LABELS[row["label"]])
|
| 51 |
-
|
| 52 |
-
return image_paths, labels
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
class CHAMMI_WTC(ExtendedVisionDataset):
|
| 56 |
-
"""
|
| 57 |
-
Implementation of the CP (Cell-Painting) subset of the CHAMMI benchmark dataset,
|
| 58 |
-
following the CHAMMI paper: https://arxiv.org/pdf/2310.19224
|
| 59 |
-
Github code: https://github.com/chaudatascience/channel_adaptive_models
|
| 60 |
-
"""
|
| 61 |
-
|
| 62 |
-
Split = Union[_Split]
|
| 63 |
-
|
| 64 |
-
def __init__(
|
| 65 |
-
self,
|
| 66 |
-
*,
|
| 67 |
-
split: "CHAMMI_WTC.Split",
|
| 68 |
-
root: str,
|
| 69 |
-
transforms: Optional[Callable] = None,
|
| 70 |
-
transform: Optional[Callable] = None,
|
| 71 |
-
target_transform: Optional[Callable] = None,
|
| 72 |
-
image_decoder_type: DecoderType = DecoderType.XChannelsTIFFDecoder,
|
| 73 |
-
**kwargs: Any,
|
| 74 |
-
) -> None:
|
| 75 |
-
super().__init__(
|
| 76 |
-
root,
|
| 77 |
-
transforms,
|
| 78 |
-
transform,
|
| 79 |
-
target_transform,
|
| 80 |
-
image_decoder_type=image_decoder_type,
|
| 81 |
-
**kwargs,
|
| 82 |
-
)
|
| 83 |
-
self.split = split
|
| 84 |
-
self.root = root
|
| 85 |
-
|
| 86 |
-
self._image_paths, self._targets = _load_file_names_and_targets(
|
| 87 |
-
root,
|
| 88 |
-
split,
|
| 89 |
-
)
|
| 90 |
-
|
| 91 |
-
def get_image_relpath(self, index: int) -> str:
|
| 92 |
-
return self._image_paths[index]
|
| 93 |
-
|
| 94 |
-
def get_image_data(self, index: int) -> bytes:
|
| 95 |
-
image_relpath = self.get_image_relpath(index)
|
| 96 |
-
image_full_path = os.path.join(self.root, image_relpath)
|
| 97 |
-
with open(image_full_path, mode="rb") as f:
|
| 98 |
-
image_data = f.read()
|
| 99 |
-
return image_data
|
| 100 |
-
|
| 101 |
-
def get_target(self, index: int) -> Any:
|
| 102 |
-
return self._targets[index]
|
| 103 |
-
|
| 104 |
-
def get_targets(self) -> np.ndarray:
|
| 105 |
-
return np.array(self._targets)
|
| 106 |
-
|
| 107 |
-
def __len__(self) -> int:
|
| 108 |
-
return len(self._image_paths)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dinov2/data/datasets/cell_dino/hpafov.py
DELETED
|
@@ -1,283 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the CC-by-NC licence,
|
| 4 |
-
# found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
import csv
|
| 7 |
-
from enum import Enum
|
| 8 |
-
import logging
|
| 9 |
-
import os
|
| 10 |
-
from typing import Any, Callable, List, Optional, Tuple, Union, Dict
|
| 11 |
-
|
| 12 |
-
import numpy as np
|
| 13 |
-
|
| 14 |
-
from ..extended import ExtendedVisionDataset
|
| 15 |
-
from ..decoders import DecoderType
|
| 16 |
-
|
| 17 |
-
logger = logging.getLogger("dinov2")
|
| 18 |
-
|
| 19 |
-
CELL_TYPE = [
|
| 20 |
-
"BJ", # 1
|
| 21 |
-
"LHCN-M2",
|
| 22 |
-
"RH-30",
|
| 23 |
-
"SH-SY5Y",
|
| 24 |
-
"U-2 OS", # 5
|
| 25 |
-
"ASC TERT1",
|
| 26 |
-
"HaCaT",
|
| 27 |
-
"A-431",
|
| 28 |
-
"U-251 MG",
|
| 29 |
-
"HEK 293", # 10
|
| 30 |
-
"A549",
|
| 31 |
-
"RT4",
|
| 32 |
-
"HeLa",
|
| 33 |
-
"MCF7",
|
| 34 |
-
"PC-3", # 15
|
| 35 |
-
"hTERT-RPE1",
|
| 36 |
-
"SK-MEL-30",
|
| 37 |
-
"EFO-21",
|
| 38 |
-
"AF22",
|
| 39 |
-
"HEL", # 20
|
| 40 |
-
"Hep G2",
|
| 41 |
-
"HUVEC TERT2",
|
| 42 |
-
"THP-1",
|
| 43 |
-
"CACO-2",
|
| 44 |
-
"JURKAT", # 25
|
| 45 |
-
"RPTEC TERT1",
|
| 46 |
-
"SuSa",
|
| 47 |
-
"REH",
|
| 48 |
-
"HDLM-2",
|
| 49 |
-
"K-562", # 30
|
| 50 |
-
"hTCEpi",
|
| 51 |
-
"NB-4",
|
| 52 |
-
"HAP1",
|
| 53 |
-
"OE19",
|
| 54 |
-
"SiHa", # 35
|
| 55 |
-
]
|
| 56 |
-
|
| 57 |
-
PROTEIN_LOCALIZATION = [ # matches https://www.kaggle.com/c/human-protein-atlas-image-classification/data
|
| 58 |
-
"nucleoplasm",
|
| 59 |
-
"nuclear membrane",
|
| 60 |
-
"nucleoli",
|
| 61 |
-
"nucleoli fibrillar center",
|
| 62 |
-
"nuclear speckles", # 5
|
| 63 |
-
"nuclear bodies",
|
| 64 |
-
"endoplasmic reticulum",
|
| 65 |
-
"golgi apparatus",
|
| 66 |
-
"peroxisomes",
|
| 67 |
-
"endosomes", # 10
|
| 68 |
-
"lysosomes",
|
| 69 |
-
"intermediate filaments",
|
| 70 |
-
"actin filaments",
|
| 71 |
-
"focal adhesion sites",
|
| 72 |
-
"microtubules", # 15
|
| 73 |
-
"microtubule ends",
|
| 74 |
-
"cytokinetic bridge",
|
| 75 |
-
"mitotic spindle",
|
| 76 |
-
"microtubule organizing center",
|
| 77 |
-
"centrosome", # 20
|
| 78 |
-
"lipid droplets",
|
| 79 |
-
"plasma membrane",
|
| 80 |
-
"cell junctions",
|
| 81 |
-
"mitochondria",
|
| 82 |
-
"aggresome", # 25
|
| 83 |
-
"cytosol",
|
| 84 |
-
"cytoplasmic bodies",
|
| 85 |
-
"rods & rings",
|
| 86 |
-
]
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
class _Split(Enum):
|
| 90 |
-
TRAIN = "train"
|
| 91 |
-
VAL = "val"
|
| 92 |
-
SSL = "ssl"
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
def get_csv_fpath(split):
|
| 96 |
-
"""
|
| 97 |
-
Path to data relative to root
|
| 98 |
-
"""
|
| 99 |
-
if split == _Split.TRAIN.value.upper() or split == _Split.TRAIN or split == "TRAIN":
|
| 100 |
-
return "whole_images_512_train.csv"
|
| 101 |
-
elif split == _Split.VAL.value.upper() or split == _Split.VAL or split == "VAL":
|
| 102 |
-
return "whole_images_512_test.csv"
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
class _WildCard(Enum):
|
| 106 |
-
NONE = "none"
|
| 107 |
-
SEPARATECHANNELS = "separate_channels" # each channel from each image is treated as an independent sample, overrides chosen channel configuration
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
class _Mode(Enum):
|
| 111 |
-
"""
|
| 112 |
-
Targets:
|
| 113 |
-
- ALL: tuple, (one hot encoding of multilabel protein localization, categorical encoding of cell type)
|
| 114 |
-
- PROTEIN_LOCALIZATION: one hot encoding of multilabel protein localization
|
| 115 |
-
- CELL_TYPE: categorical encoding of cell type
|
| 116 |
-
"""
|
| 117 |
-
|
| 118 |
-
ALL = "all"
|
| 119 |
-
PROTEIN_LOCALIZATION = "protein_localization"
|
| 120 |
-
CELL_TYPE = "cell_type"
|
| 121 |
-
|
| 122 |
-
@property
|
| 123 |
-
def nb_labels(self):
|
| 124 |
-
if self == _Mode.CELL_TYPE:
|
| 125 |
-
return len(CELL_TYPE)
|
| 126 |
-
elif self == _Mode.PROTEIN_LOCALIZATION:
|
| 127 |
-
return len(PROTEIN_LOCALIZATION)
|
| 128 |
-
else:
|
| 129 |
-
return None
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
def _list_images_from_csv(img_path, csv_path):
|
| 133 |
-
L = []
|
| 134 |
-
with open(csv_path) as filename:
|
| 135 |
-
reader = csv.DictReader(filename)
|
| 136 |
-
for row in reader:
|
| 137 |
-
L.append(os.path.join(img_path, row["ID"] + ".png"))
|
| 138 |
-
return L
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
def _load_file_names_and_labels_ssl(
|
| 142 |
-
root: str,
|
| 143 |
-
) -> Tuple[List[str], List[Any]]:
|
| 144 |
-
|
| 145 |
-
curr_img_path = os.path.join(root, "normalized_data")
|
| 146 |
-
csv_train_ssl = os.path.join(root, "whole_images_names.csv")
|
| 147 |
-
image_paths = _list_images_from_csv(curr_img_path, csv_train_ssl)
|
| 148 |
-
labels = [i for i in range(len(image_paths))]
|
| 149 |
-
return image_paths, labels
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
def _load_file_names_and_labels(
|
| 153 |
-
root: str,
|
| 154 |
-
split: _Split,
|
| 155 |
-
mode: _Mode,
|
| 156 |
-
) -> Tuple[List[str], List[Any], np.ndarray]:
|
| 157 |
-
|
| 158 |
-
data_path = os.path.join(root, "512_whole_images")
|
| 159 |
-
csv_fpath = os.path.join(root, get_csv_fpath(split))
|
| 160 |
-
|
| 161 |
-
image_paths = []
|
| 162 |
-
labels = []
|
| 163 |
-
|
| 164 |
-
with open(csv_fpath) as filename:
|
| 165 |
-
reader = csv.DictReader(filename)
|
| 166 |
-
for row in reader:
|
| 167 |
-
|
| 168 |
-
add_sample = True
|
| 169 |
-
if mode != _Mode.PROTEIN_LOCALIZATION.value.upper():
|
| 170 |
-
# categorical
|
| 171 |
-
if row["cell_type"] in CELL_TYPE:
|
| 172 |
-
cell_type = CELL_TYPE.index(row["cell_type"])
|
| 173 |
-
else:
|
| 174 |
-
cell_type = np.nan
|
| 175 |
-
|
| 176 |
-
if mode != _Mode.CELL_TYPE.value.upper():
|
| 177 |
-
# one hot encoding
|
| 178 |
-
prot_loc = np.zeros(len(PROTEIN_LOCALIZATION), dtype=np.int_)
|
| 179 |
-
for k in range(len(PROTEIN_LOCALIZATION)):
|
| 180 |
-
if row[PROTEIN_LOCALIZATION[k]] == "True":
|
| 181 |
-
prot_loc[k] = 1
|
| 182 |
-
if prot_loc.max() < 0.5:
|
| 183 |
-
add_sample = False
|
| 184 |
-
|
| 185 |
-
if add_sample:
|
| 186 |
-
if mode == _Mode.PROTEIN_LOCALIZATION.value.upper():
|
| 187 |
-
labels.append(prot_loc)
|
| 188 |
-
elif mode == _Mode.CELL_TYPE.value.upper():
|
| 189 |
-
labels.append(cell_type)
|
| 190 |
-
else:
|
| 191 |
-
labels.append({"prot_loc": prot_loc, "cell_type": cell_type})
|
| 192 |
-
|
| 193 |
-
candidate_path = os.path.join(data_path, row["file"].split("/")[-1])
|
| 194 |
-
if os.path.exists(candidate_path):
|
| 195 |
-
image_paths.append(candidate_path)
|
| 196 |
-
else:
|
| 197 |
-
candidate_path = os.path.join(
|
| 198 |
-
data_path, row["file"].split("/")[-1].split(".")[0] + ".tiff"
|
| 199 |
-
) # _blue.png") # some images on the normalized_data folder have a _blue suffix on their names
|
| 200 |
-
if os.path.exists(candidate_path):
|
| 201 |
-
image_paths.append(candidate_path)
|
| 202 |
-
else:
|
| 203 |
-
raise FileNotFoundError(f"File {candidate_path} not found.")
|
| 204 |
-
|
| 205 |
-
return image_paths, labels
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
class HPAFoV(ExtendedVisionDataset):
|
| 209 |
-
Split = Union[_Split]
|
| 210 |
-
Mode = Union[_Mode]
|
| 211 |
-
WildCard = Union[_WildCard]
|
| 212 |
-
|
| 213 |
-
def __init__(
|
| 214 |
-
self,
|
| 215 |
-
*,
|
| 216 |
-
split: "HPAFoV.Split" = _Split.TRAIN,
|
| 217 |
-
mode: "HPAFoV.Mode" = _Mode.ALL,
|
| 218 |
-
wildcard: "HPAFoV.WildCard" = _WildCard.NONE,
|
| 219 |
-
root: str,
|
| 220 |
-
transforms: Optional[Callable] = None,
|
| 221 |
-
transform: Optional[Callable] = None,
|
| 222 |
-
target_transform: Optional[Callable] = None,
|
| 223 |
-
image_decoder_type: DecoderType = DecoderType.ChannelSelectDecoder,
|
| 224 |
-
image_decoder_params: Dict[str, Any] = {},
|
| 225 |
-
**kwargs: Any,
|
| 226 |
-
) -> None:
|
| 227 |
-
super().__init__(
|
| 228 |
-
root,
|
| 229 |
-
transforms,
|
| 230 |
-
transform,
|
| 231 |
-
target_transform,
|
| 232 |
-
image_decoder_type=image_decoder_type,
|
| 233 |
-
image_decoder_params={
|
| 234 |
-
"select_channel": True
|
| 235 |
-
if wildcard == _WildCard.SEPARATECHANNELS or wildcard == "SEPARATE_CHANNELS"
|
| 236 |
-
else False
|
| 237 |
-
},
|
| 238 |
-
**kwargs,
|
| 239 |
-
)
|
| 240 |
-
self.mode = mode
|
| 241 |
-
self.split = split
|
| 242 |
-
self.root = root
|
| 243 |
-
self.wildcard = wildcard
|
| 244 |
-
self.channel_adaptive = True
|
| 245 |
-
if split == _Split.SSL.value.upper() or split == _Split.SSL or split == "SSL":
|
| 246 |
-
self._image_paths, self._labels = _load_file_names_and_labels_ssl(root)
|
| 247 |
-
else:
|
| 248 |
-
self._image_paths, self._labels = _load_file_names_and_labels(root, self.split, self.mode)
|
| 249 |
-
|
| 250 |
-
self._channels = np.repeat(np.array([[0, 1, 2, 3]]), len(self._image_paths), axis=0).tolist()
|
| 251 |
-
|
| 252 |
-
if self.wildcard == _WildCard.SEPARATECHANNELS.value.upper():
|
| 253 |
-
image_paths, labels, channels = self._image_paths, self._labels, self._channels
|
| 254 |
-
channels = np.array(channels)
|
| 255 |
-
# separate and stack the columns of the channels array
|
| 256 |
-
C = channels.shape[1]
|
| 257 |
-
channels = np.concatenate([channels[:, i] for i in range(C)])
|
| 258 |
-
self._channels = np.expand_dims(channels, 1).tolist()
|
| 259 |
-
self.image_paths = image_paths * C
|
| 260 |
-
self.labels = labels * C
|
| 261 |
-
|
| 262 |
-
def get_image_relpath(self, index: int) -> str:
|
| 263 |
-
return self._image_paths[index]
|
| 264 |
-
|
| 265 |
-
def get_image_data(self, index: int) -> bytes:
|
| 266 |
-
image_relpath = self.get_image_relpath(index)
|
| 267 |
-
image_full_path = os.path.join(self.root, image_relpath)
|
| 268 |
-
with open(image_full_path, mode="rb") as f:
|
| 269 |
-
image_data = f.read()
|
| 270 |
-
if self.channel_adaptive:
|
| 271 |
-
channels = self._channels[index]
|
| 272 |
-
return image_data + bytes(channels) + (len(channels)).to_bytes(1, byteorder="big")
|
| 273 |
-
else:
|
| 274 |
-
return image_data
|
| 275 |
-
|
| 276 |
-
def get_target(self, index: int) -> Any:
|
| 277 |
-
return self._labels[index]
|
| 278 |
-
|
| 279 |
-
def get_targets(self) -> np.ndarray:
|
| 280 |
-
return np.array(self._labels)
|
| 281 |
-
|
| 282 |
-
def __len__(self) -> int:
|
| 283 |
-
return len(self._image_paths)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dinov2/data/datasets/cell_dino/hpaone.py
DELETED
|
@@ -1,223 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the CC-by-NC licence,
|
| 4 |
-
# found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
import csv
|
| 7 |
-
from enum import Enum
|
| 8 |
-
import logging
|
| 9 |
-
import os
|
| 10 |
-
from typing import Any, Callable, List, Optional, Tuple, Union
|
| 11 |
-
|
| 12 |
-
import numpy as np
|
| 13 |
-
|
| 14 |
-
from ..extended import ExtendedVisionDataset
|
| 15 |
-
from ..decoders import DecoderType
|
| 16 |
-
|
| 17 |
-
logger = logging.getLogger("dinov2")
|
| 18 |
-
|
| 19 |
-
PROTEIN_LOCALIZATION = [
|
| 20 |
-
"actin filaments,focal adhesion sites",
|
| 21 |
-
"aggresome",
|
| 22 |
-
"centrosome,centriolar satellite",
|
| 23 |
-
"cytosol",
|
| 24 |
-
"endoplasmic reticulum",
|
| 25 |
-
"golgi apparatus",
|
| 26 |
-
"intermediate filaments",
|
| 27 |
-
"microtubules",
|
| 28 |
-
"mitochondria",
|
| 29 |
-
"mitotic spindle",
|
| 30 |
-
"no staining",
|
| 31 |
-
"nuclear bodies",
|
| 32 |
-
"nuclear membrane",
|
| 33 |
-
"nuclear speckles",
|
| 34 |
-
"nucleoli",
|
| 35 |
-
"nucleoli fibrillar center",
|
| 36 |
-
"nucleoplasm",
|
| 37 |
-
"plasma membrane,cell junctions",
|
| 38 |
-
"vesicles,peroxisomes,endosomes,lysosomes,lipid droplets,cytoplasmic bodies",
|
| 39 |
-
] # 19
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
CELL_TYPE = [
|
| 43 |
-
"A-431", # 0
|
| 44 |
-
"A549",
|
| 45 |
-
"AF22",
|
| 46 |
-
"ASC TERT1",
|
| 47 |
-
"BJ",
|
| 48 |
-
"CACO-2",
|
| 49 |
-
"EFO-21",
|
| 50 |
-
"HAP1",
|
| 51 |
-
"HDLM-2",
|
| 52 |
-
"HEK 293", # 9
|
| 53 |
-
"HEL",
|
| 54 |
-
"HUVEC TERT2",
|
| 55 |
-
"HaCaT",
|
| 56 |
-
"HeLa",
|
| 57 |
-
"Hep G2",
|
| 58 |
-
"JURKAT",
|
| 59 |
-
"K-562",
|
| 60 |
-
"MCF7",
|
| 61 |
-
"PC-3",
|
| 62 |
-
"REH",
|
| 63 |
-
"RH-30", # 20
|
| 64 |
-
"RPTEC TERT1",
|
| 65 |
-
"RT4",
|
| 66 |
-
"SH-SY5Y",
|
| 67 |
-
"SK-MEL-30",
|
| 68 |
-
"SiHa",
|
| 69 |
-
"U-2 OS",
|
| 70 |
-
"U-251 MG",
|
| 71 |
-
"hTCEpi", # 28
|
| 72 |
-
] # 29 cell types
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
class _Split(Enum):
|
| 76 |
-
VAL = "val"
|
| 77 |
-
TRAIN = "train"
|
| 78 |
-
ALL = "all" # images without labels, for encoder training
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
class _Mode(Enum):
|
| 82 |
-
PROTEIN_LOCALIZATION = "protein_localization"
|
| 83 |
-
CELL_TYPE = "cell_type"
|
| 84 |
-
|
| 85 |
-
@property
|
| 86 |
-
def num_labels(self):
|
| 87 |
-
if self == _Mode.CELL_TYPE.value.upper():
|
| 88 |
-
return len(CELL_TYPE)
|
| 89 |
-
return len(PROTEIN_LOCALIZATION)
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
def _simple_parse_csv(img_rootdir, csv_filepath: str):
|
| 93 |
-
samples = []
|
| 94 |
-
with open(csv_filepath) as filename:
|
| 95 |
-
template = csv.DictReader(filename)
|
| 96 |
-
samples = [(os.path.join(img_rootdir, row["img_path"]), 0) for row in template]
|
| 97 |
-
return samples
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
def _parse_csv(img_rootdir, csv_labels_path: str):
|
| 101 |
-
nb_protein_location = len(PROTEIN_LOCALIZATION)
|
| 102 |
-
nb_cell_type = len(CELL_TYPE)
|
| 103 |
-
samples = []
|
| 104 |
-
with open(csv_labels_path) as filename:
|
| 105 |
-
reader = csv.DictReader(filename)
|
| 106 |
-
for row in reader:
|
| 107 |
-
protein_location = np.zeros(nb_protein_location, dtype=np.int_)
|
| 108 |
-
for k in range(nb_protein_location):
|
| 109 |
-
if row[PROTEIN_LOCALIZATION[k]] == "True":
|
| 110 |
-
protein_location[k] = 1
|
| 111 |
-
|
| 112 |
-
cell_type = 0
|
| 113 |
-
for k in range(nb_cell_type):
|
| 114 |
-
if row[CELL_TYPE[k]] == "True":
|
| 115 |
-
cell_type = k
|
| 116 |
-
|
| 117 |
-
samples.append(
|
| 118 |
-
(
|
| 119 |
-
img_rootdir + "/" + row["file"].rsplit("/", 1)[1],
|
| 120 |
-
protein_location,
|
| 121 |
-
cell_type,
|
| 122 |
-
)
|
| 123 |
-
)
|
| 124 |
-
return samples
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
def _load_file_names_and_labels_ssl(
|
| 128 |
-
root: str,
|
| 129 |
-
) -> Tuple[List[str], List[Any]]:
|
| 130 |
-
curr_dir_train = os.path.join(root, "varied_size_masked_single_cells_HPA")
|
| 131 |
-
csv_all_path = os.path.join(root, "varied_size_masked_single_cells_pretrain_20240507.csv")
|
| 132 |
-
samples = _simple_parse_csv(curr_dir_train, csv_all_path)
|
| 133 |
-
image_paths, fake_labels = zip(*samples)
|
| 134 |
-
lab = list(fake_labels)
|
| 135 |
-
return image_paths, lab
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
def _load_file_names_and_labels_train_or_test(
|
| 139 |
-
root: str,
|
| 140 |
-
split: _Split,
|
| 141 |
-
mode: _Mode,
|
| 142 |
-
) -> Tuple[List[str], List[Any]]:
|
| 143 |
-
|
| 144 |
-
if split == _Split.TRAIN.value.upper() or split == _Split.TRAIN:
|
| 145 |
-
csv_labels_path = os.path.join(root, "fixed_size_masked_single_cells_pretrain_20240507.csv")
|
| 146 |
-
elif split == _Split.VAL.value.upper() or split == _Split.VAL:
|
| 147 |
-
csv_labels_path = os.path.join(root, "fixed_size_masked_single_cells_evaluation_20240507.csv")
|
| 148 |
-
else:
|
| 149 |
-
print("wrong split name")
|
| 150 |
-
curr_dir_val = os.path.join(root, "fixed_size_masked_single_cells_HPA")
|
| 151 |
-
|
| 152 |
-
samples = _parse_csv(curr_dir_val, csv_labels_path)
|
| 153 |
-
image_paths, protein_location, cell_type = zip(*samples)
|
| 154 |
-
if mode == _Mode.PROTEIN_LOCALIZATION.value.upper():
|
| 155 |
-
lab = protein_location
|
| 156 |
-
elif mode == _Mode.CELL_TYPE.value.upper():
|
| 157 |
-
lab = cell_type
|
| 158 |
-
else:
|
| 159 |
-
lab = protein_location, cell_type
|
| 160 |
-
image_paths = list(image_paths)
|
| 161 |
-
return image_paths, lab
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
class HPAone(ExtendedVisionDataset):
|
| 165 |
-
Split = Union[_Split]
|
| 166 |
-
Mode = Union[_Mode]
|
| 167 |
-
|
| 168 |
-
def __init__(
|
| 169 |
-
self,
|
| 170 |
-
*,
|
| 171 |
-
split: "HPAone.Split" = _Split.ALL,
|
| 172 |
-
mode: "HPAone.Mode" = None,
|
| 173 |
-
root: str,
|
| 174 |
-
transforms: Optional[Callable] = None,
|
| 175 |
-
transform: Optional[Callable] = None,
|
| 176 |
-
target_transform: Optional[Callable] = None,
|
| 177 |
-
image_decoder_type: DecoderType = DecoderType.XChannelsDecoder,
|
| 178 |
-
**kwargs: Any,
|
| 179 |
-
) -> None:
|
| 180 |
-
super().__init__(
|
| 181 |
-
root,
|
| 182 |
-
transforms,
|
| 183 |
-
transform,
|
| 184 |
-
target_transform,
|
| 185 |
-
image_decoder_type=image_decoder_type,
|
| 186 |
-
**kwargs,
|
| 187 |
-
)
|
| 188 |
-
self.mode = mode
|
| 189 |
-
self.split = split
|
| 190 |
-
self.root = root
|
| 191 |
-
|
| 192 |
-
if (
|
| 193 |
-
split in {_Split.TRAIN.value.upper(), _Split.VAL.value.upper()}
|
| 194 |
-
or split == _Split.TRAIN
|
| 195 |
-
or split == _Split.VAL
|
| 196 |
-
):
|
| 197 |
-
(
|
| 198 |
-
self._image_paths,
|
| 199 |
-
self._labels,
|
| 200 |
-
) = _load_file_names_and_labels_train_or_test(root, split, mode)
|
| 201 |
-
elif split == _Split.ALL.value.upper() or split == _Split.ALL:
|
| 202 |
-
self._image_paths, self._labels = _load_file_names_and_labels_ssl(root)
|
| 203 |
-
else:
|
| 204 |
-
logger.info(f"unknown split: {split}, {_Split.ALL.value.upper()}")
|
| 205 |
-
|
| 206 |
-
def get_image_relpath(self, index: int) -> str:
|
| 207 |
-
return self._image_paths[index]
|
| 208 |
-
|
| 209 |
-
def get_image_data(self, index: int) -> bytes:
|
| 210 |
-
image_relpath = self.get_image_relpath(index)
|
| 211 |
-
image_full_path = os.path.join(self.root, image_relpath)
|
| 212 |
-
with open(image_full_path, mode="rb") as f:
|
| 213 |
-
image_data = f.read()
|
| 214 |
-
return image_data
|
| 215 |
-
|
| 216 |
-
def get_target(self, index: int) -> Any:
|
| 217 |
-
return self._labels[index]
|
| 218 |
-
|
| 219 |
-
def get_targets(self) -> np.ndarray:
|
| 220 |
-
return np.array(self._labels)
|
| 221 |
-
|
| 222 |
-
def __len__(self) -> int:
|
| 223 |
-
return len(self._image_paths)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dinov2/data/datasets/decoders.py
DELETED
|
@@ -1,94 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
from io import BytesIO
|
| 7 |
-
from typing import Any, Type
|
| 8 |
-
|
| 9 |
-
from PIL import Image
|
| 10 |
-
import numpy as np
|
| 11 |
-
import torch
|
| 12 |
-
from enum import Enum
|
| 13 |
-
|
| 14 |
-
try:
|
| 15 |
-
import tifffile
|
| 16 |
-
except ImportError:
|
| 17 |
-
print("Could not import `tifffile`, TIFFImageDataDecoder will be disabled")
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
class Decoder:
|
| 21 |
-
def decode(self) -> Any:
|
| 22 |
-
raise NotImplementedError
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
class DecoderType(Enum):
|
| 26 |
-
ImageDataDecoder = "ImageDataDecoder"
|
| 27 |
-
XChannelsDecoder = "XChannelsDecoder"
|
| 28 |
-
XChannelsTIFFDecoder = "XChannelsTIFFDecoder"
|
| 29 |
-
ChannelSelectDecoder = "ChannelSelectDecoder"
|
| 30 |
-
|
| 31 |
-
def get_class(self) -> Type[Decoder]: # noqa: C901
|
| 32 |
-
if self == DecoderType.ImageDataDecoder:
|
| 33 |
-
return ImageDataDecoder
|
| 34 |
-
if self == DecoderType.XChannelsDecoder:
|
| 35 |
-
return XChannelsDecoder
|
| 36 |
-
if self == DecoderType.XChannelsTIFFDecoder:
|
| 37 |
-
return XChannelsTIFFDecoder
|
| 38 |
-
if self == DecoderType.ChannelSelectDecoder:
|
| 39 |
-
return ChannelSelectDecoder
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
class ImageDataDecoder(Decoder):
|
| 43 |
-
def __init__(self, image_data: bytes) -> None:
|
| 44 |
-
self._image_data = image_data
|
| 45 |
-
|
| 46 |
-
def decode(self) -> Image:
|
| 47 |
-
f = BytesIO(self._image_data)
|
| 48 |
-
return Image.open(f).convert(mode="RGB")
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
class TargetDecoder(Decoder):
|
| 52 |
-
def __init__(self, target: Any):
|
| 53 |
-
self._target = target
|
| 54 |
-
|
| 55 |
-
def decode(self) -> Any:
|
| 56 |
-
return self._target
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
class XChannelsDecoder(Decoder):
|
| 60 |
-
def __init__(self, image_data: bytes) -> None:
|
| 61 |
-
self._image_data = image_data
|
| 62 |
-
|
| 63 |
-
def decode(self):
|
| 64 |
-
im = np.asarray(Image.open(BytesIO(self._image_data)))
|
| 65 |
-
if len(im.shape) == 2:
|
| 66 |
-
im = np.reshape(im, (im.shape[0], im.shape[0], -1), order="F")
|
| 67 |
-
return torch.Tensor(im).permute(2, 0, 1)
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
class XChannelsTIFFDecoder(Decoder):
|
| 71 |
-
def __init__(self, image_data: bytes, num_channels: int = 3) -> None:
|
| 72 |
-
self._image_data = image_data
|
| 73 |
-
self._num_channels = num_channels
|
| 74 |
-
|
| 75 |
-
def decode(self):
|
| 76 |
-
numpy_array = tifffile.imread(BytesIO(self._image_data))
|
| 77 |
-
numpy_array = np.reshape(numpy_array, (numpy_array.shape[0], -1, self._num_channels), order="F")
|
| 78 |
-
return torch.Tensor(numpy_array).permute(2, 0, 1)
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
class ChannelSelectDecoder(Decoder):
|
| 82 |
-
def __init__(self, image_data: bytes, select_channel: bool = False) -> None:
|
| 83 |
-
self.select_channel = select_channel
|
| 84 |
-
if select_channel:
|
| 85 |
-
self._image_data = image_data[:-1]
|
| 86 |
-
self._channel = image_data[-1]
|
| 87 |
-
else:
|
| 88 |
-
self._image_data = image_data
|
| 89 |
-
|
| 90 |
-
def decode(self):
|
| 91 |
-
im = np.asarray(Image.open(BytesIO(self._image_data)))
|
| 92 |
-
if self.select_channel:
|
| 93 |
-
return torch.Tensor(im).permute(2, 0, 1)[[self._channel]]
|
| 94 |
-
return torch.Tensor(im).permute(2, 0, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dinov2/data/datasets/extended.py
DELETED
|
@@ -1,44 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
from typing import Any, Tuple
|
| 7 |
-
|
| 8 |
-
from torchvision.datasets import VisionDataset
|
| 9 |
-
|
| 10 |
-
from .decoders import DecoderType, TargetDecoder
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
class ExtendedVisionDataset(VisionDataset):
|
| 14 |
-
def __init__(self, *args, **kwargs) -> None:
|
| 15 |
-
image_decoder_type = kwargs.pop("image_decoder_type", DecoderType.ImageDataDecoder)
|
| 16 |
-
self._decoder_params = {}
|
| 17 |
-
self._image_decoder_class = image_decoder_type.get_class()
|
| 18 |
-
if "image_decoder_params" in kwargs:
|
| 19 |
-
self._decoder_params = kwargs.pop("image_decoder_params")
|
| 20 |
-
|
| 21 |
-
super().__init__(*args, **kwargs) # type: ignore
|
| 22 |
-
|
| 23 |
-
def get_image_data(self, index: int) -> bytes:
|
| 24 |
-
raise NotImplementedError
|
| 25 |
-
|
| 26 |
-
def get_target(self, index: int) -> Any:
|
| 27 |
-
raise NotImplementedError
|
| 28 |
-
|
| 29 |
-
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
| 30 |
-
try:
|
| 31 |
-
image_data = self.get_image_data(index)
|
| 32 |
-
image = self._image_decoder_class(image_data, **self._decoder_params).decode()
|
| 33 |
-
except Exception as e:
|
| 34 |
-
raise RuntimeError(f"can not read image for sample {index}") from e
|
| 35 |
-
target = self.get_target(index)
|
| 36 |
-
target = TargetDecoder(target).decode()
|
| 37 |
-
|
| 38 |
-
if self.transforms is not None:
|
| 39 |
-
image, target = self.transforms(image, target)
|
| 40 |
-
|
| 41 |
-
return image, target
|
| 42 |
-
|
| 43 |
-
def __len__(self) -> int:
|
| 44 |
-
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dinov2/data/datasets/image_net.py
DELETED
|
@@ -1,290 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
import csv
|
| 7 |
-
from enum import Enum
|
| 8 |
-
import logging
|
| 9 |
-
import os
|
| 10 |
-
from typing import Callable, List, Optional, Tuple, Union
|
| 11 |
-
|
| 12 |
-
import numpy as np
|
| 13 |
-
|
| 14 |
-
from .extended import ExtendedVisionDataset
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
logger = logging.getLogger("dinov2")
|
| 18 |
-
_Target = int
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
class _Split(Enum):
|
| 22 |
-
TRAIN = "train"
|
| 23 |
-
VAL = "val"
|
| 24 |
-
TEST = "test" # NOTE: torchvision does not support the test split
|
| 25 |
-
|
| 26 |
-
@property
|
| 27 |
-
def length(self) -> int:
|
| 28 |
-
split_lengths = {
|
| 29 |
-
_Split.TRAIN: 1_281_167,
|
| 30 |
-
_Split.VAL: 50_000,
|
| 31 |
-
_Split.TEST: 100_000,
|
| 32 |
-
}
|
| 33 |
-
return split_lengths[self]
|
| 34 |
-
|
| 35 |
-
def get_dirname(self, class_id: Optional[str] = None) -> str:
|
| 36 |
-
return self.value if class_id is None else os.path.join(self.value, class_id)
|
| 37 |
-
|
| 38 |
-
def get_image_relpath(self, actual_index: int, class_id: Optional[str] = None) -> str:
|
| 39 |
-
dirname = self.get_dirname(class_id)
|
| 40 |
-
if self == _Split.TRAIN:
|
| 41 |
-
basename = f"{class_id}_{actual_index}"
|
| 42 |
-
else: # self in (_Split.VAL, _Split.TEST):
|
| 43 |
-
basename = f"ILSVRC2012_{self.value}_{actual_index:08d}"
|
| 44 |
-
return os.path.join(dirname, basename + ".JPEG")
|
| 45 |
-
|
| 46 |
-
def parse_image_relpath(self, image_relpath: str) -> Tuple[str, int]:
|
| 47 |
-
assert self != _Split.TEST
|
| 48 |
-
dirname, filename = os.path.split(image_relpath)
|
| 49 |
-
class_id = os.path.split(dirname)[-1]
|
| 50 |
-
basename, _ = os.path.splitext(filename)
|
| 51 |
-
actual_index = int(basename.split("_")[-1])
|
| 52 |
-
return class_id, actual_index
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
class ImageNet(ExtendedVisionDataset):
|
| 56 |
-
Target = Union[_Target]
|
| 57 |
-
Split = Union[_Split]
|
| 58 |
-
|
| 59 |
-
def __init__(
|
| 60 |
-
self,
|
| 61 |
-
*,
|
| 62 |
-
split: "ImageNet.Split",
|
| 63 |
-
root: str,
|
| 64 |
-
extra: str,
|
| 65 |
-
transforms: Optional[Callable] = None,
|
| 66 |
-
transform: Optional[Callable] = None,
|
| 67 |
-
target_transform: Optional[Callable] = None,
|
| 68 |
-
) -> None:
|
| 69 |
-
super().__init__(root, transforms, transform, target_transform)
|
| 70 |
-
self._extra_root = extra
|
| 71 |
-
self._split = split
|
| 72 |
-
|
| 73 |
-
self._entries = None
|
| 74 |
-
self._class_ids = None
|
| 75 |
-
self._class_names = None
|
| 76 |
-
|
| 77 |
-
@property
|
| 78 |
-
def split(self) -> "ImageNet.Split":
|
| 79 |
-
return self._split
|
| 80 |
-
|
| 81 |
-
def _get_extra_full_path(self, extra_path: str) -> str:
|
| 82 |
-
return os.path.join(self._extra_root, extra_path)
|
| 83 |
-
|
| 84 |
-
def _load_extra(self, extra_path: str) -> np.ndarray:
|
| 85 |
-
extra_full_path = self._get_extra_full_path(extra_path)
|
| 86 |
-
return np.load(extra_full_path, mmap_mode="r")
|
| 87 |
-
|
| 88 |
-
def _save_extra(self, extra_array: np.ndarray, extra_path: str) -> None:
|
| 89 |
-
extra_full_path = self._get_extra_full_path(extra_path)
|
| 90 |
-
os.makedirs(self._extra_root, exist_ok=True)
|
| 91 |
-
np.save(extra_full_path, extra_array)
|
| 92 |
-
|
| 93 |
-
@property
|
| 94 |
-
def _entries_path(self) -> str:
|
| 95 |
-
return f"entries-{self._split.value.upper()}.npy"
|
| 96 |
-
|
| 97 |
-
@property
|
| 98 |
-
def _class_ids_path(self) -> str:
|
| 99 |
-
return f"class-ids-{self._split.value.upper()}.npy"
|
| 100 |
-
|
| 101 |
-
@property
|
| 102 |
-
def _class_names_path(self) -> str:
|
| 103 |
-
return f"class-names-{self._split.value.upper()}.npy"
|
| 104 |
-
|
| 105 |
-
def _get_entries(self) -> np.ndarray:
|
| 106 |
-
if self._entries is None:
|
| 107 |
-
self._entries = self._load_extra(self._entries_path)
|
| 108 |
-
assert self._entries is not None
|
| 109 |
-
return self._entries
|
| 110 |
-
|
| 111 |
-
def _get_class_ids(self) -> np.ndarray:
|
| 112 |
-
if self._split == _Split.TEST:
|
| 113 |
-
assert False, "Class IDs are not available in TEST split"
|
| 114 |
-
if self._class_ids is None:
|
| 115 |
-
self._class_ids = self._load_extra(self._class_ids_path)
|
| 116 |
-
assert self._class_ids is not None
|
| 117 |
-
return self._class_ids
|
| 118 |
-
|
| 119 |
-
def _get_class_names(self) -> np.ndarray:
|
| 120 |
-
if self._split == _Split.TEST:
|
| 121 |
-
assert False, "Class names are not available in TEST split"
|
| 122 |
-
if self._class_names is None:
|
| 123 |
-
self._class_names = self._load_extra(self._class_names_path)
|
| 124 |
-
assert self._class_names is not None
|
| 125 |
-
return self._class_names
|
| 126 |
-
|
| 127 |
-
def find_class_id(self, class_index: int) -> str:
|
| 128 |
-
class_ids = self._get_class_ids()
|
| 129 |
-
return str(class_ids[class_index])
|
| 130 |
-
|
| 131 |
-
def find_class_name(self, class_index: int) -> str:
|
| 132 |
-
class_names = self._get_class_names()
|
| 133 |
-
return str(class_names[class_index])
|
| 134 |
-
|
| 135 |
-
def get_image_data(self, index: int) -> bytes:
|
| 136 |
-
entries = self._get_entries()
|
| 137 |
-
actual_index = entries[index]["actual_index"]
|
| 138 |
-
|
| 139 |
-
class_id = self.get_class_id(index)
|
| 140 |
-
|
| 141 |
-
image_relpath = self.split.get_image_relpath(actual_index, class_id)
|
| 142 |
-
image_full_path = os.path.join(self.root, image_relpath)
|
| 143 |
-
with open(image_full_path, mode="rb") as f:
|
| 144 |
-
image_data = f.read()
|
| 145 |
-
return image_data
|
| 146 |
-
|
| 147 |
-
def get_target(self, index: int) -> Optional[Target]:
|
| 148 |
-
entries = self._get_entries()
|
| 149 |
-
class_index = entries[index]["class_index"]
|
| 150 |
-
return None if self.split == _Split.TEST else int(class_index)
|
| 151 |
-
|
| 152 |
-
def get_targets(self) -> Optional[np.ndarray]:
|
| 153 |
-
entries = self._get_entries()
|
| 154 |
-
return None if self.split == _Split.TEST else entries["class_index"]
|
| 155 |
-
|
| 156 |
-
def get_class_id(self, index: int) -> Optional[str]:
|
| 157 |
-
entries = self._get_entries()
|
| 158 |
-
class_id = entries[index]["class_id"]
|
| 159 |
-
return None if self.split == _Split.TEST else str(class_id)
|
| 160 |
-
|
| 161 |
-
def get_class_name(self, index: int) -> Optional[str]:
|
| 162 |
-
entries = self._get_entries()
|
| 163 |
-
class_name = entries[index]["class_name"]
|
| 164 |
-
return None if self.split == _Split.TEST else str(class_name)
|
| 165 |
-
|
| 166 |
-
def __len__(self) -> int:
|
| 167 |
-
entries = self._get_entries()
|
| 168 |
-
assert len(entries) == self.split.length
|
| 169 |
-
return len(entries)
|
| 170 |
-
|
| 171 |
-
def _load_labels(self, labels_path: str) -> List[Tuple[str, str]]:
|
| 172 |
-
labels_full_path = os.path.join(self.root, labels_path)
|
| 173 |
-
labels = []
|
| 174 |
-
|
| 175 |
-
try:
|
| 176 |
-
with open(labels_full_path, "r") as f:
|
| 177 |
-
reader = csv.reader(f)
|
| 178 |
-
for row in reader:
|
| 179 |
-
class_id, class_name = row
|
| 180 |
-
labels.append((class_id, class_name))
|
| 181 |
-
except OSError as e:
|
| 182 |
-
raise RuntimeError(f'can not read labels file "{labels_full_path}"') from e
|
| 183 |
-
|
| 184 |
-
return labels
|
| 185 |
-
|
| 186 |
-
def _dump_entries(self) -> None:
|
| 187 |
-
split = self.split
|
| 188 |
-
if split == ImageNet.Split.TEST:
|
| 189 |
-
dataset = None
|
| 190 |
-
sample_count = split.length
|
| 191 |
-
max_class_id_length, max_class_name_length = 0, 0
|
| 192 |
-
else:
|
| 193 |
-
labels_path = "labels.txt"
|
| 194 |
-
logger.info(f'loading labels from "{labels_path}"')
|
| 195 |
-
labels = self._load_labels(labels_path)
|
| 196 |
-
|
| 197 |
-
# NOTE: Using torchvision ImageFolder for consistency
|
| 198 |
-
from torchvision.datasets import ImageFolder
|
| 199 |
-
|
| 200 |
-
dataset_root = os.path.join(self.root, split.get_dirname())
|
| 201 |
-
dataset = ImageFolder(dataset_root)
|
| 202 |
-
sample_count = len(dataset)
|
| 203 |
-
max_class_id_length, max_class_name_length = -1, -1
|
| 204 |
-
for sample in dataset.samples:
|
| 205 |
-
_, class_index = sample
|
| 206 |
-
class_id, class_name = labels[class_index]
|
| 207 |
-
max_class_id_length = max(len(class_id), max_class_id_length)
|
| 208 |
-
max_class_name_length = max(len(class_name), max_class_name_length)
|
| 209 |
-
|
| 210 |
-
dtype = np.dtype(
|
| 211 |
-
[
|
| 212 |
-
("actual_index", "<u4"),
|
| 213 |
-
("class_index", "<u4"),
|
| 214 |
-
("class_id", f"U{max_class_id_length}"),
|
| 215 |
-
("class_name", f"U{max_class_name_length}"),
|
| 216 |
-
]
|
| 217 |
-
)
|
| 218 |
-
entries_array = np.empty(sample_count, dtype=dtype)
|
| 219 |
-
|
| 220 |
-
if split == ImageNet.Split.TEST:
|
| 221 |
-
old_percent = -1
|
| 222 |
-
for index in range(sample_count):
|
| 223 |
-
percent = 100 * (index + 1) // sample_count
|
| 224 |
-
if percent > old_percent:
|
| 225 |
-
logger.info(f"creating entries: {percent}%")
|
| 226 |
-
old_percent = percent
|
| 227 |
-
|
| 228 |
-
actual_index = index + 1
|
| 229 |
-
class_index = np.uint32(-1)
|
| 230 |
-
class_id, class_name = "", ""
|
| 231 |
-
entries_array[index] = (actual_index, class_index, class_id, class_name)
|
| 232 |
-
else:
|
| 233 |
-
class_names = {class_id: class_name for class_id, class_name in labels}
|
| 234 |
-
|
| 235 |
-
assert dataset
|
| 236 |
-
old_percent = -1
|
| 237 |
-
for index in range(sample_count):
|
| 238 |
-
percent = 100 * (index + 1) // sample_count
|
| 239 |
-
if percent > old_percent:
|
| 240 |
-
logger.info(f"creating entries: {percent}%")
|
| 241 |
-
old_percent = percent
|
| 242 |
-
|
| 243 |
-
image_full_path, class_index = dataset.samples[index]
|
| 244 |
-
image_relpath = os.path.relpath(image_full_path, self.root)
|
| 245 |
-
class_id, actual_index = split.parse_image_relpath(image_relpath)
|
| 246 |
-
class_name = class_names[class_id]
|
| 247 |
-
entries_array[index] = (actual_index, class_index, class_id, class_name)
|
| 248 |
-
|
| 249 |
-
logger.info(f'saving entries to "{self._entries_path}"')
|
| 250 |
-
self._save_extra(entries_array, self._entries_path)
|
| 251 |
-
|
| 252 |
-
def _dump_class_ids_and_names(self) -> None:
|
| 253 |
-
split = self.split
|
| 254 |
-
if split == ImageNet.Split.TEST:
|
| 255 |
-
return
|
| 256 |
-
|
| 257 |
-
entries_array = self._load_extra(self._entries_path)
|
| 258 |
-
|
| 259 |
-
max_class_id_length, max_class_name_length, max_class_index = -1, -1, -1
|
| 260 |
-
for entry in entries_array:
|
| 261 |
-
class_index, class_id, class_name = (
|
| 262 |
-
entry["class_index"],
|
| 263 |
-
entry["class_id"],
|
| 264 |
-
entry["class_name"],
|
| 265 |
-
)
|
| 266 |
-
max_class_index = max(int(class_index), max_class_index)
|
| 267 |
-
max_class_id_length = max(len(str(class_id)), max_class_id_length)
|
| 268 |
-
max_class_name_length = max(len(str(class_name)), max_class_name_length)
|
| 269 |
-
|
| 270 |
-
class_count = max_class_index + 1
|
| 271 |
-
class_ids_array = np.empty(class_count, dtype=f"U{max_class_id_length}")
|
| 272 |
-
class_names_array = np.empty(class_count, dtype=f"U{max_class_name_length}")
|
| 273 |
-
for entry in entries_array:
|
| 274 |
-
class_index, class_id, class_name = (
|
| 275 |
-
entry["class_index"],
|
| 276 |
-
entry["class_id"],
|
| 277 |
-
entry["class_name"],
|
| 278 |
-
)
|
| 279 |
-
class_ids_array[class_index] = class_id
|
| 280 |
-
class_names_array[class_index] = class_name
|
| 281 |
-
|
| 282 |
-
logger.info(f'saving class IDs to "{self._class_ids_path}"')
|
| 283 |
-
self._save_extra(class_ids_array, self._class_ids_path)
|
| 284 |
-
|
| 285 |
-
logger.info(f'saving class names to "{self._class_names_path}"')
|
| 286 |
-
self._save_extra(class_names_array, self._class_names_path)
|
| 287 |
-
|
| 288 |
-
def dump_extra(self) -> None:
|
| 289 |
-
self._dump_entries()
|
| 290 |
-
self._dump_class_ids_and_names()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dinov2/data/datasets/image_net_22k.py
DELETED
|
@@ -1,302 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
from dataclasses import dataclass
|
| 7 |
-
from enum import Enum
|
| 8 |
-
from functools import lru_cache
|
| 9 |
-
from gzip import GzipFile
|
| 10 |
-
from io import BytesIO
|
| 11 |
-
from mmap import ACCESS_READ, mmap
|
| 12 |
-
import os
|
| 13 |
-
from typing import Any, Callable, List, Optional, Set, Tuple
|
| 14 |
-
import warnings
|
| 15 |
-
|
| 16 |
-
import numpy as np
|
| 17 |
-
|
| 18 |
-
from .extended import ExtendedVisionDataset
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
_Labels = int
|
| 22 |
-
|
| 23 |
-
_DEFAULT_MMAP_CACHE_SIZE = 16 # Warning: This can exhaust file descriptors
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
@dataclass
|
| 27 |
-
class _ClassEntry:
|
| 28 |
-
block_offset: int
|
| 29 |
-
maybe_filename: Optional[str] = None
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
@dataclass
|
| 33 |
-
class _Entry:
|
| 34 |
-
class_index: int # noqa: E701
|
| 35 |
-
start_offset: int
|
| 36 |
-
end_offset: int
|
| 37 |
-
filename: str
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
class _Split(Enum):
|
| 41 |
-
TRAIN = "train"
|
| 42 |
-
VAL = "val"
|
| 43 |
-
|
| 44 |
-
@property
|
| 45 |
-
def length(self) -> int:
|
| 46 |
-
return {
|
| 47 |
-
_Split.TRAIN: 11_797_647,
|
| 48 |
-
_Split.VAL: 561_050,
|
| 49 |
-
}[self]
|
| 50 |
-
|
| 51 |
-
def entries_path(self):
|
| 52 |
-
return f"imagenet21kp_{self.value}.txt"
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
def _get_tarball_path(class_id: str) -> str:
|
| 56 |
-
return f"{class_id}.tar"
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
def _make_mmap_tarball(tarballs_root: str, mmap_cache_size: int):
|
| 60 |
-
@lru_cache(maxsize=mmap_cache_size)
|
| 61 |
-
def _mmap_tarball(class_id: str) -> mmap:
|
| 62 |
-
tarball_path = _get_tarball_path(class_id)
|
| 63 |
-
tarball_full_path = os.path.join(tarballs_root, tarball_path)
|
| 64 |
-
with open(tarball_full_path) as f:
|
| 65 |
-
return mmap(fileno=f.fileno(), length=0, access=ACCESS_READ)
|
| 66 |
-
|
| 67 |
-
return _mmap_tarball
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
class ImageNet22k(ExtendedVisionDataset):
|
| 71 |
-
_GZIPPED_INDICES: Set[int] = {
|
| 72 |
-
841_545,
|
| 73 |
-
1_304_131,
|
| 74 |
-
2_437_921,
|
| 75 |
-
2_672_079,
|
| 76 |
-
2_795_676,
|
| 77 |
-
2_969_786,
|
| 78 |
-
6_902_965,
|
| 79 |
-
6_903_550,
|
| 80 |
-
6_903_628,
|
| 81 |
-
7_432_557,
|
| 82 |
-
7_432_589,
|
| 83 |
-
7_813_809,
|
| 84 |
-
8_329_633,
|
| 85 |
-
10_296_990,
|
| 86 |
-
10_417_652,
|
| 87 |
-
10_492_265,
|
| 88 |
-
10_598_078,
|
| 89 |
-
10_782_398,
|
| 90 |
-
10_902_612,
|
| 91 |
-
11_203_736,
|
| 92 |
-
11_342_890,
|
| 93 |
-
11_397_596,
|
| 94 |
-
11_589_762,
|
| 95 |
-
11_705_103,
|
| 96 |
-
12_936_875,
|
| 97 |
-
13_289_782,
|
| 98 |
-
}
|
| 99 |
-
Labels = _Labels
|
| 100 |
-
|
| 101 |
-
def __init__(
|
| 102 |
-
self,
|
| 103 |
-
*,
|
| 104 |
-
root: str,
|
| 105 |
-
extra: str,
|
| 106 |
-
transforms: Optional[Callable] = None,
|
| 107 |
-
transform: Optional[Callable] = None,
|
| 108 |
-
target_transform: Optional[Callable] = None,
|
| 109 |
-
mmap_cache_size: int = _DEFAULT_MMAP_CACHE_SIZE,
|
| 110 |
-
) -> None:
|
| 111 |
-
super().__init__(root, transforms, transform, target_transform)
|
| 112 |
-
self._extra_root = extra
|
| 113 |
-
|
| 114 |
-
entries_path = self._get_entries_path(root)
|
| 115 |
-
self._entries = self._load_extra(entries_path)
|
| 116 |
-
|
| 117 |
-
class_ids_path = self._get_class_ids_path(root)
|
| 118 |
-
self._class_ids = self._load_extra(class_ids_path)
|
| 119 |
-
|
| 120 |
-
self._gzipped_indices = ImageNet22k._GZIPPED_INDICES
|
| 121 |
-
self._mmap_tarball = _make_mmap_tarball(self._tarballs_root, mmap_cache_size)
|
| 122 |
-
|
| 123 |
-
def _get_entries_path(self, root: Optional[str] = None) -> str:
|
| 124 |
-
return "entries.npy"
|
| 125 |
-
|
| 126 |
-
def _get_class_ids_path(self, root: Optional[str] = None) -> str:
|
| 127 |
-
return "class-ids.npy"
|
| 128 |
-
|
| 129 |
-
def _find_class_ids(self, path: str) -> List[str]:
|
| 130 |
-
class_ids = []
|
| 131 |
-
|
| 132 |
-
with os.scandir(path) as entries:
|
| 133 |
-
for entry in entries:
|
| 134 |
-
root, ext = os.path.splitext(entry.name)
|
| 135 |
-
if ext != ".tar":
|
| 136 |
-
continue
|
| 137 |
-
class_ids.append(root)
|
| 138 |
-
|
| 139 |
-
return sorted(class_ids)
|
| 140 |
-
|
| 141 |
-
def _load_entries_class_ids(self, root: Optional[str] = None) -> Tuple[List[_Entry], List[str]]:
|
| 142 |
-
root = self.get_root(root)
|
| 143 |
-
entries: List[_Entry] = []
|
| 144 |
-
class_ids = self._find_class_ids(root)
|
| 145 |
-
|
| 146 |
-
for class_index, class_id in enumerate(class_ids):
|
| 147 |
-
path = os.path.join(root, "blocks", f"{class_id}.log")
|
| 148 |
-
class_entries = []
|
| 149 |
-
|
| 150 |
-
try:
|
| 151 |
-
with open(path) as f:
|
| 152 |
-
for line in f:
|
| 153 |
-
line = line.rstrip()
|
| 154 |
-
block, filename = line.split(":")
|
| 155 |
-
block_offset = int(block[6:])
|
| 156 |
-
filename = filename[1:]
|
| 157 |
-
|
| 158 |
-
maybe_filename = None
|
| 159 |
-
if filename != "** Block of NULs **":
|
| 160 |
-
maybe_filename = filename
|
| 161 |
-
_, ext = os.path.splitext(filename)
|
| 162 |
-
# assert ext == ".JPEG"
|
| 163 |
-
|
| 164 |
-
class_entry = _ClassEntry(block_offset, maybe_filename)
|
| 165 |
-
class_entries.append(class_entry)
|
| 166 |
-
except OSError as e:
|
| 167 |
-
raise RuntimeError(f'can not read blocks file "{path}"') from e
|
| 168 |
-
|
| 169 |
-
assert class_entries[-1].maybe_filename is None
|
| 170 |
-
|
| 171 |
-
for class_entry1, class_entry2 in zip(class_entries, class_entries[1:]):
|
| 172 |
-
assert class_entry1.block_offset <= class_entry2.block_offset
|
| 173 |
-
start_offset = 512 * class_entry1.block_offset
|
| 174 |
-
end_offset = 512 * class_entry2.block_offset
|
| 175 |
-
assert class_entry1.maybe_filename is not None
|
| 176 |
-
filename = class_entry1.maybe_filename
|
| 177 |
-
entry = _Entry(class_index, start_offset, end_offset, filename)
|
| 178 |
-
# Skip invalid image files (PIL throws UnidentifiedImageError)
|
| 179 |
-
if filename == "n06470073_47249.JPEG":
|
| 180 |
-
continue
|
| 181 |
-
entries.append(entry)
|
| 182 |
-
|
| 183 |
-
return entries, class_ids
|
| 184 |
-
|
| 185 |
-
def _load_extra(self, extra_path: str) -> np.ndarray:
|
| 186 |
-
extra_root = self._extra_root
|
| 187 |
-
extra_full_path = os.path.join(extra_root, extra_path)
|
| 188 |
-
return np.load(extra_full_path, mmap_mode="r")
|
| 189 |
-
|
| 190 |
-
def _save_extra(self, extra_array: np.ndarray, extra_path: str) -> None:
|
| 191 |
-
extra_root = self._extra_root
|
| 192 |
-
extra_full_path = os.path.join(extra_root, extra_path)
|
| 193 |
-
os.makedirs(extra_root, exist_ok=True)
|
| 194 |
-
np.save(extra_full_path, extra_array)
|
| 195 |
-
|
| 196 |
-
@property
|
| 197 |
-
def _tarballs_root(self) -> str:
|
| 198 |
-
return self.root
|
| 199 |
-
|
| 200 |
-
def find_class_id(self, class_index: int) -> str:
|
| 201 |
-
return str(self._class_ids[class_index])
|
| 202 |
-
|
| 203 |
-
def get_image_data(self, index: int) -> bytes:
|
| 204 |
-
entry = self._entries[index]
|
| 205 |
-
class_id = entry["class_id"]
|
| 206 |
-
class_mmap = self._mmap_tarball(class_id)
|
| 207 |
-
|
| 208 |
-
start_offset, end_offset = entry["start_offset"], entry["end_offset"]
|
| 209 |
-
try:
|
| 210 |
-
mapped_data = class_mmap[start_offset:end_offset]
|
| 211 |
-
data = mapped_data[512:] # Skip entry header block
|
| 212 |
-
|
| 213 |
-
if len(data) >= 2 and tuple(data[:2]) == (0x1F, 0x8B):
|
| 214 |
-
assert index in self._gzipped_indices, f"unexpected gzip header for sample {index}"
|
| 215 |
-
with GzipFile(fileobj=BytesIO(data)) as g:
|
| 216 |
-
data = g.read()
|
| 217 |
-
except Exception as e:
|
| 218 |
-
raise RuntimeError(f"can not retrieve image data for sample {index} " f'from "{class_id}" tarball') from e
|
| 219 |
-
|
| 220 |
-
return data
|
| 221 |
-
|
| 222 |
-
def get_target(self, index: int) -> Any:
|
| 223 |
-
return int(self._entries[index]["class_index"])
|
| 224 |
-
|
| 225 |
-
def get_targets(self) -> np.ndarray:
|
| 226 |
-
return self._entries["class_index"]
|
| 227 |
-
|
| 228 |
-
def get_class_id(self, index: int) -> str:
|
| 229 |
-
return str(self._entries[index]["class_id"])
|
| 230 |
-
|
| 231 |
-
def get_class_ids(self) -> np.ndarray:
|
| 232 |
-
return self._entries["class_id"]
|
| 233 |
-
|
| 234 |
-
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
| 235 |
-
with warnings.catch_warnings():
|
| 236 |
-
warnings.simplefilter("ignore")
|
| 237 |
-
return super().__getitem__(index)
|
| 238 |
-
|
| 239 |
-
def __len__(self) -> int:
|
| 240 |
-
return len(self._entries)
|
| 241 |
-
|
| 242 |
-
def _dump_entries(self, *args, **kwargs) -> None:
|
| 243 |
-
entries, class_ids = self._load_entries_class_ids(*args, **kwargs)
|
| 244 |
-
|
| 245 |
-
max_class_id_length, max_filename_length, max_class_index = -1, -1, -1
|
| 246 |
-
for entry in entries:
|
| 247 |
-
class_id = class_ids[entry.class_index]
|
| 248 |
-
max_class_index = max(entry.class_index, max_class_index)
|
| 249 |
-
max_class_id_length = max(len(class_id), max_class_id_length)
|
| 250 |
-
max_filename_length = max(len(entry.filename), max_filename_length)
|
| 251 |
-
|
| 252 |
-
dtype = np.dtype(
|
| 253 |
-
[
|
| 254 |
-
("class_index", "<u4"),
|
| 255 |
-
("class_id", f"U{max_class_id_length}"),
|
| 256 |
-
("start_offset", "<u4"),
|
| 257 |
-
("end_offset", "<u4"),
|
| 258 |
-
("filename", f"U{max_filename_length}"),
|
| 259 |
-
]
|
| 260 |
-
)
|
| 261 |
-
sample_count = len(entries)
|
| 262 |
-
entries_array = np.empty(sample_count, dtype=dtype)
|
| 263 |
-
for i, entry in enumerate(entries):
|
| 264 |
-
class_index = entry.class_index
|
| 265 |
-
class_id = class_ids[class_index]
|
| 266 |
-
start_offset = entry.start_offset
|
| 267 |
-
end_offset = entry.end_offset
|
| 268 |
-
filename = entry.filename
|
| 269 |
-
entries_array[i] = (
|
| 270 |
-
class_index,
|
| 271 |
-
class_id,
|
| 272 |
-
start_offset,
|
| 273 |
-
end_offset,
|
| 274 |
-
filename,
|
| 275 |
-
)
|
| 276 |
-
|
| 277 |
-
entries_path = self._get_entries_path(*args, **kwargs)
|
| 278 |
-
self._save_extra(entries_array, entries_path)
|
| 279 |
-
|
| 280 |
-
def _dump_class_ids(self, *args, **kwargs) -> None:
|
| 281 |
-
entries_path = self._get_entries_path(*args, **kwargs)
|
| 282 |
-
entries_array = self._load_extra(entries_path)
|
| 283 |
-
|
| 284 |
-
max_class_id_length, max_class_index = -1, -1
|
| 285 |
-
for entry in entries_array:
|
| 286 |
-
class_index, class_id = entry["class_index"], entry["class_id"]
|
| 287 |
-
max_class_index = max(int(class_index), max_class_index)
|
| 288 |
-
max_class_id_length = max(len(str(class_id)), max_class_id_length)
|
| 289 |
-
|
| 290 |
-
class_ids_array = np.empty(max_class_index + 1, dtype=f"U{max_class_id_length}")
|
| 291 |
-
for entry in entries_array:
|
| 292 |
-
class_index, class_id = entry["class_index"], entry["class_id"]
|
| 293 |
-
class_ids_array[class_index] = class_id
|
| 294 |
-
class_ids_path = self._get_class_ids_path(*args, **kwargs)
|
| 295 |
-
self._save_extra(class_ids_array, class_ids_path)
|
| 296 |
-
|
| 297 |
-
def _dump_extra(self, *args, **kwargs) -> None:
|
| 298 |
-
self._dump_entries(*args, *kwargs)
|
| 299 |
-
self._dump_class_ids(*args, *kwargs)
|
| 300 |
-
|
| 301 |
-
def dump_extra(self, root: Optional[str] = None) -> None:
|
| 302 |
-
return self._dump_extra(root)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dinov2/data/loaders.py
DELETED
|
@@ -1,232 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
import logging
|
| 7 |
-
from enum import Enum
|
| 8 |
-
from typing import Any, Callable, List, Optional, TypeVar
|
| 9 |
-
|
| 10 |
-
import torch
|
| 11 |
-
from torch.utils.data import Sampler
|
| 12 |
-
|
| 13 |
-
from .datasets import ImageNet, ImageNet22k, HPAone, HPAFoV, CHAMMI_CP, CHAMMI_HPA, CHAMMI_WTC
|
| 14 |
-
from .samplers import EpochSampler, InfiniteSampler, ShardedInfiniteSampler
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
logger = logging.getLogger("dinov2")
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
class SamplerType(Enum):
|
| 21 |
-
DISTRIBUTED = 0
|
| 22 |
-
EPOCH = 1
|
| 23 |
-
INFINITE = 2
|
| 24 |
-
SHARDED_INFINITE = 3
|
| 25 |
-
SHARDED_INFINITE_NEW = 4
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
def _make_bool_str(b: bool) -> str:
|
| 29 |
-
return "yes" if b else "no"
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
def _make_sample_transform(image_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None):
|
| 33 |
-
def transform(sample):
|
| 34 |
-
image, target = sample
|
| 35 |
-
if image_transform is not None:
|
| 36 |
-
image = image_transform(image)
|
| 37 |
-
if target_transform is not None:
|
| 38 |
-
target = target_transform(target)
|
| 39 |
-
return image, target
|
| 40 |
-
|
| 41 |
-
return transform
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
def _parse_dataset_str(dataset_str: str):
|
| 45 |
-
tokens = dataset_str.split(":")
|
| 46 |
-
|
| 47 |
-
name = tokens[0]
|
| 48 |
-
kwargs = {}
|
| 49 |
-
|
| 50 |
-
for token in tokens[1:]:
|
| 51 |
-
key, value = token.split("=")
|
| 52 |
-
assert key in ("root", "extra", "split", "mode", "wildcard")
|
| 53 |
-
kwargs[key] = value
|
| 54 |
-
|
| 55 |
-
if name == "ImageNet":
|
| 56 |
-
class_ = ImageNet
|
| 57 |
-
if "split" in kwargs:
|
| 58 |
-
kwargs["split"] = ImageNet.Split[kwargs["split"]]
|
| 59 |
-
elif name == "ImageNet22k":
|
| 60 |
-
class_ = ImageNet22k
|
| 61 |
-
elif name == "HPAone":
|
| 62 |
-
class_ = HPAone
|
| 63 |
-
elif name == "HPAFoV":
|
| 64 |
-
class_ = HPAFoV
|
| 65 |
-
elif name == "CHAMMI_CP":
|
| 66 |
-
class_ = CHAMMI_CP
|
| 67 |
-
elif name == "CHAMMI_WTC":
|
| 68 |
-
class_ = CHAMMI_WTC
|
| 69 |
-
elif name == "CHAMMI_HPA":
|
| 70 |
-
class_ = CHAMMI_HPA
|
| 71 |
-
else:
|
| 72 |
-
raise ValueError(f'Unsupported dataset "{name}"')
|
| 73 |
-
|
| 74 |
-
return class_, kwargs
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
def make_dataset(
|
| 78 |
-
*,
|
| 79 |
-
dataset_str: str,
|
| 80 |
-
transform: Optional[Callable] = None,
|
| 81 |
-
target_transform: Optional[Callable] = None,
|
| 82 |
-
):
|
| 83 |
-
"""
|
| 84 |
-
Creates a dataset with the specified parameters.
|
| 85 |
-
|
| 86 |
-
Args:
|
| 87 |
-
dataset_str: A dataset string description (e.g. ImageNet:split=TRAIN).
|
| 88 |
-
transform: A transform to apply to images.
|
| 89 |
-
target_transform: A transform to apply to targets.
|
| 90 |
-
|
| 91 |
-
Returns:
|
| 92 |
-
The created dataset.
|
| 93 |
-
"""
|
| 94 |
-
logger.info(f'using dataset: "{dataset_str}"')
|
| 95 |
-
|
| 96 |
-
class_, kwargs = _parse_dataset_str(dataset_str)
|
| 97 |
-
dataset = class_(transform=transform, target_transform=target_transform, **kwargs)
|
| 98 |
-
|
| 99 |
-
logger.info(f"# of dataset samples: {len(dataset):,d}")
|
| 100 |
-
|
| 101 |
-
# Aggregated datasets do not expose (yet) these attributes, so add them.
|
| 102 |
-
if not hasattr(dataset, "transform"):
|
| 103 |
-
setattr(dataset, "transform", transform)
|
| 104 |
-
if not hasattr(dataset, "target_transform"):
|
| 105 |
-
setattr(dataset, "target_transform", target_transform)
|
| 106 |
-
|
| 107 |
-
return dataset
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
def _make_sampler(
|
| 111 |
-
*,
|
| 112 |
-
dataset,
|
| 113 |
-
type: Optional[SamplerType] = None,
|
| 114 |
-
shuffle: bool = False,
|
| 115 |
-
seed: int = 0,
|
| 116 |
-
size: int = -1,
|
| 117 |
-
advance: int = 0,
|
| 118 |
-
) -> Optional[Sampler]:
|
| 119 |
-
sample_count = len(dataset)
|
| 120 |
-
|
| 121 |
-
if type == SamplerType.INFINITE:
|
| 122 |
-
logger.info("sampler: infinite")
|
| 123 |
-
if size > 0:
|
| 124 |
-
raise ValueError("sampler size > 0 is invalid")
|
| 125 |
-
return InfiniteSampler(
|
| 126 |
-
sample_count=sample_count,
|
| 127 |
-
shuffle=shuffle,
|
| 128 |
-
seed=seed,
|
| 129 |
-
advance=advance,
|
| 130 |
-
)
|
| 131 |
-
elif type in (SamplerType.SHARDED_INFINITE, SamplerType.SHARDED_INFINITE_NEW):
|
| 132 |
-
logger.info("sampler: sharded infinite")
|
| 133 |
-
if size > 0:
|
| 134 |
-
raise ValueError("sampler size > 0 is invalid")
|
| 135 |
-
# TODO: Remove support for old shuffling
|
| 136 |
-
use_new_shuffle_tensor_slice = type == SamplerType.SHARDED_INFINITE_NEW
|
| 137 |
-
return ShardedInfiniteSampler(
|
| 138 |
-
sample_count=sample_count,
|
| 139 |
-
shuffle=shuffle,
|
| 140 |
-
seed=seed,
|
| 141 |
-
advance=advance,
|
| 142 |
-
use_new_shuffle_tensor_slice=use_new_shuffle_tensor_slice,
|
| 143 |
-
)
|
| 144 |
-
elif type == SamplerType.EPOCH:
|
| 145 |
-
logger.info("sampler: epoch")
|
| 146 |
-
if advance > 0:
|
| 147 |
-
raise NotImplementedError("sampler advance > 0 is not supported")
|
| 148 |
-
size = size if size > 0 else sample_count
|
| 149 |
-
logger.info(f"# of samples / epoch: {size:,d}")
|
| 150 |
-
return EpochSampler(
|
| 151 |
-
size=size,
|
| 152 |
-
sample_count=sample_count,
|
| 153 |
-
shuffle=shuffle,
|
| 154 |
-
seed=seed,
|
| 155 |
-
)
|
| 156 |
-
elif type == SamplerType.DISTRIBUTED:
|
| 157 |
-
logger.info("sampler: distributed")
|
| 158 |
-
if size > 0:
|
| 159 |
-
raise ValueError("sampler size > 0 is invalid")
|
| 160 |
-
if advance > 0:
|
| 161 |
-
raise ValueError("sampler advance > 0 is invalid")
|
| 162 |
-
return torch.utils.data.DistributedSampler(
|
| 163 |
-
dataset=dataset,
|
| 164 |
-
shuffle=shuffle,
|
| 165 |
-
seed=seed,
|
| 166 |
-
drop_last=False,
|
| 167 |
-
)
|
| 168 |
-
|
| 169 |
-
logger.info("sampler: none")
|
| 170 |
-
return None
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
T = TypeVar("T")
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
def make_data_loader(
|
| 177 |
-
*,
|
| 178 |
-
dataset,
|
| 179 |
-
batch_size: int,
|
| 180 |
-
num_workers: int,
|
| 181 |
-
shuffle: bool = True,
|
| 182 |
-
seed: int = 0,
|
| 183 |
-
sampler_type: Optional[SamplerType] = SamplerType.INFINITE,
|
| 184 |
-
sampler_size: int = -1,
|
| 185 |
-
sampler_advance: int = 0,
|
| 186 |
-
drop_last: bool = True,
|
| 187 |
-
persistent_workers: bool = False,
|
| 188 |
-
collate_fn: Optional[Callable[[List[T]], Any]] = None,
|
| 189 |
-
):
|
| 190 |
-
"""
|
| 191 |
-
Creates a data loader with the specified parameters.
|
| 192 |
-
|
| 193 |
-
Args:
|
| 194 |
-
dataset: A dataset (third party, LaViDa or WebDataset).
|
| 195 |
-
batch_size: The size of batches to generate.
|
| 196 |
-
num_workers: The number of workers to use.
|
| 197 |
-
shuffle: Whether to shuffle samples.
|
| 198 |
-
seed: The random seed to use.
|
| 199 |
-
sampler_type: Which sampler to use: EPOCH, INFINITE, SHARDED_INFINITE, SHARDED_INFINITE_NEW, DISTRIBUTED or None.
|
| 200 |
-
sampler_size: The number of images per epoch (when applicable) or -1 for the entire dataset.
|
| 201 |
-
sampler_advance: How many samples to skip (when applicable).
|
| 202 |
-
drop_last: Whether the last non-full batch of data should be dropped.
|
| 203 |
-
persistent_workers: maintain the workers Dataset instances alive after a dataset has been consumed once.
|
| 204 |
-
collate_fn: Function that performs batch collation
|
| 205 |
-
"""
|
| 206 |
-
|
| 207 |
-
sampler = _make_sampler(
|
| 208 |
-
dataset=dataset,
|
| 209 |
-
type=sampler_type,
|
| 210 |
-
shuffle=shuffle,
|
| 211 |
-
seed=seed,
|
| 212 |
-
size=sampler_size,
|
| 213 |
-
advance=sampler_advance,
|
| 214 |
-
)
|
| 215 |
-
|
| 216 |
-
logger.info("using PyTorch data loader")
|
| 217 |
-
data_loader = torch.utils.data.DataLoader(
|
| 218 |
-
dataset,
|
| 219 |
-
sampler=sampler,
|
| 220 |
-
batch_size=batch_size,
|
| 221 |
-
num_workers=num_workers,
|
| 222 |
-
pin_memory=True,
|
| 223 |
-
drop_last=drop_last,
|
| 224 |
-
persistent_workers=persistent_workers,
|
| 225 |
-
collate_fn=collate_fn,
|
| 226 |
-
)
|
| 227 |
-
|
| 228 |
-
try:
|
| 229 |
-
logger.info(f"# of batches: {len(data_loader):,d}")
|
| 230 |
-
except TypeError: # data loader has no length
|
| 231 |
-
logger.info("infinite data loader")
|
| 232 |
-
return data_loader
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dinov2/data/masking.py
DELETED
|
@@ -1,86 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
import random
|
| 7 |
-
import math
|
| 8 |
-
import numpy as np
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
class MaskingGenerator:
|
| 12 |
-
def __init__(
|
| 13 |
-
self,
|
| 14 |
-
input_size,
|
| 15 |
-
num_masking_patches=None,
|
| 16 |
-
min_num_patches=4,
|
| 17 |
-
max_num_patches=None,
|
| 18 |
-
min_aspect=0.3,
|
| 19 |
-
max_aspect=None,
|
| 20 |
-
):
|
| 21 |
-
if not isinstance(input_size, tuple):
|
| 22 |
-
input_size = (input_size,) * 2
|
| 23 |
-
self.height, self.width = input_size
|
| 24 |
-
|
| 25 |
-
self.num_patches = self.height * self.width
|
| 26 |
-
self.num_masking_patches = num_masking_patches
|
| 27 |
-
|
| 28 |
-
self.min_num_patches = min_num_patches
|
| 29 |
-
self.max_num_patches = num_masking_patches if max_num_patches is None else max_num_patches
|
| 30 |
-
|
| 31 |
-
max_aspect = max_aspect or 1 / min_aspect
|
| 32 |
-
self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
|
| 33 |
-
|
| 34 |
-
def __repr__(self):
|
| 35 |
-
repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
|
| 36 |
-
self.height,
|
| 37 |
-
self.width,
|
| 38 |
-
self.min_num_patches,
|
| 39 |
-
self.max_num_patches,
|
| 40 |
-
self.num_masking_patches,
|
| 41 |
-
self.log_aspect_ratio[0],
|
| 42 |
-
self.log_aspect_ratio[1],
|
| 43 |
-
)
|
| 44 |
-
return repr_str
|
| 45 |
-
|
| 46 |
-
def get_shape(self):
|
| 47 |
-
return self.height, self.width
|
| 48 |
-
|
| 49 |
-
def _mask(self, mask, max_mask_patches):
|
| 50 |
-
delta = 0
|
| 51 |
-
for _ in range(10):
|
| 52 |
-
target_area = random.uniform(self.min_num_patches, max_mask_patches)
|
| 53 |
-
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
|
| 54 |
-
h = int(round(math.sqrt(target_area * aspect_ratio)))
|
| 55 |
-
w = int(round(math.sqrt(target_area / aspect_ratio)))
|
| 56 |
-
if w < self.width and h < self.height:
|
| 57 |
-
top = random.randint(0, self.height - h)
|
| 58 |
-
left = random.randint(0, self.width - w)
|
| 59 |
-
|
| 60 |
-
num_masked = mask[top : top + h, left : left + w].sum()
|
| 61 |
-
# Overlap
|
| 62 |
-
if 0 < h * w - num_masked <= max_mask_patches:
|
| 63 |
-
for i in range(top, top + h):
|
| 64 |
-
for j in range(left, left + w):
|
| 65 |
-
if mask[i, j] == 0:
|
| 66 |
-
mask[i, j] = 1
|
| 67 |
-
delta += 1
|
| 68 |
-
|
| 69 |
-
if delta > 0:
|
| 70 |
-
break
|
| 71 |
-
return delta
|
| 72 |
-
|
| 73 |
-
def __call__(self, num_masking_patches=0):
|
| 74 |
-
mask = np.zeros(shape=self.get_shape(), dtype=bool)
|
| 75 |
-
mask_count = 0
|
| 76 |
-
while mask_count < num_masking_patches:
|
| 77 |
-
max_mask_patches = num_masking_patches - mask_count
|
| 78 |
-
max_mask_patches = min(max_mask_patches, self.max_num_patches)
|
| 79 |
-
|
| 80 |
-
delta = self._mask(mask, max_mask_patches)
|
| 81 |
-
if delta == 0:
|
| 82 |
-
break
|
| 83 |
-
else:
|
| 84 |
-
mask_count += delta
|
| 85 |
-
|
| 86 |
-
return mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dinov2/data/samplers.py
DELETED
|
@@ -1,229 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
import itertools
|
| 7 |
-
from typing import Any, Optional
|
| 8 |
-
import warnings
|
| 9 |
-
|
| 10 |
-
import numpy as np
|
| 11 |
-
import torch
|
| 12 |
-
from torch.utils.data.sampler import Sampler
|
| 13 |
-
|
| 14 |
-
import dinov2.distributed as distributed
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
class EpochSampler(Sampler):
|
| 18 |
-
def __init__(
|
| 19 |
-
self,
|
| 20 |
-
*,
|
| 21 |
-
size: int,
|
| 22 |
-
sample_count: int,
|
| 23 |
-
shuffle: bool = False,
|
| 24 |
-
seed: int = 0,
|
| 25 |
-
start: Optional[int] = None,
|
| 26 |
-
step: Optional[int] = None,
|
| 27 |
-
):
|
| 28 |
-
self._size = size
|
| 29 |
-
self._sample_count = sample_count
|
| 30 |
-
self._shuffle = shuffle
|
| 31 |
-
self._seed = seed
|
| 32 |
-
self._start = distributed.get_global_rank() if start is None else start
|
| 33 |
-
self._step = distributed.get_global_size() if step is None else step
|
| 34 |
-
self._epoch = 0
|
| 35 |
-
|
| 36 |
-
def __iter__(self):
|
| 37 |
-
count = (self._size + self._sample_count - 1) // self._sample_count
|
| 38 |
-
tiled_indices = np.tile(np.arange(self._sample_count), count)
|
| 39 |
-
if self._shuffle:
|
| 40 |
-
seed = self._seed * self._epoch if self._seed != 0 else self._epoch
|
| 41 |
-
rng = np.random.default_rng(seed)
|
| 42 |
-
iterable = rng.choice(tiled_indices, self._size, replace=False)
|
| 43 |
-
else:
|
| 44 |
-
iterable = tiled_indices[: self._size]
|
| 45 |
-
|
| 46 |
-
yield from itertools.islice(iterable, self._start, None, self._step)
|
| 47 |
-
|
| 48 |
-
def __len__(self):
|
| 49 |
-
return (self._size - self._start + self._step - 1) // self._step
|
| 50 |
-
|
| 51 |
-
def set_epoch(self, epoch):
|
| 52 |
-
self._epoch = epoch
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
def _get_numpy_dtype(size: int) -> Any:
|
| 56 |
-
return np.int32 if size <= 2**31 else np.int64
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
def _get_torch_dtype(size: int) -> Any:
|
| 60 |
-
return torch.int32 if size <= 2**31 else torch.int64
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
def _generate_randperm_indices(*, size: int, generator: torch.Generator):
|
| 64 |
-
"""Generate the indices of a random permutation."""
|
| 65 |
-
dtype = _get_torch_dtype(size)
|
| 66 |
-
# This is actually matching PyTorch's CPU implementation, see: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorFactories.cpp#L900-L921
|
| 67 |
-
perm = torch.arange(size, dtype=dtype)
|
| 68 |
-
for i in range(size):
|
| 69 |
-
j = torch.randint(i, size, size=(1,), generator=generator).item()
|
| 70 |
-
|
| 71 |
-
# Always swap even if no-op
|
| 72 |
-
value = perm[j].item()
|
| 73 |
-
perm[j] = perm[i].item()
|
| 74 |
-
perm[i] = value
|
| 75 |
-
yield value
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
class InfiniteSampler(Sampler):
|
| 79 |
-
def __init__(
|
| 80 |
-
self,
|
| 81 |
-
*,
|
| 82 |
-
sample_count: int,
|
| 83 |
-
shuffle: bool = False,
|
| 84 |
-
seed: int = 0,
|
| 85 |
-
start: Optional[int] = None,
|
| 86 |
-
step: Optional[int] = None,
|
| 87 |
-
advance: int = 0,
|
| 88 |
-
):
|
| 89 |
-
self._sample_count = sample_count
|
| 90 |
-
self._seed = seed
|
| 91 |
-
self._shuffle = shuffle
|
| 92 |
-
self._start = distributed.get_global_rank() if start is None else start
|
| 93 |
-
self._step = distributed.get_global_size() if step is None else step
|
| 94 |
-
self._advance = advance
|
| 95 |
-
|
| 96 |
-
def __iter__(self):
|
| 97 |
-
if self._shuffle:
|
| 98 |
-
iterator = self._shuffled_iterator()
|
| 99 |
-
else:
|
| 100 |
-
iterator = self._iterator()
|
| 101 |
-
|
| 102 |
-
yield from itertools.islice(iterator, self._advance, None)
|
| 103 |
-
|
| 104 |
-
def _iterator(self):
|
| 105 |
-
assert not self._shuffle
|
| 106 |
-
|
| 107 |
-
while True:
|
| 108 |
-
iterable = range(self._sample_count)
|
| 109 |
-
yield from itertools.islice(iterable, self._start, None, self._step)
|
| 110 |
-
|
| 111 |
-
def _shuffled_iterator(self):
|
| 112 |
-
assert self._shuffle
|
| 113 |
-
|
| 114 |
-
# Instantiate a generator here (rather than in the ctor) to keep the class
|
| 115 |
-
# picklable (requirement of mp.spawn)
|
| 116 |
-
generator = torch.Generator().manual_seed(self._seed)
|
| 117 |
-
|
| 118 |
-
while True:
|
| 119 |
-
iterable = _generate_randperm_indices(size=self._sample_count, generator=generator)
|
| 120 |
-
yield from itertools.islice(iterable, self._start, None, self._step)
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
# The following function is somewhat equivalent to _new_shuffle_tensor_slice below,
|
| 124 |
-
# but avoids a full in-place random permutation generation.
|
| 125 |
-
def _shuffle_tensor_slice(
|
| 126 |
-
*, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator
|
| 127 |
-
) -> np.ndarray:
|
| 128 |
-
stop = len(tensor)
|
| 129 |
-
count = stop // step
|
| 130 |
-
drop_count = stop - step * count
|
| 131 |
-
if drop_count:
|
| 132 |
-
warnings.warn(f"# of dropped samples: {drop_count}")
|
| 133 |
-
|
| 134 |
-
dtype = _get_numpy_dtype(stop)
|
| 135 |
-
result = np.empty(count, dtype=dtype)
|
| 136 |
-
|
| 137 |
-
for i in range(count):
|
| 138 |
-
j = torch.randint(0, i + 1, size=(1,), generator=generator).item() if i > 0 else 0
|
| 139 |
-
|
| 140 |
-
result[i] = result[j]
|
| 141 |
-
result[j] = tensor[start + i * step].item()
|
| 142 |
-
|
| 143 |
-
return result
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
def _new_shuffle_tensor_slice(
|
| 147 |
-
*, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator
|
| 148 |
-
) -> np.ndarray:
|
| 149 |
-
stop = len(tensor)
|
| 150 |
-
count = stop // step
|
| 151 |
-
dtype = torch.int64 # Needed for using randperm result as indices
|
| 152 |
-
count = stop // step
|
| 153 |
-
drop_count = stop - step * count
|
| 154 |
-
if drop_count:
|
| 155 |
-
warnings.warn(f"# of dropped samples: {drop_count}")
|
| 156 |
-
indices = torch.randperm(count, dtype=dtype, generator=generator)
|
| 157 |
-
return tensor[start::step][indices].numpy()
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
def _make_seed(seed: int, start: int, iter_count: int) -> int:
|
| 161 |
-
# NOTE: Tried a few variants (including iter_count << 32), this one worked best.
|
| 162 |
-
return seed + start + (iter_count << 24)
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
class ShardedInfiniteSampler(Sampler):
|
| 166 |
-
def __init__(
|
| 167 |
-
self,
|
| 168 |
-
*,
|
| 169 |
-
sample_count: int,
|
| 170 |
-
shuffle: bool = False,
|
| 171 |
-
seed: int = 0,
|
| 172 |
-
start: Optional[int] = None,
|
| 173 |
-
step: Optional[int] = None,
|
| 174 |
-
advance: int = 0,
|
| 175 |
-
use_new_shuffle_tensor_slice: bool = False,
|
| 176 |
-
):
|
| 177 |
-
self._sample_count = sample_count
|
| 178 |
-
self._seed = seed
|
| 179 |
-
self._shuffle = shuffle
|
| 180 |
-
self._start = distributed.get_global_rank() if start is None else start
|
| 181 |
-
self._step = distributed.get_global_size() if step is None else step
|
| 182 |
-
self._advance = advance
|
| 183 |
-
self._iter_count = 0
|
| 184 |
-
self._shuffle_tensor_slice_fn = (
|
| 185 |
-
_new_shuffle_tensor_slice if use_new_shuffle_tensor_slice else _shuffle_tensor_slice
|
| 186 |
-
)
|
| 187 |
-
|
| 188 |
-
def __iter__(self):
|
| 189 |
-
iter_count = self._advance // self._sample_count
|
| 190 |
-
if iter_count > 0:
|
| 191 |
-
self._advance -= iter_count * self._sample_count
|
| 192 |
-
self._iter_count += iter_count
|
| 193 |
-
|
| 194 |
-
if self._shuffle:
|
| 195 |
-
iterator = self._shuffled_iterator()
|
| 196 |
-
else:
|
| 197 |
-
iterator = self._iterator()
|
| 198 |
-
|
| 199 |
-
yield from itertools.islice(iterator, self._advance, None)
|
| 200 |
-
|
| 201 |
-
def _iterator(self):
|
| 202 |
-
assert not self._shuffle
|
| 203 |
-
|
| 204 |
-
while True:
|
| 205 |
-
iterable = range(self._sample_count)
|
| 206 |
-
yield from itertools.islice(iterable, self._start, None, self._step)
|
| 207 |
-
|
| 208 |
-
def _shuffled_iterator(self):
|
| 209 |
-
assert self._shuffle
|
| 210 |
-
|
| 211 |
-
# Instantiate a generator here (rather than in the ctor) to be keep the class
|
| 212 |
-
# picklable (requirement of mp.spawn)
|
| 213 |
-
generator = torch.Generator()
|
| 214 |
-
|
| 215 |
-
# Always shuffle everything first
|
| 216 |
-
generator.manual_seed(self._seed)
|
| 217 |
-
dtype = _get_torch_dtype(self._sample_count)
|
| 218 |
-
perm = torch.randperm(self._sample_count, dtype=dtype, generator=generator)
|
| 219 |
-
|
| 220 |
-
while True:
|
| 221 |
-
# Re-seed on each iteration to allow skipping whole permutations
|
| 222 |
-
seed = _make_seed(self._seed, self._start, self._iter_count)
|
| 223 |
-
generator.manual_seed(seed)
|
| 224 |
-
|
| 225 |
-
iterable = self._shuffle_tensor_slice_fn(
|
| 226 |
-
tensor=perm, start=self._start, step=self._step, generator=generator
|
| 227 |
-
)
|
| 228 |
-
yield from iterable
|
| 229 |
-
self._iter_count += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dinov2/data/transforms.py
DELETED
|
@@ -1,91 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
from typing import Sequence
|
| 7 |
-
|
| 8 |
-
import torch
|
| 9 |
-
from torchvision import transforms
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
class GaussianBlur(transforms.RandomApply):
|
| 13 |
-
"""
|
| 14 |
-
Apply Gaussian Blur to the PIL image.
|
| 15 |
-
"""
|
| 16 |
-
|
| 17 |
-
def __init__(self, *, p: float = 0.5, radius_min: float = 0.1, radius_max: float = 2.0):
|
| 18 |
-
# NOTE: torchvision is applying 1 - probability to return the original image
|
| 19 |
-
keep_p = 1 - p
|
| 20 |
-
transform = transforms.GaussianBlur(kernel_size=9, sigma=(radius_min, radius_max))
|
| 21 |
-
super().__init__(transforms=[transform], p=keep_p)
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
class MaybeToTensor(transforms.ToTensor):
|
| 25 |
-
"""
|
| 26 |
-
Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor, or keep as is if already a tensor.
|
| 27 |
-
"""
|
| 28 |
-
|
| 29 |
-
def __call__(self, pic):
|
| 30 |
-
"""
|
| 31 |
-
Args:
|
| 32 |
-
pic (PIL Image, numpy.ndarray or torch.tensor): Image to be converted to tensor.
|
| 33 |
-
Returns:
|
| 34 |
-
Tensor: Converted image.
|
| 35 |
-
"""
|
| 36 |
-
if isinstance(pic, torch.Tensor):
|
| 37 |
-
return pic
|
| 38 |
-
return super().__call__(pic)
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
# Use timm's names
|
| 42 |
-
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
| 43 |
-
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
def make_normalize_transform(
|
| 47 |
-
mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
|
| 48 |
-
std: Sequence[float] = IMAGENET_DEFAULT_STD,
|
| 49 |
-
) -> transforms.Normalize:
|
| 50 |
-
return transforms.Normalize(mean=mean, std=std)
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
# This roughly matches torchvision's preset for classification training:
|
| 54 |
-
# https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L6-L44
|
| 55 |
-
def make_classification_train_transform(
|
| 56 |
-
*,
|
| 57 |
-
crop_size: int = 224,
|
| 58 |
-
interpolation=transforms.InterpolationMode.BICUBIC,
|
| 59 |
-
hflip_prob: float = 0.5,
|
| 60 |
-
mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
|
| 61 |
-
std: Sequence[float] = IMAGENET_DEFAULT_STD,
|
| 62 |
-
):
|
| 63 |
-
transforms_list = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
|
| 64 |
-
if hflip_prob > 0.0:
|
| 65 |
-
transforms_list.append(transforms.RandomHorizontalFlip(hflip_prob))
|
| 66 |
-
transforms_list.extend(
|
| 67 |
-
[
|
| 68 |
-
MaybeToTensor(),
|
| 69 |
-
make_normalize_transform(mean=mean, std=std),
|
| 70 |
-
]
|
| 71 |
-
)
|
| 72 |
-
return transforms.Compose(transforms_list)
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
# This matches (roughly) torchvision's preset for classification evaluation:
|
| 76 |
-
# https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L47-L69
|
| 77 |
-
def make_classification_eval_transform(
|
| 78 |
-
*,
|
| 79 |
-
resize_size: int = 256,
|
| 80 |
-
interpolation=transforms.InterpolationMode.BICUBIC,
|
| 81 |
-
crop_size: int = 224,
|
| 82 |
-
mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
|
| 83 |
-
std: Sequence[float] = IMAGENET_DEFAULT_STD,
|
| 84 |
-
) -> transforms.Compose:
|
| 85 |
-
transforms_list = [
|
| 86 |
-
transforms.Resize(resize_size, interpolation=interpolation),
|
| 87 |
-
transforms.CenterCrop(crop_size),
|
| 88 |
-
MaybeToTensor(),
|
| 89 |
-
make_normalize_transform(mean=mean, std=std),
|
| 90 |
-
]
|
| 91 |
-
return transforms.Compose(transforms_list)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dinov2/distributed/__init__.py
DELETED
|
@@ -1,270 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
import os
|
| 7 |
-
import random
|
| 8 |
-
import re
|
| 9 |
-
import socket
|
| 10 |
-
from typing import Dict, List
|
| 11 |
-
|
| 12 |
-
import torch
|
| 13 |
-
import torch.distributed as dist
|
| 14 |
-
|
| 15 |
-
_LOCAL_RANK = -1
|
| 16 |
-
_LOCAL_WORLD_SIZE = -1
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
def is_enabled() -> bool:
|
| 20 |
-
"""
|
| 21 |
-
Returns:
|
| 22 |
-
True if distributed training is enabled
|
| 23 |
-
"""
|
| 24 |
-
return dist.is_available() and dist.is_initialized()
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
def get_global_size() -> int:
|
| 28 |
-
"""
|
| 29 |
-
Returns:
|
| 30 |
-
The number of processes in the process group
|
| 31 |
-
"""
|
| 32 |
-
return dist.get_world_size() if is_enabled() else 1
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
def get_global_rank() -> int:
|
| 36 |
-
"""
|
| 37 |
-
Returns:
|
| 38 |
-
The rank of the current process within the global process group.
|
| 39 |
-
"""
|
| 40 |
-
return dist.get_rank() if is_enabled() else 0
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
def get_local_rank() -> int:
|
| 44 |
-
"""
|
| 45 |
-
Returns:
|
| 46 |
-
The rank of the current process within the local (per-machine) process group.
|
| 47 |
-
"""
|
| 48 |
-
if not is_enabled():
|
| 49 |
-
return 0
|
| 50 |
-
assert 0 <= _LOCAL_RANK < _LOCAL_WORLD_SIZE
|
| 51 |
-
return _LOCAL_RANK
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
def get_local_size() -> int:
|
| 55 |
-
"""
|
| 56 |
-
Returns:
|
| 57 |
-
The size of the per-machine process group,
|
| 58 |
-
i.e. the number of processes per machine.
|
| 59 |
-
"""
|
| 60 |
-
if not is_enabled():
|
| 61 |
-
return 1
|
| 62 |
-
assert 0 <= _LOCAL_RANK < _LOCAL_WORLD_SIZE
|
| 63 |
-
return _LOCAL_WORLD_SIZE
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
def is_main_process() -> bool:
|
| 67 |
-
"""
|
| 68 |
-
Returns:
|
| 69 |
-
True if the current process is the main one.
|
| 70 |
-
"""
|
| 71 |
-
return get_global_rank() == 0
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
def _restrict_print_to_main_process() -> None:
|
| 75 |
-
"""
|
| 76 |
-
This function disables printing when not in the main process
|
| 77 |
-
"""
|
| 78 |
-
import builtins as __builtin__
|
| 79 |
-
|
| 80 |
-
builtin_print = __builtin__.print
|
| 81 |
-
|
| 82 |
-
def print(*args, **kwargs):
|
| 83 |
-
force = kwargs.pop("force", False)
|
| 84 |
-
if is_main_process() or force:
|
| 85 |
-
builtin_print(*args, **kwargs)
|
| 86 |
-
|
| 87 |
-
__builtin__.print = print
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
def _get_master_port(seed: int = 0) -> int:
|
| 91 |
-
MIN_MASTER_PORT, MAX_MASTER_PORT = (20_000, 60_000)
|
| 92 |
-
|
| 93 |
-
master_port_str = os.environ.get("MASTER_PORT")
|
| 94 |
-
if master_port_str is None:
|
| 95 |
-
rng = random.Random(seed)
|
| 96 |
-
return rng.randint(MIN_MASTER_PORT, MAX_MASTER_PORT)
|
| 97 |
-
|
| 98 |
-
return int(master_port_str)
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
def _get_available_port() -> int:
|
| 102 |
-
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
| 103 |
-
# A "" host address means INADDR_ANY i.e. binding to all interfaces.
|
| 104 |
-
# Note this is not compatible with IPv6.
|
| 105 |
-
s.bind(("", 0))
|
| 106 |
-
port = s.getsockname()[1]
|
| 107 |
-
return port
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
_TORCH_DISTRIBUTED_ENV_VARS = (
|
| 111 |
-
"MASTER_ADDR",
|
| 112 |
-
"MASTER_PORT",
|
| 113 |
-
"RANK",
|
| 114 |
-
"WORLD_SIZE",
|
| 115 |
-
"LOCAL_RANK",
|
| 116 |
-
"LOCAL_WORLD_SIZE",
|
| 117 |
-
)
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
def _collect_env_vars() -> Dict[str, str]:
|
| 121 |
-
return {env_var: os.environ[env_var] for env_var in _TORCH_DISTRIBUTED_ENV_VARS if env_var in os.environ}
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
def _is_slurm_job_process() -> bool:
|
| 125 |
-
return "SLURM_JOB_ID" in os.environ
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
def _parse_slurm_node_list(s: str) -> List[str]:
|
| 129 |
-
nodes = []
|
| 130 |
-
# Extract "hostname", "hostname[1-2,3,4-5]," substrings
|
| 131 |
-
p = re.compile(r"(([^\[]+)(?:\[([^\]]+)\])?),?")
|
| 132 |
-
for m in p.finditer(s):
|
| 133 |
-
prefix, suffixes = s[m.start(2) : m.end(2)], s[m.start(3) : m.end(3)]
|
| 134 |
-
for suffix in suffixes.split(","):
|
| 135 |
-
span = suffix.split("-")
|
| 136 |
-
if len(span) == 1:
|
| 137 |
-
nodes.append(prefix + suffix)
|
| 138 |
-
else:
|
| 139 |
-
width = len(span[0])
|
| 140 |
-
start, end = int(span[0]), int(span[1]) + 1
|
| 141 |
-
nodes.extend([prefix + f"{i:0{width}}" for i in range(start, end)])
|
| 142 |
-
return nodes
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
def _check_env_variable(key: str, new_value: str):
|
| 146 |
-
# Only check for difference with preset environment variables
|
| 147 |
-
if key in os.environ and os.environ[key] != new_value:
|
| 148 |
-
raise RuntimeError(f"Cannot export environment variables as {key} is already set")
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
class _TorchDistributedEnvironment:
|
| 152 |
-
def __init__(self):
|
| 153 |
-
self.master_addr = "127.0.0.1"
|
| 154 |
-
self.master_port = 0
|
| 155 |
-
self.rank = -1
|
| 156 |
-
self.world_size = -1
|
| 157 |
-
self.local_rank = -1
|
| 158 |
-
self.local_world_size = -1
|
| 159 |
-
|
| 160 |
-
if _is_slurm_job_process():
|
| 161 |
-
return self._set_from_slurm_env()
|
| 162 |
-
|
| 163 |
-
env_vars = _collect_env_vars()
|
| 164 |
-
if not env_vars:
|
| 165 |
-
# Environment is not set
|
| 166 |
-
pass
|
| 167 |
-
elif len(env_vars) == len(_TORCH_DISTRIBUTED_ENV_VARS):
|
| 168 |
-
# Environment is fully set
|
| 169 |
-
return self._set_from_preset_env()
|
| 170 |
-
else:
|
| 171 |
-
# Environment is partially set
|
| 172 |
-
collected_env_vars = ", ".join(env_vars.keys())
|
| 173 |
-
raise RuntimeError(f"Partially set environment: {collected_env_vars}")
|
| 174 |
-
|
| 175 |
-
if torch.cuda.device_count() > 0:
|
| 176 |
-
return self._set_from_local()
|
| 177 |
-
|
| 178 |
-
raise RuntimeError("Can't initialize PyTorch distributed environment")
|
| 179 |
-
|
| 180 |
-
# Slurm job created with sbatch, submitit, etc...
|
| 181 |
-
def _set_from_slurm_env(self):
|
| 182 |
-
# logger.info("Initialization from Slurm environment")
|
| 183 |
-
job_id = int(os.environ["SLURM_JOB_ID"])
|
| 184 |
-
node_count = int(os.environ["SLURM_JOB_NUM_NODES"])
|
| 185 |
-
nodes = _parse_slurm_node_list(os.environ["SLURM_JOB_NODELIST"])
|
| 186 |
-
assert len(nodes) == node_count
|
| 187 |
-
|
| 188 |
-
self.master_addr = nodes[0]
|
| 189 |
-
self.master_port = _get_master_port(seed=job_id)
|
| 190 |
-
self.rank = int(os.environ["SLURM_PROCID"])
|
| 191 |
-
self.world_size = int(os.environ["SLURM_NTASKS"])
|
| 192 |
-
assert self.rank < self.world_size
|
| 193 |
-
self.local_rank = int(os.environ["SLURM_LOCALID"])
|
| 194 |
-
self.local_world_size = self.world_size // node_count
|
| 195 |
-
assert self.local_rank < self.local_world_size
|
| 196 |
-
|
| 197 |
-
# Single node job with preset environment (i.e. torchrun)
|
| 198 |
-
def _set_from_preset_env(self):
|
| 199 |
-
# logger.info("Initialization from preset environment")
|
| 200 |
-
self.master_addr = os.environ["MASTER_ADDR"]
|
| 201 |
-
self.master_port = os.environ["MASTER_PORT"]
|
| 202 |
-
self.rank = int(os.environ["RANK"])
|
| 203 |
-
self.world_size = int(os.environ["WORLD_SIZE"])
|
| 204 |
-
assert self.rank < self.world_size
|
| 205 |
-
self.local_rank = int(os.environ["LOCAL_RANK"])
|
| 206 |
-
self.local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])
|
| 207 |
-
assert self.local_rank < self.local_world_size
|
| 208 |
-
|
| 209 |
-
# Single node and GPU job (i.e. local script run)
|
| 210 |
-
def _set_from_local(self):
|
| 211 |
-
# logger.info("Initialization from local")
|
| 212 |
-
self.master_addr = "127.0.0.1"
|
| 213 |
-
self.master_port = _get_available_port()
|
| 214 |
-
self.rank = 0
|
| 215 |
-
self.world_size = 1
|
| 216 |
-
self.local_rank = 0
|
| 217 |
-
self.local_world_size = 1
|
| 218 |
-
|
| 219 |
-
def export(self, *, overwrite: bool) -> "_TorchDistributedEnvironment":
|
| 220 |
-
# See the "Environment variable initialization" section from
|
| 221 |
-
# https://pytorch.org/docs/stable/distributed.html for the complete list of
|
| 222 |
-
# environment variables required for the env:// initialization method.
|
| 223 |
-
env_vars = {
|
| 224 |
-
"MASTER_ADDR": self.master_addr,
|
| 225 |
-
"MASTER_PORT": str(self.master_port),
|
| 226 |
-
"RANK": str(self.rank),
|
| 227 |
-
"WORLD_SIZE": str(self.world_size),
|
| 228 |
-
"LOCAL_RANK": str(self.local_rank),
|
| 229 |
-
"LOCAL_WORLD_SIZE": str(self.local_world_size),
|
| 230 |
-
}
|
| 231 |
-
if not overwrite:
|
| 232 |
-
for k, v in env_vars.items():
|
| 233 |
-
_check_env_variable(k, v)
|
| 234 |
-
|
| 235 |
-
os.environ.update(env_vars)
|
| 236 |
-
return self
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
def enable(*, set_cuda_current_device: bool = True, overwrite: bool = False, allow_nccl_timeout: bool = False):
|
| 240 |
-
"""Enable distributed mode
|
| 241 |
-
|
| 242 |
-
Args:
|
| 243 |
-
set_cuda_current_device: If True, call torch.cuda.set_device() to set the
|
| 244 |
-
current PyTorch CUDA device to the one matching the local rank.
|
| 245 |
-
overwrite: If True, overwrites already set variables. Else fails.
|
| 246 |
-
"""
|
| 247 |
-
|
| 248 |
-
global _LOCAL_RANK, _LOCAL_WORLD_SIZE
|
| 249 |
-
if _LOCAL_RANK >= 0 or _LOCAL_WORLD_SIZE >= 0:
|
| 250 |
-
raise RuntimeError("Distributed mode has already been enabled")
|
| 251 |
-
torch_env = _TorchDistributedEnvironment()
|
| 252 |
-
torch_env.export(overwrite=overwrite)
|
| 253 |
-
|
| 254 |
-
if set_cuda_current_device:
|
| 255 |
-
torch.cuda.set_device(torch_env.local_rank)
|
| 256 |
-
|
| 257 |
-
if allow_nccl_timeout:
|
| 258 |
-
# This allows to use torch distributed timeout in a NCCL backend
|
| 259 |
-
key, value = "NCCL_ASYNC_ERROR_HANDLING", "1"
|
| 260 |
-
if not overwrite:
|
| 261 |
-
_check_env_variable(key, value)
|
| 262 |
-
os.environ[key] = value
|
| 263 |
-
|
| 264 |
-
dist.init_process_group(backend="nccl")
|
| 265 |
-
dist.barrier()
|
| 266 |
-
|
| 267 |
-
# Finalize setup
|
| 268 |
-
_LOCAL_RANK = torch_env.local_rank
|
| 269 |
-
_LOCAL_WORLD_SIZE = torch_env.local_world_size
|
| 270 |
-
_restrict_print_to_main_process()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dinov2/eval/__init__.py
DELETED
|
@@ -1,4 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dinov2/eval/cell_dino/knn.py
DELETED
|
@@ -1,479 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the CC-by-NC licence,
|
| 4 |
-
# found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
import argparse
|
| 7 |
-
from functools import partial
|
| 8 |
-
import json
|
| 9 |
-
import logging
|
| 10 |
-
import os
|
| 11 |
-
import sys
|
| 12 |
-
from typing import List, Optional, Any
|
| 13 |
-
import numpy as np
|
| 14 |
-
|
| 15 |
-
import torch
|
| 16 |
-
import torch.backends.cudnn as cudnn
|
| 17 |
-
import pandas as pd
|
| 18 |
-
from sklearn.metrics import f1_score
|
| 19 |
-
|
| 20 |
-
import dinov2.distributed as distributed
|
| 21 |
-
from dinov2.data import make_dataset, DatasetWithEnumeratedTargets, SamplerType, make_data_loader
|
| 22 |
-
from dinov2.data.cell_dino.transforms import NormalizationType, make_classification_eval_cell_transform
|
| 23 |
-
from dinov2.eval.metrics import build_metric, MetricType
|
| 24 |
-
from dinov2.eval.setup import get_args_parser as get_setup_args_parser
|
| 25 |
-
from dinov2.eval.setup import setup_and_build_model
|
| 26 |
-
|
| 27 |
-
from dinov2.data import ResultsAccumulator
|
| 28 |
-
from dinov2.eval.utils import ModelWithNormalize
|
| 29 |
-
from dinov2.eval.cell_dino.utils import (
|
| 30 |
-
BagOfChannelsModelWithNormalize,
|
| 31 |
-
extract_features_cell_dino,
|
| 32 |
-
average_metrics,
|
| 33 |
-
create_train_dataset_dict,
|
| 34 |
-
get_num_classes,
|
| 35 |
-
extract_features_for_dataset_dict,
|
| 36 |
-
evaluate_with_accumulate,
|
| 37 |
-
KnnModule,
|
| 38 |
-
)
|
| 39 |
-
from dinov2.eval.knn import DictKeysModule
|
| 40 |
-
from torch.utils.data import Subset as SubsetEx
|
| 41 |
-
from torch.utils.data import ConcatDataset as ConcatDatasetEx
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
logger = logging.getLogger("dinov2")
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
def get_args_parser(
|
| 48 |
-
description: Optional[str] = None,
|
| 49 |
-
parents: Optional[List[argparse.ArgumentParser]] = None,
|
| 50 |
-
add_help: bool = True,
|
| 51 |
-
):
|
| 52 |
-
parents = parents or []
|
| 53 |
-
setup_args_parser = get_setup_args_parser(parents=parents, add_help=False)
|
| 54 |
-
parents = [setup_args_parser]
|
| 55 |
-
parser = argparse.ArgumentParser(
|
| 56 |
-
description=description,
|
| 57 |
-
parents=parents,
|
| 58 |
-
add_help=add_help,
|
| 59 |
-
)
|
| 60 |
-
parser.add_argument(
|
| 61 |
-
"--train-dataset",
|
| 62 |
-
dest="train_dataset_str",
|
| 63 |
-
type=str,
|
| 64 |
-
help="Training dataset",
|
| 65 |
-
)
|
| 66 |
-
parser.add_argument(
|
| 67 |
-
"--val-dataset",
|
| 68 |
-
dest="val_dataset_str",
|
| 69 |
-
type=str,
|
| 70 |
-
help="Validation dataset",
|
| 71 |
-
)
|
| 72 |
-
parser.add_argument(
|
| 73 |
-
"--nb_knn",
|
| 74 |
-
nargs="+",
|
| 75 |
-
type=int,
|
| 76 |
-
help="Number of NN to use. 20 is usually working the best.",
|
| 77 |
-
)
|
| 78 |
-
parser.add_argument(
|
| 79 |
-
"--temperature",
|
| 80 |
-
type=float,
|
| 81 |
-
help="Temperature used in the voting coefficient",
|
| 82 |
-
)
|
| 83 |
-
parser.add_argument(
|
| 84 |
-
"--gather-on-cpu",
|
| 85 |
-
action="store_true",
|
| 86 |
-
help="Whether to gather the train features on cpu, slower"
|
| 87 |
-
"but useful to avoid OOM for large datasets (e.g. ImageNet22k).",
|
| 88 |
-
)
|
| 89 |
-
parser.add_argument(
|
| 90 |
-
"--batch-size",
|
| 91 |
-
type=int,
|
| 92 |
-
help="Batch size.",
|
| 93 |
-
)
|
| 94 |
-
parser.add_argument(
|
| 95 |
-
"--n-per-class-list",
|
| 96 |
-
nargs="+",
|
| 97 |
-
type=int,
|
| 98 |
-
help="Number to take per class",
|
| 99 |
-
)
|
| 100 |
-
parser.add_argument(
|
| 101 |
-
"--n-tries",
|
| 102 |
-
type=int,
|
| 103 |
-
help="Number of tries",
|
| 104 |
-
)
|
| 105 |
-
parser.add_argument(
|
| 106 |
-
"--leave-one-out-dataset",
|
| 107 |
-
type=str,
|
| 108 |
-
help="Path with indexes to use the leave one out strategy for CHAMMI_CP task 3 and CHAMMI_HPA task 4",
|
| 109 |
-
)
|
| 110 |
-
parser.add_argument(
|
| 111 |
-
"--bag-of-channels",
|
| 112 |
-
action="store_true",
|
| 113 |
-
help='Whether to use the "bag of channels" channel adaptive strategy',
|
| 114 |
-
)
|
| 115 |
-
parser.add_argument(
|
| 116 |
-
"--crop-size",
|
| 117 |
-
type=int,
|
| 118 |
-
help="crop size for train and eval",
|
| 119 |
-
)
|
| 120 |
-
parser.add_argument(
|
| 121 |
-
"--resize-size",
|
| 122 |
-
type=int,
|
| 123 |
-
help="resize size for image just before crop. 0: no resize",
|
| 124 |
-
)
|
| 125 |
-
parser.add_argument(
|
| 126 |
-
"--metric-type",
|
| 127 |
-
type=MetricType,
|
| 128 |
-
choices=list(MetricType),
|
| 129 |
-
help="Validation metric",
|
| 130 |
-
)
|
| 131 |
-
parser.add_argument(
|
| 132 |
-
"--avgpool",
|
| 133 |
-
action="store_true",
|
| 134 |
-
help="Whether to use average pooling of path tokens in addition to CLS tokens",
|
| 135 |
-
)
|
| 136 |
-
|
| 137 |
-
parser.set_defaults(
|
| 138 |
-
train_dataset_str="ImageNet:split=TRAIN",
|
| 139 |
-
val_dataset_str="ImageNet:split=VAL",
|
| 140 |
-
nb_knn=[1],
|
| 141 |
-
temperature=0.07,
|
| 142 |
-
batch_size=256,
|
| 143 |
-
resize_size=0,
|
| 144 |
-
)
|
| 145 |
-
return parser
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
class SequentialWithKwargs(torch.nn.Sequential):
|
| 149 |
-
def __init__(self, *args):
|
| 150 |
-
super().__init__(*args)
|
| 151 |
-
|
| 152 |
-
def forward(self, input, **kwargs):
|
| 153 |
-
|
| 154 |
-
input = self[0](input, **kwargs)
|
| 155 |
-
for module in self[1:]:
|
| 156 |
-
input = module(input)
|
| 157 |
-
return input
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
def create_train_test_dataset_dict_leave_one_out(
|
| 161 |
-
train_dataset,
|
| 162 |
-
test_dataset,
|
| 163 |
-
) -> dict[int, dict[int, Any]]:
|
| 164 |
-
"""
|
| 165 |
-
This function implements a train dataset dictionary with the leave-one-out (LOO) method.
|
| 166 |
-
Specifically, given a train dataset and test dataset, it creates a train dataset for each
|
| 167 |
-
test dataset point, which is a combination of train+test dataset except for this specific data point.
|
| 168 |
-
At the end, it contains len(test_dataset) key and value pairs.
|
| 169 |
-
|
| 170 |
-
Format is {"nth-test-sample": dataset_without_test_sample}
|
| 171 |
-
"""
|
| 172 |
-
train_dataset_dict: dict[int, Any] = {}
|
| 173 |
-
test_size = len(test_dataset)
|
| 174 |
-
|
| 175 |
-
for test_sample_index in range(test_size):
|
| 176 |
-
test_indices_bool = torch.ones(test_size, dtype=bool)
|
| 177 |
-
test_indices_bool[test_sample_index] = False
|
| 178 |
-
train_dataset_dict[test_sample_index] = ConcatDatasetEx(
|
| 179 |
-
[train_dataset, SubsetEx(test_dataset, test_indices_bool.nonzero().flatten())]
|
| 180 |
-
)
|
| 181 |
-
|
| 182 |
-
return train_dataset_dict
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
def eval_knn_with_leave_one_out(
|
| 186 |
-
model, leave_one_out_dataset, train_dataset, test_dataset, metric_type, nb_knn, temperature, batch_size, num_workers
|
| 187 |
-
):
|
| 188 |
-
num_classes = get_num_classes(test_dataset)
|
| 189 |
-
train_dataset_dict = create_train_dataset_dict(train_dataset)
|
| 190 |
-
test_dataset_dict = create_train_dataset_dict(test_dataset)
|
| 191 |
-
|
| 192 |
-
logger.info("Extracting features for train set...")
|
| 193 |
-
train_data_dict = extract_features_for_dataset_dict(
|
| 194 |
-
model, train_dataset_dict, batch_size, num_workers, gather_on_cpu=True
|
| 195 |
-
)
|
| 196 |
-
test_data_dict = extract_features_for_dataset_dict(
|
| 197 |
-
model, test_dataset_dict, batch_size, num_workers, gather_on_cpu=True
|
| 198 |
-
)
|
| 199 |
-
|
| 200 |
-
train_features = train_data_dict[0]["train_features"]
|
| 201 |
-
train_labels = train_data_dict[0]["train_labels"]
|
| 202 |
-
test_features = test_data_dict[0]["train_features"]
|
| 203 |
-
test_labels = test_data_dict[0]["train_labels"]
|
| 204 |
-
|
| 205 |
-
metric_collection = build_metric(metric_type, num_classes=3)
|
| 206 |
-
|
| 207 |
-
device = torch.cuda.current_device()
|
| 208 |
-
partial_knn_module = partial(KnnModule, T=temperature, device=device, num_classes=num_classes)
|
| 209 |
-
|
| 210 |
-
logger.info("Reading the leave-one-out label metadata.")
|
| 211 |
-
|
| 212 |
-
leave_one_out_indices = {}
|
| 213 |
-
metadata = pd.read_csv(leave_one_out_dataset)
|
| 214 |
-
if "HPA" in leave_one_out_dataset:
|
| 215 |
-
metadata = metadata[metadata["Task_three"]].reset_index()
|
| 216 |
-
leave_one_out_label_type = "cell_type"
|
| 217 |
-
else:
|
| 218 |
-
metadata = metadata[metadata["Task_four"]].reset_index()
|
| 219 |
-
leave_one_out_label_type = "Plate"
|
| 220 |
-
leave_one_out_labels = metadata[leave_one_out_label_type].unique()
|
| 221 |
-
|
| 222 |
-
for leave_one_out_label in leave_one_out_labels:
|
| 223 |
-
leave_one_out_indices[leave_one_out_label] = torch.tensor(
|
| 224 |
-
metadata[metadata[leave_one_out_label_type] == leave_one_out_label].index.values
|
| 225 |
-
)
|
| 226 |
-
|
| 227 |
-
# ============ evaluation ... ============
|
| 228 |
-
logger.info("Start the k-NN classification.")
|
| 229 |
-
|
| 230 |
-
eval_metrics_dict = {}
|
| 231 |
-
postprocessors, metrics = {k: DictKeysModule([k]) for k in nb_knn}, {
|
| 232 |
-
k: metric_collection.clone().to(device) for k in nb_knn
|
| 233 |
-
}
|
| 234 |
-
for metric_key in metrics.keys():
|
| 235 |
-
metrics[metric_key] = metrics[metric_key].to(device)
|
| 236 |
-
|
| 237 |
-
accumulator_class = ResultsAccumulator
|
| 238 |
-
accumulators = {k: accumulator_class() for k in postprocessors.keys()}
|
| 239 |
-
all_preds = []
|
| 240 |
-
all_target = []
|
| 241 |
-
|
| 242 |
-
for loo_label, loo_indices in leave_one_out_indices.items():
|
| 243 |
-
logger.info(f"Evaluating on test sample {loo_label}")
|
| 244 |
-
loo_for_training_indices = torch.ones(test_features.shape[0], dtype=bool)
|
| 245 |
-
loo_for_training_indices[loo_indices] = False
|
| 246 |
-
train_features_sample = torch.cat([train_features, test_features[loo_for_training_indices]])
|
| 247 |
-
train_labels_sample = torch.cat([train_labels, test_labels[loo_for_training_indices]])
|
| 248 |
-
logger.info(f"Train shape {train_features_sample.shape}, Test shape {test_features[loo_indices].shape}")
|
| 249 |
-
logger.info(
|
| 250 |
-
f"Train values {train_labels_sample.unique(return_counts=True)}, Test shape {test_labels[loo_indices].unique(return_counts=True)}"
|
| 251 |
-
)
|
| 252 |
-
knn_module = partial_knn_module(
|
| 253 |
-
train_features=train_features_sample, train_labels=train_labels_sample, nb_knn=nb_knn
|
| 254 |
-
)
|
| 255 |
-
|
| 256 |
-
output = knn_module(test_features[loo_indices].to(device))
|
| 257 |
-
all_preds.append(output[1])
|
| 258 |
-
all_target.append(test_labels[loo_indices])
|
| 259 |
-
output[1] = output[1][:, 4:]
|
| 260 |
-
transformed_test_labels = test_labels[loo_indices] - 4
|
| 261 |
-
for k, metric in metrics.items():
|
| 262 |
-
metric_inputs = postprocessors[k](output, transformed_test_labels.to(device))
|
| 263 |
-
metric.update(**metric_inputs)
|
| 264 |
-
accumulators[k].update(
|
| 265 |
-
preds=metric_inputs["preds"], target=metric_inputs["target"], index=loo_indices.to(device)
|
| 266 |
-
)
|
| 267 |
-
|
| 268 |
-
all_preds = torch.cat(all_preds).cpu().detach().numpy()
|
| 269 |
-
|
| 270 |
-
all_preds = np.argmax(all_preds, axis=1)
|
| 271 |
-
all_target = torch.cat(all_target).cpu().detach().numpy()
|
| 272 |
-
|
| 273 |
-
f1 = f1_score(all_target, all_preds, average="macro", labels=[4, 5, 6])
|
| 274 |
-
logger.info(f"Real f1 score: {f1}")
|
| 275 |
-
eval_metrics = {
|
| 276 |
-
k: metric.compute() for k, metric in metrics.items()
|
| 277 |
-
} # next erased by the real f1 score computed above
|
| 278 |
-
|
| 279 |
-
for k in nb_knn:
|
| 280 |
-
if k not in eval_metrics_dict:
|
| 281 |
-
eval_metrics_dict[k] = {}
|
| 282 |
-
eval_metrics_dict[k] = {metric: f1 * 100.0 for metric, v in eval_metrics[k].items()}
|
| 283 |
-
|
| 284 |
-
if len(train_data_dict) > 1:
|
| 285 |
-
return {k: average_metrics(eval_metrics_dict[k]) for k in eval_metrics_dict.keys()}
|
| 286 |
-
|
| 287 |
-
return {k: eval_metrics_dict[k] for k in eval_metrics_dict.keys()}
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
def eval_knn_with_model(
|
| 291 |
-
model,
|
| 292 |
-
output_dir,
|
| 293 |
-
train_dataset_str,
|
| 294 |
-
val_dataset_str,
|
| 295 |
-
nb_knn=(10, 20, 100, 200),
|
| 296 |
-
temperature=0.07,
|
| 297 |
-
autocast_dtype=torch.float,
|
| 298 |
-
metric_type=MetricType.MEAN_ACCURACY,
|
| 299 |
-
transform=None,
|
| 300 |
-
resize_size=256,
|
| 301 |
-
crop_size=224,
|
| 302 |
-
batch_size=256,
|
| 303 |
-
num_workers=5,
|
| 304 |
-
leave_one_out_dataset="",
|
| 305 |
-
bag_of_channels=False,
|
| 306 |
-
avgpool=False,
|
| 307 |
-
):
|
| 308 |
-
autocast_ctx = partial(torch.cuda.amp.autocast, enabled=True, dtype=autocast_dtype)
|
| 309 |
-
if bag_of_channels:
|
| 310 |
-
model = BagOfChannelsModelWithNormalize(model, autocast_ctx, avgpool)
|
| 311 |
-
else:
|
| 312 |
-
model = ModelWithNormalize(model)
|
| 313 |
-
if leave_one_out_dataset == "" or leave_one_out_dataset is None:
|
| 314 |
-
leave_one_out = False
|
| 315 |
-
else:
|
| 316 |
-
leave_one_out = True
|
| 317 |
-
|
| 318 |
-
cudnn.benchmark = True
|
| 319 |
-
transform = make_classification_eval_cell_transform(
|
| 320 |
-
normalization_type=NormalizationType.SELF_NORM_CENTER_CROP, resize_size=resize_size, crop_size=crop_size
|
| 321 |
-
)
|
| 322 |
-
|
| 323 |
-
train_dataset = make_dataset(dataset_str=train_dataset_str, transform=transform)
|
| 324 |
-
results_dict = {}
|
| 325 |
-
test_dataset = make_dataset(dataset_str=val_dataset_str, transform=transform)
|
| 326 |
-
|
| 327 |
-
with torch.cuda.amp.autocast(dtype=autocast_dtype):
|
| 328 |
-
if leave_one_out:
|
| 329 |
-
results_dict_knn = eval_knn_with_leave_one_out(
|
| 330 |
-
model=model,
|
| 331 |
-
leave_one_out_dataset=leave_one_out_dataset,
|
| 332 |
-
train_dataset=train_dataset,
|
| 333 |
-
test_dataset=test_dataset,
|
| 334 |
-
metric_type=metric_type,
|
| 335 |
-
nb_knn=nb_knn,
|
| 336 |
-
temperature=temperature,
|
| 337 |
-
batch_size=batch_size,
|
| 338 |
-
num_workers=num_workers,
|
| 339 |
-
)
|
| 340 |
-
else:
|
| 341 |
-
results_dict_knn = eval_knn(
|
| 342 |
-
model=model,
|
| 343 |
-
train_dataset=train_dataset,
|
| 344 |
-
test_dataset=test_dataset,
|
| 345 |
-
metric_type=metric_type,
|
| 346 |
-
nb_knn=nb_knn,
|
| 347 |
-
temperature=temperature,
|
| 348 |
-
batch_size=batch_size,
|
| 349 |
-
num_workers=num_workers,
|
| 350 |
-
)
|
| 351 |
-
|
| 352 |
-
for knn_ in results_dict_knn.keys():
|
| 353 |
-
top1 = results_dict_knn[knn_]["top-1"]
|
| 354 |
-
results_dict[f"{val_dataset_str}_{knn_} Top 1"] = top1
|
| 355 |
-
results_string = f"{val_dataset_str} {knn_} NN classifier result: Top1: {top1:.2f}"
|
| 356 |
-
if "top-5" in results_dict_knn[knn_]:
|
| 357 |
-
top5 = results_dict_knn[knn_]["top-5"]
|
| 358 |
-
results_dict[f"{val_dataset_str}_{knn_} Top 5"] = top5
|
| 359 |
-
results_string += f"Top5: {top5:.2f}"
|
| 360 |
-
logger.info(results_string)
|
| 361 |
-
|
| 362 |
-
metrics_file_path = os.path.join(output_dir, "results_eval_knn.json")
|
| 363 |
-
with open(metrics_file_path, "a") as f:
|
| 364 |
-
for k, v in results_dict.items():
|
| 365 |
-
f.write(json.dumps({k: v}) + "\n")
|
| 366 |
-
|
| 367 |
-
if distributed.is_enabled():
|
| 368 |
-
torch.distributed.barrier()
|
| 369 |
-
return results_dict
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
def eval_knn(
|
| 373 |
-
model,
|
| 374 |
-
train_dataset,
|
| 375 |
-
test_dataset,
|
| 376 |
-
metric_type,
|
| 377 |
-
nb_knn,
|
| 378 |
-
temperature,
|
| 379 |
-
batch_size,
|
| 380 |
-
num_workers,
|
| 381 |
-
few_shot_eval=False,
|
| 382 |
-
few_shot_k_or_percent=None,
|
| 383 |
-
few_shot_n_tries=1,
|
| 384 |
-
):
|
| 385 |
-
num_classes = get_num_classes(train_dataset)
|
| 386 |
-
train_dataset_dict = create_train_dataset_dict(
|
| 387 |
-
train_dataset,
|
| 388 |
-
few_shot_eval=few_shot_eval,
|
| 389 |
-
few_shot_k_or_percent=few_shot_k_or_percent,
|
| 390 |
-
few_shot_n_tries=few_shot_n_tries,
|
| 391 |
-
)
|
| 392 |
-
|
| 393 |
-
logger.info("Extracting features for train set...")
|
| 394 |
-
|
| 395 |
-
train_data_dict: dict[int, dict[str, torch.Tensor]] = {}
|
| 396 |
-
for try_n, dataset in train_dataset_dict.items():
|
| 397 |
-
features, labels = extract_features_cell_dino(model, dataset, batch_size, num_workers, gather_on_cpu=True)
|
| 398 |
-
train_data_dict[try_n] = {"train_features": features, "train_labels": labels}
|
| 399 |
-
|
| 400 |
-
test_data_loader = make_data_loader(
|
| 401 |
-
dataset=DatasetWithEnumeratedTargets(
|
| 402 |
-
test_dataset, pad_dataset=True, num_replicas=distributed.get_global_size()
|
| 403 |
-
),
|
| 404 |
-
batch_size=batch_size,
|
| 405 |
-
num_workers=num_workers,
|
| 406 |
-
sampler_type=SamplerType.DISTRIBUTED,
|
| 407 |
-
drop_last=False,
|
| 408 |
-
shuffle=False,
|
| 409 |
-
persistent_workers=True,
|
| 410 |
-
collate_fn=None,
|
| 411 |
-
)
|
| 412 |
-
metric_collection = build_metric(metric_type, num_classes=num_classes)
|
| 413 |
-
|
| 414 |
-
device = torch.cuda.current_device()
|
| 415 |
-
partial_knn_module = partial(
|
| 416 |
-
KnnModule,
|
| 417 |
-
T=temperature,
|
| 418 |
-
device=device,
|
| 419 |
-
num_classes=num_classes,
|
| 420 |
-
)
|
| 421 |
-
|
| 422 |
-
# ============ evaluation ... ============
|
| 423 |
-
logger.info("Start the k-NN classification.")
|
| 424 |
-
eval_metrics_dict = {}
|
| 425 |
-
|
| 426 |
-
for try_ in train_data_dict.keys():
|
| 427 |
-
train_features, train_labels = train_data_dict[try_]["train_features"], train_data_dict[try_]["train_labels"]
|
| 428 |
-
k_list = sorted(set([el if el < len(train_features) else len(train_features) for el in nb_knn]))
|
| 429 |
-
knn_module = partial_knn_module(train_features=train_features, train_labels=train_labels, nb_knn=k_list)
|
| 430 |
-
postprocessors, metrics = {k: DictKeysModule([k]) for k in k_list}, {
|
| 431 |
-
k: metric_collection.clone() for k in k_list
|
| 432 |
-
}
|
| 433 |
-
_, eval_metrics, _ = evaluate_with_accumulate(
|
| 434 |
-
SequentialWithKwargs(model, knn_module),
|
| 435 |
-
test_data_loader,
|
| 436 |
-
postprocessors,
|
| 437 |
-
metrics,
|
| 438 |
-
device,
|
| 439 |
-
accumulate_results=False,
|
| 440 |
-
)
|
| 441 |
-
for k in k_list:
|
| 442 |
-
if k not in eval_metrics_dict:
|
| 443 |
-
eval_metrics_dict[k] = {}
|
| 444 |
-
eval_metrics_dict[k][try_] = {metric: v.item() * 100.0 for metric, v in eval_metrics[k].items()}
|
| 445 |
-
|
| 446 |
-
if len(train_data_dict) > 1:
|
| 447 |
-
return {k: average_metrics(eval_metrics_dict[k]) for k in eval_metrics_dict.keys()}
|
| 448 |
-
|
| 449 |
-
return {k: eval_metrics_dict[k][0] for k in eval_metrics_dict.keys()}
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
def main(args):
|
| 453 |
-
model, autocast_dtype = setup_and_build_model(args)
|
| 454 |
-
eval_knn_with_model(
|
| 455 |
-
model=model,
|
| 456 |
-
output_dir=args.output_dir,
|
| 457 |
-
train_dataset_str=args.train_dataset_str,
|
| 458 |
-
val_dataset_str=args.val_dataset_str,
|
| 459 |
-
nb_knn=args.nb_knn,
|
| 460 |
-
temperature=args.temperature,
|
| 461 |
-
autocast_dtype=autocast_dtype,
|
| 462 |
-
transform=None,
|
| 463 |
-
metric_type=args.metric_type,
|
| 464 |
-
batch_size=args.batch_size,
|
| 465 |
-
num_workers=5,
|
| 466 |
-
leave_one_out_dataset=args.leave_one_out_dataset,
|
| 467 |
-
resize_size=args.resize_size,
|
| 468 |
-
crop_size=args.crop_size,
|
| 469 |
-
avgpool=args.avgpool,
|
| 470 |
-
bag_of_channels=args.bag_of_channels,
|
| 471 |
-
)
|
| 472 |
-
return 0
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
if __name__ == "__main__":
|
| 476 |
-
description = "k-NN evaluation on models trained with bag of channel strategy or cell dino"
|
| 477 |
-
args_parser = get_args_parser(description=description)
|
| 478 |
-
args = args_parser.parse_args()
|
| 479 |
-
sys.exit(main(args))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dinov2/eval/cell_dino/linear.py
DELETED
|
@@ -1,1048 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the CC-by-NC licence,
|
| 4 |
-
# found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
import argparse
|
| 7 |
-
from functools import partial
|
| 8 |
-
import json
|
| 9 |
-
import logging
|
| 10 |
-
import os
|
| 11 |
-
import sys
|
| 12 |
-
from typing import Any, Callable, Dict, Optional, Tuple, List
|
| 13 |
-
from enum import Enum
|
| 14 |
-
from dataclasses import dataclass
|
| 15 |
-
|
| 16 |
-
from sklearn.metrics import f1_score
|
| 17 |
-
import numpy as np
|
| 18 |
-
import pandas as pd
|
| 19 |
-
import torch
|
| 20 |
-
import torch.nn as nn
|
| 21 |
-
from torch.utils.data import TensorDataset
|
| 22 |
-
from torch.nn.parallel import DistributedDataParallel
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
from dinov2.data import SamplerType, make_data_loader, make_dataset, DatasetWithEnumeratedTargets
|
| 26 |
-
from dinov2.data.cell_dino.transforms import NormalizationType, make_classification_eval_cell_transform
|
| 27 |
-
import dinov2.distributed as distributed
|
| 28 |
-
from dinov2.eval.metrics import MetricType, build_metric
|
| 29 |
-
from dinov2.eval.setup import get_args_parser as get_setup_args_parser
|
| 30 |
-
from dinov2.eval.setup import setup_and_build_model
|
| 31 |
-
from dinov2.eval.cell_dino.utils import (
|
| 32 |
-
evaluate_with_accumulate,
|
| 33 |
-
LossType,
|
| 34 |
-
average_metrics,
|
| 35 |
-
create_train_dataset_dict,
|
| 36 |
-
get_num_classes,
|
| 37 |
-
extract_features_for_dataset_dict,
|
| 38 |
-
)
|
| 39 |
-
from dinov2.eval.utils import ModelWithIntermediateLayers
|
| 40 |
-
from dinov2.logging import MetricLogger
|
| 41 |
-
from dinov2.utils.checkpoint import build_periodic_checkpointer, resume_or_load
|
| 42 |
-
|
| 43 |
-
logger = logging.getLogger("dinov2")
|
| 44 |
-
|
| 45 |
-
"""
|
| 46 |
-
List of changes with respect to the standard linear evaluation script:
|
| 47 |
-
|
| 48 |
-
bag of channel option : SCALE ADAPTIVE STRATEGY
|
| 49 |
-
|
| 50 |
-
Adam optimizer instead of SGD
|
| 51 |
-
Scheduler : two options : onecycleLR or CosineAnnealingLR
|
| 52 |
-
the transforms/normalization are different, now calling make_classification_eval_cell_transform
|
| 53 |
-
add binary cross entropy loss option for protein localization
|
| 54 |
-
change the definition of the num_classes using get_num_classes
|
| 55 |
-
change of some default parameters (batch_size, epoch_length, epochs, lrs)
|
| 56 |
-
defined n_last_blocks option
|
| 57 |
-
avgpool option
|
| 58 |
-
leave one out strategy for CHAMMI evaluation
|
| 59 |
-
grid search for optimal weight decay
|
| 60 |
-
"""
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
def get_args_parser(
|
| 64 |
-
description: Optional[str] = None,
|
| 65 |
-
parents: Optional[List[argparse.ArgumentParser]] = None,
|
| 66 |
-
add_help: bool = True,
|
| 67 |
-
):
|
| 68 |
-
parents = parents or []
|
| 69 |
-
setup_args_parser = get_setup_args_parser(parents=parents, add_help=False)
|
| 70 |
-
parents = [setup_args_parser]
|
| 71 |
-
parser = argparse.ArgumentParser(
|
| 72 |
-
description=description,
|
| 73 |
-
parents=parents,
|
| 74 |
-
add_help=add_help,
|
| 75 |
-
)
|
| 76 |
-
parser.add_argument(
|
| 77 |
-
"--train-dataset",
|
| 78 |
-
dest="train_dataset_str",
|
| 79 |
-
type=str,
|
| 80 |
-
help="Training dataset",
|
| 81 |
-
)
|
| 82 |
-
parser.add_argument(
|
| 83 |
-
"--val-dataset",
|
| 84 |
-
dest="val_dataset_str",
|
| 85 |
-
type=str,
|
| 86 |
-
help="Validation dataset",
|
| 87 |
-
)
|
| 88 |
-
parser.add_argument(
|
| 89 |
-
"--test-datasets",
|
| 90 |
-
dest="test_dataset_strs",
|
| 91 |
-
type=str,
|
| 92 |
-
nargs="+",
|
| 93 |
-
help="Test datasets, none to reuse the validation dataset",
|
| 94 |
-
)
|
| 95 |
-
parser.add_argument(
|
| 96 |
-
"--epochs",
|
| 97 |
-
type=int,
|
| 98 |
-
help="Number of training epochs",
|
| 99 |
-
)
|
| 100 |
-
parser.add_argument(
|
| 101 |
-
"--batch-size",
|
| 102 |
-
type=int,
|
| 103 |
-
help="Batch Size (per GPU)",
|
| 104 |
-
)
|
| 105 |
-
parser.add_argument(
|
| 106 |
-
"--num-workers",
|
| 107 |
-
type=int,
|
| 108 |
-
help="Number de Workers",
|
| 109 |
-
)
|
| 110 |
-
parser.add_argument(
|
| 111 |
-
"--epoch-length",
|
| 112 |
-
type=int,
|
| 113 |
-
help="Length of an epoch in number of iterations",
|
| 114 |
-
)
|
| 115 |
-
parser.add_argument(
|
| 116 |
-
"--save-checkpoint-frequency",
|
| 117 |
-
type=int,
|
| 118 |
-
help="Number of epochs between two named checkpoint saves.",
|
| 119 |
-
)
|
| 120 |
-
parser.add_argument(
|
| 121 |
-
"--eval-period-iterations",
|
| 122 |
-
type=int,
|
| 123 |
-
help="Number of iterations between two evaluations.",
|
| 124 |
-
)
|
| 125 |
-
parser.add_argument(
|
| 126 |
-
"--learning-rates",
|
| 127 |
-
nargs="+",
|
| 128 |
-
type=float,
|
| 129 |
-
help="Learning rates to grid search.",
|
| 130 |
-
)
|
| 131 |
-
parser.add_argument(
|
| 132 |
-
"--weight_decays",
|
| 133 |
-
nargs="+",
|
| 134 |
-
type=float,
|
| 135 |
-
help="Weight decays to grid search.",
|
| 136 |
-
)
|
| 137 |
-
parser.add_argument(
|
| 138 |
-
"--n-last-blocks",
|
| 139 |
-
type=int,
|
| 140 |
-
help="number of backbone last blocks used for the linear classifier",
|
| 141 |
-
)
|
| 142 |
-
parser.add_argument(
|
| 143 |
-
"--no-resume",
|
| 144 |
-
action="store_true",
|
| 145 |
-
help="Whether to not resume from existing checkpoints",
|
| 146 |
-
)
|
| 147 |
-
parser.add_argument(
|
| 148 |
-
"--val-metric-type",
|
| 149 |
-
type=MetricType,
|
| 150 |
-
choices=list(MetricType),
|
| 151 |
-
help="Validation metric",
|
| 152 |
-
)
|
| 153 |
-
parser.add_argument(
|
| 154 |
-
"--test-metric-types",
|
| 155 |
-
type=MetricType,
|
| 156 |
-
choices=list(MetricType),
|
| 157 |
-
nargs="+",
|
| 158 |
-
help="Evaluation metric",
|
| 159 |
-
)
|
| 160 |
-
parser.add_argument(
|
| 161 |
-
"--classifier-fpath",
|
| 162 |
-
type=str,
|
| 163 |
-
help="Path to a file containing pretrained linear classifiers",
|
| 164 |
-
)
|
| 165 |
-
parser.add_argument(
|
| 166 |
-
"--val-class-mapping-fpath",
|
| 167 |
-
type=str,
|
| 168 |
-
help="Path to a file containing a mapping to adjust classifier outputs",
|
| 169 |
-
)
|
| 170 |
-
parser.add_argument(
|
| 171 |
-
"--test-class-mapping-fpaths",
|
| 172 |
-
nargs="+",
|
| 173 |
-
type=str,
|
| 174 |
-
help="Path to a file containing a mapping to adjust classifier outputs",
|
| 175 |
-
)
|
| 176 |
-
parser.add_argument(
|
| 177 |
-
"--loss-type",
|
| 178 |
-
type=LossType,
|
| 179 |
-
help="Cross Entropy or Binary Cross Entropy, default cross entropy loss",
|
| 180 |
-
)
|
| 181 |
-
parser.add_argument(
|
| 182 |
-
"--bag-of-channels",
|
| 183 |
-
action="store_true",
|
| 184 |
-
help='Whether to use the "bag of channels" channel adaptive strategy',
|
| 185 |
-
)
|
| 186 |
-
parser.add_argument(
|
| 187 |
-
"--leave-one-out-dataset",
|
| 188 |
-
type=str,
|
| 189 |
-
help="Path with indexes to use the leave one out strategy for CHAMMI_CP task 3 and CHAMMI_HPA task 4",
|
| 190 |
-
)
|
| 191 |
-
parser.add_argument(
|
| 192 |
-
"--crop-size",
|
| 193 |
-
type=int,
|
| 194 |
-
help="crop size for train and eval",
|
| 195 |
-
)
|
| 196 |
-
parser.add_argument(
|
| 197 |
-
"--resize-size",
|
| 198 |
-
type=int,
|
| 199 |
-
help="resize size for image just before crop. 0: no resize",
|
| 200 |
-
)
|
| 201 |
-
parser.add_argument(
|
| 202 |
-
"--avgpool",
|
| 203 |
-
action="store_true",
|
| 204 |
-
help="Whether to use average pooling of path tokens in addition to CLS tokens",
|
| 205 |
-
)
|
| 206 |
-
parser.add_argument(
|
| 207 |
-
"--scheduler",
|
| 208 |
-
type=SchedulerType,
|
| 209 |
-
help="Scheduler type",
|
| 210 |
-
)
|
| 211 |
-
|
| 212 |
-
parser.set_defaults(
|
| 213 |
-
train_dataset_str="ImageNet:split=TRAIN",
|
| 214 |
-
val_dataset_str="ImageNet:split=VAL",
|
| 215 |
-
test_dataset_strs=None,
|
| 216 |
-
epochs=30,
|
| 217 |
-
batch_size=64,
|
| 218 |
-
num_workers=8,
|
| 219 |
-
epoch_length=145,
|
| 220 |
-
save_checkpoint_frequency=1250,
|
| 221 |
-
eval_period_iterations=1250,
|
| 222 |
-
learning_rates=[1e-5, 2e-5, 5e-5, 1e-4, 2e-4, 5e-4, 1e-3, 2e-3, 5e-3, 1e-2, 2e-2, 5e-2, 1e-1, 2e-1, 5e-1, 1.0],
|
| 223 |
-
weight_decays=[0.0, 0.0001, 1.0e-05],
|
| 224 |
-
val_metric_type=MetricType.MEAN_ACCURACY,
|
| 225 |
-
test_metric_types=None,
|
| 226 |
-
classifier_fpath=None,
|
| 227 |
-
val_class_mapping_fpath=None,
|
| 228 |
-
test_class_mapping_fpaths=[None],
|
| 229 |
-
loss_type=LossType.CROSS_ENTROPY,
|
| 230 |
-
crop_size=384,
|
| 231 |
-
resize_size=0,
|
| 232 |
-
n_last_blocks=4,
|
| 233 |
-
avgpool=False,
|
| 234 |
-
scheduler=SchedulerType.COSINE_ANNEALING,
|
| 235 |
-
)
|
| 236 |
-
return parser
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
def has_ddp_wrapper(m: nn.Module) -> bool:
|
| 240 |
-
return isinstance(m, DistributedDataParallel)
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
def remove_ddp_wrapper(m: nn.Module) -> nn.Module:
|
| 244 |
-
return m.module if has_ddp_wrapper(m) else m
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
def create_linear_input(x_tokens_list, use_n_blocks, use_avgpool, bag_of_channels):
|
| 248 |
-
intermediate_output = x_tokens_list[-use_n_blocks:]
|
| 249 |
-
output = torch.cat([class_token for _, class_token in intermediate_output], dim=-1)
|
| 250 |
-
if bag_of_channels:
|
| 251 |
-
if use_avgpool:
|
| 252 |
-
output = torch.cat(
|
| 253 |
-
(
|
| 254 |
-
output,
|
| 255 |
-
torch.mean(intermediate_output[-1][0], dim=-2).reshape(intermediate_output[-1][0].shape[0], -1),
|
| 256 |
-
# average pooling of patch tokens: average over N, then concatenate channels if single-channel patch model
|
| 257 |
-
),
|
| 258 |
-
dim=-1,
|
| 259 |
-
) # concatenate average pooling of patch tokens to concatenated patch tokens
|
| 260 |
-
else:
|
| 261 |
-
if use_avgpool:
|
| 262 |
-
output = torch.cat(
|
| 263 |
-
(
|
| 264 |
-
output,
|
| 265 |
-
torch.mean(intermediate_output[-1][0], dim=1), # patch tokens
|
| 266 |
-
),
|
| 267 |
-
dim=-1,
|
| 268 |
-
)
|
| 269 |
-
output = output.reshape(output.shape[0], -1)
|
| 270 |
-
return output.float()
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
class LinearClassifier(nn.Module):
|
| 274 |
-
"""Linear layer to train on top of frozen features"""
|
| 275 |
-
|
| 276 |
-
def __init__(
|
| 277 |
-
self, out_dim, use_n_blocks, use_avgpool, num_classes=1000, bag_of_channels=False, leave_one_out=False
|
| 278 |
-
):
|
| 279 |
-
super().__init__()
|
| 280 |
-
self.out_dim = out_dim
|
| 281 |
-
self.use_n_blocks = use_n_blocks
|
| 282 |
-
self.use_avgpool = use_avgpool
|
| 283 |
-
self.num_classes = num_classes
|
| 284 |
-
self.bag_of_channels = bag_of_channels
|
| 285 |
-
self.leave_one_out = leave_one_out
|
| 286 |
-
self.linear = nn.Linear(out_dim, num_classes)
|
| 287 |
-
self.linear.weight.data.normal_(mean=0.0, std=0.01)
|
| 288 |
-
self.linear.bias.data.zero_()
|
| 289 |
-
|
| 290 |
-
def forward(self, x_tokens_list):
|
| 291 |
-
if self.leave_one_out:
|
| 292 |
-
return self.linear(x_tokens_list)
|
| 293 |
-
output = create_linear_input(x_tokens_list, self.use_n_blocks, self.use_avgpool, self.bag_of_channels)
|
| 294 |
-
return self.linear(output)
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
class AllClassifiers(nn.Module):
|
| 298 |
-
def __init__(self, classifiers_dict):
|
| 299 |
-
super().__init__()
|
| 300 |
-
self.classifiers_dict = nn.ModuleDict()
|
| 301 |
-
self.classifiers_dict.update(classifiers_dict)
|
| 302 |
-
|
| 303 |
-
def forward(self, inputs):
|
| 304 |
-
return {k: v.forward(inputs) for k, v in self.classifiers_dict.items()}
|
| 305 |
-
|
| 306 |
-
def __len__(self):
|
| 307 |
-
return len(self.classifiers_dict)
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
class LinearPostprocessor(nn.Module):
|
| 311 |
-
def __init__(self, linear_classifier, class_mapping=None):
|
| 312 |
-
super().__init__()
|
| 313 |
-
self.linear_classifier = linear_classifier
|
| 314 |
-
self.register_buffer("class_mapping", None if class_mapping is None else torch.LongTensor(class_mapping))
|
| 315 |
-
|
| 316 |
-
def forward(self, samples, targets):
|
| 317 |
-
preds = self.linear_classifier(samples)
|
| 318 |
-
return {
|
| 319 |
-
"preds": preds[:, self.class_mapping] if self.class_mapping is not None else preds,
|
| 320 |
-
"target": targets,
|
| 321 |
-
}
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
def scale_lr(learning_rates, batch_size):
|
| 325 |
-
return learning_rates * (batch_size * distributed.get_global_size()) / 256.0
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
def setup_linear_classifiers(
|
| 329 |
-
sample_output,
|
| 330 |
-
n_last_blocks_list,
|
| 331 |
-
learning_rates,
|
| 332 |
-
weight_decays,
|
| 333 |
-
batch_size,
|
| 334 |
-
num_classes=1000,
|
| 335 |
-
bag_of_channels=False,
|
| 336 |
-
leave_one_out=False,
|
| 337 |
-
avgpool=False,
|
| 338 |
-
):
|
| 339 |
-
linear_classifiers_dict = nn.ModuleDict()
|
| 340 |
-
avgpool_value = avgpool
|
| 341 |
-
optim_param_groups = []
|
| 342 |
-
for n in n_last_blocks_list:
|
| 343 |
-
for avgpool in [avgpool_value]:
|
| 344 |
-
for _lr in learning_rates:
|
| 345 |
-
for wd in weight_decays:
|
| 346 |
-
lr = scale_lr(_lr, batch_size)
|
| 347 |
-
out_dim = create_linear_input(
|
| 348 |
-
sample_output, use_n_blocks=n, use_avgpool=avgpool, bag_of_channels=bag_of_channels
|
| 349 |
-
).shape[1]
|
| 350 |
-
linear_classifier = LinearClassifier(
|
| 351 |
-
out_dim,
|
| 352 |
-
use_n_blocks=n,
|
| 353 |
-
use_avgpool=avgpool,
|
| 354 |
-
num_classes=num_classes,
|
| 355 |
-
bag_of_channels=bag_of_channels,
|
| 356 |
-
leave_one_out=leave_one_out,
|
| 357 |
-
)
|
| 358 |
-
linear_classifier = linear_classifier.cuda()
|
| 359 |
-
linear_classifiers_dict[
|
| 360 |
-
f"classifier_{n}_blocks_avgpool_{avgpool}_lr_{lr:.5f}_wd_{wd:.2E}".replace(".", "_")
|
| 361 |
-
] = linear_classifier
|
| 362 |
-
optim_param_groups.append({"params": linear_classifier.parameters(), "lr": lr, "weight_decay": wd})
|
| 363 |
-
|
| 364 |
-
linear_classifiers = AllClassifiers(linear_classifiers_dict)
|
| 365 |
-
if distributed.is_enabled():
|
| 366 |
-
linear_classifiers = nn.parallel.DistributedDataParallel(linear_classifiers)
|
| 367 |
-
|
| 368 |
-
return linear_classifiers, optim_param_groups
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
def make_eval_data_loader(
|
| 372 |
-
*,
|
| 373 |
-
test_dataset_str_or_path_or_loo_dataset,
|
| 374 |
-
config,
|
| 375 |
-
batch_size,
|
| 376 |
-
num_workers,
|
| 377 |
-
):
|
| 378 |
-
if isinstance(test_dataset_str_or_path_or_loo_dataset, str):
|
| 379 |
-
logger.info(f"Loading dataset {test_dataset_str_or_path_or_loo_dataset}")
|
| 380 |
-
transform = make_classification_eval_cell_transform(
|
| 381 |
-
normalization_type=NormalizationType.SELF_NORM_CENTER_CROP,
|
| 382 |
-
resize_size=config["resize_size"],
|
| 383 |
-
crop_size=config["crop_size"],
|
| 384 |
-
)
|
| 385 |
-
test_dataset = make_dataset(dataset_str=test_dataset_str_or_path_or_loo_dataset, transform=transform)
|
| 386 |
-
collate_fn = None
|
| 387 |
-
else:
|
| 388 |
-
logger.info("Making data loader for feature dataset (typical in leave one out evaluation)")
|
| 389 |
-
test_dataset = test_dataset_str_or_path_or_loo_dataset
|
| 390 |
-
collate_fn = None
|
| 391 |
-
class_mapping = None
|
| 392 |
-
if hasattr(test_dataset, "get_imagenet_class_mapping"):
|
| 393 |
-
class_mapping = test_dataset.get_imagenet_class_mapping()
|
| 394 |
-
|
| 395 |
-
test_data_loader = make_data_loader(
|
| 396 |
-
dataset=DatasetWithEnumeratedTargets(
|
| 397 |
-
test_dataset, pad_dataset=True, num_replicas=distributed.get_global_size()
|
| 398 |
-
),
|
| 399 |
-
batch_size=batch_size,
|
| 400 |
-
num_workers=num_workers,
|
| 401 |
-
sampler_type=SamplerType.DISTRIBUTED,
|
| 402 |
-
drop_last=False,
|
| 403 |
-
shuffle=False,
|
| 404 |
-
persistent_workers=False,
|
| 405 |
-
collate_fn=collate_fn,
|
| 406 |
-
)
|
| 407 |
-
return test_data_loader, class_mapping
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
@dataclass
|
| 411 |
-
class Evaluator:
|
| 412 |
-
batch_size: int
|
| 413 |
-
num_workers: int
|
| 414 |
-
dataset_str_or_path: str
|
| 415 |
-
config: Dict
|
| 416 |
-
metric_type: MetricType
|
| 417 |
-
metrics_file_path: str
|
| 418 |
-
training_num_classes: int
|
| 419 |
-
save_results_func: Optional[Callable]
|
| 420 |
-
val_dataset_loo: Optional[TensorDataset] = None
|
| 421 |
-
|
| 422 |
-
def __post_init__(self):
|
| 423 |
-
self.main_metric_name = f"{self.dataset_str_or_path}_accuracy"
|
| 424 |
-
|
| 425 |
-
if self.val_dataset_loo is not None:
|
| 426 |
-
self.dataset_str_or_path = self.val_dataset_loo
|
| 427 |
-
|
| 428 |
-
self.data_loader, self.class_mapping = make_eval_data_loader(
|
| 429 |
-
test_dataset_str_or_path_or_loo_dataset=self.dataset_str_or_path,
|
| 430 |
-
batch_size=self.batch_size,
|
| 431 |
-
num_workers=self.num_workers,
|
| 432 |
-
config=self.config,
|
| 433 |
-
)
|
| 434 |
-
|
| 435 |
-
@torch.no_grad()
|
| 436 |
-
def _evaluate_linear_classifiers(
|
| 437 |
-
self,
|
| 438 |
-
*,
|
| 439 |
-
feature_model,
|
| 440 |
-
linear_classifiers,
|
| 441 |
-
iteration,
|
| 442 |
-
prefixstring="",
|
| 443 |
-
best_classifier_on_val=None,
|
| 444 |
-
accumulate_results=False,
|
| 445 |
-
test_mode=False,
|
| 446 |
-
) -> Tuple[Dict[str, Any], Optional[Dict[str, torch.Tensor]]]:
|
| 447 |
-
logger.info("running validation !")
|
| 448 |
-
|
| 449 |
-
num_classes = len(self.class_mapping) if self.class_mapping is not None else self.training_num_classes
|
| 450 |
-
metric = build_metric(self.metric_type, num_classes=num_classes)
|
| 451 |
-
postprocessors = {
|
| 452 |
-
k: LinearPostprocessor(v, self.class_mapping) for k, v in linear_classifiers.classifiers_dict.items()
|
| 453 |
-
}
|
| 454 |
-
metrics = {k: metric.clone() for k in linear_classifiers.classifiers_dict}
|
| 455 |
-
|
| 456 |
-
_, results_dict_temp, accumulated_results = evaluate_with_accumulate(
|
| 457 |
-
feature_model,
|
| 458 |
-
self.data_loader,
|
| 459 |
-
postprocessors,
|
| 460 |
-
metrics,
|
| 461 |
-
torch.cuda.current_device(),
|
| 462 |
-
accumulate_results=accumulate_results,
|
| 463 |
-
leave_one_out=self.config["leave_one_out"],
|
| 464 |
-
test_mode=test_mode,
|
| 465 |
-
)
|
| 466 |
-
|
| 467 |
-
logger.info("")
|
| 468 |
-
results_dict = {}
|
| 469 |
-
max_accuracy = 0
|
| 470 |
-
best_classifier = ""
|
| 471 |
-
for _, (classifier_string, metric) in enumerate(results_dict_temp.items()):
|
| 472 |
-
logger.info(f"{prefixstring} -- Classifier: {classifier_string} * {metric}")
|
| 473 |
-
if (
|
| 474 |
-
best_classifier_on_val is None and metric["top-1"].item() > max_accuracy
|
| 475 |
-
) or classifier_string == best_classifier_on_val:
|
| 476 |
-
max_accuracy = metric["top-1"].item()
|
| 477 |
-
best_classifier = classifier_string
|
| 478 |
-
|
| 479 |
-
results_dict["best_classifier"] = {"name": best_classifier, "accuracy": max_accuracy}
|
| 480 |
-
|
| 481 |
-
logger.info(f"best classifier: {results_dict['best_classifier']}")
|
| 482 |
-
|
| 483 |
-
accumulated_best_results = None
|
| 484 |
-
if test_mode:
|
| 485 |
-
accumulated_best_results = accumulated_results
|
| 486 |
-
elif accumulated_results is not None:
|
| 487 |
-
accumulated_best_results = accumulated_results[best_classifier]
|
| 488 |
-
|
| 489 |
-
if distributed.is_main_process():
|
| 490 |
-
with open(self.metrics_file_path, "a") as f:
|
| 491 |
-
f.write(f"iter: {iteration}\n")
|
| 492 |
-
for k, v in results_dict.items():
|
| 493 |
-
f.write(json.dumps({k: v}) + "\n")
|
| 494 |
-
f.write("\n")
|
| 495 |
-
|
| 496 |
-
return results_dict, accumulated_best_results
|
| 497 |
-
|
| 498 |
-
def evaluate_and_maybe_save(
|
| 499 |
-
self,
|
| 500 |
-
feature_model,
|
| 501 |
-
linear_classifiers,
|
| 502 |
-
iteration: int,
|
| 503 |
-
best_classifier_on_val: Optional[Any] = None,
|
| 504 |
-
save_filename_suffix: str = "",
|
| 505 |
-
prefixstring: str = "",
|
| 506 |
-
test_mode: bool = False,
|
| 507 |
-
):
|
| 508 |
-
logger.info(f"Testing on {self.dataset_str_or_path}")
|
| 509 |
-
save_results = self.save_results_func is not None
|
| 510 |
-
full_results_dict, accumulated_best_results = self._evaluate_linear_classifiers(
|
| 511 |
-
feature_model=feature_model,
|
| 512 |
-
linear_classifiers=remove_ddp_wrapper(linear_classifiers),
|
| 513 |
-
iteration=iteration,
|
| 514 |
-
prefixstring=prefixstring,
|
| 515 |
-
best_classifier_on_val=best_classifier_on_val,
|
| 516 |
-
accumulate_results=save_results,
|
| 517 |
-
test_mode=test_mode,
|
| 518 |
-
)
|
| 519 |
-
if self.save_results_func is not None:
|
| 520 |
-
self.save_results_func(
|
| 521 |
-
filename_suffix=f"{self.dataset_str_or_path}{save_filename_suffix}", **accumulated_best_results
|
| 522 |
-
)
|
| 523 |
-
|
| 524 |
-
results_dict = {
|
| 525 |
-
self.main_metric_name: 100.0 * full_results_dict["best_classifier"]["accuracy"],
|
| 526 |
-
"best_classifier": full_results_dict["best_classifier"]["name"],
|
| 527 |
-
}
|
| 528 |
-
return results_dict, accumulated_best_results
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
def make_evaluators(
|
| 532 |
-
config: Dict,
|
| 533 |
-
val_metric_type: MetricType,
|
| 534 |
-
val_dataset: str,
|
| 535 |
-
metric_type: MetricType,
|
| 536 |
-
metrics_file_path: str,
|
| 537 |
-
training_num_classes: int,
|
| 538 |
-
save_results_func: Optional[Callable],
|
| 539 |
-
val_dataset_loo: Optional[TensorDataset] = None,
|
| 540 |
-
):
|
| 541 |
-
test_metric_types = config["test_metric_types"]
|
| 542 |
-
test_dataset_strs = config["test_datasets"]
|
| 543 |
-
if test_dataset_strs is None:
|
| 544 |
-
test_dataset_strs = (config["val_dataset"],)
|
| 545 |
-
if test_metric_types is None:
|
| 546 |
-
test_metric_types = (val_metric_type,)
|
| 547 |
-
else:
|
| 548 |
-
assert len(test_metric_types) == len(config["test_datasets"])
|
| 549 |
-
|
| 550 |
-
val_evaluator, *test_evaluators = [
|
| 551 |
-
Evaluator(
|
| 552 |
-
dataset_str_or_path=dataset_str_or_path,
|
| 553 |
-
batch_size=config["batch_size"],
|
| 554 |
-
num_workers=config["num_workers"],
|
| 555 |
-
config=config,
|
| 556 |
-
metric_type=metric_type,
|
| 557 |
-
metrics_file_path=metrics_file_path,
|
| 558 |
-
training_num_classes=training_num_classes,
|
| 559 |
-
save_results_func=save_results_func,
|
| 560 |
-
val_dataset_loo=val_dataset_loo,
|
| 561 |
-
)
|
| 562 |
-
for dataset_str_or_path, metric_type in zip(
|
| 563 |
-
(val_dataset,) + tuple(test_dataset_strs),
|
| 564 |
-
(val_metric_type,) + tuple(test_metric_types),
|
| 565 |
-
)
|
| 566 |
-
]
|
| 567 |
-
return val_evaluator, test_evaluators
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
class SchedulerType(Enum):
|
| 571 |
-
COSINE_ANNEALING = "cosine_annealing"
|
| 572 |
-
ONE_CYCLE = "one_cycle"
|
| 573 |
-
|
| 574 |
-
def get_scheduler(self, optimizer, optim_param_groups, epoch_length, epochs, max_iter):
|
| 575 |
-
if self == SchedulerType.ONE_CYCLE:
|
| 576 |
-
lr_list = [optim_param_groups[i]["lr"] for i in range(len(optim_param_groups))]
|
| 577 |
-
scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
| 578 |
-
optimizer, max_lr=lr_list, steps_per_epoch=epoch_length, epochs=epochs
|
| 579 |
-
)
|
| 580 |
-
else:
|
| 581 |
-
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, max_iter, eta_min=0)
|
| 582 |
-
print("CosineAnnealingLR scheduler")
|
| 583 |
-
return scheduler
|
| 584 |
-
|
| 585 |
-
|
| 586 |
-
def setup_linear_training(
|
| 587 |
-
*,
|
| 588 |
-
config: Dict,
|
| 589 |
-
sample_output: torch.Tensor,
|
| 590 |
-
training_num_classes: int,
|
| 591 |
-
checkpoint_output_dir: str,
|
| 592 |
-
):
|
| 593 |
-
linear_classifiers, optim_param_groups = setup_linear_classifiers(
|
| 594 |
-
sample_output,
|
| 595 |
-
config["n_last_blocks_list"],
|
| 596 |
-
config["learning_rates"],
|
| 597 |
-
config["weight_decays"],
|
| 598 |
-
config["batch_size"],
|
| 599 |
-
training_num_classes,
|
| 600 |
-
config["bag_of_channels"],
|
| 601 |
-
config["leave_one_out"],
|
| 602 |
-
config["avgpool"],
|
| 603 |
-
)
|
| 604 |
-
max_iter = config["epochs"] * config["epoch_length"]
|
| 605 |
-
optimizer = torch.optim.AdamW(optim_param_groups, weight_decay=0)
|
| 606 |
-
|
| 607 |
-
scheduler = config["scheduler"].get_scheduler(
|
| 608 |
-
optimizer=optimizer,
|
| 609 |
-
optim_param_groups=optim_param_groups,
|
| 610 |
-
epoch_length=config["epoch_length"],
|
| 611 |
-
epochs=config["epochs"],
|
| 612 |
-
max_iter=max_iter,
|
| 613 |
-
)
|
| 614 |
-
checkpoint_period = config["save_checkpoint_iterations"] or config["epoch_length"]
|
| 615 |
-
periodic_checkpointer = build_periodic_checkpointer(
|
| 616 |
-
linear_classifiers,
|
| 617 |
-
checkpoint_output_dir,
|
| 618 |
-
optimizer=optimizer,
|
| 619 |
-
scheduler=scheduler,
|
| 620 |
-
period=checkpoint_period,
|
| 621 |
-
max_iter=max_iter,
|
| 622 |
-
max_to_keep=None,
|
| 623 |
-
)
|
| 624 |
-
checkpoint = resume_or_load(periodic_checkpointer, config["classifier_fpath"] or "", resume=config["resume"])
|
| 625 |
-
|
| 626 |
-
start_iter = checkpoint.get("iteration", -1) + 1
|
| 627 |
-
best_accuracy = checkpoint.get("best_accuracy", -1)
|
| 628 |
-
|
| 629 |
-
if config["loss_type"] == LossType.BINARY_CROSS_ENTROPY:
|
| 630 |
-
criterion = nn.BCEWithLogitsLoss()
|
| 631 |
-
else:
|
| 632 |
-
criterion = nn.CrossEntropyLoss()
|
| 633 |
-
|
| 634 |
-
return (
|
| 635 |
-
linear_classifiers,
|
| 636 |
-
start_iter,
|
| 637 |
-
max_iter,
|
| 638 |
-
criterion,
|
| 639 |
-
optimizer,
|
| 640 |
-
scheduler,
|
| 641 |
-
periodic_checkpointer,
|
| 642 |
-
best_accuracy,
|
| 643 |
-
)
|
| 644 |
-
|
| 645 |
-
|
| 646 |
-
def train_linear_classifiers(
|
| 647 |
-
*,
|
| 648 |
-
feature_model,
|
| 649 |
-
train_dataset,
|
| 650 |
-
train_config: Dict,
|
| 651 |
-
training_num_classes: int,
|
| 652 |
-
val_evaluator: Evaluator,
|
| 653 |
-
checkpoint_output_dir: str,
|
| 654 |
-
sample_output: Optional[torch.Tensor] = None,
|
| 655 |
-
):
|
| 656 |
-
|
| 657 |
-
if train_config["leave_one_out"]:
|
| 658 |
-
assert sample_output is not None, "sample_output should be passed as argument when using leave_one_out."
|
| 659 |
-
else:
|
| 660 |
-
sample_output = feature_model(train_dataset[0][0].unsqueeze(0).cuda())
|
| 661 |
-
|
| 662 |
-
(
|
| 663 |
-
linear_classifiers,
|
| 664 |
-
start_iter,
|
| 665 |
-
max_iter,
|
| 666 |
-
criterion,
|
| 667 |
-
optimizer,
|
| 668 |
-
scheduler,
|
| 669 |
-
periodic_checkpointer,
|
| 670 |
-
best_accuracy,
|
| 671 |
-
) = setup_linear_training(
|
| 672 |
-
config=train_config,
|
| 673 |
-
sample_output=sample_output,
|
| 674 |
-
training_num_classes=training_num_classes,
|
| 675 |
-
checkpoint_output_dir=checkpoint_output_dir,
|
| 676 |
-
)
|
| 677 |
-
|
| 678 |
-
sampler_type = SamplerType.INFINITE
|
| 679 |
-
train_data_loader = make_data_loader(
|
| 680 |
-
dataset=train_dataset,
|
| 681 |
-
batch_size=train_config["batch_size"],
|
| 682 |
-
num_workers=train_config["num_workers"],
|
| 683 |
-
shuffle=True,
|
| 684 |
-
seed=0,
|
| 685 |
-
sampler_type=sampler_type,
|
| 686 |
-
sampler_advance=start_iter,
|
| 687 |
-
drop_last=True,
|
| 688 |
-
persistent_workers=True,
|
| 689 |
-
)
|
| 690 |
-
eval_period = train_config["eval_period_iterations"] or train_config["epoch_length"]
|
| 691 |
-
iteration = start_iter
|
| 692 |
-
logger.info("Starting training from iteration {}".format(start_iter))
|
| 693 |
-
metric_logger = MetricLogger(delimiter=" ")
|
| 694 |
-
header = "Training"
|
| 695 |
-
|
| 696 |
-
for data, labels in metric_logger.log_every(
|
| 697 |
-
train_data_loader,
|
| 698 |
-
10,
|
| 699 |
-
header,
|
| 700 |
-
max_iter,
|
| 701 |
-
start_iter,
|
| 702 |
-
):
|
| 703 |
-
data = data.cuda(non_blocking=True)
|
| 704 |
-
labels = labels.cuda(non_blocking=True)
|
| 705 |
-
|
| 706 |
-
if not train_config["leave_one_out"]:
|
| 707 |
-
in_classifier = feature_model(data)
|
| 708 |
-
else:
|
| 709 |
-
in_classifier = data
|
| 710 |
-
|
| 711 |
-
outputs = linear_classifiers(in_classifier)
|
| 712 |
-
|
| 713 |
-
if len(labels.shape) > 1:
|
| 714 |
-
labels = labels.float()
|
| 715 |
-
losses = {f"loss_{k}": criterion(v, labels) for k, v in outputs.items()}
|
| 716 |
-
loss = sum(losses.values())
|
| 717 |
-
|
| 718 |
-
optimizer.zero_grad()
|
| 719 |
-
loss.backward()
|
| 720 |
-
|
| 721 |
-
optimizer.step()
|
| 722 |
-
scheduler.step()
|
| 723 |
-
|
| 724 |
-
if iteration % 10 == 0:
|
| 725 |
-
torch.cuda.synchronize()
|
| 726 |
-
metric_logger.update(loss=loss.item())
|
| 727 |
-
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
|
| 728 |
-
|
| 729 |
-
periodic_checkpointer.step(iteration=iteration, best_accuracy=best_accuracy)
|
| 730 |
-
|
| 731 |
-
if eval_period > 0 and (iteration + 1) % eval_period == 0 and iteration != max_iter - 1:
|
| 732 |
-
val_results_dict, _ = val_evaluator.evaluate_and_maybe_save(
|
| 733 |
-
feature_model=feature_model,
|
| 734 |
-
linear_classifiers=linear_classifiers,
|
| 735 |
-
prefixstring=f"ITER: {iteration}",
|
| 736 |
-
iteration=iteration,
|
| 737 |
-
)
|
| 738 |
-
val_accuracy = val_results_dict[val_evaluator.main_metric_name]
|
| 739 |
-
if val_accuracy >= best_accuracy:
|
| 740 |
-
best_accuracy = val_accuracy
|
| 741 |
-
periodic_checkpointer.save_best(iteration=iteration, best_accuracy=best_accuracy)
|
| 742 |
-
torch.distributed.barrier()
|
| 743 |
-
|
| 744 |
-
iteration = iteration + 1
|
| 745 |
-
|
| 746 |
-
return feature_model, linear_classifiers, iteration, periodic_checkpointer
|
| 747 |
-
|
| 748 |
-
|
| 749 |
-
def eval_linear_with_model(
|
| 750 |
-
model,
|
| 751 |
-
output_dir,
|
| 752 |
-
train_dataset_str,
|
| 753 |
-
val_dataset_str,
|
| 754 |
-
batch_size,
|
| 755 |
-
epochs,
|
| 756 |
-
epoch_length,
|
| 757 |
-
num_workers,
|
| 758 |
-
save_checkpoint_frequency,
|
| 759 |
-
eval_period_iterations,
|
| 760 |
-
learning_rates,
|
| 761 |
-
weight_decays,
|
| 762 |
-
autocast_dtype,
|
| 763 |
-
test_dataset_strs=None,
|
| 764 |
-
resume=True,
|
| 765 |
-
classifier_fpath=None,
|
| 766 |
-
val_metric_type=MetricType.MEAN_ACCURACY,
|
| 767 |
-
test_metric_types=None,
|
| 768 |
-
loss_type=LossType.CROSS_ENTROPY,
|
| 769 |
-
bag_of_channels=False,
|
| 770 |
-
leave_one_out_dataset="",
|
| 771 |
-
resize_size=0,
|
| 772 |
-
crop_size=384,
|
| 773 |
-
n_last_blocks=4,
|
| 774 |
-
avgpool=False,
|
| 775 |
-
scheduler=SchedulerType.COSINE_ANNEALING,
|
| 776 |
-
):
|
| 777 |
-
|
| 778 |
-
if leave_one_out_dataset == "" or leave_one_out_dataset is None:
|
| 779 |
-
leave_one_out = False
|
| 780 |
-
else:
|
| 781 |
-
logger.info("Reading the leave-one-out label metadata.")
|
| 782 |
-
|
| 783 |
-
leave_one_out_indices = {}
|
| 784 |
-
metadata = pd.read_csv(leave_one_out_dataset)
|
| 785 |
-
if "HPA" in leave_one_out_dataset:
|
| 786 |
-
metadata = metadata[metadata["Task_three"]].reset_index()
|
| 787 |
-
leave_one_out_label_type = "cell_type"
|
| 788 |
-
else:
|
| 789 |
-
metadata = metadata[metadata["Task_four"]].reset_index()
|
| 790 |
-
leave_one_out_label_type = "Plate"
|
| 791 |
-
leave_one_out_labels = metadata[leave_one_out_label_type].unique()
|
| 792 |
-
|
| 793 |
-
for leave_one_out_label in leave_one_out_labels:
|
| 794 |
-
leave_one_out_indices[leave_one_out_label] = np.array(
|
| 795 |
-
metadata[metadata[leave_one_out_label_type] == leave_one_out_label].index.values
|
| 796 |
-
)
|
| 797 |
-
|
| 798 |
-
leave_one_out = True
|
| 799 |
-
|
| 800 |
-
train_transform = make_classification_eval_cell_transform(
|
| 801 |
-
normalization_type=NormalizationType.SELF_NORM_AUG_DECODER, crop_size=crop_size, resize_size=resize_size
|
| 802 |
-
)
|
| 803 |
-
print("train_transform", train_transform)
|
| 804 |
-
train_dataset = make_dataset(
|
| 805 |
-
dataset_str=train_dataset_str,
|
| 806 |
-
transform=train_transform,
|
| 807 |
-
)
|
| 808 |
-
|
| 809 |
-
training_num_classes = get_num_classes(train_dataset)
|
| 810 |
-
if leave_one_out:
|
| 811 |
-
training_num_classes += train_dataset.num_additional_labels_loo_eval
|
| 812 |
-
train_dataset_dict = create_train_dataset_dict(train_dataset)
|
| 813 |
-
n_last_blocks_list = [n_last_blocks]
|
| 814 |
-
n_last_blocks = max(n_last_blocks_list)
|
| 815 |
-
dataset_use_cache = True
|
| 816 |
-
autocast_ctx = partial(torch.cuda.amp.autocast, enabled=True, dtype=autocast_dtype)
|
| 817 |
-
feature_model = ModelWithIntermediateLayers(model, n_last_blocks, autocast_ctx)
|
| 818 |
-
|
| 819 |
-
if bag_of_channels:
|
| 820 |
-
sample = train_dataset[0][0].unsqueeze(0)
|
| 821 |
-
sample_output = feature_model(sample.cuda())
|
| 822 |
-
|
| 823 |
-
if leave_one_out:
|
| 824 |
-
loo_dict = {}
|
| 825 |
-
train_data_dict = extract_features_for_dataset_dict(
|
| 826 |
-
feature_model,
|
| 827 |
-
train_dataset_dict,
|
| 828 |
-
batch_size,
|
| 829 |
-
num_workers,
|
| 830 |
-
gather_on_cpu=True,
|
| 831 |
-
avgpool=avgpool,
|
| 832 |
-
)
|
| 833 |
-
val_dataset = make_dataset(
|
| 834 |
-
dataset_str=val_dataset_str,
|
| 835 |
-
transform=make_classification_eval_cell_transform(
|
| 836 |
-
normalization_type=NormalizationType.SELF_NORM_CENTER_CROP, crop_size=crop_size, resize_size=resize_size
|
| 837 |
-
),
|
| 838 |
-
)
|
| 839 |
-
val_dataset_dict = create_train_dataset_dict(val_dataset)
|
| 840 |
-
val_data_dict = extract_features_for_dataset_dict(
|
| 841 |
-
feature_model,
|
| 842 |
-
val_dataset_dict,
|
| 843 |
-
batch_size,
|
| 844 |
-
num_workers,
|
| 845 |
-
gather_on_cpu=True,
|
| 846 |
-
avgpool=avgpool,
|
| 847 |
-
)
|
| 848 |
-
|
| 849 |
-
train_features = train_data_dict[0]["train_features"]
|
| 850 |
-
train_labels = train_data_dict[0]["train_labels"]
|
| 851 |
-
val_features = val_data_dict[0]["train_features"]
|
| 852 |
-
val_labels = val_data_dict[0]["train_labels"]
|
| 853 |
-
|
| 854 |
-
for loo_label, loo_indices in leave_one_out_indices.items():
|
| 855 |
-
loo_for_training_indices = torch.ones(val_features.shape[0], dtype=bool)
|
| 856 |
-
loo_for_training_indices[loo_indices] = False
|
| 857 |
-
loo_for_val_indices = torch.zeros(val_features.shape[0], dtype=bool)
|
| 858 |
-
loo_for_val_indices[loo_indices] = True
|
| 859 |
-
|
| 860 |
-
loo_dict[loo_label] = {
|
| 861 |
-
"train_features": torch.cat([train_features, val_features[loo_for_training_indices]]),
|
| 862 |
-
"train_labels": torch.cat([train_labels, val_labels[loo_for_training_indices]]),
|
| 863 |
-
"val_features": val_features[loo_indices],
|
| 864 |
-
"val_labels": val_labels[loo_indices],
|
| 865 |
-
}
|
| 866 |
-
save_results_func = None
|
| 867 |
-
# if config.save_results:
|
| 868 |
-
# save_results_func = partial(default_save_results_func, output_dir=output_dir)
|
| 869 |
-
|
| 870 |
-
metrics_file_path = os.path.join(output_dir, "results_eval_linear.json")
|
| 871 |
-
periodic_checkpointers: list = []
|
| 872 |
-
|
| 873 |
-
train_config = {
|
| 874 |
-
"learning_rates": learning_rates,
|
| 875 |
-
"weight_decays": weight_decays,
|
| 876 |
-
"batch_size": batch_size,
|
| 877 |
-
"num_workers": num_workers,
|
| 878 |
-
"dataset_use_cache": dataset_use_cache,
|
| 879 |
-
"eval_period_iterations": eval_period_iterations,
|
| 880 |
-
"epoch_length": epoch_length,
|
| 881 |
-
"leave_one_out": leave_one_out,
|
| 882 |
-
"bag_of_channels": bag_of_channels,
|
| 883 |
-
"n_last_blocks_list": n_last_blocks_list,
|
| 884 |
-
"epochs": epochs,
|
| 885 |
-
"loss_type": loss_type,
|
| 886 |
-
"resume": resume,
|
| 887 |
-
"save_checkpoint_iterations": save_checkpoint_frequency,
|
| 888 |
-
"classifier_fpath": classifier_fpath,
|
| 889 |
-
"avgpool": avgpool,
|
| 890 |
-
"scheduler": scheduler,
|
| 891 |
-
}
|
| 892 |
-
config = {
|
| 893 |
-
"test_metric_types": test_metric_types,
|
| 894 |
-
"test_datasets": test_dataset_strs,
|
| 895 |
-
"val_metric_types": val_metric_type,
|
| 896 |
-
"val_dataset": val_dataset_str,
|
| 897 |
-
"batch_size": batch_size,
|
| 898 |
-
"num_workers": num_workers,
|
| 899 |
-
"leave_one_out": leave_one_out,
|
| 900 |
-
"crop_size": crop_size,
|
| 901 |
-
"resize_size": resize_size,
|
| 902 |
-
}
|
| 903 |
-
if not leave_one_out:
|
| 904 |
-
val_evaluator, test_evaluators = make_evaluators(
|
| 905 |
-
config=config,
|
| 906 |
-
val_metric_type=val_metric_type,
|
| 907 |
-
val_dataset=val_dataset_str,
|
| 908 |
-
metric_type=test_metric_types,
|
| 909 |
-
metrics_file_path=metrics_file_path,
|
| 910 |
-
training_num_classes=training_num_classes,
|
| 911 |
-
save_results_func=save_results_func,
|
| 912 |
-
)
|
| 913 |
-
results_dict = {}
|
| 914 |
-
|
| 915 |
-
for _try in train_dataset_dict.keys():
|
| 916 |
-
if len(train_dataset_dict) > 1:
|
| 917 |
-
checkpoint_output_dir = os.path.join(output_dir, f"checkpoints_{_try}")
|
| 918 |
-
save_filename_suffix = f"_{_try}"
|
| 919 |
-
else:
|
| 920 |
-
checkpoint_output_dir, save_filename_suffix = output_dir, ""
|
| 921 |
-
os.makedirs(checkpoint_output_dir, exist_ok=True)
|
| 922 |
-
|
| 923 |
-
feature_model, linear_classifiers, iteration, periodic_checkpointer = train_linear_classifiers(
|
| 924 |
-
train_config=train_config,
|
| 925 |
-
feature_model=feature_model,
|
| 926 |
-
train_dataset=train_dataset_dict[_try],
|
| 927 |
-
training_num_classes=training_num_classes,
|
| 928 |
-
val_evaluator=val_evaluator,
|
| 929 |
-
checkpoint_output_dir=checkpoint_output_dir,
|
| 930 |
-
)
|
| 931 |
-
periodic_checkpointers.append(periodic_checkpointer)
|
| 932 |
-
results_dict[_try], _ = val_evaluator.evaluate_and_maybe_save(
|
| 933 |
-
feature_model=feature_model,
|
| 934 |
-
linear_classifiers=linear_classifiers,
|
| 935 |
-
iteration=iteration,
|
| 936 |
-
save_filename_suffix=save_filename_suffix,
|
| 937 |
-
)
|
| 938 |
-
for test_evaluator in test_evaluators:
|
| 939 |
-
eval_results_dict, _ = test_evaluator.evaluate_and_maybe_save(
|
| 940 |
-
feature_model=feature_model,
|
| 941 |
-
linear_classifiers=linear_classifiers,
|
| 942 |
-
iteration=iteration,
|
| 943 |
-
best_classifier_on_val=results_dict[_try]["best_classifier"],
|
| 944 |
-
save_filename_suffix=save_filename_suffix,
|
| 945 |
-
)
|
| 946 |
-
results_dict[_try] = {**eval_results_dict, **results_dict[_try]}
|
| 947 |
-
if len(train_dataset_dict) > 1:
|
| 948 |
-
results_dict = average_metrics(results_dict, ignore_keys=["best_classifier"])
|
| 949 |
-
else:
|
| 950 |
-
results_dict = {**results_dict[_try]}
|
| 951 |
-
else: # if leave one out is True
|
| 952 |
-
test_results_dict = {}
|
| 953 |
-
for loo_label in loo_dict.keys():
|
| 954 |
-
|
| 955 |
-
checkpoint_output_dir, save_filename_suffix = os.path.join(output_dir, f"checkpoints_{loo_label}"), ""
|
| 956 |
-
os.makedirs(checkpoint_output_dir, exist_ok=True)
|
| 957 |
-
|
| 958 |
-
train_dataset_loo = TensorDataset(
|
| 959 |
-
loo_dict[loo_label]["train_features"], loo_dict[loo_label]["train_labels"]
|
| 960 |
-
)
|
| 961 |
-
|
| 962 |
-
logger.info(f"Creating leave_one_out evaluators. loo_label: {loo_label}")
|
| 963 |
-
val_dataset_loo = TensorDataset(loo_dict[loo_label]["val_features"], loo_dict[loo_label]["val_labels"])
|
| 964 |
-
val_evaluators_loo, _ = make_evaluators(
|
| 965 |
-
config=config,
|
| 966 |
-
val_metric_type=val_metric_type,
|
| 967 |
-
val_dataset="loo",
|
| 968 |
-
metric_type=test_metric_types,
|
| 969 |
-
metrics_file_path=metrics_file_path,
|
| 970 |
-
training_num_classes=training_num_classes,
|
| 971 |
-
save_results_func=save_results_func,
|
| 972 |
-
val_dataset_loo=val_dataset_loo,
|
| 973 |
-
)
|
| 974 |
-
feature_model, linear_classifiers, iteration, periodic_checkpointer = train_linear_classifiers(
|
| 975 |
-
feature_model=feature_model,
|
| 976 |
-
train_dataset=train_dataset_loo,
|
| 977 |
-
train_config=train_config,
|
| 978 |
-
training_num_classes=training_num_classes,
|
| 979 |
-
val_evaluator=val_evaluators_loo,
|
| 980 |
-
checkpoint_output_dir=checkpoint_output_dir,
|
| 981 |
-
sample_output=sample_output,
|
| 982 |
-
)
|
| 983 |
-
periodic_checkpointers.append(periodic_checkpointer)
|
| 984 |
-
_, test_results_dict[loo_label] = val_evaluators_loo.evaluate_and_maybe_save(
|
| 985 |
-
feature_model=feature_model,
|
| 986 |
-
linear_classifiers=linear_classifiers,
|
| 987 |
-
iteration=iteration,
|
| 988 |
-
save_filename_suffix=save_filename_suffix,
|
| 989 |
-
test_mode=True,
|
| 990 |
-
)
|
| 991 |
-
classifier_names = test_results_dict[loo_label].keys()
|
| 992 |
-
results_dict = {k: [[], []] for k in classifier_names}
|
| 993 |
-
for ll in test_results_dict.keys():
|
| 994 |
-
for k in classifier_names:
|
| 995 |
-
results_dict[k][0].append(test_results_dict[ll][k][0])
|
| 996 |
-
results_dict[k][1].append(test_results_dict[ll][k][1])
|
| 997 |
-
for k in classifier_names:
|
| 998 |
-
results_dict[k] = [
|
| 999 |
-
np.argmax(torch.cat(results_dict[k][0]).cpu().detach().numpy(), axis=1),
|
| 1000 |
-
torch.cat(results_dict[k][1]).cpu().detach().numpy(),
|
| 1001 |
-
]
|
| 1002 |
-
results_dict[k] = f1_score(results_dict[k][1], results_dict[k][0], average="macro", labels=[4, 5, 6])
|
| 1003 |
-
logger.info(
|
| 1004 |
-
f"Best performance is for {max(results_dict, key=results_dict.get)}, with F1-Score of {results_dict[max(results_dict, key=results_dict.get)]}"
|
| 1005 |
-
)
|
| 1006 |
-
|
| 1007 |
-
logger.info("Test Results Dict " + str(results_dict))
|
| 1008 |
-
return results_dict
|
| 1009 |
-
|
| 1010 |
-
|
| 1011 |
-
def main(args):
|
| 1012 |
-
model, autocast_dtype = setup_and_build_model(args)
|
| 1013 |
-
eval_linear_with_model(
|
| 1014 |
-
model=model,
|
| 1015 |
-
output_dir=args.output_dir,
|
| 1016 |
-
train_dataset_str=args.train_dataset_str,
|
| 1017 |
-
val_dataset_str=args.val_dataset_str,
|
| 1018 |
-
test_dataset_strs=args.test_dataset_strs,
|
| 1019 |
-
batch_size=args.batch_size,
|
| 1020 |
-
epochs=args.epochs,
|
| 1021 |
-
epoch_length=args.epoch_length,
|
| 1022 |
-
num_workers=args.num_workers,
|
| 1023 |
-
save_checkpoint_frequency=args.save_checkpoint_frequency,
|
| 1024 |
-
eval_period_iterations=args.eval_period_iterations,
|
| 1025 |
-
learning_rates=args.learning_rates,
|
| 1026 |
-
weight_decays=args.weight_decays,
|
| 1027 |
-
autocast_dtype=autocast_dtype,
|
| 1028 |
-
resume=not args.no_resume,
|
| 1029 |
-
classifier_fpath=args.classifier_fpath,
|
| 1030 |
-
val_metric_type=args.val_metric_type,
|
| 1031 |
-
test_metric_types=args.test_metric_types,
|
| 1032 |
-
loss_type=args.loss_type,
|
| 1033 |
-
bag_of_channels=args.bag_of_channels,
|
| 1034 |
-
leave_one_out_dataset=args.leave_one_out_dataset,
|
| 1035 |
-
crop_size=args.crop_size,
|
| 1036 |
-
resize_size=args.resize_size,
|
| 1037 |
-
n_last_blocks=args.n_last_blocks,
|
| 1038 |
-
avgpool=args.avgpool,
|
| 1039 |
-
scheduler=args.scheduler,
|
| 1040 |
-
)
|
| 1041 |
-
return 0
|
| 1042 |
-
|
| 1043 |
-
|
| 1044 |
-
if __name__ == "__main__":
|
| 1045 |
-
description = "DINOv2 linear_cell_dino evaluation"
|
| 1046 |
-
args_parser = get_args_parser(description=description)
|
| 1047 |
-
args = args_parser.parse_args()
|
| 1048 |
-
sys.exit(main(args))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dinov2/eval/cell_dino/utils.py
DELETED
|
@@ -1,542 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the CC-by-NC licence,
|
| 4 |
-
# found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
import logging
|
| 7 |
-
from typing import Callable, Dict, Optional, Any, List
|
| 8 |
-
|
| 9 |
-
import torch
|
| 10 |
-
from torch import nn
|
| 11 |
-
from torchmetrics import MetricCollection
|
| 12 |
-
|
| 13 |
-
from dinov2.data import DatasetWithEnumeratedTargets, SamplerType, make_data_loader
|
| 14 |
-
from dinov2.data import NoOpAccumulator, ResultsAccumulator
|
| 15 |
-
import dinov2.distributed as distributed
|
| 16 |
-
from dinov2.logging import MetricLogger
|
| 17 |
-
from enum import Enum
|
| 18 |
-
from torch.utils.data import Subset
|
| 19 |
-
from torchvision.datasets.vision import StandardTransform
|
| 20 |
-
import numpy as np
|
| 21 |
-
from torch.nn.functional import one_hot, softmax
|
| 22 |
-
|
| 23 |
-
logger = logging.getLogger("dinov2")
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
class LossType(Enum):
|
| 27 |
-
CROSS_ENTROPY = "cross_entropy"
|
| 28 |
-
BINARY_CROSS_ENTROPY = "binary_cross_entropy"
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
class BagOfChannelsModelWithNormalize(nn.Module):
|
| 32 |
-
def __init__(self, model, autocast_ctx, avgpool, n_last_blocks=1):
|
| 33 |
-
super().__init__()
|
| 34 |
-
self.model = model
|
| 35 |
-
self.autocast_ctx = autocast_ctx
|
| 36 |
-
self.n_last_blocks = n_last_blocks
|
| 37 |
-
self.avgpool = avgpool
|
| 38 |
-
|
| 39 |
-
def forward(self, samples):
|
| 40 |
-
with self.autocast_ctx():
|
| 41 |
-
features = self.model.get_intermediate_layers(samples, self.n_last_blocks, return_class_token=True)
|
| 42 |
-
output = create_linear_input(features, self.avgpool, use_n_blocks=self.n_last_blocks)
|
| 43 |
-
return nn.functional.normalize(output, dim=1, p=2)
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
@torch.inference_mode()
|
| 47 |
-
def evaluate_with_accumulate(
|
| 48 |
-
model: nn.Module,
|
| 49 |
-
data_loader,
|
| 50 |
-
postprocessors: Dict[str, nn.Module],
|
| 51 |
-
metrics: Dict[str, MetricCollection],
|
| 52 |
-
device: torch.device,
|
| 53 |
-
criterion: Optional[nn.Module] = None,
|
| 54 |
-
test_mode: bool = False,
|
| 55 |
-
accumulate_results: bool = False,
|
| 56 |
-
leave_one_out: bool = False,
|
| 57 |
-
):
|
| 58 |
-
model.eval()
|
| 59 |
-
|
| 60 |
-
if test_mode:
|
| 61 |
-
output_tensor = {k: [] for k in postprocessors.keys()}
|
| 62 |
-
target_tensor = {k: [] for k in postprocessors.keys()}
|
| 63 |
-
|
| 64 |
-
if criterion is not None:
|
| 65 |
-
criterion.eval()
|
| 66 |
-
|
| 67 |
-
accumulator_class = ResultsAccumulator if accumulate_results else NoOpAccumulator
|
| 68 |
-
accumulators = {k: accumulator_class() for k in postprocessors.keys()}
|
| 69 |
-
|
| 70 |
-
for metric in metrics.values():
|
| 71 |
-
metric = metric.to(device)
|
| 72 |
-
|
| 73 |
-
metric_logger = MetricLogger(delimiter=" ")
|
| 74 |
-
header = "Test:"
|
| 75 |
-
|
| 76 |
-
for samples, targets, *_ in metric_logger.log_every(data_loader, 10, header):
|
| 77 |
-
if isinstance(targets, list):
|
| 78 |
-
index = targets[0]
|
| 79 |
-
targets = targets[1]
|
| 80 |
-
samples, targets, index = samples[index >= 0], targets[index >= 0], index[index >= 0]
|
| 81 |
-
if len(index) == 0:
|
| 82 |
-
continue
|
| 83 |
-
|
| 84 |
-
outputs = samples.to(device) if leave_one_out else model(samples.to(device))
|
| 85 |
-
targets = targets.to(device)
|
| 86 |
-
|
| 87 |
-
if criterion is not None:
|
| 88 |
-
loss = criterion(outputs, targets)
|
| 89 |
-
metric_logger.update(loss=loss.item())
|
| 90 |
-
|
| 91 |
-
for k, metric in metrics.items():
|
| 92 |
-
metric_inputs = postprocessors[k](outputs, targets)
|
| 93 |
-
metric.update(**metric_inputs)
|
| 94 |
-
if test_mode:
|
| 95 |
-
output_tensor[k].append(metric_inputs["preds"])
|
| 96 |
-
target_tensor[k].append(metric_inputs["target"])
|
| 97 |
-
accumulators[k].update(preds=metric_inputs["preds"], target=metric_inputs["target"], index=index)
|
| 98 |
-
|
| 99 |
-
metric_logger.synchronize_between_processes()
|
| 100 |
-
logger.info(f"Averaged stats: {metric_logger}")
|
| 101 |
-
|
| 102 |
-
stats = {k: metric.compute() for k, metric in metrics.items()}
|
| 103 |
-
metric_logger_stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
| 104 |
-
|
| 105 |
-
# accumulator.accumulate() returns None for the NoOpAccumulator
|
| 106 |
-
accumulated_results = {k: accumulator.accumulate() for k, accumulator in accumulators.items()}
|
| 107 |
-
if test_mode:
|
| 108 |
-
for k in postprocessors.keys():
|
| 109 |
-
output_tensor[k] = torch.cat(output_tensor[k])
|
| 110 |
-
target_tensor[k] = torch.cat(target_tensor[k])
|
| 111 |
-
accumulated_results = {k: [output_tensor[k], target_tensor[k]] for k in postprocessors.keys()}
|
| 112 |
-
|
| 113 |
-
if accumulate_results:
|
| 114 |
-
return metric_logger_stats, stats
|
| 115 |
-
return metric_logger_stats, stats, accumulated_results
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
def all_gather_and_flatten(tensor_rank):
|
| 119 |
-
tensor_all_ranks = torch.empty(
|
| 120 |
-
distributed.get_global_size(),
|
| 121 |
-
*tensor_rank.shape,
|
| 122 |
-
dtype=tensor_rank.dtype,
|
| 123 |
-
device=tensor_rank.device,
|
| 124 |
-
)
|
| 125 |
-
tensor_list = list(tensor_all_ranks.unbind(0))
|
| 126 |
-
torch.distributed.all_gather(tensor_list, tensor_rank.contiguous())
|
| 127 |
-
return tensor_all_ranks.flatten(end_dim=1)
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
def extract_features_cell_dino(
|
| 131 |
-
model, dataset, batch_size, num_workers, gather_on_cpu=False, shuffle=False, avgpool=False
|
| 132 |
-
):
|
| 133 |
-
dataset_with_enumerated_targets = DatasetWithEnumeratedTargets(dataset)
|
| 134 |
-
sample_count = len(dataset_with_enumerated_targets)
|
| 135 |
-
data_loader = make_data_loader(
|
| 136 |
-
dataset=dataset_with_enumerated_targets,
|
| 137 |
-
batch_size=batch_size,
|
| 138 |
-
num_workers=num_workers,
|
| 139 |
-
sampler_type=SamplerType.DISTRIBUTED,
|
| 140 |
-
drop_last=False,
|
| 141 |
-
shuffle=shuffle,
|
| 142 |
-
)
|
| 143 |
-
return extract_features_with_dataloader_cell_dino(model, data_loader, sample_count, gather_on_cpu, avgpool=avgpool)
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
@torch.inference_mode()
|
| 147 |
-
def extract_features_with_dataloader_cell_dino(model, data_loader, sample_count, gather_on_cpu=False, avgpool=False):
|
| 148 |
-
gather_device = torch.device("cpu") if gather_on_cpu else torch.device("cuda")
|
| 149 |
-
metric_logger = MetricLogger(delimiter=" ")
|
| 150 |
-
features, all_labels = None, None
|
| 151 |
-
for samples, (index, labels_rank) in metric_logger.log_every(data_loader, 10):
|
| 152 |
-
samples = samples.cuda(non_blocking=True)
|
| 153 |
-
labels_rank = labels_rank.cuda(non_blocking=True)
|
| 154 |
-
index = index.cuda(non_blocking=True)
|
| 155 |
-
feat = model(samples)
|
| 156 |
-
if isinstance(samples, list) or isinstance(feat, tuple):
|
| 157 |
-
features_rank = create_linear_input(feat, avgpool=avgpool)
|
| 158 |
-
else:
|
| 159 |
-
features_rank = feat
|
| 160 |
-
|
| 161 |
-
# init storage feature matrix
|
| 162 |
-
if features is None:
|
| 163 |
-
features = torch.zeros(sample_count, features_rank.shape[-1], device=gather_device)
|
| 164 |
-
labels_shape = list(labels_rank.shape)
|
| 165 |
-
labels_shape[0] = sample_count
|
| 166 |
-
all_labels = torch.full(labels_shape, fill_value=-1, device=gather_device)
|
| 167 |
-
logger.info(f"Storing features into tensor of shape {features.shape}")
|
| 168 |
-
|
| 169 |
-
# share indexes, features and labels between processes
|
| 170 |
-
index_all = all_gather_and_flatten(index).to(gather_device)
|
| 171 |
-
features_all_ranks = all_gather_and_flatten(features_rank).to(gather_device)
|
| 172 |
-
labels_all_ranks = all_gather_and_flatten(labels_rank).to(gather_device)
|
| 173 |
-
|
| 174 |
-
# update storage feature matrix
|
| 175 |
-
if len(index_all) > 0:
|
| 176 |
-
features.index_copy_(0, index_all, features_all_ranks)
|
| 177 |
-
all_labels.index_copy_(0, index_all, labels_all_ranks)
|
| 178 |
-
|
| 179 |
-
logger.info(f"Features shape: {tuple(features.shape)}")
|
| 180 |
-
logger.info(f"Labels shape: {tuple(all_labels.shape)}")
|
| 181 |
-
|
| 182 |
-
assert torch.all(all_labels > -1)
|
| 183 |
-
|
| 184 |
-
return features, all_labels
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
def create_linear_input(x_tokens_list, avgpool=False, use_n_blocks=1):
|
| 188 |
-
intermediate_output = x_tokens_list[-use_n_blocks:]
|
| 189 |
-
output = torch.cat(
|
| 190 |
-
[class_token for _, class_token in intermediate_output], dim=-1
|
| 191 |
-
) # concatenate class tokens of the last n blocks
|
| 192 |
-
if avgpool:
|
| 193 |
-
output = torch.cat(
|
| 194 |
-
(
|
| 195 |
-
output,
|
| 196 |
-
torch.mean(intermediate_output[-1][0], dim=-2).reshape(
|
| 197 |
-
intermediate_output[-1][0].shape[0], -1
|
| 198 |
-
), # average pooling of patch tokens: average over N, then concatenate channels if single-channel patch model
|
| 199 |
-
),
|
| 200 |
-
dim=-1,
|
| 201 |
-
) # concatenate average pooling of patch tokens to concatenated patch tokens
|
| 202 |
-
output = output.reshape(output.shape[0], -1)
|
| 203 |
-
|
| 204 |
-
return output.float()
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
def get_target_transform(dataset) -> Optional[Callable]:
|
| 208 |
-
if hasattr(dataset, "transforms"):
|
| 209 |
-
if isinstance(dataset.transforms, StandardTransform):
|
| 210 |
-
return dataset.transforms.target_transform
|
| 211 |
-
raise ValueError("Dataset has a non-standard .transforms property")
|
| 212 |
-
if hasattr(dataset, "target_transform"):
|
| 213 |
-
return dataset.target_transform
|
| 214 |
-
return None
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
def get_labels(dataset) -> torch.Tensor:
|
| 218 |
-
"""
|
| 219 |
-
Get the labels of a classification dataset, as a Tensor, using the `get_targets` method
|
| 220 |
-
if it is present or loading the labels one by one with `get_target`, if it exists.
|
| 221 |
-
If the dataset has a target transform, iterate over the whole dataset to get the
|
| 222 |
-
transformed labels for each element, then stack them as a torch tensor.
|
| 223 |
-
"""
|
| 224 |
-
logger.info("Getting dataset labels ...")
|
| 225 |
-
if hasattr(dataset, "get_targets") or hasattr(dataset, "get_target"):
|
| 226 |
-
if hasattr(dataset, "get_targets"): # Returns a np.array
|
| 227 |
-
labels = dataset.get_targets()
|
| 228 |
-
elif hasattr(dataset, "get_target"):
|
| 229 |
-
labels = [dataset.get_target(i) for i in range(len(dataset))]
|
| 230 |
-
target_transform = get_target_transform(dataset)
|
| 231 |
-
if target_transform is not None:
|
| 232 |
-
labels = [target_transform(label) for label in labels]
|
| 233 |
-
else:
|
| 234 |
-
# Target transform is applied in this case
|
| 235 |
-
labels = [dataset[i][1] for i in range(len(dataset))]
|
| 236 |
-
return torch.stack([torch.tensor(label, dtype=int) for label in labels])
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
def get_num_classes(dataset) -> int:
|
| 240 |
-
"""
|
| 241 |
-
Get the labels of a dataset and compute the number of classes
|
| 242 |
-
"""
|
| 243 |
-
labels = get_labels(dataset)
|
| 244 |
-
if len(labels.shape) > 1:
|
| 245 |
-
return int(labels.shape[1])
|
| 246 |
-
return int(labels.max() + 1)
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
def average_metrics(eval_metrics_dict: dict[Any, dict[str, torch.Tensor]], ignore_keys: List[str] = []):
|
| 250 |
-
"""
|
| 251 |
-
Function that computes the average and the std on a metrics dict.
|
| 252 |
-
A linear evaluation dictionary contains "best_classifier",
|
| 253 |
-
so this specific key is removed for computing aggregated metrics.
|
| 254 |
-
"""
|
| 255 |
-
output_metrics_dict = {}
|
| 256 |
-
metrics = [metric for metric in eval_metrics_dict[0].keys() if metric not in ignore_keys]
|
| 257 |
-
for metric in metrics:
|
| 258 |
-
stats_tensor = torch.tensor([stat[metric] for stat in eval_metrics_dict.values()])
|
| 259 |
-
output_metrics_dict[metric + "_mean"] = stats_tensor.mean().item()
|
| 260 |
-
output_metrics_dict[metric + "_std"] = torch.std(stats_tensor).item()
|
| 261 |
-
|
| 262 |
-
return output_metrics_dict
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
def create_class_indices_mapping(labels: torch.Tensor) -> dict[int, torch.Tensor]:
|
| 266 |
-
"""
|
| 267 |
-
Efficiently creates a mapping between the labels and tensors containing
|
| 268 |
-
the indices of all the dataset elements that share this label.
|
| 269 |
-
In the case of multiple labels, it is not guaranteed that there
|
| 270 |
-
will be exactly the specified percentage of labels.
|
| 271 |
-
"""
|
| 272 |
-
if len(labels.shape) > 1: # labels are a one-hot encoding
|
| 273 |
-
assert len(labels.shape) == 2
|
| 274 |
-
sorted_labels, indices = torch.nonzero(labels.T, as_tuple=True)
|
| 275 |
-
else:
|
| 276 |
-
sorted_labels, indices = torch.sort(labels, stable=True)
|
| 277 |
-
unique_labels, counts = torch.unique_consecutive(sorted_labels, return_counts=True)
|
| 278 |
-
mapping = dict(zip(unique_labels.tolist(), torch.split(indices, counts.tolist())))
|
| 279 |
-
return mapping
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
def _shuffle_dataset(dataset: torch.Tensor, seed: int = 0):
|
| 283 |
-
"""
|
| 284 |
-
Shuffling a dataset by subsetting it with a random permutation of its indices
|
| 285 |
-
"""
|
| 286 |
-
random_generator = torch.Generator()
|
| 287 |
-
random_generator.manual_seed(seed)
|
| 288 |
-
random_indices = torch.randperm(len(dataset), generator=random_generator)
|
| 289 |
-
return Subset(dataset, random_indices)
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
def _subset_dataset_per_class(
|
| 293 |
-
class_indices_mapping: dict[int, torch.Tensor],
|
| 294 |
-
n_or_percent_per_class: float,
|
| 295 |
-
dataset_size: int,
|
| 296 |
-
seed: int = 0,
|
| 297 |
-
is_percent: bool = False,
|
| 298 |
-
) -> torch.Tensor:
|
| 299 |
-
"""
|
| 300 |
-
Helper function to select a percentage of a dataset, equally distributed across classes,
|
| 301 |
-
or to take the same number of elements from each class of the dataset.
|
| 302 |
-
Returns a boolean mask tensor being True at indices of selected elements
|
| 303 |
-
"""
|
| 304 |
-
|
| 305 |
-
random_generator = torch.Generator()
|
| 306 |
-
random_generator.manual_seed(seed)
|
| 307 |
-
|
| 308 |
-
final_indices_bool = torch.zeros(dataset_size, dtype=bool)
|
| 309 |
-
for class_indices in class_indices_mapping.values():
|
| 310 |
-
# Select at least one element
|
| 311 |
-
n_for_class = max(int(len(class_indices) * n_or_percent_per_class), 1) if is_percent else n_or_percent_per_class
|
| 312 |
-
assert isinstance(n_for_class, int)
|
| 313 |
-
filtered_index = torch.randperm(len(class_indices), generator=random_generator)[:n_for_class]
|
| 314 |
-
final_indices_bool[class_indices[filtered_index]] = True
|
| 315 |
-
return final_indices_bool
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
def _multilabel_rebalance_subset(
|
| 319 |
-
class_indices_mapping: dict[int, torch.Tensor],
|
| 320 |
-
n_or_percent_per_class: float,
|
| 321 |
-
labels: torch.Tensor,
|
| 322 |
-
indices_bool: torch.Tensor,
|
| 323 |
-
dataset_size: int,
|
| 324 |
-
seed: int = 0,
|
| 325 |
-
) -> torch.Tensor:
|
| 326 |
-
"""
|
| 327 |
-
Helper function to refine a subset of a multi-label dataset (indices_bool)
|
| 328 |
-
to better match a target percentage of labels.
|
| 329 |
-
Returns a boolean mask tensor being True at indices of selected elements.
|
| 330 |
-
"""
|
| 331 |
-
|
| 332 |
-
# Compute the number of selected labels in indices_bool
|
| 333 |
-
num_total_labels = labels.sum()
|
| 334 |
-
num_wanted_labels = int(num_total_labels * n_or_percent_per_class)
|
| 335 |
-
num_selected_labels = (labels[indices_bool] > 0).sum()
|
| 336 |
-
logger.info(f" {num_selected_labels} labels instead of {num_wanted_labels}")
|
| 337 |
-
|
| 338 |
-
# Compute a new percentage and new set selecting less images, therefore less labels, to match approximatelly the exact percentage of labels selected
|
| 339 |
-
n_or_percent_per_class = n_or_percent_per_class / (num_selected_labels / num_wanted_labels)
|
| 340 |
-
final_indices_bool = _subset_dataset_per_class(
|
| 341 |
-
class_indices_mapping, n_or_percent_per_class, dataset_size, seed, True
|
| 342 |
-
)
|
| 343 |
-
|
| 344 |
-
# Compute the number of labels finally used
|
| 345 |
-
num_selected_labels = (labels[final_indices_bool] > 0).sum()
|
| 346 |
-
logger.info(f" {num_selected_labels} labels instead of {num_wanted_labels}")
|
| 347 |
-
|
| 348 |
-
return final_indices_bool
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
def split_train_val_datasets(train_dataset, split_percentage: float = 0.1, shuffle_train: bool = True):
|
| 352 |
-
"""
|
| 353 |
-
Splitting a percent of the train dataset to choose hyperparameters, taking the same percentage for each class.
|
| 354 |
-
If `shuffle` is False, taking the first elements of each class as the validaton set.
|
| 355 |
-
"""
|
| 356 |
-
assert 0 < split_percentage < 1
|
| 357 |
-
logger.info(f"Selecting {int(split_percentage * 100)}% of the train dataset as the validation set")
|
| 358 |
-
if shuffle_train:
|
| 359 |
-
logger.info("Shuffling train dataset before splitting in train and validation sets")
|
| 360 |
-
train_dataset = _shuffle_dataset(train_dataset)
|
| 361 |
-
train_labels = get_labels(train_dataset)
|
| 362 |
-
class_indices_mapping = create_class_indices_mapping(train_labels)
|
| 363 |
-
val_mask = torch.zeros(len(train_labels), dtype=bool)
|
| 364 |
-
for class_indices in class_indices_mapping.values():
|
| 365 |
-
# If there is only one element, it goes in the train set
|
| 366 |
-
n_for_val = max(1, int(split_percentage * len(class_indices))) if len(class_indices) > 1 else 0
|
| 367 |
-
val_mask[class_indices[:n_for_val]] = True
|
| 368 |
-
|
| 369 |
-
val_dataset = Subset(train_dataset, val_mask.nonzero().flatten())
|
| 370 |
-
train_dataset = Subset(train_dataset, (~val_mask).nonzero().flatten())
|
| 371 |
-
return train_dataset, val_dataset
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
def create_train_dataset_dict(
|
| 375 |
-
train_dataset,
|
| 376 |
-
few_shot_eval: bool = False,
|
| 377 |
-
few_shot_k_or_percent=None,
|
| 378 |
-
few_shot_n_tries: int = 1,
|
| 379 |
-
) -> dict[int, dict[int, Any]]:
|
| 380 |
-
"""
|
| 381 |
-
Randomly split a dataset for few-shot evaluation, with `few_shot_k_or_percent` being
|
| 382 |
-
n elements or x% of a class. Produces a dict, which keys are number of random "tries"
|
| 383 |
-
and values are the dataset subset for this "try".
|
| 384 |
-
|
| 385 |
-
Format is {"nth-try": dataset}
|
| 386 |
-
"""
|
| 387 |
-
if few_shot_eval is False:
|
| 388 |
-
assert few_shot_k_or_percent is None
|
| 389 |
-
assert few_shot_n_tries == 1
|
| 390 |
-
return {0: train_dataset}
|
| 391 |
-
|
| 392 |
-
assert few_shot_k_or_percent is not None
|
| 393 |
-
train_labels = get_labels(train_dataset)
|
| 394 |
-
class_indices_mapping = create_class_indices_mapping(train_labels)
|
| 395 |
-
train_dataset_dict: dict[int, Any] = {}
|
| 396 |
-
is_percent = few_shot_k_or_percent < 1
|
| 397 |
-
if not is_percent:
|
| 398 |
-
few_shot_k_or_percent = int(few_shot_k_or_percent)
|
| 399 |
-
|
| 400 |
-
for t in range(few_shot_n_tries):
|
| 401 |
-
t_subset_bool = _subset_dataset_per_class(
|
| 402 |
-
class_indices_mapping=class_indices_mapping,
|
| 403 |
-
n_or_percent_per_class=few_shot_k_or_percent,
|
| 404 |
-
dataset_size=len(train_labels),
|
| 405 |
-
is_percent=is_percent,
|
| 406 |
-
seed=t,
|
| 407 |
-
)
|
| 408 |
-
if len(train_labels.shape) > 1 and is_percent:
|
| 409 |
-
t_subset_bool = _multilabel_rebalance_subset(
|
| 410 |
-
class_indices_mapping=class_indices_mapping,
|
| 411 |
-
n_or_percent_per_class=few_shot_k_or_percent,
|
| 412 |
-
dataset_size=len(train_labels),
|
| 413 |
-
labels=train_labels,
|
| 414 |
-
indices_bool=t_subset_bool,
|
| 415 |
-
seed=t,
|
| 416 |
-
)
|
| 417 |
-
train_dataset_dict[t] = Subset(train_dataset, t_subset_bool.nonzero().flatten())
|
| 418 |
-
return train_dataset_dict
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
def extract_features_for_dataset_dict(
|
| 422 |
-
model,
|
| 423 |
-
dataset_dict: dict[int, dict[int, Any]],
|
| 424 |
-
batch_size: int,
|
| 425 |
-
num_workers: int,
|
| 426 |
-
gather_on_cpu=False,
|
| 427 |
-
avgpool=False,
|
| 428 |
-
) -> dict[int, dict[str, torch.Tensor]]:
|
| 429 |
-
"""
|
| 430 |
-
Extract features for each subset of dataset in the context of few-shot evaluations
|
| 431 |
-
"""
|
| 432 |
-
few_shot_data_dict: dict[int, dict[str, torch.Tensor]] = {}
|
| 433 |
-
for try_n, dataset in dataset_dict.items():
|
| 434 |
-
features, labels = extract_features_cell_dino(
|
| 435 |
-
model, dataset, batch_size, num_workers, gather_on_cpu=gather_on_cpu, avgpool=avgpool
|
| 436 |
-
)
|
| 437 |
-
few_shot_data_dict[try_n] = {"train_features": features, "train_labels": labels}
|
| 438 |
-
return few_shot_data_dict
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
def pad_multilabel_and_collate(batch, pad_value=-1):
|
| 442 |
-
"""
|
| 443 |
-
This method pads and collates a batch of (image, (index, target)) tuples, coming from
|
| 444 |
-
DatasetWithEnumeratedTargets, with targets that are list of potentially varying sizes.
|
| 445 |
-
The targets are padded to the length of the longest target list in the batch.
|
| 446 |
-
"""
|
| 447 |
-
maxlen = max(len(targets) for _, (_, targets) in batch)
|
| 448 |
-
padded_batch = [
|
| 449 |
-
(image, (index, np.pad(targets, (0, maxlen - len(targets)), constant_values=pad_value)))
|
| 450 |
-
for image, (index, targets) in batch
|
| 451 |
-
]
|
| 452 |
-
return torch.utils.data.default_collate(padded_batch)
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
class KnnModule(torch.nn.Module):
|
| 456 |
-
"""
|
| 457 |
-
Gets knn of test features from all processes on a chunk of the train features
|
| 458 |
-
|
| 459 |
-
Each rank gets a chunk of the train features as well as a chunk of the test features.
|
| 460 |
-
In `compute_neighbors`, for each rank one after the other, its chunk of test features
|
| 461 |
-
is sent to all devices, partial knns are computed with each chunk of train features
|
| 462 |
-
then collated back on the original device.
|
| 463 |
-
"""
|
| 464 |
-
|
| 465 |
-
def __init__(self, train_features, train_labels, nb_knn, T, device, num_classes=1000):
|
| 466 |
-
super().__init__()
|
| 467 |
-
|
| 468 |
-
self.global_rank = distributed.get_global_rank()
|
| 469 |
-
self.global_size = distributed.get_global_size()
|
| 470 |
-
|
| 471 |
-
self.device = device
|
| 472 |
-
self.train_features_rank_T = train_features.chunk(self.global_size)[self.global_rank].T.to(self.device)
|
| 473 |
-
# Labels can either be integers, or in a one-hot format
|
| 474 |
-
self.candidates = train_labels.chunk(self.global_size)[self.global_rank].unsqueeze(0).to(self.device)
|
| 475 |
-
self.nb_knn = nb_knn
|
| 476 |
-
self.max_k = max(self.nb_knn)
|
| 477 |
-
self.T = T
|
| 478 |
-
self.num_classes = num_classes
|
| 479 |
-
|
| 480 |
-
def _get_knn_sims_and_labels(self, similarity, train_labels):
|
| 481 |
-
topk_sims, indices = similarity.topk(self.max_k, largest=True, sorted=True)
|
| 482 |
-
if len(train_labels.shape) == 3: # If the labels are in one_hot format
|
| 483 |
-
indices = indices.unsqueeze(2).expand(-1, -1, self.num_classes) # Orignally [bs, max_k]
|
| 484 |
-
neighbors_labels = torch.gather(train_labels, 1, indices)
|
| 485 |
-
return topk_sims, neighbors_labels
|
| 486 |
-
|
| 487 |
-
def _similarity_for_rank(self, features_rank, source_rank):
|
| 488 |
-
# Send the features from `source_rank` to all ranks
|
| 489 |
-
broadcast_shape = torch.tensor(features_rank.shape).to(self.device)
|
| 490 |
-
torch.distributed.broadcast(broadcast_shape, source_rank)
|
| 491 |
-
|
| 492 |
-
broadcasted = features_rank
|
| 493 |
-
if self.global_rank != source_rank:
|
| 494 |
-
broadcasted = torch.zeros(*broadcast_shape, dtype=features_rank.dtype, device=self.device)
|
| 495 |
-
torch.distributed.broadcast(broadcasted, source_rank)
|
| 496 |
-
|
| 497 |
-
# Compute the neighbors for `source_rank` among `train_features_rank_T`
|
| 498 |
-
similarity_rank = torch.mm(broadcasted, self.train_features_rank_T)
|
| 499 |
-
candidate_labels = self.candidates.expand(len(similarity_rank), *self.candidates.shape[1:])
|
| 500 |
-
return self._get_knn_sims_and_labels(similarity_rank, candidate_labels)
|
| 501 |
-
|
| 502 |
-
def _gather_all_knn_for_rank(self, topk_sims, neighbors_labels, target_rank):
|
| 503 |
-
# Gather all neighbors for `target_rank`
|
| 504 |
-
topk_sims_rank = retrieved_rank = None
|
| 505 |
-
if self.global_rank == target_rank:
|
| 506 |
-
topk_sims_rank = [torch.zeros_like(topk_sims) for _ in range(self.global_size)]
|
| 507 |
-
retrieved_rank = [torch.zeros_like(neighbors_labels) for _ in range(self.global_size)]
|
| 508 |
-
|
| 509 |
-
torch.distributed.gather(topk_sims, topk_sims_rank, dst=target_rank)
|
| 510 |
-
torch.distributed.gather(neighbors_labels, retrieved_rank, dst=target_rank)
|
| 511 |
-
|
| 512 |
-
if self.global_rank == target_rank:
|
| 513 |
-
# Perform a second top-k on the k * global_size retrieved neighbors
|
| 514 |
-
topk_sims_rank = torch.cat(topk_sims_rank, dim=1)
|
| 515 |
-
retrieved_rank = torch.cat(retrieved_rank, dim=1)
|
| 516 |
-
results = self._get_knn_sims_and_labels(topk_sims_rank, retrieved_rank)
|
| 517 |
-
return results
|
| 518 |
-
return None
|
| 519 |
-
|
| 520 |
-
def compute_neighbors(self, features_rank):
|
| 521 |
-
for rank in range(self.global_size):
|
| 522 |
-
topk_sims, neighbors_labels = self._similarity_for_rank(features_rank, rank)
|
| 523 |
-
results = self._gather_all_knn_for_rank(topk_sims, neighbors_labels, rank)
|
| 524 |
-
if results is not None:
|
| 525 |
-
topk_sims_rank, neighbors_labels_rank = results
|
| 526 |
-
return topk_sims_rank, neighbors_labels_rank
|
| 527 |
-
|
| 528 |
-
def forward(self, features_rank):
|
| 529 |
-
"""
|
| 530 |
-
Compute the results on all values of `self.nb_knn` neighbors from the full `self.max_k`
|
| 531 |
-
"""
|
| 532 |
-
assert all(k <= self.max_k for k in self.nb_knn)
|
| 533 |
-
|
| 534 |
-
topk_sims, neighbors_labels = self.compute_neighbors(features_rank)
|
| 535 |
-
batch_size = neighbors_labels.shape[0]
|
| 536 |
-
topk_sims_transform = softmax(topk_sims / self.T, 1)
|
| 537 |
-
voting_coefficient = topk_sims_transform.view(batch_size, -1, 1)
|
| 538 |
-
if len(neighbors_labels.shape) == 2: # If the labels are not yet one hot
|
| 539 |
-
neighbors_labels = one_hot(neighbors_labels, num_classes=self.num_classes)
|
| 540 |
-
matmul = torch.mul(neighbors_labels, voting_coefficient)
|
| 541 |
-
probas_for_k = {k: torch.sum(matmul[:, :k, :], 1) for k in self.nb_knn}
|
| 542 |
-
return probas_for_k
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dinov2/eval/depth/__init__.py
DELETED
|
@@ -1,4 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dinov2/eval/depth/models/__init__.py
DELETED
|
@@ -1,10 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
from .backbones import * # noqa: F403
|
| 7 |
-
from .builder import BACKBONES, DEPTHER, HEADS, LOSSES, build_backbone, build_depther, build_head, build_loss
|
| 8 |
-
from .decode_heads import * # noqa: F403
|
| 9 |
-
from .depther import * # noqa: F403
|
| 10 |
-
from .losses import * # noqa: F403
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dinov2/eval/depth/models/backbones/__init__.py
DELETED
|
@@ -1,6 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
from .vision_transformer import DinoVisionTransformer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dinov2/eval/depth/models/backbones/vision_transformer.py
DELETED
|
@@ -1,16 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
from mmcv.runner import BaseModule
|
| 7 |
-
|
| 8 |
-
from ..builder import BACKBONES
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
@BACKBONES.register_module()
|
| 12 |
-
class DinoVisionTransformer(BaseModule):
|
| 13 |
-
"""Vision Transformer."""
|
| 14 |
-
|
| 15 |
-
def __init__(self, *args, **kwargs):
|
| 16 |
-
super().__init__()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dinov2/eval/depth/models/builder.py
DELETED
|
@@ -1,49 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
import warnings
|
| 7 |
-
|
| 8 |
-
from mmcv.cnn import MODELS as MMCV_MODELS
|
| 9 |
-
from mmcv.cnn.bricks.registry import ATTENTION as MMCV_ATTENTION
|
| 10 |
-
from mmcv.utils import Registry
|
| 11 |
-
|
| 12 |
-
MODELS = Registry("models", parent=MMCV_MODELS)
|
| 13 |
-
ATTENTION = Registry("attention", parent=MMCV_ATTENTION)
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
BACKBONES = MODELS
|
| 17 |
-
NECKS = MODELS
|
| 18 |
-
HEADS = MODELS
|
| 19 |
-
LOSSES = MODELS
|
| 20 |
-
DEPTHER = MODELS
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
def build_backbone(cfg):
|
| 24 |
-
"""Build backbone."""
|
| 25 |
-
return BACKBONES.build(cfg)
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
def build_neck(cfg):
|
| 29 |
-
"""Build neck."""
|
| 30 |
-
return NECKS.build(cfg)
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
def build_head(cfg):
|
| 34 |
-
"""Build head."""
|
| 35 |
-
return HEADS.build(cfg)
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
def build_loss(cfg):
|
| 39 |
-
"""Build loss."""
|
| 40 |
-
return LOSSES.build(cfg)
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
def build_depther(cfg, train_cfg=None, test_cfg=None):
|
| 44 |
-
"""Build depther."""
|
| 45 |
-
if train_cfg is not None or test_cfg is not None:
|
| 46 |
-
warnings.warn("train_cfg and test_cfg is deprecated, " "please specify them in model", UserWarning)
|
| 47 |
-
assert cfg.get("train_cfg") is None or train_cfg is None, "train_cfg specified in both outer field and model field "
|
| 48 |
-
assert cfg.get("test_cfg") is None or test_cfg is None, "test_cfg specified in both outer field and model field "
|
| 49 |
-
return DEPTHER.build(cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|