DYunt commited on
Commit
2659b26
·
verified ·
1 Parent(s): 63999bd

Upload 26 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ overall.pdf filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,23 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SUMMIT: A SAR Foundation Model with Multiple Auxiliary Tasks Enhanced Intrinsic Characteristics
2
+ [SUMMIT: A SAR Foundation Model with Multiple Auxiliary Tasks Enhanced Intrinsic Characteristics](https://doi.org/10.1016/j.jag.2025.104624)
3
+ ## Overview
4
+ This repository hosts the official implementation of SUMMIT, a state-of-the-art (SOTA) foundation model tailored for Synthetic Aperture Radar (SAR) image understanding. Proposed in the paper "SUMMIT: A SAR foundation model with multiple auxiliary tasks enhanced intrinsic characteristics" (published in International Journal of Applied Earth Observation and Geoinformation, 2025), SUMMIT addresses the limitations of existing deep learning methods in SAR processing—such as neglecting SAR’s intrinsic physical properties and poor cross-task generalization—by integrating self-supervised auxiliary tasks and SAR-specific prior knowledge.
5
+
6
+ ## Key Contributions
7
+ 1. Large-Scale SAR Dataset (MuSID)Constructed the Multi-sensor SAR Image Dataset (MuSID) with over 560,000 SAR images, covering diverse scenarios (aircraft, ships, bridges, harbors), resolutions (0.1–25 m), and sensors (Gaofen-3, Sentinel-1, TerraSARX, etc.). It supports large-scale self-supervised pre-training for SAR foundation models.
8
+
9
+ 2. Multi-Auxiliary-Task Pre-Training FrameworkProposed three self-supervised auxiliary tasks (SSATs) to enhance SAR feature learning: Masked Image Modeling (MIM): Learns robust structural representations of SAR images. Self-Supervised Denoising: Mitigates speckle noise (a unique artifact of SAR imaging) and improves noise resistance. Spatial Scattering Feature (SSF) Enhancement: Preserves geometric consistency by extracting edge features (via Canny algorithm) and scattering point features (via Harris corner detection).
10
+
11
+ 3. Auxiliary Task Coordination Module (ATCM)Designed ATCM to dynamically balance and fuse the three auxiliary tasks. Unlike simple task aggregation, ATCM aligns each task with the optimal stage of the learning process (e.g., denoising at input level, edge reconstruction at output level), ensuring effective integration of SAR physical properties into feature learning.
12
+
13
+ ## Model Architecture
14
+
15
+ ![SUMMIT Architecture](overall.jpeg "SUMMIT Framework")
16
+
17
+ SUMMIT is built on a Vision Transformer (ViT). Pre-Training StageInput: MuSID dataset (448×448 resized images). Process: ATCM coordinates MIM, denoising, and SSF enhancement tasks. The shared ViT encoder learns SAR-specific features, with a decoder optimizing multi-task reconstruction loss.
18
+
19
+ ## Environment Setup
20
+ ```bash
21
+ conda create -n summit python=3.8
22
+ conda activate summit
23
+ pip install -r requirements.txt
SARdatasets.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision.datasets import ImageFolder
3
+ from PIL import Image
4
+ from PIL import ImageFile
5
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
6
+
7
+ import cv2
8
+ import numpy as np
9
+ import random
10
+ from scipy.ndimage import convolve
11
+
12
+
13
+ class SARImageFolder(ImageFolder):
14
+ def __init__(self, root, transform=None):
15
+ super().__init__(root, transform=transform)
16
+
17
+ def __getitem__(self, index):
18
+ path, target = self.samples[index]
19
+
20
+ image = cv2.imread(path)
21
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
22
+ image = np.float32(image)
23
+
24
+ edges = cv2.Canny(image.astype(np.uint8), 200, 300)
25
+
26
+ corners = cv2.cornerHarris(image, 5, 3, 0.04)
27
+ corners = corners * 255
28
+
29
+ multi_channel_image = np.dstack((image, edges, corners))
30
+ multi_channel_image = multi_channel_image.astype(np.uint8)
31
+ multi_channel_image = Image.fromarray(multi_channel_image)
32
+
33
+ if self.transform is not None:
34
+ multi_channel_image = self.transform(multi_channel_image)
35
+
36
+ return multi_channel_image, target
37
+
38
+
39
+ class build_coed_SARImageFolder(ImageFolder):
40
+ def __init__(self, root, transform=None):
41
+ super().__init__(root, transform=transform)
42
+
43
+ def __getitem__(self, index):
44
+ path, target = self.samples[index]
45
+
46
+ image_3ch = Image.open(path).convert('RGB')
47
+ image = Image.open(path).convert('L')
48
+ image_np = np.array(image)
49
+
50
+ edges = cv2.Canny(image_np, 200, 300)
51
+
52
+ corners = cv2.cornerHarris(image_np, 5, 3, 0.04)
53
+ corners = corners * 255
54
+
55
+ multi_channel_image = np.dstack((image_np, edges, corners))
56
+ multi_channel_image = multi_channel_image.astype(np.uint8)
57
+ multi_channel_image = Image.fromarray(multi_channel_image)
58
+
59
+ if self.transform is not None:
60
+ multi_channel_image = self.transform(multi_channel_image)
61
+ image_3ch = self.transform(image_3ch)
62
+
63
+ target = multi_channel_image
64
+
65
+ return image_3ch, target
66
+
67
+
68
+ class Multi_task_SARImageFolder(ImageFolder):
69
+ def __init__(self, root, transform=None):
70
+ super().__init__(root, transform=transform)
71
+
72
+ def add_gamma_noise(self, image_np, looks):
73
+ """
74
+ 向图像添加伽马分布的相干斑噪声
75
+ :param image_np: 原始图像的numpy数组
76
+ :param looks: SAR图像的等效视数(ENL,越大噪声越小)
77
+ :return: 加噪后的图像
78
+ """
79
+ image_np = image_np.astype(np.float32)
80
+
81
+ image_np = image_np / np.max(image_np)
82
+
83
+ gamma_noise = np.random.gamma(shape=looks, scale=1.0 / looks, size=image_np.shape)
84
+
85
+ noisy_image = image_np * gamma_noise
86
+
87
+ noisy_image = np.clip(noisy_image * 255, 0, 255).astype(np.uint8)
88
+
89
+ return noisy_image
90
+
91
+ def add_gaussian_noise(self, image_np, snr_db):
92
+ """
93
+ 向图像添加高斯白噪声
94
+ :param image_np: 原始图像的numpy数组
95
+ :param snr_db: 期望的信噪比(以分贝为单位)
96
+ :return: 加噪后的图像
97
+ """
98
+ signal_power = np.mean(image_np ** 2)
99
+
100
+ snr = 10 ** (snr_db / 10)
101
+
102
+ noise_power = signal_power / snr
103
+
104
+ noise_sigma = np.sqrt(noise_power)
105
+
106
+ current_state = torch.random.get_rng_state()
107
+ current_cuda_state = torch.cuda.get_rng_state()
108
+
109
+ torch.manual_seed(np.random.randint(0, 2 ** 31 - 1))
110
+ torch.cuda.manual_seed_all(np.random.randint(0, 2 ** 31 - 1))
111
+
112
+ noise = np.random.normal(0, noise_sigma, image_np.shape)
113
+
114
+ torch.random.set_rng_state(current_state)
115
+ torch.cuda.set_rng_state(current_cuda_state)
116
+
117
+ noisy_image = image_np + noise
118
+
119
+ return noisy_image.astype(np.uint8)
120
+
121
+ def log_transform(self, image_np):
122
+ image_np = image_np.astype(np.float32)
123
+
124
+ c = 20.0
125
+ transformed_image = c * np.log1p(image_np) # torch.log1p计算log(1 + x)
126
+
127
+ return transformed_image
128
+
129
+ def __getitem__(self, index):
130
+ path, target = self.samples[index]
131
+
132
+ image_3ch = Image.open(path).convert('RGB')
133
+ image_3ch_np = np.array(image_3ch)
134
+
135
+ image = Image.open(path).convert('L')
136
+ image_np = np.array(image)
137
+
138
+ edges = cv2.Canny(image_np, 200, 300)
139
+
140
+ corners = cv2.cornerHarris(image_np, 5, 3, 0.04)
141
+ corners = corners * 255
142
+
143
+ first_channel = image_3ch_np[:, :, 0]
144
+ noisy_first_channel = self.add_gamma_noise(first_channel, 30)
145
+ image_3ch_np[:, :, 0] = noisy_first_channel
146
+ image_3ch = Image.fromarray(image_3ch_np)
147
+
148
+ multi_channel_image = np.dstack((image_np, edges, corners))
149
+ multi_channel_image = multi_channel_image.astype(np.uint8)
150
+ multi_channel_image = Image.fromarray(multi_channel_image)
151
+
152
+ if self.transform is not None:
153
+ multi_channel_image = self.transform(multi_channel_image)
154
+ image_3ch = self.transform(image_3ch)
155
+
156
+ target = multi_channel_image
157
+
158
+ return image_3ch, target
159
+
160
+
161
+ class Multi_task_angel_SARImageFolder(ImageFolder):
162
+ def __init__(self, root, transform=None):
163
+ super().__init__(root, transform=transform)
164
+
165
+ def add_gaussian_noise(self, image_np, snr_db):
166
+
167
+ signal_power = np.mean(image_np ** 2)
168
+
169
+ snr = 10 ** (snr_db / 10)
170
+
171
+ noise_power = signal_power / snr
172
+
173
+ noise_sigma = np.sqrt(noise_power)
174
+
175
+ noise = np.random.normal(0, noise_sigma, image_np.shape)
176
+
177
+ noisy_image = image_np + noise
178
+
179
+ return noisy_image.astype(np.uint8)
180
+
181
+ def log_transform(self, image_np):
182
+
183
+ image_np = image_np.astype(np.float32)
184
+
185
+
186
+ c = 20.0
187
+ transformed_image = c * np.log1p(image_np)
188
+
189
+ return transformed_image
190
+
191
+ def __getitem__(self, index):
192
+ path, target = self.samples[index]
193
+
194
+ image_3ch = Image.open(path).convert('RGB')
195
+ image_3ch_np = np.array(image_3ch)
196
+
197
+ image = Image.open(path).convert('L')
198
+ image_np = np.array(image)
199
+
200
+ edges = cv2.Canny(image_np, 200, 300)
201
+
202
+ corners = cv2.cornerHarris(image_np, 5, 3, 0.04)
203
+ corners = corners * 255
204
+
205
+ kernel_size = 50
206
+ kernel = np.ones((kernel_size, kernel_size))
207
+ density = convolve(corners, kernel, mode='constant', cval=0.0)
208
+
209
+ max_density_index = np.unravel_index(np.argmax(density), density.shape)
210
+ center_y, center_x = max_density_index
211
+
212
+ half_size = kernel_size // 2
213
+ start_y = max(center_y - half_size, 0)
214
+ end_y = min(center_y + half_size, corners.shape[0])
215
+ start_x = max(center_x - half_size, 0)
216
+ end_x = min(center_x + half_size, corners.shape[1])
217
+
218
+ region = image_np[start_y:end_y, start_x:end_x]
219
+
220
+ angle = random.choice([0, 90, 180, 270])
221
+ M = cv2.getRotationMatrix2D((region.shape[1] // 2, region.shape[0] // 2), angle, 1)
222
+ rotated_region = cv2.warpAffine(region, M, (region.shape[1], region.shape[0]))
223
+
224
+ rotated_image = image_np.copy()
225
+ rotated_image[start_y:end_y, start_x:end_x] = rotated_region
226
+
227
+ image_4ch_np = np.insert(image_3ch_np, 1, rotated_image, axis=2)
228
+
229
+ first_channel = image_3ch_np[:, :, 0]
230
+ first_channel = self.log_transform(first_channel)
231
+ noisy_first_channel = self.add_gaussian_noise(first_channel, 30)
232
+ image_4ch_np[:, :, 0] = noisy_first_channel
233
+ image_4ch = Image.fromarray(image_3ch_np)
234
+
235
+ multi_channel_image = np.dstack((image_np, image_np, edges, corners))
236
+ multi_channel_image = multi_channel_image.astype(np.uint8)
237
+ multi_channel_image = Image.fromarray(multi_channel_image)
238
+
239
+ if self.transform is not None:
240
+ multi_channel_image = self.transform(multi_channel_image)
241
+ image_4ch = self.transform(image_4ch)
242
+
243
+ target = image_4ch
244
+
245
+ return multi_channel_image, target
acc_pretrain.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+
3
+ from get_args import get_args_pretrain
4
+ import mae_model
5
+ # import mae_ori_model
6
+ import numpy as np
7
+ import datetime
8
+ import time
9
+ import json
10
+ import math
11
+ import sys
12
+ from typing import Iterable
13
+ from pathlib import Path
14
+ from accelerate import Accelerator
15
+ import torch
16
+ import torch.backends.cudnn as cudnn
17
+ import torch.nn as nn
18
+ from torch.utils.tensorboard import SummaryWriter
19
+ import torchvision.transforms as transforms
20
+ import torchvision.datasets as datasets
21
+ import timm.optim.optim_factory as optim_factory
22
+ from SARdatasets import SARImageFolder, build_coed_SARImageFolder, Multi_task_SARImageFolder
23
+ import util.misc as misc
24
+ import util.lr_sched as lr_sched
25
+ from util.pos_embed import interpolate_pos_embed
26
+ from util.misc import NativeScalerWithGradNormCount as NativeScaler
27
+
28
+
29
+ def train_one_epoch(model: torch.nn.Module, data_loader: Iterable, optimizer: torch.optim.Optimizer,
30
+ device: torch.device, epoch: int, loss_scaler,
31
+ log_writer=None,
32
+ args=None,
33
+ accelerator=None):
34
+ model.train(True)
35
+ metric_logger = misc.MetricLogger(delimiter=" ")
36
+ metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
37
+ header = 'Epoch: [{}]'.format(epoch)
38
+ print_freq = 20
39
+
40
+ accum_iter = args.accum_iter
41
+
42
+ optimizer.zero_grad()
43
+
44
+ if log_writer is not None:
45
+ print('log_dir: {}'.format(log_writer.log_dir))
46
+
47
+ for data_iter_step, (samples, target) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
48
+
49
+ samples = samples.to(device, non_blocking=True)
50
+ target = target.to(device, non_blocking=True)
51
+
52
+ with torch.cuda.amp.autocast():
53
+ loss, channel_loss, _, _ = model(samples, target) #, mask_ratio=args.mask_ratio)
54
+
55
+ loss_value = loss.item()
56
+
57
+ if not math.isfinite(loss_value):
58
+ print("Loss is {}, stopping training".format(loss_value))
59
+ sys.exit(1)
60
+
61
+ accelerator.backward(loss)
62
+
63
+ if (data_iter_step + 1) % accum_iter == 0:
64
+ optimizer.zero_grad()
65
+
66
+ torch.cuda.synchronize()
67
+
68
+ metric_logger.update(loss=loss_value)
69
+
70
+ lr = optimizer.param_groups[0]["lr"]
71
+ metric_logger.update(lr=lr)
72
+
73
+ loss_value_reduce = misc.all_reduce_mean(loss_value)
74
+ if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
75
+ """ We use epoch_1000x as the x-axis in tensorboard.
76
+ This calibrates different curves when batch size changes.
77
+ """
78
+ epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
79
+ log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x)
80
+ log_writer.add_scalar('lr', lr, epoch_1000x)
81
+ # log_writer.add_scalar('Channel Loss Mean', channel_loss, epoch_1000x)
82
+ # print(f"Channel Loss Mean: {channel_loss}")
83
+
84
+ # gather the stats from all processes
85
+ metric_logger.synchronize_between_processes()
86
+ print("Averaged stats:", metric_logger)
87
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
88
+
89
+
90
+ def main(args):
91
+ misc.init_distributed_mode(args)
92
+ torch.multiprocessing.set_start_method('spawn', force=True)
93
+ print ('work_dir:{}'.format(os.path.realpath(__file__)))
94
+ accelerator = Accelerator()
95
+ device = torch.device(args.device)
96
+ device = accelerator.device
97
+ # fix the seed for reproducibility
98
+ seed = args.seed + misc.get_rank()
99
+ torch.manual_seed(seed)
100
+ np.random.seed(seed)
101
+ cudnn.benchmark = True
102
+ # simple augmentation
103
+ transform_train = transforms.Compose([
104
+ transforms.RandomResizedCrop(args.input_size, scale=(0.2, 1.0)), # 3 is bicubicinterpolation=3
105
+ transforms.RandomHorizontalFlip(),
106
+ transforms.ToTensor(),
107
+ ])
108
+
109
+ dataset_train = Multi_task_SARImageFolder(root=args.data_path, transform=transform_train)
110
+
111
+ print(dataset_train)
112
+
113
+ if True:
114
+ num_tasks = misc.get_world_size()
115
+ global_rank = misc.get_rank()
116
+ sampler_train = torch.utils.data.DistributedSampler(
117
+ dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True)
118
+ print("Sampler_train = %s" % str(sampler_train))
119
+ else:
120
+ sampler_train = torch.utils.data.RandomSampler(dataset_train)
121
+
122
+ if global_rank == 0 and args.log_dir is not None:
123
+ os.makedirs(args.log_dir, exist_ok=True)
124
+ log_writer = SummaryWriter(log_dir=args.log_dir)
125
+ else:
126
+ log_writer = None
127
+
128
+ data_loader_train = torch.utils.data.DataLoader(dataset_train, sampler=sampler_train, batch_size=args.batch_size,
129
+ num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=True, shuffle=False
130
+ )
131
+
132
+ model = mae_model.__dict__[args.model](norm_pix_loss=args.norm_pix_loss)
133
+
134
+ # load pretrain checkpoint of Imagenet
135
+ checkpoint = torch.load(args.finetune, map_location='cpu')
136
+
137
+ print("Load pre-trained checkpoint from: %s" % args.finetune)
138
+ checkpoint_model = checkpoint['model']
139
+ state_dict = model.state_dict()
140
+ for k in ['head.weight', 'head.bias']:
141
+ if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
142
+ print(f"Removing key {k} from pretrained checkpoint")
143
+ del checkpoint_model[k]
144
+
145
+ # interpolate position embedding
146
+ interpolate_pos_embed(model, checkpoint_model)
147
+ # load pre-trained model
148
+ msg = model.load_state_dict(checkpoint_model, strict=False)
149
+ print(msg)
150
+
151
+ model.to(device)
152
+ model_without_ddp = model
153
+ print("Model = %s" % str(model_without_ddp))
154
+
155
+ eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
156
+
157
+ if args.lr is None: # only base_lr is specified
158
+ args.lr = args.blr * eff_batch_size / 80 # 256
159
+
160
+ print("base lr: %.2e" % (args.lr * 80 / eff_batch_size))
161
+ print("actual lr: %.2e" % args.lr)
162
+
163
+ print("accumulate grad iterations: %d" % args.accum_iter)
164
+ print("effective batch size: %d" % eff_batch_size)
165
+
166
+ if args.distributed:
167
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
168
+ model_without_ddp = model.module
169
+
170
+ # following timm: set wd as 0 for bias and norm layers
171
+ param_groups = optim_factory.param_groups_weight_decay(model_without_ddp, args.weight_decay) # add_weight_decay
172
+ optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))
173
+ print(optimizer)
174
+ loss_scaler = NativeScaler()
175
+
176
+ model, optimizer, data_loader_train = accelerator.prepare(model, optimizer, data_loader_train)
177
+
178
+ print(f"Start training for {args.epochs} epochs")
179
+ start_time = time.time()
180
+ for epoch in range(args.start_epoch, args.epochs):
181
+ train_stats = train_one_epoch(
182
+ model, data_loader_train,
183
+ optimizer, device, epoch, loss_scaler,
184
+ log_writer=log_writer,
185
+ args=args,
186
+ accelerator=accelerator
187
+ )
188
+ if args.output_dir and (epoch % 50 == 0 or epoch + 1 == args.epochs):
189
+ misc.save_model(
190
+ args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
191
+ loss_scaler=loss_scaler, epoch=epoch)
192
+
193
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
194
+ 'epoch': epoch, }
195
+
196
+ if args.output_dir and misc.is_main_process():
197
+ if log_writer is not None:
198
+ log_writer.flush()
199
+ with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
200
+ f.write(json.dumps(log_stats) + "\n")
201
+
202
+
203
+ total_time = time.time() - start_time
204
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
205
+ print('Training time {}'.format(total_time_str))
206
+
207
+
208
+ if __name__ == '__main__':
209
+ args = get_args_pretrain()
210
+ args = args.parse_args()
211
+ if args.output_dir:
212
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
213
+ main(args)
get_args.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+
4
+ def get_args_pretrain():
5
+ parser = argparse.ArgumentParser('MAE pre-training', add_help=False)
6
+ parser.add_argument('--batch_size', default=32, type=int,
7
+ help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
8
+ parser.add_argument('--epochs', default=100, type=int)
9
+ parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N',
10
+ help='epochs to warmup LR')
11
+ parser.add_argument('--accum_iter', default=1, type=int,
12
+ help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
13
+ parser.add_argument('--finetune',
14
+ default='.', )
15
+
16
+ # Model parameters
17
+ parser.add_argument('--model', default='mae_vit_base_patch16', type=str, metavar='MODEL',
18
+ help='Name of model to train')
19
+
20
+ parser.add_argument('--input_size', default=448, type=int,
21
+ help='images input size')
22
+
23
+ parser.add_argument('--mask_ratio', default=0.75, type=float,
24
+ help='Masking ratio (percentage of removed patches).')
25
+
26
+ parser.add_argument('--norm_pix_loss', action='store_true',
27
+ help='Use (per-patch) normalized pixels as targets for computing loss')
28
+ parser.set_defaults(norm_pix_loss=False)
29
+
30
+ # Optimizer parameters
31
+ parser.add_argument('--weight_decay', type=float, default=0.05,
32
+ help='weight decay (default: 0.05)')
33
+
34
+ parser.add_argument('--lr', type=float, default=None, metavar='LR',
35
+ help='learning rate (absolute lr)')
36
+ parser.add_argument('--blr', type=float, default=1e-4, metavar='LR',
37
+ help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
38
+ parser.add_argument('--min_lr', type=float, default=5e-8, metavar='LR',
39
+ help='lower lr bound for cyclic schedulers that hit 0')
40
+
41
+
42
+ # Dataset parameters
43
+ parser.add_argument('--data_path', default=f'/home/SARDatasets/SARfolder/', type=str,
44
+ help='dataset pathpwp')
45
+
46
+ parser.add_argument('--output_dir', default='./output',
47
+ help='path where to save, empty for no saving')
48
+ parser.add_argument('--log_dir', default='./output',
49
+ help='path where to tensorboard log')
50
+ parser.add_argument('--device', default='cuda',
51
+ help='device to use for training / testing')
52
+ parser.add_argument('--seed', default=0, type=int)
53
+ parser.add_argument('--resume', default=False,
54
+ help='resume from checkpoint')
55
+
56
+ parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
57
+ help='start epoch')
58
+ parser.add_argument('--num_workers', default=4, type=int)
59
+ parser.add_argument('--pin_mem', action='store_true',
60
+ help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
61
+ parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
62
+ parser.set_defaults(pin_mem=True)
63
+
64
+ return parser
log_analyze.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import matplotlib.pyplot as plt
4
+
5
+
6
+ def get_log(path):
7
+ epoch = []
8
+ train_lr = []
9
+ train_loss = []
10
+ test_loss = []
11
+ test_acc1 = []
12
+ decoder = json.JSONDecoder()
13
+ log = open(os.path.join(path, 'log.txt'), encoding='utf-8')
14
+ data = log.readlines()
15
+ for data_line in data:
16
+ data_line = data_line.strip('\n')
17
+ data_line = decoder.raw_decode(data_line)
18
+ print(data_line)
19
+ data_line = data_line[0]
20
+ epoch_line = data_line['epoch']
21
+ epoch.append(epoch_line)
22
+ lr_line = data_line['train_lr']
23
+ train_lr.append(lr_line)
24
+ loss_line = data_line['train_loss']
25
+ train_loss.append(loss_line)
26
+ test_los_line = data_line['test_loss']
27
+ test_loss.append(test_los_line)
28
+ acc1_line = data_line['test_acc1']
29
+ test_acc1.append(acc1_line)
30
+ log.close()
31
+ return epoch, train_lr, train_loss, test_loss, test_acc1
32
+
33
+
34
+ path = 'output_dir_finetune/'
35
+ path_noise = 'output_dir_finetune/'
36
+ epoch, train_lr, train_loss, test_loss, test_acc1 = get_log(path)
37
+ epoch_noise, train_lr_noise, train_loss_noise, test_loss_noise, test_acc1_noise = get_log(path_noise)
38
+ # 绘制test_acc1的曲线图
39
+ plt.figure()
40
+ plt.plot(test_acc1, color='r', label='test accuracy of multi-task pre-trained')
41
+ plt.plot(test_acc1_noise, color='b', label='test accuracy of none pre-trained')
42
+ # plt.title('Test Accuracy Over Time')
43
+ plt.xlabel('Epoch')
44
+ # plt.ylabel('test accuracy')
45
+ plt.legend()
46
+ plt.show()
47
+ plt.savefig(os.path.join(path, 'acd_acc.png'))
48
+
49
+
50
+ plt.figure()
51
+ plt.plot(train_loss, color='r', label='train loss of multi-task pre-trained')
52
+ plt.plot(train_loss_noise, color='b', label='train loss of none pre-trained')
53
+ # plt.title('Test Accuracy Over Time')
54
+ plt.xlabel('Epoch')
55
+ # plt.ylabel('test accuracy')
56
+ plt.legend()
57
+ plt.show()
58
+ plt.savefig(os.path.join(path, 'acd_loss.png'))
mae_model.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from timm.models.vision_transformer import PatchEmbed, Block
4
+ from pos_embed import get_2d_sincos_pos_embed
5
+ from functools import partial
6
+
7
+ dd = 12
8
+ class MAEViT(nn.Module):
9
+ def __init__(self, img_size=448, patch_size=16, in_chans=3, embed_dim=1024, depth=24, num_heads=16,
10
+ decoder_embed_dim=512, decoder_depth=dd, decoder_num_heads=16,
11
+ mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False):
12
+ super(MAEViT, self).__init__()
13
+ # MAE Encoder
14
+ self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
15
+ num_patches = self.patch_embed.num_patches
16
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
17
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False)
18
+ self.blocks = nn.ModuleList([
19
+ Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) # qk_scale=None,
20
+ for i in range(depth)])
21
+ self.norm = norm_layer(embed_dim)
22
+ # MAE Decoder
23
+ self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
24
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
25
+
26
+ self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim),
27
+ requires_grad=False) # fixed sin-cos embedding
28
+
29
+ self.decoder_blocks = nn.ModuleList([
30
+ Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) # qk_scale=None,
31
+ for i in range(decoder_depth)])
32
+
33
+ self.decoder_norm = norm_layer(decoder_embed_dim)
34
+ self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size ** 2 * in_chans, bias=True) # decoder to patch
35
+ # --------------------------------------------------------------------------
36
+
37
+ self.norm_pix_loss = norm_pix_loss
38
+
39
+ self.initialize_weights()
40
+
41
+ def initialize_weights(self):
42
+ # initialization
43
+ # initialize (and freeze) pos_embed by sin-cos embedding
44
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches ** .5),
45
+ cls_token=True)
46
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
47
+
48
+ decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1],
49
+ int(self.patch_embed.num_patches ** .5), cls_token=True)
50
+ self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
51
+
52
+ # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
53
+ w = self.patch_embed.proj.weight.data
54
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
55
+
56
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
57
+ torch.nn.init.normal_(self.cls_token, std=.02)
58
+ torch.nn.init.normal_(self.mask_token, std=.02)
59
+
60
+ # initialize nn.Linear and nn.LayerNorm
61
+ self.apply(self._init_weights)
62
+
63
+ def _init_weights(self, m):
64
+ if isinstance(m, nn.Linear):
65
+ # we use xavier_uniform following official JAX ViT:
66
+ torch.nn.init.xavier_uniform_(m.weight)
67
+ if isinstance(m, nn.Linear) and m.bias is not None:
68
+ nn.init.constant_(m.bias, 0)
69
+ elif isinstance(m, nn.LayerNorm):
70
+ nn.init.constant_(m.bias, 0)
71
+ nn.init.constant_(m.weight, 1.0)
72
+
73
+ def patchify(self, imgs):
74
+ """
75
+ imgs: (N, 3, H, W)
76
+ x: (N, L, patch_size**2 *3)
77
+ """
78
+ p = self.patch_embed.patch_size[0]
79
+ assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
80
+
81
+ h = w = imgs.shape[2] // p
82
+ x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
83
+ x = torch.einsum('nchpwq->nhwpqc', x)
84
+ x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3))
85
+ return x
86
+
87
+ def unpatchify(self, x):
88
+ """
89
+ x: (N, L, patch_size**2 *3)
90
+ imgs: (N, 3, H, W)
91
+ """
92
+ p = self.patch_embed.patch_size[0]
93
+ h = w = int(x.shape[1] ** .5)
94
+ assert h * w == x.shape[1]
95
+
96
+ hid_chans = int(x.shape[2]/(p**2))
97
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, hid_chans))
98
+ x = torch.einsum('nhwpqc->nchpwq', x)
99
+ imgs = x.reshape(shape=(x.shape[0], hid_chans, h * p, w * p))
100
+ return imgs
101
+
102
+ def random_masking(self, x, mask_ratio):
103
+ """
104
+ Perform per-sample random masking by per-sample shuffling.
105
+ Per-sample shuffling is done by argsort random noise.
106
+ x: [N, L, D], sequence
107
+ """
108
+ N, L, D = x.shape # batch, length, dim
109
+ len_keep = int(L * (1 - mask_ratio))
110
+
111
+ noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
112
+
113
+ # sort noise for each sample
114
+ ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
115
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
116
+
117
+ # keep the first subset
118
+ ids_keep = ids_shuffle[:, :len_keep]
119
+ x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
120
+
121
+ # generate the binary mask: 0 is keep, 1 is remove
122
+ mask = torch.ones([N, L], device=x.device)
123
+ mask[:, :len_keep] = 0
124
+ # unshuffle to get the binary mask
125
+ mask = torch.gather(mask, dim=1, index=ids_restore)
126
+
127
+ return x_masked, mask, ids_restore
128
+
129
+ def forward_encoder(self, x, mask_ratio):
130
+ # embed patches
131
+ x = self.patch_embed(x)
132
+
133
+ # add pos embed w/o cls token
134
+ x = x + self.pos_embed[:, 1:, :]
135
+
136
+ # masking: length -> length * mask_ratio
137
+ x, mask, ids_restore = self.random_masking(x, mask_ratio)
138
+
139
+ # append cls token
140
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
141
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
142
+ x = torch.cat((cls_tokens, x), dim=1)
143
+
144
+ # apply Transformer blocks
145
+ for blk in self.blocks:
146
+ x = blk(x)
147
+ x = self.norm(x)
148
+
149
+ return x, mask, ids_restore
150
+
151
+ def forward_decoder(self, x, ids_restore):
152
+ # embed tokens
153
+ x = self.decoder_embed(x)
154
+
155
+ # append mask tokens to sequence
156
+ mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
157
+ x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
158
+ x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
159
+ x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
160
+
161
+ # add pos embed
162
+ x = x + self.decoder_pos_embed
163
+
164
+ # apply Transformer blocks
165
+ for blk in self.decoder_blocks:
166
+ x = blk(x)
167
+ x = self.decoder_norm(x)
168
+
169
+ # predictor projection
170
+ x = self.decoder_pred(x)
171
+
172
+ # remove cls token
173
+ x = x[:, 1:, :]
174
+
175
+ return x
176
+
177
+
178
+ def forward_loss(self, imgs, pred, mask):
179
+ """
180
+ imgs: [N, 3, H, W]
181
+ pred: [N, L, p*p*3]
182
+ mask: [N, L], 0 is keep, 1 is remove,
183
+ """
184
+ target = self.patchify(imgs)
185
+ if self.norm_pix_loss:
186
+ mean = target.mean(dim=-1, keepdim=True)
187
+ var = target.var(dim=-1, keepdim=True)
188
+ target = (target - mean) / (var + 1.e-6) ** .5
189
+
190
+ loss = (pred - target) ** 2
191
+ loss = loss.mean(dim=-1) # [N, L], mean loss per patch
192
+
193
+ loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
194
+ return loss
195
+
196
+ def forward_loss_separately(self, imgs, pred, mask):
197
+ """
198
+ imgs: [N, 3, H, W]
199
+ pred: [N, L, p*p*3]
200
+ mask: [N, L], 0 is keep, 1 is remove,
201
+ """
202
+ target = self.patchify(imgs)
203
+ if self.norm_pix_loss:
204
+ mean = target.mean(dim=-1, keepdim=True)
205
+ var = target.var(dim=-1, keepdim=True)
206
+ target = (target - mean) / (var + 1.e-6) ** .5
207
+
208
+ channel_weights = torch.tensor([1, 0.5, 0.5], device=pred.device)
209
+
210
+ loss = (pred - target) ** 2
211
+ loss = loss.view(loss.shape[0], loss.shape[1], -1, 3)
212
+
213
+ channel_loss_mean = loss.mean(dim=[0, 1, 2])
214
+ # print(f"Channel Loss Mean: {channel_loss_mean}")
215
+
216
+ loss = loss * channel_weights
217
+ loss = loss.sum(dim=-1)
218
+ loss = loss.mean(dim=-1)
219
+ loss = (loss * mask).sum() / mask.sum()
220
+
221
+ return loss, channel_loss_mean
222
+
223
+ def forward(self, imgs, target, mask_ratio=0.75):
224
+ latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio=mask_ratio)
225
+ pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3]
226
+ # loss = self.forward_loss(imgs, pred, mask)
227
+ # return loss, pred, mask
228
+ loss, channel_loss_mean = self.forward_loss_separately(target, pred, mask)
229
+ return loss, channel_loss_mean, pred, mask
230
+
231
+ # def forward(self, imgs, mask_ratio=0.75):
232
+ # latent, mask = self.forward_encoder(imgs)
233
+ # pred = self.forward_decoder(latent, mask) # Use mask instead of ids_restore
234
+ # loss = self.forward_loss(imgs, pred, mask)
235
+ # return loss, pred, mask
236
+
237
+
238
+
239
+ def mae_vit_base_patch16(**kwargs):
240
+ model = MAEViT(
241
+ patch_size=16, embed_dim=768, depth=12, num_heads=12,
242
+ decoder_embed_dim=512, decoder_depth=dd, decoder_num_heads=16,
243
+ mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
244
+ return model
245
+
246
+ def mae_vit_large_patch16(**kwargs):
247
+ model = MAEViT(
248
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16,
249
+ decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
250
+ mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
251
+ return model
252
+
253
+
254
+ def mae_vit_huge_patch14(**kwargs):
255
+ model = MAEViT(
256
+ patch_size=14, embed_dim=1280, depth=32, num_heads=16,
257
+ decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
258
+ mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
259
+ return model
mae_ori_model.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.distributions as dist
4
+ import torchvision.transforms
5
+ import numpy as np
6
+ from timm.models.vision_transformer import PatchEmbed, Block
7
+ from pos_embed import get_2d_sincos_pos_embed
8
+ from scipy.stats import gamma, lognorm, expon
9
+ from functools import partial
10
+
11
+
12
+ class DenoiseMAEViT(nn.Module):
13
+ def __init__(self, img_size=448, patch_size=16, in_chans=3, embed_dim=1024, depth=24, num_heads=16,
14
+ decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
15
+ mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False):
16
+ super(DenoiseMAEViT, self).__init__()
17
+ # MAE Encoder
18
+ self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
19
+ num_patches = self.patch_embed.num_patches
20
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
21
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False)
22
+ self.blocks = nn.ModuleList([
23
+ Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) # qk_scale=None,
24
+ for i in range(depth)])
25
+ self.norm = norm_layer(embed_dim)
26
+ # MAE Decoder
27
+ self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
28
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
29
+
30
+ self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim),
31
+ requires_grad=False) # fixed sin-cos embedding
32
+
33
+ self.decoder_blocks = nn.ModuleList([
34
+ Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) # qk_scale=None,
35
+ for i in range(decoder_depth)])
36
+
37
+ self.decoder_norm = norm_layer(decoder_embed_dim)
38
+ self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size ** 2 * in_chans, bias=True) # decoder to patch
39
+ # --------------------------------------------------------------------------
40
+
41
+ self.norm_pix_loss = norm_pix_loss
42
+
43
+ self.initialize_weights()
44
+
45
+ def initialize_weights(self):
46
+ # initialization
47
+ # initialize (and freeze) pos_embed by sin-cos embedding
48
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches ** .5),
49
+ cls_token=True)
50
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
51
+
52
+ decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1],
53
+ int(self.patch_embed.num_patches ** .5), cls_token=True)
54
+ self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
55
+
56
+ # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
57
+ w = self.patch_embed.proj.weight.data
58
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
59
+
60
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
61
+ torch.nn.init.normal_(self.cls_token, std=.02)
62
+ torch.nn.init.normal_(self.mask_token, std=.02)
63
+
64
+ # initialize nn.Linear and nn.LayerNorm
65
+ self.apply(self._init_weights)
66
+
67
+ def _init_weights(self, m):
68
+ if isinstance(m, nn.Linear):
69
+ # we use xavier_uniform following official JAX ViT:
70
+ torch.nn.init.xavier_uniform_(m.weight)
71
+ if isinstance(m, nn.Linear) and m.bias is not None:
72
+ nn.init.constant_(m.bias, 0)
73
+ elif isinstance(m, nn.LayerNorm):
74
+ nn.init.constant_(m.bias, 0)
75
+ nn.init.constant_(m.weight, 1.0)
76
+
77
+ def patchify(self, imgs):
78
+ """
79
+ imgs: (N, 3, H, W)
80
+ x: (N, L, patch_size**2 *3)
81
+ """
82
+ p = self.patch_embed.patch_size[0]
83
+ assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
84
+
85
+ h = w = imgs.shape[2] // p
86
+ x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
87
+ x = torch.einsum('nchpwq->nhwpqc', x)
88
+ x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3))
89
+ return x
90
+
91
+ def unpatchify(self, x):
92
+ """
93
+ x: (N, L, patch_size**2 *3)
94
+ imgs: (N, 3, H, W)
95
+ """
96
+ p = self.patch_embed.patch_size[0]
97
+ h = w = int(x.shape[1] ** .5)
98
+ assert h * w == x.shape[1]
99
+
100
+ hid_chans = int(x.shape[2]/(p**2))
101
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, hid_chans))
102
+ x = torch.einsum('nhwpqc->nchpwq', x)
103
+ imgs = x.reshape(shape=(x.shape[0], hid_chans, h * p, w * p))
104
+ return imgs
105
+
106
+ # x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
107
+ # x = torch.einsum('nhwpqc->nchpwq', x)
108
+ # imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
109
+ # return imgs
110
+
111
+ def random_masking(self, x, mask_ratio):
112
+ """
113
+ Perform per-sample random masking by per-sample shuffling.
114
+ Per-sample shuffling is done by argsort random noise.
115
+ x: [N, L, D], sequence
116
+ """
117
+ N, L, D = x.shape # batch, length, dim
118
+ len_keep = int(L * (1 - mask_ratio))
119
+
120
+ noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
121
+
122
+ # sort noise for each sample
123
+ ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
124
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
125
+
126
+ # keep the first subset
127
+ ids_keep = ids_shuffle[:, :len_keep]
128
+ x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
129
+
130
+ # generate the binary mask: 0 is keep, 1 is remove
131
+ mask = torch.ones([N, L], device=x.device)
132
+ mask[:, :len_keep] = 0
133
+ # unshuffle to get the binary mask
134
+ mask = torch.gather(mask, dim=1, index=ids_restore)
135
+
136
+ return x_masked, mask, ids_restore
137
+
138
+ def forward_encoder(self, x, mask_ratio):
139
+ # embed patches
140
+ x = self.patch_embed(x)
141
+
142
+ # add pos embed w/o cls token
143
+ x = x + self.pos_embed[:, 1:, :]
144
+
145
+ # masking: length -> length * mask_ratio
146
+ x, mask, ids_restore = self.random_masking(x, mask_ratio)
147
+
148
+ # append cls token
149
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
150
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
151
+ x = torch.cat((cls_tokens, x), dim=1)
152
+
153
+ # apply Transformer blocks
154
+ for blk in self.blocks:
155
+ x = blk(x)
156
+ x = self.norm(x)
157
+
158
+ return x, mask, ids_restore
159
+
160
+ def forward_decoder(self, x, ids_restore):
161
+ # embed tokens
162
+ x = self.decoder_embed(x)
163
+
164
+ # append mask tokens to sequence
165
+ mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
166
+ x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
167
+ x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
168
+ x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
169
+
170
+ # add pos embed
171
+ x = x + self.decoder_pos_embed
172
+
173
+ # apply Transformer blocks
174
+ for blk in self.decoder_blocks:
175
+ x = blk(x)
176
+ x = self.decoder_norm(x)
177
+
178
+ # predictor projection
179
+ x = self.decoder_pred(x)
180
+
181
+ # remove cls token
182
+ x = x[:, 1:, :]
183
+
184
+ return x
185
+
186
+ def forward_loss(self, imgs, pred, mask):
187
+ """
188
+ imgs: [N, 3, H, W]
189
+ pred: [N, L, p*p*3]
190
+ mask: [N, L], 0 is keep, 1 is remove,
191
+ """
192
+ target = self.patchify(imgs)
193
+ if self.norm_pix_loss:
194
+ mean = target.mean(dim=-1, keepdim=True)
195
+ var = target.var(dim=-1, keepdim=True)
196
+ target = (target - mean) / (var + 1.e-6) ** .5
197
+
198
+ loss = (pred - target) ** 2
199
+ loss = loss.mean(dim=-1) # [N, L], mean loss per patch
200
+
201
+ loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
202
+ return loss
203
+
204
+ def forward(self, imgs, mask_ratio=0.75):
205
+ latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
206
+ pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3]
207
+ loss = self.forward_loss(imgs, pred, mask)
208
+ return loss, pred, mask
209
+
210
+
211
+ def mae_vit_base_patch16(**kwargs):
212
+ model = DenoiseMAEViT(
213
+ patch_size=16, embed_dim=768, depth=12, num_heads=12,
214
+ decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
215
+ mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
216
+ return model
217
+
218
+
219
+ def mae_vit_large_patch16(**kwargs):
220
+ model = DenoiseMAEViT(
221
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16,
222
+ decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
223
+ mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
224
+ return model
225
+
226
+
227
+ def mae_vit_huge_patch16(**kwargs):
228
+ model = DenoiseMAEViT(
229
+ patch_size=16, embed_dim=1280, depth=32, num_heads=16,
230
+ decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
231
+ mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
232
+ return model
233
+
overall.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5f990a07117a667a2972cbf8c5c609e43b1fecca305ebf4efd4f7e6fa22b35b6
3
+ size 229409
pos_embed.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # Position embedding utils
8
+ # --------------------------------------------------------
9
+
10
+ import numpy as np
11
+
12
+ import torch
13
+
14
+ # --------------------------------------------------------
15
+ # 2D sine-cosine position embedding
16
+ # References:
17
+ # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
18
+ # MoCo v3: https://github.com/facebookresearch/moco-v3
19
+ # --------------------------------------------------------
20
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
21
+ """
22
+ grid_size: int of the grid height and width
23
+ return:
24
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
25
+ """
26
+ grid_h = np.arange(grid_size, dtype=np.float32)
27
+ grid_w = np.arange(grid_size, dtype=np.float32)
28
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
29
+ grid = np.stack(grid, axis=0)
30
+
31
+ grid = grid.reshape([2, 1, grid_size, grid_size])
32
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
33
+ if cls_token:
34
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
35
+ return pos_embed
36
+
37
+
38
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
39
+ assert embed_dim % 2 == 0
40
+
41
+ # use half of dimensions to encode grid_h
42
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
43
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
44
+
45
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
46
+ return emb
47
+
48
+
49
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
50
+ """
51
+ embed_dim: output dimension for each position
52
+ pos: a list of positions to be encoded: size (M,)
53
+ out: (M, D)
54
+ """
55
+ assert embed_dim % 2 == 0
56
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
57
+ omega /= embed_dim / 2.
58
+ omega = 1. / 10000**omega # (D/2,)
59
+
60
+ pos = pos.reshape(-1) # (M,)
61
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
62
+
63
+ emb_sin = np.sin(out) # (M, D/2)
64
+ emb_cos = np.cos(out) # (M, D/2)
65
+
66
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
67
+ return emb
68
+
69
+
70
+ # --------------------------------------------------------
71
+ # Interpolate position embeddings for high-resolution
72
+ # References:
73
+ # DeiT: https://github.com/facebookresearch/deit
74
+ # --------------------------------------------------------
75
+ def interpolate_pos_embed(model, checkpoint_model):
76
+ if 'pos_embed' in checkpoint_model:
77
+ pos_embed_checkpoint = checkpoint_model['pos_embed']
78
+ embedding_size = pos_embed_checkpoint.shape[-1]
79
+ num_patches = model.patch_embed.num_patches
80
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
81
+ # height (== width) for the checkpoint position embedding
82
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
83
+ # height (== width) for the new position embedding
84
+ new_size = int(num_patches ** 0.5)
85
+ # class_token and dist_token are kept unchanged
86
+ if orig_size != new_size:
87
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
88
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
89
+ # only the position tokens are interpolated
90
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
91
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
92
+ pos_tokens = torch.nn.functional.interpolate(
93
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
94
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
95
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
96
+ checkpoint_model['pos_embed'] = new_pos_embed
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.28.0
2
+ matplotlib==3.8.4
3
+ numpy==2.3.3
4
+ opencv_python==4.9.0.80
5
+ Pillow==11.3.0
6
+ Requests==2.32.5
7
+ scikit_learn==1.4.2
8
+ scipy==1.16.2
9
+ submitit==1.5.3
10
+ timm==1.0.20
11
+ torch==2.2.1+cu118
12
+ torchvision==0.17.1+cu118
13
+ tqdm==4.66.2
util/__pycache__/lr_decay.cpython-310.pyc ADDED
Binary file (1.61 kB). View file
 
util/__pycache__/lr_decay.cpython-312.pyc ADDED
Binary file (2.34 kB). View file
 
util/__pycache__/lr_sched.cpython-310.pyc ADDED
Binary file (611 Bytes). View file
 
util/__pycache__/lr_sched.cpython-312.pyc ADDED
Binary file (1.07 kB). View file
 
util/__pycache__/misc.cpython-310.pyc ADDED
Binary file (10.9 kB). View file
 
util/__pycache__/misc.cpython-312.pyc ADDED
Binary file (19.7 kB). View file
 
util/__pycache__/pos_embed.cpython-310.pyc ADDED
Binary file (2.38 kB). View file
 
util/__pycache__/pos_embed.cpython-312.pyc ADDED
Binary file (4.03 kB). View file
 
util/crop.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+
9
+ import torch
10
+
11
+ from torchvision import transforms
12
+ from torchvision.transforms import functional as F
13
+
14
+
15
+ class RandomResizedCrop(transforms.RandomResizedCrop):
16
+ """
17
+ RandomResizedCrop for matching TF/TPU implementation: no for-loop is used.
18
+ This may lead to results different with torchvision's version.
19
+ Following BYOL's TF code:
20
+ https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206
21
+ """
22
+ @staticmethod
23
+ def get_params(img, scale, ratio):
24
+ width, height = F._get_image_size(img)
25
+ area = height * width
26
+
27
+ target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
28
+ log_ratio = torch.log(torch.tensor(ratio))
29
+ aspect_ratio = torch.exp(
30
+ torch.empty(1).uniform_(log_ratio[0], log_ratio[1])
31
+ ).item()
32
+
33
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
34
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
35
+
36
+ w = min(w, width)
37
+ h = min(h, height)
38
+
39
+ i = torch.randint(0, height - h + 1, size=(1,)).item()
40
+ j = torch.randint(0, width - w + 1, size=(1,)).item()
41
+
42
+ return i, j, h, w
util/datasets.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # DeiT: https://github.com/facebookresearch/deit
9
+ # --------------------------------------------------------
10
+
11
+ import os
12
+ import PIL
13
+
14
+ from torchvision import datasets, transforms
15
+
16
+ from timm.data import create_transform
17
+ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
18
+
19
+
20
+ def build_dataset(is_train, args):
21
+ transform = build_transform(is_train, args)
22
+
23
+ root = os.path.join(args.data_path, 'train' if is_train else 'val')
24
+ dataset = datasets.ImageFolder(root, transform=transform)
25
+
26
+ print(dataset)
27
+
28
+ return dataset
29
+
30
+
31
+ def build_transform(is_train, args):
32
+ mean = IMAGENET_DEFAULT_MEAN
33
+ std = IMAGENET_DEFAULT_STD
34
+ # train transform
35
+ if is_train:
36
+ # this should always dispatch to transforms_imagenet_train
37
+ transform = create_transform(
38
+ input_size=args.input_size,
39
+ is_training=True,
40
+ color_jitter=args.color_jitter,
41
+ auto_augment=args.aa,
42
+ interpolation='bicubic',
43
+ re_prob=args.reprob,
44
+ re_mode=args.remode,
45
+ re_count=args.recount,
46
+ )
47
+ return transform
48
+
49
+ # eval transform
50
+ t = []
51
+ if args.input_size <= 224:
52
+ crop_pct = 224 / 256
53
+ else:
54
+ crop_pct = 1.0
55
+ size = int(args.input_size / crop_pct)
56
+ t.append(
57
+ transforms.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images
58
+ )
59
+ t.append(transforms.CenterCrop(args.input_size))
60
+
61
+ t.append(transforms.ToTensor())
62
+ t.append(transforms.Normalize(mean, std))
63
+ return transforms.Compose(t)
util/lars.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # LARS optimizer, implementation from MoCo v3:
8
+ # https://github.com/facebookresearch/moco-v3
9
+ # --------------------------------------------------------
10
+
11
+ import torch
12
+
13
+
14
+ class LARS(torch.optim.Optimizer):
15
+ """
16
+ LARS optimizer, no rate scaling or weight decay for parameters <= 1D.
17
+ """
18
+ def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001):
19
+ defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient)
20
+ super().__init__(params, defaults)
21
+
22
+ @torch.no_grad()
23
+ def step(self):
24
+ for g in self.param_groups:
25
+ for p in g['params']:
26
+ dp = p.grad
27
+
28
+ if dp is None:
29
+ continue
30
+
31
+ if p.ndim > 1: # if not normalization gamma/beta or bias
32
+ dp = dp.add(p, alpha=g['weight_decay'])
33
+ param_norm = torch.norm(p)
34
+ update_norm = torch.norm(dp)
35
+ one = torch.ones_like(param_norm)
36
+ q = torch.where(param_norm > 0.,
37
+ torch.where(update_norm > 0,
38
+ (g['trust_coefficient'] * param_norm / update_norm), one),
39
+ one)
40
+ dp = dp.mul(q)
41
+
42
+ param_state = self.state[p]
43
+ if 'mu' not in param_state:
44
+ param_state['mu'] = torch.zeros_like(p)
45
+ mu = param_state['mu']
46
+ mu.mul_(g['momentum']).add_(dp)
47
+ p.add_(mu, alpha=-g['lr'])
util/lr_decay.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # ELECTRA https://github.com/google-research/electra
9
+ # BEiT: https://github.com/microsoft/unilm/tree/master/beit
10
+ # --------------------------------------------------------
11
+
12
+ import json
13
+
14
+
15
+ def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75):
16
+ """
17
+ Parameter groups for layer-wise lr decay
18
+ Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
19
+ """
20
+ param_group_names = {}
21
+ param_groups = {}
22
+
23
+ num_layers = len(model.blocks) + 1
24
+
25
+ layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1))
26
+
27
+ for n, p in model.named_parameters():
28
+ if not p.requires_grad:
29
+ continue
30
+
31
+ # no decay: all 1D parameters and model specific ones
32
+ if p.ndim == 1 or n in no_weight_decay_list:
33
+ g_decay = "no_decay"
34
+ this_decay = 0.
35
+ else:
36
+ g_decay = "decay"
37
+ this_decay = weight_decay
38
+
39
+ layer_id = get_layer_id_for_vit(n, num_layers)
40
+ group_name = "layer_%d_%s" % (layer_id, g_decay)
41
+
42
+ if group_name not in param_group_names:
43
+ this_scale = layer_scales[layer_id]
44
+
45
+ param_group_names[group_name] = {
46
+ "lr_scale": this_scale,
47
+ "weight_decay": this_decay,
48
+ "params": [],
49
+ }
50
+ param_groups[group_name] = {
51
+ "lr_scale": this_scale,
52
+ "weight_decay": this_decay,
53
+ "params": [],
54
+ }
55
+
56
+ param_group_names[group_name]["params"].append(n)
57
+ param_groups[group_name]["params"].append(p)
58
+
59
+ # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
60
+
61
+ return list(param_groups.values())
62
+
63
+
64
+ def get_layer_id_for_vit(name, num_layers):
65
+ """
66
+ Assign a parameter with its layer id
67
+ Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
68
+ """
69
+ if name in ['cls_token', 'pos_embed']:
70
+ return 0
71
+ elif name.startswith('patch_embed'):
72
+ return 0
73
+ elif name.startswith('blocks'):
74
+ return int(name.split('.')[1]) + 1
75
+ else:
76
+ return num_layers
util/lr_sched.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+
9
+
10
+ def adjust_learning_rate(optimizer, epoch, args):
11
+ """Decay the learning rate with half-cycle cosine after warmup"""
12
+ if epoch < args.warmup_epochs:
13
+ lr = args.blr * epoch / args.warmup_epochs
14
+ else:
15
+ lr = args.min_lr + (args.blr - args.min_lr) * 0.5 * \
16
+ (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
17
+ for param_group in optimizer.param_groups:
18
+ if "lr_scale" in param_group:
19
+ param_group["lr"] = lr * param_group["lr_scale"]
20
+ else:
21
+ param_group["lr"] = lr
22
+ return lr
util/misc.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # DeiT: https://github.com/facebookresearch/deit
9
+ # BEiT: https://github.com/microsoft/unilm/tree/master/beit
10
+ # --------------------------------------------------------
11
+
12
+ import builtins
13
+ import datetime
14
+ import os
15
+ import time
16
+ from collections import defaultdict, deque
17
+ from pathlib import Path
18
+
19
+ import torch
20
+ import torch.distributed as dist
21
+ from PIL import ImageFile
22
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
23
+
24
+
25
+ dist_on_itp = False
26
+
27
+ # from torch._six import inf
28
+
29
+
30
+ class SmoothedValue(object):
31
+ """Track a series of values and provide access to smoothed values over a
32
+ window or the global series average.
33
+ """
34
+
35
+ def __init__(self, window_size=20, fmt=None):
36
+ if fmt is None:
37
+ fmt = "{median:.4f} ({global_avg:.4f})"
38
+ self.deque = deque(maxlen=window_size)
39
+ self.total = 0.0
40
+ self.count = 0
41
+ self.fmt = fmt
42
+
43
+ def update(self, value, n=1):
44
+ self.deque.append(value)
45
+ self.count += n
46
+ self.total += value * n
47
+
48
+ def synchronize_between_processes(self):
49
+ """
50
+ Warning: does not synchronize the deque!
51
+ """
52
+ if not is_dist_avail_and_initialized():
53
+ return
54
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
55
+ dist.barrier()
56
+ dist.all_reduce(t)
57
+ t = t.tolist()
58
+ self.count = int(t[0])
59
+ self.total = t[1]
60
+
61
+ @property
62
+ def median(self):
63
+ d = torch.tensor(list(self.deque))
64
+ return d.median().item()
65
+
66
+ @property
67
+ def avg(self):
68
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
69
+ return d.mean().item()
70
+
71
+ @property
72
+ def global_avg(self):
73
+ if self.count == 0:
74
+ return 0
75
+ else:
76
+ return self.total / self.count
77
+
78
+ @property
79
+ def max(self):
80
+ return max(self.deque)
81
+
82
+ @property
83
+ def value(self):
84
+ return self.deque[-1]
85
+
86
+ def __str__(self):
87
+ return self.fmt.format(
88
+ median=self.median,
89
+ avg=self.avg,
90
+ global_avg=self.global_avg,
91
+ max=self.max,
92
+ value=self.value)
93
+
94
+
95
+ class MetricLogger(object):
96
+ def __init__(self, delimiter="\t"):
97
+ self.meters = defaultdict(SmoothedValue)
98
+ self.delimiter = delimiter
99
+
100
+ def update(self, **kwargs):
101
+ for k, v in kwargs.items():
102
+ if v is None:
103
+ continue
104
+ if isinstance(v, torch.Tensor):
105
+ v = v.item()
106
+ assert isinstance(v, (float, int))
107
+ self.meters[k].update(v)
108
+
109
+ def __getattr__(self, attr):
110
+ if attr in self.meters:
111
+ return self.meters[attr]
112
+ if attr in self.__dict__:
113
+ return self.__dict__[attr]
114
+ raise AttributeError("'{}' object has no attribute '{}'".format(
115
+ type(self).__name__, attr))
116
+
117
+ def __str__(self):
118
+ loss_str = []
119
+ for name, meter in self.meters.items():
120
+ loss_str.append(
121
+ "{}: {}".format(name, str(meter))
122
+ )
123
+ return self.delimiter.join(loss_str)
124
+
125
+ def synchronize_between_processes(self):
126
+ for meter in self.meters.values():
127
+ meter.synchronize_between_processes()
128
+
129
+ def add_meter(self, name, meter):
130
+ self.meters[name] = meter
131
+
132
+ def log_every(self, iterable, print_freq, header=None):
133
+ i = 0
134
+ if not header:
135
+ header = ''
136
+ start_time = time.time()
137
+ end = time.time()
138
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
139
+ data_time = SmoothedValue(fmt='{avg:.4f}')
140
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
141
+ log_msg = [
142
+ header,
143
+ '[{0' + space_fmt + '}/{1}]',
144
+ 'eta: {eta}',
145
+ '{meters}',
146
+ 'time: {time}',
147
+ 'data: {data}'
148
+ ]
149
+ if torch.cuda.is_available():
150
+ log_msg.append('max mem: {memory:.0f}')
151
+ log_msg = self.delimiter.join(log_msg)
152
+ MB = 1024.0 * 1024.0
153
+ for obj in iterable:
154
+ data_time.update(time.time() - end)
155
+ yield obj
156
+ iter_time.update(time.time() - end)
157
+ if i % print_freq == 0 or i == len(iterable) - 1:
158
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
159
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
160
+ if torch.cuda.is_available():
161
+ print(log_msg.format(
162
+ i, len(iterable), eta=eta_string,
163
+ meters=str(self),
164
+ time=str(iter_time), data=str(data_time),
165
+ memory=torch.cuda.max_memory_allocated() / MB))
166
+ else:
167
+ print(log_msg.format(
168
+ i, len(iterable), eta=eta_string,
169
+ meters=str(self),
170
+ time=str(iter_time), data=str(data_time)))
171
+ i += 1
172
+ end = time.time()
173
+ total_time = time.time() - start_time
174
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
175
+ if len(iterable) == 0:
176
+ print('Total time: {} ({:.4f} s / it)'.format(total_time_str, 0))
177
+ else:
178
+ print('{} Total time: {} ({:.4f} s / it)'.format(
179
+ header, total_time_str, total_time / len(iterable)))
180
+
181
+
182
+ def setup_for_distributed(is_master):
183
+ """
184
+ This function disables printing when not in master process
185
+ """
186
+ builtin_print = builtins.print
187
+
188
+ def print(*args, **kwargs):
189
+ force = kwargs.pop('force', False)
190
+ force = force or (get_world_size() > 8)
191
+ if is_master or force:
192
+ now = datetime.datetime.now().time()
193
+ builtin_print('[{}] '.format(now), end='') # print with time stamp
194
+ builtin_print(*args, **kwargs)
195
+
196
+ builtins.print = print
197
+
198
+
199
+ def is_dist_avail_and_initialized():
200
+ if not dist.is_available():
201
+ return False
202
+ if not dist.is_initialized():
203
+ return False
204
+ return True
205
+
206
+
207
+ def get_world_size():
208
+ if not is_dist_avail_and_initialized():
209
+ return 1
210
+ return dist.get_world_size()
211
+
212
+
213
+ def get_rank():
214
+ if not is_dist_avail_and_initialized():
215
+ return 0
216
+ return dist.get_rank()
217
+
218
+
219
+ def is_main_process():
220
+ return get_rank() == 0
221
+
222
+
223
+ def save_on_master(*args, **kwargs):
224
+ if is_main_process():
225
+ torch.save(*args, **kwargs)
226
+
227
+
228
+ def init_distributed_mode(args):
229
+ if dist_on_itp:
230
+ args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
231
+ args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
232
+ args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
233
+ args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
234
+ os.environ['LOCAL_RANK'] = str(args.gpu)
235
+ os.environ['RANK'] = str(args.rank)
236
+ os.environ['WORLD_SIZE'] = str(args.world_size)
237
+ # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
238
+ elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
239
+ args.rank = int(os.environ["RANK"])
240
+ args.world_size = int(os.environ['WORLD_SIZE'])
241
+ args.gpu = int(os.environ['LOCAL_RANK'])
242
+ elif 'SLURM_PROCID' in os.environ:
243
+ args.rank = int(os.environ['SLURM_PROCID'])
244
+ args.gpu = args.rank % torch.cuda.device_count()
245
+ else:
246
+ print('Not using distributed mode')
247
+ setup_for_distributed(is_master=True) # hack
248
+ args.distributed = False
249
+ return
250
+
251
+ args.distributed = True
252
+
253
+ torch.cuda.set_device(args.gpu)
254
+ args.dist_url = 'env://'
255
+ args.dist_backend = 'nccl'
256
+ print('| distributed init (rank {}): {}, gpu {}'.format(
257
+ args.rank, args.dist_url, args.gpu), flush=True)
258
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
259
+ world_size=args.world_size, rank=args.rank)
260
+ torch.distributed.barrier()
261
+ setup_for_distributed(args.rank == 0)
262
+
263
+
264
+ class NativeScalerWithGradNormCount:
265
+ state_dict_key = "amp_scaler"
266
+
267
+ def __init__(self):
268
+ self._scaler = torch.cuda.amp.GradScaler()
269
+
270
+ def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
271
+ self._scaler.scale(loss).backward(create_graph=create_graph)
272
+ if update_grad:
273
+ if clip_grad is not None:
274
+ assert parameters is not None
275
+ self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
276
+ norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
277
+ else:
278
+ self._scaler.unscale_(optimizer)
279
+ norm = get_grad_norm_(parameters)
280
+ self._scaler.step(optimizer)
281
+ self._scaler.update()
282
+ else:
283
+ norm = None
284
+ return norm
285
+
286
+ def state_dict(self):
287
+ return self._scaler.state_dict()
288
+
289
+ def load_state_dict(self, state_dict):
290
+ self._scaler.load_state_dict(state_dict)
291
+
292
+
293
+ def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
294
+ if isinstance(parameters, torch.Tensor):
295
+ parameters = [parameters]
296
+ parameters = [p for p in parameters if p.grad is not None]
297
+ norm_type = float(norm_type)
298
+ if len(parameters) == 0:
299
+ return torch.tensor(0.)
300
+ device = parameters[0].grad.device
301
+ if norm_type == float('inf'):
302
+ total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
303
+ else:
304
+ total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
305
+ return total_norm
306
+
307
+
308
+ def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler):
309
+ output_dir = Path(args.output_dir)
310
+ epoch_name = str(epoch)
311
+ if loss_scaler is not None:
312
+ checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
313
+ for checkpoint_path in checkpoint_paths:
314
+ to_save = {
315
+ 'model': model_without_ddp.state_dict(),
316
+ 'optimizer': optimizer.state_dict(),
317
+ 'epoch': epoch,
318
+ 'scaler': loss_scaler.state_dict(),
319
+ 'args': args,
320
+ }
321
+
322
+ save_on_master(to_save, checkpoint_path)
323
+ else:
324
+ client_state = {'epoch': epoch}
325
+ model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state)
326
+
327
+
328
+ def load_model(args, model_without_ddp, optimizer, loss_scaler):
329
+ if args.resume:
330
+ if args.resume.startswith('https'):
331
+ checkpoint = torch.hub.load_state_dict_from_url(
332
+ args.resume, map_location='cpu', check_hash=True)
333
+ else:
334
+ checkpoint = torch.load(args.resume, map_location='cpu')
335
+ model_without_ddp.load_state_dict(checkpoint['model'])
336
+ print("Resume checkpoint %s" % args.resume)
337
+ if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):
338
+ optimizer.load_state_dict(checkpoint['optimizer'])
339
+ args.start_epoch = checkpoint['epoch'] + 1
340
+ if 'scaler' in checkpoint:
341
+ loss_scaler.load_state_dict(checkpoint['scaler'])
342
+ print("With optim & sched!")
343
+
344
+
345
+ def all_reduce_mean(x):
346
+ world_size = get_world_size()
347
+ if world_size > 1:
348
+ x_reduce = torch.tensor(x).cuda()
349
+ dist.all_reduce(x_reduce)
350
+ x_reduce /= world_size
351
+ return x_reduce.item()
352
+ else:
353
+ return x
util/pos_embed.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # Position embedding utils
8
+ # --------------------------------------------------------
9
+
10
+ import numpy as np
11
+
12
+ import torch
13
+
14
+ # --------------------------------------------------------
15
+ # 2D sine-cosine position embedding
16
+ # References:
17
+ # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
18
+ # MoCo v3: https://github.com/facebookresearch/moco-v3
19
+ # --------------------------------------------------------
20
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
21
+ """
22
+ grid_size: int of the grid height and width
23
+ return:
24
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
25
+ """
26
+ grid_h = np.arange(grid_size, dtype=np.float32)
27
+ grid_w = np.arange(grid_size, dtype=np.float32)
28
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
29
+ grid = np.stack(grid, axis=0)
30
+
31
+ grid = grid.reshape([2, 1, grid_size, grid_size])
32
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
33
+ if cls_token:
34
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
35
+ return pos_embed
36
+
37
+
38
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
39
+ assert embed_dim % 2 == 0
40
+
41
+ # use half of dimensions to encode grid_h
42
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
43
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
44
+
45
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
46
+ return emb
47
+
48
+
49
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
50
+ """
51
+ embed_dim: output dimension for each position
52
+ pos: a list of positions to be encoded: size (M,)
53
+ out: (M, D)
54
+ """
55
+ assert embed_dim % 2 == 0
56
+ omega = np.arange(embed_dim // 2, dtype=np.float)
57
+ omega /= embed_dim / 2.
58
+ omega = 1. / 10000**omega # (D/2,)
59
+
60
+ pos = pos.reshape(-1) # (M,)
61
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
62
+
63
+ emb_sin = np.sin(out) # (M, D/2)
64
+ emb_cos = np.cos(out) # (M, D/2)
65
+
66
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
67
+ return emb
68
+
69
+
70
+ # --------------------------------------------------------
71
+ # Interpolate position embeddings for high-resolution
72
+ # References:
73
+ # DeiT: https://github.com/facebookresearch/deit
74
+ # --------------------------------------------------------
75
+ def interpolate_pos_embed(model, checkpoint_model):
76
+ if 'pos_embed' in checkpoint_model:
77
+ pos_embed_checkpoint = checkpoint_model['pos_embed']
78
+ embedding_size = pos_embed_checkpoint.shape[-1]
79
+ num_patches = model.patch_embed.num_patches
80
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
81
+ # height (== width) for the checkpoint position embedding
82
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
83
+ # height (== width) for the new position embedding
84
+ new_size = int(num_patches ** 0.5)
85
+ # class_token and dist_token are kept unchanged
86
+ if orig_size != new_size:
87
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
88
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
89
+ # only the position tokens are interpolated
90
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
91
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
92
+ pos_tokens = torch.nn.functional.interpolate(
93
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
94
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
95
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
96
+ checkpoint_model['pos_embed'] = new_pos_embed
vit_model.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import timm
4
+ from timm.models.vision_transformer import PatchEmbed
5
+ from functools import partial
6
+
7
+
8
+ class vit(timm.models.vision_transformer.VisionTransformer):
9
+ def __init__(self, global_pool=False, **kwargs):
10
+ super(vit, self).__init__()
11
+ self.global_pool = global_pool
12
+ embed_dim = kwargs['embed_dim']
13
+ num_classes = kwargs['num_classes']
14
+ self.head = nn.Linear(embed_dim, num_classes, bias=True)
15
+ if self.global_pool:
16
+ norm_layer = kwargs['norm_layer']
17
+ embed_dim = kwargs['embed_dim']
18
+ self.fc_norm = norm_layer(embed_dim)
19
+
20
+ del self.norm
21
+
22
+ for param in self.parameters():
23
+ param.requires_grad = False
24
+
25
+ for param in self.head.parameters():
26
+ param.requires_grad = True
27
+
28
+ def forward_features(self, x):
29
+ B = x.shape[0]
30
+ x = self.patch_embed(x)
31
+
32
+ cls_tokens = self.cls_token.expand(B, -1, -1)
33
+ x = torch.cat((cls_tokens, x), dim=1)
34
+ x = x + self.pos_embed
35
+ x = self.pos_drop(x)
36
+
37
+ for blk in self.blocks:
38
+ x = blk(x)
39
+
40
+ if self.global_pool:
41
+ x = x[:, 1:, :].mean(dim=1)
42
+ outcome = self.fc_norm(x)
43
+ else:
44
+ x = self.norm(x)
45
+ outcome = x[:, 0]
46
+
47
+ return outcome
48
+
49
+ def forward(self, x):
50
+ x = self.forward_features(x)
51
+ x = self.head(x)
52
+ return x
53
+
54
+
55
+ def vit_base_patch16(**kwargs):
56
+ model = vit(img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
57
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
58
+ return model
59
+
60
+
61
+ def vit_large_patch16(**kwargs):
62
+ model = vit(patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
63
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
64
+ return model
65
+
66
+
67
+ def vit_huge_patch14(**kwargs):
68
+ model = vit(patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True,
69
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
70
+ return model