Delete train.py
Browse files
train.py
DELETED
|
@@ -1,478 +0,0 @@
|
|
| 1 |
-
import argparse
|
| 2 |
-
import copy
|
| 3 |
-
from copy import deepcopy
|
| 4 |
-
import logging
|
| 5 |
-
import os
|
| 6 |
-
import torch
|
| 7 |
-
torch.hub.set_dir(r"/slurm-files/jdy/hub/")
|
| 8 |
-
os.environ["NCCL_TIMEOUT"] = "9000000000"
|
| 9 |
-
os.environ["HF_DATASETS_CACHE"] = "/slurm-files/jdy/cache/"
|
| 10 |
-
os.environ["HF_HOME"] = "/slurm-files/jdy/cache/"
|
| 11 |
-
os.environ["HUGGINGFACE_HUB_CACHE"] = "/slurm-files/jdy/cache/"
|
| 12 |
-
os.environ["HF_ENDPOINT"]="https://hf-mirror.com"
|
| 13 |
-
|
| 14 |
-
from pathlib import Path
|
| 15 |
-
from collections import OrderedDict
|
| 16 |
-
import json
|
| 17 |
-
|
| 18 |
-
import numpy as np
|
| 19 |
-
from torchvision.utils import save_image
|
| 20 |
-
import torch.nn.functional as F
|
| 21 |
-
import torch.utils.checkpoint
|
| 22 |
-
from tqdm.auto import tqdm
|
| 23 |
-
from torch.utils.data import DataLoader
|
| 24 |
-
|
| 25 |
-
from accelerate import Accelerator
|
| 26 |
-
from accelerate.logging import get_logger
|
| 27 |
-
from accelerate.utils import ProjectConfiguration, set_seed
|
| 28 |
-
|
| 29 |
-
from models.sit import SiT_models
|
| 30 |
-
from loss import SILoss
|
| 31 |
-
from utils import load_encoders
|
| 32 |
-
|
| 33 |
-
from dataset import CustomDataset
|
| 34 |
-
from diffusers.models import AutoencoderKL
|
| 35 |
-
# import wandb_utils
|
| 36 |
-
import math
|
| 37 |
-
from torchvision.utils import make_grid
|
| 38 |
-
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
| 39 |
-
from torchvision.transforms import Normalize
|
| 40 |
-
from PIL import Image
|
| 41 |
-
logger = get_logger(__name__)
|
| 42 |
-
|
| 43 |
-
CLIP_DEFAULT_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
| 44 |
-
CLIP_DEFAULT_STD = (0.26862954, 0.26130258, 0.27577711)
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
def preprocess_raw_image(x, enc_type):
|
| 48 |
-
if 'clip' in enc_type:
|
| 49 |
-
x = x / 255.
|
| 50 |
-
x = torch.nn.functional.interpolate(x, 224, mode='bicubic')
|
| 51 |
-
x = Normalize(CLIP_DEFAULT_MEAN, CLIP_DEFAULT_STD)(x)
|
| 52 |
-
elif 'mocov3' in enc_type or 'mae' in enc_type:
|
| 53 |
-
x = x / 255.
|
| 54 |
-
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
| 55 |
-
elif 'dinov2' in enc_type:
|
| 56 |
-
x = x / 255.
|
| 57 |
-
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
| 58 |
-
x = torch.nn.functional.interpolate(x, 224, mode='bicubic')
|
| 59 |
-
elif 'dinov1' in enc_type:
|
| 60 |
-
x = x / 255.
|
| 61 |
-
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
| 62 |
-
elif 'jepa' in enc_type:
|
| 63 |
-
x = x / 255.
|
| 64 |
-
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
| 65 |
-
x = torch.nn.functional.interpolate(x, 224, mode='bicubic')
|
| 66 |
-
|
| 67 |
-
return x
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
def array2grid(x):
|
| 71 |
-
nrow = round(math.sqrt(x.size(0)))
|
| 72 |
-
x = make_grid(x.clamp(0, 1), nrow=nrow, value_range=(0, 1))
|
| 73 |
-
x = x.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
|
| 74 |
-
return x
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
@torch.no_grad()
|
| 78 |
-
def sample_posterior(moments, latents_scale=1., latents_bias=0.):
|
| 79 |
-
device = moments.device
|
| 80 |
-
|
| 81 |
-
mean, std = torch.chunk(moments, 2, dim=1)
|
| 82 |
-
z = mean + std * torch.randn_like(mean)
|
| 83 |
-
z = (z * latents_scale + latents_bias)
|
| 84 |
-
return z
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
@torch.no_grad()
|
| 88 |
-
def update_ema(ema_model, model, decay=0.9999):
|
| 89 |
-
"""
|
| 90 |
-
Step the EMA model towards the current model.
|
| 91 |
-
"""
|
| 92 |
-
ema_params = OrderedDict(ema_model.named_parameters())
|
| 93 |
-
model_params = OrderedDict(model.named_parameters())
|
| 94 |
-
|
| 95 |
-
for name, param in model_params.items():
|
| 96 |
-
name = name.replace("module.", "")
|
| 97 |
-
# TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
|
| 98 |
-
ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
def create_logger(logging_dir):
|
| 102 |
-
"""
|
| 103 |
-
Create a logger that writes to a log file and stdout.
|
| 104 |
-
"""
|
| 105 |
-
logging.basicConfig(
|
| 106 |
-
level=logging.INFO,
|
| 107 |
-
format='[\033[34m%(asctime)s\033[0m] %(message)s',
|
| 108 |
-
datefmt='%Y-%m-%d %H:%M:%S',
|
| 109 |
-
handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
|
| 110 |
-
)
|
| 111 |
-
logger = logging.getLogger(__name__)
|
| 112 |
-
return logger
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
def requires_grad(model, flag=True):
|
| 116 |
-
"""
|
| 117 |
-
Set requires_grad flag for all parameters in a model.
|
| 118 |
-
"""
|
| 119 |
-
for p in model.parameters():
|
| 120 |
-
p.requires_grad = flag
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
#################################################################################
|
| 124 |
-
# Training Loop #
|
| 125 |
-
#################################################################################
|
| 126 |
-
|
| 127 |
-
def main(args):
|
| 128 |
-
# set accelerator
|
| 129 |
-
logging_dir = Path(args.output_dir, args.logging_dir)
|
| 130 |
-
accelerator_project_config = ProjectConfiguration(
|
| 131 |
-
project_dir=args.output_dir, logging_dir=logging_dir
|
| 132 |
-
)
|
| 133 |
-
|
| 134 |
-
accelerator = Accelerator(
|
| 135 |
-
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 136 |
-
mixed_precision=args.mixed_precision,
|
| 137 |
-
project_config=accelerator_project_config,
|
| 138 |
-
)
|
| 139 |
-
|
| 140 |
-
if accelerator.is_main_process:
|
| 141 |
-
os.makedirs(args.output_dir, exist_ok=True) # Make results folder (holds all experiment subfolders)
|
| 142 |
-
save_dir = os.path.join(args.output_dir, args.exp_name)
|
| 143 |
-
os.makedirs(save_dir, exist_ok=True)
|
| 144 |
-
args_dict = vars(args)
|
| 145 |
-
# Save to a JSON file
|
| 146 |
-
json_dir = os.path.join(save_dir, "args.json")
|
| 147 |
-
with open(json_dir, 'w') as f:
|
| 148 |
-
json.dump(args_dict, f, indent=4)
|
| 149 |
-
checkpoint_dir = f"{save_dir}/checkpoints" # Stores saved model checkpoints
|
| 150 |
-
os.makedirs(checkpoint_dir, exist_ok=True)
|
| 151 |
-
logger = create_logger(save_dir)
|
| 152 |
-
logger.info(f"Experiment directory created at {save_dir}")
|
| 153 |
-
device = accelerator.device
|
| 154 |
-
if torch.backends.mps.is_available():
|
| 155 |
-
accelerator.native_amp = False
|
| 156 |
-
if args.seed is not None:
|
| 157 |
-
set_seed(args.seed + accelerator.process_index)
|
| 158 |
-
|
| 159 |
-
# Create model:
|
| 160 |
-
assert args.resolution % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)."
|
| 161 |
-
latent_size = args.resolution // 8
|
| 162 |
-
|
| 163 |
-
if args.enc_type != 'None':
|
| 164 |
-
encoders, encoder_types, architectures = load_encoders(args.enc_type, device)
|
| 165 |
-
else:
|
| 166 |
-
encoders, encoder_types, architectures = [None], [None], [None]
|
| 167 |
-
z_dims = [encoder.embed_dim for encoder in encoders] if args.enc_type != 'None' else [0]
|
| 168 |
-
block_kwargs = {"fused_attn": args.fused_attn, "qk_norm": args.qk_norm}
|
| 169 |
-
model = SiT_models[args.model](
|
| 170 |
-
input_size=latent_size,
|
| 171 |
-
num_classes=args.num_classes,
|
| 172 |
-
use_cfg = (args.cfg_prob > 0),
|
| 173 |
-
z_dims = z_dims,
|
| 174 |
-
encoder_depth=args.encoder_depth,
|
| 175 |
-
**block_kwargs
|
| 176 |
-
)
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
model = model.to(device)
|
| 180 |
-
ema = deepcopy(model).to(device) # Create an EMA of the model for use after training
|
| 181 |
-
vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-mse").to(device)
|
| 182 |
-
requires_grad(ema, False)
|
| 183 |
-
|
| 184 |
-
latents_scale = torch.tensor(
|
| 185 |
-
[0.18215, 0.18215, 0.18215, 0.18215]
|
| 186 |
-
).view(1, 4, 1, 1).to(device)
|
| 187 |
-
latents_bias = torch.tensor(
|
| 188 |
-
[0., 0., 0., 0.]
|
| 189 |
-
).view(1, 4, 1, 1).to(device)
|
| 190 |
-
|
| 191 |
-
# create loss function
|
| 192 |
-
loss_fn = SILoss(
|
| 193 |
-
prediction=args.prediction,
|
| 194 |
-
path_type=args.path_type,
|
| 195 |
-
encoders=encoders,
|
| 196 |
-
accelerator=accelerator,
|
| 197 |
-
latents_scale=latents_scale,
|
| 198 |
-
latents_bias=latents_bias,
|
| 199 |
-
weighting=args.weighting
|
| 200 |
-
)
|
| 201 |
-
if accelerator.is_main_process:
|
| 202 |
-
logger.info(f"SiT Parameters: {sum(p.numel() for p in model.parameters()):,}")
|
| 203 |
-
|
| 204 |
-
# Setup optimizer (we used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper):
|
| 205 |
-
if args.allow_tf32:
|
| 206 |
-
torch.backends.cuda.matmul.allow_tf32 = True
|
| 207 |
-
torch.backends.cudnn.allow_tf32 = True
|
| 208 |
-
|
| 209 |
-
optimizer = torch.optim.AdamW(
|
| 210 |
-
model.parameters(),
|
| 211 |
-
lr=args.learning_rate,
|
| 212 |
-
betas=(args.adam_beta1, args.adam_beta2),
|
| 213 |
-
weight_decay=args.adam_weight_decay,
|
| 214 |
-
eps=args.adam_epsilon,
|
| 215 |
-
)
|
| 216 |
-
|
| 217 |
-
# Setup data:
|
| 218 |
-
train_dataset = CustomDataset(args.data_dir)
|
| 219 |
-
local_batch_size = int(args.batch_size)
|
| 220 |
-
|
| 221 |
-
train_dataloader = DataLoader(
|
| 222 |
-
train_dataset,
|
| 223 |
-
batch_size=local_batch_size,
|
| 224 |
-
shuffle=True,
|
| 225 |
-
num_workers=args.num_workers,
|
| 226 |
-
pin_memory=True,
|
| 227 |
-
drop_last=True
|
| 228 |
-
)
|
| 229 |
-
if accelerator.is_main_process:
|
| 230 |
-
logger.info(f"Dataset contains {len(train_dataset):,} images ({args.data_dir})")
|
| 231 |
-
|
| 232 |
-
# Prepare models for training:
|
| 233 |
-
update_ema(ema, model, decay=0) # Ensure EMA is initialized with synced weights
|
| 234 |
-
model.train() # important! This enables embedding dropout for classifier-free guidance
|
| 235 |
-
ema.eval() # EMA model should always be in eval mode
|
| 236 |
-
|
| 237 |
-
# resume:
|
| 238 |
-
global_step = 0
|
| 239 |
-
if args.resume_step > 0:
|
| 240 |
-
ckpt_name = str(args.resume_step).zfill(7) +'.pt'
|
| 241 |
-
ckpt = torch.load(
|
| 242 |
-
f'{os.path.join(args.output_dir, args.exp_name)}/checkpoints/{ckpt_name}',
|
| 243 |
-
map_location='cpu',
|
| 244 |
-
)
|
| 245 |
-
model.load_state_dict(ckpt['model'])
|
| 246 |
-
ema.load_state_dict(ckpt['ema'])
|
| 247 |
-
optimizer.load_state_dict(ckpt['opt'])
|
| 248 |
-
global_step = ckpt['steps']
|
| 249 |
-
|
| 250 |
-
model, optimizer, train_dataloader = accelerator.prepare(
|
| 251 |
-
model, optimizer, train_dataloader
|
| 252 |
-
)
|
| 253 |
-
|
| 254 |
-
if accelerator.is_main_process:
|
| 255 |
-
logger.info(f"Starting training experiment: {args.exp_name}")
|
| 256 |
-
|
| 257 |
-
progress_bar = tqdm(
|
| 258 |
-
range(0, args.max_train_steps),
|
| 259 |
-
initial=global_step,
|
| 260 |
-
desc="Steps",
|
| 261 |
-
# Only show the progress bar once on each machine.
|
| 262 |
-
disable=not accelerator.is_local_main_process,
|
| 263 |
-
)
|
| 264 |
-
|
| 265 |
-
# Labels to condition the model with (feel free to change):
|
| 266 |
-
sample_batch_size = 16 // accelerator.num_processes
|
| 267 |
-
_, gt_xs, _ = next(iter(train_dataloader))
|
| 268 |
-
gt_xs = gt_xs[:sample_batch_size]
|
| 269 |
-
gt_xs = sample_posterior(
|
| 270 |
-
gt_xs.to(device), latents_scale=latents_scale, latents_bias=latents_bias
|
| 271 |
-
)
|
| 272 |
-
ys = torch.randint(1000, size=(sample_batch_size,), device=device)
|
| 273 |
-
ys = ys.to(device)
|
| 274 |
-
# Create sampling noise:
|
| 275 |
-
n = ys.size(0)
|
| 276 |
-
xT = torch.randn((n, 4, latent_size, latent_size), device=device)
|
| 277 |
-
|
| 278 |
-
for epoch in range(args.epochs):
|
| 279 |
-
model.train()
|
| 280 |
-
for raw_image, x, y in train_dataloader:
|
| 281 |
-
raw_image = raw_image.to(device)
|
| 282 |
-
x = x.squeeze(dim=1).to(device)
|
| 283 |
-
y = y.to(device)
|
| 284 |
-
z = None
|
| 285 |
-
if args.legacy:
|
| 286 |
-
# In our early experiments, we accidentally apply label dropping twice:
|
| 287 |
-
# once in train.py and once in sit.py.
|
| 288 |
-
# We keep this option for exact reproducibility with previous runs.
|
| 289 |
-
drop_ids = torch.rand(y.shape[0], device=y.device) < args.cfg_prob
|
| 290 |
-
labels = torch.where(drop_ids, args.num_classes, y)
|
| 291 |
-
else:
|
| 292 |
-
labels = y
|
| 293 |
-
with torch.no_grad():
|
| 294 |
-
x = sample_posterior(x, latents_scale=latents_scale, latents_bias=latents_bias)
|
| 295 |
-
zs = []
|
| 296 |
-
import time
|
| 297 |
-
start = time.perf_counter()
|
| 298 |
-
with accelerator.autocast():
|
| 299 |
-
|
| 300 |
-
for encoder, encoder_type, arch in zip(encoders, encoder_types, architectures):
|
| 301 |
-
raw_image_ = preprocess_raw_image(raw_image, encoder_type)
|
| 302 |
-
z = encoder.forward_features(raw_image_)
|
| 303 |
-
if 'mocov3' in encoder_type: z = z = z[:, 1:]
|
| 304 |
-
if 'dinov2' in encoder_type: z = z['x_norm_patchtokens']
|
| 305 |
-
zs.append(z)
|
| 306 |
-
|
| 307 |
-
end = time.perf_counter()
|
| 308 |
-
elapsed_ms = end - start
|
| 309 |
-
|
| 310 |
-
with accelerator.accumulate(model):
|
| 311 |
-
model_kwargs = dict(y=labels)
|
| 312 |
-
loss, proj_loss = loss_fn(model, x, model_kwargs, zs=zs)
|
| 313 |
-
loss_mean = loss.mean()
|
| 314 |
-
proj_loss_mean = proj_loss.mean()
|
| 315 |
-
loss = loss_mean + proj_loss_mean * args.proj_coeff
|
| 316 |
-
|
| 317 |
-
## optimization
|
| 318 |
-
start = time.perf_counter()
|
| 319 |
-
accelerator.backward(loss)
|
| 320 |
-
end = time.perf_counter()
|
| 321 |
-
bpt = end - start
|
| 322 |
-
if accelerator.sync_gradients:
|
| 323 |
-
params_to_clip = model.parameters()
|
| 324 |
-
grad_norm = accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
| 325 |
-
optimizer.step()
|
| 326 |
-
optimizer.zero_grad(set_to_none=True)
|
| 327 |
-
|
| 328 |
-
if accelerator.sync_gradients:
|
| 329 |
-
update_ema(ema, model) # change ema function
|
| 330 |
-
|
| 331 |
-
### enter
|
| 332 |
-
if accelerator.sync_gradients:
|
| 333 |
-
progress_bar.update(1)
|
| 334 |
-
global_step += 1
|
| 335 |
-
if global_step % args.checkpointing_steps == 0 and global_step > 0:
|
| 336 |
-
if accelerator.is_main_process:
|
| 337 |
-
checkpoint = {
|
| 338 |
-
"model": model.module.state_dict(),
|
| 339 |
-
"ema": ema.state_dict(),
|
| 340 |
-
"opt": optimizer.state_dict(),
|
| 341 |
-
"args": args,
|
| 342 |
-
"steps": global_step,
|
| 343 |
-
}
|
| 344 |
-
checkpoint_path = f"{checkpoint_dir}/{global_step:07d}.pt"
|
| 345 |
-
torch.save(checkpoint, checkpoint_path)
|
| 346 |
-
logger.info(f"Saved checkpoint to {checkpoint_path}")
|
| 347 |
-
|
| 348 |
-
# you can set global_step==1 instead of 1e10 to help to debug
|
| 349 |
-
if (global_step == 100000 or (global_step % args.sampling_steps == 0 and global_step > 0)):
|
| 350 |
-
from samplers import euler_sampler
|
| 351 |
-
with torch.no_grad():
|
| 352 |
-
samples = euler_sampler(
|
| 353 |
-
model,
|
| 354 |
-
xT,
|
| 355 |
-
ys,
|
| 356 |
-
num_steps=50,
|
| 357 |
-
cfg_scale=4.0,
|
| 358 |
-
guidance_low=0.,
|
| 359 |
-
guidance_high=1.,
|
| 360 |
-
path_type=args.path_type,
|
| 361 |
-
heun=False,
|
| 362 |
-
).to(torch.float32)
|
| 363 |
-
samples = vae.decode((samples - latents_bias) / latents_scale).sample
|
| 364 |
-
gt_samples = vae.decode((gt_xs - latents_bias) / latents_scale).sample
|
| 365 |
-
samples = (samples + 1) / 2.
|
| 366 |
-
gt_samples = (gt_samples + 1) / 2.
|
| 367 |
-
|
| 368 |
-
# Save images locally instead of logging to wandb
|
| 369 |
-
out_samples = accelerator.gather(samples.to(torch.float32))
|
| 370 |
-
gt_samples = accelerator.gather(gt_samples.to(torch.float32))
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
# Save as grid images
|
| 375 |
-
out_samples = Image.fromarray(array2grid(out_samples))
|
| 376 |
-
gt_samples = Image.fromarray(array2grid(gt_samples))
|
| 377 |
-
|
| 378 |
-
if accelerator.is_main_process:
|
| 379 |
-
base_dir = os.path.join(args.output_dir, args.exp_name)
|
| 380 |
-
sample_dir = os.path.join(base_dir, "samples")
|
| 381 |
-
os.makedirs(sample_dir, exist_ok=True)
|
| 382 |
-
out_samples.save(f"{sample_dir}/samples_step_{global_step}.png")
|
| 383 |
-
gt_samples.save(f"{sample_dir}/gt_samples_step_{global_step}.png")
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
logging.info(f"Saved samples at step {global_step}")
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
logging.info("Generating EMA samples done.")
|
| 392 |
-
|
| 393 |
-
logs = {
|
| 394 |
-
"ex_f_t": elapsed_ms,
|
| 395 |
-
"bp_t": bpt,
|
| 396 |
-
"loss": accelerator.gather(loss_mean).mean().detach().item(),
|
| 397 |
-
"proj_loss": accelerator.gather(proj_loss_mean).mean().detach().item(),
|
| 398 |
-
}
|
| 399 |
-
progress_bar.set_postfix(**logs)
|
| 400 |
-
accelerator.log(logs, step=global_step)
|
| 401 |
-
|
| 402 |
-
if global_step >= args.max_train_steps:
|
| 403 |
-
break
|
| 404 |
-
if global_step >= args.max_train_steps:
|
| 405 |
-
break
|
| 406 |
-
|
| 407 |
-
model.eval() # important! This disables randomized embedding dropout
|
| 408 |
-
# do any sampling/FID calculation/etc. with ema (or model) in eval mode ...
|
| 409 |
-
|
| 410 |
-
accelerator.wait_for_everyone()
|
| 411 |
-
if accelerator.is_main_process:
|
| 412 |
-
logger.info("Done!")
|
| 413 |
-
accelerator.end_training()
|
| 414 |
-
|
| 415 |
-
def parse_args(input_args=None):
|
| 416 |
-
parser = argparse.ArgumentParser(description="Training")
|
| 417 |
-
|
| 418 |
-
# logging:
|
| 419 |
-
parser.add_argument("--output-dir", type=str, default="exps")
|
| 420 |
-
parser.add_argument("--exp-name", type=str, required=True)
|
| 421 |
-
parser.add_argument("--logging-dir", type=str, default="logs")
|
| 422 |
-
parser.add_argument("--sampling-steps", type=int, default=5000000)
|
| 423 |
-
parser.add_argument("--resume-step", type=int, default=0)
|
| 424 |
-
|
| 425 |
-
# model
|
| 426 |
-
parser.add_argument("--model", type=str)
|
| 427 |
-
parser.add_argument("--num-classes", type=int, default=1000)
|
| 428 |
-
parser.add_argument("--encoder-depth", type=int, default=8)
|
| 429 |
-
parser.add_argument("--fused-attn", action=argparse.BooleanOptionalAction, default=True)
|
| 430 |
-
parser.add_argument("--qk-norm", action=argparse.BooleanOptionalAction, default=False)
|
| 431 |
-
|
| 432 |
-
# dataset
|
| 433 |
-
parser.add_argument("--data-dir", type=str, default="../data/imagenet256")
|
| 434 |
-
parser.add_argument("--resolution", type=int, choices=[256], default=256)
|
| 435 |
-
parser.add_argument("--batch-size", type=int, default=32)
|
| 436 |
-
|
| 437 |
-
# precision
|
| 438 |
-
parser.add_argument("--allow-tf32", action="store_true")
|
| 439 |
-
parser.add_argument("--mixed-precision", type=str, default="fp16", choices=["no", "fp16", "bf16"])
|
| 440 |
-
|
| 441 |
-
# optimization
|
| 442 |
-
parser.add_argument("--epochs", type=int, default=800)
|
| 443 |
-
parser.add_argument("--max-train-steps", type=int, default=2000000)
|
| 444 |
-
parser.add_argument("--checkpointing-steps", type=int, default=500000)
|
| 445 |
-
parser.add_argument("--gradient-accumulation-steps", type=int, default=1)
|
| 446 |
-
parser.add_argument("--learning-rate", type=float, default=1e-4)
|
| 447 |
-
parser.add_argument("--adam-beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
| 448 |
-
parser.add_argument("--adam-beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
| 449 |
-
parser.add_argument("--adam-weight-decay", type=float, default=0., help="Weight decay to use.")
|
| 450 |
-
parser.add_argument("--adam-epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
| 451 |
-
parser.add_argument("--max-grad-norm", default=1.0, type=float, help="Max gradient norm.")
|
| 452 |
-
|
| 453 |
-
# seed
|
| 454 |
-
parser.add_argument("--seed", type=int, default=0)
|
| 455 |
-
|
| 456 |
-
# cpu
|
| 457 |
-
parser.add_argument("--num-workers", type=int, default=8)
|
| 458 |
-
|
| 459 |
-
# loss
|
| 460 |
-
parser.add_argument("--path-type", type=str, default="linear", choices=["linear", "cosine"])
|
| 461 |
-
parser.add_argument("--prediction", type=str, default="v", choices=["v"]) # currently we only support v-prediction
|
| 462 |
-
parser.add_argument("--cfg-prob", type=float, default=0.1)
|
| 463 |
-
parser.add_argument("--enc-type", type=str, default='dinov2-vit-b')
|
| 464 |
-
parser.add_argument("--proj-coeff", type=float, default=0.5)
|
| 465 |
-
parser.add_argument("--weighting", default="uniform", type=str, help="Max gradient norm.")
|
| 466 |
-
parser.add_argument("--legacy", action=argparse.BooleanOptionalAction, default=False)
|
| 467 |
-
|
| 468 |
-
if input_args is not None:
|
| 469 |
-
args = parser.parse_args(input_args)
|
| 470 |
-
else:
|
| 471 |
-
args = parser.parse_args()
|
| 472 |
-
|
| 473 |
-
return args
|
| 474 |
-
|
| 475 |
-
if __name__ == "__main__":
|
| 476 |
-
args = parse_args()
|
| 477 |
-
|
| 478 |
-
main(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|