PaulaSerna commited on
Commit
0591151
·
1 Parent(s): e7d4561

End of training

Browse files
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ dog/alvan-nee-Id1DBHv4fbg-unsplash.jpeg filter=lfs diff=lfs merge=lfs -text
37
+ dog/alvan-nee-bQaAJCbNq3g-unsplash.jpeg filter=lfs diff=lfs merge=lfs -text
38
+ dog/alvan-nee-brFsZ7qszSY-unsplash.jpeg filter=lfs diff=lfs merge=lfs -text
39
+ dog/alvan-nee-eoqnr8ikwFE-unsplash.jpeg filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ---
3
+ license: creativeml-openrail-m
4
+ base_model: /mnt/nfs_disk/huggingface/cache/hub/models--runwayml--stable-diffusion-v1-5/snapshots/1d0c4ebf6ff58a5caecab40fa1406526bca4b5b9/
5
+ instance_prompt: a photo of sks dog
6
+ tags:
7
+ - stable-diffusion
8
+ - stable-diffusion-diffusers
9
+ - text-to-image
10
+ - diffusers
11
+ - dreambooth
12
+ inference: true
13
+ ---
14
+
15
+ # DreamBooth - PaulaSerna/dreambooth
16
+
17
+ This is a dreambooth model derived from /mnt/nfs_disk/huggingface/cache/hub/models--runwayml--stable-diffusion-v1-5/snapshots/1d0c4ebf6ff58a5caecab40fa1406526bca4b5b9/. The weights were trained on a photo of sks dog using [DreamBooth](https://dreambooth.github.io/).
18
+ You can find some example images in the following.
19
+
20
+
21
+
22
+ DreamBooth for the text encoder was enabled: False.
accelerate/default_config.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "compute_environment": "LOCAL_MACHINE",
3
+ "debug": false,
4
+ "distributed_type": "MULTI_GPU",
5
+ "downcast_bf16": false,
6
+ "machine_rank": 0,
7
+ "main_training_function": "main",
8
+ "mixed_precision": "no",
9
+ "num_machines": 1,
10
+ "num_processes": 2,
11
+ "rdzv_backend": "static",
12
+ "same_network": false,
13
+ "tpu_use_cluster": false,
14
+ "tpu_use_sudo": false,
15
+ "use_cpu": false
16
+ }
dog/alvan-nee-9M0tSjb-cpA-unsplash.jpeg ADDED
dog/alvan-nee-Id1DBHv4fbg-unsplash.jpeg ADDED

Git LFS Details

  • SHA256: a65d3a853b7c65dd4d394cb6b209f77666351d2bae7c6670c5677d8eb5981644
  • Pointer size: 132 Bytes
  • Size of remote file: 1.16 MB
dog/alvan-nee-bQaAJCbNq3g-unsplash.jpeg ADDED

Git LFS Details

  • SHA256: 4cda55c53c11843ed368eb8eb68fd79521ac7b839bdd70f8f89589cf7006ed97
  • Pointer size: 132 Bytes
  • Size of remote file: 1.4 MB
dog/alvan-nee-brFsZ7qszSY-unsplash.jpeg ADDED

Git LFS Details

  • SHA256: 9d8013d9efa2edb356e0f88c66de044f71247a99cab52b1628e753c2a08bb602
  • Pointer size: 132 Bytes
  • Size of remote file: 1.19 MB
dog/alvan-nee-eoqnr8ikwFE-unsplash.jpeg ADDED

Git LFS Details

  • SHA256: 5c9805758a8f8950a35df820f3bfc32b3c6ca2a0e0e214a7978ea147a233bd54
  • Pointer size: 132 Bytes
  • Size of remote file: 1.17 MB
