Addax-Data-Science commited on
Commit
3d18251
·
verified ·
1 Parent(s): c8ae7e8

Delete dinov2

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. dinov2/__init__.py +0 -6
  2. dinov2/configs/__init__.py +0 -22
  3. dinov2/configs/eval/cell_dino/vitl16_channel_adaptive_pretrain.yaml +0 -35
  4. dinov2/configs/eval/cell_dino/vitl16_pretrain.yaml +0 -14
  5. dinov2/configs/eval/vitb14_pretrain.yaml +0 -6
  6. dinov2/configs/eval/vitb14_reg4_pretrain.yaml +0 -9
  7. dinov2/configs/eval/vitg14_pretrain.yaml +0 -7
  8. dinov2/configs/eval/vitg14_reg4_pretrain.yaml +0 -10
  9. dinov2/configs/eval/vitl14_pretrain.yaml +0 -6
  10. dinov2/configs/eval/vitl14_reg4_pretrain.yaml +0 -9
  11. dinov2/configs/eval/vits14_pretrain.yaml +0 -6
  12. dinov2/configs/eval/vits14_reg4_pretrain.yaml +0 -9
  13. dinov2/configs/ssl_default_config.yaml +0 -123
  14. dinov2/configs/train/cell_dino/vitl16_boc_hpafov.yaml +0 -31
  15. dinov2/configs/train/cell_dino/vitl16_hpafov.yaml +0 -32
  16. dinov2/configs/train/cell_dino/vitl16_hpaone.yaml +0 -30
  17. dinov2/configs/train/vitg14.yaml +0 -26
  18. dinov2/configs/train/vitl14.yaml +0 -26
  19. dinov2/configs/train/vitl16_short.yaml +0 -6
  20. dinov2/data/__init__.py +0 -12
  21. dinov2/data/accumulators.py +0 -133
  22. dinov2/data/adapters.py +0 -51
  23. dinov2/data/augmentations.py +0 -118
  24. dinov2/data/cell_dino/augmentations.py +0 -91
  25. dinov2/data/cell_dino/transforms.py +0 -169
  26. dinov2/data/collate.py +0 -49
  27. dinov2/data/datasets/__init__.py +0 -12
  28. dinov2/data/datasets/cell_dino/chammi_cp.py +0 -112
  29. dinov2/data/datasets/cell_dino/chammi_hpa.py +0 -111
  30. dinov2/data/datasets/cell_dino/chammi_wtc.py +0 -108
  31. dinov2/data/datasets/cell_dino/hpafov.py +0 -283
  32. dinov2/data/datasets/cell_dino/hpaone.py +0 -223
  33. dinov2/data/datasets/decoders.py +0 -94
  34. dinov2/data/datasets/extended.py +0 -44
  35. dinov2/data/datasets/image_net.py +0 -290
  36. dinov2/data/datasets/image_net_22k.py +0 -302
  37. dinov2/data/loaders.py +0 -232
  38. dinov2/data/masking.py +0 -86
  39. dinov2/data/samplers.py +0 -229
  40. dinov2/data/transforms.py +0 -91
  41. dinov2/distributed/__init__.py +0 -270
  42. dinov2/eval/__init__.py +0 -4
  43. dinov2/eval/cell_dino/knn.py +0 -479
  44. dinov2/eval/cell_dino/linear.py +0 -1048
  45. dinov2/eval/cell_dino/utils.py +0 -542
  46. dinov2/eval/depth/__init__.py +0 -4
  47. dinov2/eval/depth/models/__init__.py +0 -10
  48. dinov2/eval/depth/models/backbones/__init__.py +0 -6
  49. dinov2/eval/depth/models/backbones/vision_transformer.py +0 -16
  50. dinov2/eval/depth/models/builder.py +0 -49
dinov2/__init__.py DELETED
@@ -1,6 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the Apache License, Version 2.0
4
- # found in the LICENSE file in the root directory of this source tree.
5
-
6
- __version__ = "0.0.1"
 
 
 
 
 
 
 
