VARestorer / tools /run_varestorer.py
YixuanEvan's picture
add HF model card and mirror runnable codebase
7f7272e
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import os.path as osp
from typing import List
from contextlib import nullcontext
import math
import hashlib
import yaml
import argparse
import shutil
import re
import cv2
import numpy as np
import torch
torch._dynamo.config.cache_size_limit=64
import pandas as pd
from transformers import AutoTokenizer, T5EncoderModel, T5TokenizerFast
from PIL import Image, ImageEnhance
import torch.nn.functional as F
from torch.cuda.amp import autocast
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from infinity.models.infinity import Infinity,BInfinity
from infinity.models.basic import *
import PIL.Image as PImage
from torchvision.transforms.functional import to_tensor
from infinity.utils.dynamic_resolution import dynamic_resolution_h_w, h_div_w_templates
import pdb
from torchvision import transforms
from infinity.models.bitwise_self_correction import BitwiseSelfCorrection
from infinity.models.swinir import SwinIR
import importlib
from lora_diffusion import inject_trainable_lora
from transformers import BlipForConditionalGeneration,BlipProcessor
def get_obj_from_str(string: str, reload: bool = False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
def instantiate_from_config(config):
if not "target" in config:
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()))
def extract_key_val(text):
pattern = r'<(.+?):(.+?)>'
matches = re.findall(pattern, text)
key_val = {}
for match in matches:
key_val[match[0]] = match[1].lstrip()
return key_val
def encode_prompt(text_tokenizer, text_encoder, prompt, enable_positive_prompt=False):
if enable_positive_prompt:
print(f'before positive_prompt aug: {prompt}')
prompt = aug_with_positive_prompt(prompt)
print(f'after positive_prompt aug: {prompt}')
# print(f'prompt={prompt}')
captions = [prompt]
tokens = text_tokenizer(text=captions, max_length=512, padding='max_length', truncation=True, return_tensors='pt') # todo: put this into dataset
input_ids = tokens.input_ids.cuda(non_blocking=True)
mask = tokens.attention_mask.cuda(non_blocking=True)
text_features = text_encoder(input_ids=input_ids, attention_mask=mask)['last_hidden_state'].float()
lens: List[int] = mask.sum(dim=-1).tolist()
cu_seqlens_k = F.pad(mask.sum(dim=-1).to(dtype=torch.int32).cumsum_(0), (1, 0))
Ltext = max(lens)
kv_compact = []
for len_i, feat_i in zip(lens, text_features.unbind(0)):
kv_compact.append(feat_i[:len_i])
kv_compact = torch.cat(kv_compact, dim=0)
text_cond_tuple = (kv_compact, lens, cu_seqlens_k, Ltext)
return text_cond_tuple
def aug_with_positive_prompt(prompt):
for key in ['man', 'woman', 'men', 'women', 'boy', 'girl', 'child', 'person', 'human', 'adult', 'teenager', 'employee',
'employer', 'worker', 'mother', 'father', 'sister', 'brother', 'grandmother', 'grandfather', 'son', 'daughter']:
if key in prompt:
prompt = prompt + '. very smooth faces, good looking faces, face to the camera, perfect facial features'
break
return prompt
def enhance_image(image):
for t in range(1):
contrast_image = image.copy()
contrast_enhancer = ImageEnhance.Contrast(contrast_image)
contrast_image = contrast_enhancer.enhance(1.05)
color_image = contrast_image.copy()
color_enhancer = ImageEnhance.Color(color_image)
color_image = color_enhancer.enhance(1.05)
return color_image
def load_swinir_model(device, swinir=None):
if swinir is not None:
return swinir
swinir_config = {
"target": "infinity.models.swinir.SwinIR",
"params": {
"img_size": 64,
"patch_size": 1,
"in_chans": 3,
"embed_dim": 180,
"depths": [6, 6, 6, 6, 6, 6, 6, 6],
"num_heads": [6, 6, 6, 6, 6, 6, 6, 6],
"window_size": 8,
"mlp_ratio": 2,
"sf": 8,
"img_range": 1.0,
"upsampler": "nearest+conv",
"resi_connection": "1conv",
"unshuffle": True,
"unshuffle_scale": 8
}
}
swinir = instantiate_from_config(swinir_config)
sd = torch.load('weights/general_swinir_v1.ckpt', map_location="cpu")
if "state_dict" in sd:
sd = sd["state_dict"]
sd = {
(k[len("module."):] if k.startswith("module.") else k): v
for k, v in sd.items()
}
swinir.load_state_dict(sd, strict=True)
for p in swinir.parameters():
p.requires_grad = False
swinir.eval().to(device)
return swinir
def pil_to_lq_tensor(pil_img, device):
transform = transforms.ToTensor()
lq_img = transform(pil_img)
lq_img = lq_img * 2 - 1
return lq_img.unsqueeze(0).to(device, non_blocking=True)
def run_swinir(lq_img, swinir):
lq_img = (lq_img + 1) / 2
lq_img = swinir(lq_img)
lq_img = lq_img + lq_img - 1
return lq_img
def prepare_prompt_conditions(
text_tokenizer,
text_encoder,
prompt,
negative_prompt='',
enable_positive_prompt=0,
):
text_cond_tuple = encode_prompt(text_tokenizer, text_encoder, prompt, enable_positive_prompt)
if negative_prompt:
negative_label_B_or_BLT = encode_prompt(text_tokenizer, text_encoder, negative_prompt)
else:
negative_label_B_or_BLT = None
return text_cond_tuple, negative_label_B_or_BLT
def gaussian_tile_weights(tile_height, tile_width):
from numpy import exp, pi, sqrt
var = 0.01
x_mid = (tile_width - 1) / 2
y_mid = (tile_height - 1) / 2
x_probs = [exp(-(x - x_mid) * (x - x_mid) / (tile_width * tile_width) / (2 * var)) / sqrt(2 * pi * var) for x in range(tile_width)]
y_probs = [exp(-(y - y_mid) * (y - y_mid) / (tile_height * tile_height) / (2 * var)) / sqrt(2 * pi * var) for y in range(tile_height)]
weights = np.outer(y_probs, x_probs).astype(np.float32)
weights /= weights.max()
return weights[..., None]
def compute_tile_positions(length, tile_size, tile_overlap):
if length <= tile_size:
return [0]
stride = tile_size - tile_overlap
if stride <= 0:
raise ValueError(f'tile_overlap must be smaller than tile_size, got {tile_overlap=} {tile_size=}')
positions = list(range(0, length - tile_size + 1, stride))
if positions[-1] != length - tile_size:
positions.append(length - tile_size)
return positions
def crop_tile_with_padding(image_np, top, left, tile_size):
bottom = min(top + tile_size, image_np.shape[0])
right = min(left + tile_size, image_np.shape[1])
tile = image_np[top:bottom, left:right]
pad_bottom = tile_size - tile.shape[0]
pad_right = tile_size - tile.shape[1]
if pad_bottom > 0 or pad_right > 0:
try:
tile = np.pad(tile, ((0, pad_bottom), (0, pad_right), (0, 0)), mode='reflect')
except ValueError:
tile = np.pad(tile, ((0, pad_bottom), (0, pad_right), (0, 0)), mode='edge')
return Image.fromarray(tile), bottom - top, right - left
def pad_image_for_tiling(image_np, border_pad):
if border_pad <= 0:
return image_np, 0
try:
padded = np.pad(image_np, ((border_pad, border_pad), (border_pad, border_pad), (0, 0)), mode='reflect')
except ValueError:
padded = np.pad(image_np, ((border_pad, border_pad), (border_pad, border_pad), (0, 0)), mode='edge')
return padded, border_pad
def resize_image_to_sr_scale(pil_img, sr_scale):
if sr_scale <= 0:
raise ValueError(f'sr_scale must be positive, got {sr_scale}')
width, height = pil_img.size
target_width = max(1, int(round(width * sr_scale)))
target_height = max(1, int(round(height * sr_scale)))
if (target_width, target_height) != (width, height):
pil_img = pil_img.resize((target_width, target_height), resample=PImage.LANCZOS)
return pil_img, target_width, target_height
def resize_output_tensor(image_tensor, target_width, target_height):
image_np = image_tensor.detach().cpu().numpy().astype(np.uint8)
resized = Image.fromarray(image_np).resize((target_width, target_height), resample=PImage.LANCZOS)
return torch.from_numpy(np.asarray(resized))
def infer_single_sr_tile(
infinity_test,
vae,
scale_schedule,
text_cond_tuple,
negative_label_B_or_BLT,
lq_img,
cfg_list=[],
tau_list=[],
top_k=900,
top_p=0.97,
cfg_sc=3,
cfg_exp_k=0.0,
cfg_insertion_layer=-5,
vae_type=0,
gumbel=0,
softmax_merge_topk=-1,
gt_leak=-1,
gt_ls_Bl=None,
g_seed=None,
sampling_per_bits=1,
args=None,
bitwise_self_correction=None,
requant_mode='flip',
):
device = lq_img.device
with torch.amp.autocast('cuda', enabled=False):
if infinity_test.apply_spatial_patchify:
vae_scale_schedule = [(pt, 2 * ph, 2 * pw) for pt, ph, pw in scale_schedule]
else:
vae_scale_schedule = scale_schedule
raw_features_lq, _, _ = vae.encode_for_raw_features(lq_img, scale_schedule=vae_scale_schedule)
if bitwise_self_correction is None:
if args is None:
args = globals().get('args', None)
bitwise_self_correction = BitwiseSelfCorrection(vae, args)
if requant_mode == 'flip':
x_BLC_wo_prefix_lq, _ = bitwise_self_correction.flip_requant(vae_scale_schedule, lq_img, raw_features_lq, device)
elif requant_mode == 'long':
x_BLC_wo_prefix_lq, _ = bitwise_self_correction.long_flip_requant(vae_scale_schedule, lq_img, raw_features_lq, device)
else:
raise ValueError(f'Unsupported requant_mode: {requant_mode}')
if not isinstance(cfg_list, list):
cfg_list = [cfg_list] * len(scale_schedule)
if not isinstance(tau_list, list):
tau_list = [tau_list] * len(scale_schedule)
autocast_ctx = torch.cuda.amp.autocast(enabled=device.type == 'cuda', dtype=torch.bfloat16, cache_enabled=True) if device.type == 'cuda' else nullcontext()
with autocast_ctx:
_, _, img_list = infinity_test.autoregressive_infer_cfg(
vae=vae,
scale_schedule=scale_schedule,
label_B_or_BLT=text_cond_tuple,
g_seed=g_seed,
B=1,
negative_label_B_or_BLT=negative_label_B_or_BLT,
force_gt_Bhw=None,
cfg_sc=cfg_sc,
cfg_list=cfg_list,
tau_list=tau_list,
top_k=top_k,
top_p=top_p,
returns_vemb=1,
ratio_Bl1=None,
gumbel=gumbel,
norm_cfg=False,
cfg_exp_k=cfg_exp_k,
cfg_insertion_layer=cfg_insertion_layer,
vae_type=vae_type,
softmax_merge_topk=softmax_merge_topk,
ret_img=True,
trunk_scale=1000,
gt_leak=gt_leak,
gt_ls_Bl=gt_ls_Bl,
inference_mode=True,
sampling_per_bits=sampling_per_bits,
x_BLC_wo_prefix_lq=x_BLC_wo_prefix_lq,
)
return img_list[0]
def gen_one_img(
infinity_test,
vae,
text_tokenizer,
text_encoder,
prompt,
cfg_list=[],
tau_list=[],
negative_prompt='',
scale_schedule=None,
top_k=900,
top_p=0.97,
cfg_sc=3,
cfg_exp_k=0.0,
cfg_insertion_layer=-5,
vae_type=0,
gumbel=0,
softmax_merge_topk=-1,
gt_leak=-1,
gt_ls_Bl=None,
g_seed=None,
sampling_per_bits=1,
enable_positive_prompt=0,
lq_img_path='',
gt_img_path='',
args=None,
swinir=None,
bitwise_self_correction=None,
):
lq_img = Image.open(lq_img_path)
if lq_img.mode != "RGB":
lq_img = lq_img.convert("RGB")
lq_img = lq_img.resize((512,512))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
swinir = load_swinir_model(device, swinir=swinir)
lq_img = pil_to_lq_tensor(lq_img, device)
lq_img = run_swinir(lq_img, swinir)
text_cond_tuple, negative_label_B_or_BLT = prepare_prompt_conditions(
text_tokenizer,
text_encoder,
prompt,
negative_prompt=negative_prompt,
enable_positive_prompt=enable_positive_prompt,
)
print(f'cfg: {cfg_list}, tau: {tau_list}')
img = infer_single_sr_tile(
infinity_test,
vae,
scale_schedule,
text_cond_tuple,
negative_label_B_or_BLT,
lq_img,
cfg_list=cfg_list,
tau_list=tau_list,
top_k=top_k,
top_p=top_p,
cfg_sc=cfg_sc,
cfg_exp_k=cfg_exp_k,
cfg_insertion_layer=cfg_insertion_layer,
vae_type=vae_type,
gumbel=gumbel,
softmax_merge_topk=softmax_merge_topk,
gt_leak=gt_leak,
gt_ls_Bl=gt_ls_Bl,
g_seed=g_seed,
sampling_per_bits=sampling_per_bits,
args=args,
bitwise_self_correction=bitwise_self_correction,
)
return img
@torch.no_grad()
def gen_one_img_anyres(
infinity_test,
vae,
text_tokenizer,
text_encoder,
prompt,
cfg_list=[],
tau_list=[],
negative_prompt='',
scale_schedule=None,
top_k=900,
top_p=0.97,
cfg_sc=3,
cfg_exp_k=0.0,
cfg_insertion_layer=-5,
vae_type=0,
gumbel=0,
softmax_merge_topk=-1,
gt_leak=-1,
gt_ls_Bl=None,
g_seed=None,
sampling_per_bits=1,
enable_positive_prompt=0,
lq_img_path='',
gt_img_path='',
args=None,
swinir=None,
bitwise_self_correction=None,
tile_size=512,
tile_overlap=128,
sr_scale=1.0,
tiled=0,
):
if tile_size != 512:
raise ValueError(f'Current B inference pipeline expects tile_size=512, got {tile_size}')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
swinir = load_swinir_model(device, swinir=swinir)
if bitwise_self_correction is None:
if args is None:
args = globals().get('args', None)
bitwise_self_correction = BitwiseSelfCorrection(vae, args)
lq_img = Image.open(lq_img_path)
if lq_img.mode != "RGB":
lq_img = lq_img.convert("RGB")
original_width, original_height = lq_img.size
lq_img, target_width, target_height = resize_image_to_sr_scale(lq_img, sr_scale)
print(f'input size: {original_width}x{original_height}, sr_scale: {sr_scale}, target size: {target_width}x{target_height}, tiled: {tiled}')
text_cond_tuple, negative_label_B_or_BLT = prepare_prompt_conditions(
text_tokenizer,
text_encoder,
prompt,
negative_prompt=negative_prompt,
enable_positive_prompt=enable_positive_prompt,
)
width, height = lq_img.size
if not tiled:
print(f'use resized single-pass inference for {width}x{height}')
resized_lq = lq_img.resize((tile_size, tile_size), resample=PImage.LANCZOS)
lq_tensor = pil_to_lq_tensor(resized_lq, device)
lq_tensor = run_swinir(lq_tensor, swinir)
img = infer_single_sr_tile(
infinity_test,
vae,
scale_schedule,
text_cond_tuple,
negative_label_B_or_BLT,
lq_tensor,
cfg_list=cfg_list,
tau_list=tau_list,
top_k=top_k,
top_p=top_p,
cfg_sc=cfg_sc,
cfg_exp_k=cfg_exp_k,
cfg_insertion_layer=cfg_insertion_layer,
vae_type=vae_type,
gumbel=gumbel,
softmax_merge_topk=softmax_merge_topk,
gt_leak=gt_leak,
gt_ls_Bl=gt_ls_Bl,
g_seed=g_seed,
sampling_per_bits=sampling_per_bits,
args=args,
bitwise_self_correction=bitwise_self_correction,
requant_mode='long',
)
if width != tile_size or height != tile_size:
img = resize_output_tensor(img, width, height)
return img
image_np = np.asarray(lq_img)
image_np, border_pad = pad_image_for_tiling(image_np, tile_overlap)
padded_height, padded_width = image_np.shape[:2]
xs = compute_tile_positions(padded_width, tile_size, tile_overlap)
ys = compute_tile_positions(padded_height, tile_size, tile_overlap)
tile_weights = gaussian_tile_weights(tile_size, tile_size)
preds = np.zeros((padded_height, padded_width, 3), dtype=np.float32)
contributors = np.zeros((padded_height, padded_width, 1), dtype=np.float32)
total_tiles = len(xs) * len(ys)
tile_index = 0
for top in ys:
for left in xs:
tile_index += 1
tile_pil, valid_h, valid_w = crop_tile_with_padding(image_np, top, left, tile_size)
tile_tensor = pil_to_lq_tensor(tile_pil, device)
tile_tensor = run_swinir(tile_tensor, swinir)
tile_img = infer_single_sr_tile(
infinity_test,
vae,
scale_schedule,
text_cond_tuple,
negative_label_B_or_BLT,
tile_tensor,
cfg_list=cfg_list,
tau_list=tau_list,
top_k=top_k,
top_p=top_p,
cfg_sc=cfg_sc,
cfg_exp_k=cfg_exp_k,
cfg_insertion_layer=cfg_insertion_layer,
vae_type=vae_type,
gumbel=gumbel,
softmax_merge_topk=softmax_merge_topk,
gt_leak=gt_leak,
gt_ls_Bl=gt_ls_Bl,
g_seed=g_seed,
sampling_per_bits=sampling_per_bits,
args=args,
bitwise_self_correction=bitwise_self_correction,
requant_mode='long',
)
tile_img = tile_img.detach().cpu().numpy().astype(np.float32)
valid_weights = tile_weights[:valid_h, :valid_w]
preds[top:top + valid_h, left:left + valid_w] += tile_img[:valid_h, :valid_w] * valid_weights
contributors[top:top + valid_h, left:left + valid_w] += valid_weights
print(f'processed tile {tile_index}/{total_tiles}: top={top}, left={left}, size={valid_h}x{valid_w}')
preds /= np.clip(contributors, 1e-8, None)
preds = np.clip(np.rint(preds), 0, 255).astype(np.uint8)
if border_pad > 0:
preds = preds[border_pad:border_pad + height, border_pad:border_pad + width]
return torch.from_numpy(preds)
@torch.no_grad()
def gen_one_img_eval(
infinity_test,
vae,
text_tokenizer,
text_encoder,
prompt,
cfg_list=[],
tau_list=[],
negative_prompt='',
scale_schedule=None,
top_k=900,
top_p=0.97,
cfg_sc=3,
cfg_exp_k=0.0,
cfg_insertion_layer=-5,
vae_type=0,
gumbel=0,
softmax_merge_topk=-1,
gt_leak=-1,
gt_ls_Bl=None,
g_seed=None,
sampling_per_bits=1,
enable_positive_prompt=0,
lq_img_path='',
args=None,
blip_model=None,
blip_processor=None,
):
lq_img = Image.open(lq_img_path)
if lq_img.mode != "RGB":
lq_img = lq_img.convert("RGB")
# if scale_schedule[-1][-1]==16:
# lq_img = lq_img.resize((256,256))
lq_img = lq_img.resize((512,512))
transform = transforms.ToTensor()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
lq_img = transform(lq_img)
lq_img = lq_img*2-1
lq_img = lq_img.unsqueeze(0).to(device, non_blocking=True)
##### swinir
lq_img = (lq_img+1)/2
lq_img = swinir(lq_img)
lq_img = lq_img + lq_img -1
##### swinir
#####blip
if not prompt:
raw_image = Image.open(lq_img_path).convert('RGB')
inputs = blip_processor(raw_image, return_tensors="pt").to("cuda", torch.float16)
out = blip_model.generate(**inputs)
prompt = blip_processor.decode(out[0], skip_special_tokens=True)
#####
with torch.amp.autocast('cuda', enabled=False):
with torch.no_grad():
if infinity_test.apply_spatial_patchify:
vae_scale_schedule = [(pt, 2*ph, 2*pw) for pt, ph, pw in scale_schedule]
else:
vae_scale_schedule = scale_schedule
# raw_features, _, _ = vae.encode_for_raw_features(gt_img, scale_schedule=vae_scale_schedule)
raw_features_lq, _, _ = vae.encode_for_raw_features(lq_img, scale_schedule=vae_scale_schedule)
#####need to change
bitwise_self_correction= BitwiseSelfCorrection(vae, args)
x_BLC_wo_prefix_lq,_ = bitwise_self_correction.flip_requant(vae_scale_schedule, lq_img, raw_features_lq, device)
if not isinstance(cfg_list, list):
cfg_list = [cfg_list] * len(scale_schedule)
if not isinstance(tau_list, list):
tau_list = [tau_list] * len(scale_schedule)
text_cond_tuple = encode_prompt(text_tokenizer, text_encoder, prompt, enable_positive_prompt)
if negative_prompt:
negative_label_B_or_BLT = encode_prompt(text_tokenizer, text_encoder, negative_prompt)
else:
negative_label_B_or_BLT = None
with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16, cache_enabled=True):
### single step
_, _, img_list = infinity_test.autoregressive_infer_cfg(
vae=vae,
scale_schedule=scale_schedule,
label_B_or_BLT=text_cond_tuple, g_seed=g_seed,
B=1, negative_label_B_or_BLT=negative_label_B_or_BLT, force_gt_Bhw=None,
cfg_sc=cfg_sc, cfg_list=cfg_list, tau_list=tau_list, top_k=top_k, top_p=top_p,
returns_vemb=1, ratio_Bl1=None, gumbel=gumbel, norm_cfg=False,
cfg_exp_k=cfg_exp_k, cfg_insertion_layer=cfg_insertion_layer,
vae_type=vae_type, softmax_merge_topk=softmax_merge_topk,
ret_img=True, trunk_scale=1000,
gt_leak=gt_leak, gt_ls_Bl=gt_ls_Bl, inference_mode=True,
sampling_per_bits=sampling_per_bits,
x_BLC_wo_prefix_lq=x_BLC_wo_prefix_lq,
)
# ###
img = img_list[0]
return img,prompt
@torch.no_grad()
def gen_one_img_eval_long(
infinity_test,
vae,
text_tokenizer,
text_encoder,
prompt,
cfg_list=[],
tau_list=[],
negative_prompt='',
scale_schedule=None,
top_k=900,
top_p=0.97,
cfg_sc=3,
cfg_exp_k=0.0,
cfg_insertion_layer=-5,
vae_type=0,
gumbel=0,
softmax_merge_topk=-1,
gt_leak=-1,
gt_ls_Bl=None,
g_seed=None,
sampling_per_bits=1,
enable_positive_prompt=0,
lq_img_path='',
args=None,
blip_model=None,
blip_processor=None,
swinir=None,
bitwise_self_correction=None
):
lq_img = Image.open(lq_img_path)
if lq_img.mode != "RGB":
lq_img = lq_img.convert("RGB")
# if scale_schedule[-1][-1]==16:
# lq_img = lq_img.resize((256,256))
lq_img = lq_img.resize((512,512))
transform = transforms.ToTensor()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
lq_img = transform(lq_img)
lq_img = lq_img*2-1
lq_img = lq_img.unsqueeze(0).to(device, non_blocking=True)
lq_img = (lq_img+1)/2
lq_img = swinir(lq_img)
lq_img = lq_img + lq_img -1
##### swinir
#####blip
if not prompt:
raw_image = Image.open(lq_img_path).convert('RGB')
inputs = blip_processor(raw_image, return_tensors="pt").to("cuda", torch.float16)
out = blip_model.generate(**inputs)
prompt = blip_processor.decode(out[0], skip_special_tokens=True)
#####
with torch.amp.autocast('cuda', enabled=False):
with torch.no_grad():
if infinity_test.apply_spatial_patchify:
vae_scale_schedule = [(pt, 2*ph, 2*pw) for pt, ph, pw in scale_schedule]
else:
vae_scale_schedule = scale_schedule
# raw_features, _, _ = vae.encode_for_raw_features(gt_img, scale_schedule=vae_scale_schedule)
raw_features_lq, _, _ = vae.encode_for_raw_features(lq_img, scale_schedule=vae_scale_schedule)
#####need to change
# x_BLC_wo_prefix_lq,_ = bitwise_self_correction.flip_requant(vae_scale_schedule, lq_img, raw_features_lq, device)
# x_BLC_w_prefix_lq,_ = bitwise_self_correction.my_flip_requant(vae_scale_schedule, lq_img, raw_features_lq, device)
# last_scale_length = scale_schedule[-1][0] * scale_schedule[-1][1] * scale_schedule[-1][2]
# x_BLC_wo_prefix_lq_long = torch.cat([x_BLC_wo_prefix_lq,x_BLC_w_prefix_lq[:,-last_scale_length:,:]],dim = 1)
x_BLC_wo_prefix_lq_long,_ = bitwise_self_correction.long_flip_requant(vae_scale_schedule, lq_img, raw_features_lq, device)
#####
if not isinstance(cfg_list, list):
cfg_list = [cfg_list] * len(scale_schedule)
if not isinstance(tau_list, list):
tau_list = [tau_list] * len(scale_schedule)
text_cond_tuple = encode_prompt(text_tokenizer, text_encoder, prompt, enable_positive_prompt)
if negative_prompt:
negative_label_B_or_BLT = encode_prompt(text_tokenizer, text_encoder, negative_prompt)
else:
negative_label_B_or_BLT = None
with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16, cache_enabled=True):
### single step
_, _, img_list = infinity_test.autoregressive_infer_cfg(
vae=vae,
scale_schedule=scale_schedule,
label_B_or_BLT=text_cond_tuple, g_seed=g_seed,
B=1, negative_label_B_or_BLT=negative_label_B_or_BLT, force_gt_Bhw=None,
cfg_sc=cfg_sc, cfg_list=cfg_list, tau_list=tau_list, top_k=top_k, top_p=top_p,
returns_vemb=1, ratio_Bl1=None, gumbel=gumbel, norm_cfg=False,
cfg_exp_k=cfg_exp_k, cfg_insertion_layer=cfg_insertion_layer,
vae_type=vae_type, softmax_merge_topk=softmax_merge_topk,
ret_img=True, trunk_scale=1000,
gt_leak=gt_leak, gt_ls_Bl=gt_ls_Bl, inference_mode=True,
sampling_per_bits=sampling_per_bits,
x_BLC_wo_prefix_lq=x_BLC_wo_prefix_lq_long,
)
# ###
img = img_list[0]
return img,prompt
def get_prompt_id(prompt):
md5 = hashlib.md5()
md5.update(prompt.encode('utf-8'))
prompt_id = md5.hexdigest()
return prompt_id
def save_slim_model(infinity_model_path, save_file=None, device='cpu', key='gpt_fsdp'):
print('[Save slim model]')
full_ckpt = torch.load(infinity_model_path, map_location=device)
infinity_slim = full_ckpt['trainer'][key]
# ema_state_dict = cpu_d['trainer'].get('gpt_ema_fsdp', state_dict)
if not save_file:
save_file = osp.splitext(infinity_model_path)[0] + '-slim.pth'
print(f'Save to {save_file}')
torch.save(infinity_slim, save_file)
print('[Save slim model] done')
return save_file
def load_tokenizer(t5_path =''):
print(f'[Loading tokenizer and text encoder]')
text_tokenizer: T5TokenizerFast = AutoTokenizer.from_pretrained(t5_path, revision=None, legacy=True)
text_tokenizer.model_max_length = 512
text_encoder: T5EncoderModel = T5EncoderModel.from_pretrained(t5_path, torch_dtype=torch.float16)
text_encoder.to('cuda')
text_encoder.eval()
text_encoder.requires_grad_(False)
return text_tokenizer, text_encoder
def load_infinity(
rope2d_each_sa_layer,
rope2d_normalized_by_hw,
use_scale_schedule_embedding,
pn,
use_bit_label,
add_lvl_embeding_only_first_block,
model_path='',
scale_schedule=None,
vae=None,
device='cuda',
model_kwargs=None,
text_channels=2048,
apply_spatial_patchify=0,
use_flex_attn=False,
bf16=False,
):
print(f'[Loading Infinity]')
text_maxlen = 512
with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16, cache_enabled=True), torch.no_grad():
infinity_test: Infinity = BInfinity(
vae_local=vae, text_channels=text_channels, text_maxlen=text_maxlen,
shared_aln=True, raw_scale_schedule=scale_schedule,
checkpointing='full-block',
customized_flash_attn=False,
fused_norm=True,
pad_to_multiplier=1,
use_flex_attn=use_flex_attn,
add_lvl_embeding_only_first_block=add_lvl_embeding_only_first_block,
use_bit_label=use_bit_label,
rope2d_each_sa_layer=rope2d_each_sa_layer,
rope2d_normalized_by_hw=rope2d_normalized_by_hw,
pn=pn,
apply_spatial_patchify=apply_spatial_patchify,
inference_mode=True,
train_h_div_w_list=[1.0],
**model_kwargs,
).to(device=device)
print(f'[you selected Infinity with {model_kwargs=}] model size: {sum(p.numel() for p in infinity_test.parameters())/1e9:.2f}B, bf16={bf16}')
if bf16:
for block in infinity_test.unregistered_blocks:
block.bfloat16()
### my code
# unet_lora_params, train_names = inject_trainable_lora(infinity_test.block_chunks)
unet_lora_params, train_names = inject_trainable_lora(infinity_test.block_chunks, target_replace_module={"CrossAttention", "SelfAttention"}, r=32)
infinity_test.eval()
infinity_test.requires_grad_(False)
infinity_test.cuda()
torch.cuda.empty_cache()
print(f'[Load Infinity weights]')
checkpoint = torch.load(model_path,map_location=device)
infinity_test.load_state_dict(checkpoint,strict=True)
# state_dict = checkpoint['infinity']
# lora_params = {
# k: v for k, v in state_dict.items()
# if 'lora' in k.lower()
# }
# torch.save(lora_params, 'infinity_lora.pth')
# pdb.set_trace()
infinity_test.rng = torch.Generator(device=device)
return infinity_test
def transform(pil_img, tgt_h, tgt_w):
width, height = pil_img.size
if width / height <= tgt_w / tgt_h:
resized_width = tgt_w
resized_height = int(tgt_w / (width / height))
else:
resized_height = tgt_h
resized_width = int((width / height) * tgt_h)
pil_img = pil_img.resize((resized_width, resized_height), resample=PImage.LANCZOS)
# crop the center out
arr = np.array(pil_img)
crop_y = (arr.shape[0] - tgt_h) // 2
crop_x = (arr.shape[1] - tgt_w) // 2
im = to_tensor(arr[crop_y: crop_y + tgt_h, crop_x: crop_x + tgt_w])
return im.add(im).add_(-1)
def joint_vi_vae_encode_decode(vae, image_path, scale_schedule, device, tgt_h, tgt_w):
pil_image = Image.open(image_path).convert('RGB')
inp = transform(pil_image, tgt_h, tgt_w)
inp = inp.unsqueeze(0).to(device)
scale_schedule = [(item[0], item[1], item[2]) for item in scale_schedule]
h, z, _, all_bit_indices, _, infinity_input = vae.encode(inp, scale_schedule=scale_schedule)
recons_img = vae.decode(z)[0]
if len(recons_img.shape) == 4:
recons_img = recons_img.squeeze(1)
print(f'recons: z.shape: {z.shape}, recons_img shape: {recons_img.shape}')
recons_img = (recons_img + 1) / 2
recons_img = recons_img.permute(1, 2, 0).mul_(255).cpu().numpy().astype(np.uint8)
gt_img = (inp[0] + 1) / 2
gt_img = gt_img.permute(1, 2, 0).mul_(255).cpu().numpy().astype(np.uint8)
print(recons_img.shape, gt_img.shape)
return gt_img, recons_img, all_bit_indices
def load_visual_tokenizer(args):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# load vae
if args.vae_type in [16,18,20,24,32,64]:
from infinity.models.bsq_vae.vae import vae_model
schedule_mode = "dynamic"
codebook_dim = args.vae_type
codebook_size = 2**codebook_dim
if args.apply_spatial_patchify:
patch_size = 8
encoder_ch_mult=[1, 2, 4, 4]
decoder_ch_mult=[1, 2, 4, 4]
else:
patch_size = 16
encoder_ch_mult=[1, 2, 4, 4, 4]
decoder_ch_mult=[1, 2, 4, 4, 4]
vae = vae_model(args.vae_path, schedule_mode, codebook_dim, codebook_size, patch_size=patch_size,
encoder_ch_mult=encoder_ch_mult, decoder_ch_mult=decoder_ch_mult, test_mode=True).to(device)
else:
raise ValueError(f'vae_type={args.vae_type} not supported')
return vae
def load_visual_tokenizer_lora(args):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# load vae
if args.vae_type in [16,18,20,24,32,64]:
from infinity.models.bsq_vae.vae import vae_model_lora
schedule_mode = "dynamic"
codebook_dim = args.vae_type
codebook_size = 2**codebook_dim
if args.apply_spatial_patchify:
patch_size = 8
encoder_ch_mult=[1, 2, 4, 4]
decoder_ch_mult=[1, 2, 4, 4]
else:
patch_size = 16
encoder_ch_mult=[1, 2, 4, 4, 4]
decoder_ch_mult=[1, 2, 4, 4, 4]
vae = vae_model_lora(args.vae_path, schedule_mode, codebook_dim, codebook_size, patch_size=patch_size,
encoder_ch_mult=encoder_ch_mult, decoder_ch_mult=decoder_ch_mult, test_mode=True).to(device)
else:
raise ValueError(f'vae_type={args.vae_type} not supported')
return vae
def load_transformer(vae, args):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_path = args.model_path
if args.model_type == 'infinity_2b':
kwargs_model = dict(depth=32, embed_dim=2048, num_heads=2048//128, drop_path_rate=0.1, mlp_ratio=4, block_chunks=8) # 2b model
elif args.model_type == 'infinity_layer12':
kwargs_model = dict(depth=12, embed_dim=768, num_heads=8, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
elif args.model_type == 'infinity_layer16':
kwargs_model = dict(depth=16, embed_dim=1152, num_heads=12, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
elif args.model_type == 'infinity_layer24':
kwargs_model = dict(depth=24, embed_dim=1536, num_heads=16, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
elif args.model_type == 'infinity_layer32':
kwargs_model = dict(depth=32, embed_dim=2080, num_heads=20, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
elif args.model_type == 'infinity_layer40':
kwargs_model = dict(depth=40, embed_dim=2688, num_heads=24, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
elif args.model_type == 'infinity_layer48':
kwargs_model = dict(depth=48, embed_dim=3360, num_heads=28, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
infinity = load_infinity(
rope2d_each_sa_layer=args.rope2d_each_sa_layer,
rope2d_normalized_by_hw=args.rope2d_normalized_by_hw,
use_scale_schedule_embedding=args.use_scale_schedule_embedding,
pn=args.pn,
use_bit_label=args.use_bit_label,
add_lvl_embeding_only_first_block=args.add_lvl_embeding_only_first_block,
model_path=model_path,
scale_schedule=None,
vae=vae,
device=device,
model_kwargs=kwargs_model,
text_channels=args.text_channels,
apply_spatial_patchify=args.apply_spatial_patchify,
use_flex_attn=args.use_flex_attn,
bf16=args.bf16,
)
return infinity
def add_common_arguments(parser):
parser.add_argument('--cfg', type=str, default='3')
parser.add_argument('--tau', type=float, default=1)
parser.add_argument('--pn', type=str, required=True, choices=['0.06M', '0.25M', '1M'])
parser.add_argument('--model_path', type=str, required=True)
parser.add_argument('--cfg_insertion_layer', type=int, default=0)
parser.add_argument('--vae_type', type=int, default=1)
parser.add_argument('--vae_path', type=str, default='')
parser.add_argument('--add_lvl_embeding_only_first_block', type=int, default=0, choices=[0,1])
parser.add_argument('--use_bit_label', type=int, default=1, choices=[0,1])
parser.add_argument('--model_type', type=str, default='infinity_2b')
parser.add_argument('--rope2d_each_sa_layer', type=int, default=1, choices=[0,1])
parser.add_argument('--rope2d_normalized_by_hw', type=int, default=2, choices=[0,1,2])
parser.add_argument('--use_scale_schedule_embedding', type=int, default=0, choices=[0,1])
parser.add_argument('--sampling_per_bits', type=int, default=1, choices=[1,2,4,8,16])
parser.add_argument('--text_encoder_ckpt', type=str, default='')
parser.add_argument('--text_channels', type=int, default=2048)
parser.add_argument('--apply_spatial_patchify', type=int, default=0, choices=[0,1])
parser.add_argument('--h_div_w_template', type=float, default=1.000)
parser.add_argument('--use_flex_attn', type=int, default=0, choices=[0,1])
parser.add_argument('--enable_positive_prompt', type=int, default=0, choices=[0,1])
parser.add_argument('--cache_dir', type=str, default='/dev/shm')
parser.add_argument('--checkpoint_type', type=str, default='torch')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--bf16', type=int, default=1, choices=[0,1])
parser.add_argument('--tile_size', type=int, default=512)
parser.add_argument('--tile_overlap', type=int, default=128)
parser.add_argument('--sr_scale', type=float, default=1.0)
parser.add_argument('--tiled', type=int, default=0, choices=[0,1])
def encode_and_decode(lq_img_path,vae,save_path):
lq_img = Image.open(lq_img_path)
if lq_img.mode != "RGB":
lq_img = lq_img.convert("RGB")
transform = transforms.ToTensor()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
lq_img = transform(lq_img)
lq_img = lq_img*2-1
x = lq_img.unsqueeze(0).to(device, non_blocking=True)
is_image = x.ndim == 4
if not is_image:
B, C, T, H, W = x.shape
else:
B, C, H, W = x.shape
T = 1
ptdtype = {None: torch.float32, 'fp32': torch.float32, 'bf16': torch.bfloat16}
enc_dtype = ptdtype[vae.args.encoder_dtype]
with torch.amp.autocast("cuda", dtype=enc_dtype):
h, hs, hs_mid = vae.encoder(x, return_hidden=True) # B C H W or B C T H W
hs = [_h.detach() for _h in hs]
hs_mid = [_h.detach() for _h in hs_mid]
h = h.to(dtype=torch.float32)
# print(z.shape)
# Multiscale LFQ
# z, all_indices, all_loss = vae.quantizer(h)
z,_,_,_,_,_ = vae.quantizer(h)
x_recon = vae.decoder(z)
x_recon = (x_recon+1)/2
x_recon = x_recon.squeeze(0)
to_pil = transforms.ToPILImage()
x_recon = to_pil(x_recon)
x_recon.save(save_path)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
add_common_arguments(parser)
parser.add_argument('--prompt', type=str, default='a dog')
parser.add_argument('--save_file', type=str, default='./tmp.jpg')
parser.add_argument('--lq_img_path', type=str, default='')
parser.add_argument('--noise_apply_layers',type=int,default=-1)
parser.add_argument('--noise_apply_requant',type=int,default=1)
parser.add_argument('--noise_apply_strength',type=float,default=0.3)
parser.add_argument('--debug_bsc',type=int,default=0)
args = parser.parse_args()
# noise_apply_layers: int = 13 # Bitwise Self-Correction: apply noise to layers, -1 means not apply noise
# noise_apply_strength: float = 0.3 # Bitwise Self-Correction: apply noise strength, -1 means not apply noise
# noise_apply_requant: int = 1 # Bitwise Self-Correction: requant after apply noise
# debug_bsc: int = 0
# parse cfg
args.cfg = list(map(float, args.cfg.split(',')))
if len(args.cfg) == 1:
args.cfg = args.cfg[0]
# load text encoder
text_tokenizer, text_encoder = load_tokenizer(t5_path =args.text_encoder_ckpt)
# load vae
vae = load_visual_tokenizer(args)
# load infinity
infinity = load_transformer(vae, args)
bitwise_self_correction = BitwiseSelfCorrection(vae, args)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
swinir = load_swinir_model(device)
scale_schedule = dynamic_resolution_h_w[args.h_div_w_template][args.pn]['scales']
scale_schedule = [ (1, h, w) for (_, h, w) in scale_schedule]
with autocast(dtype=torch.bfloat16):
with torch.no_grad():
generated_image = gen_one_img_anyres(
infinity,
vae,
text_tokenizer,
text_encoder,
args.prompt,
g_seed=args.seed,
gt_leak=0,
gt_ls_Bl=None,
cfg_list=args.cfg,
tau_list=args.tau,
scale_schedule=scale_schedule,
cfg_insertion_layer=[args.cfg_insertion_layer],
vae_type=args.vae_type,
sampling_per_bits=args.sampling_per_bits,
enable_positive_prompt=args.enable_positive_prompt,
lq_img_path=args.lq_img_path,
args=args,
swinir=swinir,
bitwise_self_correction=bitwise_self_correction,
tile_size=args.tile_size,
tile_overlap=args.tile_overlap,
sr_scale=args.sr_scale,
tiled=args.tiled,
)
os.makedirs(osp.dirname(osp.abspath(args.save_file)), exist_ok=True)
cv2.imwrite(args.save_file, generated_image.cpu().numpy())
print(f'Save to {osp.abspath(args.save_file)}')