Anonymise commited on
Commit
45461c9
·
1 Parent(s): 6fd4e87

add necessary module

Browse files
dataset/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # ProFound dataset package
dataset/dataset_cls.py ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ from monai.transforms import (
3
+ Compose,
4
+ RandCropByPosNegLabeld,
5
+ CropForegroundd,
6
+ SpatialPadd,
7
+ ScaleIntensityRanged,
8
+ RandShiftIntensityd,
9
+ RandFlipd,
10
+ RandAffined,
11
+ RandZoomd,
12
+ RandRotated,
13
+ RandBiasFieldd,
14
+ RandRotate90d,
15
+ RandGaussianNoised,
16
+ RandGaussianSmoothd,
17
+ NormalizeIntensityd,
18
+ MapTransform,
19
+ RandScaleIntensityd,
20
+ RandSpatialCropd,
21
+ CenterSpatialCropd,
22
+ )
23
+
24
+ from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
25
+ import torch
26
+ import numpy as np
27
+ import nibabel as nib
28
+ import torch.nn.functional as F
29
+ import os
30
+ import pandas as pd
31
+ from ast import literal_eval
32
+
33
+
34
+ class RiskSet(Dataset):
35
+ def __init__(self, args, image_paths, phase, transforms=None):
36
+ super().__init__()
37
+ self.img_dict = pd.read_csv(image_paths)
38
+ if phase == 'train':
39
+ if args.data_num > 0:
40
+ # crop the dataset
41
+ self.img_dict = self.img_dict.iloc[: args.data_num]
42
+ print(f"Loading {phase} dataset with {len(self.img_dict)} samples")
43
+ self.root = args.root
44
+ self._set_dataset_stat()
45
+ self.transforms = transforms # self.get_transforms()
46
+ if not args.demo:
47
+ self.set_sampler()
48
+
49
+ def set_sampler(self):
50
+ class_counts = self.img_dict["pirads"].value_counts().sort_index().values
51
+ class_weights = 1.0 / class_counts
52
+ values = self.img_dict["pirads"].values.astype(int) - 2
53
+ self.sampler_weight = class_weights[values]
54
+
55
+ def cal_weight(self):
56
+ class_counts = self.img_dict["pirads"].value_counts().sort_index().values
57
+ return class_counts
58
+
59
+ def _set_dataset_stat(self):
60
+ self.spacing = (0.5, 0.5, 1.0)
61
+ self.spatial_index = [2, 1, 0] # index used to convert to DHW
62
+ self.target_class = 1
63
+
64
+ def __len__(self):
65
+ return len(self.img_dict)
66
+
67
+ def read(self, path):
68
+ vol = nib.load(os.path.join(self.root, path))
69
+ vol = vol.get_fdata().astype(np.float32).transpose(self.spatial_index)
70
+ vol = torch.from_numpy(vol)
71
+ return vol
72
+
73
+ def __getitem__(self, idx):
74
+ path = self.img_dict.iloc[idx]
75
+ t2w = self.read(path["t2w"])
76
+ dwi = self.read(path["highb"])
77
+ adc = self.read(path["adc"])
78
+ img = torch.stack([t2w, dwi, adc], 0)
79
+ label = torch.tensor(int(path["pirads"]) - 2, dtype=torch.long)
80
+ if self.transforms is not None:
81
+ trans_dict = self.transforms({"image": img})
82
+ if type(trans_dict) == list:
83
+ trans_dict = trans_dict[0]
84
+ img = trans_dict["image"]
85
+ return img, label, torch.tensor(idx, dtype=torch.long)
86
+
87
+
88
+ class ScreeningSet(RiskSet):
89
+ def __init__(self, args, image_paths, phase, transforms=None):
90
+ super().__init__(args=args, image_paths=image_paths, phase = phase, transforms=transforms)
91
+
92
+ def set_sampler(self):
93
+ class_counts = self.img_dict["result"].value_counts().sort_index().values
94
+ class_weights = 1.0 / class_counts
95
+ self.sampler_weight = class_weights[self.img_dict["result"].values]
96
+
97
+ def cal_weight(self):
98
+ class_counts = self.img_dict["result"].value_counts().sort_index().values
99
+ return class_counts
100
+
101
+ def __getitem__(self, idx):
102
+ path = self.img_dict.iloc[idx]
103
+ t2w = self.read(path["t2w"])
104
+ dwi = self.read(path["dwi"])
105
+ adc = self.read(path["adc"])
106
+ img = torch.stack([t2w, dwi, adc], 0)
107
+ label = torch.tensor(int(path["result"]), dtype=torch.long)
108
+ if self.transforms is not None:
109
+ trans_dict = self.transforms({"image": img})
110
+ if type(trans_dict) == list:
111
+ trans_dict = trans_dict[0]
112
+ img = trans_dict["image"]
113
+ return img, label, torch.tensor(idx, dtype=torch.long)
114
+
115
+
116
+ class PromisSet(RiskSet):
117
+ def __init__(self, args, image_paths, phase, transforms=None):
118
+ super().__init__(args=args, image_paths=image_paths, phase=phase, transforms=transforms)
119
+
120
+ def set_sampler(self):
121
+ class_counts = self.img_dict["patient_level"].value_counts().sort_index().values
122
+ class_weights = 1.0 / class_counts
123
+ self.sampler_weight = class_weights[self.img_dict["patient_level"].values.astype(int)]
124
+
125
+ def cal_weight(self):
126
+ class_counts = self.img_dict["patient_level"].value_counts().sort_index().values
127
+ return class_counts
128
+
129
+ def __getitem__(self, idx):
130
+ path = self.img_dict.iloc[idx]
131
+ t2w = self.read(path["t2w"])
132
+ dwi = self.read(path["dwi"])
133
+ adc = self.read(path["adc"])
134
+ img = torch.stack([t2w, dwi, adc], 0)
135
+ zone_level = literal_eval(path["zone_level"])
136
+ zone_level = torch.tensor(zone_level, dtype=torch.float32)
137
+ #patient_level = torch.tensor(int(path["patient_level"]), dtype=torch.float32)
138
+ if self.transforms is not None:
139
+ trans_dict = self.transforms({"image": img})
140
+ if type(trans_dict) == list:
141
+ trans_dict = trans_dict[0]
142
+ img = trans_dict["image"]
143
+ return img, zone_level, torch.tensor(idx, dtype=torch.long)
144
+
145
+ class Promis3HistSet(RiskSet):
146
+ def __init__(self, args, image_paths, phase, transforms=None):
147
+ super().__init__(args=args, image_paths=image_paths, phase=phase, transforms=transforms)
148
+
149
+ def set_sampler(self):
150
+ class_counts = self.img_dict["def"].value_counts().sort_index().values
151
+ class_weights = 1.0 / class_counts
152
+ self.sampler_weight = class_weights[self.img_dict["def"].values.astype(int)]
153
+
154
+ def cal_weight(self):
155
+ class_counts = self.img_dict["def"].value_counts().sort_index().values
156
+ return class_counts
157
+
158
+ def __getitem__(self, idx):
159
+ path = self.img_dict.iloc[idx]
160
+ t2w = self.read(path["t2w"])
161
+ dwi = self.read(path["dwi"])
162
+ adc = self.read(path["adc"])
163
+ img = torch.stack([t2w, dwi, adc], 0)
164
+ label = torch.tensor(int(path["def"]), dtype=torch.long)
165
+ if self.transforms is not None:
166
+ trans_dict = self.transforms({"image": img})
167
+ if type(trans_dict) == list:
168
+ trans_dict = trans_dict[0]
169
+ img = trans_dict["image"]
170
+ return img, label, torch.tensor(idx, dtype=torch.long)
171
+
172
+ class Promis3GGSet(RiskSet):
173
+ def __init__(self, args, image_paths, phase, transforms=None):
174
+ super().__init__(args=args, image_paths=image_paths, phase=phase, transforms=transforms)
175
+
176
+ def set_sampler(self):
177
+ class_counts = self.img_dict["gleason"].value_counts().sort_index().values
178
+ class_weights = 1.0 / class_counts
179
+ self.sampler_weight = class_weights[self.img_dict["gleason"].values.astype(int)]
180
+
181
+ def cal_weight(self):
182
+ class_counts = self.img_dict["gleason"].value_counts().sort_index().values
183
+ return class_counts
184
+
185
+ def __getitem__(self, idx):
186
+ path = self.img_dict.iloc[idx]
187
+ t2w = self.read(path["t2w"])
188
+ dwi = self.read(path["dwi"])
189
+ adc = self.read(path["adc"])
190
+ img = torch.stack([t2w, dwi, adc], 0)
191
+ label = torch.tensor(int(path["gleason"]), dtype=torch.long)
192
+ if self.transforms is not None:
193
+ trans_dict = self.transforms({"image": img})
194
+ if type(trans_dict) == list:
195
+ trans_dict = trans_dict[0]
196
+ img = trans_dict["image"]
197
+ return img, label, torch.tensor(idx, dtype=torch.long)
198
+
199
+
200
+ def get_transforms(args):
201
+ train_transforms = [
202
+ NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
203
+ CenterSpatialCropd(keys="image", roi_size=(80, 300, 300)),
204
+ RandRotated(
205
+ keys="image",
206
+ prob=0.3,
207
+ range_x=10 / 180 * np.pi,
208
+ range_y=10 / 180 * np.pi,
209
+ range_z=10 / 180 * np.pi,
210
+ keep_size=False,
211
+ mode="bilinear",
212
+ ),
213
+ RandZoomd(
214
+ keys="image",
215
+ prob=0.3,
216
+ min_zoom=[0.9, 0.9, 0.9],
217
+ max_zoom=[1.1, 1.1, 1.1],
218
+ mode="trilinear",
219
+ ),
220
+ SpatialPadd(
221
+ keys="image",
222
+ spatial_size=[round(i * 1.2) for i in args.crop_spatial_size],
223
+ ),
224
+ RandSpatialCropd(
225
+ keys="image",
226
+ roi_size=args.crop_spatial_size,
227
+ random_size=False,
228
+ ),
229
+ RandFlipd(keys="image", prob=0.5, spatial_axis=2),
230
+ # BinarizeLabeld(keys=["label"])
231
+ RandScaleIntensityd(keys="image", factors=0.1, prob=0.8),
232
+ RandShiftIntensityd(keys="image", offsets=0.1, prob=0.8),
233
+ RandBiasFieldd(keys="image", prob=0.2),
234
+ RandGaussianSmoothd(keys="image", prob=1.0)
235
+ ]
236
+
237
+ train_transforms = Compose(train_transforms)
238
+ val_transforms = Compose(
239
+ [
240
+ NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
241
+ CenterSpatialCropd(keys="image", roi_size=args.crop_spatial_size),
242
+ SpatialPadd(keys="image", spatial_size=[i for i in args.crop_spatial_size]),
243
+ # BinarizeLabeld(keys=["label"])
244
+ ]
245
+ )
246
+ test_transforms = Compose(
247
+ [
248
+ NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
249
+ CenterSpatialCropd(keys="image", roi_size=args.crop_spatial_size),
250
+ SpatialPadd(keys="image", spatial_size=[i for i in args.crop_spatial_size]),
251
+ # BinarizeLabeld(keys=["label"])
252
+ ]
253
+ )
254
+ return train_transforms, val_transforms, test_transforms
255
+
256
+
257
+ def build_Risk_loader(args):
258
+ train_transforms, val_transforms, test_transforms = get_transforms(args)
259
+
260
+ if args.demo:
261
+ test_set = RiskSet(args, "demo/data/risk/test.csv", 'test', test_transforms)
262
+ test_loader = DataLoader(
263
+ test_set,
264
+ batch_size=args.batch_size,
265
+ shuffle=False,
266
+ pin_memory=True,
267
+ num_workers=14,
268
+ drop_last=False,
269
+ )
270
+ args.in_channels = 3
271
+ args.num_classes = 4
272
+ return test_loader
273
+ else:
274
+ if args.data20:
275
+ train_set = RiskSet(args, "spilt/risk/train_16.csv", 'train', train_transforms)
276
+ else:
277
+ train_set = RiskSet(args, "spilt/risk/train.csv", 'train', train_transforms)
278
+ val_set = RiskSet(args, "spilt/risk/val.csv", 'val', val_transforms)
279
+ test_set = RiskSet(args, "spilt/risk/test.csv", 'test', test_transforms)
280
+
281
+ sampler = WeightedRandomSampler(
282
+ weights=train_set.sampler_weight, num_samples=len(train_set), replacement=True
283
+ )
284
+ train_loader = DataLoader(
285
+ train_set,
286
+ batch_size=args.batch_size,
287
+ sampler=sampler,
288
+ num_workers=args.num_workers,
289
+ drop_last=False,
290
+ pin_memory=True,
291
+ )
292
+ val_loader = DataLoader(
293
+ val_set,
294
+ batch_size=args.batch_size,
295
+ shuffle=False,
296
+ pin_memory=True,
297
+ num_workers=14,
298
+ drop_last=False,
299
+ )
300
+ test_loader = DataLoader(
301
+ test_set,
302
+ batch_size=args.batch_size,
303
+ shuffle=False,
304
+ pin_memory=True,
305
+ num_workers=14,
306
+ drop_last=False,
307
+ )
308
+ args.in_channels = 3
309
+ args.num_classes = 4
310
+ return train_loader, val_loader, test_loader
311
+
312
+
313
+ def build_Screening_loader(args):
314
+ train_transforms, val_transforms, test_transforms = get_transforms(args)
315
+ if args.kfold is None:
316
+ if args.data20:
317
+ train_set = ScreeningSet(
318
+ args, "spilt/screening/train_20.csv", 'train', train_transforms
319
+ )
320
+ else:
321
+ train_set = ScreeningSet(
322
+ args, "spilt/screening/train.csv", 'train', train_transforms
323
+ )
324
+ val_set = ScreeningSet(args, "spilt/screening/val.csv", 'val', val_transforms)
325
+ test_set = ScreeningSet(args, "spilt/screening/test.csv", 'test', test_transforms)
326
+ args.cls_account = train_set.cal_weight() / len(train_set)
327
+ else:
328
+ train_set = ScreeningSet(
329
+ args, f"spilt/screening/train_{args.kfold}.csv", train_transforms
330
+ )
331
+ args.cls_account = train_set.cal_weight() / len(train_set)
332
+ train_set, val_set = torch.utils.data.random_split(train_set, [0.9, 0.1])
333
+ val_set.transforms = val_transforms
334
+ test_set = ScreeningSet(
335
+ args, f"spilt/screening/test_{args.kfold}.csv", test_transforms
336
+ )
337
+
338
+ # sampler_weight = [train_set.dataset.sampler_weight[i] for i in train_set.indices]
339
+ sampler = WeightedRandomSampler(
340
+ weights=train_set.sampler_weight, num_samples=len(train_set), replacement=True
341
+ )
342
+ train_loader = DataLoader(
343
+ train_set,
344
+ batch_size=args.batch_size,
345
+ sampler=sampler,
346
+ num_workers=args.num_workers,
347
+ drop_last=True,
348
+ pin_memory=True,
349
+ )
350
+ val_loader = DataLoader(
351
+ val_set,
352
+ batch_size=args.batch_size,
353
+ shuffle=False,
354
+ pin_memory=True,
355
+ num_workers=14,
356
+ drop_last=False,
357
+ )
358
+ test_loader = DataLoader(
359
+ test_set,
360
+ batch_size=args.batch_size,
361
+ shuffle=False,
362
+ pin_memory=True,
363
+ num_workers=14,
364
+ drop_last=False,
365
+ )
366
+ args.in_channels = 3
367
+ args.num_classes = 2
368
+ return train_loader, val_loader, test_loader
369
+
370
+
371
+ # 4.0 453
372
+ # 3.0 206
373
+ # 5.0 195
374
+ # 2.0 174
375
+
376
+
377
+ def build_Promis_loader(args):
378
+ train_transforms, val_transforms, test_transforms = get_transforms(args)
379
+ if args.data20:
380
+ train_set = PromisSet(args, "spilt/promis567_hist/train_20.csv", 'train', train_transforms)
381
+ else:
382
+ train_set = PromisSet(args, "spilt/promis567_hist/train.csv", 'train', train_transforms)
383
+ val_set = PromisSet(args, "spilt/promis567_hist/val.csv", 'val', val_transforms)
384
+ test_set = PromisSet(args, "spilt/promis567_hist/test.csv", 'test', test_transforms)
385
+
386
+ # sampler = WeightedRandomSampler(
387
+ # weights=train_set.sampler_weight, num_samples=len(train_set), replacement=True
388
+ # )
389
+ train_loader = DataLoader(
390
+ train_set,
391
+ batch_size=args.batch_size,
392
+ num_workers=args.num_workers,
393
+ drop_last=True,
394
+ pin_memory=True,
395
+ )
396
+ val_loader = DataLoader(
397
+ val_set,
398
+ batch_size=args.batch_size,
399
+ shuffle=False,
400
+ pin_memory=True,
401
+ num_workers=14,
402
+ drop_last=False,
403
+ )
404
+ test_loader = DataLoader(
405
+ test_set,
406
+ batch_size=args.batch_size,
407
+ shuffle=False,
408
+ pin_memory=True,
409
+ num_workers=14,
410
+ drop_last=False,
411
+ )
412
+ args.in_channels = 3
413
+ args.num_classes = 20
414
+ return train_loader, val_loader, test_loader
415
+
416
+ def build_Promis3_hist_loader(args):
417
+ train_transforms, val_transforms, test_transforms = get_transforms(args)
418
+ train_set = Promis3HistSet(args, "spilt/promis_pirads3_hist/train.csv", 'train', train_transforms)
419
+ val_set = Promis3HistSet(args, "spilt/promis_pirads3_hist/val.csv", 'val', val_transforms)
420
+ test_set = Promis3HistSet(args, "spilt/promis_pirads3_hist/test.csv", 'test', test_transforms)
421
+
422
+ train_loader = DataLoader(
423
+ train_set,
424
+ batch_size=args.batch_size,
425
+ num_workers=args.num_workers,
426
+ drop_last=True,
427
+ pin_memory=True,
428
+ )
429
+ val_loader = DataLoader(
430
+ val_set,
431
+ batch_size=args.batch_size,
432
+ shuffle=False,
433
+ pin_memory=True,
434
+ num_workers=14,
435
+ drop_last=False,
436
+ )
437
+ test_loader = DataLoader(
438
+ test_set,
439
+ batch_size=args.batch_size,
440
+ shuffle=False,
441
+ pin_memory=True,
442
+ num_workers=14,
443
+ drop_last=False,
444
+ )
445
+ args.in_channels = 3
446
+ args.num_classes = 3
447
+ return train_loader, val_loader, test_loader
448
+
449
+ def build_Promis3_gg_loader(args):
450
+ train_transforms, val_transforms, test_transforms = get_transforms(args)
451
+ train_set = Promis3GGSet(args, "spilt/promis_pirads3_gg/train.csv", 'train', train_transforms)
452
+ val_set = Promis3GGSet(args, "spilt/promis_pirads3_gg/val.csv", 'val', val_transforms)
453
+ test_set = Promis3GGSet(args, "spilt/promis_pirads3_gg/test.csv", 'test', test_transforms)
454
+
455
+ train_loader = DataLoader(
456
+ train_set,
457
+ batch_size=args.batch_size,
458
+ num_workers=args.num_workers,
459
+ drop_last=True,
460
+ pin_memory=True,
461
+ )
462
+ val_loader = DataLoader(
463
+ val_set,
464
+ batch_size=args.batch_size,
465
+ shuffle=False,
466
+ pin_memory=True,
467
+ num_workers=14,
468
+ drop_last=False,
469
+ )
470
+ test_loader = DataLoader(
471
+ test_set,
472
+ batch_size=args.batch_size,
473
+ shuffle=False,
474
+ pin_memory=True,
475
+ num_workers=14,
476
+ drop_last=False,
477
+ )
478
+ args.in_channels = 3
479
+ args.num_classes = 5
480
+ return train_loader, val_loader, test_loader
dataset/dataset_seg.py ADDED
@@ -0,0 +1,556 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ from monai.transforms import (
3
+ Compose,
4
+ RandCropByPosNegLabeld,
5
+ CropForegroundd,
6
+ SpatialPadd,
7
+ ScaleIntensityRanged,
8
+ RandShiftIntensityd,
9
+ RandFlipd,
10
+ RandAffined,
11
+ RandZoomd,
12
+ RandRotated,
13
+ RandRotate90d,
14
+ RandGaussianNoised,
15
+ RandGaussianSmoothd,
16
+ NormalizeIntensityd,
17
+ RandBiasFieldd,
18
+ MapTransform,
19
+ RandScaleIntensityd,
20
+ RandSpatialCropd,
21
+ CenterSpatialCropd,
22
+ )
23
+
24
+ from torch.utils.data import DataLoader, Dataset
25
+ import torch
26
+ import numpy as np
27
+ import nibabel as nib
28
+ import torch.nn.functional as F
29
+ import os
30
+ import pandas as pd
31
+
32
+
33
+ class BaseVolumeDataset(Dataset):
34
+ def __init__(self, args, image_paths, phase, transforms=None):
35
+ super().__init__()
36
+ self.img_dict = pd.read_csv(image_paths)
37
+ if phase == 'train':
38
+ if args.data_num > 0:
39
+ # crop the dataset
40
+ self.img_dict = self.img_dict.iloc[: args.data_num]
41
+ print(f"Loading {phase} dataset with {len(self.img_dict)} samples")
42
+ self.root = args.root
43
+ self._set_dataset_stat()
44
+ self.transforms = transforms # self.get_transforms()
45
+
46
+ def _set_dataset_stat(self):
47
+ self.spacing = (0.5, 0.5, 1.0)
48
+ self.spatial_index = [2, 1, 0] # index used to convert to DHW
49
+ self.target_class = 1
50
+
51
+ def __len__(self):
52
+ return len(self.img_dict)
53
+
54
+ def read(self, path):
55
+ vol = nib.load(os.path.join(self.root, path))
56
+ vol = vol.get_fdata().astype(np.float32).transpose(self.spatial_index)
57
+ vol = torch.from_numpy(vol)
58
+ return vol
59
+
60
+ def __getitem__(self, idx):
61
+ return NotImplemented
62
+
63
+
64
+ class UCLSet(BaseVolumeDataset):
65
+ def __init__(self, args, image_paths, phase, transforms=None):
66
+ super().__init__(args=args, image_paths=image_paths, phase=phase, transforms=transforms)
67
+
68
+ def __getitem__(self, idx):
69
+ path = self.img_dict.iloc[idx]
70
+ t2w = self.read(path["t2w"])
71
+ dwi = self.read(path["dwi"])
72
+ adc = self.read(path["adc"])
73
+ img = torch.stack([t2w, dwi, adc], 0)
74
+ seg = self.read(path["lesion"]).unsqueeze(0)
75
+ seg = seg > 0
76
+ # print(img.shape)
77
+ # seg = (seg == self.target_class).float()
78
+ if self.transforms is not None:
79
+ trans_dict = self.transforms({"image": img, "label": seg})
80
+ if type(trans_dict) == list:
81
+ trans_dict = trans_dict[0]
82
+ img, seg = trans_dict["image"], trans_dict["label"]
83
+ return img, seg, torch.tensor(idx, dtype=torch.long)
84
+
85
+ # TODO: need to update; unfinished
86
+ """
87
+ class UCL2DSet(BaseVolumeDataset):
88
+ def __init__(self, args, image_paths, phase, transforms=None):
89
+ super().__init__(args=args, image_paths=image_paths, phase=phase, transforms=transforms)
90
+
91
+ def __getitem__(self, idx):
92
+ path = self.img_dict.iloc[idx]
93
+ t2w = self.read(path["t2w"])
94
+ dwi = self.read(path["dwi"])
95
+ adc = self.read(path["adc"])
96
+
97
+ seg = self.read(path["lesion"]).unsqueeze(0)
98
+ seg = seg > 0
99
+
100
+ seg_mask = seg.squeeze(0).numpy()
101
+ non_zero_slices = np.where(seg_mask.any(axis=1,2))[0]
102
+ if len(non_zero_slices) > 0:
103
+ sampled_slices = np.random.choice(non_zero_slices, min(N, len(non_zero_slices)), replace=False)
104
+ filtered_seg = np.zeros_like(seg_mask)
105
+ filtered_seg[sampled_slices] = seg_mask[sampled_slices]
106
+ else:
107
+ filtered_seg = seg_mask
108
+
109
+ img = torch.stack([t2w, dwi, adc], 0)
110
+ seg = torch.tensor(filtered_seg, dtype=torch.float32).unsqueeze(0)
111
+ if self.transforms is not None:
112
+ trans_dict = self.transforms({"image": img, "label": seg})
113
+ if type(trans_dict) == list:
114
+ trans_dict = trans_dict[0]
115
+ img, seg = trans_dict["image"], trans_dict["label"]
116
+ return img, seg, torch.tensor(idx, dtype=torch.long)
117
+ """
118
+
119
+ class AnatomySet(BaseVolumeDataset):
120
+ def __init__(self, args, image_paths, phase, transforms=None):
121
+ super().__init__(args=args, image_paths=image_paths, phase=phase, transforms=transforms)
122
+ def __getitem__(self, idx):
123
+ path = self.img_dict.iloc[idx]
124
+ t2w = self.read(path["t2w"])
125
+ # img = t2w.unsqueeze(0)
126
+ zero = torch.zeros_like(t2w)
127
+ # modified to align img to 3 channel
128
+ img = torch.stack([t2w, zero, zero], 0)
129
+ seg = self.read(path["mask"]).unsqueeze(0)
130
+ if self.transforms is not None:
131
+ trans_dict = self.transforms({"image": img, "label": seg})
132
+ if type(trans_dict) == list:
133
+ trans_dict = trans_dict[0]
134
+ img, seg = trans_dict["image"], trans_dict["label"]
135
+ return img, seg, torch.tensor(idx, dtype=torch.long)
136
+
137
+
138
+ class BpAnatomySet(BaseVolumeDataset):
139
+ def __init__(self, args, image_paths, phase, transforms=None):
140
+ super().__init__(args=args, image_paths=image_paths, phase=phase, transforms=transforms)
141
+
142
+ def __getitem__(self, idx):
143
+ path = self.img_dict.iloc[idx]
144
+ t2w = self.read(path["t2w"])
145
+ zero = torch.zeros_like(t2w)
146
+ img = torch.stack([t2w, zero, zero], 0)
147
+ seg = self.read(path["mask"]).unsqueeze(0)
148
+ if self.transforms is not None:
149
+ trans_dict = self.transforms({"image": img, "label": seg})
150
+ if type(trans_dict) == list:
151
+ trans_dict = trans_dict[0]
152
+ img, seg = trans_dict["image"], trans_dict["label"]
153
+ return img, seg, torch.tensor(idx, dtype=torch.long)
154
+
155
+ class PromisHist(BaseVolumeDataset):
156
+ def __init__(self, args, image_paths, phase, transforms=None):
157
+ super().__init__(args=args, image_paths=image_paths, phase=phase, transforms=transforms)
158
+
159
+ def __getitem__(self, idx):
160
+ path = self.img_dict.iloc[idx]
161
+ t2w = self.read(path["t2w"])
162
+ dwi = self.read(path["dwi"])
163
+ adc = self.read(path["adc"])
164
+ img = torch.stack([t2w, dwi, adc], 0)
165
+
166
+ zone_mask = self.read(path["gland"]).unsqueeze(0)
167
+
168
+ zone_level = list(map(int, path["zone_label"].split()))
169
+ zone_level = torch.tensor(zone_level)
170
+
171
+ if self.transforms is not None:
172
+ trans_dict = self.transforms({"image": img, "label": zone_mask})
173
+ if type(trans_dict) == list:
174
+ trans_dict = trans_dict[0]
175
+ img, zone_mask = trans_dict["image"], trans_dict["label"]
176
+
177
+ return img, zone_mask, zone_level
178
+
179
+ class PromisZone(BaseVolumeDataset):
180
+ def __init__(self, args, image_paths, phase, transforms=None):
181
+ super().__init__(args=args, image_paths=image_paths, phase=phase, transforms=transforms)
182
+
183
+ def __getitem__(self, idx):
184
+ path = self.img_dict.iloc[idx]
185
+ t2w = self.read(path["t2w"])
186
+ dwi = self.read(path["dwi"])
187
+ adc = self.read(path["adc"])
188
+ img = torch.stack([t2w, dwi, adc], 0)
189
+
190
+ zone_mask = self.read(path["zome_mask"]).unsqueeze(0)
191
+
192
+ zone_level = list(map(int, path["zone_label"].split()))
193
+ zone_level = torch.tensor(zone_level)
194
+
195
+ if self.transforms is not None:
196
+ trans_dict = self.transforms({"image": img, "label": zone_mask})
197
+ if type(trans_dict) == list:
198
+ trans_dict = trans_dict[0]
199
+ img, zone_mask = trans_dict["image"], trans_dict["label"]
200
+
201
+ return img, zone_mask, zone_level
202
+
203
+ def get_transforms(args):
204
+ train_transforms = [
205
+ NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
206
+ RandRotated(
207
+ keys=["image", "label"],
208
+ prob=0.3,
209
+ range_x=30 / 180 * np.pi,
210
+ keep_size=False,
211
+ mode=["bilinear", "nearest"],
212
+ ),
213
+ RandZoomd(
214
+ keys=["image", "label"],
215
+ prob=0.3,
216
+ min_zoom=[1, 0.9, 0.9],
217
+ max_zoom=[1, 1.1, 1.1],
218
+ mode=["trilinear", "nearest"],
219
+ ),
220
+ SpatialPadd(
221
+ keys=["image", "label"],
222
+ spatial_size=[round(i * 1.2) for i in args.crop_spatial_size],
223
+ ),
224
+ # RandCropByPosNegLabeld(
225
+ # keys=["image", "label"],
226
+ # spatial_size=[round(i * 1.2) for i in args.crop_spatial_size],
227
+ # label_key="label",
228
+ # pos=2,
229
+ # neg=1,
230
+ # num_samples=1,
231
+ # ),
232
+ RandSpatialCropd(
233
+ keys=["image", "label"],
234
+ roi_size=args.crop_spatial_size,
235
+ random_size=False,
236
+ ),
237
+ RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
238
+ # BinarizeLabeld(keys=["label"])
239
+ RandScaleIntensityd(keys="image", factors=0.1, prob=0.8),
240
+ RandShiftIntensityd(keys="image", offsets=0.1, prob=0.8),
241
+ RandBiasFieldd(keys="image", prob=0.2),
242
+ RandGaussianSmoothd(keys="image", prob=1.0)
243
+ ]
244
+
245
+ train_transforms = Compose(train_transforms)
246
+ val_transforms = Compose(
247
+ [
248
+ NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
249
+ CenterSpatialCropd(
250
+ keys=["image", "label"], roi_size=args.crop_spatial_size
251
+ ),
252
+ SpatialPadd(
253
+ keys=["image", "label"],
254
+ spatial_size=[i for i in args.crop_spatial_size],
255
+ ),
256
+ # BinarizeLabeld(keys=["label"])
257
+ ]
258
+ )
259
+ test_transforms = Compose(
260
+ [
261
+ NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
262
+ CenterSpatialCropd(
263
+ keys=["image", "label"], roi_size=args.crop_spatial_size
264
+ ),
265
+ SpatialPadd(
266
+ keys=["image", "label"],
267
+ spatial_size=[i for i in args.crop_spatial_size],
268
+ ),
269
+ # BinarizeLabeld(keys=["label"])
270
+ ]
271
+ )
272
+ return train_transforms, val_transforms, test_transforms
273
+
274
+
275
+ def build_UCL_loader(args):
276
+ train_transforms, val_transforms, test_transforms = get_transforms(args)
277
+ if args.demo:
278
+ test_set = UCLSet(args, "demo/data/UCL/test.csv", 'test', test_transforms)
279
+ test_loader = DataLoader(
280
+ test_set,
281
+ batch_size=1,
282
+ shuffle=False,
283
+ pin_memory=True,
284
+ num_workers=14,
285
+ drop_last=False,
286
+ )
287
+ args.in_channels = 3
288
+ args.out_channels = 1
289
+ args.num_classes = 1
290
+ return test_loader
291
+ else:
292
+ if args.data20:
293
+ train_set = UCLSet(args, "spilt/UCL/train_16.csv", 'train', train_transforms)
294
+ else:
295
+ train_set = UCLSet(args, "spilt/UCL/train.csv", 'train', train_transforms)
296
+ val_set = UCLSet(args, "spilt/UCL/val.csv", 'val', val_transforms)
297
+ test_set = UCLSet(args, "spilt/UCL/test.csv", 'test', test_transforms)
298
+ train_loader = DataLoader(
299
+ train_set,
300
+ batch_size=args.batch_size,
301
+ shuffle=True,
302
+ pin_memory=True,
303
+ num_workers=14,
304
+ drop_last=True,
305
+ )
306
+ val_loader = DataLoader(
307
+ val_set,
308
+ batch_size=args.batch_size,
309
+ shuffle=False,
310
+ pin_memory=True,
311
+ num_workers=14,
312
+ drop_last=False,
313
+ )
314
+ test_loader = DataLoader(
315
+ test_set,
316
+ batch_size=1,
317
+ shuffle=False,
318
+ pin_memory=True,
319
+ num_workers=14,
320
+ drop_last=False,
321
+ )
322
+ args.in_channels = 3
323
+ args.out_channels = 1
324
+ args.num_classes = 1
325
+ return train_loader, val_loader, test_loader
326
+
327
+
328
+ def build_Promis_loader(args):
329
+ train_transforms, val_transforms, test_transforms = get_transforms(args)
330
+ if args.data20:
331
+ train_set = UCLSet(args, "spilt/promis567/train_20.csv", 'train', train_transforms)
332
+ else:
333
+ train_set = UCLSet(args, "spilt/promis567/train.csv", 'train', train_transforms)
334
+ val_set = UCLSet(args, "spilt/promis567/val.csv", 'val', val_transforms)
335
+ test_set = UCLSet(args, "spilt/promis567/test.csv", 'test', test_transforms)
336
+ train_loader = DataLoader(
337
+ train_set,
338
+ batch_size=args.batch_size,
339
+ shuffle=True,
340
+ pin_memory=True,
341
+ num_workers=14,
342
+ drop_last=False,
343
+ )
344
+ val_loader = DataLoader(
345
+ val_set,
346
+ batch_size=args.batch_size,
347
+ shuffle=False,
348
+ pin_memory=True,
349
+ num_workers=14,
350
+ drop_last=False,
351
+ )
352
+ test_loader = DataLoader(
353
+ test_set,
354
+ batch_size=1,
355
+ shuffle=False,
356
+ pin_memory=True,
357
+ num_workers=14,
358
+ drop_last=False,
359
+ )
360
+ args.in_channels = 3
361
+ args.out_channels = 1
362
+ args.num_classes = 1
363
+ return train_loader, val_loader, test_loader
364
+
365
+
366
+ def build_Anatomy_loader(args):
367
+ train_transforms, val_transforms, test_transforms = get_transforms(args)
368
+ if args.data20:
369
+ train_set = AnatomySet(args, "spilt/anatomy/train_20.csv", 'train', train_transforms)
370
+ else:
371
+ train_set = AnatomySet(args, "spilt/anatomy/train.csv", 'train', train_transforms)
372
+ val_set = AnatomySet(args, "spilt/anatomy/val.csv", 'val', val_transforms)
373
+ test_set = AnatomySet(
374
+ args,
375
+ "spilt/anatomy/test.csv",
376
+ 'test',
377
+ NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
378
+ )
379
+ train_loader = DataLoader(
380
+ train_set,
381
+ batch_size=args.batch_size,
382
+ shuffle=True,
383
+ pin_memory=True,
384
+ num_workers=14,
385
+ drop_last=False,
386
+ )
387
+ val_loader = DataLoader(
388
+ val_set,
389
+ batch_size=args.batch_size,
390
+ shuffle=False,
391
+ pin_memory=True,
392
+ num_workers=14,
393
+ drop_last=False,
394
+ )
395
+ test_loader = DataLoader(
396
+ test_set,
397
+ batch_size=1,
398
+ shuffle=False,
399
+ pin_memory=True,
400
+ num_workers=14,
401
+ drop_last=False,
402
+ )
403
+ if args.prompt:
404
+ # TODO: need to update; currently not in use
405
+ args.in_channels = 3
406
+ else:
407
+ args.in_channels = 3
408
+ args.out_channels = 9
409
+ args.num_classes = 8
410
+ return train_loader, val_loader, test_loader
411
+
412
+
413
+ def build_BpAnatomy_loader(args):
414
+ train_transforms, val_transforms, test_transforms = get_transforms(args)
415
+ if args.data20:
416
+ train_set = BpAnatomySet(args, "spilt/anatomy/train_20.csv", 'train', train_transforms)
417
+ else:
418
+ train_set = BpAnatomySet(args, "spilt/anatomy/train.csv", 'train', train_transforms)
419
+ val_set = BpAnatomySet(args, "spilt/anatomy/val.csv", 'val', val_transforms)
420
+ test_set = BpAnatomySet(
421
+ args,
422
+ "spilt/anatomy/test.csv",
423
+ 'test',
424
+ NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
425
+ )
426
+ train_loader = DataLoader(
427
+ train_set,
428
+ batch_size=args.batch_size,
429
+ shuffle=True,
430
+ num_workers=4,
431
+ drop_last=False,
432
+ )
433
+ val_loader = DataLoader(
434
+ val_set,
435
+ batch_size=args.batch_size,
436
+ shuffle=False,
437
+ num_workers=4,
438
+ drop_last=False,
439
+ )
440
+ test_loader = DataLoader(
441
+ test_set, batch_size=1, shuffle=False, num_workers=4, drop_last=False
442
+ )
443
+ args.in_channels = 3
444
+ args.out_channels = 9
445
+ args.num_classes = 8
446
+ return train_loader, val_loader, test_loader
447
+
448
+
449
+ def build_PromisHist_loader(args):
450
+ train_transforms, val_transforms, test_transforms = get_transforms(args)
451
+ if args.data20:
452
+ train_set = PromisHist(args, "spilt/promis567_hist/train_20.csv", 'train', train_transforms)
453
+ else:
454
+ train_set = PromisHist(args, "spilt/promis567_hist/train.csv", 'train', train_transforms)
455
+ val_set = PromisHist(args, "spilt/promis567_hist/val.csv", 'val', val_transforms)
456
+ test_set = PromisHist(args, "spilt/promis567_hist/test.csv", 'test', test_transforms)
457
+ train_loader = DataLoader(
458
+ train_set,
459
+ batch_size=args.batch_size,
460
+ shuffle=True,
461
+ pin_memory=True,
462
+ num_workers=14,
463
+ drop_last=False,
464
+ )
465
+ val_loader = DataLoader(
466
+ val_set,
467
+ batch_size=args.batch_size,
468
+ shuffle=False,
469
+ pin_memory=True,
470
+ num_workers=14,
471
+ drop_last=True,
472
+ )
473
+ test_loader = DataLoader(
474
+ test_set,
475
+ batch_size=1,
476
+ shuffle=False,
477
+ pin_memory=True,
478
+ num_workers=14,
479
+ drop_last=False,
480
+ )
481
+ args.in_channels = 3
482
+ args.out_channels = 1
483
+ args.num_classes = 1
484
+ return train_loader, val_loader, test_loader
485
+
486
+ def build_PromisZone_loader(args):
487
+ train_transforms, val_transforms, test_transforms = get_transforms(args)
488
+ train_set = PromisZone(args, "spilt/promis_zone/train.csv", 'train', train_transforms)
489
+ val_set = PromisZone(args, "spilt/promis_zone/val.csv", 'val', val_transforms)
490
+ test_set = PromisZone(args, "spilt/promis_zone/test.csv", 'test', test_transforms)
491
+ train_loader = DataLoader(
492
+ train_set,
493
+ batch_size=args.batch_size,
494
+ shuffle=True,
495
+ pin_memory=True,
496
+ num_workers=14,
497
+ drop_last=True,
498
+ )
499
+ val_loader = DataLoader(
500
+ val_set,
501
+ batch_size=args.batch_size,
502
+ shuffle=False,
503
+ pin_memory=True,
504
+ num_workers=14,
505
+ drop_last=True,
506
+ )
507
+ test_loader = DataLoader(
508
+ test_set,
509
+ batch_size=1,
510
+ shuffle=False,
511
+ pin_memory=True,
512
+ num_workers=14,
513
+ drop_last=False,
514
+ )
515
+ args.in_channels = 3
516
+ args.out_channels = 1
517
+ args.num_classes = 1
518
+ return train_loader, val_loader, test_loader
519
+
520
+
521
+ def build_PromisPirads3_loader(args):
522
+ train_transforms, val_transforms, test_transforms = get_transforms(args)
523
+ if args.data20:
524
+ train_set = UCLSet(args, "spilt/promis_pirads3/train_15.csv", 'train', train_transforms)
525
+ else:
526
+ train_set = UCLSet(args, "spilt/promis_pirads3/train.csv", 'train', train_transforms)
527
+ val_set = UCLSet(args, "spilt/promis_pirads3/val.csv", 'val', val_transforms)
528
+ test_set = UCLSet(args, "spilt/promis_pirads3/test.csv", 'test', test_transforms)
529
+ train_loader = DataLoader(
530
+ train_set,
531
+ batch_size=args.batch_size,
532
+ shuffle=True,
533
+ pin_memory=True,
534
+ num_workers=14,
535
+ drop_last=False,
536
+ )
537
+ val_loader = DataLoader(
538
+ val_set,
539
+ batch_size=args.batch_size,
540
+ shuffle=False,
541
+ pin_memory=True,
542
+ num_workers=14,
543
+ drop_last=False,
544
+ )
545
+ test_loader = DataLoader(
546
+ test_set,
547
+ batch_size=1,
548
+ shuffle=False,
549
+ pin_memory=True,
550
+ num_workers=14,
551
+ drop_last=False,
552
+ )
553
+ args.in_channels = 3
554
+ args.out_channels = 1
555
+ args.num_classes = 1
556
+ return train_loader, val_loader, test_loader
demo_classfication.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # DeiT: https://github.com/facebookresearch/deit
9
+ # BEiT: https://github.com/microsoft/unilm/tree/master/beit
10
+ # --------------------------------------------------------
11
+ import argparse
12
+ import datetime
13
+ import json
14
+ import numpy as np
15
+ import os
16
+ import time
17
+ from pathlib import Path
18
+ from typing import Callable, List, Optional, Tuple
19
+ import torch
20
+ import torch.backends.cudnn as cudnn
21
+ from models.classifier import Classifier
22
+ from models.convnextv2 import convnextv2_tiny, remap_checkpoint_keys, load_state_dict
23
+ from dataset.dataset_cls import build_Risk_loader, build_Screening_loader, build_Promis_loader, build_Promis3_hist_loader
24
+ from engine.classification import test_risk
25
+
26
+ def tuple_type(strings):
27
+ strings = strings.replace("(", "").replace(")", "")
28
+ mapped_int = map(int, strings.split(","))
29
+ return tuple(mapped_int)
30
+
31
+ def get_args_parser():
32
+ parser = argparse.ArgumentParser("segmentation", add_help=False)
33
+ parser.add_argument(
34
+ "--batch_size",
35
+ default=1,
36
+ type=int,
37
+ help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus",
38
+ )
39
+ parser.add_argument("--epochs", default=400, type=int)
40
+ parser.add_argument(
41
+ "--root", default="./", type=str
42
+ )
43
+ parser.add_argument("--crop_spatial_size", default=(64, 256, 256), type=tuple_type)
44
+
45
+ # Model parameters
46
+ parser.add_argument("--model", help="model name")
47
+ parser.add_argument(
48
+ "--input_size", default=(64, 256, 256), type=tuple_type, help="images input size"
49
+ )
50
+ parser.add_argument(
51
+ "--train",
52
+ default="scratch",
53
+ choices=["fintune", "freeze", "scratch"],
54
+ help="train method",
55
+ )
56
+ parser.add_argument("--pretrain", default=None, type=str)
57
+ parser.add_argument("--tolerance", default=5, type=int)
58
+ parser.add_argument("--spacing", default=(1.0, 0.5, 0.5), type=tuple)
59
+ # Optimizer parameters
60
+ parser.add_argument(
61
+ "--weight_decay", type=float, default=1e-5, help="weight decay (default: 1e-5)"
62
+ )
63
+ parser.add_argument(
64
+ "--lr",
65
+ default=0.1,
66
+ type=float,
67
+ metavar="LR",
68
+ help="learning rate (absolute lr)",
69
+ )
70
+ parser.add_argument(
71
+ "--min_lr",
72
+ type=float,
73
+ default=0.0,
74
+ metavar="LR",
75
+ help="lower lr bound for cyclic schedulers that hit 0",
76
+ )
77
+ parser.add_argument(
78
+ "--warmup_epochs", type=int, default=40, metavar="N", help="epochs to warmup LR"
79
+ )
80
+
81
+ # Dataset parameters
82
+ parser.add_argument(
83
+ "--output_dir",
84
+ default="./outputcls",
85
+ help="path where to save, empty for no saving",
86
+ )
87
+ parser.add_argument("--file_name", default="")
88
+ parser.add_argument("--ckpt_dir", default="./outputcls")
89
+ parser.add_argument(
90
+ "--log_dir", default="./outputcls", help="path where to tensorboard log"
91
+ )
92
+ parser.add_argument("--dataset", default="UCL", help="dataset name")
93
+ parser.add_argument(
94
+ "--device", default="cuda", help="device to use for training / testing"
95
+ )
96
+ parser.add_argument("--seed", default=0, type=int)
97
+ parser.add_argument("--resume", default="", help="resume from checkpoint")
98
+
99
+ parser.add_argument(
100
+ "--start_epoch", default=0, type=int, metavar="N", help="start epoch"
101
+ )
102
+ parser.add_argument("--num_workers", default=10, type=int)
103
+ parser.add_argument(
104
+ "--pin_mem",
105
+ action="store_true",
106
+ help="Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.",
107
+ )
108
+ parser.add_argument("--no_pin_mem", action="store_false", dest="pin_mem")
109
+ parser.set_defaults(pin_mem=True)
110
+
111
+ parser.add_argument("--data20", action="store_true", help="Use 20 training data")
112
+ parser.set_defaults(data20=False)
113
+
114
+ parser.add_argument("--data_num", default=0, type=int, help="number of train data")
115
+
116
+ parser.add_argument("--save_fig", action="store_true")
117
+ parser.set_defaults(save_fig=False)
118
+
119
+ parser.add_argument(
120
+ "--prompt", action="store_true", help="Use visual prompt tuning"
121
+ )
122
+ parser.set_defaults(data20=False)
123
+
124
+ parser.add_argument(
125
+ "--world_size", default=1, type=int, help="number of distributed processes"
126
+ )
127
+ parser.add_argument("--local_rank", default=-1, type=int)
128
+ parser.add_argument("--dist_on_itp", action="store_true")
129
+ parser.add_argument(
130
+ "--dist_url", default="env://", help="url used to set up distributed training"
131
+ )
132
+ parser.add_argument("--kfold", type=int, default=None)
133
+ parser.add_argument("--demo", type=bool, default=True, help="Run in demo mode")
134
+ return parser
135
+
136
+
137
+ def main(args):
138
+
139
+ device = "cuda"
140
+ # fix the seed for reproducibility
141
+ seed = args.seed
142
+ torch.manual_seed(seed)
143
+ np.random.seed(seed)
144
+ cudnn.benchmark = True
145
+
146
+
147
+ if args.dataset == "risk":
148
+ data_loader_test = build_Risk_loader(args)
149
+ # elif args.dataset == "screening":
150
+ # data_loader_train, data_loader_val, data_loader_test = build_Screening_loader(
151
+ # args
152
+ # )
153
+ # elif args.dataset == "promis":
154
+ # data_loader_train, data_loader_val, data_loader_test = build_Promis_loader(args)
155
+ # elif args.dataset == "promis3hist":
156
+ # data_loader_train, data_loader_val, data_loader_test = build_Promis3_hist_loader(args)
157
+ else:
158
+ raise NotImplementedError(f"unknown schedule sampler: {args.dataset}")
159
+ print(f"Loaded dataset: {args.dataset}, test set size: {len(data_loader_test.dataset)}")
160
+
161
+ if args.model == "profound_conv":
162
+ convnext = convnextv2_tiny(in_chans=3)
163
+ model = Classifier(convnext, args.num_classes)
164
+ else:
165
+ raise NotImplementedError(f"unknown model: {args.model}")
166
+
167
+ args.output_dir = os.path.join(args.output_dir, args.dataset)
168
+ os.makedirs(args.output_dir, exist_ok=True)
169
+
170
+ model.load_state_dict(torch.load(args.ckpt_dir, map_location='cpu', weights_only=False)["model"])
171
+ print(f"Loaded model from {args.ckpt_dir}")
172
+ model.to(device)
173
+ logits, gts = [], []
174
+ model.eval()
175
+ with torch.no_grad():
176
+ for idx, (img, gt, pid) in enumerate(data_loader_test):
177
+ img, gt = img.to(args.device), gt.to(args.device)
178
+ logit = model(img)
179
+ logits.append(logit)
180
+ gts.append(gt)
181
+
182
+ # if args.dataset == "risk":
183
+ # test_risk(logits, gts)
184
+ logits = torch.cat(logits, 0).squeeze().cpu().numpy()
185
+ gts = torch.cat(gts, 0).squeeze().cpu().numpy()
186
+ print(f"test results: logits {logits}, gts {gts}")
187
+ np.savez(os.path.join(args.output_dir, f"{args.file_name}.npz"), logits = logits, gts=gts)
188
+
189
+ if __name__ == "__main__":
190
+ args = get_args_parser()
191
+ args = args.parse_args()
192
+ main(args)
demo_segmentation.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # DeiT: https://github.com/facebookresearch/deit
9
+ # BEiT: https://github.com/microsoft/unilm/tree/master/beit
10
+ # --------------------------------------------------------
11
+ import argparse
12
+ import datetime
13
+ import json
14
+ import numpy as np
15
+ import os
16
+ import time
17
+ from pathlib import Path
18
+ from typing import Callable, List, Optional, Tuple
19
+ import torch
20
+ import torch.backends.cudnn as cudnn
21
+ from dataset.dataset_seg import (
22
+ build_UCL_loader,
23
+ build_Anatomy_loader,
24
+ build_BpAnatomy_loader,
25
+ build_Promis_loader,
26
+ build_PromisPirads3_loader
27
+ )
28
+ import monai
29
+ from monai.inferers import sliding_window_inference
30
+ from monai.metrics import compute_dice
31
+ import SimpleITK as sitk
32
+ from models.convnextv2 import convnextv2_tiny, remap_checkpoint_keys, load_state_dict
33
+ from models.convnext_unter import ConvnextUNETR
34
+ from models.upernet_module import UperNet
35
+
36
+
37
+ def tuple_type(strings):
38
+ strings = strings.replace("(", "").replace(")", "")
39
+ mapped_int = map(int, strings.split(","))
40
+ return tuple(mapped_int)
41
+
42
+
43
+
44
+ def get_args_parser():
45
+ parser = argparse.ArgumentParser("segmentation", add_help=False)
46
+ parser.add_argument(
47
+ "--batch_size",
48
+ default=1,
49
+ type=int,
50
+ help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus",
51
+ )
52
+ parser.add_argument("--epochs", default=400, type=int)
53
+ parser.add_argument(
54
+ "--root", default="./", type=str
55
+ )
56
+ parser.add_argument("--crop_spatial_size", default=(64, 256, 256), type=tuple_type)
57
+
58
+ # Model parameters
59
+ parser.add_argument("--model", help="model name")
60
+ parser.add_argument(
61
+ "--input_size", default=(64, 256, 256), type=tuple_type, help="images input size"
62
+ )
63
+ parser.add_argument(
64
+ "--train",
65
+ default="scratch",
66
+ choices=["fintune", "freeze", "scratch"],
67
+ help="train method",
68
+ )
69
+ parser.add_argument("--pretrain", default=None, type=str)
70
+ parser.add_argument("--tolerance", default=5, type=int)
71
+ parser.add_argument("--spacing", default=(1.0, 0.5, 0.5), type=tuple)
72
+ # Optimizer parameters
73
+ parser.add_argument(
74
+ "--weight_decay", type=float, default=1e-5, help="weight decay (default: 1e-5)"
75
+ )
76
+ parser.add_argument(
77
+ "--lr",
78
+ default=0.1,
79
+ type=float,
80
+ metavar="LR",
81
+ help="learning rate (absolute lr)",
82
+ )
83
+ parser.add_argument(
84
+ "--min_lr",
85
+ type=float,
86
+ default=0.0,
87
+ metavar="LR",
88
+ help="lower lr bound for cyclic schedulers that hit 0",
89
+ )
90
+ parser.add_argument(
91
+ "--warmup_epochs", type=int, default=40, metavar="N", help="epochs to warmup LR"
92
+ )
93
+
94
+ # Dataset parameters
95
+ parser.add_argument(
96
+ "--output_dir",
97
+ default="./outputseg",
98
+ help="path where to save, empty for no saving",
99
+ )
100
+ parser.add_argument("--file_name", default="")
101
+ parser.add_argument("--ckpt_dir", default="./outputseg")
102
+ parser.add_argument(
103
+ "--log_dir", default="./outputseg", help="path where to tensorboard log"
104
+ )
105
+ parser.add_argument("--dataset", default="UCL", help="dataset name")
106
+ parser.add_argument(
107
+ "--device", default="cuda", help="device to use for training / testing"
108
+ )
109
+ parser.add_argument("--seed", default=0, type=int)
110
+ parser.add_argument("--resume", default="", help="resume from checkpoint")
111
+
112
+ parser.add_argument(
113
+ "--start_epoch", default=0, type=int, metavar="N", help="start epoch"
114
+ )
115
+ parser.add_argument("--num_workers", default=10, type=int)
116
+ parser.add_argument(
117
+ "--pin_mem",
118
+ action="store_true",
119
+ help="Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.",
120
+ )
121
+ parser.add_argument("--no_pin_mem", action="store_false", dest="pin_mem")
122
+ parser.set_defaults(pin_mem=True)
123
+
124
+ parser.add_argument("--data20", action="store_true", help="Use 20 training data")
125
+ parser.set_defaults(data20=False)
126
+
127
+ parser.add_argument("--data_num", default=0, type=int, help="number of train data")
128
+
129
+ parser.add_argument("--save_fig", action="store_true")
130
+ parser.set_defaults(save_fig=False)
131
+
132
+ parser.add_argument(
133
+ "--prompt", action="store_true", help="Use visual prompt tuning"
134
+ )
135
+ parser.set_defaults(prompt=False)
136
+
137
+ parser.add_argument(
138
+ "--world_size", default=1, type=int, help="number of distributed processes"
139
+ )
140
+ parser.add_argument("--local_rank", default=-1, type=int)
141
+ parser.add_argument("--dist_on_itp", action="store_true")
142
+ parser.add_argument(
143
+ "--dist_url", default="env://", help="url used to set up distributed training"
144
+ )
145
+ parser.add_argument("--demo", type=bool, default=True, help="Run in demo mode")
146
+ return parser
147
+
148
+
149
+ def main(args):
150
+
151
+ device = "cuda"
152
+ # fix the seed for reproducibility
153
+ seed = args.seed
154
+ torch.manual_seed(seed)
155
+ np.random.seed(seed)
156
+ cudnn.benchmark = True
157
+
158
+ if args.dataset == "UCL":
159
+ data_loader_test = build_UCL_loader(args)
160
+ args.sliding_window = False
161
+
162
+ else:
163
+ raise NotImplementedError(f"unknown schedule sampler: {args.dataset}")
164
+ print(f"Loaded dataset: {args.dataset}, test set size: {len(data_loader_test)}")
165
+
166
+ if args.model == "profound_conv":
167
+ convnext = convnextv2_tiny(in_chans=3)
168
+ model = UperNet(
169
+ encoder=convnext,
170
+ in_channels=[96, 192, 384, 768],
171
+ out_channels=args.out_channels,
172
+ )
173
+ model = model.to(device)
174
+
175
+ elif args.model == "profound_conv_unetr3d":
176
+ convnext = convnextv2_tiny(in_chans=3)
177
+
178
+ model = ConvnextUNETR(
179
+ in_channels=3, out_channels=1, convnext=convnext, feature_size=32
180
+ )
181
+ model = model.to(device)
182
+
183
+ else:
184
+ raise NotImplementedError(f"unknown model: {args.model}")
185
+
186
+
187
+ args.output_dir = os.path.join(args.output_dir, args.dataset)
188
+ os.makedirs(args.output_dir, exist_ok=True)
189
+
190
+ model.load_state_dict(torch.load(args.ckpt_dir, weights_only=False)["model"])
191
+ print(f"Loaded model: {args.ckpt_dir}")
192
+
193
+ dice_list = []
194
+ model.eval()
195
+ with torch.no_grad():
196
+ for idx, (img, gt, pid) in enumerate(data_loader_test):
197
+ img, gt = img.to(args.device), gt.to(args.device)
198
+ if args.sliding_window:
199
+ pred = sliding_window_inference(
200
+ img, args.crop_spatial_size, 4, model, overlap=0.5
201
+ )
202
+ else:
203
+ pred = model(img)
204
+
205
+ if args.num_classes == 1:
206
+ pred = torch.sigmoid(pred) > 0.5
207
+ pred = pred.int()
208
+ else:
209
+ pred = torch.softmax(pred, dim=1)
210
+ pred = torch.argmax(pred, dim=1, keepdim=True)
211
+
212
+ dice = compute_dice(pred, gt) # compute_dice(pred, gt, False,num_classes=9)
213
+ print(pid, dice.item())
214
+ if not torch.isnan(dice):
215
+ dice_list.append(dice)
216
+ # dice = int(dice.mean()*10000)
217
+ img = img.squeeze().cpu().numpy()
218
+ pred = pred.squeeze().cpu().numpy()
219
+ gt = gt.squeeze().cpu().numpy()
220
+ if args.save_fig:
221
+ if idx < 20:
222
+ # print(img.shape,pred.shape, gt.shape )
223
+ sitk.WriteImage(
224
+ sitk.GetImageFromArray(img[0]),
225
+ os.path.join(args.output_dir, f"{idx}_t2w.nii.gz"),
226
+ )
227
+ sitk.WriteImage(
228
+ sitk.GetImageFromArray(img[1]),
229
+ os.path.join(args.output_dir, f"{idx}_dwi.nii.gz"),
230
+ )
231
+ sitk.WriteImage(
232
+ sitk.GetImageFromArray(pred),
233
+ os.path.join(args.output_dir, f"{idx}_pred.nii.gz"),
234
+ )
235
+ sitk.WriteImage(
236
+ sitk.GetImageFromArray(gt),
237
+ os.path.join(args.output_dir, f"{idx}_gt.nii.gz"),
238
+ )
239
+ dice_list = torch.stack(dice_list, 0)
240
+ np.save(
241
+ os.path.join(args.output_dir, f"{args.file_name}.npy"),
242
+ dice_list.cpu().numpy(),
243
+ )
244
+ print("dice mean: ", dice_list.mean().item())
245
+
246
+
247
+ if __name__ == "__main__":
248
+ args = get_args_parser()
249
+ args = args.parse_args()
250
+ main(args)
engine/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # ProFound engine package
engine/classification.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # DeiT: https://github.com/facebookresearch/deit
9
+ # BEiT: https://github.com/microsoft/unilm/tree/master/beit
10
+ # --------------------------------------------------------
11
+ import math
12
+ import sys
13
+ import torch
14
+ import os
15
+ import util.misc as misc
16
+ import util.lr_sched as lr_sched
17
+ import numpy as np
18
+ from util.metric import accuracy, ConfusionMatrix, kappa
19
+ from sklearn.metrics import (
20
+ roc_auc_score,
21
+ top_k_accuracy_score,
22
+ f1_score,
23
+ confusion_matrix,
24
+ )
25
+ from torchmetrics.classification import (
26
+ BinarySpecificityAtSensitivity,
27
+ BinarySensitivityAtSpecificity,
28
+ )
29
+
30
+
31
+ import pdb
32
+
33
+
34
+ def train_one_epoch(
35
+ model,
36
+ data_loader,
37
+ optimizer,
38
+ device,
39
+ epoch: int,
40
+ loss_scaler,
41
+ log_writer=None,
42
+ args=None,
43
+ ):
44
+ model.train(True)
45
+ metric_logger = misc.MetricLogger(delimiter=" ")
46
+ metric_logger.add_meter("lr", misc.SmoothedValue(window_size=1, fmt="{value:.6f}"))
47
+ header = "Epoch: [{}]".format(epoch)
48
+ print_freq = 20
49
+
50
+ if args.dataset == "promis":
51
+ loss_cal = torch.nn.BCEWithLogitsLoss()
52
+ else:
53
+ if args.num_classes > 1:
54
+ loss_cal = torch.nn.CrossEntropyLoss()
55
+ else:
56
+ loss_cal = torch.nn.BCEWithLogitsLoss()
57
+
58
+ optimizer.zero_grad()
59
+
60
+ if log_writer is not None:
61
+ print("log_dir: {}".format(log_writer.log_dir))
62
+ last_norm = 0.0
63
+ for data_iter_step, (img, gt, dataidx) in enumerate(
64
+ metric_logger.log_every(data_loader, print_freq, header)
65
+ ):
66
+ # we use a per iteration (instead of per epoch) lr scheduler
67
+ img, gt = img.to(device, non_blocking=True), gt.to(device, non_blocking=True)
68
+ lr_sched.adjust_learning_rate(
69
+ optimizer, data_iter_step / len(data_loader) + epoch, args
70
+ )
71
+ logit = model(img)
72
+ # print("logit: ", logit.shape, "gt: ", gt.shape, "image: ", img.shape)
73
+ loss = loss_cal(logit, gt)
74
+ loss_value = loss.item()
75
+
76
+ if not math.isfinite(loss_value):
77
+ print(
78
+ "nan",
79
+ torch.isnan(logit).any(),
80
+ torch.isnan(img).any(),
81
+ dataidx,
82
+ last_norm,
83
+ )
84
+ print(
85
+ "inf",
86
+ torch.isinf(logit).any(),
87
+ torch.isinf(img).any(),
88
+ dataidx,
89
+ last_norm,
90
+ )
91
+ print("Loss is {}, stopping training".format(loss_value))
92
+ sys.exit(1)
93
+
94
+ optimizer.zero_grad()
95
+ loss.backward()
96
+ # torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
97
+ optimizer.step()
98
+
99
+ # last_norm = loss_scaler(loss, optimizer, parameters=model.parameters())
100
+ # optimizer.zero_grad()
101
+ # torch.cuda.synchronize()
102
+ metric_logger.update(loss=loss_value)
103
+
104
+ lr = optimizer.param_groups[0]["lr"]
105
+ metric_logger.update(lr=lr)
106
+
107
+ loss_value_reduce = misc.all_reduce_mean(loss_value)
108
+ if log_writer is not None:
109
+ """We use epoch_1000x as the x-axis in tensorboard.
110
+ This calibrates different curves when batch size changes.
111
+ """
112
+ epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
113
+ log_writer.add_scalar("train_loss", loss_value_reduce, epoch_1000x)
114
+ log_writer.add_scalar("lr", lr, epoch_1000x)
115
+
116
+ # gather the stats from all processes
117
+ # metric_logger.synchronize_between_processes()
118
+ print("Averaged stats:", metric_logger)
119
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
120
+
121
+
122
+ def validation(model, data_loader_val, device, epoch, args):
123
+ model.eval()
124
+
125
+ if args.dataset == "promis":
126
+ loss_cal = torch.nn.BCEWithLogitsLoss()
127
+ else:
128
+ if args.num_classes > 1:
129
+ loss_cal = torch.nn.CrossEntropyLoss()
130
+ else:
131
+ loss_cal = torch.nn.BCEWithLogitsLoss()
132
+
133
+ with torch.no_grad():
134
+ loss_summary = []
135
+ for idx, (img, gt, _) in enumerate(data_loader_val):
136
+ img, gt = img.to(device), gt.to(device)
137
+ mask = model(img)
138
+ loss = loss_cal(mask, gt)
139
+ loss_summary.append(loss.detach().cpu().numpy())
140
+ print(
141
+ "epoch: {}/{}, iter: {}/{}".format(
142
+ epoch, args.epochs, idx, len(data_loader_val)
143
+ )
144
+ + " loss:"
145
+ + str(loss_summary[-1].flatten()[0])
146
+ )
147
+ avg_loss = np.mean(loss_summary)
148
+ print("Averaged stats:", str(avg_loss))
149
+ return avg_loss
150
+
151
+
152
+ def test(model, test_loader, args):
153
+ filepath_best = os.path.join(args.output_dir, "best.pth.tar")
154
+ model.load_state_dict(torch.load(filepath_best)["model"], weights_only=False)
155
+ model.eval()
156
+ prob, gts = [], []
157
+ with torch.no_grad():
158
+ for idx, (img, gt, _) in enumerate(test_loader):
159
+ img, gt = img.to(args.device), gt.to(args.device)
160
+ logit = model(img)
161
+ prob.append(logit)
162
+ gts.append(gt)
163
+
164
+ if args.dataset == "risk":
165
+ return test_risk(prob, gts)
166
+ elif args.dataset == "screening":
167
+ return test_screening(prob, gts)
168
+ elif args.dataset == "promis":
169
+ return test_promis(prob, gts)
170
+ else:
171
+ raise NotImplementedError(f"unknown dataset: {args.dataset}")
172
+
173
+
174
+ def test_risk(prob, gts):
175
+ log_stats = {}
176
+ prob = torch.cat(prob, 0)
177
+ prob = torch.softmax(prob, dim=-1).cpu().numpy()
178
+ gts = torch.cat(gts, 0).cpu().numpy()
179
+
180
+ score_acc = top_k_accuracy_score(gts, prob, k=1) * 100
181
+ score_qwk = kappa(gts, np.argmax(prob, 1))
182
+ score_auc = roc_auc_score(gts, prob, multi_class="ovr") * 100
183
+ score_f1 = f1_score(gts, np.argmax(prob, 1), average="macro") * 100
184
+
185
+ print("score")
186
+ print(f"acc\t auc \t qwk \t f1")
187
+ print(f"{score_acc:.2f} \t {score_auc:.2f} \t {score_qwk:.4f} \t {score_f1:.2f}")
188
+ log_stats["4-class_acc"] = f"{score_acc:.2f}"
189
+ log_stats["4-class_auc"] = f"{score_auc:.2f}"
190
+ log_stats["4-class_qwk"] = f"{score_qwk:.4f}"
191
+ log_stats["4-class_f1"] = f"{score_f1:.2f}"
192
+
193
+ # 2 3 4 5 four classes 0 1 2 3
194
+
195
+ sig_prob = np.sum(prob[:, 1:], -1)
196
+ sig_gts = (gts > 0).astype(int)
197
+ sig_acc = top_k_accuracy_score(sig_gts, sig_prob, k=1) * 100
198
+ sig_auc = roc_auc_score(sig_gts, sig_prob) * 100
199
+ sig_f1 = f1_score(sig_gts, sig_prob > 0.5) * 100
200
+
201
+ print("Pirads >=3")
202
+ print(f"auc \t f1 ")
203
+ print(f"{sig_auc:.2f} \t {sig_f1:.2f}")
204
+
205
+ log_stats["leq3_auc"]=f"{sig_auc:.2f}"
206
+ log_stats["leq3_f1"]=f"{sig_f1:.2f}"
207
+
208
+ for i in [0.8, 0.9]:
209
+ sig_spec = BinarySpecificityAtSensitivity(min_sensitivity=i, thresholds=None)
210
+ sig_specificity, _ = sig_spec(
211
+ torch.from_numpy(sig_prob), torch.from_numpy(sig_gts)
212
+ )
213
+ sig_specificity = sig_specificity * 100
214
+ sig_sens = BinarySensitivityAtSpecificity(min_specificity=i, thresholds=None)
215
+ sig_sensitivity, _ = sig_sens(
216
+ torch.from_numpy(sig_prob), torch.from_numpy(sig_gts)
217
+ )
218
+ sig_sensitivity = sig_sensitivity* 100
219
+
220
+ print(f"min: {i}")
221
+ print(f"Specificity at Sensitivity \t Sensitivity at Specificity")
222
+ print(f"{sig_specificity:.2f} \t {sig_sensitivity:.2f} ")
223
+ log_stats[f"leq3_specificity_at_{i}"]=f"{sig_specificity:.2f}"
224
+ log_stats[f"leq3_sensitivity_at_{i}"]=f"{sig_sensitivity:.2f}"
225
+
226
+ sig_prob = np.sum(prob[:, 2:], -1)
227
+ sig_gts = (gts > 1).astype(int)
228
+ sig_acc = top_k_accuracy_score(sig_gts, sig_prob, k=1) * 100
229
+ sig_auc = roc_auc_score(sig_gts, sig_prob) * 100
230
+ sig_f1 = f1_score(sig_gts, sig_prob > 0.5) * 100
231
+
232
+ print("Pirads >=4")
233
+ print(f"auc \t f1 ")
234
+ print(f"{sig_auc:.2f} \t {sig_f1:.2f}")
235
+
236
+ log_stats["leq4_auc"]=f"{sig_auc:.2f}"
237
+ log_stats["leq4_f1"]=f"{sig_f1:.2f}"
238
+
239
+ for i in [0.8, 0.9]:
240
+ sig_spec = BinarySpecificityAtSensitivity(min_sensitivity=i, thresholds=None)
241
+ sig_specificity, _ = sig_spec(
242
+ torch.from_numpy(sig_prob), torch.from_numpy(sig_gts)
243
+ )
244
+ sig_specificity = sig_specificity * 100
245
+ sig_sens = BinarySensitivityAtSpecificity(min_specificity=i, thresholds=None)
246
+ sig_sensitivity, _ = sig_sens(
247
+ torch.from_numpy(sig_prob), torch.from_numpy(sig_gts)
248
+ )
249
+ sig_sensitivity = sig_sensitivity* 100
250
+
251
+ print(f"min: {i}")
252
+ print(f"Specificity at Sensitivity \t Sensitivity at Specificity")
253
+ print(f"{sig_specificity:.2f} \t {sig_sensitivity:.2f} ")
254
+ log_stats[f"leq4_specificity_at_{i}"]=f"{sig_specificity:.2f}"
255
+ log_stats[f"leq4_sensitivity_at_{i}"]=f"{sig_sensitivity:.2f}"
256
+ return log_stats
257
+
258
+
259
+ def test_screening(prob, gts):
260
+ prob = torch.cat(prob, 0)
261
+ prob = torch.sigmoid(prob).cpu().numpy()
262
+ gts = torch.cat(gts, 0).long().cpu().numpy()
263
+
264
+ np.savez("result.npz", gts=gts, prob=prob)
265
+ score_acc = top_k_accuracy_score(gts, prob, k=1) * 100
266
+ score_auc = roc_auc_score(gts, prob) * 100
267
+ score_f1 = f1_score(gts, np.argmax(prob, 1)) * 100
268
+
269
+ print(f"acc\t auc \t f1")
270
+ print(f"{score_acc:.2f} \t {score_auc:.2f} \t {score_f1:.2f}")
271
+
272
+ for i in [0.8, 0.9]:
273
+ sig_spec = BinarySpecificityAtSensitivity(min_sensitivity=i, thresholds=None)
274
+ sig_specificity, _ = sig_spec(torch.from_numpy(prob), torch.from_numpy(gts))
275
+ sig_sens = BinarySensitivityAtSpecificity(min_specificity=i, thresholds=None)
276
+ sig_sensitivity, _ = sig_sens(torch.from_numpy(prob), torch.from_numpy(gts))
277
+
278
+ print(f"min: {i}")
279
+ print(f"Specificity at Sensitivity \t Sensitivity at Specificity")
280
+ print(f"{sig_specificity* 100:.2f} \t {sig_sensitivity* 100:.2f} ")
281
+
282
+ log_stats = None
283
+ return log_stats
284
+
285
+
286
+
287
+ def test_promis(prob, gts):
288
+ log_stats = {}
289
+
290
+ prob = torch.cat(prob, 0)
291
+ prob = torch.sigmoid(prob).cpu().numpy()
292
+ gts = torch.cat(gts, 0).cpu().numpy().astype(int)
293
+
294
+ #zone level
295
+ zone_prob = prob.reshape(-1)
296
+ zone_gt = gts.reshape(-1)
297
+ print(f"zone level performance")
298
+
299
+ auc = roc_auc_score(zone_prob, zone_gt) * 100
300
+ print(f"AUC: {auc:.2f}")
301
+ for i in [0.8, 0.9]:
302
+ sig_spec = BinarySpecificityAtSensitivity(min_sensitivity=i, thresholds=None)
303
+ sig_specificity, _ = sig_spec(
304
+ torch.from_numpy(zone_prob), torch.from_numpy(zone_gt)
305
+ )
306
+ sig_sens = BinarySensitivityAtSpecificity(min_specificity=i, thresholds=None)
307
+ sig_sensitivity, _ = sig_sens(
308
+ torch.from_numpy(zone_prob), torch.from_numpy(zone_gt)
309
+ )
310
+
311
+ print(f"min: {i}")
312
+ print(f"Specificity at Sensitivity \t Sensitivity at Specificity")
313
+ print(f"{sig_specificity* 100:.2f} \t {sig_sensitivity* 100:.2f} ")
314
+
315
+
316
+
317
+
318
+ #patient level
319
+ patient_prob = prob.max(-1)
320
+ patient_gt = gts.max(-1)
321
+
322
+ print(f"patient level performance")
323
+
324
+ auc = roc_auc_score(patient_prob, patient_gt) * 100
325
+ print(f"AUC: {auc:.2f}")
326
+ for i in [0.8, 0.9]:
327
+ sig_spec = BinarySpecificityAtSensitivity(min_sensitivity=i, thresholds=None)
328
+ sig_specificity, _ = sig_spec(
329
+ torch.from_numpy(patient_prob), torch.from_numpy(patient_gt)
330
+ )
331
+ sig_sens = BinarySensitivityAtSpecificity(min_specificity=i, thresholds=None)
332
+ sig_sensitivity, _ = sig_sens(
333
+ torch.from_numpy(patient_prob), torch.from_numpy(patient_gt)
334
+ )
335
+
336
+ print(f"min: {i}")
337
+ print(f"Specificity at Sensitivity \t Sensitivity at Specificity")
338
+ print(f"{sig_specificity* 100:.2f} \t {sig_sensitivity* 100:.2f} ")
339
+
340
+
341
+ return log_stats
engine/location.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import sys
3
+ from typing import Iterable
4
+
5
+ import torch
6
+ import os
7
+ import util.misc as misc
8
+ import util.lr_sched as lr_sched
9
+ from monai.losses import DiceCELoss, DiceLoss
10
+ import numpy as np
11
+ from monai.metrics import DiceHelper
12
+ import surface_distance
13
+ from surface_distance import metrics
14
+ from util.meter import DiceMeter, HausdorffMeter, SurfaceDistanceMeter
15
+
16
+ # from monai.data import ImageDataset, create_test_image_3d, decollate_batch, DataLoader
17
+ from monai.inferers import sliding_window_inference
18
+ from torchmetrics.classification import (
19
+ BinarySpecificityAtSensitivity,
20
+ BinarySensitivityAtSpecificity,
21
+ )
22
+ # from monai.metrics import DiceMetric
23
+ # from monai.transforms import Activations
24
+ import pdb
25
+ from sklearn.metrics import (
26
+ roc_auc_score,
27
+ top_k_accuracy_score,
28
+ f1_score,
29
+ confusion_matrix,
30
+ )
31
+
32
+
33
+ def train_one_epoch(
34
+ model,
35
+ data_loader,
36
+ optimizer,
37
+ device,
38
+ epoch: int,
39
+ loss_scaler,
40
+ log_writer=None,
41
+ args=None,
42
+ ):
43
+ model.train(True)
44
+ metric_logger = misc.MetricLogger(delimiter=" ")
45
+ metric_logger.add_meter("lr", misc.SmoothedValue(window_size=1, fmt="{value:.6f}"))
46
+ header = "Epoch: [{}]".format(epoch)
47
+ print_freq = 20
48
+
49
+ loss_cal = torch.nn.BCEWithLogitsLoss()
50
+
51
+ optimizer.zero_grad()
52
+
53
+ if log_writer is not None:
54
+ print("log_dir: {}".format(log_writer.log_dir))
55
+ last_norm = 0.0
56
+ for data_iter_step, (img, zone_mask, gt) in enumerate(
57
+ metric_logger.log_every(data_loader, print_freq, header)
58
+ ):
59
+
60
+
61
+ # we use a per iteration (instead of per epoch) lr scheduler
62
+ img, zone_mask, gt = img.to(device, non_blocking=True), zone_mask.to(device, non_blocking=True), gt.to(device, non_blocking=True)
63
+ gt = gt.float()
64
+ lr_sched.adjust_learning_rate(
65
+ optimizer, data_iter_step / len(data_loader) + epoch, args
66
+ )
67
+ logit = model(img, zone_mask)
68
+ if isinstance(logit, list):
69
+ loss = loss_cal(logit[0], gt) + 0.4*loss_cal(logit[1], gt)
70
+ else:
71
+ loss = loss_cal(logit, gt)
72
+
73
+ loss_value = loss.item()
74
+
75
+ if not math.isfinite(loss_value):
76
+ print(
77
+ "nan",
78
+ torch.isnan(logit).any(),
79
+ torch.isnan(img).any(),
80
+ last_norm,
81
+ )
82
+ print(
83
+ "inf",
84
+ torch.isinf(logit).any(),
85
+ torch.isinf(img).any(),
86
+ last_norm,
87
+ )
88
+ print("Loss is {}, stopping training".format(loss_value))
89
+ sys.exit(1)
90
+
91
+ optimizer.zero_grad()
92
+ loss.backward()
93
+ # torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
94
+ optimizer.step()
95
+
96
+ metric_logger.update(loss=loss_value)
97
+
98
+ lr = optimizer.param_groups[0]["lr"]
99
+ metric_logger.update(lr=lr)
100
+
101
+ loss_value_reduce = misc.all_reduce_mean(loss_value)
102
+ if log_writer is not None:
103
+ """We use epoch_1000x as the x-axis in tensorboard.
104
+ This calibrates different curves when batch size changes.
105
+ """
106
+ epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
107
+ log_writer.add_scalar("train_loss", loss_value_reduce, epoch_1000x)
108
+ log_writer.add_scalar("lr", lr, epoch_1000x)
109
+
110
+ # gather the stats from all processes
111
+ # metric_logger.synchronize_between_processes()
112
+ print("Averaged stats:", metric_logger)
113
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
114
+
115
+
116
+ def validation(model, data_loader_val, device, epoch, args):
117
+ model.eval()
118
+ loss_cal = torch.nn.BCEWithLogitsLoss()
119
+
120
+ with torch.no_grad():
121
+ loss_summary = []
122
+
123
+ for idx, (img, zone_mask, gt) in enumerate(data_loader_val):
124
+ img, zone_mask, gt = img.to(device, non_blocking=True), zone_mask.to(device, non_blocking=True), gt.to(device, non_blocking=True)
125
+ gt = gt.float()
126
+ logit = model(img, zone_mask)
127
+ loss = loss_cal(logit, gt)
128
+ loss_summary.append(loss.detach().cpu().numpy())
129
+ print(
130
+ "epoch: {}/{}, iter: {}/{}".format(
131
+ epoch, args.epochs, idx, len(data_loader_val)
132
+ )
133
+ + " loss:"
134
+ + str(loss_summary[-1].flatten()[0])
135
+ )
136
+ avg_loss = np.mean(loss_summary)
137
+ print("Averaged stats:", str(avg_loss))
138
+ return avg_loss
139
+
140
+
141
+ def test(model, test_loader, args, sliding_window=False):
142
+ model.eval()
143
+ filepath_best = os.path.join(args.output_dir, "best.pth.tar")
144
+ model.load_state_dict(torch.load(filepath_best)["model"], weights_only=False)
145
+
146
+ log_stats = {}
147
+ with torch.no_grad():
148
+ prob, gts = [], []
149
+
150
+ for idx, (img, zone_mask, gt) in enumerate(test_loader):
151
+ img, zone_mask, gt = img.to(args.device, non_blocking=True), zone_mask.to(args.device, non_blocking=True), gt.to(args.device, non_blocking=True)
152
+
153
+ logit = model(img, zone_mask)
154
+ prob.append(logit)
155
+ gts.append(gt)
156
+
157
+
158
+ prob = torch.cat(prob, 0)
159
+ prob = torch.sigmoid(prob).cpu()
160
+ gts = torch.cat(gts, 0).cpu()
161
+
162
+
163
+
164
+ print("- Zone level: ")
165
+ zone_prob = prob.reshape(-1, prob.shape[-1])
166
+ zone_gt = gts.reshape(-1, prob.shape[-1])
167
+ zone_auc = roc_auc_score(zone_prob, zone_gt) * 100
168
+
169
+ for i in [0.8, 0.9]:
170
+ sig_spec = BinarySpecificityAtSensitivity(min_sensitivity=i, thresholds=None)
171
+ sig_specificity, _ = sig_spec(zone_prob, zone_gt)
172
+ sig_specificity = sig_specificity * 100
173
+
174
+ sig_sens = BinarySensitivityAtSpecificity(min_specificity=i, thresholds=None)
175
+ sig_sensitivity, _ = sig_sens(zone_prob, zone_gt)
176
+ sig_sensitivity = sig_sensitivity* 100
177
+
178
+ print(f"min: {i}")
179
+ print(f"Specificity at Sensitivity \t Sensitivity at Specificity")
180
+ print(f"{sig_specificity:.2f} \t {sig_sensitivity:.2f} ")
181
+ log_stats[f"specificity_at_{i}"]=f"{sig_specificity:.2f}"
182
+ log_stats[f"sensitivity_at_{i}"]=f"{sig_sensitivity:.2f}"
183
+
184
+
185
+ print("- Patient level: ")
186
+ p_prob = prob.max(1).values
187
+ p_gt = gts.max(1).values
188
+
189
+ p_auc = roc_auc_score(p_prob, p_gt) * 100
190
+
191
+ for i in [0.8, 0.9]:
192
+ sig_spec = BinarySpecificityAtSensitivity(min_sensitivity=i, thresholds=None)
193
+ sig_specificity, _ = sig_spec(p_prob, p_gt)
194
+ sig_specificity = sig_specificity * 100
195
+
196
+ sig_sens = BinarySensitivityAtSpecificity(min_specificity=i, thresholds=None)
197
+ sig_sensitivity, _ = sig_sens(p_prob, p_gt)
198
+ sig_sensitivity = sig_sensitivity* 100
199
+
200
+ print(f"min: {i}")
201
+ print(f"Specificity at Sensitivity \t Sensitivity at Specificity")
202
+ print(f"{sig_specificity:.2f} \t {sig_sensitivity:.2f} ")
203
+ log_stats[f"specificity_at_{i}"]=f"{sig_specificity:.2f}"
204
+ log_stats[f"sensitivity_at_{i}"]=f"{sig_sensitivity:.2f}"
205
+
206
+ return log_stats
engine/pretrain.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # DeiT: https://github.com/facebookresearch/deit
9
+ # BEiT: https://github.com/microsoft/unilm/tree/master/beit
10
+ # --------------------------------------------------------
11
+ import math
12
+ import sys
13
+ from typing import Iterable
14
+
15
+ import torch
16
+
17
+ import util.misc as misc
18
+ import util.lr_sched as lr_sched
19
+
20
+
21
+ def train_one_epoch(
22
+ model,
23
+ data_loader,
24
+ optimizer,
25
+ device,
26
+ epoch: int,
27
+ loss_scaler,
28
+ log_writer=None,
29
+ args=None,
30
+ ):
31
+ model.train(True)
32
+ metric_logger = misc.MetricLogger(delimiter=" ")
33
+ metric_logger.add_meter("lr", misc.SmoothedValue(window_size=1, fmt="{value:.6f}"))
34
+ header = "Epoch: [{}]".format(epoch)
35
+ print_freq = 20
36
+
37
+ optimizer.zero_grad()
38
+
39
+ if log_writer is not None:
40
+ print("log_dir: {}".format(log_writer.log_dir))
41
+
42
+ for data_iter_step, (samples, _) in enumerate(
43
+ metric_logger.log_every(data_loader, print_freq, header)
44
+ ):
45
+
46
+ # we use a per iteration (instead of per epoch) lr scheduler
47
+ samples = samples.to(device, non_blocking=True)
48
+ lr_sched.adjust_learning_rate(
49
+ optimizer, data_iter_step / len(data_loader) + epoch, args
50
+ )
51
+
52
+ # with torch.cuda.amp.autocast():
53
+ loss, _, _ = model(samples, mask_ratio=args.mask_ratio)
54
+
55
+ loss_value = loss.item()
56
+
57
+ if not math.isfinite(loss_value):
58
+ print("Loss is {}, stopping training".format(loss_value))
59
+ sys.exit(1)
60
+
61
+ optimizer.zero_grad()
62
+ loss.backward()
63
+ # torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
64
+ optimizer.step()
65
+ # loss_scaler(loss, optimizer, parameters=model.parameters(),clip_grad=1.0)
66
+ # optimizer.zero_grad()
67
+ torch.cuda.synchronize()
68
+ metric_logger.update(loss=loss_value)
69
+
70
+ lr = optimizer.param_groups[0]["lr"]
71
+ metric_logger.update(lr=lr)
72
+
73
+ loss_value_reduce = misc.all_reduce_mean(loss_value)
74
+ if log_writer is not None:
75
+ """We use epoch_1000x as the x-axis in tensorboard.
76
+ This calibrates different curves when batch size changes.
77
+ """
78
+ epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
79
+ log_writer.add_scalar("train_loss", loss_value_reduce, epoch_1000x)
80
+ log_writer.add_scalar("lr", lr, epoch_1000x)
81
+
82
+ # gather the stats from all processes
83
+ metric_logger.synchronize_between_processes()
84
+ print("Averaged stats:", metric_logger)
85
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
engine/pretrain_amp.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # DeiT: https://github.com/facebookresearch/deit
9
+ # BEiT: https://github.com/microsoft/unilm/tree/master/beit
10
+ # --------------------------------------------------------
11
+ import math
12
+ import sys
13
+ from typing import Iterable
14
+
15
+ import torch
16
+
17
+ import util.misc as misc
18
+ import util.lr_sched as lr_sched
19
+
20
+
21
+ def train_one_epoch(
22
+ model,
23
+ data_loader,
24
+ optimizer,
25
+ device,
26
+ epoch: int,
27
+ loss_scaler,
28
+ log_writer=None,
29
+ args=None,
30
+ ):
31
+ model.train(True)
32
+ metric_logger = misc.MetricLogger(delimiter=" ")
33
+ metric_logger.add_meter("lr", misc.SmoothedValue(window_size=1, fmt="{value:.6f}"))
34
+ header = "Epoch: [{}]".format(epoch)
35
+ print_freq = 20
36
+
37
+ optimizer.zero_grad()
38
+
39
+ if log_writer is not None:
40
+ print("log_dir: {}".format(log_writer.log_dir))
41
+
42
+ for data_iter_step, (samples, _) in enumerate(
43
+ metric_logger.log_every(data_loader, print_freq, header)
44
+ ):
45
+
46
+ # we use a per iteration (instead of per epoch) lr scheduler
47
+ samples = samples.to(device, non_blocking=True)
48
+ lr_sched.adjust_learning_rate(
49
+ optimizer, data_iter_step / len(data_loader) + epoch, args
50
+ )
51
+
52
+ with torch.cuda.amp.autocast():
53
+ loss, _, _ = model(samples, mask_ratio=args.mask_ratio)
54
+
55
+ loss_value = loss.item()
56
+
57
+ if not math.isfinite(loss_value):
58
+ print("Loss is {}, stopping training".format(loss_value))
59
+ sys.exit(1)
60
+
61
+ loss_scaler(loss, optimizer, parameters=model.parameters(), clip_grad=1.0)
62
+ optimizer.zero_grad()
63
+ torch.cuda.synchronize()
64
+ metric_logger.update(loss=loss_value)
65
+
66
+ lr = optimizer.param_groups[0]["lr"]
67
+ metric_logger.update(lr=lr)
68
+
69
+ loss_value_reduce = misc.all_reduce_mean(loss_value)
70
+ if log_writer is not None:
71
+ """We use epoch_1000x as the x-axis in tensorboard.
72
+ This calibrates different curves when batch size changes.
73
+ """
74
+ epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
75
+ log_writer.add_scalar("train_loss", loss_value_reduce, epoch_1000x)
76
+ log_writer.add_scalar("lr", lr, epoch_1000x)
77
+
78
+ # gather the stats from all processes
79
+ metric_logger.synchronize_between_processes()
80
+ print("Averaged stats:", metric_logger)
81
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
engine/regression.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # DeiT: https://github.com/facebookresearch/deit
9
+ # BEiT: https://github.com/microsoft/unilm/tree/master/beit
10
+ # --------------------------------------------------------
11
+ import math
12
+ import sys
13
+ import torch
14
+ import os
15
+ import util.misc as misc
16
+ import util.lr_sched as lr_sched
17
+ import numpy as np
18
+
19
+
20
+ def train_one_epoch(
21
+ model,
22
+ data_loader,
23
+ optimizer,
24
+ device,
25
+ epoch: int,
26
+ loss_scaler,
27
+ log_writer=None,
28
+ args=None,
29
+ ):
30
+ model.train(True)
31
+ metric_logger = misc.MetricLogger(delimiter=" ")
32
+ metric_logger.add_meter("lr", misc.SmoothedValue(window_size=1, fmt="{value:.6f}"))
33
+ header = "Epoch: [{}]".format(epoch)
34
+ print_freq = 20
35
+
36
+ loss_cal = torch.nn.MSELoss()
37
+
38
+ optimizer.zero_grad()
39
+
40
+ if log_writer is not None:
41
+ print("log_dir: {}".format(log_writer.log_dir))
42
+ last_norm = 0.0
43
+ for data_iter_step, (img, gt, dataidx) in enumerate(
44
+ metric_logger.log_every(data_loader, print_freq, header)
45
+ ):
46
+ # we use a per iteration (instead of per epoch) lr scheduler
47
+ img, gt = img.to(device, non_blocking=True), gt.to(device, non_blocking=True)
48
+ lr_sched.adjust_learning_rate(
49
+ optimizer, data_iter_step / len(data_loader) + epoch, args
50
+ )
51
+ logit = model(img)
52
+ loss = loss_cal(logit, gt)
53
+ loss_value = loss.item()
54
+ if not math.isfinite(loss_value):
55
+ print(
56
+ "nan",
57
+ torch.isnan(logit).any(),
58
+ torch.isnan(img).any(),
59
+ dataidx,
60
+ last_norm,
61
+ )
62
+ print(
63
+ "inf",
64
+ torch.isinf(logit).any(),
65
+ torch.isinf(img).any(),
66
+ dataidx,
67
+ last_norm,
68
+ )
69
+ print("Loss is {}, stopping training".format(loss_value))
70
+ sys.exit(1)
71
+
72
+ optimizer.zero_grad()
73
+ loss.backward()
74
+ # torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
75
+ optimizer.step()
76
+
77
+ # last_norm = loss_scaler(loss, optimizer, parameters=model.parameters())
78
+ # optimizer.zero_grad()
79
+ # torch.cuda.synchronize()
80
+ metric_logger.update(loss=loss_value)
81
+
82
+ lr = optimizer.param_groups[0]["lr"]
83
+ metric_logger.update(lr=lr)
84
+
85
+ loss_value_reduce = misc.all_reduce_mean(loss_value)
86
+ if log_writer is not None:
87
+ """We use epoch_1000x as the x-axis in tensorboard.
88
+ This calibrates different curves when batch size changes.
89
+ """
90
+ epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
91
+ log_writer.add_scalar("train_loss", loss_value_reduce, epoch_1000x)
92
+ log_writer.add_scalar("lr", lr, epoch_1000x)
93
+
94
+ # gather the stats from all processes
95
+ metric_logger.synchronize_between_processes()
96
+ print("Averaged stats:", metric_logger)
97
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
98
+
99
+
100
+ def validation(model, data_loader_val, device, epoch, args):
101
+ model.eval()
102
+ loss_cal = torch.nn.MSELoss()
103
+ with torch.no_grad():
104
+ loss_summary = []
105
+ for idx, (img, gt, _) in enumerate(data_loader_val):
106
+ img, gt = img.to(device), gt.to(device)
107
+ loss = loss_cal(model(img), gt)
108
+ loss_summary.append(loss.detach().cpu().numpy())
109
+ print(
110
+ "epoch: {}/{}, iter: {}/{}".format(
111
+ epoch, args.epochs, idx, len(data_loader_val)
112
+ )
113
+ + " loss:"
114
+ + str(loss_summary[-1].flatten()[0])
115
+ )
116
+ avg_loss = np.mean(loss_summary)
117
+ print("Averaged stats:", str(avg_loss))
118
+ return avg_loss
119
+
120
+
121
+ def test(model, test_loader, args):
122
+ filepath_best = os.path.join(args.output_dir, "best.pth.tar")
123
+ model.load_state_dict(torch.load(filepath_best)["model"], weights_only=False)
124
+
125
+ model.eval()
126
+ log_stats = {}
127
+ pred, gts = [], []
128
+
129
+ with torch.no_grad():
130
+ for idx, (img, gt, _) in enumerate(test_loader):
131
+ img, gt = img.to(args.device), gt.to(args.device)
132
+ pred.append(model(img))
133
+ gts.append(gt)
134
+ pred = torch.cat(pred, 0)
135
+ gts = torch.cat(gts, 0)
136
+ pred = pred * 500000 + 70000
137
+ gts = gts * 500000 + 70000
138
+ mse = torch.nn.MSELoss()(pred, gts)
139
+ mae = torch.nn.L1Loss()(pred, gts)
140
+ print("MSE", mse.item(), "MAE", mae.item())
141
+ log_stats = {"MSE": mse.item(), "MAE": mae.item()}
142
+ return log_stats
engine/segment.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # DeiT: https://github.com/facebookresearch/deit
9
+ # BEiT: https://github.com/microsoft/unilm/tree/master/beit
10
+ # --------------------------------------------------------
11
+ import math
12
+ import sys
13
+ from typing import Iterable
14
+
15
+ import torch
16
+ import os
17
+ import util.misc as misc
18
+ import util.lr_sched as lr_sched
19
+ from monai.losses import DiceCELoss, DiceLoss
20
+ import numpy as np
21
+ from monai.metrics import DiceHelper
22
+ import surface_distance
23
+ from surface_distance import metrics
24
+ from util.meter import DiceMeter, HausdorffMeter, SurfaceDistanceMeter
25
+
26
+ # from monai.data import ImageDataset, create_test_image_3d, decollate_batch, DataLoader
27
+ from monai.inferers import sliding_window_inference
28
+
29
+ # from monai.metrics import DiceMetric
30
+ # from monai.transforms import Activations
31
+ import pdb
32
+
33
+
34
+ def train_one_epoch(
35
+ model,
36
+ data_loader,
37
+ optimizer,
38
+ device,
39
+ epoch: int,
40
+ loss_scaler,
41
+ log_writer=None,
42
+ args=None,
43
+ ):
44
+ model.train(True)
45
+ metric_logger = misc.MetricLogger(delimiter=" ")
46
+ metric_logger.add_meter("lr", misc.SmoothedValue(window_size=1, fmt="{value:.6f}"))
47
+ header = "Epoch: [{}]".format(epoch)
48
+ print_freq = 20
49
+
50
+ if args.out_channels == 1:
51
+ loss_cal = DiceCELoss(sigmoid=True)
52
+ else:
53
+ loss_cal = DiceCELoss(to_onehot_y=True, softmax=True, include_background=False)
54
+
55
+
56
+ optimizer.zero_grad()
57
+
58
+ if log_writer is not None:
59
+ print("log_dir: {}".format(log_writer.log_dir))
60
+ last_norm = 0.0
61
+ for data_iter_step, (img, gt, dataidx) in enumerate(
62
+ metric_logger.log_every(data_loader, print_freq, header)
63
+ ):
64
+ # we use a per iteration (instead of per epoch) lr scheduler
65
+ img, gt = img.to(device, non_blocking=True), gt.to(device, non_blocking=True)
66
+ lr_sched.adjust_learning_rate(
67
+ optimizer, data_iter_step / len(data_loader) + epoch, args
68
+ )
69
+ # print(img.shape, img.mean(), img.std())
70
+ # with torch.cuda.amp.autocast():
71
+ logit = model(img)
72
+ if isinstance(logit, list):
73
+ loss = loss_cal(logit[0], gt) + 0.4*loss_cal(logit[1], gt)
74
+ else:
75
+ loss = loss_cal(logit, gt)
76
+
77
+ loss_value = loss.item()
78
+
79
+ if not math.isfinite(loss_value):
80
+ print(
81
+ "nan",
82
+ torch.isnan(logit).any(),
83
+ torch.isnan(img).any(),
84
+ dataidx,
85
+ last_norm,
86
+ )
87
+ print(
88
+ "inf",
89
+ torch.isinf(logit).any(),
90
+ torch.isinf(img).any(),
91
+ dataidx,
92
+ last_norm,
93
+ )
94
+ print("Loss is {}, stopping training".format(loss_value))
95
+ sys.exit(1)
96
+
97
+ optimizer.zero_grad()
98
+ loss.backward()
99
+ # torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
100
+ optimizer.step()
101
+
102
+ # last_norm = loss_scaler(loss, optimizer, parameters=model.parameters())
103
+ # optimizer.zero_grad()
104
+ # torch.cuda.synchronize()
105
+ metric_logger.update(loss=loss_value)
106
+
107
+ lr = optimizer.param_groups[0]["lr"]
108
+ metric_logger.update(lr=lr)
109
+
110
+ loss_value_reduce = misc.all_reduce_mean(loss_value)
111
+ if log_writer is not None:
112
+ """We use epoch_1000x as the x-axis in tensorboard.
113
+ This calibrates different curves when batch size changes.
114
+ """
115
+ epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
116
+ log_writer.add_scalar("train_loss", loss_value_reduce, epoch_1000x)
117
+ log_writer.add_scalar("lr", lr, epoch_1000x)
118
+
119
+ # gather the stats from all processes
120
+ # metric_logger.synchronize_between_processes()
121
+ print("Averaged stats:", metric_logger)
122
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
123
+
124
+
125
+ def validation(model, data_loader_val, device, epoch, args):
126
+ model.eval()
127
+ if args.out_channels == 1:
128
+ dice_loss = DiceLoss(sigmoid=True)
129
+ else:
130
+ dice_loss = DiceLoss(to_onehot_y=True, softmax=True, include_background=False)
131
+
132
+ with torch.no_grad():
133
+ loss_summary = []
134
+ for idx, (img, gt, _) in enumerate(data_loader_val):
135
+ img, gt = img.to(device), gt.to(device)
136
+ mask = model(img)
137
+ loss = dice_loss(mask, gt)
138
+ loss_summary.append(loss.detach().cpu().numpy())
139
+ print(
140
+ "epoch: {}/{}, iter: {}/{}".format(
141
+ epoch, args.epochs, idx, len(data_loader_val)
142
+ )
143
+ + " loss:"
144
+ + str(loss_summary[-1].flatten()[0])
145
+ )
146
+ avg_loss = np.mean(loss_summary)
147
+ print("Averaged stats:", str(avg_loss))
148
+ return avg_loss
149
+
150
+
151
+ def test(model, test_loader, args, sliding_window=False):
152
+ model.eval()
153
+ filepath_best = os.path.join(args.output_dir, "best.pth.tar")
154
+ model.load_state_dict(torch.load(filepath_best)["model"], weights_only=False)
155
+ dice_meter = DiceMeter(args)
156
+ hausdorff_meter = HausdorffMeter(args)
157
+ sd_meter = SurfaceDistanceMeter(args)
158
+ log_stats = {}
159
+ with torch.no_grad():
160
+ for idx, (img, gt, _) in enumerate(test_loader):
161
+ img, gt = img.to(args.device), gt.to(args.device)
162
+ if sliding_window:
163
+ pred = sliding_window_inference(
164
+ img, args.crop_spatial_size, 4, model, overlap=0.5
165
+ )
166
+ else:
167
+ pred = model(img)
168
+ if args.num_classes == 1:
169
+ pred = torch.sigmoid(pred) > 0.5
170
+ else:
171
+ pred = torch.softmax(pred, dim=1)
172
+ pred = torch.argmax(pred, dim=1, keepdim=True)
173
+ dice_meter.update(pred, gt)
174
+ hausdorff_meter.update(pred, gt)
175
+ sd_meter.update(pred, gt)
176
+
177
+ print("- Test metrics Dice: ")
178
+ dice_class_avg, dice_avg = dice_meter.get_average()
179
+ print("Class wise: ", dice_class_avg)
180
+ print("Avg.: ", dice_avg)
181
+
182
+ print("- Test metrics Hausdorff95: ")
183
+ hsd_class_avg, hsd_avg = hausdorff_meter.get_average()
184
+ print("Class wise: ", hsd_class_avg)
185
+ print("Avg.: ", hsd_avg)
186
+
187
+ print("- Test metrics SurfaceDistance: ")
188
+ sd_class_avg, sd_avg = sd_meter.get_average()
189
+ print("Class wise: ", sd_class_avg)
190
+ print("Avg.: ", sd_avg)
191
+ log_stats = {
192
+ "dice_class_avg": dice_class_avg.tolist() if isinstance(dice_class_avg, np.ndarray) else dice_class_avg,
193
+ "dice_avg": dice_avg.tolist() if isinstance(dice_avg, np.ndarray) else dice_avg,
194
+ "hsd_class_avg": hsd_class_avg.tolist() if isinstance(hsd_class_avg, np.ndarray) else hsd_class_avg,
195
+ "hsd_avg": hsd_avg.tolist() if isinstance(hsd_avg, np.ndarray) else hsd_avg,
196
+ "sd_class_avg": sd_class_avg.tolist() if isinstance(sd_class_avg, np.ndarray) else sd_class_avg,
197
+ "sd_avg": sd_avg.tolist() if isinstance(sd_avg, np.ndarray) else sd_avg,
198
+ }
199
+ return log_stats
models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # ProFound models package
models/build_classification.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models.classifier import Classifier
2
+ from models.convnextv2 import convnextv2_tiny, remap_checkpoint_keys, load_state_dict
3
+ from util.lars import LARS
4
+ import torch
5
+ import os
6
+ from util.convnext_optim import get_parameter_groups, LayerDecayValueAssigner
7
+
8
+ def build_model(args, device):
9
+ if args.model == "profound_conv":
10
+ convnext = convnextv2_tiny(in_chans=3, drop_path_rate=0.1)
11
+ if args.pretrain is None:
12
+ raise NotImplementedError(f"No pretrained weight")
13
+ if not os.path.exists(args.pretrain):
14
+ raise FileExistsError(f"{args.pretrain} Not exists")
15
+ ckpt = torch.load(args.pretrain, map_location="cpu")
16
+ ckpt = remap_checkpoint_keys(ckpt)
17
+ load_state_dict(convnext, ckpt, weights_only=False)
18
+ model = Classifier(convnext, args.num_classes)
19
+ model = model.to(device)
20
+ if args.train == "freeze":
21
+ for key, value in model.encoder.named_parameters():
22
+ value.requires_grad = False
23
+ optimizer = LARS(model.head.parameters(), weight_decay=0, lr=args.lr)
24
+ else:
25
+ num_layers = sum(convnext.depths)
26
+ assigner = LayerDecayValueAssigner(
27
+ list(
28
+ args.layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2)
29
+ ),
30
+ depths=convnext.depths,
31
+ layer_decay_type=args.layer_decay_type,
32
+ )
33
+
34
+ skip = {}
35
+ if hasattr(model.encoder, "no_weight_decay"):
36
+ skip = model.encoder.no_weight_decay()
37
+
38
+ backbone_param_groups = get_parameter_groups(
39
+ model.encoder,
40
+ args.weight_decay,
41
+ skip,
42
+ assigner.get_layer_id,
43
+ assigner.get_scale,
44
+ )
45
+ decoder_param_groups = [
46
+ {"params": model.head.parameters(), "weight_decay": 0.0, "lr": args.lr}
47
+ ]
48
+
49
+ optimizer = torch.optim.AdamW(
50
+ backbone_param_groups + decoder_param_groups, lr=args.lr
51
+ )
52
+
53
+ else:
54
+ raise NotImplementedError(f"unknown model: {args.model}")
55
+
56
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
57
+
58
+ print("Model = %s" % str(model))
59
+ print("number of params (M): %.2f" % (n_parameters / 1.0e6))
60
+
61
+ return model, optimizer
62
+
63
+
64
+ def vit_backbone_parameters(
65
+ model: torch.nn.Module, weight_decay=1e-5, no_weight_decay_list=(), lr=1e-3
66
+ ):
67
+ no_weight_decay_list = set(no_weight_decay_list)
68
+ decay = []
69
+ no_decay = []
70
+
71
+ for name, param in model.named_parameters():
72
+ if not param.requires_grad:
73
+ continue
74
+
75
+ if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_list:
76
+ no_decay.append(param)
77
+ else:
78
+ decay.append(param)
79
+
80
+ return [
81
+ {"params": no_decay, "weight_decay": 0.0, "lr": lr},
82
+ {"params": decay, "weight_decay": weight_decay, "lr": lr},
83
+ ]
models/classifier.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class Classifier(nn.Module):
6
+ def __init__(self, encoder, num_classes, bottleneck_dim=256):
7
+ super().__init__()
8
+ self.encoder = encoder
9
+ self.embed_dim = self.encoder.embed_dim
10
+ self.head = torch.nn.Sequential(
11
+ nn.Linear(self.embed_dim, bottleneck_dim),
12
+ nn.BatchNorm1d(bottleneck_dim),
13
+ nn.ReLU(),
14
+ nn.Linear(bottleneck_dim, num_classes)
15
+ )
16
+
17
+ def forward(self, x):
18
+ x = self.encoder(x)
19
+ if type(x) == tuple:
20
+ x = x[0]
21
+ x = self.head(x)
22
+ return x
23
+
models/convnext_unter.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn.functional as F
2
+ from typing import Sequence, Tuple, Union
3
+ import torch
4
+ import torch.nn as nn
5
+ from monai.networks.blocks.dynunet_block import UnetOutBlock
6
+ from monai.networks.blocks.unetr_block import (
7
+ UnetrBasicBlock,
8
+ UnetrPrUpBlock,
9
+ UnetrUpBlock,
10
+ )
11
+ from models.util import LayerNorm
12
+
13
+
14
+ class ConvnextUNETR_Decoder(nn.Module):
15
+ """
16
+ UNETR based on: "Hatamizadeh et al.,
17
+ UNETR: Transformers for 3D Medical Image Segmentation <https://arxiv.org/abs/2103.10504>"
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ in_channels: int,
23
+ out_channels: int,
24
+ feature_size: int = 16,
25
+ norm_name: Union[Tuple, str] = "instance",
26
+ conv_block: bool = True,
27
+ res_block: bool = True,
28
+ spatial_dims: int = 3,
29
+ hidden_size = [96, 192, 384, 768]
30
+ ) -> None:
31
+
32
+ super().__init__()
33
+
34
+ self.encoder1 = UnetrBasicBlock(
35
+ spatial_dims=spatial_dims,
36
+ in_channels=in_channels,
37
+ out_channels=feature_size,
38
+ kernel_size=3,
39
+ stride=1,
40
+ norm_name=norm_name,
41
+ res_block=res_block,
42
+ )
43
+ self.encoder2 = UnetrPrUpBlock(
44
+ spatial_dims=spatial_dims,
45
+ in_channels=hidden_size[0],
46
+ out_channels=feature_size * 2,
47
+ num_layer=0,
48
+ kernel_size=3,
49
+ stride=1,
50
+ upsample_kernel_size=2,
51
+ norm_name=norm_name,
52
+ conv_block=conv_block,
53
+ res_block=res_block,
54
+ )
55
+ self.encoder3 = UnetrPrUpBlock(
56
+ spatial_dims=spatial_dims,
57
+ in_channels=hidden_size[1],
58
+ out_channels=feature_size * 4,
59
+ num_layer=0,
60
+ kernel_size=3,
61
+ stride=1,
62
+ upsample_kernel_size=2,
63
+ norm_name=norm_name,
64
+ conv_block=conv_block,
65
+ res_block=res_block,
66
+ )
67
+ self.encoder4 = UnetrPrUpBlock(
68
+ spatial_dims=spatial_dims,
69
+ in_channels=hidden_size[2],
70
+ out_channels=feature_size * 8,
71
+ num_layer=0,
72
+ kernel_size=3,
73
+ stride=1,
74
+ upsample_kernel_size=2,
75
+ norm_name=norm_name,
76
+ conv_block=conv_block,
77
+ res_block=res_block,
78
+ )
79
+ self.decoder5 = UnetrUpBlock(
80
+ spatial_dims=spatial_dims,
81
+ in_channels=hidden_size[3],
82
+ out_channels=feature_size * 8,
83
+ kernel_size=3,
84
+ upsample_kernel_size=2,
85
+ norm_name=norm_name,
86
+ res_block=res_block,
87
+ )
88
+ self.decoder4 = UnetrUpBlock(
89
+ spatial_dims=spatial_dims,
90
+ in_channels=feature_size * 8,
91
+ out_channels=feature_size * 4,
92
+ kernel_size=3,
93
+ upsample_kernel_size=2,
94
+ norm_name=norm_name,
95
+ res_block=res_block,
96
+ )
97
+ self.decoder3 = UnetrUpBlock(
98
+ spatial_dims=spatial_dims,
99
+ in_channels=feature_size * 4,
100
+ out_channels=feature_size * 2,
101
+ kernel_size=3,
102
+ upsample_kernel_size=2,
103
+ norm_name=norm_name,
104
+ res_block=res_block,
105
+ )
106
+ self.decoder2 = UnetrUpBlock(
107
+ spatial_dims=spatial_dims,
108
+ in_channels=feature_size * 2,
109
+ out_channels=feature_size,
110
+ kernel_size=3,
111
+ upsample_kernel_size=2,
112
+ norm_name=norm_name,
113
+ res_block=res_block,
114
+ )
115
+ self.out = UnetOutBlock(
116
+ spatial_dims=spatial_dims,
117
+ in_channels=feature_size,
118
+ out_channels=out_channels,
119
+ )
120
+
121
+ def forward(self, x, x1, x2, x3, x4):
122
+ enc1 = self.encoder1(x)
123
+ enc2 = self.encoder2(x1)
124
+ enc3 = self.encoder3(x2)
125
+ enc4 = self.encoder4(x3)
126
+ dec3 = self.decoder5(x4, enc4)
127
+ dec2 = self.decoder4(dec3, enc3)
128
+ dec1 = self.decoder3(dec2, enc2)
129
+ out = self.decoder2(dec1, enc1)
130
+ mask = self.out(out)
131
+ return mask
132
+
133
+
134
+ class ConvnextUNETR(nn.Module):
135
+ """
136
+ UNETR based on: "Hatamizadeh et al.,
137
+ UNETR: Transformers for 3D Medical Image Segmentation <https://arxiv.org/abs/2103.10504>"
138
+ """
139
+
140
+ def __init__(
141
+ self,
142
+ in_channels: int,
143
+ out_channels: int,
144
+ convnext,
145
+ feature_size: int = 16,
146
+ norm_name: Union[Tuple, str] = "instance",
147
+ conv_block: bool = True,
148
+ res_block: bool = True,
149
+ spatial_dims: int = 3,
150
+ hidden_size = [96, 192, 384, 768]
151
+ ) -> None:
152
+
153
+ super().__init__()
154
+
155
+ self.encoder = convnext
156
+
157
+ self.norm1 = LayerNorm(hidden_size[0], eps=1e-6, data_format="channels_first")
158
+ self.norm2 = LayerNorm(hidden_size[1], eps=1e-6, data_format="channels_first")
159
+ self.norm3 = LayerNorm(hidden_size[2], eps=1e-6, data_format="channels_first")
160
+
161
+ self.decoder = ConvnextUNETR_Decoder(
162
+ in_channels=in_channels,
163
+ out_channels=out_channels,
164
+ feature_size=feature_size,
165
+ norm_name=norm_name,
166
+ conv_block=conv_block,
167
+ res_block=res_block,
168
+ spatial_dims=spatial_dims,
169
+ hidden_size=hidden_size
170
+ )
171
+
172
+ def forward(self, x):
173
+ _, hidden_states_out = self.encoder(x, ret_hids=True)
174
+ x1, x2, x3, x4 = hidden_states_out
175
+ x1 = self.norm1(x1)
176
+ x2 = self.norm2(x2)
177
+ x3 = self.norm3(x3)
178
+ x4 = x4.permute(0, 2, 3, 4, 1) # (N, C, H, W, D) -> (N, H, W, D, C)
179
+ x4 = self.encoder.norm(x4)
180
+ x4 = x4.permute(0, 4, 1, 2, 3)
181
+ mask = self.decoder(x, x1, x2, x3, x4)
182
+ return mask
models/convnextv2.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from timm.models.layers import trunc_normal_, DropPath
12
+ from models.util import LayerNorm, GRN
13
+ from collections import OrderedDict
14
+ import math
15
+
16
+
17
+ class Block(nn.Module):
18
+ """ConvNeXtV2 Block.
19
+
20
+ Args:
21
+ dim (int): Number of input channels.
22
+ drop_path (float): Stochastic depth rate. Default: 0.0
23
+ """
24
+
25
+ def __init__(self, dim, drop_path=0.0):
26
+ super().__init__()
27
+ self.dwconv = nn.Conv3d(
28
+ dim, dim, kernel_size=7, padding=3, groups=dim
29
+ ) # depthwise conv
30
+ self.norm = LayerNorm(dim, eps=1e-6)
31
+ self.pwconv1 = nn.Linear(
32
+ dim, 4 * dim
33
+ ) # pointwise/1x1 convs, implemented with linear layers
34
+ self.act = nn.GELU()
35
+ self.grn = GRN(4 * dim)
36
+ self.pwconv2 = nn.Linear(4 * dim, dim)
37
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
38
+
39
+ def forward(self, x):
40
+ input = x
41
+ x = self.dwconv(x)
42
+ x = x.permute(0, 2, 3, 4, 1) # (N, C, H, W, D) -> (N, H, W, D, C)
43
+ x = self.norm(x)
44
+ x = self.pwconv1(x)
45
+ x = self.act(x)
46
+ x = self.grn(x)
47
+ x = self.pwconv2(x)
48
+ x = x.permute(0, 4, 1, 2, 3) # (N, H, W, D, C) -> (N, C, H, W, D)
49
+ x = input + self.drop_path(x)
50
+ return x
51
+
52
+
53
+ class ConvNeXtV2(nn.Module):
54
+ """ConvNeXt V2
55
+
56
+ Args:
57
+ in_chans (int): Number of input image channels. Default: 3
58
+ num_classes (int): Number of classes for classification head. Default: 1000
59
+ depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
60
+ dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
61
+ drop_path_rate (float): Stochastic depth rate. Default: 0.
62
+ head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ in_chans=3,
68
+ depths=[3, 3, 9, 3],
69
+ dims=[96, 192, 384, 768],
70
+ drop_path_rate=0.0,
71
+ ):
72
+ super().__init__()
73
+ self.depths = depths
74
+ self.downsample_layers = (
75
+ nn.ModuleList()
76
+ ) # stem and 3 intermediate downsampling conv layers
77
+ stem = nn.Sequential(
78
+ nn.Conv3d(in_chans, dims[0], kernel_size=4, stride=4),
79
+ LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
80
+ )
81
+ self.downsample_layers.append(stem)
82
+ for i in range(3):
83
+ if i == 2:
84
+ stride = 1
85
+ else:
86
+ stride = 2
87
+ downsample_layer = nn.Sequential(
88
+ LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
89
+ nn.Conv3d(dims[i], dims[i + 1], kernel_size=stride, stride=stride),
90
+ )
91
+ self.downsample_layers.append(downsample_layer)
92
+
93
+ self.stages = (
94
+ nn.ModuleList()
95
+ ) # 4 feature resolution stages, each consisting of multiple residual blocks
96
+ dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
97
+ cur = 0
98
+ for i in range(4):
99
+ stage = nn.Sequential(
100
+ *[
101
+ Block(dim=dims[i], drop_path=dp_rates[cur + j])
102
+ for j in range(depths[i])
103
+ ]
104
+ )
105
+ self.stages.append(stage)
106
+ cur += depths[i]
107
+
108
+ self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
109
+ # self.head = nn.Linear(dims[-1], num_classes)
110
+
111
+ self.apply(self._init_weights)
112
+ # self.head.weight.data.mul_(head_init_scale)
113
+ # self.head.bias.data.mul_(head_init_scale)
114
+ self.embed_dim = dims[-1]
115
+
116
+ def _init_weights(self, m):
117
+ if isinstance(m, (nn.Conv3d, nn.Linear)):
118
+ trunc_normal_(m.weight, std=0.02)
119
+ nn.init.constant_(m.bias, 0)
120
+
121
+ def forward_features(self, x):
122
+ hidden_states_out = []
123
+ for i in range(4):
124
+ x = self.downsample_layers[i](x)
125
+ x = self.stages[i](x)
126
+ hidden_states_out.append(x)
127
+ return self.norm(x.mean([-3, -2, -1])), hidden_states_out # global average pooling, (N, C, H, W, D) -> (N, C)
128
+
129
+ def forward(self, x, ret_hids=False):
130
+ x, hidden_states_out = self.forward_features(x)
131
+ if ret_hids:
132
+ return x, hidden_states_out
133
+ else:
134
+ return x
135
+
136
+
137
+
138
+ def convnextv2_atto(**kwargs):
139
+ model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[40, 80, 160, 320], **kwargs)
140
+ return model
141
+
142
+
143
+ def convnextv2_femto(**kwargs):
144
+ model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[48, 96, 192, 384], **kwargs)
145
+ return model
146
+
147
+
148
+ def convnext_pico(**kwargs):
149
+ model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[64, 128, 256, 512], **kwargs)
150
+ return model
151
+
152
+
153
+ def convnextv2_nano(**kwargs):
154
+ model = ConvNeXtV2(depths=[2, 2, 8, 2], dims=[80, 160, 320, 640], **kwargs)
155
+ return model
156
+
157
+
158
+ def convnextv2_tiny(**kwargs):
159
+ model = ConvNeXtV2(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
160
+ return model
161
+
162
+
163
+ def convnextv2_base(**kwargs):
164
+ model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
165
+ return model
166
+
167
+
168
+ def convnextv2_large(**kwargs):
169
+ model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
170
+ return model
171
+
172
+
173
+ def convnextv2_huge(**kwargs):
174
+ model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], **kwargs)
175
+ return model
176
+
177
+
178
+ def remap_checkpoint_keys(ckpt):
179
+ new_ckpt = OrderedDict()
180
+ ckpt = ckpt["model"]
181
+
182
+ checkpoint_model_keys = list(ckpt.keys())
183
+ for k in checkpoint_model_keys:
184
+ if "decoder" in k or "mask_token" in k or "proj" in k or "pred" in k:
185
+ print(f"Removing key {k} from pretrained checkpoint")
186
+ del ckpt[k]
187
+
188
+ for k, v in ckpt.items():
189
+ if k.startswith("encoder"):
190
+ k = ".".join(k.split(".")[1:]) # remove encoder in the name
191
+ if k.endswith("kernel"):
192
+ k = ".".join(k.split(".")[:-1]) # remove kernel in the name
193
+ new_k = k + ".weight"
194
+ if len(v.shape) == 3: # resahpe standard convolution
195
+ kv, in_dim, out_dim = v.shape
196
+ # ks = int(math.sqrt(kv))
197
+ # # pow(kv, 1/3)
198
+ # new_ckpt[new_k] = v.permute(2, 1, 0).\
199
+ # reshape(out_dim, in_dim, ks, ks).transpose(3, 2)
200
+ ks = int(
201
+ round(kv ** (1 / 3))
202
+ ) # calculate kernel size assuming cubic kernel
203
+ new_ckpt[new_k] = (
204
+ v.permute(2, 1, 0)
205
+ .reshape(out_dim, in_dim, ks, ks, ks)
206
+ .permute(0, 1, 4, 3, 2)
207
+ )
208
+ elif len(v.shape) == 2: # reshape depthwise convolution
209
+ kv, dim = v.shape
210
+ # ks = int(math.sqrt(kv))
211
+ # new_ckpt[new_k] = v.permute(1, 0).\
212
+ # reshape(dim, 1, ks, ks).transpose(3, 2)
213
+ if new_k == "downsample_layers.3.1.weight":
214
+ new_ckpt[new_k] = (
215
+ v.permute(1, 0).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
216
+ )
217
+ else:
218
+ ks = int(round(kv ** (1 / 3)))
219
+ new_ckpt[new_k] = (
220
+ v.permute(1, 0)
221
+ .reshape(dim, 1, ks, ks, ks)
222
+ .permute(0, 1, 4, 3, 2)
223
+ )
224
+ continue
225
+ elif "ln" in k or "linear" in k:
226
+ k = k.split(".")
227
+ k.pop(-2) # remove ln and linear in the name
228
+ new_k = ".".join(k)
229
+ else:
230
+ new_k = k
231
+ new_ckpt[new_k] = v
232
+
233
+ # reshape grn affine parameters and biases
234
+ for k, v in new_ckpt.items():
235
+ if k.endswith("bias") and len(v.shape) != 1:
236
+ new_ckpt[k] = v.reshape(-1)
237
+ elif "grn" in k:
238
+ new_ckpt[k] = v.unsqueeze(0).unsqueeze(1).unsqueeze(0)
239
+ return new_ckpt
240
+
241
+
242
+ def load_state_dict(
243
+ model, state_dict, prefix="", ignore_missing="relative_position_index"
244
+ ):
245
+ missing_keys = []
246
+ unexpected_keys = []
247
+ error_msgs = []
248
+ # copy state_dict so _load_from_state_dict can modify it
249
+ metadata = getattr(state_dict, "_metadata", None)
250
+ state_dict = state_dict.copy()
251
+ if metadata is not None:
252
+ state_dict._metadata = metadata
253
+
254
+ def load(module, prefix=""):
255
+ local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
256
+ module._load_from_state_dict(
257
+ state_dict,
258
+ prefix,
259
+ local_metadata,
260
+ True,
261
+ missing_keys,
262
+ unexpected_keys,
263
+ error_msgs,
264
+ )
265
+ for name, child in module._modules.items():
266
+ if child is not None:
267
+ load(child, prefix + name + ".")
268
+
269
+ load(model, prefix=prefix)
270
+
271
+ warn_missing_keys = []
272
+ ignore_missing_keys = []
273
+ for key in missing_keys:
274
+ keep_flag = True
275
+ for ignore_key in ignore_missing.split("|"):
276
+ if ignore_key in key:
277
+ keep_flag = False
278
+ break
279
+ if keep_flag:
280
+ warn_missing_keys.append(key)
281
+ else:
282
+ ignore_missing_keys.append(key)
283
+
284
+ missing_keys = warn_missing_keys
285
+
286
+ if len(missing_keys) > 0:
287
+ print(
288
+ "Weights of {} not initialized from pretrained model: {}".format(
289
+ model.__class__.__name__, missing_keys
290
+ )
291
+ )
292
+ if len(unexpected_keys) > 0:
293
+ print(
294
+ "Weights from pretrained model not used in {}: {}".format(
295
+ model.__class__.__name__, unexpected_keys
296
+ )
297
+ )
298
+ if len(ignore_missing_keys) > 0:
299
+ print(
300
+ "Ignored weights of {} not initialized from pretrained model: {}".format(
301
+ model.__class__.__name__, ignore_missing_keys
302
+ )
303
+ )
304
+ if len(error_msgs) > 0:
305
+ print("\n".join(error_msgs))
306
+
307
+
308
+ # if __name__ == 'main':
309
+ # model = convnextv2_base().cuda()
310
+ # x = torch.rand(1,3,256,256,32).cuda()
311
+ # print(model(x).shape)
models/upernet_module.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+ import torch
3
+ from torch import nn
4
+ from models.util import LayerNorm, GRN
5
+
6
+ class UperNetConvModule(nn.Module):
7
+ """
8
+ A convolutional block that bundles conv/norm/activation layers. This block simplifies the usage of convolution
9
+ layers, which are commonly used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
10
+ """
11
+
12
+ def __init__(
13
+ self,
14
+ in_channels: int,
15
+ out_channels: int,
16
+ kernel_size: Union[int, Tuple[int, int]],
17
+ padding: Union[int, Tuple[int, int], str] = 0,
18
+ bias: bool = False,
19
+ dilation: Union[int, Tuple[int, int]] = 1,
20
+ ) -> None:
21
+ super().__init__()
22
+ self.conv = nn.Conv3d(
23
+ in_channels=in_channels,
24
+ out_channels=out_channels,
25
+ kernel_size=kernel_size,
26
+ padding=padding,
27
+ bias=bias,
28
+ dilation=dilation,
29
+ )
30
+ self.batch_norm = LayerNorm(out_channels, eps=1e-6, data_format="channels_first") # nn.BatchNorm3d(out_channels)
31
+ self.activation = nn.GELU()
32
+
33
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
34
+ output = self.conv(input)
35
+ output = self.batch_norm(output)
36
+ output = self.activation(output)
37
+
38
+ return output
39
+
40
+
41
+ class UperNetPyramidPoolingBlock(nn.Module):
42
+ def __init__(self, pool_scale: int, in_channels: int, channels: int) -> None:
43
+ super().__init__()
44
+ self.layers = [
45
+ nn.AdaptiveAvgPool3d(pool_scale),
46
+ UperNetConvModule(in_channels, channels, kernel_size=1),
47
+ ]
48
+ for i, layer in enumerate(self.layers):
49
+ self.add_module(str(i), layer)
50
+
51
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
52
+ hidden_state = input
53
+ for layer in self.layers:
54
+ hidden_state = layer(hidden_state)
55
+ return hidden_state
56
+
57
+
58
+ class UperNetPyramidPoolingModule(nn.Module):
59
+ """
60
+ Pyramid Pooling Module (PPM) used in PSPNet.
61
+
62
+ Args:
63
+ pool_scales (`Tuple[int]`):
64
+ Pooling scales used in Pooling Pyramid Module.
65
+ in_channels (`int`):
66
+ Input channels.
67
+ channels (`int`):
68
+ Channels after modules, before conv_seg.
69
+ align_corners (`bool`):
70
+ align_corners argument of F.interpolate.
71
+ """
72
+
73
+ def __init__(
74
+ self,
75
+ pool_scales: Tuple[int, ...],
76
+ in_channels: int,
77
+ channels: int,
78
+ align_corners: bool,
79
+ ) -> None:
80
+ super().__init__()
81
+ self.pool_scales = pool_scales
82
+ self.align_corners = align_corners
83
+ self.in_channels = in_channels
84
+ self.channels = channels
85
+ self.blocks = []
86
+ for i, pool_scale in enumerate(pool_scales):
87
+ block = UperNetPyramidPoolingBlock(
88
+ pool_scale=pool_scale, in_channels=in_channels, channels=channels
89
+ )
90
+ self.blocks.append(block)
91
+ self.add_module(str(i), block)
92
+
93
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
94
+ ppm_outs = []
95
+ for ppm in self.blocks:
96
+ ppm_out = ppm(x)
97
+ upsampled_ppm_out = nn.functional.interpolate(
98
+ ppm_out,
99
+ size=x.size()[2:],
100
+ mode="trilinear",
101
+ align_corners=self.align_corners,
102
+ )
103
+ ppm_outs.append(upsampled_ppm_out)
104
+ return ppm_outs
105
+
106
+
107
+ class UperNetHead(nn.Module):
108
+ """
109
+ Unified Perceptual Parsing for Scene Understanding. This head is the implementation of
110
+ [UPerNet](https://arxiv.org/abs/1807.10221).
111
+ """
112
+
113
+ def __init__(self, in_channels, pool_scales, hidden_size, out_channels):
114
+ super().__init__()
115
+ self.pool_scales = pool_scales # e.g. (1, 2, 3, 6)
116
+ self.in_channels = in_channels
117
+ self.channels = hidden_size
118
+ self.align_corners = False
119
+ self.classifier = nn.Conv3d(self.channels, out_channels, kernel_size=1)
120
+
121
+ # PSP Module
122
+ self.psp_modules = UperNetPyramidPoolingModule(
123
+ self.pool_scales,
124
+ self.in_channels[-1],
125
+ self.channels,
126
+ align_corners=self.align_corners,
127
+ )
128
+ self.bottleneck = UperNetConvModule(
129
+ self.in_channels[-1] + len(self.pool_scales) * self.channels,
130
+ self.channels,
131
+ kernel_size=3,
132
+ padding=1,
133
+ )
134
+ # FPN Module
135
+ self.lateral_convs = nn.ModuleList()
136
+ self.fpn_convs = nn.ModuleList()
137
+ for in_channels in self.in_channels[:-1]: # skip the top layer
138
+ l_conv = UperNetConvModule(in_channels, self.channels, kernel_size=1)
139
+ fpn_conv = UperNetConvModule(
140
+ self.channels, self.channels, kernel_size=3, padding=1
141
+ )
142
+ self.lateral_convs.append(l_conv)
143
+ self.fpn_convs.append(fpn_conv)
144
+
145
+ self.fpn_bottleneck = UperNetConvModule(
146
+ len(self.in_channels) * self.channels,
147
+ self.channels,
148
+ kernel_size=3,
149
+ padding=1,
150
+ )
151
+
152
+ def init_weights(self):
153
+ self.apply(self._init_weights)
154
+
155
+ def _init_weights(self, module):
156
+ if isinstance(module, nn.Conv3d):
157
+ module.weight.data.normal_(mean=0.0, std=0.02)
158
+ if module.bias is not None:
159
+ module.bias.data.zero_()
160
+
161
+ def psp_forward(self, inputs):
162
+ x = inputs[-1]
163
+ psp_outs = [x]
164
+ psp_outs.extend(self.psp_modules(x))
165
+ psp_outs = torch.cat(psp_outs, dim=1)
166
+ output = self.bottleneck(psp_outs)
167
+
168
+ return output
169
+
170
+ def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
171
+ # build laterals
172
+ laterals = [
173
+ lateral_conv(encoder_hidden_states[i])
174
+ for i, lateral_conv in enumerate(self.lateral_convs)
175
+ ]
176
+
177
+ laterals.append(self.psp_forward(encoder_hidden_states))
178
+
179
+ # build top-down path
180
+ used_backbone_levels = len(laterals)
181
+ for i in range(used_backbone_levels - 1, 0, -1):
182
+ prev_shape = laterals[i - 1].shape[2:]
183
+ laterals[i - 1] = laterals[i - 1] + nn.functional.interpolate(
184
+ laterals[i],
185
+ size=prev_shape,
186
+ mode="trilinear",
187
+ align_corners=self.align_corners,
188
+ )
189
+
190
+ # build outputs
191
+ fpn_outs = [
192
+ self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1)
193
+ ]
194
+ # append psp feature
195
+ fpn_outs.append(laterals[-1])
196
+
197
+ for i in range(used_backbone_levels - 1, 0, -1):
198
+ fpn_outs[i] = nn.functional.interpolate(
199
+ fpn_outs[i],
200
+ size=fpn_outs[0].shape[2:],
201
+ mode="trilinear",
202
+ align_corners=self.align_corners,
203
+ )
204
+ fpn_outs = torch.cat(fpn_outs, dim=1)
205
+ output = self.fpn_bottleneck(fpn_outs)
206
+ output = self.classifier(output)
207
+
208
+ return output
209
+
210
+
211
+ class UperNetFCNHead(nn.Module):
212
+ """
213
+ Fully Convolution Networks for Semantic Segmentation. This head is the implementation of
214
+ [FCNNet](https://arxiv.org/abs/1411.4038>).
215
+
216
+ Args:
217
+ in_channels (int):
218
+ Number of input channels.
219
+ kernel_size (int):
220
+ The kernel size for convs in the head. Default: 3.
221
+ dilation (int):
222
+ The dilation rate for convs in the head. Default: 1.
223
+ """
224
+
225
+ def __init__(
226
+ self,
227
+ in_channels,
228
+ hidden_size,
229
+ num_convs,
230
+ out_channels,
231
+ concat_input=False,
232
+ in_index: int = 2,
233
+ kernel_size: int = 3,
234
+ dilation: Union[int, Tuple[int, int]] = 1,
235
+ ) -> None:
236
+ super().__init__()
237
+
238
+ self.in_channels = in_channels[in_index]
239
+ self.channels = hidden_size
240
+ self.num_convs = num_convs
241
+ self.concat_input = concat_input
242
+ self.in_index = in_index
243
+
244
+ conv_padding = (kernel_size // 2) * dilation
245
+ convs = []
246
+ convs.append(
247
+ UperNetConvModule(
248
+ self.in_channels,
249
+ self.channels,
250
+ kernel_size=kernel_size,
251
+ padding=conv_padding,
252
+ dilation=dilation,
253
+ )
254
+ )
255
+ for i in range(self.num_convs - 1):
256
+ convs.append(
257
+ UperNetConvModule(
258
+ self.channels,
259
+ self.channels,
260
+ kernel_size=kernel_size,
261
+ padding=conv_padding,
262
+ dilation=dilation,
263
+ )
264
+ )
265
+ if self.num_convs == 0:
266
+ self.convs = nn.Identity()
267
+ else:
268
+ self.convs = nn.Sequential(*convs)
269
+ if self.concat_input:
270
+ self.conv_cat = UperNetConvModule(
271
+ self.in_channels + self.channels,
272
+ self.channels,
273
+ kernel_size=kernel_size,
274
+ padding=kernel_size // 2,
275
+ )
276
+
277
+ self.classifier = nn.Conv3d(self.channels, out_channels, kernel_size=1)
278
+
279
+ def init_weights(self):
280
+ self.apply(self._init_weights)
281
+
282
+ def _init_weights(self, module):
283
+ if isinstance(module, nn.Conv3d):
284
+ module.weight.data.normal_(mean=0.0, std=0.02)
285
+ if module.bias is not None:
286
+ module.bias.data.zero_()
287
+
288
+ def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
289
+ # just take the relevant feature maps
290
+ hidden_states = encoder_hidden_states[self.in_index]
291
+ output = self.convs(hidden_states)
292
+ if self.concat_input:
293
+ output = self.conv_cat(torch.cat([hidden_states, output], dim=1))
294
+ output = self.classifier(output)
295
+ return output
296
+
297
+
298
+ class ViTAdapter(nn.Module):
299
+ def __init__(
300
+ self,
301
+ img_size=(64, 256, 256),
302
+ patch_size=(16, 32, 32),
303
+ embed_dim=768,
304
+ # out_indices=[3, 5, 7, 11],
305
+ ):
306
+ super().__init__()
307
+ # self.out_indices = out_indices
308
+
309
+ self.grid_size = tuple(img_d // p_d for img_d, p_d in zip(img_size, patch_size))
310
+ self.hidden_size = embed_dim
311
+
312
+ if patch_size == (16, 32, 32):
313
+ self.fpn1 = nn.Sequential(
314
+ nn.ConvTranspose3d(
315
+ embed_dim, embed_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)
316
+ ),
317
+ nn.BatchNorm3d(embed_dim),
318
+ nn.GELU(),
319
+ nn.ConvTranspose3d(embed_dim, embed_dim, kernel_size=2, stride=2),
320
+ nn.BatchNorm3d(embed_dim),
321
+ nn.GELU(),
322
+ nn.ConvTranspose3d(embed_dim, embed_dim, kernel_size=2, stride=2),
323
+ )
324
+
325
+ # 8
326
+ self.fpn2 = nn.Sequential(
327
+ nn.ConvTranspose3d(
328
+ embed_dim, embed_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)
329
+ ),
330
+ nn.BatchNorm3d(embed_dim),
331
+ nn.GELU(),
332
+ nn.ConvTranspose3d(embed_dim, embed_dim, kernel_size=2, stride=2),
333
+ )
334
+
335
+ # 16
336
+ self.fpn3 = nn.Sequential(
337
+ nn.ConvTranspose3d(
338
+ embed_dim, embed_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)
339
+ ),
340
+ )
341
+
342
+ # 32
343
+ self.fpn4 = nn.MaxPool3d(kernel_size=(2, 1, 1), stride=(2, 1, 1))
344
+
345
+ self.adapters = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
346
+
347
+ def proj_feat(self, x):
348
+
349
+ new_view = (x.size(0), *self.grid_size, self.hidden_size)
350
+ # print(f"x.shape: {x.shape}, expected: {new_view}, grid_size: {self.grid_size}")
351
+ x = x.view(new_view)
352
+ new_axes = (0, len(x.shape) - 1) + tuple(
353
+ d + 1 for d in range(len(self.grid_size))
354
+ )
355
+ x = x.permute(new_axes).contiguous()
356
+ return x
357
+
358
+ def forward(self, encoder_hidden_states):
359
+ output = []
360
+ # print(f"len_encoder_hidden: {len(encoder_hidden_states)}")
361
+ for index, op in zip(range(len(encoder_hidden_states)), self.adapters):
362
+ output.append(op(self.proj_feat(encoder_hidden_states[index])))
363
+ return output
364
+
365
+
366
+ class UperNet(nn.Module):
367
+ def __init__(
368
+ self,
369
+ encoder,
370
+ in_channels,
371
+ out_channels,
372
+ adapter=None,
373
+ out_indices=None,
374
+ pool_scales=[1, 2, 3, 6],
375
+ hidden_size=512,
376
+ auxiliary_channels=256,
377
+ use_auxiliary_head=True,
378
+ ):
379
+ super().__init__()
380
+ self.encoder = encoder
381
+ self.adapter = adapter
382
+ self.out_indices = out_indices
383
+ self.decode_head = UperNetHead(
384
+ in_channels=in_channels,
385
+ pool_scales=pool_scales,
386
+ hidden_size=hidden_size,
387
+ out_channels=out_channels,
388
+ )
389
+ self.auxiliary_head = (
390
+ UperNetFCNHead(
391
+ in_channels=in_channels,
392
+ hidden_size=auxiliary_channels,
393
+ num_convs=1,
394
+ out_channels=out_channels,
395
+ )
396
+ if use_auxiliary_head
397
+ else None
398
+ )
399
+
400
+ self.hidden_norm = nn.ModuleList()
401
+ for in_channel in in_channels:
402
+ norm = LayerNorm(in_channel, eps=1e-6, data_format="channels_first") # nn.BatchNorm3d(out_channels)
403
+ self.hidden_norm.append(norm)
404
+
405
+ def forward(self, x):
406
+ # print(f"403 input x.shape: {x.shape}")
407
+ encoder_hidden_states = self.encoder(x, ret_hids=True)
408
+ # print(f"405 {type(encoder_hidden_states)}, encoder_hidden_states: {len(encoder_hidden_states)}")
409
+ # for i, hidden_state in enumerate(encoder_hidden_states):
410
+ # print(f"407 encoder_hidden_states[{i}]: {type(hidden_state)}, {len(hidden_state)}")
411
+ if isinstance(encoder_hidden_states, list) or isinstance(
412
+ encoder_hidden_states, Tuple
413
+ ):
414
+ encoder_hidden_states = encoder_hidden_states[-1]
415
+ # print(f"410 {type(encoder_hidden_states)}, encoder_hidden_states: {len(encoder_hidden_states)}")
416
+ # for i, hidden_state in enumerate(encoder_hidden_states):
417
+ # print(f"412 encoder_hidden_states[{i}]: {hidden_state.shape}")
418
+ if self.out_indices:
419
+ encoder_hidden_states = [
420
+ encoder_hidden_states[i] for i in self.out_indices
421
+ ]
422
+
423
+ encoder_hidden_states = [
424
+ norm(encoder_hidden_states[i])
425
+ for i, norm in enumerate(self.hidden_norm)
426
+ ]
427
+ # print(f"415 encoder_hidden_states: {len(encoder_hidden_states)}")
428
+ # for i in range(len(encoder_hidden_states)):
429
+ # print(f"417 encoder_hidden_states[{i}]: {encoder_hidden_states[i].shape}")
430
+
431
+ if self.adapter:
432
+ encoder_hidden_states = self.adapter(encoder_hidden_states)
433
+
434
+ logits = self.decode_head(encoder_hidden_states)
435
+ logits = nn.functional.interpolate(
436
+ logits, size=x.shape[2:], mode="trilinear", align_corners=False
437
+ )
438
+ if not self.training:
439
+ return logits
440
+
441
+ auxiliary_logits = None
442
+ if self.auxiliary_head is not None:
443
+ auxiliary_logits = self.auxiliary_head(encoder_hidden_states)
444
+ auxiliary_logits = nn.functional.interpolate(
445
+ auxiliary_logits,
446
+ size=x.shape[2:],
447
+ mode="trilinear",
448
+ align_corners=False,
449
+ )
450
+ return [logits, auxiliary_logits]
451
+ return logits
models/util.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from itertools import chain
4
+ from typing import Callable
5
+ from torch.utils.checkpoint import checkpoint
6
+
7
+ import numpy.random as random
8
+ import torch.nn.functional as F
9
+ # from MinkowskiEngine import SparseTensor
10
+
11
+
12
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
13
+
14
+ # All rights reserved.
15
+
16
+ # This source code is licensed under the license found in the
17
+ # LICENSE file in the root directory of this source tree.
18
+
19
+
20
+ # class MinkowskiGRN(nn.Module):
21
+ # """GRN layer for sparse tensors."""
22
+
23
+ # def __init__(self, dim):
24
+ # super().__init__()
25
+ # self.gamma = nn.Parameter(torch.zeros(1, dim))
26
+ # self.beta = nn.Parameter(torch.zeros(1, dim))
27
+
28
+ # def forward(self, x):
29
+ # cm = x.coordinate_manager
30
+ # in_key = x.coordinate_map_key
31
+
32
+ # Gx = torch.norm(x.F, p=2, dim=0, keepdim=True)
33
+ # Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
34
+ # return SparseTensor(
35
+ # self.gamma * (x.F * Nx) + self.beta + x.F,
36
+ # coordinate_map_key=in_key,
37
+ # coordinate_manager=cm,
38
+ # )
39
+
40
+
41
+ class MinkowskiDropPath(nn.Module):
42
+ """Drop Path for sparse tensors."""
43
+
44
+ def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
45
+ super(MinkowskiDropPath, self).__init__()
46
+ self.drop_prob = drop_prob
47
+ self.scale_by_keep = scale_by_keep
48
+
49
+ def forward(self, x):
50
+ if self.drop_prob == 0.0 or not self.training:
51
+ return x
52
+ cm = x.coordinate_manager
53
+ in_key = x.coordinate_map_key
54
+ keep_prob = 1 - self.drop_prob
55
+ mask = (
56
+ torch.cat(
57
+ [
58
+ (
59
+ torch.ones(len(_))
60
+ if random.uniform(0, 1) > self.drop_prob
61
+ else torch.zeros(len(_))
62
+ )
63
+ for _ in x.decomposed_coordinates
64
+ ]
65
+ )
66
+ .view(-1, 1)
67
+ .to(x.device)
68
+ )
69
+ if keep_prob > 0.0 and self.scale_by_keep:
70
+ mask.div_(keep_prob)
71
+ return SparseTensor(
72
+ x.F * mask, coordinate_map_key=in_key, coordinate_manager=cm
73
+ )
74
+
75
+
76
+ class MinkowskiLayerNorm(nn.Module):
77
+ """Channel-wise layer normalization for sparse tensors."""
78
+
79
+ def __init__(
80
+ self,
81
+ normalized_shape,
82
+ eps=1e-6,
83
+ ):
84
+ super(MinkowskiLayerNorm, self).__init__()
85
+ self.ln = nn.LayerNorm(normalized_shape, eps=eps)
86
+
87
+ def forward(self, input):
88
+ output = self.ln(input.F)
89
+ return SparseTensor(
90
+ output,
91
+ coordinate_map_key=input.coordinate_map_key,
92
+ coordinate_manager=input.coordinate_manager,
93
+ )
94
+
95
+
96
+ class LayerNorm(nn.Module):
97
+ """LayerNorm that supports two data formats: channels_last (default) or channels_first.
98
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
99
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
100
+ with shape (batch_size, channels, height, width).
101
+ """
102
+
103
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
104
+ super().__init__()
105
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
106
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
107
+ self.eps = eps
108
+ self.data_format = data_format
109
+ if self.data_format not in ["channels_last", "channels_first"]:
110
+ raise NotImplementedError
111
+ self.normalized_shape = (normalized_shape,)
112
+
113
+ def forward(self, x):
114
+ if self.data_format == "channels_last":
115
+ return F.layer_norm(
116
+ x, self.normalized_shape, self.weight, self.bias, self.eps
117
+ )
118
+ elif self.data_format == "channels_first":
119
+ if len(x.shape) == 3: # for vit adapter
120
+ u = x.mean(1, keepdim=True)
121
+ s = (x - u).pow(2).mean(1, keepdim=True)
122
+ x = (x - u) / torch.sqrt(s + self.eps)
123
+ x = self.weight * x + self.bias
124
+ return x
125
+ else:
126
+ u = x.mean(1, keepdim=True)
127
+ s = (x - u).pow(2).mean(1, keepdim=True)
128
+ x = (x - u) / torch.sqrt(s + self.eps)
129
+ x = self.weight[:, None, None, None] * x + self.bias[:, None, None, None]
130
+ return x
131
+
132
+
133
+ class GRN(nn.Module):
134
+ """GRN (Global Response Normalization) layer"""
135
+
136
+ def __init__(self, dim):
137
+ super().__init__()
138
+ self.gamma = nn.Parameter(torch.zeros(1, 1, 1, 1, dim))
139
+ self.beta = nn.Parameter(torch.zeros(1, 1, 1, 1, dim))
140
+
141
+ def forward(self, x):
142
+ Gx = torch.norm(x, p=2, dim=(1, 2, 3), keepdim=True)
143
+ Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
144
+ return self.gamma * (x * Nx) + self.beta + x
145
+
146
+
147
+ def get_tokens(embed_dim: int, n_tokens: int) -> nn.Parameter:
148
+ """Return a learnable token of shape (1, n_tokens, embed_dim).
149
+
150
+ Args:
151
+ embed_dim: number of embedding channels.
152
+ n_tokens: number of tokens.
153
+
154
+ Returns:
155
+ token: learnable token.
156
+ """
157
+ token = nn.Parameter(torch.zeros(1, n_tokens, embed_dim))
158
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
159
+ nn.init.trunc_normal_(token, std=0.02, b=2.0)
160
+ return token
161
+
162
+
163
+ def init_weights(m):
164
+ if isinstance(m, nn.Linear):
165
+ # we use xavier_uniform following official JAX ViT:
166
+ torch.nn.init.xavier_uniform_(m.weight)
167
+ if isinstance(m, nn.Linear) and m.bias is not None:
168
+ nn.init.constant_(m.bias, 0)
169
+ elif isinstance(m, nn.LayerNorm):
170
+ nn.init.constant_(m.bias, 0)
171
+ nn.init.constant_(m.weight, 1.0)
172
+
173
+
174
+ """Gradient checkpointing utilities.
175
+
176
+ Copied from
177
+ https://github.com/huggingface/pytorch-image-models/blob/f8979d4f50b7920c78511746f7315df8f1857bc5/timm/models/_manipulate.py
178
+ and added use_reentrant=False following warnings in pytorch docs.
179
+ """
180
+
181
+
182
+ def checkpoint_seq(
183
+ functions: nn.Sequential,
184
+ x: torch.Tensor,
185
+ every: int = 1,
186
+ flatten: bool = False,
187
+ skip_last: bool = False,
188
+ preserve_rng_state: bool = True,
189
+ ) -> torch.Tensor:
190
+ r"""A helper function for checkpointing sequential models.
191
+
192
+ Sequential models execute a list of modules/functions in order
193
+ (sequentially). Therefore, we can divide such a sequence into segments
194
+ and checkpoint each segment. All segments except run in :func:`torch.no_grad`
195
+ manner, i.e., not storing the intermediate activations. The inputs of each
196
+ checkpointed segment will be saved for re-running the segment in the backward pass.
197
+
198
+ See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works.
199
+
200
+ .. warning::
201
+ Checkpointing currently only supports :func:`torch.autograd.backward`
202
+ and only if its `inputs` argument is not passed. :func:`torch.autograd.grad`
203
+ is not supported.
204
+
205
+ .. warning:
206
+ At least one of the inputs needs to have :code:`requires_grad=True` if
207
+ grads are needed for model inputs, otherwise the checkpointed part of the
208
+ model won't have gradients.
209
+
210
+ Args:
211
+ functions: A :class:`torch.nn.Sequential` or the list of modules or functions to run sequentially.
212
+ x: A Tensor that is input to :attr:`functions`
213
+ every: checkpoint every-n functions (default: 1)
214
+ flatten (bool): flatten nn.Sequential of nn.Sequentials
215
+ skip_last (bool): skip checkpointing the last function in the sequence if True
216
+ preserve_rng_state (bool, optional, default=True): Omit stashing and restoring
217
+ the RNG state during each checkpoint.
218
+
219
+ Returns:
220
+ Output of running :attr:`functions` sequentially on :attr:`*inputs`
221
+
222
+ Example:
223
+ >>> model = nn.Sequential(...)
224
+ >>> input_var = checkpoint_seq(model, input_var, every=2)
225
+ """
226
+
227
+ def run_function(
228
+ start: int, end: int, functions: nn.Sequential
229
+ ) -> Callable[[torch.Tensor], torch.Tensor]:
230
+ def forward(_x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
231
+ for j in range(start, end + 1):
232
+ _x = functions[j](_x)
233
+ return _x
234
+
235
+ return forward
236
+
237
+ if isinstance(functions, torch.nn.Sequential):
238
+ functions = functions.children()
239
+ if flatten:
240
+ functions = chain.from_iterable(functions)
241
+ if not isinstance(functions, (tuple, list)):
242
+ functions = tuple(functions)
243
+
244
+ num_checkpointed = len(functions)
245
+ if skip_last:
246
+ num_checkpointed -= 1
247
+ end = -1
248
+ for start in range(0, num_checkpointed, every):
249
+ end = min(start + every - 1, num_checkpointed - 1)
250
+ x = checkpoint(
251
+ run_function(start, end, functions),
252
+ x,
253
+ use_reentrant=False,
254
+ preserve_rng_state=preserve_rng_state,
255
+ )
256
+ if skip_last:
257
+ return run_function(end + 1, len(functions) - 1, functions)(x)
258
+ return x
requirements.txt CHANGED
@@ -34,6 +34,3 @@ tqdm==4.67.1
34
  # Additional dependencies for model architecture
35
  einops==0.8.1
36
  timm==1.0.15
37
-
38
- # ProFound package from GitHub
39
- git+https://github.com/pipiwang/ProFound.git@demo
 
34
  # Additional dependencies for model architecture
35
  einops==0.8.1
36
  timm==1.0.15
 
 
 
util/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # ProFound utilities package
util/convnext_optim.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+
9
+ import torch
10
+ from torch import optim as optim
11
+ import json
12
+
13
+
14
+ def get_num_layer_for_convnext_single(var_name, depths):
15
+ """
16
+ Each layer is assigned distinctive layer ids
17
+ """
18
+ if var_name.startswith("downsample_layers"):
19
+ stage_id = int(var_name.split(".")[1])
20
+ layer_id = sum(depths[:stage_id]) + 1
21
+ return layer_id
22
+
23
+ elif var_name.startswith("stages"):
24
+ stage_id = int(var_name.split(".")[1])
25
+ block_id = int(var_name.split(".")[2])
26
+ layer_id = sum(depths[:stage_id]) + block_id + 1
27
+ return layer_id
28
+
29
+ else:
30
+ return sum(depths) + 1
31
+
32
+
33
+ def get_num_layer_for_convnext(var_name):
34
+ """
35
+ Divide [3, 3, 27, 3] layers into 12 groups; each group is three
36
+ consecutive blocks, including possible neighboring downsample layers;
37
+ adapted from https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py
38
+ """
39
+ num_max_layer = 12
40
+ if var_name.startswith("downsample_layers"):
41
+ stage_id = int(var_name.split(".")[1])
42
+ if stage_id == 0:
43
+ layer_id = 0
44
+ elif stage_id == 1 or stage_id == 2:
45
+ layer_id = stage_id + 1
46
+ elif stage_id == 3:
47
+ layer_id = 12
48
+ return layer_id
49
+
50
+ elif var_name.startswith("stages"):
51
+ stage_id = int(var_name.split(".")[1])
52
+ block_id = int(var_name.split(".")[2])
53
+ if stage_id == 0 or stage_id == 1:
54
+ layer_id = stage_id + 1
55
+ elif stage_id == 2:
56
+ layer_id = 3 + block_id // 3
57
+ elif stage_id == 3:
58
+ layer_id = 12
59
+ return layer_id
60
+ else:
61
+ return num_max_layer + 1
62
+
63
+
64
+ class LayerDecayValueAssigner(object):
65
+ def __init__(self, values, depths=[3, 3, 27, 3], layer_decay_type="single"):
66
+ self.values = values
67
+ self.depths = depths
68
+ self.layer_decay_type = layer_decay_type
69
+
70
+ def get_scale(self, layer_id):
71
+ return self.values[layer_id]
72
+
73
+ def get_layer_id(self, var_name):
74
+ if self.layer_decay_type == "single":
75
+ return get_num_layer_for_convnext_single(var_name, self.depths)
76
+ else:
77
+ return get_num_layer_for_convnext(var_name)
78
+
79
+
80
+ def get_parameter_groups(
81
+ model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None
82
+ ):
83
+ parameter_group_names = {}
84
+ parameter_group_vars = {}
85
+
86
+ for name, param in model.named_parameters():
87
+ if not param.requires_grad:
88
+ continue # frozen weights
89
+ if (
90
+ len(param.shape) == 1
91
+ or name.endswith(".bias")
92
+ or name in skip_list
93
+ or name.endswith(".gamma")
94
+ or name.endswith(".beta")
95
+ ):
96
+ group_name = "no_decay"
97
+ this_weight_decay = 0.0
98
+ else:
99
+ group_name = "decay"
100
+ this_weight_decay = weight_decay
101
+ if get_num_layer is not None:
102
+ layer_id = get_num_layer(name)
103
+ group_name = "layer_%d_%s" % (layer_id, group_name)
104
+ else:
105
+ layer_id = None
106
+
107
+ if group_name not in parameter_group_names:
108
+ if get_layer_scale is not None:
109
+ scale = get_layer_scale(layer_id)
110
+ else:
111
+ scale = 1.0
112
+
113
+ parameter_group_names[group_name] = {
114
+ "weight_decay": this_weight_decay,
115
+ "params": [],
116
+ "lr_scale": scale,
117
+ }
118
+ parameter_group_vars[group_name] = {
119
+ "weight_decay": this_weight_decay,
120
+ "params": [],
121
+ "lr_scale": scale,
122
+ }
123
+
124
+ parameter_group_vars[group_name]["params"].append(param)
125
+ parameter_group_names[group_name]["params"].append(name)
126
+ print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
127
+ return list(parameter_group_vars.values())
util/lars.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # LARS optimizer, implementation from MoCo v3:
8
+ # https://github.com/facebookresearch/moco-v3
9
+ # --------------------------------------------------------
10
+
11
+ import torch
12
+
13
+ class LARS(torch.optim.Optimizer):
14
+ """
15
+ LARS optimizer, no rate scaling or weight decay for parameters <= 1D.
16
+ """
17
+
18
+ def __init__(
19
+ self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001
20
+ ):
21
+ defaults = dict(
22
+ lr=lr,
23
+ weight_decay=weight_decay,
24
+ momentum=momentum,
25
+ trust_coefficient=trust_coefficient,
26
+ )
27
+ super().__init__(params, defaults)
28
+
29
+ @torch.no_grad()
30
+ def step(self):
31
+ for g in self.param_groups:
32
+ for p in g["params"]:
33
+ dp = p.grad
34
+
35
+ if dp is None:
36
+ continue
37
+
38
+ if p.ndim > 1: # if not normalization gamma/beta or bias
39
+ dp = dp.add(p, alpha=g["weight_decay"])
40
+ param_norm = torch.norm(p)
41
+ update_norm = torch.norm(dp)
42
+ one = torch.ones_like(param_norm)
43
+ q = torch.where(
44
+ param_norm > 0.0,
45
+ torch.where(
46
+ update_norm > 0,
47
+ (g["trust_coefficient"] * param_norm / update_norm),
48
+ one,
49
+ ),
50
+ one,
51
+ )
52
+ dp = dp.mul(q)
53
+
54
+ param_state = self.state[p]
55
+ if "mu" not in param_state:
56
+ param_state["mu"] = torch.zeros_like(p)
57
+ mu = param_state["mu"]
58
+ mu.mul_(g["momentum"]).add_(dp)
59
+ p.add_(mu, alpha=-g["lr"])
util/lr_sched.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+
9
+
10
+ def adjust_learning_rate(optimizer, epoch, args):
11
+ """Decay the learning rate with half-cycle cosine after warmup"""
12
+ if epoch < args.warmup_epochs:
13
+ lr = args.lr * epoch / args.warmup_epochs
14
+ else:
15
+ lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * (
16
+ 1.0
17
+ + math.cos(
18
+ math.pi
19
+ * (epoch - args.warmup_epochs)
20
+ / (args.epochs - args.warmup_epochs)
21
+ )
22
+ )
23
+ for param_group in optimizer.param_groups:
24
+ if "lr_scale" in param_group:
25
+ param_group["lr"] = lr * param_group["lr_scale"]
26
+ else:
27
+ param_group["lr"] = lr
28
+ return lr
util/metric.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import prettytable
3
+ import copy
4
+ import sys
5
+ from importlib import import_module
6
+ from inspect import signature
7
+ from pathlib import Path
8
+ from typing import Optional, Union
9
+
10
+ import numpy as np
11
+ from scipy.stats import kendalltau, pearsonr, spearmanr
12
+ from sklearn.metrics import (
13
+ confusion_matrix,
14
+ f1_score,
15
+ fbeta_score,
16
+ get_scorer,
17
+ get_scorer_names,
18
+ make_scorer,
19
+ )
20
+
21
+
22
+ def binary_accuracy(output: torch.Tensor, target: torch.Tensor) -> float:
23
+ """Computes the accuracy for binary classification"""
24
+ with torch.no_grad():
25
+ batch_size = target.size(0)
26
+ pred = (output >= 0.5).float().t().view(-1)
27
+ correct = pred.eq(target.view(-1)).float().sum()
28
+ correct.mul_(100.0 / batch_size)
29
+ return correct
30
+
31
+
32
+ def accuracy(output, target, topk=(1,)):
33
+ r"""
34
+ Computes the accuracy over the k top predictions for the specified values of k
35
+
36
+ Args:
37
+ output (tensor): Classification outputs, :math:`(N, C)` where `C = number of classes`
38
+ target (tensor): :math:`(N)` where each value is :math:`0 \leq \text{targets}[i] \leq C-1`
39
+ topk (sequence[int]): A list of top-N number.
40
+
41
+ Returns:
42
+ Top-N accuracies (N :math:`\in` topK).
43
+ """
44
+ with torch.no_grad():
45
+ maxk = max(topk)
46
+ batch_size = target.size(0)
47
+
48
+ _, pred = output.topk(maxk, 1, True, True)
49
+ pred = pred.t()
50
+ correct = pred.eq(target[None])
51
+
52
+ res = []
53
+ for k in topk:
54
+ correct_k = correct[:k].flatten().sum(dtype=torch.float32)
55
+ res.append(correct_k * (100.0 / batch_size))
56
+ return res
57
+
58
+
59
+ class ConfusionMatrix(object):
60
+ def __init__(self, num_classes):
61
+ self.num_classes = num_classes
62
+ self.mat = None
63
+
64
+ def update(self, target, output):
65
+ """
66
+ Update confusion matrix.
67
+
68
+ Args:
69
+ target: ground truth
70
+ output: predictions of models
71
+
72
+ Shape:
73
+ - target: :math:`(minibatch, C)` where C means the number of classes.
74
+ - output: :math:`(minibatch, C)` where C means the number of classes.
75
+ """
76
+ n = self.num_classes
77
+ if self.mat is None:
78
+ self.mat = torch.zeros((n, n), dtype=torch.int64, device=target.device)
79
+ with torch.no_grad():
80
+ k = (target >= 0) & (target < n)
81
+ inds = n * target[k].to(torch.int64) + output[k]
82
+ self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)
83
+
84
+ def reset(self):
85
+ self.mat.zero_()
86
+
87
+ def compute(self):
88
+ """compute global accuracy, per-class accuracy and per-class IoU"""
89
+ h = self.mat.float()
90
+ acc_global = torch.diag(h).sum() / h.sum()
91
+ acc = torch.diag(h) / h.sum(1)
92
+ iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
93
+ return acc_global, acc, iu
94
+
95
+ # def reduce_from_all_processes(self):
96
+ # if not torch.distributed.is_available():
97
+ # return
98
+ # if not torch.distributed.is_initialized():
99
+ # return
100
+ # torch.distributed.barrier()
101
+ # torch.distributed.all_reduce(self.mat)
102
+
103
+ def __str__(self):
104
+ acc_global, acc, iu = self.compute()
105
+ return (
106
+ "global correct: {:.1f}\n"
107
+ "average row correct: {}\n"
108
+ "IoU: {}\n"
109
+ "mean IoU: {:.1f}"
110
+ ).format(
111
+ acc_global.item() * 100,
112
+ ["{:.1f}".format(i) for i in (acc * 100).tolist()],
113
+ ["{:.1f}".format(i) for i in (iu * 100).tolist()],
114
+ iu.mean().item() * 100,
115
+ )
116
+
117
+ def format(self, classes: list):
118
+ """Get the accuracy and IoU for each class in the table format"""
119
+ acc_global, acc, iu = self.compute()
120
+
121
+ table = prettytable.PrettyTable(["class", "acc", "iou"])
122
+ for i, class_name, per_acc, per_iu in zip(
123
+ range(len(classes)), classes, (acc * 100).tolist(), (iu * 100).tolist()
124
+ ):
125
+ table.add_row([class_name, per_acc, per_iu])
126
+
127
+ return (
128
+ "global correct: {:.1f}\nmean correct:{:.1f}\nmean IoU: {:.1f}\n{}".format(
129
+ acc_global.item() * 100,
130
+ acc.mean().item() * 100,
131
+ iu.mean().item() * 100,
132
+ table.get_string(),
133
+ )
134
+ )
135
+
136
+
137
+ def kappa(
138
+ y_true: np.ndarray,
139
+ y_pred: np.ndarray,
140
+ weights: Optional[Union[str, np.ndarray]] = None,
141
+ allow_off_by_one: bool = False,
142
+ ) -> float:
143
+ """
144
+ Calculate the kappa inter-rater agreement.
145
+
146
+ The agreement is calculated between the gold standard and the predicted
147
+ ratings. Potential values range from -1 (representing complete disagreement)
148
+ to 1 (representing complete agreement). A kappa value of 0 is expected if
149
+ all agreement is due to chance.
150
+
151
+ In the course of calculating kappa, all items in ``y_true`` and ``y_pred`` will
152
+ first be converted to floats and then rounded to integers.
153
+
154
+ It is assumed that y_true and y_pred contain the complete range of possible
155
+ ratings.
156
+
157
+ This function contains a combination of code from yorchopolis's kappa-stats
158
+ and Ben Hamner's Metrics projects on Github.
159
+
160
+ Parameters
161
+ ----------
162
+ y_true : numpy.ndarray
163
+ The true/actual/gold labels for the data.
164
+ y_pred : numpy.ndarray
165
+ The predicted/observed labels for the data.
166
+ weights : Optional[Union[str, numpy.ndarray]], default=None
167
+ Specifies the weight matrix for the calculation.
168
+ Possible values are: ``None`` (unweighted-kappa), ``"quadratic"``
169
+ (quadratically weighted kappa), ``"linear"`` (linearly weighted kappa),
170
+ and a two-dimensional numpy array (a custom matrix of weights). Each
171
+ weight in this array corresponds to the :math:`w_{ij}` values in the
172
+ Wikipedia description of how to calculate weighted Cohen's kappa.
173
+ allow_off_by_one : bool, default=False
174
+ If true, ratings that are off by one are counted as
175
+ equal, and all other differences are reduced by
176
+ one. For example, 1 and 2 will be considered to be
177
+ equal, whereas 1 and 3 will have a difference of 1
178
+ for when building the weights matrix.
179
+
180
+ Returns
181
+ -------
182
+ float
183
+ The weighted or unweighted kappa score.
184
+
185
+ Raises
186
+ ------
187
+ AssertionError
188
+ If ``y_true`` != ``y_pred``.
189
+ ValueError
190
+ If labels cannot be converted to int.
191
+ ValueError
192
+ If invalid weight scheme.
193
+ """
194
+ # Ensure that the lists are both the same length
195
+ assert len(y_true) == len(y_pred)
196
+
197
+ # This rather crazy looking typecast is intended to work as follows:
198
+ # If an input is an int, the operations will have no effect.
199
+ # If it is a float, it will be rounded and then converted to an int
200
+ # because the ml_metrics package requires ints.
201
+ # If it is a str like "1", then it will be converted to a (rounded) int.
202
+ # If it is a str that can't be typecast, then the user is
203
+ # given a hopefully useful error message.
204
+ try:
205
+ y_true = np.array([int(np.round(float(y))) for y in y_true])
206
+ y_pred = np.array([int(np.round(float(y))) for y in y_pred])
207
+ except ValueError:
208
+ raise ValueError(
209
+ "For kappa, the labels should be integers or strings"
210
+ " that can be converted to ints (E.g., '4.0' or "
211
+ "'3')."
212
+ )
213
+
214
+ # Figure out normalized expected values
215
+ min_rating = min(min(y_true), min(y_pred))
216
+ max_rating = max(max(y_true), max(y_pred))
217
+
218
+ # shift the values so that the lowest value is 0
219
+ # (to support scales that include negative values)
220
+ y_true = y_true - min_rating
221
+ y_pred = y_pred - min_rating
222
+
223
+ # Build the observed/confusion matrix
224
+ num_ratings = max_rating - min_rating + 1
225
+ observed = confusion_matrix(y_true, y_pred, labels=list(range(num_ratings)))
226
+ num_scored_items = float(len(y_true))
227
+
228
+ # Build weight array if weren't passed one
229
+ if isinstance(weights, str):
230
+ wt_scheme = weights
231
+ weights = None
232
+ else:
233
+ wt_scheme = ""
234
+
235
+ if weights is None:
236
+ kappa_weights = np.empty((num_ratings, num_ratings))
237
+ for i in range(num_ratings):
238
+ for j in range(num_ratings):
239
+ diff = abs(i - j)
240
+ if allow_off_by_one and diff:
241
+ diff -= 1
242
+ if wt_scheme == "linear":
243
+ kappa_weights[i, j] = diff
244
+ elif wt_scheme == "quadratic":
245
+ kappa_weights[i, j] = diff**2
246
+ elif not wt_scheme: # unweighted
247
+ kappa_weights[i, j] = bool(diff)
248
+ else:
249
+ raise ValueError(
250
+ "Invalid weight scheme specified for " f"kappa: {wt_scheme}"
251
+ )
252
+ else:
253
+ kappa_weights = weights
254
+
255
+ hist_true: np.ndarray = np.bincount(y_true, minlength=num_ratings)
256
+ hist_true = hist_true[:num_ratings] / num_scored_items
257
+ hist_pred: np.ndarray = np.bincount(y_pred, minlength=num_ratings)
258
+ hist_pred = hist_pred[:num_ratings] / num_scored_items
259
+ expected = np.outer(hist_true, hist_pred)
260
+
261
+ # Normalize observed array
262
+ observed = observed / num_scored_items
263
+
264
+ # If all weights are zero, that means no disagreements matter.
265
+ k = 1.0
266
+ if np.count_nonzero(kappa_weights):
267
+ observed_sum = np.sum(kappa_weights * observed)
268
+ expected_sum = np.sum(kappa_weights * expected)
269
+ k -= np.sum(observed_sum) / np.sum(expected_sum)
270
+
271
+ return k
272
+
273
+
274
+ def correlation(
275
+ y_true: np.ndarray, y_pred: np.ndarray, corr_type: str = "pearson"
276
+ ) -> float:
277
+ """
278
+ Calculate given correlation type between ``y_true`` and ``y_pred``.
279
+
280
+ ``y_pred`` can be multi-dimensional. If ``y_pred`` is 1-dimensional, it
281
+ may either contain probabilities, most-likely classification labels, or
282
+ regressor predictions. In that case, we simply return the correlation
283
+ between ``y_true`` and ``y_pred``. If ``y_pred`` is multi-dimensional,
284
+ it contains probabilties for multiple classes in which case, we infer the
285
+ most likely labels and then compute the correlation between those and
286
+ ``y_true``.
287
+
288
+ Parameters
289
+ ----------
290
+ y_true : numpy.ndarray
291
+ The true/actual/gold labels for the data.
292
+ y_pred : numpy.ndarray
293
+ The predicted/observed labels for the data.
294
+ corr_type : str, default="pearson"
295
+ Which type of correlation to compute. Possible
296
+ choices are "pearson", "spearman", and "kendall_tau".
297
+
298
+ Returns
299
+ -------
300
+ float
301
+ correlation value if well-defined, else 0.0
302
+ """
303
+ # get the correlation function to use based on the given type
304
+ corr_func = pearsonr
305
+ if corr_type == "spearman":
306
+ corr_func = spearmanr
307
+ elif corr_type == "kendall_tau":
308
+ corr_func = kendalltau
309
+
310
+ # convert to numpy array in case we are passed a list
311
+ y_pred = np.array(y_pred)
312
+
313
+ # multi-dimensional -> probability array -> get label
314
+ if y_pred.ndim > 1:
315
+ labels = np.argmax(y_pred, axis=1)
316
+ ret_score = corr_func(y_true, labels)[0]
317
+ # 1-dimensional -> probabilities/labels -> use as is
318
+ else:
319
+ ret_score = corr_func(y_true, y_pred)[0]
320
+ return ret_score
321
+
322
+
323
+ def f1_score_least_frequent(y_true: np.ndarray, y_pred: np.ndarray) -> float:
324
+ """
325
+ Calculate F1 score of the least frequent label/class.
326
+
327
+ Parameters
328
+ ----------
329
+ y_true : numpy.ndarray
330
+ The true/actual/gold labels for the data.
331
+ y_pred : numpy.ndarray
332
+ The predicted/observed labels for the data.
333
+
334
+ Returns
335
+ -------
336
+ float
337
+ F1 score of the least frequent label.
338
+ """
339
+ least_frequent = np.bincount(y_true).argmin()
340
+ return f1_score(y_true, y_pred, average=None)[least_frequent]
util/misc.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # DeiT: https://github.com/facebookresearch/deit
9
+ # BEiT: https://github.com/microsoft/unilm/tree/master/beit
10
+ # --------------------------------------------------------
11
+
12
+ import builtins
13
+ import datetime
14
+ import os
15
+ import time
16
+ from collections import defaultdict, deque
17
+ from pathlib import Path
18
+ import shutil
19
+ import torch
20
+ import torch.distributed as dist
21
+ from torch import inf
22
+ import json
23
+
24
+
25
+ class SmoothedValue(object):
26
+ """Track a series of values and provide access to smoothed values over a
27
+ window or the global series average.
28
+ """
29
+
30
+ def __init__(self, window_size=20, fmt=None):
31
+ if fmt is None:
32
+ fmt = "{median:.4f} ({global_avg:.4f})"
33
+ self.deque = deque(maxlen=window_size)
34
+ self.total = 0.0
35
+ self.count = 0
36
+ self.fmt = fmt
37
+
38
+ def update(self, value, n=1):
39
+ self.deque.append(value)
40
+ self.count += n
41
+ self.total += value * n
42
+
43
+ def synchronize_between_processes(self):
44
+ """
45
+ Warning: does not synchronize the deque!
46
+ """
47
+ if not is_dist_avail_and_initialized():
48
+ return
49
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
50
+ dist.barrier()
51
+ dist.all_reduce(t)
52
+ t = t.tolist()
53
+ self.count = int(t[0])
54
+ self.total = t[1]
55
+
56
+ @property
57
+ def median(self):
58
+ d = torch.tensor(list(self.deque))
59
+ return d.median().item()
60
+
61
+ @property
62
+ def avg(self):
63
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
64
+ return d.mean().item()
65
+
66
+ @property
67
+ def global_avg(self):
68
+ return self.total / self.count
69
+
70
+ @property
71
+ def max(self):
72
+ return max(self.deque)
73
+
74
+ @property
75
+ def value(self):
76
+ return self.deque[-1]
77
+
78
+ def __str__(self):
79
+ return self.fmt.format(
80
+ median=self.median,
81
+ avg=self.avg,
82
+ global_avg=self.global_avg,
83
+ max=self.max,
84
+ value=self.value,
85
+ )
86
+
87
+
88
+ class MetricLogger(object):
89
+ def __init__(self, delimiter="\t"):
90
+ self.meters = defaultdict(SmoothedValue)
91
+ self.delimiter = delimiter
92
+
93
+ def update(self, **kwargs):
94
+ for k, v in kwargs.items():
95
+ if v is None:
96
+ continue
97
+ if isinstance(v, torch.Tensor):
98
+ v = v.item()
99
+ assert isinstance(v, (float, int))
100
+ self.meters[k].update(v)
101
+
102
+ def __getattr__(self, attr):
103
+ if attr in self.meters:
104
+ return self.meters[attr]
105
+ if attr in self.__dict__:
106
+ return self.__dict__[attr]
107
+ raise AttributeError(
108
+ "'{}' object has no attribute '{}'".format(type(self).__name__, attr)
109
+ )
110
+
111
+ def __str__(self):
112
+ loss_str = []
113
+ for name, meter in self.meters.items():
114
+ loss_str.append("{}: {}".format(name, str(meter)))
115
+ return self.delimiter.join(loss_str)
116
+
117
+ def synchronize_between_processes(self):
118
+ for meter in self.meters.values():
119
+ meter.synchronize_between_processes()
120
+
121
+ def add_meter(self, name, meter):
122
+ self.meters[name] = meter
123
+
124
+ def log_every(self, iterable, print_freq, header=None):
125
+ i = 0
126
+ if not header:
127
+ header = ""
128
+ start_time = time.time()
129
+ end = time.time()
130
+ iter_time = SmoothedValue(fmt="{avg:.4f}")
131
+ data_time = SmoothedValue(fmt="{avg:.4f}")
132
+ space_fmt = ":" + str(len(str(len(iterable)))) + "d"
133
+ log_msg = [
134
+ header,
135
+ "[{0" + space_fmt + "}/{1}]",
136
+ "eta: {eta}",
137
+ "{meters}",
138
+ "time: {time}",
139
+ "data: {data}",
140
+ ]
141
+ if torch.cuda.is_available():
142
+ log_msg.append("max mem: {memory:.0f}")
143
+ log_msg = self.delimiter.join(log_msg)
144
+ MB = 1024.0 * 1024.0
145
+ for obj in iterable:
146
+ data_time.update(time.time() - end)
147
+ yield obj
148
+ iter_time.update(time.time() - end)
149
+ if i % print_freq == 0 or i == len(iterable) - 1:
150
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
151
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
152
+ if torch.cuda.is_available():
153
+ print(
154
+ log_msg.format(
155
+ i,
156
+ len(iterable),
157
+ eta=eta_string,
158
+ meters=str(self),
159
+ time=str(iter_time),
160
+ data=str(data_time),
161
+ memory=torch.cuda.max_memory_allocated() / MB,
162
+ )
163
+ )
164
+ else:
165
+ print(
166
+ log_msg.format(
167
+ i,
168
+ len(iterable),
169
+ eta=eta_string,
170
+ meters=str(self),
171
+ time=str(iter_time),
172
+ data=str(data_time),
173
+ )
174
+ )
175
+ i += 1
176
+ end = time.time()
177
+ total_time = time.time() - start_time
178
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
179
+ print(
180
+ "{} Total time: {} ({:.4f} s / it)".format(
181
+ header, total_time_str, total_time / len(iterable)
182
+ )
183
+ )
184
+
185
+
186
+ def setup_for_distributed(is_master):
187
+ """
188
+ This function disables printing when not in master process
189
+ """
190
+ builtin_print = builtins.print
191
+
192
+ def print(*args, **kwargs):
193
+ force = kwargs.pop("force", False)
194
+ force = force or (get_world_size() > 8)
195
+ if is_master or force:
196
+ now = datetime.datetime.now().time()
197
+ builtin_print("[{}] ".format(now), end="") # print with time stamp
198
+ builtin_print(*args, **kwargs)
199
+
200
+ builtins.print = print
201
+
202
+
203
+ def is_dist_avail_and_initialized():
204
+ if not dist.is_available():
205
+ return False
206
+ if not dist.is_initialized():
207
+ return False
208
+ return True
209
+
210
+
211
+ def get_world_size():
212
+ if not is_dist_avail_and_initialized():
213
+ return 1
214
+ return dist.get_world_size()
215
+
216
+
217
+ def get_rank():
218
+ if not is_dist_avail_and_initialized():
219
+ return 0
220
+ return dist.get_rank()
221
+
222
+
223
+ def is_main_process():
224
+ return get_rank() == 0
225
+
226
+
227
+ def save_on_master(*args, **kwargs):
228
+ if is_main_process():
229
+ torch.save(*args, **kwargs)
230
+
231
+
232
+ def init_distributed_mode(args):
233
+ if args.dist_on_itp:
234
+ args.rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
235
+ args.world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"])
236
+ args.gpu = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
237
+ args.dist_url = "tcp://%s:%s" % (
238
+ os.environ["MASTER_ADDR"],
239
+ os.environ["MASTER_PORT"],
240
+ )
241
+ os.environ["LOCAL_RANK"] = str(args.gpu)
242
+ os.environ["RANK"] = str(args.rank)
243
+ os.environ["WORLD_SIZE"] = str(args.world_size)
244
+ # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
245
+ elif "RANK" in os.environ and "WORLD_SIZE" in os.environ:
246
+ args.rank = int(os.environ["RANK"])
247
+ args.world_size = int(os.environ["WORLD_SIZE"])
248
+ args.gpu = int(os.environ["LOCAL_RANK"])
249
+ elif "SLURM_PROCID" in os.environ:
250
+ args.rank = int(os.environ["SLURM_PROCID"])
251
+ args.gpu = args.rank % torch.cuda.device_count()
252
+ else:
253
+ print("Not using distributed mode")
254
+ setup_for_distributed(is_master=True) # hack
255
+ args.distributed = False
256
+ return
257
+
258
+ args.distributed = True
259
+
260
+ torch.cuda.set_device(args.gpu)
261
+ args.dist_backend = "nccl"
262
+ print(
263
+ "| distributed init (rank {}): {}, gpu {}".format(
264
+ args.rank, args.dist_url, args.gpu
265
+ ),
266
+ flush=True,
267
+ )
268
+ torch.distributed.init_process_group(
269
+ backend=args.dist_backend,
270
+ init_method=args.dist_url,
271
+ world_size=args.world_size,
272
+ rank=args.rank,
273
+ )
274
+ torch.distributed.barrier()
275
+ setup_for_distributed(args.rank == 0)
276
+
277
+
278
+ class NativeScalerWithGradNormCount:
279
+ state_dict_key = "amp_scaler"
280
+
281
+ def __init__(self):
282
+ self._scaler = torch.cuda.amp.GradScaler()
283
+
284
+ def __call__(
285
+ self,
286
+ loss,
287
+ optimizer,
288
+ clip_grad=None,
289
+ parameters=None,
290
+ create_graph=False,
291
+ update_grad=True,
292
+ ):
293
+ self._scaler.scale(loss).backward(create_graph=create_graph)
294
+ if update_grad:
295
+ if clip_grad is not None:
296
+ assert parameters is not None
297
+ self._scaler.unscale_(
298
+ optimizer
299
+ ) # unscale the gradients of optimizer's assigned params in-place
300
+ norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
301
+ else:
302
+ self._scaler.unscale_(optimizer)
303
+ norm = get_grad_norm_(parameters)
304
+ self._scaler.step(optimizer)
305
+ self._scaler.update()
306
+ else:
307
+ norm = None
308
+ return norm
309
+
310
+ def state_dict(self):
311
+ return self._scaler.state_dict()
312
+
313
+ def load_state_dict(self, state_dict):
314
+ self._scaler.load_state_dict(state_dict)
315
+
316
+
317
+ def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
318
+ if isinstance(parameters, torch.Tensor):
319
+ parameters = [parameters]
320
+ parameters = [p for p in parameters if p.grad is not None]
321
+ norm_type = float(norm_type)
322
+ if len(parameters) == 0:
323
+ return torch.tensor(0.0)
324
+ device = parameters[0].grad.device
325
+ if norm_type == inf:
326
+ total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
327
+ else:
328
+ total_norm = torch.norm(
329
+ torch.stack(
330
+ [torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]
331
+ ),
332
+ norm_type,
333
+ )
334
+ return total_norm
335
+
336
+
337
+ def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler):
338
+ output_dir = Path(args.output_dir)
339
+ epoch_name = str(epoch)
340
+ if loss_scaler is not None:
341
+ checkpoint_paths = [output_dir / ("checkpoint-%s.pth" % epoch_name)]
342
+ for checkpoint_path in checkpoint_paths:
343
+ to_save = {
344
+ "model": model_without_ddp.state_dict(),
345
+ "optimizer": optimizer.state_dict(),
346
+ "epoch": epoch,
347
+ "scaler": loss_scaler.state_dict(),
348
+ "args": args,
349
+ }
350
+ save_on_master(to_save, checkpoint_path)
351
+ else:
352
+ client_state = {"epoch": epoch}
353
+ model.save_checkpoint(
354
+ save_dir=args.output_dir,
355
+ tag="checkpoint-%s" % epoch_name,
356
+ client_state=client_state,
357
+ )
358
+
359
+
360
+ def save_best_model(
361
+ args, epoch, model, model_without_ddp, optimizer, loss_scaler, is_best
362
+ ):
363
+ output_dir = Path(args.output_dir)
364
+ epoch_name = str(epoch)
365
+ if loss_scaler is not None:
366
+ checkpoint_path = output_dir / ("last.pth.tar")
367
+ to_save = {
368
+ "model": model_without_ddp.state_dict(),
369
+ "optimizer": optimizer.state_dict(),
370
+ "epoch": epoch,
371
+ "scaler": loss_scaler.state_dict(),
372
+ "args": args,
373
+ }
374
+ save_on_master(to_save, checkpoint_path)
375
+ else:
376
+ client_state = {"epoch": epoch}
377
+ model.save_checkpoint(
378
+ save_dir=args.output_dir,
379
+ tag="checkpoint-%s" % epoch_name,
380
+ client_state=client_state,
381
+ )
382
+ if is_best:
383
+ filepath_best = output_dir / ("best.pth.tar")
384
+ shutil.copyfile(checkpoint_path, filepath_best)
385
+
386
+
387
+ def save_current_best_model(
388
+ args, epoch, model, model_without_ddp, optimizer, loss_scaler, is_best, current_interval
389
+ ):
390
+ output_dir = Path(args.output_dir)
391
+ epoch_name = str(epoch)
392
+ if loss_scaler is not None:
393
+ checkpoint_paths = [output_dir / (f"{current_interval}_last.pth.tar")]
394
+ for checkpoint_path in checkpoint_paths:
395
+ to_save = {
396
+ "model": model_without_ddp.state_dict(),
397
+ "optimizer": optimizer.state_dict(),
398
+ "epoch": epoch,
399
+ "scaler": loss_scaler.state_dict(),
400
+ "args": args,
401
+ }
402
+ save_on_master(to_save, checkpoint_path)
403
+ else:
404
+ client_state = {"epoch": epoch}
405
+ model.save_checkpoint(
406
+ save_dir=args.output_dir,
407
+ tag="checkpoint-%s" % epoch_name,
408
+ client_state=client_state,
409
+ )
410
+ if is_best:
411
+ filepath_best = output_dir / (f"{current_interval}_best.pth.tar")
412
+ shutil.copyfile(checkpoint_path, filepath_best)
413
+
414
+
415
+ def load_model(args, model_without_ddp, optimizer, loss_scaler):
416
+ if args.resume:
417
+ if args.resume.startswith("https"):
418
+ checkpoint = torch.hub.load_state_dict_from_url(
419
+ args.resume, map_location="cpu", check_hash=True, weights_only=False
420
+ )
421
+ else:
422
+ checkpoint = torch.load(args.resume, map_location="cpu")
423
+ model_without_ddp.load_state_dict(checkpoint["model"], weights_only=False)
424
+ print("Resume checkpoint %s" % args.resume)
425
+ if (
426
+ "optimizer" in checkpoint
427
+ and "epoch" in checkpoint
428
+ and not (hasattr(args, "eval") and args.eval)
429
+ ):
430
+ optimizer.load_state_dict(checkpoint["optimizer"], weights_only=False)
431
+ args.start_epoch = checkpoint["epoch"] + 1
432
+ if "scaler" in checkpoint:
433
+ loss_scaler.load_state_dict(checkpoint["scaler"], weights_only=False)
434
+ print("With optim & sched!")
435
+
436
+
437
+ def all_reduce_mean(x):
438
+ world_size = get_world_size()
439
+ if world_size > 1:
440
+ x_reduce = torch.tensor(x).cuda()
441
+ dist.all_reduce(x_reduce)
442
+ x_reduce /= world_size
443
+ return x_reduce.item()
444
+ else:
445
+ return x
446
+
447
+
448
+ def write_log(log_writer, log_stats, args):
449
+ if args.output_dir and is_main_process():
450
+ if log_writer is not None:
451
+ log_writer.flush()
452
+ with open(
453
+ os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8"
454
+ ) as f:
455
+ f.write(json.dumps(log_stats) + "\n")