add necessary module
Browse files- dataset/__init__.py +1 -0
- dataset/dataset_cls.py +480 -0
- dataset/dataset_seg.py +556 -0
- demo_classfication.py +192 -0
- demo_segmentation.py +250 -0
- engine/__init__.py +1 -0
- engine/classification.py +341 -0
- engine/location.py +206 -0
- engine/pretrain.py +85 -0
- engine/pretrain_amp.py +81 -0
- engine/regression.py +142 -0
- engine/segment.py +199 -0
- models/__init__.py +1 -0
- models/build_classification.py +83 -0
- models/classifier.py +23 -0
- models/convnext_unter.py +182 -0
- models/convnextv2.py +311 -0
- models/upernet_module.py +451 -0
- models/util.py +258 -0
- requirements.txt +0 -3
- util/__init__.py +1 -0
- util/convnext_optim.py +127 -0
- util/lars.py +59 -0
- util/lr_sched.py +28 -0
- util/metric.py +340 -0
- util/misc.py +455 -0
dataset/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# ProFound dataset package
|
dataset/dataset_cls.py
ADDED
|
@@ -0,0 +1,480 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
from monai.transforms import (
|
| 3 |
+
Compose,
|
| 4 |
+
RandCropByPosNegLabeld,
|
| 5 |
+
CropForegroundd,
|
| 6 |
+
SpatialPadd,
|
| 7 |
+
ScaleIntensityRanged,
|
| 8 |
+
RandShiftIntensityd,
|
| 9 |
+
RandFlipd,
|
| 10 |
+
RandAffined,
|
| 11 |
+
RandZoomd,
|
| 12 |
+
RandRotated,
|
| 13 |
+
RandBiasFieldd,
|
| 14 |
+
RandRotate90d,
|
| 15 |
+
RandGaussianNoised,
|
| 16 |
+
RandGaussianSmoothd,
|
| 17 |
+
NormalizeIntensityd,
|
| 18 |
+
MapTransform,
|
| 19 |
+
RandScaleIntensityd,
|
| 20 |
+
RandSpatialCropd,
|
| 21 |
+
CenterSpatialCropd,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
|
| 25 |
+
import torch
|
| 26 |
+
import numpy as np
|
| 27 |
+
import nibabel as nib
|
| 28 |
+
import torch.nn.functional as F
|
| 29 |
+
import os
|
| 30 |
+
import pandas as pd
|
| 31 |
+
from ast import literal_eval
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class RiskSet(Dataset):
|
| 35 |
+
def __init__(self, args, image_paths, phase, transforms=None):
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.img_dict = pd.read_csv(image_paths)
|
| 38 |
+
if phase == 'train':
|
| 39 |
+
if args.data_num > 0:
|
| 40 |
+
# crop the dataset
|
| 41 |
+
self.img_dict = self.img_dict.iloc[: args.data_num]
|
| 42 |
+
print(f"Loading {phase} dataset with {len(self.img_dict)} samples")
|
| 43 |
+
self.root = args.root
|
| 44 |
+
self._set_dataset_stat()
|
| 45 |
+
self.transforms = transforms # self.get_transforms()
|
| 46 |
+
if not args.demo:
|
| 47 |
+
self.set_sampler()
|
| 48 |
+
|
| 49 |
+
def set_sampler(self):
|
| 50 |
+
class_counts = self.img_dict["pirads"].value_counts().sort_index().values
|
| 51 |
+
class_weights = 1.0 / class_counts
|
| 52 |
+
values = self.img_dict["pirads"].values.astype(int) - 2
|
| 53 |
+
self.sampler_weight = class_weights[values]
|
| 54 |
+
|
| 55 |
+
def cal_weight(self):
|
| 56 |
+
class_counts = self.img_dict["pirads"].value_counts().sort_index().values
|
| 57 |
+
return class_counts
|
| 58 |
+
|
| 59 |
+
def _set_dataset_stat(self):
|
| 60 |
+
self.spacing = (0.5, 0.5, 1.0)
|
| 61 |
+
self.spatial_index = [2, 1, 0] # index used to convert to DHW
|
| 62 |
+
self.target_class = 1
|
| 63 |
+
|
| 64 |
+
def __len__(self):
|
| 65 |
+
return len(self.img_dict)
|
| 66 |
+
|
| 67 |
+
def read(self, path):
|
| 68 |
+
vol = nib.load(os.path.join(self.root, path))
|
| 69 |
+
vol = vol.get_fdata().astype(np.float32).transpose(self.spatial_index)
|
| 70 |
+
vol = torch.from_numpy(vol)
|
| 71 |
+
return vol
|
| 72 |
+
|
| 73 |
+
def __getitem__(self, idx):
|
| 74 |
+
path = self.img_dict.iloc[idx]
|
| 75 |
+
t2w = self.read(path["t2w"])
|
| 76 |
+
dwi = self.read(path["highb"])
|
| 77 |
+
adc = self.read(path["adc"])
|
| 78 |
+
img = torch.stack([t2w, dwi, adc], 0)
|
| 79 |
+
label = torch.tensor(int(path["pirads"]) - 2, dtype=torch.long)
|
| 80 |
+
if self.transforms is not None:
|
| 81 |
+
trans_dict = self.transforms({"image": img})
|
| 82 |
+
if type(trans_dict) == list:
|
| 83 |
+
trans_dict = trans_dict[0]
|
| 84 |
+
img = trans_dict["image"]
|
| 85 |
+
return img, label, torch.tensor(idx, dtype=torch.long)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class ScreeningSet(RiskSet):
|
| 89 |
+
def __init__(self, args, image_paths, phase, transforms=None):
|
| 90 |
+
super().__init__(args=args, image_paths=image_paths, phase = phase, transforms=transforms)
|
| 91 |
+
|
| 92 |
+
def set_sampler(self):
|
| 93 |
+
class_counts = self.img_dict["result"].value_counts().sort_index().values
|
| 94 |
+
class_weights = 1.0 / class_counts
|
| 95 |
+
self.sampler_weight = class_weights[self.img_dict["result"].values]
|
| 96 |
+
|
| 97 |
+
def cal_weight(self):
|
| 98 |
+
class_counts = self.img_dict["result"].value_counts().sort_index().values
|
| 99 |
+
return class_counts
|
| 100 |
+
|
| 101 |
+
def __getitem__(self, idx):
|
| 102 |
+
path = self.img_dict.iloc[idx]
|
| 103 |
+
t2w = self.read(path["t2w"])
|
| 104 |
+
dwi = self.read(path["dwi"])
|
| 105 |
+
adc = self.read(path["adc"])
|
| 106 |
+
img = torch.stack([t2w, dwi, adc], 0)
|
| 107 |
+
label = torch.tensor(int(path["result"]), dtype=torch.long)
|
| 108 |
+
if self.transforms is not None:
|
| 109 |
+
trans_dict = self.transforms({"image": img})
|
| 110 |
+
if type(trans_dict) == list:
|
| 111 |
+
trans_dict = trans_dict[0]
|
| 112 |
+
img = trans_dict["image"]
|
| 113 |
+
return img, label, torch.tensor(idx, dtype=torch.long)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class PromisSet(RiskSet):
|
| 117 |
+
def __init__(self, args, image_paths, phase, transforms=None):
|
| 118 |
+
super().__init__(args=args, image_paths=image_paths, phase=phase, transforms=transforms)
|
| 119 |
+
|
| 120 |
+
def set_sampler(self):
|
| 121 |
+
class_counts = self.img_dict["patient_level"].value_counts().sort_index().values
|
| 122 |
+
class_weights = 1.0 / class_counts
|
| 123 |
+
self.sampler_weight = class_weights[self.img_dict["patient_level"].values.astype(int)]
|
| 124 |
+
|
| 125 |
+
def cal_weight(self):
|
| 126 |
+
class_counts = self.img_dict["patient_level"].value_counts().sort_index().values
|
| 127 |
+
return class_counts
|
| 128 |
+
|
| 129 |
+
def __getitem__(self, idx):
|
| 130 |
+
path = self.img_dict.iloc[idx]
|
| 131 |
+
t2w = self.read(path["t2w"])
|
| 132 |
+
dwi = self.read(path["dwi"])
|
| 133 |
+
adc = self.read(path["adc"])
|
| 134 |
+
img = torch.stack([t2w, dwi, adc], 0)
|
| 135 |
+
zone_level = literal_eval(path["zone_level"])
|
| 136 |
+
zone_level = torch.tensor(zone_level, dtype=torch.float32)
|
| 137 |
+
#patient_level = torch.tensor(int(path["patient_level"]), dtype=torch.float32)
|
| 138 |
+
if self.transforms is not None:
|
| 139 |
+
trans_dict = self.transforms({"image": img})
|
| 140 |
+
if type(trans_dict) == list:
|
| 141 |
+
trans_dict = trans_dict[0]
|
| 142 |
+
img = trans_dict["image"]
|
| 143 |
+
return img, zone_level, torch.tensor(idx, dtype=torch.long)
|
| 144 |
+
|
| 145 |
+
class Promis3HistSet(RiskSet):
|
| 146 |
+
def __init__(self, args, image_paths, phase, transforms=None):
|
| 147 |
+
super().__init__(args=args, image_paths=image_paths, phase=phase, transforms=transforms)
|
| 148 |
+
|
| 149 |
+
def set_sampler(self):
|
| 150 |
+
class_counts = self.img_dict["def"].value_counts().sort_index().values
|
| 151 |
+
class_weights = 1.0 / class_counts
|
| 152 |
+
self.sampler_weight = class_weights[self.img_dict["def"].values.astype(int)]
|
| 153 |
+
|
| 154 |
+
def cal_weight(self):
|
| 155 |
+
class_counts = self.img_dict["def"].value_counts().sort_index().values
|
| 156 |
+
return class_counts
|
| 157 |
+
|
| 158 |
+
def __getitem__(self, idx):
|
| 159 |
+
path = self.img_dict.iloc[idx]
|
| 160 |
+
t2w = self.read(path["t2w"])
|
| 161 |
+
dwi = self.read(path["dwi"])
|
| 162 |
+
adc = self.read(path["adc"])
|
| 163 |
+
img = torch.stack([t2w, dwi, adc], 0)
|
| 164 |
+
label = torch.tensor(int(path["def"]), dtype=torch.long)
|
| 165 |
+
if self.transforms is not None:
|
| 166 |
+
trans_dict = self.transforms({"image": img})
|
| 167 |
+
if type(trans_dict) == list:
|
| 168 |
+
trans_dict = trans_dict[0]
|
| 169 |
+
img = trans_dict["image"]
|
| 170 |
+
return img, label, torch.tensor(idx, dtype=torch.long)
|
| 171 |
+
|
| 172 |
+
class Promis3GGSet(RiskSet):
|
| 173 |
+
def __init__(self, args, image_paths, phase, transforms=None):
|
| 174 |
+
super().__init__(args=args, image_paths=image_paths, phase=phase, transforms=transforms)
|
| 175 |
+
|
| 176 |
+
def set_sampler(self):
|
| 177 |
+
class_counts = self.img_dict["gleason"].value_counts().sort_index().values
|
| 178 |
+
class_weights = 1.0 / class_counts
|
| 179 |
+
self.sampler_weight = class_weights[self.img_dict["gleason"].values.astype(int)]
|
| 180 |
+
|
| 181 |
+
def cal_weight(self):
|
| 182 |
+
class_counts = self.img_dict["gleason"].value_counts().sort_index().values
|
| 183 |
+
return class_counts
|
| 184 |
+
|
| 185 |
+
def __getitem__(self, idx):
|
| 186 |
+
path = self.img_dict.iloc[idx]
|
| 187 |
+
t2w = self.read(path["t2w"])
|
| 188 |
+
dwi = self.read(path["dwi"])
|
| 189 |
+
adc = self.read(path["adc"])
|
| 190 |
+
img = torch.stack([t2w, dwi, adc], 0)
|
| 191 |
+
label = torch.tensor(int(path["gleason"]), dtype=torch.long)
|
| 192 |
+
if self.transforms is not None:
|
| 193 |
+
trans_dict = self.transforms({"image": img})
|
| 194 |
+
if type(trans_dict) == list:
|
| 195 |
+
trans_dict = trans_dict[0]
|
| 196 |
+
img = trans_dict["image"]
|
| 197 |
+
return img, label, torch.tensor(idx, dtype=torch.long)
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def get_transforms(args):
|
| 201 |
+
train_transforms = [
|
| 202 |
+
NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
|
| 203 |
+
CenterSpatialCropd(keys="image", roi_size=(80, 300, 300)),
|
| 204 |
+
RandRotated(
|
| 205 |
+
keys="image",
|
| 206 |
+
prob=0.3,
|
| 207 |
+
range_x=10 / 180 * np.pi,
|
| 208 |
+
range_y=10 / 180 * np.pi,
|
| 209 |
+
range_z=10 / 180 * np.pi,
|
| 210 |
+
keep_size=False,
|
| 211 |
+
mode="bilinear",
|
| 212 |
+
),
|
| 213 |
+
RandZoomd(
|
| 214 |
+
keys="image",
|
| 215 |
+
prob=0.3,
|
| 216 |
+
min_zoom=[0.9, 0.9, 0.9],
|
| 217 |
+
max_zoom=[1.1, 1.1, 1.1],
|
| 218 |
+
mode="trilinear",
|
| 219 |
+
),
|
| 220 |
+
SpatialPadd(
|
| 221 |
+
keys="image",
|
| 222 |
+
spatial_size=[round(i * 1.2) for i in args.crop_spatial_size],
|
| 223 |
+
),
|
| 224 |
+
RandSpatialCropd(
|
| 225 |
+
keys="image",
|
| 226 |
+
roi_size=args.crop_spatial_size,
|
| 227 |
+
random_size=False,
|
| 228 |
+
),
|
| 229 |
+
RandFlipd(keys="image", prob=0.5, spatial_axis=2),
|
| 230 |
+
# BinarizeLabeld(keys=["label"])
|
| 231 |
+
RandScaleIntensityd(keys="image", factors=0.1, prob=0.8),
|
| 232 |
+
RandShiftIntensityd(keys="image", offsets=0.1, prob=0.8),
|
| 233 |
+
RandBiasFieldd(keys="image", prob=0.2),
|
| 234 |
+
RandGaussianSmoothd(keys="image", prob=1.0)
|
| 235 |
+
]
|
| 236 |
+
|
| 237 |
+
train_transforms = Compose(train_transforms)
|
| 238 |
+
val_transforms = Compose(
|
| 239 |
+
[
|
| 240 |
+
NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
|
| 241 |
+
CenterSpatialCropd(keys="image", roi_size=args.crop_spatial_size),
|
| 242 |
+
SpatialPadd(keys="image", spatial_size=[i for i in args.crop_spatial_size]),
|
| 243 |
+
# BinarizeLabeld(keys=["label"])
|
| 244 |
+
]
|
| 245 |
+
)
|
| 246 |
+
test_transforms = Compose(
|
| 247 |
+
[
|
| 248 |
+
NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
|
| 249 |
+
CenterSpatialCropd(keys="image", roi_size=args.crop_spatial_size),
|
| 250 |
+
SpatialPadd(keys="image", spatial_size=[i for i in args.crop_spatial_size]),
|
| 251 |
+
# BinarizeLabeld(keys=["label"])
|
| 252 |
+
]
|
| 253 |
+
)
|
| 254 |
+
return train_transforms, val_transforms, test_transforms
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def build_Risk_loader(args):
|
| 258 |
+
train_transforms, val_transforms, test_transforms = get_transforms(args)
|
| 259 |
+
|
| 260 |
+
if args.demo:
|
| 261 |
+
test_set = RiskSet(args, "demo/data/risk/test.csv", 'test', test_transforms)
|
| 262 |
+
test_loader = DataLoader(
|
| 263 |
+
test_set,
|
| 264 |
+
batch_size=args.batch_size,
|
| 265 |
+
shuffle=False,
|
| 266 |
+
pin_memory=True,
|
| 267 |
+
num_workers=14,
|
| 268 |
+
drop_last=False,
|
| 269 |
+
)
|
| 270 |
+
args.in_channels = 3
|
| 271 |
+
args.num_classes = 4
|
| 272 |
+
return test_loader
|
| 273 |
+
else:
|
| 274 |
+
if args.data20:
|
| 275 |
+
train_set = RiskSet(args, "spilt/risk/train_16.csv", 'train', train_transforms)
|
| 276 |
+
else:
|
| 277 |
+
train_set = RiskSet(args, "spilt/risk/train.csv", 'train', train_transforms)
|
| 278 |
+
val_set = RiskSet(args, "spilt/risk/val.csv", 'val', val_transforms)
|
| 279 |
+
test_set = RiskSet(args, "spilt/risk/test.csv", 'test', test_transforms)
|
| 280 |
+
|
| 281 |
+
sampler = WeightedRandomSampler(
|
| 282 |
+
weights=train_set.sampler_weight, num_samples=len(train_set), replacement=True
|
| 283 |
+
)
|
| 284 |
+
train_loader = DataLoader(
|
| 285 |
+
train_set,
|
| 286 |
+
batch_size=args.batch_size,
|
| 287 |
+
sampler=sampler,
|
| 288 |
+
num_workers=args.num_workers,
|
| 289 |
+
drop_last=False,
|
| 290 |
+
pin_memory=True,
|
| 291 |
+
)
|
| 292 |
+
val_loader = DataLoader(
|
| 293 |
+
val_set,
|
| 294 |
+
batch_size=args.batch_size,
|
| 295 |
+
shuffle=False,
|
| 296 |
+
pin_memory=True,
|
| 297 |
+
num_workers=14,
|
| 298 |
+
drop_last=False,
|
| 299 |
+
)
|
| 300 |
+
test_loader = DataLoader(
|
| 301 |
+
test_set,
|
| 302 |
+
batch_size=args.batch_size,
|
| 303 |
+
shuffle=False,
|
| 304 |
+
pin_memory=True,
|
| 305 |
+
num_workers=14,
|
| 306 |
+
drop_last=False,
|
| 307 |
+
)
|
| 308 |
+
args.in_channels = 3
|
| 309 |
+
args.num_classes = 4
|
| 310 |
+
return train_loader, val_loader, test_loader
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def build_Screening_loader(args):
|
| 314 |
+
train_transforms, val_transforms, test_transforms = get_transforms(args)
|
| 315 |
+
if args.kfold is None:
|
| 316 |
+
if args.data20:
|
| 317 |
+
train_set = ScreeningSet(
|
| 318 |
+
args, "spilt/screening/train_20.csv", 'train', train_transforms
|
| 319 |
+
)
|
| 320 |
+
else:
|
| 321 |
+
train_set = ScreeningSet(
|
| 322 |
+
args, "spilt/screening/train.csv", 'train', train_transforms
|
| 323 |
+
)
|
| 324 |
+
val_set = ScreeningSet(args, "spilt/screening/val.csv", 'val', val_transforms)
|
| 325 |
+
test_set = ScreeningSet(args, "spilt/screening/test.csv", 'test', test_transforms)
|
| 326 |
+
args.cls_account = train_set.cal_weight() / len(train_set)
|
| 327 |
+
else:
|
| 328 |
+
train_set = ScreeningSet(
|
| 329 |
+
args, f"spilt/screening/train_{args.kfold}.csv", train_transforms
|
| 330 |
+
)
|
| 331 |
+
args.cls_account = train_set.cal_weight() / len(train_set)
|
| 332 |
+
train_set, val_set = torch.utils.data.random_split(train_set, [0.9, 0.1])
|
| 333 |
+
val_set.transforms = val_transforms
|
| 334 |
+
test_set = ScreeningSet(
|
| 335 |
+
args, f"spilt/screening/test_{args.kfold}.csv", test_transforms
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
# sampler_weight = [train_set.dataset.sampler_weight[i] for i in train_set.indices]
|
| 339 |
+
sampler = WeightedRandomSampler(
|
| 340 |
+
weights=train_set.sampler_weight, num_samples=len(train_set), replacement=True
|
| 341 |
+
)
|
| 342 |
+
train_loader = DataLoader(
|
| 343 |
+
train_set,
|
| 344 |
+
batch_size=args.batch_size,
|
| 345 |
+
sampler=sampler,
|
| 346 |
+
num_workers=args.num_workers,
|
| 347 |
+
drop_last=True,
|
| 348 |
+
pin_memory=True,
|
| 349 |
+
)
|
| 350 |
+
val_loader = DataLoader(
|
| 351 |
+
val_set,
|
| 352 |
+
batch_size=args.batch_size,
|
| 353 |
+
shuffle=False,
|
| 354 |
+
pin_memory=True,
|
| 355 |
+
num_workers=14,
|
| 356 |
+
drop_last=False,
|
| 357 |
+
)
|
| 358 |
+
test_loader = DataLoader(
|
| 359 |
+
test_set,
|
| 360 |
+
batch_size=args.batch_size,
|
| 361 |
+
shuffle=False,
|
| 362 |
+
pin_memory=True,
|
| 363 |
+
num_workers=14,
|
| 364 |
+
drop_last=False,
|
| 365 |
+
)
|
| 366 |
+
args.in_channels = 3
|
| 367 |
+
args.num_classes = 2
|
| 368 |
+
return train_loader, val_loader, test_loader
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
# 4.0 453
|
| 372 |
+
# 3.0 206
|
| 373 |
+
# 5.0 195
|
| 374 |
+
# 2.0 174
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
def build_Promis_loader(args):
|
| 378 |
+
train_transforms, val_transforms, test_transforms = get_transforms(args)
|
| 379 |
+
if args.data20:
|
| 380 |
+
train_set = PromisSet(args, "spilt/promis567_hist/train_20.csv", 'train', train_transforms)
|
| 381 |
+
else:
|
| 382 |
+
train_set = PromisSet(args, "spilt/promis567_hist/train.csv", 'train', train_transforms)
|
| 383 |
+
val_set = PromisSet(args, "spilt/promis567_hist/val.csv", 'val', val_transforms)
|
| 384 |
+
test_set = PromisSet(args, "spilt/promis567_hist/test.csv", 'test', test_transforms)
|
| 385 |
+
|
| 386 |
+
# sampler = WeightedRandomSampler(
|
| 387 |
+
# weights=train_set.sampler_weight, num_samples=len(train_set), replacement=True
|
| 388 |
+
# )
|
| 389 |
+
train_loader = DataLoader(
|
| 390 |
+
train_set,
|
| 391 |
+
batch_size=args.batch_size,
|
| 392 |
+
num_workers=args.num_workers,
|
| 393 |
+
drop_last=True,
|
| 394 |
+
pin_memory=True,
|
| 395 |
+
)
|
| 396 |
+
val_loader = DataLoader(
|
| 397 |
+
val_set,
|
| 398 |
+
batch_size=args.batch_size,
|
| 399 |
+
shuffle=False,
|
| 400 |
+
pin_memory=True,
|
| 401 |
+
num_workers=14,
|
| 402 |
+
drop_last=False,
|
| 403 |
+
)
|
| 404 |
+
test_loader = DataLoader(
|
| 405 |
+
test_set,
|
| 406 |
+
batch_size=args.batch_size,
|
| 407 |
+
shuffle=False,
|
| 408 |
+
pin_memory=True,
|
| 409 |
+
num_workers=14,
|
| 410 |
+
drop_last=False,
|
| 411 |
+
)
|
| 412 |
+
args.in_channels = 3
|
| 413 |
+
args.num_classes = 20
|
| 414 |
+
return train_loader, val_loader, test_loader
|
| 415 |
+
|
| 416 |
+
def build_Promis3_hist_loader(args):
|
| 417 |
+
train_transforms, val_transforms, test_transforms = get_transforms(args)
|
| 418 |
+
train_set = Promis3HistSet(args, "spilt/promis_pirads3_hist/train.csv", 'train', train_transforms)
|
| 419 |
+
val_set = Promis3HistSet(args, "spilt/promis_pirads3_hist/val.csv", 'val', val_transforms)
|
| 420 |
+
test_set = Promis3HistSet(args, "spilt/promis_pirads3_hist/test.csv", 'test', test_transforms)
|
| 421 |
+
|
| 422 |
+
train_loader = DataLoader(
|
| 423 |
+
train_set,
|
| 424 |
+
batch_size=args.batch_size,
|
| 425 |
+
num_workers=args.num_workers,
|
| 426 |
+
drop_last=True,
|
| 427 |
+
pin_memory=True,
|
| 428 |
+
)
|
| 429 |
+
val_loader = DataLoader(
|
| 430 |
+
val_set,
|
| 431 |
+
batch_size=args.batch_size,
|
| 432 |
+
shuffle=False,
|
| 433 |
+
pin_memory=True,
|
| 434 |
+
num_workers=14,
|
| 435 |
+
drop_last=False,
|
| 436 |
+
)
|
| 437 |
+
test_loader = DataLoader(
|
| 438 |
+
test_set,
|
| 439 |
+
batch_size=args.batch_size,
|
| 440 |
+
shuffle=False,
|
| 441 |
+
pin_memory=True,
|
| 442 |
+
num_workers=14,
|
| 443 |
+
drop_last=False,
|
| 444 |
+
)
|
| 445 |
+
args.in_channels = 3
|
| 446 |
+
args.num_classes = 3
|
| 447 |
+
return train_loader, val_loader, test_loader
|
| 448 |
+
|
| 449 |
+
def build_Promis3_gg_loader(args):
|
| 450 |
+
train_transforms, val_transforms, test_transforms = get_transforms(args)
|
| 451 |
+
train_set = Promis3GGSet(args, "spilt/promis_pirads3_gg/train.csv", 'train', train_transforms)
|
| 452 |
+
val_set = Promis3GGSet(args, "spilt/promis_pirads3_gg/val.csv", 'val', val_transforms)
|
| 453 |
+
test_set = Promis3GGSet(args, "spilt/promis_pirads3_gg/test.csv", 'test', test_transforms)
|
| 454 |
+
|
| 455 |
+
train_loader = DataLoader(
|
| 456 |
+
train_set,
|
| 457 |
+
batch_size=args.batch_size,
|
| 458 |
+
num_workers=args.num_workers,
|
| 459 |
+
drop_last=True,
|
| 460 |
+
pin_memory=True,
|
| 461 |
+
)
|
| 462 |
+
val_loader = DataLoader(
|
| 463 |
+
val_set,
|
| 464 |
+
batch_size=args.batch_size,
|
| 465 |
+
shuffle=False,
|
| 466 |
+
pin_memory=True,
|
| 467 |
+
num_workers=14,
|
| 468 |
+
drop_last=False,
|
| 469 |
+
)
|
| 470 |
+
test_loader = DataLoader(
|
| 471 |
+
test_set,
|
| 472 |
+
batch_size=args.batch_size,
|
| 473 |
+
shuffle=False,
|
| 474 |
+
pin_memory=True,
|
| 475 |
+
num_workers=14,
|
| 476 |
+
drop_last=False,
|
| 477 |
+
)
|
| 478 |
+
args.in_channels = 3
|
| 479 |
+
args.num_classes = 5
|
| 480 |
+
return train_loader, val_loader, test_loader
|
dataset/dataset_seg.py
ADDED
|
@@ -0,0 +1,556 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
from monai.transforms import (
|
| 3 |
+
Compose,
|
| 4 |
+
RandCropByPosNegLabeld,
|
| 5 |
+
CropForegroundd,
|
| 6 |
+
SpatialPadd,
|
| 7 |
+
ScaleIntensityRanged,
|
| 8 |
+
RandShiftIntensityd,
|
| 9 |
+
RandFlipd,
|
| 10 |
+
RandAffined,
|
| 11 |
+
RandZoomd,
|
| 12 |
+
RandRotated,
|
| 13 |
+
RandRotate90d,
|
| 14 |
+
RandGaussianNoised,
|
| 15 |
+
RandGaussianSmoothd,
|
| 16 |
+
NormalizeIntensityd,
|
| 17 |
+
RandBiasFieldd,
|
| 18 |
+
MapTransform,
|
| 19 |
+
RandScaleIntensityd,
|
| 20 |
+
RandSpatialCropd,
|
| 21 |
+
CenterSpatialCropd,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
from torch.utils.data import DataLoader, Dataset
|
| 25 |
+
import torch
|
| 26 |
+
import numpy as np
|
| 27 |
+
import nibabel as nib
|
| 28 |
+
import torch.nn.functional as F
|
| 29 |
+
import os
|
| 30 |
+
import pandas as pd
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class BaseVolumeDataset(Dataset):
|
| 34 |
+
def __init__(self, args, image_paths, phase, transforms=None):
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.img_dict = pd.read_csv(image_paths)
|
| 37 |
+
if phase == 'train':
|
| 38 |
+
if args.data_num > 0:
|
| 39 |
+
# crop the dataset
|
| 40 |
+
self.img_dict = self.img_dict.iloc[: args.data_num]
|
| 41 |
+
print(f"Loading {phase} dataset with {len(self.img_dict)} samples")
|
| 42 |
+
self.root = args.root
|
| 43 |
+
self._set_dataset_stat()
|
| 44 |
+
self.transforms = transforms # self.get_transforms()
|
| 45 |
+
|
| 46 |
+
def _set_dataset_stat(self):
|
| 47 |
+
self.spacing = (0.5, 0.5, 1.0)
|
| 48 |
+
self.spatial_index = [2, 1, 0] # index used to convert to DHW
|
| 49 |
+
self.target_class = 1
|
| 50 |
+
|
| 51 |
+
def __len__(self):
|
| 52 |
+
return len(self.img_dict)
|
| 53 |
+
|
| 54 |
+
def read(self, path):
|
| 55 |
+
vol = nib.load(os.path.join(self.root, path))
|
| 56 |
+
vol = vol.get_fdata().astype(np.float32).transpose(self.spatial_index)
|
| 57 |
+
vol = torch.from_numpy(vol)
|
| 58 |
+
return vol
|
| 59 |
+
|
| 60 |
+
def __getitem__(self, idx):
|
| 61 |
+
return NotImplemented
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class UCLSet(BaseVolumeDataset):
|
| 65 |
+
def __init__(self, args, image_paths, phase, transforms=None):
|
| 66 |
+
super().__init__(args=args, image_paths=image_paths, phase=phase, transforms=transforms)
|
| 67 |
+
|
| 68 |
+
def __getitem__(self, idx):
|
| 69 |
+
path = self.img_dict.iloc[idx]
|
| 70 |
+
t2w = self.read(path["t2w"])
|
| 71 |
+
dwi = self.read(path["dwi"])
|
| 72 |
+
adc = self.read(path["adc"])
|
| 73 |
+
img = torch.stack([t2w, dwi, adc], 0)
|
| 74 |
+
seg = self.read(path["lesion"]).unsqueeze(0)
|
| 75 |
+
seg = seg > 0
|
| 76 |
+
# print(img.shape)
|
| 77 |
+
# seg = (seg == self.target_class).float()
|
| 78 |
+
if self.transforms is not None:
|
| 79 |
+
trans_dict = self.transforms({"image": img, "label": seg})
|
| 80 |
+
if type(trans_dict) == list:
|
| 81 |
+
trans_dict = trans_dict[0]
|
| 82 |
+
img, seg = trans_dict["image"], trans_dict["label"]
|
| 83 |
+
return img, seg, torch.tensor(idx, dtype=torch.long)
|
| 84 |
+
|
| 85 |
+
# TODO: need to update; unfinished
|
| 86 |
+
"""
|
| 87 |
+
class UCL2DSet(BaseVolumeDataset):
|
| 88 |
+
def __init__(self, args, image_paths, phase, transforms=None):
|
| 89 |
+
super().__init__(args=args, image_paths=image_paths, phase=phase, transforms=transforms)
|
| 90 |
+
|
| 91 |
+
def __getitem__(self, idx):
|
| 92 |
+
path = self.img_dict.iloc[idx]
|
| 93 |
+
t2w = self.read(path["t2w"])
|
| 94 |
+
dwi = self.read(path["dwi"])
|
| 95 |
+
adc = self.read(path["adc"])
|
| 96 |
+
|
| 97 |
+
seg = self.read(path["lesion"]).unsqueeze(0)
|
| 98 |
+
seg = seg > 0
|
| 99 |
+
|
| 100 |
+
seg_mask = seg.squeeze(0).numpy()
|
| 101 |
+
non_zero_slices = np.where(seg_mask.any(axis=1,2))[0]
|
| 102 |
+
if len(non_zero_slices) > 0:
|
| 103 |
+
sampled_slices = np.random.choice(non_zero_slices, min(N, len(non_zero_slices)), replace=False)
|
| 104 |
+
filtered_seg = np.zeros_like(seg_mask)
|
| 105 |
+
filtered_seg[sampled_slices] = seg_mask[sampled_slices]
|
| 106 |
+
else:
|
| 107 |
+
filtered_seg = seg_mask
|
| 108 |
+
|
| 109 |
+
img = torch.stack([t2w, dwi, adc], 0)
|
| 110 |
+
seg = torch.tensor(filtered_seg, dtype=torch.float32).unsqueeze(0)
|
| 111 |
+
if self.transforms is not None:
|
| 112 |
+
trans_dict = self.transforms({"image": img, "label": seg})
|
| 113 |
+
if type(trans_dict) == list:
|
| 114 |
+
trans_dict = trans_dict[0]
|
| 115 |
+
img, seg = trans_dict["image"], trans_dict["label"]
|
| 116 |
+
return img, seg, torch.tensor(idx, dtype=torch.long)
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
class AnatomySet(BaseVolumeDataset):
|
| 120 |
+
def __init__(self, args, image_paths, phase, transforms=None):
|
| 121 |
+
super().__init__(args=args, image_paths=image_paths, phase=phase, transforms=transforms)
|
| 122 |
+
def __getitem__(self, idx):
|
| 123 |
+
path = self.img_dict.iloc[idx]
|
| 124 |
+
t2w = self.read(path["t2w"])
|
| 125 |
+
# img = t2w.unsqueeze(0)
|
| 126 |
+
zero = torch.zeros_like(t2w)
|
| 127 |
+
# modified to align img to 3 channel
|
| 128 |
+
img = torch.stack([t2w, zero, zero], 0)
|
| 129 |
+
seg = self.read(path["mask"]).unsqueeze(0)
|
| 130 |
+
if self.transforms is not None:
|
| 131 |
+
trans_dict = self.transforms({"image": img, "label": seg})
|
| 132 |
+
if type(trans_dict) == list:
|
| 133 |
+
trans_dict = trans_dict[0]
|
| 134 |
+
img, seg = trans_dict["image"], trans_dict["label"]
|
| 135 |
+
return img, seg, torch.tensor(idx, dtype=torch.long)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class BpAnatomySet(BaseVolumeDataset):
|
| 139 |
+
def __init__(self, args, image_paths, phase, transforms=None):
|
| 140 |
+
super().__init__(args=args, image_paths=image_paths, phase=phase, transforms=transforms)
|
| 141 |
+
|
| 142 |
+
def __getitem__(self, idx):
|
| 143 |
+
path = self.img_dict.iloc[idx]
|
| 144 |
+
t2w = self.read(path["t2w"])
|
| 145 |
+
zero = torch.zeros_like(t2w)
|
| 146 |
+
img = torch.stack([t2w, zero, zero], 0)
|
| 147 |
+
seg = self.read(path["mask"]).unsqueeze(0)
|
| 148 |
+
if self.transforms is not None:
|
| 149 |
+
trans_dict = self.transforms({"image": img, "label": seg})
|
| 150 |
+
if type(trans_dict) == list:
|
| 151 |
+
trans_dict = trans_dict[0]
|
| 152 |
+
img, seg = trans_dict["image"], trans_dict["label"]
|
| 153 |
+
return img, seg, torch.tensor(idx, dtype=torch.long)
|
| 154 |
+
|
| 155 |
+
class PromisHist(BaseVolumeDataset):
|
| 156 |
+
def __init__(self, args, image_paths, phase, transforms=None):
|
| 157 |
+
super().__init__(args=args, image_paths=image_paths, phase=phase, transforms=transforms)
|
| 158 |
+
|
| 159 |
+
def __getitem__(self, idx):
|
| 160 |
+
path = self.img_dict.iloc[idx]
|
| 161 |
+
t2w = self.read(path["t2w"])
|
| 162 |
+
dwi = self.read(path["dwi"])
|
| 163 |
+
adc = self.read(path["adc"])
|
| 164 |
+
img = torch.stack([t2w, dwi, adc], 0)
|
| 165 |
+
|
| 166 |
+
zone_mask = self.read(path["gland"]).unsqueeze(0)
|
| 167 |
+
|
| 168 |
+
zone_level = list(map(int, path["zone_label"].split()))
|
| 169 |
+
zone_level = torch.tensor(zone_level)
|
| 170 |
+
|
| 171 |
+
if self.transforms is not None:
|
| 172 |
+
trans_dict = self.transforms({"image": img, "label": zone_mask})
|
| 173 |
+
if type(trans_dict) == list:
|
| 174 |
+
trans_dict = trans_dict[0]
|
| 175 |
+
img, zone_mask = trans_dict["image"], trans_dict["label"]
|
| 176 |
+
|
| 177 |
+
return img, zone_mask, zone_level
|
| 178 |
+
|
| 179 |
+
class PromisZone(BaseVolumeDataset):
|
| 180 |
+
def __init__(self, args, image_paths, phase, transforms=None):
|
| 181 |
+
super().__init__(args=args, image_paths=image_paths, phase=phase, transforms=transforms)
|
| 182 |
+
|
| 183 |
+
def __getitem__(self, idx):
|
| 184 |
+
path = self.img_dict.iloc[idx]
|
| 185 |
+
t2w = self.read(path["t2w"])
|
| 186 |
+
dwi = self.read(path["dwi"])
|
| 187 |
+
adc = self.read(path["adc"])
|
| 188 |
+
img = torch.stack([t2w, dwi, adc], 0)
|
| 189 |
+
|
| 190 |
+
zone_mask = self.read(path["zome_mask"]).unsqueeze(0)
|
| 191 |
+
|
| 192 |
+
zone_level = list(map(int, path["zone_label"].split()))
|
| 193 |
+
zone_level = torch.tensor(zone_level)
|
| 194 |
+
|
| 195 |
+
if self.transforms is not None:
|
| 196 |
+
trans_dict = self.transforms({"image": img, "label": zone_mask})
|
| 197 |
+
if type(trans_dict) == list:
|
| 198 |
+
trans_dict = trans_dict[0]
|
| 199 |
+
img, zone_mask = trans_dict["image"], trans_dict["label"]
|
| 200 |
+
|
| 201 |
+
return img, zone_mask, zone_level
|
| 202 |
+
|
| 203 |
+
def get_transforms(args):
|
| 204 |
+
train_transforms = [
|
| 205 |
+
NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
|
| 206 |
+
RandRotated(
|
| 207 |
+
keys=["image", "label"],
|
| 208 |
+
prob=0.3,
|
| 209 |
+
range_x=30 / 180 * np.pi,
|
| 210 |
+
keep_size=False,
|
| 211 |
+
mode=["bilinear", "nearest"],
|
| 212 |
+
),
|
| 213 |
+
RandZoomd(
|
| 214 |
+
keys=["image", "label"],
|
| 215 |
+
prob=0.3,
|
| 216 |
+
min_zoom=[1, 0.9, 0.9],
|
| 217 |
+
max_zoom=[1, 1.1, 1.1],
|
| 218 |
+
mode=["trilinear", "nearest"],
|
| 219 |
+
),
|
| 220 |
+
SpatialPadd(
|
| 221 |
+
keys=["image", "label"],
|
| 222 |
+
spatial_size=[round(i * 1.2) for i in args.crop_spatial_size],
|
| 223 |
+
),
|
| 224 |
+
# RandCropByPosNegLabeld(
|
| 225 |
+
# keys=["image", "label"],
|
| 226 |
+
# spatial_size=[round(i * 1.2) for i in args.crop_spatial_size],
|
| 227 |
+
# label_key="label",
|
| 228 |
+
# pos=2,
|
| 229 |
+
# neg=1,
|
| 230 |
+
# num_samples=1,
|
| 231 |
+
# ),
|
| 232 |
+
RandSpatialCropd(
|
| 233 |
+
keys=["image", "label"],
|
| 234 |
+
roi_size=args.crop_spatial_size,
|
| 235 |
+
random_size=False,
|
| 236 |
+
),
|
| 237 |
+
RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
|
| 238 |
+
# BinarizeLabeld(keys=["label"])
|
| 239 |
+
RandScaleIntensityd(keys="image", factors=0.1, prob=0.8),
|
| 240 |
+
RandShiftIntensityd(keys="image", offsets=0.1, prob=0.8),
|
| 241 |
+
RandBiasFieldd(keys="image", prob=0.2),
|
| 242 |
+
RandGaussianSmoothd(keys="image", prob=1.0)
|
| 243 |
+
]
|
| 244 |
+
|
| 245 |
+
train_transforms = Compose(train_transforms)
|
| 246 |
+
val_transforms = Compose(
|
| 247 |
+
[
|
| 248 |
+
NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
|
| 249 |
+
CenterSpatialCropd(
|
| 250 |
+
keys=["image", "label"], roi_size=args.crop_spatial_size
|
| 251 |
+
),
|
| 252 |
+
SpatialPadd(
|
| 253 |
+
keys=["image", "label"],
|
| 254 |
+
spatial_size=[i for i in args.crop_spatial_size],
|
| 255 |
+
),
|
| 256 |
+
# BinarizeLabeld(keys=["label"])
|
| 257 |
+
]
|
| 258 |
+
)
|
| 259 |
+
test_transforms = Compose(
|
| 260 |
+
[
|
| 261 |
+
NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
|
| 262 |
+
CenterSpatialCropd(
|
| 263 |
+
keys=["image", "label"], roi_size=args.crop_spatial_size
|
| 264 |
+
),
|
| 265 |
+
SpatialPadd(
|
| 266 |
+
keys=["image", "label"],
|
| 267 |
+
spatial_size=[i for i in args.crop_spatial_size],
|
| 268 |
+
),
|
| 269 |
+
# BinarizeLabeld(keys=["label"])
|
| 270 |
+
]
|
| 271 |
+
)
|
| 272 |
+
return train_transforms, val_transforms, test_transforms
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def build_UCL_loader(args):
|
| 276 |
+
train_transforms, val_transforms, test_transforms = get_transforms(args)
|
| 277 |
+
if args.demo:
|
| 278 |
+
test_set = UCLSet(args, "demo/data/UCL/test.csv", 'test', test_transforms)
|
| 279 |
+
test_loader = DataLoader(
|
| 280 |
+
test_set,
|
| 281 |
+
batch_size=1,
|
| 282 |
+
shuffle=False,
|
| 283 |
+
pin_memory=True,
|
| 284 |
+
num_workers=14,
|
| 285 |
+
drop_last=False,
|
| 286 |
+
)
|
| 287 |
+
args.in_channels = 3
|
| 288 |
+
args.out_channels = 1
|
| 289 |
+
args.num_classes = 1
|
| 290 |
+
return test_loader
|
| 291 |
+
else:
|
| 292 |
+
if args.data20:
|
| 293 |
+
train_set = UCLSet(args, "spilt/UCL/train_16.csv", 'train', train_transforms)
|
| 294 |
+
else:
|
| 295 |
+
train_set = UCLSet(args, "spilt/UCL/train.csv", 'train', train_transforms)
|
| 296 |
+
val_set = UCLSet(args, "spilt/UCL/val.csv", 'val', val_transforms)
|
| 297 |
+
test_set = UCLSet(args, "spilt/UCL/test.csv", 'test', test_transforms)
|
| 298 |
+
train_loader = DataLoader(
|
| 299 |
+
train_set,
|
| 300 |
+
batch_size=args.batch_size,
|
| 301 |
+
shuffle=True,
|
| 302 |
+
pin_memory=True,
|
| 303 |
+
num_workers=14,
|
| 304 |
+
drop_last=True,
|
| 305 |
+
)
|
| 306 |
+
val_loader = DataLoader(
|
| 307 |
+
val_set,
|
| 308 |
+
batch_size=args.batch_size,
|
| 309 |
+
shuffle=False,
|
| 310 |
+
pin_memory=True,
|
| 311 |
+
num_workers=14,
|
| 312 |
+
drop_last=False,
|
| 313 |
+
)
|
| 314 |
+
test_loader = DataLoader(
|
| 315 |
+
test_set,
|
| 316 |
+
batch_size=1,
|
| 317 |
+
shuffle=False,
|
| 318 |
+
pin_memory=True,
|
| 319 |
+
num_workers=14,
|
| 320 |
+
drop_last=False,
|
| 321 |
+
)
|
| 322 |
+
args.in_channels = 3
|
| 323 |
+
args.out_channels = 1
|
| 324 |
+
args.num_classes = 1
|
| 325 |
+
return train_loader, val_loader, test_loader
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
def build_Promis_loader(args):
|
| 329 |
+
train_transforms, val_transforms, test_transforms = get_transforms(args)
|
| 330 |
+
if args.data20:
|
| 331 |
+
train_set = UCLSet(args, "spilt/promis567/train_20.csv", 'train', train_transforms)
|
| 332 |
+
else:
|
| 333 |
+
train_set = UCLSet(args, "spilt/promis567/train.csv", 'train', train_transforms)
|
| 334 |
+
val_set = UCLSet(args, "spilt/promis567/val.csv", 'val', val_transforms)
|
| 335 |
+
test_set = UCLSet(args, "spilt/promis567/test.csv", 'test', test_transforms)
|
| 336 |
+
train_loader = DataLoader(
|
| 337 |
+
train_set,
|
| 338 |
+
batch_size=args.batch_size,
|
| 339 |
+
shuffle=True,
|
| 340 |
+
pin_memory=True,
|
| 341 |
+
num_workers=14,
|
| 342 |
+
drop_last=False,
|
| 343 |
+
)
|
| 344 |
+
val_loader = DataLoader(
|
| 345 |
+
val_set,
|
| 346 |
+
batch_size=args.batch_size,
|
| 347 |
+
shuffle=False,
|
| 348 |
+
pin_memory=True,
|
| 349 |
+
num_workers=14,
|
| 350 |
+
drop_last=False,
|
| 351 |
+
)
|
| 352 |
+
test_loader = DataLoader(
|
| 353 |
+
test_set,
|
| 354 |
+
batch_size=1,
|
| 355 |
+
shuffle=False,
|
| 356 |
+
pin_memory=True,
|
| 357 |
+
num_workers=14,
|
| 358 |
+
drop_last=False,
|
| 359 |
+
)
|
| 360 |
+
args.in_channels = 3
|
| 361 |
+
args.out_channels = 1
|
| 362 |
+
args.num_classes = 1
|
| 363 |
+
return train_loader, val_loader, test_loader
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
def build_Anatomy_loader(args):
|
| 367 |
+
train_transforms, val_transforms, test_transforms = get_transforms(args)
|
| 368 |
+
if args.data20:
|
| 369 |
+
train_set = AnatomySet(args, "spilt/anatomy/train_20.csv", 'train', train_transforms)
|
| 370 |
+
else:
|
| 371 |
+
train_set = AnatomySet(args, "spilt/anatomy/train.csv", 'train', train_transforms)
|
| 372 |
+
val_set = AnatomySet(args, "spilt/anatomy/val.csv", 'val', val_transforms)
|
| 373 |
+
test_set = AnatomySet(
|
| 374 |
+
args,
|
| 375 |
+
"spilt/anatomy/test.csv",
|
| 376 |
+
'test',
|
| 377 |
+
NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
|
| 378 |
+
)
|
| 379 |
+
train_loader = DataLoader(
|
| 380 |
+
train_set,
|
| 381 |
+
batch_size=args.batch_size,
|
| 382 |
+
shuffle=True,
|
| 383 |
+
pin_memory=True,
|
| 384 |
+
num_workers=14,
|
| 385 |
+
drop_last=False,
|
| 386 |
+
)
|
| 387 |
+
val_loader = DataLoader(
|
| 388 |
+
val_set,
|
| 389 |
+
batch_size=args.batch_size,
|
| 390 |
+
shuffle=False,
|
| 391 |
+
pin_memory=True,
|
| 392 |
+
num_workers=14,
|
| 393 |
+
drop_last=False,
|
| 394 |
+
)
|
| 395 |
+
test_loader = DataLoader(
|
| 396 |
+
test_set,
|
| 397 |
+
batch_size=1,
|
| 398 |
+
shuffle=False,
|
| 399 |
+
pin_memory=True,
|
| 400 |
+
num_workers=14,
|
| 401 |
+
drop_last=False,
|
| 402 |
+
)
|
| 403 |
+
if args.prompt:
|
| 404 |
+
# TODO: need to update; currently not in use
|
| 405 |
+
args.in_channels = 3
|
| 406 |
+
else:
|
| 407 |
+
args.in_channels = 3
|
| 408 |
+
args.out_channels = 9
|
| 409 |
+
args.num_classes = 8
|
| 410 |
+
return train_loader, val_loader, test_loader
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
def build_BpAnatomy_loader(args):
|
| 414 |
+
train_transforms, val_transforms, test_transforms = get_transforms(args)
|
| 415 |
+
if args.data20:
|
| 416 |
+
train_set = BpAnatomySet(args, "spilt/anatomy/train_20.csv", 'train', train_transforms)
|
| 417 |
+
else:
|
| 418 |
+
train_set = BpAnatomySet(args, "spilt/anatomy/train.csv", 'train', train_transforms)
|
| 419 |
+
val_set = BpAnatomySet(args, "spilt/anatomy/val.csv", 'val', val_transforms)
|
| 420 |
+
test_set = BpAnatomySet(
|
| 421 |
+
args,
|
| 422 |
+
"spilt/anatomy/test.csv",
|
| 423 |
+
'test',
|
| 424 |
+
NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
|
| 425 |
+
)
|
| 426 |
+
train_loader = DataLoader(
|
| 427 |
+
train_set,
|
| 428 |
+
batch_size=args.batch_size,
|
| 429 |
+
shuffle=True,
|
| 430 |
+
num_workers=4,
|
| 431 |
+
drop_last=False,
|
| 432 |
+
)
|
| 433 |
+
val_loader = DataLoader(
|
| 434 |
+
val_set,
|
| 435 |
+
batch_size=args.batch_size,
|
| 436 |
+
shuffle=False,
|
| 437 |
+
num_workers=4,
|
| 438 |
+
drop_last=False,
|
| 439 |
+
)
|
| 440 |
+
test_loader = DataLoader(
|
| 441 |
+
test_set, batch_size=1, shuffle=False, num_workers=4, drop_last=False
|
| 442 |
+
)
|
| 443 |
+
args.in_channels = 3
|
| 444 |
+
args.out_channels = 9
|
| 445 |
+
args.num_classes = 8
|
| 446 |
+
return train_loader, val_loader, test_loader
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
def build_PromisHist_loader(args):
|
| 450 |
+
train_transforms, val_transforms, test_transforms = get_transforms(args)
|
| 451 |
+
if args.data20:
|
| 452 |
+
train_set = PromisHist(args, "spilt/promis567_hist/train_20.csv", 'train', train_transforms)
|
| 453 |
+
else:
|
| 454 |
+
train_set = PromisHist(args, "spilt/promis567_hist/train.csv", 'train', train_transforms)
|
| 455 |
+
val_set = PromisHist(args, "spilt/promis567_hist/val.csv", 'val', val_transforms)
|
| 456 |
+
test_set = PromisHist(args, "spilt/promis567_hist/test.csv", 'test', test_transforms)
|
| 457 |
+
train_loader = DataLoader(
|
| 458 |
+
train_set,
|
| 459 |
+
batch_size=args.batch_size,
|
| 460 |
+
shuffle=True,
|
| 461 |
+
pin_memory=True,
|
| 462 |
+
num_workers=14,
|
| 463 |
+
drop_last=False,
|
| 464 |
+
)
|
| 465 |
+
val_loader = DataLoader(
|
| 466 |
+
val_set,
|
| 467 |
+
batch_size=args.batch_size,
|
| 468 |
+
shuffle=False,
|
| 469 |
+
pin_memory=True,
|
| 470 |
+
num_workers=14,
|
| 471 |
+
drop_last=True,
|
| 472 |
+
)
|
| 473 |
+
test_loader = DataLoader(
|
| 474 |
+
test_set,
|
| 475 |
+
batch_size=1,
|
| 476 |
+
shuffle=False,
|
| 477 |
+
pin_memory=True,
|
| 478 |
+
num_workers=14,
|
| 479 |
+
drop_last=False,
|
| 480 |
+
)
|
| 481 |
+
args.in_channels = 3
|
| 482 |
+
args.out_channels = 1
|
| 483 |
+
args.num_classes = 1
|
| 484 |
+
return train_loader, val_loader, test_loader
|
| 485 |
+
|
| 486 |
+
def build_PromisZone_loader(args):
|
| 487 |
+
train_transforms, val_transforms, test_transforms = get_transforms(args)
|
| 488 |
+
train_set = PromisZone(args, "spilt/promis_zone/train.csv", 'train', train_transforms)
|
| 489 |
+
val_set = PromisZone(args, "spilt/promis_zone/val.csv", 'val', val_transforms)
|
| 490 |
+
test_set = PromisZone(args, "spilt/promis_zone/test.csv", 'test', test_transforms)
|
| 491 |
+
train_loader = DataLoader(
|
| 492 |
+
train_set,
|
| 493 |
+
batch_size=args.batch_size,
|
| 494 |
+
shuffle=True,
|
| 495 |
+
pin_memory=True,
|
| 496 |
+
num_workers=14,
|
| 497 |
+
drop_last=True,
|
| 498 |
+
)
|
| 499 |
+
val_loader = DataLoader(
|
| 500 |
+
val_set,
|
| 501 |
+
batch_size=args.batch_size,
|
| 502 |
+
shuffle=False,
|
| 503 |
+
pin_memory=True,
|
| 504 |
+
num_workers=14,
|
| 505 |
+
drop_last=True,
|
| 506 |
+
)
|
| 507 |
+
test_loader = DataLoader(
|
| 508 |
+
test_set,
|
| 509 |
+
batch_size=1,
|
| 510 |
+
shuffle=False,
|
| 511 |
+
pin_memory=True,
|
| 512 |
+
num_workers=14,
|
| 513 |
+
drop_last=False,
|
| 514 |
+
)
|
| 515 |
+
args.in_channels = 3
|
| 516 |
+
args.out_channels = 1
|
| 517 |
+
args.num_classes = 1
|
| 518 |
+
return train_loader, val_loader, test_loader
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
def build_PromisPirads3_loader(args):
|
| 522 |
+
train_transforms, val_transforms, test_transforms = get_transforms(args)
|
| 523 |
+
if args.data20:
|
| 524 |
+
train_set = UCLSet(args, "spilt/promis_pirads3/train_15.csv", 'train', train_transforms)
|
| 525 |
+
else:
|
| 526 |
+
train_set = UCLSet(args, "spilt/promis_pirads3/train.csv", 'train', train_transforms)
|
| 527 |
+
val_set = UCLSet(args, "spilt/promis_pirads3/val.csv", 'val', val_transforms)
|
| 528 |
+
test_set = UCLSet(args, "spilt/promis_pirads3/test.csv", 'test', test_transforms)
|
| 529 |
+
train_loader = DataLoader(
|
| 530 |
+
train_set,
|
| 531 |
+
batch_size=args.batch_size,
|
| 532 |
+
shuffle=True,
|
| 533 |
+
pin_memory=True,
|
| 534 |
+
num_workers=14,
|
| 535 |
+
drop_last=False,
|
| 536 |
+
)
|
| 537 |
+
val_loader = DataLoader(
|
| 538 |
+
val_set,
|
| 539 |
+
batch_size=args.batch_size,
|
| 540 |
+
shuffle=False,
|
| 541 |
+
pin_memory=True,
|
| 542 |
+
num_workers=14,
|
| 543 |
+
drop_last=False,
|
| 544 |
+
)
|
| 545 |
+
test_loader = DataLoader(
|
| 546 |
+
test_set,
|
| 547 |
+
batch_size=1,
|
| 548 |
+
shuffle=False,
|
| 549 |
+
pin_memory=True,
|
| 550 |
+
num_workers=14,
|
| 551 |
+
drop_last=False,
|
| 552 |
+
)
|
| 553 |
+
args.in_channels = 3
|
| 554 |
+
args.out_channels = 1
|
| 555 |
+
args.num_classes = 1
|
| 556 |
+
return train_loader, val_loader, test_loader
|
demo_classfication.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
# References:
|
| 8 |
+
# DeiT: https://github.com/facebookresearch/deit
|
| 9 |
+
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
|
| 10 |
+
# --------------------------------------------------------
|
| 11 |
+
import argparse
|
| 12 |
+
import datetime
|
| 13 |
+
import json
|
| 14 |
+
import numpy as np
|
| 15 |
+
import os
|
| 16 |
+
import time
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from typing import Callable, List, Optional, Tuple
|
| 19 |
+
import torch
|
| 20 |
+
import torch.backends.cudnn as cudnn
|
| 21 |
+
from models.classifier import Classifier
|
| 22 |
+
from models.convnextv2 import convnextv2_tiny, remap_checkpoint_keys, load_state_dict
|
| 23 |
+
from dataset.dataset_cls import build_Risk_loader, build_Screening_loader, build_Promis_loader, build_Promis3_hist_loader
|
| 24 |
+
from engine.classification import test_risk
|
| 25 |
+
|
| 26 |
+
def tuple_type(strings):
|
| 27 |
+
strings = strings.replace("(", "").replace(")", "")
|
| 28 |
+
mapped_int = map(int, strings.split(","))
|
| 29 |
+
return tuple(mapped_int)
|
| 30 |
+
|
| 31 |
+
def get_args_parser():
|
| 32 |
+
parser = argparse.ArgumentParser("segmentation", add_help=False)
|
| 33 |
+
parser.add_argument(
|
| 34 |
+
"--batch_size",
|
| 35 |
+
default=1,
|
| 36 |
+
type=int,
|
| 37 |
+
help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus",
|
| 38 |
+
)
|
| 39 |
+
parser.add_argument("--epochs", default=400, type=int)
|
| 40 |
+
parser.add_argument(
|
| 41 |
+
"--root", default="./", type=str
|
| 42 |
+
)
|
| 43 |
+
parser.add_argument("--crop_spatial_size", default=(64, 256, 256), type=tuple_type)
|
| 44 |
+
|
| 45 |
+
# Model parameters
|
| 46 |
+
parser.add_argument("--model", help="model name")
|
| 47 |
+
parser.add_argument(
|
| 48 |
+
"--input_size", default=(64, 256, 256), type=tuple_type, help="images input size"
|
| 49 |
+
)
|
| 50 |
+
parser.add_argument(
|
| 51 |
+
"--train",
|
| 52 |
+
default="scratch",
|
| 53 |
+
choices=["fintune", "freeze", "scratch"],
|
| 54 |
+
help="train method",
|
| 55 |
+
)
|
| 56 |
+
parser.add_argument("--pretrain", default=None, type=str)
|
| 57 |
+
parser.add_argument("--tolerance", default=5, type=int)
|
| 58 |
+
parser.add_argument("--spacing", default=(1.0, 0.5, 0.5), type=tuple)
|
| 59 |
+
# Optimizer parameters
|
| 60 |
+
parser.add_argument(
|
| 61 |
+
"--weight_decay", type=float, default=1e-5, help="weight decay (default: 1e-5)"
|
| 62 |
+
)
|
| 63 |
+
parser.add_argument(
|
| 64 |
+
"--lr",
|
| 65 |
+
default=0.1,
|
| 66 |
+
type=float,
|
| 67 |
+
metavar="LR",
|
| 68 |
+
help="learning rate (absolute lr)",
|
| 69 |
+
)
|
| 70 |
+
parser.add_argument(
|
| 71 |
+
"--min_lr",
|
| 72 |
+
type=float,
|
| 73 |
+
default=0.0,
|
| 74 |
+
metavar="LR",
|
| 75 |
+
help="lower lr bound for cyclic schedulers that hit 0",
|
| 76 |
+
)
|
| 77 |
+
parser.add_argument(
|
| 78 |
+
"--warmup_epochs", type=int, default=40, metavar="N", help="epochs to warmup LR"
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
# Dataset parameters
|
| 82 |
+
parser.add_argument(
|
| 83 |
+
"--output_dir",
|
| 84 |
+
default="./outputcls",
|
| 85 |
+
help="path where to save, empty for no saving",
|
| 86 |
+
)
|
| 87 |
+
parser.add_argument("--file_name", default="")
|
| 88 |
+
parser.add_argument("--ckpt_dir", default="./outputcls")
|
| 89 |
+
parser.add_argument(
|
| 90 |
+
"--log_dir", default="./outputcls", help="path where to tensorboard log"
|
| 91 |
+
)
|
| 92 |
+
parser.add_argument("--dataset", default="UCL", help="dataset name")
|
| 93 |
+
parser.add_argument(
|
| 94 |
+
"--device", default="cuda", help="device to use for training / testing"
|
| 95 |
+
)
|
| 96 |
+
parser.add_argument("--seed", default=0, type=int)
|
| 97 |
+
parser.add_argument("--resume", default="", help="resume from checkpoint")
|
| 98 |
+
|
| 99 |
+
parser.add_argument(
|
| 100 |
+
"--start_epoch", default=0, type=int, metavar="N", help="start epoch"
|
| 101 |
+
)
|
| 102 |
+
parser.add_argument("--num_workers", default=10, type=int)
|
| 103 |
+
parser.add_argument(
|
| 104 |
+
"--pin_mem",
|
| 105 |
+
action="store_true",
|
| 106 |
+
help="Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.",
|
| 107 |
+
)
|
| 108 |
+
parser.add_argument("--no_pin_mem", action="store_false", dest="pin_mem")
|
| 109 |
+
parser.set_defaults(pin_mem=True)
|
| 110 |
+
|
| 111 |
+
parser.add_argument("--data20", action="store_true", help="Use 20 training data")
|
| 112 |
+
parser.set_defaults(data20=False)
|
| 113 |
+
|
| 114 |
+
parser.add_argument("--data_num", default=0, type=int, help="number of train data")
|
| 115 |
+
|
| 116 |
+
parser.add_argument("--save_fig", action="store_true")
|
| 117 |
+
parser.set_defaults(save_fig=False)
|
| 118 |
+
|
| 119 |
+
parser.add_argument(
|
| 120 |
+
"--prompt", action="store_true", help="Use visual prompt tuning"
|
| 121 |
+
)
|
| 122 |
+
parser.set_defaults(data20=False)
|
| 123 |
+
|
| 124 |
+
parser.add_argument(
|
| 125 |
+
"--world_size", default=1, type=int, help="number of distributed processes"
|
| 126 |
+
)
|
| 127 |
+
parser.add_argument("--local_rank", default=-1, type=int)
|
| 128 |
+
parser.add_argument("--dist_on_itp", action="store_true")
|
| 129 |
+
parser.add_argument(
|
| 130 |
+
"--dist_url", default="env://", help="url used to set up distributed training"
|
| 131 |
+
)
|
| 132 |
+
parser.add_argument("--kfold", type=int, default=None)
|
| 133 |
+
parser.add_argument("--demo", type=bool, default=True, help="Run in demo mode")
|
| 134 |
+
return parser
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def main(args):
|
| 138 |
+
|
| 139 |
+
device = "cuda"
|
| 140 |
+
# fix the seed for reproducibility
|
| 141 |
+
seed = args.seed
|
| 142 |
+
torch.manual_seed(seed)
|
| 143 |
+
np.random.seed(seed)
|
| 144 |
+
cudnn.benchmark = True
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
if args.dataset == "risk":
|
| 148 |
+
data_loader_test = build_Risk_loader(args)
|
| 149 |
+
# elif args.dataset == "screening":
|
| 150 |
+
# data_loader_train, data_loader_val, data_loader_test = build_Screening_loader(
|
| 151 |
+
# args
|
| 152 |
+
# )
|
| 153 |
+
# elif args.dataset == "promis":
|
| 154 |
+
# data_loader_train, data_loader_val, data_loader_test = build_Promis_loader(args)
|
| 155 |
+
# elif args.dataset == "promis3hist":
|
| 156 |
+
# data_loader_train, data_loader_val, data_loader_test = build_Promis3_hist_loader(args)
|
| 157 |
+
else:
|
| 158 |
+
raise NotImplementedError(f"unknown schedule sampler: {args.dataset}")
|
| 159 |
+
print(f"Loaded dataset: {args.dataset}, test set size: {len(data_loader_test.dataset)}")
|
| 160 |
+
|
| 161 |
+
if args.model == "profound_conv":
|
| 162 |
+
convnext = convnextv2_tiny(in_chans=3)
|
| 163 |
+
model = Classifier(convnext, args.num_classes)
|
| 164 |
+
else:
|
| 165 |
+
raise NotImplementedError(f"unknown model: {args.model}")
|
| 166 |
+
|
| 167 |
+
args.output_dir = os.path.join(args.output_dir, args.dataset)
|
| 168 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 169 |
+
|
| 170 |
+
model.load_state_dict(torch.load(args.ckpt_dir, map_location='cpu', weights_only=False)["model"])
|
| 171 |
+
print(f"Loaded model from {args.ckpt_dir}")
|
| 172 |
+
model.to(device)
|
| 173 |
+
logits, gts = [], []
|
| 174 |
+
model.eval()
|
| 175 |
+
with torch.no_grad():
|
| 176 |
+
for idx, (img, gt, pid) in enumerate(data_loader_test):
|
| 177 |
+
img, gt = img.to(args.device), gt.to(args.device)
|
| 178 |
+
logit = model(img)
|
| 179 |
+
logits.append(logit)
|
| 180 |
+
gts.append(gt)
|
| 181 |
+
|
| 182 |
+
# if args.dataset == "risk":
|
| 183 |
+
# test_risk(logits, gts)
|
| 184 |
+
logits = torch.cat(logits, 0).squeeze().cpu().numpy()
|
| 185 |
+
gts = torch.cat(gts, 0).squeeze().cpu().numpy()
|
| 186 |
+
print(f"test results: logits {logits}, gts {gts}")
|
| 187 |
+
np.savez(os.path.join(args.output_dir, f"{args.file_name}.npz"), logits = logits, gts=gts)
|
| 188 |
+
|
| 189 |
+
if __name__ == "__main__":
|
| 190 |
+
args = get_args_parser()
|
| 191 |
+
args = args.parse_args()
|
| 192 |
+
main(args)
|
demo_segmentation.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
# References:
|
| 8 |
+
# DeiT: https://github.com/facebookresearch/deit
|
| 9 |
+
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
|
| 10 |
+
# --------------------------------------------------------
|
| 11 |
+
import argparse
|
| 12 |
+
import datetime
|
| 13 |
+
import json
|
| 14 |
+
import numpy as np
|
| 15 |
+
import os
|
| 16 |
+
import time
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from typing import Callable, List, Optional, Tuple
|
| 19 |
+
import torch
|
| 20 |
+
import torch.backends.cudnn as cudnn
|
| 21 |
+
from dataset.dataset_seg import (
|
| 22 |
+
build_UCL_loader,
|
| 23 |
+
build_Anatomy_loader,
|
| 24 |
+
build_BpAnatomy_loader,
|
| 25 |
+
build_Promis_loader,
|
| 26 |
+
build_PromisPirads3_loader
|
| 27 |
+
)
|
| 28 |
+
import monai
|
| 29 |
+
from monai.inferers import sliding_window_inference
|
| 30 |
+
from monai.metrics import compute_dice
|
| 31 |
+
import SimpleITK as sitk
|
| 32 |
+
from models.convnextv2 import convnextv2_tiny, remap_checkpoint_keys, load_state_dict
|
| 33 |
+
from models.convnext_unter import ConvnextUNETR
|
| 34 |
+
from models.upernet_module import UperNet
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def tuple_type(strings):
|
| 38 |
+
strings = strings.replace("(", "").replace(")", "")
|
| 39 |
+
mapped_int = map(int, strings.split(","))
|
| 40 |
+
return tuple(mapped_int)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def get_args_parser():
|
| 45 |
+
parser = argparse.ArgumentParser("segmentation", add_help=False)
|
| 46 |
+
parser.add_argument(
|
| 47 |
+
"--batch_size",
|
| 48 |
+
default=1,
|
| 49 |
+
type=int,
|
| 50 |
+
help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus",
|
| 51 |
+
)
|
| 52 |
+
parser.add_argument("--epochs", default=400, type=int)
|
| 53 |
+
parser.add_argument(
|
| 54 |
+
"--root", default="./", type=str
|
| 55 |
+
)
|
| 56 |
+
parser.add_argument("--crop_spatial_size", default=(64, 256, 256), type=tuple_type)
|
| 57 |
+
|
| 58 |
+
# Model parameters
|
| 59 |
+
parser.add_argument("--model", help="model name")
|
| 60 |
+
parser.add_argument(
|
| 61 |
+
"--input_size", default=(64, 256, 256), type=tuple_type, help="images input size"
|
| 62 |
+
)
|
| 63 |
+
parser.add_argument(
|
| 64 |
+
"--train",
|
| 65 |
+
default="scratch",
|
| 66 |
+
choices=["fintune", "freeze", "scratch"],
|
| 67 |
+
help="train method",
|
| 68 |
+
)
|
| 69 |
+
parser.add_argument("--pretrain", default=None, type=str)
|
| 70 |
+
parser.add_argument("--tolerance", default=5, type=int)
|
| 71 |
+
parser.add_argument("--spacing", default=(1.0, 0.5, 0.5), type=tuple)
|
| 72 |
+
# Optimizer parameters
|
| 73 |
+
parser.add_argument(
|
| 74 |
+
"--weight_decay", type=float, default=1e-5, help="weight decay (default: 1e-5)"
|
| 75 |
+
)
|
| 76 |
+
parser.add_argument(
|
| 77 |
+
"--lr",
|
| 78 |
+
default=0.1,
|
| 79 |
+
type=float,
|
| 80 |
+
metavar="LR",
|
| 81 |
+
help="learning rate (absolute lr)",
|
| 82 |
+
)
|
| 83 |
+
parser.add_argument(
|
| 84 |
+
"--min_lr",
|
| 85 |
+
type=float,
|
| 86 |
+
default=0.0,
|
| 87 |
+
metavar="LR",
|
| 88 |
+
help="lower lr bound for cyclic schedulers that hit 0",
|
| 89 |
+
)
|
| 90 |
+
parser.add_argument(
|
| 91 |
+
"--warmup_epochs", type=int, default=40, metavar="N", help="epochs to warmup LR"
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
# Dataset parameters
|
| 95 |
+
parser.add_argument(
|
| 96 |
+
"--output_dir",
|
| 97 |
+
default="./outputseg",
|
| 98 |
+
help="path where to save, empty for no saving",
|
| 99 |
+
)
|
| 100 |
+
parser.add_argument("--file_name", default="")
|
| 101 |
+
parser.add_argument("--ckpt_dir", default="./outputseg")
|
| 102 |
+
parser.add_argument(
|
| 103 |
+
"--log_dir", default="./outputseg", help="path where to tensorboard log"
|
| 104 |
+
)
|
| 105 |
+
parser.add_argument("--dataset", default="UCL", help="dataset name")
|
| 106 |
+
parser.add_argument(
|
| 107 |
+
"--device", default="cuda", help="device to use for training / testing"
|
| 108 |
+
)
|
| 109 |
+
parser.add_argument("--seed", default=0, type=int)
|
| 110 |
+
parser.add_argument("--resume", default="", help="resume from checkpoint")
|
| 111 |
+
|
| 112 |
+
parser.add_argument(
|
| 113 |
+
"--start_epoch", default=0, type=int, metavar="N", help="start epoch"
|
| 114 |
+
)
|
| 115 |
+
parser.add_argument("--num_workers", default=10, type=int)
|
| 116 |
+
parser.add_argument(
|
| 117 |
+
"--pin_mem",
|
| 118 |
+
action="store_true",
|
| 119 |
+
help="Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.",
|
| 120 |
+
)
|
| 121 |
+
parser.add_argument("--no_pin_mem", action="store_false", dest="pin_mem")
|
| 122 |
+
parser.set_defaults(pin_mem=True)
|
| 123 |
+
|
| 124 |
+
parser.add_argument("--data20", action="store_true", help="Use 20 training data")
|
| 125 |
+
parser.set_defaults(data20=False)
|
| 126 |
+
|
| 127 |
+
parser.add_argument("--data_num", default=0, type=int, help="number of train data")
|
| 128 |
+
|
| 129 |
+
parser.add_argument("--save_fig", action="store_true")
|
| 130 |
+
parser.set_defaults(save_fig=False)
|
| 131 |
+
|
| 132 |
+
parser.add_argument(
|
| 133 |
+
"--prompt", action="store_true", help="Use visual prompt tuning"
|
| 134 |
+
)
|
| 135 |
+
parser.set_defaults(prompt=False)
|
| 136 |
+
|
| 137 |
+
parser.add_argument(
|
| 138 |
+
"--world_size", default=1, type=int, help="number of distributed processes"
|
| 139 |
+
)
|
| 140 |
+
parser.add_argument("--local_rank", default=-1, type=int)
|
| 141 |
+
parser.add_argument("--dist_on_itp", action="store_true")
|
| 142 |
+
parser.add_argument(
|
| 143 |
+
"--dist_url", default="env://", help="url used to set up distributed training"
|
| 144 |
+
)
|
| 145 |
+
parser.add_argument("--demo", type=bool, default=True, help="Run in demo mode")
|
| 146 |
+
return parser
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def main(args):
|
| 150 |
+
|
| 151 |
+
device = "cuda"
|
| 152 |
+
# fix the seed for reproducibility
|
| 153 |
+
seed = args.seed
|
| 154 |
+
torch.manual_seed(seed)
|
| 155 |
+
np.random.seed(seed)
|
| 156 |
+
cudnn.benchmark = True
|
| 157 |
+
|
| 158 |
+
if args.dataset == "UCL":
|
| 159 |
+
data_loader_test = build_UCL_loader(args)
|
| 160 |
+
args.sliding_window = False
|
| 161 |
+
|
| 162 |
+
else:
|
| 163 |
+
raise NotImplementedError(f"unknown schedule sampler: {args.dataset}")
|
| 164 |
+
print(f"Loaded dataset: {args.dataset}, test set size: {len(data_loader_test)}")
|
| 165 |
+
|
| 166 |
+
if args.model == "profound_conv":
|
| 167 |
+
convnext = convnextv2_tiny(in_chans=3)
|
| 168 |
+
model = UperNet(
|
| 169 |
+
encoder=convnext,
|
| 170 |
+
in_channels=[96, 192, 384, 768],
|
| 171 |
+
out_channels=args.out_channels,
|
| 172 |
+
)
|
| 173 |
+
model = model.to(device)
|
| 174 |
+
|
| 175 |
+
elif args.model == "profound_conv_unetr3d":
|
| 176 |
+
convnext = convnextv2_tiny(in_chans=3)
|
| 177 |
+
|
| 178 |
+
model = ConvnextUNETR(
|
| 179 |
+
in_channels=3, out_channels=1, convnext=convnext, feature_size=32
|
| 180 |
+
)
|
| 181 |
+
model = model.to(device)
|
| 182 |
+
|
| 183 |
+
else:
|
| 184 |
+
raise NotImplementedError(f"unknown model: {args.model}")
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
args.output_dir = os.path.join(args.output_dir, args.dataset)
|
| 188 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 189 |
+
|
| 190 |
+
model.load_state_dict(torch.load(args.ckpt_dir, weights_only=False)["model"])
|
| 191 |
+
print(f"Loaded model: {args.ckpt_dir}")
|
| 192 |
+
|
| 193 |
+
dice_list = []
|
| 194 |
+
model.eval()
|
| 195 |
+
with torch.no_grad():
|
| 196 |
+
for idx, (img, gt, pid) in enumerate(data_loader_test):
|
| 197 |
+
img, gt = img.to(args.device), gt.to(args.device)
|
| 198 |
+
if args.sliding_window:
|
| 199 |
+
pred = sliding_window_inference(
|
| 200 |
+
img, args.crop_spatial_size, 4, model, overlap=0.5
|
| 201 |
+
)
|
| 202 |
+
else:
|
| 203 |
+
pred = model(img)
|
| 204 |
+
|
| 205 |
+
if args.num_classes == 1:
|
| 206 |
+
pred = torch.sigmoid(pred) > 0.5
|
| 207 |
+
pred = pred.int()
|
| 208 |
+
else:
|
| 209 |
+
pred = torch.softmax(pred, dim=1)
|
| 210 |
+
pred = torch.argmax(pred, dim=1, keepdim=True)
|
| 211 |
+
|
| 212 |
+
dice = compute_dice(pred, gt) # compute_dice(pred, gt, False,num_classes=9)
|
| 213 |
+
print(pid, dice.item())
|
| 214 |
+
if not torch.isnan(dice):
|
| 215 |
+
dice_list.append(dice)
|
| 216 |
+
# dice = int(dice.mean()*10000)
|
| 217 |
+
img = img.squeeze().cpu().numpy()
|
| 218 |
+
pred = pred.squeeze().cpu().numpy()
|
| 219 |
+
gt = gt.squeeze().cpu().numpy()
|
| 220 |
+
if args.save_fig:
|
| 221 |
+
if idx < 20:
|
| 222 |
+
# print(img.shape,pred.shape, gt.shape )
|
| 223 |
+
sitk.WriteImage(
|
| 224 |
+
sitk.GetImageFromArray(img[0]),
|
| 225 |
+
os.path.join(args.output_dir, f"{idx}_t2w.nii.gz"),
|
| 226 |
+
)
|
| 227 |
+
sitk.WriteImage(
|
| 228 |
+
sitk.GetImageFromArray(img[1]),
|
| 229 |
+
os.path.join(args.output_dir, f"{idx}_dwi.nii.gz"),
|
| 230 |
+
)
|
| 231 |
+
sitk.WriteImage(
|
| 232 |
+
sitk.GetImageFromArray(pred),
|
| 233 |
+
os.path.join(args.output_dir, f"{idx}_pred.nii.gz"),
|
| 234 |
+
)
|
| 235 |
+
sitk.WriteImage(
|
| 236 |
+
sitk.GetImageFromArray(gt),
|
| 237 |
+
os.path.join(args.output_dir, f"{idx}_gt.nii.gz"),
|
| 238 |
+
)
|
| 239 |
+
dice_list = torch.stack(dice_list, 0)
|
| 240 |
+
np.save(
|
| 241 |
+
os.path.join(args.output_dir, f"{args.file_name}.npy"),
|
| 242 |
+
dice_list.cpu().numpy(),
|
| 243 |
+
)
|
| 244 |
+
print("dice mean: ", dice_list.mean().item())
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
if __name__ == "__main__":
|
| 248 |
+
args = get_args_parser()
|
| 249 |
+
args = args.parse_args()
|
| 250 |
+
main(args)
|
engine/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# ProFound engine package
|
engine/classification.py
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
# References:
|
| 8 |
+
# DeiT: https://github.com/facebookresearch/deit
|
| 9 |
+
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
|
| 10 |
+
# --------------------------------------------------------
|
| 11 |
+
import math
|
| 12 |
+
import sys
|
| 13 |
+
import torch
|
| 14 |
+
import os
|
| 15 |
+
import util.misc as misc
|
| 16 |
+
import util.lr_sched as lr_sched
|
| 17 |
+
import numpy as np
|
| 18 |
+
from util.metric import accuracy, ConfusionMatrix, kappa
|
| 19 |
+
from sklearn.metrics import (
|
| 20 |
+
roc_auc_score,
|
| 21 |
+
top_k_accuracy_score,
|
| 22 |
+
f1_score,
|
| 23 |
+
confusion_matrix,
|
| 24 |
+
)
|
| 25 |
+
from torchmetrics.classification import (
|
| 26 |
+
BinarySpecificityAtSensitivity,
|
| 27 |
+
BinarySensitivityAtSpecificity,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
import pdb
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def train_one_epoch(
|
| 35 |
+
model,
|
| 36 |
+
data_loader,
|
| 37 |
+
optimizer,
|
| 38 |
+
device,
|
| 39 |
+
epoch: int,
|
| 40 |
+
loss_scaler,
|
| 41 |
+
log_writer=None,
|
| 42 |
+
args=None,
|
| 43 |
+
):
|
| 44 |
+
model.train(True)
|
| 45 |
+
metric_logger = misc.MetricLogger(delimiter=" ")
|
| 46 |
+
metric_logger.add_meter("lr", misc.SmoothedValue(window_size=1, fmt="{value:.6f}"))
|
| 47 |
+
header = "Epoch: [{}]".format(epoch)
|
| 48 |
+
print_freq = 20
|
| 49 |
+
|
| 50 |
+
if args.dataset == "promis":
|
| 51 |
+
loss_cal = torch.nn.BCEWithLogitsLoss()
|
| 52 |
+
else:
|
| 53 |
+
if args.num_classes > 1:
|
| 54 |
+
loss_cal = torch.nn.CrossEntropyLoss()
|
| 55 |
+
else:
|
| 56 |
+
loss_cal = torch.nn.BCEWithLogitsLoss()
|
| 57 |
+
|
| 58 |
+
optimizer.zero_grad()
|
| 59 |
+
|
| 60 |
+
if log_writer is not None:
|
| 61 |
+
print("log_dir: {}".format(log_writer.log_dir))
|
| 62 |
+
last_norm = 0.0
|
| 63 |
+
for data_iter_step, (img, gt, dataidx) in enumerate(
|
| 64 |
+
metric_logger.log_every(data_loader, print_freq, header)
|
| 65 |
+
):
|
| 66 |
+
# we use a per iteration (instead of per epoch) lr scheduler
|
| 67 |
+
img, gt = img.to(device, non_blocking=True), gt.to(device, non_blocking=True)
|
| 68 |
+
lr_sched.adjust_learning_rate(
|
| 69 |
+
optimizer, data_iter_step / len(data_loader) + epoch, args
|
| 70 |
+
)
|
| 71 |
+
logit = model(img)
|
| 72 |
+
# print("logit: ", logit.shape, "gt: ", gt.shape, "image: ", img.shape)
|
| 73 |
+
loss = loss_cal(logit, gt)
|
| 74 |
+
loss_value = loss.item()
|
| 75 |
+
|
| 76 |
+
if not math.isfinite(loss_value):
|
| 77 |
+
print(
|
| 78 |
+
"nan",
|
| 79 |
+
torch.isnan(logit).any(),
|
| 80 |
+
torch.isnan(img).any(),
|
| 81 |
+
dataidx,
|
| 82 |
+
last_norm,
|
| 83 |
+
)
|
| 84 |
+
print(
|
| 85 |
+
"inf",
|
| 86 |
+
torch.isinf(logit).any(),
|
| 87 |
+
torch.isinf(img).any(),
|
| 88 |
+
dataidx,
|
| 89 |
+
last_norm,
|
| 90 |
+
)
|
| 91 |
+
print("Loss is {}, stopping training".format(loss_value))
|
| 92 |
+
sys.exit(1)
|
| 93 |
+
|
| 94 |
+
optimizer.zero_grad()
|
| 95 |
+
loss.backward()
|
| 96 |
+
# torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 97 |
+
optimizer.step()
|
| 98 |
+
|
| 99 |
+
# last_norm = loss_scaler(loss, optimizer, parameters=model.parameters())
|
| 100 |
+
# optimizer.zero_grad()
|
| 101 |
+
# torch.cuda.synchronize()
|
| 102 |
+
metric_logger.update(loss=loss_value)
|
| 103 |
+
|
| 104 |
+
lr = optimizer.param_groups[0]["lr"]
|
| 105 |
+
metric_logger.update(lr=lr)
|
| 106 |
+
|
| 107 |
+
loss_value_reduce = misc.all_reduce_mean(loss_value)
|
| 108 |
+
if log_writer is not None:
|
| 109 |
+
"""We use epoch_1000x as the x-axis in tensorboard.
|
| 110 |
+
This calibrates different curves when batch size changes.
|
| 111 |
+
"""
|
| 112 |
+
epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
|
| 113 |
+
log_writer.add_scalar("train_loss", loss_value_reduce, epoch_1000x)
|
| 114 |
+
log_writer.add_scalar("lr", lr, epoch_1000x)
|
| 115 |
+
|
| 116 |
+
# gather the stats from all processes
|
| 117 |
+
# metric_logger.synchronize_between_processes()
|
| 118 |
+
print("Averaged stats:", metric_logger)
|
| 119 |
+
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def validation(model, data_loader_val, device, epoch, args):
|
| 123 |
+
model.eval()
|
| 124 |
+
|
| 125 |
+
if args.dataset == "promis":
|
| 126 |
+
loss_cal = torch.nn.BCEWithLogitsLoss()
|
| 127 |
+
else:
|
| 128 |
+
if args.num_classes > 1:
|
| 129 |
+
loss_cal = torch.nn.CrossEntropyLoss()
|
| 130 |
+
else:
|
| 131 |
+
loss_cal = torch.nn.BCEWithLogitsLoss()
|
| 132 |
+
|
| 133 |
+
with torch.no_grad():
|
| 134 |
+
loss_summary = []
|
| 135 |
+
for idx, (img, gt, _) in enumerate(data_loader_val):
|
| 136 |
+
img, gt = img.to(device), gt.to(device)
|
| 137 |
+
mask = model(img)
|
| 138 |
+
loss = loss_cal(mask, gt)
|
| 139 |
+
loss_summary.append(loss.detach().cpu().numpy())
|
| 140 |
+
print(
|
| 141 |
+
"epoch: {}/{}, iter: {}/{}".format(
|
| 142 |
+
epoch, args.epochs, idx, len(data_loader_val)
|
| 143 |
+
)
|
| 144 |
+
+ " loss:"
|
| 145 |
+
+ str(loss_summary[-1].flatten()[0])
|
| 146 |
+
)
|
| 147 |
+
avg_loss = np.mean(loss_summary)
|
| 148 |
+
print("Averaged stats:", str(avg_loss))
|
| 149 |
+
return avg_loss
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def test(model, test_loader, args):
|
| 153 |
+
filepath_best = os.path.join(args.output_dir, "best.pth.tar")
|
| 154 |
+
model.load_state_dict(torch.load(filepath_best)["model"], weights_only=False)
|
| 155 |
+
model.eval()
|
| 156 |
+
prob, gts = [], []
|
| 157 |
+
with torch.no_grad():
|
| 158 |
+
for idx, (img, gt, _) in enumerate(test_loader):
|
| 159 |
+
img, gt = img.to(args.device), gt.to(args.device)
|
| 160 |
+
logit = model(img)
|
| 161 |
+
prob.append(logit)
|
| 162 |
+
gts.append(gt)
|
| 163 |
+
|
| 164 |
+
if args.dataset == "risk":
|
| 165 |
+
return test_risk(prob, gts)
|
| 166 |
+
elif args.dataset == "screening":
|
| 167 |
+
return test_screening(prob, gts)
|
| 168 |
+
elif args.dataset == "promis":
|
| 169 |
+
return test_promis(prob, gts)
|
| 170 |
+
else:
|
| 171 |
+
raise NotImplementedError(f"unknown dataset: {args.dataset}")
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def test_risk(prob, gts):
|
| 175 |
+
log_stats = {}
|
| 176 |
+
prob = torch.cat(prob, 0)
|
| 177 |
+
prob = torch.softmax(prob, dim=-1).cpu().numpy()
|
| 178 |
+
gts = torch.cat(gts, 0).cpu().numpy()
|
| 179 |
+
|
| 180 |
+
score_acc = top_k_accuracy_score(gts, prob, k=1) * 100
|
| 181 |
+
score_qwk = kappa(gts, np.argmax(prob, 1))
|
| 182 |
+
score_auc = roc_auc_score(gts, prob, multi_class="ovr") * 100
|
| 183 |
+
score_f1 = f1_score(gts, np.argmax(prob, 1), average="macro") * 100
|
| 184 |
+
|
| 185 |
+
print("score")
|
| 186 |
+
print(f"acc\t auc \t qwk \t f1")
|
| 187 |
+
print(f"{score_acc:.2f} \t {score_auc:.2f} \t {score_qwk:.4f} \t {score_f1:.2f}")
|
| 188 |
+
log_stats["4-class_acc"] = f"{score_acc:.2f}"
|
| 189 |
+
log_stats["4-class_auc"] = f"{score_auc:.2f}"
|
| 190 |
+
log_stats["4-class_qwk"] = f"{score_qwk:.4f}"
|
| 191 |
+
log_stats["4-class_f1"] = f"{score_f1:.2f}"
|
| 192 |
+
|
| 193 |
+
# 2 3 4 5 four classes 0 1 2 3
|
| 194 |
+
|
| 195 |
+
sig_prob = np.sum(prob[:, 1:], -1)
|
| 196 |
+
sig_gts = (gts > 0).astype(int)
|
| 197 |
+
sig_acc = top_k_accuracy_score(sig_gts, sig_prob, k=1) * 100
|
| 198 |
+
sig_auc = roc_auc_score(sig_gts, sig_prob) * 100
|
| 199 |
+
sig_f1 = f1_score(sig_gts, sig_prob > 0.5) * 100
|
| 200 |
+
|
| 201 |
+
print("Pirads >=3")
|
| 202 |
+
print(f"auc \t f1 ")
|
| 203 |
+
print(f"{sig_auc:.2f} \t {sig_f1:.2f}")
|
| 204 |
+
|
| 205 |
+
log_stats["leq3_auc"]=f"{sig_auc:.2f}"
|
| 206 |
+
log_stats["leq3_f1"]=f"{sig_f1:.2f}"
|
| 207 |
+
|
| 208 |
+
for i in [0.8, 0.9]:
|
| 209 |
+
sig_spec = BinarySpecificityAtSensitivity(min_sensitivity=i, thresholds=None)
|
| 210 |
+
sig_specificity, _ = sig_spec(
|
| 211 |
+
torch.from_numpy(sig_prob), torch.from_numpy(sig_gts)
|
| 212 |
+
)
|
| 213 |
+
sig_specificity = sig_specificity * 100
|
| 214 |
+
sig_sens = BinarySensitivityAtSpecificity(min_specificity=i, thresholds=None)
|
| 215 |
+
sig_sensitivity, _ = sig_sens(
|
| 216 |
+
torch.from_numpy(sig_prob), torch.from_numpy(sig_gts)
|
| 217 |
+
)
|
| 218 |
+
sig_sensitivity = sig_sensitivity* 100
|
| 219 |
+
|
| 220 |
+
print(f"min: {i}")
|
| 221 |
+
print(f"Specificity at Sensitivity \t Sensitivity at Specificity")
|
| 222 |
+
print(f"{sig_specificity:.2f} \t {sig_sensitivity:.2f} ")
|
| 223 |
+
log_stats[f"leq3_specificity_at_{i}"]=f"{sig_specificity:.2f}"
|
| 224 |
+
log_stats[f"leq3_sensitivity_at_{i}"]=f"{sig_sensitivity:.2f}"
|
| 225 |
+
|
| 226 |
+
sig_prob = np.sum(prob[:, 2:], -1)
|
| 227 |
+
sig_gts = (gts > 1).astype(int)
|
| 228 |
+
sig_acc = top_k_accuracy_score(sig_gts, sig_prob, k=1) * 100
|
| 229 |
+
sig_auc = roc_auc_score(sig_gts, sig_prob) * 100
|
| 230 |
+
sig_f1 = f1_score(sig_gts, sig_prob > 0.5) * 100
|
| 231 |
+
|
| 232 |
+
print("Pirads >=4")
|
| 233 |
+
print(f"auc \t f1 ")
|
| 234 |
+
print(f"{sig_auc:.2f} \t {sig_f1:.2f}")
|
| 235 |
+
|
| 236 |
+
log_stats["leq4_auc"]=f"{sig_auc:.2f}"
|
| 237 |
+
log_stats["leq4_f1"]=f"{sig_f1:.2f}"
|
| 238 |
+
|
| 239 |
+
for i in [0.8, 0.9]:
|
| 240 |
+
sig_spec = BinarySpecificityAtSensitivity(min_sensitivity=i, thresholds=None)
|
| 241 |
+
sig_specificity, _ = sig_spec(
|
| 242 |
+
torch.from_numpy(sig_prob), torch.from_numpy(sig_gts)
|
| 243 |
+
)
|
| 244 |
+
sig_specificity = sig_specificity * 100
|
| 245 |
+
sig_sens = BinarySensitivityAtSpecificity(min_specificity=i, thresholds=None)
|
| 246 |
+
sig_sensitivity, _ = sig_sens(
|
| 247 |
+
torch.from_numpy(sig_prob), torch.from_numpy(sig_gts)
|
| 248 |
+
)
|
| 249 |
+
sig_sensitivity = sig_sensitivity* 100
|
| 250 |
+
|
| 251 |
+
print(f"min: {i}")
|
| 252 |
+
print(f"Specificity at Sensitivity \t Sensitivity at Specificity")
|
| 253 |
+
print(f"{sig_specificity:.2f} \t {sig_sensitivity:.2f} ")
|
| 254 |
+
log_stats[f"leq4_specificity_at_{i}"]=f"{sig_specificity:.2f}"
|
| 255 |
+
log_stats[f"leq4_sensitivity_at_{i}"]=f"{sig_sensitivity:.2f}"
|
| 256 |
+
return log_stats
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def test_screening(prob, gts):
|
| 260 |
+
prob = torch.cat(prob, 0)
|
| 261 |
+
prob = torch.sigmoid(prob).cpu().numpy()
|
| 262 |
+
gts = torch.cat(gts, 0).long().cpu().numpy()
|
| 263 |
+
|
| 264 |
+
np.savez("result.npz", gts=gts, prob=prob)
|
| 265 |
+
score_acc = top_k_accuracy_score(gts, prob, k=1) * 100
|
| 266 |
+
score_auc = roc_auc_score(gts, prob) * 100
|
| 267 |
+
score_f1 = f1_score(gts, np.argmax(prob, 1)) * 100
|
| 268 |
+
|
| 269 |
+
print(f"acc\t auc \t f1")
|
| 270 |
+
print(f"{score_acc:.2f} \t {score_auc:.2f} \t {score_f1:.2f}")
|
| 271 |
+
|
| 272 |
+
for i in [0.8, 0.9]:
|
| 273 |
+
sig_spec = BinarySpecificityAtSensitivity(min_sensitivity=i, thresholds=None)
|
| 274 |
+
sig_specificity, _ = sig_spec(torch.from_numpy(prob), torch.from_numpy(gts))
|
| 275 |
+
sig_sens = BinarySensitivityAtSpecificity(min_specificity=i, thresholds=None)
|
| 276 |
+
sig_sensitivity, _ = sig_sens(torch.from_numpy(prob), torch.from_numpy(gts))
|
| 277 |
+
|
| 278 |
+
print(f"min: {i}")
|
| 279 |
+
print(f"Specificity at Sensitivity \t Sensitivity at Specificity")
|
| 280 |
+
print(f"{sig_specificity* 100:.2f} \t {sig_sensitivity* 100:.2f} ")
|
| 281 |
+
|
| 282 |
+
log_stats = None
|
| 283 |
+
return log_stats
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def test_promis(prob, gts):
|
| 288 |
+
log_stats = {}
|
| 289 |
+
|
| 290 |
+
prob = torch.cat(prob, 0)
|
| 291 |
+
prob = torch.sigmoid(prob).cpu().numpy()
|
| 292 |
+
gts = torch.cat(gts, 0).cpu().numpy().astype(int)
|
| 293 |
+
|
| 294 |
+
#zone level
|
| 295 |
+
zone_prob = prob.reshape(-1)
|
| 296 |
+
zone_gt = gts.reshape(-1)
|
| 297 |
+
print(f"zone level performance")
|
| 298 |
+
|
| 299 |
+
auc = roc_auc_score(zone_prob, zone_gt) * 100
|
| 300 |
+
print(f"AUC: {auc:.2f}")
|
| 301 |
+
for i in [0.8, 0.9]:
|
| 302 |
+
sig_spec = BinarySpecificityAtSensitivity(min_sensitivity=i, thresholds=None)
|
| 303 |
+
sig_specificity, _ = sig_spec(
|
| 304 |
+
torch.from_numpy(zone_prob), torch.from_numpy(zone_gt)
|
| 305 |
+
)
|
| 306 |
+
sig_sens = BinarySensitivityAtSpecificity(min_specificity=i, thresholds=None)
|
| 307 |
+
sig_sensitivity, _ = sig_sens(
|
| 308 |
+
torch.from_numpy(zone_prob), torch.from_numpy(zone_gt)
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
print(f"min: {i}")
|
| 312 |
+
print(f"Specificity at Sensitivity \t Sensitivity at Specificity")
|
| 313 |
+
print(f"{sig_specificity* 100:.2f} \t {sig_sensitivity* 100:.2f} ")
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
#patient level
|
| 319 |
+
patient_prob = prob.max(-1)
|
| 320 |
+
patient_gt = gts.max(-1)
|
| 321 |
+
|
| 322 |
+
print(f"patient level performance")
|
| 323 |
+
|
| 324 |
+
auc = roc_auc_score(patient_prob, patient_gt) * 100
|
| 325 |
+
print(f"AUC: {auc:.2f}")
|
| 326 |
+
for i in [0.8, 0.9]:
|
| 327 |
+
sig_spec = BinarySpecificityAtSensitivity(min_sensitivity=i, thresholds=None)
|
| 328 |
+
sig_specificity, _ = sig_spec(
|
| 329 |
+
torch.from_numpy(patient_prob), torch.from_numpy(patient_gt)
|
| 330 |
+
)
|
| 331 |
+
sig_sens = BinarySensitivityAtSpecificity(min_specificity=i, thresholds=None)
|
| 332 |
+
sig_sensitivity, _ = sig_sens(
|
| 333 |
+
torch.from_numpy(patient_prob), torch.from_numpy(patient_gt)
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
print(f"min: {i}")
|
| 337 |
+
print(f"Specificity at Sensitivity \t Sensitivity at Specificity")
|
| 338 |
+
print(f"{sig_specificity* 100:.2f} \t {sig_sensitivity* 100:.2f} ")
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
return log_stats
|
engine/location.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import sys
|
| 3 |
+
from typing import Iterable
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import os
|
| 7 |
+
import util.misc as misc
|
| 8 |
+
import util.lr_sched as lr_sched
|
| 9 |
+
from monai.losses import DiceCELoss, DiceLoss
|
| 10 |
+
import numpy as np
|
| 11 |
+
from monai.metrics import DiceHelper
|
| 12 |
+
import surface_distance
|
| 13 |
+
from surface_distance import metrics
|
| 14 |
+
from util.meter import DiceMeter, HausdorffMeter, SurfaceDistanceMeter
|
| 15 |
+
|
| 16 |
+
# from monai.data import ImageDataset, create_test_image_3d, decollate_batch, DataLoader
|
| 17 |
+
from monai.inferers import sliding_window_inference
|
| 18 |
+
from torchmetrics.classification import (
|
| 19 |
+
BinarySpecificityAtSensitivity,
|
| 20 |
+
BinarySensitivityAtSpecificity,
|
| 21 |
+
)
|
| 22 |
+
# from monai.metrics import DiceMetric
|
| 23 |
+
# from monai.transforms import Activations
|
| 24 |
+
import pdb
|
| 25 |
+
from sklearn.metrics import (
|
| 26 |
+
roc_auc_score,
|
| 27 |
+
top_k_accuracy_score,
|
| 28 |
+
f1_score,
|
| 29 |
+
confusion_matrix,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def train_one_epoch(
|
| 34 |
+
model,
|
| 35 |
+
data_loader,
|
| 36 |
+
optimizer,
|
| 37 |
+
device,
|
| 38 |
+
epoch: int,
|
| 39 |
+
loss_scaler,
|
| 40 |
+
log_writer=None,
|
| 41 |
+
args=None,
|
| 42 |
+
):
|
| 43 |
+
model.train(True)
|
| 44 |
+
metric_logger = misc.MetricLogger(delimiter=" ")
|
| 45 |
+
metric_logger.add_meter("lr", misc.SmoothedValue(window_size=1, fmt="{value:.6f}"))
|
| 46 |
+
header = "Epoch: [{}]".format(epoch)
|
| 47 |
+
print_freq = 20
|
| 48 |
+
|
| 49 |
+
loss_cal = torch.nn.BCEWithLogitsLoss()
|
| 50 |
+
|
| 51 |
+
optimizer.zero_grad()
|
| 52 |
+
|
| 53 |
+
if log_writer is not None:
|
| 54 |
+
print("log_dir: {}".format(log_writer.log_dir))
|
| 55 |
+
last_norm = 0.0
|
| 56 |
+
for data_iter_step, (img, zone_mask, gt) in enumerate(
|
| 57 |
+
metric_logger.log_every(data_loader, print_freq, header)
|
| 58 |
+
):
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# we use a per iteration (instead of per epoch) lr scheduler
|
| 62 |
+
img, zone_mask, gt = img.to(device, non_blocking=True), zone_mask.to(device, non_blocking=True), gt.to(device, non_blocking=True)
|
| 63 |
+
gt = gt.float()
|
| 64 |
+
lr_sched.adjust_learning_rate(
|
| 65 |
+
optimizer, data_iter_step / len(data_loader) + epoch, args
|
| 66 |
+
)
|
| 67 |
+
logit = model(img, zone_mask)
|
| 68 |
+
if isinstance(logit, list):
|
| 69 |
+
loss = loss_cal(logit[0], gt) + 0.4*loss_cal(logit[1], gt)
|
| 70 |
+
else:
|
| 71 |
+
loss = loss_cal(logit, gt)
|
| 72 |
+
|
| 73 |
+
loss_value = loss.item()
|
| 74 |
+
|
| 75 |
+
if not math.isfinite(loss_value):
|
| 76 |
+
print(
|
| 77 |
+
"nan",
|
| 78 |
+
torch.isnan(logit).any(),
|
| 79 |
+
torch.isnan(img).any(),
|
| 80 |
+
last_norm,
|
| 81 |
+
)
|
| 82 |
+
print(
|
| 83 |
+
"inf",
|
| 84 |
+
torch.isinf(logit).any(),
|
| 85 |
+
torch.isinf(img).any(),
|
| 86 |
+
last_norm,
|
| 87 |
+
)
|
| 88 |
+
print("Loss is {}, stopping training".format(loss_value))
|
| 89 |
+
sys.exit(1)
|
| 90 |
+
|
| 91 |
+
optimizer.zero_grad()
|
| 92 |
+
loss.backward()
|
| 93 |
+
# torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 94 |
+
optimizer.step()
|
| 95 |
+
|
| 96 |
+
metric_logger.update(loss=loss_value)
|
| 97 |
+
|
| 98 |
+
lr = optimizer.param_groups[0]["lr"]
|
| 99 |
+
metric_logger.update(lr=lr)
|
| 100 |
+
|
| 101 |
+
loss_value_reduce = misc.all_reduce_mean(loss_value)
|
| 102 |
+
if log_writer is not None:
|
| 103 |
+
"""We use epoch_1000x as the x-axis in tensorboard.
|
| 104 |
+
This calibrates different curves when batch size changes.
|
| 105 |
+
"""
|
| 106 |
+
epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
|
| 107 |
+
log_writer.add_scalar("train_loss", loss_value_reduce, epoch_1000x)
|
| 108 |
+
log_writer.add_scalar("lr", lr, epoch_1000x)
|
| 109 |
+
|
| 110 |
+
# gather the stats from all processes
|
| 111 |
+
# metric_logger.synchronize_between_processes()
|
| 112 |
+
print("Averaged stats:", metric_logger)
|
| 113 |
+
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def validation(model, data_loader_val, device, epoch, args):
|
| 117 |
+
model.eval()
|
| 118 |
+
loss_cal = torch.nn.BCEWithLogitsLoss()
|
| 119 |
+
|
| 120 |
+
with torch.no_grad():
|
| 121 |
+
loss_summary = []
|
| 122 |
+
|
| 123 |
+
for idx, (img, zone_mask, gt) in enumerate(data_loader_val):
|
| 124 |
+
img, zone_mask, gt = img.to(device, non_blocking=True), zone_mask.to(device, non_blocking=True), gt.to(device, non_blocking=True)
|
| 125 |
+
gt = gt.float()
|
| 126 |
+
logit = model(img, zone_mask)
|
| 127 |
+
loss = loss_cal(logit, gt)
|
| 128 |
+
loss_summary.append(loss.detach().cpu().numpy())
|
| 129 |
+
print(
|
| 130 |
+
"epoch: {}/{}, iter: {}/{}".format(
|
| 131 |
+
epoch, args.epochs, idx, len(data_loader_val)
|
| 132 |
+
)
|
| 133 |
+
+ " loss:"
|
| 134 |
+
+ str(loss_summary[-1].flatten()[0])
|
| 135 |
+
)
|
| 136 |
+
avg_loss = np.mean(loss_summary)
|
| 137 |
+
print("Averaged stats:", str(avg_loss))
|
| 138 |
+
return avg_loss
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def test(model, test_loader, args, sliding_window=False):
|
| 142 |
+
model.eval()
|
| 143 |
+
filepath_best = os.path.join(args.output_dir, "best.pth.tar")
|
| 144 |
+
model.load_state_dict(torch.load(filepath_best)["model"], weights_only=False)
|
| 145 |
+
|
| 146 |
+
log_stats = {}
|
| 147 |
+
with torch.no_grad():
|
| 148 |
+
prob, gts = [], []
|
| 149 |
+
|
| 150 |
+
for idx, (img, zone_mask, gt) in enumerate(test_loader):
|
| 151 |
+
img, zone_mask, gt = img.to(args.device, non_blocking=True), zone_mask.to(args.device, non_blocking=True), gt.to(args.device, non_blocking=True)
|
| 152 |
+
|
| 153 |
+
logit = model(img, zone_mask)
|
| 154 |
+
prob.append(logit)
|
| 155 |
+
gts.append(gt)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
prob = torch.cat(prob, 0)
|
| 159 |
+
prob = torch.sigmoid(prob).cpu()
|
| 160 |
+
gts = torch.cat(gts, 0).cpu()
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
print("- Zone level: ")
|
| 165 |
+
zone_prob = prob.reshape(-1, prob.shape[-1])
|
| 166 |
+
zone_gt = gts.reshape(-1, prob.shape[-1])
|
| 167 |
+
zone_auc = roc_auc_score(zone_prob, zone_gt) * 100
|
| 168 |
+
|
| 169 |
+
for i in [0.8, 0.9]:
|
| 170 |
+
sig_spec = BinarySpecificityAtSensitivity(min_sensitivity=i, thresholds=None)
|
| 171 |
+
sig_specificity, _ = sig_spec(zone_prob, zone_gt)
|
| 172 |
+
sig_specificity = sig_specificity * 100
|
| 173 |
+
|
| 174 |
+
sig_sens = BinarySensitivityAtSpecificity(min_specificity=i, thresholds=None)
|
| 175 |
+
sig_sensitivity, _ = sig_sens(zone_prob, zone_gt)
|
| 176 |
+
sig_sensitivity = sig_sensitivity* 100
|
| 177 |
+
|
| 178 |
+
print(f"min: {i}")
|
| 179 |
+
print(f"Specificity at Sensitivity \t Sensitivity at Specificity")
|
| 180 |
+
print(f"{sig_specificity:.2f} \t {sig_sensitivity:.2f} ")
|
| 181 |
+
log_stats[f"specificity_at_{i}"]=f"{sig_specificity:.2f}"
|
| 182 |
+
log_stats[f"sensitivity_at_{i}"]=f"{sig_sensitivity:.2f}"
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
print("- Patient level: ")
|
| 186 |
+
p_prob = prob.max(1).values
|
| 187 |
+
p_gt = gts.max(1).values
|
| 188 |
+
|
| 189 |
+
p_auc = roc_auc_score(p_prob, p_gt) * 100
|
| 190 |
+
|
| 191 |
+
for i in [0.8, 0.9]:
|
| 192 |
+
sig_spec = BinarySpecificityAtSensitivity(min_sensitivity=i, thresholds=None)
|
| 193 |
+
sig_specificity, _ = sig_spec(p_prob, p_gt)
|
| 194 |
+
sig_specificity = sig_specificity * 100
|
| 195 |
+
|
| 196 |
+
sig_sens = BinarySensitivityAtSpecificity(min_specificity=i, thresholds=None)
|
| 197 |
+
sig_sensitivity, _ = sig_sens(p_prob, p_gt)
|
| 198 |
+
sig_sensitivity = sig_sensitivity* 100
|
| 199 |
+
|
| 200 |
+
print(f"min: {i}")
|
| 201 |
+
print(f"Specificity at Sensitivity \t Sensitivity at Specificity")
|
| 202 |
+
print(f"{sig_specificity:.2f} \t {sig_sensitivity:.2f} ")
|
| 203 |
+
log_stats[f"specificity_at_{i}"]=f"{sig_specificity:.2f}"
|
| 204 |
+
log_stats[f"sensitivity_at_{i}"]=f"{sig_sensitivity:.2f}"
|
| 205 |
+
|
| 206 |
+
return log_stats
|
engine/pretrain.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
# References:
|
| 8 |
+
# DeiT: https://github.com/facebookresearch/deit
|
| 9 |
+
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
|
| 10 |
+
# --------------------------------------------------------
|
| 11 |
+
import math
|
| 12 |
+
import sys
|
| 13 |
+
from typing import Iterable
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
|
| 17 |
+
import util.misc as misc
|
| 18 |
+
import util.lr_sched as lr_sched
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def train_one_epoch(
|
| 22 |
+
model,
|
| 23 |
+
data_loader,
|
| 24 |
+
optimizer,
|
| 25 |
+
device,
|
| 26 |
+
epoch: int,
|
| 27 |
+
loss_scaler,
|
| 28 |
+
log_writer=None,
|
| 29 |
+
args=None,
|
| 30 |
+
):
|
| 31 |
+
model.train(True)
|
| 32 |
+
metric_logger = misc.MetricLogger(delimiter=" ")
|
| 33 |
+
metric_logger.add_meter("lr", misc.SmoothedValue(window_size=1, fmt="{value:.6f}"))
|
| 34 |
+
header = "Epoch: [{}]".format(epoch)
|
| 35 |
+
print_freq = 20
|
| 36 |
+
|
| 37 |
+
optimizer.zero_grad()
|
| 38 |
+
|
| 39 |
+
if log_writer is not None:
|
| 40 |
+
print("log_dir: {}".format(log_writer.log_dir))
|
| 41 |
+
|
| 42 |
+
for data_iter_step, (samples, _) in enumerate(
|
| 43 |
+
metric_logger.log_every(data_loader, print_freq, header)
|
| 44 |
+
):
|
| 45 |
+
|
| 46 |
+
# we use a per iteration (instead of per epoch) lr scheduler
|
| 47 |
+
samples = samples.to(device, non_blocking=True)
|
| 48 |
+
lr_sched.adjust_learning_rate(
|
| 49 |
+
optimizer, data_iter_step / len(data_loader) + epoch, args
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
# with torch.cuda.amp.autocast():
|
| 53 |
+
loss, _, _ = model(samples, mask_ratio=args.mask_ratio)
|
| 54 |
+
|
| 55 |
+
loss_value = loss.item()
|
| 56 |
+
|
| 57 |
+
if not math.isfinite(loss_value):
|
| 58 |
+
print("Loss is {}, stopping training".format(loss_value))
|
| 59 |
+
sys.exit(1)
|
| 60 |
+
|
| 61 |
+
optimizer.zero_grad()
|
| 62 |
+
loss.backward()
|
| 63 |
+
# torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 64 |
+
optimizer.step()
|
| 65 |
+
# loss_scaler(loss, optimizer, parameters=model.parameters(),clip_grad=1.0)
|
| 66 |
+
# optimizer.zero_grad()
|
| 67 |
+
torch.cuda.synchronize()
|
| 68 |
+
metric_logger.update(loss=loss_value)
|
| 69 |
+
|
| 70 |
+
lr = optimizer.param_groups[0]["lr"]
|
| 71 |
+
metric_logger.update(lr=lr)
|
| 72 |
+
|
| 73 |
+
loss_value_reduce = misc.all_reduce_mean(loss_value)
|
| 74 |
+
if log_writer is not None:
|
| 75 |
+
"""We use epoch_1000x as the x-axis in tensorboard.
|
| 76 |
+
This calibrates different curves when batch size changes.
|
| 77 |
+
"""
|
| 78 |
+
epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
|
| 79 |
+
log_writer.add_scalar("train_loss", loss_value_reduce, epoch_1000x)
|
| 80 |
+
log_writer.add_scalar("lr", lr, epoch_1000x)
|
| 81 |
+
|
| 82 |
+
# gather the stats from all processes
|
| 83 |
+
metric_logger.synchronize_between_processes()
|
| 84 |
+
print("Averaged stats:", metric_logger)
|
| 85 |
+
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
engine/pretrain_amp.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
# References:
|
| 8 |
+
# DeiT: https://github.com/facebookresearch/deit
|
| 9 |
+
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
|
| 10 |
+
# --------------------------------------------------------
|
| 11 |
+
import math
|
| 12 |
+
import sys
|
| 13 |
+
from typing import Iterable
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
|
| 17 |
+
import util.misc as misc
|
| 18 |
+
import util.lr_sched as lr_sched
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def train_one_epoch(
|
| 22 |
+
model,
|
| 23 |
+
data_loader,
|
| 24 |
+
optimizer,
|
| 25 |
+
device,
|
| 26 |
+
epoch: int,
|
| 27 |
+
loss_scaler,
|
| 28 |
+
log_writer=None,
|
| 29 |
+
args=None,
|
| 30 |
+
):
|
| 31 |
+
model.train(True)
|
| 32 |
+
metric_logger = misc.MetricLogger(delimiter=" ")
|
| 33 |
+
metric_logger.add_meter("lr", misc.SmoothedValue(window_size=1, fmt="{value:.6f}"))
|
| 34 |
+
header = "Epoch: [{}]".format(epoch)
|
| 35 |
+
print_freq = 20
|
| 36 |
+
|
| 37 |
+
optimizer.zero_grad()
|
| 38 |
+
|
| 39 |
+
if log_writer is not None:
|
| 40 |
+
print("log_dir: {}".format(log_writer.log_dir))
|
| 41 |
+
|
| 42 |
+
for data_iter_step, (samples, _) in enumerate(
|
| 43 |
+
metric_logger.log_every(data_loader, print_freq, header)
|
| 44 |
+
):
|
| 45 |
+
|
| 46 |
+
# we use a per iteration (instead of per epoch) lr scheduler
|
| 47 |
+
samples = samples.to(device, non_blocking=True)
|
| 48 |
+
lr_sched.adjust_learning_rate(
|
| 49 |
+
optimizer, data_iter_step / len(data_loader) + epoch, args
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
with torch.cuda.amp.autocast():
|
| 53 |
+
loss, _, _ = model(samples, mask_ratio=args.mask_ratio)
|
| 54 |
+
|
| 55 |
+
loss_value = loss.item()
|
| 56 |
+
|
| 57 |
+
if not math.isfinite(loss_value):
|
| 58 |
+
print("Loss is {}, stopping training".format(loss_value))
|
| 59 |
+
sys.exit(1)
|
| 60 |
+
|
| 61 |
+
loss_scaler(loss, optimizer, parameters=model.parameters(), clip_grad=1.0)
|
| 62 |
+
optimizer.zero_grad()
|
| 63 |
+
torch.cuda.synchronize()
|
| 64 |
+
metric_logger.update(loss=loss_value)
|
| 65 |
+
|
| 66 |
+
lr = optimizer.param_groups[0]["lr"]
|
| 67 |
+
metric_logger.update(lr=lr)
|
| 68 |
+
|
| 69 |
+
loss_value_reduce = misc.all_reduce_mean(loss_value)
|
| 70 |
+
if log_writer is not None:
|
| 71 |
+
"""We use epoch_1000x as the x-axis in tensorboard.
|
| 72 |
+
This calibrates different curves when batch size changes.
|
| 73 |
+
"""
|
| 74 |
+
epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
|
| 75 |
+
log_writer.add_scalar("train_loss", loss_value_reduce, epoch_1000x)
|
| 76 |
+
log_writer.add_scalar("lr", lr, epoch_1000x)
|
| 77 |
+
|
| 78 |
+
# gather the stats from all processes
|
| 79 |
+
metric_logger.synchronize_between_processes()
|
| 80 |
+
print("Averaged stats:", metric_logger)
|
| 81 |
+
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
engine/regression.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
# References:
|
| 8 |
+
# DeiT: https://github.com/facebookresearch/deit
|
| 9 |
+
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
|
| 10 |
+
# --------------------------------------------------------
|
| 11 |
+
import math
|
| 12 |
+
import sys
|
| 13 |
+
import torch
|
| 14 |
+
import os
|
| 15 |
+
import util.misc as misc
|
| 16 |
+
import util.lr_sched as lr_sched
|
| 17 |
+
import numpy as np
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def train_one_epoch(
|
| 21 |
+
model,
|
| 22 |
+
data_loader,
|
| 23 |
+
optimizer,
|
| 24 |
+
device,
|
| 25 |
+
epoch: int,
|
| 26 |
+
loss_scaler,
|
| 27 |
+
log_writer=None,
|
| 28 |
+
args=None,
|
| 29 |
+
):
|
| 30 |
+
model.train(True)
|
| 31 |
+
metric_logger = misc.MetricLogger(delimiter=" ")
|
| 32 |
+
metric_logger.add_meter("lr", misc.SmoothedValue(window_size=1, fmt="{value:.6f}"))
|
| 33 |
+
header = "Epoch: [{}]".format(epoch)
|
| 34 |
+
print_freq = 20
|
| 35 |
+
|
| 36 |
+
loss_cal = torch.nn.MSELoss()
|
| 37 |
+
|
| 38 |
+
optimizer.zero_grad()
|
| 39 |
+
|
| 40 |
+
if log_writer is not None:
|
| 41 |
+
print("log_dir: {}".format(log_writer.log_dir))
|
| 42 |
+
last_norm = 0.0
|
| 43 |
+
for data_iter_step, (img, gt, dataidx) in enumerate(
|
| 44 |
+
metric_logger.log_every(data_loader, print_freq, header)
|
| 45 |
+
):
|
| 46 |
+
# we use a per iteration (instead of per epoch) lr scheduler
|
| 47 |
+
img, gt = img.to(device, non_blocking=True), gt.to(device, non_blocking=True)
|
| 48 |
+
lr_sched.adjust_learning_rate(
|
| 49 |
+
optimizer, data_iter_step / len(data_loader) + epoch, args
|
| 50 |
+
)
|
| 51 |
+
logit = model(img)
|
| 52 |
+
loss = loss_cal(logit, gt)
|
| 53 |
+
loss_value = loss.item()
|
| 54 |
+
if not math.isfinite(loss_value):
|
| 55 |
+
print(
|
| 56 |
+
"nan",
|
| 57 |
+
torch.isnan(logit).any(),
|
| 58 |
+
torch.isnan(img).any(),
|
| 59 |
+
dataidx,
|
| 60 |
+
last_norm,
|
| 61 |
+
)
|
| 62 |
+
print(
|
| 63 |
+
"inf",
|
| 64 |
+
torch.isinf(logit).any(),
|
| 65 |
+
torch.isinf(img).any(),
|
| 66 |
+
dataidx,
|
| 67 |
+
last_norm,
|
| 68 |
+
)
|
| 69 |
+
print("Loss is {}, stopping training".format(loss_value))
|
| 70 |
+
sys.exit(1)
|
| 71 |
+
|
| 72 |
+
optimizer.zero_grad()
|
| 73 |
+
loss.backward()
|
| 74 |
+
# torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 75 |
+
optimizer.step()
|
| 76 |
+
|
| 77 |
+
# last_norm = loss_scaler(loss, optimizer, parameters=model.parameters())
|
| 78 |
+
# optimizer.zero_grad()
|
| 79 |
+
# torch.cuda.synchronize()
|
| 80 |
+
metric_logger.update(loss=loss_value)
|
| 81 |
+
|
| 82 |
+
lr = optimizer.param_groups[0]["lr"]
|
| 83 |
+
metric_logger.update(lr=lr)
|
| 84 |
+
|
| 85 |
+
loss_value_reduce = misc.all_reduce_mean(loss_value)
|
| 86 |
+
if log_writer is not None:
|
| 87 |
+
"""We use epoch_1000x as the x-axis in tensorboard.
|
| 88 |
+
This calibrates different curves when batch size changes.
|
| 89 |
+
"""
|
| 90 |
+
epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
|
| 91 |
+
log_writer.add_scalar("train_loss", loss_value_reduce, epoch_1000x)
|
| 92 |
+
log_writer.add_scalar("lr", lr, epoch_1000x)
|
| 93 |
+
|
| 94 |
+
# gather the stats from all processes
|
| 95 |
+
metric_logger.synchronize_between_processes()
|
| 96 |
+
print("Averaged stats:", metric_logger)
|
| 97 |
+
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def validation(model, data_loader_val, device, epoch, args):
|
| 101 |
+
model.eval()
|
| 102 |
+
loss_cal = torch.nn.MSELoss()
|
| 103 |
+
with torch.no_grad():
|
| 104 |
+
loss_summary = []
|
| 105 |
+
for idx, (img, gt, _) in enumerate(data_loader_val):
|
| 106 |
+
img, gt = img.to(device), gt.to(device)
|
| 107 |
+
loss = loss_cal(model(img), gt)
|
| 108 |
+
loss_summary.append(loss.detach().cpu().numpy())
|
| 109 |
+
print(
|
| 110 |
+
"epoch: {}/{}, iter: {}/{}".format(
|
| 111 |
+
epoch, args.epochs, idx, len(data_loader_val)
|
| 112 |
+
)
|
| 113 |
+
+ " loss:"
|
| 114 |
+
+ str(loss_summary[-1].flatten()[0])
|
| 115 |
+
)
|
| 116 |
+
avg_loss = np.mean(loss_summary)
|
| 117 |
+
print("Averaged stats:", str(avg_loss))
|
| 118 |
+
return avg_loss
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def test(model, test_loader, args):
|
| 122 |
+
filepath_best = os.path.join(args.output_dir, "best.pth.tar")
|
| 123 |
+
model.load_state_dict(torch.load(filepath_best)["model"], weights_only=False)
|
| 124 |
+
|
| 125 |
+
model.eval()
|
| 126 |
+
log_stats = {}
|
| 127 |
+
pred, gts = [], []
|
| 128 |
+
|
| 129 |
+
with torch.no_grad():
|
| 130 |
+
for idx, (img, gt, _) in enumerate(test_loader):
|
| 131 |
+
img, gt = img.to(args.device), gt.to(args.device)
|
| 132 |
+
pred.append(model(img))
|
| 133 |
+
gts.append(gt)
|
| 134 |
+
pred = torch.cat(pred, 0)
|
| 135 |
+
gts = torch.cat(gts, 0)
|
| 136 |
+
pred = pred * 500000 + 70000
|
| 137 |
+
gts = gts * 500000 + 70000
|
| 138 |
+
mse = torch.nn.MSELoss()(pred, gts)
|
| 139 |
+
mae = torch.nn.L1Loss()(pred, gts)
|
| 140 |
+
print("MSE", mse.item(), "MAE", mae.item())
|
| 141 |
+
log_stats = {"MSE": mse.item(), "MAE": mae.item()}
|
| 142 |
+
return log_stats
|
engine/segment.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
# References:
|
| 8 |
+
# DeiT: https://github.com/facebookresearch/deit
|
| 9 |
+
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
|
| 10 |
+
# --------------------------------------------------------
|
| 11 |
+
import math
|
| 12 |
+
import sys
|
| 13 |
+
from typing import Iterable
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import os
|
| 17 |
+
import util.misc as misc
|
| 18 |
+
import util.lr_sched as lr_sched
|
| 19 |
+
from monai.losses import DiceCELoss, DiceLoss
|
| 20 |
+
import numpy as np
|
| 21 |
+
from monai.metrics import DiceHelper
|
| 22 |
+
import surface_distance
|
| 23 |
+
from surface_distance import metrics
|
| 24 |
+
from util.meter import DiceMeter, HausdorffMeter, SurfaceDistanceMeter
|
| 25 |
+
|
| 26 |
+
# from monai.data import ImageDataset, create_test_image_3d, decollate_batch, DataLoader
|
| 27 |
+
from monai.inferers import sliding_window_inference
|
| 28 |
+
|
| 29 |
+
# from monai.metrics import DiceMetric
|
| 30 |
+
# from monai.transforms import Activations
|
| 31 |
+
import pdb
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def train_one_epoch(
|
| 35 |
+
model,
|
| 36 |
+
data_loader,
|
| 37 |
+
optimizer,
|
| 38 |
+
device,
|
| 39 |
+
epoch: int,
|
| 40 |
+
loss_scaler,
|
| 41 |
+
log_writer=None,
|
| 42 |
+
args=None,
|
| 43 |
+
):
|
| 44 |
+
model.train(True)
|
| 45 |
+
metric_logger = misc.MetricLogger(delimiter=" ")
|
| 46 |
+
metric_logger.add_meter("lr", misc.SmoothedValue(window_size=1, fmt="{value:.6f}"))
|
| 47 |
+
header = "Epoch: [{}]".format(epoch)
|
| 48 |
+
print_freq = 20
|
| 49 |
+
|
| 50 |
+
if args.out_channels == 1:
|
| 51 |
+
loss_cal = DiceCELoss(sigmoid=True)
|
| 52 |
+
else:
|
| 53 |
+
loss_cal = DiceCELoss(to_onehot_y=True, softmax=True, include_background=False)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
optimizer.zero_grad()
|
| 57 |
+
|
| 58 |
+
if log_writer is not None:
|
| 59 |
+
print("log_dir: {}".format(log_writer.log_dir))
|
| 60 |
+
last_norm = 0.0
|
| 61 |
+
for data_iter_step, (img, gt, dataidx) in enumerate(
|
| 62 |
+
metric_logger.log_every(data_loader, print_freq, header)
|
| 63 |
+
):
|
| 64 |
+
# we use a per iteration (instead of per epoch) lr scheduler
|
| 65 |
+
img, gt = img.to(device, non_blocking=True), gt.to(device, non_blocking=True)
|
| 66 |
+
lr_sched.adjust_learning_rate(
|
| 67 |
+
optimizer, data_iter_step / len(data_loader) + epoch, args
|
| 68 |
+
)
|
| 69 |
+
# print(img.shape, img.mean(), img.std())
|
| 70 |
+
# with torch.cuda.amp.autocast():
|
| 71 |
+
logit = model(img)
|
| 72 |
+
if isinstance(logit, list):
|
| 73 |
+
loss = loss_cal(logit[0], gt) + 0.4*loss_cal(logit[1], gt)
|
| 74 |
+
else:
|
| 75 |
+
loss = loss_cal(logit, gt)
|
| 76 |
+
|
| 77 |
+
loss_value = loss.item()
|
| 78 |
+
|
| 79 |
+
if not math.isfinite(loss_value):
|
| 80 |
+
print(
|
| 81 |
+
"nan",
|
| 82 |
+
torch.isnan(logit).any(),
|
| 83 |
+
torch.isnan(img).any(),
|
| 84 |
+
dataidx,
|
| 85 |
+
last_norm,
|
| 86 |
+
)
|
| 87 |
+
print(
|
| 88 |
+
"inf",
|
| 89 |
+
torch.isinf(logit).any(),
|
| 90 |
+
torch.isinf(img).any(),
|
| 91 |
+
dataidx,
|
| 92 |
+
last_norm,
|
| 93 |
+
)
|
| 94 |
+
print("Loss is {}, stopping training".format(loss_value))
|
| 95 |
+
sys.exit(1)
|
| 96 |
+
|
| 97 |
+
optimizer.zero_grad()
|
| 98 |
+
loss.backward()
|
| 99 |
+
# torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 100 |
+
optimizer.step()
|
| 101 |
+
|
| 102 |
+
# last_norm = loss_scaler(loss, optimizer, parameters=model.parameters())
|
| 103 |
+
# optimizer.zero_grad()
|
| 104 |
+
# torch.cuda.synchronize()
|
| 105 |
+
metric_logger.update(loss=loss_value)
|
| 106 |
+
|
| 107 |
+
lr = optimizer.param_groups[0]["lr"]
|
| 108 |
+
metric_logger.update(lr=lr)
|
| 109 |
+
|
| 110 |
+
loss_value_reduce = misc.all_reduce_mean(loss_value)
|
| 111 |
+
if log_writer is not None:
|
| 112 |
+
"""We use epoch_1000x as the x-axis in tensorboard.
|
| 113 |
+
This calibrates different curves when batch size changes.
|
| 114 |
+
"""
|
| 115 |
+
epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
|
| 116 |
+
log_writer.add_scalar("train_loss", loss_value_reduce, epoch_1000x)
|
| 117 |
+
log_writer.add_scalar("lr", lr, epoch_1000x)
|
| 118 |
+
|
| 119 |
+
# gather the stats from all processes
|
| 120 |
+
# metric_logger.synchronize_between_processes()
|
| 121 |
+
print("Averaged stats:", metric_logger)
|
| 122 |
+
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def validation(model, data_loader_val, device, epoch, args):
|
| 126 |
+
model.eval()
|
| 127 |
+
if args.out_channels == 1:
|
| 128 |
+
dice_loss = DiceLoss(sigmoid=True)
|
| 129 |
+
else:
|
| 130 |
+
dice_loss = DiceLoss(to_onehot_y=True, softmax=True, include_background=False)
|
| 131 |
+
|
| 132 |
+
with torch.no_grad():
|
| 133 |
+
loss_summary = []
|
| 134 |
+
for idx, (img, gt, _) in enumerate(data_loader_val):
|
| 135 |
+
img, gt = img.to(device), gt.to(device)
|
| 136 |
+
mask = model(img)
|
| 137 |
+
loss = dice_loss(mask, gt)
|
| 138 |
+
loss_summary.append(loss.detach().cpu().numpy())
|
| 139 |
+
print(
|
| 140 |
+
"epoch: {}/{}, iter: {}/{}".format(
|
| 141 |
+
epoch, args.epochs, idx, len(data_loader_val)
|
| 142 |
+
)
|
| 143 |
+
+ " loss:"
|
| 144 |
+
+ str(loss_summary[-1].flatten()[0])
|
| 145 |
+
)
|
| 146 |
+
avg_loss = np.mean(loss_summary)
|
| 147 |
+
print("Averaged stats:", str(avg_loss))
|
| 148 |
+
return avg_loss
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def test(model, test_loader, args, sliding_window=False):
|
| 152 |
+
model.eval()
|
| 153 |
+
filepath_best = os.path.join(args.output_dir, "best.pth.tar")
|
| 154 |
+
model.load_state_dict(torch.load(filepath_best)["model"], weights_only=False)
|
| 155 |
+
dice_meter = DiceMeter(args)
|
| 156 |
+
hausdorff_meter = HausdorffMeter(args)
|
| 157 |
+
sd_meter = SurfaceDistanceMeter(args)
|
| 158 |
+
log_stats = {}
|
| 159 |
+
with torch.no_grad():
|
| 160 |
+
for idx, (img, gt, _) in enumerate(test_loader):
|
| 161 |
+
img, gt = img.to(args.device), gt.to(args.device)
|
| 162 |
+
if sliding_window:
|
| 163 |
+
pred = sliding_window_inference(
|
| 164 |
+
img, args.crop_spatial_size, 4, model, overlap=0.5
|
| 165 |
+
)
|
| 166 |
+
else:
|
| 167 |
+
pred = model(img)
|
| 168 |
+
if args.num_classes == 1:
|
| 169 |
+
pred = torch.sigmoid(pred) > 0.5
|
| 170 |
+
else:
|
| 171 |
+
pred = torch.softmax(pred, dim=1)
|
| 172 |
+
pred = torch.argmax(pred, dim=1, keepdim=True)
|
| 173 |
+
dice_meter.update(pred, gt)
|
| 174 |
+
hausdorff_meter.update(pred, gt)
|
| 175 |
+
sd_meter.update(pred, gt)
|
| 176 |
+
|
| 177 |
+
print("- Test metrics Dice: ")
|
| 178 |
+
dice_class_avg, dice_avg = dice_meter.get_average()
|
| 179 |
+
print("Class wise: ", dice_class_avg)
|
| 180 |
+
print("Avg.: ", dice_avg)
|
| 181 |
+
|
| 182 |
+
print("- Test metrics Hausdorff95: ")
|
| 183 |
+
hsd_class_avg, hsd_avg = hausdorff_meter.get_average()
|
| 184 |
+
print("Class wise: ", hsd_class_avg)
|
| 185 |
+
print("Avg.: ", hsd_avg)
|
| 186 |
+
|
| 187 |
+
print("- Test metrics SurfaceDistance: ")
|
| 188 |
+
sd_class_avg, sd_avg = sd_meter.get_average()
|
| 189 |
+
print("Class wise: ", sd_class_avg)
|
| 190 |
+
print("Avg.: ", sd_avg)
|
| 191 |
+
log_stats = {
|
| 192 |
+
"dice_class_avg": dice_class_avg.tolist() if isinstance(dice_class_avg, np.ndarray) else dice_class_avg,
|
| 193 |
+
"dice_avg": dice_avg.tolist() if isinstance(dice_avg, np.ndarray) else dice_avg,
|
| 194 |
+
"hsd_class_avg": hsd_class_avg.tolist() if isinstance(hsd_class_avg, np.ndarray) else hsd_class_avg,
|
| 195 |
+
"hsd_avg": hsd_avg.tolist() if isinstance(hsd_avg, np.ndarray) else hsd_avg,
|
| 196 |
+
"sd_class_avg": sd_class_avg.tolist() if isinstance(sd_class_avg, np.ndarray) else sd_class_avg,
|
| 197 |
+
"sd_avg": sd_avg.tolist() if isinstance(sd_avg, np.ndarray) else sd_avg,
|
| 198 |
+
}
|
| 199 |
+
return log_stats
|
models/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# ProFound models package
|
models/build_classification.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from models.classifier import Classifier
|
| 2 |
+
from models.convnextv2 import convnextv2_tiny, remap_checkpoint_keys, load_state_dict
|
| 3 |
+
from util.lars import LARS
|
| 4 |
+
import torch
|
| 5 |
+
import os
|
| 6 |
+
from util.convnext_optim import get_parameter_groups, LayerDecayValueAssigner
|
| 7 |
+
|
| 8 |
+
def build_model(args, device):
|
| 9 |
+
if args.model == "profound_conv":
|
| 10 |
+
convnext = convnextv2_tiny(in_chans=3, drop_path_rate=0.1)
|
| 11 |
+
if args.pretrain is None:
|
| 12 |
+
raise NotImplementedError(f"No pretrained weight")
|
| 13 |
+
if not os.path.exists(args.pretrain):
|
| 14 |
+
raise FileExistsError(f"{args.pretrain} Not exists")
|
| 15 |
+
ckpt = torch.load(args.pretrain, map_location="cpu")
|
| 16 |
+
ckpt = remap_checkpoint_keys(ckpt)
|
| 17 |
+
load_state_dict(convnext, ckpt, weights_only=False)
|
| 18 |
+
model = Classifier(convnext, args.num_classes)
|
| 19 |
+
model = model.to(device)
|
| 20 |
+
if args.train == "freeze":
|
| 21 |
+
for key, value in model.encoder.named_parameters():
|
| 22 |
+
value.requires_grad = False
|
| 23 |
+
optimizer = LARS(model.head.parameters(), weight_decay=0, lr=args.lr)
|
| 24 |
+
else:
|
| 25 |
+
num_layers = sum(convnext.depths)
|
| 26 |
+
assigner = LayerDecayValueAssigner(
|
| 27 |
+
list(
|
| 28 |
+
args.layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2)
|
| 29 |
+
),
|
| 30 |
+
depths=convnext.depths,
|
| 31 |
+
layer_decay_type=args.layer_decay_type,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
skip = {}
|
| 35 |
+
if hasattr(model.encoder, "no_weight_decay"):
|
| 36 |
+
skip = model.encoder.no_weight_decay()
|
| 37 |
+
|
| 38 |
+
backbone_param_groups = get_parameter_groups(
|
| 39 |
+
model.encoder,
|
| 40 |
+
args.weight_decay,
|
| 41 |
+
skip,
|
| 42 |
+
assigner.get_layer_id,
|
| 43 |
+
assigner.get_scale,
|
| 44 |
+
)
|
| 45 |
+
decoder_param_groups = [
|
| 46 |
+
{"params": model.head.parameters(), "weight_decay": 0.0, "lr": args.lr}
|
| 47 |
+
]
|
| 48 |
+
|
| 49 |
+
optimizer = torch.optim.AdamW(
|
| 50 |
+
backbone_param_groups + decoder_param_groups, lr=args.lr
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
else:
|
| 54 |
+
raise NotImplementedError(f"unknown model: {args.model}")
|
| 55 |
+
|
| 56 |
+
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 57 |
+
|
| 58 |
+
print("Model = %s" % str(model))
|
| 59 |
+
print("number of params (M): %.2f" % (n_parameters / 1.0e6))
|
| 60 |
+
|
| 61 |
+
return model, optimizer
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def vit_backbone_parameters(
|
| 65 |
+
model: torch.nn.Module, weight_decay=1e-5, no_weight_decay_list=(), lr=1e-3
|
| 66 |
+
):
|
| 67 |
+
no_weight_decay_list = set(no_weight_decay_list)
|
| 68 |
+
decay = []
|
| 69 |
+
no_decay = []
|
| 70 |
+
|
| 71 |
+
for name, param in model.named_parameters():
|
| 72 |
+
if not param.requires_grad:
|
| 73 |
+
continue
|
| 74 |
+
|
| 75 |
+
if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_list:
|
| 76 |
+
no_decay.append(param)
|
| 77 |
+
else:
|
| 78 |
+
decay.append(param)
|
| 79 |
+
|
| 80 |
+
return [
|
| 81 |
+
{"params": no_decay, "weight_decay": 0.0, "lr": lr},
|
| 82 |
+
{"params": decay, "weight_decay": weight_decay, "lr": lr},
|
| 83 |
+
]
|
models/classifier.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class Classifier(nn.Module):
|
| 6 |
+
def __init__(self, encoder, num_classes, bottleneck_dim=256):
|
| 7 |
+
super().__init__()
|
| 8 |
+
self.encoder = encoder
|
| 9 |
+
self.embed_dim = self.encoder.embed_dim
|
| 10 |
+
self.head = torch.nn.Sequential(
|
| 11 |
+
nn.Linear(self.embed_dim, bottleneck_dim),
|
| 12 |
+
nn.BatchNorm1d(bottleneck_dim),
|
| 13 |
+
nn.ReLU(),
|
| 14 |
+
nn.Linear(bottleneck_dim, num_classes)
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
def forward(self, x):
|
| 18 |
+
x = self.encoder(x)
|
| 19 |
+
if type(x) == tuple:
|
| 20 |
+
x = x[0]
|
| 21 |
+
x = self.head(x)
|
| 22 |
+
return x
|
| 23 |
+
|
models/convnext_unter.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn.functional as F
|
| 2 |
+
from typing import Sequence, Tuple, Union
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from monai.networks.blocks.dynunet_block import UnetOutBlock
|
| 6 |
+
from monai.networks.blocks.unetr_block import (
|
| 7 |
+
UnetrBasicBlock,
|
| 8 |
+
UnetrPrUpBlock,
|
| 9 |
+
UnetrUpBlock,
|
| 10 |
+
)
|
| 11 |
+
from models.util import LayerNorm
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ConvnextUNETR_Decoder(nn.Module):
|
| 15 |
+
"""
|
| 16 |
+
UNETR based on: "Hatamizadeh et al.,
|
| 17 |
+
UNETR: Transformers for 3D Medical Image Segmentation <https://arxiv.org/abs/2103.10504>"
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
in_channels: int,
|
| 23 |
+
out_channels: int,
|
| 24 |
+
feature_size: int = 16,
|
| 25 |
+
norm_name: Union[Tuple, str] = "instance",
|
| 26 |
+
conv_block: bool = True,
|
| 27 |
+
res_block: bool = True,
|
| 28 |
+
spatial_dims: int = 3,
|
| 29 |
+
hidden_size = [96, 192, 384, 768]
|
| 30 |
+
) -> None:
|
| 31 |
+
|
| 32 |
+
super().__init__()
|
| 33 |
+
|
| 34 |
+
self.encoder1 = UnetrBasicBlock(
|
| 35 |
+
spatial_dims=spatial_dims,
|
| 36 |
+
in_channels=in_channels,
|
| 37 |
+
out_channels=feature_size,
|
| 38 |
+
kernel_size=3,
|
| 39 |
+
stride=1,
|
| 40 |
+
norm_name=norm_name,
|
| 41 |
+
res_block=res_block,
|
| 42 |
+
)
|
| 43 |
+
self.encoder2 = UnetrPrUpBlock(
|
| 44 |
+
spatial_dims=spatial_dims,
|
| 45 |
+
in_channels=hidden_size[0],
|
| 46 |
+
out_channels=feature_size * 2,
|
| 47 |
+
num_layer=0,
|
| 48 |
+
kernel_size=3,
|
| 49 |
+
stride=1,
|
| 50 |
+
upsample_kernel_size=2,
|
| 51 |
+
norm_name=norm_name,
|
| 52 |
+
conv_block=conv_block,
|
| 53 |
+
res_block=res_block,
|
| 54 |
+
)
|
| 55 |
+
self.encoder3 = UnetrPrUpBlock(
|
| 56 |
+
spatial_dims=spatial_dims,
|
| 57 |
+
in_channels=hidden_size[1],
|
| 58 |
+
out_channels=feature_size * 4,
|
| 59 |
+
num_layer=0,
|
| 60 |
+
kernel_size=3,
|
| 61 |
+
stride=1,
|
| 62 |
+
upsample_kernel_size=2,
|
| 63 |
+
norm_name=norm_name,
|
| 64 |
+
conv_block=conv_block,
|
| 65 |
+
res_block=res_block,
|
| 66 |
+
)
|
| 67 |
+
self.encoder4 = UnetrPrUpBlock(
|
| 68 |
+
spatial_dims=spatial_dims,
|
| 69 |
+
in_channels=hidden_size[2],
|
| 70 |
+
out_channels=feature_size * 8,
|
| 71 |
+
num_layer=0,
|
| 72 |
+
kernel_size=3,
|
| 73 |
+
stride=1,
|
| 74 |
+
upsample_kernel_size=2,
|
| 75 |
+
norm_name=norm_name,
|
| 76 |
+
conv_block=conv_block,
|
| 77 |
+
res_block=res_block,
|
| 78 |
+
)
|
| 79 |
+
self.decoder5 = UnetrUpBlock(
|
| 80 |
+
spatial_dims=spatial_dims,
|
| 81 |
+
in_channels=hidden_size[3],
|
| 82 |
+
out_channels=feature_size * 8,
|
| 83 |
+
kernel_size=3,
|
| 84 |
+
upsample_kernel_size=2,
|
| 85 |
+
norm_name=norm_name,
|
| 86 |
+
res_block=res_block,
|
| 87 |
+
)
|
| 88 |
+
self.decoder4 = UnetrUpBlock(
|
| 89 |
+
spatial_dims=spatial_dims,
|
| 90 |
+
in_channels=feature_size * 8,
|
| 91 |
+
out_channels=feature_size * 4,
|
| 92 |
+
kernel_size=3,
|
| 93 |
+
upsample_kernel_size=2,
|
| 94 |
+
norm_name=norm_name,
|
| 95 |
+
res_block=res_block,
|
| 96 |
+
)
|
| 97 |
+
self.decoder3 = UnetrUpBlock(
|
| 98 |
+
spatial_dims=spatial_dims,
|
| 99 |
+
in_channels=feature_size * 4,
|
| 100 |
+
out_channels=feature_size * 2,
|
| 101 |
+
kernel_size=3,
|
| 102 |
+
upsample_kernel_size=2,
|
| 103 |
+
norm_name=norm_name,
|
| 104 |
+
res_block=res_block,
|
| 105 |
+
)
|
| 106 |
+
self.decoder2 = UnetrUpBlock(
|
| 107 |
+
spatial_dims=spatial_dims,
|
| 108 |
+
in_channels=feature_size * 2,
|
| 109 |
+
out_channels=feature_size,
|
| 110 |
+
kernel_size=3,
|
| 111 |
+
upsample_kernel_size=2,
|
| 112 |
+
norm_name=norm_name,
|
| 113 |
+
res_block=res_block,
|
| 114 |
+
)
|
| 115 |
+
self.out = UnetOutBlock(
|
| 116 |
+
spatial_dims=spatial_dims,
|
| 117 |
+
in_channels=feature_size,
|
| 118 |
+
out_channels=out_channels,
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
def forward(self, x, x1, x2, x3, x4):
|
| 122 |
+
enc1 = self.encoder1(x)
|
| 123 |
+
enc2 = self.encoder2(x1)
|
| 124 |
+
enc3 = self.encoder3(x2)
|
| 125 |
+
enc4 = self.encoder4(x3)
|
| 126 |
+
dec3 = self.decoder5(x4, enc4)
|
| 127 |
+
dec2 = self.decoder4(dec3, enc3)
|
| 128 |
+
dec1 = self.decoder3(dec2, enc2)
|
| 129 |
+
out = self.decoder2(dec1, enc1)
|
| 130 |
+
mask = self.out(out)
|
| 131 |
+
return mask
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class ConvnextUNETR(nn.Module):
|
| 135 |
+
"""
|
| 136 |
+
UNETR based on: "Hatamizadeh et al.,
|
| 137 |
+
UNETR: Transformers for 3D Medical Image Segmentation <https://arxiv.org/abs/2103.10504>"
|
| 138 |
+
"""
|
| 139 |
+
|
| 140 |
+
def __init__(
|
| 141 |
+
self,
|
| 142 |
+
in_channels: int,
|
| 143 |
+
out_channels: int,
|
| 144 |
+
convnext,
|
| 145 |
+
feature_size: int = 16,
|
| 146 |
+
norm_name: Union[Tuple, str] = "instance",
|
| 147 |
+
conv_block: bool = True,
|
| 148 |
+
res_block: bool = True,
|
| 149 |
+
spatial_dims: int = 3,
|
| 150 |
+
hidden_size = [96, 192, 384, 768]
|
| 151 |
+
) -> None:
|
| 152 |
+
|
| 153 |
+
super().__init__()
|
| 154 |
+
|
| 155 |
+
self.encoder = convnext
|
| 156 |
+
|
| 157 |
+
self.norm1 = LayerNorm(hidden_size[0], eps=1e-6, data_format="channels_first")
|
| 158 |
+
self.norm2 = LayerNorm(hidden_size[1], eps=1e-6, data_format="channels_first")
|
| 159 |
+
self.norm3 = LayerNorm(hidden_size[2], eps=1e-6, data_format="channels_first")
|
| 160 |
+
|
| 161 |
+
self.decoder = ConvnextUNETR_Decoder(
|
| 162 |
+
in_channels=in_channels,
|
| 163 |
+
out_channels=out_channels,
|
| 164 |
+
feature_size=feature_size,
|
| 165 |
+
norm_name=norm_name,
|
| 166 |
+
conv_block=conv_block,
|
| 167 |
+
res_block=res_block,
|
| 168 |
+
spatial_dims=spatial_dims,
|
| 169 |
+
hidden_size=hidden_size
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
def forward(self, x):
|
| 173 |
+
_, hidden_states_out = self.encoder(x, ret_hids=True)
|
| 174 |
+
x1, x2, x3, x4 = hidden_states_out
|
| 175 |
+
x1 = self.norm1(x1)
|
| 176 |
+
x2 = self.norm2(x2)
|
| 177 |
+
x3 = self.norm3(x3)
|
| 178 |
+
x4 = x4.permute(0, 2, 3, 4, 1) # (N, C, H, W, D) -> (N, H, W, D, C)
|
| 179 |
+
x4 = self.encoder.norm(x4)
|
| 180 |
+
x4 = x4.permute(0, 4, 1, 2, 3)
|
| 181 |
+
mask = self.decoder(x, x1, x2, x3, x4)
|
| 182 |
+
return mask
|
models/convnextv2.py
ADDED
|
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
# All rights reserved.
|
| 4 |
+
|
| 5 |
+
# This source code is licensed under the license found in the
|
| 6 |
+
# LICENSE file in the root directory of this source tree.
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from timm.models.layers import trunc_normal_, DropPath
|
| 12 |
+
from models.util import LayerNorm, GRN
|
| 13 |
+
from collections import OrderedDict
|
| 14 |
+
import math
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Block(nn.Module):
|
| 18 |
+
"""ConvNeXtV2 Block.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
dim (int): Number of input channels.
|
| 22 |
+
drop_path (float): Stochastic depth rate. Default: 0.0
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(self, dim, drop_path=0.0):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.dwconv = nn.Conv3d(
|
| 28 |
+
dim, dim, kernel_size=7, padding=3, groups=dim
|
| 29 |
+
) # depthwise conv
|
| 30 |
+
self.norm = LayerNorm(dim, eps=1e-6)
|
| 31 |
+
self.pwconv1 = nn.Linear(
|
| 32 |
+
dim, 4 * dim
|
| 33 |
+
) # pointwise/1x1 convs, implemented with linear layers
|
| 34 |
+
self.act = nn.GELU()
|
| 35 |
+
self.grn = GRN(4 * dim)
|
| 36 |
+
self.pwconv2 = nn.Linear(4 * dim, dim)
|
| 37 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 38 |
+
|
| 39 |
+
def forward(self, x):
|
| 40 |
+
input = x
|
| 41 |
+
x = self.dwconv(x)
|
| 42 |
+
x = x.permute(0, 2, 3, 4, 1) # (N, C, H, W, D) -> (N, H, W, D, C)
|
| 43 |
+
x = self.norm(x)
|
| 44 |
+
x = self.pwconv1(x)
|
| 45 |
+
x = self.act(x)
|
| 46 |
+
x = self.grn(x)
|
| 47 |
+
x = self.pwconv2(x)
|
| 48 |
+
x = x.permute(0, 4, 1, 2, 3) # (N, H, W, D, C) -> (N, C, H, W, D)
|
| 49 |
+
x = input + self.drop_path(x)
|
| 50 |
+
return x
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class ConvNeXtV2(nn.Module):
|
| 54 |
+
"""ConvNeXt V2
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
in_chans (int): Number of input image channels. Default: 3
|
| 58 |
+
num_classes (int): Number of classes for classification head. Default: 1000
|
| 59 |
+
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
|
| 60 |
+
dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
|
| 61 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.
|
| 62 |
+
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
def __init__(
|
| 66 |
+
self,
|
| 67 |
+
in_chans=3,
|
| 68 |
+
depths=[3, 3, 9, 3],
|
| 69 |
+
dims=[96, 192, 384, 768],
|
| 70 |
+
drop_path_rate=0.0,
|
| 71 |
+
):
|
| 72 |
+
super().__init__()
|
| 73 |
+
self.depths = depths
|
| 74 |
+
self.downsample_layers = (
|
| 75 |
+
nn.ModuleList()
|
| 76 |
+
) # stem and 3 intermediate downsampling conv layers
|
| 77 |
+
stem = nn.Sequential(
|
| 78 |
+
nn.Conv3d(in_chans, dims[0], kernel_size=4, stride=4),
|
| 79 |
+
LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
|
| 80 |
+
)
|
| 81 |
+
self.downsample_layers.append(stem)
|
| 82 |
+
for i in range(3):
|
| 83 |
+
if i == 2:
|
| 84 |
+
stride = 1
|
| 85 |
+
else:
|
| 86 |
+
stride = 2
|
| 87 |
+
downsample_layer = nn.Sequential(
|
| 88 |
+
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
|
| 89 |
+
nn.Conv3d(dims[i], dims[i + 1], kernel_size=stride, stride=stride),
|
| 90 |
+
)
|
| 91 |
+
self.downsample_layers.append(downsample_layer)
|
| 92 |
+
|
| 93 |
+
self.stages = (
|
| 94 |
+
nn.ModuleList()
|
| 95 |
+
) # 4 feature resolution stages, each consisting of multiple residual blocks
|
| 96 |
+
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
| 97 |
+
cur = 0
|
| 98 |
+
for i in range(4):
|
| 99 |
+
stage = nn.Sequential(
|
| 100 |
+
*[
|
| 101 |
+
Block(dim=dims[i], drop_path=dp_rates[cur + j])
|
| 102 |
+
for j in range(depths[i])
|
| 103 |
+
]
|
| 104 |
+
)
|
| 105 |
+
self.stages.append(stage)
|
| 106 |
+
cur += depths[i]
|
| 107 |
+
|
| 108 |
+
self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
|
| 109 |
+
# self.head = nn.Linear(dims[-1], num_classes)
|
| 110 |
+
|
| 111 |
+
self.apply(self._init_weights)
|
| 112 |
+
# self.head.weight.data.mul_(head_init_scale)
|
| 113 |
+
# self.head.bias.data.mul_(head_init_scale)
|
| 114 |
+
self.embed_dim = dims[-1]
|
| 115 |
+
|
| 116 |
+
def _init_weights(self, m):
|
| 117 |
+
if isinstance(m, (nn.Conv3d, nn.Linear)):
|
| 118 |
+
trunc_normal_(m.weight, std=0.02)
|
| 119 |
+
nn.init.constant_(m.bias, 0)
|
| 120 |
+
|
| 121 |
+
def forward_features(self, x):
|
| 122 |
+
hidden_states_out = []
|
| 123 |
+
for i in range(4):
|
| 124 |
+
x = self.downsample_layers[i](x)
|
| 125 |
+
x = self.stages[i](x)
|
| 126 |
+
hidden_states_out.append(x)
|
| 127 |
+
return self.norm(x.mean([-3, -2, -1])), hidden_states_out # global average pooling, (N, C, H, W, D) -> (N, C)
|
| 128 |
+
|
| 129 |
+
def forward(self, x, ret_hids=False):
|
| 130 |
+
x, hidden_states_out = self.forward_features(x)
|
| 131 |
+
if ret_hids:
|
| 132 |
+
return x, hidden_states_out
|
| 133 |
+
else:
|
| 134 |
+
return x
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def convnextv2_atto(**kwargs):
|
| 139 |
+
model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[40, 80, 160, 320], **kwargs)
|
| 140 |
+
return model
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def convnextv2_femto(**kwargs):
|
| 144 |
+
model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[48, 96, 192, 384], **kwargs)
|
| 145 |
+
return model
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def convnext_pico(**kwargs):
|
| 149 |
+
model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[64, 128, 256, 512], **kwargs)
|
| 150 |
+
return model
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def convnextv2_nano(**kwargs):
|
| 154 |
+
model = ConvNeXtV2(depths=[2, 2, 8, 2], dims=[80, 160, 320, 640], **kwargs)
|
| 155 |
+
return model
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def convnextv2_tiny(**kwargs):
|
| 159 |
+
model = ConvNeXtV2(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
|
| 160 |
+
return model
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def convnextv2_base(**kwargs):
|
| 164 |
+
model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
|
| 165 |
+
return model
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def convnextv2_large(**kwargs):
|
| 169 |
+
model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
|
| 170 |
+
return model
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def convnextv2_huge(**kwargs):
|
| 174 |
+
model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], **kwargs)
|
| 175 |
+
return model
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def remap_checkpoint_keys(ckpt):
|
| 179 |
+
new_ckpt = OrderedDict()
|
| 180 |
+
ckpt = ckpt["model"]
|
| 181 |
+
|
| 182 |
+
checkpoint_model_keys = list(ckpt.keys())
|
| 183 |
+
for k in checkpoint_model_keys:
|
| 184 |
+
if "decoder" in k or "mask_token" in k or "proj" in k or "pred" in k:
|
| 185 |
+
print(f"Removing key {k} from pretrained checkpoint")
|
| 186 |
+
del ckpt[k]
|
| 187 |
+
|
| 188 |
+
for k, v in ckpt.items():
|
| 189 |
+
if k.startswith("encoder"):
|
| 190 |
+
k = ".".join(k.split(".")[1:]) # remove encoder in the name
|
| 191 |
+
if k.endswith("kernel"):
|
| 192 |
+
k = ".".join(k.split(".")[:-1]) # remove kernel in the name
|
| 193 |
+
new_k = k + ".weight"
|
| 194 |
+
if len(v.shape) == 3: # resahpe standard convolution
|
| 195 |
+
kv, in_dim, out_dim = v.shape
|
| 196 |
+
# ks = int(math.sqrt(kv))
|
| 197 |
+
# # pow(kv, 1/3)
|
| 198 |
+
# new_ckpt[new_k] = v.permute(2, 1, 0).\
|
| 199 |
+
# reshape(out_dim, in_dim, ks, ks).transpose(3, 2)
|
| 200 |
+
ks = int(
|
| 201 |
+
round(kv ** (1 / 3))
|
| 202 |
+
) # calculate kernel size assuming cubic kernel
|
| 203 |
+
new_ckpt[new_k] = (
|
| 204 |
+
v.permute(2, 1, 0)
|
| 205 |
+
.reshape(out_dim, in_dim, ks, ks, ks)
|
| 206 |
+
.permute(0, 1, 4, 3, 2)
|
| 207 |
+
)
|
| 208 |
+
elif len(v.shape) == 2: # reshape depthwise convolution
|
| 209 |
+
kv, dim = v.shape
|
| 210 |
+
# ks = int(math.sqrt(kv))
|
| 211 |
+
# new_ckpt[new_k] = v.permute(1, 0).\
|
| 212 |
+
# reshape(dim, 1, ks, ks).transpose(3, 2)
|
| 213 |
+
if new_k == "downsample_layers.3.1.weight":
|
| 214 |
+
new_ckpt[new_k] = (
|
| 215 |
+
v.permute(1, 0).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
| 216 |
+
)
|
| 217 |
+
else:
|
| 218 |
+
ks = int(round(kv ** (1 / 3)))
|
| 219 |
+
new_ckpt[new_k] = (
|
| 220 |
+
v.permute(1, 0)
|
| 221 |
+
.reshape(dim, 1, ks, ks, ks)
|
| 222 |
+
.permute(0, 1, 4, 3, 2)
|
| 223 |
+
)
|
| 224 |
+
continue
|
| 225 |
+
elif "ln" in k or "linear" in k:
|
| 226 |
+
k = k.split(".")
|
| 227 |
+
k.pop(-2) # remove ln and linear in the name
|
| 228 |
+
new_k = ".".join(k)
|
| 229 |
+
else:
|
| 230 |
+
new_k = k
|
| 231 |
+
new_ckpt[new_k] = v
|
| 232 |
+
|
| 233 |
+
# reshape grn affine parameters and biases
|
| 234 |
+
for k, v in new_ckpt.items():
|
| 235 |
+
if k.endswith("bias") and len(v.shape) != 1:
|
| 236 |
+
new_ckpt[k] = v.reshape(-1)
|
| 237 |
+
elif "grn" in k:
|
| 238 |
+
new_ckpt[k] = v.unsqueeze(0).unsqueeze(1).unsqueeze(0)
|
| 239 |
+
return new_ckpt
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def load_state_dict(
|
| 243 |
+
model, state_dict, prefix="", ignore_missing="relative_position_index"
|
| 244 |
+
):
|
| 245 |
+
missing_keys = []
|
| 246 |
+
unexpected_keys = []
|
| 247 |
+
error_msgs = []
|
| 248 |
+
# copy state_dict so _load_from_state_dict can modify it
|
| 249 |
+
metadata = getattr(state_dict, "_metadata", None)
|
| 250 |
+
state_dict = state_dict.copy()
|
| 251 |
+
if metadata is not None:
|
| 252 |
+
state_dict._metadata = metadata
|
| 253 |
+
|
| 254 |
+
def load(module, prefix=""):
|
| 255 |
+
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
| 256 |
+
module._load_from_state_dict(
|
| 257 |
+
state_dict,
|
| 258 |
+
prefix,
|
| 259 |
+
local_metadata,
|
| 260 |
+
True,
|
| 261 |
+
missing_keys,
|
| 262 |
+
unexpected_keys,
|
| 263 |
+
error_msgs,
|
| 264 |
+
)
|
| 265 |
+
for name, child in module._modules.items():
|
| 266 |
+
if child is not None:
|
| 267 |
+
load(child, prefix + name + ".")
|
| 268 |
+
|
| 269 |
+
load(model, prefix=prefix)
|
| 270 |
+
|
| 271 |
+
warn_missing_keys = []
|
| 272 |
+
ignore_missing_keys = []
|
| 273 |
+
for key in missing_keys:
|
| 274 |
+
keep_flag = True
|
| 275 |
+
for ignore_key in ignore_missing.split("|"):
|
| 276 |
+
if ignore_key in key:
|
| 277 |
+
keep_flag = False
|
| 278 |
+
break
|
| 279 |
+
if keep_flag:
|
| 280 |
+
warn_missing_keys.append(key)
|
| 281 |
+
else:
|
| 282 |
+
ignore_missing_keys.append(key)
|
| 283 |
+
|
| 284 |
+
missing_keys = warn_missing_keys
|
| 285 |
+
|
| 286 |
+
if len(missing_keys) > 0:
|
| 287 |
+
print(
|
| 288 |
+
"Weights of {} not initialized from pretrained model: {}".format(
|
| 289 |
+
model.__class__.__name__, missing_keys
|
| 290 |
+
)
|
| 291 |
+
)
|
| 292 |
+
if len(unexpected_keys) > 0:
|
| 293 |
+
print(
|
| 294 |
+
"Weights from pretrained model not used in {}: {}".format(
|
| 295 |
+
model.__class__.__name__, unexpected_keys
|
| 296 |
+
)
|
| 297 |
+
)
|
| 298 |
+
if len(ignore_missing_keys) > 0:
|
| 299 |
+
print(
|
| 300 |
+
"Ignored weights of {} not initialized from pretrained model: {}".format(
|
| 301 |
+
model.__class__.__name__, ignore_missing_keys
|
| 302 |
+
)
|
| 303 |
+
)
|
| 304 |
+
if len(error_msgs) > 0:
|
| 305 |
+
print("\n".join(error_msgs))
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
# if __name__ == 'main':
|
| 309 |
+
# model = convnextv2_base().cuda()
|
| 310 |
+
# x = torch.rand(1,3,256,256,32).cuda()
|
| 311 |
+
# print(model(x).shape)
|
models/upernet_module.py
ADDED
|
@@ -0,0 +1,451 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional, Tuple, Union
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
from models.util import LayerNorm, GRN
|
| 5 |
+
|
| 6 |
+
class UperNetConvModule(nn.Module):
|
| 7 |
+
"""
|
| 8 |
+
A convolutional block that bundles conv/norm/activation layers. This block simplifies the usage of convolution
|
| 9 |
+
layers, which are commonly used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
in_channels: int,
|
| 15 |
+
out_channels: int,
|
| 16 |
+
kernel_size: Union[int, Tuple[int, int]],
|
| 17 |
+
padding: Union[int, Tuple[int, int], str] = 0,
|
| 18 |
+
bias: bool = False,
|
| 19 |
+
dilation: Union[int, Tuple[int, int]] = 1,
|
| 20 |
+
) -> None:
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.conv = nn.Conv3d(
|
| 23 |
+
in_channels=in_channels,
|
| 24 |
+
out_channels=out_channels,
|
| 25 |
+
kernel_size=kernel_size,
|
| 26 |
+
padding=padding,
|
| 27 |
+
bias=bias,
|
| 28 |
+
dilation=dilation,
|
| 29 |
+
)
|
| 30 |
+
self.batch_norm = LayerNorm(out_channels, eps=1e-6, data_format="channels_first") # nn.BatchNorm3d(out_channels)
|
| 31 |
+
self.activation = nn.GELU()
|
| 32 |
+
|
| 33 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 34 |
+
output = self.conv(input)
|
| 35 |
+
output = self.batch_norm(output)
|
| 36 |
+
output = self.activation(output)
|
| 37 |
+
|
| 38 |
+
return output
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class UperNetPyramidPoolingBlock(nn.Module):
|
| 42 |
+
def __init__(self, pool_scale: int, in_channels: int, channels: int) -> None:
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.layers = [
|
| 45 |
+
nn.AdaptiveAvgPool3d(pool_scale),
|
| 46 |
+
UperNetConvModule(in_channels, channels, kernel_size=1),
|
| 47 |
+
]
|
| 48 |
+
for i, layer in enumerate(self.layers):
|
| 49 |
+
self.add_module(str(i), layer)
|
| 50 |
+
|
| 51 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 52 |
+
hidden_state = input
|
| 53 |
+
for layer in self.layers:
|
| 54 |
+
hidden_state = layer(hidden_state)
|
| 55 |
+
return hidden_state
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class UperNetPyramidPoolingModule(nn.Module):
|
| 59 |
+
"""
|
| 60 |
+
Pyramid Pooling Module (PPM) used in PSPNet.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
pool_scales (`Tuple[int]`):
|
| 64 |
+
Pooling scales used in Pooling Pyramid Module.
|
| 65 |
+
in_channels (`int`):
|
| 66 |
+
Input channels.
|
| 67 |
+
channels (`int`):
|
| 68 |
+
Channels after modules, before conv_seg.
|
| 69 |
+
align_corners (`bool`):
|
| 70 |
+
align_corners argument of F.interpolate.
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
def __init__(
|
| 74 |
+
self,
|
| 75 |
+
pool_scales: Tuple[int, ...],
|
| 76 |
+
in_channels: int,
|
| 77 |
+
channels: int,
|
| 78 |
+
align_corners: bool,
|
| 79 |
+
) -> None:
|
| 80 |
+
super().__init__()
|
| 81 |
+
self.pool_scales = pool_scales
|
| 82 |
+
self.align_corners = align_corners
|
| 83 |
+
self.in_channels = in_channels
|
| 84 |
+
self.channels = channels
|
| 85 |
+
self.blocks = []
|
| 86 |
+
for i, pool_scale in enumerate(pool_scales):
|
| 87 |
+
block = UperNetPyramidPoolingBlock(
|
| 88 |
+
pool_scale=pool_scale, in_channels=in_channels, channels=channels
|
| 89 |
+
)
|
| 90 |
+
self.blocks.append(block)
|
| 91 |
+
self.add_module(str(i), block)
|
| 92 |
+
|
| 93 |
+
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
| 94 |
+
ppm_outs = []
|
| 95 |
+
for ppm in self.blocks:
|
| 96 |
+
ppm_out = ppm(x)
|
| 97 |
+
upsampled_ppm_out = nn.functional.interpolate(
|
| 98 |
+
ppm_out,
|
| 99 |
+
size=x.size()[2:],
|
| 100 |
+
mode="trilinear",
|
| 101 |
+
align_corners=self.align_corners,
|
| 102 |
+
)
|
| 103 |
+
ppm_outs.append(upsampled_ppm_out)
|
| 104 |
+
return ppm_outs
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class UperNetHead(nn.Module):
|
| 108 |
+
"""
|
| 109 |
+
Unified Perceptual Parsing for Scene Understanding. This head is the implementation of
|
| 110 |
+
[UPerNet](https://arxiv.org/abs/1807.10221).
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
def __init__(self, in_channels, pool_scales, hidden_size, out_channels):
|
| 114 |
+
super().__init__()
|
| 115 |
+
self.pool_scales = pool_scales # e.g. (1, 2, 3, 6)
|
| 116 |
+
self.in_channels = in_channels
|
| 117 |
+
self.channels = hidden_size
|
| 118 |
+
self.align_corners = False
|
| 119 |
+
self.classifier = nn.Conv3d(self.channels, out_channels, kernel_size=1)
|
| 120 |
+
|
| 121 |
+
# PSP Module
|
| 122 |
+
self.psp_modules = UperNetPyramidPoolingModule(
|
| 123 |
+
self.pool_scales,
|
| 124 |
+
self.in_channels[-1],
|
| 125 |
+
self.channels,
|
| 126 |
+
align_corners=self.align_corners,
|
| 127 |
+
)
|
| 128 |
+
self.bottleneck = UperNetConvModule(
|
| 129 |
+
self.in_channels[-1] + len(self.pool_scales) * self.channels,
|
| 130 |
+
self.channels,
|
| 131 |
+
kernel_size=3,
|
| 132 |
+
padding=1,
|
| 133 |
+
)
|
| 134 |
+
# FPN Module
|
| 135 |
+
self.lateral_convs = nn.ModuleList()
|
| 136 |
+
self.fpn_convs = nn.ModuleList()
|
| 137 |
+
for in_channels in self.in_channels[:-1]: # skip the top layer
|
| 138 |
+
l_conv = UperNetConvModule(in_channels, self.channels, kernel_size=1)
|
| 139 |
+
fpn_conv = UperNetConvModule(
|
| 140 |
+
self.channels, self.channels, kernel_size=3, padding=1
|
| 141 |
+
)
|
| 142 |
+
self.lateral_convs.append(l_conv)
|
| 143 |
+
self.fpn_convs.append(fpn_conv)
|
| 144 |
+
|
| 145 |
+
self.fpn_bottleneck = UperNetConvModule(
|
| 146 |
+
len(self.in_channels) * self.channels,
|
| 147 |
+
self.channels,
|
| 148 |
+
kernel_size=3,
|
| 149 |
+
padding=1,
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
def init_weights(self):
|
| 153 |
+
self.apply(self._init_weights)
|
| 154 |
+
|
| 155 |
+
def _init_weights(self, module):
|
| 156 |
+
if isinstance(module, nn.Conv3d):
|
| 157 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
| 158 |
+
if module.bias is not None:
|
| 159 |
+
module.bias.data.zero_()
|
| 160 |
+
|
| 161 |
+
def psp_forward(self, inputs):
|
| 162 |
+
x = inputs[-1]
|
| 163 |
+
psp_outs = [x]
|
| 164 |
+
psp_outs.extend(self.psp_modules(x))
|
| 165 |
+
psp_outs = torch.cat(psp_outs, dim=1)
|
| 166 |
+
output = self.bottleneck(psp_outs)
|
| 167 |
+
|
| 168 |
+
return output
|
| 169 |
+
|
| 170 |
+
def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
|
| 171 |
+
# build laterals
|
| 172 |
+
laterals = [
|
| 173 |
+
lateral_conv(encoder_hidden_states[i])
|
| 174 |
+
for i, lateral_conv in enumerate(self.lateral_convs)
|
| 175 |
+
]
|
| 176 |
+
|
| 177 |
+
laterals.append(self.psp_forward(encoder_hidden_states))
|
| 178 |
+
|
| 179 |
+
# build top-down path
|
| 180 |
+
used_backbone_levels = len(laterals)
|
| 181 |
+
for i in range(used_backbone_levels - 1, 0, -1):
|
| 182 |
+
prev_shape = laterals[i - 1].shape[2:]
|
| 183 |
+
laterals[i - 1] = laterals[i - 1] + nn.functional.interpolate(
|
| 184 |
+
laterals[i],
|
| 185 |
+
size=prev_shape,
|
| 186 |
+
mode="trilinear",
|
| 187 |
+
align_corners=self.align_corners,
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
# build outputs
|
| 191 |
+
fpn_outs = [
|
| 192 |
+
self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1)
|
| 193 |
+
]
|
| 194 |
+
# append psp feature
|
| 195 |
+
fpn_outs.append(laterals[-1])
|
| 196 |
+
|
| 197 |
+
for i in range(used_backbone_levels - 1, 0, -1):
|
| 198 |
+
fpn_outs[i] = nn.functional.interpolate(
|
| 199 |
+
fpn_outs[i],
|
| 200 |
+
size=fpn_outs[0].shape[2:],
|
| 201 |
+
mode="trilinear",
|
| 202 |
+
align_corners=self.align_corners,
|
| 203 |
+
)
|
| 204 |
+
fpn_outs = torch.cat(fpn_outs, dim=1)
|
| 205 |
+
output = self.fpn_bottleneck(fpn_outs)
|
| 206 |
+
output = self.classifier(output)
|
| 207 |
+
|
| 208 |
+
return output
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class UperNetFCNHead(nn.Module):
|
| 212 |
+
"""
|
| 213 |
+
Fully Convolution Networks for Semantic Segmentation. This head is the implementation of
|
| 214 |
+
[FCNNet](https://arxiv.org/abs/1411.4038>).
|
| 215 |
+
|
| 216 |
+
Args:
|
| 217 |
+
in_channels (int):
|
| 218 |
+
Number of input channels.
|
| 219 |
+
kernel_size (int):
|
| 220 |
+
The kernel size for convs in the head. Default: 3.
|
| 221 |
+
dilation (int):
|
| 222 |
+
The dilation rate for convs in the head. Default: 1.
|
| 223 |
+
"""
|
| 224 |
+
|
| 225 |
+
def __init__(
|
| 226 |
+
self,
|
| 227 |
+
in_channels,
|
| 228 |
+
hidden_size,
|
| 229 |
+
num_convs,
|
| 230 |
+
out_channels,
|
| 231 |
+
concat_input=False,
|
| 232 |
+
in_index: int = 2,
|
| 233 |
+
kernel_size: int = 3,
|
| 234 |
+
dilation: Union[int, Tuple[int, int]] = 1,
|
| 235 |
+
) -> None:
|
| 236 |
+
super().__init__()
|
| 237 |
+
|
| 238 |
+
self.in_channels = in_channels[in_index]
|
| 239 |
+
self.channels = hidden_size
|
| 240 |
+
self.num_convs = num_convs
|
| 241 |
+
self.concat_input = concat_input
|
| 242 |
+
self.in_index = in_index
|
| 243 |
+
|
| 244 |
+
conv_padding = (kernel_size // 2) * dilation
|
| 245 |
+
convs = []
|
| 246 |
+
convs.append(
|
| 247 |
+
UperNetConvModule(
|
| 248 |
+
self.in_channels,
|
| 249 |
+
self.channels,
|
| 250 |
+
kernel_size=kernel_size,
|
| 251 |
+
padding=conv_padding,
|
| 252 |
+
dilation=dilation,
|
| 253 |
+
)
|
| 254 |
+
)
|
| 255 |
+
for i in range(self.num_convs - 1):
|
| 256 |
+
convs.append(
|
| 257 |
+
UperNetConvModule(
|
| 258 |
+
self.channels,
|
| 259 |
+
self.channels,
|
| 260 |
+
kernel_size=kernel_size,
|
| 261 |
+
padding=conv_padding,
|
| 262 |
+
dilation=dilation,
|
| 263 |
+
)
|
| 264 |
+
)
|
| 265 |
+
if self.num_convs == 0:
|
| 266 |
+
self.convs = nn.Identity()
|
| 267 |
+
else:
|
| 268 |
+
self.convs = nn.Sequential(*convs)
|
| 269 |
+
if self.concat_input:
|
| 270 |
+
self.conv_cat = UperNetConvModule(
|
| 271 |
+
self.in_channels + self.channels,
|
| 272 |
+
self.channels,
|
| 273 |
+
kernel_size=kernel_size,
|
| 274 |
+
padding=kernel_size // 2,
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
self.classifier = nn.Conv3d(self.channels, out_channels, kernel_size=1)
|
| 278 |
+
|
| 279 |
+
def init_weights(self):
|
| 280 |
+
self.apply(self._init_weights)
|
| 281 |
+
|
| 282 |
+
def _init_weights(self, module):
|
| 283 |
+
if isinstance(module, nn.Conv3d):
|
| 284 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
| 285 |
+
if module.bias is not None:
|
| 286 |
+
module.bias.data.zero_()
|
| 287 |
+
|
| 288 |
+
def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
|
| 289 |
+
# just take the relevant feature maps
|
| 290 |
+
hidden_states = encoder_hidden_states[self.in_index]
|
| 291 |
+
output = self.convs(hidden_states)
|
| 292 |
+
if self.concat_input:
|
| 293 |
+
output = self.conv_cat(torch.cat([hidden_states, output], dim=1))
|
| 294 |
+
output = self.classifier(output)
|
| 295 |
+
return output
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
class ViTAdapter(nn.Module):
|
| 299 |
+
def __init__(
|
| 300 |
+
self,
|
| 301 |
+
img_size=(64, 256, 256),
|
| 302 |
+
patch_size=(16, 32, 32),
|
| 303 |
+
embed_dim=768,
|
| 304 |
+
# out_indices=[3, 5, 7, 11],
|
| 305 |
+
):
|
| 306 |
+
super().__init__()
|
| 307 |
+
# self.out_indices = out_indices
|
| 308 |
+
|
| 309 |
+
self.grid_size = tuple(img_d // p_d for img_d, p_d in zip(img_size, patch_size))
|
| 310 |
+
self.hidden_size = embed_dim
|
| 311 |
+
|
| 312 |
+
if patch_size == (16, 32, 32):
|
| 313 |
+
self.fpn1 = nn.Sequential(
|
| 314 |
+
nn.ConvTranspose3d(
|
| 315 |
+
embed_dim, embed_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)
|
| 316 |
+
),
|
| 317 |
+
nn.BatchNorm3d(embed_dim),
|
| 318 |
+
nn.GELU(),
|
| 319 |
+
nn.ConvTranspose3d(embed_dim, embed_dim, kernel_size=2, stride=2),
|
| 320 |
+
nn.BatchNorm3d(embed_dim),
|
| 321 |
+
nn.GELU(),
|
| 322 |
+
nn.ConvTranspose3d(embed_dim, embed_dim, kernel_size=2, stride=2),
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
# 8
|
| 326 |
+
self.fpn2 = nn.Sequential(
|
| 327 |
+
nn.ConvTranspose3d(
|
| 328 |
+
embed_dim, embed_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)
|
| 329 |
+
),
|
| 330 |
+
nn.BatchNorm3d(embed_dim),
|
| 331 |
+
nn.GELU(),
|
| 332 |
+
nn.ConvTranspose3d(embed_dim, embed_dim, kernel_size=2, stride=2),
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
# 16
|
| 336 |
+
self.fpn3 = nn.Sequential(
|
| 337 |
+
nn.ConvTranspose3d(
|
| 338 |
+
embed_dim, embed_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)
|
| 339 |
+
),
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
# 32
|
| 343 |
+
self.fpn4 = nn.MaxPool3d(kernel_size=(2, 1, 1), stride=(2, 1, 1))
|
| 344 |
+
|
| 345 |
+
self.adapters = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
|
| 346 |
+
|
| 347 |
+
def proj_feat(self, x):
|
| 348 |
+
|
| 349 |
+
new_view = (x.size(0), *self.grid_size, self.hidden_size)
|
| 350 |
+
# print(f"x.shape: {x.shape}, expected: {new_view}, grid_size: {self.grid_size}")
|
| 351 |
+
x = x.view(new_view)
|
| 352 |
+
new_axes = (0, len(x.shape) - 1) + tuple(
|
| 353 |
+
d + 1 for d in range(len(self.grid_size))
|
| 354 |
+
)
|
| 355 |
+
x = x.permute(new_axes).contiguous()
|
| 356 |
+
return x
|
| 357 |
+
|
| 358 |
+
def forward(self, encoder_hidden_states):
|
| 359 |
+
output = []
|
| 360 |
+
# print(f"len_encoder_hidden: {len(encoder_hidden_states)}")
|
| 361 |
+
for index, op in zip(range(len(encoder_hidden_states)), self.adapters):
|
| 362 |
+
output.append(op(self.proj_feat(encoder_hidden_states[index])))
|
| 363 |
+
return output
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
class UperNet(nn.Module):
|
| 367 |
+
def __init__(
|
| 368 |
+
self,
|
| 369 |
+
encoder,
|
| 370 |
+
in_channels,
|
| 371 |
+
out_channels,
|
| 372 |
+
adapter=None,
|
| 373 |
+
out_indices=None,
|
| 374 |
+
pool_scales=[1, 2, 3, 6],
|
| 375 |
+
hidden_size=512,
|
| 376 |
+
auxiliary_channels=256,
|
| 377 |
+
use_auxiliary_head=True,
|
| 378 |
+
):
|
| 379 |
+
super().__init__()
|
| 380 |
+
self.encoder = encoder
|
| 381 |
+
self.adapter = adapter
|
| 382 |
+
self.out_indices = out_indices
|
| 383 |
+
self.decode_head = UperNetHead(
|
| 384 |
+
in_channels=in_channels,
|
| 385 |
+
pool_scales=pool_scales,
|
| 386 |
+
hidden_size=hidden_size,
|
| 387 |
+
out_channels=out_channels,
|
| 388 |
+
)
|
| 389 |
+
self.auxiliary_head = (
|
| 390 |
+
UperNetFCNHead(
|
| 391 |
+
in_channels=in_channels,
|
| 392 |
+
hidden_size=auxiliary_channels,
|
| 393 |
+
num_convs=1,
|
| 394 |
+
out_channels=out_channels,
|
| 395 |
+
)
|
| 396 |
+
if use_auxiliary_head
|
| 397 |
+
else None
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
self.hidden_norm = nn.ModuleList()
|
| 401 |
+
for in_channel in in_channels:
|
| 402 |
+
norm = LayerNorm(in_channel, eps=1e-6, data_format="channels_first") # nn.BatchNorm3d(out_channels)
|
| 403 |
+
self.hidden_norm.append(norm)
|
| 404 |
+
|
| 405 |
+
def forward(self, x):
|
| 406 |
+
# print(f"403 input x.shape: {x.shape}")
|
| 407 |
+
encoder_hidden_states = self.encoder(x, ret_hids=True)
|
| 408 |
+
# print(f"405 {type(encoder_hidden_states)}, encoder_hidden_states: {len(encoder_hidden_states)}")
|
| 409 |
+
# for i, hidden_state in enumerate(encoder_hidden_states):
|
| 410 |
+
# print(f"407 encoder_hidden_states[{i}]: {type(hidden_state)}, {len(hidden_state)}")
|
| 411 |
+
if isinstance(encoder_hidden_states, list) or isinstance(
|
| 412 |
+
encoder_hidden_states, Tuple
|
| 413 |
+
):
|
| 414 |
+
encoder_hidden_states = encoder_hidden_states[-1]
|
| 415 |
+
# print(f"410 {type(encoder_hidden_states)}, encoder_hidden_states: {len(encoder_hidden_states)}")
|
| 416 |
+
# for i, hidden_state in enumerate(encoder_hidden_states):
|
| 417 |
+
# print(f"412 encoder_hidden_states[{i}]: {hidden_state.shape}")
|
| 418 |
+
if self.out_indices:
|
| 419 |
+
encoder_hidden_states = [
|
| 420 |
+
encoder_hidden_states[i] for i in self.out_indices
|
| 421 |
+
]
|
| 422 |
+
|
| 423 |
+
encoder_hidden_states = [
|
| 424 |
+
norm(encoder_hidden_states[i])
|
| 425 |
+
for i, norm in enumerate(self.hidden_norm)
|
| 426 |
+
]
|
| 427 |
+
# print(f"415 encoder_hidden_states: {len(encoder_hidden_states)}")
|
| 428 |
+
# for i in range(len(encoder_hidden_states)):
|
| 429 |
+
# print(f"417 encoder_hidden_states[{i}]: {encoder_hidden_states[i].shape}")
|
| 430 |
+
|
| 431 |
+
if self.adapter:
|
| 432 |
+
encoder_hidden_states = self.adapter(encoder_hidden_states)
|
| 433 |
+
|
| 434 |
+
logits = self.decode_head(encoder_hidden_states)
|
| 435 |
+
logits = nn.functional.interpolate(
|
| 436 |
+
logits, size=x.shape[2:], mode="trilinear", align_corners=False
|
| 437 |
+
)
|
| 438 |
+
if not self.training:
|
| 439 |
+
return logits
|
| 440 |
+
|
| 441 |
+
auxiliary_logits = None
|
| 442 |
+
if self.auxiliary_head is not None:
|
| 443 |
+
auxiliary_logits = self.auxiliary_head(encoder_hidden_states)
|
| 444 |
+
auxiliary_logits = nn.functional.interpolate(
|
| 445 |
+
auxiliary_logits,
|
| 446 |
+
size=x.shape[2:],
|
| 447 |
+
mode="trilinear",
|
| 448 |
+
align_corners=False,
|
| 449 |
+
)
|
| 450 |
+
return [logits, auxiliary_logits]
|
| 451 |
+
return logits
|
models/util.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from itertools import chain
|
| 4 |
+
from typing import Callable
|
| 5 |
+
from torch.utils.checkpoint import checkpoint
|
| 6 |
+
|
| 7 |
+
import numpy.random as random
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
# from MinkowskiEngine import SparseTensor
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 13 |
+
|
| 14 |
+
# All rights reserved.
|
| 15 |
+
|
| 16 |
+
# This source code is licensed under the license found in the
|
| 17 |
+
# LICENSE file in the root directory of this source tree.
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# class MinkowskiGRN(nn.Module):
|
| 21 |
+
# """GRN layer for sparse tensors."""
|
| 22 |
+
|
| 23 |
+
# def __init__(self, dim):
|
| 24 |
+
# super().__init__()
|
| 25 |
+
# self.gamma = nn.Parameter(torch.zeros(1, dim))
|
| 26 |
+
# self.beta = nn.Parameter(torch.zeros(1, dim))
|
| 27 |
+
|
| 28 |
+
# def forward(self, x):
|
| 29 |
+
# cm = x.coordinate_manager
|
| 30 |
+
# in_key = x.coordinate_map_key
|
| 31 |
+
|
| 32 |
+
# Gx = torch.norm(x.F, p=2, dim=0, keepdim=True)
|
| 33 |
+
# Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
|
| 34 |
+
# return SparseTensor(
|
| 35 |
+
# self.gamma * (x.F * Nx) + self.beta + x.F,
|
| 36 |
+
# coordinate_map_key=in_key,
|
| 37 |
+
# coordinate_manager=cm,
|
| 38 |
+
# )
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class MinkowskiDropPath(nn.Module):
|
| 42 |
+
"""Drop Path for sparse tensors."""
|
| 43 |
+
|
| 44 |
+
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
|
| 45 |
+
super(MinkowskiDropPath, self).__init__()
|
| 46 |
+
self.drop_prob = drop_prob
|
| 47 |
+
self.scale_by_keep = scale_by_keep
|
| 48 |
+
|
| 49 |
+
def forward(self, x):
|
| 50 |
+
if self.drop_prob == 0.0 or not self.training:
|
| 51 |
+
return x
|
| 52 |
+
cm = x.coordinate_manager
|
| 53 |
+
in_key = x.coordinate_map_key
|
| 54 |
+
keep_prob = 1 - self.drop_prob
|
| 55 |
+
mask = (
|
| 56 |
+
torch.cat(
|
| 57 |
+
[
|
| 58 |
+
(
|
| 59 |
+
torch.ones(len(_))
|
| 60 |
+
if random.uniform(0, 1) > self.drop_prob
|
| 61 |
+
else torch.zeros(len(_))
|
| 62 |
+
)
|
| 63 |
+
for _ in x.decomposed_coordinates
|
| 64 |
+
]
|
| 65 |
+
)
|
| 66 |
+
.view(-1, 1)
|
| 67 |
+
.to(x.device)
|
| 68 |
+
)
|
| 69 |
+
if keep_prob > 0.0 and self.scale_by_keep:
|
| 70 |
+
mask.div_(keep_prob)
|
| 71 |
+
return SparseTensor(
|
| 72 |
+
x.F * mask, coordinate_map_key=in_key, coordinate_manager=cm
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class MinkowskiLayerNorm(nn.Module):
|
| 77 |
+
"""Channel-wise layer normalization for sparse tensors."""
|
| 78 |
+
|
| 79 |
+
def __init__(
|
| 80 |
+
self,
|
| 81 |
+
normalized_shape,
|
| 82 |
+
eps=1e-6,
|
| 83 |
+
):
|
| 84 |
+
super(MinkowskiLayerNorm, self).__init__()
|
| 85 |
+
self.ln = nn.LayerNorm(normalized_shape, eps=eps)
|
| 86 |
+
|
| 87 |
+
def forward(self, input):
|
| 88 |
+
output = self.ln(input.F)
|
| 89 |
+
return SparseTensor(
|
| 90 |
+
output,
|
| 91 |
+
coordinate_map_key=input.coordinate_map_key,
|
| 92 |
+
coordinate_manager=input.coordinate_manager,
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class LayerNorm(nn.Module):
|
| 97 |
+
"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
| 98 |
+
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
|
| 99 |
+
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
|
| 100 |
+
with shape (batch_size, channels, height, width).
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
|
| 104 |
+
super().__init__()
|
| 105 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
| 106 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
| 107 |
+
self.eps = eps
|
| 108 |
+
self.data_format = data_format
|
| 109 |
+
if self.data_format not in ["channels_last", "channels_first"]:
|
| 110 |
+
raise NotImplementedError
|
| 111 |
+
self.normalized_shape = (normalized_shape,)
|
| 112 |
+
|
| 113 |
+
def forward(self, x):
|
| 114 |
+
if self.data_format == "channels_last":
|
| 115 |
+
return F.layer_norm(
|
| 116 |
+
x, self.normalized_shape, self.weight, self.bias, self.eps
|
| 117 |
+
)
|
| 118 |
+
elif self.data_format == "channels_first":
|
| 119 |
+
if len(x.shape) == 3: # for vit adapter
|
| 120 |
+
u = x.mean(1, keepdim=True)
|
| 121 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
| 122 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
| 123 |
+
x = self.weight * x + self.bias
|
| 124 |
+
return x
|
| 125 |
+
else:
|
| 126 |
+
u = x.mean(1, keepdim=True)
|
| 127 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
| 128 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
| 129 |
+
x = self.weight[:, None, None, None] * x + self.bias[:, None, None, None]
|
| 130 |
+
return x
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class GRN(nn.Module):
|
| 134 |
+
"""GRN (Global Response Normalization) layer"""
|
| 135 |
+
|
| 136 |
+
def __init__(self, dim):
|
| 137 |
+
super().__init__()
|
| 138 |
+
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, 1, dim))
|
| 139 |
+
self.beta = nn.Parameter(torch.zeros(1, 1, 1, 1, dim))
|
| 140 |
+
|
| 141 |
+
def forward(self, x):
|
| 142 |
+
Gx = torch.norm(x, p=2, dim=(1, 2, 3), keepdim=True)
|
| 143 |
+
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
|
| 144 |
+
return self.gamma * (x * Nx) + self.beta + x
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def get_tokens(embed_dim: int, n_tokens: int) -> nn.Parameter:
|
| 148 |
+
"""Return a learnable token of shape (1, n_tokens, embed_dim).
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
embed_dim: number of embedding channels.
|
| 152 |
+
n_tokens: number of tokens.
|
| 153 |
+
|
| 154 |
+
Returns:
|
| 155 |
+
token: learnable token.
|
| 156 |
+
"""
|
| 157 |
+
token = nn.Parameter(torch.zeros(1, n_tokens, embed_dim))
|
| 158 |
+
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
|
| 159 |
+
nn.init.trunc_normal_(token, std=0.02, b=2.0)
|
| 160 |
+
return token
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def init_weights(m):
|
| 164 |
+
if isinstance(m, nn.Linear):
|
| 165 |
+
# we use xavier_uniform following official JAX ViT:
|
| 166 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
| 167 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 168 |
+
nn.init.constant_(m.bias, 0)
|
| 169 |
+
elif isinstance(m, nn.LayerNorm):
|
| 170 |
+
nn.init.constant_(m.bias, 0)
|
| 171 |
+
nn.init.constant_(m.weight, 1.0)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
"""Gradient checkpointing utilities.
|
| 175 |
+
|
| 176 |
+
Copied from
|
| 177 |
+
https://github.com/huggingface/pytorch-image-models/blob/f8979d4f50b7920c78511746f7315df8f1857bc5/timm/models/_manipulate.py
|
| 178 |
+
and added use_reentrant=False following warnings in pytorch docs.
|
| 179 |
+
"""
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def checkpoint_seq(
|
| 183 |
+
functions: nn.Sequential,
|
| 184 |
+
x: torch.Tensor,
|
| 185 |
+
every: int = 1,
|
| 186 |
+
flatten: bool = False,
|
| 187 |
+
skip_last: bool = False,
|
| 188 |
+
preserve_rng_state: bool = True,
|
| 189 |
+
) -> torch.Tensor:
|
| 190 |
+
r"""A helper function for checkpointing sequential models.
|
| 191 |
+
|
| 192 |
+
Sequential models execute a list of modules/functions in order
|
| 193 |
+
(sequentially). Therefore, we can divide such a sequence into segments
|
| 194 |
+
and checkpoint each segment. All segments except run in :func:`torch.no_grad`
|
| 195 |
+
manner, i.e., not storing the intermediate activations. The inputs of each
|
| 196 |
+
checkpointed segment will be saved for re-running the segment in the backward pass.
|
| 197 |
+
|
| 198 |
+
See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works.
|
| 199 |
+
|
| 200 |
+
.. warning::
|
| 201 |
+
Checkpointing currently only supports :func:`torch.autograd.backward`
|
| 202 |
+
and only if its `inputs` argument is not passed. :func:`torch.autograd.grad`
|
| 203 |
+
is not supported.
|
| 204 |
+
|
| 205 |
+
.. warning:
|
| 206 |
+
At least one of the inputs needs to have :code:`requires_grad=True` if
|
| 207 |
+
grads are needed for model inputs, otherwise the checkpointed part of the
|
| 208 |
+
model won't have gradients.
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
functions: A :class:`torch.nn.Sequential` or the list of modules or functions to run sequentially.
|
| 212 |
+
x: A Tensor that is input to :attr:`functions`
|
| 213 |
+
every: checkpoint every-n functions (default: 1)
|
| 214 |
+
flatten (bool): flatten nn.Sequential of nn.Sequentials
|
| 215 |
+
skip_last (bool): skip checkpointing the last function in the sequence if True
|
| 216 |
+
preserve_rng_state (bool, optional, default=True): Omit stashing and restoring
|
| 217 |
+
the RNG state during each checkpoint.
|
| 218 |
+
|
| 219 |
+
Returns:
|
| 220 |
+
Output of running :attr:`functions` sequentially on :attr:`*inputs`
|
| 221 |
+
|
| 222 |
+
Example:
|
| 223 |
+
>>> model = nn.Sequential(...)
|
| 224 |
+
>>> input_var = checkpoint_seq(model, input_var, every=2)
|
| 225 |
+
"""
|
| 226 |
+
|
| 227 |
+
def run_function(
|
| 228 |
+
start: int, end: int, functions: nn.Sequential
|
| 229 |
+
) -> Callable[[torch.Tensor], torch.Tensor]:
|
| 230 |
+
def forward(_x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
|
| 231 |
+
for j in range(start, end + 1):
|
| 232 |
+
_x = functions[j](_x)
|
| 233 |
+
return _x
|
| 234 |
+
|
| 235 |
+
return forward
|
| 236 |
+
|
| 237 |
+
if isinstance(functions, torch.nn.Sequential):
|
| 238 |
+
functions = functions.children()
|
| 239 |
+
if flatten:
|
| 240 |
+
functions = chain.from_iterable(functions)
|
| 241 |
+
if not isinstance(functions, (tuple, list)):
|
| 242 |
+
functions = tuple(functions)
|
| 243 |
+
|
| 244 |
+
num_checkpointed = len(functions)
|
| 245 |
+
if skip_last:
|
| 246 |
+
num_checkpointed -= 1
|
| 247 |
+
end = -1
|
| 248 |
+
for start in range(0, num_checkpointed, every):
|
| 249 |
+
end = min(start + every - 1, num_checkpointed - 1)
|
| 250 |
+
x = checkpoint(
|
| 251 |
+
run_function(start, end, functions),
|
| 252 |
+
x,
|
| 253 |
+
use_reentrant=False,
|
| 254 |
+
preserve_rng_state=preserve_rng_state,
|
| 255 |
+
)
|
| 256 |
+
if skip_last:
|
| 257 |
+
return run_function(end + 1, len(functions) - 1, functions)(x)
|
| 258 |
+
return x
|
requirements.txt
CHANGED
|
@@ -34,6 +34,3 @@ tqdm==4.67.1
|
|
| 34 |
# Additional dependencies for model architecture
|
| 35 |
einops==0.8.1
|
| 36 |
timm==1.0.15
|
| 37 |
-
|
| 38 |
-
# ProFound package from GitHub
|
| 39 |
-
git+https://github.com/pipiwang/ProFound.git@demo
|
|
|
|
| 34 |
# Additional dependencies for model architecture
|
| 35 |
einops==0.8.1
|
| 36 |
timm==1.0.15
|
|
|
|
|
|
|
|
|
util/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# ProFound utilities package
|
util/convnext_optim.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
# All rights reserved.
|
| 4 |
+
|
| 5 |
+
# This source code is licensed under the license found in the
|
| 6 |
+
# LICENSE file in the root directory of this source tree.
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch import optim as optim
|
| 11 |
+
import json
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_num_layer_for_convnext_single(var_name, depths):
|
| 15 |
+
"""
|
| 16 |
+
Each layer is assigned distinctive layer ids
|
| 17 |
+
"""
|
| 18 |
+
if var_name.startswith("downsample_layers"):
|
| 19 |
+
stage_id = int(var_name.split(".")[1])
|
| 20 |
+
layer_id = sum(depths[:stage_id]) + 1
|
| 21 |
+
return layer_id
|
| 22 |
+
|
| 23 |
+
elif var_name.startswith("stages"):
|
| 24 |
+
stage_id = int(var_name.split(".")[1])
|
| 25 |
+
block_id = int(var_name.split(".")[2])
|
| 26 |
+
layer_id = sum(depths[:stage_id]) + block_id + 1
|
| 27 |
+
return layer_id
|
| 28 |
+
|
| 29 |
+
else:
|
| 30 |
+
return sum(depths) + 1
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def get_num_layer_for_convnext(var_name):
|
| 34 |
+
"""
|
| 35 |
+
Divide [3, 3, 27, 3] layers into 12 groups; each group is three
|
| 36 |
+
consecutive blocks, including possible neighboring downsample layers;
|
| 37 |
+
adapted from https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py
|
| 38 |
+
"""
|
| 39 |
+
num_max_layer = 12
|
| 40 |
+
if var_name.startswith("downsample_layers"):
|
| 41 |
+
stage_id = int(var_name.split(".")[1])
|
| 42 |
+
if stage_id == 0:
|
| 43 |
+
layer_id = 0
|
| 44 |
+
elif stage_id == 1 or stage_id == 2:
|
| 45 |
+
layer_id = stage_id + 1
|
| 46 |
+
elif stage_id == 3:
|
| 47 |
+
layer_id = 12
|
| 48 |
+
return layer_id
|
| 49 |
+
|
| 50 |
+
elif var_name.startswith("stages"):
|
| 51 |
+
stage_id = int(var_name.split(".")[1])
|
| 52 |
+
block_id = int(var_name.split(".")[2])
|
| 53 |
+
if stage_id == 0 or stage_id == 1:
|
| 54 |
+
layer_id = stage_id + 1
|
| 55 |
+
elif stage_id == 2:
|
| 56 |
+
layer_id = 3 + block_id // 3
|
| 57 |
+
elif stage_id == 3:
|
| 58 |
+
layer_id = 12
|
| 59 |
+
return layer_id
|
| 60 |
+
else:
|
| 61 |
+
return num_max_layer + 1
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class LayerDecayValueAssigner(object):
|
| 65 |
+
def __init__(self, values, depths=[3, 3, 27, 3], layer_decay_type="single"):
|
| 66 |
+
self.values = values
|
| 67 |
+
self.depths = depths
|
| 68 |
+
self.layer_decay_type = layer_decay_type
|
| 69 |
+
|
| 70 |
+
def get_scale(self, layer_id):
|
| 71 |
+
return self.values[layer_id]
|
| 72 |
+
|
| 73 |
+
def get_layer_id(self, var_name):
|
| 74 |
+
if self.layer_decay_type == "single":
|
| 75 |
+
return get_num_layer_for_convnext_single(var_name, self.depths)
|
| 76 |
+
else:
|
| 77 |
+
return get_num_layer_for_convnext(var_name)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def get_parameter_groups(
|
| 81 |
+
model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None
|
| 82 |
+
):
|
| 83 |
+
parameter_group_names = {}
|
| 84 |
+
parameter_group_vars = {}
|
| 85 |
+
|
| 86 |
+
for name, param in model.named_parameters():
|
| 87 |
+
if not param.requires_grad:
|
| 88 |
+
continue # frozen weights
|
| 89 |
+
if (
|
| 90 |
+
len(param.shape) == 1
|
| 91 |
+
or name.endswith(".bias")
|
| 92 |
+
or name in skip_list
|
| 93 |
+
or name.endswith(".gamma")
|
| 94 |
+
or name.endswith(".beta")
|
| 95 |
+
):
|
| 96 |
+
group_name = "no_decay"
|
| 97 |
+
this_weight_decay = 0.0
|
| 98 |
+
else:
|
| 99 |
+
group_name = "decay"
|
| 100 |
+
this_weight_decay = weight_decay
|
| 101 |
+
if get_num_layer is not None:
|
| 102 |
+
layer_id = get_num_layer(name)
|
| 103 |
+
group_name = "layer_%d_%s" % (layer_id, group_name)
|
| 104 |
+
else:
|
| 105 |
+
layer_id = None
|
| 106 |
+
|
| 107 |
+
if group_name not in parameter_group_names:
|
| 108 |
+
if get_layer_scale is not None:
|
| 109 |
+
scale = get_layer_scale(layer_id)
|
| 110 |
+
else:
|
| 111 |
+
scale = 1.0
|
| 112 |
+
|
| 113 |
+
parameter_group_names[group_name] = {
|
| 114 |
+
"weight_decay": this_weight_decay,
|
| 115 |
+
"params": [],
|
| 116 |
+
"lr_scale": scale,
|
| 117 |
+
}
|
| 118 |
+
parameter_group_vars[group_name] = {
|
| 119 |
+
"weight_decay": this_weight_decay,
|
| 120 |
+
"params": [],
|
| 121 |
+
"lr_scale": scale,
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
parameter_group_vars[group_name]["params"].append(param)
|
| 125 |
+
parameter_group_names[group_name]["params"].append(name)
|
| 126 |
+
print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
|
| 127 |
+
return list(parameter_group_vars.values())
|
util/lars.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
# LARS optimizer, implementation from MoCo v3:
|
| 8 |
+
# https://github.com/facebookresearch/moco-v3
|
| 9 |
+
# --------------------------------------------------------
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
class LARS(torch.optim.Optimizer):
|
| 14 |
+
"""
|
| 15 |
+
LARS optimizer, no rate scaling or weight decay for parameters <= 1D.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001
|
| 20 |
+
):
|
| 21 |
+
defaults = dict(
|
| 22 |
+
lr=lr,
|
| 23 |
+
weight_decay=weight_decay,
|
| 24 |
+
momentum=momentum,
|
| 25 |
+
trust_coefficient=trust_coefficient,
|
| 26 |
+
)
|
| 27 |
+
super().__init__(params, defaults)
|
| 28 |
+
|
| 29 |
+
@torch.no_grad()
|
| 30 |
+
def step(self):
|
| 31 |
+
for g in self.param_groups:
|
| 32 |
+
for p in g["params"]:
|
| 33 |
+
dp = p.grad
|
| 34 |
+
|
| 35 |
+
if dp is None:
|
| 36 |
+
continue
|
| 37 |
+
|
| 38 |
+
if p.ndim > 1: # if not normalization gamma/beta or bias
|
| 39 |
+
dp = dp.add(p, alpha=g["weight_decay"])
|
| 40 |
+
param_norm = torch.norm(p)
|
| 41 |
+
update_norm = torch.norm(dp)
|
| 42 |
+
one = torch.ones_like(param_norm)
|
| 43 |
+
q = torch.where(
|
| 44 |
+
param_norm > 0.0,
|
| 45 |
+
torch.where(
|
| 46 |
+
update_norm > 0,
|
| 47 |
+
(g["trust_coefficient"] * param_norm / update_norm),
|
| 48 |
+
one,
|
| 49 |
+
),
|
| 50 |
+
one,
|
| 51 |
+
)
|
| 52 |
+
dp = dp.mul(q)
|
| 53 |
+
|
| 54 |
+
param_state = self.state[p]
|
| 55 |
+
if "mu" not in param_state:
|
| 56 |
+
param_state["mu"] = torch.zeros_like(p)
|
| 57 |
+
mu = param_state["mu"]
|
| 58 |
+
mu.mul_(g["momentum"]).add_(dp)
|
| 59 |
+
p.add_(mu, alpha=-g["lr"])
|
util/lr_sched.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def adjust_learning_rate(optimizer, epoch, args):
|
| 11 |
+
"""Decay the learning rate with half-cycle cosine after warmup"""
|
| 12 |
+
if epoch < args.warmup_epochs:
|
| 13 |
+
lr = args.lr * epoch / args.warmup_epochs
|
| 14 |
+
else:
|
| 15 |
+
lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * (
|
| 16 |
+
1.0
|
| 17 |
+
+ math.cos(
|
| 18 |
+
math.pi
|
| 19 |
+
* (epoch - args.warmup_epochs)
|
| 20 |
+
/ (args.epochs - args.warmup_epochs)
|
| 21 |
+
)
|
| 22 |
+
)
|
| 23 |
+
for param_group in optimizer.param_groups:
|
| 24 |
+
if "lr_scale" in param_group:
|
| 25 |
+
param_group["lr"] = lr * param_group["lr_scale"]
|
| 26 |
+
else:
|
| 27 |
+
param_group["lr"] = lr
|
| 28 |
+
return lr
|
util/metric.py
ADDED
|
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import prettytable
|
| 3 |
+
import copy
|
| 4 |
+
import sys
|
| 5 |
+
from importlib import import_module
|
| 6 |
+
from inspect import signature
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Optional, Union
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
from scipy.stats import kendalltau, pearsonr, spearmanr
|
| 12 |
+
from sklearn.metrics import (
|
| 13 |
+
confusion_matrix,
|
| 14 |
+
f1_score,
|
| 15 |
+
fbeta_score,
|
| 16 |
+
get_scorer,
|
| 17 |
+
get_scorer_names,
|
| 18 |
+
make_scorer,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def binary_accuracy(output: torch.Tensor, target: torch.Tensor) -> float:
|
| 23 |
+
"""Computes the accuracy for binary classification"""
|
| 24 |
+
with torch.no_grad():
|
| 25 |
+
batch_size = target.size(0)
|
| 26 |
+
pred = (output >= 0.5).float().t().view(-1)
|
| 27 |
+
correct = pred.eq(target.view(-1)).float().sum()
|
| 28 |
+
correct.mul_(100.0 / batch_size)
|
| 29 |
+
return correct
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def accuracy(output, target, topk=(1,)):
|
| 33 |
+
r"""
|
| 34 |
+
Computes the accuracy over the k top predictions for the specified values of k
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
output (tensor): Classification outputs, :math:`(N, C)` where `C = number of classes`
|
| 38 |
+
target (tensor): :math:`(N)` where each value is :math:`0 \leq \text{targets}[i] \leq C-1`
|
| 39 |
+
topk (sequence[int]): A list of top-N number.
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
Top-N accuracies (N :math:`\in` topK).
|
| 43 |
+
"""
|
| 44 |
+
with torch.no_grad():
|
| 45 |
+
maxk = max(topk)
|
| 46 |
+
batch_size = target.size(0)
|
| 47 |
+
|
| 48 |
+
_, pred = output.topk(maxk, 1, True, True)
|
| 49 |
+
pred = pred.t()
|
| 50 |
+
correct = pred.eq(target[None])
|
| 51 |
+
|
| 52 |
+
res = []
|
| 53 |
+
for k in topk:
|
| 54 |
+
correct_k = correct[:k].flatten().sum(dtype=torch.float32)
|
| 55 |
+
res.append(correct_k * (100.0 / batch_size))
|
| 56 |
+
return res
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class ConfusionMatrix(object):
|
| 60 |
+
def __init__(self, num_classes):
|
| 61 |
+
self.num_classes = num_classes
|
| 62 |
+
self.mat = None
|
| 63 |
+
|
| 64 |
+
def update(self, target, output):
|
| 65 |
+
"""
|
| 66 |
+
Update confusion matrix.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
target: ground truth
|
| 70 |
+
output: predictions of models
|
| 71 |
+
|
| 72 |
+
Shape:
|
| 73 |
+
- target: :math:`(minibatch, C)` where C means the number of classes.
|
| 74 |
+
- output: :math:`(minibatch, C)` where C means the number of classes.
|
| 75 |
+
"""
|
| 76 |
+
n = self.num_classes
|
| 77 |
+
if self.mat is None:
|
| 78 |
+
self.mat = torch.zeros((n, n), dtype=torch.int64, device=target.device)
|
| 79 |
+
with torch.no_grad():
|
| 80 |
+
k = (target >= 0) & (target < n)
|
| 81 |
+
inds = n * target[k].to(torch.int64) + output[k]
|
| 82 |
+
self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)
|
| 83 |
+
|
| 84 |
+
def reset(self):
|
| 85 |
+
self.mat.zero_()
|
| 86 |
+
|
| 87 |
+
def compute(self):
|
| 88 |
+
"""compute global accuracy, per-class accuracy and per-class IoU"""
|
| 89 |
+
h = self.mat.float()
|
| 90 |
+
acc_global = torch.diag(h).sum() / h.sum()
|
| 91 |
+
acc = torch.diag(h) / h.sum(1)
|
| 92 |
+
iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
|
| 93 |
+
return acc_global, acc, iu
|
| 94 |
+
|
| 95 |
+
# def reduce_from_all_processes(self):
|
| 96 |
+
# if not torch.distributed.is_available():
|
| 97 |
+
# return
|
| 98 |
+
# if not torch.distributed.is_initialized():
|
| 99 |
+
# return
|
| 100 |
+
# torch.distributed.barrier()
|
| 101 |
+
# torch.distributed.all_reduce(self.mat)
|
| 102 |
+
|
| 103 |
+
def __str__(self):
|
| 104 |
+
acc_global, acc, iu = self.compute()
|
| 105 |
+
return (
|
| 106 |
+
"global correct: {:.1f}\n"
|
| 107 |
+
"average row correct: {}\n"
|
| 108 |
+
"IoU: {}\n"
|
| 109 |
+
"mean IoU: {:.1f}"
|
| 110 |
+
).format(
|
| 111 |
+
acc_global.item() * 100,
|
| 112 |
+
["{:.1f}".format(i) for i in (acc * 100).tolist()],
|
| 113 |
+
["{:.1f}".format(i) for i in (iu * 100).tolist()],
|
| 114 |
+
iu.mean().item() * 100,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
def format(self, classes: list):
|
| 118 |
+
"""Get the accuracy and IoU for each class in the table format"""
|
| 119 |
+
acc_global, acc, iu = self.compute()
|
| 120 |
+
|
| 121 |
+
table = prettytable.PrettyTable(["class", "acc", "iou"])
|
| 122 |
+
for i, class_name, per_acc, per_iu in zip(
|
| 123 |
+
range(len(classes)), classes, (acc * 100).tolist(), (iu * 100).tolist()
|
| 124 |
+
):
|
| 125 |
+
table.add_row([class_name, per_acc, per_iu])
|
| 126 |
+
|
| 127 |
+
return (
|
| 128 |
+
"global correct: {:.1f}\nmean correct:{:.1f}\nmean IoU: {:.1f}\n{}".format(
|
| 129 |
+
acc_global.item() * 100,
|
| 130 |
+
acc.mean().item() * 100,
|
| 131 |
+
iu.mean().item() * 100,
|
| 132 |
+
table.get_string(),
|
| 133 |
+
)
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def kappa(
|
| 138 |
+
y_true: np.ndarray,
|
| 139 |
+
y_pred: np.ndarray,
|
| 140 |
+
weights: Optional[Union[str, np.ndarray]] = None,
|
| 141 |
+
allow_off_by_one: bool = False,
|
| 142 |
+
) -> float:
|
| 143 |
+
"""
|
| 144 |
+
Calculate the kappa inter-rater agreement.
|
| 145 |
+
|
| 146 |
+
The agreement is calculated between the gold standard and the predicted
|
| 147 |
+
ratings. Potential values range from -1 (representing complete disagreement)
|
| 148 |
+
to 1 (representing complete agreement). A kappa value of 0 is expected if
|
| 149 |
+
all agreement is due to chance.
|
| 150 |
+
|
| 151 |
+
In the course of calculating kappa, all items in ``y_true`` and ``y_pred`` will
|
| 152 |
+
first be converted to floats and then rounded to integers.
|
| 153 |
+
|
| 154 |
+
It is assumed that y_true and y_pred contain the complete range of possible
|
| 155 |
+
ratings.
|
| 156 |
+
|
| 157 |
+
This function contains a combination of code from yorchopolis's kappa-stats
|
| 158 |
+
and Ben Hamner's Metrics projects on Github.
|
| 159 |
+
|
| 160 |
+
Parameters
|
| 161 |
+
----------
|
| 162 |
+
y_true : numpy.ndarray
|
| 163 |
+
The true/actual/gold labels for the data.
|
| 164 |
+
y_pred : numpy.ndarray
|
| 165 |
+
The predicted/observed labels for the data.
|
| 166 |
+
weights : Optional[Union[str, numpy.ndarray]], default=None
|
| 167 |
+
Specifies the weight matrix for the calculation.
|
| 168 |
+
Possible values are: ``None`` (unweighted-kappa), ``"quadratic"``
|
| 169 |
+
(quadratically weighted kappa), ``"linear"`` (linearly weighted kappa),
|
| 170 |
+
and a two-dimensional numpy array (a custom matrix of weights). Each
|
| 171 |
+
weight in this array corresponds to the :math:`w_{ij}` values in the
|
| 172 |
+
Wikipedia description of how to calculate weighted Cohen's kappa.
|
| 173 |
+
allow_off_by_one : bool, default=False
|
| 174 |
+
If true, ratings that are off by one are counted as
|
| 175 |
+
equal, and all other differences are reduced by
|
| 176 |
+
one. For example, 1 and 2 will be considered to be
|
| 177 |
+
equal, whereas 1 and 3 will have a difference of 1
|
| 178 |
+
for when building the weights matrix.
|
| 179 |
+
|
| 180 |
+
Returns
|
| 181 |
+
-------
|
| 182 |
+
float
|
| 183 |
+
The weighted or unweighted kappa score.
|
| 184 |
+
|
| 185 |
+
Raises
|
| 186 |
+
------
|
| 187 |
+
AssertionError
|
| 188 |
+
If ``y_true`` != ``y_pred``.
|
| 189 |
+
ValueError
|
| 190 |
+
If labels cannot be converted to int.
|
| 191 |
+
ValueError
|
| 192 |
+
If invalid weight scheme.
|
| 193 |
+
"""
|
| 194 |
+
# Ensure that the lists are both the same length
|
| 195 |
+
assert len(y_true) == len(y_pred)
|
| 196 |
+
|
| 197 |
+
# This rather crazy looking typecast is intended to work as follows:
|
| 198 |
+
# If an input is an int, the operations will have no effect.
|
| 199 |
+
# If it is a float, it will be rounded and then converted to an int
|
| 200 |
+
# because the ml_metrics package requires ints.
|
| 201 |
+
# If it is a str like "1", then it will be converted to a (rounded) int.
|
| 202 |
+
# If it is a str that can't be typecast, then the user is
|
| 203 |
+
# given a hopefully useful error message.
|
| 204 |
+
try:
|
| 205 |
+
y_true = np.array([int(np.round(float(y))) for y in y_true])
|
| 206 |
+
y_pred = np.array([int(np.round(float(y))) for y in y_pred])
|
| 207 |
+
except ValueError:
|
| 208 |
+
raise ValueError(
|
| 209 |
+
"For kappa, the labels should be integers or strings"
|
| 210 |
+
" that can be converted to ints (E.g., '4.0' or "
|
| 211 |
+
"'3')."
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
# Figure out normalized expected values
|
| 215 |
+
min_rating = min(min(y_true), min(y_pred))
|
| 216 |
+
max_rating = max(max(y_true), max(y_pred))
|
| 217 |
+
|
| 218 |
+
# shift the values so that the lowest value is 0
|
| 219 |
+
# (to support scales that include negative values)
|
| 220 |
+
y_true = y_true - min_rating
|
| 221 |
+
y_pred = y_pred - min_rating
|
| 222 |
+
|
| 223 |
+
# Build the observed/confusion matrix
|
| 224 |
+
num_ratings = max_rating - min_rating + 1
|
| 225 |
+
observed = confusion_matrix(y_true, y_pred, labels=list(range(num_ratings)))
|
| 226 |
+
num_scored_items = float(len(y_true))
|
| 227 |
+
|
| 228 |
+
# Build weight array if weren't passed one
|
| 229 |
+
if isinstance(weights, str):
|
| 230 |
+
wt_scheme = weights
|
| 231 |
+
weights = None
|
| 232 |
+
else:
|
| 233 |
+
wt_scheme = ""
|
| 234 |
+
|
| 235 |
+
if weights is None:
|
| 236 |
+
kappa_weights = np.empty((num_ratings, num_ratings))
|
| 237 |
+
for i in range(num_ratings):
|
| 238 |
+
for j in range(num_ratings):
|
| 239 |
+
diff = abs(i - j)
|
| 240 |
+
if allow_off_by_one and diff:
|
| 241 |
+
diff -= 1
|
| 242 |
+
if wt_scheme == "linear":
|
| 243 |
+
kappa_weights[i, j] = diff
|
| 244 |
+
elif wt_scheme == "quadratic":
|
| 245 |
+
kappa_weights[i, j] = diff**2
|
| 246 |
+
elif not wt_scheme: # unweighted
|
| 247 |
+
kappa_weights[i, j] = bool(diff)
|
| 248 |
+
else:
|
| 249 |
+
raise ValueError(
|
| 250 |
+
"Invalid weight scheme specified for " f"kappa: {wt_scheme}"
|
| 251 |
+
)
|
| 252 |
+
else:
|
| 253 |
+
kappa_weights = weights
|
| 254 |
+
|
| 255 |
+
hist_true: np.ndarray = np.bincount(y_true, minlength=num_ratings)
|
| 256 |
+
hist_true = hist_true[:num_ratings] / num_scored_items
|
| 257 |
+
hist_pred: np.ndarray = np.bincount(y_pred, minlength=num_ratings)
|
| 258 |
+
hist_pred = hist_pred[:num_ratings] / num_scored_items
|
| 259 |
+
expected = np.outer(hist_true, hist_pred)
|
| 260 |
+
|
| 261 |
+
# Normalize observed array
|
| 262 |
+
observed = observed / num_scored_items
|
| 263 |
+
|
| 264 |
+
# If all weights are zero, that means no disagreements matter.
|
| 265 |
+
k = 1.0
|
| 266 |
+
if np.count_nonzero(kappa_weights):
|
| 267 |
+
observed_sum = np.sum(kappa_weights * observed)
|
| 268 |
+
expected_sum = np.sum(kappa_weights * expected)
|
| 269 |
+
k -= np.sum(observed_sum) / np.sum(expected_sum)
|
| 270 |
+
|
| 271 |
+
return k
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def correlation(
|
| 275 |
+
y_true: np.ndarray, y_pred: np.ndarray, corr_type: str = "pearson"
|
| 276 |
+
) -> float:
|
| 277 |
+
"""
|
| 278 |
+
Calculate given correlation type between ``y_true`` and ``y_pred``.
|
| 279 |
+
|
| 280 |
+
``y_pred`` can be multi-dimensional. If ``y_pred`` is 1-dimensional, it
|
| 281 |
+
may either contain probabilities, most-likely classification labels, or
|
| 282 |
+
regressor predictions. In that case, we simply return the correlation
|
| 283 |
+
between ``y_true`` and ``y_pred``. If ``y_pred`` is multi-dimensional,
|
| 284 |
+
it contains probabilties for multiple classes in which case, we infer the
|
| 285 |
+
most likely labels and then compute the correlation between those and
|
| 286 |
+
``y_true``.
|
| 287 |
+
|
| 288 |
+
Parameters
|
| 289 |
+
----------
|
| 290 |
+
y_true : numpy.ndarray
|
| 291 |
+
The true/actual/gold labels for the data.
|
| 292 |
+
y_pred : numpy.ndarray
|
| 293 |
+
The predicted/observed labels for the data.
|
| 294 |
+
corr_type : str, default="pearson"
|
| 295 |
+
Which type of correlation to compute. Possible
|
| 296 |
+
choices are "pearson", "spearman", and "kendall_tau".
|
| 297 |
+
|
| 298 |
+
Returns
|
| 299 |
+
-------
|
| 300 |
+
float
|
| 301 |
+
correlation value if well-defined, else 0.0
|
| 302 |
+
"""
|
| 303 |
+
# get the correlation function to use based on the given type
|
| 304 |
+
corr_func = pearsonr
|
| 305 |
+
if corr_type == "spearman":
|
| 306 |
+
corr_func = spearmanr
|
| 307 |
+
elif corr_type == "kendall_tau":
|
| 308 |
+
corr_func = kendalltau
|
| 309 |
+
|
| 310 |
+
# convert to numpy array in case we are passed a list
|
| 311 |
+
y_pred = np.array(y_pred)
|
| 312 |
+
|
| 313 |
+
# multi-dimensional -> probability array -> get label
|
| 314 |
+
if y_pred.ndim > 1:
|
| 315 |
+
labels = np.argmax(y_pred, axis=1)
|
| 316 |
+
ret_score = corr_func(y_true, labels)[0]
|
| 317 |
+
# 1-dimensional -> probabilities/labels -> use as is
|
| 318 |
+
else:
|
| 319 |
+
ret_score = corr_func(y_true, y_pred)[0]
|
| 320 |
+
return ret_score
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def f1_score_least_frequent(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
| 324 |
+
"""
|
| 325 |
+
Calculate F1 score of the least frequent label/class.
|
| 326 |
+
|
| 327 |
+
Parameters
|
| 328 |
+
----------
|
| 329 |
+
y_true : numpy.ndarray
|
| 330 |
+
The true/actual/gold labels for the data.
|
| 331 |
+
y_pred : numpy.ndarray
|
| 332 |
+
The predicted/observed labels for the data.
|
| 333 |
+
|
| 334 |
+
Returns
|
| 335 |
+
-------
|
| 336 |
+
float
|
| 337 |
+
F1 score of the least frequent label.
|
| 338 |
+
"""
|
| 339 |
+
least_frequent = np.bincount(y_true).argmin()
|
| 340 |
+
return f1_score(y_true, y_pred, average=None)[least_frequent]
|
util/misc.py
ADDED
|
@@ -0,0 +1,455 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
# References:
|
| 8 |
+
# DeiT: https://github.com/facebookresearch/deit
|
| 9 |
+
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
|
| 10 |
+
# --------------------------------------------------------
|
| 11 |
+
|
| 12 |
+
import builtins
|
| 13 |
+
import datetime
|
| 14 |
+
import os
|
| 15 |
+
import time
|
| 16 |
+
from collections import defaultdict, deque
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
import shutil
|
| 19 |
+
import torch
|
| 20 |
+
import torch.distributed as dist
|
| 21 |
+
from torch import inf
|
| 22 |
+
import json
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class SmoothedValue(object):
|
| 26 |
+
"""Track a series of values and provide access to smoothed values over a
|
| 27 |
+
window or the global series average.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, window_size=20, fmt=None):
|
| 31 |
+
if fmt is None:
|
| 32 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
| 33 |
+
self.deque = deque(maxlen=window_size)
|
| 34 |
+
self.total = 0.0
|
| 35 |
+
self.count = 0
|
| 36 |
+
self.fmt = fmt
|
| 37 |
+
|
| 38 |
+
def update(self, value, n=1):
|
| 39 |
+
self.deque.append(value)
|
| 40 |
+
self.count += n
|
| 41 |
+
self.total += value * n
|
| 42 |
+
|
| 43 |
+
def synchronize_between_processes(self):
|
| 44 |
+
"""
|
| 45 |
+
Warning: does not synchronize the deque!
|
| 46 |
+
"""
|
| 47 |
+
if not is_dist_avail_and_initialized():
|
| 48 |
+
return
|
| 49 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
|
| 50 |
+
dist.barrier()
|
| 51 |
+
dist.all_reduce(t)
|
| 52 |
+
t = t.tolist()
|
| 53 |
+
self.count = int(t[0])
|
| 54 |
+
self.total = t[1]
|
| 55 |
+
|
| 56 |
+
@property
|
| 57 |
+
def median(self):
|
| 58 |
+
d = torch.tensor(list(self.deque))
|
| 59 |
+
return d.median().item()
|
| 60 |
+
|
| 61 |
+
@property
|
| 62 |
+
def avg(self):
|
| 63 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
| 64 |
+
return d.mean().item()
|
| 65 |
+
|
| 66 |
+
@property
|
| 67 |
+
def global_avg(self):
|
| 68 |
+
return self.total / self.count
|
| 69 |
+
|
| 70 |
+
@property
|
| 71 |
+
def max(self):
|
| 72 |
+
return max(self.deque)
|
| 73 |
+
|
| 74 |
+
@property
|
| 75 |
+
def value(self):
|
| 76 |
+
return self.deque[-1]
|
| 77 |
+
|
| 78 |
+
def __str__(self):
|
| 79 |
+
return self.fmt.format(
|
| 80 |
+
median=self.median,
|
| 81 |
+
avg=self.avg,
|
| 82 |
+
global_avg=self.global_avg,
|
| 83 |
+
max=self.max,
|
| 84 |
+
value=self.value,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class MetricLogger(object):
|
| 89 |
+
def __init__(self, delimiter="\t"):
|
| 90 |
+
self.meters = defaultdict(SmoothedValue)
|
| 91 |
+
self.delimiter = delimiter
|
| 92 |
+
|
| 93 |
+
def update(self, **kwargs):
|
| 94 |
+
for k, v in kwargs.items():
|
| 95 |
+
if v is None:
|
| 96 |
+
continue
|
| 97 |
+
if isinstance(v, torch.Tensor):
|
| 98 |
+
v = v.item()
|
| 99 |
+
assert isinstance(v, (float, int))
|
| 100 |
+
self.meters[k].update(v)
|
| 101 |
+
|
| 102 |
+
def __getattr__(self, attr):
|
| 103 |
+
if attr in self.meters:
|
| 104 |
+
return self.meters[attr]
|
| 105 |
+
if attr in self.__dict__:
|
| 106 |
+
return self.__dict__[attr]
|
| 107 |
+
raise AttributeError(
|
| 108 |
+
"'{}' object has no attribute '{}'".format(type(self).__name__, attr)
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
def __str__(self):
|
| 112 |
+
loss_str = []
|
| 113 |
+
for name, meter in self.meters.items():
|
| 114 |
+
loss_str.append("{}: {}".format(name, str(meter)))
|
| 115 |
+
return self.delimiter.join(loss_str)
|
| 116 |
+
|
| 117 |
+
def synchronize_between_processes(self):
|
| 118 |
+
for meter in self.meters.values():
|
| 119 |
+
meter.synchronize_between_processes()
|
| 120 |
+
|
| 121 |
+
def add_meter(self, name, meter):
|
| 122 |
+
self.meters[name] = meter
|
| 123 |
+
|
| 124 |
+
def log_every(self, iterable, print_freq, header=None):
|
| 125 |
+
i = 0
|
| 126 |
+
if not header:
|
| 127 |
+
header = ""
|
| 128 |
+
start_time = time.time()
|
| 129 |
+
end = time.time()
|
| 130 |
+
iter_time = SmoothedValue(fmt="{avg:.4f}")
|
| 131 |
+
data_time = SmoothedValue(fmt="{avg:.4f}")
|
| 132 |
+
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
|
| 133 |
+
log_msg = [
|
| 134 |
+
header,
|
| 135 |
+
"[{0" + space_fmt + "}/{1}]",
|
| 136 |
+
"eta: {eta}",
|
| 137 |
+
"{meters}",
|
| 138 |
+
"time: {time}",
|
| 139 |
+
"data: {data}",
|
| 140 |
+
]
|
| 141 |
+
if torch.cuda.is_available():
|
| 142 |
+
log_msg.append("max mem: {memory:.0f}")
|
| 143 |
+
log_msg = self.delimiter.join(log_msg)
|
| 144 |
+
MB = 1024.0 * 1024.0
|
| 145 |
+
for obj in iterable:
|
| 146 |
+
data_time.update(time.time() - end)
|
| 147 |
+
yield obj
|
| 148 |
+
iter_time.update(time.time() - end)
|
| 149 |
+
if i % print_freq == 0 or i == len(iterable) - 1:
|
| 150 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
| 151 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
| 152 |
+
if torch.cuda.is_available():
|
| 153 |
+
print(
|
| 154 |
+
log_msg.format(
|
| 155 |
+
i,
|
| 156 |
+
len(iterable),
|
| 157 |
+
eta=eta_string,
|
| 158 |
+
meters=str(self),
|
| 159 |
+
time=str(iter_time),
|
| 160 |
+
data=str(data_time),
|
| 161 |
+
memory=torch.cuda.max_memory_allocated() / MB,
|
| 162 |
+
)
|
| 163 |
+
)
|
| 164 |
+
else:
|
| 165 |
+
print(
|
| 166 |
+
log_msg.format(
|
| 167 |
+
i,
|
| 168 |
+
len(iterable),
|
| 169 |
+
eta=eta_string,
|
| 170 |
+
meters=str(self),
|
| 171 |
+
time=str(iter_time),
|
| 172 |
+
data=str(data_time),
|
| 173 |
+
)
|
| 174 |
+
)
|
| 175 |
+
i += 1
|
| 176 |
+
end = time.time()
|
| 177 |
+
total_time = time.time() - start_time
|
| 178 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
| 179 |
+
print(
|
| 180 |
+
"{} Total time: {} ({:.4f} s / it)".format(
|
| 181 |
+
header, total_time_str, total_time / len(iterable)
|
| 182 |
+
)
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def setup_for_distributed(is_master):
|
| 187 |
+
"""
|
| 188 |
+
This function disables printing when not in master process
|
| 189 |
+
"""
|
| 190 |
+
builtin_print = builtins.print
|
| 191 |
+
|
| 192 |
+
def print(*args, **kwargs):
|
| 193 |
+
force = kwargs.pop("force", False)
|
| 194 |
+
force = force or (get_world_size() > 8)
|
| 195 |
+
if is_master or force:
|
| 196 |
+
now = datetime.datetime.now().time()
|
| 197 |
+
builtin_print("[{}] ".format(now), end="") # print with time stamp
|
| 198 |
+
builtin_print(*args, **kwargs)
|
| 199 |
+
|
| 200 |
+
builtins.print = print
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def is_dist_avail_and_initialized():
|
| 204 |
+
if not dist.is_available():
|
| 205 |
+
return False
|
| 206 |
+
if not dist.is_initialized():
|
| 207 |
+
return False
|
| 208 |
+
return True
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def get_world_size():
|
| 212 |
+
if not is_dist_avail_and_initialized():
|
| 213 |
+
return 1
|
| 214 |
+
return dist.get_world_size()
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def get_rank():
|
| 218 |
+
if not is_dist_avail_and_initialized():
|
| 219 |
+
return 0
|
| 220 |
+
return dist.get_rank()
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def is_main_process():
|
| 224 |
+
return get_rank() == 0
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def save_on_master(*args, **kwargs):
|
| 228 |
+
if is_main_process():
|
| 229 |
+
torch.save(*args, **kwargs)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def init_distributed_mode(args):
|
| 233 |
+
if args.dist_on_itp:
|
| 234 |
+
args.rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
|
| 235 |
+
args.world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"])
|
| 236 |
+
args.gpu = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
|
| 237 |
+
args.dist_url = "tcp://%s:%s" % (
|
| 238 |
+
os.environ["MASTER_ADDR"],
|
| 239 |
+
os.environ["MASTER_PORT"],
|
| 240 |
+
)
|
| 241 |
+
os.environ["LOCAL_RANK"] = str(args.gpu)
|
| 242 |
+
os.environ["RANK"] = str(args.rank)
|
| 243 |
+
os.environ["WORLD_SIZE"] = str(args.world_size)
|
| 244 |
+
# ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
|
| 245 |
+
elif "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
| 246 |
+
args.rank = int(os.environ["RANK"])
|
| 247 |
+
args.world_size = int(os.environ["WORLD_SIZE"])
|
| 248 |
+
args.gpu = int(os.environ["LOCAL_RANK"])
|
| 249 |
+
elif "SLURM_PROCID" in os.environ:
|
| 250 |
+
args.rank = int(os.environ["SLURM_PROCID"])
|
| 251 |
+
args.gpu = args.rank % torch.cuda.device_count()
|
| 252 |
+
else:
|
| 253 |
+
print("Not using distributed mode")
|
| 254 |
+
setup_for_distributed(is_master=True) # hack
|
| 255 |
+
args.distributed = False
|
| 256 |
+
return
|
| 257 |
+
|
| 258 |
+
args.distributed = True
|
| 259 |
+
|
| 260 |
+
torch.cuda.set_device(args.gpu)
|
| 261 |
+
args.dist_backend = "nccl"
|
| 262 |
+
print(
|
| 263 |
+
"| distributed init (rank {}): {}, gpu {}".format(
|
| 264 |
+
args.rank, args.dist_url, args.gpu
|
| 265 |
+
),
|
| 266 |
+
flush=True,
|
| 267 |
+
)
|
| 268 |
+
torch.distributed.init_process_group(
|
| 269 |
+
backend=args.dist_backend,
|
| 270 |
+
init_method=args.dist_url,
|
| 271 |
+
world_size=args.world_size,
|
| 272 |
+
rank=args.rank,
|
| 273 |
+
)
|
| 274 |
+
torch.distributed.barrier()
|
| 275 |
+
setup_for_distributed(args.rank == 0)
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
class NativeScalerWithGradNormCount:
|
| 279 |
+
state_dict_key = "amp_scaler"
|
| 280 |
+
|
| 281 |
+
def __init__(self):
|
| 282 |
+
self._scaler = torch.cuda.amp.GradScaler()
|
| 283 |
+
|
| 284 |
+
def __call__(
|
| 285 |
+
self,
|
| 286 |
+
loss,
|
| 287 |
+
optimizer,
|
| 288 |
+
clip_grad=None,
|
| 289 |
+
parameters=None,
|
| 290 |
+
create_graph=False,
|
| 291 |
+
update_grad=True,
|
| 292 |
+
):
|
| 293 |
+
self._scaler.scale(loss).backward(create_graph=create_graph)
|
| 294 |
+
if update_grad:
|
| 295 |
+
if clip_grad is not None:
|
| 296 |
+
assert parameters is not None
|
| 297 |
+
self._scaler.unscale_(
|
| 298 |
+
optimizer
|
| 299 |
+
) # unscale the gradients of optimizer's assigned params in-place
|
| 300 |
+
norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
|
| 301 |
+
else:
|
| 302 |
+
self._scaler.unscale_(optimizer)
|
| 303 |
+
norm = get_grad_norm_(parameters)
|
| 304 |
+
self._scaler.step(optimizer)
|
| 305 |
+
self._scaler.update()
|
| 306 |
+
else:
|
| 307 |
+
norm = None
|
| 308 |
+
return norm
|
| 309 |
+
|
| 310 |
+
def state_dict(self):
|
| 311 |
+
return self._scaler.state_dict()
|
| 312 |
+
|
| 313 |
+
def load_state_dict(self, state_dict):
|
| 314 |
+
self._scaler.load_state_dict(state_dict)
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
|
| 318 |
+
if isinstance(parameters, torch.Tensor):
|
| 319 |
+
parameters = [parameters]
|
| 320 |
+
parameters = [p for p in parameters if p.grad is not None]
|
| 321 |
+
norm_type = float(norm_type)
|
| 322 |
+
if len(parameters) == 0:
|
| 323 |
+
return torch.tensor(0.0)
|
| 324 |
+
device = parameters[0].grad.device
|
| 325 |
+
if norm_type == inf:
|
| 326 |
+
total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
|
| 327 |
+
else:
|
| 328 |
+
total_norm = torch.norm(
|
| 329 |
+
torch.stack(
|
| 330 |
+
[torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]
|
| 331 |
+
),
|
| 332 |
+
norm_type,
|
| 333 |
+
)
|
| 334 |
+
return total_norm
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler):
|
| 338 |
+
output_dir = Path(args.output_dir)
|
| 339 |
+
epoch_name = str(epoch)
|
| 340 |
+
if loss_scaler is not None:
|
| 341 |
+
checkpoint_paths = [output_dir / ("checkpoint-%s.pth" % epoch_name)]
|
| 342 |
+
for checkpoint_path in checkpoint_paths:
|
| 343 |
+
to_save = {
|
| 344 |
+
"model": model_without_ddp.state_dict(),
|
| 345 |
+
"optimizer": optimizer.state_dict(),
|
| 346 |
+
"epoch": epoch,
|
| 347 |
+
"scaler": loss_scaler.state_dict(),
|
| 348 |
+
"args": args,
|
| 349 |
+
}
|
| 350 |
+
save_on_master(to_save, checkpoint_path)
|
| 351 |
+
else:
|
| 352 |
+
client_state = {"epoch": epoch}
|
| 353 |
+
model.save_checkpoint(
|
| 354 |
+
save_dir=args.output_dir,
|
| 355 |
+
tag="checkpoint-%s" % epoch_name,
|
| 356 |
+
client_state=client_state,
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
def save_best_model(
|
| 361 |
+
args, epoch, model, model_without_ddp, optimizer, loss_scaler, is_best
|
| 362 |
+
):
|
| 363 |
+
output_dir = Path(args.output_dir)
|
| 364 |
+
epoch_name = str(epoch)
|
| 365 |
+
if loss_scaler is not None:
|
| 366 |
+
checkpoint_path = output_dir / ("last.pth.tar")
|
| 367 |
+
to_save = {
|
| 368 |
+
"model": model_without_ddp.state_dict(),
|
| 369 |
+
"optimizer": optimizer.state_dict(),
|
| 370 |
+
"epoch": epoch,
|
| 371 |
+
"scaler": loss_scaler.state_dict(),
|
| 372 |
+
"args": args,
|
| 373 |
+
}
|
| 374 |
+
save_on_master(to_save, checkpoint_path)
|
| 375 |
+
else:
|
| 376 |
+
client_state = {"epoch": epoch}
|
| 377 |
+
model.save_checkpoint(
|
| 378 |
+
save_dir=args.output_dir,
|
| 379 |
+
tag="checkpoint-%s" % epoch_name,
|
| 380 |
+
client_state=client_state,
|
| 381 |
+
)
|
| 382 |
+
if is_best:
|
| 383 |
+
filepath_best = output_dir / ("best.pth.tar")
|
| 384 |
+
shutil.copyfile(checkpoint_path, filepath_best)
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
def save_current_best_model(
|
| 388 |
+
args, epoch, model, model_without_ddp, optimizer, loss_scaler, is_best, current_interval
|
| 389 |
+
):
|
| 390 |
+
output_dir = Path(args.output_dir)
|
| 391 |
+
epoch_name = str(epoch)
|
| 392 |
+
if loss_scaler is not None:
|
| 393 |
+
checkpoint_paths = [output_dir / (f"{current_interval}_last.pth.tar")]
|
| 394 |
+
for checkpoint_path in checkpoint_paths:
|
| 395 |
+
to_save = {
|
| 396 |
+
"model": model_without_ddp.state_dict(),
|
| 397 |
+
"optimizer": optimizer.state_dict(),
|
| 398 |
+
"epoch": epoch,
|
| 399 |
+
"scaler": loss_scaler.state_dict(),
|
| 400 |
+
"args": args,
|
| 401 |
+
}
|
| 402 |
+
save_on_master(to_save, checkpoint_path)
|
| 403 |
+
else:
|
| 404 |
+
client_state = {"epoch": epoch}
|
| 405 |
+
model.save_checkpoint(
|
| 406 |
+
save_dir=args.output_dir,
|
| 407 |
+
tag="checkpoint-%s" % epoch_name,
|
| 408 |
+
client_state=client_state,
|
| 409 |
+
)
|
| 410 |
+
if is_best:
|
| 411 |
+
filepath_best = output_dir / (f"{current_interval}_best.pth.tar")
|
| 412 |
+
shutil.copyfile(checkpoint_path, filepath_best)
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def load_model(args, model_without_ddp, optimizer, loss_scaler):
|
| 416 |
+
if args.resume:
|
| 417 |
+
if args.resume.startswith("https"):
|
| 418 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
| 419 |
+
args.resume, map_location="cpu", check_hash=True, weights_only=False
|
| 420 |
+
)
|
| 421 |
+
else:
|
| 422 |
+
checkpoint = torch.load(args.resume, map_location="cpu")
|
| 423 |
+
model_without_ddp.load_state_dict(checkpoint["model"], weights_only=False)
|
| 424 |
+
print("Resume checkpoint %s" % args.resume)
|
| 425 |
+
if (
|
| 426 |
+
"optimizer" in checkpoint
|
| 427 |
+
and "epoch" in checkpoint
|
| 428 |
+
and not (hasattr(args, "eval") and args.eval)
|
| 429 |
+
):
|
| 430 |
+
optimizer.load_state_dict(checkpoint["optimizer"], weights_only=False)
|
| 431 |
+
args.start_epoch = checkpoint["epoch"] + 1
|
| 432 |
+
if "scaler" in checkpoint:
|
| 433 |
+
loss_scaler.load_state_dict(checkpoint["scaler"], weights_only=False)
|
| 434 |
+
print("With optim & sched!")
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
def all_reduce_mean(x):
|
| 438 |
+
world_size = get_world_size()
|
| 439 |
+
if world_size > 1:
|
| 440 |
+
x_reduce = torch.tensor(x).cuda()
|
| 441 |
+
dist.all_reduce(x_reduce)
|
| 442 |
+
x_reduce /= world_size
|
| 443 |
+
return x_reduce.item()
|
| 444 |
+
else:
|
| 445 |
+
return x
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
def write_log(log_writer, log_stats, args):
|
| 449 |
+
if args.output_dir and is_main_process():
|
| 450 |
+
if log_writer is not None:
|
| 451 |
+
log_writer.flush()
|
| 452 |
+
with open(
|
| 453 |
+
os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8"
|
| 454 |
+
) as f:
|
| 455 |
+
f.write(json.dumps(log_stats) + "\n")
|