| import argparse |
| import contextlib |
| import time |
| import gc |
| import logging |
| import math |
| import os |
| import random |
| import jsonlines |
| import functools |
| import shutil |
| import pyrallis |
| import itertools |
| from pathlib import Path |
| from collections import namedtuple, OrderedDict |
|
|
| import accelerate |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| import torch.utils.checkpoint |
| import transformers |
| from accelerate import Accelerator |
| from accelerate.logging import get_logger |
| from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed |
| from datasets import load_dataset |
| from packaging import version |
| from PIL import Image |
| from losses.losses import * |
| from torchvision import transforms |
| from torchvision.transforms.functional import crop |
| from tqdm.auto import tqdm |
|
|
|
|
| def import_model_class_from_model_name_or_path( |
| pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" |
| ): |
| from transformers import PretrainedConfig |
| text_encoder_config = PretrainedConfig.from_pretrained( |
| pretrained_model_name_or_path, subfolder=subfolder, revision=revision |
| ) |
| model_class = text_encoder_config.architectures[0] |
|
|
| if model_class == "CLIPTextModel": |
| from transformers import CLIPTextModel |
|
|
| return CLIPTextModel |
| elif model_class == "CLIPTextModelWithProjection": |
| from transformers import CLIPTextModelWithProjection |
|
|
| return CLIPTextModelWithProjection |
| else: |
| raise ValueError(f"{model_class} is not supported.") |
|
|
| def get_train_dataset(dataset_name, dataset_dir, args, accelerator): |
| |
| |
|
|
| |
| |
| dataset = load_dataset( |
| dataset_name, |
| data_dir=dataset_dir, |
| cache_dir=os.path.join(dataset_dir, ".cache"), |
| num_proc=4, |
| split="train", |
| ) |
|
|
| |
| |
| column_names = dataset.column_names |
|
|
| |
| if args.image_column is None: |
| args.image_column = column_names[0] |
| logger.info(f"image column defaulting to {column_names[0]}") |
| else: |
| image_column = args.image_column |
| if image_column not in column_names: |
| logger.warning(f"dataset {dataset_name} has no column {image_column}") |
|
|
| if args.caption_column is None: |
| args.caption_column = column_names[1] |
| logger.info(f"caption column defaulting to {column_names[1]}") |
| else: |
| caption_column = args.caption_column |
| if caption_column not in column_names: |
| logger.warning(f"dataset {dataset_name} has no column {caption_column}") |
|
|
| if args.conditioning_image_column is None: |
| args.conditioning_image_column = column_names[2] |
| logger.info(f"conditioning image column defaulting to {column_names[2]}") |
| else: |
| conditioning_image_column = args.conditioning_image_column |
| if conditioning_image_column not in column_names: |
| logger.warning(f"dataset {dataset_name} has no column {conditioning_image_column}") |
|
|
| with accelerator.main_process_first(): |
| train_dataset = dataset.shuffle(seed=args.seed) |
| if args.max_train_samples is not None: |
| train_dataset = train_dataset.select(range(args.max_train_samples)) |
| return train_dataset |
|
|
| def prepare_train_dataset(dataset, accelerator, deg_pipeline, centralize=False): |
|
|
| |
| hflip = deg_pipeline.augment_opt['use_hflip'] and random.random() < 0.5 |
| vflip = deg_pipeline.augment_opt['use_rot'] and random.random() < 0.5 |
| rot90 = deg_pipeline.augment_opt['use_rot'] and random.random() < 0.5 |
| augment_transforms = [] |
| if hflip: |
| augment_transforms.append(transforms.RandomHorizontalFlip(p=1.0)) |
| if vflip: |
| augment_transforms.append(transforms.RandomVerticalFlip(p=1.0)) |
| if rot90: |
| |
| augment_transforms.append(transforms.RandomRotation(degrees=(90,90))) |
| torch_transforms=[transforms.ToTensor()] |
| if centralize: |
| |
| torch_transforms.append(transforms.Normalize([0.5], [0.5])) |
|
|
| training_size = deg_pipeline.degrade_opt['gt_size'] |
| image_transforms = transforms.Compose(augment_transforms) |
| train_transforms = transforms.Compose(torch_transforms) |
| train_resize = transforms.Resize(training_size, interpolation=transforms.InterpolationMode.BILINEAR) |
| train_crop = transforms.RandomCrop(training_size) |
|
|
| def preprocess_train(examples): |
| raw_images = [] |
| for img_data in examples[args.image_column]: |
| raw_images.append(Image.open(img_data).convert("RGB")) |
|
|
| |
| images = [] |
| original_sizes = [] |
| crop_top_lefts = [] |
| |
| kernel = [] |
| kernel2 = [] |
| sinc_kernel = [] |
|
|
| for raw_image in raw_images: |
| raw_image = image_transforms(raw_image) |
| original_sizes.append((raw_image.height, raw_image.width)) |
|
|
| |
| raw_image = train_resize(raw_image) |
| |
| y1, x1, h, w = train_crop.get_params(raw_image, (training_size, training_size)) |
| raw_image = crop(raw_image, y1, x1, h, w) |
| crop_top_left = (y1, x1) |
| crop_top_lefts.append(crop_top_left) |
| image = train_transforms(raw_image) |
|
|
| images.append(image) |
| k, k2, sk = deg_pipeline.get_kernel() |
| kernel.append(k) |
| kernel2.append(k2) |
| sinc_kernel.append(sk) |
|
|
| examples["images"] = images |
| examples["original_sizes"] = original_sizes |
| examples["crop_top_lefts"] = crop_top_lefts |
| examples["kernel"] = kernel |
| examples["kernel2"] = kernel2 |
| examples["sinc_kernel"] = sinc_kernel |
|
|
| return examples |
|
|
| with accelerator.main_process_first(): |
| dataset = dataset.with_transform(preprocess_train) |
|
|
| return dataset |
|
|
| def collate_fn(examples): |
| images = torch.stack([example["images"] for example in examples]) |
| images = images.to(memory_format=torch.contiguous_format).float() |
| kernel = torch.stack([example["kernel"] for example in examples]) |
| kernel = kernel.to(memory_format=torch.contiguous_format).float() |
| kernel2 = torch.stack([example["kernel2"] for example in examples]) |
| kernel2 = kernel2.to(memory_format=torch.contiguous_format).float() |
| sinc_kernel = torch.stack([example["sinc_kernel"] for example in examples]) |
| sinc_kernel = sinc_kernel.to(memory_format=torch.contiguous_format).float() |
| original_sizes = [example["original_sizes"] for example in examples] |
| crop_top_lefts = [example["crop_top_lefts"] for example in examples] |
|
|
| prompts = [] |
| for example in examples: |
| prompts.append(example[args.caption_column]) if args.caption_column in example else prompts.append("") |
|
|
| return { |
| "images": images, |
| "text": prompts, |
| "kernel": kernel, |
| "kernel2": kernel2, |
| "sinc_kernel": sinc_kernel, |
| "original_sizes": original_sizes, |
| "crop_top_lefts": crop_top_lefts, |
| } |
|
|
| def encode_prompt(prompt_batch, text_encoders, tokenizers, is_train=True): |
| prompt_embeds_list = [] |
|
|
| captions = [] |
| for caption in prompt_batch: |
| if isinstance(caption, str): |
| captions.append(caption) |
| elif isinstance(caption, (list, np.ndarray)): |
| |
| captions.append(random.choice(caption) if is_train else caption[0]) |
|
|
| with torch.no_grad(): |
| for tokenizer, text_encoder in zip(tokenizers, text_encoders): |
| text_inputs = tokenizer( |
| captions, |
| padding="max_length", |
| max_length=tokenizer.model_max_length, |
| truncation=True, |
| return_tensors="pt", |
| ) |
| text_input_ids = text_inputs.input_ids |
| prompt_embeds = text_encoder( |
| text_input_ids.to(text_encoder.device), |
| output_hidden_states=True, |
| ) |
|
|
| |
| pooled_prompt_embeds = prompt_embeds[0] |
| prompt_embeds = prompt_embeds.hidden_states[-2] |
| bs_embed, seq_len, _ = prompt_embeds.shape |
| prompt_embeds_list.append(prompt_embeds) |
|
|
| prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) |
| prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) |
| pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) |
| return prompt_embeds, pooled_prompt_embeds |
|
|
| def importance_sampling_fn(t, max_t, alpha): |
| """Importance Sampling Function f(t)""" |
| return 1 / max_t * (1 - alpha * np.cos(np.pi * t / max_t)) |
|
|
| def extract_into_tensor(a, t, x_shape): |
| b, *_ = t.shape |
| out = a.gather(-1, t) |
| return out.reshape(b, *((1,) * (len(x_shape) - 1))) |
|
|
| def tensor_to_pil(images): |
| """ |
| Convert image tensor or a batch of image tensors to PIL image(s). |
| """ |
| images = (images + 1) / 2 |
| images_np = images.detach().cpu().numpy() |
| if images_np.ndim == 4: |
| images_np = np.transpose(images_np, (0, 2, 3, 1)) |
| elif images_np.ndim == 3: |
| images_np = np.transpose(images_np, (1, 2, 0)) |
| images_np = images_np[None, ...] |
| images_np = (images_np * 255).round().astype("uint8") |
| if images_np.shape[-1] == 1: |
| |
| pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images_np] |
| else: |
| pil_images = [Image.fromarray(image[:, :, :3]) for image in images_np] |
|
|
| return pil_images |
|
|
| def save_np_to_image(img_np, save_dir): |
| img_np = np.transpose(img_np, (0, 2, 3, 1)) |
| img_np = (img_np * 255).astype(np.uint8) |
| img_np = Image.fromarray(img_np[0]) |
| img_np.save(save_dir) |
|
|
|
|
| def seperate_SFT_params_from_unet(unet): |
| params = [] |
| non_params = [] |
| for name, param in unet.named_parameters(): |
| if "SFT" in name: |
| params.append(param) |
| else: |
| non_params.append(param) |
| return params, non_params |
|
|
|
|
| def seperate_lora_params_from_unet(unet): |
| keys = [] |
| frozen_keys = [] |
| for name, param in unet.named_parameters(): |
| if "lora" in name: |
| keys.append(param) |
| else: |
| frozen_keys.append(param) |
| return keys, frozen_keys |
|
|
|
|
| def seperate_ip_params_from_unet(unet): |
| ip_params = [] |
| non_ip_params = [] |
| for name, param in unet.named_parameters(): |
| if "encoder_hid_proj." in name or "_ip." in name: |
| ip_params.append(param) |
| elif "attn" in name and "processor" in name: |
| if "ip" in name or "ln" in name: |
| ip_params.append(param) |
| else: |
| non_ip_params.append(param) |
| return ip_params, non_ip_params |
|
|
|
|
| def seperate_ref_params_from_unet(unet): |
| ip_params = [] |
| non_ip_params = [] |
| for name, param in unet.named_parameters(): |
| if "encoder_hid_proj." in name or "_ip." in name: |
| ip_params.append(param) |
| elif "attn" in name and "processor" in name: |
| if "ip" in name or "ln" in name: |
| ip_params.append(param) |
| elif "extract" in name: |
| ip_params.append(param) |
| else: |
| non_ip_params.append(param) |
| return ip_params, non_ip_params |
|
|
|
|
| def seperate_ip_modules_from_unet(unet): |
| ip_modules = [] |
| non_ip_modules = [] |
| for name, module in unet.named_modules(): |
| if "encoder_hid_proj" in name or "attn2.processor" in name: |
| ip_modules.append(module) |
| else: |
| non_ip_modules.append(module) |
| return ip_modules, non_ip_modules |
|
|
|
|
| def seperate_SFT_keys_from_unet(unet): |
| keys = [] |
| non_keys = [] |
| for name, param in unet.named_parameters(): |
| if "SFT" in name: |
| keys.append(name) |
| else: |
| non_keys.append(name) |
| return keys, non_keys |
|
|
|
|
| def seperate_ip_keys_from_unet(unet): |
| keys = [] |
| non_keys = [] |
| for name, param in unet.named_parameters(): |
| if "encoder_hid_proj." in name or "_ip." in name: |
| keys.append(name) |
| elif "attn" in name and "processor" in name: |
| if "ip" in name or "ln" in name: |
| keys.append(name) |
| else: |
| non_keys.append(name) |
| return keys, non_keys |