jsflow / New /REG /train.py
xiangzai's picture
Add files using upload-large-folder tool
db5d40d verified
import argparse
import copy
from copy import deepcopy
import logging
import os
from pathlib import Path
from collections import OrderedDict
import json
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
from accelerate import Accelerator, DistributedDataParallelKwargs
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from models.sit import SiT_models
from loss import SILoss
from utils import load_encoders
from dataset import CustomDataset
from diffusers.models import AutoencoderKL
from PIL import Image
from samplers import euler_maruyama_sampler
# import wandb_utils
import wandb
import math
from torchvision.utils import make_grid
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision.transforms import Normalize
logger = get_logger(__name__)
CLIP_DEFAULT_MEAN = (0.48145466, 0.4578275, 0.40821073)
CLIP_DEFAULT_STD = (0.26862954, 0.26130258, 0.27577711)
def preprocess_raw_image(x, enc_type):
resolution = x.shape[-1]
if 'clip' in enc_type:
x = x / 255.
x = torch.nn.functional.interpolate(x, 224 * (resolution // 256), mode='bicubic')
x = Normalize(CLIP_DEFAULT_MEAN, CLIP_DEFAULT_STD)(x)
elif 'mocov3' in enc_type or 'mae' in enc_type:
x = x / 255.
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
elif 'dinov2' in enc_type:
x = x / 255.
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
x = torch.nn.functional.interpolate(x, 224 * (resolution // 256), mode='bicubic')
elif 'dinov1' in enc_type:
x = x / 255.
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
elif 'jepa' in enc_type:
x = x / 255.
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
x = torch.nn.functional.interpolate(x, 224 * (resolution // 256), mode='bicubic')
return x
def array2grid(x):
nrow = round(math.sqrt(x.size(0)))
x = make_grid(x.clamp(0, 1), nrow=nrow, value_range=(0, 1))
x = x.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
return x
@torch.no_grad()
def sample_posterior(moments, latents_scale=1., latents_bias=0.):
device = moments.device
mean, std = torch.chunk(moments, 2, dim=1)
z = mean + std * torch.randn_like(mean)
z = (z * latents_scale + latents_bias)
return z
@torch.no_grad()
def update_ema(ema_model, model, decay=0.9999):
"""
Step the EMA model towards the current model.
"""
ema_params = OrderedDict(ema_model.named_parameters())
model_params = OrderedDict(model.named_parameters())
for name, param in model_params.items():
name = name.replace("module.", "")
# TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
def create_logger(logging_dir):
"""
Create a logger that writes to a log file and stdout.
"""
logging.basicConfig(
level=logging.INFO,
format='[\033[34m%(asctime)s\033[0m] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
)
logger = logging.getLogger(__name__)
return logger
def requires_grad(model, flag=True):
"""
Set requires_grad flag for all parameters in a model.
"""
for p in model.parameters():
p.requires_grad = flag
#################################################################################
# Training Loop #
#################################################################################
def main(args):
# set accelerator
logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(
project_dir=args.output_dir, logging_dir=logging_dir
)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config,
kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)]
)
if accelerator.is_main_process:
os.makedirs(args.output_dir, exist_ok=True) # Make results folder (holds all experiment subfolders)
save_dir = os.path.join(args.output_dir, args.exp_name)
os.makedirs(save_dir, exist_ok=True)
args_dict = vars(args)
# Save to a JSON file
json_dir = os.path.join(save_dir, "args.json")
with open(json_dir, 'w') as f:
json.dump(args_dict, f, indent=4)
checkpoint_dir = f"{save_dir}/checkpoints" # Stores saved model checkpoints
os.makedirs(checkpoint_dir, exist_ok=True)
logger = create_logger(save_dir)
logger.info(f"Experiment directory created at {save_dir}")
device = accelerator.device
if torch.backends.mps.is_available():
accelerator.native_amp = False
if args.seed is not None:
set_seed(args.seed + accelerator.process_index)
# Create model:
assert args.resolution % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)."
latent_size = args.resolution // 8
if args.enc_type != None:
encoders, encoder_types, architectures = load_encoders(
args.enc_type, device, args.resolution
)
else:
raise NotImplementedError()
z_dims = [encoder.embed_dim for encoder in encoders] if args.enc_type != 'None' else [0]
block_kwargs = {"fused_attn": args.fused_attn, "qk_norm": args.qk_norm}
model = SiT_models[args.model](
input_size=latent_size,
num_classes=args.num_classes,
use_cfg = (args.cfg_prob > 0),
z_dims = z_dims,
encoder_depth=args.encoder_depth,
**block_kwargs
)
model = model.to(device)
ema = deepcopy(model).to(device) # Create an EMA of the model for use after training
requires_grad(ema, False)
latents_scale = torch.tensor(
[0.18215, 0.18215, 0.18215, 0.18215]
).view(1, 4, 1, 1).to(device)
latents_bias = torch.tensor(
[0., 0., 0., 0.]
).view(1, 4, 1, 1).to(device)
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device)
vae.eval()
# create loss function
loss_fn = SILoss(
prediction=args.prediction,
path_type=args.path_type,
encoders=encoders,
accelerator=accelerator,
latents_scale=latents_scale,
latents_bias=latents_bias,
weighting=args.weighting
)
if accelerator.is_main_process:
logger.info(f"SiT Parameters: {sum(p.numel() for p in model.parameters()):,}")
# Setup optimizer (we used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper):
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
optimizer = torch.optim.AdamW(
model.parameters(),
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
# Setup data:
train_dataset = CustomDataset(
args.data_dir, semantic_features_dir=args.semantic_features_dir
)
local_batch_size = int(args.batch_size // accelerator.num_processes)
train_dataloader = DataLoader(
train_dataset,
batch_size=local_batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=True,
drop_last=True
)
if accelerator.is_main_process:
logger.info(f"Dataset contains {len(train_dataset):,} images ({args.data_dir})")
# Prepare models for training:
update_ema(ema, model, decay=0) # Ensure EMA is initialized with synced weights
model.train() # important! This enables embedding dropout for classifier-free guidance
ema.eval() # EMA model should always be in eval mode
# resume:
global_step = 0
if args.resume_step > 0:
ckpt_name = str(args.resume_step).zfill(7) +'.pt'
ckpt = torch.load(
f'{os.path.join(args.output_dir, args.exp_name)}/checkpoints/{ckpt_name}',
map_location='cpu',
)
model.load_state_dict(ckpt['model'])
ema.load_state_dict(ckpt['ema'])
optimizer.load_state_dict(ckpt['opt'])
global_step = ckpt['steps']
model, optimizer, train_dataloader = accelerator.prepare(
model, optimizer, train_dataloader
)
if accelerator.is_main_process:
tracker_config = vars(copy.deepcopy(args))
accelerator.init_trackers(
project_name="REG",
config=tracker_config,
init_kwargs={
"wandb": {"name": f"{args.exp_name}"}
},
)
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=global_step,
desc="Steps",
# Only show the progress bar once on each machine.
disable=not accelerator.is_local_main_process,
)
# Labels to condition the model with (feel free to change):
sample_batch_size = 64 // accelerator.num_processes
first_batch = next(iter(train_dataloader))
preprocessed_semantic = len(first_batch) == 4
if preprocessed_semantic:
gt_raw_images, gt_xs, _r_pre, _y = first_batch
else:
gt_raw_images, gt_xs, _y = first_batch
# 仅在“非预处理 semantic 模式”下,raw_image 是 RGB 图(分辨率应与 args.resolution 对齐)。
if not preprocessed_semantic:
assert gt_raw_images.shape[-1] == args.resolution
gt_xs = gt_xs[:sample_batch_size]
gt_xs = sample_posterior(
gt_xs.to(device), latents_scale=latents_scale, latents_bias=latents_bias
)
ys = torch.randint(1000, size=(sample_batch_size,), device=device)
ys = ys.to(device)
# Create sampling noise:
n = ys.size(0)
xT = torch.randn((n, 4, latent_size, latent_size), device=device)
for epoch in range(args.epochs):
model.train()
for batch in train_dataloader:
if len(batch) == 4:
raw_image, x, r_preprocessed, y = batch
r_preprocessed = r_preprocessed.to(device).float()
else:
raw_image, x, y = batch
r_preprocessed = None
raw_image = raw_image.to(device)
x = x.squeeze(dim=1).to(device)
y = y.to(device)
z = None
if args.legacy:
# In our early experiments, we accidentally apply label dropping twice:
# once in train.py and once in sit.py.
# We keep this option for exact reproducibility with previous runs.
drop_ids = torch.rand(y.shape[0], device=y.device) < args.cfg_prob
labels = torch.where(drop_ids, args.num_classes, y)
else:
labels = y
with torch.no_grad():
x = sample_posterior(x, latents_scale=latents_scale, latents_bias=latents_bias)
zs = []
if r_preprocessed is not None:
# 预处理 semantic 模式:直接用 cls token 构造 dense tokens。
cls_token = r_preprocessed
while cls_token.dim() > 2:
cls_token = cls_token.squeeze(1)
base_m = model.module if hasattr(model, "module") else model
n_pad = base_m.x_embedder.num_patches
zs = [
torch.cat(
[
cls_token.unsqueeze(1),
cls_token.unsqueeze(1).expand(-1, n_pad, -1),
],
dim=1,
)
]
else:
# 在线 encoder 模式:与原 New/REG 行为一致
with accelerator.autocast():
for encoder, encoder_type, arch in zip(
encoders, encoder_types, architectures
):
raw_image_ = preprocess_raw_image(raw_image, encoder_type)
z = encoder.forward_features(raw_image_)
if "dinov2" in encoder_type:
dense_z = z["x_norm_patchtokens"]
cls_token = z["x_norm_clstoken"]
dense_z = torch.cat(
[cls_token.unsqueeze(1), dense_z], dim=1
)
else:
exit()
zs.append(dense_z)
with accelerator.accumulate(model):
model_kwargs = dict(y=labels)
loss1, proj_loss1, time_input, noises, loss2 = loss_fn(model, x, model_kwargs, zs=zs,
cls_token=cls_token,
time_input=None, noises=None)
loss_mean = loss1.mean()
loss_mean_cls = loss2.mean() * args.cls
proj_loss_mean = proj_loss1.mean() * args.proj_coeff
loss = loss_mean + proj_loss_mean + loss_mean_cls
## optimization
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = model.parameters()
grad_norm = accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
optimizer.zero_grad(set_to_none=True)
if accelerator.sync_gradients:
update_ema(ema, model) # change ema function
### enter
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
if global_step % args.checkpointing_steps == 0 and global_step > 0:
if accelerator.is_main_process:
checkpoint = {
"model": model.module.state_dict(),
"ema": ema.state_dict(),
"opt": optimizer.state_dict(),
"args": args,
"steps": global_step,
}
checkpoint_path = f"{checkpoint_dir}/{global_step:07d}.pt"
torch.save(checkpoint, checkpoint_path)
logger.info(f"Saved checkpoint to {checkpoint_path}")
if (global_step == 1 or (global_step % args.sampling_steps == 0 and global_step > 0)):
t_mid_vis = float(args.t_c)
tc_tag = f"{t_mid_vis:.4f}".rstrip("0").rstrip(".").replace(".", "_")
logging.info(
f"Generating EMA samples (Euler-Maruyama; t≈{t_mid_vis:g} -> t=0)..."
)
ema.eval()
with torch.no_grad():
latent_size = args.resolution // 8
n_samples = min(16, args.batch_size)
base_model = model.module if hasattr(model, "module") else model
cls_dim = base_model.z_dims[0]
shared_seed = torch.randint(0, 2**32, (1,), device=device).item()
torch.manual_seed(shared_seed)
z_init = torch.randn(n_samples, base_model.in_channels, latent_size, latent_size, device=device)
torch.manual_seed(shared_seed)
cls_init = torch.randn(n_samples, cls_dim, device=device)
y_samples = torch.randint(0, args.num_classes, (n_samples,), device=device)
z_0, z_mid, _ = euler_maruyama_sampler(
ema,
z_init,
y_samples,
num_steps=50,
cfg_scale=1.0,
guidance_low=0.0,
guidance_high=1.0,
path_type=args.path_type,
cls_latents=cls_init,
args=args,
return_mid_state=True,
t_mid=t_mid_vis,
)
samples_root = os.path.join(args.output_dir, args.exp_name, "samples")
t0_dir = os.path.join(samples_root, "t0")
t_mid_dir = os.path.join(samples_root, f"t0_{tc_tag}")
os.makedirs(t0_dir, exist_ok=True)
os.makedirs(t_mid_dir, exist_ok=True)
z_f = z_0.to(dtype=torch.float32)
samples_final = vae.decode((z_f - latents_bias) / latents_scale).sample
samples_final = (samples_final + 1) / 2.0
samples_final = samples_final.clamp(0, 1)
grid_final = array2grid(samples_final)
Image.fromarray(grid_final).save(
os.path.join(t0_dir, f"step_{global_step:07d}_t0.png")
)
if z_mid is not None:
z_m = z_mid.to(dtype=torch.float32)
samples_mid = vae.decode((z_m - latents_bias) / latents_scale).sample
samples_mid = (samples_mid + 1) / 2.0
samples_mid = samples_mid.clamp(0, 1)
grid_mid = array2grid(samples_mid)
Image.fromarray(grid_mid).save(
os.path.join(t_mid_dir, f"step_{global_step:07d}_t0_{tc_tag}.png")
)
else:
logging.warning(
f"Sampling time grid did not bracket t_mid={t_mid_vis:g}; "
f"skip t0_{tc_tag} image this step."
)
del z_init, cls_init, y_samples, z_0
if z_mid is not None:
del z_mid
del samples_final, grid_final
if "samples_mid" in locals():
del samples_mid, grid_mid
torch.cuda.empty_cache()
logs = {
"loss_final": accelerator.gather(loss).mean().detach().item(),
"loss_mean": accelerator.gather(loss_mean).mean().detach().item(),
"proj_loss": accelerator.gather(proj_loss_mean).mean().detach().item(),
"loss_mean_cls": accelerator.gather(loss_mean_cls).mean().detach().item(),
"grad_norm": accelerator.gather(grad_norm).mean().detach().item()
}
log_message = ", ".join(f"{key}: {value:.6f}" for key, value in logs.items())
logging.info(f"Step: {global_step}, Training Logs: {log_message}")
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
if global_step >= args.max_train_steps:
break
model.eval() # important! This disables randomized embedding dropout
# do any sampling/FID calculation/etc. with ema (or model) in eval mode ...
accelerator.wait_for_everyone()
if accelerator.is_main_process:
logger.info("Done!")
accelerator.end_training()
def parse_args(input_args=None):
parser = argparse.ArgumentParser(description="Training")
# logging:
parser.add_argument("--output-dir", type=str, default="exps")
parser.add_argument("--exp-name", type=str, required=True)
parser.add_argument("--logging-dir", type=str, default="logs")
parser.add_argument("--report-to", type=str, default="wandb")
parser.add_argument("--sampling-steps", type=int, default=10000)
parser.add_argument("--resume-step", type=int, default=0)
# model
parser.add_argument("--model", type=str)
parser.add_argument("--num-classes", type=int, default=1000)
parser.add_argument("--encoder-depth", type=int, default=8)
parser.add_argument("--fused-attn", action=argparse.BooleanOptionalAction, default=True)
parser.add_argument("--qk-norm", action=argparse.BooleanOptionalAction, default=False)
parser.add_argument("--ops-head", type=int, default=16)
# dataset
parser.add_argument("--data-dir", type=str, default="../data/imagenet256")
parser.add_argument(
"--semantic-features-dir",
type=str,
default=None,
help="预处理 semantic features 目录(与 REG/dataset.py 语义相同),仅影响数据加载方式,不引入 t_c。",
)
parser.add_argument("--resolution", type=int, choices=[256, 512], default=256)
parser.add_argument("--batch-size", type=int, default=8)#256
# precision
parser.add_argument("--allow-tf32", action="store_true")
parser.add_argument("--mixed-precision", type=str, default="fp16", choices=["no", "fp16", "bf16"])
# optimization
parser.add_argument("--epochs", type=int, default=1400)
parser.add_argument("--max-train-steps", type=int, default=1000000)
parser.add_argument("--checkpointing-steps", type=int, default=10000)
parser.add_argument("--gradient-accumulation-steps", type=int, default=1)
parser.add_argument("--learning-rate", type=float, default=1e-4)
parser.add_argument("--adam-beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
parser.add_argument("--adam-beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
parser.add_argument("--adam-weight-decay", type=float, default=0., help="Weight decay to use.")
parser.add_argument("--adam-epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
parser.add_argument("--max-grad-norm", default=1.0, type=float, help="Max gradient norm.")
# seed
parser.add_argument("--seed", type=int, default=0)
# cpu
parser.add_argument("--num-workers", type=int, default=4)
# loss
parser.add_argument("--path-type", type=str, default="linear", choices=["linear", "cosine"])
parser.add_argument("--prediction", type=str, default="v", choices=["v"]) # currently we only support v-prediction
parser.add_argument("--cfg-prob", type=float, default=0.1)
parser.add_argument("--enc-type", type=str, default='dinov2-vit-b')
parser.add_argument("--proj-coeff", type=float, default=0.5)
parser.add_argument("--weighting", default="uniform", type=str, help="Max gradient norm.")
parser.add_argument("--legacy", action=argparse.BooleanOptionalAction, default=False)
parser.add_argument("--cls", type=float, default=0.03)
parser.add_argument(
"--t-c",
type=float,
default=0.5,
help="训练中采样时保存的中间时刻 t(用于输出 t0 与 t0_tc 对比图)。",
)
if input_args is not None:
args = parser.parse_args(input_args)
else:
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
main(args)