dinov2/configs/__init__.py DELETED
@@ -1,22 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the Apache License, Version 2.0
4
- # found in the LICENSE file in the root directory of this source tree.
5
-
6
- import pathlib
7
-
8
- from omegaconf import OmegaConf
9
-
10
-
11
- def load_config(config_name: str):
12
- config_filename = config_name + ".yaml"
13
- return OmegaConf.load(pathlib.Path(__file__).parent.resolve() / config_filename)
14
-
15
-
16
- dinov2_default_config = load_config("ssl_default_config")
17
-
18
-
19
- def load_and_merge_config(config_name: str):
20
- default_config = OmegaConf.create(dinov2_default_config)
21
- loaded_config = load_config(config_name)
22
- return OmegaConf.merge(default_config, loaded_config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dinov2/configs/eval/cell_dino/vitl16_channel_adaptive_pretrain.yaml DELETED
@@ -1,35 +0,0 @@
1
- train:
2
- batch_size_per_gpu: 32
3
- OFFICIAL_EPOCH_LENGTH: 450
4
- cell_augmentation: true
5
- channel_adaptive: true
6
- student:
7
- arch: vit_large
8
- patch_size: 16
9
- num_register_tokens: 0
10
- interpolate_antialias: false
11
- interpolate_offset: 0.1
12
- drop_path_rate: 0.1
13
- in_chans: 1
14
- block_chunks: 4
15
- channel_adaptive: true
16
- teacher:
17
- momentum_teacher: 0.996
18
- warmup_teacher_temp_epochs: 20
19
- in_chans: 1
20
- channel_adaptive: true
21
- crops:
22
- global_crops_scale:
23
- - 0.4
24
- - 1.0
25
- local_crops_number: 8
26
- local_crops_scale:
27
- - 0.005
28
- - 0.4
29
- global_crops_size: 224
30
- local_crops_size: 96
31
- optim:
32
- weight_decay_end: 0.2
33
- base_lr: 5.0e-4
34
- warmup_epochs: 20
35
- epochs: 400
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dinov2/configs/eval/cell_dino/vitl16_pretrain.yaml DELETED
@@ -1,14 +0,0 @@
1
- student:
2
- arch: vit_large
3
- patch_size: 16
4
- num_register_tokens: 0
5
- interpolate_antialias: false
6
- interpolate_offset: 0.1
7
- drop_path_rate: 0.1
8
- in_chans: 4
9
- block_chunks: 4
10
- teacher:
11
- in_chans: 4
12
- crops:
13
- global_crops_size: 224
14
- local_crops_size: 96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dinov2/configs/eval/vitb14_pretrain.yaml DELETED
@@ -1,6 +0,0 @@
1
- student:
2
- arch: vit_base
3
- patch_size: 14
4
- crops:
5
- global_crops_size: 518 # this is to set up the position embeddings properly
6
- local_crops_size: 98
 
 
 
 
 
 
 
dinov2/configs/eval/vitb14_reg4_pretrain.yaml DELETED
@@ -1,9 +0,0 @@
1
- student:
2
- arch: vit_base
3
- patch_size: 14
4
- num_register_tokens: 4
5
- interpolate_antialias: true
6
- interpolate_offset: 0.0
7
- crops:
8
- global_crops_size: 518 # this is to set up the position embeddings properly
9
- local_crops_size: 98
 
 
 
 
 
 
 
 
 
 
dinov2/configs/eval/vitg14_pretrain.yaml DELETED
@@ -1,7 +0,0 @@
1
- student:
2
- arch: vit_giant2
3
- patch_size: 14
4
- ffn_layer: swiglufused
5
- crops:
6
- global_crops_size: 518 # this is to set up the position embeddings properly
7
- local_crops_size: 98
 
 
 
 
 
 
 
 
dinov2/configs/eval/vitg14_reg4_pretrain.yaml DELETED
@@ -1,10 +0,0 @@
1
- student:
2
- arch: vit_giant2
3
- patch_size: 14
4
- ffn_layer: swiglufused
5
- num_register_tokens: 4
6
- interpolate_antialias: true
7
- interpolate_offset: 0.0
8
- crops:
9
- global_crops_size: 518 # this is to set up the position embeddings properly
10
- local_crops_size: 98
 
 
 
 
 
 
 
 
 
 
 
dinov2/configs/eval/vitl14_pretrain.yaml DELETED
@@ -1,6 +0,0 @@
1
- student:
2
- arch: vit_large
3
- patch_size: 14
4
- crops:
5
- global_crops_size: 518 # this is to set up the position embeddings properly
6
- local_crops_size: 98
 
 
 
 
 
 
 
dinov2/configs/eval/vitl14_reg4_pretrain.yaml DELETED
@@ -1,9 +0,0 @@
1
- student:
2
- arch: vit_large
3
- patch_size: 14
4
- num_register_tokens: 4
5
- interpolate_antialias: true
6
- interpolate_offset: 0.0
7
- crops:
8
- global_crops_size: 518 # this is to set up the position embeddings properly
9
- local_crops_size: 98
 
 
 
 
 
 
 
 
 
 
dinov2/configs/eval/vits14_pretrain.yaml DELETED
@@ -1,6 +0,0 @@
1
- student:
2
- arch: vit_small
3
- patch_size: 14
4
- crops:
5
- global_crops_size: 518 # this is to set up the position embeddings properly
6
- local_crops_size: 98
 
 
 
 
 
 
 
dinov2/configs/eval/vits14_reg4_pretrain.yaml DELETED
@@ -1,9 +0,0 @@
1
- student:
2
- arch: vit_small
3
- patch_size: 14
4
- num_register_tokens: 4
5
- interpolate_antialias: true
6
- interpolate_offset: 0.0
7
- crops:
8
- global_crops_size: 518 # this is to set up the position embeddings properly
9
- local_crops_size: 98
 
 
 
 
 
 
 
 
 
 
dinov2/configs/ssl_default_config.yaml DELETED
@@ -1,123 +0,0 @@
1
- MODEL:
2
- WEIGHTS: ''
3
- compute_precision:
4
- grad_scaler: true
5
- teacher:
6
- backbone:
7
- sharding_strategy: SHARD_GRAD_OP
8
- mixed_precision:
9
- param_dtype: fp16
10
- reduce_dtype: fp16
11
- buffer_dtype: fp32
12
- dino_head:
13
- sharding_strategy: SHARD_GRAD_OP
14
- mixed_precision:
15
- param_dtype: fp16
16
- reduce_dtype: fp16
17
- buffer_dtype: fp32
18
- ibot_head:
19
- sharding_strategy: SHARD_GRAD_OP
20
- mixed_precision:
21
- param_dtype: fp16
22
- reduce_dtype: fp16
23
- buffer_dtype: fp32
24
- student:
25
- backbone:
26
- sharding_strategy: SHARD_GRAD_OP
27
- mixed_precision:
28
- param_dtype: fp16
29
- reduce_dtype: fp16
30
- buffer_dtype: fp32
31
- dino_head:
32
- sharding_strategy: SHARD_GRAD_OP
33
- mixed_precision:
34
- param_dtype: fp16
35
- reduce_dtype: fp32
36
- buffer_dtype: fp32
37
- ibot_head:
38
- sharding_strategy: SHARD_GRAD_OP
39
- mixed_precision:
40
- param_dtype: fp16
41
- reduce_dtype: fp32
42
- buffer_dtype: fp32
43
- dino:
44
- loss_weight: 1.0
45
- head_n_prototypes: 65536
46
- head_bottleneck_dim: 256
47
- head_nlayers: 3
48
- head_hidden_dim: 2048
49
- koleo_loss_weight: 0.1
50
- ibot:
51
- loss_weight: 1.0
52
- mask_sample_probability: 0.5
53
- mask_ratio_min_max:
54
- - 0.1
55
- - 0.5
56
- separate_head: false
57
- head_n_prototypes: 65536
58
- head_bottleneck_dim: 256
59
- head_nlayers: 3
60
- head_hidden_dim: 2048
61
- train:
62
- batch_size_per_gpu: 64
63
- dataset_path: ImageNet:split=TRAIN
64
- output_dir: .
65
- saveckp_freq: 20
66
- seed: 0
67
- num_workers: 10
68
- OFFICIAL_EPOCH_LENGTH: 1250
69
- cache_dataset: true
70
- centering: "centering" # or "sinkhorn_knopp"
71
- cell_augmentation: false
72
- student:
73
- arch: vit_large
74
- patch_size: 16
75
- drop_path_rate: 0.3
76
- layerscale: 1.0e-05
77
- drop_path_uniform: true
78
- pretrained_weights: ''
79
- ffn_layer: "mlp"
80
- block_chunks: 0
81
- qkv_bias: true
82
- proj_bias: true
83
- ffn_bias: true
84
- num_register_tokens: 0
85
- interpolate_antialias: false
86
- interpolate_offset: 0.1
87
- in_chans: 3
88
- channel_adaptive: false
89
- teacher:
90
- momentum_teacher: 0.992
91
- final_momentum_teacher: 1
92
- warmup_teacher_temp: 0.04
93
- teacher_temp: 0.07
94
- warmup_teacher_temp_epochs: 30
95
- in_chans: 3
96
- channel_adaptive: false
97
- optim:
98
- epochs: 100
99
- weight_decay: 0.04
100
- weight_decay_end: 0.4
101
- base_lr: 0.004 # learning rate for a batch size of 1024
102
- lr: 0. # will be set after applying scaling rule
103
- warmup_epochs: 10
104
- min_lr: 1.0e-06
105
- clip_grad: 3.0
106
- freeze_last_layer_epochs: 1
107
- scaling_rule: sqrt_wrt_1024
108
- patch_embed_lr_mult: 0.2
109
- layerwise_decay: 0.9
110
- adamw_beta1: 0.9
111
- adamw_beta2: 0.999
112
- crops:
113
- global_crops_scale:
114
- - 0.32
115
- - 1.0
116
- local_crops_number: 8
117
- local_crops_scale:
118
- - 0.05
119
- - 0.32
120
- global_crops_size: 224
121
- local_crops_size: 96
122
- evaluation:
123
- eval_period_iterations: 12500
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dinov2/configs/train/cell_dino/vitl16_boc_hpafov.yaml DELETED
@@ -1,31 +0,0 @@
1
- train:
2
- batch_size_per_gpu: 16
3
- OFFICIAL_EPOCH_LENGTH: 450
4
- cell_augmentation: true
5
- channel_adaptive: true
6
- student:
7
- arch: vit_large
8
- patch_size: 16
9
- in_chans: 1
10
- drop_path_rate: 0.1
11
- block_chunks: 4
12
- teacher:
13
- momentum_teacher: 0.996
14
- warmup_teacher_temp_epochs: 20
15
- in_chans: 1
16
- crops:
17
- global_crops_scale:
18
- - 0.4
19
- - 1.0
20
- local_crops_number: 8
21
- local_crops_scale:
22
- - 0.005
23
- - 0.4
24
- global_crops_size: 224
25
- local_crops_size: 96
26
- optim:
27
- weight_decay_end: 0.2
28
- base_lr: 5.0e-4
29
- warmup_epochs: 20
30
- epochs: 400
31
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dinov2/configs/train/cell_dino/vitl16_hpafov.yaml DELETED
@@ -1,32 +0,0 @@
1
- train:
2
- batch_size_per_gpu: 16
3
- OFFICIAL_EPOCH_LENGTH: 450
4
- cell_augmentation: true
5
- student:
6
- arch: vit_large
7
- patch_size: 16
8
- in_chans: 4
9
- drop_path_rate: 0.1
10
- block_chunks: 4
11
- teacher:
12
- momentum_teacher: 0.996
13
- warmup_teacher_temp_epochs: 20
14
- in_chans: 4
15
- optim:
16
- weight_decay_end: 0.2
17
- base_lr: 5.0e-4
18
- warmup_epochs: 20
19
- crops:
20
- global_crops_scale:
21
- - 0.4
22
- - 1.0
23
- local_crops_number: 8
24
- local_crops_scale:
25
- - 0.005
26
- - 0.4
27
- global_crops_size: 224
28
- local_crops_size: 96
29
- evaluation:
30
- eval_period_iterations: 9000
31
-
32
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dinov2/configs/train/cell_dino/vitl16_hpaone.yaml DELETED
@@ -1,30 +0,0 @@
1
- train:
2
- batch_size_per_gpu: 16
3
- OFFICIAL_EPOCH_LENGTH: 1756
4
- cell_augmentation: true
5
- student:
6
- arch: vit_large
7
- patch_size: 16
8
- in_chans: 4
9
- drop_path_rate: 0.1
10
- block_chunks: 4
11
- teacher:
12
- momentum_teacher: 0.996
13
- warmup_teacher_temp_epochs: 20
14
- in_chans: 4
15
- optim:
16
- weight_decay_end: 0.2
17
- base_lr: 5.0e-4
18
- warmup_epochs: 20
19
- crops:
20
- global_crops_scale:
21
- - 0.4
22
- - 1.0
23
- local_crops_number: 8
24
- local_crops_scale:
25
- - 0.005
26
- - 0.4
27
- global_crops_size: 224
28
- local_crops_size: 96
29
- evaluation:
30
- eval_period_iterations: 9000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dinov2/configs/train/vitg14.yaml DELETED
@@ -1,26 +0,0 @@
1
- dino:
2
- head_n_prototypes: 131072
3
- head_bottleneck_dim: 384
4
- ibot:
5
- separate_head: true
6
- head_n_prototypes: 131072
7
- train:
8
- batch_size_per_gpu: 12
9
- dataset_path: ImageNet22k
10
- centering: sinkhorn_knopp
11
- student:
12
- arch: vit_giant2
13
- patch_size: 14
14
- drop_path_rate: 0.4
15
- ffn_layer: swiglufused
16
- block_chunks: 4
17
- teacher:
18
- momentum_teacher: 0.994
19
- optim:
20
- epochs: 500
21
- weight_decay_end: 0.2
22
- base_lr: 2.0e-04 # learning rate for a batch size of 1024
23
- warmup_epochs: 80
24
- layerwise_decay: 1.0
25
- crops:
26
- local_crops_size: 98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dinov2/configs/train/vitl14.yaml DELETED
@@ -1,26 +0,0 @@
1
- dino:
2
- head_n_prototypes: 131072
3
- head_bottleneck_dim: 384
4
- ibot:
5
- separate_head: true
6
- head_n_prototypes: 131072
7
- train:
8
- batch_size_per_gpu: 32
9
- dataset_path: ImageNet22k
10
- centering: sinkhorn_knopp
11
- student:
12
- arch: vit_large
13
- patch_size: 14
14
- drop_path_rate: 0.4
15
- ffn_layer: swiglufused
16
- block_chunks: 4
17
- teacher:
18
- momentum_teacher: 0.994
19
- optim:
20
- epochs: 500
21
- weight_decay_end: 0.2
22
- base_lr: 2.0e-04 # learning rate for a batch size of 1024
23
- warmup_epochs: 80
24
- layerwise_decay: 1.0
25
- crops:
26
- local_crops_size: 98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dinov2/configs/train/vitl16_short.yaml DELETED
@@ -1,6 +0,0 @@
1
- # this corresponds to the default config
2
- train:
3
- dataset_path: ImageNet:split=TRAIN
4
- batch_size_per_gpu: 64
5
- student:
6
- block_chunks: 4
 
 
 
 
 
 
 
dinov2/data/__init__.py DELETED
@@ -1,12 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the Apache License, Version 2.0
4
- # found in the LICENSE file in the root directory of this source tree.
5
-
6
- from .adapters import DatasetWithEnumeratedTargets
7
- from .loaders import make_data_loader, make_dataset, SamplerType
8
- from .collate import collate_data_and_cast
9
- from .masking import MaskingGenerator
10
- from .augmentations import DataAugmentationDINO
11
- from .cell_dino.augmentations import CellAugmentationDINO
12
- from .accumulators import NoOpAccumulator, ResultsAccumulator
 
 
 
 
 
 
 
 
 
 
 
 
 
dinov2/data/accumulators.py DELETED
@@ -1,133 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the Apache License, Version 2.0
4
- # found in the LICENSE file in the root directory of this source tree.
5
-
6
- from collections import defaultdict
7
- from typing import Dict, List, Optional, Any
8
-
9
- import torch
10
- from torch import Tensor
11
- from torch.nn import functional as F
12
-
13
- import torch.distributed as dist
14
- from dinov2.distributed import get_global_size
15
-
16
-
17
- def _simple_gather_all_tensors(result: torch.Tensor, group: Any, world_size: int) -> List[torch.Tensor]:
18
- gathered_result = [torch.zeros_like(result) for _ in range(world_size)]
19
- dist.all_gather(gathered_result, result, group)
20
- return gathered_result
21
-
22
-
23
- def gather_all_tensors(result: torch.Tensor, group: Optional[Any] = None) -> List[torch.Tensor]:
24
- """
25
- Copied from https://github.com/Lightning-AI/torchmetrics/blob/master/src/torchmetrics/utilities/distributed.py
26
- Gather all tensors from several ddp processes onto a list that is broadcasted to all processes.
27
-
28
- Works on tensors that have the same number of dimensions, but where each dimension may differ. In this case
29
- tensors are padded, gathered and then trimmed to secure equal workload for all processes.
30
-
31
- Args:
32
- result: the value to sync
33
- group: the process group to gather results from. Defaults to all processes (world)
34
-
35
- Return:
36
- list with size equal to the process group where element i corresponds to result tensor from process i
37
- """
38
- # convert tensors to contiguous format
39
- result = result.contiguous()
40
-
41
- world_size = get_global_size()
42
- dist.barrier(group=group)
43
-
44
- # if the tensor is scalar, things are easy
45
- if result.ndim == 0:
46
- return _simple_gather_all_tensors(result, group, world_size)
47
-
48
- # 1. Gather sizes of all tensors
49
- local_size = torch.tensor(result.shape, device=result.device)
50
- local_sizes = [torch.zeros_like(local_size) for _ in range(world_size)]
51
- dist.all_gather(local_sizes, local_size, group=group)
52
- max_size = torch.stack(local_sizes).max(dim=0).values
53
- all_sizes_equal = all(all(ls == max_size) for ls in local_sizes)
54
-
55
- # 2. If shapes are all the same, then do a simple gather:
56
- if all_sizes_equal:
57
- return _simple_gather_all_tensors(result, group, world_size)
58
-
59
- # 3. If not, we need to pad each local tensor to maximum size, gather and then truncate
60
- pad_dims = []
61
- pad_by = (max_size - local_size).detach().cpu()
62
- for val in reversed(pad_by):
63
- pad_dims.append(0)
64
- pad_dims.append(val.item())
65
- result_padded = F.pad(result, pad_dims)
66
- gathered_result = [torch.zeros_like(result_padded) for _ in range(world_size)]
67
- dist.all_gather(gathered_result, result_padded, group)
68
- for idx, item_size in enumerate(local_sizes):
69
- slice_param = [slice(dim_size) for dim_size in item_size]
70
- gathered_result[idx] = gathered_result[idx][slice_param]
71
- return gathered_result
72
-
73
-
74
- def _cat_and_gather_tensor_list(tensor_list: List[Tensor]) -> Tensor:
75
- local_cat = torch.cat(tensor_list)
76
- return torch.cat(gather_all_tensors(local_cat))
77
-
78
-
79
- class Accumulator:
80
- def __init__(self) -> None:
81
- pass
82
-
83
- def update(self, preds: Tensor, target: Tensor, index: Tensor) -> None:
84
- raise NotImplementedError
85
-
86
- def accumulate(self) -> Optional[Dict[str, Tensor]]:
87
- raise NotImplementedError
88
-
89
-
90
- class NoOpAccumulator(Accumulator):
91
- def __init__(self) -> None:
92
- pass
93
-
94
- def update(self, preds: Tensor, target: Tensor, index: Tensor) -> None:
95
- pass
96
-
97
- def accumulate(self) -> None:
98
- return None
99
-
100
-
101
- class ResultsAccumulator(Accumulator):
102
- """
103
- Accumulate predictions and targets across processes
104
- """
105
-
106
- def __init__(self) -> None:
107
- self._local_values: Dict[str, List[Tensor]] = defaultdict(list)
108
- self._gathered_values: Dict[str, Tensor] = {}
109
- self._gathered = False
110
-
111
- def update(self, preds: Tensor, target: Tensor, index: Tensor) -> None:
112
- assert len(preds) == len(target) == len(index)
113
- assert not self._gathered, "Tensors have already been gathered in this helper"
114
- self._local_values["preds"].append(preds)
115
- self._local_values["target"].append(target)
116
- self._local_values["index"].append(index)
117
- self._gathered = False
118
-
119
- def _gather_tensors(self):
120
- for k, tensor_list in self._local_values.items():
121
- self._gathered_values[k] = _cat_and_gather_tensor_list(tensor_list)
122
- self._gathered = True
123
-
124
- def accumulate(self) -> Dict[str, Tensor]:
125
- if not self._gathered:
126
- self._gather_tensors()
127
- preds, target, index = [self._gathered_values[k] for k in ["preds", "target", "index"]]
128
- assert len(preds) == len(target) == len(index) and index.min() == 0
129
- preds_ordered = torch.zeros((index.max() + 1, *preds.shape[1:]), dtype=preds.dtype, device=preds.device)
130
- preds_ordered[index] = preds
131
- target_ordered = torch.zeros((index.max() + 1, *target.shape[1:]), dtype=target.dtype, device=target.device)
132
- target_ordered[index] = target
133
- return {"preds": preds_ordered, "target": target_ordered}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dinov2/data/adapters.py DELETED
@@ -1,51 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the Apache License, Version 2.0
4
- # found in the LICENSE file in the root directory of this source tree.
5
-
6
- from typing import Any, Tuple, Optional
7
-
8
- from torch.utils.data import Dataset
9
-
10
-
11
- class DatasetWithEnumeratedTargets(Dataset):
12
- """
13
- If pad_dataset is set, pads based on torch's DistributedSampler implementation, which
14
- with drop_last=False pads the last batch to be a multiple of the world size.
15
- https://github.com/pytorch/pytorch/blob/main/torch/utils/data/distributed.py#L91
16
- """
17
-
18
- def __init__(self, dataset: Dataset, pad_dataset: bool = False, num_replicas: Optional[int] = None):
19
- self._dataset = dataset
20
- self._size = len(self._dataset)
21
- self._padded_size = self._size
22
- self._pad_dataset = pad_dataset
23
- if self._pad_dataset:
24
- assert num_replicas is not None, "num_replicas should be set if pad_dataset is True"
25
- self._padded_size = num_replicas * ((len(dataset) + num_replicas - 1) // num_replicas)
26
-
27
- def get_image_relpath(self, index: int) -> str:
28
- assert self._pad_dataset or index < self._size
29
- return self._dataset.get_image_relpath(index % self._size)
30
-
31
- def get_image_data(self, index: int) -> bytes:
32
- assert self._pad_dataset or index < self._size
33
- return self._dataset.get_image_data(index % self._size)
34
-
35
- def get_target(self, index: int) -> Tuple[Any, int]:
36
- target = self._dataset.get_target(index % self._size)
37
- if index >= self._size:
38
- assert self._pad_dataset
39
- return (-1, target)
40
- return (index, target)
41
-
42
- def __getitem__(self, index: int) -> Tuple[Any, Tuple[Any, int]]:
43
- image, target = self._dataset[index % self._size]
44
- if index >= self._size:
45
- assert self._pad_dataset
46
- return image, (-1, target)
47
- target = index if target is None else target
48
- return image, (index, target)
49
-
50
- def __len__(self) -> int:
51
- return self._padded_size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dinov2/data/augmentations.py DELETED
@@ -1,118 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the Apache License, Version 2.0
4
- # found in the LICENSE file in the root directory of this source tree.
5
-
6
- import logging
7
-
8
- from torchvision import transforms
9
-
10
- from .transforms import (
11
- GaussianBlur,
12
- make_normalize_transform,
13
- )
14
-
15
-
16
- logger = logging.getLogger("dinov2")
17
-
18
-
19
- class DataAugmentationDINO(object):
20
- def __init__(
21
- self,
22
- global_crops_scale,
23
- local_crops_scale,
24
- local_crops_number,
25
- global_crops_size=224,
26
- local_crops_size=96,
27
- ):
28
- self.global_crops_scale = global_crops_scale
29
- self.local_crops_scale = local_crops_scale
30
- self.local_crops_number = local_crops_number
31
- self.global_crops_size = global_crops_size
32
- self.local_crops_size = local_crops_size
33
-
34
- logger.info("###################################")
35
- logger.info("Using data augmentation parameters:")
36
- logger.info(f"global_crops_scale: {global_crops_scale}")
37
- logger.info(f"local_crops_scale: {local_crops_scale}")
38
- logger.info(f"local_crops_number: {local_crops_number}")
39
- logger.info(f"global_crops_size: {global_crops_size}")
40
- logger.info(f"local_crops_size: {local_crops_size}")
41
- logger.info("###################################")
42
-
43
- # random resized crop and flip
44
- self.geometric_augmentation_global = transforms.Compose(
45
- [
46
- transforms.RandomResizedCrop(
47
- global_crops_size, scale=global_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC
48
- ),
49
- transforms.RandomHorizontalFlip(p=0.5),
50
- ]
51
- )
52
-
53
- self.geometric_augmentation_local = transforms.Compose(
54
- [
55
- transforms.RandomResizedCrop(
56
- local_crops_size, scale=local_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC
57
- ),
58
- transforms.RandomHorizontalFlip(p=0.5),
59
- ]
60
- )
61
-
62
- # color distorsions / blurring
63
- color_jittering = transforms.Compose(
64
- [
65
- transforms.RandomApply(
66
- [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)],
67
- p=0.8,
68
- ),
69
- transforms.RandomGrayscale(p=0.2),
70
- ]
71
- )
72
-
73
- global_transfo1_extra = GaussianBlur(p=1.0)
74
-
75
- global_transfo2_extra = transforms.Compose(
76
- [
77
- GaussianBlur(p=0.1),
78
- transforms.RandomSolarize(threshold=128, p=0.2),
79
- ]
80
- )
81
-
82
- local_transfo_extra = GaussianBlur(p=0.5)
83
-
84
- # normalization
85
- self.normalize = transforms.Compose(
86
- [
87
- transforms.ToTensor(),
88
- make_normalize_transform(),
89
- ]
90
- )
91
-
92
- self.global_transfo1 = transforms.Compose([color_jittering, global_transfo1_extra, self.normalize])
93
- self.global_transfo2 = transforms.Compose([color_jittering, global_transfo2_extra, self.normalize])
94
- self.local_transfo = transforms.Compose([color_jittering, local_transfo_extra, self.normalize])
95
-
96
- def __call__(self, image):
97
- output = {}
98
-
99
- # global crops:
100
- im1_base = self.geometric_augmentation_global(image)
101
- global_crop_1 = self.global_transfo1(im1_base)
102
-
103
- im2_base = self.geometric_augmentation_global(image)
104
- global_crop_2 = self.global_transfo2(im2_base)
105
-
106
- output["global_crops"] = [global_crop_1, global_crop_2]
107
-
108
- # global crops for teacher:
109
- output["global_crops_teacher"] = [global_crop_1, global_crop_2]
110
-
111
- # local crops:
112
- local_crops = [
113
- self.local_transfo(self.geometric_augmentation_local(image)) for _ in range(self.local_crops_number)
114
- ]
115
- output["local_crops"] = local_crops
116
- output["offsets"] = ()
117
-
118
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dinov2/data/cell_dino/augmentations.py DELETED
@@ -1,91 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the CC-by-NC licence,
4
- # found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree.
5
-
6
- import logging
7
- import torchvision
8
- from torchvision import transforms
9
-
10
- from .transforms import (
11
- RandomContrastProteinChannel,
12
- RandomRemoveChannelExceptProtein,
13
- RandomBrightness,
14
- RandomContrast,
15
- Div255,
16
- SelfNormalizeNoDiv,
17
- )
18
-
19
- logger = logging.getLogger("dinov2")
20
-
21
-
22
- class CellAugmentationDINO(object):
23
- def __init__(
24
- self,
25
- global_crops_scale,
26
- local_crops_scale,
27
- local_crops_number,
28
- global_crops_size=224,
29
- local_crops_size=96,
30
- ):
31
- self.global_crops_scale = global_crops_scale
32
- self.local_crops_scale = local_crops_scale
33
- self.local_crops_number = local_crops_number
34
- self.global_crops_size = global_crops_size
35
- self.local_crops_size = local_crops_size
36
-
37
- logger.info("###################################")
38
- logger.info("Using data augmentation parameters:")
39
- logger.info(f"global_crops_scale: {global_crops_scale}")
40
- logger.info(f"local_crops_scale: {local_crops_scale}")
41
- logger.info(f"local_crops_number: {local_crops_number}")
42
- logger.info(f"global_crops_size: {global_crops_size}")
43
- logger.info(f"local_crops_size: {local_crops_size}")
44
- logger.info("###################################")
45
-
46
- additional_transforms_list = [
47
- torchvision.transforms.RandomHorizontalFlip(),
48
- torchvision.transforms.RandomVerticalFlip(),
49
- RandomBrightness(),
50
- RandomContrast(),
51
- SelfNormalizeNoDiv(),
52
- ]
53
-
54
- first_transforms_list = [
55
- Div255(),
56
- RandomRemoveChannelExceptProtein(),
57
- RandomContrastProteinChannel(),
58
- ]
59
-
60
- global_transforms_list = first_transforms_list.copy()
61
- global_transforms_list.append(
62
- torchvision.transforms.RandomResizedCrop(size=global_crops_size, scale=global_crops_scale)
63
- )
64
- global_transforms_list = global_transforms_list + additional_transforms_list
65
-
66
- local_transforms_list = first_transforms_list
67
- local_transforms_list.append(
68
- torchvision.transforms.RandomResizedCrop(size=local_crops_size, scale=local_crops_scale)
69
- )
70
- local_transforms_list = local_transforms_list + additional_transforms_list
71
-
72
- self.global_transform = transforms.Compose(global_transforms_list)
73
- self.local_transform = transforms.Compose(local_transforms_list)
74
-
75
- def __call__(self, image):
76
- output = {}
77
-
78
- global_crop1 = self.global_transform(image)
79
- global_crop2 = self.global_transform(image)
80
-
81
- output["global_crops"] = [global_crop1, global_crop2]
82
-
83
- local_crops = []
84
- for _ in range(self.local_crops_number):
85
- local_crops.append(self.local_transform(image))
86
-
87
- output["local_crops"] = local_crops
88
- output["global_crops_teacher"] = [global_crop1, global_crop2]
89
- output["offsets"] = ()
90
-
91
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dinov2/data/cell_dino/transforms.py DELETED
@@ -1,169 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the CC-by-NC licence,
4
- # found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree.
5
-
6
- import torch
7
- from torchvision import transforms
8
- import numpy as np
9
- from enum import Enum
10
-
11
-
12
- class NormalizationType(Enum):
13
- SELF_NORM_AUG_DECODER = "self_norm_aug_decoder"
14
- SELF_NORM_CENTER_CROP = "self_norm_center_crop"
15
-
16
-
17
- class Div255(torch.nn.Module):
18
- def forward(self, x):
19
- x = x / 255
20
- return x
21
-
22
-
23
- class SelfNormalizeNoDiv(torch.nn.Module):
24
- def forward(self, x):
25
- m = x.mean((-2, -1), keepdim=True)
26
- s = x.std((-2, -1), unbiased=False, keepdim=True)
27
- x -= m
28
- x /= s + 1e-7
29
- return x
30
-
31
-
32
- class SelfNormalize(torch.nn.Module):
33
- def forward(self, x):
34
- x = x / 255
35
- m = x.mean((-2, -1), keepdim=True)
36
- s = x.std((-2, -1), unbiased=False, keepdim=True)
37
- x -= m
38
- x /= s + 1e-7
39
- return x
40
-
41
-
42
- class RandomContrastProteinChannel(torch.nn.Module):
43
- """
44
- Random constrast rescaling of the protein channel only.
45
- RescaleProtein function in Dino4cell codebase.
46
- """
47
-
48
- def __init__(self, p=0.2):
49
- super().__init__()
50
- self.p = p
51
-
52
- def forward(self, img):
53
- if img.max() == 0:
54
- return img
55
- if len(img) == 1:
56
- return img
57
- if np.random.rand() <= self.p:
58
- random_factor = (np.random.rand() * 2) / img.max() # scaling
59
- img[1] = img[1] * random_factor
60
- return img
61
- else:
62
- return img
63
-
64
-
65
- class RandomRemoveChannelExceptProtein(torch.nn.Module):
66
- """
67
- dropping a channel at random except the channel 1, corresponding to proteins in HPA datasets.
68
- """
69
-
70
- def __init__(self, p=0.2):
71
- super().__init__()
72
- self.p = p
73
-
74
- def forward(self, img):
75
- img_size = np.array(img).shape
76
- if img_size[0] < 4:
77
- return img
78
- if np.random.rand() <= self.p:
79
- channel_to_blacken = np.random.choice(np.array([0, 2, 3]))
80
- img[channel_to_blacken] = torch.zeros(1, *img.shape[1:])
81
- return img
82
- else:
83
- return img
84
-
85
-
86
- class RandomRemoveChannel(torch.nn.Module):
87
- """
88
- dropping a channel at random
89
- """
90
-
91
- def __init__(self, p=0.2):
92
- super().__init__()
93
- self.p = p
94
-
95
- def forward(self, img):
96
- img_size = np.array(img).shape
97
- num_channels = img_size[0]
98
- if num_channels < 4:
99
- return img
100
- if np.random.rand() <= self.p:
101
- channel_to_blacken = np.random.choice(np.array(list(range(num_channels))))
102
- img[channel_to_blacken] = torch.zeros(1, *img.shape[1:])
103
- return img
104
- else:
105
- return img
106
-
107
-
108
- class RandomContrast(torch.nn.Module):
109
- def __init__(self, p=0.2):
110
- super().__init__()
111
- self.p = p
112
-
113
- def forward(self, img):
114
- if img.max() == 0:
115
- return img
116
- n_channels = img.shape[0]
117
- for ind in range(n_channels):
118
- factor = max(np.random.normal(1, self.p), 0.5)
119
- img[ind] = transforms.functional.adjust_contrast(img[ind][None, ...], factor)
120
- return img
121
-
122
-
123
- class RandomBrightness(torch.nn.Module):
124
- def __init__(self, p=0.2):
125
- super().__init__()
126
- self.p = p
127
-
128
- def forward(self, img):
129
- if img.max() == 0:
130
- return img
131
- n_channels = img.shape[0]
132
- for ind in range(n_channels):
133
- factor = max(np.random.normal(1, self.p), 0.5)
134
- img[ind] = transforms.functional.adjust_brightness(img[ind], factor)
135
- return img
136
-
137
-
138
- def make_classification_eval_cell_transform(
139
- *,
140
- resize_size: int = 0,
141
- interpolation=transforms.InterpolationMode.BICUBIC,
142
- crop_size: int = 384,
143
- normalization_type: Enum = NormalizationType.SELF_NORM_CENTER_CROP,
144
- ) -> transforms.Compose:
145
-
146
- from .transforms import (
147
- Div255,
148
- SelfNormalizeNoDiv,
149
- )
150
-
151
- transforms_list = [Div255()]
152
- if resize_size > 0:
153
- transforms_list.append(transforms.Resize(resize_size, interpolation=interpolation))
154
-
155
- if normalization_type == NormalizationType.SELF_NORM_AUG_DECODER:
156
- transforms_list.extend(
157
- [
158
- transforms.RandomCrop(size=crop_size, pad_if_needed=True),
159
- transforms.RandomHorizontalFlip(),
160
- transforms.RandomVerticalFlip(),
161
- ]
162
- )
163
- elif normalization_type == NormalizationType.SELF_NORM_CENTER_CROP:
164
- transforms_list.append(transforms.CenterCrop(size=crop_size))
165
- else:
166
- raise ValueError("f{normalization_type}: unknown NormalizationType")
167
- transforms_list.append(SelfNormalizeNoDiv())
168
-
169
- return transforms.Compose(transforms_list)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dinov2/data/collate.py DELETED
@@ -1,49 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the Apache License, Version 2.0
4
- # found in the LICENSE file in the root directory of this source tree.
5
-
6
- import torch
7
- import random
8
-
9
-
10
- def collate_data_and_cast(samples_list, mask_ratio_tuple, mask_probability, dtype, n_tokens=None, mask_generator=None):
11
- # dtype = torch.half # TODO: Remove
12
-
13
- n_global_crops = len(samples_list[0][0]["global_crops"])
14
- n_local_crops = len(samples_list[0][0]["local_crops"])
15
-
16
- collated_global_crops = torch.stack([s[0]["global_crops"][i] for i in range(n_global_crops) for s in samples_list])
17
-
18
- collated_local_crops = torch.stack([s[0]["local_crops"][i] for i in range(n_local_crops) for s in samples_list])
19
-
20
- B = len(collated_global_crops)
21
- N = n_tokens
22
- n_samples_masked = int(B * mask_probability)
23
- probs = torch.linspace(*mask_ratio_tuple, n_samples_masked + 1)
24
- upperbound = 0
25
- masks_list = []
26
- for i in range(0, n_samples_masked):
27
- prob_min = probs[i]
28
- prob_max = probs[i + 1]
29
- masks_list.append(torch.BoolTensor(mask_generator(int(N * random.uniform(prob_min, prob_max)))))
30
- upperbound += int(N * prob_max)
31
- for i in range(n_samples_masked, B):
32
- masks_list.append(torch.BoolTensor(mask_generator(0)))
33
-
34
- random.shuffle(masks_list)
35
-
36
- collated_masks = torch.stack(masks_list).flatten(1)
37
- mask_indices_list = collated_masks.flatten().nonzero().flatten()
38
-
39
- masks_weight = (1 / collated_masks.sum(-1).clamp(min=1.0)).unsqueeze(-1).expand_as(collated_masks)[collated_masks]
40
-
41
- return {
42
- "collated_global_crops": collated_global_crops.to(dtype),
43
- "collated_local_crops": collated_local_crops.to(dtype),
44
- "collated_masks": collated_masks,
45
- "mask_indices_list": mask_indices_list,
46
- "masks_weight": masks_weight,
47
- "upperbound": upperbound,
48
- "n_masked_patches": torch.full((1,), fill_value=mask_indices_list.shape[0], dtype=torch.long),
49
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dinov2/data/datasets/__init__.py DELETED
@@ -1,12 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the Apache License, Version 2.0
4
- # found in the LICENSE file in the root directory of this source tree.
5
-
6
- from .image_net import ImageNet
7
- from .image_net_22k import ImageNet22k
8
- from .cell_dino.hpaone import HPAone
9
- from .cell_dino.hpafov import HPAFoV
10
- from .cell_dino.chammi_cp import CHAMMI_CP
11
- from .cell_dino.chammi_hpa import CHAMMI_HPA
12
- from .cell_dino.chammi_wtc import CHAMMI_WTC
 
 
 
 
 
 
 
 
 
 
 
 
 
dinov2/data/datasets/cell_dino/chammi_cp.py DELETED
@@ -1,112 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the CC-by-NC licence,
4
- # found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree.
5
-
6
- import csv
7
- from enum import Enum
8
- import logging
9
- import os
10
- from typing import Any, Callable, Optional, Union
11
-
12
- import numpy as np
13
-
14
- from ..extended import ExtendedVisionDataset
15
- from ..decoders import DecoderType
16
-
17
- logger = logging.getLogger("dinov2")
18
-
19
-
20
- METADATA_FILE = "morphem70k_v2.csv"
21
-
22
- CLASS_LABELS = {
23
- "BRD-A29260609": 0,
24
- "BRD-K04185004": 1,
25
- "BRD-K21680192": 2,
26
- "DMSO": 3,
27
- "BRD-K11129031": 4, # labels only seen in TASK_FOUR
28
- "BRD-K62310379": 5,
29
- "BRD-K77947974": 6,
30
- }
31
-
32
-
33
- class _Split(Enum):
34
- TRAIN = "Train"
35
- TASK_ONE = "Task_one"
36
- TASK_TWO = "Task_two"
37
- TASK_THREE = "Task_three"
38
- TASK_FOUR = "Task_four"
39
-
40
-
41
- def _load_file_names_and_targets(
42
- root: str,
43
- split: _Split,
44
- ):
45
- image_paths = []
46
- labels = []
47
- with open(os.path.join(root, METADATA_FILE)) as metadata:
48
- metadata_reader = csv.DictReader(metadata)
49
- for row in metadata_reader:
50
- row_dataset = row["file_path"].split("/")[0]
51
-
52
- if row["train_test_split"].upper() == split and row_dataset == "CP":
53
- image_paths.append(row["file_path"])
54
- labels.append(CLASS_LABELS[row["label"]])
55
-
56
- return image_paths, labels # to debug
57
-
58
-
59
- class CHAMMI_CP(ExtendedVisionDataset):
60
- """
61
- Implementation of the CP (Cell-Painting) subset of the CHAMMI benchmark dataset,
62
- following the CHAMMI paper: https://arxiv.org/pdf/2310.19224
63
- Github code: https://github.com/chaudatascience/channel_adaptive_models
64
- """
65
-
66
- Split = Union[_Split]
67
-
68
- def __init__(
69
- self,
70
- *,
71
- split: "CHAMMI_CP.Split",
72
- root: str,
73
- transforms: Optional[Callable] = None,
74
- transform: Optional[Callable] = None,
75
- target_transform: Optional[Callable] = None,
76
- image_decoder_type: DecoderType = DecoderType.XChannelsDecoder,
77
- **kwargs: Any,
78
- ) -> None:
79
- super().__init__(
80
- root,
81
- transforms,
82
- transform,
83
- target_transform,
84
- image_decoder_type=image_decoder_type,
85
- **kwargs,
86
- )
87
- self.split = split
88
- self.root = root
89
- self.num_additional_labels_loo_eval = 3
90
- self._image_paths, self._targets = _load_file_names_and_targets(
91
- root,
92
- split,
93
- )
94
-
95
- def get_image_relpath(self, index: int) -> str:
96
- return self._image_paths[index]
97
-
98
- def get_image_data(self, index: int) -> bytes:
99
- image_relpath = self.get_image_relpath(index)
100
- image_full_path = os.path.join(self.root, image_relpath)
101
- with open(image_full_path, mode="rb") as f:
102
- image_data = f.read()
103
- return image_data
104
-
105
- def get_target(self, index: int) -> Any:
106
- return self._targets[index]
107
-
108
- def get_targets(self) -> np.ndarray:
109
- return np.array(self._targets)
110
-
111
- def __len__(self) -> int:
112
- return len(self._image_paths)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dinov2/data/datasets/cell_dino/chammi_hpa.py DELETED
@@ -1,111 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the CC-by-NC licence,
4
- # found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree.
5
-
6
- import csv
7
- from enum import Enum
8
- import logging
9
- import os
10
- from typing import Any, Callable, Optional, Union
11
-
12
- import numpy as np
13
-
14
- from ..extended import ExtendedVisionDataset
15
- from ..decoders import DecoderType
16
-
17
- logger = logging.getLogger("dinov2")
18
-
19
-
20
- METADATA_FILE = "morphem70k_v2.csv"
21
-
22
- CLASS_LABELS = {
23
- "golgi apparatus": 0,
24
- "microtubules": 1,
25
- "mitochondria": 2,
26
- "nuclear speckles": 3,
27
- "cytosol": 4, # labels only seen in TASK_THREE
28
- "endoplasmic reticulum": 5,
29
- "nucleoplasm": 6,
30
- }
31
-
32
-
33
- class _Split(Enum):
34
- TRAIN = "Train"
35
- TASK_ONE = "Task_one"
36
- TASK_TWO = "Task_two"
37
- TASK_THREE = "Task_three"
38
-
39
-
40
- def _load_file_names_and_targets(
41
- root: str,
42
- split: _Split,
43
- ):
44
- image_paths = []
45
- labels = []
46
- with open(os.path.join(root, METADATA_FILE)) as metadata:
47
- metadata_reader = csv.DictReader(metadata)
48
- for row in metadata_reader:
49
- row_dataset = row["file_path"].split("/")[0]
50
- if row["train_test_split"].upper() == split and row_dataset == "HPA":
51
- image_paths.append(row["file_path"])
52
- labels.append(CLASS_LABELS[row["label"]])
53
-
54
- return image_paths, labels
55
-
56
-
57
- class CHAMMI_HPA(ExtendedVisionDataset):
58
- """
59
- Implementation of the CP (Cell-Painting) subset of the CHAMMI benchmark dataset,
60
- following the CHAMMI paper: https://arxiv.org/pdf/2310.19224
61
- Github code: https://github.com/chaudatascience/channel_adaptive_models
62
- """
63
-
64
- Split = Union[_Split]
65
-
66
- def __init__(
67
- self,
68
- *,
69
- split: "CHAMMI_HPA.Split",
70
- root: str,
71
- transforms: Optional[Callable] = None,
72
- transform: Optional[Callable] = None,
73
- target_transform: Optional[Callable] = None,
74
- image_decoder_type: DecoderType = DecoderType.XChannelsDecoder,
75
- **kwargs: Any,
76
- ) -> None:
77
- super().__init__(
78
- root,
79
- transforms,
80
- transform,
81
- target_transform,
82
- image_decoder_type=image_decoder_type,
83
- **kwargs,
84
- )
85
- self.split = split
86
- self.root = root
87
- self.num_additional_labels_loo_eval = 3
88
-
89
- self._image_paths, self._targets = _load_file_names_and_targets(
90
- root,
91
- split,
92
- )
93
-
94
- def get_image_relpath(self, index: int) -> str:
95
- return self._image_paths[index]
96
-
97
- def get_image_data(self, index: int) -> bytes:
98
- image_relpath = self.get_image_relpath(index)
99
- image_full_path = os.path.join(self.root, image_relpath)
100
- with open(image_full_path, mode="rb") as f:
101
- image_data = f.read()
102
- return image_data
103
-
104
- def get_target(self, index: int) -> Any:
105
- return self._targets[index]
106
-
107
- def get_targets(self) -> np.ndarray:
108
- return np.array(self._targets)
109
-
110
- def __len__(self) -> int:
111
- return len(self._image_paths)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dinov2/data/datasets/cell_dino/chammi_wtc.py DELETED
@@ -1,108 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the CC-by-NC licence,
4
- # found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree.
5
-
6
- import csv
7
- from enum import Enum
8
- import logging
9
- import os
10
- from typing import Any, Callable, Optional, Union
11
-
12
- import numpy as np
13
-
14
- from ..extended import ExtendedVisionDataset
15
- from ..decoders import DecoderType
16
-
17
- logger = logging.getLogger("dinov2")
18
-
19
-
20
- METADATA_FILE = "morphem70k_v2.csv"
21
-
22
- CLASS_LABELS = {
23
- "M0": 0,
24
- "M1M2": 1,
25
- "M3": 2,
26
- "M4M5": 3,
27
- "M6M7_complete": 4,
28
- "M6M7_single": 5,
29
- }
30
-
31
-
32
- class _Split(Enum):
33
- TRAIN = "Train"
34
- TASK_ONE = "Task_one"
35
- TASK_TWO = "Task_two"
36
-
37
-
38
- def _load_file_names_and_targets(
39
- root: str,
40
- split: _Split,
41
- ):
42
- image_paths = []
43
- labels = []
44
- with open(os.path.join(root, METADATA_FILE)) as metadata:
45
- metadata_reader = csv.DictReader(metadata)
46
- for row in metadata_reader:
47
- row_dataset = row["file_path"].split("/")[0]
48
- if row["train_test_split"].upper() == split and row_dataset == "Allen":
49
- image_paths.append(row["file_path"])
50
- labels.append(CLASS_LABELS[row["label"]])
51
-
52
- return image_paths, labels
53
-
54
-
55
- class CHAMMI_WTC(ExtendedVisionDataset):
56
- """
57
- Implementation of the CP (Cell-Painting) subset of the CHAMMI benchmark dataset,
58
- following the CHAMMI paper: https://arxiv.org/pdf/2310.19224
59
- Github code: https://github.com/chaudatascience/channel_adaptive_models
60
- """
61
-
62
- Split = Union[_Split]
63
-
64
- def __init__(
65
- self,
66
- *,
67
- split: "CHAMMI_WTC.Split",
68
- root: str,
69
- transforms: Optional[Callable] = None,
70
- transform: Optional[Callable] = None,
71
- target_transform: Optional[Callable] = None,
72
- image_decoder_type: DecoderType = DecoderType.XChannelsTIFFDecoder,
73
- **kwargs: Any,
74
- ) -> None:
75
- super().__init__(
76
- root,
77
- transforms,
78
- transform,
79
- target_transform,
80
- image_decoder_type=image_decoder_type,
81
- **kwargs,
82
- )
83
- self.split = split
84
- self.root = root
85
-
86
- self._image_paths, self._targets = _load_file_names_and_targets(
87
- root,
88
- split,
89
- )
90
-
91
- def get_image_relpath(self, index: int) -> str:
92
- return self._image_paths[index]
93
-
94
- def get_image_data(self, index: int) -> bytes:
95
- image_relpath = self.get_image_relpath(index)
96
- image_full_path = os.path.join(self.root, image_relpath)
97
- with open(image_full_path, mode="rb") as f:
98
- image_data = f.read()
99
- return image_data
100
-
101
- def get_target(self, index: int) -> Any:
102
- return self._targets[index]
103
-
104
- def get_targets(self) -> np.ndarray:
105
- return np.array(self._targets)
106
-
107
- def __len__(self) -> int:
108
- return len(self._image_paths)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dinov2/data/datasets/cell_dino/hpafov.py DELETED
@@ -1,283 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the CC-by-NC licence,
4
- # found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree.
5
-
6
- import csv
7
- from enum import Enum
8
- import logging
9
- import os
10
- from typing import Any, Callable, List, Optional, Tuple, Union, Dict
11
-
12
- import numpy as np
13
-
14
- from ..extended import ExtendedVisionDataset
15
- from ..decoders import DecoderType
16
-
17
- logger = logging.getLogger("dinov2")
18
-
19
- CELL_TYPE = [
20
- "BJ", # 1
21
- "LHCN-M2",
22
- "RH-30",
23
- "SH-SY5Y",
24
- "U-2 OS", # 5
25
- "ASC TERT1",
26
- "HaCaT",
27
- "A-431",
28
- "U-251 MG",
29
- "HEK 293", # 10
30
- "A549",
31
- "RT4",
32
- "HeLa",
33
- "MCF7",
34
- "PC-3", # 15
35
- "hTERT-RPE1",
36
- "SK-MEL-30",
37
- "EFO-21",
38
- "AF22",
39
- "HEL", # 20
40
- "Hep G2",
41
- "HUVEC TERT2",
42
- "THP-1",
43
- "CACO-2",
44
- "JURKAT", # 25
45
- "RPTEC TERT1",
46
- "SuSa",
47
- "REH",
48
- "HDLM-2",
49
- "K-562", # 30
50
- "hTCEpi",
51
- "NB-4",
52
- "HAP1",
53
- "OE19",
54
- "SiHa", # 35
55
- ]
56
-
57
- PROTEIN_LOCALIZATION = [ # matches https://www.kaggle.com/c/human-protein-atlas-image-classification/data
58
- "nucleoplasm",
59
- "nuclear membrane",
60
- "nucleoli",
61
- "nucleoli fibrillar center",
62
- "nuclear speckles", # 5
63
- "nuclear bodies",
64
- "endoplasmic reticulum",
65
- "golgi apparatus",
66
- "peroxisomes",
67
- "endosomes", # 10
68
- "lysosomes",
69
- "intermediate filaments",
70
- "actin filaments",
71
- "focal adhesion sites",
72
- "microtubules", # 15
73
- "microtubule ends",
74
- "cytokinetic bridge",
75
- "mitotic spindle",
76
- "microtubule organizing center",
77
- "centrosome", # 20
78
- "lipid droplets",
79
- "plasma membrane",
80
- "cell junctions",
81
- "mitochondria",
82
- "aggresome", # 25
83
- "cytosol",
84
- "cytoplasmic bodies",
85
- "rods & rings",
86
- ]
87
-
88
-
89
- class _Split(Enum):
90
- TRAIN = "train"
91
- VAL = "val"
92
- SSL = "ssl"
93
-
94
-
95
- def get_csv_fpath(split):
96
- """
97
- Path to data relative to root
98
- """
99
- if split == _Split.TRAIN.value.upper() or split == _Split.TRAIN or split == "TRAIN":
100
- return "whole_images_512_train.csv"
101
- elif split == _Split.VAL.value.upper() or split == _Split.VAL or split == "VAL":
102
- return "whole_images_512_test.csv"
103
-
104
-
105
- class _WildCard(Enum):
106
- NONE = "none"
107
- SEPARATECHANNELS = "separate_channels" # each channel from each image is treated as an independent sample, overrides chosen channel configuration
108
-
109
-
110
- class _Mode(Enum):
111
- """
112
- Targets:
113
- - ALL: tuple, (one hot encoding of multilabel protein localization, categorical encoding of cell type)
114
- - PROTEIN_LOCALIZATION: one hot encoding of multilabel protein localization
115
- - CELL_TYPE: categorical encoding of cell type
116
- """
117
-
118
- ALL = "all"
119
- PROTEIN_LOCALIZATION = "protein_localization"
120
- CELL_TYPE = "cell_type"
121
-
122
- @property
123
- def nb_labels(self):
124
- if self == _Mode.CELL_TYPE:
125
- return len(CELL_TYPE)
126
- elif self == _Mode.PROTEIN_LOCALIZATION:
127
- return len(PROTEIN_LOCALIZATION)
128
- else:
129
- return None
130
-
131
-
132
- def _list_images_from_csv(img_path, csv_path):
133
- L = []
134
- with open(csv_path) as filename:
135
- reader = csv.DictReader(filename)
136
- for row in reader:
137
- L.append(os.path.join(img_path, row["ID"] + ".png"))
138
- return L
139
-
140
-
141
- def _load_file_names_and_labels_ssl(
142
- root: str,
143
- ) -> Tuple[List[str], List[Any]]:
144
-
145
- curr_img_path = os.path.join(root, "normalized_data")
146
- csv_train_ssl = os.path.join(root, "whole_images_names.csv")
147
- image_paths = _list_images_from_csv(curr_img_path, csv_train_ssl)
148
- labels = [i for i in range(len(image_paths))]
149
- return image_paths, labels
150
-
151
-
152
- def _load_file_names_and_labels(
153
- root: str,
154
- split: _Split,
155
- mode: _Mode,
156
- ) -> Tuple[List[str], List[Any], np.ndarray]:
157
-
158
- data_path = os.path.join(root, "512_whole_images")
159
- csv_fpath = os.path.join(root, get_csv_fpath(split))
160
-
161
- image_paths = []
162
- labels = []
163
-
164
- with open(csv_fpath) as filename:
165
- reader = csv.DictReader(filename)
166
- for row in reader:
167
-
168
- add_sample = True
169
- if mode != _Mode.PROTEIN_LOCALIZATION.value.upper():
170
- # categorical
171
- if row["cell_type"] in CELL_TYPE:
172
- cell_type = CELL_TYPE.index(row["cell_type"])
173
- else:
174
- cell_type = np.nan
175
-
176
- if mode != _Mode.CELL_TYPE.value.upper():
177
- # one hot encoding
178
- prot_loc = np.zeros(len(PROTEIN_LOCALIZATION), dtype=np.int_)
179
- for k in range(len(PROTEIN_LOCALIZATION)):
180
- if row[PROTEIN_LOCALIZATION[k]] == "True":
181
- prot_loc[k] = 1
182
- if prot_loc.max() < 0.5:
183
- add_sample = False
184
-
185
- if add_sample:
186
- if mode == _Mode.PROTEIN_LOCALIZATION.value.upper():
187
- labels.append(prot_loc)
188
- elif mode == _Mode.CELL_TYPE.value.upper():
189
- labels.append(cell_type)
190
- else:
191
- labels.append({"prot_loc": prot_loc, "cell_type": cell_type})
192
-
193
- candidate_path = os.path.join(data_path, row["file"].split("/")[-1])
194
- if os.path.exists(candidate_path):
195
- image_paths.append(candidate_path)
196
- else:
197
- candidate_path = os.path.join(
198
- data_path, row["file"].split("/")[-1].split(".")[0] + ".tiff"
199
- ) # _blue.png") # some images on the normalized_data folder have a _blue suffix on their names
200
- if os.path.exists(candidate_path):
201
- image_paths.append(candidate_path)
202
- else:
203
- raise FileNotFoundError(f"File {candidate_path} not found.")
204
-
205
- return image_paths, labels
206
-
207
-
208
- class HPAFoV(ExtendedVisionDataset):
209
- Split = Union[_Split]
210
- Mode = Union[_Mode]
211
- WildCard = Union[_WildCard]
212
-
213
- def __init__(
214
- self,
215
- *,
216
- split: "HPAFoV.Split" = _Split.TRAIN,
217
- mode: "HPAFoV.Mode" = _Mode.ALL,
218
- wildcard: "HPAFoV.WildCard" = _WildCard.NONE,
219
- root: str,
220
- transforms: Optional[Callable] = None,
221
- transform: Optional[Callable] = None,
222
- target_transform: Optional[Callable] = None,
223
- image_decoder_type: DecoderType = DecoderType.ChannelSelectDecoder,
224
- image_decoder_params: Dict[str, Any] = {},
225
- **kwargs: Any,
226
- ) -> None:
227
- super().__init__(
228
- root,
229
- transforms,
230
- transform,
231
- target_transform,
232
- image_decoder_type=image_decoder_type,
233
- image_decoder_params={
234
- "select_channel": True
235
- if wildcard == _WildCard.SEPARATECHANNELS or wildcard == "SEPARATE_CHANNELS"
236
- else False
237
- },
238
- **kwargs,
239
- )
240
- self.mode = mode
241
- self.split = split
242
- self.root = root
243
- self.wildcard = wildcard
244
- self.channel_adaptive = True
245
- if split == _Split.SSL.value.upper() or split == _Split.SSL or split == "SSL":
246
- self._image_paths, self._labels = _load_file_names_and_labels_ssl(root)
247
- else:
248
- self._image_paths, self._labels = _load_file_names_and_labels(root, self.split, self.mode)
249
-
250
- self._channels = np.repeat(np.array([[0, 1, 2, 3]]), len(self._image_paths), axis=0).tolist()
251
-
252
- if self.wildcard == _WildCard.SEPARATECHANNELS.value.upper():
253
- image_paths, labels, channels = self._image_paths, self._labels, self._channels
254
- channels = np.array(channels)
255
- # separate and stack the columns of the channels array
256
- C = channels.shape[1]
257
- channels = np.concatenate([channels[:, i] for i in range(C)])
258
- self._channels = np.expand_dims(channels, 1).tolist()
259
- self.image_paths = image_paths * C
260
- self.labels = labels * C
261
-
262
- def get_image_relpath(self, index: int) -> str:
263
- return self._image_paths[index]
264
-
265
- def get_image_data(self, index: int) -> bytes:
266
- image_relpath = self.get_image_relpath(index)
267
- image_full_path = os.path.join(self.root, image_relpath)
268
- with open(image_full_path, mode="rb") as f:
269
- image_data = f.read()
270
- if self.channel_adaptive:
271
- channels = self._channels[index]
272
- return image_data + bytes(channels) + (len(channels)).to_bytes(1, byteorder="big")
273
- else:
274
- return image_data
275
-
276
- def get_target(self, index: int) -> Any:
277
- return self._labels[index]
278
-
279
- def get_targets(self) -> np.ndarray:
280
- return np.array(self._labels)
281
-
282
- def __len__(self) -> int:
283
- return len(self._image_paths)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dinov2/data/datasets/cell_dino/hpaone.py DELETED
@@ -1,223 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the CC-by-NC licence,
4
- # found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree.
5
-
6
- import csv
7
- from enum import Enum
8
- import logging
9
- import os
10
- from typing import Any, Callable, List, Optional, Tuple, Union
11
-
12
- import numpy as np
13
-
14
- from ..extended import ExtendedVisionDataset
15
- from ..decoders import DecoderType
16
-
17
- logger = logging.getLogger("dinov2")
18
-
19
- PROTEIN_LOCALIZATION = [
20
- "actin filaments,focal adhesion sites",
21
- "aggresome",
22
- "centrosome,centriolar satellite",
23
- "cytosol",
24
- "endoplasmic reticulum",
25
- "golgi apparatus",
26
- "intermediate filaments",
27
- "microtubules",
28
- "mitochondria",
29
- "mitotic spindle",
30
- "no staining",
31
- "nuclear bodies",
32
- "nuclear membrane",
33
- "nuclear speckles",
34
- "nucleoli",
35
- "nucleoli fibrillar center",
36
- "nucleoplasm",
37
- "plasma membrane,cell junctions",
38
- "vesicles,peroxisomes,endosomes,lysosomes,lipid droplets,cytoplasmic bodies",
39
- ] # 19
40
-
41
-
42
- CELL_TYPE = [
43
- "A-431", # 0
44
- "A549",
45
- "AF22",
46
- "ASC TERT1",
47
- "BJ",
48
- "CACO-2",
49
- "EFO-21",
50
- "HAP1",
51
- "HDLM-2",
52
- "HEK 293", # 9
53
- "HEL",
54
- "HUVEC TERT2",
55
- "HaCaT",
56
- "HeLa",
57
- "Hep G2",
58
- "JURKAT",
59
- "K-562",
60
- "MCF7",
61
- "PC-3",
62
- "REH",
63
- "RH-30", # 20
64
- "RPTEC TERT1",
65
- "RT4",
66
- "SH-SY5Y",
67
- "SK-MEL-30",
68
- "SiHa",
69
- "U-2 OS",
70
- "U-251 MG",
71
- "hTCEpi", # 28
72
- ] # 29 cell types
73
-
74
-
75
- class _Split(Enum):
76
- VAL = "val"
77
- TRAIN = "train"
78
- ALL = "all" # images without labels, for encoder training
79
-
80
-
81
- class _Mode(Enum):
82
- PROTEIN_LOCALIZATION = "protein_localization"
83
- CELL_TYPE = "cell_type"
84
-
85
- @property
86
- def num_labels(self):
87
- if self == _Mode.CELL_TYPE.value.upper():
88
- return len(CELL_TYPE)
89
- return len(PROTEIN_LOCALIZATION)
90
-
91
-
92
- def _simple_parse_csv(img_rootdir, csv_filepath: str):
93
- samples = []
94
- with open(csv_filepath) as filename:
95
- template = csv.DictReader(filename)
96
- samples = [(os.path.join(img_rootdir, row["img_path"]), 0) for row in template]
97
- return samples
98
-
99
-
100
- def _parse_csv(img_rootdir, csv_labels_path: str):
101
- nb_protein_location = len(PROTEIN_LOCALIZATION)
102
- nb_cell_type = len(CELL_TYPE)
103
- samples = []
104
- with open(csv_labels_path) as filename:
105
- reader = csv.DictReader(filename)
106
- for row in reader:
107
- protein_location = np.zeros(nb_protein_location, dtype=np.int_)
108
- for k in range(nb_protein_location):
109
- if row[PROTEIN_LOCALIZATION[k]] == "True":
110
- protein_location[k] = 1
111
-
112
- cell_type = 0
113
- for k in range(nb_cell_type):
114
- if row[CELL_TYPE[k]] == "True":
115
- cell_type = k
116
-
117
- samples.append(
118
- (
119
- img_rootdir + "/" + row["file"].rsplit("/", 1)[1],
120
- protein_location,
121
- cell_type,
122
- )
123
- )
124
- return samples
125
-
126
-
127
- def _load_file_names_and_labels_ssl(
128
- root: str,
129
- ) -> Tuple[List[str], List[Any]]:
130
- curr_dir_train = os.path.join(root, "varied_size_masked_single_cells_HPA")
131
- csv_all_path = os.path.join(root, "varied_size_masked_single_cells_pretrain_20240507.csv")
132
- samples = _simple_parse_csv(curr_dir_train, csv_all_path)
133
- image_paths, fake_labels = zip(*samples)
134
- lab = list(fake_labels)
135
- return image_paths, lab
136
-
137
-
138
- def _load_file_names_and_labels_train_or_test(
139
- root: str,
140
- split: _Split,
141
- mode: _Mode,
142
- ) -> Tuple[List[str], List[Any]]:
143
-
144
- if split == _Split.TRAIN.value.upper() or split == _Split.TRAIN:
145
- csv_labels_path = os.path.join(root, "fixed_size_masked_single_cells_pretrain_20240507.csv")
146
- elif split == _Split.VAL.value.upper() or split == _Split.VAL:
147
- csv_labels_path = os.path.join(root, "fixed_size_masked_single_cells_evaluation_20240507.csv")
148
- else:
149
- print("wrong split name")
150
- curr_dir_val = os.path.join(root, "fixed_size_masked_single_cells_HPA")
151
-
152
- samples = _parse_csv(curr_dir_val, csv_labels_path)
153
- image_paths, protein_location, cell_type = zip(*samples)
154
- if mode == _Mode.PROTEIN_LOCALIZATION.value.upper():
155
- lab = protein_location
156
- elif mode == _Mode.CELL_TYPE.value.upper():
157
- lab = cell_type
158
- else:
159
- lab = protein_location, cell_type
160
- image_paths = list(image_paths)
161
- return image_paths, lab
162
-
163
-
164
- class HPAone(ExtendedVisionDataset):
165
- Split = Union[_Split]
166
- Mode = Union[_Mode]
167
-
168
- def __init__(
169
- self,
170
- *,
171
- split: "HPAone.Split" = _Split.ALL,
172
- mode: "HPAone.Mode" = None,
173
- root: str,
174
- transforms: Optional[Callable] = None,
175
- transform: Optional[Callable] = None,
176
- target_transform: Optional[Callable] = None,
177
- image_decoder_type: DecoderType = DecoderType.XChannelsDecoder,
178
- **kwargs: Any,
179
- ) -> None:
180
- super().__init__(
181
- root,
182
- transforms,
183
- transform,
184
- target_transform,
185
- image_decoder_type=image_decoder_type,
186
- **kwargs,
187
- )
188
- self.mode = mode
189
- self.split = split
190
- self.root = root
191
-
192
- if (
193
- split in {_Split.TRAIN.value.upper(), _Split.VAL.value.upper()}
194
- or split == _Split.TRAIN
195
- or split == _Split.VAL
196
- ):
197
- (
198
- self._image_paths,
199
- self._labels,
200
- ) = _load_file_names_and_labels_train_or_test(root, split, mode)
201
- elif split == _Split.ALL.value.upper() or split == _Split.ALL:
202
- self._image_paths, self._labels = _load_file_names_and_labels_ssl(root)
203
- else:
204
- logger.info(f"unknown split: {split}, {_Split.ALL.value.upper()}")
205
-
206
- def get_image_relpath(self, index: int) -> str:
207
- return self._image_paths[index]
208
-
209
- def get_image_data(self, index: int) -> bytes:
210
- image_relpath = self.get_image_relpath(index)
211
- image_full_path = os.path.join(self.root, image_relpath)
212
- with open(image_full_path, mode="rb") as f:
213
- image_data = f.read()
214
- return image_data
215
-
216
- def get_target(self, index: int) -> Any:
217
- return self._labels[index]
218
-
219
- def get_targets(self) -> np.ndarray:
220
- return np.array(self._labels)
221
-
222
- def __len__(self) -> int:
223
- return len(self._image_paths)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dinov2/data/datasets/decoders.py DELETED
@@ -1,94 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the Apache License, Version 2.0
4
- # found in the LICENSE file in the root directory of this source tree.
5
-
6
- from io import BytesIO
7
- from typing import Any, Type
8
-
9
- from PIL import Image
10
- import numpy as np
11
- import torch
12
- from enum import Enum
13
-
14
- try:
15
- import tifffile
16
- except ImportError:
17
- print("Could not import `tifffile`, TIFFImageDataDecoder will be disabled")
18
-
19
-
20
- class Decoder:
21
- def decode(self) -> Any:
22
- raise NotImplementedError
23
-
24
-
25
- class DecoderType(Enum):
26
- ImageDataDecoder = "ImageDataDecoder"
27
- XChannelsDecoder = "XChannelsDecoder"
28
- XChannelsTIFFDecoder = "XChannelsTIFFDecoder"
29
- ChannelSelectDecoder = "ChannelSelectDecoder"
30
-
31
- def get_class(self) -> Type[Decoder]: # noqa: C901
32
- if self == DecoderType.ImageDataDecoder:
33
- return ImageDataDecoder
34
- if self == DecoderType.XChannelsDecoder:
35
- return XChannelsDecoder
36
- if self == DecoderType.XChannelsTIFFDecoder:
37
- return XChannelsTIFFDecoder
38
- if self == DecoderType.ChannelSelectDecoder:
39
- return ChannelSelectDecoder
40
-
41
-
42
- class ImageDataDecoder(Decoder):
43
- def __init__(self, image_data: bytes) -> None:
44
- self._image_data = image_data
45
-
46
- def decode(self) -> Image:
47
- f = BytesIO(self._image_data)
48
- return Image.open(f).convert(mode="RGB")
49
-
50
-
51
- class TargetDecoder(Decoder):
52
- def __init__(self, target: Any):
53
- self._target = target
54
-
55
- def decode(self) -> Any:
56
- return self._target
57
-
58
-
59
- class XChannelsDecoder(Decoder):
60
- def __init__(self, image_data: bytes) -> None:
61
- self._image_data = image_data
62
-
63
- def decode(self):
64
- im = np.asarray(Image.open(BytesIO(self._image_data)))
65
- if len(im.shape) == 2:
66
- im = np.reshape(im, (im.shape[0], im.shape[0], -1), order="F")
67
- return torch.Tensor(im).permute(2, 0, 1)
68
-
69
-
70
- class XChannelsTIFFDecoder(Decoder):
71
- def __init__(self, image_data: bytes, num_channels: int = 3) -> None:
72
- self._image_data = image_data
73
- self._num_channels = num_channels
74
-
75
- def decode(self):
76
- numpy_array = tifffile.imread(BytesIO(self._image_data))
77
- numpy_array = np.reshape(numpy_array, (numpy_array.shape[0], -1, self._num_channels), order="F")
78
- return torch.Tensor(numpy_array).permute(2, 0, 1)
79
-
80
-
81
- class ChannelSelectDecoder(Decoder):
82
- def __init__(self, image_data: bytes, select_channel: bool = False) -> None:
83
- self.select_channel = select_channel
84
- if select_channel:
85
- self._image_data = image_data[:-1]
86
- self._channel = image_data[-1]
87
- else:
88
- self._image_data = image_data
89
-
90
- def decode(self):
91
- im = np.asarray(Image.open(BytesIO(self._image_data)))
92
- if self.select_channel:
93
- return torch.Tensor(im).permute(2, 0, 1)[[self._channel]]
94
- return torch.Tensor(im).permute(2, 0, 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dinov2/data/datasets/extended.py DELETED
@@ -1,44 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the Apache License, Version 2.0
4
- # found in the LICENSE file in the root directory of this source tree.
5
-
6
- from typing import Any, Tuple
7
-
8
- from torchvision.datasets import VisionDataset
9
-
10
- from .decoders import DecoderType, TargetDecoder
11
-
12
-
13
- class ExtendedVisionDataset(VisionDataset):
14
- def __init__(self, *args, **kwargs) -> None:
15
- image_decoder_type = kwargs.pop("image_decoder_type", DecoderType.ImageDataDecoder)
16
- self._decoder_params = {}
17
- self._image_decoder_class = image_decoder_type.get_class()
18
- if "image_decoder_params" in kwargs:
19
- self._decoder_params = kwargs.pop("image_decoder_params")
20
-
21
- super().__init__(*args, **kwargs) # type: ignore
22
-
23
- def get_image_data(self, index: int) -> bytes:
24
- raise NotImplementedError
25
-
26
- def get_target(self, index: int) -> Any:
27
- raise NotImplementedError
28
-
29
- def __getitem__(self, index: int) -> Tuple[Any, Any]:
30
- try:
31
- image_data = self.get_image_data(index)
32
- image = self._image_decoder_class(image_data, **self._decoder_params).decode()
33
- except Exception as e:
34
- raise RuntimeError(f"can not read image for sample {index}") from e
35
- target = self.get_target(index)
36
- target = TargetDecoder(target).decode()
37
-
38
- if self.transforms is not None:
39
- image, target = self.transforms(image, target)
40
-
41
- return image, target
42
-
43
- def __len__(self) -> int:
44
- raise NotImplementedError
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dinov2/data/datasets/image_net.py DELETED
@@ -1,290 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the Apache License, Version 2.0
4
- # found in the LICENSE file in the root directory of this source tree.
5
-
6
- import csv
7
- from enum import Enum
8
- import logging
9
- import os
10
- from typing import Callable, List, Optional, Tuple, Union
11
-
12
- import numpy as np
13
-
14
- from .extended import ExtendedVisionDataset
15
-
16
-
17
- logger = logging.getLogger("dinov2")
18
- _Target = int
19
-
20
-
21
- class _Split(Enum):
22
- TRAIN = "train"
23
- VAL = "val"
24
- TEST = "test" # NOTE: torchvision does not support the test split
25
-
26
- @property
27
- def length(self) -> int:
28
- split_lengths = {
29
- _Split.TRAIN: 1_281_167,
30
- _Split.VAL: 50_000,
31
- _Split.TEST: 100_000,
32
- }
33
- return split_lengths[self]
34
-
35
- def get_dirname(self, class_id: Optional[str] = None) -> str:
36
- return self.value if class_id is None else os.path.join(self.value, class_id)
37
-
38
- def get_image_relpath(self, actual_index: int, class_id: Optional[str] = None) -> str:
39
- dirname = self.get_dirname(class_id)
40
- if self == _Split.TRAIN:
41
- basename = f"{class_id}_{actual_index}"
42
- else: # self in (_Split.VAL, _Split.TEST):
43
- basename = f"ILSVRC2012_{self.value}_{actual_index:08d}"
44
- return os.path.join(dirname, basename + ".JPEG")
45
-
46
- def parse_image_relpath(self, image_relpath: str) -> Tuple[str, int]:
47
- assert self != _Split.TEST
48
- dirname, filename = os.path.split(image_relpath)
49
- class_id = os.path.split(dirname)[-1]
50
- basename, _ = os.path.splitext(filename)
51
- actual_index = int(basename.split("_")[-1])
52
- return class_id, actual_index
53
-
54
-
55
- class ImageNet(ExtendedVisionDataset):
56
- Target = Union[_Target]
57
- Split = Union[_Split]
58
-
59
- def __init__(
60
- self,
61
- *,
62
- split: "ImageNet.Split",
63
- root: str,
64
- extra: str,
65
- transforms: Optional[Callable] = None,
66
- transform: Optional[Callable] = None,
67
- target_transform: Optional[Callable] = None,
68
- ) -> None:
69
- super().__init__(root, transforms, transform, target_transform)
70
- self._extra_root = extra
71
- self._split = split
72
-
73
- self._entries = None
74
- self._class_ids = None
75
- self._class_names = None
76
-
77
- @property
78
- def split(self) -> "ImageNet.Split":
79
- return self._split
80
-
81
- def _get_extra_full_path(self, extra_path: str) -> str:
82
- return os.path.join(self._extra_root, extra_path)
83
-
84
- def _load_extra(self, extra_path: str) -> np.ndarray:
85
- extra_full_path = self._get_extra_full_path(extra_path)
86
- return np.load(extra_full_path, mmap_mode="r")
87
-
88
- def _save_extra(self, extra_array: np.ndarray, extra_path: str) -> None:
89
- extra_full_path = self._get_extra_full_path(extra_path)
90
- os.makedirs(self._extra_root, exist_ok=True)
91
- np.save(extra_full_path, extra_array)
92
-
93
- @property
94
- def _entries_path(self) -> str:
95
- return f"entries-{self._split.value.upper()}.npy"
96
-
97
- @property
98
- def _class_ids_path(self) -> str:
99
- return f"class-ids-{self._split.value.upper()}.npy"
100
-
101
- @property
102
- def _class_names_path(self) -> str:
103
- return f"class-names-{self._split.value.upper()}.npy"
104
-
105
- def _get_entries(self) -> np.ndarray:
106
- if self._entries is None:
107
- self._entries = self._load_extra(self._entries_path)
108
- assert self._entries is not None
109
- return self._entries
110
-
111
- def _get_class_ids(self) -> np.ndarray:
112
- if self._split == _Split.TEST:
113
- assert False, "Class IDs are not available in TEST split"
114
- if self._class_ids is None:
115
- self._class_ids = self._load_extra(self._class_ids_path)
116
- assert self._class_ids is not None
117
- return self._class_ids
118
-
119
- def _get_class_names(self) -> np.ndarray:
120
- if self._split == _Split.TEST:
121
- assert False, "Class names are not available in TEST split"
122
- if self._class_names is None:
123
- self._class_names = self._load_extra(self._class_names_path)
124
- assert self._class_names is not None
125
- return self._class_names
126
-
127
- def find_class_id(self, class_index: int) -> str:
128
- class_ids = self._get_class_ids()
129
- return str(class_ids[class_index])
130
-
131
- def find_class_name(self, class_index: int) -> str:
132
- class_names = self._get_class_names()
133
- return str(class_names[class_index])
134
-
135
- def get_image_data(self, index: int) -> bytes:
136
- entries = self._get_entries()
137
- actual_index = entries[index]["actual_index"]
138
-
139
- class_id = self.get_class_id(index)
140
-
141
- image_relpath = self.split.get_image_relpath(actual_index, class_id)
142
- image_full_path = os.path.join(self.root, image_relpath)
143
- with open(image_full_path, mode="rb") as f:
144
- image_data = f.read()
145
- return image_data
146
-
147
- def get_target(self, index: int) -> Optional[Target]:
148
- entries = self._get_entries()
149
- class_index = entries[index]["class_index"]
150
- return None if self.split == _Split.TEST else int(class_index)
151
-
152
- def get_targets(self) -> Optional[np.ndarray]:
153
- entries = self._get_entries()
154
- return None if self.split == _Split.TEST else entries["class_index"]
155
-
156
- def get_class_id(self, index: int) -> Optional[str]:
157
- entries = self._get_entries()
158
- class_id = entries[index]["class_id"]
159
- return None if self.split == _Split.TEST else str(class_id)
160
-
161
- def get_class_name(self, index: int) -> Optional[str]:
162
- entries = self._get_entries()
163
- class_name = entries[index]["class_name"]
164
- return None if self.split == _Split.TEST else str(class_name)
165
-
166
- def __len__(self) -> int:
167
- entries = self._get_entries()
168
- assert len(entries) == self.split.length
169
- return len(entries)
170
-
171
- def _load_labels(self, labels_path: str) -> List[Tuple[str, str]]:
172
- labels_full_path = os.path.join(self.root, labels_path)
173
- labels = []
174
-
175
- try:
176
- with open(labels_full_path, "r") as f:
177
- reader = csv.reader(f)
178
- for row in reader:
179
- class_id, class_name = row
180
- labels.append((class_id, class_name))
181
- except OSError as e:
182
- raise RuntimeError(f'can not read labels file "{labels_full_path}"') from e
183
-
184
- return labels
185
-
186
- def _dump_entries(self) -> None:
187
- split = self.split
188
- if split == ImageNet.Split.TEST:
189
- dataset = None
190
- sample_count = split.length
191
- max_class_id_length, max_class_name_length = 0, 0
192
- else:
193
- labels_path = "labels.txt"
194
- logger.info(f'loading labels from "{labels_path}"')
195
- labels = self._load_labels(labels_path)
196
-
197
- # NOTE: Using torchvision ImageFolder for consistency
198
- from torchvision.datasets import ImageFolder
199
-
200
- dataset_root = os.path.join(self.root, split.get_dirname())
201
- dataset = ImageFolder(dataset_root)
202
- sample_count = len(dataset)
203
- max_class_id_length, max_class_name_length = -1, -1
204
- for sample in dataset.samples:
205
- _, class_index = sample
206
- class_id, class_name = labels[class_index]
207
- max_class_id_length = max(len(class_id), max_class_id_length)
208
- max_class_name_length = max(len(class_name), max_class_name_length)
209
-
210
- dtype = np.dtype(
211
- [
212
- ("actual_index", "<u4"),
213
- ("class_index", "<u4"),
214
- ("class_id", f"U{max_class_id_length}"),
215
- ("class_name", f"U{max_class_name_length}"),
216
- ]
217
- )
218
- entries_array = np.empty(sample_count, dtype=dtype)
219
-
220
- if split == ImageNet.Split.TEST:
221
- old_percent = -1
222
- for index in range(sample_count):
223
- percent = 100 * (index + 1) // sample_count
224
- if percent > old_percent:
225
- logger.info(f"creating entries: {percent}%")
226
- old_percent = percent
227
-
228
- actual_index = index + 1
229
- class_index = np.uint32(-1)
230
- class_id, class_name = "", ""
231
- entries_array[index] = (actual_index, class_index, class_id, class_name)
232
- else:
233
- class_names = {class_id: class_name for class_id, class_name in labels}
234
-
235
- assert dataset
236
- old_percent = -1
237
- for index in range(sample_count):
238
- percent = 100 * (index + 1) // sample_count
239
- if percent > old_percent:
240
- logger.info(f"creating entries: {percent}%")
241
- old_percent = percent
242
-
243
- image_full_path, class_index = dataset.samples[index]
244
- image_relpath = os.path.relpath(image_full_path, self.root)
245
- class_id, actual_index = split.parse_image_relpath(image_relpath)
246
- class_name = class_names[class_id]
247
- entries_array[index] = (actual_index, class_index, class_id, class_name)
248
-
249
- logger.info(f'saving entries to "{self._entries_path}"')
250
- self._save_extra(entries_array, self._entries_path)
251
-
252
- def _dump_class_ids_and_names(self) -> None:
253
- split = self.split
254
- if split == ImageNet.Split.TEST:
255
- return
256
-
257
- entries_array = self._load_extra(self._entries_path)
258
-
259
- max_class_id_length, max_class_name_length, max_class_index = -1, -1, -1
260
- for entry in entries_array:
261
- class_index, class_id, class_name = (
262
- entry["class_index"],
263
- entry["class_id"],
264
- entry["class_name"],
265
- )
266
- max_class_index = max(int(class_index), max_class_index)
267
- max_class_id_length = max(len(str(class_id)), max_class_id_length)
268
- max_class_name_length = max(len(str(class_name)), max_class_name_length)
269
-
270
- class_count = max_class_index + 1
271
- class_ids_array = np.empty(class_count, dtype=f"U{max_class_id_length}")
272
- class_names_array = np.empty(class_count, dtype=f"U{max_class_name_length}")
273
- for entry in entries_array:
274
- class_index, class_id, class_name = (
275
- entry["class_index"],
276
- entry["class_id"],
277
- entry["class_name"],
278
- )
279
- class_ids_array[class_index] = class_id
280
- class_names_array[class_index] = class_name
281
-
282
- logger.info(f'saving class IDs to "{self._class_ids_path}"')
283
- self._save_extra(class_ids_array, self._class_ids_path)
284
-
285
- logger.info(f'saving class names to "{self._class_names_path}"')
286
- self._save_extra(class_names_array, self._class_names_path)
287
-
288
- def dump_extra(self) -> None:
289
- self._dump_entries()
290
- self._dump_class_ids_and_names()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dinov2/data/datasets/image_net_22k.py DELETED
@@ -1,302 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the Apache License, Version 2.0
4
- # found in the LICENSE file in the root directory of this source tree.
5
-
6
- from dataclasses import dataclass
7
- from enum import Enum
8
- from functools import lru_cache
9
- from gzip import GzipFile
10
- from io import BytesIO
11
- from mmap import ACCESS_READ, mmap
12
- import os
13
- from typing import Any, Callable, List, Optional, Set, Tuple
14
- import warnings
15
-
16
- import numpy as np
17
-
18
- from .extended import ExtendedVisionDataset
19
-
20
-
21
- _Labels = int
22
-
23
- _DEFAULT_MMAP_CACHE_SIZE = 16 # Warning: This can exhaust file descriptors
24
-
25
-
26
- @dataclass
27
- class _ClassEntry:
28
- block_offset: int
29
- maybe_filename: Optional[str] = None
30
-
31
-
32
- @dataclass
33
- class _Entry:
34
- class_index: int # noqa: E701
35
- start_offset: int
36
- end_offset: int
37
- filename: str
38
-
39
-
40
- class _Split(Enum):
41
- TRAIN = "train"
42
- VAL = "val"
43
-
44
- @property
45
- def length(self) -> int:
46
- return {
47
- _Split.TRAIN: 11_797_647,
48
- _Split.VAL: 561_050,
49
- }[self]
50
-
51
- def entries_path(self):
52
- return f"imagenet21kp_{self.value}.txt"
53
-
54
-
55
- def _get_tarball_path(class_id: str) -> str:
56
- return f"{class_id}.tar"
57
-
58
-
59
- def _make_mmap_tarball(tarballs_root: str, mmap_cache_size: int):
60
- @lru_cache(maxsize=mmap_cache_size)
61
- def _mmap_tarball(class_id: str) -> mmap:
62
- tarball_path = _get_tarball_path(class_id)
63
- tarball_full_path = os.path.join(tarballs_root, tarball_path)
64
- with open(tarball_full_path) as f:
65
- return mmap(fileno=f.fileno(), length=0, access=ACCESS_READ)
66
-
67
- return _mmap_tarball
68
-
69
-
70
- class ImageNet22k(ExtendedVisionDataset):
71
- _GZIPPED_INDICES: Set[int] = {
72
- 841_545,
73
- 1_304_131,
74
- 2_437_921,
75
- 2_672_079,
76
- 2_795_676,
77
- 2_969_786,
78
- 6_902_965,
79
- 6_903_550,
80
- 6_903_628,
81
- 7_432_557,
82
- 7_432_589,
83
- 7_813_809,
84
- 8_329_633,
85
- 10_296_990,
86
- 10_417_652,
87
- 10_492_265,
88
- 10_598_078,
89
- 10_782_398,
90
- 10_902_612,
91
- 11_203_736,
92
- 11_342_890,
93
- 11_397_596,
94
- 11_589_762,
95
- 11_705_103,
96
- 12_936_875,
97
- 13_289_782,
98
- }
99
- Labels = _Labels
100
-
101
- def __init__(
102
- self,
103
- *,
104
- root: str,
105
- extra: str,
106
- transforms: Optional[Callable] = None,
107
- transform: Optional[Callable] = None,
108
- target_transform: Optional[Callable] = None,
109
- mmap_cache_size: int = _DEFAULT_MMAP_CACHE_SIZE,
110
- ) -> None:
111
- super().__init__(root, transforms, transform, target_transform)
112
- self._extra_root = extra
113
-
114
- entries_path = self._get_entries_path(root)
115
- self._entries = self._load_extra(entries_path)
116
-
117
- class_ids_path = self._get_class_ids_path(root)
118
- self._class_ids = self._load_extra(class_ids_path)
119
-
120
- self._gzipped_indices = ImageNet22k._GZIPPED_INDICES
121
- self._mmap_tarball = _make_mmap_tarball(self._tarballs_root, mmap_cache_size)
122
-
123
- def _get_entries_path(self, root: Optional[str] = None) -> str:
124
- return "entries.npy"
125
-
126
- def _get_class_ids_path(self, root: Optional[str] = None) -> str:
127
- return "class-ids.npy"
128
-
129
- def _find_class_ids(self, path: str) -> List[str]:
130
- class_ids = []
131
-
132
- with os.scandir(path) as entries:
133
- for entry in entries:
134
- root, ext = os.path.splitext(entry.name)
135
- if ext != ".tar":
136
- continue
137
- class_ids.append(root)
138
-
139
- return sorted(class_ids)
140
-
141
- def _load_entries_class_ids(self, root: Optional[str] = None) -> Tuple[List[_Entry], List[str]]:
142
- root = self.get_root(root)
143
- entries: List[_Entry] = []
144
- class_ids = self._find_class_ids(root)
145
-
146
- for class_index, class_id in enumerate(class_ids):
147
- path = os.path.join(root, "blocks", f"{class_id}.log")
148
- class_entries = []
149
-
150
- try:
151
- with open(path) as f:
152
- for line in f:
153
- line = line.rstrip()
154
- block, filename = line.split(":")
155
- block_offset = int(block[6:])
156
- filename = filename[1:]
157
-
158
- maybe_filename = None
159
- if filename != "** Block of NULs **":
160
- maybe_filename = filename
161
- _, ext = os.path.splitext(filename)
162
- # assert ext == ".JPEG"
163
-
164
- class_entry = _ClassEntry(block_offset, maybe_filename)
165
- class_entries.append(class_entry)
166
- except OSError as e:
167
- raise RuntimeError(f'can not read blocks file "{path}"') from e
168
-
169
- assert class_entries[-1].maybe_filename is None
170
-
171
- for class_entry1, class_entry2 in zip(class_entries, class_entries[1:]):
172
- assert class_entry1.block_offset <= class_entry2.block_offset
173
- start_offset = 512 * class_entry1.block_offset
174
- end_offset = 512 * class_entry2.block_offset
175
- assert class_entry1.maybe_filename is not None
176
- filename = class_entry1.maybe_filename
177
- entry = _Entry(class_index, start_offset, end_offset, filename)
178
- # Skip invalid image files (PIL throws UnidentifiedImageError)
179
- if filename == "n06470073_47249.JPEG":
180
- continue
181
- entries.append(entry)
182
-
183
- return entries, class_ids
184
-
185
- def _load_extra(self, extra_path: str) -> np.ndarray:
186
- extra_root = self._extra_root
187
- extra_full_path = os.path.join(extra_root, extra_path)
188
- return np.load(extra_full_path, mmap_mode="r")
189
-
190
- def _save_extra(self, extra_array: np.ndarray, extra_path: str) -> None:
191
- extra_root = self._extra_root
192
- extra_full_path = os.path.join(extra_root, extra_path)
193
- os.makedirs(extra_root, exist_ok=True)
194
- np.save(extra_full_path, extra_array)
195
-
196
- @property
197
- def _tarballs_root(self) -> str:
198
- return self.root
199
-
200
- def find_class_id(self, class_index: int) -> str:
201
- return str(self._class_ids[class_index])
202
-
203
- def get_image_data(self, index: int) -> bytes:
204
- entry = self._entries[index]
205
- class_id = entry["class_id"]
206
- class_mmap = self._mmap_tarball(class_id)
207
-
208
- start_offset, end_offset = entry["start_offset"], entry["end_offset"]
209
- try:
210
- mapped_data = class_mmap[start_offset:end_offset]
211
- data = mapped_data[512:] # Skip entry header block
212
-
213
- if len(data) >= 2 and tuple(data[:2]) == (0x1F, 0x8B):
214
- assert index in self._gzipped_indices, f"unexpected gzip header for sample {index}"
215
- with GzipFile(fileobj=BytesIO(data)) as g:
216
- data = g.read()
217
- except Exception as e:
218
- raise RuntimeError(f"can not retrieve image data for sample {index} " f'from "{class_id}" tarball') from e
219
-
220
- return data
221
-
222
- def get_target(self, index: int) -> Any:
223
- return int(self._entries[index]["class_index"])
224
-
225
- def get_targets(self) -> np.ndarray:
226
- return self._entries["class_index"]
227
-
228
- def get_class_id(self, index: int) -> str:
229
- return str(self._entries[index]["class_id"])
230
-
231
- def get_class_ids(self) -> np.ndarray:
232
- return self._entries["class_id"]
233
-
234
- def __getitem__(self, index: int) -> Tuple[Any, Any]:
235
- with warnings.catch_warnings():
236
- warnings.simplefilter("ignore")
237
- return super().__getitem__(index)
238
-
239
- def __len__(self) -> int:
240
- return len(self._entries)
241
-
242
- def _dump_entries(self, *args, **kwargs) -> None:
243
- entries, class_ids = self._load_entries_class_ids(*args, **kwargs)
244
-
245
- max_class_id_length, max_filename_length, max_class_index = -1, -1, -1
246
- for entry in entries:
247
- class_id = class_ids[entry.class_index]
248
- max_class_index = max(entry.class_index, max_class_index)
249
- max_class_id_length = max(len(class_id), max_class_id_length)
250
- max_filename_length = max(len(entry.filename), max_filename_length)
251
-
252
- dtype = np.dtype(
253
- [
254
- ("class_index", "<u4"),
255
- ("class_id", f"U{max_class_id_length}"),
256
- ("start_offset", "<u4"),
257
- ("end_offset", "<u4"),
258
- ("filename", f"U{max_filename_length}"),
259
- ]
260
- )
261
- sample_count = len(entries)
262
- entries_array = np.empty(sample_count, dtype=dtype)
263
- for i, entry in enumerate(entries):
264
- class_index = entry.class_index
265
- class_id = class_ids[class_index]
266
- start_offset = entry.start_offset
267
- end_offset = entry.end_offset
268
- filename = entry.filename
269
- entries_array[i] = (
270
- class_index,
271
- class_id,
272
- start_offset,
273
- end_offset,
274
- filename,
275
- )
276
-
277
- entries_path = self._get_entries_path(*args, **kwargs)
278
- self._save_extra(entries_array, entries_path)
279
-
280
- def _dump_class_ids(self, *args, **kwargs) -> None:
281
- entries_path = self._get_entries_path(*args, **kwargs)
282
- entries_array = self._load_extra(entries_path)
283
-
284
- max_class_id_length, max_class_index = -1, -1
285
- for entry in entries_array:
286
- class_index, class_id = entry["class_index"], entry["class_id"]
287
- max_class_index = max(int(class_index), max_class_index)
288
- max_class_id_length = max(len(str(class_id)), max_class_id_length)
289
-
290
- class_ids_array = np.empty(max_class_index + 1, dtype=f"U{max_class_id_length}")
291
- for entry in entries_array:
292
- class_index, class_id = entry["class_index"], entry["class_id"]
293
- class_ids_array[class_index] = class_id
294
- class_ids_path = self._get_class_ids_path(*args, **kwargs)
295
- self._save_extra(class_ids_array, class_ids_path)
296
-
297
- def _dump_extra(self, *args, **kwargs) -> None:
298
- self._dump_entries(*args, *kwargs)
299
- self._dump_class_ids(*args, *kwargs)
300
-
301
- def dump_extra(self, root: Optional[str] = None) -> None:
302
- return self._dump_extra(root)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dinov2/data/loaders.py DELETED
@@ -1,232 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the Apache License, Version 2.0
4
- # found in the LICENSE file in the root directory of this source tree.
5
-
6
- import logging
7
- from enum import Enum
8
- from typing import Any, Callable, List, Optional, TypeVar
9
-
10
- import torch
11
- from torch.utils.data import Sampler
12
-
13
- from .datasets import ImageNet, ImageNet22k, HPAone, HPAFoV, CHAMMI_CP, CHAMMI_HPA, CHAMMI_WTC
14
- from .samplers import EpochSampler, InfiniteSampler, ShardedInfiniteSampler
15
-
16
-
17
- logger = logging.getLogger("dinov2")
18
-
19
-
20
- class SamplerType(Enum):
21
- DISTRIBUTED = 0
22
- EPOCH = 1
23
- INFINITE = 2
24
- SHARDED_INFINITE = 3
25
- SHARDED_INFINITE_NEW = 4
26
-
27
-
28
- def _make_bool_str(b: bool) -> str:
29
- return "yes" if b else "no"
30
-
31
-
32
- def _make_sample_transform(image_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None):
33
- def transform(sample):
34
- image, target = sample
35
- if image_transform is not None:
36
- image = image_transform(image)
37
- if target_transform is not None:
38
- target = target_transform(target)
39
- return image, target
40
-
41
- return transform
42
-
43
-
44
- def _parse_dataset_str(dataset_str: str):
45
- tokens = dataset_str.split(":")
46
-
47
- name = tokens[0]
48
- kwargs = {}
49
-
50
- for token in tokens[1:]:
51
- key, value = token.split("=")
52
- assert key in ("root", "extra", "split", "mode", "wildcard")
53
- kwargs[key] = value
54
-
55
- if name == "ImageNet":
56
- class_ = ImageNet
57
- if "split" in kwargs:
58
- kwargs["split"] = ImageNet.Split[kwargs["split"]]
59
- elif name == "ImageNet22k":
60
- class_ = ImageNet22k
61
- elif name == "HPAone":
62
- class_ = HPAone
63
- elif name == "HPAFoV":
64
- class_ = HPAFoV
65
- elif name == "CHAMMI_CP":
66
- class_ = CHAMMI_CP
67
- elif name == "CHAMMI_WTC":
68
- class_ = CHAMMI_WTC
69
- elif name == "CHAMMI_HPA":
70
- class_ = CHAMMI_HPA
71
- else:
72
- raise ValueError(f'Unsupported dataset "{name}"')
73
-
74
- return class_, kwargs
75
-
76
-
77
- def make_dataset(
78
- *,
79
- dataset_str: str,
80
- transform: Optional[Callable] = None,
81
- target_transform: Optional[Callable] = None,
82
- ):
83
- """
84
- Creates a dataset with the specified parameters.
85
-
86
- Args:
87
- dataset_str: A dataset string description (e.g. ImageNet:split=TRAIN).
88
- transform: A transform to apply to images.
89
- target_transform: A transform to apply to targets.
90
-
91
- Returns:
92
- The created dataset.
93
- """
94
- logger.info(f'using dataset: "{dataset_str}"')
95
-
96
- class_, kwargs = _parse_dataset_str(dataset_str)
97
- dataset = class_(transform=transform, target_transform=target_transform, **kwargs)
98
-
99
- logger.info(f"# of dataset samples: {len(dataset):,d}")
100
-
101
- # Aggregated datasets do not expose (yet) these attributes, so add them.
102
- if not hasattr(dataset, "transform"):
103
- setattr(dataset, "transform", transform)
104
- if not hasattr(dataset, "target_transform"):
105
- setattr(dataset, "target_transform", target_transform)
106
-
107
- return dataset
108
-
109
-
110
- def _make_sampler(
111
- *,
112
- dataset,
113
- type: Optional[SamplerType] = None,
114
- shuffle: bool = False,
115
- seed: int = 0,
116
- size: int = -1,
117
- advance: int = 0,
118
- ) -> Optional[Sampler]:
119
- sample_count = len(dataset)
120
-
121
- if type == SamplerType.INFINITE:
122
- logger.info("sampler: infinite")
123
- if size > 0:
124
- raise ValueError("sampler size > 0 is invalid")
125
- return InfiniteSampler(
126
- sample_count=sample_count,
127
- shuffle=shuffle,
128
- seed=seed,
129
- advance=advance,
130
- )
131
- elif type in (SamplerType.SHARDED_INFINITE, SamplerType.SHARDED_INFINITE_NEW):
132
- logger.info("sampler: sharded infinite")
133
- if size > 0:
134
- raise ValueError("sampler size > 0 is invalid")
135
- # TODO: Remove support for old shuffling
136
- use_new_shuffle_tensor_slice = type == SamplerType.SHARDED_INFINITE_NEW
137
- return ShardedInfiniteSampler(
138
- sample_count=sample_count,
139
- shuffle=shuffle,
140
- seed=seed,
141
- advance=advance,
142
- use_new_shuffle_tensor_slice=use_new_shuffle_tensor_slice,
143
- )
144
- elif type == SamplerType.EPOCH:
145
- logger.info("sampler: epoch")
146
- if advance > 0:
147
- raise NotImplementedError("sampler advance > 0 is not supported")
148
- size = size if size > 0 else sample_count
149
- logger.info(f"# of samples / epoch: {size:,d}")
150
- return EpochSampler(
151
- size=size,
152
- sample_count=sample_count,
153
- shuffle=shuffle,
154
- seed=seed,
155
- )
156
- elif type == SamplerType.DISTRIBUTED:
157
- logger.info("sampler: distributed")
158
- if size > 0:
159
- raise ValueError("sampler size > 0 is invalid")
160
- if advance > 0:
161
- raise ValueError("sampler advance > 0 is invalid")
162
- return torch.utils.data.DistributedSampler(
163
- dataset=dataset,
164
- shuffle=shuffle,
165
- seed=seed,
166
- drop_last=False,
167
- )
168
-
169
- logger.info("sampler: none")
170
- return None
171
-
172
-
173
- T = TypeVar("T")
174
-
175
-
176
- def make_data_loader(
177
- *,
178
- dataset,
179
- batch_size: int,
180
- num_workers: int,
181
- shuffle: bool = True,
182
- seed: int = 0,
183
- sampler_type: Optional[SamplerType] = SamplerType.INFINITE,
184
- sampler_size: int = -1,
185
- sampler_advance: int = 0,
186
- drop_last: bool = True,
187
- persistent_workers: bool = False,
188
- collate_fn: Optional[Callable[[List[T]], Any]] = None,
189
- ):
190
- """
191
- Creates a data loader with the specified parameters.
192
-
193
- Args:
194
- dataset: A dataset (third party, LaViDa or WebDataset).
195
- batch_size: The size of batches to generate.
196
- num_workers: The number of workers to use.
197
- shuffle: Whether to shuffle samples.
198
- seed: The random seed to use.
199
- sampler_type: Which sampler to use: EPOCH, INFINITE, SHARDED_INFINITE, SHARDED_INFINITE_NEW, DISTRIBUTED or None.
200
- sampler_size: The number of images per epoch (when applicable) or -1 for the entire dataset.
201
- sampler_advance: How many samples to skip (when applicable).
202
- drop_last: Whether the last non-full batch of data should be dropped.
203
- persistent_workers: maintain the workers Dataset instances alive after a dataset has been consumed once.
204
- collate_fn: Function that performs batch collation
205
- """
206
-
207
- sampler = _make_sampler(
208
- dataset=dataset,
209
- type=sampler_type,
210
- shuffle=shuffle,
211
- seed=seed,
212
- size=sampler_size,
213
- advance=sampler_advance,
214
- )
215
-
216
- logger.info("using PyTorch data loader")
217
- data_loader = torch.utils.data.DataLoader(
218
- dataset,
219
- sampler=sampler,
220
- batch_size=batch_size,
221
- num_workers=num_workers,
222
- pin_memory=True,
223
- drop_last=drop_last,
224
- persistent_workers=persistent_workers,
225
- collate_fn=collate_fn,
226
- )
227
-
228
- try:
229
- logger.info(f"# of batches: {len(data_loader):,d}")
230
- except TypeError: # data loader has no length
231
- logger.info("infinite data loader")
232
- return data_loader
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dinov2/data/masking.py DELETED
@@ -1,86 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the Apache License, Version 2.0
4
- # found in the LICENSE file in the root directory of this source tree.
5
-
6
- import random
7
- import math
8
- import numpy as np
9
-
10
-
11
- class MaskingGenerator:
12
- def __init__(
13
- self,
14
- input_size,
15
- num_masking_patches=None,
16
- min_num_patches=4,
17
- max_num_patches=None,
18
- min_aspect=0.3,
19
- max_aspect=None,
20
- ):
21
- if not isinstance(input_size, tuple):
22
- input_size = (input_size,) * 2
23
- self.height, self.width = input_size
24
-
25
- self.num_patches = self.height * self.width
26
- self.num_masking_patches = num_masking_patches
27
-
28
- self.min_num_patches = min_num_patches
29
- self.max_num_patches = num_masking_patches if max_num_patches is None else max_num_patches
30
-
31
- max_aspect = max_aspect or 1 / min_aspect
32
- self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
33
-
34
- def __repr__(self):
35
- repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
36
- self.height,
37
- self.width,
38
- self.min_num_patches,
39
- self.max_num_patches,
40
- self.num_masking_patches,
41
- self.log_aspect_ratio[0],
42
- self.log_aspect_ratio[1],
43
- )
44
- return repr_str
45
-
46
- def get_shape(self):
47
- return self.height, self.width
48
-
49
- def _mask(self, mask, max_mask_patches):
50
- delta = 0
51
- for _ in range(10):
52
- target_area = random.uniform(self.min_num_patches, max_mask_patches)
53
- aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
54
- h = int(round(math.sqrt(target_area * aspect_ratio)))
55
- w = int(round(math.sqrt(target_area / aspect_ratio)))
56
- if w < self.width and h < self.height:
57
- top = random.randint(0, self.height - h)
58
- left = random.randint(0, self.width - w)
59
-
60
- num_masked = mask[top : top + h, left : left + w].sum()
61
- # Overlap
62
- if 0 < h * w - num_masked <= max_mask_patches:
63
- for i in range(top, top + h):
64
- for j in range(left, left + w):
65
- if mask[i, j] == 0:
66
- mask[i, j] = 1
67
- delta += 1
68
-
69
- if delta > 0:
70
- break
71
- return delta
72
-
73
- def __call__(self, num_masking_patches=0):
74
- mask = np.zeros(shape=self.get_shape(), dtype=bool)
75
- mask_count = 0
76
- while mask_count < num_masking_patches:
77
- max_mask_patches = num_masking_patches - mask_count
78
- max_mask_patches = min(max_mask_patches, self.max_num_patches)
79
-
80
- delta = self._mask(mask, max_mask_patches)
81
- if delta == 0:
82
- break
83
- else:
84
- mask_count += delta
85
-
86
- return mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dinov2/data/samplers.py DELETED
@@ -1,229 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the Apache License, Version 2.0
4
- # found in the LICENSE file in the root directory of this source tree.
5
-
6
- import itertools
7
- from typing import Any, Optional
8
- import warnings
9
-
10
- import numpy as np
11
- import torch
12
- from torch.utils.data.sampler import Sampler
13
-
14
- import dinov2.distributed as distributed
15
-
16
-
17
- class EpochSampler(Sampler):
18
- def __init__(
19
- self,
20
- *,
21
- size: int,
22
- sample_count: int,
23
- shuffle: bool = False,
24
- seed: int = 0,
25
- start: Optional[int] = None,
26
- step: Optional[int] = None,
27
- ):
28
- self._size = size
29
- self._sample_count = sample_count
30
- self._shuffle = shuffle
31
- self._seed = seed
32
- self._start = distributed.get_global_rank() if start is None else start
33
- self._step = distributed.get_global_size() if step is None else step
34
- self._epoch = 0
35
-
36
- def __iter__(self):
37
- count = (self._size + self._sample_count - 1) // self._sample_count
38
- tiled_indices = np.tile(np.arange(self._sample_count), count)
39
- if self._shuffle:
40
- seed = self._seed * self._epoch if self._seed != 0 else self._epoch
41
- rng = np.random.default_rng(seed)
42
- iterable = rng.choice(tiled_indices, self._size, replace=False)
43
- else:
44
- iterable = tiled_indices[: self._size]
45
-
46
- yield from itertools.islice(iterable, self._start, None, self._step)
47
-
48
- def __len__(self):
49
- return (self._size - self._start + self._step - 1) // self._step
50
-
51
- def set_epoch(self, epoch):
52
- self._epoch = epoch
53
-
54
-
55
- def _get_numpy_dtype(size: int) -> Any:
56
- return np.int32 if size <= 2**31 else np.int64
57
-
58
-
59
- def _get_torch_dtype(size: int) -> Any:
60
- return torch.int32 if size <= 2**31 else torch.int64
61
-
62
-
63
- def _generate_randperm_indices(*, size: int, generator: torch.Generator):
64
- """Generate the indices of a random permutation."""
65
- dtype = _get_torch_dtype(size)
66
- # This is actually matching PyTorch's CPU implementation, see: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorFactories.cpp#L900-L921
67
- perm = torch.arange(size, dtype=dtype)
68
- for i in range(size):
69
- j = torch.randint(i, size, size=(1,), generator=generator).item()
70
-
71
- # Always swap even if no-op
72
- value = perm[j].item()
73
- perm[j] = perm[i].item()
74
- perm[i] = value
75
- yield value
76
-
77
-
78
- class InfiniteSampler(Sampler):
79
- def __init__(
80
- self,
81
- *,
82
- sample_count: int,
83
- shuffle: bool = False,
84
- seed: int = 0,
85
- start: Optional[int] = None,
86
- step: Optional[int] = None,
87
- advance: int = 0,
88
- ):
89
- self._sample_count = sample_count
90
- self._seed = seed
91
- self._shuffle = shuffle
92
- self._start = distributed.get_global_rank() if start is None else start
93
- self._step = distributed.get_global_size() if step is None else step
94
- self._advance = advance
95
-
96
- def __iter__(self):
97
- if self._shuffle:
98
- iterator = self._shuffled_iterator()
99
- else:
100
- iterator = self._iterator()
101
-
102
- yield from itertools.islice(iterator, self._advance, None)
103
-
104
- def _iterator(self):
105
- assert not self._shuffle
106
-
107
- while True:
108
- iterable = range(self._sample_count)
109
- yield from itertools.islice(iterable, self._start, None, self._step)
110
-
111
- def _shuffled_iterator(self):
112
- assert self._shuffle
113
-
114
- # Instantiate a generator here (rather than in the ctor) to keep the class
115
- # picklable (requirement of mp.spawn)
116
- generator = torch.Generator().manual_seed(self._seed)
117
-
118
- while True:
119
- iterable = _generate_randperm_indices(size=self._sample_count, generator=generator)
120
- yield from itertools.islice(iterable, self._start, None, self._step)
121
-
122
-
123
- # The following function is somewhat equivalent to _new_shuffle_tensor_slice below,
124
- # but avoids a full in-place random permutation generation.
125
- def _shuffle_tensor_slice(
126
- *, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator
127
- ) -> np.ndarray:
128
- stop = len(tensor)
129
- count = stop // step
130
- drop_count = stop - step * count
131
- if drop_count:
132
- warnings.warn(f"# of dropped samples: {drop_count}")
133
-
134
- dtype = _get_numpy_dtype(stop)
135
- result = np.empty(count, dtype=dtype)
136
-
137
- for i in range(count):
138
- j = torch.randint(0, i + 1, size=(1,), generator=generator).item() if i > 0 else 0
139
-
140
- result[i] = result[j]
141
- result[j] = tensor[start + i * step].item()
142
-
143
- return result
144
-
145
-
146
- def _new_shuffle_tensor_slice(
147
- *, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator
148
- ) -> np.ndarray:
149
- stop = len(tensor)
150
- count = stop // step
151
- dtype = torch.int64 # Needed for using randperm result as indices
152
- count = stop // step
153
- drop_count = stop - step * count
154
- if drop_count:
155
- warnings.warn(f"# of dropped samples: {drop_count}")
156
- indices = torch.randperm(count, dtype=dtype, generator=generator)
157
- return tensor[start::step][indices].numpy()
158
-
159
-
160
- def _make_seed(seed: int, start: int, iter_count: int) -> int:
161
- # NOTE: Tried a few variants (including iter_count << 32), this one worked best.
162
- return seed + start + (iter_count << 24)
163
-
164
-
165
- class ShardedInfiniteSampler(Sampler):
166
- def __init__(
167
- self,
168
- *,
169
- sample_count: int,
170
- shuffle: bool = False,
171
- seed: int = 0,
172
- start: Optional[int] = None,
173
- step: Optional[int] = None,
174
- advance: int = 0,
175
- use_new_shuffle_tensor_slice: bool = False,
176
- ):
177
- self._sample_count = sample_count
178
- self._seed = seed
179
- self._shuffle = shuffle
180
- self._start = distributed.get_global_rank() if start is None else start
181
- self._step = distributed.get_global_size() if step is None else step
182
- self._advance = advance
183
- self._iter_count = 0
184
- self._shuffle_tensor_slice_fn = (
185
- _new_shuffle_tensor_slice if use_new_shuffle_tensor_slice else _shuffle_tensor_slice
186
- )
187
-
188
- def __iter__(self):
189
- iter_count = self._advance // self._sample_count
190
- if iter_count > 0:
191
- self._advance -= iter_count * self._sample_count
192
- self._iter_count += iter_count
193
-
194
- if self._shuffle:
195
- iterator = self._shuffled_iterator()
196
- else:
197
- iterator = self._iterator()
198
-
199
- yield from itertools.islice(iterator, self._advance, None)
200
-
201
- def _iterator(self):
202
- assert not self._shuffle
203
-
204
- while True:
205
- iterable = range(self._sample_count)
206
- yield from itertools.islice(iterable, self._start, None, self._step)
207
-
208
- def _shuffled_iterator(self):
209
- assert self._shuffle
210
-
211
- # Instantiate a generator here (rather than in the ctor) to be keep the class
212
- # picklable (requirement of mp.spawn)
213
- generator = torch.Generator()
214
-
215
- # Always shuffle everything first
216
- generator.manual_seed(self._seed)
217
- dtype = _get_torch_dtype(self._sample_count)
218
- perm = torch.randperm(self._sample_count, dtype=dtype, generator=generator)
219
-
220
- while True:
221
- # Re-seed on each iteration to allow skipping whole permutations
222
- seed = _make_seed(self._seed, self._start, self._iter_count)
223
- generator.manual_seed(seed)
224
-
225
- iterable = self._shuffle_tensor_slice_fn(
226
- tensor=perm, start=self._start, step=self._step, generator=generator
227
- )
228
- yield from iterable
229
- self._iter_count += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dinov2/data/transforms.py DELETED
@@ -1,91 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the Apache License, Version 2.0
4
- # found in the LICENSE file in the root directory of this source tree.
5
-
6
- from typing import Sequence
7
-
8
- import torch
9
- from torchvision import transforms
10
-
11
-
12
- class GaussianBlur(transforms.RandomApply):
13
- """
14
- Apply Gaussian Blur to the PIL image.
15
- """
16
-
17
- def __init__(self, *, p: float = 0.5, radius_min: float = 0.1, radius_max: float = 2.0):
18
- # NOTE: torchvision is applying 1 - probability to return the original image
19
- keep_p = 1 - p
20
- transform = transforms.GaussianBlur(kernel_size=9, sigma=(radius_min, radius_max))
21
- super().__init__(transforms=[transform], p=keep_p)
22
-
23
-
24
- class MaybeToTensor(transforms.ToTensor):
25
- """
26
- Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor, or keep as is if already a tensor.
27
- """
28
-
29
- def __call__(self, pic):
30
- """
31
- Args:
32
- pic (PIL Image, numpy.ndarray or torch.tensor): Image to be converted to tensor.
33
- Returns:
34
- Tensor: Converted image.
35
- """
36
- if isinstance(pic, torch.Tensor):
37
- return pic
38
- return super().__call__(pic)
39
-
40
-
41
- # Use timm's names
42
- IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
43
- IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
44
-
45
-
46
- def make_normalize_transform(
47
- mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
48
- std: Sequence[float] = IMAGENET_DEFAULT_STD,
49
- ) -> transforms.Normalize:
50
- return transforms.Normalize(mean=mean, std=std)
51
-
52
-
53
- # This roughly matches torchvision's preset for classification training:
54
- # https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L6-L44
55
- def make_classification_train_transform(
56
- *,
57
- crop_size: int = 224,
58
- interpolation=transforms.InterpolationMode.BICUBIC,
59
- hflip_prob: float = 0.5,
60
- mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
61
- std: Sequence[float] = IMAGENET_DEFAULT_STD,
62
- ):
63
- transforms_list = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
64
- if hflip_prob > 0.0:
65
- transforms_list.append(transforms.RandomHorizontalFlip(hflip_prob))
66
- transforms_list.extend(
67
- [
68
- MaybeToTensor(),
69
- make_normalize_transform(mean=mean, std=std),
70
- ]
71
- )
72
- return transforms.Compose(transforms_list)
73
-
74
-
75
- # This matches (roughly) torchvision's preset for classification evaluation:
76
- # https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L47-L69
77
- def make_classification_eval_transform(
78
- *,
79
- resize_size: int = 256,
80
- interpolation=transforms.InterpolationMode.BICUBIC,
81
- crop_size: int = 224,
82
- mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
83
- std: Sequence[float] = IMAGENET_DEFAULT_STD,
84
- ) -> transforms.Compose:
85
- transforms_list = [
86
- transforms.Resize(resize_size, interpolation=interpolation),
87
- transforms.CenterCrop(crop_size),
88
- MaybeToTensor(),
89
- make_normalize_transform(mean=mean, std=std),
90
- ]
91
- return transforms.Compose(transforms_list)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dinov2/distributed/__init__.py DELETED
@@ -1,270 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the Apache License, Version 2.0
4
- # found in the LICENSE file in the root directory of this source tree.
5
-
6
- import os
7
- import random
8
- import re
9
- import socket
10
- from typing import Dict, List
11
-
12
- import torch
13
- import torch.distributed as dist
14
-
15
- _LOCAL_RANK = -1
16
- _LOCAL_WORLD_SIZE = -1
17
-
18
-
19
- def is_enabled() -> bool:
20
- """
21
- Returns:
22
- True if distributed training is enabled
23
- """
24
- return dist.is_available() and dist.is_initialized()
25
-
26
-
27
- def get_global_size() -> int:
28
- """
29
- Returns:
30
- The number of processes in the process group
31
- """
32
- return dist.get_world_size() if is_enabled() else 1
33
-
34
-
35
- def get_global_rank() -> int:
36
- """
37
- Returns:
38
- The rank of the current process within the global process group.
39
- """
40
- return dist.get_rank() if is_enabled() else 0
41
-
42
-
43
- def get_local_rank() -> int:
44
- """
45
- Returns:
46
- The rank of the current process within the local (per-machine) process group.
47
- """
48
- if not is_enabled():
49
- return 0
50
- assert 0 <= _LOCAL_RANK < _LOCAL_WORLD_SIZE
51
- return _LOCAL_RANK
52
-
53
-
54
- def get_local_size() -> int:
55
- """
56
- Returns:
57
- The size of the per-machine process group,
58
- i.e. the number of processes per machine.
59
- """
60
- if not is_enabled():
61
- return 1
62
- assert 0 <= _LOCAL_RANK < _LOCAL_WORLD_SIZE
63
- return _LOCAL_WORLD_SIZE
64
-
65
-
66
- def is_main_process() -> bool:
67
- """
68
- Returns:
69
- True if the current process is the main one.
70
- """
71
- return get_global_rank() == 0
72
-
73
-
74
- def _restrict_print_to_main_process() -> None:
75
- """
76
- This function disables printing when not in the main process
77
- """
78
- import builtins as __builtin__
79
-
80
- builtin_print = __builtin__.print
81
-
82
- def print(*args, **kwargs):
83
- force = kwargs.pop("force", False)
84
- if is_main_process() or force:
85
- builtin_print(*args, **kwargs)
86
-
87
- __builtin__.print = print
88
-
89
-
90
- def _get_master_port(seed: int = 0) -> int:
91
- MIN_MASTER_PORT, MAX_MASTER_PORT = (20_000, 60_000)
92
-
93
- master_port_str = os.environ.get("MASTER_PORT")
94
- if master_port_str is None:
95
- rng = random.Random(seed)
96
- return rng.randint(MIN_MASTER_PORT, MAX_MASTER_PORT)
97
-
98
- return int(master_port_str)
99
-
100
-
101
- def _get_available_port() -> int:
102
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
103
- # A "" host address means INADDR_ANY i.e. binding to all interfaces.
104
- # Note this is not compatible with IPv6.
105
- s.bind(("", 0))
106
- port = s.getsockname()[1]
107
- return port
108
-
109
-
110
- _TORCH_DISTRIBUTED_ENV_VARS = (
111
- "MASTER_ADDR",
112
- "MASTER_PORT",
113
- "RANK",
114
- "WORLD_SIZE",
115
- "LOCAL_RANK",
116
- "LOCAL_WORLD_SIZE",
117
- )
118
-
119
-
120
- def _collect_env_vars() -> Dict[str, str]:
121
- return {env_var: os.environ[env_var] for env_var in _TORCH_DISTRIBUTED_ENV_VARS if env_var in os.environ}
122
-
123
-
124
- def _is_slurm_job_process() -> bool:
125
- return "SLURM_JOB_ID" in os.environ
126
-
127
-
128
- def _parse_slurm_node_list(s: str) -> List[str]:
129
- nodes = []
130
- # Extract "hostname", "hostname[1-2,3,4-5]," substrings
131
- p = re.compile(r"(([^\[]+)(?:\[([^\]]+)\])?),?")
132
- for m in p.finditer(s):
133
- prefix, suffixes = s[m.start(2) : m.end(2)], s[m.start(3) : m.end(3)]
134
- for suffix in suffixes.split(","):
135
- span = suffix.split("-")
136
- if len(span) == 1:
137
- nodes.append(prefix + suffix)
138
- else:
139
- width = len(span[0])
140
- start, end = int(span[0]), int(span[1]) + 1
141
- nodes.extend([prefix + f"{i:0{width}}" for i in range(start, end)])
142
- return nodes
143
-
144
-
145
- def _check_env_variable(key: str, new_value: str):
146
- # Only check for difference with preset environment variables
147
- if key in os.environ and os.environ[key] != new_value:
148
- raise RuntimeError(f"Cannot export environment variables as {key} is already set")
149
-
150
-
151
- class _TorchDistributedEnvironment:
152
- def __init__(self):
153
- self.master_addr = "127.0.0.1"
154
- self.master_port = 0
155
- self.rank = -1
156
- self.world_size = -1
157
- self.local_rank = -1
158
- self.local_world_size = -1
159
-
160
- if _is_slurm_job_process():
161
- return self._set_from_slurm_env()
162
-
163
- env_vars = _collect_env_vars()
164
- if not env_vars:
165
- # Environment is not set
166
- pass
167
- elif len(env_vars) == len(_TORCH_DISTRIBUTED_ENV_VARS):
168
- # Environment is fully set
169
- return self._set_from_preset_env()
170
- else:
171
- # Environment is partially set
172
- collected_env_vars = ", ".join(env_vars.keys())
173
- raise RuntimeError(f"Partially set environment: {collected_env_vars}")
174
-
175
- if torch.cuda.device_count() > 0:
176
- return self._set_from_local()
177
-
178
- raise RuntimeError("Can't initialize PyTorch distributed environment")
179
-
180
- # Slurm job created with sbatch, submitit, etc...
181
- def _set_from_slurm_env(self):
182
- # logger.info("Initialization from Slurm environment")
183
- job_id = int(os.environ["SLURM_JOB_ID"])
184
- node_count = int(os.environ["SLURM_JOB_NUM_NODES"])
185
- nodes = _parse_slurm_node_list(os.environ["SLURM_JOB_NODELIST"])
186
- assert len(nodes) == node_count
187
-
188
- self.master_addr = nodes[0]
189
- self.master_port = _get_master_port(seed=job_id)
190
- self.rank = int(os.environ["SLURM_PROCID"])
191
- self.world_size = int(os.environ["SLURM_NTASKS"])
192
- assert self.rank < self.world_size
193
- self.local_rank = int(os.environ["SLURM_LOCALID"])
194
- self.local_world_size = self.world_size // node_count
195
- assert self.local_rank < self.local_world_size
196
-
197
- # Single node job with preset environment (i.e. torchrun)
198
- def _set_from_preset_env(self):
199
- # logger.info("Initialization from preset environment")
200
- self.master_addr = os.environ["MASTER_ADDR"]
201
- self.master_port = os.environ["MASTER_PORT"]
202
- self.rank = int(os.environ["RANK"])
203
- self.world_size = int(os.environ["WORLD_SIZE"])
204
- assert self.rank < self.world_size
205
- self.local_rank = int(os.environ["LOCAL_RANK"])
206
- self.local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])
207
- assert self.local_rank < self.local_world_size
208
-
209
- # Single node and GPU job (i.e. local script run)
210
- def _set_from_local(self):
211
- # logger.info("Initialization from local")
212
- self.master_addr = "127.0.0.1"
213
- self.master_port = _get_available_port()
214
- self.rank = 0
215
- self.world_size = 1
216
- self.local_rank = 0
217
- self.local_world_size = 1
218
-
219
- def export(self, *, overwrite: bool) -> "_TorchDistributedEnvironment":
220
- # See the "Environment variable initialization" section from
221
- # https://pytorch.org/docs/stable/distributed.html for the complete list of
222
- # environment variables required for the env:// initialization method.
223
- env_vars = {
224
- "MASTER_ADDR": self.master_addr,
225
- "MASTER_PORT": str(self.master_port),
226
- "RANK": str(self.rank),
227
- "WORLD_SIZE": str(self.world_size),
228
- "LOCAL_RANK": str(self.local_rank),
229
- "LOCAL_WORLD_SIZE": str(self.local_world_size),
230
- }
231
- if not overwrite:
232
- for k, v in env_vars.items():
233
- _check_env_variable(k, v)
234
-
235
- os.environ.update(env_vars)
236
- return self
237
-
238
-
239
- def enable(*, set_cuda_current_device: bool = True, overwrite: bool = False, allow_nccl_timeout: bool = False):
240
- """Enable distributed mode
241
-
242
- Args:
243
- set_cuda_current_device: If True, call torch.cuda.set_device() to set the
244
- current PyTorch CUDA device to the one matching the local rank.
245
- overwrite: If True, overwrites already set variables. Else fails.
246
- """
247
-
248
- global _LOCAL_RANK, _LOCAL_WORLD_SIZE
249
- if _LOCAL_RANK >= 0 or _LOCAL_WORLD_SIZE >= 0:
250
- raise RuntimeError("Distributed mode has already been enabled")
251
- torch_env = _TorchDistributedEnvironment()
252
- torch_env.export(overwrite=overwrite)
253
-
254
- if set_cuda_current_device:
255
- torch.cuda.set_device(torch_env.local_rank)
256
-
257
- if allow_nccl_timeout:
258
- # This allows to use torch distributed timeout in a NCCL backend
259
- key, value = "NCCL_ASYNC_ERROR_HANDLING", "1"
260
- if not overwrite:
261
- _check_env_variable(key, value)
262
- os.environ[key] = value
263
-
264
- dist.init_process_group(backend="nccl")
265
- dist.barrier()
266
-
267
- # Finalize setup
268
- _LOCAL_RANK = torch_env.local_rank
269
- _LOCAL_WORLD_SIZE = torch_env.local_world_size
270
- _restrict_print_to_main_process()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dinov2/eval/__init__.py DELETED
@@ -1,4 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the Apache License, Version 2.0
4
- # found in the LICENSE file in the root directory of this source tree.
 
 
 
 
 
