| from typing import Dict, List, Any |
| import base64 |
| from io import BytesIO |
| from pathlib import Path |
| import torch |
| from torch import autocast |
| import open_clip |
| from open_clip import tokenizer |
| from rudalle import get_vae |
| from einops import rearrange |
| from PIL import Image |
|
|
| from modules import DenoiseUNet |
|
|
| |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
| batch_size = 1 |
| steps = 11 |
| scale = 5 |
|
|
| def to_pil(images): |
| images = images.permute(0, 2, 3, 1).cpu().numpy() |
| images = (images * 255).round().astype("uint8") |
| images = [Image.fromarray(image) for image in images] |
| return images |
|
|
| def log(t, eps=1e-20): |
| return torch.log(t + eps) |
|
|
| def gumbel_noise(t): |
| noise = torch.zeros_like(t).uniform_(0, 1) |
| return -log(-log(noise)) |
|
|
| def gumbel_sample(t, temperature=1., dim=-1): |
| return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim) |
|
|
| def sample(model, c, x=None, mask=None, T=12, size=(32, 32), starting_t=0, temp_range=[1.0, 1.0], typical_filtering=True, typical_mass=0.2, typical_min_tokens=1, classifier_free_scale=-1, renoise_steps=11, renoise_mode='start'): |
| with torch.inference_mode(): |
| r_range = torch.linspace(0, 1, T+1)[:-1][:, None].expand(-1, c.size(0)).to(c.device) |
| temperatures = torch.linspace(temp_range[0], temp_range[1], T) |
| preds = [] |
| if x is None: |
| x = torch.randint(0, model.num_labels, size=(c.size(0), *size), device=c.device) |
| elif mask is not None: |
| noise = torch.randint(0, model.num_labels, size=(c.size(0), *size), device=c.device) |
| x = noise * mask + (1-mask) * x |
| init_x = x.clone() |
| for i in range(starting_t, T): |
| if renoise_mode == 'prev': |
| prev_x = x.clone() |
| r, temp = r_range[i], temperatures[i] |
| logits = model(x, c, r) |
| if classifier_free_scale >= 0: |
| logits_uncond = model(x, torch.zeros_like(c), r) |
| logits = torch.lerp(logits_uncond, logits, classifier_free_scale) |
| x = logits |
| x_flat = x.permute(0, 2, 3, 1).reshape(-1, x.size(1)) |
| if typical_filtering: |
| x_flat_norm = torch.nn.functional.log_softmax(x_flat, dim=-1) |
| x_flat_norm_p = torch.exp(x_flat_norm) |
| entropy = -(x_flat_norm * x_flat_norm_p).nansum(-1, keepdim=True) |
|
|
| c_flat_shifted = torch.abs((-x_flat_norm) - entropy) |
| c_flat_sorted, x_flat_indices = torch.sort(c_flat_shifted, descending=False) |
| x_flat_cumsum = x_flat.gather(-1, x_flat_indices).softmax(dim=-1).cumsum(dim=-1) |
|
|
| last_ind = (x_flat_cumsum < typical_mass).sum(dim=-1) |
| sorted_indices_to_remove = c_flat_sorted > c_flat_sorted.gather(1, last_ind.view(-1, 1)) |
| if typical_min_tokens > 1: |
| sorted_indices_to_remove[..., :typical_min_tokens] = 0 |
| indices_to_remove = sorted_indices_to_remove.scatter(1, x_flat_indices, sorted_indices_to_remove) |
| x_flat = x_flat.masked_fill(indices_to_remove, -float("Inf")) |
| |
| x_flat = gumbel_sample(x_flat, temperature=temp) |
| x = x_flat.view(x.size(0), *x.shape[2:]) |
| if mask is not None: |
| x = x * mask + (1-mask) * init_x |
| if i < renoise_steps: |
| if renoise_mode == 'start': |
| x, _ = model.add_noise(x, r_range[i+1], random_x=init_x) |
| elif renoise_mode == 'prev': |
| x, _ = model.add_noise(x, r_range[i+1], random_x=prev_x) |
| else: |
| x, _ = model.add_noise(x, r_range[i+1]) |
| preds.append(x.detach()) |
| return preds |
|
|
|
|
| class EndpointHandler(): |
| def __init__(self, path=""): |
| model_path = Path(path) / "model_600000.pt" |
| state_dict = torch.load(model_path, map_location=device) |
| model = DenoiseUNet(num_labels=8192).to(device) |
| model.load_state_dict(state_dict) |
| model.to(device).eval().requires_grad_() |
| self.model = model |
|
|
| vqmodel = get_vae().to(device) |
| vqmodel.eval().requires_grad_(False) |
| self.vqmodel = vqmodel |
|
|
| clip_model, _, _ = open_clip.create_model_and_transforms('ViT-g-14', pretrained='laion2b_s12b_b42k') |
| clip_model = clip_model.to(device).eval().requires_grad_(False) |
| self.clip_model = clip_model |
|
|
| def encode(self, x): |
| return self.vqmodel.model.encode((2 * x - 1))[-1][-1] |
| |
| def decode(self, img_seq, shape=(32,32)): |
| img_seq = img_seq.view(img_seq.shape[0], -1) |
| b, n = img_seq.shape |
| one_hot_indices = torch.nn.functional.one_hot(img_seq, num_classes=self.vqmodel.num_tokens).float() |
| z = (one_hot_indices @ self.vqmodel.model.quantize.embed.weight) |
| z = rearrange(z, 'b (h w) c -> b c h w', h=shape[0], w=shape[1]) |
| img = self.vqmodel.model.decode(z) |
| img = (img.clamp(-1., 1.) + 1) * 0.5 |
| return img |
|
|
| def __call__(self, data: Any) -> List[List[Dict[str, float]]]: |
| """ |
| Args: |
| data (:obj:): |
| includes the input data and the parameters for the inference. |
| Return: |
| A :obj:`dict`:. base64 encoded image |
| """ |
| inputs = data.pop("inputs", data) |
|
|
| latent_shape = (32, 32) |
| tokenized_text = tokenizer.tokenize([inputs] * batch_size).to(device) |
| with autocast(device.type): |
| clip_embeddings = self.clip_model.encode_text(tokenized_text) |
| images = sample( |
| self.model, clip_embeddings, T=12, size=latent_shape, starting_t=0, temp_range=[1.0, 1.0], |
| typical_filtering=True, typical_mass=0.2, typical_min_tokens=1, |
| classifier_free_scale=scale, renoise_steps=steps, renoise_mode="start" |
| ) |
| images = self.decode(images[-1], latent_shape) |
| images = to_pil(images) |
|
|
| |
| buffered = BytesIO() |
| images[0].save(buffered, format="JPEG") |
| img_str = base64.b64encode(buffered.getvalue()) |
|
|
| |
| return {"image": img_str.decode()} |
|
|