harry152 commited on
Commit
3040d13
·
verified ·
1 Parent(s): 5f6420c

Upload folder using huggingface_hub

Browse files
default_config.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ deepspeed_config:
4
+ gradient_accumulation_steps: 4
5
+ gradient_clipping: 1
6
+ offload_optimizer_device: none
7
+ offload_param_device: none
8
+ deepspeed_config_file: .
9
+ deepspeed_moe_layer_cls_names: ''
10
+ zero3_init_flag: false
11
+ zero_stage: 2
12
+ distributed_type: DEEPSPEED
13
+ # downcast_bf16: 'no'
14
+ dynamo_config:
15
+ dynamo_backend: AOT_TS_NVFUSER
16
+ enable_cpu_affinity: false
17
+ machine_rank: 0
18
+ main_training_function: main
19
+ num_machines: 1
20
+ num_processes: 4
21
+ rdzv_backend: static
22
+ same_network: true
23
+ tpu_env: []
24
+ tpu_use_cluster: false
25
+ tpu_use_sudo: false
26
+ use_cpu: false
ds_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "zero_optimization": {
3
+ "stage": 2,
4
+ "allgather_partitions": true,
5
+ "reduce_scatter": true,
6
+ "contiguous_gradients": true
7
+ },
8
+ "bf16": {
9
+ "enabled": true
10
+ },
11
+ "gradient_clipping": 1.0,
12
+ "train_batch_size": 1,
13
+ "gradient_accumulation_steps": 4
14
+ }
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ diffusers
2
+ transformers
3
+ bitsandbytes
4
+ peft
5
+ sentencepiece
6
+ protobuf
train_flux_control.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate launch train_flux_control_lora.py \
2
+ --pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \
3
+ --dataset_name="raulc0399/open_pose_controlnet" \
4
+ --output_dir="pose-control-lora" \
5
+ --mixed_precision="bf16" \
6
+ --train_batch_size=1 \
7
+ --rank=64 \
8
+ --gradient_accumulation_steps=4 \
9
+ --gradient_checkpointing \
10
+ --use_8bit_adam \
11
+ --learning_rate=1e-4 \
12
+ --report_to="wandb" \
13
+ --lr_scheduler="constant" \
14
+ --lr_warmup_steps=0 \
15
+ --max_train_steps=5000 \
16
+ --validation_image="openpose.png" \
17
+ --validation_prompt="A couple, 4k photo, highly detailed" \
18
+ --offload \
19
+ --seed="0" \
20
+ --push_to_hub
train_flux_control_lora.py ADDED
@@ -0,0 +1,1405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+
16
+ import argparse
17
+ import copy
18
+ import logging
19
+ import math
20
+ import os
21
+ import random
22
+ import shutil
23
+ from contextlib import nullcontext
24
+ from pathlib import Path
25
+
26
+ import accelerate
27
+ import diffusers
28
+ import numpy as np
29
+ import torch
30
+ import transformers
31
+ from accelerate import Accelerator
32
+ from accelerate.logging import get_logger
33
+ from accelerate.utils import DistributedType, ProjectConfiguration, set_seed
34
+ from datasets import load_dataset
35
+ from diffusers import (AutoencoderKL, FlowMatchEulerDiscreteScheduler,
36
+ FluxControlPipeline, FluxTransformer2DModel)
37
+ from diffusers.optimization import get_scheduler
38
+ from diffusers.training_utils import (cast_training_params,
39
+ compute_density_for_timestep_sampling,
40
+ compute_loss_weighting_for_sd3,
41
+ free_memory)
42
+ from diffusers.utils import (check_min_version, is_wandb_available, load_image,
43
+ make_image_grid)
44
+ from diffusers.utils.hub_utils import (load_or_create_model_card,
45
+ populate_model_card)
46
+ from diffusers.utils.torch_utils import is_compiled_module
47
+ from huggingface_hub import create_repo, upload_folder
48
+ from packaging import version
49
+ from peft import LoraConfig, set_peft_model_state_dict
50
+ from peft.utils import get_peft_model_state_dict
51
+ from PIL import Image
52
+ from torchvision import transforms
53
+ from tqdm.auto import tqdm
54
+
55
+ if is_wandb_available():
56
+ import wandb
57
+
58
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
59
+ check_min_version("0.35.0.dev0")
60
+
61
+ logger = get_logger(__name__)
62
+
63
+ NORM_LAYER_PREFIXES = ["norm_q", "norm_k", "norm_added_q", "norm_added_k"]
64
+
65
+
66
+ def encode_images(pixels: torch.Tensor, vae: torch.nn.Module, weight_dtype):
67
+ pixel_latents = vae.encode(pixels.to(vae.dtype)).latent_dist.sample()
68
+ pixel_latents = (pixel_latents - vae.config.shift_factor) * vae.config.scaling_factor
69
+ return pixel_latents.to(weight_dtype)
70
+
71
+
72
+ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_final_validation=False):
73
+ logger.info("Running validation... ")
74
+
75
+ if not is_final_validation:
76
+ flux_transformer = accelerator.unwrap_model(flux_transformer)
77
+ pipeline = FluxControlPipeline.from_pretrained(
78
+ args.pretrained_model_name_or_path,
79
+ transformer=flux_transformer,
80
+ torch_dtype=weight_dtype,
81
+ )
82
+ else:
83
+ transformer = FluxTransformer2DModel.from_pretrained(
84
+ args.pretrained_model_name_or_path, subfolder="transformer", torch_dtype=weight_dtype
85
+ )
86
+ initial_channels = transformer.config.in_channels
87
+ pipeline = FluxControlPipeline.from_pretrained(
88
+ args.pretrained_model_name_or_path,
89
+ transformer=transformer,
90
+ torch_dtype=weight_dtype,
91
+ )
92
+ pipeline.load_lora_weights(args.output_dir)
93
+ assert pipeline.transformer.config.in_channels == initial_channels * 2, (
94
+ f"{pipeline.transformer.config.in_channels=}"
95
+ )
96
+
97
+ pipeline.to(accelerator.device)
98
+ pipeline.set_progress_bar_config(disable=True)
99
+
100
+ if args.seed is None:
101
+ generator = None
102
+ else:
103
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
104
+
105
+ if len(args.validation_image) == len(args.validation_prompt):
106
+ validation_images = args.validation_image
107
+ validation_prompts = args.validation_prompt
108
+ elif len(args.validation_image) == 1:
109
+ validation_images = args.validation_image * len(args.validation_prompt)
110
+ validation_prompts = args.validation_prompt
111
+ elif len(args.validation_prompt) == 1:
112
+ validation_images = args.validation_image
113
+ validation_prompts = args.validation_prompt * len(args.validation_image)
114
+ else:
115
+ raise ValueError(
116
+ "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
117
+ )
118
+
119
+ image_logs = []
120
+ if is_final_validation or torch.backends.mps.is_available():
121
+ autocast_ctx = nullcontext()
122
+ else:
123
+ autocast_ctx = torch.autocast(accelerator.device.type, weight_dtype)
124
+
125
+ for validation_prompt, validation_image in zip(validation_prompts, validation_images):
126
+ validation_image = load_image(validation_image)
127
+ # maybe need to inference on 1024 to get a good image
128
+ validation_image = validation_image.resize((args.resolution, args.resolution))
129
+
130
+ images = []
131
+
132
+ for _ in range(args.num_validation_images):
133
+ with autocast_ctx:
134
+ image = pipeline(
135
+ prompt=validation_prompt,
136
+ control_image=validation_image,
137
+ num_inference_steps=50,
138
+ guidance_scale=args.guidance_scale,
139
+ generator=generator,
140
+ max_sequence_length=512,
141
+ height=args.resolution,
142
+ width=args.resolution,
143
+ ).images[0]
144
+ image = image.resize((args.resolution, args.resolution))
145
+ images.append(image)
146
+ image_logs.append(
147
+ {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
148
+ )
149
+
150
+ tracker_key = "test" if is_final_validation else "validation"
151
+ for tracker in accelerator.trackers:
152
+ if tracker.name == "tensorboard":
153
+ for log in image_logs:
154
+ images = log["images"]
155
+ validation_prompt = log["validation_prompt"]
156
+ validation_image = log["validation_image"]
157
+ formatted_images = []
158
+ formatted_images.append(np.asarray(validation_image))
159
+ for image in images:
160
+ formatted_images.append(np.asarray(image))
161
+ formatted_images = np.stack(formatted_images)
162
+ tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC")
163
+
164
+ elif tracker.name == "wandb":
165
+ formatted_images = []
166
+ for log in image_logs:
167
+ images = log["images"]
168
+ validation_prompt = log["validation_prompt"]
169
+ validation_image = log["validation_image"]
170
+ formatted_images.append(wandb.Image(validation_image, caption="Conditioning"))
171
+ for image in images:
172
+ image = wandb.Image(image, caption=validation_prompt)
173
+ formatted_images.append(image)
174
+
175
+ tracker.log({tracker_key: formatted_images})
176
+ else:
177
+ logger.warning(f"image logging not implemented for {tracker.name}")
178
+
179
+ del pipeline
180
+ free_memory()
181
+ return image_logs
182
+
183
+
184
+ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
185
+ img_str = ""
186
+ if image_logs is not None:
187
+ img_str = "You can find some example images below.\n\n"
188
+ for i, log in enumerate(image_logs):
189
+ images = log["images"]
190
+ validation_prompt = log["validation_prompt"]
191
+ validation_image = log["validation_image"]
192
+ validation_image.save(os.path.join(repo_folder, "image_control.png"))
193
+ img_str += f"prompt: {validation_prompt}\n"
194
+ images = [validation_image] + images
195
+ make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
196
+ img_str += f"![images_{i})](./images_{i}.png)\n"
197
+
198
+ model_description = f"""
199
+ # control-lora-{repo_id}
200
+
201
+ These are Control LoRA weights trained on {base_model} with new type of conditioning.
202
+ {img_str}
203
+
204
+ ## License
205
+
206
+ Please adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)
207
+ """
208
+
209
+ model_card = load_or_create_model_card(
210
+ repo_id_or_path=repo_id,
211
+ from_training=True,
212
+ license="other",
213
+ base_model=base_model,
214
+ model_description=model_description,
215
+ inference=True,
216
+ )
217
+
218
+ tags = [
219
+ "flux",
220
+ "flux-diffusers",
221
+ "text-to-image",
222
+ "diffusers",
223
+ "control-lora",
224
+ "diffusers-training",
225
+ "lora",
226
+ ]
227
+ model_card = populate_model_card(model_card, tags=tags)
228
+
229
+ model_card.save(os.path.join(repo_folder, "README.md"))
230
+
231
+
232
+ def parse_args(input_args=None):
233
+ parser = argparse.ArgumentParser(description="Simple example of a Control LoRA training script.")
234
+ parser.add_argument(
235
+ "--pretrained_model_name_or_path",
236
+ type=str,
237
+ default=None,
238
+ required=True,
239
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
240
+ )
241
+ parser.add_argument(
242
+ "--variant",
243
+ type=str,
244
+ default=None,
245
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
246
+ )
247
+ parser.add_argument(
248
+ "--revision",
249
+ type=str,
250
+ default=None,
251
+ required=False,
252
+ help="Revision of pretrained model identifier from huggingface.co/models.",
253
+ )
254
+ parser.add_argument(
255
+ "--output_dir",
256
+ type=str,
257
+ default="control-lora",
258
+ help="The output directory where the model predictions and checkpoints will be written.",
259
+ )
260
+ parser.add_argument(
261
+ "--cache_dir",
262
+ type=str,
263
+ default=None,
264
+ help="The directory where the downloaded models and datasets will be stored.",
265
+ )
266
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
267
+ parser.add_argument(
268
+ "--resolution",
269
+ type=int,
270
+ default=1024,
271
+ help=(
272
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
273
+ " resolution"
274
+ ),
275
+ )
276
+ parser.add_argument(
277
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
278
+ )
279
+ parser.add_argument("--num_train_epochs", type=int, default=1)
280
+ parser.add_argument(
281
+ "--max_train_steps",
282
+ type=int,
283
+ default=None,
284
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
285
+ )
286
+ parser.add_argument(
287
+ "--checkpointing_steps",
288
+ type=int,
289
+ default=500,
290
+ help=(
291
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
292
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
293
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
294
+ "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
295
+ "instructions."
296
+ ),
297
+ )
298
+ parser.add_argument(
299
+ "--checkpoints_total_limit",
300
+ type=int,
301
+ default=None,
302
+ help=("Max number of checkpoints to store."),
303
+ )
304
+ parser.add_argument(
305
+ "--resume_from_checkpoint",
306
+ type=str,
307
+ default=None,
308
+ help=(
309
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
310
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
311
+ ),
312
+ )
313
+ parser.add_argument(
314
+ "--proportion_empty_prompts",
315
+ type=float,
316
+ default=0,
317
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
318
+ )
319
+ parser.add_argument(
320
+ "--rank",
321
+ type=int,
322
+ default=4,
323
+ help=("The dimension of the LoRA update matrices."),
324
+ )
325
+ parser.add_argument("--use_lora_bias", action="store_true", help="If training the bias of lora_B layers.")
326
+ parser.add_argument(
327
+ "--lora_layers",
328
+ type=str,
329
+ default=None,
330
+ help=(
331
+ 'The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only'
332
+ ),
333
+ )
334
+ parser.add_argument(
335
+ "--gaussian_init_lora",
336
+ action="store_true",
337
+ help="If using the Gaussian init strategy. When False, we follow the original LoRA init strategy.",
338
+ )
339
+ parser.add_argument("--train_norm_layers", action="store_true", help="Whether to train the norm scales.")
340
+ parser.add_argument(
341
+ "--gradient_accumulation_steps",
342
+ type=int,
343
+ default=1,
344
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
345
+ )
346
+ parser.add_argument(
347
+ "--gradient_checkpointing",
348
+ action="store_true",
349
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
350
+ )
351
+ parser.add_argument(
352
+ "--learning_rate",
353
+ type=float,
354
+ default=5e-6,
355
+ help="Initial learning rate (after the potential warmup period) to use.",
356
+ )
357
+ parser.add_argument(
358
+ "--scale_lr",
359
+ action="store_true",
360
+ default=False,
361
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
362
+ )
363
+ parser.add_argument(
364
+ "--lr_scheduler",
365
+ type=str,
366
+ default="constant",
367
+ help=(
368
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
369
+ ' "constant", "constant_with_warmup"]'
370
+ ),
371
+ )
372
+ parser.add_argument(
373
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
374
+ )
375
+ parser.add_argument(
376
+ "--lr_num_cycles",
377
+ type=int,
378
+ default=1,
379
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
380
+ )
381
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
382
+ parser.add_argument(
383
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
384
+ )
385
+
386
+ parser.add_argument(
387
+ "--dataloader_num_workers",
388
+ type=int,
389
+ default=0,
390
+ help=(
391
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
392
+ ),
393
+ )
394
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
395
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
396
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
397
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
398
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
399
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
400
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
401
+ parser.add_argument(
402
+ "--hub_model_id",
403
+ type=str,
404
+ default=None,
405
+ help="The name of the repository to keep in sync with the local `output_dir`.",
406
+ )
407
+ parser.add_argument(
408
+ "--logging_dir",
409
+ type=str,
410
+ default="logs",
411
+ help=(
412
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
413
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
414
+ ),
415
+ )
416
+ parser.add_argument(
417
+ "--allow_tf32",
418
+ action="store_true",
419
+ help=(
420
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
421
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
422
+ ),
423
+ )
424
+ parser.add_argument(
425
+ "--report_to",
426
+ type=str,
427
+ default="tensorboard",
428
+ help=(
429
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
430
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
431
+ ),
432
+ )
433
+ parser.add_argument(
434
+ "--mixed_precision",
435
+ type=str,
436
+ default=None,
437
+ choices=["no", "fp16", "bf16"],
438
+ help=(
439
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
440
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
441
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
442
+ ),
443
+ )
444
+ parser.add_argument(
445
+ "--dataset_name",
446
+ type=str,
447
+ default=None,
448
+ help=(
449
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
450
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
451
+ " or to a folder containing files that 🤗 Datasets can understand."
452
+ ),
453
+ )
454
+ parser.add_argument(
455
+ "--dataset_config_name",
456
+ type=str,
457
+ default=None,
458
+ help="The config of the Dataset, leave as None if there's only one config.",
459
+ )
460
+ parser.add_argument(
461
+ "--image_column", type=str, default="image", help="The column of the dataset containing the target image."
462
+ )
463
+ parser.add_argument(
464
+ "--conditioning_image_column",
465
+ type=str,
466
+ default="conditioning_image",
467
+ help="The column of the dataset containing the control conditioning image.",
468
+ )
469
+ parser.add_argument(
470
+ "--caption_column",
471
+ type=str,
472
+ default="text",
473
+ help="The column of the dataset containing a caption or a list of captions.",
474
+ )
475
+ parser.add_argument("--log_dataset_samples", action="store_true", help="Whether to log somple dataset samples.")
476
+ parser.add_argument(
477
+ "--max_train_samples",
478
+ type=int,
479
+ default=None,
480
+ help=(
481
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
482
+ "value if set."
483
+ ),
484
+ )
485
+ parser.add_argument(
486
+ "--validation_prompt",
487
+ type=str,
488
+ default=None,
489
+ nargs="+",
490
+ help=(
491
+ "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
492
+ " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
493
+ " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
494
+ ),
495
+ )
496
+ parser.add_argument(
497
+ "--validation_image",
498
+ type=str,
499
+ default=None,
500
+ nargs="+",
501
+ help=(
502
+ "A set of paths to the control conditioning image be evaluated every `--validation_steps`"
503
+ " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
504
+ " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
505
+ " `--validation_image` that will be used with all `--validation_prompt`s."
506
+ ),
507
+ )
508
+ parser.add_argument(
509
+ "--num_validation_images",
510
+ type=int,
511
+ default=1,
512
+ help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair",
513
+ )
514
+ parser.add_argument(
515
+ "--validation_steps",
516
+ type=int,
517
+ default=100,
518
+ help=(
519
+ "Run validation every X steps. Validation consists of running the prompt"
520
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
521
+ " and logging the images."
522
+ ),
523
+ )
524
+ parser.add_argument(
525
+ "--tracker_project_name",
526
+ type=str,
527
+ default="flux_train_control_lora",
528
+ help=(
529
+ "The `project_name` argument passed to Accelerator.init_trackers for"
530
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
531
+ ),
532
+ )
533
+ parser.add_argument(
534
+ "--jsonl_for_train",
535
+ type=str,
536
+ default=None,
537
+ help="Path to the jsonl file containing the training data.",
538
+ )
539
+
540
+ parser.add_argument(
541
+ "--guidance_scale",
542
+ type=float,
543
+ default=30.0,
544
+ help="the guidance scale used for transformer.",
545
+ )
546
+
547
+ parser.add_argument(
548
+ "--upcast_before_saving",
549
+ action="store_true",
550
+ help=(
551
+ "Whether to upcast the trained transformer layers to float32 before saving (at the end of training). "
552
+ "Defaults to precision dtype used for training to save memory"
553
+ ),
554
+ )
555
+
556
+ parser.add_argument(
557
+ "--weighting_scheme",
558
+ type=str,
559
+ default="none",
560
+ choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
561
+ help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'),
562
+ )
563
+ parser.add_argument(
564
+ "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
565
+ )
566
+ parser.add_argument(
567
+ "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
568
+ )
569
+ parser.add_argument(
570
+ "--mode_scale",
571
+ type=float,
572
+ default=1.29,
573
+ help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
574
+ )
575
+ parser.add_argument(
576
+ "--offload",
577
+ action="store_true",
578
+ help="Whether to offload the VAE and the text encoders to CPU when they are not used.",
579
+ )
580
+
581
+ if input_args is not None:
582
+ args = parser.parse_args(input_args)
583
+ else:
584
+ args = parser.parse_args()
585
+
586
+ if args.dataset_name is None and args.jsonl_for_train is None:
587
+ raise ValueError("Specify either `--dataset_name` or `--jsonl_for_train`")
588
+
589
+ if args.dataset_name is not None and args.jsonl_for_train is not None:
590
+ raise ValueError("Specify only one of `--dataset_name` or `--jsonl_for_train`")
591
+
592
+ if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
593
+ raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
594
+
595
+ if args.validation_prompt is not None and args.validation_image is None:
596
+ raise ValueError("`--validation_image` must be set if `--validation_prompt` is set")
597
+
598
+ if args.validation_prompt is None and args.validation_image is not None:
599
+ raise ValueError("`--validation_prompt` must be set if `--validation_image` is set")
600
+
601
+ if (
602
+ args.validation_image is not None
603
+ and args.validation_prompt is not None
604
+ and len(args.validation_image) != 1
605
+ and len(args.validation_prompt) != 1
606
+ and len(args.validation_image) != len(args.validation_prompt)
607
+ ):
608
+ raise ValueError(
609
+ "Must provide either 1 `--validation_image`, 1 `--validation_prompt`,"
610
+ " or the same number of `--validation_prompt`s and `--validation_image`s"
611
+ )
612
+
613
+ if args.resolution % 8 != 0:
614
+ raise ValueError(
615
+ "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the Flux transformer."
616
+ )
617
+
618
+ return args
619
+
620
+
621
+ def get_train_dataset(args, accelerator):
622
+ dataset = None
623
+ if args.dataset_name is not None:
624
+ # Downloading and loading a dataset from the hub.
625
+ dataset = load_dataset(
626
+ args.dataset_name,
627
+ args.dataset_config_name,
628
+ cache_dir=args.cache_dir,
629
+ )
630
+ if args.jsonl_for_train is not None:
631
+ # load from json
632
+ dataset = load_dataset("json", data_files=args.jsonl_for_train, cache_dir=args.cache_dir)
633
+ dataset = dataset.flatten_indices()
634
+ # Preprocessing the datasets.
635
+ # We need to tokenize inputs and targets.
636
+ column_names = dataset["train"].column_names
637
+
638
+ # 6. Get the column names for input/target.
639
+ if args.image_column is None:
640
+ image_column = column_names[0]
641
+ logger.info(f"image column defaulting to {image_column}")
642
+ else:
643
+ image_column = args.image_column
644
+ if image_column not in column_names:
645
+ raise ValueError(
646
+ f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
647
+ )
648
+
649
+ if args.caption_column is None:
650
+ caption_column = column_names[1]
651
+ logger.info(f"caption column defaulting to {caption_column}")
652
+ else:
653
+ caption_column = args.caption_column
654
+ if caption_column not in column_names:
655
+ raise ValueError(
656
+ f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
657
+ )
658
+
659
+ if args.conditioning_image_column is None:
660
+ conditioning_image_column = column_names[2]
661
+ logger.info(f"conditioning image column defaulting to {conditioning_image_column}")
662
+ else:
663
+ conditioning_image_column = args.conditioning_image_column
664
+ if conditioning_image_column not in column_names:
665
+ raise ValueError(
666
+ f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
667
+ )
668
+
669
+ with accelerator.main_process_first():
670
+ train_dataset = dataset["train"].shuffle(seed=args.seed)
671
+ if args.max_train_samples is not None:
672
+ train_dataset = train_dataset.select(range(args.max_train_samples))
673
+ return train_dataset
674
+
675
+
676
+ def prepare_train_dataset(dataset, accelerator):
677
+ image_transforms = transforms.Compose(
678
+ [
679
+ transforms.Resize((args.resolution, args.resolution), interpolation=transforms.InterpolationMode.BILINEAR),
680
+ transforms.ToTensor(),
681
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
682
+ ]
683
+ )
684
+
685
+ def preprocess_train(examples):
686
+ images = [
687
+ (image.convert("RGB") if not isinstance(image, str) else Image.open(image).convert("RGB"))
688
+ for image in examples[args.image_column]
689
+ ]
690
+ images = [image_transforms(image) for image in images]
691
+
692
+ conditioning_images = [
693
+ (image.convert("RGB") if not isinstance(image, str) else Image.open(image).convert("RGB"))
694
+ for image in examples[args.conditioning_image_column]
695
+ ]
696
+ conditioning_images = [image_transforms(image) for image in conditioning_images]
697
+ examples["pixel_values"] = images
698
+ examples["conditioning_pixel_values"] = conditioning_images
699
+
700
+ is_caption_list = isinstance(examples[args.caption_column][0], list)
701
+ if is_caption_list:
702
+ examples["captions"] = [max(example, key=len) for example in examples[args.caption_column]]
703
+ else:
704
+ examples["captions"] = list(examples[args.caption_column])
705
+
706
+ return examples
707
+
708
+ with accelerator.main_process_first():
709
+ dataset = dataset.with_transform(preprocess_train)
710
+
711
+ return dataset
712
+
713
+
714
+ def collate_fn(examples):
715
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
716
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
717
+ conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples])
718
+ conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()
719
+ captions = [example["captions"] for example in examples]
720
+ return {"pixel_values": pixel_values, "conditioning_pixel_values": conditioning_pixel_values, "captions": captions}
721
+
722
+
723
+ def main(args):
724
+ if args.report_to == "wandb" and args.hub_token is not None:
725
+ raise ValueError(
726
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
727
+ " Please use `huggingface-cli login` to authenticate with the Hub."
728
+ )
729
+ if args.use_lora_bias and args.gaussian_init_lora:
730
+ raise ValueError("`gaussian` LoRA init scheme isn't supported when `use_lora_bias` is True.")
731
+
732
+ logging_out_dir = Path(args.output_dir, args.logging_dir)
733
+
734
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
735
+ # due to pytorch#99272, MPS does not yet support bfloat16.
736
+ raise ValueError(
737
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
738
+ )
739
+
740
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=str(logging_out_dir))
741
+
742
+ accelerator = Accelerator(
743
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
744
+ mixed_precision=args.mixed_precision,
745
+ log_with=args.report_to,
746
+ project_config=accelerator_project_config,
747
+ )
748
+
749
+ # Disable AMP for MPS. A technique for accelerating machine learning computations on iOS and macOS devices.
750
+ if torch.backends.mps.is_available():
751
+ logger.info("MPS is enabled. Disabling AMP.")
752
+ accelerator.native_amp = False
753
+
754
+ # Make one log on every process with the configuration for debugging.
755
+ logging.basicConfig(
756
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
757
+ datefmt="%m/%d/%Y %H:%M:%S",
758
+ # DEBUG, INFO, WARNING, ERROR, CRITICAL
759
+ level=logging.INFO,
760
+ )
761
+ logger.info(accelerator.state, main_process_only=False)
762
+
763
+ if accelerator.is_local_main_process:
764
+ transformers.utils.logging.set_verbosity_warning()
765
+ diffusers.utils.logging.set_verbosity_info()
766
+ else:
767
+ transformers.utils.logging.set_verbosity_error()
768
+ diffusers.utils.logging.set_verbosity_error()
769
+
770
+ # If passed along, set the training seed now.
771
+ if args.seed is not None:
772
+ set_seed(args.seed)
773
+
774
+ # Handle the repository creation
775
+ if accelerator.is_main_process:
776
+ if args.output_dir is not None:
777
+ os.makedirs(args.output_dir, exist_ok=True)
778
+
779
+ if args.push_to_hub:
780
+ repo_id = create_repo(
781
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
782
+ ).repo_id
783
+
784
+ # Load models. We will load the text encoders later in a pipeline to compute
785
+ # embeddings.
786
+ vae = AutoencoderKL.from_pretrained(
787
+ args.pretrained_model_name_or_path,
788
+ subfolder="vae",
789
+ revision=args.revision,
790
+ variant=args.variant,
791
+ )
792
+ vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
793
+ flux_transformer = FluxTransformer2DModel.from_pretrained(
794
+ args.pretrained_model_name_or_path,
795
+ subfolder="transformer",
796
+ revision=args.revision,
797
+ variant=args.variant,
798
+ )
799
+ logger.info("All models loaded successfully")
800
+
801
+ noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
802
+ args.pretrained_model_name_or_path,
803
+ subfolder="scheduler",
804
+ )
805
+ noise_scheduler_copy = copy.deepcopy(noise_scheduler)
806
+ vae.requires_grad_(False)
807
+ flux_transformer.requires_grad_(False)
808
+
809
+ # cast down and move to the CPU
810
+ weight_dtype = torch.float32
811
+ if accelerator.mixed_precision == "fp16":
812
+ weight_dtype = torch.float16
813
+ elif accelerator.mixed_precision == "bf16":
814
+ weight_dtype = torch.bfloat16
815
+
816
+ # let's not move the VAE to the GPU yet.
817
+ vae.to(dtype=torch.float32) # keep the VAE in float32.
818
+ flux_transformer.to(dtype=weight_dtype, device=accelerator.device)
819
+
820
+ # enable image inputs
821
+ with torch.no_grad():
822
+ initial_input_channels = flux_transformer.config.in_channels
823
+ new_linear = torch.nn.Linear(
824
+ flux_transformer.x_embedder.in_features * 2,
825
+ flux_transformer.x_embedder.out_features,
826
+ bias=flux_transformer.x_embedder.bias is not None,
827
+ dtype=flux_transformer.dtype,
828
+ device=flux_transformer.device,
829
+ )
830
+ new_linear.weight.zero_()
831
+ new_linear.weight[:, :initial_input_channels].copy_(flux_transformer.x_embedder.weight)
832
+ if flux_transformer.x_embedder.bias is not None:
833
+ new_linear.bias.copy_(flux_transformer.x_embedder.bias)
834
+ flux_transformer.x_embedder = new_linear
835
+
836
+ assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0)
837
+ flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels)
838
+
839
+ if args.lora_layers is not None:
840
+ if args.lora_layers != "all-linear":
841
+ target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
842
+ # add the input layer to the mix.
843
+ if "x_embedder" not in target_modules:
844
+ target_modules.append("x_embedder")
845
+ elif args.lora_layers == "all-linear":
846
+ target_modules = set()
847
+ for name, module in flux_transformer.named_modules():
848
+ if isinstance(module, torch.nn.Linear):
849
+ target_modules.add(name)
850
+ target_modules = list(target_modules)
851
+ else:
852
+ target_modules = [
853
+ "x_embedder",
854
+ "attn.to_k",
855
+ "attn.to_q",
856
+ "attn.to_v",
857
+ "attn.to_out.0",
858
+ "attn.add_k_proj",
859
+ "attn.add_q_proj",
860
+ "attn.add_v_proj",
861
+ "attn.to_add_out",
862
+ "ff.net.0.proj",
863
+ "ff.net.2",
864
+ "ff_context.net.0.proj",
865
+ "ff_context.net.2",
866
+ ]
867
+ transformer_lora_config = LoraConfig(
868
+ r=args.rank,
869
+ lora_alpha=args.rank,
870
+ init_lora_weights="gaussian" if args.gaussian_init_lora else True,
871
+ target_modules=target_modules,
872
+ lora_bias=args.use_lora_bias,
873
+ )
874
+ flux_transformer.add_adapter(transformer_lora_config)
875
+
876
+ if args.train_norm_layers:
877
+ for name, param in flux_transformer.named_parameters():
878
+ if any(k in name for k in NORM_LAYER_PREFIXES):
879
+ param.requires_grad = True
880
+
881
+ def unwrap_model(model):
882
+ model = accelerator.unwrap_model(model)
883
+ model = model._orig_mod if is_compiled_module(model) else model
884
+ return model
885
+
886
+ # `accelerate` 0.16.0 will have better support for customized saving
887
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
888
+
889
+ def save_model_hook(models, weights, output_dir):
890
+ if accelerator.is_main_process:
891
+ transformer_lora_layers_to_save = None
892
+
893
+ for model in models:
894
+ if isinstance(unwrap_model(model), type(unwrap_model(flux_transformer))):
895
+ model = unwrap_model(model)
896
+ transformer_lora_layers_to_save = get_peft_model_state_dict(model)
897
+ if args.train_norm_layers:
898
+ transformer_norm_layers_to_save = {
899
+ f"transformer.{name}": param
900
+ for name, param in model.named_parameters()
901
+ if any(k in name for k in NORM_LAYER_PREFIXES)
902
+ }
903
+ transformer_lora_layers_to_save = {
904
+ **transformer_lora_layers_to_save,
905
+ **transformer_norm_layers_to_save,
906
+ }
907
+ else:
908
+ raise ValueError(f"unexpected save model: {model.__class__}")
909
+
910
+ # make sure to pop weight so that corresponding model is not saved again
911
+ if weights:
912
+ weights.pop()
913
+
914
+ FluxControlPipeline.save_lora_weights(
915
+ output_dir,
916
+ transformer_lora_layers=transformer_lora_layers_to_save,
917
+ )
918
+
919
+ def load_model_hook(models, input_dir):
920
+ transformer_ = None
921
+
922
+ if not accelerator.distributed_type == DistributedType.DEEPSPEED:
923
+ while len(models) > 0:
924
+ model = models.pop()
925
+
926
+ if isinstance(model, type(unwrap_model(flux_transformer))):
927
+ transformer_ = model
928
+ else:
929
+ raise ValueError(f"unexpected save model: {model.__class__}")
930
+ else:
931
+ transformer_ = FluxTransformer2DModel.from_pretrained(
932
+ args.pretrained_model_name_or_path, subfolder="transformer"
933
+ ).to(accelerator.device, weight_dtype)
934
+
935
+ # Handle input dimension doubling before adding adapter
936
+ with torch.no_grad():
937
+ initial_input_channels = transformer_.config.in_channels
938
+ new_linear = torch.nn.Linear(
939
+ transformer_.x_embedder.in_features * 2,
940
+ transformer_.x_embedder.out_features,
941
+ bias=transformer_.x_embedder.bias is not None,
942
+ dtype=transformer_.dtype,
943
+ device=transformer_.device,
944
+ )
945
+ new_linear.weight.zero_()
946
+ new_linear.weight[:, :initial_input_channels].copy_(transformer_.x_embedder.weight)
947
+ if transformer_.x_embedder.bias is not None:
948
+ new_linear.bias.copy_(transformer_.x_embedder.bias)
949
+ transformer_.x_embedder = new_linear
950
+ transformer_.register_to_config(in_channels=initial_input_channels * 2)
951
+
952
+ transformer_.add_adapter(transformer_lora_config)
953
+
954
+ lora_state_dict = FluxControlPipeline.lora_state_dict(input_dir)
955
+ transformer_lora_state_dict = {
956
+ f"{k.replace('transformer.', '')}": v
957
+ for k, v in lora_state_dict.items()
958
+ if k.startswith("transformer.") and "lora" in k
959
+ }
960
+ incompatible_keys = set_peft_model_state_dict(
961
+ transformer_, transformer_lora_state_dict, adapter_name="default"
962
+ )
963
+ if incompatible_keys is not None:
964
+ # check only for unexpected keys
965
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
966
+ if unexpected_keys:
967
+ logger.warning(
968
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
969
+ f" {unexpected_keys}. "
970
+ )
971
+ if args.train_norm_layers:
972
+ transformer_norm_state_dict = {
973
+ k: v
974
+ for k, v in lora_state_dict.items()
975
+ if k.startswith("transformer.") and any(norm_k in k for norm_k in NORM_LAYER_PREFIXES)
976
+ }
977
+ transformer_._transformer_norm_layers = FluxControlPipeline._load_norm_into_transformer(
978
+ transformer_norm_state_dict,
979
+ transformer=transformer_,
980
+ discard_original_layers=False,
981
+ )
982
+
983
+ # Make sure the trainable params are in float32. This is again needed since the base models
984
+ # are in `weight_dtype`. More details:
985
+ # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
986
+ if args.mixed_precision == "fp16":
987
+ models = [transformer_]
988
+ # only upcast trainable parameters (LoRA) into fp32
989
+ cast_training_params(models)
990
+
991
+ accelerator.register_save_state_pre_hook(save_model_hook)
992
+ accelerator.register_load_state_pre_hook(load_model_hook)
993
+
994
+ # Make sure the trainable params are in float32.
995
+ if args.mixed_precision == "fp16":
996
+ models = [flux_transformer]
997
+ # only upcast trainable parameters (LoRA) into fp32
998
+ cast_training_params(models, dtype=torch.float32)
999
+
1000
+ if args.gradient_checkpointing:
1001
+ flux_transformer.enable_gradient_checkpointing()
1002
+
1003
+ # Enable TF32 for faster training on Ampere GPUs,
1004
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
1005
+ if args.allow_tf32:
1006
+ torch.backends.cuda.matmul.allow_tf32 = True
1007
+
1008
+ if args.scale_lr:
1009
+ args.learning_rate = (
1010
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
1011
+ )
1012
+
1013
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
1014
+ if args.use_8bit_adam:
1015
+ try:
1016
+ import bitsandbytes as bnb
1017
+ except ImportError:
1018
+ raise ImportError(
1019
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
1020
+ )
1021
+
1022
+ optimizer_class = bnb.optim.AdamW8bit
1023
+ else:
1024
+ optimizer_class = torch.optim.AdamW
1025
+
1026
+ # Optimization parameters
1027
+ transformer_lora_parameters = list(filter(lambda p: p.requires_grad, flux_transformer.parameters()))
1028
+ optimizer = optimizer_class(
1029
+ transformer_lora_parameters,
1030
+ lr=args.learning_rate,
1031
+ betas=(args.adam_beta1, args.adam_beta2),
1032
+ weight_decay=args.adam_weight_decay,
1033
+ eps=args.adam_epsilon,
1034
+ )
1035
+
1036
+ # Prepare dataset and dataloader.
1037
+ train_dataset = get_train_dataset(args, accelerator)
1038
+ train_dataset = prepare_train_dataset(train_dataset, accelerator)
1039
+ train_dataloader = torch.utils.data.DataLoader(
1040
+ train_dataset,
1041
+ shuffle=True,
1042
+ collate_fn=collate_fn,
1043
+ batch_size=args.train_batch_size,
1044
+ num_workers=args.dataloader_num_workers,
1045
+ )
1046
+
1047
+ # Scheduler and math around the number of training steps.
1048
+ # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
1049
+ if args.max_train_steps is None:
1050
+ len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
1051
+ num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
1052
+ num_training_steps_for_scheduler = (
1053
+ args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
1054
+ )
1055
+ else:
1056
+ num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
1057
+
1058
+ lr_scheduler = get_scheduler(
1059
+ args.lr_scheduler,
1060
+ optimizer=optimizer,
1061
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
1062
+ num_training_steps=num_training_steps_for_scheduler,
1063
+ num_cycles=args.lr_num_cycles,
1064
+ power=args.lr_power,
1065
+ )
1066
+ # Prepare everything with our `accelerator`.
1067
+ flux_transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1068
+ flux_transformer, optimizer, train_dataloader, lr_scheduler
1069
+ )
1070
+
1071
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
1072
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1073
+ if args.max_train_steps is None:
1074
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1075
+ if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
1076
+ logger.warning(
1077
+ f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
1078
+ f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
1079
+ f"This inconsistency may result in the learning rate scheduler not functioning properly."
1080
+ )
1081
+ # Afterwards we recalculate our number of training epochs
1082
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
1083
+
1084
+ # We need to initialize the trackers we use, and also store our configuration.
1085
+ # The trackers initializes automatically on the main process.
1086
+ if accelerator.is_main_process:
1087
+ tracker_config = dict(vars(args))
1088
+
1089
+ # tensorboard cannot handle list types for config
1090
+ tracker_config.pop("validation_prompt")
1091
+ tracker_config.pop("validation_image")
1092
+
1093
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
1094
+
1095
+ # Train!
1096
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
1097
+
1098
+ logger.info("***** Running training *****")
1099
+ logger.info(f" Num examples = {len(train_dataset)}")
1100
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
1101
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
1102
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
1103
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
1104
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1105
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
1106
+ global_step = 0
1107
+ first_epoch = 0
1108
+
1109
+ # Create a pipeline for text encoding. We will move this pipeline to GPU/CPU as needed.
1110
+ text_encoding_pipeline = FluxControlPipeline.from_pretrained(
1111
+ args.pretrained_model_name_or_path, transformer=None, vae=None, torch_dtype=weight_dtype
1112
+ )
1113
+
1114
+ # Potentially load in the weights and states from a previous save
1115
+ if args.resume_from_checkpoint:
1116
+ if args.resume_from_checkpoint != "latest":
1117
+ path = os.path.basename(args.resume_from_checkpoint)
1118
+ else:
1119
+ # Get the most recent checkpoint
1120
+ dirs = os.listdir(args.output_dir)
1121
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
1122
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
1123
+ path = dirs[-1] if len(dirs) > 0 else None
1124
+
1125
+ if path is None:
1126
+ logger.info(f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.")
1127
+ args.resume_from_checkpoint = None
1128
+ initial_global_step = 0
1129
+ else:
1130
+ logger.info(f"Resuming from checkpoint {path}")
1131
+ accelerator.load_state(os.path.join(args.output_dir, path))
1132
+ global_step = int(path.split("-")[1])
1133
+
1134
+ initial_global_step = global_step
1135
+ first_epoch = global_step // num_update_steps_per_epoch
1136
+ else:
1137
+ initial_global_step = 0
1138
+
1139
+ if accelerator.is_main_process and args.report_to == "wandb" and args.log_dataset_samples:
1140
+ logger.info("Logging some dataset samples.")
1141
+ formatted_images = []
1142
+ formatted_control_images = []
1143
+ all_prompts = []
1144
+ for i, batch in enumerate(train_dataloader):
1145
+ images = (batch["pixel_values"] + 1) / 2
1146
+ control_images = (batch["conditioning_pixel_values"] + 1) / 2
1147
+ prompts = batch["captions"]
1148
+
1149
+ if len(formatted_images) > 10:
1150
+ break
1151
+
1152
+ for img, control_img, prompt in zip(images, control_images, prompts):
1153
+ formatted_images.append(img)
1154
+ formatted_control_images.append(control_img)
1155
+ all_prompts.append(prompt)
1156
+
1157
+ logged_artifacts = []
1158
+ for img, control_img, prompt in zip(formatted_images, formatted_control_images, all_prompts):
1159
+ logged_artifacts.append(wandb.Image(control_img, caption="Conditioning"))
1160
+ logged_artifacts.append(wandb.Image(img, caption=prompt))
1161
+
1162
+ wandb_tracker = [tracker for tracker in accelerator.trackers if tracker.name == "wandb"]
1163
+ wandb_tracker[0].log({"dataset_samples": logged_artifacts})
1164
+
1165
+ progress_bar = tqdm(
1166
+ range(0, args.max_train_steps),
1167
+ initial=initial_global_step,
1168
+ desc="Steps",
1169
+ # Only show the progress bar once on each machine.
1170
+ disable=not accelerator.is_local_main_process,
1171
+ )
1172
+
1173
+ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
1174
+ sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
1175
+ schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
1176
+ timesteps = timesteps.to(accelerator.device)
1177
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
1178
+
1179
+ sigma = sigmas[step_indices].flatten()
1180
+ while len(sigma.shape) < n_dim:
1181
+ sigma = sigma.unsqueeze(-1)
1182
+ return sigma
1183
+
1184
+ image_logs = None
1185
+ for epoch in range(first_epoch, args.num_train_epochs):
1186
+ flux_transformer.train()
1187
+ for step, batch in enumerate(train_dataloader):
1188
+ with accelerator.accumulate(flux_transformer):
1189
+ # Convert images to latent space
1190
+ # vae encode
1191
+ pixel_latents = encode_images(batch["pixel_values"], vae.to(accelerator.device), weight_dtype)
1192
+ control_latents = encode_images(
1193
+ batch["conditioning_pixel_values"], vae.to(accelerator.device), weight_dtype
1194
+ )
1195
+
1196
+ if args.offload:
1197
+ # offload vae to CPU.
1198
+ vae.cpu()
1199
+
1200
+ # Sample a random timestep for each image
1201
+ # for weighting schemes where we sample timesteps non-uniformly
1202
+ bsz = pixel_latents.shape[0]
1203
+ noise = torch.randn_like(pixel_latents, device=accelerator.device, dtype=weight_dtype)
1204
+ u = compute_density_for_timestep_sampling(
1205
+ weighting_scheme=args.weighting_scheme,
1206
+ batch_size=bsz,
1207
+ logit_mean=args.logit_mean,
1208
+ logit_std=args.logit_std,
1209
+ mode_scale=args.mode_scale,
1210
+ )
1211
+ indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
1212
+ timesteps = noise_scheduler_copy.timesteps[indices].to(device=pixel_latents.device)
1213
+
1214
+ # Add noise according to flow matching.
1215
+ sigmas = get_sigmas(timesteps, n_dim=pixel_latents.ndim, dtype=pixel_latents.dtype)
1216
+ noisy_model_input = (1.0 - sigmas) * pixel_latents + sigmas * noise
1217
+ # Concatenate across channels.
1218
+ # Question: Should we concatenate before adding noise?
1219
+ concatenated_noisy_model_input = torch.cat([noisy_model_input, control_latents], dim=1)
1220
+
1221
+ # pack the latents.
1222
+ packed_noisy_model_input = FluxControlPipeline._pack_latents(
1223
+ concatenated_noisy_model_input,
1224
+ batch_size=bsz,
1225
+ num_channels_latents=concatenated_noisy_model_input.shape[1],
1226
+ height=concatenated_noisy_model_input.shape[2],
1227
+ width=concatenated_noisy_model_input.shape[3],
1228
+ )
1229
+
1230
+ # latent image ids for RoPE.
1231
+ latent_image_ids = FluxControlPipeline._prepare_latent_image_ids(
1232
+ bsz,
1233
+ concatenated_noisy_model_input.shape[2] // 2,
1234
+ concatenated_noisy_model_input.shape[3] // 2,
1235
+ accelerator.device,
1236
+ weight_dtype,
1237
+ )
1238
+
1239
+ # handle guidance
1240
+ if unwrap_model(flux_transformer).config.guidance_embeds:
1241
+ guidance_vec = torch.full(
1242
+ (bsz,),
1243
+ args.guidance_scale,
1244
+ device=noisy_model_input.device,
1245
+ dtype=weight_dtype,
1246
+ )
1247
+ else:
1248
+ guidance_vec = None
1249
+
1250
+ # text encoding.
1251
+ captions = batch["captions"]
1252
+ text_encoding_pipeline = text_encoding_pipeline.to("cuda")
1253
+ with torch.no_grad():
1254
+ prompt_embeds, pooled_prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt(
1255
+ captions, prompt_2=None
1256
+ )
1257
+ # this could be optimized by not having to do any text encoding and just
1258
+ # doing zeros on specified shapes for `prompt_embeds` and `pooled_prompt_embeds`
1259
+ if args.proportion_empty_prompts and random.random() < args.proportion_empty_prompts:
1260
+ prompt_embeds.zero_()
1261
+ pooled_prompt_embeds.zero_()
1262
+ if args.offload:
1263
+ text_encoding_pipeline = text_encoding_pipeline.to("cpu")
1264
+
1265
+ # Predict.
1266
+ model_pred = flux_transformer(
1267
+ hidden_states=packed_noisy_model_input,
1268
+ timestep=timesteps / 1000,
1269
+ guidance=guidance_vec,
1270
+ pooled_projections=pooled_prompt_embeds,
1271
+ encoder_hidden_states=prompt_embeds,
1272
+ txt_ids=text_ids,
1273
+ img_ids=latent_image_ids,
1274
+ return_dict=False,
1275
+ )[0]
1276
+ model_pred = FluxControlPipeline._unpack_latents(
1277
+ model_pred,
1278
+ height=noisy_model_input.shape[2] * vae_scale_factor,
1279
+ width=noisy_model_input.shape[3] * vae_scale_factor,
1280
+ vae_scale_factor=vae_scale_factor,
1281
+ )
1282
+ # these weighting schemes use a uniform timestep sampling
1283
+ # and instead post-weight the loss
1284
+ weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
1285
+
1286
+ # flow-matching loss
1287
+ target = noise - pixel_latents
1288
+ loss = torch.mean(
1289
+ (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
1290
+ 1,
1291
+ )
1292
+ loss = loss.mean()
1293
+ accelerator.backward(loss)
1294
+
1295
+ if accelerator.sync_gradients:
1296
+ params_to_clip = flux_transformer.parameters()
1297
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1298
+ optimizer.step()
1299
+ lr_scheduler.step()
1300
+ optimizer.zero_grad()
1301
+
1302
+ # Checks if the accelerator has performed an optimization step behind the scenes
1303
+ if accelerator.sync_gradients:
1304
+ progress_bar.update(1)
1305
+ global_step += 1
1306
+
1307
+ # DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues.
1308
+ if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
1309
+ if global_step % args.checkpointing_steps == 0:
1310
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1311
+ if args.checkpoints_total_limit is not None:
1312
+ checkpoints = os.listdir(args.output_dir)
1313
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
1314
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
1315
+
1316
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1317
+ if len(checkpoints) >= args.checkpoints_total_limit:
1318
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
1319
+ removing_checkpoints = checkpoints[0:num_to_remove]
1320
+
1321
+ logger.info(
1322
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1323
+ )
1324
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
1325
+
1326
+ for removing_checkpoint in removing_checkpoints:
1327
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
1328
+ shutil.rmtree(removing_checkpoint)
1329
+
1330
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1331
+ accelerator.save_state(save_path)
1332
+ logger.info(f"Saved state to {save_path}")
1333
+
1334
+ if args.validation_prompt is not None and global_step % args.validation_steps == 0:
1335
+ image_logs = log_validation(
1336
+ flux_transformer=flux_transformer,
1337
+ args=args,
1338
+ accelerator=accelerator,
1339
+ weight_dtype=weight_dtype,
1340
+ step=global_step,
1341
+ )
1342
+
1343
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1344
+ progress_bar.set_postfix(**logs)
1345
+ accelerator.log(logs, step=global_step)
1346
+
1347
+ if global_step >= args.max_train_steps:
1348
+ break
1349
+
1350
+ # Create the pipeline using using the trained modules and save it.
1351
+ accelerator.wait_for_everyone()
1352
+ if accelerator.is_main_process:
1353
+ flux_transformer = unwrap_model(flux_transformer)
1354
+ if args.upcast_before_saving:
1355
+ flux_transformer.to(torch.float32)
1356
+ transformer_lora_layers = get_peft_model_state_dict(flux_transformer)
1357
+ if args.train_norm_layers:
1358
+ transformer_norm_layers = {
1359
+ f"transformer.{name}": param
1360
+ for name, param in flux_transformer.named_parameters()
1361
+ if any(k in name for k in NORM_LAYER_PREFIXES)
1362
+ }
1363
+ transformer_lora_layers = {**transformer_lora_layers, **transformer_norm_layers}
1364
+ FluxControlPipeline.save_lora_weights(
1365
+ save_directory=args.output_dir,
1366
+ transformer_lora_layers=transformer_lora_layers,
1367
+ )
1368
+
1369
+ del flux_transformer
1370
+ del text_encoding_pipeline
1371
+ del vae
1372
+ free_memory()
1373
+
1374
+ # Run a final round of validation.
1375
+ image_logs = None
1376
+ if args.validation_prompt is not None:
1377
+ image_logs = log_validation(
1378
+ flux_transformer=None,
1379
+ args=args,
1380
+ accelerator=accelerator,
1381
+ weight_dtype=weight_dtype,
1382
+ step=global_step,
1383
+ is_final_validation=True,
1384
+ )
1385
+
1386
+ if args.push_to_hub:
1387
+ save_model_card(
1388
+ repo_id,
1389
+ image_logs=image_logs,
1390
+ base_model=args.pretrained_model_name_or_path,
1391
+ repo_folder=args.output_dir,
1392
+ )
1393
+ upload_folder(
1394
+ repo_id=repo_id,
1395
+ folder_path=args.output_dir,
1396
+ commit_message="End of training",
1397
+ ignore_patterns=["step_*", "epoch_*", "*.pt", "*.bin"],
1398
+ )
1399
+
1400
+ accelerator.end_training()
1401
+
1402
+
1403
+ if __name__ == "__main__":
1404
+ args = parse_args()
1405
+ main(args)
train_flux_kontext.sh ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
2
+ --deepspeed_config_file "ds_config.json" \
3
+ train_flux_kontext_lora.py \
4
+ --pretrained_model_name_or_path="flux-kontext-ckpt" \
5
+ --dataset_name="raulc0399/open_pose_controlnet" \
6
+ --output_dir="long-context-flux" \
7
+ --mixed_precision="bf16" \
8
+ --train_batch_size=1 \
9
+ --rank=16 \
10
+ --gradient_accumulation_steps=4 \
11
+ --gradient_checkpointing \
12
+ --learning_rate=1e-4 \
13
+ --report_to="wandb" \
14
+ --lr_scheduler="constant" \
15
+ --lr_warmup_steps=0 \
16
+ --max_train_steps=5000 \
17
+ --validation_image="openpose.png" \
18
+ --validation_prompt="A couple, 4k photo, highly detailed" \
19
+ --offload \
20
+ --seed="0" \
21
+ --push_to_hub \
22
+ --use_8bit_adam
train_flux_kontext_lora.py ADDED
@@ -0,0 +1,1435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+
16
+ import argparse
17
+ import copy
18
+ import logging
19
+ import math
20
+ import os
21
+
22
+ os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
23
+
24
+ import random
25
+ import shutil
26
+ from contextlib import nullcontext
27
+ from pathlib import Path
28
+
29
+ import accelerate
30
+ import diffusers
31
+ import numpy as np
32
+ import torch
33
+ import transformers
34
+ from accelerate import Accelerator
35
+ from accelerate.logging import get_logger
36
+ from accelerate.utils import DistributedType, ProjectConfiguration, set_seed
37
+ from datasets import load_dataset
38
+ from diffusers import (AutoencoderKL, FlowMatchEulerDiscreteScheduler,
39
+ FluxKontextPipeline, FluxTransformer2DModel)
40
+ from diffusers.optimization import get_scheduler
41
+ from diffusers.training_utils import (cast_training_params,
42
+ compute_density_for_timestep_sampling,
43
+ compute_loss_weighting_for_sd3,
44
+ free_memory)
45
+ from diffusers.utils import (check_min_version, is_wandb_available, load_image,
46
+ make_image_grid)
47
+ from diffusers.utils.hub_utils import (load_or_create_model_card,
48
+ populate_model_card)
49
+ from diffusers.utils.torch_utils import is_compiled_module
50
+ from huggingface_hub import create_repo, upload_folder
51
+ from packaging import version
52
+ from peft import LoraConfig, set_peft_model_state_dict
53
+ from peft.utils import get_peft_model_state_dict
54
+ from PIL import Image
55
+ from torchvision import transforms
56
+ from tqdm.auto import tqdm
57
+
58
+ if is_wandb_available():
59
+ import wandb
60
+
61
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
62
+ check_min_version("0.35.0.dev0")
63
+
64
+ logger = get_logger(__name__)
65
+
66
+ NORM_LAYER_PREFIXES = ["norm_q", "norm_k", "norm_added_q", "norm_added_k"]
67
+
68
+
69
+
70
+ def encode_images(pixels: torch.Tensor, vae: torch.nn.Module, weight_dtype):
71
+ pixel_latents = vae.encode(pixels.to(vae.dtype)).latent_dist.sample()
72
+ pixel_latents = (pixel_latents - vae.config.shift_factor) * vae.config.scaling_factor
73
+ return pixel_latents.to(weight_dtype)
74
+
75
+
76
+ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_final_validation=False):
77
+ logger.info("Running validation... ")
78
+
79
+ if not is_final_validation:
80
+ flux_transformer = accelerator.unwrap_model(flux_transformer)
81
+ pipeline = FluxKontextPipeline.from_pretrained(
82
+ args.pretrained_model_name_or_path,
83
+ transformer=flux_transformer,
84
+ torch_dtype=weight_dtype,
85
+ )
86
+ else:
87
+ transformer = FluxTransformer2DModel.from_pretrained(
88
+ args.pretrained_model_name_or_path, subfolder="transformer", torch_dtype=weight_dtype
89
+ )
90
+ initial_channels = transformer.config.in_channels
91
+ pipeline = FluxKontextPipeline.from_pretrained(
92
+ args.pretrained_model_name_or_path,
93
+ transformer=transformer,
94
+ torch_dtype=weight_dtype,
95
+ )
96
+ pipeline.load_lora_weights(args.output_dir)
97
+ assert pipeline.transformer.config.in_channels == initial_channels * 2, (
98
+ f"{pipeline.transformer.config.in_channels=}"
99
+ )
100
+
101
+ pipeline.to(accelerator.device)
102
+ pipeline.set_progress_bar_config(disable=True)
103
+
104
+ if args.seed is None:
105
+ generator = None
106
+ else:
107
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
108
+
109
+ # NOTE: note here
110
+ if len(args.validation_image) == len(args.validation_prompt):
111
+ validation_images = args.validation_image
112
+ validation_prompts = args.validation_prompt
113
+ elif len(args.validation_image) == 1:
114
+ validation_images = args.validation_image * len(args.validation_prompt)
115
+ validation_prompts = args.validation_prompt
116
+ elif len(args.validation_prompt) == 1:
117
+ validation_images = args.validation_image
118
+ validation_prompts = args.validation_prompt * len(args.validation_image)
119
+ else:
120
+ raise ValueError(
121
+ "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
122
+ )
123
+
124
+ image_logs = []
125
+ if is_final_validation or torch.backends.mps.is_available():
126
+ autocast_ctx = nullcontext()
127
+ else:
128
+ autocast_ctx = torch.autocast(accelerator.device.type, weight_dtype)
129
+
130
+ for validation_prompt, validation_image in zip(validation_prompts, validation_images):
131
+ validation_image = load_image(validation_image)
132
+ # maybe need to inference on 1024 to get a good image
133
+ validation_image = validation_image.resize((args.resolution, args.resolution))
134
+
135
+ images = []
136
+
137
+ for _ in range(args.num_validation_images):
138
+ with autocast_ctx:
139
+ # pipeline of flux kontext
140
+ image = pipeline(
141
+ image=validation_image,
142
+ prompt=validation_prompt,
143
+ guidance_scale=args.guidance_scale,
144
+ max_sequence_length=512,
145
+ height=args.resolution,
146
+ width=args.resolution,
147
+ ).images[0]
148
+ image = image.resize((args.resolution, args.resolution))
149
+ images.append(image)
150
+ image_logs.append(
151
+ {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
152
+ )
153
+
154
+ tracker_key = "test" if is_final_validation else "validation"
155
+ for tracker in accelerator.trackers:
156
+ if tracker.name == "tensorboard":
157
+ for log in image_logs:
158
+ images = log["images"]
159
+ validation_prompt = log["validation_prompt"]
160
+ validation_image = log["validation_image"]
161
+ formatted_images = []
162
+ formatted_images.append(np.asarray(validation_image))
163
+ for image in images:
164
+ formatted_images.append(np.asarray(image))
165
+ formatted_images = np.stack(formatted_images)
166
+ tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC")
167
+
168
+ elif tracker.name == "wandb":
169
+ formatted_images = []
170
+ for log in image_logs:
171
+ images = log["images"]
172
+ validation_prompt = log["validation_prompt"]
173
+ validation_image = log["validation_image"]
174
+ formatted_images.append(wandb.Image(validation_image, caption="Conditioning"))
175
+ for image in images:
176
+ image = wandb.Image(image, caption=validation_prompt)
177
+ formatted_images.append(image)
178
+
179
+ tracker.log({tracker_key: formatted_images})
180
+ else:
181
+ logger.warning(f"image logging not implemented for {tracker.name}")
182
+
183
+ del pipeline
184
+ free_memory()
185
+ return image_logs
186
+
187
+
188
+ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
189
+ img_str = ""
190
+ if image_logs is not None:
191
+ img_str = "You can find some example images below.\n\n"
192
+ for i, log in enumerate(image_logs):
193
+ images = log["images"]
194
+ validation_prompt = log["validation_prompt"]
195
+ validation_image = log["validation_image"]
196
+ validation_image.save(os.path.join(repo_folder, "image_control.png"))
197
+ img_str += f"prompt: {validation_prompt}\n"
198
+ images = [validation_image] + images
199
+ make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
200
+ img_str += f"![images_{i})](./images_{i}.png)\n"
201
+
202
+ model_description = f"""
203
+ # control-lora-{repo_id}
204
+
205
+ These are Control LoRA weights trained on {base_model} with new type of conditioning.
206
+ {img_str}
207
+
208
+ ## License
209
+
210
+ Please adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)
211
+ """
212
+
213
+ model_card = load_or_create_model_card(
214
+ repo_id_or_path=repo_id,
215
+ from_training=True,
216
+ license="other",
217
+ base_model=base_model,
218
+ model_description=model_description,
219
+ inference=True,
220
+ )
221
+
222
+ tags = [
223
+ "flux-kontext",
224
+ "flux-diffusers",
225
+ "text-to-image",
226
+ "diffusers",
227
+ "control-lora",
228
+ "diffusers-training",
229
+ "lora",
230
+ ]
231
+ model_card = populate_model_card(model_card, tags=tags)
232
+
233
+ model_card.save(os.path.join(repo_folder, "README.md"))
234
+
235
+
236
+ def parse_args(input_args=None):
237
+ parser = argparse.ArgumentParser(description="Simple example of a Control LoRA training script.")
238
+ parser.add_argument(
239
+ "--pretrained_model_name_or_path",
240
+ type=str,
241
+ default=None,
242
+ required=True,
243
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
244
+ )
245
+ parser.add_argument(
246
+ "--variant",
247
+ type=str,
248
+ default=None,
249
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
250
+ )
251
+ parser.add_argument(
252
+ "--revision",
253
+ type=str,
254
+ default=None,
255
+ required=False,
256
+ help="Revision of pretrained model identifier from huggingface.co/models.",
257
+ )
258
+ parser.add_argument(
259
+ "--output_dir",
260
+ type=str,
261
+ default="control-lora",
262
+ help="The output directory where the model predictions and checkpoints will be written.",
263
+ )
264
+ parser.add_argument(
265
+ "--cache_dir",
266
+ type=str,
267
+ default=None,
268
+ help="The directory where the downloaded models and datasets will be stored.",
269
+ )
270
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
271
+ parser.add_argument(
272
+ "--resolution",
273
+ type=int,
274
+ default=1024,
275
+ help=(
276
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
277
+ " resolution"
278
+ ),
279
+ )
280
+ parser.add_argument(
281
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
282
+ )
283
+ parser.add_argument("--num_train_epochs", type=int, default=1)
284
+ parser.add_argument(
285
+ "--max_train_steps",
286
+ type=int,
287
+ default=None,
288
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
289
+ )
290
+ parser.add_argument(
291
+ "--checkpointing_steps",
292
+ type=int,
293
+ default=500,
294
+ help=(
295
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
296
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
297
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
298
+ "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
299
+ "instructions."
300
+ ),
301
+ )
302
+ parser.add_argument(
303
+ "--checkpoints_total_limit",
304
+ type=int,
305
+ default=None,
306
+ help=("Max number of checkpoints to store."),
307
+ )
308
+ parser.add_argument(
309
+ "--resume_from_checkpoint",
310
+ type=str,
311
+ default=None,
312
+ help=(
313
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
314
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
315
+ ),
316
+ )
317
+ parser.add_argument(
318
+ "--proportion_empty_prompts",
319
+ type=float,
320
+ default=0,
321
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
322
+ )
323
+ parser.add_argument(
324
+ "--rank",
325
+ type=int,
326
+ default=4,
327
+ help=("The dimension of the LoRA update matrices."),
328
+ )
329
+ parser.add_argument("--use_lora_bias", action="store_true", help="If training the bias of lora_B layers.")
330
+ parser.add_argument(
331
+ "--lora_layers",
332
+ type=str,
333
+ default=None,
334
+ help=(
335
+ 'The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only'
336
+ ),
337
+ )
338
+ parser.add_argument(
339
+ "--gaussian_init_lora",
340
+ action="store_true",
341
+ help="If using the Gaussian init strategy. When False, we follow the original LoRA init strategy.",
342
+ )
343
+ parser.add_argument("--train_norm_layers", action="store_true", help="Whether to train the norm scales.")
344
+ parser.add_argument(
345
+ "--gradient_accumulation_steps",
346
+ type=int,
347
+ default=1,
348
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
349
+ )
350
+ parser.add_argument(
351
+ "--gradient_checkpointing",
352
+ action="store_true",
353
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
354
+ )
355
+ parser.add_argument(
356
+ "--learning_rate",
357
+ type=float,
358
+ default=5e-6,
359
+ help="Initial learning rate (after the potential warmup period) to use.",
360
+ )
361
+ parser.add_argument(
362
+ "--scale_lr",
363
+ action="store_true",
364
+ default=False,
365
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
366
+ )
367
+ parser.add_argument(
368
+ "--lr_scheduler",
369
+ type=str,
370
+ default="constant",
371
+ help=(
372
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
373
+ ' "constant", "constant_with_warmup"]'
374
+ ),
375
+ )
376
+ parser.add_argument(
377
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
378
+ )
379
+ parser.add_argument(
380
+ "--lr_num_cycles",
381
+ type=int,
382
+ default=1,
383
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
384
+ )
385
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
386
+ parser.add_argument(
387
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
388
+ )
389
+
390
+ parser.add_argument(
391
+ "--dataloader_num_workers",
392
+ type=int,
393
+ default=0,
394
+ help=(
395
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
396
+ ),
397
+ )
398
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
399
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
400
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
401
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
402
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
403
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
404
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
405
+ parser.add_argument(
406
+ "--hub_model_id",
407
+ type=str,
408
+ default=None,
409
+ help="The name of the repository to keep in sync with the local `output_dir`.",
410
+ )
411
+ parser.add_argument(
412
+ "--logging_dir",
413
+ type=str,
414
+ default="logs",
415
+ help=(
416
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
417
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
418
+ ),
419
+ )
420
+ parser.add_argument(
421
+ "--allow_tf32",
422
+ action="store_true",
423
+ help=(
424
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
425
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
426
+ ),
427
+ )
428
+ parser.add_argument(
429
+ "--report_to",
430
+ type=str,
431
+ default="tensorboard",
432
+ help=(
433
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
434
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
435
+ ),
436
+ )
437
+ parser.add_argument(
438
+ "--mixed_precision",
439
+ type=str,
440
+ default=None,
441
+ choices=["no", "fp16", "bf16"],
442
+ help=(
443
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
444
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
445
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
446
+ ),
447
+ )
448
+ parser.add_argument(
449
+ "--dataset_name",
450
+ type=str,
451
+ default=None,
452
+ help=(
453
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
454
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
455
+ " or to a folder containing files that 🤗 Datasets can understand."
456
+ ),
457
+ )
458
+ parser.add_argument(
459
+ "--dataset_config_name",
460
+ type=str,
461
+ default=None,
462
+ help="The config of the Dataset, leave as None if there's only one config.",
463
+ )
464
+ parser.add_argument(
465
+ "--image_column", type=str, default="image", help="The column of the dataset containing the target image."
466
+ )
467
+ parser.add_argument(
468
+ "--conditioning_image_column",
469
+ type=str,
470
+ default="conditioning_image",
471
+ help="The column of the dataset containing the control conditioning image.",
472
+ )
473
+ parser.add_argument(
474
+ "--caption_column",
475
+ type=str,
476
+ default="text",
477
+ help="The column of the dataset containing a caption or a list of captions.",
478
+ )
479
+ parser.add_argument("--log_dataset_samples", action="store_true", help="Whether to log somple dataset samples.")
480
+ parser.add_argument(
481
+ "--max_train_samples",
482
+ type=int,
483
+ default=None,
484
+ help=(
485
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
486
+ "value if set."
487
+ ),
488
+ )
489
+ parser.add_argument(
490
+ "--validation_prompt",
491
+ type=str,
492
+ default=None,
493
+ nargs="+",
494
+ help=(
495
+ "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
496
+ " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
497
+ " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
498
+ ),
499
+ )
500
+ parser.add_argument(
501
+ "--validation_image",
502
+ type=str,
503
+ default=None,
504
+ nargs="+",
505
+ help=(
506
+ "A set of paths to the control conditioning image be evaluated every `--validation_steps`"
507
+ " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
508
+ " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
509
+ " `--validation_image` that will be used with all `--validation_prompt`s."
510
+ ),
511
+ )
512
+ parser.add_argument(
513
+ "--num_validation_images",
514
+ type=int,
515
+ default=1,
516
+ help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair",
517
+ )
518
+ parser.add_argument(
519
+ "--validation_steps",
520
+ type=int,
521
+ default=100,
522
+ help=(
523
+ "Run validation every X steps. Validation consists of running the prompt"
524
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
525
+ " and logging the images."
526
+ ),
527
+ )
528
+ parser.add_argument(
529
+ "--tracker_project_name",
530
+ type=str,
531
+ default="flux_train_control_lora",
532
+ help=(
533
+ "The `project_name` argument passed to Accelerator.init_trackers for"
534
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
535
+ ),
536
+ )
537
+ parser.add_argument(
538
+ "--jsonl_for_train",
539
+ type=str,
540
+ default=None,
541
+ help="Path to the jsonl file containing the training data.",
542
+ )
543
+
544
+ parser.add_argument(
545
+ "--guidance_scale",
546
+ type=float,
547
+ default=30.0,
548
+ help="the guidance scale used for transformer.",
549
+ )
550
+
551
+ parser.add_argument(
552
+ "--upcast_before_saving",
553
+ action="store_true",
554
+ help=(
555
+ "Whether to upcast the trained transformer layers to float32 before saving (at the end of training). "
556
+ "Defaults to precision dtype used for training to save memory"
557
+ ),
558
+ )
559
+
560
+ parser.add_argument(
561
+ "--weighting_scheme",
562
+ type=str,
563
+ default="none",
564
+ choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
565
+ help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'),
566
+ )
567
+ parser.add_argument(
568
+ "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
569
+ )
570
+ parser.add_argument(
571
+ "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
572
+ )
573
+ parser.add_argument(
574
+ "--mode_scale",
575
+ type=float,
576
+ default=1.29,
577
+ help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
578
+ )
579
+ parser.add_argument(
580
+ "--offload",
581
+ action="store_true",
582
+ help="Whether to offload the VAE and the text encoders to CPU when they are not used.",
583
+ )
584
+
585
+ if input_args is not None:
586
+ args = parser.parse_args(input_args)
587
+ else:
588
+ args = parser.parse_args()
589
+
590
+ if args.dataset_name is None and args.jsonl_for_train is None:
591
+ raise ValueError("Specify either `--dataset_name` or `--jsonl_for_train`")
592
+
593
+ if args.dataset_name is not None and args.jsonl_for_train is not None:
594
+ raise ValueError("Specify only one of `--dataset_name` or `--jsonl_for_train`")
595
+
596
+ if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
597
+ raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
598
+
599
+ if args.validation_prompt is not None and args.validation_image is None:
600
+ raise ValueError("`--validation_image` must be set if `--validation_prompt` is set")
601
+
602
+ if args.validation_prompt is None and args.validation_image is not None:
603
+ raise ValueError("`--validation_prompt` must be set if `--validation_image` is set")
604
+
605
+ if (
606
+ args.validation_image is not None
607
+ and args.validation_prompt is not None
608
+ and len(args.validation_image) != 1
609
+ and len(args.validation_prompt) != 1
610
+ and len(args.validation_image) != len(args.validation_prompt)
611
+ ):
612
+ raise ValueError(
613
+ "Must provide either 1 `--validation_image`, 1 `--validation_prompt`,"
614
+ " or the same number of `--validation_prompt`s and `--validation_image`s"
615
+ )
616
+
617
+ if args.resolution % 8 != 0:
618
+ raise ValueError(
619
+ "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the Flux transformer."
620
+ )
621
+
622
+ return args
623
+
624
+
625
+ def get_train_dataset(args, accelerator):
626
+ dataset = None
627
+ if args.dataset_name is not None:
628
+ # Downloading and loading a dataset from the hub.
629
+ dataset = load_dataset(
630
+ args.dataset_name,
631
+ args.dataset_config_name,
632
+ cache_dir=args.cache_dir,
633
+ )
634
+ if args.jsonl_for_train is not None:
635
+ # load from json
636
+ dataset = load_dataset("json", data_files=args.jsonl_for_train, cache_dir=args.cache_dir)
637
+ dataset = dataset.flatten_indices()
638
+ # Preprocessing the datasets.
639
+ # We need to tokenize inputs and targets.
640
+ column_names = dataset["train"].column_names
641
+
642
+ # 6. Get the column names for input/target.
643
+ if args.image_column is None:
644
+ image_column = column_names[0]
645
+ logger.info(f"image column defaulting to {image_column}")
646
+ else:
647
+ image_column = args.image_column
648
+ if image_column not in column_names:
649
+ raise ValueError(
650
+ f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
651
+ )
652
+
653
+ if args.caption_column is None:
654
+ caption_column = column_names[1]
655
+ logger.info(f"caption column defaulting to {caption_column}")
656
+ else:
657
+ caption_column = args.caption_column
658
+ if caption_column not in column_names:
659
+ raise ValueError(
660
+ f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
661
+ )
662
+
663
+ if args.conditioning_image_column is None:
664
+ conditioning_image_column = column_names[2]
665
+ logger.info(f"conditioning image column defaulting to {conditioning_image_column}")
666
+ else:
667
+ conditioning_image_column = args.conditioning_image_column
668
+ if conditioning_image_column not in column_names:
669
+ raise ValueError(
670
+ f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
671
+ )
672
+
673
+ with accelerator.main_process_first():
674
+ train_dataset = dataset["train"].shuffle(seed=args.seed)
675
+ if args.max_train_samples is not None:
676
+ train_dataset = train_dataset.select(range(args.max_train_samples))
677
+ return train_dataset
678
+
679
+
680
+ def prepare_train_dataset(dataset, accelerator):
681
+ image_transforms = transforms.Compose(
682
+ [
683
+ transforms.Resize((args.resolution, args.resolution), interpolation=transforms.InterpolationMode.BILINEAR),
684
+ transforms.ToTensor(),
685
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
686
+ ]
687
+ )
688
+
689
+ def preprocess_train(examples):
690
+ images = [
691
+ (image.convert("RGB") if not isinstance(image, str) else Image.open(image).convert("RGB"))
692
+ for image in examples[args.image_column]
693
+ ]
694
+ images = [image_transforms(image) for image in images]
695
+
696
+ conditioning_images = [
697
+ (image.convert("RGB") if not isinstance(image, str) else Image.open(image).convert("RGB"))
698
+ for image in examples[args.conditioning_image_column]
699
+ ]
700
+ conditioning_images = [image_transforms(image) for image in conditioning_images]
701
+ examples["pixel_values"] = images
702
+ examples["conditioning_pixel_values"] = conditioning_images
703
+
704
+ is_caption_list = isinstance(examples[args.caption_column][0], list)
705
+ if is_caption_list:
706
+ examples["captions"] = [max(example, key=len) for example in examples[args.caption_column]]
707
+ else:
708
+ examples["captions"] = list(examples[args.caption_column])
709
+
710
+ return examples
711
+
712
+ with accelerator.main_process_first():
713
+ dataset = dataset.with_transform(preprocess_train)
714
+
715
+ return dataset
716
+
717
+ # NOTE: The meaning of "conditioning_pixel_values" differs.
718
+ # The conditioning_pixel_values can be a list of values in the future.
719
+ def collate_fn(examples):
720
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
721
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
722
+ conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples])
723
+ conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()
724
+ captions = [example["captions"] for example in examples]
725
+ return {"pixel_values": pixel_values, "conditioning_pixel_values": conditioning_pixel_values, "captions": captions}
726
+
727
+
728
+ def main(args):
729
+ if args.report_to == "wandb" and args.hub_token is not None:
730
+ raise ValueError(
731
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
732
+ " Please use `huggingface-cli login` to authenticate with the Hub."
733
+ )
734
+ if args.use_lora_bias and args.gaussian_init_lora:
735
+ raise ValueError("`gaussian` LoRA init scheme isn't supported when `use_lora_bias` is True.")
736
+
737
+ logging_out_dir = Path(args.output_dir, args.logging_dir)
738
+
739
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
740
+ # due to pytorch#99272, MPS does not yet support bfloat16.
741
+ raise ValueError(
742
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
743
+ )
744
+
745
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=str(logging_out_dir))
746
+
747
+ accelerator = Accelerator(
748
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
749
+ mixed_precision=args.mixed_precision,
750
+ log_with=args.report_to,
751
+ project_config=accelerator_project_config,
752
+ )
753
+
754
+ # Disable AMP for MPS. A technique for accelerating machine learning computations on iOS and macOS devices.
755
+ if torch.backends.mps.is_available():
756
+ logger.info("MPS is enabled. Disabling AMP.")
757
+ accelerator.native_amp = False
758
+
759
+ # Make one log on every process with the configuration for debugging.
760
+ logging.basicConfig(
761
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
762
+ datefmt="%m/%d/%Y %H:%M:%S",
763
+ # DEBUG, INFO, WARNING, ERROR, CRITICAL
764
+ level=logging.INFO,
765
+ )
766
+ logger.info(accelerator.state, main_process_only=False)
767
+
768
+ if accelerator.is_local_main_process:
769
+ transformers.utils.logging.set_verbosity_warning()
770
+ diffusers.utils.logging.set_verbosity_info()
771
+ else:
772
+ transformers.utils.logging.set_verbosity_error()
773
+ diffusers.utils.logging.set_verbosity_error()
774
+
775
+ # If passed along, set the training seed now.
776
+ if args.seed is not None:
777
+ set_seed(args.seed)
778
+
779
+ # Handle the repository creation
780
+ if accelerator.is_main_process:
781
+ if args.output_dir is not None:
782
+ os.makedirs(args.output_dir, exist_ok=True)
783
+
784
+ if args.push_to_hub:
785
+ repo_id = create_repo(
786
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
787
+ ).repo_id
788
+
789
+ # Load models. We will load the text encoders later in a pipeline to compute
790
+ # embeddings.
791
+ vae = AutoencoderKL.from_pretrained(
792
+ args.pretrained_model_name_or_path,
793
+ subfolder="vae",
794
+ revision=args.revision,
795
+ variant=args.variant,
796
+ )
797
+ vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
798
+ flux_transformer = FluxTransformer2DModel.from_pretrained(
799
+ args.pretrained_model_name_or_path,
800
+ subfolder="transformer",
801
+ revision=args.revision,
802
+ variant=args.variant,
803
+ )
804
+ logger.info("All models loaded successfully")
805
+
806
+ noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
807
+ args.pretrained_model_name_or_path,
808
+ subfolder="scheduler",
809
+ )
810
+ noise_scheduler_copy = copy.deepcopy(noise_scheduler)
811
+ vae.requires_grad_(False)
812
+ flux_transformer.requires_grad_(False)
813
+
814
+ # cast down and move to the CPU
815
+ weight_dtype = torch.float32
816
+ if accelerator.mixed_precision == "fp16":
817
+ weight_dtype = torch.float16
818
+ elif accelerator.mixed_precision == "bf16":
819
+ weight_dtype = torch.bfloat16
820
+
821
+ # let's not move the VAE to the GPU yet.
822
+ vae.to(dtype=torch.float32) # keep the VAE in float32.
823
+ flux_transformer.to(dtype=weight_dtype, device=accelerator.device)
824
+
825
+ # We do not need to double the input channels in flux kontext pipeline.
826
+ # with torch.no_grad():
827
+ # initial_input_channels = flux_transformer.config.in_channels
828
+ # new_linear = torch.nn.Linear(
829
+ # flux_transformer.x_embedder.in_features * 2,
830
+ # flux_transformer.x_embedder.out_features,
831
+ # bias=flux_transformer.x_embedder.bias is not None,
832
+ # dtype=flux_transformer.dtype,
833
+ # device=flux_transformer.device,
834
+ # )
835
+ # new_linear.weight.zero_()
836
+ # new_linear.weight[:, :initial_input_channels].copy_(flux_transformer.x_embedder.weight)
837
+ # if flux_transformer.x_embedder.bias is not None:
838
+ # new_linear.bias.copy_(flux_transformer.x_embedder.bias)
839
+ # flux_transformer.x_embedder = new_linear
840
+
841
+ # assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0)
842
+ # flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels)
843
+
844
+ if args.lora_layers is not None:
845
+ if args.lora_layers != "all-linear":
846
+ target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
847
+ # add the input layer to the mix.
848
+ if "x_embedder" not in target_modules:
849
+ target_modules.append("x_embedder")
850
+ elif args.lora_layers == "all-linear":
851
+ target_modules = set()
852
+ for name, module in flux_transformer.named_modules():
853
+ if isinstance(module, torch.nn.Linear):
854
+ target_modules.add(name)
855
+ target_modules = list(target_modules)
856
+ else:
857
+ target_modules = [
858
+ "x_embedder",
859
+ "attn.to_k",
860
+ "attn.to_q",
861
+ "attn.to_v",
862
+ "attn.to_out.0",
863
+ "attn.add_k_proj",
864
+ "attn.add_q_proj",
865
+ "attn.add_v_proj",
866
+ "attn.to_add_out",
867
+ "ff.net.0.proj",
868
+ "ff.net.2",
869
+ "ff_context.net.0.proj",
870
+ "ff_context.net.2",
871
+ ]
872
+ transformer_lora_config = LoraConfig(
873
+ r=args.rank,
874
+ lora_alpha=args.rank,
875
+ init_lora_weights="gaussian" if args.gaussian_init_lora else True,
876
+ target_modules=target_modules,
877
+ lora_bias=args.use_lora_bias,
878
+ )
879
+ flux_transformer.add_adapter(transformer_lora_config)
880
+
881
+ if args.train_norm_layers:
882
+ for name, param in flux_transformer.named_parameters():
883
+ if any(k in name for k in NORM_LAYER_PREFIXES):
884
+ param.requires_grad = True
885
+
886
+ def unwrap_model(model):
887
+ model = accelerator.unwrap_model(model)
888
+ model = model._orig_mod if is_compiled_module(model) else model
889
+ return model
890
+
891
+ # `accelerate` 0.16.0 will have better support for customized saving
892
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
893
+
894
+ def save_model_hook(models, weights, output_dir):
895
+ if accelerator.is_main_process:
896
+ transformer_lora_layers_to_save = None
897
+
898
+ for model in models:
899
+ if isinstance(unwrap_model(model), type(unwrap_model(flux_transformer))):
900
+ model = unwrap_model(model)
901
+ transformer_lora_layers_to_save = get_peft_model_state_dict(model)
902
+ if args.train_norm_layers:
903
+ transformer_norm_layers_to_save = {
904
+ f"transformer.{name}": param
905
+ for name, param in model.named_parameters()
906
+ if any(k in name for k in NORM_LAYER_PREFIXES)
907
+ }
908
+ transformer_lora_layers_to_save = {
909
+ **transformer_lora_layers_to_save,
910
+ **transformer_norm_layers_to_save,
911
+ }
912
+ else:
913
+ raise ValueError(f"unexpected save model: {model.__class__}")
914
+
915
+ # make sure to pop weight so that corresponding model is not saved again
916
+ if weights:
917
+ weights.pop()
918
+
919
+ FluxKontextPipeline.save_lora_weights(
920
+ output_dir,
921
+ transformer_lora_layers=transformer_lora_layers_to_save,
922
+ )
923
+
924
+ def load_model_hook(models, input_dir):
925
+ transformer_ = None
926
+
927
+ if not accelerator.distributed_type == DistributedType.DEEPSPEED:
928
+ while len(models) > 0:
929
+ model = models.pop()
930
+
931
+ if isinstance(model, type(unwrap_model(flux_transformer))):
932
+ transformer_ = model
933
+ else:
934
+ raise ValueError(f"unexpected save model: {model.__class__}")
935
+ else:
936
+ transformer_ = FluxTransformer2DModel.from_pretrained(
937
+ args.pretrained_model_name_or_path, subfolder="transformer"
938
+ ).to(accelerator.device, weight_dtype)
939
+
940
+ # NOTE: We do not need double the input channels in flux kontext pipeline.
941
+
942
+ # Handle input dimension doubling before adding adapter
943
+ # with torch.no_grad():
944
+ # initial_input_channels = transformer_.config.in_channels
945
+ # new_linear = torch.nn.Linear(
946
+ # transformer_.x_embedder.in_features * 2,
947
+ # transformer_.x_embedder.out_features,
948
+ # bias=transformer_.x_embedder.bias is not None,
949
+ # dtype=transformer_.dtype,
950
+ # device=transformer_.device,
951
+ # )
952
+ # new_linear.weight.zero_()
953
+ # new_linear.weight[:, :initial_input_channels].copy_(transformer_.x_embedder.weight)
954
+ # if transformer_.x_embedder.bias is not None:
955
+ # new_linear.bias.copy_(transformer_.x_embedder.bias)
956
+ # transformer_.x_embedder = new_linear
957
+ # transformer_.register_to_config(in_channels=initial_input_channels * 2)
958
+
959
+ transformer_.add_adapter(transformer_lora_config)
960
+
961
+ lora_state_dict = FluxKontextPipeline.lora_state_dict(input_dir)
962
+ transformer_lora_state_dict = {
963
+ f"{k.replace('transformer.', '')}": v
964
+ for k, v in lora_state_dict.items()
965
+ if k.startswith("transformer.") and "lora" in k
966
+ }
967
+ incompatible_keys = set_peft_model_state_dict(
968
+ transformer_, transformer_lora_state_dict, adapter_name="default"
969
+ )
970
+ if incompatible_keys is not None:
971
+ # check only for unexpected keys
972
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
973
+ if unexpected_keys:
974
+ logger.warning(
975
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
976
+ f" {unexpected_keys}. "
977
+ )
978
+ if args.train_norm_layers:
979
+ transformer_norm_state_dict = {
980
+ k: v
981
+ for k, v in lora_state_dict.items()
982
+ if k.startswith("transformer.") and any(norm_k in k for norm_k in NORM_LAYER_PREFIXES)
983
+ }
984
+ transformer_._transformer_norm_layers = FluxKontextPipeline._load_norm_into_transformer(
985
+ transformer_norm_state_dict,
986
+ transformer=transformer_,
987
+ discard_original_layers=False,
988
+ )
989
+
990
+ # Make sure the trainable params are in float32. This is again needed since the base models
991
+ # are in `weight_dtype`. More details:
992
+ # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
993
+ if args.mixed_precision == "fp16":
994
+ models = [transformer_]
995
+ # only upcast trainable parameters (LoRA) into fp32
996
+ cast_training_params(models)
997
+
998
+ accelerator.register_save_state_pre_hook(save_model_hook)
999
+ accelerator.register_load_state_pre_hook(load_model_hook)
1000
+
1001
+ # Make sure the trainable params are in float32.
1002
+ if args.mixed_precision == "fp16":
1003
+ models = [flux_transformer]
1004
+ # only upcast trainable parameters (LoRA) into fp32
1005
+ cast_training_params(models, dtype=torch.float32)
1006
+
1007
+ if args.gradient_checkpointing:
1008
+ flux_transformer.enable_gradient_checkpointing()
1009
+
1010
+ # Enable TF32 for faster training on Ampere GPUs,
1011
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
1012
+ if args.allow_tf32:
1013
+ torch.backends.cuda.matmul.allow_tf32 = True
1014
+
1015
+ if args.scale_lr:
1016
+ args.learning_rate = (
1017
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
1018
+ )
1019
+
1020
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
1021
+ if args.use_8bit_adam:
1022
+ try:
1023
+ import bitsandbytes as bnb
1024
+ except ImportError:
1025
+ raise ImportError(
1026
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
1027
+ )
1028
+
1029
+ optimizer_class = bnb.optim.AdamW8bit
1030
+ else:
1031
+ optimizer_class = torch.optim.AdamW
1032
+
1033
+ # Optimization parameters
1034
+ transformer_lora_parameters = list(filter(lambda p: p.requires_grad, flux_transformer.parameters()))
1035
+ optimizer = optimizer_class(
1036
+ transformer_lora_parameters,
1037
+ lr=args.learning_rate,
1038
+ betas=(args.adam_beta1, args.adam_beta2),
1039
+ weight_decay=args.adam_weight_decay,
1040
+ eps=args.adam_epsilon,
1041
+ )
1042
+
1043
+ # Prepare dataset and dataloader.
1044
+ train_dataset = get_train_dataset(args, accelerator)
1045
+ train_dataset = prepare_train_dataset(train_dataset, accelerator)
1046
+ train_dataloader = torch.utils.data.DataLoader(
1047
+ train_dataset,
1048
+ shuffle=True,
1049
+ collate_fn=collate_fn,
1050
+ batch_size=args.train_batch_size,
1051
+ num_workers=args.dataloader_num_workers,
1052
+ )
1053
+
1054
+ # Scheduler and math around the number of training steps.
1055
+ # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
1056
+ if args.max_train_steps is None:
1057
+ len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
1058
+ num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
1059
+ num_training_steps_for_scheduler = (
1060
+ args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
1061
+ )
1062
+ else:
1063
+ num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
1064
+
1065
+ lr_scheduler = get_scheduler(
1066
+ args.lr_scheduler,
1067
+ optimizer=optimizer,
1068
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
1069
+ num_training_steps=num_training_steps_for_scheduler,
1070
+ num_cycles=args.lr_num_cycles,
1071
+ power=args.lr_power,
1072
+ )
1073
+ # Prepare everything with our `accelerator`.
1074
+ flux_transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1075
+ flux_transformer, optimizer, train_dataloader, lr_scheduler
1076
+ )
1077
+
1078
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
1079
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1080
+ if args.max_train_steps is None:
1081
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1082
+ if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
1083
+ logger.warning(
1084
+ f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
1085
+ f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
1086
+ f"This inconsistency may result in the learning rate scheduler not functioning properly."
1087
+ )
1088
+ # Afterwards we recalculate our number of training epochs
1089
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
1090
+
1091
+ # We need to initialize the trackers we use, and also store our configuration.
1092
+ # The trackers initializes automatically on the main process.
1093
+ if accelerator.is_main_process:
1094
+ tracker_config = dict(vars(args))
1095
+
1096
+ # tensorboard cannot handle list types for config
1097
+ tracker_config.pop("validation_prompt")
1098
+ tracker_config.pop("validation_image")
1099
+
1100
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
1101
+
1102
+ # Train!
1103
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
1104
+
1105
+ logger.info("***** Running training *****")
1106
+ logger.info(f" Num examples = {len(train_dataset)}")
1107
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
1108
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
1109
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
1110
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
1111
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1112
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
1113
+ global_step = 0
1114
+ first_epoch = 0
1115
+
1116
+ # Create a pipeline for text encoding. We will move this pipeline to GPU/CPU as needed.
1117
+ text_encoding_pipeline = FluxKontextPipeline.from_pretrained(
1118
+ args.pretrained_model_name_or_path, transformer=None, vae=None, torch_dtype=weight_dtype
1119
+ )
1120
+
1121
+ # Potentially load in the weights and states from a previous save
1122
+ if args.resume_from_checkpoint:
1123
+ if args.resume_from_checkpoint != "latest":
1124
+ path = os.path.basename(args.resume_from_checkpoint)
1125
+ else:
1126
+ # Get the most recent checkpoint
1127
+ dirs = os.listdir(args.output_dir)
1128
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
1129
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
1130
+ path = dirs[-1] if len(dirs) > 0 else None
1131
+
1132
+ if path is None:
1133
+ logger.info(f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.")
1134
+ args.resume_from_checkpoint = None
1135
+ initial_global_step = 0
1136
+ else:
1137
+ logger.info(f"Resuming from checkpoint {path}")
1138
+ accelerator.load_state(os.path.join(args.output_dir, path))
1139
+ global_step = int(path.split("-")[1])
1140
+
1141
+ initial_global_step = global_step
1142
+ first_epoch = global_step // num_update_steps_per_epoch
1143
+ else:
1144
+ initial_global_step = 0
1145
+
1146
+ if accelerator.is_main_process and args.report_to == "wandb" and args.log_dataset_samples:
1147
+ logger.info("Logging some dataset samples.")
1148
+ formatted_images = []
1149
+ formatted_control_images = []
1150
+ all_prompts = []
1151
+ for i, batch in enumerate(train_dataloader):
1152
+ images = (batch["pixel_values"] + 1) / 2
1153
+ control_images = (batch["conditioning_pixel_values"] + 1) / 2
1154
+ prompts = batch["captions"]
1155
+
1156
+ if len(formatted_images) > 10:
1157
+ break
1158
+
1159
+ for img, control_img, prompt in zip(images, control_images, prompts):
1160
+ formatted_images.append(img)
1161
+ formatted_control_images.append(control_img)
1162
+ all_prompts.append(prompt)
1163
+
1164
+ logged_artifacts = []
1165
+ for img, control_img, prompt in zip(formatted_images, formatted_control_images, all_prompts):
1166
+ logged_artifacts.append(wandb.Image(control_img, caption="Conditioning"))
1167
+ logged_artifacts.append(wandb.Image(img, caption=prompt))
1168
+
1169
+ wandb_tracker = [tracker for tracker in accelerator.trackers if tracker.name == "wandb"]
1170
+ wandb_tracker[0].log({"dataset_samples": logged_artifacts})
1171
+
1172
+ progress_bar = tqdm(
1173
+ range(0, args.max_train_steps),
1174
+ initial=initial_global_step,
1175
+ desc="Steps",
1176
+ # Only show the progress bar once on each machine.
1177
+ disable=not accelerator.is_local_main_process,
1178
+ )
1179
+
1180
+ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
1181
+ sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
1182
+ schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
1183
+ timesteps = timesteps.to(accelerator.device)
1184
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
1185
+
1186
+ sigma = sigmas[step_indices].flatten()
1187
+ while len(sigma.shape) < n_dim:
1188
+ sigma = sigma.unsqueeze(-1)
1189
+ return sigma
1190
+
1191
+ image_logs = None
1192
+ for epoch in range(first_epoch, args.num_train_epochs):
1193
+ flux_transformer.train()
1194
+ for step, batch in enumerate(train_dataloader):
1195
+ with accelerator.accumulate(flux_transformer):
1196
+ # Convert images to latent space
1197
+ # vae encode
1198
+ pixel_latents = encode_images(batch["pixel_values"], vae.to(accelerator.device), weight_dtype)
1199
+ control_latents = encode_images(
1200
+ batch["conditioning_pixel_values"], vae.to(accelerator.device), weight_dtype
1201
+ )
1202
+
1203
+ if args.offload:
1204
+ # offload vae to CPU.
1205
+ vae.cpu()
1206
+
1207
+ # Sample a random timestep for each image
1208
+ # for weighting schemes where we sample timesteps non-uniformly
1209
+ bsz = pixel_latents.shape[0]
1210
+ noise = torch.randn_like(pixel_latents, device=accelerator.device, dtype=weight_dtype)
1211
+ u = compute_density_for_timestep_sampling(
1212
+ weighting_scheme=args.weighting_scheme,
1213
+ batch_size=bsz,
1214
+ logit_mean=args.logit_mean,
1215
+ logit_std=args.logit_std,
1216
+ mode_scale=args.mode_scale,
1217
+ )
1218
+ indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
1219
+ timesteps = noise_scheduler_copy.timesteps[indices].to(device=pixel_latents.device)
1220
+
1221
+ # Add noise according to flow matching.
1222
+ sigmas = get_sigmas(timesteps, n_dim=pixel_latents.ndim, dtype=pixel_latents.dtype)
1223
+ noisy_model_input = (1.0 - sigmas) * pixel_latents + sigmas * noise
1224
+ # Concatenate across **Sequence dimention**.
1225
+ # 3D RoPE need to be added.
1226
+
1227
+ # concatenated_noisy_model_input = torch.cat([noisy_model_input, control_latents], dim=1)
1228
+
1229
+ # pack the latents.
1230
+ # pachify
1231
+ packed_noisy_model_input = FluxKontextPipeline._pack_latents(
1232
+ noisy_model_input,
1233
+ batch_size=bsz,
1234
+ num_channels_latents=noisy_model_input.shape[1],
1235
+ height=noisy_model_input.shape[2],
1236
+ width=noisy_model_input.shape[3],
1237
+ )
1238
+ packed_control_latents = FluxKontextPipeline._pack_latents(
1239
+ control_latents,
1240
+ batch_size=bsz,
1241
+ num_channels_latents=control_latents.shape[1],
1242
+ height=control_latents.shape[2],
1243
+ width=control_latents.shape[3],
1244
+ )
1245
+
1246
+ # latent image ids for RoPE.
1247
+ latent_ids = FluxKontextPipeline._prepare_latent_image_ids(
1248
+ bsz,
1249
+ noisy_model_input.shape[2] // 2,
1250
+ noisy_model_input.shape[3] // 2,
1251
+ accelerator.device,
1252
+ weight_dtype,
1253
+ )
1254
+ control_latent_ids = FluxKontextPipeline._prepare_latent_image_ids(
1255
+ bsz,
1256
+ control_latents.shape[2] // 2,
1257
+ control_latents.shape[3] // 2,
1258
+ accelerator.device,
1259
+ weight_dtype,
1260
+ )
1261
+ control_latent_ids[..., 0] = 1 # set t = 1. Need modified when context are expended.
1262
+
1263
+ latent_ids = torch.cat([latent_ids, control_latent_ids], dim=0)
1264
+ # TODO: support for different latent image token len.
1265
+ model_input_length = packed_noisy_model_input.shape[1]
1266
+ latent_model_input = torch.cat([packed_noisy_model_input, packed_control_latents], dim=1)
1267
+
1268
+ # handle guidance
1269
+ if unwrap_model(flux_transformer).config.guidance_embeds:
1270
+ guidance_vec = torch.full(
1271
+ (bsz,),
1272
+ args.guidance_scale,
1273
+ device=noisy_model_input.device,
1274
+ dtype=weight_dtype,
1275
+ )
1276
+ else:
1277
+ guidance_vec = None
1278
+
1279
+ # text encoding.
1280
+ captions = batch["captions"]
1281
+ text_encoding_pipeline = text_encoding_pipeline.to("cuda")
1282
+ with torch.no_grad():
1283
+ prompt_embeds, pooled_prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt(
1284
+ captions, prompt_2=None
1285
+ )
1286
+ # this could be optimized by not having to do any text encoding and just
1287
+ # doing zeros on specified shapes for `prompt_embeds` and `pooled_prompt_embeds`
1288
+ if args.proportion_empty_prompts and random.random() < args.proportion_empty_prompts:
1289
+ prompt_embeds.zero_()
1290
+ pooled_prompt_embeds.zero_()
1291
+ if args.offload:
1292
+ text_encoding_pipeline = text_encoding_pipeline.to("cpu")
1293
+
1294
+ # Predict.
1295
+ model_pred = flux_transformer(
1296
+ hidden_states=latent_model_input,
1297
+ timestep=timesteps / 1000,
1298
+ guidance=guidance_vec,
1299
+ pooled_projections=pooled_prompt_embeds,
1300
+ encoder_hidden_states=prompt_embeds,
1301
+ txt_ids=text_ids,
1302
+ img_ids=latent_ids,
1303
+ return_dict=False,
1304
+ )[0]
1305
+ # only supervise on the noisy latent.
1306
+ model_pred = FluxKontextPipeline._unpack_latents(
1307
+ model_pred[:, :model_input_length],
1308
+ height=noisy_model_input.shape[2] * vae_scale_factor,
1309
+ width=noisy_model_input.shape[3] * vae_scale_factor,
1310
+ vae_scale_factor=vae_scale_factor,
1311
+ )
1312
+ # these weighting schemes use a uniform timestep sampling
1313
+ # and instead post-weight the loss
1314
+ weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
1315
+
1316
+ # flow-matching loss
1317
+ target = noise - pixel_latents
1318
+ loss = torch.mean(
1319
+ (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
1320
+ 1,
1321
+ )
1322
+ loss = loss.mean()
1323
+ accelerator.backward(loss)
1324
+
1325
+ if accelerator.sync_gradients:
1326
+ params_to_clip = flux_transformer.parameters()
1327
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1328
+ optimizer.step()
1329
+ lr_scheduler.step()
1330
+ optimizer.zero_grad()
1331
+
1332
+ # Checks if the accelerator has performed an optimization step behind the scenes
1333
+ if accelerator.sync_gradients:
1334
+ progress_bar.update(1)
1335
+ global_step += 1
1336
+
1337
+ # DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues.
1338
+ if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
1339
+ if global_step % args.checkpointing_steps == 0:
1340
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1341
+ if args.checkpoints_total_limit is not None:
1342
+ checkpoints = os.listdir(args.output_dir)
1343
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
1344
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
1345
+
1346
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1347
+ if len(checkpoints) >= args.checkpoints_total_limit:
1348
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
1349
+ removing_checkpoints = checkpoints[0:num_to_remove]
1350
+
1351
+ logger.info(
1352
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1353
+ )
1354
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
1355
+
1356
+ for removing_checkpoint in removing_checkpoints:
1357
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
1358
+ shutil.rmtree(removing_checkpoint)
1359
+
1360
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1361
+ accelerator.save_state(save_path)
1362
+ logger.info(f"Saved state to {save_path}")
1363
+
1364
+ if args.validation_prompt is not None and global_step % args.validation_steps == 0:
1365
+ image_logs = log_validation(
1366
+ flux_transformer=flux_transformer,
1367
+ args=args,
1368
+ accelerator=accelerator,
1369
+ weight_dtype=weight_dtype,
1370
+ step=global_step,
1371
+ )
1372
+
1373
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1374
+ progress_bar.set_postfix(**logs)
1375
+ accelerator.log(logs, step=global_step)
1376
+
1377
+ if global_step >= args.max_train_steps:
1378
+ break
1379
+
1380
+ # Create the pipeline using using the trained modules and save it.
1381
+ accelerator.wait_for_everyone()
1382
+ if accelerator.is_main_process:
1383
+ flux_transformer = unwrap_model(flux_transformer)
1384
+ if args.upcast_before_saving:
1385
+ flux_transformer.to(torch.float32)
1386
+ transformer_lora_layers = get_peft_model_state_dict(flux_transformer)
1387
+ if args.train_norm_layers:
1388
+ transformer_norm_layers = {
1389
+ f"transformer.{name}": param
1390
+ for name, param in flux_transformer.named_parameters()
1391
+ if any(k in name for k in NORM_LAYER_PREFIXES)
1392
+ }
1393
+ transformer_lora_layers = {**transformer_lora_layers, **transformer_norm_layers}
1394
+ FluxKontextPipeline.save_lora_weights(
1395
+ save_directory=args.output_dir,
1396
+ transformer_lora_layers=transformer_lora_layers,
1397
+ )
1398
+
1399
+ del flux_transformer
1400
+ del text_encoding_pipeline
1401
+ del vae
1402
+ free_memory()
1403
+
1404
+ # Run a final round of validation.
1405
+ image_logs = None
1406
+ if args.validation_prompt is not None:
1407
+ image_logs = log_validation(
1408
+ flux_transformer=None,
1409
+ args=args,
1410
+ accelerator=accelerator,
1411
+ weight_dtype=weight_dtype,
1412
+ step=global_step,
1413
+ is_final_validation=True,
1414
+ )
1415
+
1416
+ if args.push_to_hub:
1417
+ save_model_card(
1418
+ repo_id,
1419
+ image_logs=image_logs,
1420
+ base_model=args.pretrained_model_name_or_path,
1421
+ repo_folder=args.output_dir,
1422
+ )
1423
+ upload_folder(
1424
+ repo_id=repo_id,
1425
+ folder_path=args.output_dir,
1426
+ commit_message="End of training",
1427
+ ignore_patterns=["step_*", "epoch_*", "*.pt", "*.bin"],
1428
+ )
1429
+
1430
+ accelerator.end_training()
1431
+
1432
+
1433
+ if __name__ == "__main__":
1434
+ args = parse_args()
1435
+ main(args)