isno0907 commited on
Commit
6c49103
·
verified ·
1 Parent(s): fdd9e4e

Upload 115 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. LDMAE/.DS_Store +0 -0
  3. LDMAE/configs/accelerator/4gpu.yaml +17 -0
  4. LDMAE/configs/accelerator/8gpu.yaml +17 -0
  5. LDMAE/configs/celeba_hq/lightningdit_b_vmae_f8d16_cfg.yaml +82 -0
  6. LDMAE/configs/imagenet/lightningdit_b_vmae_f8d16_cfg.yaml +80 -0
  7. LDMAE/datasets/__init__.py +0 -0
  8. LDMAE/datasets/__pycache__/__init__.cpython-310.pyc +0 -0
  9. LDMAE/datasets/__pycache__/__init__.cpython-38.pyc +0 -0
  10. LDMAE/datasets/__pycache__/img_latent_dataset.cpython-310.pyc +0 -0
  11. LDMAE/datasets/__pycache__/img_latent_dataset.cpython-38.pyc +0 -0
  12. LDMAE/datasets/img_latent_dataset.py +94 -0
  13. LDMAE/evaluate_tokenizer.py +262 -0
  14. LDMAE/extract_features.py +235 -0
  15. LDMAE/inference.py +368 -0
  16. LDMAE/models/__init__.py +0 -0
  17. LDMAE/models/__pycache__/__init__.cpython-310.pyc +0 -0
  18. LDMAE/models/__pycache__/__init__.cpython-38.pyc +0 -0
  19. LDMAE/models/__pycache__/lightningdit.cpython-310.pyc +0 -0
  20. LDMAE/models/__pycache__/lightningdit.cpython-38.pyc +0 -0
  21. LDMAE/models/__pycache__/pos_embed.cpython-310.pyc +0 -0
  22. LDMAE/models/__pycache__/pos_embed.cpython-38.pyc +0 -0
  23. LDMAE/models/__pycache__/rmsnorm.cpython-310.pyc +0 -0
  24. LDMAE/models/__pycache__/rmsnorm.cpython-38.pyc +0 -0
  25. LDMAE/models/__pycache__/swiglu_ffn.cpython-310.pyc +0 -0
  26. LDMAE/models/__pycache__/swiglu_ffn.cpython-38.pyc +0 -0
  27. LDMAE/models/lightningdit.py +531 -0
  28. LDMAE/models/lpips.py +184 -0
  29. LDMAE/models/pos_embed.py +135 -0
  30. LDMAE/models/rmsnorm.py +495 -0
  31. LDMAE/models/swiglu_ffn.py +74 -0
  32. LDMAE/pretrain_weight/aef8d16.pth +3 -0
  33. LDMAE/pretrain_weight/daef8d16.pth +3 -0
  34. LDMAE/pretrain_weight/sdv3f8d16.pth +3 -0
  35. LDMAE/pretrain_weight/vaef8d16.pth +3 -0
  36. LDMAE/pretrain_weight/vmaef8d16.pth +3 -0
  37. LDMAE/requirements.txt +16 -0
  38. LDMAE/run_extract_feature.sh +22 -0
  39. LDMAE/run_fast_inference.sh +20 -0
  40. LDMAE/run_inference.sh +20 -0
  41. LDMAE/run_robustness_test.sh +81 -0
  42. LDMAE/run_train.sh +22 -0
  43. LDMAE/tokenizer/__init__.py +0 -0
  44. LDMAE/tokenizer/__pycache__/__init__.cpython-310.pyc +0 -0
  45. LDMAE/tokenizer/__pycache__/__init__.cpython-38.pyc +0 -0
  46. LDMAE/tokenizer/__pycache__/autoencoder.cpython-310.pyc +0 -0
  47. LDMAE/tokenizer/__pycache__/models_mae.cpython-310.pyc +0 -0
  48. LDMAE/tokenizer/__pycache__/sdvae.cpython-310.pyc +0 -0
  49. LDMAE/tokenizer/__pycache__/vavae.cpython-310.pyc +0 -0
  50. LDMAE/tokenizer/__pycache__/vavae.cpython-38.pyc +0 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ figure/thumbnail.png filter=lfs diff=lfs merge=lfs -text
LDMAE/.DS_Store ADDED
Binary file (6.15 kB). View file
 
LDMAE/configs/accelerator/4gpu.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: MULTI_GPU
4
+ downcast_bf16: 'no'
5
+ enable_cpu_affinity: false
6
+ gpu_ids: all
7
+ machine_rank: 0
8
+ main_training_function: main
9
+ mixed_precision: bf16
10
+ num_machines: 1
11
+ num_processes: 4
12
+ rdzv_backend: static
13
+ same_network: true
14
+ tpu_env: []
15
+ tpu_use_cluster: false
16
+ tpu_use_sudo: false
17
+ use_cpu: false
LDMAE/configs/accelerator/8gpu.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: MULTI_GPU
4
+ downcast_bf16: 'no'
5
+ enable_cpu_affinity: false
6
+ gpu_ids: all
7
+ machine_rank: 0
8
+ main_training_function: main
9
+ mixed_precision: bf16
10
+ num_machines: 1
11
+ num_processes: 8
12
+ rdzv_backend: static
13
+ same_network: true
14
+ tpu_env: []
15
+ tpu_use_cluster: false
16
+ tpu_use_sudo: false
17
+ use_cpu: false
LDMAE/configs/celeba_hq/lightningdit_b_vmae_f8d16_cfg.yaml ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # we recommend to read config_details.yaml first.
2
+
3
+ ckpt_path: 'output/celeba_hq/lightningdit_b_vmae_f8d16/checkpoints/0060000.pt' # <---- download our pre-trained lightningdit or your own checkpoint
4
+
5
+ data:
6
+ name: celebahq
7
+ origin_path: "/data/dataset/celeba_hq/celeba_hq_256"
8
+ data_path: '/data/dataset/celeba_hq/vmae_feature_celebahq_train_256' # <---- path to your data. it is generated by extract_features.py.
9
+ # if you just want to inference, download our latent_stats.pt and give its path here is ok.
10
+ fid_reference_file: 'tools/fid_statistics/ALL_celebahq256.npz' # <---- path to your fid_reference_file.npz. download it from ADM
11
+
12
+
13
+ # recommend to use default settings
14
+ image_size: 256
15
+ num_classes: 1
16
+ num_workers: 8
17
+ latent_norm: true
18
+ latent_multiplier: 1.0
19
+ sample: true # <------------------------------ check this. you should comment out this when you don't want to use it.
20
+
21
+ # recommend to use default settings (we wil release codes with SD-VAE later)
22
+ vae:
23
+ model_name: 'vmae'
24
+ downsample_ratio: 8
25
+ weight_path: 'pretrain_weight/vmae_f8d16.pth'
26
+
27
+ # recommend to use default settings
28
+ model:
29
+ model_type: LightningDiT-B/1
30
+ use_qknorm: false # no qk normalizing in celeba
31
+ use_swiglu: true
32
+ use_rope: true
33
+ use_rmsnorm: true
34
+ wo_shift: false
35
+ in_chans: 16
36
+
37
+ # recommend to use default settings
38
+ train:
39
+ max_steps: 60000
40
+ global_batch_size: 1024 # 256 ok
41
+ global_seed: 1
42
+ output_dir: 'output'
43
+ exp_name: 'celeba_hq/lightningdit_b_vmae_f8d16' # <---- experiment name, set as you like
44
+ ckpt: null
45
+ log_every: 100
46
+ ckpt_every: 20000
47
+ use_checkpoint: false
48
+ gradient_accumulation_steps: 1
49
+ optimizer:
50
+ lr: 0.0002
51
+ beta2: 0.95
52
+ # max_grad_norm: 1.0
53
+ # recommend to use default settings
54
+ transport:
55
+ path_type: Linear
56
+ prediction: velocity
57
+ loss_weight: null
58
+ sample_eps: null
59
+ train_eps: null
60
+ use_cosine_loss: false
61
+ use_lognorm: true
62
+
63
+ # recommend to use default settings
64
+ sample:
65
+ mode: ODE
66
+ sampling_method: euler
67
+ atol: 0.000001
68
+ rtol: 0.001
69
+ reverse: false
70
+ likelihood: false
71
+ num_sampling_steps: 250
72
+ cfg_scale: 0 # <---- cfg scale, for 800 epoch performance with FID=1.35 cfg_scale=6.7
73
+ # for 64 epoch performance with FID=2.11 cfg_scale=10.0
74
+ # you may find we use a large cfg_scale, this is because of 2 reasons:
75
+ # we find a high-dimensional latent space requires a large cfg_scale to get good performance than f8d4 SD-VAE
76
+ # we enable cfg interval, which reduces the negative effects of large cfg on high-noise parts. This means larger cfg can be utilized
77
+
78
+ # recommend to use default settings
79
+ per_proc_batch_size: 128
80
+ fid_num: 50000
81
+ cfg_interval_start: 0.10
82
+ timestep_shift: 0.3
LDMAE/configs/imagenet/lightningdit_b_vmae_f8d16_cfg.yaml ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # we recommend to read config_details.yaml first.
2
+
3
+ ckpt_path: 'output/imagenet/lightningdit_b_vmae_f8d16/checkpoints/0100000.pt' # <---- download our pre-trained lightningdit or your own checkpoint
4
+
5
+ data:
6
+ origin_path: '/data/dataset/imagenet/1K_dataset'
7
+ data_path: '/data/dataset/imagenet/vmae_feature_imagenet_train_256' # <---- path to your data. it is generated by extract_features.py.
8
+ # if you just want to inference, download our latent_stats.pt and give its path here is ok.
9
+ fid_reference_file: 'tools/fid_statistics/VIRTUAL_imagenet256_labeled.npz' # <---- path to your fid_reference_file.npz. download it from ADM
10
+
11
+ # recommend to use default settings
12
+ image_size: 256
13
+ num_classes: 1000
14
+ num_workers: 8
15
+ latent_norm: true
16
+ latent_multiplier: 1.0
17
+ sample: true # <------------------------------ check this. you should comment out this when you don't want to use it.
18
+
19
+ # recommend to use default settings (we wil release codes with SD-VAE later)
20
+ vae:
21
+ model_name: 'vmae_f8d16'
22
+ downsample_ratio: 8
23
+ weight_path: 'pretrain_weight/vmaef8d16.pth'
24
+
25
+ # recommend to use default settings
26
+ model:
27
+ model_type: LightningDiT-B/1
28
+ use_qknorm: true
29
+ use_swiglu: true
30
+ use_rope: true
31
+ use_rmsnorm: true
32
+ wo_shift: false
33
+ in_chans: 16
34
+
35
+ # recommend to use default settings
36
+ train:
37
+ max_steps: 100000
38
+ global_batch_size: 256 # 256 ok
39
+ global_seed: 0
40
+ output_dir: 'output'
41
+ exp_name: 'imagenet/lightningdit_b_vmae_f8d16' # <---- experiment name, set as you like
42
+ ckpt: null
43
+ log_every: 100
44
+ ckpt_every: 20000
45
+ use_checkpoint: false
46
+ gradient_accumulation_steps: 1
47
+ optimizer:
48
+ lr: 0.0002
49
+ beta2: 0.95
50
+ # max_grad_norm: 1.0
51
+ # recommend to use default settings
52
+ transport:
53
+ path_type: Linear
54
+ prediction: velocity
55
+ loss_weight: null
56
+ sample_eps: null
57
+ train_eps: null
58
+ use_cosine_loss: false
59
+ use_lognorm: true
60
+
61
+ # recommend to use default settings
62
+ sample:
63
+ mode: ODE
64
+ sampling_method: euler
65
+ atol: 0.000001
66
+ rtol: 0.001
67
+ reverse: false
68
+ likelihood: false
69
+ num_sampling_steps: 250
70
+ cfg_scale: 10.0 # <---- cfg scale, for 800 epoch performance with FID=1.35 cfg_scale=6.7
71
+ # for 64 epoch performance with FID=2.11 cfg_scale=10.0
72
+ # you may find we use a large cfg_scale, this is because of 2 reasons:
73
+ # we find a high-dimensional latent space requires a large cfg_scale to get good performance than f8d4 SD-VAE
74
+ # we enable cfg interval, which reduces the negative effects of large cfg on high-noise parts. This means larger cfg can be utilized
75
+
76
+ # recommend to use default settings
77
+ per_proc_batch_size: 256
78
+ fid_num: 50000
79
+ cfg_interval_start: 0.10
80
+ timestep_shift: 0.3
LDMAE/datasets/__init__.py ADDED
File without changes
LDMAE/datasets/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (173 Bytes). View file
 
LDMAE/datasets/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (171 Bytes). View file
 
LDMAE/datasets/__pycache__/img_latent_dataset.cpython-310.pyc ADDED
Binary file (3.41 kB). View file
 
LDMAE/datasets/__pycache__/img_latent_dataset.cpython-38.pyc ADDED
Binary file (3.33 kB). View file
 