feature_extractor/preprocessor_config.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "crop_size": {
3
+ "height": 224,
4
+ "width": 224
5
+ },
6
+ "do_center_crop": true,
7
+ "do_convert_rgb": true,
8
+ "do_normalize": true,
9
+ "do_rescale": true,
10
+ "do_resize": true,
11
+ "feature_extractor_type": "CLIPFeatureExtractor",
12
+ "image_mean": [
13
+ 0.48145466,
14
+ 0.4578275,
15
+ 0.40821073
16
+ ],
17
+ "image_processor_type": "CLIPImageProcessor",
18
+ "image_std": [
19
+ 0.26862954,
20
+ 0.26130258,
21
+ 0.27577711
22
+ ],
23
+ "resample": 3,
24
+ "rescale_factor": 0.00392156862745098,
25
+ "size": {
26
+ "shortest_edge": 224
27
+ }
28
+ }
logs/dreambooth/1698075438.4966881/events.out.tfevents.1698075438.worker2.3659641.1 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:052c7dfc7353b967c3e5c0bde6d6cd1078e6f850e16291e7ab7109ba14398a5a
3
+ size 2891
logs/dreambooth/1698075438.5000386/hparams.yml ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ adam_beta1: 0.9
2
+ adam_beta2: 0.999
3
+ adam_epsilon: 1.0e-08
4
+ adam_weight_decay: 0.01
5
+ allow_tf32: false
6
+ center_crop: false
7
+ checkpointing_steps: 500
8
+ checkpoints_total_limit: null
9
+ class_data_dir: null
10
+ class_labels_conditioning: null
11
+ class_prompt: null
12
+ dataloader_num_workers: 0
13
+ enable_xformers_memory_efficient_attention: false
14
+ gradient_accumulation_steps: 1
15
+ gradient_checkpointing: false
16
+ hub_model_id: null
17
+ hub_token: null
18
+ instance_data_dir: /home/rocky/llms/dreambooth/dog/
19
+ instance_prompt: a photo of sks dog
20
+ learning_rate: 5.0e-06
21
+ local_rank: 0
22
+ logging_dir: logs
23
+ lr_num_cycles: 1
24
+ lr_power: 1.0
25
+ lr_scheduler: constant
26
+ lr_warmup_steps: 0
27
+ max_grad_norm: 1.0
28
+ max_train_steps: 400
29
+ mixed_precision: null
30
+ num_class_images: 100
31
+ num_train_epochs: 134
32
+ num_validation_images: 4
33
+ offset_noise: false
34
+ output_dir: /home/rocky/llms/dreambooth/
35
+ pre_compute_text_embeddings: false
36
+ pretrained_model_name_or_path: /mnt/nfs_disk/huggingface/cache/hub/models--runwayml--stable-diffusion-v1-5/snapshots/1d0c4ebf6ff58a5caecab40fa1406526bca4b5b9/
37
+ prior_generation_precision: null
38
+ prior_loss_weight: 1.0
39
+ push_to_hub: true
40
+ report_to: tensorboard
41
+ resolution: 512
42
+ resume_from_checkpoint: null
43
+ revision: null
44
+ sample_batch_size: 4
45
+ scale_lr: false
46
+ seed: null
47
+ set_grads_to_none: false
48
+ skip_save_text_encoder: false
49
+ snr_gamma: null
50
+ text_encoder_use_attention_mask: false
51
+ tokenizer_max_length: null
52
+ tokenizer_name: null
53
+ train_batch_size: 1
54
+ train_text_encoder: false
55
+ use_8bit_adam: false
56
+ validation_prompt: null
57
+ validation_scheduler: DPMSolverMultistepScheduler
58
+ validation_steps: 100
59
+ with_prior_preservation: false
logs/dreambooth/1698077664.7348418/events.out.tfevents.1698077664.worker2.3729036.1 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:416a61aeb40b7cd850d7fbea73c7110acba3a253f70e565ffd8f92491a193894
3
+ size 2891
logs/dreambooth/1698077664.7381294/hparams.yml ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ adam_beta1: 0.9
2
+ adam_beta2: 0.999
3
+ adam_epsilon: 1.0e-08
4
+ adam_weight_decay: 0.01
5
+ allow_tf32: false
6
+ center_crop: false
7
+ checkpointing_steps: 500
8
+ checkpoints_total_limit: null
9
+ class_data_dir: null
10
+ class_labels_conditioning: null
11
+ class_prompt: null
12
+ dataloader_num_workers: 0
13
+ enable_xformers_memory_efficient_attention: false
14
+ gradient_accumulation_steps: 1
15
+ gradient_checkpointing: false
16
+ hub_model_id: null
17
+ hub_token: null
18
+ instance_data_dir: /home/rocky/llms/dreambooth/dog/
19
+ instance_prompt: a photo of sks dog
20
+ learning_rate: 5.0e-06
21
+ local_rank: 0
22
+ logging_dir: logs
23
+ lr_num_cycles: 1
24
+ lr_power: 1.0
25
+ lr_scheduler: constant
26
+ lr_warmup_steps: 0
27
+ max_grad_norm: 1.0
28
+ max_train_steps: 400
29
+ mixed_precision: null
30
+ num_class_images: 100
31
+ num_train_epochs: 134
32
+ num_validation_images: 4
33
+ offset_noise: false
34
+ output_dir: /home/rocky/llms/dreambooth/
35
+ pre_compute_text_embeddings: false
36
+ pretrained_model_name_or_path: /mnt/nfs_disk/huggingface/cache/hub/models--runwayml--stable-diffusion-v1-5/snapshots/1d0c4ebf6ff58a5caecab40fa1406526bca4b5b9/
37
+ prior_generation_precision: null
38
+ prior_loss_weight: 1.0
39
+ push_to_hub: true
40
+ report_to: tensorboard
41
+ resolution: 512
42
+ resume_from_checkpoint: null
43
+ revision: null
44
+ sample_batch_size: 4
45
+ scale_lr: false
46
+ seed: null
47
+ set_grads_to_none: false
48
+ skip_save_text_encoder: false
49
+ snr_gamma: null
50
+ text_encoder_use_attention_mask: false
51
+ tokenizer_max_length: null
52
+ tokenizer_name: null
53
+ train_batch_size: 1
54
+ train_text_encoder: false
55
+ use_8bit_adam: false
56
+ validation_prompt: null
57
+ validation_scheduler: DPMSolverMultistepScheduler
58
+ validation_steps: 100
59
+ with_prior_preservation: false
logs/dreambooth/events.out.tfevents.1698075438.worker2.3659641.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c14500c6563fbf99ec4149f714c8dbdc8c6daa64bdbcb5dc566f479fb739ae13
3
+ size 88
logs/dreambooth/events.out.tfevents.1698077664.worker2.3729036.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:09f3a04f1835de758b9a10f02aa9fbb12794c71b69316591fe91aff3fabc3208
3
+ size 33434
model_index.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "StableDiffusionPipeline",
3
+ "_diffusers_version": "0.22.0.dev0",
4
+ "_name_or_path": "/mnt/nfs_disk/huggingface/cache/hub/models--runwayml--stable-diffusion-v1-5/snapshots/1d0c4ebf6ff58a5caecab40fa1406526bca4b5b9/",
5
+ "feature_extractor": [
6
+ "transformers",
7
+ "CLIPImageProcessor"
8
+ ],
9
+ "requires_safety_checker": true,
10
+ "safety_checker": [
11
+ "stable_diffusion",
12
+ "StableDiffusionSafetyChecker"
13
+ ],
14
+ "scheduler": [
15
+ "diffusers",
16
+ "PNDMScheduler"
17
+ ],
18
+ "text_encoder": [
19
+ "transformers",
20
+ "CLIPTextModel"
21
+ ],
22
+ "tokenizer": [
23
+ "transformers",
24
+ "CLIPTokenizer"
25
+ ],
26
+ "unet": [
27
+ "diffusers",
28
+ "UNet2DConditionModel"
29
+ ],
30
+ "vae": [
31
+ "diffusers",
32
+ "AutoencoderKL"
33
+ ]
34
+ }
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ accelerate>=0.16.0
2
+ torchvision
3
+ transformers>=4.25.1
4
+ ftfy
5
+ tensorboard
6
+ Jinja2
safety_checker/config.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/mnt/nfs_disk/huggingface/cache/hub/models--runwayml--stable-diffusion-v1-5/snapshots/1d0c4ebf6ff58a5caecab40fa1406526bca4b5b9/safety_checker",
3
+ "architectures": [
4
+ "StableDiffusionSafetyChecker"
5
+ ],
6
+ "initializer_factor": 1.0,
7
+ "logit_scale_init_value": 2.6592,
8
+ "model_type": "clip",
9
+ "projection_dim": 768,
10
+ "text_config": {
11
+ "dropout": 0.0,
12
+ "hidden_size": 768,
13
+ "intermediate_size": 3072,
14
+ "model_type": "clip_text_model",
15
+ "num_attention_heads": 12
16
+ },
17
+ "torch_dtype": "float32",
18
+ "transformers_version": "4.33.2",
19
+ "vision_config": {
20
+ "dropout": 0.0,
21
+ "hidden_size": 1024,
22
+ "intermediate_size": 4096,
23
+ "model_type": "clip_vision_model",
24
+ "num_attention_heads": 16,
25
+ "num_hidden_layers": 24,
26
+ "patch_size": 14
27
+ }
28
+ }
safety_checker/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fb351a5ded815c3ff744968ad9c6b218d071b9d313d04f35e813b84b4c0ffde8
3
+ size 1215979664
scheduler/scheduler_config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "PNDMScheduler",
3
+ "_diffusers_version": "0.22.0.dev0",
4
+ "beta_end": 0.012,
5
+ "beta_schedule": "scaled_linear",
6
+ "beta_start": 0.00085,
7
+ "clip_sample": false,
8
+ "num_train_timesteps": 1000,
9
+ "prediction_type": "epsilon",
10
+ "set_alpha_to_one": false,
11
+ "skip_prk_steps": true,
12
+ "steps_offset": 1,
13
+ "timestep_spacing": "leading",
14
+ "trained_betas": null
15
+ }
text_encoder/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/mnt/nfs_disk/huggingface/cache/hub/models--runwayml--stable-diffusion-v1-5/snapshots/1d0c4ebf6ff58a5caecab40fa1406526bca4b5b9/",
3
+ "architectures": [
4
+ "CLIPTextModel"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 0,
8
+ "dropout": 0.0,
9
+ "eos_token_id": 2,
10
+ "hidden_act": "quick_gelu",
11
+ "hidden_size": 768,
12
+ "initializer_factor": 1.0,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 3072,
15
+ "layer_norm_eps": 1e-05,
16
+ "max_position_embeddings": 77,
17
+ "model_type": "clip_text_model",
18
+ "num_attention_heads": 12,
19
+ "num_hidden_layers": 12,
20
+ "pad_token_id": 1,
21
+ "projection_dim": 768,
22
+ "torch_dtype": "float32",
23
+ "transformers_version": "4.33.2",
24
+ "vocab_size": 49408
25
+ }
text_encoder/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:778d02eb9e707c3fbaae0b67b79ea0d1399b52e624fb634f2f19375ae7c047c3
3
+ size 492265168
tokenizer/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|startoftext|>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "<|endoftext|>",
17
+ "unk_token": {
18
+ "content": "<|endoftext|>",
19
+ "lstrip": false,
20
+ "normalized": true,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "bos_token": {
4
+ "__type": "AddedToken",
5
+ "content": "<|startoftext|>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false
10
+ },
11
+ "clean_up_tokenization_spaces": true,
12
+ "do_lower_case": true,
13
+ "eos_token": {
14
+ "__type": "AddedToken",
15
+ "content": "<|endoftext|>",
16
+ "lstrip": false,
17
+ "normalized": true,
18
+ "rstrip": false,
19
+ "single_word": false
20
+ },
21
+ "errors": "replace",
22
+ "model_max_length": 77,
23
+ "pad_token": "<|endoftext|>",
24
+ "tokenizer_class": "CLIPTokenizer",
25
+ "unk_token": {
26
+ "__type": "AddedToken",
27
+ "content": "<|endoftext|>",
28
+ "lstrip": false,
29
+ "normalized": true,
30
+ "rstrip": false,
31
+ "single_word": false
32
+ }
33
+ }
tokenizer/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
train_dreambooth.py ADDED
@@ -0,0 +1,1422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+
16
+ import argparse
17
+ import copy
18
+ import gc
19
+ import hashlib
20
+ import importlib
21
+ import itertools
22
+ import logging
23
+ import math
24
+ import os
25
+ import shutil
26
+ import warnings
27
+ from pathlib import Path
28
+
29
+ import numpy as np
30
+ import torch
31
+ import torch.nn.functional as F
32
+ import torch.utils.checkpoint
33
+ import transformers
34
+ from accelerate import Accelerator
35
+ from accelerate.logging import get_logger
36
+ from accelerate.utils import ProjectConfiguration, set_seed
37
+ from huggingface_hub import create_repo, model_info, upload_folder
38
+ from packaging import version
39
+ from PIL import Image
40
+ from PIL.ImageOps import exif_transpose
41
+ from torch.utils.data import Dataset
42
+ from torchvision import transforms
43
+ from tqdm.auto import tqdm
44
+ from transformers import AutoTokenizer, PretrainedConfig
45
+
46
+ import diffusers
47
+ from diffusers import (
48
+ AutoencoderKL,
49
+ DDPMScheduler,
50
+ DiffusionPipeline,
51
+ StableDiffusionPipeline,
52
+ UNet2DConditionModel,
53
+ )
54
+ from diffusers.optimization import get_scheduler
55
+ from diffusers.training_utils import compute_snr
56
+ from diffusers.utils import check_min_version, is_wandb_available
57
+ from diffusers.utils.import_utils import is_xformers_available
58
+
59
+
60
+ if is_wandb_available():
61
+ import wandb
62
+
63
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
64
+ check_min_version("0.22.0.dev0")
65
+
66
+ logger = get_logger(__name__)
67
+
68
+
69
+ def save_model_card(
70
+ repo_id: str,
71
+ images=None,
72
+ base_model=str,
73
+ train_text_encoder=False,
74
+ prompt=str,
75
+ repo_folder=None,
76
+ pipeline: DiffusionPipeline = None,
77
+ ):
78
+ img_str = ""
79
+ for i, image in enumerate(images):
80
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
81
+ img_str += f"![img_{i}](./image_{i}.png)\n"
82
+
83
+ yaml = f"""
84
+ ---
85
+ license: creativeml-openrail-m
86
+ base_model: {base_model}
87
+ instance_prompt: {prompt}
88
+ tags:
89
+ - {'stable-diffusion' if isinstance(pipeline, StableDiffusionPipeline) else 'if'}
90
+ - {'stable-diffusion-diffusers' if isinstance(pipeline, StableDiffusionPipeline) else 'if-diffusers'}
91
+ - text-to-image
92
+ - diffusers
93
+ - dreambooth
94
+ inference: true
95
+ ---
96
+ """
97
+ model_card = f"""
98
+ # DreamBooth - {repo_id}
99
+
100
+ This is a dreambooth model derived from {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/).
101
+ You can find some example images in the following. \n
102
+ {img_str}
103
+
104
+ DreamBooth for the text encoder was enabled: {train_text_encoder}.
105
+ """
106
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
107
+ f.write(yaml + model_card)
108
+
109
+
110
+ def log_validation(
111
+ text_encoder,
112
+ tokenizer,
113
+ unet,
114
+ vae,
115
+ args,
116
+ accelerator,
117
+ weight_dtype,
118
+ global_step,
119
+ prompt_embeds,
120
+ negative_prompt_embeds,
121
+ ):
122
+ logger.info(
123
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
124
+ f" {args.validation_prompt}."
125
+ )
126
+
127
+ pipeline_args = {}
128
+
129
+ if vae is not None:
130
+ pipeline_args["vae"] = vae
131
+
132
+ if text_encoder is not None:
133
+ text_encoder = accelerator.unwrap_model(text_encoder)
134
+
135
+ # create pipeline (note: unet and vae are loaded again in float32)
136
+ pipeline = DiffusionPipeline.from_pretrained(
137
+ args.pretrained_model_name_or_path,
138
+ tokenizer=tokenizer,
139
+ text_encoder=text_encoder,
140
+ unet=accelerator.unwrap_model(unet),
141
+ revision=args.revision,
142
+ torch_dtype=weight_dtype,
143
+ **pipeline_args,
144
+ )
145
+
146
+ # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
147
+ scheduler_args = {}
148
+
149
+ if "variance_type" in pipeline.scheduler.config:
150
+ variance_type = pipeline.scheduler.config.variance_type
151
+
152
+ if variance_type in ["learned", "learned_range"]:
153
+ variance_type = "fixed_small"
154
+
155
+ scheduler_args["variance_type"] = variance_type
156
+
157
+ module = importlib.import_module("diffusers")
158
+ scheduler_class = getattr(module, args.validation_scheduler)
159
+ pipeline.scheduler = scheduler_class.from_config(pipeline.scheduler.config, **scheduler_args)
160
+ pipeline = pipeline.to(accelerator.device)
161
+ pipeline.set_progress_bar_config(disable=True)
162
+
163
+ if args.pre_compute_text_embeddings:
164
+ pipeline_args = {
165
+ "prompt_embeds": prompt_embeds,
166
+ "negative_prompt_embeds": negative_prompt_embeds,
167
+ }
168
+ else:
169
+ pipeline_args = {"prompt": args.validation_prompt}
170
+
171
+ # run inference
172
+ generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
173
+ images = []
174
+ if args.validation_images is None:
175
+ for _ in range(args.num_validation_images):
176
+ with torch.autocast("cuda"):
177
+ image = pipeline(**pipeline_args, num_inference_steps=25, generator=generator).images[0]
178
+ images.append(image)
179
+ else:
180
+ for image in args.validation_images:
181
+ image = Image.open(image)
182
+ image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
183
+ images.append(image)
184
+
185
+ for tracker in accelerator.trackers:
186
+ if tracker.name == "tensorboard":
187
+ np_images = np.stack([np.asarray(img) for img in images])
188
+ tracker.writer.add_images("validation", np_images, global_step, dataformats="NHWC")
189
+ if tracker.name == "wandb":
190
+ tracker.log(
191
+ {
192
+ "validation": [
193
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
194
+ ]
195
+ }
196
+ )
197
+
198
+ del pipeline
199
+ torch.cuda.empty_cache()
200
+
201
+ return images
202
+
203
+
204
+ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
205
+ text_encoder_config = PretrainedConfig.from_pretrained(
206
+ pretrained_model_name_or_path,
207
+ subfolder="text_encoder",
208
+ revision=revision,
209
+ )
210
+ model_class = text_encoder_config.architectures[0]
211
+
212
+ if model_class == "CLIPTextModel":
213
+ from transformers import CLIPTextModel
214
+
215
+ return CLIPTextModel
216
+ elif model_class == "RobertaSeriesModelWithTransformation":
217
+ from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
218
+
219
+ return RobertaSeriesModelWithTransformation
220
+ elif model_class == "T5EncoderModel":
221
+ from transformers import T5EncoderModel
222
+
223
+ return T5EncoderModel
224
+ else:
225
+ raise ValueError(f"{model_class} is not supported.")
226
+
227
+
228
+ def parse_args(input_args=None):
229
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
230
+ parser.add_argument(
231
+ "--pretrained_model_name_or_path",
232
+ type=str,
233
+ default=None,
234
+ required=True,
235
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
236
+ )
237
+ parser.add_argument(
238
+ "--revision",
239
+ type=str,
240
+ default=None,
241
+ required=False,
242
+ help=(
243
+ "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
244
+ " float32 precision."
245
+ ),
246
+ )
247
+ parser.add_argument(
248
+ "--tokenizer_name",
249
+ type=str,
250
+ default=None,
251
+ help="Pretrained tokenizer name or path if not the same as model_name",
252
+ )
253
+ parser.add_argument(
254
+ "--instance_data_dir",
255
+ type=str,
256
+ default=None,
257
+ required=True,
258
+ help="A folder containing the training data of instance images.",
259
+ )
260
+ parser.add_argument(
261
+ "--class_data_dir",
262
+ type=str,
263
+ default=None,
264
+ required=False,
265
+ help="A folder containing the training data of class images.",
266
+ )
267
+ parser.add_argument(
268
+ "--instance_prompt",
269
+ type=str,
270
+ default=None,
271
+ required=True,
272
+ help="The prompt with identifier specifying the instance",
273
+ )
274
+ parser.add_argument(
275
+ "--class_prompt",
276
+ type=str,
277
+ default=None,
278
+ help="The prompt to specify images in the same class as provided instance images.",
279
+ )
280
+ parser.add_argument(
281
+ "--with_prior_preservation",
282
+ default=False,
283
+ action="store_true",
284
+ help="Flag to add prior preservation loss.",
285
+ )
286
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
287
+ parser.add_argument(
288
+ "--num_class_images",
289
+ type=int,
290
+ default=100,
291
+ help=(
292
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
293
+ " class_data_dir, additional images will be sampled with class_prompt."
294
+ ),
295
+ )
296
+ parser.add_argument(
297
+ "--output_dir",
298
+ type=str,
299
+ default="text-inversion-model",
300
+ help="The output directory where the model predictions and checkpoints will be written.",
301
+ )
302
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
303
+ parser.add_argument(
304
+ "--resolution",
305
+ type=int,
306
+ default=512,
307
+ help=(
308
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
309
+ " resolution"
310
+ ),
311
+ )
312
+ parser.add_argument(
313
+ "--center_crop",
314
+ default=False,
315
+ action="store_true",
316
+ help=(
317
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
318
+ " cropped. The images will be resized to the resolution first before cropping."
319
+ ),
320
+ )
321
+ parser.add_argument(
322
+ "--train_text_encoder",
323
+ action="store_true",
324
+ help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
325
+ )
326
+ parser.add_argument(
327
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
328
+ )
329
+ parser.add_argument(
330
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
331
+ )
332
+ parser.add_argument("--num_train_epochs", type=int, default=1)
333
+ parser.add_argument(
334
+ "--max_train_steps",
335
+ type=int,
336
+ default=None,
337
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
338
+ )
339
+ parser.add_argument(
340
+ "--checkpointing_steps",
341
+ type=int,
342
+ default=500,
343
+ help=(
344
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
345
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
346
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
347
+ "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
348
+ "instructions."
349
+ ),
350
+ )
351
+ parser.add_argument(
352
+ "--checkpoints_total_limit",
353
+ type=int,
354
+ default=None,
355
+ help=(
356
+ "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
357
+ " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
358
+ " for more details"
359
+ ),
360
+ )
361
+ parser.add_argument(
362
+ "--resume_from_checkpoint",
363
+ type=str,
364
+ default=None,
365
+ help=(
366
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
367
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
368
+ ),
369
+ )
370
+ parser.add_argument(
371
+ "--gradient_accumulation_steps",
372
+ type=int,
373
+ default=1,
374
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
375
+ )
376
+ parser.add_argument(
377
+ "--gradient_checkpointing",
378
+ action="store_true",
379
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
380
+ )
381
+ parser.add_argument(
382
+ "--learning_rate",
383
+ type=float,
384
+ default=5e-6,
385
+ help="Initial learning rate (after the potential warmup period) to use.",
386
+ )
387
+ parser.add_argument(
388
+ "--scale_lr",
389
+ action="store_true",
390
+ default=False,
391
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
392
+ )
393
+ parser.add_argument(
394
+ "--lr_scheduler",
395
+ type=str,
396
+ default="constant",
397
+ help=(
398
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
399
+ ' "constant", "constant_with_warmup"]'
400
+ ),
401
+ )
402
+ parser.add_argument(
403
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
404
+ )
405
+ parser.add_argument(
406
+ "--lr_num_cycles",
407
+ type=int,
408
+ default=1,
409
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
410
+ )
411
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
412
+ parser.add_argument(
413
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
414
+ )
415
+ parser.add_argument(
416
+ "--dataloader_num_workers",
417
+ type=int,
418
+ default=0,
419
+ help=(
420
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
421
+ ),
422
+ )
423
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
424
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
425
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
426
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
427
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
428
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
429
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
430
+ parser.add_argument(
431
+ "--hub_model_id",
432
+ type=str,
433
+ default=None,
434
+ help="The name of the repository to keep in sync with the local `output_dir`.",
435
+ )
436
+ parser.add_argument(
437
+ "--logging_dir",
438
+ type=str,
439
+ default="logs",
440
+ help=(
441
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
442
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
443
+ ),
444
+ )
445
+ parser.add_argument(
446
+ "--allow_tf32",
447
+ action="store_true",
448
+ help=(
449
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
450
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
451
+ ),
452
+ )
453
+ parser.add_argument(
454
+ "--report_to",
455
+ type=str,
456
+ default="tensorboard",
457
+ help=(
458
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
459
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
460
+ ),
461
+ )
462
+ parser.add_argument(
463
+ "--validation_prompt",
464
+ type=str,
465
+ default=None,
466
+ help="A prompt that is used during validation to verify that the model is learning.",
467
+ )
468
+ parser.add_argument(
469
+ "--num_validation_images",
470
+ type=int,
471
+ default=4,
472
+ help="Number of images that should be generated during validation with `validation_prompt`.",
473
+ )
474
+ parser.add_argument(
475
+ "--validation_steps",
476
+ type=int,
477
+ default=100,
478
+ help=(
479
+ "Run validation every X steps. Validation consists of running the prompt"
480
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
481
+ " and logging the images."
482
+ ),
483
+ )
484
+ parser.add_argument(
485
+ "--mixed_precision",
486
+ type=str,
487
+ default=None,
488
+ choices=["no", "fp16", "bf16"],
489
+ help=(
490
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
491
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
492
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
493
+ ),
494
+ )
495
+ parser.add_argument(
496
+ "--prior_generation_precision",
497
+ type=str,
498
+ default=None,
499
+ choices=["no", "fp32", "fp16", "bf16"],
500
+ help=(
501
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
502
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
503
+ ),
504
+ )
505
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
506
+ parser.add_argument(
507
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
508
+ )
509
+ parser.add_argument(
510
+ "--set_grads_to_none",
511
+ action="store_true",
512
+ help=(
513
+ "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
514
+ " behaviors, so disable this argument if it causes any problems. More info:"
515
+ " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
516
+ ),
517
+ )
518
+
519
+ parser.add_argument(
520
+ "--offset_noise",
521
+ action="store_true",
522
+ default=False,
523
+ help=(
524
+ "Fine-tuning against a modified noise"
525
+ " See: https://www.crosslabs.org//blog/diffusion-with-offset-noise for more information."
526
+ ),
527
+ )
528
+ parser.add_argument(
529
+ "--snr_gamma",
530
+ type=float,
531
+ default=None,
532
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
533
+ "More details here: https://arxiv.org/abs/2303.09556.",
534
+ )
535
+ parser.add_argument(
536
+ "--pre_compute_text_embeddings",
537
+ action="store_true",
538
+ help="Whether or not to pre-compute text embeddings. If text embeddings are pre-computed, the text encoder will not be kept in memory during training and will leave more GPU memory available for training the rest of the model. This is not compatible with `--train_text_encoder`.",
539
+ )
540
+ parser.add_argument(
541
+ "--tokenizer_max_length",
542
+ type=int,
543
+ default=None,
544
+ required=False,
545
+ help="The maximum length of the tokenizer. If not set, will default to the tokenizer's max length.",
546
+ )
547
+ parser.add_argument(
548
+ "--text_encoder_use_attention_mask",
549
+ action="store_true",
550
+ required=False,
551
+ help="Whether to use attention mask for the text encoder",
552
+ )
553
+ parser.add_argument(
554
+ "--skip_save_text_encoder", action="store_true", required=False, help="Set to not save text encoder"
555
+ )
556
+ parser.add_argument(
557
+ "--validation_images",
558
+ required=False,
559
+ default=None,
560
+ nargs="+",
561
+ help="Optional set of images to use for validation. Used when the target pipeline takes an initial image as input such as when training image variation or superresolution.",
562
+ )
563
+ parser.add_argument(
564
+ "--class_labels_conditioning",
565
+ required=False,
566
+ default=None,
567
+ help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.",
568
+ )
569
+ parser.add_argument(
570
+ "--validation_scheduler",
571
+ type=str,
572
+ default="DPMSolverMultistepScheduler",
573
+ choices=["DPMSolverMultistepScheduler", "DDPMScheduler"],
574
+ help="Select which scheduler to use for validation. DDPMScheduler is recommended for DeepFloyd IF.",
575
+ )
576
+
577
+ if input_args is not None:
578
+ args = parser.parse_args(input_args)
579
+ else:
580
+ args = parser.parse_args()
581
+
582
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
583
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
584
+ args.local_rank = env_local_rank
585
+
586
+ if args.with_prior_preservation:
587
+ if args.class_data_dir is None:
588
+ raise ValueError("You must specify a data directory for class images.")
589
+ if args.class_prompt is None:
590
+ raise ValueError("You must specify prompt for class images.")
591
+ else:
592
+ # logger is not available yet
593
+ if args.class_data_dir is not None:
594
+ warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
595
+ if args.class_prompt is not None:
596
+ warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
597
+
598
+ if args.train_text_encoder and args.pre_compute_text_embeddings:
599
+ raise ValueError("`--train_text_encoder` cannot be used with `--pre_compute_text_embeddings`")
600
+
601
+ return args
602
+
603
+
604
+ class DreamBoothDataset(Dataset):
605
+ """
606
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
607
+ It pre-processes the images and the tokenizes prompts.
608
+ """
609
+
610
+ def __init__(
611
+ self,
612
+ instance_data_root,
613
+ instance_prompt,
614
+ tokenizer,
615
+ class_data_root=None,
616
+ class_prompt=None,
617
+ class_num=None,
618
+ size=512,
619
+ center_crop=False,
620
+ encoder_hidden_states=None,
621
+ class_prompt_encoder_hidden_states=None,
622
+ tokenizer_max_length=None,
623
+ ):
624
+ self.size = size
625
+ self.center_crop = center_crop
626
+ self.tokenizer = tokenizer
627
+ self.encoder_hidden_states = encoder_hidden_states
628
+ self.class_prompt_encoder_hidden_states = class_prompt_encoder_hidden_states
629
+ self.tokenizer_max_length = tokenizer_max_length
630
+
631
+ self.instance_data_root = Path(instance_data_root)
632
+ if not self.instance_data_root.exists():
633
+ raise ValueError(f"Instance {self.instance_data_root} images root doesn't exists.")
634
+
635
+ self.instance_images_path = list(Path(instance_data_root).iterdir())
636
+ self.num_instance_images = len(self.instance_images_path)
637
+ self.instance_prompt = instance_prompt
638
+ self._length = self.num_instance_images
639
+
640
+ if class_data_root is not None:
641
+ self.class_data_root = Path(class_data_root)
642
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
643
+ self.class_images_path = list(self.class_data_root.iterdir())
644
+ if class_num is not None:
645
+ self.num_class_images = min(len(self.class_images_path), class_num)
646
+ else:
647
+ self.num_class_images = len(self.class_images_path)
648
+ self._length = max(self.num_class_images, self.num_instance_images)
649
+ self.class_prompt = class_prompt
650
+ else:
651
+ self.class_data_root = None
652
+
653
+ self.image_transforms = transforms.Compose(
654
+ [
655
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
656
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
657
+ transforms.ToTensor(),
658
+ transforms.Normalize([0.5], [0.5]),
659
+ ]
660
+ )
661
+
662
+ def __len__(self):
663
+ return self._length
664
+
665
+ def __getitem__(self, index):
666
+ example = {}
667
+ instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
668
+ instance_image = exif_transpose(instance_image)
669
+
670
+ if not instance_image.mode == "RGB":
671
+ instance_image = instance_image.convert("RGB")
672
+ example["instance_images"] = self.image_transforms(instance_image)
673
+
674
+ if self.encoder_hidden_states is not None:
675
+ example["instance_prompt_ids"] = self.encoder_hidden_states
676
+ else:
677
+ text_inputs = tokenize_prompt(
678
+ self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length
679
+ )
680
+ example["instance_prompt_ids"] = text_inputs.input_ids
681
+ example["instance_attention_mask"] = text_inputs.attention_mask
682
+
683
+ if self.class_data_root:
684
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
685
+ class_image = exif_transpose(class_image)
686
+
687
+ if not class_image.mode == "RGB":
688
+ class_image = class_image.convert("RGB")
689
+ example["class_images"] = self.image_transforms(class_image)
690
+
691
+ if self.class_prompt_encoder_hidden_states is not None:
692
+ example["class_prompt_ids"] = self.class_prompt_encoder_hidden_states
693
+ else:
694
+ class_text_inputs = tokenize_prompt(
695
+ self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length
696
+ )
697
+ example["class_prompt_ids"] = class_text_inputs.input_ids
698
+ example["class_attention_mask"] = class_text_inputs.attention_mask
699
+
700
+ return example
701
+
702
+
703
+ def collate_fn(examples, with_prior_preservation=False):
704
+ has_attention_mask = "instance_attention_mask" in examples[0]
705
+
706
+ input_ids = [example["instance_prompt_ids"] for example in examples]
707
+ pixel_values = [example["instance_images"] for example in examples]
708
+
709
+ if has_attention_mask:
710
+ attention_mask = [example["instance_attention_mask"] for example in examples]
711
+
712
+ # Concat class and instance examples for prior preservation.
713
+ # We do this to avoid doing two forward passes.
714
+ if with_prior_preservation:
715
+ input_ids += [example["class_prompt_ids"] for example in examples]
716
+ pixel_values += [example["class_images"] for example in examples]
717
+
718
+ if has_attention_mask:
719
+ attention_mask += [example["class_attention_mask"] for example in examples]
720
+
721
+ pixel_values = torch.stack(pixel_values)
722
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
723
+
724
+ input_ids = torch.cat(input_ids, dim=0)
725
+
726
+ batch = {
727
+ "input_ids": input_ids,
728
+ "pixel_values": pixel_values,
729
+ }
730
+
731
+ if has_attention_mask:
732
+ attention_mask = torch.cat(attention_mask, dim=0)
733
+ batch["attention_mask"] = attention_mask
734
+
735
+ return batch
736
+
737
+
738
+ class PromptDataset(Dataset):
739
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
740
+
741
+ def __init__(self, prompt, num_samples):
742
+ self.prompt = prompt
743
+ self.num_samples = num_samples
744
+
745
+ def __len__(self):
746
+ return self.num_samples
747
+
748
+ def __getitem__(self, index):
749
+ example = {}
750
+ example["prompt"] = self.prompt
751
+ example["index"] = index
752
+ return example
753
+
754
+
755
+ def model_has_vae(args):
756
+ config_file_name = os.path.join("vae", AutoencoderKL.config_name)
757
+ if os.path.isdir(args.pretrained_model_name_or_path):
758
+ config_file_name = os.path.join(args.pretrained_model_name_or_path, config_file_name)
759
+ return os.path.isfile(config_file_name)
760
+ else:
761
+ files_in_repo = model_info(args.pretrained_model_name_or_path, revision=args.revision).siblings
762
+ return any(file.rfilename == config_file_name for file in files_in_repo)
763
+
764
+
765
+ def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None):
766
+ if tokenizer_max_length is not None:
767
+ max_length = tokenizer_max_length
768
+ else:
769
+ max_length = tokenizer.model_max_length
770
+
771
+ text_inputs = tokenizer(
772
+ prompt,
773
+ truncation=True,
774
+ padding="max_length",
775
+ max_length=max_length,
776
+ return_tensors="pt",
777
+ )
778
+
779
+ return text_inputs
780
+
781
+
782
+ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None):
783
+ text_input_ids = input_ids.to(text_encoder.device)
784
+
785
+ if text_encoder_use_attention_mask:
786
+ attention_mask = attention_mask.to(text_encoder.device)
787
+ else:
788
+ attention_mask = None
789
+
790
+ prompt_embeds = text_encoder(
791
+ text_input_ids,
792
+ attention_mask=attention_mask,
793
+ )
794
+ prompt_embeds = prompt_embeds[0]
795
+
796
+ return prompt_embeds
797
+
798
+
799
+ def main(args):
800
+ logging_dir = Path(args.output_dir, args.logging_dir)
801
+
802
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
803
+
804
+ accelerator = Accelerator(
805
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
806
+ mixed_precision=args.mixed_precision,
807
+ log_with=args.report_to,
808
+ project_config=accelerator_project_config,
809
+ )
810
+
811
+ if args.report_to == "wandb":
812
+ if not is_wandb_available():
813
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
814
+
815
+ # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
816
+ # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
817
+ # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
818
+ if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
819
+ raise ValueError(
820
+ "Gradient accumulation is not supported when training the text encoder in distributed training. "
821
+ "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
822
+ )
823
+
824
+ # Make one log on every process with the configuration for debugging.
825
+ logging.basicConfig(
826
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
827
+ datefmt="%m/%d/%Y %H:%M:%S",
828
+ level=logging.INFO,
829
+ )
830
+ logger.info(accelerator.state, main_process_only=False)
831
+ if accelerator.is_local_main_process:
832
+ transformers.utils.logging.set_verbosity_warning()
833
+ diffusers.utils.logging.set_verbosity_info()
834
+ else:
835
+ transformers.utils.logging.set_verbosity_error()
836
+ diffusers.utils.logging.set_verbosity_error()
837
+
838
+ # If passed along, set the training seed now.
839
+ if args.seed is not None:
840
+ set_seed(args.seed)
841
+
842
+ # Generate class images if prior preservation is enabled.
843
+ if args.with_prior_preservation:
844
+ class_images_dir = Path(args.class_data_dir)
845
+ if not class_images_dir.exists():
846
+ class_images_dir.mkdir(parents=True)
847
+ cur_class_images = len(list(class_images_dir.iterdir()))
848
+
849
+ if cur_class_images < args.num_class_images:
850
+ torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
851
+ if args.prior_generation_precision == "fp32":
852
+ torch_dtype = torch.float32
853
+ elif args.prior_generation_precision == "fp16":
854
+ torch_dtype = torch.float16
855
+ elif args.prior_generation_precision == "bf16":
856
+ torch_dtype = torch.bfloat16
857
+ pipeline = DiffusionPipeline.from_pretrained(
858
+ args.pretrained_model_name_or_path,
859
+ torch_dtype=torch_dtype,
860
+ safety_checker=None,
861
+ revision=args.revision,
862
+ )
863
+ pipeline.set_progress_bar_config(disable=True)
864
+
865
+ num_new_images = args.num_class_images - cur_class_images
866
+ logger.info(f"Number of class images to sample: {num_new_images}.")
867
+
868
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
869
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
870
+
871
+ sample_dataloader = accelerator.prepare(sample_dataloader)
872
+ pipeline.to(accelerator.device)
873
+
874
+ for example in tqdm(
875
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
876
+ ):
877
+ images = pipeline(example["prompt"]).images
878
+
879
+ for i, image in enumerate(images):
880
+ hash_image = hashlib.sha1(image.tobytes()).hexdigest()
881
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
882
+ image.save(image_filename)
883
+
884
+ del pipeline
885
+ if torch.cuda.is_available():
886
+ torch.cuda.empty_cache()
887
+
888
+ # Handle the repository creation
889
+ if accelerator.is_main_process:
890
+ if args.output_dir is not None:
891
+ os.makedirs(args.output_dir, exist_ok=True)
892
+
893
+ if args.push_to_hub:
894
+ repo_id = create_repo(
895
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
896
+ ).repo_id
897
+
898
+ # Load the tokenizer
899
+ if args.tokenizer_name:
900
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
901
+ elif args.pretrained_model_name_or_path:
902
+ tokenizer = AutoTokenizer.from_pretrained(
903
+ args.pretrained_model_name_or_path,
904
+ subfolder="tokenizer",
905
+ revision=args.revision,
906
+ use_fast=False,
907
+ )
908
+
909
+ # import correct text encoder class
910
+ text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
911
+
912
+ # Load scheduler and models
913
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
914
+ text_encoder = text_encoder_cls.from_pretrained(
915
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
916
+ )
917
+
918
+ if model_has_vae(args):
919
+ vae = AutoencoderKL.from_pretrained(
920
+ args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
921
+ )
922
+ else:
923
+ vae = None
924
+
925
+ unet = UNet2DConditionModel.from_pretrained(
926
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
927
+ )
928
+
929
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
930
+ def save_model_hook(models, weights, output_dir):
931
+ if accelerator.is_main_process:
932
+ for model in models:
933
+ sub_dir = "unet" if isinstance(model, type(accelerator.unwrap_model(unet))) else "text_encoder"
934
+ model.save_pretrained(os.path.join(output_dir, sub_dir))
935
+
936
+ # make sure to pop weight so that corresponding model is not saved again
937
+ weights.pop()
938
+
939
+ def load_model_hook(models, input_dir):
940
+ while len(models) > 0:
941
+ # pop models so that they are not loaded again
942
+ model = models.pop()
943
+
944
+ if isinstance(model, type(accelerator.unwrap_model(text_encoder))):
945
+ # load transformers style into model
946
+ load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder")
947
+ model.config = load_model.config
948
+ else:
949
+ # load diffusers style into model
950
+ load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
951
+ model.register_to_config(**load_model.config)
952
+
953
+ model.load_state_dict(load_model.state_dict())
954
+ del load_model
955
+
956
+ accelerator.register_save_state_pre_hook(save_model_hook)
957
+ accelerator.register_load_state_pre_hook(load_model_hook)
958
+
959
+ if vae is not None:
960
+ vae.requires_grad_(False)
961
+
962
+ if not args.train_text_encoder:
963
+ text_encoder.requires_grad_(False)
964
+
965
+ if args.enable_xformers_memory_efficient_attention:
966
+ if is_xformers_available():
967
+ import xformers
968
+
969
+ xformers_version = version.parse(xformers.__version__)
970
+ if xformers_version == version.parse("0.0.16"):
971
+ logger.warn(
972
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
973
+ )
974
+ unet.enable_xformers_memory_efficient_attention()
975
+ else:
976
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
977
+
978
+ if args.gradient_checkpointing:
979
+ unet.enable_gradient_checkpointing()
980
+ if args.train_text_encoder:
981
+ text_encoder.gradient_checkpointing_enable()
982
+
983
+ # Check that all trainable models are in full precision
984
+ low_precision_error_string = (
985
+ "Please make sure to always have all model weights in full float32 precision when starting training - even if"
986
+ " doing mixed precision training. copy of the weights should still be float32."
987
+ )
988
+
989
+ if accelerator.unwrap_model(unet).dtype != torch.float32:
990
+ raise ValueError(
991
+ f"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
992
+ )
993
+
994
+ if args.train_text_encoder and accelerator.unwrap_model(text_encoder).dtype != torch.float32:
995
+ raise ValueError(
996
+ f"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder).dtype}."
997
+ f" {low_precision_error_string}"
998
+ )
999
+
1000
+ # Enable TF32 for faster training on Ampere GPUs,
1001
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
1002
+ if args.allow_tf32:
1003
+ torch.backends.cuda.matmul.allow_tf32 = True
1004
+
1005
+ if args.scale_lr:
1006
+ args.learning_rate = (
1007
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
1008
+ )
1009
+
1010
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
1011
+ if args.use_8bit_adam:
1012
+ try:
1013
+ import bitsandbytes as bnb
1014
+ except ImportError:
1015
+ raise ImportError(
1016
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
1017
+ )
1018
+
1019
+ optimizer_class = bnb.optim.AdamW8bit
1020
+ else:
1021
+ optimizer_class = torch.optim.AdamW
1022
+
1023
+ # Optimizer creation
1024
+ params_to_optimize = (
1025
+ itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
1026
+ )
1027
+ optimizer = optimizer_class(
1028
+ params_to_optimize,
1029
+ lr=args.learning_rate,
1030
+ betas=(args.adam_beta1, args.adam_beta2),
1031
+ weight_decay=args.adam_weight_decay,
1032
+ eps=args.adam_epsilon,
1033
+ )
1034
+
1035
+ if args.pre_compute_text_embeddings:
1036
+
1037
+ def compute_text_embeddings(prompt):
1038
+ with torch.no_grad():
1039
+ text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=args.tokenizer_max_length)
1040
+ prompt_embeds = encode_prompt(
1041
+ text_encoder,
1042
+ text_inputs.input_ids,
1043
+ text_inputs.attention_mask,
1044
+ text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
1045
+ )
1046
+
1047
+ return prompt_embeds
1048
+
1049
+ pre_computed_encoder_hidden_states = compute_text_embeddings(args.instance_prompt)
1050
+ validation_prompt_negative_prompt_embeds = compute_text_embeddings("")
1051
+
1052
+ if args.validation_prompt is not None:
1053
+ validation_prompt_encoder_hidden_states = compute_text_embeddings(args.validation_prompt)
1054
+ else:
1055
+ validation_prompt_encoder_hidden_states = None
1056
+
1057
+ if args.class_prompt is not None:
1058
+ pre_computed_class_prompt_encoder_hidden_states = compute_text_embeddings(args.class_prompt)
1059
+ else:
1060
+ pre_computed_class_prompt_encoder_hidden_states = None
1061
+
1062
+ text_encoder = None
1063
+ tokenizer = None
1064
+
1065
+ gc.collect()
1066
+ torch.cuda.empty_cache()
1067
+ else:
1068
+ pre_computed_encoder_hidden_states = None
1069
+ validation_prompt_encoder_hidden_states = None
1070
+ validation_prompt_negative_prompt_embeds = None
1071
+ pre_computed_class_prompt_encoder_hidden_states = None
1072
+
1073
+ # Dataset and DataLoaders creation:
1074
+ train_dataset = DreamBoothDataset(
1075
+ instance_data_root=args.instance_data_dir,
1076
+ instance_prompt=args.instance_prompt,
1077
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
1078
+ class_prompt=args.class_prompt,
1079
+ class_num=args.num_class_images,
1080
+ tokenizer=tokenizer,
1081
+ size=args.resolution,
1082
+ center_crop=args.center_crop,
1083
+ encoder_hidden_states=pre_computed_encoder_hidden_states,
1084
+ class_prompt_encoder_hidden_states=pre_computed_class_prompt_encoder_hidden_states,
1085
+ tokenizer_max_length=args.tokenizer_max_length,
1086
+ )
1087
+
1088
+ train_dataloader = torch.utils.data.DataLoader(
1089
+ train_dataset,
1090
+ batch_size=args.train_batch_size,
1091
+ shuffle=True,
1092
+ collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
1093
+ num_workers=args.dataloader_num_workers,
1094
+ )
1095
+
1096
+ # Scheduler and math around the number of training steps.
1097
+ overrode_max_train_steps = False
1098
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1099
+ if args.max_train_steps is None:
1100
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1101
+ overrode_max_train_steps = True
1102
+
1103
+ lr_scheduler = get_scheduler(
1104
+ args.lr_scheduler,
1105
+ optimizer=optimizer,
1106
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
1107
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
1108
+ num_cycles=args.lr_num_cycles,
1109
+ power=args.lr_power,
1110
+ )
1111
+
1112
+ # Prepare everything with our `accelerator`.
1113
+ if args.train_text_encoder:
1114
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1115
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler
1116
+ )
1117
+ else:
1118
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1119
+ unet, optimizer, train_dataloader, lr_scheduler
1120
+ )
1121
+
1122
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
1123
+ # as these weights are only used for inference, keeping weights in full precision is not required.
1124
+ weight_dtype = torch.float32
1125
+ if accelerator.mixed_precision == "fp16":
1126
+ weight_dtype = torch.float16
1127
+ elif accelerator.mixed_precision == "bf16":
1128
+ weight_dtype = torch.bfloat16
1129
+
1130
+ # Move vae and text_encoder to device and cast to weight_dtype
1131
+ if vae is not None:
1132
+ vae.to(accelerator.device, dtype=weight_dtype)
1133
+
1134
+ if not args.train_text_encoder and text_encoder is not None:
1135
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
1136
+
1137
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
1138
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1139
+ if overrode_max_train_steps:
1140
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1141
+ # Afterwards we recalculate our number of training epochs
1142
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
1143
+
1144
+ # We need to initialize the trackers we use, and also store our configuration.
1145
+ # The trackers initializes automatically on the main process.
1146
+ if accelerator.is_main_process:
1147
+ tracker_config = vars(copy.deepcopy(args))
1148
+ tracker_config.pop("validation_images")
1149
+ accelerator.init_trackers("dreambooth", config=tracker_config)
1150
+
1151
+ # Train!
1152
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
1153
+
1154
+ logger.info("***** Running training *****")
1155
+ logger.info(f" Num examples = {len(train_dataset)}")
1156
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
1157
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
1158
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
1159
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
1160
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1161
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
1162
+ global_step = 0
1163
+ first_epoch = 0
1164
+
1165
+ # Potentially load in the weights and states from a previous save
1166
+ if args.resume_from_checkpoint:
1167
+ if args.resume_from_checkpoint != "latest":
1168
+ path = os.path.basename(args.resume_from_checkpoint)
1169
+ else:
1170
+ # Get the mos recent checkpoint
1171
+ dirs = os.listdir(args.output_dir)
1172
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
1173
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
1174
+ path = dirs[-1] if len(dirs) > 0 else None
1175
+
1176
+ if path is None:
1177
+ accelerator.print(
1178
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
1179
+ )
1180
+ args.resume_from_checkpoint = None
1181
+ initial_global_step = 0
1182
+ else:
1183
+ accelerator.print(f"Resuming from checkpoint {path}")
1184
+ accelerator.load_state(os.path.join(args.output_dir, path))
1185
+ global_step = int(path.split("-")[1])
1186
+
1187
+ initial_global_step = global_step
1188
+ first_epoch = global_step // num_update_steps_per_epoch
1189
+ else:
1190
+ initial_global_step = 0
1191
+
1192
+ progress_bar = tqdm(
1193
+ range(0, args.max_train_steps),
1194
+ initial=initial_global_step,
1195
+ desc="Steps",
1196
+ # Only show the progress bar once on each machine.
1197
+ disable=not accelerator.is_local_main_process,
1198
+ )
1199
+
1200
+ for epoch in range(first_epoch, args.num_train_epochs):
1201
+ unet.train()
1202
+ if args.train_text_encoder:
1203
+ text_encoder.train()
1204
+ for step, batch in enumerate(train_dataloader):
1205
+ with accelerator.accumulate(unet):
1206
+ pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
1207
+
1208
+ if vae is not None:
1209
+ # Convert images to latent space
1210
+ model_input = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
1211
+ model_input = model_input * vae.config.scaling_factor
1212
+ else:
1213
+ model_input = pixel_values
1214
+
1215
+ # Sample noise that we'll add to the model input
1216
+ if args.offset_noise:
1217
+ noise = torch.randn_like(model_input) + 0.1 * torch.randn(
1218
+ model_input.shape[0], model_input.shape[1], 1, 1, device=model_input.device
1219
+ )
1220
+ else:
1221
+ noise = torch.randn_like(model_input)
1222
+ bsz, channels, height, width = model_input.shape
1223
+ # Sample a random timestep for each image
1224
+ timesteps = torch.randint(
1225
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
1226
+ )
1227
+ timesteps = timesteps.long()
1228
+
1229
+ # Add noise to the model input according to the noise magnitude at each timestep
1230
+ # (this is the forward diffusion process)
1231
+ noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
1232
+
1233
+ # Get the text embedding for conditioning
1234
+ if args.pre_compute_text_embeddings:
1235
+ encoder_hidden_states = batch["input_ids"]
1236
+ else:
1237
+ encoder_hidden_states = encode_prompt(
1238
+ text_encoder,
1239
+ batch["input_ids"],
1240
+ batch["attention_mask"],
1241
+ text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
1242
+ )
1243
+
1244
+ if accelerator.unwrap_model(unet).config.in_channels == channels * 2:
1245
+ noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)
1246
+
1247
+ if args.class_labels_conditioning == "timesteps":
1248
+ class_labels = timesteps
1249
+ else:
1250
+ class_labels = None
1251
+
1252
+ # Predict the noise residual
1253
+ model_pred = unet(
1254
+ noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels
1255
+ ).sample
1256
+
1257
+ if model_pred.shape[1] == 6:
1258
+ model_pred, _ = torch.chunk(model_pred, 2, dim=1)
1259
+
1260
+ # Get the target for loss depending on the prediction type
1261
+ if noise_scheduler.config.prediction_type == "epsilon":
1262
+ target = noise
1263
+ elif noise_scheduler.config.prediction_type == "v_prediction":
1264
+ target = noise_scheduler.get_velocity(model_input, noise, timesteps)
1265
+ else:
1266
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
1267
+
1268
+ if args.with_prior_preservation:
1269
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
1270
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
1271
+ target, target_prior = torch.chunk(target, 2, dim=0)
1272
+ # Compute prior loss
1273
+ prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
1274
+
1275
+ # Compute instance loss
1276
+ if args.snr_gamma is None:
1277
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
1278
+ else:
1279
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
1280
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
1281
+ # This is discussed in Section 4.2 of the same paper.
1282
+ snr = compute_snr(noise_scheduler, timesteps)
1283
+ base_weight = (
1284
+ torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
1285
+ )
1286
+
1287
+ if noise_scheduler.config.prediction_type == "v_prediction":
1288
+ # Velocity objective needs to be floored to an SNR weight of one.
1289
+ mse_loss_weights = base_weight + 1
1290
+ else:
1291
+ # Epsilon and sample both use the same loss weights.
1292
+ mse_loss_weights = base_weight
1293
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
1294
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
1295
+ loss = loss.mean()
1296
+
1297
+ if args.with_prior_preservation:
1298
+ # Add the prior loss to the instance loss.
1299
+ loss = loss + args.prior_loss_weight * prior_loss
1300
+
1301
+ accelerator.backward(loss)
1302
+ if accelerator.sync_gradients:
1303
+ params_to_clip = (
1304
+ itertools.chain(unet.parameters(), text_encoder.parameters())
1305
+ if args.train_text_encoder
1306
+ else unet.parameters()
1307
+ )
1308
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1309
+ optimizer.step()
1310
+ lr_scheduler.step()
1311
+ optimizer.zero_grad(set_to_none=args.set_grads_to_none)
1312
+
1313
+ # Checks if the accelerator has performed an optimization step behind the scenes
1314
+ if accelerator.sync_gradients:
1315
+ progress_bar.update(1)
1316
+ global_step += 1
1317
+
1318
+ if accelerator.is_main_process:
1319
+ if global_step % args.checkpointing_steps == 0:
1320
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1321
+ if args.checkpoints_total_limit is not None:
1322
+ checkpoints = os.listdir(args.output_dir)
1323
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
1324
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
1325
+
1326
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1327
+ if len(checkpoints) >= args.checkpoints_total_limit:
1328
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
1329
+ removing_checkpoints = checkpoints[0:num_to_remove]
1330
+
1331
+ logger.info(
1332
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1333
+ )
1334
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
1335
+
1336
+ for removing_checkpoint in removing_checkpoints:
1337
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
1338
+ shutil.rmtree(removing_checkpoint)
1339
+
1340
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1341
+ accelerator.save_state(save_path)
1342
+ logger.info(f"Saved state to {save_path}")
1343
+
1344
+ images = []
1345
+
1346
+ if args.validation_prompt is not None and global_step % args.validation_steps == 0:
1347
+ images = log_validation(
1348
+ text_encoder,
1349
+ tokenizer,
1350
+ unet,
1351
+ vae,
1352
+ args,
1353
+ accelerator,
1354
+ weight_dtype,
1355
+ global_step,
1356
+ validation_prompt_encoder_hidden_states,
1357
+ validation_prompt_negative_prompt_embeds,
1358
+ )
1359
+
1360
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1361
+ progress_bar.set_postfix(**logs)
1362
+ accelerator.log(logs, step=global_step)
1363
+
1364
+ if global_step >= args.max_train_steps:
1365
+ break
1366
+
1367
+ # Create the pipeline using using the trained modules and save it.
1368
+ accelerator.wait_for_everyone()
1369
+ if accelerator.is_main_process:
1370
+ pipeline_args = {}
1371
+
1372
+ if text_encoder is not None:
1373
+ pipeline_args["text_encoder"] = accelerator.unwrap_model(text_encoder)
1374
+
1375
+ if args.skip_save_text_encoder:
1376
+ pipeline_args["text_encoder"] = None
1377
+
1378
+ pipeline = DiffusionPipeline.from_pretrained(
1379
+ args.pretrained_model_name_or_path,
1380
+ unet=accelerator.unwrap_model(unet),
1381
+ revision=args.revision,
1382
+ **pipeline_args,
1383
+ )
1384
+
1385
+ # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
1386
+ scheduler_args = {}
1387
+
1388
+ if "variance_type" in pipeline.scheduler.config:
1389
+ variance_type = pipeline.scheduler.config.variance_type
1390
+
1391
+ if variance_type in ["learned", "learned_range"]:
1392
+ variance_type = "fixed_small"
1393
+
1394
+ scheduler_args["variance_type"] = variance_type
1395
+
1396
+ pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args)
1397
+
1398
+ pipeline.save_pretrained(args.output_dir)
1399
+
1400
+ if args.push_to_hub:
1401
+ save_model_card(
1402
+ repo_id,
1403
+ images=images,
1404
+ base_model=args.pretrained_model_name_or_path,
1405
+ train_text_encoder=args.train_text_encoder,
1406
+ prompt=args.instance_prompt,
1407
+ repo_folder=args.output_dir,
1408
+ pipeline=pipeline,
1409
+ )
1410
+ upload_folder(
1411
+ repo_id=repo_id,
1412
+ folder_path=args.output_dir,
1413
+ commit_message="End of training",
1414
+ ignore_patterns=["step_*", "epoch_*"],
1415
+ )
1416
+
1417
+ accelerator.end_training()
1418
+
1419
+
1420
+ if __name__ == "__main__":
1421
+ args = parse_args()
1422
+ main(args)
unet/config.json ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "UNet2DConditionModel",
3
+ "_diffusers_version": "0.22.0.dev0",
4
+ "_name_or_path": "/mnt/nfs_disk/huggingface/cache/hub/models--runwayml--stable-diffusion-v1-5/snapshots/1d0c4ebf6ff58a5caecab40fa1406526bca4b5b9/",
5
+ "act_fn": "silu",
6
+ "addition_embed_type": null,
7
+ "addition_embed_type_num_heads": 64,
8
+ "addition_time_embed_dim": null,
9
+ "attention_head_dim": 8,
10
+ "attention_type": "default",
11
+ "block_out_channels": [
12
+ 320,
13
+ 640,
14
+ 1280,
15
+ 1280
16
+ ],
17
+ "center_input_sample": false,
18
+ "class_embed_type": null,
19
+ "class_embeddings_concat": false,
20
+ "conv_in_kernel": 3,
21
+ "conv_out_kernel": 3,
22
+ "cross_attention_dim": 768,
23
+ "cross_attention_norm": null,
24
+ "down_block_types": [
25
+ "CrossAttnDownBlock2D",
26
+ "CrossAttnDownBlock2D",
27
+ "CrossAttnDownBlock2D",
28
+ "DownBlock2D"
29
+ ],
30
+ "downsample_padding": 1,
31
+ "dropout": 0.0,
32
+ "dual_cross_attention": false,
33
+ "encoder_hid_dim": null,
34
+ "encoder_hid_dim_type": null,
35
+ "flip_sin_to_cos": true,
36
+ "freq_shift": 0,
37
+ "in_channels": 4,
38
+ "layers_per_block": 2,
39
+ "mid_block_only_cross_attention": null,
40
+ "mid_block_scale_factor": 1,
41
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
42
+ "norm_eps": 1e-05,
43
+ "norm_num_groups": 32,
44
+ "num_attention_heads": null,
45
+ "num_class_embeds": null,
46
+ "only_cross_attention": false,
47
+ "out_channels": 4,
48
+ "projection_class_embeddings_input_dim": null,
49
+ "resnet_out_scale_factor": 1.0,
50
+ "resnet_skip_time_act": false,
51
+ "resnet_time_scale_shift": "default",
52
+ "reverse_transformer_layers_per_block": null,
53
+ "sample_size": 64,
54
+ "time_cond_proj_dim": null,
55
+ "time_embedding_act_fn": null,
56
+ "time_embedding_dim": null,
57
+ "time_embedding_type": "positional",
58
+ "timestep_post_act": null,
59
+ "transformer_layers_per_block": 1,
60
+ "up_block_types": [
61
+ "UpBlock2D",
62
+ "CrossAttnUpBlock2D",
63
+ "CrossAttnUpBlock2D",
64
+ "CrossAttnUpBlock2D"
65
+ ],
66
+ "upcast_attention": false,
67
+ "use_linear_projection": false
68
+ }
unet/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:290dc731af38c3be51298e4da72d908aa52f385617cf2c9b26fda7c505b8f890
3
+ size 3438167536
vae/config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.22.0.dev0",
4
+ "_name_or_path": "/mnt/nfs_disk/huggingface/cache/hub/models--runwayml--stable-diffusion-v1-5/snapshots/1d0c4ebf6ff58a5caecab40fa1406526bca4b5b9/vae",
5
+ "act_fn": "silu",
6
+ "block_out_channels": [
7
+ 128,
8
+ 256,
9
+ 512,
10
+ 512
11
+ ],
12
+ "down_block_types": [
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D",
16
+ "DownEncoderBlock2D"
17
+ ],
18
+ "force_upcast": true,
19
+ "in_channels": 3,
20
+ "latent_channels": 4,
21
+ "layers_per_block": 2,
22
+ "norm_num_groups": 32,
23
+ "out_channels": 3,
24
+ "sample_size": 512,
25
+ "scaling_factor": 0.18215,
26
+ "up_block_types": [
27
+ "UpDecoderBlock2D",
28
+ "UpDecoderBlock2D",
29
+ "UpDecoderBlock2D",
30
+ "UpDecoderBlock2D"
31
+ ]
32
+ }
vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4d2b5932bb4151e54e694fd31ccf51fca908223c9485bd56cd0e1d83ad94c49
3
+ size 334643268