import copy import glob as glob_module import math import os import random import shutil import subprocess import sys from pathlib import Path import itertools from typing import Iterator, Iterable, List, NamedTuple import numpy as np import torch import torch.distributed as dist from einops import rearrange from accelerate import Accelerator from omegaconf import OmegaConf def build_ar_logit_mask(vis_pos_mask, sem_pos_mask, vis_cb_size, sem_cb_size): """Merge visual/semantic per-position masks into a single ``[T, vis_cb+sem_cb]`` AR logit mask.""" if vis_pos_mask is None and sem_pos_mask is None: return None ar_vocab = vis_cb_size + sem_cb_size parts = [] if sem_pos_mask is not None: sem_full = torch.full((sem_pos_mask.shape[0], ar_vocab), float('-inf')) sem_full[:, vis_cb_size:vis_cb_size + sem_cb_size] = sem_pos_mask parts.append(sem_full) if vis_pos_mask is not None: vis_full = torch.full((vis_pos_mask.shape[0], ar_vocab), float('-inf')) vis_full[:, :vis_cb_size] = vis_pos_mask parts.append(vis_full) return torch.cat(parts, dim=0) if parts else None def load_config(): """OmegaConf merge of ``--config`` / ``--configs`` (comma list, left-to-right) plus CLI ``key=value`` overrides.""" OmegaConf.register_new_resolver("eval", eval, replace=True) cli = OmegaConf.from_cli() paths_str = cli.pop("--configs", None) or cli.pop("--config", None) if paths_str is None: raise ValueError("Must provide --config or --configs") paths = [p.strip() for p in str(paths_str).split(",") if p.strip()] conf = OmegaConf.merge(*[OmegaConf.load(p) for p in paths]) for k, v in cli.items(): OmegaConf.update(conf, k, v) return conf def print0(*args, **kwargs): rank = 0 if dist.is_available() and dist.is_initialized(): rank = dist.get_rank() else: rank = int(os.environ.get("LOCAL_RANK", 0)) if rank == 0: print(*args, **kwargs) # ============================================================================ # Phase / Target Training System # ============================================================================ class Target(NamedTuple): DO_AE: bool = False DO_L2: bool = False DO_L1: bool = False DO_LPIPS: bool = False DO_GAN_G: bool = False DO_GAN_D: bool = False DO_PRIOR_AR: bool = False DO_PRIOR_ENC: bool = False class Phase(NamedTuple): num_steps: int targets: List[Target] internal_steps: List[int] def parse_phases(phases_str): phases = [] for phase_str in phases_str.split(' '): num_steps, targets_str, internal_steps_str = phase_str.split(':') num_steps = int(num_steps) targets = [Target(**{k: True for obj in target_str.split(',') for k in obj.split('-') }) for target_str in targets_str.split(',')] internal_steps = [int(step) for step in internal_steps_str.split(',')] phases.append(Phase(num_steps, targets, internal_steps)) return phases def parse_training_config_from_phases(phases): train_ae = False train_ar = False use_lpips_loss = False use_gan_loss = False train_prior_enc = False for phase in phases: for target in phase.targets: if target.DO_L1 or target.DO_L2 or target.DO_LPIPS or target.DO_GAN_G : train_ae = True if target.DO_PRIOR_AR or target.DO_PRIOR_ENC: train_ar = True if target.DO_LPIPS: use_lpips_loss = True if target.DO_GAN_G or target.DO_GAN_D: use_gan_loss = True if target.DO_PRIOR_ENC: train_prior_enc = True return train_ae, train_ar, use_lpips_loss, use_gan_loss, train_prior_enc def get_phase(global_step, phases, phase_step_accum, gan_start=0): target = None for phase_idx, phase_step in enumerate(phase_step_accum): if global_step <= phase_step: internel_step = (global_step - phase_step_accum[phase_idx-1]) if phase_idx > 0 else global_step internel_accumulate = list(itertools.accumulate(phases[phase_idx].internal_steps)) internel_step = internel_step % internel_accumulate[-1] for inner_idx in range(len(internel_accumulate)): if internel_step < internel_accumulate[inner_idx]: target = phases[phase_idx].targets[inner_idx] break if target is not None: break DO_AE = any([target.DO_L2, target.DO_L1, target.DO_LPIPS, target.DO_GAN_G]) target = Target(DO_L1=target.DO_L1, DO_L2=target.DO_L2, DO_LPIPS=target.DO_LPIPS, DO_GAN_G=target.DO_GAN_G and (global_step >= gan_start), DO_GAN_D=target.DO_GAN_D and (global_step >= gan_start), DO_PRIOR_AR=target.DO_PRIOR_AR, DO_PRIOR_ENC=target.DO_PRIOR_ENC, DO_AE=DO_AE) return phase_idx, inner_idx, target, internel_step # ============================================================================ # Learning Rate Schedulers # ============================================================================ def get_linear_schedule_with_warmup_peak( optimizer: torch.optim.Optimizer, num_warmup_steps: int, num_peak_steps: int, num_training_steps: int, last_epoch: int = -1, base_lr: float = 1e-4, end_lr: float = 0.0, ): """Linear warmup -> flat peak -> linear decay (``base_lr`` -> ``end_lr``).""" def lr_lambda(current_step): if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) elif current_step < num_warmup_steps + num_peak_steps: return 1.0 else: decay_steps = num_training_steps - num_warmup_steps - num_peak_steps progress = float(current_step - num_warmup_steps - num_peak_steps) / float(max(1, decay_steps)) progress = min(progress, 1.0) ratio = 1.0 - progress return (end_lr + (base_lr - end_lr) * ratio) / base_lr return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch) try: import wandb except ImportError: wandb = None try: import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt except Exception: # pragma: no cover plt = None # Trie structure for computing data conditional entropy from dataclasses import dataclass, field from typing import Dict, List, Tuple @dataclass class TrieNode: count: int = 0 children: Dict[int, "TrieNode"] = field(default_factory=dict) def trie_insert(root: TrieNode, seq: List[int], max_depth: int) -> None: """Insert a sequence into the Trie up to max_depth.""" node = root for tok in seq[:max_depth]: nxt = node.children.get(tok) if nxt is None: nxt = TrieNode() node.children[tok] = nxt nxt.count += 1 node = nxt def entropy_from_counts(counts: List[int], log_base: float = 2.0) -> float: """Compute entropy from a list of counts.""" if log_base <= 0: raise ValueError("log_base must be > 0") T = int(sum(counts)) if T <= 0: return float("nan") s = 0.0 for c in counts: if c > 0: s += c * math.log(c) denom = math.log(log_base) if log_base != math.e else 1.0 return (math.log(T) - (s / float(T))) / denom def trie_conditional_entropy_all_positions(root: TrieNode, max_depth: int, log_base: float = 2.0) -> Tuple[List[float], List[int]]: """Per-position conditional entropy H(X_d | X_ 0: child_counts = [int(ch.count) for ch in node.children.values()] if len(child_counts) == 0: continue Hc = entropy_from_counts(child_counts, log_base=log_base) total_T += ctx_T weighted_sum += ctx_T * Hc ctx_cnt += 1 continue if depth > target_depth: continue for child in node.children.values(): stack.append((child, depth + 1)) H = weighted_sum / float(total_T) if total_T > 0 else float("nan") H_cond.append(float(H)) num_contexts.append(int(ctx_cnt)) return H_cond, num_contexts def _entropy_from_logits(logits: torch.Tensor, log_base: float = 2.0) -> torch.Tensor: """Masked categorical entropy from logits (bits when ``log_base == 2``).""" logits = logits.float() log_probs = torch.log_softmax(logits, dim=-1) probs = log_probs.exp() ent_nats = -torch.nan_to_num(probs * log_probs, nan=0.0).sum(dim=-1) # 0*log0 -> 0 if log_base == math.e: return ent_nats return ent_nats / math.log(log_base) def plot_data_conditional_entropy( out_path: str, H_cond: list | None, log_base: float = 2.0, codebook_size: int | None = None, title_prefix: str = "Data conditional entropy", ) -> bool: """Save bar plot of H(X_d | X_ 0: max_entropy = math.log(codebook_size) / math.log(log_base) plt.axhline(y=max_entropy, color='red', linestyle='--', linewidth=2.0, label=f'Max H (Independent): {max_entropy:.2f}', alpha=0.8, zorder=10) if N > 0: codes_per_pos = codebook_size / N if codes_per_pos >= 1.0: split_entropy = math.log(codes_per_pos) / math.log(log_base) plt.axhline(y=split_entropy, color='orange', linestyle='-.', linewidth=2.0, label=f'Split Codebook ({codebook_size}/{N}={codes_per_pos:.1f}): {split_entropy:.2f}', alpha=0.8, zorder=10) plt.xlim(-0.5, max(0, N - 0.5)) step = max(1, N // 16) xticks = list(range(0, N, step)) if (N - 1) not in xticks: xticks.append(N - 1) plt.xticks(xticks) # Compute the y-axis range from data + reference lines valid_vals = [v for v in H_cond if isinstance(v, (int, float)) and not math.isnan(v)] if codebook_size is not None and codebook_size > 0: valid_vals.append(math.log(codebook_size) / math.log(log_base)) if N > 0 and codebook_size / N >= 1.0: valid_vals.append(math.log(codebook_size / N) / math.log(log_base)) if valid_vals: ymax = max(valid_vals) ymax = max(ymax, 0.0) y_max_tick = int(math.ceil(ymax * 1.1)) # leave 10% headroom plt.yticks(list(range(0, y_max_tick + 1, max(1, y_max_tick // 5)))) plt.ylim(0, y_max_tick) else: plt.ylim(bottom=0) plt.xlabel("position d (0-based)") plt.ylabel(f"conditional entropy (log_base={log_base})") plt.title(f"{title_prefix}: H(X_d | X_ bool: """Save per-position entropy curve to ``out_path``; optional ``codebook_size`` adds reference lines. """ if plt is None: return False if H is None or len(H) <= 0: return False plot_len = len(H) xs = list(range(plot_len)) fig = plt.figure(figsize=(max(10, plot_len * 0.05), 4)) plt.bar(xs, H, width=1.0, color="#4C78A8", label="H_ar(X_d | X_ 0: max_entropy = math.log(codebook_size) / math.log(log_base) plt.axhline(y=max_entropy, color='red', linestyle='--', linewidth=2.0, label=f'Max H (K={codebook_size}): {max_entropy:.2f}', alpha=0.8, zorder=10) plt.xlim(-0.5, max(0, plot_len - 0.5)) step = max(1, plot_len // 16) xticks = list(range(0, plot_len, step)) if (plot_len - 1) not in xticks: xticks.append(plot_len - 1) plt.xticks(xticks) valid_vals = [v for v in H if isinstance(v, (int, float)) and not math.isnan(v)] if codebook_size is not None and codebook_size > 0: valid_vals.append(math.log(codebook_size) / math.log(log_base)) if valid_vals: ymax = max(max(valid_vals), 0.0) y_max_tick = int(math.ceil(ymax * 1.1)) plt.yticks(list(range(0, y_max_tick + 1, max(1, y_max_tick // 5)))) plt.ylim(0, y_max_tick) else: plt.ylim(bottom=0) plt.xlabel("position d (0-based)") plt.ylabel(f"conditional entropy (log_base={log_base})") plt.title(f"{title_prefix}, d=0..{plot_len-1} (log_base={log_base})") plt.legend() plt.tight_layout() plt.savefig(out_path, dpi=200) plt.close(fig) return True def compute_posterior_entropy_from_logits(logits: torch.Tensor, log_base: float = 2.0) -> torch.Tensor: """Posterior entropy ``-E_q log q`` from ``[B, L, K]`` logits.""" logits = logits.float() log_probs = torch.log_softmax(logits, dim=-1) probs = log_probs.exp() ent_nats = -torch.nan_to_num(probs * log_probs, nan=0.0).sum(dim=-1) # 0*log0 -> 0 if log_base == math.e: return ent_nats return ent_nats / math.log(log_base) def compute_aggregated_entropy_from_counts(count_matrix: torch.Tensor, log_base: float = 2.0) -> torch.Tensor: """Aggregated-posterior entropy ``-E_z log q(z)`` from ``[L, K]`` counts.""" probs = count_matrix.float() / count_matrix.sum(dim=-1, keepdim=True).clamp(min=1.0) log_probs = torch.log(probs.clamp(min=1e-10)) ent_nats = -(probs * log_probs).sum(dim=-1) if log_base == math.e: return ent_nats return ent_nats / math.log(log_base) def plot_posterior_entropy( sample_entropy: list | None, aggregated_entropy: list | None, *, accelerator: Accelerator, save_dir: str, global_step: int, rFID: float = 0.0, gFID: float = 0.0, log_base: float = 2.0, codebook_size: int | None = None, out_name: str = "ae_pos_posterior_entropy.png", ) -> None: """Save bar plot of per-position sample/aggregated entropy.""" if plt is None or not accelerator.is_main_process: return if sample_entropy is None and aggregated_entropy is None: return len_sample = len(sample_entropy) if sample_entropy is not None else 0 len_agg = len(aggregated_entropy) if aggregated_entropy is not None else 0 L = max(len_sample, len_agg) if L <= 0: return out_dir = Path(save_dir) / "analysis_ae" / f"Step={global_step+1}-rFID={rFID:.4f}-gFID={gFID:.4f}" out_dir.mkdir(exist_ok=True, parents=True) fig_path = out_dir / out_name # Thin bars fig = plt.figure(figsize=(max(10, L * 0.05), 4)) xs = list(range(L)) # Pick bar offsets/widths based on which series are present if sample_entropy is not None and aggregated_entropy is not None: # Both series: side-by-side with offsets if len(sample_entropy) > 0: plt.bar([x - 0.2 for x in xs[:len_sample]], sample_entropy, width=0.4, color="#F58518", label="Sample Entropy", edgecolor='none') if len(aggregated_entropy) > 0: plt.bar([x + 0.2 for x in xs[:len_agg]], aggregated_entropy, width=0.4, color="#4C78A8", label="Aggregated Entropy", edgecolor='none') else: # Only one series: centered bars if sample_entropy is not None and len(sample_entropy) > 0: plt.bar(xs[:len_sample], sample_entropy, width=1.0, color="#F58518", label="Sample Entropy", edgecolor='none') if aggregated_entropy is not None and len(aggregated_entropy) > 0: plt.bar(xs[:len_agg], aggregated_entropy, width=1.0, color="#4C78A8", label="Aggregated Entropy", edgecolor='none') plt.grid(True, linestyle="--", alpha=0.3, axis='y') # Theoretical reference lines (fixed codebook) if codebook_size is not None and codebook_size > 0: max_entropy = math.log(codebook_size) / math.log(log_base) plt.axhline(y=max_entropy, color='red', linestyle='--', linewidth=2.0, label=f'Max Entropy (Uniform over {codebook_size}): {max_entropy:.2f}', alpha=0.8, zorder=10) if L > 0: codes_per_pos = codebook_size / L if codes_per_pos >= 1.0: split_entropy = math.log(codes_per_pos) / math.log(log_base) plt.axhline(y=split_entropy, color='orange', linestyle='-.', linewidth=2.0, label=f'Split Codebook ({codebook_size}/{L}={codes_per_pos:.1f}): {split_entropy:.2f}', alpha=0.8, zorder=10) plt.xlim(-0.5, max(0, L - 0.5)) # Thin out x-axis ticks step = max(1, L // 16) xticks = list(range(0, L, step)) if (L - 1) not in xticks: xticks.append(L - 1) plt.xticks(xticks) # Compute the y-axis range valid_vals = [] if sample_entropy is not None: valid_vals += [v for v in sample_entropy if isinstance(v, (int, float)) and not math.isnan(v)] if aggregated_entropy is not None: valid_vals += [v for v in aggregated_entropy if isinstance(v, (int, float)) and not math.isnan(v)] if codebook_size is not None and codebook_size > 0: valid_vals.append(math.log(codebook_size) / math.log(log_base)) if L > 0 and codebook_size / L >= 1.0: valid_vals.append(math.log(codebook_size / L) / math.log(log_base)) if valid_vals: ymax = max(valid_vals) ymax = max(ymax, 0.0) y_max_tick = int(math.ceil(ymax * 1.1)) # leave 10% headroom plt.yticks(list(range(0, y_max_tick + 1, max(1, y_max_tick // 5)))) plt.ylim(0, y_max_tick) else: plt.ylim(bottom=0) plt.xlabel("position d (0-based)") plt.ylabel(f"entropy (log_base={log_base})") plt.title(f"Aggregated Posterior Entropy per Position") plt.legend() plt.tight_layout() plt.savefig(str(fig_path), dpi=200) plt.close(fig) accelerator.log({"analysis/ae_pos_posterior_entropy": wandb.Image(str(fig_path)), "global_step": global_step + 1}, step=global_step+1) def plot_codebook_usage( codebook_usage: torch.Tensor | None, *, accelerator: Accelerator, save_dir: str, global_step: int, rFID: float = 0.0, gFID: float = 0.0, out_name: str = "ae_pos_code_usage_rate.png", ) -> None: """Save per-position unique-code / K usage from ``codebook_usage[L, K]`` counts.""" if plt is None or not accelerator.is_main_process: return if codebook_usage is None or codebook_usage.dim() != 2: return # Per-position codebook utilization used_per_pos = (codebook_usage > 0).sum(dim=1).float() # [L] usage = (used_per_pos / float(codebook_usage.shape[1])).detach().cpu().tolist() L = int(len(usage)) if L <= 0: return out_dir = Path(save_dir) / "analysis_ae" / f"Step={global_step+1}-rFID={rFID:.4f}-gFID={gFID:.4f}" out_dir.mkdir(exist_ok=True, parents=True) fig_path = out_dir / out_name # Bar plot fig = plt.figure(figsize=(max(10, L * 0.05), 4)) plt.bar(list(range(L)), usage, color="#54A24B", width=1.0, edgecolor='none') plt.grid(True, linestyle="--", alpha=0.3, axis='y') plt.xlim(-0.5, max(0, L - 0.5)) plt.ylim(0.0, 1.05) # Thin out x-axis ticks step = max(1, L // 16) xticks = list(range(0, L, step)) if (L - 1) not in xticks: xticks.append(L - 1) plt.xticks(xticks) plt.xlabel("position d (0-based)") plt.ylabel(f"unique / K (K={codebook_usage.shape[1]})") plt.title("Codebook Usage Rate per Position") plt.tight_layout() plt.savefig(str(fig_path), dpi=200) plt.close(fig) accelerator.log({"analysis/ae_pos_code_usage_rate": wandb.Image(str(fig_path)), "global_step": global_step + 1}, step=global_step+1) def seed_everything(seed): """Set python/numpy/torch (CPU+CUDA)/hash seeds.""" random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) os.environ['PYTHONHASHSEED'] = str(seed) def worker_init_fn(worker_id): worker_seed = torch.initial_seed() % 2**32 np.random.seed(worker_seed) random.seed(worker_seed) def make_worker_init_fn(base_seed: int): base_seed = int(base_seed) % (2**32) def _init(worker_id: int): rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0 worker_seed = (base_seed + worker_id + 1000 * rank) % (2**32) np.random.seed(worker_seed) random.seed(worker_seed) torch.manual_seed(worker_seed) return _init def load_accelerate_weights_only( *, accelerator: Accelerator, input_dir: str, strict: bool = True, map_location: str | torch.device | None = "cpu", ) -> None: """Load only ``model*.safetensors`` from an ``accelerator.save_state()`` dir (no optim/RNG/dl state).""" input_dir = os.path.expanduser(str(input_dir)) if not os.path.isdir(input_dir): raise ValueError(f"Tried to load weights from {input_dir} but folder does not exist") from accelerate.state import DistributedType from accelerate.utils import load as accelerate_load, load_fsdp_model from accelerate.checkpointing import SAFE_MODEL_NAME, MODEL_NAME, load_model device_str = "cpu" if map_location in (None, "cpu") else str(map_location) input_path = Path(input_dir) # Iterate over accelerator._models to preserve save_state ordering. models = getattr(accelerator, "_models", None) or [] if len(models) == 0: print0("[warn] No models registered in accelerator; skip loading weights.") return for i, model in enumerate(models): if accelerator.distributed_type == DistributedType.FSDP: load_fsdp_model(accelerator.state.fsdp_plugin, accelerator, model, input_dir, i) continue if accelerator.distributed_type == DistributedType.DEEPSPEED: ckpt_id = f"{MODEL_NAME}" if i == 0 else f"{MODEL_NAME}_{i}" model.load_checkpoint( input_dir, ckpt_id, load_optimizer_states=False, load_lr_scheduler_states=False, load_module_strict=bool(strict), ) continue if accelerator.distributed_type == DistributedType.MEGATRON_LM: raise NotImplementedError( "resume_train=False (weights-only) is not supported for Megatron-LM checkpoints in this script." ) ending = f"_{i}" if i > 0 else "" safe_file = input_path / f"{SAFE_MODEL_NAME}{ending}.safetensors" if safe_file.exists(): load_model(model, safe_file, strict=bool(strict), device=device_str) continue bin_file = input_path / f"{MODEL_NAME}{ending}.bin" if bin_file.exists(): state_dict = accelerate_load(bin_file, map_location=map_location) model.load_state_dict(state_dict, strict=bool(strict)) continue raise FileNotFoundError( f"Could not find model weights for model index {i} under {input_dir}. " f"Tried: {safe_file.name} and {bin_file.name}" ) @torch.no_grad() def draw_data_conditional_entropy( trie_root: TrieNode | None, *, idx: torch.Tensor | None = None, accelerator: Accelerator, save_dir: str, global_step: int, log_base: float = 2.0, finalize: bool = False, rFID: float = 0.0, gFID: float = 0.0, max_depth: int = 0, codebook_size: int | None = None, ) -> TrieNode | None: """Trie builder for data conditional entropy: idx chunks until ``finalize=True``, then plot + wandb log.""" if not finalize: if idx is None: return trie_root idx_all = accelerator.gather(idx.detach()) # [B_total, L] if accelerator.is_main_process: if trie_root is None: trie_root = TrieNode() L = int(idx_all.shape[1]) if max_depth <= 0: max_depth = L idx_cpu = idx_all.cpu().tolist() for seq in idx_cpu: trie_insert(trie_root, seq, max_depth=max_depth) return trie_root accelerator.wait_for_everyone() if not accelerator.is_main_process or trie_root is None: return trie_root if max_depth <= 0: max_depth = 256 H_cond, num_contexts = trie_conditional_entropy_all_positions( trie_root, max_depth=max_depth, log_base=log_base ) if len(H_cond) == 0: return trie_root out_dir = Path(save_dir) / "analysis_ae" / f"Step={global_step+1}-rFID={rFID:.4f}-gFID={gFID:.4f}" out_dir.mkdir(exist_ok=True, parents=True) fig_path = out_dir / "ae_data_conditional_entropy.png" saved = plot_data_conditional_entropy( out_path=str(fig_path), H_cond=H_cond, log_base=log_base, codebook_size=codebook_size, title_prefix="AE Data Conditional Entropy", ) if saved and wandb is not None: accelerator.log( {"analysis/ae_data_conditional_entropy": wandb.Image(str(fig_path)), "global_step": global_step + 1}, step=global_step+1 ) if len(H_cond) > 0: mean_H = float(np.nanmean(np.array(H_cond, dtype=np.float64))) accelerator.log( {"analysis/ae_data_cond_entropy_mean": mean_H, "global_step": global_step + 1}, step=global_step+1 ) return trie_root @torch.no_grad() def draw_conditional_entropy( acc: dict, *, logits: torch.Tensor | None = None, accelerator: Accelerator, save_dir: str, global_step: int, log_base: float = 2.0, finalize: bool = False, rFID: float = 0.0, gFID: float = 0.0, codebook_size: int | None = None, ) -> None: """Accumulate logit entropy from existing logits; ``finalize=True`` reduces + plots + wandb-logs.""" device = accelerator.device if not finalize: if logits is None: return ent = _entropy_from_logits(logits, log_base=log_base) # [B, L] L = int(ent.shape[1]) if acc.get("ent_sum") is None: acc["ent_sum"] = torch.zeros(L, dtype=ent.dtype, device=device) acc["ent_cnt"] = torch.zeros(L, dtype=torch.long, device=device) elif int(acc["ent_sum"].shape[0]) < L: pad = L - int(acc["ent_sum"].shape[0]) acc["ent_sum"] = torch.cat( [acc["ent_sum"], torch.zeros(pad, dtype=acc["ent_sum"].dtype, device=device)], dim=0, ) acc["ent_cnt"] = torch.cat( [acc["ent_cnt"], torch.zeros(pad, dtype=torch.long, device=device)], dim=0, ) acc["ent_sum"][:L] += ent.sum(dim=0) acc["ent_cnt"][:L] += int(ent.shape[0]) return # finalize mode accelerator.wait_for_everyone() if acc.get("ent_sum") is not None and acc.get("ent_cnt") is not None: acc["ent_sum"] = accelerator.reduce(acc["ent_sum"], reduction='sum') acc["ent_cnt"] = accelerator.reduce(acc["ent_cnt"], reduction='sum') if not accelerator.is_main_process: return H = None if acc.get("ent_sum") is not None and acc.get("ent_cnt") is not None: denom = acc["ent_cnt"].clamp(min=1).to(acc["ent_sum"].dtype) H = (acc["ent_sum"] / denom).detach().cpu().tolist() out_dir = Path(save_dir) / "analysis_ar" / f"Step={global_step+1}-rFID={rFID:.4f}-gFID={gFID:.4f}" out_dir.mkdir(exist_ok=True, parents=True) fig_path = out_dir / "ar_prefix_conditional_entropy.png" saved = plot_ar_prefix_conditional_entropy( out_path=str(fig_path), H=H, log_base=log_base, codebook_size=codebook_size, ) if saved: accelerator.log({"analysis/ar_prefix_conditional_entropy": wandb.Image(str(fig_path)), "global_step": global_step + 1}, step=global_step+1) if H is not None and len(H) > 0: accelerator.log( { "analysis/ar_entropy_mean_per_pos": float(np.nanmean(np.array(H, dtype=np.float64))), "global_step": global_step + 1, }, step=global_step+1, ) def generate_uniform_labels( *, num_samples: int, num_classes: int, accelerator: Accelerator, exclude_uncond: bool = True, ) -> torch.Tensor: """Uniform class label indices for this rank (excludes uncond class by default).""" num_valid_classes = num_classes - 1 if exclude_uncond else num_classes all_classes = list(range(num_valid_classes)) * (num_samples // num_valid_classes + 1) all_classes = all_classes[:num_samples] all_classes_tensor = torch.tensor(all_classes, dtype=torch.long) rank = accelerator.process_index num_devices = accelerator.num_processes samples_per_rank = num_samples // num_devices start_idx = rank * samples_per_rank end_idx = start_idx + samples_per_rank if rank < num_devices - 1 else num_samples return all_classes_tensor[start_idx:end_idx].to(accelerator.device) class InfiniteIterator(Iterator): def __init__(self, iterable: Iterable, dl_generator: torch.Generator = None): self.iterable = iterable self.dl_generator = dl_generator self._pre_epoch_gen_state = ( dl_generator.get_state().clone() if dl_generator is not None else None ) self._it = iter(iterable) self.total_yielded = 0 def __iter__(self): return self def __next__(self): try: item = next(self._it) except StopIteration: if self.dl_generator is not None: self._pre_epoch_gen_state = self.dl_generator.get_state().clone() self._it = iter(self.iterable) item = next(self._it) self.total_yielded += 1 return item # ============================================================================ # File / Checkpoint Utils # ============================================================================ def safe_remove_file(path: str): """Remove a single file safely: warn on failure but never raise.""" try: if path is not None and os.path.exists(path): os.remove(path) except Exception as e: print(f"Warning: Failed to remove {path}: {e}") def save_training_state(accelerator, save_path, extra_state, **updates): """Save accelerator state + extra state for fully consistent resume.""" if updates: extra_state.update(updates) accelerator.save_state(save_path) if accelerator.is_main_process: torch.save(extra_state, os.path.join(save_path, "extra_state.pt")) accelerator.wait_for_everyone() def remove_old_best_checkpoints(ckpt_dir: str, metric_type: str = "gFID"): """Delete older ``best-*-{metric_type}=*`` checkpoints under ``ckpt_dir``.""" pattern = os.path.join(ckpt_dir, f"best-*-{metric_type}=*") old_best_ckpts = glob_module.glob(pattern) for old_ckpt in old_best_ckpts: try: if os.path.isdir(old_ckpt): shutil.rmtree(old_ckpt) print(f"Removed old best checkpoint: {os.path.basename(old_ckpt)}") except Exception as e: print(f"Warning: Failed to remove {old_ckpt}: {e}") # ============================================================================ # Image Processing Utils # ============================================================================ def patchify(x, patch_size): x = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size) return x def img_uint8_to_norm(x): return x.float() / 127.5 - 1.0 def unpatchify(x, image_size, patch_size): x = rearrange(x, 'b (h w) (p1 p2 c) -> b c (h p1) (w p2)', p1=patch_size, p2=patch_size, h=image_size//patch_size, w=image_size//patch_size) return x def img_denormalize(x): return x.clamp(-1, 1) * 0.5 + 0.5 def img_norm_to_uint8(x): return torch.clamp(127.5 * x + 128.0, 0, 255).byte() # ============================================================================ # FID Utils # ============================================================================ def adm_fid_evaluator(sample_cached_path, gt_cache_path, config, accelerator: Accelerator, compute_is=False): if not os.path.exists(gt_cache_path): raise FileNotFoundError(f"Ground-truth cache not found: {gt_cache_path}") if not os.path.exists(sample_cached_path): raise FileNotFoundError(f"Sample cache not found: {sample_cached_path}") fid_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'eval_fid.py') env = os.environ.copy() cmd = [sys.executable, fid_script, "--ref_batch", gt_cache_path, "--sample_batch", sample_cached_path, "--batch_size", str(config.eval_batch_size)] if compute_is: cmd.append("--compute_is") print0(f"Running FID evaluation via {fid_script}..." + (" (with IS)" if compute_is else "")) FID = 0.0 IS = 0.0 process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1, env=env) assert process.stdout is not None for line in process.stdout: line = line.rstrip("\n") if line: print0(line, flush=True) if line.startswith("FID_RESULT:"): try: FID = float(line.split("FID_RESULT:")[1].strip()) except ValueError: pass elif line.startswith("IS_RESULT:"): try: IS = float(line.split("IS_RESULT:")[1].strip()) except ValueError: pass retcode = process.wait() if retcode != 0 and FID == 0.: print0(f"eval_fid.py exited with code {retcode} and no FID_RESULT was parsed.") if compute_is: return FID, IS return FID # ============================================================================ # Training Utils # ============================================================================ @torch.no_grad() @torch._dynamo.disable def _unwrap(model): """Unwrap torch.compile / DDP wrappers to access the raw nn.Module.""" while hasattr(model, '_orig_mod'): model = model._orig_mod while hasattr(model, 'module'): model = model.module return model @torch.no_grad() @torch._dynamo.disable def ema_update(model, ema_model, ema_rate): if model is None or ema_model is None: return for p, ema_p in zip(model.parameters(), ema_model.parameters()): ema_p.copy_(p.detach().lerp(ema_p, ema_rate)) def sync_gradients(model, sub_modules=None): import torch.distributed as dist if not dist.is_initialized(): return params = [] if sub_modules is not None: for name in sub_modules: params.extend(getattr(model, name).parameters()) else: params = list(model.parameters()) for p in params: if p.grad is not None: dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) def toggle_require_grad(model, grads=True, accelerator=None, sub_modules=None): if model is None: return if accelerator is not None: model = accelerator.unwrap_model(model) elif hasattr(model, "_orig_mod"): model = model._orig_mod if sub_modules is not None: for name in sub_modules: getattr(model, name).requires_grad_(grads) elif hasattr(model, "requires_grad_"): model.requires_grad_(grads) else: for p in model.parameters(): p.requires_grad_(grads) def toggle_train_eval(model, train=True, accelerator=None, sub_modules=None): if model is None: return if accelerator is not None: model = accelerator.unwrap_model(model) elif hasattr(model, "_orig_mod"): model = model._orig_mod if sub_modules is not None: for name in sub_modules: getattr(model, name).train(mode=train) elif hasattr(model, "train"): model.train(mode=train) def zero_nan_gradients(model, accelerator=None): if model is None: return if accelerator is not None: model = accelerator.unwrap_model(model) elif hasattr(model, "_orig_mod"): model = model._orig_mod for name, param in model.named_parameters(): if param.grad is not None: param.grad.nan_to_num_(nan=0.0, posinf=1e5, neginf=-1e5) def calc_grad_norm( named_models: dict, global_step: int, grad_norm_freq: int, accelerator=None, ) -> dict: """Per-parameter grad L2 norms as a flat dict for wandb (only on ``(step+1) % freq == 0``).""" if grad_norm_freq <= 0 or (global_step + 1) % grad_norm_freq != 0: return {} result = {} for group, model in named_models.items(): if model is None: continue # unwrap DDP wrapper, then strip torch.compile's OptimizedModule raw = accelerator.unwrap_model(model) if accelerator is not None else model while hasattr(raw, "_orig_mod"): raw = raw._orig_mod sq_sum = 0.0 for name, param in raw.named_parameters(): if param.grad is not None: pnorm = param.grad.norm().item() result[f"Gradient_Norm/{group}/{name}"] = pnorm sq_sum += pnorm ** 2 if sq_sum > 0.0: result[f"Gradient_Norm/{group}/_total"] = sq_sum ** 0.5 return result def save_tensor_image_png_pdf(tensor, png_path: str, dpi: float = 300.0) -> None: """Save ``[N, C, H, W]`` in [0, 1] to ``png_path`` and a sibling PDF (figure-friendly). """ import torchvision.utils torchvision.utils.save_image(tensor, png_path) pdf_path = os.path.splitext(png_path)[0] + ".pdf" from PIL import Image im = Image.open(png_path) if im.mode == "RGBA": bg = Image.new("RGB", im.size, (255, 255, 255)) bg.paste(im, mask=im.split()[3]) im.close() im = bg else: im = im.convert("RGB") im.save(pdf_path, "PDF", resolution=dpi) im.close()