| import os |
|
|
| os.environ["TOKENIZERS_PARALLELISM"] = "false" |
| import os.path as osp |
| from typing import List |
| from contextlib import nullcontext |
| import math |
| import hashlib |
| import yaml |
| import argparse |
| import shutil |
| import re |
|
|
| import cv2 |
| import numpy as np |
| import torch |
| torch._dynamo.config.cache_size_limit=64 |
| import pandas as pd |
| from transformers import AutoTokenizer, T5EncoderModel, T5TokenizerFast |
| from PIL import Image, ImageEnhance |
| import torch.nn.functional as F |
| from torch.cuda.amp import autocast |
|
|
| import sys |
| import os |
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) |
|
|
| from infinity.models.infinity import Infinity,BInfinity |
| from infinity.models.basic import * |
| import PIL.Image as PImage |
| from torchvision.transforms.functional import to_tensor |
| from infinity.utils.dynamic_resolution import dynamic_resolution_h_w, h_div_w_templates |
| import pdb |
| from torchvision import transforms |
| from infinity.models.bitwise_self_correction import BitwiseSelfCorrection |
| from infinity.models.swinir import SwinIR |
| import importlib |
| from lora_diffusion import inject_trainable_lora |
| from transformers import BlipForConditionalGeneration,BlipProcessor |
|
|
| def get_obj_from_str(string: str, reload: bool = False): |
| module, cls = string.rsplit(".", 1) |
| if reload: |
| module_imp = importlib.import_module(module) |
| importlib.reload(module_imp) |
| return getattr(importlib.import_module(module, package=None), cls) |
|
|
|
|
| def instantiate_from_config(config): |
| if not "target" in config: |
| raise KeyError("Expected key `target` to instantiate.") |
| return get_obj_from_str(config["target"])(**config.get("params", dict())) |
|
|
| def extract_key_val(text): |
| pattern = r'<(.+?):(.+?)>' |
| matches = re.findall(pattern, text) |
| key_val = {} |
| for match in matches: |
| key_val[match[0]] = match[1].lstrip() |
| return key_val |
|
|
| def encode_prompt(text_tokenizer, text_encoder, prompt, enable_positive_prompt=False): |
| if enable_positive_prompt: |
| print(f'before positive_prompt aug: {prompt}') |
| prompt = aug_with_positive_prompt(prompt) |
| print(f'after positive_prompt aug: {prompt}') |
| |
| captions = [prompt] |
| tokens = text_tokenizer(text=captions, max_length=512, padding='max_length', truncation=True, return_tensors='pt') |
| input_ids = tokens.input_ids.cuda(non_blocking=True) |
| mask = tokens.attention_mask.cuda(non_blocking=True) |
| text_features = text_encoder(input_ids=input_ids, attention_mask=mask)['last_hidden_state'].float() |
| lens: List[int] = mask.sum(dim=-1).tolist() |
| cu_seqlens_k = F.pad(mask.sum(dim=-1).to(dtype=torch.int32).cumsum_(0), (1, 0)) |
| Ltext = max(lens) |
| kv_compact = [] |
| for len_i, feat_i in zip(lens, text_features.unbind(0)): |
| kv_compact.append(feat_i[:len_i]) |
| kv_compact = torch.cat(kv_compact, dim=0) |
| text_cond_tuple = (kv_compact, lens, cu_seqlens_k, Ltext) |
| return text_cond_tuple |
|
|
| def aug_with_positive_prompt(prompt): |
| for key in ['man', 'woman', 'men', 'women', 'boy', 'girl', 'child', 'person', 'human', 'adult', 'teenager', 'employee', |
| 'employer', 'worker', 'mother', 'father', 'sister', 'brother', 'grandmother', 'grandfather', 'son', 'daughter']: |
| if key in prompt: |
| prompt = prompt + '. very smooth faces, good looking faces, face to the camera, perfect facial features' |
| break |
| return prompt |
|
|
|
|
| def enhance_image(image): |
| for t in range(1): |
| contrast_image = image.copy() |
| contrast_enhancer = ImageEnhance.Contrast(contrast_image) |
| contrast_image = contrast_enhancer.enhance(1.05) |
| color_image = contrast_image.copy() |
| color_enhancer = ImageEnhance.Color(color_image) |
| color_image = color_enhancer.enhance(1.05) |
| return color_image |
|
|
|
|
| def load_swinir_model(device, swinir=None): |
| if swinir is not None: |
| return swinir |
|
|
| swinir_config = { |
| "target": "infinity.models.swinir.SwinIR", |
| "params": { |
| "img_size": 64, |
| "patch_size": 1, |
| "in_chans": 3, |
| "embed_dim": 180, |
| "depths": [6, 6, 6, 6, 6, 6, 6, 6], |
| "num_heads": [6, 6, 6, 6, 6, 6, 6, 6], |
| "window_size": 8, |
| "mlp_ratio": 2, |
| "sf": 8, |
| "img_range": 1.0, |
| "upsampler": "nearest+conv", |
| "resi_connection": "1conv", |
| "unshuffle": True, |
| "unshuffle_scale": 8 |
| } |
| } |
| swinir = instantiate_from_config(swinir_config) |
| sd = torch.load('weights/general_swinir_v1.ckpt', map_location="cpu") |
| if "state_dict" in sd: |
| sd = sd["state_dict"] |
| sd = { |
| (k[len("module."):] if k.startswith("module.") else k): v |
| for k, v in sd.items() |
| } |
| swinir.load_state_dict(sd, strict=True) |
| for p in swinir.parameters(): |
| p.requires_grad = False |
| swinir.eval().to(device) |
| return swinir |
|
|
|
|
| def pil_to_lq_tensor(pil_img, device): |
| transform = transforms.ToTensor() |
| lq_img = transform(pil_img) |
| lq_img = lq_img * 2 - 1 |
| return lq_img.unsqueeze(0).to(device, non_blocking=True) |
|
|
|
|
| def run_swinir(lq_img, swinir): |
| lq_img = (lq_img + 1) / 2 |
| lq_img = swinir(lq_img) |
| lq_img = lq_img + lq_img - 1 |
| return lq_img |
|
|
|
|
| def prepare_prompt_conditions( |
| text_tokenizer, |
| text_encoder, |
| prompt, |
| negative_prompt='', |
| enable_positive_prompt=0, |
| ): |
| text_cond_tuple = encode_prompt(text_tokenizer, text_encoder, prompt, enable_positive_prompt) |
| if negative_prompt: |
| negative_label_B_or_BLT = encode_prompt(text_tokenizer, text_encoder, negative_prompt) |
| else: |
| negative_label_B_or_BLT = None |
| return text_cond_tuple, negative_label_B_or_BLT |
|
|
|
|
| def gaussian_tile_weights(tile_height, tile_width): |
| from numpy import exp, pi, sqrt |
|
|
| var = 0.01 |
| x_mid = (tile_width - 1) / 2 |
| y_mid = (tile_height - 1) / 2 |
| x_probs = [exp(-(x - x_mid) * (x - x_mid) / (tile_width * tile_width) / (2 * var)) / sqrt(2 * pi * var) for x in range(tile_width)] |
| y_probs = [exp(-(y - y_mid) * (y - y_mid) / (tile_height * tile_height) / (2 * var)) / sqrt(2 * pi * var) for y in range(tile_height)] |
| weights = np.outer(y_probs, x_probs).astype(np.float32) |
| weights /= weights.max() |
| return weights[..., None] |
|
|
|
|
| def compute_tile_positions(length, tile_size, tile_overlap): |
| if length <= tile_size: |
| return [0] |
|
|
| stride = tile_size - tile_overlap |
| if stride <= 0: |
| raise ValueError(f'tile_overlap must be smaller than tile_size, got {tile_overlap=} {tile_size=}') |
|
|
| positions = list(range(0, length - tile_size + 1, stride)) |
| if positions[-1] != length - tile_size: |
| positions.append(length - tile_size) |
| return positions |
|
|
|
|
| def crop_tile_with_padding(image_np, top, left, tile_size): |
| bottom = min(top + tile_size, image_np.shape[0]) |
| right = min(left + tile_size, image_np.shape[1]) |
| tile = image_np[top:bottom, left:right] |
|
|
| pad_bottom = tile_size - tile.shape[0] |
| pad_right = tile_size - tile.shape[1] |
| if pad_bottom > 0 or pad_right > 0: |
| try: |
| tile = np.pad(tile, ((0, pad_bottom), (0, pad_right), (0, 0)), mode='reflect') |
| except ValueError: |
| tile = np.pad(tile, ((0, pad_bottom), (0, pad_right), (0, 0)), mode='edge') |
|
|
| return Image.fromarray(tile), bottom - top, right - left |
|
|
|
|
| def pad_image_for_tiling(image_np, border_pad): |
| if border_pad <= 0: |
| return image_np, 0 |
|
|
| try: |
| padded = np.pad(image_np, ((border_pad, border_pad), (border_pad, border_pad), (0, 0)), mode='reflect') |
| except ValueError: |
| padded = np.pad(image_np, ((border_pad, border_pad), (border_pad, border_pad), (0, 0)), mode='edge') |
| return padded, border_pad |
|
|
|
|
| def resize_image_to_sr_scale(pil_img, sr_scale): |
| if sr_scale <= 0: |
| raise ValueError(f'sr_scale must be positive, got {sr_scale}') |
|
|
| width, height = pil_img.size |
| target_width = max(1, int(round(width * sr_scale))) |
| target_height = max(1, int(round(height * sr_scale))) |
| if (target_width, target_height) != (width, height): |
| pil_img = pil_img.resize((target_width, target_height), resample=PImage.LANCZOS) |
| return pil_img, target_width, target_height |
|
|
|
|
| def resize_output_tensor(image_tensor, target_width, target_height): |
| image_np = image_tensor.detach().cpu().numpy().astype(np.uint8) |
| resized = Image.fromarray(image_np).resize((target_width, target_height), resample=PImage.LANCZOS) |
| return torch.from_numpy(np.asarray(resized)) |
|
|
|
|
| def infer_single_sr_tile( |
| infinity_test, |
| vae, |
| scale_schedule, |
| text_cond_tuple, |
| negative_label_B_or_BLT, |
| lq_img, |
| cfg_list=[], |
| tau_list=[], |
| top_k=900, |
| top_p=0.97, |
| cfg_sc=3, |
| cfg_exp_k=0.0, |
| cfg_insertion_layer=-5, |
| vae_type=0, |
| gumbel=0, |
| softmax_merge_topk=-1, |
| gt_leak=-1, |
| gt_ls_Bl=None, |
| g_seed=None, |
| sampling_per_bits=1, |
| args=None, |
| bitwise_self_correction=None, |
| requant_mode='flip', |
| ): |
| device = lq_img.device |
|
|
| with torch.amp.autocast('cuda', enabled=False): |
| if infinity_test.apply_spatial_patchify: |
| vae_scale_schedule = [(pt, 2 * ph, 2 * pw) for pt, ph, pw in scale_schedule] |
| else: |
| vae_scale_schedule = scale_schedule |
| raw_features_lq, _, _ = vae.encode_for_raw_features(lq_img, scale_schedule=vae_scale_schedule) |
|
|
| if bitwise_self_correction is None: |
| if args is None: |
| args = globals().get('args', None) |
| bitwise_self_correction = BitwiseSelfCorrection(vae, args) |
| if requant_mode == 'flip': |
| x_BLC_wo_prefix_lq, _ = bitwise_self_correction.flip_requant(vae_scale_schedule, lq_img, raw_features_lq, device) |
| elif requant_mode == 'long': |
| x_BLC_wo_prefix_lq, _ = bitwise_self_correction.long_flip_requant(vae_scale_schedule, lq_img, raw_features_lq, device) |
| else: |
| raise ValueError(f'Unsupported requant_mode: {requant_mode}') |
|
|
| if not isinstance(cfg_list, list): |
| cfg_list = [cfg_list] * len(scale_schedule) |
| if not isinstance(tau_list, list): |
| tau_list = [tau_list] * len(scale_schedule) |
|
|
| autocast_ctx = torch.cuda.amp.autocast(enabled=device.type == 'cuda', dtype=torch.bfloat16, cache_enabled=True) if device.type == 'cuda' else nullcontext() |
| with autocast_ctx: |
| _, _, img_list = infinity_test.autoregressive_infer_cfg( |
| vae=vae, |
| scale_schedule=scale_schedule, |
| label_B_or_BLT=text_cond_tuple, |
| g_seed=g_seed, |
| B=1, |
| negative_label_B_or_BLT=negative_label_B_or_BLT, |
| force_gt_Bhw=None, |
| cfg_sc=cfg_sc, |
| cfg_list=cfg_list, |
| tau_list=tau_list, |
| top_k=top_k, |
| top_p=top_p, |
| returns_vemb=1, |
| ratio_Bl1=None, |
| gumbel=gumbel, |
| norm_cfg=False, |
| cfg_exp_k=cfg_exp_k, |
| cfg_insertion_layer=cfg_insertion_layer, |
| vae_type=vae_type, |
| softmax_merge_topk=softmax_merge_topk, |
| ret_img=True, |
| trunk_scale=1000, |
| gt_leak=gt_leak, |
| gt_ls_Bl=gt_ls_Bl, |
| inference_mode=True, |
| sampling_per_bits=sampling_per_bits, |
| x_BLC_wo_prefix_lq=x_BLC_wo_prefix_lq, |
| ) |
| return img_list[0] |
|
|
|
|
|
|
| def gen_one_img( |
| infinity_test, |
| vae, |
| text_tokenizer, |
| text_encoder, |
| prompt, |
| cfg_list=[], |
| tau_list=[], |
| negative_prompt='', |
| scale_schedule=None, |
| top_k=900, |
| top_p=0.97, |
| cfg_sc=3, |
| cfg_exp_k=0.0, |
| cfg_insertion_layer=-5, |
| vae_type=0, |
| gumbel=0, |
| softmax_merge_topk=-1, |
| gt_leak=-1, |
| gt_ls_Bl=None, |
| g_seed=None, |
| sampling_per_bits=1, |
| enable_positive_prompt=0, |
| lq_img_path='', |
| gt_img_path='', |
| args=None, |
| swinir=None, |
| bitwise_self_correction=None, |
| ): |
| lq_img = Image.open(lq_img_path) |
| if lq_img.mode != "RGB": |
| lq_img = lq_img.convert("RGB") |
| lq_img = lq_img.resize((512,512)) |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| swinir = load_swinir_model(device, swinir=swinir) |
| lq_img = pil_to_lq_tensor(lq_img, device) |
| lq_img = run_swinir(lq_img, swinir) |
|
|
| text_cond_tuple, negative_label_B_or_BLT = prepare_prompt_conditions( |
| text_tokenizer, |
| text_encoder, |
| prompt, |
| negative_prompt=negative_prompt, |
| enable_positive_prompt=enable_positive_prompt, |
| ) |
| print(f'cfg: {cfg_list}, tau: {tau_list}') |
| img = infer_single_sr_tile( |
| infinity_test, |
| vae, |
| scale_schedule, |
| text_cond_tuple, |
| negative_label_B_or_BLT, |
| lq_img, |
| cfg_list=cfg_list, |
| tau_list=tau_list, |
| top_k=top_k, |
| top_p=top_p, |
| cfg_sc=cfg_sc, |
| cfg_exp_k=cfg_exp_k, |
| cfg_insertion_layer=cfg_insertion_layer, |
| vae_type=vae_type, |
| gumbel=gumbel, |
| softmax_merge_topk=softmax_merge_topk, |
| gt_leak=gt_leak, |
| gt_ls_Bl=gt_ls_Bl, |
| g_seed=g_seed, |
| sampling_per_bits=sampling_per_bits, |
| args=args, |
| bitwise_self_correction=bitwise_self_correction, |
| ) |
| return img |
|
|
|
|
| @torch.no_grad() |
| def gen_one_img_anyres( |
| infinity_test, |
| vae, |
| text_tokenizer, |
| text_encoder, |
| prompt, |
| cfg_list=[], |
| tau_list=[], |
| negative_prompt='', |
| scale_schedule=None, |
| top_k=900, |
| top_p=0.97, |
| cfg_sc=3, |
| cfg_exp_k=0.0, |
| cfg_insertion_layer=-5, |
| vae_type=0, |
| gumbel=0, |
| softmax_merge_topk=-1, |
| gt_leak=-1, |
| gt_ls_Bl=None, |
| g_seed=None, |
| sampling_per_bits=1, |
| enable_positive_prompt=0, |
| lq_img_path='', |
| gt_img_path='', |
| args=None, |
| swinir=None, |
| bitwise_self_correction=None, |
| tile_size=512, |
| tile_overlap=128, |
| sr_scale=1.0, |
| tiled=0, |
| ): |
| if tile_size != 512: |
| raise ValueError(f'Current B inference pipeline expects tile_size=512, got {tile_size}') |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| swinir = load_swinir_model(device, swinir=swinir) |
|
|
| if bitwise_self_correction is None: |
| if args is None: |
| args = globals().get('args', None) |
| bitwise_self_correction = BitwiseSelfCorrection(vae, args) |
|
|
| lq_img = Image.open(lq_img_path) |
| if lq_img.mode != "RGB": |
| lq_img = lq_img.convert("RGB") |
| original_width, original_height = lq_img.size |
| lq_img, target_width, target_height = resize_image_to_sr_scale(lq_img, sr_scale) |
| print(f'input size: {original_width}x{original_height}, sr_scale: {sr_scale}, target size: {target_width}x{target_height}, tiled: {tiled}') |
|
|
| text_cond_tuple, negative_label_B_or_BLT = prepare_prompt_conditions( |
| text_tokenizer, |
| text_encoder, |
| prompt, |
| negative_prompt=negative_prompt, |
| enable_positive_prompt=enable_positive_prompt, |
| ) |
|
|
| width, height = lq_img.size |
| if not tiled: |
| print(f'use resized single-pass inference for {width}x{height}') |
| resized_lq = lq_img.resize((tile_size, tile_size), resample=PImage.LANCZOS) |
| lq_tensor = pil_to_lq_tensor(resized_lq, device) |
| lq_tensor = run_swinir(lq_tensor, swinir) |
|
|
| img = infer_single_sr_tile( |
| infinity_test, |
| vae, |
| scale_schedule, |
| text_cond_tuple, |
| negative_label_B_or_BLT, |
| lq_tensor, |
| cfg_list=cfg_list, |
| tau_list=tau_list, |
| top_k=top_k, |
| top_p=top_p, |
| cfg_sc=cfg_sc, |
| cfg_exp_k=cfg_exp_k, |
| cfg_insertion_layer=cfg_insertion_layer, |
| vae_type=vae_type, |
| gumbel=gumbel, |
| softmax_merge_topk=softmax_merge_topk, |
| gt_leak=gt_leak, |
| gt_ls_Bl=gt_ls_Bl, |
| g_seed=g_seed, |
| sampling_per_bits=sampling_per_bits, |
| args=args, |
| bitwise_self_correction=bitwise_self_correction, |
| requant_mode='long', |
| ) |
|
|
| if width != tile_size or height != tile_size: |
| img = resize_output_tensor(img, width, height) |
| return img |
|
|
| image_np = np.asarray(lq_img) |
| image_np, border_pad = pad_image_for_tiling(image_np, tile_overlap) |
| padded_height, padded_width = image_np.shape[:2] |
|
|
| xs = compute_tile_positions(padded_width, tile_size, tile_overlap) |
| ys = compute_tile_positions(padded_height, tile_size, tile_overlap) |
| tile_weights = gaussian_tile_weights(tile_size, tile_size) |
|
|
| preds = np.zeros((padded_height, padded_width, 3), dtype=np.float32) |
| contributors = np.zeros((padded_height, padded_width, 1), dtype=np.float32) |
|
|
| total_tiles = len(xs) * len(ys) |
| tile_index = 0 |
| for top in ys: |
| for left in xs: |
| tile_index += 1 |
| tile_pil, valid_h, valid_w = crop_tile_with_padding(image_np, top, left, tile_size) |
| tile_tensor = pil_to_lq_tensor(tile_pil, device) |
| tile_tensor = run_swinir(tile_tensor, swinir) |
|
|
| tile_img = infer_single_sr_tile( |
| infinity_test, |
| vae, |
| scale_schedule, |
| text_cond_tuple, |
| negative_label_B_or_BLT, |
| tile_tensor, |
| cfg_list=cfg_list, |
| tau_list=tau_list, |
| top_k=top_k, |
| top_p=top_p, |
| cfg_sc=cfg_sc, |
| cfg_exp_k=cfg_exp_k, |
| cfg_insertion_layer=cfg_insertion_layer, |
| vae_type=vae_type, |
| gumbel=gumbel, |
| softmax_merge_topk=softmax_merge_topk, |
| gt_leak=gt_leak, |
| gt_ls_Bl=gt_ls_Bl, |
| g_seed=g_seed, |
| sampling_per_bits=sampling_per_bits, |
| args=args, |
| bitwise_self_correction=bitwise_self_correction, |
| requant_mode='long', |
| ) |
|
|
| tile_img = tile_img.detach().cpu().numpy().astype(np.float32) |
| valid_weights = tile_weights[:valid_h, :valid_w] |
| preds[top:top + valid_h, left:left + valid_w] += tile_img[:valid_h, :valid_w] * valid_weights |
| contributors[top:top + valid_h, left:left + valid_w] += valid_weights |
| print(f'processed tile {tile_index}/{total_tiles}: top={top}, left={left}, size={valid_h}x{valid_w}') |
|
|
| preds /= np.clip(contributors, 1e-8, None) |
| preds = np.clip(np.rint(preds), 0, 255).astype(np.uint8) |
| if border_pad > 0: |
| preds = preds[border_pad:border_pad + height, border_pad:border_pad + width] |
| return torch.from_numpy(preds) |
|
|
| @torch.no_grad() |
| def gen_one_img_eval( |
| infinity_test, |
| vae, |
| text_tokenizer, |
| text_encoder, |
| prompt, |
| cfg_list=[], |
| tau_list=[], |
| negative_prompt='', |
| scale_schedule=None, |
| top_k=900, |
| top_p=0.97, |
| cfg_sc=3, |
| cfg_exp_k=0.0, |
| cfg_insertion_layer=-5, |
| vae_type=0, |
| gumbel=0, |
| softmax_merge_topk=-1, |
| gt_leak=-1, |
| gt_ls_Bl=None, |
| g_seed=None, |
| sampling_per_bits=1, |
| enable_positive_prompt=0, |
| lq_img_path='', |
| args=None, |
| blip_model=None, |
| blip_processor=None, |
| ): |
| lq_img = Image.open(lq_img_path) |
| if lq_img.mode != "RGB": |
| lq_img = lq_img.convert("RGB") |
| |
| |
| |
| |
| lq_img = lq_img.resize((512,512)) |
|
|
| transform = transforms.ToTensor() |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| lq_img = transform(lq_img) |
| lq_img = lq_img*2-1 |
| lq_img = lq_img.unsqueeze(0).to(device, non_blocking=True) |
| |
| |
| lq_img = (lq_img+1)/2 |
| lq_img = swinir(lq_img) |
| lq_img = lq_img + lq_img -1 |
| |
| |
| |
| if not prompt: |
| raw_image = Image.open(lq_img_path).convert('RGB') |
| inputs = blip_processor(raw_image, return_tensors="pt").to("cuda", torch.float16) |
| out = blip_model.generate(**inputs) |
| prompt = blip_processor.decode(out[0], skip_special_tokens=True) |
| |
| |
| with torch.amp.autocast('cuda', enabled=False): |
| with torch.no_grad(): |
| if infinity_test.apply_spatial_patchify: |
| vae_scale_schedule = [(pt, 2*ph, 2*pw) for pt, ph, pw in scale_schedule] |
| else: |
| vae_scale_schedule = scale_schedule |
| |
| raw_features_lq, _, _ = vae.encode_for_raw_features(lq_img, scale_schedule=vae_scale_schedule) |
| |
| |
| bitwise_self_correction= BitwiseSelfCorrection(vae, args) |
| x_BLC_wo_prefix_lq,_ = bitwise_self_correction.flip_requant(vae_scale_schedule, lq_img, raw_features_lq, device) |
| |
| if not isinstance(cfg_list, list): |
| cfg_list = [cfg_list] * len(scale_schedule) |
| if not isinstance(tau_list, list): |
| tau_list = [tau_list] * len(scale_schedule) |
| text_cond_tuple = encode_prompt(text_tokenizer, text_encoder, prompt, enable_positive_prompt) |
| if negative_prompt: |
| negative_label_B_or_BLT = encode_prompt(text_tokenizer, text_encoder, negative_prompt) |
| else: |
| negative_label_B_or_BLT = None |
|
|
| with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16, cache_enabled=True): |
| |
| _, _, img_list = infinity_test.autoregressive_infer_cfg( |
| vae=vae, |
| scale_schedule=scale_schedule, |
| label_B_or_BLT=text_cond_tuple, g_seed=g_seed, |
| B=1, negative_label_B_or_BLT=negative_label_B_or_BLT, force_gt_Bhw=None, |
| cfg_sc=cfg_sc, cfg_list=cfg_list, tau_list=tau_list, top_k=top_k, top_p=top_p, |
| returns_vemb=1, ratio_Bl1=None, gumbel=gumbel, norm_cfg=False, |
| cfg_exp_k=cfg_exp_k, cfg_insertion_layer=cfg_insertion_layer, |
| vae_type=vae_type, softmax_merge_topk=softmax_merge_topk, |
| ret_img=True, trunk_scale=1000, |
| gt_leak=gt_leak, gt_ls_Bl=gt_ls_Bl, inference_mode=True, |
| sampling_per_bits=sampling_per_bits, |
| x_BLC_wo_prefix_lq=x_BLC_wo_prefix_lq, |
| ) |
| |
| img = img_list[0] |
| return img,prompt |
|
|
| @torch.no_grad() |
| def gen_one_img_eval_long( |
| infinity_test, |
| vae, |
| text_tokenizer, |
| text_encoder, |
| prompt, |
| cfg_list=[], |
| tau_list=[], |
| negative_prompt='', |
| scale_schedule=None, |
| top_k=900, |
| top_p=0.97, |
| cfg_sc=3, |
| cfg_exp_k=0.0, |
| cfg_insertion_layer=-5, |
| vae_type=0, |
| gumbel=0, |
| softmax_merge_topk=-1, |
| gt_leak=-1, |
| gt_ls_Bl=None, |
| g_seed=None, |
| sampling_per_bits=1, |
| enable_positive_prompt=0, |
| lq_img_path='', |
| args=None, |
| blip_model=None, |
| blip_processor=None, |
| swinir=None, |
| bitwise_self_correction=None |
| ): |
| lq_img = Image.open(lq_img_path) |
| if lq_img.mode != "RGB": |
| lq_img = lq_img.convert("RGB") |
| |
| |
| |
| |
| lq_img = lq_img.resize((512,512)) |
|
|
| transform = transforms.ToTensor() |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| lq_img = transform(lq_img) |
| lq_img = lq_img*2-1 |
| lq_img = lq_img.unsqueeze(0).to(device, non_blocking=True) |
| |
| lq_img = (lq_img+1)/2 |
| lq_img = swinir(lq_img) |
| lq_img = lq_img + lq_img -1 |
| |
| |
| |
| if not prompt: |
| raw_image = Image.open(lq_img_path).convert('RGB') |
| inputs = blip_processor(raw_image, return_tensors="pt").to("cuda", torch.float16) |
| out = blip_model.generate(**inputs) |
| prompt = blip_processor.decode(out[0], skip_special_tokens=True) |
| |
| |
| with torch.amp.autocast('cuda', enabled=False): |
| with torch.no_grad(): |
| if infinity_test.apply_spatial_patchify: |
| vae_scale_schedule = [(pt, 2*ph, 2*pw) for pt, ph, pw in scale_schedule] |
| else: |
| vae_scale_schedule = scale_schedule |
| |
| raw_features_lq, _, _ = vae.encode_for_raw_features(lq_img, scale_schedule=vae_scale_schedule) |
| |
| |
| |
| |
| |
| |
| x_BLC_wo_prefix_lq_long,_ = bitwise_self_correction.long_flip_requant(vae_scale_schedule, lq_img, raw_features_lq, device) |
| |
| |
| if not isinstance(cfg_list, list): |
| cfg_list = [cfg_list] * len(scale_schedule) |
| if not isinstance(tau_list, list): |
| tau_list = [tau_list] * len(scale_schedule) |
| text_cond_tuple = encode_prompt(text_tokenizer, text_encoder, prompt, enable_positive_prompt) |
| if negative_prompt: |
| negative_label_B_or_BLT = encode_prompt(text_tokenizer, text_encoder, negative_prompt) |
| else: |
| negative_label_B_or_BLT = None |
|
|
| with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16, cache_enabled=True): |
| |
| _, _, img_list = infinity_test.autoregressive_infer_cfg( |
| vae=vae, |
| scale_schedule=scale_schedule, |
| label_B_or_BLT=text_cond_tuple, g_seed=g_seed, |
| B=1, negative_label_B_or_BLT=negative_label_B_or_BLT, force_gt_Bhw=None, |
| cfg_sc=cfg_sc, cfg_list=cfg_list, tau_list=tau_list, top_k=top_k, top_p=top_p, |
| returns_vemb=1, ratio_Bl1=None, gumbel=gumbel, norm_cfg=False, |
| cfg_exp_k=cfg_exp_k, cfg_insertion_layer=cfg_insertion_layer, |
| vae_type=vae_type, softmax_merge_topk=softmax_merge_topk, |
| ret_img=True, trunk_scale=1000, |
| gt_leak=gt_leak, gt_ls_Bl=gt_ls_Bl, inference_mode=True, |
| sampling_per_bits=sampling_per_bits, |
| x_BLC_wo_prefix_lq=x_BLC_wo_prefix_lq_long, |
| ) |
| |
| img = img_list[0] |
| return img,prompt |
|
|
| def get_prompt_id(prompt): |
| md5 = hashlib.md5() |
| md5.update(prompt.encode('utf-8')) |
| prompt_id = md5.hexdigest() |
| return prompt_id |
|
|
| def save_slim_model(infinity_model_path, save_file=None, device='cpu', key='gpt_fsdp'): |
| print('[Save slim model]') |
| full_ckpt = torch.load(infinity_model_path, map_location=device) |
| infinity_slim = full_ckpt['trainer'][key] |
| |
| if not save_file: |
| save_file = osp.splitext(infinity_model_path)[0] + '-slim.pth' |
| print(f'Save to {save_file}') |
| torch.save(infinity_slim, save_file) |
| print('[Save slim model] done') |
| return save_file |
|
|
| def load_tokenizer(t5_path =''): |
| print(f'[Loading tokenizer and text encoder]') |
| text_tokenizer: T5TokenizerFast = AutoTokenizer.from_pretrained(t5_path, revision=None, legacy=True) |
| text_tokenizer.model_max_length = 512 |
| text_encoder: T5EncoderModel = T5EncoderModel.from_pretrained(t5_path, torch_dtype=torch.float16) |
| text_encoder.to('cuda') |
| text_encoder.eval() |
| text_encoder.requires_grad_(False) |
| return text_tokenizer, text_encoder |
|
|
| def load_infinity( |
| rope2d_each_sa_layer, |
| rope2d_normalized_by_hw, |
| use_scale_schedule_embedding, |
| pn, |
| use_bit_label, |
| add_lvl_embeding_only_first_block, |
| model_path='', |
| scale_schedule=None, |
| vae=None, |
| device='cuda', |
| model_kwargs=None, |
| text_channels=2048, |
| apply_spatial_patchify=0, |
| use_flex_attn=False, |
| bf16=False, |
| ): |
| print(f'[Loading Infinity]') |
| text_maxlen = 512 |
| with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16, cache_enabled=True), torch.no_grad(): |
| infinity_test: Infinity = BInfinity( |
| vae_local=vae, text_channels=text_channels, text_maxlen=text_maxlen, |
| shared_aln=True, raw_scale_schedule=scale_schedule, |
| checkpointing='full-block', |
| customized_flash_attn=False, |
| fused_norm=True, |
| pad_to_multiplier=1, |
| use_flex_attn=use_flex_attn, |
| add_lvl_embeding_only_first_block=add_lvl_embeding_only_first_block, |
| use_bit_label=use_bit_label, |
| rope2d_each_sa_layer=rope2d_each_sa_layer, |
| rope2d_normalized_by_hw=rope2d_normalized_by_hw, |
| pn=pn, |
| apply_spatial_patchify=apply_spatial_patchify, |
| inference_mode=True, |
| train_h_div_w_list=[1.0], |
| **model_kwargs, |
| ).to(device=device) |
| print(f'[you selected Infinity with {model_kwargs=}] model size: {sum(p.numel() for p in infinity_test.parameters())/1e9:.2f}B, bf16={bf16}') |
|
|
| if bf16: |
| for block in infinity_test.unregistered_blocks: |
| block.bfloat16() |
|
|
| |
| |
| unet_lora_params, train_names = inject_trainable_lora(infinity_test.block_chunks, target_replace_module={"CrossAttention", "SelfAttention"}, r=32) |
| |
| infinity_test.eval() |
| infinity_test.requires_grad_(False) |
|
|
| infinity_test.cuda() |
| torch.cuda.empty_cache() |
|
|
| print(f'[Load Infinity weights]') |
| checkpoint = torch.load(model_path,map_location=device) |
| infinity_test.load_state_dict(checkpoint,strict=True) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| infinity_test.rng = torch.Generator(device=device) |
| return infinity_test |
|
|
| def transform(pil_img, tgt_h, tgt_w): |
| width, height = pil_img.size |
| if width / height <= tgt_w / tgt_h: |
| resized_width = tgt_w |
| resized_height = int(tgt_w / (width / height)) |
| else: |
| resized_height = tgt_h |
| resized_width = int((width / height) * tgt_h) |
| pil_img = pil_img.resize((resized_width, resized_height), resample=PImage.LANCZOS) |
| |
| arr = np.array(pil_img) |
| crop_y = (arr.shape[0] - tgt_h) // 2 |
| crop_x = (arr.shape[1] - tgt_w) // 2 |
| im = to_tensor(arr[crop_y: crop_y + tgt_h, crop_x: crop_x + tgt_w]) |
| return im.add(im).add_(-1) |
|
|
| def joint_vi_vae_encode_decode(vae, image_path, scale_schedule, device, tgt_h, tgt_w): |
| pil_image = Image.open(image_path).convert('RGB') |
| inp = transform(pil_image, tgt_h, tgt_w) |
| inp = inp.unsqueeze(0).to(device) |
| scale_schedule = [(item[0], item[1], item[2]) for item in scale_schedule] |
| h, z, _, all_bit_indices, _, infinity_input = vae.encode(inp, scale_schedule=scale_schedule) |
| recons_img = vae.decode(z)[0] |
| if len(recons_img.shape) == 4: |
| recons_img = recons_img.squeeze(1) |
| print(f'recons: z.shape: {z.shape}, recons_img shape: {recons_img.shape}') |
| recons_img = (recons_img + 1) / 2 |
| recons_img = recons_img.permute(1, 2, 0).mul_(255).cpu().numpy().astype(np.uint8) |
| gt_img = (inp[0] + 1) / 2 |
| gt_img = gt_img.permute(1, 2, 0).mul_(255).cpu().numpy().astype(np.uint8) |
| print(recons_img.shape, gt_img.shape) |
| return gt_img, recons_img, all_bit_indices |
|
|
| def load_visual_tokenizer(args): |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| |
| if args.vae_type in [16,18,20,24,32,64]: |
| from infinity.models.bsq_vae.vae import vae_model |
| schedule_mode = "dynamic" |
| codebook_dim = args.vae_type |
| codebook_size = 2**codebook_dim |
| if args.apply_spatial_patchify: |
| patch_size = 8 |
| encoder_ch_mult=[1, 2, 4, 4] |
| decoder_ch_mult=[1, 2, 4, 4] |
| else: |
| patch_size = 16 |
| encoder_ch_mult=[1, 2, 4, 4, 4] |
| decoder_ch_mult=[1, 2, 4, 4, 4] |
| vae = vae_model(args.vae_path, schedule_mode, codebook_dim, codebook_size, patch_size=patch_size, |
| encoder_ch_mult=encoder_ch_mult, decoder_ch_mult=decoder_ch_mult, test_mode=True).to(device) |
| else: |
| raise ValueError(f'vae_type={args.vae_type} not supported') |
| return vae |
|
|
| def load_visual_tokenizer_lora(args): |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| |
| if args.vae_type in [16,18,20,24,32,64]: |
| from infinity.models.bsq_vae.vae import vae_model_lora |
| schedule_mode = "dynamic" |
| codebook_dim = args.vae_type |
| codebook_size = 2**codebook_dim |
| if args.apply_spatial_patchify: |
| patch_size = 8 |
| encoder_ch_mult=[1, 2, 4, 4] |
| decoder_ch_mult=[1, 2, 4, 4] |
| else: |
| patch_size = 16 |
| encoder_ch_mult=[1, 2, 4, 4, 4] |
| decoder_ch_mult=[1, 2, 4, 4, 4] |
| vae = vae_model_lora(args.vae_path, schedule_mode, codebook_dim, codebook_size, patch_size=patch_size, |
| encoder_ch_mult=encoder_ch_mult, decoder_ch_mult=decoder_ch_mult, test_mode=True).to(device) |
| else: |
| raise ValueError(f'vae_type={args.vae_type} not supported') |
| return vae |
|
|
| def load_transformer(vae, args): |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| model_path = args.model_path |
|
|
| if args.model_type == 'infinity_2b': |
| kwargs_model = dict(depth=32, embed_dim=2048, num_heads=2048//128, drop_path_rate=0.1, mlp_ratio=4, block_chunks=8) |
| elif args.model_type == 'infinity_layer12': |
| kwargs_model = dict(depth=12, embed_dim=768, num_heads=8, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4) |
| elif args.model_type == 'infinity_layer16': |
| kwargs_model = dict(depth=16, embed_dim=1152, num_heads=12, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4) |
| elif args.model_type == 'infinity_layer24': |
| kwargs_model = dict(depth=24, embed_dim=1536, num_heads=16, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4) |
| elif args.model_type == 'infinity_layer32': |
| kwargs_model = dict(depth=32, embed_dim=2080, num_heads=20, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4) |
| elif args.model_type == 'infinity_layer40': |
| kwargs_model = dict(depth=40, embed_dim=2688, num_heads=24, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4) |
| elif args.model_type == 'infinity_layer48': |
| kwargs_model = dict(depth=48, embed_dim=3360, num_heads=28, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4) |
| infinity = load_infinity( |
| rope2d_each_sa_layer=args.rope2d_each_sa_layer, |
| rope2d_normalized_by_hw=args.rope2d_normalized_by_hw, |
| use_scale_schedule_embedding=args.use_scale_schedule_embedding, |
| pn=args.pn, |
| use_bit_label=args.use_bit_label, |
| add_lvl_embeding_only_first_block=args.add_lvl_embeding_only_first_block, |
| model_path=model_path, |
| scale_schedule=None, |
| vae=vae, |
| device=device, |
| model_kwargs=kwargs_model, |
| text_channels=args.text_channels, |
| apply_spatial_patchify=args.apply_spatial_patchify, |
| use_flex_attn=args.use_flex_attn, |
| bf16=args.bf16, |
| ) |
| return infinity |
|
|
| def add_common_arguments(parser): |
| parser.add_argument('--cfg', type=str, default='3') |
| parser.add_argument('--tau', type=float, default=1) |
| parser.add_argument('--pn', type=str, required=True, choices=['0.06M', '0.25M', '1M']) |
| parser.add_argument('--model_path', type=str, required=True) |
| parser.add_argument('--cfg_insertion_layer', type=int, default=0) |
| parser.add_argument('--vae_type', type=int, default=1) |
| parser.add_argument('--vae_path', type=str, default='') |
| parser.add_argument('--add_lvl_embeding_only_first_block', type=int, default=0, choices=[0,1]) |
| parser.add_argument('--use_bit_label', type=int, default=1, choices=[0,1]) |
| parser.add_argument('--model_type', type=str, default='infinity_2b') |
| parser.add_argument('--rope2d_each_sa_layer', type=int, default=1, choices=[0,1]) |
| parser.add_argument('--rope2d_normalized_by_hw', type=int, default=2, choices=[0,1,2]) |
| parser.add_argument('--use_scale_schedule_embedding', type=int, default=0, choices=[0,1]) |
| parser.add_argument('--sampling_per_bits', type=int, default=1, choices=[1,2,4,8,16]) |
| parser.add_argument('--text_encoder_ckpt', type=str, default='') |
| parser.add_argument('--text_channels', type=int, default=2048) |
| parser.add_argument('--apply_spatial_patchify', type=int, default=0, choices=[0,1]) |
| parser.add_argument('--h_div_w_template', type=float, default=1.000) |
| parser.add_argument('--use_flex_attn', type=int, default=0, choices=[0,1]) |
| parser.add_argument('--enable_positive_prompt', type=int, default=0, choices=[0,1]) |
| parser.add_argument('--cache_dir', type=str, default='/dev/shm') |
| parser.add_argument('--checkpoint_type', type=str, default='torch') |
| parser.add_argument('--seed', type=int, default=0) |
| parser.add_argument('--bf16', type=int, default=1, choices=[0,1]) |
| parser.add_argument('--tile_size', type=int, default=512) |
| parser.add_argument('--tile_overlap', type=int, default=128) |
| parser.add_argument('--sr_scale', type=float, default=1.0) |
| parser.add_argument('--tiled', type=int, default=0, choices=[0,1]) |
|
|
| def encode_and_decode(lq_img_path,vae,save_path): |
| |
| lq_img = Image.open(lq_img_path) |
| if lq_img.mode != "RGB": |
| lq_img = lq_img.convert("RGB") |
| transform = transforms.ToTensor() |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| lq_img = transform(lq_img) |
| lq_img = lq_img*2-1 |
| x = lq_img.unsqueeze(0).to(device, non_blocking=True) |
|
|
| is_image = x.ndim == 4 |
| if not is_image: |
| B, C, T, H, W = x.shape |
| else: |
| B, C, H, W = x.shape |
| T = 1 |
| ptdtype = {None: torch.float32, 'fp32': torch.float32, 'bf16': torch.bfloat16} |
| enc_dtype = ptdtype[vae.args.encoder_dtype] |
|
|
| with torch.amp.autocast("cuda", dtype=enc_dtype): |
| h, hs, hs_mid = vae.encoder(x, return_hidden=True) |
| hs = [_h.detach() for _h in hs] |
| hs_mid = [_h.detach() for _h in hs_mid] |
| h = h.to(dtype=torch.float32) |
| |
| |
| |
| z,_,_,_,_,_ = vae.quantizer(h) |
| x_recon = vae.decoder(z) |
| |
| x_recon = (x_recon+1)/2 |
| x_recon = x_recon.squeeze(0) |
| to_pil = transforms.ToPILImage() |
| x_recon = to_pil(x_recon) |
| x_recon.save(save_path) |
|
|
| if __name__ == '__main__': |
| parser = argparse.ArgumentParser() |
| add_common_arguments(parser) |
| parser.add_argument('--prompt', type=str, default='a dog') |
| parser.add_argument('--save_file', type=str, default='./tmp.jpg') |
| parser.add_argument('--lq_img_path', type=str, default='') |
| parser.add_argument('--noise_apply_layers',type=int,default=-1) |
| parser.add_argument('--noise_apply_requant',type=int,default=1) |
| parser.add_argument('--noise_apply_strength',type=float,default=0.3) |
| parser.add_argument('--debug_bsc',type=int,default=0) |
| args = parser.parse_args() |
| |
| |
| |
| |
|
|
| |
| args.cfg = list(map(float, args.cfg.split(','))) |
| if len(args.cfg) == 1: |
| args.cfg = args.cfg[0] |
| |
| |
| text_tokenizer, text_encoder = load_tokenizer(t5_path =args.text_encoder_ckpt) |
| |
| vae = load_visual_tokenizer(args) |
| |
| infinity = load_transformer(vae, args) |
| bitwise_self_correction = BitwiseSelfCorrection(vae, args) |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| swinir = load_swinir_model(device) |
| |
| scale_schedule = dynamic_resolution_h_w[args.h_div_w_template][args.pn]['scales'] |
| scale_schedule = [ (1, h, w) for (_, h, w) in scale_schedule] |
| |
|
|
| with autocast(dtype=torch.bfloat16): |
| with torch.no_grad(): |
| generated_image = gen_one_img_anyres( |
| infinity, |
| vae, |
| text_tokenizer, |
| text_encoder, |
| args.prompt, |
| g_seed=args.seed, |
| gt_leak=0, |
| gt_ls_Bl=None, |
| cfg_list=args.cfg, |
| tau_list=args.tau, |
| scale_schedule=scale_schedule, |
| cfg_insertion_layer=[args.cfg_insertion_layer], |
| vae_type=args.vae_type, |
| sampling_per_bits=args.sampling_per_bits, |
| enable_positive_prompt=args.enable_positive_prompt, |
| lq_img_path=args.lq_img_path, |
| args=args, |
| swinir=swinir, |
| bitwise_self_correction=bitwise_self_correction, |
| tile_size=args.tile_size, |
| tile_overlap=args.tile_overlap, |
| sr_scale=args.sr_scale, |
| tiled=args.tiled, |
| ) |
| os.makedirs(osp.dirname(osp.abspath(args.save_file)), exist_ok=True) |
| cv2.imwrite(args.save_file, generated_image.cpu().numpy()) |
| print(f'Save to {osp.abspath(args.save_file)}') |
|
|