LDMAE/datasets/img_latent_dataset.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ImageNet Latent Dataset with safetensors.
3
+ """
4
+
5
+ import os
6
+ import numpy as np
7
+ from glob import glob
8
+ from tqdm import tqdm
9
+
10
+ import torch
11
+ from torch.utils.data import Dataset
12
+
13
+ from safetensors import safe_open
14
+ from tokenizer.util.misc import DiagonalGaussianDistribution
15
+
16
+ class ImgLatentDataset(Dataset):
17
+ def __init__(self, data_dir, latent_norm=True, latent_multiplier=1.0, sample=False):
18
+ self.data_dir = data_dir
19
+ self.latent_norm = latent_norm
20
+ self.latent_multiplier = latent_multiplier
21
+ self.sample = sample
22
+
23
+ self.files = sorted(glob(os.path.join(data_dir, "*.safetensors")))
24
+ self.img_to_file_map = self.get_img_to_safefile_map()
25
+
26
+ if latent_norm:
27
+ self._latent_mean, self._latent_std = self.get_latent_stats()
28
+
29
+ def get_img_to_safefile_map(self):
30
+ img_to_file = {}
31
+ for safe_file in self.files:
32
+ with safe_open(safe_file, framework="pt", device="cpu") as f:
33
+ labels = f.get_slice('labels')
34
+ labels_shape = labels.get_shape()
35
+ num_imgs = labels_shape[0]
36
+ cur_len = len(img_to_file)
37
+ for i in range(num_imgs):
38
+ img_to_file[cur_len+i] = {
39
+ 'safe_file': safe_file,
40
+ 'idx_in_file': i
41
+ }
42
+ return img_to_file
43
+
44
+ def get_latent_stats(self):
45
+ latent_stats_cache_file = os.path.join(self.data_dir, "latents_stats.pt")
46
+ if not os.path.exists(latent_stats_cache_file):
47
+ latent_stats = self.compute_latent_stats()
48
+ torch.save(latent_stats, latent_stats_cache_file)
49
+ else:
50
+ latent_stats = torch.load(latent_stats_cache_file)
51
+ return latent_stats['mean'], latent_stats['std']
52
+
53
+ def compute_latent_stats(self):
54
+ num_samples = min(10000, len(self.img_to_file_map))
55
+ random_indices = np.random.choice(len(self.img_to_file_map), num_samples, replace=False)
56
+ latents = []
57
+ for idx in tqdm(random_indices):
58
+ img_info = self.img_to_file_map[idx]
59
+ safe_file, img_idx = img_info['safe_file'], img_info['idx_in_file']
60
+ with safe_open(safe_file, framework="pt", device="cpu") as f:
61
+ features = f.get_slice('latents')
62
+ feature = features[img_idx:img_idx+1]
63
+ if self.sample:
64
+ feature = DiagonalGaussianDistribution(feature).sample()
65
+ latents.append(feature)
66
+ latents = torch.cat(latents, dim=0)
67
+ mean = latents.mean(dim=[0, 2, 3], keepdim=True)
68
+ std = latents.std(dim=[0, 2, 3], keepdim=True)
69
+ latent_stats = {'mean': mean, 'std': std}
70
+ print(latent_stats)
71
+ return latent_stats
72
+
73
+ def __len__(self):
74
+ return len(self.img_to_file_map.keys())
75
+
76
+ def __getitem__(self, idx):
77
+ img_info = self.img_to_file_map[idx]
78
+ safe_file, img_idx = img_info['safe_file'], img_info['idx_in_file']
79
+ with safe_open(safe_file, framework="pt", device="cpu") as f:
80
+ tensor_key = "latents" if np.random.uniform(0, 1) > 0.5 else "latents_flip"
81
+ features = f.get_slice(tensor_key)
82
+ labels = f.get_slice('labels')
83
+ feature = features[img_idx:img_idx+1]
84
+ label = labels[img_idx:img_idx+1]
85
+ if self.sample:
86
+ feature = DiagonalGaussianDistribution(feature).sample()
87
+ if self.latent_norm:
88
+ feature = (feature - self._latent_mean) / self._latent_std
89
+ feature = feature * self.latent_multiplier
90
+
91
+ # remove the first batch dimension (=1) kept by get_slice()
92
+ feature = feature.squeeze(0)
93
+ label = label.squeeze(0)
94
+ return feature, label
LDMAE/evaluate_tokenizer.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluate tokenizer performance by computing reconstruction metrics.
3
+
4
+ Metrics include:
5
+ - rFID (Reconstruction FID)
6
+ - PSNR (Peak Signal-to-Noise Ratio)
7
+ - LPIPS (Learned Perceptual Image Patch Similarity)
8
+ - SSIM (Structural Similarity Index)
9
+
10
+ by Jingfeng Yao
11
+ from HUST-VL
12
+ """
13
+
14
+ import os
15
+ import torch, yaml
16
+ import numpy as np
17
+ from tqdm import tqdm
18
+ from PIL import Image
19
+ import torch.distributed as dist
20
+ from torch.nn.parallel import DistributedDataParallel as DDP
21
+ from omegaconf import OmegaConf
22
+ from torch.utils.data import DataLoader, DistributedSampler
23
+ from tools.calculate_fid import calculate_fid_given_paths
24
+ from concurrent.futures import ThreadPoolExecutor, as_completed
25
+ from torchmetrics import StructuralSimilarityIndexMeasure
26
+ from models.lpips import LPIPS
27
+ from torchvision.datasets import ImageFolder
28
+ from torchvision import transforms
29
+ from diffusers.models import AutoencoderKL
30
+ from tokenizer.sdvae import Diffusers_AutoencoderKL
31
+ from tokenizer import models_mae
32
+
33
+ def load_config(config_path):
34
+ with open(config_path, "r") as file:
35
+ config = yaml.safe_load(file)
36
+ return config
37
+
38
+ def print_with_prefix(content, prefix='Tokenizer Evaluation', rank=0):
39
+ if rank == 0:
40
+ print(f"\033[34m[{prefix}]\033[0m {content}")
41
+
42
+ def save_image(image, filename):
43
+ Image.fromarray(image).save(filename)
44
+
45
+ def evaluate_tokenizer(args, config_path, model_type, data_path, output_path):
46
+ # Initialize distributed training
47
+ dist.init_process_group(backend='nccl')
48
+ local_rank = torch.distributed.get_rank()
49
+ torch.cuda.set_device(local_rank)
50
+ device = torch.device(f'cuda:{local_rank}')
51
+ train_config = load_config(config_path)
52
+ model_type = train_config['vae']['model_name'].split("_")[0]
53
+
54
+ if local_rank == 0:
55
+ print_with_prefix(f"Loading model... {model_type.upper()} {args.epsilon}")
56
+
57
+ if train_config['vae']['model_name'].split("_")[0] == 'vmae':
58
+ chkpt = train_config['vae']['weight_path']
59
+ arch = 'mae_for_ldmae_f8d16_prev'
60
+ model = getattr(models_mae, arch)(ldmae_mode=True, no_cls=True, kl_loss_weight=True, smooth_output=True, img_size=train_config['data']['image_size'])
61
+ checkpoint = torch.load(chkpt, map_location='cpu')
62
+ model = model.to(device).eval()
63
+ msg = model.load_state_dict(checkpoint['model'], strict=False)
64
+ elif train_config['vae']['model_name'].split("_")[0] in ['ae','dae','vae','sdv3']:
65
+ model = Diffusers_AutoencoderKL(
66
+ img_size=train_config['data']['image_size'],
67
+ sample_size=128,
68
+ in_channels=3,
69
+ out_channels=3,
70
+ layers_per_block=2,
71
+ latent_channels=16,
72
+ norm_num_groups=32,
73
+ act_fn="silu",
74
+ block_out_channels=(128, 256, 512, 512),
75
+ force_upcast=False,
76
+ use_quant_conv=False,
77
+ use_post_quant_conv=False,
78
+ down_block_types=(
79
+ "DownEncoderBlock2D",
80
+ "DownEncoderBlock2D",
81
+ "DownEncoderBlock2D",
82
+ "DownEncoderBlock2D",
83
+ ),
84
+ up_block_types=(
85
+ "UpDecoderBlock2D",
86
+ "UpDecoderBlock2D",
87
+ "UpDecoderBlock2D",
88
+ "UpDecoderBlock2D",
89
+ ),
90
+ ).to(device).eval()
91
+ chkpt_dir = train_config['vae']['weight_path']
92
+ checkpoint = torch.load(chkpt_dir, map_location='cpu')
93
+ msg = model.load_state_dict(checkpoint['model'], strict=False)
94
+ else:
95
+ raise
96
+ print(msg)
97
+ # Image preprocessing
98
+ transform = transforms.Compose([
99
+ transforms.ToTensor(),
100
+ transforms.Resize(256),
101
+ transforms.CenterCrop(256),
102
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
103
+ ])
104
+
105
+ # Create dataset and dataloader
106
+ dataset = ImageFolder(root=data_path, transform=transform)
107
+ distributed_sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(), rank=local_rank)
108
+ val_dataloader = DataLoader(
109
+ dataset,
110
+ batch_size=8,
111
+ shuffle=False,
112
+ num_workers=4,
113
+ sampler=distributed_sampler
114
+ )
115
+
116
+ if 'sample' in train_config['data']:
117
+ train_config['data']['data_path'] += '_sample'
118
+ latent_stats_cache_file = os.path.join(train_config['data']['data_path'], 'latents_stats.pt')
119
+ latent_stats = torch.load(latent_stats_cache_file)
120
+ latent_mean, latent_std = latent_stats['mean'], latent_stats['std']
121
+
122
+ latent_mean = latent_mean.clone().detach().to(device)
123
+ latent_std = latent_std.clone().detach().to(device)
124
+
125
+
126
+ # Setup output directories
127
+ folder_name = f"{model_type}_{args.epsilon}"
128
+
129
+ save_dir = os.path.join(output_path, folder_name, 'decoded_images')
130
+ ref_path = os.path.join(output_path, 'ref_images')
131
+ os.makedirs(save_dir, exist_ok=True)
132
+ os.makedirs(ref_path, exist_ok=True)
133
+
134
+ if local_rank == 0:
135
+ print_with_prefix(f"Output dir: {save_dir}")
136
+ print_with_prefix(f"Reference dir: {ref_path}")
137
+
138
+ # Save reference images if needed
139
+ ref_png_files = [f for f in os.listdir(ref_path) if f.endswith('.png')]
140
+ if len(ref_png_files) < 50000:
141
+ total_samples = 0
142
+ for batch in val_dataloader:
143
+ images = batch[0].to(device)
144
+ for j in range(images.size(0)):
145
+ img = torch.clamp(127.5 * images[j] + 128.0, 0, 255).cpu().permute(1, 2, 0).numpy().astype(np.uint8)
146
+ Image.fromarray(img).save(os.path.join(ref_path, f"ref_image_rank_{local_rank}_{total_samples}.png"))
147
+ total_samples += 1
148
+ if total_samples % 100 == 0 and local_rank == 0:
149
+ print_with_prefix(f"Rank {local_rank}, Saved {total_samples} reference images")
150
+ dist.barrier()
151
+
152
+ # Initialize metrics
153
+ lpips_values = []
154
+ ssim_values = []
155
+ lpips = LPIPS().to(device).eval()
156
+ ssim_metric = StructuralSimilarityIndexMeasure(data_range=(-1.0, 1.0)).to(device)
157
+
158
+ # Generate reconstructions and compute metrics
159
+ if local_rank == 0:
160
+ print_with_prefix("Generating reconstructions...")
161
+ all_indices = 0
162
+ if len(os.listdir(save_dir)) < 50000:
163
+ for batch in val_dataloader:
164
+ images = batch[0].to(device)
165
+ latents = encode_images(model, images)
166
+ epsilon = args.epsilon * torch.randn_like(latents)
167
+ latents = latents + epsilon * latent_std
168
+
169
+ with torch.no_grad():
170
+ decoded_images_tensor = model.decode(latents).sample
171
+ decoded_images = torch.clamp(127.5 * decoded_images_tensor + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
172
+
173
+ # Compute metrics
174
+ lpips_values.append(lpips(decoded_images_tensor, images).mean())
175
+ ssim_values.append(ssim_metric(decoded_images_tensor, images))
176
+
177
+ # Save reconstructions
178
+ for i, img in enumerate(decoded_images):
179
+ save_image(img, os.path.join(save_dir, f"decoded_image_rank_{local_rank}_{all_indices + i}.png"))
180
+ if (all_indices + i) % 100 == 0 and local_rank == 0:
181
+ print_with_prefix(f"Rank {local_rank}, Processed {all_indices + i} images")
182
+ all_indices += len(decoded_images)
183
+ dist.barrier()
184
+
185
+ # Aggregate metrics across GPUs
186
+ lpips_values = torch.tensor(lpips_values).to(device)
187
+ ssim_values = torch.tensor(ssim_values).to(device)
188
+ dist.all_reduce(lpips_values, op=dist.ReduceOp.AVG)
189
+ dist.all_reduce(ssim_values, op=dist.ReduceOp.AVG)
190
+
191
+ avg_lpips = lpips_values.mean().item()
192
+ avg_ssim = ssim_values.mean().item()
193
+
194
+ if local_rank == 0:
195
+ # Calculate FID
196
+ print_with_prefix("Computing rFID...")
197
+ fid = calculate_fid_given_paths([ref_path, save_dir], batch_size=50, dims=2048, device=device, num_workers=16)
198
+
199
+ # Calculate PSNR
200
+ print_with_prefix("Computing PSNR...")
201
+ psnr_values = calculate_psnr_between_folders(ref_path, save_dir)
202
+ avg_psnr = sum(psnr_values) / len(psnr_values)
203
+
204
+ # Print final results
205
+ print_with_prefix(f"Final Metrics:")
206
+ print_with_prefix(f"rFID: {fid:.3f}")
207
+ print_with_prefix(f"PSNR: {avg_psnr:.3f}")
208
+ print_with_prefix(f"LPIPS: {avg_lpips:.3f}")
209
+ print_with_prefix(f"SSIM: {avg_ssim:.3f}")
210
+ dist.barrier()
211
+ dist.destroy_process_group()
212
+
213
+ def encode_images(model, images):
214
+ with torch.no_grad():
215
+ posterior = model.encode(images).latent_dist
216
+ return posterior.mode().to(torch.float32)
217
+
218
+ def decode_to_images(model, z):
219
+ with torch.no_grad():
220
+ images = model.decode(z)
221
+ images = torch.clamp(127.5 * images + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
222
+ return images
223
+
224
+ def calculate_psnr(original, processed):
225
+ mse = torch.mean((original - processed) ** 2)
226
+ return 20 * torch.log10(255.0 / torch.sqrt(mse)).item()
227
+
228
+ def load_image(image_path):
229
+ image = Image.open(image_path).convert('RGB')
230
+ return torch.tensor(np.array(image).transpose(2, 0, 1), dtype=torch.float32)
231
+
232
+ def calculate_psnr_for_pair(original_path, processed_path):
233
+ return calculate_psnr(load_image(original_path), load_image(processed_path))
234
+
235
+ def calculate_psnr_between_folders(original_folder, processed_folder):
236
+ original_files = sorted(os.listdir(original_folder))
237
+ processed_files = sorted(os.listdir(processed_folder))
238
+
239
+ if len(original_files) != len(processed_files):
240
+ print("Warning: Mismatched number of images in folders")
241
+ return []
242
+
243
+ with ThreadPoolExecutor() as executor:
244
+ futures = [
245
+ executor.submit(calculate_psnr_for_pair,
246
+ os.path.join(original_folder, orig),
247
+ os.path.join(processed_folder, proc))
248
+ for orig, proc in zip(original_files, processed_files)
249
+ ]
250
+ return [future.result() for future in as_completed(futures)]
251
+
252
+ if __name__ == "__main__":
253
+ import argparse
254
+ parser = argparse.ArgumentParser()
255
+ parser.add_argument('--config_path', type=str, default='tokenizer/configs/vavae_f16d32.yaml')
256
+ parser.add_argument('--model_type', type=str, default='vavae')
257
+ parser.add_argument('--data_path', type=str, default='/data/dataset/imagenet/1K_dataset/val')
258
+ parser.add_argument('--output_path', type=str, default='./rfid')
259
+ parser.add_argument('--seed', type=int, default=42)
260
+ parser.add_argument('--epsilon', type=float, default=0, help="Noise pertubation ratio for latent robustness experiment.")
261
+ args = parser.parse_args()
262
+ evaluate_tokenizer(args, config_path=args.config_path, model_type=args.model_type, data_path=args.data_path, output_path=args.output_path)
LDMAE/extract_features.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ torch.backends.cuda.matmul.allow_tf32 = True
3
+ torch.backends.cudnn.allow_tf32 = True
4
+ import torch.distributed as dist
5
+ from torch.utils.data import DataLoader
6
+ from torch.utils.data.distributed import DistributedSampler
7
+ from torchvision.datasets import ImageFolder
8
+ import argparse
9
+ import os, yaml
10
+ from safetensors.torch import save_file
11
+ from datetime import datetime
12
+ from datasets.img_latent_dataset import ImgLatentDataset
13
+ from tokenizer import models_mae
14
+ from tokenizer.sdvae import Diffusers_AutoencoderKL
15
+
16
+ def load_config(config_path):
17
+ with open(config_path, "r") as file:
18
+ config = yaml.safe_load(file)
19
+ return config
20
+
21
+ def main(args, train_config):
22
+ """
23
+ Run a tokenizer on full dataset and save the features.
24
+ """
25
+ assert torch.cuda.is_available(), "Extract features currently requires at least one GPU."
26
+
27
+ # Setup DDP:
28
+ try:
29
+ dist.init_process_group("nccl")
30
+ rank = dist.get_rank()
31
+ device = rank % torch.cuda.device_count()
32
+ world_size = dist.get_world_size()
33
+ seed = args.seed + rank
34
+ if rank == 0:
35
+ print(f"Starting rank={rank}, seed={seed}, world_size={world_size}.")
36
+ except:
37
+ print("Failed to initialize DDP. Running in local mode.")
38
+ rank = 0
39
+ device = 0
40
+ world_size = 1
41
+ seed = args.seed
42
+ torch.manual_seed(seed)
43
+ torch.cuda.set_device(device)
44
+ model_name = train_config['vae']['model_name'].split("_")[0]
45
+ output_path = os.path.dirname(train_config['data']['origin_path'])
46
+ dataset_name = train_config['data']['name']
47
+
48
+ # Setup feature folders:
49
+ output_dir = os.path.join(output_path, f'{model_name}_feature_{dataset_name}_{args.data_split}_{args.image_size}')
50
+ if 'sample' in train_config['data']:
51
+ output_dir += '_sample'
52
+ if rank == 0:
53
+ os.makedirs(output_dir, exist_ok=True)
54
+ print(model_name)
55
+ # Create model:
56
+
57
+ if model_name == 'vmae':
58
+ arch = 'mae_for_ldmae_f8d16_prev'
59
+ # chkpt = 'pretrain_weight/mae60_kl_f8d16_200ep.pth'
60
+ chkpt = train_config['vae']['weight_path']
61
+ tokenizer = getattr(models_mae, arch)(ldmae_mode=True, no_cls=True, kl_loss_weight=True, smooth_output=True, img_size=args.image_size)
62
+ checkpoint = torch.load(chkpt, map_location='cpu')
63
+ tokenizer = tokenizer.to(device).eval()
64
+ msg = tokenizer.load_state_dict(checkpoint['model'], strict=False)
65
+ if rank == 0:
66
+ print(model_name, msg)
67
+ elif model_name in ['ae','dae', 'vae','sdv3']:
68
+ tokenizer = Diffusers_AutoencoderKL(
69
+ img_size=args.image_size,
70
+ sample_size=128,
71
+ in_channels=3,
72
+ out_channels=3,
73
+ layers_per_block=2,
74
+ latent_channels=16,
75
+ norm_num_groups=32,
76
+ act_fn="silu",
77
+ block_out_channels=(128, 256, 512, 512),
78
+ force_upcast=False,
79
+ use_quant_conv=False,
80
+ use_post_quant_conv=False,
81
+ down_block_types=(
82
+ "DownEncoderBlock2D",
83
+ "DownEncoderBlock2D",
84
+ "DownEncoderBlock2D",
85
+ "DownEncoderBlock2D",
86
+ ),
87
+ up_block_types=(
88
+ "UpDecoderBlock2D",
89
+ "UpDecoderBlock2D",
90
+ "UpDecoderBlock2D",
91
+ "UpDecoderBlock2D",
92
+ ),
93
+ ).to(device).eval()
94
+ # chkpt_dir = "./pretrain_weight/sdv3f8d16.pth"
95
+ chkpt = train_config['vae']['weight_path']
96
+ checkpoint = torch.load(chkpt, map_location='cpu')
97
+ msg = tokenizer.load_state_dict(checkpoint['model'], strict=False)
98
+ if rank == 0:
99
+ print(model_name, msg)
100
+ else:
101
+ raise("")
102
+
103
+
104
+ print(f"{device} GPU - Model loaded")
105
+ # Setup data:
106
+ data_path = train_config['data']['origin_path']
107
+ datasets = [
108
+ ImageFolder(os.path.join(data_path, args.data_split), transform=tokenizer.img_transform(p_hflip=0.0, img_size=args.image_size)),
109
+ ImageFolder(os.path.join(data_path, args.data_split), transform=tokenizer.img_transform(p_hflip=1.0, img_size=args.image_size))
110
+ ]
111
+ samplers = [
112
+ DistributedSampler(
113
+ dataset,
114
+ num_replicas=world_size,
115
+ rank=rank,
116
+ shuffle=False,
117
+ seed=args.seed
118
+ ) for dataset in datasets
119
+ ] # Maybe gray scale files are dropped. Need to be fixed.
120
+ loaders = [
121
+ DataLoader(
122
+ dataset,
123
+ batch_size=args.batch_size,
124
+ shuffle=False,
125
+ sampler=sampler,
126
+ num_workers=args.num_workers,
127
+ pin_memory=True,
128
+ drop_last=False
129
+ ) for dataset, sampler in zip(datasets, samplers)
130
+ ]
131
+ total_data_in_loop = len(loaders[0].dataset)
132
+ if rank == 0:
133
+ print(f"Total data in one loop: {total_data_in_loop}")
134
+
135
+ run_images = 0
136
+ saved_files = 0
137
+ latents = []
138
+ latents_flip = []
139
+ labels = []
140
+ for batch_idx, batch_data in enumerate(zip(*loaders)):
141
+ run_images += batch_data[0][0].shape[0]
142
+ if run_images % 100 == 0 and rank == 0:
143
+ print(f'{datetime.now()} processing {run_images} of {total_data_in_loop} images')
144
+
145
+ for loader_idx, data in enumerate(batch_data):
146
+ x = data[0].to(device)
147
+ y = data[1] # (N,)
148
+ with torch.no_grad():
149
+ if 'sample' in train_config['data']:
150
+ z = tokenizer._encode(x)
151
+ else:
152
+ z = tokenizer.encode(x).latent_dist.mode().detach().cpu() # (N, C, H, W)
153
+
154
+ if batch_idx == 0 and rank == 0:
155
+ print('latent shape', z.shape, 'dtype', z.dtype)
156
+
157
+ if loader_idx == 0:
158
+ latents.append(z)
159
+ labels.append(y)
160
+ else:
161
+ latents_flip.append(z)
162
+
163
+ if len(latents) == 10000 // args.batch_size:
164
+ latents = torch.cat(latents, dim=0)
165
+ latents_flip = torch.cat(latents_flip, dim=0)
166
+ labels = torch.cat(labels, dim=0)
167
+ save_dict = {
168
+ 'latents': latents,
169
+ 'latents_flip': latents_flip,
170
+ 'labels': labels
171
+ }
172
+ for key in save_dict:
173
+ if rank == 0:
174
+ print(key, save_dict[key].shape)
175
+ save_dict = {key: tensor.contiguous().cpu() for key, tensor in save_dict.items()}
176
+ save_filename = os.path.join(output_dir, f'latents_rank{rank:02d}_shard{saved_files:03d}.safetensors')
177
+ save_file(
178
+ save_dict,
179
+ save_filename,
180
+ metadata={'total_size': f'{latents.shape[0]}', 'dtype': f'{latents.dtype}', 'device': f'{latents.device}'}
181
+ )
182
+ if rank == 0:
183
+ print(f'Saved {save_filename}')
184
+
185
+ latents = []
186
+ latents_flip = []
187
+ labels = []
188
+ saved_files += 1
189
+
190
+ # save remainder latents that are fewer than 10000 images
191
+ if len(latents) > 0:
192
+ latents = torch.cat(latents, dim=0)
193
+ latents_flip = torch.cat(latents_flip, dim=0)
194
+ labels = torch.cat(labels, dim=0)
195
+ save_dict = {
196
+ 'latents': latents,
197
+ 'latents_flip': latents_flip,
198
+ 'labels': labels
199
+ }
200
+ for key in save_dict:
201
+ if rank == 0:
202
+ print(key, save_dict[key].shape)
203
+
204
+ save_dict = {key: tensor.contiguous().cpu() for key, tensor in save_dict.items()}
205
+ save_filename = os.path.join(output_dir, f'latents_rank{rank:02d}_shard{saved_files:03d}.safetensors')
206
+ save_file(
207
+ save_dict,
208
+ save_filename,
209
+ metadata={'total_size': f'{latents.shape[0]}', 'dtype': f'{latents.dtype}', 'device': f'{latents.device}'}
210
+ )
211
+ if rank == 0:
212
+ print(f'Saved {save_filename}')
213
+
214
+ # Calculate latents stats
215
+ dist.barrier()
216
+ if rank == 0:
217
+ dataset = ImgLatentDataset(output_dir, latent_norm=True, sample=train_config['data']['sample'] if 'sample' in train_config['data'] else False,)
218
+ dist.barrier()
219
+ dist.destroy_process_group()
220
+
221
+
222
+ if __name__ == "__main__":
223
+ parser = argparse.ArgumentParser()
224
+ # parser.add_argument("--data_path", type=str, default='/path/to/your/data')
225
+ parser.add_argument("--data_split", type=str, default='train')
226
+ parser.add_argument("--output_path", type=str, default="/data/dataset/imagenet/")
227
+ parser.add_argument("--image_size", type=int, default=256)
228
+ parser.add_argument("--batch_size", type=int, default=64)
229
+ parser.add_argument("--seed", type=int, default=42)
230
+ parser.add_argument("--num_workers", type=int, default=8)
231
+ parser.add_argument('--config', type=str, default='configs/debug.yaml')
232
+ args = parser.parse_args()
233
+
234
+ train_config = load_config(args.config)
235
+ main(args, train_config)
LDMAE/inference.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sampling Scripts of LightningDiT.
3
+
4
+ by Maple (Jingfeng Yao) from HUST-VL
5
+ """
6
+
7
+ import os, math, json, pickle, logging, argparse, yaml, torch, numpy as np
8
+ from time import time, strftime
9
+ from glob import glob
10
+ from copy import deepcopy
11
+ from collections import OrderedDict
12
+ from PIL import Image
13
+ from tqdm import tqdm
14
+ import torch.distributed as dist
15
+ from accelerate import Accelerator
16
+ from torch.utils.data import DataLoader
17
+ from torch.nn.parallel import DistributedDataParallel as DDP
18
+ from torch.utils.tensorboard import SummaryWriter
19
+ import torchvision
20
+ # local imports
21
+ from tokenizer.vavae import VA_VAE
22
+ from tokenizer.sdvae import Diffusers_AutoencoderKL
23
+ from tokenizer import models_mae
24
+ import threading
25
+
26
+ from models.lightningdit import LightningDiT_models
27
+ from transport import create_transport, Sampler
28
+ from datasets.img_latent_dataset import ImgLatentDataset
29
+ from torchvision.utils import save_image
30
+
31
+ # sample function
32
+ def save_images_async(images, indices, save_dir):
33
+ """비동기적으로 이미지를 저장하는 함수"""
34
+ for img, idx in zip(images, indices):
35
+ # numpy.ndarray를 torch.Tensor로 변환 후 저장
36
+ if isinstance(img, np.ndarray):
37
+ img = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0 # [H, W, C] → [C, H, W]
38
+ save_image(img, f"{save_dir}/{idx:06d}.png")
39
+
40
+ def do_sample(train_config, accelerator, ckpt_path=None, cfg_scale=None, model=None, vae=None, demo_sample_mode=False):
41
+ """
42
+ Run sampling.
43
+ """
44
+
45
+ folder_name = f"{train_config['model']['model_type'].replace('/', '-')}-ckpt-{ckpt_path.split('/')[-1].split('.')[0]}-{train_config['sample']['sampling_method']}-{train_config['sample']['num_sampling_steps']}".lower()
46
+ if cfg_scale is None:
47
+ cfg_scale = train_config['sample']['cfg_scale']
48
+ cfg_interval_start = train_config['sample']['cfg_interval_start'] if 'cfg_interval_start' in train_config['sample'] else 0
49
+ timestep_shift = train_config['sample']['timestep_shift'] if 'timestep_shift' in train_config['sample'] else 0
50
+ if cfg_scale > 1.0:
51
+ folder_name += f"-interval{cfg_interval_start:.2f}"+f"-cfg{cfg_scale:.2f}"
52
+ folder_name += f"-shift{timestep_shift:.2f}"
53
+
54
+ if demo_sample_mode:
55
+ cfg_interval_start = 0
56
+ timestep_shift = 0
57
+ # cfg_scale = 15
58
+
59
+ sample_folder_dir = os.path.join(train_config['train']['output_dir'], train_config['train']['exp_name'], folder_name)
60
+ if accelerator.process_index == 0:
61
+ if not demo_sample_mode:
62
+ print_with_prefix('Sample_folder_dir=', sample_folder_dir)
63
+ print_with_prefix('ckpt_path=', ckpt_path)
64
+ print_with_prefix('cfg_scale=', cfg_scale)
65
+ print_with_prefix('cfg_interval_start=', cfg_interval_start)
66
+ print_with_prefix('timestep_shift=', timestep_shift)
67
+ if not demo_sample_mode:
68
+ if not os.path.exists(sample_folder_dir):
69
+ if accelerator.process_index == 0:
70
+ os.makedirs(sample_folder_dir, exist_ok=True)
71
+ else:
72
+ png_files = [f for f in os.listdir(sample_folder_dir) if f.endswith('.png')]
73
+ png_count = len(png_files)
74
+ if png_count > train_config['sample']['fid_num']:
75
+ if accelerator.process_index == 0:
76
+ print_with_prefix(f"Found {png_count} PNG files in {sample_folder_dir}, skip sampling.")
77
+ return sample_folder_dir
78
+
79
+ torch.backends.cuda.matmul.allow_tf32 = True # True: fast but may lead to some small numerical differences
80
+ assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
81
+ torch.set_grad_enabled(False)
82
+
83
+ # Setup accelerator:
84
+ device = accelerator.device
85
+
86
+ # Setup DDP:
87
+ seed = train_config['train']['global_seed'] * accelerator.num_processes + accelerator.process_index
88
+ torch.manual_seed(seed)
89
+ # torch.cuda.set_device(device)
90
+ print_with_prefix(f"Starting rank={accelerator.local_process_index}, seed={seed}, world_size={accelerator.num_processes}.")
91
+ rank = accelerator.local_process_index
92
+
93
+ # Load model:
94
+ if 'downsample_ratio' in train_config['vae']:
95
+ downsample_ratio = train_config['vae']['downsample_ratio']
96
+ else:
97
+ downsample_ratio = 16
98
+ latent_size = train_config['data']['image_size'] // downsample_ratio
99
+
100
+ checkpoint = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
101
+ if "ema" in checkpoint: # supports checkpoints from train.py
102
+ checkpoint = checkpoint["ema"]
103
+ model.load_state_dict(checkpoint)
104
+ model.eval() # important!
105
+ model.to(device)
106
+
107
+ transport = create_transport(
108
+ train_config['transport']['path_type'],
109
+ train_config['transport']['prediction'],
110
+ train_config['transport']['loss_weight'],
111
+ train_config['transport']['train_eps'],
112
+ train_config['transport']['sample_eps'],
113
+ use_cosine_loss = train_config['transport']['use_cosine_loss'] if 'use_cosine_loss' in train_config['transport'] else False,
114
+ use_lognorm = train_config['transport']['use_lognorm'] if 'use_lognorm' in train_config['transport'] else False,
115
+ ) # default: velocity;
116
+ sampler = Sampler(transport)
117
+ mode = train_config['sample']['mode']
118
+ if mode == "ODE":
119
+ sample_fn = sampler.sample_ode(
120
+ sampling_method=train_config['sample']['sampling_method'],
121
+ num_steps=train_config['sample']['num_sampling_steps'],
122
+ atol=train_config['sample']['atol'],
123
+ rtol=train_config['sample']['rtol'],
124
+ reverse=train_config['sample']['reverse'],
125
+ timestep_shift=timestep_shift,
126
+ )
127
+ else:
128
+ raise NotImplementedError(f"Sampling mode {mode} is not supported.")
129
+
130
+ if vae is None:
131
+ if train_config['vae']['model_name'].split("_")[0] == 'vmae':
132
+ chkpt = train_config['vae']['weight_path']
133
+ arch = 'mae_for_ldmae_f8d16_prev'
134
+ vae = getattr(models_mae, arch)(ldmae_mode=True, no_cls=True, kl_loss_weight=True, smooth_output=True, img_size=train_config['data']['image_size'])
135
+ checkpoint = torch.load(chkpt, map_location='cpu')
136
+ vae = vae.to(device).eval()
137
+ msg = vae.load_state_dict(checkpoint['model'], strict=False)
138
+ elif train_config['vae']['model_name'].split("_")[0] in ['ae','dae', 'vae','sdv3']:
139
+ vae = Diffusers_AutoencoderKL(
140
+ img_size=train_config['data']['image_size'],
141
+ sample_size=128,
142
+ in_channels=3,
143
+ out_channels=3,
144
+ layers_per_block=2,
145
+ latent_channels=16,
146
+ norm_num_groups=32,
147
+ act_fn="silu",
148
+ block_out_channels=(128, 256, 512, 512),
149
+ force_upcast=False,
150
+ use_quant_conv=False,
151
+ use_post_quant_conv=False,
152
+ down_block_types=(
153
+ "DownEncoderBlock2D",
154
+ "DownEncoderBlock2D",
155
+ "DownEncoderBlock2D",
156
+ "DownEncoderBlock2D",
157
+ ),
158
+ up_block_types=(
159
+ "UpDecoderBlock2D",
160
+ "UpDecoderBlock2D",
161
+ "UpDecoderBlock2D",
162
+ "UpDecoderBlock2D",
163
+ ),
164
+ ).to(device).eval()
165
+ chkpt_dir = train_config['vae']['weight_path']
166
+ checkpoint = torch.load(chkpt_dir, map_location='cpu')
167
+ msg = vae.load_state_dict(checkpoint['model'], strict=False)
168
+ else:
169
+ raise
170
+ if accelerator.process_index == 0:
171
+ print_with_prefix(f'Model Loaded')
172
+
173
+ using_cfg = cfg_scale > 1.0
174
+ if using_cfg:
175
+ if accelerator.process_index == 0:
176
+ print_with_prefix('Using cfg:', using_cfg)
177
+
178
+ if rank == 0:
179
+ os.makedirs(sample_folder_dir, exist_ok=True)
180
+ if accelerator.process_index == 0 and not demo_sample_mode:
181
+ print_with_prefix(f"Saving .png samples at {sample_folder_dir}")
182
+ accelerator.wait_for_everyone()
183
+
184
+ # Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
185
+ n = train_config['sample']['per_proc_batch_size']
186
+ global_batch_size = n * accelerator.num_processes
187
+ # To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples:
188
+ num_samples = len([name for name in os.listdir(sample_folder_dir) if (os.path.isfile(os.path.join(sample_folder_dir, name)) and ".png" in name)])
189
+ total_samples = int(math.ceil(train_config['sample']['fid_num'] / global_batch_size) * global_batch_size)
190
+ if rank == 0:
191
+ if accelerator.process_index == 0:
192
+ print_with_prefix(f"Total number of images that will be sampled: {total_samples}")
193
+ assert total_samples % accelerator.num_processes == 0, "total_samples must be divisible by world_size"
194
+ samples_needed_this_gpu = int(total_samples // accelerator.num_processes)
195
+ assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size"
196
+ iterations = int(samples_needed_this_gpu // n)
197
+ done_iterations = int( int(num_samples // accelerator.num_processes) // n)
198
+ pbar = range(iterations)
199
+ if not demo_sample_mode:
200
+ pbar = tqdm(pbar) if rank == 0 else pbar
201
+ total = 0
202
+
203
+ if accelerator.process_index == 0:
204
+ print_with_prefix("Using latent normalization")
205
+ if 'sample' in train_config['data']:
206
+ train_config['data']['data_path'] += '_sample'
207
+ dataset = ImgLatentDataset(
208
+ data_dir=train_config['data']['data_path'],
209
+ latent_norm=train_config['data']['latent_norm'] if 'latent_norm' in train_config['data'] else False,
210
+ latent_multiplier=train_config['data']['latent_multiplier'] if 'latent_multiplier' in train_config['data'] else 0.18215,
211
+ sample=train_config['data']['sample'] if 'sample' in train_config['data'] else False,
212
+ )
213
+ latent_mean, latent_std = dataset.get_latent_stats()
214
+ latent_multiplier = train_config['data']['latent_multiplier'] if 'latent_multiplier' in train_config['data'] else 0.18215
215
+ # move to device
216
+ latent_mean = latent_mean.clone().detach().to(device)
217
+ latent_std = latent_std.clone().detach().to(device)
218
+
219
+ if demo_sample_mode:
220
+ if accelerator.process_index == 0:
221
+ images = []
222
+ if using_cfg:
223
+ for label in tqdm([975, 3, 207, 387, 388, 88, 979, 279], desc="Generating Demo Samples"):
224
+ z = torch.randn(1, model.in_channels, latent_size, latent_size, device=device)
225
+ y = torch.tensor([label], device=device)
226
+ z = torch.cat([z, z], 0)
227
+ y_null = torch.tensor([1000] * 1, device=device)
228
+ y = torch.cat([y, y_null], 0)
229
+ model_kwargs = dict(y=y, cfg_scale=cfg_scale, cfg_interval=False, cfg_interval_start=cfg_interval_start)
230
+ model_fn = model.forward_with_cfg
231
+ samples = sample_fn(z, model_fn, **model_kwargs)[-1]
232
+ samples = (samples * latent_std) / latent_multiplier + latent_mean
233
+ samples = vae.decode_to_images(samples)
234
+ images.append(samples)
235
+
236
+ else:
237
+ for label in tqdm([0]*8, desc="Generating Demo Samples"):
238
+ z = torch.randn(1, model.in_channels, latent_size, latent_size, device=device)
239
+ y = torch.tensor([label], device=device)
240
+ model_kwargs = dict(y=y)
241
+ model_fn = model.forward
242
+ samples = sample_fn(z, model_fn, **model_kwargs)[-1]
243
+ samples = (samples * latent_std) / latent_multiplier + latent_mean
244
+ samples = vae.decode_to_images(samples)
245
+ images.append(samples)
246
+
247
+ # Combine 8 images into a 2x4 grid
248
+ os.makedirs('demo_images', exist_ok=True)
249
+ # Stack all images into a large numpy array
250
+ all_images = np.stack([img[0] for img in images]) # Take first image from each batch
251
+ # Rearrange into 2x4 grid
252
+ h, w = all_images.shape[1:3]
253
+ grid = np.zeros((2 * h, 4 * w, 3), dtype=np.uint8)
254
+ for idx, image in enumerate(all_images):
255
+ i, j = divmod(idx, 4) # Calculate position in 2x4 grid
256
+ grid[i*h:(i+1)*h, j*w:(j+1)*w] = image
257
+
258
+ # Save the combined image
259
+ exp_name = train_config['train']['exp_name']
260
+ ckpt_iter = train_config['ckpt_path'].split("/")[-1][:-3]
261
+ Image.fromarray(grid).save(f'demo_images/{exp_name}_cfg{cfg_scale}_{ckpt_iter}_demo_samples.png')
262
+ return None
263
+ else:
264
+ for i in pbar:
265
+ # Sample inputs:
266
+ z = torch.randn(n, model.in_channels, latent_size, latent_size, device=device)
267
+ if 'trunaction' in train_config['sample']:
268
+ truncation_bound = train_config['sample']['truncation']
269
+ for _ in range(100):
270
+ invalid_mask = torch.abs(z) > truncation_bound
271
+ if not invalid_mask.any():
272
+ break
273
+ z[invalid_mask] = torch.randn_like(z[invalid_mask])
274
+ y = torch.randint(0, train_config['data']['num_classes'], (n,), device=device)
275
+
276
+ # Setup classifier-free guidance:
277
+ if using_cfg:
278
+ z = torch.cat([z, z], 0)
279
+ y_null = torch.tensor([1000] * n, device=device)
280
+ y = torch.cat([y, y_null], 0)
281
+ model_kwargs = dict(y=y, cfg_scale=cfg_scale, cfg_interval=True, cfg_interval_start=cfg_interval_start)
282
+ model_fn = model.forward_with_cfg
283
+ else:
284
+ model_kwargs = dict(y=y)
285
+ model_fn = model.forward
286
+
287
+ samples = sample_fn(z, model_fn, **model_kwargs)[-1]
288
+ if using_cfg:
289
+ samples, _ = samples.chunk(2, dim=0) # Remove null class samples
290
+
291
+ samples = (samples * latent_std) / latent_multiplier + latent_mean
292
+ samples = vae.decode_to_images(samples)
293
+
294
+ # Save samples to disk as individual .png files
295
+ for i, sample in enumerate(samples):
296
+ index = i * accelerator.num_processes + accelerator.process_index + total
297
+ Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png")
298
+ total += global_batch_size
299
+ accelerator.wait_for_everyone()
300
+
301
+ return sample_folder_dir
302
+
303
+ # some utils
304
+ def print_with_prefix(*messages):
305
+ prefix = f"\033[34m[LightningDiT-Sampling {strftime('%Y-%m-%d %H:%M:%S')}]\033[0m"
306
+ combined_message = ' '.join(map(str, messages))
307
+ print(f"{prefix}: {combined_message}")
308
+
309
+ def load_config(config_path):
310
+ with open(config_path, "r") as file:
311
+ config = yaml.safe_load(file)
312
+ return config
313
+
314
+ if __name__ == "__main__":
315
+
316
+ # read config
317
+ parser = argparse.ArgumentParser()
318
+ parser.add_argument('--config', type=str, default='configs/lightningdit_b_ldmvae_f16d16.yaml')
319
+ parser.add_argument('--demo', action='store_true', default=False)
320
+ args = parser.parse_args()
321
+ accelerator = Accelerator()
322
+ train_config = load_config(args.config)
323
+
324
+ # get ckpt_dir
325
+ assert 'ckpt_path' in train_config, "ckpt_path must be specified in config"
326
+ if accelerator.process_index == 0:
327
+ print_with_prefix('Using ckpt:', train_config['ckpt_path'])
328
+ ckpt_dir = train_config['ckpt_path']
329
+
330
+ if 'downsample_ratio' in train_config['vae']:
331
+ latent_size = train_config['data']['image_size'] // train_config['vae']['downsample_ratio']
332
+ else:
333
+ latent_size = train_config['data']['image_size'] // 16
334
+
335
+ # get model
336
+ model = LightningDiT_models[train_config['model']['model_type']](
337
+ input_size=latent_size,
338
+ num_classes=train_config['data']['num_classes'],
339
+ use_qknorm=train_config['model']['use_qknorm'],
340
+ use_swiglu=train_config['model']['use_swiglu'] if 'use_swiglu' in train_config['model'] else False,
341
+ use_rope=train_config['model']['use_rope'] if 'use_rope' in train_config['model'] else False,
342
+ use_rmsnorm=train_config['model']['use_rmsnorm'] if 'use_rmsnorm' in train_config['model'] else False,
343
+ wo_shift=train_config['model']['wo_shift'] if 'wo_shift' in train_config['model'] else False,
344
+ in_channels=train_config['model']['in_chans'] if 'in_chans' in train_config['model'] else 4,
345
+ learn_sigma=train_config['model']['learn_sigma'] if 'learn_sigma' in train_config['model'] else False,
346
+ class_dropout_prob=0 if train_config['data']['num_classes'] == 1 else 0.1,
347
+ )
348
+
349
+ # naive sample
350
+ sample_folder_dir = do_sample(train_config, accelerator, ckpt_path=ckpt_dir, model=model, demo_sample_mode=args.demo)
351
+
352
+ if not args.demo:
353
+ # calculate FID
354
+ # Important: FID is only for reference, please use ADM evaluation for paper reporting
355
+ if accelerator.process_index == 0:
356
+ from tools.calculate_fid import calculate_fid_given_paths
357
+ print_with_prefix('Calculating FID with {} number of samples'.format(train_config['sample']['fid_num']))
358
+ assert 'fid_reference_file' in train_config['data'], "fid_reference_file must be specified in config"
359
+ fid_reference_file = train_config['data']['fid_reference_file']
360
+ fid = calculate_fid_given_paths(
361
+ [fid_reference_file, sample_folder_dir],
362
+ batch_size=50,
363
+ dims=2048,
364
+ device='cuda',
365
+ num_workers=8,
366
+ sp_len = train_config['sample']['fid_num']
367
+ )
368
+ print_with_prefix('fid=',fid)
LDMAE/models/__init__.py ADDED
File without changes
LDMAE/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (171 Bytes). View file
 
LDMAE/models/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (169 Bytes). View file
 
LDMAE/models/__pycache__/lightningdit.cpython-310.pyc ADDED
Binary file (15.9 kB). View file
 
LDMAE/models/__pycache__/lightningdit.cpython-38.pyc ADDED
Binary file (16 kB). View file
 
LDMAE/models/__pycache__/pos_embed.cpython-310.pyc ADDED
Binary file (4.77 kB). View file
 
LDMAE/models/__pycache__/pos_embed.cpython-38.pyc ADDED
Binary file (4.76 kB). View file
 
LDMAE/models/__pycache__/rmsnorm.cpython-310.pyc ADDED
Binary file (16.4 kB). View file
 
LDMAE/models/__pycache__/rmsnorm.cpython-38.pyc ADDED
Binary file (16.5 kB). View file
 
LDMAE/models/__pycache__/swiglu_ffn.cpython-310.pyc ADDED
Binary file (2.16 kB). View file
 
LDMAE/models/__pycache__/swiglu_ffn.cpython-38.pyc ADDED
Binary file (2.07 kB). View file
 
LDMAE/models/lightningdit.py ADDED
@@ -0,0 +1,531 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Lightning DiT's codes are built from original DiT & SiT.
3
+ (https://github.com/facebookresearch/DiT; https://github.com/willisma/SiT)
4
+ It demonstrates that a advanced DiT together with advanced diffusion skills
5
+ could also achieve a very promising result with 1.35 FID on ImageNet 256 generation.
6
+
7
+ Enjoy everyone, DiT strikes back!
8
+
9
+ by Maple (Jingfeng Yao) from HUST-VL
10
+ """
11
+
12
+ import os
13
+ import math
14
+ import numpy as np
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from torch.utils.checkpoint import checkpoint
20
+
21
+ from timm.models.vision_transformer import PatchEmbed, Mlp
22
+ from models.swiglu_ffn import SwiGLUFFN
23
+ from models.pos_embed import VisionRotaryEmbeddingFast
24
+ from models.rmsnorm import RMSNorm
25
+
26
+ @torch.compile
27
+ def modulate(x, shift, scale):
28
+ if shift is None:
29
+ return x * (1 + scale.unsqueeze(1))
30
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
31
+
32
+ class Attention(nn.Module):
33
+ """
34
+ Attention module of LightningDiT.
35
+ """
36
+ def __init__(
37
+ self,
38
+ dim: int,
39
+ num_heads: int = 8,
40
+ qkv_bias: bool = False,
41
+ qk_norm: bool = False,
42
+ attn_drop: float = 0.,
43
+ proj_drop: float = 0.,
44
+ norm_layer: nn.Module = nn.LayerNorm,
45
+ fused_attn: bool = True,
46
+ use_rmsnorm: bool = False,
47
+ ) -> None:
48
+ super().__init__()
49
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
50
+
51
+ self.num_heads = num_heads
52
+ self.head_dim = dim // num_heads
53
+ self.scale = self.head_dim ** -0.5
54
+ self.fused_attn = fused_attn
55
+
56
+ if use_rmsnorm:
57
+ norm_layer = RMSNorm
58
+
59
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
60
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
61
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
62
+ self.attn_drop = nn.Dropout(attn_drop)
63
+ self.proj = nn.Linear(dim, dim)
64
+ self.proj_drop = nn.Dropout(proj_drop)
65
+
66
+ def forward(self, x: torch.Tensor, rope=None) -> torch.Tensor:
67
+ B, N, C = x.shape
68
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
69
+ q, k, v = qkv.unbind(0)
70
+ q, k = self.q_norm(q), self.k_norm(k)
71
+
72
+ if rope is not None:
73
+ q = rope(q)
74
+ k = rope(k)
75
+
76
+ if self.fused_attn:
77
+ x = F.scaled_dot_product_attention(
78
+ q, k, v,
79
+ dropout_p=self.attn_drop.p if self.training else 0.,
80
+ )
81
+ else:
82
+ q = q * self.scale
83
+ attn = q @ k.transpose(-2, -1)
84
+ attn = attn.softmax(dim=-1)
85
+ attn = self.attn_drop(attn)
86
+ x = attn @ v
87
+
88
+ x = x.transpose(1, 2).reshape(B, N, C)
89
+ x = self.proj(x)
90
+ x = self.proj_drop(x)
91
+ return x
92
+
93
+
94
+ class TimestepEmbedder(nn.Module):
95
+ """
96
+ Embeds scalar timesteps into vector representations.
97
+ Same as DiT.
98
+ """
99
+ def __init__(self, hidden_size: int, frequency_embedding_size: int = 256) -> None:
100
+ super().__init__()
101
+ self.frequency_embedding_size = frequency_embedding_size
102
+ self.mlp = nn.Sequential(
103
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
104
+ nn.SiLU(),
105
+ nn.Linear(hidden_size, hidden_size, bias=True),
106
+ )
107
+
108
+ @staticmethod
109
+ def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 10000) -> torch.Tensor:
110
+ """
111
+ Create sinusoidal timestep embeddings.
112
+ Args:
113
+ t: A 1-D Tensor of N indices, one per batch element. These may be fractional.
114
+ dim: The dimension of the output.
115
+ max_period: Controls the minimum frequency of the embeddings.
116
+ Returns:
117
+ An (N, D) Tensor of positional embeddings.
118
+ """
119
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
120
+ half = dim // 2
121
+ freqs = torch.exp(
122
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
123
+ ).to(device=t.device)
124
+
125
+ args = t[:, None].float() * freqs[None]
126
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
127
+
128
+ if dim % 2:
129
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
130
+
131
+ return embedding
132
+
133
+ @torch.compile
134
+ def forward(self, t: torch.Tensor) -> torch.Tensor:
135
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
136
+ t_emb = self.mlp(t_freq)
137
+ return t_emb
138
+
139
+
140
+ class LabelEmbedder(nn.Module):
141
+ """
142
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
143
+ Same as DiT.
144
+ """
145
+ def __init__(self, num_classes, hidden_size, dropout_prob):
146
+ super().__init__()
147
+ use_cfg_embedding = dropout_prob > 0
148
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
149
+ self.num_classes = num_classes
150
+ self.dropout_prob = dropout_prob
151
+
152
+ def token_drop(self, labels, force_drop_ids=None):
153
+ """
154
+ Drops labels to enable classifier-free guidance.
155
+ """
156
+ if force_drop_ids is None:
157
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
158
+ else:
159
+ drop_ids = force_drop_ids == 1
160
+ labels = torch.where(drop_ids, self.num_classes, labels)
161
+ return labels
162
+
163
+ @torch.compile
164
+ def forward(self, labels, train, force_drop_ids=None):
165
+ use_dropout = self.dropout_prob > 0
166
+ if (train and use_dropout) or (force_drop_ids is not None):
167
+ labels = self.token_drop(labels, force_drop_ids)
168
+ embeddings = self.embedding_table(labels)
169
+ return embeddings
170
+
171
+ class LightningDiTBlock(nn.Module):
172
+ """
173
+ Lightning DiT Block. We add features including:
174
+ - ROPE
175
+ - QKNorm
176
+ - RMSNorm
177
+ - SwiGLU
178
+ - No shift AdaLN.
179
+ Not all of them are used in the final model, please refer to the paper for more details.
180
+ """
181
+ def __init__(
182
+ self,
183
+ hidden_size,
184
+ num_heads,
185
+ mlp_ratio=4.0,
186
+ use_qknorm=False,
187
+ use_swiglu=False,
188
+ use_rmsnorm=False,
189
+ wo_shift=False,
190
+ **block_kwargs
191
+ ):
192
+ super().__init__()
193
+
194
+ # Initialize normalization layers
195
+ if not use_rmsnorm:
196
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
197
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
198
+ else:
199
+ self.norm1 = RMSNorm(hidden_size)
200
+ self.norm2 = RMSNorm(hidden_size)
201
+
202
+ # Initialize attention layer
203
+ self.attn = Attention(
204
+ hidden_size,
205
+ num_heads=num_heads,
206
+ qkv_bias=True,
207
+ qk_norm=use_qknorm,
208
+ use_rmsnorm=use_rmsnorm,
209
+ **block_kwargs
210
+ )
211
+
212
+ # Initialize MLP layer
213
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
214
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
215
+ if use_swiglu:
216
+ # here we did not use SwiGLU from xformers because it is not compatible with torch.compile for now.
217
+ self.mlp = SwiGLUFFN(hidden_size, int(2/3 * mlp_hidden_dim))
218
+ else:
219
+ self.mlp = Mlp(
220
+ in_features=hidden_size,
221
+ hidden_features=mlp_hidden_dim,
222
+ act_layer=approx_gelu,
223
+ drop=0
224
+ )
225
+
226
+ # Initialize AdaLN modulation
227
+ if wo_shift:
228
+ self.adaLN_modulation = nn.Sequential(
229
+ nn.SiLU(),
230
+ nn.Linear(hidden_size, 4 * hidden_size, bias=True)
231
+ )
232
+ else:
233
+ self.adaLN_modulation = nn.Sequential(
234
+ nn.SiLU(),
235
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
236
+ )
237
+ self.wo_shift = wo_shift
238
+
239
+ @torch.compile
240
+ def forward(self, x, c, feat_rope=None):
241
+ if self.wo_shift:
242
+ scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(4, dim=1)
243
+ shift_msa = None
244
+ shift_mlp = None
245
+ else:
246
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
247
+
248
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), rope=feat_rope)
249
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
250
+ return x
251
+
252
+ class FinalLayer(nn.Module):
253
+ """
254
+ The final layer of LightningDiT.
255
+ """
256
+ def __init__(self, hidden_size, patch_size, out_channels, use_rmsnorm=False):
257
+ super().__init__()
258
+ if not use_rmsnorm:
259
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
260
+ else:
261
+ self.norm_final = RMSNorm(hidden_size)
262
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
263
+ self.adaLN_modulation = nn.Sequential(
264
+ nn.SiLU(),
265
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True)
266
+ )
267
+ @torch.compile
268
+ def forward(self, x, c):
269
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
270
+ x = modulate(self.norm_final(x), shift, scale)
271
+ x = self.linear(x)
272
+ return x
273
+
274
+
275
+ class LightningDiT(nn.Module):
276
+ """
277
+ Diffusion model with a Transformer backbone.
278
+ """
279
+ def __init__(
280
+ self,
281
+ input_size=32,
282
+ patch_size=2,
283
+ in_channels=32,
284
+ hidden_size=1152,
285
+ depth=28,
286
+ num_heads=16,
287
+ mlp_ratio=4.0,
288
+ class_dropout_prob=0.1,
289
+ num_classes=1000,
290
+ learn_sigma=False,
291
+ use_qknorm=False,
292
+ use_swiglu=False,
293
+ use_rope=False,
294
+ use_rmsnorm=False,
295
+ wo_shift=False,
296
+ use_checkpoint=False,
297
+ ):
298
+ super().__init__()
299
+ self.learn_sigma = learn_sigma
300
+ self.in_channels = in_channels
301
+ self.out_channels = in_channels if not learn_sigma else in_channels * 2
302
+ self.patch_size = patch_size
303
+ self.num_heads = num_heads
304
+ self.use_rope = use_rope
305
+ self.use_rmsnorm = use_rmsnorm
306
+ self.depth = depth
307
+ self.hidden_size = hidden_size
308
+ self.use_checkpoint = use_checkpoint
309
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
310
+ self.t_embedder = TimestepEmbedder(hidden_size)
311
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
312
+ num_patches = self.x_embedder.num_patches
313
+ # Will use fixed sin-cos embedding:
314
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
315
+
316
+ # use rotary position encoding, borrow from EVA
317
+ if self.use_rope:
318
+ half_head_dim = hidden_size // num_heads // 2
319
+ hw_seq_len = input_size // patch_size
320
+ self.feat_rope = VisionRotaryEmbeddingFast(
321
+ dim=half_head_dim,
322
+ pt_seq_len=hw_seq_len,
323
+ )
324
+ else:
325
+ self.feat_rope = None
326
+
327
+ self.blocks = nn.ModuleList([
328
+ LightningDiTBlock(hidden_size,
329
+ num_heads,
330
+ mlp_ratio=mlp_ratio,
331
+ use_qknorm=use_qknorm,
332
+ use_swiglu=use_swiglu,
333
+ use_rmsnorm=use_rmsnorm,
334
+ wo_shift=wo_shift,
335
+ ) for _ in range(depth)
336
+ ])
337
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels, use_rmsnorm=use_rmsnorm)
338
+ self.initialize_weights()
339
+
340
+ def initialize_weights(self):
341
+ # Initialize transformer layers:
342
+ def _basic_init(module):
343
+ if isinstance(module, nn.Linear):
344
+ torch.nn.init.xavier_uniform_(module.weight)
345
+ if module.bias is not None:
346
+ nn.init.constant_(module.bias, 0)
347
+ self.apply(_basic_init)
348
+
349
+ # Initialize (and freeze) pos_embed by sin-cos embedding:
350
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
351
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
352
+
353
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
354
+ w = self.x_embedder.proj.weight.data
355
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
356
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
357
+
358
+ # Initialize label embedding table:
359
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
360
+
361
+ # Initialize timestep embedding MLP:
362
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
363
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
364
+
365
+ # Zero-out adaLN modulation layers in LightningDiT blocks:
366
+ for block in self.blocks:
367
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
368
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
369
+
370
+ # Zero-out output layers:
371
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
372
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
373
+ nn.init.constant_(self.final_layer.linear.weight, 0)
374
+ nn.init.constant_(self.final_layer.linear.bias, 0)
375
+
376
+ def unpatchify(self, x):
377
+ """
378
+ x: (N, T, patch_size**2 * C)
379
+ imgs: (N, H, W, C)
380
+ """
381
+ c = self.out_channels
382
+ p = self.x_embedder.patch_size[0]
383
+ h = w = int(x.shape[1] ** 0.5)
384
+ assert h * w == x.shape[1]
385
+
386
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
387
+ x = torch.einsum('nhwpqc->nchpwq', x)
388
+ imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
389
+ return imgs
390
+
391
+ def forward(self, x, t=None, y=None):
392
+ """
393
+ Forward pass of LightningDiT.
394
+ x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
395
+ t: (N,) tensor of diffusion timesteps
396
+ y: (N,) tensor of class labels
397
+ use_checkpoint: boolean to toggle checkpointing
398
+ """
399
+
400
+ use_checkpoint = self.use_checkpoint
401
+
402
+ x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
403
+ t = self.t_embedder(t) # (N, D)
404
+ y = self.y_embedder(y, self.training) # (N, D)
405
+ c = t + y # (N, D)
406
+
407
+ for block in self.blocks:
408
+ if use_checkpoint:
409
+ x = checkpoint(block, x, c, self.feat_rope, use_reentrant=True)
410
+ else:
411
+ x = block(x, c, self.feat_rope)
412
+
413
+ x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
414
+ x = self.unpatchify(x) # (N, out_channels, H, W)
415
+
416
+ if self.learn_sigma:
417
+ x, _ = x.chunk(2, dim=1)
418
+ return x
419
+
420
+ def forward_with_cfg(self, x, t, y, cfg_scale, cfg_interval=None, cfg_interval_start=None):
421
+ """
422
+ Forward pass of LightningDiT, but also batches the unconditional forward pass for classifier-free guidance.
423
+ """
424
+ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
425
+ half = x[: len(x) // 2]
426
+ combined = torch.cat([half, half], dim=0)
427
+ model_out = self.forward(combined, t, y)
428
+ # For exact reproducibility reasons, we apply classifier-free guidance on only
429
+ # three channels by default. The standard approach to cfg applies it to all channels.
430
+ # This can be done by uncommenting the following line and commenting-out the line following that.
431
+ # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
432
+ eps, rest = model_out[:, :3], model_out[:, 3:]
433
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
434
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
435
+
436
+ if cfg_interval is True:
437
+ timestep = t[0]
438
+ if timestep < cfg_interval_start:
439
+ half_eps = cond_eps
440
+
441
+ eps = torch.cat([half_eps, half_eps], dim=0)
442
+ return torch.cat([eps, rest], dim=1)
443
+
444
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
445
+ """
446
+ grid_size: int of the grid height and width
447
+ return:
448
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
449
+ """
450
+ grid_h = np.arange(grid_size, dtype=np.float32)
451
+ grid_w = np.arange(grid_size, dtype=np.float32)
452
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
453
+ grid = np.stack(grid, axis=0)
454
+
455
+ grid = grid.reshape([2, 1, grid_size, grid_size])
456
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
457
+ if cls_token and extra_tokens > 0:
458
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
459
+ return pos_embed
460
+
461
+
462
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
463
+ assert embed_dim % 2 == 0
464
+
465
+ # use half of dimensions to encode grid_h
466
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
467
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
468
+
469
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
470
+ return emb
471
+
472
+
473
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
474
+ """
475
+ embed_dim: output dimension for each position
476
+ pos: a list of positions to be encoded: size (M,)
477
+ out: (M, D)
478
+ """
479
+ assert embed_dim % 2 == 0
480
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
481
+ omega /= embed_dim / 2.
482
+ omega = 1. / 10000**omega # (D/2,)
483
+
484
+ pos = pos.reshape(-1) # (M,)
485
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
486
+
487
+ emb_sin = np.sin(out) # (M, D/2)
488
+ emb_cos = np.cos(out) # (M, D/2)
489
+
490
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
491
+ return emb
492
+
493
+
494
+ #################################################################################
495
+ # LightningDiT Configs #
496
+ #################################################################################
497
+
498
+ def LightningDiT_XL_1(**kwargs):
499
+ return LightningDiT(depth=28, hidden_size=1152, patch_size=1, num_heads=16, **kwargs)
500
+
501
+ def LightningDiT_XL_2(**kwargs):
502
+ return LightningDiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
503
+
504
+ def LightningDiT_L_2(**kwargs):
505
+ return LightningDiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
506
+
507
+ def LightningDiT_B_1(**kwargs):
508
+ return LightningDiT(depth=12, hidden_size=768, patch_size=1, num_heads=12, **kwargs)
509
+
510
+ def LightningDiT_B_2(**kwargs):
511
+ return LightningDiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
512
+
513
+ def LightningDiT_1p0B_1(**kwargs):
514
+ return LightningDiT(depth=24, hidden_size=1536, patch_size=1, num_heads=24, **kwargs)
515
+
516
+ def LightningDiT_1p0B_2(**kwargs):
517
+ return LightningDiT(depth=24, hidden_size=1536, patch_size=2, num_heads=24, **kwargs)
518
+
519
+ def LightningDiT_1p6B_1(**kwargs):
520
+ return LightningDiT(depth=28, hidden_size=1792, patch_size=1, num_heads=28, **kwargs)
521
+
522
+ def LightningDiT_1p6B_2(**kwargs):
523
+ return LightningDiT(depth=28, hidden_size=1792, patch_size=2, num_heads=28, **kwargs)
524
+
525
+ LightningDiT_models = {
526
+ 'LightningDiT-B/1': LightningDiT_B_1, 'LightningDiT-B/2': LightningDiT_B_2,
527
+ 'LightningDiT-L/2': LightningDiT_L_2,
528
+ 'LightningDiT-XL/1': LightningDiT_XL_1, 'LightningDiT-XL/2': LightningDiT_XL_2,
529
+ 'LightningDiT-1p0B/1': LightningDiT_1p0B_1, 'LightningDiT-1p0B/2': LightningDiT_1p0B_2,
530
+ 'LightningDiT-1p6B/1': LightningDiT_1p6B_1, 'LightningDiT-1p6B/2': LightningDiT_1p6B_2,
531
+ }
LDMAE/models/lpips.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import hashlib
3
+ import requests
4
+ import torch.nn as nn
5
+ from torchvision import models
6
+ from collections import namedtuple
7
+ import os
8
+ from tqdm import tqdm
9
+
10
+ URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"}
11
+
12
+ CKPT_MAP = {"vgg_lpips": "vgg.pth"}
13
+
14
+ MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"}
15
+
16
+
17
+
18
+ def download(url, local_path, chunk_size=1024):
19
+ os.makedirs(os.path.split(local_path)[0], exist_ok=True)
20
+ with requests.get(url, stream=True) as r:
21
+ total_size = int(r.headers.get("content-length", 0))
22
+ with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
23
+ with open(local_path, "wb") as f:
24
+ for data in r.iter_content(chunk_size=chunk_size):
25
+ if data:
26
+ f.write(data)
27
+ pbar.update(chunk_size)
28
+
29
+
30
+ def md5_hash(path):
31
+ with open(path, "rb") as f:
32
+ content = f.read()
33
+ return hashlib.md5(content).hexdigest()
34
+
35
+
36
+ def get_ckpt_path(name, root, check=False):
37
+ assert name in URL_MAP
38
+ path = os.path.join(root, CKPT_MAP[name])
39
+ if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
40
+ print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
41
+ download(URL_MAP[name], path)
42
+ md5 = md5_hash(path)
43
+ assert md5 == MD5_MAP[name], md5
44
+ return path
45
+
46
+
47
+ class LPIPS(nn.Module):
48
+ # Learned perceptual metric
49
+ def __init__(self, use_dropout=True):
50
+ super().__init__()
51
+ self.scaling_layer = ScalingLayer()
52
+ self.chns = [64, 128, 256, 512, 512] # vgg16 features
53
+ self.net = vgg16(pretrained=True, requires_grad=False)
54
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
55
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
56
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
57
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
58
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
59
+ self.load_from_pretrained()
60
+ for param in self.parameters():
61
+ param.requires_grad = False
62
+
63
+ def load_from_pretrained(self, name="vgg_lpips"):
64
+ ckpt = get_ckpt_path(name, "movqgan/modules/losses/lpips")
65
+ self.load_state_dict(
66
+ torch.load(ckpt, map_location=torch.device("cpu")), strict=False
67
+ )
68
+ print("loaded pretrained LPIPS loss from {}".format(ckpt))
69
+
70
+ @classmethod
71
+ def from_pretrained(cls, name="vgg_lpips"):
72
+ if name != "vgg_lpips":
73
+ raise NotImplementedError
74
+ model = cls()
75
+ ckpt = get_ckpt_path(name)
76
+ model.load_state_dict(
77
+ torch.load(ckpt, map_location=torch.device("cpu")), strict=False
78
+ )
79
+ return model
80
+
81
+ def forward(self, input, target):
82
+ in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
83
+ outs0, outs1 = self.net(in0_input), self.net(in1_input)
84
+ feats0, feats1, diffs = {}, {}, {}
85
+ lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
86
+ for kk in range(len(self.chns)):
87
+ feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(
88
+ outs1[kk]
89
+ )
90
+ diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
91
+
92
+ res = [
93
+ spatial_average(lins[kk].model(diffs[kk]), keepdim=True)
94
+ for kk in range(len(self.chns))
95
+ ]
96
+ val = res[0]
97
+ for l in range(1, len(self.chns)):
98
+ val += res[l]
99
+ return val
100
+
101
+
102
+ class ScalingLayer(nn.Module):
103
+ def __init__(self):
104
+ super(ScalingLayer, self).__init__()
105
+ self.register_buffer(
106
+ "shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]
107
+ )
108
+ self.register_buffer(
109
+ "scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]
110
+ )
111
+
112
+ def forward(self, inp):
113
+ # convert imagenet normalized data to [-1, 1]
114
+ return (inp - self.shift) / self.scale
115
+
116
+
117
+ class NetLinLayer(nn.Module):
118
+ """A single linear layer which does a 1x1 conv"""
119
+
120
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
121
+ super(NetLinLayer, self).__init__()
122
+ layers = (
123
+ [
124
+ nn.Dropout(),
125
+ ]
126
+ if (use_dropout)
127
+ else []
128
+ )
129
+ layers += [
130
+ nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
131
+ ]
132
+ self.model = nn.Sequential(*layers)
133
+
134
+
135
+ class vgg16(torch.nn.Module):
136
+ def __init__(self, requires_grad=False, pretrained=True):
137
+ super(vgg16, self).__init__()
138
+ vgg_pretrained_features = models.vgg16(pretrained=pretrained)
139
+ vgg_pretrained_features = vgg_pretrained_features.features
140
+ self.slice1 = torch.nn.Sequential()
141
+ self.slice2 = torch.nn.Sequential()
142
+ self.slice3 = torch.nn.Sequential()
143
+ self.slice4 = torch.nn.Sequential()
144
+ self.slice5 = torch.nn.Sequential()
145
+ self.N_slices = 5
146
+ for x in range(4):
147
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
148
+ for x in range(4, 9):
149
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
150
+ for x in range(9, 16):
151
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
152
+ for x in range(16, 23):
153
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
154
+ for x in range(23, 30):
155
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
156
+ if not requires_grad:
157
+ for param in self.parameters():
158
+ param.requires_grad = False
159
+
160
+ def forward(self, X):
161
+ h = self.slice1(X)
162
+ h_relu1_2 = h
163
+ h = self.slice2(h)
164
+ h_relu2_2 = h
165
+ h = self.slice3(h)
166
+ h_relu3_3 = h
167
+ h = self.slice4(h)
168
+ h_relu4_3 = h
169
+ h = self.slice5(h)
170
+ h_relu5_3 = h
171
+ vgg_outputs = namedtuple(
172
+ "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]
173
+ )
174
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
175
+ return out
176
+
177
+
178
+ def normalize_tensor(x, eps=1e-10):
179
+ norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
180
+ return x / (norm_factor + eps)
181
+
182
+
183
+ def spatial_average(x, keepdim=True):
184
+ return x.mean([2, 3], keepdim=keepdim)
LDMAE/models/pos_embed.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # EVA-02: A Visual Representation for Neon Genesis
3
+ # Github source: https://github.com/baaivision/EVA/EVA02
4
+ # Copyright (c) 2023 Beijing Academy of Artificial Intelligence (BAAI)
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # By Yuxin Fang
7
+ #
8
+ # Based on https://github.com/lucidrains/rotary-embedding-torch
9
+ # --------------------------------------------------------'
10
+
11
+ from math import pi
12
+
13
+ import torch
14
+ from torch import nn
15
+
16
+ from einops import rearrange, repeat
17
+
18
+
19
+
20
+ def broadcat(tensors, dim = -1):
21
+ num_tensors = len(tensors)
22
+ shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
23
+ assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
24
+ shape_len = list(shape_lens)[0]
25
+ dim = (dim + shape_len) if dim < 0 else dim
26
+ dims = list(zip(*map(lambda t: list(t.shape), tensors)))
27
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
28
+ assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation'
29
+ max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
30
+ expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
31
+ expanded_dims.insert(dim, (dim, dims[dim]))
32
+ expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
33
+ tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
34
+ return torch.cat(tensors, dim = dim)
35
+
36
+
37
+
38
+ def rotate_half(x):
39
+ x = rearrange(x, '... (d r) -> ... d r', r = 2)
40
+ x1, x2 = x.unbind(dim = -1)
41
+ x = torch.stack((-x2, x1), dim = -1)
42
+ return rearrange(x, '... d r -> ... (d r)')
43
+
44
+
45
+
46
+ class VisionRotaryEmbedding(nn.Module):
47
+ def __init__(
48
+ self,
49
+ dim,
50
+ pt_seq_len,
51
+ ft_seq_len=None,
52
+ custom_freqs = None,
53
+ freqs_for = 'lang',
54
+ theta = 10000,
55
+ max_freq = 10,
56
+ num_freqs = 1,
57
+ ):
58
+ super().__init__()
59
+ if custom_freqs:
60
+ freqs = custom_freqs
61
+ elif freqs_for == 'lang':
62
+ freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
63
+ elif freqs_for == 'pixel':
64
+ freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
65
+ elif freqs_for == 'constant':
66
+ freqs = torch.ones(num_freqs).float()
67
+ else:
68
+ raise ValueError(f'unknown modality {freqs_for}')
69
+
70
+ if ft_seq_len is None: ft_seq_len = pt_seq_len
71
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
72
+
73
+ freqs_h = torch.einsum('..., f -> ... f', t, freqs)
74
+ freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2)
75
+
76
+ freqs_w = torch.einsum('..., f -> ... f', t, freqs)
77
+ freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2)
78
+
79
+ freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1)
80
+
81
+ self.register_buffer("freqs_cos", freqs.cos())
82
+ self.register_buffer("freqs_sin", freqs.sin())
83
+
84
+ # print('======== shape of rope freq', self.freqs_cos.shape, '========')
85
+
86
+ def forward(self, t, start_index = 0):
87
+ rot_dim = self.freqs_cos.shape[-1]
88
+ end_index = start_index + rot_dim
89
+ assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
90
+ t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
91
+ t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
92
+ return torch.cat((t_left, t, t_right), dim = -1)
93
+
94
+
95
+
96
+ class VisionRotaryEmbeddingFast(nn.Module):
97
+ def __init__(
98
+ self,
99
+ dim,
100
+ pt_seq_len=16,
101
+ ft_seq_len=None,
102
+ custom_freqs = None,
103
+ freqs_for = 'lang',
104
+ theta = 10000,
105
+ max_freq = 10,
106
+ num_freqs = 1,
107
+ ):
108
+ super().__init__()
109
+ if custom_freqs:
110
+ freqs = custom_freqs
111
+ elif freqs_for == 'lang':
112
+ freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
113
+ elif freqs_for == 'pixel':
114
+ freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
115
+ elif freqs_for == 'constant':
116
+ freqs = torch.ones(num_freqs).float()
117
+ else:
118
+ raise ValueError(f'unknown modality {freqs_for}')
119
+
120
+ if ft_seq_len is None: ft_seq_len = pt_seq_len
121
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
122
+
123
+ freqs = torch.einsum('..., f -> ... f', t, freqs)
124
+ freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
125
+ freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1)
126
+
127
+ freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
128
+ freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
129
+
130
+ self.register_buffer("freqs_cos", freqs_cos)
131
+ self.register_buffer("freqs_sin", freqs_sin)
132
+
133
+ # print('======== shape of rope freq', self.freqs_cos.shape, '========')
134
+
135
+ def forward(self, t): return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
LDMAE/models/rmsnorm.py ADDED
@@ -0,0 +1,495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
3
+
4
+ import math
5
+ from dataclasses import dataclass
6
+ from typing import Optional, Tuple
7
+
8
+ import fairscale.nn.model_parallel.initialize as fs_init
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from fairscale.nn.model_parallel.layers import (
12
+ ColumnParallelLinear,
13
+ ParallelEmbedding,
14
+ RowParallelLinear,
15
+ )
16
+ from torch import nn
17
+
18
+
19
+ @dataclass
20
+ class ModelArgs:
21
+ dim: int = 4096
22
+ n_layers: int = 32
23
+ n_heads: int = 32
24
+ n_kv_heads: Optional[int] = None
25
+ vocab_size: int = -1 # defined later by tokenizer
26
+ multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
27
+ ffn_dim_multiplier: Optional[float] = None
28
+ norm_eps: float = 1e-5
29
+
30
+ max_batch_size: int = 32
31
+ max_seq_len: int = 2048
32
+
33
+
34
+ class RMSNorm(torch.nn.Module):
35
+ def __init__(self, dim: int, eps: float = 1e-6):
36
+ """
37
+ Initialize the RMSNorm normalization layer.
38
+
39
+ Args:
40
+ dim (int): The dimension of the input tensor.
41
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
42
+
43
+ Attributes:
44
+ eps (float): A small value added to the denominator for numerical stability.
45
+ weight (nn.Parameter): Learnable scaling parameter.
46
+
47
+ """
48
+ super().__init__()
49
+ self.eps = eps
50
+ self.weight = nn.Parameter(torch.ones(dim))
51
+
52
+ def _norm(self, x):
53
+ """
54
+ Apply the RMSNorm normalization to the input tensor.
55
+
56
+ Args:
57
+ x (torch.Tensor): The input tensor.
58
+
59
+ Returns:
60
+ torch.Tensor: The normalized tensor.
61
+
62
+ """
63
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
64
+
65
+ def forward(self, x):
66
+ """
67
+ Forward pass through the RMSNorm layer.
68
+
69
+ Args:
70
+ x (torch.Tensor): The input tensor.
71
+
72
+ Returns:
73
+ torch.Tensor: The output tensor after applying RMSNorm.
74
+
75
+ """
76
+ output = self._norm(x.float()).type_as(x)
77
+ return output * self.weight
78
+
79
+
80
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
81
+ """
82
+ Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
83
+
84
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
85
+ and the end index 'end'. The 'theta' parameter scales the frequencies.
86
+ The returned tensor contains complex values in complex64 data type.
87
+
88
+ Args:
89
+ dim (int): Dimension of the frequency tensor.
90
+ end (int): End index for precomputing frequencies.
91
+ theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
92
+
93
+ Returns:
94
+ torch.Tensor: Precomputed frequency tensor with complex exponentials.
95
+
96
+
97
+
98
+
99
+ """
100
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
101
+ t = torch.arange(end, device=freqs.device) # type: ignore
102
+ freqs = torch.outer(t, freqs).float() # type: ignore
103
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
104
+ return freqs_cis
105
+
106
+
107
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
108
+ """
109
+ Reshape frequency tensor for broadcasting it with another tensor.
110
+
111
+ This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
112
+ for the purpose of broadcasting the frequency tensor during element-wise operations.
113
+
114
+ Args:
115
+ freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
116
+ x (torch.Tensor): Target tensor for broadcasting compatibility.
117
+
118
+ Returns:
119
+ torch.Tensor: Reshaped frequency tensor.
120
+
121
+ Raises:
122
+ AssertionError: If the frequency tensor doesn't match the expected shape.
123
+ AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
124
+ """
125
+ ndim = x.ndim
126
+ assert 0 <= 1 < ndim
127
+ assert freqs_cis.shape == (x.shape[1], x.shape[-1])
128
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
129
+ return freqs_cis.view(*shape)
130
+
131
+
132
+ def apply_rotary_emb(
133
+ xq: torch.Tensor,
134
+ xk: torch.Tensor,
135
+ freqs_cis: torch.Tensor,
136
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
137
+ """
138
+ Apply rotary embeddings to input tensors using the given frequency tensor.
139
+
140
+ This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
141
+ frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
142
+ is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
143
+ returned as real tensors.
144
+
145
+ Args:
146
+ xq (torch.Tensor): Query tensor to apply rotary embeddings.
147
+ xk (torch.Tensor): Key tensor to apply rotary embeddings.
148
+ freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.
149
+
150
+ Returns:
151
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
152
+
153
+
154
+
155
+ """
156
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
157
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
158
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
159
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
160
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
161
+ return xq_out.type_as(xq), xk_out.type_as(xk)
162
+
163
+
164
+ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
165
+ """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
166
+ bs, slen, n_kv_heads, head_dim = x.shape
167
+ if n_rep == 1:
168
+ return x
169
+ return (
170
+ x[:, :, :, None, :]
171
+ .expand(bs, slen, n_kv_heads, n_rep, head_dim)
172
+ .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
173
+ )
174
+
175
+
176
+ class Attention(nn.Module):
177
+ """Multi-head attention module."""
178
+ def __init__(self, args: ModelArgs):
179
+ """
180
+ Initialize the Attention module.
181
+
182
+ Args:
183
+ args (ModelArgs): Model configuration parameters.
184
+
185
+ Attributes:
186
+ n_kv_heads (int): Number of key and value heads.
187
+ n_local_heads (int): Number of local query heads.
188
+ n_local_kv_heads (int): Number of local key and value heads.
189
+ n_rep (int): Number of repetitions for local heads.
190
+ head_dim (int): Dimension size of each attention head.
191
+ wq (ColumnParallelLinear): Linear transformation for queries.
192
+ wk (ColumnParallelLinear): Linear transformation for keys.
193
+ wv (ColumnParallelLinear): Linear transformation for values.
194
+ wo (RowParallelLinear): Linear transformation for output.
195
+ cache_k (torch.Tensor): Cached keys for attention.
196
+ cache_v (torch.Tensor): Cached values for attention.
197
+
198
+ """
199
+ super().__init__()
200
+ self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
201
+ model_parallel_size = fs_init.get_model_parallel_world_size()
202
+ self.n_local_heads = args.n_heads // model_parallel_size
203
+ self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
204
+ self.n_rep = self.n_local_heads // self.n_local_kv_heads
205
+ self.head_dim = args.dim // args.n_heads
206
+
207
+ self.wq = ColumnParallelLinear(
208
+ args.dim,
209
+ args.n_heads * self.head_dim,
210
+ bias=False,
211
+ gather_output=False,
212
+ init_method=lambda x: x,
213
+ )
214
+ self.wk = ColumnParallelLinear(
215
+ args.dim,
216
+ self.n_kv_heads * self.head_dim,
217
+ bias=False,
218
+ gather_output=False,
219
+ init_method=lambda x: x,
220
+ )
221
+ self.wv = ColumnParallelLinear(
222
+ args.dim,
223
+ self.n_kv_heads * self.head_dim,
224
+ bias=False,
225
+ gather_output=False,
226
+ init_method=lambda x: x,
227
+ )
228
+ self.wo = RowParallelLinear(
229
+ args.n_heads * self.head_dim,
230
+ args.dim,
231
+ bias=False,
232
+ input_is_parallel=True,
233
+ init_method=lambda x: x,
234
+ )
235
+
236
+ self.cache_k = torch.zeros(
237
+ (
238
+ args.max_batch_size,
239
+ args.max_seq_len,
240
+ self.n_local_kv_heads,
241
+ self.head_dim,
242
+ )
243
+ ).cuda()
244
+ self.cache_v = torch.zeros(
245
+ (
246
+ args.max_batch_size,
247
+ args.max_seq_len,
248
+ self.n_local_kv_heads,
249
+ self.head_dim,
250
+ )
251
+ ).cuda()
252
+
253
+ def forward(
254
+ self,
255
+ x: torch.Tensor,
256
+ start_pos: int,
257
+ freqs_cis: torch.Tensor,
258
+ mask: Optional[torch.Tensor],
259
+ ):
260
+ """
261
+ Forward pass of the attention module.
262
+
263
+ Args:
264
+ x (torch.Tensor): Input tensor.
265
+ start_pos (int): Starting position for caching.
266
+ freqs_cis (torch.Tensor): Precomputed frequency tensor.
267
+ mask (torch.Tensor, optional): Attention mask tensor.
268
+
269
+ Returns:
270
+ torch.Tensor: Output tensor after attention.
271
+
272
+ """
273
+ bsz, seqlen, _ = x.shape
274
+ xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
275
+
276
+ xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
277
+ xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
278
+ xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
279
+
280
+ xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
281
+
282
+ self.cache_k = self.cache_k.to(xq)
283
+ self.cache_v = self.cache_v.to(xq)
284
+
285
+ self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
286
+ self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
287
+
288
+ keys = self.cache_k[:bsz, : start_pos + seqlen]
289
+ values = self.cache_v[:bsz, : start_pos + seqlen]
290
+
291
+ # repeat k/v heads if n_kv_heads < n_heads
292
+ keys = repeat_kv(keys, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
293
+ values = repeat_kv(values, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
294
+
295
+ xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
296
+ keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
297
+ values = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
298
+ scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
299
+ if mask is not None:
300
+ scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
301
+ scores = F.softmax(scores.float(), dim=-1).type_as(xq)
302
+ output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
303
+ output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
304
+ return self.wo(output)
305
+
306
+
307
+ class FeedForward(nn.Module):
308
+ def __init__(
309
+ self,
310
+ dim: int,
311
+ hidden_dim: int,
312
+ multiple_of: int,
313
+ ffn_dim_multiplier: Optional[float],
314
+ ):
315
+ """
316
+ Initialize the FeedForward module.
317
+
318
+ Args:
319
+ dim (int): Input dimension.
320
+ hidden_dim (int): Hidden dimension of the feedforward layer.
321
+ multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
322
+ ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.
323
+
324
+ Attributes:
325
+ w1 (ColumnParallelLinear): Linear transformation for the first layer.
326
+ w2 (RowParallelLinear): Linear transformation for the second layer.
327
+ w3 (ColumnParallelLinear): Linear transformation for the third layer.
328
+
329
+ """
330
+ super().__init__()
331
+ hidden_dim = int(2 * hidden_dim / 3)
332
+ # custom dim factor multiplier
333
+ if ffn_dim_multiplier is not None:
334
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
335
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
336
+
337
+ self.w1 = ColumnParallelLinear(
338
+ dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
339
+ )
340
+ self.w2 = RowParallelLinear(
341
+ hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x
342
+ )
343
+ self.w3 = ColumnParallelLinear(
344
+ dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
345
+ )
346
+
347
+ def forward(self, x):
348
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
349
+
350
+
351
+ class TransformerBlock(nn.Module):
352
+ def __init__(self, layer_id: int, args: ModelArgs):
353
+ """
354
+ Initialize a TransformerBlock.
355
+
356
+ Args:
357
+ layer_id (int): Identifier for the layer.
358
+ args (ModelArgs): Model configuration parameters.
359
+
360
+ Attributes:
361
+ n_heads (int): Number of attention heads.
362
+ dim (int): Dimension size of the model.
363
+ head_dim (int): Dimension size of each attention head.
364
+ attention (Attention): Attention module.
365
+ feed_forward (FeedForward): FeedForward module.
366
+ layer_id (int): Identifier for the layer.
367
+ attention_norm (RMSNorm): Layer normalization for attention output.
368
+ ffn_norm (RMSNorm): Layer normalization for feedforward output.
369
+
370
+ """
371
+ super().__init__()
372
+ self.n_heads = args.n_heads
373
+ self.dim = args.dim
374
+ self.head_dim = args.dim // args.n_heads
375
+ self.attention = Attention(args)
376
+ self.feed_forward = FeedForward(
377
+ dim=args.dim,
378
+ hidden_dim=4 * args.dim,
379
+ multiple_of=args.multiple_of,
380
+ ffn_dim_multiplier=args.ffn_dim_multiplier,
381
+ )
382
+ self.layer_id = layer_id
383
+ self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
384
+ self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
385
+
386
+ def forward(
387
+ self,
388
+ x: torch.Tensor,
389
+ start_pos: int,
390
+ freqs_cis: torch.Tensor,
391
+ mask: Optional[torch.Tensor],
392
+ ):
393
+ """
394
+ Perform a forward pass through the TransformerBlock.
395
+
396
+ Args:
397
+ x (torch.Tensor): Input tensor.
398
+ start_pos (int): Starting position for attention caching.
399
+ freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
400
+ mask (torch.Tensor, optional): Masking tensor for attention. Defaults to None.
401
+
402
+ Returns:
403
+ torch.Tensor: Output tensor after applying attention and feedforward layers.
404
+
405
+ """
406
+ h = x + self.attention(
407
+ self.attention_norm(x), start_pos, freqs_cis, mask
408
+ )
409
+ out = h + self.feed_forward(self.ffn_norm(h))
410
+ return out
411
+
412
+
413
+ class Transformer(nn.Module):
414
+ def __init__(self, params: ModelArgs):
415
+ """
416
+ Initialize a Transformer model.
417
+
418
+ Args:
419
+ params (ModelArgs): Model configuration parameters.
420
+
421
+ Attributes:
422
+ params (ModelArgs): Model configuration parameters.
423
+ vocab_size (int): Vocabulary size.
424
+ n_layers (int): Number of layers in the model.
425
+ tok_embeddings (ParallelEmbedding): Token embeddings.
426
+ layers (torch.nn.ModuleList): List of Transformer blocks.
427
+ norm (RMSNorm): Layer normalization for the model output.
428
+ output (ColumnParallelLinear): Linear layer for final output.
429
+ freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
430
+
431
+ """
432
+ super().__init__()
433
+ self.params = params
434
+ self.vocab_size = params.vocab_size
435
+ self.n_layers = params.n_layers
436
+
437
+ self.tok_embeddings = ParallelEmbedding(
438
+ params.vocab_size, params.dim, init_method=lambda x: x
439
+ )
440
+
441
+ self.layers = torch.nn.ModuleList()
442
+ for layer_id in range(params.n_layers):
443
+ self.layers.append(TransformerBlock(layer_id, params))
444
+
445
+ self.norm = RMSNorm(params.dim, eps=params.norm_eps)
446
+ self.output = ColumnParallelLinear(
447
+ params.dim, params.vocab_size, bias=False, init_method=lambda x: x
448
+ )
449
+
450
+ self.freqs_cis = precompute_freqs_cis(
451
+ # Note that self.params.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation of models is 4096.
452
+ # Adding this multiplier instead of using 4096 directly allows for dynamism of token lengths while training or fine-tuning.
453
+ self.params.dim // self.params.n_heads, self.params.max_seq_len * 2
454
+ )
455
+
456
+ @torch.inference_mode()
457
+ def forward(self, tokens: torch.Tensor, start_pos: int):
458
+ """
459
+ Perform a forward pass through the Transformer model.
460
+
461
+ Args:
462
+ tokens (torch.Tensor): Input token indices.
463
+ start_pos (int): Starting position for attention caching.
464
+
465
+ Returns:
466
+ torch.Tensor: Output logits after applying the Transformer model.
467
+
468
+ """
469
+ _bsz, seqlen = tokens.shape
470
+ h = self.tok_embeddings(tokens)
471
+ self.freqs_cis = self.freqs_cis.to(h.device)
472
+ freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
473
+
474
+ mask = None
475
+ if seqlen > 1:
476
+ mask = torch.full(
477
+ (seqlen, seqlen), float("-inf"), device=tokens.device
478
+ )
479
+
480
+ mask = torch.triu(mask, diagonal=1)
481
+
482
+ # When performing key-value caching, we compute the attention scores
483
+ # only for the new sequence. Thus, the matrix of scores is of size
484
+ # (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
485
+ # j > cache_len + i, since row i corresponds to token cache_len + i.
486
+ mask = torch.hstack([
487
+ torch.zeros((seqlen, start_pos), device=tokens.device),
488
+ mask
489
+ ]).type_as(h)
490
+
491
+ for layer in self.layers:
492
+ h = layer(h, start_pos, freqs_cis, mask)
493
+ h = self.norm(h)
494
+ output = self.output(h).float()
495
+ return output
LDMAE/models/swiglu_ffn.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ import torch
8
+ from typing import Callable, Optional
9
+ import warnings
10
+
11
+ from torch import Tensor, nn
12
+ import torch.nn.functional as F
13
+
14
+
15
+ class SwiGLUFFN(nn.Module):
16
+ def __init__(
17
+ self,
18
+ in_features: int,
19
+ hidden_features: Optional[int] = None,
20
+ out_features: Optional[int] = None,
21
+ act_layer: Callable[..., nn.Module] = None,
22
+ drop: float = 0.0,
23
+ bias: bool = True,
24
+ ) -> None:
25
+ super().__init__()
26
+ out_features = out_features or in_features
27
+ hidden_features = hidden_features or in_features
28
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
29
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
30
+
31
+ @torch.compile
32
+ def forward(self, x: Tensor) -> Tensor:
33
+ x12 = self.w12(x)
34
+ x1, x2 = x12.chunk(2, dim=-1)
35
+ hidden = F.silu(x1) * x2
36
+ return self.w3(hidden)
37
+
38
+
39
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
40
+ try:
41
+ if XFORMERS_ENABLED:
42
+ from xformers.ops import SwiGLU
43
+
44
+ XFORMERS_AVAILABLE = True
45
+ # warnings.warn("xFormers is available (SwiGLU)")
46
+ else:
47
+ # warnings.warn("xFormers is disabled (SwiGLU)")
48
+ raise ImportError
49
+ except ImportError:
50
+ SwiGLU = SwiGLUFFN
51
+ XFORMERS_AVAILABLE = False
52
+
53
+ # warnings.warn("xFormers is not available (SwiGLU)")
54
+
55
+
56
+ class SwiGLUFFNFused(SwiGLU):
57
+ def __init__(
58
+ self,
59
+ in_features: int,
60
+ hidden_features: Optional[int] = None,
61
+ out_features: Optional[int] = None,
62
+ act_layer: Callable[..., nn.Module] = None,
63
+ drop: float = 0.0,
64
+ bias: bool = True,
65
+ ) -> None:
66
+ out_features = out_features or in_features
67
+ hidden_features = hidden_features or in_features
68
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
69
+ super().__init__(
70
+ in_features=in_features,
71
+ hidden_features=hidden_features,
72
+ out_features=out_features,
73
+ bias=bias,
74
+ )
LDMAE/pretrain_weight/aef8d16.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:97ad3672641653fcd74106cd050dc8f5042089b8edc06e30cbcde642be239aa6
3
+ size 1006144522
LDMAE/pretrain_weight/daef8d16.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cbf39522a2f602df6f271b8b0ea0a73a7c23687fd54f98c5b60ef85289b15168
3
+ size 1006144522
LDMAE/pretrain_weight/sdv3f8d16.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b367d60d708cb371261c005a44bd68f8d17dd211f8c771fb2b3802e51df2f8c
3
+ size 1098238157
LDMAE/pretrain_weight/vaef8d16.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f0e12f87137cdb19bd8f461dc1c8d7c572628d79e03f070ebdcc7a802e610c6
3
+ size 1006144522
LDMAE/pretrain_weight/vmaef8d16.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:441e7360993a03978e729dafc77432a372f066bf46a3c9610c7c33c4a0f09fc1
3
+ size 147225897
LDMAE/requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision
3
+ accelerate
4
+ transformers
5
+ Pillow
6
+ numpy
7
+ scipy
8
+ tqdm
9
+ matplotlib
10
+ tensorboard
11
+ omegaconf
12
+ einops
13
+ timm
14
+ opencv-python
15
+ scikit-learn
16
+ lpips
LDMAE/run_extract_feature.sh ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CONFIG_PATH=$1
2
+
3
+ GPUS_PER_NODE=${GPUS_PER_NODE:-8}
4
+ NNODES=${WORLD_SIZE:-1}
5
+ NODE_RANK=${RANK:-0}
6
+ MASTER_ADDR=${MASTER_ADDR:-127.0.0.1}
7
+ MASTER_PORT=${MASTER_PORT:-1235}
8
+ WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
9
+ PRECISION=${PRECISION:-bf16}
10
+
11
+ echo $CONFIG_PATH
12
+
13
+ accelerate launch \
14
+ --config-file configs/accelerator/8gpu.yaml \
15
+ --main_process_ip $MASTER_ADDR \
16
+ --main_process_port $MASTER_PORT \
17
+ --machine_rank $NODE_RANK \
18
+ --num_processes $(($GPUS_PER_NODE*$NNODES)) \
19
+ --num_machines $NNODES \
20
+ --mixed_precision $PRECISION \
21
+ extract_features.py \
22
+ --config $CONFIG_PATH \
LDMAE/run_fast_inference.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CONFIG_PATH=$1
2
+
3
+ GPUS_PER_NODE=1
4
+ NNODES=1
5
+ NODE_RANK=0
6
+ MASTER_ADDR=${MASTER_ADDR:-127.0.0.1}
7
+ MASTER_PORT=${MASTER_PORT:-1236}
8
+ WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
9
+ PRECISION=${PRECISION:-bf16}
10
+
11
+ accelerate launch \
12
+ --main_process_ip $MASTER_ADDR \
13
+ --main_process_port $MASTER_PORT \
14
+ --machine_rank $NODE_RANK \
15
+ --num_processes $(($GPUS_PER_NODE*$NNODES)) \
16
+ --num_machines $NNODES \
17
+ --mixed_precision $PRECISION \
18
+ inference.py \
19
+ --config $CONFIG_PATH \
20
+ --demo
LDMAE/run_inference.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CONFIG_PATH=$1
2
+
3
+ GPUS_PER_NODE=${GPUS_PER_NODE:-8}
4
+ NNODES=${WORLD_SIZE:-1}
5
+ NODE_RANK=${RANK:-0}
6
+ MASTER_ADDR=${MASTER_ADDR:-127.0.0.1}
7
+ MASTER_PORT=${MASTER_PORT:-1237}
8
+ WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
9
+ PRECISION=${PRECISION:-bf16}
10
+
11
+ accelerate launch \
12
+ --config-file configs/accelerator/8gpu.yaml \
13
+ --main_process_ip $MASTER_ADDR \
14
+ --main_process_port $MASTER_PORT \
15
+ --machine_rank $NODE_RANK \
16
+ --num_processes $(($GPUS_PER_NODE*$NNODES)) \
17
+ --num_machines $NNODES \
18
+ --mixed_precision $PRECISION \
19
+ inference.py \
20
+ --config $CONFIG_PATH
LDMAE/run_robustness_test.sh ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CONFIG=$1
2
+ GPUS_PER_NODE=${GPUS_PER_NODE:-8}
3
+ NNODES=${WORLD_SIZE:-1}
4
+ NODE_RANK=${RANK:-0}
5
+ MASTER_ADDR=${MASTER_ADDR:-127.0.0.1}
6
+ MASTER_PORT=${MASTER_PORT:-1241}
7
+ WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
8
+ PRECISION=${PRECISION:-bf16}
9
+
10
+ # VMAE reconstruction
11
+ accelerate launch \
12
+ --config-file configs/accelerator/8gpu.yaml \
13
+ --main_process_ip $MASTER_ADDR \
14
+ --main_process_port $MASTER_PORT \
15
+ --machine_rank $NODE_RANK \
16
+ --num_processes $(($GPUS_PER_NODE*$NNODES)) \
17
+ --num_machines $NNODES \
18
+ --mixed_precision $PRECISION \
19
+ evaluate_tokenizer.py \
20
+ --config configs/imagenet/lightningdit_b_vmae_f8d16_cfg.yaml \
21
+ --robust_exp True \
22
+
23
+ accelerate launch \
24
+ --config-file configs/accelerator/8gpu.yaml \
25
+ --main_process_ip $MASTER_ADDR \
26
+ --main_process_port $MASTER_PORT \
27
+ --machine_rank $NODE_RANK \
28
+ --num_processes $(($GPUS_PER_NODE*$NNODES)) \
29
+ --num_machines $NNODES \
30
+ --mixed_precision $PRECISION \
31
+ evaluate_tokenizer.py \
32
+ --epsilon 0.01 \
33
+ --config configs/imagenet/lightningdit_b_vmae_f8d16_cfg.yaml \
34
+
35
+ accelerate launch \
36
+ --config-file configs/accelerator/8gpu.yaml \
37
+ --main_process_ip $MASTER_ADDR \
38
+ --main_process_port $MASTER_PORT \
39
+ --machine_rank $NODE_RANK \
40
+ --num_processes $(($GPUS_PER_NODE*$NNODES)) \
41
+ --num_machines $NNODES \
42
+ --mixed_precision $PRECISION \
43
+ evaluate_tokenizer.py \
44
+ --epsilon 0.05 \
45
+ --config configs/imagenet/lightningdit_b_vmae_f8d16_cfg.yaml \
46
+
47
+ accelerate launch \
48
+ --config-file configs/accelerator/8gpu.yaml \
49
+ --main_process_ip $MASTER_ADDR \
50
+ --main_process_port $MASTER_PORT \
51
+ --machine_rank $NODE_RANK \
52
+ --num_processes $(($GPUS_PER_NODE*$NNODES)) \
53
+ --num_machines $NNODES \
54
+ --mixed_precision $PRECISION \
55
+ evaluate_tokenizer.py \
56
+ --epsilon 0.1 \
57
+ --config configs/imagenet/lightningdit_b_vmae_f8d16_cfg.yaml \
58
+
59
+ accelerate launch \
60
+ --config-file configs/accelerator/8gpu.yaml \
61
+ --main_process_ip $MASTER_ADDR \
62
+ --main_process_port $MASTER_PORT \
63
+ --machine_rank $NODE_RANK \
64
+ --num_processes $(($GPUS_PER_NODE*$NNODES)) \
65
+ --num_machines $NNODES \
66
+ --mixed_precision $PRECISION \
67
+ evaluate_tokenizer_mae.py \
68
+ --epsilon 0.2 \
69
+ --config configs/imagenet/lightningdit_b_vmae_f8d16_cfg.yaml \
70
+
71
+ accelerate launch \
72
+ --config-file configs/accelerator/8gpu.yaml \
73
+ --main_process_ip $MASTER_ADDR \
74
+ --main_process_port $MASTER_PORT \
75
+ --machine_rank $NODE_RANK \
76
+ --num_processes $(($GPUS_PER_NODE*$NNODES)) \
77
+ --num_machines $NNODES \
78
+ --mixed_precision $PRECISION \
79
+ evaluate_tokenizer.py \
80
+ --epsilon 0.3 \
81
+ --config configs/imagenet/lightningdit_b_vmae_f8d16_cfg.yaml \
LDMAE/run_train.sh ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CONFIG_PATH=$1
2
+
3
+ GPUS_PER_NODE=${GPUS_PER_NODE:-8}
4
+ NNODES=${WORLD_SIZE:-1}
5
+ NODE_RANK=${RANK:-0}
6
+ MASTER_ADDR=${MASTER_ADDR:-127.0.0.1}
7
+ MASTER_PORT=${MASTER_PORT:-1235}
8
+ WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
9
+ PRECISION=${PRECISION:-bf16}
10
+
11
+ echo $CONFIG_PATH
12
+
13
+ accelerate launch \
14
+ --config-file configs/accelerator/8gpu.yaml \
15
+ --main_process_ip $MASTER_ADDR \
16
+ --main_process_port $MASTER_PORT \
17
+ --machine_rank $NODE_RANK \
18
+ --num_processes $(($GPUS_PER_NODE*$NNODES)) \
19
+ --num_machines $NNODES \
20
+ --mixed_precision $PRECISION \
21
+ train_accum.py \
22
+ --config $CONFIG_PATH
LDMAE/tokenizer/__init__.py ADDED
File without changes
LDMAE/tokenizer/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (174 Bytes). View file
 
LDMAE/tokenizer/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (146 Bytes). View file
 
LDMAE/tokenizer/__pycache__/autoencoder.cpython-310.pyc ADDED
Binary file (12.5 kB). View file
 
LDMAE/tokenizer/__pycache__/models_mae.cpython-310.pyc ADDED
Binary file (28.8 kB). View file
 
LDMAE/tokenizer/__pycache__/sdvae.cpython-310.pyc ADDED
Binary file (3.5 kB). View file
 
LDMAE/tokenizer/__pycache__/vavae.cpython-310.pyc ADDED
Binary file (4.48 kB). View file
 
LDMAE/tokenizer/__pycache__/vavae.cpython-38.pyc ADDED
Binary file (4.35 kB). View file