prologue-demo / sample_vis.py
Bowen Zheng
init
500ee30
"""Class-conditional sampling and prologue-fix resampling for Prologue models.
Two modes:
sample Class-conditional generation grid (one row per class).
prologue_fix Sample a reference; freeze its first z_len prologue tokens
and resample the remaining visual tokens. Requires a Prologue
tokenizer (configs/tokenizer/prologue.yaml).
The functions below are also used as a library by `app.py` (Gradio demo).
CLI usage:
python sample_vis.py --configs=configs/default.yaml,configs/ar/_defaults.yaml,\\
configs/ar/xlarge.yaml,configs/tokenizer/default.yaml,\\
configs/tokenizer/prologue.yaml,configs/train/ar.yaml,configs/train/eval_ar.yaml \\
tokenizer_ckpt_path=<tok> resume_ckpt_path=<ar> \\
mode=prologue_fix class_ids="207,388" num_resample=8 \\
output_dir=out/
Module-internal attributes (``semantic_emb``, ``z_len``, ``semantic_drop``, ...) match
the released safetensors keys; "prologue" is the user-facing name everywhere else.
"""
import math
import os
import torch
import torch.nn.functional as F
import torchvision.utils
from safetensors.torch import load_file
from train_ar import (
_codes_from_indices,
_labels_from_label_idx,
_load_decoder,
_load_quantizer,
_load_semantic_quantizer,
)
from models import ARModel
from utils import (
build_ar_logit_mask,
img_denormalize,
load_config,
print0,
save_tensor_image_png_pdf,
seed_everything,
unpatchify,
)
torch.backends.cudnn.benchmark = True
IMAGENET_NAMES = {
33: "loggerhead_turtle", 88: "macaw", 90: "lorikeet", 94: "hummingbird",
100: "black_swan", 107: "jellyfish", 117: "chambered_nautilus", 130: "flamingo",
144: "pelican", 146: "albatross", 207: "golden_retriever", 250: "Siberian_husky",
259: "Pomeranian", 279: "arctic_fox", 281: "tabby_cat", 291: "lion",
292: "tiger", 293: "cheetah", 295: "brown_bear", 323: "monarch_butterfly",
340: "zebra", 360: "otter", 386: "African_elephant", 387: "red_panda",
388: "giant_panda", 417: "balloon", 628: "liner", 817: "sports_car",
927: "trifle", 928: "ice_cream", 930: "French_loaf", 933: "cheeseburger",
934: "hotdog", 963: "pizza", 971: "bubble", 972: "cliff", 973: "coral_reef",
978: "seashore", 979: "valley", 980: "volcano", 985: "daisy", 988: "acorn",
996: "alp",
}
DEFAULT_CLASS_IDS = [
207, 388, 387, 88, 130, 279, 417, 928,
980, 973, 985, 33, 360, 250, 293, 323,
]
# ═══════════════════════════════════════════════════════════════════════════
# Model loading
# ═══════════════════════════════════════════════════════════════════════════
def load_models(config, device):
"""Return ``(quantizer, dec, prologue_quantizer, ar_model)``; prologue_quantizer is ``None`` for 1D/2D tokenizers."""
prologue = (
bool(config.get("Prologue", False))
and not bool(config.get("share_semantic_codebook", False))
)
tok_ckpt = config.tokenizer_ckpt_path
quantizer = _load_quantizer(config, ckpt_dir=tok_ckpt).to(device)
dec = _load_decoder(config, ckpt_dir=tok_ckpt).to(device)
prologue_quantizer = None
if prologue:
prologue_quantizer = _load_semantic_quantizer(
config, ckpt_dir=tok_ckpt
).to(device)
print0("Prologue: loaded prologue (semantic) quantizer")
ar_model = ARModel(config)
_logit_mask = build_ar_logit_mask(
getattr(quantizer, "pos_select_mask", None),
getattr(prologue_quantizer, "pos_select_mask", None) if prologue else None,
vis_cb_size=int(config["Quantizer"]["codebook_size"]),
sem_cb_size=int(config["SemanticQuantizer"]["codebook_size"]) if prologue else 0,
)
ar_model.set_logit_mask(_logit_mask)
ema = bool(config.get("ema_sampling", False))
if bool(config.get("continuous_training", False)) and tok_ckpt:
# OneStage: AR weights live in the tokenizer ckpt as model_5/6.safetensors.
path = os.path.join(tok_ckpt, "model_6.safetensors" if ema else "model_5.safetensors")
if ema and not os.path.exists(path):
path = os.path.join(tok_ckpt, "model_5.safetensors")
print0("AR EMA not found, falling back to regular weights")
print0(f"Loading AR weights (joint training): {path}")
ar_model.load_state_dict(load_file(path), strict=True)
elif getattr(config, "resume_ckpt_path", ""):
ckpt = config.resume_ckpt_path
fname = "model_1.safetensors" if ema else "model.safetensors"
path = os.path.join(ckpt, fname)
if ema and not os.path.exists(path):
path = os.path.join(ckpt, "model.safetensors")
print0("AR EMA not found, falling back to regular weights")
print0(f"Loading AR weights: {path}")
ar_model.load_state_dict(load_file(path), strict=False)
else:
raise ValueError(
"Must provide resume_ckpt_path or "
"(continuous_training=True + tokenizer_ckpt_path)"
)
ar_model.to(device).eval()
print0(
f"AR model ready (z_len={ar_model.z_len}, "
f"max_length={ar_model.max_length})"
)
return quantizer, dec, prologue_quantizer, ar_model
# ═══════════════════════════════════════════════════════════════════════════
# Sampling helpers
# ═══════════════════════════════════════════════════════════════════════════
def _cfg_params(config):
"""Pack every CFG-related field into a kwargs dict for ARModel.sampling."""
get = config.get
return dict(
temperature=config.temperature,
topK=config.topK,
topP=config.topP,
cfg=config.cfg,
cfg_schedule=config.cfg_schedule,
cfg_power=config.cfg_power,
cache_kv=config.cache_kv,
semantic_cfg=get("semantic_cfg", None),
semantic_cfg_schedule=get("semantic_cfg_schedule", None),
semantic_cfg_scale=get("semantic_cfg_scale", None),
semantic_cfg_power=get("semantic_cfg_power", None),
semantic_cfg_start=float(get("semantic_cfg_start", 0.0)),
visual_cfg_schedule=get("visual_cfg_schedule", None),
visual_cfg_scale=get("visual_cfg_scale", None),
visual_cfg_power=get("visual_cfg_power", None),
visual_cfg_start=float(get("visual_cfg_start", 1.0)),
cfg_continuous=bool(get("cfg_continuous", False)),
semantic_temperature=(
float(get("semantic_temperature"))
if get("semantic_temperature") is not None
else None
),
)
@torch.no_grad()
def sample_tokens(ar_model, *, bz, class_label, config):
"""Thin wrapper over ``ARModel.sampling``; returns token ids ``[bz, max_length]``."""
get = config.get
sem_temp = get("semantic_temperature")
return ar_model.sampling(
bz, class_label,
config.temperature, config.topK, config.topP,
config.cfg, config.cfg_schedule, config.cfg_power,
config.cache_kv,
semantic_cfg_schedule=get("semantic_cfg_schedule", None),
semantic_cfg_scale=get("semantic_cfg_scale", None),
semantic_cfg_power=get("semantic_cfg_power", None),
semantic_cfg_start=float(get("semantic_cfg_start", 0.0)),
visual_cfg_schedule=get("visual_cfg_schedule", None),
visual_cfg_scale=get("visual_cfg_scale", None),
visual_cfg_power=get("visual_cfg_power", None),
visual_cfg_start=float(get("visual_cfg_start", 1.0)),
semantic_temperature=float(sem_temp) if sem_temp is not None else None,
)
def decode_tokens(quantizer, dec, token_ids, *, config, ae_label):
"""Decode token ids to images ``[B, 3, H, W]`` in [0, 1]."""
prologue = (
bool(config.get("Prologue", False))
and not bool(config.get("share_semantic_codebook", False))
)
if prologue:
z_len = int(config.z_len)
eos_len = 1 if bool(config.get("use_eos", False)) and z_len > 0 else 0
visual_ids = token_ids[:, z_len + eos_len:]
quant = _codes_from_indices(quantizer, visual_ids, ae_label)
else:
quant = _codes_from_indices(quantizer, token_ids, ae_label)
patches = dec(quant, ae_label)
return img_denormalize(unpatchify(patches, config.image_size, config.patch_size))
def _make_labels(class_id, bz, num_classes, device):
"""Build ``(class_label, unconditional_label)`` each ``[bz, num_classes]``."""
uncond_idx = num_classes - 1
idx = torch.full((bz,), class_id, device=device, dtype=torch.long)
cls = _labels_from_label_idx(idx, num_classes=num_classes, uncond_idx=uncond_idx)
uncond_idx_t = torch.full((bz,), uncond_idx, device=device, dtype=torch.long)
ae = _labels_from_label_idx(uncond_idx_t, num_classes=num_classes, uncond_idx=uncond_idx)
return cls, ae
# ═══════════════════════════════════════════════════════════════════════════
# Prologue-fix sampling (teacher-force masked prologue positions)
# ═══════════════════════════════════════════════════════════════════════════
@torch.no_grad()
@torch._dynamo.disable
def sampling_with_fixed_prologue(ar_model, *, bz, class_label, config,
fixed_prologue_ids, n_fix=None, fix_mask=None):
"""AR sample with selected prologue positions teacher-forced (``fix_mask`` overrides ``n_fix``)."""
m = ar_model
z_len = m.z_len
if z_len <= 0:
raise ValueError("sampling_with_fixed_prologue requires a Prologue tokenizer (z_len > 0)")
fixed = fixed_prologue_ids.expand(bz, -1)
if fix_mask is None:
n = n_fix if n_fix is not None else z_len
fix_mask = torch.zeros(z_len, dtype=torch.bool, device=fixed.device)
fix_mask[:n] = True
p = _cfg_params(config)
cfg_val = 0.0 if class_label is None else p["cfg"]
cfg_schedule = p["cfg_schedule"]
cfg_power = p["cfg_power"]
temperature = p["temperature"]
topK = p["topK"]
topP = p["topP"]
cache_kv = p["cache_kv"]
sem_cfg = p["semantic_cfg"]
sem_cfg_sched = p["semantic_cfg_schedule"]
sem_cfg_scale = p["semantic_cfg_scale"]
sem_cfg_pow = p["semantic_cfg_power"]
sem_cfg_start = p["semantic_cfg_start"]
vis_cfg_sched = p["visual_cfg_schedule"]
vis_cfg_scale = p["visual_cfg_scale"]
vis_cfg_pow = p["visual_cfg_power"]
vis_cfg_start = p["visual_cfg_start"]
cfg_cont = p["cfg_continuous"]
sem_temp = p["semantic_temperature"]
use_seg = sem_cfg_sched is not None or vis_cfg_sched is not None
use_cfg = cfg_val > 0.0 or (sem_cfg is not None and sem_cfg > 0.0)
if use_seg:
_ss = sem_cfg_scale if sem_cfg_scale is not None else cfg_val
_vs = vis_cfg_scale if vis_cfg_scale is not None else cfg_val
use_cfg = _ss > 0.0 or _vs > 0.0
uncond_idx = int(m.ar_model.cond_input_dim) - 1
device = m.bos_emb.weight.device
if m.conditional_injection == "llamagen":
cond_bos = m.bos_emb(torch.argmax(class_label, dim=1)).unsqueeze(1)
uncond_bos = m.bos_emb(
torch.full((bz,), uncond_idx, device=device, dtype=torch.long)
).unsqueeze(1)
ar_labels = m.uncond_ar_labels.expand(bz, -1).to(device=device)
uncond_labels = m.uncond_ar_labels.expand(bz, -1).to(device=device)
else:
cond_bos = m.bos_emb(
torch.zeros(bz, device=device, dtype=torch.long)
).unsqueeze(1)
uncond_bos = cond_bos
ar_labels = class_label
uncond_labels = m.uncond_ar_labels.expand(bz, -1).to(device=device)
quant_input = (
torch.cat([cond_bos, uncond_bos], dim=0) if use_cfg else cond_bos
)
ar_labels_2x = (
torch.cat([ar_labels, uncond_labels], dim=0) if use_cfg else ar_labels
)
quant_output = []
past_kvs = None
for step in range(m.max_length):
is_sem = z_len > 0 and step < z_len
if use_cfg:
ar_out = m.ar_model(quant_input, ar_labels_2x, cache_kv=cache_kv, past_kvs=past_kvs)
hidden_all, past_kvs = ar_out if cache_kv else (ar_out, None)
if m.tied_embedding:
hidden_all = F.linear(hidden_all[:, -1:], m.semantic_emb.weight)
logits_all = hidden_all[:, -1]
logits, uncond_logits = logits_all.chunk(2, dim=0)
# Scheduled CFG strength c(step), identical to ARModel.sampling.
if use_seg:
if is_sem:
sc = sem_cfg_sched or "constant"
ss = sem_cfg_scale if sem_cfg_scale is not None else cfg_val
sp = sem_cfg_pow if sem_cfg_pow is not None else cfg_power
s0 = sem_cfg_start
st = step / z_len if z_len > 0 else 0.0
else:
sc = vis_cfg_sched or "constant"
ss = vis_cfg_scale if vis_cfg_scale is not None else cfg_val
sp = vis_cfg_pow if vis_cfg_pow is not None else cfg_power
s0 = (
(sem_cfg_scale if sem_cfg_scale is not None else cfg_val)
if cfg_cont else vis_cfg_start
)
vl = m.max_length - z_len
st = (step - z_len) / vl if vl > 0 else 0.0
if sc == "constant":
c = ss
elif sc == "linear":
c = s0 + (ss - s0) * st
elif sc == "cosine":
c = s0 + (ss - s0) * (1 - math.cos((st ** sp) * math.pi)) * 0.5
else:
raise ValueError(sc)
elif cfg_schedule == "constant" and is_sem and sem_cfg is not None:
c = sem_cfg
elif cfg_schedule == "constant":
c = cfg_val
elif cfg_schedule == "linear":
c = 1.0 * (1 - step / m.max_length) + cfg_val * (step / m.max_length)
elif cfg_schedule == "cosine":
c = (1 - math.cos(((step / m.max_length) ** cfg_power) * math.pi)) * 0.5
c = (cfg_val - 1) * c + 1
else:
raise ValueError(cfg_schedule)
logits = c * logits + (1 - c) * uncond_logits
else:
ar_out = m.ar_model(quant_input, ar_labels_2x, cache_kv=cache_kv, past_kvs=past_kvs)
hidden, past_kvs = ar_out if cache_kv else (ar_out, None)
if m.tied_embedding:
hidden = F.linear(hidden[:, -1:], m.semantic_emb.weight)
logits = hidden[:, -1]
t = sem_temp if (sem_temp is not None and is_sem) else temperature
logits = logits / t
if m.logit_mask is not None:
logits = logits + m.logit_mask[step]
if topK is not None and topK > 0.0:
tl, ti = logits.topk(int(topK), dim=-1)
logits = torch.full_like(logits, float("-inf"))
logits.scatter_(dim=-1, index=ti, src=tl)
if topP is not None and 0.0 < topP < 1.0:
sl, si = torch.sort(logits, dim=-1, descending=True)
ps = sl.softmax(dim=-1).cumsum(dim=-1)
mask = ps > topP
mask[..., 1:] = mask[..., :-1].clone()
mask[..., 0] = False
sl[mask] = float("-inf")
logits = torch.full_like(logits, float("-inf"))
logits.scatter_(dim=-1, index=si, src=sl)
if step < z_len and fix_mask[step]:
next_idx = fixed[:, step : step + 1]
else:
with torch.amp.autocast("cuda", enabled=False):
next_idx = torch.multinomial(F.softmax(logits.float(), dim=-1), 1)
next_idx = next_idx.to(dtype=torch.long)
quant_output.append(next_idx)
next_emb = m.semantic_emb(next_idx)
if use_cfg:
# semantic_drop branch kept for legacy ckpts; default = symmetric uncond.
if getattr(m, "semantic_drop", False) and is_sem:
uncond_sem = torch.full_like(next_idx, m.uncond_sem_token_id)
uncond_emb = m.semantic_emb(uncond_sem)
next_emb = torch.cat([next_emb, uncond_emb], dim=0)
else:
next_emb = torch.cat([next_emb, next_emb], dim=0)
if not cache_kv:
quant_input = torch.cat((quant_input, next_emb), dim=1)
else:
quant_input = next_emb
return torch.cat(quant_output, dim=1)
# ═══════════════════════════════════════════════════════════════════════════
# CLI mode implementations
# ═══════════════════════════════════════════════════════════════════════════
def mode_sample(quantizer, dec, ar_model, *, config, class_ids, num_per_class,
output_dir, device):
"""One row per class, ``num_per_class`` independent samples."""
num_classes = int(config.num_classes)
ae_no_label = bool(config.get("ae_no_label", False))
all_imgs = []
for cid in class_ids:
name = IMAGENET_NAMES.get(cid, f"class{cid}")
print0(f" Sampling class {cid} ({name}) -> {num_per_class} ...")
cls_lbl, uncond_lbl = _make_labels(cid, num_per_class, num_classes, device)
ae_lbl = uncond_lbl if ae_no_label else cls_lbl
token_ids = sample_tokens(ar_model, bz=num_per_class,
class_label=cls_lbl, config=config)
imgs = decode_tokens(quantizer, dec, token_ids,
config=config, ae_label=ae_lbl)
all_imgs.append(imgs)
grid = torchvision.utils.make_grid(imgs, nrow=num_per_class, padding=2)
out = os.path.join(output_dir, f"class_{cid}_{name}.png")
save_tensor_image_png_pdf(grid, out)
print0(f" -> {out} (+ .pdf)")
combined = torch.cat(all_imgs, dim=0)
grid = torchvision.utils.make_grid(combined, nrow=num_per_class, padding=2)
out = os.path.join(output_dir, "sample_grid.png")
save_tensor_image_png_pdf(grid, out)
print0(f" Combined grid -> {out} (+ .pdf)")
def mode_prologue_fix(quantizer, dec, ar_model, *, config, class_ids,
num_resample, num_prologue_sets, output_dir, device):
"""Per class: ``num_prologue_sets`` refs Γ— ``num_resample`` visuals with fixed prologue prefix."""
prologue = (
bool(config.get("Prologue", False))
and not bool(config.get("share_semantic_codebook", False))
)
if not prologue:
raise ValueError(
"prologue_fix mode requires a Prologue tokenizer "
"(Prologue=True, share_semantic_codebook=False)"
)
num_classes = int(config.num_classes)
ae_no_label = bool(config.get("ae_no_label", False))
z_len = ar_model.z_len
all_class_imgs = []
for cid in class_ids:
name = IMAGENET_NAMES.get(cid, f"class{cid}")
print0(f" Prologue-fix: class {cid} ({name}), "
f"{num_prologue_sets} set(s) -> {num_resample} visual resamples ...")
cls_lbl_1, uncond_lbl_1 = _make_labels(cid, 1, num_classes, device)
cls_lbl_n, uncond_lbl_n = _make_labels(cid, num_resample, num_classes, device)
ae_lbl_n = uncond_lbl_n if ae_no_label else cls_lbl_n
rows = []
for _ in range(num_prologue_sets):
token_ids = sample_tokens(ar_model, bz=1,
class_label=cls_lbl_1, config=config)
prologue_ids = token_ids[:, :z_len]
resampled = sampling_with_fixed_prologue(
ar_model, bz=num_resample, class_label=cls_lbl_n,
config=config, fixed_prologue_ids=prologue_ids,
)
imgs = decode_tokens(quantizer, dec, resampled,
config=config, ae_label=ae_lbl_n)
rows.append(imgs)
grid_imgs = torch.cat(rows, dim=0)
all_class_imgs.append(grid_imgs)
grid = torchvision.utils.make_grid(grid_imgs, nrow=num_resample, padding=2)
out = os.path.join(output_dir, f"prologue_fix_{cid}_{name}.png")
save_tensor_image_png_pdf(grid, out)
print0(f" -> {out} (+ .pdf)")
combined = torch.cat(all_class_imgs, dim=0)
grid = torchvision.utils.make_grid(combined, nrow=num_resample, padding=2)
out = os.path.join(output_dir, "all_prologue_fix.png")
save_tensor_image_png_pdf(grid, out)
print0(f" Combined grid -> {out} (+ .pdf)")
# ═══════════════════════════════════════════════════════════════════════════
# Main
# ═══════════════════════════════════════════════════════════════════════════
def main():
config = load_config()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if config.get("use_tf32", True):
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
seed = int(config.get("seed", 42))
seed_everything(seed)
torch.cuda.manual_seed(seed)
mode = str(config.get("mode", "sample"))
output_dir = str(config.get("output_dir", "sample_vis_output"))
os.makedirs(output_dir, exist_ok=True)
raw_ids = str(config.get("class_ids", ""))
class_ids = (
[int(x.strip()) for x in raw_ids.split(",") if x.strip()]
if raw_ids else DEFAULT_CLASS_IDS
)
print0(f"Mode : {mode}")
print0(f"Class IDs : {class_ids}")
print0(f"Output dir : {output_dir}")
print0(f"Seed : {seed}")
quantizer, dec, _, ar_model = load_models(config, device)
if mode == "sample":
num_per_class = int(config.get("num_per_class", 8))
mode_sample(quantizer, dec, ar_model, config=config,
class_ids=class_ids, num_per_class=num_per_class,
output_dir=output_dir, device=device)
elif mode == "prologue_fix":
num_resample = int(config.get("num_resample", 8))
num_prologue_sets = int(config.get("num_prologue_sets", 1))
mode_prologue_fix(quantizer, dec, ar_model, config=config,
class_ids=class_ids, num_resample=num_resample,
num_prologue_sets=num_prologue_sets,
output_dir=output_dir, device=device)
else:
raise ValueError(
f"Unknown mode: {mode!r} (expected 'sample' or 'prologue_fix')"
)
print0("Done.")
if __name__ == "__main__":
main()