DyJiang commited on
Commit
1fb0d58
·
verified ·
1 Parent(s): 800f26e

Delete train.py

Browse files
Files changed (1) hide show
  1. train.py +0 -478
train.py DELETED
@@ -1,478 +0,0 @@
1
- import argparse
2
- import copy
3
- from copy import deepcopy
4
- import logging
5
- import os
6
- import torch
7
- torch.hub.set_dir(r"/slurm-files/jdy/hub/")
8
- os.environ["NCCL_TIMEOUT"] = "9000000000"
9
- os.environ["HF_DATASETS_CACHE"] = "/slurm-files/jdy/cache/"
10
- os.environ["HF_HOME"] = "/slurm-files/jdy/cache/"
11
- os.environ["HUGGINGFACE_HUB_CACHE"] = "/slurm-files/jdy/cache/"
12
- os.environ["HF_ENDPOINT"]="https://hf-mirror.com"
13
-
14
- from pathlib import Path
15
- from collections import OrderedDict
16
- import json
17
-
18
- import numpy as np
19
- from torchvision.utils import save_image
20
- import torch.nn.functional as F
21
- import torch.utils.checkpoint
22
- from tqdm.auto import tqdm
23
- from torch.utils.data import DataLoader
24
-
25
- from accelerate import Accelerator
26
- from accelerate.logging import get_logger
27
- from accelerate.utils import ProjectConfiguration, set_seed
28
-
29
- from models.sit import SiT_models
30
- from loss import SILoss
31
- from utils import load_encoders
32
-
33
- from dataset import CustomDataset
34
- from diffusers.models import AutoencoderKL
35
- # import wandb_utils
36
- import math
37
- from torchvision.utils import make_grid
38
- from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
39
- from torchvision.transforms import Normalize
40
- from PIL import Image
41
- logger = get_logger(__name__)
42
-
43
- CLIP_DEFAULT_MEAN = (0.48145466, 0.4578275, 0.40821073)
44
- CLIP_DEFAULT_STD = (0.26862954, 0.26130258, 0.27577711)
45
-
46
-
47
- def preprocess_raw_image(x, enc_type):
48
- if 'clip' in enc_type:
49
- x = x / 255.
50
- x = torch.nn.functional.interpolate(x, 224, mode='bicubic')
51
- x = Normalize(CLIP_DEFAULT_MEAN, CLIP_DEFAULT_STD)(x)
52
- elif 'mocov3' in enc_type or 'mae' in enc_type:
53
- x = x / 255.
54
- x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
55
- elif 'dinov2' in enc_type:
56
- x = x / 255.
57
- x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
58
- x = torch.nn.functional.interpolate(x, 224, mode='bicubic')
59
- elif 'dinov1' in enc_type:
60
- x = x / 255.
61
- x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
62
- elif 'jepa' in enc_type:
63
- x = x / 255.
64
- x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
65
- x = torch.nn.functional.interpolate(x, 224, mode='bicubic')
66
-
67
- return x
68
-
69
-
70
- def array2grid(x):
71
- nrow = round(math.sqrt(x.size(0)))
72
- x = make_grid(x.clamp(0, 1), nrow=nrow, value_range=(0, 1))
73
- x = x.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
74
- return x
75
-
76
-
77
- @torch.no_grad()
78
- def sample_posterior(moments, latents_scale=1., latents_bias=0.):
79
- device = moments.device
80
-
81
- mean, std = torch.chunk(moments, 2, dim=1)
82
- z = mean + std * torch.randn_like(mean)
83
- z = (z * latents_scale + latents_bias)
84
- return z
85
-
86
-
87
- @torch.no_grad()
88
- def update_ema(ema_model, model, decay=0.9999):
89
- """
90
- Step the EMA model towards the current model.
91
- """
92
- ema_params = OrderedDict(ema_model.named_parameters())
93
- model_params = OrderedDict(model.named_parameters())
94
-
95
- for name, param in model_params.items():
96
- name = name.replace("module.", "")
97
- # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
98
- ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
99
-
100
-
101
- def create_logger(logging_dir):
102
- """
103
- Create a logger that writes to a log file and stdout.
104
- """
105
- logging.basicConfig(
106
- level=logging.INFO,
107
- format='[\033[34m%(asctime)s\033[0m] %(message)s',
108
- datefmt='%Y-%m-%d %H:%M:%S',
109
- handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
110
- )
111
- logger = logging.getLogger(__name__)
112
- return logger
113
-
114
-
115
- def requires_grad(model, flag=True):
116
- """
117
- Set requires_grad flag for all parameters in a model.
118
- """
119
- for p in model.parameters():
120
- p.requires_grad = flag
121
-
122
-
123
- #################################################################################
124
- # Training Loop #
125
- #################################################################################
126
-
127
- def main(args):
128
- # set accelerator
129
- logging_dir = Path(args.output_dir, args.logging_dir)
130
- accelerator_project_config = ProjectConfiguration(
131
- project_dir=args.output_dir, logging_dir=logging_dir
132
- )
133
-
134
- accelerator = Accelerator(
135
- gradient_accumulation_steps=args.gradient_accumulation_steps,
136
- mixed_precision=args.mixed_precision,
137
- project_config=accelerator_project_config,
138
- )
139
-
140
- if accelerator.is_main_process:
141
- os.makedirs(args.output_dir, exist_ok=True) # Make results folder (holds all experiment subfolders)
142
- save_dir = os.path.join(args.output_dir, args.exp_name)
143
- os.makedirs(save_dir, exist_ok=True)
144
- args_dict = vars(args)
145
- # Save to a JSON file
146
- json_dir = os.path.join(save_dir, "args.json")
147
- with open(json_dir, 'w') as f:
148
- json.dump(args_dict, f, indent=4)
149
- checkpoint_dir = f"{save_dir}/checkpoints" # Stores saved model checkpoints
150
- os.makedirs(checkpoint_dir, exist_ok=True)
151
- logger = create_logger(save_dir)
152
- logger.info(f"Experiment directory created at {save_dir}")
153
- device = accelerator.device
154
- if torch.backends.mps.is_available():
155
- accelerator.native_amp = False
156
- if args.seed is not None:
157
- set_seed(args.seed + accelerator.process_index)
158
-
159
- # Create model:
160
- assert args.resolution % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)."
161
- latent_size = args.resolution // 8
162
-
163
- if args.enc_type != 'None':
164
- encoders, encoder_types, architectures = load_encoders(args.enc_type, device)
165
- else:
166
- encoders, encoder_types, architectures = [None], [None], [None]
167
- z_dims = [encoder.embed_dim for encoder in encoders] if args.enc_type != 'None' else [0]
168
- block_kwargs = {"fused_attn": args.fused_attn, "qk_norm": args.qk_norm}
169
- model = SiT_models[args.model](
170
- input_size=latent_size,
171
- num_classes=args.num_classes,
172
- use_cfg = (args.cfg_prob > 0),
173
- z_dims = z_dims,
174
- encoder_depth=args.encoder_depth,
175
- **block_kwargs
176
- )
177
-
178
-
179
- model = model.to(device)
180
- ema = deepcopy(model).to(device) # Create an EMA of the model for use after training
181
- vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-mse").to(device)
182
- requires_grad(ema, False)
183
-
184
- latents_scale = torch.tensor(
185
- [0.18215, 0.18215, 0.18215, 0.18215]
186
- ).view(1, 4, 1, 1).to(device)
187
- latents_bias = torch.tensor(
188
- [0., 0., 0., 0.]
189
- ).view(1, 4, 1, 1).to(device)
190
-
191
- # create loss function
192
- loss_fn = SILoss(
193
- prediction=args.prediction,
194
- path_type=args.path_type,
195
- encoders=encoders,
196
- accelerator=accelerator,
197
- latents_scale=latents_scale,
198
- latents_bias=latents_bias,
199
- weighting=args.weighting
200
- )
201
- if accelerator.is_main_process:
202
- logger.info(f"SiT Parameters: {sum(p.numel() for p in model.parameters()):,}")
203
-
204
- # Setup optimizer (we used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper):
205
- if args.allow_tf32:
206
- torch.backends.cuda.matmul.allow_tf32 = True
207
- torch.backends.cudnn.allow_tf32 = True
208
-
209
- optimizer = torch.optim.AdamW(
210
- model.parameters(),
211
- lr=args.learning_rate,
212
- betas=(args.adam_beta1, args.adam_beta2),
213
- weight_decay=args.adam_weight_decay,
214
- eps=args.adam_epsilon,
215
- )
216
-
217
- # Setup data:
218
- train_dataset = CustomDataset(args.data_dir)
219
- local_batch_size = int(args.batch_size)
220
-
221
- train_dataloader = DataLoader(
222
- train_dataset,
223
- batch_size=local_batch_size,
224
- shuffle=True,
225
- num_workers=args.num_workers,
226
- pin_memory=True,
227
- drop_last=True
228
- )
229
- if accelerator.is_main_process:
230
- logger.info(f"Dataset contains {len(train_dataset):,} images ({args.data_dir})")
231
-
232
- # Prepare models for training:
233
- update_ema(ema, model, decay=0) # Ensure EMA is initialized with synced weights
234
- model.train() # important! This enables embedding dropout for classifier-free guidance
235
- ema.eval() # EMA model should always be in eval mode
236
-
237
- # resume:
238
- global_step = 0
239
- if args.resume_step > 0:
240
- ckpt_name = str(args.resume_step).zfill(7) +'.pt'
241
- ckpt = torch.load(
242
- f'{os.path.join(args.output_dir, args.exp_name)}/checkpoints/{ckpt_name}',
243
- map_location='cpu',
244
- )
245
- model.load_state_dict(ckpt['model'])
246
- ema.load_state_dict(ckpt['ema'])
247
- optimizer.load_state_dict(ckpt['opt'])
248
- global_step = ckpt['steps']
249
-
250
- model, optimizer, train_dataloader = accelerator.prepare(
251
- model, optimizer, train_dataloader
252
- )
253
-
254
- if accelerator.is_main_process:
255
- logger.info(f"Starting training experiment: {args.exp_name}")
256
-
257
- progress_bar = tqdm(
258
- range(0, args.max_train_steps),
259
- initial=global_step,
260
- desc="Steps",
261
- # Only show the progress bar once on each machine.
262
- disable=not accelerator.is_local_main_process,
263
- )
264
-
265
- # Labels to condition the model with (feel free to change):
266
- sample_batch_size = 16 // accelerator.num_processes
267
- _, gt_xs, _ = next(iter(train_dataloader))
268
- gt_xs = gt_xs[:sample_batch_size]
269
- gt_xs = sample_posterior(
270
- gt_xs.to(device), latents_scale=latents_scale, latents_bias=latents_bias
271
- )
272
- ys = torch.randint(1000, size=(sample_batch_size,), device=device)
273
- ys = ys.to(device)
274
- # Create sampling noise:
275
- n = ys.size(0)
276
- xT = torch.randn((n, 4, latent_size, latent_size), device=device)
277
-
278
- for epoch in range(args.epochs):
279
- model.train()
280
- for raw_image, x, y in train_dataloader:
281
- raw_image = raw_image.to(device)
282
- x = x.squeeze(dim=1).to(device)
283
- y = y.to(device)
284
- z = None
285
- if args.legacy:
286
- # In our early experiments, we accidentally apply label dropping twice:
287
- # once in train.py and once in sit.py.
288
- # We keep this option for exact reproducibility with previous runs.
289
- drop_ids = torch.rand(y.shape[0], device=y.device) < args.cfg_prob
290
- labels = torch.where(drop_ids, args.num_classes, y)
291
- else:
292
- labels = y
293
- with torch.no_grad():
294
- x = sample_posterior(x, latents_scale=latents_scale, latents_bias=latents_bias)
295
- zs = []
296
- import time
297
- start = time.perf_counter()
298
- with accelerator.autocast():
299
-
300
- for encoder, encoder_type, arch in zip(encoders, encoder_types, architectures):
301
- raw_image_ = preprocess_raw_image(raw_image, encoder_type)
302
- z = encoder.forward_features(raw_image_)
303
- if 'mocov3' in encoder_type: z = z = z[:, 1:]
304
- if 'dinov2' in encoder_type: z = z['x_norm_patchtokens']
305
- zs.append(z)
306
-
307
- end = time.perf_counter()
308
- elapsed_ms = end - start
309
-
310
- with accelerator.accumulate(model):
311
- model_kwargs = dict(y=labels)
312
- loss, proj_loss = loss_fn(model, x, model_kwargs, zs=zs)
313
- loss_mean = loss.mean()
314
- proj_loss_mean = proj_loss.mean()
315
- loss = loss_mean + proj_loss_mean * args.proj_coeff
316
-
317
- ## optimization
318
- start = time.perf_counter()
319
- accelerator.backward(loss)
320
- end = time.perf_counter()
321
- bpt = end - start
322
- if accelerator.sync_gradients:
323
- params_to_clip = model.parameters()
324
- grad_norm = accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
325
- optimizer.step()
326
- optimizer.zero_grad(set_to_none=True)
327
-
328
- if accelerator.sync_gradients:
329
- update_ema(ema, model) # change ema function
330
-
331
- ### enter
332
- if accelerator.sync_gradients:
333
- progress_bar.update(1)
334
- global_step += 1
335
- if global_step % args.checkpointing_steps == 0 and global_step > 0:
336
- if accelerator.is_main_process:
337
- checkpoint = {
338
- "model": model.module.state_dict(),
339
- "ema": ema.state_dict(),
340
- "opt": optimizer.state_dict(),
341
- "args": args,
342
- "steps": global_step,
343
- }
344
- checkpoint_path = f"{checkpoint_dir}/{global_step:07d}.pt"
345
- torch.save(checkpoint, checkpoint_path)
346
- logger.info(f"Saved checkpoint to {checkpoint_path}")
347
-
348
- # you can set global_step==1 instead of 1e10 to help to debug
349
- if (global_step == 100000 or (global_step % args.sampling_steps == 0 and global_step > 0)):
350
- from samplers import euler_sampler
351
- with torch.no_grad():
352
- samples = euler_sampler(
353
- model,
354
- xT,
355
- ys,
356
- num_steps=50,
357
- cfg_scale=4.0,
358
- guidance_low=0.,
359
- guidance_high=1.,
360
- path_type=args.path_type,
361
- heun=False,
362
- ).to(torch.float32)
363
- samples = vae.decode((samples - latents_bias) / latents_scale).sample
364
- gt_samples = vae.decode((gt_xs - latents_bias) / latents_scale).sample
365
- samples = (samples + 1) / 2.
366
- gt_samples = (gt_samples + 1) / 2.
367
-
368
- # Save images locally instead of logging to wandb
369
- out_samples = accelerator.gather(samples.to(torch.float32))
370
- gt_samples = accelerator.gather(gt_samples.to(torch.float32))
371
-
372
-
373
-
374
- # Save as grid images
375
- out_samples = Image.fromarray(array2grid(out_samples))
376
- gt_samples = Image.fromarray(array2grid(gt_samples))
377
-
378
- if accelerator.is_main_process:
379
- base_dir = os.path.join(args.output_dir, args.exp_name)
380
- sample_dir = os.path.join(base_dir, "samples")
381
- os.makedirs(sample_dir, exist_ok=True)
382
- out_samples.save(f"{sample_dir}/samples_step_{global_step}.png")
383
- gt_samples.save(f"{sample_dir}/gt_samples_step_{global_step}.png")
384
-
385
-
386
-
387
- logging.info(f"Saved samples at step {global_step}")
388
-
389
-
390
-
391
- logging.info("Generating EMA samples done.")
392
-
393
- logs = {
394
- "ex_f_t": elapsed_ms,
395
- "bp_t": bpt,
396
- "loss": accelerator.gather(loss_mean).mean().detach().item(),
397
- "proj_loss": accelerator.gather(proj_loss_mean).mean().detach().item(),
398
- }
399
- progress_bar.set_postfix(**logs)
400
- accelerator.log(logs, step=global_step)
401
-
402
- if global_step >= args.max_train_steps:
403
- break
404
- if global_step >= args.max_train_steps:
405
- break
406
-
407
- model.eval() # important! This disables randomized embedding dropout
408
- # do any sampling/FID calculation/etc. with ema (or model) in eval mode ...
409
-
410
- accelerator.wait_for_everyone()
411
- if accelerator.is_main_process:
412
- logger.info("Done!")
413
- accelerator.end_training()
414
-
415
- def parse_args(input_args=None):
416
- parser = argparse.ArgumentParser(description="Training")
417
-
418
- # logging:
419
- parser.add_argument("--output-dir", type=str, default="exps")
420
- parser.add_argument("--exp-name", type=str, required=True)
421
- parser.add_argument("--logging-dir", type=str, default="logs")
422
- parser.add_argument("--sampling-steps", type=int, default=5000000)
423
- parser.add_argument("--resume-step", type=int, default=0)
424
-
425
- # model
426
- parser.add_argument("--model", type=str)
427
- parser.add_argument("--num-classes", type=int, default=1000)
428
- parser.add_argument("--encoder-depth", type=int, default=8)
429
- parser.add_argument("--fused-attn", action=argparse.BooleanOptionalAction, default=True)
430
- parser.add_argument("--qk-norm", action=argparse.BooleanOptionalAction, default=False)
431
-
432
- # dataset
433
- parser.add_argument("--data-dir", type=str, default="../data/imagenet256")
434
- parser.add_argument("--resolution", type=int, choices=[256], default=256)
435
- parser.add_argument("--batch-size", type=int, default=32)
436
-
437
- # precision
438
- parser.add_argument("--allow-tf32", action="store_true")
439
- parser.add_argument("--mixed-precision", type=str, default="fp16", choices=["no", "fp16", "bf16"])
440
-
441
- # optimization
442
- parser.add_argument("--epochs", type=int, default=800)
443
- parser.add_argument("--max-train-steps", type=int, default=2000000)
444
- parser.add_argument("--checkpointing-steps", type=int, default=500000)
445
- parser.add_argument("--gradient-accumulation-steps", type=int, default=1)
446
- parser.add_argument("--learning-rate", type=float, default=1e-4)
447
- parser.add_argument("--adam-beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
448
- parser.add_argument("--adam-beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
449
- parser.add_argument("--adam-weight-decay", type=float, default=0., help="Weight decay to use.")
450
- parser.add_argument("--adam-epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
451
- parser.add_argument("--max-grad-norm", default=1.0, type=float, help="Max gradient norm.")
452
-
453
- # seed
454
- parser.add_argument("--seed", type=int, default=0)
455
-
456
- # cpu
457
- parser.add_argument("--num-workers", type=int, default=8)
458
-
459
- # loss
460
- parser.add_argument("--path-type", type=str, default="linear", choices=["linear", "cosine"])
461
- parser.add_argument("--prediction", type=str, default="v", choices=["v"]) # currently we only support v-prediction
462
- parser.add_argument("--cfg-prob", type=float, default=0.1)
463
- parser.add_argument("--enc-type", type=str, default='dinov2-vit-b')
464
- parser.add_argument("--proj-coeff", type=float, default=0.5)
465
- parser.add_argument("--weighting", default="uniform", type=str, help="Max gradient norm.")
466
- parser.add_argument("--legacy", action=argparse.BooleanOptionalAction, default=False)
467
-
468
- if input_args is not None:
469
- args = parser.parse_args(input_args)
470
- else:
471
- args = parser.parse_args()
472
-
473
- return args
474
-
475
- if __name__ == "__main__":
476
- args = parse_args()
477
-
478
- main(args)