Addax-Data-Science commited on
Commit
c8ae7e8
·
verified ·
1 Parent(s): 269a3ed

Upload 174 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. dinov2/__init__.py +6 -0
  2. dinov2/configs/__init__.py +22 -0
  3. dinov2/configs/eval/cell_dino/vitl16_channel_adaptive_pretrain.yaml +35 -0
  4. dinov2/configs/eval/cell_dino/vitl16_pretrain.yaml +14 -0
  5. dinov2/configs/eval/vitb14_pretrain.yaml +6 -0
  6. dinov2/configs/eval/vitb14_reg4_pretrain.yaml +9 -0
  7. dinov2/configs/eval/vitg14_pretrain.yaml +7 -0
  8. dinov2/configs/eval/vitg14_reg4_pretrain.yaml +10 -0
  9. dinov2/configs/eval/vitl14_pretrain.yaml +6 -0
  10. dinov2/configs/eval/vitl14_reg4_pretrain.yaml +9 -0
  11. dinov2/configs/eval/vits14_pretrain.yaml +6 -0
  12. dinov2/configs/eval/vits14_reg4_pretrain.yaml +9 -0
  13. dinov2/configs/ssl_default_config.yaml +123 -0
  14. dinov2/configs/train/cell_dino/vitl16_boc_hpafov.yaml +31 -0
  15. dinov2/configs/train/cell_dino/vitl16_hpafov.yaml +32 -0
  16. dinov2/configs/train/cell_dino/vitl16_hpaone.yaml +30 -0
  17. dinov2/configs/train/vitg14.yaml +26 -0
  18. dinov2/configs/train/vitl14.yaml +26 -0
  19. dinov2/configs/train/vitl16_short.yaml +6 -0
  20. dinov2/data/__init__.py +12 -0
  21. dinov2/data/accumulators.py +133 -0
  22. dinov2/data/adapters.py +51 -0
  23. dinov2/data/augmentations.py +118 -0
  24. dinov2/data/cell_dino/augmentations.py +91 -0
  25. dinov2/data/cell_dino/transforms.py +169 -0
  26. dinov2/data/collate.py +49 -0
  27. dinov2/data/datasets/__init__.py +12 -0
  28. dinov2/data/datasets/cell_dino/chammi_cp.py +112 -0
  29. dinov2/data/datasets/cell_dino/chammi_hpa.py +111 -0
  30. dinov2/data/datasets/cell_dino/chammi_wtc.py +108 -0
  31. dinov2/data/datasets/cell_dino/hpafov.py +283 -0
  32. dinov2/data/datasets/cell_dino/hpaone.py +223 -0
  33. dinov2/data/datasets/decoders.py +94 -0
  34. dinov2/data/datasets/extended.py +44 -0
  35. dinov2/data/datasets/image_net.py +290 -0
  36. dinov2/data/datasets/image_net_22k.py +302 -0
  37. dinov2/data/loaders.py +232 -0
  38. dinov2/data/masking.py +86 -0
  39. dinov2/data/samplers.py +229 -0
  40. dinov2/data/transforms.py +91 -0
  41. dinov2/distributed/__init__.py +270 -0
  42. dinov2/eval/__init__.py +4 -0
  43. dinov2/eval/cell_dino/knn.py +479 -0
  44. dinov2/eval/cell_dino/linear.py +1048 -0
  45. dinov2/eval/cell_dino/utils.py +542 -0
  46. dinov2/eval/depth/__init__.py +4 -0
  47. dinov2/eval/depth/models/__init__.py +10 -0
  48. dinov2/eval/depth/models/backbones/__init__.py +6 -0
  49. dinov2/eval/depth/models/backbones/vision_transformer.py +16 -0
  50. dinov2/eval/depth/models/builder.py +49 -0
