| import argparse |
| import torch |
|
|
| from infinity.models.bsq_vae.flux_vqgan import AutoEncoder |
|
|
| def load_cnn(model, state_dict, prefix, expand=False, use_linear=False): |
| delete_keys = [] |
| loaded_keys = [] |
| for key in state_dict: |
| if key.startswith(prefix): |
| _key = key[len(prefix):] |
| if _key in model.state_dict(): |
| |
| if use_linear and (".q.weight" in key or ".k.weight" in key or ".v.weight" in key or ".proj_out.weight" in key): |
| load_weights = state_dict[key].squeeze() |
| elif _key.endswith(".conv.weight") and expand: |
| if model.state_dict()[_key].shape == state_dict[key].shape: |
| |
| load_weights = state_dict[key] |
| else: |
| |
| _expand_dim = model.state_dict()[_key].shape[2] |
| load_weights = state_dict[key].unsqueeze(2).repeat(1, 1, _expand_dim, 1, 1) |
| else: |
| load_weights = state_dict[key] |
| model.state_dict()[_key].copy_(load_weights) |
| delete_keys.append(key) |
| loaded_keys.append(prefix+_key) |
| |
| conv_list = ["conv"] if use_linear else ["conv", ".q.", ".k.", ".v.", ".proj_out.", ".nin_shortcut."] |
| if any(k in _key for k in conv_list): |
| if _key.endswith(".weight"): |
| conv_key = _key.replace(".weight", ".conv.weight") |
| if conv_key and conv_key in model.state_dict(): |
| if model.state_dict()[conv_key].shape == state_dict[key].shape: |
| |
| load_weights = state_dict[key] |
| else: |
| |
| _expand_dim = model.state_dict()[conv_key].shape[2] |
| load_weights = state_dict[key].unsqueeze(2).repeat(1, 1, _expand_dim, 1, 1) |
| model.state_dict()[conv_key].copy_(load_weights) |
| delete_keys.append(key) |
| loaded_keys.append(prefix+conv_key) |
| if _key.endswith(".bias"): |
| conv_key = _key.replace(".bias", ".conv.bias") |
| if conv_key and conv_key in model.state_dict(): |
| model.state_dict()[conv_key].copy_(state_dict[key]) |
| delete_keys.append(key) |
| loaded_keys.append(prefix+conv_key) |
| |
| if "norm" in _key: |
| if _key.endswith(".weight"): |
| norm_key = _key.replace(".weight", ".norm.weight") |
| if norm_key and norm_key in model.state_dict(): |
| model.state_dict()[norm_key].copy_(state_dict[key]) |
| delete_keys.append(key) |
| loaded_keys.append(prefix+norm_key) |
| if _key.endswith(".bias"): |
| norm_key = _key.replace(".bias", ".norm.bias") |
| if norm_key and norm_key in model.state_dict(): |
| model.state_dict()[norm_key].copy_(state_dict[key]) |
| delete_keys.append(key) |
| loaded_keys.append(prefix+norm_key) |
| |
| for key in delete_keys: |
| del state_dict[key] |
|
|
| return model, state_dict, loaded_keys |
|
|
|
|
| def vae_model(vqgan_ckpt, schedule_mode, codebook_dim, codebook_size, test_mode=True, patch_size=16, encoder_ch_mult=[1, 2, 4, 4, 4], decoder_ch_mult=[1, 2, 4, 4, 4],): |
| args=argparse.Namespace( |
| vqgan_ckpt=vqgan_ckpt, |
| sd_ckpt=None, |
| inference_type='image', |
| save='./imagenet_val_bsq', |
| save_prediction=True, |
| image_recon4video=False, |
| junke_old=False, |
| device='cuda', |
| max_steps=1000000.0, |
| log_every=1, |
| visu_every=1000, |
| ckpt_every=1000, |
| default_root_dir='', |
| compile='no', |
| ema='no', |
| lr=0.0001, |
| beta1=0.9, |
| beta2=0.95, |
| warmup_steps=0, |
| optim_type='Adam', |
| disc_optim_type=None, |
| lr_min=0.0, |
| warmup_lr_init=0.0, |
| max_grad_norm=1.0, |
| max_grad_norm_disc=1.0, |
| disable_sch=False, |
| patch_size=patch_size, |
| temporal_patch_size=4, |
| embedding_dim=256, |
| codebook_dim=codebook_dim, |
| num_quantizers=8, |
| quantizer_type='MultiScaleBSQ', |
| use_vae=False, |
| use_freq_enc=False, |
| use_freq_dec=False, |
| preserve_norm=False, |
| ln_before_quant=False, |
| ln_init_by_sqrt=False, |
| use_pxsf=False, |
| new_quant=True, |
| use_decay_factor=False, |
| mask_out=False, |
| use_stochastic_depth=False, |
| drop_rate=0.0, |
| schedule_mode=schedule_mode, |
| lr_drop=None, |
| lr_drop_rate=0.1, |
| keep_first_quant=False, |
| keep_last_quant=False, |
| remove_residual_detach=False, |
| use_out_phi=False, |
| use_out_phi_res=False, |
| use_lecam_reg=False, |
| lecam_weight=0.05, |
| perceptual_model='vgg16', |
| base_ch_disc=64, |
| random_flip=False, |
| flip_prob=0.5, |
| flip_mode='stochastic', |
| max_flip_lvl=1, |
| not_load_optimizer=False, |
| use_lecam_reg_zero=False, |
| freeze_encoder=False, |
| rm_downsample=False, |
| random_flip_1lvl=False, |
| flip_lvl_idx=0, |
| drop_when_test=False, |
| drop_lvl_idx=0, |
| drop_lvl_num=1, |
| disc_version='v1', |
| magvit_disc=False, |
| sigmoid_in_disc=False, |
| activation_in_disc='leaky_relu', |
| apply_blur=False, |
| apply_noise=False, |
| dis_warmup_steps=0, |
| dis_lr_multiplier=1.0, |
| dis_minlr_multiplier=False, |
| disc_channels=64, |
| disc_layers=3, |
| discriminator_iter_start=0, |
| disc_pretrain_iter=0, |
| disc_optim_steps=1, |
| disc_warmup=0, |
| disc_pool='no', |
| disc_pool_size=1000, |
| advanced_disc=False, |
| recon_loss_type='l1', |
| video_perceptual_weight=0.0, |
| image_gan_weight=1.0, |
| video_gan_weight=1.0, |
| image_disc_weight=0.0, |
| video_disc_weight=0.0, |
| l1_weight=4.0, |
| gan_feat_weight=0.0, |
| perceptual_weight=0.0, |
| kl_weight=0.0, |
| lfq_weight=0.0, |
| entropy_loss_weight=0.1, |
| commitment_loss_weight=0.25, |
| diversity_gamma=1, |
| norm_type='group', |
| disc_loss_type='hinge', |
| use_checkpoint=False, |
| precision='fp32', |
| encoder_dtype='fp32', |
| upcast_attention='', |
| upcast_tf32=False, |
| tokenizer='flux', |
| pretrained=None, |
| pretrained_mode='full', |
| inflation_pe=False, |
| init_vgen='no', |
| no_init_idis=False, |
| init_idis='keep', |
| init_vdis='no', |
| enable_nan_detector=False, |
| turn_on_profiler=False, |
| profiler_scheduler_wait_steps=10, |
| debug=True, |
| video_logger=False, |
| bytenas='', |
| username='', |
| seed=1234, |
| vq_to_vae=False, |
| load_not_strict=False, |
| zero=0, |
| bucket_cap_mb=40, |
| manual_gc_interval=1000, |
| data_path=[''], |
| data_type=[''], |
| dataset_list=['imagenet'], |
| fps=-1, |
| dataaug='resizecrop', |
| multi_resolution=False, |
| random_bucket_ratio=0.0, |
| sequence_length=16, |
| resolution=[256, 256], |
| batch_size=[1], |
| num_workers=0, |
| image_channels=3, |
| codebook_size=codebook_size, |
| codebook_l2_norm=True, |
| codebook_show_usage=True, |
| commit_loss_beta=0.25, |
| entropy_loss_ratio=0.0, |
| base_ch=128, |
| num_res_blocks=2, |
| encoder_ch_mult=encoder_ch_mult, |
| decoder_ch_mult=decoder_ch_mult, |
| dropout_p=0.0, |
| cnn_type='2d', |
| cnn_version='v1', |
| conv_in_out_2d='no', |
| conv_inner_2d='no', |
| res_conv_2d='no', |
| cnn_attention='no', |
| cnn_norm_axis='spatial', |
| flux_weight=0, |
| cycle_weight=0, |
| cycle_feat_weight=0, |
| cycle_gan_weight=0, |
| cycle_loop=0, |
| z_drop=0.0) |
| |
| vae = AutoEncoder(args) |
| use_vae = vae.use_vae |
| if not use_vae: |
| num_codes = args.codebook_size |
| if isinstance(vqgan_ckpt, str): |
| state_dict = torch.load(args.vqgan_ckpt, map_location=torch.device("cpu"), weights_only=True) |
| else: |
| state_dict = args.vqgan_ckpt |
| if state_dict: |
| if args.ema == "yes": |
| vae, new_state_dict, loaded_keys = load_cnn(vae, state_dict["ema"], prefix="", expand=False) |
| else: |
| vae, new_state_dict, loaded_keys = load_cnn(vae, state_dict["vae"], prefix="", expand=False) |
| if test_mode: |
| vae.eval() |
| [p.requires_grad_(False) for p in vae.parameters()] |
| return vae |
|
|
| def vae_model_lora(vqgan_ckpt, schedule_mode, codebook_dim, codebook_size, test_mode=True, patch_size=16, encoder_ch_mult=[1, 2, 4, 4, 4], decoder_ch_mult=[1, 2, 4, 4, 4],): |
| args=argparse.Namespace( |
| vqgan_ckpt=vqgan_ckpt, |
| sd_ckpt=None, |
| inference_type='image', |
| save='./imagenet_val_bsq', |
| save_prediction=True, |
| image_recon4video=False, |
| junke_old=False, |
| device='cuda', |
| max_steps=1000000.0, |
| log_every=1, |
| visu_every=1000, |
| ckpt_every=1000, |
| default_root_dir='', |
| compile='no', |
| ema='no', |
| lr=0.0001, |
| beta1=0.9, |
| beta2=0.95, |
| warmup_steps=0, |
| optim_type='Adam', |
| disc_optim_type=None, |
| lr_min=0.0, |
| warmup_lr_init=0.0, |
| max_grad_norm=1.0, |
| max_grad_norm_disc=1.0, |
| disable_sch=False, |
| patch_size=patch_size, |
| temporal_patch_size=4, |
| embedding_dim=256, |
| codebook_dim=codebook_dim, |
| num_quantizers=8, |
| quantizer_type='MultiScaleBSQ', |
| use_vae=False, |
| use_freq_enc=False, |
| use_freq_dec=False, |
| preserve_norm=False, |
| ln_before_quant=False, |
| ln_init_by_sqrt=False, |
| use_pxsf=False, |
| new_quant=True, |
| use_decay_factor=False, |
| mask_out=False, |
| use_stochastic_depth=False, |
| drop_rate=0.0, |
| schedule_mode=schedule_mode, |
| lr_drop=None, |
| lr_drop_rate=0.1, |
| keep_first_quant=False, |
| keep_last_quant=False, |
| remove_residual_detach=False, |
| use_out_phi=False, |
| use_out_phi_res=False, |
| use_lecam_reg=False, |
| lecam_weight=0.05, |
| perceptual_model='vgg16', |
| base_ch_disc=64, |
| random_flip=False, |
| flip_prob=0.5, |
| flip_mode='stochastic', |
| max_flip_lvl=1, |
| not_load_optimizer=False, |
| use_lecam_reg_zero=False, |
| freeze_encoder=False, |
| rm_downsample=False, |
| random_flip_1lvl=False, |
| flip_lvl_idx=0, |
| drop_when_test=False, |
| drop_lvl_idx=0, |
| drop_lvl_num=1, |
| disc_version='v1', |
| magvit_disc=False, |
| sigmoid_in_disc=False, |
| activation_in_disc='leaky_relu', |
| apply_blur=False, |
| apply_noise=False, |
| dis_warmup_steps=0, |
| dis_lr_multiplier=1.0, |
| dis_minlr_multiplier=False, |
| disc_channels=64, |
| disc_layers=3, |
| discriminator_iter_start=0, |
| disc_pretrain_iter=0, |
| disc_optim_steps=1, |
| disc_warmup=0, |
| disc_pool='no', |
| disc_pool_size=1000, |
| advanced_disc=False, |
| recon_loss_type='l1', |
| video_perceptual_weight=0.0, |
| image_gan_weight=1.0, |
| video_gan_weight=1.0, |
| image_disc_weight=0.0, |
| video_disc_weight=0.0, |
| l1_weight=4.0, |
| gan_feat_weight=0.0, |
| perceptual_weight=0.0, |
| kl_weight=0.0, |
| lfq_weight=0.0, |
| entropy_loss_weight=0.1, |
| commitment_loss_weight=0.25, |
| diversity_gamma=1, |
| norm_type='group', |
| disc_loss_type='hinge', |
| use_checkpoint=False, |
| precision='fp32', |
| encoder_dtype='fp32', |
| upcast_attention='', |
| upcast_tf32=False, |
| tokenizer='flux', |
| pretrained=None, |
| pretrained_mode='full', |
| inflation_pe=False, |
| init_vgen='no', |
| no_init_idis=False, |
| init_idis='keep', |
| init_vdis='no', |
| enable_nan_detector=False, |
| turn_on_profiler=False, |
| profiler_scheduler_wait_steps=10, |
| debug=True, |
| video_logger=False, |
| bytenas='', |
| username='', |
| seed=1234, |
| vq_to_vae=False, |
| load_not_strict=False, |
| zero=0, |
| bucket_cap_mb=40, |
| manual_gc_interval=1000, |
| data_path=[''], |
| data_type=[''], |
| dataset_list=['imagenet'], |
| fps=-1, |
| dataaug='resizecrop', |
| multi_resolution=False, |
| random_bucket_ratio=0.0, |
| sequence_length=16, |
| resolution=[256, 256], |
| batch_size=[1], |
| num_workers=0, |
| image_channels=3, |
| codebook_size=codebook_size, |
| codebook_l2_norm=True, |
| codebook_show_usage=True, |
| commit_loss_beta=0.25, |
| entropy_loss_ratio=0.0, |
| base_ch=128, |
| num_res_blocks=2, |
| encoder_ch_mult=encoder_ch_mult, |
| decoder_ch_mult=decoder_ch_mult, |
| dropout_p=0.0, |
| cnn_type='2d', |
| cnn_version='v1', |
| conv_in_out_2d='no', |
| conv_inner_2d='no', |
| res_conv_2d='no', |
| cnn_attention='no', |
| cnn_norm_axis='spatial', |
| flux_weight=0, |
| cycle_weight=0, |
| cycle_feat_weight=0, |
| cycle_gan_weight=0, |
| cycle_loop=0, |
| z_drop=0.0) |
| |
| vae = AutoEncoder(args) |
| use_vae = vae.use_vae |
| if not use_vae: |
| num_codes = args.codebook_size |
| if isinstance(vqgan_ckpt, str): |
| state_dict = torch.load(args.vqgan_ckpt, map_location=torch.device("cpu"), weights_only=False) |
| else: |
| state_dict = args.vqgan_ckpt |
| if state_dict: |
| |
| |
| |
| |
| from lora_diffusion import inject_trainable_lora_extended |
| unet_lora_params_vae, train_names_vae = inject_trainable_lora_extended(vae.encoder) |
| vae.load_state_dict(state_dict['vae'],strict=True) |
| if test_mode: |
| vae.eval() |
| [p.requires_grad_(False) for p in vae.parameters()] |
| return vae |