Upload 26 files
Browse files- .gitattributes +1 -0
- README.md +23 -3
- SARdatasets.py +245 -0
- acc_pretrain.py +213 -0
- get_args.py +64 -0
- log_analyze.py +58 -0
- mae_model.py +259 -0
- mae_ori_model.py +233 -0
- overall.pdf +3 -0
- pos_embed.py +96 -0
- requirements.txt +13 -0
- util/__pycache__/lr_decay.cpython-310.pyc +0 -0
- util/__pycache__/lr_decay.cpython-312.pyc +0 -0
- util/__pycache__/lr_sched.cpython-310.pyc +0 -0
- util/__pycache__/lr_sched.cpython-312.pyc +0 -0
- util/__pycache__/misc.cpython-310.pyc +0 -0
- util/__pycache__/misc.cpython-312.pyc +0 -0
- util/__pycache__/pos_embed.cpython-310.pyc +0 -0
- util/__pycache__/pos_embed.cpython-312.pyc +0 -0
- util/crop.py +42 -0
- util/datasets.py +63 -0
- util/lars.py +47 -0
- util/lr_decay.py +76 -0
- util/lr_sched.py +22 -0
- util/misc.py +353 -0
- util/pos_embed.py +96 -0
- vit_model.py +70 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
overall.pdf filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,3 +1,23 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SUMMIT: A SAR Foundation Model with Multiple Auxiliary Tasks Enhanced Intrinsic Characteristics
|
| 2 |
+
[SUMMIT: A SAR Foundation Model with Multiple Auxiliary Tasks Enhanced Intrinsic Characteristics](https://doi.org/10.1016/j.jag.2025.104624)
|
| 3 |
+
## Overview
|
| 4 |
+
This repository hosts the official implementation of SUMMIT, a state-of-the-art (SOTA) foundation model tailored for Synthetic Aperture Radar (SAR) image understanding. Proposed in the paper "SUMMIT: A SAR foundation model with multiple auxiliary tasks enhanced intrinsic characteristics" (published in International Journal of Applied Earth Observation and Geoinformation, 2025), SUMMIT addresses the limitations of existing deep learning methods in SAR processing—such as neglecting SAR’s intrinsic physical properties and poor cross-task generalization—by integrating self-supervised auxiliary tasks and SAR-specific prior knowledge.
|
| 5 |
+
|
| 6 |
+
## Key Contributions
|
| 7 |
+
1. Large-Scale SAR Dataset (MuSID)Constructed the Multi-sensor SAR Image Dataset (MuSID) with over 560,000 SAR images, covering diverse scenarios (aircraft, ships, bridges, harbors), resolutions (0.1–25 m), and sensors (Gaofen-3, Sentinel-1, TerraSARX, etc.). It supports large-scale self-supervised pre-training for SAR foundation models.
|
| 8 |
+
|
| 9 |
+
2. Multi-Auxiliary-Task Pre-Training FrameworkProposed three self-supervised auxiliary tasks (SSATs) to enhance SAR feature learning: Masked Image Modeling (MIM): Learns robust structural representations of SAR images. Self-Supervised Denoising: Mitigates speckle noise (a unique artifact of SAR imaging) and improves noise resistance. Spatial Scattering Feature (SSF) Enhancement: Preserves geometric consistency by extracting edge features (via Canny algorithm) and scattering point features (via Harris corner detection).
|
| 10 |
+
|
| 11 |
+
3. Auxiliary Task Coordination Module (ATCM)Designed ATCM to dynamically balance and fuse the three auxiliary tasks. Unlike simple task aggregation, ATCM aligns each task with the optimal stage of the learning process (e.g., denoising at input level, edge reconstruction at output level), ensuring effective integration of SAR physical properties into feature learning.
|
| 12 |
+
|
| 13 |
+
## Model Architecture
|
| 14 |
+
|
| 15 |
+

|
| 16 |
+
|
| 17 |
+
SUMMIT is built on a Vision Transformer (ViT). Pre-Training StageInput: MuSID dataset (448×448 resized images). Process: ATCM coordinates MIM, denoising, and SSF enhancement tasks. The shared ViT encoder learns SAR-specific features, with a decoder optimizing multi-task reconstruction loss.
|
| 18 |
+
|
| 19 |
+
## Environment Setup
|
| 20 |
+
```bash
|
| 21 |
+
conda create -n summit python=3.8
|
| 22 |
+
conda activate summit
|
| 23 |
+
pip install -r requirements.txt
|
SARdatasets.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torchvision.datasets import ImageFolder
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from PIL import ImageFile
|
| 5 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
| 6 |
+
|
| 7 |
+
import cv2
|
| 8 |
+
import numpy as np
|
| 9 |
+
import random
|
| 10 |
+
from scipy.ndimage import convolve
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class SARImageFolder(ImageFolder):
|
| 14 |
+
def __init__(self, root, transform=None):
|
| 15 |
+
super().__init__(root, transform=transform)
|
| 16 |
+
|
| 17 |
+
def __getitem__(self, index):
|
| 18 |
+
path, target = self.samples[index]
|
| 19 |
+
|
| 20 |
+
image = cv2.imread(path)
|
| 21 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
| 22 |
+
image = np.float32(image)
|
| 23 |
+
|
| 24 |
+
edges = cv2.Canny(image.astype(np.uint8), 200, 300)
|
| 25 |
+
|
| 26 |
+
corners = cv2.cornerHarris(image, 5, 3, 0.04)
|
| 27 |
+
corners = corners * 255
|
| 28 |
+
|
| 29 |
+
multi_channel_image = np.dstack((image, edges, corners))
|
| 30 |
+
multi_channel_image = multi_channel_image.astype(np.uint8)
|
| 31 |
+
multi_channel_image = Image.fromarray(multi_channel_image)
|
| 32 |
+
|
| 33 |
+
if self.transform is not None:
|
| 34 |
+
multi_channel_image = self.transform(multi_channel_image)
|
| 35 |
+
|
| 36 |
+
return multi_channel_image, target
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class build_coed_SARImageFolder(ImageFolder):
|
| 40 |
+
def __init__(self, root, transform=None):
|
| 41 |
+
super().__init__(root, transform=transform)
|
| 42 |
+
|
| 43 |
+
def __getitem__(self, index):
|
| 44 |
+
path, target = self.samples[index]
|
| 45 |
+
|
| 46 |
+
image_3ch = Image.open(path).convert('RGB')
|
| 47 |
+
image = Image.open(path).convert('L')
|
| 48 |
+
image_np = np.array(image)
|
| 49 |
+
|
| 50 |
+
edges = cv2.Canny(image_np, 200, 300)
|
| 51 |
+
|
| 52 |
+
corners = cv2.cornerHarris(image_np, 5, 3, 0.04)
|
| 53 |
+
corners = corners * 255
|
| 54 |
+
|
| 55 |
+
multi_channel_image = np.dstack((image_np, edges, corners))
|
| 56 |
+
multi_channel_image = multi_channel_image.astype(np.uint8)
|
| 57 |
+
multi_channel_image = Image.fromarray(multi_channel_image)
|
| 58 |
+
|
| 59 |
+
if self.transform is not None:
|
| 60 |
+
multi_channel_image = self.transform(multi_channel_image)
|
| 61 |
+
image_3ch = self.transform(image_3ch)
|
| 62 |
+
|
| 63 |
+
target = multi_channel_image
|
| 64 |
+
|
| 65 |
+
return image_3ch, target
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class Multi_task_SARImageFolder(ImageFolder):
|
| 69 |
+
def __init__(self, root, transform=None):
|
| 70 |
+
super().__init__(root, transform=transform)
|
| 71 |
+
|
| 72 |
+
def add_gamma_noise(self, image_np, looks):
|
| 73 |
+
"""
|
| 74 |
+
向图像添加伽马分布的相干斑噪声
|
| 75 |
+
:param image_np: 原始图像的numpy数组
|
| 76 |
+
:param looks: SAR图像的等效视数(ENL,越大噪声越小)
|
| 77 |
+
:return: 加噪后的图像
|
| 78 |
+
"""
|
| 79 |
+
image_np = image_np.astype(np.float32)
|
| 80 |
+
|
| 81 |
+
image_np = image_np / np.max(image_np)
|
| 82 |
+
|
| 83 |
+
gamma_noise = np.random.gamma(shape=looks, scale=1.0 / looks, size=image_np.shape)
|
| 84 |
+
|
| 85 |
+
noisy_image = image_np * gamma_noise
|
| 86 |
+
|
| 87 |
+
noisy_image = np.clip(noisy_image * 255, 0, 255).astype(np.uint8)
|
| 88 |
+
|
| 89 |
+
return noisy_image
|
| 90 |
+
|
| 91 |
+
def add_gaussian_noise(self, image_np, snr_db):
|
| 92 |
+
"""
|
| 93 |
+
向图像添加高斯白噪声
|
| 94 |
+
:param image_np: 原始图像的numpy数组
|
| 95 |
+
:param snr_db: 期望的信噪比(以分贝为单位)
|
| 96 |
+
:return: 加噪后的图像
|
| 97 |
+
"""
|
| 98 |
+
signal_power = np.mean(image_np ** 2)
|
| 99 |
+
|
| 100 |
+
snr = 10 ** (snr_db / 10)
|
| 101 |
+
|
| 102 |
+
noise_power = signal_power / snr
|
| 103 |
+
|
| 104 |
+
noise_sigma = np.sqrt(noise_power)
|
| 105 |
+
|
| 106 |
+
current_state = torch.random.get_rng_state()
|
| 107 |
+
current_cuda_state = torch.cuda.get_rng_state()
|
| 108 |
+
|
| 109 |
+
torch.manual_seed(np.random.randint(0, 2 ** 31 - 1))
|
| 110 |
+
torch.cuda.manual_seed_all(np.random.randint(0, 2 ** 31 - 1))
|
| 111 |
+
|
| 112 |
+
noise = np.random.normal(0, noise_sigma, image_np.shape)
|
| 113 |
+
|
| 114 |
+
torch.random.set_rng_state(current_state)
|
| 115 |
+
torch.cuda.set_rng_state(current_cuda_state)
|
| 116 |
+
|
| 117 |
+
noisy_image = image_np + noise
|
| 118 |
+
|
| 119 |
+
return noisy_image.astype(np.uint8)
|
| 120 |
+
|
| 121 |
+
def log_transform(self, image_np):
|
| 122 |
+
image_np = image_np.astype(np.float32)
|
| 123 |
+
|
| 124 |
+
c = 20.0
|
| 125 |
+
transformed_image = c * np.log1p(image_np) # torch.log1p计算log(1 + x)
|
| 126 |
+
|
| 127 |
+
return transformed_image
|
| 128 |
+
|
| 129 |
+
def __getitem__(self, index):
|
| 130 |
+
path, target = self.samples[index]
|
| 131 |
+
|
| 132 |
+
image_3ch = Image.open(path).convert('RGB')
|
| 133 |
+
image_3ch_np = np.array(image_3ch)
|
| 134 |
+
|
| 135 |
+
image = Image.open(path).convert('L')
|
| 136 |
+
image_np = np.array(image)
|
| 137 |
+
|
| 138 |
+
edges = cv2.Canny(image_np, 200, 300)
|
| 139 |
+
|
| 140 |
+
corners = cv2.cornerHarris(image_np, 5, 3, 0.04)
|
| 141 |
+
corners = corners * 255
|
| 142 |
+
|
| 143 |
+
first_channel = image_3ch_np[:, :, 0]
|
| 144 |
+
noisy_first_channel = self.add_gamma_noise(first_channel, 30)
|
| 145 |
+
image_3ch_np[:, :, 0] = noisy_first_channel
|
| 146 |
+
image_3ch = Image.fromarray(image_3ch_np)
|
| 147 |
+
|
| 148 |
+
multi_channel_image = np.dstack((image_np, edges, corners))
|
| 149 |
+
multi_channel_image = multi_channel_image.astype(np.uint8)
|
| 150 |
+
multi_channel_image = Image.fromarray(multi_channel_image)
|
| 151 |
+
|
| 152 |
+
if self.transform is not None:
|
| 153 |
+
multi_channel_image = self.transform(multi_channel_image)
|
| 154 |
+
image_3ch = self.transform(image_3ch)
|
| 155 |
+
|
| 156 |
+
target = multi_channel_image
|
| 157 |
+
|
| 158 |
+
return image_3ch, target
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
class Multi_task_angel_SARImageFolder(ImageFolder):
|
| 162 |
+
def __init__(self, root, transform=None):
|
| 163 |
+
super().__init__(root, transform=transform)
|
| 164 |
+
|
| 165 |
+
def add_gaussian_noise(self, image_np, snr_db):
|
| 166 |
+
|
| 167 |
+
signal_power = np.mean(image_np ** 2)
|
| 168 |
+
|
| 169 |
+
snr = 10 ** (snr_db / 10)
|
| 170 |
+
|
| 171 |
+
noise_power = signal_power / snr
|
| 172 |
+
|
| 173 |
+
noise_sigma = np.sqrt(noise_power)
|
| 174 |
+
|
| 175 |
+
noise = np.random.normal(0, noise_sigma, image_np.shape)
|
| 176 |
+
|
| 177 |
+
noisy_image = image_np + noise
|
| 178 |
+
|
| 179 |
+
return noisy_image.astype(np.uint8)
|
| 180 |
+
|
| 181 |
+
def log_transform(self, image_np):
|
| 182 |
+
|
| 183 |
+
image_np = image_np.astype(np.float32)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
c = 20.0
|
| 187 |
+
transformed_image = c * np.log1p(image_np)
|
| 188 |
+
|
| 189 |
+
return transformed_image
|
| 190 |
+
|
| 191 |
+
def __getitem__(self, index):
|
| 192 |
+
path, target = self.samples[index]
|
| 193 |
+
|
| 194 |
+
image_3ch = Image.open(path).convert('RGB')
|
| 195 |
+
image_3ch_np = np.array(image_3ch)
|
| 196 |
+
|
| 197 |
+
image = Image.open(path).convert('L')
|
| 198 |
+
image_np = np.array(image)
|
| 199 |
+
|
| 200 |
+
edges = cv2.Canny(image_np, 200, 300)
|
| 201 |
+
|
| 202 |
+
corners = cv2.cornerHarris(image_np, 5, 3, 0.04)
|
| 203 |
+
corners = corners * 255
|
| 204 |
+
|
| 205 |
+
kernel_size = 50
|
| 206 |
+
kernel = np.ones((kernel_size, kernel_size))
|
| 207 |
+
density = convolve(corners, kernel, mode='constant', cval=0.0)
|
| 208 |
+
|
| 209 |
+
max_density_index = np.unravel_index(np.argmax(density), density.shape)
|
| 210 |
+
center_y, center_x = max_density_index
|
| 211 |
+
|
| 212 |
+
half_size = kernel_size // 2
|
| 213 |
+
start_y = max(center_y - half_size, 0)
|
| 214 |
+
end_y = min(center_y + half_size, corners.shape[0])
|
| 215 |
+
start_x = max(center_x - half_size, 0)
|
| 216 |
+
end_x = min(center_x + half_size, corners.shape[1])
|
| 217 |
+
|
| 218 |
+
region = image_np[start_y:end_y, start_x:end_x]
|
| 219 |
+
|
| 220 |
+
angle = random.choice([0, 90, 180, 270])
|
| 221 |
+
M = cv2.getRotationMatrix2D((region.shape[1] // 2, region.shape[0] // 2), angle, 1)
|
| 222 |
+
rotated_region = cv2.warpAffine(region, M, (region.shape[1], region.shape[0]))
|
| 223 |
+
|
| 224 |
+
rotated_image = image_np.copy()
|
| 225 |
+
rotated_image[start_y:end_y, start_x:end_x] = rotated_region
|
| 226 |
+
|
| 227 |
+
image_4ch_np = np.insert(image_3ch_np, 1, rotated_image, axis=2)
|
| 228 |
+
|
| 229 |
+
first_channel = image_3ch_np[:, :, 0]
|
| 230 |
+
first_channel = self.log_transform(first_channel)
|
| 231 |
+
noisy_first_channel = self.add_gaussian_noise(first_channel, 30)
|
| 232 |
+
image_4ch_np[:, :, 0] = noisy_first_channel
|
| 233 |
+
image_4ch = Image.fromarray(image_3ch_np)
|
| 234 |
+
|
| 235 |
+
multi_channel_image = np.dstack((image_np, image_np, edges, corners))
|
| 236 |
+
multi_channel_image = multi_channel_image.astype(np.uint8)
|
| 237 |
+
multi_channel_image = Image.fromarray(multi_channel_image)
|
| 238 |
+
|
| 239 |
+
if self.transform is not None:
|
| 240 |
+
multi_channel_image = self.transform(multi_channel_image)
|
| 241 |
+
image_4ch = self.transform(image_4ch)
|
| 242 |
+
|
| 243 |
+
target = image_4ch
|
| 244 |
+
|
| 245 |
+
return multi_channel_image, target
|
acc_pretrain.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path
|
| 2 |
+
|
| 3 |
+
from get_args import get_args_pretrain
|
| 4 |
+
import mae_model
|
| 5 |
+
# import mae_ori_model
|
| 6 |
+
import numpy as np
|
| 7 |
+
import datetime
|
| 8 |
+
import time
|
| 9 |
+
import json
|
| 10 |
+
import math
|
| 11 |
+
import sys
|
| 12 |
+
from typing import Iterable
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from accelerate import Accelerator
|
| 15 |
+
import torch
|
| 16 |
+
import torch.backends.cudnn as cudnn
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 19 |
+
import torchvision.transforms as transforms
|
| 20 |
+
import torchvision.datasets as datasets
|
| 21 |
+
import timm.optim.optim_factory as optim_factory
|
| 22 |
+
from SARdatasets import SARImageFolder, build_coed_SARImageFolder, Multi_task_SARImageFolder
|
| 23 |
+
import util.misc as misc
|
| 24 |
+
import util.lr_sched as lr_sched
|
| 25 |
+
from util.pos_embed import interpolate_pos_embed
|
| 26 |
+
from util.misc import NativeScalerWithGradNormCount as NativeScaler
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def train_one_epoch(model: torch.nn.Module, data_loader: Iterable, optimizer: torch.optim.Optimizer,
|
| 30 |
+
device: torch.device, epoch: int, loss_scaler,
|
| 31 |
+
log_writer=None,
|
| 32 |
+
args=None,
|
| 33 |
+
accelerator=None):
|
| 34 |
+
model.train(True)
|
| 35 |
+
metric_logger = misc.MetricLogger(delimiter=" ")
|
| 36 |
+
metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
| 37 |
+
header = 'Epoch: [{}]'.format(epoch)
|
| 38 |
+
print_freq = 20
|
| 39 |
+
|
| 40 |
+
accum_iter = args.accum_iter
|
| 41 |
+
|
| 42 |
+
optimizer.zero_grad()
|
| 43 |
+
|
| 44 |
+
if log_writer is not None:
|
| 45 |
+
print('log_dir: {}'.format(log_writer.log_dir))
|
| 46 |
+
|
| 47 |
+
for data_iter_step, (samples, target) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
|
| 48 |
+
|
| 49 |
+
samples = samples.to(device, non_blocking=True)
|
| 50 |
+
target = target.to(device, non_blocking=True)
|
| 51 |
+
|
| 52 |
+
with torch.cuda.amp.autocast():
|
| 53 |
+
loss, channel_loss, _, _ = model(samples, target) #, mask_ratio=args.mask_ratio)
|
| 54 |
+
|
| 55 |
+
loss_value = loss.item()
|
| 56 |
+
|
| 57 |
+
if not math.isfinite(loss_value):
|
| 58 |
+
print("Loss is {}, stopping training".format(loss_value))
|
| 59 |
+
sys.exit(1)
|
| 60 |
+
|
| 61 |
+
accelerator.backward(loss)
|
| 62 |
+
|
| 63 |
+
if (data_iter_step + 1) % accum_iter == 0:
|
| 64 |
+
optimizer.zero_grad()
|
| 65 |
+
|
| 66 |
+
torch.cuda.synchronize()
|
| 67 |
+
|
| 68 |
+
metric_logger.update(loss=loss_value)
|
| 69 |
+
|
| 70 |
+
lr = optimizer.param_groups[0]["lr"]
|
| 71 |
+
metric_logger.update(lr=lr)
|
| 72 |
+
|
| 73 |
+
loss_value_reduce = misc.all_reduce_mean(loss_value)
|
| 74 |
+
if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
|
| 75 |
+
""" We use epoch_1000x as the x-axis in tensorboard.
|
| 76 |
+
This calibrates different curves when batch size changes.
|
| 77 |
+
"""
|
| 78 |
+
epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
|
| 79 |
+
log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x)
|
| 80 |
+
log_writer.add_scalar('lr', lr, epoch_1000x)
|
| 81 |
+
# log_writer.add_scalar('Channel Loss Mean', channel_loss, epoch_1000x)
|
| 82 |
+
# print(f"Channel Loss Mean: {channel_loss}")
|
| 83 |
+
|
| 84 |
+
# gather the stats from all processes
|
| 85 |
+
metric_logger.synchronize_between_processes()
|
| 86 |
+
print("Averaged stats:", metric_logger)
|
| 87 |
+
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def main(args):
|
| 91 |
+
misc.init_distributed_mode(args)
|
| 92 |
+
torch.multiprocessing.set_start_method('spawn', force=True)
|
| 93 |
+
print ('work_dir:{}'.format(os.path.realpath(__file__)))
|
| 94 |
+
accelerator = Accelerator()
|
| 95 |
+
device = torch.device(args.device)
|
| 96 |
+
device = accelerator.device
|
| 97 |
+
# fix the seed for reproducibility
|
| 98 |
+
seed = args.seed + misc.get_rank()
|
| 99 |
+
torch.manual_seed(seed)
|
| 100 |
+
np.random.seed(seed)
|
| 101 |
+
cudnn.benchmark = True
|
| 102 |
+
# simple augmentation
|
| 103 |
+
transform_train = transforms.Compose([
|
| 104 |
+
transforms.RandomResizedCrop(args.input_size, scale=(0.2, 1.0)), # 3 is bicubicinterpolation=3
|
| 105 |
+
transforms.RandomHorizontalFlip(),
|
| 106 |
+
transforms.ToTensor(),
|
| 107 |
+
])
|
| 108 |
+
|
| 109 |
+
dataset_train = Multi_task_SARImageFolder(root=args.data_path, transform=transform_train)
|
| 110 |
+
|
| 111 |
+
print(dataset_train)
|
| 112 |
+
|
| 113 |
+
if True:
|
| 114 |
+
num_tasks = misc.get_world_size()
|
| 115 |
+
global_rank = misc.get_rank()
|
| 116 |
+
sampler_train = torch.utils.data.DistributedSampler(
|
| 117 |
+
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True)
|
| 118 |
+
print("Sampler_train = %s" % str(sampler_train))
|
| 119 |
+
else:
|
| 120 |
+
sampler_train = torch.utils.data.RandomSampler(dataset_train)
|
| 121 |
+
|
| 122 |
+
if global_rank == 0 and args.log_dir is not None:
|
| 123 |
+
os.makedirs(args.log_dir, exist_ok=True)
|
| 124 |
+
log_writer = SummaryWriter(log_dir=args.log_dir)
|
| 125 |
+
else:
|
| 126 |
+
log_writer = None
|
| 127 |
+
|
| 128 |
+
data_loader_train = torch.utils.data.DataLoader(dataset_train, sampler=sampler_train, batch_size=args.batch_size,
|
| 129 |
+
num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=True, shuffle=False
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
model = mae_model.__dict__[args.model](norm_pix_loss=args.norm_pix_loss)
|
| 133 |
+
|
| 134 |
+
# load pretrain checkpoint of Imagenet
|
| 135 |
+
checkpoint = torch.load(args.finetune, map_location='cpu')
|
| 136 |
+
|
| 137 |
+
print("Load pre-trained checkpoint from: %s" % args.finetune)
|
| 138 |
+
checkpoint_model = checkpoint['model']
|
| 139 |
+
state_dict = model.state_dict()
|
| 140 |
+
for k in ['head.weight', 'head.bias']:
|
| 141 |
+
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
|
| 142 |
+
print(f"Removing key {k} from pretrained checkpoint")
|
| 143 |
+
del checkpoint_model[k]
|
| 144 |
+
|
| 145 |
+
# interpolate position embedding
|
| 146 |
+
interpolate_pos_embed(model, checkpoint_model)
|
| 147 |
+
# load pre-trained model
|
| 148 |
+
msg = model.load_state_dict(checkpoint_model, strict=False)
|
| 149 |
+
print(msg)
|
| 150 |
+
|
| 151 |
+
model.to(device)
|
| 152 |
+
model_without_ddp = model
|
| 153 |
+
print("Model = %s" % str(model_without_ddp))
|
| 154 |
+
|
| 155 |
+
eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
|
| 156 |
+
|
| 157 |
+
if args.lr is None: # only base_lr is specified
|
| 158 |
+
args.lr = args.blr * eff_batch_size / 80 # 256
|
| 159 |
+
|
| 160 |
+
print("base lr: %.2e" % (args.lr * 80 / eff_batch_size))
|
| 161 |
+
print("actual lr: %.2e" % args.lr)
|
| 162 |
+
|
| 163 |
+
print("accumulate grad iterations: %d" % args.accum_iter)
|
| 164 |
+
print("effective batch size: %d" % eff_batch_size)
|
| 165 |
+
|
| 166 |
+
if args.distributed:
|
| 167 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
|
| 168 |
+
model_without_ddp = model.module
|
| 169 |
+
|
| 170 |
+
# following timm: set wd as 0 for bias and norm layers
|
| 171 |
+
param_groups = optim_factory.param_groups_weight_decay(model_without_ddp, args.weight_decay) # add_weight_decay
|
| 172 |
+
optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))
|
| 173 |
+
print(optimizer)
|
| 174 |
+
loss_scaler = NativeScaler()
|
| 175 |
+
|
| 176 |
+
model, optimizer, data_loader_train = accelerator.prepare(model, optimizer, data_loader_train)
|
| 177 |
+
|
| 178 |
+
print(f"Start training for {args.epochs} epochs")
|
| 179 |
+
start_time = time.time()
|
| 180 |
+
for epoch in range(args.start_epoch, args.epochs):
|
| 181 |
+
train_stats = train_one_epoch(
|
| 182 |
+
model, data_loader_train,
|
| 183 |
+
optimizer, device, epoch, loss_scaler,
|
| 184 |
+
log_writer=log_writer,
|
| 185 |
+
args=args,
|
| 186 |
+
accelerator=accelerator
|
| 187 |
+
)
|
| 188 |
+
if args.output_dir and (epoch % 50 == 0 or epoch + 1 == args.epochs):
|
| 189 |
+
misc.save_model(
|
| 190 |
+
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
|
| 191 |
+
loss_scaler=loss_scaler, epoch=epoch)
|
| 192 |
+
|
| 193 |
+
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
|
| 194 |
+
'epoch': epoch, }
|
| 195 |
+
|
| 196 |
+
if args.output_dir and misc.is_main_process():
|
| 197 |
+
if log_writer is not None:
|
| 198 |
+
log_writer.flush()
|
| 199 |
+
with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
|
| 200 |
+
f.write(json.dumps(log_stats) + "\n")
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
total_time = time.time() - start_time
|
| 204 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
| 205 |
+
print('Training time {}'.format(total_time_str))
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
if __name__ == '__main__':
|
| 209 |
+
args = get_args_pretrain()
|
| 210 |
+
args = args.parse_args()
|
| 211 |
+
if args.output_dir:
|
| 212 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
| 213 |
+
main(args)
|
get_args.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def get_args_pretrain():
|
| 5 |
+
parser = argparse.ArgumentParser('MAE pre-training', add_help=False)
|
| 6 |
+
parser.add_argument('--batch_size', default=32, type=int,
|
| 7 |
+
help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
|
| 8 |
+
parser.add_argument('--epochs', default=100, type=int)
|
| 9 |
+
parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N',
|
| 10 |
+
help='epochs to warmup LR')
|
| 11 |
+
parser.add_argument('--accum_iter', default=1, type=int,
|
| 12 |
+
help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
|
| 13 |
+
parser.add_argument('--finetune',
|
| 14 |
+
default='.', )
|
| 15 |
+
|
| 16 |
+
# Model parameters
|
| 17 |
+
parser.add_argument('--model', default='mae_vit_base_patch16', type=str, metavar='MODEL',
|
| 18 |
+
help='Name of model to train')
|
| 19 |
+
|
| 20 |
+
parser.add_argument('--input_size', default=448, type=int,
|
| 21 |
+
help='images input size')
|
| 22 |
+
|
| 23 |
+
parser.add_argument('--mask_ratio', default=0.75, type=float,
|
| 24 |
+
help='Masking ratio (percentage of removed patches).')
|
| 25 |
+
|
| 26 |
+
parser.add_argument('--norm_pix_loss', action='store_true',
|
| 27 |
+
help='Use (per-patch) normalized pixels as targets for computing loss')
|
| 28 |
+
parser.set_defaults(norm_pix_loss=False)
|
| 29 |
+
|
| 30 |
+
# Optimizer parameters
|
| 31 |
+
parser.add_argument('--weight_decay', type=float, default=0.05,
|
| 32 |
+
help='weight decay (default: 0.05)')
|
| 33 |
+
|
| 34 |
+
parser.add_argument('--lr', type=float, default=None, metavar='LR',
|
| 35 |
+
help='learning rate (absolute lr)')
|
| 36 |
+
parser.add_argument('--blr', type=float, default=1e-4, metavar='LR',
|
| 37 |
+
help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
|
| 38 |
+
parser.add_argument('--min_lr', type=float, default=5e-8, metavar='LR',
|
| 39 |
+
help='lower lr bound for cyclic schedulers that hit 0')
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# Dataset parameters
|
| 43 |
+
parser.add_argument('--data_path', default=f'/home/SARDatasets/SARfolder/', type=str,
|
| 44 |
+
help='dataset pathpwp')
|
| 45 |
+
|
| 46 |
+
parser.add_argument('--output_dir', default='./output',
|
| 47 |
+
help='path where to save, empty for no saving')
|
| 48 |
+
parser.add_argument('--log_dir', default='./output',
|
| 49 |
+
help='path where to tensorboard log')
|
| 50 |
+
parser.add_argument('--device', default='cuda',
|
| 51 |
+
help='device to use for training / testing')
|
| 52 |
+
parser.add_argument('--seed', default=0, type=int)
|
| 53 |
+
parser.add_argument('--resume', default=False,
|
| 54 |
+
help='resume from checkpoint')
|
| 55 |
+
|
| 56 |
+
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
|
| 57 |
+
help='start epoch')
|
| 58 |
+
parser.add_argument('--num_workers', default=4, type=int)
|
| 59 |
+
parser.add_argument('--pin_mem', action='store_true',
|
| 60 |
+
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
|
| 61 |
+
parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
|
| 62 |
+
parser.set_defaults(pin_mem=True)
|
| 63 |
+
|
| 64 |
+
return parser
|
log_analyze.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def get_log(path):
|
| 7 |
+
epoch = []
|
| 8 |
+
train_lr = []
|
| 9 |
+
train_loss = []
|
| 10 |
+
test_loss = []
|
| 11 |
+
test_acc1 = []
|
| 12 |
+
decoder = json.JSONDecoder()
|
| 13 |
+
log = open(os.path.join(path, 'log.txt'), encoding='utf-8')
|
| 14 |
+
data = log.readlines()
|
| 15 |
+
for data_line in data:
|
| 16 |
+
data_line = data_line.strip('\n')
|
| 17 |
+
data_line = decoder.raw_decode(data_line)
|
| 18 |
+
print(data_line)
|
| 19 |
+
data_line = data_line[0]
|
| 20 |
+
epoch_line = data_line['epoch']
|
| 21 |
+
epoch.append(epoch_line)
|
| 22 |
+
lr_line = data_line['train_lr']
|
| 23 |
+
train_lr.append(lr_line)
|
| 24 |
+
loss_line = data_line['train_loss']
|
| 25 |
+
train_loss.append(loss_line)
|
| 26 |
+
test_los_line = data_line['test_loss']
|
| 27 |
+
test_loss.append(test_los_line)
|
| 28 |
+
acc1_line = data_line['test_acc1']
|
| 29 |
+
test_acc1.append(acc1_line)
|
| 30 |
+
log.close()
|
| 31 |
+
return epoch, train_lr, train_loss, test_loss, test_acc1
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
path = 'output_dir_finetune/'
|
| 35 |
+
path_noise = 'output_dir_finetune/'
|
| 36 |
+
epoch, train_lr, train_loss, test_loss, test_acc1 = get_log(path)
|
| 37 |
+
epoch_noise, train_lr_noise, train_loss_noise, test_loss_noise, test_acc1_noise = get_log(path_noise)
|
| 38 |
+
# 绘制test_acc1的曲线图
|
| 39 |
+
plt.figure()
|
| 40 |
+
plt.plot(test_acc1, color='r', label='test accuracy of multi-task pre-trained')
|
| 41 |
+
plt.plot(test_acc1_noise, color='b', label='test accuracy of none pre-trained')
|
| 42 |
+
# plt.title('Test Accuracy Over Time')
|
| 43 |
+
plt.xlabel('Epoch')
|
| 44 |
+
# plt.ylabel('test accuracy')
|
| 45 |
+
plt.legend()
|
| 46 |
+
plt.show()
|
| 47 |
+
plt.savefig(os.path.join(path, 'acd_acc.png'))
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
plt.figure()
|
| 51 |
+
plt.plot(train_loss, color='r', label='train loss of multi-task pre-trained')
|
| 52 |
+
plt.plot(train_loss_noise, color='b', label='train loss of none pre-trained')
|
| 53 |
+
# plt.title('Test Accuracy Over Time')
|
| 54 |
+
plt.xlabel('Epoch')
|
| 55 |
+
# plt.ylabel('test accuracy')
|
| 56 |
+
plt.legend()
|
| 57 |
+
plt.show()
|
| 58 |
+
plt.savefig(os.path.join(path, 'acd_loss.png'))
|
mae_model.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from timm.models.vision_transformer import PatchEmbed, Block
|
| 4 |
+
from pos_embed import get_2d_sincos_pos_embed
|
| 5 |
+
from functools import partial
|
| 6 |
+
|
| 7 |
+
dd = 12
|
| 8 |
+
class MAEViT(nn.Module):
|
| 9 |
+
def __init__(self, img_size=448, patch_size=16, in_chans=3, embed_dim=1024, depth=24, num_heads=16,
|
| 10 |
+
decoder_embed_dim=512, decoder_depth=dd, decoder_num_heads=16,
|
| 11 |
+
mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False):
|
| 12 |
+
super(MAEViT, self).__init__()
|
| 13 |
+
# MAE Encoder
|
| 14 |
+
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
|
| 15 |
+
num_patches = self.patch_embed.num_patches
|
| 16 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 17 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False)
|
| 18 |
+
self.blocks = nn.ModuleList([
|
| 19 |
+
Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) # qk_scale=None,
|
| 20 |
+
for i in range(depth)])
|
| 21 |
+
self.norm = norm_layer(embed_dim)
|
| 22 |
+
# MAE Decoder
|
| 23 |
+
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
|
| 24 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
|
| 25 |
+
|
| 26 |
+
self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim),
|
| 27 |
+
requires_grad=False) # fixed sin-cos embedding
|
| 28 |
+
|
| 29 |
+
self.decoder_blocks = nn.ModuleList([
|
| 30 |
+
Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) # qk_scale=None,
|
| 31 |
+
for i in range(decoder_depth)])
|
| 32 |
+
|
| 33 |
+
self.decoder_norm = norm_layer(decoder_embed_dim)
|
| 34 |
+
self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size ** 2 * in_chans, bias=True) # decoder to patch
|
| 35 |
+
# --------------------------------------------------------------------------
|
| 36 |
+
|
| 37 |
+
self.norm_pix_loss = norm_pix_loss
|
| 38 |
+
|
| 39 |
+
self.initialize_weights()
|
| 40 |
+
|
| 41 |
+
def initialize_weights(self):
|
| 42 |
+
# initialization
|
| 43 |
+
# initialize (and freeze) pos_embed by sin-cos embedding
|
| 44 |
+
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches ** .5),
|
| 45 |
+
cls_token=True)
|
| 46 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
| 47 |
+
|
| 48 |
+
decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1],
|
| 49 |
+
int(self.patch_embed.num_patches ** .5), cls_token=True)
|
| 50 |
+
self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
|
| 51 |
+
|
| 52 |
+
# initialize patch_embed like nn.Linear (instead of nn.Conv2d)
|
| 53 |
+
w = self.patch_embed.proj.weight.data
|
| 54 |
+
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
| 55 |
+
|
| 56 |
+
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
|
| 57 |
+
torch.nn.init.normal_(self.cls_token, std=.02)
|
| 58 |
+
torch.nn.init.normal_(self.mask_token, std=.02)
|
| 59 |
+
|
| 60 |
+
# initialize nn.Linear and nn.LayerNorm
|
| 61 |
+
self.apply(self._init_weights)
|
| 62 |
+
|
| 63 |
+
def _init_weights(self, m):
|
| 64 |
+
if isinstance(m, nn.Linear):
|
| 65 |
+
# we use xavier_uniform following official JAX ViT:
|
| 66 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
| 67 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 68 |
+
nn.init.constant_(m.bias, 0)
|
| 69 |
+
elif isinstance(m, nn.LayerNorm):
|
| 70 |
+
nn.init.constant_(m.bias, 0)
|
| 71 |
+
nn.init.constant_(m.weight, 1.0)
|
| 72 |
+
|
| 73 |
+
def patchify(self, imgs):
|
| 74 |
+
"""
|
| 75 |
+
imgs: (N, 3, H, W)
|
| 76 |
+
x: (N, L, patch_size**2 *3)
|
| 77 |
+
"""
|
| 78 |
+
p = self.patch_embed.patch_size[0]
|
| 79 |
+
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
|
| 80 |
+
|
| 81 |
+
h = w = imgs.shape[2] // p
|
| 82 |
+
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
|
| 83 |
+
x = torch.einsum('nchpwq->nhwpqc', x)
|
| 84 |
+
x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3))
|
| 85 |
+
return x
|
| 86 |
+
|
| 87 |
+
def unpatchify(self, x):
|
| 88 |
+
"""
|
| 89 |
+
x: (N, L, patch_size**2 *3)
|
| 90 |
+
imgs: (N, 3, H, W)
|
| 91 |
+
"""
|
| 92 |
+
p = self.patch_embed.patch_size[0]
|
| 93 |
+
h = w = int(x.shape[1] ** .5)
|
| 94 |
+
assert h * w == x.shape[1]
|
| 95 |
+
|
| 96 |
+
hid_chans = int(x.shape[2]/(p**2))
|
| 97 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, p, hid_chans))
|
| 98 |
+
x = torch.einsum('nhwpqc->nchpwq', x)
|
| 99 |
+
imgs = x.reshape(shape=(x.shape[0], hid_chans, h * p, w * p))
|
| 100 |
+
return imgs
|
| 101 |
+
|
| 102 |
+
def random_masking(self, x, mask_ratio):
|
| 103 |
+
"""
|
| 104 |
+
Perform per-sample random masking by per-sample shuffling.
|
| 105 |
+
Per-sample shuffling is done by argsort random noise.
|
| 106 |
+
x: [N, L, D], sequence
|
| 107 |
+
"""
|
| 108 |
+
N, L, D = x.shape # batch, length, dim
|
| 109 |
+
len_keep = int(L * (1 - mask_ratio))
|
| 110 |
+
|
| 111 |
+
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
|
| 112 |
+
|
| 113 |
+
# sort noise for each sample
|
| 114 |
+
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
|
| 115 |
+
ids_restore = torch.argsort(ids_shuffle, dim=1)
|
| 116 |
+
|
| 117 |
+
# keep the first subset
|
| 118 |
+
ids_keep = ids_shuffle[:, :len_keep]
|
| 119 |
+
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
|
| 120 |
+
|
| 121 |
+
# generate the binary mask: 0 is keep, 1 is remove
|
| 122 |
+
mask = torch.ones([N, L], device=x.device)
|
| 123 |
+
mask[:, :len_keep] = 0
|
| 124 |
+
# unshuffle to get the binary mask
|
| 125 |
+
mask = torch.gather(mask, dim=1, index=ids_restore)
|
| 126 |
+
|
| 127 |
+
return x_masked, mask, ids_restore
|
| 128 |
+
|
| 129 |
+
def forward_encoder(self, x, mask_ratio):
|
| 130 |
+
# embed patches
|
| 131 |
+
x = self.patch_embed(x)
|
| 132 |
+
|
| 133 |
+
# add pos embed w/o cls token
|
| 134 |
+
x = x + self.pos_embed[:, 1:, :]
|
| 135 |
+
|
| 136 |
+
# masking: length -> length * mask_ratio
|
| 137 |
+
x, mask, ids_restore = self.random_masking(x, mask_ratio)
|
| 138 |
+
|
| 139 |
+
# append cls token
|
| 140 |
+
cls_token = self.cls_token + self.pos_embed[:, :1, :]
|
| 141 |
+
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
|
| 142 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 143 |
+
|
| 144 |
+
# apply Transformer blocks
|
| 145 |
+
for blk in self.blocks:
|
| 146 |
+
x = blk(x)
|
| 147 |
+
x = self.norm(x)
|
| 148 |
+
|
| 149 |
+
return x, mask, ids_restore
|
| 150 |
+
|
| 151 |
+
def forward_decoder(self, x, ids_restore):
|
| 152 |
+
# embed tokens
|
| 153 |
+
x = self.decoder_embed(x)
|
| 154 |
+
|
| 155 |
+
# append mask tokens to sequence
|
| 156 |
+
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
|
| 157 |
+
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
|
| 158 |
+
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
|
| 159 |
+
x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
|
| 160 |
+
|
| 161 |
+
# add pos embed
|
| 162 |
+
x = x + self.decoder_pos_embed
|
| 163 |
+
|
| 164 |
+
# apply Transformer blocks
|
| 165 |
+
for blk in self.decoder_blocks:
|
| 166 |
+
x = blk(x)
|
| 167 |
+
x = self.decoder_norm(x)
|
| 168 |
+
|
| 169 |
+
# predictor projection
|
| 170 |
+
x = self.decoder_pred(x)
|
| 171 |
+
|
| 172 |
+
# remove cls token
|
| 173 |
+
x = x[:, 1:, :]
|
| 174 |
+
|
| 175 |
+
return x
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def forward_loss(self, imgs, pred, mask):
|
| 179 |
+
"""
|
| 180 |
+
imgs: [N, 3, H, W]
|
| 181 |
+
pred: [N, L, p*p*3]
|
| 182 |
+
mask: [N, L], 0 is keep, 1 is remove,
|
| 183 |
+
"""
|
| 184 |
+
target = self.patchify(imgs)
|
| 185 |
+
if self.norm_pix_loss:
|
| 186 |
+
mean = target.mean(dim=-1, keepdim=True)
|
| 187 |
+
var = target.var(dim=-1, keepdim=True)
|
| 188 |
+
target = (target - mean) / (var + 1.e-6) ** .5
|
| 189 |
+
|
| 190 |
+
loss = (pred - target) ** 2
|
| 191 |
+
loss = loss.mean(dim=-1) # [N, L], mean loss per patch
|
| 192 |
+
|
| 193 |
+
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
|
| 194 |
+
return loss
|
| 195 |
+
|
| 196 |
+
def forward_loss_separately(self, imgs, pred, mask):
|
| 197 |
+
"""
|
| 198 |
+
imgs: [N, 3, H, W]
|
| 199 |
+
pred: [N, L, p*p*3]
|
| 200 |
+
mask: [N, L], 0 is keep, 1 is remove,
|
| 201 |
+
"""
|
| 202 |
+
target = self.patchify(imgs)
|
| 203 |
+
if self.norm_pix_loss:
|
| 204 |
+
mean = target.mean(dim=-1, keepdim=True)
|
| 205 |
+
var = target.var(dim=-1, keepdim=True)
|
| 206 |
+
target = (target - mean) / (var + 1.e-6) ** .5
|
| 207 |
+
|
| 208 |
+
channel_weights = torch.tensor([1, 0.5, 0.5], device=pred.device)
|
| 209 |
+
|
| 210 |
+
loss = (pred - target) ** 2
|
| 211 |
+
loss = loss.view(loss.shape[0], loss.shape[1], -1, 3)
|
| 212 |
+
|
| 213 |
+
channel_loss_mean = loss.mean(dim=[0, 1, 2])
|
| 214 |
+
# print(f"Channel Loss Mean: {channel_loss_mean}")
|
| 215 |
+
|
| 216 |
+
loss = loss * channel_weights
|
| 217 |
+
loss = loss.sum(dim=-1)
|
| 218 |
+
loss = loss.mean(dim=-1)
|
| 219 |
+
loss = (loss * mask).sum() / mask.sum()
|
| 220 |
+
|
| 221 |
+
return loss, channel_loss_mean
|
| 222 |
+
|
| 223 |
+
def forward(self, imgs, target, mask_ratio=0.75):
|
| 224 |
+
latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio=mask_ratio)
|
| 225 |
+
pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3]
|
| 226 |
+
# loss = self.forward_loss(imgs, pred, mask)
|
| 227 |
+
# return loss, pred, mask
|
| 228 |
+
loss, channel_loss_mean = self.forward_loss_separately(target, pred, mask)
|
| 229 |
+
return loss, channel_loss_mean, pred, mask
|
| 230 |
+
|
| 231 |
+
# def forward(self, imgs, mask_ratio=0.75):
|
| 232 |
+
# latent, mask = self.forward_encoder(imgs)
|
| 233 |
+
# pred = self.forward_decoder(latent, mask) # Use mask instead of ids_restore
|
| 234 |
+
# loss = self.forward_loss(imgs, pred, mask)
|
| 235 |
+
# return loss, pred, mask
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def mae_vit_base_patch16(**kwargs):
|
| 240 |
+
model = MAEViT(
|
| 241 |
+
patch_size=16, embed_dim=768, depth=12, num_heads=12,
|
| 242 |
+
decoder_embed_dim=512, decoder_depth=dd, decoder_num_heads=16,
|
| 243 |
+
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 244 |
+
return model
|
| 245 |
+
|
| 246 |
+
def mae_vit_large_patch16(**kwargs):
|
| 247 |
+
model = MAEViT(
|
| 248 |
+
patch_size=16, embed_dim=1024, depth=24, num_heads=16,
|
| 249 |
+
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
|
| 250 |
+
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 251 |
+
return model
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def mae_vit_huge_patch14(**kwargs):
|
| 255 |
+
model = MAEViT(
|
| 256 |
+
patch_size=14, embed_dim=1280, depth=32, num_heads=16,
|
| 257 |
+
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
|
| 258 |
+
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 259 |
+
return model
|
mae_ori_model.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.distributions as dist
|
| 4 |
+
import torchvision.transforms
|
| 5 |
+
import numpy as np
|
| 6 |
+
from timm.models.vision_transformer import PatchEmbed, Block
|
| 7 |
+
from pos_embed import get_2d_sincos_pos_embed
|
| 8 |
+
from scipy.stats import gamma, lognorm, expon
|
| 9 |
+
from functools import partial
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DenoiseMAEViT(nn.Module):
|
| 13 |
+
def __init__(self, img_size=448, patch_size=16, in_chans=3, embed_dim=1024, depth=24, num_heads=16,
|
| 14 |
+
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
|
| 15 |
+
mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False):
|
| 16 |
+
super(DenoiseMAEViT, self).__init__()
|
| 17 |
+
# MAE Encoder
|
| 18 |
+
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
|
| 19 |
+
num_patches = self.patch_embed.num_patches
|
| 20 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 21 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False)
|
| 22 |
+
self.blocks = nn.ModuleList([
|
| 23 |
+
Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) # qk_scale=None,
|
| 24 |
+
for i in range(depth)])
|
| 25 |
+
self.norm = norm_layer(embed_dim)
|
| 26 |
+
# MAE Decoder
|
| 27 |
+
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
|
| 28 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
|
| 29 |
+
|
| 30 |
+
self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim),
|
| 31 |
+
requires_grad=False) # fixed sin-cos embedding
|
| 32 |
+
|
| 33 |
+
self.decoder_blocks = nn.ModuleList([
|
| 34 |
+
Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) # qk_scale=None,
|
| 35 |
+
for i in range(decoder_depth)])
|
| 36 |
+
|
| 37 |
+
self.decoder_norm = norm_layer(decoder_embed_dim)
|
| 38 |
+
self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size ** 2 * in_chans, bias=True) # decoder to patch
|
| 39 |
+
# --------------------------------------------------------------------------
|
| 40 |
+
|
| 41 |
+
self.norm_pix_loss = norm_pix_loss
|
| 42 |
+
|
| 43 |
+
self.initialize_weights()
|
| 44 |
+
|
| 45 |
+
def initialize_weights(self):
|
| 46 |
+
# initialization
|
| 47 |
+
# initialize (and freeze) pos_embed by sin-cos embedding
|
| 48 |
+
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches ** .5),
|
| 49 |
+
cls_token=True)
|
| 50 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
| 51 |
+
|
| 52 |
+
decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1],
|
| 53 |
+
int(self.patch_embed.num_patches ** .5), cls_token=True)
|
| 54 |
+
self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
|
| 55 |
+
|
| 56 |
+
# initialize patch_embed like nn.Linear (instead of nn.Conv2d)
|
| 57 |
+
w = self.patch_embed.proj.weight.data
|
| 58 |
+
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
| 59 |
+
|
| 60 |
+
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
|
| 61 |
+
torch.nn.init.normal_(self.cls_token, std=.02)
|
| 62 |
+
torch.nn.init.normal_(self.mask_token, std=.02)
|
| 63 |
+
|
| 64 |
+
# initialize nn.Linear and nn.LayerNorm
|
| 65 |
+
self.apply(self._init_weights)
|
| 66 |
+
|
| 67 |
+
def _init_weights(self, m):
|
| 68 |
+
if isinstance(m, nn.Linear):
|
| 69 |
+
# we use xavier_uniform following official JAX ViT:
|
| 70 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
| 71 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 72 |
+
nn.init.constant_(m.bias, 0)
|
| 73 |
+
elif isinstance(m, nn.LayerNorm):
|
| 74 |
+
nn.init.constant_(m.bias, 0)
|
| 75 |
+
nn.init.constant_(m.weight, 1.0)
|
| 76 |
+
|
| 77 |
+
def patchify(self, imgs):
|
| 78 |
+
"""
|
| 79 |
+
imgs: (N, 3, H, W)
|
| 80 |
+
x: (N, L, patch_size**2 *3)
|
| 81 |
+
"""
|
| 82 |
+
p = self.patch_embed.patch_size[0]
|
| 83 |
+
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
|
| 84 |
+
|
| 85 |
+
h = w = imgs.shape[2] // p
|
| 86 |
+
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
|
| 87 |
+
x = torch.einsum('nchpwq->nhwpqc', x)
|
| 88 |
+
x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3))
|
| 89 |
+
return x
|
| 90 |
+
|
| 91 |
+
def unpatchify(self, x):
|
| 92 |
+
"""
|
| 93 |
+
x: (N, L, patch_size**2 *3)
|
| 94 |
+
imgs: (N, 3, H, W)
|
| 95 |
+
"""
|
| 96 |
+
p = self.patch_embed.patch_size[0]
|
| 97 |
+
h = w = int(x.shape[1] ** .5)
|
| 98 |
+
assert h * w == x.shape[1]
|
| 99 |
+
|
| 100 |
+
hid_chans = int(x.shape[2]/(p**2))
|
| 101 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, p, hid_chans))
|
| 102 |
+
x = torch.einsum('nhwpqc->nchpwq', x)
|
| 103 |
+
imgs = x.reshape(shape=(x.shape[0], hid_chans, h * p, w * p))
|
| 104 |
+
return imgs
|
| 105 |
+
|
| 106 |
+
# x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
|
| 107 |
+
# x = torch.einsum('nhwpqc->nchpwq', x)
|
| 108 |
+
# imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
|
| 109 |
+
# return imgs
|
| 110 |
+
|
| 111 |
+
def random_masking(self, x, mask_ratio):
|
| 112 |
+
"""
|
| 113 |
+
Perform per-sample random masking by per-sample shuffling.
|
| 114 |
+
Per-sample shuffling is done by argsort random noise.
|
| 115 |
+
x: [N, L, D], sequence
|
| 116 |
+
"""
|
| 117 |
+
N, L, D = x.shape # batch, length, dim
|
| 118 |
+
len_keep = int(L * (1 - mask_ratio))
|
| 119 |
+
|
| 120 |
+
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
|
| 121 |
+
|
| 122 |
+
# sort noise for each sample
|
| 123 |
+
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
|
| 124 |
+
ids_restore = torch.argsort(ids_shuffle, dim=1)
|
| 125 |
+
|
| 126 |
+
# keep the first subset
|
| 127 |
+
ids_keep = ids_shuffle[:, :len_keep]
|
| 128 |
+
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
|
| 129 |
+
|
| 130 |
+
# generate the binary mask: 0 is keep, 1 is remove
|
| 131 |
+
mask = torch.ones([N, L], device=x.device)
|
| 132 |
+
mask[:, :len_keep] = 0
|
| 133 |
+
# unshuffle to get the binary mask
|
| 134 |
+
mask = torch.gather(mask, dim=1, index=ids_restore)
|
| 135 |
+
|
| 136 |
+
return x_masked, mask, ids_restore
|
| 137 |
+
|
| 138 |
+
def forward_encoder(self, x, mask_ratio):
|
| 139 |
+
# embed patches
|
| 140 |
+
x = self.patch_embed(x)
|
| 141 |
+
|
| 142 |
+
# add pos embed w/o cls token
|
| 143 |
+
x = x + self.pos_embed[:, 1:, :]
|
| 144 |
+
|
| 145 |
+
# masking: length -> length * mask_ratio
|
| 146 |
+
x, mask, ids_restore = self.random_masking(x, mask_ratio)
|
| 147 |
+
|
| 148 |
+
# append cls token
|
| 149 |
+
cls_token = self.cls_token + self.pos_embed[:, :1, :]
|
| 150 |
+
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
|
| 151 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 152 |
+
|
| 153 |
+
# apply Transformer blocks
|
| 154 |
+
for blk in self.blocks:
|
| 155 |
+
x = blk(x)
|
| 156 |
+
x = self.norm(x)
|
| 157 |
+
|
| 158 |
+
return x, mask, ids_restore
|
| 159 |
+
|
| 160 |
+
def forward_decoder(self, x, ids_restore):
|
| 161 |
+
# embed tokens
|
| 162 |
+
x = self.decoder_embed(x)
|
| 163 |
+
|
| 164 |
+
# append mask tokens to sequence
|
| 165 |
+
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
|
| 166 |
+
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
|
| 167 |
+
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
|
| 168 |
+
x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
|
| 169 |
+
|
| 170 |
+
# add pos embed
|
| 171 |
+
x = x + self.decoder_pos_embed
|
| 172 |
+
|
| 173 |
+
# apply Transformer blocks
|
| 174 |
+
for blk in self.decoder_blocks:
|
| 175 |
+
x = blk(x)
|
| 176 |
+
x = self.decoder_norm(x)
|
| 177 |
+
|
| 178 |
+
# predictor projection
|
| 179 |
+
x = self.decoder_pred(x)
|
| 180 |
+
|
| 181 |
+
# remove cls token
|
| 182 |
+
x = x[:, 1:, :]
|
| 183 |
+
|
| 184 |
+
return x
|
| 185 |
+
|
| 186 |
+
def forward_loss(self, imgs, pred, mask):
|
| 187 |
+
"""
|
| 188 |
+
imgs: [N, 3, H, W]
|
| 189 |
+
pred: [N, L, p*p*3]
|
| 190 |
+
mask: [N, L], 0 is keep, 1 is remove,
|
| 191 |
+
"""
|
| 192 |
+
target = self.patchify(imgs)
|
| 193 |
+
if self.norm_pix_loss:
|
| 194 |
+
mean = target.mean(dim=-1, keepdim=True)
|
| 195 |
+
var = target.var(dim=-1, keepdim=True)
|
| 196 |
+
target = (target - mean) / (var + 1.e-6) ** .5
|
| 197 |
+
|
| 198 |
+
loss = (pred - target) ** 2
|
| 199 |
+
loss = loss.mean(dim=-1) # [N, L], mean loss per patch
|
| 200 |
+
|
| 201 |
+
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
|
| 202 |
+
return loss
|
| 203 |
+
|
| 204 |
+
def forward(self, imgs, mask_ratio=0.75):
|
| 205 |
+
latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
|
| 206 |
+
pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3]
|
| 207 |
+
loss = self.forward_loss(imgs, pred, mask)
|
| 208 |
+
return loss, pred, mask
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def mae_vit_base_patch16(**kwargs):
|
| 212 |
+
model = DenoiseMAEViT(
|
| 213 |
+
patch_size=16, embed_dim=768, depth=12, num_heads=12,
|
| 214 |
+
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
|
| 215 |
+
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 216 |
+
return model
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def mae_vit_large_patch16(**kwargs):
|
| 220 |
+
model = DenoiseMAEViT(
|
| 221 |
+
patch_size=16, embed_dim=1024, depth=24, num_heads=16,
|
| 222 |
+
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
|
| 223 |
+
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 224 |
+
return model
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def mae_vit_huge_patch16(**kwargs):
|
| 228 |
+
model = DenoiseMAEViT(
|
| 229 |
+
patch_size=16, embed_dim=1280, depth=32, num_heads=16,
|
| 230 |
+
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
|
| 231 |
+
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 232 |
+
return model
|
| 233 |
+
|
overall.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5f990a07117a667a2972cbf8c5c609e43b1fecca305ebf4efd4f7e6fa22b35b6
|
| 3 |
+
size 229409
|
pos_embed.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
# Position embedding utils
|
| 8 |
+
# --------------------------------------------------------
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
# --------------------------------------------------------
|
| 15 |
+
# 2D sine-cosine position embedding
|
| 16 |
+
# References:
|
| 17 |
+
# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
|
| 18 |
+
# MoCo v3: https://github.com/facebookresearch/moco-v3
|
| 19 |
+
# --------------------------------------------------------
|
| 20 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
|
| 21 |
+
"""
|
| 22 |
+
grid_size: int of the grid height and width
|
| 23 |
+
return:
|
| 24 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
| 25 |
+
"""
|
| 26 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
| 27 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
| 28 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
| 29 |
+
grid = np.stack(grid, axis=0)
|
| 30 |
+
|
| 31 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
| 32 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 33 |
+
if cls_token:
|
| 34 |
+
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
| 35 |
+
return pos_embed
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
| 39 |
+
assert embed_dim % 2 == 0
|
| 40 |
+
|
| 41 |
+
# use half of dimensions to encode grid_h
|
| 42 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
| 43 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
| 44 |
+
|
| 45 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
| 46 |
+
return emb
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 50 |
+
"""
|
| 51 |
+
embed_dim: output dimension for each position
|
| 52 |
+
pos: a list of positions to be encoded: size (M,)
|
| 53 |
+
out: (M, D)
|
| 54 |
+
"""
|
| 55 |
+
assert embed_dim % 2 == 0
|
| 56 |
+
omega = np.arange(embed_dim // 2, dtype=np.float32)
|
| 57 |
+
omega /= embed_dim / 2.
|
| 58 |
+
omega = 1. / 10000**omega # (D/2,)
|
| 59 |
+
|
| 60 |
+
pos = pos.reshape(-1) # (M,)
|
| 61 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
| 62 |
+
|
| 63 |
+
emb_sin = np.sin(out) # (M, D/2)
|
| 64 |
+
emb_cos = np.cos(out) # (M, D/2)
|
| 65 |
+
|
| 66 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
| 67 |
+
return emb
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# --------------------------------------------------------
|
| 71 |
+
# Interpolate position embeddings for high-resolution
|
| 72 |
+
# References:
|
| 73 |
+
# DeiT: https://github.com/facebookresearch/deit
|
| 74 |
+
# --------------------------------------------------------
|
| 75 |
+
def interpolate_pos_embed(model, checkpoint_model):
|
| 76 |
+
if 'pos_embed' in checkpoint_model:
|
| 77 |
+
pos_embed_checkpoint = checkpoint_model['pos_embed']
|
| 78 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
| 79 |
+
num_patches = model.patch_embed.num_patches
|
| 80 |
+
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
|
| 81 |
+
# height (== width) for the checkpoint position embedding
|
| 82 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
| 83 |
+
# height (== width) for the new position embedding
|
| 84 |
+
new_size = int(num_patches ** 0.5)
|
| 85 |
+
# class_token and dist_token are kept unchanged
|
| 86 |
+
if orig_size != new_size:
|
| 87 |
+
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
|
| 88 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
| 89 |
+
# only the position tokens are interpolated
|
| 90 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
| 91 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
| 92 |
+
pos_tokens = torch.nn.functional.interpolate(
|
| 93 |
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
| 94 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
| 95 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
| 96 |
+
checkpoint_model['pos_embed'] = new_pos_embed
|
requirements.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate==0.28.0
|
| 2 |
+
matplotlib==3.8.4
|
| 3 |
+
numpy==2.3.3
|
| 4 |
+
opencv_python==4.9.0.80
|
| 5 |
+
Pillow==11.3.0
|
| 6 |
+
Requests==2.32.5
|
| 7 |
+
scikit_learn==1.4.2
|
| 8 |
+
scipy==1.16.2
|
| 9 |
+
submitit==1.5.3
|
| 10 |
+
timm==1.0.20
|
| 11 |
+
torch==2.2.1+cu118
|
| 12 |
+
torchvision==0.17.1+cu118
|
| 13 |
+
tqdm==4.66.2
|
util/__pycache__/lr_decay.cpython-310.pyc
ADDED
|
Binary file (1.61 kB). View file
|
|
|
util/__pycache__/lr_decay.cpython-312.pyc
ADDED
|
Binary file (2.34 kB). View file
|
|
|
util/__pycache__/lr_sched.cpython-310.pyc
ADDED
|
Binary file (611 Bytes). View file
|
|
|
util/__pycache__/lr_sched.cpython-312.pyc
ADDED
|
Binary file (1.07 kB). View file
|
|
|
util/__pycache__/misc.cpython-310.pyc
ADDED
|
Binary file (10.9 kB). View file
|
|
|
util/__pycache__/misc.cpython-312.pyc
ADDED
|
Binary file (19.7 kB). View file
|
|
|
util/__pycache__/pos_embed.cpython-310.pyc
ADDED
|
Binary file (2.38 kB). View file
|
|
|
util/__pycache__/pos_embed.cpython-312.pyc
ADDED
|
Binary file (4.03 kB). View file
|
|
|
util/crop.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from torchvision import transforms
|
| 12 |
+
from torchvision.transforms import functional as F
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class RandomResizedCrop(transforms.RandomResizedCrop):
|
| 16 |
+
"""
|
| 17 |
+
RandomResizedCrop for matching TF/TPU implementation: no for-loop is used.
|
| 18 |
+
This may lead to results different with torchvision's version.
|
| 19 |
+
Following BYOL's TF code:
|
| 20 |
+
https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206
|
| 21 |
+
"""
|
| 22 |
+
@staticmethod
|
| 23 |
+
def get_params(img, scale, ratio):
|
| 24 |
+
width, height = F._get_image_size(img)
|
| 25 |
+
area = height * width
|
| 26 |
+
|
| 27 |
+
target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
|
| 28 |
+
log_ratio = torch.log(torch.tensor(ratio))
|
| 29 |
+
aspect_ratio = torch.exp(
|
| 30 |
+
torch.empty(1).uniform_(log_ratio[0], log_ratio[1])
|
| 31 |
+
).item()
|
| 32 |
+
|
| 33 |
+
w = int(round(math.sqrt(target_area * aspect_ratio)))
|
| 34 |
+
h = int(round(math.sqrt(target_area / aspect_ratio)))
|
| 35 |
+
|
| 36 |
+
w = min(w, width)
|
| 37 |
+
h = min(h, height)
|
| 38 |
+
|
| 39 |
+
i = torch.randint(0, height - h + 1, size=(1,)).item()
|
| 40 |
+
j = torch.randint(0, width - w + 1, size=(1,)).item()
|
| 41 |
+
|
| 42 |
+
return i, j, h, w
|
util/datasets.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
# References:
|
| 8 |
+
# DeiT: https://github.com/facebookresearch/deit
|
| 9 |
+
# --------------------------------------------------------
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
import PIL
|
| 13 |
+
|
| 14 |
+
from torchvision import datasets, transforms
|
| 15 |
+
|
| 16 |
+
from timm.data import create_transform
|
| 17 |
+
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def build_dataset(is_train, args):
|
| 21 |
+
transform = build_transform(is_train, args)
|
| 22 |
+
|
| 23 |
+
root = os.path.join(args.data_path, 'train' if is_train else 'val')
|
| 24 |
+
dataset = datasets.ImageFolder(root, transform=transform)
|
| 25 |
+
|
| 26 |
+
print(dataset)
|
| 27 |
+
|
| 28 |
+
return dataset
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def build_transform(is_train, args):
|
| 32 |
+
mean = IMAGENET_DEFAULT_MEAN
|
| 33 |
+
std = IMAGENET_DEFAULT_STD
|
| 34 |
+
# train transform
|
| 35 |
+
if is_train:
|
| 36 |
+
# this should always dispatch to transforms_imagenet_train
|
| 37 |
+
transform = create_transform(
|
| 38 |
+
input_size=args.input_size,
|
| 39 |
+
is_training=True,
|
| 40 |
+
color_jitter=args.color_jitter,
|
| 41 |
+
auto_augment=args.aa,
|
| 42 |
+
interpolation='bicubic',
|
| 43 |
+
re_prob=args.reprob,
|
| 44 |
+
re_mode=args.remode,
|
| 45 |
+
re_count=args.recount,
|
| 46 |
+
)
|
| 47 |
+
return transform
|
| 48 |
+
|
| 49 |
+
# eval transform
|
| 50 |
+
t = []
|
| 51 |
+
if args.input_size <= 224:
|
| 52 |
+
crop_pct = 224 / 256
|
| 53 |
+
else:
|
| 54 |
+
crop_pct = 1.0
|
| 55 |
+
size = int(args.input_size / crop_pct)
|
| 56 |
+
t.append(
|
| 57 |
+
transforms.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images
|
| 58 |
+
)
|
| 59 |
+
t.append(transforms.CenterCrop(args.input_size))
|
| 60 |
+
|
| 61 |
+
t.append(transforms.ToTensor())
|
| 62 |
+
t.append(transforms.Normalize(mean, std))
|
| 63 |
+
return transforms.Compose(t)
|
util/lars.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
# LARS optimizer, implementation from MoCo v3:
|
| 8 |
+
# https://github.com/facebookresearch/moco-v3
|
| 9 |
+
# --------------------------------------------------------
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class LARS(torch.optim.Optimizer):
|
| 15 |
+
"""
|
| 16 |
+
LARS optimizer, no rate scaling or weight decay for parameters <= 1D.
|
| 17 |
+
"""
|
| 18 |
+
def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001):
|
| 19 |
+
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient)
|
| 20 |
+
super().__init__(params, defaults)
|
| 21 |
+
|
| 22 |
+
@torch.no_grad()
|
| 23 |
+
def step(self):
|
| 24 |
+
for g in self.param_groups:
|
| 25 |
+
for p in g['params']:
|
| 26 |
+
dp = p.grad
|
| 27 |
+
|
| 28 |
+
if dp is None:
|
| 29 |
+
continue
|
| 30 |
+
|
| 31 |
+
if p.ndim > 1: # if not normalization gamma/beta or bias
|
| 32 |
+
dp = dp.add(p, alpha=g['weight_decay'])
|
| 33 |
+
param_norm = torch.norm(p)
|
| 34 |
+
update_norm = torch.norm(dp)
|
| 35 |
+
one = torch.ones_like(param_norm)
|
| 36 |
+
q = torch.where(param_norm > 0.,
|
| 37 |
+
torch.where(update_norm > 0,
|
| 38 |
+
(g['trust_coefficient'] * param_norm / update_norm), one),
|
| 39 |
+
one)
|
| 40 |
+
dp = dp.mul(q)
|
| 41 |
+
|
| 42 |
+
param_state = self.state[p]
|
| 43 |
+
if 'mu' not in param_state:
|
| 44 |
+
param_state['mu'] = torch.zeros_like(p)
|
| 45 |
+
mu = param_state['mu']
|
| 46 |
+
mu.mul_(g['momentum']).add_(dp)
|
| 47 |
+
p.add_(mu, alpha=-g['lr'])
|
util/lr_decay.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
# References:
|
| 8 |
+
# ELECTRA https://github.com/google-research/electra
|
| 9 |
+
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
|
| 10 |
+
# --------------------------------------------------------
|
| 11 |
+
|
| 12 |
+
import json
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75):
|
| 16 |
+
"""
|
| 17 |
+
Parameter groups for layer-wise lr decay
|
| 18 |
+
Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
|
| 19 |
+
"""
|
| 20 |
+
param_group_names = {}
|
| 21 |
+
param_groups = {}
|
| 22 |
+
|
| 23 |
+
num_layers = len(model.blocks) + 1
|
| 24 |
+
|
| 25 |
+
layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1))
|
| 26 |
+
|
| 27 |
+
for n, p in model.named_parameters():
|
| 28 |
+
if not p.requires_grad:
|
| 29 |
+
continue
|
| 30 |
+
|
| 31 |
+
# no decay: all 1D parameters and model specific ones
|
| 32 |
+
if p.ndim == 1 or n in no_weight_decay_list:
|
| 33 |
+
g_decay = "no_decay"
|
| 34 |
+
this_decay = 0.
|
| 35 |
+
else:
|
| 36 |
+
g_decay = "decay"
|
| 37 |
+
this_decay = weight_decay
|
| 38 |
+
|
| 39 |
+
layer_id = get_layer_id_for_vit(n, num_layers)
|
| 40 |
+
group_name = "layer_%d_%s" % (layer_id, g_decay)
|
| 41 |
+
|
| 42 |
+
if group_name not in param_group_names:
|
| 43 |
+
this_scale = layer_scales[layer_id]
|
| 44 |
+
|
| 45 |
+
param_group_names[group_name] = {
|
| 46 |
+
"lr_scale": this_scale,
|
| 47 |
+
"weight_decay": this_decay,
|
| 48 |
+
"params": [],
|
| 49 |
+
}
|
| 50 |
+
param_groups[group_name] = {
|
| 51 |
+
"lr_scale": this_scale,
|
| 52 |
+
"weight_decay": this_decay,
|
| 53 |
+
"params": [],
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
param_group_names[group_name]["params"].append(n)
|
| 57 |
+
param_groups[group_name]["params"].append(p)
|
| 58 |
+
|
| 59 |
+
# print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
|
| 60 |
+
|
| 61 |
+
return list(param_groups.values())
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def get_layer_id_for_vit(name, num_layers):
|
| 65 |
+
"""
|
| 66 |
+
Assign a parameter with its layer id
|
| 67 |
+
Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
|
| 68 |
+
"""
|
| 69 |
+
if name in ['cls_token', 'pos_embed']:
|
| 70 |
+
return 0
|
| 71 |
+
elif name.startswith('patch_embed'):
|
| 72 |
+
return 0
|
| 73 |
+
elif name.startswith('blocks'):
|
| 74 |
+
return int(name.split('.')[1]) + 1
|
| 75 |
+
else:
|
| 76 |
+
return num_layers
|
util/lr_sched.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def adjust_learning_rate(optimizer, epoch, args):
|
| 11 |
+
"""Decay the learning rate with half-cycle cosine after warmup"""
|
| 12 |
+
if epoch < args.warmup_epochs:
|
| 13 |
+
lr = args.blr * epoch / args.warmup_epochs
|
| 14 |
+
else:
|
| 15 |
+
lr = args.min_lr + (args.blr - args.min_lr) * 0.5 * \
|
| 16 |
+
(1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
|
| 17 |
+
for param_group in optimizer.param_groups:
|
| 18 |
+
if "lr_scale" in param_group:
|
| 19 |
+
param_group["lr"] = lr * param_group["lr_scale"]
|
| 20 |
+
else:
|
| 21 |
+
param_group["lr"] = lr
|
| 22 |
+
return lr
|
util/misc.py
ADDED
|
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
# References:
|
| 8 |
+
# DeiT: https://github.com/facebookresearch/deit
|
| 9 |
+
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
|
| 10 |
+
# --------------------------------------------------------
|
| 11 |
+
|
| 12 |
+
import builtins
|
| 13 |
+
import datetime
|
| 14 |
+
import os
|
| 15 |
+
import time
|
| 16 |
+
from collections import defaultdict, deque
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.distributed as dist
|
| 21 |
+
from PIL import ImageFile
|
| 22 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
dist_on_itp = False
|
| 26 |
+
|
| 27 |
+
# from torch._six import inf
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class SmoothedValue(object):
|
| 31 |
+
"""Track a series of values and provide access to smoothed values over a
|
| 32 |
+
window or the global series average.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(self, window_size=20, fmt=None):
|
| 36 |
+
if fmt is None:
|
| 37 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
| 38 |
+
self.deque = deque(maxlen=window_size)
|
| 39 |
+
self.total = 0.0
|
| 40 |
+
self.count = 0
|
| 41 |
+
self.fmt = fmt
|
| 42 |
+
|
| 43 |
+
def update(self, value, n=1):
|
| 44 |
+
self.deque.append(value)
|
| 45 |
+
self.count += n
|
| 46 |
+
self.total += value * n
|
| 47 |
+
|
| 48 |
+
def synchronize_between_processes(self):
|
| 49 |
+
"""
|
| 50 |
+
Warning: does not synchronize the deque!
|
| 51 |
+
"""
|
| 52 |
+
if not is_dist_avail_and_initialized():
|
| 53 |
+
return
|
| 54 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
| 55 |
+
dist.barrier()
|
| 56 |
+
dist.all_reduce(t)
|
| 57 |
+
t = t.tolist()
|
| 58 |
+
self.count = int(t[0])
|
| 59 |
+
self.total = t[1]
|
| 60 |
+
|
| 61 |
+
@property
|
| 62 |
+
def median(self):
|
| 63 |
+
d = torch.tensor(list(self.deque))
|
| 64 |
+
return d.median().item()
|
| 65 |
+
|
| 66 |
+
@property
|
| 67 |
+
def avg(self):
|
| 68 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
| 69 |
+
return d.mean().item()
|
| 70 |
+
|
| 71 |
+
@property
|
| 72 |
+
def global_avg(self):
|
| 73 |
+
if self.count == 0:
|
| 74 |
+
return 0
|
| 75 |
+
else:
|
| 76 |
+
return self.total / self.count
|
| 77 |
+
|
| 78 |
+
@property
|
| 79 |
+
def max(self):
|
| 80 |
+
return max(self.deque)
|
| 81 |
+
|
| 82 |
+
@property
|
| 83 |
+
def value(self):
|
| 84 |
+
return self.deque[-1]
|
| 85 |
+
|
| 86 |
+
def __str__(self):
|
| 87 |
+
return self.fmt.format(
|
| 88 |
+
median=self.median,
|
| 89 |
+
avg=self.avg,
|
| 90 |
+
global_avg=self.global_avg,
|
| 91 |
+
max=self.max,
|
| 92 |
+
value=self.value)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class MetricLogger(object):
|
| 96 |
+
def __init__(self, delimiter="\t"):
|
| 97 |
+
self.meters = defaultdict(SmoothedValue)
|
| 98 |
+
self.delimiter = delimiter
|
| 99 |
+
|
| 100 |
+
def update(self, **kwargs):
|
| 101 |
+
for k, v in kwargs.items():
|
| 102 |
+
if v is None:
|
| 103 |
+
continue
|
| 104 |
+
if isinstance(v, torch.Tensor):
|
| 105 |
+
v = v.item()
|
| 106 |
+
assert isinstance(v, (float, int))
|
| 107 |
+
self.meters[k].update(v)
|
| 108 |
+
|
| 109 |
+
def __getattr__(self, attr):
|
| 110 |
+
if attr in self.meters:
|
| 111 |
+
return self.meters[attr]
|
| 112 |
+
if attr in self.__dict__:
|
| 113 |
+
return self.__dict__[attr]
|
| 114 |
+
raise AttributeError("'{}' object has no attribute '{}'".format(
|
| 115 |
+
type(self).__name__, attr))
|
| 116 |
+
|
| 117 |
+
def __str__(self):
|
| 118 |
+
loss_str = []
|
| 119 |
+
for name, meter in self.meters.items():
|
| 120 |
+
loss_str.append(
|
| 121 |
+
"{}: {}".format(name, str(meter))
|
| 122 |
+
)
|
| 123 |
+
return self.delimiter.join(loss_str)
|
| 124 |
+
|
| 125 |
+
def synchronize_between_processes(self):
|
| 126 |
+
for meter in self.meters.values():
|
| 127 |
+
meter.synchronize_between_processes()
|
| 128 |
+
|
| 129 |
+
def add_meter(self, name, meter):
|
| 130 |
+
self.meters[name] = meter
|
| 131 |
+
|
| 132 |
+
def log_every(self, iterable, print_freq, header=None):
|
| 133 |
+
i = 0
|
| 134 |
+
if not header:
|
| 135 |
+
header = ''
|
| 136 |
+
start_time = time.time()
|
| 137 |
+
end = time.time()
|
| 138 |
+
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
| 139 |
+
data_time = SmoothedValue(fmt='{avg:.4f}')
|
| 140 |
+
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
| 141 |
+
log_msg = [
|
| 142 |
+
header,
|
| 143 |
+
'[{0' + space_fmt + '}/{1}]',
|
| 144 |
+
'eta: {eta}',
|
| 145 |
+
'{meters}',
|
| 146 |
+
'time: {time}',
|
| 147 |
+
'data: {data}'
|
| 148 |
+
]
|
| 149 |
+
if torch.cuda.is_available():
|
| 150 |
+
log_msg.append('max mem: {memory:.0f}')
|
| 151 |
+
log_msg = self.delimiter.join(log_msg)
|
| 152 |
+
MB = 1024.0 * 1024.0
|
| 153 |
+
for obj in iterable:
|
| 154 |
+
data_time.update(time.time() - end)
|
| 155 |
+
yield obj
|
| 156 |
+
iter_time.update(time.time() - end)
|
| 157 |
+
if i % print_freq == 0 or i == len(iterable) - 1:
|
| 158 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
| 159 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
| 160 |
+
if torch.cuda.is_available():
|
| 161 |
+
print(log_msg.format(
|
| 162 |
+
i, len(iterable), eta=eta_string,
|
| 163 |
+
meters=str(self),
|
| 164 |
+
time=str(iter_time), data=str(data_time),
|
| 165 |
+
memory=torch.cuda.max_memory_allocated() / MB))
|
| 166 |
+
else:
|
| 167 |
+
print(log_msg.format(
|
| 168 |
+
i, len(iterable), eta=eta_string,
|
| 169 |
+
meters=str(self),
|
| 170 |
+
time=str(iter_time), data=str(data_time)))
|
| 171 |
+
i += 1
|
| 172 |
+
end = time.time()
|
| 173 |
+
total_time = time.time() - start_time
|
| 174 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
| 175 |
+
if len(iterable) == 0:
|
| 176 |
+
print('Total time: {} ({:.4f} s / it)'.format(total_time_str, 0))
|
| 177 |
+
else:
|
| 178 |
+
print('{} Total time: {} ({:.4f} s / it)'.format(
|
| 179 |
+
header, total_time_str, total_time / len(iterable)))
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def setup_for_distributed(is_master):
|
| 183 |
+
"""
|
| 184 |
+
This function disables printing when not in master process
|
| 185 |
+
"""
|
| 186 |
+
builtin_print = builtins.print
|
| 187 |
+
|
| 188 |
+
def print(*args, **kwargs):
|
| 189 |
+
force = kwargs.pop('force', False)
|
| 190 |
+
force = force or (get_world_size() > 8)
|
| 191 |
+
if is_master or force:
|
| 192 |
+
now = datetime.datetime.now().time()
|
| 193 |
+
builtin_print('[{}] '.format(now), end='') # print with time stamp
|
| 194 |
+
builtin_print(*args, **kwargs)
|
| 195 |
+
|
| 196 |
+
builtins.print = print
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def is_dist_avail_and_initialized():
|
| 200 |
+
if not dist.is_available():
|
| 201 |
+
return False
|
| 202 |
+
if not dist.is_initialized():
|
| 203 |
+
return False
|
| 204 |
+
return True
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def get_world_size():
|
| 208 |
+
if not is_dist_avail_and_initialized():
|
| 209 |
+
return 1
|
| 210 |
+
return dist.get_world_size()
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def get_rank():
|
| 214 |
+
if not is_dist_avail_and_initialized():
|
| 215 |
+
return 0
|
| 216 |
+
return dist.get_rank()
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def is_main_process():
|
| 220 |
+
return get_rank() == 0
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def save_on_master(*args, **kwargs):
|
| 224 |
+
if is_main_process():
|
| 225 |
+
torch.save(*args, **kwargs)
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def init_distributed_mode(args):
|
| 229 |
+
if dist_on_itp:
|
| 230 |
+
args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
|
| 231 |
+
args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
|
| 232 |
+
args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
|
| 233 |
+
args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
|
| 234 |
+
os.environ['LOCAL_RANK'] = str(args.gpu)
|
| 235 |
+
os.environ['RANK'] = str(args.rank)
|
| 236 |
+
os.environ['WORLD_SIZE'] = str(args.world_size)
|
| 237 |
+
# ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
|
| 238 |
+
elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
| 239 |
+
args.rank = int(os.environ["RANK"])
|
| 240 |
+
args.world_size = int(os.environ['WORLD_SIZE'])
|
| 241 |
+
args.gpu = int(os.environ['LOCAL_RANK'])
|
| 242 |
+
elif 'SLURM_PROCID' in os.environ:
|
| 243 |
+
args.rank = int(os.environ['SLURM_PROCID'])
|
| 244 |
+
args.gpu = args.rank % torch.cuda.device_count()
|
| 245 |
+
else:
|
| 246 |
+
print('Not using distributed mode')
|
| 247 |
+
setup_for_distributed(is_master=True) # hack
|
| 248 |
+
args.distributed = False
|
| 249 |
+
return
|
| 250 |
+
|
| 251 |
+
args.distributed = True
|
| 252 |
+
|
| 253 |
+
torch.cuda.set_device(args.gpu)
|
| 254 |
+
args.dist_url = 'env://'
|
| 255 |
+
args.dist_backend = 'nccl'
|
| 256 |
+
print('| distributed init (rank {}): {}, gpu {}'.format(
|
| 257 |
+
args.rank, args.dist_url, args.gpu), flush=True)
|
| 258 |
+
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
| 259 |
+
world_size=args.world_size, rank=args.rank)
|
| 260 |
+
torch.distributed.barrier()
|
| 261 |
+
setup_for_distributed(args.rank == 0)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
class NativeScalerWithGradNormCount:
|
| 265 |
+
state_dict_key = "amp_scaler"
|
| 266 |
+
|
| 267 |
+
def __init__(self):
|
| 268 |
+
self._scaler = torch.cuda.amp.GradScaler()
|
| 269 |
+
|
| 270 |
+
def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
|
| 271 |
+
self._scaler.scale(loss).backward(create_graph=create_graph)
|
| 272 |
+
if update_grad:
|
| 273 |
+
if clip_grad is not None:
|
| 274 |
+
assert parameters is not None
|
| 275 |
+
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
|
| 276 |
+
norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
|
| 277 |
+
else:
|
| 278 |
+
self._scaler.unscale_(optimizer)
|
| 279 |
+
norm = get_grad_norm_(parameters)
|
| 280 |
+
self._scaler.step(optimizer)
|
| 281 |
+
self._scaler.update()
|
| 282 |
+
else:
|
| 283 |
+
norm = None
|
| 284 |
+
return norm
|
| 285 |
+
|
| 286 |
+
def state_dict(self):
|
| 287 |
+
return self._scaler.state_dict()
|
| 288 |
+
|
| 289 |
+
def load_state_dict(self, state_dict):
|
| 290 |
+
self._scaler.load_state_dict(state_dict)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
|
| 294 |
+
if isinstance(parameters, torch.Tensor):
|
| 295 |
+
parameters = [parameters]
|
| 296 |
+
parameters = [p for p in parameters if p.grad is not None]
|
| 297 |
+
norm_type = float(norm_type)
|
| 298 |
+
if len(parameters) == 0:
|
| 299 |
+
return torch.tensor(0.)
|
| 300 |
+
device = parameters[0].grad.device
|
| 301 |
+
if norm_type == float('inf'):
|
| 302 |
+
total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
|
| 303 |
+
else:
|
| 304 |
+
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
|
| 305 |
+
return total_norm
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler):
|
| 309 |
+
output_dir = Path(args.output_dir)
|
| 310 |
+
epoch_name = str(epoch)
|
| 311 |
+
if loss_scaler is not None:
|
| 312 |
+
checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
|
| 313 |
+
for checkpoint_path in checkpoint_paths:
|
| 314 |
+
to_save = {
|
| 315 |
+
'model': model_without_ddp.state_dict(),
|
| 316 |
+
'optimizer': optimizer.state_dict(),
|
| 317 |
+
'epoch': epoch,
|
| 318 |
+
'scaler': loss_scaler.state_dict(),
|
| 319 |
+
'args': args,
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
save_on_master(to_save, checkpoint_path)
|
| 323 |
+
else:
|
| 324 |
+
client_state = {'epoch': epoch}
|
| 325 |
+
model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state)
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
def load_model(args, model_without_ddp, optimizer, loss_scaler):
|
| 329 |
+
if args.resume:
|
| 330 |
+
if args.resume.startswith('https'):
|
| 331 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
| 332 |
+
args.resume, map_location='cpu', check_hash=True)
|
| 333 |
+
else:
|
| 334 |
+
checkpoint = torch.load(args.resume, map_location='cpu')
|
| 335 |
+
model_without_ddp.load_state_dict(checkpoint['model'])
|
| 336 |
+
print("Resume checkpoint %s" % args.resume)
|
| 337 |
+
if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):
|
| 338 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
| 339 |
+
args.start_epoch = checkpoint['epoch'] + 1
|
| 340 |
+
if 'scaler' in checkpoint:
|
| 341 |
+
loss_scaler.load_state_dict(checkpoint['scaler'])
|
| 342 |
+
print("With optim & sched!")
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def all_reduce_mean(x):
|
| 346 |
+
world_size = get_world_size()
|
| 347 |
+
if world_size > 1:
|
| 348 |
+
x_reduce = torch.tensor(x).cuda()
|
| 349 |
+
dist.all_reduce(x_reduce)
|
| 350 |
+
x_reduce /= world_size
|
| 351 |
+
return x_reduce.item()
|
| 352 |
+
else:
|
| 353 |
+
return x
|
util/pos_embed.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
# Position embedding utils
|
| 8 |
+
# --------------------------------------------------------
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
# --------------------------------------------------------
|
| 15 |
+
# 2D sine-cosine position embedding
|
| 16 |
+
# References:
|
| 17 |
+
# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
|
| 18 |
+
# MoCo v3: https://github.com/facebookresearch/moco-v3
|
| 19 |
+
# --------------------------------------------------------
|
| 20 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
|
| 21 |
+
"""
|
| 22 |
+
grid_size: int of the grid height and width
|
| 23 |
+
return:
|
| 24 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
| 25 |
+
"""
|
| 26 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
| 27 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
| 28 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
| 29 |
+
grid = np.stack(grid, axis=0)
|
| 30 |
+
|
| 31 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
| 32 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 33 |
+
if cls_token:
|
| 34 |
+
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
| 35 |
+
return pos_embed
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
| 39 |
+
assert embed_dim % 2 == 0
|
| 40 |
+
|
| 41 |
+
# use half of dimensions to encode grid_h
|
| 42 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
| 43 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
| 44 |
+
|
| 45 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
| 46 |
+
return emb
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 50 |
+
"""
|
| 51 |
+
embed_dim: output dimension for each position
|
| 52 |
+
pos: a list of positions to be encoded: size (M,)
|
| 53 |
+
out: (M, D)
|
| 54 |
+
"""
|
| 55 |
+
assert embed_dim % 2 == 0
|
| 56 |
+
omega = np.arange(embed_dim // 2, dtype=np.float)
|
| 57 |
+
omega /= embed_dim / 2.
|
| 58 |
+
omega = 1. / 10000**omega # (D/2,)
|
| 59 |
+
|
| 60 |
+
pos = pos.reshape(-1) # (M,)
|
| 61 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
| 62 |
+
|
| 63 |
+
emb_sin = np.sin(out) # (M, D/2)
|
| 64 |
+
emb_cos = np.cos(out) # (M, D/2)
|
| 65 |
+
|
| 66 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
| 67 |
+
return emb
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# --------------------------------------------------------
|
| 71 |
+
# Interpolate position embeddings for high-resolution
|
| 72 |
+
# References:
|
| 73 |
+
# DeiT: https://github.com/facebookresearch/deit
|
| 74 |
+
# --------------------------------------------------------
|
| 75 |
+
def interpolate_pos_embed(model, checkpoint_model):
|
| 76 |
+
if 'pos_embed' in checkpoint_model:
|
| 77 |
+
pos_embed_checkpoint = checkpoint_model['pos_embed']
|
| 78 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
| 79 |
+
num_patches = model.patch_embed.num_patches
|
| 80 |
+
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
|
| 81 |
+
# height (== width) for the checkpoint position embedding
|
| 82 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
| 83 |
+
# height (== width) for the new position embedding
|
| 84 |
+
new_size = int(num_patches ** 0.5)
|
| 85 |
+
# class_token and dist_token are kept unchanged
|
| 86 |
+
if orig_size != new_size:
|
| 87 |
+
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
|
| 88 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
| 89 |
+
# only the position tokens are interpolated
|
| 90 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
| 91 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
| 92 |
+
pos_tokens = torch.nn.functional.interpolate(
|
| 93 |
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
| 94 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
| 95 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
| 96 |
+
checkpoint_model['pos_embed'] = new_pos_embed
|
vit_model.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import timm
|
| 4 |
+
from timm.models.vision_transformer import PatchEmbed
|
| 5 |
+
from functools import partial
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class vit(timm.models.vision_transformer.VisionTransformer):
|
| 9 |
+
def __init__(self, global_pool=False, **kwargs):
|
| 10 |
+
super(vit, self).__init__()
|
| 11 |
+
self.global_pool = global_pool
|
| 12 |
+
embed_dim = kwargs['embed_dim']
|
| 13 |
+
num_classes = kwargs['num_classes']
|
| 14 |
+
self.head = nn.Linear(embed_dim, num_classes, bias=True)
|
| 15 |
+
if self.global_pool:
|
| 16 |
+
norm_layer = kwargs['norm_layer']
|
| 17 |
+
embed_dim = kwargs['embed_dim']
|
| 18 |
+
self.fc_norm = norm_layer(embed_dim)
|
| 19 |
+
|
| 20 |
+
del self.norm
|
| 21 |
+
|
| 22 |
+
for param in self.parameters():
|
| 23 |
+
param.requires_grad = False
|
| 24 |
+
|
| 25 |
+
for param in self.head.parameters():
|
| 26 |
+
param.requires_grad = True
|
| 27 |
+
|
| 28 |
+
def forward_features(self, x):
|
| 29 |
+
B = x.shape[0]
|
| 30 |
+
x = self.patch_embed(x)
|
| 31 |
+
|
| 32 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
| 33 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 34 |
+
x = x + self.pos_embed
|
| 35 |
+
x = self.pos_drop(x)
|
| 36 |
+
|
| 37 |
+
for blk in self.blocks:
|
| 38 |
+
x = blk(x)
|
| 39 |
+
|
| 40 |
+
if self.global_pool:
|
| 41 |
+
x = x[:, 1:, :].mean(dim=1)
|
| 42 |
+
outcome = self.fc_norm(x)
|
| 43 |
+
else:
|
| 44 |
+
x = self.norm(x)
|
| 45 |
+
outcome = x[:, 0]
|
| 46 |
+
|
| 47 |
+
return outcome
|
| 48 |
+
|
| 49 |
+
def forward(self, x):
|
| 50 |
+
x = self.forward_features(x)
|
| 51 |
+
x = self.head(x)
|
| 52 |
+
return x
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def vit_base_patch16(**kwargs):
|
| 56 |
+
model = vit(img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
| 57 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 58 |
+
return model
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def vit_large_patch16(**kwargs):
|
| 62 |
+
model = vit(patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
| 63 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 64 |
+
return model
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def vit_huge_patch14(**kwargs):
|
| 68 |
+
model = vit(patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
| 69 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 70 |
+
return model
|