Upload 115 files
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +1 -0
- LDMAE/.DS_Store +0 -0
- LDMAE/configs/accelerator/4gpu.yaml +17 -0
- LDMAE/configs/accelerator/8gpu.yaml +17 -0
- LDMAE/configs/celeba_hq/lightningdit_b_vmae_f8d16_cfg.yaml +82 -0
- LDMAE/configs/imagenet/lightningdit_b_vmae_f8d16_cfg.yaml +80 -0
- LDMAE/datasets/__init__.py +0 -0
- LDMAE/datasets/__pycache__/__init__.cpython-310.pyc +0 -0
- LDMAE/datasets/__pycache__/__init__.cpython-38.pyc +0 -0
- LDMAE/datasets/__pycache__/img_latent_dataset.cpython-310.pyc +0 -0
- LDMAE/datasets/__pycache__/img_latent_dataset.cpython-38.pyc +0 -0
- LDMAE/datasets/img_latent_dataset.py +94 -0
- LDMAE/evaluate_tokenizer.py +262 -0
- LDMAE/extract_features.py +235 -0
- LDMAE/inference.py +368 -0
- LDMAE/models/__init__.py +0 -0
- LDMAE/models/__pycache__/__init__.cpython-310.pyc +0 -0
- LDMAE/models/__pycache__/__init__.cpython-38.pyc +0 -0
- LDMAE/models/__pycache__/lightningdit.cpython-310.pyc +0 -0
- LDMAE/models/__pycache__/lightningdit.cpython-38.pyc +0 -0
- LDMAE/models/__pycache__/pos_embed.cpython-310.pyc +0 -0
- LDMAE/models/__pycache__/pos_embed.cpython-38.pyc +0 -0
- LDMAE/models/__pycache__/rmsnorm.cpython-310.pyc +0 -0
- LDMAE/models/__pycache__/rmsnorm.cpython-38.pyc +0 -0
- LDMAE/models/__pycache__/swiglu_ffn.cpython-310.pyc +0 -0
- LDMAE/models/__pycache__/swiglu_ffn.cpython-38.pyc +0 -0
- LDMAE/models/lightningdit.py +531 -0
- LDMAE/models/lpips.py +184 -0
- LDMAE/models/pos_embed.py +135 -0
- LDMAE/models/rmsnorm.py +495 -0
- LDMAE/models/swiglu_ffn.py +74 -0
- LDMAE/pretrain_weight/aef8d16.pth +3 -0
- LDMAE/pretrain_weight/daef8d16.pth +3 -0
- LDMAE/pretrain_weight/sdv3f8d16.pth +3 -0
- LDMAE/pretrain_weight/vaef8d16.pth +3 -0
- LDMAE/pretrain_weight/vmaef8d16.pth +3 -0
- LDMAE/requirements.txt +16 -0
- LDMAE/run_extract_feature.sh +22 -0
- LDMAE/run_fast_inference.sh +20 -0
- LDMAE/run_inference.sh +20 -0
- LDMAE/run_robustness_test.sh +81 -0
- LDMAE/run_train.sh +22 -0
- LDMAE/tokenizer/__init__.py +0 -0
- LDMAE/tokenizer/__pycache__/__init__.cpython-310.pyc +0 -0
- LDMAE/tokenizer/__pycache__/__init__.cpython-38.pyc +0 -0
- LDMAE/tokenizer/__pycache__/autoencoder.cpython-310.pyc +0 -0
- LDMAE/tokenizer/__pycache__/models_mae.cpython-310.pyc +0 -0
- LDMAE/tokenizer/__pycache__/sdvae.cpython-310.pyc +0 -0
- LDMAE/tokenizer/__pycache__/vavae.cpython-310.pyc +0 -0
- LDMAE/tokenizer/__pycache__/vavae.cpython-38.pyc +0 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
figure/thumbnail.png filter=lfs diff=lfs merge=lfs -text
|
LDMAE/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
LDMAE/configs/accelerator/4gpu.yaml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
compute_environment: LOCAL_MACHINE
|
| 2 |
+
debug: false
|
| 3 |
+
distributed_type: MULTI_GPU
|
| 4 |
+
downcast_bf16: 'no'
|
| 5 |
+
enable_cpu_affinity: false
|
| 6 |
+
gpu_ids: all
|
| 7 |
+
machine_rank: 0
|
| 8 |
+
main_training_function: main
|
| 9 |
+
mixed_precision: bf16
|
| 10 |
+
num_machines: 1
|
| 11 |
+
num_processes: 4
|
| 12 |
+
rdzv_backend: static
|
| 13 |
+
same_network: true
|
| 14 |
+
tpu_env: []
|
| 15 |
+
tpu_use_cluster: false
|
| 16 |
+
tpu_use_sudo: false
|
| 17 |
+
use_cpu: false
|
LDMAE/configs/accelerator/8gpu.yaml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
compute_environment: LOCAL_MACHINE
|
| 2 |
+
debug: false
|
| 3 |
+
distributed_type: MULTI_GPU
|
| 4 |
+
downcast_bf16: 'no'
|
| 5 |
+
enable_cpu_affinity: false
|
| 6 |
+
gpu_ids: all
|
| 7 |
+
machine_rank: 0
|
| 8 |
+
main_training_function: main
|
| 9 |
+
mixed_precision: bf16
|
| 10 |
+
num_machines: 1
|
| 11 |
+
num_processes: 8
|
| 12 |
+
rdzv_backend: static
|
| 13 |
+
same_network: true
|
| 14 |
+
tpu_env: []
|
| 15 |
+
tpu_use_cluster: false
|
| 16 |
+
tpu_use_sudo: false
|
| 17 |
+
use_cpu: false
|
LDMAE/configs/celeba_hq/lightningdit_b_vmae_f8d16_cfg.yaml
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# we recommend to read config_details.yaml first.
|
| 2 |
+
|
| 3 |
+
ckpt_path: 'output/celeba_hq/lightningdit_b_vmae_f8d16/checkpoints/0060000.pt' # <---- download our pre-trained lightningdit or your own checkpoint
|
| 4 |
+
|
| 5 |
+
data:
|
| 6 |
+
name: celebahq
|
| 7 |
+
origin_path: "/data/dataset/celeba_hq/celeba_hq_256"
|
| 8 |
+
data_path: '/data/dataset/celeba_hq/vmae_feature_celebahq_train_256' # <---- path to your data. it is generated by extract_features.py.
|
| 9 |
+
# if you just want to inference, download our latent_stats.pt and give its path here is ok.
|
| 10 |
+
fid_reference_file: 'tools/fid_statistics/ALL_celebahq256.npz' # <---- path to your fid_reference_file.npz. download it from ADM
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# recommend to use default settings
|
| 14 |
+
image_size: 256
|
| 15 |
+
num_classes: 1
|
| 16 |
+
num_workers: 8
|
| 17 |
+
latent_norm: true
|
| 18 |
+
latent_multiplier: 1.0
|
| 19 |
+
sample: true # <------------------------------ check this. you should comment out this when you don't want to use it.
|
| 20 |
+
|
| 21 |
+
# recommend to use default settings (we wil release codes with SD-VAE later)
|
| 22 |
+
vae:
|
| 23 |
+
model_name: 'vmae'
|
| 24 |
+
downsample_ratio: 8
|
| 25 |
+
weight_path: 'pretrain_weight/vmae_f8d16.pth'
|
| 26 |
+
|
| 27 |
+
# recommend to use default settings
|
| 28 |
+
model:
|
| 29 |
+
model_type: LightningDiT-B/1
|
| 30 |
+
use_qknorm: false # no qk normalizing in celeba
|
| 31 |
+
use_swiglu: true
|
| 32 |
+
use_rope: true
|
| 33 |
+
use_rmsnorm: true
|
| 34 |
+
wo_shift: false
|
| 35 |
+
in_chans: 16
|
| 36 |
+
|
| 37 |
+
# recommend to use default settings
|
| 38 |
+
train:
|
| 39 |
+
max_steps: 60000
|
| 40 |
+
global_batch_size: 1024 # 256 ok
|
| 41 |
+
global_seed: 1
|
| 42 |
+
output_dir: 'output'
|
| 43 |
+
exp_name: 'celeba_hq/lightningdit_b_vmae_f8d16' # <---- experiment name, set as you like
|
| 44 |
+
ckpt: null
|
| 45 |
+
log_every: 100
|
| 46 |
+
ckpt_every: 20000
|
| 47 |
+
use_checkpoint: false
|
| 48 |
+
gradient_accumulation_steps: 1
|
| 49 |
+
optimizer:
|
| 50 |
+
lr: 0.0002
|
| 51 |
+
beta2: 0.95
|
| 52 |
+
# max_grad_norm: 1.0
|
| 53 |
+
# recommend to use default settings
|
| 54 |
+
transport:
|
| 55 |
+
path_type: Linear
|
| 56 |
+
prediction: velocity
|
| 57 |
+
loss_weight: null
|
| 58 |
+
sample_eps: null
|
| 59 |
+
train_eps: null
|
| 60 |
+
use_cosine_loss: false
|
| 61 |
+
use_lognorm: true
|
| 62 |
+
|
| 63 |
+
# recommend to use default settings
|
| 64 |
+
sample:
|
| 65 |
+
mode: ODE
|
| 66 |
+
sampling_method: euler
|
| 67 |
+
atol: 0.000001
|
| 68 |
+
rtol: 0.001
|
| 69 |
+
reverse: false
|
| 70 |
+
likelihood: false
|
| 71 |
+
num_sampling_steps: 250
|
| 72 |
+
cfg_scale: 0 # <---- cfg scale, for 800 epoch performance with FID=1.35 cfg_scale=6.7
|
| 73 |
+
# for 64 epoch performance with FID=2.11 cfg_scale=10.0
|
| 74 |
+
# you may find we use a large cfg_scale, this is because of 2 reasons:
|
| 75 |
+
# we find a high-dimensional latent space requires a large cfg_scale to get good performance than f8d4 SD-VAE
|
| 76 |
+
# we enable cfg interval, which reduces the negative effects of large cfg on high-noise parts. This means larger cfg can be utilized
|
| 77 |
+
|
| 78 |
+
# recommend to use default settings
|
| 79 |
+
per_proc_batch_size: 128
|
| 80 |
+
fid_num: 50000
|
| 81 |
+
cfg_interval_start: 0.10
|
| 82 |
+
timestep_shift: 0.3
|
LDMAE/configs/imagenet/lightningdit_b_vmae_f8d16_cfg.yaml
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# we recommend to read config_details.yaml first.
|
| 2 |
+
|
| 3 |
+
ckpt_path: 'output/imagenet/lightningdit_b_vmae_f8d16/checkpoints/0100000.pt' # <---- download our pre-trained lightningdit or your own checkpoint
|
| 4 |
+
|
| 5 |
+
data:
|
| 6 |
+
origin_path: '/data/dataset/imagenet/1K_dataset'
|
| 7 |
+
data_path: '/data/dataset/imagenet/vmae_feature_imagenet_train_256' # <---- path to your data. it is generated by extract_features.py.
|
| 8 |
+
# if you just want to inference, download our latent_stats.pt and give its path here is ok.
|
| 9 |
+
fid_reference_file: 'tools/fid_statistics/VIRTUAL_imagenet256_labeled.npz' # <---- path to your fid_reference_file.npz. download it from ADM
|
| 10 |
+
|
| 11 |
+
# recommend to use default settings
|
| 12 |
+
image_size: 256
|
| 13 |
+
num_classes: 1000
|
| 14 |
+
num_workers: 8
|
| 15 |
+
latent_norm: true
|
| 16 |
+
latent_multiplier: 1.0
|
| 17 |
+
sample: true # <------------------------------ check this. you should comment out this when you don't want to use it.
|
| 18 |
+
|
| 19 |
+
# recommend to use default settings (we wil release codes with SD-VAE later)
|
| 20 |
+
vae:
|
| 21 |
+
model_name: 'vmae_f8d16'
|
| 22 |
+
downsample_ratio: 8
|
| 23 |
+
weight_path: 'pretrain_weight/vmaef8d16.pth'
|
| 24 |
+
|
| 25 |
+
# recommend to use default settings
|
| 26 |
+
model:
|
| 27 |
+
model_type: LightningDiT-B/1
|
| 28 |
+
use_qknorm: true
|
| 29 |
+
use_swiglu: true
|
| 30 |
+
use_rope: true
|
| 31 |
+
use_rmsnorm: true
|
| 32 |
+
wo_shift: false
|
| 33 |
+
in_chans: 16
|
| 34 |
+
|
| 35 |
+
# recommend to use default settings
|
| 36 |
+
train:
|
| 37 |
+
max_steps: 100000
|
| 38 |
+
global_batch_size: 256 # 256 ok
|
| 39 |
+
global_seed: 0
|
| 40 |
+
output_dir: 'output'
|
| 41 |
+
exp_name: 'imagenet/lightningdit_b_vmae_f8d16' # <---- experiment name, set as you like
|
| 42 |
+
ckpt: null
|
| 43 |
+
log_every: 100
|
| 44 |
+
ckpt_every: 20000
|
| 45 |
+
use_checkpoint: false
|
| 46 |
+
gradient_accumulation_steps: 1
|
| 47 |
+
optimizer:
|
| 48 |
+
lr: 0.0002
|
| 49 |
+
beta2: 0.95
|
| 50 |
+
# max_grad_norm: 1.0
|
| 51 |
+
# recommend to use default settings
|
| 52 |
+
transport:
|
| 53 |
+
path_type: Linear
|
| 54 |
+
prediction: velocity
|
| 55 |
+
loss_weight: null
|
| 56 |
+
sample_eps: null
|
| 57 |
+
train_eps: null
|
| 58 |
+
use_cosine_loss: false
|
| 59 |
+
use_lognorm: true
|
| 60 |
+
|
| 61 |
+
# recommend to use default settings
|
| 62 |
+
sample:
|
| 63 |
+
mode: ODE
|
| 64 |
+
sampling_method: euler
|
| 65 |
+
atol: 0.000001
|
| 66 |
+
rtol: 0.001
|
| 67 |
+
reverse: false
|
| 68 |
+
likelihood: false
|
| 69 |
+
num_sampling_steps: 250
|
| 70 |
+
cfg_scale: 10.0 # <---- cfg scale, for 800 epoch performance with FID=1.35 cfg_scale=6.7
|
| 71 |
+
# for 64 epoch performance with FID=2.11 cfg_scale=10.0
|
| 72 |
+
# you may find we use a large cfg_scale, this is because of 2 reasons:
|
| 73 |
+
# we find a high-dimensional latent space requires a large cfg_scale to get good performance than f8d4 SD-VAE
|
| 74 |
+
# we enable cfg interval, which reduces the negative effects of large cfg on high-noise parts. This means larger cfg can be utilized
|
| 75 |
+
|
| 76 |
+
# recommend to use default settings
|
| 77 |
+
per_proc_batch_size: 256
|
| 78 |
+
fid_num: 50000
|
| 79 |
+
cfg_interval_start: 0.10
|
| 80 |
+
timestep_shift: 0.3
|
LDMAE/datasets/__init__.py
ADDED
|
File without changes
|
LDMAE/datasets/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (173 Bytes). View file
|
|
|
LDMAE/datasets/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (171 Bytes). View file
|
|
|
LDMAE/datasets/__pycache__/img_latent_dataset.cpython-310.pyc
ADDED
|
Binary file (3.41 kB). View file
|
|
|
LDMAE/datasets/__pycache__/img_latent_dataset.cpython-38.pyc
ADDED
|
Binary file (3.33 kB). View file
|
|
|
LDMAE/datasets/img_latent_dataset.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ImageNet Latent Dataset with safetensors.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import numpy as np
|
| 7 |
+
from glob import glob
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torch.utils.data import Dataset
|
| 12 |
+
|
| 13 |
+
from safetensors import safe_open
|
| 14 |
+
from tokenizer.util.misc import DiagonalGaussianDistribution
|
| 15 |
+
|
| 16 |
+
class ImgLatentDataset(Dataset):
|
| 17 |
+
def __init__(self, data_dir, latent_norm=True, latent_multiplier=1.0, sample=False):
|
| 18 |
+
self.data_dir = data_dir
|
| 19 |
+
self.latent_norm = latent_norm
|
| 20 |
+
self.latent_multiplier = latent_multiplier
|
| 21 |
+
self.sample = sample
|
| 22 |
+
|
| 23 |
+
self.files = sorted(glob(os.path.join(data_dir, "*.safetensors")))
|
| 24 |
+
self.img_to_file_map = self.get_img_to_safefile_map()
|
| 25 |
+
|
| 26 |
+
if latent_norm:
|
| 27 |
+
self._latent_mean, self._latent_std = self.get_latent_stats()
|
| 28 |
+
|
| 29 |
+
def get_img_to_safefile_map(self):
|
| 30 |
+
img_to_file = {}
|
| 31 |
+
for safe_file in self.files:
|
| 32 |
+
with safe_open(safe_file, framework="pt", device="cpu") as f:
|
| 33 |
+
labels = f.get_slice('labels')
|
| 34 |
+
labels_shape = labels.get_shape()
|
| 35 |
+
num_imgs = labels_shape[0]
|
| 36 |
+
cur_len = len(img_to_file)
|
| 37 |
+
for i in range(num_imgs):
|
| 38 |
+
img_to_file[cur_len+i] = {
|
| 39 |
+
'safe_file': safe_file,
|
| 40 |
+
'idx_in_file': i
|
| 41 |
+
}
|
| 42 |
+
return img_to_file
|
| 43 |
+
|
| 44 |
+
def get_latent_stats(self):
|
| 45 |
+
latent_stats_cache_file = os.path.join(self.data_dir, "latents_stats.pt")
|
| 46 |
+
if not os.path.exists(latent_stats_cache_file):
|
| 47 |
+
latent_stats = self.compute_latent_stats()
|
| 48 |
+
torch.save(latent_stats, latent_stats_cache_file)
|
| 49 |
+
else:
|
| 50 |
+
latent_stats = torch.load(latent_stats_cache_file)
|
| 51 |
+
return latent_stats['mean'], latent_stats['std']
|
| 52 |
+
|
| 53 |
+
def compute_latent_stats(self):
|
| 54 |
+
num_samples = min(10000, len(self.img_to_file_map))
|
| 55 |
+
random_indices = np.random.choice(len(self.img_to_file_map), num_samples, replace=False)
|
| 56 |
+
latents = []
|
| 57 |
+
for idx in tqdm(random_indices):
|
| 58 |
+
img_info = self.img_to_file_map[idx]
|
| 59 |
+
safe_file, img_idx = img_info['safe_file'], img_info['idx_in_file']
|
| 60 |
+
with safe_open(safe_file, framework="pt", device="cpu") as f:
|
| 61 |
+
features = f.get_slice('latents')
|
| 62 |
+
feature = features[img_idx:img_idx+1]
|
| 63 |
+
if self.sample:
|
| 64 |
+
feature = DiagonalGaussianDistribution(feature).sample()
|
| 65 |
+
latents.append(feature)
|
| 66 |
+
latents = torch.cat(latents, dim=0)
|
| 67 |
+
mean = latents.mean(dim=[0, 2, 3], keepdim=True)
|
| 68 |
+
std = latents.std(dim=[0, 2, 3], keepdim=True)
|
| 69 |
+
latent_stats = {'mean': mean, 'std': std}
|
| 70 |
+
print(latent_stats)
|
| 71 |
+
return latent_stats
|
| 72 |
+
|
| 73 |
+
def __len__(self):
|
| 74 |
+
return len(self.img_to_file_map.keys())
|
| 75 |
+
|
| 76 |
+
def __getitem__(self, idx):
|
| 77 |
+
img_info = self.img_to_file_map[idx]
|
| 78 |
+
safe_file, img_idx = img_info['safe_file'], img_info['idx_in_file']
|
| 79 |
+
with safe_open(safe_file, framework="pt", device="cpu") as f:
|
| 80 |
+
tensor_key = "latents" if np.random.uniform(0, 1) > 0.5 else "latents_flip"
|
| 81 |
+
features = f.get_slice(tensor_key)
|
| 82 |
+
labels = f.get_slice('labels')
|
| 83 |
+
feature = features[img_idx:img_idx+1]
|
| 84 |
+
label = labels[img_idx:img_idx+1]
|
| 85 |
+
if self.sample:
|
| 86 |
+
feature = DiagonalGaussianDistribution(feature).sample()
|
| 87 |
+
if self.latent_norm:
|
| 88 |
+
feature = (feature - self._latent_mean) / self._latent_std
|
| 89 |
+
feature = feature * self.latent_multiplier
|
| 90 |
+
|
| 91 |
+
# remove the first batch dimension (=1) kept by get_slice()
|
| 92 |
+
feature = feature.squeeze(0)
|
| 93 |
+
label = label.squeeze(0)
|
| 94 |
+
return feature, label
|
LDMAE/evaluate_tokenizer.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Evaluate tokenizer performance by computing reconstruction metrics.
|
| 3 |
+
|
| 4 |
+
Metrics include:
|
| 5 |
+
- rFID (Reconstruction FID)
|
| 6 |
+
- PSNR (Peak Signal-to-Noise Ratio)
|
| 7 |
+
- LPIPS (Learned Perceptual Image Patch Similarity)
|
| 8 |
+
- SSIM (Structural Similarity Index)
|
| 9 |
+
|
| 10 |
+
by Jingfeng Yao
|
| 11 |
+
from HUST-VL
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import os
|
| 15 |
+
import torch, yaml
|
| 16 |
+
import numpy as np
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
from PIL import Image
|
| 19 |
+
import torch.distributed as dist
|
| 20 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 21 |
+
from omegaconf import OmegaConf
|
| 22 |
+
from torch.utils.data import DataLoader, DistributedSampler
|
| 23 |
+
from tools.calculate_fid import calculate_fid_given_paths
|
| 24 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 25 |
+
from torchmetrics import StructuralSimilarityIndexMeasure
|
| 26 |
+
from models.lpips import LPIPS
|
| 27 |
+
from torchvision.datasets import ImageFolder
|
| 28 |
+
from torchvision import transforms
|
| 29 |
+
from diffusers.models import AutoencoderKL
|
| 30 |
+
from tokenizer.sdvae import Diffusers_AutoencoderKL
|
| 31 |
+
from tokenizer import models_mae
|
| 32 |
+
|
| 33 |
+
def load_config(config_path):
|
| 34 |
+
with open(config_path, "r") as file:
|
| 35 |
+
config = yaml.safe_load(file)
|
| 36 |
+
return config
|
| 37 |
+
|
| 38 |
+
def print_with_prefix(content, prefix='Tokenizer Evaluation', rank=0):
|
| 39 |
+
if rank == 0:
|
| 40 |
+
print(f"\033[34m[{prefix}]\033[0m {content}")
|
| 41 |
+
|
| 42 |
+
def save_image(image, filename):
|
| 43 |
+
Image.fromarray(image).save(filename)
|
| 44 |
+
|
| 45 |
+
def evaluate_tokenizer(args, config_path, model_type, data_path, output_path):
|
| 46 |
+
# Initialize distributed training
|
| 47 |
+
dist.init_process_group(backend='nccl')
|
| 48 |
+
local_rank = torch.distributed.get_rank()
|
| 49 |
+
torch.cuda.set_device(local_rank)
|
| 50 |
+
device = torch.device(f'cuda:{local_rank}')
|
| 51 |
+
train_config = load_config(config_path)
|
| 52 |
+
model_type = train_config['vae']['model_name'].split("_")[0]
|
| 53 |
+
|
| 54 |
+
if local_rank == 0:
|
| 55 |
+
print_with_prefix(f"Loading model... {model_type.upper()} {args.epsilon}")
|
| 56 |
+
|
| 57 |
+
if train_config['vae']['model_name'].split("_")[0] == 'vmae':
|
| 58 |
+
chkpt = train_config['vae']['weight_path']
|
| 59 |
+
arch = 'mae_for_ldmae_f8d16_prev'
|
| 60 |
+
model = getattr(models_mae, arch)(ldmae_mode=True, no_cls=True, kl_loss_weight=True, smooth_output=True, img_size=train_config['data']['image_size'])
|
| 61 |
+
checkpoint = torch.load(chkpt, map_location='cpu')
|
| 62 |
+
model = model.to(device).eval()
|
| 63 |
+
msg = model.load_state_dict(checkpoint['model'], strict=False)
|
| 64 |
+
elif train_config['vae']['model_name'].split("_")[0] in ['ae','dae','vae','sdv3']:
|
| 65 |
+
model = Diffusers_AutoencoderKL(
|
| 66 |
+
img_size=train_config['data']['image_size'],
|
| 67 |
+
sample_size=128,
|
| 68 |
+
in_channels=3,
|
| 69 |
+
out_channels=3,
|
| 70 |
+
layers_per_block=2,
|
| 71 |
+
latent_channels=16,
|
| 72 |
+
norm_num_groups=32,
|
| 73 |
+
act_fn="silu",
|
| 74 |
+
block_out_channels=(128, 256, 512, 512),
|
| 75 |
+
force_upcast=False,
|
| 76 |
+
use_quant_conv=False,
|
| 77 |
+
use_post_quant_conv=False,
|
| 78 |
+
down_block_types=(
|
| 79 |
+
"DownEncoderBlock2D",
|
| 80 |
+
"DownEncoderBlock2D",
|
| 81 |
+
"DownEncoderBlock2D",
|
| 82 |
+
"DownEncoderBlock2D",
|
| 83 |
+
),
|
| 84 |
+
up_block_types=(
|
| 85 |
+
"UpDecoderBlock2D",
|
| 86 |
+
"UpDecoderBlock2D",
|
| 87 |
+
"UpDecoderBlock2D",
|
| 88 |
+
"UpDecoderBlock2D",
|
| 89 |
+
),
|
| 90 |
+
).to(device).eval()
|
| 91 |
+
chkpt_dir = train_config['vae']['weight_path']
|
| 92 |
+
checkpoint = torch.load(chkpt_dir, map_location='cpu')
|
| 93 |
+
msg = model.load_state_dict(checkpoint['model'], strict=False)
|
| 94 |
+
else:
|
| 95 |
+
raise
|
| 96 |
+
print(msg)
|
| 97 |
+
# Image preprocessing
|
| 98 |
+
transform = transforms.Compose([
|
| 99 |
+
transforms.ToTensor(),
|
| 100 |
+
transforms.Resize(256),
|
| 101 |
+
transforms.CenterCrop(256),
|
| 102 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
| 103 |
+
])
|
| 104 |
+
|
| 105 |
+
# Create dataset and dataloader
|
| 106 |
+
dataset = ImageFolder(root=data_path, transform=transform)
|
| 107 |
+
distributed_sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(), rank=local_rank)
|
| 108 |
+
val_dataloader = DataLoader(
|
| 109 |
+
dataset,
|
| 110 |
+
batch_size=8,
|
| 111 |
+
shuffle=False,
|
| 112 |
+
num_workers=4,
|
| 113 |
+
sampler=distributed_sampler
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
if 'sample' in train_config['data']:
|
| 117 |
+
train_config['data']['data_path'] += '_sample'
|
| 118 |
+
latent_stats_cache_file = os.path.join(train_config['data']['data_path'], 'latents_stats.pt')
|
| 119 |
+
latent_stats = torch.load(latent_stats_cache_file)
|
| 120 |
+
latent_mean, latent_std = latent_stats['mean'], latent_stats['std']
|
| 121 |
+
|
| 122 |
+
latent_mean = latent_mean.clone().detach().to(device)
|
| 123 |
+
latent_std = latent_std.clone().detach().to(device)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
# Setup output directories
|
| 127 |
+
folder_name = f"{model_type}_{args.epsilon}"
|
| 128 |
+
|
| 129 |
+
save_dir = os.path.join(output_path, folder_name, 'decoded_images')
|
| 130 |
+
ref_path = os.path.join(output_path, 'ref_images')
|
| 131 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 132 |
+
os.makedirs(ref_path, exist_ok=True)
|
| 133 |
+
|
| 134 |
+
if local_rank == 0:
|
| 135 |
+
print_with_prefix(f"Output dir: {save_dir}")
|
| 136 |
+
print_with_prefix(f"Reference dir: {ref_path}")
|
| 137 |
+
|
| 138 |
+
# Save reference images if needed
|
| 139 |
+
ref_png_files = [f for f in os.listdir(ref_path) if f.endswith('.png')]
|
| 140 |
+
if len(ref_png_files) < 50000:
|
| 141 |
+
total_samples = 0
|
| 142 |
+
for batch in val_dataloader:
|
| 143 |
+
images = batch[0].to(device)
|
| 144 |
+
for j in range(images.size(0)):
|
| 145 |
+
img = torch.clamp(127.5 * images[j] + 128.0, 0, 255).cpu().permute(1, 2, 0).numpy().astype(np.uint8)
|
| 146 |
+
Image.fromarray(img).save(os.path.join(ref_path, f"ref_image_rank_{local_rank}_{total_samples}.png"))
|
| 147 |
+
total_samples += 1
|
| 148 |
+
if total_samples % 100 == 0 and local_rank == 0:
|
| 149 |
+
print_with_prefix(f"Rank {local_rank}, Saved {total_samples} reference images")
|
| 150 |
+
dist.barrier()
|
| 151 |
+
|
| 152 |
+
# Initialize metrics
|
| 153 |
+
lpips_values = []
|
| 154 |
+
ssim_values = []
|
| 155 |
+
lpips = LPIPS().to(device).eval()
|
| 156 |
+
ssim_metric = StructuralSimilarityIndexMeasure(data_range=(-1.0, 1.0)).to(device)
|
| 157 |
+
|
| 158 |
+
# Generate reconstructions and compute metrics
|
| 159 |
+
if local_rank == 0:
|
| 160 |
+
print_with_prefix("Generating reconstructions...")
|
| 161 |
+
all_indices = 0
|
| 162 |
+
if len(os.listdir(save_dir)) < 50000:
|
| 163 |
+
for batch in val_dataloader:
|
| 164 |
+
images = batch[0].to(device)
|
| 165 |
+
latents = encode_images(model, images)
|
| 166 |
+
epsilon = args.epsilon * torch.randn_like(latents)
|
| 167 |
+
latents = latents + epsilon * latent_std
|
| 168 |
+
|
| 169 |
+
with torch.no_grad():
|
| 170 |
+
decoded_images_tensor = model.decode(latents).sample
|
| 171 |
+
decoded_images = torch.clamp(127.5 * decoded_images_tensor + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
|
| 172 |
+
|
| 173 |
+
# Compute metrics
|
| 174 |
+
lpips_values.append(lpips(decoded_images_tensor, images).mean())
|
| 175 |
+
ssim_values.append(ssim_metric(decoded_images_tensor, images))
|
| 176 |
+
|
| 177 |
+
# Save reconstructions
|
| 178 |
+
for i, img in enumerate(decoded_images):
|
| 179 |
+
save_image(img, os.path.join(save_dir, f"decoded_image_rank_{local_rank}_{all_indices + i}.png"))
|
| 180 |
+
if (all_indices + i) % 100 == 0 and local_rank == 0:
|
| 181 |
+
print_with_prefix(f"Rank {local_rank}, Processed {all_indices + i} images")
|
| 182 |
+
all_indices += len(decoded_images)
|
| 183 |
+
dist.barrier()
|
| 184 |
+
|
| 185 |
+
# Aggregate metrics across GPUs
|
| 186 |
+
lpips_values = torch.tensor(lpips_values).to(device)
|
| 187 |
+
ssim_values = torch.tensor(ssim_values).to(device)
|
| 188 |
+
dist.all_reduce(lpips_values, op=dist.ReduceOp.AVG)
|
| 189 |
+
dist.all_reduce(ssim_values, op=dist.ReduceOp.AVG)
|
| 190 |
+
|
| 191 |
+
avg_lpips = lpips_values.mean().item()
|
| 192 |
+
avg_ssim = ssim_values.mean().item()
|
| 193 |
+
|
| 194 |
+
if local_rank == 0:
|
| 195 |
+
# Calculate FID
|
| 196 |
+
print_with_prefix("Computing rFID...")
|
| 197 |
+
fid = calculate_fid_given_paths([ref_path, save_dir], batch_size=50, dims=2048, device=device, num_workers=16)
|
| 198 |
+
|
| 199 |
+
# Calculate PSNR
|
| 200 |
+
print_with_prefix("Computing PSNR...")
|
| 201 |
+
psnr_values = calculate_psnr_between_folders(ref_path, save_dir)
|
| 202 |
+
avg_psnr = sum(psnr_values) / len(psnr_values)
|
| 203 |
+
|
| 204 |
+
# Print final results
|
| 205 |
+
print_with_prefix(f"Final Metrics:")
|
| 206 |
+
print_with_prefix(f"rFID: {fid:.3f}")
|
| 207 |
+
print_with_prefix(f"PSNR: {avg_psnr:.3f}")
|
| 208 |
+
print_with_prefix(f"LPIPS: {avg_lpips:.3f}")
|
| 209 |
+
print_with_prefix(f"SSIM: {avg_ssim:.3f}")
|
| 210 |
+
dist.barrier()
|
| 211 |
+
dist.destroy_process_group()
|
| 212 |
+
|
| 213 |
+
def encode_images(model, images):
|
| 214 |
+
with torch.no_grad():
|
| 215 |
+
posterior = model.encode(images).latent_dist
|
| 216 |
+
return posterior.mode().to(torch.float32)
|
| 217 |
+
|
| 218 |
+
def decode_to_images(model, z):
|
| 219 |
+
with torch.no_grad():
|
| 220 |
+
images = model.decode(z)
|
| 221 |
+
images = torch.clamp(127.5 * images + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
|
| 222 |
+
return images
|
| 223 |
+
|
| 224 |
+
def calculate_psnr(original, processed):
|
| 225 |
+
mse = torch.mean((original - processed) ** 2)
|
| 226 |
+
return 20 * torch.log10(255.0 / torch.sqrt(mse)).item()
|
| 227 |
+
|
| 228 |
+
def load_image(image_path):
|
| 229 |
+
image = Image.open(image_path).convert('RGB')
|
| 230 |
+
return torch.tensor(np.array(image).transpose(2, 0, 1), dtype=torch.float32)
|
| 231 |
+
|
| 232 |
+
def calculate_psnr_for_pair(original_path, processed_path):
|
| 233 |
+
return calculate_psnr(load_image(original_path), load_image(processed_path))
|
| 234 |
+
|
| 235 |
+
def calculate_psnr_between_folders(original_folder, processed_folder):
|
| 236 |
+
original_files = sorted(os.listdir(original_folder))
|
| 237 |
+
processed_files = sorted(os.listdir(processed_folder))
|
| 238 |
+
|
| 239 |
+
if len(original_files) != len(processed_files):
|
| 240 |
+
print("Warning: Mismatched number of images in folders")
|
| 241 |
+
return []
|
| 242 |
+
|
| 243 |
+
with ThreadPoolExecutor() as executor:
|
| 244 |
+
futures = [
|
| 245 |
+
executor.submit(calculate_psnr_for_pair,
|
| 246 |
+
os.path.join(original_folder, orig),
|
| 247 |
+
os.path.join(processed_folder, proc))
|
| 248 |
+
for orig, proc in zip(original_files, processed_files)
|
| 249 |
+
]
|
| 250 |
+
return [future.result() for future in as_completed(futures)]
|
| 251 |
+
|
| 252 |
+
if __name__ == "__main__":
|
| 253 |
+
import argparse
|
| 254 |
+
parser = argparse.ArgumentParser()
|
| 255 |
+
parser.add_argument('--config_path', type=str, default='tokenizer/configs/vavae_f16d32.yaml')
|
| 256 |
+
parser.add_argument('--model_type', type=str, default='vavae')
|
| 257 |
+
parser.add_argument('--data_path', type=str, default='/data/dataset/imagenet/1K_dataset/val')
|
| 258 |
+
parser.add_argument('--output_path', type=str, default='./rfid')
|
| 259 |
+
parser.add_argument('--seed', type=int, default=42)
|
| 260 |
+
parser.add_argument('--epsilon', type=float, default=0, help="Noise pertubation ratio for latent robustness experiment.")
|
| 261 |
+
args = parser.parse_args()
|
| 262 |
+
evaluate_tokenizer(args, config_path=args.config_path, model_type=args.model_type, data_path=args.data_path, output_path=args.output_path)
|
LDMAE/extract_features.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 3 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 4 |
+
import torch.distributed as dist
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 7 |
+
from torchvision.datasets import ImageFolder
|
| 8 |
+
import argparse
|
| 9 |
+
import os, yaml
|
| 10 |
+
from safetensors.torch import save_file
|
| 11 |
+
from datetime import datetime
|
| 12 |
+
from datasets.img_latent_dataset import ImgLatentDataset
|
| 13 |
+
from tokenizer import models_mae
|
| 14 |
+
from tokenizer.sdvae import Diffusers_AutoencoderKL
|
| 15 |
+
|
| 16 |
+
def load_config(config_path):
|
| 17 |
+
with open(config_path, "r") as file:
|
| 18 |
+
config = yaml.safe_load(file)
|
| 19 |
+
return config
|
| 20 |
+
|
| 21 |
+
def main(args, train_config):
|
| 22 |
+
"""
|
| 23 |
+
Run a tokenizer on full dataset and save the features.
|
| 24 |
+
"""
|
| 25 |
+
assert torch.cuda.is_available(), "Extract features currently requires at least one GPU."
|
| 26 |
+
|
| 27 |
+
# Setup DDP:
|
| 28 |
+
try:
|
| 29 |
+
dist.init_process_group("nccl")
|
| 30 |
+
rank = dist.get_rank()
|
| 31 |
+
device = rank % torch.cuda.device_count()
|
| 32 |
+
world_size = dist.get_world_size()
|
| 33 |
+
seed = args.seed + rank
|
| 34 |
+
if rank == 0:
|
| 35 |
+
print(f"Starting rank={rank}, seed={seed}, world_size={world_size}.")
|
| 36 |
+
except:
|
| 37 |
+
print("Failed to initialize DDP. Running in local mode.")
|
| 38 |
+
rank = 0
|
| 39 |
+
device = 0
|
| 40 |
+
world_size = 1
|
| 41 |
+
seed = args.seed
|
| 42 |
+
torch.manual_seed(seed)
|
| 43 |
+
torch.cuda.set_device(device)
|
| 44 |
+
model_name = train_config['vae']['model_name'].split("_")[0]
|
| 45 |
+
output_path = os.path.dirname(train_config['data']['origin_path'])
|
| 46 |
+
dataset_name = train_config['data']['name']
|
| 47 |
+
|
| 48 |
+
# Setup feature folders:
|
| 49 |
+
output_dir = os.path.join(output_path, f'{model_name}_feature_{dataset_name}_{args.data_split}_{args.image_size}')
|
| 50 |
+
if 'sample' in train_config['data']:
|
| 51 |
+
output_dir += '_sample'
|
| 52 |
+
if rank == 0:
|
| 53 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 54 |
+
print(model_name)
|
| 55 |
+
# Create model:
|
| 56 |
+
|
| 57 |
+
if model_name == 'vmae':
|
| 58 |
+
arch = 'mae_for_ldmae_f8d16_prev'
|
| 59 |
+
# chkpt = 'pretrain_weight/mae60_kl_f8d16_200ep.pth'
|
| 60 |
+
chkpt = train_config['vae']['weight_path']
|
| 61 |
+
tokenizer = getattr(models_mae, arch)(ldmae_mode=True, no_cls=True, kl_loss_weight=True, smooth_output=True, img_size=args.image_size)
|
| 62 |
+
checkpoint = torch.load(chkpt, map_location='cpu')
|
| 63 |
+
tokenizer = tokenizer.to(device).eval()
|
| 64 |
+
msg = tokenizer.load_state_dict(checkpoint['model'], strict=False)
|
| 65 |
+
if rank == 0:
|
| 66 |
+
print(model_name, msg)
|
| 67 |
+
elif model_name in ['ae','dae', 'vae','sdv3']:
|
| 68 |
+
tokenizer = Diffusers_AutoencoderKL(
|
| 69 |
+
img_size=args.image_size,
|
| 70 |
+
sample_size=128,
|
| 71 |
+
in_channels=3,
|
| 72 |
+
out_channels=3,
|
| 73 |
+
layers_per_block=2,
|
| 74 |
+
latent_channels=16,
|
| 75 |
+
norm_num_groups=32,
|
| 76 |
+
act_fn="silu",
|
| 77 |
+
block_out_channels=(128, 256, 512, 512),
|
| 78 |
+
force_upcast=False,
|
| 79 |
+
use_quant_conv=False,
|
| 80 |
+
use_post_quant_conv=False,
|
| 81 |
+
down_block_types=(
|
| 82 |
+
"DownEncoderBlock2D",
|
| 83 |
+
"DownEncoderBlock2D",
|
| 84 |
+
"DownEncoderBlock2D",
|
| 85 |
+
"DownEncoderBlock2D",
|
| 86 |
+
),
|
| 87 |
+
up_block_types=(
|
| 88 |
+
"UpDecoderBlock2D",
|
| 89 |
+
"UpDecoderBlock2D",
|
| 90 |
+
"UpDecoderBlock2D",
|
| 91 |
+
"UpDecoderBlock2D",
|
| 92 |
+
),
|
| 93 |
+
).to(device).eval()
|
| 94 |
+
# chkpt_dir = "./pretrain_weight/sdv3f8d16.pth"
|
| 95 |
+
chkpt = train_config['vae']['weight_path']
|
| 96 |
+
checkpoint = torch.load(chkpt, map_location='cpu')
|
| 97 |
+
msg = tokenizer.load_state_dict(checkpoint['model'], strict=False)
|
| 98 |
+
if rank == 0:
|
| 99 |
+
print(model_name, msg)
|
| 100 |
+
else:
|
| 101 |
+
raise("")
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
print(f"{device} GPU - Model loaded")
|
| 105 |
+
# Setup data:
|
| 106 |
+
data_path = train_config['data']['origin_path']
|
| 107 |
+
datasets = [
|
| 108 |
+
ImageFolder(os.path.join(data_path, args.data_split), transform=tokenizer.img_transform(p_hflip=0.0, img_size=args.image_size)),
|
| 109 |
+
ImageFolder(os.path.join(data_path, args.data_split), transform=tokenizer.img_transform(p_hflip=1.0, img_size=args.image_size))
|
| 110 |
+
]
|
| 111 |
+
samplers = [
|
| 112 |
+
DistributedSampler(
|
| 113 |
+
dataset,
|
| 114 |
+
num_replicas=world_size,
|
| 115 |
+
rank=rank,
|
| 116 |
+
shuffle=False,
|
| 117 |
+
seed=args.seed
|
| 118 |
+
) for dataset in datasets
|
| 119 |
+
] # Maybe gray scale files are dropped. Need to be fixed.
|
| 120 |
+
loaders = [
|
| 121 |
+
DataLoader(
|
| 122 |
+
dataset,
|
| 123 |
+
batch_size=args.batch_size,
|
| 124 |
+
shuffle=False,
|
| 125 |
+
sampler=sampler,
|
| 126 |
+
num_workers=args.num_workers,
|
| 127 |
+
pin_memory=True,
|
| 128 |
+
drop_last=False
|
| 129 |
+
) for dataset, sampler in zip(datasets, samplers)
|
| 130 |
+
]
|
| 131 |
+
total_data_in_loop = len(loaders[0].dataset)
|
| 132 |
+
if rank == 0:
|
| 133 |
+
print(f"Total data in one loop: {total_data_in_loop}")
|
| 134 |
+
|
| 135 |
+
run_images = 0
|
| 136 |
+
saved_files = 0
|
| 137 |
+
latents = []
|
| 138 |
+
latents_flip = []
|
| 139 |
+
labels = []
|
| 140 |
+
for batch_idx, batch_data in enumerate(zip(*loaders)):
|
| 141 |
+
run_images += batch_data[0][0].shape[0]
|
| 142 |
+
if run_images % 100 == 0 and rank == 0:
|
| 143 |
+
print(f'{datetime.now()} processing {run_images} of {total_data_in_loop} images')
|
| 144 |
+
|
| 145 |
+
for loader_idx, data in enumerate(batch_data):
|
| 146 |
+
x = data[0].to(device)
|
| 147 |
+
y = data[1] # (N,)
|
| 148 |
+
with torch.no_grad():
|
| 149 |
+
if 'sample' in train_config['data']:
|
| 150 |
+
z = tokenizer._encode(x)
|
| 151 |
+
else:
|
| 152 |
+
z = tokenizer.encode(x).latent_dist.mode().detach().cpu() # (N, C, H, W)
|
| 153 |
+
|
| 154 |
+
if batch_idx == 0 and rank == 0:
|
| 155 |
+
print('latent shape', z.shape, 'dtype', z.dtype)
|
| 156 |
+
|
| 157 |
+
if loader_idx == 0:
|
| 158 |
+
latents.append(z)
|
| 159 |
+
labels.append(y)
|
| 160 |
+
else:
|
| 161 |
+
latents_flip.append(z)
|
| 162 |
+
|
| 163 |
+
if len(latents) == 10000 // args.batch_size:
|
| 164 |
+
latents = torch.cat(latents, dim=0)
|
| 165 |
+
latents_flip = torch.cat(latents_flip, dim=0)
|
| 166 |
+
labels = torch.cat(labels, dim=0)
|
| 167 |
+
save_dict = {
|
| 168 |
+
'latents': latents,
|
| 169 |
+
'latents_flip': latents_flip,
|
| 170 |
+
'labels': labels
|
| 171 |
+
}
|
| 172 |
+
for key in save_dict:
|
| 173 |
+
if rank == 0:
|
| 174 |
+
print(key, save_dict[key].shape)
|
| 175 |
+
save_dict = {key: tensor.contiguous().cpu() for key, tensor in save_dict.items()}
|
| 176 |
+
save_filename = os.path.join(output_dir, f'latents_rank{rank:02d}_shard{saved_files:03d}.safetensors')
|
| 177 |
+
save_file(
|
| 178 |
+
save_dict,
|
| 179 |
+
save_filename,
|
| 180 |
+
metadata={'total_size': f'{latents.shape[0]}', 'dtype': f'{latents.dtype}', 'device': f'{latents.device}'}
|
| 181 |
+
)
|
| 182 |
+
if rank == 0:
|
| 183 |
+
print(f'Saved {save_filename}')
|
| 184 |
+
|
| 185 |
+
latents = []
|
| 186 |
+
latents_flip = []
|
| 187 |
+
labels = []
|
| 188 |
+
saved_files += 1
|
| 189 |
+
|
| 190 |
+
# save remainder latents that are fewer than 10000 images
|
| 191 |
+
if len(latents) > 0:
|
| 192 |
+
latents = torch.cat(latents, dim=0)
|
| 193 |
+
latents_flip = torch.cat(latents_flip, dim=0)
|
| 194 |
+
labels = torch.cat(labels, dim=0)
|
| 195 |
+
save_dict = {
|
| 196 |
+
'latents': latents,
|
| 197 |
+
'latents_flip': latents_flip,
|
| 198 |
+
'labels': labels
|
| 199 |
+
}
|
| 200 |
+
for key in save_dict:
|
| 201 |
+
if rank == 0:
|
| 202 |
+
print(key, save_dict[key].shape)
|
| 203 |
+
|
| 204 |
+
save_dict = {key: tensor.contiguous().cpu() for key, tensor in save_dict.items()}
|
| 205 |
+
save_filename = os.path.join(output_dir, f'latents_rank{rank:02d}_shard{saved_files:03d}.safetensors')
|
| 206 |
+
save_file(
|
| 207 |
+
save_dict,
|
| 208 |
+
save_filename,
|
| 209 |
+
metadata={'total_size': f'{latents.shape[0]}', 'dtype': f'{latents.dtype}', 'device': f'{latents.device}'}
|
| 210 |
+
)
|
| 211 |
+
if rank == 0:
|
| 212 |
+
print(f'Saved {save_filename}')
|
| 213 |
+
|
| 214 |
+
# Calculate latents stats
|
| 215 |
+
dist.barrier()
|
| 216 |
+
if rank == 0:
|
| 217 |
+
dataset = ImgLatentDataset(output_dir, latent_norm=True, sample=train_config['data']['sample'] if 'sample' in train_config['data'] else False,)
|
| 218 |
+
dist.barrier()
|
| 219 |
+
dist.destroy_process_group()
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
if __name__ == "__main__":
|
| 223 |
+
parser = argparse.ArgumentParser()
|
| 224 |
+
# parser.add_argument("--data_path", type=str, default='/path/to/your/data')
|
| 225 |
+
parser.add_argument("--data_split", type=str, default='train')
|
| 226 |
+
parser.add_argument("--output_path", type=str, default="/data/dataset/imagenet/")
|
| 227 |
+
parser.add_argument("--image_size", type=int, default=256)
|
| 228 |
+
parser.add_argument("--batch_size", type=int, default=64)
|
| 229 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 230 |
+
parser.add_argument("--num_workers", type=int, default=8)
|
| 231 |
+
parser.add_argument('--config', type=str, default='configs/debug.yaml')
|
| 232 |
+
args = parser.parse_args()
|
| 233 |
+
|
| 234 |
+
train_config = load_config(args.config)
|
| 235 |
+
main(args, train_config)
|
LDMAE/inference.py
ADDED
|
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Sampling Scripts of LightningDiT.
|
| 3 |
+
|
| 4 |
+
by Maple (Jingfeng Yao) from HUST-VL
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os, math, json, pickle, logging, argparse, yaml, torch, numpy as np
|
| 8 |
+
from time import time, strftime
|
| 9 |
+
from glob import glob
|
| 10 |
+
from copy import deepcopy
|
| 11 |
+
from collections import OrderedDict
|
| 12 |
+
from PIL import Image
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
import torch.distributed as dist
|
| 15 |
+
from accelerate import Accelerator
|
| 16 |
+
from torch.utils.data import DataLoader
|
| 17 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 18 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 19 |
+
import torchvision
|
| 20 |
+
# local imports
|
| 21 |
+
from tokenizer.vavae import VA_VAE
|
| 22 |
+
from tokenizer.sdvae import Diffusers_AutoencoderKL
|
| 23 |
+
from tokenizer import models_mae
|
| 24 |
+
import threading
|
| 25 |
+
|
| 26 |
+
from models.lightningdit import LightningDiT_models
|
| 27 |
+
from transport import create_transport, Sampler
|
| 28 |
+
from datasets.img_latent_dataset import ImgLatentDataset
|
| 29 |
+
from torchvision.utils import save_image
|
| 30 |
+
|
| 31 |
+
# sample function
|
| 32 |
+
def save_images_async(images, indices, save_dir):
|
| 33 |
+
"""비동기적으로 이미지를 저장하는 함수"""
|
| 34 |
+
for img, idx in zip(images, indices):
|
| 35 |
+
# numpy.ndarray를 torch.Tensor로 변환 후 저장
|
| 36 |
+
if isinstance(img, np.ndarray):
|
| 37 |
+
img = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0 # [H, W, C] → [C, H, W]
|
| 38 |
+
save_image(img, f"{save_dir}/{idx:06d}.png")
|
| 39 |
+
|
| 40 |
+
def do_sample(train_config, accelerator, ckpt_path=None, cfg_scale=None, model=None, vae=None, demo_sample_mode=False):
|
| 41 |
+
"""
|
| 42 |
+
Run sampling.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
folder_name = f"{train_config['model']['model_type'].replace('/', '-')}-ckpt-{ckpt_path.split('/')[-1].split('.')[0]}-{train_config['sample']['sampling_method']}-{train_config['sample']['num_sampling_steps']}".lower()
|
| 46 |
+
if cfg_scale is None:
|
| 47 |
+
cfg_scale = train_config['sample']['cfg_scale']
|
| 48 |
+
cfg_interval_start = train_config['sample']['cfg_interval_start'] if 'cfg_interval_start' in train_config['sample'] else 0
|
| 49 |
+
timestep_shift = train_config['sample']['timestep_shift'] if 'timestep_shift' in train_config['sample'] else 0
|
| 50 |
+
if cfg_scale > 1.0:
|
| 51 |
+
folder_name += f"-interval{cfg_interval_start:.2f}"+f"-cfg{cfg_scale:.2f}"
|
| 52 |
+
folder_name += f"-shift{timestep_shift:.2f}"
|
| 53 |
+
|
| 54 |
+
if demo_sample_mode:
|
| 55 |
+
cfg_interval_start = 0
|
| 56 |
+
timestep_shift = 0
|
| 57 |
+
# cfg_scale = 15
|
| 58 |
+
|
| 59 |
+
sample_folder_dir = os.path.join(train_config['train']['output_dir'], train_config['train']['exp_name'], folder_name)
|
| 60 |
+
if accelerator.process_index == 0:
|
| 61 |
+
if not demo_sample_mode:
|
| 62 |
+
print_with_prefix('Sample_folder_dir=', sample_folder_dir)
|
| 63 |
+
print_with_prefix('ckpt_path=', ckpt_path)
|
| 64 |
+
print_with_prefix('cfg_scale=', cfg_scale)
|
| 65 |
+
print_with_prefix('cfg_interval_start=', cfg_interval_start)
|
| 66 |
+
print_with_prefix('timestep_shift=', timestep_shift)
|
| 67 |
+
if not demo_sample_mode:
|
| 68 |
+
if not os.path.exists(sample_folder_dir):
|
| 69 |
+
if accelerator.process_index == 0:
|
| 70 |
+
os.makedirs(sample_folder_dir, exist_ok=True)
|
| 71 |
+
else:
|
| 72 |
+
png_files = [f for f in os.listdir(sample_folder_dir) if f.endswith('.png')]
|
| 73 |
+
png_count = len(png_files)
|
| 74 |
+
if png_count > train_config['sample']['fid_num']:
|
| 75 |
+
if accelerator.process_index == 0:
|
| 76 |
+
print_with_prefix(f"Found {png_count} PNG files in {sample_folder_dir}, skip sampling.")
|
| 77 |
+
return sample_folder_dir
|
| 78 |
+
|
| 79 |
+
torch.backends.cuda.matmul.allow_tf32 = True # True: fast but may lead to some small numerical differences
|
| 80 |
+
assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
|
| 81 |
+
torch.set_grad_enabled(False)
|
| 82 |
+
|
| 83 |
+
# Setup accelerator:
|
| 84 |
+
device = accelerator.device
|
| 85 |
+
|
| 86 |
+
# Setup DDP:
|
| 87 |
+
seed = train_config['train']['global_seed'] * accelerator.num_processes + accelerator.process_index
|
| 88 |
+
torch.manual_seed(seed)
|
| 89 |
+
# torch.cuda.set_device(device)
|
| 90 |
+
print_with_prefix(f"Starting rank={accelerator.local_process_index}, seed={seed}, world_size={accelerator.num_processes}.")
|
| 91 |
+
rank = accelerator.local_process_index
|
| 92 |
+
|
| 93 |
+
# Load model:
|
| 94 |
+
if 'downsample_ratio' in train_config['vae']:
|
| 95 |
+
downsample_ratio = train_config['vae']['downsample_ratio']
|
| 96 |
+
else:
|
| 97 |
+
downsample_ratio = 16
|
| 98 |
+
latent_size = train_config['data']['image_size'] // downsample_ratio
|
| 99 |
+
|
| 100 |
+
checkpoint = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
|
| 101 |
+
if "ema" in checkpoint: # supports checkpoints from train.py
|
| 102 |
+
checkpoint = checkpoint["ema"]
|
| 103 |
+
model.load_state_dict(checkpoint)
|
| 104 |
+
model.eval() # important!
|
| 105 |
+
model.to(device)
|
| 106 |
+
|
| 107 |
+
transport = create_transport(
|
| 108 |
+
train_config['transport']['path_type'],
|
| 109 |
+
train_config['transport']['prediction'],
|
| 110 |
+
train_config['transport']['loss_weight'],
|
| 111 |
+
train_config['transport']['train_eps'],
|
| 112 |
+
train_config['transport']['sample_eps'],
|
| 113 |
+
use_cosine_loss = train_config['transport']['use_cosine_loss'] if 'use_cosine_loss' in train_config['transport'] else False,
|
| 114 |
+
use_lognorm = train_config['transport']['use_lognorm'] if 'use_lognorm' in train_config['transport'] else False,
|
| 115 |
+
) # default: velocity;
|
| 116 |
+
sampler = Sampler(transport)
|
| 117 |
+
mode = train_config['sample']['mode']
|
| 118 |
+
if mode == "ODE":
|
| 119 |
+
sample_fn = sampler.sample_ode(
|
| 120 |
+
sampling_method=train_config['sample']['sampling_method'],
|
| 121 |
+
num_steps=train_config['sample']['num_sampling_steps'],
|
| 122 |
+
atol=train_config['sample']['atol'],
|
| 123 |
+
rtol=train_config['sample']['rtol'],
|
| 124 |
+
reverse=train_config['sample']['reverse'],
|
| 125 |
+
timestep_shift=timestep_shift,
|
| 126 |
+
)
|
| 127 |
+
else:
|
| 128 |
+
raise NotImplementedError(f"Sampling mode {mode} is not supported.")
|
| 129 |
+
|
| 130 |
+
if vae is None:
|
| 131 |
+
if train_config['vae']['model_name'].split("_")[0] == 'vmae':
|
| 132 |
+
chkpt = train_config['vae']['weight_path']
|
| 133 |
+
arch = 'mae_for_ldmae_f8d16_prev'
|
| 134 |
+
vae = getattr(models_mae, arch)(ldmae_mode=True, no_cls=True, kl_loss_weight=True, smooth_output=True, img_size=train_config['data']['image_size'])
|
| 135 |
+
checkpoint = torch.load(chkpt, map_location='cpu')
|
| 136 |
+
vae = vae.to(device).eval()
|
| 137 |
+
msg = vae.load_state_dict(checkpoint['model'], strict=False)
|
| 138 |
+
elif train_config['vae']['model_name'].split("_")[0] in ['ae','dae', 'vae','sdv3']:
|
| 139 |
+
vae = Diffusers_AutoencoderKL(
|
| 140 |
+
img_size=train_config['data']['image_size'],
|
| 141 |
+
sample_size=128,
|
| 142 |
+
in_channels=3,
|
| 143 |
+
out_channels=3,
|
| 144 |
+
layers_per_block=2,
|
| 145 |
+
latent_channels=16,
|
| 146 |
+
norm_num_groups=32,
|
| 147 |
+
act_fn="silu",
|
| 148 |
+
block_out_channels=(128, 256, 512, 512),
|
| 149 |
+
force_upcast=False,
|
| 150 |
+
use_quant_conv=False,
|
| 151 |
+
use_post_quant_conv=False,
|
| 152 |
+
down_block_types=(
|
| 153 |
+
"DownEncoderBlock2D",
|
| 154 |
+
"DownEncoderBlock2D",
|
| 155 |
+
"DownEncoderBlock2D",
|
| 156 |
+
"DownEncoderBlock2D",
|
| 157 |
+
),
|
| 158 |
+
up_block_types=(
|
| 159 |
+
"UpDecoderBlock2D",
|
| 160 |
+
"UpDecoderBlock2D",
|
| 161 |
+
"UpDecoderBlock2D",
|
| 162 |
+
"UpDecoderBlock2D",
|
| 163 |
+
),
|
| 164 |
+
).to(device).eval()
|
| 165 |
+
chkpt_dir = train_config['vae']['weight_path']
|
| 166 |
+
checkpoint = torch.load(chkpt_dir, map_location='cpu')
|
| 167 |
+
msg = vae.load_state_dict(checkpoint['model'], strict=False)
|
| 168 |
+
else:
|
| 169 |
+
raise
|
| 170 |
+
if accelerator.process_index == 0:
|
| 171 |
+
print_with_prefix(f'Model Loaded')
|
| 172 |
+
|
| 173 |
+
using_cfg = cfg_scale > 1.0
|
| 174 |
+
if using_cfg:
|
| 175 |
+
if accelerator.process_index == 0:
|
| 176 |
+
print_with_prefix('Using cfg:', using_cfg)
|
| 177 |
+
|
| 178 |
+
if rank == 0:
|
| 179 |
+
os.makedirs(sample_folder_dir, exist_ok=True)
|
| 180 |
+
if accelerator.process_index == 0 and not demo_sample_mode:
|
| 181 |
+
print_with_prefix(f"Saving .png samples at {sample_folder_dir}")
|
| 182 |
+
accelerator.wait_for_everyone()
|
| 183 |
+
|
| 184 |
+
# Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
|
| 185 |
+
n = train_config['sample']['per_proc_batch_size']
|
| 186 |
+
global_batch_size = n * accelerator.num_processes
|
| 187 |
+
# To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples:
|
| 188 |
+
num_samples = len([name for name in os.listdir(sample_folder_dir) if (os.path.isfile(os.path.join(sample_folder_dir, name)) and ".png" in name)])
|
| 189 |
+
total_samples = int(math.ceil(train_config['sample']['fid_num'] / global_batch_size) * global_batch_size)
|
| 190 |
+
if rank == 0:
|
| 191 |
+
if accelerator.process_index == 0:
|
| 192 |
+
print_with_prefix(f"Total number of images that will be sampled: {total_samples}")
|
| 193 |
+
assert total_samples % accelerator.num_processes == 0, "total_samples must be divisible by world_size"
|
| 194 |
+
samples_needed_this_gpu = int(total_samples // accelerator.num_processes)
|
| 195 |
+
assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size"
|
| 196 |
+
iterations = int(samples_needed_this_gpu // n)
|
| 197 |
+
done_iterations = int( int(num_samples // accelerator.num_processes) // n)
|
| 198 |
+
pbar = range(iterations)
|
| 199 |
+
if not demo_sample_mode:
|
| 200 |
+
pbar = tqdm(pbar) if rank == 0 else pbar
|
| 201 |
+
total = 0
|
| 202 |
+
|
| 203 |
+
if accelerator.process_index == 0:
|
| 204 |
+
print_with_prefix("Using latent normalization")
|
| 205 |
+
if 'sample' in train_config['data']:
|
| 206 |
+
train_config['data']['data_path'] += '_sample'
|
| 207 |
+
dataset = ImgLatentDataset(
|
| 208 |
+
data_dir=train_config['data']['data_path'],
|
| 209 |
+
latent_norm=train_config['data']['latent_norm'] if 'latent_norm' in train_config['data'] else False,
|
| 210 |
+
latent_multiplier=train_config['data']['latent_multiplier'] if 'latent_multiplier' in train_config['data'] else 0.18215,
|
| 211 |
+
sample=train_config['data']['sample'] if 'sample' in train_config['data'] else False,
|
| 212 |
+
)
|
| 213 |
+
latent_mean, latent_std = dataset.get_latent_stats()
|
| 214 |
+
latent_multiplier = train_config['data']['latent_multiplier'] if 'latent_multiplier' in train_config['data'] else 0.18215
|
| 215 |
+
# move to device
|
| 216 |
+
latent_mean = latent_mean.clone().detach().to(device)
|
| 217 |
+
latent_std = latent_std.clone().detach().to(device)
|
| 218 |
+
|
| 219 |
+
if demo_sample_mode:
|
| 220 |
+
if accelerator.process_index == 0:
|
| 221 |
+
images = []
|
| 222 |
+
if using_cfg:
|
| 223 |
+
for label in tqdm([975, 3, 207, 387, 388, 88, 979, 279], desc="Generating Demo Samples"):
|
| 224 |
+
z = torch.randn(1, model.in_channels, latent_size, latent_size, device=device)
|
| 225 |
+
y = torch.tensor([label], device=device)
|
| 226 |
+
z = torch.cat([z, z], 0)
|
| 227 |
+
y_null = torch.tensor([1000] * 1, device=device)
|
| 228 |
+
y = torch.cat([y, y_null], 0)
|
| 229 |
+
model_kwargs = dict(y=y, cfg_scale=cfg_scale, cfg_interval=False, cfg_interval_start=cfg_interval_start)
|
| 230 |
+
model_fn = model.forward_with_cfg
|
| 231 |
+
samples = sample_fn(z, model_fn, **model_kwargs)[-1]
|
| 232 |
+
samples = (samples * latent_std) / latent_multiplier + latent_mean
|
| 233 |
+
samples = vae.decode_to_images(samples)
|
| 234 |
+
images.append(samples)
|
| 235 |
+
|
| 236 |
+
else:
|
| 237 |
+
for label in tqdm([0]*8, desc="Generating Demo Samples"):
|
| 238 |
+
z = torch.randn(1, model.in_channels, latent_size, latent_size, device=device)
|
| 239 |
+
y = torch.tensor([label], device=device)
|
| 240 |
+
model_kwargs = dict(y=y)
|
| 241 |
+
model_fn = model.forward
|
| 242 |
+
samples = sample_fn(z, model_fn, **model_kwargs)[-1]
|
| 243 |
+
samples = (samples * latent_std) / latent_multiplier + latent_mean
|
| 244 |
+
samples = vae.decode_to_images(samples)
|
| 245 |
+
images.append(samples)
|
| 246 |
+
|
| 247 |
+
# Combine 8 images into a 2x4 grid
|
| 248 |
+
os.makedirs('demo_images', exist_ok=True)
|
| 249 |
+
# Stack all images into a large numpy array
|
| 250 |
+
all_images = np.stack([img[0] for img in images]) # Take first image from each batch
|
| 251 |
+
# Rearrange into 2x4 grid
|
| 252 |
+
h, w = all_images.shape[1:3]
|
| 253 |
+
grid = np.zeros((2 * h, 4 * w, 3), dtype=np.uint8)
|
| 254 |
+
for idx, image in enumerate(all_images):
|
| 255 |
+
i, j = divmod(idx, 4) # Calculate position in 2x4 grid
|
| 256 |
+
grid[i*h:(i+1)*h, j*w:(j+1)*w] = image
|
| 257 |
+
|
| 258 |
+
# Save the combined image
|
| 259 |
+
exp_name = train_config['train']['exp_name']
|
| 260 |
+
ckpt_iter = train_config['ckpt_path'].split("/")[-1][:-3]
|
| 261 |
+
Image.fromarray(grid).save(f'demo_images/{exp_name}_cfg{cfg_scale}_{ckpt_iter}_demo_samples.png')
|
| 262 |
+
return None
|
| 263 |
+
else:
|
| 264 |
+
for i in pbar:
|
| 265 |
+
# Sample inputs:
|
| 266 |
+
z = torch.randn(n, model.in_channels, latent_size, latent_size, device=device)
|
| 267 |
+
if 'trunaction' in train_config['sample']:
|
| 268 |
+
truncation_bound = train_config['sample']['truncation']
|
| 269 |
+
for _ in range(100):
|
| 270 |
+
invalid_mask = torch.abs(z) > truncation_bound
|
| 271 |
+
if not invalid_mask.any():
|
| 272 |
+
break
|
| 273 |
+
z[invalid_mask] = torch.randn_like(z[invalid_mask])
|
| 274 |
+
y = torch.randint(0, train_config['data']['num_classes'], (n,), device=device)
|
| 275 |
+
|
| 276 |
+
# Setup classifier-free guidance:
|
| 277 |
+
if using_cfg:
|
| 278 |
+
z = torch.cat([z, z], 0)
|
| 279 |
+
y_null = torch.tensor([1000] * n, device=device)
|
| 280 |
+
y = torch.cat([y, y_null], 0)
|
| 281 |
+
model_kwargs = dict(y=y, cfg_scale=cfg_scale, cfg_interval=True, cfg_interval_start=cfg_interval_start)
|
| 282 |
+
model_fn = model.forward_with_cfg
|
| 283 |
+
else:
|
| 284 |
+
model_kwargs = dict(y=y)
|
| 285 |
+
model_fn = model.forward
|
| 286 |
+
|
| 287 |
+
samples = sample_fn(z, model_fn, **model_kwargs)[-1]
|
| 288 |
+
if using_cfg:
|
| 289 |
+
samples, _ = samples.chunk(2, dim=0) # Remove null class samples
|
| 290 |
+
|
| 291 |
+
samples = (samples * latent_std) / latent_multiplier + latent_mean
|
| 292 |
+
samples = vae.decode_to_images(samples)
|
| 293 |
+
|
| 294 |
+
# Save samples to disk as individual .png files
|
| 295 |
+
for i, sample in enumerate(samples):
|
| 296 |
+
index = i * accelerator.num_processes + accelerator.process_index + total
|
| 297 |
+
Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png")
|
| 298 |
+
total += global_batch_size
|
| 299 |
+
accelerator.wait_for_everyone()
|
| 300 |
+
|
| 301 |
+
return sample_folder_dir
|
| 302 |
+
|
| 303 |
+
# some utils
|
| 304 |
+
def print_with_prefix(*messages):
|
| 305 |
+
prefix = f"\033[34m[LightningDiT-Sampling {strftime('%Y-%m-%d %H:%M:%S')}]\033[0m"
|
| 306 |
+
combined_message = ' '.join(map(str, messages))
|
| 307 |
+
print(f"{prefix}: {combined_message}")
|
| 308 |
+
|
| 309 |
+
def load_config(config_path):
|
| 310 |
+
with open(config_path, "r") as file:
|
| 311 |
+
config = yaml.safe_load(file)
|
| 312 |
+
return config
|
| 313 |
+
|
| 314 |
+
if __name__ == "__main__":
|
| 315 |
+
|
| 316 |
+
# read config
|
| 317 |
+
parser = argparse.ArgumentParser()
|
| 318 |
+
parser.add_argument('--config', type=str, default='configs/lightningdit_b_ldmvae_f16d16.yaml')
|
| 319 |
+
parser.add_argument('--demo', action='store_true', default=False)
|
| 320 |
+
args = parser.parse_args()
|
| 321 |
+
accelerator = Accelerator()
|
| 322 |
+
train_config = load_config(args.config)
|
| 323 |
+
|
| 324 |
+
# get ckpt_dir
|
| 325 |
+
assert 'ckpt_path' in train_config, "ckpt_path must be specified in config"
|
| 326 |
+
if accelerator.process_index == 0:
|
| 327 |
+
print_with_prefix('Using ckpt:', train_config['ckpt_path'])
|
| 328 |
+
ckpt_dir = train_config['ckpt_path']
|
| 329 |
+
|
| 330 |
+
if 'downsample_ratio' in train_config['vae']:
|
| 331 |
+
latent_size = train_config['data']['image_size'] // train_config['vae']['downsample_ratio']
|
| 332 |
+
else:
|
| 333 |
+
latent_size = train_config['data']['image_size'] // 16
|
| 334 |
+
|
| 335 |
+
# get model
|
| 336 |
+
model = LightningDiT_models[train_config['model']['model_type']](
|
| 337 |
+
input_size=latent_size,
|
| 338 |
+
num_classes=train_config['data']['num_classes'],
|
| 339 |
+
use_qknorm=train_config['model']['use_qknorm'],
|
| 340 |
+
use_swiglu=train_config['model']['use_swiglu'] if 'use_swiglu' in train_config['model'] else False,
|
| 341 |
+
use_rope=train_config['model']['use_rope'] if 'use_rope' in train_config['model'] else False,
|
| 342 |
+
use_rmsnorm=train_config['model']['use_rmsnorm'] if 'use_rmsnorm' in train_config['model'] else False,
|
| 343 |
+
wo_shift=train_config['model']['wo_shift'] if 'wo_shift' in train_config['model'] else False,
|
| 344 |
+
in_channels=train_config['model']['in_chans'] if 'in_chans' in train_config['model'] else 4,
|
| 345 |
+
learn_sigma=train_config['model']['learn_sigma'] if 'learn_sigma' in train_config['model'] else False,
|
| 346 |
+
class_dropout_prob=0 if train_config['data']['num_classes'] == 1 else 0.1,
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
# naive sample
|
| 350 |
+
sample_folder_dir = do_sample(train_config, accelerator, ckpt_path=ckpt_dir, model=model, demo_sample_mode=args.demo)
|
| 351 |
+
|
| 352 |
+
if not args.demo:
|
| 353 |
+
# calculate FID
|
| 354 |
+
# Important: FID is only for reference, please use ADM evaluation for paper reporting
|
| 355 |
+
if accelerator.process_index == 0:
|
| 356 |
+
from tools.calculate_fid import calculate_fid_given_paths
|
| 357 |
+
print_with_prefix('Calculating FID with {} number of samples'.format(train_config['sample']['fid_num']))
|
| 358 |
+
assert 'fid_reference_file' in train_config['data'], "fid_reference_file must be specified in config"
|
| 359 |
+
fid_reference_file = train_config['data']['fid_reference_file']
|
| 360 |
+
fid = calculate_fid_given_paths(
|
| 361 |
+
[fid_reference_file, sample_folder_dir],
|
| 362 |
+
batch_size=50,
|
| 363 |
+
dims=2048,
|
| 364 |
+
device='cuda',
|
| 365 |
+
num_workers=8,
|
| 366 |
+
sp_len = train_config['sample']['fid_num']
|
| 367 |
+
)
|
| 368 |
+
print_with_prefix('fid=',fid)
|
LDMAE/models/__init__.py
ADDED
|
File without changes
|
LDMAE/models/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (171 Bytes). View file
|
|
|
LDMAE/models/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (169 Bytes). View file
|
|
|
LDMAE/models/__pycache__/lightningdit.cpython-310.pyc
ADDED
|
Binary file (15.9 kB). View file
|
|
|
LDMAE/models/__pycache__/lightningdit.cpython-38.pyc
ADDED
|
Binary file (16 kB). View file
|
|
|
LDMAE/models/__pycache__/pos_embed.cpython-310.pyc
ADDED
|
Binary file (4.77 kB). View file
|
|
|
LDMAE/models/__pycache__/pos_embed.cpython-38.pyc
ADDED
|
Binary file (4.76 kB). View file
|
|
|
LDMAE/models/__pycache__/rmsnorm.cpython-310.pyc
ADDED
|
Binary file (16.4 kB). View file
|
|
|
LDMAE/models/__pycache__/rmsnorm.cpython-38.pyc
ADDED
|
Binary file (16.5 kB). View file
|
|
|
LDMAE/models/__pycache__/swiglu_ffn.cpython-310.pyc
ADDED
|
Binary file (2.16 kB). View file
|
|
|
LDMAE/models/__pycache__/swiglu_ffn.cpython-38.pyc
ADDED
|
Binary file (2.07 kB). View file
|
|
|
LDMAE/models/lightningdit.py
ADDED
|
@@ -0,0 +1,531 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Lightning DiT's codes are built from original DiT & SiT.
|
| 3 |
+
(https://github.com/facebookresearch/DiT; https://github.com/willisma/SiT)
|
| 4 |
+
It demonstrates that a advanced DiT together with advanced diffusion skills
|
| 5 |
+
could also achieve a very promising result with 1.35 FID on ImageNet 256 generation.
|
| 6 |
+
|
| 7 |
+
Enjoy everyone, DiT strikes back!
|
| 8 |
+
|
| 9 |
+
by Maple (Jingfeng Yao) from HUST-VL
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
import math
|
| 14 |
+
import numpy as np
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
+
from torch.utils.checkpoint import checkpoint
|
| 20 |
+
|
| 21 |
+
from timm.models.vision_transformer import PatchEmbed, Mlp
|
| 22 |
+
from models.swiglu_ffn import SwiGLUFFN
|
| 23 |
+
from models.pos_embed import VisionRotaryEmbeddingFast
|
| 24 |
+
from models.rmsnorm import RMSNorm
|
| 25 |
+
|
| 26 |
+
@torch.compile
|
| 27 |
+
def modulate(x, shift, scale):
|
| 28 |
+
if shift is None:
|
| 29 |
+
return x * (1 + scale.unsqueeze(1))
|
| 30 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
| 31 |
+
|
| 32 |
+
class Attention(nn.Module):
|
| 33 |
+
"""
|
| 34 |
+
Attention module of LightningDiT.
|
| 35 |
+
"""
|
| 36 |
+
def __init__(
|
| 37 |
+
self,
|
| 38 |
+
dim: int,
|
| 39 |
+
num_heads: int = 8,
|
| 40 |
+
qkv_bias: bool = False,
|
| 41 |
+
qk_norm: bool = False,
|
| 42 |
+
attn_drop: float = 0.,
|
| 43 |
+
proj_drop: float = 0.,
|
| 44 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
| 45 |
+
fused_attn: bool = True,
|
| 46 |
+
use_rmsnorm: bool = False,
|
| 47 |
+
) -> None:
|
| 48 |
+
super().__init__()
|
| 49 |
+
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
| 50 |
+
|
| 51 |
+
self.num_heads = num_heads
|
| 52 |
+
self.head_dim = dim // num_heads
|
| 53 |
+
self.scale = self.head_dim ** -0.5
|
| 54 |
+
self.fused_attn = fused_attn
|
| 55 |
+
|
| 56 |
+
if use_rmsnorm:
|
| 57 |
+
norm_layer = RMSNorm
|
| 58 |
+
|
| 59 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 60 |
+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
| 61 |
+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
| 62 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 63 |
+
self.proj = nn.Linear(dim, dim)
|
| 64 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 65 |
+
|
| 66 |
+
def forward(self, x: torch.Tensor, rope=None) -> torch.Tensor:
|
| 67 |
+
B, N, C = x.shape
|
| 68 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
| 69 |
+
q, k, v = qkv.unbind(0)
|
| 70 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
| 71 |
+
|
| 72 |
+
if rope is not None:
|
| 73 |
+
q = rope(q)
|
| 74 |
+
k = rope(k)
|
| 75 |
+
|
| 76 |
+
if self.fused_attn:
|
| 77 |
+
x = F.scaled_dot_product_attention(
|
| 78 |
+
q, k, v,
|
| 79 |
+
dropout_p=self.attn_drop.p if self.training else 0.,
|
| 80 |
+
)
|
| 81 |
+
else:
|
| 82 |
+
q = q * self.scale
|
| 83 |
+
attn = q @ k.transpose(-2, -1)
|
| 84 |
+
attn = attn.softmax(dim=-1)
|
| 85 |
+
attn = self.attn_drop(attn)
|
| 86 |
+
x = attn @ v
|
| 87 |
+
|
| 88 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
| 89 |
+
x = self.proj(x)
|
| 90 |
+
x = self.proj_drop(x)
|
| 91 |
+
return x
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class TimestepEmbedder(nn.Module):
|
| 95 |
+
"""
|
| 96 |
+
Embeds scalar timesteps into vector representations.
|
| 97 |
+
Same as DiT.
|
| 98 |
+
"""
|
| 99 |
+
def __init__(self, hidden_size: int, frequency_embedding_size: int = 256) -> None:
|
| 100 |
+
super().__init__()
|
| 101 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 102 |
+
self.mlp = nn.Sequential(
|
| 103 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
| 104 |
+
nn.SiLU(),
|
| 105 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
@staticmethod
|
| 109 |
+
def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 10000) -> torch.Tensor:
|
| 110 |
+
"""
|
| 111 |
+
Create sinusoidal timestep embeddings.
|
| 112 |
+
Args:
|
| 113 |
+
t: A 1-D Tensor of N indices, one per batch element. These may be fractional.
|
| 114 |
+
dim: The dimension of the output.
|
| 115 |
+
max_period: Controls the minimum frequency of the embeddings.
|
| 116 |
+
Returns:
|
| 117 |
+
An (N, D) Tensor of positional embeddings.
|
| 118 |
+
"""
|
| 119 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
| 120 |
+
half = dim // 2
|
| 121 |
+
freqs = torch.exp(
|
| 122 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
| 123 |
+
).to(device=t.device)
|
| 124 |
+
|
| 125 |
+
args = t[:, None].float() * freqs[None]
|
| 126 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 127 |
+
|
| 128 |
+
if dim % 2:
|
| 129 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 130 |
+
|
| 131 |
+
return embedding
|
| 132 |
+
|
| 133 |
+
@torch.compile
|
| 134 |
+
def forward(self, t: torch.Tensor) -> torch.Tensor:
|
| 135 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
| 136 |
+
t_emb = self.mlp(t_freq)
|
| 137 |
+
return t_emb
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class LabelEmbedder(nn.Module):
|
| 141 |
+
"""
|
| 142 |
+
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
| 143 |
+
Same as DiT.
|
| 144 |
+
"""
|
| 145 |
+
def __init__(self, num_classes, hidden_size, dropout_prob):
|
| 146 |
+
super().__init__()
|
| 147 |
+
use_cfg_embedding = dropout_prob > 0
|
| 148 |
+
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
|
| 149 |
+
self.num_classes = num_classes
|
| 150 |
+
self.dropout_prob = dropout_prob
|
| 151 |
+
|
| 152 |
+
def token_drop(self, labels, force_drop_ids=None):
|
| 153 |
+
"""
|
| 154 |
+
Drops labels to enable classifier-free guidance.
|
| 155 |
+
"""
|
| 156 |
+
if force_drop_ids is None:
|
| 157 |
+
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
|
| 158 |
+
else:
|
| 159 |
+
drop_ids = force_drop_ids == 1
|
| 160 |
+
labels = torch.where(drop_ids, self.num_classes, labels)
|
| 161 |
+
return labels
|
| 162 |
+
|
| 163 |
+
@torch.compile
|
| 164 |
+
def forward(self, labels, train, force_drop_ids=None):
|
| 165 |
+
use_dropout = self.dropout_prob > 0
|
| 166 |
+
if (train and use_dropout) or (force_drop_ids is not None):
|
| 167 |
+
labels = self.token_drop(labels, force_drop_ids)
|
| 168 |
+
embeddings = self.embedding_table(labels)
|
| 169 |
+
return embeddings
|
| 170 |
+
|
| 171 |
+
class LightningDiTBlock(nn.Module):
|
| 172 |
+
"""
|
| 173 |
+
Lightning DiT Block. We add features including:
|
| 174 |
+
- ROPE
|
| 175 |
+
- QKNorm
|
| 176 |
+
- RMSNorm
|
| 177 |
+
- SwiGLU
|
| 178 |
+
- No shift AdaLN.
|
| 179 |
+
Not all of them are used in the final model, please refer to the paper for more details.
|
| 180 |
+
"""
|
| 181 |
+
def __init__(
|
| 182 |
+
self,
|
| 183 |
+
hidden_size,
|
| 184 |
+
num_heads,
|
| 185 |
+
mlp_ratio=4.0,
|
| 186 |
+
use_qknorm=False,
|
| 187 |
+
use_swiglu=False,
|
| 188 |
+
use_rmsnorm=False,
|
| 189 |
+
wo_shift=False,
|
| 190 |
+
**block_kwargs
|
| 191 |
+
):
|
| 192 |
+
super().__init__()
|
| 193 |
+
|
| 194 |
+
# Initialize normalization layers
|
| 195 |
+
if not use_rmsnorm:
|
| 196 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 197 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 198 |
+
else:
|
| 199 |
+
self.norm1 = RMSNorm(hidden_size)
|
| 200 |
+
self.norm2 = RMSNorm(hidden_size)
|
| 201 |
+
|
| 202 |
+
# Initialize attention layer
|
| 203 |
+
self.attn = Attention(
|
| 204 |
+
hidden_size,
|
| 205 |
+
num_heads=num_heads,
|
| 206 |
+
qkv_bias=True,
|
| 207 |
+
qk_norm=use_qknorm,
|
| 208 |
+
use_rmsnorm=use_rmsnorm,
|
| 209 |
+
**block_kwargs
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
# Initialize MLP layer
|
| 213 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 214 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
| 215 |
+
if use_swiglu:
|
| 216 |
+
# here we did not use SwiGLU from xformers because it is not compatible with torch.compile for now.
|
| 217 |
+
self.mlp = SwiGLUFFN(hidden_size, int(2/3 * mlp_hidden_dim))
|
| 218 |
+
else:
|
| 219 |
+
self.mlp = Mlp(
|
| 220 |
+
in_features=hidden_size,
|
| 221 |
+
hidden_features=mlp_hidden_dim,
|
| 222 |
+
act_layer=approx_gelu,
|
| 223 |
+
drop=0
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
# Initialize AdaLN modulation
|
| 227 |
+
if wo_shift:
|
| 228 |
+
self.adaLN_modulation = nn.Sequential(
|
| 229 |
+
nn.SiLU(),
|
| 230 |
+
nn.Linear(hidden_size, 4 * hidden_size, bias=True)
|
| 231 |
+
)
|
| 232 |
+
else:
|
| 233 |
+
self.adaLN_modulation = nn.Sequential(
|
| 234 |
+
nn.SiLU(),
|
| 235 |
+
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
| 236 |
+
)
|
| 237 |
+
self.wo_shift = wo_shift
|
| 238 |
+
|
| 239 |
+
@torch.compile
|
| 240 |
+
def forward(self, x, c, feat_rope=None):
|
| 241 |
+
if self.wo_shift:
|
| 242 |
+
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(4, dim=1)
|
| 243 |
+
shift_msa = None
|
| 244 |
+
shift_mlp = None
|
| 245 |
+
else:
|
| 246 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
|
| 247 |
+
|
| 248 |
+
x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), rope=feat_rope)
|
| 249 |
+
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
| 250 |
+
return x
|
| 251 |
+
|
| 252 |
+
class FinalLayer(nn.Module):
|
| 253 |
+
"""
|
| 254 |
+
The final layer of LightningDiT.
|
| 255 |
+
"""
|
| 256 |
+
def __init__(self, hidden_size, patch_size, out_channels, use_rmsnorm=False):
|
| 257 |
+
super().__init__()
|
| 258 |
+
if not use_rmsnorm:
|
| 259 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 260 |
+
else:
|
| 261 |
+
self.norm_final = RMSNorm(hidden_size)
|
| 262 |
+
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
| 263 |
+
self.adaLN_modulation = nn.Sequential(
|
| 264 |
+
nn.SiLU(),
|
| 265 |
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
| 266 |
+
)
|
| 267 |
+
@torch.compile
|
| 268 |
+
def forward(self, x, c):
|
| 269 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
| 270 |
+
x = modulate(self.norm_final(x), shift, scale)
|
| 271 |
+
x = self.linear(x)
|
| 272 |
+
return x
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
class LightningDiT(nn.Module):
|
| 276 |
+
"""
|
| 277 |
+
Diffusion model with a Transformer backbone.
|
| 278 |
+
"""
|
| 279 |
+
def __init__(
|
| 280 |
+
self,
|
| 281 |
+
input_size=32,
|
| 282 |
+
patch_size=2,
|
| 283 |
+
in_channels=32,
|
| 284 |
+
hidden_size=1152,
|
| 285 |
+
depth=28,
|
| 286 |
+
num_heads=16,
|
| 287 |
+
mlp_ratio=4.0,
|
| 288 |
+
class_dropout_prob=0.1,
|
| 289 |
+
num_classes=1000,
|
| 290 |
+
learn_sigma=False,
|
| 291 |
+
use_qknorm=False,
|
| 292 |
+
use_swiglu=False,
|
| 293 |
+
use_rope=False,
|
| 294 |
+
use_rmsnorm=False,
|
| 295 |
+
wo_shift=False,
|
| 296 |
+
use_checkpoint=False,
|
| 297 |
+
):
|
| 298 |
+
super().__init__()
|
| 299 |
+
self.learn_sigma = learn_sigma
|
| 300 |
+
self.in_channels = in_channels
|
| 301 |
+
self.out_channels = in_channels if not learn_sigma else in_channels * 2
|
| 302 |
+
self.patch_size = patch_size
|
| 303 |
+
self.num_heads = num_heads
|
| 304 |
+
self.use_rope = use_rope
|
| 305 |
+
self.use_rmsnorm = use_rmsnorm
|
| 306 |
+
self.depth = depth
|
| 307 |
+
self.hidden_size = hidden_size
|
| 308 |
+
self.use_checkpoint = use_checkpoint
|
| 309 |
+
self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
|
| 310 |
+
self.t_embedder = TimestepEmbedder(hidden_size)
|
| 311 |
+
self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
|
| 312 |
+
num_patches = self.x_embedder.num_patches
|
| 313 |
+
# Will use fixed sin-cos embedding:
|
| 314 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
|
| 315 |
+
|
| 316 |
+
# use rotary position encoding, borrow from EVA
|
| 317 |
+
if self.use_rope:
|
| 318 |
+
half_head_dim = hidden_size // num_heads // 2
|
| 319 |
+
hw_seq_len = input_size // patch_size
|
| 320 |
+
self.feat_rope = VisionRotaryEmbeddingFast(
|
| 321 |
+
dim=half_head_dim,
|
| 322 |
+
pt_seq_len=hw_seq_len,
|
| 323 |
+
)
|
| 324 |
+
else:
|
| 325 |
+
self.feat_rope = None
|
| 326 |
+
|
| 327 |
+
self.blocks = nn.ModuleList([
|
| 328 |
+
LightningDiTBlock(hidden_size,
|
| 329 |
+
num_heads,
|
| 330 |
+
mlp_ratio=mlp_ratio,
|
| 331 |
+
use_qknorm=use_qknorm,
|
| 332 |
+
use_swiglu=use_swiglu,
|
| 333 |
+
use_rmsnorm=use_rmsnorm,
|
| 334 |
+
wo_shift=wo_shift,
|
| 335 |
+
) for _ in range(depth)
|
| 336 |
+
])
|
| 337 |
+
self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels, use_rmsnorm=use_rmsnorm)
|
| 338 |
+
self.initialize_weights()
|
| 339 |
+
|
| 340 |
+
def initialize_weights(self):
|
| 341 |
+
# Initialize transformer layers:
|
| 342 |
+
def _basic_init(module):
|
| 343 |
+
if isinstance(module, nn.Linear):
|
| 344 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 345 |
+
if module.bias is not None:
|
| 346 |
+
nn.init.constant_(module.bias, 0)
|
| 347 |
+
self.apply(_basic_init)
|
| 348 |
+
|
| 349 |
+
# Initialize (and freeze) pos_embed by sin-cos embedding:
|
| 350 |
+
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
|
| 351 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
| 352 |
+
|
| 353 |
+
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
| 354 |
+
w = self.x_embedder.proj.weight.data
|
| 355 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
| 356 |
+
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
| 357 |
+
|
| 358 |
+
# Initialize label embedding table:
|
| 359 |
+
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
|
| 360 |
+
|
| 361 |
+
# Initialize timestep embedding MLP:
|
| 362 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
| 363 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
| 364 |
+
|
| 365 |
+
# Zero-out adaLN modulation layers in LightningDiT blocks:
|
| 366 |
+
for block in self.blocks:
|
| 367 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
| 368 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
| 369 |
+
|
| 370 |
+
# Zero-out output layers:
|
| 371 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
| 372 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
| 373 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
| 374 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
| 375 |
+
|
| 376 |
+
def unpatchify(self, x):
|
| 377 |
+
"""
|
| 378 |
+
x: (N, T, patch_size**2 * C)
|
| 379 |
+
imgs: (N, H, W, C)
|
| 380 |
+
"""
|
| 381 |
+
c = self.out_channels
|
| 382 |
+
p = self.x_embedder.patch_size[0]
|
| 383 |
+
h = w = int(x.shape[1] ** 0.5)
|
| 384 |
+
assert h * w == x.shape[1]
|
| 385 |
+
|
| 386 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
| 387 |
+
x = torch.einsum('nhwpqc->nchpwq', x)
|
| 388 |
+
imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
|
| 389 |
+
return imgs
|
| 390 |
+
|
| 391 |
+
def forward(self, x, t=None, y=None):
|
| 392 |
+
"""
|
| 393 |
+
Forward pass of LightningDiT.
|
| 394 |
+
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
| 395 |
+
t: (N,) tensor of diffusion timesteps
|
| 396 |
+
y: (N,) tensor of class labels
|
| 397 |
+
use_checkpoint: boolean to toggle checkpointing
|
| 398 |
+
"""
|
| 399 |
+
|
| 400 |
+
use_checkpoint = self.use_checkpoint
|
| 401 |
+
|
| 402 |
+
x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
|
| 403 |
+
t = self.t_embedder(t) # (N, D)
|
| 404 |
+
y = self.y_embedder(y, self.training) # (N, D)
|
| 405 |
+
c = t + y # (N, D)
|
| 406 |
+
|
| 407 |
+
for block in self.blocks:
|
| 408 |
+
if use_checkpoint:
|
| 409 |
+
x = checkpoint(block, x, c, self.feat_rope, use_reentrant=True)
|
| 410 |
+
else:
|
| 411 |
+
x = block(x, c, self.feat_rope)
|
| 412 |
+
|
| 413 |
+
x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
|
| 414 |
+
x = self.unpatchify(x) # (N, out_channels, H, W)
|
| 415 |
+
|
| 416 |
+
if self.learn_sigma:
|
| 417 |
+
x, _ = x.chunk(2, dim=1)
|
| 418 |
+
return x
|
| 419 |
+
|
| 420 |
+
def forward_with_cfg(self, x, t, y, cfg_scale, cfg_interval=None, cfg_interval_start=None):
|
| 421 |
+
"""
|
| 422 |
+
Forward pass of LightningDiT, but also batches the unconditional forward pass for classifier-free guidance.
|
| 423 |
+
"""
|
| 424 |
+
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
|
| 425 |
+
half = x[: len(x) // 2]
|
| 426 |
+
combined = torch.cat([half, half], dim=0)
|
| 427 |
+
model_out = self.forward(combined, t, y)
|
| 428 |
+
# For exact reproducibility reasons, we apply classifier-free guidance on only
|
| 429 |
+
# three channels by default. The standard approach to cfg applies it to all channels.
|
| 430 |
+
# This can be done by uncommenting the following line and commenting-out the line following that.
|
| 431 |
+
# eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
|
| 432 |
+
eps, rest = model_out[:, :3], model_out[:, 3:]
|
| 433 |
+
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
| 434 |
+
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
|
| 435 |
+
|
| 436 |
+
if cfg_interval is True:
|
| 437 |
+
timestep = t[0]
|
| 438 |
+
if timestep < cfg_interval_start:
|
| 439 |
+
half_eps = cond_eps
|
| 440 |
+
|
| 441 |
+
eps = torch.cat([half_eps, half_eps], dim=0)
|
| 442 |
+
return torch.cat([eps, rest], dim=1)
|
| 443 |
+
|
| 444 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
|
| 445 |
+
"""
|
| 446 |
+
grid_size: int of the grid height and width
|
| 447 |
+
return:
|
| 448 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
| 449 |
+
"""
|
| 450 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
| 451 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
| 452 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
| 453 |
+
grid = np.stack(grid, axis=0)
|
| 454 |
+
|
| 455 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
| 456 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 457 |
+
if cls_token and extra_tokens > 0:
|
| 458 |
+
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
| 459 |
+
return pos_embed
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
| 463 |
+
assert embed_dim % 2 == 0
|
| 464 |
+
|
| 465 |
+
# use half of dimensions to encode grid_h
|
| 466 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
| 467 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
| 468 |
+
|
| 469 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
| 470 |
+
return emb
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 474 |
+
"""
|
| 475 |
+
embed_dim: output dimension for each position
|
| 476 |
+
pos: a list of positions to be encoded: size (M,)
|
| 477 |
+
out: (M, D)
|
| 478 |
+
"""
|
| 479 |
+
assert embed_dim % 2 == 0
|
| 480 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
| 481 |
+
omega /= embed_dim / 2.
|
| 482 |
+
omega = 1. / 10000**omega # (D/2,)
|
| 483 |
+
|
| 484 |
+
pos = pos.reshape(-1) # (M,)
|
| 485 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
| 486 |
+
|
| 487 |
+
emb_sin = np.sin(out) # (M, D/2)
|
| 488 |
+
emb_cos = np.cos(out) # (M, D/2)
|
| 489 |
+
|
| 490 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
| 491 |
+
return emb
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
#################################################################################
|
| 495 |
+
# LightningDiT Configs #
|
| 496 |
+
#################################################################################
|
| 497 |
+
|
| 498 |
+
def LightningDiT_XL_1(**kwargs):
|
| 499 |
+
return LightningDiT(depth=28, hidden_size=1152, patch_size=1, num_heads=16, **kwargs)
|
| 500 |
+
|
| 501 |
+
def LightningDiT_XL_2(**kwargs):
|
| 502 |
+
return LightningDiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
|
| 503 |
+
|
| 504 |
+
def LightningDiT_L_2(**kwargs):
|
| 505 |
+
return LightningDiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
|
| 506 |
+
|
| 507 |
+
def LightningDiT_B_1(**kwargs):
|
| 508 |
+
return LightningDiT(depth=12, hidden_size=768, patch_size=1, num_heads=12, **kwargs)
|
| 509 |
+
|
| 510 |
+
def LightningDiT_B_2(**kwargs):
|
| 511 |
+
return LightningDiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
|
| 512 |
+
|
| 513 |
+
def LightningDiT_1p0B_1(**kwargs):
|
| 514 |
+
return LightningDiT(depth=24, hidden_size=1536, patch_size=1, num_heads=24, **kwargs)
|
| 515 |
+
|
| 516 |
+
def LightningDiT_1p0B_2(**kwargs):
|
| 517 |
+
return LightningDiT(depth=24, hidden_size=1536, patch_size=2, num_heads=24, **kwargs)
|
| 518 |
+
|
| 519 |
+
def LightningDiT_1p6B_1(**kwargs):
|
| 520 |
+
return LightningDiT(depth=28, hidden_size=1792, patch_size=1, num_heads=28, **kwargs)
|
| 521 |
+
|
| 522 |
+
def LightningDiT_1p6B_2(**kwargs):
|
| 523 |
+
return LightningDiT(depth=28, hidden_size=1792, patch_size=2, num_heads=28, **kwargs)
|
| 524 |
+
|
| 525 |
+
LightningDiT_models = {
|
| 526 |
+
'LightningDiT-B/1': LightningDiT_B_1, 'LightningDiT-B/2': LightningDiT_B_2,
|
| 527 |
+
'LightningDiT-L/2': LightningDiT_L_2,
|
| 528 |
+
'LightningDiT-XL/1': LightningDiT_XL_1, 'LightningDiT-XL/2': LightningDiT_XL_2,
|
| 529 |
+
'LightningDiT-1p0B/1': LightningDiT_1p0B_1, 'LightningDiT-1p0B/2': LightningDiT_1p0B_2,
|
| 530 |
+
'LightningDiT-1p6B/1': LightningDiT_1p6B_1, 'LightningDiT-1p6B/2': LightningDiT_1p6B_2,
|
| 531 |
+
}
|
LDMAE/models/lpips.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import hashlib
|
| 3 |
+
import requests
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from torchvision import models
|
| 6 |
+
from collections import namedtuple
|
| 7 |
+
import os
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"}
|
| 11 |
+
|
| 12 |
+
CKPT_MAP = {"vgg_lpips": "vgg.pth"}
|
| 13 |
+
|
| 14 |
+
MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"}
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def download(url, local_path, chunk_size=1024):
|
| 19 |
+
os.makedirs(os.path.split(local_path)[0], exist_ok=True)
|
| 20 |
+
with requests.get(url, stream=True) as r:
|
| 21 |
+
total_size = int(r.headers.get("content-length", 0))
|
| 22 |
+
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
|
| 23 |
+
with open(local_path, "wb") as f:
|
| 24 |
+
for data in r.iter_content(chunk_size=chunk_size):
|
| 25 |
+
if data:
|
| 26 |
+
f.write(data)
|
| 27 |
+
pbar.update(chunk_size)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def md5_hash(path):
|
| 31 |
+
with open(path, "rb") as f:
|
| 32 |
+
content = f.read()
|
| 33 |
+
return hashlib.md5(content).hexdigest()
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def get_ckpt_path(name, root, check=False):
|
| 37 |
+
assert name in URL_MAP
|
| 38 |
+
path = os.path.join(root, CKPT_MAP[name])
|
| 39 |
+
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
|
| 40 |
+
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
|
| 41 |
+
download(URL_MAP[name], path)
|
| 42 |
+
md5 = md5_hash(path)
|
| 43 |
+
assert md5 == MD5_MAP[name], md5
|
| 44 |
+
return path
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class LPIPS(nn.Module):
|
| 48 |
+
# Learned perceptual metric
|
| 49 |
+
def __init__(self, use_dropout=True):
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.scaling_layer = ScalingLayer()
|
| 52 |
+
self.chns = [64, 128, 256, 512, 512] # vgg16 features
|
| 53 |
+
self.net = vgg16(pretrained=True, requires_grad=False)
|
| 54 |
+
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
|
| 55 |
+
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
|
| 56 |
+
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
|
| 57 |
+
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
|
| 58 |
+
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
|
| 59 |
+
self.load_from_pretrained()
|
| 60 |
+
for param in self.parameters():
|
| 61 |
+
param.requires_grad = False
|
| 62 |
+
|
| 63 |
+
def load_from_pretrained(self, name="vgg_lpips"):
|
| 64 |
+
ckpt = get_ckpt_path(name, "movqgan/modules/losses/lpips")
|
| 65 |
+
self.load_state_dict(
|
| 66 |
+
torch.load(ckpt, map_location=torch.device("cpu")), strict=False
|
| 67 |
+
)
|
| 68 |
+
print("loaded pretrained LPIPS loss from {}".format(ckpt))
|
| 69 |
+
|
| 70 |
+
@classmethod
|
| 71 |
+
def from_pretrained(cls, name="vgg_lpips"):
|
| 72 |
+
if name != "vgg_lpips":
|
| 73 |
+
raise NotImplementedError
|
| 74 |
+
model = cls()
|
| 75 |
+
ckpt = get_ckpt_path(name)
|
| 76 |
+
model.load_state_dict(
|
| 77 |
+
torch.load(ckpt, map_location=torch.device("cpu")), strict=False
|
| 78 |
+
)
|
| 79 |
+
return model
|
| 80 |
+
|
| 81 |
+
def forward(self, input, target):
|
| 82 |
+
in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
|
| 83 |
+
outs0, outs1 = self.net(in0_input), self.net(in1_input)
|
| 84 |
+
feats0, feats1, diffs = {}, {}, {}
|
| 85 |
+
lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
|
| 86 |
+
for kk in range(len(self.chns)):
|
| 87 |
+
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(
|
| 88 |
+
outs1[kk]
|
| 89 |
+
)
|
| 90 |
+
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
|
| 91 |
+
|
| 92 |
+
res = [
|
| 93 |
+
spatial_average(lins[kk].model(diffs[kk]), keepdim=True)
|
| 94 |
+
for kk in range(len(self.chns))
|
| 95 |
+
]
|
| 96 |
+
val = res[0]
|
| 97 |
+
for l in range(1, len(self.chns)):
|
| 98 |
+
val += res[l]
|
| 99 |
+
return val
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class ScalingLayer(nn.Module):
|
| 103 |
+
def __init__(self):
|
| 104 |
+
super(ScalingLayer, self).__init__()
|
| 105 |
+
self.register_buffer(
|
| 106 |
+
"shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]
|
| 107 |
+
)
|
| 108 |
+
self.register_buffer(
|
| 109 |
+
"scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
def forward(self, inp):
|
| 113 |
+
# convert imagenet normalized data to [-1, 1]
|
| 114 |
+
return (inp - self.shift) / self.scale
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class NetLinLayer(nn.Module):
|
| 118 |
+
"""A single linear layer which does a 1x1 conv"""
|
| 119 |
+
|
| 120 |
+
def __init__(self, chn_in, chn_out=1, use_dropout=False):
|
| 121 |
+
super(NetLinLayer, self).__init__()
|
| 122 |
+
layers = (
|
| 123 |
+
[
|
| 124 |
+
nn.Dropout(),
|
| 125 |
+
]
|
| 126 |
+
if (use_dropout)
|
| 127 |
+
else []
|
| 128 |
+
)
|
| 129 |
+
layers += [
|
| 130 |
+
nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
|
| 131 |
+
]
|
| 132 |
+
self.model = nn.Sequential(*layers)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class vgg16(torch.nn.Module):
|
| 136 |
+
def __init__(self, requires_grad=False, pretrained=True):
|
| 137 |
+
super(vgg16, self).__init__()
|
| 138 |
+
vgg_pretrained_features = models.vgg16(pretrained=pretrained)
|
| 139 |
+
vgg_pretrained_features = vgg_pretrained_features.features
|
| 140 |
+
self.slice1 = torch.nn.Sequential()
|
| 141 |
+
self.slice2 = torch.nn.Sequential()
|
| 142 |
+
self.slice3 = torch.nn.Sequential()
|
| 143 |
+
self.slice4 = torch.nn.Sequential()
|
| 144 |
+
self.slice5 = torch.nn.Sequential()
|
| 145 |
+
self.N_slices = 5
|
| 146 |
+
for x in range(4):
|
| 147 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
| 148 |
+
for x in range(4, 9):
|
| 149 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
| 150 |
+
for x in range(9, 16):
|
| 151 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
| 152 |
+
for x in range(16, 23):
|
| 153 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
| 154 |
+
for x in range(23, 30):
|
| 155 |
+
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
| 156 |
+
if not requires_grad:
|
| 157 |
+
for param in self.parameters():
|
| 158 |
+
param.requires_grad = False
|
| 159 |
+
|
| 160 |
+
def forward(self, X):
|
| 161 |
+
h = self.slice1(X)
|
| 162 |
+
h_relu1_2 = h
|
| 163 |
+
h = self.slice2(h)
|
| 164 |
+
h_relu2_2 = h
|
| 165 |
+
h = self.slice3(h)
|
| 166 |
+
h_relu3_3 = h
|
| 167 |
+
h = self.slice4(h)
|
| 168 |
+
h_relu4_3 = h
|
| 169 |
+
h = self.slice5(h)
|
| 170 |
+
h_relu5_3 = h
|
| 171 |
+
vgg_outputs = namedtuple(
|
| 172 |
+
"VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]
|
| 173 |
+
)
|
| 174 |
+
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
|
| 175 |
+
return out
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def normalize_tensor(x, eps=1e-10):
|
| 179 |
+
norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
|
| 180 |
+
return x / (norm_factor + eps)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def spatial_average(x, keepdim=True):
|
| 184 |
+
return x.mean([2, 3], keepdim=keepdim)
|
LDMAE/models/pos_embed.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# EVA-02: A Visual Representation for Neon Genesis
|
| 3 |
+
# Github source: https://github.com/baaivision/EVA/EVA02
|
| 4 |
+
# Copyright (c) 2023 Beijing Academy of Artificial Intelligence (BAAI)
|
| 5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 6 |
+
# By Yuxin Fang
|
| 7 |
+
#
|
| 8 |
+
# Based on https://github.com/lucidrains/rotary-embedding-torch
|
| 9 |
+
# --------------------------------------------------------'
|
| 10 |
+
|
| 11 |
+
from math import pi
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
from torch import nn
|
| 15 |
+
|
| 16 |
+
from einops import rearrange, repeat
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def broadcat(tensors, dim = -1):
|
| 21 |
+
num_tensors = len(tensors)
|
| 22 |
+
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
|
| 23 |
+
assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
|
| 24 |
+
shape_len = list(shape_lens)[0]
|
| 25 |
+
dim = (dim + shape_len) if dim < 0 else dim
|
| 26 |
+
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
|
| 27 |
+
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
| 28 |
+
assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation'
|
| 29 |
+
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
|
| 30 |
+
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
|
| 31 |
+
expanded_dims.insert(dim, (dim, dims[dim]))
|
| 32 |
+
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
|
| 33 |
+
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
|
| 34 |
+
return torch.cat(tensors, dim = dim)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def rotate_half(x):
|
| 39 |
+
x = rearrange(x, '... (d r) -> ... d r', r = 2)
|
| 40 |
+
x1, x2 = x.unbind(dim = -1)
|
| 41 |
+
x = torch.stack((-x2, x1), dim = -1)
|
| 42 |
+
return rearrange(x, '... d r -> ... (d r)')
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class VisionRotaryEmbedding(nn.Module):
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
dim,
|
| 50 |
+
pt_seq_len,
|
| 51 |
+
ft_seq_len=None,
|
| 52 |
+
custom_freqs = None,
|
| 53 |
+
freqs_for = 'lang',
|
| 54 |
+
theta = 10000,
|
| 55 |
+
max_freq = 10,
|
| 56 |
+
num_freqs = 1,
|
| 57 |
+
):
|
| 58 |
+
super().__init__()
|
| 59 |
+
if custom_freqs:
|
| 60 |
+
freqs = custom_freqs
|
| 61 |
+
elif freqs_for == 'lang':
|
| 62 |
+
freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
|
| 63 |
+
elif freqs_for == 'pixel':
|
| 64 |
+
freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
|
| 65 |
+
elif freqs_for == 'constant':
|
| 66 |
+
freqs = torch.ones(num_freqs).float()
|
| 67 |
+
else:
|
| 68 |
+
raise ValueError(f'unknown modality {freqs_for}')
|
| 69 |
+
|
| 70 |
+
if ft_seq_len is None: ft_seq_len = pt_seq_len
|
| 71 |
+
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
|
| 72 |
+
|
| 73 |
+
freqs_h = torch.einsum('..., f -> ... f', t, freqs)
|
| 74 |
+
freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2)
|
| 75 |
+
|
| 76 |
+
freqs_w = torch.einsum('..., f -> ... f', t, freqs)
|
| 77 |
+
freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2)
|
| 78 |
+
|
| 79 |
+
freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1)
|
| 80 |
+
|
| 81 |
+
self.register_buffer("freqs_cos", freqs.cos())
|
| 82 |
+
self.register_buffer("freqs_sin", freqs.sin())
|
| 83 |
+
|
| 84 |
+
# print('======== shape of rope freq', self.freqs_cos.shape, '========')
|
| 85 |
+
|
| 86 |
+
def forward(self, t, start_index = 0):
|
| 87 |
+
rot_dim = self.freqs_cos.shape[-1]
|
| 88 |
+
end_index = start_index + rot_dim
|
| 89 |
+
assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
|
| 90 |
+
t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
|
| 91 |
+
t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
|
| 92 |
+
return torch.cat((t_left, t, t_right), dim = -1)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class VisionRotaryEmbeddingFast(nn.Module):
|
| 97 |
+
def __init__(
|
| 98 |
+
self,
|
| 99 |
+
dim,
|
| 100 |
+
pt_seq_len=16,
|
| 101 |
+
ft_seq_len=None,
|
| 102 |
+
custom_freqs = None,
|
| 103 |
+
freqs_for = 'lang',
|
| 104 |
+
theta = 10000,
|
| 105 |
+
max_freq = 10,
|
| 106 |
+
num_freqs = 1,
|
| 107 |
+
):
|
| 108 |
+
super().__init__()
|
| 109 |
+
if custom_freqs:
|
| 110 |
+
freqs = custom_freqs
|
| 111 |
+
elif freqs_for == 'lang':
|
| 112 |
+
freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
|
| 113 |
+
elif freqs_for == 'pixel':
|
| 114 |
+
freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
|
| 115 |
+
elif freqs_for == 'constant':
|
| 116 |
+
freqs = torch.ones(num_freqs).float()
|
| 117 |
+
else:
|
| 118 |
+
raise ValueError(f'unknown modality {freqs_for}')
|
| 119 |
+
|
| 120 |
+
if ft_seq_len is None: ft_seq_len = pt_seq_len
|
| 121 |
+
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
|
| 122 |
+
|
| 123 |
+
freqs = torch.einsum('..., f -> ... f', t, freqs)
|
| 124 |
+
freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
|
| 125 |
+
freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1)
|
| 126 |
+
|
| 127 |
+
freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
|
| 128 |
+
freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
|
| 129 |
+
|
| 130 |
+
self.register_buffer("freqs_cos", freqs_cos)
|
| 131 |
+
self.register_buffer("freqs_sin", freqs_sin)
|
| 132 |
+
|
| 133 |
+
# print('======== shape of rope freq', self.freqs_cos.shape, '========')
|
| 134 |
+
|
| 135 |
+
def forward(self, t): return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
|
LDMAE/models/rmsnorm.py
ADDED
|
@@ -0,0 +1,495 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
| 3 |
+
|
| 4 |
+
import math
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import Optional, Tuple
|
| 7 |
+
|
| 8 |
+
import fairscale.nn.model_parallel.initialize as fs_init
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from fairscale.nn.model_parallel.layers import (
|
| 12 |
+
ColumnParallelLinear,
|
| 13 |
+
ParallelEmbedding,
|
| 14 |
+
RowParallelLinear,
|
| 15 |
+
)
|
| 16 |
+
from torch import nn
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class ModelArgs:
|
| 21 |
+
dim: int = 4096
|
| 22 |
+
n_layers: int = 32
|
| 23 |
+
n_heads: int = 32
|
| 24 |
+
n_kv_heads: Optional[int] = None
|
| 25 |
+
vocab_size: int = -1 # defined later by tokenizer
|
| 26 |
+
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
|
| 27 |
+
ffn_dim_multiplier: Optional[float] = None
|
| 28 |
+
norm_eps: float = 1e-5
|
| 29 |
+
|
| 30 |
+
max_batch_size: int = 32
|
| 31 |
+
max_seq_len: int = 2048
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class RMSNorm(torch.nn.Module):
|
| 35 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
| 36 |
+
"""
|
| 37 |
+
Initialize the RMSNorm normalization layer.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
dim (int): The dimension of the input tensor.
|
| 41 |
+
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
| 42 |
+
|
| 43 |
+
Attributes:
|
| 44 |
+
eps (float): A small value added to the denominator for numerical stability.
|
| 45 |
+
weight (nn.Parameter): Learnable scaling parameter.
|
| 46 |
+
|
| 47 |
+
"""
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.eps = eps
|
| 50 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 51 |
+
|
| 52 |
+
def _norm(self, x):
|
| 53 |
+
"""
|
| 54 |
+
Apply the RMSNorm normalization to the input tensor.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
x (torch.Tensor): The input tensor.
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
torch.Tensor: The normalized tensor.
|
| 61 |
+
|
| 62 |
+
"""
|
| 63 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 64 |
+
|
| 65 |
+
def forward(self, x):
|
| 66 |
+
"""
|
| 67 |
+
Forward pass through the RMSNorm layer.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
x (torch.Tensor): The input tensor.
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
torch.Tensor: The output tensor after applying RMSNorm.
|
| 74 |
+
|
| 75 |
+
"""
|
| 76 |
+
output = self._norm(x.float()).type_as(x)
|
| 77 |
+
return output * self.weight
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
|
| 81 |
+
"""
|
| 82 |
+
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
| 83 |
+
|
| 84 |
+
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
|
| 85 |
+
and the end index 'end'. The 'theta' parameter scales the frequencies.
|
| 86 |
+
The returned tensor contains complex values in complex64 data type.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
dim (int): Dimension of the frequency tensor.
|
| 90 |
+
end (int): End index for precomputing frequencies.
|
| 91 |
+
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
torch.Tensor: Precomputed frequency tensor with complex exponentials.
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
"""
|
| 100 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
| 101 |
+
t = torch.arange(end, device=freqs.device) # type: ignore
|
| 102 |
+
freqs = torch.outer(t, freqs).float() # type: ignore
|
| 103 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
| 104 |
+
return freqs_cis
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
| 108 |
+
"""
|
| 109 |
+
Reshape frequency tensor for broadcasting it with another tensor.
|
| 110 |
+
|
| 111 |
+
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
|
| 112 |
+
for the purpose of broadcasting the frequency tensor during element-wise operations.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
|
| 116 |
+
x (torch.Tensor): Target tensor for broadcasting compatibility.
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
torch.Tensor: Reshaped frequency tensor.
|
| 120 |
+
|
| 121 |
+
Raises:
|
| 122 |
+
AssertionError: If the frequency tensor doesn't match the expected shape.
|
| 123 |
+
AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
|
| 124 |
+
"""
|
| 125 |
+
ndim = x.ndim
|
| 126 |
+
assert 0 <= 1 < ndim
|
| 127 |
+
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
|
| 128 |
+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
| 129 |
+
return freqs_cis.view(*shape)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def apply_rotary_emb(
|
| 133 |
+
xq: torch.Tensor,
|
| 134 |
+
xk: torch.Tensor,
|
| 135 |
+
freqs_cis: torch.Tensor,
|
| 136 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 137 |
+
"""
|
| 138 |
+
Apply rotary embeddings to input tensors using the given frequency tensor.
|
| 139 |
+
|
| 140 |
+
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
|
| 141 |
+
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
|
| 142 |
+
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
|
| 143 |
+
returned as real tensors.
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
xq (torch.Tensor): Query tensor to apply rotary embeddings.
|
| 147 |
+
xk (torch.Tensor): Key tensor to apply rotary embeddings.
|
| 148 |
+
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
"""
|
| 156 |
+
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
| 157 |
+
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
| 158 |
+
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
| 159 |
+
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
|
| 160 |
+
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
| 161 |
+
return xq_out.type_as(xq), xk_out.type_as(xk)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 165 |
+
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
|
| 166 |
+
bs, slen, n_kv_heads, head_dim = x.shape
|
| 167 |
+
if n_rep == 1:
|
| 168 |
+
return x
|
| 169 |
+
return (
|
| 170 |
+
x[:, :, :, None, :]
|
| 171 |
+
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
|
| 172 |
+
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class Attention(nn.Module):
|
| 177 |
+
"""Multi-head attention module."""
|
| 178 |
+
def __init__(self, args: ModelArgs):
|
| 179 |
+
"""
|
| 180 |
+
Initialize the Attention module.
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
args (ModelArgs): Model configuration parameters.
|
| 184 |
+
|
| 185 |
+
Attributes:
|
| 186 |
+
n_kv_heads (int): Number of key and value heads.
|
| 187 |
+
n_local_heads (int): Number of local query heads.
|
| 188 |
+
n_local_kv_heads (int): Number of local key and value heads.
|
| 189 |
+
n_rep (int): Number of repetitions for local heads.
|
| 190 |
+
head_dim (int): Dimension size of each attention head.
|
| 191 |
+
wq (ColumnParallelLinear): Linear transformation for queries.
|
| 192 |
+
wk (ColumnParallelLinear): Linear transformation for keys.
|
| 193 |
+
wv (ColumnParallelLinear): Linear transformation for values.
|
| 194 |
+
wo (RowParallelLinear): Linear transformation for output.
|
| 195 |
+
cache_k (torch.Tensor): Cached keys for attention.
|
| 196 |
+
cache_v (torch.Tensor): Cached values for attention.
|
| 197 |
+
|
| 198 |
+
"""
|
| 199 |
+
super().__init__()
|
| 200 |
+
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
| 201 |
+
model_parallel_size = fs_init.get_model_parallel_world_size()
|
| 202 |
+
self.n_local_heads = args.n_heads // model_parallel_size
|
| 203 |
+
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
|
| 204 |
+
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
| 205 |
+
self.head_dim = args.dim // args.n_heads
|
| 206 |
+
|
| 207 |
+
self.wq = ColumnParallelLinear(
|
| 208 |
+
args.dim,
|
| 209 |
+
args.n_heads * self.head_dim,
|
| 210 |
+
bias=False,
|
| 211 |
+
gather_output=False,
|
| 212 |
+
init_method=lambda x: x,
|
| 213 |
+
)
|
| 214 |
+
self.wk = ColumnParallelLinear(
|
| 215 |
+
args.dim,
|
| 216 |
+
self.n_kv_heads * self.head_dim,
|
| 217 |
+
bias=False,
|
| 218 |
+
gather_output=False,
|
| 219 |
+
init_method=lambda x: x,
|
| 220 |
+
)
|
| 221 |
+
self.wv = ColumnParallelLinear(
|
| 222 |
+
args.dim,
|
| 223 |
+
self.n_kv_heads * self.head_dim,
|
| 224 |
+
bias=False,
|
| 225 |
+
gather_output=False,
|
| 226 |
+
init_method=lambda x: x,
|
| 227 |
+
)
|
| 228 |
+
self.wo = RowParallelLinear(
|
| 229 |
+
args.n_heads * self.head_dim,
|
| 230 |
+
args.dim,
|
| 231 |
+
bias=False,
|
| 232 |
+
input_is_parallel=True,
|
| 233 |
+
init_method=lambda x: x,
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
self.cache_k = torch.zeros(
|
| 237 |
+
(
|
| 238 |
+
args.max_batch_size,
|
| 239 |
+
args.max_seq_len,
|
| 240 |
+
self.n_local_kv_heads,
|
| 241 |
+
self.head_dim,
|
| 242 |
+
)
|
| 243 |
+
).cuda()
|
| 244 |
+
self.cache_v = torch.zeros(
|
| 245 |
+
(
|
| 246 |
+
args.max_batch_size,
|
| 247 |
+
args.max_seq_len,
|
| 248 |
+
self.n_local_kv_heads,
|
| 249 |
+
self.head_dim,
|
| 250 |
+
)
|
| 251 |
+
).cuda()
|
| 252 |
+
|
| 253 |
+
def forward(
|
| 254 |
+
self,
|
| 255 |
+
x: torch.Tensor,
|
| 256 |
+
start_pos: int,
|
| 257 |
+
freqs_cis: torch.Tensor,
|
| 258 |
+
mask: Optional[torch.Tensor],
|
| 259 |
+
):
|
| 260 |
+
"""
|
| 261 |
+
Forward pass of the attention module.
|
| 262 |
+
|
| 263 |
+
Args:
|
| 264 |
+
x (torch.Tensor): Input tensor.
|
| 265 |
+
start_pos (int): Starting position for caching.
|
| 266 |
+
freqs_cis (torch.Tensor): Precomputed frequency tensor.
|
| 267 |
+
mask (torch.Tensor, optional): Attention mask tensor.
|
| 268 |
+
|
| 269 |
+
Returns:
|
| 270 |
+
torch.Tensor: Output tensor after attention.
|
| 271 |
+
|
| 272 |
+
"""
|
| 273 |
+
bsz, seqlen, _ = x.shape
|
| 274 |
+
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
| 275 |
+
|
| 276 |
+
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
| 277 |
+
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
| 278 |
+
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
| 279 |
+
|
| 280 |
+
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
|
| 281 |
+
|
| 282 |
+
self.cache_k = self.cache_k.to(xq)
|
| 283 |
+
self.cache_v = self.cache_v.to(xq)
|
| 284 |
+
|
| 285 |
+
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
|
| 286 |
+
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
|
| 287 |
+
|
| 288 |
+
keys = self.cache_k[:bsz, : start_pos + seqlen]
|
| 289 |
+
values = self.cache_v[:bsz, : start_pos + seqlen]
|
| 290 |
+
|
| 291 |
+
# repeat k/v heads if n_kv_heads < n_heads
|
| 292 |
+
keys = repeat_kv(keys, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
|
| 293 |
+
values = repeat_kv(values, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
|
| 294 |
+
|
| 295 |
+
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
|
| 296 |
+
keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
|
| 297 |
+
values = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
|
| 298 |
+
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
|
| 299 |
+
if mask is not None:
|
| 300 |
+
scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
|
| 301 |
+
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
|
| 302 |
+
output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
|
| 303 |
+
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
|
| 304 |
+
return self.wo(output)
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
class FeedForward(nn.Module):
|
| 308 |
+
def __init__(
|
| 309 |
+
self,
|
| 310 |
+
dim: int,
|
| 311 |
+
hidden_dim: int,
|
| 312 |
+
multiple_of: int,
|
| 313 |
+
ffn_dim_multiplier: Optional[float],
|
| 314 |
+
):
|
| 315 |
+
"""
|
| 316 |
+
Initialize the FeedForward module.
|
| 317 |
+
|
| 318 |
+
Args:
|
| 319 |
+
dim (int): Input dimension.
|
| 320 |
+
hidden_dim (int): Hidden dimension of the feedforward layer.
|
| 321 |
+
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
|
| 322 |
+
ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.
|
| 323 |
+
|
| 324 |
+
Attributes:
|
| 325 |
+
w1 (ColumnParallelLinear): Linear transformation for the first layer.
|
| 326 |
+
w2 (RowParallelLinear): Linear transformation for the second layer.
|
| 327 |
+
w3 (ColumnParallelLinear): Linear transformation for the third layer.
|
| 328 |
+
|
| 329 |
+
"""
|
| 330 |
+
super().__init__()
|
| 331 |
+
hidden_dim = int(2 * hidden_dim / 3)
|
| 332 |
+
# custom dim factor multiplier
|
| 333 |
+
if ffn_dim_multiplier is not None:
|
| 334 |
+
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
| 335 |
+
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
| 336 |
+
|
| 337 |
+
self.w1 = ColumnParallelLinear(
|
| 338 |
+
dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
|
| 339 |
+
)
|
| 340 |
+
self.w2 = RowParallelLinear(
|
| 341 |
+
hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x
|
| 342 |
+
)
|
| 343 |
+
self.w3 = ColumnParallelLinear(
|
| 344 |
+
dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
def forward(self, x):
|
| 348 |
+
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
class TransformerBlock(nn.Module):
|
| 352 |
+
def __init__(self, layer_id: int, args: ModelArgs):
|
| 353 |
+
"""
|
| 354 |
+
Initialize a TransformerBlock.
|
| 355 |
+
|
| 356 |
+
Args:
|
| 357 |
+
layer_id (int): Identifier for the layer.
|
| 358 |
+
args (ModelArgs): Model configuration parameters.
|
| 359 |
+
|
| 360 |
+
Attributes:
|
| 361 |
+
n_heads (int): Number of attention heads.
|
| 362 |
+
dim (int): Dimension size of the model.
|
| 363 |
+
head_dim (int): Dimension size of each attention head.
|
| 364 |
+
attention (Attention): Attention module.
|
| 365 |
+
feed_forward (FeedForward): FeedForward module.
|
| 366 |
+
layer_id (int): Identifier for the layer.
|
| 367 |
+
attention_norm (RMSNorm): Layer normalization for attention output.
|
| 368 |
+
ffn_norm (RMSNorm): Layer normalization for feedforward output.
|
| 369 |
+
|
| 370 |
+
"""
|
| 371 |
+
super().__init__()
|
| 372 |
+
self.n_heads = args.n_heads
|
| 373 |
+
self.dim = args.dim
|
| 374 |
+
self.head_dim = args.dim // args.n_heads
|
| 375 |
+
self.attention = Attention(args)
|
| 376 |
+
self.feed_forward = FeedForward(
|
| 377 |
+
dim=args.dim,
|
| 378 |
+
hidden_dim=4 * args.dim,
|
| 379 |
+
multiple_of=args.multiple_of,
|
| 380 |
+
ffn_dim_multiplier=args.ffn_dim_multiplier,
|
| 381 |
+
)
|
| 382 |
+
self.layer_id = layer_id
|
| 383 |
+
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
| 384 |
+
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
| 385 |
+
|
| 386 |
+
def forward(
|
| 387 |
+
self,
|
| 388 |
+
x: torch.Tensor,
|
| 389 |
+
start_pos: int,
|
| 390 |
+
freqs_cis: torch.Tensor,
|
| 391 |
+
mask: Optional[torch.Tensor],
|
| 392 |
+
):
|
| 393 |
+
"""
|
| 394 |
+
Perform a forward pass through the TransformerBlock.
|
| 395 |
+
|
| 396 |
+
Args:
|
| 397 |
+
x (torch.Tensor): Input tensor.
|
| 398 |
+
start_pos (int): Starting position for attention caching.
|
| 399 |
+
freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
|
| 400 |
+
mask (torch.Tensor, optional): Masking tensor for attention. Defaults to None.
|
| 401 |
+
|
| 402 |
+
Returns:
|
| 403 |
+
torch.Tensor: Output tensor after applying attention and feedforward layers.
|
| 404 |
+
|
| 405 |
+
"""
|
| 406 |
+
h = x + self.attention(
|
| 407 |
+
self.attention_norm(x), start_pos, freqs_cis, mask
|
| 408 |
+
)
|
| 409 |
+
out = h + self.feed_forward(self.ffn_norm(h))
|
| 410 |
+
return out
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
class Transformer(nn.Module):
|
| 414 |
+
def __init__(self, params: ModelArgs):
|
| 415 |
+
"""
|
| 416 |
+
Initialize a Transformer model.
|
| 417 |
+
|
| 418 |
+
Args:
|
| 419 |
+
params (ModelArgs): Model configuration parameters.
|
| 420 |
+
|
| 421 |
+
Attributes:
|
| 422 |
+
params (ModelArgs): Model configuration parameters.
|
| 423 |
+
vocab_size (int): Vocabulary size.
|
| 424 |
+
n_layers (int): Number of layers in the model.
|
| 425 |
+
tok_embeddings (ParallelEmbedding): Token embeddings.
|
| 426 |
+
layers (torch.nn.ModuleList): List of Transformer blocks.
|
| 427 |
+
norm (RMSNorm): Layer normalization for the model output.
|
| 428 |
+
output (ColumnParallelLinear): Linear layer for final output.
|
| 429 |
+
freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
|
| 430 |
+
|
| 431 |
+
"""
|
| 432 |
+
super().__init__()
|
| 433 |
+
self.params = params
|
| 434 |
+
self.vocab_size = params.vocab_size
|
| 435 |
+
self.n_layers = params.n_layers
|
| 436 |
+
|
| 437 |
+
self.tok_embeddings = ParallelEmbedding(
|
| 438 |
+
params.vocab_size, params.dim, init_method=lambda x: x
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
self.layers = torch.nn.ModuleList()
|
| 442 |
+
for layer_id in range(params.n_layers):
|
| 443 |
+
self.layers.append(TransformerBlock(layer_id, params))
|
| 444 |
+
|
| 445 |
+
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
|
| 446 |
+
self.output = ColumnParallelLinear(
|
| 447 |
+
params.dim, params.vocab_size, bias=False, init_method=lambda x: x
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
self.freqs_cis = precompute_freqs_cis(
|
| 451 |
+
# Note that self.params.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation of models is 4096.
|
| 452 |
+
# Adding this multiplier instead of using 4096 directly allows for dynamism of token lengths while training or fine-tuning.
|
| 453 |
+
self.params.dim // self.params.n_heads, self.params.max_seq_len * 2
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
@torch.inference_mode()
|
| 457 |
+
def forward(self, tokens: torch.Tensor, start_pos: int):
|
| 458 |
+
"""
|
| 459 |
+
Perform a forward pass through the Transformer model.
|
| 460 |
+
|
| 461 |
+
Args:
|
| 462 |
+
tokens (torch.Tensor): Input token indices.
|
| 463 |
+
start_pos (int): Starting position for attention caching.
|
| 464 |
+
|
| 465 |
+
Returns:
|
| 466 |
+
torch.Tensor: Output logits after applying the Transformer model.
|
| 467 |
+
|
| 468 |
+
"""
|
| 469 |
+
_bsz, seqlen = tokens.shape
|
| 470 |
+
h = self.tok_embeddings(tokens)
|
| 471 |
+
self.freqs_cis = self.freqs_cis.to(h.device)
|
| 472 |
+
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
|
| 473 |
+
|
| 474 |
+
mask = None
|
| 475 |
+
if seqlen > 1:
|
| 476 |
+
mask = torch.full(
|
| 477 |
+
(seqlen, seqlen), float("-inf"), device=tokens.device
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
mask = torch.triu(mask, diagonal=1)
|
| 481 |
+
|
| 482 |
+
# When performing key-value caching, we compute the attention scores
|
| 483 |
+
# only for the new sequence. Thus, the matrix of scores is of size
|
| 484 |
+
# (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
|
| 485 |
+
# j > cache_len + i, since row i corresponds to token cache_len + i.
|
| 486 |
+
mask = torch.hstack([
|
| 487 |
+
torch.zeros((seqlen, start_pos), device=tokens.device),
|
| 488 |
+
mask
|
| 489 |
+
]).type_as(h)
|
| 490 |
+
|
| 491 |
+
for layer in self.layers:
|
| 492 |
+
h = layer(h, start_pos, freqs_cis, mask)
|
| 493 |
+
h = self.norm(h)
|
| 494 |
+
output = self.output(h).float()
|
| 495 |
+
return output
|
LDMAE/models/swiglu_ffn.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import torch
|
| 8 |
+
from typing import Callable, Optional
|
| 9 |
+
import warnings
|
| 10 |
+
|
| 11 |
+
from torch import Tensor, nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class SwiGLUFFN(nn.Module):
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
in_features: int,
|
| 19 |
+
hidden_features: Optional[int] = None,
|
| 20 |
+
out_features: Optional[int] = None,
|
| 21 |
+
act_layer: Callable[..., nn.Module] = None,
|
| 22 |
+
drop: float = 0.0,
|
| 23 |
+
bias: bool = True,
|
| 24 |
+
) -> None:
|
| 25 |
+
super().__init__()
|
| 26 |
+
out_features = out_features or in_features
|
| 27 |
+
hidden_features = hidden_features or in_features
|
| 28 |
+
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
|
| 29 |
+
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 30 |
+
|
| 31 |
+
@torch.compile
|
| 32 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 33 |
+
x12 = self.w12(x)
|
| 34 |
+
x1, x2 = x12.chunk(2, dim=-1)
|
| 35 |
+
hidden = F.silu(x1) * x2
|
| 36 |
+
return self.w3(hidden)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
| 40 |
+
try:
|
| 41 |
+
if XFORMERS_ENABLED:
|
| 42 |
+
from xformers.ops import SwiGLU
|
| 43 |
+
|
| 44 |
+
XFORMERS_AVAILABLE = True
|
| 45 |
+
# warnings.warn("xFormers is available (SwiGLU)")
|
| 46 |
+
else:
|
| 47 |
+
# warnings.warn("xFormers is disabled (SwiGLU)")
|
| 48 |
+
raise ImportError
|
| 49 |
+
except ImportError:
|
| 50 |
+
SwiGLU = SwiGLUFFN
|
| 51 |
+
XFORMERS_AVAILABLE = False
|
| 52 |
+
|
| 53 |
+
# warnings.warn("xFormers is not available (SwiGLU)")
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class SwiGLUFFNFused(SwiGLU):
|
| 57 |
+
def __init__(
|
| 58 |
+
self,
|
| 59 |
+
in_features: int,
|
| 60 |
+
hidden_features: Optional[int] = None,
|
| 61 |
+
out_features: Optional[int] = None,
|
| 62 |
+
act_layer: Callable[..., nn.Module] = None,
|
| 63 |
+
drop: float = 0.0,
|
| 64 |
+
bias: bool = True,
|
| 65 |
+
) -> None:
|
| 66 |
+
out_features = out_features or in_features
|
| 67 |
+
hidden_features = hidden_features or in_features
|
| 68 |
+
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
| 69 |
+
super().__init__(
|
| 70 |
+
in_features=in_features,
|
| 71 |
+
hidden_features=hidden_features,
|
| 72 |
+
out_features=out_features,
|
| 73 |
+
bias=bias,
|
| 74 |
+
)
|
LDMAE/pretrain_weight/aef8d16.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:97ad3672641653fcd74106cd050dc8f5042089b8edc06e30cbcde642be239aa6
|
| 3 |
+
size 1006144522
|
LDMAE/pretrain_weight/daef8d16.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cbf39522a2f602df6f271b8b0ea0a73a7c23687fd54f98c5b60ef85289b15168
|
| 3 |
+
size 1006144522
|
LDMAE/pretrain_weight/sdv3f8d16.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9b367d60d708cb371261c005a44bd68f8d17dd211f8c771fb2b3802e51df2f8c
|
| 3 |
+
size 1098238157
|
LDMAE/pretrain_weight/vaef8d16.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8f0e12f87137cdb19bd8f461dc1c8d7c572628d79e03f070ebdcc7a802e610c6
|
| 3 |
+
size 1006144522
|
LDMAE/pretrain_weight/vmaef8d16.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:441e7360993a03978e729dafc77432a372f066bf46a3c9610c7c33c4a0f09fc1
|
| 3 |
+
size 147225897
|
LDMAE/requirements.txt
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0.0
|
| 2 |
+
torchvision
|
| 3 |
+
accelerate
|
| 4 |
+
transformers
|
| 5 |
+
Pillow
|
| 6 |
+
numpy
|
| 7 |
+
scipy
|
| 8 |
+
tqdm
|
| 9 |
+
matplotlib
|
| 10 |
+
tensorboard
|
| 11 |
+
omegaconf
|
| 12 |
+
einops
|
| 13 |
+
timm
|
| 14 |
+
opencv-python
|
| 15 |
+
scikit-learn
|
| 16 |
+
lpips
|
LDMAE/run_extract_feature.sh
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
CONFIG_PATH=$1
|
| 2 |
+
|
| 3 |
+
GPUS_PER_NODE=${GPUS_PER_NODE:-8}
|
| 4 |
+
NNODES=${WORLD_SIZE:-1}
|
| 5 |
+
NODE_RANK=${RANK:-0}
|
| 6 |
+
MASTER_ADDR=${MASTER_ADDR:-127.0.0.1}
|
| 7 |
+
MASTER_PORT=${MASTER_PORT:-1235}
|
| 8 |
+
WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
|
| 9 |
+
PRECISION=${PRECISION:-bf16}
|
| 10 |
+
|
| 11 |
+
echo $CONFIG_PATH
|
| 12 |
+
|
| 13 |
+
accelerate launch \
|
| 14 |
+
--config-file configs/accelerator/8gpu.yaml \
|
| 15 |
+
--main_process_ip $MASTER_ADDR \
|
| 16 |
+
--main_process_port $MASTER_PORT \
|
| 17 |
+
--machine_rank $NODE_RANK \
|
| 18 |
+
--num_processes $(($GPUS_PER_NODE*$NNODES)) \
|
| 19 |
+
--num_machines $NNODES \
|
| 20 |
+
--mixed_precision $PRECISION \
|
| 21 |
+
extract_features.py \
|
| 22 |
+
--config $CONFIG_PATH \
|
LDMAE/run_fast_inference.sh
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
CONFIG_PATH=$1
|
| 2 |
+
|
| 3 |
+
GPUS_PER_NODE=1
|
| 4 |
+
NNODES=1
|
| 5 |
+
NODE_RANK=0
|
| 6 |
+
MASTER_ADDR=${MASTER_ADDR:-127.0.0.1}
|
| 7 |
+
MASTER_PORT=${MASTER_PORT:-1236}
|
| 8 |
+
WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
|
| 9 |
+
PRECISION=${PRECISION:-bf16}
|
| 10 |
+
|
| 11 |
+
accelerate launch \
|
| 12 |
+
--main_process_ip $MASTER_ADDR \
|
| 13 |
+
--main_process_port $MASTER_PORT \
|
| 14 |
+
--machine_rank $NODE_RANK \
|
| 15 |
+
--num_processes $(($GPUS_PER_NODE*$NNODES)) \
|
| 16 |
+
--num_machines $NNODES \
|
| 17 |
+
--mixed_precision $PRECISION \
|
| 18 |
+
inference.py \
|
| 19 |
+
--config $CONFIG_PATH \
|
| 20 |
+
--demo
|
LDMAE/run_inference.sh
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
CONFIG_PATH=$1
|
| 2 |
+
|
| 3 |
+
GPUS_PER_NODE=${GPUS_PER_NODE:-8}
|
| 4 |
+
NNODES=${WORLD_SIZE:-1}
|
| 5 |
+
NODE_RANK=${RANK:-0}
|
| 6 |
+
MASTER_ADDR=${MASTER_ADDR:-127.0.0.1}
|
| 7 |
+
MASTER_PORT=${MASTER_PORT:-1237}
|
| 8 |
+
WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
|
| 9 |
+
PRECISION=${PRECISION:-bf16}
|
| 10 |
+
|
| 11 |
+
accelerate launch \
|
| 12 |
+
--config-file configs/accelerator/8gpu.yaml \
|
| 13 |
+
--main_process_ip $MASTER_ADDR \
|
| 14 |
+
--main_process_port $MASTER_PORT \
|
| 15 |
+
--machine_rank $NODE_RANK \
|
| 16 |
+
--num_processes $(($GPUS_PER_NODE*$NNODES)) \
|
| 17 |
+
--num_machines $NNODES \
|
| 18 |
+
--mixed_precision $PRECISION \
|
| 19 |
+
inference.py \
|
| 20 |
+
--config $CONFIG_PATH
|
LDMAE/run_robustness_test.sh
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CONFIG=$1
|
| 2 |
+
GPUS_PER_NODE=${GPUS_PER_NODE:-8}
|
| 3 |
+
NNODES=${WORLD_SIZE:-1}
|
| 4 |
+
NODE_RANK=${RANK:-0}
|
| 5 |
+
MASTER_ADDR=${MASTER_ADDR:-127.0.0.1}
|
| 6 |
+
MASTER_PORT=${MASTER_PORT:-1241}
|
| 7 |
+
WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
|
| 8 |
+
PRECISION=${PRECISION:-bf16}
|
| 9 |
+
|
| 10 |
+
# VMAE reconstruction
|
| 11 |
+
accelerate launch \
|
| 12 |
+
--config-file configs/accelerator/8gpu.yaml \
|
| 13 |
+
--main_process_ip $MASTER_ADDR \
|
| 14 |
+
--main_process_port $MASTER_PORT \
|
| 15 |
+
--machine_rank $NODE_RANK \
|
| 16 |
+
--num_processes $(($GPUS_PER_NODE*$NNODES)) \
|
| 17 |
+
--num_machines $NNODES \
|
| 18 |
+
--mixed_precision $PRECISION \
|
| 19 |
+
evaluate_tokenizer.py \
|
| 20 |
+
--config configs/imagenet/lightningdit_b_vmae_f8d16_cfg.yaml \
|
| 21 |
+
--robust_exp True \
|
| 22 |
+
|
| 23 |
+
accelerate launch \
|
| 24 |
+
--config-file configs/accelerator/8gpu.yaml \
|
| 25 |
+
--main_process_ip $MASTER_ADDR \
|
| 26 |
+
--main_process_port $MASTER_PORT \
|
| 27 |
+
--machine_rank $NODE_RANK \
|
| 28 |
+
--num_processes $(($GPUS_PER_NODE*$NNODES)) \
|
| 29 |
+
--num_machines $NNODES \
|
| 30 |
+
--mixed_precision $PRECISION \
|
| 31 |
+
evaluate_tokenizer.py \
|
| 32 |
+
--epsilon 0.01 \
|
| 33 |
+
--config configs/imagenet/lightningdit_b_vmae_f8d16_cfg.yaml \
|
| 34 |
+
|
| 35 |
+
accelerate launch \
|
| 36 |
+
--config-file configs/accelerator/8gpu.yaml \
|
| 37 |
+
--main_process_ip $MASTER_ADDR \
|
| 38 |
+
--main_process_port $MASTER_PORT \
|
| 39 |
+
--machine_rank $NODE_RANK \
|
| 40 |
+
--num_processes $(($GPUS_PER_NODE*$NNODES)) \
|
| 41 |
+
--num_machines $NNODES \
|
| 42 |
+
--mixed_precision $PRECISION \
|
| 43 |
+
evaluate_tokenizer.py \
|
| 44 |
+
--epsilon 0.05 \
|
| 45 |
+
--config configs/imagenet/lightningdit_b_vmae_f8d16_cfg.yaml \
|
| 46 |
+
|
| 47 |
+
accelerate launch \
|
| 48 |
+
--config-file configs/accelerator/8gpu.yaml \
|
| 49 |
+
--main_process_ip $MASTER_ADDR \
|
| 50 |
+
--main_process_port $MASTER_PORT \
|
| 51 |
+
--machine_rank $NODE_RANK \
|
| 52 |
+
--num_processes $(($GPUS_PER_NODE*$NNODES)) \
|
| 53 |
+
--num_machines $NNODES \
|
| 54 |
+
--mixed_precision $PRECISION \
|
| 55 |
+
evaluate_tokenizer.py \
|
| 56 |
+
--epsilon 0.1 \
|
| 57 |
+
--config configs/imagenet/lightningdit_b_vmae_f8d16_cfg.yaml \
|
| 58 |
+
|
| 59 |
+
accelerate launch \
|
| 60 |
+
--config-file configs/accelerator/8gpu.yaml \
|
| 61 |
+
--main_process_ip $MASTER_ADDR \
|
| 62 |
+
--main_process_port $MASTER_PORT \
|
| 63 |
+
--machine_rank $NODE_RANK \
|
| 64 |
+
--num_processes $(($GPUS_PER_NODE*$NNODES)) \
|
| 65 |
+
--num_machines $NNODES \
|
| 66 |
+
--mixed_precision $PRECISION \
|
| 67 |
+
evaluate_tokenizer_mae.py \
|
| 68 |
+
--epsilon 0.2 \
|
| 69 |
+
--config configs/imagenet/lightningdit_b_vmae_f8d16_cfg.yaml \
|
| 70 |
+
|
| 71 |
+
accelerate launch \
|
| 72 |
+
--config-file configs/accelerator/8gpu.yaml \
|
| 73 |
+
--main_process_ip $MASTER_ADDR \
|
| 74 |
+
--main_process_port $MASTER_PORT \
|
| 75 |
+
--machine_rank $NODE_RANK \
|
| 76 |
+
--num_processes $(($GPUS_PER_NODE*$NNODES)) \
|
| 77 |
+
--num_machines $NNODES \
|
| 78 |
+
--mixed_precision $PRECISION \
|
| 79 |
+
evaluate_tokenizer.py \
|
| 80 |
+
--epsilon 0.3 \
|
| 81 |
+
--config configs/imagenet/lightningdit_b_vmae_f8d16_cfg.yaml \
|
LDMAE/run_train.sh
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
CONFIG_PATH=$1
|
| 2 |
+
|
| 3 |
+
GPUS_PER_NODE=${GPUS_PER_NODE:-8}
|
| 4 |
+
NNODES=${WORLD_SIZE:-1}
|
| 5 |
+
NODE_RANK=${RANK:-0}
|
| 6 |
+
MASTER_ADDR=${MASTER_ADDR:-127.0.0.1}
|
| 7 |
+
MASTER_PORT=${MASTER_PORT:-1235}
|
| 8 |
+
WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
|
| 9 |
+
PRECISION=${PRECISION:-bf16}
|
| 10 |
+
|
| 11 |
+
echo $CONFIG_PATH
|
| 12 |
+
|
| 13 |
+
accelerate launch \
|
| 14 |
+
--config-file configs/accelerator/8gpu.yaml \
|
| 15 |
+
--main_process_ip $MASTER_ADDR \
|
| 16 |
+
--main_process_port $MASTER_PORT \
|
| 17 |
+
--machine_rank $NODE_RANK \
|
| 18 |
+
--num_processes $(($GPUS_PER_NODE*$NNODES)) \
|
| 19 |
+
--num_machines $NNODES \
|
| 20 |
+
--mixed_precision $PRECISION \
|
| 21 |
+
train_accum.py \
|
| 22 |
+
--config $CONFIG_PATH
|
LDMAE/tokenizer/__init__.py
ADDED
|
File without changes
|
LDMAE/tokenizer/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (174 Bytes). View file
|
|
|
LDMAE/tokenizer/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (146 Bytes). View file
|
|
|
LDMAE/tokenizer/__pycache__/autoencoder.cpython-310.pyc
ADDED
|
Binary file (12.5 kB). View file
|
|
|
LDMAE/tokenizer/__pycache__/models_mae.cpython-310.pyc
ADDED
|
Binary file (28.8 kB). View file
|
|
|
LDMAE/tokenizer/__pycache__/sdvae.cpython-310.pyc
ADDED
|
Binary file (3.5 kB). View file
|
|
|
LDMAE/tokenizer/__pycache__/vavae.cpython-310.pyc
ADDED
|
Binary file (4.48 kB). View file
|
|
|
LDMAE/tokenizer/__pycache__/vavae.cpython-38.pyc
ADDED
|
Binary file (4.35 kB). View file
|
|
|