dinov2/eval/cell_dino/knn.py DELETED
@@ -1,479 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the CC-by-NC licence,
4
- # found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree.
5
-
6
- import argparse
7
- from functools import partial
8
- import json
9
- import logging
10
- import os
11
- import sys
12
- from typing import List, Optional, Any
13
- import numpy as np
14
-
15
- import torch
16
- import torch.backends.cudnn as cudnn
17
- import pandas as pd
18
- from sklearn.metrics import f1_score
19
-
20
- import dinov2.distributed as distributed
21
- from dinov2.data import make_dataset, DatasetWithEnumeratedTargets, SamplerType, make_data_loader
22
- from dinov2.data.cell_dino.transforms import NormalizationType, make_classification_eval_cell_transform
23
- from dinov2.eval.metrics import build_metric, MetricType
24
- from dinov2.eval.setup import get_args_parser as get_setup_args_parser
25
- from dinov2.eval.setup import setup_and_build_model
26
-
27
- from dinov2.data import ResultsAccumulator
28
- from dinov2.eval.utils import ModelWithNormalize
29
- from dinov2.eval.cell_dino.utils import (
30
- BagOfChannelsModelWithNormalize,
31
- extract_features_cell_dino,
32
- average_metrics,
33
- create_train_dataset_dict,
34
- get_num_classes,
35
- extract_features_for_dataset_dict,
36
- evaluate_with_accumulate,
37
- KnnModule,
38
- )
39
- from dinov2.eval.knn import DictKeysModule
40
- from torch.utils.data import Subset as SubsetEx
41
- from torch.utils.data import ConcatDataset as ConcatDatasetEx
42
-
43
-
44
- logger = logging.getLogger("dinov2")
45
-
46
-
47
- def get_args_parser(
48
- description: Optional[str] = None,
49
- parents: Optional[List[argparse.ArgumentParser]] = None,
50
- add_help: bool = True,
51
- ):
52
- parents = parents or []
53
- setup_args_parser = get_setup_args_parser(parents=parents, add_help=False)
54
- parents = [setup_args_parser]
55
- parser = argparse.ArgumentParser(
56
- description=description,
57
- parents=parents,
58
- add_help=add_help,
59
- )
60
- parser.add_argument(
61
- "--train-dataset",
62
- dest="train_dataset_str",
63
- type=str,
64
- help="Training dataset",
65
- )
66
- parser.add_argument(
67
- "--val-dataset",
68
- dest="val_dataset_str",
69
- type=str,
70
- help="Validation dataset",
71
- )
72
- parser.add_argument(
73
- "--nb_knn",
74
- nargs="+",
75
- type=int,
76
- help="Number of NN to use. 20 is usually working the best.",
77
- )
78
- parser.add_argument(
79
- "--temperature",
80
- type=float,
81
- help="Temperature used in the voting coefficient",
82
- )
83
- parser.add_argument(
84
- "--gather-on-cpu",
85
- action="store_true",
86
- help="Whether to gather the train features on cpu, slower"
87
- "but useful to avoid OOM for large datasets (e.g. ImageNet22k).",
88
- )
89
- parser.add_argument(
90
- "--batch-size",
91
- type=int,
92
- help="Batch size.",
93
- )
94
- parser.add_argument(
95
- "--n-per-class-list",
96
- nargs="+",
97
- type=int,
98
- help="Number to take per class",
99
- )
100
- parser.add_argument(
101
- "--n-tries",
102
- type=int,
103
- help="Number of tries",
104
- )
105
- parser.add_argument(
106
- "--leave-one-out-dataset",
107
- type=str,
108
- help="Path with indexes to use the leave one out strategy for CHAMMI_CP task 3 and CHAMMI_HPA task 4",
109
- )
110
- parser.add_argument(
111
- "--bag-of-channels",
112
- action="store_true",
113
- help='Whether to use the "bag of channels" channel adaptive strategy',
114
- )
115
- parser.add_argument(
116
- "--crop-size",
117
- type=int,
118
- help="crop size for train and eval",
119
- )
120
- parser.add_argument(
121
- "--resize-size",
122
- type=int,
123
- help="resize size for image just before crop. 0: no resize",
124
- )
125
- parser.add_argument(
126
- "--metric-type",
127
- type=MetricType,
128
- choices=list(MetricType),
129
- help="Validation metric",
130
- )
131
- parser.add_argument(
132
- "--avgpool",
133
- action="store_true",
134
- help="Whether to use average pooling of path tokens in addition to CLS tokens",
135
- )
136
-
137
- parser.set_defaults(
138
- train_dataset_str="ImageNet:split=TRAIN",
139
- val_dataset_str="ImageNet:split=VAL",
140
- nb_knn=[1],
141
- temperature=0.07,
142
- batch_size=256,
143
- resize_size=0,
144
- )
145
- return parser
146
-
147
-
148
- class SequentialWithKwargs(torch.nn.Sequential):
149
- def __init__(self, *args):
150
- super().__init__(*args)
151
-
152
- def forward(self, input, **kwargs):
153
-
154
- input = self[0](input, **kwargs)
155
- for module in self[1:]:
156
- input = module(input)
157
- return input
158
-
159
-
160
- def create_train_test_dataset_dict_leave_one_out(
161
- train_dataset,
162
- test_dataset,
163
- ) -> dict[int, dict[int, Any]]:
164
- """
165
- This function implements a train dataset dictionary with the leave-one-out (LOO) method.
166
- Specifically, given a train dataset and test dataset, it creates a train dataset for each
167
- test dataset point, which is a combination of train+test dataset except for this specific data point.
168
- At the end, it contains len(test_dataset) key and value pairs.
169
-
170
- Format is {"nth-test-sample": dataset_without_test_sample}
171
- """
172
- train_dataset_dict: dict[int, Any] = {}
173
- test_size = len(test_dataset)
174
-
175
- for test_sample_index in range(test_size):
176
- test_indices_bool = torch.ones(test_size, dtype=bool)
177
- test_indices_bool[test_sample_index] = False
178
- train_dataset_dict[test_sample_index] = ConcatDatasetEx(
179
- [train_dataset, SubsetEx(test_dataset, test_indices_bool.nonzero().flatten())]
180
- )
181
-
182
- return train_dataset_dict
183
-
184
-
185
- def eval_knn_with_leave_one_out(
186
- model, leave_one_out_dataset, train_dataset, test_dataset, metric_type, nb_knn, temperature, batch_size, num_workers
187
- ):
188
- num_classes = get_num_classes(test_dataset)
189
- train_dataset_dict = create_train_dataset_dict(train_dataset)
190
- test_dataset_dict = create_train_dataset_dict(test_dataset)
191
-
192
- logger.info("Extracting features for train set...")
193
- train_data_dict = extract_features_for_dataset_dict(
194
- model, train_dataset_dict, batch_size, num_workers, gather_on_cpu=True
195
- )
196
- test_data_dict = extract_features_for_dataset_dict(
197
- model, test_dataset_dict, batch_size, num_workers, gather_on_cpu=True
198
- )
199
-
200
- train_features = train_data_dict[0]["train_features"]
201
- train_labels = train_data_dict[0]["train_labels"]
202
- test_features = test_data_dict[0]["train_features"]
203
- test_labels = test_data_dict[0]["train_labels"]
204
-
205
- metric_collection = build_metric(metric_type, num_classes=3)
206
-
207
- device = torch.cuda.current_device()
208
- partial_knn_module = partial(KnnModule, T=temperature, device=device, num_classes=num_classes)
209
-
210
- logger.info("Reading the leave-one-out label metadata.")
211
-
212
- leave_one_out_indices = {}
213
- metadata = pd.read_csv(leave_one_out_dataset)
214
- if "HPA" in leave_one_out_dataset:
215
- metadata = metadata[metadata["Task_three"]].reset_index()
216
- leave_one_out_label_type = "cell_type"
217
- else:
218
- metadata = metadata[metadata["Task_four"]].reset_index()
219
- leave_one_out_label_type = "Plate"
220
- leave_one_out_labels = metadata[leave_one_out_label_type].unique()
221
-
222
- for leave_one_out_label in leave_one_out_labels:
223
- leave_one_out_indices[leave_one_out_label] = torch.tensor(
224
- metadata[metadata[leave_one_out_label_type] == leave_one_out_label].index.values
225
- )
226
-
227
- # ============ evaluation ... ============
228
- logger.info("Start the k-NN classification.")
229
-
230
- eval_metrics_dict = {}
231
- postprocessors, metrics = {k: DictKeysModule([k]) for k in nb_knn}, {
232
- k: metric_collection.clone().to(device) for k in nb_knn
233
- }
234
- for metric_key in metrics.keys():
235
- metrics[metric_key] = metrics[metric_key].to(device)
236
-
237
- accumulator_class = ResultsAccumulator
238
- accumulators = {k: accumulator_class() for k in postprocessors.keys()}
239
- all_preds = []
240
- all_target = []
241
-
242
- for loo_label, loo_indices in leave_one_out_indices.items():
243
- logger.info(f"Evaluating on test sample {loo_label}")
244
- loo_for_training_indices = torch.ones(test_features.shape[0], dtype=bool)
245
- loo_for_training_indices[loo_indices] = False
246
- train_features_sample = torch.cat([train_features, test_features[loo_for_training_indices]])
247
- train_labels_sample = torch.cat([train_labels, test_labels[loo_for_training_indices]])
248
- logger.info(f"Train shape {train_features_sample.shape}, Test shape {test_features[loo_indices].shape}")
249
- logger.info(
250
- f"Train values {train_labels_sample.unique(return_counts=True)}, Test shape {test_labels[loo_indices].unique(return_counts=True)}"
251
- )
252
- knn_module = partial_knn_module(
253
- train_features=train_features_sample, train_labels=train_labels_sample, nb_knn=nb_knn
254
- )
255
-
256
- output = knn_module(test_features[loo_indices].to(device))
257
- all_preds.append(output[1])
258
- all_target.append(test_labels[loo_indices])
259
- output[1] = output[1][:, 4:]
260
- transformed_test_labels = test_labels[loo_indices] - 4
261
- for k, metric in metrics.items():
262
- metric_inputs = postprocessors[k](output, transformed_test_labels.to(device))
263
- metric.update(**metric_inputs)
264
- accumulators[k].update(
265
- preds=metric_inputs["preds"], target=metric_inputs["target"], index=loo_indices.to(device)
266
- )
267
-
268
- all_preds = torch.cat(all_preds).cpu().detach().numpy()
269
-
270
- all_preds = np.argmax(all_preds, axis=1)
271
- all_target = torch.cat(all_target).cpu().detach().numpy()
272
-
273
- f1 = f1_score(all_target, all_preds, average="macro", labels=[4, 5, 6])
274
- logger.info(f"Real f1 score: {f1}")
275
- eval_metrics = {
276
- k: metric.compute() for k, metric in metrics.items()
277
- } # next erased by the real f1 score computed above
278
-
279
- for k in nb_knn:
280
- if k not in eval_metrics_dict:
281
- eval_metrics_dict[k] = {}
282
- eval_metrics_dict[k] = {metric: f1 * 100.0 for metric, v in eval_metrics[k].items()}
283
-
284
- if len(train_data_dict) > 1:
285
- return {k: average_metrics(eval_metrics_dict[k]) for k in eval_metrics_dict.keys()}
286
-
287
- return {k: eval_metrics_dict[k] for k in eval_metrics_dict.keys()}
288
-
289
-
290
- def eval_knn_with_model(
291
- model,
292
- output_dir,
293
- train_dataset_str,
294
- val_dataset_str,
295
- nb_knn=(10, 20, 100, 200),
296
- temperature=0.07,
297
- autocast_dtype=torch.float,
298
- metric_type=MetricType.MEAN_ACCURACY,
299
- transform=None,
300
- resize_size=256,
301
- crop_size=224,
302
- batch_size=256,
303
- num_workers=5,
304
- leave_one_out_dataset="",
305
- bag_of_channels=False,
306
- avgpool=False,
307
- ):
308
- autocast_ctx = partial(torch.cuda.amp.autocast, enabled=True, dtype=autocast_dtype)
309
- if bag_of_channels:
310
- model = BagOfChannelsModelWithNormalize(model, autocast_ctx, avgpool)
311
- else:
312
- model = ModelWithNormalize(model)
313
- if leave_one_out_dataset == "" or leave_one_out_dataset is None:
314
- leave_one_out = False
315
- else:
316
- leave_one_out = True
317
-
318
- cudnn.benchmark = True
319
- transform = make_classification_eval_cell_transform(
320
- normalization_type=NormalizationType.SELF_NORM_CENTER_CROP, resize_size=resize_size, crop_size=crop_size
321
- )
322
-
323
- train_dataset = make_dataset(dataset_str=train_dataset_str, transform=transform)
324
- results_dict = {}
325
- test_dataset = make_dataset(dataset_str=val_dataset_str, transform=transform)
326
-
327
- with torch.cuda.amp.autocast(dtype=autocast_dtype):
328
- if leave_one_out:
329
- results_dict_knn = eval_knn_with_leave_one_out(
330
- model=model,
331
- leave_one_out_dataset=leave_one_out_dataset,
332
- train_dataset=train_dataset,
333
- test_dataset=test_dataset,
334
- metric_type=metric_type,
335
- nb_knn=nb_knn,
336
- temperature=temperature,
337
- batch_size=batch_size,
338
- num_workers=num_workers,
339
- )
340
- else:
341
- results_dict_knn = eval_knn(
342
- model=model,
343
- train_dataset=train_dataset,
344
- test_dataset=test_dataset,
345
- metric_type=metric_type,
346
- nb_knn=nb_knn,
347
- temperature=temperature,
348
- batch_size=batch_size,
349
- num_workers=num_workers,
350
- )
351
-
352
- for knn_ in results_dict_knn.keys():
353
- top1 = results_dict_knn[knn_]["top-1"]
354
- results_dict[f"{val_dataset_str}_{knn_} Top 1"] = top1
355
- results_string = f"{val_dataset_str} {knn_} NN classifier result: Top1: {top1:.2f}"
356
- if "top-5" in results_dict_knn[knn_]:
357
- top5 = results_dict_knn[knn_]["top-5"]
358
- results_dict[f"{val_dataset_str}_{knn_} Top 5"] = top5
359
- results_string += f"Top5: {top5:.2f}"
360
- logger.info(results_string)
361
-
362
- metrics_file_path = os.path.join(output_dir, "results_eval_knn.json")
363
- with open(metrics_file_path, "a") as f:
364
- for k, v in results_dict.items():
365
- f.write(json.dumps({k: v}) + "\n")
366
-
367
- if distributed.is_enabled():
368
- torch.distributed.barrier()
369
- return results_dict
370
-
371
-
372
- def eval_knn(
373
- model,
374
- train_dataset,
375
- test_dataset,
376
- metric_type,
377
- nb_knn,
378
- temperature,
379
- batch_size,
380
- num_workers,
381
- few_shot_eval=False,
382
- few_shot_k_or_percent=None,
383
- few_shot_n_tries=1,
384
- ):
385
- num_classes = get_num_classes(train_dataset)
386
- train_dataset_dict = create_train_dataset_dict(
387
- train_dataset,
388
- few_shot_eval=few_shot_eval,
389
- few_shot_k_or_percent=few_shot_k_or_percent,
390
- few_shot_n_tries=few_shot_n_tries,
391
- )
392
-
393
- logger.info("Extracting features for train set...")
394
-
395
- train_data_dict: dict[int, dict[str, torch.Tensor]] = {}
396
- for try_n, dataset in train_dataset_dict.items():
397
- features, labels = extract_features_cell_dino(model, dataset, batch_size, num_workers, gather_on_cpu=True)
398
- train_data_dict[try_n] = {"train_features": features, "train_labels": labels}
399
-
400
- test_data_loader = make_data_loader(
401
- dataset=DatasetWithEnumeratedTargets(
402
- test_dataset, pad_dataset=True, num_replicas=distributed.get_global_size()
403
- ),
404
- batch_size=batch_size,
405
- num_workers=num_workers,
406
- sampler_type=SamplerType.DISTRIBUTED,
407
- drop_last=False,
408
- shuffle=False,
409
- persistent_workers=True,
410
- collate_fn=None,
411
- )
412
- metric_collection = build_metric(metric_type, num_classes=num_classes)
413
-
414
- device = torch.cuda.current_device()
415
- partial_knn_module = partial(
416
- KnnModule,
417
- T=temperature,
418
- device=device,
419
- num_classes=num_classes,
420
- )
421
-
422
- # ============ evaluation ... ============
423
- logger.info("Start the k-NN classification.")
424
- eval_metrics_dict = {}
425
-
426
- for try_ in train_data_dict.keys():
427
- train_features, train_labels = train_data_dict[try_]["train_features"], train_data_dict[try_]["train_labels"]
428
- k_list = sorted(set([el if el < len(train_features) else len(train_features) for el in nb_knn]))
429
- knn_module = partial_knn_module(train_features=train_features, train_labels=train_labels, nb_knn=k_list)
430
- postprocessors, metrics = {k: DictKeysModule([k]) for k in k_list}, {
431
- k: metric_collection.clone() for k in k_list
432
- }
433
- _, eval_metrics, _ = evaluate_with_accumulate(
434
- SequentialWithKwargs(model, knn_module),
435
- test_data_loader,
436
- postprocessors,
437
- metrics,
438
- device,
439
- accumulate_results=False,
440
- )
441
- for k in k_list:
442
- if k not in eval_metrics_dict:
443
- eval_metrics_dict[k] = {}
444
- eval_metrics_dict[k][try_] = {metric: v.item() * 100.0 for metric, v in eval_metrics[k].items()}
445
-
446
- if len(train_data_dict) > 1:
447
- return {k: average_metrics(eval_metrics_dict[k]) for k in eval_metrics_dict.keys()}
448
-
449
- return {k: eval_metrics_dict[k][0] for k in eval_metrics_dict.keys()}
450
-
451
-
452
- def main(args):
453
- model, autocast_dtype = setup_and_build_model(args)
454
- eval_knn_with_model(
455
- model=model,
456
- output_dir=args.output_dir,
457
- train_dataset_str=args.train_dataset_str,
458
- val_dataset_str=args.val_dataset_str,
459
- nb_knn=args.nb_knn,
460
- temperature=args.temperature,
461
- autocast_dtype=autocast_dtype,
462
- transform=None,
463
- metric_type=args.metric_type,
464
- batch_size=args.batch_size,
465
- num_workers=5,
466
- leave_one_out_dataset=args.leave_one_out_dataset,
467
- resize_size=args.resize_size,
468
- crop_size=args.crop_size,
469
- avgpool=args.avgpool,
470
- bag_of_channels=args.bag_of_channels,
471
- )
472
- return 0
473
-
474
-
475
- if __name__ == "__main__":
476
- description = "k-NN evaluation on models trained with bag of channel strategy or cell dino"
477
- args_parser = get_args_parser(description=description)
478
- args = args_parser.parse_args()
479
- sys.exit(main(args))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dinov2/eval/cell_dino/linear.py DELETED
@@ -1,1048 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the CC-by-NC licence,
4
- # found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree.
5
-
6
- import argparse
7
- from functools import partial
8
- import json
9
- import logging
10
- import os
11
- import sys
12
- from typing import Any, Callable, Dict, Optional, Tuple, List
13
- from enum import Enum
14
- from dataclasses import dataclass
15
-
16
- from sklearn.metrics import f1_score
17
- import numpy as np
18
- import pandas as pd
19
- import torch
20
- import torch.nn as nn
21
- from torch.utils.data import TensorDataset
22
- from torch.nn.parallel import DistributedDataParallel
23
-
24
-
25
- from dinov2.data import SamplerType, make_data_loader, make_dataset, DatasetWithEnumeratedTargets
26
- from dinov2.data.cell_dino.transforms import NormalizationType, make_classification_eval_cell_transform
27
- import dinov2.distributed as distributed
28
- from dinov2.eval.metrics import MetricType, build_metric
29
- from dinov2.eval.setup import get_args_parser as get_setup_args_parser
30
- from dinov2.eval.setup import setup_and_build_model
31
- from dinov2.eval.cell_dino.utils import (
32
- evaluate_with_accumulate,
33
- LossType,
34
- average_metrics,
35
- create_train_dataset_dict,
36
- get_num_classes,
37
- extract_features_for_dataset_dict,
38
- )
39
- from dinov2.eval.utils import ModelWithIntermediateLayers
40
- from dinov2.logging import MetricLogger
41
- from dinov2.utils.checkpoint import build_periodic_checkpointer, resume_or_load
42
-
43
- logger = logging.getLogger("dinov2")
44
-
45
- """
46
- List of changes with respect to the standard linear evaluation script:
47
-
48
- bag of channel option : SCALE ADAPTIVE STRATEGY
49
-
50
- Adam optimizer instead of SGD
51
- Scheduler : two options : onecycleLR or CosineAnnealingLR
52
- the transforms/normalization are different, now calling make_classification_eval_cell_transform
53
- add binary cross entropy loss option for protein localization
54
- change the definition of the num_classes using get_num_classes
55
- change of some default parameters (batch_size, epoch_length, epochs, lrs)
56
- defined n_last_blocks option
57
- avgpool option
58
- leave one out strategy for CHAMMI evaluation
59
- grid search for optimal weight decay
60
- """
61
-
62
-
63
- def get_args_parser(
64
- description: Optional[str] = None,
65
- parents: Optional[List[argparse.ArgumentParser]] = None,
66
- add_help: bool = True,
67
- ):
68
- parents = parents or []
69
- setup_args_parser = get_setup_args_parser(parents=parents, add_help=False)
70
- parents = [setup_args_parser]
71
- parser = argparse.ArgumentParser(
72
- description=description,
73
- parents=parents,
74
- add_help=add_help,
75
- )
76
- parser.add_argument(
77
- "--train-dataset",
78
- dest="train_dataset_str",
79
- type=str,
80
- help="Training dataset",
81
- )
82
- parser.add_argument(
83
- "--val-dataset",
84
- dest="val_dataset_str",
85
- type=str,
86
- help="Validation dataset",
87
- )
88
- parser.add_argument(
89
- "--test-datasets",
90
- dest="test_dataset_strs",
91
- type=str,
92
- nargs="+",
93
- help="Test datasets, none to reuse the validation dataset",
94
- )
95
- parser.add_argument(
96
- "--epochs",
97
- type=int,
98
- help="Number of training epochs",
99
- )
100
- parser.add_argument(
101
- "--batch-size",
102
- type=int,
103
- help="Batch Size (per GPU)",
104
- )
105
- parser.add_argument(
106
- "--num-workers",
107
- type=int,
108
- help="Number de Workers",
109
- )
110
- parser.add_argument(
111
- "--epoch-length",
112
- type=int,
113
- help="Length of an epoch in number of iterations",
114
- )
115
- parser.add_argument(
116
- "--save-checkpoint-frequency",
117
- type=int,
118
- help="Number of epochs between two named checkpoint saves.",
119
- )
120
- parser.add_argument(
121
- "--eval-period-iterations",
122
- type=int,
123
- help="Number of iterations between two evaluations.",
124
- )
125
- parser.add_argument(
126
- "--learning-rates",
127
- nargs="+",
128
- type=float,
129
- help="Learning rates to grid search.",
130
- )
131
- parser.add_argument(
132
- "--weight_decays",
133
- nargs="+",
134
- type=float,
135
- help="Weight decays to grid search.",
136
- )
137
- parser.add_argument(
138
- "--n-last-blocks",
139
- type=int,
140
- help="number of backbone last blocks used for the linear classifier",
141
- )
142
- parser.add_argument(
143
- "--no-resume",
144
- action="store_true",
145
- help="Whether to not resume from existing checkpoints",
146
- )
147
- parser.add_argument(
148
- "--val-metric-type",
149
- type=MetricType,
150
- choices=list(MetricType),
151
- help="Validation metric",
152
- )
153
- parser.add_argument(
154
- "--test-metric-types",
155
- type=MetricType,
156
- choices=list(MetricType),
157
- nargs="+",
158
- help="Evaluation metric",
159
- )
160
- parser.add_argument(
161
- "--classifier-fpath",
162
- type=str,
163
- help="Path to a file containing pretrained linear classifiers",
164
- )
165
- parser.add_argument(
166
- "--val-class-mapping-fpath",
167
- type=str,
168
- help="Path to a file containing a mapping to adjust classifier outputs",
169
- )
170
- parser.add_argument(
171
- "--test-class-mapping-fpaths",
172
- nargs="+",
173
- type=str,
174
- help="Path to a file containing a mapping to adjust classifier outputs",
175
- )
176
- parser.add_argument(
177
- "--loss-type",
178
- type=LossType,
179
- help="Cross Entropy or Binary Cross Entropy, default cross entropy loss",
180
- )
181
- parser.add_argument(
182
- "--bag-of-channels",
183
- action="store_true",
184
- help='Whether to use the "bag of channels" channel adaptive strategy',
185
- )
186
- parser.add_argument(
187
- "--leave-one-out-dataset",
188
- type=str,
189
- help="Path with indexes to use the leave one out strategy for CHAMMI_CP task 3 and CHAMMI_HPA task 4",
190
- )
191
- parser.add_argument(
192
- "--crop-size",
193
- type=int,
194
- help="crop size for train and eval",
195
- )
196
- parser.add_argument(
197
- "--resize-size",
198
- type=int,
199
- help="resize size for image just before crop. 0: no resize",
200
- )
201
- parser.add_argument(
202
- "--avgpool",
203
- action="store_true",
204
- help="Whether to use average pooling of path tokens in addition to CLS tokens",
205
- )
206
- parser.add_argument(
207
- "--scheduler",
208
- type=SchedulerType,
209
- help="Scheduler type",
210
- )
211
-
212
- parser.set_defaults(
213
- train_dataset_str="ImageNet:split=TRAIN",
214
- val_dataset_str="ImageNet:split=VAL",
215
- test_dataset_strs=None,
216
- epochs=30,
217
- batch_size=64,
218
- num_workers=8,
219
- epoch_length=145,
220
- save_checkpoint_frequency=1250,
221
- eval_period_iterations=1250,
222
- learning_rates=[1e-5, 2e-5, 5e-5, 1e-4, 2e-4, 5e-4, 1e-3, 2e-3, 5e-3, 1e-2, 2e-2, 5e-2, 1e-1, 2e-1, 5e-1, 1.0],
223
- weight_decays=[0.0, 0.0001, 1.0e-05],
224
- val_metric_type=MetricType.MEAN_ACCURACY,
225
- test_metric_types=None,
226
- classifier_fpath=None,
227
- val_class_mapping_fpath=None,
228
- test_class_mapping_fpaths=[None],
229
- loss_type=LossType.CROSS_ENTROPY,
230
- crop_size=384,
231
- resize_size=0,
232
- n_last_blocks=4,
233
- avgpool=False,
234
- scheduler=SchedulerType.COSINE_ANNEALING,
235
- )
236
- return parser
237
-
238
-
239
- def has_ddp_wrapper(m: nn.Module) -> bool:
240
- return isinstance(m, DistributedDataParallel)
241
-
242
-
243
- def remove_ddp_wrapper(m: nn.Module) -> nn.Module:
244
- return m.module if has_ddp_wrapper(m) else m
245
-
246
-
247
- def create_linear_input(x_tokens_list, use_n_blocks, use_avgpool, bag_of_channels):
248
- intermediate_output = x_tokens_list[-use_n_blocks:]
249
- output = torch.cat([class_token for _, class_token in intermediate_output], dim=-1)
250
- if bag_of_channels:
251
- if use_avgpool:
252
- output = torch.cat(
253
- (
254
- output,
255
- torch.mean(intermediate_output[-1][0], dim=-2).reshape(intermediate_output[-1][0].shape[0], -1),
256
- # average pooling of patch tokens: average over N, then concatenate channels if single-channel patch model
257
- ),
258
- dim=-1,
259
- ) # concatenate average pooling of patch tokens to concatenated patch tokens
260
- else:
261
- if use_avgpool:
262
- output = torch.cat(
263
- (
264
- output,
265
- torch.mean(intermediate_output[-1][0], dim=1), # patch tokens
266
- ),
267
- dim=-1,
268
- )
269
- output = output.reshape(output.shape[0], -1)
270
- return output.float()
271
-
272
-
273
- class LinearClassifier(nn.Module):
274
- """Linear layer to train on top of frozen features"""
275
-
276
- def __init__(
277
- self, out_dim, use_n_blocks, use_avgpool, num_classes=1000, bag_of_channels=False, leave_one_out=False
278
- ):
279
- super().__init__()
280
- self.out_dim = out_dim
281
- self.use_n_blocks = use_n_blocks
282
- self.use_avgpool = use_avgpool
283
- self.num_classes = num_classes
284
- self.bag_of_channels = bag_of_channels
285
- self.leave_one_out = leave_one_out
286
- self.linear = nn.Linear(out_dim, num_classes)
287
- self.linear.weight.data.normal_(mean=0.0, std=0.01)
288
- self.linear.bias.data.zero_()
289
-
290
- def forward(self, x_tokens_list):
291
- if self.leave_one_out:
292
- return self.linear(x_tokens_list)
293
- output = create_linear_input(x_tokens_list, self.use_n_blocks, self.use_avgpool, self.bag_of_channels)
294
- return self.linear(output)
295
-
296
-
297
- class AllClassifiers(nn.Module):
298
- def __init__(self, classifiers_dict):
299
- super().__init__()
300
- self.classifiers_dict = nn.ModuleDict()
301
- self.classifiers_dict.update(classifiers_dict)
302
-
303
- def forward(self, inputs):
304
- return {k: v.forward(inputs) for k, v in self.classifiers_dict.items()}
305
-
306
- def __len__(self):
307
- return len(self.classifiers_dict)
308
-
309
-
310
- class LinearPostprocessor(nn.Module):
311
- def __init__(self, linear_classifier, class_mapping=None):
312
- super().__init__()
313
- self.linear_classifier = linear_classifier
314
- self.register_buffer("class_mapping", None if class_mapping is None else torch.LongTensor(class_mapping))
315
-
316
- def forward(self, samples, targets):
317
- preds = self.linear_classifier(samples)
318
- return {
319
- "preds": preds[:, self.class_mapping] if self.class_mapping is not None else preds,
320
- "target": targets,
321
- }
322
-
323
-
324
- def scale_lr(learning_rates, batch_size):
325
- return learning_rates * (batch_size * distributed.get_global_size()) / 256.0
326
-
327
-
328
- def setup_linear_classifiers(
329
- sample_output,
330
- n_last_blocks_list,
331
- learning_rates,
332
- weight_decays,
333
- batch_size,
334
- num_classes=1000,
335
- bag_of_channels=False,
336
- leave_one_out=False,
337
- avgpool=False,
338
- ):
339
- linear_classifiers_dict = nn.ModuleDict()
340
- avgpool_value = avgpool
341
- optim_param_groups = []
342
- for n in n_last_blocks_list:
343
- for avgpool in [avgpool_value]:
344
- for _lr in learning_rates:
345
- for wd in weight_decays:
346
- lr = scale_lr(_lr, batch_size)
347
- out_dim = create_linear_input(
348
- sample_output, use_n_blocks=n, use_avgpool=avgpool, bag_of_channels=bag_of_channels
349
- ).shape[1]
350
- linear_classifier = LinearClassifier(
351
- out_dim,
352
- use_n_blocks=n,
353
- use_avgpool=avgpool,
354
- num_classes=num_classes,
355
- bag_of_channels=bag_of_channels,
356
- leave_one_out=leave_one_out,
357
- )
358
- linear_classifier = linear_classifier.cuda()
359
- linear_classifiers_dict[
360
- f"classifier_{n}_blocks_avgpool_{avgpool}_lr_{lr:.5f}_wd_{wd:.2E}".replace(".", "_")
361
- ] = linear_classifier
362
- optim_param_groups.append({"params": linear_classifier.parameters(), "lr": lr, "weight_decay": wd})
363
-
364
- linear_classifiers = AllClassifiers(linear_classifiers_dict)
365
- if distributed.is_enabled():
366
- linear_classifiers = nn.parallel.DistributedDataParallel(linear_classifiers)
367
-
368
- return linear_classifiers, optim_param_groups
369
-
370
-
371
- def make_eval_data_loader(
372
- *,
373
- test_dataset_str_or_path_or_loo_dataset,
374
- config,
375
- batch_size,
376
- num_workers,
377
- ):
378
- if isinstance(test_dataset_str_or_path_or_loo_dataset, str):
379
- logger.info(f"Loading dataset {test_dataset_str_or_path_or_loo_dataset}")
380
- transform = make_classification_eval_cell_transform(
381
- normalization_type=NormalizationType.SELF_NORM_CENTER_CROP,
382
- resize_size=config["resize_size"],
383
- crop_size=config["crop_size"],
384
- )
385
- test_dataset = make_dataset(dataset_str=test_dataset_str_or_path_or_loo_dataset, transform=transform)
386
- collate_fn = None
387
- else:
388
- logger.info("Making data loader for feature dataset (typical in leave one out evaluation)")
389
- test_dataset = test_dataset_str_or_path_or_loo_dataset
390
- collate_fn = None
391
- class_mapping = None
392
- if hasattr(test_dataset, "get_imagenet_class_mapping"):
393
- class_mapping = test_dataset.get_imagenet_class_mapping()
394
-
395
- test_data_loader = make_data_loader(
396
- dataset=DatasetWithEnumeratedTargets(
397
- test_dataset, pad_dataset=True, num_replicas=distributed.get_global_size()
398
- ),
399
- batch_size=batch_size,
400
- num_workers=num_workers,
401
- sampler_type=SamplerType.DISTRIBUTED,
402
- drop_last=False,
403
- shuffle=False,
404
- persistent_workers=False,
405
- collate_fn=collate_fn,
406
- )
407
- return test_data_loader, class_mapping
408
-
409
-
410
- @dataclass
411
- class Evaluator:
412
- batch_size: int
413
- num_workers: int
414
- dataset_str_or_path: str
415
- config: Dict
416
- metric_type: MetricType
417
- metrics_file_path: str
418
- training_num_classes: int
419
- save_results_func: Optional[Callable]
420
- val_dataset_loo: Optional[TensorDataset] = None
421
-
422
- def __post_init__(self):
423
- self.main_metric_name = f"{self.dataset_str_or_path}_accuracy"
424
-
425
- if self.val_dataset_loo is not None:
426
- self.dataset_str_or_path = self.val_dataset_loo
427
-
428
- self.data_loader, self.class_mapping = make_eval_data_loader(
429
- test_dataset_str_or_path_or_loo_dataset=self.dataset_str_or_path,
430
- batch_size=self.batch_size,
431
- num_workers=self.num_workers,
432
- config=self.config,
433
- )
434
-
435
- @torch.no_grad()
436
- def _evaluate_linear_classifiers(
437
- self,
438
- *,
439
- feature_model,
440
- linear_classifiers,
441
- iteration,
442
- prefixstring="",
443
- best_classifier_on_val=None,
444
- accumulate_results=False,
445
- test_mode=False,
446
- ) -> Tuple[Dict[str, Any], Optional[Dict[str, torch.Tensor]]]:
447
- logger.info("running validation !")
448
-
449
- num_classes = len(self.class_mapping) if self.class_mapping is not None else self.training_num_classes
450
- metric = build_metric(self.metric_type, num_classes=num_classes)
451
- postprocessors = {
452
- k: LinearPostprocessor(v, self.class_mapping) for k, v in linear_classifiers.classifiers_dict.items()
453
- }
454
- metrics = {k: metric.clone() for k in linear_classifiers.classifiers_dict}
455
-
456
- _, results_dict_temp, accumulated_results = evaluate_with_accumulate(
457
- feature_model,
458
- self.data_loader,
459
- postprocessors,
460
- metrics,
461
- torch.cuda.current_device(),
462
- accumulate_results=accumulate_results,
463
- leave_one_out=self.config["leave_one_out"],
464
- test_mode=test_mode,
465
- )
466
-
467
- logger.info("")
468
- results_dict = {}
469
- max_accuracy = 0
470
- best_classifier = ""
471
- for _, (classifier_string, metric) in enumerate(results_dict_temp.items()):
472
- logger.info(f"{prefixstring} -- Classifier: {classifier_string} * {metric}")
473
- if (
474
- best_classifier_on_val is None and metric["top-1"].item() > max_accuracy
475
- ) or classifier_string == best_classifier_on_val:
476
- max_accuracy = metric["top-1"].item()
477
- best_classifier = classifier_string
478
-
479
- results_dict["best_classifier"] = {"name": best_classifier, "accuracy": max_accuracy}
480
-
481
- logger.info(f"best classifier: {results_dict['best_classifier']}")
482
-
483
- accumulated_best_results = None
484
- if test_mode:
485
- accumulated_best_results = accumulated_results
486
- elif accumulated_results is not None:
487
- accumulated_best_results = accumulated_results[best_classifier]
488
-
489
- if distributed.is_main_process():
490
- with open(self.metrics_file_path, "a") as f:
491
- f.write(f"iter: {iteration}\n")
492
- for k, v in results_dict.items():
493
- f.write(json.dumps({k: v}) + "\n")
494
- f.write("\n")
495
-
496
- return results_dict, accumulated_best_results
497
-
498
- def evaluate_and_maybe_save(
499
- self,
500
- feature_model,
501
- linear_classifiers,
502
- iteration: int,
503
- best_classifier_on_val: Optional[Any] = None,
504
- save_filename_suffix: str = "",
505
- prefixstring: str = "",
506
- test_mode: bool = False,
507
- ):
508
- logger.info(f"Testing on {self.dataset_str_or_path}")
509
- save_results = self.save_results_func is not None
510
- full_results_dict, accumulated_best_results = self._evaluate_linear_classifiers(
511
- feature_model=feature_model,
512
- linear_classifiers=remove_ddp_wrapper(linear_classifiers),
513
- iteration=iteration,
514
- prefixstring=prefixstring,
515
- best_classifier_on_val=best_classifier_on_val,
516
- accumulate_results=save_results,
517
- test_mode=test_mode,
518
- )
519
- if self.save_results_func is not None:
520
- self.save_results_func(
521
- filename_suffix=f"{self.dataset_str_or_path}{save_filename_suffix}", **accumulated_best_results
522
- )
523
-
524
- results_dict = {
525
- self.main_metric_name: 100.0 * full_results_dict["best_classifier"]["accuracy"],
526
- "best_classifier": full_results_dict["best_classifier"]["name"],
527
- }
528
- return results_dict, accumulated_best_results
529
-
530
-
531
- def make_evaluators(
532
- config: Dict,
533
- val_metric_type: MetricType,
534
- val_dataset: str,
535
- metric_type: MetricType,
536
- metrics_file_path: str,
537
- training_num_classes: int,
538
- save_results_func: Optional[Callable],
539
- val_dataset_loo: Optional[TensorDataset] = None,
540
- ):
541
- test_metric_types = config["test_metric_types"]
542
- test_dataset_strs = config["test_datasets"]
543
- if test_dataset_strs is None:
544
- test_dataset_strs = (config["val_dataset"],)
545
- if test_metric_types is None:
546
- test_metric_types = (val_metric_type,)
547
- else:
548
- assert len(test_metric_types) == len(config["test_datasets"])
549
-
550
- val_evaluator, *test_evaluators = [
551
- Evaluator(
552
- dataset_str_or_path=dataset_str_or_path,
553
- batch_size=config["batch_size"],
554
- num_workers=config["num_workers"],
555
- config=config,
556
- metric_type=metric_type,
557
- metrics_file_path=metrics_file_path,
558
- training_num_classes=training_num_classes,
559
- save_results_func=save_results_func,
560
- val_dataset_loo=val_dataset_loo,
561
- )
562
- for dataset_str_or_path, metric_type in zip(
563
- (val_dataset,) + tuple(test_dataset_strs),
564
- (val_metric_type,) + tuple(test_metric_types),
565
- )
566
- ]
567
- return val_evaluator, test_evaluators
568
-
569
-
570
- class SchedulerType(Enum):
571
- COSINE_ANNEALING = "cosine_annealing"
572
- ONE_CYCLE = "one_cycle"
573
-
574
- def get_scheduler(self, optimizer, optim_param_groups, epoch_length, epochs, max_iter):
575
- if self == SchedulerType.ONE_CYCLE:
576
- lr_list = [optim_param_groups[i]["lr"] for i in range(len(optim_param_groups))]
577
- scheduler = torch.optim.lr_scheduler.OneCycleLR(
578
- optimizer, max_lr=lr_list, steps_per_epoch=epoch_length, epochs=epochs
579
- )
580
- else:
581
- scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, max_iter, eta_min=0)
582
- print("CosineAnnealingLR scheduler")
583
- return scheduler
584
-
585
-
586
- def setup_linear_training(
587
- *,
588
- config: Dict,
589
- sample_output: torch.Tensor,
590
- training_num_classes: int,
591
- checkpoint_output_dir: str,
592
- ):
593
- linear_classifiers, optim_param_groups = setup_linear_classifiers(
594
- sample_output,
595
- config["n_last_blocks_list"],
596
- config["learning_rates"],
597
- config["weight_decays"],
598
- config["batch_size"],
599
- training_num_classes,
600
- config["bag_of_channels"],
601
- config["leave_one_out"],
602
- config["avgpool"],
603
- )
604
- max_iter = config["epochs"] * config["epoch_length"]
605
- optimizer = torch.optim.AdamW(optim_param_groups, weight_decay=0)
606
-
607
- scheduler = config["scheduler"].get_scheduler(
608
- optimizer=optimizer,
609
- optim_param_groups=optim_param_groups,
610
- epoch_length=config["epoch_length"],
611
- epochs=config["epochs"],
612
- max_iter=max_iter,
613
- )
614
- checkpoint_period = config["save_checkpoint_iterations"] or config["epoch_length"]
615
- periodic_checkpointer = build_periodic_checkpointer(
616
- linear_classifiers,
617
- checkpoint_output_dir,
618
- optimizer=optimizer,
619
- scheduler=scheduler,
620
- period=checkpoint_period,
621
- max_iter=max_iter,
622
- max_to_keep=None,
623
- )
624
- checkpoint = resume_or_load(periodic_checkpointer, config["classifier_fpath"] or "", resume=config["resume"])
625
-
626
- start_iter = checkpoint.get("iteration", -1) + 1
627
- best_accuracy = checkpoint.get("best_accuracy", -1)
628
-
629
- if config["loss_type"] == LossType.BINARY_CROSS_ENTROPY:
630
- criterion = nn.BCEWithLogitsLoss()
631
- else:
632
- criterion = nn.CrossEntropyLoss()
633
-
634
- return (
635
- linear_classifiers,
636
- start_iter,
637
- max_iter,
638
- criterion,
639
- optimizer,
640
- scheduler,
641
- periodic_checkpointer,
642
- best_accuracy,
643
- )
644
-
645
-
646
- def train_linear_classifiers(
647
- *,
648
- feature_model,
649
- train_dataset,
650
- train_config: Dict,
651
- training_num_classes: int,
652
- val_evaluator: Evaluator,
653
- checkpoint_output_dir: str,
654
- sample_output: Optional[torch.Tensor] = None,
655
- ):
656
-
657
- if train_config["leave_one_out"]:
658
- assert sample_output is not None, "sample_output should be passed as argument when using leave_one_out."
659
- else:
660
- sample_output = feature_model(train_dataset[0][0].unsqueeze(0).cuda())
661
-
662
- (
663
- linear_classifiers,
664
- start_iter,
665
- max_iter,
666
- criterion,
667
- optimizer,
668
- scheduler,
669
- periodic_checkpointer,
670
- best_accuracy,
671
- ) = setup_linear_training(
672
- config=train_config,
673
- sample_output=sample_output,
674
- training_num_classes=training_num_classes,
675
- checkpoint_output_dir=checkpoint_output_dir,
676
- )
677
-
678
- sampler_type = SamplerType.INFINITE
679
- train_data_loader = make_data_loader(
680
- dataset=train_dataset,
681
- batch_size=train_config["batch_size"],
682
- num_workers=train_config["num_workers"],
683
- shuffle=True,
684
- seed=0,
685
- sampler_type=sampler_type,
686
- sampler_advance=start_iter,
687
- drop_last=True,
688
- persistent_workers=True,
689
- )
690
- eval_period = train_config["eval_period_iterations"] or train_config["epoch_length"]
691
- iteration = start_iter
692
- logger.info("Starting training from iteration {}".format(start_iter))
693
- metric_logger = MetricLogger(delimiter=" ")
694
- header = "Training"
695
-
696
- for data, labels in metric_logger.log_every(
697
- train_data_loader,
698
- 10,
699
- header,
700
- max_iter,
701
- start_iter,
702
- ):
703
- data = data.cuda(non_blocking=True)
704
- labels = labels.cuda(non_blocking=True)
705
-
706
- if not train_config["leave_one_out"]:
707
- in_classifier = feature_model(data)
708
- else:
709
- in_classifier = data
710
-
711
- outputs = linear_classifiers(in_classifier)
712
-
713
- if len(labels.shape) > 1:
714
- labels = labels.float()
715
- losses = {f"loss_{k}": criterion(v, labels) for k, v in outputs.items()}
716
- loss = sum(losses.values())
717
-
718
- optimizer.zero_grad()
719
- loss.backward()
720
-
721
- optimizer.step()
722
- scheduler.step()
723
-
724
- if iteration % 10 == 0:
725
- torch.cuda.synchronize()
726
- metric_logger.update(loss=loss.item())
727
- metric_logger.update(lr=optimizer.param_groups[0]["lr"])
728
-
729
- periodic_checkpointer.step(iteration=iteration, best_accuracy=best_accuracy)
730
-
731
- if eval_period > 0 and (iteration + 1) % eval_period == 0 and iteration != max_iter - 1:
732
- val_results_dict, _ = val_evaluator.evaluate_and_maybe_save(
733
- feature_model=feature_model,
734
- linear_classifiers=linear_classifiers,
735
- prefixstring=f"ITER: {iteration}",
736
- iteration=iteration,
737
- )
738
- val_accuracy = val_results_dict[val_evaluator.main_metric_name]
739
- if val_accuracy >= best_accuracy:
740
- best_accuracy = val_accuracy
741
- periodic_checkpointer.save_best(iteration=iteration, best_accuracy=best_accuracy)
742
- torch.distributed.barrier()
743
-
744
- iteration = iteration + 1
745
-
746
- return feature_model, linear_classifiers, iteration, periodic_checkpointer
747
-
748
-
749
- def eval_linear_with_model(
750
- model,
751
- output_dir,
752
- train_dataset_str,
753
- val_dataset_str,
754
- batch_size,
755
- epochs,
756
- epoch_length,
757
- num_workers,
758
- save_checkpoint_frequency,
759
- eval_period_iterations,
760
- learning_rates,
761
- weight_decays,
762
- autocast_dtype,
763
- test_dataset_strs=None,
764
- resume=True,
765
- classifier_fpath=None,
766
- val_metric_type=MetricType.MEAN_ACCURACY,
767
- test_metric_types=None,
768
- loss_type=LossType.CROSS_ENTROPY,
769
- bag_of_channels=False,
770
- leave_one_out_dataset="",
771
- resize_size=0,
772
- crop_size=384,
773
- n_last_blocks=4,
774
- avgpool=False,
775
- scheduler=SchedulerType.COSINE_ANNEALING,
776
- ):
777
-
778
- if leave_one_out_dataset == "" or leave_one_out_dataset is None:
779
- leave_one_out = False
780
- else:
781
- logger.info("Reading the leave-one-out label metadata.")
782
-
783
- leave_one_out_indices = {}
784
- metadata = pd.read_csv(leave_one_out_dataset)
785
- if "HPA" in leave_one_out_dataset:
786
- metadata = metadata[metadata["Task_three"]].reset_index()
787
- leave_one_out_label_type = "cell_type"
788
- else:
789
- metadata = metadata[metadata["Task_four"]].reset_index()
790
- leave_one_out_label_type = "Plate"
791
- leave_one_out_labels = metadata[leave_one_out_label_type].unique()
792
-
793
- for leave_one_out_label in leave_one_out_labels:
794
- leave_one_out_indices[leave_one_out_label] = np.array(
795
- metadata[metadata[leave_one_out_label_type] == leave_one_out_label].index.values
796
- )
797
-
798
- leave_one_out = True
799
-
800
- train_transform = make_classification_eval_cell_transform(
801
- normalization_type=NormalizationType.SELF_NORM_AUG_DECODER, crop_size=crop_size, resize_size=resize_size
802
- )
803
- print("train_transform", train_transform)
804
- train_dataset = make_dataset(
805
- dataset_str=train_dataset_str,
806
- transform=train_transform,
807
- )
808
-
809
- training_num_classes = get_num_classes(train_dataset)
810
- if leave_one_out:
811
- training_num_classes += train_dataset.num_additional_labels_loo_eval
812
- train_dataset_dict = create_train_dataset_dict(train_dataset)
813
- n_last_blocks_list = [n_last_blocks]
814
- n_last_blocks = max(n_last_blocks_list)
815
- dataset_use_cache = True
816
- autocast_ctx = partial(torch.cuda.amp.autocast, enabled=True, dtype=autocast_dtype)
817
- feature_model = ModelWithIntermediateLayers(model, n_last_blocks, autocast_ctx)
818
-
819
- if bag_of_channels:
820
- sample = train_dataset[0][0].unsqueeze(0)
821
- sample_output = feature_model(sample.cuda())
822
-
823
- if leave_one_out:
824
- loo_dict = {}
825
- train_data_dict = extract_features_for_dataset_dict(
826
- feature_model,
827
- train_dataset_dict,
828
- batch_size,
829
- num_workers,
830
- gather_on_cpu=True,
831
- avgpool=avgpool,
832
- )
833
- val_dataset = make_dataset(
834
- dataset_str=val_dataset_str,
835
- transform=make_classification_eval_cell_transform(
836
- normalization_type=NormalizationType.SELF_NORM_CENTER_CROP, crop_size=crop_size, resize_size=resize_size
837
- ),
838
- )
839
- val_dataset_dict = create_train_dataset_dict(val_dataset)
840
- val_data_dict = extract_features_for_dataset_dict(
841
- feature_model,
842
- val_dataset_dict,
843
- batch_size,
844
- num_workers,
845
- gather_on_cpu=True,
846
- avgpool=avgpool,
847
- )
848
-
849
- train_features = train_data_dict[0]["train_features"]
850
- train_labels = train_data_dict[0]["train_labels"]
851
- val_features = val_data_dict[0]["train_features"]
852
- val_labels = val_data_dict[0]["train_labels"]
853
-
854
- for loo_label, loo_indices in leave_one_out_indices.items():
855
- loo_for_training_indices = torch.ones(val_features.shape[0], dtype=bool)
856
- loo_for_training_indices[loo_indices] = False
857
- loo_for_val_indices = torch.zeros(val_features.shape[0], dtype=bool)
858
- loo_for_val_indices[loo_indices] = True
859
-
860
- loo_dict[loo_label] = {
861
- "train_features": torch.cat([train_features, val_features[loo_for_training_indices]]),
862
- "train_labels": torch.cat([train_labels, val_labels[loo_for_training_indices]]),
863
- "val_features": val_features[loo_indices],
864
- "val_labels": val_labels[loo_indices],
865
- }
866
- save_results_func = None
867
- # if config.save_results:
868
- # save_results_func = partial(default_save_results_func, output_dir=output_dir)
869
-
870
- metrics_file_path = os.path.join(output_dir, "results_eval_linear.json")
871
- periodic_checkpointers: list = []
872
-
873
- train_config = {
874
- "learning_rates": learning_rates,
875
- "weight_decays": weight_decays,
876
- "batch_size": batch_size,
877
- "num_workers": num_workers,
878
- "dataset_use_cache": dataset_use_cache,
879
- "eval_period_iterations": eval_period_iterations,
880
- "epoch_length": epoch_length,
881
- "leave_one_out": leave_one_out,
882
- "bag_of_channels": bag_of_channels,
883
- "n_last_blocks_list": n_last_blocks_list,
884
- "epochs": epochs,
885
- "loss_type": loss_type,
886
- "resume": resume,
887
- "save_checkpoint_iterations": save_checkpoint_frequency,
888
- "classifier_fpath": classifier_fpath,
889
- "avgpool": avgpool,
890
- "scheduler": scheduler,
891
- }
892
- config = {
893
- "test_metric_types": test_metric_types,
894
- "test_datasets": test_dataset_strs,
895
- "val_metric_types": val_metric_type,
896
- "val_dataset": val_dataset_str,
897
- "batch_size": batch_size,
898
- "num_workers": num_workers,
899
- "leave_one_out": leave_one_out,
900
- "crop_size": crop_size,
901
- "resize_size": resize_size,
902
- }
903
- if not leave_one_out:
904
- val_evaluator, test_evaluators = make_evaluators(
905
- config=config,
906
- val_metric_type=val_metric_type,
907
- val_dataset=val_dataset_str,
908
- metric_type=test_metric_types,
909
- metrics_file_path=metrics_file_path,
910
- training_num_classes=training_num_classes,
911
- save_results_func=save_results_func,
912
- )
913
- results_dict = {}
914
-
915
- for _try in train_dataset_dict.keys():
916
- if len(train_dataset_dict) > 1:
917
- checkpoint_output_dir = os.path.join(output_dir, f"checkpoints_{_try}")
918
- save_filename_suffix = f"_{_try}"
919
- else:
920
- checkpoint_output_dir, save_filename_suffix = output_dir, ""
921
- os.makedirs(checkpoint_output_dir, exist_ok=True)
922
-
923
- feature_model, linear_classifiers, iteration, periodic_checkpointer = train_linear_classifiers(
924
- train_config=train_config,
925
- feature_model=feature_model,
926
- train_dataset=train_dataset_dict[_try],
927
- training_num_classes=training_num_classes,
928
- val_evaluator=val_evaluator,
929
- checkpoint_output_dir=checkpoint_output_dir,
930
- )
931
- periodic_checkpointers.append(periodic_checkpointer)
932
- results_dict[_try], _ = val_evaluator.evaluate_and_maybe_save(
933
- feature_model=feature_model,
934
- linear_classifiers=linear_classifiers,
935
- iteration=iteration,
936
- save_filename_suffix=save_filename_suffix,
937
- )
938
- for test_evaluator in test_evaluators:
939
- eval_results_dict, _ = test_evaluator.evaluate_and_maybe_save(
940
- feature_model=feature_model,
941
- linear_classifiers=linear_classifiers,
942
- iteration=iteration,
943
- best_classifier_on_val=results_dict[_try]["best_classifier"],
944
- save_filename_suffix=save_filename_suffix,
945
- )
946
- results_dict[_try] = {**eval_results_dict, **results_dict[_try]}
947
- if len(train_dataset_dict) > 1:
948
- results_dict = average_metrics(results_dict, ignore_keys=["best_classifier"])
949
- else:
950
- results_dict = {**results_dict[_try]}
951
- else: # if leave one out is True
952
- test_results_dict = {}
953
- for loo_label in loo_dict.keys():
954
-
955
- checkpoint_output_dir, save_filename_suffix = os.path.join(output_dir, f"checkpoints_{loo_label}"), ""
956
- os.makedirs(checkpoint_output_dir, exist_ok=True)
957
-
958
- train_dataset_loo = TensorDataset(
959
- loo_dict[loo_label]["train_features"], loo_dict[loo_label]["train_labels"]
960
- )
961
-
962
- logger.info(f"Creating leave_one_out evaluators. loo_label: {loo_label}")
963
- val_dataset_loo = TensorDataset(loo_dict[loo_label]["val_features"], loo_dict[loo_label]["val_labels"])
964
- val_evaluators_loo, _ = make_evaluators(
965
- config=config,
966
- val_metric_type=val_metric_type,
967
- val_dataset="loo",
968
- metric_type=test_metric_types,
969
- metrics_file_path=metrics_file_path,
970
- training_num_classes=training_num_classes,
971
- save_results_func=save_results_func,
972
- val_dataset_loo=val_dataset_loo,
973
- )
974
- feature_model, linear_classifiers, iteration, periodic_checkpointer = train_linear_classifiers(
975
- feature_model=feature_model,
976
- train_dataset=train_dataset_loo,
977
- train_config=train_config,
978
- training_num_classes=training_num_classes,
979
- val_evaluator=val_evaluators_loo,
980
- checkpoint_output_dir=checkpoint_output_dir,
981
- sample_output=sample_output,
982
- )
983
- periodic_checkpointers.append(periodic_checkpointer)
984
- _, test_results_dict[loo_label] = val_evaluators_loo.evaluate_and_maybe_save(
985
- feature_model=feature_model,
986
- linear_classifiers=linear_classifiers,
987
- iteration=iteration,
988
- save_filename_suffix=save_filename_suffix,
989
- test_mode=True,
990
- )
991
- classifier_names = test_results_dict[loo_label].keys()
992
- results_dict = {k: [[], []] for k in classifier_names}
993
- for ll in test_results_dict.keys():
994
- for k in classifier_names:
995
- results_dict[k][0].append(test_results_dict[ll][k][0])
996
- results_dict[k][1].append(test_results_dict[ll][k][1])
997
- for k in classifier_names:
998
- results_dict[k] = [
999
- np.argmax(torch.cat(results_dict[k][0]).cpu().detach().numpy(), axis=1),
1000
- torch.cat(results_dict[k][1]).cpu().detach().numpy(),
1001
- ]
1002
- results_dict[k] = f1_score(results_dict[k][1], results_dict[k][0], average="macro", labels=[4, 5, 6])
1003
- logger.info(
1004
- f"Best performance is for {max(results_dict, key=results_dict.get)}, with F1-Score of {results_dict[max(results_dict, key=results_dict.get)]}"
1005
- )
1006
-
1007
- logger.info("Test Results Dict " + str(results_dict))
1008
- return results_dict
1009
-
1010
-
1011
- def main(args):
1012
- model, autocast_dtype = setup_and_build_model(args)
1013
- eval_linear_with_model(
1014
- model=model,
1015
- output_dir=args.output_dir,
1016
- train_dataset_str=args.train_dataset_str,
1017
- val_dataset_str=args.val_dataset_str,
1018
- test_dataset_strs=args.test_dataset_strs,
1019
- batch_size=args.batch_size,
1020
- epochs=args.epochs,
1021
- epoch_length=args.epoch_length,
1022
- num_workers=args.num_workers,
1023
- save_checkpoint_frequency=args.save_checkpoint_frequency,
1024
- eval_period_iterations=args.eval_period_iterations,
1025
- learning_rates=args.learning_rates,
1026
- weight_decays=args.weight_decays,
1027
- autocast_dtype=autocast_dtype,
1028
- resume=not args.no_resume,
1029
- classifier_fpath=args.classifier_fpath,
1030
- val_metric_type=args.val_metric_type,
1031
- test_metric_types=args.test_metric_types,
1032
- loss_type=args.loss_type,
1033
- bag_of_channels=args.bag_of_channels,
1034
- leave_one_out_dataset=args.leave_one_out_dataset,
1035
- crop_size=args.crop_size,
1036
- resize_size=args.resize_size,
1037
- n_last_blocks=args.n_last_blocks,
1038
- avgpool=args.avgpool,
1039
- scheduler=args.scheduler,
1040
- )
1041
- return 0
1042
-
1043
-
1044
- if __name__ == "__main__":
1045
- description = "DINOv2 linear_cell_dino evaluation"
1046
- args_parser = get_args_parser(description=description)
1047
- args = args_parser.parse_args()
1048
- sys.exit(main(args))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dinov2/eval/cell_dino/utils.py DELETED
@@ -1,542 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the CC-by-NC licence,
4
- # found in the LICENSE_CELL_DINO_CODE file in the root directory of this source tree.
5
-
6
- import logging
7
- from typing import Callable, Dict, Optional, Any, List
8
-
9
- import torch
10
- from torch import nn
11
- from torchmetrics import MetricCollection
12
-
13
- from dinov2.data import DatasetWithEnumeratedTargets, SamplerType, make_data_loader
14
- from dinov2.data import NoOpAccumulator, ResultsAccumulator
15
- import dinov2.distributed as distributed
16
- from dinov2.logging import MetricLogger
17
- from enum import Enum
18
- from torch.utils.data import Subset
19
- from torchvision.datasets.vision import StandardTransform
20
- import numpy as np
21
- from torch.nn.functional import one_hot, softmax
22
-
23
- logger = logging.getLogger("dinov2")
24
-
25
-
26
- class LossType(Enum):
27
- CROSS_ENTROPY = "cross_entropy"
28
- BINARY_CROSS_ENTROPY = "binary_cross_entropy"
29
-
30
-
31
- class BagOfChannelsModelWithNormalize(nn.Module):
32
- def __init__(self, model, autocast_ctx, avgpool, n_last_blocks=1):
33
- super().__init__()
34
- self.model = model
35
- self.autocast_ctx = autocast_ctx
36
- self.n_last_blocks = n_last_blocks
37
- self.avgpool = avgpool
38
-
39
- def forward(self, samples):
40
- with self.autocast_ctx():
41
- features = self.model.get_intermediate_layers(samples, self.n_last_blocks, return_class_token=True)
42
- output = create_linear_input(features, self.avgpool, use_n_blocks=self.n_last_blocks)
43
- return nn.functional.normalize(output, dim=1, p=2)
44
-
45
-
46
- @torch.inference_mode()
47
- def evaluate_with_accumulate(
48
- model: nn.Module,
49
- data_loader,
50
- postprocessors: Dict[str, nn.Module],
51
- metrics: Dict[str, MetricCollection],
52
- device: torch.device,
53
- criterion: Optional[nn.Module] = None,
54
- test_mode: bool = False,
55
- accumulate_results: bool = False,
56
- leave_one_out: bool = False,
57
- ):
58
- model.eval()
59
-
60
- if test_mode:
61
- output_tensor = {k: [] for k in postprocessors.keys()}
62
- target_tensor = {k: [] for k in postprocessors.keys()}
63
-
64
- if criterion is not None:
65
- criterion.eval()
66
-
67
- accumulator_class = ResultsAccumulator if accumulate_results else NoOpAccumulator
68
- accumulators = {k: accumulator_class() for k in postprocessors.keys()}
69
-
70
- for metric in metrics.values():
71
- metric = metric.to(device)
72
-
73
- metric_logger = MetricLogger(delimiter=" ")
74
- header = "Test:"
75
-
76
- for samples, targets, *_ in metric_logger.log_every(data_loader, 10, header):
77
- if isinstance(targets, list):
78
- index = targets[0]
79
- targets = targets[1]
80
- samples, targets, index = samples[index >= 0], targets[index >= 0], index[index >= 0]
81
- if len(index) == 0:
82
- continue
83
-
84
- outputs = samples.to(device) if leave_one_out else model(samples.to(device))
85
- targets = targets.to(device)
86
-
87
- if criterion is not None:
88
- loss = criterion(outputs, targets)
89
- metric_logger.update(loss=loss.item())
90
-
91
- for k, metric in metrics.items():
92
- metric_inputs = postprocessors[k](outputs, targets)
93
- metric.update(**metric_inputs)
94
- if test_mode:
95
- output_tensor[k].append(metric_inputs["preds"])
96
- target_tensor[k].append(metric_inputs["target"])
97
- accumulators[k].update(preds=metric_inputs["preds"], target=metric_inputs["target"], index=index)
98
-
99
- metric_logger.synchronize_between_processes()
100
- logger.info(f"Averaged stats: {metric_logger}")
101
-
102
- stats = {k: metric.compute() for k, metric in metrics.items()}
103
- metric_logger_stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
104
-
105
- # accumulator.accumulate() returns None for the NoOpAccumulator
106
- accumulated_results = {k: accumulator.accumulate() for k, accumulator in accumulators.items()}
107
- if test_mode:
108
- for k in postprocessors.keys():
109
- output_tensor[k] = torch.cat(output_tensor[k])
110
- target_tensor[k] = torch.cat(target_tensor[k])
111
- accumulated_results = {k: [output_tensor[k], target_tensor[k]] for k in postprocessors.keys()}
112
-
113
- if accumulate_results:
114
- return metric_logger_stats, stats
115
- return metric_logger_stats, stats, accumulated_results
116
-
117
-
118
- def all_gather_and_flatten(tensor_rank):
119
- tensor_all_ranks = torch.empty(
120
- distributed.get_global_size(),
121
- *tensor_rank.shape,
122
- dtype=tensor_rank.dtype,
123
- device=tensor_rank.device,
124
- )
125
- tensor_list = list(tensor_all_ranks.unbind(0))
126
- torch.distributed.all_gather(tensor_list, tensor_rank.contiguous())
127
- return tensor_all_ranks.flatten(end_dim=1)
128
-
129
-
130
- def extract_features_cell_dino(
131
- model, dataset, batch_size, num_workers, gather_on_cpu=False, shuffle=False, avgpool=False
132
- ):
133
- dataset_with_enumerated_targets = DatasetWithEnumeratedTargets(dataset)
134
- sample_count = len(dataset_with_enumerated_targets)
135
- data_loader = make_data_loader(
136
- dataset=dataset_with_enumerated_targets,
137
- batch_size=batch_size,
138
- num_workers=num_workers,
139
- sampler_type=SamplerType.DISTRIBUTED,
140
- drop_last=False,
141
- shuffle=shuffle,
142
- )
143
- return extract_features_with_dataloader_cell_dino(model, data_loader, sample_count, gather_on_cpu, avgpool=avgpool)
144
-
145
-
146
- @torch.inference_mode()
147
- def extract_features_with_dataloader_cell_dino(model, data_loader, sample_count, gather_on_cpu=False, avgpool=False):
148
- gather_device = torch.device("cpu") if gather_on_cpu else torch.device("cuda")
149
- metric_logger = MetricLogger(delimiter=" ")
150
- features, all_labels = None, None
151
- for samples, (index, labels_rank) in metric_logger.log_every(data_loader, 10):
152
- samples = samples.cuda(non_blocking=True)
153
- labels_rank = labels_rank.cuda(non_blocking=True)
154
- index = index.cuda(non_blocking=True)
155
- feat = model(samples)
156
- if isinstance(samples, list) or isinstance(feat, tuple):
157
- features_rank = create_linear_input(feat, avgpool=avgpool)
158
- else:
159
- features_rank = feat
160
-
161
- # init storage feature matrix
162
- if features is None:
163
- features = torch.zeros(sample_count, features_rank.shape[-1], device=gather_device)
164
- labels_shape = list(labels_rank.shape)
165
- labels_shape[0] = sample_count
166
- all_labels = torch.full(labels_shape, fill_value=-1, device=gather_device)
167
- logger.info(f"Storing features into tensor of shape {features.shape}")
168
-
169
- # share indexes, features and labels between processes
170
- index_all = all_gather_and_flatten(index).to(gather_device)
171
- features_all_ranks = all_gather_and_flatten(features_rank).to(gather_device)
172
- labels_all_ranks = all_gather_and_flatten(labels_rank).to(gather_device)
173
-
174
- # update storage feature matrix
175
- if len(index_all) > 0:
176
- features.index_copy_(0, index_all, features_all_ranks)
177
- all_labels.index_copy_(0, index_all, labels_all_ranks)
178
-
179
- logger.info(f"Features shape: {tuple(features.shape)}")
180
- logger.info(f"Labels shape: {tuple(all_labels.shape)}")
181
-
182
- assert torch.all(all_labels > -1)
183
-
184
- return features, all_labels
185
-
186
-
187
- def create_linear_input(x_tokens_list, avgpool=False, use_n_blocks=1):
188
- intermediate_output = x_tokens_list[-use_n_blocks:]
189
- output = torch.cat(
190
- [class_token for _, class_token in intermediate_output], dim=-1
191
- ) # concatenate class tokens of the last n blocks
192
- if avgpool:
193
- output = torch.cat(
194
- (
195
- output,
196
- torch.mean(intermediate_output[-1][0], dim=-2).reshape(
197
- intermediate_output[-1][0].shape[0], -1
198
- ), # average pooling of patch tokens: average over N, then concatenate channels if single-channel patch model
199
- ),
200
- dim=-1,
201
- ) # concatenate average pooling of patch tokens to concatenated patch tokens
202
- output = output.reshape(output.shape[0], -1)
203
-
204
- return output.float()
205
-
206
-
207
- def get_target_transform(dataset) -> Optional[Callable]:
208
- if hasattr(dataset, "transforms"):
209
- if isinstance(dataset.transforms, StandardTransform):
210
- return dataset.transforms.target_transform
211
- raise ValueError("Dataset has a non-standard .transforms property")
212
- if hasattr(dataset, "target_transform"):
213
- return dataset.target_transform
214
- return None
215
-
216
-
217
- def get_labels(dataset) -> torch.Tensor:
218
- """
219
- Get the labels of a classification dataset, as a Tensor, using the `get_targets` method
220
- if it is present or loading the labels one by one with `get_target`, if it exists.
221
- If the dataset has a target transform, iterate over the whole dataset to get the
222
- transformed labels for each element, then stack them as a torch tensor.
223
- """
224
- logger.info("Getting dataset labels ...")
225
- if hasattr(dataset, "get_targets") or hasattr(dataset, "get_target"):
226
- if hasattr(dataset, "get_targets"): # Returns a np.array
227
- labels = dataset.get_targets()
228
- elif hasattr(dataset, "get_target"):
229
- labels = [dataset.get_target(i) for i in range(len(dataset))]
230
- target_transform = get_target_transform(dataset)
231
- if target_transform is not None:
232
- labels = [target_transform(label) for label in labels]
233
- else:
234
- # Target transform is applied in this case
235
- labels = [dataset[i][1] for i in range(len(dataset))]
236
- return torch.stack([torch.tensor(label, dtype=int) for label in labels])
237
-
238
-
239
- def get_num_classes(dataset) -> int:
240
- """
241
- Get the labels of a dataset and compute the number of classes
242
- """
243
- labels = get_labels(dataset)
244
- if len(labels.shape) > 1:
245
- return int(labels.shape[1])
246
- return int(labels.max() + 1)
247
-
248
-
249
- def average_metrics(eval_metrics_dict: dict[Any, dict[str, torch.Tensor]], ignore_keys: List[str] = []):
250
- """
251
- Function that computes the average and the std on a metrics dict.
252
- A linear evaluation dictionary contains "best_classifier",
253
- so this specific key is removed for computing aggregated metrics.
254
- """
255
- output_metrics_dict = {}
256
- metrics = [metric for metric in eval_metrics_dict[0].keys() if metric not in ignore_keys]
257
- for metric in metrics:
258
- stats_tensor = torch.tensor([stat[metric] for stat in eval_metrics_dict.values()])
259
- output_metrics_dict[metric + "_mean"] = stats_tensor.mean().item()
260
- output_metrics_dict[metric + "_std"] = torch.std(stats_tensor).item()
261
-
262
- return output_metrics_dict
263
-
264
-
265
- def create_class_indices_mapping(labels: torch.Tensor) -> dict[int, torch.Tensor]:
266
- """
267
- Efficiently creates a mapping between the labels and tensors containing
268
- the indices of all the dataset elements that share this label.
269
- In the case of multiple labels, it is not guaranteed that there
270
- will be exactly the specified percentage of labels.
271
- """
272
- if len(labels.shape) > 1: # labels are a one-hot encoding
273
- assert len(labels.shape) == 2
274
- sorted_labels, indices = torch.nonzero(labels.T, as_tuple=True)
275
- else:
276
- sorted_labels, indices = torch.sort(labels, stable=True)
277
- unique_labels, counts = torch.unique_consecutive(sorted_labels, return_counts=True)
278
- mapping = dict(zip(unique_labels.tolist(), torch.split(indices, counts.tolist())))
279
- return mapping
280
-
281
-
282
- def _shuffle_dataset(dataset: torch.Tensor, seed: int = 0):
283
- """
284
- Shuffling a dataset by subsetting it with a random permutation of its indices
285
- """
286
- random_generator = torch.Generator()
287
- random_generator.manual_seed(seed)
288
- random_indices = torch.randperm(len(dataset), generator=random_generator)
289
- return Subset(dataset, random_indices)
290
-
291
-
292
- def _subset_dataset_per_class(
293
- class_indices_mapping: dict[int, torch.Tensor],
294
- n_or_percent_per_class: float,
295
- dataset_size: int,
296
- seed: int = 0,
297
- is_percent: bool = False,
298
- ) -> torch.Tensor:
299
- """
300
- Helper function to select a percentage of a dataset, equally distributed across classes,
301
- or to take the same number of elements from each class of the dataset.
302
- Returns a boolean mask tensor being True at indices of selected elements
303
- """
304
-
305
- random_generator = torch.Generator()
306
- random_generator.manual_seed(seed)
307
-
308
- final_indices_bool = torch.zeros(dataset_size, dtype=bool)
309
- for class_indices in class_indices_mapping.values():
310
- # Select at least one element
311
- n_for_class = max(int(len(class_indices) * n_or_percent_per_class), 1) if is_percent else n_or_percent_per_class
312
- assert isinstance(n_for_class, int)
313
- filtered_index = torch.randperm(len(class_indices), generator=random_generator)[:n_for_class]
314
- final_indices_bool[class_indices[filtered_index]] = True
315
- return final_indices_bool
316
-
317
-
318
- def _multilabel_rebalance_subset(
319
- class_indices_mapping: dict[int, torch.Tensor],
320
- n_or_percent_per_class: float,
321
- labels: torch.Tensor,
322
- indices_bool: torch.Tensor,
323
- dataset_size: int,
324
- seed: int = 0,
325
- ) -> torch.Tensor:
326
- """
327
- Helper function to refine a subset of a multi-label dataset (indices_bool)
328
- to better match a target percentage of labels.
329
- Returns a boolean mask tensor being True at indices of selected elements.
330
- """
331
-
332
- # Compute the number of selected labels in indices_bool
333
- num_total_labels = labels.sum()
334
- num_wanted_labels = int(num_total_labels * n_or_percent_per_class)
335
- num_selected_labels = (labels[indices_bool] > 0).sum()
336
- logger.info(f" {num_selected_labels} labels instead of {num_wanted_labels}")
337
-
338
- # Compute a new percentage and new set selecting less images, therefore less labels, to match approximatelly the exact percentage of labels selected
339
- n_or_percent_per_class = n_or_percent_per_class / (num_selected_labels / num_wanted_labels)
340
- final_indices_bool = _subset_dataset_per_class(
341
- class_indices_mapping, n_or_percent_per_class, dataset_size, seed, True
342
- )
343
-
344
- # Compute the number of labels finally used
345
- num_selected_labels = (labels[final_indices_bool] > 0).sum()
346
- logger.info(f" {num_selected_labels} labels instead of {num_wanted_labels}")
347
-
348
- return final_indices_bool
349
-
350
-
351
- def split_train_val_datasets(train_dataset, split_percentage: float = 0.1, shuffle_train: bool = True):
352
- """
353
- Splitting a percent of the train dataset to choose hyperparameters, taking the same percentage for each class.
354
- If `shuffle` is False, taking the first elements of each class as the validaton set.
355
- """
356
- assert 0 < split_percentage < 1
357
- logger.info(f"Selecting {int(split_percentage * 100)}% of the train dataset as the validation set")
358
- if shuffle_train:
359
- logger.info("Shuffling train dataset before splitting in train and validation sets")
360
- train_dataset = _shuffle_dataset(train_dataset)
361
- train_labels = get_labels(train_dataset)
362
- class_indices_mapping = create_class_indices_mapping(train_labels)
363
- val_mask = torch.zeros(len(train_labels), dtype=bool)
364
- for class_indices in class_indices_mapping.values():
365
- # If there is only one element, it goes in the train set
366
- n_for_val = max(1, int(split_percentage * len(class_indices))) if len(class_indices) > 1 else 0
367
- val_mask[class_indices[:n_for_val]] = True
368
-
369
- val_dataset = Subset(train_dataset, val_mask.nonzero().flatten())
370
- train_dataset = Subset(train_dataset, (~val_mask).nonzero().flatten())
371
- return train_dataset, val_dataset
372
-
373
-
374
- def create_train_dataset_dict(
375
- train_dataset,
376
- few_shot_eval: bool = False,
377
- few_shot_k_or_percent=None,
378
- few_shot_n_tries: int = 1,
379
- ) -> dict[int, dict[int, Any]]:
380
- """
381
- Randomly split a dataset for few-shot evaluation, with `few_shot_k_or_percent` being
382
- n elements or x% of a class. Produces a dict, which keys are number of random "tries"
383
- and values are the dataset subset for this "try".
384
-
385
- Format is {"nth-try": dataset}
386
- """
387
- if few_shot_eval is False:
388
- assert few_shot_k_or_percent is None
389
- assert few_shot_n_tries == 1
390
- return {0: train_dataset}
391
-
392
- assert few_shot_k_or_percent is not None
393
- train_labels = get_labels(train_dataset)
394
- class_indices_mapping = create_class_indices_mapping(train_labels)
395
- train_dataset_dict: dict[int, Any] = {}
396
- is_percent = few_shot_k_or_percent < 1
397
- if not is_percent:
398
- few_shot_k_or_percent = int(few_shot_k_or_percent)
399
-
400
- for t in range(few_shot_n_tries):
401
- t_subset_bool = _subset_dataset_per_class(
402
- class_indices_mapping=class_indices_mapping,
403
- n_or_percent_per_class=few_shot_k_or_percent,
404
- dataset_size=len(train_labels),
405
- is_percent=is_percent,
406
- seed=t,
407
- )
408
- if len(train_labels.shape) > 1 and is_percent:
409
- t_subset_bool = _multilabel_rebalance_subset(
410
- class_indices_mapping=class_indices_mapping,
411
- n_or_percent_per_class=few_shot_k_or_percent,
412
- dataset_size=len(train_labels),
413
- labels=train_labels,
414
- indices_bool=t_subset_bool,
415
- seed=t,
416
- )
417
- train_dataset_dict[t] = Subset(train_dataset, t_subset_bool.nonzero().flatten())
418
- return train_dataset_dict
419
-
420
-
421
- def extract_features_for_dataset_dict(
422
- model,
423
- dataset_dict: dict[int, dict[int, Any]],
424
- batch_size: int,
425
- num_workers: int,
426
- gather_on_cpu=False,
427
- avgpool=False,
428
- ) -> dict[int, dict[str, torch.Tensor]]:
429
- """
430
- Extract features for each subset of dataset in the context of few-shot evaluations
431
- """
432
- few_shot_data_dict: dict[int, dict[str, torch.Tensor]] = {}
433
- for try_n, dataset in dataset_dict.items():
434
- features, labels = extract_features_cell_dino(
435
- model, dataset, batch_size, num_workers, gather_on_cpu=gather_on_cpu, avgpool=avgpool
436
- )
437
- few_shot_data_dict[try_n] = {"train_features": features, "train_labels": labels}
438
- return few_shot_data_dict
439
-
440
-
441
- def pad_multilabel_and_collate(batch, pad_value=-1):
442
- """
443
- This method pads and collates a batch of (image, (index, target)) tuples, coming from
444
- DatasetWithEnumeratedTargets, with targets that are list of potentially varying sizes.
445
- The targets are padded to the length of the longest target list in the batch.
446
- """
447
- maxlen = max(len(targets) for _, (_, targets) in batch)
448
- padded_batch = [
449
- (image, (index, np.pad(targets, (0, maxlen - len(targets)), constant_values=pad_value)))
450
- for image, (index, targets) in batch
451
- ]
452
- return torch.utils.data.default_collate(padded_batch)
453
-
454
-
455
- class KnnModule(torch.nn.Module):
456
- """
457
- Gets knn of test features from all processes on a chunk of the train features
458
-
459
- Each rank gets a chunk of the train features as well as a chunk of the test features.
460
- In `compute_neighbors`, for each rank one after the other, its chunk of test features
461
- is sent to all devices, partial knns are computed with each chunk of train features
462
- then collated back on the original device.
463
- """
464
-
465
- def __init__(self, train_features, train_labels, nb_knn, T, device, num_classes=1000):
466
- super().__init__()
467
-
468
- self.global_rank = distributed.get_global_rank()
469
- self.global_size = distributed.get_global_size()
470
-
471
- self.device = device
472
- self.train_features_rank_T = train_features.chunk(self.global_size)[self.global_rank].T.to(self.device)
473
- # Labels can either be integers, or in a one-hot format
474
- self.candidates = train_labels.chunk(self.global_size)[self.global_rank].unsqueeze(0).to(self.device)
475
- self.nb_knn = nb_knn
476
- self.max_k = max(self.nb_knn)
477
- self.T = T
478
- self.num_classes = num_classes
479
-
480
- def _get_knn_sims_and_labels(self, similarity, train_labels):
481
- topk_sims, indices = similarity.topk(self.max_k, largest=True, sorted=True)
482
- if len(train_labels.shape) == 3: # If the labels are in one_hot format
483
- indices = indices.unsqueeze(2).expand(-1, -1, self.num_classes) # Orignally [bs, max_k]
484
- neighbors_labels = torch.gather(train_labels, 1, indices)
485
- return topk_sims, neighbors_labels
486
-
487
- def _similarity_for_rank(self, features_rank, source_rank):
488
- # Send the features from `source_rank` to all ranks
489
- broadcast_shape = torch.tensor(features_rank.shape).to(self.device)
490
- torch.distributed.broadcast(broadcast_shape, source_rank)
491
-
492
- broadcasted = features_rank
493
- if self.global_rank != source_rank:
494
- broadcasted = torch.zeros(*broadcast_shape, dtype=features_rank.dtype, device=self.device)
495
- torch.distributed.broadcast(broadcasted, source_rank)
496
-
497
- # Compute the neighbors for `source_rank` among `train_features_rank_T`
498
- similarity_rank = torch.mm(broadcasted, self.train_features_rank_T)
499
- candidate_labels = self.candidates.expand(len(similarity_rank), *self.candidates.shape[1:])
500
- return self._get_knn_sims_and_labels(similarity_rank, candidate_labels)
501
-
502
- def _gather_all_knn_for_rank(self, topk_sims, neighbors_labels, target_rank):
503
- # Gather all neighbors for `target_rank`
504
- topk_sims_rank = retrieved_rank = None
505
- if self.global_rank == target_rank:
506
- topk_sims_rank = [torch.zeros_like(topk_sims) for _ in range(self.global_size)]
507
- retrieved_rank = [torch.zeros_like(neighbors_labels) for _ in range(self.global_size)]
508
-
509
- torch.distributed.gather(topk_sims, topk_sims_rank, dst=target_rank)
510
- torch.distributed.gather(neighbors_labels, retrieved_rank, dst=target_rank)
511
-
512
- if self.global_rank == target_rank:
513
- # Perform a second top-k on the k * global_size retrieved neighbors
514
- topk_sims_rank = torch.cat(topk_sims_rank, dim=1)
515
- retrieved_rank = torch.cat(retrieved_rank, dim=1)
516
- results = self._get_knn_sims_and_labels(topk_sims_rank, retrieved_rank)
517
- return results
518
- return None
519
-
520
- def compute_neighbors(self, features_rank):
521
- for rank in range(self.global_size):
522
- topk_sims, neighbors_labels = self._similarity_for_rank(features_rank, rank)
523
- results = self._gather_all_knn_for_rank(topk_sims, neighbors_labels, rank)
524
- if results is not None:
525
- topk_sims_rank, neighbors_labels_rank = results
526
- return topk_sims_rank, neighbors_labels_rank
527
-
528
- def forward(self, features_rank):
529
- """
530
- Compute the results on all values of `self.nb_knn` neighbors from the full `self.max_k`
531
- """
532
- assert all(k <= self.max_k for k in self.nb_knn)
533
-
534
- topk_sims, neighbors_labels = self.compute_neighbors(features_rank)
535
- batch_size = neighbors_labels.shape[0]
536
- topk_sims_transform = softmax(topk_sims / self.T, 1)
537
- voting_coefficient = topk_sims_transform.view(batch_size, -1, 1)
538
- if len(neighbors_labels.shape) == 2: # If the labels are not yet one hot
539
- neighbors_labels = one_hot(neighbors_labels, num_classes=self.num_classes)
540
- matmul = torch.mul(neighbors_labels, voting_coefficient)
541
- probas_for_k = {k: torch.sum(matmul[:, :k, :], 1) for k in self.nb_knn}
542
- return probas_for_k
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dinov2/eval/depth/__init__.py DELETED
@@ -1,4 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the Apache License, Version 2.0
4
- # found in the LICENSE file in the root directory of this source tree.
 
 
 
 
 