dinov2/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ __version__ = "0.0.1"
dinov2/configs/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import pathlib
7
+
8
+ from omegaconf import OmegaConf
9
+
10
+
11
+ def load_config(config_name: str):
12
+ config_filename = config_name + ".yaml"
13
+ return OmegaConf.load(pathlib.Path(__file__).parent.resolve() / config_filename)
14
+
15
+
16
+ dinov2_default_config = load_config("ssl_default_config")
17
+
18
+
19
+ def load_and_merge_config(config_name: str):
20
+ default_config = OmegaConf.create(dinov2_default_config)
21
+ loaded_config = load_config(config_name)
22
+ return OmegaConf.merge(default_config, loaded_config)
dinov2/configs/eval/cell_dino/vitl16_channel_adaptive_pretrain.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train:
2
+ batch_size_per_gpu: 32
3
+ OFFICIAL_EPOCH_LENGTH: 450
4
+ cell_augmentation: true
5
+ channel_adaptive: true
6
+ student:
7
+ arch: vit_large
8
+ patch_size: 16
9
+ num_register_tokens: 0
10
+ interpolate_antialias: false
11
+ interpolate_offset: 0.1
12
+ drop_path_rate: 0.1
13
+ in_chans: 1
14
+ block_chunks: 4
15
+ channel_adaptive: true
16
+ teacher:
17
+ momentum_teacher: 0.996
18
+ warmup_teacher_temp_epochs: 20
19
+ in_chans: 1
20
+ channel_adaptive: true
21
+ crops:
22
+ global_crops_scale:
23
+ - 0.4
24
+ - 1.0
25
+ local_crops_number: 8
26
+ local_crops_scale:
27
+ - 0.005
28
+ - 0.4
29
+ global_crops_size: 224
30
+ local_crops_size: 96
31
+ optim:
32
+ weight_decay_end: 0.2
33
+ base_lr: 5.0e-4
34
+ warmup_epochs: 20
35
+ epochs: 400
dinov2/configs/eval/cell_dino/vitl16_pretrain.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ student:
2
+ arch: vit_large
3
+ patch_size: 16
4
+ num_register_tokens: 0
5
+ interpolate_antialias: false
6
+ interpolate_offset: 0.1
7
+ drop_path_rate: 0.1
8
+ in_chans: 4
9
+ block_chunks: 4
10
+ teacher:
11
+ in_chans: 4
12
+ crops:
13
+ global_crops_size: 224
14
+ local_crops_size: 96
dinov2/configs/eval/vitb14_pretrain.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ student:
2
+ arch: vit_base
3
+ patch_size: 14
4
+ crops:
5
+ global_crops_size: 518 # this is to set up the position embeddings properly
6
+ local_crops_size: 98
dinov2/configs/eval/vitb14_reg4_pretrain.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ student:
2
+ arch: vit_base
3
+ patch_size: 14
4
+ num_register_tokens: 4
5
+ interpolate_antialias: true
6
+ interpolate_offset: 0.0
7
+ crops:
8
+ global_crops_size: 518 # this is to set up the position embeddings properly
9
+ local_crops_size: 98
dinov2/configs/eval/vitg14_pretrain.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ student:
2
+ arch: vit_giant2
3
+ patch_size: 14
4
+ ffn_layer: swiglufused
5
+ crops:
6
+ global_crops_size: 518 # this is to set up the position embeddings properly
7
+ local_crops_size: 98
dinov2/configs/eval/vitg14_reg4_pretrain.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ student:
2
+ arch: vit_giant2
3
+ patch_size: 14
4
+ ffn_layer: swiglufused
5
+ num_register_tokens: 4
6
+ interpolate_antialias: true
7
+ interpolate_offset: 0.0
8
+ crops:
9
+ global_crops_size: 518 # this is to set up the position embeddings properly
10
+ local_crops_size: 98
dinov2/configs/eval/vitl14_pretrain.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ student:
2
+ arch: vit_large
3
+ patch_size: 14
4
+ crops:
5
+ global_crops_size: 518 # this is to set up the position embeddings properly
6
+ local_crops_size: 98
dinov2/configs/eval/vitl14_reg4_pretrain.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ student:
2
+ arch: vit_large
3
+ patch_size: 14
4
+ num_register_tokens: 4
5
+ interpolate_antialias: true
6
+ interpolate_offset: 0.0
7
+ crops:
8
+ global_crops_size: 518 # this is to set up the position embeddings properly
9
+ local_crops_size: 98
dinov2/configs/eval/vits14_pretrain.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ student:
2
+ arch: vit_small
3
+ patch_size: 14
4
+ crops:
5
+ global_crops_size: 518 # this is to set up the position embeddings properly
6
+ local_crops_size: 98
dinov2/configs/eval/vits14_reg4_pretrain.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ student:
2
+ arch: vit_small
3
+ patch_size: 14
4
+ num_register_tokens: 4
5
+ interpolate_antialias: true
6
+ interpolate_offset: 0.0
7
+ crops:
8
+ global_crops_size: 518 # this is to set up the position embeddings properly
9
+ local_crops_size: 98
dinov2/configs/ssl_default_config.yaml ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL:
2
+ WEIGHTS: ''
3
+ compute_precision:
4
+ grad_scaler: true
5
+ teacher:
6
+ backbone:
7
+ sharding_strategy: SHARD_GRAD_OP
8
+ mixed_precision:
9
+ param_dtype: fp16
10
+ reduce_dtype: fp16
11
+ buffer_dtype: fp32
12
+ dino_head:
13
+ sharding_strategy: SHARD_GRAD_OP
14
+ mixed_precision:
15
+ param_dtype: fp16
16
+ reduce_dtype: fp16
17
+ buffer_dtype: fp32
18
+ ibot_head:
19
+ sharding_strategy: SHARD_GRAD_OP
20
+ mixed_precision:
21
+ param_dtype: fp16
22
+ reduce_dtype: fp16
23
+ buffer_dtype: fp32
24
+ student:
25
+ backbone:
26
+ sharding_strategy: SHARD_GRAD_OP
27
+ mixed_precision:
28
+ param_dtype: fp16
29
+ reduce_dtype: fp16
30
+ buffer_dtype: fp32
31
+ dino_head:
32
+ sharding_strategy: SHARD_GRAD_OP
33
+ mixed_precision:
34
+ param_dtype: fp16
35
+ reduce_dtype: fp32
36
+ buffer_dtype: fp32
37
+ ibot_head:
38
+ sharding_strategy: SHARD_GRAD_OP
39
+ mixed_precision:
40
+ param_dtype: fp16
41
+ reduce_dtype: fp32
42
+ buffer_dtype: fp32
43
+ dino:
44
+ loss_weight: 1.0
45
+ head_n_prototypes: 65536
46
+ head_bottleneck_dim: 256
47
+ head_nlayers: 3
48
+ head_hidden_dim: 2048
49
+ koleo_loss_weight: 0.1
50
+ ibot:
51
+ loss_weight: 1.0
52
+ mask_sample_probability: 0.5
53
+ mask_ratio_min_max:
54
+ - 0.1
55
+ - 0.5
56
+ separate_head: false
57
+ head_n_prototypes: 65536
58
+ head_bottleneck_dim: 256
59
+ head_nlayers: 3
60
+ head_hidden_dim: 2048
61
+ train:
62
+ batch_size_per_gpu: 64
63
+ dataset_path: ImageNet:split=TRAIN
64
+ output_dir: .
65
+ saveckp_freq: 20
66
+ seed: 0
67
+ num_workers: 10
68
+ OFFICIAL_EPOCH_LENGTH: 1250
69
+ cache_dataset: true
70
+ centering: "centering" # or "sinkhorn_knopp"
71
+ cell_augmentation: false
72
+ student:
73
+ arch: vit_large
74
+ patch_size: 16
75
+ drop_path_rate: 0.3
76
+ layerscale: 1.0e-05
77
+ drop_path_uniform: true
78
+ pretrained_weights: ''
79
+ ffn_layer: "mlp"
80
+ block_chunks: 0
81
+ qkv_bias: true
82
+ proj_bias: true
83
+ ffn_bias: true
84
+ num_register_tokens: 0
85
+ interpolate_antialias: false
86
+ interpolate_offset: 0.1
87
+ in_chans: 3
88
+ channel_adaptive: false
89
+ teacher:
90
+ momentum_teacher: 0.992
91
+ final_momentum_teacher: 1
92
+ warmup_teacher_temp: 0.04
93
+ teacher_temp: 0.07
94
+ warmup_teacher_temp_epochs: 30
95
+ in_chans: 3
96
+ channel_adaptive: false
97
+ optim:
98
+ epochs: 100
99
+ weight_decay: 0.04
100
+ weight_decay_end: 0.4
101
+ base_lr: 0.004 # learning rate for a batch size of 1024
102
+ lr: 0. # will be set after applying scaling rule
103
+ warmup_epochs: 10
104
+ min_lr: 1.0e-06
105
+ clip_grad: 3.0
106
+ freeze_last_layer_epochs: 1
107
+ scaling_rule: sqrt_wrt_1024
108
+ patch_embed_lr_mult: 0.2
109
+ layerwise_decay: 0.9
110
+ adamw_beta1: 0.9
111
+ adamw_beta2: 0.999
112
+ crops:
113
+ global_crops_scale:
114
+ - 0.32
115
+ - 1.0
116
+ local_crops_number: 8
117
+ local_crops_scale:
118
+ - 0.05
119
+ - 0.32
120
+ global_crops_size: 224
121
+ local_crops_size: 96
122
+ evaluation:
123
+ eval_period_iterations: 12500
dinov2/configs/train/cell_dino/vitl16_boc_hpafov.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train:
2
+ batch_size_per_gpu: 16
3
+ OFFICIAL_EPOCH_LENGTH: 450
4
+ cell_augmentation: true
5
+ channel_adaptive: true
6
+ student:
7
+ arch: vit_large
8
+ patch_size: 16
9
+ in_chans: 1
10
+ drop_path_rate: 0.1
11
+ block_chunks: 4
12
+ teacher:
13
+ momentum_teacher: 0.996
14
+ warmup_teacher_temp_epochs: 20
15
+ in_chans: 1
16
+ crops:
17
+ global_crops_scale:
18
+ - 0.4
19
+ - 1.0
20
+ local_crops_number: 8
21
+ local_crops_scale:
22
+ - 0.005
23
+ - 0.4
24
+ global_crops_size: 224
25
+ local_crops_size: 96
26
+ optim:
27
+ weight_decay_end: 0.2
28
+ base_lr: 5.0e-4
29
+ warmup_epochs: 20
30
+ epochs: 400
31
+
dinov2/configs/train/cell_dino/vitl16_hpafov.yaml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train:
2
+ batch_size_per_gpu: 16
3
+ OFFICIAL_EPOCH_LENGTH: 450
4
+ cell_augmentation: true
5
+ student:
6
+ arch: vit_large
7
+ patch_size: 16
8
+ in_chans: 4
9
+ drop_path_rate: 0.1
10
+ block_chunks: 4
11
+ teacher:
12
+ momentum_teacher: 0.996
13
+ warmup_teacher_temp_epochs: 20
14
+ in_chans: 4
15
+ optim:
16
+ weight_decay_end: 0.2
17
+ base_lr: 5.0e-4
18
+ warmup_epochs: 20
19
+ crops:
20
+ global_crops_scale:
21
+ - 0.4
22
+ - 1.0
23
+ local_crops_number: 8
24
+ local_crops_scale:
25
+ - 0.005
26
+ - 0.4
27
+ global_crops_size: 224
28
+ local_crops_size: 96
29
+ evaluation:
30
+ eval_period_iterations: 9000
31
+
32
+
dinov2/configs/train/cell_dino/vitl16_hpaone.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train:
2
+ batch_size_per_gpu: 16
3
+ OFFICIAL_EPOCH_LENGTH: 1756
4
+ cell_augmentation: true
5
+ student:
6
+ arch: vit_large
7
+ patch_size: 16
8
+ in_chans: 4
9
+ drop_path_rate: 0.1
10
+ block_chunks: 4
11
+ teacher:
12
+ momentum_teacher: 0.996
13
+ warmup_teacher_temp_epochs: 20
14
+ in_chans: 4
15
+ optim:
16
+ weight_decay_end: 0.2
17
+ base_lr: 5.0e-4
18
+ warmup_epochs: 20
19
+ crops:
20
+ global_crops_scale:
21
+ - 0.4
22
+ - 1.0
23
+ local_crops_number: 8
24
+ local_crops_scale:
25
+ - 0.005
26
+ - 0.4
27
+ global_crops_size: 224
28
+ local_crops_size: 96
29
+ evaluation:
30
+ eval_period_iterations: 9000
dinov2/configs/train/vitg14.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dino:
2
+ head_n_prototypes: 131072
3
+ head_bottleneck_dim: 384
4
+ ibot:
5
+ separate_head: true
6
+ head_n_prototypes: 131072
7
+ train:
8
+ batch_size_per_gpu: 12
9
+ dataset_path: ImageNet22k
10
+ centering: sinkhorn_knopp
11
+ student:
12
+ arch: vit_giant2
13
+ patch_size: 14
14
+ drop_path_rate: 0.4
15
+ ffn_layer: swiglufused
16
+ block_chunks: 4
17
+ teacher:
18
+ momentum_teacher: 0.994
19
+ optim:
20
+ epochs: 500
21
+ weight_decay_end: 0.2
22
+ base_lr: 2.0e-04 # learning rate for a batch size of 1024
23
+ warmup_epochs: 80
24
+ layerwise_decay: 1.0
25
+ crops:
26
+ local_crops_size: 98
dinov2/configs/train/vitl14.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dino:
2
+ head_n_prototypes: 131072
3
+ head_bottleneck_dim: 384
4
+ ibot:
5
+ separate_head: true
6
+ head_n_prototypes: 131072
7
+ train:
8
+ batch_size_per_gpu: 32
9
+ dataset_path: ImageNet22k
10
+ centering: sinkhorn_knopp
11
+ student:
12
+ arch: vit_large
13
+ patch_size: 14
14
+ drop_path_rate: 0.4
15
+ ffn_layer: swiglufused
16
+ block_chunks: 4
17
+ teacher:
18
+ momentum_teacher: 0.994
19
+ optim:
20
+ epochs: 500
21
+ weight_decay_end: 0.2
22
+ base_lr: 2.0e-04 # learning rate for a batch size of 1024
23
+ warmup_epochs: 80
24
+ layerwise_decay: 1.0
25
+ crops:
26
+ local_crops_size: 98
dinov2/configs/train/vitl16_short.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # this corresponds to the default config
2
+ train:
3
+ dataset_path: ImageNet:split=TRAIN
4
+ batch_size_per_gpu: 64
5
+ student:
6
+ block_chunks: 4
dinov2/data/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from .adapters import DatasetWithEnumeratedTargets
7
+ from .loaders import make_data_loader, make_dataset, SamplerType
8
+ from .collate import collate_data_and_cast
9
+ from .masking import MaskingGenerator
10
+ from .augmentations import DataAugmentationDINO
11
+ from .cell_dino.augmentations import CellAugmentationDINO
12
+ from .accumulators import NoOpAccumulator, ResultsAccumulator
dinov2/data/accumulators.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from collections import defaultdict
7
+ from typing import Dict, List, Optional, Any
8
+
9
+ import torch
10
+ from torch import Tensor
11
+ from torch.nn import functional as F
12
+
13
+ import torch.distributed as dist
14
+ from dinov2.distributed import get_global_size
15
+
16
+
17
+ def _simple_gather_all_tensors(result: torch.Tensor, group: Any, world_size: int) -> List[torch.Tensor]:
18
+ gathered_result = [torch.zeros_like(result) for _ in range(world_size)]
19
+ dist.all_gather(gathered_result, result, group)
20
+ return gathered_result
21
+
22
+
23
+ def gather_all_tensors(result: torch.Tensor, group: Optional[Any] = None) -> List[torch.Tensor]:
24
+ """
25
+ Copied from https://github.com/Lightning-AI/torchmetrics/blob/master/src/torchmetrics/utilities/distributed.py
26
+ Gather all tensors from several ddp processes onto a list that is broadcasted to all processes.
27
+
28
+ Works on tensors that have the same number of dimensions, but where each dimension may differ. In this case
29
+ tensors are padded, gathered and then trimmed to secure equal workload for all processes.
30
+
31
+ Args:
32
+ result: the value to sync
33
+ group: the process group to gather results from. Defaults to all processes (world)
34
+
35
+ Return:
36
+ list with size equal to the process group where element i corresponds to result tensor from process i
37
+ """
38
+ # convert tensors to contiguous format
39
+ result = result.contiguous()
40
+
41
+ world_size = get_global_size()
42
+ dist.barrier(group=group)
43
+
44
+ # if the tensor is scalar, things are easy
45
+ if result.ndim == 0:
46
+ return _simple_gather_all_tensors(result, group, world_size)
47
+
48
+ # 1. Gather sizes of all tensors
49
+ local_size = torch.tensor(result.shape, device=result.device)
50
+ local_sizes = [torch.zeros_like(local_size) for _ in range(world_size)]
51
+ dist.all_gather(local_sizes, local_size, group=group)
52
+ max_size = torch.stack(local_sizes).max(dim=0).values
53
+ all_sizes_equal = all(all(ls == max_size) for ls in local_sizes)
54
+
55
+ # 2. If shapes are all the same, then do a simple gather:
56
+ if all_sizes_equal:
57
+ return _simple_gather_all_tensors(result, group, world_size)
58
+
59
+ # 3. If not, we need to pad each local tensor to maximum size, gather and then truncate
60
+ pad_dims = []
61
+ pad_by = (max_size - local_size).detach().cpu()
62
+ for val in reversed(pad_by):
63
+ pad_dims.append(0)
64
+ pad_dims.append(val.item())
65
+ result_padded = F.pad(result, pad_dims)
66
+ gathered_result = [torch.zeros_like(result_padded) for _ in range(world_size)]
67
+ dist.all_gather(gathered_result, result_padded, group)
68
+ for idx, item_size in enumerate(local_sizes):
69
+ slice_param = [slice(dim_size) for dim_size in item_size]
70
+ gathered_result[idx] = gathered_result[idx][slice_param]
71
+ return gathered_result
72
+
73
+
74
+ def _cat_and_gather_tensor_list(tensor_list: List[Tensor]) -> Tensor:
75
+ local_cat = torch.cat(tensor_list)
76
+ return torch.cat(gather_all_tensors(local_cat))
77
+
78
+
79
+ class Accumulator:
80
+ def __init__(self) -> None:
81
+ pass
82
+
83
+ def update(self, preds: Tensor, target: Tensor, index: Tensor) -> None:
84
+ raise NotImplementedError
85
+
86
+ def accumulate(self) -> Optional[Dict[str, Tensor]]:
87
+ raise NotImplementedError
88
+
89
+
90
+ class NoOpAccumulator(Accumulator):
91
+ def __init__(self) -> None:
92
+ pass
93
+
94
+ def update(self, preds: Tensor, target: Tensor, index: Tensor) -> None:
95
+ pass
96
+
97
+ def accumulate(self) -> None:
98
+ return None
99
+
100
+
101
+ class ResultsAccumulator(Accumulator):
102
+ """
103
+ Accumulate predictions and targets across processes
104
+ """
105
+
106
+ def __init__(self) -> None:
107
+ self._local_values: Dict[str, List[Tensor]] = defaultdict(list)
108
+ self._gathered_values: Dict[str, Tensor] = {}
109
+ self._gathered = False
110
+
111
+ def update(self, preds: Tensor, target: Tensor, index: Tensor) -> None:
112
+ assert len(preds) == len(target) == len(index)
113
+ assert not self._gathered, "Tensors have already been gathered in this helper"
114
+ self._local_values["preds"].append(preds)
115
+ self._local_values["target"].append(target)
116
+ self._local_values["index"].append(index)
117
+ self._gathered = False
118
+
119
+ def _gather_tensors(self):
120
+ for k, tensor_list in self._local_values.items():
121
+ self._gathered_values[k] = _cat_and_gather_tensor_list(tensor_list)
122
+ self._gathered = True
123
+
124
+ def accumulate(self) -> Dict[str, Tensor]:
125
+ if not self._gathered:
126
+ self._gather_tensors()
127
+ preds, target, index = [self._gathered_values[k] for k in ["preds", "target", "index"]]
128
+ assert len(preds) == len(target) == len(index) and index.min() == 0
129
+ preds_ordered = torch.zeros((index.max() + 1, *preds.shape[1:]), dtype=preds.dtype, device=preds.device)
130
+ preds_ordered[index] = preds
131
+ target_ordered = torch.zeros((index.max() + 1, *target.shape[1:]), dtype=target.dtype, device=target.device)
132
+ target_ordered[index] = target
133
+ return {"preds": preds_ordered, "target": target_ordered}
dinov2/data/adapters.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from typing import Any, Tuple, Optional
7
+
8
+ from torch.utils.data import Dataset
9
+
10
+
11
+ class DatasetWithEnumeratedTargets(Dataset):
12
+ """
13
+ If pad_dataset is set, pads based on torch's DistributedSampler implementation, which
14
+ with drop_last=False pads the last batch to be a multiple of the world size.
15
+ https://github.com/pytorch/pytorch/blob/main/torch/utils/data/distributed.py#L91
16
+ """
17
+
18
+ def __init__(self, dataset: Dataset, pad_dataset: bool = False, num_replicas: Optional[int] = None):
19
+ self._dataset = dataset
20
+ self._size = len(self._dataset)
21
+ self._padded_size = self._size
22
+ self._pad_dataset = pad_dataset
23
+ if self._pad_dataset:
24
+ assert num_replicas is not None, "num_replicas should be set if pad_dataset is True"
25
+ self._padded_size = num_replicas * ((len(dataset) + num_replicas - 1) // num_replicas)
26
+
27
+ def get_image_relpath(self, index: int) -> str:
28
+ assert self._pad_dataset or index < self._size
29
+ return self._dataset.get_image_relpath(index % self._size)
30
+
31
+ def get_image_data(self, index: int) -> bytes:
32
+ assert self._pad_dataset or index < self._size
33
+ return self._dataset.get_image_data(index % self._size)
34
+
35
+ def get_target(self, index: int) -> Tuple[Any, int]:
36
+ target = self._dataset.get_target(index % self._size)
37
+ if index >= self._size:
38
+ assert self._pad_dataset
39
+ return (-1, target)
40
+ return (index, target)
41
+
42
+ def __getitem__(self, index: int) -> Tuple[Any, Tuple[Any, int]]:
43
+ image, target = self._dataset[index % self._size]
44
+ if index >= self._size:
45
+ assert self._pad_dataset
46
+ return image, (-1, target)
47
+ target = index if target is None else target
48
+ return image, (index, target)
49
+
50
+ def __len__(self) -> int:
51
+ return self._padded_size
dinov2/data/augmentations.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+
8
+ from torchvision import transforms
9
+
10
+ from .transforms import (
11
+ GaussianBlur,
12
+ make_normalize_transform,
13
+ )
14
+
15
+
16
+ logger = logging.getLogger("dinov2")
17
+
18
+
19
+ class DataAugmentationDINO(object):
20
+ def __init__(
21
+ self,
22
+ global_crops_scale,
23
+ local_crops_scale,
24
+ local_crops_number,
25
+ global_crops_size=224,
26
+ local_crops_size=96,
27
+ ):
28
+ self.global_crops_scale = global_crops_scale
29
+ self.local_crops_scale = local_crops_scale
30
+ self.local_crops_number = local_crops_number
31
+ self.global_crops_size = global_crops_size
32
+ self.local_crops_size = local_crops_size
33
+
34
+ logger.info("###################################")
35
+ logger.info("Using data augmentation parameters:")
36
+ logger.info(f"global_crops_scale: {global_crops_scale}")
37
+ logger.info(f"local_crops_scale: {local_crops_scale}")
38
+ logger.info(f"local_crops_number: {local_crops_number}")
39
+ logger.info(f"global_crops_size: {global_crops_size}")
40
+ logger.info(f"local_crops_size: {local_crops_size}")
41
+ logger.info("###################################")
42
+
43
+ # random resized crop and flip
44
+ self.geometric_augmentation_global = transforms.Compose(
45
+ [
46
+ transforms.RandomResizedCrop(
47
+ global_crops_size, scale=global_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC
48
+ ),
49
+ transforms.RandomHorizontalFlip(p=0.5),
50
+ ]
51
+ )
52
+
53
+ self.geometric_augmentation_local = transforms.Compose(
54
+ [
55
+ transforms.RandomResizedCrop(
56
+ local_crops_size, scale=local_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC
57
+ ),
58
+ transforms.RandomHorizontalFlip(p=0.5),
59
+ ]
60
+ )
61
+
62
+ # color distorsions / blurring
63
+ color_jittering = transforms.Compose(
64
+ [
65
+ transforms.RandomApply(
66
+ [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)],
67
+ p=0.8,
68
+ ),
69
+ transforms.RandomGrayscale(p=0.2),
70
+ ]
71
+ )
72
+
73
+ global_transfo1_extra = GaussianBlur(p=1.0)
74
+
75
+ global_transfo2_extra = transforms.Compose(
76
+ [
77
+ GaussianBlur(p=0.1),
78
+ transforms.RandomSolarize(threshold=128, p=0.2),
79
+ ]
80
+ )
81
+
82
+ local_transfo_extra = GaussianBlur(p=0.5)
83
+
84
+ # normalization
85
+ self.normalize = transforms.Compose(
86
+ [
87
+ transforms.ToTensor(),
88
+ make_normalize_transform(),
89
+ ]
90
+ )
91
+
92
+ self.global_transfo1 = transforms.Compose([color_jittering, global_transfo1_extra, self.normalize])
93
+ self.global_transfo2 = transforms.Compose([color_jittering, global_transfo2_extra, self.normalize])
94
+ self.local_transfo = transforms.Compose([color_jittering, local_transfo_extra, self.normalize])
95
+
96
+ def __call__(self, image):
97
+ output = {}
98
+
99
+ # global crops:
100
+ im1_base = self.geometric_augmentation_global(image)
101
+ global_crop_1 = self.global_transfo1(im1_base)
102
+
103
+ im2_base = self.geometric_augmentation_global(image)
104
+ global_crop_2 = self.global_transfo2(im2_base)
105
+
106
+ output["global_crops"] = [global_crop_1, global_crop_2]
107
+
108
+ # global crops for teacher:
109
+ output["global_crops_teacher"] = [global_crop_1, global_crop_2]
110
+
111
+ # local crops:
112
+ local_crops = [
113
+ self.local_transfo(self.geometric_augmentation_local(image)) for _ in range(self.local_crops_number)
114
+ ]
115
+ output["local_crops"] = local_crops
116
+ output["offsets"] = ()
117
+
118
+ return output
dinov2/data/cell_dino/augmentations.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the CC-by-NC licence,
4
+ # found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ import torchvision
8
+ from torchvision import transforms
9
+
10
+ from .transforms import (
11
+ RandomContrastProteinChannel,
12
+ RandomRemoveChannelExceptProtein,
13
+ RandomBrightness,
14
+ RandomContrast,
15
+ Div255,
16
+ SelfNormalizeNoDiv,
17
+ )
18
+
19
+ logger = logging.getLogger("dinov2")
20
+
21
+
22
+ class CellAugmentationDINO(object):
23
+ def __init__(
24
+ self,
25
+ global_crops_scale,
26
+ local_crops_scale,
27
+ local_crops_number,
28
+ global_crops_size=224,
29
+ local_crops_size=96,
30
+ ):
31
+ self.global_crops_scale = global_crops_scale
32
+ self.local_crops_scale = local_crops_scale
33
+ self.local_crops_number = local_crops_number
34
+ self.global_crops_size = global_crops_size
35
+ self.local_crops_size = local_crops_size
36
+
37
+ logger.info("###################################")
38
+ logger.info("Using data augmentation parameters:")
39
+ logger.info(f"global_crops_scale: {global_crops_scale}")
40
+ logger.info(f"local_crops_scale: {local_crops_scale}")
41
+ logger.info(f"local_crops_number: {local_crops_number}")
42
+ logger.info(f"global_crops_size: {global_crops_size}")
43
+ logger.info(f"local_crops_size: {local_crops_size}")
44
+ logger.info("###################################")
45
+
46
+ additional_transforms_list = [
47
+ torchvision.transforms.RandomHorizontalFlip(),
48
+ torchvision.transforms.RandomVerticalFlip(),
49
+ RandomBrightness(),
50
+ RandomContrast(),
51
+ SelfNormalizeNoDiv(),
52
+ ]
53
+
54
+ first_transforms_list = [
55
+ Div255(),
56
+ RandomRemoveChannelExceptProtein(),
57
+ RandomContrastProteinChannel(),
58
+ ]
59
+
60
+ global_transforms_list = first_transforms_list.copy()
61
+ global_transforms_list.append(
62
+ torchvision.transforms.RandomResizedCrop(size=global_crops_size, scale=global_crops_scale)
63
+ )
64
+ global_transforms_list = global_transforms_list + additional_transforms_list
65
+
66
+ local_transforms_list = first_transforms_list
67
+ local_transforms_list.append(
68
+ torchvision.transforms.RandomResizedCrop(size=local_crops_size, scale=local_crops_scale)
69
+ )
70
+ local_transforms_list = local_transforms_list + additional_transforms_list
71
+
72
+ self.global_transform = transforms.Compose(global_transforms_list)
73
+ self.local_transform = transforms.Compose(local_transforms_list)
74
+
75
+ def __call__(self, image):
76
+ output = {}
77
+
78
+ global_crop1 = self.global_transform(image)
79
+ global_crop2 = self.global_transform(image)
80
+
81
+ output["global_crops"] = [global_crop1, global_crop2]
82
+
83
+ local_crops = []
84
+ for _ in range(self.local_crops_number):
85
+ local_crops.append(self.local_transform(image))
86
+
87
+ output["local_crops"] = local_crops
88
+ output["global_crops_teacher"] = [global_crop1, global_crop2]
89
+ output["offsets"] = ()
90
+
91
+ return output
dinov2/data/cell_dino/transforms.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the CC-by-NC licence,
4
+ # found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ from torchvision import transforms
8
+ import numpy as np
9
+ from enum import Enum
10
+
11
+
12
+ class NormalizationType(Enum):
13
+ SELF_NORM_AUG_DECODER = "self_norm_aug_decoder"
14
+ SELF_NORM_CENTER_CROP = "self_norm_center_crop"
15
+
16
+
17
+ class Div255(torch.nn.Module):
18
+ def forward(self, x):
19
+ x = x / 255
20
+ return x
21
+
22
+
23
+ class SelfNormalizeNoDiv(torch.nn.Module):
24
+ def forward(self, x):
25
+ m = x.mean((-2, -1), keepdim=True)
26
+ s = x.std((-2, -1), unbiased=False, keepdim=True)
27
+ x -= m
28
+ x /= s + 1e-7
29
+ return x
30
+
31
+
32
+ class SelfNormalize(torch.nn.Module):
33
+ def forward(self, x):
34
+ x = x / 255
35
+ m = x.mean((-2, -1), keepdim=True)
36
+ s = x.std((-2, -1), unbiased=False, keepdim=True)
37
+ x -= m
38
+ x /= s + 1e-7
39
+ return x
40
+
41
+
42
+ class RandomContrastProteinChannel(torch.nn.Module):
43
+ """
44
+ Random constrast rescaling of the protein channel only.
45
+ RescaleProtein function in Dino4cell codebase.
46
+ """
47
+
48
+ def __init__(self, p=0.2):
49
+ super().__init__()
50
+ self.p = p
51
+
52
+ def forward(self, img):
53
+ if img.max() == 0:
54
+ return img
55
+ if len(img) == 1:
56
+ return img
57
+ if np.random.rand() <= self.p:
58
+ random_factor = (np.random.rand() * 2) / img.max() # scaling
59
+ img[1] = img[1] * random_factor
60
+ return img
61
+ else:
62
+ return img
63
+
64
+
65
+ class RandomRemoveChannelExceptProtein(torch.nn.Module):
66
+ """
67
+ dropping a channel at random except the channel 1, corresponding to proteins in HPA datasets.
68
+ """
69
+
70
+ def __init__(self, p=0.2):
71
+ super().__init__()
72
+ self.p = p
73
+
74
+ def forward(self, img):
75
+ img_size = np.array(img).shape
76
+ if img_size[0] < 4:
77
+ return img
78
+ if np.random.rand() <= self.p:
79
+ channel_to_blacken = np.random.choice(np.array([0, 2, 3]))
80
+ img[channel_to_blacken] = torch.zeros(1, *img.shape[1:])
81
+ return img
82
+ else:
83
+ return img
84
+
85
+
86
+ class RandomRemoveChannel(torch.nn.Module):
87
+ """
88
+ dropping a channel at random
89
+ """
90
+
91
+ def __init__(self, p=0.2):
92
+ super().__init__()
93
+ self.p = p
94
+
95
+ def forward(self, img):
96
+ img_size = np.array(img).shape
97
+ num_channels = img_size[0]
98
+ if num_channels < 4:
99
+ return img
100
+ if np.random.rand() <= self.p:
101
+ channel_to_blacken = np.random.choice(np.array(list(range(num_channels))))
102
+ img[channel_to_blacken] = torch.zeros(1, *img.shape[1:])
103
+ return img
104
+ else:
105
+ return img
106
+
107
+
108
+ class RandomContrast(torch.nn.Module):
109
+ def __init__(self, p=0.2):
110
+ super().__init__()
111
+ self.p = p
112
+
113
+ def forward(self, img):
114
+ if img.max() == 0:
115
+ return img
116
+ n_channels = img.shape[0]
117
+ for ind in range(n_channels):
118
+ factor = max(np.random.normal(1, self.p), 0.5)
119
+ img[ind] = transforms.functional.adjust_contrast(img[ind][None, ...], factor)
120
+ return img
121
+
122
+
123
+ class RandomBrightness(torch.nn.Module):
124
+ def __init__(self, p=0.2):
125
+ super().__init__()
126
+ self.p = p
127
+
128
+ def forward(self, img):
129
+ if img.max() == 0:
130
+ return img
131
+ n_channels = img.shape[0]
132
+ for ind in range(n_channels):
133
+ factor = max(np.random.normal(1, self.p), 0.5)
134
+ img[ind] = transforms.functional.adjust_brightness(img[ind], factor)
135
+ return img
136
+
137
+
138
+ def make_classification_eval_cell_transform(
139
+ *,
140
+ resize_size: int = 0,
141
+ interpolation=transforms.InterpolationMode.BICUBIC,
142
+ crop_size: int = 384,
143
+ normalization_type: Enum = NormalizationType.SELF_NORM_CENTER_CROP,
144
+ ) -> transforms.Compose:
145
+
146
+ from .transforms import (
147
+ Div255,
148
+ SelfNormalizeNoDiv,
149
+ )
150
+
151
+ transforms_list = [Div255()]
152
+ if resize_size > 0:
153
+ transforms_list.append(transforms.Resize(resize_size, interpolation=interpolation))
154
+
155
+ if normalization_type == NormalizationType.SELF_NORM_AUG_DECODER:
156
+ transforms_list.extend(
157
+ [
158
+ transforms.RandomCrop(size=crop_size, pad_if_needed=True),
159
+ transforms.RandomHorizontalFlip(),
160
+ transforms.RandomVerticalFlip(),
161
+ ]
162
+ )
163
+ elif normalization_type == NormalizationType.SELF_NORM_CENTER_CROP:
164
+ transforms_list.append(transforms.CenterCrop(size=crop_size))
165
+ else:
166
+ raise ValueError("f{normalization_type}: unknown NormalizationType")
167
+ transforms_list.append(SelfNormalizeNoDiv())
168
+
169
+ return transforms.Compose(transforms_list)
dinov2/data/collate.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import random
8
+
9
+
10
+ def collate_data_and_cast(samples_list, mask_ratio_tuple, mask_probability, dtype, n_tokens=None, mask_generator=None):
11
+ # dtype = torch.half # TODO: Remove
12
+
13
+ n_global_crops = len(samples_list[0][0]["global_crops"])
14
+ n_local_crops = len(samples_list[0][0]["local_crops"])
15
+
16
+ collated_global_crops = torch.stack([s[0]["global_crops"][i] for i in range(n_global_crops) for s in samples_list])
17
+
18
+ collated_local_crops = torch.stack([s[0]["local_crops"][i] for i in range(n_local_crops) for s in samples_list])
19
+
20
+ B = len(collated_global_crops)
21
+ N = n_tokens
22
+ n_samples_masked = int(B * mask_probability)
23
+ probs = torch.linspace(*mask_ratio_tuple, n_samples_masked + 1)
24
+ upperbound = 0
25
+ masks_list = []
26
+ for i in range(0, n_samples_masked):
27
+ prob_min = probs[i]
28
+ prob_max = probs[i + 1]
29
+ masks_list.append(torch.BoolTensor(mask_generator(int(N * random.uniform(prob_min, prob_max)))))
30
+ upperbound += int(N * prob_max)
31
+ for i in range(n_samples_masked, B):
32
+ masks_list.append(torch.BoolTensor(mask_generator(0)))
33
+
34
+ random.shuffle(masks_list)
35
+
36
+ collated_masks = torch.stack(masks_list).flatten(1)
37
+ mask_indices_list = collated_masks.flatten().nonzero().flatten()
38
+
39
+ masks_weight = (1 / collated_masks.sum(-1).clamp(min=1.0)).unsqueeze(-1).expand_as(collated_masks)[collated_masks]
40
+
41
+ return {
42
+ "collated_global_crops": collated_global_crops.to(dtype),
43
+ "collated_local_crops": collated_local_crops.to(dtype),
44
+ "collated_masks": collated_masks,
45
+ "mask_indices_list": mask_indices_list,
46
+ "masks_weight": masks_weight,
47
+ "upperbound": upperbound,
48
+ "n_masked_patches": torch.full((1,), fill_value=mask_indices_list.shape[0], dtype=torch.long),
49
+ }
dinov2/data/datasets/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from .image_net import ImageNet
7
+ from .image_net_22k import ImageNet22k
8
+ from .cell_dino.hpaone import HPAone
9
+ from .cell_dino.hpafov import HPAFoV
10
+ from .cell_dino.chammi_cp import CHAMMI_CP
11
+ from .cell_dino.chammi_hpa import CHAMMI_HPA
12
+ from .cell_dino.chammi_wtc import CHAMMI_WTC
dinov2/data/datasets/cell_dino/chammi_cp.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the CC-by-NC licence,
4
+ # found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree.
5
+
6
+ import csv
7
+ from enum import Enum
8
+ import logging
9
+ import os
10
+ from typing import Any, Callable, Optional, Union
11
+
12
+ import numpy as np
13
+
14
+ from ..extended import ExtendedVisionDataset
15
+ from ..decoders import DecoderType
16
+
17
+ logger = logging.getLogger("dinov2")
18
+
19
+
20
+ METADATA_FILE = "morphem70k_v2.csv"
21
+
22
+ CLASS_LABELS = {
23
+ "BRD-A29260609": 0,
24
+ "BRD-K04185004": 1,
25
+ "BRD-K21680192": 2,
26
+ "DMSO": 3,
27
+ "BRD-K11129031": 4, # labels only seen in TASK_FOUR
28
+ "BRD-K62310379": 5,
29
+ "BRD-K77947974": 6,
30
+ }
31
+
32
+
33
+ class _Split(Enum):
34
+ TRAIN = "Train"
35
+ TASK_ONE = "Task_one"
36
+ TASK_TWO = "Task_two"
37
+ TASK_THREE = "Task_three"
38
+ TASK_FOUR = "Task_four"
39
+
40
+
41
+ def _load_file_names_and_targets(
42
+ root: str,
43
+ split: _Split,
44
+ ):
45
+ image_paths = []
46
+ labels = []
47
+ with open(os.path.join(root, METADATA_FILE)) as metadata:
48
+ metadata_reader = csv.DictReader(metadata)
49
+ for row in metadata_reader:
50
+ row_dataset = row["file_path"].split("/")[0]
51
+
52
+ if row["train_test_split"].upper() == split and row_dataset == "CP":
53
+ image_paths.append(row["file_path"])
54
+ labels.append(CLASS_LABELS[row["label"]])
55
+
56
+ return image_paths, labels # to debug
57
+
58
+
59
+ class CHAMMI_CP(ExtendedVisionDataset):
60
+ """
61
+ Implementation of the CP (Cell-Painting) subset of the CHAMMI benchmark dataset,
62
+ following the CHAMMI paper: https://arxiv.org/pdf/2310.19224
63
+ Github code: https://github.com/chaudatascience/channel_adaptive_models
64
+ """
65
+
66
+ Split = Union[_Split]
67
+
68
+ def __init__(
69
+ self,
70
+ *,
71
+ split: "CHAMMI_CP.Split",
72
+ root: str,
73
+ transforms: Optional[Callable] = None,
74
+ transform: Optional[Callable] = None,
75
+ target_transform: Optional[Callable] = None,
76
+ image_decoder_type: DecoderType = DecoderType.XChannelsDecoder,
77
+ **kwargs: Any,
78
+ ) -> None:
79
+ super().__init__(
80
+ root,
81
+ transforms,
82
+ transform,
83
+ target_transform,
84
+ image_decoder_type=image_decoder_type,
85
+ **kwargs,
86
+ )
87
+ self.split = split
88
+ self.root = root
89
+ self.num_additional_labels_loo_eval = 3
90
+ self._image_paths, self._targets = _load_file_names_and_targets(
91
+ root,
92
+ split,
93
+ )
94
+
95
+ def get_image_relpath(self, index: int) -> str:
96
+ return self._image_paths[index]
97
+
98
+ def get_image_data(self, index: int) -> bytes:
99
+ image_relpath = self.get_image_relpath(index)
100
+ image_full_path = os.path.join(self.root, image_relpath)
101
+ with open(image_full_path, mode="rb") as f:
102
+ image_data = f.read()
103
+ return image_data
104
+
105
+ def get_target(self, index: int) -> Any:
106
+ return self._targets[index]
107
+
108
+ def get_targets(self) -> np.ndarray:
109
+ return np.array(self._targets)
110
+
111
+ def __len__(self) -> int:
112
+ return len(self._image_paths)
dinov2/data/datasets/cell_dino/chammi_hpa.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the CC-by-NC licence,
4
+ # found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree.
5
+
6
+ import csv
7
+ from enum import Enum
8
+ import logging
9
+ import os
10
+ from typing import Any, Callable, Optional, Union
11
+
12
+ import numpy as np
13
+
14
+ from ..extended import ExtendedVisionDataset
15
+ from ..decoders import DecoderType
16
+
17
+ logger = logging.getLogger("dinov2")
18
+
19
+
20
+ METADATA_FILE = "morphem70k_v2.csv"
21
+
22
+ CLASS_LABELS = {
23
+ "golgi apparatus": 0,
24
+ "microtubules": 1,
25
+ "mitochondria": 2,
26
+ "nuclear speckles": 3,
27
+ "cytosol": 4, # labels only seen in TASK_THREE
28
+ "endoplasmic reticulum": 5,
29
+ "nucleoplasm": 6,
30
+ }
31
+
32
+
33
+ class _Split(Enum):
34
+ TRAIN = "Train"
35
+ TASK_ONE = "Task_one"
36
+ TASK_TWO = "Task_two"
37
+ TASK_THREE = "Task_three"
38
+
39
+
40
+ def _load_file_names_and_targets(
41
+ root: str,
42
+ split: _Split,
43
+ ):
44
+ image_paths = []
45
+ labels = []
46
+ with open(os.path.join(root, METADATA_FILE)) as metadata:
47
+ metadata_reader = csv.DictReader(metadata)
48
+ for row in metadata_reader:
49
+ row_dataset = row["file_path"].split("/")[0]
50
+ if row["train_test_split"].upper() == split and row_dataset == "HPA":
51
+ image_paths.append(row["file_path"])
52
+ labels.append(CLASS_LABELS[row["label"]])
53
+
54
+ return image_paths, labels
55
+
56
+
57
+ class CHAMMI_HPA(ExtendedVisionDataset):
58
+ """
59
+ Implementation of the CP (Cell-Painting) subset of the CHAMMI benchmark dataset,
60
+ following the CHAMMI paper: https://arxiv.org/pdf/2310.19224
61
+ Github code: https://github.com/chaudatascience/channel_adaptive_models
62
+ """
63
+
64
+ Split = Union[_Split]
65
+
66
+ def __init__(
67
+ self,
68
+ *,
69
+ split: "CHAMMI_HPA.Split",
70
+ root: str,
71
+ transforms: Optional[Callable] = None,
72
+ transform: Optional[Callable] = None,
73
+ target_transform: Optional[Callable] = None,
74
+ image_decoder_type: DecoderType = DecoderType.XChannelsDecoder,
75
+ **kwargs: Any,
76
+ ) -> None:
77
+ super().__init__(
78
+ root,
79
+ transforms,
80
+ transform,
81
+ target_transform,
82
+ image_decoder_type=image_decoder_type,
83
+ **kwargs,
84
+ )
85
+ self.split = split
86
+ self.root = root
87
+ self.num_additional_labels_loo_eval = 3
88
+
89
+ self._image_paths, self._targets = _load_file_names_and_targets(
90
+ root,
91
+ split,
92
+ )
93
+
94
+ def get_image_relpath(self, index: int) -> str:
95
+ return self._image_paths[index]
96
+
97
+ def get_image_data(self, index: int) -> bytes:
98
+ image_relpath = self.get_image_relpath(index)
99
+ image_full_path = os.path.join(self.root, image_relpath)
100
+ with open(image_full_path, mode="rb") as f:
101
+ image_data = f.read()
102
+ return image_data
103
+
104
+ def get_target(self, index: int) -> Any:
105
+ return self._targets[index]
106
+
107
+ def get_targets(self) -> np.ndarray:
108
+ return np.array(self._targets)
109
+
110
+ def __len__(self) -> int:
111
+ return len(self._image_paths)
dinov2/data/datasets/cell_dino/chammi_wtc.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the CC-by-NC licence,
4
+ # found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree.
5
+
6
+ import csv
7
+ from enum import Enum
8
+ import logging
9
+ import os
10
+ from typing import Any, Callable, Optional, Union
11
+
12
+ import numpy as np
13
+
14
+ from ..extended import ExtendedVisionDataset
15
+ from ..decoders import DecoderType
16
+
17
+ logger = logging.getLogger("dinov2")
18
+
19
+
20
+ METADATA_FILE = "morphem70k_v2.csv"
21
+
22
+ CLASS_LABELS = {
23
+ "M0": 0,
24
+ "M1M2": 1,
25
+ "M3": 2,
26
+ "M4M5": 3,
27
+ "M6M7_complete": 4,
28
+ "M6M7_single": 5,
29
+ }
30
+
31
+
32
+ class _Split(Enum):
33
+ TRAIN = "Train"
34
+ TASK_ONE = "Task_one"
35
+ TASK_TWO = "Task_two"
36
+
37
+
38
+ def _load_file_names_and_targets(
39
+ root: str,
40
+ split: _Split,
41
+ ):
42
+ image_paths = []
43
+ labels = []
44
+ with open(os.path.join(root, METADATA_FILE)) as metadata:
45
+ metadata_reader = csv.DictReader(metadata)
46
+ for row in metadata_reader:
47
+ row_dataset = row["file_path"].split("/")[0]
48
+ if row["train_test_split"].upper() == split and row_dataset == "Allen":
49
+ image_paths.append(row["file_path"])
50
+ labels.append(CLASS_LABELS[row["label"]])
51
+
52
+ return image_paths, labels
53
+
54
+
55
+ class CHAMMI_WTC(ExtendedVisionDataset):
56
+ """
57
+ Implementation of the CP (Cell-Painting) subset of the CHAMMI benchmark dataset,
58
+ following the CHAMMI paper: https://arxiv.org/pdf/2310.19224
59
+ Github code: https://github.com/chaudatascience/channel_adaptive_models
60
+ """
61
+
62
+ Split = Union[_Split]
63
+
64
+ def __init__(
65
+ self,
66
+ *,
67
+ split: "CHAMMI_WTC.Split",
68
+ root: str,
69
+ transforms: Optional[Callable] = None,
70
+ transform: Optional[Callable] = None,
71
+ target_transform: Optional[Callable] = None,
72
+ image_decoder_type: DecoderType = DecoderType.XChannelsTIFFDecoder,
73
+ **kwargs: Any,
74
+ ) -> None:
75
+ super().__init__(
76
+ root,
77
+ transforms,
78
+ transform,
79
+ target_transform,
80
+ image_decoder_type=image_decoder_type,
81
+ **kwargs,
82
+ )
83
+ self.split = split
84
+ self.root = root
85
+
86
+ self._image_paths, self._targets = _load_file_names_and_targets(
87
+ root,
88
+ split,
89
+ )
90
+
91
+ def get_image_relpath(self, index: int) -> str:
92
+ return self._image_paths[index]
93
+
94
+ def get_image_data(self, index: int) -> bytes:
95
+ image_relpath = self.get_image_relpath(index)
96
+ image_full_path = os.path.join(self.root, image_relpath)
97
+ with open(image_full_path, mode="rb") as f:
98
+ image_data = f.read()
99
+ return image_data
100
+
101
+ def get_target(self, index: int) -> Any:
102
+ return self._targets[index]
103
+
104
+ def get_targets(self) -> np.ndarray:
105
+ return np.array(self._targets)
106
+
107
+ def __len__(self) -> int:
108
+ return len(self._image_paths)
dinov2/data/datasets/cell_dino/hpafov.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the CC-by-NC licence,
4
+ # found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree.
5
+
6
+ import csv
7
+ from enum import Enum
8
+ import logging
9
+ import os
10
+ from typing import Any, Callable, List, Optional, Tuple, Union, Dict
11
+
12
+ import numpy as np
13
+
14
+ from ..extended import ExtendedVisionDataset
15
+ from ..decoders import DecoderType
16
+
17
+ logger = logging.getLogger("dinov2")
18
+
19
+ CELL_TYPE = [
20
+ "BJ", # 1
21
+ "LHCN-M2",
22
+ "RH-30",
23
+ "SH-SY5Y",
24
+ "U-2 OS", # 5
25
+ "ASC TERT1",
26
+ "HaCaT",
27
+ "A-431",
28
+ "U-251 MG",
29
+ "HEK 293", # 10
30
+ "A549",
31
+ "RT4",
32
+ "HeLa",
33
+ "MCF7",
34
+ "PC-3", # 15
35
+ "hTERT-RPE1",
36
+ "SK-MEL-30",
37
+ "EFO-21",
38
+ "AF22",
39
+ "HEL", # 20
40
+ "Hep G2",
41
+ "HUVEC TERT2",
42
+ "THP-1",
43
+ "CACO-2",
44
+ "JURKAT", # 25
45
+ "RPTEC TERT1",
46
+ "SuSa",
47
+ "REH",
48
+ "HDLM-2",
49
+ "K-562", # 30
50
+ "hTCEpi",
51
+ "NB-4",
52
+ "HAP1",
53
+ "OE19",
54
+ "SiHa", # 35
55
+ ]
56
+
57
+ PROTEIN_LOCALIZATION = [ # matches https://www.kaggle.com/c/human-protein-atlas-image-classification/data
58
+ "nucleoplasm",
59
+ "nuclear membrane",
60
+ "nucleoli",
61
+ "nucleoli fibrillar center",
62
+ "nuclear speckles", # 5
63
+ "nuclear bodies",
64
+ "endoplasmic reticulum",
65
+ "golgi apparatus",
66
+ "peroxisomes",
67
+ "endosomes", # 10
68
+ "lysosomes",
69
+ "intermediate filaments",
70
+ "actin filaments",
71
+ "focal adhesion sites",
72
+ "microtubules", # 15
73
+ "microtubule ends",
74
+ "cytokinetic bridge",
75
+ "mitotic spindle",
76
+ "microtubule organizing center",
77
+ "centrosome", # 20
78
+ "lipid droplets",
79
+ "plasma membrane",
80
+ "cell junctions",
81
+ "mitochondria",
82
+ "aggresome", # 25
83
+ "cytosol",
84
+ "cytoplasmic bodies",
85
+ "rods & rings",
86
+ ]
87
+
88
+
89
+ class _Split(Enum):
90
+ TRAIN = "train"
91
+ VAL = "val"
92
+ SSL = "ssl"
93
+
94
+
95
+ def get_csv_fpath(split):
96
+ """
97
+ Path to data relative to root
98
+ """
99
+ if split == _Split.TRAIN.value.upper() or split == _Split.TRAIN or split == "TRAIN":
100
+ return "whole_images_512_train.csv"
101
+ elif split == _Split.VAL.value.upper() or split == _Split.VAL or split == "VAL":
102
+ return "whole_images_512_test.csv"
103
+
104
+
105
+ class _WildCard(Enum):
106
+ NONE = "none"
107
+ SEPARATECHANNELS = "separate_channels" # each channel from each image is treated as an independent sample, overrides chosen channel configuration
108
+
109
+
110
+ class _Mode(Enum):
111
+ """
112
+ Targets:
113
+ - ALL: tuple, (one hot encoding of multilabel protein localization, categorical encoding of cell type)
114
+ - PROTEIN_LOCALIZATION: one hot encoding of multilabel protein localization
115
+ - CELL_TYPE: categorical encoding of cell type
116
+ """
117
+
118
+ ALL = "all"
119
+ PROTEIN_LOCALIZATION = "protein_localization"
120
+ CELL_TYPE = "cell_type"
121
+
122
+ @property
123
+ def nb_labels(self):
124
+ if self == _Mode.CELL_TYPE:
125
+ return len(CELL_TYPE)
126
+ elif self == _Mode.PROTEIN_LOCALIZATION:
127
+ return len(PROTEIN_LOCALIZATION)
128
+ else:
129
+ return None
130
+
131
+
132
+ def _list_images_from_csv(img_path, csv_path):
133
+ L = []
134
+ with open(csv_path) as filename:
135
+ reader = csv.DictReader(filename)
136
+ for row in reader:
137
+ L.append(os.path.join(img_path, row["ID"] + ".png"))
138
+ return L
139
+
140
+
141
+ def _load_file_names_and_labels_ssl(
142
+ root: str,
143
+ ) -> Tuple[List[str], List[Any]]:
144
+
145
+ curr_img_path = os.path.join(root, "normalized_data")
146
+ csv_train_ssl = os.path.join(root, "whole_images_names.csv")
147
+ image_paths = _list_images_from_csv(curr_img_path, csv_train_ssl)
148
+ labels = [i for i in range(len(image_paths))]
149
+ return image_paths, labels
150
+
151
+
152
+ def _load_file_names_and_labels(
153
+ root: str,
154
+ split: _Split,
155
+ mode: _Mode,
156
+ ) -> Tuple[List[str], List[Any], np.ndarray]:
157
+
158
+ data_path = os.path.join(root, "512_whole_images")
159
+ csv_fpath = os.path.join(root, get_csv_fpath(split))
160
+
161
+ image_paths = []
162
+ labels = []
163
+
164
+ with open(csv_fpath) as filename:
165
+ reader = csv.DictReader(filename)
166
+ for row in reader:
167
+
168
+ add_sample = True
169
+ if mode != _Mode.PROTEIN_LOCALIZATION.value.upper():
170
+ # categorical
171
+ if row["cell_type"] in CELL_TYPE:
172
+ cell_type = CELL_TYPE.index(row["cell_type"])
173
+ else:
174
+ cell_type = np.nan
175
+
176
+ if mode != _Mode.CELL_TYPE.value.upper():
177
+ # one hot encoding
178
+ prot_loc = np.zeros(len(PROTEIN_LOCALIZATION), dtype=np.int_)
179
+ for k in range(len(PROTEIN_LOCALIZATION)):
180
+ if row[PROTEIN_LOCALIZATION[k]] == "True":
181
+ prot_loc[k] = 1
182
+ if prot_loc.max() < 0.5:
183
+ add_sample = False
184
+
185
+ if add_sample:
186
+ if mode == _Mode.PROTEIN_LOCALIZATION.value.upper():
187
+ labels.append(prot_loc)
188
+ elif mode == _Mode.CELL_TYPE.value.upper():
189
+ labels.append(cell_type)
190
+ else:
191
+ labels.append({"prot_loc": prot_loc, "cell_type": cell_type})
192
+
193
+ candidate_path = os.path.join(data_path, row["file"].split("/")[-1])
194
+ if os.path.exists(candidate_path):
195
+ image_paths.append(candidate_path)
196
+ else:
197
+ candidate_path = os.path.join(
198
+ data_path, row["file"].split("/")[-1].split(".")[0] + ".tiff"
199
+ ) # _blue.png") # some images on the normalized_data folder have a _blue suffix on their names
200
+ if os.path.exists(candidate_path):
201
+ image_paths.append(candidate_path)
202
+ else:
203
+ raise FileNotFoundError(f"File {candidate_path} not found.")
204
+
205
+ return image_paths, labels
206
+
207
+
208
+ class HPAFoV(ExtendedVisionDataset):
209
+ Split = Union[_Split]
210
+ Mode = Union[_Mode]
211
+ WildCard = Union[_WildCard]
212
+
213
+ def __init__(
214
+ self,
215
+ *,
216
+ split: "HPAFoV.Split" = _Split.TRAIN,
217
+ mode: "HPAFoV.Mode" = _Mode.ALL,
218
+ wildcard: "HPAFoV.WildCard" = _WildCard.NONE,
219
+ root: str,
220
+ transforms: Optional[Callable] = None,
221
+ transform: Optional[Callable] = None,
222
+ target_transform: Optional[Callable] = None,
223
+ image_decoder_type: DecoderType = DecoderType.ChannelSelectDecoder,
224
+ image_decoder_params: Dict[str, Any] = {},
225
+ **kwargs: Any,
226
+ ) -> None:
227
+ super().__init__(
228
+ root,
229
+ transforms,
230
+ transform,
231
+ target_transform,
232
+ image_decoder_type=image_decoder_type,
233
+ image_decoder_params={
234
+ "select_channel": True
235
+ if wildcard == _WildCard.SEPARATECHANNELS or wildcard == "SEPARATE_CHANNELS"
236
+ else False
237
+ },
238
+ **kwargs,
239
+ )
240
+ self.mode = mode
241
+ self.split = split
242
+ self.root = root
243
+ self.wildcard = wildcard
244
+ self.channel_adaptive = True
245
+ if split == _Split.SSL.value.upper() or split == _Split.SSL or split == "SSL":
246
+ self._image_paths, self._labels = _load_file_names_and_labels_ssl(root)
247
+ else:
248
+ self._image_paths, self._labels = _load_file_names_and_labels(root, self.split, self.mode)
249
+
250
+ self._channels = np.repeat(np.array([[0, 1, 2, 3]]), len(self._image_paths), axis=0).tolist()
251
+
252
+ if self.wildcard == _WildCard.SEPARATECHANNELS.value.upper():
253
+ image_paths, labels, channels = self._image_paths, self._labels, self._channels
254
+ channels = np.array(channels)
255
+ # separate and stack the columns of the channels array
256
+ C = channels.shape[1]
257
+ channels = np.concatenate([channels[:, i] for i in range(C)])
258
+ self._channels = np.expand_dims(channels, 1).tolist()
259
+ self.image_paths = image_paths * C
260
+ self.labels = labels * C
261
+
262
+ def get_image_relpath(self, index: int) -> str:
263
+ return self._image_paths[index]
264
+
265
+ def get_image_data(self, index: int) -> bytes:
266
+ image_relpath = self.get_image_relpath(index)
267
+ image_full_path = os.path.join(self.root, image_relpath)
268
+ with open(image_full_path, mode="rb") as f:
269
+ image_data = f.read()
270
+ if self.channel_adaptive:
271
+ channels = self._channels[index]
272
+ return image_data + bytes(channels) + (len(channels)).to_bytes(1, byteorder="big")
273
+ else:
274
+ return image_data
275
+
276
+ def get_target(self, index: int) -> Any:
277
+ return self._labels[index]
278
+
279
+ def get_targets(self) -> np.ndarray:
280
+ return np.array(self._labels)
281
+
282
+ def __len__(self) -> int:
283
+ return len(self._image_paths)
dinov2/data/datasets/cell_dino/hpaone.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the CC-by-NC licence,
4
+ # found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree.
5
+
6
+ import csv
7
+ from enum import Enum
8
+ import logging
9
+ import os
10
+ from typing import Any, Callable, List, Optional, Tuple, Union
11
+
12
+ import numpy as np
13
+
14
+ from ..extended import ExtendedVisionDataset
15
+ from ..decoders import DecoderType
16
+
17
+ logger = logging.getLogger("dinov2")
18
+
19
+ PROTEIN_LOCALIZATION = [
20
+ "actin filaments,focal adhesion sites",
21
+ "aggresome",
22
+ "centrosome,centriolar satellite",
23
+ "cytosol",
24
+ "endoplasmic reticulum",
25
+ "golgi apparatus",
26
+ "intermediate filaments",
27
+ "microtubules",
28
+ "mitochondria",
29
+ "mitotic spindle",
30
+ "no staining",
31
+ "nuclear bodies",
32
+ "nuclear membrane",
33
+ "nuclear speckles",
34
+ "nucleoli",
35
+ "nucleoli fibrillar center",
36
+ "nucleoplasm",
37
+ "plasma membrane,cell junctions",
38
+ "vesicles,peroxisomes,endosomes,lysosomes,lipid droplets,cytoplasmic bodies",
39
+ ] # 19
40
+
41
+
42
+ CELL_TYPE = [
43
+ "A-431", # 0
44
+ "A549",
45
+ "AF22",
46
+ "ASC TERT1",
47
+ "BJ",
48
+ "CACO-2",
49
+ "EFO-21",
50
+ "HAP1",
51
+ "HDLM-2",
52
+ "HEK 293", # 9
53
+ "HEL",
54
+ "HUVEC TERT2",
55
+ "HaCaT",
56
+ "HeLa",
57
+ "Hep G2",
58
+ "JURKAT",
59
+ "K-562",
60
+ "MCF7",
61
+ "PC-3",
62
+ "REH",
63
+ "RH-30", # 20
64
+ "RPTEC TERT1",
65
+ "RT4",
66
+ "SH-SY5Y",
67
+ "SK-MEL-30",
68
+ "SiHa",
69
+ "U-2 OS",
70
+ "U-251 MG",
71
+ "hTCEpi", # 28
72
+ ] # 29 cell types
73
+
74
+
75
+ class _Split(Enum):
76
+ VAL = "val"
77
+ TRAIN = "train"
78
+ ALL = "all" # images without labels, for encoder training
79
+
80
+
81
+ class _Mode(Enum):
82
+ PROTEIN_LOCALIZATION = "protein_localization"
83
+ CELL_TYPE = "cell_type"
84
+
85
+ @property
86
+ def num_labels(self):
87
+ if self == _Mode.CELL_TYPE.value.upper():
88
+ return len(CELL_TYPE)
89
+ return len(PROTEIN_LOCALIZATION)
90
+
91
+
92
+ def _simple_parse_csv(img_rootdir, csv_filepath: str):
93
+ samples = []
94
+ with open(csv_filepath) as filename:
95
+ template = csv.DictReader(filename)
96
+ samples = [(os.path.join(img_rootdir, row["img_path"]), 0) for row in template]
97
+ return samples
98
+
99
+
100
+ def _parse_csv(img_rootdir, csv_labels_path: str):
101
+ nb_protein_location = len(PROTEIN_LOCALIZATION)
102
+ nb_cell_type = len(CELL_TYPE)
103
+ samples = []
104
+ with open(csv_labels_path) as filename:
105
+ reader = csv.DictReader(filename)
106
+ for row in reader:
107
+ protein_location = np.zeros(nb_protein_location, dtype=np.int_)
108
+ for k in range(nb_protein_location):
109
+ if row[PROTEIN_LOCALIZATION[k]] == "True":
110
+ protein_location[k] = 1
111
+
112
+ cell_type = 0
113
+ for k in range(nb_cell_type):
114
+ if row[CELL_TYPE[k]] == "True":
115
+ cell_type = k
116
+
117
+ samples.append(
118
+ (
119
+ img_rootdir + "/" + row["file"].rsplit("/", 1)[1],
120
+ protein_location,
121
+ cell_type,
122
+ )
123
+ )
124
+ return samples
125
+
126
+
127
+ def _load_file_names_and_labels_ssl(
128
+ root: str,
129
+ ) -> Tuple[List[str], List[Any]]:
130
+ curr_dir_train = os.path.join(root, "varied_size_masked_single_cells_HPA")
131
+ csv_all_path = os.path.join(root, "varied_size_masked_single_cells_pretrain_20240507.csv")
132
+ samples = _simple_parse_csv(curr_dir_train, csv_all_path)
133
+ image_paths, fake_labels = zip(*samples)
134
+ lab = list(fake_labels)
135
+ return image_paths, lab
136
+
137
+
138
+ def _load_file_names_and_labels_train_or_test(
139
+ root: str,
140
+ split: _Split,
141
+ mode: _Mode,
142
+ ) -> Tuple[List[str], List[Any]]:
143
+
144
+ if split == _Split.TRAIN.value.upper() or split == _Split.TRAIN:
145
+ csv_labels_path = os.path.join(root, "fixed_size_masked_single_cells_pretrain_20240507.csv")
146
+ elif split == _Split.VAL.value.upper() or split == _Split.VAL:
147
+ csv_labels_path = os.path.join(root, "fixed_size_masked_single_cells_evaluation_20240507.csv")
148
+ else:
149
+ print("wrong split name")
150
+ curr_dir_val = os.path.join(root, "fixed_size_masked_single_cells_HPA")
151
+
152
+ samples = _parse_csv(curr_dir_val, csv_labels_path)
153
+ image_paths, protein_location, cell_type = zip(*samples)
154
+ if mode == _Mode.PROTEIN_LOCALIZATION.value.upper():
155
+ lab = protein_location
156
+ elif mode == _Mode.CELL_TYPE.value.upper():
157
+ lab = cell_type
158
+ else:
159
+ lab = protein_location, cell_type
160
+ image_paths = list(image_paths)
161
+ return image_paths, lab
162
+
163
+
164
+ class HPAone(ExtendedVisionDataset):
165
+ Split = Union[_Split]
166
+ Mode = Union[_Mode]
167
+
168
+ def __init__(
169
+ self,
170
+ *,
171
+ split: "HPAone.Split" = _Split.ALL,
172
+ mode: "HPAone.Mode" = None,
173
+ root: str,
174
+ transforms: Optional[Callable] = None,
175
+ transform: Optional[Callable] = None,
176
+ target_transform: Optional[Callable] = None,
177
+ image_decoder_type: DecoderType = DecoderType.XChannelsDecoder,
178
+ **kwargs: Any,
179
+ ) -> None:
180
+ super().__init__(
181
+ root,
182
+ transforms,
183
+ transform,
184
+ target_transform,
185
+ image_decoder_type=image_decoder_type,
186
+ **kwargs,
187
+ )
188
+ self.mode = mode
189
+ self.split = split
190
+ self.root = root
191
+
192
+ if (
193
+ split in {_Split.TRAIN.value.upper(), _Split.VAL.value.upper()}
194
+ or split == _Split.TRAIN
195
+ or split == _Split.VAL
196
+ ):
197
+ (
198
+ self._image_paths,
199
+ self._labels,
200
+ ) = _load_file_names_and_labels_train_or_test(root, split, mode)
201
+ elif split == _Split.ALL.value.upper() or split == _Split.ALL:
202
+ self._image_paths, self._labels = _load_file_names_and_labels_ssl(root)
203
+ else:
204
+ logger.info(f"unknown split: {split}, {_Split.ALL.value.upper()}")
205
+
206
+ def get_image_relpath(self, index: int) -> str:
207
+ return self._image_paths[index]
208
+
209
+ def get_image_data(self, index: int) -> bytes:
210
+ image_relpath = self.get_image_relpath(index)
211
+ image_full_path = os.path.join(self.root, image_relpath)
212
+ with open(image_full_path, mode="rb") as f:
213
+ image_data = f.read()
214
+ return image_data
215
+
216
+ def get_target(self, index: int) -> Any:
217
+ return self._labels[index]
218
+
219
+ def get_targets(self) -> np.ndarray:
220
+ return np.array(self._labels)
221
+
222
+ def __len__(self) -> int:
223
+ return len(self._image_paths)
dinov2/data/datasets/decoders.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from io import BytesIO
7
+ from typing import Any, Type
8
+
9
+ from PIL import Image
10
+ import numpy as np
11
+ import torch
12
+ from enum import Enum
13
+
14
+ try:
15
+ import tifffile
16
+ except ImportError:
17
+ print("Could not import `tifffile`, TIFFImageDataDecoder will be disabled")
18
+
19
+
20
+ class Decoder:
21
+ def decode(self) -> Any:
22
+ raise NotImplementedError
23
+
24
+
25
+ class DecoderType(Enum):
26
+ ImageDataDecoder = "ImageDataDecoder"
27
+ XChannelsDecoder = "XChannelsDecoder"
28
+ XChannelsTIFFDecoder = "XChannelsTIFFDecoder"
29
+ ChannelSelectDecoder = "ChannelSelectDecoder"
30
+
31
+ def get_class(self) -> Type[Decoder]: # noqa: C901
32
+ if self == DecoderType.ImageDataDecoder:
33
+ return ImageDataDecoder
34
+ if self == DecoderType.XChannelsDecoder:
35
+ return XChannelsDecoder
36
+ if self == DecoderType.XChannelsTIFFDecoder:
37
+ return XChannelsTIFFDecoder
38
+ if self == DecoderType.ChannelSelectDecoder:
39
+ return ChannelSelectDecoder
40
+
41
+
42
+ class ImageDataDecoder(Decoder):
43
+ def __init__(self, image_data: bytes) -> None:
44
+ self._image_data = image_data
45
+
46
+ def decode(self) -> Image:
47
+ f = BytesIO(self._image_data)
48
+ return Image.open(f).convert(mode="RGB")
49
+
50
+
51
+ class TargetDecoder(Decoder):
52
+ def __init__(self, target: Any):
53
+ self._target = target
54
+
55
+ def decode(self) -> Any:
56
+ return self._target
57
+
58
+
59
+ class XChannelsDecoder(Decoder):
60
+ def __init__(self, image_data: bytes) -> None:
61
+ self._image_data = image_data
62
+
63
+ def decode(self):
64
+ im = np.asarray(Image.open(BytesIO(self._image_data)))
65
+ if len(im.shape) == 2:
66
+ im = np.reshape(im, (im.shape[0], im.shape[0], -1), order="F")
67
+ return torch.Tensor(im).permute(2, 0, 1)
68
+
69
+
70
+ class XChannelsTIFFDecoder(Decoder):
71
+ def __init__(self, image_data: bytes, num_channels: int = 3) -> None:
72
+ self._image_data = image_data
73
+ self._num_channels = num_channels
74
+
75
+ def decode(self):
76
+ numpy_array = tifffile.imread(BytesIO(self._image_data))
77
+ numpy_array = np.reshape(numpy_array, (numpy_array.shape[0], -1, self._num_channels), order="F")
78
+ return torch.Tensor(numpy_array).permute(2, 0, 1)
79
+
80
+
81
+ class ChannelSelectDecoder(Decoder):
82
+ def __init__(self, image_data: bytes, select_channel: bool = False) -> None:
83
+ self.select_channel = select_channel
84
+ if select_channel:
85
+ self._image_data = image_data[:-1]
86
+ self._channel = image_data[-1]
87
+ else:
88
+ self._image_data = image_data
89
+
90
+ def decode(self):
91
+ im = np.asarray(Image.open(BytesIO(self._image_data)))
92
+ if self.select_channel:
93
+ return torch.Tensor(im).permute(2, 0, 1)[[self._channel]]
94
+ return torch.Tensor(im).permute(2, 0, 1)
dinov2/data/datasets/extended.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from typing import Any, Tuple
7
+
8
+ from torchvision.datasets import VisionDataset
9
+
10
+ from .decoders import DecoderType, TargetDecoder
11
+
12
+
13
+ class ExtendedVisionDataset(VisionDataset):
14
+ def __init__(self, *args, **kwargs) -> None:
15
+ image_decoder_type = kwargs.pop("image_decoder_type", DecoderType.ImageDataDecoder)
16
+ self._decoder_params = {}
17
+ self._image_decoder_class = image_decoder_type.get_class()
18
+ if "image_decoder_params" in kwargs:
19
+ self._decoder_params = kwargs.pop("image_decoder_params")
20
+
21
+ super().__init__(*args, **kwargs) # type: ignore
22
+
23
+ def get_image_data(self, index: int) -> bytes:
24
+ raise NotImplementedError
25
+
26
+ def get_target(self, index: int) -> Any:
27
+ raise NotImplementedError
28
+
29
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
30
+ try:
31
+ image_data = self.get_image_data(index)
32
+ image = self._image_decoder_class(image_data, **self._decoder_params).decode()
33
+ except Exception as e:
34
+ raise RuntimeError(f"can not read image for sample {index}") from e
35
+ target = self.get_target(index)
36
+ target = TargetDecoder(target).decode()
37
+
38
+ if self.transforms is not None:
39
+ image, target = self.transforms(image, target)
40
+
41
+ return image, target
42
+
43
+ def __len__(self) -> int:
44
+ raise NotImplementedError
dinov2/data/datasets/image_net.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import csv
7
+ from enum import Enum
8
+ import logging
9
+ import os
10
+ from typing import Callable, List, Optional, Tuple, Union
11
+
12
+ import numpy as np
13
+
14
+ from .extended import ExtendedVisionDataset
15
+
16
+
17
+ logger = logging.getLogger("dinov2")
18
+ _Target = int
19
+
20
+
21
+ class _Split(Enum):
22
+ TRAIN = "train"
23
+ VAL = "val"
24
+ TEST = "test" # NOTE: torchvision does not support the test split
25
+
26
+ @property
27
+ def length(self) -> int:
28
+ split_lengths = {
29
+ _Split.TRAIN: 1_281_167,
30
+ _Split.VAL: 50_000,
31
+ _Split.TEST: 100_000,
32
+ }
33
+ return split_lengths[self]
34
+
35
+ def get_dirname(self, class_id: Optional[str] = None) -> str:
36
+ return self.value if class_id is None else os.path.join(self.value, class_id)
37
+
38
+ def get_image_relpath(self, actual_index: int, class_id: Optional[str] = None) -> str:
39
+ dirname = self.get_dirname(class_id)
40
+ if self == _Split.TRAIN:
41
+ basename = f"{class_id}_{actual_index}"
42
+ else: # self in (_Split.VAL, _Split.TEST):
43
+ basename = f"ILSVRC2012_{self.value}_{actual_index:08d}"
44
+ return os.path.join(dirname, basename + ".JPEG")
45
+
46
+ def parse_image_relpath(self, image_relpath: str) -> Tuple[str, int]:
47
+ assert self != _Split.TEST
48
+ dirname, filename = os.path.split(image_relpath)
49
+ class_id = os.path.split(dirname)[-1]
50
+ basename, _ = os.path.splitext(filename)
51
+ actual_index = int(basename.split("_")[-1])
52
+ return class_id, actual_index
53
+
54
+
55
+ class ImageNet(ExtendedVisionDataset):
56
+ Target = Union[_Target]
57
+ Split = Union[_Split]
58
+
59
+ def __init__(
60
+ self,
61
+ *,
62
+ split: "ImageNet.Split",
63
+ root: str,
64
+ extra: str,
65
+ transforms: Optional[Callable] = None,
66
+ transform: Optional[Callable] = None,
67
+ target_transform: Optional[Callable] = None,
68
+ ) -> None:
69
+ super().__init__(root, transforms, transform, target_transform)
70
+ self._extra_root = extra
71
+ self._split = split
72
+
73
+ self._entries = None
74
+ self._class_ids = None
75
+ self._class_names = None
76
+
77
+ @property
78
+ def split(self) -> "ImageNet.Split":
79
+ return self._split
80
+
81
+ def _get_extra_full_path(self, extra_path: str) -> str:
82
+ return os.path.join(self._extra_root, extra_path)
83
+
84
+ def _load_extra(self, extra_path: str) -> np.ndarray:
85
+ extra_full_path = self._get_extra_full_path(extra_path)
86
+ return np.load(extra_full_path, mmap_mode="r")
87
+
88
+ def _save_extra(self, extra_array: np.ndarray, extra_path: str) -> None:
89
+ extra_full_path = self._get_extra_full_path(extra_path)
90
+ os.makedirs(self._extra_root, exist_ok=True)
91
+ np.save(extra_full_path, extra_array)
92
+
93
+ @property
94
+ def _entries_path(self) -> str:
95
+ return f"entries-{self._split.value.upper()}.npy"
96
+
97
+ @property
98
+ def _class_ids_path(self) -> str:
99
+ return f"class-ids-{self._split.value.upper()}.npy"
100
+
101
+ @property
102
+ def _class_names_path(self) -> str:
103
+ return f"class-names-{self._split.value.upper()}.npy"
104
+
105
+ def _get_entries(self) -> np.ndarray:
106
+ if self._entries is None:
107
+ self._entries = self._load_extra(self._entries_path)
108
+ assert self._entries is not None
109
+ return self._entries
110
+
111
+ def _get_class_ids(self) -> np.ndarray:
112
+ if self._split == _Split.TEST:
113
+ assert False, "Class IDs are not available in TEST split"
114
+ if self._class_ids is None:
115
+ self._class_ids = self._load_extra(self._class_ids_path)
116
+ assert self._class_ids is not None
117
+ return self._class_ids
118
+
119
+ def _get_class_names(self) -> np.ndarray:
120
+ if self._split == _Split.TEST:
121
+ assert False, "Class names are not available in TEST split"
122
+ if self._class_names is None:
123
+ self._class_names = self._load_extra(self._class_names_path)
124
+ assert self._class_names is not None
125
+ return self._class_names
126
+
127
+ def find_class_id(self, class_index: int) -> str:
128
+ class_ids = self._get_class_ids()
129
+ return str(class_ids[class_index])
130
+
131
+ def find_class_name(self, class_index: int) -> str:
132
+ class_names = self._get_class_names()
133
+ return str(class_names[class_index])
134
+
135
+ def get_image_data(self, index: int) -> bytes:
136
+ entries = self._get_entries()
137
+ actual_index = entries[index]["actual_index"]
138
+
139
+ class_id = self.get_class_id(index)
140
+
141
+ image_relpath = self.split.get_image_relpath(actual_index, class_id)
142
+ image_full_path = os.path.join(self.root, image_relpath)
143
+ with open(image_full_path, mode="rb") as f:
144
+ image_data = f.read()
145
+ return image_data
146
+
147
+ def get_target(self, index: int) -> Optional[Target]:
148
+ entries = self._get_entries()
149
+ class_index = entries[index]["class_index"]
150
+ return None if self.split == _Split.TEST else int(class_index)
151
+
152
+ def get_targets(self) -> Optional[np.ndarray]:
153
+ entries = self._get_entries()
154
+ return None if self.split == _Split.TEST else entries["class_index"]
155
+
156
+ def get_class_id(self, index: int) -> Optional[str]:
157
+ entries = self._get_entries()
158
+ class_id = entries[index]["class_id"]
159
+ return None if self.split == _Split.TEST else str(class_id)
160
+
161
+ def get_class_name(self, index: int) -> Optional[str]:
162
+ entries = self._get_entries()
163
+ class_name = entries[index]["class_name"]
164
+ return None if self.split == _Split.TEST else str(class_name)
165
+
166
+ def __len__(self) -> int:
167
+ entries = self._get_entries()
168
+ assert len(entries) == self.split.length
169
+ return len(entries)
170
+
171
+ def _load_labels(self, labels_path: str) -> List[Tuple[str, str]]:
172
+ labels_full_path = os.path.join(self.root, labels_path)
173
+ labels = []
174
+
175
+ try:
176
+ with open(labels_full_path, "r") as f:
177
+ reader = csv.reader(f)
178
+ for row in reader:
179
+ class_id, class_name = row
180
+ labels.append((class_id, class_name))
181
+ except OSError as e:
182
+ raise RuntimeError(f'can not read labels file "{labels_full_path}"') from e
183
+
184
+ return labels
185
+
186
+ def _dump_entries(self) -> None:
187
+ split = self.split
188
+ if split == ImageNet.Split.TEST:
189
+ dataset = None
190
+ sample_count = split.length
191
+ max_class_id_length, max_class_name_length = 0, 0
192
+ else:
193
+ labels_path = "labels.txt"
194
+ logger.info(f'loading labels from "{labels_path}"')
195
+ labels = self._load_labels(labels_path)
196
+
197
+ # NOTE: Using torchvision ImageFolder for consistency
198
+ from torchvision.datasets import ImageFolder
199
+
200
+ dataset_root = os.path.join(self.root, split.get_dirname())
201
+ dataset = ImageFolder(dataset_root)
202
+ sample_count = len(dataset)
203
+ max_class_id_length, max_class_name_length = -1, -1
204
+ for sample in dataset.samples:
205
+ _, class_index = sample
206
+ class_id, class_name = labels[class_index]
207
+ max_class_id_length = max(len(class_id), max_class_id_length)
208
+ max_class_name_length = max(len(class_name), max_class_name_length)
209
+
210
+ dtype = np.dtype(
211
+ [
212
+ ("actual_index", "<u4"),
213
+ ("class_index", "<u4"),
214
+ ("class_id", f"U{max_class_id_length}"),
215
+ ("class_name", f"U{max_class_name_length}"),
216
+ ]
217
+ )
218
+ entries_array = np.empty(sample_count, dtype=dtype)
219
+
220
+ if split == ImageNet.Split.TEST:
221
+ old_percent = -1
222
+ for index in range(sample_count):
223
+ percent = 100 * (index + 1) // sample_count
224
+ if percent > old_percent:
225
+ logger.info(f"creating entries: {percent}%")
226
+ old_percent = percent
227
+
228
+ actual_index = index + 1
229
+ class_index = np.uint32(-1)
230
+ class_id, class_name = "", ""
231
+ entries_array[index] = (actual_index, class_index, class_id, class_name)
232
+ else:
233
+ class_names = {class_id: class_name for class_id, class_name in labels}
234
+
235
+ assert dataset
236
+ old_percent = -1
237
+ for index in range(sample_count):
238
+ percent = 100 * (index + 1) // sample_count
239
+ if percent > old_percent:
240
+ logger.info(f"creating entries: {percent}%")
241
+ old_percent = percent
242
+
243
+ image_full_path, class_index = dataset.samples[index]
244
+ image_relpath = os.path.relpath(image_full_path, self.root)
245
+ class_id, actual_index = split.parse_image_relpath(image_relpath)
246
+ class_name = class_names[class_id]
247
+ entries_array[index] = (actual_index, class_index, class_id, class_name)
248
+
249
+ logger.info(f'saving entries to "{self._entries_path}"')
250
+ self._save_extra(entries_array, self._entries_path)
251
+
252
+ def _dump_class_ids_and_names(self) -> None:
253
+ split = self.split
254
+ if split == ImageNet.Split.TEST:
255
+ return
256
+
257
+ entries_array = self._load_extra(self._entries_path)
258
+
259
+ max_class_id_length, max_class_name_length, max_class_index = -1, -1, -1
260
+ for entry in entries_array:
261
+ class_index, class_id, class_name = (
262
+ entry["class_index"],
263
+ entry["class_id"],
264
+ entry["class_name"],
265
+ )
266
+ max_class_index = max(int(class_index), max_class_index)
267
+ max_class_id_length = max(len(str(class_id)), max_class_id_length)
268
+ max_class_name_length = max(len(str(class_name)), max_class_name_length)
269
+
270
+ class_count = max_class_index + 1
271
+ class_ids_array = np.empty(class_count, dtype=f"U{max_class_id_length}")
272
+ class_names_array = np.empty(class_count, dtype=f"U{max_class_name_length}")
273
+ for entry in entries_array:
274
+ class_index, class_id, class_name = (
275
+ entry["class_index"],
276
+ entry["class_id"],
277
+ entry["class_name"],
278
+ )
279
+ class_ids_array[class_index] = class_id
280
+ class_names_array[class_index] = class_name
281
+
282
+ logger.info(f'saving class IDs to "{self._class_ids_path}"')
283
+ self._save_extra(class_ids_array, self._class_ids_path)
284
+
285
+ logger.info(f'saving class names to "{self._class_names_path}"')
286
+ self._save_extra(class_names_array, self._class_names_path)
287
+
288
+ def dump_extra(self) -> None:
289
+ self._dump_entries()
290
+ self._dump_class_ids_and_names()
dinov2/data/datasets/image_net_22k.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from dataclasses import dataclass
7
+ from enum import Enum
8
+ from functools import lru_cache
9
+ from gzip import GzipFile
10
+ from io import BytesIO
11
+ from mmap import ACCESS_READ, mmap
12
+ import os
13
+ from typing import Any, Callable, List, Optional, Set, Tuple
14
+ import warnings
15
+
16
+ import numpy as np
17
+
18
+ from .extended import ExtendedVisionDataset
19
+
20
+
21
+ _Labels = int
22
+
23
+ _DEFAULT_MMAP_CACHE_SIZE = 16 # Warning: This can exhaust file descriptors
24
+
25
+
26
+ @dataclass
27
+ class _ClassEntry:
28
+ block_offset: int
29
+ maybe_filename: Optional[str] = None
30
+
31
+
32
+ @dataclass
33
+ class _Entry:
34
+ class_index: int # noqa: E701
35
+ start_offset: int
36
+ end_offset: int
37
+ filename: str
38
+
39
+
40
+ class _Split(Enum):
41
+ TRAIN = "train"
42
+ VAL = "val"
43
+
44
+ @property
45
+ def length(self) -> int:
46
+ return {
47
+ _Split.TRAIN: 11_797_647,
48
+ _Split.VAL: 561_050,
49
+ }[self]
50
+
51
+ def entries_path(self):
52
+ return f"imagenet21kp_{self.value}.txt"
53
+
54
+
55
+ def _get_tarball_path(class_id: str) -> str:
56
+ return f"{class_id}.tar"
57
+
58
+
59
+ def _make_mmap_tarball(tarballs_root: str, mmap_cache_size: int):
60
+ @lru_cache(maxsize=mmap_cache_size)
61
+ def _mmap_tarball(class_id: str) -> mmap:
62
+ tarball_path = _get_tarball_path(class_id)
63
+ tarball_full_path = os.path.join(tarballs_root, tarball_path)
64
+ with open(tarball_full_path) as f:
65
+ return mmap(fileno=f.fileno(), length=0, access=ACCESS_READ)
66
+
67
+ return _mmap_tarball
68
+
69
+
70
+ class ImageNet22k(ExtendedVisionDataset):
71
+ _GZIPPED_INDICES: Set[int] = {
72
+ 841_545,
73
+ 1_304_131,
74
+ 2_437_921,
75
+ 2_672_079,
76
+ 2_795_676,
77
+ 2_969_786,
78
+ 6_902_965,
79
+ 6_903_550,
80
+ 6_903_628,
81
+ 7_432_557,
82
+ 7_432_589,
83
+ 7_813_809,
84
+ 8_329_633,
85
+ 10_296_990,
86
+ 10_417_652,
87
+ 10_492_265,
88
+ 10_598_078,
89
+ 10_782_398,
90
+ 10_902_612,
91
+ 11_203_736,
92
+ 11_342_890,
93
+ 11_397_596,
94
+ 11_589_762,
95
+ 11_705_103,
96
+ 12_936_875,
97
+ 13_289_782,
98
+ }
99
+ Labels = _Labels
100
+
101
+ def __init__(
102
+ self,
103
+ *,
104
+ root: str,
105
+ extra: str,
106
+ transforms: Optional[Callable] = None,
107
+ transform: Optional[Callable] = None,
108
+ target_transform: Optional[Callable] = None,
109
+ mmap_cache_size: int = _DEFAULT_MMAP_CACHE_SIZE,
110
+ ) -> None:
111
+ super().__init__(root, transforms, transform, target_transform)
112
+ self._extra_root = extra
113
+
114
+ entries_path = self._get_entries_path(root)
115
+ self._entries = self._load_extra(entries_path)
116
+
117
+ class_ids_path = self._get_class_ids_path(root)
118
+ self._class_ids = self._load_extra(class_ids_path)
119
+
120
+ self._gzipped_indices = ImageNet22k._GZIPPED_INDICES
121
+ self._mmap_tarball = _make_mmap_tarball(self._tarballs_root, mmap_cache_size)
122
+
123
+ def _get_entries_path(self, root: Optional[str] = None) -> str:
124
+ return "entries.npy"
125
+
126
+ def _get_class_ids_path(self, root: Optional[str] = None) -> str:
127
+ return "class-ids.npy"
128
+
129
+ def _find_class_ids(self, path: str) -> List[str]:
130
+ class_ids = []
131
+
132
+ with os.scandir(path) as entries:
133
+ for entry in entries:
134
+ root, ext = os.path.splitext(entry.name)
135
+ if ext != ".tar":
136
+ continue
137
+ class_ids.append(root)
138
+
139
+ return sorted(class_ids)
140
+
141
+ def _load_entries_class_ids(self, root: Optional[str] = None) -> Tuple[List[_Entry], List[str]]:
142
+ root = self.get_root(root)
143
+ entries: List[_Entry] = []
144
+ class_ids = self._find_class_ids(root)
145
+
146
+ for class_index, class_id in enumerate(class_ids):
147
+ path = os.path.join(root, "blocks", f"{class_id}.log")
148
+ class_entries = []
149
+
150
+ try:
151
+ with open(path) as f:
152
+ for line in f:
153
+ line = line.rstrip()
154
+ block, filename = line.split(":")
155
+ block_offset = int(block[6:])
156
+ filename = filename[1:]
157
+
158
+ maybe_filename = None
159
+ if filename != "** Block of NULs **":
160
+ maybe_filename = filename
161
+ _, ext = os.path.splitext(filename)
162
+ # assert ext == ".JPEG"
163
+
164
+ class_entry = _ClassEntry(block_offset, maybe_filename)
165
+ class_entries.append(class_entry)
166
+ except OSError as e:
167
+ raise RuntimeError(f'can not read blocks file "{path}"') from e
168
+
169
+ assert class_entries[-1].maybe_filename is None
170
+
171
+ for class_entry1, class_entry2 in zip(class_entries, class_entries[1:]):
172
+ assert class_entry1.block_offset <= class_entry2.block_offset
173
+ start_offset = 512 * class_entry1.block_offset
174
+ end_offset = 512 * class_entry2.block_offset
175
+ assert class_entry1.maybe_filename is not None
176
+ filename = class_entry1.maybe_filename
177
+ entry = _Entry(class_index, start_offset, end_offset, filename)
178
+ # Skip invalid image files (PIL throws UnidentifiedImageError)
179
+ if filename == "n06470073_47249.JPEG":
180
+ continue
181
+ entries.append(entry)
182
+
183
+ return entries, class_ids
184
+
185
+ def _load_extra(self, extra_path: str) -> np.ndarray:
186
+ extra_root = self._extra_root
187
+ extra_full_path = os.path.join(extra_root, extra_path)
188
+ return np.load(extra_full_path, mmap_mode="r")
189
+
190
+ def _save_extra(self, extra_array: np.ndarray, extra_path: str) -> None:
191
+ extra_root = self._extra_root
192
+ extra_full_path = os.path.join(extra_root, extra_path)
193
+ os.makedirs(extra_root, exist_ok=True)
194
+ np.save(extra_full_path, extra_array)
195
+
196
+ @property
197
+ def _tarballs_root(self) -> str:
198
+ return self.root
199
+
200
+ def find_class_id(self, class_index: int) -> str:
201
+ return str(self._class_ids[class_index])
202
+
203
+ def get_image_data(self, index: int) -> bytes:
204
+ entry = self._entries[index]
205
+ class_id = entry["class_id"]
206
+ class_mmap = self._mmap_tarball(class_id)
207
+
208
+ start_offset, end_offset = entry["start_offset"], entry["end_offset"]
209
+ try:
210
+ mapped_data = class_mmap[start_offset:end_offset]
211
+ data = mapped_data[512:] # Skip entry header block
212
+
213
+ if len(data) >= 2 and tuple(data[:2]) == (0x1F, 0x8B):
214
+ assert index in self._gzipped_indices, f"unexpected gzip header for sample {index}"
215
+ with GzipFile(fileobj=BytesIO(data)) as g:
216
+ data = g.read()
217
+ except Exception as e:
218
+ raise RuntimeError(f"can not retrieve image data for sample {index} " f'from "{class_id}" tarball') from e
219
+
220
+ return data
221
+
222
+ def get_target(self, index: int) -> Any:
223
+ return int(self._entries[index]["class_index"])
224
+
225
+ def get_targets(self) -> np.ndarray:
226
+ return self._entries["class_index"]
227
+
228
+ def get_class_id(self, index: int) -> str:
229
+ return str(self._entries[index]["class_id"])
230
+
231
+ def get_class_ids(self) -> np.ndarray:
232
+ return self._entries["class_id"]
233
+
234
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
235
+ with warnings.catch_warnings():
236
+ warnings.simplefilter("ignore")
237
+ return super().__getitem__(index)
238
+
239
+ def __len__(self) -> int:
240
+ return len(self._entries)
241
+
242
+ def _dump_entries(self, *args, **kwargs) -> None:
243
+ entries, class_ids = self._load_entries_class_ids(*args, **kwargs)
244
+
245
+ max_class_id_length, max_filename_length, max_class_index = -1, -1, -1
246
+ for entry in entries:
247
+ class_id = class_ids[entry.class_index]
248
+ max_class_index = max(entry.class_index, max_class_index)
249
+ max_class_id_length = max(len(class_id), max_class_id_length)
250
+ max_filename_length = max(len(entry.filename), max_filename_length)
251
+
252
+ dtype = np.dtype(
253
+ [
254
+ ("class_index", "<u4"),
255
+ ("class_id", f"U{max_class_id_length}"),
256
+ ("start_offset", "<u4"),
257
+ ("end_offset", "<u4"),
258
+ ("filename", f"U{max_filename_length}"),
259
+ ]
260
+ )
261
+ sample_count = len(entries)
262
+ entries_array = np.empty(sample_count, dtype=dtype)
263
+ for i, entry in enumerate(entries):
264
+ class_index = entry.class_index
265
+ class_id = class_ids[class_index]
266
+ start_offset = entry.start_offset
267
+ end_offset = entry.end_offset
268
+ filename = entry.filename
269
+ entries_array[i] = (
270
+ class_index,
271
+ class_id,
272
+ start_offset,
273
+ end_offset,
274
+ filename,
275
+ )
276
+
277
+ entries_path = self._get_entries_path(*args, **kwargs)
278
+ self._save_extra(entries_array, entries_path)
279
+
280
+ def _dump_class_ids(self, *args, **kwargs) -> None:
281
+ entries_path = self._get_entries_path(*args, **kwargs)
282
+ entries_array = self._load_extra(entries_path)
283
+
284
+ max_class_id_length, max_class_index = -1, -1
285
+ for entry in entries_array:
286
+ class_index, class_id = entry["class_index"], entry["class_id"]
287
+ max_class_index = max(int(class_index), max_class_index)
288
+ max_class_id_length = max(len(str(class_id)), max_class_id_length)
289
+
290
+ class_ids_array = np.empty(max_class_index + 1, dtype=f"U{max_class_id_length}")
291
+ for entry in entries_array:
292
+ class_index, class_id = entry["class_index"], entry["class_id"]
293
+ class_ids_array[class_index] = class_id
294
+ class_ids_path = self._get_class_ids_path(*args, **kwargs)
295
+ self._save_extra(class_ids_array, class_ids_path)
296
+
297
+ def _dump_extra(self, *args, **kwargs) -> None:
298
+ self._dump_entries(*args, *kwargs)
299
+ self._dump_class_ids(*args, *kwargs)
300
+
301
+ def dump_extra(self, root: Optional[str] = None) -> None:
302
+ return self._dump_extra(root)
dinov2/data/loaders.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ from enum import Enum
8
+ from typing import Any, Callable, List, Optional, TypeVar
9
+
10
+ import torch
11
+ from torch.utils.data import Sampler
12
+
13
+ from .datasets import ImageNet, ImageNet22k, HPAone, HPAFoV, CHAMMI_CP, CHAMMI_HPA, CHAMMI_WTC
14
+ from .samplers import EpochSampler, InfiniteSampler, ShardedInfiniteSampler
15
+
16
+
17
+ logger = logging.getLogger("dinov2")
18
+
19
+
20
+ class SamplerType(Enum):
21
+ DISTRIBUTED = 0
22
+ EPOCH = 1
23
+ INFINITE = 2
24
+ SHARDED_INFINITE = 3
25
+ SHARDED_INFINITE_NEW = 4
26
+
27
+
28
+ def _make_bool_str(b: bool) -> str:
29
+ return "yes" if b else "no"
30
+
31
+
32
+ def _make_sample_transform(image_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None):
33
+ def transform(sample):
34
+ image, target = sample
35
+ if image_transform is not None:
36
+ image = image_transform(image)
37
+ if target_transform is not None:
38
+ target = target_transform(target)
39
+ return image, target
40
+
41
+ return transform
42
+
43
+
44
+ def _parse_dataset_str(dataset_str: str):
45
+ tokens = dataset_str.split(":")
46
+
47
+ name = tokens[0]
48
+ kwargs = {}
49
+
50
+ for token in tokens[1:]:
51
+ key, value = token.split("=")
52
+ assert key in ("root", "extra", "split", "mode", "wildcard")
53
+ kwargs[key] = value
54
+
55
+ if name == "ImageNet":
56
+ class_ = ImageNet
57
+ if "split" in kwargs:
58
+ kwargs["split"] = ImageNet.Split[kwargs["split"]]
59
+ elif name == "ImageNet22k":
60
+ class_ = ImageNet22k
61
+ elif name == "HPAone":
62
+ class_ = HPAone
63
+ elif name == "HPAFoV":
64
+ class_ = HPAFoV
65
+ elif name == "CHAMMI_CP":
66
+ class_ = CHAMMI_CP
67
+ elif name == "CHAMMI_WTC":
68
+ class_ = CHAMMI_WTC
69
+ elif name == "CHAMMI_HPA":
70
+ class_ = CHAMMI_HPA
71
+ else:
72
+ raise ValueError(f'Unsupported dataset "{name}"')
73
+
74
+ return class_, kwargs
75
+
76
+
77
+ def make_dataset(
78
+ *,
79
+ dataset_str: str,
80
+ transform: Optional[Callable] = None,
81
+ target_transform: Optional[Callable] = None,
82
+ ):
83
+ """
84
+ Creates a dataset with the specified parameters.
85
+
86
+ Args:
87
+ dataset_str: A dataset string description (e.g. ImageNet:split=TRAIN).
88
+ transform: A transform to apply to images.
89
+ target_transform: A transform to apply to targets.
90
+
91
+ Returns:
92
+ The created dataset.
93
+ """
94
+ logger.info(f'using dataset: "{dataset_str}"')
95
+
96
+ class_, kwargs = _parse_dataset_str(dataset_str)
97
+ dataset = class_(transform=transform, target_transform=target_transform, **kwargs)
98
+
99
+ logger.info(f"# of dataset samples: {len(dataset):,d}")
100
+
101
+ # Aggregated datasets do not expose (yet) these attributes, so add them.
102
+ if not hasattr(dataset, "transform"):
103
+ setattr(dataset, "transform", transform)
104
+ if not hasattr(dataset, "target_transform"):
105
+ setattr(dataset, "target_transform", target_transform)
106
+
107
+ return dataset
108
+
109
+
110
+ def _make_sampler(
111
+ *,
112
+ dataset,
113
+ type: Optional[SamplerType] = None,
114
+ shuffle: bool = False,
115
+ seed: int = 0,
116
+ size: int = -1,
117
+ advance: int = 0,
118
+ ) -> Optional[Sampler]:
119
+ sample_count = len(dataset)
120
+
121
+ if type == SamplerType.INFINITE:
122
+ logger.info("sampler: infinite")
123
+ if size > 0:
124
+ raise ValueError("sampler size > 0 is invalid")
125
+ return InfiniteSampler(
126
+ sample_count=sample_count,
127
+ shuffle=shuffle,
128
+ seed=seed,
129
+ advance=advance,
130
+ )
131
+ elif type in (SamplerType.SHARDED_INFINITE, SamplerType.SHARDED_INFINITE_NEW):
132
+ logger.info("sampler: sharded infinite")
133
+ if size > 0:
134
+ raise ValueError("sampler size > 0 is invalid")
135
+ # TODO: Remove support for old shuffling
136
+ use_new_shuffle_tensor_slice = type == SamplerType.SHARDED_INFINITE_NEW
137
+ return ShardedInfiniteSampler(
138
+ sample_count=sample_count,
139
+ shuffle=shuffle,
140
+ seed=seed,
141
+ advance=advance,
142
+ use_new_shuffle_tensor_slice=use_new_shuffle_tensor_slice,
143
+ )
144
+ elif type == SamplerType.EPOCH:
145
+ logger.info("sampler: epoch")
146
+ if advance > 0:
147
+ raise NotImplementedError("sampler advance > 0 is not supported")
148
+ size = size if size > 0 else sample_count
149
+ logger.info(f"# of samples / epoch: {size:,d}")
150
+ return EpochSampler(
151
+ size=size,
152
+ sample_count=sample_count,
153
+ shuffle=shuffle,
154
+ seed=seed,
155
+ )
156
+ elif type == SamplerType.DISTRIBUTED:
157
+ logger.info("sampler: distributed")
158
+ if size > 0:
159
+ raise ValueError("sampler size > 0 is invalid")
160
+ if advance > 0:
161
+ raise ValueError("sampler advance > 0 is invalid")
162
+ return torch.utils.data.DistributedSampler(
163
+ dataset=dataset,
164
+ shuffle=shuffle,
165
+ seed=seed,
166
+ drop_last=False,
167
+ )
168
+
169
+ logger.info("sampler: none")
170
+ return None
171
+
172
+
173
+ T = TypeVar("T")
174
+
175
+
176
+ def make_data_loader(
177
+ *,
178
+ dataset,
179
+ batch_size: int,
180
+ num_workers: int,
181
+ shuffle: bool = True,
182
+ seed: int = 0,
183
+ sampler_type: Optional[SamplerType] = SamplerType.INFINITE,
184
+ sampler_size: int = -1,
185
+ sampler_advance: int = 0,
186
+ drop_last: bool = True,
187
+ persistent_workers: bool = False,
188
+ collate_fn: Optional[Callable[[List[T]], Any]] = None,
189
+ ):
190
+ """
191
+ Creates a data loader with the specified parameters.
192
+
193
+ Args:
194
+ dataset: A dataset (third party, LaViDa or WebDataset).
195
+ batch_size: The size of batches to generate.
196
+ num_workers: The number of workers to use.
197
+ shuffle: Whether to shuffle samples.
198
+ seed: The random seed to use.
199
+ sampler_type: Which sampler to use: EPOCH, INFINITE, SHARDED_INFINITE, SHARDED_INFINITE_NEW, DISTRIBUTED or None.
200
+ sampler_size: The number of images per epoch (when applicable) or -1 for the entire dataset.
201
+ sampler_advance: How many samples to skip (when applicable).
202
+ drop_last: Whether the last non-full batch of data should be dropped.
203
+ persistent_workers: maintain the workers Dataset instances alive after a dataset has been consumed once.
204
+ collate_fn: Function that performs batch collation
205
+ """
206
+
207
+ sampler = _make_sampler(
208
+ dataset=dataset,
209
+ type=sampler_type,
210
+ shuffle=shuffle,
211
+ seed=seed,
212
+ size=sampler_size,
213
+ advance=sampler_advance,
214
+ )
215
+
216
+ logger.info("using PyTorch data loader")
217
+ data_loader = torch.utils.data.DataLoader(
218
+ dataset,
219
+ sampler=sampler,
220
+ batch_size=batch_size,
221
+ num_workers=num_workers,
222
+ pin_memory=True,
223
+ drop_last=drop_last,
224
+ persistent_workers=persistent_workers,
225
+ collate_fn=collate_fn,
226
+ )
227
+
228
+ try:
229
+ logger.info(f"# of batches: {len(data_loader):,d}")
230
+ except TypeError: # data loader has no length
231
+ logger.info("infinite data loader")
232
+ return data_loader
dinov2/data/masking.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import random
7
+ import math
8
+ import numpy as np
9
+
10
+
11
+ class MaskingGenerator:
12
+ def __init__(
13
+ self,
14
+ input_size,
15
+ num_masking_patches=None,
16
+ min_num_patches=4,
17
+ max_num_patches=None,
18
+ min_aspect=0.3,
19
+ max_aspect=None,
20
+ ):
21
+ if not isinstance(input_size, tuple):
22
+ input_size = (input_size,) * 2
23
+ self.height, self.width = input_size
24
+
25
+ self.num_patches = self.height * self.width
26
+ self.num_masking_patches = num_masking_patches
27
+
28
+ self.min_num_patches = min_num_patches
29
+ self.max_num_patches = num_masking_patches if max_num_patches is None else max_num_patches
30
+
31
+ max_aspect = max_aspect or 1 / min_aspect
32
+ self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
33
+
34
+ def __repr__(self):
35
+ repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
36
+ self.height,
37
+ self.width,
38
+ self.min_num_patches,
39
+ self.max_num_patches,
40
+ self.num_masking_patches,
41
+ self.log_aspect_ratio[0],
42
+ self.log_aspect_ratio[1],
43
+ )
44
+ return repr_str
45
+
46
+ def get_shape(self):
47
+ return self.height, self.width
48
+
49
+ def _mask(self, mask, max_mask_patches):
50
+ delta = 0
51
+ for _ in range(10):
52
+ target_area = random.uniform(self.min_num_patches, max_mask_patches)
53
+ aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
54
+ h = int(round(math.sqrt(target_area * aspect_ratio)))
55
+ w = int(round(math.sqrt(target_area / aspect_ratio)))
56
+ if w < self.width and h < self.height:
57
+ top = random.randint(0, self.height - h)
58
+ left = random.randint(0, self.width - w)
59
+
60
+ num_masked = mask[top : top + h, left : left + w].sum()
61
+ # Overlap
62
+ if 0 < h * w - num_masked <= max_mask_patches:
63
+ for i in range(top, top + h):
64
+ for j in range(left, left + w):
65
+ if mask[i, j] == 0:
66
+ mask[i, j] = 1
67
+ delta += 1
68
+
69
+ if delta > 0:
70
+ break
71
+ return delta
72
+
73
+ def __call__(self, num_masking_patches=0):
74
+ mask = np.zeros(shape=self.get_shape(), dtype=bool)
75
+ mask_count = 0
76
+ while mask_count < num_masking_patches:
77
+ max_mask_patches = num_masking_patches - mask_count
78
+ max_mask_patches = min(max_mask_patches, self.max_num_patches)
79
+
80
+ delta = self._mask(mask, max_mask_patches)
81
+ if delta == 0:
82
+ break
83
+ else:
84
+ mask_count += delta
85
+
86
+ return mask
dinov2/data/samplers.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import itertools
7
+ from typing import Any, Optional
8
+ import warnings
9
+
10
+ import numpy as np
11
+ import torch
12
+ from torch.utils.data.sampler import Sampler
13
+
14
+ import dinov2.distributed as distributed
15
+
16
+
17
+ class EpochSampler(Sampler):
18
+ def __init__(
19
+ self,
20
+ *,
21
+ size: int,
22
+ sample_count: int,
23
+ shuffle: bool = False,
24
+ seed: int = 0,
25
+ start: Optional[int] = None,
26
+ step: Optional[int] = None,
27
+ ):
28
+ self._size = size
29
+ self._sample_count = sample_count
30
+ self._shuffle = shuffle
31
+ self._seed = seed
32
+ self._start = distributed.get_global_rank() if start is None else start
33
+ self._step = distributed.get_global_size() if step is None else step
34
+ self._epoch = 0
35
+
36
+ def __iter__(self):
37
+ count = (self._size + self._sample_count - 1) // self._sample_count
38
+ tiled_indices = np.tile(np.arange(self._sample_count), count)
39
+ if self._shuffle:
40
+ seed = self._seed * self._epoch if self._seed != 0 else self._epoch
41
+ rng = np.random.default_rng(seed)
42
+ iterable = rng.choice(tiled_indices, self._size, replace=False)
43
+ else:
44
+ iterable = tiled_indices[: self._size]
45
+
46
+ yield from itertools.islice(iterable, self._start, None, self._step)
47
+
48
+ def __len__(self):
49
+ return (self._size - self._start + self._step - 1) // self._step
50
+
51
+ def set_epoch(self, epoch):
52
+ self._epoch = epoch
53
+
54
+
55
+ def _get_numpy_dtype(size: int) -> Any:
56
+ return np.int32 if size <= 2**31 else np.int64
57
+
58
+
59
+ def _get_torch_dtype(size: int) -> Any:
60
+ return torch.int32 if size <= 2**31 else torch.int64
61
+
62
+
63
+ def _generate_randperm_indices(*, size: int, generator: torch.Generator):
64
+ """Generate the indices of a random permutation."""
65
+ dtype = _get_torch_dtype(size)
66
+ # This is actually matching PyTorch's CPU implementation, see: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorFactories.cpp#L900-L921
67
+ perm = torch.arange(size, dtype=dtype)
68
+ for i in range(size):
69
+ j = torch.randint(i, size, size=(1,), generator=generator).item()
70
+
71
+ # Always swap even if no-op
72
+ value = perm[j].item()
73
+ perm[j] = perm[i].item()
74
+ perm[i] = value
75
+ yield value
76
+
77
+
78
+ class InfiniteSampler(Sampler):
79
+ def __init__(
80
+ self,
81
+ *,
82
+ sample_count: int,
83
+ shuffle: bool = False,
84
+ seed: int = 0,
85
+ start: Optional[int] = None,
86
+ step: Optional[int] = None,
87
+ advance: int = 0,
88
+ ):
89
+ self._sample_count = sample_count
90
+ self._seed = seed
91
+ self._shuffle = shuffle
92
+ self._start = distributed.get_global_rank() if start is None else start
93
+ self._step = distributed.get_global_size() if step is None else step
94
+ self._advance = advance
95
+
96
+ def __iter__(self):
97
+ if self._shuffle:
98
+ iterator = self._shuffled_iterator()
99
+ else:
100
+ iterator = self._iterator()
101
+
102
+ yield from itertools.islice(iterator, self._advance, None)
103
+
104
+ def _iterator(self):
105
+ assert not self._shuffle
106
+
107
+ while True:
108
+ iterable = range(self._sample_count)
109
+ yield from itertools.islice(iterable, self._start, None, self._step)
110
+
111
+ def _shuffled_iterator(self):
112
+ assert self._shuffle
113
+
114
+ # Instantiate a generator here (rather than in the ctor) to keep the class
115
+ # picklable (requirement of mp.spawn)
116
+ generator = torch.Generator().manual_seed(self._seed)
117
+
118
+ while True:
119
+ iterable = _generate_randperm_indices(size=self._sample_count, generator=generator)
120
+ yield from itertools.islice(iterable, self._start, None, self._step)
121
+
122
+
123
+ # The following function is somewhat equivalent to _new_shuffle_tensor_slice below,
124
+ # but avoids a full in-place random permutation generation.
125
+ def _shuffle_tensor_slice(
126
+ *, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator
127
+ ) -> np.ndarray:
128
+ stop = len(tensor)
129
+ count = stop // step
130
+ drop_count = stop - step * count
131
+ if drop_count:
132
+ warnings.warn(f"# of dropped samples: {drop_count}")
133
+
134
+ dtype = _get_numpy_dtype(stop)
135
+ result = np.empty(count, dtype=dtype)
136
+
137
+ for i in range(count):
138
+ j = torch.randint(0, i + 1, size=(1,), generator=generator).item() if i > 0 else 0
139
+
140
+ result[i] = result[j]
141
+ result[j] = tensor[start + i * step].item()
142
+
143
+ return result
144
+
145
+
146
+ def _new_shuffle_tensor_slice(
147
+ *, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator
148
+ ) -> np.ndarray:
149
+ stop = len(tensor)
150
+ count = stop // step
151
+ dtype = torch.int64 # Needed for using randperm result as indices
152
+ count = stop // step
153
+ drop_count = stop - step * count
154
+ if drop_count:
155
+ warnings.warn(f"# of dropped samples: {drop_count}")
156
+ indices = torch.randperm(count, dtype=dtype, generator=generator)
157
+ return tensor[start::step][indices].numpy()
158
+
159
+
160
+ def _make_seed(seed: int, start: int, iter_count: int) -> int:
161
+ # NOTE: Tried a few variants (including iter_count << 32), this one worked best.
162
+ return seed + start + (iter_count << 24)
163
+
164
+
165
+ class ShardedInfiniteSampler(Sampler):
166
+ def __init__(
167
+ self,
168
+ *,
169
+ sample_count: int,
170
+ shuffle: bool = False,
171
+ seed: int = 0,
172
+ start: Optional[int] = None,
173
+ step: Optional[int] = None,
174
+ advance: int = 0,
175
+ use_new_shuffle_tensor_slice: bool = False,
176
+ ):
177
+ self._sample_count = sample_count
178
+ self._seed = seed
179
+ self._shuffle = shuffle
180
+ self._start = distributed.get_global_rank() if start is None else start
181
+ self._step = distributed.get_global_size() if step is None else step
182
+ self._advance = advance
183
+ self._iter_count = 0
184
+ self._shuffle_tensor_slice_fn = (
185
+ _new_shuffle_tensor_slice if use_new_shuffle_tensor_slice else _shuffle_tensor_slice
186
+ )
187
+
188
+ def __iter__(self):
189
+ iter_count = self._advance // self._sample_count
190
+ if iter_count > 0:
191
+ self._advance -= iter_count * self._sample_count
192
+ self._iter_count += iter_count
193
+
194
+ if self._shuffle:
195
+ iterator = self._shuffled_iterator()
196
+ else:
197
+ iterator = self._iterator()
198
+
199
+ yield from itertools.islice(iterator, self._advance, None)
200
+
201
+ def _iterator(self):
202
+ assert not self._shuffle
203
+
204
+ while True:
205
+ iterable = range(self._sample_count)
206
+ yield from itertools.islice(iterable, self._start, None, self._step)
207
+
208
+ def _shuffled_iterator(self):
209
+ assert self._shuffle
210
+
211
+ # Instantiate a generator here (rather than in the ctor) to be keep the class
212
+ # picklable (requirement of mp.spawn)
213
+ generator = torch.Generator()
214
+
215
+ # Always shuffle everything first
216
+ generator.manual_seed(self._seed)
217
+ dtype = _get_torch_dtype(self._sample_count)
218
+ perm = torch.randperm(self._sample_count, dtype=dtype, generator=generator)
219
+
220
+ while True:
221
+ # Re-seed on each iteration to allow skipping whole permutations
222
+ seed = _make_seed(self._seed, self._start, self._iter_count)
223
+ generator.manual_seed(seed)
224
+
225
+ iterable = self._shuffle_tensor_slice_fn(
226
+ tensor=perm, start=self._start, step=self._step, generator=generator
227
+ )
228
+ yield from iterable
229
+ self._iter_count += 1
dinov2/data/transforms.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from typing import Sequence
7
+
8
+ import torch
9
+ from torchvision import transforms
10
+
11
+
12
+ class GaussianBlur(transforms.RandomApply):
13
+ """
14
+ Apply Gaussian Blur to the PIL image.
15
+ """
16
+
17
+ def __init__(self, *, p: float = 0.5, radius_min: float = 0.1, radius_max: float = 2.0):
18
+ # NOTE: torchvision is applying 1 - probability to return the original image
19
+ keep_p = 1 - p
20
+ transform = transforms.GaussianBlur(kernel_size=9, sigma=(radius_min, radius_max))
21
+ super().__init__(transforms=[transform], p=keep_p)
22
+
23
+
24
+ class MaybeToTensor(transforms.ToTensor):
25
+ """
26
+ Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor, or keep as is if already a tensor.
27
+ """
28
+
29
+ def __call__(self, pic):
30
+ """
31
+ Args:
32
+ pic (PIL Image, numpy.ndarray or torch.tensor): Image to be converted to tensor.
33
+ Returns:
34
+ Tensor: Converted image.
35
+ """
36
+ if isinstance(pic, torch.Tensor):
37
+ return pic
38
+ return super().__call__(pic)
39
+
40
+
41
+ # Use timm's names
42
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
43
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
44
+
45
+
46
+ def make_normalize_transform(
47
+ mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
48
+ std: Sequence[float] = IMAGENET_DEFAULT_STD,
49
+ ) -> transforms.Normalize:
50
+ return transforms.Normalize(mean=mean, std=std)
51
+
52
+
53
+ # This roughly matches torchvision's preset for classification training:
54
+ # https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L6-L44
55
+ def make_classification_train_transform(
56
+ *,
57
+ crop_size: int = 224,
58
+ interpolation=transforms.InterpolationMode.BICUBIC,
59
+ hflip_prob: float = 0.5,
60
+ mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
61
+ std: Sequence[float] = IMAGENET_DEFAULT_STD,
62
+ ):
63
+ transforms_list = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
64
+ if hflip_prob > 0.0:
65
+ transforms_list.append(transforms.RandomHorizontalFlip(hflip_prob))
66
+ transforms_list.extend(
67
+ [
68
+ MaybeToTensor(),
69
+ make_normalize_transform(mean=mean, std=std),
70
+ ]
71
+ )
72
+ return transforms.Compose(transforms_list)
73
+
74
+
75
+ # This matches (roughly) torchvision's preset for classification evaluation:
76
+ # https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L47-L69
77
+ def make_classification_eval_transform(
78
+ *,
79
+ resize_size: int = 256,
80
+ interpolation=transforms.InterpolationMode.BICUBIC,
81
+ crop_size: int = 224,
82
+ mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
83
+ std: Sequence[float] = IMAGENET_DEFAULT_STD,
84
+ ) -> transforms.Compose:
85
+ transforms_list = [
86
+ transforms.Resize(resize_size, interpolation=interpolation),
87
+ transforms.CenterCrop(crop_size),
88
+ MaybeToTensor(),
89
+ make_normalize_transform(mean=mean, std=std),
90
+ ]
91
+ return transforms.Compose(transforms_list)
dinov2/distributed/__init__.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ import random
8
+ import re
9
+ import socket
10
+ from typing import Dict, List
11
+
12
+ import torch
13
+ import torch.distributed as dist
14
+
15
+ _LOCAL_RANK = -1
16
+ _LOCAL_WORLD_SIZE = -1
17
+
18
+
19
+ def is_enabled() -> bool:
20
+ """
21
+ Returns:
22
+ True if distributed training is enabled
23
+ """
24
+ return dist.is_available() and dist.is_initialized()
25
+
26
+
27
+ def get_global_size() -> int:
28
+ """
29
+ Returns:
30
+ The number of processes in the process group
31
+ """
32
+ return dist.get_world_size() if is_enabled() else 1
33
+
34
+
35
+ def get_global_rank() -> int:
36
+ """
37
+ Returns:
38
+ The rank of the current process within the global process group.
39
+ """
40
+ return dist.get_rank() if is_enabled() else 0
41
+
42
+
43
+ def get_local_rank() -> int:
44
+ """
45
+ Returns:
46
+ The rank of the current process within the local (per-machine) process group.
47
+ """
48
+ if not is_enabled():
49
+ return 0
50
+ assert 0 <= _LOCAL_RANK < _LOCAL_WORLD_SIZE
51
+ return _LOCAL_RANK
52
+
53
+
54
+ def get_local_size() -> int:
55
+ """
56
+ Returns:
57
+ The size of the per-machine process group,
58
+ i.e. the number of processes per machine.
59
+ """
60
+ if not is_enabled():
61
+ return 1
62
+ assert 0 <= _LOCAL_RANK < _LOCAL_WORLD_SIZE
63
+ return _LOCAL_WORLD_SIZE
64
+
65
+
66
+ def is_main_process() -> bool:
67
+ """
68
+ Returns:
69
+ True if the current process is the main one.
70
+ """
71
+ return get_global_rank() == 0
72
+
73
+
74
+ def _restrict_print_to_main_process() -> None:
75
+ """
76
+ This function disables printing when not in the main process
77
+ """
78
+ import builtins as __builtin__
79
+
80
+ builtin_print = __builtin__.print
81
+
82
+ def print(*args, **kwargs):
83
+ force = kwargs.pop("force", False)
84
+ if is_main_process() or force:
85
+ builtin_print(*args, **kwargs)
86
+
87
+ __builtin__.print = print
88
+
89
+
90
+ def _get_master_port(seed: int = 0) -> int:
91
+ MIN_MASTER_PORT, MAX_MASTER_PORT = (20_000, 60_000)
92
+
93
+ master_port_str = os.environ.get("MASTER_PORT")
94
+ if master_port_str is None:
95
+ rng = random.Random(seed)
96
+ return rng.randint(MIN_MASTER_PORT, MAX_MASTER_PORT)
97
+
98
+ return int(master_port_str)
99
+
100
+
101
+ def _get_available_port() -> int:
102
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
103
+ # A "" host address means INADDR_ANY i.e. binding to all interfaces.
104
+ # Note this is not compatible with IPv6.
105
+ s.bind(("", 0))
106
+ port = s.getsockname()[1]
107
+ return port
108
+
109
+
110
+ _TORCH_DISTRIBUTED_ENV_VARS = (
111
+ "MASTER_ADDR",
112
+ "MASTER_PORT",
113
+ "RANK",
114
+ "WORLD_SIZE",
115
+ "LOCAL_RANK",
116
+ "LOCAL_WORLD_SIZE",
117
+ )
118
+
119
+
120
+ def _collect_env_vars() -> Dict[str, str]:
121
+ return {env_var: os.environ[env_var] for env_var in _TORCH_DISTRIBUTED_ENV_VARS if env_var in os.environ}
122
+
123
+
124
+ def _is_slurm_job_process() -> bool:
125
+ return "SLURM_JOB_ID" in os.environ
126
+
127
+
128
+ def _parse_slurm_node_list(s: str) -> List[str]:
129
+ nodes = []
130
+ # Extract "hostname", "hostname[1-2,3,4-5]," substrings
131
+ p = re.compile(r"(([^\[]+)(?:\[([^\]]+)\])?),?")
132
+ for m in p.finditer(s):
133
+ prefix, suffixes = s[m.start(2) : m.end(2)], s[m.start(3) : m.end(3)]
134
+ for suffix in suffixes.split(","):
135
+ span = suffix.split("-")
136
+ if len(span) == 1:
137
+ nodes.append(prefix + suffix)
138
+ else:
139
+ width = len(span[0])
140
+ start, end = int(span[0]), int(span[1]) + 1
141
+ nodes.extend([prefix + f"{i:0{width}}" for i in range(start, end)])
142
+ return nodes
143
+
144
+
145
+ def _check_env_variable(key: str, new_value: str):
146
+ # Only check for difference with preset environment variables
147
+ if key in os.environ and os.environ[key] != new_value:
148
+ raise RuntimeError(f"Cannot export environment variables as {key} is already set")
149
+
150
+
151
+ class _TorchDistributedEnvironment:
152
+ def __init__(self):
153
+ self.master_addr = "127.0.0.1"
154
+ self.master_port = 0
155
+ self.rank = -1
156
+ self.world_size = -1
157
+ self.local_rank = -1
158
+ self.local_world_size = -1
159
+
160
+ if _is_slurm_job_process():
161
+ return self._set_from_slurm_env()
162
+
163
+ env_vars = _collect_env_vars()
164
+ if not env_vars:
165
+ # Environment is not set
166
+ pass
167
+ elif len(env_vars) == len(_TORCH_DISTRIBUTED_ENV_VARS):
168
+ # Environment is fully set
169
+ return self._set_from_preset_env()
170
+ else:
171
+ # Environment is partially set
172
+ collected_env_vars = ", ".join(env_vars.keys())
173
+ raise RuntimeError(f"Partially set environment: {collected_env_vars}")
174
+
175
+ if torch.cuda.device_count() > 0:
176
+ return self._set_from_local()
177
+
178
+ raise RuntimeError("Can't initialize PyTorch distributed environment")
179
+
180
+ # Slurm job created with sbatch, submitit, etc...
181
+ def _set_from_slurm_env(self):
182
+ # logger.info("Initialization from Slurm environment")
183
+ job_id = int(os.environ["SLURM_JOB_ID"])
184
+ node_count = int(os.environ["SLURM_JOB_NUM_NODES"])
185
+ nodes = _parse_slurm_node_list(os.environ["SLURM_JOB_NODELIST"])
186
+ assert len(nodes) == node_count
187
+
188
+ self.master_addr = nodes[0]
189
+ self.master_port = _get_master_port(seed=job_id)
190
+ self.rank = int(os.environ["SLURM_PROCID"])
191
+ self.world_size = int(os.environ["SLURM_NTASKS"])
192
+ assert self.rank < self.world_size
193
+ self.local_rank = int(os.environ["SLURM_LOCALID"])
194
+ self.local_world_size = self.world_size // node_count
195
+ assert self.local_rank < self.local_world_size
196
+
197
+ # Single node job with preset environment (i.e. torchrun)
198
+ def _set_from_preset_env(self):
199
+ # logger.info("Initialization from preset environment")
200
+ self.master_addr = os.environ["MASTER_ADDR"]
201
+ self.master_port = os.environ["MASTER_PORT"]
202
+ self.rank = int(os.environ["RANK"])
203
+ self.world_size = int(os.environ["WORLD_SIZE"])
204
+ assert self.rank < self.world_size
205
+ self.local_rank = int(os.environ["LOCAL_RANK"])
206
+ self.local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])
207
+ assert self.local_rank < self.local_world_size
208
+
209
+ # Single node and GPU job (i.e. local script run)
210
+ def _set_from_local(self):
211
+ # logger.info("Initialization from local")
212
+ self.master_addr = "127.0.0.1"
213
+ self.master_port = _get_available_port()
214
+ self.rank = 0
215
+ self.world_size = 1
216
+ self.local_rank = 0
217
+ self.local_world_size = 1
218
+
219
+ def export(self, *, overwrite: bool) -> "_TorchDistributedEnvironment":
220
+ # See the "Environment variable initialization" section from
221
+ # https://pytorch.org/docs/stable/distributed.html for the complete list of
222
+ # environment variables required for the env:// initialization method.
223
+ env_vars = {
224
+ "MASTER_ADDR": self.master_addr,
225
+ "MASTER_PORT": str(self.master_port),
226
+ "RANK": str(self.rank),
227
+ "WORLD_SIZE": str(self.world_size),
228
+ "LOCAL_RANK": str(self.local_rank),
229
+ "LOCAL_WORLD_SIZE": str(self.local_world_size),
230
+ }
231
+ if not overwrite:
232
+ for k, v in env_vars.items():
233
+ _check_env_variable(k, v)
234
+
235
+ os.environ.update(env_vars)
236
+ return self
237
+
238
+
239
+ def enable(*, set_cuda_current_device: bool = True, overwrite: bool = False, allow_nccl_timeout: bool = False):
240
+ """Enable distributed mode
241
+
242
+ Args:
243
+ set_cuda_current_device: If True, call torch.cuda.set_device() to set the
244
+ current PyTorch CUDA device to the one matching the local rank.
245
+ overwrite: If True, overwrites already set variables. Else fails.
246
+ """
247
+
248
+ global _LOCAL_RANK, _LOCAL_WORLD_SIZE
249
+ if _LOCAL_RANK >= 0 or _LOCAL_WORLD_SIZE >= 0:
250
+ raise RuntimeError("Distributed mode has already been enabled")
251
+ torch_env = _TorchDistributedEnvironment()
252
+ torch_env.export(overwrite=overwrite)
253
+
254
+ if set_cuda_current_device:
255
+ torch.cuda.set_device(torch_env.local_rank)
256
+
257
+ if allow_nccl_timeout:
258
+ # This allows to use torch distributed timeout in a NCCL backend
259
+ key, value = "NCCL_ASYNC_ERROR_HANDLING", "1"
260
+ if not overwrite:
261
+ _check_env_variable(key, value)
262
+ os.environ[key] = value
263
+
264
+ dist.init_process_group(backend="nccl")
265
+ dist.barrier()
266
+
267
+ # Finalize setup
268
+ _LOCAL_RANK = torch_env.local_rank
269
+ _LOCAL_WORLD_SIZE = torch_env.local_world_size
270
+ _restrict_print_to_main_process()
dinov2/eval/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
dinov2/eval/cell_dino/knn.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the CC-by-NC licence,
4
+ # found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree.
5
+
6
+ import argparse
7
+ from functools import partial
8
+ import json
9
+ import logging
10
+ import os
11
+ import sys
12
+ from typing import List, Optional, Any
13
+ import numpy as np
14
+
15
+ import torch
16
+ import torch.backends.cudnn as cudnn
17
+ import pandas as pd
18
+ from sklearn.metrics import f1_score
19
+
20
+ import dinov2.distributed as distributed
21
+ from dinov2.data import make_dataset, DatasetWithEnumeratedTargets, SamplerType, make_data_loader
22
+ from dinov2.data.cell_dino.transforms import NormalizationType, make_classification_eval_cell_transform
23
+ from dinov2.eval.metrics import build_metric, MetricType
24
+ from dinov2.eval.setup import get_args_parser as get_setup_args_parser
25
+ from dinov2.eval.setup import setup_and_build_model
26
+
27
+ from dinov2.data import ResultsAccumulator
28
+ from dinov2.eval.utils import ModelWithNormalize
29
+ from dinov2.eval.cell_dino.utils import (
30
+ BagOfChannelsModelWithNormalize,
31
+ extract_features_cell_dino,
32
+ average_metrics,
33
+ create_train_dataset_dict,
34
+ get_num_classes,
35
+ extract_features_for_dataset_dict,
36
+ evaluate_with_accumulate,
37
+ KnnModule,
38
+ )
39
+ from dinov2.eval.knn import DictKeysModule
40
+ from torch.utils.data import Subset as SubsetEx
41
+ from torch.utils.data import ConcatDataset as ConcatDatasetEx
42
+
43
+
44
+ logger = logging.getLogger("dinov2")
45
+
46
+
47
+ def get_args_parser(
48
+ description: Optional[str] = None,
49
+ parents: Optional[List[argparse.ArgumentParser]] = None,
50
+ add_help: bool = True,
51
+ ):
52
+ parents = parents or []
53
+ setup_args_parser = get_setup_args_parser(parents=parents, add_help=False)
54
+ parents = [setup_args_parser]
55
+ parser = argparse.ArgumentParser(
56
+ description=description,
57
+ parents=parents,
58
+ add_help=add_help,
59
+ )
60
+ parser.add_argument(
61
+ "--train-dataset",
62
+ dest="train_dataset_str",
63
+ type=str,
64
+ help="Training dataset",
65
+ )
66
+ parser.add_argument(
67
+ "--val-dataset",
68
+ dest="val_dataset_str",
69
+ type=str,
70
+ help="Validation dataset",
71
+ )
72
+ parser.add_argument(
73
+ "--nb_knn",
74
+ nargs="+",
75
+ type=int,
76
+ help="Number of NN to use. 20 is usually working the best.",
77
+ )
78
+ parser.add_argument(
79
+ "--temperature",
80
+ type=float,
81
+ help="Temperature used in the voting coefficient",
82
+ )
83
+ parser.add_argument(
84
+ "--gather-on-cpu",
85
+ action="store_true",
86
+ help="Whether to gather the train features on cpu, slower"
87
+ "but useful to avoid OOM for large datasets (e.g. ImageNet22k).",
88
+ )
89
+ parser.add_argument(
90
+ "--batch-size",
91
+ type=int,
92
+ help="Batch size.",
93
+ )
94
+ parser.add_argument(
95
+ "--n-per-class-list",
96
+ nargs="+",
97
+ type=int,
98
+ help="Number to take per class",
99
+ )
100
+ parser.add_argument(
101
+ "--n-tries",
102
+ type=int,
103
+ help="Number of tries",
104
+ )
105
+ parser.add_argument(
106
+ "--leave-one-out-dataset",
107
+ type=str,
108
+ help="Path with indexes to use the leave one out strategy for CHAMMI_CP task 3 and CHAMMI_HPA task 4",
109
+ )
110
+ parser.add_argument(
111
+ "--bag-of-channels",
112
+ action="store_true",
113
+ help='Whether to use the "bag of channels" channel adaptive strategy',
114
+ )
115
+ parser.add_argument(
116
+ "--crop-size",
117
+ type=int,
118
+ help="crop size for train and eval",
119
+ )
120
+ parser.add_argument(
121
+ "--resize-size",
122
+ type=int,
123
+ help="resize size for image just before crop. 0: no resize",
124
+ )
125
+ parser.add_argument(
126
+ "--metric-type",
127
+ type=MetricType,
128
+ choices=list(MetricType),
129
+ help="Validation metric",
130
+ )
131
+ parser.add_argument(
132
+ "--avgpool",
133
+ action="store_true",
134
+ help="Whether to use average pooling of path tokens in addition to CLS tokens",
135
+ )
136
+
137
+ parser.set_defaults(
138
+ train_dataset_str="ImageNet:split=TRAIN",
139
+ val_dataset_str="ImageNet:split=VAL",
140
+ nb_knn=[1],
141
+ temperature=0.07,
142
+ batch_size=256,
143
+ resize_size=0,
144
+ )
145
+ return parser
146
+
147
+
148
+ class SequentialWithKwargs(torch.nn.Sequential):
149
+ def __init__(self, *args):
150
+ super().__init__(*args)
151
+
152
+ def forward(self, input, **kwargs):
153
+
154
+ input = self[0](input, **kwargs)
155
+ for module in self[1:]:
156
+ input = module(input)
157
+ return input
158
+
159
+
160
+ def create_train_test_dataset_dict_leave_one_out(
161
+ train_dataset,
162
+ test_dataset,
163
+ ) -> dict[int, dict[int, Any]]:
164
+ """
165
+ This function implements a train dataset dictionary with the leave-one-out (LOO) method.
166
+ Specifically, given a train dataset and test dataset, it creates a train dataset for each
167
+ test dataset point, which is a combination of train+test dataset except for this specific data point.
168
+ At the end, it contains len(test_dataset) key and value pairs.
169
+
170
+ Format is {"nth-test-sample": dataset_without_test_sample}
171
+ """
172
+ train_dataset_dict: dict[int, Any] = {}
173
+ test_size = len(test_dataset)
174
+
175
+ for test_sample_index in range(test_size):
176
+ test_indices_bool = torch.ones(test_size, dtype=bool)
177
+ test_indices_bool[test_sample_index] = False
178
+ train_dataset_dict[test_sample_index] = ConcatDatasetEx(
179
+ [train_dataset, SubsetEx(test_dataset, test_indices_bool.nonzero().flatten())]
180
+ )
181
+
182
+ return train_dataset_dict
183
+
184
+
185
+ def eval_knn_with_leave_one_out(
186
+ model, leave_one_out_dataset, train_dataset, test_dataset, metric_type, nb_knn, temperature, batch_size, num_workers
187
+ ):
188
+ num_classes = get_num_classes(test_dataset)
189
+ train_dataset_dict = create_train_dataset_dict(train_dataset)
190
+ test_dataset_dict = create_train_dataset_dict(test_dataset)
191
+
192
+ logger.info("Extracting features for train set...")
193
+ train_data_dict = extract_features_for_dataset_dict(
194
+ model, train_dataset_dict, batch_size, num_workers, gather_on_cpu=True
195
+ )
196
+ test_data_dict = extract_features_for_dataset_dict(
197
+ model, test_dataset_dict, batch_size, num_workers, gather_on_cpu=True
198
+ )
199
+
200
+ train_features = train_data_dict[0]["train_features"]
201
+ train_labels = train_data_dict[0]["train_labels"]
202
+ test_features = test_data_dict[0]["train_features"]
203
+ test_labels = test_data_dict[0]["train_labels"]
204
+
205
+ metric_collection = build_metric(metric_type, num_classes=3)
206
+
207
+ device = torch.cuda.current_device()
208
+ partial_knn_module = partial(KnnModule, T=temperature, device=device, num_classes=num_classes)
209
+
210
+ logger.info("Reading the leave-one-out label metadata.")
211
+
212
+ leave_one_out_indices = {}
213
+ metadata = pd.read_csv(leave_one_out_dataset)
214
+ if "HPA" in leave_one_out_dataset:
215
+ metadata = metadata[metadata["Task_three"]].reset_index()
216
+ leave_one_out_label_type = "cell_type"
217
+ else:
218
+ metadata = metadata[metadata["Task_four"]].reset_index()
219
+ leave_one_out_label_type = "Plate"
220
+ leave_one_out_labels = metadata[leave_one_out_label_type].unique()
221
+
222
+ for leave_one_out_label in leave_one_out_labels:
223
+ leave_one_out_indices[leave_one_out_label] = torch.tensor(
224
+ metadata[metadata[leave_one_out_label_type] == leave_one_out_label].index.values
225
+ )
226
+
227
+ # ============ evaluation ... ============
228
+ logger.info("Start the k-NN classification.")
229
+
230
+ eval_metrics_dict = {}
231
+ postprocessors, metrics = {k: DictKeysModule([k]) for k in nb_knn}, {
232
+ k: metric_collection.clone().to(device) for k in nb_knn
233
+ }
234
+ for metric_key in metrics.keys():
235
+ metrics[metric_key] = metrics[metric_key].to(device)
236
+
237
+ accumulator_class = ResultsAccumulator
238
+ accumulators = {k: accumulator_class() for k in postprocessors.keys()}
239
+ all_preds = []
240
+ all_target = []
241
+
242
+ for loo_label, loo_indices in leave_one_out_indices.items():
243
+ logger.info(f"Evaluating on test sample {loo_label}")
244
+ loo_for_training_indices = torch.ones(test_features.shape[0], dtype=bool)
245
+ loo_for_training_indices[loo_indices] = False
246
+ train_features_sample = torch.cat([train_features, test_features[loo_for_training_indices]])
247
+ train_labels_sample = torch.cat([train_labels, test_labels[loo_for_training_indices]])
248
+ logger.info(f"Train shape {train_features_sample.shape}, Test shape {test_features[loo_indices].shape}")
249
+ logger.info(
250
+ f"Train values {train_labels_sample.unique(return_counts=True)}, Test shape {test_labels[loo_indices].unique(return_counts=True)}"
251
+ )
252
+ knn_module = partial_knn_module(
253
+ train_features=train_features_sample, train_labels=train_labels_sample, nb_knn=nb_knn
254
+ )
255
+
256
+ output = knn_module(test_features[loo_indices].to(device))
257
+ all_preds.append(output[1])
258
+ all_target.append(test_labels[loo_indices])
259
+ output[1] = output[1][:, 4:]
260
+ transformed_test_labels = test_labels[loo_indices] - 4
261
+ for k, metric in metrics.items():
262
+ metric_inputs = postprocessors[k](output, transformed_test_labels.to(device))
263
+ metric.update(**metric_inputs)
264
+ accumulators[k].update(
265
+ preds=metric_inputs["preds"], target=metric_inputs["target"], index=loo_indices.to(device)
266
+ )
267
+
268
+ all_preds = torch.cat(all_preds).cpu().detach().numpy()
269
+
270
+ all_preds = np.argmax(all_preds, axis=1)
271
+ all_target = torch.cat(all_target).cpu().detach().numpy()
272
+
273
+ f1 = f1_score(all_target, all_preds, average="macro", labels=[4, 5, 6])
274
+ logger.info(f"Real f1 score: {f1}")
275
+ eval_metrics = {
276
+ k: metric.compute() for k, metric in metrics.items()
277
+ } # next erased by the real f1 score computed above
278
+
279
+ for k in nb_knn:
280
+ if k not in eval_metrics_dict:
281
+ eval_metrics_dict[k] = {}
282
+ eval_metrics_dict[k] = {metric: f1 * 100.0 for metric, v in eval_metrics[k].items()}
283
+
284
+ if len(train_data_dict) > 1:
285
+ return {k: average_metrics(eval_metrics_dict[k]) for k in eval_metrics_dict.keys()}
286
+
287
+ return {k: eval_metrics_dict[k] for k in eval_metrics_dict.keys()}
288
+
289
+
290
+ def eval_knn_with_model(
291
+ model,
292
+ output_dir,
293
+ train_dataset_str,
294
+ val_dataset_str,
295
+ nb_knn=(10, 20, 100, 200),
296
+ temperature=0.07,
297
+ autocast_dtype=torch.float,
298
+ metric_type=MetricType.MEAN_ACCURACY,
299
+ transform=None,
300
+ resize_size=256,
301
+ crop_size=224,
302
+ batch_size=256,
303
+ num_workers=5,
304
+ leave_one_out_dataset="",
305
+ bag_of_channels=False,
306
+ avgpool=False,
307
+ ):
308
+ autocast_ctx = partial(torch.cuda.amp.autocast, enabled=True, dtype=autocast_dtype)
309
+ if bag_of_channels:
310
+ model = BagOfChannelsModelWithNormalize(model, autocast_ctx, avgpool)
311
+ else:
312
+ model = ModelWithNormalize(model)
313
+ if leave_one_out_dataset == "" or leave_one_out_dataset is None:
314
+ leave_one_out = False
315
+ else:
316
+ leave_one_out = True
317
+
318
+ cudnn.benchmark = True
319
+ transform = make_classification_eval_cell_transform(
320
+ normalization_type=NormalizationType.SELF_NORM_CENTER_CROP, resize_size=resize_size, crop_size=crop_size
321
+ )
322
+
323
+ train_dataset = make_dataset(dataset_str=train_dataset_str, transform=transform)
324
+ results_dict = {}
325
+ test_dataset = make_dataset(dataset_str=val_dataset_str, transform=transform)
326
+
327
+ with torch.cuda.amp.autocast(dtype=autocast_dtype):
328
+ if leave_one_out:
329
+ results_dict_knn = eval_knn_with_leave_one_out(
330
+ model=model,
331
+ leave_one_out_dataset=leave_one_out_dataset,
332
+ train_dataset=train_dataset,
333
+ test_dataset=test_dataset,
334
+ metric_type=metric_type,
335
+ nb_knn=nb_knn,
336
+ temperature=temperature,
337
+ batch_size=batch_size,
338
+ num_workers=num_workers,
339
+ )
340
+ else:
341
+ results_dict_knn = eval_knn(
342
+ model=model,
343
+ train_dataset=train_dataset,
344
+ test_dataset=test_dataset,
345
+ metric_type=metric_type,
346
+ nb_knn=nb_knn,
347
+ temperature=temperature,
348
+ batch_size=batch_size,
349
+ num_workers=num_workers,
350
+ )
351
+
352
+ for knn_ in results_dict_knn.keys():
353
+ top1 = results_dict_knn[knn_]["top-1"]
354
+ results_dict[f"{val_dataset_str}_{knn_} Top 1"] = top1
355
+ results_string = f"{val_dataset_str} {knn_} NN classifier result: Top1: {top1:.2f}"
356
+ if "top-5" in results_dict_knn[knn_]:
357
+ top5 = results_dict_knn[knn_]["top-5"]
358
+ results_dict[f"{val_dataset_str}_{knn_} Top 5"] = top5
359
+ results_string += f"Top5: {top5:.2f}"
360
+ logger.info(results_string)
361
+
362
+ metrics_file_path = os.path.join(output_dir, "results_eval_knn.json")
363
+ with open(metrics_file_path, "a") as f:
364
+ for k, v in results_dict.items():
365
+ f.write(json.dumps({k: v}) + "\n")
366
+
367
+ if distributed.is_enabled():
368
+ torch.distributed.barrier()
369
+ return results_dict
370
+
371
+
372
+ def eval_knn(
373
+ model,
374
+ train_dataset,
375
+ test_dataset,
376
+ metric_type,
377
+ nb_knn,
378
+ temperature,
379
+ batch_size,
380
+ num_workers,
381
+ few_shot_eval=False,
382
+ few_shot_k_or_percent=None,
383
+ few_shot_n_tries=1,
384
+ ):
385
+ num_classes = get_num_classes(train_dataset)
386
+ train_dataset_dict = create_train_dataset_dict(
387
+ train_dataset,
388
+ few_shot_eval=few_shot_eval,
389
+ few_shot_k_or_percent=few_shot_k_or_percent,
390
+ few_shot_n_tries=few_shot_n_tries,
391
+ )
392
+
393
+ logger.info("Extracting features for train set...")
394
+
395
+ train_data_dict: dict[int, dict[str, torch.Tensor]] = {}
396
+ for try_n, dataset in train_dataset_dict.items():
397
+ features, labels = extract_features_cell_dino(model, dataset, batch_size, num_workers, gather_on_cpu=True)
398
+ train_data_dict[try_n] = {"train_features": features, "train_labels": labels}
399
+
400
+ test_data_loader = make_data_loader(
401
+ dataset=DatasetWithEnumeratedTargets(
402
+ test_dataset, pad_dataset=True, num_replicas=distributed.get_global_size()
403
+ ),
404
+ batch_size=batch_size,
405
+ num_workers=num_workers,
406
+ sampler_type=SamplerType.DISTRIBUTED,
407
+ drop_last=False,
408
+ shuffle=False,
409
+ persistent_workers=True,
410
+ collate_fn=None,
411
+ )
412
+ metric_collection = build_metric(metric_type, num_classes=num_classes)
413
+
414
+ device = torch.cuda.current_device()
415
+ partial_knn_module = partial(
416
+ KnnModule,
417
+ T=temperature,
418
+ device=device,
419
+ num_classes=num_classes,
420
+ )
421
+
422
+ # ============ evaluation ... ============
423
+ logger.info("Start the k-NN classification.")
424
+ eval_metrics_dict = {}
425
+
426
+ for try_ in train_data_dict.keys():
427
+ train_features, train_labels = train_data_dict[try_]["train_features"], train_data_dict[try_]["train_labels"]
428
+ k_list = sorted(set([el if el < len(train_features) else len(train_features) for el in nb_knn]))
429
+ knn_module = partial_knn_module(train_features=train_features, train_labels=train_labels, nb_knn=k_list)
430
+ postprocessors, metrics = {k: DictKeysModule([k]) for k in k_list}, {
431
+ k: metric_collection.clone() for k in k_list
432
+ }
433
+ _, eval_metrics, _ = evaluate_with_accumulate(
434
+ SequentialWithKwargs(model, knn_module),
435
+ test_data_loader,
436
+ postprocessors,
437
+ metrics,
438
+ device,
439
+ accumulate_results=False,
440
+ )
441
+ for k in k_list:
442
+ if k not in eval_metrics_dict:
443
+ eval_metrics_dict[k] = {}
444
+ eval_metrics_dict[k][try_] = {metric: v.item() * 100.0 for metric, v in eval_metrics[k].items()}
445
+
446
+ if len(train_data_dict) > 1:
447
+ return {k: average_metrics(eval_metrics_dict[k]) for k in eval_metrics_dict.keys()}
448
+
449
+ return {k: eval_metrics_dict[k][0] for k in eval_metrics_dict.keys()}
450
+
451
+
452
+ def main(args):
453
+ model, autocast_dtype = setup_and_build_model(args)
454
+ eval_knn_with_model(
455
+ model=model,
456
+ output_dir=args.output_dir,
457
+ train_dataset_str=args.train_dataset_str,
458
+ val_dataset_str=args.val_dataset_str,
459
+ nb_knn=args.nb_knn,
460
+ temperature=args.temperature,
461
+ autocast_dtype=autocast_dtype,
462
+ transform=None,
463
+ metric_type=args.metric_type,
464
+ batch_size=args.batch_size,
465
+ num_workers=5,
466
+ leave_one_out_dataset=args.leave_one_out_dataset,
467
+ resize_size=args.resize_size,
468
+ crop_size=args.crop_size,
469
+ avgpool=args.avgpool,
470
+ bag_of_channels=args.bag_of_channels,
471
+ )
472
+ return 0
473
+
474
+
475
+ if __name__ == "__main__":
476
+ description = "k-NN evaluation on models trained with bag of channel strategy or cell dino"
477
+ args_parser = get_args_parser(description=description)
478
+ args = args_parser.parse_args()
479
+ sys.exit(main(args))
dinov2/eval/cell_dino/linear.py ADDED
@@ -0,0 +1,1048 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the CC-by-NC licence,
4
+ # found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree.
5
+
6
+ import argparse
7
+ from functools import partial
8
+ import json
9
+ import logging
10
+ import os
11
+ import sys
12
+ from typing import Any, Callable, Dict, Optional, Tuple, List
13
+ from enum import Enum
14
+ from dataclasses import dataclass
15
+
16
+ from sklearn.metrics import f1_score
17
+ import numpy as np
18
+ import pandas as pd
19
+ import torch
20
+ import torch.nn as nn
21
+ from torch.utils.data import TensorDataset
22
+ from torch.nn.parallel import DistributedDataParallel
23
+
24
+
25
+ from dinov2.data import SamplerType, make_data_loader, make_dataset, DatasetWithEnumeratedTargets
26
+ from dinov2.data.cell_dino.transforms import NormalizationType, make_classification_eval_cell_transform
27
+ import dinov2.distributed as distributed
28
+ from dinov2.eval.metrics import MetricType, build_metric
29
+ from dinov2.eval.setup import get_args_parser as get_setup_args_parser
30
+ from dinov2.eval.setup import setup_and_build_model
31
+ from dinov2.eval.cell_dino.utils import (
32
+ evaluate_with_accumulate,
33
+ LossType,
34
+ average_metrics,
35
+ create_train_dataset_dict,
36
+ get_num_classes,
37
+ extract_features_for_dataset_dict,
38
+ )
39
+ from dinov2.eval.utils import ModelWithIntermediateLayers
40
+ from dinov2.logging import MetricLogger
41
+ from dinov2.utils.checkpoint import build_periodic_checkpointer, resume_or_load
42
+
43
+ logger = logging.getLogger("dinov2")
44
+
45
+ """
46
+ List of changes with respect to the standard linear evaluation script:
47
+
48
+ bag of channel option : SCALE ADAPTIVE STRATEGY
49
+
50
+ Adam optimizer instead of SGD
51
+ Scheduler : two options : onecycleLR or CosineAnnealingLR
52
+ the transforms/normalization are different, now calling make_classification_eval_cell_transform
53
+ add binary cross entropy loss option for protein localization
54
+ change the definition of the num_classes using get_num_classes
55
+ change of some default parameters (batch_size, epoch_length, epochs, lrs)
56
+ defined n_last_blocks option
57
+ avgpool option
58
+ leave one out strategy for CHAMMI evaluation
59
+ grid search for optimal weight decay
60
+ """
61
+
62
+
63
+ def get_args_parser(
64
+ description: Optional[str] = None,
65
+ parents: Optional[List[argparse.ArgumentParser]] = None,
66
+ add_help: bool = True,
67
+ ):
68
+ parents = parents or []
69
+ setup_args_parser = get_setup_args_parser(parents=parents, add_help=False)
70
+ parents = [setup_args_parser]
71
+ parser = argparse.ArgumentParser(
72
+ description=description,
73
+ parents=parents,
74
+ add_help=add_help,
75
+ )
76
+ parser.add_argument(
77
+ "--train-dataset",
78
+ dest="train_dataset_str",
79
+ type=str,
80
+ help="Training dataset",
81
+ )
82
+ parser.add_argument(
83
+ "--val-dataset",
84
+ dest="val_dataset_str",
85
+ type=str,
86
+ help="Validation dataset",
87
+ )
88
+ parser.add_argument(
89
+ "--test-datasets",
90
+ dest="test_dataset_strs",
91
+ type=str,
92
+ nargs="+",
93
+ help="Test datasets, none to reuse the validation dataset",
94
+ )
95
+ parser.add_argument(
96
+ "--epochs",
97
+ type=int,
98
+ help="Number of training epochs",
99
+ )
100
+ parser.add_argument(
101
+ "--batch-size",
102
+ type=int,
103
+ help="Batch Size (per GPU)",
104
+ )
105
+ parser.add_argument(
106
+ "--num-workers",
107
+ type=int,
108
+ help="Number de Workers",
109
+ )
110
+ parser.add_argument(
111
+ "--epoch-length",
112
+ type=int,
113
+ help="Length of an epoch in number of iterations",
114
+ )
115
+ parser.add_argument(
116
+ "--save-checkpoint-frequency",
117
+ type=int,
118
+ help="Number of epochs between two named checkpoint saves.",
119
+ )
120
+ parser.add_argument(
121
+ "--eval-period-iterations",
122
+ type=int,
123
+ help="Number of iterations between two evaluations.",
124
+ )
125
+ parser.add_argument(
126
+ "--learning-rates",
127
+ nargs="+",
128
+ type=float,
129
+ help="Learning rates to grid search.",
130
+ )
131
+ parser.add_argument(
132
+ "--weight_decays",
133
+ nargs="+",
134
+ type=float,
135
+ help="Weight decays to grid search.",
136
+ )
137
+ parser.add_argument(
138
+ "--n-last-blocks",
139
+ type=int,
140
+ help="number of backbone last blocks used for the linear classifier",
141
+ )
142
+ parser.add_argument(
143
+ "--no-resume",
144
+ action="store_true",
145
+ help="Whether to not resume from existing checkpoints",
146
+ )
147
+ parser.add_argument(
148
+ "--val-metric-type",
149
+ type=MetricType,
150
+ choices=list(MetricType),
151
+ help="Validation metric",
152
+ )
153
+ parser.add_argument(
154
+ "--test-metric-types",
155
+ type=MetricType,
156
+ choices=list(MetricType),
157
+ nargs="+",
158
+ help="Evaluation metric",
159
+ )
160
+ parser.add_argument(
161
+ "--classifier-fpath",
162
+ type=str,
163
+ help="Path to a file containing pretrained linear classifiers",
164
+ )
165
+ parser.add_argument(
166
+ "--val-class-mapping-fpath",
167
+ type=str,
168
+ help="Path to a file containing a mapping to adjust classifier outputs",
169
+ )
170
+ parser.add_argument(
171
+ "--test-class-mapping-fpaths",
172
+ nargs="+",
173
+ type=str,
174
+ help="Path to a file containing a mapping to adjust classifier outputs",
175
+ )
176
+ parser.add_argument(
177
+ "--loss-type",
178
+ type=LossType,
179
+ help="Cross Entropy or Binary Cross Entropy, default cross entropy loss",
180
+ )
181
+ parser.add_argument(
182
+ "--bag-of-channels",
183
+ action="store_true",
184
+ help='Whether to use the "bag of channels" channel adaptive strategy',
185
+ )
186
+ parser.add_argument(
187
+ "--leave-one-out-dataset",
188
+ type=str,
189
+ help="Path with indexes to use the leave one out strategy for CHAMMI_CP task 3 and CHAMMI_HPA task 4",
190
+ )
191
+ parser.add_argument(
192
+ "--crop-size",
193
+ type=int,
194
+ help="crop size for train and eval",
195
+ )
196
+ parser.add_argument(
197
+ "--resize-size",
198
+ type=int,
199
+ help="resize size for image just before crop. 0: no resize",
200
+ )
201
+ parser.add_argument(
202
+ "--avgpool",
203
+ action="store_true",
204
+ help="Whether to use average pooling of path tokens in addition to CLS tokens",
205
+ )
206
+ parser.add_argument(
207
+ "--scheduler",
208
+ type=SchedulerType,
209
+ help="Scheduler type",
210
+ )
211
+
212
+ parser.set_defaults(
213
+ train_dataset_str="ImageNet:split=TRAIN",
214
+ val_dataset_str="ImageNet:split=VAL",
215
+ test_dataset_strs=None,
216
+ epochs=30,
217
+ batch_size=64,
218
+ num_workers=8,
219
+ epoch_length=145,
220
+ save_checkpoint_frequency=1250,
221
+ eval_period_iterations=1250,
222
+ learning_rates=[1e-5, 2e-5, 5e-5, 1e-4, 2e-4, 5e-4, 1e-3, 2e-3, 5e-3, 1e-2, 2e-2, 5e-2, 1e-1, 2e-1, 5e-1, 1.0],
223
+ weight_decays=[0.0, 0.0001, 1.0e-05],
224
+ val_metric_type=MetricType.MEAN_ACCURACY,
225
+ test_metric_types=None,
226
+ classifier_fpath=None,
227
+ val_class_mapping_fpath=None,
228
+ test_class_mapping_fpaths=[None],
229
+ loss_type=LossType.CROSS_ENTROPY,
230
+ crop_size=384,
231
+ resize_size=0,
232
+ n_last_blocks=4,
233
+ avgpool=False,
234
+ scheduler=SchedulerType.COSINE_ANNEALING,
235
+ )
236
+ return parser
237
+
238
+
239
+ def has_ddp_wrapper(m: nn.Module) -> bool:
240
+ return isinstance(m, DistributedDataParallel)
241
+
242
+
243
+ def remove_ddp_wrapper(m: nn.Module) -> nn.Module:
244
+ return m.module if has_ddp_wrapper(m) else m
245
+
246
+
247
+ def create_linear_input(x_tokens_list, use_n_blocks, use_avgpool, bag_of_channels):
248
+ intermediate_output = x_tokens_list[-use_n_blocks:]
249
+ output = torch.cat([class_token for _, class_token in intermediate_output], dim=-1)
250
+ if bag_of_channels:
251
+ if use_avgpool:
252
+ output = torch.cat(
253
+ (
254
+ output,
255
+ torch.mean(intermediate_output[-1][0], dim=-2).reshape(intermediate_output[-1][0].shape[0], -1),
256
+ # average pooling of patch tokens: average over N, then concatenate channels if single-channel patch model
257
+ ),
258
+ dim=-1,
259
+ ) # concatenate average pooling of patch tokens to concatenated patch tokens
260
+ else:
261
+ if use_avgpool:
262
+ output = torch.cat(
263
+ (
264
+ output,
265
+ torch.mean(intermediate_output[-1][0], dim=1), # patch tokens
266
+ ),
267
+ dim=-1,
268
+ )
269
+ output = output.reshape(output.shape[0], -1)
270
+ return output.float()
271
+
272
+
273
+ class LinearClassifier(nn.Module):
274
+ """Linear layer to train on top of frozen features"""
275
+
276
+ def __init__(
277
+ self, out_dim, use_n_blocks, use_avgpool, num_classes=1000, bag_of_channels=False, leave_one_out=False
278
+ ):
279
+ super().__init__()
280
+ self.out_dim = out_dim
281
+ self.use_n_blocks = use_n_blocks
282
+ self.use_avgpool = use_avgpool
283
+ self.num_classes = num_classes
284
+ self.bag_of_channels = bag_of_channels
285
+ self.leave_one_out = leave_one_out
286
+ self.linear = nn.Linear(out_dim, num_classes)
287
+ self.linear.weight.data.normal_(mean=0.0, std=0.01)
288
+ self.linear.bias.data.zero_()
289
+
290
+ def forward(self, x_tokens_list):
291
+ if self.leave_one_out:
292
+ return self.linear(x_tokens_list)
293
+ output = create_linear_input(x_tokens_list, self.use_n_blocks, self.use_avgpool, self.bag_of_channels)
294
+ return self.linear(output)
295
+
296
+
297
+ class AllClassifiers(nn.Module):
298
+ def __init__(self, classifiers_dict):
299
+ super().__init__()
300
+ self.classifiers_dict = nn.ModuleDict()
301
+ self.classifiers_dict.update(classifiers_dict)
302
+
303
+ def forward(self, inputs):
304
+ return {k: v.forward(inputs) for k, v in self.classifiers_dict.items()}
305
+
306
+ def __len__(self):
307
+ return len(self.classifiers_dict)
308
+
309
+
310
+ class LinearPostprocessor(nn.Module):
311
+ def __init__(self, linear_classifier, class_mapping=None):
312
+ super().__init__()
313
+ self.linear_classifier = linear_classifier
314
+ self.register_buffer("class_mapping", None if class_mapping is None else torch.LongTensor(class_mapping))
315
+
316
+ def forward(self, samples, targets):
317
+ preds = self.linear_classifier(samples)
318
+ return {
319
+ "preds": preds[:, self.class_mapping] if self.class_mapping is not None else preds,
320
+ "target": targets,
321
+ }
322
+
323
+
324
+ def scale_lr(learning_rates, batch_size):
325
+ return learning_rates * (batch_size * distributed.get_global_size()) / 256.0
326
+
327
+
328
+ def setup_linear_classifiers(
329
+ sample_output,
330
+ n_last_blocks_list,
331
+ learning_rates,
332
+ weight_decays,
333
+ batch_size,
334
+ num_classes=1000,
335
+ bag_of_channels=False,
336
+ leave_one_out=False,
337
+ avgpool=False,
338
+ ):
339
+ linear_classifiers_dict = nn.ModuleDict()
340
+ avgpool_value = avgpool
341
+ optim_param_groups = []
342
+ for n in n_last_blocks_list:
343
+ for avgpool in [avgpool_value]:
344
+ for _lr in learning_rates:
345
+ for wd in weight_decays:
346
+ lr = scale_lr(_lr, batch_size)
347
+ out_dim = create_linear_input(
348
+ sample_output, use_n_blocks=n, use_avgpool=avgpool, bag_of_channels=bag_of_channels
349
+ ).shape[1]
350
+ linear_classifier = LinearClassifier(
351
+ out_dim,
352
+ use_n_blocks=n,
353
+ use_avgpool=avgpool,
354
+ num_classes=num_classes,
355
+ bag_of_channels=bag_of_channels,
356
+ leave_one_out=leave_one_out,
357
+ )
358
+ linear_classifier = linear_classifier.cuda()
359
+ linear_classifiers_dict[
360
+ f"classifier_{n}_blocks_avgpool_{avgpool}_lr_{lr:.5f}_wd_{wd:.2E}".replace(".", "_")
361
+ ] = linear_classifier
362
+ optim_param_groups.append({"params": linear_classifier.parameters(), "lr": lr, "weight_decay": wd})
363
+
364
+ linear_classifiers = AllClassifiers(linear_classifiers_dict)
365
+ if distributed.is_enabled():
366
+ linear_classifiers = nn.parallel.DistributedDataParallel(linear_classifiers)
367
+
368
+ return linear_classifiers, optim_param_groups
369
+
370
+
371
+ def make_eval_data_loader(
372
+ *,
373
+ test_dataset_str_or_path_or_loo_dataset,
374
+ config,
375
+ batch_size,
376
+ num_workers,
377
+ ):
378
+ if isinstance(test_dataset_str_or_path_or_loo_dataset, str):
379
+ logger.info(f"Loading dataset {test_dataset_str_or_path_or_loo_dataset}")
380
+ transform = make_classification_eval_cell_transform(
381
+ normalization_type=NormalizationType.SELF_NORM_CENTER_CROP,
382
+ resize_size=config["resize_size"],
383
+ crop_size=config["crop_size"],
384
+ )
385
+ test_dataset = make_dataset(dataset_str=test_dataset_str_or_path_or_loo_dataset, transform=transform)
386
+ collate_fn = None
387
+ else:
388
+ logger.info("Making data loader for feature dataset (typical in leave one out evaluation)")
389
+ test_dataset = test_dataset_str_or_path_or_loo_dataset
390
+ collate_fn = None
391
+ class_mapping = None
392
+ if hasattr(test_dataset, "get_imagenet_class_mapping"):
393
+ class_mapping = test_dataset.get_imagenet_class_mapping()
394
+
395
+ test_data_loader = make_data_loader(
396
+ dataset=DatasetWithEnumeratedTargets(
397
+ test_dataset, pad_dataset=True, num_replicas=distributed.get_global_size()
398
+ ),
399
+ batch_size=batch_size,
400
+ num_workers=num_workers,
401
+ sampler_type=SamplerType.DISTRIBUTED,
402
+ drop_last=False,
403
+ shuffle=False,
404
+ persistent_workers=False,
405
+ collate_fn=collate_fn,
406
+ )
407
+ return test_data_loader, class_mapping
408
+
409
+
410
+ @dataclass
411
+ class Evaluator:
412
+ batch_size: int
413
+ num_workers: int
414
+ dataset_str_or_path: str
415
+ config: Dict
416
+ metric_type: MetricType
417
+ metrics_file_path: str
418
+ training_num_classes: int
419
+ save_results_func: Optional[Callable]
420
+ val_dataset_loo: Optional[TensorDataset] = None
421
+
422
+ def __post_init__(self):
423
+ self.main_metric_name = f"{self.dataset_str_or_path}_accuracy"
424
+
425
+ if self.val_dataset_loo is not None:
426
+ self.dataset_str_or_path = self.val_dataset_loo
427
+
428
+ self.data_loader, self.class_mapping = make_eval_data_loader(
429
+ test_dataset_str_or_path_or_loo_dataset=self.dataset_str_or_path,
430
+ batch_size=self.batch_size,
431
+ num_workers=self.num_workers,
432
+ config=self.config,
433
+ )
434
+
435
+ @torch.no_grad()
436
+ def _evaluate_linear_classifiers(
437
+ self,
438
+ *,
439
+ feature_model,
440
+ linear_classifiers,
441
+ iteration,
442
+ prefixstring="",
443
+ best_classifier_on_val=None,
444
+ accumulate_results=False,
445
+ test_mode=False,
446
+ ) -> Tuple[Dict[str, Any], Optional[Dict[str, torch.Tensor]]]:
447
+ logger.info("running validation !")
448
+
449
+ num_classes = len(self.class_mapping) if self.class_mapping is not None else self.training_num_classes
450
+ metric = build_metric(self.metric_type, num_classes=num_classes)
451
+ postprocessors = {
452
+ k: LinearPostprocessor(v, self.class_mapping) for k, v in linear_classifiers.classifiers_dict.items()
453
+ }
454
+ metrics = {k: metric.clone() for k in linear_classifiers.classifiers_dict}
455
+
456
+ _, results_dict_temp, accumulated_results = evaluate_with_accumulate(
457
+ feature_model,
458
+ self.data_loader,
459
+ postprocessors,
460
+ metrics,
461
+ torch.cuda.current_device(),
462
+ accumulate_results=accumulate_results,
463
+ leave_one_out=self.config["leave_one_out"],
464
+ test_mode=test_mode,
465
+ )
466
+
467
+ logger.info("")
468
+ results_dict = {}
469
+ max_accuracy = 0
470
+ best_classifier = ""
471
+ for _, (classifier_string, metric) in enumerate(results_dict_temp.items()):
472
+ logger.info(f"{prefixstring} -- Classifier: {classifier_string} * {metric}")
473
+ if (
474
+ best_classifier_on_val is None and metric["top-1"].item() > max_accuracy
475
+ ) or classifier_string == best_classifier_on_val:
476
+ max_accuracy = metric["top-1"].item()
477
+ best_classifier = classifier_string
478
+
479
+ results_dict["best_classifier"] = {"name": best_classifier, "accuracy": max_accuracy}
480
+
481
+ logger.info(f"best classifier: {results_dict['best_classifier']}")
482
+
483
+ accumulated_best_results = None
484
+ if test_mode:
485
+ accumulated_best_results = accumulated_results
486
+ elif accumulated_results is not None:
487
+ accumulated_best_results = accumulated_results[best_classifier]
488
+
489
+ if distributed.is_main_process():
490
+ with open(self.metrics_file_path, "a") as f:
491
+ f.write(f"iter: {iteration}\n")
492
+ for k, v in results_dict.items():
493
+ f.write(json.dumps({k: v}) + "\n")
494
+ f.write("\n")
495
+
496
+ return results_dict, accumulated_best_results
497
+
498
+ def evaluate_and_maybe_save(
499
+ self,
500
+ feature_model,
501
+ linear_classifiers,
502
+ iteration: int,
503
+ best_classifier_on_val: Optional[Any] = None,
504
+ save_filename_suffix: str = "",
505
+ prefixstring: str = "",
506
+ test_mode: bool = False,
507
+ ):
508
+ logger.info(f"Testing on {self.dataset_str_or_path}")
509
+ save_results = self.save_results_func is not None
510
+ full_results_dict, accumulated_best_results = self._evaluate_linear_classifiers(
511
+ feature_model=feature_model,
512
+ linear_classifiers=remove_ddp_wrapper(linear_classifiers),
513
+ iteration=iteration,
514
+ prefixstring=prefixstring,
515
+ best_classifier_on_val=best_classifier_on_val,
516
+ accumulate_results=save_results,
517
+ test_mode=test_mode,
518
+ )
519
+ if self.save_results_func is not None:
520
+ self.save_results_func(
521
+ filename_suffix=f"{self.dataset_str_or_path}{save_filename_suffix}", **accumulated_best_results
522
+ )
523
+
524
+ results_dict = {
525
+ self.main_metric_name: 100.0 * full_results_dict["best_classifier"]["accuracy"],
526
+ "best_classifier": full_results_dict["best_classifier"]["name"],
527
+ }
528
+ return results_dict, accumulated_best_results
529
+
530
+
531
+ def make_evaluators(
532
+ config: Dict,
533
+ val_metric_type: MetricType,
534
+ val_dataset: str,
535
+ metric_type: MetricType,
536
+ metrics_file_path: str,
537
+ training_num_classes: int,
538
+ save_results_func: Optional[Callable],
539
+ val_dataset_loo: Optional[TensorDataset] = None,
540
+ ):
541
+ test_metric_types = config["test_metric_types"]
542
+ test_dataset_strs = config["test_datasets"]
543
+ if test_dataset_strs is None:
544
+ test_dataset_strs = (config["val_dataset"],)
545
+ if test_metric_types is None:
546
+ test_metric_types = (val_metric_type,)
547
+ else:
548
+ assert len(test_metric_types) == len(config["test_datasets"])
549
+
550
+ val_evaluator, *test_evaluators = [
551
+ Evaluator(
552
+ dataset_str_or_path=dataset_str_or_path,
553
+ batch_size=config["batch_size"],
554
+ num_workers=config["num_workers"],
555
+ config=config,
556
+ metric_type=metric_type,
557
+ metrics_file_path=metrics_file_path,
558
+ training_num_classes=training_num_classes,
559
+ save_results_func=save_results_func,
560
+ val_dataset_loo=val_dataset_loo,
561
+ )
562
+ for dataset_str_or_path, metric_type in zip(
563
+ (val_dataset,) + tuple(test_dataset_strs),
564
+ (val_metric_type,) + tuple(test_metric_types),
565
+ )
566
+ ]
567
+ return val_evaluator, test_evaluators
568
+
569
+
570
+ class SchedulerType(Enum):
571
+ COSINE_ANNEALING = "cosine_annealing"
572
+ ONE_CYCLE = "one_cycle"
573
+
574
+ def get_scheduler(self, optimizer, optim_param_groups, epoch_length, epochs, max_iter):
575
+ if self == SchedulerType.ONE_CYCLE:
576
+ lr_list = [optim_param_groups[i]["lr"] for i in range(len(optim_param_groups))]
577
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(
578
+ optimizer, max_lr=lr_list, steps_per_epoch=epoch_length, epochs=epochs
579
+ )
580
+ else:
581
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, max_iter, eta_min=0)
582
+ print("CosineAnnealingLR scheduler")
583
+ return scheduler
584
+
585
+
586
+ def setup_linear_training(
587
+ *,
588
+ config: Dict,
589
+ sample_output: torch.Tensor,
590
+ training_num_classes: int,
591
+ checkpoint_output_dir: str,
592
+ ):
593
+ linear_classifiers, optim_param_groups = setup_linear_classifiers(
594
+ sample_output,
595
+ config["n_last_blocks_list"],
596
+ config["learning_rates"],
597
+ config["weight_decays"],
598
+ config["batch_size"],
599
+ training_num_classes,
600
+ config["bag_of_channels"],
601
+ config["leave_one_out"],
602
+ config["avgpool"],
603
+ )
604
+ max_iter = config["epochs"] * config["epoch_length"]
605
+ optimizer = torch.optim.AdamW(optim_param_groups, weight_decay=0)
606
+
607
+ scheduler = config["scheduler"].get_scheduler(
608
+ optimizer=optimizer,
609
+ optim_param_groups=optim_param_groups,
610
+ epoch_length=config["epoch_length"],
611
+ epochs=config["epochs"],
612
+ max_iter=max_iter,
613
+ )
614
+ checkpoint_period = config["save_checkpoint_iterations"] or config["epoch_length"]
615
+ periodic_checkpointer = build_periodic_checkpointer(
616
+ linear_classifiers,
617
+ checkpoint_output_dir,
618
+ optimizer=optimizer,
619
+ scheduler=scheduler,
620
+ period=checkpoint_period,
621
+ max_iter=max_iter,
622
+ max_to_keep=None,
623
+ )
624
+ checkpoint = resume_or_load(periodic_checkpointer, config["classifier_fpath"] or "", resume=config["resume"])
625
+
626
+ start_iter = checkpoint.get("iteration", -1) + 1
627
+ best_accuracy = checkpoint.get("best_accuracy", -1)
628
+
629
+ if config["loss_type"] == LossType.BINARY_CROSS_ENTROPY:
630
+ criterion = nn.BCEWithLogitsLoss()
631
+ else:
632
+ criterion = nn.CrossEntropyLoss()
633
+
634
+ return (
635
+ linear_classifiers,
636
+ start_iter,
637
+ max_iter,
638
+ criterion,
639
+ optimizer,
640
+ scheduler,
641
+ periodic_checkpointer,
642
+ best_accuracy,
643
+ )
644
+
645
+
646
+ def train_linear_classifiers(
647
+ *,
648
+ feature_model,
649
+ train_dataset,
650
+ train_config: Dict,
651
+ training_num_classes: int,
652
+ val_evaluator: Evaluator,
653
+ checkpoint_output_dir: str,
654
+ sample_output: Optional[torch.Tensor] = None,
655
+ ):
656
+
657
+ if train_config["leave_one_out"]:
658
+ assert sample_output is not None, "sample_output should be passed as argument when using leave_one_out."
659
+ else:
660
+ sample_output = feature_model(train_dataset[0][0].unsqueeze(0).cuda())
661
+
662
+ (
663
+ linear_classifiers,
664
+ start_iter,
665
+ max_iter,
666
+ criterion,
667
+ optimizer,
668
+ scheduler,
669
+ periodic_checkpointer,
670
+ best_accuracy,
671
+ ) = setup_linear_training(
672
+ config=train_config,
673
+ sample_output=sample_output,
674
+ training_num_classes=training_num_classes,
675
+ checkpoint_output_dir=checkpoint_output_dir,
676
+ )
677
+
678
+ sampler_type = SamplerType.INFINITE
679
+ train_data_loader = make_data_loader(
680
+ dataset=train_dataset,
681
+ batch_size=train_config["batch_size"],
682
+ num_workers=train_config["num_workers"],
683
+ shuffle=True,
684
+ seed=0,
685
+ sampler_type=sampler_type,
686
+ sampler_advance=start_iter,
687
+ drop_last=True,
688
+ persistent_workers=True,
689
+ )
690
+ eval_period = train_config["eval_period_iterations"] or train_config["epoch_length"]
691
+ iteration = start_iter
692
+ logger.info("Starting training from iteration {}".format(start_iter))
693
+ metric_logger = MetricLogger(delimiter=" ")
694
+ header = "Training"
695
+
696
+ for data, labels in metric_logger.log_every(
697
+ train_data_loader,
698
+ 10,
699
+ header,
700
+ max_iter,
701
+ start_iter,
702
+ ):
703
+ data = data.cuda(non_blocking=True)
704
+ labels = labels.cuda(non_blocking=True)
705
+
706
+ if not train_config["leave_one_out"]:
707
+ in_classifier = feature_model(data)
708
+ else:
709
+ in_classifier = data
710
+
711
+ outputs = linear_classifiers(in_classifier)
712
+
713
+ if len(labels.shape) > 1:
714
+ labels = labels.float()
715
+ losses = {f"loss_{k}": criterion(v, labels) for k, v in outputs.items()}
716
+ loss = sum(losses.values())
717
+
718
+ optimizer.zero_grad()
719
+ loss.backward()
720
+
721
+ optimizer.step()
722
+ scheduler.step()
723
+
724
+ if iteration % 10 == 0:
725
+ torch.cuda.synchronize()
726
+ metric_logger.update(loss=loss.item())
727
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
728
+
729
+ periodic_checkpointer.step(iteration=iteration, best_accuracy=best_accuracy)
730
+
731
+ if eval_period > 0 and (iteration + 1) % eval_period == 0 and iteration != max_iter - 1:
732
+ val_results_dict, _ = val_evaluator.evaluate_and_maybe_save(
733
+ feature_model=feature_model,
734
+ linear_classifiers=linear_classifiers,
735
+ prefixstring=f"ITER: {iteration}",
736
+ iteration=iteration,
737
+ )
738
+ val_accuracy = val_results_dict[val_evaluator.main_metric_name]
739
+ if val_accuracy >= best_accuracy:
740
+ best_accuracy = val_accuracy
741
+ periodic_checkpointer.save_best(iteration=iteration, best_accuracy=best_accuracy)
742
+ torch.distributed.barrier()
743
+
744
+ iteration = iteration + 1
745
+
746
+ return feature_model, linear_classifiers, iteration, periodic_checkpointer
747
+
748
+
749
+ def eval_linear_with_model(
750
+ model,
751
+ output_dir,
752
+ train_dataset_str,
753
+ val_dataset_str,
754
+ batch_size,
755
+ epochs,
756
+ epoch_length,
757
+ num_workers,
758
+ save_checkpoint_frequency,
759
+ eval_period_iterations,
760
+ learning_rates,
761
+ weight_decays,
762
+ autocast_dtype,
763
+ test_dataset_strs=None,
764
+ resume=True,
765
+ classifier_fpath=None,
766
+ val_metric_type=MetricType.MEAN_ACCURACY,
767
+ test_metric_types=None,
768
+ loss_type=LossType.CROSS_ENTROPY,
769
+ bag_of_channels=False,
770
+ leave_one_out_dataset="",
771
+ resize_size=0,
772
+ crop_size=384,
773
+ n_last_blocks=4,
774
+ avgpool=False,
775
+ scheduler=SchedulerType.COSINE_ANNEALING,
776
+ ):
777
+
778
+ if leave_one_out_dataset == "" or leave_one_out_dataset is None:
779
+ leave_one_out = False
780
+ else:
781
+ logger.info("Reading the leave-one-out label metadata.")
782
+
783
+ leave_one_out_indices = {}
784
+ metadata = pd.read_csv(leave_one_out_dataset)
785
+ if "HPA" in leave_one_out_dataset:
786
+ metadata = metadata[metadata["Task_three"]].reset_index()
787
+ leave_one_out_label_type = "cell_type"
788
+ else:
789
+ metadata = metadata[metadata["Task_four"]].reset_index()
790
+ leave_one_out_label_type = "Plate"
791
+ leave_one_out_labels = metadata[leave_one_out_label_type].unique()
792
+
793
+ for leave_one_out_label in leave_one_out_labels:
794
+ leave_one_out_indices[leave_one_out_label] = np.array(
795
+ metadata[metadata[leave_one_out_label_type] == leave_one_out_label].index.values
796
+ )
797
+
798
+ leave_one_out = True
799
+
800
+ train_transform = make_classification_eval_cell_transform(
801
+ normalization_type=NormalizationType.SELF_NORM_AUG_DECODER, crop_size=crop_size, resize_size=resize_size
802
+ )
803
+ print("train_transform", train_transform)
804
+ train_dataset = make_dataset(
805
+ dataset_str=train_dataset_str,
806
+ transform=train_transform,
807
+ )
808
+
809
+ training_num_classes = get_num_classes(train_dataset)
810
+ if leave_one_out:
811
+ training_num_classes += train_dataset.num_additional_labels_loo_eval
812
+ train_dataset_dict = create_train_dataset_dict(train_dataset)
813
+ n_last_blocks_list = [n_last_blocks]
814
+ n_last_blocks = max(n_last_blocks_list)
815
+ dataset_use_cache = True
816
+ autocast_ctx = partial(torch.cuda.amp.autocast, enabled=True, dtype=autocast_dtype)
817
+ feature_model = ModelWithIntermediateLayers(model, n_last_blocks, autocast_ctx)
818
+
819
+ if bag_of_channels:
820
+ sample = train_dataset[0][0].unsqueeze(0)
821
+ sample_output = feature_model(sample.cuda())
822
+
823
+ if leave_one_out:
824
+ loo_dict = {}
825
+ train_data_dict = extract_features_for_dataset_dict(
826
+ feature_model,
827
+ train_dataset_dict,
828
+ batch_size,
829
+ num_workers,
830
+ gather_on_cpu=True,
831
+ avgpool=avgpool,
832
+ )
833
+ val_dataset = make_dataset(
834
+ dataset_str=val_dataset_str,
835
+ transform=make_classification_eval_cell_transform(
836
+ normalization_type=NormalizationType.SELF_NORM_CENTER_CROP, crop_size=crop_size, resize_size=resize_size
837
+ ),
838
+ )
839
+ val_dataset_dict = create_train_dataset_dict(val_dataset)
840
+ val_data_dict = extract_features_for_dataset_dict(
841
+ feature_model,
842
+ val_dataset_dict,
843
+ batch_size,
844
+ num_workers,
845
+ gather_on_cpu=True,
846
+ avgpool=avgpool,
847
+ )
848
+
849
+ train_features = train_data_dict[0]["train_features"]
850
+ train_labels = train_data_dict[0]["train_labels"]
851
+ val_features = val_data_dict[0]["train_features"]
852
+ val_labels = val_data_dict[0]["train_labels"]
853
+
854
+ for loo_label, loo_indices in leave_one_out_indices.items():
855
+ loo_for_training_indices = torch.ones(val_features.shape[0], dtype=bool)
856
+ loo_for_training_indices[loo_indices] = False
857
+ loo_for_val_indices = torch.zeros(val_features.shape[0], dtype=bool)
858
+ loo_for_val_indices[loo_indices] = True
859
+
860
+ loo_dict[loo_label] = {
861
+ "train_features": torch.cat([train_features, val_features[loo_for_training_indices]]),
862
+ "train_labels": torch.cat([train_labels, val_labels[loo_for_training_indices]]),
863
+ "val_features": val_features[loo_indices],
864
+ "val_labels": val_labels[loo_indices],
865
+ }
866
+ save_results_func = None
867
+ # if config.save_results:
868
+ # save_results_func = partial(default_save_results_func, output_dir=output_dir)
869
+
870
+ metrics_file_path = os.path.join(output_dir, "results_eval_linear.json")
871
+ periodic_checkpointers: list = []
872
+
873
+ train_config = {
874
+ "learning_rates": learning_rates,
875
+ "weight_decays": weight_decays,
876
+ "batch_size": batch_size,
877
+ "num_workers": num_workers,
878
+ "dataset_use_cache": dataset_use_cache,
879
+ "eval_period_iterations": eval_period_iterations,
880
+ "epoch_length": epoch_length,
881
+ "leave_one_out": leave_one_out,
882
+ "bag_of_channels": bag_of_channels,
883
+ "n_last_blocks_list": n_last_blocks_list,
884
+ "epochs": epochs,
885
+ "loss_type": loss_type,
886
+ "resume": resume,
887
+ "save_checkpoint_iterations": save_checkpoint_frequency,
888
+ "classifier_fpath": classifier_fpath,
889
+ "avgpool": avgpool,
890
+ "scheduler": scheduler,
891
+ }
892
+ config = {
893
+ "test_metric_types": test_metric_types,
894
+ "test_datasets": test_dataset_strs,
895
+ "val_metric_types": val_metric_type,
896
+ "val_dataset": val_dataset_str,
897
+ "batch_size": batch_size,
898
+ "num_workers": num_workers,
899
+ "leave_one_out": leave_one_out,
900
+ "crop_size": crop_size,
901
+ "resize_size": resize_size,
902
+ }
903
+ if not leave_one_out:
904
+ val_evaluator, test_evaluators = make_evaluators(
905
+ config=config,
906
+ val_metric_type=val_metric_type,
907
+ val_dataset=val_dataset_str,
908
+ metric_type=test_metric_types,
909
+ metrics_file_path=metrics_file_path,
910
+ training_num_classes=training_num_classes,
911
+ save_results_func=save_results_func,
912
+ )
913
+ results_dict = {}
914
+
915
+ for _try in train_dataset_dict.keys():
916
+ if len(train_dataset_dict) > 1:
917
+ checkpoint_output_dir = os.path.join(output_dir, f"checkpoints_{_try}")
918
+ save_filename_suffix = f"_{_try}"
919
+ else:
920
+ checkpoint_output_dir, save_filename_suffix = output_dir, ""
921
+ os.makedirs(checkpoint_output_dir, exist_ok=True)
922
+
923
+ feature_model, linear_classifiers, iteration, periodic_checkpointer = train_linear_classifiers(
924
+ train_config=train_config,
925
+ feature_model=feature_model,
926
+ train_dataset=train_dataset_dict[_try],
927
+ training_num_classes=training_num_classes,
928
+ val_evaluator=val_evaluator,
929
+ checkpoint_output_dir=checkpoint_output_dir,
930
+ )
931
+ periodic_checkpointers.append(periodic_checkpointer)
932
+ results_dict[_try], _ = val_evaluator.evaluate_and_maybe_save(
933
+ feature_model=feature_model,
934
+ linear_classifiers=linear_classifiers,
935
+ iteration=iteration,
936
+ save_filename_suffix=save_filename_suffix,
937
+ )
938
+ for test_evaluator in test_evaluators:
939
+ eval_results_dict, _ = test_evaluator.evaluate_and_maybe_save(
940
+ feature_model=feature_model,
941
+ linear_classifiers=linear_classifiers,
942
+ iteration=iteration,
943
+ best_classifier_on_val=results_dict[_try]["best_classifier"],
944
+ save_filename_suffix=save_filename_suffix,
945
+ )
946
+ results_dict[_try] = {**eval_results_dict, **results_dict[_try]}
947
+ if len(train_dataset_dict) > 1:
948
+ results_dict = average_metrics(results_dict, ignore_keys=["best_classifier"])
949
+ else:
950
+ results_dict = {**results_dict[_try]}
951
+ else: # if leave one out is True
952
+ test_results_dict = {}
953
+ for loo_label in loo_dict.keys():
954
+
955
+ checkpoint_output_dir, save_filename_suffix = os.path.join(output_dir, f"checkpoints_{loo_label}"), ""
956
+ os.makedirs(checkpoint_output_dir, exist_ok=True)
957
+
958
+ train_dataset_loo = TensorDataset(
959
+ loo_dict[loo_label]["train_features"], loo_dict[loo_label]["train_labels"]
960
+ )
961
+
962
+ logger.info(f"Creating leave_one_out evaluators. loo_label: {loo_label}")
963
+ val_dataset_loo = TensorDataset(loo_dict[loo_label]["val_features"], loo_dict[loo_label]["val_labels"])
964
+ val_evaluators_loo, _ = make_evaluators(
965
+ config=config,
966
+ val_metric_type=val_metric_type,
967
+ val_dataset="loo",
968
+ metric_type=test_metric_types,
969
+ metrics_file_path=metrics_file_path,
970
+ training_num_classes=training_num_classes,
971
+ save_results_func=save_results_func,
972
+ val_dataset_loo=val_dataset_loo,
973
+ )
974
+ feature_model, linear_classifiers, iteration, periodic_checkpointer = train_linear_classifiers(
975
+ feature_model=feature_model,
976
+ train_dataset=train_dataset_loo,
977
+ train_config=train_config,
978
+ training_num_classes=training_num_classes,
979
+ val_evaluator=val_evaluators_loo,
980
+ checkpoint_output_dir=checkpoint_output_dir,
981
+ sample_output=sample_output,
982
+ )
983
+ periodic_checkpointers.append(periodic_checkpointer)
984
+ _, test_results_dict[loo_label] = val_evaluators_loo.evaluate_and_maybe_save(
985
+ feature_model=feature_model,
986
+ linear_classifiers=linear_classifiers,
987
+ iteration=iteration,
988
+ save_filename_suffix=save_filename_suffix,
989
+ test_mode=True,
990
+ )
991
+ classifier_names = test_results_dict[loo_label].keys()
992
+ results_dict = {k: [[], []] for k in classifier_names}
993
+ for ll in test_results_dict.keys():
994
+ for k in classifier_names:
995
+ results_dict[k][0].append(test_results_dict[ll][k][0])
996
+ results_dict[k][1].append(test_results_dict[ll][k][1])
997
+ for k in classifier_names:
998
+ results_dict[k] = [
999
+ np.argmax(torch.cat(results_dict[k][0]).cpu().detach().numpy(), axis=1),
1000
+ torch.cat(results_dict[k][1]).cpu().detach().numpy(),
1001
+ ]
1002
+ results_dict[k] = f1_score(results_dict[k][1], results_dict[k][0], average="macro", labels=[4, 5, 6])
1003
+ logger.info(
1004
+ f"Best performance is for {max(results_dict, key=results_dict.get)}, with F1-Score of {results_dict[max(results_dict, key=results_dict.get)]}"
1005
+ )
1006
+
1007
+ logger.info("Test Results Dict " + str(results_dict))
1008
+ return results_dict
1009
+
1010
+
1011
+ def main(args):
1012
+ model, autocast_dtype = setup_and_build_model(args)
1013
+ eval_linear_with_model(
1014
+ model=model,
1015
+ output_dir=args.output_dir,
1016
+ train_dataset_str=args.train_dataset_str,
1017
+ val_dataset_str=args.val_dataset_str,
1018
+ test_dataset_strs=args.test_dataset_strs,
1019
+ batch_size=args.batch_size,
1020
+ epochs=args.epochs,
1021
+ epoch_length=args.epoch_length,
1022
+ num_workers=args.num_workers,
1023
+ save_checkpoint_frequency=args.save_checkpoint_frequency,
1024
+ eval_period_iterations=args.eval_period_iterations,
1025
+ learning_rates=args.learning_rates,
1026
+ weight_decays=args.weight_decays,
1027
+ autocast_dtype=autocast_dtype,
1028
+ resume=not args.no_resume,
1029
+ classifier_fpath=args.classifier_fpath,
1030
+ val_metric_type=args.val_metric_type,
1031
+ test_metric_types=args.test_metric_types,
1032
+ loss_type=args.loss_type,
1033
+ bag_of_channels=args.bag_of_channels,
1034
+ leave_one_out_dataset=args.leave_one_out_dataset,
1035
+ crop_size=args.crop_size,
1036
+ resize_size=args.resize_size,
1037
+ n_last_blocks=args.n_last_blocks,
1038
+ avgpool=args.avgpool,
1039
+ scheduler=args.scheduler,
1040
+ )
1041
+ return 0
1042
+
1043
+
1044
+ if __name__ == "__main__":
1045
+ description = "DINOv2 linear_cell_dino evaluation"
1046
+ args_parser = get_args_parser(description=description)
1047
+ args = args_parser.parse_args()
1048
+ sys.exit(main(args))
dinov2/eval/cell_dino/utils.py ADDED
@@ -0,0 +1,542 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the CC-by-NC licence,
4
+ # found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ from typing import Callable, Dict, Optional, Any, List
8
+
9
+ import torch
10
+ from torch import nn
11
+ from torchmetrics import MetricCollection
12
+
13
+ from dinov2.data import DatasetWithEnumeratedTargets, SamplerType, make_data_loader
14
+ from dinov2.data import NoOpAccumulator, ResultsAccumulator
15
+ import dinov2.distributed as distributed
16
+ from dinov2.logging import MetricLogger
17
+ from enum import Enum
18
+ from torch.utils.data import Subset
19
+ from torchvision.datasets.vision import StandardTransform
20
+ import numpy as np
21
+ from torch.nn.functional import one_hot, softmax
22
+
23
+ logger = logging.getLogger("dinov2")
24
+
25
+
26
+ class LossType(Enum):
27
+ CROSS_ENTROPY = "cross_entropy"
28
+ BINARY_CROSS_ENTROPY = "binary_cross_entropy"
29
+
30
+
31
+ class BagOfChannelsModelWithNormalize(nn.Module):
32
+ def __init__(self, model, autocast_ctx, avgpool, n_last_blocks=1):
33
+ super().__init__()
34
+ self.model = model
35
+ self.autocast_ctx = autocast_ctx
36
+ self.n_last_blocks = n_last_blocks
37
+ self.avgpool = avgpool
38
+
39
+ def forward(self, samples):
40
+ with self.autocast_ctx():
41
+ features = self.model.get_intermediate_layers(samples, self.n_last_blocks, return_class_token=True)
42
+ output = create_linear_input(features, self.avgpool, use_n_blocks=self.n_last_blocks)
43
+ return nn.functional.normalize(output, dim=1, p=2)
44
+
45
+
46
+ @torch.inference_mode()
47
+ def evaluate_with_accumulate(
48
+ model: nn.Module,
49
+ data_loader,
50
+ postprocessors: Dict[str, nn.Module],
51
+ metrics: Dict[str, MetricCollection],
52
+ device: torch.device,
53
+ criterion: Optional[nn.Module] = None,
54
+ test_mode: bool = False,
55
+ accumulate_results: bool = False,
56
+ leave_one_out: bool = False,
57
+ ):
58
+ model.eval()
59
+
60
+ if test_mode:
61
+ output_tensor = {k: [] for k in postprocessors.keys()}
62
+ target_tensor = {k: [] for k in postprocessors.keys()}
63
+
64
+ if criterion is not None:
65
+ criterion.eval()
66
+
67
+ accumulator_class = ResultsAccumulator if accumulate_results else NoOpAccumulator
68
+ accumulators = {k: accumulator_class() for k in postprocessors.keys()}
69
+
70
+ for metric in metrics.values():
71
+ metric = metric.to(device)
72
+
73
+ metric_logger = MetricLogger(delimiter=" ")
74
+ header = "Test:"
75
+
76
+ for samples, targets, *_ in metric_logger.log_every(data_loader, 10, header):
77
+ if isinstance(targets, list):
78
+ index = targets[0]
79
+ targets = targets[1]
80
+ samples, targets, index = samples[index >= 0], targets[index >= 0], index[index >= 0]
81
+ if len(index) == 0:
82
+ continue
83
+
84
+ outputs = samples.to(device) if leave_one_out else model(samples.to(device))
85
+ targets = targets.to(device)
86
+
87
+ if criterion is not None:
88
+ loss = criterion(outputs, targets)
89
+ metric_logger.update(loss=loss.item())
90
+
91
+ for k, metric in metrics.items():
92
+ metric_inputs = postprocessors[k](outputs, targets)
93
+ metric.update(**metric_inputs)
94
+ if test_mode:
95
+ output_tensor[k].append(metric_inputs["preds"])
96
+ target_tensor[k].append(metric_inputs["target"])
97
+ accumulators[k].update(preds=metric_inputs["preds"], target=metric_inputs["target"], index=index)
98
+
99
+ metric_logger.synchronize_between_processes()
100
+ logger.info(f"Averaged stats: {metric_logger}")
101
+
102
+ stats = {k: metric.compute() for k, metric in metrics.items()}
103
+ metric_logger_stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
104
+
105
+ # accumulator.accumulate() returns None for the NoOpAccumulator
106
+ accumulated_results = {k: accumulator.accumulate() for k, accumulator in accumulators.items()}
107
+ if test_mode:
108
+ for k in postprocessors.keys():
109
+ output_tensor[k] = torch.cat(output_tensor[k])
110
+ target_tensor[k] = torch.cat(target_tensor[k])
111
+ accumulated_results = {k: [output_tensor[k], target_tensor[k]] for k in postprocessors.keys()}
112
+
113
+ if accumulate_results:
114
+ return metric_logger_stats, stats
115
+ return metric_logger_stats, stats, accumulated_results
116
+
117
+
118
+ def all_gather_and_flatten(tensor_rank):
119
+ tensor_all_ranks = torch.empty(
120
+ distributed.get_global_size(),
121
+ *tensor_rank.shape,
122
+ dtype=tensor_rank.dtype,
123
+ device=tensor_rank.device,
124
+ )
125
+ tensor_list = list(tensor_all_ranks.unbind(0))
126
+ torch.distributed.all_gather(tensor_list, tensor_rank.contiguous())
127
+ return tensor_all_ranks.flatten(end_dim=1)
128
+
129
+
130
+ def extract_features_cell_dino(
131
+ model, dataset, batch_size, num_workers, gather_on_cpu=False, shuffle=False, avgpool=False
132
+ ):
133
+ dataset_with_enumerated_targets = DatasetWithEnumeratedTargets(dataset)
134
+ sample_count = len(dataset_with_enumerated_targets)
135
+ data_loader = make_data_loader(
136
+ dataset=dataset_with_enumerated_targets,
137
+ batch_size=batch_size,
138
+ num_workers=num_workers,
139
+ sampler_type=SamplerType.DISTRIBUTED,
140
+ drop_last=False,
141
+ shuffle=shuffle,
142
+ )
143
+ return extract_features_with_dataloader_cell_dino(model, data_loader, sample_count, gather_on_cpu, avgpool=avgpool)
144
+
145
+
146
+ @torch.inference_mode()
147
+ def extract_features_with_dataloader_cell_dino(model, data_loader, sample_count, gather_on_cpu=False, avgpool=False):
148
+ gather_device = torch.device("cpu") if gather_on_cpu else torch.device("cuda")
149
+ metric_logger = MetricLogger(delimiter=" ")
150
+ features, all_labels = None, None
151
+ for samples, (index, labels_rank) in metric_logger.log_every(data_loader, 10):
152
+ samples = samples.cuda(non_blocking=True)
153
+ labels_rank = labels_rank.cuda(non_blocking=True)
154
+ index = index.cuda(non_blocking=True)
155
+ feat = model(samples)
156
+ if isinstance(samples, list) or isinstance(feat, tuple):
157
+ features_rank = create_linear_input(feat, avgpool=avgpool)
158
+ else:
159
+ features_rank = feat
160
+
161
+ # init storage feature matrix
162
+ if features is None:
163
+ features = torch.zeros(sample_count, features_rank.shape[-1], device=gather_device)
164
+ labels_shape = list(labels_rank.shape)
165
+ labels_shape[0] = sample_count
166
+ all_labels = torch.full(labels_shape, fill_value=-1, device=gather_device)
167
+ logger.info(f"Storing features into tensor of shape {features.shape}")
168
+
169
+ # share indexes, features and labels between processes
170
+ index_all = all_gather_and_flatten(index).to(gather_device)
171
+ features_all_ranks = all_gather_and_flatten(features_rank).to(gather_device)
172
+ labels_all_ranks = all_gather_and_flatten(labels_rank).to(gather_device)
173
+
174
+ # update storage feature matrix
175
+ if len(index_all) > 0:
176
+ features.index_copy_(0, index_all, features_all_ranks)
177
+ all_labels.index_copy_(0, index_all, labels_all_ranks)
178
+
179
+ logger.info(f"Features shape: {tuple(features.shape)}")
180
+ logger.info(f"Labels shape: {tuple(all_labels.shape)}")
181
+
182
+ assert torch.all(all_labels > -1)
183
+
184
+ return features, all_labels
185
+
186
+
187
+ def create_linear_input(x_tokens_list, avgpool=False, use_n_blocks=1):
188
+ intermediate_output = x_tokens_list[-use_n_blocks:]
189
+ output = torch.cat(
190
+ [class_token for _, class_token in intermediate_output], dim=-1
191
+ ) # concatenate class tokens of the last n blocks
192
+ if avgpool:
193
+ output = torch.cat(
194
+ (
195
+ output,
196
+ torch.mean(intermediate_output[-1][0], dim=-2).reshape(
197
+ intermediate_output[-1][0].shape[0], -1
198
+ ), # average pooling of patch tokens: average over N, then concatenate channels if single-channel patch model
199
+ ),
200
+ dim=-1,
201
+ ) # concatenate average pooling of patch tokens to concatenated patch tokens
202
+ output = output.reshape(output.shape[0], -1)
203
+
204
+ return output.float()
205
+
206
+
207
+ def get_target_transform(dataset) -> Optional[Callable]:
208
+ if hasattr(dataset, "transforms"):
209
+ if isinstance(dataset.transforms, StandardTransform):
210
+ return dataset.transforms.target_transform
211
+ raise ValueError("Dataset has a non-standard .transforms property")
212
+ if hasattr(dataset, "target_transform"):
213
+ return dataset.target_transform
214
+ return None
215
+
216
+
217
+ def get_labels(dataset) -> torch.Tensor:
218
+ """
219
+ Get the labels of a classification dataset, as a Tensor, using the `get_targets` method
220
+ if it is present or loading the labels one by one with `get_target`, if it exists.
221
+ If the dataset has a target transform, iterate over the whole dataset to get the
222
+ transformed labels for each element, then stack them as a torch tensor.
223
+ """
224
+ logger.info("Getting dataset labels ...")
225
+ if hasattr(dataset, "get_targets") or hasattr(dataset, "get_target"):
226
+ if hasattr(dataset, "get_targets"): # Returns a np.array
227
+ labels = dataset.get_targets()
228
+ elif hasattr(dataset, "get_target"):
229
+ labels = [dataset.get_target(i) for i in range(len(dataset))]
230
+ target_transform = get_target_transform(dataset)
231
+ if target_transform is not None:
232
+ labels = [target_transform(label) for label in labels]
233
+ else:
234
+ # Target transform is applied in this case
235
+ labels = [dataset[i][1] for i in range(len(dataset))]
236
+ return torch.stack([torch.tensor(label, dtype=int) for label in labels])
237
+
238
+
239
+ def get_num_classes(dataset) -> int:
240
+ """
241
+ Get the labels of a dataset and compute the number of classes
242
+ """
243
+ labels = get_labels(dataset)
244
+ if len(labels.shape) > 1:
245
+ return int(labels.shape[1])
246
+ return int(labels.max() + 1)
247
+
248
+
249
+ def average_metrics(eval_metrics_dict: dict[Any, dict[str, torch.Tensor]], ignore_keys: List[str] = []):
250
+ """
251
+ Function that computes the average and the std on a metrics dict.
252
+ A linear evaluation dictionary contains "best_classifier",
253
+ so this specific key is removed for computing aggregated metrics.
254
+ """
255
+ output_metrics_dict = {}
256
+ metrics = [metric for metric in eval_metrics_dict[0].keys() if metric not in ignore_keys]
257
+ for metric in metrics:
258
+ stats_tensor = torch.tensor([stat[metric] for stat in eval_metrics_dict.values()])
259
+ output_metrics_dict[metric + "_mean"] = stats_tensor.mean().item()
260
+ output_metrics_dict[metric + "_std"] = torch.std(stats_tensor).item()
261
+
262
+ return output_metrics_dict
263
+
264
+
265
+ def create_class_indices_mapping(labels: torch.Tensor) -> dict[int, torch.Tensor]:
266
+ """
267
+ Efficiently creates a mapping between the labels and tensors containing
268
+ the indices of all the dataset elements that share this label.
269
+ In the case of multiple labels, it is not guaranteed that there
270
+ will be exactly the specified percentage of labels.
271
+ """
272
+ if len(labels.shape) > 1: # labels are a one-hot encoding
273
+ assert len(labels.shape) == 2
274
+ sorted_labels, indices = torch.nonzero(labels.T, as_tuple=True)
275
+ else:
276
+ sorted_labels, indices = torch.sort(labels, stable=True)
277
+ unique_labels, counts = torch.unique_consecutive(sorted_labels, return_counts=True)
278
+ mapping = dict(zip(unique_labels.tolist(), torch.split(indices, counts.tolist())))
279
+ return mapping
280
+
281
+
282
+ def _shuffle_dataset(dataset: torch.Tensor, seed: int = 0):
283
+ """
284
+ Shuffling a dataset by subsetting it with a random permutation of its indices
285
+ """
286
+ random_generator = torch.Generator()
287
+ random_generator.manual_seed(seed)
288
+ random_indices = torch.randperm(len(dataset), generator=random_generator)
289
+ return Subset(dataset, random_indices)
290
+
291
+
292
+ def _subset_dataset_per_class(
293
+ class_indices_mapping: dict[int, torch.Tensor],
294
+ n_or_percent_per_class: float,
295
+ dataset_size: int,
296
+ seed: int = 0,
297
+ is_percent: bool = False,
298
+ ) -> torch.Tensor:
299
+ """
300
+ Helper function to select a percentage of a dataset, equally distributed across classes,
301
+ or to take the same number of elements from each class of the dataset.
302
+ Returns a boolean mask tensor being True at indices of selected elements
303
+ """
304
+
305
+ random_generator = torch.Generator()
306
+ random_generator.manual_seed(seed)
307
+
308
+ final_indices_bool = torch.zeros(dataset_size, dtype=bool)
309
+ for class_indices in class_indices_mapping.values():
310
+ # Select at least one element
311
+ n_for_class = max(int(len(class_indices) * n_or_percent_per_class), 1) if is_percent else n_or_percent_per_class
312
+ assert isinstance(n_for_class, int)
313
+ filtered_index = torch.randperm(len(class_indices), generator=random_generator)[:n_for_class]
314
+ final_indices_bool[class_indices[filtered_index]] = True
315
+ return final_indices_bool
316
+
317
+
318
+ def _multilabel_rebalance_subset(
319
+ class_indices_mapping: dict[int, torch.Tensor],
320
+ n_or_percent_per_class: float,
321
+ labels: torch.Tensor,
322
+ indices_bool: torch.Tensor,
323
+ dataset_size: int,
324
+ seed: int = 0,
325
+ ) -> torch.Tensor:
326
+ """
327
+ Helper function to refine a subset of a multi-label dataset (indices_bool)
328
+ to better match a target percentage of labels.
329
+ Returns a boolean mask tensor being True at indices of selected elements.
330
+ """
331
+
332
+ # Compute the number of selected labels in indices_bool
333
+ num_total_labels = labels.sum()
334
+ num_wanted_labels = int(num_total_labels * n_or_percent_per_class)
335
+ num_selected_labels = (labels[indices_bool] > 0).sum()
336
+ logger.info(f" {num_selected_labels} labels instead of {num_wanted_labels}")
337
+
338
+ # Compute a new percentage and new set selecting less images, therefore less labels, to match approximatelly the exact percentage of labels selected
339
+ n_or_percent_per_class = n_or_percent_per_class / (num_selected_labels / num_wanted_labels)
340
+ final_indices_bool = _subset_dataset_per_class(
341
+ class_indices_mapping, n_or_percent_per_class, dataset_size, seed, True
342
+ )
343
+
344
+ # Compute the number of labels finally used
345
+ num_selected_labels = (labels[final_indices_bool] > 0).sum()
346
+ logger.info(f" {num_selected_labels} labels instead of {num_wanted_labels}")
347
+
348
+ return final_indices_bool
349
+
350
+
351
+ def split_train_val_datasets(train_dataset, split_percentage: float = 0.1, shuffle_train: bool = True):
352
+ """
353
+ Splitting a percent of the train dataset to choose hyperparameters, taking the same percentage for each class.
354
+ If `shuffle` is False, taking the first elements of each class as the validaton set.
355
+ """
356
+ assert 0 < split_percentage < 1
357
+ logger.info(f"Selecting {int(split_percentage * 100)}% of the train dataset as the validation set")
358
+ if shuffle_train:
359
+ logger.info("Shuffling train dataset before splitting in train and validation sets")
360
+ train_dataset = _shuffle_dataset(train_dataset)
361
+ train_labels = get_labels(train_dataset)
362
+ class_indices_mapping = create_class_indices_mapping(train_labels)
363
+ val_mask = torch.zeros(len(train_labels), dtype=bool)
364
+ for class_indices in class_indices_mapping.values():
365
+ # If there is only one element, it goes in the train set
366
+ n_for_val = max(1, int(split_percentage * len(class_indices))) if len(class_indices) > 1 else 0
367
+ val_mask[class_indices[:n_for_val]] = True
368
+
369
+ val_dataset = Subset(train_dataset, val_mask.nonzero().flatten())
370
+ train_dataset = Subset(train_dataset, (~val_mask).nonzero().flatten())
371
+ return train_dataset, val_dataset
372
+
373
+
374
+ def create_train_dataset_dict(
375
+ train_dataset,
376
+ few_shot_eval: bool = False,
377
+ few_shot_k_or_percent=None,
378
+ few_shot_n_tries: int = 1,
379
+ ) -> dict[int, dict[int, Any]]:
380
+ """
381
+ Randomly split a dataset for few-shot evaluation, with `few_shot_k_or_percent` being
382
+ n elements or x% of a class. Produces a dict, which keys are number of random "tries"
383
+ and values are the dataset subset for this "try".
384
+
385
+ Format is {"nth-try": dataset}
386
+ """
387
+ if few_shot_eval is False:
388
+ assert few_shot_k_or_percent is None
389
+ assert few_shot_n_tries == 1
390
+ return {0: train_dataset}
391
+
392
+ assert few_shot_k_or_percent is not None
393
+ train_labels = get_labels(train_dataset)
394
+ class_indices_mapping = create_class_indices_mapping(train_labels)
395
+ train_dataset_dict: dict[int, Any] = {}
396
+ is_percent = few_shot_k_or_percent < 1
397
+ if not is_percent:
398
+ few_shot_k_or_percent = int(few_shot_k_or_percent)
399
+
400
+ for t in range(few_shot_n_tries):
401
+ t_subset_bool = _subset_dataset_per_class(
402
+ class_indices_mapping=class_indices_mapping,
403
+ n_or_percent_per_class=few_shot_k_or_percent,
404
+ dataset_size=len(train_labels),
405
+ is_percent=is_percent,
406
+ seed=t,
407
+ )
408
+ if len(train_labels.shape) > 1 and is_percent:
409
+ t_subset_bool = _multilabel_rebalance_subset(
410
+ class_indices_mapping=class_indices_mapping,
411
+ n_or_percent_per_class=few_shot_k_or_percent,
412
+ dataset_size=len(train_labels),
413
+ labels=train_labels,
414
+ indices_bool=t_subset_bool,
415
+ seed=t,
416
+ )
417
+ train_dataset_dict[t] = Subset(train_dataset, t_subset_bool.nonzero().flatten())
418
+ return train_dataset_dict
419
+
420
+
421
+ def extract_features_for_dataset_dict(
422
+ model,
423
+ dataset_dict: dict[int, dict[int, Any]],
424
+ batch_size: int,
425
+ num_workers: int,
426
+ gather_on_cpu=False,
427
+ avgpool=False,
428
+ ) -> dict[int, dict[str, torch.Tensor]]:
429
+ """
430
+ Extract features for each subset of dataset in the context of few-shot evaluations
431
+ """
432
+ few_shot_data_dict: dict[int, dict[str, torch.Tensor]] = {}
433
+ for try_n, dataset in dataset_dict.items():
434
+ features, labels = extract_features_cell_dino(
435
+ model, dataset, batch_size, num_workers, gather_on_cpu=gather_on_cpu, avgpool=avgpool
436
+ )
437
+ few_shot_data_dict[try_n] = {"train_features": features, "train_labels": labels}
438
+ return few_shot_data_dict
439
+
440
+
441
+ def pad_multilabel_and_collate(batch, pad_value=-1):
442
+ """
443
+ This method pads and collates a batch of (image, (index, target)) tuples, coming from
444
+ DatasetWithEnumeratedTargets, with targets that are list of potentially varying sizes.
445
+ The targets are padded to the length of the longest target list in the batch.
446
+ """
447
+ maxlen = max(len(targets) for _, (_, targets) in batch)
448
+ padded_batch = [
449
+ (image, (index, np.pad(targets, (0, maxlen - len(targets)), constant_values=pad_value)))
450
+ for image, (index, targets) in batch
451
+ ]
452
+ return torch.utils.data.default_collate(padded_batch)
453
+
454
+
455
+ class KnnModule(torch.nn.Module):
456
+ """
457
+ Gets knn of test features from all processes on a chunk of the train features
458
+
459
+ Each rank gets a chunk of the train features as well as a chunk of the test features.
460
+ In `compute_neighbors`, for each rank one after the other, its chunk of test features
461
+ is sent to all devices, partial knns are computed with each chunk of train features
462
+ then collated back on the original device.
463
+ """
464
+
465
+ def __init__(self, train_features, train_labels, nb_knn, T, device, num_classes=1000):
466
+ super().__init__()
467
+
468
+ self.global_rank = distributed.get_global_rank()
469
+ self.global_size = distributed.get_global_size()
470
+
471
+ self.device = device
472
+ self.train_features_rank_T = train_features.chunk(self.global_size)[self.global_rank].T.to(self.device)
473
+ # Labels can either be integers, or in a one-hot format
474
+ self.candidates = train_labels.chunk(self.global_size)[self.global_rank].unsqueeze(0).to(self.device)
475
+ self.nb_knn = nb_knn
476
+ self.max_k = max(self.nb_knn)
477
+ self.T = T
478
+ self.num_classes = num_classes
479
+
480
+ def _get_knn_sims_and_labels(self, similarity, train_labels):
481
+ topk_sims, indices = similarity.topk(self.max_k, largest=True, sorted=True)
482
+ if len(train_labels.shape) == 3: # If the labels are in one_hot format
483
+ indices = indices.unsqueeze(2).expand(-1, -1, self.num_classes) # Orignally [bs, max_k]
484
+ neighbors_labels = torch.gather(train_labels, 1, indices)
485
+ return topk_sims, neighbors_labels
486
+
487
+ def _similarity_for_rank(self, features_rank, source_rank):
488
+ # Send the features from `source_rank` to all ranks
489
+ broadcast_shape = torch.tensor(features_rank.shape).to(self.device)
490
+ torch.distributed.broadcast(broadcast_shape, source_rank)
491
+
492
+ broadcasted = features_rank
493
+ if self.global_rank != source_rank:
494
+ broadcasted = torch.zeros(*broadcast_shape, dtype=features_rank.dtype, device=self.device)
495
+ torch.distributed.broadcast(broadcasted, source_rank)
496
+
497
+ # Compute the neighbors for `source_rank` among `train_features_rank_T`
498
+ similarity_rank = torch.mm(broadcasted, self.train_features_rank_T)
499
+ candidate_labels = self.candidates.expand(len(similarity_rank), *self.candidates.shape[1:])
500
+ return self._get_knn_sims_and_labels(similarity_rank, candidate_labels)
501
+
502
+ def _gather_all_knn_for_rank(self, topk_sims, neighbors_labels, target_rank):
503
+ # Gather all neighbors for `target_rank`
504
+ topk_sims_rank = retrieved_rank = None
505
+ if self.global_rank == target_rank:
506
+ topk_sims_rank = [torch.zeros_like(topk_sims) for _ in range(self.global_size)]
507
+ retrieved_rank = [torch.zeros_like(neighbors_labels) for _ in range(self.global_size)]
508
+
509
+ torch.distributed.gather(topk_sims, topk_sims_rank, dst=target_rank)
510
+ torch.distributed.gather(neighbors_labels, retrieved_rank, dst=target_rank)
511
+
512
+ if self.global_rank == target_rank:
513
+ # Perform a second top-k on the k * global_size retrieved neighbors
514
+ topk_sims_rank = torch.cat(topk_sims_rank, dim=1)
515
+ retrieved_rank = torch.cat(retrieved_rank, dim=1)
516
+ results = self._get_knn_sims_and_labels(topk_sims_rank, retrieved_rank)
517
+ return results
518
+ return None
519
+
520
+ def compute_neighbors(self, features_rank):
521
+ for rank in range(self.global_size):
522
+ topk_sims, neighbors_labels = self._similarity_for_rank(features_rank, rank)
523
+ results = self._gather_all_knn_for_rank(topk_sims, neighbors_labels, rank)
524
+ if results is not None:
525
+ topk_sims_rank, neighbors_labels_rank = results
526
+ return topk_sims_rank, neighbors_labels_rank
527
+
528
+ def forward(self, features_rank):
529
+ """
530
+ Compute the results on all values of `self.nb_knn` neighbors from the full `self.max_k`
531
+ """
532
+ assert all(k <= self.max_k for k in self.nb_knn)
533
+
534
+ topk_sims, neighbors_labels = self.compute_neighbors(features_rank)
535
+ batch_size = neighbors_labels.shape[0]
536
+ topk_sims_transform = softmax(topk_sims / self.T, 1)
537
+ voting_coefficient = topk_sims_transform.view(batch_size, -1, 1)
538
+ if len(neighbors_labels.shape) == 2: # If the labels are not yet one hot
539
+ neighbors_labels = one_hot(neighbors_labels, num_classes=self.num_classes)
540
+ matmul = torch.mul(neighbors_labels, voting_coefficient)
541
+ probas_for_k = {k: torch.sum(matmul[:, :k, :], 1) for k in self.nb_knn}
542
+ return probas_for_k
dinov2/eval/depth/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
dinov2/eval/depth/models/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from .backbones import * # noqa: F403
7
+ from .builder import BACKBONES, DEPTHER, HEADS, LOSSES, build_backbone, build_depther, build_head, build_loss
8
+ from .decode_heads import * # noqa: F403
9
+ from .depther import * # noqa: F403
10
+ from .losses import * # noqa: F403
dinov2/eval/depth/models/backbones/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from .vision_transformer import DinoVisionTransformer
dinov2/eval/depth/models/backbones/vision_transformer.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from mmcv.runner import BaseModule
7
+
8
+ from ..builder import BACKBONES
9
+
10
+
11
+ @BACKBONES.register_module()
12
+ class DinoVisionTransformer(BaseModule):
13
+ """Vision Transformer."""
14
+
15
+ def __init__(self, *args, **kwargs):
16
+ super().__init__()
dinov2/eval/depth/models/builder.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import warnings
7
+
8
+ from mmcv.cnn import MODELS as MMCV_MODELS
9
+ from mmcv.cnn.bricks.registry import ATTENTION as MMCV_ATTENTION
10
+ from mmcv.utils import Registry
11
+
12
+ MODELS = Registry("models", parent=MMCV_MODELS)
13
+ ATTENTION = Registry("attention", parent=MMCV_ATTENTION)
14
+
15
+
16
+ BACKBONES = MODELS
17
+ NECKS = MODELS
18
+ HEADS = MODELS
19
+ LOSSES = MODELS
20
+ DEPTHER = MODELS
21
+
22
+
23
+ def build_backbone(cfg):
24
+ """Build backbone."""
25
+ return BACKBONES.build(cfg)
26
+
27
+
28
+ def build_neck(cfg):
29
+ """Build neck."""
30
+ return NECKS.build(cfg)
31
+
32
+
33
+ def build_head(cfg):
34
+ """Build head."""
35
+ return HEADS.build(cfg)
36
+
37
+
38
+ def build_loss(cfg):
39
+ """Build loss."""
40
+ return LOSSES.build(cfg)
41
+
42
+
43
+ def build_depther(cfg, train_cfg=None, test_cfg=None):
44
+ """Build depther."""
45
+ if train_cfg is not None or test_cfg is not None:
46
+ warnings.warn("train_cfg and test_cfg is deprecated, " "please specify them in model", UserWarning)
47
+ assert cfg.get("train_cfg") is None or train_cfg is None, "train_cfg specified in both outer field and model field "
48
+ assert cfg.get("test_cfg") is None or test_cfg is None, "test_cfg specified in both outer field and model field "
49
+ return DEPTHER.build(cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))