import argparse import copy from copy import deepcopy import logging import os from pathlib import Path from collections import OrderedDict import json import numpy as np import torch import torch.nn.functional as F import torch.utils.checkpoint from tqdm.auto import tqdm from torch.utils.data import DataLoader from accelerate import Accelerator, DistributedDataParallelKwargs from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed from models.sit import SiT_models from loss import SILoss from utils import load_encoders from dataset import CustomDataset from diffusers.models import AutoencoderKL # import wandb_utils import wandb import math from torchvision.utils import make_grid from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from torchvision.transforms import Normalize from PIL import Image logger = get_logger(__name__) def semantic_dim_from_enc_type(enc_type): """DINOv2 等 enc_type 字符串推断 class token 维度(与预处理特征一致)。""" if enc_type is None: return 768 s = str(enc_type).lower() if "vit-g" in s or "vitg" in s: return 1536 if "vit-l" in s or "vitl" in s: return 1024 if "vit-s" in s or "vits" in s: return 384 return 768 CLIP_DEFAULT_MEAN = (0.48145466, 0.4578275, 0.40821073) CLIP_DEFAULT_STD = (0.26862954, 0.26130258, 0.27577711) def preprocess_raw_image(x, enc_type): resolution = x.shape[-1] if 'clip' in enc_type: x = x / 255. x = torch.nn.functional.interpolate(x, 224 * (resolution // 256), mode='bicubic') x = Normalize(CLIP_DEFAULT_MEAN, CLIP_DEFAULT_STD)(x) elif 'mocov3' in enc_type or 'mae' in enc_type: x = x / 255. x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) elif 'dinov2' in enc_type: x = x / 255. x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) x = torch.nn.functional.interpolate(x, 224 * (resolution // 256), mode='bicubic') elif 'dinov1' in enc_type: x = x / 255. x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) elif 'jepa' in enc_type: x = x / 255. x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) x = torch.nn.functional.interpolate(x, 224 * (resolution // 256), mode='bicubic') return x def array2grid(x): nrow = round(math.sqrt(x.size(0))) x = make_grid(x.clamp(0, 1), nrow=nrow, value_range=(0, 1)) x = x.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() return x @torch.no_grad() def sample_posterior(moments, latents_scale=1., latents_bias=0.): device = moments.device mean, std = torch.chunk(moments, 2, dim=1) z = mean + std * torch.randn_like(mean) z = (z * latents_scale + latents_bias) return z @torch.no_grad() def update_ema(ema_model, model, decay=0.9999): """ Step the EMA model towards the current model. """ ema_params = OrderedDict(ema_model.named_parameters()) model_params = OrderedDict(model.named_parameters()) for name, param in model_params.items(): name = name.replace("module.", "") # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) def create_logger(logging_dir): """ Create a logger that writes to a log file and stdout. """ logging.basicConfig( level=logging.INFO, format='[\033[34m%(asctime)s\033[0m] %(message)s', datefmt='%Y-%m-%d %H:%M:%S', handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")] ) logger = logging.getLogger(__name__) return logger def requires_grad(model, flag=True): """ Set requires_grad flag for all parameters in a model. """ for p in model.parameters(): p.requires_grad = flag ################################################################################# # Training Loop # ################################################################################# def main(args): # set accelerator logging_dir = Path(args.output_dir, args.logging_dir) accelerator_project_config = ProjectConfiguration( project_dir=args.output_dir, logging_dir=logging_dir ) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, log_with=args.report_to, project_config=accelerator_project_config, kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)] ) if accelerator.is_main_process: os.makedirs(args.output_dir, exist_ok=True) # Make results folder (holds all experiment subfolders) save_dir = os.path.join(args.output_dir, args.exp_name) os.makedirs(save_dir, exist_ok=True) args_dict = vars(args) # Save to a JSON file json_dir = os.path.join(save_dir, "args.json") with open(json_dir, 'w') as f: json.dump(args_dict, f, indent=4) checkpoint_dir = f"{save_dir}/checkpoints" # Stores saved model checkpoints os.makedirs(checkpoint_dir, exist_ok=True) logger = create_logger(save_dir) logger.info(f"Experiment directory created at {save_dir}") device = accelerator.device if torch.backends.mps.is_available(): accelerator.native_amp = False if args.seed is not None: set_seed(args.seed + accelerator.process_index) # Create model: assert args.resolution % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)." latent_size = args.resolution // 8 train_dataset = CustomDataset( args.data_dir, semantic_features_dir=args.semantic_features_dir ) use_preprocessed_semantic = train_dataset.use_preprocessed_semantic if use_preprocessed_semantic: encoders, encoder_types, architectures = [], [], [] z_dims = [semantic_dim_from_enc_type(args.enc_type)] if accelerator.is_main_process: logger.info( f"Preprocessed semantic features: skip loading online encoder, z_dims={z_dims}" ) elif args.enc_type is not None: encoders, encoder_types, architectures = load_encoders( args.enc_type, device, args.resolution ) z_dims = [encoder.embed_dim for encoder in encoders] else: raise NotImplementedError() block_kwargs = {"fused_attn": args.fused_attn, "qk_norm": args.qk_norm} model = SiT_models[args.model]( input_size=latent_size, num_classes=args.num_classes, use_cfg = (args.cfg_prob > 0), z_dims = z_dims, encoder_depth=args.encoder_depth, **block_kwargs ) model = model.to(device) ema = deepcopy(model).to(device) # Create an EMA of the model for use after training requires_grad(ema, False) latents_scale = torch.tensor( [0.18215, 0.18215, 0.18215, 0.18215] ).view(1, 4, 1, 1).to(device) latents_bias = torch.tensor( [0., 0., 0., 0.] ).view(1, 4, 1, 1).to(device) # VAE decoder:采样阶段将 latent 解码为图像(与根目录 train.py / 预处理一致:sd-vae-ft-mse) try: from preprocessing import dnnlib cache_dir = dnnlib.make_cache_dir_path("diffusers") os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1" os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" os.environ["HF_HOME"] = cache_dir try: vae = AutoencoderKL.from_pretrained( "stabilityai/sd-vae-ft-mse", cache_dir=cache_dir, local_files_only=True, ).to(device) vae.eval() if accelerator.is_main_process: logger.info( "Loaded VAE 'stabilityai/sd-vae-ft-mse' from local diffusers cache " f"at '{cache_dir}' for intermediate sampling." ) except Exception as e_main: vae = None candidate_dir = None possible_roots = [ cache_dir, os.path.join(os.path.expanduser("~"), ".cache", "dnnlib", "diffusers"), os.path.join(os.path.expanduser("~"), ".cache", "diffusers"), os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub"), ] checked_roots = [] for root_dir in possible_roots: if not os.path.isdir(root_dir): continue checked_roots.append(root_dir) for root, dirs, files in os.walk(root_dir): if "config.json" in files and "sd-vae-ft-mse" in root.replace("\\", "/"): candidate_dir = root break if candidate_dir is not None: break if candidate_dir is not None: try: vae = AutoencoderKL.from_pretrained( candidate_dir, local_files_only=True, ).to(device) vae.eval() if accelerator.is_main_process: logger.info( "Loaded VAE 'stabilityai/sd-vae-ft-mse' from discovered local path " f"'{candidate_dir}'. Searched roots: {checked_roots}" ) except Exception as e_fallback: if accelerator.is_main_process: logger.warning( "Tried to load VAE from discovered local path " f"'{candidate_dir}' but failed: {e_fallback}" ) if vae is None and accelerator.is_main_process: logger.warning( "Could not load VAE 'stabilityai/sd-vae-ft-mse' via repo name or local search. " f"Last repo-level error: {e_main}" ) except Exception as e: vae = None if accelerator.is_main_process: logger.warning( f"Failed to initialize VAE loading logic (will skip image decoding): {e}" ) # create loss function loss_fn = SILoss( prediction=args.prediction, path_type=args.path_type, encoders=encoders, accelerator=accelerator, latents_scale=latents_scale, latents_bias=latents_bias, weighting=args.weighting, t_c=args.t_c, ot_cls=args.ot_cls, ) if accelerator.is_main_process: logger.info(f"SiT Parameters: {sum(p.numel() for p in model.parameters()):,}") # Setup optimizer (we used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper): if args.allow_tf32: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True optimizer = torch.optim.AdamW( model.parameters(), lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, ) # Setup data(train_dataset 已在上方创建) local_batch_size = int(args.batch_size // accelerator.num_processes) train_dataloader = DataLoader( train_dataset, batch_size=local_batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, drop_last=True ) if accelerator.is_main_process: logger.info(f"Dataset contains {len(train_dataset):,} images ({args.data_dir})") # Prepare models for training: update_ema(ema, model, decay=0) # Ensure EMA is initialized with synced weights model.train() # important! This enables embedding dropout for classifier-free guidance ema.eval() # EMA model should always be in eval mode # resume: global_step = 0 if args.resume_from_ckpt is not None: ckpt = torch.load(args.resume_from_ckpt, map_location="cpu") model.load_state_dict(ckpt["model"]) ema.load_state_dict(ckpt["ema"]) if "opt" in ckpt: optimizer.load_state_dict(ckpt["opt"]) global_step = int(ckpt.get("steps", 0)) if accelerator.is_main_process: logger.info( f"Resumed from ckpt: {args.resume_from_ckpt} (global_step={global_step})" ) elif args.resume_step > 0: ckpt_name = str(args.resume_step).zfill(7) +'.pt' ckpt = torch.load( f'{os.path.join(args.output_dir, args.exp_name)}/checkpoints/{ckpt_name}', map_location='cpu', ) model.load_state_dict(ckpt['model']) ema.load_state_dict(ckpt['ema']) optimizer.load_state_dict(ckpt['opt']) global_step = ckpt['steps'] model, optimizer, train_dataloader = accelerator.prepare( model, optimizer, train_dataloader ) if accelerator.is_main_process: tracker_config = vars(copy.deepcopy(args)) accelerator.init_trackers( project_name="REG", config=tracker_config, init_kwargs={ "wandb": {"name": f"{args.exp_name}"} }, ) progress_bar = tqdm( range(0, args.max_train_steps), initial=global_step, desc="Steps", # Only show the progress bar once on each machine. disable=not accelerator.is_local_main_process, ) # Labels to condition the model with (feel free to change): sample_batch_size = 64 // accelerator.num_processes first_batch = next(iter(train_dataloader)) if len(first_batch) == 4: gt_raw_images, gt_xs, _, _ = first_batch else: gt_raw_images, gt_xs, _ = first_batch assert gt_raw_images.shape[-1] == args.resolution gt_xs = gt_xs[:sample_batch_size] gt_xs = sample_posterior( gt_xs.to(device), latents_scale=latents_scale, latents_bias=latents_bias ) ys = torch.randint(1000, size=(sample_batch_size,), device=device) ys = ys.to(device) # Create sampling noise: n = ys.size(0) xT = torch.randn((n, 4, latent_size, latent_size), device=device) for epoch in range(args.epochs): model.train() for batch in train_dataloader: if len(batch) == 4: raw_image, x, r_preprocessed, y = batch use_sem_file = True else: raw_image, x, y = batch r_preprocessed = None use_sem_file = False raw_image = raw_image.to(device) x = x.squeeze(dim=1).to(device).float() y = y.to(device) if args.legacy: # In our early experiments, we accidentally apply label dropping twice: # once in train.py and once in sit.py. # We keep this option for exact reproducibility with previous runs. drop_ids = torch.rand(y.shape[0], device=y.device) < args.cfg_prob labels = torch.where(drop_ids, args.num_classes, y) else: labels = y with torch.no_grad(): x = sample_posterior(x, latents_scale=latents_scale, latents_bias=latents_bias) zs = [] if use_sem_file and r_preprocessed is not None: cls_token = r_preprocessed.to(device).float() if cls_token.dim() == 1: cls_token = cls_token.unsqueeze(0) while cls_token.dim() > 2: cls_token = cls_token.squeeze(1) base_m = model.module if hasattr(model, "module") else model n_pad = base_m.x_embedder.num_patches zs = [ torch.cat( [ cls_token.unsqueeze(1), cls_token.unsqueeze(1).expand(-1, n_pad, -1), ], dim=1, ) ] else: with accelerator.autocast(): for encoder, encoder_type, arch in zip( encoders, encoder_types, architectures ): raw_image_ = preprocess_raw_image(raw_image, encoder_type) z = encoder.forward_features(raw_image_) if 'dinov2' in encoder_type: dense_z = z['x_norm_patchtokens'] cls_token = z['x_norm_clstoken'] dense_z = torch.cat([cls_token.unsqueeze(1), dense_z], dim=1) else: exit() zs.append(dense_z) with accelerator.accumulate(model): model_kwargs = dict(y=labels) loss1, proj_loss1, time_input, noises, loss2 = loss_fn(model, x, model_kwargs, zs=zs, cls_token=cls_token, time_input=None, noises=None) loss_mean = loss1.mean() loss_mean_cls = loss2.mean() * args.cls proj_loss_mean = proj_loss1.mean() * args.proj_coeff tc_vel_loss = torch.tensor(0.0, device=device) if args.tc_velocity_loss_coeff > 0: tc_vel_loss = loss_fn.tc_velocity_loss( model, x, model_kwargs=model_kwargs, cls_token=cls_token, noises=noises, ).mean() loss = ( loss_mean + proj_loss_mean + loss_mean_cls + args.tc_velocity_loss_coeff * tc_vel_loss ) ## optimization accelerator.backward(loss) if accelerator.sync_gradients: params_to_clip = model.parameters() grad_norm = accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() optimizer.zero_grad(set_to_none=True) if accelerator.sync_gradients: update_ema(ema, model) # change ema function ### enter if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 if global_step % args.checkpointing_steps == 0 and global_step > 0: if accelerator.is_main_process: checkpoint = { "model": model.module.state_dict(), "ema": ema.state_dict(), "opt": optimizer.state_dict(), "args": args, "steps": global_step, } checkpoint_path = f"{checkpoint_dir}/{global_step:07d}.pt" torch.save(checkpoint, checkpoint_path) logger.info(f"Saved checkpoint to {checkpoint_path}") if (global_step == 1 or (global_step % args.sampling_steps == 0 and global_step > 0)): t_mid_vis = float(args.t_c) tc_tag = f"{t_mid_vis:.4f}".rstrip("0").rstrip(".").replace(".", "_") logging.info( f"Generating EMA samples (Euler-Maruyama; t≈{t_mid_vis:g} → t=0)..." ) ema.eval() with torch.no_grad(): latent_size = args.resolution // 8 n_samples = min(16, args.batch_size) base_model = model.module if hasattr(model, "module") else model cls_dim = base_model.z_dims[0] shared_seed = torch.randint(0, 2**32, (1,), device=device).item() torch.manual_seed(shared_seed) z_init = torch.randn(n_samples, base_model.in_channels, latent_size, latent_size, device=device) torch.manual_seed(shared_seed) cls_init = torch.randn(n_samples, cls_dim, device=device) y_samples = torch.randint(0, args.num_classes, (n_samples,), device=device) from samplers import euler_maruyama_sampler z_0, z_mid, _ = euler_maruyama_sampler( ema, z_init, y_samples, num_steps=50, cfg_scale=1.0, guidance_low=0.0, guidance_high=1.0, path_type=args.path_type, cls_latents=cls_init, args=args, return_mid_state=True, t_mid=t_mid_vis, ) samples_root = os.path.join(args.output_dir, args.exp_name, "samples") t0_dir = os.path.join(samples_root, "t0") t_mid_dir = os.path.join(samples_root, f"t0_{tc_tag}") os.makedirs(t0_dir, exist_ok=True) os.makedirs(t_mid_dir, exist_ok=True) if vae is not None: z_f = z_0.to(dtype=torch.float32) samples_final = vae.decode((z_f - latents_bias) / latents_scale).sample samples_final = (samples_final + 1) / 2.0 samples_final = samples_final.clamp(0, 1) grid_final = array2grid(samples_final) Image.fromarray(grid_final).save( os.path.join(t0_dir, f"step_{global_step:07d}_t0.png") ) if z_mid is not None: z_m = z_mid.to(dtype=torch.float32) samples_mid = vae.decode((z_m - latents_bias) / latents_scale).sample samples_mid = (samples_mid + 1) / 2.0 samples_mid = samples_mid.clamp(0, 1) grid_mid = array2grid(samples_mid) Image.fromarray(grid_mid).save( os.path.join(t_mid_dir, f"step_{global_step:07d}_t0_{tc_tag}.png") ) else: logging.warning( f"Sampling time grid did not bracket t_mid={t_mid_vis:g}; " f"skip t0_{tc_tag} image this step." ) del z_init, cls_init, y_samples, z_0 if z_mid is not None: del z_mid if vae is not None: del samples_final, grid_final if "samples_mid" in locals(): del samples_mid, grid_mid torch.cuda.empty_cache() logs = { "loss_final": accelerator.gather(loss).mean().detach().item(), "loss_mean": accelerator.gather(loss_mean).mean().detach().item(), "proj_loss": accelerator.gather(proj_loss_mean).mean().detach().item(), "loss_mean_cls": accelerator.gather(loss_mean_cls).mean().detach().item(), "loss_tc_vel": accelerator.gather(tc_vel_loss).mean().detach().item(), "grad_norm": accelerator.gather(grad_norm).mean().detach().item() } log_message = ", ".join(f"{key}: {value:.6f}" for key, value in logs.items()) logging.info(f"Step: {global_step}, Training Logs: {log_message}") progress_bar.set_postfix(**logs) accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break if global_step >= args.max_train_steps: break model.eval() # important! This disables randomized embedding dropout # do any sampling/FID calculation/etc. with ema (or model) in eval mode ... accelerator.wait_for_everyone() if accelerator.is_main_process: logger.info("Done!") accelerator.end_training() def parse_args(input_args=None): parser = argparse.ArgumentParser(description="Training") # logging: parser.add_argument("--output-dir", type=str, default="exps") parser.add_argument("--exp-name", type=str, required=True) parser.add_argument("--logging-dir", type=str, default="logs") parser.add_argument("--report-to", type=str, default="wandb") parser.add_argument("--sampling-steps", type=int, default=2000) parser.add_argument("--resume-step", type=int, default=0) parser.add_argument( "--resume-from-ckpt", type=str, default=None, help="直接从指定 checkpoint 路径续训(优先于 --resume-step)。", ) # model parser.add_argument("--model", type=str) parser.add_argument("--num-classes", type=int, default=1000) parser.add_argument("--encoder-depth", type=int, default=8) parser.add_argument("--fused-attn", action=argparse.BooleanOptionalAction, default=True) parser.add_argument("--qk-norm", action=argparse.BooleanOptionalAction, default=False) parser.add_argument("--ops-head", type=int, default=16) # dataset parser.add_argument("--data-dir", type=str, default="../data/imagenet256") parser.add_argument( "--semantic-features-dir", type=str, default=None, help="预处理 DINOv2 class token 等特征目录(含 dataset.json)。" "默认 None 时若存在 data-dir/imagenet_256_features/dinov2-vit-b_tmp/gpu0 则自动使用。", ) parser.add_argument("--resolution", type=int, choices=[256, 512], default=256) parser.add_argument("--batch-size", type=int, default=256)#256 # precision parser.add_argument("--allow-tf32", action="store_true") parser.add_argument("--mixed-precision", type=str, default="fp16", choices=["no", "fp16", "bf16"]) # optimization parser.add_argument("--epochs", type=int, default=14000) parser.add_argument("--max-train-steps", type=int, default=10000000) parser.add_argument("--checkpointing-steps", type=int, default=10000) parser.add_argument("--gradient-accumulation-steps", type=int, default=1) parser.add_argument("--learning-rate", type=float, default=1e-4) parser.add_argument("--adam-beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") parser.add_argument("--adam-beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") parser.add_argument("--adam-weight-decay", type=float, default=0., help="Weight decay to use.") parser.add_argument("--adam-epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") parser.add_argument("--max-grad-norm", default=1.0, type=float, help="Max gradient norm.") # seed parser.add_argument("--seed", type=int, default=0) # cpu parser.add_argument("--num-workers", type=int, default=4) # loss parser.add_argument("--path-type", type=str, default="linear", choices=["linear", "cosine"]) parser.add_argument("--prediction", type=str, default="v", choices=["v"]) # currently we only support v-prediction parser.add_argument("--cfg-prob", type=float, default=0.1) parser.add_argument("--enc-type", type=str, default='dinov2-vit-b') parser.add_argument("--proj-coeff", type=float, default=0.5) parser.add_argument("--weighting", default="uniform", type=str, help="Max gradient norm.") parser.add_argument("--legacy", action=argparse.BooleanOptionalAction, default=False) parser.add_argument("--cls", type=float, default=0.03) parser.add_argument( "--t-c", type=float, default=0.5, help="语义分界时刻(与脚本内 t 约定一致:t=1 噪声→t=0 数据)。" "t∈(t_c,1]:cls 沿 OT 配对后的路径插值(CFM/OT-CFM 式 minibatch OT);" "t∈[0,t_c]:cls 固定为真实 encoder cls,目标 cls 速度为 0。", ) parser.add_argument( "--ot-cls", action=argparse.BooleanOptionalAction, default=True, help="在 t>t_c 段对 cls 噪声与 batch 内 cls_gt 做 minibatch 最优传输配对(需 scipy);关闭则退化为独立高斯噪声配对。", ) parser.add_argument( "--tc-velocity-loss-coeff", type=float, default=0.0, help="额外 t=t_c 图像速度场监督项权重(>0 启用,用于增强单步性)。", ) if input_args is not None: args = parser.parse_args(input_args) else: args = parser.parse_args() return args if __name__ == "__main__": args = parse_args() main(args)