| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import argparse |
| import copy |
| import gc |
| import hashlib |
| import logging |
| import math |
| import os |
| import random |
| import shutil |
| from contextlib import nullcontext |
| from pathlib import Path |
|
|
| import numpy as np |
| import pandas as pd |
| import torch |
| import torch.utils.checkpoint |
| import transformers |
| from accelerate import Accelerator |
| from accelerate.logging import get_logger |
| from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed |
| from huggingface_hub import create_repo, upload_folder |
| from peft import LoraConfig, set_peft_model_state_dict |
| from peft.utils import get_peft_model_state_dict |
| from PIL import Image |
| from PIL.ImageOps import exif_transpose |
| from torch.utils.data import Dataset |
| from torchvision import transforms |
| from torchvision.transforms.functional import crop |
| from tqdm.auto import tqdm |
|
|
| import diffusers |
| from diffusers import ( |
| AutoencoderKL, |
| FlowMatchEulerDiscreteScheduler, |
| SD3Transformer2DModel, |
| StableDiffusion3Pipeline, |
| ) |
| from diffusers.optimization import get_scheduler |
| from diffusers.training_utils import ( |
| cast_training_params, |
| compute_density_for_timestep_sampling, |
| compute_loss_weighting_for_sd3, |
| ) |
| from diffusers.utils import ( |
| check_min_version, |
| convert_unet_state_dict_to_peft, |
| is_wandb_available, |
| ) |
| from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card |
| from diffusers.utils.torch_utils import is_compiled_module |
|
|
|
|
| if is_wandb_available(): |
| import wandb |
|
|
| |
| check_min_version("0.30.0.dev0") |
|
|
| logger = get_logger(__name__) |
|
|
|
|
| def save_model_card( |
| repo_id: str, |
| images=None, |
| base_model: str = None, |
| train_text_encoder=False, |
| instance_prompt=None, |
| validation_prompt=None, |
| repo_folder=None, |
| ): |
| widget_dict = [] |
| if images is not None: |
| for i, image in enumerate(images): |
| image.save(os.path.join(repo_folder, f"image_{i}.png")) |
| widget_dict.append( |
| {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}} |
| ) |
|
|
| model_description = f""" |
| # SD3 DreamBooth LoRA - {repo_id} |
| |
| <Gallery /> |
| |
| ## Model description |
| |
| These are {repo_id} DreamBooth weights for {base_model}. |
| |
| The weights were trained using [DreamBooth](https://dreambooth.github.io/). |
| |
| LoRA for the text encoder was enabled: {train_text_encoder}. |
| |
| ## Trigger words |
| |
| You should use {instance_prompt} to trigger the image generation. |
| |
| ## Download model |
| |
| [Download]({repo_id}/tree/main) them in the Files & versions tab. |
| |
| ## License |
| |
| Please adhere to the licensing terms as described [here](https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE). |
| """ |
| model_card = load_or_create_model_card( |
| repo_id_or_path=repo_id, |
| from_training=True, |
| license="openrail++", |
| base_model=base_model, |
| prompt=instance_prompt, |
| model_description=model_description, |
| widget=widget_dict, |
| ) |
| tags = [ |
| "text-to-image", |
| "diffusers-training", |
| "diffusers", |
| "lora", |
| "sd3", |
| "sd3-diffusers", |
| "template:sd-lora", |
| ] |
|
|
| model_card = populate_model_card(model_card, tags=tags) |
| model_card.save(os.path.join(repo_folder, "README.md")) |
|
|
|
|
| def log_validation( |
| pipeline, |
| args, |
| accelerator, |
| pipeline_args, |
| epoch, |
| is_final_validation=False, |
| ): |
| logger.info( |
| f"Running validation... \n Generating {args.num_validation_images} images with prompt:" |
| f" {args.validation_prompt}." |
| ) |
| pipeline.enable_model_cpu_offload() |
| pipeline.set_progress_bar_config(disable=True) |
|
|
| |
| generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None |
| |
| autocast_ctx = nullcontext() |
|
|
| with autocast_ctx: |
| images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] |
|
|
| for tracker in accelerator.trackers: |
| phase_name = "test" if is_final_validation else "validation" |
| if tracker.name == "tensorboard": |
| np_images = np.stack([np.asarray(img) for img in images]) |
| tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC") |
| if tracker.name == "wandb": |
| tracker.log( |
| { |
| phase_name: [ |
| wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) |
| ] |
| } |
| ) |
|
|
| del pipeline |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| return images |
|
|
|
|
| def parse_args(input_args=None): |
| parser = argparse.ArgumentParser(description="Simple example of a training script.") |
| parser.add_argument( |
| "--pretrained_model_name_or_path", |
| type=str, |
| default=None, |
| required=True, |
| help="Path to pretrained model or model identifier from huggingface.co/models.", |
| ) |
| parser.add_argument( |
| "--revision", |
| type=str, |
| default=None, |
| required=False, |
| help="Revision of pretrained model identifier from huggingface.co/models.", |
| ) |
| parser.add_argument( |
| "--variant", |
| type=str, |
| default=None, |
| help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", |
| ) |
| parser.add_argument( |
| "--instance_data_dir", |
| type=str, |
| default=None, |
| help=("A folder containing the training data. "), |
| ) |
| parser.add_argument( |
| "--data_df_path", |
| type=str, |
| default=None, |
| help=("Path to the parquet file serialized with compute_embeddings.py."), |
| ) |
| parser.add_argument( |
| "--cache_dir", |
| type=str, |
| default=None, |
| help="The directory where the downloaded models and datasets will be stored.", |
| ) |
| parser.add_argument( |
| "--instance_prompt", |
| type=str, |
| default=None, |
| required=True, |
| help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'", |
| ) |
| parser.add_argument( |
| "--max_sequence_length", |
| type=int, |
| default=77, |
| help="Maximum sequence length to use with with the T5 text encoder", |
| ) |
| parser.add_argument( |
| "--validation_prompt", |
| type=str, |
| default=None, |
| help="A prompt that is used during validation to verify that the model is learning.", |
| ) |
| parser.add_argument( |
| "--num_validation_images", |
| type=int, |
| default=4, |
| help="Number of images that should be generated during validation with `validation_prompt`.", |
| ) |
| parser.add_argument( |
| "--validation_epochs", |
| type=int, |
| default=50, |
| help=( |
| "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt" |
| " `args.validation_prompt` multiple times: `args.num_validation_images`." |
| ), |
| ) |
| parser.add_argument( |
| "--rank", |
| type=int, |
| default=4, |
| help=("The dimension of the LoRA update matrices."), |
| ) |
| parser.add_argument( |
| "--output_dir", |
| type=str, |
| default="sd3-dreambooth-lora", |
| help="The output directory where the model predictions and checkpoints will be written.", |
| ) |
| parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") |
| parser.add_argument( |
| "--resolution", |
| type=int, |
| default=512, |
| help=( |
| "The resolution for input images, all the images in the train/validation dataset will be resized to this" |
| " resolution" |
| ), |
| ) |
| parser.add_argument( |
| "--center_crop", |
| default=False, |
| action="store_true", |
| help=( |
| "Whether to center crop the input images to the resolution. If not set, the images will be randomly" |
| " cropped. The images will be resized to the resolution first before cropping." |
| ), |
| ) |
| parser.add_argument( |
| "--random_flip", |
| action="store_true", |
| help="whether to randomly flip images horizontally", |
| ) |
|
|
| parser.add_argument( |
| "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." |
| ) |
| parser.add_argument("--num_train_epochs", type=int, default=1) |
| parser.add_argument( |
| "--max_train_steps", |
| type=int, |
| default=None, |
| help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
| ) |
| parser.add_argument( |
| "--checkpointing_steps", |
| type=int, |
| default=500, |
| help=( |
| "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" |
| " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" |
| " training using `--resume_from_checkpoint`." |
| ), |
| ) |
| parser.add_argument( |
| "--checkpoints_total_limit", |
| type=int, |
| default=None, |
| help=("Max number of checkpoints to store."), |
| ) |
| parser.add_argument( |
| "--resume_from_checkpoint", |
| type=str, |
| default=None, |
| help=( |
| "Whether training should be resumed from a previous checkpoint. Use a path saved by" |
| ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' |
| ), |
| ) |
| parser.add_argument( |
| "--gradient_accumulation_steps", |
| type=int, |
| default=1, |
| help="Number of updates steps to accumulate before performing a backward/update pass.", |
| ) |
| parser.add_argument( |
| "--gradient_checkpointing", |
| action="store_true", |
| help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", |
| ) |
| parser.add_argument( |
| "--learning_rate", |
| type=float, |
| default=1e-4, |
| help="Initial learning rate (after the potential warmup period) to use.", |
| ) |
| parser.add_argument( |
| "--scale_lr", |
| action="store_true", |
| default=False, |
| help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", |
| ) |
| parser.add_argument( |
| "--lr_scheduler", |
| type=str, |
| default="constant", |
| help=( |
| 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' |
| ' "constant", "constant_with_warmup"]' |
| ), |
| ) |
| parser.add_argument( |
| "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." |
| ) |
| parser.add_argument( |
| "--lr_num_cycles", |
| type=int, |
| default=1, |
| help="Number of hard resets of the lr in cosine_with_restarts scheduler.", |
| ) |
| parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") |
| parser.add_argument( |
| "--dataloader_num_workers", |
| type=int, |
| default=0, |
| help=( |
| "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." |
| ), |
| ) |
| parser.add_argument( |
| "--weighting_scheme", |
| type=str, |
| default="logit_normal", |
| choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"], |
| ) |
| parser.add_argument( |
| "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." |
| ) |
| parser.add_argument( |
| "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." |
| ) |
| parser.add_argument( |
| "--mode_scale", |
| type=float, |
| default=1.29, |
| help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", |
| ) |
| parser.add_argument( |
| "--optimizer", |
| type=str, |
| default="AdamW", |
| help=('The optimizer type to use. Choose between ["AdamW"]'), |
| ) |
|
|
| parser.add_argument( |
| "--use_8bit_adam", |
| action="store_true", |
| help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW", |
| ) |
|
|
| parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") |
| parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") |
| parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") |
|
|
| parser.add_argument( |
| "--adam_epsilon", |
| type=float, |
| default=1e-08, |
| help="Epsilon value for the Adam optimizer.", |
| ) |
| parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") |
| parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") |
| parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") |
| parser.add_argument( |
| "--hub_model_id", |
| type=str, |
| default=None, |
| help="The name of the repository to keep in sync with the local `output_dir`.", |
| ) |
| parser.add_argument( |
| "--logging_dir", |
| type=str, |
| default="logs", |
| help=( |
| "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" |
| " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." |
| ), |
| ) |
| parser.add_argument( |
| "--allow_tf32", |
| action="store_true", |
| help=( |
| "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" |
| " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" |
| ), |
| ) |
| parser.add_argument( |
| "--report_to", |
| type=str, |
| default="tensorboard", |
| help=( |
| 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' |
| ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' |
| ), |
| ) |
| parser.add_argument( |
| "--mixed_precision", |
| type=str, |
| default=None, |
| choices=["no", "fp16", "bf16"], |
| help=( |
| "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" |
| " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" |
| " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." |
| ), |
| ) |
| parser.add_argument( |
| "--prior_generation_precision", |
| type=str, |
| default=None, |
| choices=["no", "fp32", "fp16", "bf16"], |
| help=( |
| "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" |
| " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." |
| ), |
| ) |
| parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") |
|
|
| if input_args is not None: |
| args = parser.parse_args(input_args) |
| else: |
| args = parser.parse_args() |
|
|
| if args.instance_data_dir is None: |
| raise ValueError("Specify `instance_data_dir`.") |
|
|
| env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) |
| if env_local_rank != -1 and env_local_rank != args.local_rank: |
| args.local_rank = env_local_rank |
|
|
| return args |
|
|
|
|
| class DreamBoothDataset(Dataset): |
| """ |
| A dataset to prepare the instance and class images with the prompts for fine-tuning the model. |
| It pre-processes the images. |
| """ |
|
|
| def __init__( |
| self, |
| data_df_path, |
| instance_data_root, |
| instance_prompt, |
| size=1024, |
| center_crop=False, |
| ): |
| |
| self.size = size |
| self.center_crop = center_crop |
|
|
| self.instance_prompt = instance_prompt |
| self.instance_data_root = Path(instance_data_root) |
| if not self.instance_data_root.exists(): |
| raise ValueError("Instance images root doesn't exists.") |
|
|
| |
| instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())] |
| image_hashes = [self.generate_image_hash(path) for path in list(Path(instance_data_root).iterdir())] |
| self.instance_images = instance_images |
| self.image_hashes = image_hashes |
|
|
| |
| self.pixel_values = self.apply_image_transformations( |
| instance_images=instance_images, size=size, center_crop=center_crop |
| ) |
|
|
| |
| self.data_dict = self.map_image_hash_embedding(data_df_path=data_df_path) |
|
|
| self.num_instance_images = len(instance_images) |
| self._length = self.num_instance_images |
|
|
| def __len__(self): |
| return self._length |
|
|
| def __getitem__(self, index): |
| example = {} |
| instance_image = self.pixel_values[index % self.num_instance_images] |
| image_hash = self.image_hashes[index % self.num_instance_images] |
| prompt_embeds, pooled_prompt_embeds = self.data_dict[image_hash] |
| example["instance_images"] = instance_image |
| example["prompt_embeds"] = prompt_embeds |
| example["pooled_prompt_embeds"] = pooled_prompt_embeds |
| return example |
|
|
| def apply_image_transformations(self, instance_images, size, center_crop): |
| pixel_values = [] |
|
|
| train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) |
| train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) |
| train_flip = transforms.RandomHorizontalFlip(p=1.0) |
| train_transforms = transforms.Compose( |
| [ |
| transforms.ToTensor(), |
| transforms.Normalize([0.5], [0.5]), |
| ] |
| ) |
| for image in instance_images: |
| image = exif_transpose(image) |
| if not image.mode == "RGB": |
| image = image.convert("RGB") |
| image = train_resize(image) |
| if args.random_flip and random.random() < 0.5: |
| |
| image = train_flip(image) |
| if args.center_crop: |
| y1 = max(0, int(round((image.height - args.resolution) / 2.0))) |
| x1 = max(0, int(round((image.width - args.resolution) / 2.0))) |
| image = train_crop(image) |
| else: |
| y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) |
| image = crop(image, y1, x1, h, w) |
| image = train_transforms(image) |
| pixel_values.append(image) |
|
|
| return pixel_values |
|
|
| def convert_to_torch_tensor(self, embeddings: list): |
| prompt_embeds = embeddings[0] |
| pooled_prompt_embeds = embeddings[1] |
| prompt_embeds = np.array(prompt_embeds).reshape(154, 4096) |
| pooled_prompt_embeds = np.array(pooled_prompt_embeds).reshape(2048) |
| return torch.from_numpy(prompt_embeds), torch.from_numpy(pooled_prompt_embeds) |
|
|
| def map_image_hash_embedding(self, data_df_path): |
| hashes_df = pd.read_parquet(data_df_path) |
| data_dict = {} |
| for i, row in hashes_df.iterrows(): |
| embeddings = [row["prompt_embeds"], row["pooled_prompt_embeds"]] |
| prompt_embeds, pooled_prompt_embeds = self.convert_to_torch_tensor(embeddings=embeddings) |
| data_dict.update({row["image_hash"]: (prompt_embeds, pooled_prompt_embeds)}) |
| return data_dict |
|
|
| def generate_image_hash(self, image_path): |
| with open(image_path, "rb") as f: |
| img_data = f.read() |
| return hashlib.sha256(img_data).hexdigest() |
|
|
|
|
| def collate_fn(examples): |
| pixel_values = [example["instance_images"] for example in examples] |
| prompt_embeds = [example["prompt_embeds"] for example in examples] |
| pooled_prompt_embeds = [example["pooled_prompt_embeds"] for example in examples] |
|
|
| pixel_values = torch.stack(pixel_values) |
| pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() |
| prompt_embeds = torch.stack(prompt_embeds) |
| pooled_prompt_embeds = torch.stack(pooled_prompt_embeds) |
|
|
| batch = { |
| "pixel_values": pixel_values, |
| "prompt_embeds": prompt_embeds, |
| "pooled_prompt_embeds": pooled_prompt_embeds, |
| } |
| return batch |
|
|
|
|
| def main(args): |
| if args.report_to == "wandb" and args.hub_token is not None: |
| raise ValueError( |
| "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." |
| " Please use `huggingface-cli login` to authenticate with the Hub." |
| ) |
|
|
| if torch.backends.mps.is_available() and args.mixed_precision == "bf16": |
| |
| raise ValueError( |
| "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." |
| ) |
|
|
| logging_dir = Path(args.output_dir, args.logging_dir) |
|
|
| accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) |
| kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) |
| accelerator = Accelerator( |
| gradient_accumulation_steps=args.gradient_accumulation_steps, |
| mixed_precision=args.mixed_precision, |
| log_with=args.report_to, |
| project_config=accelerator_project_config, |
| kwargs_handlers=[kwargs], |
| ) |
|
|
| |
| if torch.backends.mps.is_available(): |
| accelerator.native_amp = False |
|
|
| if args.report_to == "wandb": |
| if not is_wandb_available(): |
| raise ImportError("Make sure to install wandb if you want to use it for logging during training.") |
|
|
| |
| logging.basicConfig( |
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
| datefmt="%m/%d/%Y %H:%M:%S", |
| level=logging.INFO, |
| ) |
| logger.info(accelerator.state, main_process_only=False) |
| if accelerator.is_local_main_process: |
| transformers.utils.logging.set_verbosity_warning() |
| diffusers.utils.logging.set_verbosity_info() |
| else: |
| transformers.utils.logging.set_verbosity_error() |
| diffusers.utils.logging.set_verbosity_error() |
|
|
| |
| if args.seed is not None: |
| set_seed(args.seed) |
|
|
| |
| if accelerator.is_main_process: |
| if args.output_dir is not None: |
| os.makedirs(args.output_dir, exist_ok=True) |
|
|
| if args.push_to_hub: |
| repo_id = create_repo( |
| repo_id=args.hub_model_id or Path(args.output_dir).name, |
| exist_ok=True, |
| ).repo_id |
|
|
| |
| noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( |
| args.pretrained_model_name_or_path, subfolder="scheduler" |
| ) |
| noise_scheduler_copy = copy.deepcopy(noise_scheduler) |
| vae = AutoencoderKL.from_pretrained( |
| args.pretrained_model_name_or_path, |
| subfolder="vae", |
| revision=args.revision, |
| variant=args.variant, |
| ) |
| transformer = SD3Transformer2DModel.from_pretrained( |
| args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant |
| ) |
|
|
| transformer.requires_grad_(False) |
| vae.requires_grad_(False) |
|
|
| |
| |
| weight_dtype = torch.float32 |
| if accelerator.mixed_precision == "fp16": |
| weight_dtype = torch.float16 |
| elif accelerator.mixed_precision == "bf16": |
| weight_dtype = torch.bfloat16 |
|
|
| if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: |
| |
| raise ValueError( |
| "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." |
| ) |
|
|
| vae.to(accelerator.device, dtype=torch.float32) |
| transformer.to(accelerator.device, dtype=weight_dtype) |
|
|
| if args.gradient_checkpointing: |
| transformer.enable_gradient_checkpointing() |
|
|
| |
| transformer_lora_config = LoraConfig( |
| r=args.rank, |
| lora_alpha=args.rank, |
| init_lora_weights="gaussian", |
| target_modules=["to_k", "to_q", "to_v", "to_out.0"], |
| ) |
| transformer.add_adapter(transformer_lora_config) |
|
|
| def unwrap_model(model): |
| model = accelerator.unwrap_model(model) |
| model = model._orig_mod if is_compiled_module(model) else model |
| return model |
|
|
| |
| def save_model_hook(models, weights, output_dir): |
| if accelerator.is_main_process: |
| transformer_lora_layers_to_save = None |
| for model in models: |
| if isinstance(model, type(unwrap_model(transformer))): |
| transformer_lora_layers_to_save = get_peft_model_state_dict(model) |
| else: |
| raise ValueError(f"unexpected save model: {model.__class__}") |
|
|
| |
| weights.pop() |
|
|
| StableDiffusion3Pipeline.save_lora_weights( |
| output_dir, |
| transformer_lora_layers=transformer_lora_layers_to_save, |
| ) |
|
|
| def load_model_hook(models, input_dir): |
| transformer_ = None |
|
|
| while len(models) > 0: |
| model = models.pop() |
|
|
| if isinstance(model, type(unwrap_model(transformer))): |
| transformer_ = model |
| else: |
| raise ValueError(f"unexpected save model: {model.__class__}") |
|
|
| lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir) |
|
|
| transformer_state_dict = { |
| f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.") |
| } |
| transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) |
| incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") |
| if incompatible_keys is not None: |
| |
| unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) |
| if unexpected_keys: |
| logger.warning( |
| f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " |
| f" {unexpected_keys}. " |
| ) |
|
|
| |
| |
| |
| if args.mixed_precision == "fp16": |
| models = [transformer_] |
| |
| cast_training_params(models) |
|
|
| accelerator.register_save_state_pre_hook(save_model_hook) |
| accelerator.register_load_state_pre_hook(load_model_hook) |
|
|
| |
| |
| if args.allow_tf32 and torch.cuda.is_available(): |
| torch.backends.cuda.matmul.allow_tf32 = True |
|
|
| if args.scale_lr: |
| args.learning_rate = ( |
| args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes |
| ) |
|
|
| |
| if args.mixed_precision == "fp16": |
| models = [transformer] |
| |
| cast_training_params(models, dtype=torch.float32) |
|
|
| |
| transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) |
| transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate} |
| params_to_optimize = [transformer_parameters_with_lr] |
|
|
| |
| if not args.optimizer.lower() == "adamw": |
| logger.warning( |
| f"Unsupported choice of optimizer: {args.optimizer}. Supported optimizers include [adamW]." |
| "Defaulting to adamW" |
| ) |
| args.optimizer = "adamw" |
|
|
| if args.use_8bit_adam and not args.optimizer.lower() == "adamw": |
| logger.warning( |
| f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was " |
| f"set to {args.optimizer.lower()}" |
| ) |
|
|
| if args.optimizer.lower() == "adamw": |
| if args.use_8bit_adam: |
| try: |
| import bitsandbytes as bnb |
| except ImportError: |
| raise ImportError( |
| "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." |
| ) |
|
|
| optimizer_class = bnb.optim.AdamW8bit |
| else: |
| optimizer_class = torch.optim.AdamW |
|
|
| optimizer = optimizer_class( |
| params_to_optimize, |
| betas=(args.adam_beta1, args.adam_beta2), |
| weight_decay=args.adam_weight_decay, |
| eps=args.adam_epsilon, |
| ) |
|
|
| |
| train_dataset = DreamBoothDataset( |
| data_df_path=args.data_df_path, |
| instance_data_root=args.instance_data_dir, |
| instance_prompt=args.instance_prompt, |
| size=args.resolution, |
| center_crop=args.center_crop, |
| ) |
|
|
| train_dataloader = torch.utils.data.DataLoader( |
| train_dataset, |
| batch_size=args.train_batch_size, |
| shuffle=True, |
| collate_fn=lambda examples: collate_fn(examples), |
| num_workers=args.dataloader_num_workers, |
| ) |
|
|
| |
| overrode_max_train_steps = False |
| num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) |
| if args.max_train_steps is None: |
| args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
| overrode_max_train_steps = True |
|
|
| lr_scheduler = get_scheduler( |
| args.lr_scheduler, |
| optimizer=optimizer, |
| num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, |
| num_training_steps=args.max_train_steps * accelerator.num_processes, |
| num_cycles=args.lr_num_cycles, |
| power=args.lr_power, |
| ) |
|
|
| |
| transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( |
| transformer, optimizer, train_dataloader, lr_scheduler |
| ) |
|
|
| |
| num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) |
| if overrode_max_train_steps: |
| args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
| |
| args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) |
|
|
| |
| |
| if accelerator.is_main_process: |
| tracker_name = "dreambooth-sd3-lora-miniature" |
| accelerator.init_trackers(tracker_name, config=vars(args)) |
|
|
| |
| total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps |
|
|
| logger.info("***** Running training *****") |
| logger.info(f" Num examples = {len(train_dataset)}") |
| logger.info(f" Num batches each epoch = {len(train_dataloader)}") |
| logger.info(f" Num Epochs = {args.num_train_epochs}") |
| logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") |
| logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") |
| logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") |
| logger.info(f" Total optimization steps = {args.max_train_steps}") |
| global_step = 0 |
| first_epoch = 0 |
|
|
| |
| if args.resume_from_checkpoint: |
| if args.resume_from_checkpoint != "latest": |
| path = os.path.basename(args.resume_from_checkpoint) |
| else: |
| |
| dirs = os.listdir(args.output_dir) |
| dirs = [d for d in dirs if d.startswith("checkpoint")] |
| dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) |
| path = dirs[-1] if len(dirs) > 0 else None |
|
|
| if path is None: |
| accelerator.print( |
| f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." |
| ) |
| args.resume_from_checkpoint = None |
| initial_global_step = 0 |
| else: |
| accelerator.print(f"Resuming from checkpoint {path}") |
| accelerator.load_state(os.path.join(args.output_dir, path)) |
| global_step = int(path.split("-")[1]) |
|
|
| initial_global_step = global_step |
| first_epoch = global_step // num_update_steps_per_epoch |
|
|
| else: |
| initial_global_step = 0 |
|
|
| progress_bar = tqdm( |
| range(0, args.max_train_steps), |
| initial=initial_global_step, |
| desc="Steps", |
| |
| disable=not accelerator.is_local_main_process, |
| ) |
|
|
| def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): |
| sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) |
| schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device) |
| timesteps = timesteps.to(accelerator.device) |
| step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] |
|
|
| sigma = sigmas[step_indices].flatten() |
| while len(sigma.shape) < n_dim: |
| sigma = sigma.unsqueeze(-1) |
| return sigma |
|
|
| for epoch in range(first_epoch, args.num_train_epochs): |
| transformer.train() |
|
|
| for step, batch in enumerate(train_dataloader): |
| models_to_accumulate = [transformer] |
| with accelerator.accumulate(models_to_accumulate): |
| pixel_values = batch["pixel_values"].to(dtype=vae.dtype) |
|
|
| |
| model_input = vae.encode(pixel_values).latent_dist.sample() |
| model_input = model_input * vae.config.scaling_factor |
| model_input = model_input.to(dtype=weight_dtype) |
|
|
| |
| noise = torch.randn_like(model_input) |
| bsz = model_input.shape[0] |
|
|
| |
| |
| u = compute_density_for_timestep_sampling( |
| weighting_scheme=args.weighting_scheme, |
| batch_size=bsz, |
| logit_mean=args.logit_mean, |
| logit_std=args.logit_std, |
| mode_scale=args.mode_scale, |
| ) |
| indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() |
| timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) |
|
|
| |
| sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) |
| noisy_model_input = sigmas * noise + (1.0 - sigmas) * model_input |
|
|
| |
| prompt_embeds, pooled_prompt_embeds = batch["prompt_embeds"], batch["pooled_prompt_embeds"] |
| prompt_embeds = prompt_embeds.to(device=accelerator.device, dtype=weight_dtype) |
| pooled_prompt_embeds = pooled_prompt_embeds.to(device=accelerator.device, dtype=weight_dtype) |
| model_pred = transformer( |
| hidden_states=noisy_model_input, |
| timestep=timesteps, |
| encoder_hidden_states=prompt_embeds, |
| pooled_projections=pooled_prompt_embeds, |
| return_dict=False, |
| )[0] |
|
|
| |
| |
| model_pred = model_pred * (-sigmas) + noisy_model_input |
|
|
| |
| |
| weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) |
|
|
| |
| target = model_input |
|
|
| |
| loss = torch.mean( |
| (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), |
| 1, |
| ) |
| loss = loss.mean() |
|
|
| accelerator.backward(loss) |
| if accelerator.sync_gradients: |
| params_to_clip = transformer_lora_parameters |
| accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) |
|
|
| optimizer.step() |
| lr_scheduler.step() |
| optimizer.zero_grad() |
|
|
| |
| if accelerator.sync_gradients: |
| progress_bar.update(1) |
| global_step += 1 |
|
|
| if accelerator.is_main_process: |
| if global_step % args.checkpointing_steps == 0: |
| |
| if args.checkpoints_total_limit is not None: |
| checkpoints = os.listdir(args.output_dir) |
| checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] |
| checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) |
|
|
| |
| if len(checkpoints) >= args.checkpoints_total_limit: |
| num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 |
| removing_checkpoints = checkpoints[0:num_to_remove] |
|
|
| logger.info( |
| f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" |
| ) |
| logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") |
|
|
| for removing_checkpoint in removing_checkpoints: |
| removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) |
| shutil.rmtree(removing_checkpoint) |
|
|
| save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") |
| accelerator.save_state(save_path) |
| logger.info(f"Saved state to {save_path}") |
|
|
| logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} |
| progress_bar.set_postfix(**logs) |
| accelerator.log(logs, step=global_step) |
|
|
| if global_step >= args.max_train_steps: |
| break |
|
|
| if accelerator.is_main_process: |
| if args.validation_prompt is not None and epoch % args.validation_epochs == 0: |
| pipeline = StableDiffusion3Pipeline.from_pretrained( |
| args.pretrained_model_name_or_path, |
| vae=vae, |
| transformer=accelerator.unwrap_model(transformer), |
| revision=args.revision, |
| variant=args.variant, |
| torch_dtype=weight_dtype, |
| ) |
| pipeline_args = {"prompt": args.validation_prompt} |
| images = log_validation( |
| pipeline=pipeline, |
| args=args, |
| accelerator=accelerator, |
| pipeline_args=pipeline_args, |
| epoch=epoch, |
| ) |
| torch.cuda.empty_cache() |
| gc.collect() |
|
|
| |
| accelerator.wait_for_everyone() |
| if accelerator.is_main_process: |
| transformer = unwrap_model(transformer) |
| transformer = transformer.to(torch.float32) |
| transformer_lora_layers = get_peft_model_state_dict(transformer) |
|
|
| StableDiffusion3Pipeline.save_lora_weights( |
| save_directory=args.output_dir, |
| transformer_lora_layers=transformer_lora_layers, |
| ) |
|
|
| |
| |
| pipeline = StableDiffusion3Pipeline.from_pretrained( |
| args.pretrained_model_name_or_path, |
| revision=args.revision, |
| variant=args.variant, |
| torch_dtype=weight_dtype, |
| ) |
| |
| pipeline.load_lora_weights(args.output_dir) |
|
|
| |
| images = [] |
| if args.validation_prompt and args.num_validation_images > 0: |
| pipeline_args = {"prompt": args.validation_prompt} |
| images = log_validation( |
| pipeline=pipeline, |
| args=args, |
| accelerator=accelerator, |
| pipeline_args=pipeline_args, |
| epoch=epoch, |
| is_final_validation=True, |
| ) |
|
|
| if args.push_to_hub: |
| save_model_card( |
| repo_id, |
| images=images, |
| base_model=args.pretrained_model_name_or_path, |
| instance_prompt=args.instance_prompt, |
| validation_prompt=args.validation_prompt, |
| repo_folder=args.output_dir, |
| ) |
| upload_folder( |
| repo_id=repo_id, |
| folder_path=args.output_dir, |
| commit_message="End of training", |
| ignore_patterns=["step_*", "epoch_*"], |
| ) |
|
|
| accelerator.end_training() |
|
|
|
|
| if __name__ == "__main__": |
| args = parse_args() |
| main(args) |
|
|