dinov2/eval/depth/models/__init__.py DELETED
@@ -1,10 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the Apache License, Version 2.0
4
- # found in the LICENSE file in the root directory of this source tree.
5
-
6
- from .backbones import * # noqa: F403
7
- from .builder import BACKBONES, DEPTHER, HEADS, LOSSES, build_backbone, build_depther, build_head, build_loss
8
- from .decode_heads import * # noqa: F403
9
- from .depther import * # noqa: F403
10
- from .losses import * # noqa: F403
 
 
 
 
 
 
 
 
 
 
 
dinov2/eval/depth/models/backbones/__init__.py DELETED
@@ -1,6 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the Apache License, Version 2.0
4
- # found in the LICENSE file in the root directory of this source tree.
5
-
6
- from .vision_transformer import DinoVisionTransformer
 
 
 
 
 
 
 
dinov2/eval/depth/models/backbones/vision_transformer.py DELETED
@@ -1,16 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the Apache License, Version 2.0
4
- # found in the LICENSE file in the root directory of this source tree.
5
-
6
- from mmcv.runner import BaseModule
7
-
8
- from ..builder import BACKBONES
9
-
10
-
11
- @BACKBONES.register_module()
12
- class DinoVisionTransformer(BaseModule):
13
- """Vision Transformer."""
14
-
15
- def __init__(self, *args, **kwargs):
16
- super().__init__()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dinov2/eval/depth/models/builder.py DELETED
@@ -1,49 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the Apache License, Version 2.0
4
- # found in the LICENSE file in the root directory of this source tree.
5
-
6
- import warnings
7
-
8
- from mmcv.cnn import MODELS as MMCV_MODELS
9
- from mmcv.cnn.bricks.registry import ATTENTION as MMCV_ATTENTION
10
- from mmcv.utils import Registry
11
-
12
- MODELS = Registry("models", parent=MMCV_MODELS)
13
- ATTENTION = Registry("attention", parent=MMCV_ATTENTION)
14
-
15
-
16
- BACKBONES = MODELS
17
- NECKS = MODELS
18
- HEADS = MODELS
19
- LOSSES = MODELS
20
- DEPTHER = MODELS
21
-
22
-
23
- def build_backbone(cfg):
24
- """Build backbone."""
25
- return BACKBONES.build(cfg)
26
-
27
-
28
- def build_neck(cfg):
29
- """Build neck."""
30
- return NECKS.build(cfg)
31
-
32
-
33
- def build_head(cfg):
34
- """Build head."""
35
- return HEADS.build(cfg)
36
-
37
-
38
- def build_loss(cfg):
39
- """Build loss."""
40
- return LOSSES.build(cfg)
41
-
42
-
43
- def build_depther(cfg, train_cfg=None, test_cfg=None):
44
- """Build depther."""
45
- if train_cfg is not None or test_cfg is not None:
46
- warnings.warn("train_cfg and test_cfg is deprecated, " "please specify them in model", UserWarning)
47
- assert cfg.get("train_cfg") is None or train_cfg is None, "train_cfg specified in both outer field and model field "
48
- assert cfg.get("test_cfg") is None or test_cfg is None, "test_cfg specified in both outer field and model field "
49
- return DEPTHER.build(cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))