| import torch
|
| import numpy as np
|
| from PIL import Image
|
| from safetensors.torch import load_file
|
| from .models.gpt_t2i import GPT_models
|
| from .models.generate import generate
|
| from .tokenizer.vq_model import VQ_models
|
|
|
| class CondRefARPipeline:
|
| def __init__(self, device=None, torch_dtype=torch.bfloat16):
|
| self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| self.dtype = torch_dtype
|
| self.gpt = None
|
| self.vq = None
|
| self.image_size = None
|
| self.downsample = None
|
| self.n_q = 8
|
|
|
| @classmethod
|
| def from_pretrained(cls, repo_or_path, gpt_config, vq_config, gpt_weights="weights/sketch-gpt-xl.safetensors", vq_weights="weights/vq-16.safetensors", device=None, torch_dtype=torch.bfloat16):
|
| pipe = cls(device=device, torch_dtype=torch_dtype)
|
|
|
|
|
| pipe.downsample = int(vq_config["downsample_size"])
|
| codebook_size = int(vq_config["codebook_size"])
|
| codebook_embed_dim = int(vq_config["codebook_embed_dim"])
|
| pipe.vq = VQ_models[vq_config.get("model_name", "VQ-16")](codebook_size=codebook_size, codebook_embed_dim=codebook_embed_dim)
|
| vq_state = load_file(f"{repo_or_path}/{vq_weights}")
|
| pipe.vq.load_state_dict(vq_state, strict=True)
|
| pipe.vq.to(pipe.device)
|
| pipe.vq.eval()
|
|
|
|
|
| pipe.image_size = int(gpt_config["image_size"])
|
| vocab_size = int(gpt_config["vocab_size"])
|
| latent_size = pipe.image_size // pipe.downsample
|
| block_size=latent_size ** 2
|
| num_classes = int(gpt_config.get("num_classes", 1000))
|
| cls_token_num = int(gpt_config.get("cls_token_num", 120))
|
| model_type = gpt_config.get("model_type", "t2i")
|
| adapter_size = gpt_config.get("adapter_size", "small")
|
| condition_type = gpt_config.get("condition_type", "sketch")
|
|
|
|
|
| pipe.gpt = GPT_models[gpt_config.get("gpt_name", "GPT-XL")](
|
| vocab_size=vocab_size,
|
| block_size=block_size,
|
| num_classes=num_classes,
|
| cls_token_num=cls_token_num,
|
| model_type=model_type,
|
| adapter_size=adapter_size,
|
| condition_type=condition_type
|
| ).to(device=pipe.device, dtype=pipe.dtype)
|
| gpt_state = load_file(f"{repo_or_path}/{gpt_weights}")
|
| pipe.gpt.load_state_dict(gpt_state, strict=False)
|
| pipe.gpt.eval()
|
|
|
| return pipe
|
|
|
| @torch.inference_mode()
|
| def __call__(self, prompt_emb, control_image, cfg_scale=4, cfg_interval=-1, temperature=1.0, top_k=2000, top_p=1.0):
|
| """
|
| prompt_emb: torch.Tensor [B, T_txt, D]
|
| control_image: np.ndarray/PIL
|
| Return: Image
|
| """
|
|
|
| if isinstance(control_image, Image.Image):
|
| control_image = np.array(control_image.convert("RGB"))
|
| if isinstance(control_image, np.ndarray):
|
|
|
| control_image = torch.from_numpy(control_image).permute(2,0,1).unsqueeze(0).float()
|
| if control_image.max() > 1.0:
|
| control_image = control_image / 255.0
|
| control_image = 2.0 * (control_image - 0.5)
|
| control = control_image.to(self.device, dtype=self.dtype)
|
|
|
| c_indices = prompt_emb.to(self.device, dtype=self.dtype)
|
|
|
| c_emb_masks = None
|
|
|
| Hq = self.image_size // self.downsample
|
| Wq = Hq
|
| seq_len = Hq * Wq
|
|
|
| index_sample = generate(
|
| self.gpt, c_indices, seq_len, c_emb_masks,
|
| condition=control, cfg_scale=cfg_scale, cfg_interval=cfg_interval,
|
| temperature=temperature, top_k=top_k, top_p=top_p, sample_logits=True
|
| )
|
|
|
| if index_sample.dim() == 2 and index_sample.shape[1] == self.n_q * Hq * Wq:
|
| tokens = index_sample.view(index_sample.size(0), self.n_q, Hq, Wq).long()
|
| elif index_sample.dim() == 2 and index_sample.shape[1] == Hq * Wq:
|
| tokens = index_sample.view(index_sample.size(0), 1, Hq, Wq).long()
|
| else:
|
|
|
| n_q = max(1, index_sample.shape[1] // (Hq * Wq))
|
| tokens = index_sample[:, : n_q * Hq * Wq].view(index_sample.size(0), n_q, Hq, Wq).long()
|
| tokens = tokens.to(self.device)
|
| qzshape = [tokens.size(0), 8, Hq, Wq]
|
| samples = self.vq.decode_code(tokens, qzshape).detach().float().cpu()
|
|
|
| if samples.min() < -0.9:
|
| samples = (samples + 1.0) / 2.0
|
| samples = samples.clamp(0, 1)
|
|
|
| imgs = []
|
| arr = (samples * 255).to(torch.uint8).permute(0,2,3,1).numpy()
|
| for i in range(arr.shape[0]):
|
| imgs.append(Image.fromarray(arr[i]))
|
| return imgs |