Spaces:
Sleeping
Sleeping
| import sys | |
| import os | |
| from safetensors.torch import load_file, save_file | |
| import argparse | |
| import itertools | |
| import math | |
| import random | |
| import subprocess | |
| import zipfile | |
| import io | |
| from pathlib import Path | |
| from typing import Iterable, Iterator | |
| try: | |
| import wandb | |
| except ImportError: # inference-only envs | |
| wandb = None | |
| import numpy as np | |
| import numpy.lib.format as np_format | |
| import torch | |
| import torch.nn.functional as F | |
| import torchvision.utils | |
| import yaml | |
| from einops import rearrange | |
| from accelerate import Accelerator, DataLoaderConfiguration, DistributedDataParallelKwargs | |
| from accelerate.utils import set_seed, ProjectConfiguration | |
| from dataset import ImageFolderDataset | |
| from model_gan import GANLoss | |
| from model_lpips import LPIPS, BothPerceptualLoss, build_perceptual_loss | |
| from models import ARModel, AROutput, Encoder, Decoder, Linear, EncoderOutput, VQLossDetail, insert_eos_token | |
| from tqdm import tqdm | |
| from omegaconf import OmegaConf | |
| from torch.utils.data import DataLoader | |
| import copy | |
| import torch.distributed as dist | |
| from glob import glob | |
| from utils import ( | |
| seed_everything, | |
| build_ar_logit_mask, | |
| make_worker_init_fn, | |
| load_accelerate_weights_only, | |
| draw_conditional_entropy, | |
| draw_data_conditional_entropy, | |
| plot_codebook_usage, | |
| plot_posterior_entropy, | |
| compute_posterior_entropy_from_logits, | |
| compute_aggregated_entropy_from_counts, | |
| InfiniteIterator, | |
| get_linear_schedule_with_warmup_peak, | |
| safe_remove_file, | |
| save_training_state, | |
| remove_old_best_checkpoints, | |
| adm_fid_evaluator, | |
| ema_update, | |
| img_denormalize, | |
| img_norm_to_uint8, | |
| img_uint8_to_norm, | |
| patchify, | |
| unpatchify, | |
| toggle_require_grad, | |
| toggle_train_eval, | |
| zero_nan_gradients, | |
| calc_grad_norm, | |
| print0, | |
| _unwrap, | |
| Target, | |
| Phase, | |
| parse_phases, | |
| parse_training_config_from_phases, | |
| get_phase, | |
| ) | |
| from util_model_profile import print_model_stats | |
| import gc | |
| from torchmetrics.functional.image import ( | |
| peak_signal_noise_ratio as calc_psnr, | |
| structural_similarity_index_measure as calc_ssim) | |
| torch.backends.cudnn.benchmark = True | |
| torch.backends.cudnn.deterministic = False | |
| torch._dynamo.config.recompile_limit = 64 | |
| def sampling(enc, dec, ar_model, bz, class_label=None, temperature=1.0, topK=None, topP=None, cfg=1.0, cfg_schedule=None, cfg_power=None, cache_kv=False, ae_label=None, | |
| semantic_cfg_schedule=None, semantic_cfg_scale=None, semantic_cfg_power=None, semantic_cfg_start=0.0, | |
| visual_cfg_schedule=None, visual_cfg_scale=None, visual_cfg_power=None, visual_cfg_start=1.0): | |
| enc_raw = _unwrap(enc) | |
| ar_raw = _unwrap(ar_model) | |
| token_ids = ar_raw.sampling(bz, class_label, temperature, topK, topP, cfg, cfg_schedule, cfg_power, cache_kv, | |
| semantic_cfg_schedule=semantic_cfg_schedule, | |
| semantic_cfg_scale=semantic_cfg_scale, | |
| semantic_cfg_power=semantic_cfg_power, | |
| semantic_cfg_start=semantic_cfg_start, | |
| visual_cfg_schedule=visual_cfg_schedule, | |
| visual_cfg_scale=visual_cfg_scale, | |
| visual_cfg_power=visual_cfg_power, | |
| visual_cfg_start=visual_cfg_start) | |
| quant = enc_raw.get_visual_codes(token_ids, ae_label) | |
| x_hat = dec(quant, ae_label) | |
| return x_hat | |
| def reconstruction(enc, dec, x, labels): | |
| out = enc(x, labels, training=False) | |
| x_hat = dec(out.visual_quant, labels) | |
| return x_hat, out.indices | |
| def calc_per_sample_reconstruction_metrics(img, img_hat, lpips): | |
| img = img.to(dtype=torch.float32) | |
| img_hat = img_hat.to(dtype=torch.float32) | |
| img01 = ((img + 1.0) * 0.5).clamp(0.0, 1.0) | |
| img_hat01 = ((img_hat + 1.0) * 0.5).clamp(0.0, 1.0) | |
| psnr_vals = calc_psnr(img_hat01, img01, data_range=1.0).to(dtype=torch.float32) | |
| ssim_vals = calc_ssim(img_hat01, img01, data_range=1.0,).to(dtype=torch.float32) | |
| out = lpips(img, img_hat) | |
| lpips_vals = out[0].mean() if isinstance(out, tuple) else out.mean() | |
| return lpips_vals, ssim_vals, psnr_vals | |
| def calculate_adaptive_weight_acc(nll_loss, g_loss, last_layer, accelerator, dec): | |
| # Avoid gradient all-reduce on decoder params during these probe grads. | |
| with accelerator.no_sync(dec): | |
| nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] | |
| g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] | |
| d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) | |
| d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() | |
| return d_weight | |
| def make_cache_kwargs(cached_enc_out, cached_imgs_hat): | |
| detach_cpu = lambda t: t.detach().cpu() if isinstance(t, torch.Tensor) else None | |
| if cached_enc_out is not None: | |
| return dict( | |
| cached_quant=detach_cpu(cached_enc_out.quant), | |
| cached_idx=detach_cpu(cached_enc_out.indices), | |
| cached_visual_quant=detach_cpu(cached_enc_out.visual_quant), | |
| cached_semantic_quant=detach_cpu(cached_enc_out.semantic_quant), | |
| cached_visual_indices=detach_cpu(cached_enc_out.visual_indices), | |
| cached_semantic_indices=detach_cpu(cached_enc_out.semantic_indices), | |
| cached_one_hot=detach_cpu(cached_enc_out.one_hot), | |
| cached_imgs_hat=detach_cpu(cached_imgs_hat), | |
| ) | |
| return dict(cached_quant=None, cached_idx=None, cached_imgs_hat=None) | |
| def train(config): | |
| if config.get("use_tf32", True): | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| if "EXPERIMENT_SAVE_DIR" in os.environ: | |
| save_dir = os.environ["EXPERIMENT_SAVE_DIR"] | |
| else: | |
| experiment_index = len(glob(f"{config.save_dir}/*")) | |
| save_dir = config.save_dir+f"/{experiment_index:03d}" | |
| os.environ["EXPERIMENT_SAVE_DIR"] = save_dir | |
| ddp_kwargs = DistributedDataParallelKwargs( | |
| find_unused_parameters=bool(config.get("find_unused_parameters", False)) | |
| ) | |
| accelerator = Accelerator( | |
| mixed_precision=config.precision if config.precision in ["fp16", "bf16"] else "no", | |
| log_with="wandb", | |
| project_config=ProjectConfiguration(project_dir=save_dir, logging_dir=save_dir), | |
| dataloader_config=DataLoaderConfiguration(even_batches=False), | |
| kwargs_handlers=[ddp_kwargs], | |
| ) | |
| phases = parse_phases(config.phases) | |
| train_ae, train_ar, use_lpips_loss, use_gan_loss, train_prior_enc = parse_training_config_from_phases(phases) | |
| total_phase_steps = sum(phase.num_steps for phase in phases) | |
| total_phase = len(phases) | |
| phase_step_accum = list(itertools.accumulate([phase.num_steps for phase in phases])) | |
| print0("Training Phases: ", phases) | |
| if config.torch_compile and use_gan_loss and config.disc_adaptive_weight: | |
| try: | |
| import torch._functorch.config as _functorch_config | |
| _functorch_config.donated_buffer = False | |
| if accelerator.is_main_process: | |
| print0("[warn] disabled torch._functorch.config.donated_buffer due to disc_adaptive_weight + torch_compile") | |
| except Exception: | |
| pass | |
| # prepare seed | |
| if config.seed is not None: | |
| seed = int(config.seed) if config.seed is not None else 0 | |
| set_seed(seed) | |
| seed_everything(seed) | |
| dl_generator = torch.Generator() | |
| dl_generator.manual_seed(seed) | |
| worker_init = make_worker_init_fn(seed) | |
| print0(config) | |
| print0(f"Global seed set to {config.seed}") | |
| # Models | |
| enc = Encoder(config) | |
| dec = Decoder(config) | |
| config["train_prior_enc"] = train_prior_enc | |
| ar_model = ARModel(config) if train_ar else None | |
| # Cache encoder properties before DDP / torch.compile wrapping | |
| _has_separate_semantic = enc.has_separate_semantic | |
| _visual_modules = enc.visual_modules | |
| _semantic_modules = enc.semantic_modules | |
| _prologue = config.get("Prologue", False) and not config.get("share_semantic_codebook", False) | |
| _ste_ar_embedding = config.get("ARModel", {}).get("ste_ar_embedding", False) | |
| _semantic_offset = int(config["Quantizer"]["codebook_size"]) if _prologue else 0 | |
| _use_eos = bool(config.get("use_eos", False)) and _prologue and int(config.get("z_len", 0)) > 0 | |
| _eos_offset = 1 if _use_eos else 0 | |
| _ae_no_label = bool(getattr(config, 'ae_no_label', False)) | |
| _prior_visual_dropout = float(config.get("prior_visual_dropout", 0.0)) | |
| if _prior_visual_dropout > 0 and not _use_eos: | |
| print0("WARNING: prior_visual_dropout requires use_eos=True and Prologue; forcing to 0.") | |
| _prior_visual_dropout = 0.0 | |
| _label_drop_always = float(getattr(config, "label_drop_prob", 0.0)) >= 1.0 # uncond viz/eval when always-drop | |
| if not train_ae: | |
| toggle_require_grad(enc, False, sub_modules=_visual_modules if _has_separate_semantic else None) | |
| toggle_train_eval(enc, train=False, sub_modules=_visual_modules if _has_separate_semantic else None) | |
| toggle_require_grad(dec, False) | |
| toggle_train_eval(dec, train=False) | |
| gan_loss = GANLoss(**config.GANLoss) if use_gan_loss else None | |
| perceptual_loss = build_perceptual_loss(config.get("perceptual_network", "vgg")).to(accelerator.device).eval().requires_grad_(False) if use_lpips_loss else None | |
| # AR logit mask: visual/semantic segments + optional EOS row. | |
| if ar_model is not None: | |
| _logit_mask = build_ar_logit_mask( | |
| getattr(enc.quantizer, 'pos_select_mask', None), | |
| getattr(getattr(enc, 'semantic_quantizer', None), 'pos_select_mask', None), | |
| vis_cb_size=int(config["Quantizer"]["codebook_size"]), | |
| sem_cb_size=int(config["SemanticQuantizer"]["codebook_size"]) if _prologue else 0, | |
| ) | |
| ar_model.set_logit_mask(_logit_mask) | |
| print0(f"logit_mask: {'set' if ar_model.logit_mask is not None else 'None'}") | |
| # EMA models | |
| ae_ema_rate = math.pow(0.5,1/config.ae_ema_halflife) if config.ae_ema_halflife>0 else 0. | |
| ar_ema_rate = math.pow(0.5,1/config.ar_ema_halflife) if config.ar_ema_halflife>0 else 0. | |
| enc_ema = copy.deepcopy(enc).requires_grad_(False).eval() if ae_ema_rate>0 else None | |
| dec_ema = copy.deepcopy(dec).requires_grad_(False).eval() if ae_ema_rate>0 else None | |
| ar_model_ema = copy.deepcopy(ar_model).requires_grad_(False).eval() if train_ar and ar_ema_rate>0 else None | |
| fixed_ar_model = copy.deepcopy(ar_model).requires_grad_(False).eval() if train_prior_enc else None | |
| if config.resume_ckpt_path != "" and not config.resume_train: | |
| ckpt_path = config.resume_ckpt_path | |
| if config.resume_enc: | |
| sd = load_file(os.path.join(ckpt_path, "model.safetensors")) | |
| load_visual_only = _has_separate_semantic and not any(k.startswith("semantic_enc.") for k in sd) | |
| if load_visual_only: | |
| use_ema_as_visual = not train_ae | |
| visual_sd = load_file(os.path.join(ckpt_path, "model_2.safetensors")) if use_ema_as_visual else sd | |
| enc.enc.load_state_dict({k.removeprefix("enc."): v for k, v in visual_sd.items() if k.startswith("enc.")}, strict=True) | |
| enc.quantizer.load_state_dict({k.removeprefix("quantizer."): v for k, v in visual_sd.items() if k.startswith("quantizer.")}, strict=True) | |
| else: | |
| enc.load_state_dict(sd, strict=True) | |
| if enc_ema is not None: | |
| ema_sd = load_file(os.path.join(ckpt_path, "model_2.safetensors")) if not load_visual_only or not use_ema_as_visual else visual_sd | |
| if load_visual_only: | |
| enc_ema.enc.load_state_dict({k.removeprefix("enc."): v for k, v in ema_sd.items() if k.startswith("enc.")}, strict=True) | |
| enc_ema.quantizer.load_state_dict({k.removeprefix("quantizer."): v for k, v in ema_sd.items() if k.startswith("quantizer.")}, strict=True) | |
| else: | |
| enc_ema.load_state_dict(ema_sd, strict=True) | |
| print0(f"Loaded encoder from {ckpt_path}" + (" (visual only, ema)" if load_visual_only and use_ema_as_visual else " (visual only)" if load_visual_only else "")) | |
| if config.resume_dec: | |
| use_ema_as_dec = not train_ae | |
| dec_file = "model_3.safetensors" if use_ema_as_dec else "model_1.safetensors" | |
| dec.load_state_dict(load_file(os.path.join(ckpt_path, dec_file)), strict=True) | |
| if dec_ema is not None: | |
| dec_ema.load_state_dict(load_file(os.path.join(ckpt_path, "model_3.safetensors")), strict=True) | |
| print0(f"Loaded decoder from {ckpt_path}" + (f" (ema as init)" if use_ema_as_dec else "")) | |
| if config.resume_ar: | |
| ar_model.load_state_dict(load_file(os.path.join(ckpt_path, "model_5.safetensors")), strict=True) | |
| if ar_model_ema is not None: | |
| ar_model_ema.load_state_dict(load_file(os.path.join(ckpt_path, "model_6.safetensors")), strict=True) | |
| if train_prior_enc: | |
| fixed_ar_model.load_state_dict(load_file(os.path.join(ckpt_path, "model_7.safetensors")), strict=True) | |
| print0(f"Loaded AR model from {ckpt_path}" + (f" (+ema)" if ar_model_ema is not None else "") + (f" (+fixed)" if train_prior_enc else "")) | |
| if config.resume_gan: | |
| gan_loss.load_state_dict(load_file(os.path.join(ckpt_path, "model_4.safetensors")), strict=True) | |
| print0(f"Loaded GAN loss from {ckpt_path}") | |
| opt_enc = torch.optim.AdamW( enc.semantic_parameters() if (_has_separate_semantic and not train_ae) else enc.parameters(), lr=config.lr_enc, betas=(config.beta1, config.beta2), weight_decay=config.weight_decay_enc) if (train_ae or (_has_separate_semantic and train_ar)) else None | |
| if train_ar: | |
| no_decay_keywords = ['bias', 'norm', 'adaln'] | |
| decay_params, nodecay_params = [], [] | |
| for n, p in ar_model.named_parameters(): | |
| if not p.requires_grad: | |
| continue | |
| if p.dim() < 2 or any(kw in n for kw in no_decay_keywords): | |
| nodecay_params.append(p) | |
| else: | |
| decay_params.append(p) | |
| print0(f"AR weight decay groups: {len(decay_params)} decay ({sum(p.numel() for p in decay_params):,} params), " | |
| f"{len(nodecay_params)} no-decay ({sum(p.numel() for p in nodecay_params):,} params)") | |
| opt_ar = torch.optim.AdamW( | |
| [ | |
| {"params": decay_params, "weight_decay": config.weight_decay_ar}, | |
| {"params": nodecay_params, "weight_decay": 0.0}, | |
| ], | |
| lr=config.lr_ar, | |
| betas=(config.ar_beta1, config.ar_beta2), | |
| ) | |
| else: | |
| opt_ar = None | |
| opt_dec = torch.optim.AdamW(dec.parameters(), lr=config.lr_dec, betas=(config.beta1, config.beta2), weight_decay=config.weight_decay_dec) if train_ae else None | |
| opt_gan_loss = torch.optim.AdamW(gan_loss.parameters(), lr=config.lr_gan_loss, betas=(config.beta1, config.beta2), weight_decay=config.weight_decay_gan) if use_gan_loss else None | |
| _gc_fallback = float(getattr(config, 'grad_clip', 0.0)) | |
| grad_clip_enc = float(getattr(config, 'grad_clip_enc', _gc_fallback)) | |
| grad_clip_dec = float(getattr(config, 'grad_clip_dec', _gc_fallback)) | |
| grad_clip_ar = float(getattr(config, 'grad_clip_ar', _gc_fallback)) | |
| grad_clip_gan = float(getattr(config, 'grad_clip_gan', _gc_fallback)) | |
| print0(f"Grad clip: enc={grad_clip_enc}, dec={grad_clip_dec}, ar={grad_clip_ar}, gan={grad_clip_gan}") | |
| dataset = ImageFolderDataset(path=config['data_dir'], | |
| resolution=config.image_size, | |
| use_label=config.use_label, | |
| max_size=None, | |
| xflip=config.xflip, | |
| crop_type=config.crop_type, | |
| deterministic_crop=bool(getattr(config, "deterministic_crop", False)), | |
| crop_seed=int(config.seed) if config.seed is not None else 0, | |
| ) | |
| train_loader = DataLoader( | |
| dataset, | |
| batch_size=config.batch_size // accelerator.num_processes, | |
| shuffle=True, | |
| num_workers=config.num_workers, | |
| pin_memory=True, | |
| drop_last=True, | |
| persistent_workers=False, | |
| worker_init_fn=worker_init, | |
| generator=dl_generator, | |
| ) | |
| eval_loader=None | |
| if config.eval_data_dir is not None: | |
| eval_dataset = ImageFolderDataset(path=config['eval_data_dir'], | |
| resolution=config.image_size, | |
| use_label=config.use_label, | |
| max_size=None, | |
| xflip=False, | |
| crop_type=config.eval_crop_type, | |
| deterministic_crop=bool(getattr(config, "deterministic_crop", False)), | |
| crop_seed=int(config.seed) if config.seed is not None else 0, | |
| ) | |
| eval_loader = DataLoader( | |
| eval_dataset, | |
| batch_size=config.eval_batch_size // accelerator.num_processes, | |
| shuffle=False, | |
| pin_memory=True, | |
| drop_last=False, | |
| worker_init_fn=worker_init, | |
| generator=dl_generator, | |
| ) | |
| eval_dataset_size = len(eval_dataset) | |
| _is_resume = (config.resume_ckpt_path != "" and config.resume_train) | |
| if accelerator.is_main_process and not _is_resume: | |
| print0("Calculating model stats...") | |
| print_model_stats(config, accelerator.device, enc, dec, ar_model) | |
| # Prepare with Accelerator | |
| all_training_components = [enc, dec, enc_ema, dec_ema, gan_loss, ar_model, ar_model_ema, fixed_ar_model, | |
| opt_enc, opt_dec, opt_ar, opt_gan_loss, | |
| train_loader, eval_loader] | |
| prepared_components = iter(accelerator.prepare(*filter(lambda x: x is not None, all_training_components))) | |
| enc = next(prepared_components) if enc is not None else None | |
| dec = next(prepared_components) if dec is not None else None | |
| enc_ema = next(prepared_components) if enc_ema is not None else None | |
| dec_ema = next(prepared_components) if dec_ema is not None else None | |
| gan_loss = next(prepared_components) if gan_loss is not None else None | |
| ar_model = next(prepared_components) if ar_model is not None else None | |
| ar_model_ema = next(prepared_components) if ar_model_ema is not None else None | |
| fixed_ar_model = next(prepared_components) if fixed_ar_model is not None else None | |
| opt_enc = next(prepared_components) if opt_enc is not None else None | |
| opt_dec = next(prepared_components) if opt_dec is not None else None | |
| opt_ar = next(prepared_components) if opt_ar is not None else None | |
| opt_gan_loss = next(prepared_components) if opt_gan_loss is not None else None | |
| train_loader = next(prepared_components) if train_loader is not None else None | |
| eval_loader = next(prepared_components) if eval_loader is not None else None | |
| global_step = 0 | |
| pbar = tqdm(total=total_phase_steps, disable=not accelerator.is_main_process, dynamic_ncols=True, file=sys.stdout) | |
| pbar.set_description(f"Total Steps {total_phase_steps}") | |
| # LR schedulers (must be created and registered BEFORE load_state for resume compatibility) | |
| if config.lr_scheduler == 'linear': | |
| scheduler_enc = get_linear_schedule_with_warmup_peak(opt_enc, num_warmup_steps=config.warmup_steps, num_peak_steps=config.peak_steps, num_training_steps=total_phase_steps/2,base_lr=config.lr_enc,end_lr=config.lr_enc_min) if opt_enc is not None else None | |
| scheduler_dec = get_linear_schedule_with_warmup_peak(opt_dec, num_warmup_steps=config.warmup_steps, num_peak_steps=config.peak_steps, num_training_steps=total_phase_steps/2,base_lr=config.lr_dec,end_lr=config.lr_dec_min) if opt_dec is not None else None | |
| scheduler_ar = get_linear_schedule_with_warmup_peak(opt_ar, num_warmup_steps=config.warmup_steps, num_peak_steps=config.peak_steps, num_training_steps=total_phase_steps/2,base_lr=config.lr_ar,end_lr=config.lr_ar_min) if train_ar else None | |
| scheduler_gan_loss = get_linear_schedule_with_warmup_peak(opt_gan_loss, num_warmup_steps=config.warmup_steps, num_peak_steps=config.peak_steps, num_training_steps=total_phase_steps/2,base_lr=config.lr_gan_loss,end_lr=config.lr_gan_loss_min) if use_gan_loss else None | |
| else: | |
| scheduler_enc = None | |
| scheduler_dec = None | |
| scheduler_ar = None | |
| scheduler_gan_loss = None | |
| for sched in (scheduler_enc, scheduler_dec, scheduler_ar, scheduler_gan_loss): | |
| if sched is not None: | |
| accelerator.register_for_checkpointing(sched) | |
| extra_training_states = { | |
| "global_step": 0, "total_yielded": 0, | |
| "best_rfid": float("inf"), "best_gfid": float("inf"), | |
| "prev_phase_idx": -1, "prev_inner_idx": -1, | |
| "data_buffer": [], "cached_quant": None, "cached_idx": None, "cached_imgs_hat": None, | |
| "cached_visual_quant": None, "cached_semantic_quant": None, | |
| "cached_visual_indices": None, "cached_semantic_indices": None, | |
| "cached_one_hot": None, | |
| } | |
| if config.resume_ckpt_path != "" and config.resume_train: | |
| print0(f"Resuming training from {config.resume_ckpt_path}") | |
| accelerator.load_state(config.resume_ckpt_path) | |
| _extra_path = os.path.join(config.resume_ckpt_path, "extra_state.pt") | |
| if os.path.exists(_extra_path): | |
| _saved = torch.load(_extra_path, weights_only=False) | |
| if "dl_generator" in _saved: | |
| dl_generator.set_state(_saved["dl_generator"]) | |
| extra_training_states.update({k: _saved[k] for k in extra_training_states if k in _saved}) | |
| print0(f"Restored extra state: step={extra_training_states['global_step']}, " | |
| f"yielded={extra_training_states['total_yielded']}, " | |
| f"prev_phase={extra_training_states['prev_phase_idx']}, " | |
| f"prev_inner={extra_training_states['prev_inner_idx']}, " | |
| f"buf_len={len(extra_training_states['data_buffer'])}, " | |
| f"has_cache={extra_training_states.get('cached_quant') is not None}") | |
| global_step = extra_training_states["global_step"] | |
| if global_step == 0: | |
| global_step = int(config.resume_ckpt_path.split("Step=")[-1].split('-')[0]) | |
| pbar.update(global_step) | |
| accelerator.wait_for_everyone() | |
| if accelerator.is_main_process: | |
| root_dir = Path(save_dir) | |
| root_dir.mkdir(exist_ok=True, parents=True) | |
| img_dir = Path(save_dir+ "/images") | |
| img_dir.mkdir(exist_ok=True, parents=True) | |
| ckpt_dir = Path(save_dir+ "/ckpts") | |
| ckpt_dir.mkdir(exist_ok=True, parents=True) | |
| tmp_dir = Path(config.tmp_dir) | |
| tmp_dir.mkdir(exist_ok=True, parents=True) | |
| accelerator.wait_for_everyone() | |
| # Compilation | |
| if config.torch_compile: | |
| if enc is not None: enc = torch.compile(enc) | |
| if enc_ema is not None: enc_ema = torch.compile(enc_ema) | |
| if dec is not None: dec = torch.compile(dec) | |
| if dec_ema is not None: dec_ema = torch.compile(dec_ema) | |
| if ar_model is not None: ar_model = torch.compile(ar_model) | |
| if ar_model_ema is not None: ar_model_ema = torch.compile(ar_model_ema) | |
| # Fast-forward DataLoader to resume position | |
| _resume_total_yielded = extra_training_states["total_yielded"] | |
| if _resume_total_yielded > 0: | |
| _bpe = len(train_loader) | |
| if _bpe > 0: | |
| _epoch = _resume_total_yielded // _bpe | |
| _skip = _resume_total_yielded % _bpe | |
| if hasattr(train_loader, 'iteration'): | |
| train_loader.iteration = _epoch | |
| print0(f"DataLoader fast-forward: epoch={_epoch}, skip={_skip}/{_bpe}") | |
| # Infinite iterator wrapper | |
| train_loader = InfiniteIterator(train_loader, dl_generator=dl_generator) | |
| if _resume_total_yielded > 0 and _bpe > 0 and _skip > 0: | |
| for _ in range(_skip): | |
| next(train_loader) | |
| train_loader.total_yielded = _resume_total_yielded | |
| wandb_run_dir = os.path.join(str(config.wandb_dir), str(config.wandb_name)) | |
| os.makedirs(wandb_run_dir, exist_ok=True) | |
| accelerator.init_trackers( | |
| project_name=config.wandb_project, | |
| config=OmegaConf.to_container(config, resolve=True), | |
| init_kwargs={"wandb": {"name": config.wandb_name, "dir": wandb_run_dir}} | |
| ) | |
| best_rfid = extra_training_states["best_rfid"] | |
| best_gfid = extra_training_states["best_gfid"] | |
| rFID = 0.0 | |
| gFID = 0.0 | |
| prev_phase_idx = extra_training_states["prev_phase_idx"] | |
| prev_inner_idx = extra_training_states["prev_inner_idx"] | |
| buf = extra_training_states["data_buffer"] | |
| data_buffer = [tuple(t.to(accelerator.device) for t in item) for item in buf] if buf else [] | |
| dev = accelerator.device | |
| cq, ci, cvq, csq, cvi, csi, coh = [ | |
| t.to(dev) if isinstance(t, torch.Tensor) else None | |
| for t in (extra_training_states.get(k) for k in | |
| ("cached_quant", "cached_idx", "cached_visual_quant", | |
| "cached_semantic_quant", "cached_visual_indices", "cached_semantic_indices", "cached_one_hot")) | |
| ] | |
| if cq is not None: | |
| cached_enc_out = EncoderOutput( | |
| quant=cq, indices=ci, one_hot=coh if coh is not None else torch.zeros_like(ci), | |
| semantic_vq_loss=VQLossDetail.zero(dev), | |
| visual_vq_loss=VQLossDetail.zero(dev), | |
| semantic_quant=csq if csq is not None else cq, | |
| visual_quant=cvq if cvq is not None else cq, | |
| semantic_indices=csi if csi is not None else ci, | |
| visual_indices=cvi if cvi is not None else ci, | |
| ) | |
| else: | |
| cached_enc_out = None | |
| cached_imgs_hat = extra_training_states.get("cached_imgs_hat") | |
| cached_imgs_hat = cached_imgs_hat.to(dev) if isinstance(cached_imgs_hat, torch.Tensor) else None | |
| uncond_labels = F.one_hot(torch.full((1,), config.num_classes - 1, device=accelerator.device, dtype=torch.long), num_classes=config.num_classes).float() | |
| exp_total_phase_steps = int(getattr(config, 'exp_total_phase_steps', 0)) or total_phase_steps | |
| while global_step < exp_total_phase_steps: | |
| phase_idx, inner_idx, target, internel_step = get_phase(global_step, phases, phase_step_accum, config.gan_start) | |
| enc_grad = target.DO_AE or target.DO_PRIOR_ENC | |
| dec_grad = target.DO_AE | |
| prior_grad = target.DO_PRIOR_AR | |
| disc_grad = target.DO_GAN_D | |
| semantic_grad = target.DO_PRIOR_ENC if _has_separate_semantic else False | |
| if phase_idx != prev_phase_idx or inner_idx != prev_inner_idx: | |
| if _has_separate_semantic: | |
| toggle_require_grad(enc, target.DO_AE, accelerator=accelerator, sub_modules=_visual_modules) | |
| toggle_require_grad(enc, semantic_grad, accelerator=accelerator, sub_modules=_semantic_modules) | |
| else: | |
| toggle_require_grad(enc, enc_grad, accelerator=accelerator) | |
| toggle_require_grad(dec, dec_grad, accelerator=accelerator) | |
| toggle_require_grad(ar_model, prior_grad, accelerator=accelerator) | |
| toggle_require_grad(gan_loss, disc_grad, accelerator=accelerator) | |
| toggle_train_eval(ar_model, train=prior_grad, accelerator=accelerator) | |
| opts = [] | |
| if enc_grad and opt_enc: | |
| opts.append((enc, opt_enc, scheduler_enc, grad_clip_enc)) | |
| if dec_grad: | |
| opts.append((dec, opt_dec, scheduler_dec, grad_clip_dec)) | |
| if prior_grad: | |
| opts.append((ar_model, opt_ar, scheduler_ar, grad_clip_ar)) | |
| if disc_grad: | |
| opts.append((gan_loss, opt_gan_loss, scheduler_gan_loss, grad_clip_gan)) | |
| if internel_step == 0: | |
| data_buffer = [] | |
| if inner_idx == 0 or len(data_buffer) == 0: | |
| batch = next(train_loader) | |
| imgs, labels = batch if isinstance(batch, (list, tuple)) and len(batch) == 2 else (batch, None) | |
| imgs = img_uint8_to_norm(imgs) | |
| x = patchify(imgs, config.patch_size) | |
| uncond_batch = uncond_labels.expand(imgs.shape[0], -1) | |
| if labels is None or len(labels)==0 or not config.use_label: | |
| labels = raw_labels = uncond_batch | |
| else: | |
| labels = raw_labels = torch.cat([labels, torch.full((labels.shape[0],1), 0, device=accelerator.device, dtype=torch.long)], dim=-1) | |
| if config.label_drop_prob > 0: | |
| drop_mask = torch.rand((imgs.shape[0],1), device=accelerator.device) < config.label_drop_prob | |
| labels = torch.where(drop_mask, uncond_batch, labels) | |
| data_buffer.append((x,imgs,raw_labels,labels)) | |
| else: | |
| x, imgs, raw_labels, labels = data_buffer.pop(-1) | |
| uncond_batch = uncond_labels.expand(imgs.shape[0], -1) | |
| ae_labels = uncond_batch if _ae_no_label else raw_labels | |
| if enc_grad: | |
| enc_out = enc(x, ae_labels, training=True) | |
| cached_enc_out = EncoderOutput( | |
| quant=enc_out.quant.detach(), indices=enc_out.indices.detach(), | |
| one_hot=enc_out.one_hot.detach(), | |
| semantic_vq_loss=enc_out.semantic_vq_loss, visual_vq_loss=enc_out.visual_vq_loss, | |
| semantic_quant=enc_out.semantic_quant.detach() if enc_out.semantic_quant is not None else None, | |
| visual_quant=enc_out.visual_quant.detach(), | |
| semantic_indices=enc_out.semantic_indices.detach() if enc_out.semantic_indices is not None else None, | |
| visual_indices=enc_out.visual_indices.detach(), | |
| ) | |
| else: | |
| enc_out = cached_enc_out | |
| if target.DO_AE: | |
| x_hat = dec(enc_out.visual_quant, ae_labels) | |
| imgs_hat = unpatchify(x_hat, config.image_size, config.patch_size) | |
| cached_imgs_hat = imgs_hat.detach() | |
| if target.DO_PRIOR_AR or target.DO_PRIOR_ENC: | |
| # STE needs encoder one_hot; on cache miss / pretoken mode use idx embedding. | |
| if _ste_ar_embedding and enc_grad: | |
| semantic_one_hot = enc_out.semantic_one_hot if _prologue else enc_out.one_hot | |
| else: | |
| semantic_one_hot = None | |
| ar_indices = enc_out.indices | |
| if _semantic_offset > 0 and _prologue: | |
| ar_indices = ar_indices.clone() | |
| ar_indices[:, :config.z_len] += _semantic_offset | |
| if _use_eos: | |
| ar_indices = insert_eos_token(ar_indices, int(config.z_len), _unwrap(ar_model).eos_token_id) | |
| ar_targets = ar_indices | |
| if _prior_visual_dropout > 0 and _prologue: | |
| vis_start = config.z_len + _eos_offset | |
| drop_mask = torch.rand(ar_indices.shape[0], device=ar_indices.device) < _prior_visual_dropout | |
| if drop_mask.any(): | |
| ar_input = ar_indices.clone() | |
| ar_input[drop_mask, vis_start:] = _unwrap(ar_model).eos_token_id | |
| else: | |
| ar_input = ar_indices | |
| else: | |
| ar_input = ar_indices | |
| ar_out = ar_model(ar_input, labels=labels, semantic_one_hot=semantic_one_hot) | |
| # calculate the loss | |
| l2_loss = config.l2_weight * F.mse_loss(x, x_hat) if target.DO_L2 else 0. | |
| l1_loss = config.l1_weight * F.l1_loss(x, x_hat) if target.DO_L1 else 0. | |
| convnext_loss = 0. | |
| if target.DO_LPIPS: | |
| if isinstance(perceptual_loss, BothPerceptualLoss): | |
| _lpips_val, _convnext_val = perceptual_loss(imgs, imgs_hat) | |
| lpips_loss = config.lpips_weight * _lpips_val.mean() | |
| convnext_loss = config.get("convnext_weight", 0.1) * _convnext_val.mean() | |
| else: | |
| lpips_loss = config.lpips_weight * perceptual_loss(imgs, imgs_hat).mean() | |
| else: | |
| lpips_loss = 0. | |
| ae_loss = l2_loss + l1_loss + lpips_loss + convnext_loss | |
| semantic_vqloss_dict = enc_out.semantic_vq_loss | |
| visual_vqloss_dict = enc_out.visual_vq_loss | |
| if semantic_vqloss_dict is not None: | |
| semantic_vqloss = config.commitment_loss_weight * semantic_vqloss_dict.quant_loss + config.entropy_loss_weight * semantic_vqloss_dict.entropy_loss if enc_grad else 0. | |
| else: | |
| semantic_vqloss = 0. | |
| visual_vqloss = config.commitment_loss_weight * visual_vqloss_dict.quant_loss + config.entropy_loss_weight * visual_vqloss_dict.entropy_loss if enc_grad else 0. | |
| vqloss = semantic_vqloss + visual_vqloss | |
| gan_G_loss, gan_G_loss_dict = gan_loss(imgs, imgs_hat, global_step=global_step, loss='G') if target.DO_GAN_G else (0., {}) | |
| gan_D_loss, gan_D_loss_dict = gan_loss(imgs, cached_imgs_hat, global_step=global_step, loss='D') if target.DO_GAN_D else (0., {}) | |
| gan_G_loss_weight = 0. | |
| adapt_weight = 0. | |
| if target.DO_GAN_G: | |
| if config.disc_adaptive_weight: | |
| adapt_weight = calculate_adaptive_weight_acc(ae_loss, gan_G_loss, accelerator.unwrap_model(dec).dec.out.weight, accelerator, dec) | |
| gan_G_loss_weight = (config.gan_G_weight * adapt_weight) | |
| else: | |
| gan_G_loss_weight = config.gan_G_weight | |
| gan_G_loss = gan_G_loss * gan_G_loss_weight if target.DO_GAN_G else 0. | |
| gan_D_loss = gan_D_loss * config.gan_D_weight if target.DO_GAN_D else 0. | |
| prior_ar_loss = 0. | |
| semantic_prior_ar_loss = 0. | |
| visual_prior_ar_loss = 0. | |
| eos_prior_ar_loss = 0. | |
| correct_token_rate = 0. | |
| semantic_correct_token_rate = 0. | |
| visual_correct_token_rate = 0. | |
| if target.DO_PRIOR_AR: | |
| ar_logits = ar_out.logits | |
| if _prologue: | |
| total_len = config.z_len + _eos_offset + config.x_len | |
| V = ar_logits.shape[-1] | |
| B = ar_logits.shape[0] | |
| sem_logits = ar_logits[:, :config.z_len] | |
| vis_logits = ar_logits[:, config.z_len + _eos_offset:] | |
| sem_targets = ar_targets[:, :config.z_len] | |
| vis_targets = ar_targets[:, config.z_len + _eos_offset:] | |
| sem_loss_per_token = F.cross_entropy(sem_logits.reshape(-1, V), sem_targets.reshape(-1), reduction='none') | |
| semantic_prior_ar_loss = sem_loss_per_token.reshape(B, -1).sum(dim=1).mean() / total_len | |
| semantic_correct_token_rate = (sem_logits.argmax(dim=-1) == sem_targets).detach().float().mean().item() | |
| vis_loss_per_token = F.cross_entropy(vis_logits.reshape(-1, V), vis_targets.reshape(-1), reduction='none') | |
| visual_prior_ar_loss = vis_loss_per_token.reshape(B, -1).sum(dim=1).mean() / total_len | |
| visual_correct_token_rate = (vis_logits.argmax(dim=-1) == vis_targets).detach().float().mean().item() | |
| if _eos_offset > 0: | |
| eos_logits = ar_logits[:, config.z_len:config.z_len + 1] | |
| eos_targets = ar_targets[:, config.z_len:config.z_len + 1] | |
| eos_loss_per_token = F.cross_entropy(eos_logits.reshape(-1, V), eos_targets.reshape(-1), reduction='none') | |
| eos_prior_ar_loss = eos_loss_per_token.reshape(B, -1).sum(dim=1).mean() / total_len | |
| prior_ar_loss = semantic_prior_ar_loss + eos_prior_ar_loss + visual_prior_ar_loss | |
| else: | |
| prior_ar_loss = F.cross_entropy(ar_logits.reshape(-1, ar_logits.shape[-1]), ar_targets.reshape(-1)) | |
| visual_prior_ar_loss = prior_ar_loss | |
| visual_correct_token_rate = (ar_logits.argmax(dim=-1) == ar_targets).detach().float().mean().item() | |
| correct_token_rate = (ar_logits.argmax(dim=-1) == ar_targets).detach().float().mean().item() | |
| prior_enc_loss = raw_prior_enc_loss = 0. | |
| semantic_prior_enc_loss = 0. | |
| visual_prior_enc_loss = 0. | |
| eos_prior_enc_loss = 0. | |
| if target.DO_PRIOR_ENC: | |
| enc_ar_logits = ar_out.logits | |
| if _prologue: | |
| total_len_enc = config.z_len + _eos_offset + config.x_len | |
| V_enc = enc_ar_logits.shape[-1] | |
| B_enc = enc_ar_logits.shape[0] | |
| enc_sem_logits = enc_ar_logits[:, :config.z_len] | |
| enc_vis_logits = enc_ar_logits[:, config.z_len + _eos_offset:] | |
| enc_sem_targets = ar_targets[:, :config.z_len] | |
| enc_vis_targets = ar_targets[:, config.z_len + _eos_offset:] | |
| enc_sem_loss = F.cross_entropy(enc_sem_logits.reshape(-1, V_enc), enc_sem_targets.reshape(-1), reduction='none') | |
| semantic_prior_enc_loss = enc_sem_loss.reshape(B_enc, -1).sum(dim=1).mean() / total_len_enc | |
| enc_vis_loss = F.cross_entropy(enc_vis_logits.reshape(-1, V_enc), enc_vis_targets.reshape(-1), reduction='none') | |
| visual_prior_enc_loss = enc_vis_loss.reshape(B_enc, -1).sum(dim=1).mean() / total_len_enc | |
| if _eos_offset > 0: | |
| enc_eos_logits = enc_ar_logits[:, config.z_len:config.z_len + 1] | |
| enc_eos_targets = ar_targets[:, config.z_len:config.z_len + 1] | |
| enc_eos_loss = F.cross_entropy(enc_eos_logits.reshape(-1, V_enc), enc_eos_targets.reshape(-1), reduction='none') | |
| eos_prior_enc_loss = enc_eos_loss.reshape(B_enc, -1).sum(dim=1).mean() / total_len_enc | |
| else: | |
| enc_vis_logits = enc_ar_logits | |
| enc_vis_targets = ar_targets | |
| visual_prior_enc_loss = F.cross_entropy(enc_vis_logits.reshape(-1, enc_vis_logits.shape[-1]), enc_vis_targets.reshape(-1)) | |
| raw_prior_enc_loss = semantic_prior_enc_loss + eos_prior_enc_loss + visual_prior_enc_loss | |
| prior_enc_loss = config.prior_enc_semantic_weight * (semantic_prior_enc_loss + eos_prior_enc_loss) + config.prior_enc_visual_weight * visual_prior_enc_loss | |
| loss = ae_loss + prior_ar_loss + gan_G_loss + gan_D_loss + vqloss + prior_enc_loss | |
| # backward and optimization | |
| accelerator.backward(loss) | |
| grad_norms = calc_grad_norm( | |
| {"Enc": enc, "Dec": dec, "AR": ar_model if train_ar else None, "GAN": gan_loss if use_gan_loss else None}, | |
| global_step, int(getattr(config, "grad_norm_freq", 0)), | |
| accelerator=accelerator, | |
| ) | |
| for model, opt, scheduler, clip_val in opts: | |
| if model is not None and opt is not None: | |
| cur_lr = max(pg['lr'] for pg in opt.param_groups) | |
| if cur_lr > 0: | |
| zero_nan_gradients(model, accelerator=accelerator) | |
| if clip_val > 0: | |
| accelerator.clip_grad_norm_(model.parameters(), max_norm=clip_val) | |
| opt.step() | |
| opt.zero_grad(set_to_none=True) | |
| if scheduler is not None: | |
| scheduler.step() | |
| # EMA updates (skip when lr == 0 since params were not updated) | |
| if enc_grad and opt_enc is not None and max(pg['lr'] for pg in opt_enc.param_groups) > 0: | |
| ema_update(enc, enc_ema, ae_ema_rate) | |
| if dec_grad and opt_dec is not None and max(pg['lr'] for pg in opt_dec.param_groups) > 0: | |
| ema_update(dec, dec_ema, ae_ema_rate) | |
| if train_ar and prior_grad and opt_ar is not None and max(pg['lr'] for pg in opt_ar.param_groups) > 0: | |
| ema_update(ar_model, ar_model_ema, ar_ema_rate) | |
| #visualization | |
| if (global_step + 1) % config.visualize_freq == 0: | |
| grid = x[:config.visualize_img_num] | |
| # enc/dec: toggle whenever either ema flag is off (any non-EMA path needs train-mode toggling) | |
| if ((not config.ema_sampling) or (not config.ema_reconstruction)) and train_ae: | |
| toggle_train_eval(enc, train=False, accelerator=accelerator) | |
| toggle_train_eval(dec, train=False, accelerator=accelerator) | |
| # ar: only ema_sampling controls whether to toggle | |
| if not config.ema_sampling and train_ar: | |
| toggle_train_eval(ar_model, train=False, accelerator=accelerator) | |
| # Pick the models used for recon vs sampling | |
| if config.ema_reconstruction: | |
| recon_enc, recon_dec = enc_ema, dec_ema | |
| else: | |
| recon_enc, recon_dec = enc, dec | |
| if config.ema_sampling: | |
| sample_enc, sample_dec, sample_ar = enc_ema, dec_ema, ar_model_ema | |
| else: | |
| sample_enc, sample_dec, sample_ar = enc, dec, ar_model | |
| with torch.no_grad(): | |
| if train_ae: | |
| all_recon_images,_ = reconstruction(recon_enc, recon_dec, x, ae_labels) | |
| grid = torch.cat([grid, all_recon_images[:config.visualize_img_num]], dim=0) | |
| if train_ar: | |
| _viz_cls = uncond_batch if _label_drop_always else raw_labels | |
| all_sample_images = sampling(sample_enc, sample_dec, sample_ar, bz=x.shape[0], | |
| class_label=_viz_cls,temperature=config.temperature, | |
| topK=config.topK, topP=config.topP, | |
| cfg=config.cfg, cfg_schedule=config.cfg_schedule, cfg_power=config.cfg_power, | |
| cache_kv=config.cache_kv, | |
| ae_label=ae_labels, | |
| semantic_cfg_schedule=config.get("semantic_cfg_schedule", None), | |
| semantic_cfg_scale=config.get("semantic_cfg_scale", None), | |
| semantic_cfg_power=config.get("semantic_cfg_power", None), | |
| semantic_cfg_start=float(config.get("semantic_cfg_start", 0.0)), | |
| visual_cfg_schedule=config.get("visual_cfg_schedule", None), | |
| visual_cfg_scale=config.get("visual_cfg_scale", None), | |
| visual_cfg_power=config.get("visual_cfg_power", None), | |
| visual_cfg_start=float(config.get("visual_cfg_start", 1.0))) | |
| grid = torch.cat([grid, all_sample_images[:config.visualize_img_num]], dim=0) | |
| grid = img_denormalize(unpatchify(grid, config.image_size, config.patch_size)) | |
| grid = torchvision.utils.make_grid(grid, nrow=config.visualize_img_num, normalize=False) | |
| accelerator.log({"visualization/imgs":wandb.Image(grid),'global_step':global_step+1}, step=global_step+1) | |
| # Toggle back to train mode | |
| if ((not config.ema_sampling) or (not config.ema_reconstruction)) and train_ae: | |
| toggle_train_eval(enc, train=True, accelerator=accelerator) | |
| toggle_train_eval(dec, train=True, accelerator=accelerator) | |
| if not config.ema_sampling and train_ar: | |
| toggle_train_eval(ar_model, train=True, accelerator=accelerator) | |
| # eval and save ckpt | |
| rFID = 0.0 | |
| gFID = 0.0 | |
| # Also keep per-GPU sums/count so final aggregation can be weighted-correct in metrics block | |
| eval_lpips = torch.zeros((), device=accelerator.device, dtype=torch.float32) | |
| eval_ssim = torch.zeros((), device=accelerator.device, dtype=torch.float32) | |
| eval_psnr = torch.zeros((), device=accelerator.device, dtype=torch.float32) | |
| semantic_codebook_usage = torch.zeros(config.z_len, config.SemanticQuantizer.codebook_size, device=accelerator.device).long() if (train_ae and _prologue) else None | |
| visual_codebook_usage = torch.zeros(config.x_len if _prologue else config.z_len, config.codebook_size, device=accelerator.device).long() if train_ae else None | |
| codebook_usage = None | |
| avg_agg_ent = torch.zeros((), device=accelerator.device, dtype=torch.float32) | |
| entropy_acc = {"ent_sum": None, "ent_cnt": None} | |
| entropy_log_base = float(getattr(config, "prefix_entropy_log_base", 2.0)) | |
| data_cond_entropy_trie = None # For train_ae: collect real token sequences | |
| if eval_loader is not None and (global_step + 1) % config.eval_freq == 0: | |
| sample_cached_path = os.path.join(config.tmp_dir, "sample_images.npz") | |
| recon_cached_path = os.path.join(config.tmp_dir, "recon_images.npz") | |
| gt_cache_path = config.eval_fid_ref_path | |
| gt_buf = [] | |
| samples_buf = [] | |
| recons_buf = [] | |
| semantic_idx_buf = [] | |
| visual_idx_buf = [] | |
| lpips_sum = torch.zeros((), device=accelerator.device, dtype=torch.float32) | |
| ssim_sum = torch.zeros((), device=accelerator.device, dtype=torch.float32) | |
| psnr_sum = torch.zeros((), device=accelerator.device, dtype=torch.float32) | |
| count = torch.zeros((), device=accelerator.device, dtype=torch.long) | |
| # enc/dec: toggle whenever either ema flag is off (any non-EMA path needs train-mode toggling) | |
| if ((not config.ema_sampling) or (not config.ema_reconstruction)) and train_ae: | |
| toggle_train_eval(enc, train=False, accelerator=accelerator) | |
| toggle_train_eval(dec, train=False, accelerator=accelerator) | |
| # ar: only ema_sampling controls whether to toggle | |
| if not config.ema_sampling and train_ar: | |
| toggle_train_eval(ar_model, train=False, accelerator=accelerator) | |
| # Pick the models used for recon vs sampling | |
| if config.ema_reconstruction: | |
| recon_enc, recon_dec = enc_ema, dec_ema | |
| else: | |
| recon_enc, recon_dec = enc, dec | |
| if config.ema_sampling: | |
| sample_enc, sample_dec, sample_ar = enc_ema, dec_ema, ar_model_ema | |
| else: | |
| sample_enc, sample_dec, sample_ar = enc, dec, ar_model | |
| cuda_rng_state = torch.cuda.get_rng_state() | |
| eval_seed = int(config.seed) + accelerator.process_index | |
| torch.cuda.manual_seed(eval_seed) | |
| print0(f"[Eval] Per-rank CUDA seed: base={config.seed}, rank={accelerator.process_index}, effective={eval_seed}") | |
| with torch.no_grad(): | |
| accelerator.wait_for_everyone() | |
| for i, batch in enumerate(tqdm(eval_loader,disable=not accelerator.is_main_process,dynamic_ncols=True,file=sys.stdout,desc="Evaluating")): | |
| imgs, labels = batch if isinstance(batch, (list, tuple)) and len(batch) == 2 else (batch, None) | |
| imgs_norm = img_uint8_to_norm(imgs) | |
| x = patchify(imgs_norm, config.patch_size) | |
| uncond_batch = uncond_labels.expand(imgs.shape[0], -1) | |
| if labels is None or len(labels)==0 or not config.use_label: | |
| labels = uncond_batch | |
| else: | |
| labels = torch.cat([labels, torch.full((labels.shape[0],1), 0, device=accelerator.device, dtype=torch.long)], dim=-1) | |
| eval_ae_labels = uncond_batch if _ae_no_label else labels | |
| # Align AR eval (encode / teacher-forcing / sampling) with unconditional training when label_drop_prob>=1.0 | |
| ar_eval_cls = uncond_batch if _label_drop_always else labels | |
| idx = None | |
| if train_ae: | |
| recon_patches, idx = reconstruction(recon_enc, recon_dec, x, eval_ae_labels) # patch domain | |
| recon_img = unpatchify(recon_patches, config.image_size, config.patch_size) # [B,C,H,W] float in [-1,1] | |
| lpips_vals, ssim_vals, psnr_vals = calc_per_sample_reconstruction_metrics(imgs_norm, recon_img, perceptual_loss) | |
| lpips_sum += lpips_vals | |
| ssim_sum += ssim_vals | |
| psnr_sum += psnr_vals | |
| count += 1 | |
| recon_img_u8 = img_norm_to_uint8(recon_img) | |
| recons_buf.append(recon_img_u8.permute(0,2,3,1).cpu().numpy()) | |
| elif train_ar: | |
| idx = sample_enc(x, ar_eval_cls, training=False).indices | |
| if train_ar: | |
| ar_eval_idx = idx | |
| if _semantic_offset > 0 and _prologue: | |
| ar_eval_idx = idx.clone() | |
| ar_eval_idx[:, :config.z_len] += _semantic_offset | |
| if _use_eos: | |
| ar_eval_idx = insert_eos_token(ar_eval_idx, int(config.z_len), _unwrap(ar_model).eos_token_id) | |
| ar_logits = sample_ar(ar_eval_idx, ar_eval_cls).logits | |
| draw_conditional_entropy( | |
| entropy_acc, | |
| logits=ar_logits, | |
| accelerator=accelerator, | |
| save_dir=save_dir, | |
| global_step=global_step, | |
| log_base=entropy_log_base, | |
| finalize=False, | |
| codebook_size=config.codebook_size, | |
| ) | |
| if train_ar: | |
| sample_images = sampling(sample_enc, sample_dec, sample_ar, | |
| bz=x.shape[0], class_label=ar_eval_cls,temperature=config.temperature, | |
| topK=config.topK, topP=config.topP, cfg=config.cfg, cfg_schedule=config.cfg_schedule, | |
| cfg_power=config.cfg_power, cache_kv=config.cache_kv, | |
| ae_label=eval_ae_labels, | |
| semantic_cfg_schedule=config.get("semantic_cfg_schedule", None), | |
| semantic_cfg_scale=config.get("semantic_cfg_scale", None), | |
| semantic_cfg_power=config.get("semantic_cfg_power", None), | |
| semantic_cfg_start=float(config.get("semantic_cfg_start", 0.0)), | |
| visual_cfg_schedule=config.get("visual_cfg_schedule", None), | |
| visual_cfg_scale=config.get("visual_cfg_scale", None), | |
| visual_cfg_power=config.get("visual_cfg_power", None), | |
| visual_cfg_start=float(config.get("visual_cfg_start", 1.0))) | |
| sample_images = img_norm_to_uint8(unpatchify(sample_images, config.image_size, config.patch_size)) | |
| samples_buf.append(sample_images.permute(0,2,3,1).cpu().numpy()) | |
| if not os.path.exists(gt_cache_path): | |
| gt_buf.append(imgs.permute(0,2,3,1).cpu().numpy()) | |
| if train_ae and visual_codebook_usage is not None and idx is not None: | |
| if _prologue: | |
| semantic_idx_buf.append(idx[:, :-config.x_len]) | |
| visual_idx_buf.append(idx[:, -config.x_len:]) | |
| else: | |
| visual_idx_buf.append(idx.detach()) | |
| # Collect token sequences for data conditional entropy | |
| if train_ae and idx is not None: | |
| data_cond_entropy_trie = draw_data_conditional_entropy( | |
| data_cond_entropy_trie, | |
| idx=idx, | |
| accelerator=accelerator, | |
| save_dir=save_dir, | |
| global_step=global_step, | |
| log_base=entropy_log_base, | |
| finalize=False, | |
| max_depth=config.z_len, | |
| codebook_size=config.codebook_size, | |
| ) | |
| accelerator.wait_for_everyone() | |
| if train_ae: | |
| if semantic_codebook_usage is not None and len(semantic_idx_buf) > 0: | |
| s_idx_all = accelerator.gather(torch.cat(semantic_idx_buf, dim=0)) | |
| if accelerator.is_main_process: | |
| s_idx_all = s_idx_all.to(device=semantic_codebook_usage.device, dtype=torch.long) | |
| s_pos = torch.arange(s_idx_all.shape[1], device=semantic_codebook_usage.device, dtype=torch.long).unsqueeze(0).expand_as(s_idx_all).reshape(-1) | |
| semantic_codebook_usage[s_pos, s_idx_all.reshape(-1)] += 1 | |
| if visual_codebook_usage is not None and len(visual_idx_buf) > 0: | |
| v_idx_all = accelerator.gather(torch.cat(visual_idx_buf, dim=0)) | |
| if accelerator.is_main_process: | |
| v_idx_all = v_idx_all.to(device=visual_codebook_usage.device, dtype=torch.long) | |
| v_pos = torch.arange(v_idx_all.shape[1], device=visual_codebook_usage.device, dtype=torch.long).unsqueeze(0).expand_as(v_idx_all).reshape(-1) | |
| visual_codebook_usage[v_pos, v_idx_all.reshape(-1)] += 1 | |
| if train_ae: | |
| eval_lpips = lpips_sum / count | |
| eval_ssim = ssim_sum / count | |
| eval_psnr = psnr_sum / count | |
| print0("eval_lpips",eval_lpips,"eval_ssim",eval_ssim,"eval_psnr",eval_psnr) | |
| # Gather all buffers at once and save | |
| gathered_samples = accelerator.gather(torch.from_numpy(np.concatenate(samples_buf, axis=0)).to(accelerator.device)).cpu().numpy() if train_ar and len(samples_buf) > 0 else None | |
| gathered_recons = accelerator.gather(torch.from_numpy(np.concatenate(recons_buf, axis=0)).to(accelerator.device)).cpu().numpy() if train_ae and len(recons_buf) > 0 else None | |
| gathered_gt = accelerator.gather(torch.from_numpy(np.concatenate(gt_buf, axis=0)).to(accelerator.device)).cpu().numpy() if len(gt_buf) > 0 else None | |
| if accelerator.is_main_process: | |
| if gathered_samples is not None: | |
| print0(f"[Eval] Gathered samples: {gathered_samples.shape[0]}, eval_dataset: {eval_dataset_size}") | |
| if gathered_samples.shape[0] != eval_dataset_size: | |
| print0(f"WARNING: Gathered samples count ({gathered_samples.shape[0]}) != eval_dataset size ({eval_dataset_size})") | |
| np.savez(sample_cached_path, gathered_samples) | |
| if gathered_recons is not None: | |
| print0(f"[Eval] Gathered recons: {gathered_recons.shape[0]}, eval_dataset: {eval_dataset_size}") | |
| if gathered_recons.shape[0] != eval_dataset_size: | |
| print0(f"WARNING: Gathered recons count ({gathered_recons.shape[0]}) != eval_dataset size ({eval_dataset_size})") | |
| np.savez(recon_cached_path, gathered_recons) | |
| if gathered_gt is not None: | |
| print0(f"[Eval] Gathered gt: {gathered_gt.shape[0]}, eval_dataset: {eval_dataset_size}") | |
| if gathered_gt.shape[0] != eval_dataset_size: | |
| print0(f"WARNING: Gathered gt count ({gathered_gt.shape[0]}) != eval_dataset size ({eval_dataset_size})") | |
| if not os.path.exists(gt_cache_path): | |
| np.savez(gt_cache_path, gathered_gt) | |
| if train_ar: | |
| gFID = adm_fid_evaluator(sample_cached_path, gt_cache_path, config, accelerator) | |
| if train_ae: | |
| rFID = adm_fid_evaluator(recon_cached_path, gt_cache_path, config, accelerator) | |
| accelerator.wait_for_everyone() | |
| # Sync FID across all processes to ensure consistent checkpoint naming | |
| rFID_t = torch.tensor(rFID, device=accelerator.device) | |
| gFID_t = torch.tensor(gFID, device=accelerator.device) | |
| rFID = accelerator.reduce(rFID_t, reduction="sum").item() | |
| gFID = accelerator.reduce(gFID_t, reduction="sum").item() | |
| if config.save_best and config.save_ckpt: | |
| if train_ae and rFID < best_rfid: | |
| best_rfid = float(rFID) | |
| if accelerator.is_main_process: | |
| remove_old_best_checkpoints(f"{save_dir}/ckpts", metric_type="rFID") | |
| accelerator.wait_for_everyone() | |
| save_training_state(accelerator, f"{save_dir}/ckpts/best-Step={global_step+1}-rFID={best_rfid:.4f}", extra_training_states, | |
| global_step=global_step + 1, dl_generator=train_loader._pre_epoch_gen_state, | |
| total_yielded=train_loader.total_yielded, prev_phase_idx=phase_idx, prev_inner_idx=inner_idx, | |
| data_buffer=[tuple(t.detach().cpu() for t in item) for item in data_buffer], | |
| **make_cache_kwargs(cached_enc_out, cached_imgs_hat), | |
| best_rfid=best_rfid) | |
| if accelerator.is_main_process: | |
| print0(f"[best] Saved new best rFID checkpoint: {best_rfid:.4f}") | |
| if train_ar and gFID < best_gfid: | |
| best_gfid = float(gFID) | |
| if accelerator.is_main_process: | |
| remove_old_best_checkpoints(f"{save_dir}/ckpts", metric_type="gFID") | |
| accelerator.wait_for_everyone() | |
| save_training_state(accelerator, f"{save_dir}/ckpts/best-Step={global_step+1}-gFID={best_gfid:.4f}", extra_training_states, | |
| global_step=global_step + 1, dl_generator=train_loader._pre_epoch_gen_state, | |
| total_yielded=train_loader.total_yielded, prev_phase_idx=phase_idx, prev_inner_idx=inner_idx, | |
| data_buffer=[tuple(t.detach().cpu() for t in item) for item in data_buffer], | |
| **make_cache_kwargs(cached_enc_out, cached_imgs_hat), | |
| best_gfid=best_gfid) | |
| if accelerator.is_main_process: | |
| print0(f"[best] Saved new best gFID checkpoint: {best_gfid:.4f}") | |
| if train_ar: | |
| draw_conditional_entropy( | |
| entropy_acc, | |
| accelerator=accelerator, | |
| save_dir=save_dir, | |
| global_step=global_step, | |
| log_base=entropy_log_base, | |
| finalize=True, | |
| rFID=rFID, | |
| gFID=gFID, | |
| codebook_size=config.codebook_size, | |
| ) | |
| if train_ae: | |
| # Plot data conditional entropy (from real token sequences) | |
| draw_data_conditional_entropy( | |
| data_cond_entropy_trie, | |
| accelerator=accelerator, | |
| save_dir=save_dir, | |
| global_step=global_step, | |
| log_base=entropy_log_base, | |
| finalize=True, | |
| rFID=rFID, | |
| gFID=gFID, | |
| max_depth=config.z_len, | |
| codebook_size=config.codebook_size, | |
| ) | |
| if semantic_codebook_usage is not None and visual_codebook_usage is not None: | |
| max_cb = max(semantic_codebook_usage.shape[1], visual_codebook_usage.shape[1]) | |
| s_pad = F.pad(semantic_codebook_usage, (0, max_cb - semantic_codebook_usage.shape[1])) if semantic_codebook_usage.shape[1] < max_cb else semantic_codebook_usage | |
| v_pad = F.pad(visual_codebook_usage, (0, max_cb - visual_codebook_usage.shape[1])) if visual_codebook_usage.shape[1] < max_cb else visual_codebook_usage | |
| codebook_usage = torch.cat([s_pad, v_pad], dim=0) | |
| elif visual_codebook_usage is not None: | |
| codebook_usage = visual_codebook_usage | |
| else: | |
| codebook_usage = None | |
| plot_codebook_usage( | |
| codebook_usage, | |
| accelerator=accelerator, | |
| save_dir=save_dir, | |
| global_step=global_step, | |
| rFID=rFID, | |
| gFID=gFID, | |
| ) | |
| if codebook_usage is not None and accelerator.is_main_process: | |
| aggregated_ent = compute_aggregated_entropy_from_counts(codebook_usage, log_base=entropy_log_base) # [L] | |
| aggregated_ent_list = aggregated_ent.detach().cpu().tolist() | |
| avg_agg_ent_val = float(np.nanmean(np.array(aggregated_ent_list, dtype=np.float64))) | |
| avg_agg_ent = torch.tensor(avg_agg_ent_val, device=accelerator.device, dtype=torch.float32) | |
| plot_posterior_entropy( | |
| sample_entropy=None, | |
| aggregated_entropy=aggregated_ent_list, | |
| accelerator=accelerator, | |
| save_dir=save_dir, | |
| global_step=global_step, | |
| rFID=rFID, | |
| gFID=gFID, | |
| log_base=entropy_log_base, | |
| codebook_size=config.codebook_size, | |
| ) | |
| # Reduce avg_agg_ent across all processes | |
| if train_ae: | |
| avg_agg_ent = accelerator.reduce(avg_agg_ent, reduction='sum') | |
| torch.cuda.set_rng_state(cuda_rng_state) | |
| # Toggle back to train mode | |
| if ((not config.ema_sampling) or (not config.ema_reconstruction)) and train_ae: | |
| toggle_train_eval(enc, train=True, accelerator=accelerator) | |
| toggle_train_eval(dec, train=True, accelerator=accelerator) | |
| if not config.ema_sampling and train_ar: | |
| toggle_train_eval(ar_model, train=True, accelerator=accelerator) | |
| # Cleanup temporary npz files produced during this eval | |
| if accelerator.is_main_process: | |
| if train_ar: | |
| safe_remove_file(sample_cached_path) | |
| if train_ae: | |
| safe_remove_file(recon_cached_path) | |
| # Explicitly clear large buffers to free memory | |
| del samples_buf, recons_buf, gt_buf, semantic_idx_buf, visual_idx_buf | |
| if 'gathered_samples' in locals(): del gathered_samples | |
| if 'gathered_recons' in locals(): del gathered_recons | |
| if 'gathered_gt' in locals(): del gathered_gt | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| accelerator.wait_for_everyone() | |
| # metrics | |
| if (global_step +1) % config.log_freq == 0: | |
| lr_enc = opt_enc.param_groups[0]['lr'] if opt_enc is not None else 0. | |
| lr_dec = opt_dec.param_groups[0]['lr'] if opt_dec is not None else 0. | |
| lr_ar = opt_ar.param_groups[0]['lr'] if opt_ar is not None else 0. | |
| lr_gan_loss = opt_gan_loss.param_groups[0]['lr'] if opt_gan_loss is not None else 0. | |
| metrics = { | |
| 'Phase':phase_idx, | |
| 'ae_loss/l1_loss':l1_loss, | |
| 'ae_loss/l2_loss':l2_loss, | |
| 'ae_loss/lpips_loss':lpips_loss, | |
| 'ae_loss/convnext_loss':convnext_loss, | |
| 'ae_loss/loss':ae_loss, | |
| 'ae_loss/vqloss': vqloss, | |
| 'ae_loss/semantic_vqloss': semantic_vqloss, | |
| 'ae_loss/visual_vqloss': visual_vqloss, | |
| **(({ | |
| 'semantic_vqloss/quant_loss': semantic_vqloss_dict.quant_loss.detach(), | |
| 'semantic_vqloss/entropy_loss': semantic_vqloss_dict.entropy_loss.detach(), | |
| 'semantic_vqloss/sample_entropy': semantic_vqloss_dict.sample_entropy, | |
| 'semantic_vqloss/batch_entropy': semantic_vqloss_dict.batch_entropy, | |
| 'semantic_vqloss/z_norm': semantic_vqloss_dict.l2norm_z, | |
| 'semantic_vqloss/code_norm': semantic_vqloss_dict.l2norm_code, | |
| }) if semantic_vqloss_dict is not None else {}), | |
| 'visual_vqloss/quant_loss': visual_vqloss_dict.quant_loss.detach(), | |
| 'visual_vqloss/entropy_loss': visual_vqloss_dict.entropy_loss.detach(), | |
| 'visual_vqloss/sample_entropy': visual_vqloss_dict.sample_entropy, | |
| 'visual_vqloss/batch_entropy': visual_vqloss_dict.batch_entropy, | |
| 'visual_vqloss/z_norm': visual_vqloss_dict.l2norm_z, | |
| 'visual_vqloss/code_norm': visual_vqloss_dict.l2norm_code, | |
| 'GAN/gan_G_loss':gan_G_loss, | |
| 'GAN/gan_D_loss':gan_D_loss, | |
| 'prior_loss/ar_loss':prior_ar_loss, | |
| 'prior_loss/semantic_ar_loss':semantic_prior_ar_loss, | |
| 'prior_loss/visual_ar_loss':visual_prior_ar_loss, | |
| 'prior_loss/eos_ar_loss':eos_prior_ar_loss, | |
| 'prior_loss/prior_enc_loss':raw_prior_enc_loss, | |
| 'prior_loss/semantic_prior_enc_loss':semantic_prior_enc_loss, | |
| 'prior_loss/visual_prior_enc_loss':visual_prior_enc_loss, | |
| 'prior_loss/eos_prior_enc_loss':eos_prior_enc_loss, | |
| 'prior_loss/correct_token_rate': correct_token_rate, | |
| 'prior_loss/semantic_correct_token_rate': semantic_correct_token_rate, | |
| 'prior_loss/visual_correct_token_rate': visual_correct_token_rate, | |
| 'eval/semantic_codebook_usage': ((semantic_codebook_usage > 0).any(dim=0).float().mean().item() if (train_ae and semantic_codebook_usage is not None) else 0.0), | |
| 'eval/visual_codebook_usage': ((visual_codebook_usage > 0).any(dim=0).float().mean().item() if (train_ae and visual_codebook_usage is not None) else 0.0), | |
| 'eval/rFID':rFID, | |
| 'eval/gFID':gFID, | |
| 'eval/lpips': eval_lpips, | |
| 'eval/psnr':eval_psnr, | |
| 'eval/ssim':eval_ssim, | |
| 'eval/aggregated_post_entropy': avg_agg_ent, | |
| 'lr/lr_enc': lr_enc, | |
| 'lr/lr_dec': lr_dec, | |
| 'lr/lr_ar': lr_ar, | |
| 'lr/lr_gan_loss': lr_gan_loss, | |
| } | |
| if target.DO_GAN_G or target.DO_GAN_D: | |
| gan_loss_dict = {**gan_G_loss_dict, **gan_D_loss_dict, 'GAN/gan_G_loss_weight':gan_G_loss_weight, 'GAN/adapt_weight':adapt_weight} | |
| metrics.update(gan_loss_dict) | |
| metrics = {k: accelerator.reduce(v, reduction='mean').item() if isinstance(v, torch.Tensor) else v for k, v in metrics.items() } | |
| metrics_logger = {k: v for k, v in metrics.items() if v!=0. } | |
| metrics_logger.update(grad_norms) | |
| metrics_4f = {k: f"{v:.4f}" if k!='Phase' else int(v) for k, v in metrics.items() } | |
| accelerator.log(metrics_logger, step=global_step+1) | |
| pbar.set_postfix(metrics_4f, refresh=False) | |
| pbar.update(config.log_freq) | |
| if config.save_ckpt and ((global_step+1) % config.ckpt_freq == 0): | |
| ckpt_name = f'Phase={phase_idx}-Step={global_step+1}-rFID={rFID:.4f}-gFID={gFID:.4f}' if rFID is not None and gFID is not None else "-".join([f'{k}={v:.4f}' for k, v in metrics.items()]) | |
| save_path = f'{save_dir}/ckpts/{ckpt_name}' | |
| save_training_state(accelerator, save_path, extra_training_states, | |
| global_step=global_step + 1, dl_generator=train_loader._pre_epoch_gen_state, | |
| total_yielded=train_loader.total_yielded, prev_phase_idx=phase_idx, prev_inner_idx=inner_idx, | |
| data_buffer=[tuple(t.detach().cpu() for t in item) for item in data_buffer], | |
| **make_cache_kwargs(cached_enc_out, cached_imgs_hat)) | |
| global_step += 1 | |
| prev_phase_idx = phase_idx | |
| prev_inner_idx = inner_idx | |
| # Save last state | |
| if config.save_ckpt: | |
| save_path = f'{save_dir}/ckpts/last-Step={global_step}-rFID={rFID:.4f}-gFID={gFID:.4f}' | |
| save_training_state(accelerator, save_path, extra_training_states, | |
| global_step=global_step, dl_generator=train_loader._pre_epoch_gen_state, | |
| total_yielded=train_loader.total_yielded, prev_phase_idx=phase_idx, prev_inner_idx=inner_idx, | |
| data_buffer=[tuple(t.detach().cpu() for t in item) for item in data_buffer], | |
| **make_cache_kwargs(cached_enc_out, cached_imgs_hat)) | |
| accelerator.wait_for_everyone() | |
| pbar.close() | |
| if accelerator.is_main_process: | |
| for tracker in accelerator.trackers: | |
| tracker.finish() | |
| accelerator.wait_for_everyone() | |
| accelerator.end_training() | |
| if __name__ == "__main__": | |
| from utils import load_config | |
| train(load_config()) | |