import argparse import copy from copy import deepcopy import logging import os from pathlib import Path from collections import OrderedDict import json import numpy as np import torch import torch.nn.functional as F import torch.utils.checkpoint from tqdm.auto import tqdm from torch.utils.data import DataLoader from accelerate import Accelerator, DistributedDataParallelKwargs from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed from models.sit import SiT_models from loss import SILoss from utils import load_encoders from dataset import CustomDataset from diffusers.models import AutoencoderKL from PIL import Image from samplers import euler_maruyama_sampler # import wandb_utils import wandb import math from torchvision.utils import make_grid from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from torchvision.transforms import Normalize logger = get_logger(__name__) CLIP_DEFAULT_MEAN = (0.48145466, 0.4578275, 0.40821073) CLIP_DEFAULT_STD = (0.26862954, 0.26130258, 0.27577711) def preprocess_raw_image(x, enc_type): resolution = x.shape[-1] if 'clip' in enc_type: x = x / 255. x = torch.nn.functional.interpolate(x, 224 * (resolution // 256), mode='bicubic') x = Normalize(CLIP_DEFAULT_MEAN, CLIP_DEFAULT_STD)(x) elif 'mocov3' in enc_type or 'mae' in enc_type: x = x / 255. x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) elif 'dinov2' in enc_type: x = x / 255. x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) x = torch.nn.functional.interpolate(x, 224 * (resolution // 256), mode='bicubic') elif 'dinov1' in enc_type: x = x / 255. x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) elif 'jepa' in enc_type: x = x / 255. x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) x = torch.nn.functional.interpolate(x, 224 * (resolution // 256), mode='bicubic') return x def array2grid(x): nrow = round(math.sqrt(x.size(0))) x = make_grid(x.clamp(0, 1), nrow=nrow, value_range=(0, 1)) x = x.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() return x @torch.no_grad() def sample_posterior(moments, latents_scale=1., latents_bias=0.): device = moments.device mean, std = torch.chunk(moments, 2, dim=1) z = mean + std * torch.randn_like(mean) z = (z * latents_scale + latents_bias) return z @torch.no_grad() def update_ema(ema_model, model, decay=0.9999): """ Step the EMA model towards the current model. """ ema_params = OrderedDict(ema_model.named_parameters()) model_params = OrderedDict(model.named_parameters()) for name, param in model_params.items(): name = name.replace("module.", "") # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) def create_logger(logging_dir): """ Create a logger that writes to a log file and stdout. """ logging.basicConfig( level=logging.INFO, format='[\033[34m%(asctime)s\033[0m] %(message)s', datefmt='%Y-%m-%d %H:%M:%S', handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")] ) logger = logging.getLogger(__name__) return logger def requires_grad(model, flag=True): """ Set requires_grad flag for all parameters in a model. """ for p in model.parameters(): p.requires_grad = flag ################################################################################# # Training Loop # ################################################################################# def main(args): # set accelerator logging_dir = Path(args.output_dir, args.logging_dir) accelerator_project_config = ProjectConfiguration( project_dir=args.output_dir, logging_dir=logging_dir ) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, log_with=args.report_to, project_config=accelerator_project_config, kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)] ) if accelerator.is_main_process: os.makedirs(args.output_dir, exist_ok=True) # Make results folder (holds all experiment subfolders) save_dir = os.path.join(args.output_dir, args.exp_name) os.makedirs(save_dir, exist_ok=True) args_dict = vars(args) # Save to a JSON file json_dir = os.path.join(save_dir, "args.json") with open(json_dir, 'w') as f: json.dump(args_dict, f, indent=4) checkpoint_dir = f"{save_dir}/checkpoints" # Stores saved model checkpoints os.makedirs(checkpoint_dir, exist_ok=True) logger = create_logger(save_dir) logger.info(f"Experiment directory created at {save_dir}") device = accelerator.device if torch.backends.mps.is_available(): accelerator.native_amp = False if args.seed is not None: set_seed(args.seed + accelerator.process_index) # Create model: assert args.resolution % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)." latent_size = args.resolution // 8 if args.enc_type != None: encoders, encoder_types, architectures = load_encoders( args.enc_type, device, args.resolution ) else: raise NotImplementedError() z_dims = [encoder.embed_dim for encoder in encoders] if args.enc_type != 'None' else [0] block_kwargs = {"fused_attn": args.fused_attn, "qk_norm": args.qk_norm} model = SiT_models[args.model]( input_size=latent_size, num_classes=args.num_classes, use_cfg = (args.cfg_prob > 0), z_dims = z_dims, encoder_depth=args.encoder_depth, **block_kwargs ) model = model.to(device) ema = deepcopy(model).to(device) # Create an EMA of the model for use after training requires_grad(ema, False) latents_scale = torch.tensor( [0.18215, 0.18215, 0.18215, 0.18215] ).view(1, 4, 1, 1).to(device) latents_bias = torch.tensor( [0., 0., 0., 0.] ).view(1, 4, 1, 1).to(device) vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device) vae.eval() # create loss function loss_fn = SILoss( prediction=args.prediction, path_type=args.path_type, encoders=encoders, accelerator=accelerator, latents_scale=latents_scale, latents_bias=latents_bias, weighting=args.weighting ) if accelerator.is_main_process: logger.info(f"SiT Parameters: {sum(p.numel() for p in model.parameters()):,}") # Setup optimizer (we used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper): if args.allow_tf32: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True optimizer = torch.optim.AdamW( model.parameters(), lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, ) # Setup data: train_dataset = CustomDataset( args.data_dir, semantic_features_dir=args.semantic_features_dir ) local_batch_size = int(args.batch_size // accelerator.num_processes) train_dataloader = DataLoader( train_dataset, batch_size=local_batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, drop_last=True ) if accelerator.is_main_process: logger.info(f"Dataset contains {len(train_dataset):,} images ({args.data_dir})") # Prepare models for training: update_ema(ema, model, decay=0) # Ensure EMA is initialized with synced weights model.train() # important! This enables embedding dropout for classifier-free guidance ema.eval() # EMA model should always be in eval mode # resume: global_step = 0 if args.resume_step > 0: ckpt_name = str(args.resume_step).zfill(7) +'.pt' ckpt = torch.load( f'{os.path.join(args.output_dir, args.exp_name)}/checkpoints/{ckpt_name}', map_location='cpu', ) model.load_state_dict(ckpt['model']) ema.load_state_dict(ckpt['ema']) optimizer.load_state_dict(ckpt['opt']) global_step = ckpt['steps'] model, optimizer, train_dataloader = accelerator.prepare( model, optimizer, train_dataloader ) if accelerator.is_main_process: tracker_config = vars(copy.deepcopy(args)) accelerator.init_trackers( project_name="REG", config=tracker_config, init_kwargs={ "wandb": {"name": f"{args.exp_name}"} }, ) progress_bar = tqdm( range(0, args.max_train_steps), initial=global_step, desc="Steps", # Only show the progress bar once on each machine. disable=not accelerator.is_local_main_process, ) # Labels to condition the model with (feel free to change): sample_batch_size = 64 // accelerator.num_processes first_batch = next(iter(train_dataloader)) preprocessed_semantic = len(first_batch) == 4 if preprocessed_semantic: gt_raw_images, gt_xs, _r_pre, _y = first_batch else: gt_raw_images, gt_xs, _y = first_batch # 仅在“非预处理 semantic 模式”下,raw_image 是 RGB 图(分辨率应与 args.resolution 对齐)。 if not preprocessed_semantic: assert gt_raw_images.shape[-1] == args.resolution gt_xs = gt_xs[:sample_batch_size] gt_xs = sample_posterior( gt_xs.to(device), latents_scale=latents_scale, latents_bias=latents_bias ) ys = torch.randint(1000, size=(sample_batch_size,), device=device) ys = ys.to(device) # Create sampling noise: n = ys.size(0) xT = torch.randn((n, 4, latent_size, latent_size), device=device) for epoch in range(args.epochs): model.train() for batch in train_dataloader: if len(batch) == 4: raw_image, x, r_preprocessed, y = batch r_preprocessed = r_preprocessed.to(device).float() else: raw_image, x, y = batch r_preprocessed = None raw_image = raw_image.to(device) x = x.squeeze(dim=1).to(device) y = y.to(device) z = None if args.legacy: # In our early experiments, we accidentally apply label dropping twice: # once in train.py and once in sit.py. # We keep this option for exact reproducibility with previous runs. drop_ids = torch.rand(y.shape[0], device=y.device) < args.cfg_prob labels = torch.where(drop_ids, args.num_classes, y) else: labels = y with torch.no_grad(): x = sample_posterior(x, latents_scale=latents_scale, latents_bias=latents_bias) zs = [] if r_preprocessed is not None: # 预处理 semantic 模式:直接用 cls token 构造 dense tokens。 cls_token = r_preprocessed while cls_token.dim() > 2: cls_token = cls_token.squeeze(1) base_m = model.module if hasattr(model, "module") else model n_pad = base_m.x_embedder.num_patches zs = [ torch.cat( [ cls_token.unsqueeze(1), cls_token.unsqueeze(1).expand(-1, n_pad, -1), ], dim=1, ) ] else: # 在线 encoder 模式:与原 New/REG 行为一致 with accelerator.autocast(): for encoder, encoder_type, arch in zip( encoders, encoder_types, architectures ): raw_image_ = preprocess_raw_image(raw_image, encoder_type) z = encoder.forward_features(raw_image_) if "dinov2" in encoder_type: dense_z = z["x_norm_patchtokens"] cls_token = z["x_norm_clstoken"] dense_z = torch.cat( [cls_token.unsqueeze(1), dense_z], dim=1 ) else: exit() zs.append(dense_z) with accelerator.accumulate(model): model_kwargs = dict(y=labels) loss1, proj_loss1, time_input, noises, loss2 = loss_fn(model, x, model_kwargs, zs=zs, cls_token=cls_token, time_input=None, noises=None) loss_mean = loss1.mean() loss_mean_cls = loss2.mean() * args.cls proj_loss_mean = proj_loss1.mean() * args.proj_coeff loss = loss_mean + proj_loss_mean + loss_mean_cls ## optimization accelerator.backward(loss) if accelerator.sync_gradients: params_to_clip = model.parameters() grad_norm = accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() optimizer.zero_grad(set_to_none=True) if accelerator.sync_gradients: update_ema(ema, model) # change ema function ### enter if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 if global_step % args.checkpointing_steps == 0 and global_step > 0: if accelerator.is_main_process: checkpoint = { "model": model.module.state_dict(), "ema": ema.state_dict(), "opt": optimizer.state_dict(), "args": args, "steps": global_step, } checkpoint_path = f"{checkpoint_dir}/{global_step:07d}.pt" torch.save(checkpoint, checkpoint_path) logger.info(f"Saved checkpoint to {checkpoint_path}") if (global_step == 1 or (global_step % args.sampling_steps == 0 and global_step > 0)): t_mid_vis = float(args.t_c) tc_tag = f"{t_mid_vis:.4f}".rstrip("0").rstrip(".").replace(".", "_") logging.info( f"Generating EMA samples (Euler-Maruyama; t≈{t_mid_vis:g} -> t=0)..." ) ema.eval() with torch.no_grad(): latent_size = args.resolution // 8 n_samples = min(16, args.batch_size) base_model = model.module if hasattr(model, "module") else model cls_dim = base_model.z_dims[0] shared_seed = torch.randint(0, 2**32, (1,), device=device).item() torch.manual_seed(shared_seed) z_init = torch.randn(n_samples, base_model.in_channels, latent_size, latent_size, device=device) torch.manual_seed(shared_seed) cls_init = torch.randn(n_samples, cls_dim, device=device) y_samples = torch.randint(0, args.num_classes, (n_samples,), device=device) z_0, z_mid, _ = euler_maruyama_sampler( ema, z_init, y_samples, num_steps=50, cfg_scale=1.0, guidance_low=0.0, guidance_high=1.0, path_type=args.path_type, cls_latents=cls_init, args=args, return_mid_state=True, t_mid=t_mid_vis, ) samples_root = os.path.join(args.output_dir, args.exp_name, "samples") t0_dir = os.path.join(samples_root, "t0") t_mid_dir = os.path.join(samples_root, f"t0_{tc_tag}") os.makedirs(t0_dir, exist_ok=True) os.makedirs(t_mid_dir, exist_ok=True) z_f = z_0.to(dtype=torch.float32) samples_final = vae.decode((z_f - latents_bias) / latents_scale).sample samples_final = (samples_final + 1) / 2.0 samples_final = samples_final.clamp(0, 1) grid_final = array2grid(samples_final) Image.fromarray(grid_final).save( os.path.join(t0_dir, f"step_{global_step:07d}_t0.png") ) if z_mid is not None: z_m = z_mid.to(dtype=torch.float32) samples_mid = vae.decode((z_m - latents_bias) / latents_scale).sample samples_mid = (samples_mid + 1) / 2.0 samples_mid = samples_mid.clamp(0, 1) grid_mid = array2grid(samples_mid) Image.fromarray(grid_mid).save( os.path.join(t_mid_dir, f"step_{global_step:07d}_t0_{tc_tag}.png") ) else: logging.warning( f"Sampling time grid did not bracket t_mid={t_mid_vis:g}; " f"skip t0_{tc_tag} image this step." ) del z_init, cls_init, y_samples, z_0 if z_mid is not None: del z_mid del samples_final, grid_final if "samples_mid" in locals(): del samples_mid, grid_mid torch.cuda.empty_cache() logs = { "loss_final": accelerator.gather(loss).mean().detach().item(), "loss_mean": accelerator.gather(loss_mean).mean().detach().item(), "proj_loss": accelerator.gather(proj_loss_mean).mean().detach().item(), "loss_mean_cls": accelerator.gather(loss_mean_cls).mean().detach().item(), "grad_norm": accelerator.gather(grad_norm).mean().detach().item() } log_message = ", ".join(f"{key}: {value:.6f}" for key, value in logs.items()) logging.info(f"Step: {global_step}, Training Logs: {log_message}") progress_bar.set_postfix(**logs) accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break if global_step >= args.max_train_steps: break model.eval() # important! This disables randomized embedding dropout # do any sampling/FID calculation/etc. with ema (or model) in eval mode ... accelerator.wait_for_everyone() if accelerator.is_main_process: logger.info("Done!") accelerator.end_training() def parse_args(input_args=None): parser = argparse.ArgumentParser(description="Training") # logging: parser.add_argument("--output-dir", type=str, default="exps") parser.add_argument("--exp-name", type=str, required=True) parser.add_argument("--logging-dir", type=str, default="logs") parser.add_argument("--report-to", type=str, default="wandb") parser.add_argument("--sampling-steps", type=int, default=10000) parser.add_argument("--resume-step", type=int, default=0) # model parser.add_argument("--model", type=str) parser.add_argument("--num-classes", type=int, default=1000) parser.add_argument("--encoder-depth", type=int, default=8) parser.add_argument("--fused-attn", action=argparse.BooleanOptionalAction, default=True) parser.add_argument("--qk-norm", action=argparse.BooleanOptionalAction, default=False) parser.add_argument("--ops-head", type=int, default=16) # dataset parser.add_argument("--data-dir", type=str, default="../data/imagenet256") parser.add_argument( "--semantic-features-dir", type=str, default=None, help="预处理 semantic features 目录(与 REG/dataset.py 语义相同),仅影响数据加载方式,不引入 t_c。", ) parser.add_argument("--resolution", type=int, choices=[256, 512], default=256) parser.add_argument("--batch-size", type=int, default=8)#256 # precision parser.add_argument("--allow-tf32", action="store_true") parser.add_argument("--mixed-precision", type=str, default="fp16", choices=["no", "fp16", "bf16"]) # optimization parser.add_argument("--epochs", type=int, default=1400) parser.add_argument("--max-train-steps", type=int, default=1000000) parser.add_argument("--checkpointing-steps", type=int, default=10000) parser.add_argument("--gradient-accumulation-steps", type=int, default=1) parser.add_argument("--learning-rate", type=float, default=1e-4) parser.add_argument("--adam-beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") parser.add_argument("--adam-beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") parser.add_argument("--adam-weight-decay", type=float, default=0., help="Weight decay to use.") parser.add_argument("--adam-epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") parser.add_argument("--max-grad-norm", default=1.0, type=float, help="Max gradient norm.") # seed parser.add_argument("--seed", type=int, default=0) # cpu parser.add_argument("--num-workers", type=int, default=4) # loss parser.add_argument("--path-type", type=str, default="linear", choices=["linear", "cosine"]) parser.add_argument("--prediction", type=str, default="v", choices=["v"]) # currently we only support v-prediction parser.add_argument("--cfg-prob", type=float, default=0.1) parser.add_argument("--enc-type", type=str, default='dinov2-vit-b') parser.add_argument("--proj-coeff", type=float, default=0.5) parser.add_argument("--weighting", default="uniform", type=str, help="Max gradient norm.") parser.add_argument("--legacy", action=argparse.BooleanOptionalAction, default=False) parser.add_argument("--cls", type=float, default=0.03) parser.add_argument( "--t-c", type=float, default=0.5, help="训练中采样时保存的中间时刻 t(用于输出 t0 与 t0_tc 对比图)。", ) if input_args is not None: args = parser.parse_args(input_args) else: args = parser.parse_args() return args if __name__ == "__main__": args = parse_args() main(args)